Skip to main content

vector_ta/indicators/
dti.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use cust::memory::DeviceBuffer;
3#[cfg(feature = "python")]
4use numpy::{IntoPyArray, PyArray1};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyValueError;
7#[cfg(feature = "python")]
8use pyo3::prelude::*;
9#[cfg(feature = "python")]
10use pyo3::types::PyDict;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::{source_type, Candles};
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21    make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25use aligned_vec::{AVec, CACHELINE_ALIGN};
26#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
27use core::arch::x86_64::*;
28#[cfg(not(target_arch = "wasm32"))]
29use rayon::prelude::*;
30use std::convert::AsRef;
31use std::error::Error;
32use thiserror::Error;
33
34#[derive(Debug, Clone)]
35pub enum DtiData<'a> {
36    Candles { candles: &'a Candles },
37    Slices { high: &'a [f64], low: &'a [f64] },
38}
39
40#[derive(Debug, Clone)]
41pub struct DtiOutput {
42    pub values: Vec<f64>,
43}
44
45#[derive(Debug, Clone)]
46#[cfg_attr(
47    all(target_arch = "wasm32", feature = "wasm"),
48    derive(serde::Serialize, serde::Deserialize)
49)]
50pub struct DtiParams {
51    pub r: Option<usize>,
52    pub s: Option<usize>,
53    pub u: Option<usize>,
54}
55
56impl Default for DtiParams {
57    fn default() -> Self {
58        Self {
59            r: Some(14),
60            s: Some(10),
61            u: Some(5),
62        }
63    }
64}
65
66#[derive(Debug, Clone)]
67pub struct DtiInput<'a> {
68    pub data: DtiData<'a>,
69    pub params: DtiParams,
70}
71
72impl<'a> DtiInput<'a> {
73    #[inline]
74    pub fn from_candles(candles: &'a Candles, params: DtiParams) -> Self {
75        Self {
76            data: DtiData::Candles { candles },
77            params,
78        }
79    }
80
81    #[inline]
82    pub fn from_slices(high: &'a [f64], low: &'a [f64], params: DtiParams) -> Self {
83        Self {
84            data: DtiData::Slices { high, low },
85            params,
86        }
87    }
88
89    #[inline]
90    pub fn with_default_candles(candles: &'a Candles) -> Self {
91        Self::from_candles(candles, DtiParams::default())
92    }
93
94    #[inline]
95    pub fn get_r(&self) -> usize {
96        self.params.r.unwrap_or(14)
97    }
98    #[inline]
99    pub fn get_s(&self) -> usize {
100        self.params.s.unwrap_or(10)
101    }
102    #[inline]
103    pub fn get_u(&self) -> usize {
104        self.params.u.unwrap_or(5)
105    }
106}
107
108#[derive(Copy, Clone, Debug)]
109pub struct DtiBuilder {
110    r: Option<usize>,
111    s: Option<usize>,
112    u: Option<usize>,
113    kernel: Kernel,
114}
115
116impl Default for DtiBuilder {
117    fn default() -> Self {
118        Self {
119            r: None,
120            s: None,
121            u: None,
122            kernel: Kernel::Auto,
123        }
124    }
125}
126
127impl DtiBuilder {
128    #[inline(always)]
129    pub fn new() -> Self {
130        Self::default()
131    }
132    #[inline(always)]
133    pub fn r(mut self, n: usize) -> Self {
134        self.r = Some(n);
135        self
136    }
137    #[inline(always)]
138    pub fn s(mut self, n: usize) -> Self {
139        self.s = Some(n);
140        self
141    }
142    #[inline(always)]
143    pub fn u(mut self, n: usize) -> Self {
144        self.u = Some(n);
145        self
146    }
147    #[inline(always)]
148    pub fn kernel(mut self, k: Kernel) -> Self {
149        self.kernel = k;
150        self
151    }
152    #[inline(always)]
153    pub fn apply(self, c: &Candles) -> Result<DtiOutput, DtiError> {
154        let p = DtiParams {
155            r: self.r,
156            s: self.s,
157            u: self.u,
158        };
159        let i = DtiInput::from_candles(c, p);
160        dti_with_kernel(&i, self.kernel)
161    }
162    #[inline(always)]
163    pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<DtiOutput, DtiError> {
164        let p = DtiParams {
165            r: self.r,
166            s: self.s,
167            u: self.u,
168        };
169        let i = DtiInput::from_slices(high, low, p);
170        dti_with_kernel(&i, self.kernel)
171    }
172    #[inline(always)]
173    pub fn into_stream(self) -> Result<DtiStream, DtiError> {
174        let p = DtiParams {
175            r: self.r,
176            s: self.s,
177            u: self.u,
178        };
179        DtiStream::try_new(p)
180    }
181}
182
183#[derive(Debug, Error)]
184pub enum DtiError {
185    #[error("dti: Input data slice is empty.")]
186    EmptyInputData,
187    #[error("dti: Candle field error: {0}")]
188    CandleFieldError(String),
189    #[error("dti: Invalid period: period = {period}, data length = {data_len}")]
190    InvalidPeriod { period: usize, data_len: usize },
191    #[error("dti: Not enough valid data: needed = {needed}, valid = {valid}")]
192    NotEnoughValidData { needed: usize, valid: usize },
193    #[error("dti: All high/low values are NaN.")]
194    AllValuesNaN,
195    #[error("dti: Length mismatch: high length = {high}, low length = {low}")]
196    LengthMismatch { high: usize, low: usize },
197    #[error("dti: output length mismatch: expected = {expected}, got = {got}")]
198    OutputLengthMismatch { expected: usize, got: usize },
199    #[error("dti: invalid kernel for batch: {0:?}")]
200    InvalidKernelForBatch(Kernel),
201    #[error("dti: invalid range: start={start}, end={end}, step={step}")]
202    InvalidRange {
203        start: usize,
204        end: usize,
205        step: usize,
206    },
207}
208
209#[inline]
210pub fn dti(input: &DtiInput) -> Result<DtiOutput, DtiError> {
211    dti_with_kernel(input, Kernel::Auto)
212}
213
214#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
215#[inline]
216pub fn dti_into(input: &DtiInput, out: &mut [f64]) -> Result<(), DtiError> {
217    dti_into_slice(out, input, Kernel::Auto)
218}
219
220#[inline]
221pub fn dti_into_slice(dst: &mut [f64], input: &DtiInput, kern: Kernel) -> Result<(), DtiError> {
222    let (high, low) = match &input.data {
223        DtiData::Candles { candles } => {
224            let high = candles
225                .select_candle_field("high")
226                .map_err(|e| DtiError::CandleFieldError(e.to_string()))?;
227            let low = candles
228                .select_candle_field("low")
229                .map_err(|e| DtiError::CandleFieldError(e.to_string()))?;
230            (high, low)
231        }
232        DtiData::Slices { high, low } => (*high, *low),
233    };
234
235    if high.is_empty() || low.is_empty() {
236        return Err(DtiError::EmptyInputData);
237    }
238    let len = high.len();
239    if low.len() != len {
240        return Err(DtiError::LengthMismatch {
241            high: high.len(),
242            low: low.len(),
243        });
244    }
245    if dst.len() != len {
246        return Err(DtiError::OutputLengthMismatch {
247            expected: len,
248            got: dst.len(),
249        });
250    }
251
252    let first_valid_idx = (0..len)
253        .find(|&i| !high[i].is_nan() && !low[i].is_nan())
254        .ok_or(DtiError::AllValuesNaN)?;
255
256    let r = input.get_r();
257    let s = input.get_s();
258    let u = input.get_u();
259    for &p in &[r, s, u] {
260        if p == 0 || p > len {
261            return Err(DtiError::InvalidPeriod {
262                period: p,
263                data_len: len,
264            });
265        }
266        if len - first_valid_idx < p {
267            return Err(DtiError::NotEnoughValidData {
268                needed: p,
269                valid: len - first_valid_idx,
270            });
271        }
272    }
273
274    let chosen = match kern {
275        Kernel::Auto => detect_best_kernel(),
276        k => k,
277    };
278
279    unsafe {
280        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
281        {
282            match chosen {
283                Kernel::Scalar | Kernel::ScalarBatch => {
284                    dti_simd128(high, low, r, s, u, first_valid_idx, dst)
285                }
286                _ => dti_scalar(high, low, r, s, u, first_valid_idx, dst),
287            }
288        }
289        #[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
290        match chosen {
291            Kernel::Scalar | Kernel::ScalarBatch => {
292                dti_scalar(high, low, r, s, u, first_valid_idx, dst)
293            }
294            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
295            Kernel::Avx2 | Kernel::Avx2Batch => dti_avx2(high, low, r, s, u, first_valid_idx, dst),
296            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
297            Kernel::Avx512 | Kernel::Avx512Batch => {
298                dti_avx512(high, low, r, s, u, first_valid_idx, dst)
299            }
300            _ => unreachable!(),
301        }
302    }
303
304    for v in &mut dst[..=first_valid_idx] {
305        *v = f64::NAN;
306    }
307    Ok(())
308}
309
310pub fn dti_with_kernel(input: &DtiInput, kernel: Kernel) -> Result<DtiOutput, DtiError> {
311    let (high, low) = match &input.data {
312        DtiData::Candles { candles } => {
313            let high = candles
314                .select_candle_field("high")
315                .map_err(|e| DtiError::CandleFieldError(e.to_string()))?;
316            let low = candles
317                .select_candle_field("low")
318                .map_err(|e| DtiError::CandleFieldError(e.to_string()))?;
319            (high, low)
320        }
321        DtiData::Slices { high, low } => (*high, *low),
322    };
323
324    if high.is_empty() || low.is_empty() {
325        return Err(DtiError::EmptyInputData);
326    }
327    let len = high.len();
328    if low.len() != len {
329        return Err(DtiError::LengthMismatch {
330            high: high.len(),
331            low: low.len(),
332        });
333    }
334
335    let first_valid_idx = match (0..len).find(|&i| !high[i].is_nan() && !low[i].is_nan()) {
336        Some(idx) => idx,
337        None => return Err(DtiError::AllValuesNaN),
338    };
339
340    let r = input.get_r();
341    let s = input.get_s();
342    let u = input.get_u();
343
344    for &period in &[r, s, u] {
345        if period == 0 || period > len {
346            return Err(DtiError::InvalidPeriod {
347                period,
348                data_len: len,
349            });
350        }
351        if (len - first_valid_idx) < period {
352            return Err(DtiError::NotEnoughValidData {
353                needed: period,
354                valid: len - first_valid_idx,
355            });
356        }
357    }
358
359    let chosen = match kernel {
360        Kernel::Auto => detect_best_kernel(),
361        other => other,
362    };
363
364    let mut out = alloc_with_nan_prefix(len, first_valid_idx + 1);
365    unsafe {
366        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
367        {
368            match chosen {
369                Kernel::Scalar | Kernel::ScalarBatch => {
370                    dti_simd128(high, low, r, s, u, first_valid_idx, &mut out)
371                }
372                _ => dti_scalar(high, low, r, s, u, first_valid_idx, &mut out),
373            }
374            return Ok(DtiOutput { values: out });
375        }
376
377        #[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
378        match chosen {
379            Kernel::Scalar | Kernel::ScalarBatch => {
380                dti_scalar(high, low, r, s, u, first_valid_idx, &mut out)
381            }
382            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
383            Kernel::Avx2 | Kernel::Avx2Batch => {
384                dti_avx2(high, low, r, s, u, first_valid_idx, &mut out)
385            }
386            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
387            Kernel::Avx512 | Kernel::Avx512Batch => {
388                dti_avx512(high, low, r, s, u, first_valid_idx, &mut out)
389            }
390            _ => unreachable!(),
391        }
392    }
393    Ok(DtiOutput { values: out })
394}
395
396#[inline]
397pub fn dti_scalar(
398    high: &[f64],
399    low: &[f64],
400    r: usize,
401    s: usize,
402    u: usize,
403    first_valid_idx: usize,
404    out: &mut [f64],
405) {
406    let len = high.len();
407    let alpha_r = 2.0 / (r as f64 + 1.0);
408    let alpha_s = 2.0 / (s as f64 + 1.0);
409    let alpha_u = 2.0 / (u as f64 + 1.0);
410
411    let alpha_r_1 = 1.0 - alpha_r;
412    let alpha_s_1 = 1.0 - alpha_s;
413    let alpha_u_1 = 1.0 - alpha_u;
414
415    let mut e0_r = 0.0;
416    let mut e0_s = 0.0;
417    let mut e0_u = 0.0;
418    let mut e1_r = 0.0;
419    let mut e1_s = 0.0;
420    let mut e1_u = 0.0;
421
422    out[first_valid_idx] = f64::NAN;
423    for i in (first_valid_idx + 1)..len {
424        let dh = high[i] - high[i - 1];
425        let dl = low[i] - low[i - 1];
426        let x_hmu = if dh > 0.0 { dh } else { 0.0 };
427        let x_lmd = if dl < 0.0 { -dl } else { 0.0 };
428        let x_price = x_hmu - x_lmd;
429        let x_price_abs = x_price.abs();
430
431        e0_r = alpha_r * x_price + alpha_r_1 * e0_r;
432        e0_s = alpha_s * e0_r + alpha_s_1 * e0_s;
433        e0_u = alpha_u * e0_s + alpha_u_1 * e0_u;
434
435        e1_r = alpha_r * x_price_abs + alpha_r_1 * e1_r;
436        e1_s = alpha_s * e1_r + alpha_s_1 * e1_s;
437        e1_u = alpha_u * e1_s + alpha_u_1 * e1_u;
438
439        if !e1_u.is_nan() && e1_u != 0.0 {
440            out[i] = 100.0 * e0_u / e1_u;
441        } else {
442            out[i] = 0.0;
443        }
444    }
445}
446
447#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
448#[inline]
449unsafe fn dti_simd128(
450    high: &[f64],
451    low: &[f64],
452    r: usize,
453    s: usize,
454    u: usize,
455    first_valid_idx: usize,
456    out: &mut [f64],
457) {
458    dti_scalar(high, low, r, s, u, first_valid_idx, out);
459}
460
461#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
462#[inline]
463pub fn dti_avx2(
464    high: &[f64],
465    low: &[f64],
466    r: usize,
467    s: usize,
468    u: usize,
469    first_valid_idx: usize,
470    out: &mut [f64],
471) {
472    use core::arch::x86_64::*;
473
474    let len = high.len();
475    let start = first_valid_idx + 1;
476    if start >= len {
477        return;
478    }
479
480    let alpha_r = 2.0 / (r as f64 + 1.0);
481    let alpha_s = 2.0 / (s as f64 + 1.0);
482    let alpha_u = 2.0 / (u as f64 + 1.0);
483    let ar1 = 1.0 - alpha_r;
484    let as1 = 1.0 - alpha_s;
485    let au1 = 1.0 - alpha_u;
486
487    unsafe {
488        let vr_a = _mm_set1_pd(alpha_r);
489        let vs_a = _mm_set1_pd(alpha_s);
490        let vu_a = _mm_set1_pd(alpha_u);
491        let vr_b = _mm_set1_pd(ar1);
492        let vs_b = _mm_set1_pd(as1);
493        let vu_b = _mm_set1_pd(au1);
494
495        let mut v_er = _mm_set1_pd(0.0);
496        let mut v_es = _mm_set1_pd(0.0);
497        let mut v_eu = _mm_set1_pd(0.0);
498
499        let ph = high.as_ptr();
500        let pl = low.as_ptr();
501        let po = out.as_mut_ptr();
502        let half = 0.5f64;
503        let hundred = 100.0f64;
504
505        let mut i = start;
506        while i < len {
507            let hi0 = *ph.add(i);
508            let hi_1 = *ph.add(i - 1);
509            let lo0 = *pl.add(i);
510            let lo_1 = *pl.add(i - 1);
511
512            let dh = hi0 - hi_1;
513            let dl = lo0 - lo_1;
514
515            let x = half * (dh.abs() + dh) - half * (dl.abs() - dl);
516            let ax = x.abs();
517
518            let vx = _mm_set_pd(ax, x);
519
520            v_er = _mm_add_pd(_mm_mul_pd(vr_a, vx), _mm_mul_pd(vr_b, v_er));
521            v_es = _mm_add_pd(_mm_mul_pd(vs_a, v_er), _mm_mul_pd(vs_b, v_es));
522            v_eu = _mm_add_pd(_mm_mul_pd(vu_a, v_es), _mm_mul_pd(vu_b, v_eu));
523
524            let mut tmp = [0.0f64; 2];
525            _mm_storeu_pd(tmp.as_mut_ptr(), v_eu);
526            let num = tmp[0];
527            let den = tmp[1];
528
529            *po.add(i) = if !den.is_nan() && den != 0.0 {
530                hundred * num / den
531            } else {
532                0.0
533            };
534
535            i += 1;
536        }
537    }
538}
539
540#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
541#[inline]
542pub fn dti_avx512(
543    high: &[f64],
544    low: &[f64],
545    r: usize,
546    s: usize,
547    u: usize,
548    first_valid_idx: usize,
549    out: &mut [f64],
550) {
551    dti_avx2(high, low, r, s, u, first_valid_idx, out)
552}
553
554#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
555#[inline]
556pub unsafe fn dti_avx512_short(
557    high: &[f64],
558    low: &[f64],
559    r: usize,
560    s: usize,
561    u: usize,
562    first_valid_idx: usize,
563    out: &mut [f64],
564) {
565    dti_avx2(high, low, r, s, u, first_valid_idx, out)
566}
567
568#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
569#[inline]
570pub unsafe fn dti_avx512_long(
571    high: &[f64],
572    low: &[f64],
573    r: usize,
574    s: usize,
575    u: usize,
576    first_valid_idx: usize,
577    out: &mut [f64],
578) {
579    dti_avx2(high, low, r, s, u, first_valid_idx, out)
580}
581
582#[inline(always)]
583pub fn dti_row_scalar(
584    high: &[f64],
585    low: &[f64],
586    r: usize,
587    s: usize,
588    u: usize,
589    first_valid_idx: usize,
590    out: &mut [f64],
591) {
592    dti_scalar(high, low, r, s, u, first_valid_idx, out)
593}
594
595#[inline(always)]
596fn dti_precompute_base(high: &[f64], low: &[f64], start: usize) -> (AVec<f64>, AVec<f64>) {
597    let len = high.len();
598    let mut x = AVec::<f64>::new(CACHELINE_ALIGN);
599    let mut ax = AVec::<f64>::new(CACHELINE_ALIGN);
600    x.resize(len, 0.0);
601    ax.resize(len, 0.0);
602    for i in (start)..len {
603        let dh = high[i] - high[i - 1];
604        let dl = low[i] - low[i - 1];
605        let x_hmu = if dh > 0.0 { dh } else { 0.0 };
606        let x_lmd = if dl < 0.0 { -dl } else { 0.0 };
607        let v = x_hmu - x_lmd;
608        x[i] = v;
609        ax[i] = v.abs();
610    }
611    (x, ax)
612}
613
614#[inline(always)]
615fn dti_row_scalar_from_base(
616    x: &[f64],
617    ax: &[f64],
618    r: usize,
619    s: usize,
620    u: usize,
621    start: usize,
622    out: &mut [f64],
623) {
624    let len = x.len();
625    if start >= len {
626        return;
627    }
628    let alpha_r = 2.0 / (r as f64 + 1.0);
629    let alpha_s = 2.0 / (s as f64 + 1.0);
630    let alpha_u = 2.0 / (u as f64 + 1.0);
631    let ar1 = 1.0 - alpha_r;
632    let as1 = 1.0 - alpha_s;
633    let au1 = 1.0 - alpha_u;
634    let mut e0_r = 0.0;
635    let mut e0_s = 0.0;
636    let mut e0_u = 0.0;
637    let mut e1_r = 0.0;
638    let mut e1_s = 0.0;
639    let mut e1_u = 0.0;
640    for i in start..len {
641        let xi = x[i];
642        let axi = ax[i];
643        e0_r = alpha_r * xi + ar1 * e0_r;
644        e0_s = alpha_s * e0_r + as1 * e0_s;
645        e0_u = alpha_u * e0_s + au1 * e0_u;
646        e1_r = alpha_r * axi + ar1 * e1_r;
647        e1_s = alpha_s * e1_r + as1 * e1_s;
648        e1_u = alpha_u * e1_s + au1 * e1_u;
649        out[i] = if !e1_u.is_nan() && e1_u != 0.0 {
650            100.0 * e0_u / e1_u
651        } else {
652            0.0
653        };
654    }
655}
656
657#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
658#[inline(always)]
659fn dti_row_avx2_from_base(
660    x: &[f64],
661    ax: &[f64],
662    r: usize,
663    s: usize,
664    u: usize,
665    start: usize,
666    out: &mut [f64],
667) {
668    unsafe {
669        use core::arch::x86_64::*;
670        let len = x.len();
671        if start >= len {
672            return;
673        }
674        let alpha_r = 2.0 / (r as f64 + 1.0);
675        let alpha_s = 2.0 / (s as f64 + 1.0);
676        let alpha_u = 2.0 / (u as f64 + 1.0);
677        let ar1 = 1.0 - alpha_r;
678        let as1 = 1.0 - alpha_s;
679        let au1 = 1.0 - alpha_u;
680
681        let vr_a = _mm_set1_pd(alpha_r);
682        let vs_a = _mm_set1_pd(alpha_s);
683        let vu_a = _mm_set1_pd(alpha_u);
684        let vr_b = _mm_set1_pd(ar1);
685        let vs_b = _mm_set1_pd(as1);
686        let vu_b = _mm_set1_pd(au1);
687
688        let mut v_er = _mm_set1_pd(0.0);
689        let mut v_es = _mm_set1_pd(0.0);
690        let mut v_eu = _mm_set1_pd(0.0);
691        let px = x.as_ptr();
692        let pax = ax.as_ptr();
693        let po = out.as_mut_ptr();
694        let hundred = 100.0f64;
695        let mut i = start;
696        while i < len {
697            let xi = *px.add(i);
698            let axi = *pax.add(i);
699            let vx = _mm_set_pd(axi, xi);
700            v_er = _mm_add_pd(_mm_mul_pd(vr_a, vx), _mm_mul_pd(vr_b, v_er));
701            v_es = _mm_add_pd(_mm_mul_pd(vs_a, v_er), _mm_mul_pd(vs_b, v_es));
702            v_eu = _mm_add_pd(_mm_mul_pd(vu_a, v_es), _mm_mul_pd(vu_b, v_eu));
703            let mut tmp = [0.0f64; 2];
704            _mm_storeu_pd(tmp.as_mut_ptr(), v_eu);
705            let num = tmp[0];
706            let den = tmp[1];
707            *po.add(i) = if !den.is_nan() && den != 0.0 {
708                hundred * num / den
709            } else {
710                0.0
711            };
712            i += 1;
713        }
714    }
715}
716
717#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
718#[inline(always)]
719fn dti_row_avx512_from_base(
720    x: &[f64],
721    ax: &[f64],
722    r: usize,
723    s: usize,
724    u: usize,
725    start: usize,
726    out: &mut [f64],
727) {
728    dti_row_avx2_from_base(x, ax, r, s, u, start, out)
729}
730
731#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
732#[inline(always)]
733pub fn dti_row_avx2(
734    high: &[f64],
735    low: &[f64],
736    r: usize,
737    s: usize,
738    u: usize,
739    first_valid_idx: usize,
740    out: &mut [f64],
741) {
742    dti_avx2(high, low, r, s, u, first_valid_idx, out)
743}
744
745#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
746#[inline(always)]
747pub fn dti_row_avx512(
748    high: &[f64],
749    low: &[f64],
750    r: usize,
751    s: usize,
752    u: usize,
753    first_valid_idx: usize,
754    out: &mut [f64],
755) {
756    dti_avx512(high, low, r, s, u, first_valid_idx, out)
757}
758
759#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
760#[inline(always)]
761pub fn dti_row_avx512_short(
762    high: &[f64],
763    low: &[f64],
764    r: usize,
765    s: usize,
766    u: usize,
767    first_valid_idx: usize,
768    out: &mut [f64],
769) {
770    unsafe { dti_avx512_short(high, low, r, s, u, first_valid_idx, out) }
771}
772
773#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
774#[inline(always)]
775pub fn dti_row_avx512_long(
776    high: &[f64],
777    low: &[f64],
778    r: usize,
779    s: usize,
780    u: usize,
781    first_valid_idx: usize,
782    out: &mut [f64],
783) {
784    unsafe { dti_avx512_long(high, low, r, s, u, first_valid_idx, out) }
785}
786
787#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
788#[inline(always)]
789pub fn dti_row_simd128(
790    high: &[f64],
791    low: &[f64],
792    r: usize,
793    s: usize,
794    u: usize,
795    first_valid_idx: usize,
796    out: &mut [f64],
797) {
798    unsafe { dti_simd128(high, low, r, s, u, first_valid_idx, out) }
799}
800
801#[derive(Debug, Clone)]
802pub struct DtiStream {
803    r: usize,
804    s: usize,
805    u: usize,
806    alpha_r: f64,
807    alpha_s: f64,
808    alpha_u: f64,
809    alpha_r_1: f64,
810    alpha_s_1: f64,
811    alpha_u_1: f64,
812    e0_r: f64,
813    e0_s: f64,
814    e0_u: f64,
815    e1_r: f64,
816    e1_s: f64,
817    e1_u: f64,
818    last_high: Option<f64>,
819    last_low: Option<f64>,
820    initialized: bool,
821}
822
823impl DtiStream {
824    pub fn try_new(params: DtiParams) -> Result<Self, DtiError> {
825        let r = params.r.unwrap_or(14);
826        let s = params.s.unwrap_or(10);
827        let u = params.u.unwrap_or(5);
828        if r == 0 || s == 0 || u == 0 {
829            return Err(DtiError::InvalidPeriod {
830                period: 0,
831                data_len: 0,
832            });
833        }
834        let alpha_r = 2.0 / (r as f64 + 1.0);
835        let alpha_s = 2.0 / (s as f64 + 1.0);
836        let alpha_u = 2.0 / (u as f64 + 1.0);
837        let alpha_r_1 = 1.0 - alpha_r;
838        let alpha_s_1 = 1.0 - alpha_s;
839        let alpha_u_1 = 1.0 - alpha_u;
840        Ok(Self {
841            r,
842            s,
843            u,
844            alpha_r,
845            alpha_s,
846            alpha_u,
847            alpha_r_1,
848            alpha_s_1,
849            alpha_u_1,
850            e0_r: 0.0,
851            e0_s: 0.0,
852            e0_u: 0.0,
853            e1_r: 0.0,
854            e1_s: 0.0,
855            e1_u: 0.0,
856            last_high: None,
857            last_low: None,
858            initialized: false,
859        })
860    }
861
862    #[inline(always)]
863    pub fn update(&mut self, high: f64, low: f64) -> Option<f64> {
864        if let (Some(last_h), Some(last_l)) = (self.last_high, self.last_low) {
865            let dh = high - last_h;
866            let dl = low - last_l;
867            let x_hmu = if dh > 0.0 { dh } else { 0.0 };
868            let x_lmd = if dl < 0.0 { -dl } else { 0.0 };
869            let x_price = x_hmu - x_lmd;
870            let x_price_abs = x_price.abs();
871
872            self.e0_r = (x_price - self.e0_r).mul_add(self.alpha_r, self.e0_r);
873            self.e0_s = (self.e0_r - self.e0_s).mul_add(self.alpha_s, self.e0_s);
874            self.e0_u = (self.e0_s - self.e0_u).mul_add(self.alpha_u, self.e0_u);
875
876            self.e1_r = (x_price_abs - self.e1_r).mul_add(self.alpha_r, self.e1_r);
877            self.e1_s = (self.e1_r - self.e1_s).mul_add(self.alpha_s, self.e1_s);
878            self.e1_u = (self.e1_s - self.e1_u).mul_add(self.alpha_u, self.e1_u);
879
880            self.last_high = Some(high);
881            self.last_low = Some(low);
882
883            if !self.e1_u.is_nan() && self.e1_u != 0.0 {
884                Some(fast_div_approx(self.e0_u * 100.0, self.e1_u))
885            } else {
886                Some(0.0)
887            }
888        } else {
889            self.last_high = Some(high);
890            self.last_low = Some(low);
891            self.initialized = true;
892            None
893        }
894    }
895
896    #[inline(always)]
897    pub fn reset(&mut self) {
898        self.e0_r = 0.0;
899        self.e0_s = 0.0;
900        self.e0_u = 0.0;
901        self.e1_r = 0.0;
902        self.e1_s = 0.0;
903        self.e1_u = 0.0;
904        self.last_high = None;
905        self.last_low = None;
906        self.initialized = false;
907    }
908
909    #[inline(always)]
910    pub fn update_delta(&mut self, dh: f64, dl: f64) -> Option<f64> {
911        if let (Some(prev_h), Some(prev_l)) = (self.last_high, self.last_low) {
912            let high = prev_h + dh;
913            let low = prev_l + dl;
914
915            let x_hmu = if dh > 0.0 { dh } else { 0.0 };
916            let x_lmd = if dl < 0.0 { -dl } else { 0.0 };
917            let x_price = x_hmu - x_lmd;
918            let x_price_abs = x_price.abs();
919
920            self.e0_r = (x_price - self.e0_r).mul_add(self.alpha_r, self.e0_r);
921            self.e0_s = (self.e0_r - self.e0_s).mul_add(self.alpha_s, self.e0_s);
922            self.e0_u = (self.e0_s - self.e0_u).mul_add(self.alpha_u, self.e0_u);
923
924            self.e1_r = (x_price_abs - self.e1_r).mul_add(self.alpha_r, self.e1_r);
925            self.e1_s = (self.e1_r - self.e1_s).mul_add(self.alpha_s, self.e1_s);
926            self.e1_u = (self.e1_s - self.e1_u).mul_add(self.alpha_u, self.e1_u);
927
928            self.last_high = Some(high);
929            self.last_low = Some(low);
930
931            if !self.e1_u.is_nan() && self.e1_u != 0.0 {
932                Some(fast_div_approx(self.e0_u * 100.0, self.e1_u))
933            } else {
934                Some(0.0)
935            }
936        } else {
937            None
938        }
939    }
940}
941
942#[inline(always)]
943fn fast_div_approx(num: f64, den: f64) -> f64 {
944    debug_assert!(den != 0.0);
945
946    let ad = den.abs();
947    if ad <= f32::MAX as f64 && ad >= f32::MIN_POSITIVE as f64 {
948        let r0 = (1.0f32 / den as f32) as f64;
949        let r1 = r0 * (2.0 - den * r0);
950        num * r1
951    } else {
952        num / den
953    }
954}
955
956#[cfg(feature = "python")]
957#[pyfunction(name = "dti")]
958#[pyo3(signature = (high, low, r, s, u, kernel=None))]
959pub fn dti_py<'py>(
960    py: Python<'py>,
961    high: numpy::PyReadonlyArray1<'py, f64>,
962    low: numpy::PyReadonlyArray1<'py, f64>,
963    r: usize,
964    s: usize,
965    u: usize,
966    kernel: Option<&str>,
967) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
968    use numpy::{IntoPyArray, PyArrayMethods};
969
970    let high_slice = high.as_slice()?;
971    let low_slice = low.as_slice()?;
972    let kern = validate_kernel(kernel, false)?;
973
974    let params = DtiParams {
975        r: Some(r),
976        s: Some(s),
977        u: Some(u),
978    };
979    let input = DtiInput::from_slices(high_slice, low_slice, params);
980
981    let result_vec: Vec<f64> = py
982        .allow_threads(|| dti_with_kernel(&input, kern).map(|o| o.values))
983        .map_err(|e| PyValueError::new_err(e.to_string()))?;
984
985    Ok(result_vec.into_pyarray(py))
986}
987
988#[cfg(feature = "python")]
989#[pyclass(name = "DtiStream")]
990pub struct DtiStreamPy {
991    stream: DtiStream,
992}
993
994#[cfg(feature = "python")]
995#[pymethods]
996impl DtiStreamPy {
997    #[new]
998    fn new(r: usize, s: usize, u: usize) -> PyResult<Self> {
999        let params = DtiParams {
1000            r: Some(r),
1001            s: Some(s),
1002            u: Some(u),
1003        };
1004        let stream =
1005            DtiStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1006        Ok(DtiStreamPy { stream })
1007    }
1008
1009    fn update(&mut self, high: f64, low: f64) -> Option<f64> {
1010        self.stream.update(high, low)
1011    }
1012}
1013
1014#[cfg(feature = "python")]
1015#[pyfunction(name = "dti_batch")]
1016#[pyo3(signature = (high, low, r_range, s_range, u_range, kernel=None))]
1017pub fn dti_batch_py<'py>(
1018    py: Python<'py>,
1019    high: numpy::PyReadonlyArray1<'py, f64>,
1020    low: numpy::PyReadonlyArray1<'py, f64>,
1021    r_range: (usize, usize, usize),
1022    s_range: (usize, usize, usize),
1023    u_range: (usize, usize, usize),
1024    kernel: Option<&str>,
1025) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1026    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1027    use pyo3::types::PyDict;
1028
1029    let high_slice = high.as_slice()?;
1030    let low_slice = low.as_slice()?;
1031
1032    let sweep = DtiBatchRange {
1033        r: r_range,
1034        s: s_range,
1035        u: u_range,
1036    };
1037
1038    let combos = expand_grid(&sweep);
1039    let rows = combos.len();
1040    let cols = high_slice.len();
1041
1042    let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1043    let slice_out = unsafe { out_arr.as_slice_mut()? };
1044
1045    let kern = validate_kernel(kernel, true)?;
1046
1047    let combos = py
1048        .allow_threads(|| {
1049            let kernel = match kern {
1050                Kernel::Auto => detect_best_batch_kernel(),
1051                k => k,
1052            };
1053            let simd = match kernel {
1054                Kernel::Avx512Batch => Kernel::Avx512,
1055                Kernel::Avx2Batch => Kernel::Avx2,
1056                Kernel::ScalarBatch => Kernel::Scalar,
1057                _ => unreachable!(),
1058            };
1059            dti_batch_inner_into(high_slice, low_slice, &sweep, simd, true, slice_out)
1060        })
1061        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1062
1063    let dict = PyDict::new(py);
1064    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1065    dict.set_item(
1066        "r",
1067        combos
1068            .iter()
1069            .map(|p| p.r.unwrap() as u64)
1070            .collect::<Vec<_>>()
1071            .into_pyarray(py),
1072    )?;
1073    dict.set_item(
1074        "s",
1075        combos
1076            .iter()
1077            .map(|p| p.s.unwrap() as u64)
1078            .collect::<Vec<_>>()
1079            .into_pyarray(py),
1080    )?;
1081    dict.set_item(
1082        "u",
1083        combos
1084            .iter()
1085            .map(|p| p.u.unwrap() as u64)
1086            .collect::<Vec<_>>()
1087            .into_pyarray(py),
1088    )?;
1089    Ok(dict)
1090}
1091
1092#[cfg(all(feature = "python", feature = "cuda"))]
1093use crate::cuda::oscillators::dti_wrapper::DeviceArrayF32Dti;
1094#[cfg(all(feature = "python", feature = "cuda"))]
1095use crate::cuda::oscillators::CudaDti;
1096#[cfg(all(feature = "python", feature = "cuda"))]
1097use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1098
1099#[cfg(all(feature = "python", feature = "cuda"))]
1100#[pyclass(module = "ta_indicators.cuda", unsendable)]
1101pub struct DtiDeviceArrayF32Py {
1102    pub(crate) inner: DeviceArrayF32Dti,
1103}
1104
1105#[cfg(all(feature = "python", feature = "cuda"))]
1106#[pymethods]
1107impl DtiDeviceArrayF32Py {
1108    #[getter]
1109    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1110        let d = PyDict::new(py);
1111        d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1112        d.set_item("typestr", "<f4")?;
1113        d.set_item(
1114            "strides",
1115            (
1116                self.inner.cols * std::mem::size_of::<f32>(),
1117                std::mem::size_of::<f32>(),
1118            ),
1119        )?;
1120        d.set_item("data", (self.inner.device_ptr() as usize, false))?;
1121
1122        d.set_item("version", 3)?;
1123        Ok(d)
1124    }
1125
1126    fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
1127        let mut device_ordinal: i32 = self.inner.device_id as i32;
1128        unsafe {
1129            let attr = cust::sys::CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL;
1130            let mut value = std::mem::MaybeUninit::<i32>::uninit();
1131            let err = cust::sys::cuPointerGetAttribute(
1132                value.as_mut_ptr() as *mut std::ffi::c_void,
1133                attr,
1134                self.inner.buf.as_device_ptr().as_raw(),
1135            );
1136            if err == cust::sys::CUresult::CUDA_SUCCESS {
1137                device_ordinal = value.assume_init();
1138            } else {
1139                let _ = cust::sys::cuCtxGetDevice(&mut device_ordinal);
1140            }
1141        }
1142        Ok((2, device_ordinal))
1143    }
1144
1145    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1146    fn __dlpack__<'py>(
1147        &mut self,
1148        py: Python<'py>,
1149        stream: Option<pyo3::PyObject>,
1150        max_version: Option<pyo3::PyObject>,
1151        dl_device: Option<pyo3::PyObject>,
1152        copy: Option<pyo3::PyObject>,
1153    ) -> PyResult<PyObject> {
1154        let (kdl, alloc_dev) = self.__dlpack_device__()?;
1155        if let Some(dev_obj) = dl_device.as_ref() {
1156            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1157                if dev_ty != kdl || dev_id != alloc_dev {
1158                    let wants_copy = copy
1159                        .as_ref()
1160                        .and_then(|c| c.extract::<bool>(py).ok())
1161                        .unwrap_or(false);
1162                    if wants_copy {
1163                        return Err(PyValueError::new_err(
1164                            "device copy not implemented for __dlpack__",
1165                        ));
1166                    } else {
1167                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1168                    }
1169                }
1170            }
1171        }
1172        let _ = stream;
1173
1174        let ctx_arc = self.inner.ctx.clone();
1175        let dummy =
1176            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1177        let inner = std::mem::replace(
1178            &mut self.inner,
1179            DeviceArrayF32Dti {
1180                buf: dummy,
1181                rows: 0,
1182                cols: 0,
1183                ctx: ctx_arc,
1184                device_id: alloc_dev as u32,
1185            },
1186        );
1187
1188        let rows = inner.rows;
1189        let cols = inner.cols;
1190        let buf = inner.buf;
1191
1192        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1193
1194        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1195    }
1196}
1197
1198#[cfg(all(feature = "python", feature = "cuda"))]
1199#[pyfunction(name = "dti_cuda_batch_dev")]
1200#[pyo3(signature = (high_f32, low_f32, r_range, s_range, u_range, device_id=0))]
1201pub fn dti_cuda_batch_dev_py<'py>(
1202    py: Python<'py>,
1203    high_f32: numpy::PyReadonlyArray1<'py, f32>,
1204    low_f32: numpy::PyReadonlyArray1<'py, f32>,
1205    r_range: (usize, usize, usize),
1206    s_range: (usize, usize, usize),
1207    u_range: (usize, usize, usize),
1208    device_id: usize,
1209) -> PyResult<(DtiDeviceArrayF32Py, Bound<'py, pyo3::types::PyDict>)> {
1210    use crate::cuda::cuda_available;
1211    use numpy::IntoPyArray;
1212    if !cuda_available() {
1213        return Err(PyValueError::new_err("CUDA not available"));
1214    }
1215    let high = high_f32.as_slice()?;
1216    let low = low_f32.as_slice()?;
1217    let sweep = DtiBatchRange {
1218        r: r_range,
1219        s: s_range,
1220        u: u_range,
1221    };
1222    let (inner, combos) = py.allow_threads(|| {
1223        let cuda = CudaDti::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1224        cuda.dti_batch_dev(high, low, &sweep)
1225            .map_err(|e| PyValueError::new_err(e.to_string()))
1226    })?;
1227    let dict = PyDict::new(py);
1228    let rr: Vec<u64> = combos.iter().map(|c| c.r.unwrap() as u64).collect();
1229    let ss: Vec<u64> = combos.iter().map(|c| c.s.unwrap() as u64).collect();
1230    let uu: Vec<u64> = combos.iter().map(|c| c.u.unwrap() as u64).collect();
1231    dict.set_item("r", rr.into_pyarray(py))?;
1232    dict.set_item("s", ss.into_pyarray(py))?;
1233    dict.set_item("u", uu.into_pyarray(py))?;
1234    Ok((DtiDeviceArrayF32Py { inner }, dict))
1235}
1236
1237#[cfg(all(feature = "python", feature = "cuda"))]
1238#[pyfunction(name = "dti_cuda_many_series_one_param_dev")]
1239#[pyo3(signature = (high_tm_f32, low_tm_f32, r, s, u, device_id=0))]
1240pub fn dti_cuda_many_series_one_param_dev_py(
1241    py: Python<'_>,
1242    high_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1243    low_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1244    r: usize,
1245    s: usize,
1246    u: usize,
1247    device_id: usize,
1248) -> PyResult<DtiDeviceArrayF32Py> {
1249    use crate::cuda::cuda_available;
1250    use numpy::PyUntypedArrayMethods;
1251    if !cuda_available() {
1252        return Err(PyValueError::new_err("CUDA not available"));
1253    }
1254    let high_flat = high_tm_f32.as_slice()?;
1255    let low_flat = low_tm_f32.as_slice()?;
1256    let rows = high_tm_f32.shape()[0];
1257    let cols = high_tm_f32.shape()[1];
1258    let elems = cols
1259        .checked_mul(rows)
1260        .ok_or_else(|| PyValueError::new_err("matrix size overflow"))?;
1261    if low_tm_f32.shape() != [rows, cols] {
1262        return Err(PyValueError::new_err("high/low shapes mismatch"));
1263    }
1264    if high_tm_f32.len() != elems || low_tm_f32.len() != elems {
1265        return Err(PyValueError::new_err("high/low flattened sizes mismatch"));
1266    }
1267    let params = DtiParams {
1268        r: Some(r),
1269        s: Some(s),
1270        u: Some(u),
1271    };
1272    let inner = py.allow_threads(|| {
1273        let cuda = CudaDti::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1274        cuda.dti_many_series_one_param_time_major_dev(high_flat, low_flat, cols, rows, &params)
1275            .map_err(|e| PyValueError::new_err(e.to_string()))
1276    })?;
1277    Ok(DtiDeviceArrayF32Py { inner })
1278}
1279
1280#[derive(Clone, Debug)]
1281pub struct DtiBatchRange {
1282    pub r: (usize, usize, usize),
1283    pub s: (usize, usize, usize),
1284    pub u: (usize, usize, usize),
1285}
1286
1287impl Default for DtiBatchRange {
1288    fn default() -> Self {
1289        Self {
1290            r: (14, 263, 1),
1291            s: (10, 10, 0),
1292            u: (5, 5, 0),
1293        }
1294    }
1295}
1296
1297#[derive(Clone, Debug, Default)]
1298pub struct DtiBatchBuilder {
1299    range: DtiBatchRange,
1300    kernel: Kernel,
1301}
1302
1303impl DtiBatchBuilder {
1304    pub fn new() -> Self {
1305        Self::default()
1306    }
1307    pub fn kernel(mut self, k: Kernel) -> Self {
1308        self.kernel = k;
1309        self
1310    }
1311    #[inline]
1312    pub fn r_range(mut self, start: usize, end: usize, step: usize) -> Self {
1313        self.range.r = (start, end, step);
1314        self
1315    }
1316    #[inline]
1317    pub fn r_static(mut self, p: usize) -> Self {
1318        self.range.r = (p, p, 0);
1319        self
1320    }
1321    #[inline]
1322    pub fn s_range(mut self, start: usize, end: usize, step: usize) -> Self {
1323        self.range.s = (start, end, step);
1324        self
1325    }
1326    #[inline]
1327    pub fn s_static(mut self, x: usize) -> Self {
1328        self.range.s = (x, x, 0);
1329        self
1330    }
1331    #[inline]
1332    pub fn u_range(mut self, start: usize, end: usize, step: usize) -> Self {
1333        self.range.u = (start, end, step);
1334        self
1335    }
1336    #[inline]
1337    pub fn u_static(mut self, s: usize) -> Self {
1338        self.range.u = (s, s, 0);
1339        self
1340    }
1341    pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<DtiBatchOutput, DtiError> {
1342        dti_batch_with_kernel(high, low, &self.range, self.kernel)
1343    }
1344    pub fn with_default_slices(
1345        high: &[f64],
1346        low: &[f64],
1347        k: Kernel,
1348    ) -> Result<DtiBatchOutput, DtiError> {
1349        DtiBatchBuilder::new().kernel(k).apply_slices(high, low)
1350    }
1351    pub fn apply_candles(self, c: &Candles) -> Result<DtiBatchOutput, DtiError> {
1352        let high = c
1353            .select_candle_field("high")
1354            .map_err(|e| DtiError::CandleFieldError(e.to_string()))?;
1355        let low = c
1356            .select_candle_field("low")
1357            .map_err(|e| DtiError::CandleFieldError(e.to_string()))?;
1358        self.apply_slices(high, low)
1359    }
1360    pub fn with_default_candles(c: &Candles) -> Result<DtiBatchOutput, DtiError> {
1361        DtiBatchBuilder::new().kernel(Kernel::Auto).apply_candles(c)
1362    }
1363}
1364
1365#[derive(Clone, Debug)]
1366pub struct DtiBatchOutput {
1367    pub values: Vec<f64>,
1368    pub combos: Vec<DtiParams>,
1369    pub rows: usize,
1370    pub cols: usize,
1371}
1372
1373impl DtiBatchOutput {
1374    pub fn row_for_params(&self, p: &DtiParams) -> Option<usize> {
1375        self.combos.iter().position(|c| {
1376            c.r.unwrap_or(14) == p.r.unwrap_or(14)
1377                && c.s.unwrap_or(10) == p.s.unwrap_or(10)
1378                && c.u.unwrap_or(5) == p.u.unwrap_or(5)
1379        })
1380    }
1381    pub fn values_for(&self, p: &DtiParams) -> Option<&[f64]> {
1382        self.row_for_params(p).map(|row| {
1383            let start = row * self.cols;
1384            &self.values[start..start + self.cols]
1385        })
1386    }
1387}
1388
1389#[inline(always)]
1390fn expand_grid(r: &DtiBatchRange) -> Vec<DtiParams> {
1391    #[inline(always)]
1392    fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
1393        if step == 0 || start == end {
1394            return vec![start];
1395        }
1396        let mut v = Vec::new();
1397        if start < end {
1398            let mut cur = start;
1399            while cur <= end {
1400                v.push(cur);
1401                match cur.checked_add(step) {
1402                    Some(nxt) => {
1403                        if nxt > end {
1404                            break;
1405                        }
1406                        cur = nxt;
1407                    }
1408                    None => break,
1409                }
1410            }
1411        } else {
1412            let mut cur = start;
1413            loop {
1414                if cur < end {
1415                    break;
1416                }
1417                v.push(cur);
1418                if cur == end {
1419                    break;
1420                }
1421                match cur.checked_sub(step) {
1422                    Some(nxt) => {
1423                        if nxt < end {
1424                            break;
1425                        }
1426                        cur = nxt;
1427                    }
1428                    None => break,
1429                }
1430            }
1431        }
1432        v
1433    }
1434
1435    let rr = axis_usize(r.r);
1436    let ss = axis_usize(r.s);
1437    let uu = axis_usize(r.u);
1438
1439    let cap = rr
1440        .len()
1441        .checked_mul(ss.len())
1442        .and_then(|x| x.checked_mul(uu.len()))
1443        .unwrap_or(0);
1444    let mut out = Vec::with_capacity(cap);
1445    for &rv in &rr {
1446        for &sv in &ss {
1447            for &uv in &uu {
1448                out.push(DtiParams {
1449                    r: Some(rv),
1450                    s: Some(sv),
1451                    u: Some(uv),
1452                });
1453            }
1454        }
1455    }
1456    out
1457}
1458
1459#[inline(always)]
1460pub fn dti_batch_with_kernel(
1461    high: &[f64],
1462    low: &[f64],
1463    sweep: &DtiBatchRange,
1464    k: Kernel,
1465) -> Result<DtiBatchOutput, DtiError> {
1466    let kernel = match k {
1467        Kernel::Auto => detect_best_batch_kernel(),
1468        other if other.is_batch() => other,
1469        other => return Err(DtiError::InvalidKernelForBatch(other)),
1470    };
1471    let simd = match kernel {
1472        Kernel::Avx512Batch => Kernel::Avx512,
1473        Kernel::Avx2Batch => Kernel::Avx2,
1474        Kernel::ScalarBatch => Kernel::Scalar,
1475        _ => unreachable!(),
1476    };
1477    dti_batch_par_slice(high, low, sweep, simd)
1478}
1479
1480#[inline(always)]
1481pub fn dti_batch_slice(
1482    high: &[f64],
1483    low: &[f64],
1484    sweep: &DtiBatchRange,
1485    kern: Kernel,
1486) -> Result<DtiBatchOutput, DtiError> {
1487    dti_batch_inner(high, low, sweep, kern, false)
1488}
1489
1490#[inline(always)]
1491pub fn dti_batch_par_slice(
1492    high: &[f64],
1493    low: &[f64],
1494    sweep: &DtiBatchRange,
1495    kern: Kernel,
1496) -> Result<DtiBatchOutput, DtiError> {
1497    dti_batch_inner(high, low, sweep, kern, true)
1498}
1499
1500#[inline(always)]
1501fn dti_batch_inner(
1502    high: &[f64],
1503    low: &[f64],
1504    sweep: &DtiBatchRange,
1505    kern: Kernel,
1506    parallel: bool,
1507) -> Result<DtiBatchOutput, DtiError> {
1508    let combos = expand_grid(sweep);
1509    if combos.is_empty() {
1510        return Err(DtiError::InvalidRange {
1511            start: 0,
1512            end: 0,
1513            step: 0,
1514        });
1515    }
1516    let len = high.len();
1517    if low.len() != len {
1518        return Err(DtiError::LengthMismatch {
1519            high: high.len(),
1520            low: low.len(),
1521        });
1522    }
1523    let first_valid = (0..len)
1524        .find(|&i| !high[i].is_nan() && !low[i].is_nan())
1525        .ok_or(DtiError::AllValuesNaN)?;
1526    let max_p = combos
1527        .iter()
1528        .map(|c| c.r.unwrap().max(c.s.unwrap()).max(c.u.unwrap()))
1529        .max()
1530        .unwrap();
1531    if len - first_valid < max_p {
1532        return Err(DtiError::NotEnoughValidData {
1533            needed: max_p,
1534            valid: len - first_valid,
1535        });
1536    }
1537    let rows = combos.len();
1538    let cols = len;
1539    let expected = rows.checked_mul(cols).ok_or(DtiError::InvalidRange {
1540        start: 0,
1541        end: 0,
1542        step: 0,
1543    })?;
1544
1545    let warmup_periods: Vec<usize> = combos.iter().map(|_| first_valid + 1).collect();
1546
1547    let mut buf_mu = make_uninit_matrix(rows, cols);
1548    init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
1549
1550    let uninit_ptr = buf_mu.as_mut_ptr();
1551    let values = unsafe { std::slice::from_raw_parts_mut(uninit_ptr as *mut f64, expected) };
1552
1553    let start = first_valid + 1;
1554    let (x_base, ax_base) = dti_precompute_base(high, low, start);
1555
1556    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1557        let prm = &combos[row];
1558        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1559        {
1560            if matches!(kern, Kernel::Scalar | Kernel::Auto) {
1561                dti_row_simd128(
1562                    high,
1563                    low,
1564                    prm.r.unwrap(),
1565                    prm.s.unwrap(),
1566                    prm.u.unwrap(),
1567                    first_valid,
1568                    out_row,
1569                );
1570                return;
1571            }
1572        }
1573
1574        match kern {
1575            Kernel::Scalar | Kernel::Auto => dti_row_scalar_from_base(
1576                &x_base,
1577                &ax_base,
1578                prm.r.unwrap(),
1579                prm.s.unwrap(),
1580                prm.u.unwrap(),
1581                start,
1582                out_row,
1583            ),
1584            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1585            Kernel::Avx2 => dti_row_avx2_from_base(
1586                &x_base,
1587                &ax_base,
1588                prm.r.unwrap(),
1589                prm.s.unwrap(),
1590                prm.u.unwrap(),
1591                start,
1592                out_row,
1593            ),
1594            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1595            Kernel::Avx512 => dti_row_avx512_from_base(
1596                &x_base,
1597                &ax_base,
1598                prm.r.unwrap(),
1599                prm.s.unwrap(),
1600                prm.u.unwrap(),
1601                start,
1602                out_row,
1603            ),
1604            _ => dti_row_scalar_from_base(
1605                &x_base,
1606                &ax_base,
1607                prm.r.unwrap(),
1608                prm.s.unwrap(),
1609                prm.u.unwrap(),
1610                start,
1611                out_row,
1612            ),
1613        }
1614    };
1615    if parallel {
1616        #[cfg(not(target_arch = "wasm32"))]
1617        {
1618            values
1619                .par_chunks_mut(cols)
1620                .enumerate()
1621                .for_each(|(row, slice)| do_row(row, slice));
1622        }
1623
1624        #[cfg(target_arch = "wasm32")]
1625        {
1626            for (row, slice) in values.chunks_mut(cols).enumerate() {
1627                do_row(row, slice);
1628            }
1629        }
1630    } else {
1631        for (row, slice) in values.chunks_mut(cols).enumerate() {
1632            do_row(row, slice);
1633        }
1634    }
1635
1636    let values = unsafe { Vec::from_raw_parts(uninit_ptr as *mut f64, expected, expected) };
1637    std::mem::forget(buf_mu);
1638
1639    Ok(DtiBatchOutput {
1640        values,
1641        combos,
1642        rows,
1643        cols,
1644    })
1645}
1646
1647#[inline(always)]
1648pub fn dti_batch_inner_into(
1649    high: &[f64],
1650    low: &[f64],
1651    sweep: &DtiBatchRange,
1652    kern: Kernel,
1653    parallel: bool,
1654    out: &mut [f64],
1655) -> Result<Vec<DtiParams>, DtiError> {
1656    let combos = expand_grid(sweep);
1657    if combos.is_empty() {
1658        return Err(DtiError::InvalidRange {
1659            start: 0,
1660            end: 0,
1661            step: 0,
1662        });
1663    }
1664    let len = high.len();
1665    if low.len() != len {
1666        return Err(DtiError::LengthMismatch {
1667            high: high.len(),
1668            low: low.len(),
1669        });
1670    }
1671    let first_valid = (0..len)
1672        .find(|&i| !high[i].is_nan() && !low[i].is_nan())
1673        .ok_or(DtiError::AllValuesNaN)?;
1674    let max_p = combos
1675        .iter()
1676        .map(|c| c.r.unwrap().max(c.s.unwrap()).max(c.u.unwrap()))
1677        .max()
1678        .unwrap();
1679    if len - first_valid < max_p {
1680        return Err(DtiError::NotEnoughValidData {
1681            needed: max_p,
1682            valid: len - first_valid,
1683        });
1684    }
1685
1686    let rows = combos.len();
1687    let cols = len;
1688    let expected = rows.checked_mul(cols).ok_or(DtiError::InvalidRange {
1689        start: 0,
1690        end: 0,
1691        step: 0,
1692    })?;
1693    if out.len() != expected {
1694        return Err(DtiError::OutputLengthMismatch {
1695            expected,
1696            got: out.len(),
1697        });
1698    }
1699
1700    let warm: Vec<usize> = vec![first_valid + 1; rows];
1701    let out_mu: &mut [std::mem::MaybeUninit<f64>] = unsafe {
1702        std::slice::from_raw_parts_mut(
1703            out.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
1704            out.len(),
1705        )
1706    };
1707    init_matrix_prefixes(out_mu, cols, &warm);
1708
1709    let values: &mut [f64] =
1710        unsafe { std::slice::from_raw_parts_mut(out_mu.as_mut_ptr() as *mut f64, out_mu.len()) };
1711
1712    let do_row = |row: usize, row_slice: &mut [f64]| unsafe {
1713        let prm = &combos[row];
1714        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1715        {
1716            if matches!(kern, Kernel::Scalar | Kernel::Auto) {
1717                dti_row_simd128(
1718                    high,
1719                    low,
1720                    prm.r.unwrap(),
1721                    prm.s.unwrap(),
1722                    prm.u.unwrap(),
1723                    first_valid,
1724                    row_slice,
1725                );
1726                return;
1727            }
1728        }
1729        match kern {
1730            Kernel::Scalar | Kernel::Auto => dti_row_scalar(
1731                high,
1732                low,
1733                prm.r.unwrap(),
1734                prm.s.unwrap(),
1735                prm.u.unwrap(),
1736                first_valid,
1737                row_slice,
1738            ),
1739            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1740            Kernel::Avx2 => dti_row_avx2(
1741                high,
1742                low,
1743                prm.r.unwrap(),
1744                prm.s.unwrap(),
1745                prm.u.unwrap(),
1746                first_valid,
1747                row_slice,
1748            ),
1749            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1750            Kernel::Avx512 => dti_row_avx512(
1751                high,
1752                low,
1753                prm.r.unwrap(),
1754                prm.s.unwrap(),
1755                prm.u.unwrap(),
1756                first_valid,
1757                row_slice,
1758            ),
1759            _ => dti_row_scalar(
1760                high,
1761                low,
1762                prm.r.unwrap(),
1763                prm.s.unwrap(),
1764                prm.u.unwrap(),
1765                first_valid,
1766                row_slice,
1767            ),
1768        }
1769    };
1770
1771    if parallel {
1772        #[cfg(not(target_arch = "wasm32"))]
1773        values
1774            .par_chunks_mut(cols)
1775            .enumerate()
1776            .for_each(|(row, sl)| do_row(row, sl));
1777        #[cfg(target_arch = "wasm32")]
1778        for (row, sl) in values.chunks_mut(cols).enumerate() {
1779            do_row(row, sl);
1780        }
1781    } else {
1782        for (row, sl) in values.chunks_mut(cols).enumerate() {
1783            do_row(row, sl);
1784        }
1785    }
1786
1787    Ok(combos)
1788}
1789
1790#[inline(always)]
1791pub fn expand_grid_dti(r: &DtiBatchRange) -> Vec<DtiParams> {
1792    expand_grid(r)
1793}
1794
1795#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1796#[wasm_bindgen]
1797pub fn dti_js(
1798    high: &[f64],
1799    low: &[f64],
1800    r: usize,
1801    s: usize,
1802    u: usize,
1803) -> Result<Vec<f64>, JsValue> {
1804    let params = DtiParams {
1805        r: Some(r),
1806        s: Some(s),
1807        u: Some(u),
1808    };
1809    let data = DtiData::Slices { high, low };
1810    let input = DtiInput { data, params };
1811
1812    let output =
1813        dti_with_kernel(&input, Kernel::Auto).map_err(|e| JsValue::from_str(&e.to_string()))?;
1814
1815    Ok(output.values)
1816}
1817
1818#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1819#[wasm_bindgen]
1820pub fn dti_into(
1821    high_ptr: *const f64,
1822    low_ptr: *const f64,
1823    out_ptr: *mut f64,
1824    len: usize,
1825    r: usize,
1826    s: usize,
1827    u: usize,
1828) -> Result<(), JsValue> {
1829    if high_ptr.is_null() || low_ptr.is_null() || out_ptr.is_null() {
1830        return Err(JsValue::from_str("Null pointer provided"));
1831    }
1832
1833    unsafe {
1834        let high = std::slice::from_raw_parts(high_ptr, len);
1835        let low = std::slice::from_raw_parts(low_ptr, len);
1836        let params = DtiParams {
1837            r: Some(r),
1838            s: Some(s),
1839            u: Some(u),
1840        };
1841        let data = DtiData::Slices { high, low };
1842        let input = DtiInput { data, params };
1843
1844        if out_ptr as *const f64 == high_ptr || out_ptr as *const f64 == low_ptr {
1845            let mut temp = vec![0.0; len];
1846            dti_into_slice(&mut temp, &input, Kernel::Auto)
1847                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1848            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1849            out.copy_from_slice(&temp);
1850        } else {
1851            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1852            dti_into_slice(out, &input, Kernel::Auto)
1853                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1854        }
1855        Ok(())
1856    }
1857}
1858
1859#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1860#[wasm_bindgen]
1861pub fn dti_alloc(len: usize) -> *mut f64 {
1862    let mut vec = Vec::<f64>::with_capacity(len);
1863    let ptr = vec.as_mut_ptr();
1864    std::mem::forget(vec);
1865    ptr
1866}
1867
1868#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1869#[wasm_bindgen]
1870pub fn dti_free(ptr: *mut f64, len: usize) {
1871    if !ptr.is_null() {
1872        unsafe {
1873            let _ = Vec::from_raw_parts(ptr, len, len);
1874        }
1875    }
1876}
1877
1878#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1879#[derive(Serialize, Deserialize)]
1880pub struct DtiBatchConfig {
1881    pub r_range: (usize, usize, usize),
1882    pub s_range: (usize, usize, usize),
1883    pub u_range: (usize, usize, usize),
1884}
1885
1886#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1887#[derive(Serialize, Deserialize)]
1888pub struct DtiBatchJsOutput {
1889    pub values: Vec<f64>,
1890    pub combos: Vec<DtiParams>,
1891    pub rows: usize,
1892    pub cols: usize,
1893}
1894
1895#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1896#[wasm_bindgen(js_name = dti_batch)]
1897pub fn dti_batch_js(high: &[f64], low: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1898    let config: DtiBatchConfig = serde_wasm_bindgen::from_value(config)
1899        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1900
1901    let sweep = DtiBatchRange {
1902        r: config.r_range,
1903        s: config.s_range,
1904        u: config.u_range,
1905    };
1906
1907    let out = dti_batch_inner(high, low, &sweep, detect_best_kernel(), false)
1908        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1909
1910    let js_out = DtiBatchJsOutput {
1911        values: out.values,
1912        combos: out.combos,
1913        rows: out.rows,
1914        cols: out.cols,
1915    };
1916    serde_wasm_bindgen::to_value(&js_out)
1917        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1918}
1919
1920#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1921#[wasm_bindgen]
1922pub fn dti_batch_into(
1923    high_ptr: *const f64,
1924    low_ptr: *const f64,
1925    out_ptr: *mut f64,
1926    len: usize,
1927    r_start: usize,
1928    r_end: usize,
1929    r_step: usize,
1930    s_start: usize,
1931    s_end: usize,
1932    s_step: usize,
1933    u_start: usize,
1934    u_end: usize,
1935    u_step: usize,
1936) -> Result<usize, JsValue> {
1937    if high_ptr.is_null() || low_ptr.is_null() || out_ptr.is_null() {
1938        return Err(JsValue::from_str("Null pointer provided to dti_batch_into"));
1939    }
1940
1941    unsafe {
1942        let high = std::slice::from_raw_parts(high_ptr, len);
1943        let low = std::slice::from_raw_parts(low_ptr, len);
1944
1945        let sweep = DtiBatchRange {
1946            r: (r_start, r_end, r_step),
1947            s: (s_start, s_end, s_step),
1948            u: (u_start, u_end, u_step),
1949        };
1950
1951        fn axis_count(start: usize, end: usize, step: usize) -> usize {
1952            if step == 0 || start == end {
1953                return 1;
1954            }
1955            if start < end {
1956                ((end - start) / step) + 1
1957            } else {
1958                ((start - end) / step) + 1
1959            }
1960        }
1961        let r_count = axis_count(r_start, r_end, r_step);
1962        let s_count = axis_count(s_start, s_end, s_step);
1963        let u_count = axis_count(u_start, u_end, u_step);
1964        let total_rows = r_count
1965            .checked_mul(s_count)
1966            .and_then(|x| x.checked_mul(u_count))
1967            .ok_or(JsValue::from_str(
1968                "range expansion overflow in dti_batch_into",
1969            ))?;
1970        let total_len = total_rows
1971            .checked_mul(len)
1972            .ok_or(JsValue::from_str("size overflow in dti_batch_into"))?;
1973
1974        let out = std::slice::from_raw_parts_mut(out_ptr, total_len);
1975
1976        if out_ptr as *const f64 == high_ptr || out_ptr as *const f64 == low_ptr {
1977            let mut temp_values = vec![0.0; total_len];
1978            let combos =
1979                dti_batch_inner_into(high, low, &sweep, Kernel::Auto, false, &mut temp_values)
1980                    .map_err(|e| JsValue::from_str(&e.to_string()))?;
1981            out.copy_from_slice(&temp_values);
1982        } else {
1983            dti_batch_inner_into(high, low, &sweep, Kernel::Auto, false, out)
1984                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1985        }
1986
1987        Ok(total_rows)
1988    }
1989}
1990
1991#[cfg(test)]
1992mod tests {
1993    use super::*;
1994    use crate::skip_if_unsupported;
1995    use crate::utilities::data_loader::read_candles_from_csv;
1996    fn check_dti_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1997        skip_if_unsupported!(kernel, test_name);
1998        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1999        let candles = read_candles_from_csv(file_path)?;
2000        let default_params = DtiParams {
2001            r: None,
2002            s: None,
2003            u: None,
2004        };
2005        let input = DtiInput::from_candles(&candles, default_params);
2006        let output = dti_with_kernel(&input, kernel)?;
2007        assert_eq!(output.values.len(), candles.close.len());
2008        Ok(())
2009    }
2010    fn check_dti_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2011        skip_if_unsupported!(kernel, test_name);
2012        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2013        let candles = read_candles_from_csv(file_path)?;
2014        let input = DtiInput::with_default_candles(&candles);
2015        let result = dti_with_kernel(&input, kernel)?;
2016        let expected_last_five = [
2017            -39.0091620347991,
2018            -39.75219264093014,
2019            -40.53941417932286,
2020            -41.2787749205189,
2021            -42.93758699380749,
2022        ];
2023        let start = result.values.len().saturating_sub(5);
2024        for (i, &val) in result.values[start..].iter().enumerate() {
2025            let diff = (val - expected_last_five[i]).abs();
2026            assert!(
2027                diff < 1e-6,
2028                "[{}] DTI {:?} mismatch at idx {}: got {}, expected {}",
2029                test_name,
2030                kernel,
2031                i,
2032                val,
2033                expected_last_five[i]
2034            );
2035        }
2036        Ok(())
2037    }
2038    fn check_dti_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2039        skip_if_unsupported!(kernel, test_name);
2040        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2041        let candles = read_candles_from_csv(file_path)?;
2042        let input = DtiInput::with_default_candles(&candles);
2043        let output = dti_with_kernel(&input, kernel)?;
2044        assert_eq!(output.values.len(), candles.close.len());
2045        Ok(())
2046    }
2047    fn check_dti_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2048        skip_if_unsupported!(kernel, test_name);
2049        let high = [10.0, 11.0, 12.0];
2050        let low = [9.0, 10.0, 11.0];
2051        let params = DtiParams {
2052            r: Some(0),
2053            s: Some(10),
2054            u: Some(5),
2055        };
2056        let input = DtiInput::from_slices(&high, &low, params);
2057        let res = dti_with_kernel(&input, kernel);
2058        assert!(
2059            res.is_err(),
2060            "[{}] DTI should fail with zero period",
2061            test_name
2062        );
2063        Ok(())
2064    }
2065    fn check_dti_period_exceeds_length(
2066        test_name: &str,
2067        kernel: Kernel,
2068    ) -> Result<(), Box<dyn Error>> {
2069        skip_if_unsupported!(kernel, test_name);
2070        let high = [10.0, 11.0];
2071        let low = [9.0, 10.0];
2072        let params = DtiParams {
2073            r: Some(14),
2074            s: Some(10),
2075            u: Some(5),
2076        };
2077        let input = DtiInput::from_slices(&high, &low, params);
2078        let res = dti_with_kernel(&input, kernel);
2079        assert!(
2080            res.is_err(),
2081            "[{}] DTI should fail with period exceeding length",
2082            test_name
2083        );
2084        Ok(())
2085    }
2086    fn check_dti_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2087        skip_if_unsupported!(kernel, test_name);
2088        let high = [f64::NAN, f64::NAN, f64::NAN];
2089        let low = [f64::NAN, f64::NAN, f64::NAN];
2090        let params = DtiParams::default();
2091        let input = DtiInput::from_slices(&high, &low, params);
2092        let res = dti_with_kernel(&input, kernel);
2093        assert!(res.is_err(), "[{}] DTI should fail with all NaN", test_name);
2094        Ok(())
2095    }
2096    fn check_dti_empty_data(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2097        skip_if_unsupported!(kernel, test_name);
2098        let high: [f64; 0] = [];
2099        let low: [f64; 0] = [];
2100        let params = DtiParams::default();
2101        let input = DtiInput::from_slices(&high, &low, params);
2102        let res = dti_with_kernel(&input, kernel);
2103        assert!(
2104            res.is_err(),
2105            "[{}] DTI should fail with empty data",
2106            test_name
2107        );
2108        Ok(())
2109    }
2110    fn check_dti_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2111        skip_if_unsupported!(kernel, test_name);
2112        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2113        let candles = read_candles_from_csv(file_path)?;
2114        let high = candles.select_candle_field("high")?;
2115        let low = candles.select_candle_field("low")?;
2116        let params = DtiParams::default();
2117        let input = DtiInput::from_slices(high, low, params.clone());
2118        let batch_output = dti_with_kernel(&input, kernel)?.values;
2119        let mut stream = DtiStream::try_new(params)?;
2120        let mut stream_values = Vec::with_capacity(high.len());
2121        for (&h, &l) in high.iter().zip(low.iter()) {
2122            match stream.update(h, l) {
2123                Some(val) => stream_values.push(val),
2124                None => stream_values.push(f64::NAN),
2125            }
2126        }
2127        assert_eq!(batch_output.len(), stream_values.len());
2128        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
2129            if b.is_nan() && s.is_nan() {
2130                continue;
2131            }
2132            let diff = (b - s).abs();
2133            assert!(
2134                diff < 1e-9,
2135                "[{}] DTI streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
2136                test_name,
2137                i,
2138                b,
2139                s,
2140                diff
2141            );
2142        }
2143        Ok(())
2144    }
2145
2146    fn check_dti_length_mismatch(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2147        skip_if_unsupported!(kernel, test_name);
2148
2149        let high = vec![10.0, 11.0, 12.0];
2150        let low = vec![9.0, 10.0];
2151        let params = DtiParams::default();
2152        let input = DtiInput::from_slices(&high, &low, params);
2153
2154        let result = dti_with_kernel(&input, kernel);
2155        assert!(
2156            result.is_err(),
2157            "[{}] Should fail with mismatched lengths",
2158            test_name
2159        );
2160
2161        match result.unwrap_err() {
2162            DtiError::LengthMismatch { high: h, low: l } => {
2163                assert_eq!(h, 3, "[{}] High length should be 3", test_name);
2164                assert_eq!(l, 2, "[{}] Low length should be 2", test_name);
2165            }
2166            e => panic!(
2167                "[{}] Expected LengthMismatch error, got: {:?}",
2168                test_name, e
2169            ),
2170        }
2171
2172        let mut out = vec![0.0; high.len()];
2173        let result = dti_into_slice(&mut out, &input, kernel);
2174        assert!(
2175            result.is_err(),
2176            "[{}] dti_into_slice should fail with mismatched lengths",
2177            test_name
2178        );
2179
2180        match result.unwrap_err() {
2181            DtiError::LengthMismatch { high: h, low: l } => {
2182                assert_eq!(h, 3, "[{}] High length should be 3", test_name);
2183                assert_eq!(l, 2, "[{}] Low length should be 2", test_name);
2184            }
2185            e => panic!(
2186                "[{}] Expected LengthMismatch error from dti_into_slice, got: {:?}",
2187                test_name, e
2188            ),
2189        }
2190
2191        let sweep = DtiBatchRange::default();
2192        let result = dti_batch_inner(&high, &low, &sweep, kernel, false);
2193        assert!(
2194            result.is_err(),
2195            "[{}] Batch should fail with mismatched lengths",
2196            test_name
2197        );
2198
2199        match result.unwrap_err() {
2200            DtiError::LengthMismatch { high: h, low: l } => {
2201                assert_eq!(h, 3, "[{}] Batch high length should be 3", test_name);
2202                assert_eq!(l, 2, "[{}] Batch low length should be 2", test_name);
2203            }
2204            e => panic!(
2205                "[{}] Expected LengthMismatch error from batch, got: {:?}",
2206                test_name, e
2207            ),
2208        }
2209
2210        Ok(())
2211    }
2212
2213    #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2214    fn check_dti_wasm_kernel_fallback(
2215        test_name: &str,
2216        _kernel: Kernel,
2217    ) -> Result<(), Box<dyn Error>> {
2218        let high = vec![10.0, 11.0, 12.0, 13.0, 14.0];
2219        let low = vec![9.0, 10.0, 11.0, 12.0, 13.0];
2220        let params = DtiParams::default();
2221        let input = DtiInput::from_slices(&high, &low, params);
2222
2223        let scalar_result = dti_with_kernel(&input, Kernel::Scalar)?;
2224
2225        let avx2_result = dti_with_kernel(&input, Kernel::Avx2)?;
2226        assert_eq!(
2227            scalar_result.values.len(),
2228            avx2_result.values.len(),
2229            "[{}] AVX2 fallback length mismatch",
2230            test_name
2231        );
2232        for (i, (s, a)) in scalar_result
2233            .values
2234            .iter()
2235            .zip(avx2_result.values.iter())
2236            .enumerate()
2237        {
2238            if s.is_nan() && a.is_nan() {
2239                continue;
2240            }
2241            assert!(
2242                (s - a).abs() < 1e-10,
2243                "[{}] AVX2 fallback mismatch at {}: scalar={}, avx2={}",
2244                test_name,
2245                i,
2246                s,
2247                a
2248            );
2249        }
2250
2251        let avx512_result = dti_with_kernel(&input, Kernel::Avx512)?;
2252        assert_eq!(
2253            scalar_result.values.len(),
2254            avx512_result.values.len(),
2255            "[{}] AVX512 fallback length mismatch",
2256            test_name
2257        );
2258        for (i, (s, a)) in scalar_result
2259            .values
2260            .iter()
2261            .zip(avx512_result.values.iter())
2262            .enumerate()
2263        {
2264            if s.is_nan() && a.is_nan() {
2265                continue;
2266            }
2267            assert!(
2268                (s - a).abs() < 1e-10,
2269                "[{}] AVX512 fallback mismatch at {}: scalar={}, avx512={}",
2270                test_name,
2271                i,
2272                s,
2273                a
2274            );
2275        }
2276
2277        let mut out_scalar = vec![0.0; high.len()];
2278        let mut out_avx = vec![0.0; high.len()];
2279
2280        dti_into_slice(&mut out_scalar, &input, Kernel::Scalar)?;
2281        dti_into_slice(&mut out_avx, &input, Kernel::Avx2)?;
2282
2283        for (i, (s, a)) in out_scalar.iter().zip(out_avx.iter()).enumerate() {
2284            if s.is_nan() && a.is_nan() {
2285                continue;
2286            }
2287            assert!(
2288                (s - a).abs() < 1e-10,
2289                "[{}] dti_into_slice AVX2 fallback mismatch at {}",
2290                test_name,
2291                i
2292            );
2293        }
2294
2295        Ok(())
2296    }
2297
2298    #[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
2299    fn check_dti_wasm_kernel_fallback(
2300        _test_name: &str,
2301        _kernel: Kernel,
2302    ) -> Result<(), Box<dyn Error>> {
2303        Ok(())
2304    }
2305
2306    #[cfg(debug_assertions)]
2307    fn check_dti_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2308        skip_if_unsupported!(kernel, test_name);
2309
2310        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2311        let candles = read_candles_from_csv(file_path)?;
2312
2313        let test_params = vec![
2314            DtiParams::default(),
2315            DtiParams {
2316                r: Some(1),
2317                s: Some(1),
2318                u: Some(1),
2319            },
2320            DtiParams {
2321                r: Some(5),
2322                s: Some(5),
2323                u: Some(5),
2324            },
2325            DtiParams {
2326                r: Some(20),
2327                s: Some(15),
2328                u: Some(10),
2329            },
2330            DtiParams {
2331                r: Some(50),
2332                s: Some(30),
2333                u: Some(20),
2334            },
2335            DtiParams {
2336                r: Some(100),
2337                s: Some(50),
2338                u: Some(25),
2339            },
2340            DtiParams {
2341                r: Some(14),
2342                s: Some(5),
2343                u: Some(20),
2344            },
2345            DtiParams {
2346                r: Some(30),
2347                s: Some(10),
2348                u: Some(5),
2349            },
2350            DtiParams {
2351                r: Some(10),
2352                s: Some(20),
2353                u: Some(15),
2354            },
2355            DtiParams {
2356                r: Some(2),
2357                s: Some(10),
2358                u: Some(5),
2359            },
2360        ];
2361
2362        for (param_idx, params) in test_params.iter().enumerate() {
2363            let input = DtiInput::from_candles(&candles, params.clone());
2364            let output = dti_with_kernel(&input, kernel)?;
2365
2366            for (i, &val) in output.values.iter().enumerate() {
2367                if val.is_nan() {
2368                    continue;
2369                }
2370
2371                let bits = val.to_bits();
2372
2373                if bits == 0x11111111_11111111 {
2374                    panic!(
2375                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2376						 with params: r={}, s={}, u={} (param set {})",
2377                        test_name,
2378                        val,
2379                        bits,
2380                        i,
2381                        params.r.unwrap_or(14),
2382                        params.s.unwrap_or(10),
2383                        params.u.unwrap_or(5),
2384                        param_idx
2385                    );
2386                }
2387
2388                if bits == 0x22222222_22222222 {
2389                    panic!(
2390                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2391						 with params: r={}, s={}, u={} (param set {})",
2392                        test_name,
2393                        val,
2394                        bits,
2395                        i,
2396                        params.r.unwrap_or(14),
2397                        params.s.unwrap_or(10),
2398                        params.u.unwrap_or(5),
2399                        param_idx
2400                    );
2401                }
2402
2403                if bits == 0x33333333_33333333 {
2404                    panic!(
2405                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2406						 with params: r={}, s={}, u={} (param set {})",
2407                        test_name,
2408                        val,
2409                        bits,
2410                        i,
2411                        params.r.unwrap_or(14),
2412                        params.s.unwrap_or(10),
2413                        params.u.unwrap_or(5),
2414                        param_idx
2415                    );
2416                }
2417            }
2418        }
2419
2420        Ok(())
2421    }
2422
2423    #[cfg(not(debug_assertions))]
2424    fn check_dti_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2425        Ok(())
2426    }
2427
2428    #[cfg(feature = "proptest")]
2429    fn check_dti_property(
2430        test_name: &str,
2431        kernel: Kernel,
2432    ) -> Result<(), Box<dyn std::error::Error>> {
2433        use proptest::prelude::*;
2434        skip_if_unsupported!(kernel, test_name);
2435
2436        let strat = (2usize..=50)
2437            .prop_flat_map(|max_period| {
2438                let min_len = max_period * 3;
2439                (
2440                    (100.0f64..5000.0f64, 0.01f64..0.1f64),
2441                    (
2442                        1usize..=max_period,
2443                        1usize..=max_period,
2444                        1usize..=max_period,
2445                    ),
2446                    min_len..400,
2447                )
2448            })
2449            .prop_flat_map(|((base_price, volatility), (r, s, u), len)| {
2450                let price_changes = prop::collection::vec((-1.0f64..1.0f64), len);
2451
2452                (
2453                    Just((base_price, volatility)),
2454                    Just((r, s, u)),
2455                    price_changes,
2456                )
2457            })
2458            .prop_map(|((base_price, volatility), (r, s, u), changes)| {
2459                let len = changes.len();
2460                let mut high = Vec::with_capacity(len);
2461                let mut low = Vec::with_capacity(len);
2462                let mut current_price = base_price;
2463
2464                for change_factor in changes {
2465                    let change = change_factor * volatility * current_price;
2466                    current_price = (current_price + change).max(10.0);
2467
2468                    let daily_range = current_price * volatility * (0.5 + change_factor.abs());
2469                    let mid_adjustment = change_factor * daily_range * 0.25;
2470
2471                    let daily_high = current_price + daily_range / 2.0 + mid_adjustment;
2472                    let daily_low = current_price - daily_range / 2.0 + mid_adjustment;
2473
2474                    high.push(daily_high);
2475                    low.push(daily_low.max(1.0));
2476                }
2477
2478                (high, low, r, s, u)
2479            });
2480
2481        proptest::test_runner::TestRunner::default()
2482            .run(&strat, |(high, low, r, s, u)| {
2483                let params = DtiParams {
2484                    r: Some(r),
2485                    s: Some(s),
2486                    u: Some(u),
2487                };
2488                let input = DtiInput::from_slices(&high, &low, params);
2489
2490                let result = dti_with_kernel(&input, kernel);
2491                prop_assert!(result.is_ok(), "DTI computation failed: {:?}", result);
2492                let DtiOutput { values: out } = result.unwrap();
2493
2494                let DtiOutput { values: ref_out } =
2495                    dti_with_kernel(&input, Kernel::Scalar).unwrap();
2496
2497                prop_assert_eq!(out.len(), high.len(), "Output length mismatch");
2498
2499                prop_assert!(out[0].is_nan(), "First value should be NaN");
2500
2501                let finite_values: Vec<f64> =
2502                    out.iter().copied().filter(|v| v.is_finite()).collect();
2503                if !finite_values.is_empty() {
2504                    let max_val = finite_values
2505                        .iter()
2506                        .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
2507                    let min_val = finite_values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
2508
2509                    prop_assert!(
2510                        max_val <= 100.0001 && min_val >= -100.0001,
2511                        "DTI values exceed mathematical bounds: [{:.6}, {:.6}]",
2512                        min_val,
2513                        max_val
2514                    );
2515                }
2516
2517                for i in 0..out.len() {
2518                    let y = out[i];
2519                    let r = ref_out[i];
2520
2521                    if !y.is_finite() || !r.is_finite() {
2522                        prop_assert_eq!(
2523                            y.to_bits(),
2524                            r.to_bits(),
2525                            "NaN/finite mismatch at idx {}: {} vs {}",
2526                            i,
2527                            y,
2528                            r
2529                        );
2530                    } else {
2531                        let diff = (y - r).abs();
2532                        let ulp_diff = y.to_bits().abs_diff(r.to_bits());
2533                        prop_assert!(
2534                            diff <= 1e-9 || ulp_diff <= 10,
2535                            "Kernel mismatch at idx {}: {} vs {} (diff={}, ulp={})",
2536                            i,
2537                            y,
2538                            r,
2539                            diff,
2540                            ulp_diff
2541                        );
2542                    }
2543                }
2544
2545                let is_strong_uptrend = high
2546                    .windows(5)
2547                    .all(|w| w.windows(2).all(|pair| pair[1] > pair[0] * 1.001))
2548                    && low
2549                        .windows(5)
2550                        .all(|w| w.windows(2).all(|pair| pair[1] > pair[0] * 1.001));
2551
2552                let is_strong_downtrend = high
2553                    .windows(5)
2554                    .all(|w| w.windows(2).all(|pair| pair[1] < pair[0] * 0.999))
2555                    && low
2556                        .windows(5)
2557                        .all(|w| w.windows(2).all(|pair| pair[1] < pair[0] * 0.999));
2558
2559                if (is_strong_uptrend || is_strong_downtrend) && out.len() > r + s + u + 10 {
2560                    let later_values: Vec<f64> = out[out.len() - 10..]
2561                        .iter()
2562                        .copied()
2563                        .filter(|v| v.is_finite())
2564                        .collect();
2565                    if later_values.len() >= 5 {
2566                        let avg = later_values.iter().sum::<f64>() / later_values.len() as f64;
2567                        if is_strong_uptrend {
2568                            prop_assert!(
2569                                avg > 0.0,
2570                                "DTI should be positive in strong uptrend: avg={:.2}",
2571                                avg
2572                            );
2573                        }
2574                        if is_strong_downtrend {
2575                            prop_assert!(
2576                                avg < 0.0,
2577                                "DTI should be negative in strong downtrend: avg={:.2}",
2578                                avg
2579                            );
2580                        }
2581                    }
2582                }
2583
2584                #[cfg(debug_assertions)]
2585                for (i, &val) in out.iter().enumerate() {
2586                    if val.is_nan() {
2587                        continue;
2588                    }
2589                    let bits = val.to_bits();
2590                    prop_assert!(
2591                        bits != 0x11111111_11111111
2592                            && bits != 0x22222222_22222222
2593                            && bits != 0x33333333_33333333,
2594                        "Found poison value at index {}: {} (0x{:016X})",
2595                        i,
2596                        val,
2597                        bits
2598                    );
2599                }
2600
2601                if r == 1 && s == 1 && u == 1 && out.len() > 10 {
2602                    let responsive_values: Vec<f64> = out[2..10]
2603                        .iter()
2604                        .copied()
2605                        .filter(|v| v.is_finite() && v.abs() > 0.0)
2606                        .collect();
2607                    prop_assert!(
2608                        !responsive_values.is_empty(),
2609                        "DTI with period=1 should produce non-zero values quickly"
2610                    );
2611                }
2612
2613                let is_zero_volatility = high
2614                    .iter()
2615                    .zip(low.iter())
2616                    .all(|(h, l)| (h - l).abs() < 1e-10);
2617                if is_zero_volatility && out.len() > 2 {
2618                    for i in 2..out.len() {
2619                        if out[i].is_finite() {
2620                            prop_assert!(
2621                                out[i].abs() < 1e-10,
2622                                "DTI should be 0 with zero volatility at index {}: {}",
2623                                i,
2624                                out[i]
2625                            );
2626                        }
2627                    }
2628                }
2629
2630                if high.len() >= 10 {
2631                    let spreads: Vec<f64> =
2632                        high.iter().zip(low.iter()).map(|(h, l)| h - l).collect();
2633                    let first_spread = spreads[0];
2634                    let is_constant_spread = spreads
2635                        .iter()
2636                        .all(|&s| (s - first_spread).abs() < first_spread * 0.01);
2637
2638                    let high_changes: Vec<f64> =
2639                        high.windows(2).map(|w| (w[1] - w[0]).abs()).collect();
2640                    let low_changes: Vec<f64> =
2641                        low.windows(2).map(|w| (w[1] - w[0]).abs()).collect();
2642                    let avg_high_change =
2643                        high_changes.iter().sum::<f64>() / high_changes.len() as f64;
2644                    let avg_low_change = low_changes.iter().sum::<f64>() / low_changes.len() as f64;
2645                    let is_stable =
2646                        avg_high_change < high[0] * 0.001 && avg_low_change < low[0] * 0.001;
2647
2648                    if is_constant_spread && is_stable && out.len() > r + s + u + 5 {
2649                        let last_values: Vec<f64> = out[out.len() - 5..]
2650                            .iter()
2651                            .copied()
2652                            .filter(|v| v.is_finite())
2653                            .collect();
2654                        if last_values.len() >= 3 {
2655                            let avg_abs = last_values.iter().map(|v| v.abs()).sum::<f64>()
2656                                / last_values.len() as f64;
2657                            prop_assert!(
2658								avg_abs < 10.0,
2659								"DTI should converge near 0 with constant spread and stable prices: avg_abs={:.2}",
2660								avg_abs
2661							);
2662                        }
2663                    }
2664                }
2665
2666                if high.len() >= 10 {
2667                    let high_rising = high
2668                        .windows(5)
2669                        .all(|w| w.windows(2).all(|pair| pair[1] >= pair[0]));
2670                    let low_rising = low
2671                        .windows(5)
2672                        .all(|w| w.windows(2).all(|pair| pair[1] >= pair[0]));
2673                    let high_falling = high
2674                        .windows(5)
2675                        .all(|w| w.windows(2).all(|pair| pair[1] <= pair[0]));
2676                    let low_falling = low
2677                        .windows(5)
2678                        .all(|w| w.windows(2).all(|pair| pair[1] <= pair[0]));
2679
2680                    if (high_rising && low_rising) && out.len() > 10 {
2681                        let mid_to_end: Vec<f64> = out[out.len() / 2..]
2682                            .iter()
2683                            .copied()
2684                            .filter(|v| v.is_finite())
2685                            .collect();
2686                        if mid_to_end.len() >= 3 {
2687                            let positive_count = mid_to_end.iter().filter(|&&v| v > 0.0).count();
2688                            prop_assert!(
2689                                positive_count >= mid_to_end.len() / 2,
2690                                "DTI should be mostly positive when prices consistently rise"
2691                            );
2692                        }
2693                    }
2694
2695                    if (high_falling && low_falling) && out.len() > 10 {
2696                        let mid_to_end: Vec<f64> = out[out.len() / 2..]
2697                            .iter()
2698                            .copied()
2699                            .filter(|v| v.is_finite())
2700                            .collect();
2701                        if mid_to_end.len() >= 3 {
2702                            let negative_count = mid_to_end.iter().filter(|&&v| v < 0.0).count();
2703                            prop_assert!(
2704                                negative_count >= mid_to_end.len() / 2,
2705                                "DTI should be mostly negative when prices consistently fall"
2706                            );
2707                        }
2708                    }
2709                }
2710
2711                Ok(())
2712            })
2713            .unwrap();
2714
2715        Ok(())
2716    }
2717
2718    #[cfg(not(feature = "proptest"))]
2719    fn check_dti_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2720        skip_if_unsupported!(kernel, test_name);
2721        Ok(())
2722    }
2723
2724    macro_rules! generate_all_dti_tests {
2725        ($($test_fn:ident),*) => {
2726            paste::paste! {
2727                $(
2728                    #[test]
2729                    fn [<$test_fn _scalar_f64>]() {
2730                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2731                    }
2732                )*
2733                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2734                $(
2735                    #[test]
2736                    fn [<$test_fn _avx2_f64>]() {
2737                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2738                    }
2739                    #[test]
2740                    fn [<$test_fn _avx512_f64>]() {
2741                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2742                    }
2743                )*
2744                #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2745                $(
2746                    #[test]
2747                    fn [<$test_fn _simd128_f64>]() {
2748                        let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
2749                    }
2750                )*
2751            }
2752        }
2753    }
2754    generate_all_dti_tests!(
2755        check_dti_partial_params,
2756        check_dti_accuracy,
2757        check_dti_default_candles,
2758        check_dti_zero_period,
2759        check_dti_period_exceeds_length,
2760        check_dti_all_nan,
2761        check_dti_empty_data,
2762        check_dti_streaming,
2763        check_dti_length_mismatch,
2764        check_dti_wasm_kernel_fallback,
2765        check_dti_no_poison,
2766        check_dti_property
2767    );
2768    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2769        skip_if_unsupported!(kernel, test);
2770        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2771        let c = read_candles_from_csv(file)?;
2772        let output = DtiBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
2773        let def = DtiParams::default();
2774        let row = output.values_for(&def).expect("default row missing");
2775        assert_eq!(row.len(), c.close.len());
2776        let expected = [
2777            -39.0091620347991,
2778            -39.75219264093014,
2779            -40.53941417932286,
2780            -41.2787749205189,
2781            -42.93758699380749,
2782        ];
2783        let start = row.len() - 5;
2784        for (i, &v) in row[start..].iter().enumerate() {
2785            assert!(
2786                (v - expected[i]).abs() < 1e-6,
2787                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2788            );
2789        }
2790        Ok(())
2791    }
2792
2793    #[cfg(debug_assertions)]
2794    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2795        skip_if_unsupported!(kernel, test);
2796
2797        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2798        let c = read_candles_from_csv(file)?;
2799
2800        let test_configs = vec![
2801            (1, 5, 1, 1, 5, 1, 1, 5, 1),
2802            (5, 15, 5, 5, 15, 5, 5, 15, 5),
2803            (10, 30, 10, 10, 20, 10, 5, 15, 5),
2804            (14, 14, 0, 10, 10, 0, 1, 10, 1),
2805            (1, 20, 5, 10, 10, 0, 5, 5, 0),
2806            (20, 50, 15, 15, 30, 15, 10, 20, 10),
2807            (2, 6, 2, 8, 12, 2, 3, 9, 3),
2808            (50, 100, 25, 30, 60, 30, 20, 40, 20),
2809            (14, 14, 0, 5, 20, 5, 5, 20, 5),
2810            (5, 5, 0, 5, 5, 0, 1, 10, 1),
2811        ];
2812
2813        for (cfg_idx, &(r_start, r_end, r_step, s_start, s_end, s_step, u_start, u_end, u_step)) in
2814            test_configs.iter().enumerate()
2815        {
2816            let output = DtiBatchBuilder::new()
2817                .kernel(kernel)
2818                .r_range(r_start, r_end, r_step)
2819                .s_range(s_start, s_end, s_step)
2820                .u_range(u_start, u_end, u_step)
2821                .apply_candles(&c)?;
2822
2823            for (idx, &val) in output.values.iter().enumerate() {
2824                if val.is_nan() {
2825                    continue;
2826                }
2827
2828                let bits = val.to_bits();
2829                let row = idx / output.cols;
2830                let col = idx % output.cols;
2831                let combo = &output.combos[row];
2832
2833                if bits == 0x11111111_11111111 {
2834                    panic!(
2835                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2836						 at row {} col {} (flat index {}) with params: r={}, s={}, u={}",
2837                        test,
2838                        cfg_idx,
2839                        val,
2840                        bits,
2841                        row,
2842                        col,
2843                        idx,
2844                        combo.r.unwrap_or(14),
2845                        combo.s.unwrap_or(10),
2846                        combo.u.unwrap_or(5)
2847                    );
2848                }
2849
2850                if bits == 0x22222222_22222222 {
2851                    panic!(
2852                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2853						 at row {} col {} (flat index {}) with params: r={}, s={}, u={}",
2854                        test,
2855                        cfg_idx,
2856                        val,
2857                        bits,
2858                        row,
2859                        col,
2860                        idx,
2861                        combo.r.unwrap_or(14),
2862                        combo.s.unwrap_or(10),
2863                        combo.u.unwrap_or(5)
2864                    );
2865                }
2866
2867                if bits == 0x33333333_33333333 {
2868                    panic!(
2869                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2870						 at row {} col {} (flat index {}) with params: r={}, s={}, u={}",
2871                        test,
2872                        cfg_idx,
2873                        val,
2874                        bits,
2875                        row,
2876                        col,
2877                        idx,
2878                        combo.r.unwrap_or(14),
2879                        combo.s.unwrap_or(10),
2880                        combo.u.unwrap_or(5)
2881                    );
2882                }
2883            }
2884        }
2885
2886        Ok(())
2887    }
2888
2889    #[cfg(not(debug_assertions))]
2890    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2891        Ok(())
2892    }
2893
2894    macro_rules! gen_batch_tests {
2895        ($fn_name:ident) => {
2896            paste::paste! {
2897                #[test] fn [<$fn_name _scalar>]()      {
2898                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2899                }
2900                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2901                #[test] fn [<$fn_name _avx2>]()        {
2902                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2903                }
2904                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2905                #[test] fn [<$fn_name _avx512>]()      {
2906                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2907                }
2908                #[test] fn [<$fn_name _auto_detect>]() {
2909                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]),
2910                                     Kernel::Auto);
2911                }
2912            }
2913        };
2914    }
2915    gen_batch_tests!(check_batch_default_row);
2916    gen_batch_tests!(check_batch_no_poison);
2917
2918    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2919    #[test]
2920    fn test_dti_into_matches_api() -> Result<(), Box<dyn Error>> {
2921        fn eq_or_both_nan(a: f64, b: f64) -> bool {
2922            (a.is_nan() && b.is_nan()) || (a == b)
2923        }
2924
2925        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2926        let candles = read_candles_from_csv(file_path)?;
2927        let input = DtiInput::with_default_candles(&candles);
2928
2929        let baseline = dti(&input)?;
2930
2931        let mut out = vec![0.0f64; candles.close.len()];
2932        dti_into(&input, &mut out)?;
2933
2934        assert_eq!(baseline.values.len(), out.len());
2935        for i in 0..out.len() {
2936            assert!(
2937                eq_or_both_nan(baseline.values[i], out[i]),
2938                "Mismatch at index {}: api={} into={}",
2939                i,
2940                baseline.values[i],
2941                out[i]
2942            );
2943        }
2944        Ok(())
2945    }
2946}