1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods, PyUntypedArrayMethods};
3#[cfg(feature = "python")]
4use pyo3::exceptions::{PyBufferError, PyValueError};
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::{PyDict, PyList};
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15#[cfg(all(feature = "python", feature = "cuda"))]
16use crate::cuda::{cuda_available, CudaAtr};
17use crate::utilities::data_loader::{source_type, Candles};
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21 make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25use aligned_vec::{AVec, CACHELINE_ALIGN};
26#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
27use core::arch::x86_64::*;
28#[cfg(not(target_arch = "wasm32"))]
29use rayon::prelude::*;
30use std::error::Error;
31use thiserror::Error;
32
33#[derive(Debug, Clone)]
34pub enum AtrData<'a> {
35 Candles {
36 candles: &'a Candles,
37 },
38 Slices {
39 high: &'a [f64],
40 low: &'a [f64],
41 close: &'a [f64],
42 },
43}
44
45#[derive(Debug, Clone)]
46pub struct AtrOutput {
47 pub values: Vec<f64>,
48}
49
50#[derive(Debug, Clone)]
51#[cfg_attr(
52 all(target_arch = "wasm32", feature = "wasm"),
53 derive(Serialize, Deserialize)
54)]
55pub struct AtrParams {
56 pub length: Option<usize>,
57}
58
59impl Default for AtrParams {
60 fn default() -> Self {
61 Self { length: Some(14) }
62 }
63}
64
65#[derive(Debug, Clone)]
66pub struct AtrInput<'a> {
67 pub data: AtrData<'a>,
68 pub params: AtrParams,
69}
70
71impl<'a> AtrInput<'a> {
72 #[inline]
73 pub fn from_candles(candles: &'a Candles, params: AtrParams) -> Self {
74 Self {
75 data: AtrData::Candles { candles },
76 params,
77 }
78 }
79 #[inline]
80 pub fn from_slices(
81 high: &'a [f64],
82 low: &'a [f64],
83 close: &'a [f64],
84 params: AtrParams,
85 ) -> Self {
86 Self {
87 data: AtrData::Slices { high, low, close },
88 params,
89 }
90 }
91 #[inline]
92 pub fn with_default_candles(candles: &'a Candles) -> Self {
93 Self::from_candles(candles, AtrParams::default())
94 }
95 #[inline]
96 pub fn get_length(&self) -> usize {
97 self.params.length.unwrap_or(14)
98 }
99}
100
101#[derive(Copy, Clone, Debug)]
102pub struct AtrBuilder {
103 length: Option<usize>,
104 kernel: Kernel,
105}
106
107impl Default for AtrBuilder {
108 fn default() -> Self {
109 Self {
110 length: None,
111 kernel: Kernel::Auto,
112 }
113 }
114}
115
116impl AtrBuilder {
117 #[inline(always)]
118 pub fn new() -> Self {
119 Self::default()
120 }
121 #[inline(always)]
122 pub fn length(mut self, n: usize) -> Self {
123 self.length = Some(n);
124 self
125 }
126 #[inline(always)]
127 pub fn kernel(mut self, k: Kernel) -> Self {
128 self.kernel = k;
129 self
130 }
131 #[inline(always)]
132 pub fn apply(self, c: &Candles) -> Result<AtrOutput, AtrError> {
133 let p = AtrParams {
134 length: self.length,
135 };
136 let i = AtrInput::from_candles(c, p);
137 atr_with_kernel(&i, self.kernel)
138 }
139 #[inline(always)]
140 pub fn apply_slices(
141 self,
142 high: &[f64],
143 low: &[f64],
144 close: &[f64],
145 ) -> Result<AtrOutput, AtrError> {
146 let p = AtrParams {
147 length: self.length,
148 };
149 let i = AtrInput::from_slices(high, low, close, p);
150 atr_with_kernel(&i, self.kernel)
151 }
152 #[inline(always)]
153 pub fn into_stream(self) -> Result<AtrStream, AtrError> {
154 let p = AtrParams {
155 length: self.length,
156 };
157 AtrStream::try_new(p)
158 }
159}
160
161#[derive(Debug, Error)]
162pub enum AtrError {
163 #[error("atr: Input data slice is empty.")]
164 EmptyInputData,
165 #[error("atr: All values are NaN.")]
166 AllValuesNaN,
167 #[error("atr: Invalid period: period = {period}, data length = {data_len}")]
168 InvalidPeriod { period: usize, data_len: usize },
169 #[error("atr: Not enough valid data: needed = {needed}, valid = {valid}")]
170 NotEnoughValidData { needed: usize, valid: usize },
171 #[error("atr: Output slice length mismatch: expected = {expected}, got = {got}")]
172 OutputLengthMismatch { expected: usize, got: usize },
173 #[error("atr: Invalid range: start = {start}, end = {end}, step = {step}")]
174 InvalidRange {
175 start: usize,
176 end: usize,
177 step: usize,
178 },
179 #[error("atr: Invalid kernel type for batch operation: {0:?}")]
180 InvalidKernelForBatch(Kernel),
181
182 #[error("Invalid length for ATR calculation (length={length}).")]
183 InvalidLength { length: usize },
184 #[error("Inconsistent slice lengths for ATR calculation: high={high_len}, low={low_len}, close={close_len}")]
185 InconsistentSliceLengths {
186 high_len: usize,
187 low_len: usize,
188 close_len: usize,
189 },
190 #[error("atr: No candles available for ATR calculation.")]
191 NoCandlesAvailable,
192 #[error("Not enough data to calculate ATR: length={length}, data length={data_len}")]
193 NotEnoughData { length: usize, data_len: usize },
194}
195
196#[inline(always)]
197fn first_valid_hlc(high: &[f64], low: &[f64], close: &[f64]) -> usize {
198 let len = close.len();
199 let mut i = 0;
200 while i < len {
201 if !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan() {
202 break;
203 }
204 i += 1;
205 }
206 i.min(len)
207}
208
209#[inline(always)]
210fn atr_prepare_full<'a>(
211 high: &'a [f64],
212 low: &'a [f64],
213 close: &'a [f64],
214 length: usize,
215) -> Result<(&'a [f64], &'a [f64], &'a [f64], usize, usize), AtrError> {
216 let (high, low, close, length) = atr_prepare(high, low, close, length)?;
217 let first = first_valid_hlc(high, low, close);
218 if first >= close.len() {
219 return Err(AtrError::AllValuesNaN);
220 }
221 let valid = close.len().saturating_sub(first);
222 if valid < length {
223 return Err(AtrError::NotEnoughValidData {
224 needed: length,
225 valid,
226 });
227 }
228 let warmup = first + length - 1;
229 Ok((high, low, close, first, warmup))
230}
231
232#[inline]
233pub fn atr(input: &AtrInput) -> Result<AtrOutput, AtrError> {
234 atr_with_kernel(input, Kernel::Auto)
235}
236
237pub fn atr_with_kernel(input: &AtrInput, kernel: Kernel) -> Result<AtrOutput, AtrError> {
238 let (high, low, close) = match &input.data {
239 AtrData::Candles { candles } => (
240 candles
241 .select_candle_field("high")
242 .map_err(|_| AtrError::NoCandlesAvailable)?,
243 candles
244 .select_candle_field("low")
245 .map_err(|_| AtrError::NoCandlesAvailable)?,
246 candles
247 .select_candle_field("close")
248 .map_err(|_| AtrError::NoCandlesAvailable)?,
249 ),
250 AtrData::Slices { high, low, close } => {
251 if high.len() != low.len() || low.len() != close.len() {
252 return Err(AtrError::InconsistentSliceLengths {
253 high_len: high.len(),
254 low_len: low.len(),
255 close_len: close.len(),
256 });
257 }
258 (*high, *low, *close)
259 }
260 };
261
262 let len = close.len();
263 let length = input.get_length();
264 if length == 0 {
265 return Err(AtrError::InvalidLength { length });
266 }
267 if len == 0 {
268 return Err(AtrError::NoCandlesAvailable);
269 }
270 if length > len {
271 return Err(AtrError::NotEnoughData {
272 length,
273 data_len: len,
274 });
275 }
276
277 let chosen = match kernel {
278 Kernel::Auto => Kernel::Scalar,
279 k => k,
280 };
281
282 let (_, _, _, first, warmup) = atr_prepare_full(high, low, close, length)?;
283 let mut out = alloc_with_nan_prefix(len, warmup);
284 atr_compute_into(high, low, close, length, first, chosen, &mut out);
285 Ok(AtrOutput { values: out })
286}
287
288#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
289pub fn atr_into(input: &AtrInput, out: &mut [f64]) -> Result<(), AtrError> {
290 let (high, low, close) = match &input.data {
291 AtrData::Candles { candles } => (&candles.high[..], &candles.low[..], &candles.close[..]),
292 AtrData::Slices { high, low, close } => (*high, *low, *close),
293 };
294
295 let length = input.params.length.unwrap_or(14);
296 let (high, low, close, length) = atr_prepare(high, low, close, length)?;
297
298 let first = first_valid_hlc(high, low, close);
299 let valid = close.len().saturating_sub(first);
300 if valid < length {
301 return Err(AtrError::NotEnoughValidData {
302 needed: length,
303 valid,
304 });
305 }
306 let warmup = first + length - 1;
307
308 if out.len() != close.len() {
309 return Err(AtrError::OutputLengthMismatch {
310 expected: close.len(),
311 got: out.len(),
312 });
313 }
314
315 let prefix = warmup.min(out.len());
316 for v in &mut out[..prefix] {
317 *v = f64::from_bits(0x7ff8_0000_0000_0000);
318 }
319
320 let chosen = match Kernel::Auto {
321 Kernel::Auto => Kernel::Scalar,
322 k => k,
323 };
324 atr_compute_into(high, low, close, length, first, chosen, out);
325 Ok(())
326}
327
328#[inline(always)]
329fn atr_compute_into_scalar(
330 high: &[f64],
331 low: &[f64],
332 close: &[f64],
333 length: usize,
334 first: usize,
335 out: &mut [f64],
336) {
337 debug_assert_eq!(high.len(), low.len());
338 debug_assert_eq!(low.len(), close.len());
339 debug_assert_eq!(out.len(), close.len());
340
341 let warm = first + length - 1;
342 let alpha = 1.0 / (length as f64);
343
344 unsafe {
345 let mut sum_tr = *high.get_unchecked(first) - *low.get_unchecked(first);
346
347 if warm > first {
348 let mut i = first + 1;
349 let mut prev_c = *close.get_unchecked(i - 1);
350 while i <= warm {
351 let hi = *high.get_unchecked(i);
352 let lo = *low.get_unchecked(i);
353
354 let mut tr = hi - lo;
355 let hc = (hi - prev_c).abs();
356 if hc > tr {
357 tr = hc;
358 }
359 let lc = (lo - prev_c).abs();
360 if lc > tr {
361 tr = lc;
362 }
363
364 sum_tr += tr;
365 prev_c = *close.get_unchecked(i);
366 i += 1;
367 }
368 }
369
370 let mut rma = sum_tr / (length as f64);
371 *out.get_unchecked_mut(warm) = rma;
372
373 let mut i = warm + 1;
374 let n = out.len();
375
376 let mut prev_c = if i > 0 {
377 *close.get_unchecked(i - 1)
378 } else {
379 *close.get_unchecked(0)
380 };
381
382 while i + 3 < n {
383 let (hi0, lo0) = (*high.get_unchecked(i), *low.get_unchecked(i));
384 let mut tr0 = hi0 - lo0;
385 let hc0 = (hi0 - prev_c).abs();
386 if hc0 > tr0 {
387 tr0 = hc0;
388 }
389 let lc0 = (lo0 - prev_c).abs();
390 if lc0 > tr0 {
391 tr0 = lc0;
392 }
393 rma = (-alpha).mul_add(rma, rma) + alpha * tr0;
394 *out.get_unchecked_mut(i) = rma;
395
396 let prev0 = *close.get_unchecked(i);
397 let (hi1, lo1) = (*high.get_unchecked(i + 1), *low.get_unchecked(i + 1));
398 let mut tr1 = hi1 - lo1;
399 let hc1 = (hi1 - prev0).abs();
400 if hc1 > tr1 {
401 tr1 = hc1;
402 }
403 let lc1 = (lo1 - prev0).abs();
404 if lc1 > tr1 {
405 tr1 = lc1;
406 }
407 rma = (-alpha).mul_add(rma, rma) + alpha * tr1;
408 *out.get_unchecked_mut(i + 1) = rma;
409
410 let prev1 = *close.get_unchecked(i + 1);
411 let (hi2, lo2) = (*high.get_unchecked(i + 2), *low.get_unchecked(i + 2));
412 let mut tr2 = hi2 - lo2;
413 let hc2 = (hi2 - prev1).abs();
414 if hc2 > tr2 {
415 tr2 = hc2;
416 }
417 let lc2 = (lo2 - prev1).abs();
418 if lc2 > tr2 {
419 tr2 = lc2;
420 }
421 rma = (-alpha).mul_add(rma, rma) + alpha * tr2;
422 *out.get_unchecked_mut(i + 2) = rma;
423
424 let prev2 = *close.get_unchecked(i + 2);
425 let (hi3, lo3) = (*high.get_unchecked(i + 3), *low.get_unchecked(i + 3));
426 let mut tr3 = hi3 - lo3;
427 let hc3 = (hi3 - prev2).abs();
428 if hc3 > tr3 {
429 tr3 = hc3;
430 }
431 let lc3 = (lo3 - prev2).abs();
432 if lc3 > tr3 {
433 tr3 = lc3;
434 }
435 rma = (-alpha).mul_add(rma, rma) + alpha * tr3;
436 *out.get_unchecked_mut(i + 3) = rma;
437
438 i += 4;
439 prev_c = *close.get_unchecked(i - 1);
440 }
441
442 while i < n {
443 let (hi, lo) = (*high.get_unchecked(i), *low.get_unchecked(i));
444 let mut tr = hi - lo;
445 let hc = (hi - prev_c).abs();
446 if hc > tr {
447 tr = hc;
448 }
449 let lc = (lo - prev_c).abs();
450 if lc > tr {
451 tr = lc;
452 }
453 rma = (-alpha).mul_add(rma, rma) + alpha * tr;
454 *out.get_unchecked_mut(i) = rma;
455
456 prev_c = *close.get_unchecked(i);
457 i += 1;
458 }
459 }
460}
461
462#[inline]
463pub fn atr_scalar(high: &[f64], low: &[f64], close: &[f64], length: usize, out: &mut [f64]) {
464 atr_compute_into_scalar(high, low, close, length, 0, out);
465}
466
467#[inline(always)]
468fn atr_compute_into(
469 high: &[f64],
470 low: &[f64],
471 close: &[f64],
472 length: usize,
473 first: usize,
474 kern: Kernel,
475 out: &mut [f64],
476) {
477 unsafe {
478 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
479 {
480 if matches!(kern, Kernel::Scalar | Kernel::ScalarBatch) {
481 atr_compute_into_scalar(high, low, close, length, first, out);
482 return;
483 }
484 }
485 match kern {
486 Kernel::Scalar | Kernel::ScalarBatch => {
487 atr_compute_into_scalar(high, low, close, length, first, out)
488 }
489 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
490 Kernel::Avx2 | Kernel::Avx2Batch => {
491 atr_compute_into_avx2(high, low, close, length, first, out)
492 }
493 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
494 Kernel::Avx512 | Kernel::Avx512Batch => {
495 atr_compute_into_avx512(high, low, close, length, first, out)
496 }
497 _ => unreachable!(),
498 }
499 }
500}
501
502#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
503#[inline]
504unsafe fn atr_simd128(high: &[f64], low: &[f64], close: &[f64], length: usize, out: &mut [f64]) {
505 use core::arch::wasm32::*;
506
507 atr_scalar(high, low, close, length, out);
508}
509
510#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
511#[inline(always)]
512unsafe fn atr_compute_into_avx2(
513 high: &[f64],
514 low: &[f64],
515 close: &[f64],
516 length: usize,
517 first: usize,
518 out: &mut [f64],
519) {
520 use core::arch::x86_64::*;
521
522 debug_assert_eq!(high.len(), low.len());
523 debug_assert_eq!(low.len(), close.len());
524 debug_assert_eq!(out.len(), close.len());
525
526 let warm = first + length - 1;
527 let alpha = 1.0 / (length as f64);
528
529 let mut sum_tr = *high.get_unchecked(first) - *low.get_unchecked(first);
530 if warm > first {
531 let mut i = first + 1;
532 let mut prev_c = *close.get_unchecked(i - 1);
533 while i <= warm {
534 let hi = *high.get_unchecked(i);
535 let lo = *low.get_unchecked(i);
536
537 let mut tr = hi - lo;
538 let hc = (hi - prev_c).abs();
539 if hc > tr {
540 tr = hc;
541 }
542 let lc = (lo - prev_c).abs();
543 if lc > tr {
544 tr = lc;
545 }
546
547 sum_tr += tr;
548 prev_c = *close.get_unchecked(i);
549 i += 1;
550 }
551 }
552
553 let mut rma = sum_tr / (length as f64);
554 *out.get_unchecked_mut(warm) = rma;
555
556 let mut i = warm + 1;
557 let n = out.len();
558
559 let mask_abs = _mm256_castsi256_pd(_mm256_set1_epi64x(0x7fff_ffff_ffff_ffffu64 as i64));
560
561 while i + 3 < n {
562 let v_hi = _mm256_loadu_pd(high.as_ptr().add(i));
563 let v_lo = _mm256_loadu_pd(low.as_ptr().add(i));
564
565 let v_pc = _mm256_loadu_pd(close.as_ptr().add(i - 1));
566
567 let v_hl = _mm256_sub_pd(v_hi, v_lo);
568
569 let v_hc = _mm256_and_pd(_mm256_sub_pd(v_hi, v_pc), mask_abs);
570
571 let v_lc = _mm256_and_pd(_mm256_sub_pd(v_lo, v_pc), mask_abs);
572
573 let v_m1 = _mm256_max_pd(v_hl, v_hc);
574 let v_tr = _mm256_max_pd(v_m1, v_lc);
575
576 let mut buf = [0.0f64; 4];
577 _mm256_storeu_pd(buf.as_mut_ptr(), v_tr);
578
579 rma = (-alpha).mul_add(rma, rma) + alpha * buf[0];
580 *out.get_unchecked_mut(i) = rma;
581
582 rma = (-alpha).mul_add(rma, rma) + alpha * buf[1];
583 *out.get_unchecked_mut(i + 1) = rma;
584
585 rma = (-alpha).mul_add(rma, rma) + alpha * buf[2];
586 *out.get_unchecked_mut(i + 2) = rma;
587
588 rma = (-alpha).mul_add(rma, rma) + alpha * buf[3];
589 *out.get_unchecked_mut(i + 3) = rma;
590
591 i += 4;
592 }
593
594 if i < n {
595 let mut prev_c = *close.get_unchecked(i - 1);
596 while i < n {
597 let hi = *high.get_unchecked(i);
598 let lo = *low.get_unchecked(i);
599 let mut tr = hi - lo;
600 let hc = (hi - prev_c).abs();
601 if hc > tr {
602 tr = hc;
603 }
604 let lc = (lo - prev_c).abs();
605 if lc > tr {
606 tr = lc;
607 }
608 rma = (-alpha).mul_add(rma, rma) + alpha * tr;
609 *out.get_unchecked_mut(i) = rma;
610
611 prev_c = *close.get_unchecked(i);
612 i += 1;
613 }
614 }
615}
616
617#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
618#[inline(always)]
619unsafe fn atr_compute_into_avx512(
620 high: &[f64],
621 low: &[f64],
622 close: &[f64],
623 length: usize,
624 first: usize,
625 out: &mut [f64],
626) {
627 use core::arch::x86_64::*;
628
629 debug_assert_eq!(high.len(), low.len());
630 debug_assert_eq!(low.len(), close.len());
631 debug_assert_eq!(out.len(), close.len());
632
633 let warm = first + length - 1;
634 let alpha = 1.0 / (length as f64);
635
636 let mut sum_tr = *high.get_unchecked(first) - *low.get_unchecked(first);
637 if warm > first {
638 let mut i = first + 1;
639 let mut prev_c = *close.get_unchecked(i - 1);
640 while i <= warm {
641 let hi = *high.get_unchecked(i);
642 let lo = *low.get_unchecked(i);
643
644 let mut tr = hi - lo;
645 let hc = (hi - prev_c).abs();
646 if hc > tr {
647 tr = hc;
648 }
649 let lc = (lo - prev_c).abs();
650 if lc > tr {
651 tr = lc;
652 }
653
654 sum_tr += tr;
655 prev_c = *close.get_unchecked(i);
656 i += 1;
657 }
658 }
659
660 let mut rma = sum_tr / (length as f64);
661 *out.get_unchecked_mut(warm) = rma;
662
663 let mut i = warm + 1;
664 let n = out.len();
665
666 let mask_abs = _mm512_castsi512_pd(_mm512_set1_epi64(0x7fff_ffff_ffff_ffffu64 as i64));
667
668 while i + 7 < n {
669 let v_hi = _mm512_loadu_pd(high.as_ptr().add(i));
670 let v_lo = _mm512_loadu_pd(low.as_ptr().add(i));
671 let v_pc = _mm512_loadu_pd(close.as_ptr().add(i - 1));
672
673 let v_hl = _mm512_sub_pd(v_hi, v_lo);
674 let v_hc = _mm512_and_pd(_mm512_sub_pd(v_hi, v_pc), mask_abs);
675 let v_lc = _mm512_and_pd(_mm512_sub_pd(v_lo, v_pc), mask_abs);
676
677 let v_m1 = _mm512_max_pd(v_hl, v_hc);
678 let v_tr = _mm512_max_pd(v_m1, v_lc);
679
680 let mut buf = [0.0f64; 8];
681 _mm512_storeu_pd(buf.as_mut_ptr(), v_tr);
682
683 rma = (-alpha).mul_add(rma, rma) + alpha * buf[0];
684 *out.get_unchecked_mut(i) = rma;
685
686 rma = (-alpha).mul_add(rma, rma) + alpha * buf[1];
687 *out.get_unchecked_mut(i + 1) = rma;
688
689 rma = (-alpha).mul_add(rma, rma) + alpha * buf[2];
690 *out.get_unchecked_mut(i + 2) = rma;
691
692 rma = (-alpha).mul_add(rma, rma) + alpha * buf[3];
693 *out.get_unchecked_mut(i + 3) = rma;
694
695 rma = (-alpha).mul_add(rma, rma) + alpha * buf[4];
696 *out.get_unchecked_mut(i + 4) = rma;
697
698 rma = (-alpha).mul_add(rma, rma) + alpha * buf[5];
699 *out.get_unchecked_mut(i + 5) = rma;
700
701 rma = (-alpha).mul_add(rma, rma) + alpha * buf[6];
702 *out.get_unchecked_mut(i + 6) = rma;
703
704 rma = (-alpha).mul_add(rma, rma) + alpha * buf[7];
705 *out.get_unchecked_mut(i + 7) = rma;
706
707 i += 8;
708 }
709
710 if i < n {
711 let mut prev_c = *close.get_unchecked(i - 1);
712 while i < n {
713 let hi = *high.get_unchecked(i);
714 let lo = *low.get_unchecked(i);
715 let mut tr = hi - lo;
716 let hc = (hi - prev_c).abs();
717 if hc > tr {
718 tr = hc;
719 }
720 let lc = (lo - prev_c).abs();
721 if lc > tr {
722 tr = lc;
723 }
724 rma = (-alpha).mul_add(rma, rma) + alpha * tr;
725 *out.get_unchecked_mut(i) = rma;
726
727 prev_c = *close.get_unchecked(i);
728 i += 1;
729 }
730 }
731}
732
733#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
734#[inline]
735pub fn atr_avx2(high: &[f64], low: &[f64], close: &[f64], length: usize, out: &mut [f64]) {
736 unsafe { atr_compute_into_avx2(high, low, close, length, 0, out) }
737}
738
739#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
740#[inline]
741pub fn atr_avx512(high: &[f64], low: &[f64], close: &[f64], length: usize, out: &mut [f64]) {
742 unsafe { atr_compute_into_avx512(high, low, close, length, 0, out) }
743}
744
745#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
746#[inline]
747pub unsafe fn atr_avx512_short(
748 high: &[f64],
749 low: &[f64],
750 close: &[f64],
751 length: usize,
752 out: &mut [f64],
753) {
754 atr_compute_into_avx512(high, low, close, length, 0, out)
755}
756#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
757#[inline]
758pub unsafe fn atr_avx512_long(
759 high: &[f64],
760 low: &[f64],
761 close: &[f64],
762 length: usize,
763 out: &mut [f64],
764) {
765 atr_compute_into_avx512(high, low, close, length, 0, out)
766}
767
768#[derive(Debug, Clone)]
769pub struct AtrStream {
770 length: usize,
771 alpha: f64,
772 prev_close: f64,
773 rma: f64,
774 warm_sum: f64,
775 warm_count: usize,
776 seeded: bool,
777}
778
779impl AtrStream {
780 #[inline(always)]
781 pub fn try_new(params: AtrParams) -> Result<Self, AtrError> {
782 let length = params.length.unwrap_or(14);
783 if length == 0 {
784 return Err(AtrError::InvalidLength { length });
785 }
786 Ok(Self {
787 length,
788 alpha: 1.0 / (length as f64),
789 prev_close: f64::NAN,
790 rma: f64::NAN,
791 warm_sum: 0.0,
792 warm_count: 0,
793 seeded: false,
794 })
795 }
796
797 #[inline(always)]
798 pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
799 debug_assert!(
800 high.is_finite() && low.is_finite() && close.is_finite(),
801 "Streaming ATR assumes finite inputs; prefilter NaNs/Infs upstream if needed",
802 );
803
804 let tr = if self.prev_close.is_nan() {
805 high - low
806 } else {
807 let up = if high > self.prev_close {
808 high
809 } else {
810 self.prev_close
811 };
812 let dn = if low < self.prev_close {
813 low
814 } else {
815 self.prev_close
816 };
817 up - dn
818 };
819
820 self.prev_close = close;
821
822 if !self.seeded {
823 self.warm_sum += tr;
824 self.warm_count += 1;
825
826 if self.warm_count == self.length {
827 self.rma = self.warm_sum * self.alpha;
828 self.seeded = true;
829 return Some(self.rma);
830 }
831 return None;
832 }
833
834 self.rma = self.alpha.mul_add(tr - self.rma, self.rma);
835 Some(self.rma)
836 }
837}
838
839#[derive(Clone, Debug)]
840pub struct AtrBatchRange {
841 pub length: (usize, usize, usize),
842}
843impl Default for AtrBatchRange {
844 fn default() -> Self {
845 Self {
846 length: (14, 263, 1),
847 }
848 }
849}
850#[derive(Clone, Debug, Default)]
851pub struct AtrBatchBuilder {
852 range: AtrBatchRange,
853 kernel: Kernel,
854}
855impl AtrBatchBuilder {
856 pub fn new() -> Self {
857 Self::default()
858 }
859 pub fn kernel(mut self, k: Kernel) -> Self {
860 self.kernel = k;
861 self
862 }
863 #[inline]
864 pub fn length_range(mut self, start: usize, end: usize, step: usize) -> Self {
865 self.range.length = (start, end, step);
866 self
867 }
868 #[inline]
869 pub fn length_static(mut self, p: usize) -> Self {
870 self.range.length = (p, p, 0);
871 self
872 }
873 pub fn apply_slices(
874 self,
875 high: &[f64],
876 low: &[f64],
877 close: &[f64],
878 ) -> Result<AtrBatchOutput, AtrError> {
879 atr_batch_with_kernel(high, low, close, &self.range, self.kernel)
880 }
881 pub fn apply_candles(self, c: &Candles) -> Result<AtrBatchOutput, AtrError> {
882 let high = c
883 .select_candle_field("high")
884 .map_err(|_| AtrError::NoCandlesAvailable)?;
885 let low = c
886 .select_candle_field("low")
887 .map_err(|_| AtrError::NoCandlesAvailable)?;
888 let close = c
889 .select_candle_field("close")
890 .map_err(|_| AtrError::NoCandlesAvailable)?;
891 self.apply_slices(high, low, close)
892 }
893}
894
895#[derive(Clone, Debug)]
896pub struct AtrBatchOutput {
897 pub values: Vec<f64>,
898 pub combos: Vec<AtrParams>,
899 pub rows: usize,
900 pub cols: usize,
901}
902impl AtrBatchOutput {
903 pub fn row_for_params(&self, p: &AtrParams) -> Option<usize> {
904 self.combos
905 .iter()
906 .position(|c| c.length.unwrap_or(14) == p.length.unwrap_or(14))
907 }
908 pub fn values_for(&self, p: &AtrParams) -> Option<&[f64]> {
909 self.row_for_params(p).map(|row| {
910 let start = row * self.cols;
911 &self.values[start..start + self.cols]
912 })
913 }
914}
915
916#[inline(always)]
917fn expand_grid(r: &AtrBatchRange) -> Vec<AtrParams> {
918 let (start, end, step) = r.length;
919 if step == 0 || start == end {
920 return vec![AtrParams {
921 length: Some(start),
922 }];
923 }
924 if start < end {
925 (start..=end)
926 .step_by(step)
927 .map(|l| AtrParams { length: Some(l) })
928 .collect()
929 } else {
930 let mut v: Vec<usize> = (end..=start).step_by(step).collect();
931 v.reverse();
932 v.into_iter()
933 .map(|l| AtrParams { length: Some(l) })
934 .collect()
935 }
936}
937
938pub fn atr_batch_with_kernel(
939 high: &[f64],
940 low: &[f64],
941 close: &[f64],
942 sweep: &AtrBatchRange,
943 k: Kernel,
944) -> Result<AtrBatchOutput, AtrError> {
945 let kernel = match k {
946 Kernel::Auto => detect_best_batch_kernel(),
947 other if other.is_batch() => other,
948 other => return Err(AtrError::InvalidKernelForBatch(other)),
949 };
950 let simd = match kernel {
951 Kernel::Avx512Batch => Kernel::Avx512,
952 Kernel::Avx2Batch => Kernel::Avx2,
953 Kernel::ScalarBatch => Kernel::Scalar,
954 _ => unreachable!(),
955 };
956 atr_batch_par_slice(high, low, close, sweep, simd)
957}
958
959#[inline(always)]
960pub fn atr_batch_slice(
961 high: &[f64],
962 low: &[f64],
963 close: &[f64],
964 sweep: &AtrBatchRange,
965 kern: Kernel,
966) -> Result<AtrBatchOutput, AtrError> {
967 atr_batch_inner(high, low, close, sweep, kern, false)
968}
969#[inline(always)]
970pub fn atr_batch_par_slice(
971 high: &[f64],
972 low: &[f64],
973 close: &[f64],
974 sweep: &AtrBatchRange,
975 kern: Kernel,
976) -> Result<AtrBatchOutput, AtrError> {
977 atr_batch_inner(high, low, close, sweep, kern, true)
978}
979
980fn atr_batch_inner_into(
981 high: &[f64],
982 low: &[f64],
983 close: &[f64],
984 sweep: &AtrBatchRange,
985 kern: Kernel,
986 parallel: bool,
987 out: &mut [f64],
988) -> Result<Vec<AtrParams>, AtrError> {
989 let combos = expand_grid(sweep);
990 if combos.is_empty() {
991 let (s, e, st) = sweep.length;
992 return Err(AtrError::InvalidRange {
993 start: s,
994 end: e,
995 step: st,
996 });
997 }
998 let rows = combos.len();
999 let cols = high.len();
1000 let expected = rows.checked_mul(cols).ok_or(AtrError::InvalidRange {
1001 start: sweep.length.0,
1002 end: sweep.length.1,
1003 step: sweep.length.2,
1004 })?;
1005 if out.len() != expected {
1006 return Err(AtrError::OutputLengthMismatch {
1007 expected,
1008 got: out.len(),
1009 });
1010 }
1011
1012 let first = first_valid_hlc(high, low, close);
1013 if first >= cols {
1014 return Err(AtrError::AllValuesNaN);
1015 }
1016
1017 let mut tr = AVec::<f64>::with_capacity(CACHELINE_ALIGN, cols);
1018 unsafe {
1019 tr.set_len(cols);
1020 }
1021
1022 for v in &mut tr[..] {
1023 *v = 0.0;
1024 }
1025
1026 match kern_to_simd(kern) {
1027 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1028 Kernel::Avx512 => unsafe {
1029 precompute_tr_into_avx512(high, low, close, first, &mut tr);
1030 },
1031 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1032 Kernel::Avx2 => unsafe {
1033 precompute_tr_into_avx2(high, low, close, first, &mut tr);
1034 },
1035 _ => {
1036 precompute_tr_into_scalar(high, low, close, first, &mut tr);
1037 }
1038 }
1039
1040 let mut ps = AVec::<f64>::with_capacity(CACHELINE_ALIGN, cols + 1);
1041 unsafe {
1042 ps.set_len(cols + 1);
1043 }
1044 ps[0] = 0.0;
1045
1046 for i in 0..cols {
1047 ps[i + 1] = ps[i] + tr[i];
1048 }
1049
1050 let do_row = |row: usize, dst: &mut [f64]| {
1051 let length = combos[row].length.unwrap();
1052 let warm = first + length - 1;
1053
1054 for v in &mut dst[..warm] {
1055 *v = f64::NAN;
1056 }
1057
1058 let sum_tr = ps[warm + 1] - ps[first];
1059 let mut rma = sum_tr / (length as f64);
1060 dst[warm] = rma;
1061 let alpha = 1.0 / (length as f64);
1062 let mut i = warm + 1;
1063 while i < cols {
1064 let tri = tr[i];
1065 rma = (-alpha).mul_add(rma, rma) + alpha * tri;
1066 dst[i] = rma;
1067 i += 1;
1068 }
1069 };
1070
1071 #[inline(always)]
1072 fn kern_to_simd(k: Kernel) -> Kernel {
1073 match k {
1074 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1075 Kernel::Avx512Batch => Kernel::Avx512,
1076 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1077 Kernel::Avx2Batch => Kernel::Avx2,
1078 Kernel::ScalarBatch => Kernel::Scalar,
1079 other => other,
1080 }
1081 }
1082
1083 if parallel {
1084 #[cfg(not(target_arch = "wasm32"))]
1085 out.par_chunks_mut(cols)
1086 .enumerate()
1087 .for_each(|(r, row)| do_row(r, row));
1088 #[cfg(target_arch = "wasm32")]
1089 for (r, row) in out.chunks_mut(cols).enumerate() {
1090 do_row(r, row);
1091 }
1092 } else {
1093 for (r, row) in out.chunks_mut(cols).enumerate() {
1094 do_row(r, row);
1095 }
1096 }
1097
1098 Ok(combos)
1099}
1100
1101fn atr_batch_inner(
1102 high: &[f64],
1103 low: &[f64],
1104 close: &[f64],
1105 sweep: &AtrBatchRange,
1106 kern: Kernel,
1107 parallel: bool,
1108) -> Result<AtrBatchOutput, AtrError> {
1109 let combos = expand_grid(sweep);
1110 if combos.is_empty() {
1111 let (s, e, st) = sweep.length;
1112 return Err(AtrError::InvalidRange {
1113 start: s,
1114 end: e,
1115 step: st,
1116 });
1117 }
1118 let len = close.len();
1119 let rows = combos.len();
1120 let cols = len;
1121
1122 let mut buf_mu = make_uninit_matrix(rows, cols);
1123
1124 let first_valid = first_valid_hlc(high, low, close);
1125
1126 let warm: Vec<usize> = combos
1127 .iter()
1128 .map(|c| first_valid + c.length.unwrap() - 1)
1129 .collect();
1130
1131 init_matrix_prefixes(&mut buf_mu, cols, &warm);
1132
1133 let mut buf_guard = std::mem::ManuallyDrop::new(buf_mu);
1134 let values: &mut [f64] = unsafe {
1135 std::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
1136 };
1137
1138 let mut tr = AVec::<f64>::with_capacity(CACHELINE_ALIGN, cols);
1139 unsafe {
1140 tr.set_len(cols);
1141 }
1142 for v in &mut tr[..] {
1143 *v = 0.0;
1144 }
1145 match kern {
1146 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1147 Kernel::Avx512 => unsafe {
1148 precompute_tr_into_avx512(high, low, close, first_valid, &mut tr)
1149 },
1150 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1151 Kernel::Avx2 => unsafe { precompute_tr_into_avx2(high, low, close, first_valid, &mut tr) },
1152 _ => precompute_tr_into_scalar(high, low, close, first_valid, &mut tr),
1153 }
1154 let mut ps = AVec::<f64>::with_capacity(CACHELINE_ALIGN, cols + 1);
1155 unsafe { ps.set_len(cols + 1) };
1156 ps[0] = 0.0;
1157 for i in 0..cols {
1158 ps[i + 1] = ps[i] + tr[i];
1159 }
1160
1161 let do_row = |row: usize, out_row: &mut [f64]| {
1162 let length = combos[row].length.unwrap();
1163 let warm = first_valid + length - 1;
1164
1165 let sum_tr = ps[warm + 1] - ps[first_valid];
1166 let mut rma = sum_tr / (length as f64);
1167 out_row[warm] = rma;
1168 let alpha = 1.0 / (length as f64);
1169 let mut i = warm + 1;
1170 while i < cols {
1171 let tri = tr[i];
1172 rma = (-alpha).mul_add(rma, rma) + alpha * tri;
1173 out_row[i] = rma;
1174 i += 1;
1175 }
1176 };
1177 if parallel {
1178 #[cfg(not(target_arch = "wasm32"))]
1179 {
1180 values
1181 .par_chunks_mut(cols)
1182 .enumerate()
1183 .for_each(|(row, slice)| do_row(row, slice));
1184 }
1185
1186 #[cfg(target_arch = "wasm32")]
1187 {
1188 for (row, slice) in values.chunks_mut(cols).enumerate() {
1189 do_row(row, slice);
1190 }
1191 }
1192 } else {
1193 for (row, slice) in values.chunks_mut(cols).enumerate() {
1194 do_row(row, slice);
1195 }
1196 }
1197
1198 let final_values = unsafe {
1199 Vec::from_raw_parts(
1200 buf_guard.as_mut_ptr() as *mut f64,
1201 buf_guard.len(),
1202 buf_guard.capacity(),
1203 )
1204 };
1205
1206 Ok(AtrBatchOutput {
1207 values: final_values,
1208 combos,
1209 rows,
1210 cols,
1211 })
1212}
1213
1214#[inline(always)]
1215unsafe fn atr_row_scalar(high: &[f64], low: &[f64], close: &[f64], length: usize, out: &mut [f64]) {
1216 let first = first_valid_hlc(high, low, close);
1217 atr_compute_into(high, low, close, length, first, Kernel::Scalar, out);
1218}
1219
1220#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1221#[inline(always)]
1222unsafe fn atr_row_avx2(high: &[f64], low: &[f64], close: &[f64], length: usize, out: &mut [f64]) {
1223 let first = first_valid_hlc(high, low, close);
1224 atr_compute_into(high, low, close, length, first, Kernel::Avx2, out);
1225}
1226#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1227#[inline(always)]
1228pub unsafe fn atr_row_avx512(
1229 high: &[f64],
1230 low: &[f64],
1231 close: &[f64],
1232 length: usize,
1233 out: &mut [f64],
1234) {
1235 if length <= 32 {
1236 atr_row_avx512_short(high, low, close, length, out);
1237 } else {
1238 atr_row_avx512_long(high, low, close, length, out);
1239 }
1240}
1241#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1242#[inline(always)]
1243pub unsafe fn atr_row_avx512_short(
1244 high: &[f64],
1245 low: &[f64],
1246 close: &[f64],
1247 length: usize,
1248 out: &mut [f64],
1249) {
1250 let first = first_valid_hlc(high, low, close);
1251 atr_compute_into(high, low, close, length, first, Kernel::Avx512, out);
1252}
1253#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1254#[inline(always)]
1255pub unsafe fn atr_row_avx512_long(
1256 high: &[f64],
1257 low: &[f64],
1258 close: &[f64],
1259 length: usize,
1260 out: &mut [f64],
1261) {
1262 let first = first_valid_hlc(high, low, close);
1263 atr_compute_into(high, low, close, length, first, Kernel::Avx512, out);
1264}
1265
1266#[inline(always)]
1267fn precompute_tr_into_scalar(
1268 high: &[f64],
1269 low: &[f64],
1270 close: &[f64],
1271 first: usize,
1272 tr_out: &mut [f64],
1273) {
1274 if first >= tr_out.len() {
1275 return;
1276 }
1277 tr_out[first] = high[first] - low[first];
1278 let mut i = first + 1;
1279 while i < tr_out.len() {
1280 let hi = high[i];
1281 let lo = low[i];
1282 let pc = close[i - 1];
1283 let mut tr = hi - lo;
1284 let hc = (hi - pc).abs();
1285 if hc > tr {
1286 tr = hc;
1287 }
1288 let lc = (lo - pc).abs();
1289 if lc > tr {
1290 tr = lc;
1291 }
1292 tr_out[i] = tr;
1293 i += 1;
1294 }
1295}
1296
1297#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1298#[inline(always)]
1299unsafe fn precompute_tr_into_avx2(
1300 high: &[f64],
1301 low: &[f64],
1302 close: &[f64],
1303 first: usize,
1304 tr_out: &mut [f64],
1305) {
1306 use core::arch::x86_64::*;
1307 if first >= tr_out.len() {
1308 return;
1309 }
1310 tr_out[first] = *high.get_unchecked(first) - *low.get_unchecked(first);
1311 let mut i = first + 1;
1312 let n = tr_out.len();
1313 let mask_abs = _mm256_castsi256_pd(_mm256_set1_epi64x(0x7fff_ffff_ffff_ffffu64 as i64));
1314 while i + 3 < n {
1315 let v_hi = _mm256_loadu_pd(high.as_ptr().add(i));
1316 let v_lo = _mm256_loadu_pd(low.as_ptr().add(i));
1317 let v_pc = _mm256_loadu_pd(close.as_ptr().add(i - 1));
1318
1319 let v_hl = _mm256_sub_pd(v_hi, v_lo);
1320 let v_hc = _mm256_and_pd(_mm256_sub_pd(v_hi, v_pc), mask_abs);
1321 let v_lc = _mm256_and_pd(_mm256_sub_pd(v_lo, v_pc), mask_abs);
1322 let v_m1 = _mm256_max_pd(v_hl, v_hc);
1323 let v_tr = _mm256_max_pd(v_m1, v_lc);
1324 _mm256_storeu_pd(tr_out.as_mut_ptr().add(i), v_tr);
1325 i += 4;
1326 }
1327 while i < n {
1328 let hi = *high.get_unchecked(i);
1329 let lo = *low.get_unchecked(i);
1330 let pc = *close.get_unchecked(i - 1);
1331 let mut tr = hi - lo;
1332 let hc = (hi - pc).abs();
1333 if hc > tr {
1334 tr = hc;
1335 }
1336 let lc = (lo - pc).abs();
1337 if lc > tr {
1338 tr = lc;
1339 }
1340 *tr_out.get_unchecked_mut(i) = tr;
1341 i += 1;
1342 }
1343}
1344
1345#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1346#[inline(always)]
1347unsafe fn precompute_tr_into_avx512(
1348 high: &[f64],
1349 low: &[f64],
1350 close: &[f64],
1351 first: usize,
1352 tr_out: &mut [f64],
1353) {
1354 use core::arch::x86_64::*;
1355 if first >= tr_out.len() {
1356 return;
1357 }
1358 tr_out[first] = *high.get_unchecked(first) - *low.get_unchecked(first);
1359 let mut i = first + 1;
1360 let n = tr_out.len();
1361 let mask_abs = _mm512_castsi512_pd(_mm512_set1_epi64(0x7fff_ffff_ffff_ffffu64 as i64));
1362 while i + 7 < n {
1363 let v_hi = _mm512_loadu_pd(high.as_ptr().add(i));
1364 let v_lo = _mm512_loadu_pd(low.as_ptr().add(i));
1365 let v_pc = _mm512_loadu_pd(close.as_ptr().add(i - 1));
1366 let v_hl = _mm512_sub_pd(v_hi, v_lo);
1367 let v_hc = _mm512_and_pd(_mm512_sub_pd(v_hi, v_pc), mask_abs);
1368 let v_lc = _mm512_and_pd(_mm512_sub_pd(v_lo, v_pc), mask_abs);
1369 let v_m1 = _mm512_max_pd(v_hl, v_hc);
1370 let v_tr = _mm512_max_pd(v_m1, v_lc);
1371 _mm512_storeu_pd(tr_out.as_mut_ptr().add(i), v_tr);
1372 i += 8;
1373 }
1374 while i < n {
1375 let hi = *high.get_unchecked(i);
1376 let lo = *low.get_unchecked(i);
1377 let pc = *close.get_unchecked(i - 1);
1378 let mut tr = hi - lo;
1379 let hc = (hi - pc).abs();
1380 if hc > tr {
1381 tr = hc;
1382 }
1383 let lc = (lo - pc).abs();
1384 if lc > tr {
1385 tr = lc;
1386 }
1387 *tr_out.get_unchecked_mut(i) = tr;
1388 i += 1;
1389 }
1390}
1391
1392#[cfg(test)]
1393mod tests {
1394 use super::*;
1395 use crate::skip_if_unsupported;
1396 use crate::utilities::data_loader::read_candles_from_csv;
1397 use crate::utilities::enums::Kernel;
1398 #[cfg(feature = "proptest")]
1399 use proptest::prelude::*;
1400
1401 fn check_atr_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1402 skip_if_unsupported!(kernel, test_name);
1403 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1404 let candles = read_candles_from_csv(file_path)?;
1405 let partial_params = AtrParams { length: None };
1406 let input_partial = AtrInput::from_candles(&candles, partial_params);
1407 let result_partial = atr_with_kernel(&input_partial, kernel)?;
1408 assert_eq!(result_partial.values.len(), candles.close.len());
1409 let zero_and_none_params = AtrParams { length: Some(14) };
1410 let input_zero_and_none = AtrInput::from_candles(&candles, zero_and_none_params);
1411 let result_zero_and_none = atr_with_kernel(&input_zero_and_none, kernel)?;
1412 assert_eq!(result_zero_and_none.values.len(), candles.close.len());
1413 Ok(())
1414 }
1415
1416 fn check_atr_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1417 skip_if_unsupported!(kernel, test_name);
1418 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1419 let candles = read_candles_from_csv(file_path)?;
1420 let input = AtrInput::with_default_candles(&candles);
1421 let result = atr_with_kernel(&input, kernel)?;
1422 let expected_last_five = [916.89, 874.33, 838.45, 801.92, 811.57];
1423 assert!(result.values.len() >= 5, "Not enough ATR values");
1424 assert_eq!(
1425 result.values.len(),
1426 candles.close.len(),
1427 "ATR output length does not match input length!"
1428 );
1429 let start_index = result.values.len().saturating_sub(5);
1430 let last_five = &result.values[start_index..];
1431 for (i, &value) in last_five.iter().enumerate() {
1432 assert!(
1433 (value - expected_last_five[i]).abs() < 1e-2,
1434 "ATR value mismatch at index {}: expected {}, got {}",
1435 i,
1436 expected_last_five[i],
1437 value
1438 );
1439 }
1440 let length = 14;
1441 for val in result.values.iter().skip(length - 1) {
1442 if !val.is_nan() {
1443 assert!(
1444 val.is_finite(),
1445 "ATR output should be finite after RMA stabilizes"
1446 );
1447 }
1448 }
1449 Ok(())
1450 }
1451
1452 fn check_atr_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1453 skip_if_unsupported!(kernel, test_name);
1454 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1455 let candles = read_candles_from_csv(file_path)?;
1456 let input = AtrInput::with_default_candles(&candles);
1457 match input.data {
1458 AtrData::Candles { .. } => {}
1459 _ => panic!("Expected AtrData::Candles variant"),
1460 }
1461 let default_params = AtrParams::default();
1462 assert_eq!(input.params.length, default_params.length);
1463 let output = atr_with_kernel(&input, kernel)?;
1464 assert_eq!(output.values.len(), candles.close.len());
1465 Ok(())
1466 }
1467
1468 fn check_atr_zero_length(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1469 skip_if_unsupported!(kernel, test_name);
1470 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1471 let candles = read_candles_from_csv(file_path)?;
1472 let zero_length_params = AtrParams { length: Some(0) };
1473 let input_zero_length = AtrInput::from_candles(&candles, zero_length_params);
1474 let result_zero_length = atr_with_kernel(&input_zero_length, kernel);
1475 assert!(result_zero_length.is_err());
1476 Ok(())
1477 }
1478
1479 fn check_atr_length_exceeding_data_length(
1480 test_name: &str,
1481 kernel: Kernel,
1482 ) -> Result<(), Box<dyn Error>> {
1483 skip_if_unsupported!(kernel, test_name);
1484 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1485 let candles = read_candles_from_csv(file_path)?;
1486 let too_long_params = AtrParams {
1487 length: Some(candles.close.len() + 10),
1488 };
1489 let input_too_long = AtrInput::from_candles(&candles, too_long_params);
1490 let result_too_long = atr_with_kernel(&input_too_long, kernel);
1491 assert!(result_too_long.is_err());
1492 Ok(())
1493 }
1494
1495 fn check_atr_very_small_data_set(
1496 test_name: &str,
1497 kernel: Kernel,
1498 ) -> Result<(), Box<dyn Error>> {
1499 skip_if_unsupported!(kernel, test_name);
1500 let high = [10.0];
1501 let low = [5.0];
1502 let close = [7.0];
1503 let params = AtrParams { length: Some(14) };
1504 let input = AtrInput::from_slices(&high, &low, &close, params);
1505 let result = atr_with_kernel(&input, kernel);
1506 assert!(result.is_err());
1507 Ok(())
1508 }
1509
1510 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1511 #[test]
1512 fn test_atr_into_matches_api() -> Result<(), Box<dyn Error>> {
1513 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1514 let candles = read_candles_from_csv(file_path)?;
1515 let input = AtrInput::with_default_candles(&candles);
1516
1517 let baseline = atr(&input)?;
1518
1519 let mut out = vec![0.0f64; candles.close.len()];
1520 atr_into(&input, &mut out)?;
1521
1522 assert_eq!(baseline.values.len(), out.len());
1523
1524 fn eq_or_nan_bits(a: f64, b: f64) -> bool {
1525 if !a.is_finite() || !b.is_finite() {
1526 a.to_bits() == b.to_bits()
1527 } else {
1528 (a - b).abs() <= 1e-12
1529 }
1530 }
1531
1532 for i in 0..out.len() {
1533 assert!(
1534 eq_or_nan_bits(baseline.values[i], out[i]),
1535 "Mismatch at {}: api={} into={}",
1536 i,
1537 baseline.values[i],
1538 out[i]
1539 );
1540 }
1541 Ok(())
1542 }
1543
1544 fn check_atr_with_slice_data_reinput(
1545 test_name: &str,
1546 kernel: Kernel,
1547 ) -> Result<(), Box<dyn Error>> {
1548 skip_if_unsupported!(kernel, test_name);
1549 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1550 let candles = read_candles_from_csv(file_path)?;
1551 let first_params = AtrParams { length: Some(14) };
1552 let first_input = AtrInput::from_candles(&candles, first_params);
1553 let first_result = atr_with_kernel(&first_input, kernel)?;
1554 assert_eq!(first_result.values.len(), candles.close.len());
1555 let second_params = AtrParams { length: Some(5) };
1556 let second_input = AtrInput::from_slices(
1557 &first_result.values,
1558 &first_result.values,
1559 &first_result.values,
1560 second_params,
1561 );
1562 let second_result = atr_with_kernel(&second_input, kernel)?;
1563 assert_eq!(second_result.values.len(), first_result.values.len());
1564 Ok(())
1565 }
1566
1567 fn check_atr_accuracy_nan_check(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1568 skip_if_unsupported!(kernel, test_name);
1569 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1570 let candles = read_candles_from_csv(file_path)?;
1571 let params = AtrParams { length: Some(14) };
1572 let input = AtrInput::from_candles(&candles, params);
1573 let result = atr_with_kernel(&input, kernel)?;
1574 assert_eq!(result.values.len(), candles.close.len());
1575 if result.values.len() > 240 {
1576 for i in 240..result.values.len() {
1577 assert!(!result.values[i].is_nan());
1578 }
1579 }
1580 Ok(())
1581 }
1582
1583 #[cfg(debug_assertions)]
1584 fn check_atr_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1585 skip_if_unsupported!(kernel, test_name);
1586
1587 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1588 let candles = read_candles_from_csv(file_path)?;
1589
1590 let test_lengths = vec![2, 5, 10, 14, 20, 50, 100, 200];
1591
1592 for length in test_lengths {
1593 let params = AtrParams {
1594 length: Some(length),
1595 };
1596 let input = AtrInput::from_candles(&candles, params);
1597 let output = atr_with_kernel(&input, kernel)?;
1598
1599 for (i, &val) in output.values.iter().enumerate() {
1600 if val.is_nan() {
1601 continue;
1602 }
1603
1604 let bits = val.to_bits();
1605
1606 if bits == 0x11111111_11111111 {
1607 panic!(
1608 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with length={}",
1609 test_name, val, bits, i, length
1610 );
1611 }
1612
1613 if bits == 0x22222222_22222222 {
1614 panic!(
1615 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with length={}",
1616 test_name, val, bits, i, length
1617 );
1618 }
1619
1620 if bits == 0x33333333_33333333 {
1621 panic!(
1622 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with length={}",
1623 test_name, val, bits, i, length
1624 );
1625 }
1626 }
1627 }
1628
1629 Ok(())
1630 }
1631
1632 #[cfg(not(debug_assertions))]
1633 fn check_atr_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1634 Ok(())
1635 }
1636
1637 #[cfg(feature = "proptest")]
1638 #[allow(clippy::float_cmp)]
1639 fn check_atr_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1640 use proptest::prelude::*;
1641 skip_if_unsupported!(kernel, test_name);
1642
1643 let strat = (2usize..=50)
1644 .prop_flat_map(|length| {
1645 (length..400).prop_flat_map(move |data_len| {
1646 (
1647 prop::collection::vec(
1648 (10.0f64..10000.0f64).prop_filter("finite", |x| x.is_finite()),
1649 data_len,
1650 ),
1651 prop::collection::vec(
1652 (10.0f64..10000.0f64).prop_filter("finite", |x| x.is_finite()),
1653 data_len,
1654 ),
1655 prop::collection::vec(
1656 (10.0f64..10000.0f64).prop_filter("finite", |x| x.is_finite()),
1657 data_len,
1658 ),
1659 Just(length),
1660 )
1661 })
1662 })
1663 .prop_map(|(high_raw, low_raw, close_raw, length)| {
1664 let len = high_raw.len();
1665 assert_eq!(low_raw.len(), len);
1666 assert_eq!(close_raw.len(), len);
1667
1668 let mut high = Vec::with_capacity(len);
1669 let mut low = Vec::with_capacity(len);
1670 let mut close = Vec::with_capacity(len);
1671
1672 for i in 0..len {
1673 let h = high_raw[i].max(low_raw[i]);
1674 let l = high_raw[i].min(low_raw[i]);
1675
1676 let c = close_raw[i].max(l).min(h);
1677
1678 high.push(h);
1679 low.push(l);
1680 close.push(c);
1681 }
1682
1683 (high, low, close, length)
1684 });
1685
1686 proptest::test_runner::TestRunner::default().run(
1687 &strat,
1688 |(high, low, close, length)| {
1689 let params = AtrParams {
1690 length: Some(length),
1691 };
1692 let input = AtrInput::from_slices(&high, &low, &close, params);
1693
1694 let AtrOutput { values: out } = atr_with_kernel(&input, kernel)?;
1695 let AtrOutput { values: ref_out } = atr_with_kernel(&input, Kernel::Scalar)?;
1696
1697 prop_assert_eq!(out.len(), high.len(), "Output length mismatch");
1698
1699 for i in 0..(length - 1) {
1700 prop_assert!(
1701 out[i].is_nan(),
1702 "Expected NaN during warmup at index {}, got {}",
1703 i,
1704 out[i]
1705 );
1706 }
1707
1708 for (i, &val) in out.iter().enumerate().skip(length - 1) {
1709 if !val.is_nan() {
1710 prop_assert!(
1711 val >= 0.0,
1712 "ATR must be non-negative at index {}: got {}",
1713 i,
1714 val
1715 );
1716 }
1717 }
1718
1719 let mut max_true_range = 0.0f64;
1720 for i in 0..high.len() {
1721 let tr = if i == 0 {
1722 high[0] - low[0]
1723 } else {
1724 let hl = high[i] - low[i];
1725 let hc = (high[i] - close[i - 1]).abs();
1726 let lc = (low[i] - close[i - 1]).abs();
1727 hl.max(hc).max(lc)
1728 };
1729 max_true_range = max_true_range.max(tr);
1730 }
1731
1732 for (i, &val) in out.iter().enumerate().skip(length - 1) {
1733 if !val.is_nan() && val.is_finite() {
1734 prop_assert!(
1735 val <= max_true_range + 1e-9,
1736 "ATR at index {} exceeds max true range: {} > {}",
1737 i,
1738 val,
1739 max_true_range
1740 );
1741 }
1742 }
1743
1744 for i in 0..out.len() {
1745 let y = out[i];
1746 let r = ref_out[i];
1747
1748 if !y.is_finite() || !r.is_finite() {
1749 prop_assert_eq!(
1750 y.to_bits(),
1751 r.to_bits(),
1752 "NaN/infinite mismatch at index {}: {} vs {}",
1753 i,
1754 y,
1755 r
1756 );
1757 continue;
1758 }
1759
1760 let y_bits = y.to_bits();
1761 let r_bits = r.to_bits();
1762 let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1763
1764 prop_assert!(
1765 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1766 "Kernel mismatch at index {}: {} vs {} (ULP={})",
1767 i,
1768 y,
1769 r,
1770 ulp_diff
1771 );
1772 }
1773
1774 let first_price = high[0];
1775 let is_constant = high.iter().all(|&h| (h - first_price).abs() < 1e-10)
1776 && low.iter().all(|&l| (l - first_price).abs() < 1e-10)
1777 && close.iter().all(|&c| (c - first_price).abs() < 1e-10);
1778
1779 if is_constant {
1780 if out.len() >= length * 3 {
1781 let last_values = &out[out.len().saturating_sub(5)..];
1782 for &val in last_values {
1783 if !val.is_nan() && val.is_finite() {
1784 prop_assert!(
1785 val < 1e-6,
1786 "ATR should converge to 0 for constant prices, got {}",
1787 val
1788 );
1789 }
1790 }
1791 }
1792 }
1793
1794 if out.len() >= length + 10 {
1795 for i in (length + 1)..out.len() {
1796 if !out[i].is_nan() && !out[i - 1].is_nan() {
1797 let tr = {
1798 let hl = high[i] - low[i];
1799 let hc = (high[i] - close[i - 1]).abs();
1800 let lc = (low[i] - close[i - 1]).abs();
1801 hl.max(hc).max(lc)
1802 };
1803
1804 let expected_change_bound = (tr - out[i - 1]).abs() / length as f64;
1805 let actual_change = (out[i] - out[i - 1]).abs();
1806
1807 prop_assert!(
1808 actual_change <= expected_change_bound + 1e-9,
1809 "ATR change at index {} exceeds RMA bound: {} > {}",
1810 i,
1811 actual_change,
1812 expected_change_bound
1813 );
1814 }
1815 }
1816 }
1817
1818 if length == 1 {
1819 for i in 0..out.len() {
1820 if !out[i].is_nan() {
1821 let tr = if i == 0 {
1822 high[0] - low[0]
1823 } else {
1824 let hl = high[i] - low[i];
1825 let hc = (high[i] - close[i - 1]).abs();
1826 let lc = (low[i] - close[i - 1]).abs();
1827 hl.max(hc).max(lc)
1828 };
1829 prop_assert!(
1830 (out[i] - tr).abs() <= 1e-9,
1831 "Length=1 ATR should equal TR at index {}: {} vs {}",
1832 i,
1833 out[i],
1834 tr
1835 );
1836 }
1837 }
1838 }
1839
1840 Ok(())
1841 },
1842 )?;
1843
1844 Ok(())
1845 }
1846
1847 macro_rules! generate_all_atr_tests {
1848 ($($test_fn:ident),*) => {
1849 paste::paste! {
1850 $(
1851 #[test]
1852 fn [<$test_fn _scalar_f64>]() {
1853 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1854 }
1855 )*
1856 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1857 $(
1858 #[test]
1859 fn [<$test_fn _avx2_f64>]() {
1860 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1861 }
1862 #[test]
1863 fn [<$test_fn _avx512_f64>]() {
1864 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1865 }
1866 )*
1867 }
1868 }
1869 }
1870
1871 generate_all_atr_tests!(
1872 check_atr_partial_params,
1873 check_atr_accuracy,
1874 check_atr_default_candles,
1875 check_atr_zero_length,
1876 check_atr_length_exceeding_data_length,
1877 check_atr_very_small_data_set,
1878 check_atr_with_slice_data_reinput,
1879 check_atr_accuracy_nan_check,
1880 check_atr_no_poison
1881 );
1882
1883 #[cfg(feature = "proptest")]
1884 generate_all_atr_tests!(check_atr_property);
1885
1886 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1887 skip_if_unsupported!(kernel, test);
1888
1889 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1890 let c = read_candles_from_csv(file)?;
1891 let output = AtrBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
1892
1893 let def = AtrParams::default();
1894 let row = output.values_for(&def).expect("default row missing");
1895
1896 assert_eq!(row.len(), c.close.len());
1897
1898 let expected = [916.89, 874.33, 838.45, 801.92, 811.57];
1899 let start = row.len() - 5;
1900 for (i, &v) in row[start..].iter().enumerate() {
1901 assert!(
1902 (v - expected[i]).abs() < 1e-2,
1903 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1904 );
1905 }
1906 Ok(())
1907 }
1908
1909 #[cfg(debug_assertions)]
1910 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1911 skip_if_unsupported!(kernel, test);
1912
1913 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1914 let c = read_candles_from_csv(file)?;
1915
1916 let test_configs = vec![
1917 (2, 10, 1),
1918 (5, 25, 5),
1919 (10, 50, 10),
1920 (14, 140, 14),
1921 (50, 200, 50),
1922 (100, 100, 0),
1923 ];
1924
1925 for (start, end, step) in test_configs {
1926 let output = AtrBatchBuilder::new()
1927 .kernel(kernel)
1928 .length_range(start, end, step)
1929 .apply_candles(&c)?;
1930
1931 for (idx, &val) in output.values.iter().enumerate() {
1932 if val.is_nan() {
1933 continue;
1934 }
1935
1936 let bits = val.to_bits();
1937 let row = idx / output.cols;
1938 let col = idx % output.cols;
1939 let length = output.combos[row].length.unwrap_or(14);
1940
1941 if bits == 0x11111111_11111111 {
1942 panic!(
1943 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with length={} in range ({},{},{})",
1944 test, val, bits, row, col, idx, length, start, end, step
1945 );
1946 }
1947
1948 if bits == 0x22222222_22222222 {
1949 panic!(
1950 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) with length={} in range ({},{},{})",
1951 test, val, bits, row, col, idx, length, start, end, step
1952 );
1953 }
1954
1955 if bits == 0x33333333_33333333 {
1956 panic!(
1957 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with length={} in range ({},{},{})",
1958 test, val, bits, row, col, idx, length, start, end, step
1959 );
1960 }
1961 }
1962 }
1963
1964 Ok(())
1965 }
1966
1967 #[cfg(not(debug_assertions))]
1968 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1969 Ok(())
1970 }
1971
1972 macro_rules! gen_batch_tests {
1973 ($fn_name:ident) => {
1974 paste::paste! {
1975 #[test] fn [<$fn_name _scalar>]() {
1976 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1977 }
1978 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1979 #[test] fn [<$fn_name _avx2>]() {
1980 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1981 }
1982 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1983 #[test] fn [<$fn_name _avx512>]() {
1984 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1985 }
1986 #[test] fn [<$fn_name _auto_detect>]() {
1987 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1988 }
1989 }
1990 };
1991 }
1992 gen_batch_tests!(check_batch_default_row);
1993 gen_batch_tests!(check_batch_no_poison);
1994}
1995
1996#[cfg(feature = "python")]
1997use pyo3::create_exception;
1998
1999#[cfg(feature = "python")]
2000create_exception!(atr, InvalidLengthError, PyValueError);
2001#[cfg(feature = "python")]
2002create_exception!(atr, InconsistentSliceLengthsError, PyValueError);
2003#[cfg(feature = "python")]
2004create_exception!(atr, NoCandlesAvailableError, PyValueError);
2005#[cfg(feature = "python")]
2006create_exception!(atr, NotEnoughDataError, PyValueError);
2007#[cfg(feature = "python")]
2008create_exception!(atr, EmptyInputDataError, PyValueError);
2009#[cfg(feature = "python")]
2010create_exception!(atr, AllValuesNaNError, PyValueError);
2011#[cfg(feature = "python")]
2012create_exception!(atr, InvalidPeriodError, PyValueError);
2013#[cfg(feature = "python")]
2014create_exception!(atr, NotEnoughValidDataError, PyValueError);
2015#[cfg(feature = "python")]
2016create_exception!(atr, OutputLengthMismatchError, PyValueError);
2017#[cfg(feature = "python")]
2018create_exception!(atr, InvalidRangeError, PyValueError);
2019#[cfg(feature = "python")]
2020create_exception!(atr, InvalidKernelForBatchError, PyValueError);
2021
2022#[cfg(feature = "python")]
2023impl From<AtrError> for PyErr {
2024 fn from(err: AtrError) -> PyErr {
2025 match err {
2026 AtrError::EmptyInputData => {
2027 EmptyInputDataError::new_err("atr: Input data slice is empty.")
2028 }
2029 AtrError::AllValuesNaN => AllValuesNaNError::new_err("atr: All values are NaN."),
2030 AtrError::InvalidPeriod { period, data_len } => InvalidPeriodError::new_err(format!(
2031 "atr: Invalid period: period = {}, data length = {}",
2032 period, data_len
2033 )),
2034 AtrError::NotEnoughValidData { needed, valid } => {
2035 NotEnoughValidDataError::new_err(format!(
2036 "atr: Not enough valid data: needed = {}, valid = {}",
2037 needed, valid
2038 ))
2039 }
2040 AtrError::OutputLengthMismatch { expected, got } => {
2041 OutputLengthMismatchError::new_err(format!(
2042 "atr: Output slice length mismatch: expected = {}, got = {}",
2043 expected, got
2044 ))
2045 }
2046 AtrError::InvalidRange { start, end, step } => InvalidRangeError::new_err(format!(
2047 "atr: Invalid range: start = {}, end = {}, step = {}",
2048 start, end, step
2049 )),
2050 AtrError::InvalidKernelForBatch(k) => InvalidKernelForBatchError::new_err(format!(
2051 "atr: Invalid kernel type for batch operation: {:?}",
2052 k
2053 )),
2054 AtrError::InvalidLength { length } => InvalidLengthError::new_err(format!(
2055 "Invalid length for ATR calculation (length={}).",
2056 length
2057 )),
2058 AtrError::InconsistentSliceLengths {
2059 high_len,
2060 low_len,
2061 close_len,
2062 } => InconsistentSliceLengthsError::new_err(format!(
2063 "Inconsistent slice lengths for ATR calculation: high={}, low={}, close={}",
2064 high_len, low_len, close_len
2065 )),
2066 AtrError::NoCandlesAvailable => {
2067 NoCandlesAvailableError::new_err("No candles available for ATR calculation.")
2068 }
2069 AtrError::NotEnoughData { length, data_len } => NotEnoughDataError::new_err(format!(
2070 "Not enough data to calculate ATR: length={}, data length={}",
2071 length, data_len
2072 )),
2073 }
2074 }
2075}
2076
2077#[cfg(all(feature = "python", feature = "cuda"))]
2078use crate::cuda::atr_wrapper::DeviceArrayF32Atr;
2079#[cfg(all(feature = "python", feature = "cuda"))]
2080use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2081
2082#[cfg(all(feature = "python", feature = "cuda"))]
2083#[pyclass(module = "ta_indicators.cuda", unsendable)]
2084pub struct DeviceArrayF32Py {
2085 pub(crate) inner: DeviceArrayF32Atr,
2086}
2087
2088#[cfg(all(feature = "python", feature = "cuda"))]
2089#[pymethods]
2090impl DeviceArrayF32Py {
2091 #[getter]
2092 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2093 let d = PyDict::new(py);
2094 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
2095 d.set_item("typestr", "<f4")?;
2096 d.set_item(
2097 "strides",
2098 (
2099 self.inner.cols * std::mem::size_of::<f32>(),
2100 std::mem::size_of::<f32>(),
2101 ),
2102 )?;
2103 d.set_item("data", (self.inner.device_ptr() as usize, false))?;
2104
2105 d.set_item("version", 3)?;
2106 Ok(d)
2107 }
2108
2109 fn __dlpack_device__(&self) -> (i32, i32) {
2110 (2, self.inner.device_id as i32)
2111 }
2112
2113 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2114 fn __dlpack__<'py>(
2115 &mut self,
2116 py: Python<'py>,
2117 stream: Option<PyObject>,
2118 max_version: Option<PyObject>,
2119 dl_device: Option<PyObject>,
2120 copy: Option<PyObject>,
2121 ) -> PyResult<PyObject> {
2122 use cust::memory::DeviceBuffer;
2123
2124 let (kdl, alloc_dev) = self.__dlpack_device__();
2125 if let Some(dev_obj) = dl_device.as_ref() {
2126 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2127 if dev_ty != kdl || dev_id != alloc_dev {
2128 let wants_copy = copy
2129 .as_ref()
2130 .and_then(|c| c.extract::<bool>(py).ok())
2131 .unwrap_or(false);
2132 if wants_copy {
2133 return Err(PyBufferError::new_err(
2134 "device copy not implemented for __dlpack__",
2135 ));
2136 } else {
2137 return Err(PyBufferError::new_err(
2138 "__dlpack__: requested device does not match producer buffer",
2139 ));
2140 }
2141 }
2142 }
2143 }
2144 let _ = stream;
2145
2146 if let Some(copy_obj) = copy.as_ref() {
2147 let do_copy: bool = copy_obj.extract(py)?;
2148 if do_copy {
2149 return Err(PyBufferError::new_err(
2150 "__dlpack__(copy=True) not supported for atr CUDA buffers",
2151 ));
2152 }
2153 }
2154
2155 let dummy =
2156 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2157 let rows = self.inner.rows;
2158 let cols = self.inner.cols;
2159 let ctx = self.inner.ctx.clone();
2160 let device_id = self.inner.device_id;
2161 let inner = std::mem::replace(
2162 &mut self.inner,
2163 DeviceArrayF32Atr {
2164 buf: dummy,
2165 rows: 0,
2166 cols: 0,
2167 ctx,
2168 device_id,
2169 },
2170 );
2171
2172 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2173
2174 export_f32_cuda_dlpack_2d(py, inner.buf, rows, cols, alloc_dev, max_version_bound)
2175 }
2176}
2177
2178#[inline(always)]
2179fn atr_prepare_from_input<'a>(
2180 input: &'a AtrInput,
2181) -> Result<(&'a [f64], &'a [f64], &'a [f64], usize, usize), AtrError> {
2182 let (high, low, close) = match &input.data {
2183 AtrData::Candles { candles } => (&candles.high[..], &candles.low[..], &candles.close[..]),
2184 AtrData::Slices { high, low, close } => (*high, *low, *close),
2185 };
2186
2187 let length = input.params.length.unwrap_or(14);
2188 let (high, low, close, length) = atr_prepare(high, low, close, length)?;
2189 let warmup = length - 1;
2190 Ok((high, low, close, length, warmup))
2191}
2192
2193#[inline(always)]
2194fn atr_prepare<'a>(
2195 high: &'a [f64],
2196 low: &'a [f64],
2197 close: &'a [f64],
2198 length: usize,
2199) -> Result<(&'a [f64], &'a [f64], &'a [f64], usize), AtrError> {
2200 if high.len() != low.len() || low.len() != close.len() {
2201 return Err(AtrError::InconsistentSliceLengths {
2202 high_len: high.len(),
2203 low_len: low.len(),
2204 close_len: close.len(),
2205 });
2206 }
2207
2208 if close.is_empty() {
2209 return Err(AtrError::NoCandlesAvailable);
2210 }
2211
2212 if length == 0 {
2213 return Err(AtrError::InvalidLength { length });
2214 }
2215
2216 if length > close.len() {
2217 return Err(AtrError::NotEnoughData {
2218 length,
2219 data_len: close.len(),
2220 });
2221 }
2222
2223 Ok((high, low, close, length))
2224}
2225
2226#[cfg(feature = "python")]
2227#[pyfunction(name = "atr")]
2228#[pyo3(signature = (high, low, close, length=14, kernel=None))]
2229pub fn atr_py<'py>(
2230 py: Python<'py>,
2231 high: numpy::PyReadonlyArray1<'py, f64>,
2232 low: numpy::PyReadonlyArray1<'py, f64>,
2233 close: numpy::PyReadonlyArray1<'py, f64>,
2234 length: usize,
2235 kernel: Option<&str>,
2236) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
2237 use numpy::{IntoPyArray, PyArrayMethods};
2238
2239 let kernel_enum = validate_kernel(kernel, false)?;
2240
2241 let high_slice = high.as_slice()?;
2242 let low_slice = low.as_slice()?;
2243 let close_slice = close.as_slice()?;
2244
2245 let params = AtrParams {
2246 length: Some(length),
2247 };
2248 let input = AtrInput::from_slices(high_slice, low_slice, close_slice, params);
2249
2250 let result_vec: Vec<f64> = py
2251 .allow_threads(|| atr_with_kernel(&input, kernel_enum).map(|output| output.values))
2252 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2253
2254 Ok(result_vec.into_pyarray(py))
2255}
2256
2257#[cfg(feature = "python")]
2258#[pyclass(name = "AtrStream")]
2259pub struct AtrStreamPy {
2260 stream: AtrStream,
2261}
2262
2263#[cfg(feature = "python")]
2264#[pymethods]
2265impl AtrStreamPy {
2266 #[new]
2267 pub fn new(length: Option<usize>) -> PyResult<Self> {
2268 let params = AtrParams { length };
2269 let stream = AtrStream::try_new(params)?;
2270 Ok(Self { stream })
2271 }
2272
2273 pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
2274 self.stream.update(high, low, close)
2275 }
2276}
2277
2278#[cfg(feature = "python")]
2279#[pyfunction(name = "atr_batch")]
2280#[pyo3(signature = (high, low, close, length_range, kernel=None))]
2281pub fn atr_batch_py<'py>(
2282 py: Python<'py>,
2283 high: numpy::PyReadonlyArray1<'py, f64>,
2284 low: numpy::PyReadonlyArray1<'py, f64>,
2285 close: numpy::PyReadonlyArray1<'py, f64>,
2286 length_range: (usize, usize, usize),
2287 kernel: Option<&str>,
2288) -> PyResult<Bound<'py, PyDict>> {
2289 use numpy::{IntoPyArray, PyArrayMethods};
2290
2291 let k = validate_kernel(kernel, true)?;
2292 let hs = high.as_slice()?;
2293 let ls = low.as_slice()?;
2294 let cs = close.as_slice()?;
2295
2296 let range = AtrBatchRange {
2297 length: length_range,
2298 };
2299 let combos = expand_grid(&range);
2300 let rows = combos.len();
2301 let cols = cs.len();
2302 let total = rows
2303 .checked_mul(cols)
2304 .ok_or_else(|| PyValueError::new_err("atr_batch: rows*cols overflow"))?;
2305
2306 let out_arr = unsafe { numpy::PyArray1::<f64>::new(py, [total], false) };
2307 let buf = unsafe { out_arr.as_slice_mut()? };
2308
2309 py.allow_threads(|| {
2310 let simd = match match k {
2311 Kernel::Auto => detect_best_batch_kernel(),
2312 k if k.is_batch() => k,
2313 Kernel::Scalar => Kernel::ScalarBatch,
2314 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2315 Kernel::Avx2 => Kernel::Avx2Batch,
2316 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2317 Kernel::Avx512 => Kernel::Avx512Batch,
2318 _ => Kernel::ScalarBatch,
2319 } {
2320 Kernel::Avx512Batch => Kernel::Avx512,
2321 Kernel::Avx2Batch => Kernel::Avx2,
2322 Kernel::ScalarBatch => Kernel::Scalar,
2323 _ => unreachable!(),
2324 };
2325 atr_batch_inner_into(hs, ls, cs, &range, simd, true, buf)
2326 .map(|_| ())
2327 .map_err(|e| e)
2328 })
2329 .map_err(|e: AtrError| PyValueError::new_err(e.to_string()))?;
2330
2331 let dict = PyDict::new(py);
2332 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2333 dict.set_item(
2334 "lengths",
2335 combos
2336 .iter()
2337 .map(|p| p.length.unwrap())
2338 .collect::<Vec<_>>()
2339 .into_pyarray(py),
2340 )?;
2341 Ok(dict.into())
2342}
2343
2344#[cfg(all(feature = "python", feature = "cuda"))]
2345#[pyfunction(name = "atr_cuda_batch_dev")]
2346#[pyo3(signature = (high, low, close, length_range, device_id=0))]
2347pub fn atr_cuda_batch_dev_py(
2348 py: Python<'_>,
2349 high: numpy::PyReadonlyArray1<'_, f32>,
2350 low: numpy::PyReadonlyArray1<'_, f32>,
2351 close: numpy::PyReadonlyArray1<'_, f32>,
2352 length_range: (usize, usize, usize),
2353 device_id: usize,
2354) -> PyResult<DeviceArrayF32Py> {
2355 if !cuda_available() {
2356 return Err(PyValueError::new_err("CUDA not available"));
2357 }
2358 let hs = high.as_slice()?;
2359 let ls = low.as_slice()?;
2360 let cs = close.as_slice()?;
2361 if hs.len() != ls.len() || ls.len() != cs.len() {
2362 return Err(PyValueError::new_err("input length mismatch"));
2363 }
2364 let sweep = AtrBatchRange {
2365 length: length_range,
2366 };
2367 let inner = py.allow_threads(|| {
2368 let cuda = CudaAtr::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2369 cuda.atr_batch_dev(hs, ls, cs, &sweep)
2370 .map_err(|e| PyValueError::new_err(e.to_string()))
2371 })?;
2372 Ok(DeviceArrayF32Py { inner })
2373}
2374
2375#[cfg(all(feature = "python", feature = "cuda"))]
2376#[pyfunction(name = "atr_cuda_many_series_one_param_dev")]
2377#[pyo3(signature = (high_tm, low_tm, close_tm, cols, rows, length, device_id=0))]
2378pub fn atr_cuda_many_series_one_param_dev_py(
2379 py: Python<'_>,
2380 high_tm: numpy::PyReadonlyArray1<'_, f32>,
2381 low_tm: numpy::PyReadonlyArray1<'_, f32>,
2382 close_tm: numpy::PyReadonlyArray1<'_, f32>,
2383 cols: usize,
2384 rows: usize,
2385 length: usize,
2386 device_id: usize,
2387) -> PyResult<DeviceArrayF32Py> {
2388 if !cuda_available() {
2389 return Err(PyValueError::new_err("CUDA not available"));
2390 }
2391 let h = high_tm.as_slice()?;
2392 let l = low_tm.as_slice()?;
2393 let c = close_tm.as_slice()?;
2394 let expected = cols
2395 .checked_mul(rows)
2396 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2397 if h.len() != expected || l.len() != expected || c.len() != expected {
2398 return Err(PyValueError::new_err("time-major input length mismatch"));
2399 }
2400 let inner = py.allow_threads(|| {
2401 let cuda = CudaAtr::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2402 cuda.atr_many_series_one_param_time_major_dev(h, l, c, cols, rows, length)
2403 .map_err(|e| PyValueError::new_err(e.to_string()))
2404 })?;
2405 Ok(DeviceArrayF32Py { inner })
2406}
2407
2408#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2409pub fn atr_into_slice(dst: &mut [f64], input: &AtrInput, kern: Kernel) -> Result<(), AtrError> {
2410 let (high, low, close) = match &input.data {
2411 AtrData::Candles { candles } => (&candles.high[..], &candles.low[..], &candles.close[..]),
2412 AtrData::Slices { high, low, close } => (*high, *low, *close),
2413 };
2414
2415 let length = input.params.length.unwrap_or(14);
2416 let (high, low, close, length) = atr_prepare(high, low, close, length)?;
2417 let first = first_valid_hlc(high, low, close);
2418 let valid = close.len().saturating_sub(first);
2419 if valid < length {
2420 return Err(AtrError::NotEnoughValidData {
2421 needed: length,
2422 valid,
2423 });
2424 }
2425 let warm = first + length - 1;
2426
2427 if dst.len() != close.len() {
2428 return Err(AtrError::OutputLengthMismatch {
2429 expected: close.len(),
2430 got: dst.len(),
2431 });
2432 }
2433
2434 for v in &mut dst[..warm] {
2435 *v = f64::NAN;
2436 }
2437
2438 let k = match kern {
2439 Kernel::Auto => Kernel::Scalar,
2440 k => k,
2441 };
2442 atr_compute_into(high, low, close, length, first, k, dst);
2443 Ok(())
2444}
2445
2446#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2447#[wasm_bindgen(js_name = "atr")]
2448pub fn atr_js(
2449 high: &[f64],
2450 low: &[f64],
2451 close: &[f64],
2452 length: usize,
2453) -> Result<Vec<f64>, JsError> {
2454 let params = AtrParams {
2455 length: Some(length),
2456 };
2457 let input = AtrInput::from_slices(high, low, close, params);
2458
2459 let mut output = vec![0.0; high.len()];
2460 atr_into_slice(&mut output, &input, Kernel::Auto).map_err(|e| JsError::new(&e.to_string()))?;
2461
2462 Ok(output)
2463}
2464
2465#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2466#[wasm_bindgen(js_name = "atrBatch")]
2467pub fn atr_batch_js(
2468 high: &[f64],
2469 low: &[f64],
2470 close: &[f64],
2471 length_start: usize,
2472 length_end: usize,
2473 length_step: usize,
2474) -> Result<Vec<f64>, JsError> {
2475 let range = AtrBatchRange {
2476 length: (length_start, length_end, length_step),
2477 };
2478 let output = atr_batch_with_kernel(high, low, close, &range, Kernel::Auto)
2479 .map_err(|e| JsError::new(&e.to_string()))?;
2480 Ok(output.values)
2481}
2482
2483#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2484#[wasm_bindgen(js_name = "atrBatchMetadata")]
2485pub fn atr_batch_metadata_js(
2486 length_start: usize,
2487 length_end: usize,
2488 length_step: usize,
2489) -> Vec<f64> {
2490 let range = AtrBatchRange {
2491 length: (length_start, length_end, length_step),
2492 };
2493 let combos = expand_grid(&range);
2494
2495 combos
2496 .iter()
2497 .map(|p| p.length.unwrap_or(14) as f64)
2498 .collect()
2499}
2500
2501#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2502#[wasm_bindgen(js_name = "atr_batch", skip_jsdoc)]
2503pub fn atr_batch_unified_js(
2504 high: &[f64],
2505 low: &[f64],
2506 close: &[f64],
2507 config: JsValue,
2508) -> Result<JsValue, JsError> {
2509 #[derive(Deserialize)]
2510 struct BatchConfig {
2511 length_range: [usize; 3],
2512 }
2513
2514 let config: BatchConfig =
2515 serde_wasm_bindgen::from_value(config).map_err(|e| JsError::new(&e.to_string()))?;
2516
2517 let range = AtrBatchRange {
2518 length: (
2519 config.length_range[0],
2520 config.length_range[1],
2521 config.length_range[2],
2522 ),
2523 };
2524
2525 let output = atr_batch_with_kernel(high, low, close, &range, Kernel::Auto)
2526 .map_err(|e| JsError::new(&e.to_string()))?;
2527
2528 #[derive(Serialize)]
2529 struct BatchResult {
2530 values: Vec<f64>,
2531 combos: Vec<AtrParams>,
2532 rows: usize,
2533 cols: usize,
2534 }
2535
2536 let result = BatchResult {
2537 values: output.values,
2538 combos: output.combos,
2539 rows: output.rows,
2540 cols: output.cols,
2541 };
2542
2543 serde_wasm_bindgen::to_value(&result).map_err(|e| JsError::new(&e.to_string()))
2544}
2545
2546#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2547#[wasm_bindgen]
2548pub fn atr_alloc(len: usize) -> *mut f64 {
2549 let mut vec = Vec::<f64>::with_capacity(len);
2550 let ptr = vec.as_mut_ptr();
2551 std::mem::forget(vec);
2552 ptr
2553}
2554
2555#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2556#[wasm_bindgen]
2557pub fn atr_free(ptr: *mut f64, len: usize) {
2558 unsafe {
2559 let _ = Vec::from_raw_parts(ptr, len, len);
2560 }
2561}
2562
2563#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2564#[wasm_bindgen]
2565pub fn atr_into(
2566 high_ptr: *const f64,
2567 low_ptr: *const f64,
2568 close_ptr: *const f64,
2569 out_ptr: *mut f64,
2570 len: usize,
2571 length: usize,
2572) -> Result<(), JsError> {
2573 if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
2574 return Err(JsError::new("null pointer passed to atr_into"));
2575 }
2576
2577 unsafe {
2578 let high = std::slice::from_raw_parts(high_ptr, len);
2579 let low = std::slice::from_raw_parts(low_ptr, len);
2580 let close = std::slice::from_raw_parts(close_ptr, len);
2581
2582 let params = AtrParams {
2583 length: Some(length),
2584 };
2585 let input = AtrInput::from_slices(high, low, close, params);
2586
2587 if high_ptr == out_ptr || low_ptr == out_ptr || close_ptr == out_ptr {
2588 let mut temp = vec![0.0; len];
2589 atr_into_slice(&mut temp, &input, Kernel::Auto)
2590 .map_err(|e| JsError::new(&e.to_string()))?;
2591 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2592 out.copy_from_slice(&temp);
2593 } else {
2594 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2595 atr_into_slice(out, &input, Kernel::Auto).map_err(|e| JsError::new(&e.to_string()))?;
2596 }
2597 Ok(())
2598 }
2599}
2600
2601#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2602#[wasm_bindgen]
2603pub fn atr_batch_into(
2604 high_ptr: *const f64,
2605 low_ptr: *const f64,
2606 close_ptr: *const f64,
2607 out_ptr: *mut f64,
2608 len: usize,
2609 length_start: usize,
2610 length_end: usize,
2611 length_step: usize,
2612) -> Result<(), JsError> {
2613 if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
2614 return Err(JsError::new("null pointer passed to atr_batch_into"));
2615 }
2616
2617 unsafe {
2618 let high = std::slice::from_raw_parts(high_ptr, len);
2619 let low = std::slice::from_raw_parts(low_ptr, len);
2620 let close = std::slice::from_raw_parts(close_ptr, len);
2621
2622 let range = AtrBatchRange {
2623 length: (length_start, length_end, length_step),
2624 };
2625
2626 let combos = expand_grid(&range);
2627 let rows = combos.len();
2628 let cols = len;
2629 let output_size = rows
2630 .checked_mul(cols)
2631 .ok_or_else(|| JsError::new("atr_batch_into: rows*cols overflow"))?;
2632
2633 if high_ptr == out_ptr || low_ptr == out_ptr || close_ptr == out_ptr {
2634 let output = atr_batch_with_kernel(high, low, close, &range, Kernel::Auto)
2635 .map_err(|e| JsError::new(&e.to_string()))?;
2636 let out_slice = std::slice::from_raw_parts_mut(out_ptr, output_size);
2637 out_slice.copy_from_slice(&output.values);
2638 } else {
2639 let out_slice = std::slice::from_raw_parts_mut(out_ptr, output_size);
2640
2641 let kernel = match detect_best_batch_kernel() {
2642 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2643 Kernel::Avx512Batch => Kernel::Avx512,
2644 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2645 Kernel::Avx2Batch => Kernel::Avx2,
2646 Kernel::ScalarBatch => Kernel::Scalar,
2647 _ => Kernel::Scalar,
2648 };
2649
2650 atr_batch_inner_into(high, low, close, &range, kernel, false, out_slice)
2651 .map_err(|e| JsError::new(&e.to_string()))?;
2652 }
2653
2654 Ok(())
2655 }
2656}
2657
2658#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2659#[wasm_bindgen]
2660#[deprecated(
2661 since = "1.0.0",
2662 note = "For weight reuse patterns, use the fast/unsafe API with persistent buffers"
2663)]
2664pub struct AtrContext {
2665 stream: AtrStream,
2666}
2667
2668#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2669#[wasm_bindgen]
2670#[allow(deprecated)]
2671impl AtrContext {
2672 #[wasm_bindgen(constructor)]
2673 #[deprecated(
2674 since = "1.0.0",
2675 note = "For weight reuse patterns, use the fast/unsafe API with persistent buffers"
2676 )]
2677 pub fn new(length: usize) -> Result<AtrContext, JsError> {
2678 let params = AtrParams {
2679 length: Some(length),
2680 };
2681 let stream = AtrStream::try_new(params).map_err(|e| JsError::new(&e.to_string()))?;
2682 Ok(AtrContext { stream })
2683 }
2684
2685 #[wasm_bindgen]
2686 pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
2687 self.stream.update(high, low, close)
2688 }
2689
2690 #[wasm_bindgen]
2691 pub fn reset(&mut self) -> Result<(), JsError> {
2692 let length = self.stream.length;
2693 let params = AtrParams {
2694 length: Some(length),
2695 };
2696 self.stream = AtrStream::try_new(params).map_err(|e| JsError::new(&e.to_string()))?;
2697 Ok(())
2698 }
2699}
2700
2701#[cfg(feature = "python")]
2702pub fn register_atr_exceptions(m: &Bound<'_, PyModule>) -> PyResult<()> {
2703 m.add(
2704 "InvalidLengthError",
2705 m.py().get_type::<InvalidLengthError>(),
2706 )?;
2707 m.add(
2708 "InconsistentSliceLengthsError",
2709 m.py().get_type::<InconsistentSliceLengthsError>(),
2710 )?;
2711 m.add(
2712 "NoCandlesAvailableError",
2713 m.py().get_type::<NoCandlesAvailableError>(),
2714 )?;
2715 m.add(
2716 "NotEnoughDataError",
2717 m.py().get_type::<NotEnoughDataError>(),
2718 )?;
2719 Ok(())
2720}