Skip to main content

vector_ta/indicators/
natr.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5    make_uninit_matrix,
6};
7use aligned_vec::{AVec, CACHELINE_ALIGN};
8#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
9use core::arch::x86_64::*;
10#[cfg(not(target_arch = "wasm32"))]
11use rayon::prelude::*;
12use std::convert::AsRef;
13use thiserror::Error;
14
15#[cfg(feature = "python")]
16use crate::utilities::kernel_validation::validate_kernel;
17#[cfg(feature = "python")]
18use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
19#[cfg(feature = "python")]
20use pyo3::exceptions::PyValueError;
21#[cfg(feature = "python")]
22use pyo3::prelude::*;
23#[cfg(feature = "python")]
24use pyo3::types::PyDict;
25#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
26use serde::{Deserialize, Serialize};
27#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
28use wasm_bindgen::prelude::*;
29
30#[derive(Debug, Clone)]
31pub enum NatrData<'a> {
32    Candles {
33        candles: &'a Candles,
34    },
35    Slices {
36        high: &'a [f64],
37        low: &'a [f64],
38        close: &'a [f64],
39    },
40}
41
42#[derive(Debug, Clone)]
43pub struct NatrOutput {
44    pub values: Vec<f64>,
45}
46
47#[derive(Debug, Clone)]
48#[cfg_attr(
49    all(target_arch = "wasm32", feature = "wasm"),
50    derive(Serialize, Deserialize)
51)]
52pub struct NatrParams {
53    pub period: Option<usize>,
54}
55
56impl Default for NatrParams {
57    fn default() -> Self {
58        Self { period: Some(14) }
59    }
60}
61
62#[derive(Debug, Clone)]
63pub struct NatrInput<'a> {
64    pub data: NatrData<'a>,
65    pub params: NatrParams,
66}
67
68impl<'a> NatrInput<'a> {
69    #[inline]
70    pub fn from_candles(candles: &'a Candles, params: NatrParams) -> Self {
71        Self {
72            data: NatrData::Candles { candles },
73            params,
74        }
75    }
76    #[inline]
77    pub fn from_slices(
78        high: &'a [f64],
79        low: &'a [f64],
80        close: &'a [f64],
81        params: NatrParams,
82    ) -> Self {
83        Self {
84            data: NatrData::Slices { high, low, close },
85            params,
86        }
87    }
88    #[inline]
89    pub fn with_default_candles(candles: &'a Candles) -> Self {
90        Self {
91            data: NatrData::Candles { candles },
92            params: NatrParams::default(),
93        }
94    }
95    #[inline]
96    pub fn get_period(&self) -> usize {
97        self.params.period.unwrap_or(14)
98    }
99}
100
101#[derive(Copy, Clone, Debug)]
102pub struct NatrBuilder {
103    period: Option<usize>,
104    kernel: Kernel,
105}
106
107impl Default for NatrBuilder {
108    fn default() -> Self {
109        Self {
110            period: None,
111            kernel: Kernel::Auto,
112        }
113    }
114}
115
116impl NatrBuilder {
117    #[inline(always)]
118    pub fn new() -> Self {
119        Self::default()
120    }
121    #[inline(always)]
122    pub fn period(mut self, n: usize) -> Self {
123        self.period = Some(n);
124        self
125    }
126    #[inline(always)]
127    pub fn kernel(mut self, k: Kernel) -> Self {
128        self.kernel = k;
129        self
130    }
131    #[inline(always)]
132    pub fn apply(self, c: &Candles) -> Result<NatrOutput, NatrError> {
133        let p = NatrParams {
134            period: self.period,
135        };
136        let i = NatrInput::from_candles(c, p);
137        natr_with_kernel(&i, self.kernel)
138    }
139    #[inline(always)]
140    pub fn apply_slices(
141        self,
142        high: &[f64],
143        low: &[f64],
144        close: &[f64],
145    ) -> Result<NatrOutput, NatrError> {
146        let p = NatrParams {
147            period: self.period,
148        };
149        let i = NatrInput::from_slices(high, low, close, p);
150        natr_with_kernel(&i, self.kernel)
151    }
152    #[inline(always)]
153    pub fn into_stream(self) -> Result<NatrStream, NatrError> {
154        let p = NatrParams {
155            period: self.period,
156        };
157        NatrStream::try_new(p)
158    }
159}
160
161#[derive(Debug, Error)]
162pub enum NatrError {
163    #[error("natr: Empty data provided for NATR.")]
164    EmptyInputData,
165    #[error("natr: All values are NaN.")]
166    AllValuesNaN,
167    #[error("natr: Empty data provided for NATR.")]
168    EmptyData,
169    #[error("natr: Invalid period: period = {period}, data length = {data_len}")]
170    InvalidPeriod { period: usize, data_len: usize },
171    #[error("natr: Not enough valid data: needed = {needed}, valid = {valid}")]
172    NotEnoughValidData { needed: usize, valid: usize },
173    #[error("natr: Output length mismatch: expected = {expected}, got = {got}")]
174    OutputLengthMismatch { expected: usize, got: usize },
175    #[error("natr: Invalid range: start={start}, end={end}, step={step}")]
176    InvalidRange {
177        start: String,
178        end: String,
179        step: String,
180    },
181    #[error("natr: Invalid kernel for batch: {0:?}")]
182    InvalidKernelForBatch(Kernel),
183    #[error("natr: Mismatched lengths: expected = {expected}, actual = {actual}")]
184    MismatchedLength { expected: usize, actual: usize },
185}
186
187#[inline]
188pub fn natr(input: &NatrInput) -> Result<NatrOutput, NatrError> {
189    natr_with_kernel(input, Kernel::Auto)
190}
191
192pub fn natr_with_kernel(input: &NatrInput, kernel: Kernel) -> Result<NatrOutput, NatrError> {
193    let (high, low, close) = match &input.data {
194        NatrData::Candles { candles } => {
195            let high = source_type(candles, "high");
196            let low = source_type(candles, "low");
197            let close = source_type(candles, "close");
198            (high, low, close)
199        }
200        NatrData::Slices { high, low, close } => (*high, *low, *close),
201    };
202
203    if high.is_empty() || low.is_empty() || close.is_empty() {
204        return Err(NatrError::EmptyInputData);
205    }
206
207    let len_h = high.len();
208    let len_l = low.len();
209    let len_c = close.len();
210    if len_h != len_l || len_h != len_c {
211        return Err(NatrError::MismatchedLength {
212            expected: len_h,
213            actual: if len_l != len_h { len_l } else { len_c },
214        });
215    }
216    let len = len_h;
217
218    let period = input.get_period();
219    if period == 0 || period > len {
220        return Err(NatrError::InvalidPeriod {
221            period,
222            data_len: len,
223        });
224    }
225
226    let first_valid_idx = {
227        let first_valid_idx_h = high.iter().position(|&x| !x.is_nan());
228        let first_valid_idx_l = low.iter().position(|&x| !x.is_nan());
229        let first_valid_idx_c = close.iter().position(|&x| !x.is_nan());
230
231        match (first_valid_idx_h, first_valid_idx_l, first_valid_idx_c) {
232            (Some(h), Some(l), Some(c)) => Some(h.max(l).max(c)),
233            _ => None,
234        }
235    };
236
237    let first_valid_idx = match first_valid_idx {
238        Some(idx) => idx,
239        None => return Err(NatrError::AllValuesNaN),
240    };
241
242    if (len - first_valid_idx) < period {
243        return Err(NatrError::NotEnoughValidData {
244            needed: period,
245            valid: len - first_valid_idx,
246        });
247    }
248
249    let mut out = alloc_with_nan_prefix(len, first_valid_idx + period - 1);
250
251    let chosen = match kernel {
252        Kernel::Auto => Kernel::Scalar,
253        other => other,
254    };
255
256    unsafe {
257        match chosen {
258            Kernel::Scalar | Kernel::ScalarBatch => {
259                natr_scalar(high, low, close, period, first_valid_idx, &mut out)
260            }
261            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
262            Kernel::Avx2 | Kernel::Avx2Batch => {
263                natr_avx2(high, low, close, period, first_valid_idx, &mut out)
264            }
265            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
266            Kernel::Avx512 | Kernel::Avx512Batch => {
267                natr_avx512(high, low, close, period, first_valid_idx, &mut out)
268            }
269            _ => unreachable!(),
270        }
271    }
272
273    Ok(NatrOutput { values: out })
274}
275
276#[inline]
277pub fn natr_scalar(
278    high: &[f64],
279    low: &[f64],
280    close: &[f64],
281    period: usize,
282    first: usize,
283    out: &mut [f64],
284) {
285    let len = out.len();
286    if first >= len {
287        return;
288    }
289
290    let inv_p: f64 = 1.0 / (period as f64);
291    let k100: f64 = 100.0;
292
293    let warm_end = first + period - 1;
294    let mut sum_tr = 0.0;
295
296    sum_tr += high[first] - low[first];
297    for i in (first + 1)..=warm_end {
298        let hi = high[i];
299        let lo = low[i];
300        let pc = close[i - 1];
301        let tr = (hi - lo).max((hi - pc).abs()).max((lo - pc).abs());
302        sum_tr += tr;
303    }
304
305    let mut atr = sum_tr * inv_p;
306    let c_we = close[warm_end];
307    out[warm_end] = if c_we.is_finite() && c_we != 0.0 {
308        (atr / c_we) * k100
309    } else {
310        f64::NAN
311    };
312
313    let mut idx = warm_end + 1;
314    while idx + 3 < len {
315        let pc0 = close[idx - 1];
316        let pc1 = close[idx + 0];
317        let pc2 = close[idx + 1];
318        let pc3 = close[idx + 2];
319
320        let h0 = high[idx + 0];
321        let h1 = high[idx + 1];
322        let h2 = high[idx + 2];
323        let h3 = high[idx + 3];
324        let l0 = low[idx + 0];
325        let l1 = low[idx + 1];
326        let l2 = low[idx + 2];
327        let l3 = low[idx + 3];
328
329        let tr0 = (h0 - l0).max((h0 - pc0).abs()).max((l0 - pc0).abs());
330        let tr1 = (h1 - l1).max((h1 - pc1).abs()).max((l1 - pc1).abs());
331        let tr2 = (h2 - l2).max((h2 - pc2).abs()).max((l2 - pc2).abs());
332        let tr3 = (h3 - l3).max((h3 - pc3).abs()).max((l3 - pc3).abs());
333
334        atr = (tr0 - atr).mul_add(inv_p, atr);
335        let c0 = close[idx + 0];
336        out[idx + 0] = if c0.is_finite() && c0 != 0.0 {
337            (atr / c0) * k100
338        } else {
339            f64::NAN
340        };
341
342        atr = (tr1 - atr).mul_add(inv_p, atr);
343        let c1v = close[idx + 1];
344        out[idx + 1] = if c1v.is_finite() && c1v != 0.0 {
345            (atr / c1v) * k100
346        } else {
347            f64::NAN
348        };
349
350        atr = (tr2 - atr).mul_add(inv_p, atr);
351        let c2v = close[idx + 2];
352        out[idx + 2] = if c2v.is_finite() && c2v != 0.0 {
353            (atr / c2v) * k100
354        } else {
355            f64::NAN
356        };
357
358        atr = (tr3 - atr).mul_add(inv_p, atr);
359        let c3v = close[idx + 3];
360        out[idx + 3] = if c3v.is_finite() && c3v != 0.0 {
361            (atr / c3v) * k100
362        } else {
363            f64::NAN
364        };
365
366        idx += 4;
367    }
368
369    while idx < len {
370        let hi = high[idx];
371        let lo = low[idx];
372        let pc = close[idx - 1];
373        let tr = (hi - lo).max((hi - pc).abs()).max((lo - pc).abs());
374        atr = (tr - atr).mul_add(inv_p, atr);
375
376        let cv = close[idx];
377        out[idx] = if cv.is_finite() && cv != 0.0 {
378            (atr / cv) * k100
379        } else {
380            f64::NAN
381        };
382        idx += 1;
383    }
384}
385
386#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
387#[inline]
388pub fn natr_avx512(
389    high: &[f64],
390    low: &[f64],
391    close: &[f64],
392    period: usize,
393    first: usize,
394    out: &mut [f64],
395) {
396    unsafe {
397        natr_avx512_body(high, low, close, period, first, out);
398    }
399}
400
401#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
402#[target_feature(enable = "avx2")]
403#[inline]
404pub unsafe fn natr_avx2(
405    high: &[f64],
406    low: &[f64],
407    close: &[f64],
408    period: usize,
409    first: usize,
410    out: &mut [f64],
411) {
412    use core::arch::x86_64::*;
413    debug_assert!(high.len() == low.len() && high.len() == close.len() && high.len() == out.len());
414    let len = out.len();
415    if first >= len {
416        return;
417    }
418
419    let inv_p: f64 = 1.0 / (period as f64);
420    let k100: f64 = 100.0;
421
422    let h = high.as_ptr();
423    let l = low.as_ptr();
424    let c = close.as_ptr();
425    let o = out.as_mut_ptr();
426
427    let mut i = first;
428    let mut sum_tr = *h.add(i) - *l.add(i);
429    i += 1;
430    let warm_end = first + period - 1;
431
432    let sign = _mm256_set1_pd(-0.0_f64);
433    while i + 3 <= warm_end {
434        let vh = _mm256_loadu_pd(h.add(i));
435        let vl = _mm256_loadu_pd(l.add(i));
436        let vpc = _mm256_loadu_pd(c.add(i - 1));
437
438        let vhl = _mm256_sub_pd(vh, vl);
439        let vhc = _mm256_sub_pd(vh, vpc);
440        let vlc = _mm256_sub_pd(vl, vpc);
441        let abs_hc = _mm256_andnot_pd(sign, vhc);
442        let abs_lc = _mm256_andnot_pd(sign, vlc);
443        let mx1 = _mm256_max_pd(vhl, abs_hc);
444        let vtr = _mm256_max_pd(mx1, abs_lc);
445
446        let mut tmp = [0.0f64; 4];
447        _mm256_storeu_pd(tmp.as_mut_ptr(), vtr);
448        sum_tr += tmp[0] + tmp[1] + tmp[2] + tmp[3];
449
450        i += 4;
451    }
452    while i <= warm_end {
453        let hi = *h.add(i);
454        let lo = *l.add(i);
455        let pc = *c.add(i - 1);
456        let tr = (hi - lo).max((hi - pc).abs()).max((lo - pc).abs());
457        sum_tr += tr;
458        i += 1;
459    }
460
461    let mut atr = sum_tr * inv_p;
462    let c_we = *c.add(warm_end);
463    *o.add(warm_end) = if c_we.is_finite() && c_we != 0.0 {
464        (atr / c_we) * k100
465    } else {
466        f64::NAN
467    };
468
469    let mut idx = warm_end + 1;
470    while idx + 3 < len {
471        let vh = _mm256_loadu_pd(h.add(idx));
472        let vl = _mm256_loadu_pd(l.add(idx));
473        let vpc = _mm256_loadu_pd(c.add(idx - 1));
474
475        let vhl = _mm256_sub_pd(vh, vl);
476        let vhc = _mm256_sub_pd(vh, vpc);
477        let vlc = _mm256_sub_pd(vl, vpc);
478        let abs_hc = _mm256_andnot_pd(sign, vhc);
479        let abs_lc = _mm256_andnot_pd(sign, vlc);
480        let mx1 = _mm256_max_pd(vhl, abs_hc);
481        let vtr = _mm256_max_pd(mx1, abs_lc);
482
483        let mut tr = [0.0f64; 4];
484        _mm256_storeu_pd(tr.as_mut_ptr(), vtr);
485
486        atr = (tr[0] - atr).mul_add(inv_p, atr);
487        let c0 = *c.add(idx + 0);
488        *o.add(idx + 0) = if c0.is_finite() && c0 != 0.0 {
489            (atr / c0) * k100
490        } else {
491            f64::NAN
492        };
493
494        atr = (tr[1] - atr).mul_add(inv_p, atr);
495        let c1 = *c.add(idx + 1);
496        *o.add(idx + 1) = if c1.is_finite() && c1 != 0.0 {
497            (atr / c1) * k100
498        } else {
499            f64::NAN
500        };
501
502        atr = (tr[2] - atr).mul_add(inv_p, atr);
503        let c2 = *c.add(idx + 2);
504        *o.add(idx + 2) = if c2.is_finite() && c2 != 0.0 {
505            (atr / c2) * k100
506        } else {
507            f64::NAN
508        };
509
510        atr = (tr[3] - atr).mul_add(inv_p, atr);
511        let c3 = *c.add(idx + 3);
512        *o.add(idx + 3) = if c3.is_finite() && c3 != 0.0 {
513            (atr / c3) * k100
514        } else {
515            f64::NAN
516        };
517
518        idx += 4;
519    }
520    while idx < len {
521        let hi = *h.add(idx);
522        let lo = *l.add(idx);
523        let pc = *c.add(idx - 1);
524        let tr = (hi - lo).max((hi - pc).abs()).max((lo - pc).abs());
525        atr = (tr - atr).mul_add(inv_p, atr);
526
527        let cv = *c.add(idx);
528        *o.add(idx) = if cv.is_finite() && cv != 0.0 {
529            (atr / cv) * k100
530        } else {
531            f64::NAN
532        };
533        idx += 1;
534    }
535}
536
537#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
538#[inline]
539#[target_feature(enable = "avx512f")]
540unsafe fn natr_avx512_body(
541    high: &[f64],
542    low: &[f64],
543    close: &[f64],
544    period: usize,
545    first: usize,
546    out: &mut [f64],
547) {
548    use core::arch::x86_64::*;
549    debug_assert!(high.len() == low.len() && high.len() == close.len() && high.len() == out.len());
550    let len = out.len();
551    if first >= len {
552        return;
553    }
554
555    let inv_p: f64 = 1.0 / (period as f64);
556    let k100: f64 = 100.0;
557
558    let h = high.as_ptr();
559    let l = low.as_ptr();
560    let c = close.as_ptr();
561    let o = out.as_mut_ptr();
562
563    let mut i = first;
564    let mut sum_tr = *h.add(i) - *l.add(i);
565    i += 1;
566    let warm_end = first + period - 1;
567
568    let sign = _mm512_set1_pd(-0.0_f64);
569    while i + 7 <= warm_end {
570        let vh = _mm512_loadu_pd(h.add(i));
571        let vl = _mm512_loadu_pd(l.add(i));
572        let vpc = _mm512_loadu_pd(c.add(i - 1));
573
574        let vhl = _mm512_sub_pd(vh, vl);
575        let vhc = _mm512_sub_pd(vh, vpc);
576        let vlc = _mm512_sub_pd(vl, vpc);
577        let abs_hc = _mm512_andnot_pd(sign, vhc);
578        let abs_lc = _mm512_andnot_pd(sign, vlc);
579        let mx1 = _mm512_max_pd(vhl, abs_hc);
580        let vtr = _mm512_max_pd(mx1, abs_lc);
581
582        let mut tmp = [0.0f64; 8];
583        _mm512_storeu_pd(tmp.as_mut_ptr(), vtr);
584        sum_tr += tmp.iter().sum::<f64>();
585
586        i += 8;
587    }
588    while i <= warm_end {
589        let hi = *h.add(i);
590        let lo = *l.add(i);
591        let pc = *c.add(i - 1);
592        let tr = (hi - lo).max((hi - pc).abs()).max((lo - pc).abs());
593        sum_tr += tr;
594        i += 1;
595    }
596
597    let mut atr = sum_tr * inv_p;
598    let c_we = *c.add(warm_end);
599    *o.add(warm_end) = if c_we.is_finite() && c_we != 0.0 {
600        (atr / c_we) * k100
601    } else {
602        f64::NAN
603    };
604
605    let mut idx = warm_end + 1;
606    while idx + 7 < len {
607        let vh = _mm512_loadu_pd(h.add(idx));
608        let vl = _mm512_loadu_pd(l.add(idx));
609        let vpc = _mm512_loadu_pd(c.add(idx - 1));
610
611        let vhl = _mm512_sub_pd(vh, vl);
612        let vhc = _mm512_sub_pd(vh, vpc);
613        let vlc = _mm512_sub_pd(vl, vpc);
614        let abs_hc = _mm512_andnot_pd(sign, vhc);
615        let abs_lc = _mm512_andnot_pd(sign, vlc);
616        let mx1 = _mm512_max_pd(vhl, abs_hc);
617        let vtr = _mm512_max_pd(mx1, abs_lc);
618
619        let mut tr = [0.0f64; 8];
620        _mm512_storeu_pd(tr.as_mut_ptr(), vtr);
621
622        atr = (tr[0] - atr).mul_add(inv_p, atr);
623        let c0 = *c.add(idx + 0);
624        *o.add(idx + 0) = if c0.is_finite() && c0 != 0.0 {
625            (atr / c0) * k100
626        } else {
627            f64::NAN
628        };
629        atr = (tr[1] - atr).mul_add(inv_p, atr);
630        let c1 = *c.add(idx + 1);
631        *o.add(idx + 1) = if c1.is_finite() && c1 != 0.0 {
632            (atr / c1) * k100
633        } else {
634            f64::NAN
635        };
636        atr = (tr[2] - atr).mul_add(inv_p, atr);
637        let c2 = *c.add(idx + 2);
638        *o.add(idx + 2) = if c2.is_finite() && c2 != 0.0 {
639            (atr / c2) * k100
640        } else {
641            f64::NAN
642        };
643        atr = (tr[3] - atr).mul_add(inv_p, atr);
644        let c3 = *c.add(idx + 3);
645        *o.add(idx + 3) = if c3.is_finite() && c3 != 0.0 {
646            (atr / c3) * k100
647        } else {
648            f64::NAN
649        };
650        atr = (tr[4] - atr).mul_add(inv_p, atr);
651        let c4 = *c.add(idx + 4);
652        *o.add(idx + 4) = if c4.is_finite() && c4 != 0.0 {
653            (atr / c4) * k100
654        } else {
655            f64::NAN
656        };
657        atr = (tr[5] - atr).mul_add(inv_p, atr);
658        let c5 = *c.add(idx + 5);
659        *o.add(idx + 5) = if c5.is_finite() && c5 != 0.0 {
660            (atr / c5) * k100
661        } else {
662            f64::NAN
663        };
664        atr = (tr[6] - atr).mul_add(inv_p, atr);
665        let c6 = *c.add(idx + 6);
666        *o.add(idx + 6) = if c6.is_finite() && c6 != 0.0 {
667            (atr / c6) * k100
668        } else {
669            f64::NAN
670        };
671        atr = (tr[7] - atr).mul_add(inv_p, atr);
672        let c7 = *c.add(idx + 7);
673        *o.add(idx + 7) = if c7.is_finite() && c7 != 0.0 {
674            (atr / c7) * k100
675        } else {
676            f64::NAN
677        };
678
679        idx += 8;
680    }
681    while idx < len {
682        let hi = *h.add(idx);
683        let lo = *l.add(idx);
684        let pc = *c.add(idx - 1);
685        let tr = (hi - lo).max((hi - pc).abs()).max((lo - pc).abs());
686        atr = (tr - atr).mul_add(inv_p, atr);
687
688        let cv = *c.add(idx);
689        *o.add(idx) = if cv.is_finite() && cv != 0.0 {
690            (atr / cv) * k100
691        } else {
692            f64::NAN
693        };
694        idx += 1;
695    }
696}
697
698#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
699#[inline]
700pub unsafe fn natr_avx512_short(
701    high: &[f64],
702    low: &[f64],
703    close: &[f64],
704    period: usize,
705    first: usize,
706    out: &mut [f64],
707) {
708    natr_avx512_body(high, low, close, period, first, out);
709}
710
711#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
712#[inline]
713pub unsafe fn natr_avx512_long(
714    high: &[f64],
715    low: &[f64],
716    close: &[f64],
717    period: usize,
718    first: usize,
719    out: &mut [f64],
720) {
721    natr_avx512_body(high, low, close, period, first, out);
722}
723
724#[derive(Debug, Clone)]
725pub struct NatrStream {
726    period: usize,
727    alpha: f64,
728    k100: f64,
729
730    count: usize,
731    sum_tr: f64,
732    atr: f64,
733    prev_close: f64,
734    have_prev: bool,
735    ready: bool,
736}
737
738impl NatrStream {
739    #[inline(always)]
740    pub fn try_new(params: NatrParams) -> Result<Self, NatrError> {
741        let period = params.period.unwrap_or(14);
742        if period == 0 {
743            return Err(NatrError::InvalidPeriod {
744                period,
745                data_len: 0,
746            });
747        }
748        Ok(Self {
749            period,
750            alpha: 1.0 / (period as f64),
751            k100: 100.0,
752            count: 0,
753            sum_tr: 0.0,
754            atr: 0.0,
755            prev_close: 0.0,
756            have_prev: false,
757            ready: false,
758        })
759    }
760
761    #[inline(always)]
762    pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
763        let tr = if self.have_prev {
764            let pc = self.prev_close;
765            high.max(pc) - low.min(pc)
766        } else {
767            high - low
768        };
769
770        self.prev_close = close;
771        self.have_prev = true;
772
773        self.count += 1;
774
775        if !self.ready {
776            self.sum_tr += tr;
777            if self.count == self.period {
778                self.atr = self.sum_tr / (self.period as f64);
779                self.ready = true;
780            } else {
781                return None;
782            }
783        } else {
784            self.atr = (tr - self.atr).mul_add(self.alpha, self.atr);
785        }
786
787        if close.is_finite() && close != 0.0 {
788            Some((self.atr / close) * self.k100)
789        } else {
790            Some(f64::NAN)
791        }
792    }
793}
794
795#[derive(Clone, Debug)]
796pub struct NatrBatchRange {
797    pub period: (usize, usize, usize),
798}
799
800impl Default for NatrBatchRange {
801    fn default() -> Self {
802        Self {
803            period: (14, 263, 1),
804        }
805    }
806}
807
808#[derive(Clone, Debug, Default)]
809pub struct NatrBatchBuilder {
810    range: NatrBatchRange,
811    kernel: Kernel,
812}
813
814impl NatrBatchBuilder {
815    pub fn new() -> Self {
816        Self::default()
817    }
818    pub fn kernel(mut self, k: Kernel) -> Self {
819        self.kernel = k;
820        self
821    }
822    #[inline]
823    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
824        self.range.period = (start, end, step);
825        self
826    }
827    #[inline]
828    pub fn period_static(mut self, p: usize) -> Self {
829        self.range.period = (p, p, 0);
830        self
831    }
832    pub fn apply_slices(
833        self,
834        high: &[f64],
835        low: &[f64],
836        close: &[f64],
837    ) -> Result<NatrBatchOutput, NatrError> {
838        natr_batch_with_kernel(high, low, close, &self.range, self.kernel)
839    }
840    pub fn apply_candles(self, c: &Candles) -> Result<NatrBatchOutput, NatrError> {
841        let high = source_type(c, "high");
842        let low = source_type(c, "low");
843        let close = source_type(c, "close");
844        self.apply_slices(high, low, close)
845    }
846    pub fn with_default_candles(c: &Candles, k: Kernel) -> Result<NatrBatchOutput, NatrError> {
847        NatrBatchBuilder::new().kernel(k).apply_candles(c)
848    }
849}
850
851pub fn natr_batch_with_kernel(
852    high: &[f64],
853    low: &[f64],
854    close: &[f64],
855    sweep: &NatrBatchRange,
856    k: Kernel,
857) -> Result<NatrBatchOutput, NatrError> {
858    let kernel = match k {
859        Kernel::Auto => detect_best_batch_kernel(),
860        other if other.is_batch() => other,
861        _ => {
862            return Err(NatrError::InvalidKernelForBatch(k));
863        }
864    };
865    let simd = match kernel {
866        Kernel::Avx512Batch => Kernel::Avx512,
867        Kernel::Avx2Batch => Kernel::Avx2,
868        Kernel::ScalarBatch => Kernel::Scalar,
869        _ => unreachable!(),
870    };
871    natr_batch_par_slice(high, low, close, sweep, simd)
872}
873
874#[derive(Clone, Debug)]
875pub struct NatrBatchOutput {
876    pub values: Vec<f64>,
877    pub combos: Vec<NatrParams>,
878    pub rows: usize,
879    pub cols: usize,
880}
881impl NatrBatchOutput {
882    pub fn row_for_params(&self, p: &NatrParams) -> Option<usize> {
883        self.combos
884            .iter()
885            .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
886    }
887    pub fn values_for(&self, p: &NatrParams) -> Option<&[f64]> {
888        self.row_for_params(p).map(|row| {
889            let start = row * self.cols;
890            &self.values[start..start + self.cols]
891        })
892    }
893}
894
895#[inline(always)]
896fn expand_grid(r: &NatrBatchRange) -> Result<Vec<NatrParams>, NatrError> {
897    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, NatrError> {
898        if step == 0 || start == end {
899            return Ok(vec![start]);
900        }
901
902        let mut values = Vec::new();
903        let step_u = step;
904
905        if start <= end {
906            let mut v = start;
907            loop {
908                if v > end {
909                    break;
910                }
911                values.push(v);
912                match v.checked_add(step_u) {
913                    Some(next) => v = next,
914                    None => break,
915                }
916            }
917        } else {
918            let mut v = start;
919            loop {
920                if v < end {
921                    break;
922                }
923                values.push(v);
924                match v.checked_sub(step_u) {
925                    Some(next) => v = next,
926                    None => break,
927                }
928            }
929        }
930
931        if values.is_empty() {
932            return Err(NatrError::InvalidRange {
933                start: start.to_string(),
934                end: end.to_string(),
935                step: step.to_string(),
936            });
937        }
938
939        Ok(values)
940    }
941
942    let periods = axis_usize(r.period)?;
943
944    let mut out = Vec::with_capacity(periods.len());
945    for p in periods {
946        out.push(NatrParams { period: Some(p) });
947    }
948    Ok(out)
949}
950
951#[inline(always)]
952pub fn natr_batch_slice(
953    high: &[f64],
954    low: &[f64],
955    close: &[f64],
956    sweep: &NatrBatchRange,
957    kern: Kernel,
958) -> Result<NatrBatchOutput, NatrError> {
959    natr_batch_inner(high, low, close, sweep, kern, false)
960}
961
962#[inline(always)]
963pub fn natr_batch_par_slice(
964    high: &[f64],
965    low: &[f64],
966    close: &[f64],
967    sweep: &NatrBatchRange,
968    kern: Kernel,
969) -> Result<NatrBatchOutput, NatrError> {
970    natr_batch_inner(high, low, close, sweep, kern, true)
971}
972
973#[inline(always)]
974fn natr_batch_inner(
975    high: &[f64],
976    low: &[f64],
977    close: &[f64],
978    sweep: &NatrBatchRange,
979    kern: Kernel,
980    parallel: bool,
981) -> Result<NatrBatchOutput, NatrError> {
982    let combos = expand_grid(sweep)?;
983
984    let len_h = high.len();
985    let len_l = low.len();
986    let len_c = close.len();
987    if len_h != len_l || len_h != len_c {
988        return Err(NatrError::MismatchedLength {
989            expected: len_h,
990            actual: if len_l != len_h { len_l } else { len_c },
991        });
992    }
993    let len = len_h;
994
995    let first = high
996        .iter()
997        .position(|x| !x.is_nan())
998        .unwrap_or(0)
999        .max(low.iter().position(|x| !x.is_nan()).unwrap_or(0))
1000        .max(close.iter().position(|x| !x.is_nan()).unwrap_or(0));
1001    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1002    if len - first < max_p {
1003        return Err(NatrError::NotEnoughValidData {
1004            needed: max_p,
1005            valid: len - first,
1006        });
1007    }
1008    let rows = combos.len();
1009    let cols = len;
1010
1011    let mut buf_mu = make_uninit_matrix(rows, cols);
1012    let warm: Vec<usize> = combos
1013        .iter()
1014        .map(|c| first + c.period.unwrap() - 1)
1015        .collect();
1016    init_matrix_prefixes(&mut buf_mu, cols, &warm);
1017
1018    let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
1019    let out: &mut [f64] = unsafe {
1020        core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
1021    };
1022
1023    let use_tr_shared = combos.len() >= 24;
1024
1025    if use_tr_shared {
1026        let mut tr: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, len);
1027        tr.resize(len, 0.0);
1028        if first < len {
1029            tr[first] = high[first] - low[first];
1030            let mut i = first + 1;
1031            while i < len {
1032                let hi = high[i];
1033                let lo = low[i];
1034                let pc = close[i - 1];
1035                let trv = (hi - lo).max((hi - pc).abs()).max((lo - pc).abs());
1036                tr[i] = trv;
1037                i += 1;
1038            }
1039        }
1040
1041        let mut pref: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, len + 1);
1042        pref.resize(len + 1, 0.0);
1043        for i in first..len {
1044            pref[i + 1] = pref[i] + tr[i];
1045        }
1046
1047        let do_row = |row: usize, out_row: &mut [f64]| {
1048            let period = combos[row].period.unwrap();
1049            let inv_p = 1.0 / (period as f64);
1050            let k100 = 100.0;
1051            let warm_end = first + period - 1;
1052
1053            let sum_tr = pref[warm_end + 1] - pref[first];
1054            let mut atr = sum_tr * inv_p;
1055            let cw = close[warm_end];
1056            out_row[warm_end] = if cw.is_finite() && cw != 0.0 {
1057                (atr / cw) * k100
1058            } else {
1059                f64::NAN
1060            };
1061
1062            let mut i = warm_end + 1;
1063            while i < len {
1064                atr = (tr[i] - atr).mul_add(inv_p, atr);
1065                let cv = close[i];
1066                out_row[i] = if cv.is_finite() && cv != 0.0 {
1067                    (atr / cv) * k100
1068                } else {
1069                    f64::NAN
1070                };
1071                i += 1;
1072            }
1073        };
1074
1075        if parallel {
1076            #[cfg(not(target_arch = "wasm32"))]
1077            {
1078                out.par_chunks_mut(cols)
1079                    .enumerate()
1080                    .for_each(|(row, slice)| do_row(row, slice));
1081            }
1082
1083            #[cfg(target_arch = "wasm32")]
1084            {
1085                for (row, slice) in out.chunks_mut(cols).enumerate() {
1086                    do_row(row, slice);
1087                }
1088            }
1089        } else {
1090            for (row, slice) in out.chunks_mut(cols).enumerate() {
1091                do_row(row, slice);
1092            }
1093        }
1094    } else {
1095        let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1096            let period = combos[row].period.unwrap();
1097            match kern {
1098                Kernel::Scalar => natr_row_scalar(high, low, close, first, period, out_row),
1099                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1100                Kernel::Avx2 => natr_row_avx2(high, low, close, first, period, out_row),
1101                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1102                Kernel::Avx512 => natr_row_avx512(high, low, close, first, period, out_row),
1103                _ => unreachable!(),
1104            }
1105        };
1106
1107        if parallel {
1108            #[cfg(not(target_arch = "wasm32"))]
1109            {
1110                out.par_chunks_mut(cols)
1111                    .enumerate()
1112                    .for_each(|(row, slice)| do_row(row, slice));
1113            }
1114
1115            #[cfg(target_arch = "wasm32")]
1116            {
1117                for (row, slice) in out.chunks_mut(cols).enumerate() {
1118                    do_row(row, slice);
1119                }
1120            }
1121        } else {
1122            for (row, slice) in out.chunks_mut(cols).enumerate() {
1123                do_row(row, slice);
1124            }
1125        }
1126    }
1127
1128    let values = unsafe {
1129        Vec::from_raw_parts(
1130            buf_guard.as_mut_ptr() as *mut f64,
1131            buf_guard.len(),
1132            buf_guard.capacity(),
1133        )
1134    };
1135
1136    Ok(NatrBatchOutput {
1137        values,
1138        combos,
1139        rows,
1140        cols,
1141    })
1142}
1143
1144#[inline(always)]
1145fn natr_batch_inner_into(
1146    high: &[f64],
1147    low: &[f64],
1148    close: &[f64],
1149    sweep: &NatrBatchRange,
1150    kern: Kernel,
1151    parallel: bool,
1152    out: &mut [f64],
1153) -> Result<Vec<NatrParams>, NatrError> {
1154    let combos = expand_grid(sweep)?;
1155
1156    let len_h = high.len();
1157    let len_l = low.len();
1158    let len_c = close.len();
1159    if len_h != len_l || len_h != len_c {
1160        return Err(NatrError::MismatchedLength {
1161            expected: len_h,
1162            actual: if len_l != len_h { len_l } else { len_c },
1163        });
1164    }
1165    let len = len_h;
1166
1167    let first = high
1168        .iter()
1169        .position(|x| !x.is_nan())
1170        .unwrap_or(0)
1171        .max(low.iter().position(|x| !x.is_nan()).unwrap_or(0))
1172        .max(close.iter().position(|x| !x.is_nan()).unwrap_or(0));
1173    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1174    if len - first < max_p {
1175        return Err(NatrError::NotEnoughValidData {
1176            needed: max_p,
1177            valid: len - first,
1178        });
1179    }
1180    let rows = combos.len();
1181    let cols = len;
1182
1183    for (row, combo) in combos.iter().enumerate() {
1184        let period = combo.period.unwrap();
1185        let warmup_end = first + period - 1;
1186        let row_start = row * cols;
1187        for i in 0..warmup_end.min(cols) {
1188            out[row_start + i] = f64::NAN;
1189        }
1190    }
1191
1192    let use_tr_shared = combos.len() >= 24;
1193    if use_tr_shared {
1194        let mut tr: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, len);
1195        tr.resize(len, 0.0);
1196        if first < len {
1197            tr[first] = high[first] - low[first];
1198            let mut i = first + 1;
1199            while i < len {
1200                let hi = high[i];
1201                let lo = low[i];
1202                let pc = close[i - 1];
1203                let trv = (hi - lo).max((hi - pc).abs()).max((lo - pc).abs());
1204                tr[i] = trv;
1205                i += 1;
1206            }
1207        }
1208
1209        let mut pref: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, len + 1);
1210        pref.resize(len + 1, 0.0);
1211        for i in first..len {
1212            pref[i + 1] = pref[i] + tr[i];
1213        }
1214
1215        let do_row = |row: usize, out_row: &mut [f64]| {
1216            let period = combos[row].period.unwrap();
1217            let inv_p = 1.0 / (period as f64);
1218            let k100 = 100.0;
1219            let warm_end = first + period - 1;
1220
1221            let sum_tr = pref[warm_end + 1] - pref[first];
1222            let mut atr = sum_tr * inv_p;
1223            let cw = close[warm_end];
1224            out_row[warm_end] = if cw.is_finite() && cw != 0.0 {
1225                (atr / cw) * k100
1226            } else {
1227                f64::NAN
1228            };
1229
1230            let mut i = warm_end + 1;
1231            while i < len {
1232                atr = (tr[i] - atr).mul_add(inv_p, atr);
1233                let cv = close[i];
1234                out_row[i] = if cv.is_finite() && cv != 0.0 {
1235                    (atr / cv) * k100
1236                } else {
1237                    f64::NAN
1238                };
1239                i += 1;
1240            }
1241        };
1242
1243        if parallel {
1244            #[cfg(not(target_arch = "wasm32"))]
1245            {
1246                out.par_chunks_mut(cols)
1247                    .enumerate()
1248                    .for_each(|(row, slice)| do_row(row, slice));
1249            }
1250
1251            #[cfg(target_arch = "wasm32")]
1252            {
1253                for (row, slice) in out.chunks_mut(cols).enumerate() {
1254                    do_row(row, slice);
1255                }
1256            }
1257        } else {
1258            for (row, slice) in out.chunks_mut(cols).enumerate() {
1259                do_row(row, slice);
1260            }
1261        }
1262    } else {
1263        let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1264            let period = combos[row].period.unwrap();
1265            match kern {
1266                Kernel::Scalar => natr_row_scalar(high, low, close, first, period, out_row),
1267                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1268                Kernel::Avx2 => natr_row_avx2(high, low, close, first, period, out_row),
1269                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1270                Kernel::Avx512 => natr_row_avx512(high, low, close, first, period, out_row),
1271                _ => unreachable!(),
1272            }
1273        };
1274
1275        if parallel {
1276            #[cfg(not(target_arch = "wasm32"))]
1277            {
1278                out.par_chunks_mut(cols)
1279                    .enumerate()
1280                    .for_each(|(row, slice)| do_row(row, slice));
1281            }
1282
1283            #[cfg(target_arch = "wasm32")]
1284            {
1285                for (row, slice) in out.chunks_mut(cols).enumerate() {
1286                    do_row(row, slice);
1287                }
1288            }
1289        } else {
1290            for (row, slice) in out.chunks_mut(cols).enumerate() {
1291                do_row(row, slice);
1292            }
1293        }
1294    }
1295
1296    Ok(combos)
1297}
1298
1299#[inline(always)]
1300unsafe fn natr_row_scalar(
1301    high: &[f64],
1302    low: &[f64],
1303    close: &[f64],
1304    first: usize,
1305    period: usize,
1306    out: &mut [f64],
1307) {
1308    natr_scalar(high, low, close, period, first, out)
1309}
1310
1311#[inline(always)]
1312unsafe fn natr_row_scalar_from_tr(
1313    tr: &[f64],
1314    close: &[f64],
1315    first: usize,
1316    period: usize,
1317    out: &mut [f64],
1318) {
1319    let len = out.len();
1320    if first >= len {
1321        return;
1322    }
1323    let inv_p = 1.0 / (period as f64);
1324    let k100 = 100.0;
1325    let warm_end = first + period - 1;
1326
1327    let mut sum_tr = 0.0;
1328    for i in first..=warm_end {
1329        sum_tr += tr[i];
1330    }
1331    let mut atr = sum_tr * inv_p;
1332    let cw = close[warm_end];
1333    out[warm_end] = if cw.is_finite() && cw != 0.0 {
1334        (atr / cw) * k100
1335    } else {
1336        f64::NAN
1337    };
1338
1339    let mut i = warm_end + 1;
1340    while i < len {
1341        atr = (tr[i] - atr).mul_add(inv_p, atr);
1342        let cv = close[i];
1343        out[i] = if cv.is_finite() && cv != 0.0 {
1344            (atr / cv) * k100
1345        } else {
1346            f64::NAN
1347        };
1348        i += 1;
1349    }
1350}
1351
1352#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1353#[inline(always)]
1354unsafe fn natr_row_avx2(
1355    high: &[f64],
1356    low: &[f64],
1357    close: &[f64],
1358    first: usize,
1359    period: usize,
1360    out: &mut [f64],
1361) {
1362    natr_avx2(high, low, close, period, first, out)
1363}
1364
1365#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1366#[inline(always)]
1367pub unsafe fn natr_row_avx512(
1368    high: &[f64],
1369    low: &[f64],
1370    close: &[f64],
1371    first: usize,
1372    period: usize,
1373    out: &mut [f64],
1374) {
1375    natr_avx512(high, low, close, period, first, out);
1376}
1377
1378#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1379#[inline(always)]
1380pub unsafe fn natr_row_avx512_short(
1381    high: &[f64],
1382    low: &[f64],
1383    close: &[f64],
1384    first: usize,
1385    period: usize,
1386    out: &mut [f64],
1387) {
1388    natr_avx512_body(high, low, close, period, first, out)
1389}
1390
1391#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1392#[inline(always)]
1393pub unsafe fn natr_row_avx512_long(
1394    high: &[f64],
1395    low: &[f64],
1396    close: &[f64],
1397    first: usize,
1398    period: usize,
1399    out: &mut [f64],
1400) {
1401    natr_avx512_body(high, low, close, period, first, out)
1402}
1403
1404#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1405#[inline(always)]
1406unsafe fn natr_row_avx2_from_tr(
1407    tr: &[f64],
1408    close: &[f64],
1409    first: usize,
1410    period: usize,
1411    out: &mut [f64],
1412) {
1413    natr_row_scalar_from_tr(tr, close, first, period, out)
1414}
1415
1416#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1417#[inline(always)]
1418unsafe fn natr_row_avx512_from_tr(
1419    tr: &[f64],
1420    close: &[f64],
1421    first: usize,
1422    period: usize,
1423    out: &mut [f64],
1424) {
1425    natr_row_scalar_from_tr(tr, close, first, period, out)
1426}
1427
1428#[cfg(test)]
1429mod tests {
1430    use super::*;
1431    use crate::skip_if_unsupported;
1432    use crate::utilities::data_loader::read_candles_from_csv;
1433    #[cfg(feature = "proptest")]
1434    use proptest::prelude::*;
1435
1436    #[test]
1437    fn test_natr_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1438        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1439        let candles = read_candles_from_csv(file_path)?;
1440
1441        let input = NatrInput::with_default_candles(&candles);
1442
1443        let baseline = natr(&input)?.values;
1444
1445        let mut out = vec![0.0; candles.close.len()];
1446        #[allow(unused_variables)]
1447        {
1448            #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1449            {
1450                natr_into(&input, &mut out)?;
1451            }
1452        }
1453
1454        assert_eq!(baseline.len(), out.len());
1455
1456        fn eq_or_both_nan(a: f64, b: f64) -> bool {
1457            (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
1458        }
1459
1460        for i in 0..baseline.len() {
1461            assert!(
1462                eq_or_both_nan(baseline[i], out[i]),
1463                "NATR into parity mismatch at index {}: baseline={}, into={}",
1464                i,
1465                baseline[i],
1466                out[i]
1467            );
1468        }
1469        Ok(())
1470    }
1471
1472    fn check_natr_partial_params(
1473        test_name: &str,
1474        kernel: Kernel,
1475    ) -> Result<(), Box<dyn std::error::Error>> {
1476        skip_if_unsupported!(kernel, test_name);
1477        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1478        let candles = read_candles_from_csv(file_path)?;
1479        let default_params = NatrParams { period: None };
1480        let input_default = NatrInput::from_candles(&candles, default_params);
1481        let output_default = natr_with_kernel(&input_default, kernel)?;
1482        assert_eq!(output_default.values.len(), candles.close.len());
1483        let params_period_7 = NatrParams { period: Some(7) };
1484        let input_period_7 = NatrInput::from_candles(&candles, params_period_7);
1485        let output_period_7 = natr_with_kernel(&input_period_7, kernel)?;
1486        assert_eq!(output_period_7.values.len(), candles.close.len());
1487        Ok(())
1488    }
1489
1490    fn check_natr_accuracy(
1491        test_name: &str,
1492        kernel: Kernel,
1493    ) -> Result<(), Box<dyn std::error::Error>> {
1494        skip_if_unsupported!(kernel, test_name);
1495        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1496        let candles = read_candles_from_csv(file_path)?;
1497        let close_prices = candles.select_candle_field("close").unwrap();
1498        let params = NatrParams { period: Some(14) };
1499        let input = NatrInput::from_candles(&candles, params.clone());
1500        let natr_result = natr_with_kernel(&input, kernel)?;
1501        assert_eq!(natr_result.values.len(), close_prices.len());
1502        let expected_last_five = [
1503            1.5465877404905772,
1504            1.4773840355794576,
1505            1.4201627494720954,
1506            1.3556212509014807,
1507            1.3836271128536142,
1508        ];
1509        let start_index = natr_result.values.len() - 5;
1510        let result_last_five = &natr_result.values[start_index..];
1511        for (i, &value) in result_last_five.iter().enumerate() {
1512            let expected_value = expected_last_five[i];
1513            assert!(
1514                (value - expected_value).abs() < 1e-8,
1515                "NATR mismatch at index {}: expected {}, got {}",
1516                i,
1517                expected_value,
1518                value
1519            );
1520        }
1521        let period = params.period.unwrap();
1522        for i in 0..(period - 1) {
1523            assert!(natr_result.values[i].is_nan());
1524        }
1525        Ok(())
1526    }
1527
1528    fn check_natr_with_zero_period(
1529        test_name: &str,
1530        kernel: Kernel,
1531    ) -> Result<(), Box<dyn std::error::Error>> {
1532        skip_if_unsupported!(kernel, test_name);
1533        let high = [10.0, 20.0, 30.0];
1534        let low = [5.0, 10.0, 15.0];
1535        let close = [7.0, 14.0, 25.0];
1536        let params = NatrParams { period: Some(0) };
1537        let input = NatrInput::from_slices(&high, &low, &close, params);
1538        let result = natr_with_kernel(&input, kernel);
1539        assert!(result.is_err());
1540        Ok(())
1541    }
1542
1543    fn check_natr_period_exceeds_length(
1544        test_name: &str,
1545        kernel: Kernel,
1546    ) -> Result<(), Box<dyn std::error::Error>> {
1547        skip_if_unsupported!(kernel, test_name);
1548        let high = [10.0, 20.0, 30.0];
1549        let low = [5.0, 10.0, 15.0];
1550        let close = [7.0, 14.0, 25.0];
1551        let params = NatrParams { period: Some(10) };
1552        let input = NatrInput::from_slices(&high, &low, &close, params);
1553        let result = natr_with_kernel(&input, kernel);
1554        assert!(result.is_err());
1555        Ok(())
1556    }
1557
1558    fn check_natr_very_small_dataset(
1559        test_name: &str,
1560        kernel: Kernel,
1561    ) -> Result<(), Box<dyn std::error::Error>> {
1562        skip_if_unsupported!(kernel, test_name);
1563        let high = [42.0];
1564        let low = [40.0];
1565        let close = [41.0];
1566        let params = NatrParams { period: Some(14) };
1567        let input = NatrInput::from_slices(&high, &low, &close, params);
1568        let result = natr_with_kernel(&input, kernel);
1569        assert!(result.is_err());
1570        Ok(())
1571    }
1572
1573    fn check_natr_all_values_nan(
1574        test_name: &str,
1575        kernel: Kernel,
1576    ) -> Result<(), Box<dyn std::error::Error>> {
1577        skip_if_unsupported!(kernel, test_name);
1578        let high = [f64::NAN, f64::NAN];
1579        let low = [f64::NAN, f64::NAN];
1580        let close = [f64::NAN, f64::NAN];
1581        let params = NatrParams { period: Some(2) };
1582        let input = NatrInput::from_slices(&high, &low, &close, params);
1583        let result = natr_with_kernel(&input, kernel);
1584        assert!(result.is_err());
1585        Ok(())
1586    }
1587
1588    fn check_natr_not_enough_valid_data(
1589        test_name: &str,
1590        kernel: Kernel,
1591    ) -> Result<(), Box<dyn std::error::Error>> {
1592        skip_if_unsupported!(kernel, test_name);
1593        let high = [f64::NAN, 10.0];
1594        let low = [f64::NAN, 5.0];
1595        let close = [f64::NAN, 7.0];
1596        let params = NatrParams { period: Some(5) };
1597        let input = NatrInput::from_slices(&high, &low, &close, params);
1598        let result = natr_with_kernel(&input, kernel);
1599        assert!(result.is_err());
1600        Ok(())
1601    }
1602
1603    fn check_natr_slice_data_reinput(
1604        test_name: &str,
1605        kernel: Kernel,
1606    ) -> Result<(), Box<dyn std::error::Error>> {
1607        skip_if_unsupported!(kernel, test_name);
1608        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1609        let candles = read_candles_from_csv(file_path)?;
1610        let first_params = NatrParams { period: Some(14) };
1611        let first_input = NatrInput::from_candles(&candles, first_params);
1612        let first_result = natr_with_kernel(&first_input, kernel)?;
1613        assert_eq!(first_result.values.len(), candles.close.len());
1614
1615        let second_params = NatrParams { period: Some(14) };
1616        let second_input = NatrInput::from_slices(
1617            &first_result.values,
1618            &first_result.values,
1619            &first_result.values,
1620            second_params,
1621        );
1622        let second_result = natr_with_kernel(&second_input, kernel)?;
1623        assert_eq!(second_result.values.len(), first_result.values.len());
1624
1625        for i in 28..second_result.values.len() {
1626            assert!(
1627                !second_result.values[i].is_nan(),
1628                "Expected no NaN after index 28, but found NaN at index {}",
1629                i
1630            );
1631        }
1632        Ok(())
1633    }
1634
1635    fn check_natr_nan_check(
1636        test_name: &str,
1637        kernel: Kernel,
1638    ) -> Result<(), Box<dyn std::error::Error>> {
1639        skip_if_unsupported!(kernel, test_name);
1640        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1641        let candles = read_candles_from_csv(file_path)?;
1642        let params = NatrParams { period: Some(14) };
1643        let input = NatrInput::from_candles(&candles, params);
1644        let natr_result = natr_with_kernel(&input, kernel)?;
1645        assert_eq!(natr_result.values.len(), candles.close.len());
1646        if natr_result.values.len() > 30 {
1647            for i in 30..natr_result.values.len() {
1648                assert!(
1649                    !natr_result.values[i].is_nan(),
1650                    "Expected no NaN after index 30, but found NaN at index {}",
1651                    i
1652                );
1653            }
1654        }
1655        Ok(())
1656    }
1657
1658    fn check_natr_default_candles(
1659        test_name: &str,
1660        kernel: Kernel,
1661    ) -> Result<(), Box<dyn std::error::Error>> {
1662        skip_if_unsupported!(kernel, test_name);
1663        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1664        let candles = read_candles_from_csv(file_path)?;
1665        let input = NatrInput::with_default_candles(&candles);
1666        match input.data {
1667            NatrData::Candles { .. } => {}
1668            _ => panic!("Expected NatrData::Candles variant"),
1669        }
1670        let output = natr_with_kernel(&input, kernel)?;
1671        assert_eq!(output.values.len(), candles.close.len());
1672        Ok(())
1673    }
1674
1675    fn check_natr_streaming(
1676        test_name: &str,
1677        kernel: Kernel,
1678    ) -> Result<(), Box<dyn std::error::Error>> {
1679        skip_if_unsupported!(kernel, test_name);
1680        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1681        let candles = read_candles_from_csv(file_path)?;
1682        let period = 14;
1683        let high = &candles.high;
1684        let low = &candles.low;
1685        let close = &candles.close;
1686        let input = NatrInput::from_slices(
1687            high,
1688            low,
1689            close,
1690            NatrParams {
1691                period: Some(period),
1692            },
1693        );
1694        let batch_output = natr_with_kernel(&input, kernel)?.values;
1695
1696        let mut stream = NatrStream::try_new(NatrParams {
1697            period: Some(period),
1698        })?;
1699        let mut stream_values = Vec::with_capacity(close.len());
1700        for ((&h, &l), &c) in high.iter().zip(low.iter()).zip(close.iter()) {
1701            match stream.update(h, l, c) {
1702                Some(natr_val) => stream_values.push(natr_val),
1703                None => stream_values.push(f64::NAN),
1704            }
1705        }
1706        assert_eq!(batch_output.len(), stream_values.len());
1707        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1708            if b.is_nan() && s.is_nan() {
1709                continue;
1710            }
1711            let diff = (b - s).abs();
1712            assert!(
1713                diff < 1e-9,
1714                "[{}] NATR streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1715                test_name,
1716                i,
1717                b,
1718                s,
1719                diff
1720            );
1721        }
1722        Ok(())
1723    }
1724
1725    #[cfg(debug_assertions)]
1726    fn check_natr_no_poison(
1727        test_name: &str,
1728        kernel: Kernel,
1729    ) -> Result<(), Box<dyn std::error::Error>> {
1730        skip_if_unsupported!(kernel, test_name);
1731
1732        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1733        let candles = read_candles_from_csv(file_path)?;
1734
1735        let test_params = vec![
1736            NatrParams::default(),
1737            NatrParams { period: Some(2) },
1738            NatrParams { period: Some(5) },
1739            NatrParams { period: Some(7) },
1740            NatrParams { period: Some(10) },
1741            NatrParams { period: Some(20) },
1742            NatrParams { period: Some(30) },
1743            NatrParams { period: Some(50) },
1744            NatrParams { period: Some(100) },
1745            NatrParams { period: Some(200) },
1746        ];
1747
1748        for (param_idx, params) in test_params.iter().enumerate() {
1749            let input = NatrInput::from_candles(&candles, params.clone());
1750            let output = natr_with_kernel(&input, kernel)?;
1751
1752            for (i, &val) in output.values.iter().enumerate() {
1753                if val.is_nan() {
1754                    continue;
1755                }
1756
1757                let bits = val.to_bits();
1758
1759                if bits == 0x11111111_11111111 {
1760                    panic!(
1761                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1762						 with params: period={} (param set {})",
1763                        test_name,
1764                        val,
1765                        bits,
1766                        i,
1767                        params.period.unwrap_or(14),
1768                        param_idx
1769                    );
1770                }
1771
1772                if bits == 0x22222222_22222222 {
1773                    panic!(
1774                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1775						 with params: period={} (param set {})",
1776                        test_name,
1777                        val,
1778                        bits,
1779                        i,
1780                        params.period.unwrap_or(14),
1781                        param_idx
1782                    );
1783                }
1784
1785                if bits == 0x33333333_33333333 {
1786                    panic!(
1787                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1788						 with params: period={} (param set {})",
1789                        test_name,
1790                        val,
1791                        bits,
1792                        i,
1793                        params.period.unwrap_or(14),
1794                        param_idx
1795                    );
1796                }
1797            }
1798        }
1799
1800        Ok(())
1801    }
1802
1803    #[cfg(not(debug_assertions))]
1804    fn check_natr_no_poison(
1805        _test_name: &str,
1806        _kernel: Kernel,
1807    ) -> Result<(), Box<dyn std::error::Error>> {
1808        Ok(())
1809    }
1810
1811    fn check_batch_default_row(
1812        test: &str,
1813        kernel: Kernel,
1814    ) -> Result<(), Box<dyn std::error::Error>> {
1815        skip_if_unsupported!(kernel, test);
1816        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1817        let c = read_candles_from_csv(file)?;
1818        let output = NatrBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
1819        let def = NatrParams::default();
1820        let row = output.values_for(&def).expect("default row missing");
1821        assert_eq!(row.len(), c.close.len());
1822        Ok(())
1823    }
1824
1825    #[cfg(debug_assertions)]
1826    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1827        skip_if_unsupported!(kernel, test);
1828
1829        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1830        let c = read_candles_from_csv(file)?;
1831
1832        let test_configs = vec![
1833            (2, 10, 2),
1834            (5, 25, 5),
1835            (30, 60, 15),
1836            (2, 5, 1),
1837            (10, 20, 2),
1838            (50, 100, 10),
1839            (14, 14, 0),
1840        ];
1841
1842        for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1843            let output = NatrBatchBuilder::new()
1844                .kernel(kernel)
1845                .period_range(p_start, p_end, p_step)
1846                .apply_candles(&c)?;
1847
1848            for (idx, &val) in output.values.iter().enumerate() {
1849                if val.is_nan() {
1850                    continue;
1851                }
1852
1853                let bits = val.to_bits();
1854                let row = idx / output.cols;
1855                let col = idx % output.cols;
1856                let combo = &output.combos[row];
1857
1858                if bits == 0x11111111_11111111 {
1859                    panic!(
1860                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1861						 at row {} col {} (flat index {}) with params: period={}",
1862                        test,
1863                        cfg_idx,
1864                        val,
1865                        bits,
1866                        row,
1867                        col,
1868                        idx,
1869                        combo.period.unwrap_or(14)
1870                    );
1871                }
1872
1873                if bits == 0x22222222_22222222 {
1874                    panic!(
1875                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1876						 at row {} col {} (flat index {}) with params: period={}",
1877                        test,
1878                        cfg_idx,
1879                        val,
1880                        bits,
1881                        row,
1882                        col,
1883                        idx,
1884                        combo.period.unwrap_or(14)
1885                    );
1886                }
1887
1888                if bits == 0x33333333_33333333 {
1889                    panic!(
1890                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1891						 at row {} col {} (flat index {}) with params: period={}",
1892                        test,
1893                        cfg_idx,
1894                        val,
1895                        bits,
1896                        row,
1897                        col,
1898                        idx,
1899                        combo.period.unwrap_or(14)
1900                    );
1901                }
1902            }
1903        }
1904
1905        Ok(())
1906    }
1907
1908    #[cfg(not(debug_assertions))]
1909    fn check_batch_no_poison(
1910        _test: &str,
1911        _kernel: Kernel,
1912    ) -> Result<(), Box<dyn std::error::Error>> {
1913        Ok(())
1914    }
1915
1916    #[cfg(feature = "proptest")]
1917    fn check_natr_property(
1918        test_name: &str,
1919        kernel: Kernel,
1920    ) -> Result<(), Box<dyn std::error::Error>> {
1921        skip_if_unsupported!(kernel, test_name);
1922
1923        let strat = (2usize..=50, 50usize..=400, 0usize..=2)
1924            .prop_flat_map(|(period, len, scenario)| {
1925                let close_strategy = match scenario {
1926                    0 => prop::collection::vec(
1927                        (1.0f64..1000.0f64).prop_filter("finite", |x| x.is_finite()),
1928                        len,
1929                    )
1930                    .boxed(),
1931                    1 => prop::collection::vec(
1932                        (0.01f64..1.0f64).prop_filter("finite", |x| x.is_finite()),
1933                        len,
1934                    )
1935                    .boxed(),
1936                    _ => (1.0f64..100.0f64)
1937                        .prop_map(move |val| vec![val; len])
1938                        .boxed(),
1939                };
1940
1941                (close_strategy, Just(period), Just(len), Just(scenario))
1942            })
1943            .prop_flat_map(|(close_prices, period, len, scenario)| {
1944                let mut high_vec = Vec::with_capacity(len);
1945                let mut low_vec = Vec::with_capacity(len);
1946
1947                for (i, &close) in close_prices.iter().enumerate() {
1948                    if scenario == 2 {
1949                        high_vec.push(close);
1950                        low_vec.push(close);
1951                    } else {
1952                        let volatility_factor = 0.001 + 0.20 * ((i * 7919) % 100) as f64 / 100.0;
1953                        let spread = close * volatility_factor;
1954
1955                        let high = close + spread * 0.5;
1956                        let low = close - spread * 0.5;
1957
1958                        high_vec.push(high);
1959                        low_vec.push(low.max(0.001));
1960                    }
1961                }
1962
1963                (
1964                    Just(high_vec),
1965                    Just(low_vec),
1966                    Just(close_prices),
1967                    Just(period),
1968                    Just(scenario),
1969                )
1970            });
1971
1972        proptest::test_runner::TestRunner::default().run(
1973            &strat,
1974            |(high, low, close, period, scenario)| {
1975                let params = NatrParams {
1976                    period: Some(period),
1977                };
1978                let input = NatrInput::from_slices(&high, &low, &close, params);
1979
1980                let result = natr_with_kernel(&input, kernel)?;
1981
1982                let ref_result = natr_with_kernel(&input, Kernel::Scalar)?;
1983
1984                prop_assert_eq!(result.values.len(), high.len());
1985                prop_assert_eq!(result.values.len(), low.len());
1986                prop_assert_eq!(result.values.len(), close.len());
1987
1988                for i in 0..(period - 1) {
1989                    prop_assert!(
1990                        result.values[i].is_nan(),
1991                        "Expected NaN at index {} during warmup, got {}",
1992                        i,
1993                        result.values[i]
1994                    );
1995                }
1996
1997                for i in period..result.values.len() {
1998                    if result.values[i].is_finite() {
1999                        prop_assert!(
2000                            result.values[i] >= 0.0,
2001                            "NATR should be non-negative at index {}: got {}",
2002                            i,
2003                            result.values[i]
2004                        );
2005
2006                        prop_assert!(
2007                            result.values[i] < 10000.0,
2008                            "NATR seems unreasonably high at index {}: got {}",
2009                            i,
2010                            result.values[i]
2011                        );
2012                    }
2013                }
2014
2015                for i in 0..result.values.len() {
2016                    let val = result.values[i];
2017                    let ref_val = ref_result.values[i];
2018
2019                    if val.is_nan() && ref_val.is_nan() {
2020                        continue;
2021                    }
2022
2023                    if val.is_finite() && ref_val.is_finite() {
2024                        let diff = (val - ref_val).abs();
2025                        let tolerance = (ref_val.abs() * 1e-10).max(1e-10);
2026                        prop_assert!(
2027                            diff <= tolerance,
2028                            "Kernel mismatch at index {}: {} vs {} (diff: {})",
2029                            i,
2030                            val,
2031                            ref_val,
2032                            diff
2033                        );
2034                    } else {
2035                        prop_assert_eq!(
2036                            val.is_finite(),
2037                            ref_val.is_finite(),
2038                            "Finite status mismatch at index {}: {} vs {}",
2039                            i,
2040                            val,
2041                            ref_val
2042                        );
2043                    }
2044                }
2045
2046                if scenario == 2 {
2047                    let is_constant = high
2048                        .iter()
2049                        .zip(&low)
2050                        .zip(&close)
2051                        .all(|((h, l), c)| (*h - *l).abs() < 1e-10 && (*h - *c).abs() < 1e-10);
2052
2053                    if is_constant && result.values.len() > period + 5 {
2054                        for i in (period + 5)..result.values.len() {
2055                            if result.values[i].is_finite() {
2056                                prop_assert!(
2057                                    result.values[i].abs() < 1e-10,
2058                                    "NATR should be 0 for constant prices at index {}, got {}",
2059                                    i,
2060                                    result.values[i]
2061                                );
2062                            }
2063                        }
2064                    }
2065                }
2066
2067                if scenario == 1 {
2068                    for i in period..result.values.len() {
2069                        if result.values[i].is_finite() && close[i] > 0.0 {
2070                            prop_assert!(
2071                                result.values[i] >= 0.0 && result.values[i] < 100000.0,
2072                                "NATR out of bounds with small prices at index {}: got {}",
2073                                i,
2074                                result.values[i]
2075                            );
2076                        }
2077                    }
2078                }
2079
2080                if close.iter().any(|&c| c.abs() < 1e-10) {
2081                    for (i, &c) in close.iter().enumerate() {
2082                        if c.abs() < 1e-10 && i >= period - 1 {
2083                            prop_assert!(
2084                                result.values[i] == 0.0 || result.values[i].is_nan(),
2085                                "NATR should be 0 or NaN when close is 0, got {} at index {}",
2086                                result.values[i],
2087                                i
2088                            );
2089                        }
2090                    }
2091                }
2092
2093                #[cfg(debug_assertions)]
2094                {
2095                    for (i, &val) in result.values.iter().enumerate() {
2096                        if val.is_finite() {
2097                            let bits = val.to_bits();
2098                            prop_assert!(
2099                                bits != 0x11111111_11111111
2100                                    && bits != 0x22222222_22222222
2101                                    && bits != 0x33333333_33333333,
2102                                "Found poison value at index {}: {} (0x{:016X})",
2103                                i,
2104                                val,
2105                                bits
2106                            );
2107                        }
2108                    }
2109                }
2110
2111                Ok(())
2112            },
2113        )?;
2114
2115        Ok(())
2116    }
2117
2118    macro_rules! generate_all_natr_tests {
2119        ($($test_fn:ident),*) => {
2120            paste::paste! {
2121                $(
2122                    #[test]
2123                    fn [<$test_fn _scalar_f64>]() {
2124                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2125                    }
2126                )*
2127                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2128                $(
2129                    #[test]
2130                    fn [<$test_fn _avx2_f64>]() {
2131                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2132                    }
2133                    #[test]
2134                    fn [<$test_fn _avx512_f64>]() {
2135                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2136                    }
2137                )*
2138                #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2139                $(
2140                    #[test]
2141                    fn [<$test_fn _simd128_f64>]() {
2142                        let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
2143                    }
2144                )*
2145            }
2146        }
2147    }
2148
2149    #[cfg(feature = "proptest")]
2150    generate_all_natr_tests!(check_natr_property);
2151
2152    generate_all_natr_tests!(
2153        check_natr_partial_params,
2154        check_natr_accuracy,
2155        check_natr_with_zero_period,
2156        check_natr_period_exceeds_length,
2157        check_natr_very_small_dataset,
2158        check_natr_all_values_nan,
2159        check_natr_not_enough_valid_data,
2160        check_natr_slice_data_reinput,
2161        check_natr_nan_check,
2162        check_natr_default_candles,
2163        check_natr_streaming,
2164        check_natr_no_poison
2165    );
2166
2167    macro_rules! gen_batch_tests {
2168        ($fn_name:ident) => {
2169            paste::paste! {
2170                #[test] fn [<$fn_name _scalar>]()      {
2171                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2172                }
2173                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2174                #[test] fn [<$fn_name _avx2>]()        {
2175                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2176                }
2177                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2178                #[test] fn [<$fn_name _avx512>]()      {
2179                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2180                }
2181                #[test] fn [<$fn_name _auto_detect>]() {
2182                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2183                }
2184            }
2185        };
2186    }
2187    gen_batch_tests!(check_batch_default_row);
2188    gen_batch_tests!(check_batch_no_poison);
2189}
2190
2191#[cfg(feature = "python")]
2192#[pyfunction(name = "natr")]
2193#[pyo3(signature = (high, low, close, period, kernel=None))]
2194pub fn natr_py<'py>(
2195    py: Python<'py>,
2196    high: PyReadonlyArray1<'py, f64>,
2197    low: PyReadonlyArray1<'py, f64>,
2198    close: PyReadonlyArray1<'py, f64>,
2199    period: usize,
2200    kernel: Option<&str>,
2201) -> PyResult<Bound<'py, PyArray1<f64>>> {
2202    let high_slice = high.as_slice()?;
2203    let low_slice = low.as_slice()?;
2204    let close_slice = close.as_slice()?;
2205    let kern = validate_kernel(kernel, false)?;
2206
2207    let params = NatrParams {
2208        period: Some(period),
2209    };
2210    let input = NatrInput::from_slices(high_slice, low_slice, close_slice, params);
2211
2212    let result_vec: Vec<f64> = py
2213        .allow_threads(|| natr_with_kernel(&input, kern).map(|o| o.values))
2214        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2215
2216    Ok(result_vec.into_pyarray(py))
2217}
2218
2219#[cfg(feature = "python")]
2220#[pyfunction(name = "natr_batch")]
2221#[pyo3(signature = (high, low, close, period_range, kernel=None))]
2222pub fn natr_batch_py<'py>(
2223    py: Python<'py>,
2224    high: PyReadonlyArray1<'py, f64>,
2225    low: PyReadonlyArray1<'py, f64>,
2226    close: PyReadonlyArray1<'py, f64>,
2227    period_range: (usize, usize, usize),
2228    kernel: Option<&str>,
2229) -> PyResult<Bound<'py, PyDict>> {
2230    let high_slice = high.as_slice()?;
2231    let low_slice = low.as_slice()?;
2232    let close_slice = close.as_slice()?;
2233    if high_slice.len() != low_slice.len() || high_slice.len() != close_slice.len() {
2234        return Err(PyValueError::new_err("natr: Mismatched input lengths"));
2235    }
2236    let cols = high_slice.len();
2237
2238    let kern = validate_kernel(kernel, true)?;
2239    let sweep = NatrBatchRange {
2240        period: period_range,
2241    };
2242    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2243    let rows = combos.len();
2244    let total = rows
2245        .checked_mul(cols)
2246        .ok_or_else(|| PyValueError::new_err("natr_batch: rows*cols overflow"))?;
2247
2248    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2249    let slice_out = unsafe { out_arr.as_slice_mut()? };
2250
2251    let combos = py
2252        .allow_threads(|| {
2253            let kernel = match kern {
2254                Kernel::Auto => detect_best_batch_kernel(),
2255                k => k,
2256            };
2257            let simd = match kernel {
2258                Kernel::Avx512Batch => Kernel::Avx512,
2259                Kernel::Avx2Batch => Kernel::Avx2,
2260                Kernel::ScalarBatch => Kernel::Scalar,
2261                _ => kernel,
2262            };
2263            natr_batch_inner_into(
2264                high_slice,
2265                low_slice,
2266                close_slice,
2267                &sweep,
2268                simd,
2269                true,
2270                slice_out,
2271            )
2272        })
2273        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2274
2275    let dict = PyDict::new(py);
2276    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2277    dict.set_item(
2278        "periods",
2279        combos
2280            .iter()
2281            .map(|p| p.period.unwrap() as u64)
2282            .collect::<Vec<_>>()
2283            .into_pyarray(py),
2284    )?;
2285
2286    Ok(dict)
2287}
2288
2289#[cfg(feature = "python")]
2290#[pyclass(name = "NatrStream")]
2291pub struct NatrStreamPy {
2292    stream: NatrStream,
2293}
2294
2295#[cfg(feature = "python")]
2296#[pymethods]
2297impl NatrStreamPy {
2298    #[new]
2299    fn new(period: usize) -> PyResult<Self> {
2300        let params = NatrParams {
2301            period: Some(period),
2302        };
2303        let stream =
2304            NatrStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2305        Ok(NatrStreamPy { stream })
2306    }
2307
2308    fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
2309        self.stream.update(high, low, close)
2310    }
2311}
2312
2313pub fn natr_into_slice(dst: &mut [f64], input: &NatrInput, kern: Kernel) -> Result<(), NatrError> {
2314    let (high, low, close, period) = match &input.data {
2315        NatrData::Candles { candles } => (
2316            candles.high.as_slice(),
2317            candles.low.as_slice(),
2318            candles.close.as_slice(),
2319            input.params.period.unwrap_or(14),
2320        ),
2321        NatrData::Slices { high, low, close } => {
2322            (*high, *low, *close, input.params.period.unwrap_or(14))
2323        }
2324    };
2325
2326    let len = high.len().min(low.len()).min(close.len());
2327
2328    if dst.len() != len {
2329        return Err(NatrError::OutputLengthMismatch {
2330            expected: len,
2331            got: dst.len(),
2332        });
2333    }
2334
2335    if len == 0 {
2336        return Err(NatrError::EmptyInputData);
2337    }
2338    if period == 0 {
2339        return Err(NatrError::InvalidPeriod {
2340            period,
2341            data_len: len,
2342        });
2343    }
2344    if period > len {
2345        return Err(NatrError::InvalidPeriod {
2346            period,
2347            data_len: len,
2348        });
2349    }
2350
2351    let first_valid_idx = {
2352        let first_valid_idx_h = high.iter().position(|&x| !x.is_nan());
2353        let first_valid_idx_l = low.iter().position(|&x| !x.is_nan());
2354        let first_valid_idx_c = close.iter().position(|&x| !x.is_nan());
2355
2356        match (first_valid_idx_h, first_valid_idx_l, first_valid_idx_c) {
2357            (Some(h), Some(l), Some(c)) => Some(h.max(l).max(c)),
2358            _ => None,
2359        }
2360    };
2361
2362    let first_valid_idx = match first_valid_idx {
2363        Some(idx) => idx,
2364        None => return Err(NatrError::AllValuesNaN),
2365    };
2366
2367    if (len - first_valid_idx) < period {
2368        return Err(NatrError::NotEnoughValidData {
2369            needed: period,
2370            valid: len - first_valid_idx,
2371        });
2372    }
2373
2374    let chosen = match kern {
2375        Kernel::Auto => Kernel::Scalar,
2376        other => other,
2377    };
2378
2379    unsafe {
2380        match chosen {
2381            Kernel::Scalar | Kernel::ScalarBatch => {
2382                natr_scalar(high, low, close, period, first_valid_idx, dst)
2383            }
2384            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2385            Kernel::Avx2 | Kernel::Avx2Batch => {
2386                natr_avx2(high, low, close, period, first_valid_idx, dst)
2387            }
2388            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2389            Kernel::Avx512 | Kernel::Avx512Batch => {
2390                natr_avx512(high, low, close, period, first_valid_idx, dst)
2391            }
2392            _ => unreachable!(),
2393        }
2394    }
2395
2396    for v in &mut dst[..(first_valid_idx + period - 1)] {
2397        *v = f64::NAN;
2398    }
2399
2400    Ok(())
2401}
2402
2403#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2404#[inline]
2405pub fn natr_into(input: &NatrInput, out: &mut [f64]) -> Result<(), NatrError> {
2406    natr_into_slice(out, input, Kernel::Auto)
2407}
2408
2409#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2410#[wasm_bindgen]
2411pub fn natr_js(
2412    high: &[f64],
2413    low: &[f64],
2414    close: &[f64],
2415    period: usize,
2416) -> Result<Vec<f64>, JsValue> {
2417    if high.len() != low.len() || high.len() != close.len() {
2418        return Err(JsValue::from_str("natr: Mismatched input lengths"));
2419    }
2420    let params = NatrParams {
2421        period: Some(period),
2422    };
2423    let input = NatrInput::from_slices(high, low, close, params);
2424    let mut output = vec![0.0; high.len()];
2425    natr_into_slice(&mut output, &input, detect_best_kernel())
2426        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2427    Ok(output)
2428}
2429
2430#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2431#[wasm_bindgen]
2432pub fn natr_into(
2433    high_ptr: *const f64,
2434    low_ptr: *const f64,
2435    close_ptr: *const f64,
2436    out_ptr: *mut f64,
2437    len: usize,
2438    period: usize,
2439) -> Result<(), JsValue> {
2440    if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
2441        return Err(JsValue::from_str("Null pointer provided"));
2442    }
2443
2444    unsafe {
2445        let high = std::slice::from_raw_parts(high_ptr, len);
2446        let low = std::slice::from_raw_parts(low_ptr, len);
2447        let close = std::slice::from_raw_parts(close_ptr, len);
2448        let params = NatrParams {
2449            period: Some(period),
2450        };
2451        let input = NatrInput::from_slices(high, low, close, params);
2452
2453        if high_ptr == out_ptr || low_ptr == out_ptr || close_ptr == out_ptr {
2454            let mut temp = vec![0.0; len];
2455            natr_into_slice(&mut temp, &input, detect_best_kernel())
2456                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2457            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2458            out.copy_from_slice(&temp);
2459        } else {
2460            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2461            natr_into_slice(out, &input, detect_best_kernel())
2462                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2463        }
2464        Ok(())
2465    }
2466}
2467
2468#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2469#[wasm_bindgen]
2470pub fn natr_alloc(len: usize) -> *mut f64 {
2471    let mut vec = Vec::<f64>::with_capacity(len);
2472    let ptr = vec.as_mut_ptr();
2473    std::mem::forget(vec);
2474    ptr
2475}
2476
2477#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2478#[wasm_bindgen]
2479pub fn natr_free(ptr: *mut f64, len: usize) {
2480    if !ptr.is_null() {
2481        unsafe {
2482            let _ = Vec::from_raw_parts(ptr, len, len);
2483        }
2484    }
2485}
2486
2487#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2488#[derive(Serialize, Deserialize)]
2489pub struct NatrBatchConfig {
2490    pub period_range: (usize, usize, usize),
2491}
2492
2493#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2494#[derive(Serialize, Deserialize)]
2495pub struct NatrBatchJsOutput {
2496    pub values: Vec<f64>,
2497    pub combos: Vec<NatrParams>,
2498    pub rows: usize,
2499    pub cols: usize,
2500}
2501
2502#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2503#[wasm_bindgen(js_name = natr_batch)]
2504pub fn natr_batch_unified_js(
2505    high: &[f64],
2506    low: &[f64],
2507    close: &[f64],
2508    config: JsValue,
2509) -> Result<JsValue, JsValue> {
2510    if high.len() != low.len() || high.len() != close.len() {
2511        return Err(JsValue::from_str("natr: Mismatched input lengths"));
2512    }
2513    let cfg: NatrBatchConfig = serde_wasm_bindgen::from_value(config)
2514        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2515
2516    let sweep = NatrBatchRange {
2517        period: cfg.period_range,
2518    };
2519
2520    let out = natr_batch_inner(high, low, close, &sweep, detect_best_kernel(), false)
2521        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2522
2523    let js = NatrBatchJsOutput {
2524        values: out.values,
2525        combos: out.combos,
2526        rows: out.rows,
2527        cols: out.cols,
2528    };
2529    serde_wasm_bindgen::to_value(&js)
2530        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2531}
2532
2533#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2534#[wasm_bindgen]
2535pub fn natr_batch_into(
2536    high_ptr: *const f64,
2537    low_ptr: *const f64,
2538    close_ptr: *const f64,
2539    out_ptr: *mut f64,
2540    len: usize,
2541    period_start: usize,
2542    period_end: usize,
2543    period_step: usize,
2544) -> Result<usize, JsValue> {
2545    if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
2546        return Err(JsValue::from_str("null pointer passed to natr_batch_into"));
2547    }
2548
2549    unsafe {
2550        let high = std::slice::from_raw_parts(high_ptr, len);
2551        let low = std::slice::from_raw_parts(low_ptr, len);
2552        let close = std::slice::from_raw_parts(close_ptr, len);
2553
2554        let sweep = NatrBatchRange {
2555            period: (period_start, period_end, period_step),
2556        };
2557
2558        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2559        let rows = combos.len();
2560        let cols = len;
2561        let total = rows
2562            .checked_mul(cols)
2563            .ok_or_else(|| JsValue::from_str("natr_batch_into: rows*cols overflow"))?;
2564
2565        let out = std::slice::from_raw_parts_mut(out_ptr, total);
2566
2567        natr_batch_inner_into(high, low, close, &sweep, detect_best_kernel(), false, out)
2568            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2569
2570        Ok(rows)
2571    }
2572}
2573
2574#[cfg(feature = "python")]
2575pub fn register_natr_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
2576    m.add_function(wrap_pyfunction!(natr_py, m)?)?;
2577    m.add_function(wrap_pyfunction!(natr_batch_py, m)?)?;
2578    #[cfg(feature = "cuda")]
2579    {
2580        m.add_class::<NatrDeviceArrayF32Py>()?;
2581        m.add_function(wrap_pyfunction!(natr_cuda_batch_dev_py, m)?)?;
2582        m.add_function(wrap_pyfunction!(natr_cuda_many_series_one_param_dev_py, m)?)?;
2583    }
2584    Ok(())
2585}
2586
2587#[cfg(all(feature = "python", feature = "cuda"))]
2588use crate::cuda::natr_wrapper::DeviceArrayF32Natr;
2589#[cfg(all(feature = "python", feature = "cuda"))]
2590use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2591#[cfg(all(feature = "python", feature = "cuda"))]
2592use cust::context::Context;
2593#[cfg(all(feature = "python", feature = "cuda"))]
2594use cust::memory::DeviceBuffer;
2595#[cfg(all(feature = "python", feature = "cuda"))]
2596use pyo3::exceptions::PyBufferError;
2597#[cfg(all(feature = "python", feature = "cuda"))]
2598use pyo3::PyObject;
2599#[cfg(all(feature = "python", feature = "cuda"))]
2600use std::sync::Arc;
2601
2602#[cfg(all(feature = "python", feature = "cuda"))]
2603#[pyclass(module = "ta_indicators.cuda", unsendable)]
2604pub struct NatrDeviceArrayF32Py {
2605    pub(crate) buf: Option<DeviceBuffer<f32>>,
2606    pub(crate) rows: usize,
2607    pub(crate) cols: usize,
2608    pub(crate) ctx: Arc<Context>,
2609    pub(crate) device_id: u32,
2610}
2611
2612#[cfg(all(feature = "python", feature = "cuda"))]
2613#[pymethods]
2614impl NatrDeviceArrayF32Py {
2615    #[getter]
2616    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2617        let d = PyDict::new(py);
2618        d.set_item("shape", (self.rows, self.cols))?;
2619        d.set_item("typestr", "<f4")?;
2620        d.set_item(
2621            "strides",
2622            (
2623                self.cols * std::mem::size_of::<f32>(),
2624                std::mem::size_of::<f32>(),
2625            ),
2626        )?;
2627        let ptr = self
2628            .buf
2629            .as_ref()
2630            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
2631            .as_device_ptr()
2632            .as_raw() as usize;
2633        d.set_item("data", (ptr, false))?;
2634
2635        d.set_item("version", 3)?;
2636        Ok(d)
2637    }
2638
2639    fn __dlpack_device__(&self) -> (i32, i32) {
2640        (2, self.device_id as i32)
2641    }
2642
2643    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2644    fn __dlpack__<'py>(
2645        &mut self,
2646        py: Python<'py>,
2647        stream: Option<PyObject>,
2648        max_version: Option<PyObject>,
2649        dl_device: Option<PyObject>,
2650        copy: Option<PyObject>,
2651    ) -> PyResult<PyObject> {
2652        let (kdl, alloc_dev) = self.__dlpack_device__();
2653        if let Some(dev_obj) = dl_device.as_ref() {
2654            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2655                if dev_ty != kdl || dev_id != alloc_dev {
2656                    let wants_copy = copy
2657                        .as_ref()
2658                        .and_then(|c| c.extract::<bool>(py).ok())
2659                        .unwrap_or(false);
2660                    if wants_copy {
2661                        return Err(PyBufferError::new_err(
2662                            "device copy not implemented for __dlpack__",
2663                        ));
2664                    } else {
2665                        return Err(PyBufferError::new_err(
2666                            "__dlpack__: requested device does not match producer buffer",
2667                        ));
2668                    }
2669                }
2670            }
2671        }
2672        let _ = stream;
2673
2674        if let Some(copy_obj) = copy.as_ref() {
2675            let do_copy: bool = copy_obj.extract(py)?;
2676            if do_copy {
2677                return Err(PyBufferError::new_err(
2678                    "__dlpack__(copy=True) not supported for natr CUDA buffers",
2679                ));
2680            }
2681        }
2682
2683        let rows = self.rows;
2684        let cols = self.cols;
2685        let buf = self
2686            .buf
2687            .take()
2688            .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
2689
2690        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2691
2692        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2693    }
2694}
2695
2696#[cfg(all(feature = "python", feature = "cuda"))]
2697#[pyfunction(name = "natr_cuda_batch_dev")]
2698#[pyo3(signature = (high, low, close, period_range, device_id=0))]
2699pub fn natr_cuda_batch_dev_py(
2700    py: Python<'_>,
2701    high: numpy::PyReadonlyArray1<'_, f32>,
2702    low: numpy::PyReadonlyArray1<'_, f32>,
2703    close: numpy::PyReadonlyArray1<'_, f32>,
2704    period_range: (usize, usize, usize),
2705    device_id: usize,
2706) -> PyResult<NatrDeviceArrayF32Py> {
2707    use crate::cuda::cuda_available;
2708    use crate::cuda::CudaNatr;
2709
2710    if !cuda_available() {
2711        return Err(PyValueError::new_err("CUDA not available"));
2712    }
2713    let h = high.as_slice()?;
2714    let l = low.as_slice()?;
2715    let c = close.as_slice()?;
2716    if h.len() != l.len() || h.len() != c.len() {
2717        return Err(PyValueError::new_err("mismatched input lengths"));
2718    }
2719    let sweep = NatrBatchRange {
2720        period: period_range,
2721    };
2722    let dev = py.allow_threads(|| {
2723        let mut cuda =
2724            CudaNatr::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2725        cuda.natr_batch_dev(h, l, c, &sweep)
2726            .map_err(|e| PyValueError::new_err(e.to_string()))
2727    })?;
2728    let DeviceArrayF32Natr {
2729        buf,
2730        rows,
2731        cols,
2732        ctx,
2733        device_id,
2734    } = dev;
2735    Ok(NatrDeviceArrayF32Py {
2736        buf: Some(buf),
2737        rows,
2738        cols,
2739        ctx,
2740        device_id,
2741    })
2742}
2743
2744#[cfg(all(feature = "python", feature = "cuda"))]
2745#[pyfunction(name = "natr_cuda_many_series_one_param_dev")]
2746#[pyo3(signature = (high_tm, low_tm, close_tm, period, device_id=0))]
2747pub fn natr_cuda_many_series_one_param_dev_py(
2748    py: Python<'_>,
2749    high_tm: numpy::PyReadonlyArray2<'_, f32>,
2750    low_tm: numpy::PyReadonlyArray2<'_, f32>,
2751    close_tm: numpy::PyReadonlyArray2<'_, f32>,
2752    period: usize,
2753    device_id: usize,
2754) -> PyResult<NatrDeviceArrayF32Py> {
2755    use crate::cuda::cuda_available;
2756    use crate::cuda::CudaNatr;
2757    use numpy::PyUntypedArrayMethods;
2758
2759    if !cuda_available() {
2760        return Err(PyValueError::new_err("CUDA not available"));
2761    }
2762    let h = high_tm.as_slice()?;
2763    let l = low_tm.as_slice()?;
2764    let c = close_tm.as_slice()?;
2765    let rows = high_tm.shape()[0];
2766    let cols = high_tm.shape()[1];
2767
2768    let dev = py.allow_threads(|| {
2769        let mut cuda =
2770            CudaNatr::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2771        cuda.natr_many_series_one_param_time_major_dev(h, l, c, cols, rows, period)
2772            .map_err(|e| PyValueError::new_err(e.to_string()))
2773    })?;
2774    let DeviceArrayF32Natr {
2775        buf,
2776        rows,
2777        cols,
2778        ctx,
2779        device_id,
2780    } = dev;
2781    Ok(NatrDeviceArrayF32Py {
2782        buf: Some(buf),
2783        rows,
2784        cols,
2785        ctx,
2786        device_id,
2787    })
2788}