Skip to main content

vector_ta/indicators/moving_averages/
tradjema.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::CudaTradjema;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::cuda::moving_averages::DeviceArrayF32;
7#[cfg(all(feature = "python", feature = "cuda"))]
8use cust::context::Context;
9#[cfg(all(feature = "python", feature = "cuda"))]
10use numpy::PyUntypedArrayMethods;
11#[cfg(feature = "python")]
12use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2};
13#[cfg(feature = "python")]
14use pyo3::exceptions::PyValueError;
15#[cfg(feature = "python")]
16use pyo3::prelude::*;
17#[cfg(feature = "python")]
18use pyo3::types::{PyDict, PyList};
19#[cfg(all(feature = "python", feature = "cuda"))]
20use std::sync::Arc;
21
22#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
23use serde::{Deserialize, Serialize};
24#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
25use wasm_bindgen::prelude::*;
26
27use crate::utilities::data_loader::Candles;
28use crate::utilities::enums::Kernel;
29use crate::utilities::helpers::{
30    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
31    make_uninit_matrix,
32};
33#[cfg(feature = "python")]
34use crate::utilities::kernel_validation::validate_kernel;
35
36#[cfg(not(target_arch = "wasm32"))]
37use rayon::prelude::*;
38
39use std::convert::AsRef;
40use std::error::Error;
41use std::mem::MaybeUninit;
42use thiserror::Error;
43
44#[derive(Debug, Clone)]
45pub enum TradjemaData<'a> {
46    Candles {
47        candles: &'a Candles,
48    },
49    Slices {
50        high: &'a [f64],
51        low: &'a [f64],
52        close: &'a [f64],
53    },
54}
55
56#[derive(Debug, Clone)]
57pub struct TradjemaOutput {
58    pub values: Vec<f64>,
59}
60
61#[derive(Debug, Clone)]
62#[cfg_attr(
63    all(target_arch = "wasm32", feature = "wasm"),
64    derive(Serialize, Deserialize)
65)]
66pub struct TradjemaParams {
67    pub length: Option<usize>,
68    pub mult: Option<f64>,
69}
70
71impl Default for TradjemaParams {
72    fn default() -> Self {
73        Self {
74            length: Some(40),
75            mult: Some(10.0),
76        }
77    }
78}
79
80#[derive(Debug, Clone)]
81pub struct TradjemaInput<'a> {
82    pub data: TradjemaData<'a>,
83    pub params: TradjemaParams,
84}
85
86impl<'a> TradjemaInput<'a> {
87    #[inline]
88    pub fn from_candles(candles: &'a Candles, params: TradjemaParams) -> Self {
89        Self {
90            data: TradjemaData::Candles { candles },
91            params,
92        }
93    }
94
95    #[inline]
96    pub fn from_slices(
97        high: &'a [f64],
98        low: &'a [f64],
99        close: &'a [f64],
100        params: TradjemaParams,
101    ) -> Self {
102        Self {
103            data: TradjemaData::Slices { high, low, close },
104            params,
105        }
106    }
107
108    #[inline]
109    pub fn with_default_candles(candles: &'a Candles) -> Self {
110        Self::from_candles(candles, TradjemaParams::default())
111    }
112
113    #[inline]
114    pub fn get_length(&self) -> usize {
115        self.params.length.unwrap_or(40)
116    }
117
118    #[inline]
119    pub fn get_mult(&self) -> f64 {
120        self.params.mult.unwrap_or(10.0)
121    }
122}
123
124#[derive(Copy, Clone, Debug)]
125pub struct TradjemaBuilder {
126    length: Option<usize>,
127    mult: Option<f64>,
128    kernel: Kernel,
129}
130
131impl Default for TradjemaBuilder {
132    fn default() -> Self {
133        Self {
134            length: None,
135            mult: None,
136            kernel: Kernel::Auto,
137        }
138    }
139}
140
141impl TradjemaBuilder {
142    #[inline(always)]
143    pub fn new() -> Self {
144        Self::default()
145    }
146
147    #[inline(always)]
148    pub fn length(mut self, n: usize) -> Self {
149        self.length = Some(n);
150        self
151    }
152
153    #[inline(always)]
154    pub fn mult(mut self, m: f64) -> Self {
155        self.mult = Some(m);
156        self
157    }
158
159    #[inline(always)]
160    pub fn kernel(mut self, k: Kernel) -> Self {
161        self.kernel = k;
162        self
163    }
164
165    #[inline(always)]
166    pub fn apply(self, c: &Candles) -> Result<TradjemaOutput, TradjemaError> {
167        let p = TradjemaParams {
168            length: self.length,
169            mult: self.mult,
170        };
171        let i = TradjemaInput::from_candles(c, p);
172        tradjema_with_kernel(&i, self.kernel)
173    }
174
175    #[inline(always)]
176    pub fn apply_slices(
177        self,
178        high: &[f64],
179        low: &[f64],
180        close: &[f64],
181    ) -> Result<TradjemaOutput, TradjemaError> {
182        let p = TradjemaParams {
183            length: self.length,
184            mult: self.mult,
185        };
186        let i = TradjemaInput::from_slices(high, low, close, p);
187        tradjema_with_kernel(&i, self.kernel)
188    }
189
190    #[inline(always)]
191    pub fn into_stream(self) -> Result<TradjemaStream, TradjemaError> {
192        let p = TradjemaParams {
193            length: self.length,
194            mult: self.mult,
195        };
196        TradjemaStream::try_new(p)
197    }
198}
199
200#[derive(Debug, Error)]
201pub enum TradjemaError {
202    #[error("tradjema: Input data slice is empty.")]
203    EmptyInputData,
204
205    #[error("tradjema: All values are NaN.")]
206    AllValuesNaN,
207
208    #[error("tradjema: Invalid length: length = {length}, data length = {data_len}")]
209    InvalidLength { length: usize, data_len: usize },
210
211    #[error("tradjema: Not enough valid data: needed = {needed}, valid = {valid}")]
212    NotEnoughValidData { needed: usize, valid: usize },
213
214    #[error("tradjema: OHLC data length mismatch")]
215    MissingData,
216
217    #[error("tradjema: Invalid multiplier: {mult}")]
218    InvalidMult { mult: f64 },
219
220    #[error("tradjema: Output length mismatch: expected {expected}, got {got}")]
221    OutputLengthMismatch { expected: usize, got: usize },
222
223    #[error("tradjema: Invalid length range (start={start}, end={end}, step={step})")]
224    InvalidLengthRange {
225        start: usize,
226        end: usize,
227        step: usize,
228    },
229
230    #[error("tradjema: Invalid mult range (start={start}, end={end}, step={step})")]
231    InvalidMultRange { start: f64, end: f64, step: f64 },
232
233    #[error("tradjema: non-batch kernel passed to batch path: {0:?}")]
234    InvalidKernelForBatch(Kernel),
235}
236
237#[inline(always)]
238fn tradjema_prepare<'a>(
239    input: &'a TradjemaInput,
240    kernel: Kernel,
241) -> Result<(&'a [f64], &'a [f64], &'a [f64], usize, usize, f64, Kernel), TradjemaError> {
242    let (high, low, close) = match &input.data {
243        TradjemaData::Candles { candles } => {
244            let h = candles
245                .select_candle_field("high")
246                .map_err(|_| TradjemaError::EmptyInputData)?;
247            let l = candles
248                .select_candle_field("low")
249                .map_err(|_| TradjemaError::EmptyInputData)?;
250            let c = candles
251                .select_candle_field("close")
252                .map_err(|_| TradjemaError::EmptyInputData)?;
253            (h, l, c)
254        }
255        TradjemaData::Slices { high, low, close } => {
256            if high.len() != low.len() || low.len() != close.len() {
257                return Err(TradjemaError::MissingData);
258            }
259            (*high, *low, *close)
260        }
261    };
262
263    let len = close.len();
264    if len == 0 {
265        return Err(TradjemaError::EmptyInputData);
266    }
267
268    let first = close
269        .iter()
270        .position(|v| !v.is_nan())
271        .ok_or(TradjemaError::AllValuesNaN)?;
272    let length = input.get_length();
273    if length < 2 || length > len {
274        return Err(TradjemaError::InvalidLength {
275            length,
276            data_len: len,
277        });
278    }
279    if len - first < length {
280        return Err(TradjemaError::NotEnoughValidData {
281            needed: length,
282            valid: len - first,
283        });
284    }
285
286    let mult = input.get_mult();
287    if mult <= 0.0 || !mult.is_finite() {
288        return Err(TradjemaError::InvalidMult { mult });
289    }
290
291    let chosen = match kernel {
292        Kernel::Auto => Kernel::Scalar,
293        k => k,
294    };
295    Ok((high, low, close, length, first, mult, chosen))
296}
297
298#[inline]
299pub fn tradjema(input: &TradjemaInput) -> Result<TradjemaOutput, TradjemaError> {
300    tradjema_with_kernel(input, Kernel::Auto)
301}
302
303pub fn tradjema_with_kernel(
304    input: &TradjemaInput,
305    kernel: Kernel,
306) -> Result<TradjemaOutput, TradjemaError> {
307    let (h, l, c, length, first, mult, chosen) = tradjema_prepare(input, kernel)?;
308    let warm = first + length - 1;
309    let mut out = alloc_with_nan_prefix(c.len(), warm);
310    tradjema_compute_into(h, l, c, length, mult, first, chosen, &mut out);
311    Ok(TradjemaOutput { values: out })
312}
313
314#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
315#[inline]
316pub fn tradjema_into(input: &TradjemaInput, out: &mut [f64]) -> Result<(), TradjemaError> {
317    let (h, l, c, length, first, mult, chosen) = tradjema_prepare(input, Kernel::Auto)?;
318    if out.len() != c.len() {
319        return Err(TradjemaError::OutputLengthMismatch {
320            expected: c.len(),
321            got: out.len(),
322        });
323    }
324
325    let warm = first + length - 1;
326    let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
327    let w = warm.min(out.len());
328    for v in &mut out[..w] {
329        *v = qnan;
330    }
331
332    tradjema_compute_into(h, l, c, length, mult, first, chosen, out);
333    Ok(())
334}
335
336#[inline]
337pub fn tradjema_into_slice(
338    dst: &mut [f64],
339    input: &TradjemaInput,
340    kern: Kernel,
341) -> Result<(), TradjemaError> {
342    let (h, l, c, length, first, mult, chosen) = tradjema_prepare(input, kern)?;
343    if dst.len() != c.len() {
344        return Err(TradjemaError::OutputLengthMismatch {
345            expected: c.len(),
346            got: dst.len(),
347        });
348    }
349    tradjema_compute_into(h, l, c, length, mult, first, chosen, dst);
350
351    let warm = first + length - 1;
352    for v in &mut dst[..warm] {
353        *v = f64::NAN;
354    }
355    Ok(())
356}
357
358#[inline(always)]
359fn tradjema_compute_into_scalar(
360    high: &[f64],
361    low: &[f64],
362    close: &[f64],
363    length: usize,
364    mult: f64,
365    first: usize,
366    out: &mut [f64],
367) {
368    debug_assert_eq!(high.len(), low.len());
369    debug_assert_eq!(low.len(), close.len());
370    debug_assert_eq!(close.len(), out.len());
371    debug_assert!(length >= 2);
372
373    let n = out.len();
374    let warm = first + length - 1;
375    if warm >= n {
376        return;
377    }
378
379    let alpha = 2.0 / (length as f64 + 1.0);
380
381    let cap = length;
382    let mut min_vals = vec![0.0f64; cap];
383    let mut min_idx = vec![0usize; cap];
384    let mut max_vals = vec![0.0f64; cap];
385    let mut max_idx = vec![0usize; cap];
386    let (mut min_head, mut min_tail) = (0usize, 0usize);
387    let (mut max_head, mut max_tail) = (0usize, 0usize);
388
389    #[inline(always)]
390    fn inc(i: &mut usize, cap: usize) {
391        *i += 1;
392        if *i == cap {
393            *i = 0;
394        }
395    }
396    #[inline(always)]
397    fn dec(i: usize, cap: usize) -> usize {
398        if i == 0 {
399            cap - 1
400        } else {
401            i - 1
402        }
403    }
404    #[inline(always)]
405    fn minq_push(
406        v: f64,
407        idx: usize,
408        vals: &mut [f64],
409        id: &mut [usize],
410        head: &mut usize,
411        tail: &mut usize,
412        cap: usize,
413    ) {
414        let mut back = dec(*tail, cap);
415        while *tail != *head && unsafe { *vals.get_unchecked(back) } > v {
416            *tail = back;
417            back = dec(*tail, cap);
418        }
419        unsafe {
420            *vals.get_unchecked_mut(*tail) = v;
421            *id.get_unchecked_mut(*tail) = idx;
422        }
423        inc(tail, cap);
424    }
425    #[inline(always)]
426    fn maxq_push(
427        v: f64,
428        idx: usize,
429        vals: &mut [f64],
430        id: &mut [usize],
431        head: &mut usize,
432        tail: &mut usize,
433        cap: usize,
434    ) {
435        let mut back = dec(*tail, cap);
436        while *tail != *head && unsafe { *vals.get_unchecked(back) } < v {
437            *tail = back;
438            back = dec(*tail, cap);
439        }
440        unsafe {
441            *vals.get_unchecked_mut(*tail) = v;
442            *id.get_unchecked_mut(*tail) = idx;
443        }
444        inc(tail, cap);
445    }
446    #[inline(always)]
447    fn q_expire(
448        cur: usize,
449        len: usize,
450        id: &mut [usize],
451        head: &mut usize,
452        tail: &mut usize,
453        cap: usize,
454    ) {
455        let lim = cur.saturating_sub(len);
456        while *head != *tail && unsafe { *id.get_unchecked(*head) } <= lim {
457            inc(head, cap);
458        }
459    }
460
461    #[inline(always)]
462    fn max3(a: f64, b: f64, c: f64) -> f64 {
463        let m = if a > b { a } else { b };
464        if m > c {
465            m
466        } else {
467            c
468        }
469    }
470
471    let tr0 = unsafe { *high.get_unchecked(first) - *low.get_unchecked(first) };
472    minq_push(
473        tr0,
474        first,
475        &mut min_vals,
476        &mut min_idx,
477        &mut min_head,
478        &mut min_tail,
479        cap,
480    );
481    maxq_push(
482        tr0,
483        first,
484        &mut max_vals,
485        &mut max_idx,
486        &mut max_head,
487        &mut max_tail,
488        cap,
489    );
490    let mut last_tr = tr0;
491
492    let mut i = first + 1;
493    while i <= warm {
494        let hi = unsafe { *high.get_unchecked(i) };
495        let lo = unsafe { *low.get_unchecked(i) };
496        let pc1 = unsafe { *close.get_unchecked(i - 1) };
497        let tr = max3(hi - lo, (hi - pc1).abs(), (lo - pc1).abs());
498        minq_push(
499            tr,
500            i,
501            &mut min_vals,
502            &mut min_idx,
503            &mut min_head,
504            &mut min_tail,
505            cap,
506        );
507        maxq_push(
508            tr,
509            i,
510            &mut max_vals,
511            &mut max_idx,
512            &mut max_head,
513            &mut max_tail,
514            cap,
515        );
516        last_tr = tr;
517        i += 1;
518    }
519
520    let tr_low = unsafe { *min_vals.get_unchecked(min_head) };
521    let tr_high = unsafe { *max_vals.get_unchecked(max_head) };
522    let denom = tr_high - tr_low;
523    let tr_adj0 = if denom != 0.0 {
524        (last_tr - tr_low) / denom
525    } else {
526        0.0
527    };
528    let a0 = alpha * (1.0 + tr_adj0 * mult);
529    let src0 = unsafe { *close.get_unchecked(warm - 1) };
530    let mut y = src0.mul_add(a0, 0.0);
531    unsafe {
532        *out.get_unchecked_mut(warm) = y;
533    }
534
535    i = warm + 1;
536    while i < n {
537        q_expire(i, length, &mut min_idx, &mut min_head, &mut min_tail, cap);
538        q_expire(i, length, &mut max_idx, &mut max_head, &mut max_tail, cap);
539
540        let hi = unsafe { *high.get_unchecked(i) };
541        let lo = unsafe { *low.get_unchecked(i) };
542        let pc1 = unsafe { *close.get_unchecked(i - 1) };
543        let tr = max3(hi - lo, (hi - pc1).abs(), (lo - pc1).abs());
544        minq_push(
545            tr,
546            i,
547            &mut min_vals,
548            &mut min_idx,
549            &mut min_head,
550            &mut min_tail,
551            cap,
552        );
553        maxq_push(
554            tr,
555            i,
556            &mut max_vals,
557            &mut max_idx,
558            &mut max_head,
559            &mut max_tail,
560            cap,
561        );
562
563        let lo_tr = unsafe { *min_vals.get_unchecked(min_head) };
564        let hi_tr = unsafe { *max_vals.get_unchecked(max_head) };
565        let den = hi_tr - lo_tr;
566        let tr_adj = if den != 0.0 { (tr - lo_tr) / den } else { 0.0 };
567        let a = alpha * (1.0 + tr_adj * mult);
568        let src = pc1;
569        y = (src - y).mul_add(a, y);
570        unsafe {
571            *out.get_unchecked_mut(i) = y;
572        }
573
574        i += 1;
575    }
576}
577
578#[inline(always)]
579fn tradjema_compute_into(
580    high: &[f64],
581    low: &[f64],
582    close: &[f64],
583    length: usize,
584    mult: f64,
585    first: usize,
586    kern: Kernel,
587    out: &mut [f64],
588) {
589    unsafe {
590        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
591        {
592            if matches!(kern, Kernel::Scalar | Kernel::ScalarBatch) {
593                tradjema_compute_into_scalar(high, low, close, length, mult, first, out);
594                return;
595            }
596        }
597        match kern {
598            Kernel::Scalar | Kernel::ScalarBatch => {
599                tradjema_compute_into_scalar(high, low, close, length, mult, first, out)
600            }
601            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
602            Kernel::Avx2 | Kernel::Avx2Batch => {
603                tradjema_compute_into_avx2(high, low, close, length, mult, first, out)
604            }
605            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
606            Kernel::Avx512 | Kernel::Avx512Batch => {
607                tradjema_compute_into_avx512(high, low, close, length, mult, first, out)
608            }
609            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
610            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
611                tradjema_compute_into_scalar(high, low, close, length, mult, first, out)
612            }
613            _ => unreachable!(),
614        }
615    }
616}
617
618#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
619#[target_feature(enable = "avx2")]
620unsafe fn tradjema_compute_into_avx2(
621    high: &[f64],
622    low: &[f64],
623    close: &[f64],
624    length: usize,
625    mult: f64,
626    first: usize,
627    out: &mut [f64],
628) {
629    tradjema_compute_into_scalar(high, low, close, length, mult, first, out);
630}
631
632#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
633#[target_feature(enable = "avx512f")]
634unsafe fn tradjema_compute_into_avx512(
635    high: &[f64],
636    low: &[f64],
637    close: &[f64],
638    length: usize,
639    mult: f64,
640    first: usize,
641    out: &mut [f64],
642) {
643    tradjema_compute_into_scalar(high, low, close, length, mult, first, out);
644}
645
646#[derive(Debug, Clone)]
647pub struct TradjemaStream {
648    length: usize,
649    mult: f64,
650    alpha: f64,
651
652    i: usize,
653    filled: bool,
654
655    prev_close: f64,
656    tradjema: f64,
657
658    min_vals: Vec<f64>,
659    min_idx: Vec<usize>,
660    max_vals: Vec<f64>,
661    max_idx: Vec<usize>,
662    min_head: usize,
663    min_tail: usize,
664    max_head: usize,
665    max_tail: usize,
666}
667
668impl TradjemaStream {
669    pub fn try_new(params: TradjemaParams) -> Result<Self, TradjemaError> {
670        let length = params.length.unwrap_or(40);
671        let mult = params.mult.unwrap_or(10.0);
672
673        if length < 2 {
674            return Err(TradjemaError::InvalidLength {
675                length,
676                data_len: 0,
677            });
678        }
679        if mult <= 0.0 || !mult.is_finite() {
680            return Err(TradjemaError::InvalidMult { mult });
681        }
682        let cap = length;
683
684        Ok(Self {
685            length,
686            mult,
687            alpha: 2.0 / (length as f64 + 1.0),
688
689            i: 0,
690            filled: false,
691
692            prev_close: f64::NAN,
693            tradjema: f64::NAN,
694
695            min_vals: vec![0.0; cap],
696            min_idx: vec![0; cap],
697            max_vals: vec![0.0; cap],
698            max_idx: vec![0; cap],
699            min_head: 0,
700            min_tail: 0,
701            max_head: 0,
702            max_tail: 0,
703        })
704    }
705
706    #[inline(always)]
707    fn inc(i: &mut usize, cap: usize) {
708        *i += 1;
709        if *i == cap {
710            *i = 0;
711        }
712    }
713    #[inline(always)]
714    fn dec(i: usize, cap: usize) -> usize {
715        if i == 0 {
716            cap - 1
717        } else {
718            i - 1
719        }
720    }
721    #[inline(always)]
722    fn minq_push(&mut self, v: f64, idx: usize) {
723        let cap = self.length;
724        let mut back = Self::dec(self.min_tail, cap);
725
726        while self.min_tail != self.min_head && self.min_vals[back] > v {
727            self.min_tail = back;
728            back = Self::dec(self.min_tail, cap);
729        }
730        self.min_vals[self.min_tail] = v;
731        self.min_idx[self.min_tail] = idx;
732        Self::inc(&mut self.min_tail, cap);
733    }
734    #[inline(always)]
735    fn maxq_push(&mut self, v: f64, idx: usize) {
736        let cap = self.length;
737        let mut back = Self::dec(self.max_tail, cap);
738
739        while self.max_tail != self.max_head && self.max_vals[back] < v {
740            self.max_tail = back;
741            back = Self::dec(self.max_tail, cap);
742        }
743        self.max_vals[self.max_tail] = v;
744        self.max_idx[self.max_tail] = idx;
745        Self::inc(&mut self.max_tail, cap);
746    }
747    #[inline(always)]
748    fn q_expire(
749        head: &mut usize,
750        tail: &mut usize,
751        id: &mut [usize],
752        cur: usize,
753        len: usize,
754        cap: usize,
755    ) {
756        let lim = cur.saturating_sub(len);
757        while *head != *tail && id[*head] <= lim {
758            Self::inc(head, cap);
759        }
760    }
761
762    #[inline(always)]
763    fn max3(a: f64, b: f64, c: f64) -> f64 {
764        let m = if a > b { a } else { b };
765        if m > c {
766            m
767        } else {
768            c
769        }
770    }
771
772    #[inline(always)]
773    pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
774        let tr = if self.prev_close.is_nan() {
775            high - low
776        } else {
777            let hl = high - low;
778            let hc = (high - self.prev_close).abs();
779            let lc = (low - self.prev_close).abs();
780            Self::max3(hl, hc, lc)
781        };
782
783        if !self.filled {
784            self.minq_push(tr, self.i);
785            self.maxq_push(tr, self.i);
786
787            if self.i + 1 < self.length {
788                self.prev_close = close;
789                self.i += 1;
790                return None;
791            }
792
793            let lo = self.min_vals[self.min_head];
794            let hi = self.max_vals[self.max_head];
795            let den = hi - lo;
796            let tr_adj = if den != 0.0 { (tr - lo) / den } else { 0.0 };
797            let a0 = self.alpha * (1.0 + tr_adj * self.mult);
798
799            let src = self.prev_close;
800            self.tradjema = src.mul_add(a0, 0.0);
801
802            self.prev_close = close;
803            self.filled = true;
804            self.i += 1;
805            return Some(self.tradjema);
806        }
807
808        let cap = self.length;
809        Self::q_expire(
810            &mut self.min_head,
811            &mut self.min_tail,
812            &mut self.min_idx,
813            self.i,
814            self.length,
815            cap,
816        );
817        Self::q_expire(
818            &mut self.max_head,
819            &mut self.max_tail,
820            &mut self.max_idx,
821            self.i,
822            self.length,
823            cap,
824        );
825
826        self.minq_push(tr, self.i);
827        self.maxq_push(tr, self.i);
828
829        let lo = self.min_vals[self.min_head];
830        let hi = self.max_vals[self.max_head];
831        let den = hi - lo;
832        let tr_adj = if den != 0.0 { (tr - lo) / den } else { 0.0 };
833        let a = self.alpha * (1.0 + tr_adj * self.mult);
834
835        let src = self.prev_close;
836        self.tradjema = (src - self.tradjema).mul_add(a, self.tradjema);
837
838        self.prev_close = close;
839        self.i += 1;
840
841        Some(self.tradjema)
842    }
843}
844
845#[cfg(feature = "python")]
846#[pyfunction(name = "tradjema")]
847#[pyo3(signature = (high, low, close, length, mult, kernel=None))]
848pub fn tradjema_py<'py>(
849    py: Python<'py>,
850    high: PyReadonlyArray1<'py, f64>,
851    low: PyReadonlyArray1<'py, f64>,
852    close: PyReadonlyArray1<'py, f64>,
853    length: usize,
854    mult: f64,
855    kernel: Option<&str>,
856) -> PyResult<Bound<'py, PyArray1<f64>>> {
857    let (h, l, c) = (high.as_slice()?, low.as_slice()?, close.as_slice()?);
858    if h.len() != l.len() || l.len() != c.len() {
859        return Err(PyValueError::new_err(
860            "All OHLC arrays must have the same length",
861        ));
862    }
863    let kern = validate_kernel(kernel, false)?;
864    let input = TradjemaInput::from_slices(
865        h,
866        l,
867        c,
868        TradjemaParams {
869            length: Some(length),
870            mult: Some(mult),
871        },
872    );
873
874    let values = py
875        .allow_threads(|| tradjema_with_kernel(&input, kern).map(|o| o.values))
876        .map_err(|e| PyValueError::new_err(e.to_string()))?;
877    Ok(values.into_pyarray(py))
878}
879
880#[cfg(feature = "python")]
881#[pyfunction(name = "tradjema_batch")]
882#[pyo3(signature = (high, low, close, length_range, mult_range, kernel=None))]
883pub fn tradjema_batch_py<'py>(
884    py: Python<'py>,
885    high: PyReadonlyArray1<'py, f64>,
886    low: PyReadonlyArray1<'py, f64>,
887    close: PyReadonlyArray1<'py, f64>,
888    length_range: (usize, usize, usize),
889    mult_range: (f64, f64, f64),
890    kernel: Option<&str>,
891) -> PyResult<Bound<'py, PyDict>> {
892    use numpy::PyArray1;
893
894    let (h, l, c) = (high.as_slice()?, low.as_slice()?, close.as_slice()?);
895    if h.len() != l.len() || l.len() != c.len() {
896        return Err(PyValueError::new_err(
897            "All OHLC arrays must have the same length",
898        ));
899    }
900
901    let sweep = TradjemaBatchRange {
902        length: length_range,
903        mult: mult_range,
904    };
905    let combos = expand_grid(&sweep);
906    if combos.is_empty() {
907        return Err(PyValueError::new_err("Empty parameter grid"));
908    }
909    let rows = combos.len();
910    let cols = c.len();
911
912    let total = rows
913        .checked_mul(cols)
914        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
915
916    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
917    let slice_out = unsafe { out_arr.as_slice_mut()? };
918
919    let first = c
920        .iter()
921        .position(|v| !v.is_nan())
922        .ok_or_else(|| PyValueError::new_err("All values are NaN"))?;
923    for (row, prm) in combos.iter().enumerate() {
924        let length = prm.length.unwrap_or(40);
925        let warm = first + length - 1;
926        let row_slice = &mut slice_out[row * cols..(row + 1) * cols];
927        for v in &mut row_slice[..warm] {
928            *v = f64::NAN;
929        }
930    }
931
932    let kern = validate_kernel(kernel, true)?;
933    let simd = match kern {
934        Kernel::Auto => detect_best_batch_kernel(),
935        Kernel::Avx512Batch => Kernel::Avx512,
936        Kernel::Avx2Batch => Kernel::Avx2,
937        Kernel::ScalarBatch => Kernel::Scalar,
938        _ => Kernel::Scalar,
939    };
940
941    let combos = py
942        .allow_threads(|| tradjema_batch_inner_into(h, l, c, &sweep, simd, true, slice_out))
943        .map_err(|e| PyValueError::new_err(e.to_string()))?;
944
945    let dict = PyDict::new(py);
946    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
947    dict.set_item(
948        "lengths",
949        combos
950            .iter()
951            .map(|p| p.length.unwrap_or(40) as u64)
952            .collect::<Vec<_>>()
953            .into_pyarray(py),
954    )?;
955    dict.set_item(
956        "mults",
957        combos
958            .iter()
959            .map(|p| p.mult.unwrap_or(10.0))
960            .collect::<Vec<_>>()
961            .into_pyarray(py),
962    )?;
963    Ok(dict)
964}
965
966#[cfg(all(feature = "python", feature = "cuda"))]
967#[pyfunction(name = "tradjema_cuda_batch_dev")]
968#[pyo3(signature = (high_f32, low_f32, close_f32, length_range, mult_range, device_id=0))]
969pub fn tradjema_cuda_batch_dev_py(
970    py: Python<'_>,
971    high_f32: PyReadonlyArray1<'_, f32>,
972    low_f32: PyReadonlyArray1<'_, f32>,
973    close_f32: PyReadonlyArray1<'_, f32>,
974    length_range: (usize, usize, usize),
975    mult_range: (f64, f64, f64),
976    device_id: usize,
977) -> PyResult<DeviceArrayF32TradjemaPy> {
978    if !cuda_available() {
979        return Err(PyValueError::new_err("CUDA not available"));
980    }
981
982    let high = high_f32.as_slice()?;
983    let low = low_f32.as_slice()?;
984    let close = close_f32.as_slice()?;
985
986    if high.len() != low.len() || low.len() != close.len() {
987        return Err(PyValueError::new_err(
988            "All OHLC arrays must have the same length",
989        ));
990    }
991
992    let sweep = TradjemaBatchRange {
993        length: length_range,
994        mult: mult_range,
995    };
996
997    let (inner, ctx, dev_id) = py.allow_threads(|| {
998        let cuda =
999            CudaTradjema::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1000        let ctx = cuda.context_arc();
1001        let dev_id = cuda.device_id();
1002        let arr = cuda
1003            .tradjema_batch_dev(high, low, close, &sweep)
1004            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1005        Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
1006    })?;
1007
1008    Ok(DeviceArrayF32TradjemaPy::new(inner, ctx, dev_id))
1009}
1010
1011#[cfg(all(feature = "python", feature = "cuda"))]
1012#[pyfunction(name = "tradjema_cuda_many_series_one_param_dev")]
1013#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, length, mult, device_id=0))]
1014pub fn tradjema_cuda_many_series_one_param_dev_py(
1015    py: Python<'_>,
1016    high_tm_f32: PyReadonlyArray2<'_, f32>,
1017    low_tm_f32: PyReadonlyArray2<'_, f32>,
1018    close_tm_f32: PyReadonlyArray2<'_, f32>,
1019    length: usize,
1020    mult: f64,
1021    device_id: usize,
1022) -> PyResult<DeviceArrayF32TradjemaPy> {
1023    if !cuda_available() {
1024        return Err(PyValueError::new_err("CUDA not available"));
1025    }
1026
1027    let shape = high_tm_f32.shape();
1028    if shape != low_tm_f32.shape() || shape != close_tm_f32.shape() {
1029        return Err(PyValueError::new_err(
1030            "OHLC tensors must share the same shape",
1031        ));
1032    }
1033    if shape.len() != 2 {
1034        return Err(PyValueError::new_err("expected 2D arrays (time, series)"));
1035    }
1036    let rows = shape[0];
1037    let cols = shape[1];
1038
1039    let high = high_tm_f32.as_slice()?;
1040    let low = low_tm_f32.as_slice()?;
1041    let close = close_tm_f32.as_slice()?;
1042
1043    let params = TradjemaParams {
1044        length: Some(length),
1045        mult: Some(mult),
1046    };
1047
1048    let (inner, ctx, dev_id) = py.allow_threads(|| {
1049        let cuda =
1050            CudaTradjema::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1051        let ctx = cuda.context_arc();
1052        let dev_id = cuda.device_id();
1053        let arr = cuda
1054            .tradjema_many_series_one_param_time_major_dev(high, low, close, cols, rows, &params)
1055            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1056        Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
1057    })?;
1058
1059    Ok(DeviceArrayF32TradjemaPy::new(inner, ctx, dev_id))
1060}
1061
1062#[cfg(all(feature = "python", feature = "cuda"))]
1063#[pyclass(
1064    module = "ta_indicators.cuda",
1065    name = "DeviceArrayF32Tradjema",
1066    unsendable
1067)]
1068pub struct DeviceArrayF32TradjemaPy {
1069    pub(crate) inner: DeviceArrayF32,
1070    _ctx_guard: Arc<Context>,
1071    _device_id: u32,
1072}
1073
1074#[cfg(all(feature = "python", feature = "cuda"))]
1075#[pymethods]
1076impl DeviceArrayF32TradjemaPy {
1077    #[new]
1078    fn py_new() -> PyResult<Self> {
1079        Err(pyo3::exceptions::PyTypeError::new_err(
1080            "use factory methods from CUDA functions",
1081        ))
1082    }
1083
1084    #[getter]
1085    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1086        let d = PyDict::new(py);
1087        let itemsize = std::mem::size_of::<f32>();
1088        d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1089        d.set_item("typestr", "<f4")?;
1090        d.set_item("strides", (self.inner.cols * itemsize, itemsize))?;
1091        let size = self.inner.rows.saturating_mul(self.inner.cols);
1092        let ptr_val: usize = if size == 0 {
1093            0
1094        } else {
1095            self.inner.buf.as_device_ptr().as_raw() as usize
1096        };
1097        d.set_item("data", (ptr_val, false))?;
1098        d.set_item("version", 3)?;
1099        Ok(d)
1100    }
1101
1102    fn __dlpack_device__(&self) -> (i32, i32) {
1103        (2, self._device_id as i32)
1104    }
1105
1106    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1107    fn __dlpack__<'py>(
1108        &mut self,
1109        py: Python<'py>,
1110        stream: Option<PyObject>,
1111        max_version: Option<PyObject>,
1112        dl_device: Option<PyObject>,
1113        copy: Option<PyObject>,
1114    ) -> PyResult<PyObject> {
1115        use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1116        use cust::memory::DeviceBuffer;
1117
1118        let (kdl, alloc_dev) = self.__dlpack_device__();
1119        if let Some(dev_obj) = dl_device.as_ref() {
1120            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1121                if dev_ty != kdl || dev_id != alloc_dev {
1122                    let wants_copy = copy
1123                        .as_ref()
1124                        .and_then(|c| c.extract::<bool>(py).ok())
1125                        .unwrap_or(false);
1126                    if wants_copy {
1127                        return Err(PyValueError::new_err(
1128                            "device copy not implemented for __dlpack__",
1129                        ));
1130                    } else {
1131                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1132                    }
1133                }
1134            }
1135        }
1136        let _ = stream;
1137
1138        let dummy =
1139            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1140        let inner = std::mem::replace(
1141            &mut self.inner,
1142            DeviceArrayF32 {
1143                buf: dummy,
1144                rows: 0,
1145                cols: 0,
1146            },
1147        );
1148
1149        let rows = inner.rows;
1150        let cols = inner.cols;
1151        let buf = inner.buf;
1152
1153        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1154
1155        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1156    }
1157}
1158
1159#[cfg(all(feature = "python", feature = "cuda"))]
1160impl DeviceArrayF32TradjemaPy {
1161    pub fn new(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
1162        Self {
1163            inner,
1164            _ctx_guard: ctx_guard,
1165            _device_id: device_id,
1166        }
1167    }
1168}
1169
1170#[cfg(feature = "python")]
1171#[pyclass(name = "TradjemaStream")]
1172pub struct TradjemaStreamPy {
1173    inner: TradjemaStream,
1174}
1175
1176#[cfg(feature = "python")]
1177#[pymethods]
1178impl TradjemaStreamPy {
1179    #[new]
1180    fn new(length: usize, mult: f64) -> PyResult<Self> {
1181        TradjemaStream::try_new(TradjemaParams {
1182            length: Some(length),
1183            mult: Some(mult),
1184        })
1185        .map(|inner| Self { inner })
1186        .map_err(|e| PyValueError::new_err(e.to_string()))
1187    }
1188    fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
1189        self.inner.update(high, low, close)
1190    }
1191}
1192
1193#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1194#[wasm_bindgen]
1195pub fn tradjema_into(
1196    high_ptr: *const f64,
1197    low_ptr: *const f64,
1198    close_ptr: *const f64,
1199    out_ptr: *mut f64,
1200    len: usize,
1201    length: usize,
1202    mult: f64,
1203) -> Result<(), JsValue> {
1204    if [high_ptr, low_ptr, close_ptr, out_ptr]
1205        .iter()
1206        .any(|p| p.is_null())
1207    {
1208        return Err(JsValue::from_str("null pointer"));
1209    }
1210    unsafe {
1211        let h = std::slice::from_raw_parts(high_ptr, len);
1212        let l = std::slice::from_raw_parts(low_ptr, len);
1213        let c = std::slice::from_raw_parts(close_ptr, len);
1214
1215        let params = TradjemaParams {
1216            length: Some(length),
1217            mult: Some(mult),
1218        };
1219        let input = TradjemaInput::from_slices(h, l, c, params);
1220
1221        if (out_ptr as *const f64) == close_ptr
1222            || (out_ptr as *const f64) == high_ptr
1223            || (out_ptr as *const f64) == low_ptr
1224        {
1225            let mut tmp = vec![f64::NAN; len];
1226            tradjema_into_slice(&mut tmp, &input, Kernel::Auto)
1227                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1228            std::slice::from_raw_parts_mut(out_ptr, len).copy_from_slice(&tmp);
1229        } else {
1230            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1231            tradjema_into_slice(out, &input, Kernel::Auto)
1232                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1233        }
1234        Ok(())
1235    }
1236}
1237
1238#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1239#[wasm_bindgen]
1240pub fn tradjema_batch_into(
1241    high_ptr: *const f64,
1242    low_ptr: *const f64,
1243    close_ptr: *const f64,
1244    len: usize,
1245    length_start: usize,
1246    length_end: usize,
1247    length_step: usize,
1248    mult_start: f64,
1249    mult_end: f64,
1250    mult_step: f64,
1251    out_ptr: *mut f64,
1252) -> Result<usize, JsValue> {
1253    if [high_ptr, low_ptr, close_ptr].iter().any(|p| p.is_null()) || out_ptr.is_null() {
1254        return Err(JsValue::from_str("Null pointer passed"));
1255    }
1256
1257    unsafe {
1258        let high = std::slice::from_raw_parts(high_ptr, len);
1259        let low = std::slice::from_raw_parts(low_ptr, len);
1260        let close = std::slice::from_raw_parts(close_ptr, len);
1261
1262        let sweep = TradjemaBatchRange {
1263            length: (length_start, length_end, length_step),
1264            mult: (mult_start, mult_end, mult_step),
1265        };
1266        let combos = expand_grid(&sweep);
1267        if combos.is_empty() {
1268            return Err(JsValue::from_str("Empty parameter grid"));
1269        }
1270        let rows = combos.len();
1271        let cols = len;
1272
1273        let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
1274
1275        let first = close
1276            .iter()
1277            .position(|v| !v.is_nan())
1278            .ok_or_else(|| JsValue::from_str("All values are NaN"))?;
1279        for (row, prm) in combos.iter().enumerate() {
1280            let length = prm.length.unwrap_or(40);
1281            let warm = first + length - 1;
1282            let row_slice = &mut out[row * cols..(row + 1) * cols];
1283            for v in &mut row_slice[..warm] {
1284                *v = f64::NAN;
1285            }
1286        }
1287
1288        let simd = match detect_best_batch_kernel() {
1289            Kernel::Avx512Batch => Kernel::Avx512,
1290            Kernel::Avx2Batch => Kernel::Avx2,
1291            _ => Kernel::Scalar,
1292        };
1293        tradjema_batch_inner_into(high, low, close, &sweep, simd, false, out)
1294            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1295
1296        Ok(rows)
1297    }
1298}
1299
1300#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1301#[wasm_bindgen]
1302pub fn tradjema_alloc(len: usize) -> *mut f64 {
1303    let mut v = Vec::<f64>::with_capacity(len);
1304    let p = v.as_mut_ptr();
1305    std::mem::forget(v);
1306    p
1307}
1308
1309#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1310#[wasm_bindgen]
1311pub fn tradjema_free(ptr: *mut f64, len: usize) {
1312    unsafe {
1313        let _ = Vec::from_raw_parts(ptr, len, len);
1314    }
1315}
1316
1317#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1318#[wasm_bindgen]
1319pub fn tradjema_js(
1320    high: &[f64],
1321    low: &[f64],
1322    close: &[f64],
1323    length: usize,
1324    mult: f64,
1325) -> Result<Vec<f64>, JsValue> {
1326    if close.is_empty() {
1327        return Err(JsValue::from_str("Input data slice is empty"));
1328    }
1329    if high.len() != low.len() || low.len() != close.len() {
1330        return Err(JsValue::from_str("length mismatch"));
1331    }
1332
1333    if length < 2 || length > close.len() {
1334        return Err(JsValue::from_str("Invalid length"));
1335    }
1336    if !(mult.is_finite()) || mult <= 0.0 {
1337        return Err(JsValue::from_str("Invalid mult"));
1338    }
1339    let first = close
1340        .iter()
1341        .position(|v| !v.is_nan())
1342        .ok_or_else(|| JsValue::from_str("All values are NaN"))?;
1343    if close.len() - first < length {
1344        return Err(JsValue::from_str("Not enough valid data"));
1345    }
1346    let warm = first + length - 1;
1347
1348    let mut out = alloc_with_nan_prefix(close.len(), warm);
1349
1350    tradjema_compute_into(
1351        high,
1352        low,
1353        close,
1354        length,
1355        mult,
1356        first,
1357        Kernel::Scalar,
1358        &mut out,
1359    );
1360
1361    Ok(out)
1362}
1363
1364#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1365#[derive(Serialize, Deserialize)]
1366pub struct TradjemaBatchConfig {
1367    pub length_range: (usize, usize, usize),
1368    pub mult_range: (f64, f64, f64),
1369}
1370
1371#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1372#[derive(Serialize, Deserialize)]
1373pub struct TradjemaBatchJsOutput {
1374    pub values: Vec<f64>,
1375    pub combos: Vec<TradjemaParams>,
1376    pub rows: usize,
1377    pub cols: usize,
1378}
1379
1380#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1381#[wasm_bindgen(js_name = "tradjema_batch")]
1382pub fn tradjema_batch_unified_js(
1383    high: &[f64],
1384    low: &[f64],
1385    close: &[f64],
1386    config: JsValue,
1387) -> Result<JsValue, JsValue> {
1388    let cfg: TradjemaBatchConfig = serde_wasm_bindgen::from_value(config)
1389        .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
1390    let sweep = TradjemaBatchRange {
1391        length: cfg.length_range,
1392        mult: cfg.mult_range,
1393    };
1394
1395    if high.is_empty() || low.is_empty() || close.is_empty() {
1396        return Err(JsValue::from_str("Input arrays are empty"));
1397    }
1398
1399    let out = tradjema_batch_with_kernel(high, low, close, &sweep, Kernel::Auto)
1400        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1401
1402    let js = TradjemaBatchJsOutput {
1403        values: out.values,
1404        combos: out.combos,
1405        rows: out.rows,
1406        cols: out.cols,
1407    };
1408    serde_wasm_bindgen::to_value(&js)
1409        .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
1410}
1411
1412#[derive(Clone, Debug)]
1413pub struct TradjemaBatchRange {
1414    pub length: (usize, usize, usize),
1415    pub mult: (f64, f64, f64),
1416}
1417
1418impl Default for TradjemaBatchRange {
1419    fn default() -> Self {
1420        Self {
1421            length: (40, 289, 1),
1422            mult: (10.0, 10.0, 0.0),
1423        }
1424    }
1425}
1426
1427#[derive(Clone, Debug, Default)]
1428pub struct TradjemaBatchBuilder {
1429    range: TradjemaBatchRange,
1430    kernel: Kernel,
1431}
1432
1433impl TradjemaBatchBuilder {
1434    pub fn new() -> Self {
1435        Self::default()
1436    }
1437
1438    pub fn with_default_candles(c: &Candles) -> Result<TradjemaBatchOutput, TradjemaError> {
1439        TradjemaBatchBuilder::new()
1440            .kernel(Kernel::Auto)
1441            .apply_candles(c)
1442    }
1443
1444    pub fn kernel(mut self, k: Kernel) -> Self {
1445        self.kernel = k;
1446        self
1447    }
1448
1449    #[inline]
1450    pub fn length_range(mut self, start: usize, end: usize, step: usize) -> Self {
1451        self.range.length = (start, end, step);
1452        self
1453    }
1454
1455    #[inline]
1456    pub fn mult_range(mut self, start: f64, end: f64, step: f64) -> Self {
1457        self.range.mult = (start, end, step);
1458        self
1459    }
1460
1461    pub fn apply_slices(
1462        self,
1463        high: &[f64],
1464        low: &[f64],
1465        close: &[f64],
1466    ) -> Result<TradjemaBatchOutput, TradjemaError> {
1467        tradjema_batch_with_kernel(high, low, close, &self.range, self.kernel)
1468    }
1469
1470    pub fn apply_candles(self, c: &Candles) -> Result<TradjemaBatchOutput, TradjemaError> {
1471        let high = c
1472            .select_candle_field("high")
1473            .map_err(|_| TradjemaError::EmptyInputData)?;
1474        let low = c
1475            .select_candle_field("low")
1476            .map_err(|_| TradjemaError::EmptyInputData)?;
1477        let close = c
1478            .select_candle_field("close")
1479            .map_err(|_| TradjemaError::EmptyInputData)?;
1480        self.apply_slices(high, low, close)
1481    }
1482}
1483
1484#[derive(Clone, Debug)]
1485pub struct TradjemaBatchOutput {
1486    pub values: Vec<f64>,
1487    pub combos: Vec<TradjemaParams>,
1488    pub rows: usize,
1489    pub cols: usize,
1490}
1491
1492impl TradjemaBatchOutput {
1493    pub fn row_for_params(&self, p: &TradjemaParams) -> Option<usize> {
1494        self.combos.iter().position(|c| {
1495            c.length.unwrap_or(40) == p.length.unwrap_or(40)
1496                && (c.mult.unwrap_or(10.0) - p.mult.unwrap_or(10.0)).abs() < 1e-9
1497        })
1498    }
1499
1500    pub fn values_for(&self, p: &TradjemaParams) -> Option<&[f64]> {
1501        self.row_for_params(p).map(|row| {
1502            let start = row * self.cols;
1503            &self.values[start..start + self.cols]
1504        })
1505    }
1506}
1507
1508#[inline(always)]
1509fn expand_grid(r: &TradjemaBatchRange) -> Vec<TradjemaParams> {
1510    let (ls, le, lstep) = r.length;
1511    let (ms, me, mstep) = r.mult;
1512
1513    #[inline]
1514    fn axis_usize(start: usize, end: usize, step: usize) -> Vec<usize> {
1515        if step == 0 {
1516            return vec![start];
1517        }
1518        let mut vals = Vec::new();
1519        if start <= end {
1520            let mut v = start;
1521            while v <= end {
1522                vals.push(v);
1523                match v.checked_add(step) {
1524                    Some(n) if n > v => v = n,
1525                    _ => break,
1526                }
1527            }
1528        } else {
1529            let mut v = start;
1530            loop {
1531                vals.push(v);
1532                if v <= end {
1533                    break;
1534                }
1535                v = v.saturating_sub(step);
1536                if v < end {
1537                    break;
1538                }
1539            }
1540        }
1541        vals
1542    }
1543
1544    #[inline]
1545    fn axis_f64(start: f64, end: f64, step: f64) -> Vec<f64> {
1546        if step == 0.0 {
1547            return vec![start];
1548        }
1549        let mut vals = Vec::new();
1550        if start <= end {
1551            let mut v = start;
1552            while v <= end {
1553                vals.push(v);
1554                v += step;
1555
1556                if !v.is_finite() {
1557                    break;
1558                }
1559                if step.is_sign_negative() {
1560                    break;
1561                }
1562            }
1563        } else {
1564            let mut v = start;
1565            while v >= end {
1566                vals.push(v);
1567                v -= step.abs();
1568                if !v.is_finite() {
1569                    break;
1570                }
1571                if step == 0.0 {
1572                    break;
1573                }
1574            }
1575        }
1576        vals
1577    }
1578
1579    let lengths = axis_usize(ls, le, lstep);
1580    let mults = axis_f64(ms, me, mstep);
1581    if lengths.is_empty() || mults.is_empty() {
1582        return Vec::new();
1583    }
1584    let mut combos = Vec::with_capacity(lengths.len().saturating_mul(mults.len()));
1585    for &l in &lengths {
1586        for &m in &mults {
1587            combos.push(TradjemaParams {
1588                length: Some(l),
1589                mult: Some(m),
1590            });
1591        }
1592    }
1593    combos
1594}
1595
1596pub fn tradjema_batch_with_kernel(
1597    high: &[f64],
1598    low: &[f64],
1599    close: &[f64],
1600    sweep: &TradjemaBatchRange,
1601    k: Kernel,
1602) -> Result<TradjemaBatchOutput, TradjemaError> {
1603    let kernel = match k {
1604        Kernel::Auto => detect_best_batch_kernel(),
1605        other if other.is_batch() => other,
1606        Kernel::Scalar | Kernel::Avx2 | Kernel::Avx512 => {
1607            return Err(TradjemaError::InvalidKernelForBatch(k));
1608        }
1609        _ => detect_best_batch_kernel(),
1610    };
1611
1612    let simd = match kernel {
1613        Kernel::Avx512Batch => Kernel::Avx512,
1614        Kernel::Avx2Batch => Kernel::Avx2,
1615        Kernel::ScalarBatch => Kernel::Scalar,
1616        _ => Kernel::Scalar,
1617    };
1618
1619    let combos = expand_grid(sweep);
1620    if combos.is_empty() {
1621        let (ls, le, lstep) = sweep.length;
1622        let (ms, me, mstep) = sweep.mult;
1623
1624        let length_empty =
1625            (lstep == 0 && ls != le) || (lstep > 0 && ls > le && ls.saturating_sub(le) < lstep);
1626        if length_empty {
1627            return Err(TradjemaError::InvalidLengthRange {
1628                start: ls,
1629                end: le,
1630                step: lstep,
1631            });
1632        } else {
1633            return Err(TradjemaError::InvalidMultRange {
1634                start: ms,
1635                end: me,
1636                step: mstep,
1637            });
1638        }
1639    }
1640    let rows = combos.len();
1641    let cols = close.len();
1642    rows.checked_mul(cols)
1643        .ok_or(TradjemaError::InvalidLengthRange {
1644            start: sweep.length.0,
1645            end: sweep.length.1,
1646            step: sweep.length.2,
1647        })?;
1648
1649    tradjema_batch_inner(high, low, close, sweep, simd, true)
1650}
1651
1652#[inline(always)]
1653fn tradjema_batch_inner_into(
1654    high: &[f64],
1655    low: &[f64],
1656    close: &[f64],
1657    sweep: &TradjemaBatchRange,
1658    kern: Kernel,
1659    parallel: bool,
1660    out: &mut [f64],
1661) -> Result<Vec<TradjemaParams>, TradjemaError> {
1662    let combos = expand_grid(sweep);
1663    if combos.is_empty() {
1664        let (ls, le, lstep) = sweep.length;
1665        let (ms, me, mstep) = sweep.mult;
1666        let length_empty =
1667            (lstep == 0 && ls != le) || (lstep > 0 && ls > le && ls.saturating_sub(le) < lstep);
1668        if length_empty {
1669            return Err(TradjemaError::InvalidLengthRange {
1670                start: ls,
1671                end: le,
1672                step: lstep,
1673            });
1674        } else {
1675            return Err(TradjemaError::InvalidMultRange {
1676                start: ms,
1677                end: me,
1678                step: mstep,
1679            });
1680        }
1681    }
1682    if high.len() != low.len() || low.len() != close.len() {
1683        return Err(TradjemaError::MissingData);
1684    }
1685
1686    let cols = close.len();
1687    let first = close
1688        .iter()
1689        .position(|v| !v.is_nan())
1690        .ok_or(TradjemaError::AllValuesNaN)?;
1691
1692    #[inline(always)]
1693    fn precompute_tr(high: &[f64], low: &[f64], close: &[f64], first: usize) -> Vec<f64> {
1694        let n = close.len();
1695        let mut tr = vec![0.0f64; n];
1696        if first < n {
1697            tr[first] = high[first] - low[first];
1698            let mut i = first + 1;
1699            while i < n {
1700                let hl = high[i] - low[i];
1701                let hc = (high[i] - close[i - 1]).abs();
1702                let lc = (low[i] - close[i - 1]).abs();
1703                tr[i] = hl.max(hc).max(lc);
1704                i += 1;
1705            }
1706        }
1707        tr
1708    }
1709
1710    #[inline(always)]
1711    fn compute_from_tr_into(
1712        tr: &[f64],
1713        close: &[f64],
1714        length: usize,
1715        mult: f64,
1716        first: usize,
1717        out: &mut [f64],
1718    ) {
1719        debug_assert_eq!(tr.len(), close.len());
1720        debug_assert_eq!(close.len(), out.len());
1721
1722        let warm = first + length - 1;
1723        if warm >= out.len() {
1724            return;
1725        }
1726        let alpha = 2.0 / (length as f64 + 1.0);
1727
1728        let cap = length;
1729        let mut min_vals = vec![0.0f64; cap];
1730        let mut min_idx = vec![0usize; cap];
1731        let mut max_vals = vec![0.0f64; cap];
1732        let mut max_idx = vec![0usize; cap];
1733        let (mut min_head, mut min_tail) = (0usize, 0usize);
1734        let (mut max_head, mut max_tail) = (0usize, 0usize);
1735        #[inline(always)]
1736        fn inc(i: &mut usize, cap: usize) {
1737            *i += 1;
1738            if *i == cap {
1739                *i = 0;
1740            }
1741        }
1742        #[inline(always)]
1743        fn dec(i: usize, cap: usize) -> usize {
1744            if i == 0 {
1745                cap - 1
1746            } else {
1747                i - 1
1748            }
1749        }
1750        #[inline(always)]
1751        fn minq_push(
1752            v: f64,
1753            idx: usize,
1754            vals: &mut [f64],
1755            id: &mut [usize],
1756            head: &mut usize,
1757            tail: &mut usize,
1758            cap: usize,
1759        ) {
1760            let mut back = dec(*tail, cap);
1761            while *tail != *head && vals[back] > v {
1762                *tail = back;
1763                back = dec(*tail, cap);
1764            }
1765            vals[*tail] = v;
1766            id[*tail] = idx;
1767            inc(tail, cap);
1768        }
1769        #[inline(always)]
1770        fn maxq_push(
1771            v: f64,
1772            idx: usize,
1773            vals: &mut [f64],
1774            id: &mut [usize],
1775            head: &mut usize,
1776            tail: &mut usize,
1777            cap: usize,
1778        ) {
1779            let mut back = dec(*tail, cap);
1780            while *tail != *head && vals[back] < v {
1781                *tail = back;
1782                back = dec(*tail, cap);
1783            }
1784            vals[*tail] = v;
1785            id[*tail] = idx;
1786            inc(tail, cap);
1787        }
1788        #[inline(always)]
1789        fn q_expire(
1790            cur: usize,
1791            len: usize,
1792            id: &mut [usize],
1793            head: &mut usize,
1794            tail: &mut usize,
1795            cap: usize,
1796        ) {
1797            let lim = cur.saturating_sub(len);
1798            while *head != *tail && id[*head] <= lim {
1799                inc(head, cap);
1800            }
1801        }
1802
1803        let mut i = first;
1804        while i <= warm {
1805            let v = tr[i];
1806            minq_push(
1807                v,
1808                i,
1809                &mut min_vals,
1810                &mut min_idx,
1811                &mut min_head,
1812                &mut min_tail,
1813                cap,
1814            );
1815            maxq_push(
1816                v,
1817                i,
1818                &mut max_vals,
1819                &mut max_idx,
1820                &mut max_head,
1821                &mut max_tail,
1822                cap,
1823            );
1824            i += 1;
1825        }
1826        let lo = min_vals[min_head];
1827        let hi = max_vals[max_head];
1828        let den = hi - lo;
1829        let v = tr[warm];
1830        let tr_adj0 = if den != 0.0 { (v - lo) / den } else { 0.0 };
1831        let a0 = alpha * (1.0 + tr_adj0 * mult);
1832        let mut y = a0 * close[warm - 1];
1833        out[warm] = y;
1834
1835        i = warm + 1;
1836        while i < out.len() {
1837            q_expire(i, length, &mut min_idx, &mut min_head, &mut min_tail, cap);
1838            q_expire(i, length, &mut max_idx, &mut max_head, &mut max_tail, cap);
1839
1840            let v = tr[i];
1841            minq_push(
1842                v,
1843                i,
1844                &mut min_vals,
1845                &mut min_idx,
1846                &mut min_head,
1847                &mut min_tail,
1848                cap,
1849            );
1850            maxq_push(
1851                v,
1852                i,
1853                &mut max_vals,
1854                &mut max_idx,
1855                &mut max_head,
1856                &mut max_tail,
1857                cap,
1858            );
1859
1860            let lo = min_vals[min_head];
1861            let hi = max_vals[max_head];
1862            let den = hi - lo;
1863            let tr_adj = if den != 0.0 { (v - lo) / den } else { 0.0 };
1864            let a = alpha * (1.0 + tr_adj * mult);
1865            let src = close[i - 1];
1866            y += a * (src - y);
1867            out[i] = y;
1868
1869            i += 1;
1870        }
1871    }
1872
1873    let pre_tr = precompute_tr(high, low, close, first);
1874    let do_row = |row: usize, dst: &mut [f64]| {
1875        let p = &combos[row];
1876        let length = p.length.unwrap_or(40);
1877        let mult = p.mult.unwrap_or(10.0);
1878        if length < 2 {
1879            return;
1880        }
1881
1882        let _ = kern;
1883        compute_from_tr_into(&pre_tr, close, length, mult, first, dst);
1884    };
1885
1886    if parallel {
1887        #[cfg(not(target_arch = "wasm32"))]
1888        out.par_chunks_mut(cols)
1889            .enumerate()
1890            .for_each(|(row, slice)| do_row(row, slice));
1891        #[cfg(target_arch = "wasm32")]
1892        for (row, slice) in out.chunks_mut(cols).enumerate() {
1893            do_row(row, slice);
1894        }
1895    } else {
1896        for (row, slice) in out.chunks_mut(cols).enumerate() {
1897            do_row(row, slice);
1898        }
1899    }
1900
1901    Ok(combos)
1902}
1903
1904fn tradjema_batch_inner(
1905    high: &[f64],
1906    low: &[f64],
1907    close: &[f64],
1908    sweep: &TradjemaBatchRange,
1909    kern: Kernel,
1910    parallel: bool,
1911) -> Result<TradjemaBatchOutput, TradjemaError> {
1912    let combos = expand_grid(sweep);
1913    if combos.is_empty() {
1914        return Err(TradjemaError::InvalidLength {
1915            length: 0,
1916            data_len: 0,
1917        });
1918    }
1919    let rows = combos.len();
1920    let cols = close.len();
1921
1922    let mut buf_mu = make_uninit_matrix(rows, cols);
1923    let first = close
1924        .iter()
1925        .position(|v| !v.is_nan())
1926        .ok_or(TradjemaError::AllValuesNaN)?;
1927    let warms: Vec<usize> = combos
1928        .iter()
1929        .map(|p| first + p.length.unwrap_or(40) - 1)
1930        .collect();
1931    init_matrix_prefixes(&mut buf_mu, cols, &warms);
1932
1933    let mut guard = core::mem::ManuallyDrop::new(buf_mu);
1934    let out: &mut [f64] =
1935        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
1936
1937    let combos = tradjema_batch_inner_into(high, low, close, sweep, kern, parallel, out)?;
1938
1939    let values = unsafe {
1940        Vec::from_raw_parts(
1941            guard.as_mut_ptr() as *mut f64,
1942            guard.len(),
1943            guard.capacity(),
1944        )
1945    };
1946
1947    Ok(TradjemaBatchOutput {
1948        values,
1949        combos,
1950        rows,
1951        cols,
1952    })
1953}
1954
1955#[cfg(test)]
1956mod tests {
1957    use super::*;
1958    use crate::skip_if_unsupported;
1959    use crate::utilities::data_loader::read_candles_from_csv;
1960    #[cfg(feature = "proptest")]
1961    use proptest::prelude::*;
1962    use std::error::Error;
1963
1964    fn check_tradjema_partial_params(
1965        test_name: &str,
1966        kernel: Kernel,
1967    ) -> Result<(), Box<dyn Error>> {
1968        skip_if_unsupported!(kernel, test_name);
1969        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1970        let candles = read_candles_from_csv(file_path)?;
1971
1972        let high = candles.select_candle_field("high")?;
1973        let low = candles.select_candle_field("low")?;
1974        let close = candles.select_candle_field("close")?;
1975
1976        let default_params = TradjemaParams {
1977            length: None,
1978            mult: None,
1979        };
1980        let input = TradjemaInput::from_slices(high, low, close, default_params);
1981        let output = tradjema_with_kernel(&input, kernel)?;
1982        assert_eq!(output.values.len(), candles.close.len());
1983
1984        Ok(())
1985    }
1986
1987    fn check_tradjema_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1988        skip_if_unsupported!(kernel, test_name);
1989        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1990        let candles = read_candles_from_csv(file_path)?;
1991
1992        let high = candles.select_candle_field("high")?;
1993        let low = candles.select_candle_field("low")?;
1994        let close = candles.select_candle_field("close")?;
1995
1996        let input = TradjemaInput::from_slices(high, low, close, TradjemaParams::default());
1997        let result = tradjema_with_kernel(&input, kernel)?;
1998
1999        assert_eq!(result.values.len(), candles.close.len());
2000
2001        let warmup = 39;
2002        for i in 0..warmup {
2003            assert!(
2004                result.values[i].is_nan(),
2005                "[{}] Expected NaN during warmup at index {}",
2006                test_name,
2007                i
2008            );
2009        }
2010
2011        for i in warmup..result.values.len() {
2012            assert!(
2013                !result.values[i].is_nan(),
2014                "[{}] Expected valid value after warmup at index {}",
2015                test_name,
2016                i
2017            );
2018        }
2019
2020        let expected_last_five = [
2021            59395.39322263,
2022            59388.09683228,
2023            59373.08371503,
2024            59350.75110897,
2025            59323.14225348,
2026        ];
2027
2028        let start = result.values.len().saturating_sub(5);
2029        for (i, &val) in result.values[start..].iter().enumerate() {
2030            let diff = (val - expected_last_five[i]).abs();
2031            assert!(
2032                diff < 1e-8,
2033                "[{}] TRADJEMA accuracy mismatch at last_5[{}]: got {:.8}, expected {:.8}, diff={:.10}",
2034                test_name,
2035                i,
2036                val,
2037                expected_last_five[i],
2038                diff
2039            );
2040        }
2041
2042        Ok(())
2043    }
2044
2045    fn check_tradjema_default_candles(
2046        test_name: &str,
2047        kernel: Kernel,
2048    ) -> Result<(), Box<dyn Error>> {
2049        skip_if_unsupported!(kernel, test_name);
2050        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2051        let candles = read_candles_from_csv(file_path)?;
2052
2053        let input = TradjemaInput::with_default_candles(&candles);
2054        let output = tradjema_with_kernel(&input, kernel)?;
2055        assert_eq!(output.values.len(), candles.close.len());
2056
2057        Ok(())
2058    }
2059
2060    fn check_tradjema_zero_length(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2061        skip_if_unsupported!(kernel, test_name);
2062        let input_data = vec![10.0, 20.0, 30.0];
2063
2064        let params = TradjemaParams {
2065            length: Some(0),
2066            mult: None,
2067        };
2068        let input = TradjemaInput::from_slices(&input_data, &input_data, &input_data, params);
2069        let res = tradjema_with_kernel(&input, kernel);
2070        assert!(
2071            res.is_err(),
2072            "[{}] TRADJEMA should fail with zero length",
2073            test_name
2074        );
2075
2076        let params = TradjemaParams {
2077            length: Some(1),
2078            mult: None,
2079        };
2080        let input = TradjemaInput::from_slices(&input_data, &input_data, &input_data, params);
2081        let res = tradjema_with_kernel(&input, kernel);
2082        assert!(
2083            res.is_err(),
2084            "[{}] TRADJEMA should fail with length=1 (minimum is 2)",
2085            test_name
2086        );
2087
2088        Ok(())
2089    }
2090
2091    fn check_tradjema_length_exceeds_data(
2092        test_name: &str,
2093        kernel: Kernel,
2094    ) -> Result<(), Box<dyn Error>> {
2095        skip_if_unsupported!(kernel, test_name);
2096        let data_small = vec![10.0, 20.0, 30.0];
2097        let params = TradjemaParams {
2098            length: Some(10),
2099            mult: None,
2100        };
2101        let input = TradjemaInput::from_slices(&data_small, &data_small, &data_small, params);
2102        let res = tradjema_with_kernel(&input, kernel);
2103        assert!(
2104            res.is_err(),
2105            "[{}] TRADJEMA should fail with length exceeding data",
2106            test_name
2107        );
2108        Ok(())
2109    }
2110
2111    fn check_tradjema_very_small_dataset(
2112        test_name: &str,
2113        kernel: Kernel,
2114    ) -> Result<(), Box<dyn Error>> {
2115        skip_if_unsupported!(kernel, test_name);
2116        let single_point = vec![42.0];
2117        let params = TradjemaParams {
2118            length: Some(40),
2119            mult: None,
2120        };
2121        let input = TradjemaInput::from_slices(&single_point, &single_point, &single_point, params);
2122        let res = tradjema_with_kernel(&input, kernel);
2123        assert!(
2124            res.is_err(),
2125            "[{}] TRADJEMA should fail with insufficient data",
2126            test_name
2127        );
2128        Ok(())
2129    }
2130
2131    fn check_tradjema_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2132        skip_if_unsupported!(kernel, test_name);
2133        let empty: Vec<f64> = vec![];
2134        let input = TradjemaInput::from_slices(&empty, &empty, &empty, TradjemaParams::default());
2135        let res = tradjema_with_kernel(&input, kernel);
2136        assert!(
2137            matches!(res, Err(TradjemaError::EmptyInputData)),
2138            "[{}] TRADJEMA should fail with empty input",
2139            test_name
2140        );
2141        Ok(())
2142    }
2143
2144    fn check_tradjema_invalid_mult(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2145        skip_if_unsupported!(kernel, test_name);
2146        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
2147
2148        let params = TradjemaParams {
2149            length: Some(2),
2150            mult: Some(-10.0),
2151        };
2152        let input = TradjemaInput::from_slices(&data, &data, &data, params);
2153        let res = tradjema_with_kernel(&input, kernel);
2154        assert!(
2155            matches!(res, Err(TradjemaError::InvalidMult { .. })),
2156            "[{}] TRADJEMA should fail with negative mult",
2157            test_name
2158        );
2159
2160        let params = TradjemaParams {
2161            length: Some(2),
2162            mult: Some(f64::NAN),
2163        };
2164        let input = TradjemaInput::from_slices(&data, &data, &data, params);
2165        let res = tradjema_with_kernel(&input, kernel);
2166        assert!(
2167            matches!(res, Err(TradjemaError::InvalidMult { .. })),
2168            "[{}] TRADJEMA should fail with NaN mult",
2169            test_name
2170        );
2171
2172        Ok(())
2173    }
2174
2175    fn check_tradjema_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2176        skip_if_unsupported!(kernel, test_name);
2177        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2178        let candles = read_candles_from_csv(file_path)?;
2179
2180        let high = candles.select_candle_field("high")?;
2181        let low = candles.select_candle_field("low")?;
2182        let close = candles.select_candle_field("close")?;
2183
2184        let first_params = TradjemaParams {
2185            length: Some(20),
2186            mult: Some(5.0),
2187        };
2188        let first_input = TradjemaInput::from_slices(high, low, close, first_params);
2189        let first_result = tradjema_with_kernel(&first_input, kernel)?;
2190
2191        let second_params = TradjemaParams {
2192            length: Some(20),
2193            mult: Some(5.0),
2194        };
2195        let second_input = TradjemaInput::from_slices(
2196            &first_result.values,
2197            &first_result.values,
2198            &first_result.values,
2199            second_params,
2200        );
2201        let second_result = tradjema_with_kernel(&second_input, kernel)?;
2202
2203        assert_eq!(second_result.values.len(), first_result.values.len());
2204        Ok(())
2205    }
2206
2207    fn check_tradjema_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2208        skip_if_unsupported!(kernel, test_name);
2209        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2210        let candles = read_candles_from_csv(file_path)?;
2211
2212        let high = candles.select_candle_field("high")?;
2213        let low = candles.select_candle_field("low")?;
2214        let close = candles.select_candle_field("close")?;
2215
2216        let input = TradjemaInput::from_slices(
2217            high,
2218            low,
2219            close,
2220            TradjemaParams {
2221                length: Some(40),
2222                mult: Some(10.0),
2223            },
2224        );
2225        let res = tradjema_with_kernel(&input, kernel)?;
2226        assert_eq!(res.values.len(), candles.close.len());
2227
2228        if res.values.len() > 50 {
2229            for (i, &val) in res.values[50..].iter().enumerate() {
2230                assert!(
2231                    !val.is_nan(),
2232                    "[{}] Found unexpected NaN at out-index {}",
2233                    test_name,
2234                    50 + i
2235                );
2236            }
2237        }
2238        Ok(())
2239    }
2240
2241    fn check_tradjema_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2242        skip_if_unsupported!(kernel, test_name);
2243
2244        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2245        let candles = read_candles_from_csv(file_path)?;
2246
2247        let high = candles.select_candle_field("high")?;
2248        let low = candles.select_candle_field("low")?;
2249        let close = candles.select_candle_field("close")?;
2250
2251        let length = 40;
2252        let mult = 10.0;
2253
2254        let input = TradjemaInput::from_slices(
2255            high,
2256            low,
2257            close,
2258            TradjemaParams {
2259                length: Some(length),
2260                mult: Some(mult),
2261            },
2262        );
2263        let batch_output = tradjema_with_kernel(&input, kernel)?.values;
2264
2265        let mut stream = TradjemaStream::try_new(TradjemaParams {
2266            length: Some(length),
2267            mult: Some(mult),
2268        })?;
2269
2270        let mut stream_values = Vec::with_capacity(candles.close.len());
2271        for i in 0..candles.close.len() {
2272            match stream.update(high[i], low[i], close[i]) {
2273                Some(val) => stream_values.push(val),
2274                None => stream_values.push(f64::NAN),
2275            }
2276        }
2277
2278        assert_eq!(batch_output.len(), stream_values.len());
2279        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
2280            if b.is_nan() && s.is_nan() {
2281                continue;
2282            }
2283            let diff = (b - s).abs();
2284            assert!(
2285                diff < 1e-9,
2286                "[{}] TRADJEMA streaming mismatch at idx {}: batch={}, stream={}, diff={}",
2287                test_name,
2288                i,
2289                b,
2290                s,
2291                diff
2292            );
2293        }
2294        Ok(())
2295    }
2296
2297    #[cfg(debug_assertions)]
2298    fn check_tradjema_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2299        skip_if_unsupported!(kernel, test_name);
2300
2301        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2302        let candles = read_candles_from_csv(file_path)?;
2303
2304        let high = candles.select_candle_field("high")?;
2305        let low = candles.select_candle_field("low")?;
2306        let close = candles.select_candle_field("close")?;
2307
2308        let test_params = vec![
2309            TradjemaParams::default(),
2310            TradjemaParams {
2311                length: Some(10),
2312                mult: Some(5.0),
2313            },
2314            TradjemaParams {
2315                length: Some(20),
2316                mult: Some(7.5),
2317            },
2318            TradjemaParams {
2319                length: Some(50),
2320                mult: Some(15.0),
2321            },
2322            TradjemaParams {
2323                length: Some(100),
2324                mult: Some(20.0),
2325            },
2326        ];
2327
2328        for (param_idx, params) in test_params.iter().enumerate() {
2329            let input = TradjemaInput::from_slices(high, low, close, params.clone());
2330            let output = tradjema_with_kernel(&input, kernel)?;
2331
2332            for (i, &val) in output.values.iter().enumerate() {
2333                if val.is_nan() {
2334                    continue;
2335                }
2336
2337                let bits = val.to_bits();
2338
2339                if bits == 0x11111111_11111111 {
2340                    panic!(
2341                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2342                        with params: length={}, mult={}",
2343                        test_name,
2344                        val,
2345                        bits,
2346                        i,
2347                        params.length.unwrap_or(40),
2348                        params.mult.unwrap_or(10.0)
2349                    );
2350                }
2351
2352                if bits == 0x22222222_22222222 {
2353                    panic!(
2354                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2355                        with params: length={}, mult={}",
2356                        test_name,
2357                        val,
2358                        bits,
2359                        i,
2360                        params.length.unwrap_or(40),
2361                        params.mult.unwrap_or(10.0)
2362                    );
2363                }
2364
2365                if bits == 0x33333333_33333333 {
2366                    panic!(
2367                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2368                        with params: length={}, mult={}",
2369                        test_name,
2370                        val,
2371                        bits,
2372                        i,
2373                        params.length.unwrap_or(40),
2374                        params.mult.unwrap_or(10.0)
2375                    );
2376                }
2377            }
2378        }
2379
2380        Ok(())
2381    }
2382
2383    #[cfg(not(debug_assertions))]
2384    fn check_tradjema_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2385        Ok(())
2386    }
2387
2388    #[cfg(feature = "proptest")]
2389    #[allow(clippy::float_cmp)]
2390    fn check_tradjema_property(
2391        test_name: &str,
2392        kernel: Kernel,
2393    ) -> Result<(), Box<dyn std::error::Error>> {
2394        use proptest::prelude::*;
2395        skip_if_unsupported!(kernel, test_name);
2396
2397        let strat = (2usize..=100).prop_flat_map(|length| {
2398            (
2399                prop::collection::vec(
2400                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2401                    length..400,
2402                ),
2403                Just(length),
2404                0.1f64..50.0f64,
2405            )
2406        });
2407
2408        proptest::test_runner::TestRunner::default()
2409            .run(&strat, |(data, length, mult)| {
2410                let params = TradjemaParams {
2411                    length: Some(length),
2412                    mult: Some(mult),
2413                };
2414
2415                let input = TradjemaInput::from_slices(&data, &data, &data, params);
2416
2417                let TradjemaOutput { values: out } = tradjema_with_kernel(&input, kernel).unwrap();
2418                let TradjemaOutput { values: ref_out } =
2419                    tradjema_with_kernel(&input, Kernel::Scalar).unwrap();
2420
2421                for i in (length - 1)..data.len() {
2422                    let y = out[i];
2423                    let r = ref_out[i];
2424
2425                    if !y.is_finite() || !r.is_finite() {
2426                        prop_assert!(
2427                            y.to_bits() == r.to_bits(),
2428                            "finite/NaN mismatch idx {i}: {y} vs {r}"
2429                        );
2430                        continue;
2431                    }
2432
2433                    let ulp_diff: u64 = y.to_bits().abs_diff(r.to_bits());
2434
2435                    prop_assert!(
2436                        (y - r).abs() <= 1e-9 || ulp_diff <= 4,
2437                        "mismatch idx {i}: {y} vs {r} (ULP={ulp_diff})"
2438                    );
2439                }
2440                Ok(())
2441            })
2442            .unwrap();
2443
2444        Ok(())
2445    }
2446
2447    macro_rules! generate_all_tradjema_tests {
2448        ($($test_fn:ident),*) => {
2449            paste::paste! {
2450                $(
2451                    #[test]
2452                    fn [<$test_fn _scalar_f64>]() {
2453                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2454                    }
2455                )*
2456                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2457                $(
2458                    #[test]
2459                    fn [<$test_fn _avx2_f64>]() {
2460                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2461                    }
2462                    #[test]
2463                    fn [<$test_fn _avx512_f64>]() {
2464                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2465                    }
2466                )*
2467                #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2468                $(
2469                    #[test]
2470                    fn [<$test_fn _simd128_f64>]() {
2471                        let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
2472                    }
2473                )*
2474            }
2475        }
2476    }
2477
2478    generate_all_tradjema_tests!(
2479        check_tradjema_partial_params,
2480        check_tradjema_accuracy,
2481        check_tradjema_default_candles,
2482        check_tradjema_zero_length,
2483        check_tradjema_length_exceeds_data,
2484        check_tradjema_very_small_dataset,
2485        check_tradjema_empty_input,
2486        check_tradjema_invalid_mult,
2487        check_tradjema_reinput,
2488        check_tradjema_nan_handling,
2489        check_tradjema_streaming,
2490        check_tradjema_no_poison
2491    );
2492
2493    fn check_tradjema_into_slice(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2494        skip_if_unsupported!(kernel, test_name);
2495        let f = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2496        let c = read_candles_from_csv(f)?;
2497        let (h, l, cl) = (
2498            c.select_candle_field("high")?,
2499            c.select_candle_field("low")?,
2500            c.select_candle_field("close")?,
2501        );
2502        let input = TradjemaInput::from_slices(h, l, cl, TradjemaParams::default());
2503        let mut dst = vec![0.0; cl.len()];
2504        tradjema_into_slice(&mut dst, &input, kernel)?;
2505        let first = cl.iter().position(|v| !v.is_nan()).unwrap();
2506        let warm = first + input.get_length() - 1;
2507        assert!(
2508            dst[..warm].iter().all(|v| v.is_nan()),
2509            "[{}] warmup prefix must be NaN",
2510            test_name
2511        );
2512        Ok(())
2513    }
2514
2515    generate_all_tradjema_tests!(check_tradjema_into_slice);
2516
2517    #[cfg(feature = "proptest")]
2518    generate_all_tradjema_tests!(check_tradjema_property);
2519
2520    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2521        skip_if_unsupported!(kernel, test);
2522
2523        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2524        let c = read_candles_from_csv(file)?;
2525
2526        let output = TradjemaBatchBuilder::new()
2527            .kernel(kernel)
2528            .apply_candles(&c)?;
2529
2530        let def = TradjemaParams::default();
2531        let row = output.values_for(&def).expect("default row missing");
2532
2533        assert_eq!(row.len(), c.close.len());
2534        Ok(())
2535    }
2536
2537    fn check_batch_sweep(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2538        skip_if_unsupported!(kernel, test);
2539
2540        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2541        let c = read_candles_from_csv(file)?;
2542
2543        let output = TradjemaBatchBuilder::new()
2544            .kernel(kernel)
2545            .length_range(20, 50, 10)
2546            .mult_range(5.0, 15.0, 5.0)
2547            .apply_candles(&c)?;
2548
2549        let expected_combos = 4 * 3;
2550        assert_eq!(output.combos.len(), expected_combos);
2551        assert_eq!(output.rows, expected_combos);
2552        assert_eq!(output.cols, c.close.len());
2553
2554        Ok(())
2555    }
2556
2557    #[cfg(debug_assertions)]
2558    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2559        skip_if_unsupported!(kernel, test);
2560
2561        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2562        let c = read_candles_from_csv(file)?;
2563
2564        let test_configs = vec![
2565            (10, 30, 10, 5.0, 15.0, 5.0),
2566            (40, 40, 0, 10.0, 10.0, 0.0),
2567            (20, 60, 20, 7.5, 12.5, 2.5),
2568        ];
2569
2570        for (cfg_idx, &(l_start, l_end, l_step, m_start, m_end, m_step)) in
2571            test_configs.iter().enumerate()
2572        {
2573            let output = TradjemaBatchBuilder::new()
2574                .kernel(kernel)
2575                .length_range(l_start, l_end, l_step)
2576                .mult_range(m_start, m_end, m_step)
2577                .apply_candles(&c)?;
2578
2579            for (idx, &val) in output.values.iter().enumerate() {
2580                if val.is_nan() {
2581                    continue;
2582                }
2583
2584                let bits = val.to_bits();
2585                let row = idx / output.cols;
2586                let col = idx % output.cols;
2587                let combo = &output.combos[row];
2588
2589                if bits == 0x11111111_11111111
2590                    || bits == 0x22222222_22222222
2591                    || bits == 0x33333333_33333333
2592                {
2593                    panic!(
2594                        "[{}] Config {}: Found poison value {} (0x{:016X}) \
2595                        at row {} col {} (flat index {}) with params: length={}, mult={}",
2596                        test,
2597                        cfg_idx,
2598                        val,
2599                        bits,
2600                        row,
2601                        col,
2602                        idx,
2603                        combo.length.unwrap_or(40),
2604                        combo.mult.unwrap_or(10.0)
2605                    );
2606                }
2607            }
2608        }
2609
2610        Ok(())
2611    }
2612
2613    #[cfg(not(debug_assertions))]
2614    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2615        Ok(())
2616    }
2617
2618    macro_rules! gen_batch_tests {
2619        ($fn_name:ident) => {
2620            paste::paste! {
2621                #[test] fn [<$fn_name _scalar>]()      {
2622                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2623                }
2624                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2625                #[test] fn [<$fn_name _avx2>]()        {
2626                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2627                }
2628                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2629                #[test] fn [<$fn_name _avx512>]()      {
2630                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2631                }
2632                #[test] fn [<$fn_name _auto_detect>]() {
2633                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]),
2634                                    Kernel::Auto);
2635                }
2636            }
2637        };
2638    }
2639
2640    gen_batch_tests!(check_batch_default_row);
2641    gen_batch_tests!(check_batch_sweep);
2642    gen_batch_tests!(check_batch_no_poison);
2643
2644    #[test]
2645    fn test_expand_grid_single_value() {
2646        let range = TradjemaBatchRange {
2647            length: (2, 2, 0),
2648            mult: (10.0, 10.0, 0.0),
2649        };
2650        let combos = expand_grid(&range);
2651        assert!(
2652            !combos.is_empty(),
2653            "expand_grid should not return empty for single value"
2654        );
2655        assert_eq!(combos.len(), 1, "Should have exactly one combo");
2656        assert_eq!(combos[0].length, Some(2));
2657        assert_eq!(combos[0].mult, Some(10.0));
2658
2659        let range2 = TradjemaBatchRange {
2660            length: (40, 40, 0),
2661            mult: (10.0, 10.0, 0.0),
2662        };
2663        let combos2 = expand_grid(&range2);
2664        assert!(
2665            !combos2.is_empty(),
2666            "expand_grid should not return empty for single value (40,40,0)"
2667        );
2668        assert_eq!(
2669            combos2.len(),
2670            1,
2671            "Should have exactly one combo for (40,40,0)"
2672        );
2673        assert_eq!(combos2[0].length, Some(40));
2674        assert_eq!(combos2[0].mult, Some(10.0));
2675    }
2676
2677    #[test]
2678    fn test_tradjema_into_matches_api() -> Result<(), Box<dyn Error>> {
2679        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2680        let c = read_candles_from_csv(file)?;
2681        let (h, l, cl) = (
2682            c.select_candle_field("high")?,
2683            c.select_candle_field("low")?,
2684            c.select_candle_field("close")?,
2685        );
2686
2687        let input = TradjemaInput::from_slices(h, l, cl, TradjemaParams::default());
2688
2689        let base = tradjema(&input)?.values;
2690
2691        let mut out = vec![0.0; cl.len()];
2692        tradjema_into(&input, &mut out)?;
2693
2694        assert_eq!(base.len(), out.len());
2695
2696        for i in 0..out.len() {
2697            let a = base[i];
2698            let b = out[i];
2699            if a.is_nan() || b.is_nan() {
2700                assert!(
2701                    a.is_nan() && b.is_nan(),
2702                    "NaN mismatch at index {}: {:?} vs {:?}",
2703                    i,
2704                    a,
2705                    b
2706                );
2707            } else {
2708                assert_eq!(a, b, "Value mismatch at index {}: {} vs {}", i, a, b);
2709            }
2710        }
2711
2712        Ok(())
2713    }
2714}