1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::CudaTradjema;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::cuda::moving_averages::DeviceArrayF32;
7#[cfg(all(feature = "python", feature = "cuda"))]
8use cust::context::Context;
9#[cfg(all(feature = "python", feature = "cuda"))]
10use numpy::PyUntypedArrayMethods;
11#[cfg(feature = "python")]
12use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2};
13#[cfg(feature = "python")]
14use pyo3::exceptions::PyValueError;
15#[cfg(feature = "python")]
16use pyo3::prelude::*;
17#[cfg(feature = "python")]
18use pyo3::types::{PyDict, PyList};
19#[cfg(all(feature = "python", feature = "cuda"))]
20use std::sync::Arc;
21
22#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
23use serde::{Deserialize, Serialize};
24#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
25use wasm_bindgen::prelude::*;
26
27use crate::utilities::data_loader::Candles;
28use crate::utilities::enums::Kernel;
29use crate::utilities::helpers::{
30 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
31 make_uninit_matrix,
32};
33#[cfg(feature = "python")]
34use crate::utilities::kernel_validation::validate_kernel;
35
36#[cfg(not(target_arch = "wasm32"))]
37use rayon::prelude::*;
38
39use std::convert::AsRef;
40use std::error::Error;
41use std::mem::MaybeUninit;
42use thiserror::Error;
43
44#[derive(Debug, Clone)]
45pub enum TradjemaData<'a> {
46 Candles {
47 candles: &'a Candles,
48 },
49 Slices {
50 high: &'a [f64],
51 low: &'a [f64],
52 close: &'a [f64],
53 },
54}
55
56#[derive(Debug, Clone)]
57pub struct TradjemaOutput {
58 pub values: Vec<f64>,
59}
60
61#[derive(Debug, Clone)]
62#[cfg_attr(
63 all(target_arch = "wasm32", feature = "wasm"),
64 derive(Serialize, Deserialize)
65)]
66pub struct TradjemaParams {
67 pub length: Option<usize>,
68 pub mult: Option<f64>,
69}
70
71impl Default for TradjemaParams {
72 fn default() -> Self {
73 Self {
74 length: Some(40),
75 mult: Some(10.0),
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
81pub struct TradjemaInput<'a> {
82 pub data: TradjemaData<'a>,
83 pub params: TradjemaParams,
84}
85
86impl<'a> TradjemaInput<'a> {
87 #[inline]
88 pub fn from_candles(candles: &'a Candles, params: TradjemaParams) -> Self {
89 Self {
90 data: TradjemaData::Candles { candles },
91 params,
92 }
93 }
94
95 #[inline]
96 pub fn from_slices(
97 high: &'a [f64],
98 low: &'a [f64],
99 close: &'a [f64],
100 params: TradjemaParams,
101 ) -> Self {
102 Self {
103 data: TradjemaData::Slices { high, low, close },
104 params,
105 }
106 }
107
108 #[inline]
109 pub fn with_default_candles(candles: &'a Candles) -> Self {
110 Self::from_candles(candles, TradjemaParams::default())
111 }
112
113 #[inline]
114 pub fn get_length(&self) -> usize {
115 self.params.length.unwrap_or(40)
116 }
117
118 #[inline]
119 pub fn get_mult(&self) -> f64 {
120 self.params.mult.unwrap_or(10.0)
121 }
122}
123
124#[derive(Copy, Clone, Debug)]
125pub struct TradjemaBuilder {
126 length: Option<usize>,
127 mult: Option<f64>,
128 kernel: Kernel,
129}
130
131impl Default for TradjemaBuilder {
132 fn default() -> Self {
133 Self {
134 length: None,
135 mult: None,
136 kernel: Kernel::Auto,
137 }
138 }
139}
140
141impl TradjemaBuilder {
142 #[inline(always)]
143 pub fn new() -> Self {
144 Self::default()
145 }
146
147 #[inline(always)]
148 pub fn length(mut self, n: usize) -> Self {
149 self.length = Some(n);
150 self
151 }
152
153 #[inline(always)]
154 pub fn mult(mut self, m: f64) -> Self {
155 self.mult = Some(m);
156 self
157 }
158
159 #[inline(always)]
160 pub fn kernel(mut self, k: Kernel) -> Self {
161 self.kernel = k;
162 self
163 }
164
165 #[inline(always)]
166 pub fn apply(self, c: &Candles) -> Result<TradjemaOutput, TradjemaError> {
167 let p = TradjemaParams {
168 length: self.length,
169 mult: self.mult,
170 };
171 let i = TradjemaInput::from_candles(c, p);
172 tradjema_with_kernel(&i, self.kernel)
173 }
174
175 #[inline(always)]
176 pub fn apply_slices(
177 self,
178 high: &[f64],
179 low: &[f64],
180 close: &[f64],
181 ) -> Result<TradjemaOutput, TradjemaError> {
182 let p = TradjemaParams {
183 length: self.length,
184 mult: self.mult,
185 };
186 let i = TradjemaInput::from_slices(high, low, close, p);
187 tradjema_with_kernel(&i, self.kernel)
188 }
189
190 #[inline(always)]
191 pub fn into_stream(self) -> Result<TradjemaStream, TradjemaError> {
192 let p = TradjemaParams {
193 length: self.length,
194 mult: self.mult,
195 };
196 TradjemaStream::try_new(p)
197 }
198}
199
200#[derive(Debug, Error)]
201pub enum TradjemaError {
202 #[error("tradjema: Input data slice is empty.")]
203 EmptyInputData,
204
205 #[error("tradjema: All values are NaN.")]
206 AllValuesNaN,
207
208 #[error("tradjema: Invalid length: length = {length}, data length = {data_len}")]
209 InvalidLength { length: usize, data_len: usize },
210
211 #[error("tradjema: Not enough valid data: needed = {needed}, valid = {valid}")]
212 NotEnoughValidData { needed: usize, valid: usize },
213
214 #[error("tradjema: OHLC data length mismatch")]
215 MissingData,
216
217 #[error("tradjema: Invalid multiplier: {mult}")]
218 InvalidMult { mult: f64 },
219
220 #[error("tradjema: Output length mismatch: expected {expected}, got {got}")]
221 OutputLengthMismatch { expected: usize, got: usize },
222
223 #[error("tradjema: Invalid length range (start={start}, end={end}, step={step})")]
224 InvalidLengthRange {
225 start: usize,
226 end: usize,
227 step: usize,
228 },
229
230 #[error("tradjema: Invalid mult range (start={start}, end={end}, step={step})")]
231 InvalidMultRange { start: f64, end: f64, step: f64 },
232
233 #[error("tradjema: non-batch kernel passed to batch path: {0:?}")]
234 InvalidKernelForBatch(Kernel),
235}
236
237#[inline(always)]
238fn tradjema_prepare<'a>(
239 input: &'a TradjemaInput,
240 kernel: Kernel,
241) -> Result<(&'a [f64], &'a [f64], &'a [f64], usize, usize, f64, Kernel), TradjemaError> {
242 let (high, low, close) = match &input.data {
243 TradjemaData::Candles { candles } => {
244 let h = candles
245 .select_candle_field("high")
246 .map_err(|_| TradjemaError::EmptyInputData)?;
247 let l = candles
248 .select_candle_field("low")
249 .map_err(|_| TradjemaError::EmptyInputData)?;
250 let c = candles
251 .select_candle_field("close")
252 .map_err(|_| TradjemaError::EmptyInputData)?;
253 (h, l, c)
254 }
255 TradjemaData::Slices { high, low, close } => {
256 if high.len() != low.len() || low.len() != close.len() {
257 return Err(TradjemaError::MissingData);
258 }
259 (*high, *low, *close)
260 }
261 };
262
263 let len = close.len();
264 if len == 0 {
265 return Err(TradjemaError::EmptyInputData);
266 }
267
268 let first = close
269 .iter()
270 .position(|v| !v.is_nan())
271 .ok_or(TradjemaError::AllValuesNaN)?;
272 let length = input.get_length();
273 if length < 2 || length > len {
274 return Err(TradjemaError::InvalidLength {
275 length,
276 data_len: len,
277 });
278 }
279 if len - first < length {
280 return Err(TradjemaError::NotEnoughValidData {
281 needed: length,
282 valid: len - first,
283 });
284 }
285
286 let mult = input.get_mult();
287 if mult <= 0.0 || !mult.is_finite() {
288 return Err(TradjemaError::InvalidMult { mult });
289 }
290
291 let chosen = match kernel {
292 Kernel::Auto => Kernel::Scalar,
293 k => k,
294 };
295 Ok((high, low, close, length, first, mult, chosen))
296}
297
298#[inline]
299pub fn tradjema(input: &TradjemaInput) -> Result<TradjemaOutput, TradjemaError> {
300 tradjema_with_kernel(input, Kernel::Auto)
301}
302
303pub fn tradjema_with_kernel(
304 input: &TradjemaInput,
305 kernel: Kernel,
306) -> Result<TradjemaOutput, TradjemaError> {
307 let (h, l, c, length, first, mult, chosen) = tradjema_prepare(input, kernel)?;
308 let warm = first + length - 1;
309 let mut out = alloc_with_nan_prefix(c.len(), warm);
310 tradjema_compute_into(h, l, c, length, mult, first, chosen, &mut out);
311 Ok(TradjemaOutput { values: out })
312}
313
314#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
315#[inline]
316pub fn tradjema_into(input: &TradjemaInput, out: &mut [f64]) -> Result<(), TradjemaError> {
317 let (h, l, c, length, first, mult, chosen) = tradjema_prepare(input, Kernel::Auto)?;
318 if out.len() != c.len() {
319 return Err(TradjemaError::OutputLengthMismatch {
320 expected: c.len(),
321 got: out.len(),
322 });
323 }
324
325 let warm = first + length - 1;
326 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
327 let w = warm.min(out.len());
328 for v in &mut out[..w] {
329 *v = qnan;
330 }
331
332 tradjema_compute_into(h, l, c, length, mult, first, chosen, out);
333 Ok(())
334}
335
336#[inline]
337pub fn tradjema_into_slice(
338 dst: &mut [f64],
339 input: &TradjemaInput,
340 kern: Kernel,
341) -> Result<(), TradjemaError> {
342 let (h, l, c, length, first, mult, chosen) = tradjema_prepare(input, kern)?;
343 if dst.len() != c.len() {
344 return Err(TradjemaError::OutputLengthMismatch {
345 expected: c.len(),
346 got: dst.len(),
347 });
348 }
349 tradjema_compute_into(h, l, c, length, mult, first, chosen, dst);
350
351 let warm = first + length - 1;
352 for v in &mut dst[..warm] {
353 *v = f64::NAN;
354 }
355 Ok(())
356}
357
358#[inline(always)]
359fn tradjema_compute_into_scalar(
360 high: &[f64],
361 low: &[f64],
362 close: &[f64],
363 length: usize,
364 mult: f64,
365 first: usize,
366 out: &mut [f64],
367) {
368 debug_assert_eq!(high.len(), low.len());
369 debug_assert_eq!(low.len(), close.len());
370 debug_assert_eq!(close.len(), out.len());
371 debug_assert!(length >= 2);
372
373 let n = out.len();
374 let warm = first + length - 1;
375 if warm >= n {
376 return;
377 }
378
379 let alpha = 2.0 / (length as f64 + 1.0);
380
381 let cap = length;
382 let mut min_vals = vec![0.0f64; cap];
383 let mut min_idx = vec![0usize; cap];
384 let mut max_vals = vec![0.0f64; cap];
385 let mut max_idx = vec![0usize; cap];
386 let (mut min_head, mut min_tail) = (0usize, 0usize);
387 let (mut max_head, mut max_tail) = (0usize, 0usize);
388
389 #[inline(always)]
390 fn inc(i: &mut usize, cap: usize) {
391 *i += 1;
392 if *i == cap {
393 *i = 0;
394 }
395 }
396 #[inline(always)]
397 fn dec(i: usize, cap: usize) -> usize {
398 if i == 0 {
399 cap - 1
400 } else {
401 i - 1
402 }
403 }
404 #[inline(always)]
405 fn minq_push(
406 v: f64,
407 idx: usize,
408 vals: &mut [f64],
409 id: &mut [usize],
410 head: &mut usize,
411 tail: &mut usize,
412 cap: usize,
413 ) {
414 let mut back = dec(*tail, cap);
415 while *tail != *head && unsafe { *vals.get_unchecked(back) } > v {
416 *tail = back;
417 back = dec(*tail, cap);
418 }
419 unsafe {
420 *vals.get_unchecked_mut(*tail) = v;
421 *id.get_unchecked_mut(*tail) = idx;
422 }
423 inc(tail, cap);
424 }
425 #[inline(always)]
426 fn maxq_push(
427 v: f64,
428 idx: usize,
429 vals: &mut [f64],
430 id: &mut [usize],
431 head: &mut usize,
432 tail: &mut usize,
433 cap: usize,
434 ) {
435 let mut back = dec(*tail, cap);
436 while *tail != *head && unsafe { *vals.get_unchecked(back) } < v {
437 *tail = back;
438 back = dec(*tail, cap);
439 }
440 unsafe {
441 *vals.get_unchecked_mut(*tail) = v;
442 *id.get_unchecked_mut(*tail) = idx;
443 }
444 inc(tail, cap);
445 }
446 #[inline(always)]
447 fn q_expire(
448 cur: usize,
449 len: usize,
450 id: &mut [usize],
451 head: &mut usize,
452 tail: &mut usize,
453 cap: usize,
454 ) {
455 let lim = cur.saturating_sub(len);
456 while *head != *tail && unsafe { *id.get_unchecked(*head) } <= lim {
457 inc(head, cap);
458 }
459 }
460
461 #[inline(always)]
462 fn max3(a: f64, b: f64, c: f64) -> f64 {
463 let m = if a > b { a } else { b };
464 if m > c {
465 m
466 } else {
467 c
468 }
469 }
470
471 let tr0 = unsafe { *high.get_unchecked(first) - *low.get_unchecked(first) };
472 minq_push(
473 tr0,
474 first,
475 &mut min_vals,
476 &mut min_idx,
477 &mut min_head,
478 &mut min_tail,
479 cap,
480 );
481 maxq_push(
482 tr0,
483 first,
484 &mut max_vals,
485 &mut max_idx,
486 &mut max_head,
487 &mut max_tail,
488 cap,
489 );
490 let mut last_tr = tr0;
491
492 let mut i = first + 1;
493 while i <= warm {
494 let hi = unsafe { *high.get_unchecked(i) };
495 let lo = unsafe { *low.get_unchecked(i) };
496 let pc1 = unsafe { *close.get_unchecked(i - 1) };
497 let tr = max3(hi - lo, (hi - pc1).abs(), (lo - pc1).abs());
498 minq_push(
499 tr,
500 i,
501 &mut min_vals,
502 &mut min_idx,
503 &mut min_head,
504 &mut min_tail,
505 cap,
506 );
507 maxq_push(
508 tr,
509 i,
510 &mut max_vals,
511 &mut max_idx,
512 &mut max_head,
513 &mut max_tail,
514 cap,
515 );
516 last_tr = tr;
517 i += 1;
518 }
519
520 let tr_low = unsafe { *min_vals.get_unchecked(min_head) };
521 let tr_high = unsafe { *max_vals.get_unchecked(max_head) };
522 let denom = tr_high - tr_low;
523 let tr_adj0 = if denom != 0.0 {
524 (last_tr - tr_low) / denom
525 } else {
526 0.0
527 };
528 let a0 = alpha * (1.0 + tr_adj0 * mult);
529 let src0 = unsafe { *close.get_unchecked(warm - 1) };
530 let mut y = src0.mul_add(a0, 0.0);
531 unsafe {
532 *out.get_unchecked_mut(warm) = y;
533 }
534
535 i = warm + 1;
536 while i < n {
537 q_expire(i, length, &mut min_idx, &mut min_head, &mut min_tail, cap);
538 q_expire(i, length, &mut max_idx, &mut max_head, &mut max_tail, cap);
539
540 let hi = unsafe { *high.get_unchecked(i) };
541 let lo = unsafe { *low.get_unchecked(i) };
542 let pc1 = unsafe { *close.get_unchecked(i - 1) };
543 let tr = max3(hi - lo, (hi - pc1).abs(), (lo - pc1).abs());
544 minq_push(
545 tr,
546 i,
547 &mut min_vals,
548 &mut min_idx,
549 &mut min_head,
550 &mut min_tail,
551 cap,
552 );
553 maxq_push(
554 tr,
555 i,
556 &mut max_vals,
557 &mut max_idx,
558 &mut max_head,
559 &mut max_tail,
560 cap,
561 );
562
563 let lo_tr = unsafe { *min_vals.get_unchecked(min_head) };
564 let hi_tr = unsafe { *max_vals.get_unchecked(max_head) };
565 let den = hi_tr - lo_tr;
566 let tr_adj = if den != 0.0 { (tr - lo_tr) / den } else { 0.0 };
567 let a = alpha * (1.0 + tr_adj * mult);
568 let src = pc1;
569 y = (src - y).mul_add(a, y);
570 unsafe {
571 *out.get_unchecked_mut(i) = y;
572 }
573
574 i += 1;
575 }
576}
577
578#[inline(always)]
579fn tradjema_compute_into(
580 high: &[f64],
581 low: &[f64],
582 close: &[f64],
583 length: usize,
584 mult: f64,
585 first: usize,
586 kern: Kernel,
587 out: &mut [f64],
588) {
589 unsafe {
590 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
591 {
592 if matches!(kern, Kernel::Scalar | Kernel::ScalarBatch) {
593 tradjema_compute_into_scalar(high, low, close, length, mult, first, out);
594 return;
595 }
596 }
597 match kern {
598 Kernel::Scalar | Kernel::ScalarBatch => {
599 tradjema_compute_into_scalar(high, low, close, length, mult, first, out)
600 }
601 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
602 Kernel::Avx2 | Kernel::Avx2Batch => {
603 tradjema_compute_into_avx2(high, low, close, length, mult, first, out)
604 }
605 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
606 Kernel::Avx512 | Kernel::Avx512Batch => {
607 tradjema_compute_into_avx512(high, low, close, length, mult, first, out)
608 }
609 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
610 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
611 tradjema_compute_into_scalar(high, low, close, length, mult, first, out)
612 }
613 _ => unreachable!(),
614 }
615 }
616}
617
618#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
619#[target_feature(enable = "avx2")]
620unsafe fn tradjema_compute_into_avx2(
621 high: &[f64],
622 low: &[f64],
623 close: &[f64],
624 length: usize,
625 mult: f64,
626 first: usize,
627 out: &mut [f64],
628) {
629 tradjema_compute_into_scalar(high, low, close, length, mult, first, out);
630}
631
632#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
633#[target_feature(enable = "avx512f")]
634unsafe fn tradjema_compute_into_avx512(
635 high: &[f64],
636 low: &[f64],
637 close: &[f64],
638 length: usize,
639 mult: f64,
640 first: usize,
641 out: &mut [f64],
642) {
643 tradjema_compute_into_scalar(high, low, close, length, mult, first, out);
644}
645
646#[derive(Debug, Clone)]
647pub struct TradjemaStream {
648 length: usize,
649 mult: f64,
650 alpha: f64,
651
652 i: usize,
653 filled: bool,
654
655 prev_close: f64,
656 tradjema: f64,
657
658 min_vals: Vec<f64>,
659 min_idx: Vec<usize>,
660 max_vals: Vec<f64>,
661 max_idx: Vec<usize>,
662 min_head: usize,
663 min_tail: usize,
664 max_head: usize,
665 max_tail: usize,
666}
667
668impl TradjemaStream {
669 pub fn try_new(params: TradjemaParams) -> Result<Self, TradjemaError> {
670 let length = params.length.unwrap_or(40);
671 let mult = params.mult.unwrap_or(10.0);
672
673 if length < 2 {
674 return Err(TradjemaError::InvalidLength {
675 length,
676 data_len: 0,
677 });
678 }
679 if mult <= 0.0 || !mult.is_finite() {
680 return Err(TradjemaError::InvalidMult { mult });
681 }
682 let cap = length;
683
684 Ok(Self {
685 length,
686 mult,
687 alpha: 2.0 / (length as f64 + 1.0),
688
689 i: 0,
690 filled: false,
691
692 prev_close: f64::NAN,
693 tradjema: f64::NAN,
694
695 min_vals: vec![0.0; cap],
696 min_idx: vec![0; cap],
697 max_vals: vec![0.0; cap],
698 max_idx: vec![0; cap],
699 min_head: 0,
700 min_tail: 0,
701 max_head: 0,
702 max_tail: 0,
703 })
704 }
705
706 #[inline(always)]
707 fn inc(i: &mut usize, cap: usize) {
708 *i += 1;
709 if *i == cap {
710 *i = 0;
711 }
712 }
713 #[inline(always)]
714 fn dec(i: usize, cap: usize) -> usize {
715 if i == 0 {
716 cap - 1
717 } else {
718 i - 1
719 }
720 }
721 #[inline(always)]
722 fn minq_push(&mut self, v: f64, idx: usize) {
723 let cap = self.length;
724 let mut back = Self::dec(self.min_tail, cap);
725
726 while self.min_tail != self.min_head && self.min_vals[back] > v {
727 self.min_tail = back;
728 back = Self::dec(self.min_tail, cap);
729 }
730 self.min_vals[self.min_tail] = v;
731 self.min_idx[self.min_tail] = idx;
732 Self::inc(&mut self.min_tail, cap);
733 }
734 #[inline(always)]
735 fn maxq_push(&mut self, v: f64, idx: usize) {
736 let cap = self.length;
737 let mut back = Self::dec(self.max_tail, cap);
738
739 while self.max_tail != self.max_head && self.max_vals[back] < v {
740 self.max_tail = back;
741 back = Self::dec(self.max_tail, cap);
742 }
743 self.max_vals[self.max_tail] = v;
744 self.max_idx[self.max_tail] = idx;
745 Self::inc(&mut self.max_tail, cap);
746 }
747 #[inline(always)]
748 fn q_expire(
749 head: &mut usize,
750 tail: &mut usize,
751 id: &mut [usize],
752 cur: usize,
753 len: usize,
754 cap: usize,
755 ) {
756 let lim = cur.saturating_sub(len);
757 while *head != *tail && id[*head] <= lim {
758 Self::inc(head, cap);
759 }
760 }
761
762 #[inline(always)]
763 fn max3(a: f64, b: f64, c: f64) -> f64 {
764 let m = if a > b { a } else { b };
765 if m > c {
766 m
767 } else {
768 c
769 }
770 }
771
772 #[inline(always)]
773 pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
774 let tr = if self.prev_close.is_nan() {
775 high - low
776 } else {
777 let hl = high - low;
778 let hc = (high - self.prev_close).abs();
779 let lc = (low - self.prev_close).abs();
780 Self::max3(hl, hc, lc)
781 };
782
783 if !self.filled {
784 self.minq_push(tr, self.i);
785 self.maxq_push(tr, self.i);
786
787 if self.i + 1 < self.length {
788 self.prev_close = close;
789 self.i += 1;
790 return None;
791 }
792
793 let lo = self.min_vals[self.min_head];
794 let hi = self.max_vals[self.max_head];
795 let den = hi - lo;
796 let tr_adj = if den != 0.0 { (tr - lo) / den } else { 0.0 };
797 let a0 = self.alpha * (1.0 + tr_adj * self.mult);
798
799 let src = self.prev_close;
800 self.tradjema = src.mul_add(a0, 0.0);
801
802 self.prev_close = close;
803 self.filled = true;
804 self.i += 1;
805 return Some(self.tradjema);
806 }
807
808 let cap = self.length;
809 Self::q_expire(
810 &mut self.min_head,
811 &mut self.min_tail,
812 &mut self.min_idx,
813 self.i,
814 self.length,
815 cap,
816 );
817 Self::q_expire(
818 &mut self.max_head,
819 &mut self.max_tail,
820 &mut self.max_idx,
821 self.i,
822 self.length,
823 cap,
824 );
825
826 self.minq_push(tr, self.i);
827 self.maxq_push(tr, self.i);
828
829 let lo = self.min_vals[self.min_head];
830 let hi = self.max_vals[self.max_head];
831 let den = hi - lo;
832 let tr_adj = if den != 0.0 { (tr - lo) / den } else { 0.0 };
833 let a = self.alpha * (1.0 + tr_adj * self.mult);
834
835 let src = self.prev_close;
836 self.tradjema = (src - self.tradjema).mul_add(a, self.tradjema);
837
838 self.prev_close = close;
839 self.i += 1;
840
841 Some(self.tradjema)
842 }
843}
844
845#[cfg(feature = "python")]
846#[pyfunction(name = "tradjema")]
847#[pyo3(signature = (high, low, close, length, mult, kernel=None))]
848pub fn tradjema_py<'py>(
849 py: Python<'py>,
850 high: PyReadonlyArray1<'py, f64>,
851 low: PyReadonlyArray1<'py, f64>,
852 close: PyReadonlyArray1<'py, f64>,
853 length: usize,
854 mult: f64,
855 kernel: Option<&str>,
856) -> PyResult<Bound<'py, PyArray1<f64>>> {
857 let (h, l, c) = (high.as_slice()?, low.as_slice()?, close.as_slice()?);
858 if h.len() != l.len() || l.len() != c.len() {
859 return Err(PyValueError::new_err(
860 "All OHLC arrays must have the same length",
861 ));
862 }
863 let kern = validate_kernel(kernel, false)?;
864 let input = TradjemaInput::from_slices(
865 h,
866 l,
867 c,
868 TradjemaParams {
869 length: Some(length),
870 mult: Some(mult),
871 },
872 );
873
874 let values = py
875 .allow_threads(|| tradjema_with_kernel(&input, kern).map(|o| o.values))
876 .map_err(|e| PyValueError::new_err(e.to_string()))?;
877 Ok(values.into_pyarray(py))
878}
879
880#[cfg(feature = "python")]
881#[pyfunction(name = "tradjema_batch")]
882#[pyo3(signature = (high, low, close, length_range, mult_range, kernel=None))]
883pub fn tradjema_batch_py<'py>(
884 py: Python<'py>,
885 high: PyReadonlyArray1<'py, f64>,
886 low: PyReadonlyArray1<'py, f64>,
887 close: PyReadonlyArray1<'py, f64>,
888 length_range: (usize, usize, usize),
889 mult_range: (f64, f64, f64),
890 kernel: Option<&str>,
891) -> PyResult<Bound<'py, PyDict>> {
892 use numpy::PyArray1;
893
894 let (h, l, c) = (high.as_slice()?, low.as_slice()?, close.as_slice()?);
895 if h.len() != l.len() || l.len() != c.len() {
896 return Err(PyValueError::new_err(
897 "All OHLC arrays must have the same length",
898 ));
899 }
900
901 let sweep = TradjemaBatchRange {
902 length: length_range,
903 mult: mult_range,
904 };
905 let combos = expand_grid(&sweep);
906 if combos.is_empty() {
907 return Err(PyValueError::new_err("Empty parameter grid"));
908 }
909 let rows = combos.len();
910 let cols = c.len();
911
912 let total = rows
913 .checked_mul(cols)
914 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
915
916 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
917 let slice_out = unsafe { out_arr.as_slice_mut()? };
918
919 let first = c
920 .iter()
921 .position(|v| !v.is_nan())
922 .ok_or_else(|| PyValueError::new_err("All values are NaN"))?;
923 for (row, prm) in combos.iter().enumerate() {
924 let length = prm.length.unwrap_or(40);
925 let warm = first + length - 1;
926 let row_slice = &mut slice_out[row * cols..(row + 1) * cols];
927 for v in &mut row_slice[..warm] {
928 *v = f64::NAN;
929 }
930 }
931
932 let kern = validate_kernel(kernel, true)?;
933 let simd = match kern {
934 Kernel::Auto => detect_best_batch_kernel(),
935 Kernel::Avx512Batch => Kernel::Avx512,
936 Kernel::Avx2Batch => Kernel::Avx2,
937 Kernel::ScalarBatch => Kernel::Scalar,
938 _ => Kernel::Scalar,
939 };
940
941 let combos = py
942 .allow_threads(|| tradjema_batch_inner_into(h, l, c, &sweep, simd, true, slice_out))
943 .map_err(|e| PyValueError::new_err(e.to_string()))?;
944
945 let dict = PyDict::new(py);
946 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
947 dict.set_item(
948 "lengths",
949 combos
950 .iter()
951 .map(|p| p.length.unwrap_or(40) as u64)
952 .collect::<Vec<_>>()
953 .into_pyarray(py),
954 )?;
955 dict.set_item(
956 "mults",
957 combos
958 .iter()
959 .map(|p| p.mult.unwrap_or(10.0))
960 .collect::<Vec<_>>()
961 .into_pyarray(py),
962 )?;
963 Ok(dict)
964}
965
966#[cfg(all(feature = "python", feature = "cuda"))]
967#[pyfunction(name = "tradjema_cuda_batch_dev")]
968#[pyo3(signature = (high_f32, low_f32, close_f32, length_range, mult_range, device_id=0))]
969pub fn tradjema_cuda_batch_dev_py(
970 py: Python<'_>,
971 high_f32: PyReadonlyArray1<'_, f32>,
972 low_f32: PyReadonlyArray1<'_, f32>,
973 close_f32: PyReadonlyArray1<'_, f32>,
974 length_range: (usize, usize, usize),
975 mult_range: (f64, f64, f64),
976 device_id: usize,
977) -> PyResult<DeviceArrayF32TradjemaPy> {
978 if !cuda_available() {
979 return Err(PyValueError::new_err("CUDA not available"));
980 }
981
982 let high = high_f32.as_slice()?;
983 let low = low_f32.as_slice()?;
984 let close = close_f32.as_slice()?;
985
986 if high.len() != low.len() || low.len() != close.len() {
987 return Err(PyValueError::new_err(
988 "All OHLC arrays must have the same length",
989 ));
990 }
991
992 let sweep = TradjemaBatchRange {
993 length: length_range,
994 mult: mult_range,
995 };
996
997 let (inner, ctx, dev_id) = py.allow_threads(|| {
998 let cuda =
999 CudaTradjema::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1000 let ctx = cuda.context_arc();
1001 let dev_id = cuda.device_id();
1002 let arr = cuda
1003 .tradjema_batch_dev(high, low, close, &sweep)
1004 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1005 Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
1006 })?;
1007
1008 Ok(DeviceArrayF32TradjemaPy::new(inner, ctx, dev_id))
1009}
1010
1011#[cfg(all(feature = "python", feature = "cuda"))]
1012#[pyfunction(name = "tradjema_cuda_many_series_one_param_dev")]
1013#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, length, mult, device_id=0))]
1014pub fn tradjema_cuda_many_series_one_param_dev_py(
1015 py: Python<'_>,
1016 high_tm_f32: PyReadonlyArray2<'_, f32>,
1017 low_tm_f32: PyReadonlyArray2<'_, f32>,
1018 close_tm_f32: PyReadonlyArray2<'_, f32>,
1019 length: usize,
1020 mult: f64,
1021 device_id: usize,
1022) -> PyResult<DeviceArrayF32TradjemaPy> {
1023 if !cuda_available() {
1024 return Err(PyValueError::new_err("CUDA not available"));
1025 }
1026
1027 let shape = high_tm_f32.shape();
1028 if shape != low_tm_f32.shape() || shape != close_tm_f32.shape() {
1029 return Err(PyValueError::new_err(
1030 "OHLC tensors must share the same shape",
1031 ));
1032 }
1033 if shape.len() != 2 {
1034 return Err(PyValueError::new_err("expected 2D arrays (time, series)"));
1035 }
1036 let rows = shape[0];
1037 let cols = shape[1];
1038
1039 let high = high_tm_f32.as_slice()?;
1040 let low = low_tm_f32.as_slice()?;
1041 let close = close_tm_f32.as_slice()?;
1042
1043 let params = TradjemaParams {
1044 length: Some(length),
1045 mult: Some(mult),
1046 };
1047
1048 let (inner, ctx, dev_id) = py.allow_threads(|| {
1049 let cuda =
1050 CudaTradjema::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1051 let ctx = cuda.context_arc();
1052 let dev_id = cuda.device_id();
1053 let arr = cuda
1054 .tradjema_many_series_one_param_time_major_dev(high, low, close, cols, rows, ¶ms)
1055 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1056 Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
1057 })?;
1058
1059 Ok(DeviceArrayF32TradjemaPy::new(inner, ctx, dev_id))
1060}
1061
1062#[cfg(all(feature = "python", feature = "cuda"))]
1063#[pyclass(
1064 module = "ta_indicators.cuda",
1065 name = "DeviceArrayF32Tradjema",
1066 unsendable
1067)]
1068pub struct DeviceArrayF32TradjemaPy {
1069 pub(crate) inner: DeviceArrayF32,
1070 _ctx_guard: Arc<Context>,
1071 _device_id: u32,
1072}
1073
1074#[cfg(all(feature = "python", feature = "cuda"))]
1075#[pymethods]
1076impl DeviceArrayF32TradjemaPy {
1077 #[new]
1078 fn py_new() -> PyResult<Self> {
1079 Err(pyo3::exceptions::PyTypeError::new_err(
1080 "use factory methods from CUDA functions",
1081 ))
1082 }
1083
1084 #[getter]
1085 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1086 let d = PyDict::new(py);
1087 let itemsize = std::mem::size_of::<f32>();
1088 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1089 d.set_item("typestr", "<f4")?;
1090 d.set_item("strides", (self.inner.cols * itemsize, itemsize))?;
1091 let size = self.inner.rows.saturating_mul(self.inner.cols);
1092 let ptr_val: usize = if size == 0 {
1093 0
1094 } else {
1095 self.inner.buf.as_device_ptr().as_raw() as usize
1096 };
1097 d.set_item("data", (ptr_val, false))?;
1098 d.set_item("version", 3)?;
1099 Ok(d)
1100 }
1101
1102 fn __dlpack_device__(&self) -> (i32, i32) {
1103 (2, self._device_id as i32)
1104 }
1105
1106 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1107 fn __dlpack__<'py>(
1108 &mut self,
1109 py: Python<'py>,
1110 stream: Option<PyObject>,
1111 max_version: Option<PyObject>,
1112 dl_device: Option<PyObject>,
1113 copy: Option<PyObject>,
1114 ) -> PyResult<PyObject> {
1115 use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1116 use cust::memory::DeviceBuffer;
1117
1118 let (kdl, alloc_dev) = self.__dlpack_device__();
1119 if let Some(dev_obj) = dl_device.as_ref() {
1120 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1121 if dev_ty != kdl || dev_id != alloc_dev {
1122 let wants_copy = copy
1123 .as_ref()
1124 .and_then(|c| c.extract::<bool>(py).ok())
1125 .unwrap_or(false);
1126 if wants_copy {
1127 return Err(PyValueError::new_err(
1128 "device copy not implemented for __dlpack__",
1129 ));
1130 } else {
1131 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1132 }
1133 }
1134 }
1135 }
1136 let _ = stream;
1137
1138 let dummy =
1139 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1140 let inner = std::mem::replace(
1141 &mut self.inner,
1142 DeviceArrayF32 {
1143 buf: dummy,
1144 rows: 0,
1145 cols: 0,
1146 },
1147 );
1148
1149 let rows = inner.rows;
1150 let cols = inner.cols;
1151 let buf = inner.buf;
1152
1153 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1154
1155 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1156 }
1157}
1158
1159#[cfg(all(feature = "python", feature = "cuda"))]
1160impl DeviceArrayF32TradjemaPy {
1161 pub fn new(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
1162 Self {
1163 inner,
1164 _ctx_guard: ctx_guard,
1165 _device_id: device_id,
1166 }
1167 }
1168}
1169
1170#[cfg(feature = "python")]
1171#[pyclass(name = "TradjemaStream")]
1172pub struct TradjemaStreamPy {
1173 inner: TradjemaStream,
1174}
1175
1176#[cfg(feature = "python")]
1177#[pymethods]
1178impl TradjemaStreamPy {
1179 #[new]
1180 fn new(length: usize, mult: f64) -> PyResult<Self> {
1181 TradjemaStream::try_new(TradjemaParams {
1182 length: Some(length),
1183 mult: Some(mult),
1184 })
1185 .map(|inner| Self { inner })
1186 .map_err(|e| PyValueError::new_err(e.to_string()))
1187 }
1188 fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
1189 self.inner.update(high, low, close)
1190 }
1191}
1192
1193#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1194#[wasm_bindgen]
1195pub fn tradjema_into(
1196 high_ptr: *const f64,
1197 low_ptr: *const f64,
1198 close_ptr: *const f64,
1199 out_ptr: *mut f64,
1200 len: usize,
1201 length: usize,
1202 mult: f64,
1203) -> Result<(), JsValue> {
1204 if [high_ptr, low_ptr, close_ptr, out_ptr]
1205 .iter()
1206 .any(|p| p.is_null())
1207 {
1208 return Err(JsValue::from_str("null pointer"));
1209 }
1210 unsafe {
1211 let h = std::slice::from_raw_parts(high_ptr, len);
1212 let l = std::slice::from_raw_parts(low_ptr, len);
1213 let c = std::slice::from_raw_parts(close_ptr, len);
1214
1215 let params = TradjemaParams {
1216 length: Some(length),
1217 mult: Some(mult),
1218 };
1219 let input = TradjemaInput::from_slices(h, l, c, params);
1220
1221 if (out_ptr as *const f64) == close_ptr
1222 || (out_ptr as *const f64) == high_ptr
1223 || (out_ptr as *const f64) == low_ptr
1224 {
1225 let mut tmp = vec![f64::NAN; len];
1226 tradjema_into_slice(&mut tmp, &input, Kernel::Auto)
1227 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1228 std::slice::from_raw_parts_mut(out_ptr, len).copy_from_slice(&tmp);
1229 } else {
1230 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1231 tradjema_into_slice(out, &input, Kernel::Auto)
1232 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1233 }
1234 Ok(())
1235 }
1236}
1237
1238#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1239#[wasm_bindgen]
1240pub fn tradjema_batch_into(
1241 high_ptr: *const f64,
1242 low_ptr: *const f64,
1243 close_ptr: *const f64,
1244 len: usize,
1245 length_start: usize,
1246 length_end: usize,
1247 length_step: usize,
1248 mult_start: f64,
1249 mult_end: f64,
1250 mult_step: f64,
1251 out_ptr: *mut f64,
1252) -> Result<usize, JsValue> {
1253 if [high_ptr, low_ptr, close_ptr].iter().any(|p| p.is_null()) || out_ptr.is_null() {
1254 return Err(JsValue::from_str("Null pointer passed"));
1255 }
1256
1257 unsafe {
1258 let high = std::slice::from_raw_parts(high_ptr, len);
1259 let low = std::slice::from_raw_parts(low_ptr, len);
1260 let close = std::slice::from_raw_parts(close_ptr, len);
1261
1262 let sweep = TradjemaBatchRange {
1263 length: (length_start, length_end, length_step),
1264 mult: (mult_start, mult_end, mult_step),
1265 };
1266 let combos = expand_grid(&sweep);
1267 if combos.is_empty() {
1268 return Err(JsValue::from_str("Empty parameter grid"));
1269 }
1270 let rows = combos.len();
1271 let cols = len;
1272
1273 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
1274
1275 let first = close
1276 .iter()
1277 .position(|v| !v.is_nan())
1278 .ok_or_else(|| JsValue::from_str("All values are NaN"))?;
1279 for (row, prm) in combos.iter().enumerate() {
1280 let length = prm.length.unwrap_or(40);
1281 let warm = first + length - 1;
1282 let row_slice = &mut out[row * cols..(row + 1) * cols];
1283 for v in &mut row_slice[..warm] {
1284 *v = f64::NAN;
1285 }
1286 }
1287
1288 let simd = match detect_best_batch_kernel() {
1289 Kernel::Avx512Batch => Kernel::Avx512,
1290 Kernel::Avx2Batch => Kernel::Avx2,
1291 _ => Kernel::Scalar,
1292 };
1293 tradjema_batch_inner_into(high, low, close, &sweep, simd, false, out)
1294 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1295
1296 Ok(rows)
1297 }
1298}
1299
1300#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1301#[wasm_bindgen]
1302pub fn tradjema_alloc(len: usize) -> *mut f64 {
1303 let mut v = Vec::<f64>::with_capacity(len);
1304 let p = v.as_mut_ptr();
1305 std::mem::forget(v);
1306 p
1307}
1308
1309#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1310#[wasm_bindgen]
1311pub fn tradjema_free(ptr: *mut f64, len: usize) {
1312 unsafe {
1313 let _ = Vec::from_raw_parts(ptr, len, len);
1314 }
1315}
1316
1317#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1318#[wasm_bindgen]
1319pub fn tradjema_js(
1320 high: &[f64],
1321 low: &[f64],
1322 close: &[f64],
1323 length: usize,
1324 mult: f64,
1325) -> Result<Vec<f64>, JsValue> {
1326 if close.is_empty() {
1327 return Err(JsValue::from_str("Input data slice is empty"));
1328 }
1329 if high.len() != low.len() || low.len() != close.len() {
1330 return Err(JsValue::from_str("length mismatch"));
1331 }
1332
1333 if length < 2 || length > close.len() {
1334 return Err(JsValue::from_str("Invalid length"));
1335 }
1336 if !(mult.is_finite()) || mult <= 0.0 {
1337 return Err(JsValue::from_str("Invalid mult"));
1338 }
1339 let first = close
1340 .iter()
1341 .position(|v| !v.is_nan())
1342 .ok_or_else(|| JsValue::from_str("All values are NaN"))?;
1343 if close.len() - first < length {
1344 return Err(JsValue::from_str("Not enough valid data"));
1345 }
1346 let warm = first + length - 1;
1347
1348 let mut out = alloc_with_nan_prefix(close.len(), warm);
1349
1350 tradjema_compute_into(
1351 high,
1352 low,
1353 close,
1354 length,
1355 mult,
1356 first,
1357 Kernel::Scalar,
1358 &mut out,
1359 );
1360
1361 Ok(out)
1362}
1363
1364#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1365#[derive(Serialize, Deserialize)]
1366pub struct TradjemaBatchConfig {
1367 pub length_range: (usize, usize, usize),
1368 pub mult_range: (f64, f64, f64),
1369}
1370
1371#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1372#[derive(Serialize, Deserialize)]
1373pub struct TradjemaBatchJsOutput {
1374 pub values: Vec<f64>,
1375 pub combos: Vec<TradjemaParams>,
1376 pub rows: usize,
1377 pub cols: usize,
1378}
1379
1380#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1381#[wasm_bindgen(js_name = "tradjema_batch")]
1382pub fn tradjema_batch_unified_js(
1383 high: &[f64],
1384 low: &[f64],
1385 close: &[f64],
1386 config: JsValue,
1387) -> Result<JsValue, JsValue> {
1388 let cfg: TradjemaBatchConfig = serde_wasm_bindgen::from_value(config)
1389 .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
1390 let sweep = TradjemaBatchRange {
1391 length: cfg.length_range,
1392 mult: cfg.mult_range,
1393 };
1394
1395 if high.is_empty() || low.is_empty() || close.is_empty() {
1396 return Err(JsValue::from_str("Input arrays are empty"));
1397 }
1398
1399 let out = tradjema_batch_with_kernel(high, low, close, &sweep, Kernel::Auto)
1400 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1401
1402 let js = TradjemaBatchJsOutput {
1403 values: out.values,
1404 combos: out.combos,
1405 rows: out.rows,
1406 cols: out.cols,
1407 };
1408 serde_wasm_bindgen::to_value(&js)
1409 .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
1410}
1411
1412#[derive(Clone, Debug)]
1413pub struct TradjemaBatchRange {
1414 pub length: (usize, usize, usize),
1415 pub mult: (f64, f64, f64),
1416}
1417
1418impl Default for TradjemaBatchRange {
1419 fn default() -> Self {
1420 Self {
1421 length: (40, 289, 1),
1422 mult: (10.0, 10.0, 0.0),
1423 }
1424 }
1425}
1426
1427#[derive(Clone, Debug, Default)]
1428pub struct TradjemaBatchBuilder {
1429 range: TradjemaBatchRange,
1430 kernel: Kernel,
1431}
1432
1433impl TradjemaBatchBuilder {
1434 pub fn new() -> Self {
1435 Self::default()
1436 }
1437
1438 pub fn with_default_candles(c: &Candles) -> Result<TradjemaBatchOutput, TradjemaError> {
1439 TradjemaBatchBuilder::new()
1440 .kernel(Kernel::Auto)
1441 .apply_candles(c)
1442 }
1443
1444 pub fn kernel(mut self, k: Kernel) -> Self {
1445 self.kernel = k;
1446 self
1447 }
1448
1449 #[inline]
1450 pub fn length_range(mut self, start: usize, end: usize, step: usize) -> Self {
1451 self.range.length = (start, end, step);
1452 self
1453 }
1454
1455 #[inline]
1456 pub fn mult_range(mut self, start: f64, end: f64, step: f64) -> Self {
1457 self.range.mult = (start, end, step);
1458 self
1459 }
1460
1461 pub fn apply_slices(
1462 self,
1463 high: &[f64],
1464 low: &[f64],
1465 close: &[f64],
1466 ) -> Result<TradjemaBatchOutput, TradjemaError> {
1467 tradjema_batch_with_kernel(high, low, close, &self.range, self.kernel)
1468 }
1469
1470 pub fn apply_candles(self, c: &Candles) -> Result<TradjemaBatchOutput, TradjemaError> {
1471 let high = c
1472 .select_candle_field("high")
1473 .map_err(|_| TradjemaError::EmptyInputData)?;
1474 let low = c
1475 .select_candle_field("low")
1476 .map_err(|_| TradjemaError::EmptyInputData)?;
1477 let close = c
1478 .select_candle_field("close")
1479 .map_err(|_| TradjemaError::EmptyInputData)?;
1480 self.apply_slices(high, low, close)
1481 }
1482}
1483
1484#[derive(Clone, Debug)]
1485pub struct TradjemaBatchOutput {
1486 pub values: Vec<f64>,
1487 pub combos: Vec<TradjemaParams>,
1488 pub rows: usize,
1489 pub cols: usize,
1490}
1491
1492impl TradjemaBatchOutput {
1493 pub fn row_for_params(&self, p: &TradjemaParams) -> Option<usize> {
1494 self.combos.iter().position(|c| {
1495 c.length.unwrap_or(40) == p.length.unwrap_or(40)
1496 && (c.mult.unwrap_or(10.0) - p.mult.unwrap_or(10.0)).abs() < 1e-9
1497 })
1498 }
1499
1500 pub fn values_for(&self, p: &TradjemaParams) -> Option<&[f64]> {
1501 self.row_for_params(p).map(|row| {
1502 let start = row * self.cols;
1503 &self.values[start..start + self.cols]
1504 })
1505 }
1506}
1507
1508#[inline(always)]
1509fn expand_grid(r: &TradjemaBatchRange) -> Vec<TradjemaParams> {
1510 let (ls, le, lstep) = r.length;
1511 let (ms, me, mstep) = r.mult;
1512
1513 #[inline]
1514 fn axis_usize(start: usize, end: usize, step: usize) -> Vec<usize> {
1515 if step == 0 {
1516 return vec![start];
1517 }
1518 let mut vals = Vec::new();
1519 if start <= end {
1520 let mut v = start;
1521 while v <= end {
1522 vals.push(v);
1523 match v.checked_add(step) {
1524 Some(n) if n > v => v = n,
1525 _ => break,
1526 }
1527 }
1528 } else {
1529 let mut v = start;
1530 loop {
1531 vals.push(v);
1532 if v <= end {
1533 break;
1534 }
1535 v = v.saturating_sub(step);
1536 if v < end {
1537 break;
1538 }
1539 }
1540 }
1541 vals
1542 }
1543
1544 #[inline]
1545 fn axis_f64(start: f64, end: f64, step: f64) -> Vec<f64> {
1546 if step == 0.0 {
1547 return vec![start];
1548 }
1549 let mut vals = Vec::new();
1550 if start <= end {
1551 let mut v = start;
1552 while v <= end {
1553 vals.push(v);
1554 v += step;
1555
1556 if !v.is_finite() {
1557 break;
1558 }
1559 if step.is_sign_negative() {
1560 break;
1561 }
1562 }
1563 } else {
1564 let mut v = start;
1565 while v >= end {
1566 vals.push(v);
1567 v -= step.abs();
1568 if !v.is_finite() {
1569 break;
1570 }
1571 if step == 0.0 {
1572 break;
1573 }
1574 }
1575 }
1576 vals
1577 }
1578
1579 let lengths = axis_usize(ls, le, lstep);
1580 let mults = axis_f64(ms, me, mstep);
1581 if lengths.is_empty() || mults.is_empty() {
1582 return Vec::new();
1583 }
1584 let mut combos = Vec::with_capacity(lengths.len().saturating_mul(mults.len()));
1585 for &l in &lengths {
1586 for &m in &mults {
1587 combos.push(TradjemaParams {
1588 length: Some(l),
1589 mult: Some(m),
1590 });
1591 }
1592 }
1593 combos
1594}
1595
1596pub fn tradjema_batch_with_kernel(
1597 high: &[f64],
1598 low: &[f64],
1599 close: &[f64],
1600 sweep: &TradjemaBatchRange,
1601 k: Kernel,
1602) -> Result<TradjemaBatchOutput, TradjemaError> {
1603 let kernel = match k {
1604 Kernel::Auto => detect_best_batch_kernel(),
1605 other if other.is_batch() => other,
1606 Kernel::Scalar | Kernel::Avx2 | Kernel::Avx512 => {
1607 return Err(TradjemaError::InvalidKernelForBatch(k));
1608 }
1609 _ => detect_best_batch_kernel(),
1610 };
1611
1612 let simd = match kernel {
1613 Kernel::Avx512Batch => Kernel::Avx512,
1614 Kernel::Avx2Batch => Kernel::Avx2,
1615 Kernel::ScalarBatch => Kernel::Scalar,
1616 _ => Kernel::Scalar,
1617 };
1618
1619 let combos = expand_grid(sweep);
1620 if combos.is_empty() {
1621 let (ls, le, lstep) = sweep.length;
1622 let (ms, me, mstep) = sweep.mult;
1623
1624 let length_empty =
1625 (lstep == 0 && ls != le) || (lstep > 0 && ls > le && ls.saturating_sub(le) < lstep);
1626 if length_empty {
1627 return Err(TradjemaError::InvalidLengthRange {
1628 start: ls,
1629 end: le,
1630 step: lstep,
1631 });
1632 } else {
1633 return Err(TradjemaError::InvalidMultRange {
1634 start: ms,
1635 end: me,
1636 step: mstep,
1637 });
1638 }
1639 }
1640 let rows = combos.len();
1641 let cols = close.len();
1642 rows.checked_mul(cols)
1643 .ok_or(TradjemaError::InvalidLengthRange {
1644 start: sweep.length.0,
1645 end: sweep.length.1,
1646 step: sweep.length.2,
1647 })?;
1648
1649 tradjema_batch_inner(high, low, close, sweep, simd, true)
1650}
1651
1652#[inline(always)]
1653fn tradjema_batch_inner_into(
1654 high: &[f64],
1655 low: &[f64],
1656 close: &[f64],
1657 sweep: &TradjemaBatchRange,
1658 kern: Kernel,
1659 parallel: bool,
1660 out: &mut [f64],
1661) -> Result<Vec<TradjemaParams>, TradjemaError> {
1662 let combos = expand_grid(sweep);
1663 if combos.is_empty() {
1664 let (ls, le, lstep) = sweep.length;
1665 let (ms, me, mstep) = sweep.mult;
1666 let length_empty =
1667 (lstep == 0 && ls != le) || (lstep > 0 && ls > le && ls.saturating_sub(le) < lstep);
1668 if length_empty {
1669 return Err(TradjemaError::InvalidLengthRange {
1670 start: ls,
1671 end: le,
1672 step: lstep,
1673 });
1674 } else {
1675 return Err(TradjemaError::InvalidMultRange {
1676 start: ms,
1677 end: me,
1678 step: mstep,
1679 });
1680 }
1681 }
1682 if high.len() != low.len() || low.len() != close.len() {
1683 return Err(TradjemaError::MissingData);
1684 }
1685
1686 let cols = close.len();
1687 let first = close
1688 .iter()
1689 .position(|v| !v.is_nan())
1690 .ok_or(TradjemaError::AllValuesNaN)?;
1691
1692 #[inline(always)]
1693 fn precompute_tr(high: &[f64], low: &[f64], close: &[f64], first: usize) -> Vec<f64> {
1694 let n = close.len();
1695 let mut tr = vec![0.0f64; n];
1696 if first < n {
1697 tr[first] = high[first] - low[first];
1698 let mut i = first + 1;
1699 while i < n {
1700 let hl = high[i] - low[i];
1701 let hc = (high[i] - close[i - 1]).abs();
1702 let lc = (low[i] - close[i - 1]).abs();
1703 tr[i] = hl.max(hc).max(lc);
1704 i += 1;
1705 }
1706 }
1707 tr
1708 }
1709
1710 #[inline(always)]
1711 fn compute_from_tr_into(
1712 tr: &[f64],
1713 close: &[f64],
1714 length: usize,
1715 mult: f64,
1716 first: usize,
1717 out: &mut [f64],
1718 ) {
1719 debug_assert_eq!(tr.len(), close.len());
1720 debug_assert_eq!(close.len(), out.len());
1721
1722 let warm = first + length - 1;
1723 if warm >= out.len() {
1724 return;
1725 }
1726 let alpha = 2.0 / (length as f64 + 1.0);
1727
1728 let cap = length;
1729 let mut min_vals = vec![0.0f64; cap];
1730 let mut min_idx = vec![0usize; cap];
1731 let mut max_vals = vec![0.0f64; cap];
1732 let mut max_idx = vec![0usize; cap];
1733 let (mut min_head, mut min_tail) = (0usize, 0usize);
1734 let (mut max_head, mut max_tail) = (0usize, 0usize);
1735 #[inline(always)]
1736 fn inc(i: &mut usize, cap: usize) {
1737 *i += 1;
1738 if *i == cap {
1739 *i = 0;
1740 }
1741 }
1742 #[inline(always)]
1743 fn dec(i: usize, cap: usize) -> usize {
1744 if i == 0 {
1745 cap - 1
1746 } else {
1747 i - 1
1748 }
1749 }
1750 #[inline(always)]
1751 fn minq_push(
1752 v: f64,
1753 idx: usize,
1754 vals: &mut [f64],
1755 id: &mut [usize],
1756 head: &mut usize,
1757 tail: &mut usize,
1758 cap: usize,
1759 ) {
1760 let mut back = dec(*tail, cap);
1761 while *tail != *head && vals[back] > v {
1762 *tail = back;
1763 back = dec(*tail, cap);
1764 }
1765 vals[*tail] = v;
1766 id[*tail] = idx;
1767 inc(tail, cap);
1768 }
1769 #[inline(always)]
1770 fn maxq_push(
1771 v: f64,
1772 idx: usize,
1773 vals: &mut [f64],
1774 id: &mut [usize],
1775 head: &mut usize,
1776 tail: &mut usize,
1777 cap: usize,
1778 ) {
1779 let mut back = dec(*tail, cap);
1780 while *tail != *head && vals[back] < v {
1781 *tail = back;
1782 back = dec(*tail, cap);
1783 }
1784 vals[*tail] = v;
1785 id[*tail] = idx;
1786 inc(tail, cap);
1787 }
1788 #[inline(always)]
1789 fn q_expire(
1790 cur: usize,
1791 len: usize,
1792 id: &mut [usize],
1793 head: &mut usize,
1794 tail: &mut usize,
1795 cap: usize,
1796 ) {
1797 let lim = cur.saturating_sub(len);
1798 while *head != *tail && id[*head] <= lim {
1799 inc(head, cap);
1800 }
1801 }
1802
1803 let mut i = first;
1804 while i <= warm {
1805 let v = tr[i];
1806 minq_push(
1807 v,
1808 i,
1809 &mut min_vals,
1810 &mut min_idx,
1811 &mut min_head,
1812 &mut min_tail,
1813 cap,
1814 );
1815 maxq_push(
1816 v,
1817 i,
1818 &mut max_vals,
1819 &mut max_idx,
1820 &mut max_head,
1821 &mut max_tail,
1822 cap,
1823 );
1824 i += 1;
1825 }
1826 let lo = min_vals[min_head];
1827 let hi = max_vals[max_head];
1828 let den = hi - lo;
1829 let v = tr[warm];
1830 let tr_adj0 = if den != 0.0 { (v - lo) / den } else { 0.0 };
1831 let a0 = alpha * (1.0 + tr_adj0 * mult);
1832 let mut y = a0 * close[warm - 1];
1833 out[warm] = y;
1834
1835 i = warm + 1;
1836 while i < out.len() {
1837 q_expire(i, length, &mut min_idx, &mut min_head, &mut min_tail, cap);
1838 q_expire(i, length, &mut max_idx, &mut max_head, &mut max_tail, cap);
1839
1840 let v = tr[i];
1841 minq_push(
1842 v,
1843 i,
1844 &mut min_vals,
1845 &mut min_idx,
1846 &mut min_head,
1847 &mut min_tail,
1848 cap,
1849 );
1850 maxq_push(
1851 v,
1852 i,
1853 &mut max_vals,
1854 &mut max_idx,
1855 &mut max_head,
1856 &mut max_tail,
1857 cap,
1858 );
1859
1860 let lo = min_vals[min_head];
1861 let hi = max_vals[max_head];
1862 let den = hi - lo;
1863 let tr_adj = if den != 0.0 { (v - lo) / den } else { 0.0 };
1864 let a = alpha * (1.0 + tr_adj * mult);
1865 let src = close[i - 1];
1866 y += a * (src - y);
1867 out[i] = y;
1868
1869 i += 1;
1870 }
1871 }
1872
1873 let pre_tr = precompute_tr(high, low, close, first);
1874 let do_row = |row: usize, dst: &mut [f64]| {
1875 let p = &combos[row];
1876 let length = p.length.unwrap_or(40);
1877 let mult = p.mult.unwrap_or(10.0);
1878 if length < 2 {
1879 return;
1880 }
1881
1882 let _ = kern;
1883 compute_from_tr_into(&pre_tr, close, length, mult, first, dst);
1884 };
1885
1886 if parallel {
1887 #[cfg(not(target_arch = "wasm32"))]
1888 out.par_chunks_mut(cols)
1889 .enumerate()
1890 .for_each(|(row, slice)| do_row(row, slice));
1891 #[cfg(target_arch = "wasm32")]
1892 for (row, slice) in out.chunks_mut(cols).enumerate() {
1893 do_row(row, slice);
1894 }
1895 } else {
1896 for (row, slice) in out.chunks_mut(cols).enumerate() {
1897 do_row(row, slice);
1898 }
1899 }
1900
1901 Ok(combos)
1902}
1903
1904fn tradjema_batch_inner(
1905 high: &[f64],
1906 low: &[f64],
1907 close: &[f64],
1908 sweep: &TradjemaBatchRange,
1909 kern: Kernel,
1910 parallel: bool,
1911) -> Result<TradjemaBatchOutput, TradjemaError> {
1912 let combos = expand_grid(sweep);
1913 if combos.is_empty() {
1914 return Err(TradjemaError::InvalidLength {
1915 length: 0,
1916 data_len: 0,
1917 });
1918 }
1919 let rows = combos.len();
1920 let cols = close.len();
1921
1922 let mut buf_mu = make_uninit_matrix(rows, cols);
1923 let first = close
1924 .iter()
1925 .position(|v| !v.is_nan())
1926 .ok_or(TradjemaError::AllValuesNaN)?;
1927 let warms: Vec<usize> = combos
1928 .iter()
1929 .map(|p| first + p.length.unwrap_or(40) - 1)
1930 .collect();
1931 init_matrix_prefixes(&mut buf_mu, cols, &warms);
1932
1933 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
1934 let out: &mut [f64] =
1935 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
1936
1937 let combos = tradjema_batch_inner_into(high, low, close, sweep, kern, parallel, out)?;
1938
1939 let values = unsafe {
1940 Vec::from_raw_parts(
1941 guard.as_mut_ptr() as *mut f64,
1942 guard.len(),
1943 guard.capacity(),
1944 )
1945 };
1946
1947 Ok(TradjemaBatchOutput {
1948 values,
1949 combos,
1950 rows,
1951 cols,
1952 })
1953}
1954
1955#[cfg(test)]
1956mod tests {
1957 use super::*;
1958 use crate::skip_if_unsupported;
1959 use crate::utilities::data_loader::read_candles_from_csv;
1960 #[cfg(feature = "proptest")]
1961 use proptest::prelude::*;
1962 use std::error::Error;
1963
1964 fn check_tradjema_partial_params(
1965 test_name: &str,
1966 kernel: Kernel,
1967 ) -> Result<(), Box<dyn Error>> {
1968 skip_if_unsupported!(kernel, test_name);
1969 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1970 let candles = read_candles_from_csv(file_path)?;
1971
1972 let high = candles.select_candle_field("high")?;
1973 let low = candles.select_candle_field("low")?;
1974 let close = candles.select_candle_field("close")?;
1975
1976 let default_params = TradjemaParams {
1977 length: None,
1978 mult: None,
1979 };
1980 let input = TradjemaInput::from_slices(high, low, close, default_params);
1981 let output = tradjema_with_kernel(&input, kernel)?;
1982 assert_eq!(output.values.len(), candles.close.len());
1983
1984 Ok(())
1985 }
1986
1987 fn check_tradjema_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1988 skip_if_unsupported!(kernel, test_name);
1989 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1990 let candles = read_candles_from_csv(file_path)?;
1991
1992 let high = candles.select_candle_field("high")?;
1993 let low = candles.select_candle_field("low")?;
1994 let close = candles.select_candle_field("close")?;
1995
1996 let input = TradjemaInput::from_slices(high, low, close, TradjemaParams::default());
1997 let result = tradjema_with_kernel(&input, kernel)?;
1998
1999 assert_eq!(result.values.len(), candles.close.len());
2000
2001 let warmup = 39;
2002 for i in 0..warmup {
2003 assert!(
2004 result.values[i].is_nan(),
2005 "[{}] Expected NaN during warmup at index {}",
2006 test_name,
2007 i
2008 );
2009 }
2010
2011 for i in warmup..result.values.len() {
2012 assert!(
2013 !result.values[i].is_nan(),
2014 "[{}] Expected valid value after warmup at index {}",
2015 test_name,
2016 i
2017 );
2018 }
2019
2020 let expected_last_five = [
2021 59395.39322263,
2022 59388.09683228,
2023 59373.08371503,
2024 59350.75110897,
2025 59323.14225348,
2026 ];
2027
2028 let start = result.values.len().saturating_sub(5);
2029 for (i, &val) in result.values[start..].iter().enumerate() {
2030 let diff = (val - expected_last_five[i]).abs();
2031 assert!(
2032 diff < 1e-8,
2033 "[{}] TRADJEMA accuracy mismatch at last_5[{}]: got {:.8}, expected {:.8}, diff={:.10}",
2034 test_name,
2035 i,
2036 val,
2037 expected_last_five[i],
2038 diff
2039 );
2040 }
2041
2042 Ok(())
2043 }
2044
2045 fn check_tradjema_default_candles(
2046 test_name: &str,
2047 kernel: Kernel,
2048 ) -> Result<(), Box<dyn Error>> {
2049 skip_if_unsupported!(kernel, test_name);
2050 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2051 let candles = read_candles_from_csv(file_path)?;
2052
2053 let input = TradjemaInput::with_default_candles(&candles);
2054 let output = tradjema_with_kernel(&input, kernel)?;
2055 assert_eq!(output.values.len(), candles.close.len());
2056
2057 Ok(())
2058 }
2059
2060 fn check_tradjema_zero_length(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2061 skip_if_unsupported!(kernel, test_name);
2062 let input_data = vec![10.0, 20.0, 30.0];
2063
2064 let params = TradjemaParams {
2065 length: Some(0),
2066 mult: None,
2067 };
2068 let input = TradjemaInput::from_slices(&input_data, &input_data, &input_data, params);
2069 let res = tradjema_with_kernel(&input, kernel);
2070 assert!(
2071 res.is_err(),
2072 "[{}] TRADJEMA should fail with zero length",
2073 test_name
2074 );
2075
2076 let params = TradjemaParams {
2077 length: Some(1),
2078 mult: None,
2079 };
2080 let input = TradjemaInput::from_slices(&input_data, &input_data, &input_data, params);
2081 let res = tradjema_with_kernel(&input, kernel);
2082 assert!(
2083 res.is_err(),
2084 "[{}] TRADJEMA should fail with length=1 (minimum is 2)",
2085 test_name
2086 );
2087
2088 Ok(())
2089 }
2090
2091 fn check_tradjema_length_exceeds_data(
2092 test_name: &str,
2093 kernel: Kernel,
2094 ) -> Result<(), Box<dyn Error>> {
2095 skip_if_unsupported!(kernel, test_name);
2096 let data_small = vec![10.0, 20.0, 30.0];
2097 let params = TradjemaParams {
2098 length: Some(10),
2099 mult: None,
2100 };
2101 let input = TradjemaInput::from_slices(&data_small, &data_small, &data_small, params);
2102 let res = tradjema_with_kernel(&input, kernel);
2103 assert!(
2104 res.is_err(),
2105 "[{}] TRADJEMA should fail with length exceeding data",
2106 test_name
2107 );
2108 Ok(())
2109 }
2110
2111 fn check_tradjema_very_small_dataset(
2112 test_name: &str,
2113 kernel: Kernel,
2114 ) -> Result<(), Box<dyn Error>> {
2115 skip_if_unsupported!(kernel, test_name);
2116 let single_point = vec![42.0];
2117 let params = TradjemaParams {
2118 length: Some(40),
2119 mult: None,
2120 };
2121 let input = TradjemaInput::from_slices(&single_point, &single_point, &single_point, params);
2122 let res = tradjema_with_kernel(&input, kernel);
2123 assert!(
2124 res.is_err(),
2125 "[{}] TRADJEMA should fail with insufficient data",
2126 test_name
2127 );
2128 Ok(())
2129 }
2130
2131 fn check_tradjema_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2132 skip_if_unsupported!(kernel, test_name);
2133 let empty: Vec<f64> = vec![];
2134 let input = TradjemaInput::from_slices(&empty, &empty, &empty, TradjemaParams::default());
2135 let res = tradjema_with_kernel(&input, kernel);
2136 assert!(
2137 matches!(res, Err(TradjemaError::EmptyInputData)),
2138 "[{}] TRADJEMA should fail with empty input",
2139 test_name
2140 );
2141 Ok(())
2142 }
2143
2144 fn check_tradjema_invalid_mult(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2145 skip_if_unsupported!(kernel, test_name);
2146 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
2147
2148 let params = TradjemaParams {
2149 length: Some(2),
2150 mult: Some(-10.0),
2151 };
2152 let input = TradjemaInput::from_slices(&data, &data, &data, params);
2153 let res = tradjema_with_kernel(&input, kernel);
2154 assert!(
2155 matches!(res, Err(TradjemaError::InvalidMult { .. })),
2156 "[{}] TRADJEMA should fail with negative mult",
2157 test_name
2158 );
2159
2160 let params = TradjemaParams {
2161 length: Some(2),
2162 mult: Some(f64::NAN),
2163 };
2164 let input = TradjemaInput::from_slices(&data, &data, &data, params);
2165 let res = tradjema_with_kernel(&input, kernel);
2166 assert!(
2167 matches!(res, Err(TradjemaError::InvalidMult { .. })),
2168 "[{}] TRADJEMA should fail with NaN mult",
2169 test_name
2170 );
2171
2172 Ok(())
2173 }
2174
2175 fn check_tradjema_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2176 skip_if_unsupported!(kernel, test_name);
2177 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2178 let candles = read_candles_from_csv(file_path)?;
2179
2180 let high = candles.select_candle_field("high")?;
2181 let low = candles.select_candle_field("low")?;
2182 let close = candles.select_candle_field("close")?;
2183
2184 let first_params = TradjemaParams {
2185 length: Some(20),
2186 mult: Some(5.0),
2187 };
2188 let first_input = TradjemaInput::from_slices(high, low, close, first_params);
2189 let first_result = tradjema_with_kernel(&first_input, kernel)?;
2190
2191 let second_params = TradjemaParams {
2192 length: Some(20),
2193 mult: Some(5.0),
2194 };
2195 let second_input = TradjemaInput::from_slices(
2196 &first_result.values,
2197 &first_result.values,
2198 &first_result.values,
2199 second_params,
2200 );
2201 let second_result = tradjema_with_kernel(&second_input, kernel)?;
2202
2203 assert_eq!(second_result.values.len(), first_result.values.len());
2204 Ok(())
2205 }
2206
2207 fn check_tradjema_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2208 skip_if_unsupported!(kernel, test_name);
2209 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2210 let candles = read_candles_from_csv(file_path)?;
2211
2212 let high = candles.select_candle_field("high")?;
2213 let low = candles.select_candle_field("low")?;
2214 let close = candles.select_candle_field("close")?;
2215
2216 let input = TradjemaInput::from_slices(
2217 high,
2218 low,
2219 close,
2220 TradjemaParams {
2221 length: Some(40),
2222 mult: Some(10.0),
2223 },
2224 );
2225 let res = tradjema_with_kernel(&input, kernel)?;
2226 assert_eq!(res.values.len(), candles.close.len());
2227
2228 if res.values.len() > 50 {
2229 for (i, &val) in res.values[50..].iter().enumerate() {
2230 assert!(
2231 !val.is_nan(),
2232 "[{}] Found unexpected NaN at out-index {}",
2233 test_name,
2234 50 + i
2235 );
2236 }
2237 }
2238 Ok(())
2239 }
2240
2241 fn check_tradjema_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2242 skip_if_unsupported!(kernel, test_name);
2243
2244 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2245 let candles = read_candles_from_csv(file_path)?;
2246
2247 let high = candles.select_candle_field("high")?;
2248 let low = candles.select_candle_field("low")?;
2249 let close = candles.select_candle_field("close")?;
2250
2251 let length = 40;
2252 let mult = 10.0;
2253
2254 let input = TradjemaInput::from_slices(
2255 high,
2256 low,
2257 close,
2258 TradjemaParams {
2259 length: Some(length),
2260 mult: Some(mult),
2261 },
2262 );
2263 let batch_output = tradjema_with_kernel(&input, kernel)?.values;
2264
2265 let mut stream = TradjemaStream::try_new(TradjemaParams {
2266 length: Some(length),
2267 mult: Some(mult),
2268 })?;
2269
2270 let mut stream_values = Vec::with_capacity(candles.close.len());
2271 for i in 0..candles.close.len() {
2272 match stream.update(high[i], low[i], close[i]) {
2273 Some(val) => stream_values.push(val),
2274 None => stream_values.push(f64::NAN),
2275 }
2276 }
2277
2278 assert_eq!(batch_output.len(), stream_values.len());
2279 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
2280 if b.is_nan() && s.is_nan() {
2281 continue;
2282 }
2283 let diff = (b - s).abs();
2284 assert!(
2285 diff < 1e-9,
2286 "[{}] TRADJEMA streaming mismatch at idx {}: batch={}, stream={}, diff={}",
2287 test_name,
2288 i,
2289 b,
2290 s,
2291 diff
2292 );
2293 }
2294 Ok(())
2295 }
2296
2297 #[cfg(debug_assertions)]
2298 fn check_tradjema_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2299 skip_if_unsupported!(kernel, test_name);
2300
2301 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2302 let candles = read_candles_from_csv(file_path)?;
2303
2304 let high = candles.select_candle_field("high")?;
2305 let low = candles.select_candle_field("low")?;
2306 let close = candles.select_candle_field("close")?;
2307
2308 let test_params = vec![
2309 TradjemaParams::default(),
2310 TradjemaParams {
2311 length: Some(10),
2312 mult: Some(5.0),
2313 },
2314 TradjemaParams {
2315 length: Some(20),
2316 mult: Some(7.5),
2317 },
2318 TradjemaParams {
2319 length: Some(50),
2320 mult: Some(15.0),
2321 },
2322 TradjemaParams {
2323 length: Some(100),
2324 mult: Some(20.0),
2325 },
2326 ];
2327
2328 for (param_idx, params) in test_params.iter().enumerate() {
2329 let input = TradjemaInput::from_slices(high, low, close, params.clone());
2330 let output = tradjema_with_kernel(&input, kernel)?;
2331
2332 for (i, &val) in output.values.iter().enumerate() {
2333 if val.is_nan() {
2334 continue;
2335 }
2336
2337 let bits = val.to_bits();
2338
2339 if bits == 0x11111111_11111111 {
2340 panic!(
2341 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2342 with params: length={}, mult={}",
2343 test_name,
2344 val,
2345 bits,
2346 i,
2347 params.length.unwrap_or(40),
2348 params.mult.unwrap_or(10.0)
2349 );
2350 }
2351
2352 if bits == 0x22222222_22222222 {
2353 panic!(
2354 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2355 with params: length={}, mult={}",
2356 test_name,
2357 val,
2358 bits,
2359 i,
2360 params.length.unwrap_or(40),
2361 params.mult.unwrap_or(10.0)
2362 );
2363 }
2364
2365 if bits == 0x33333333_33333333 {
2366 panic!(
2367 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2368 with params: length={}, mult={}",
2369 test_name,
2370 val,
2371 bits,
2372 i,
2373 params.length.unwrap_or(40),
2374 params.mult.unwrap_or(10.0)
2375 );
2376 }
2377 }
2378 }
2379
2380 Ok(())
2381 }
2382
2383 #[cfg(not(debug_assertions))]
2384 fn check_tradjema_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2385 Ok(())
2386 }
2387
2388 #[cfg(feature = "proptest")]
2389 #[allow(clippy::float_cmp)]
2390 fn check_tradjema_property(
2391 test_name: &str,
2392 kernel: Kernel,
2393 ) -> Result<(), Box<dyn std::error::Error>> {
2394 use proptest::prelude::*;
2395 skip_if_unsupported!(kernel, test_name);
2396
2397 let strat = (2usize..=100).prop_flat_map(|length| {
2398 (
2399 prop::collection::vec(
2400 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2401 length..400,
2402 ),
2403 Just(length),
2404 0.1f64..50.0f64,
2405 )
2406 });
2407
2408 proptest::test_runner::TestRunner::default()
2409 .run(&strat, |(data, length, mult)| {
2410 let params = TradjemaParams {
2411 length: Some(length),
2412 mult: Some(mult),
2413 };
2414
2415 let input = TradjemaInput::from_slices(&data, &data, &data, params);
2416
2417 let TradjemaOutput { values: out } = tradjema_with_kernel(&input, kernel).unwrap();
2418 let TradjemaOutput { values: ref_out } =
2419 tradjema_with_kernel(&input, Kernel::Scalar).unwrap();
2420
2421 for i in (length - 1)..data.len() {
2422 let y = out[i];
2423 let r = ref_out[i];
2424
2425 if !y.is_finite() || !r.is_finite() {
2426 prop_assert!(
2427 y.to_bits() == r.to_bits(),
2428 "finite/NaN mismatch idx {i}: {y} vs {r}"
2429 );
2430 continue;
2431 }
2432
2433 let ulp_diff: u64 = y.to_bits().abs_diff(r.to_bits());
2434
2435 prop_assert!(
2436 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
2437 "mismatch idx {i}: {y} vs {r} (ULP={ulp_diff})"
2438 );
2439 }
2440 Ok(())
2441 })
2442 .unwrap();
2443
2444 Ok(())
2445 }
2446
2447 macro_rules! generate_all_tradjema_tests {
2448 ($($test_fn:ident),*) => {
2449 paste::paste! {
2450 $(
2451 #[test]
2452 fn [<$test_fn _scalar_f64>]() {
2453 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2454 }
2455 )*
2456 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2457 $(
2458 #[test]
2459 fn [<$test_fn _avx2_f64>]() {
2460 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2461 }
2462 #[test]
2463 fn [<$test_fn _avx512_f64>]() {
2464 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2465 }
2466 )*
2467 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2468 $(
2469 #[test]
2470 fn [<$test_fn _simd128_f64>]() {
2471 let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
2472 }
2473 )*
2474 }
2475 }
2476 }
2477
2478 generate_all_tradjema_tests!(
2479 check_tradjema_partial_params,
2480 check_tradjema_accuracy,
2481 check_tradjema_default_candles,
2482 check_tradjema_zero_length,
2483 check_tradjema_length_exceeds_data,
2484 check_tradjema_very_small_dataset,
2485 check_tradjema_empty_input,
2486 check_tradjema_invalid_mult,
2487 check_tradjema_reinput,
2488 check_tradjema_nan_handling,
2489 check_tradjema_streaming,
2490 check_tradjema_no_poison
2491 );
2492
2493 fn check_tradjema_into_slice(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2494 skip_if_unsupported!(kernel, test_name);
2495 let f = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2496 let c = read_candles_from_csv(f)?;
2497 let (h, l, cl) = (
2498 c.select_candle_field("high")?,
2499 c.select_candle_field("low")?,
2500 c.select_candle_field("close")?,
2501 );
2502 let input = TradjemaInput::from_slices(h, l, cl, TradjemaParams::default());
2503 let mut dst = vec![0.0; cl.len()];
2504 tradjema_into_slice(&mut dst, &input, kernel)?;
2505 let first = cl.iter().position(|v| !v.is_nan()).unwrap();
2506 let warm = first + input.get_length() - 1;
2507 assert!(
2508 dst[..warm].iter().all(|v| v.is_nan()),
2509 "[{}] warmup prefix must be NaN",
2510 test_name
2511 );
2512 Ok(())
2513 }
2514
2515 generate_all_tradjema_tests!(check_tradjema_into_slice);
2516
2517 #[cfg(feature = "proptest")]
2518 generate_all_tradjema_tests!(check_tradjema_property);
2519
2520 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2521 skip_if_unsupported!(kernel, test);
2522
2523 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2524 let c = read_candles_from_csv(file)?;
2525
2526 let output = TradjemaBatchBuilder::new()
2527 .kernel(kernel)
2528 .apply_candles(&c)?;
2529
2530 let def = TradjemaParams::default();
2531 let row = output.values_for(&def).expect("default row missing");
2532
2533 assert_eq!(row.len(), c.close.len());
2534 Ok(())
2535 }
2536
2537 fn check_batch_sweep(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2538 skip_if_unsupported!(kernel, test);
2539
2540 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2541 let c = read_candles_from_csv(file)?;
2542
2543 let output = TradjemaBatchBuilder::new()
2544 .kernel(kernel)
2545 .length_range(20, 50, 10)
2546 .mult_range(5.0, 15.0, 5.0)
2547 .apply_candles(&c)?;
2548
2549 let expected_combos = 4 * 3;
2550 assert_eq!(output.combos.len(), expected_combos);
2551 assert_eq!(output.rows, expected_combos);
2552 assert_eq!(output.cols, c.close.len());
2553
2554 Ok(())
2555 }
2556
2557 #[cfg(debug_assertions)]
2558 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2559 skip_if_unsupported!(kernel, test);
2560
2561 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2562 let c = read_candles_from_csv(file)?;
2563
2564 let test_configs = vec![
2565 (10, 30, 10, 5.0, 15.0, 5.0),
2566 (40, 40, 0, 10.0, 10.0, 0.0),
2567 (20, 60, 20, 7.5, 12.5, 2.5),
2568 ];
2569
2570 for (cfg_idx, &(l_start, l_end, l_step, m_start, m_end, m_step)) in
2571 test_configs.iter().enumerate()
2572 {
2573 let output = TradjemaBatchBuilder::new()
2574 .kernel(kernel)
2575 .length_range(l_start, l_end, l_step)
2576 .mult_range(m_start, m_end, m_step)
2577 .apply_candles(&c)?;
2578
2579 for (idx, &val) in output.values.iter().enumerate() {
2580 if val.is_nan() {
2581 continue;
2582 }
2583
2584 let bits = val.to_bits();
2585 let row = idx / output.cols;
2586 let col = idx % output.cols;
2587 let combo = &output.combos[row];
2588
2589 if bits == 0x11111111_11111111
2590 || bits == 0x22222222_22222222
2591 || bits == 0x33333333_33333333
2592 {
2593 panic!(
2594 "[{}] Config {}: Found poison value {} (0x{:016X}) \
2595 at row {} col {} (flat index {}) with params: length={}, mult={}",
2596 test,
2597 cfg_idx,
2598 val,
2599 bits,
2600 row,
2601 col,
2602 idx,
2603 combo.length.unwrap_or(40),
2604 combo.mult.unwrap_or(10.0)
2605 );
2606 }
2607 }
2608 }
2609
2610 Ok(())
2611 }
2612
2613 #[cfg(not(debug_assertions))]
2614 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2615 Ok(())
2616 }
2617
2618 macro_rules! gen_batch_tests {
2619 ($fn_name:ident) => {
2620 paste::paste! {
2621 #[test] fn [<$fn_name _scalar>]() {
2622 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2623 }
2624 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2625 #[test] fn [<$fn_name _avx2>]() {
2626 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2627 }
2628 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2629 #[test] fn [<$fn_name _avx512>]() {
2630 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2631 }
2632 #[test] fn [<$fn_name _auto_detect>]() {
2633 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]),
2634 Kernel::Auto);
2635 }
2636 }
2637 };
2638 }
2639
2640 gen_batch_tests!(check_batch_default_row);
2641 gen_batch_tests!(check_batch_sweep);
2642 gen_batch_tests!(check_batch_no_poison);
2643
2644 #[test]
2645 fn test_expand_grid_single_value() {
2646 let range = TradjemaBatchRange {
2647 length: (2, 2, 0),
2648 mult: (10.0, 10.0, 0.0),
2649 };
2650 let combos = expand_grid(&range);
2651 assert!(
2652 !combos.is_empty(),
2653 "expand_grid should not return empty for single value"
2654 );
2655 assert_eq!(combos.len(), 1, "Should have exactly one combo");
2656 assert_eq!(combos[0].length, Some(2));
2657 assert_eq!(combos[0].mult, Some(10.0));
2658
2659 let range2 = TradjemaBatchRange {
2660 length: (40, 40, 0),
2661 mult: (10.0, 10.0, 0.0),
2662 };
2663 let combos2 = expand_grid(&range2);
2664 assert!(
2665 !combos2.is_empty(),
2666 "expand_grid should not return empty for single value (40,40,0)"
2667 );
2668 assert_eq!(
2669 combos2.len(),
2670 1,
2671 "Should have exactly one combo for (40,40,0)"
2672 );
2673 assert_eq!(combos2[0].length, Some(40));
2674 assert_eq!(combos2[0].mult, Some(10.0));
2675 }
2676
2677 #[test]
2678 fn test_tradjema_into_matches_api() -> Result<(), Box<dyn Error>> {
2679 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2680 let c = read_candles_from_csv(file)?;
2681 let (h, l, cl) = (
2682 c.select_candle_field("high")?,
2683 c.select_candle_field("low")?,
2684 c.select_candle_field("close")?,
2685 );
2686
2687 let input = TradjemaInput::from_slices(h, l, cl, TradjemaParams::default());
2688
2689 let base = tradjema(&input)?.values;
2690
2691 let mut out = vec![0.0; cl.len()];
2692 tradjema_into(&input, &mut out)?;
2693
2694 assert_eq!(base.len(), out.len());
2695
2696 for i in 0..out.len() {
2697 let a = base[i];
2698 let b = out[i];
2699 if a.is_nan() || b.is_nan() {
2700 assert!(
2701 a.is_nan() && b.is_nan(),
2702 "NaN mismatch at index {}: {:?} vs {:?}",
2703 i,
2704 a,
2705 b
2706 );
2707 } else {
2708 assert_eq!(a, b, "Value mismatch at index {}: {} vs {}", i, a, b);
2709 }
2710 }
2711
2712 Ok(())
2713 }
2714}