1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7use aligned_vec::{AVec, CACHELINE_ALIGN};
8#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
9use core::arch::x86_64::*;
10#[cfg(not(target_arch = "wasm32"))]
11use rayon::prelude::*;
12use std::convert::AsRef;
13use thiserror::Error;
14
15#[cfg(feature = "python")]
16use crate::utilities::kernel_validation::validate_kernel;
17#[cfg(feature = "python")]
18use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
19#[cfg(feature = "python")]
20use pyo3::exceptions::PyValueError;
21#[cfg(feature = "python")]
22use pyo3::prelude::*;
23#[cfg(feature = "python")]
24use pyo3::types::PyDict;
25#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
26use serde::{Deserialize, Serialize};
27#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
28use wasm_bindgen::prelude::*;
29
30#[derive(Debug, Clone)]
31pub enum NatrData<'a> {
32 Candles {
33 candles: &'a Candles,
34 },
35 Slices {
36 high: &'a [f64],
37 low: &'a [f64],
38 close: &'a [f64],
39 },
40}
41
42#[derive(Debug, Clone)]
43pub struct NatrOutput {
44 pub values: Vec<f64>,
45}
46
47#[derive(Debug, Clone)]
48#[cfg_attr(
49 all(target_arch = "wasm32", feature = "wasm"),
50 derive(Serialize, Deserialize)
51)]
52pub struct NatrParams {
53 pub period: Option<usize>,
54}
55
56impl Default for NatrParams {
57 fn default() -> Self {
58 Self { period: Some(14) }
59 }
60}
61
62#[derive(Debug, Clone)]
63pub struct NatrInput<'a> {
64 pub data: NatrData<'a>,
65 pub params: NatrParams,
66}
67
68impl<'a> NatrInput<'a> {
69 #[inline]
70 pub fn from_candles(candles: &'a Candles, params: NatrParams) -> Self {
71 Self {
72 data: NatrData::Candles { candles },
73 params,
74 }
75 }
76 #[inline]
77 pub fn from_slices(
78 high: &'a [f64],
79 low: &'a [f64],
80 close: &'a [f64],
81 params: NatrParams,
82 ) -> Self {
83 Self {
84 data: NatrData::Slices { high, low, close },
85 params,
86 }
87 }
88 #[inline]
89 pub fn with_default_candles(candles: &'a Candles) -> Self {
90 Self {
91 data: NatrData::Candles { candles },
92 params: NatrParams::default(),
93 }
94 }
95 #[inline]
96 pub fn get_period(&self) -> usize {
97 self.params.period.unwrap_or(14)
98 }
99}
100
101#[derive(Copy, Clone, Debug)]
102pub struct NatrBuilder {
103 period: Option<usize>,
104 kernel: Kernel,
105}
106
107impl Default for NatrBuilder {
108 fn default() -> Self {
109 Self {
110 period: None,
111 kernel: Kernel::Auto,
112 }
113 }
114}
115
116impl NatrBuilder {
117 #[inline(always)]
118 pub fn new() -> Self {
119 Self::default()
120 }
121 #[inline(always)]
122 pub fn period(mut self, n: usize) -> Self {
123 self.period = Some(n);
124 self
125 }
126 #[inline(always)]
127 pub fn kernel(mut self, k: Kernel) -> Self {
128 self.kernel = k;
129 self
130 }
131 #[inline(always)]
132 pub fn apply(self, c: &Candles) -> Result<NatrOutput, NatrError> {
133 let p = NatrParams {
134 period: self.period,
135 };
136 let i = NatrInput::from_candles(c, p);
137 natr_with_kernel(&i, self.kernel)
138 }
139 #[inline(always)]
140 pub fn apply_slices(
141 self,
142 high: &[f64],
143 low: &[f64],
144 close: &[f64],
145 ) -> Result<NatrOutput, NatrError> {
146 let p = NatrParams {
147 period: self.period,
148 };
149 let i = NatrInput::from_slices(high, low, close, p);
150 natr_with_kernel(&i, self.kernel)
151 }
152 #[inline(always)]
153 pub fn into_stream(self) -> Result<NatrStream, NatrError> {
154 let p = NatrParams {
155 period: self.period,
156 };
157 NatrStream::try_new(p)
158 }
159}
160
161#[derive(Debug, Error)]
162pub enum NatrError {
163 #[error("natr: Empty data provided for NATR.")]
164 EmptyInputData,
165 #[error("natr: All values are NaN.")]
166 AllValuesNaN,
167 #[error("natr: Empty data provided for NATR.")]
168 EmptyData,
169 #[error("natr: Invalid period: period = {period}, data length = {data_len}")]
170 InvalidPeriod { period: usize, data_len: usize },
171 #[error("natr: Not enough valid data: needed = {needed}, valid = {valid}")]
172 NotEnoughValidData { needed: usize, valid: usize },
173 #[error("natr: Output length mismatch: expected = {expected}, got = {got}")]
174 OutputLengthMismatch { expected: usize, got: usize },
175 #[error("natr: Invalid range: start={start}, end={end}, step={step}")]
176 InvalidRange {
177 start: String,
178 end: String,
179 step: String,
180 },
181 #[error("natr: Invalid kernel for batch: {0:?}")]
182 InvalidKernelForBatch(Kernel),
183 #[error("natr: Mismatched lengths: expected = {expected}, actual = {actual}")]
184 MismatchedLength { expected: usize, actual: usize },
185}
186
187#[inline]
188pub fn natr(input: &NatrInput) -> Result<NatrOutput, NatrError> {
189 natr_with_kernel(input, Kernel::Auto)
190}
191
192pub fn natr_with_kernel(input: &NatrInput, kernel: Kernel) -> Result<NatrOutput, NatrError> {
193 let (high, low, close) = match &input.data {
194 NatrData::Candles { candles } => {
195 let high = source_type(candles, "high");
196 let low = source_type(candles, "low");
197 let close = source_type(candles, "close");
198 (high, low, close)
199 }
200 NatrData::Slices { high, low, close } => (*high, *low, *close),
201 };
202
203 if high.is_empty() || low.is_empty() || close.is_empty() {
204 return Err(NatrError::EmptyInputData);
205 }
206
207 let len_h = high.len();
208 let len_l = low.len();
209 let len_c = close.len();
210 if len_h != len_l || len_h != len_c {
211 return Err(NatrError::MismatchedLength {
212 expected: len_h,
213 actual: if len_l != len_h { len_l } else { len_c },
214 });
215 }
216 let len = len_h;
217
218 let period = input.get_period();
219 if period == 0 || period > len {
220 return Err(NatrError::InvalidPeriod {
221 period,
222 data_len: len,
223 });
224 }
225
226 let first_valid_idx = {
227 let first_valid_idx_h = high.iter().position(|&x| !x.is_nan());
228 let first_valid_idx_l = low.iter().position(|&x| !x.is_nan());
229 let first_valid_idx_c = close.iter().position(|&x| !x.is_nan());
230
231 match (first_valid_idx_h, first_valid_idx_l, first_valid_idx_c) {
232 (Some(h), Some(l), Some(c)) => Some(h.max(l).max(c)),
233 _ => None,
234 }
235 };
236
237 let first_valid_idx = match first_valid_idx {
238 Some(idx) => idx,
239 None => return Err(NatrError::AllValuesNaN),
240 };
241
242 if (len - first_valid_idx) < period {
243 return Err(NatrError::NotEnoughValidData {
244 needed: period,
245 valid: len - first_valid_idx,
246 });
247 }
248
249 let mut out = alloc_with_nan_prefix(len, first_valid_idx + period - 1);
250
251 let chosen = match kernel {
252 Kernel::Auto => Kernel::Scalar,
253 other => other,
254 };
255
256 unsafe {
257 match chosen {
258 Kernel::Scalar | Kernel::ScalarBatch => {
259 natr_scalar(high, low, close, period, first_valid_idx, &mut out)
260 }
261 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
262 Kernel::Avx2 | Kernel::Avx2Batch => {
263 natr_avx2(high, low, close, period, first_valid_idx, &mut out)
264 }
265 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
266 Kernel::Avx512 | Kernel::Avx512Batch => {
267 natr_avx512(high, low, close, period, first_valid_idx, &mut out)
268 }
269 _ => unreachable!(),
270 }
271 }
272
273 Ok(NatrOutput { values: out })
274}
275
276#[inline]
277pub fn natr_scalar(
278 high: &[f64],
279 low: &[f64],
280 close: &[f64],
281 period: usize,
282 first: usize,
283 out: &mut [f64],
284) {
285 let len = out.len();
286 if first >= len {
287 return;
288 }
289
290 let inv_p: f64 = 1.0 / (period as f64);
291 let k100: f64 = 100.0;
292
293 let warm_end = first + period - 1;
294 let mut sum_tr = 0.0;
295
296 sum_tr += high[first] - low[first];
297 for i in (first + 1)..=warm_end {
298 let hi = high[i];
299 let lo = low[i];
300 let pc = close[i - 1];
301 let tr = (hi - lo).max((hi - pc).abs()).max((lo - pc).abs());
302 sum_tr += tr;
303 }
304
305 let mut atr = sum_tr * inv_p;
306 let c_we = close[warm_end];
307 out[warm_end] = if c_we.is_finite() && c_we != 0.0 {
308 (atr / c_we) * k100
309 } else {
310 f64::NAN
311 };
312
313 let mut idx = warm_end + 1;
314 while idx + 3 < len {
315 let pc0 = close[idx - 1];
316 let pc1 = close[idx + 0];
317 let pc2 = close[idx + 1];
318 let pc3 = close[idx + 2];
319
320 let h0 = high[idx + 0];
321 let h1 = high[idx + 1];
322 let h2 = high[idx + 2];
323 let h3 = high[idx + 3];
324 let l0 = low[idx + 0];
325 let l1 = low[idx + 1];
326 let l2 = low[idx + 2];
327 let l3 = low[idx + 3];
328
329 let tr0 = (h0 - l0).max((h0 - pc0).abs()).max((l0 - pc0).abs());
330 let tr1 = (h1 - l1).max((h1 - pc1).abs()).max((l1 - pc1).abs());
331 let tr2 = (h2 - l2).max((h2 - pc2).abs()).max((l2 - pc2).abs());
332 let tr3 = (h3 - l3).max((h3 - pc3).abs()).max((l3 - pc3).abs());
333
334 atr = (tr0 - atr).mul_add(inv_p, atr);
335 let c0 = close[idx + 0];
336 out[idx + 0] = if c0.is_finite() && c0 != 0.0 {
337 (atr / c0) * k100
338 } else {
339 f64::NAN
340 };
341
342 atr = (tr1 - atr).mul_add(inv_p, atr);
343 let c1v = close[idx + 1];
344 out[idx + 1] = if c1v.is_finite() && c1v != 0.0 {
345 (atr / c1v) * k100
346 } else {
347 f64::NAN
348 };
349
350 atr = (tr2 - atr).mul_add(inv_p, atr);
351 let c2v = close[idx + 2];
352 out[idx + 2] = if c2v.is_finite() && c2v != 0.0 {
353 (atr / c2v) * k100
354 } else {
355 f64::NAN
356 };
357
358 atr = (tr3 - atr).mul_add(inv_p, atr);
359 let c3v = close[idx + 3];
360 out[idx + 3] = if c3v.is_finite() && c3v != 0.0 {
361 (atr / c3v) * k100
362 } else {
363 f64::NAN
364 };
365
366 idx += 4;
367 }
368
369 while idx < len {
370 let hi = high[idx];
371 let lo = low[idx];
372 let pc = close[idx - 1];
373 let tr = (hi - lo).max((hi - pc).abs()).max((lo - pc).abs());
374 atr = (tr - atr).mul_add(inv_p, atr);
375
376 let cv = close[idx];
377 out[idx] = if cv.is_finite() && cv != 0.0 {
378 (atr / cv) * k100
379 } else {
380 f64::NAN
381 };
382 idx += 1;
383 }
384}
385
386#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
387#[inline]
388pub fn natr_avx512(
389 high: &[f64],
390 low: &[f64],
391 close: &[f64],
392 period: usize,
393 first: usize,
394 out: &mut [f64],
395) {
396 unsafe {
397 natr_avx512_body(high, low, close, period, first, out);
398 }
399}
400
401#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
402#[target_feature(enable = "avx2")]
403#[inline]
404pub unsafe fn natr_avx2(
405 high: &[f64],
406 low: &[f64],
407 close: &[f64],
408 period: usize,
409 first: usize,
410 out: &mut [f64],
411) {
412 use core::arch::x86_64::*;
413 debug_assert!(high.len() == low.len() && high.len() == close.len() && high.len() == out.len());
414 let len = out.len();
415 if first >= len {
416 return;
417 }
418
419 let inv_p: f64 = 1.0 / (period as f64);
420 let k100: f64 = 100.0;
421
422 let h = high.as_ptr();
423 let l = low.as_ptr();
424 let c = close.as_ptr();
425 let o = out.as_mut_ptr();
426
427 let mut i = first;
428 let mut sum_tr = *h.add(i) - *l.add(i);
429 i += 1;
430 let warm_end = first + period - 1;
431
432 let sign = _mm256_set1_pd(-0.0_f64);
433 while i + 3 <= warm_end {
434 let vh = _mm256_loadu_pd(h.add(i));
435 let vl = _mm256_loadu_pd(l.add(i));
436 let vpc = _mm256_loadu_pd(c.add(i - 1));
437
438 let vhl = _mm256_sub_pd(vh, vl);
439 let vhc = _mm256_sub_pd(vh, vpc);
440 let vlc = _mm256_sub_pd(vl, vpc);
441 let abs_hc = _mm256_andnot_pd(sign, vhc);
442 let abs_lc = _mm256_andnot_pd(sign, vlc);
443 let mx1 = _mm256_max_pd(vhl, abs_hc);
444 let vtr = _mm256_max_pd(mx1, abs_lc);
445
446 let mut tmp = [0.0f64; 4];
447 _mm256_storeu_pd(tmp.as_mut_ptr(), vtr);
448 sum_tr += tmp[0] + tmp[1] + tmp[2] + tmp[3];
449
450 i += 4;
451 }
452 while i <= warm_end {
453 let hi = *h.add(i);
454 let lo = *l.add(i);
455 let pc = *c.add(i - 1);
456 let tr = (hi - lo).max((hi - pc).abs()).max((lo - pc).abs());
457 sum_tr += tr;
458 i += 1;
459 }
460
461 let mut atr = sum_tr * inv_p;
462 let c_we = *c.add(warm_end);
463 *o.add(warm_end) = if c_we.is_finite() && c_we != 0.0 {
464 (atr / c_we) * k100
465 } else {
466 f64::NAN
467 };
468
469 let mut idx = warm_end + 1;
470 while idx + 3 < len {
471 let vh = _mm256_loadu_pd(h.add(idx));
472 let vl = _mm256_loadu_pd(l.add(idx));
473 let vpc = _mm256_loadu_pd(c.add(idx - 1));
474
475 let vhl = _mm256_sub_pd(vh, vl);
476 let vhc = _mm256_sub_pd(vh, vpc);
477 let vlc = _mm256_sub_pd(vl, vpc);
478 let abs_hc = _mm256_andnot_pd(sign, vhc);
479 let abs_lc = _mm256_andnot_pd(sign, vlc);
480 let mx1 = _mm256_max_pd(vhl, abs_hc);
481 let vtr = _mm256_max_pd(mx1, abs_lc);
482
483 let mut tr = [0.0f64; 4];
484 _mm256_storeu_pd(tr.as_mut_ptr(), vtr);
485
486 atr = (tr[0] - atr).mul_add(inv_p, atr);
487 let c0 = *c.add(idx + 0);
488 *o.add(idx + 0) = if c0.is_finite() && c0 != 0.0 {
489 (atr / c0) * k100
490 } else {
491 f64::NAN
492 };
493
494 atr = (tr[1] - atr).mul_add(inv_p, atr);
495 let c1 = *c.add(idx + 1);
496 *o.add(idx + 1) = if c1.is_finite() && c1 != 0.0 {
497 (atr / c1) * k100
498 } else {
499 f64::NAN
500 };
501
502 atr = (tr[2] - atr).mul_add(inv_p, atr);
503 let c2 = *c.add(idx + 2);
504 *o.add(idx + 2) = if c2.is_finite() && c2 != 0.0 {
505 (atr / c2) * k100
506 } else {
507 f64::NAN
508 };
509
510 atr = (tr[3] - atr).mul_add(inv_p, atr);
511 let c3 = *c.add(idx + 3);
512 *o.add(idx + 3) = if c3.is_finite() && c3 != 0.0 {
513 (atr / c3) * k100
514 } else {
515 f64::NAN
516 };
517
518 idx += 4;
519 }
520 while idx < len {
521 let hi = *h.add(idx);
522 let lo = *l.add(idx);
523 let pc = *c.add(idx - 1);
524 let tr = (hi - lo).max((hi - pc).abs()).max((lo - pc).abs());
525 atr = (tr - atr).mul_add(inv_p, atr);
526
527 let cv = *c.add(idx);
528 *o.add(idx) = if cv.is_finite() && cv != 0.0 {
529 (atr / cv) * k100
530 } else {
531 f64::NAN
532 };
533 idx += 1;
534 }
535}
536
537#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
538#[inline]
539#[target_feature(enable = "avx512f")]
540unsafe fn natr_avx512_body(
541 high: &[f64],
542 low: &[f64],
543 close: &[f64],
544 period: usize,
545 first: usize,
546 out: &mut [f64],
547) {
548 use core::arch::x86_64::*;
549 debug_assert!(high.len() == low.len() && high.len() == close.len() && high.len() == out.len());
550 let len = out.len();
551 if first >= len {
552 return;
553 }
554
555 let inv_p: f64 = 1.0 / (period as f64);
556 let k100: f64 = 100.0;
557
558 let h = high.as_ptr();
559 let l = low.as_ptr();
560 let c = close.as_ptr();
561 let o = out.as_mut_ptr();
562
563 let mut i = first;
564 let mut sum_tr = *h.add(i) - *l.add(i);
565 i += 1;
566 let warm_end = first + period - 1;
567
568 let sign = _mm512_set1_pd(-0.0_f64);
569 while i + 7 <= warm_end {
570 let vh = _mm512_loadu_pd(h.add(i));
571 let vl = _mm512_loadu_pd(l.add(i));
572 let vpc = _mm512_loadu_pd(c.add(i - 1));
573
574 let vhl = _mm512_sub_pd(vh, vl);
575 let vhc = _mm512_sub_pd(vh, vpc);
576 let vlc = _mm512_sub_pd(vl, vpc);
577 let abs_hc = _mm512_andnot_pd(sign, vhc);
578 let abs_lc = _mm512_andnot_pd(sign, vlc);
579 let mx1 = _mm512_max_pd(vhl, abs_hc);
580 let vtr = _mm512_max_pd(mx1, abs_lc);
581
582 let mut tmp = [0.0f64; 8];
583 _mm512_storeu_pd(tmp.as_mut_ptr(), vtr);
584 sum_tr += tmp.iter().sum::<f64>();
585
586 i += 8;
587 }
588 while i <= warm_end {
589 let hi = *h.add(i);
590 let lo = *l.add(i);
591 let pc = *c.add(i - 1);
592 let tr = (hi - lo).max((hi - pc).abs()).max((lo - pc).abs());
593 sum_tr += tr;
594 i += 1;
595 }
596
597 let mut atr = sum_tr * inv_p;
598 let c_we = *c.add(warm_end);
599 *o.add(warm_end) = if c_we.is_finite() && c_we != 0.0 {
600 (atr / c_we) * k100
601 } else {
602 f64::NAN
603 };
604
605 let mut idx = warm_end + 1;
606 while idx + 7 < len {
607 let vh = _mm512_loadu_pd(h.add(idx));
608 let vl = _mm512_loadu_pd(l.add(idx));
609 let vpc = _mm512_loadu_pd(c.add(idx - 1));
610
611 let vhl = _mm512_sub_pd(vh, vl);
612 let vhc = _mm512_sub_pd(vh, vpc);
613 let vlc = _mm512_sub_pd(vl, vpc);
614 let abs_hc = _mm512_andnot_pd(sign, vhc);
615 let abs_lc = _mm512_andnot_pd(sign, vlc);
616 let mx1 = _mm512_max_pd(vhl, abs_hc);
617 let vtr = _mm512_max_pd(mx1, abs_lc);
618
619 let mut tr = [0.0f64; 8];
620 _mm512_storeu_pd(tr.as_mut_ptr(), vtr);
621
622 atr = (tr[0] - atr).mul_add(inv_p, atr);
623 let c0 = *c.add(idx + 0);
624 *o.add(idx + 0) = if c0.is_finite() && c0 != 0.0 {
625 (atr / c0) * k100
626 } else {
627 f64::NAN
628 };
629 atr = (tr[1] - atr).mul_add(inv_p, atr);
630 let c1 = *c.add(idx + 1);
631 *o.add(idx + 1) = if c1.is_finite() && c1 != 0.0 {
632 (atr / c1) * k100
633 } else {
634 f64::NAN
635 };
636 atr = (tr[2] - atr).mul_add(inv_p, atr);
637 let c2 = *c.add(idx + 2);
638 *o.add(idx + 2) = if c2.is_finite() && c2 != 0.0 {
639 (atr / c2) * k100
640 } else {
641 f64::NAN
642 };
643 atr = (tr[3] - atr).mul_add(inv_p, atr);
644 let c3 = *c.add(idx + 3);
645 *o.add(idx + 3) = if c3.is_finite() && c3 != 0.0 {
646 (atr / c3) * k100
647 } else {
648 f64::NAN
649 };
650 atr = (tr[4] - atr).mul_add(inv_p, atr);
651 let c4 = *c.add(idx + 4);
652 *o.add(idx + 4) = if c4.is_finite() && c4 != 0.0 {
653 (atr / c4) * k100
654 } else {
655 f64::NAN
656 };
657 atr = (tr[5] - atr).mul_add(inv_p, atr);
658 let c5 = *c.add(idx + 5);
659 *o.add(idx + 5) = if c5.is_finite() && c5 != 0.0 {
660 (atr / c5) * k100
661 } else {
662 f64::NAN
663 };
664 atr = (tr[6] - atr).mul_add(inv_p, atr);
665 let c6 = *c.add(idx + 6);
666 *o.add(idx + 6) = if c6.is_finite() && c6 != 0.0 {
667 (atr / c6) * k100
668 } else {
669 f64::NAN
670 };
671 atr = (tr[7] - atr).mul_add(inv_p, atr);
672 let c7 = *c.add(idx + 7);
673 *o.add(idx + 7) = if c7.is_finite() && c7 != 0.0 {
674 (atr / c7) * k100
675 } else {
676 f64::NAN
677 };
678
679 idx += 8;
680 }
681 while idx < len {
682 let hi = *h.add(idx);
683 let lo = *l.add(idx);
684 let pc = *c.add(idx - 1);
685 let tr = (hi - lo).max((hi - pc).abs()).max((lo - pc).abs());
686 atr = (tr - atr).mul_add(inv_p, atr);
687
688 let cv = *c.add(idx);
689 *o.add(idx) = if cv.is_finite() && cv != 0.0 {
690 (atr / cv) * k100
691 } else {
692 f64::NAN
693 };
694 idx += 1;
695 }
696}
697
698#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
699#[inline]
700pub unsafe fn natr_avx512_short(
701 high: &[f64],
702 low: &[f64],
703 close: &[f64],
704 period: usize,
705 first: usize,
706 out: &mut [f64],
707) {
708 natr_avx512_body(high, low, close, period, first, out);
709}
710
711#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
712#[inline]
713pub unsafe fn natr_avx512_long(
714 high: &[f64],
715 low: &[f64],
716 close: &[f64],
717 period: usize,
718 first: usize,
719 out: &mut [f64],
720) {
721 natr_avx512_body(high, low, close, period, first, out);
722}
723
724#[derive(Debug, Clone)]
725pub struct NatrStream {
726 period: usize,
727 alpha: f64,
728 k100: f64,
729
730 count: usize,
731 sum_tr: f64,
732 atr: f64,
733 prev_close: f64,
734 have_prev: bool,
735 ready: bool,
736}
737
738impl NatrStream {
739 #[inline(always)]
740 pub fn try_new(params: NatrParams) -> Result<Self, NatrError> {
741 let period = params.period.unwrap_or(14);
742 if period == 0 {
743 return Err(NatrError::InvalidPeriod {
744 period,
745 data_len: 0,
746 });
747 }
748 Ok(Self {
749 period,
750 alpha: 1.0 / (period as f64),
751 k100: 100.0,
752 count: 0,
753 sum_tr: 0.0,
754 atr: 0.0,
755 prev_close: 0.0,
756 have_prev: false,
757 ready: false,
758 })
759 }
760
761 #[inline(always)]
762 pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
763 let tr = if self.have_prev {
764 let pc = self.prev_close;
765 high.max(pc) - low.min(pc)
766 } else {
767 high - low
768 };
769
770 self.prev_close = close;
771 self.have_prev = true;
772
773 self.count += 1;
774
775 if !self.ready {
776 self.sum_tr += tr;
777 if self.count == self.period {
778 self.atr = self.sum_tr / (self.period as f64);
779 self.ready = true;
780 } else {
781 return None;
782 }
783 } else {
784 self.atr = (tr - self.atr).mul_add(self.alpha, self.atr);
785 }
786
787 if close.is_finite() && close != 0.0 {
788 Some((self.atr / close) * self.k100)
789 } else {
790 Some(f64::NAN)
791 }
792 }
793}
794
795#[derive(Clone, Debug)]
796pub struct NatrBatchRange {
797 pub period: (usize, usize, usize),
798}
799
800impl Default for NatrBatchRange {
801 fn default() -> Self {
802 Self {
803 period: (14, 263, 1),
804 }
805 }
806}
807
808#[derive(Clone, Debug, Default)]
809pub struct NatrBatchBuilder {
810 range: NatrBatchRange,
811 kernel: Kernel,
812}
813
814impl NatrBatchBuilder {
815 pub fn new() -> Self {
816 Self::default()
817 }
818 pub fn kernel(mut self, k: Kernel) -> Self {
819 self.kernel = k;
820 self
821 }
822 #[inline]
823 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
824 self.range.period = (start, end, step);
825 self
826 }
827 #[inline]
828 pub fn period_static(mut self, p: usize) -> Self {
829 self.range.period = (p, p, 0);
830 self
831 }
832 pub fn apply_slices(
833 self,
834 high: &[f64],
835 low: &[f64],
836 close: &[f64],
837 ) -> Result<NatrBatchOutput, NatrError> {
838 natr_batch_with_kernel(high, low, close, &self.range, self.kernel)
839 }
840 pub fn apply_candles(self, c: &Candles) -> Result<NatrBatchOutput, NatrError> {
841 let high = source_type(c, "high");
842 let low = source_type(c, "low");
843 let close = source_type(c, "close");
844 self.apply_slices(high, low, close)
845 }
846 pub fn with_default_candles(c: &Candles, k: Kernel) -> Result<NatrBatchOutput, NatrError> {
847 NatrBatchBuilder::new().kernel(k).apply_candles(c)
848 }
849}
850
851pub fn natr_batch_with_kernel(
852 high: &[f64],
853 low: &[f64],
854 close: &[f64],
855 sweep: &NatrBatchRange,
856 k: Kernel,
857) -> Result<NatrBatchOutput, NatrError> {
858 let kernel = match k {
859 Kernel::Auto => detect_best_batch_kernel(),
860 other if other.is_batch() => other,
861 _ => {
862 return Err(NatrError::InvalidKernelForBatch(k));
863 }
864 };
865 let simd = match kernel {
866 Kernel::Avx512Batch => Kernel::Avx512,
867 Kernel::Avx2Batch => Kernel::Avx2,
868 Kernel::ScalarBatch => Kernel::Scalar,
869 _ => unreachable!(),
870 };
871 natr_batch_par_slice(high, low, close, sweep, simd)
872}
873
874#[derive(Clone, Debug)]
875pub struct NatrBatchOutput {
876 pub values: Vec<f64>,
877 pub combos: Vec<NatrParams>,
878 pub rows: usize,
879 pub cols: usize,
880}
881impl NatrBatchOutput {
882 pub fn row_for_params(&self, p: &NatrParams) -> Option<usize> {
883 self.combos
884 .iter()
885 .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
886 }
887 pub fn values_for(&self, p: &NatrParams) -> Option<&[f64]> {
888 self.row_for_params(p).map(|row| {
889 let start = row * self.cols;
890 &self.values[start..start + self.cols]
891 })
892 }
893}
894
895#[inline(always)]
896fn expand_grid(r: &NatrBatchRange) -> Result<Vec<NatrParams>, NatrError> {
897 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, NatrError> {
898 if step == 0 || start == end {
899 return Ok(vec![start]);
900 }
901
902 let mut values = Vec::new();
903 let step_u = step;
904
905 if start <= end {
906 let mut v = start;
907 loop {
908 if v > end {
909 break;
910 }
911 values.push(v);
912 match v.checked_add(step_u) {
913 Some(next) => v = next,
914 None => break,
915 }
916 }
917 } else {
918 let mut v = start;
919 loop {
920 if v < end {
921 break;
922 }
923 values.push(v);
924 match v.checked_sub(step_u) {
925 Some(next) => v = next,
926 None => break,
927 }
928 }
929 }
930
931 if values.is_empty() {
932 return Err(NatrError::InvalidRange {
933 start: start.to_string(),
934 end: end.to_string(),
935 step: step.to_string(),
936 });
937 }
938
939 Ok(values)
940 }
941
942 let periods = axis_usize(r.period)?;
943
944 let mut out = Vec::with_capacity(periods.len());
945 for p in periods {
946 out.push(NatrParams { period: Some(p) });
947 }
948 Ok(out)
949}
950
951#[inline(always)]
952pub fn natr_batch_slice(
953 high: &[f64],
954 low: &[f64],
955 close: &[f64],
956 sweep: &NatrBatchRange,
957 kern: Kernel,
958) -> Result<NatrBatchOutput, NatrError> {
959 natr_batch_inner(high, low, close, sweep, kern, false)
960}
961
962#[inline(always)]
963pub fn natr_batch_par_slice(
964 high: &[f64],
965 low: &[f64],
966 close: &[f64],
967 sweep: &NatrBatchRange,
968 kern: Kernel,
969) -> Result<NatrBatchOutput, NatrError> {
970 natr_batch_inner(high, low, close, sweep, kern, true)
971}
972
973#[inline(always)]
974fn natr_batch_inner(
975 high: &[f64],
976 low: &[f64],
977 close: &[f64],
978 sweep: &NatrBatchRange,
979 kern: Kernel,
980 parallel: bool,
981) -> Result<NatrBatchOutput, NatrError> {
982 let combos = expand_grid(sweep)?;
983
984 let len_h = high.len();
985 let len_l = low.len();
986 let len_c = close.len();
987 if len_h != len_l || len_h != len_c {
988 return Err(NatrError::MismatchedLength {
989 expected: len_h,
990 actual: if len_l != len_h { len_l } else { len_c },
991 });
992 }
993 let len = len_h;
994
995 let first = high
996 .iter()
997 .position(|x| !x.is_nan())
998 .unwrap_or(0)
999 .max(low.iter().position(|x| !x.is_nan()).unwrap_or(0))
1000 .max(close.iter().position(|x| !x.is_nan()).unwrap_or(0));
1001 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1002 if len - first < max_p {
1003 return Err(NatrError::NotEnoughValidData {
1004 needed: max_p,
1005 valid: len - first,
1006 });
1007 }
1008 let rows = combos.len();
1009 let cols = len;
1010
1011 let mut buf_mu = make_uninit_matrix(rows, cols);
1012 let warm: Vec<usize> = combos
1013 .iter()
1014 .map(|c| first + c.period.unwrap() - 1)
1015 .collect();
1016 init_matrix_prefixes(&mut buf_mu, cols, &warm);
1017
1018 let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
1019 let out: &mut [f64] = unsafe {
1020 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
1021 };
1022
1023 let use_tr_shared = combos.len() >= 24;
1024
1025 if use_tr_shared {
1026 let mut tr: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, len);
1027 tr.resize(len, 0.0);
1028 if first < len {
1029 tr[first] = high[first] - low[first];
1030 let mut i = first + 1;
1031 while i < len {
1032 let hi = high[i];
1033 let lo = low[i];
1034 let pc = close[i - 1];
1035 let trv = (hi - lo).max((hi - pc).abs()).max((lo - pc).abs());
1036 tr[i] = trv;
1037 i += 1;
1038 }
1039 }
1040
1041 let mut pref: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, len + 1);
1042 pref.resize(len + 1, 0.0);
1043 for i in first..len {
1044 pref[i + 1] = pref[i] + tr[i];
1045 }
1046
1047 let do_row = |row: usize, out_row: &mut [f64]| {
1048 let period = combos[row].period.unwrap();
1049 let inv_p = 1.0 / (period as f64);
1050 let k100 = 100.0;
1051 let warm_end = first + period - 1;
1052
1053 let sum_tr = pref[warm_end + 1] - pref[first];
1054 let mut atr = sum_tr * inv_p;
1055 let cw = close[warm_end];
1056 out_row[warm_end] = if cw.is_finite() && cw != 0.0 {
1057 (atr / cw) * k100
1058 } else {
1059 f64::NAN
1060 };
1061
1062 let mut i = warm_end + 1;
1063 while i < len {
1064 atr = (tr[i] - atr).mul_add(inv_p, atr);
1065 let cv = close[i];
1066 out_row[i] = if cv.is_finite() && cv != 0.0 {
1067 (atr / cv) * k100
1068 } else {
1069 f64::NAN
1070 };
1071 i += 1;
1072 }
1073 };
1074
1075 if parallel {
1076 #[cfg(not(target_arch = "wasm32"))]
1077 {
1078 out.par_chunks_mut(cols)
1079 .enumerate()
1080 .for_each(|(row, slice)| do_row(row, slice));
1081 }
1082
1083 #[cfg(target_arch = "wasm32")]
1084 {
1085 for (row, slice) in out.chunks_mut(cols).enumerate() {
1086 do_row(row, slice);
1087 }
1088 }
1089 } else {
1090 for (row, slice) in out.chunks_mut(cols).enumerate() {
1091 do_row(row, slice);
1092 }
1093 }
1094 } else {
1095 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1096 let period = combos[row].period.unwrap();
1097 match kern {
1098 Kernel::Scalar => natr_row_scalar(high, low, close, first, period, out_row),
1099 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1100 Kernel::Avx2 => natr_row_avx2(high, low, close, first, period, out_row),
1101 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1102 Kernel::Avx512 => natr_row_avx512(high, low, close, first, period, out_row),
1103 _ => unreachable!(),
1104 }
1105 };
1106
1107 if parallel {
1108 #[cfg(not(target_arch = "wasm32"))]
1109 {
1110 out.par_chunks_mut(cols)
1111 .enumerate()
1112 .for_each(|(row, slice)| do_row(row, slice));
1113 }
1114
1115 #[cfg(target_arch = "wasm32")]
1116 {
1117 for (row, slice) in out.chunks_mut(cols).enumerate() {
1118 do_row(row, slice);
1119 }
1120 }
1121 } else {
1122 for (row, slice) in out.chunks_mut(cols).enumerate() {
1123 do_row(row, slice);
1124 }
1125 }
1126 }
1127
1128 let values = unsafe {
1129 Vec::from_raw_parts(
1130 buf_guard.as_mut_ptr() as *mut f64,
1131 buf_guard.len(),
1132 buf_guard.capacity(),
1133 )
1134 };
1135
1136 Ok(NatrBatchOutput {
1137 values,
1138 combos,
1139 rows,
1140 cols,
1141 })
1142}
1143
1144#[inline(always)]
1145fn natr_batch_inner_into(
1146 high: &[f64],
1147 low: &[f64],
1148 close: &[f64],
1149 sweep: &NatrBatchRange,
1150 kern: Kernel,
1151 parallel: bool,
1152 out: &mut [f64],
1153) -> Result<Vec<NatrParams>, NatrError> {
1154 let combos = expand_grid(sweep)?;
1155
1156 let len_h = high.len();
1157 let len_l = low.len();
1158 let len_c = close.len();
1159 if len_h != len_l || len_h != len_c {
1160 return Err(NatrError::MismatchedLength {
1161 expected: len_h,
1162 actual: if len_l != len_h { len_l } else { len_c },
1163 });
1164 }
1165 let len = len_h;
1166
1167 let first = high
1168 .iter()
1169 .position(|x| !x.is_nan())
1170 .unwrap_or(0)
1171 .max(low.iter().position(|x| !x.is_nan()).unwrap_or(0))
1172 .max(close.iter().position(|x| !x.is_nan()).unwrap_or(0));
1173 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1174 if len - first < max_p {
1175 return Err(NatrError::NotEnoughValidData {
1176 needed: max_p,
1177 valid: len - first,
1178 });
1179 }
1180 let rows = combos.len();
1181 let cols = len;
1182
1183 for (row, combo) in combos.iter().enumerate() {
1184 let period = combo.period.unwrap();
1185 let warmup_end = first + period - 1;
1186 let row_start = row * cols;
1187 for i in 0..warmup_end.min(cols) {
1188 out[row_start + i] = f64::NAN;
1189 }
1190 }
1191
1192 let use_tr_shared = combos.len() >= 24;
1193 if use_tr_shared {
1194 let mut tr: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, len);
1195 tr.resize(len, 0.0);
1196 if first < len {
1197 tr[first] = high[first] - low[first];
1198 let mut i = first + 1;
1199 while i < len {
1200 let hi = high[i];
1201 let lo = low[i];
1202 let pc = close[i - 1];
1203 let trv = (hi - lo).max((hi - pc).abs()).max((lo - pc).abs());
1204 tr[i] = trv;
1205 i += 1;
1206 }
1207 }
1208
1209 let mut pref: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, len + 1);
1210 pref.resize(len + 1, 0.0);
1211 for i in first..len {
1212 pref[i + 1] = pref[i] + tr[i];
1213 }
1214
1215 let do_row = |row: usize, out_row: &mut [f64]| {
1216 let period = combos[row].period.unwrap();
1217 let inv_p = 1.0 / (period as f64);
1218 let k100 = 100.0;
1219 let warm_end = first + period - 1;
1220
1221 let sum_tr = pref[warm_end + 1] - pref[first];
1222 let mut atr = sum_tr * inv_p;
1223 let cw = close[warm_end];
1224 out_row[warm_end] = if cw.is_finite() && cw != 0.0 {
1225 (atr / cw) * k100
1226 } else {
1227 f64::NAN
1228 };
1229
1230 let mut i = warm_end + 1;
1231 while i < len {
1232 atr = (tr[i] - atr).mul_add(inv_p, atr);
1233 let cv = close[i];
1234 out_row[i] = if cv.is_finite() && cv != 0.0 {
1235 (atr / cv) * k100
1236 } else {
1237 f64::NAN
1238 };
1239 i += 1;
1240 }
1241 };
1242
1243 if parallel {
1244 #[cfg(not(target_arch = "wasm32"))]
1245 {
1246 out.par_chunks_mut(cols)
1247 .enumerate()
1248 .for_each(|(row, slice)| do_row(row, slice));
1249 }
1250
1251 #[cfg(target_arch = "wasm32")]
1252 {
1253 for (row, slice) in out.chunks_mut(cols).enumerate() {
1254 do_row(row, slice);
1255 }
1256 }
1257 } else {
1258 for (row, slice) in out.chunks_mut(cols).enumerate() {
1259 do_row(row, slice);
1260 }
1261 }
1262 } else {
1263 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1264 let period = combos[row].period.unwrap();
1265 match kern {
1266 Kernel::Scalar => natr_row_scalar(high, low, close, first, period, out_row),
1267 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1268 Kernel::Avx2 => natr_row_avx2(high, low, close, first, period, out_row),
1269 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1270 Kernel::Avx512 => natr_row_avx512(high, low, close, first, period, out_row),
1271 _ => unreachable!(),
1272 }
1273 };
1274
1275 if parallel {
1276 #[cfg(not(target_arch = "wasm32"))]
1277 {
1278 out.par_chunks_mut(cols)
1279 .enumerate()
1280 .for_each(|(row, slice)| do_row(row, slice));
1281 }
1282
1283 #[cfg(target_arch = "wasm32")]
1284 {
1285 for (row, slice) in out.chunks_mut(cols).enumerate() {
1286 do_row(row, slice);
1287 }
1288 }
1289 } else {
1290 for (row, slice) in out.chunks_mut(cols).enumerate() {
1291 do_row(row, slice);
1292 }
1293 }
1294 }
1295
1296 Ok(combos)
1297}
1298
1299#[inline(always)]
1300unsafe fn natr_row_scalar(
1301 high: &[f64],
1302 low: &[f64],
1303 close: &[f64],
1304 first: usize,
1305 period: usize,
1306 out: &mut [f64],
1307) {
1308 natr_scalar(high, low, close, period, first, out)
1309}
1310
1311#[inline(always)]
1312unsafe fn natr_row_scalar_from_tr(
1313 tr: &[f64],
1314 close: &[f64],
1315 first: usize,
1316 period: usize,
1317 out: &mut [f64],
1318) {
1319 let len = out.len();
1320 if first >= len {
1321 return;
1322 }
1323 let inv_p = 1.0 / (period as f64);
1324 let k100 = 100.0;
1325 let warm_end = first + period - 1;
1326
1327 let mut sum_tr = 0.0;
1328 for i in first..=warm_end {
1329 sum_tr += tr[i];
1330 }
1331 let mut atr = sum_tr * inv_p;
1332 let cw = close[warm_end];
1333 out[warm_end] = if cw.is_finite() && cw != 0.0 {
1334 (atr / cw) * k100
1335 } else {
1336 f64::NAN
1337 };
1338
1339 let mut i = warm_end + 1;
1340 while i < len {
1341 atr = (tr[i] - atr).mul_add(inv_p, atr);
1342 let cv = close[i];
1343 out[i] = if cv.is_finite() && cv != 0.0 {
1344 (atr / cv) * k100
1345 } else {
1346 f64::NAN
1347 };
1348 i += 1;
1349 }
1350}
1351
1352#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1353#[inline(always)]
1354unsafe fn natr_row_avx2(
1355 high: &[f64],
1356 low: &[f64],
1357 close: &[f64],
1358 first: usize,
1359 period: usize,
1360 out: &mut [f64],
1361) {
1362 natr_avx2(high, low, close, period, first, out)
1363}
1364
1365#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1366#[inline(always)]
1367pub unsafe fn natr_row_avx512(
1368 high: &[f64],
1369 low: &[f64],
1370 close: &[f64],
1371 first: usize,
1372 period: usize,
1373 out: &mut [f64],
1374) {
1375 natr_avx512(high, low, close, period, first, out);
1376}
1377
1378#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1379#[inline(always)]
1380pub unsafe fn natr_row_avx512_short(
1381 high: &[f64],
1382 low: &[f64],
1383 close: &[f64],
1384 first: usize,
1385 period: usize,
1386 out: &mut [f64],
1387) {
1388 natr_avx512_body(high, low, close, period, first, out)
1389}
1390
1391#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1392#[inline(always)]
1393pub unsafe fn natr_row_avx512_long(
1394 high: &[f64],
1395 low: &[f64],
1396 close: &[f64],
1397 first: usize,
1398 period: usize,
1399 out: &mut [f64],
1400) {
1401 natr_avx512_body(high, low, close, period, first, out)
1402}
1403
1404#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1405#[inline(always)]
1406unsafe fn natr_row_avx2_from_tr(
1407 tr: &[f64],
1408 close: &[f64],
1409 first: usize,
1410 period: usize,
1411 out: &mut [f64],
1412) {
1413 natr_row_scalar_from_tr(tr, close, first, period, out)
1414}
1415
1416#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1417#[inline(always)]
1418unsafe fn natr_row_avx512_from_tr(
1419 tr: &[f64],
1420 close: &[f64],
1421 first: usize,
1422 period: usize,
1423 out: &mut [f64],
1424) {
1425 natr_row_scalar_from_tr(tr, close, first, period, out)
1426}
1427
1428#[cfg(test)]
1429mod tests {
1430 use super::*;
1431 use crate::skip_if_unsupported;
1432 use crate::utilities::data_loader::read_candles_from_csv;
1433 #[cfg(feature = "proptest")]
1434 use proptest::prelude::*;
1435
1436 #[test]
1437 fn test_natr_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1438 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1439 let candles = read_candles_from_csv(file_path)?;
1440
1441 let input = NatrInput::with_default_candles(&candles);
1442
1443 let baseline = natr(&input)?.values;
1444
1445 let mut out = vec![0.0; candles.close.len()];
1446 #[allow(unused_variables)]
1447 {
1448 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1449 {
1450 natr_into(&input, &mut out)?;
1451 }
1452 }
1453
1454 assert_eq!(baseline.len(), out.len());
1455
1456 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1457 (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
1458 }
1459
1460 for i in 0..baseline.len() {
1461 assert!(
1462 eq_or_both_nan(baseline[i], out[i]),
1463 "NATR into parity mismatch at index {}: baseline={}, into={}",
1464 i,
1465 baseline[i],
1466 out[i]
1467 );
1468 }
1469 Ok(())
1470 }
1471
1472 fn check_natr_partial_params(
1473 test_name: &str,
1474 kernel: Kernel,
1475 ) -> Result<(), Box<dyn std::error::Error>> {
1476 skip_if_unsupported!(kernel, test_name);
1477 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1478 let candles = read_candles_from_csv(file_path)?;
1479 let default_params = NatrParams { period: None };
1480 let input_default = NatrInput::from_candles(&candles, default_params);
1481 let output_default = natr_with_kernel(&input_default, kernel)?;
1482 assert_eq!(output_default.values.len(), candles.close.len());
1483 let params_period_7 = NatrParams { period: Some(7) };
1484 let input_period_7 = NatrInput::from_candles(&candles, params_period_7);
1485 let output_period_7 = natr_with_kernel(&input_period_7, kernel)?;
1486 assert_eq!(output_period_7.values.len(), candles.close.len());
1487 Ok(())
1488 }
1489
1490 fn check_natr_accuracy(
1491 test_name: &str,
1492 kernel: Kernel,
1493 ) -> Result<(), Box<dyn std::error::Error>> {
1494 skip_if_unsupported!(kernel, test_name);
1495 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1496 let candles = read_candles_from_csv(file_path)?;
1497 let close_prices = candles.select_candle_field("close").unwrap();
1498 let params = NatrParams { period: Some(14) };
1499 let input = NatrInput::from_candles(&candles, params.clone());
1500 let natr_result = natr_with_kernel(&input, kernel)?;
1501 assert_eq!(natr_result.values.len(), close_prices.len());
1502 let expected_last_five = [
1503 1.5465877404905772,
1504 1.4773840355794576,
1505 1.4201627494720954,
1506 1.3556212509014807,
1507 1.3836271128536142,
1508 ];
1509 let start_index = natr_result.values.len() - 5;
1510 let result_last_five = &natr_result.values[start_index..];
1511 for (i, &value) in result_last_five.iter().enumerate() {
1512 let expected_value = expected_last_five[i];
1513 assert!(
1514 (value - expected_value).abs() < 1e-8,
1515 "NATR mismatch at index {}: expected {}, got {}",
1516 i,
1517 expected_value,
1518 value
1519 );
1520 }
1521 let period = params.period.unwrap();
1522 for i in 0..(period - 1) {
1523 assert!(natr_result.values[i].is_nan());
1524 }
1525 Ok(())
1526 }
1527
1528 fn check_natr_with_zero_period(
1529 test_name: &str,
1530 kernel: Kernel,
1531 ) -> Result<(), Box<dyn std::error::Error>> {
1532 skip_if_unsupported!(kernel, test_name);
1533 let high = [10.0, 20.0, 30.0];
1534 let low = [5.0, 10.0, 15.0];
1535 let close = [7.0, 14.0, 25.0];
1536 let params = NatrParams { period: Some(0) };
1537 let input = NatrInput::from_slices(&high, &low, &close, params);
1538 let result = natr_with_kernel(&input, kernel);
1539 assert!(result.is_err());
1540 Ok(())
1541 }
1542
1543 fn check_natr_period_exceeds_length(
1544 test_name: &str,
1545 kernel: Kernel,
1546 ) -> Result<(), Box<dyn std::error::Error>> {
1547 skip_if_unsupported!(kernel, test_name);
1548 let high = [10.0, 20.0, 30.0];
1549 let low = [5.0, 10.0, 15.0];
1550 let close = [7.0, 14.0, 25.0];
1551 let params = NatrParams { period: Some(10) };
1552 let input = NatrInput::from_slices(&high, &low, &close, params);
1553 let result = natr_with_kernel(&input, kernel);
1554 assert!(result.is_err());
1555 Ok(())
1556 }
1557
1558 fn check_natr_very_small_dataset(
1559 test_name: &str,
1560 kernel: Kernel,
1561 ) -> Result<(), Box<dyn std::error::Error>> {
1562 skip_if_unsupported!(kernel, test_name);
1563 let high = [42.0];
1564 let low = [40.0];
1565 let close = [41.0];
1566 let params = NatrParams { period: Some(14) };
1567 let input = NatrInput::from_slices(&high, &low, &close, params);
1568 let result = natr_with_kernel(&input, kernel);
1569 assert!(result.is_err());
1570 Ok(())
1571 }
1572
1573 fn check_natr_all_values_nan(
1574 test_name: &str,
1575 kernel: Kernel,
1576 ) -> Result<(), Box<dyn std::error::Error>> {
1577 skip_if_unsupported!(kernel, test_name);
1578 let high = [f64::NAN, f64::NAN];
1579 let low = [f64::NAN, f64::NAN];
1580 let close = [f64::NAN, f64::NAN];
1581 let params = NatrParams { period: Some(2) };
1582 let input = NatrInput::from_slices(&high, &low, &close, params);
1583 let result = natr_with_kernel(&input, kernel);
1584 assert!(result.is_err());
1585 Ok(())
1586 }
1587
1588 fn check_natr_not_enough_valid_data(
1589 test_name: &str,
1590 kernel: Kernel,
1591 ) -> Result<(), Box<dyn std::error::Error>> {
1592 skip_if_unsupported!(kernel, test_name);
1593 let high = [f64::NAN, 10.0];
1594 let low = [f64::NAN, 5.0];
1595 let close = [f64::NAN, 7.0];
1596 let params = NatrParams { period: Some(5) };
1597 let input = NatrInput::from_slices(&high, &low, &close, params);
1598 let result = natr_with_kernel(&input, kernel);
1599 assert!(result.is_err());
1600 Ok(())
1601 }
1602
1603 fn check_natr_slice_data_reinput(
1604 test_name: &str,
1605 kernel: Kernel,
1606 ) -> Result<(), Box<dyn std::error::Error>> {
1607 skip_if_unsupported!(kernel, test_name);
1608 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1609 let candles = read_candles_from_csv(file_path)?;
1610 let first_params = NatrParams { period: Some(14) };
1611 let first_input = NatrInput::from_candles(&candles, first_params);
1612 let first_result = natr_with_kernel(&first_input, kernel)?;
1613 assert_eq!(first_result.values.len(), candles.close.len());
1614
1615 let second_params = NatrParams { period: Some(14) };
1616 let second_input = NatrInput::from_slices(
1617 &first_result.values,
1618 &first_result.values,
1619 &first_result.values,
1620 second_params,
1621 );
1622 let second_result = natr_with_kernel(&second_input, kernel)?;
1623 assert_eq!(second_result.values.len(), first_result.values.len());
1624
1625 for i in 28..second_result.values.len() {
1626 assert!(
1627 !second_result.values[i].is_nan(),
1628 "Expected no NaN after index 28, but found NaN at index {}",
1629 i
1630 );
1631 }
1632 Ok(())
1633 }
1634
1635 fn check_natr_nan_check(
1636 test_name: &str,
1637 kernel: Kernel,
1638 ) -> Result<(), Box<dyn std::error::Error>> {
1639 skip_if_unsupported!(kernel, test_name);
1640 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1641 let candles = read_candles_from_csv(file_path)?;
1642 let params = NatrParams { period: Some(14) };
1643 let input = NatrInput::from_candles(&candles, params);
1644 let natr_result = natr_with_kernel(&input, kernel)?;
1645 assert_eq!(natr_result.values.len(), candles.close.len());
1646 if natr_result.values.len() > 30 {
1647 for i in 30..natr_result.values.len() {
1648 assert!(
1649 !natr_result.values[i].is_nan(),
1650 "Expected no NaN after index 30, but found NaN at index {}",
1651 i
1652 );
1653 }
1654 }
1655 Ok(())
1656 }
1657
1658 fn check_natr_default_candles(
1659 test_name: &str,
1660 kernel: Kernel,
1661 ) -> Result<(), Box<dyn std::error::Error>> {
1662 skip_if_unsupported!(kernel, test_name);
1663 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1664 let candles = read_candles_from_csv(file_path)?;
1665 let input = NatrInput::with_default_candles(&candles);
1666 match input.data {
1667 NatrData::Candles { .. } => {}
1668 _ => panic!("Expected NatrData::Candles variant"),
1669 }
1670 let output = natr_with_kernel(&input, kernel)?;
1671 assert_eq!(output.values.len(), candles.close.len());
1672 Ok(())
1673 }
1674
1675 fn check_natr_streaming(
1676 test_name: &str,
1677 kernel: Kernel,
1678 ) -> Result<(), Box<dyn std::error::Error>> {
1679 skip_if_unsupported!(kernel, test_name);
1680 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1681 let candles = read_candles_from_csv(file_path)?;
1682 let period = 14;
1683 let high = &candles.high;
1684 let low = &candles.low;
1685 let close = &candles.close;
1686 let input = NatrInput::from_slices(
1687 high,
1688 low,
1689 close,
1690 NatrParams {
1691 period: Some(period),
1692 },
1693 );
1694 let batch_output = natr_with_kernel(&input, kernel)?.values;
1695
1696 let mut stream = NatrStream::try_new(NatrParams {
1697 period: Some(period),
1698 })?;
1699 let mut stream_values = Vec::with_capacity(close.len());
1700 for ((&h, &l), &c) in high.iter().zip(low.iter()).zip(close.iter()) {
1701 match stream.update(h, l, c) {
1702 Some(natr_val) => stream_values.push(natr_val),
1703 None => stream_values.push(f64::NAN),
1704 }
1705 }
1706 assert_eq!(batch_output.len(), stream_values.len());
1707 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1708 if b.is_nan() && s.is_nan() {
1709 continue;
1710 }
1711 let diff = (b - s).abs();
1712 assert!(
1713 diff < 1e-9,
1714 "[{}] NATR streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1715 test_name,
1716 i,
1717 b,
1718 s,
1719 diff
1720 );
1721 }
1722 Ok(())
1723 }
1724
1725 #[cfg(debug_assertions)]
1726 fn check_natr_no_poison(
1727 test_name: &str,
1728 kernel: Kernel,
1729 ) -> Result<(), Box<dyn std::error::Error>> {
1730 skip_if_unsupported!(kernel, test_name);
1731
1732 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1733 let candles = read_candles_from_csv(file_path)?;
1734
1735 let test_params = vec![
1736 NatrParams::default(),
1737 NatrParams { period: Some(2) },
1738 NatrParams { period: Some(5) },
1739 NatrParams { period: Some(7) },
1740 NatrParams { period: Some(10) },
1741 NatrParams { period: Some(20) },
1742 NatrParams { period: Some(30) },
1743 NatrParams { period: Some(50) },
1744 NatrParams { period: Some(100) },
1745 NatrParams { period: Some(200) },
1746 ];
1747
1748 for (param_idx, params) in test_params.iter().enumerate() {
1749 let input = NatrInput::from_candles(&candles, params.clone());
1750 let output = natr_with_kernel(&input, kernel)?;
1751
1752 for (i, &val) in output.values.iter().enumerate() {
1753 if val.is_nan() {
1754 continue;
1755 }
1756
1757 let bits = val.to_bits();
1758
1759 if bits == 0x11111111_11111111 {
1760 panic!(
1761 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1762 with params: period={} (param set {})",
1763 test_name,
1764 val,
1765 bits,
1766 i,
1767 params.period.unwrap_or(14),
1768 param_idx
1769 );
1770 }
1771
1772 if bits == 0x22222222_22222222 {
1773 panic!(
1774 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1775 with params: period={} (param set {})",
1776 test_name,
1777 val,
1778 bits,
1779 i,
1780 params.period.unwrap_or(14),
1781 param_idx
1782 );
1783 }
1784
1785 if bits == 0x33333333_33333333 {
1786 panic!(
1787 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1788 with params: period={} (param set {})",
1789 test_name,
1790 val,
1791 bits,
1792 i,
1793 params.period.unwrap_or(14),
1794 param_idx
1795 );
1796 }
1797 }
1798 }
1799
1800 Ok(())
1801 }
1802
1803 #[cfg(not(debug_assertions))]
1804 fn check_natr_no_poison(
1805 _test_name: &str,
1806 _kernel: Kernel,
1807 ) -> Result<(), Box<dyn std::error::Error>> {
1808 Ok(())
1809 }
1810
1811 fn check_batch_default_row(
1812 test: &str,
1813 kernel: Kernel,
1814 ) -> Result<(), Box<dyn std::error::Error>> {
1815 skip_if_unsupported!(kernel, test);
1816 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1817 let c = read_candles_from_csv(file)?;
1818 let output = NatrBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
1819 let def = NatrParams::default();
1820 let row = output.values_for(&def).expect("default row missing");
1821 assert_eq!(row.len(), c.close.len());
1822 Ok(())
1823 }
1824
1825 #[cfg(debug_assertions)]
1826 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1827 skip_if_unsupported!(kernel, test);
1828
1829 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1830 let c = read_candles_from_csv(file)?;
1831
1832 let test_configs = vec![
1833 (2, 10, 2),
1834 (5, 25, 5),
1835 (30, 60, 15),
1836 (2, 5, 1),
1837 (10, 20, 2),
1838 (50, 100, 10),
1839 (14, 14, 0),
1840 ];
1841
1842 for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1843 let output = NatrBatchBuilder::new()
1844 .kernel(kernel)
1845 .period_range(p_start, p_end, p_step)
1846 .apply_candles(&c)?;
1847
1848 for (idx, &val) in output.values.iter().enumerate() {
1849 if val.is_nan() {
1850 continue;
1851 }
1852
1853 let bits = val.to_bits();
1854 let row = idx / output.cols;
1855 let col = idx % output.cols;
1856 let combo = &output.combos[row];
1857
1858 if bits == 0x11111111_11111111 {
1859 panic!(
1860 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1861 at row {} col {} (flat index {}) with params: period={}",
1862 test,
1863 cfg_idx,
1864 val,
1865 bits,
1866 row,
1867 col,
1868 idx,
1869 combo.period.unwrap_or(14)
1870 );
1871 }
1872
1873 if bits == 0x22222222_22222222 {
1874 panic!(
1875 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1876 at row {} col {} (flat index {}) with params: period={}",
1877 test,
1878 cfg_idx,
1879 val,
1880 bits,
1881 row,
1882 col,
1883 idx,
1884 combo.period.unwrap_or(14)
1885 );
1886 }
1887
1888 if bits == 0x33333333_33333333 {
1889 panic!(
1890 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1891 at row {} col {} (flat index {}) with params: period={}",
1892 test,
1893 cfg_idx,
1894 val,
1895 bits,
1896 row,
1897 col,
1898 idx,
1899 combo.period.unwrap_or(14)
1900 );
1901 }
1902 }
1903 }
1904
1905 Ok(())
1906 }
1907
1908 #[cfg(not(debug_assertions))]
1909 fn check_batch_no_poison(
1910 _test: &str,
1911 _kernel: Kernel,
1912 ) -> Result<(), Box<dyn std::error::Error>> {
1913 Ok(())
1914 }
1915
1916 #[cfg(feature = "proptest")]
1917 fn check_natr_property(
1918 test_name: &str,
1919 kernel: Kernel,
1920 ) -> Result<(), Box<dyn std::error::Error>> {
1921 skip_if_unsupported!(kernel, test_name);
1922
1923 let strat = (2usize..=50, 50usize..=400, 0usize..=2)
1924 .prop_flat_map(|(period, len, scenario)| {
1925 let close_strategy = match scenario {
1926 0 => prop::collection::vec(
1927 (1.0f64..1000.0f64).prop_filter("finite", |x| x.is_finite()),
1928 len,
1929 )
1930 .boxed(),
1931 1 => prop::collection::vec(
1932 (0.01f64..1.0f64).prop_filter("finite", |x| x.is_finite()),
1933 len,
1934 )
1935 .boxed(),
1936 _ => (1.0f64..100.0f64)
1937 .prop_map(move |val| vec![val; len])
1938 .boxed(),
1939 };
1940
1941 (close_strategy, Just(period), Just(len), Just(scenario))
1942 })
1943 .prop_flat_map(|(close_prices, period, len, scenario)| {
1944 let mut high_vec = Vec::with_capacity(len);
1945 let mut low_vec = Vec::with_capacity(len);
1946
1947 for (i, &close) in close_prices.iter().enumerate() {
1948 if scenario == 2 {
1949 high_vec.push(close);
1950 low_vec.push(close);
1951 } else {
1952 let volatility_factor = 0.001 + 0.20 * ((i * 7919) % 100) as f64 / 100.0;
1953 let spread = close * volatility_factor;
1954
1955 let high = close + spread * 0.5;
1956 let low = close - spread * 0.5;
1957
1958 high_vec.push(high);
1959 low_vec.push(low.max(0.001));
1960 }
1961 }
1962
1963 (
1964 Just(high_vec),
1965 Just(low_vec),
1966 Just(close_prices),
1967 Just(period),
1968 Just(scenario),
1969 )
1970 });
1971
1972 proptest::test_runner::TestRunner::default().run(
1973 &strat,
1974 |(high, low, close, period, scenario)| {
1975 let params = NatrParams {
1976 period: Some(period),
1977 };
1978 let input = NatrInput::from_slices(&high, &low, &close, params);
1979
1980 let result = natr_with_kernel(&input, kernel)?;
1981
1982 let ref_result = natr_with_kernel(&input, Kernel::Scalar)?;
1983
1984 prop_assert_eq!(result.values.len(), high.len());
1985 prop_assert_eq!(result.values.len(), low.len());
1986 prop_assert_eq!(result.values.len(), close.len());
1987
1988 for i in 0..(period - 1) {
1989 prop_assert!(
1990 result.values[i].is_nan(),
1991 "Expected NaN at index {} during warmup, got {}",
1992 i,
1993 result.values[i]
1994 );
1995 }
1996
1997 for i in period..result.values.len() {
1998 if result.values[i].is_finite() {
1999 prop_assert!(
2000 result.values[i] >= 0.0,
2001 "NATR should be non-negative at index {}: got {}",
2002 i,
2003 result.values[i]
2004 );
2005
2006 prop_assert!(
2007 result.values[i] < 10000.0,
2008 "NATR seems unreasonably high at index {}: got {}",
2009 i,
2010 result.values[i]
2011 );
2012 }
2013 }
2014
2015 for i in 0..result.values.len() {
2016 let val = result.values[i];
2017 let ref_val = ref_result.values[i];
2018
2019 if val.is_nan() && ref_val.is_nan() {
2020 continue;
2021 }
2022
2023 if val.is_finite() && ref_val.is_finite() {
2024 let diff = (val - ref_val).abs();
2025 let tolerance = (ref_val.abs() * 1e-10).max(1e-10);
2026 prop_assert!(
2027 diff <= tolerance,
2028 "Kernel mismatch at index {}: {} vs {} (diff: {})",
2029 i,
2030 val,
2031 ref_val,
2032 diff
2033 );
2034 } else {
2035 prop_assert_eq!(
2036 val.is_finite(),
2037 ref_val.is_finite(),
2038 "Finite status mismatch at index {}: {} vs {}",
2039 i,
2040 val,
2041 ref_val
2042 );
2043 }
2044 }
2045
2046 if scenario == 2 {
2047 let is_constant = high
2048 .iter()
2049 .zip(&low)
2050 .zip(&close)
2051 .all(|((h, l), c)| (*h - *l).abs() < 1e-10 && (*h - *c).abs() < 1e-10);
2052
2053 if is_constant && result.values.len() > period + 5 {
2054 for i in (period + 5)..result.values.len() {
2055 if result.values[i].is_finite() {
2056 prop_assert!(
2057 result.values[i].abs() < 1e-10,
2058 "NATR should be 0 for constant prices at index {}, got {}",
2059 i,
2060 result.values[i]
2061 );
2062 }
2063 }
2064 }
2065 }
2066
2067 if scenario == 1 {
2068 for i in period..result.values.len() {
2069 if result.values[i].is_finite() && close[i] > 0.0 {
2070 prop_assert!(
2071 result.values[i] >= 0.0 && result.values[i] < 100000.0,
2072 "NATR out of bounds with small prices at index {}: got {}",
2073 i,
2074 result.values[i]
2075 );
2076 }
2077 }
2078 }
2079
2080 if close.iter().any(|&c| c.abs() < 1e-10) {
2081 for (i, &c) in close.iter().enumerate() {
2082 if c.abs() < 1e-10 && i >= period - 1 {
2083 prop_assert!(
2084 result.values[i] == 0.0 || result.values[i].is_nan(),
2085 "NATR should be 0 or NaN when close is 0, got {} at index {}",
2086 result.values[i],
2087 i
2088 );
2089 }
2090 }
2091 }
2092
2093 #[cfg(debug_assertions)]
2094 {
2095 for (i, &val) in result.values.iter().enumerate() {
2096 if val.is_finite() {
2097 let bits = val.to_bits();
2098 prop_assert!(
2099 bits != 0x11111111_11111111
2100 && bits != 0x22222222_22222222
2101 && bits != 0x33333333_33333333,
2102 "Found poison value at index {}: {} (0x{:016X})",
2103 i,
2104 val,
2105 bits
2106 );
2107 }
2108 }
2109 }
2110
2111 Ok(())
2112 },
2113 )?;
2114
2115 Ok(())
2116 }
2117
2118 macro_rules! generate_all_natr_tests {
2119 ($($test_fn:ident),*) => {
2120 paste::paste! {
2121 $(
2122 #[test]
2123 fn [<$test_fn _scalar_f64>]() {
2124 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2125 }
2126 )*
2127 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2128 $(
2129 #[test]
2130 fn [<$test_fn _avx2_f64>]() {
2131 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2132 }
2133 #[test]
2134 fn [<$test_fn _avx512_f64>]() {
2135 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2136 }
2137 )*
2138 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2139 $(
2140 #[test]
2141 fn [<$test_fn _simd128_f64>]() {
2142 let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
2143 }
2144 )*
2145 }
2146 }
2147 }
2148
2149 #[cfg(feature = "proptest")]
2150 generate_all_natr_tests!(check_natr_property);
2151
2152 generate_all_natr_tests!(
2153 check_natr_partial_params,
2154 check_natr_accuracy,
2155 check_natr_with_zero_period,
2156 check_natr_period_exceeds_length,
2157 check_natr_very_small_dataset,
2158 check_natr_all_values_nan,
2159 check_natr_not_enough_valid_data,
2160 check_natr_slice_data_reinput,
2161 check_natr_nan_check,
2162 check_natr_default_candles,
2163 check_natr_streaming,
2164 check_natr_no_poison
2165 );
2166
2167 macro_rules! gen_batch_tests {
2168 ($fn_name:ident) => {
2169 paste::paste! {
2170 #[test] fn [<$fn_name _scalar>]() {
2171 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2172 }
2173 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2174 #[test] fn [<$fn_name _avx2>]() {
2175 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2176 }
2177 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2178 #[test] fn [<$fn_name _avx512>]() {
2179 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2180 }
2181 #[test] fn [<$fn_name _auto_detect>]() {
2182 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2183 }
2184 }
2185 };
2186 }
2187 gen_batch_tests!(check_batch_default_row);
2188 gen_batch_tests!(check_batch_no_poison);
2189}
2190
2191#[cfg(feature = "python")]
2192#[pyfunction(name = "natr")]
2193#[pyo3(signature = (high, low, close, period, kernel=None))]
2194pub fn natr_py<'py>(
2195 py: Python<'py>,
2196 high: PyReadonlyArray1<'py, f64>,
2197 low: PyReadonlyArray1<'py, f64>,
2198 close: PyReadonlyArray1<'py, f64>,
2199 period: usize,
2200 kernel: Option<&str>,
2201) -> PyResult<Bound<'py, PyArray1<f64>>> {
2202 let high_slice = high.as_slice()?;
2203 let low_slice = low.as_slice()?;
2204 let close_slice = close.as_slice()?;
2205 let kern = validate_kernel(kernel, false)?;
2206
2207 let params = NatrParams {
2208 period: Some(period),
2209 };
2210 let input = NatrInput::from_slices(high_slice, low_slice, close_slice, params);
2211
2212 let result_vec: Vec<f64> = py
2213 .allow_threads(|| natr_with_kernel(&input, kern).map(|o| o.values))
2214 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2215
2216 Ok(result_vec.into_pyarray(py))
2217}
2218
2219#[cfg(feature = "python")]
2220#[pyfunction(name = "natr_batch")]
2221#[pyo3(signature = (high, low, close, period_range, kernel=None))]
2222pub fn natr_batch_py<'py>(
2223 py: Python<'py>,
2224 high: PyReadonlyArray1<'py, f64>,
2225 low: PyReadonlyArray1<'py, f64>,
2226 close: PyReadonlyArray1<'py, f64>,
2227 period_range: (usize, usize, usize),
2228 kernel: Option<&str>,
2229) -> PyResult<Bound<'py, PyDict>> {
2230 let high_slice = high.as_slice()?;
2231 let low_slice = low.as_slice()?;
2232 let close_slice = close.as_slice()?;
2233 if high_slice.len() != low_slice.len() || high_slice.len() != close_slice.len() {
2234 return Err(PyValueError::new_err("natr: Mismatched input lengths"));
2235 }
2236 let cols = high_slice.len();
2237
2238 let kern = validate_kernel(kernel, true)?;
2239 let sweep = NatrBatchRange {
2240 period: period_range,
2241 };
2242 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2243 let rows = combos.len();
2244 let total = rows
2245 .checked_mul(cols)
2246 .ok_or_else(|| PyValueError::new_err("natr_batch: rows*cols overflow"))?;
2247
2248 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2249 let slice_out = unsafe { out_arr.as_slice_mut()? };
2250
2251 let combos = py
2252 .allow_threads(|| {
2253 let kernel = match kern {
2254 Kernel::Auto => detect_best_batch_kernel(),
2255 k => k,
2256 };
2257 let simd = match kernel {
2258 Kernel::Avx512Batch => Kernel::Avx512,
2259 Kernel::Avx2Batch => Kernel::Avx2,
2260 Kernel::ScalarBatch => Kernel::Scalar,
2261 _ => kernel,
2262 };
2263 natr_batch_inner_into(
2264 high_slice,
2265 low_slice,
2266 close_slice,
2267 &sweep,
2268 simd,
2269 true,
2270 slice_out,
2271 )
2272 })
2273 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2274
2275 let dict = PyDict::new(py);
2276 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2277 dict.set_item(
2278 "periods",
2279 combos
2280 .iter()
2281 .map(|p| p.period.unwrap() as u64)
2282 .collect::<Vec<_>>()
2283 .into_pyarray(py),
2284 )?;
2285
2286 Ok(dict)
2287}
2288
2289#[cfg(feature = "python")]
2290#[pyclass(name = "NatrStream")]
2291pub struct NatrStreamPy {
2292 stream: NatrStream,
2293}
2294
2295#[cfg(feature = "python")]
2296#[pymethods]
2297impl NatrStreamPy {
2298 #[new]
2299 fn new(period: usize) -> PyResult<Self> {
2300 let params = NatrParams {
2301 period: Some(period),
2302 };
2303 let stream =
2304 NatrStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2305 Ok(NatrStreamPy { stream })
2306 }
2307
2308 fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
2309 self.stream.update(high, low, close)
2310 }
2311}
2312
2313pub fn natr_into_slice(dst: &mut [f64], input: &NatrInput, kern: Kernel) -> Result<(), NatrError> {
2314 let (high, low, close, period) = match &input.data {
2315 NatrData::Candles { candles } => (
2316 candles.high.as_slice(),
2317 candles.low.as_slice(),
2318 candles.close.as_slice(),
2319 input.params.period.unwrap_or(14),
2320 ),
2321 NatrData::Slices { high, low, close } => {
2322 (*high, *low, *close, input.params.period.unwrap_or(14))
2323 }
2324 };
2325
2326 let len = high.len().min(low.len()).min(close.len());
2327
2328 if dst.len() != len {
2329 return Err(NatrError::OutputLengthMismatch {
2330 expected: len,
2331 got: dst.len(),
2332 });
2333 }
2334
2335 if len == 0 {
2336 return Err(NatrError::EmptyInputData);
2337 }
2338 if period == 0 {
2339 return Err(NatrError::InvalidPeriod {
2340 period,
2341 data_len: len,
2342 });
2343 }
2344 if period > len {
2345 return Err(NatrError::InvalidPeriod {
2346 period,
2347 data_len: len,
2348 });
2349 }
2350
2351 let first_valid_idx = {
2352 let first_valid_idx_h = high.iter().position(|&x| !x.is_nan());
2353 let first_valid_idx_l = low.iter().position(|&x| !x.is_nan());
2354 let first_valid_idx_c = close.iter().position(|&x| !x.is_nan());
2355
2356 match (first_valid_idx_h, first_valid_idx_l, first_valid_idx_c) {
2357 (Some(h), Some(l), Some(c)) => Some(h.max(l).max(c)),
2358 _ => None,
2359 }
2360 };
2361
2362 let first_valid_idx = match first_valid_idx {
2363 Some(idx) => idx,
2364 None => return Err(NatrError::AllValuesNaN),
2365 };
2366
2367 if (len - first_valid_idx) < period {
2368 return Err(NatrError::NotEnoughValidData {
2369 needed: period,
2370 valid: len - first_valid_idx,
2371 });
2372 }
2373
2374 let chosen = match kern {
2375 Kernel::Auto => Kernel::Scalar,
2376 other => other,
2377 };
2378
2379 unsafe {
2380 match chosen {
2381 Kernel::Scalar | Kernel::ScalarBatch => {
2382 natr_scalar(high, low, close, period, first_valid_idx, dst)
2383 }
2384 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2385 Kernel::Avx2 | Kernel::Avx2Batch => {
2386 natr_avx2(high, low, close, period, first_valid_idx, dst)
2387 }
2388 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2389 Kernel::Avx512 | Kernel::Avx512Batch => {
2390 natr_avx512(high, low, close, period, first_valid_idx, dst)
2391 }
2392 _ => unreachable!(),
2393 }
2394 }
2395
2396 for v in &mut dst[..(first_valid_idx + period - 1)] {
2397 *v = f64::NAN;
2398 }
2399
2400 Ok(())
2401}
2402
2403#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2404#[inline]
2405pub fn natr_into(input: &NatrInput, out: &mut [f64]) -> Result<(), NatrError> {
2406 natr_into_slice(out, input, Kernel::Auto)
2407}
2408
2409#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2410#[wasm_bindgen]
2411pub fn natr_js(
2412 high: &[f64],
2413 low: &[f64],
2414 close: &[f64],
2415 period: usize,
2416) -> Result<Vec<f64>, JsValue> {
2417 if high.len() != low.len() || high.len() != close.len() {
2418 return Err(JsValue::from_str("natr: Mismatched input lengths"));
2419 }
2420 let params = NatrParams {
2421 period: Some(period),
2422 };
2423 let input = NatrInput::from_slices(high, low, close, params);
2424 let mut output = vec![0.0; high.len()];
2425 natr_into_slice(&mut output, &input, detect_best_kernel())
2426 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2427 Ok(output)
2428}
2429
2430#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2431#[wasm_bindgen]
2432pub fn natr_into(
2433 high_ptr: *const f64,
2434 low_ptr: *const f64,
2435 close_ptr: *const f64,
2436 out_ptr: *mut f64,
2437 len: usize,
2438 period: usize,
2439) -> Result<(), JsValue> {
2440 if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
2441 return Err(JsValue::from_str("Null pointer provided"));
2442 }
2443
2444 unsafe {
2445 let high = std::slice::from_raw_parts(high_ptr, len);
2446 let low = std::slice::from_raw_parts(low_ptr, len);
2447 let close = std::slice::from_raw_parts(close_ptr, len);
2448 let params = NatrParams {
2449 period: Some(period),
2450 };
2451 let input = NatrInput::from_slices(high, low, close, params);
2452
2453 if high_ptr == out_ptr || low_ptr == out_ptr || close_ptr == out_ptr {
2454 let mut temp = vec![0.0; len];
2455 natr_into_slice(&mut temp, &input, detect_best_kernel())
2456 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2457 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2458 out.copy_from_slice(&temp);
2459 } else {
2460 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2461 natr_into_slice(out, &input, detect_best_kernel())
2462 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2463 }
2464 Ok(())
2465 }
2466}
2467
2468#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2469#[wasm_bindgen]
2470pub fn natr_alloc(len: usize) -> *mut f64 {
2471 let mut vec = Vec::<f64>::with_capacity(len);
2472 let ptr = vec.as_mut_ptr();
2473 std::mem::forget(vec);
2474 ptr
2475}
2476
2477#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2478#[wasm_bindgen]
2479pub fn natr_free(ptr: *mut f64, len: usize) {
2480 if !ptr.is_null() {
2481 unsafe {
2482 let _ = Vec::from_raw_parts(ptr, len, len);
2483 }
2484 }
2485}
2486
2487#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2488#[derive(Serialize, Deserialize)]
2489pub struct NatrBatchConfig {
2490 pub period_range: (usize, usize, usize),
2491}
2492
2493#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2494#[derive(Serialize, Deserialize)]
2495pub struct NatrBatchJsOutput {
2496 pub values: Vec<f64>,
2497 pub combos: Vec<NatrParams>,
2498 pub rows: usize,
2499 pub cols: usize,
2500}
2501
2502#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2503#[wasm_bindgen(js_name = natr_batch)]
2504pub fn natr_batch_unified_js(
2505 high: &[f64],
2506 low: &[f64],
2507 close: &[f64],
2508 config: JsValue,
2509) -> Result<JsValue, JsValue> {
2510 if high.len() != low.len() || high.len() != close.len() {
2511 return Err(JsValue::from_str("natr: Mismatched input lengths"));
2512 }
2513 let cfg: NatrBatchConfig = serde_wasm_bindgen::from_value(config)
2514 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2515
2516 let sweep = NatrBatchRange {
2517 period: cfg.period_range,
2518 };
2519
2520 let out = natr_batch_inner(high, low, close, &sweep, detect_best_kernel(), false)
2521 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2522
2523 let js = NatrBatchJsOutput {
2524 values: out.values,
2525 combos: out.combos,
2526 rows: out.rows,
2527 cols: out.cols,
2528 };
2529 serde_wasm_bindgen::to_value(&js)
2530 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2531}
2532
2533#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2534#[wasm_bindgen]
2535pub fn natr_batch_into(
2536 high_ptr: *const f64,
2537 low_ptr: *const f64,
2538 close_ptr: *const f64,
2539 out_ptr: *mut f64,
2540 len: usize,
2541 period_start: usize,
2542 period_end: usize,
2543 period_step: usize,
2544) -> Result<usize, JsValue> {
2545 if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
2546 return Err(JsValue::from_str("null pointer passed to natr_batch_into"));
2547 }
2548
2549 unsafe {
2550 let high = std::slice::from_raw_parts(high_ptr, len);
2551 let low = std::slice::from_raw_parts(low_ptr, len);
2552 let close = std::slice::from_raw_parts(close_ptr, len);
2553
2554 let sweep = NatrBatchRange {
2555 period: (period_start, period_end, period_step),
2556 };
2557
2558 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2559 let rows = combos.len();
2560 let cols = len;
2561 let total = rows
2562 .checked_mul(cols)
2563 .ok_or_else(|| JsValue::from_str("natr_batch_into: rows*cols overflow"))?;
2564
2565 let out = std::slice::from_raw_parts_mut(out_ptr, total);
2566
2567 natr_batch_inner_into(high, low, close, &sweep, detect_best_kernel(), false, out)
2568 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2569
2570 Ok(rows)
2571 }
2572}
2573
2574#[cfg(feature = "python")]
2575pub fn register_natr_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
2576 m.add_function(wrap_pyfunction!(natr_py, m)?)?;
2577 m.add_function(wrap_pyfunction!(natr_batch_py, m)?)?;
2578 #[cfg(feature = "cuda")]
2579 {
2580 m.add_class::<NatrDeviceArrayF32Py>()?;
2581 m.add_function(wrap_pyfunction!(natr_cuda_batch_dev_py, m)?)?;
2582 m.add_function(wrap_pyfunction!(natr_cuda_many_series_one_param_dev_py, m)?)?;
2583 }
2584 Ok(())
2585}
2586
2587#[cfg(all(feature = "python", feature = "cuda"))]
2588use crate::cuda::natr_wrapper::DeviceArrayF32Natr;
2589#[cfg(all(feature = "python", feature = "cuda"))]
2590use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2591#[cfg(all(feature = "python", feature = "cuda"))]
2592use cust::context::Context;
2593#[cfg(all(feature = "python", feature = "cuda"))]
2594use cust::memory::DeviceBuffer;
2595#[cfg(all(feature = "python", feature = "cuda"))]
2596use pyo3::exceptions::PyBufferError;
2597#[cfg(all(feature = "python", feature = "cuda"))]
2598use pyo3::PyObject;
2599#[cfg(all(feature = "python", feature = "cuda"))]
2600use std::sync::Arc;
2601
2602#[cfg(all(feature = "python", feature = "cuda"))]
2603#[pyclass(module = "ta_indicators.cuda", unsendable)]
2604pub struct NatrDeviceArrayF32Py {
2605 pub(crate) buf: Option<DeviceBuffer<f32>>,
2606 pub(crate) rows: usize,
2607 pub(crate) cols: usize,
2608 pub(crate) ctx: Arc<Context>,
2609 pub(crate) device_id: u32,
2610}
2611
2612#[cfg(all(feature = "python", feature = "cuda"))]
2613#[pymethods]
2614impl NatrDeviceArrayF32Py {
2615 #[getter]
2616 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2617 let d = PyDict::new(py);
2618 d.set_item("shape", (self.rows, self.cols))?;
2619 d.set_item("typestr", "<f4")?;
2620 d.set_item(
2621 "strides",
2622 (
2623 self.cols * std::mem::size_of::<f32>(),
2624 std::mem::size_of::<f32>(),
2625 ),
2626 )?;
2627 let ptr = self
2628 .buf
2629 .as_ref()
2630 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
2631 .as_device_ptr()
2632 .as_raw() as usize;
2633 d.set_item("data", (ptr, false))?;
2634
2635 d.set_item("version", 3)?;
2636 Ok(d)
2637 }
2638
2639 fn __dlpack_device__(&self) -> (i32, i32) {
2640 (2, self.device_id as i32)
2641 }
2642
2643 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2644 fn __dlpack__<'py>(
2645 &mut self,
2646 py: Python<'py>,
2647 stream: Option<PyObject>,
2648 max_version: Option<PyObject>,
2649 dl_device: Option<PyObject>,
2650 copy: Option<PyObject>,
2651 ) -> PyResult<PyObject> {
2652 let (kdl, alloc_dev) = self.__dlpack_device__();
2653 if let Some(dev_obj) = dl_device.as_ref() {
2654 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2655 if dev_ty != kdl || dev_id != alloc_dev {
2656 let wants_copy = copy
2657 .as_ref()
2658 .and_then(|c| c.extract::<bool>(py).ok())
2659 .unwrap_or(false);
2660 if wants_copy {
2661 return Err(PyBufferError::new_err(
2662 "device copy not implemented for __dlpack__",
2663 ));
2664 } else {
2665 return Err(PyBufferError::new_err(
2666 "__dlpack__: requested device does not match producer buffer",
2667 ));
2668 }
2669 }
2670 }
2671 }
2672 let _ = stream;
2673
2674 if let Some(copy_obj) = copy.as_ref() {
2675 let do_copy: bool = copy_obj.extract(py)?;
2676 if do_copy {
2677 return Err(PyBufferError::new_err(
2678 "__dlpack__(copy=True) not supported for natr CUDA buffers",
2679 ));
2680 }
2681 }
2682
2683 let rows = self.rows;
2684 let cols = self.cols;
2685 let buf = self
2686 .buf
2687 .take()
2688 .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
2689
2690 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2691
2692 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2693 }
2694}
2695
2696#[cfg(all(feature = "python", feature = "cuda"))]
2697#[pyfunction(name = "natr_cuda_batch_dev")]
2698#[pyo3(signature = (high, low, close, period_range, device_id=0))]
2699pub fn natr_cuda_batch_dev_py(
2700 py: Python<'_>,
2701 high: numpy::PyReadonlyArray1<'_, f32>,
2702 low: numpy::PyReadonlyArray1<'_, f32>,
2703 close: numpy::PyReadonlyArray1<'_, f32>,
2704 period_range: (usize, usize, usize),
2705 device_id: usize,
2706) -> PyResult<NatrDeviceArrayF32Py> {
2707 use crate::cuda::cuda_available;
2708 use crate::cuda::CudaNatr;
2709
2710 if !cuda_available() {
2711 return Err(PyValueError::new_err("CUDA not available"));
2712 }
2713 let h = high.as_slice()?;
2714 let l = low.as_slice()?;
2715 let c = close.as_slice()?;
2716 if h.len() != l.len() || h.len() != c.len() {
2717 return Err(PyValueError::new_err("mismatched input lengths"));
2718 }
2719 let sweep = NatrBatchRange {
2720 period: period_range,
2721 };
2722 let dev = py.allow_threads(|| {
2723 let mut cuda =
2724 CudaNatr::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2725 cuda.natr_batch_dev(h, l, c, &sweep)
2726 .map_err(|e| PyValueError::new_err(e.to_string()))
2727 })?;
2728 let DeviceArrayF32Natr {
2729 buf,
2730 rows,
2731 cols,
2732 ctx,
2733 device_id,
2734 } = dev;
2735 Ok(NatrDeviceArrayF32Py {
2736 buf: Some(buf),
2737 rows,
2738 cols,
2739 ctx,
2740 device_id,
2741 })
2742}
2743
2744#[cfg(all(feature = "python", feature = "cuda"))]
2745#[pyfunction(name = "natr_cuda_many_series_one_param_dev")]
2746#[pyo3(signature = (high_tm, low_tm, close_tm, period, device_id=0))]
2747pub fn natr_cuda_many_series_one_param_dev_py(
2748 py: Python<'_>,
2749 high_tm: numpy::PyReadonlyArray2<'_, f32>,
2750 low_tm: numpy::PyReadonlyArray2<'_, f32>,
2751 close_tm: numpy::PyReadonlyArray2<'_, f32>,
2752 period: usize,
2753 device_id: usize,
2754) -> PyResult<NatrDeviceArrayF32Py> {
2755 use crate::cuda::cuda_available;
2756 use crate::cuda::CudaNatr;
2757 use numpy::PyUntypedArrayMethods;
2758
2759 if !cuda_available() {
2760 return Err(PyValueError::new_err("CUDA not available"));
2761 }
2762 let h = high_tm.as_slice()?;
2763 let l = low_tm.as_slice()?;
2764 let c = close_tm.as_slice()?;
2765 let rows = high_tm.shape()[0];
2766 let cols = high_tm.shape()[1];
2767
2768 let dev = py.allow_threads(|| {
2769 let mut cuda =
2770 CudaNatr::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2771 cuda.natr_many_series_one_param_time_major_dev(h, l, c, cols, rows, period)
2772 .map_err(|e| PyValueError::new_err(e.to_string()))
2773 })?;
2774 let DeviceArrayF32Natr {
2775 buf,
2776 rows,
2777 cols,
2778 ctx,
2779 device_id,
2780 } = dev;
2781 Ok(NatrDeviceArrayF32Py {
2782 buf: Some(buf),
2783 rows,
2784 cols,
2785 ctx,
2786 device_id,
2787 })
2788}