1#[cfg(all(feature = "python", feature = "cuda"))]
2use cust::memory::DeviceBuffer;
3#[cfg(feature = "python")]
4use numpy::{IntoPyArray, PyArray1};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyValueError;
7#[cfg(feature = "python")]
8use pyo3::prelude::*;
9#[cfg(feature = "python")]
10use pyo3::types::PyDict;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::{source_type, Candles};
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21 make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25use aligned_vec::{AVec, CACHELINE_ALIGN};
26#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
27use core::arch::x86_64::*;
28#[cfg(not(target_arch = "wasm32"))]
29use rayon::prelude::*;
30use std::convert::AsRef;
31use std::error::Error;
32use thiserror::Error;
33
34#[derive(Debug, Clone)]
35pub enum DtiData<'a> {
36 Candles { candles: &'a Candles },
37 Slices { high: &'a [f64], low: &'a [f64] },
38}
39
40#[derive(Debug, Clone)]
41pub struct DtiOutput {
42 pub values: Vec<f64>,
43}
44
45#[derive(Debug, Clone)]
46#[cfg_attr(
47 all(target_arch = "wasm32", feature = "wasm"),
48 derive(serde::Serialize, serde::Deserialize)
49)]
50pub struct DtiParams {
51 pub r: Option<usize>,
52 pub s: Option<usize>,
53 pub u: Option<usize>,
54}
55
56impl Default for DtiParams {
57 fn default() -> Self {
58 Self {
59 r: Some(14),
60 s: Some(10),
61 u: Some(5),
62 }
63 }
64}
65
66#[derive(Debug, Clone)]
67pub struct DtiInput<'a> {
68 pub data: DtiData<'a>,
69 pub params: DtiParams,
70}
71
72impl<'a> DtiInput<'a> {
73 #[inline]
74 pub fn from_candles(candles: &'a Candles, params: DtiParams) -> Self {
75 Self {
76 data: DtiData::Candles { candles },
77 params,
78 }
79 }
80
81 #[inline]
82 pub fn from_slices(high: &'a [f64], low: &'a [f64], params: DtiParams) -> Self {
83 Self {
84 data: DtiData::Slices { high, low },
85 params,
86 }
87 }
88
89 #[inline]
90 pub fn with_default_candles(candles: &'a Candles) -> Self {
91 Self::from_candles(candles, DtiParams::default())
92 }
93
94 #[inline]
95 pub fn get_r(&self) -> usize {
96 self.params.r.unwrap_or(14)
97 }
98 #[inline]
99 pub fn get_s(&self) -> usize {
100 self.params.s.unwrap_or(10)
101 }
102 #[inline]
103 pub fn get_u(&self) -> usize {
104 self.params.u.unwrap_or(5)
105 }
106}
107
108#[derive(Copy, Clone, Debug)]
109pub struct DtiBuilder {
110 r: Option<usize>,
111 s: Option<usize>,
112 u: Option<usize>,
113 kernel: Kernel,
114}
115
116impl Default for DtiBuilder {
117 fn default() -> Self {
118 Self {
119 r: None,
120 s: None,
121 u: None,
122 kernel: Kernel::Auto,
123 }
124 }
125}
126
127impl DtiBuilder {
128 #[inline(always)]
129 pub fn new() -> Self {
130 Self::default()
131 }
132 #[inline(always)]
133 pub fn r(mut self, n: usize) -> Self {
134 self.r = Some(n);
135 self
136 }
137 #[inline(always)]
138 pub fn s(mut self, n: usize) -> Self {
139 self.s = Some(n);
140 self
141 }
142 #[inline(always)]
143 pub fn u(mut self, n: usize) -> Self {
144 self.u = Some(n);
145 self
146 }
147 #[inline(always)]
148 pub fn kernel(mut self, k: Kernel) -> Self {
149 self.kernel = k;
150 self
151 }
152 #[inline(always)]
153 pub fn apply(self, c: &Candles) -> Result<DtiOutput, DtiError> {
154 let p = DtiParams {
155 r: self.r,
156 s: self.s,
157 u: self.u,
158 };
159 let i = DtiInput::from_candles(c, p);
160 dti_with_kernel(&i, self.kernel)
161 }
162 #[inline(always)]
163 pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<DtiOutput, DtiError> {
164 let p = DtiParams {
165 r: self.r,
166 s: self.s,
167 u: self.u,
168 };
169 let i = DtiInput::from_slices(high, low, p);
170 dti_with_kernel(&i, self.kernel)
171 }
172 #[inline(always)]
173 pub fn into_stream(self) -> Result<DtiStream, DtiError> {
174 let p = DtiParams {
175 r: self.r,
176 s: self.s,
177 u: self.u,
178 };
179 DtiStream::try_new(p)
180 }
181}
182
183#[derive(Debug, Error)]
184pub enum DtiError {
185 #[error("dti: Input data slice is empty.")]
186 EmptyInputData,
187 #[error("dti: Candle field error: {0}")]
188 CandleFieldError(String),
189 #[error("dti: Invalid period: period = {period}, data length = {data_len}")]
190 InvalidPeriod { period: usize, data_len: usize },
191 #[error("dti: Not enough valid data: needed = {needed}, valid = {valid}")]
192 NotEnoughValidData { needed: usize, valid: usize },
193 #[error("dti: All high/low values are NaN.")]
194 AllValuesNaN,
195 #[error("dti: Length mismatch: high length = {high}, low length = {low}")]
196 LengthMismatch { high: usize, low: usize },
197 #[error("dti: output length mismatch: expected = {expected}, got = {got}")]
198 OutputLengthMismatch { expected: usize, got: usize },
199 #[error("dti: invalid kernel for batch: {0:?}")]
200 InvalidKernelForBatch(Kernel),
201 #[error("dti: invalid range: start={start}, end={end}, step={step}")]
202 InvalidRange {
203 start: usize,
204 end: usize,
205 step: usize,
206 },
207}
208
209#[inline]
210pub fn dti(input: &DtiInput) -> Result<DtiOutput, DtiError> {
211 dti_with_kernel(input, Kernel::Auto)
212}
213
214#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
215#[inline]
216pub fn dti_into(input: &DtiInput, out: &mut [f64]) -> Result<(), DtiError> {
217 dti_into_slice(out, input, Kernel::Auto)
218}
219
220#[inline]
221pub fn dti_into_slice(dst: &mut [f64], input: &DtiInput, kern: Kernel) -> Result<(), DtiError> {
222 let (high, low) = match &input.data {
223 DtiData::Candles { candles } => {
224 let high = candles
225 .select_candle_field("high")
226 .map_err(|e| DtiError::CandleFieldError(e.to_string()))?;
227 let low = candles
228 .select_candle_field("low")
229 .map_err(|e| DtiError::CandleFieldError(e.to_string()))?;
230 (high, low)
231 }
232 DtiData::Slices { high, low } => (*high, *low),
233 };
234
235 if high.is_empty() || low.is_empty() {
236 return Err(DtiError::EmptyInputData);
237 }
238 let len = high.len();
239 if low.len() != len {
240 return Err(DtiError::LengthMismatch {
241 high: high.len(),
242 low: low.len(),
243 });
244 }
245 if dst.len() != len {
246 return Err(DtiError::OutputLengthMismatch {
247 expected: len,
248 got: dst.len(),
249 });
250 }
251
252 let first_valid_idx = (0..len)
253 .find(|&i| !high[i].is_nan() && !low[i].is_nan())
254 .ok_or(DtiError::AllValuesNaN)?;
255
256 let r = input.get_r();
257 let s = input.get_s();
258 let u = input.get_u();
259 for &p in &[r, s, u] {
260 if p == 0 || p > len {
261 return Err(DtiError::InvalidPeriod {
262 period: p,
263 data_len: len,
264 });
265 }
266 if len - first_valid_idx < p {
267 return Err(DtiError::NotEnoughValidData {
268 needed: p,
269 valid: len - first_valid_idx,
270 });
271 }
272 }
273
274 let chosen = match kern {
275 Kernel::Auto => detect_best_kernel(),
276 k => k,
277 };
278
279 unsafe {
280 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
281 {
282 match chosen {
283 Kernel::Scalar | Kernel::ScalarBatch => {
284 dti_simd128(high, low, r, s, u, first_valid_idx, dst)
285 }
286 _ => dti_scalar(high, low, r, s, u, first_valid_idx, dst),
287 }
288 }
289 #[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
290 match chosen {
291 Kernel::Scalar | Kernel::ScalarBatch => {
292 dti_scalar(high, low, r, s, u, first_valid_idx, dst)
293 }
294 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
295 Kernel::Avx2 | Kernel::Avx2Batch => dti_avx2(high, low, r, s, u, first_valid_idx, dst),
296 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
297 Kernel::Avx512 | Kernel::Avx512Batch => {
298 dti_avx512(high, low, r, s, u, first_valid_idx, dst)
299 }
300 _ => unreachable!(),
301 }
302 }
303
304 for v in &mut dst[..=first_valid_idx] {
305 *v = f64::NAN;
306 }
307 Ok(())
308}
309
310pub fn dti_with_kernel(input: &DtiInput, kernel: Kernel) -> Result<DtiOutput, DtiError> {
311 let (high, low) = match &input.data {
312 DtiData::Candles { candles } => {
313 let high = candles
314 .select_candle_field("high")
315 .map_err(|e| DtiError::CandleFieldError(e.to_string()))?;
316 let low = candles
317 .select_candle_field("low")
318 .map_err(|e| DtiError::CandleFieldError(e.to_string()))?;
319 (high, low)
320 }
321 DtiData::Slices { high, low } => (*high, *low),
322 };
323
324 if high.is_empty() || low.is_empty() {
325 return Err(DtiError::EmptyInputData);
326 }
327 let len = high.len();
328 if low.len() != len {
329 return Err(DtiError::LengthMismatch {
330 high: high.len(),
331 low: low.len(),
332 });
333 }
334
335 let first_valid_idx = match (0..len).find(|&i| !high[i].is_nan() && !low[i].is_nan()) {
336 Some(idx) => idx,
337 None => return Err(DtiError::AllValuesNaN),
338 };
339
340 let r = input.get_r();
341 let s = input.get_s();
342 let u = input.get_u();
343
344 for &period in &[r, s, u] {
345 if period == 0 || period > len {
346 return Err(DtiError::InvalidPeriod {
347 period,
348 data_len: len,
349 });
350 }
351 if (len - first_valid_idx) < period {
352 return Err(DtiError::NotEnoughValidData {
353 needed: period,
354 valid: len - first_valid_idx,
355 });
356 }
357 }
358
359 let chosen = match kernel {
360 Kernel::Auto => detect_best_kernel(),
361 other => other,
362 };
363
364 let mut out = alloc_with_nan_prefix(len, first_valid_idx + 1);
365 unsafe {
366 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
367 {
368 match chosen {
369 Kernel::Scalar | Kernel::ScalarBatch => {
370 dti_simd128(high, low, r, s, u, first_valid_idx, &mut out)
371 }
372 _ => dti_scalar(high, low, r, s, u, first_valid_idx, &mut out),
373 }
374 return Ok(DtiOutput { values: out });
375 }
376
377 #[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
378 match chosen {
379 Kernel::Scalar | Kernel::ScalarBatch => {
380 dti_scalar(high, low, r, s, u, first_valid_idx, &mut out)
381 }
382 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
383 Kernel::Avx2 | Kernel::Avx2Batch => {
384 dti_avx2(high, low, r, s, u, first_valid_idx, &mut out)
385 }
386 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
387 Kernel::Avx512 | Kernel::Avx512Batch => {
388 dti_avx512(high, low, r, s, u, first_valid_idx, &mut out)
389 }
390 _ => unreachable!(),
391 }
392 }
393 Ok(DtiOutput { values: out })
394}
395
396#[inline]
397pub fn dti_scalar(
398 high: &[f64],
399 low: &[f64],
400 r: usize,
401 s: usize,
402 u: usize,
403 first_valid_idx: usize,
404 out: &mut [f64],
405) {
406 let len = high.len();
407 let alpha_r = 2.0 / (r as f64 + 1.0);
408 let alpha_s = 2.0 / (s as f64 + 1.0);
409 let alpha_u = 2.0 / (u as f64 + 1.0);
410
411 let alpha_r_1 = 1.0 - alpha_r;
412 let alpha_s_1 = 1.0 - alpha_s;
413 let alpha_u_1 = 1.0 - alpha_u;
414
415 let mut e0_r = 0.0;
416 let mut e0_s = 0.0;
417 let mut e0_u = 0.0;
418 let mut e1_r = 0.0;
419 let mut e1_s = 0.0;
420 let mut e1_u = 0.0;
421
422 out[first_valid_idx] = f64::NAN;
423 for i in (first_valid_idx + 1)..len {
424 let dh = high[i] - high[i - 1];
425 let dl = low[i] - low[i - 1];
426 let x_hmu = if dh > 0.0 { dh } else { 0.0 };
427 let x_lmd = if dl < 0.0 { -dl } else { 0.0 };
428 let x_price = x_hmu - x_lmd;
429 let x_price_abs = x_price.abs();
430
431 e0_r = alpha_r * x_price + alpha_r_1 * e0_r;
432 e0_s = alpha_s * e0_r + alpha_s_1 * e0_s;
433 e0_u = alpha_u * e0_s + alpha_u_1 * e0_u;
434
435 e1_r = alpha_r * x_price_abs + alpha_r_1 * e1_r;
436 e1_s = alpha_s * e1_r + alpha_s_1 * e1_s;
437 e1_u = alpha_u * e1_s + alpha_u_1 * e1_u;
438
439 if !e1_u.is_nan() && e1_u != 0.0 {
440 out[i] = 100.0 * e0_u / e1_u;
441 } else {
442 out[i] = 0.0;
443 }
444 }
445}
446
447#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
448#[inline]
449unsafe fn dti_simd128(
450 high: &[f64],
451 low: &[f64],
452 r: usize,
453 s: usize,
454 u: usize,
455 first_valid_idx: usize,
456 out: &mut [f64],
457) {
458 dti_scalar(high, low, r, s, u, first_valid_idx, out);
459}
460
461#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
462#[inline]
463pub fn dti_avx2(
464 high: &[f64],
465 low: &[f64],
466 r: usize,
467 s: usize,
468 u: usize,
469 first_valid_idx: usize,
470 out: &mut [f64],
471) {
472 use core::arch::x86_64::*;
473
474 let len = high.len();
475 let start = first_valid_idx + 1;
476 if start >= len {
477 return;
478 }
479
480 let alpha_r = 2.0 / (r as f64 + 1.0);
481 let alpha_s = 2.0 / (s as f64 + 1.0);
482 let alpha_u = 2.0 / (u as f64 + 1.0);
483 let ar1 = 1.0 - alpha_r;
484 let as1 = 1.0 - alpha_s;
485 let au1 = 1.0 - alpha_u;
486
487 unsafe {
488 let vr_a = _mm_set1_pd(alpha_r);
489 let vs_a = _mm_set1_pd(alpha_s);
490 let vu_a = _mm_set1_pd(alpha_u);
491 let vr_b = _mm_set1_pd(ar1);
492 let vs_b = _mm_set1_pd(as1);
493 let vu_b = _mm_set1_pd(au1);
494
495 let mut v_er = _mm_set1_pd(0.0);
496 let mut v_es = _mm_set1_pd(0.0);
497 let mut v_eu = _mm_set1_pd(0.0);
498
499 let ph = high.as_ptr();
500 let pl = low.as_ptr();
501 let po = out.as_mut_ptr();
502 let half = 0.5f64;
503 let hundred = 100.0f64;
504
505 let mut i = start;
506 while i < len {
507 let hi0 = *ph.add(i);
508 let hi_1 = *ph.add(i - 1);
509 let lo0 = *pl.add(i);
510 let lo_1 = *pl.add(i - 1);
511
512 let dh = hi0 - hi_1;
513 let dl = lo0 - lo_1;
514
515 let x = half * (dh.abs() + dh) - half * (dl.abs() - dl);
516 let ax = x.abs();
517
518 let vx = _mm_set_pd(ax, x);
519
520 v_er = _mm_add_pd(_mm_mul_pd(vr_a, vx), _mm_mul_pd(vr_b, v_er));
521 v_es = _mm_add_pd(_mm_mul_pd(vs_a, v_er), _mm_mul_pd(vs_b, v_es));
522 v_eu = _mm_add_pd(_mm_mul_pd(vu_a, v_es), _mm_mul_pd(vu_b, v_eu));
523
524 let mut tmp = [0.0f64; 2];
525 _mm_storeu_pd(tmp.as_mut_ptr(), v_eu);
526 let num = tmp[0];
527 let den = tmp[1];
528
529 *po.add(i) = if !den.is_nan() && den != 0.0 {
530 hundred * num / den
531 } else {
532 0.0
533 };
534
535 i += 1;
536 }
537 }
538}
539
540#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
541#[inline]
542pub fn dti_avx512(
543 high: &[f64],
544 low: &[f64],
545 r: usize,
546 s: usize,
547 u: usize,
548 first_valid_idx: usize,
549 out: &mut [f64],
550) {
551 dti_avx2(high, low, r, s, u, first_valid_idx, out)
552}
553
554#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
555#[inline]
556pub unsafe fn dti_avx512_short(
557 high: &[f64],
558 low: &[f64],
559 r: usize,
560 s: usize,
561 u: usize,
562 first_valid_idx: usize,
563 out: &mut [f64],
564) {
565 dti_avx2(high, low, r, s, u, first_valid_idx, out)
566}
567
568#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
569#[inline]
570pub unsafe fn dti_avx512_long(
571 high: &[f64],
572 low: &[f64],
573 r: usize,
574 s: usize,
575 u: usize,
576 first_valid_idx: usize,
577 out: &mut [f64],
578) {
579 dti_avx2(high, low, r, s, u, first_valid_idx, out)
580}
581
582#[inline(always)]
583pub fn dti_row_scalar(
584 high: &[f64],
585 low: &[f64],
586 r: usize,
587 s: usize,
588 u: usize,
589 first_valid_idx: usize,
590 out: &mut [f64],
591) {
592 dti_scalar(high, low, r, s, u, first_valid_idx, out)
593}
594
595#[inline(always)]
596fn dti_precompute_base(high: &[f64], low: &[f64], start: usize) -> (AVec<f64>, AVec<f64>) {
597 let len = high.len();
598 let mut x = AVec::<f64>::new(CACHELINE_ALIGN);
599 let mut ax = AVec::<f64>::new(CACHELINE_ALIGN);
600 x.resize(len, 0.0);
601 ax.resize(len, 0.0);
602 for i in (start)..len {
603 let dh = high[i] - high[i - 1];
604 let dl = low[i] - low[i - 1];
605 let x_hmu = if dh > 0.0 { dh } else { 0.0 };
606 let x_lmd = if dl < 0.0 { -dl } else { 0.0 };
607 let v = x_hmu - x_lmd;
608 x[i] = v;
609 ax[i] = v.abs();
610 }
611 (x, ax)
612}
613
614#[inline(always)]
615fn dti_row_scalar_from_base(
616 x: &[f64],
617 ax: &[f64],
618 r: usize,
619 s: usize,
620 u: usize,
621 start: usize,
622 out: &mut [f64],
623) {
624 let len = x.len();
625 if start >= len {
626 return;
627 }
628 let alpha_r = 2.0 / (r as f64 + 1.0);
629 let alpha_s = 2.0 / (s as f64 + 1.0);
630 let alpha_u = 2.0 / (u as f64 + 1.0);
631 let ar1 = 1.0 - alpha_r;
632 let as1 = 1.0 - alpha_s;
633 let au1 = 1.0 - alpha_u;
634 let mut e0_r = 0.0;
635 let mut e0_s = 0.0;
636 let mut e0_u = 0.0;
637 let mut e1_r = 0.0;
638 let mut e1_s = 0.0;
639 let mut e1_u = 0.0;
640 for i in start..len {
641 let xi = x[i];
642 let axi = ax[i];
643 e0_r = alpha_r * xi + ar1 * e0_r;
644 e0_s = alpha_s * e0_r + as1 * e0_s;
645 e0_u = alpha_u * e0_s + au1 * e0_u;
646 e1_r = alpha_r * axi + ar1 * e1_r;
647 e1_s = alpha_s * e1_r + as1 * e1_s;
648 e1_u = alpha_u * e1_s + au1 * e1_u;
649 out[i] = if !e1_u.is_nan() && e1_u != 0.0 {
650 100.0 * e0_u / e1_u
651 } else {
652 0.0
653 };
654 }
655}
656
657#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
658#[inline(always)]
659fn dti_row_avx2_from_base(
660 x: &[f64],
661 ax: &[f64],
662 r: usize,
663 s: usize,
664 u: usize,
665 start: usize,
666 out: &mut [f64],
667) {
668 unsafe {
669 use core::arch::x86_64::*;
670 let len = x.len();
671 if start >= len {
672 return;
673 }
674 let alpha_r = 2.0 / (r as f64 + 1.0);
675 let alpha_s = 2.0 / (s as f64 + 1.0);
676 let alpha_u = 2.0 / (u as f64 + 1.0);
677 let ar1 = 1.0 - alpha_r;
678 let as1 = 1.0 - alpha_s;
679 let au1 = 1.0 - alpha_u;
680
681 let vr_a = _mm_set1_pd(alpha_r);
682 let vs_a = _mm_set1_pd(alpha_s);
683 let vu_a = _mm_set1_pd(alpha_u);
684 let vr_b = _mm_set1_pd(ar1);
685 let vs_b = _mm_set1_pd(as1);
686 let vu_b = _mm_set1_pd(au1);
687
688 let mut v_er = _mm_set1_pd(0.0);
689 let mut v_es = _mm_set1_pd(0.0);
690 let mut v_eu = _mm_set1_pd(0.0);
691 let px = x.as_ptr();
692 let pax = ax.as_ptr();
693 let po = out.as_mut_ptr();
694 let hundred = 100.0f64;
695 let mut i = start;
696 while i < len {
697 let xi = *px.add(i);
698 let axi = *pax.add(i);
699 let vx = _mm_set_pd(axi, xi);
700 v_er = _mm_add_pd(_mm_mul_pd(vr_a, vx), _mm_mul_pd(vr_b, v_er));
701 v_es = _mm_add_pd(_mm_mul_pd(vs_a, v_er), _mm_mul_pd(vs_b, v_es));
702 v_eu = _mm_add_pd(_mm_mul_pd(vu_a, v_es), _mm_mul_pd(vu_b, v_eu));
703 let mut tmp = [0.0f64; 2];
704 _mm_storeu_pd(tmp.as_mut_ptr(), v_eu);
705 let num = tmp[0];
706 let den = tmp[1];
707 *po.add(i) = if !den.is_nan() && den != 0.0 {
708 hundred * num / den
709 } else {
710 0.0
711 };
712 i += 1;
713 }
714 }
715}
716
717#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
718#[inline(always)]
719fn dti_row_avx512_from_base(
720 x: &[f64],
721 ax: &[f64],
722 r: usize,
723 s: usize,
724 u: usize,
725 start: usize,
726 out: &mut [f64],
727) {
728 dti_row_avx2_from_base(x, ax, r, s, u, start, out)
729}
730
731#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
732#[inline(always)]
733pub fn dti_row_avx2(
734 high: &[f64],
735 low: &[f64],
736 r: usize,
737 s: usize,
738 u: usize,
739 first_valid_idx: usize,
740 out: &mut [f64],
741) {
742 dti_avx2(high, low, r, s, u, first_valid_idx, out)
743}
744
745#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
746#[inline(always)]
747pub fn dti_row_avx512(
748 high: &[f64],
749 low: &[f64],
750 r: usize,
751 s: usize,
752 u: usize,
753 first_valid_idx: usize,
754 out: &mut [f64],
755) {
756 dti_avx512(high, low, r, s, u, first_valid_idx, out)
757}
758
759#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
760#[inline(always)]
761pub fn dti_row_avx512_short(
762 high: &[f64],
763 low: &[f64],
764 r: usize,
765 s: usize,
766 u: usize,
767 first_valid_idx: usize,
768 out: &mut [f64],
769) {
770 unsafe { dti_avx512_short(high, low, r, s, u, first_valid_idx, out) }
771}
772
773#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
774#[inline(always)]
775pub fn dti_row_avx512_long(
776 high: &[f64],
777 low: &[f64],
778 r: usize,
779 s: usize,
780 u: usize,
781 first_valid_idx: usize,
782 out: &mut [f64],
783) {
784 unsafe { dti_avx512_long(high, low, r, s, u, first_valid_idx, out) }
785}
786
787#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
788#[inline(always)]
789pub fn dti_row_simd128(
790 high: &[f64],
791 low: &[f64],
792 r: usize,
793 s: usize,
794 u: usize,
795 first_valid_idx: usize,
796 out: &mut [f64],
797) {
798 unsafe { dti_simd128(high, low, r, s, u, first_valid_idx, out) }
799}
800
801#[derive(Debug, Clone)]
802pub struct DtiStream {
803 r: usize,
804 s: usize,
805 u: usize,
806 alpha_r: f64,
807 alpha_s: f64,
808 alpha_u: f64,
809 alpha_r_1: f64,
810 alpha_s_1: f64,
811 alpha_u_1: f64,
812 e0_r: f64,
813 e0_s: f64,
814 e0_u: f64,
815 e1_r: f64,
816 e1_s: f64,
817 e1_u: f64,
818 last_high: Option<f64>,
819 last_low: Option<f64>,
820 initialized: bool,
821}
822
823impl DtiStream {
824 pub fn try_new(params: DtiParams) -> Result<Self, DtiError> {
825 let r = params.r.unwrap_or(14);
826 let s = params.s.unwrap_or(10);
827 let u = params.u.unwrap_or(5);
828 if r == 0 || s == 0 || u == 0 {
829 return Err(DtiError::InvalidPeriod {
830 period: 0,
831 data_len: 0,
832 });
833 }
834 let alpha_r = 2.0 / (r as f64 + 1.0);
835 let alpha_s = 2.0 / (s as f64 + 1.0);
836 let alpha_u = 2.0 / (u as f64 + 1.0);
837 let alpha_r_1 = 1.0 - alpha_r;
838 let alpha_s_1 = 1.0 - alpha_s;
839 let alpha_u_1 = 1.0 - alpha_u;
840 Ok(Self {
841 r,
842 s,
843 u,
844 alpha_r,
845 alpha_s,
846 alpha_u,
847 alpha_r_1,
848 alpha_s_1,
849 alpha_u_1,
850 e0_r: 0.0,
851 e0_s: 0.0,
852 e0_u: 0.0,
853 e1_r: 0.0,
854 e1_s: 0.0,
855 e1_u: 0.0,
856 last_high: None,
857 last_low: None,
858 initialized: false,
859 })
860 }
861
862 #[inline(always)]
863 pub fn update(&mut self, high: f64, low: f64) -> Option<f64> {
864 if let (Some(last_h), Some(last_l)) = (self.last_high, self.last_low) {
865 let dh = high - last_h;
866 let dl = low - last_l;
867 let x_hmu = if dh > 0.0 { dh } else { 0.0 };
868 let x_lmd = if dl < 0.0 { -dl } else { 0.0 };
869 let x_price = x_hmu - x_lmd;
870 let x_price_abs = x_price.abs();
871
872 self.e0_r = (x_price - self.e0_r).mul_add(self.alpha_r, self.e0_r);
873 self.e0_s = (self.e0_r - self.e0_s).mul_add(self.alpha_s, self.e0_s);
874 self.e0_u = (self.e0_s - self.e0_u).mul_add(self.alpha_u, self.e0_u);
875
876 self.e1_r = (x_price_abs - self.e1_r).mul_add(self.alpha_r, self.e1_r);
877 self.e1_s = (self.e1_r - self.e1_s).mul_add(self.alpha_s, self.e1_s);
878 self.e1_u = (self.e1_s - self.e1_u).mul_add(self.alpha_u, self.e1_u);
879
880 self.last_high = Some(high);
881 self.last_low = Some(low);
882
883 if !self.e1_u.is_nan() && self.e1_u != 0.0 {
884 Some(fast_div_approx(self.e0_u * 100.0, self.e1_u))
885 } else {
886 Some(0.0)
887 }
888 } else {
889 self.last_high = Some(high);
890 self.last_low = Some(low);
891 self.initialized = true;
892 None
893 }
894 }
895
896 #[inline(always)]
897 pub fn reset(&mut self) {
898 self.e0_r = 0.0;
899 self.e0_s = 0.0;
900 self.e0_u = 0.0;
901 self.e1_r = 0.0;
902 self.e1_s = 0.0;
903 self.e1_u = 0.0;
904 self.last_high = None;
905 self.last_low = None;
906 self.initialized = false;
907 }
908
909 #[inline(always)]
910 pub fn update_delta(&mut self, dh: f64, dl: f64) -> Option<f64> {
911 if let (Some(prev_h), Some(prev_l)) = (self.last_high, self.last_low) {
912 let high = prev_h + dh;
913 let low = prev_l + dl;
914
915 let x_hmu = if dh > 0.0 { dh } else { 0.0 };
916 let x_lmd = if dl < 0.0 { -dl } else { 0.0 };
917 let x_price = x_hmu - x_lmd;
918 let x_price_abs = x_price.abs();
919
920 self.e0_r = (x_price - self.e0_r).mul_add(self.alpha_r, self.e0_r);
921 self.e0_s = (self.e0_r - self.e0_s).mul_add(self.alpha_s, self.e0_s);
922 self.e0_u = (self.e0_s - self.e0_u).mul_add(self.alpha_u, self.e0_u);
923
924 self.e1_r = (x_price_abs - self.e1_r).mul_add(self.alpha_r, self.e1_r);
925 self.e1_s = (self.e1_r - self.e1_s).mul_add(self.alpha_s, self.e1_s);
926 self.e1_u = (self.e1_s - self.e1_u).mul_add(self.alpha_u, self.e1_u);
927
928 self.last_high = Some(high);
929 self.last_low = Some(low);
930
931 if !self.e1_u.is_nan() && self.e1_u != 0.0 {
932 Some(fast_div_approx(self.e0_u * 100.0, self.e1_u))
933 } else {
934 Some(0.0)
935 }
936 } else {
937 None
938 }
939 }
940}
941
942#[inline(always)]
943fn fast_div_approx(num: f64, den: f64) -> f64 {
944 debug_assert!(den != 0.0);
945
946 let ad = den.abs();
947 if ad <= f32::MAX as f64 && ad >= f32::MIN_POSITIVE as f64 {
948 let r0 = (1.0f32 / den as f32) as f64;
949 let r1 = r0 * (2.0 - den * r0);
950 num * r1
951 } else {
952 num / den
953 }
954}
955
956#[cfg(feature = "python")]
957#[pyfunction(name = "dti")]
958#[pyo3(signature = (high, low, r, s, u, kernel=None))]
959pub fn dti_py<'py>(
960 py: Python<'py>,
961 high: numpy::PyReadonlyArray1<'py, f64>,
962 low: numpy::PyReadonlyArray1<'py, f64>,
963 r: usize,
964 s: usize,
965 u: usize,
966 kernel: Option<&str>,
967) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
968 use numpy::{IntoPyArray, PyArrayMethods};
969
970 let high_slice = high.as_slice()?;
971 let low_slice = low.as_slice()?;
972 let kern = validate_kernel(kernel, false)?;
973
974 let params = DtiParams {
975 r: Some(r),
976 s: Some(s),
977 u: Some(u),
978 };
979 let input = DtiInput::from_slices(high_slice, low_slice, params);
980
981 let result_vec: Vec<f64> = py
982 .allow_threads(|| dti_with_kernel(&input, kern).map(|o| o.values))
983 .map_err(|e| PyValueError::new_err(e.to_string()))?;
984
985 Ok(result_vec.into_pyarray(py))
986}
987
988#[cfg(feature = "python")]
989#[pyclass(name = "DtiStream")]
990pub struct DtiStreamPy {
991 stream: DtiStream,
992}
993
994#[cfg(feature = "python")]
995#[pymethods]
996impl DtiStreamPy {
997 #[new]
998 fn new(r: usize, s: usize, u: usize) -> PyResult<Self> {
999 let params = DtiParams {
1000 r: Some(r),
1001 s: Some(s),
1002 u: Some(u),
1003 };
1004 let stream =
1005 DtiStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1006 Ok(DtiStreamPy { stream })
1007 }
1008
1009 fn update(&mut self, high: f64, low: f64) -> Option<f64> {
1010 self.stream.update(high, low)
1011 }
1012}
1013
1014#[cfg(feature = "python")]
1015#[pyfunction(name = "dti_batch")]
1016#[pyo3(signature = (high, low, r_range, s_range, u_range, kernel=None))]
1017pub fn dti_batch_py<'py>(
1018 py: Python<'py>,
1019 high: numpy::PyReadonlyArray1<'py, f64>,
1020 low: numpy::PyReadonlyArray1<'py, f64>,
1021 r_range: (usize, usize, usize),
1022 s_range: (usize, usize, usize),
1023 u_range: (usize, usize, usize),
1024 kernel: Option<&str>,
1025) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1026 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1027 use pyo3::types::PyDict;
1028
1029 let high_slice = high.as_slice()?;
1030 let low_slice = low.as_slice()?;
1031
1032 let sweep = DtiBatchRange {
1033 r: r_range,
1034 s: s_range,
1035 u: u_range,
1036 };
1037
1038 let combos = expand_grid(&sweep);
1039 let rows = combos.len();
1040 let cols = high_slice.len();
1041
1042 let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1043 let slice_out = unsafe { out_arr.as_slice_mut()? };
1044
1045 let kern = validate_kernel(kernel, true)?;
1046
1047 let combos = py
1048 .allow_threads(|| {
1049 let kernel = match kern {
1050 Kernel::Auto => detect_best_batch_kernel(),
1051 k => k,
1052 };
1053 let simd = match kernel {
1054 Kernel::Avx512Batch => Kernel::Avx512,
1055 Kernel::Avx2Batch => Kernel::Avx2,
1056 Kernel::ScalarBatch => Kernel::Scalar,
1057 _ => unreachable!(),
1058 };
1059 dti_batch_inner_into(high_slice, low_slice, &sweep, simd, true, slice_out)
1060 })
1061 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1062
1063 let dict = PyDict::new(py);
1064 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1065 dict.set_item(
1066 "r",
1067 combos
1068 .iter()
1069 .map(|p| p.r.unwrap() as u64)
1070 .collect::<Vec<_>>()
1071 .into_pyarray(py),
1072 )?;
1073 dict.set_item(
1074 "s",
1075 combos
1076 .iter()
1077 .map(|p| p.s.unwrap() as u64)
1078 .collect::<Vec<_>>()
1079 .into_pyarray(py),
1080 )?;
1081 dict.set_item(
1082 "u",
1083 combos
1084 .iter()
1085 .map(|p| p.u.unwrap() as u64)
1086 .collect::<Vec<_>>()
1087 .into_pyarray(py),
1088 )?;
1089 Ok(dict)
1090}
1091
1092#[cfg(all(feature = "python", feature = "cuda"))]
1093use crate::cuda::oscillators::dti_wrapper::DeviceArrayF32Dti;
1094#[cfg(all(feature = "python", feature = "cuda"))]
1095use crate::cuda::oscillators::CudaDti;
1096#[cfg(all(feature = "python", feature = "cuda"))]
1097use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1098
1099#[cfg(all(feature = "python", feature = "cuda"))]
1100#[pyclass(module = "ta_indicators.cuda", unsendable)]
1101pub struct DtiDeviceArrayF32Py {
1102 pub(crate) inner: DeviceArrayF32Dti,
1103}
1104
1105#[cfg(all(feature = "python", feature = "cuda"))]
1106#[pymethods]
1107impl DtiDeviceArrayF32Py {
1108 #[getter]
1109 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1110 let d = PyDict::new(py);
1111 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1112 d.set_item("typestr", "<f4")?;
1113 d.set_item(
1114 "strides",
1115 (
1116 self.inner.cols * std::mem::size_of::<f32>(),
1117 std::mem::size_of::<f32>(),
1118 ),
1119 )?;
1120 d.set_item("data", (self.inner.device_ptr() as usize, false))?;
1121
1122 d.set_item("version", 3)?;
1123 Ok(d)
1124 }
1125
1126 fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
1127 let mut device_ordinal: i32 = self.inner.device_id as i32;
1128 unsafe {
1129 let attr = cust::sys::CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL;
1130 let mut value = std::mem::MaybeUninit::<i32>::uninit();
1131 let err = cust::sys::cuPointerGetAttribute(
1132 value.as_mut_ptr() as *mut std::ffi::c_void,
1133 attr,
1134 self.inner.buf.as_device_ptr().as_raw(),
1135 );
1136 if err == cust::sys::CUresult::CUDA_SUCCESS {
1137 device_ordinal = value.assume_init();
1138 } else {
1139 let _ = cust::sys::cuCtxGetDevice(&mut device_ordinal);
1140 }
1141 }
1142 Ok((2, device_ordinal))
1143 }
1144
1145 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1146 fn __dlpack__<'py>(
1147 &mut self,
1148 py: Python<'py>,
1149 stream: Option<pyo3::PyObject>,
1150 max_version: Option<pyo3::PyObject>,
1151 dl_device: Option<pyo3::PyObject>,
1152 copy: Option<pyo3::PyObject>,
1153 ) -> PyResult<PyObject> {
1154 let (kdl, alloc_dev) = self.__dlpack_device__()?;
1155 if let Some(dev_obj) = dl_device.as_ref() {
1156 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1157 if dev_ty != kdl || dev_id != alloc_dev {
1158 let wants_copy = copy
1159 .as_ref()
1160 .and_then(|c| c.extract::<bool>(py).ok())
1161 .unwrap_or(false);
1162 if wants_copy {
1163 return Err(PyValueError::new_err(
1164 "device copy not implemented for __dlpack__",
1165 ));
1166 } else {
1167 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1168 }
1169 }
1170 }
1171 }
1172 let _ = stream;
1173
1174 let ctx_arc = self.inner.ctx.clone();
1175 let dummy =
1176 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1177 let inner = std::mem::replace(
1178 &mut self.inner,
1179 DeviceArrayF32Dti {
1180 buf: dummy,
1181 rows: 0,
1182 cols: 0,
1183 ctx: ctx_arc,
1184 device_id: alloc_dev as u32,
1185 },
1186 );
1187
1188 let rows = inner.rows;
1189 let cols = inner.cols;
1190 let buf = inner.buf;
1191
1192 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1193
1194 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1195 }
1196}
1197
1198#[cfg(all(feature = "python", feature = "cuda"))]
1199#[pyfunction(name = "dti_cuda_batch_dev")]
1200#[pyo3(signature = (high_f32, low_f32, r_range, s_range, u_range, device_id=0))]
1201pub fn dti_cuda_batch_dev_py<'py>(
1202 py: Python<'py>,
1203 high_f32: numpy::PyReadonlyArray1<'py, f32>,
1204 low_f32: numpy::PyReadonlyArray1<'py, f32>,
1205 r_range: (usize, usize, usize),
1206 s_range: (usize, usize, usize),
1207 u_range: (usize, usize, usize),
1208 device_id: usize,
1209) -> PyResult<(DtiDeviceArrayF32Py, Bound<'py, pyo3::types::PyDict>)> {
1210 use crate::cuda::cuda_available;
1211 use numpy::IntoPyArray;
1212 if !cuda_available() {
1213 return Err(PyValueError::new_err("CUDA not available"));
1214 }
1215 let high = high_f32.as_slice()?;
1216 let low = low_f32.as_slice()?;
1217 let sweep = DtiBatchRange {
1218 r: r_range,
1219 s: s_range,
1220 u: u_range,
1221 };
1222 let (inner, combos) = py.allow_threads(|| {
1223 let cuda = CudaDti::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1224 cuda.dti_batch_dev(high, low, &sweep)
1225 .map_err(|e| PyValueError::new_err(e.to_string()))
1226 })?;
1227 let dict = PyDict::new(py);
1228 let rr: Vec<u64> = combos.iter().map(|c| c.r.unwrap() as u64).collect();
1229 let ss: Vec<u64> = combos.iter().map(|c| c.s.unwrap() as u64).collect();
1230 let uu: Vec<u64> = combos.iter().map(|c| c.u.unwrap() as u64).collect();
1231 dict.set_item("r", rr.into_pyarray(py))?;
1232 dict.set_item("s", ss.into_pyarray(py))?;
1233 dict.set_item("u", uu.into_pyarray(py))?;
1234 Ok((DtiDeviceArrayF32Py { inner }, dict))
1235}
1236
1237#[cfg(all(feature = "python", feature = "cuda"))]
1238#[pyfunction(name = "dti_cuda_many_series_one_param_dev")]
1239#[pyo3(signature = (high_tm_f32, low_tm_f32, r, s, u, device_id=0))]
1240pub fn dti_cuda_many_series_one_param_dev_py(
1241 py: Python<'_>,
1242 high_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1243 low_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1244 r: usize,
1245 s: usize,
1246 u: usize,
1247 device_id: usize,
1248) -> PyResult<DtiDeviceArrayF32Py> {
1249 use crate::cuda::cuda_available;
1250 use numpy::PyUntypedArrayMethods;
1251 if !cuda_available() {
1252 return Err(PyValueError::new_err("CUDA not available"));
1253 }
1254 let high_flat = high_tm_f32.as_slice()?;
1255 let low_flat = low_tm_f32.as_slice()?;
1256 let rows = high_tm_f32.shape()[0];
1257 let cols = high_tm_f32.shape()[1];
1258 let elems = cols
1259 .checked_mul(rows)
1260 .ok_or_else(|| PyValueError::new_err("matrix size overflow"))?;
1261 if low_tm_f32.shape() != [rows, cols] {
1262 return Err(PyValueError::new_err("high/low shapes mismatch"));
1263 }
1264 if high_tm_f32.len() != elems || low_tm_f32.len() != elems {
1265 return Err(PyValueError::new_err("high/low flattened sizes mismatch"));
1266 }
1267 let params = DtiParams {
1268 r: Some(r),
1269 s: Some(s),
1270 u: Some(u),
1271 };
1272 let inner = py.allow_threads(|| {
1273 let cuda = CudaDti::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1274 cuda.dti_many_series_one_param_time_major_dev(high_flat, low_flat, cols, rows, ¶ms)
1275 .map_err(|e| PyValueError::new_err(e.to_string()))
1276 })?;
1277 Ok(DtiDeviceArrayF32Py { inner })
1278}
1279
1280#[derive(Clone, Debug)]
1281pub struct DtiBatchRange {
1282 pub r: (usize, usize, usize),
1283 pub s: (usize, usize, usize),
1284 pub u: (usize, usize, usize),
1285}
1286
1287impl Default for DtiBatchRange {
1288 fn default() -> Self {
1289 Self {
1290 r: (14, 263, 1),
1291 s: (10, 10, 0),
1292 u: (5, 5, 0),
1293 }
1294 }
1295}
1296
1297#[derive(Clone, Debug, Default)]
1298pub struct DtiBatchBuilder {
1299 range: DtiBatchRange,
1300 kernel: Kernel,
1301}
1302
1303impl DtiBatchBuilder {
1304 pub fn new() -> Self {
1305 Self::default()
1306 }
1307 pub fn kernel(mut self, k: Kernel) -> Self {
1308 self.kernel = k;
1309 self
1310 }
1311 #[inline]
1312 pub fn r_range(mut self, start: usize, end: usize, step: usize) -> Self {
1313 self.range.r = (start, end, step);
1314 self
1315 }
1316 #[inline]
1317 pub fn r_static(mut self, p: usize) -> Self {
1318 self.range.r = (p, p, 0);
1319 self
1320 }
1321 #[inline]
1322 pub fn s_range(mut self, start: usize, end: usize, step: usize) -> Self {
1323 self.range.s = (start, end, step);
1324 self
1325 }
1326 #[inline]
1327 pub fn s_static(mut self, x: usize) -> Self {
1328 self.range.s = (x, x, 0);
1329 self
1330 }
1331 #[inline]
1332 pub fn u_range(mut self, start: usize, end: usize, step: usize) -> Self {
1333 self.range.u = (start, end, step);
1334 self
1335 }
1336 #[inline]
1337 pub fn u_static(mut self, s: usize) -> Self {
1338 self.range.u = (s, s, 0);
1339 self
1340 }
1341 pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<DtiBatchOutput, DtiError> {
1342 dti_batch_with_kernel(high, low, &self.range, self.kernel)
1343 }
1344 pub fn with_default_slices(
1345 high: &[f64],
1346 low: &[f64],
1347 k: Kernel,
1348 ) -> Result<DtiBatchOutput, DtiError> {
1349 DtiBatchBuilder::new().kernel(k).apply_slices(high, low)
1350 }
1351 pub fn apply_candles(self, c: &Candles) -> Result<DtiBatchOutput, DtiError> {
1352 let high = c
1353 .select_candle_field("high")
1354 .map_err(|e| DtiError::CandleFieldError(e.to_string()))?;
1355 let low = c
1356 .select_candle_field("low")
1357 .map_err(|e| DtiError::CandleFieldError(e.to_string()))?;
1358 self.apply_slices(high, low)
1359 }
1360 pub fn with_default_candles(c: &Candles) -> Result<DtiBatchOutput, DtiError> {
1361 DtiBatchBuilder::new().kernel(Kernel::Auto).apply_candles(c)
1362 }
1363}
1364
1365#[derive(Clone, Debug)]
1366pub struct DtiBatchOutput {
1367 pub values: Vec<f64>,
1368 pub combos: Vec<DtiParams>,
1369 pub rows: usize,
1370 pub cols: usize,
1371}
1372
1373impl DtiBatchOutput {
1374 pub fn row_for_params(&self, p: &DtiParams) -> Option<usize> {
1375 self.combos.iter().position(|c| {
1376 c.r.unwrap_or(14) == p.r.unwrap_or(14)
1377 && c.s.unwrap_or(10) == p.s.unwrap_or(10)
1378 && c.u.unwrap_or(5) == p.u.unwrap_or(5)
1379 })
1380 }
1381 pub fn values_for(&self, p: &DtiParams) -> Option<&[f64]> {
1382 self.row_for_params(p).map(|row| {
1383 let start = row * self.cols;
1384 &self.values[start..start + self.cols]
1385 })
1386 }
1387}
1388
1389#[inline(always)]
1390fn expand_grid(r: &DtiBatchRange) -> Vec<DtiParams> {
1391 #[inline(always)]
1392 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
1393 if step == 0 || start == end {
1394 return vec![start];
1395 }
1396 let mut v = Vec::new();
1397 if start < end {
1398 let mut cur = start;
1399 while cur <= end {
1400 v.push(cur);
1401 match cur.checked_add(step) {
1402 Some(nxt) => {
1403 if nxt > end {
1404 break;
1405 }
1406 cur = nxt;
1407 }
1408 None => break,
1409 }
1410 }
1411 } else {
1412 let mut cur = start;
1413 loop {
1414 if cur < end {
1415 break;
1416 }
1417 v.push(cur);
1418 if cur == end {
1419 break;
1420 }
1421 match cur.checked_sub(step) {
1422 Some(nxt) => {
1423 if nxt < end {
1424 break;
1425 }
1426 cur = nxt;
1427 }
1428 None => break,
1429 }
1430 }
1431 }
1432 v
1433 }
1434
1435 let rr = axis_usize(r.r);
1436 let ss = axis_usize(r.s);
1437 let uu = axis_usize(r.u);
1438
1439 let cap = rr
1440 .len()
1441 .checked_mul(ss.len())
1442 .and_then(|x| x.checked_mul(uu.len()))
1443 .unwrap_or(0);
1444 let mut out = Vec::with_capacity(cap);
1445 for &rv in &rr {
1446 for &sv in &ss {
1447 for &uv in &uu {
1448 out.push(DtiParams {
1449 r: Some(rv),
1450 s: Some(sv),
1451 u: Some(uv),
1452 });
1453 }
1454 }
1455 }
1456 out
1457}
1458
1459#[inline(always)]
1460pub fn dti_batch_with_kernel(
1461 high: &[f64],
1462 low: &[f64],
1463 sweep: &DtiBatchRange,
1464 k: Kernel,
1465) -> Result<DtiBatchOutput, DtiError> {
1466 let kernel = match k {
1467 Kernel::Auto => detect_best_batch_kernel(),
1468 other if other.is_batch() => other,
1469 other => return Err(DtiError::InvalidKernelForBatch(other)),
1470 };
1471 let simd = match kernel {
1472 Kernel::Avx512Batch => Kernel::Avx512,
1473 Kernel::Avx2Batch => Kernel::Avx2,
1474 Kernel::ScalarBatch => Kernel::Scalar,
1475 _ => unreachable!(),
1476 };
1477 dti_batch_par_slice(high, low, sweep, simd)
1478}
1479
1480#[inline(always)]
1481pub fn dti_batch_slice(
1482 high: &[f64],
1483 low: &[f64],
1484 sweep: &DtiBatchRange,
1485 kern: Kernel,
1486) -> Result<DtiBatchOutput, DtiError> {
1487 dti_batch_inner(high, low, sweep, kern, false)
1488}
1489
1490#[inline(always)]
1491pub fn dti_batch_par_slice(
1492 high: &[f64],
1493 low: &[f64],
1494 sweep: &DtiBatchRange,
1495 kern: Kernel,
1496) -> Result<DtiBatchOutput, DtiError> {
1497 dti_batch_inner(high, low, sweep, kern, true)
1498}
1499
1500#[inline(always)]
1501fn dti_batch_inner(
1502 high: &[f64],
1503 low: &[f64],
1504 sweep: &DtiBatchRange,
1505 kern: Kernel,
1506 parallel: bool,
1507) -> Result<DtiBatchOutput, DtiError> {
1508 let combos = expand_grid(sweep);
1509 if combos.is_empty() {
1510 return Err(DtiError::InvalidRange {
1511 start: 0,
1512 end: 0,
1513 step: 0,
1514 });
1515 }
1516 let len = high.len();
1517 if low.len() != len {
1518 return Err(DtiError::LengthMismatch {
1519 high: high.len(),
1520 low: low.len(),
1521 });
1522 }
1523 let first_valid = (0..len)
1524 .find(|&i| !high[i].is_nan() && !low[i].is_nan())
1525 .ok_or(DtiError::AllValuesNaN)?;
1526 let max_p = combos
1527 .iter()
1528 .map(|c| c.r.unwrap().max(c.s.unwrap()).max(c.u.unwrap()))
1529 .max()
1530 .unwrap();
1531 if len - first_valid < max_p {
1532 return Err(DtiError::NotEnoughValidData {
1533 needed: max_p,
1534 valid: len - first_valid,
1535 });
1536 }
1537 let rows = combos.len();
1538 let cols = len;
1539 let expected = rows.checked_mul(cols).ok_or(DtiError::InvalidRange {
1540 start: 0,
1541 end: 0,
1542 step: 0,
1543 })?;
1544
1545 let warmup_periods: Vec<usize> = combos.iter().map(|_| first_valid + 1).collect();
1546
1547 let mut buf_mu = make_uninit_matrix(rows, cols);
1548 init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
1549
1550 let uninit_ptr = buf_mu.as_mut_ptr();
1551 let values = unsafe { std::slice::from_raw_parts_mut(uninit_ptr as *mut f64, expected) };
1552
1553 let start = first_valid + 1;
1554 let (x_base, ax_base) = dti_precompute_base(high, low, start);
1555
1556 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1557 let prm = &combos[row];
1558 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1559 {
1560 if matches!(kern, Kernel::Scalar | Kernel::Auto) {
1561 dti_row_simd128(
1562 high,
1563 low,
1564 prm.r.unwrap(),
1565 prm.s.unwrap(),
1566 prm.u.unwrap(),
1567 first_valid,
1568 out_row,
1569 );
1570 return;
1571 }
1572 }
1573
1574 match kern {
1575 Kernel::Scalar | Kernel::Auto => dti_row_scalar_from_base(
1576 &x_base,
1577 &ax_base,
1578 prm.r.unwrap(),
1579 prm.s.unwrap(),
1580 prm.u.unwrap(),
1581 start,
1582 out_row,
1583 ),
1584 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1585 Kernel::Avx2 => dti_row_avx2_from_base(
1586 &x_base,
1587 &ax_base,
1588 prm.r.unwrap(),
1589 prm.s.unwrap(),
1590 prm.u.unwrap(),
1591 start,
1592 out_row,
1593 ),
1594 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1595 Kernel::Avx512 => dti_row_avx512_from_base(
1596 &x_base,
1597 &ax_base,
1598 prm.r.unwrap(),
1599 prm.s.unwrap(),
1600 prm.u.unwrap(),
1601 start,
1602 out_row,
1603 ),
1604 _ => dti_row_scalar_from_base(
1605 &x_base,
1606 &ax_base,
1607 prm.r.unwrap(),
1608 prm.s.unwrap(),
1609 prm.u.unwrap(),
1610 start,
1611 out_row,
1612 ),
1613 }
1614 };
1615 if parallel {
1616 #[cfg(not(target_arch = "wasm32"))]
1617 {
1618 values
1619 .par_chunks_mut(cols)
1620 .enumerate()
1621 .for_each(|(row, slice)| do_row(row, slice));
1622 }
1623
1624 #[cfg(target_arch = "wasm32")]
1625 {
1626 for (row, slice) in values.chunks_mut(cols).enumerate() {
1627 do_row(row, slice);
1628 }
1629 }
1630 } else {
1631 for (row, slice) in values.chunks_mut(cols).enumerate() {
1632 do_row(row, slice);
1633 }
1634 }
1635
1636 let values = unsafe { Vec::from_raw_parts(uninit_ptr as *mut f64, expected, expected) };
1637 std::mem::forget(buf_mu);
1638
1639 Ok(DtiBatchOutput {
1640 values,
1641 combos,
1642 rows,
1643 cols,
1644 })
1645}
1646
1647#[inline(always)]
1648pub fn dti_batch_inner_into(
1649 high: &[f64],
1650 low: &[f64],
1651 sweep: &DtiBatchRange,
1652 kern: Kernel,
1653 parallel: bool,
1654 out: &mut [f64],
1655) -> Result<Vec<DtiParams>, DtiError> {
1656 let combos = expand_grid(sweep);
1657 if combos.is_empty() {
1658 return Err(DtiError::InvalidRange {
1659 start: 0,
1660 end: 0,
1661 step: 0,
1662 });
1663 }
1664 let len = high.len();
1665 if low.len() != len {
1666 return Err(DtiError::LengthMismatch {
1667 high: high.len(),
1668 low: low.len(),
1669 });
1670 }
1671 let first_valid = (0..len)
1672 .find(|&i| !high[i].is_nan() && !low[i].is_nan())
1673 .ok_or(DtiError::AllValuesNaN)?;
1674 let max_p = combos
1675 .iter()
1676 .map(|c| c.r.unwrap().max(c.s.unwrap()).max(c.u.unwrap()))
1677 .max()
1678 .unwrap();
1679 if len - first_valid < max_p {
1680 return Err(DtiError::NotEnoughValidData {
1681 needed: max_p,
1682 valid: len - first_valid,
1683 });
1684 }
1685
1686 let rows = combos.len();
1687 let cols = len;
1688 let expected = rows.checked_mul(cols).ok_or(DtiError::InvalidRange {
1689 start: 0,
1690 end: 0,
1691 step: 0,
1692 })?;
1693 if out.len() != expected {
1694 return Err(DtiError::OutputLengthMismatch {
1695 expected,
1696 got: out.len(),
1697 });
1698 }
1699
1700 let warm: Vec<usize> = vec![first_valid + 1; rows];
1701 let out_mu: &mut [std::mem::MaybeUninit<f64>] = unsafe {
1702 std::slice::from_raw_parts_mut(
1703 out.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
1704 out.len(),
1705 )
1706 };
1707 init_matrix_prefixes(out_mu, cols, &warm);
1708
1709 let values: &mut [f64] =
1710 unsafe { std::slice::from_raw_parts_mut(out_mu.as_mut_ptr() as *mut f64, out_mu.len()) };
1711
1712 let do_row = |row: usize, row_slice: &mut [f64]| unsafe {
1713 let prm = &combos[row];
1714 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1715 {
1716 if matches!(kern, Kernel::Scalar | Kernel::Auto) {
1717 dti_row_simd128(
1718 high,
1719 low,
1720 prm.r.unwrap(),
1721 prm.s.unwrap(),
1722 prm.u.unwrap(),
1723 first_valid,
1724 row_slice,
1725 );
1726 return;
1727 }
1728 }
1729 match kern {
1730 Kernel::Scalar | Kernel::Auto => dti_row_scalar(
1731 high,
1732 low,
1733 prm.r.unwrap(),
1734 prm.s.unwrap(),
1735 prm.u.unwrap(),
1736 first_valid,
1737 row_slice,
1738 ),
1739 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1740 Kernel::Avx2 => dti_row_avx2(
1741 high,
1742 low,
1743 prm.r.unwrap(),
1744 prm.s.unwrap(),
1745 prm.u.unwrap(),
1746 first_valid,
1747 row_slice,
1748 ),
1749 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1750 Kernel::Avx512 => dti_row_avx512(
1751 high,
1752 low,
1753 prm.r.unwrap(),
1754 prm.s.unwrap(),
1755 prm.u.unwrap(),
1756 first_valid,
1757 row_slice,
1758 ),
1759 _ => dti_row_scalar(
1760 high,
1761 low,
1762 prm.r.unwrap(),
1763 prm.s.unwrap(),
1764 prm.u.unwrap(),
1765 first_valid,
1766 row_slice,
1767 ),
1768 }
1769 };
1770
1771 if parallel {
1772 #[cfg(not(target_arch = "wasm32"))]
1773 values
1774 .par_chunks_mut(cols)
1775 .enumerate()
1776 .for_each(|(row, sl)| do_row(row, sl));
1777 #[cfg(target_arch = "wasm32")]
1778 for (row, sl) in values.chunks_mut(cols).enumerate() {
1779 do_row(row, sl);
1780 }
1781 } else {
1782 for (row, sl) in values.chunks_mut(cols).enumerate() {
1783 do_row(row, sl);
1784 }
1785 }
1786
1787 Ok(combos)
1788}
1789
1790#[inline(always)]
1791pub fn expand_grid_dti(r: &DtiBatchRange) -> Vec<DtiParams> {
1792 expand_grid(r)
1793}
1794
1795#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1796#[wasm_bindgen]
1797pub fn dti_js(
1798 high: &[f64],
1799 low: &[f64],
1800 r: usize,
1801 s: usize,
1802 u: usize,
1803) -> Result<Vec<f64>, JsValue> {
1804 let params = DtiParams {
1805 r: Some(r),
1806 s: Some(s),
1807 u: Some(u),
1808 };
1809 let data = DtiData::Slices { high, low };
1810 let input = DtiInput { data, params };
1811
1812 let output =
1813 dti_with_kernel(&input, Kernel::Auto).map_err(|e| JsValue::from_str(&e.to_string()))?;
1814
1815 Ok(output.values)
1816}
1817
1818#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1819#[wasm_bindgen]
1820pub fn dti_into(
1821 high_ptr: *const f64,
1822 low_ptr: *const f64,
1823 out_ptr: *mut f64,
1824 len: usize,
1825 r: usize,
1826 s: usize,
1827 u: usize,
1828) -> Result<(), JsValue> {
1829 if high_ptr.is_null() || low_ptr.is_null() || out_ptr.is_null() {
1830 return Err(JsValue::from_str("Null pointer provided"));
1831 }
1832
1833 unsafe {
1834 let high = std::slice::from_raw_parts(high_ptr, len);
1835 let low = std::slice::from_raw_parts(low_ptr, len);
1836 let params = DtiParams {
1837 r: Some(r),
1838 s: Some(s),
1839 u: Some(u),
1840 };
1841 let data = DtiData::Slices { high, low };
1842 let input = DtiInput { data, params };
1843
1844 if out_ptr as *const f64 == high_ptr || out_ptr as *const f64 == low_ptr {
1845 let mut temp = vec![0.0; len];
1846 dti_into_slice(&mut temp, &input, Kernel::Auto)
1847 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1848 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1849 out.copy_from_slice(&temp);
1850 } else {
1851 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1852 dti_into_slice(out, &input, Kernel::Auto)
1853 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1854 }
1855 Ok(())
1856 }
1857}
1858
1859#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1860#[wasm_bindgen]
1861pub fn dti_alloc(len: usize) -> *mut f64 {
1862 let mut vec = Vec::<f64>::with_capacity(len);
1863 let ptr = vec.as_mut_ptr();
1864 std::mem::forget(vec);
1865 ptr
1866}
1867
1868#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1869#[wasm_bindgen]
1870pub fn dti_free(ptr: *mut f64, len: usize) {
1871 if !ptr.is_null() {
1872 unsafe {
1873 let _ = Vec::from_raw_parts(ptr, len, len);
1874 }
1875 }
1876}
1877
1878#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1879#[derive(Serialize, Deserialize)]
1880pub struct DtiBatchConfig {
1881 pub r_range: (usize, usize, usize),
1882 pub s_range: (usize, usize, usize),
1883 pub u_range: (usize, usize, usize),
1884}
1885
1886#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1887#[derive(Serialize, Deserialize)]
1888pub struct DtiBatchJsOutput {
1889 pub values: Vec<f64>,
1890 pub combos: Vec<DtiParams>,
1891 pub rows: usize,
1892 pub cols: usize,
1893}
1894
1895#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1896#[wasm_bindgen(js_name = dti_batch)]
1897pub fn dti_batch_js(high: &[f64], low: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1898 let config: DtiBatchConfig = serde_wasm_bindgen::from_value(config)
1899 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1900
1901 let sweep = DtiBatchRange {
1902 r: config.r_range,
1903 s: config.s_range,
1904 u: config.u_range,
1905 };
1906
1907 let out = dti_batch_inner(high, low, &sweep, detect_best_kernel(), false)
1908 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1909
1910 let js_out = DtiBatchJsOutput {
1911 values: out.values,
1912 combos: out.combos,
1913 rows: out.rows,
1914 cols: out.cols,
1915 };
1916 serde_wasm_bindgen::to_value(&js_out)
1917 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1918}
1919
1920#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1921#[wasm_bindgen]
1922pub fn dti_batch_into(
1923 high_ptr: *const f64,
1924 low_ptr: *const f64,
1925 out_ptr: *mut f64,
1926 len: usize,
1927 r_start: usize,
1928 r_end: usize,
1929 r_step: usize,
1930 s_start: usize,
1931 s_end: usize,
1932 s_step: usize,
1933 u_start: usize,
1934 u_end: usize,
1935 u_step: usize,
1936) -> Result<usize, JsValue> {
1937 if high_ptr.is_null() || low_ptr.is_null() || out_ptr.is_null() {
1938 return Err(JsValue::from_str("Null pointer provided to dti_batch_into"));
1939 }
1940
1941 unsafe {
1942 let high = std::slice::from_raw_parts(high_ptr, len);
1943 let low = std::slice::from_raw_parts(low_ptr, len);
1944
1945 let sweep = DtiBatchRange {
1946 r: (r_start, r_end, r_step),
1947 s: (s_start, s_end, s_step),
1948 u: (u_start, u_end, u_step),
1949 };
1950
1951 fn axis_count(start: usize, end: usize, step: usize) -> usize {
1952 if step == 0 || start == end {
1953 return 1;
1954 }
1955 if start < end {
1956 ((end - start) / step) + 1
1957 } else {
1958 ((start - end) / step) + 1
1959 }
1960 }
1961 let r_count = axis_count(r_start, r_end, r_step);
1962 let s_count = axis_count(s_start, s_end, s_step);
1963 let u_count = axis_count(u_start, u_end, u_step);
1964 let total_rows = r_count
1965 .checked_mul(s_count)
1966 .and_then(|x| x.checked_mul(u_count))
1967 .ok_or(JsValue::from_str(
1968 "range expansion overflow in dti_batch_into",
1969 ))?;
1970 let total_len = total_rows
1971 .checked_mul(len)
1972 .ok_or(JsValue::from_str("size overflow in dti_batch_into"))?;
1973
1974 let out = std::slice::from_raw_parts_mut(out_ptr, total_len);
1975
1976 if out_ptr as *const f64 == high_ptr || out_ptr as *const f64 == low_ptr {
1977 let mut temp_values = vec![0.0; total_len];
1978 let combos =
1979 dti_batch_inner_into(high, low, &sweep, Kernel::Auto, false, &mut temp_values)
1980 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1981 out.copy_from_slice(&temp_values);
1982 } else {
1983 dti_batch_inner_into(high, low, &sweep, Kernel::Auto, false, out)
1984 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1985 }
1986
1987 Ok(total_rows)
1988 }
1989}
1990
1991#[cfg(test)]
1992mod tests {
1993 use super::*;
1994 use crate::skip_if_unsupported;
1995 use crate::utilities::data_loader::read_candles_from_csv;
1996 fn check_dti_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1997 skip_if_unsupported!(kernel, test_name);
1998 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1999 let candles = read_candles_from_csv(file_path)?;
2000 let default_params = DtiParams {
2001 r: None,
2002 s: None,
2003 u: None,
2004 };
2005 let input = DtiInput::from_candles(&candles, default_params);
2006 let output = dti_with_kernel(&input, kernel)?;
2007 assert_eq!(output.values.len(), candles.close.len());
2008 Ok(())
2009 }
2010 fn check_dti_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2011 skip_if_unsupported!(kernel, test_name);
2012 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2013 let candles = read_candles_from_csv(file_path)?;
2014 let input = DtiInput::with_default_candles(&candles);
2015 let result = dti_with_kernel(&input, kernel)?;
2016 let expected_last_five = [
2017 -39.0091620347991,
2018 -39.75219264093014,
2019 -40.53941417932286,
2020 -41.2787749205189,
2021 -42.93758699380749,
2022 ];
2023 let start = result.values.len().saturating_sub(5);
2024 for (i, &val) in result.values[start..].iter().enumerate() {
2025 let diff = (val - expected_last_five[i]).abs();
2026 assert!(
2027 diff < 1e-6,
2028 "[{}] DTI {:?} mismatch at idx {}: got {}, expected {}",
2029 test_name,
2030 kernel,
2031 i,
2032 val,
2033 expected_last_five[i]
2034 );
2035 }
2036 Ok(())
2037 }
2038 fn check_dti_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2039 skip_if_unsupported!(kernel, test_name);
2040 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2041 let candles = read_candles_from_csv(file_path)?;
2042 let input = DtiInput::with_default_candles(&candles);
2043 let output = dti_with_kernel(&input, kernel)?;
2044 assert_eq!(output.values.len(), candles.close.len());
2045 Ok(())
2046 }
2047 fn check_dti_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2048 skip_if_unsupported!(kernel, test_name);
2049 let high = [10.0, 11.0, 12.0];
2050 let low = [9.0, 10.0, 11.0];
2051 let params = DtiParams {
2052 r: Some(0),
2053 s: Some(10),
2054 u: Some(5),
2055 };
2056 let input = DtiInput::from_slices(&high, &low, params);
2057 let res = dti_with_kernel(&input, kernel);
2058 assert!(
2059 res.is_err(),
2060 "[{}] DTI should fail with zero period",
2061 test_name
2062 );
2063 Ok(())
2064 }
2065 fn check_dti_period_exceeds_length(
2066 test_name: &str,
2067 kernel: Kernel,
2068 ) -> Result<(), Box<dyn Error>> {
2069 skip_if_unsupported!(kernel, test_name);
2070 let high = [10.0, 11.0];
2071 let low = [9.0, 10.0];
2072 let params = DtiParams {
2073 r: Some(14),
2074 s: Some(10),
2075 u: Some(5),
2076 };
2077 let input = DtiInput::from_slices(&high, &low, params);
2078 let res = dti_with_kernel(&input, kernel);
2079 assert!(
2080 res.is_err(),
2081 "[{}] DTI should fail with period exceeding length",
2082 test_name
2083 );
2084 Ok(())
2085 }
2086 fn check_dti_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2087 skip_if_unsupported!(kernel, test_name);
2088 let high = [f64::NAN, f64::NAN, f64::NAN];
2089 let low = [f64::NAN, f64::NAN, f64::NAN];
2090 let params = DtiParams::default();
2091 let input = DtiInput::from_slices(&high, &low, params);
2092 let res = dti_with_kernel(&input, kernel);
2093 assert!(res.is_err(), "[{}] DTI should fail with all NaN", test_name);
2094 Ok(())
2095 }
2096 fn check_dti_empty_data(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2097 skip_if_unsupported!(kernel, test_name);
2098 let high: [f64; 0] = [];
2099 let low: [f64; 0] = [];
2100 let params = DtiParams::default();
2101 let input = DtiInput::from_slices(&high, &low, params);
2102 let res = dti_with_kernel(&input, kernel);
2103 assert!(
2104 res.is_err(),
2105 "[{}] DTI should fail with empty data",
2106 test_name
2107 );
2108 Ok(())
2109 }
2110 fn check_dti_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2111 skip_if_unsupported!(kernel, test_name);
2112 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2113 let candles = read_candles_from_csv(file_path)?;
2114 let high = candles.select_candle_field("high")?;
2115 let low = candles.select_candle_field("low")?;
2116 let params = DtiParams::default();
2117 let input = DtiInput::from_slices(high, low, params.clone());
2118 let batch_output = dti_with_kernel(&input, kernel)?.values;
2119 let mut stream = DtiStream::try_new(params)?;
2120 let mut stream_values = Vec::with_capacity(high.len());
2121 for (&h, &l) in high.iter().zip(low.iter()) {
2122 match stream.update(h, l) {
2123 Some(val) => stream_values.push(val),
2124 None => stream_values.push(f64::NAN),
2125 }
2126 }
2127 assert_eq!(batch_output.len(), stream_values.len());
2128 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
2129 if b.is_nan() && s.is_nan() {
2130 continue;
2131 }
2132 let diff = (b - s).abs();
2133 assert!(
2134 diff < 1e-9,
2135 "[{}] DTI streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
2136 test_name,
2137 i,
2138 b,
2139 s,
2140 diff
2141 );
2142 }
2143 Ok(())
2144 }
2145
2146 fn check_dti_length_mismatch(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2147 skip_if_unsupported!(kernel, test_name);
2148
2149 let high = vec![10.0, 11.0, 12.0];
2150 let low = vec![9.0, 10.0];
2151 let params = DtiParams::default();
2152 let input = DtiInput::from_slices(&high, &low, params);
2153
2154 let result = dti_with_kernel(&input, kernel);
2155 assert!(
2156 result.is_err(),
2157 "[{}] Should fail with mismatched lengths",
2158 test_name
2159 );
2160
2161 match result.unwrap_err() {
2162 DtiError::LengthMismatch { high: h, low: l } => {
2163 assert_eq!(h, 3, "[{}] High length should be 3", test_name);
2164 assert_eq!(l, 2, "[{}] Low length should be 2", test_name);
2165 }
2166 e => panic!(
2167 "[{}] Expected LengthMismatch error, got: {:?}",
2168 test_name, e
2169 ),
2170 }
2171
2172 let mut out = vec![0.0; high.len()];
2173 let result = dti_into_slice(&mut out, &input, kernel);
2174 assert!(
2175 result.is_err(),
2176 "[{}] dti_into_slice should fail with mismatched lengths",
2177 test_name
2178 );
2179
2180 match result.unwrap_err() {
2181 DtiError::LengthMismatch { high: h, low: l } => {
2182 assert_eq!(h, 3, "[{}] High length should be 3", test_name);
2183 assert_eq!(l, 2, "[{}] Low length should be 2", test_name);
2184 }
2185 e => panic!(
2186 "[{}] Expected LengthMismatch error from dti_into_slice, got: {:?}",
2187 test_name, e
2188 ),
2189 }
2190
2191 let sweep = DtiBatchRange::default();
2192 let result = dti_batch_inner(&high, &low, &sweep, kernel, false);
2193 assert!(
2194 result.is_err(),
2195 "[{}] Batch should fail with mismatched lengths",
2196 test_name
2197 );
2198
2199 match result.unwrap_err() {
2200 DtiError::LengthMismatch { high: h, low: l } => {
2201 assert_eq!(h, 3, "[{}] Batch high length should be 3", test_name);
2202 assert_eq!(l, 2, "[{}] Batch low length should be 2", test_name);
2203 }
2204 e => panic!(
2205 "[{}] Expected LengthMismatch error from batch, got: {:?}",
2206 test_name, e
2207 ),
2208 }
2209
2210 Ok(())
2211 }
2212
2213 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2214 fn check_dti_wasm_kernel_fallback(
2215 test_name: &str,
2216 _kernel: Kernel,
2217 ) -> Result<(), Box<dyn Error>> {
2218 let high = vec![10.0, 11.0, 12.0, 13.0, 14.0];
2219 let low = vec![9.0, 10.0, 11.0, 12.0, 13.0];
2220 let params = DtiParams::default();
2221 let input = DtiInput::from_slices(&high, &low, params);
2222
2223 let scalar_result = dti_with_kernel(&input, Kernel::Scalar)?;
2224
2225 let avx2_result = dti_with_kernel(&input, Kernel::Avx2)?;
2226 assert_eq!(
2227 scalar_result.values.len(),
2228 avx2_result.values.len(),
2229 "[{}] AVX2 fallback length mismatch",
2230 test_name
2231 );
2232 for (i, (s, a)) in scalar_result
2233 .values
2234 .iter()
2235 .zip(avx2_result.values.iter())
2236 .enumerate()
2237 {
2238 if s.is_nan() && a.is_nan() {
2239 continue;
2240 }
2241 assert!(
2242 (s - a).abs() < 1e-10,
2243 "[{}] AVX2 fallback mismatch at {}: scalar={}, avx2={}",
2244 test_name,
2245 i,
2246 s,
2247 a
2248 );
2249 }
2250
2251 let avx512_result = dti_with_kernel(&input, Kernel::Avx512)?;
2252 assert_eq!(
2253 scalar_result.values.len(),
2254 avx512_result.values.len(),
2255 "[{}] AVX512 fallback length mismatch",
2256 test_name
2257 );
2258 for (i, (s, a)) in scalar_result
2259 .values
2260 .iter()
2261 .zip(avx512_result.values.iter())
2262 .enumerate()
2263 {
2264 if s.is_nan() && a.is_nan() {
2265 continue;
2266 }
2267 assert!(
2268 (s - a).abs() < 1e-10,
2269 "[{}] AVX512 fallback mismatch at {}: scalar={}, avx512={}",
2270 test_name,
2271 i,
2272 s,
2273 a
2274 );
2275 }
2276
2277 let mut out_scalar = vec![0.0; high.len()];
2278 let mut out_avx = vec![0.0; high.len()];
2279
2280 dti_into_slice(&mut out_scalar, &input, Kernel::Scalar)?;
2281 dti_into_slice(&mut out_avx, &input, Kernel::Avx2)?;
2282
2283 for (i, (s, a)) in out_scalar.iter().zip(out_avx.iter()).enumerate() {
2284 if s.is_nan() && a.is_nan() {
2285 continue;
2286 }
2287 assert!(
2288 (s - a).abs() < 1e-10,
2289 "[{}] dti_into_slice AVX2 fallback mismatch at {}",
2290 test_name,
2291 i
2292 );
2293 }
2294
2295 Ok(())
2296 }
2297
2298 #[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
2299 fn check_dti_wasm_kernel_fallback(
2300 _test_name: &str,
2301 _kernel: Kernel,
2302 ) -> Result<(), Box<dyn Error>> {
2303 Ok(())
2304 }
2305
2306 #[cfg(debug_assertions)]
2307 fn check_dti_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2308 skip_if_unsupported!(kernel, test_name);
2309
2310 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2311 let candles = read_candles_from_csv(file_path)?;
2312
2313 let test_params = vec![
2314 DtiParams::default(),
2315 DtiParams {
2316 r: Some(1),
2317 s: Some(1),
2318 u: Some(1),
2319 },
2320 DtiParams {
2321 r: Some(5),
2322 s: Some(5),
2323 u: Some(5),
2324 },
2325 DtiParams {
2326 r: Some(20),
2327 s: Some(15),
2328 u: Some(10),
2329 },
2330 DtiParams {
2331 r: Some(50),
2332 s: Some(30),
2333 u: Some(20),
2334 },
2335 DtiParams {
2336 r: Some(100),
2337 s: Some(50),
2338 u: Some(25),
2339 },
2340 DtiParams {
2341 r: Some(14),
2342 s: Some(5),
2343 u: Some(20),
2344 },
2345 DtiParams {
2346 r: Some(30),
2347 s: Some(10),
2348 u: Some(5),
2349 },
2350 DtiParams {
2351 r: Some(10),
2352 s: Some(20),
2353 u: Some(15),
2354 },
2355 DtiParams {
2356 r: Some(2),
2357 s: Some(10),
2358 u: Some(5),
2359 },
2360 ];
2361
2362 for (param_idx, params) in test_params.iter().enumerate() {
2363 let input = DtiInput::from_candles(&candles, params.clone());
2364 let output = dti_with_kernel(&input, kernel)?;
2365
2366 for (i, &val) in output.values.iter().enumerate() {
2367 if val.is_nan() {
2368 continue;
2369 }
2370
2371 let bits = val.to_bits();
2372
2373 if bits == 0x11111111_11111111 {
2374 panic!(
2375 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2376 with params: r={}, s={}, u={} (param set {})",
2377 test_name,
2378 val,
2379 bits,
2380 i,
2381 params.r.unwrap_or(14),
2382 params.s.unwrap_or(10),
2383 params.u.unwrap_or(5),
2384 param_idx
2385 );
2386 }
2387
2388 if bits == 0x22222222_22222222 {
2389 panic!(
2390 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2391 with params: r={}, s={}, u={} (param set {})",
2392 test_name,
2393 val,
2394 bits,
2395 i,
2396 params.r.unwrap_or(14),
2397 params.s.unwrap_or(10),
2398 params.u.unwrap_or(5),
2399 param_idx
2400 );
2401 }
2402
2403 if bits == 0x33333333_33333333 {
2404 panic!(
2405 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2406 with params: r={}, s={}, u={} (param set {})",
2407 test_name,
2408 val,
2409 bits,
2410 i,
2411 params.r.unwrap_or(14),
2412 params.s.unwrap_or(10),
2413 params.u.unwrap_or(5),
2414 param_idx
2415 );
2416 }
2417 }
2418 }
2419
2420 Ok(())
2421 }
2422
2423 #[cfg(not(debug_assertions))]
2424 fn check_dti_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2425 Ok(())
2426 }
2427
2428 #[cfg(feature = "proptest")]
2429 fn check_dti_property(
2430 test_name: &str,
2431 kernel: Kernel,
2432 ) -> Result<(), Box<dyn std::error::Error>> {
2433 use proptest::prelude::*;
2434 skip_if_unsupported!(kernel, test_name);
2435
2436 let strat = (2usize..=50)
2437 .prop_flat_map(|max_period| {
2438 let min_len = max_period * 3;
2439 (
2440 (100.0f64..5000.0f64, 0.01f64..0.1f64),
2441 (
2442 1usize..=max_period,
2443 1usize..=max_period,
2444 1usize..=max_period,
2445 ),
2446 min_len..400,
2447 )
2448 })
2449 .prop_flat_map(|((base_price, volatility), (r, s, u), len)| {
2450 let price_changes = prop::collection::vec((-1.0f64..1.0f64), len);
2451
2452 (
2453 Just((base_price, volatility)),
2454 Just((r, s, u)),
2455 price_changes,
2456 )
2457 })
2458 .prop_map(|((base_price, volatility), (r, s, u), changes)| {
2459 let len = changes.len();
2460 let mut high = Vec::with_capacity(len);
2461 let mut low = Vec::with_capacity(len);
2462 let mut current_price = base_price;
2463
2464 for change_factor in changes {
2465 let change = change_factor * volatility * current_price;
2466 current_price = (current_price + change).max(10.0);
2467
2468 let daily_range = current_price * volatility * (0.5 + change_factor.abs());
2469 let mid_adjustment = change_factor * daily_range * 0.25;
2470
2471 let daily_high = current_price + daily_range / 2.0 + mid_adjustment;
2472 let daily_low = current_price - daily_range / 2.0 + mid_adjustment;
2473
2474 high.push(daily_high);
2475 low.push(daily_low.max(1.0));
2476 }
2477
2478 (high, low, r, s, u)
2479 });
2480
2481 proptest::test_runner::TestRunner::default()
2482 .run(&strat, |(high, low, r, s, u)| {
2483 let params = DtiParams {
2484 r: Some(r),
2485 s: Some(s),
2486 u: Some(u),
2487 };
2488 let input = DtiInput::from_slices(&high, &low, params);
2489
2490 let result = dti_with_kernel(&input, kernel);
2491 prop_assert!(result.is_ok(), "DTI computation failed: {:?}", result);
2492 let DtiOutput { values: out } = result.unwrap();
2493
2494 let DtiOutput { values: ref_out } =
2495 dti_with_kernel(&input, Kernel::Scalar).unwrap();
2496
2497 prop_assert_eq!(out.len(), high.len(), "Output length mismatch");
2498
2499 prop_assert!(out[0].is_nan(), "First value should be NaN");
2500
2501 let finite_values: Vec<f64> =
2502 out.iter().copied().filter(|v| v.is_finite()).collect();
2503 if !finite_values.is_empty() {
2504 let max_val = finite_values
2505 .iter()
2506 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
2507 let min_val = finite_values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
2508
2509 prop_assert!(
2510 max_val <= 100.0001 && min_val >= -100.0001,
2511 "DTI values exceed mathematical bounds: [{:.6}, {:.6}]",
2512 min_val,
2513 max_val
2514 );
2515 }
2516
2517 for i in 0..out.len() {
2518 let y = out[i];
2519 let r = ref_out[i];
2520
2521 if !y.is_finite() || !r.is_finite() {
2522 prop_assert_eq!(
2523 y.to_bits(),
2524 r.to_bits(),
2525 "NaN/finite mismatch at idx {}: {} vs {}",
2526 i,
2527 y,
2528 r
2529 );
2530 } else {
2531 let diff = (y - r).abs();
2532 let ulp_diff = y.to_bits().abs_diff(r.to_bits());
2533 prop_assert!(
2534 diff <= 1e-9 || ulp_diff <= 10,
2535 "Kernel mismatch at idx {}: {} vs {} (diff={}, ulp={})",
2536 i,
2537 y,
2538 r,
2539 diff,
2540 ulp_diff
2541 );
2542 }
2543 }
2544
2545 let is_strong_uptrend = high
2546 .windows(5)
2547 .all(|w| w.windows(2).all(|pair| pair[1] > pair[0] * 1.001))
2548 && low
2549 .windows(5)
2550 .all(|w| w.windows(2).all(|pair| pair[1] > pair[0] * 1.001));
2551
2552 let is_strong_downtrend = high
2553 .windows(5)
2554 .all(|w| w.windows(2).all(|pair| pair[1] < pair[0] * 0.999))
2555 && low
2556 .windows(5)
2557 .all(|w| w.windows(2).all(|pair| pair[1] < pair[0] * 0.999));
2558
2559 if (is_strong_uptrend || is_strong_downtrend) && out.len() > r + s + u + 10 {
2560 let later_values: Vec<f64> = out[out.len() - 10..]
2561 .iter()
2562 .copied()
2563 .filter(|v| v.is_finite())
2564 .collect();
2565 if later_values.len() >= 5 {
2566 let avg = later_values.iter().sum::<f64>() / later_values.len() as f64;
2567 if is_strong_uptrend {
2568 prop_assert!(
2569 avg > 0.0,
2570 "DTI should be positive in strong uptrend: avg={:.2}",
2571 avg
2572 );
2573 }
2574 if is_strong_downtrend {
2575 prop_assert!(
2576 avg < 0.0,
2577 "DTI should be negative in strong downtrend: avg={:.2}",
2578 avg
2579 );
2580 }
2581 }
2582 }
2583
2584 #[cfg(debug_assertions)]
2585 for (i, &val) in out.iter().enumerate() {
2586 if val.is_nan() {
2587 continue;
2588 }
2589 let bits = val.to_bits();
2590 prop_assert!(
2591 bits != 0x11111111_11111111
2592 && bits != 0x22222222_22222222
2593 && bits != 0x33333333_33333333,
2594 "Found poison value at index {}: {} (0x{:016X})",
2595 i,
2596 val,
2597 bits
2598 );
2599 }
2600
2601 if r == 1 && s == 1 && u == 1 && out.len() > 10 {
2602 let responsive_values: Vec<f64> = out[2..10]
2603 .iter()
2604 .copied()
2605 .filter(|v| v.is_finite() && v.abs() > 0.0)
2606 .collect();
2607 prop_assert!(
2608 !responsive_values.is_empty(),
2609 "DTI with period=1 should produce non-zero values quickly"
2610 );
2611 }
2612
2613 let is_zero_volatility = high
2614 .iter()
2615 .zip(low.iter())
2616 .all(|(h, l)| (h - l).abs() < 1e-10);
2617 if is_zero_volatility && out.len() > 2 {
2618 for i in 2..out.len() {
2619 if out[i].is_finite() {
2620 prop_assert!(
2621 out[i].abs() < 1e-10,
2622 "DTI should be 0 with zero volatility at index {}: {}",
2623 i,
2624 out[i]
2625 );
2626 }
2627 }
2628 }
2629
2630 if high.len() >= 10 {
2631 let spreads: Vec<f64> =
2632 high.iter().zip(low.iter()).map(|(h, l)| h - l).collect();
2633 let first_spread = spreads[0];
2634 let is_constant_spread = spreads
2635 .iter()
2636 .all(|&s| (s - first_spread).abs() < first_spread * 0.01);
2637
2638 let high_changes: Vec<f64> =
2639 high.windows(2).map(|w| (w[1] - w[0]).abs()).collect();
2640 let low_changes: Vec<f64> =
2641 low.windows(2).map(|w| (w[1] - w[0]).abs()).collect();
2642 let avg_high_change =
2643 high_changes.iter().sum::<f64>() / high_changes.len() as f64;
2644 let avg_low_change = low_changes.iter().sum::<f64>() / low_changes.len() as f64;
2645 let is_stable =
2646 avg_high_change < high[0] * 0.001 && avg_low_change < low[0] * 0.001;
2647
2648 if is_constant_spread && is_stable && out.len() > r + s + u + 5 {
2649 let last_values: Vec<f64> = out[out.len() - 5..]
2650 .iter()
2651 .copied()
2652 .filter(|v| v.is_finite())
2653 .collect();
2654 if last_values.len() >= 3 {
2655 let avg_abs = last_values.iter().map(|v| v.abs()).sum::<f64>()
2656 / last_values.len() as f64;
2657 prop_assert!(
2658 avg_abs < 10.0,
2659 "DTI should converge near 0 with constant spread and stable prices: avg_abs={:.2}",
2660 avg_abs
2661 );
2662 }
2663 }
2664 }
2665
2666 if high.len() >= 10 {
2667 let high_rising = high
2668 .windows(5)
2669 .all(|w| w.windows(2).all(|pair| pair[1] >= pair[0]));
2670 let low_rising = low
2671 .windows(5)
2672 .all(|w| w.windows(2).all(|pair| pair[1] >= pair[0]));
2673 let high_falling = high
2674 .windows(5)
2675 .all(|w| w.windows(2).all(|pair| pair[1] <= pair[0]));
2676 let low_falling = low
2677 .windows(5)
2678 .all(|w| w.windows(2).all(|pair| pair[1] <= pair[0]));
2679
2680 if (high_rising && low_rising) && out.len() > 10 {
2681 let mid_to_end: Vec<f64> = out[out.len() / 2..]
2682 .iter()
2683 .copied()
2684 .filter(|v| v.is_finite())
2685 .collect();
2686 if mid_to_end.len() >= 3 {
2687 let positive_count = mid_to_end.iter().filter(|&&v| v > 0.0).count();
2688 prop_assert!(
2689 positive_count >= mid_to_end.len() / 2,
2690 "DTI should be mostly positive when prices consistently rise"
2691 );
2692 }
2693 }
2694
2695 if (high_falling && low_falling) && out.len() > 10 {
2696 let mid_to_end: Vec<f64> = out[out.len() / 2..]
2697 .iter()
2698 .copied()
2699 .filter(|v| v.is_finite())
2700 .collect();
2701 if mid_to_end.len() >= 3 {
2702 let negative_count = mid_to_end.iter().filter(|&&v| v < 0.0).count();
2703 prop_assert!(
2704 negative_count >= mid_to_end.len() / 2,
2705 "DTI should be mostly negative when prices consistently fall"
2706 );
2707 }
2708 }
2709 }
2710
2711 Ok(())
2712 })
2713 .unwrap();
2714
2715 Ok(())
2716 }
2717
2718 #[cfg(not(feature = "proptest"))]
2719 fn check_dti_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2720 skip_if_unsupported!(kernel, test_name);
2721 Ok(())
2722 }
2723
2724 macro_rules! generate_all_dti_tests {
2725 ($($test_fn:ident),*) => {
2726 paste::paste! {
2727 $(
2728 #[test]
2729 fn [<$test_fn _scalar_f64>]() {
2730 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2731 }
2732 )*
2733 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2734 $(
2735 #[test]
2736 fn [<$test_fn _avx2_f64>]() {
2737 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2738 }
2739 #[test]
2740 fn [<$test_fn _avx512_f64>]() {
2741 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2742 }
2743 )*
2744 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2745 $(
2746 #[test]
2747 fn [<$test_fn _simd128_f64>]() {
2748 let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
2749 }
2750 )*
2751 }
2752 }
2753 }
2754 generate_all_dti_tests!(
2755 check_dti_partial_params,
2756 check_dti_accuracy,
2757 check_dti_default_candles,
2758 check_dti_zero_period,
2759 check_dti_period_exceeds_length,
2760 check_dti_all_nan,
2761 check_dti_empty_data,
2762 check_dti_streaming,
2763 check_dti_length_mismatch,
2764 check_dti_wasm_kernel_fallback,
2765 check_dti_no_poison,
2766 check_dti_property
2767 );
2768 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2769 skip_if_unsupported!(kernel, test);
2770 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2771 let c = read_candles_from_csv(file)?;
2772 let output = DtiBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
2773 let def = DtiParams::default();
2774 let row = output.values_for(&def).expect("default row missing");
2775 assert_eq!(row.len(), c.close.len());
2776 let expected = [
2777 -39.0091620347991,
2778 -39.75219264093014,
2779 -40.53941417932286,
2780 -41.2787749205189,
2781 -42.93758699380749,
2782 ];
2783 let start = row.len() - 5;
2784 for (i, &v) in row[start..].iter().enumerate() {
2785 assert!(
2786 (v - expected[i]).abs() < 1e-6,
2787 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2788 );
2789 }
2790 Ok(())
2791 }
2792
2793 #[cfg(debug_assertions)]
2794 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2795 skip_if_unsupported!(kernel, test);
2796
2797 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2798 let c = read_candles_from_csv(file)?;
2799
2800 let test_configs = vec![
2801 (1, 5, 1, 1, 5, 1, 1, 5, 1),
2802 (5, 15, 5, 5, 15, 5, 5, 15, 5),
2803 (10, 30, 10, 10, 20, 10, 5, 15, 5),
2804 (14, 14, 0, 10, 10, 0, 1, 10, 1),
2805 (1, 20, 5, 10, 10, 0, 5, 5, 0),
2806 (20, 50, 15, 15, 30, 15, 10, 20, 10),
2807 (2, 6, 2, 8, 12, 2, 3, 9, 3),
2808 (50, 100, 25, 30, 60, 30, 20, 40, 20),
2809 (14, 14, 0, 5, 20, 5, 5, 20, 5),
2810 (5, 5, 0, 5, 5, 0, 1, 10, 1),
2811 ];
2812
2813 for (cfg_idx, &(r_start, r_end, r_step, s_start, s_end, s_step, u_start, u_end, u_step)) in
2814 test_configs.iter().enumerate()
2815 {
2816 let output = DtiBatchBuilder::new()
2817 .kernel(kernel)
2818 .r_range(r_start, r_end, r_step)
2819 .s_range(s_start, s_end, s_step)
2820 .u_range(u_start, u_end, u_step)
2821 .apply_candles(&c)?;
2822
2823 for (idx, &val) in output.values.iter().enumerate() {
2824 if val.is_nan() {
2825 continue;
2826 }
2827
2828 let bits = val.to_bits();
2829 let row = idx / output.cols;
2830 let col = idx % output.cols;
2831 let combo = &output.combos[row];
2832
2833 if bits == 0x11111111_11111111 {
2834 panic!(
2835 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2836 at row {} col {} (flat index {}) with params: r={}, s={}, u={}",
2837 test,
2838 cfg_idx,
2839 val,
2840 bits,
2841 row,
2842 col,
2843 idx,
2844 combo.r.unwrap_or(14),
2845 combo.s.unwrap_or(10),
2846 combo.u.unwrap_or(5)
2847 );
2848 }
2849
2850 if bits == 0x22222222_22222222 {
2851 panic!(
2852 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2853 at row {} col {} (flat index {}) with params: r={}, s={}, u={}",
2854 test,
2855 cfg_idx,
2856 val,
2857 bits,
2858 row,
2859 col,
2860 idx,
2861 combo.r.unwrap_or(14),
2862 combo.s.unwrap_or(10),
2863 combo.u.unwrap_or(5)
2864 );
2865 }
2866
2867 if bits == 0x33333333_33333333 {
2868 panic!(
2869 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2870 at row {} col {} (flat index {}) with params: r={}, s={}, u={}",
2871 test,
2872 cfg_idx,
2873 val,
2874 bits,
2875 row,
2876 col,
2877 idx,
2878 combo.r.unwrap_or(14),
2879 combo.s.unwrap_or(10),
2880 combo.u.unwrap_or(5)
2881 );
2882 }
2883 }
2884 }
2885
2886 Ok(())
2887 }
2888
2889 #[cfg(not(debug_assertions))]
2890 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2891 Ok(())
2892 }
2893
2894 macro_rules! gen_batch_tests {
2895 ($fn_name:ident) => {
2896 paste::paste! {
2897 #[test] fn [<$fn_name _scalar>]() {
2898 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2899 }
2900 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2901 #[test] fn [<$fn_name _avx2>]() {
2902 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2903 }
2904 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2905 #[test] fn [<$fn_name _avx512>]() {
2906 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2907 }
2908 #[test] fn [<$fn_name _auto_detect>]() {
2909 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]),
2910 Kernel::Auto);
2911 }
2912 }
2913 };
2914 }
2915 gen_batch_tests!(check_batch_default_row);
2916 gen_batch_tests!(check_batch_no_poison);
2917
2918 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2919 #[test]
2920 fn test_dti_into_matches_api() -> Result<(), Box<dyn Error>> {
2921 fn eq_or_both_nan(a: f64, b: f64) -> bool {
2922 (a.is_nan() && b.is_nan()) || (a == b)
2923 }
2924
2925 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2926 let candles = read_candles_from_csv(file_path)?;
2927 let input = DtiInput::with_default_candles(&candles);
2928
2929 let baseline = dti(&input)?;
2930
2931 let mut out = vec![0.0f64; candles.close.len()];
2932 dti_into(&input, &mut out)?;
2933
2934 assert_eq!(baseline.values.len(), out.len());
2935 for i in 0..out.len() {
2936 assert!(
2937 eq_or_both_nan(baseline.values[i], out[i]),
2938 "Mismatch at index {}: api={} into={}",
2939 i,
2940 baseline.values[i],
2941 out[i]
2942 );
2943 }
2944 Ok(())
2945 }
2946}