1use crate::indicators::moving_averages::ma::{ma, MaData};
2use crate::indicators::utility_functions::RollingError;
3use crate::utilities::data_loader::{source_type, Candles};
4use crate::utilities::enums::Kernel;
5use crate::utilities::helpers::{
6 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
7 make_uninit_matrix,
8};
9#[cfg(feature = "python")]
10use crate::utilities::kernel_validation::validate_kernel;
11
12#[cfg(feature = "python")]
13use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
14#[cfg(feature = "python")]
15use pyo3::exceptions::PyValueError;
16#[cfg(feature = "python")]
17use pyo3::prelude::*;
18#[cfg(feature = "python")]
19use pyo3::types::{PyDict, PyList};
20#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
21use serde::{Deserialize, Serialize};
22#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
23use wasm_bindgen::prelude::*;
24
25use std::convert::AsRef;
26use std::error::Error;
27use std::mem::MaybeUninit;
28use thiserror::Error;
29
30#[derive(Debug, Clone)]
31pub enum KdjData<'a> {
32 Candles {
33 candles: &'a Candles,
34 },
35 Slices {
36 high: &'a [f64],
37 low: &'a [f64],
38 close: &'a [f64],
39 },
40}
41
42#[derive(Debug, Clone)]
43pub struct KdjInput<'a> {
44 pub data: KdjData<'a>,
45 pub params: KdjParams,
46}
47
48impl<'a> KdjInput<'a> {
49 #[inline]
50 pub fn from_candles(candles: &'a Candles, params: KdjParams) -> Self {
51 Self {
52 data: KdjData::Candles { candles },
53 params,
54 }
55 }
56 #[inline]
57 pub fn from_slices(
58 high: &'a [f64],
59 low: &'a [f64],
60 close: &'a [f64],
61 params: KdjParams,
62 ) -> Self {
63 Self {
64 data: KdjData::Slices { high, low, close },
65 params,
66 }
67 }
68 #[inline]
69 pub fn with_default_candles(candles: &'a Candles) -> Self {
70 Self::from_candles(candles, KdjParams::default())
71 }
72 #[inline]
73 pub fn get_fast_k_period(&self) -> usize {
74 self.params.fast_k_period.unwrap_or(9)
75 }
76 #[inline]
77 pub fn get_slow_k_period(&self) -> usize {
78 self.params.slow_k_period.unwrap_or(3)
79 }
80 #[inline]
81 pub fn get_slow_k_ma_type(&self) -> &str {
82 self.params.slow_k_ma_type.as_deref().unwrap_or("sma")
83 }
84 #[inline]
85 pub fn get_slow_d_period(&self) -> usize {
86 self.params.slow_d_period.unwrap_or(3)
87 }
88 #[inline]
89 pub fn get_slow_d_ma_type(&self) -> &str {
90 self.params.slow_d_ma_type.as_deref().unwrap_or("sma")
91 }
92}
93
94#[derive(Debug, Clone)]
95#[cfg_attr(
96 all(target_arch = "wasm32", feature = "wasm"),
97 derive(serde::Serialize, serde::Deserialize)
98)]
99pub struct KdjParams {
100 pub fast_k_period: Option<usize>,
101 pub slow_k_period: Option<usize>,
102 pub slow_k_ma_type: Option<String>,
103 pub slow_d_period: Option<usize>,
104 pub slow_d_ma_type: Option<String>,
105}
106
107impl Default for KdjParams {
108 fn default() -> Self {
109 Self {
110 fast_k_period: Some(9),
111 slow_k_period: Some(3),
112 slow_k_ma_type: Some("sma".to_string()),
113 slow_d_period: Some(3),
114 slow_d_ma_type: Some("sma".to_string()),
115 }
116 }
117}
118
119#[derive(Debug, Clone)]
120pub struct KdjOutput {
121 pub k: Vec<f64>,
122 pub d: Vec<f64>,
123 pub j: Vec<f64>,
124}
125
126#[derive(Debug, Error)]
127pub enum KdjError {
128 #[error("kdj: Empty data provided.")]
129 EmptyInputData,
130 #[error("kdj: Empty data provided.")]
131 EmptyData,
132 #[error("kdj: Invalid period: period = {period}, data length = {data_len}")]
133 InvalidPeriod { period: usize, data_len: usize },
134 #[error("kdj: Not enough valid data: needed = {needed}, valid = {valid}")]
135 NotEnoughValidData { needed: usize, valid: usize },
136 #[error("kdj: All values are NaN.")]
137 AllValuesNaN,
138 #[error("kdj: Output slice length mismatch: expected = {expected}, got = {got}")]
139 OutputLengthMismatch { expected: usize, got: usize },
140 #[error("kdj: Buffer size mismatch: expected = {expected}, got = {got}")]
141 BufferSizeMismatch { expected: usize, got: usize },
142 #[error("kdj: Invalid range: start = {start}, end = {end}, step = {step}")]
143 InvalidRange {
144 start: usize,
145 end: usize,
146 step: usize,
147 },
148 #[error("kdj: Invalid kernel type for batch operation: {0:?}")]
149 InvalidKernelForBatch(Kernel),
150 #[error("kdj: Rolling error {0}")]
151 RollingError(#[from] RollingError),
152 #[error("kdj: MA error {0}")]
153 MaError(#[from] Box<dyn Error + Send + Sync>),
154}
155
156#[inline]
157pub fn kdj(input: &KdjInput) -> Result<KdjOutput, KdjError> {
158 kdj_with_kernel(input, Kernel::Auto)
159}
160
161pub fn kdj_with_kernel(input: &KdjInput, kernel: Kernel) -> Result<KdjOutput, KdjError> {
162 let (high, low, close): (&[f64], &[f64], &[f64]) = match &input.data {
163 KdjData::Candles { candles } => (
164 source_type(candles, "high"),
165 source_type(candles, "low"),
166 source_type(candles, "close"),
167 ),
168 KdjData::Slices { high, low, close } => (high, low, close),
169 };
170
171 if high.is_empty() || low.is_empty() || close.is_empty() {
172 return Err(KdjError::EmptyInputData);
173 }
174
175 let fast_k_period = input.get_fast_k_period();
176 let slow_k_period = input.get_slow_k_period();
177 let slow_k_ma_type = input.get_slow_k_ma_type();
178 let slow_d_period = input.get_slow_d_period();
179 let slow_d_ma_type = input.get_slow_d_ma_type();
180
181 if fast_k_period == 0 || fast_k_period > high.len() {
182 return Err(KdjError::InvalidPeriod {
183 period: fast_k_period,
184 data_len: high.len(),
185 });
186 }
187 if slow_k_period == 0 {
188 return Err(KdjError::InvalidPeriod {
189 period: slow_k_period,
190 data_len: high.len(),
191 });
192 }
193 if slow_d_period == 0 {
194 return Err(KdjError::InvalidPeriod {
195 period: slow_d_period,
196 data_len: high.len(),
197 });
198 }
199
200 let first_valid_idx = high
201 .iter()
202 .zip(low.iter())
203 .zip(close.iter())
204 .position(|((&h, &l), &c)| !h.is_nan() && !l.is_nan() && !c.is_nan())
205 .ok_or(KdjError::AllValuesNaN)?;
206
207 if (high.len() - first_valid_idx) < fast_k_period {
208 return Err(KdjError::NotEnoughValidData {
209 needed: fast_k_period,
210 valid: high.len() - first_valid_idx,
211 });
212 }
213
214 let mut chosen = match kernel {
215 Kernel::Auto => detect_best_kernel(),
216 other => other,
217 };
218
219 if matches!(kernel, Kernel::Auto)
220 && fast_k_period == 9
221 && slow_k_period == 3
222 && slow_d_period == 3
223 && slow_k_ma_type.eq_ignore_ascii_case("sma")
224 && slow_d_ma_type.eq_ignore_ascii_case("sma")
225 {
226 chosen = Kernel::Scalar;
227 }
228
229 unsafe {
230 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
231 {
232 if matches!(chosen, Kernel::Scalar | Kernel::ScalarBatch) {
233 return kdj_simd128(
234 high,
235 low,
236 close,
237 fast_k_period,
238 slow_k_period,
239 slow_k_ma_type,
240 slow_d_period,
241 slow_d_ma_type,
242 first_valid_idx,
243 );
244 }
245 }
246
247 match chosen {
248 Kernel::Scalar | Kernel::ScalarBatch => kdj_scalar(
249 high,
250 low,
251 close,
252 fast_k_period,
253 slow_k_period,
254 slow_k_ma_type,
255 slow_d_period,
256 slow_d_ma_type,
257 first_valid_idx,
258 ),
259 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
260 Kernel::Avx2 | Kernel::Avx2Batch => kdj_avx2(
261 high,
262 low,
263 close,
264 fast_k_period,
265 slow_k_period,
266 slow_k_ma_type,
267 slow_d_period,
268 slow_d_ma_type,
269 first_valid_idx,
270 ),
271 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
272 Kernel::Avx512 | Kernel::Avx512Batch => kdj_avx512(
273 high,
274 low,
275 close,
276 fast_k_period,
277 slow_k_period,
278 slow_k_ma_type,
279 slow_d_period,
280 slow_d_ma_type,
281 first_valid_idx,
282 ),
283 _ => unreachable!(),
284 }
285 }
286}
287
288#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
289#[inline]
290pub fn kdj_into(
291 input: &KdjInput,
292 k_out: &mut [f64],
293 d_out: &mut [f64],
294 j_out: &mut [f64],
295) -> Result<(), KdjError> {
296 kdj_into_slices(k_out, d_out, j_out, input, Kernel::Auto)
297}
298
299#[inline]
300pub fn kdj_scalar(
301 high: &[f64],
302 low: &[f64],
303 close: &[f64],
304 fast_k_period: usize,
305 slow_k_period: usize,
306 slow_k_ma_type: &str,
307 slow_d_period: usize,
308 slow_d_ma_type: &str,
309 first_valid_idx: usize,
310) -> Result<KdjOutput, KdjError> {
311 let len = high.len();
312 let mut k: Vec<f64> = Vec::with_capacity(len);
313 let mut d: Vec<f64> = Vec::with_capacity(len);
314 let mut j: Vec<f64> = Vec::with_capacity(len);
315 unsafe {
316 k.set_len(len);
317 d.set_len(len);
318 j.set_len(len);
319 }
320
321 kdj_compute_into_scalar(
322 high,
323 low,
324 close,
325 first_valid_idx,
326 fast_k_period,
327 slow_k_period,
328 slow_k_ma_type,
329 slow_d_period,
330 slow_d_ma_type,
331 &mut k,
332 &mut d,
333 &mut j,
334 )?;
335
336 Ok(KdjOutput { k, d, j })
337}
338
339#[inline]
340pub fn kdj_into_slices(
341 k_out: &mut [f64],
342 d_out: &mut [f64],
343 j_out: &mut [f64],
344 input: &KdjInput,
345 kern: Kernel,
346) -> Result<(), KdjError> {
347 let (high, low, close) = match &input.data {
348 KdjData::Candles { candles } => (
349 source_type(candles, "high"),
350 source_type(candles, "low"),
351 source_type(candles, "close"),
352 ),
353 KdjData::Slices { high, low, close } => (*high, *low, *close),
354 };
355 if high.is_empty() || low.is_empty() || close.is_empty() {
356 return Err(KdjError::EmptyInputData);
357 }
358 let len = high.len();
359 if k_out.len() != len {
360 return Err(KdjError::OutputLengthMismatch {
361 expected: len,
362 got: k_out.len(),
363 });
364 }
365 if d_out.len() != len {
366 return Err(KdjError::OutputLengthMismatch {
367 expected: len,
368 got: d_out.len(),
369 });
370 }
371 if j_out.len() != len {
372 return Err(KdjError::OutputLengthMismatch {
373 expected: len,
374 got: j_out.len(),
375 });
376 }
377
378 let fast_k = input.get_fast_k_period();
379 if fast_k == 0 || fast_k > len {
380 return Err(KdjError::InvalidPeriod {
381 period: fast_k,
382 data_len: len,
383 });
384 }
385
386 let first = high
387 .iter()
388 .zip(low.iter())
389 .zip(close.iter())
390 .position(|((&h, &l), &c)| !h.is_nan() && !l.is_nan() && !c.is_nan())
391 .ok_or(KdjError::AllValuesNaN)?;
392
393 if len - first < fast_k {
394 return Err(KdjError::NotEnoughValidData {
395 needed: fast_k,
396 valid: len - first,
397 });
398 }
399
400 let slow_k = input.get_slow_k_period();
401 let slow_d = input.get_slow_d_period();
402 let slow_k_ma = input.get_slow_k_ma_type();
403 let slow_d_ma = input.get_slow_d_ma_type();
404 if slow_k == 0 {
405 return Err(KdjError::InvalidPeriod {
406 period: slow_k,
407 data_len: len,
408 });
409 }
410 if slow_d == 0 {
411 return Err(KdjError::InvalidPeriod {
412 period: slow_d,
413 data_len: len,
414 });
415 }
416
417 let chosen = match kern {
418 Kernel::Auto => detect_best_kernel(),
419 k => k,
420 };
421
422 match chosen {
423 Kernel::Scalar | Kernel::ScalarBatch => kdj_compute_into_scalar(
424 high, low, close, first, fast_k, slow_k, slow_k_ma, slow_d, slow_d_ma, k_out, d_out,
425 j_out,
426 ),
427 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
428 Kernel::Avx2 | Kernel::Avx2Batch => kdj_compute_into_scalar(
429 high, low, close, first, fast_k, slow_k, slow_k_ma, slow_d, slow_d_ma, k_out, d_out,
430 j_out,
431 ),
432 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
433 Kernel::Avx512 | Kernel::Avx512Batch => kdj_compute_into_scalar(
434 high, low, close, first, fast_k, slow_k, slow_k_ma, slow_d, slow_d_ma, k_out, d_out,
435 j_out,
436 ),
437 _ => kdj_compute_into_scalar(
438 high, low, close, first, fast_k, slow_k, slow_k_ma, slow_d, slow_d_ma, k_out, d_out,
439 j_out,
440 ),
441 }
442}
443
444#[inline]
445fn kdj_compute_into_scalar(
446 high: &[f64],
447 low: &[f64],
448 close: &[f64],
449 first: usize,
450 fast_k: usize,
451 slow_k: usize,
452 slow_k_ma: &str,
453 slow_d: usize,
454 slow_d_ma: &str,
455 k_out: &mut [f64],
456 d_out: &mut [f64],
457 j_out: &mut [f64],
458) -> Result<(), KdjError> {
459 use std::collections::VecDeque;
460
461 let len = high.len();
462 if len == 0 {
463 return Err(KdjError::EmptyInputData);
464 }
465
466 let stoch_warm = first + fast_k - 1;
467 let k_warm = stoch_warm + slow_k - 1;
468 let d_warm = k_warm + slow_d - 1;
469
470 let sma_k = slow_k_ma.eq_ignore_ascii_case("sma");
471 let sma_d = slow_d_ma.eq_ignore_ascii_case("sma");
472 if sma_k && sma_d {
473 for i in 0..k_warm.min(len) {
474 k_out[i] = f64::NAN;
475 }
476 for i in 0..d_warm.min(len) {
477 d_out[i] = f64::NAN;
478 j_out[i] = f64::NAN;
479 }
480
481 let cap = fast_k + 1;
482 let mut max_idx = vec![0usize; cap];
483 let mut max_val = vec![0.0f64; cap];
484 let mut min_idx = vec![0usize; cap];
485 let mut min_val = vec![0.0f64; cap];
486 let (mut max_head, mut max_tail, mut max_cnt) = (0usize, 0usize, 0usize);
487 let (mut min_head, mut min_tail, mut min_cnt) = (0usize, 0usize, 0usize);
488 #[inline(always)]
489 fn inc(i: usize, cap: usize) -> usize {
490 let j = i + 1;
491 if j == cap {
492 0
493 } else {
494 j
495 }
496 }
497 #[inline(always)]
498 fn dec(i: usize, cap: usize) -> usize {
499 if i == 0 {
500 cap - 1
501 } else {
502 i - 1
503 }
504 }
505
506 let mut stoch_ring = vec![f64::NAN; slow_k];
507 let mut sum_k = 0.0f64;
508 let mut cnt_k: usize = 0;
509
510 let mut k_ring = vec![f64::NAN; slow_d];
511 let mut sum_d = 0.0f64;
512 let mut cnt_d: usize = 0;
513
514 let mut pos_k = stoch_warm % slow_k;
515 let mut pos_d = k_warm % slow_d;
516
517 for i in first..len {
518 let hi = unsafe { *high.get_unchecked(i) };
519 while max_cnt > 0 {
520 let back = dec(max_tail, cap);
521 if max_val[back] <= hi {
522 max_tail = back;
523 max_cnt -= 1;
524 } else {
525 break;
526 }
527 }
528 max_val[max_tail] = hi;
529 max_idx[max_tail] = i;
530 max_tail = inc(max_tail, cap);
531 max_cnt += 1;
532 while max_cnt > 0 && max_idx[max_head] + fast_k <= i {
533 max_head = inc(max_head, cap);
534 max_cnt -= 1;
535 }
536
537 let lo = unsafe { *low.get_unchecked(i) };
538 while min_cnt > 0 {
539 let back = dec(min_tail, cap);
540 if min_val[back] >= lo {
541 min_tail = back;
542 min_cnt -= 1;
543 } else {
544 break;
545 }
546 }
547 min_val[min_tail] = lo;
548 min_idx[min_tail] = i;
549 min_tail = inc(min_tail, cap);
550 min_cnt += 1;
551 while min_cnt > 0 && min_idx[min_head] + fast_k <= i {
552 min_head = inc(min_head, cap);
553 min_cnt -= 1;
554 }
555
556 if i < stoch_warm {
557 continue;
558 }
559
560 let hh = max_val[max_head];
561 let ll = min_val[min_head];
562 let denom = hh - ll;
563 let stoch_i = if denom == 0.0 || denom.is_nan() {
564 f64::NAN
565 } else {
566 let c = unsafe { *close.get_unchecked(i) };
567 100.0 * ((c - ll) / denom)
568 };
569
570 let old_st = stoch_ring[pos_k];
571 if !old_st.is_nan() {
572 sum_k -= old_st;
573 cnt_k -= 1;
574 }
575 stoch_ring[pos_k] = stoch_i;
576 if !stoch_i.is_nan() {
577 sum_k += stoch_i;
578 cnt_k += 1;
579 }
580 pos_k += 1;
581 if pos_k == slow_k {
582 pos_k = 0;
583 }
584
585 if i >= k_warm {
586 let k_val = if cnt_k > 0 {
587 sum_k / (cnt_k as f64)
588 } else {
589 f64::NAN
590 };
591 unsafe { *k_out.get_unchecked_mut(i) = k_val };
592
593 let old_k = k_ring[pos_d];
594 if !old_k.is_nan() {
595 sum_d -= old_k;
596 cnt_d -= 1;
597 }
598 k_ring[pos_d] = k_val;
599 if !k_val.is_nan() {
600 sum_d += k_val;
601 cnt_d += 1;
602 }
603 pos_d += 1;
604 if pos_d == slow_d {
605 pos_d = 0;
606 }
607
608 if i >= d_warm {
609 let d_val = if cnt_d > 0 {
610 sum_d / (cnt_d as f64)
611 } else {
612 f64::NAN
613 };
614 unsafe {
615 *d_out.get_unchecked_mut(i) = d_val;
616 *j_out.get_unchecked_mut(i) = if k_val.is_nan() || d_val.is_nan() {
617 f64::NAN
618 } else {
619 3.0 * k_val - 2.0 * d_val
620 };
621 }
622 }
623 }
624 }
625 return Ok(());
626 }
627
628 let ema_k = slow_k_ma.eq_ignore_ascii_case("ema");
629 let ema_d = slow_d_ma.eq_ignore_ascii_case("ema");
630 if ema_k && ema_d {
631 for i in 0..k_warm.min(len) {
632 k_out[i] = f64::NAN;
633 }
634 for i in 0..d_warm.min(len) {
635 d_out[i] = f64::NAN;
636 j_out[i] = f64::NAN;
637 }
638
639 let mut maxdq: VecDeque<usize> = VecDeque::with_capacity(fast_k + 1);
640 let mut mindq: VecDeque<usize> = VecDeque::with_capacity(fast_k + 1);
641
642 let alpha_k = 2.0 / (slow_k as f64 + 1.0);
643 let om_alpha_k = 1.0 - alpha_k;
644 let alpha_d = 2.0 / (slow_d as f64 + 1.0);
645 let om_alpha_d = 1.0 - alpha_d;
646
647 let mut sum_init_k = 0.0f64;
648 let mut cnt_init_k: usize = 0;
649 let mut ema_kv = f64::NAN;
650
651 let mut sum_init_d = 0.0f64;
652 let mut cnt_init_d: usize = 0;
653 let mut ema_dv = f64::NAN;
654
655 for i in first..len {
656 let hi = unsafe { *high.get_unchecked(i) };
657 while let Some(&idx) = maxdq.back() {
658 if unsafe { *high.get_unchecked(idx) } <= hi {
659 maxdq.pop_back();
660 } else {
661 break;
662 }
663 }
664 maxdq.push_back(i);
665 while let Some(&idx) = maxdq.front() {
666 if idx + fast_k <= i {
667 maxdq.pop_front();
668 } else {
669 break;
670 }
671 }
672
673 let lo = unsafe { *low.get_unchecked(i) };
674 while let Some(&idx) = mindq.back() {
675 if unsafe { *low.get_unchecked(idx) } >= lo {
676 mindq.pop_back();
677 } else {
678 break;
679 }
680 }
681 mindq.push_back(i);
682 while let Some(&idx) = mindq.front() {
683 if idx + fast_k <= i {
684 mindq.pop_front();
685 } else {
686 break;
687 }
688 }
689
690 if i < stoch_warm {
691 continue;
692 }
693
694 let hh = unsafe { *high.get_unchecked(*maxdq.front().unwrap()) };
695 let ll = unsafe { *low.get_unchecked(*mindq.front().unwrap()) };
696 let denom = hh - ll;
697 let stoch_i = if denom == 0.0 || denom.is_nan() {
698 f64::NAN
699 } else {
700 let c = unsafe { *close.get_unchecked(i) };
701 100.0 * ((c - ll) / denom)
702 };
703
704 if i <= k_warm {
705 if !stoch_i.is_nan() {
706 sum_init_k += stoch_i;
707 cnt_init_k += 1;
708 }
709 if i == k_warm {
710 ema_kv = if cnt_init_k > 0 {
711 sum_init_k / (cnt_init_k as f64)
712 } else {
713 f64::NAN
714 };
715 unsafe { *k_out.get_unchecked_mut(i) = ema_kv };
716 if !ema_kv.is_nan() {
717 sum_init_d += ema_kv;
718 cnt_init_d += 1;
719 }
720 }
721 continue;
722 }
723
724 if !stoch_i.is_nan() && !ema_kv.is_nan() {
725 ema_kv = stoch_i.mul_add(alpha_k, om_alpha_k * ema_kv);
726 } else if !stoch_i.is_nan() && ema_kv.is_nan() {
727 ema_kv = stoch_i;
728 }
729 unsafe { *k_out.get_unchecked_mut(i) = ema_kv };
730
731 if i <= d_warm {
732 if !ema_kv.is_nan() {
733 sum_init_d += ema_kv;
734 cnt_init_d += 1;
735 }
736 if i == d_warm {
737 ema_dv = if cnt_init_d > 0 {
738 sum_init_d / (cnt_init_d as f64)
739 } else {
740 f64::NAN
741 };
742 unsafe {
743 *d_out.get_unchecked_mut(i) = ema_dv;
744 *j_out.get_unchecked_mut(i) = if ema_kv.is_nan() || ema_dv.is_nan() {
745 f64::NAN
746 } else {
747 3.0 * ema_kv - 2.0 * ema_dv
748 };
749 }
750 }
751 continue;
752 }
753
754 if !ema_kv.is_nan() && !ema_dv.is_nan() {
755 ema_dv = ema_kv.mul_add(alpha_d, om_alpha_d * ema_dv);
756 } else if !ema_kv.is_nan() && ema_dv.is_nan() {
757 ema_dv = ema_kv;
758 }
759 unsafe {
760 *d_out.get_unchecked_mut(i) = ema_dv;
761 *j_out.get_unchecked_mut(i) = if ema_kv.is_nan() || ema_dv.is_nan() {
762 f64::NAN
763 } else {
764 3.0 * ema_kv - 2.0 * ema_dv
765 };
766 }
767 }
768 return Ok(());
769 }
770
771 let mut stoch = alloc_with_nan_prefix(len, stoch_warm);
772
773 let mut maxdq: VecDeque<usize> = VecDeque::with_capacity(fast_k + 1);
774 let mut mindq: VecDeque<usize> = VecDeque::with_capacity(fast_k + 1);
775
776 for i in first..len {
777 let hi = unsafe { *high.get_unchecked(i) };
778 while let Some(&idx) = maxdq.back() {
779 if unsafe { *high.get_unchecked(idx) } <= hi {
780 maxdq.pop_back();
781 } else {
782 break;
783 }
784 }
785 maxdq.push_back(i);
786 while let Some(&idx) = maxdq.front() {
787 if idx + fast_k <= i {
788 maxdq.pop_front();
789 } else {
790 break;
791 }
792 }
793
794 let lo = unsafe { *low.get_unchecked(i) };
795 while let Some(&idx) = mindq.back() {
796 if unsafe { *low.get_unchecked(idx) } >= lo {
797 mindq.pop_back();
798 } else {
799 break;
800 }
801 }
802 mindq.push_back(i);
803 while let Some(&idx) = mindq.front() {
804 if idx + fast_k <= i {
805 mindq.pop_front();
806 } else {
807 break;
808 }
809 }
810
811 if i < stoch_warm {
812 continue;
813 }
814
815 let hh = unsafe { *high.get_unchecked(*maxdq.front().unwrap()) };
816 let ll = unsafe { *low.get_unchecked(*mindq.front().unwrap()) };
817 let denom = hh - ll;
818 let val = if denom == 0.0 || denom.is_nan() {
819 f64::NAN
820 } else {
821 let c = unsafe { *close.get_unchecked(i) };
822 100.0 * ((c - ll) / denom)
823 };
824 unsafe { *stoch.get_unchecked_mut(i) = val };
825 }
826
827 let k_vec = ma(slow_k_ma, MaData::Slice(&stoch), slow_k)
828 .map_err(|e| KdjError::MaError(e.to_string().into()))?;
829 let d_vec = ma(slow_d_ma, MaData::Slice(&k_vec), slow_d)
830 .map_err(|e| KdjError::MaError(e.to_string().into()))?;
831
832 k_out.copy_from_slice(&k_vec);
833 d_out.copy_from_slice(&d_vec);
834
835 let j_warm = stoch_warm + slow_k - 1 + slow_d - 1;
836 for i in 0..j_warm.min(j_out.len()) {
837 j_out[i] = f64::NAN;
838 }
839 for i in j_warm..j_out.len() {
840 j_out[i] = if k_out[i].is_nan() || d_out[i].is_nan() {
841 f64::NAN
842 } else {
843 3.0 * k_out[i] - 2.0 * d_out[i]
844 };
845 }
846 Ok(())
847}
848
849#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
850#[inline]
851pub fn kdj_avx512(
852 high: &[f64],
853 low: &[f64],
854 close: &[f64],
855 fast_k_period: usize,
856 slow_k_period: usize,
857 slow_k_ma_type: &str,
858 slow_d_period: usize,
859 slow_d_ma_type: &str,
860 first_valid_idx: usize,
861) -> Result<KdjOutput, KdjError> {
862 if fast_k_period <= 32 {
863 unsafe {
864 kdj_avx512_short(
865 high,
866 low,
867 close,
868 fast_k_period,
869 slow_k_period,
870 slow_k_ma_type,
871 slow_d_period,
872 slow_d_ma_type,
873 first_valid_idx,
874 )
875 }
876 } else {
877 unsafe {
878 kdj_avx512_long(
879 high,
880 low,
881 close,
882 fast_k_period,
883 slow_k_period,
884 slow_k_ma_type,
885 slow_d_period,
886 slow_d_ma_type,
887 first_valid_idx,
888 )
889 }
890 }
891}
892
893#[inline]
894pub fn kdj_avx2(
895 high: &[f64],
896 low: &[f64],
897 close: &[f64],
898 fast_k_period: usize,
899 slow_k_period: usize,
900 slow_k_ma_type: &str,
901 slow_d_period: usize,
902 slow_d_ma_type: &str,
903 first_valid_idx: usize,
904) -> Result<KdjOutput, KdjError> {
905 kdj_scalar(
906 high,
907 low,
908 close,
909 fast_k_period,
910 slow_k_period,
911 slow_k_ma_type,
912 slow_d_period,
913 slow_d_ma_type,
914 first_valid_idx,
915 )
916}
917
918#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
919#[inline]
920pub unsafe fn kdj_avx512_short(
921 high: &[f64],
922 low: &[f64],
923 close: &[f64],
924 fast_k_period: usize,
925 slow_k_period: usize,
926 slow_k_ma_type: &str,
927 slow_d_period: usize,
928 slow_d_ma_type: &str,
929 first_valid_idx: usize,
930) -> Result<KdjOutput, KdjError> {
931 kdj_scalar(
932 high,
933 low,
934 close,
935 fast_k_period,
936 slow_k_period,
937 slow_k_ma_type,
938 slow_d_period,
939 slow_d_ma_type,
940 first_valid_idx,
941 )
942}
943
944#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
945#[inline]
946pub unsafe fn kdj_avx512_long(
947 high: &[f64],
948 low: &[f64],
949 close: &[f64],
950 fast_k_period: usize,
951 slow_k_period: usize,
952 slow_k_ma_type: &str,
953 slow_d_period: usize,
954 slow_d_ma_type: &str,
955 first_valid_idx: usize,
956) -> Result<KdjOutput, KdjError> {
957 kdj_scalar(
958 high,
959 low,
960 close,
961 fast_k_period,
962 slow_k_period,
963 slow_k_ma_type,
964 slow_d_period,
965 slow_d_ma_type,
966 first_valid_idx,
967 )
968}
969
970#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
971#[inline]
972pub fn kdj_simd128(
973 high: &[f64],
974 low: &[f64],
975 close: &[f64],
976 fast_k_period: usize,
977 slow_k_period: usize,
978 slow_k_ma_type: &str,
979 slow_d_period: usize,
980 slow_d_ma_type: &str,
981 first_valid_idx: usize,
982) -> Result<KdjOutput, KdjError> {
983 kdj_scalar(
984 high,
985 low,
986 close,
987 fast_k_period,
988 slow_k_period,
989 slow_k_ma_type,
990 slow_d_period,
991 slow_d_ma_type,
992 first_valid_idx,
993 )
994}
995
996#[derive(Clone, Debug)]
997pub struct KdjBuilder {
998 fast_k_period: Option<usize>,
999 slow_k_period: Option<usize>,
1000 slow_k_ma_type: Option<String>,
1001 slow_d_period: Option<usize>,
1002 slow_d_ma_type: Option<String>,
1003 kernel: Kernel,
1004}
1005
1006impl Default for KdjBuilder {
1007 fn default() -> Self {
1008 Self {
1009 fast_k_period: None,
1010 slow_k_period: None,
1011 slow_k_ma_type: None,
1012 slow_d_period: None,
1013 slow_d_ma_type: None,
1014 kernel: Kernel::Auto,
1015 }
1016 }
1017}
1018
1019impl KdjBuilder {
1020 #[inline(always)]
1021 pub fn new() -> Self {
1022 Self::default()
1023 }
1024 #[inline(always)]
1025 pub fn fast_k_period(mut self, n: usize) -> Self {
1026 self.fast_k_period = Some(n);
1027 self
1028 }
1029 #[inline(always)]
1030 pub fn slow_k_period(mut self, n: usize) -> Self {
1031 self.slow_k_period = Some(n);
1032 self
1033 }
1034 #[inline(always)]
1035 pub fn slow_k_ma_type<S: Into<String>>(mut self, t: S) -> Self {
1036 self.slow_k_ma_type = Some(t.into());
1037 self
1038 }
1039 #[inline(always)]
1040 pub fn slow_d_period(mut self, n: usize) -> Self {
1041 self.slow_d_period = Some(n);
1042 self
1043 }
1044 #[inline(always)]
1045 pub fn slow_d_ma_type<S: Into<String>>(mut self, t: S) -> Self {
1046 self.slow_d_ma_type = Some(t.into());
1047 self
1048 }
1049 #[inline(always)]
1050 pub fn kernel(mut self, k: Kernel) -> Self {
1051 self.kernel = k;
1052 self
1053 }
1054
1055 #[inline(always)]
1056 pub fn apply(self, c: &Candles) -> Result<KdjOutput, KdjError> {
1057 let p = KdjParams {
1058 fast_k_period: self.fast_k_period,
1059 slow_k_period: self.slow_k_period,
1060 slow_k_ma_type: self.slow_k_ma_type,
1061 slow_d_period: self.slow_d_period,
1062 slow_d_ma_type: self.slow_d_ma_type,
1063 };
1064 let i = KdjInput::from_candles(c, p);
1065 kdj_with_kernel(&i, self.kernel)
1066 }
1067
1068 #[inline(always)]
1069 pub fn apply_slices(
1070 self,
1071 high: &[f64],
1072 low: &[f64],
1073 close: &[f64],
1074 ) -> Result<KdjOutput, KdjError> {
1075 let p = KdjParams {
1076 fast_k_period: self.fast_k_period,
1077 slow_k_period: self.slow_k_period,
1078 slow_k_ma_type: self.slow_k_ma_type,
1079 slow_d_period: self.slow_d_period,
1080 slow_d_ma_type: self.slow_d_ma_type,
1081 };
1082 let i = KdjInput::from_slices(high, low, close, p);
1083 kdj_with_kernel(&i, self.kernel)
1084 }
1085
1086 #[inline(always)]
1087 pub fn into_stream(self) -> Result<KdjStream, KdjError> {
1088 let p = KdjParams {
1089 fast_k_period: self.fast_k_period,
1090 slow_k_period: self.slow_k_period,
1091 slow_k_ma_type: self.slow_k_ma_type,
1092 slow_d_period: self.slow_d_period,
1093 slow_d_ma_type: self.slow_d_ma_type,
1094 };
1095 KdjStream::try_new(p)
1096 }
1097}
1098
1099use std::collections::VecDeque;
1100
1101#[derive(Debug, Clone)]
1102pub struct KdjStream {
1103 fast_k_period: usize,
1104 slow_k_period: usize,
1105 slow_d_period: usize,
1106
1107 k_is_sma: bool,
1108 k_is_ema: bool,
1109 d_is_sma: bool,
1110 d_is_ema: bool,
1111
1112 i: usize,
1113 maxdq: VecDeque<(usize, f64)>,
1114 mindq: VecDeque<(usize, f64)>,
1115
1116 have_fast: bool,
1117 stoch_samples: usize,
1118 k_samples: usize,
1119
1120 stoch_ring: Vec<f64>,
1121 stoch_pos: usize,
1122 sum_k: f64,
1123 cnt_k: usize,
1124 stoch_filled: bool,
1125
1126 k_ring: Vec<f64>,
1127 k_pos: usize,
1128 sum_d: f64,
1129 cnt_d: usize,
1130 k_filled: bool,
1131
1132 alpha_k: f64,
1133 om_alpha_k: f64,
1134 ema_k: f64,
1135 k_ema_inited: bool,
1136 init_sum_k: f64,
1137 init_cnt_k: usize,
1138
1139 alpha_d: f64,
1140 om_alpha_d: f64,
1141 ema_d: f64,
1142 d_ema_inited: bool,
1143 init_sum_d: f64,
1144 init_cnt_d: usize,
1145
1146 inv_cnt_k: Vec<f64>,
1147 inv_cnt_d: Vec<f64>,
1148}
1149
1150impl KdjStream {
1151 pub fn try_new(params: KdjParams) -> Result<Self, KdjError> {
1152 let fast_k_period = params.fast_k_period.unwrap_or(9);
1153 let slow_k_period = params.slow_k_period.unwrap_or(3);
1154 let slow_k_ma_type = params.slow_k_ma_type.unwrap_or_else(|| "sma".to_string());
1155 let slow_d_period = params.slow_d_period.unwrap_or(3);
1156 let slow_d_ma_type = params.slow_d_ma_type.unwrap_or_else(|| "sma".to_string());
1157
1158 if fast_k_period == 0 {
1159 return Err(KdjError::InvalidPeriod {
1160 period: fast_k_period,
1161 data_len: 0,
1162 });
1163 }
1164 if slow_k_period == 0 {
1165 return Err(KdjError::InvalidPeriod {
1166 period: slow_k_period,
1167 data_len: 0,
1168 });
1169 }
1170 if slow_d_period == 0 {
1171 return Err(KdjError::InvalidPeriod {
1172 period: slow_d_period,
1173 data_len: 0,
1174 });
1175 }
1176
1177 let k_is_sma = slow_k_ma_type.eq_ignore_ascii_case("sma");
1178 let k_is_ema = slow_k_ma_type.eq_ignore_ascii_case("ema");
1179 let d_is_sma = slow_d_ma_type.eq_ignore_ascii_case("sma");
1180 let d_is_ema = slow_d_ma_type.eq_ignore_ascii_case("ema");
1181
1182 let alpha_k = 2.0 / (slow_k_period as f64 + 1.0);
1183 let om_alpha_k = 1.0 - alpha_k;
1184 let alpha_d = 2.0 / (slow_d_period as f64 + 1.0);
1185 let om_alpha_d = 1.0 - alpha_d;
1186
1187 fn build_inv(n: usize) -> Vec<f64> {
1188 let mut v = vec![f64::NAN; n + 1];
1189 for c in 1..=n {
1190 v[c] = 1.0 / (c as f64);
1191 }
1192 v
1193 }
1194
1195 Ok(Self {
1196 fast_k_period,
1197 slow_k_period,
1198 slow_d_period,
1199 k_is_sma,
1200 k_is_ema,
1201 d_is_sma,
1202 d_is_ema,
1203
1204 i: 0,
1205 maxdq: VecDeque::with_capacity(fast_k_period + 1),
1206 mindq: VecDeque::with_capacity(fast_k_period + 1),
1207
1208 have_fast: false,
1209 stoch_samples: 0,
1210 k_samples: 0,
1211
1212 stoch_ring: vec![f64::NAN; slow_k_period],
1213 stoch_pos: 0,
1214 sum_k: 0.0,
1215 cnt_k: 0,
1216 stoch_filled: false,
1217
1218 k_ring: vec![f64::NAN; slow_d_period],
1219 k_pos: 0,
1220 sum_d: 0.0,
1221 cnt_d: 0,
1222 k_filled: false,
1223
1224 alpha_k,
1225 om_alpha_k,
1226 ema_k: f64::NAN,
1227 k_ema_inited: false,
1228 init_sum_k: 0.0,
1229 init_cnt_k: 0,
1230
1231 alpha_d,
1232 om_alpha_d,
1233 ema_d: f64::NAN,
1234 d_ema_inited: false,
1235 init_sum_d: 0.0,
1236 init_cnt_d: 0,
1237
1238 inv_cnt_k: build_inv(slow_k_period),
1239 inv_cnt_d: build_inv(slow_d_period),
1240 })
1241 }
1242
1243 #[inline(always)]
1244 pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64, f64)> {
1245 let idx = self.i;
1246 self.i = idx + 1;
1247
1248 if !high.is_nan() {
1249 while let Some(&(_, v)) = self.maxdq.back() {
1250 if v <= high {
1251 self.maxdq.pop_back();
1252 } else {
1253 break;
1254 }
1255 }
1256 self.maxdq.push_back((idx, high));
1257 }
1258 if !low.is_nan() {
1259 while let Some(&(_, v)) = self.mindq.back() {
1260 if v >= low {
1261 self.mindq.pop_back();
1262 } else {
1263 break;
1264 }
1265 }
1266 self.mindq.push_back((idx, low));
1267 }
1268
1269 let expire_before = idx + 1 - self.fast_k_period;
1270 while let Some(&(j, _)) = self.maxdq.front() {
1271 if j < expire_before {
1272 self.maxdq.pop_front();
1273 } else {
1274 break;
1275 }
1276 }
1277 while let Some(&(j, _)) = self.mindq.front() {
1278 if j < expire_before {
1279 self.mindq.pop_front();
1280 } else {
1281 break;
1282 }
1283 }
1284
1285 if !self.have_fast && (idx + 1) >= self.fast_k_period {
1286 self.have_fast = true;
1287 }
1288 if !self.have_fast {
1289 return None;
1290 }
1291
1292 let stoch = if close.is_nan() || self.maxdq.is_empty() || self.mindq.is_empty() {
1293 f64::NAN
1294 } else {
1295 let hh = self.maxdq.front().unwrap().1;
1296 let ll = self.mindq.front().unwrap().1;
1297 let denom = hh - ll;
1298 if denom == 0.0 || denom.is_nan() {
1299 f64::NAN
1300 } else {
1301 let inv = 1.0 / denom;
1302 (close - ll) * (100.0 * inv)
1303 }
1304 };
1305 self.stoch_samples += 1;
1306
1307 let mut k_val = f64::NAN;
1308 let k_now_available: bool;
1309
1310 if self.k_is_sma || (!self.k_is_ema && !self.k_is_sma) {
1311 let pos = self.stoch_pos;
1312 let old = self.stoch_ring[pos];
1313 if !old.is_nan() {
1314 self.sum_k -= old;
1315 self.cnt_k -= 1;
1316 }
1317 self.stoch_ring[pos] = stoch;
1318 self.stoch_pos = (pos + 1) % self.slow_k_period;
1319 if !stoch.is_nan() {
1320 self.sum_k += stoch;
1321 self.cnt_k += 1;
1322 }
1323 if !self.stoch_filled && self.stoch_pos == 0 {
1324 self.stoch_filled = true;
1325 }
1326
1327 if self.stoch_filled {
1328 k_val = if self.cnt_k > 0 {
1329 self.sum_k * self.inv_cnt_k[self.cnt_k]
1330 } else {
1331 f64::NAN
1332 };
1333 k_now_available = true;
1334 } else {
1335 k_now_available = false;
1336 }
1337 } else {
1338 if !self.k_ema_inited {
1339 if !stoch.is_nan() {
1340 self.init_sum_k += stoch;
1341 self.init_cnt_k += 1;
1342 }
1343 if self.stoch_samples == self.slow_k_period {
1344 self.ema_k = if self.init_cnt_k > 0 {
1345 self.init_sum_k * self.inv_cnt_k[self.init_cnt_k]
1346 } else {
1347 f64::NAN
1348 };
1349 self.k_ema_inited = true;
1350 k_val = self.ema_k;
1351 k_now_available = true;
1352 } else {
1353 k_now_available = false;
1354 }
1355 } else {
1356 if !stoch.is_nan() && !self.ema_k.is_nan() {
1357 self.ema_k = stoch.mul_add(self.alpha_k, self.om_alpha_k * self.ema_k);
1358 } else if !stoch.is_nan() && self.ema_k.is_nan() {
1359 self.ema_k = stoch;
1360 }
1361 k_val = self.ema_k;
1362 k_now_available = true;
1363 }
1364 }
1365
1366 if !k_now_available {
1367 return None;
1368 }
1369
1370 let mut d_val = f64::NAN;
1371 let d_now_available: bool;
1372
1373 if self.d_is_sma || (!self.d_is_ema && !self.d_is_sma) {
1374 let pos = self.k_pos;
1375 let old_k = self.k_ring[pos];
1376 if !old_k.is_nan() {
1377 self.sum_d -= old_k;
1378 self.cnt_d -= 1;
1379 }
1380 self.k_ring[pos] = k_val;
1381 self.k_pos = (pos + 1) % self.slow_d_period;
1382 if !k_val.is_nan() {
1383 self.sum_d += k_val;
1384 self.cnt_d += 1;
1385 }
1386 if !self.k_filled && self.k_pos == 0 {
1387 self.k_filled = true;
1388 }
1389
1390 if self.k_filled {
1391 d_val = if self.cnt_d > 0 {
1392 self.sum_d * self.inv_cnt_d[self.cnt_d]
1393 } else {
1394 f64::NAN
1395 };
1396 d_now_available = true;
1397 } else {
1398 d_now_available = false;
1399 }
1400 } else {
1401 if !self.d_ema_inited {
1402 self.k_samples += 1;
1403 if !k_val.is_nan() {
1404 self.init_sum_d += k_val;
1405 self.init_cnt_d += 1;
1406 }
1407 if self.k_samples == self.slow_d_period {
1408 self.ema_d = if self.init_cnt_d > 0 {
1409 self.init_sum_d * self.inv_cnt_d[self.init_cnt_d]
1410 } else {
1411 f64::NAN
1412 };
1413 self.d_ema_inited = true;
1414 d_val = self.ema_d;
1415 d_now_available = true;
1416 } else {
1417 d_now_available = false;
1418 }
1419 } else {
1420 if !k_val.is_nan() && !self.ema_d.is_nan() {
1421 self.ema_d = k_val.mul_add(self.alpha_d, self.om_alpha_d * self.ema_d);
1422 } else if !k_val.is_nan() && self.ema_d.is_nan() {
1423 self.ema_d = k_val;
1424 }
1425 d_val = self.ema_d;
1426 d_now_available = true;
1427 }
1428 }
1429
1430 if !self.d_is_ema {
1431 self.k_samples = self.k_samples.saturating_add(1);
1432 }
1433
1434 if !d_now_available {
1435 return None;
1436 }
1437
1438 let j_val = if k_val.is_nan() || d_val.is_nan() {
1439 f64::NAN
1440 } else {
1441 k_val.mul_add(3.0, -2.0 * d_val)
1442 };
1443
1444 Some((k_val, d_val, j_val))
1445 }
1446}
1447
1448#[derive(Clone, Debug)]
1449pub struct KdjBatchRange {
1450 pub fast_k_period: (usize, usize, usize),
1451 pub slow_k_period: (usize, usize, usize),
1452 pub slow_k_ma_type: (String, String, String),
1453 pub slow_d_period: (usize, usize, usize),
1454 pub slow_d_ma_type: (String, String, String),
1455}
1456
1457impl Default for KdjBatchRange {
1458 fn default() -> Self {
1459 Self {
1460 fast_k_period: (9, 258, 1),
1461 slow_k_period: (3, 3, 0),
1462 slow_k_ma_type: ("sma".to_string(), "sma".to_string(), "".to_string()),
1463 slow_d_period: (3, 3, 0),
1464 slow_d_ma_type: ("sma".to_string(), "sma".to_string(), "".to_string()),
1465 }
1466 }
1467}
1468
1469#[derive(Clone, Debug, Default)]
1470pub struct KdjBatchBuilder {
1471 range: KdjBatchRange,
1472 kernel: Kernel,
1473}
1474
1475impl KdjBatchBuilder {
1476 pub fn new() -> Self {
1477 Self::default()
1478 }
1479 pub fn kernel(mut self, k: Kernel) -> Self {
1480 self.kernel = k;
1481 self
1482 }
1483
1484 #[inline]
1485 pub fn fast_k_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1486 self.range.fast_k_period = (start, end, step);
1487 self
1488 }
1489 #[inline]
1490 pub fn slow_k_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1491 self.range.slow_k_period = (start, end, step);
1492 self
1493 }
1494 #[inline]
1495 pub fn slow_k_ma_type_static<S: Into<String>>(mut self, s: S) -> Self {
1496 let v = s.into();
1497 self.range.slow_k_ma_type = (v.clone(), v, "".to_string());
1498 self
1499 }
1500 #[inline]
1501 pub fn slow_d_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1502 self.range.slow_d_period = (start, end, step);
1503 self
1504 }
1505 #[inline]
1506 pub fn slow_d_ma_type_static<S: Into<String>>(mut self, s: S) -> Self {
1507 let v = s.into();
1508 self.range.slow_d_ma_type = (v.clone(), v, "".to_string());
1509 self
1510 }
1511
1512 pub fn apply_slices(
1513 self,
1514 high: &[f64],
1515 low: &[f64],
1516 close: &[f64],
1517 ) -> Result<KdjBatchOutput, KdjError> {
1518 kdj_batch_with_kernel(high, low, close, &self.range, self.kernel)
1519 }
1520
1521 pub fn apply_candles(self, c: &Candles) -> Result<KdjBatchOutput, KdjError> {
1522 let high = source_type(c, "high");
1523 let low = source_type(c, "low");
1524 let close = source_type(c, "close");
1525 self.apply_slices(high, low, close)
1526 }
1527}
1528
1529pub fn kdj_batch_with_kernel(
1530 high: &[f64],
1531 low: &[f64],
1532 close: &[f64],
1533 sweep: &KdjBatchRange,
1534 k: Kernel,
1535) -> Result<KdjBatchOutput, KdjError> {
1536 let kernel = match k {
1537 Kernel::Auto => detect_best_batch_kernel(),
1538 other if other.is_batch() => other,
1539 other => {
1540 return Err(KdjError::InvalidKernelForBatch(other));
1541 }
1542 };
1543
1544 let simd = match kernel {
1545 Kernel::Avx512Batch => Kernel::Avx512,
1546 Kernel::Avx2Batch => Kernel::Avx2,
1547 Kernel::ScalarBatch => Kernel::Scalar,
1548 _ => unreachable!(),
1549 };
1550 kdj_batch_par_slice(high, low, close, sweep, simd)
1551}
1552
1553#[derive(Clone, Debug)]
1554pub struct KdjBatchOutput {
1555 pub k: Vec<f64>,
1556 pub d: Vec<f64>,
1557 pub j: Vec<f64>,
1558 pub combos: Vec<KdjParams>,
1559 pub rows: usize,
1560 pub cols: usize,
1561}
1562impl KdjBatchOutput {
1563 pub fn row_for_params(&self, p: &KdjParams) -> Option<usize> {
1564 self.combos.iter().position(|c| {
1565 c.fast_k_period.unwrap_or(9) == p.fast_k_period.unwrap_or(9)
1566 && c.slow_k_period.unwrap_or(3) == p.slow_k_period.unwrap_or(3)
1567 && c.slow_k_ma_type.as_deref().unwrap_or("sma")
1568 == p.slow_k_ma_type.as_deref().unwrap_or("sma")
1569 && c.slow_d_period.unwrap_or(3) == p.slow_d_period.unwrap_or(3)
1570 && c.slow_d_ma_type.as_deref().unwrap_or("sma")
1571 == p.slow_d_ma_type.as_deref().unwrap_or("sma")
1572 })
1573 }
1574 pub fn k_for(&self, p: &KdjParams) -> Option<&[f64]> {
1575 self.row_for_params(p)
1576 .map(|row| &self.k[row * self.cols..][..self.cols])
1577 }
1578 pub fn d_for(&self, p: &KdjParams) -> Option<&[f64]> {
1579 self.row_for_params(p)
1580 .map(|row| &self.d[row * self.cols..][..self.cols])
1581 }
1582 pub fn j_for(&self, p: &KdjParams) -> Option<&[f64]> {
1583 self.row_for_params(p)
1584 .map(|row| &self.j[row * self.cols..][..self.cols])
1585 }
1586}
1587
1588#[inline(always)]
1589fn expand_grid(r: &KdjBatchRange) -> Result<Vec<KdjParams>, KdjError> {
1590 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, KdjError> {
1591 if step == 0 || start == end {
1592 return Ok(vec![start]);
1593 }
1594 let mut v = Vec::new();
1595 if start < end {
1596 let mut cur = start;
1597 while cur <= end {
1598 v.push(cur);
1599 let next = cur.saturating_add(step);
1600 if next == cur {
1601 break;
1602 }
1603 cur = next;
1604 }
1605 } else {
1606 let mut cur = start;
1607 while cur >= end {
1608 v.push(cur);
1609 let next = cur.saturating_sub(step);
1610 if next == cur {
1611 break;
1612 }
1613 cur = next;
1614 if cur == 0 && end > 0 {
1615 break;
1616 }
1617 }
1618 }
1619 if v.is_empty() {
1620 return Err(KdjError::InvalidRange { start, end, step });
1621 }
1622 Ok(v)
1623 }
1624 fn axis_str((start, end, _): (String, String, String)) -> Vec<String> {
1625 if start == end {
1626 vec![start]
1627 } else {
1628 vec![start, end]
1629 }
1630 }
1631 let fast_k_periods = axis_usize(r.fast_k_period)?;
1632 let slow_k_periods = axis_usize(r.slow_k_period)?;
1633 let slow_k_ma_types = axis_str(r.slow_k_ma_type.clone());
1634 let slow_d_periods = axis_usize(r.slow_d_period)?;
1635 let slow_d_ma_types = axis_str(r.slow_d_ma_type.clone());
1636 let mut out = Vec::new();
1637 for &fkp in &fast_k_periods {
1638 for &skp in &slow_k_periods {
1639 for skmt in &slow_k_ma_types {
1640 for &sdp in &slow_d_periods {
1641 for sdmt in &slow_d_ma_types {
1642 out.push(KdjParams {
1643 fast_k_period: Some(fkp),
1644 slow_k_period: Some(skp),
1645 slow_k_ma_type: Some(skmt.clone()),
1646 slow_d_period: Some(sdp),
1647 slow_d_ma_type: Some(sdmt.clone()),
1648 });
1649 }
1650 }
1651 }
1652 }
1653 }
1654 Ok(out)
1655}
1656
1657#[inline(always)]
1658pub fn kdj_batch_slice(
1659 high: &[f64],
1660 low: &[f64],
1661 close: &[f64],
1662 sweep: &KdjBatchRange,
1663 kern: Kernel,
1664) -> Result<KdjBatchOutput, KdjError> {
1665 kdj_batch_inner(high, low, close, sweep, kern, false)
1666}
1667
1668#[inline(always)]
1669pub fn kdj_batch_par_slice(
1670 high: &[f64],
1671 low: &[f64],
1672 close: &[f64],
1673 sweep: &KdjBatchRange,
1674 kern: Kernel,
1675) -> Result<KdjBatchOutput, KdjError> {
1676 kdj_batch_inner(high, low, close, sweep, kern, true)
1677}
1678
1679#[inline(always)]
1680fn kdj_batch_inner(
1681 high: &[f64],
1682 low: &[f64],
1683 close: &[f64],
1684 sweep: &KdjBatchRange,
1685 kern: Kernel,
1686 parallel: bool,
1687) -> Result<KdjBatchOutput, KdjError> {
1688 if high.is_empty() || low.is_empty() || close.is_empty() {
1689 return Err(KdjError::EmptyInputData);
1690 }
1691 let cols = high.len();
1692
1693 let combos = expand_grid(sweep)?;
1694 for c in &combos {
1695 let fk = c.fast_k_period.unwrap();
1696 let sk = c.slow_k_period.unwrap();
1697 let sd = c.slow_d_period.unwrap();
1698 if fk == 0 || sk == 0 || sd == 0 {
1699 return Err(KdjError::InvalidPeriod {
1700 period: 0,
1701 data_len: cols,
1702 });
1703 }
1704 }
1705 let first = high
1706 .iter()
1707 .zip(low.iter())
1708 .zip(close.iter())
1709 .position(|((&h, &l), &c)| !h.is_nan() && !l.is_nan() && !c.is_nan())
1710 .ok_or(KdjError::AllValuesNaN)?;
1711
1712 let max_p = combos
1713 .iter()
1714 .map(|c| c.fast_k_period.unwrap())
1715 .max()
1716 .unwrap();
1717 if high.len() - first < max_p {
1718 return Err(KdjError::NotEnoughValidData {
1719 needed: max_p,
1720 valid: high.len() - first,
1721 });
1722 }
1723 let rows = combos.len();
1724 let _ = rows.checked_mul(cols).ok_or(KdjError::InvalidRange {
1725 start: sweep.fast_k_period.0,
1726 end: sweep.fast_k_period.1,
1727 step: sweep.fast_k_period.2,
1728 })?;
1729
1730 let mut k_mu = make_uninit_matrix(rows, cols);
1731 let mut d_mu = make_uninit_matrix(rows, cols);
1732 let mut j_mu = make_uninit_matrix(rows, cols);
1733
1734 let warmup_periods: Vec<usize> = combos
1735 .iter()
1736 .map(|c| {
1737 let fast_k = c.fast_k_period.unwrap();
1738 let slow_k = c.slow_k_period.unwrap();
1739 let slow_d = c.slow_d_period.unwrap();
1740 first
1741 .checked_add(fast_k)
1742 .and_then(|v| v.checked_add(slow_k))
1743 .and_then(|v| v.checked_add(slow_d))
1744 .and_then(|v| v.checked_sub(3))
1745 .ok_or(KdjError::InvalidRange {
1746 start: sweep.fast_k_period.0,
1747 end: sweep.fast_k_period.1,
1748 step: sweep.fast_k_period.2,
1749 })
1750 })
1751 .collect::<Result<Vec<usize>, KdjError>>()?;
1752
1753 init_matrix_prefixes(&mut k_mu, cols, &warmup_periods);
1754 init_matrix_prefixes(&mut d_mu, cols, &warmup_periods);
1755 init_matrix_prefixes(&mut j_mu, cols, &warmup_periods);
1756
1757 let mut k_guard = core::mem::ManuallyDrop::new(k_mu);
1758 let mut d_guard = core::mem::ManuallyDrop::new(d_mu);
1759 let mut j_guard = core::mem::ManuallyDrop::new(j_mu);
1760
1761 let k_vals: &mut [f64] =
1762 unsafe { core::slice::from_raw_parts_mut(k_guard.as_mut_ptr() as *mut f64, k_guard.len()) };
1763 let d_vals: &mut [f64] =
1764 unsafe { core::slice::from_raw_parts_mut(d_guard.as_mut_ptr() as *mut f64, d_guard.len()) };
1765 let j_vals: &mut [f64] =
1766 unsafe { core::slice::from_raw_parts_mut(j_guard.as_mut_ptr() as *mut f64, j_guard.len()) };
1767
1768 let chosen = match kern {
1769 Kernel::Auto => detect_best_batch_kernel(),
1770 k => k,
1771 };
1772
1773 let unique_fast: std::collections::BTreeSet<usize> =
1774 combos.iter().map(|c| c.fast_k_period.unwrap()).collect();
1775
1776 let use_stoch_cache = unique_fast.len() < combos.len();
1777 let mut stoch_cache: std::collections::HashMap<usize, Vec<f64>> =
1778 std::collections::HashMap::new();
1779 if use_stoch_cache {
1780 use std::collections::VecDeque;
1781 for &fk in &unique_fast {
1782 let stoch_warm = first + fk - 1;
1783 let mut stoch = alloc_with_nan_prefix(cols, stoch_warm);
1784 let mut maxdq: VecDeque<usize> = VecDeque::with_capacity(fk + 1);
1785 let mut mindq: VecDeque<usize> = VecDeque::with_capacity(fk + 1);
1786 for i in first..cols {
1787 let hi = unsafe { *high.get_unchecked(i) };
1788 while let Some(&idx) = maxdq.back() {
1789 if unsafe { *high.get_unchecked(idx) } <= hi {
1790 maxdq.pop_back();
1791 } else {
1792 break;
1793 }
1794 }
1795 maxdq.push_back(i);
1796 while let Some(&idx) = maxdq.front() {
1797 if idx + fk <= i {
1798 maxdq.pop_front();
1799 } else {
1800 break;
1801 }
1802 }
1803
1804 let lo = unsafe { *low.get_unchecked(i) };
1805 while let Some(&idx) = mindq.back() {
1806 if unsafe { *low.get_unchecked(idx) } >= lo {
1807 mindq.pop_back();
1808 } else {
1809 break;
1810 }
1811 }
1812 mindq.push_back(i);
1813 while let Some(&idx) = mindq.front() {
1814 if idx + fk <= i {
1815 mindq.pop_front();
1816 } else {
1817 break;
1818 }
1819 }
1820
1821 if i < stoch_warm {
1822 continue;
1823 }
1824
1825 let hh = unsafe { *high.get_unchecked(*maxdq.front().unwrap()) };
1826 let ll = unsafe { *low.get_unchecked(*mindq.front().unwrap()) };
1827 let denom = hh - ll;
1828 let val = if denom == 0.0 || denom.is_nan() {
1829 f64::NAN
1830 } else {
1831 let c = unsafe { *close.get_unchecked(i) };
1832 100.0 * ((c - ll) / denom)
1833 };
1834 unsafe { *stoch.get_unchecked_mut(i) = val };
1835 }
1836 stoch_cache.insert(fk, stoch);
1837 }
1838 }
1839
1840 let do_row = |row: usize,
1841 out_k: &mut [f64],
1842 out_d: &mut [f64],
1843 out_j: &mut [f64]|
1844 -> Result<(), KdjError> {
1845 let prm = &combos[row];
1846 let fast_k = prm.fast_k_period.unwrap();
1847 let slow_k = prm.slow_k_period.unwrap();
1848 let slow_k_ma = prm.slow_k_ma_type.as_deref().unwrap_or("sma");
1849 let slow_d = prm.slow_d_period.unwrap();
1850 let slow_d_ma = prm.slow_d_ma_type.as_deref().unwrap_or("sma");
1851
1852 if use_stoch_cache {
1853 let stoch = stoch_cache
1854 .get(&fast_k)
1855 .expect("stoch cache missing fast_k");
1856 let stoch_warm = first + fast_k - 1;
1857
1858 if slow_k_ma.eq_ignore_ascii_case("sma") && slow_d_ma.eq_ignore_ascii_case("sma") {
1859 return kdj_classic_sma(stoch, slow_k, slow_d, stoch_warm, out_k, out_d, out_j);
1860 }
1861 if slow_k_ma.eq_ignore_ascii_case("ema") && slow_d_ma.eq_ignore_ascii_case("ema") {
1862 return kdj_classic_ema(stoch, slow_k, slow_d, stoch_warm, out_k, out_d, out_j);
1863 }
1864
1865 let k_vec = ma(slow_k_ma, MaData::Slice(stoch), slow_k)
1866 .map_err(|e| KdjError::MaError(e.to_string().into()))?;
1867 let d_vec = ma(slow_d_ma, MaData::Slice(&k_vec), slow_d)
1868 .map_err(|e| KdjError::MaError(e.to_string().into()))?;
1869 out_k.copy_from_slice(&k_vec);
1870 out_d.copy_from_slice(&d_vec);
1871 let j_warm = stoch_warm + slow_k - 1 + slow_d - 1;
1872 for i in 0..j_warm.min(cols) {
1873 out_j[i] = f64::NAN;
1874 }
1875 for i in j_warm..cols {
1876 out_j[i] = if out_k[i].is_nan() || out_d[i].is_nan() {
1877 f64::NAN
1878 } else {
1879 3.0 * out_k[i] - 2.0 * out_d[i]
1880 };
1881 }
1882 return Ok(());
1883 }
1884
1885 match chosen {
1886 Kernel::Scalar | Kernel::ScalarBatch => kdj_row_scalar(
1887 high, low, close, first, fast_k, slow_k, slow_k_ma, slow_d, slow_d_ma, out_k,
1888 out_d, out_j,
1889 ),
1890 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1891 Kernel::Avx2 | Kernel::Avx2Batch => unsafe {
1892 kdj_row_avx2(
1893 high, low, close, first, fast_k, slow_k, slow_k_ma, slow_d, slow_d_ma, out_k,
1894 out_d, out_j,
1895 )
1896 },
1897 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1898 Kernel::Avx512 | Kernel::Avx512Batch => unsafe {
1899 kdj_row_avx512(
1900 high, low, close, first, fast_k, slow_k, slow_k_ma, slow_d, slow_d_ma, out_k,
1901 out_d, out_j,
1902 )
1903 },
1904 _ => kdj_row_scalar(
1905 high, low, close, first, fast_k, slow_k, slow_k_ma, slow_d, slow_d_ma, out_k,
1906 out_d, out_j,
1907 ),
1908 }
1909 };
1910 if parallel {
1911 #[cfg(not(target_arch = "wasm32"))]
1912 {
1913 use rayon::prelude::*;
1914 k_vals
1915 .par_chunks_mut(cols)
1916 .zip(d_vals.par_chunks_mut(cols))
1917 .zip(j_vals.par_chunks_mut(cols))
1918 .enumerate()
1919 .try_for_each(|(row, ((ok, od), oj))| do_row(row, ok, od, oj))?;
1920 }
1921 #[cfg(target_arch = "wasm32")]
1922 {
1923 for (row, ((ok, od), oj)) in k_vals
1924 .chunks_mut(cols)
1925 .zip(d_vals.chunks_mut(cols))
1926 .zip(j_vals.chunks_mut(cols))
1927 .enumerate()
1928 {
1929 do_row(row, ok, od, oj)?;
1930 }
1931 }
1932 } else {
1933 for (row, ((ok, od), oj)) in k_vals
1934 .chunks_mut(cols)
1935 .zip(d_vals.chunks_mut(cols))
1936 .zip(j_vals.chunks_mut(cols))
1937 .enumerate()
1938 {
1939 do_row(row, ok, od, oj)?;
1940 }
1941 }
1942
1943 let k_vec = unsafe {
1944 Vec::from_raw_parts(
1945 k_guard.as_mut_ptr() as *mut f64,
1946 k_guard.len(),
1947 k_guard.capacity(),
1948 )
1949 };
1950 let d_vec = unsafe {
1951 Vec::from_raw_parts(
1952 d_guard.as_mut_ptr() as *mut f64,
1953 d_guard.len(),
1954 d_guard.capacity(),
1955 )
1956 };
1957 let j_vec = unsafe {
1958 Vec::from_raw_parts(
1959 j_guard.as_mut_ptr() as *mut f64,
1960 j_guard.len(),
1961 j_guard.capacity(),
1962 )
1963 };
1964
1965 Ok(KdjBatchOutput {
1966 k: k_vec,
1967 d: d_vec,
1968 j: j_vec,
1969 combos,
1970 rows,
1971 cols,
1972 })
1973}
1974
1975#[inline(always)]
1976fn kdj_row_scalar(
1977 high: &[f64],
1978 low: &[f64],
1979 close: &[f64],
1980 first: usize,
1981 fast_k_period: usize,
1982 slow_k_period: usize,
1983 slow_k_ma_type: &str,
1984 slow_d_period: usize,
1985 slow_d_ma_type: &str,
1986 out_k: &mut [f64],
1987 out_d: &mut [f64],
1988 out_j: &mut [f64],
1989) -> Result<(), KdjError> {
1990 kdj_compute_into_scalar(
1991 high,
1992 low,
1993 close,
1994 first,
1995 fast_k_period,
1996 slow_k_period,
1997 slow_k_ma_type,
1998 slow_d_period,
1999 slow_d_ma_type,
2000 out_k,
2001 out_d,
2002 out_j,
2003 )
2004}
2005
2006#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2007#[inline(always)]
2008unsafe fn kdj_row_avx2(
2009 high: &[f64],
2010 low: &[f64],
2011 close: &[f64],
2012 first: usize,
2013 fast_k_period: usize,
2014 slow_k_period: usize,
2015 slow_k_ma_type: &str,
2016 slow_d_period: usize,
2017 slow_d_ma_type: &str,
2018 out_k: &mut [f64],
2019 out_d: &mut [f64],
2020 out_j: &mut [f64],
2021) -> Result<(), KdjError> {
2022 kdj_row_scalar(
2023 high,
2024 low,
2025 close,
2026 first,
2027 fast_k_period,
2028 slow_k_period,
2029 slow_k_ma_type,
2030 slow_d_period,
2031 slow_d_ma_type,
2032 out_k,
2033 out_d,
2034 out_j,
2035 )
2036}
2037
2038#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2039#[inline(always)]
2040unsafe fn kdj_row_avx512(
2041 high: &[f64],
2042 low: &[f64],
2043 close: &[f64],
2044 first: usize,
2045 fast_k_period: usize,
2046 slow_k_period: usize,
2047 slow_k_ma_type: &str,
2048 slow_d_period: usize,
2049 slow_d_ma_type: &str,
2050 out_k: &mut [f64],
2051 out_d: &mut [f64],
2052 out_j: &mut [f64],
2053) -> Result<(), KdjError> {
2054 if fast_k_period <= 32 {
2055 kdj_row_avx512_short(
2056 high,
2057 low,
2058 close,
2059 first,
2060 fast_k_period,
2061 slow_k_period,
2062 slow_k_ma_type,
2063 slow_d_period,
2064 slow_d_ma_type,
2065 out_k,
2066 out_d,
2067 out_j,
2068 )
2069 } else {
2070 kdj_row_avx512_long(
2071 high,
2072 low,
2073 close,
2074 first,
2075 fast_k_period,
2076 slow_k_period,
2077 slow_k_ma_type,
2078 slow_d_period,
2079 slow_d_ma_type,
2080 out_k,
2081 out_d,
2082 out_j,
2083 )
2084 }
2085}
2086
2087#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2088#[inline(always)]
2089unsafe fn kdj_row_avx512_short(
2090 high: &[f64],
2091 low: &[f64],
2092 close: &[f64],
2093 first: usize,
2094 fast_k_period: usize,
2095 slow_k_period: usize,
2096 slow_k_ma_type: &str,
2097 slow_d_period: usize,
2098 slow_d_ma_type: &str,
2099 out_k: &mut [f64],
2100 out_d: &mut [f64],
2101 out_j: &mut [f64],
2102) -> Result<(), KdjError> {
2103 kdj_row_scalar(
2104 high,
2105 low,
2106 close,
2107 first,
2108 fast_k_period,
2109 slow_k_period,
2110 slow_k_ma_type,
2111 slow_d_period,
2112 slow_d_ma_type,
2113 out_k,
2114 out_d,
2115 out_j,
2116 )
2117}
2118
2119#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2120#[inline(always)]
2121unsafe fn kdj_row_avx512_long(
2122 high: &[f64],
2123 low: &[f64],
2124 close: &[f64],
2125 first: usize,
2126 fast_k_period: usize,
2127 slow_k_period: usize,
2128 slow_k_ma_type: &str,
2129 slow_d_period: usize,
2130 slow_d_ma_type: &str,
2131 out_k: &mut [f64],
2132 out_d: &mut [f64],
2133 out_j: &mut [f64],
2134) -> Result<(), KdjError> {
2135 kdj_row_scalar(
2136 high,
2137 low,
2138 close,
2139 first,
2140 fast_k_period,
2141 slow_k_period,
2142 slow_k_ma_type,
2143 slow_d_period,
2144 slow_d_ma_type,
2145 out_k,
2146 out_d,
2147 out_j,
2148 )
2149}
2150
2151#[cfg(test)]
2152mod tests {
2153 use super::*;
2154 use crate::skip_if_unsupported;
2155 use crate::utilities::data_loader::read_candles_from_csv;
2156 use crate::utilities::enums::Kernel;
2157
2158 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2159 #[test]
2160 fn test_kdj_into_matches_api() -> Result<(), Box<dyn Error>> {
2161 let n = 256usize;
2162 let mut high = Vec::with_capacity(n);
2163 let mut low = Vec::with_capacity(n);
2164 let mut close = Vec::with_capacity(n);
2165
2166 for _ in 0..4 {
2167 high.push(f64::NAN);
2168 low.push(f64::NAN);
2169 close.push(f64::NAN);
2170 }
2171 for i in 0..(n - 4) {
2172 let i_f = i as f64;
2173 let base = 100.0 + 0.1 * i_f + ((i % 7) as f64) * 0.5;
2174 close.push(base);
2175 high.push(base + 1.0 + ((i % 5) as f64) * 0.1);
2176 low.push(base - 1.0 - ((i % 7) as f64) * 0.1);
2177 }
2178
2179 let params = KdjParams::default();
2180 let input = KdjInput::from_slices(&high, &low, &close, params);
2181
2182 let baseline = kdj(&input)?;
2183
2184 let mut k = vec![0.0; close.len()];
2185 let mut d = vec![0.0; close.len()];
2186 let mut j = vec![0.0; close.len()];
2187 kdj_into(&input, &mut k, &mut d, &mut j)?;
2188
2189 assert_eq!(baseline.k.len(), k.len());
2190 assert_eq!(baseline.d.len(), d.len());
2191 assert_eq!(baseline.j.len(), j.len());
2192 for idx in 0..n {
2193 let a = baseline.k[idx];
2194 let b = k[idx];
2195 let ok = (a.is_nan() && b.is_nan()) || (a == b);
2196 assert!(ok, "K mismatch at {idx}: api={a} into={b}");
2197
2198 let a = baseline.d[idx];
2199 let b = d[idx];
2200 let ok = (a.is_nan() && b.is_nan()) || (a == b);
2201 assert!(ok, "D mismatch at {idx}: api={a} into={b}");
2202
2203 let a = baseline.j[idx];
2204 let b = j[idx];
2205 let ok = (a.is_nan() && b.is_nan()) || (a == b);
2206 assert!(ok, "J mismatch at {idx}: api={a} into={b}");
2207 }
2208 Ok(())
2209 }
2210
2211 fn check_kdj_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2212 skip_if_unsupported!(kernel, test_name);
2213 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2214 let candles = read_candles_from_csv(file_path)?;
2215 let partial_params = KdjParams {
2216 fast_k_period: None,
2217 slow_k_period: Some(4),
2218 slow_k_ma_type: None,
2219 slow_d_period: None,
2220 slow_d_ma_type: None,
2221 };
2222 let input = KdjInput::from_candles(&candles, partial_params);
2223 let output = kdj_with_kernel(&input, kernel)?;
2224 assert_eq!(output.k.len(), candles.close.len());
2225 assert_eq!(output.d.len(), candles.close.len());
2226 assert_eq!(output.j.len(), candles.close.len());
2227 Ok(())
2228 }
2229
2230 fn check_kdj_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2231 skip_if_unsupported!(kernel, test_name);
2232 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2233 let candles = read_candles_from_csv(file_path)?;
2234 let params = KdjParams::default();
2235 let input = KdjInput::from_candles(&candles, params);
2236 let result = kdj_with_kernel(&input, kernel)?;
2237 let expected_k = [
2238 58.04341315415984,
2239 61.56034740940419,
2240 58.056304282719545,
2241 56.10961365678364,
2242 51.43992326447119,
2243 ];
2244 let expected_d = [
2245 49.57659409278555,
2246 56.81719223571944,
2247 59.22002161542779,
2248 58.57542178296905,
2249 55.20194706799139,
2250 ];
2251 let expected_j = [
2252 74.97705127690843,
2253 71.04665775677368,
2254 55.72886961730306,
2255 51.17799740441281,
2256 43.91587565743079,
2257 ];
2258 let len = result.k.len();
2259 let start_idx = len - 5;
2260 for i in 0..5 {
2261 let k_val = result.k[start_idx + i];
2262 let d_val = result.d[start_idx + i];
2263 let j_val = result.j[start_idx + i];
2264 assert!(
2265 (k_val - expected_k[i]).abs() < 1e-4,
2266 "Mismatch in K at index {}: expected {}, got {}",
2267 i,
2268 expected_k[i],
2269 k_val
2270 );
2271 assert!(
2272 (d_val - expected_d[i]).abs() < 1e-4,
2273 "Mismatch in D at index {}: expected {}, got {}",
2274 i,
2275 expected_d[i],
2276 d_val
2277 );
2278 assert!(
2279 (j_val - expected_j[i]).abs() < 1e-4,
2280 "Mismatch in J at index {}: expected {}, got {}",
2281 i,
2282 expected_j[i],
2283 j_val
2284 );
2285 }
2286 Ok(())
2287 }
2288
2289 fn check_kdj_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2290 skip_if_unsupported!(kernel, test_name);
2291 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2292 let candles = read_candles_from_csv(file_path)?;
2293 let input = KdjInput::with_default_candles(&candles);
2294 match input.data {
2295 KdjData::Candles { .. } => {}
2296 _ => panic!("Expected KdjData::Candles variant"),
2297 }
2298 let output = kdj_with_kernel(&input, kernel)?;
2299 assert_eq!(output.k.len(), candles.close.len());
2300 Ok(())
2301 }
2302
2303 fn check_kdj_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2304 skip_if_unsupported!(kernel, test_name);
2305 let input_data = [10.0, 20.0, 30.0];
2306 let params = KdjParams {
2307 fast_k_period: Some(0),
2308 ..Default::default()
2309 };
2310 let input = KdjInput::from_slices(&input_data, &input_data, &input_data, params);
2311 let result = kdj_with_kernel(&input, kernel);
2312 assert!(
2313 result.is_err(),
2314 "[{}] KDJ should fail with zero period",
2315 test_name
2316 );
2317 Ok(())
2318 }
2319
2320 fn check_kdj_period_exceeds_length(
2321 test_name: &str,
2322 kernel: Kernel,
2323 ) -> Result<(), Box<dyn Error>> {
2324 skip_if_unsupported!(kernel, test_name);
2325 let input_data = [10.0, 20.0, 30.0];
2326 let params = KdjParams {
2327 fast_k_period: Some(10),
2328 ..Default::default()
2329 };
2330 let input = KdjInput::from_slices(&input_data, &input_data, &input_data, params);
2331 let result = kdj_with_kernel(&input, kernel);
2332 assert!(
2333 result.is_err(),
2334 "[{}] KDJ should fail with period exceeding length",
2335 test_name
2336 );
2337 Ok(())
2338 }
2339
2340 fn check_kdj_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2341 skip_if_unsupported!(kernel, test_name);
2342 let single_point = [42.0];
2343 let params = KdjParams {
2344 fast_k_period: Some(9),
2345 ..Default::default()
2346 };
2347 let input = KdjInput::from_slices(&single_point, &single_point, &single_point, params);
2348 let result = kdj_with_kernel(&input, kernel);
2349 assert!(
2350 result.is_err(),
2351 "[{}] KDJ should fail with insufficient data",
2352 test_name
2353 );
2354 Ok(())
2355 }
2356
2357 fn check_kdj_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2358 skip_if_unsupported!(kernel, test_name);
2359 let input_data = [f64::NAN, f64::NAN, f64::NAN];
2360 let params = KdjParams::default();
2361 let input = KdjInput::from_slices(&input_data, &input_data, &input_data, params);
2362 let result = kdj_with_kernel(&input, kernel);
2363 assert!(
2364 result.is_err(),
2365 "[{}] KDJ should fail with all-NaN data",
2366 test_name
2367 );
2368 Ok(())
2369 }
2370
2371 fn check_kdj_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2372 skip_if_unsupported!(kernel, test_name);
2373 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2374 let candles = read_candles_from_csv(file_path)?;
2375 let first_params = KdjParams {
2376 fast_k_period: Some(9),
2377 slow_k_period: Some(3),
2378 slow_k_ma_type: Some("sma".to_string()),
2379 slow_d_period: Some(3),
2380 slow_d_ma_type: Some("sma".to_string()),
2381 };
2382 let first_input = KdjInput::from_candles(&candles, first_params);
2383 let first_result = kdj_with_kernel(&first_input, kernel)?;
2384 assert_eq!(first_result.k.len(), candles.close.len());
2385
2386 let second_params = KdjParams {
2387 fast_k_period: Some(9),
2388 slow_k_period: Some(3),
2389 slow_k_ma_type: Some("sma".to_string()),
2390 slow_d_period: Some(3),
2391 slow_d_ma_type: Some("sma".to_string()),
2392 };
2393 let second_input = KdjInput::from_slices(
2394 &first_result.k,
2395 &first_result.k,
2396 &first_result.k,
2397 second_params,
2398 );
2399 let second_result = kdj_with_kernel(&second_input, kernel)?;
2400 assert_eq!(second_result.k.len(), first_result.k.len());
2401 for i in 50..second_result.k.len() {
2402 assert!(
2403 !second_result.k[i].is_nan(),
2404 "[{}] Expected no NaN in second KDJ at {}",
2405 test_name,
2406 i
2407 );
2408 }
2409 Ok(())
2410 }
2411
2412 fn check_kdj_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2413 skip_if_unsupported!(kernel, test_name);
2414 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2415 let candles = read_candles_from_csv(file_path)?;
2416 let params = KdjParams::default();
2417 let input = KdjInput::from_candles(&candles, params);
2418 let result = kdj_with_kernel(&input, kernel)?;
2419 if result.k.len() > 50 {
2420 for i in 50..result.k.len() {
2421 assert!(
2422 !result.k[i].is_nan(),
2423 "[{}] Expected no NaN in K after index 50 at {}",
2424 test_name,
2425 i
2426 );
2427 }
2428 }
2429 Ok(())
2430 }
2431
2432 #[cfg(debug_assertions)]
2433 fn check_kdj_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2434 skip_if_unsupported!(kernel, test_name);
2435
2436 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2437 let candles = read_candles_from_csv(file_path)?;
2438
2439 let test_params = vec![
2440 KdjParams::default(),
2441 KdjParams {
2442 fast_k_period: Some(2),
2443 slow_k_period: Some(2),
2444 slow_k_ma_type: Some("sma".to_string()),
2445 slow_d_period: Some(2),
2446 slow_d_ma_type: Some("sma".to_string()),
2447 },
2448 KdjParams {
2449 fast_k_period: Some(3),
2450 slow_k_period: Some(3),
2451 slow_k_ma_type: Some("ema".to_string()),
2452 slow_d_period: Some(3),
2453 slow_d_ma_type: Some("ema".to_string()),
2454 },
2455 KdjParams {
2456 fast_k_period: Some(5),
2457 slow_k_period: Some(2),
2458 slow_k_ma_type: Some("sma".to_string()),
2459 slow_d_period: Some(2),
2460 slow_d_ma_type: Some("sma".to_string()),
2461 },
2462 KdjParams {
2463 fast_k_period: Some(10),
2464 slow_k_period: Some(5),
2465 slow_k_ma_type: Some("wma".to_string()),
2466 slow_d_period: Some(5),
2467 slow_d_ma_type: Some("wma".to_string()),
2468 },
2469 KdjParams {
2470 fast_k_period: Some(14),
2471 slow_k_period: Some(3),
2472 slow_k_ma_type: Some("sma".to_string()),
2473 slow_d_period: Some(3),
2474 slow_d_ma_type: Some("sma".to_string()),
2475 },
2476 KdjParams {
2477 fast_k_period: Some(20),
2478 slow_k_period: Some(4),
2479 slow_k_ma_type: Some("ema".to_string()),
2480 slow_d_period: Some(6),
2481 slow_d_ma_type: Some("ema".to_string()),
2482 },
2483 KdjParams {
2484 fast_k_period: Some(30),
2485 slow_k_period: Some(6),
2486 slow_k_ma_type: Some("hma".to_string()),
2487 slow_d_period: Some(8),
2488 slow_d_ma_type: Some("hma".to_string()),
2489 },
2490 KdjParams {
2491 fast_k_period: Some(50),
2492 slow_k_period: Some(10),
2493 slow_k_ma_type: Some("sma".to_string()),
2494 slow_d_period: Some(10),
2495 slow_d_ma_type: Some("sma".to_string()),
2496 },
2497 KdjParams {
2498 fast_k_period: Some(100),
2499 slow_k_period: Some(20),
2500 slow_k_ma_type: Some("ema".to_string()),
2501 slow_d_period: Some(20),
2502 slow_d_ma_type: Some("ema".to_string()),
2503 },
2504 KdjParams {
2505 fast_k_period: Some(200),
2506 slow_k_period: Some(30),
2507 slow_k_ma_type: Some("sma".to_string()),
2508 slow_d_period: Some(30),
2509 slow_d_ma_type: Some("sma".to_string()),
2510 },
2511 ];
2512
2513 for (param_idx, params) in test_params.iter().enumerate() {
2514 let input = KdjInput::from_candles(&candles, params.clone());
2515 let output = kdj_with_kernel(&input, kernel)?;
2516
2517 for (i, &val) in output.k.iter().enumerate() {
2518 if val.is_nan() {
2519 continue;
2520 }
2521
2522 let bits = val.to_bits();
2523
2524 if bits == 0x11111111_11111111 {
2525 panic!(
2526 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in K \
2527 with params: fast_k_period={}, slow_k_period={}, slow_k_ma_type={}, slow_d_period={}, slow_d_ma_type={} (param set {})",
2528 test_name,
2529 val,
2530 bits,
2531 i,
2532 params.fast_k_period.unwrap_or(9),
2533 params.slow_k_period.unwrap_or(3),
2534 params.slow_k_ma_type.as_deref().unwrap_or("sma"),
2535 params.slow_d_period.unwrap_or(3),
2536 params.slow_d_ma_type.as_deref().unwrap_or("sma"),
2537 param_idx
2538 );
2539 }
2540
2541 if bits == 0x22222222_22222222 {
2542 panic!(
2543 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in K \
2544 with params: fast_k_period={}, slow_k_period={}, slow_k_ma_type={}, slow_d_period={}, slow_d_ma_type={} (param set {})",
2545 test_name,
2546 val,
2547 bits,
2548 i,
2549 params.fast_k_period.unwrap_or(9),
2550 params.slow_k_period.unwrap_or(3),
2551 params.slow_k_ma_type.as_deref().unwrap_or("sma"),
2552 params.slow_d_period.unwrap_or(3),
2553 params.slow_d_ma_type.as_deref().unwrap_or("sma"),
2554 param_idx
2555 );
2556 }
2557
2558 if bits == 0x33333333_33333333 {
2559 panic!(
2560 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in K \
2561 with params: fast_k_period={}, slow_k_period={}, slow_k_ma_type={}, slow_d_period={}, slow_d_ma_type={} (param set {})",
2562 test_name,
2563 val,
2564 bits,
2565 i,
2566 params.fast_k_period.unwrap_or(9),
2567 params.slow_k_period.unwrap_or(3),
2568 params.slow_k_ma_type.as_deref().unwrap_or("sma"),
2569 params.slow_d_period.unwrap_or(3),
2570 params.slow_d_ma_type.as_deref().unwrap_or("sma"),
2571 param_idx
2572 );
2573 }
2574 }
2575
2576 for (i, &val) in output.d.iter().enumerate() {
2577 if val.is_nan() {
2578 continue;
2579 }
2580
2581 let bits = val.to_bits();
2582
2583 if bits == 0x11111111_11111111 {
2584 panic!(
2585 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in D \
2586 with params: fast_k_period={}, slow_k_period={}, slow_k_ma_type={}, slow_d_period={}, slow_d_ma_type={} (param set {})",
2587 test_name,
2588 val,
2589 bits,
2590 i,
2591 params.fast_k_period.unwrap_or(9),
2592 params.slow_k_period.unwrap_or(3),
2593 params.slow_k_ma_type.as_deref().unwrap_or("sma"),
2594 params.slow_d_period.unwrap_or(3),
2595 params.slow_d_ma_type.as_deref().unwrap_or("sma"),
2596 param_idx
2597 );
2598 }
2599
2600 if bits == 0x22222222_22222222 {
2601 panic!(
2602 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in D \
2603 with params: fast_k_period={}, slow_k_period={}, slow_k_ma_type={}, slow_d_period={}, slow_d_ma_type={} (param set {})",
2604 test_name,
2605 val,
2606 bits,
2607 i,
2608 params.fast_k_period.unwrap_or(9),
2609 params.slow_k_period.unwrap_or(3),
2610 params.slow_k_ma_type.as_deref().unwrap_or("sma"),
2611 params.slow_d_period.unwrap_or(3),
2612 params.slow_d_ma_type.as_deref().unwrap_or("sma"),
2613 param_idx
2614 );
2615 }
2616
2617 if bits == 0x33333333_33333333 {
2618 panic!(
2619 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in D \
2620 with params: fast_k_period={}, slow_k_period={}, slow_k_ma_type={}, slow_d_period={}, slow_d_ma_type={} (param set {})",
2621 test_name,
2622 val,
2623 bits,
2624 i,
2625 params.fast_k_period.unwrap_or(9),
2626 params.slow_k_period.unwrap_or(3),
2627 params.slow_k_ma_type.as_deref().unwrap_or("sma"),
2628 params.slow_d_period.unwrap_or(3),
2629 params.slow_d_ma_type.as_deref().unwrap_or("sma"),
2630 param_idx
2631 );
2632 }
2633 }
2634
2635 for (i, &val) in output.j.iter().enumerate() {
2636 if val.is_nan() {
2637 continue;
2638 }
2639
2640 let bits = val.to_bits();
2641
2642 if bits == 0x11111111_11111111 {
2643 panic!(
2644 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in J \
2645 with params: fast_k_period={}, slow_k_period={}, slow_k_ma_type={}, slow_d_period={}, slow_d_ma_type={} (param set {})",
2646 test_name,
2647 val,
2648 bits,
2649 i,
2650 params.fast_k_period.unwrap_or(9),
2651 params.slow_k_period.unwrap_or(3),
2652 params.slow_k_ma_type.as_deref().unwrap_or("sma"),
2653 params.slow_d_period.unwrap_or(3),
2654 params.slow_d_ma_type.as_deref().unwrap_or("sma"),
2655 param_idx
2656 );
2657 }
2658
2659 if bits == 0x22222222_22222222 {
2660 panic!(
2661 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in J \
2662 with params: fast_k_period={}, slow_k_period={}, slow_k_ma_type={}, slow_d_period={}, slow_d_ma_type={} (param set {})",
2663 test_name,
2664 val,
2665 bits,
2666 i,
2667 params.fast_k_period.unwrap_or(9),
2668 params.slow_k_period.unwrap_or(3),
2669 params.slow_k_ma_type.as_deref().unwrap_or("sma"),
2670 params.slow_d_period.unwrap_or(3),
2671 params.slow_d_ma_type.as_deref().unwrap_or("sma"),
2672 param_idx
2673 );
2674 }
2675
2676 if bits == 0x33333333_33333333 {
2677 panic!(
2678 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in J \
2679 with params: fast_k_period={}, slow_k_period={}, slow_k_ma_type={}, slow_d_period={}, slow_d_ma_type={} (param set {})",
2680 test_name,
2681 val,
2682 bits,
2683 i,
2684 params.fast_k_period.unwrap_or(9),
2685 params.slow_k_period.unwrap_or(3),
2686 params.slow_k_ma_type.as_deref().unwrap_or("sma"),
2687 params.slow_d_period.unwrap_or(3),
2688 params.slow_d_ma_type.as_deref().unwrap_or("sma"),
2689 param_idx
2690 );
2691 }
2692 }
2693 }
2694
2695 Ok(())
2696 }
2697
2698 #[cfg(not(debug_assertions))]
2699 fn check_kdj_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2700 Ok(())
2701 }
2702
2703 #[cfg(feature = "proptest")]
2704 #[allow(clippy::float_cmp)]
2705 fn check_kdj_property(
2706 test_name: &str,
2707 kernel: Kernel,
2708 ) -> Result<(), Box<dyn std::error::Error>> {
2709 use proptest::prelude::*;
2710 skip_if_unsupported!(kernel, test_name);
2711
2712 let strat = (5usize..=21, 2usize..=5, 2usize..=5).prop_flat_map(
2713 |(fast_k_period, slow_k_period, slow_d_period)| {
2714 (
2715 (
2716 100f64..10000f64,
2717 0.01f64..0.05f64,
2718 fast_k_period + 10..400,
2719 0u8..100u8,
2720 )
2721 .prop_flat_map(move |(base_price, volatility, data_len, scenario_type)| {
2722 (
2723 Just(base_price),
2724 Just(volatility),
2725 Just(data_len),
2726 Just(scenario_type),
2727 prop::collection::vec((-1f64..1f64), data_len),
2728 prop::collection::vec((0.001f64..0.02f64), data_len),
2729 prop::collection::vec(prop::bool::ANY, data_len),
2730 )
2731 })
2732 .prop_map(
2733 move |(
2734 base_price,
2735 volatility,
2736 data_len,
2737 scenario_type,
2738 price_changes,
2739 spread_factors,
2740 zero_spread_flags,
2741 )| {
2742 let mut high = Vec::with_capacity(data_len);
2743 let mut low = Vec::with_capacity(data_len);
2744 let mut close = Vec::with_capacity(data_len);
2745 let mut current_price = base_price;
2746
2747 for i in 0..data_len {
2748 let (h, l, c) = if scenario_type >= 95 && i > fast_k_period {
2749 (current_price, current_price, current_price)
2750 } else if scenario_type >= 85 && scenario_type < 95 {
2751 current_price = (current_price * 0.99).max(10.0);
2752 let spread = current_price * spread_factors[i] * 0.5;
2753 (
2754 current_price + spread * 0.3,
2755 current_price - spread,
2756 current_price - spread * 0.7,
2757 )
2758 } else if scenario_type >= 70 && scenario_type < 85 {
2759 current_price = current_price * 1.01;
2760 let spread = current_price * spread_factors[i] * 0.5;
2761 (
2762 current_price + spread,
2763 current_price - spread * 0.3,
2764 current_price + spread * 0.7,
2765 )
2766 } else {
2767 let change = price_changes[i] * volatility * current_price;
2768 current_price = (current_price + change).max(10.0);
2769
2770 if zero_spread_flags[i] && i % 10 == 0 {
2771 (current_price, current_price, current_price)
2772 } else {
2773 let spread = current_price * spread_factors[i];
2774 let half_spread = spread / 2.0;
2775 let close_position = (price_changes[i] + 1.0) / 2.0;
2776 let c = current_price - half_spread
2777 + spread * close_position;
2778 (
2779 (current_price + half_spread).max(c),
2780 (current_price - half_spread).min(c),
2781 c,
2782 )
2783 }
2784 };
2785
2786 high.push(h);
2787 low.push(l);
2788 close.push(c);
2789 }
2790
2791 (high, low, close)
2792 },
2793 ),
2794 Just(fast_k_period),
2795 Just(slow_k_period),
2796 Just(slow_d_period),
2797 )
2798 },
2799 );
2800
2801 proptest::test_runner::TestRunner::default().run(
2802 &strat,
2803 |((high, low, close), fast_k_period, slow_k_period, slow_d_period)| {
2804 let params = KdjParams {
2805 fast_k_period: Some(fast_k_period),
2806 slow_k_period: Some(slow_k_period),
2807 slow_k_ma_type: Some("sma".to_string()),
2808 slow_d_period: Some(slow_d_period),
2809 slow_d_ma_type: Some("sma".to_string()),
2810 };
2811 let input = KdjInput::from_slices(&high, &low, &close, params.clone());
2812
2813 let KdjOutput { k, d, j } = kdj_with_kernel(&input, kernel).unwrap();
2814 let KdjOutput {
2815 k: ref_k,
2816 d: ref_d,
2817 j: ref_j,
2818 } = kdj_with_kernel(&input, Kernel::Scalar).unwrap();
2819
2820 prop_assert_eq!(k.len(), high.len(), "K length mismatch");
2821 prop_assert_eq!(d.len(), high.len(), "D length mismatch");
2822 prop_assert_eq!(j.len(), high.len(), "J length mismatch");
2823
2824 let first_valid_idx = high
2825 .iter()
2826 .zip(low.iter())
2827 .zip(close.iter())
2828 .position(|((&h, &l), &c)| !h.is_nan() && !l.is_nan() && !c.is_nan())
2829 .unwrap_or(0);
2830 let stoch_warmup = first_valid_idx + fast_k_period - 1;
2831 let k_warmup = stoch_warmup + slow_k_period - 1;
2832 let d_warmup = k_warmup + slow_d_period - 1;
2833
2834 for i in 0..k_warmup.min(k.len()) {
2835 prop_assert!(
2836 k[i].is_nan(),
2837 "K[{}] should be NaN during warmup but was {}",
2838 i,
2839 k[i]
2840 );
2841 }
2842
2843 for i in 0..d_warmup.min(d.len()) {
2844 prop_assert!(
2845 d[i].is_nan(),
2846 "D[{}] should be NaN during warmup but was {}",
2847 i,
2848 d[i]
2849 );
2850 }
2851
2852 for i in 0..d_warmup.min(j.len()) {
2853 prop_assert!(
2854 j[i].is_nan(),
2855 "J[{}] should be NaN during warmup but was {}",
2856 i,
2857 j[i]
2858 );
2859 }
2860
2861 for i in k_warmup..k.len() {
2862 if !k[i].is_nan() {
2863 prop_assert!(
2864 k[i] >= -1e-9 && k[i] <= 100.0 + 1e-9,
2865 "K[{}] = {} is outside [0, 100] range",
2866 i,
2867 k[i]
2868 );
2869 }
2870 }
2871 for i in d_warmup..d.len() {
2872 if !d[i].is_nan() {
2873 prop_assert!(
2874 d[i] >= -1e-9 && d[i] <= 100.0 + 1e-9,
2875 "D[{}] = {} is outside [0, 100] range",
2876 i,
2877 d[i]
2878 );
2879 }
2880 }
2881
2882 for i in d_warmup..j.len() {
2883 if !k[i].is_nan() && !d[i].is_nan() && !j[i].is_nan() {
2884 let expected_j = 3.0 * k[i] - 2.0 * d[i];
2885 prop_assert!(
2886 (j[i] - expected_j).abs() <= 1e-9,
2887 "J[{}] = {} but expected {} (3*K - 2*D = 3*{} - 2*{})",
2888 i,
2889 j[i],
2890 expected_j,
2891 k[i],
2892 d[i]
2893 );
2894 }
2895 }
2896
2897 for i in stoch_warmup..high.len().min(stoch_warmup + fast_k_period * 2) {
2898 if i >= fast_k_period {
2899 let window_start = i + 1 - fast_k_period;
2900 let all_zero_spread =
2901 (window_start..=i).all(|j| (high[j] - low[j]).abs() < 1e-10);
2902
2903 if all_zero_spread && i >= k_warmup {
2904 prop_assert!(
2905 k[i].is_nan(),
2906 "K[{}] should be NaN when high=low in window, but was {}",
2907 i,
2908 k[i]
2909 );
2910 }
2911 }
2912 }
2913
2914 let mut j_outside_bounds_found = false;
2915 for i in d_warmup..j.len() {
2916 if !j[i].is_nan() {
2917 if j[i] < -1e-9 || j[i] > 100.0 + 1e-9 {
2918 j_outside_bounds_found = true;
2919
2920 let expected_j = 3.0 * k[i] - 2.0 * d[i];
2921 prop_assert!(
2922 (j[i] - expected_j).abs() <= 1e-9,
2923 "J[{}] = {} is outside [0,100] but doesn't match formula 3*K - 2*D",
2924 i,
2925 j[i]
2926 );
2927 }
2928 }
2929 }
2930
2931 let mut trend_sum = 0.0;
2932 for i in 1..high.len().min(50) {
2933 trend_sum += close[i] - close[i - 1];
2934 }
2935
2936 if high.len() > d_warmup + 20 {
2937 let avg_change = trend_sum / (high.len().min(50) - 1) as f64;
2938 let first_close = close[0];
2939
2940 if avg_change > first_close * 0.005 {
2941 let last_valid_k = k
2942 .iter()
2943 .rev()
2944 .find(|&&x| !x.is_nan())
2945 .copied()
2946 .unwrap_or(0.0);
2947
2948 if last_valid_k < 30.0 {}
2949 }
2950
2951 if avg_change < -first_close * 0.005 {
2952 let last_valid_k = k
2953 .iter()
2954 .rev()
2955 .find(|&&x| !x.is_nan())
2956 .copied()
2957 .unwrap_or(100.0);
2958
2959 if last_valid_k > 70.0 {}
2960 }
2961 }
2962
2963 for i in 0..k.len() {
2964 let k_bits = k[i].to_bits();
2965 let ref_k_bits = ref_k[i].to_bits();
2966 let d_bits = d[i].to_bits();
2967 let ref_d_bits = ref_d[i].to_bits();
2968 let j_bits = j[i].to_bits();
2969 let ref_j_bits = ref_j[i].to_bits();
2970
2971 if k[i].is_nan() && ref_k[i].is_nan() {
2972 } else if !k[i].is_nan() && !ref_k[i].is_nan() {
2973 let ulp_diff = if k_bits > ref_k_bits {
2974 k_bits - ref_k_bits
2975 } else {
2976 ref_k_bits - k_bits
2977 };
2978 prop_assert!(
2979 ulp_diff <= 5,
2980 "K[{}]: kernel {} gives {} but scalar gives {} (ULP diff: {})",
2981 i,
2982 kernel as u8,
2983 k[i],
2984 ref_k[i],
2985 ulp_diff
2986 );
2987 } else {
2988 prop_assert!(false, "K[{}]: NaN mismatch between kernels", i);
2989 }
2990
2991 if d[i].is_nan() && ref_d[i].is_nan() {
2992 } else if !d[i].is_nan() && !ref_d[i].is_nan() {
2993 let ulp_diff = if d_bits > ref_d_bits {
2994 d_bits - ref_d_bits
2995 } else {
2996 ref_d_bits - d_bits
2997 };
2998 prop_assert!(
2999 ulp_diff <= 5,
3000 "D[{}]: kernel {} gives {} but scalar gives {} (ULP diff: {})",
3001 i,
3002 kernel as u8,
3003 d[i],
3004 ref_d[i],
3005 ulp_diff
3006 );
3007 } else {
3008 prop_assert!(false, "D[{}]: NaN mismatch between kernels", i);
3009 }
3010
3011 if j[i].is_nan() && ref_j[i].is_nan() {
3012 } else if !j[i].is_nan() && !ref_j[i].is_nan() {
3013 let ulp_diff = if j_bits > ref_j_bits {
3014 j_bits - ref_j_bits
3015 } else {
3016 ref_j_bits - j_bits
3017 };
3018 prop_assert!(
3019 ulp_diff <= 10,
3020 "J[{}]: kernel {} gives {} but scalar gives {} (ULP diff: {})",
3021 i,
3022 kernel as u8,
3023 j[i],
3024 ref_j[i],
3025 ulp_diff
3026 );
3027 } else {
3028 prop_assert!(false, "J[{}]: NaN mismatch between kernels", i);
3029 }
3030 }
3031
3032 Ok(())
3033 },
3034 )?;
3035
3036 Ok(())
3037 }
3038
3039 macro_rules! generate_all_kdj_tests {
3040 ($($test_fn:ident),*) => {
3041 paste::paste! {
3042 $(
3043 #[test]
3044 fn [<$test_fn _scalar_f64>]() {
3045 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
3046 }
3047 )*
3048 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3049 $(
3050 #[test]
3051 fn [<$test_fn _avx2_f64>]() {
3052 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
3053 }
3054 #[test]
3055 fn [<$test_fn _avx512_f64>]() {
3056 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
3057 }
3058 )*
3059 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
3060 $(
3061 #[test]
3062 fn [<$test_fn _simd128_f64>]() {
3063 let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
3064 }
3065 )*
3066 }
3067 }
3068 }
3069
3070 generate_all_kdj_tests!(
3071 check_kdj_partial_params,
3072 check_kdj_accuracy,
3073 check_kdj_default_candles,
3074 check_kdj_zero_period,
3075 check_kdj_period_exceeds_length,
3076 check_kdj_very_small_dataset,
3077 check_kdj_all_nan,
3078 check_kdj_reinput,
3079 check_kdj_nan_handling,
3080 check_kdj_no_poison
3081 );
3082
3083 #[cfg(feature = "proptest")]
3084 generate_all_kdj_tests!(check_kdj_property);
3085
3086 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3087 skip_if_unsupported!(kernel, test);
3088
3089 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3090 let c = read_candles_from_csv(file)?;
3091
3092 let output = KdjBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
3093
3094 let def = KdjParams::default();
3095 let row = output.k_for(&def).expect("default row missing");
3096
3097 assert_eq!(row.len(), c.close.len());
3098
3099 for &v in &row[row.len().saturating_sub(5)..] {
3100 assert!(!v.is_nan(), "[{test}] default-row unexpected NaN at tail");
3101 }
3102 Ok(())
3103 }
3104
3105 #[cfg(debug_assertions)]
3106 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3107 skip_if_unsupported!(kernel, test);
3108
3109 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3110 let c = read_candles_from_csv(file)?;
3111
3112 let test_configs = vec![
3113 (2, 10, 2, 2, 6, 2, 2, 6, 2, "sma", "sma"),
3114 (5, 25, 5, 3, 9, 3, 3, 9, 3, "ema", "ema"),
3115 (30, 60, 15, 5, 15, 5, 5, 15, 5, "sma", "ema"),
3116 (2, 5, 1, 2, 4, 1, 2, 4, 1, "wma", "wma"),
3117 (2, 2, 0, 2, 2, 0, 2, 2, 0, "sma", "sma"),
3118 (9, 15, 3, 3, 6, 3, 3, 6, 3, "sma", "sma"),
3119 (50, 100, 25, 10, 20, 10, 10, 20, 10, "hma", "hma"),
3120 ];
3121
3122 for (
3123 cfg_idx,
3124 &(
3125 fk_start,
3126 fk_end,
3127 fk_step,
3128 sk_start,
3129 sk_end,
3130 sk_step,
3131 sd_start,
3132 sd_end,
3133 sd_step,
3134 sk_ma,
3135 sd_ma,
3136 ),
3137 ) in test_configs.iter().enumerate()
3138 {
3139 let output = KdjBatchBuilder::new()
3140 .kernel(kernel)
3141 .fast_k_period_range(fk_start, fk_end, fk_step)
3142 .slow_k_period_range(sk_start, sk_end, sk_step)
3143 .slow_k_ma_type_static(sk_ma)
3144 .slow_d_period_range(sd_start, sd_end, sd_step)
3145 .slow_d_ma_type_static(sd_ma)
3146 .apply_candles(&c)?;
3147
3148 for (idx, &val) in output.k.iter().enumerate() {
3149 if val.is_nan() {
3150 continue;
3151 }
3152
3153 let bits = val.to_bits();
3154 let row = idx / output.cols;
3155 let col = idx % output.cols;
3156 let combo = &output.combos[row];
3157
3158 if bits == 0x11111111_11111111 {
3159 panic!(
3160 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
3161 at row {} col {} (flat index {}) in K with params: fast_k_period={}, \
3162 slow_k_period={}, slow_k_ma_type={}, slow_d_period={}, slow_d_ma_type={}",
3163 test,
3164 cfg_idx,
3165 val,
3166 bits,
3167 row,
3168 col,
3169 idx,
3170 combo.fast_k_period.unwrap_or(9),
3171 combo.slow_k_period.unwrap_or(3),
3172 combo.slow_k_ma_type.as_deref().unwrap_or("sma"),
3173 combo.slow_d_period.unwrap_or(3),
3174 combo.slow_d_ma_type.as_deref().unwrap_or("sma")
3175 );
3176 }
3177
3178 if bits == 0x22222222_22222222 {
3179 panic!(
3180 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
3181 at row {} col {} (flat index {}) in K with params: fast_k_period={}, \
3182 slow_k_period={}, slow_k_ma_type={}, slow_d_period={}, slow_d_ma_type={}",
3183 test,
3184 cfg_idx,
3185 val,
3186 bits,
3187 row,
3188 col,
3189 idx,
3190 combo.fast_k_period.unwrap_or(9),
3191 combo.slow_k_period.unwrap_or(3),
3192 combo.slow_k_ma_type.as_deref().unwrap_or("sma"),
3193 combo.slow_d_period.unwrap_or(3),
3194 combo.slow_d_ma_type.as_deref().unwrap_or("sma")
3195 );
3196 }
3197
3198 if bits == 0x33333333_33333333 {
3199 panic!(
3200 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
3201 at row {} col {} (flat index {}) in K with params: fast_k_period={}, \
3202 slow_k_period={}, slow_k_ma_type={}, slow_d_period={}, slow_d_ma_type={}",
3203 test,
3204 cfg_idx,
3205 val,
3206 bits,
3207 row,
3208 col,
3209 idx,
3210 combo.fast_k_period.unwrap_or(9),
3211 combo.slow_k_period.unwrap_or(3),
3212 combo.slow_k_ma_type.as_deref().unwrap_or("sma"),
3213 combo.slow_d_period.unwrap_or(3),
3214 combo.slow_d_ma_type.as_deref().unwrap_or("sma")
3215 );
3216 }
3217 }
3218
3219 for (idx, &val) in output.d.iter().enumerate() {
3220 if val.is_nan() {
3221 continue;
3222 }
3223
3224 let bits = val.to_bits();
3225 let row = idx / output.cols;
3226 let col = idx % output.cols;
3227 let combo = &output.combos[row];
3228
3229 if bits == 0x11111111_11111111 {
3230 panic!(
3231 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
3232 at row {} col {} (flat index {}) in D with params: fast_k_period={}, \
3233 slow_k_period={}, slow_k_ma_type={}, slow_d_period={}, slow_d_ma_type={}",
3234 test,
3235 cfg_idx,
3236 val,
3237 bits,
3238 row,
3239 col,
3240 idx,
3241 combo.fast_k_period.unwrap_or(9),
3242 combo.slow_k_period.unwrap_or(3),
3243 combo.slow_k_ma_type.as_deref().unwrap_or("sma"),
3244 combo.slow_d_period.unwrap_or(3),
3245 combo.slow_d_ma_type.as_deref().unwrap_or("sma")
3246 );
3247 }
3248
3249 if bits == 0x22222222_22222222 {
3250 panic!(
3251 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
3252 at row {} col {} (flat index {}) in D with params: fast_k_period={}, \
3253 slow_k_period={}, slow_k_ma_type={}, slow_d_period={}, slow_d_ma_type={}",
3254 test,
3255 cfg_idx,
3256 val,
3257 bits,
3258 row,
3259 col,
3260 idx,
3261 combo.fast_k_period.unwrap_or(9),
3262 combo.slow_k_period.unwrap_or(3),
3263 combo.slow_k_ma_type.as_deref().unwrap_or("sma"),
3264 combo.slow_d_period.unwrap_or(3),
3265 combo.slow_d_ma_type.as_deref().unwrap_or("sma")
3266 );
3267 }
3268
3269 if bits == 0x33333333_33333333 {
3270 panic!(
3271 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
3272 at row {} col {} (flat index {}) in D with params: fast_k_period={}, \
3273 slow_k_period={}, slow_k_ma_type={}, slow_d_period={}, slow_d_ma_type={}",
3274 test,
3275 cfg_idx,
3276 val,
3277 bits,
3278 row,
3279 col,
3280 idx,
3281 combo.fast_k_period.unwrap_or(9),
3282 combo.slow_k_period.unwrap_or(3),
3283 combo.slow_k_ma_type.as_deref().unwrap_or("sma"),
3284 combo.slow_d_period.unwrap_or(3),
3285 combo.slow_d_ma_type.as_deref().unwrap_or("sma")
3286 );
3287 }
3288 }
3289
3290 for (idx, &val) in output.j.iter().enumerate() {
3291 if val.is_nan() {
3292 continue;
3293 }
3294
3295 let bits = val.to_bits();
3296 let row = idx / output.cols;
3297 let col = idx % output.cols;
3298 let combo = &output.combos[row];
3299
3300 if bits == 0x11111111_11111111 {
3301 panic!(
3302 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
3303 at row {} col {} (flat index {}) in J with params: fast_k_period={}, \
3304 slow_k_period={}, slow_k_ma_type={}, slow_d_period={}, slow_d_ma_type={}",
3305 test,
3306 cfg_idx,
3307 val,
3308 bits,
3309 row,
3310 col,
3311 idx,
3312 combo.fast_k_period.unwrap_or(9),
3313 combo.slow_k_period.unwrap_or(3),
3314 combo.slow_k_ma_type.as_deref().unwrap_or("sma"),
3315 combo.slow_d_period.unwrap_or(3),
3316 combo.slow_d_ma_type.as_deref().unwrap_or("sma")
3317 );
3318 }
3319
3320 if bits == 0x22222222_22222222 {
3321 panic!(
3322 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
3323 at row {} col {} (flat index {}) in J with params: fast_k_period={}, \
3324 slow_k_period={}, slow_k_ma_type={}, slow_d_period={}, slow_d_ma_type={}",
3325 test,
3326 cfg_idx,
3327 val,
3328 bits,
3329 row,
3330 col,
3331 idx,
3332 combo.fast_k_period.unwrap_or(9),
3333 combo.slow_k_period.unwrap_or(3),
3334 combo.slow_k_ma_type.as_deref().unwrap_or("sma"),
3335 combo.slow_d_period.unwrap_or(3),
3336 combo.slow_d_ma_type.as_deref().unwrap_or("sma")
3337 );
3338 }
3339
3340 if bits == 0x33333333_33333333 {
3341 panic!(
3342 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
3343 at row {} col {} (flat index {}) in J with params: fast_k_period={}, \
3344 slow_k_period={}, slow_k_ma_type={}, slow_d_period={}, slow_d_ma_type={}",
3345 test,
3346 cfg_idx,
3347 val,
3348 bits,
3349 row,
3350 col,
3351 idx,
3352 combo.fast_k_period.unwrap_or(9),
3353 combo.slow_k_period.unwrap_or(3),
3354 combo.slow_k_ma_type.as_deref().unwrap_or("sma"),
3355 combo.slow_d_period.unwrap_or(3),
3356 combo.slow_d_ma_type.as_deref().unwrap_or("sma")
3357 );
3358 }
3359 }
3360 }
3361
3362 Ok(())
3363 }
3364
3365 #[cfg(not(debug_assertions))]
3366 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
3367 Ok(())
3368 }
3369
3370 macro_rules! gen_batch_tests {
3371 ($fn_name:ident) => {
3372 paste::paste! {
3373 #[test] fn [<$fn_name _scalar>]() {
3374 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3375 }
3376 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3377 #[test] fn [<$fn_name _avx2>]() {
3378 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3379 }
3380 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3381 #[test] fn [<$fn_name _avx512>]() {
3382 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3383 }
3384 #[test] fn [<$fn_name _auto_detect>]() {
3385 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3386 }
3387 }
3388 };
3389 }
3390 gen_batch_tests!(check_batch_default_row);
3391 gen_batch_tests!(check_batch_no_poison);
3392}
3393
3394#[cfg(feature = "python")]
3395#[pyfunction(name = "kdj")]
3396#[pyo3(signature = (high, low, close, fast_k_period=9, slow_k_period=3, slow_k_ma_type="sma", slow_d_period=3, slow_d_ma_type="sma", kernel=None))]
3397pub fn kdj_py<'py>(
3398 py: Python<'py>,
3399 high: PyReadonlyArray1<'py, f64>,
3400 low: PyReadonlyArray1<'py, f64>,
3401 close: PyReadonlyArray1<'py, f64>,
3402 fast_k_period: usize,
3403 slow_k_period: usize,
3404 slow_k_ma_type: &str,
3405 slow_d_period: usize,
3406 slow_d_ma_type: &str,
3407 kernel: Option<&str>,
3408) -> PyResult<(
3409 Bound<'py, PyArray1<f64>>,
3410 Bound<'py, PyArray1<f64>>,
3411 Bound<'py, PyArray1<f64>>,
3412)> {
3413 use numpy::PyArray1;
3414
3415 let h = high.as_slice()?;
3416 let l = low.as_slice()?;
3417 let c = close.as_slice()?;
3418 let params = KdjParams {
3419 fast_k_period: Some(fast_k_period),
3420 slow_k_period: Some(slow_k_period),
3421 slow_k_ma_type: Some(slow_k_ma_type.to_string()),
3422 slow_d_period: Some(slow_d_period),
3423 slow_d_ma_type: Some(slow_d_ma_type.to_string()),
3424 };
3425 let inp = KdjInput::from_slices(h, l, c, params);
3426 let kern = validate_kernel(kernel, false)?;
3427
3428 let (rows, cols) = (1, c.len());
3429 let k_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
3430 let d_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
3431 let j_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
3432
3433 let k_slice = unsafe { k_arr.as_slice_mut()? };
3434 let d_slice = unsafe { d_arr.as_slice_mut()? };
3435 let j_slice = unsafe { j_arr.as_slice_mut()? };
3436
3437 py.allow_threads(|| kdj_into_slices(k_slice, d_slice, j_slice, &inp, kern))
3438 .map_err(|e| PyValueError::new_err(e.to_string()))?;
3439
3440 Ok((k_arr, d_arr, j_arr))
3441}
3442
3443#[cfg(feature = "python")]
3444#[pyclass(name = "KdjStream")]
3445pub struct KdjStreamPy {
3446 stream: KdjStream,
3447}
3448
3449#[cfg(feature = "python")]
3450#[pymethods]
3451impl KdjStreamPy {
3452 #[new]
3453 fn new(
3454 fast_k_period: usize,
3455 slow_k_period: usize,
3456 slow_k_ma_type: &str,
3457 slow_d_period: usize,
3458 slow_d_ma_type: &str,
3459 ) -> PyResult<Self> {
3460 let params = KdjParams {
3461 fast_k_period: Some(fast_k_period),
3462 slow_k_period: Some(slow_k_period),
3463 slow_k_ma_type: Some(slow_k_ma_type.to_string()),
3464 slow_d_period: Some(slow_d_period),
3465 slow_d_ma_type: Some(slow_d_ma_type.to_string()),
3466 };
3467 let stream =
3468 KdjStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
3469 Ok(KdjStreamPy { stream })
3470 }
3471
3472 fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64, f64)> {
3473 self.stream.update(high, low, close)
3474 }
3475}
3476
3477#[cfg(feature = "python")]
3478#[pyfunction(name = "kdj_batch")]
3479#[pyo3(signature = (high, low, close,
3480 fast_k_range, slow_k_range, slow_k_ma_type,
3481 slow_d_range, slow_d_ma_type, kernel=None))]
3482pub fn kdj_batch_py<'py>(
3483 py: Python<'py>,
3484 high: PyReadonlyArray1<'py, f64>,
3485 low: PyReadonlyArray1<'py, f64>,
3486 close: PyReadonlyArray1<'py, f64>,
3487 fast_k_range: (usize, usize, usize),
3488 slow_k_range: (usize, usize, usize),
3489 slow_k_ma_type: &str,
3490 slow_d_range: (usize, usize, usize),
3491 slow_d_ma_type: &str,
3492 kernel: Option<&str>,
3493) -> PyResult<Bound<'py, PyDict>> {
3494 use numpy::{PyArray1, PyArrayMethods};
3495
3496 let h = high.as_slice()?;
3497 let l = low.as_slice()?;
3498 let c = close.as_slice()?;
3499
3500 let range = KdjBatchRange {
3501 fast_k_period: fast_k_range,
3502 slow_k_period: slow_k_range,
3503 slow_k_ma_type: (
3504 slow_k_ma_type.to_string(),
3505 slow_k_ma_type.to_string(),
3506 "".to_string(),
3507 ),
3508 slow_d_period: slow_d_range,
3509 slow_d_ma_type: (
3510 slow_d_ma_type.to_string(),
3511 slow_d_ma_type.to_string(),
3512 "".to_string(),
3513 ),
3514 };
3515
3516 let kern = validate_kernel(kernel, true)?;
3517 let combos;
3518 let rows;
3519 let cols = c.len();
3520
3521 let k_arr = unsafe { PyArray1::<f64>::new(py, [1], false) };
3522 let d_arr = unsafe { PyArray1::<f64>::new(py, [1], false) };
3523 let j_arr = unsafe { PyArray1::<f64>::new(py, [1], false) };
3524
3525 let (k_vec, d_vec, j_vec, cmbs, rws) = py
3526 .allow_threads(|| {
3527 let out = kdj_batch_inner(
3528 h,
3529 l,
3530 c,
3531 &range,
3532 match kern {
3533 Kernel::Avx512Batch => Kernel::Avx512,
3534 Kernel::Avx2Batch => Kernel::Avx2,
3535 Kernel::ScalarBatch => Kernel::Scalar,
3536 Kernel::Auto => detect_best_batch_kernel(),
3537 _ => kern,
3538 },
3539 true,
3540 )?;
3541 Ok::<_, KdjError>((out.k, out.d, out.j, out.combos, out.rows))
3542 })
3543 .map_err(|e| PyValueError::new_err(e.to_string()))?;
3544
3545 combos = cmbs;
3546 rows = rws;
3547
3548 let k_arr = k_vec.into_pyarray(py).reshape((rows, cols))?;
3549 let d_arr = d_vec.into_pyarray(py).reshape((rows, cols))?;
3550 let j_arr = j_vec.into_pyarray(py).reshape((rows, cols))?;
3551
3552 let dict = PyDict::new(py);
3553 dict.set_item("k", k_arr)?;
3554 dict.set_item("d", d_arr)?;
3555 dict.set_item("j", j_arr)?;
3556 dict.set_item(
3557 "fast_k_periods",
3558 combos
3559 .iter()
3560 .map(|p| p.fast_k_period.unwrap())
3561 .collect::<Vec<_>>()
3562 .into_pyarray(py),
3563 )?;
3564 dict.set_item(
3565 "slow_k_periods",
3566 combos
3567 .iter()
3568 .map(|p| p.slow_k_period.unwrap())
3569 .collect::<Vec<_>>()
3570 .into_pyarray(py),
3571 )?;
3572 dict.set_item(
3573 "slow_d_periods",
3574 combos
3575 .iter()
3576 .map(|p| p.slow_d_period.unwrap())
3577 .collect::<Vec<_>>()
3578 .into_pyarray(py),
3579 )?;
3580
3581 let combo_list = PyList::new(
3582 py,
3583 combos.iter().map(|c| {
3584 let combo_dict = PyDict::new(py);
3585 combo_dict
3586 .set_item("fast_k_period", c.fast_k_period.unwrap())
3587 .unwrap();
3588 combo_dict
3589 .set_item("slow_k_period", c.slow_k_period.unwrap())
3590 .unwrap();
3591 combo_dict
3592 .set_item(
3593 "slow_k_ma_type",
3594 c.slow_k_ma_type.as_ref().unwrap().as_str(),
3595 )
3596 .unwrap();
3597 combo_dict
3598 .set_item("slow_d_period", c.slow_d_period.unwrap())
3599 .unwrap();
3600 combo_dict
3601 .set_item(
3602 "slow_d_ma_type",
3603 c.slow_d_ma_type.as_ref().unwrap().as_str(),
3604 )
3605 .unwrap();
3606 combo_dict
3607 }),
3608 )?;
3609 dict.set_item("combos", combo_list)?;
3610
3611 Ok(dict)
3612}
3613
3614#[cfg(all(feature = "python", feature = "cuda"))]
3615use crate::cuda::{cuda_available, CudaKdj};
3616#[cfg(all(feature = "python", feature = "cuda"))]
3617use crate::indicators::moving_averages::alma::{make_device_array_py, DeviceArrayF32Py};
3618#[cfg(all(feature = "python", feature = "cuda"))]
3619#[cfg(all(feature = "python", feature = "cuda"))]
3620#[pyfunction(name = "kdj_cuda_batch_dev")]
3621#[pyo3(signature = (high_f32, low_f32, close_f32, fast_k_range, slow_k_range, slow_k_ma_range, slow_d_range, slow_d_ma_range, device_id=0))]
3622pub fn kdj_cuda_batch_dev_py(
3623 py: Python<'_>,
3624 high_f32: PyReadonlyArray1<'_, f32>,
3625 low_f32: PyReadonlyArray1<'_, f32>,
3626 close_f32: PyReadonlyArray1<'_, f32>,
3627 fast_k_range: (usize, usize, usize),
3628 slow_k_range: (usize, usize, usize),
3629 slow_k_ma_range: (String, String, String),
3630 slow_d_range: (usize, usize, usize),
3631 slow_d_ma_range: (String, String, String),
3632 device_id: usize,
3633) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py, DeviceArrayF32Py)> {
3634 if !cuda_available() {
3635 return Err(PyValueError::new_err("CUDA not available"));
3636 }
3637 let h = high_f32.as_slice()?;
3638 let l = low_f32.as_slice()?;
3639 let c = close_f32.as_slice()?;
3640 let sweep = KdjBatchRange {
3641 fast_k_period: fast_k_range,
3642 slow_k_period: slow_k_range,
3643 slow_k_ma_type: slow_k_ma_range,
3644 slow_d_period: slow_d_range,
3645 slow_d_ma_type: slow_d_ma_range,
3646 };
3647 let (k_dev, d_dev, j_dev) = py.allow_threads(|| {
3648 let cuda = CudaKdj::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
3649 cuda.kdj_batch_dev(h, l, c, &sweep)
3650 .map_err(|e| PyValueError::new_err(e.to_string()))
3651 })?;
3652 let k = make_device_array_py(device_id, k_dev)?;
3653 let d = make_device_array_py(device_id, d_dev)?;
3654 let j = make_device_array_py(device_id, j_dev)?;
3655 Ok((k, d, j))
3656}
3657
3658#[cfg(all(feature = "python", feature = "cuda"))]
3659#[pyfunction(name = "kdj_cuda_many_series_one_param_dev")]
3660#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, cols, rows, fast_k, slow_k, slow_k_ma, slow_d, slow_d_ma, device_id=0))]
3661pub fn kdj_cuda_many_series_one_param_dev_py(
3662 py: Python<'_>,
3663 high_tm_f32: PyReadonlyArray1<'_, f32>,
3664 low_tm_f32: PyReadonlyArray1<'_, f32>,
3665 close_tm_f32: PyReadonlyArray1<'_, f32>,
3666 cols: usize,
3667 rows: usize,
3668 fast_k: usize,
3669 slow_k: usize,
3670 slow_k_ma: String,
3671 slow_d: usize,
3672 slow_d_ma: String,
3673 device_id: usize,
3674) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py, DeviceArrayF32Py)> {
3675 if !cuda_available() {
3676 return Err(PyValueError::new_err("CUDA not available"));
3677 }
3678 let htm = high_tm_f32.as_slice()?;
3679 let ltm = low_tm_f32.as_slice()?;
3680 let ctm = close_tm_f32.as_slice()?;
3681 let params = KdjParams {
3682 fast_k_period: Some(fast_k),
3683 slow_k_period: Some(slow_k),
3684 slow_k_ma_type: Some(slow_k_ma),
3685 slow_d_period: Some(slow_d),
3686 slow_d_ma_type: Some(slow_d_ma),
3687 };
3688 let (k_dev, d_dev, j_dev) = py.allow_threads(|| {
3689 let cuda = CudaKdj::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
3690 cuda.kdj_many_series_one_param_time_major_dev(htm, ltm, ctm, cols, rows, ¶ms)
3691 .map_err(|e| PyValueError::new_err(e.to_string()))
3692 })?;
3693 let k = make_device_array_py(device_id, k_dev)?;
3694 let d = make_device_array_py(device_id, d_dev)?;
3695 let j = make_device_array_py(device_id, j_dev)?;
3696 Ok((k, d, j))
3697}
3698
3699#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3700#[derive(Serialize, Deserialize)]
3701pub struct KdjJsOutput {
3702 pub values: Vec<f64>,
3703 pub rows: usize,
3704 pub cols: usize,
3705}
3706
3707#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3708#[wasm_bindgen(js_name = "kdj")]
3709pub fn kdj_js(
3710 high: &[f64],
3711 low: &[f64],
3712 close: &[f64],
3713 fast_k_period: usize,
3714 slow_k_period: usize,
3715 slow_k_ma_type: &str,
3716 slow_d_period: usize,
3717 slow_d_ma_type: &str,
3718) -> Result<JsValue, JsValue> {
3719 let params = KdjParams {
3720 fast_k_period: Some(fast_k_period),
3721 slow_k_period: Some(slow_k_period),
3722 slow_k_ma_type: Some(slow_k_ma_type.to_string()),
3723 slow_d_period: Some(slow_d_period),
3724 slow_d_ma_type: Some(slow_d_ma_type.to_string()),
3725 };
3726 let input = KdjInput::from_slices(high, low, close, params);
3727
3728 let mut k = vec![0.0; close.len()];
3729 let mut d = vec![0.0; close.len()];
3730 let mut j = vec![0.0; close.len()];
3731 kdj_into_slices(&mut k, &mut d, &mut j, &input, detect_best_kernel())
3732 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3733
3734 let mut values = Vec::with_capacity(3 * close.len());
3735 values.extend_from_slice(&k);
3736 values.extend_from_slice(&d);
3737 values.extend_from_slice(&j);
3738 let result = KdjJsOutput {
3739 values,
3740 rows: 3,
3741 cols: close.len(),
3742 };
3743 serde_wasm_bindgen::to_value(&result).map_err(|e| JsValue::from_str(&e.to_string()))
3744}
3745
3746#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3747#[wasm_bindgen(js_name = "kdj_alloc")]
3748pub fn kdj_alloc(len: usize) -> *mut f64 {
3749 let mut v: Vec<f64> = Vec::with_capacity(len);
3750 let p = v.as_mut_ptr();
3751 std::mem::forget(v);
3752 p
3753}
3754
3755#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3756#[wasm_bindgen(js_name = "kdj_free")]
3757pub fn kdj_free(ptr: *mut f64, len: usize) {
3758 unsafe {
3759 let _ = Vec::from_raw_parts(ptr, len, len);
3760 }
3761}
3762
3763#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3764#[wasm_bindgen(js_name = "kdj_into")]
3765pub fn kdj_into(
3766 high_ptr: *const f64,
3767 low_ptr: *const f64,
3768 close_ptr: *const f64,
3769 k_ptr: *mut f64,
3770 d_ptr: *mut f64,
3771 j_ptr: *mut f64,
3772 len: usize,
3773 fast_k_period: usize,
3774 slow_k_period: usize,
3775 slow_k_ma_type: &str,
3776 slow_d_period: usize,
3777 slow_d_ma_type: &str,
3778) -> Result<(), JsValue> {
3779 if [
3780 high_ptr as usize,
3781 low_ptr as usize,
3782 close_ptr as usize,
3783 k_ptr as usize,
3784 d_ptr as usize,
3785 j_ptr as usize,
3786 ]
3787 .iter()
3788 .any(|&p| p == 0)
3789 {
3790 return Err(JsValue::from_str("null pointer passed to kdj_into"));
3791 }
3792 unsafe {
3793 let h = std::slice::from_raw_parts(high_ptr, len);
3794 let l = std::slice::from_raw_parts(low_ptr, len);
3795 let c = std::slice::from_raw_parts(close_ptr, len);
3796 let k = std::slice::from_raw_parts_mut(k_ptr, len);
3797 let d = std::slice::from_raw_parts_mut(d_ptr, len);
3798 let j = std::slice::from_raw_parts_mut(j_ptr, len);
3799
3800 let params = KdjParams {
3801 fast_k_period: Some(fast_k_period),
3802 slow_k_period: Some(slow_k_period),
3803 slow_k_ma_type: Some(slow_k_ma_type.to_string()),
3804 slow_d_period: Some(slow_d_period),
3805 slow_d_ma_type: Some(slow_d_ma_type.to_string()),
3806 };
3807 let input = KdjInput::from_slices(h, l, c, params);
3808 kdj_into_slices(k, d, j, &input, detect_best_kernel())
3809 .map_err(|e| JsValue::from_str(&e.to_string()))
3810 }
3811}
3812
3813#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3814#[derive(Serialize, Deserialize)]
3815pub struct KdjBatchConfig {
3816 pub fast_k_period: (usize, usize, usize),
3817 pub slow_k_period: (usize, usize, usize),
3818 pub slow_k_ma_type: String,
3819 pub slow_d_period: (usize, usize, usize),
3820 pub slow_d_ma_type: String,
3821}
3822
3823#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3824#[derive(Serialize, Deserialize)]
3825pub struct KdjBatchJsOutput {
3826 pub values: Vec<f64>,
3827 pub combos: Vec<KdjParams>,
3828 pub rows: usize,
3829 pub cols: usize,
3830}
3831
3832#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3833#[wasm_bindgen(js_name = "kdj_batch")]
3834pub fn kdj_batch_unified_js(
3835 high: &[f64],
3836 low: &[f64],
3837 close: &[f64],
3838 config: JsValue,
3839) -> Result<JsValue, JsValue> {
3840 let cfg: KdjBatchConfig = serde_wasm_bindgen::from_value(config)
3841 .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
3842 let sweep = KdjBatchRange {
3843 fast_k_period: cfg.fast_k_period,
3844 slow_k_period: cfg.slow_k_period,
3845 slow_k_ma_type: (
3846 cfg.slow_k_ma_type.clone(),
3847 cfg.slow_k_ma_type,
3848 "".to_string(),
3849 ),
3850 slow_d_period: cfg.slow_d_period,
3851 slow_d_ma_type: (
3852 cfg.slow_d_ma_type.clone(),
3853 cfg.slow_d_ma_type,
3854 "".to_string(),
3855 ),
3856 };
3857 let out = kdj_batch_inner(high, low, close, &sweep, detect_best_kernel(), false)
3858 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3859
3860 let mut values = Vec::with_capacity(out.rows * 3 * out.cols);
3861 for row in 0..out.rows {
3862 let s = row * out.cols;
3863 values.extend_from_slice(&out.k[s..s + out.cols]);
3864 values.extend_from_slice(&out.d[s..s + out.cols]);
3865 values.extend_from_slice(&out.j[s..s + out.cols]);
3866 }
3867 let js = KdjBatchJsOutput {
3868 values,
3869 combos: out.combos,
3870 rows: out.rows * 3,
3871 cols: out.cols,
3872 };
3873 serde_wasm_bindgen::to_value(&js)
3874 .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
3875}
3876
3877#[inline]
3878fn kdj_classic_sma(
3879 stoch: &[f64],
3880 slow_k: usize,
3881 slow_d: usize,
3882 stoch_warm: usize,
3883 k_out: &mut [f64],
3884 d_out: &mut [f64],
3885 j_out: &mut [f64],
3886) -> Result<(), KdjError> {
3887 let len = stoch.len();
3888
3889 let k_warm = stoch_warm + slow_k - 1;
3890 for i in 0..k_warm.min(len) {
3891 k_out[i] = f64::NAN;
3892 }
3893
3894 let mut sum_k = 0.0;
3895 let mut count_k = 0;
3896
3897 for i in stoch_warm..(stoch_warm + slow_k).min(len) {
3898 if !stoch[i].is_nan() {
3899 sum_k += stoch[i];
3900 count_k += 1;
3901 }
3902 }
3903
3904 if k_warm < len {
3905 k_out[k_warm] = if count_k > 0 {
3906 sum_k / count_k as f64
3907 } else {
3908 f64::NAN
3909 };
3910
3911 for i in (k_warm + 1)..len {
3912 let old_val = stoch[i - slow_k];
3913 let new_val = stoch[i];
3914 if !old_val.is_nan() {
3915 sum_k -= old_val;
3916 count_k -= 1;
3917 }
3918 if !new_val.is_nan() {
3919 sum_k += new_val;
3920 count_k += 1;
3921 }
3922 k_out[i] = if count_k > 0 {
3923 sum_k / count_k as f64
3924 } else {
3925 f64::NAN
3926 };
3927 }
3928 }
3929
3930 let d_warm = k_warm + slow_d - 1;
3931 for i in 0..d_warm.min(len) {
3932 d_out[i] = f64::NAN;
3933 }
3934
3935 let mut sum_d = 0.0;
3936 let mut count_d = 0;
3937
3938 for i in k_warm..(k_warm + slow_d).min(len) {
3939 if !k_out[i].is_nan() {
3940 sum_d += k_out[i];
3941 count_d += 1;
3942 }
3943 }
3944
3945 if d_warm < len {
3946 d_out[d_warm] = if count_d > 0 {
3947 sum_d / count_d as f64
3948 } else {
3949 f64::NAN
3950 };
3951
3952 for i in (d_warm + 1)..len {
3953 let old_val = k_out[i - slow_d];
3954 let new_val = k_out[i];
3955 if !old_val.is_nan() {
3956 sum_d -= old_val;
3957 count_d -= 1;
3958 }
3959 if !new_val.is_nan() {
3960 sum_d += new_val;
3961 count_d += 1;
3962 }
3963 d_out[i] = if count_d > 0 {
3964 sum_d / count_d as f64
3965 } else {
3966 f64::NAN
3967 };
3968 }
3969 }
3970
3971 for i in 0..d_warm.min(len) {
3972 j_out[i] = f64::NAN;
3973 }
3974 for i in d_warm..len {
3975 j_out[i] = if k_out[i].is_nan() || d_out[i].is_nan() {
3976 f64::NAN
3977 } else {
3978 3.0 * k_out[i] - 2.0 * d_out[i]
3979 };
3980 }
3981
3982 Ok(())
3983}
3984
3985#[inline]
3986fn kdj_classic_ema(
3987 stoch: &[f64],
3988 slow_k: usize,
3989 slow_d: usize,
3990 stoch_warm: usize,
3991 k_out: &mut [f64],
3992 d_out: &mut [f64],
3993 j_out: &mut [f64],
3994) -> Result<(), KdjError> {
3995 let len = stoch.len();
3996
3997 let k_warm = stoch_warm + slow_k - 1;
3998 for i in 0..k_warm.min(len) {
3999 k_out[i] = f64::NAN;
4000 }
4001
4002 let alpha_k = 2.0 / (slow_k as f64 + 1.0);
4003 let one_minus_alpha_k = 1.0 - alpha_k;
4004
4005 let mut sum_k = 0.0;
4006 let mut count_k = 0;
4007 for i in stoch_warm..(stoch_warm + slow_k).min(len) {
4008 if !stoch[i].is_nan() {
4009 sum_k += stoch[i];
4010 count_k += 1;
4011 }
4012 }
4013
4014 let mut ema_k = f64::NAN;
4015 if k_warm < len {
4016 if count_k > 0 {
4017 ema_k = sum_k / count_k as f64;
4018 }
4019 k_out[k_warm] = ema_k;
4020
4021 for i in (k_warm + 1)..len {
4022 let st = stoch[i];
4023 if !st.is_nan() {
4024 ema_k = if ema_k.is_nan() {
4025 st
4026 } else {
4027 st.mul_add(alpha_k, one_minus_alpha_k * ema_k)
4028 };
4029 }
4030 k_out[i] = ema_k;
4031 }
4032 }
4033
4034 let d_warm = k_warm + slow_d - 1;
4035 for i in 0..d_warm.min(len) {
4036 d_out[i] = f64::NAN;
4037 }
4038
4039 let alpha_d = 2.0 / (slow_d as f64 + 1.0);
4040 let one_minus_alpha_d = 1.0 - alpha_d;
4041
4042 let mut sum_d = 0.0;
4043 let mut count_d = 0;
4044 for i in k_warm..(k_warm + slow_d).min(len) {
4045 if !k_out[i].is_nan() {
4046 sum_d += k_out[i];
4047 count_d += 1;
4048 }
4049 }
4050
4051 let mut ema_d = f64::NAN;
4052 if d_warm < len {
4053 if count_d > 0 {
4054 ema_d = sum_d / count_d as f64;
4055 }
4056 d_out[d_warm] = ema_d;
4057
4058 for i in (d_warm + 1)..len {
4059 let kv = k_out[i];
4060 if !kv.is_nan() {
4061 ema_d = if ema_d.is_nan() {
4062 kv
4063 } else {
4064 kv.mul_add(alpha_d, one_minus_alpha_d * ema_d)
4065 };
4066 }
4067 d_out[i] = ema_d;
4068 }
4069 }
4070
4071 for i in 0..d_warm.min(len) {
4072 j_out[i] = f64::NAN;
4073 }
4074 for i in d_warm..len {
4075 j_out[i] = if k_out[i].is_nan() || d_out[i].is_nan() {
4076 f64::NAN
4077 } else {
4078 3.0 * k_out[i] - 2.0 * d_out[i]
4079 };
4080 }
4081
4082 Ok(())
4083}