Skip to main content

vector_ta/indicators/
atr.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods, PyUntypedArrayMethods};
3#[cfg(feature = "python")]
4use pyo3::exceptions::{PyBufferError, PyValueError};
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::{PyDict, PyList};
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15#[cfg(all(feature = "python", feature = "cuda"))]
16use crate::cuda::{cuda_available, CudaAtr};
17use crate::utilities::data_loader::{source_type, Candles};
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21    make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25use aligned_vec::{AVec, CACHELINE_ALIGN};
26#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
27use core::arch::x86_64::*;
28#[cfg(not(target_arch = "wasm32"))]
29use rayon::prelude::*;
30use std::error::Error;
31use thiserror::Error;
32
33#[derive(Debug, Clone)]
34pub enum AtrData<'a> {
35    Candles {
36        candles: &'a Candles,
37    },
38    Slices {
39        high: &'a [f64],
40        low: &'a [f64],
41        close: &'a [f64],
42    },
43}
44
45#[derive(Debug, Clone)]
46pub struct AtrOutput {
47    pub values: Vec<f64>,
48}
49
50#[derive(Debug, Clone)]
51#[cfg_attr(
52    all(target_arch = "wasm32", feature = "wasm"),
53    derive(Serialize, Deserialize)
54)]
55pub struct AtrParams {
56    pub length: Option<usize>,
57}
58
59impl Default for AtrParams {
60    fn default() -> Self {
61        Self { length: Some(14) }
62    }
63}
64
65#[derive(Debug, Clone)]
66pub struct AtrInput<'a> {
67    pub data: AtrData<'a>,
68    pub params: AtrParams,
69}
70
71impl<'a> AtrInput<'a> {
72    #[inline]
73    pub fn from_candles(candles: &'a Candles, params: AtrParams) -> Self {
74        Self {
75            data: AtrData::Candles { candles },
76            params,
77        }
78    }
79    #[inline]
80    pub fn from_slices(
81        high: &'a [f64],
82        low: &'a [f64],
83        close: &'a [f64],
84        params: AtrParams,
85    ) -> Self {
86        Self {
87            data: AtrData::Slices { high, low, close },
88            params,
89        }
90    }
91    #[inline]
92    pub fn with_default_candles(candles: &'a Candles) -> Self {
93        Self::from_candles(candles, AtrParams::default())
94    }
95    #[inline]
96    pub fn get_length(&self) -> usize {
97        self.params.length.unwrap_or(14)
98    }
99}
100
101#[derive(Copy, Clone, Debug)]
102pub struct AtrBuilder {
103    length: Option<usize>,
104    kernel: Kernel,
105}
106
107impl Default for AtrBuilder {
108    fn default() -> Self {
109        Self {
110            length: None,
111            kernel: Kernel::Auto,
112        }
113    }
114}
115
116impl AtrBuilder {
117    #[inline(always)]
118    pub fn new() -> Self {
119        Self::default()
120    }
121    #[inline(always)]
122    pub fn length(mut self, n: usize) -> Self {
123        self.length = Some(n);
124        self
125    }
126    #[inline(always)]
127    pub fn kernel(mut self, k: Kernel) -> Self {
128        self.kernel = k;
129        self
130    }
131    #[inline(always)]
132    pub fn apply(self, c: &Candles) -> Result<AtrOutput, AtrError> {
133        let p = AtrParams {
134            length: self.length,
135        };
136        let i = AtrInput::from_candles(c, p);
137        atr_with_kernel(&i, self.kernel)
138    }
139    #[inline(always)]
140    pub fn apply_slices(
141        self,
142        high: &[f64],
143        low: &[f64],
144        close: &[f64],
145    ) -> Result<AtrOutput, AtrError> {
146        let p = AtrParams {
147            length: self.length,
148        };
149        let i = AtrInput::from_slices(high, low, close, p);
150        atr_with_kernel(&i, self.kernel)
151    }
152    #[inline(always)]
153    pub fn into_stream(self) -> Result<AtrStream, AtrError> {
154        let p = AtrParams {
155            length: self.length,
156        };
157        AtrStream::try_new(p)
158    }
159}
160
161#[derive(Debug, Error)]
162pub enum AtrError {
163    #[error("atr: Input data slice is empty.")]
164    EmptyInputData,
165    #[error("atr: All values are NaN.")]
166    AllValuesNaN,
167    #[error("atr: Invalid period: period = {period}, data length = {data_len}")]
168    InvalidPeriod { period: usize, data_len: usize },
169    #[error("atr: Not enough valid data: needed = {needed}, valid = {valid}")]
170    NotEnoughValidData { needed: usize, valid: usize },
171    #[error("atr: Output slice length mismatch: expected = {expected}, got = {got}")]
172    OutputLengthMismatch { expected: usize, got: usize },
173    #[error("atr: Invalid range: start = {start}, end = {end}, step = {step}")]
174    InvalidRange {
175        start: usize,
176        end: usize,
177        step: usize,
178    },
179    #[error("atr: Invalid kernel type for batch operation: {0:?}")]
180    InvalidKernelForBatch(Kernel),
181
182    #[error("Invalid length for ATR calculation (length={length}).")]
183    InvalidLength { length: usize },
184    #[error("Inconsistent slice lengths for ATR calculation: high={high_len}, low={low_len}, close={close_len}")]
185    InconsistentSliceLengths {
186        high_len: usize,
187        low_len: usize,
188        close_len: usize,
189    },
190    #[error("atr: No candles available for ATR calculation.")]
191    NoCandlesAvailable,
192    #[error("Not enough data to calculate ATR: length={length}, data length={data_len}")]
193    NotEnoughData { length: usize, data_len: usize },
194}
195
196#[inline(always)]
197fn first_valid_hlc(high: &[f64], low: &[f64], close: &[f64]) -> usize {
198    let len = close.len();
199    let mut i = 0;
200    while i < len {
201        if !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan() {
202            break;
203        }
204        i += 1;
205    }
206    i.min(len)
207}
208
209#[inline(always)]
210fn atr_prepare_full<'a>(
211    high: &'a [f64],
212    low: &'a [f64],
213    close: &'a [f64],
214    length: usize,
215) -> Result<(&'a [f64], &'a [f64], &'a [f64], usize, usize), AtrError> {
216    let (high, low, close, length) = atr_prepare(high, low, close, length)?;
217    let first = first_valid_hlc(high, low, close);
218    if first >= close.len() {
219        return Err(AtrError::AllValuesNaN);
220    }
221    let valid = close.len().saturating_sub(first);
222    if valid < length {
223        return Err(AtrError::NotEnoughValidData {
224            needed: length,
225            valid,
226        });
227    }
228    let warmup = first + length - 1;
229    Ok((high, low, close, first, warmup))
230}
231
232#[inline]
233pub fn atr(input: &AtrInput) -> Result<AtrOutput, AtrError> {
234    atr_with_kernel(input, Kernel::Auto)
235}
236
237pub fn atr_with_kernel(input: &AtrInput, kernel: Kernel) -> Result<AtrOutput, AtrError> {
238    let (high, low, close) = match &input.data {
239        AtrData::Candles { candles } => (
240            candles
241                .select_candle_field("high")
242                .map_err(|_| AtrError::NoCandlesAvailable)?,
243            candles
244                .select_candle_field("low")
245                .map_err(|_| AtrError::NoCandlesAvailable)?,
246            candles
247                .select_candle_field("close")
248                .map_err(|_| AtrError::NoCandlesAvailable)?,
249        ),
250        AtrData::Slices { high, low, close } => {
251            if high.len() != low.len() || low.len() != close.len() {
252                return Err(AtrError::InconsistentSliceLengths {
253                    high_len: high.len(),
254                    low_len: low.len(),
255                    close_len: close.len(),
256                });
257            }
258            (*high, *low, *close)
259        }
260    };
261
262    let len = close.len();
263    let length = input.get_length();
264    if length == 0 {
265        return Err(AtrError::InvalidLength { length });
266    }
267    if len == 0 {
268        return Err(AtrError::NoCandlesAvailable);
269    }
270    if length > len {
271        return Err(AtrError::NotEnoughData {
272            length,
273            data_len: len,
274        });
275    }
276
277    let chosen = match kernel {
278        Kernel::Auto => Kernel::Scalar,
279        k => k,
280    };
281
282    let (_, _, _, first, warmup) = atr_prepare_full(high, low, close, length)?;
283    let mut out = alloc_with_nan_prefix(len, warmup);
284    atr_compute_into(high, low, close, length, first, chosen, &mut out);
285    Ok(AtrOutput { values: out })
286}
287
288#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
289pub fn atr_into(input: &AtrInput, out: &mut [f64]) -> Result<(), AtrError> {
290    let (high, low, close) = match &input.data {
291        AtrData::Candles { candles } => (&candles.high[..], &candles.low[..], &candles.close[..]),
292        AtrData::Slices { high, low, close } => (*high, *low, *close),
293    };
294
295    let length = input.params.length.unwrap_or(14);
296    let (high, low, close, length) = atr_prepare(high, low, close, length)?;
297
298    let first = first_valid_hlc(high, low, close);
299    let valid = close.len().saturating_sub(first);
300    if valid < length {
301        return Err(AtrError::NotEnoughValidData {
302            needed: length,
303            valid,
304        });
305    }
306    let warmup = first + length - 1;
307
308    if out.len() != close.len() {
309        return Err(AtrError::OutputLengthMismatch {
310            expected: close.len(),
311            got: out.len(),
312        });
313    }
314
315    let prefix = warmup.min(out.len());
316    for v in &mut out[..prefix] {
317        *v = f64::from_bits(0x7ff8_0000_0000_0000);
318    }
319
320    let chosen = match Kernel::Auto {
321        Kernel::Auto => Kernel::Scalar,
322        k => k,
323    };
324    atr_compute_into(high, low, close, length, first, chosen, out);
325    Ok(())
326}
327
328#[inline(always)]
329fn atr_compute_into_scalar(
330    high: &[f64],
331    low: &[f64],
332    close: &[f64],
333    length: usize,
334    first: usize,
335    out: &mut [f64],
336) {
337    debug_assert_eq!(high.len(), low.len());
338    debug_assert_eq!(low.len(), close.len());
339    debug_assert_eq!(out.len(), close.len());
340
341    let warm = first + length - 1;
342    let alpha = 1.0 / (length as f64);
343
344    unsafe {
345        let mut sum_tr = *high.get_unchecked(first) - *low.get_unchecked(first);
346
347        if warm > first {
348            let mut i = first + 1;
349            let mut prev_c = *close.get_unchecked(i - 1);
350            while i <= warm {
351                let hi = *high.get_unchecked(i);
352                let lo = *low.get_unchecked(i);
353
354                let mut tr = hi - lo;
355                let hc = (hi - prev_c).abs();
356                if hc > tr {
357                    tr = hc;
358                }
359                let lc = (lo - prev_c).abs();
360                if lc > tr {
361                    tr = lc;
362                }
363
364                sum_tr += tr;
365                prev_c = *close.get_unchecked(i);
366                i += 1;
367            }
368        }
369
370        let mut rma = sum_tr / (length as f64);
371        *out.get_unchecked_mut(warm) = rma;
372
373        let mut i = warm + 1;
374        let n = out.len();
375
376        let mut prev_c = if i > 0 {
377            *close.get_unchecked(i - 1)
378        } else {
379            *close.get_unchecked(0)
380        };
381
382        while i + 3 < n {
383            let (hi0, lo0) = (*high.get_unchecked(i), *low.get_unchecked(i));
384            let mut tr0 = hi0 - lo0;
385            let hc0 = (hi0 - prev_c).abs();
386            if hc0 > tr0 {
387                tr0 = hc0;
388            }
389            let lc0 = (lo0 - prev_c).abs();
390            if lc0 > tr0 {
391                tr0 = lc0;
392            }
393            rma = (-alpha).mul_add(rma, rma) + alpha * tr0;
394            *out.get_unchecked_mut(i) = rma;
395
396            let prev0 = *close.get_unchecked(i);
397            let (hi1, lo1) = (*high.get_unchecked(i + 1), *low.get_unchecked(i + 1));
398            let mut tr1 = hi1 - lo1;
399            let hc1 = (hi1 - prev0).abs();
400            if hc1 > tr1 {
401                tr1 = hc1;
402            }
403            let lc1 = (lo1 - prev0).abs();
404            if lc1 > tr1 {
405                tr1 = lc1;
406            }
407            rma = (-alpha).mul_add(rma, rma) + alpha * tr1;
408            *out.get_unchecked_mut(i + 1) = rma;
409
410            let prev1 = *close.get_unchecked(i + 1);
411            let (hi2, lo2) = (*high.get_unchecked(i + 2), *low.get_unchecked(i + 2));
412            let mut tr2 = hi2 - lo2;
413            let hc2 = (hi2 - prev1).abs();
414            if hc2 > tr2 {
415                tr2 = hc2;
416            }
417            let lc2 = (lo2 - prev1).abs();
418            if lc2 > tr2 {
419                tr2 = lc2;
420            }
421            rma = (-alpha).mul_add(rma, rma) + alpha * tr2;
422            *out.get_unchecked_mut(i + 2) = rma;
423
424            let prev2 = *close.get_unchecked(i + 2);
425            let (hi3, lo3) = (*high.get_unchecked(i + 3), *low.get_unchecked(i + 3));
426            let mut tr3 = hi3 - lo3;
427            let hc3 = (hi3 - prev2).abs();
428            if hc3 > tr3 {
429                tr3 = hc3;
430            }
431            let lc3 = (lo3 - prev2).abs();
432            if lc3 > tr3 {
433                tr3 = lc3;
434            }
435            rma = (-alpha).mul_add(rma, rma) + alpha * tr3;
436            *out.get_unchecked_mut(i + 3) = rma;
437
438            i += 4;
439            prev_c = *close.get_unchecked(i - 1);
440        }
441
442        while i < n {
443            let (hi, lo) = (*high.get_unchecked(i), *low.get_unchecked(i));
444            let mut tr = hi - lo;
445            let hc = (hi - prev_c).abs();
446            if hc > tr {
447                tr = hc;
448            }
449            let lc = (lo - prev_c).abs();
450            if lc > tr {
451                tr = lc;
452            }
453            rma = (-alpha).mul_add(rma, rma) + alpha * tr;
454            *out.get_unchecked_mut(i) = rma;
455
456            prev_c = *close.get_unchecked(i);
457            i += 1;
458        }
459    }
460}
461
462#[inline]
463pub fn atr_scalar(high: &[f64], low: &[f64], close: &[f64], length: usize, out: &mut [f64]) {
464    atr_compute_into_scalar(high, low, close, length, 0, out);
465}
466
467#[inline(always)]
468fn atr_compute_into(
469    high: &[f64],
470    low: &[f64],
471    close: &[f64],
472    length: usize,
473    first: usize,
474    kern: Kernel,
475    out: &mut [f64],
476) {
477    unsafe {
478        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
479        {
480            if matches!(kern, Kernel::Scalar | Kernel::ScalarBatch) {
481                atr_compute_into_scalar(high, low, close, length, first, out);
482                return;
483            }
484        }
485        match kern {
486            Kernel::Scalar | Kernel::ScalarBatch => {
487                atr_compute_into_scalar(high, low, close, length, first, out)
488            }
489            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
490            Kernel::Avx2 | Kernel::Avx2Batch => {
491                atr_compute_into_avx2(high, low, close, length, first, out)
492            }
493            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
494            Kernel::Avx512 | Kernel::Avx512Batch => {
495                atr_compute_into_avx512(high, low, close, length, first, out)
496            }
497            _ => unreachable!(),
498        }
499    }
500}
501
502#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
503#[inline]
504unsafe fn atr_simd128(high: &[f64], low: &[f64], close: &[f64], length: usize, out: &mut [f64]) {
505    use core::arch::wasm32::*;
506
507    atr_scalar(high, low, close, length, out);
508}
509
510#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
511#[inline(always)]
512unsafe fn atr_compute_into_avx2(
513    high: &[f64],
514    low: &[f64],
515    close: &[f64],
516    length: usize,
517    first: usize,
518    out: &mut [f64],
519) {
520    use core::arch::x86_64::*;
521
522    debug_assert_eq!(high.len(), low.len());
523    debug_assert_eq!(low.len(), close.len());
524    debug_assert_eq!(out.len(), close.len());
525
526    let warm = first + length - 1;
527    let alpha = 1.0 / (length as f64);
528
529    let mut sum_tr = *high.get_unchecked(first) - *low.get_unchecked(first);
530    if warm > first {
531        let mut i = first + 1;
532        let mut prev_c = *close.get_unchecked(i - 1);
533        while i <= warm {
534            let hi = *high.get_unchecked(i);
535            let lo = *low.get_unchecked(i);
536
537            let mut tr = hi - lo;
538            let hc = (hi - prev_c).abs();
539            if hc > tr {
540                tr = hc;
541            }
542            let lc = (lo - prev_c).abs();
543            if lc > tr {
544                tr = lc;
545            }
546
547            sum_tr += tr;
548            prev_c = *close.get_unchecked(i);
549            i += 1;
550        }
551    }
552
553    let mut rma = sum_tr / (length as f64);
554    *out.get_unchecked_mut(warm) = rma;
555
556    let mut i = warm + 1;
557    let n = out.len();
558
559    let mask_abs = _mm256_castsi256_pd(_mm256_set1_epi64x(0x7fff_ffff_ffff_ffffu64 as i64));
560
561    while i + 3 < n {
562        let v_hi = _mm256_loadu_pd(high.as_ptr().add(i));
563        let v_lo = _mm256_loadu_pd(low.as_ptr().add(i));
564
565        let v_pc = _mm256_loadu_pd(close.as_ptr().add(i - 1));
566
567        let v_hl = _mm256_sub_pd(v_hi, v_lo);
568
569        let v_hc = _mm256_and_pd(_mm256_sub_pd(v_hi, v_pc), mask_abs);
570
571        let v_lc = _mm256_and_pd(_mm256_sub_pd(v_lo, v_pc), mask_abs);
572
573        let v_m1 = _mm256_max_pd(v_hl, v_hc);
574        let v_tr = _mm256_max_pd(v_m1, v_lc);
575
576        let mut buf = [0.0f64; 4];
577        _mm256_storeu_pd(buf.as_mut_ptr(), v_tr);
578
579        rma = (-alpha).mul_add(rma, rma) + alpha * buf[0];
580        *out.get_unchecked_mut(i) = rma;
581
582        rma = (-alpha).mul_add(rma, rma) + alpha * buf[1];
583        *out.get_unchecked_mut(i + 1) = rma;
584
585        rma = (-alpha).mul_add(rma, rma) + alpha * buf[2];
586        *out.get_unchecked_mut(i + 2) = rma;
587
588        rma = (-alpha).mul_add(rma, rma) + alpha * buf[3];
589        *out.get_unchecked_mut(i + 3) = rma;
590
591        i += 4;
592    }
593
594    if i < n {
595        let mut prev_c = *close.get_unchecked(i - 1);
596        while i < n {
597            let hi = *high.get_unchecked(i);
598            let lo = *low.get_unchecked(i);
599            let mut tr = hi - lo;
600            let hc = (hi - prev_c).abs();
601            if hc > tr {
602                tr = hc;
603            }
604            let lc = (lo - prev_c).abs();
605            if lc > tr {
606                tr = lc;
607            }
608            rma = (-alpha).mul_add(rma, rma) + alpha * tr;
609            *out.get_unchecked_mut(i) = rma;
610
611            prev_c = *close.get_unchecked(i);
612            i += 1;
613        }
614    }
615}
616
617#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
618#[inline(always)]
619unsafe fn atr_compute_into_avx512(
620    high: &[f64],
621    low: &[f64],
622    close: &[f64],
623    length: usize,
624    first: usize,
625    out: &mut [f64],
626) {
627    use core::arch::x86_64::*;
628
629    debug_assert_eq!(high.len(), low.len());
630    debug_assert_eq!(low.len(), close.len());
631    debug_assert_eq!(out.len(), close.len());
632
633    let warm = first + length - 1;
634    let alpha = 1.0 / (length as f64);
635
636    let mut sum_tr = *high.get_unchecked(first) - *low.get_unchecked(first);
637    if warm > first {
638        let mut i = first + 1;
639        let mut prev_c = *close.get_unchecked(i - 1);
640        while i <= warm {
641            let hi = *high.get_unchecked(i);
642            let lo = *low.get_unchecked(i);
643
644            let mut tr = hi - lo;
645            let hc = (hi - prev_c).abs();
646            if hc > tr {
647                tr = hc;
648            }
649            let lc = (lo - prev_c).abs();
650            if lc > tr {
651                tr = lc;
652            }
653
654            sum_tr += tr;
655            prev_c = *close.get_unchecked(i);
656            i += 1;
657        }
658    }
659
660    let mut rma = sum_tr / (length as f64);
661    *out.get_unchecked_mut(warm) = rma;
662
663    let mut i = warm + 1;
664    let n = out.len();
665
666    let mask_abs = _mm512_castsi512_pd(_mm512_set1_epi64(0x7fff_ffff_ffff_ffffu64 as i64));
667
668    while i + 7 < n {
669        let v_hi = _mm512_loadu_pd(high.as_ptr().add(i));
670        let v_lo = _mm512_loadu_pd(low.as_ptr().add(i));
671        let v_pc = _mm512_loadu_pd(close.as_ptr().add(i - 1));
672
673        let v_hl = _mm512_sub_pd(v_hi, v_lo);
674        let v_hc = _mm512_and_pd(_mm512_sub_pd(v_hi, v_pc), mask_abs);
675        let v_lc = _mm512_and_pd(_mm512_sub_pd(v_lo, v_pc), mask_abs);
676
677        let v_m1 = _mm512_max_pd(v_hl, v_hc);
678        let v_tr = _mm512_max_pd(v_m1, v_lc);
679
680        let mut buf = [0.0f64; 8];
681        _mm512_storeu_pd(buf.as_mut_ptr(), v_tr);
682
683        rma = (-alpha).mul_add(rma, rma) + alpha * buf[0];
684        *out.get_unchecked_mut(i) = rma;
685
686        rma = (-alpha).mul_add(rma, rma) + alpha * buf[1];
687        *out.get_unchecked_mut(i + 1) = rma;
688
689        rma = (-alpha).mul_add(rma, rma) + alpha * buf[2];
690        *out.get_unchecked_mut(i + 2) = rma;
691
692        rma = (-alpha).mul_add(rma, rma) + alpha * buf[3];
693        *out.get_unchecked_mut(i + 3) = rma;
694
695        rma = (-alpha).mul_add(rma, rma) + alpha * buf[4];
696        *out.get_unchecked_mut(i + 4) = rma;
697
698        rma = (-alpha).mul_add(rma, rma) + alpha * buf[5];
699        *out.get_unchecked_mut(i + 5) = rma;
700
701        rma = (-alpha).mul_add(rma, rma) + alpha * buf[6];
702        *out.get_unchecked_mut(i + 6) = rma;
703
704        rma = (-alpha).mul_add(rma, rma) + alpha * buf[7];
705        *out.get_unchecked_mut(i + 7) = rma;
706
707        i += 8;
708    }
709
710    if i < n {
711        let mut prev_c = *close.get_unchecked(i - 1);
712        while i < n {
713            let hi = *high.get_unchecked(i);
714            let lo = *low.get_unchecked(i);
715            let mut tr = hi - lo;
716            let hc = (hi - prev_c).abs();
717            if hc > tr {
718                tr = hc;
719            }
720            let lc = (lo - prev_c).abs();
721            if lc > tr {
722                tr = lc;
723            }
724            rma = (-alpha).mul_add(rma, rma) + alpha * tr;
725            *out.get_unchecked_mut(i) = rma;
726
727            prev_c = *close.get_unchecked(i);
728            i += 1;
729        }
730    }
731}
732
733#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
734#[inline]
735pub fn atr_avx2(high: &[f64], low: &[f64], close: &[f64], length: usize, out: &mut [f64]) {
736    unsafe { atr_compute_into_avx2(high, low, close, length, 0, out) }
737}
738
739#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
740#[inline]
741pub fn atr_avx512(high: &[f64], low: &[f64], close: &[f64], length: usize, out: &mut [f64]) {
742    unsafe { atr_compute_into_avx512(high, low, close, length, 0, out) }
743}
744
745#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
746#[inline]
747pub unsafe fn atr_avx512_short(
748    high: &[f64],
749    low: &[f64],
750    close: &[f64],
751    length: usize,
752    out: &mut [f64],
753) {
754    atr_compute_into_avx512(high, low, close, length, 0, out)
755}
756#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
757#[inline]
758pub unsafe fn atr_avx512_long(
759    high: &[f64],
760    low: &[f64],
761    close: &[f64],
762    length: usize,
763    out: &mut [f64],
764) {
765    atr_compute_into_avx512(high, low, close, length, 0, out)
766}
767
768#[derive(Debug, Clone)]
769pub struct AtrStream {
770    length: usize,
771    alpha: f64,
772    prev_close: f64,
773    rma: f64,
774    warm_sum: f64,
775    warm_count: usize,
776    seeded: bool,
777}
778
779impl AtrStream {
780    #[inline(always)]
781    pub fn try_new(params: AtrParams) -> Result<Self, AtrError> {
782        let length = params.length.unwrap_or(14);
783        if length == 0 {
784            return Err(AtrError::InvalidLength { length });
785        }
786        Ok(Self {
787            length,
788            alpha: 1.0 / (length as f64),
789            prev_close: f64::NAN,
790            rma: f64::NAN,
791            warm_sum: 0.0,
792            warm_count: 0,
793            seeded: false,
794        })
795    }
796
797    #[inline(always)]
798    pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
799        debug_assert!(
800            high.is_finite() && low.is_finite() && close.is_finite(),
801            "Streaming ATR assumes finite inputs; prefilter NaNs/Infs upstream if needed",
802        );
803
804        let tr = if self.prev_close.is_nan() {
805            high - low
806        } else {
807            let up = if high > self.prev_close {
808                high
809            } else {
810                self.prev_close
811            };
812            let dn = if low < self.prev_close {
813                low
814            } else {
815                self.prev_close
816            };
817            up - dn
818        };
819
820        self.prev_close = close;
821
822        if !self.seeded {
823            self.warm_sum += tr;
824            self.warm_count += 1;
825
826            if self.warm_count == self.length {
827                self.rma = self.warm_sum * self.alpha;
828                self.seeded = true;
829                return Some(self.rma);
830            }
831            return None;
832        }
833
834        self.rma = self.alpha.mul_add(tr - self.rma, self.rma);
835        Some(self.rma)
836    }
837}
838
839#[derive(Clone, Debug)]
840pub struct AtrBatchRange {
841    pub length: (usize, usize, usize),
842}
843impl Default for AtrBatchRange {
844    fn default() -> Self {
845        Self {
846            length: (14, 263, 1),
847        }
848    }
849}
850#[derive(Clone, Debug, Default)]
851pub struct AtrBatchBuilder {
852    range: AtrBatchRange,
853    kernel: Kernel,
854}
855impl AtrBatchBuilder {
856    pub fn new() -> Self {
857        Self::default()
858    }
859    pub fn kernel(mut self, k: Kernel) -> Self {
860        self.kernel = k;
861        self
862    }
863    #[inline]
864    pub fn length_range(mut self, start: usize, end: usize, step: usize) -> Self {
865        self.range.length = (start, end, step);
866        self
867    }
868    #[inline]
869    pub fn length_static(mut self, p: usize) -> Self {
870        self.range.length = (p, p, 0);
871        self
872    }
873    pub fn apply_slices(
874        self,
875        high: &[f64],
876        low: &[f64],
877        close: &[f64],
878    ) -> Result<AtrBatchOutput, AtrError> {
879        atr_batch_with_kernel(high, low, close, &self.range, self.kernel)
880    }
881    pub fn apply_candles(self, c: &Candles) -> Result<AtrBatchOutput, AtrError> {
882        let high = c
883            .select_candle_field("high")
884            .map_err(|_| AtrError::NoCandlesAvailable)?;
885        let low = c
886            .select_candle_field("low")
887            .map_err(|_| AtrError::NoCandlesAvailable)?;
888        let close = c
889            .select_candle_field("close")
890            .map_err(|_| AtrError::NoCandlesAvailable)?;
891        self.apply_slices(high, low, close)
892    }
893}
894
895#[derive(Clone, Debug)]
896pub struct AtrBatchOutput {
897    pub values: Vec<f64>,
898    pub combos: Vec<AtrParams>,
899    pub rows: usize,
900    pub cols: usize,
901}
902impl AtrBatchOutput {
903    pub fn row_for_params(&self, p: &AtrParams) -> Option<usize> {
904        self.combos
905            .iter()
906            .position(|c| c.length.unwrap_or(14) == p.length.unwrap_or(14))
907    }
908    pub fn values_for(&self, p: &AtrParams) -> Option<&[f64]> {
909        self.row_for_params(p).map(|row| {
910            let start = row * self.cols;
911            &self.values[start..start + self.cols]
912        })
913    }
914}
915
916#[inline(always)]
917fn expand_grid(r: &AtrBatchRange) -> Vec<AtrParams> {
918    let (start, end, step) = r.length;
919    if step == 0 || start == end {
920        return vec![AtrParams {
921            length: Some(start),
922        }];
923    }
924    if start < end {
925        (start..=end)
926            .step_by(step)
927            .map(|l| AtrParams { length: Some(l) })
928            .collect()
929    } else {
930        let mut v: Vec<usize> = (end..=start).step_by(step).collect();
931        v.reverse();
932        v.into_iter()
933            .map(|l| AtrParams { length: Some(l) })
934            .collect()
935    }
936}
937
938pub fn atr_batch_with_kernel(
939    high: &[f64],
940    low: &[f64],
941    close: &[f64],
942    sweep: &AtrBatchRange,
943    k: Kernel,
944) -> Result<AtrBatchOutput, AtrError> {
945    let kernel = match k {
946        Kernel::Auto => detect_best_batch_kernel(),
947        other if other.is_batch() => other,
948        other => return Err(AtrError::InvalidKernelForBatch(other)),
949    };
950    let simd = match kernel {
951        Kernel::Avx512Batch => Kernel::Avx512,
952        Kernel::Avx2Batch => Kernel::Avx2,
953        Kernel::ScalarBatch => Kernel::Scalar,
954        _ => unreachable!(),
955    };
956    atr_batch_par_slice(high, low, close, sweep, simd)
957}
958
959#[inline(always)]
960pub fn atr_batch_slice(
961    high: &[f64],
962    low: &[f64],
963    close: &[f64],
964    sweep: &AtrBatchRange,
965    kern: Kernel,
966) -> Result<AtrBatchOutput, AtrError> {
967    atr_batch_inner(high, low, close, sweep, kern, false)
968}
969#[inline(always)]
970pub fn atr_batch_par_slice(
971    high: &[f64],
972    low: &[f64],
973    close: &[f64],
974    sweep: &AtrBatchRange,
975    kern: Kernel,
976) -> Result<AtrBatchOutput, AtrError> {
977    atr_batch_inner(high, low, close, sweep, kern, true)
978}
979
980fn atr_batch_inner_into(
981    high: &[f64],
982    low: &[f64],
983    close: &[f64],
984    sweep: &AtrBatchRange,
985    kern: Kernel,
986    parallel: bool,
987    out: &mut [f64],
988) -> Result<Vec<AtrParams>, AtrError> {
989    let combos = expand_grid(sweep);
990    if combos.is_empty() {
991        let (s, e, st) = sweep.length;
992        return Err(AtrError::InvalidRange {
993            start: s,
994            end: e,
995            step: st,
996        });
997    }
998    let rows = combos.len();
999    let cols = high.len();
1000    let expected = rows.checked_mul(cols).ok_or(AtrError::InvalidRange {
1001        start: sweep.length.0,
1002        end: sweep.length.1,
1003        step: sweep.length.2,
1004    })?;
1005    if out.len() != expected {
1006        return Err(AtrError::OutputLengthMismatch {
1007            expected,
1008            got: out.len(),
1009        });
1010    }
1011
1012    let first = first_valid_hlc(high, low, close);
1013    if first >= cols {
1014        return Err(AtrError::AllValuesNaN);
1015    }
1016
1017    let mut tr = AVec::<f64>::with_capacity(CACHELINE_ALIGN, cols);
1018    unsafe {
1019        tr.set_len(cols);
1020    }
1021
1022    for v in &mut tr[..] {
1023        *v = 0.0;
1024    }
1025
1026    match kern_to_simd(kern) {
1027        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1028        Kernel::Avx512 => unsafe {
1029            precompute_tr_into_avx512(high, low, close, first, &mut tr);
1030        },
1031        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1032        Kernel::Avx2 => unsafe {
1033            precompute_tr_into_avx2(high, low, close, first, &mut tr);
1034        },
1035        _ => {
1036            precompute_tr_into_scalar(high, low, close, first, &mut tr);
1037        }
1038    }
1039
1040    let mut ps = AVec::<f64>::with_capacity(CACHELINE_ALIGN, cols + 1);
1041    unsafe {
1042        ps.set_len(cols + 1);
1043    }
1044    ps[0] = 0.0;
1045
1046    for i in 0..cols {
1047        ps[i + 1] = ps[i] + tr[i];
1048    }
1049
1050    let do_row = |row: usize, dst: &mut [f64]| {
1051        let length = combos[row].length.unwrap();
1052        let warm = first + length - 1;
1053
1054        for v in &mut dst[..warm] {
1055            *v = f64::NAN;
1056        }
1057
1058        let sum_tr = ps[warm + 1] - ps[first];
1059        let mut rma = sum_tr / (length as f64);
1060        dst[warm] = rma;
1061        let alpha = 1.0 / (length as f64);
1062        let mut i = warm + 1;
1063        while i < cols {
1064            let tri = tr[i];
1065            rma = (-alpha).mul_add(rma, rma) + alpha * tri;
1066            dst[i] = rma;
1067            i += 1;
1068        }
1069    };
1070
1071    #[inline(always)]
1072    fn kern_to_simd(k: Kernel) -> Kernel {
1073        match k {
1074            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1075            Kernel::Avx512Batch => Kernel::Avx512,
1076            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1077            Kernel::Avx2Batch => Kernel::Avx2,
1078            Kernel::ScalarBatch => Kernel::Scalar,
1079            other => other,
1080        }
1081    }
1082
1083    if parallel {
1084        #[cfg(not(target_arch = "wasm32"))]
1085        out.par_chunks_mut(cols)
1086            .enumerate()
1087            .for_each(|(r, row)| do_row(r, row));
1088        #[cfg(target_arch = "wasm32")]
1089        for (r, row) in out.chunks_mut(cols).enumerate() {
1090            do_row(r, row);
1091        }
1092    } else {
1093        for (r, row) in out.chunks_mut(cols).enumerate() {
1094            do_row(r, row);
1095        }
1096    }
1097
1098    Ok(combos)
1099}
1100
1101fn atr_batch_inner(
1102    high: &[f64],
1103    low: &[f64],
1104    close: &[f64],
1105    sweep: &AtrBatchRange,
1106    kern: Kernel,
1107    parallel: bool,
1108) -> Result<AtrBatchOutput, AtrError> {
1109    let combos = expand_grid(sweep);
1110    if combos.is_empty() {
1111        let (s, e, st) = sweep.length;
1112        return Err(AtrError::InvalidRange {
1113            start: s,
1114            end: e,
1115            step: st,
1116        });
1117    }
1118    let len = close.len();
1119    let rows = combos.len();
1120    let cols = len;
1121
1122    let mut buf_mu = make_uninit_matrix(rows, cols);
1123
1124    let first_valid = first_valid_hlc(high, low, close);
1125
1126    let warm: Vec<usize> = combos
1127        .iter()
1128        .map(|c| first_valid + c.length.unwrap() - 1)
1129        .collect();
1130
1131    init_matrix_prefixes(&mut buf_mu, cols, &warm);
1132
1133    let mut buf_guard = std::mem::ManuallyDrop::new(buf_mu);
1134    let values: &mut [f64] = unsafe {
1135        std::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
1136    };
1137
1138    let mut tr = AVec::<f64>::with_capacity(CACHELINE_ALIGN, cols);
1139    unsafe {
1140        tr.set_len(cols);
1141    }
1142    for v in &mut tr[..] {
1143        *v = 0.0;
1144    }
1145    match kern {
1146        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1147        Kernel::Avx512 => unsafe {
1148            precompute_tr_into_avx512(high, low, close, first_valid, &mut tr)
1149        },
1150        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1151        Kernel::Avx2 => unsafe { precompute_tr_into_avx2(high, low, close, first_valid, &mut tr) },
1152        _ => precompute_tr_into_scalar(high, low, close, first_valid, &mut tr),
1153    }
1154    let mut ps = AVec::<f64>::with_capacity(CACHELINE_ALIGN, cols + 1);
1155    unsafe { ps.set_len(cols + 1) };
1156    ps[0] = 0.0;
1157    for i in 0..cols {
1158        ps[i + 1] = ps[i] + tr[i];
1159    }
1160
1161    let do_row = |row: usize, out_row: &mut [f64]| {
1162        let length = combos[row].length.unwrap();
1163        let warm = first_valid + length - 1;
1164
1165        let sum_tr = ps[warm + 1] - ps[first_valid];
1166        let mut rma = sum_tr / (length as f64);
1167        out_row[warm] = rma;
1168        let alpha = 1.0 / (length as f64);
1169        let mut i = warm + 1;
1170        while i < cols {
1171            let tri = tr[i];
1172            rma = (-alpha).mul_add(rma, rma) + alpha * tri;
1173            out_row[i] = rma;
1174            i += 1;
1175        }
1176    };
1177    if parallel {
1178        #[cfg(not(target_arch = "wasm32"))]
1179        {
1180            values
1181                .par_chunks_mut(cols)
1182                .enumerate()
1183                .for_each(|(row, slice)| do_row(row, slice));
1184        }
1185
1186        #[cfg(target_arch = "wasm32")]
1187        {
1188            for (row, slice) in values.chunks_mut(cols).enumerate() {
1189                do_row(row, slice);
1190            }
1191        }
1192    } else {
1193        for (row, slice) in values.chunks_mut(cols).enumerate() {
1194            do_row(row, slice);
1195        }
1196    }
1197
1198    let final_values = unsafe {
1199        Vec::from_raw_parts(
1200            buf_guard.as_mut_ptr() as *mut f64,
1201            buf_guard.len(),
1202            buf_guard.capacity(),
1203        )
1204    };
1205
1206    Ok(AtrBatchOutput {
1207        values: final_values,
1208        combos,
1209        rows,
1210        cols,
1211    })
1212}
1213
1214#[inline(always)]
1215unsafe fn atr_row_scalar(high: &[f64], low: &[f64], close: &[f64], length: usize, out: &mut [f64]) {
1216    let first = first_valid_hlc(high, low, close);
1217    atr_compute_into(high, low, close, length, first, Kernel::Scalar, out);
1218}
1219
1220#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1221#[inline(always)]
1222unsafe fn atr_row_avx2(high: &[f64], low: &[f64], close: &[f64], length: usize, out: &mut [f64]) {
1223    let first = first_valid_hlc(high, low, close);
1224    atr_compute_into(high, low, close, length, first, Kernel::Avx2, out);
1225}
1226#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1227#[inline(always)]
1228pub unsafe fn atr_row_avx512(
1229    high: &[f64],
1230    low: &[f64],
1231    close: &[f64],
1232    length: usize,
1233    out: &mut [f64],
1234) {
1235    if length <= 32 {
1236        atr_row_avx512_short(high, low, close, length, out);
1237    } else {
1238        atr_row_avx512_long(high, low, close, length, out);
1239    }
1240}
1241#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1242#[inline(always)]
1243pub unsafe fn atr_row_avx512_short(
1244    high: &[f64],
1245    low: &[f64],
1246    close: &[f64],
1247    length: usize,
1248    out: &mut [f64],
1249) {
1250    let first = first_valid_hlc(high, low, close);
1251    atr_compute_into(high, low, close, length, first, Kernel::Avx512, out);
1252}
1253#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1254#[inline(always)]
1255pub unsafe fn atr_row_avx512_long(
1256    high: &[f64],
1257    low: &[f64],
1258    close: &[f64],
1259    length: usize,
1260    out: &mut [f64],
1261) {
1262    let first = first_valid_hlc(high, low, close);
1263    atr_compute_into(high, low, close, length, first, Kernel::Avx512, out);
1264}
1265
1266#[inline(always)]
1267fn precompute_tr_into_scalar(
1268    high: &[f64],
1269    low: &[f64],
1270    close: &[f64],
1271    first: usize,
1272    tr_out: &mut [f64],
1273) {
1274    if first >= tr_out.len() {
1275        return;
1276    }
1277    tr_out[first] = high[first] - low[first];
1278    let mut i = first + 1;
1279    while i < tr_out.len() {
1280        let hi = high[i];
1281        let lo = low[i];
1282        let pc = close[i - 1];
1283        let mut tr = hi - lo;
1284        let hc = (hi - pc).abs();
1285        if hc > tr {
1286            tr = hc;
1287        }
1288        let lc = (lo - pc).abs();
1289        if lc > tr {
1290            tr = lc;
1291        }
1292        tr_out[i] = tr;
1293        i += 1;
1294    }
1295}
1296
1297#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1298#[inline(always)]
1299unsafe fn precompute_tr_into_avx2(
1300    high: &[f64],
1301    low: &[f64],
1302    close: &[f64],
1303    first: usize,
1304    tr_out: &mut [f64],
1305) {
1306    use core::arch::x86_64::*;
1307    if first >= tr_out.len() {
1308        return;
1309    }
1310    tr_out[first] = *high.get_unchecked(first) - *low.get_unchecked(first);
1311    let mut i = first + 1;
1312    let n = tr_out.len();
1313    let mask_abs = _mm256_castsi256_pd(_mm256_set1_epi64x(0x7fff_ffff_ffff_ffffu64 as i64));
1314    while i + 3 < n {
1315        let v_hi = _mm256_loadu_pd(high.as_ptr().add(i));
1316        let v_lo = _mm256_loadu_pd(low.as_ptr().add(i));
1317        let v_pc = _mm256_loadu_pd(close.as_ptr().add(i - 1));
1318
1319        let v_hl = _mm256_sub_pd(v_hi, v_lo);
1320        let v_hc = _mm256_and_pd(_mm256_sub_pd(v_hi, v_pc), mask_abs);
1321        let v_lc = _mm256_and_pd(_mm256_sub_pd(v_lo, v_pc), mask_abs);
1322        let v_m1 = _mm256_max_pd(v_hl, v_hc);
1323        let v_tr = _mm256_max_pd(v_m1, v_lc);
1324        _mm256_storeu_pd(tr_out.as_mut_ptr().add(i), v_tr);
1325        i += 4;
1326    }
1327    while i < n {
1328        let hi = *high.get_unchecked(i);
1329        let lo = *low.get_unchecked(i);
1330        let pc = *close.get_unchecked(i - 1);
1331        let mut tr = hi - lo;
1332        let hc = (hi - pc).abs();
1333        if hc > tr {
1334            tr = hc;
1335        }
1336        let lc = (lo - pc).abs();
1337        if lc > tr {
1338            tr = lc;
1339        }
1340        *tr_out.get_unchecked_mut(i) = tr;
1341        i += 1;
1342    }
1343}
1344
1345#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1346#[inline(always)]
1347unsafe fn precompute_tr_into_avx512(
1348    high: &[f64],
1349    low: &[f64],
1350    close: &[f64],
1351    first: usize,
1352    tr_out: &mut [f64],
1353) {
1354    use core::arch::x86_64::*;
1355    if first >= tr_out.len() {
1356        return;
1357    }
1358    tr_out[first] = *high.get_unchecked(first) - *low.get_unchecked(first);
1359    let mut i = first + 1;
1360    let n = tr_out.len();
1361    let mask_abs = _mm512_castsi512_pd(_mm512_set1_epi64(0x7fff_ffff_ffff_ffffu64 as i64));
1362    while i + 7 < n {
1363        let v_hi = _mm512_loadu_pd(high.as_ptr().add(i));
1364        let v_lo = _mm512_loadu_pd(low.as_ptr().add(i));
1365        let v_pc = _mm512_loadu_pd(close.as_ptr().add(i - 1));
1366        let v_hl = _mm512_sub_pd(v_hi, v_lo);
1367        let v_hc = _mm512_and_pd(_mm512_sub_pd(v_hi, v_pc), mask_abs);
1368        let v_lc = _mm512_and_pd(_mm512_sub_pd(v_lo, v_pc), mask_abs);
1369        let v_m1 = _mm512_max_pd(v_hl, v_hc);
1370        let v_tr = _mm512_max_pd(v_m1, v_lc);
1371        _mm512_storeu_pd(tr_out.as_mut_ptr().add(i), v_tr);
1372        i += 8;
1373    }
1374    while i < n {
1375        let hi = *high.get_unchecked(i);
1376        let lo = *low.get_unchecked(i);
1377        let pc = *close.get_unchecked(i - 1);
1378        let mut tr = hi - lo;
1379        let hc = (hi - pc).abs();
1380        if hc > tr {
1381            tr = hc;
1382        }
1383        let lc = (lo - pc).abs();
1384        if lc > tr {
1385            tr = lc;
1386        }
1387        *tr_out.get_unchecked_mut(i) = tr;
1388        i += 1;
1389    }
1390}
1391
1392#[cfg(test)]
1393mod tests {
1394    use super::*;
1395    use crate::skip_if_unsupported;
1396    use crate::utilities::data_loader::read_candles_from_csv;
1397    use crate::utilities::enums::Kernel;
1398    #[cfg(feature = "proptest")]
1399    use proptest::prelude::*;
1400
1401    fn check_atr_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1402        skip_if_unsupported!(kernel, test_name);
1403        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1404        let candles = read_candles_from_csv(file_path)?;
1405        let partial_params = AtrParams { length: None };
1406        let input_partial = AtrInput::from_candles(&candles, partial_params);
1407        let result_partial = atr_with_kernel(&input_partial, kernel)?;
1408        assert_eq!(result_partial.values.len(), candles.close.len());
1409        let zero_and_none_params = AtrParams { length: Some(14) };
1410        let input_zero_and_none = AtrInput::from_candles(&candles, zero_and_none_params);
1411        let result_zero_and_none = atr_with_kernel(&input_zero_and_none, kernel)?;
1412        assert_eq!(result_zero_and_none.values.len(), candles.close.len());
1413        Ok(())
1414    }
1415
1416    fn check_atr_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1417        skip_if_unsupported!(kernel, test_name);
1418        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1419        let candles = read_candles_from_csv(file_path)?;
1420        let input = AtrInput::with_default_candles(&candles);
1421        let result = atr_with_kernel(&input, kernel)?;
1422        let expected_last_five = [916.89, 874.33, 838.45, 801.92, 811.57];
1423        assert!(result.values.len() >= 5, "Not enough ATR values");
1424        assert_eq!(
1425            result.values.len(),
1426            candles.close.len(),
1427            "ATR output length does not match input length!"
1428        );
1429        let start_index = result.values.len().saturating_sub(5);
1430        let last_five = &result.values[start_index..];
1431        for (i, &value) in last_five.iter().enumerate() {
1432            assert!(
1433                (value - expected_last_five[i]).abs() < 1e-2,
1434                "ATR value mismatch at index {}: expected {}, got {}",
1435                i,
1436                expected_last_five[i],
1437                value
1438            );
1439        }
1440        let length = 14;
1441        for val in result.values.iter().skip(length - 1) {
1442            if !val.is_nan() {
1443                assert!(
1444                    val.is_finite(),
1445                    "ATR output should be finite after RMA stabilizes"
1446                );
1447            }
1448        }
1449        Ok(())
1450    }
1451
1452    fn check_atr_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1453        skip_if_unsupported!(kernel, test_name);
1454        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1455        let candles = read_candles_from_csv(file_path)?;
1456        let input = AtrInput::with_default_candles(&candles);
1457        match input.data {
1458            AtrData::Candles { .. } => {}
1459            _ => panic!("Expected AtrData::Candles variant"),
1460        }
1461        let default_params = AtrParams::default();
1462        assert_eq!(input.params.length, default_params.length);
1463        let output = atr_with_kernel(&input, kernel)?;
1464        assert_eq!(output.values.len(), candles.close.len());
1465        Ok(())
1466    }
1467
1468    fn check_atr_zero_length(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1469        skip_if_unsupported!(kernel, test_name);
1470        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1471        let candles = read_candles_from_csv(file_path)?;
1472        let zero_length_params = AtrParams { length: Some(0) };
1473        let input_zero_length = AtrInput::from_candles(&candles, zero_length_params);
1474        let result_zero_length = atr_with_kernel(&input_zero_length, kernel);
1475        assert!(result_zero_length.is_err());
1476        Ok(())
1477    }
1478
1479    fn check_atr_length_exceeding_data_length(
1480        test_name: &str,
1481        kernel: Kernel,
1482    ) -> Result<(), Box<dyn Error>> {
1483        skip_if_unsupported!(kernel, test_name);
1484        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1485        let candles = read_candles_from_csv(file_path)?;
1486        let too_long_params = AtrParams {
1487            length: Some(candles.close.len() + 10),
1488        };
1489        let input_too_long = AtrInput::from_candles(&candles, too_long_params);
1490        let result_too_long = atr_with_kernel(&input_too_long, kernel);
1491        assert!(result_too_long.is_err());
1492        Ok(())
1493    }
1494
1495    fn check_atr_very_small_data_set(
1496        test_name: &str,
1497        kernel: Kernel,
1498    ) -> Result<(), Box<dyn Error>> {
1499        skip_if_unsupported!(kernel, test_name);
1500        let high = [10.0];
1501        let low = [5.0];
1502        let close = [7.0];
1503        let params = AtrParams { length: Some(14) };
1504        let input = AtrInput::from_slices(&high, &low, &close, params);
1505        let result = atr_with_kernel(&input, kernel);
1506        assert!(result.is_err());
1507        Ok(())
1508    }
1509
1510    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1511    #[test]
1512    fn test_atr_into_matches_api() -> Result<(), Box<dyn Error>> {
1513        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1514        let candles = read_candles_from_csv(file_path)?;
1515        let input = AtrInput::with_default_candles(&candles);
1516
1517        let baseline = atr(&input)?;
1518
1519        let mut out = vec![0.0f64; candles.close.len()];
1520        atr_into(&input, &mut out)?;
1521
1522        assert_eq!(baseline.values.len(), out.len());
1523
1524        fn eq_or_nan_bits(a: f64, b: f64) -> bool {
1525            if !a.is_finite() || !b.is_finite() {
1526                a.to_bits() == b.to_bits()
1527            } else {
1528                (a - b).abs() <= 1e-12
1529            }
1530        }
1531
1532        for i in 0..out.len() {
1533            assert!(
1534                eq_or_nan_bits(baseline.values[i], out[i]),
1535                "Mismatch at {}: api={} into={}",
1536                i,
1537                baseline.values[i],
1538                out[i]
1539            );
1540        }
1541        Ok(())
1542    }
1543
1544    fn check_atr_with_slice_data_reinput(
1545        test_name: &str,
1546        kernel: Kernel,
1547    ) -> Result<(), Box<dyn Error>> {
1548        skip_if_unsupported!(kernel, test_name);
1549        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1550        let candles = read_candles_from_csv(file_path)?;
1551        let first_params = AtrParams { length: Some(14) };
1552        let first_input = AtrInput::from_candles(&candles, first_params);
1553        let first_result = atr_with_kernel(&first_input, kernel)?;
1554        assert_eq!(first_result.values.len(), candles.close.len());
1555        let second_params = AtrParams { length: Some(5) };
1556        let second_input = AtrInput::from_slices(
1557            &first_result.values,
1558            &first_result.values,
1559            &first_result.values,
1560            second_params,
1561        );
1562        let second_result = atr_with_kernel(&second_input, kernel)?;
1563        assert_eq!(second_result.values.len(), first_result.values.len());
1564        Ok(())
1565    }
1566
1567    fn check_atr_accuracy_nan_check(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1568        skip_if_unsupported!(kernel, test_name);
1569        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1570        let candles = read_candles_from_csv(file_path)?;
1571        let params = AtrParams { length: Some(14) };
1572        let input = AtrInput::from_candles(&candles, params);
1573        let result = atr_with_kernel(&input, kernel)?;
1574        assert_eq!(result.values.len(), candles.close.len());
1575        if result.values.len() > 240 {
1576            for i in 240..result.values.len() {
1577                assert!(!result.values[i].is_nan());
1578            }
1579        }
1580        Ok(())
1581    }
1582
1583    #[cfg(debug_assertions)]
1584    fn check_atr_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1585        skip_if_unsupported!(kernel, test_name);
1586
1587        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1588        let candles = read_candles_from_csv(file_path)?;
1589
1590        let test_lengths = vec![2, 5, 10, 14, 20, 50, 100, 200];
1591
1592        for length in test_lengths {
1593            let params = AtrParams {
1594                length: Some(length),
1595            };
1596            let input = AtrInput::from_candles(&candles, params);
1597            let output = atr_with_kernel(&input, kernel)?;
1598
1599            for (i, &val) in output.values.iter().enumerate() {
1600                if val.is_nan() {
1601                    continue;
1602                }
1603
1604                let bits = val.to_bits();
1605
1606                if bits == 0x11111111_11111111 {
1607                    panic!(
1608						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with length={}",
1609						test_name, val, bits, i, length
1610					);
1611                }
1612
1613                if bits == 0x22222222_22222222 {
1614                    panic!(
1615						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with length={}",
1616						test_name, val, bits, i, length
1617					);
1618                }
1619
1620                if bits == 0x33333333_33333333 {
1621                    panic!(
1622						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with length={}",
1623						test_name, val, bits, i, length
1624					);
1625                }
1626            }
1627        }
1628
1629        Ok(())
1630    }
1631
1632    #[cfg(not(debug_assertions))]
1633    fn check_atr_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1634        Ok(())
1635    }
1636
1637    #[cfg(feature = "proptest")]
1638    #[allow(clippy::float_cmp)]
1639    fn check_atr_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1640        use proptest::prelude::*;
1641        skip_if_unsupported!(kernel, test_name);
1642
1643        let strat = (2usize..=50)
1644            .prop_flat_map(|length| {
1645                (length..400).prop_flat_map(move |data_len| {
1646                    (
1647                        prop::collection::vec(
1648                            (10.0f64..10000.0f64).prop_filter("finite", |x| x.is_finite()),
1649                            data_len,
1650                        ),
1651                        prop::collection::vec(
1652                            (10.0f64..10000.0f64).prop_filter("finite", |x| x.is_finite()),
1653                            data_len,
1654                        ),
1655                        prop::collection::vec(
1656                            (10.0f64..10000.0f64).prop_filter("finite", |x| x.is_finite()),
1657                            data_len,
1658                        ),
1659                        Just(length),
1660                    )
1661                })
1662            })
1663            .prop_map(|(high_raw, low_raw, close_raw, length)| {
1664                let len = high_raw.len();
1665                assert_eq!(low_raw.len(), len);
1666                assert_eq!(close_raw.len(), len);
1667
1668                let mut high = Vec::with_capacity(len);
1669                let mut low = Vec::with_capacity(len);
1670                let mut close = Vec::with_capacity(len);
1671
1672                for i in 0..len {
1673                    let h = high_raw[i].max(low_raw[i]);
1674                    let l = high_raw[i].min(low_raw[i]);
1675
1676                    let c = close_raw[i].max(l).min(h);
1677
1678                    high.push(h);
1679                    low.push(l);
1680                    close.push(c);
1681                }
1682
1683                (high, low, close, length)
1684            });
1685
1686        proptest::test_runner::TestRunner::default().run(
1687            &strat,
1688            |(high, low, close, length)| {
1689                let params = AtrParams {
1690                    length: Some(length),
1691                };
1692                let input = AtrInput::from_slices(&high, &low, &close, params);
1693
1694                let AtrOutput { values: out } = atr_with_kernel(&input, kernel)?;
1695                let AtrOutput { values: ref_out } = atr_with_kernel(&input, Kernel::Scalar)?;
1696
1697                prop_assert_eq!(out.len(), high.len(), "Output length mismatch");
1698
1699                for i in 0..(length - 1) {
1700                    prop_assert!(
1701                        out[i].is_nan(),
1702                        "Expected NaN during warmup at index {}, got {}",
1703                        i,
1704                        out[i]
1705                    );
1706                }
1707
1708                for (i, &val) in out.iter().enumerate().skip(length - 1) {
1709                    if !val.is_nan() {
1710                        prop_assert!(
1711                            val >= 0.0,
1712                            "ATR must be non-negative at index {}: got {}",
1713                            i,
1714                            val
1715                        );
1716                    }
1717                }
1718
1719                let mut max_true_range = 0.0f64;
1720                for i in 0..high.len() {
1721                    let tr = if i == 0 {
1722                        high[0] - low[0]
1723                    } else {
1724                        let hl = high[i] - low[i];
1725                        let hc = (high[i] - close[i - 1]).abs();
1726                        let lc = (low[i] - close[i - 1]).abs();
1727                        hl.max(hc).max(lc)
1728                    };
1729                    max_true_range = max_true_range.max(tr);
1730                }
1731
1732                for (i, &val) in out.iter().enumerate().skip(length - 1) {
1733                    if !val.is_nan() && val.is_finite() {
1734                        prop_assert!(
1735                            val <= max_true_range + 1e-9,
1736                            "ATR at index {} exceeds max true range: {} > {}",
1737                            i,
1738                            val,
1739                            max_true_range
1740                        );
1741                    }
1742                }
1743
1744                for i in 0..out.len() {
1745                    let y = out[i];
1746                    let r = ref_out[i];
1747
1748                    if !y.is_finite() || !r.is_finite() {
1749                        prop_assert_eq!(
1750                            y.to_bits(),
1751                            r.to_bits(),
1752                            "NaN/infinite mismatch at index {}: {} vs {}",
1753                            i,
1754                            y,
1755                            r
1756                        );
1757                        continue;
1758                    }
1759
1760                    let y_bits = y.to_bits();
1761                    let r_bits = r.to_bits();
1762                    let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1763
1764                    prop_assert!(
1765                        (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1766                        "Kernel mismatch at index {}: {} vs {} (ULP={})",
1767                        i,
1768                        y,
1769                        r,
1770                        ulp_diff
1771                    );
1772                }
1773
1774                let first_price = high[0];
1775                let is_constant = high.iter().all(|&h| (h - first_price).abs() < 1e-10)
1776                    && low.iter().all(|&l| (l - first_price).abs() < 1e-10)
1777                    && close.iter().all(|&c| (c - first_price).abs() < 1e-10);
1778
1779                if is_constant {
1780                    if out.len() >= length * 3 {
1781                        let last_values = &out[out.len().saturating_sub(5)..];
1782                        for &val in last_values {
1783                            if !val.is_nan() && val.is_finite() {
1784                                prop_assert!(
1785                                    val < 1e-6,
1786                                    "ATR should converge to 0 for constant prices, got {}",
1787                                    val
1788                                );
1789                            }
1790                        }
1791                    }
1792                }
1793
1794                if out.len() >= length + 10 {
1795                    for i in (length + 1)..out.len() {
1796                        if !out[i].is_nan() && !out[i - 1].is_nan() {
1797                            let tr = {
1798                                let hl = high[i] - low[i];
1799                                let hc = (high[i] - close[i - 1]).abs();
1800                                let lc = (low[i] - close[i - 1]).abs();
1801                                hl.max(hc).max(lc)
1802                            };
1803
1804                            let expected_change_bound = (tr - out[i - 1]).abs() / length as f64;
1805                            let actual_change = (out[i] - out[i - 1]).abs();
1806
1807                            prop_assert!(
1808                                actual_change <= expected_change_bound + 1e-9,
1809                                "ATR change at index {} exceeds RMA bound: {} > {}",
1810                                i,
1811                                actual_change,
1812                                expected_change_bound
1813                            );
1814                        }
1815                    }
1816                }
1817
1818                if length == 1 {
1819                    for i in 0..out.len() {
1820                        if !out[i].is_nan() {
1821                            let tr = if i == 0 {
1822                                high[0] - low[0]
1823                            } else {
1824                                let hl = high[i] - low[i];
1825                                let hc = (high[i] - close[i - 1]).abs();
1826                                let lc = (low[i] - close[i - 1]).abs();
1827                                hl.max(hc).max(lc)
1828                            };
1829                            prop_assert!(
1830                                (out[i] - tr).abs() <= 1e-9,
1831                                "Length=1 ATR should equal TR at index {}: {} vs {}",
1832                                i,
1833                                out[i],
1834                                tr
1835                            );
1836                        }
1837                    }
1838                }
1839
1840                Ok(())
1841            },
1842        )?;
1843
1844        Ok(())
1845    }
1846
1847    macro_rules! generate_all_atr_tests {
1848        ($($test_fn:ident),*) => {
1849            paste::paste! {
1850                $(
1851                    #[test]
1852                    fn [<$test_fn _scalar_f64>]() {
1853                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1854                    }
1855                )*
1856                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1857                $(
1858                    #[test]
1859                    fn [<$test_fn _avx2_f64>]() {
1860                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1861                    }
1862                    #[test]
1863                    fn [<$test_fn _avx512_f64>]() {
1864                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1865                    }
1866                )*
1867            }
1868        }
1869    }
1870
1871    generate_all_atr_tests!(
1872        check_atr_partial_params,
1873        check_atr_accuracy,
1874        check_atr_default_candles,
1875        check_atr_zero_length,
1876        check_atr_length_exceeding_data_length,
1877        check_atr_very_small_data_set,
1878        check_atr_with_slice_data_reinput,
1879        check_atr_accuracy_nan_check,
1880        check_atr_no_poison
1881    );
1882
1883    #[cfg(feature = "proptest")]
1884    generate_all_atr_tests!(check_atr_property);
1885
1886    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1887        skip_if_unsupported!(kernel, test);
1888
1889        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1890        let c = read_candles_from_csv(file)?;
1891        let output = AtrBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
1892
1893        let def = AtrParams::default();
1894        let row = output.values_for(&def).expect("default row missing");
1895
1896        assert_eq!(row.len(), c.close.len());
1897
1898        let expected = [916.89, 874.33, 838.45, 801.92, 811.57];
1899        let start = row.len() - 5;
1900        for (i, &v) in row[start..].iter().enumerate() {
1901            assert!(
1902                (v - expected[i]).abs() < 1e-2,
1903                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1904            );
1905        }
1906        Ok(())
1907    }
1908
1909    #[cfg(debug_assertions)]
1910    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1911        skip_if_unsupported!(kernel, test);
1912
1913        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1914        let c = read_candles_from_csv(file)?;
1915
1916        let test_configs = vec![
1917            (2, 10, 1),
1918            (5, 25, 5),
1919            (10, 50, 10),
1920            (14, 140, 14),
1921            (50, 200, 50),
1922            (100, 100, 0),
1923        ];
1924
1925        for (start, end, step) in test_configs {
1926            let output = AtrBatchBuilder::new()
1927                .kernel(kernel)
1928                .length_range(start, end, step)
1929                .apply_candles(&c)?;
1930
1931            for (idx, &val) in output.values.iter().enumerate() {
1932                if val.is_nan() {
1933                    continue;
1934                }
1935
1936                let bits = val.to_bits();
1937                let row = idx / output.cols;
1938                let col = idx % output.cols;
1939                let length = output.combos[row].length.unwrap_or(14);
1940
1941                if bits == 0x11111111_11111111 {
1942                    panic!(
1943                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with length={} in range ({},{},{})",
1944                        test, val, bits, row, col, idx, length, start, end, step
1945                    );
1946                }
1947
1948                if bits == 0x22222222_22222222 {
1949                    panic!(
1950                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) with length={} in range ({},{},{})",
1951                        test, val, bits, row, col, idx, length, start, end, step
1952                    );
1953                }
1954
1955                if bits == 0x33333333_33333333 {
1956                    panic!(
1957                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with length={} in range ({},{},{})",
1958                        test, val, bits, row, col, idx, length, start, end, step
1959                    );
1960                }
1961            }
1962        }
1963
1964        Ok(())
1965    }
1966
1967    #[cfg(not(debug_assertions))]
1968    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1969        Ok(())
1970    }
1971
1972    macro_rules! gen_batch_tests {
1973        ($fn_name:ident) => {
1974            paste::paste! {
1975                #[test] fn [<$fn_name _scalar>]()      {
1976                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1977                }
1978                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1979                #[test] fn [<$fn_name _avx2>]()        {
1980                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1981                }
1982                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1983                #[test] fn [<$fn_name _avx512>]()      {
1984                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1985                }
1986                #[test] fn [<$fn_name _auto_detect>]() {
1987                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1988                }
1989            }
1990        };
1991    }
1992    gen_batch_tests!(check_batch_default_row);
1993    gen_batch_tests!(check_batch_no_poison);
1994}
1995
1996#[cfg(feature = "python")]
1997use pyo3::create_exception;
1998
1999#[cfg(feature = "python")]
2000create_exception!(atr, InvalidLengthError, PyValueError);
2001#[cfg(feature = "python")]
2002create_exception!(atr, InconsistentSliceLengthsError, PyValueError);
2003#[cfg(feature = "python")]
2004create_exception!(atr, NoCandlesAvailableError, PyValueError);
2005#[cfg(feature = "python")]
2006create_exception!(atr, NotEnoughDataError, PyValueError);
2007#[cfg(feature = "python")]
2008create_exception!(atr, EmptyInputDataError, PyValueError);
2009#[cfg(feature = "python")]
2010create_exception!(atr, AllValuesNaNError, PyValueError);
2011#[cfg(feature = "python")]
2012create_exception!(atr, InvalidPeriodError, PyValueError);
2013#[cfg(feature = "python")]
2014create_exception!(atr, NotEnoughValidDataError, PyValueError);
2015#[cfg(feature = "python")]
2016create_exception!(atr, OutputLengthMismatchError, PyValueError);
2017#[cfg(feature = "python")]
2018create_exception!(atr, InvalidRangeError, PyValueError);
2019#[cfg(feature = "python")]
2020create_exception!(atr, InvalidKernelForBatchError, PyValueError);
2021
2022#[cfg(feature = "python")]
2023impl From<AtrError> for PyErr {
2024    fn from(err: AtrError) -> PyErr {
2025        match err {
2026            AtrError::EmptyInputData => {
2027                EmptyInputDataError::new_err("atr: Input data slice is empty.")
2028            }
2029            AtrError::AllValuesNaN => AllValuesNaNError::new_err("atr: All values are NaN."),
2030            AtrError::InvalidPeriod { period, data_len } => InvalidPeriodError::new_err(format!(
2031                "atr: Invalid period: period = {}, data length = {}",
2032                period, data_len
2033            )),
2034            AtrError::NotEnoughValidData { needed, valid } => {
2035                NotEnoughValidDataError::new_err(format!(
2036                    "atr: Not enough valid data: needed = {}, valid = {}",
2037                    needed, valid
2038                ))
2039            }
2040            AtrError::OutputLengthMismatch { expected, got } => {
2041                OutputLengthMismatchError::new_err(format!(
2042                    "atr: Output slice length mismatch: expected = {}, got = {}",
2043                    expected, got
2044                ))
2045            }
2046            AtrError::InvalidRange { start, end, step } => InvalidRangeError::new_err(format!(
2047                "atr: Invalid range: start = {}, end = {}, step = {}",
2048                start, end, step
2049            )),
2050            AtrError::InvalidKernelForBatch(k) => InvalidKernelForBatchError::new_err(format!(
2051                "atr: Invalid kernel type for batch operation: {:?}",
2052                k
2053            )),
2054            AtrError::InvalidLength { length } => InvalidLengthError::new_err(format!(
2055                "Invalid length for ATR calculation (length={}).",
2056                length
2057            )),
2058            AtrError::InconsistentSliceLengths {
2059                high_len,
2060                low_len,
2061                close_len,
2062            } => InconsistentSliceLengthsError::new_err(format!(
2063                "Inconsistent slice lengths for ATR calculation: high={}, low={}, close={}",
2064                high_len, low_len, close_len
2065            )),
2066            AtrError::NoCandlesAvailable => {
2067                NoCandlesAvailableError::new_err("No candles available for ATR calculation.")
2068            }
2069            AtrError::NotEnoughData { length, data_len } => NotEnoughDataError::new_err(format!(
2070                "Not enough data to calculate ATR: length={}, data length={}",
2071                length, data_len
2072            )),
2073        }
2074    }
2075}
2076
2077#[cfg(all(feature = "python", feature = "cuda"))]
2078use crate::cuda::atr_wrapper::DeviceArrayF32Atr;
2079#[cfg(all(feature = "python", feature = "cuda"))]
2080use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2081
2082#[cfg(all(feature = "python", feature = "cuda"))]
2083#[pyclass(module = "ta_indicators.cuda", unsendable)]
2084pub struct DeviceArrayF32Py {
2085    pub(crate) inner: DeviceArrayF32Atr,
2086}
2087
2088#[cfg(all(feature = "python", feature = "cuda"))]
2089#[pymethods]
2090impl DeviceArrayF32Py {
2091    #[getter]
2092    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2093        let d = PyDict::new(py);
2094        d.set_item("shape", (self.inner.rows, self.inner.cols))?;
2095        d.set_item("typestr", "<f4")?;
2096        d.set_item(
2097            "strides",
2098            (
2099                self.inner.cols * std::mem::size_of::<f32>(),
2100                std::mem::size_of::<f32>(),
2101            ),
2102        )?;
2103        d.set_item("data", (self.inner.device_ptr() as usize, false))?;
2104
2105        d.set_item("version", 3)?;
2106        Ok(d)
2107    }
2108
2109    fn __dlpack_device__(&self) -> (i32, i32) {
2110        (2, self.inner.device_id as i32)
2111    }
2112
2113    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2114    fn __dlpack__<'py>(
2115        &mut self,
2116        py: Python<'py>,
2117        stream: Option<PyObject>,
2118        max_version: Option<PyObject>,
2119        dl_device: Option<PyObject>,
2120        copy: Option<PyObject>,
2121    ) -> PyResult<PyObject> {
2122        use cust::memory::DeviceBuffer;
2123
2124        let (kdl, alloc_dev) = self.__dlpack_device__();
2125        if let Some(dev_obj) = dl_device.as_ref() {
2126            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2127                if dev_ty != kdl || dev_id != alloc_dev {
2128                    let wants_copy = copy
2129                        .as_ref()
2130                        .and_then(|c| c.extract::<bool>(py).ok())
2131                        .unwrap_or(false);
2132                    if wants_copy {
2133                        return Err(PyBufferError::new_err(
2134                            "device copy not implemented for __dlpack__",
2135                        ));
2136                    } else {
2137                        return Err(PyBufferError::new_err(
2138                            "__dlpack__: requested device does not match producer buffer",
2139                        ));
2140                    }
2141                }
2142            }
2143        }
2144        let _ = stream;
2145
2146        if let Some(copy_obj) = copy.as_ref() {
2147            let do_copy: bool = copy_obj.extract(py)?;
2148            if do_copy {
2149                return Err(PyBufferError::new_err(
2150                    "__dlpack__(copy=True) not supported for atr CUDA buffers",
2151                ));
2152            }
2153        }
2154
2155        let dummy =
2156            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2157        let rows = self.inner.rows;
2158        let cols = self.inner.cols;
2159        let ctx = self.inner.ctx.clone();
2160        let device_id = self.inner.device_id;
2161        let inner = std::mem::replace(
2162            &mut self.inner,
2163            DeviceArrayF32Atr {
2164                buf: dummy,
2165                rows: 0,
2166                cols: 0,
2167                ctx,
2168                device_id,
2169            },
2170        );
2171
2172        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2173
2174        export_f32_cuda_dlpack_2d(py, inner.buf, rows, cols, alloc_dev, max_version_bound)
2175    }
2176}
2177
2178#[inline(always)]
2179fn atr_prepare_from_input<'a>(
2180    input: &'a AtrInput,
2181) -> Result<(&'a [f64], &'a [f64], &'a [f64], usize, usize), AtrError> {
2182    let (high, low, close) = match &input.data {
2183        AtrData::Candles { candles } => (&candles.high[..], &candles.low[..], &candles.close[..]),
2184        AtrData::Slices { high, low, close } => (*high, *low, *close),
2185    };
2186
2187    let length = input.params.length.unwrap_or(14);
2188    let (high, low, close, length) = atr_prepare(high, low, close, length)?;
2189    let warmup = length - 1;
2190    Ok((high, low, close, length, warmup))
2191}
2192
2193#[inline(always)]
2194fn atr_prepare<'a>(
2195    high: &'a [f64],
2196    low: &'a [f64],
2197    close: &'a [f64],
2198    length: usize,
2199) -> Result<(&'a [f64], &'a [f64], &'a [f64], usize), AtrError> {
2200    if high.len() != low.len() || low.len() != close.len() {
2201        return Err(AtrError::InconsistentSliceLengths {
2202            high_len: high.len(),
2203            low_len: low.len(),
2204            close_len: close.len(),
2205        });
2206    }
2207
2208    if close.is_empty() {
2209        return Err(AtrError::NoCandlesAvailable);
2210    }
2211
2212    if length == 0 {
2213        return Err(AtrError::InvalidLength { length });
2214    }
2215
2216    if length > close.len() {
2217        return Err(AtrError::NotEnoughData {
2218            length,
2219            data_len: close.len(),
2220        });
2221    }
2222
2223    Ok((high, low, close, length))
2224}
2225
2226#[cfg(feature = "python")]
2227#[pyfunction(name = "atr")]
2228#[pyo3(signature = (high, low, close, length=14, kernel=None))]
2229pub fn atr_py<'py>(
2230    py: Python<'py>,
2231    high: numpy::PyReadonlyArray1<'py, f64>,
2232    low: numpy::PyReadonlyArray1<'py, f64>,
2233    close: numpy::PyReadonlyArray1<'py, f64>,
2234    length: usize,
2235    kernel: Option<&str>,
2236) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
2237    use numpy::{IntoPyArray, PyArrayMethods};
2238
2239    let kernel_enum = validate_kernel(kernel, false)?;
2240
2241    let high_slice = high.as_slice()?;
2242    let low_slice = low.as_slice()?;
2243    let close_slice = close.as_slice()?;
2244
2245    let params = AtrParams {
2246        length: Some(length),
2247    };
2248    let input = AtrInput::from_slices(high_slice, low_slice, close_slice, params);
2249
2250    let result_vec: Vec<f64> = py
2251        .allow_threads(|| atr_with_kernel(&input, kernel_enum).map(|output| output.values))
2252        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2253
2254    Ok(result_vec.into_pyarray(py))
2255}
2256
2257#[cfg(feature = "python")]
2258#[pyclass(name = "AtrStream")]
2259pub struct AtrStreamPy {
2260    stream: AtrStream,
2261}
2262
2263#[cfg(feature = "python")]
2264#[pymethods]
2265impl AtrStreamPy {
2266    #[new]
2267    pub fn new(length: Option<usize>) -> PyResult<Self> {
2268        let params = AtrParams { length };
2269        let stream = AtrStream::try_new(params)?;
2270        Ok(Self { stream })
2271    }
2272
2273    pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
2274        self.stream.update(high, low, close)
2275    }
2276}
2277
2278#[cfg(feature = "python")]
2279#[pyfunction(name = "atr_batch")]
2280#[pyo3(signature = (high, low, close, length_range, kernel=None))]
2281pub fn atr_batch_py<'py>(
2282    py: Python<'py>,
2283    high: numpy::PyReadonlyArray1<'py, f64>,
2284    low: numpy::PyReadonlyArray1<'py, f64>,
2285    close: numpy::PyReadonlyArray1<'py, f64>,
2286    length_range: (usize, usize, usize),
2287    kernel: Option<&str>,
2288) -> PyResult<Bound<'py, PyDict>> {
2289    use numpy::{IntoPyArray, PyArrayMethods};
2290
2291    let k = validate_kernel(kernel, true)?;
2292    let hs = high.as_slice()?;
2293    let ls = low.as_slice()?;
2294    let cs = close.as_slice()?;
2295
2296    let range = AtrBatchRange {
2297        length: length_range,
2298    };
2299    let combos = expand_grid(&range);
2300    let rows = combos.len();
2301    let cols = cs.len();
2302    let total = rows
2303        .checked_mul(cols)
2304        .ok_or_else(|| PyValueError::new_err("atr_batch: rows*cols overflow"))?;
2305
2306    let out_arr = unsafe { numpy::PyArray1::<f64>::new(py, [total], false) };
2307    let buf = unsafe { out_arr.as_slice_mut()? };
2308
2309    py.allow_threads(|| {
2310        let simd = match match k {
2311            Kernel::Auto => detect_best_batch_kernel(),
2312            k if k.is_batch() => k,
2313            Kernel::Scalar => Kernel::ScalarBatch,
2314            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2315            Kernel::Avx2 => Kernel::Avx2Batch,
2316            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2317            Kernel::Avx512 => Kernel::Avx512Batch,
2318            _ => Kernel::ScalarBatch,
2319        } {
2320            Kernel::Avx512Batch => Kernel::Avx512,
2321            Kernel::Avx2Batch => Kernel::Avx2,
2322            Kernel::ScalarBatch => Kernel::Scalar,
2323            _ => unreachable!(),
2324        };
2325        atr_batch_inner_into(hs, ls, cs, &range, simd, true, buf)
2326            .map(|_| ())
2327            .map_err(|e| e)
2328    })
2329    .map_err(|e: AtrError| PyValueError::new_err(e.to_string()))?;
2330
2331    let dict = PyDict::new(py);
2332    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2333    dict.set_item(
2334        "lengths",
2335        combos
2336            .iter()
2337            .map(|p| p.length.unwrap())
2338            .collect::<Vec<_>>()
2339            .into_pyarray(py),
2340    )?;
2341    Ok(dict.into())
2342}
2343
2344#[cfg(all(feature = "python", feature = "cuda"))]
2345#[pyfunction(name = "atr_cuda_batch_dev")]
2346#[pyo3(signature = (high, low, close, length_range, device_id=0))]
2347pub fn atr_cuda_batch_dev_py(
2348    py: Python<'_>,
2349    high: numpy::PyReadonlyArray1<'_, f32>,
2350    low: numpy::PyReadonlyArray1<'_, f32>,
2351    close: numpy::PyReadonlyArray1<'_, f32>,
2352    length_range: (usize, usize, usize),
2353    device_id: usize,
2354) -> PyResult<DeviceArrayF32Py> {
2355    if !cuda_available() {
2356        return Err(PyValueError::new_err("CUDA not available"));
2357    }
2358    let hs = high.as_slice()?;
2359    let ls = low.as_slice()?;
2360    let cs = close.as_slice()?;
2361    if hs.len() != ls.len() || ls.len() != cs.len() {
2362        return Err(PyValueError::new_err("input length mismatch"));
2363    }
2364    let sweep = AtrBatchRange {
2365        length: length_range,
2366    };
2367    let inner = py.allow_threads(|| {
2368        let cuda = CudaAtr::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2369        cuda.atr_batch_dev(hs, ls, cs, &sweep)
2370            .map_err(|e| PyValueError::new_err(e.to_string()))
2371    })?;
2372    Ok(DeviceArrayF32Py { inner })
2373}
2374
2375#[cfg(all(feature = "python", feature = "cuda"))]
2376#[pyfunction(name = "atr_cuda_many_series_one_param_dev")]
2377#[pyo3(signature = (high_tm, low_tm, close_tm, cols, rows, length, device_id=0))]
2378pub fn atr_cuda_many_series_one_param_dev_py(
2379    py: Python<'_>,
2380    high_tm: numpy::PyReadonlyArray1<'_, f32>,
2381    low_tm: numpy::PyReadonlyArray1<'_, f32>,
2382    close_tm: numpy::PyReadonlyArray1<'_, f32>,
2383    cols: usize,
2384    rows: usize,
2385    length: usize,
2386    device_id: usize,
2387) -> PyResult<DeviceArrayF32Py> {
2388    if !cuda_available() {
2389        return Err(PyValueError::new_err("CUDA not available"));
2390    }
2391    let h = high_tm.as_slice()?;
2392    let l = low_tm.as_slice()?;
2393    let c = close_tm.as_slice()?;
2394    let expected = cols
2395        .checked_mul(rows)
2396        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2397    if h.len() != expected || l.len() != expected || c.len() != expected {
2398        return Err(PyValueError::new_err("time-major input length mismatch"));
2399    }
2400    let inner = py.allow_threads(|| {
2401        let cuda = CudaAtr::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2402        cuda.atr_many_series_one_param_time_major_dev(h, l, c, cols, rows, length)
2403            .map_err(|e| PyValueError::new_err(e.to_string()))
2404    })?;
2405    Ok(DeviceArrayF32Py { inner })
2406}
2407
2408#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2409pub fn atr_into_slice(dst: &mut [f64], input: &AtrInput, kern: Kernel) -> Result<(), AtrError> {
2410    let (high, low, close) = match &input.data {
2411        AtrData::Candles { candles } => (&candles.high[..], &candles.low[..], &candles.close[..]),
2412        AtrData::Slices { high, low, close } => (*high, *low, *close),
2413    };
2414
2415    let length = input.params.length.unwrap_or(14);
2416    let (high, low, close, length) = atr_prepare(high, low, close, length)?;
2417    let first = first_valid_hlc(high, low, close);
2418    let valid = close.len().saturating_sub(first);
2419    if valid < length {
2420        return Err(AtrError::NotEnoughValidData {
2421            needed: length,
2422            valid,
2423        });
2424    }
2425    let warm = first + length - 1;
2426
2427    if dst.len() != close.len() {
2428        return Err(AtrError::OutputLengthMismatch {
2429            expected: close.len(),
2430            got: dst.len(),
2431        });
2432    }
2433
2434    for v in &mut dst[..warm] {
2435        *v = f64::NAN;
2436    }
2437
2438    let k = match kern {
2439        Kernel::Auto => Kernel::Scalar,
2440        k => k,
2441    };
2442    atr_compute_into(high, low, close, length, first, k, dst);
2443    Ok(())
2444}
2445
2446#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2447#[wasm_bindgen(js_name = "atr")]
2448pub fn atr_js(
2449    high: &[f64],
2450    low: &[f64],
2451    close: &[f64],
2452    length: usize,
2453) -> Result<Vec<f64>, JsError> {
2454    let params = AtrParams {
2455        length: Some(length),
2456    };
2457    let input = AtrInput::from_slices(high, low, close, params);
2458
2459    let mut output = vec![0.0; high.len()];
2460    atr_into_slice(&mut output, &input, Kernel::Auto).map_err(|e| JsError::new(&e.to_string()))?;
2461
2462    Ok(output)
2463}
2464
2465#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2466#[wasm_bindgen(js_name = "atrBatch")]
2467pub fn atr_batch_js(
2468    high: &[f64],
2469    low: &[f64],
2470    close: &[f64],
2471    length_start: usize,
2472    length_end: usize,
2473    length_step: usize,
2474) -> Result<Vec<f64>, JsError> {
2475    let range = AtrBatchRange {
2476        length: (length_start, length_end, length_step),
2477    };
2478    let output = atr_batch_with_kernel(high, low, close, &range, Kernel::Auto)
2479        .map_err(|e| JsError::new(&e.to_string()))?;
2480    Ok(output.values)
2481}
2482
2483#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2484#[wasm_bindgen(js_name = "atrBatchMetadata")]
2485pub fn atr_batch_metadata_js(
2486    length_start: usize,
2487    length_end: usize,
2488    length_step: usize,
2489) -> Vec<f64> {
2490    let range = AtrBatchRange {
2491        length: (length_start, length_end, length_step),
2492    };
2493    let combos = expand_grid(&range);
2494
2495    combos
2496        .iter()
2497        .map(|p| p.length.unwrap_or(14) as f64)
2498        .collect()
2499}
2500
2501#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2502#[wasm_bindgen(js_name = "atr_batch", skip_jsdoc)]
2503pub fn atr_batch_unified_js(
2504    high: &[f64],
2505    low: &[f64],
2506    close: &[f64],
2507    config: JsValue,
2508) -> Result<JsValue, JsError> {
2509    #[derive(Deserialize)]
2510    struct BatchConfig {
2511        length_range: [usize; 3],
2512    }
2513
2514    let config: BatchConfig =
2515        serde_wasm_bindgen::from_value(config).map_err(|e| JsError::new(&e.to_string()))?;
2516
2517    let range = AtrBatchRange {
2518        length: (
2519            config.length_range[0],
2520            config.length_range[1],
2521            config.length_range[2],
2522        ),
2523    };
2524
2525    let output = atr_batch_with_kernel(high, low, close, &range, Kernel::Auto)
2526        .map_err(|e| JsError::new(&e.to_string()))?;
2527
2528    #[derive(Serialize)]
2529    struct BatchResult {
2530        values: Vec<f64>,
2531        combos: Vec<AtrParams>,
2532        rows: usize,
2533        cols: usize,
2534    }
2535
2536    let result = BatchResult {
2537        values: output.values,
2538        combos: output.combos,
2539        rows: output.rows,
2540        cols: output.cols,
2541    };
2542
2543    serde_wasm_bindgen::to_value(&result).map_err(|e| JsError::new(&e.to_string()))
2544}
2545
2546#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2547#[wasm_bindgen]
2548pub fn atr_alloc(len: usize) -> *mut f64 {
2549    let mut vec = Vec::<f64>::with_capacity(len);
2550    let ptr = vec.as_mut_ptr();
2551    std::mem::forget(vec);
2552    ptr
2553}
2554
2555#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2556#[wasm_bindgen]
2557pub fn atr_free(ptr: *mut f64, len: usize) {
2558    unsafe {
2559        let _ = Vec::from_raw_parts(ptr, len, len);
2560    }
2561}
2562
2563#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2564#[wasm_bindgen]
2565pub fn atr_into(
2566    high_ptr: *const f64,
2567    low_ptr: *const f64,
2568    close_ptr: *const f64,
2569    out_ptr: *mut f64,
2570    len: usize,
2571    length: usize,
2572) -> Result<(), JsError> {
2573    if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
2574        return Err(JsError::new("null pointer passed to atr_into"));
2575    }
2576
2577    unsafe {
2578        let high = std::slice::from_raw_parts(high_ptr, len);
2579        let low = std::slice::from_raw_parts(low_ptr, len);
2580        let close = std::slice::from_raw_parts(close_ptr, len);
2581
2582        let params = AtrParams {
2583            length: Some(length),
2584        };
2585        let input = AtrInput::from_slices(high, low, close, params);
2586
2587        if high_ptr == out_ptr || low_ptr == out_ptr || close_ptr == out_ptr {
2588            let mut temp = vec![0.0; len];
2589            atr_into_slice(&mut temp, &input, Kernel::Auto)
2590                .map_err(|e| JsError::new(&e.to_string()))?;
2591            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2592            out.copy_from_slice(&temp);
2593        } else {
2594            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2595            atr_into_slice(out, &input, Kernel::Auto).map_err(|e| JsError::new(&e.to_string()))?;
2596        }
2597        Ok(())
2598    }
2599}
2600
2601#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2602#[wasm_bindgen]
2603pub fn atr_batch_into(
2604    high_ptr: *const f64,
2605    low_ptr: *const f64,
2606    close_ptr: *const f64,
2607    out_ptr: *mut f64,
2608    len: usize,
2609    length_start: usize,
2610    length_end: usize,
2611    length_step: usize,
2612) -> Result<(), JsError> {
2613    if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
2614        return Err(JsError::new("null pointer passed to atr_batch_into"));
2615    }
2616
2617    unsafe {
2618        let high = std::slice::from_raw_parts(high_ptr, len);
2619        let low = std::slice::from_raw_parts(low_ptr, len);
2620        let close = std::slice::from_raw_parts(close_ptr, len);
2621
2622        let range = AtrBatchRange {
2623            length: (length_start, length_end, length_step),
2624        };
2625
2626        let combos = expand_grid(&range);
2627        let rows = combos.len();
2628        let cols = len;
2629        let output_size = rows
2630            .checked_mul(cols)
2631            .ok_or_else(|| JsError::new("atr_batch_into: rows*cols overflow"))?;
2632
2633        if high_ptr == out_ptr || low_ptr == out_ptr || close_ptr == out_ptr {
2634            let output = atr_batch_with_kernel(high, low, close, &range, Kernel::Auto)
2635                .map_err(|e| JsError::new(&e.to_string()))?;
2636            let out_slice = std::slice::from_raw_parts_mut(out_ptr, output_size);
2637            out_slice.copy_from_slice(&output.values);
2638        } else {
2639            let out_slice = std::slice::from_raw_parts_mut(out_ptr, output_size);
2640
2641            let kernel = match detect_best_batch_kernel() {
2642                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2643                Kernel::Avx512Batch => Kernel::Avx512,
2644                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2645                Kernel::Avx2Batch => Kernel::Avx2,
2646                Kernel::ScalarBatch => Kernel::Scalar,
2647                _ => Kernel::Scalar,
2648            };
2649
2650            atr_batch_inner_into(high, low, close, &range, kernel, false, out_slice)
2651                .map_err(|e| JsError::new(&e.to_string()))?;
2652        }
2653
2654        Ok(())
2655    }
2656}
2657
2658#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2659#[wasm_bindgen]
2660#[deprecated(
2661    since = "1.0.0",
2662    note = "For weight reuse patterns, use the fast/unsafe API with persistent buffers"
2663)]
2664pub struct AtrContext {
2665    stream: AtrStream,
2666}
2667
2668#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2669#[wasm_bindgen]
2670#[allow(deprecated)]
2671impl AtrContext {
2672    #[wasm_bindgen(constructor)]
2673    #[deprecated(
2674        since = "1.0.0",
2675        note = "For weight reuse patterns, use the fast/unsafe API with persistent buffers"
2676    )]
2677    pub fn new(length: usize) -> Result<AtrContext, JsError> {
2678        let params = AtrParams {
2679            length: Some(length),
2680        };
2681        let stream = AtrStream::try_new(params).map_err(|e| JsError::new(&e.to_string()))?;
2682        Ok(AtrContext { stream })
2683    }
2684
2685    #[wasm_bindgen]
2686    pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
2687        self.stream.update(high, low, close)
2688    }
2689
2690    #[wasm_bindgen]
2691    pub fn reset(&mut self) -> Result<(), JsError> {
2692        let length = self.stream.length;
2693        let params = AtrParams {
2694            length: Some(length),
2695        };
2696        self.stream = AtrStream::try_new(params).map_err(|e| JsError::new(&e.to_string()))?;
2697        Ok(())
2698    }
2699}
2700
2701#[cfg(feature = "python")]
2702pub fn register_atr_exceptions(m: &Bound<'_, PyModule>) -> PyResult<()> {
2703    m.add(
2704        "InvalidLengthError",
2705        m.py().get_type::<InvalidLengthError>(),
2706    )?;
2707    m.add(
2708        "InconsistentSliceLengthsError",
2709        m.py().get_type::<InconsistentSliceLengthsError>(),
2710    )?;
2711    m.add(
2712        "NoCandlesAvailableError",
2713        m.py().get_type::<NoCandlesAvailableError>(),
2714    )?;
2715    m.add(
2716        "NotEnoughDataError",
2717        m.py().get_type::<NotEnoughDataError>(),
2718    )?;
2719    Ok(())
2720}