1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::indicators::moving_averages::alma::DeviceArrayF32Py;
3#[cfg(feature = "python")]
4use numpy::{IntoPyArray, PyArray1};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyValueError;
7#[cfg(feature = "python")]
8use pyo3::prelude::*;
9#[cfg(feature = "python")]
10use pyo3::types::PyDict;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::{source_type, Candles};
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21 make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25use aligned_vec::{AVec, CACHELINE_ALIGN};
26#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
27use core::arch::x86_64::*;
28#[cfg(not(target_arch = "wasm32"))]
29use rayon::prelude::*;
30use std::convert::AsRef;
31use std::error::Error;
32use std::mem::MaybeUninit;
33use thiserror::Error;
34
35#[derive(Debug, Clone)]
36pub enum StochfData<'a> {
37 Candles {
38 candles: &'a Candles,
39 },
40 Slices {
41 high: &'a [f64],
42 low: &'a [f64],
43 close: &'a [f64],
44 },
45}
46
47#[derive(Debug, Clone)]
48pub struct StochfOutput {
49 pub k: Vec<f64>,
50 pub d: Vec<f64>,
51}
52
53#[derive(Debug, Clone)]
54#[cfg_attr(
55 all(target_arch = "wasm32", feature = "wasm"),
56 derive(Serialize, Deserialize)
57)]
58pub struct StochfParams {
59 pub fastk_period: Option<usize>,
60 pub fastd_period: Option<usize>,
61 pub fastd_matype: Option<usize>,
62}
63
64impl Default for StochfParams {
65 fn default() -> Self {
66 Self {
67 fastk_period: Some(5),
68 fastd_period: Some(3),
69 fastd_matype: Some(0),
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
75pub struct StochfInput<'a> {
76 pub data: StochfData<'a>,
77 pub params: StochfParams,
78}
79
80impl<'a> StochfInput<'a> {
81 #[inline]
82 pub fn from_candles(candles: &'a Candles, params: StochfParams) -> Self {
83 Self {
84 data: StochfData::Candles { candles },
85 params,
86 }
87 }
88 #[inline]
89 pub fn from_slices(
90 high: &'a [f64],
91 low: &'a [f64],
92 close: &'a [f64],
93 params: StochfParams,
94 ) -> Self {
95 Self {
96 data: StochfData::Slices { high, low, close },
97 params,
98 }
99 }
100 #[inline]
101 pub fn with_default_candles(candles: &'a Candles) -> Self {
102 Self::from_candles(candles, StochfParams::default())
103 }
104 #[inline]
105 pub fn get_fastk_period(&self) -> usize {
106 self.params.fastk_period.unwrap_or(5)
107 }
108 #[inline]
109 pub fn get_fastd_period(&self) -> usize {
110 self.params.fastd_period.unwrap_or(3)
111 }
112 #[inline]
113 pub fn get_fastd_matype(&self) -> usize {
114 self.params.fastd_matype.unwrap_or(0)
115 }
116}
117
118#[derive(Copy, Clone, Debug)]
119pub struct StochfBuilder {
120 fastk_period: Option<usize>,
121 fastd_period: Option<usize>,
122 fastd_matype: Option<usize>,
123 kernel: Kernel,
124}
125
126impl Default for StochfBuilder {
127 fn default() -> Self {
128 Self {
129 fastk_period: None,
130 fastd_period: None,
131 fastd_matype: None,
132 kernel: Kernel::Auto,
133 }
134 }
135}
136
137impl StochfBuilder {
138 #[inline]
139 pub fn new() -> Self {
140 Self::default()
141 }
142 #[inline]
143 pub fn fastk_period(mut self, n: usize) -> Self {
144 self.fastk_period = Some(n);
145 self
146 }
147 #[inline]
148 pub fn fastd_period(mut self, n: usize) -> Self {
149 self.fastd_period = Some(n);
150 self
151 }
152 #[inline]
153 pub fn fastd_matype(mut self, t: usize) -> Self {
154 self.fastd_matype = Some(t);
155 self
156 }
157 #[inline]
158 pub fn kernel(mut self, k: Kernel) -> Self {
159 self.kernel = k;
160 self
161 }
162 #[inline]
163 pub fn apply(self, candles: &Candles) -> Result<StochfOutput, StochfError> {
164 let p = StochfParams {
165 fastk_period: self.fastk_period,
166 fastd_period: self.fastd_period,
167 fastd_matype: self.fastd_matype,
168 };
169 let i = StochfInput::from_candles(candles, p);
170 stochf_with_kernel(&i, self.kernel)
171 }
172 #[inline]
173 pub fn apply_slices(
174 self,
175 high: &[f64],
176 low: &[f64],
177 close: &[f64],
178 ) -> Result<StochfOutput, StochfError> {
179 let p = StochfParams {
180 fastk_period: self.fastk_period,
181 fastd_period: self.fastd_period,
182 fastd_matype: self.fastd_matype,
183 };
184 let i = StochfInput::from_slices(high, low, close, p);
185 stochf_with_kernel(&i, self.kernel)
186 }
187 #[inline]
188 pub fn into_stream(self) -> Result<StochfStream, StochfError> {
189 let p = StochfParams {
190 fastk_period: self.fastk_period,
191 fastd_period: self.fastd_period,
192 fastd_matype: self.fastd_matype,
193 };
194 StochfStream::try_new(p)
195 }
196}
197
198#[derive(Debug, Error)]
199pub enum StochfError {
200 #[error("stochf: Empty data provided.")]
201 EmptyInputData,
202 #[error("stochf: Invalid period (fastk={fastk}, fastd={fastd}), data length={data_len}.")]
203 InvalidPeriod {
204 fastk: usize,
205 fastd: usize,
206 data_len: usize,
207 },
208 #[error("stochf: All values are NaN.")]
209 AllValuesNaN,
210 #[error(
211 "stochf: Not enough valid data after first valid index (needed={needed}, valid={valid})."
212 )]
213 NotEnoughValidData { needed: usize, valid: usize },
214 #[error("stochf: Invalid output size (expected={expected}, k_got={k_got}, d_got={d_got}).")]
215 OutputLengthMismatch {
216 expected: usize,
217 k_got: usize,
218 d_got: usize,
219 },
220 #[error("stochf: invalid range: start={start}, end={end}, step={step}")]
221 InvalidRange {
222 start: usize,
223 end: usize,
224 step: usize,
225 },
226 #[error("stochf: invalid kernel for batch: {0:?}")]
227 InvalidKernelForBatch(Kernel),
228}
229
230#[inline]
231pub fn stochf(input: &StochfInput) -> Result<StochfOutput, StochfError> {
232 stochf_with_kernel(input, Kernel::Auto)
233}
234
235#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
236#[inline]
237pub fn stochf_into(
238 input: &StochfInput,
239 out_k: &mut [f64],
240 out_d: &mut [f64],
241) -> Result<(), StochfError> {
242 let (high, low, close) = match &input.data {
243 StochfData::Candles { candles } => {
244 let high = candles
245 .select_candle_field("high")
246 .map_err(|_| StochfError::EmptyInputData)?;
247 let low = candles
248 .select_candle_field("low")
249 .map_err(|_| StochfError::EmptyInputData)?;
250 let close = candles
251 .select_candle_field("close")
252 .map_err(|_| StochfError::EmptyInputData)?;
253 (high, low, close)
254 }
255 StochfData::Slices { high, low, close } => (*high, *low, *close),
256 };
257
258 if high.is_empty() || low.is_empty() || close.is_empty() {
259 return Err(StochfError::EmptyInputData);
260 }
261 let len = high.len();
262 if low.len() != len || close.len() != len {
263 return Err(StochfError::EmptyInputData);
264 }
265 if out_k.len() != len || out_d.len() != len {
266 return Err(StochfError::OutputLengthMismatch {
267 expected: len,
268 k_got: out_k.len(),
269 d_got: out_d.len(),
270 });
271 }
272
273 let fastk = input.get_fastk_period();
274 let fastd = input.get_fastd_period();
275 let _matype = input.get_fastd_matype();
276 if fastk == 0 || fastd == 0 || fastk > len || fastd > len {
277 return Err(StochfError::InvalidPeriod {
278 fastk,
279 fastd,
280 data_len: len,
281 });
282 }
283
284 let first = (0..len)
285 .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
286 .ok_or(StochfError::AllValuesNaN)?;
287 if (len - first) < fastk {
288 return Err(StochfError::NotEnoughValidData {
289 needed: fastk,
290 valid: len - first,
291 });
292 }
293
294 let k_warm = first + fastk - 1;
295 let d_warm = first + fastk + fastd - 2;
296 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
297 for v in &mut out_k[..k_warm.min(len)] {
298 *v = qnan;
299 }
300 for v in &mut out_d[..d_warm.min(len)] {
301 *v = qnan;
302 }
303
304 stochf_into_slice(out_k, out_d, input, Kernel::Auto)
305}
306
307#[inline]
308pub fn stochf_into_slice(
309 dst_k: &mut [f64],
310 dst_d: &mut [f64],
311 input: &StochfInput,
312 kernel: Kernel,
313) -> Result<(), StochfError> {
314 let (high, low, close) = match &input.data {
315 StochfData::Candles { candles } => {
316 let high = candles
317 .select_candle_field("high")
318 .map_err(|_| StochfError::EmptyInputData)?;
319 let low = candles
320 .select_candle_field("low")
321 .map_err(|_| StochfError::EmptyInputData)?;
322 let close = candles
323 .select_candle_field("close")
324 .map_err(|_| StochfError::EmptyInputData)?;
325 (high, low, close)
326 }
327 StochfData::Slices { high, low, close } => (*high, *low, *close),
328 };
329
330 if high.is_empty() || low.is_empty() || close.is_empty() {
331 return Err(StochfError::EmptyInputData);
332 }
333 let len = high.len();
334 if low.len() != len || close.len() != len {
335 return Err(StochfError::EmptyInputData);
336 }
337 if dst_k.len() != len || dst_d.len() != len {
338 return Err(StochfError::OutputLengthMismatch {
339 expected: len,
340 k_got: dst_k.len(),
341 d_got: dst_d.len(),
342 });
343 }
344
345 let fastk_period = input.get_fastk_period();
346 let fastd_period = input.get_fastd_period();
347 let matype = input.get_fastd_matype();
348
349 if fastk_period == 0 || fastd_period == 0 || fastk_period > len || fastd_period > len {
350 return Err(StochfError::InvalidPeriod {
351 fastk: fastk_period,
352 fastd: fastd_period,
353 data_len: len,
354 });
355 }
356 let first_valid_idx = (0..len)
357 .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
358 .ok_or(StochfError::AllValuesNaN)?;
359 if (len - first_valid_idx) < fastk_period {
360 return Err(StochfError::NotEnoughValidData {
361 needed: fastk_period,
362 valid: len - first_valid_idx,
363 });
364 }
365
366 let chosen = match kernel {
367 Kernel::Auto => Kernel::Scalar,
368 other => other,
369 };
370
371 unsafe {
372 match chosen {
373 Kernel::Scalar | Kernel::ScalarBatch => stochf_scalar(
374 high,
375 low,
376 close,
377 fastk_period,
378 fastd_period,
379 matype,
380 first_valid_idx,
381 dst_k,
382 dst_d,
383 ),
384 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
385 Kernel::Avx2 | Kernel::Avx2Batch => stochf_avx2(
386 high,
387 low,
388 close,
389 fastk_period,
390 fastd_period,
391 matype,
392 first_valid_idx,
393 dst_k,
394 dst_d,
395 ),
396 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
397 Kernel::Avx512 | Kernel::Avx512Batch => stochf_avx512(
398 high,
399 low,
400 close,
401 fastk_period,
402 fastd_period,
403 matype,
404 first_valid_idx,
405 dst_k,
406 dst_d,
407 ),
408 _ => unreachable!(),
409 }
410 }
411
412 let k_warmup = (first_valid_idx + fastk_period - 1).min(len);
413 let d_warmup = (first_valid_idx + fastk_period + fastd_period - 2).min(len);
414 for v in &mut dst_k[..k_warmup] {
415 *v = f64::NAN;
416 }
417 for v in &mut dst_d[..d_warmup] {
418 *v = f64::NAN;
419 }
420
421 Ok(())
422}
423
424pub fn stochf_with_kernel(
425 input: &StochfInput,
426 kernel: Kernel,
427) -> Result<StochfOutput, StochfError> {
428 let (high, low, close) = match &input.data {
429 StochfData::Candles { candles } => {
430 let high = candles
431 .select_candle_field("high")
432 .map_err(|_| StochfError::EmptyInputData)?;
433 let low = candles
434 .select_candle_field("low")
435 .map_err(|_| StochfError::EmptyInputData)?;
436 let close = candles
437 .select_candle_field("close")
438 .map_err(|_| StochfError::EmptyInputData)?;
439 (high, low, close)
440 }
441 StochfData::Slices { high, low, close } => (*high, *low, *close),
442 };
443
444 if high.is_empty() || low.is_empty() || close.is_empty() {
445 return Err(StochfError::EmptyInputData);
446 }
447 let len = high.len();
448 if low.len() != len || close.len() != len {
449 return Err(StochfError::EmptyInputData);
450 }
451
452 let fastk_period = input.get_fastk_period();
453 let fastd_period = input.get_fastd_period();
454 let matype = input.get_fastd_matype();
455
456 if fastk_period == 0 || fastd_period == 0 || fastk_period > len || fastd_period > len {
457 return Err(StochfError::InvalidPeriod {
458 fastk: fastk_period,
459 fastd: fastd_period,
460 data_len: len,
461 });
462 }
463 let first_valid_idx = (0..len)
464 .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
465 .ok_or(StochfError::AllValuesNaN)?;
466 if (len - first_valid_idx) < fastk_period {
467 return Err(StochfError::NotEnoughValidData {
468 needed: fastk_period,
469 valid: len - first_valid_idx,
470 });
471 }
472
473 let k_warmup = first_valid_idx + fastk_period - 1;
474 let d_warmup = first_valid_idx + fastk_period + fastd_period - 2;
475 let mut k_vals = alloc_with_nan_prefix(len, k_warmup.min(len));
476 let mut d_vals = alloc_with_nan_prefix(len, d_warmup.min(len));
477
478 let chosen = match kernel {
479 Kernel::Auto => Kernel::Scalar,
480 other => other,
481 };
482
483 unsafe {
484 match chosen {
485 Kernel::Scalar | Kernel::ScalarBatch => stochf_scalar(
486 high,
487 low,
488 close,
489 fastk_period,
490 fastd_period,
491 matype,
492 first_valid_idx,
493 &mut k_vals,
494 &mut d_vals,
495 ),
496 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
497 Kernel::Avx2 | Kernel::Avx2Batch => stochf_avx2(
498 high,
499 low,
500 close,
501 fastk_period,
502 fastd_period,
503 matype,
504 first_valid_idx,
505 &mut k_vals,
506 &mut d_vals,
507 ),
508 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
509 Kernel::Avx512 | Kernel::Avx512Batch => stochf_avx512(
510 high,
511 low,
512 close,
513 fastk_period,
514 fastd_period,
515 matype,
516 first_valid_idx,
517 &mut k_vals,
518 &mut d_vals,
519 ),
520 _ => unreachable!(),
521 }
522 }
523
524 Ok(StochfOutput {
525 k: k_vals,
526 d: d_vals,
527 })
528}
529
530#[inline]
531pub unsafe fn stochf_scalar(
532 high: &[f64],
533 low: &[f64],
534 close: &[f64],
535 fastk_period: usize,
536 fastd_period: usize,
537 matype: usize,
538 first_valid_idx: usize,
539 k_vals: &mut [f64],
540 d_vals: &mut [f64],
541) {
542 debug_assert_eq!(high.len(), low.len());
543 debug_assert_eq!(high.len(), close.len());
544 debug_assert_eq!(high.len(), k_vals.len());
545 debug_assert_eq!(k_vals.len(), d_vals.len());
546
547 let len = high.len();
548 if len == 0 {
549 return;
550 }
551
552 let hp = high.as_ptr();
553 let lp = low.as_ptr();
554 let cp = close.as_ptr();
555
556 let k_start = first_valid_idx + fastk_period - 1;
557
558 if fastk_period <= 16 {
559 let use_sma_d = matype == 0;
560 let mut d_sum: f64 = 0.0;
561 let mut d_cnt: usize = 0;
562
563 let mut i = k_start;
564 while i < len {
565 let start = i + 1 - fastk_period;
566 let end = i + 1;
567
568 let mut hh = f64::NEG_INFINITY;
569 let mut ll = f64::INFINITY;
570
571 let mut j = start;
572 let unroll_end = end - ((end - j) & 3);
573 while j < unroll_end {
574 let h0 = *hp.add(j);
575 let l0 = *lp.add(j);
576 if h0 > hh {
577 hh = h0;
578 }
579 if l0 < ll {
580 ll = l0;
581 }
582
583 let h1 = *hp.add(j + 1);
584 let l1 = *lp.add(j + 1);
585 if h1 > hh {
586 hh = h1;
587 }
588 if l1 < ll {
589 ll = l1;
590 }
591
592 let h2 = *hp.add(j + 2);
593 let l2 = *lp.add(j + 2);
594 if h2 > hh {
595 hh = h2;
596 }
597 if l2 < ll {
598 ll = l2;
599 }
600
601 let h3 = *hp.add(j + 3);
602 let l3 = *lp.add(j + 3);
603 if h3 > hh {
604 hh = h3;
605 }
606 if l3 < ll {
607 ll = l3;
608 }
609
610 j += 4;
611 }
612 while j < end {
613 let h = *hp.add(j);
614 let l = *lp.add(j);
615 if h > hh {
616 hh = h;
617 }
618 if l < ll {
619 ll = l;
620 }
621 j += 1;
622 }
623
624 let c = *cp.add(i);
625 let denom = hh - ll;
626 let kv = if denom == 0.0 {
627 if c == hh {
628 100.0
629 } else {
630 0.0
631 }
632 } else {
633 let inv = 100.0 / denom;
634 c.mul_add(inv, (-ll) * inv)
635 };
636 *k_vals.get_unchecked_mut(i) = kv;
637
638 if use_sma_d {
639 if kv.is_nan() {
640 *d_vals.get_unchecked_mut(i) = f64::NAN;
641 } else if d_cnt < fastd_period {
642 d_sum += kv;
643 d_cnt += 1;
644 if d_cnt == fastd_period {
645 *d_vals.get_unchecked_mut(i) = d_sum / (fastd_period as f64);
646 } else {
647 *d_vals.get_unchecked_mut(i) = f64::NAN;
648 }
649 } else {
650 d_sum += kv - *k_vals.get_unchecked(i - fastd_period);
651 *d_vals.get_unchecked_mut(i) = d_sum / (fastd_period as f64);
652 }
653 }
654
655 i += 1;
656 }
657
658 if matype != 0 {
659 d_vals.fill(f64::NAN);
660 }
661 return;
662 }
663
664 let cap = fastk_period;
665 let mut qh = vec![0usize; cap];
666 let mut ql = vec![0usize; cap];
667 let mut qh_head = 0usize;
668 let mut qh_tail = 0usize;
669 let mut ql_head = 0usize;
670 let mut ql_tail = 0usize;
671
672 let use_sma_d = matype == 0;
673 let mut d_sum: f64 = 0.0;
674 let mut d_cnt: usize = 0;
675
676 let mut i = first_valid_idx;
677 while i < len {
678 if i + 1 >= fastk_period {
679 let win_start = i + 1 - fastk_period;
680 while qh_head != qh_tail {
681 let idx = *qh.get_unchecked(qh_head);
682 if idx >= win_start {
683 break;
684 }
685 qh_head += 1;
686 if qh_head == cap {
687 qh_head = 0;
688 }
689 }
690 while ql_head != ql_tail {
691 let idx = *ql.get_unchecked(ql_head);
692 if idx >= win_start {
693 break;
694 }
695 ql_head += 1;
696 if ql_head == cap {
697 ql_head = 0;
698 }
699 }
700 }
701
702 let h_i = *hp.add(i);
703 if h_i == h_i {
704 while qh_head != qh_tail {
705 let back = if qh_tail == 0 { cap - 1 } else { qh_tail - 1 };
706 let back_idx = *qh.get_unchecked(back);
707 if *hp.add(back_idx) <= h_i {
708 qh_tail = back;
709 } else {
710 break;
711 }
712 }
713 *qh.get_unchecked_mut(qh_tail) = i;
714 qh_tail += 1;
715 if qh_tail == cap {
716 qh_tail = 0;
717 }
718 }
719
720 let l_i = *lp.add(i);
721 if l_i == l_i {
722 while ql_head != ql_tail {
723 let back = if ql_tail == 0 { cap - 1 } else { ql_tail - 1 };
724 let back_idx = *ql.get_unchecked(back);
725 if *lp.add(back_idx) >= l_i {
726 ql_tail = back;
727 } else {
728 break;
729 }
730 }
731 *ql.get_unchecked_mut(ql_tail) = i;
732 ql_tail += 1;
733 if ql_tail == cap {
734 ql_tail = 0;
735 }
736 }
737
738 if i >= k_start {
739 let hh = if qh_head != qh_tail {
740 *hp.add(*qh.get_unchecked(qh_head))
741 } else {
742 f64::NEG_INFINITY
743 };
744 let ll = if ql_head != ql_tail {
745 *lp.add(*ql.get_unchecked(ql_head))
746 } else {
747 f64::INFINITY
748 };
749 let c = *cp.add(i);
750 let denom = hh - ll;
751 let kv = if denom == 0.0 {
752 if c == hh {
753 100.0
754 } else {
755 0.0
756 }
757 } else {
758 let inv = 100.0 / denom;
759 c.mul_add(inv, (-ll) * inv)
760 };
761 *k_vals.get_unchecked_mut(i) = kv;
762
763 if use_sma_d {
764 if kv.is_nan() {
765 *d_vals.get_unchecked_mut(i) = f64::NAN;
766 } else if d_cnt < fastd_period {
767 d_sum += kv;
768 d_cnt += 1;
769 if d_cnt == fastd_period {
770 *d_vals.get_unchecked_mut(i) = d_sum / (fastd_period as f64);
771 } else {
772 *d_vals.get_unchecked_mut(i) = f64::NAN;
773 }
774 } else {
775 d_sum += kv - *k_vals.get_unchecked(i - fastd_period);
776 *d_vals.get_unchecked_mut(i) = d_sum / (fastd_period as f64);
777 }
778 }
779 }
780
781 i += 1;
782 }
783
784 if !use_sma_d {
785 d_vals.fill(f64::NAN);
786 }
787}
788
789#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
790#[inline]
791pub unsafe fn stochf_avx2(
792 high: &[f64],
793 low: &[f64],
794 close: &[f64],
795 fastk_period: usize,
796 fastd_period: usize,
797 matype: usize,
798 first_valid_idx: usize,
799 k_vals: &mut [f64],
800 d_vals: &mut [f64],
801) {
802 if fastk_period <= 32 {
803 let len = high.len();
804 let start_i = first_valid_idx + fastk_period - 1;
805 let neg_inf = _mm256_set1_pd(f64::NEG_INFINITY);
806 let pos_inf = _mm256_set1_pd(f64::INFINITY);
807
808 let use_sma_d = matype == 0;
809 let mut d_sum = 0.0f64;
810 let mut d_cnt: usize = 0;
811
812 for i in start_i..len {
813 let start = i + 1 - fastk_period;
814 let end = i + 1;
815
816 let mut vmax = neg_inf;
817 let mut vmin = pos_inf;
818 let mut j = start;
819 while j + 4 <= end {
820 let vh = _mm256_loadu_pd(high.as_ptr().add(j));
821 let vl = _mm256_loadu_pd(low.as_ptr().add(j));
822
823 let mask_h = _mm256_cmp_pd(vh, vh, _CMP_ORD_Q);
824 let mask_l = _mm256_cmp_pd(vl, vl, _CMP_ORD_Q);
825 let vh_nnan = _mm256_blendv_pd(neg_inf, vh, mask_h);
826 let vl_nnan = _mm256_blendv_pd(pos_inf, vl, mask_l);
827
828 vmax = _mm256_max_pd(vmax, vh_nnan);
829 vmin = _mm256_min_pd(vmin, vl_nnan);
830 j += 4;
831 }
832
833 let vmax_lo = _mm256_castpd256_pd128(vmax);
834 let vmax_hi = _mm256_extractf128_pd(vmax, 1);
835 let vmax_128 = _mm_max_pd(vmax_lo, vmax_hi);
836 let vmax_hi64 = _mm_unpackhi_pd(vmax_128, vmax_128);
837 let mut hh = f64::max(_mm_cvtsd_f64(vmax_128), _mm_cvtsd_f64(vmax_hi64));
838
839 let vmin_lo = _mm256_castpd256_pd128(vmin);
840 let vmin_hi = _mm256_extractf128_pd(vmin, 1);
841 let vmin_128 = _mm_min_pd(vmin_lo, vmin_hi);
842 let vmin_hi64 = _mm_unpackhi_pd(vmin_128, vmin_128);
843 let mut ll = f64::min(_mm_cvtsd_f64(vmin_128), _mm_cvtsd_f64(vmin_hi64));
844
845 while j < end {
846 let h = *high.get_unchecked(j);
847 let l = *low.get_unchecked(j);
848 if h == h && h > hh {
849 hh = h;
850 }
851 if l == l && l < ll {
852 ll = l;
853 }
854 j += 1;
855 }
856
857 let c = *close.get_unchecked(i);
858 let denom = hh - ll;
859 let kv = if denom == 0.0 {
860 if c == hh {
861 100.0
862 } else {
863 0.0
864 }
865 } else {
866 let inv = 100.0 / denom;
867 c.mul_add(inv, (-ll) * inv)
868 };
869 *k_vals.get_unchecked_mut(i) = kv;
870
871 if use_sma_d {
872 if kv.is_nan() {
873 *d_vals.get_unchecked_mut(i) = f64::NAN;
874 } else if d_cnt < fastd_period {
875 d_sum += kv;
876 d_cnt += 1;
877 if d_cnt == fastd_period {
878 *d_vals.get_unchecked_mut(i) = d_sum / (fastd_period as f64);
879 } else {
880 *d_vals.get_unchecked_mut(i) = f64::NAN;
881 }
882 } else {
883 d_sum += kv - *k_vals.get_unchecked(i - fastd_period);
884 *d_vals.get_unchecked_mut(i) = d_sum / (fastd_period as f64);
885 }
886 }
887 }
888
889 if !use_sma_d {
890 d_vals.fill(f64::NAN);
891 }
892 } else {
893 stochf_scalar(
894 high,
895 low,
896 close,
897 fastk_period,
898 fastd_period,
899 matype,
900 first_valid_idx,
901 k_vals,
902 d_vals,
903 );
904 }
905}
906
907#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
908#[inline]
909pub unsafe fn stochf_avx512(
910 high: &[f64],
911 low: &[f64],
912 close: &[f64],
913 fastk_period: usize,
914 fastd_period: usize,
915 matype: usize,
916 first_valid_idx: usize,
917 k_vals: &mut [f64],
918 d_vals: &mut [f64],
919) {
920 if fastk_period <= 32 {
921 stochf_avx512_short(
922 high,
923 low,
924 close,
925 fastk_period,
926 fastd_period,
927 matype,
928 first_valid_idx,
929 k_vals,
930 d_vals,
931 );
932 } else {
933 stochf_avx512_long(
934 high,
935 low,
936 close,
937 fastk_period,
938 fastd_period,
939 matype,
940 first_valid_idx,
941 k_vals,
942 d_vals,
943 );
944 }
945}
946
947#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
948#[inline]
949pub unsafe fn stochf_avx512_short(
950 high: &[f64],
951 low: &[f64],
952 close: &[f64],
953 fastk_period: usize,
954 fastd_period: usize,
955 matype: usize,
956 first_valid_idx: usize,
957 k_vals: &mut [f64],
958 d_vals: &mut [f64],
959) {
960 stochf_scalar(
961 high,
962 low,
963 close,
964 fastk_period,
965 fastd_period,
966 matype,
967 first_valid_idx,
968 k_vals,
969 d_vals,
970 );
971}
972
973#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
974#[inline]
975pub unsafe fn stochf_avx512_long(
976 high: &[f64],
977 low: &[f64],
978 close: &[f64],
979 fastk_period: usize,
980 fastd_period: usize,
981 matype: usize,
982 first_valid_idx: usize,
983 k_vals: &mut [f64],
984 d_vals: &mut [f64],
985) {
986 stochf_scalar(
987 high,
988 low,
989 close,
990 fastk_period,
991 fastd_period,
992 matype,
993 first_valid_idx,
994 k_vals,
995 d_vals,
996 );
997}
998
999#[derive(Debug, Clone)]
1000pub struct StochfStream {
1001 fastk_period: usize,
1002 fastd_period: usize,
1003 fastd_matype: usize,
1004
1005 qh_idx: Vec<usize>,
1006 qh_val: Vec<f64>,
1007 qh_head: usize,
1008 qh_tail: usize,
1009
1010 ql_idx: Vec<usize>,
1011 ql_val: Vec<f64>,
1012 ql_head: usize,
1013 ql_tail: usize,
1014
1015 cap_k: usize,
1016
1017 qh_full: bool,
1018 ql_full: bool,
1019
1020 t: usize,
1021
1022 k_ring: Vec<f64>,
1023 k_head: usize,
1024 k_count: usize,
1025 d_sma_sum: f64,
1026}
1027
1028impl StochfStream {
1029 pub fn try_new(params: StochfParams) -> Result<Self, StochfError> {
1030 let fastk_period = params.fastk_period.unwrap_or(5);
1031 let fastd_period = params.fastd_period.unwrap_or(3);
1032 let fastd_matype = params.fastd_matype.unwrap_or(0);
1033
1034 if fastk_period == 0 || fastd_period == 0 {
1035 return Err(StochfError::InvalidPeriod {
1036 fastk: fastk_period,
1037 fastd: fastd_period,
1038 data_len: 0,
1039 });
1040 }
1041
1042 let cap_k = fastk_period + 1;
1043
1044 Ok(Self {
1045 fastk_period,
1046 fastd_period,
1047 fastd_matype,
1048
1049 qh_idx: vec![0; cap_k],
1050 qh_val: vec![0.0; cap_k],
1051 qh_head: 0,
1052 qh_tail: 0,
1053 qh_full: false,
1054
1055 ql_idx: vec![0; cap_k],
1056 ql_val: vec![0.0; cap_k],
1057 ql_head: 0,
1058 ql_tail: 0,
1059 ql_full: false,
1060
1061 cap_k,
1062
1063 t: 0,
1064
1065 k_ring: vec![0.0; fastd_period],
1066 k_head: 0,
1067 k_count: 0,
1068 d_sma_sum: 0.0,
1069 })
1070 }
1071
1072 #[inline(always)]
1073 fn inc(idx: &mut usize, cap: usize) {
1074 *idx += 1;
1075 if *idx == cap {
1076 *idx = 0;
1077 }
1078 }
1079
1080 #[inline(always)]
1081 fn dec(idx: &mut usize, cap: usize) {
1082 if *idx == 0 {
1083 *idx = cap - 1;
1084 } else {
1085 *idx -= 1;
1086 }
1087 }
1088
1089 #[inline(always)]
1090 fn qh_expire(&mut self, win_start: usize) {
1091 while (self.qh_head != self.qh_tail || self.qh_full)
1092 && self.qh_idx[self.qh_head] < win_start
1093 {
1094 Self::inc(&mut self.qh_head, self.cap_k);
1095
1096 self.qh_full = false;
1097 }
1098 }
1099 #[inline(always)]
1100 fn ql_expire(&mut self, win_start: usize) {
1101 while (self.ql_head != self.ql_tail || self.ql_full)
1102 && self.ql_idx[self.ql_head] < win_start
1103 {
1104 Self::inc(&mut self.ql_head, self.cap_k);
1105 self.ql_full = false;
1106 }
1107 }
1108
1109 #[inline(always)]
1110 fn qh_push(&mut self, idx: usize, val: f64) {
1111 while self.qh_head != self.qh_tail || self.qh_full {
1112 let mut back = self.qh_tail;
1113 Self::dec(&mut back, self.cap_k);
1114 if self.qh_val[back] <= val {
1115 self.qh_tail = back;
1116
1117 self.qh_full = false;
1118 } else {
1119 break;
1120 }
1121 }
1122 self.qh_idx[self.qh_tail] = idx;
1123 self.qh_val[self.qh_tail] = val;
1124 Self::inc(&mut self.qh_tail, self.cap_k);
1125
1126 if self.qh_tail == self.qh_head {
1127 self.qh_full = true;
1128 }
1129 }
1130
1131 #[inline(always)]
1132 fn ql_push(&mut self, idx: usize, val: f64) {
1133 while self.ql_head != self.ql_tail || self.ql_full {
1134 let mut back = self.ql_tail;
1135 Self::dec(&mut back, self.cap_k);
1136 if self.ql_val[back] >= val {
1137 self.ql_tail = back;
1138 self.ql_full = false;
1139 } else {
1140 break;
1141 }
1142 }
1143 self.ql_idx[self.ql_tail] = idx;
1144 self.ql_val[self.ql_tail] = val;
1145 Self::inc(&mut self.ql_tail, self.cap_k);
1146 if self.ql_tail == self.ql_head {
1147 self.ql_full = true;
1148 }
1149 }
1150
1151 pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
1152 let i = self.t;
1153
1154 self.t = self.t.wrapping_add(1);
1155
1156 if high == high {
1157 self.qh_push(i, high);
1158 }
1159 if low == low {
1160 self.ql_push(i, low);
1161 }
1162
1163 let have_k_window = (i + 1) >= self.fastk_period;
1164 if have_k_window {
1165 let win_start = i + 1 - self.fastk_period;
1166 self.qh_expire(win_start);
1167 self.ql_expire(win_start);
1168 } else {
1169 return None;
1170 }
1171
1172 let hh = if self.qh_head != self.qh_tail || self.qh_full {
1173 self.qh_val[self.qh_head]
1174 } else {
1175 f64::NEG_INFINITY
1176 };
1177 let ll = if self.ql_head != self.ql_tail || self.ql_full {
1178 self.ql_val[self.ql_head]
1179 } else {
1180 f64::INFINITY
1181 };
1182
1183 let denom = hh - ll;
1184 let k = if denom == 0.0 {
1185 if close == hh {
1186 100.0
1187 } else {
1188 0.0
1189 }
1190 } else {
1191 let scale = 100.0 / denom;
1192 close.mul_add(scale, (-ll) * scale)
1193 };
1194
1195 let d = if self.fastd_matype != 0 {
1196 f64::NAN
1197 } else if self.k_count < self.fastd_period {
1198 self.k_ring[self.k_head] = k;
1199 self.d_sma_sum += k;
1200 self.k_count += 1;
1201 StochfStream::inc(&mut self.k_head, self.fastd_period);
1202
1203 if self.k_count == self.fastd_period {
1204 self.d_sma_sum / (self.fastd_period as f64)
1205 } else {
1206 f64::NAN
1207 }
1208 } else {
1209 let old = self.k_ring[self.k_head];
1210 self.k_ring[self.k_head] = k;
1211 StochfStream::inc(&mut self.k_head, self.fastd_period);
1212
1213 self.d_sma_sum += k - old;
1214 self.d_sma_sum / (self.fastd_period as f64)
1215 };
1216
1217 Some((k, d))
1218 }
1219}
1220
1221#[derive(Clone, Debug)]
1222pub struct StochfBatchRange {
1223 pub fastk_period: (usize, usize, usize),
1224 pub fastd_period: (usize, usize, usize),
1225}
1226
1227impl Default for StochfBatchRange {
1228 fn default() -> Self {
1229 Self {
1230 fastk_period: (5, 254, 1),
1231 fastd_period: (3, 3, 0),
1232 }
1233 }
1234}
1235
1236#[derive(Clone, Debug, Default)]
1237pub struct StochfBatchBuilder {
1238 range: StochfBatchRange,
1239 kernel: Kernel,
1240}
1241
1242impl StochfBatchBuilder {
1243 pub fn new() -> Self {
1244 Self::default()
1245 }
1246 pub fn kernel(mut self, k: Kernel) -> Self {
1247 self.kernel = k;
1248 self
1249 }
1250 pub fn fastk_range(mut self, start: usize, end: usize, step: usize) -> Self {
1251 self.range.fastk_period = (start, end, step);
1252 self
1253 }
1254 pub fn fastk_static(mut self, p: usize) -> Self {
1255 self.range.fastk_period = (p, p, 0);
1256 self
1257 }
1258 pub fn fastd_range(mut self, start: usize, end: usize, step: usize) -> Self {
1259 self.range.fastd_period = (start, end, step);
1260 self
1261 }
1262 pub fn fastd_static(mut self, p: usize) -> Self {
1263 self.range.fastd_period = (p, p, 0);
1264 self
1265 }
1266 pub fn apply_slices(
1267 self,
1268 high: &[f64],
1269 low: &[f64],
1270 close: &[f64],
1271 ) -> Result<StochfBatchOutput, StochfError> {
1272 stochf_batch_with_kernel(high, low, close, &self.range, self.kernel)
1273 }
1274 pub fn with_default_slices(
1275 high: &[f64],
1276 low: &[f64],
1277 close: &[f64],
1278 k: Kernel,
1279 ) -> Result<StochfBatchOutput, StochfError> {
1280 StochfBatchBuilder::new()
1281 .kernel(k)
1282 .apply_slices(high, low, close)
1283 }
1284 pub fn apply_candles(self, c: &Candles) -> Result<StochfBatchOutput, StochfError> {
1285 let high = source_type(c, "high");
1286 let low = source_type(c, "low");
1287 let close = source_type(c, "close");
1288 self.apply_slices(high, low, close)
1289 }
1290 pub fn with_default_candles(c: &Candles) -> Result<StochfBatchOutput, StochfError> {
1291 StochfBatchBuilder::new()
1292 .kernel(Kernel::Auto)
1293 .apply_candles(c)
1294 }
1295}
1296
1297pub fn stochf_batch_with_kernel(
1298 high: &[f64],
1299 low: &[f64],
1300 close: &[f64],
1301 sweep: &StochfBatchRange,
1302 k: Kernel,
1303) -> Result<StochfBatchOutput, StochfError> {
1304 let kernel = match k {
1305 Kernel::Auto => detect_best_batch_kernel(),
1306 other if other.is_batch() => other,
1307 _ => return Err(StochfError::InvalidKernelForBatch(k)),
1308 };
1309 let simd = match kernel {
1310 Kernel::Avx512Batch => Kernel::Avx512,
1311 Kernel::Avx2Batch => Kernel::Avx2,
1312 Kernel::ScalarBatch => Kernel::Scalar,
1313 _ => unreachable!(),
1314 };
1315 stochf_batch_par_slice(high, low, close, sweep, simd)
1316}
1317
1318#[derive(Clone, Debug)]
1319pub struct StochfBatchOutput {
1320 pub k: Vec<f64>,
1321 pub d: Vec<f64>,
1322 pub combos: Vec<StochfParams>,
1323 pub rows: usize,
1324 pub cols: usize,
1325}
1326impl StochfBatchOutput {
1327 pub fn row_for_params(&self, p: &StochfParams) -> Option<usize> {
1328 self.combos.iter().position(|c| {
1329 c.fastk_period.unwrap_or(5) == p.fastk_period.unwrap_or(5)
1330 && c.fastd_period.unwrap_or(3) == p.fastd_period.unwrap_or(3)
1331 && c.fastd_matype.unwrap_or(0) == p.fastd_matype.unwrap_or(0)
1332 })
1333 }
1334 pub fn k_for(&self, p: &StochfParams) -> Option<&[f64]> {
1335 self.row_for_params(p).map(|row| {
1336 let start = row * self.cols;
1337 &self.k[start..start + self.cols]
1338 })
1339 }
1340 pub fn d_for(&self, p: &StochfParams) -> Option<&[f64]> {
1341 self.row_for_params(p).map(|row| {
1342 let start = row * self.cols;
1343 &self.d[start..start + self.cols]
1344 })
1345 }
1346}
1347
1348#[inline(always)]
1349fn expand_grid(r: &StochfBatchRange) -> Vec<StochfParams> {
1350 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, StochfError> {
1351 if step == 0 || start == end {
1352 return Ok(vec![start]);
1353 }
1354 if start < end {
1355 let mut v = Vec::new();
1356 let st = step.max(1);
1357 let mut x = start;
1358 while x <= end {
1359 v.push(x);
1360 x = match x.checked_add(st) {
1361 Some(next) => next,
1362 None => break,
1363 };
1364 }
1365 if v.is_empty() {
1366 return Err(StochfError::InvalidRange { start, end, step });
1367 }
1368 return Ok(v);
1369 }
1370
1371 let mut v = Vec::new();
1372 let st = step.max(1) as isize;
1373 let mut x = start as isize;
1374 let end_i = end as isize;
1375 while x >= end_i {
1376 v.push(x as usize);
1377 x -= st;
1378 }
1379 if v.is_empty() {
1380 return Err(StochfError::InvalidRange { start, end, step });
1381 }
1382 Ok(v)
1383 }
1384 let fastk = axis_usize(r.fastk_period).unwrap_or_else(|_| Vec::new());
1385 let fastd = axis_usize(r.fastd_period).unwrap_or_else(|_| Vec::new());
1386 let mut out = Vec::with_capacity(fastk.len().saturating_mul(fastd.len()));
1387 for &k in &fastk {
1388 for &d in &fastd {
1389 out.push(StochfParams {
1390 fastk_period: Some(k),
1391 fastd_period: Some(d),
1392 fastd_matype: Some(0),
1393 });
1394 }
1395 }
1396 out
1397}
1398
1399#[inline(always)]
1400pub fn stochf_batch_slice(
1401 high: &[f64],
1402 low: &[f64],
1403 close: &[f64],
1404 sweep: &StochfBatchRange,
1405 kern: Kernel,
1406) -> Result<StochfBatchOutput, StochfError> {
1407 stochf_batch_inner(high, low, close, sweep, kern, false)
1408}
1409
1410#[inline(always)]
1411pub fn stochf_batch_par_slice(
1412 high: &[f64],
1413 low: &[f64],
1414 close: &[f64],
1415 sweep: &StochfBatchRange,
1416 kern: Kernel,
1417) -> Result<StochfBatchOutput, StochfError> {
1418 stochf_batch_inner(high, low, close, sweep, kern, true)
1419}
1420
1421#[inline(always)]
1422pub fn stochf_batch_inner_into(
1423 high: &[f64],
1424 low: &[f64],
1425 close: &[f64],
1426 sweep: &StochfBatchRange,
1427 kern: Kernel,
1428 parallel: bool,
1429 k_out: &mut [f64],
1430 d_out: &mut [f64],
1431) -> Result<Vec<StochfParams>, StochfError> {
1432 let combos = expand_grid(sweep);
1433 if combos.is_empty() {
1434 return Err(StochfError::InvalidRange {
1435 start: sweep.fastk_period.0,
1436 end: sweep.fastk_period.1,
1437 step: sweep.fastk_period.2,
1438 });
1439 }
1440 let first = (0..high.len())
1441 .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
1442 .ok_or(StochfError::AllValuesNaN)?;
1443 let max_k = combos
1444 .iter()
1445 .map(|c| c.fastk_period.unwrap())
1446 .max()
1447 .unwrap();
1448 if high.len() - first < max_k {
1449 return Err(StochfError::NotEnoughValidData {
1450 needed: max_k,
1451 valid: high.len() - first,
1452 });
1453 }
1454 let rows = combos.len();
1455 let cols = high.len();
1456
1457 let expected_size = rows.checked_mul(cols).ok_or(StochfError::InvalidRange {
1458 start: sweep.fastk_period.0,
1459 end: sweep.fastk_period.1,
1460 step: sweep.fastk_period.2,
1461 })?;
1462 if k_out.len() != expected_size || d_out.len() != expected_size {
1463 return Err(StochfError::OutputLengthMismatch {
1464 expected: expected_size,
1465 k_got: k_out.len(),
1466 d_got: d_out.len(),
1467 });
1468 }
1469
1470 for (row, combo) in combos.iter().enumerate() {
1471 let k_warmup = (first + combo.fastk_period.unwrap() - 1).min(cols);
1472 let d_warmup =
1473 (first + combo.fastk_period.unwrap() + combo.fastd_period.unwrap() - 2).min(cols);
1474 let row_start = row * cols;
1475
1476 for i in 0..k_warmup {
1477 k_out[row_start + i] = f64::NAN;
1478 }
1479
1480 for i in 0..d_warmup {
1481 d_out[row_start + i] = f64::NAN;
1482 }
1483 }
1484
1485 let do_row = |row: usize, kout: &mut [f64], dout: &mut [f64]| unsafe {
1486 let fastk_period = combos[row].fastk_period.unwrap();
1487 let fastd_period = combos[row].fastd_period.unwrap();
1488 let matype = combos[row].fastd_matype.unwrap();
1489 match kern {
1490 Kernel::Scalar => stochf_row_scalar(
1491 high,
1492 low,
1493 close,
1494 first,
1495 fastk_period,
1496 fastd_period,
1497 matype,
1498 kout,
1499 dout,
1500 ),
1501 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1502 Kernel::Avx2 => stochf_row_avx2(
1503 high,
1504 low,
1505 close,
1506 first,
1507 fastk_period,
1508 fastd_period,
1509 matype,
1510 kout,
1511 dout,
1512 ),
1513 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1514 Kernel::Avx512 => stochf_row_avx512(
1515 high,
1516 low,
1517 close,
1518 first,
1519 fastk_period,
1520 fastd_period,
1521 matype,
1522 kout,
1523 dout,
1524 ),
1525 _ => unreachable!(),
1526 }
1527 };
1528
1529 if parallel {
1530 #[cfg(not(target_arch = "wasm32"))]
1531 {
1532 k_out
1533 .par_chunks_mut(cols)
1534 .zip(d_out.par_chunks_mut(cols))
1535 .enumerate()
1536 .for_each(|(row, (k, d))| do_row(row, k, d));
1537 }
1538
1539 #[cfg(target_arch = "wasm32")]
1540 {
1541 for (row, (k, d)) in k_out
1542 .chunks_mut(cols)
1543 .zip(d_out.chunks_mut(cols))
1544 .enumerate()
1545 {
1546 do_row(row, k, d);
1547 }
1548 }
1549 } else {
1550 for (row, (k, d)) in k_out
1551 .chunks_mut(cols)
1552 .zip(d_out.chunks_mut(cols))
1553 .enumerate()
1554 {
1555 do_row(row, k, d);
1556 }
1557 }
1558
1559 Ok(combos)
1560}
1561
1562#[inline(always)]
1563fn stochf_batch_inner(
1564 high: &[f64],
1565 low: &[f64],
1566 close: &[f64],
1567 sweep: &StochfBatchRange,
1568 kern: Kernel,
1569 parallel: bool,
1570) -> Result<StochfBatchOutput, StochfError> {
1571 let combos = expand_grid(sweep);
1572 if combos.is_empty() {
1573 return Err(StochfError::InvalidRange {
1574 start: sweep.fastk_period.0,
1575 end: sweep.fastk_period.1,
1576 step: sweep.fastk_period.2,
1577 });
1578 }
1579 let first = (0..high.len())
1580 .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
1581 .ok_or(StochfError::AllValuesNaN)?;
1582 let max_k = combos
1583 .iter()
1584 .map(|c| c.fastk_period.unwrap())
1585 .max()
1586 .unwrap();
1587 if high.len() - first < max_k {
1588 return Err(StochfError::NotEnoughValidData {
1589 needed: max_k,
1590 valid: high.len() - first,
1591 });
1592 }
1593 let rows = combos.len();
1594 let cols = high.len();
1595
1596 let _total = rows.checked_mul(cols).ok_or(StochfError::InvalidRange {
1597 start: sweep.fastk_period.0,
1598 end: sweep.fastk_period.1,
1599 step: sweep.fastk_period.2,
1600 })?;
1601
1602 let mut k_buf = make_uninit_matrix(rows, cols);
1603 let mut d_buf = make_uninit_matrix(rows, cols);
1604
1605 let k_warmups: Vec<usize> = combos
1606 .iter()
1607 .map(|c| (first + c.fastk_period.unwrap() - 1).min(cols))
1608 .collect();
1609 let d_warmups: Vec<usize> = combos
1610 .iter()
1611 .map(|c| (first + c.fastk_period.unwrap() + c.fastd_period.unwrap() - 2).min(cols))
1612 .collect();
1613
1614 init_matrix_prefixes(&mut k_buf, cols, &k_warmups);
1615 init_matrix_prefixes(&mut d_buf, cols, &d_warmups);
1616
1617 let k_buf_len = k_buf.len();
1618 let d_buf_len = d_buf.len();
1619 let k_buf_cap = k_buf.capacity();
1620 let d_buf_cap = d_buf.capacity();
1621 let k_ptr = k_buf.as_mut_ptr();
1622 let d_ptr = d_buf.as_mut_ptr();
1623 std::mem::forget(k_buf);
1624 std::mem::forget(d_buf);
1625 let k_out = unsafe { std::slice::from_raw_parts_mut(k_ptr as *mut f64, k_buf_len) };
1626 let d_out = unsafe { std::slice::from_raw_parts_mut(d_ptr as *mut f64, d_buf_len) };
1627
1628 let do_row = |row: usize, kout: &mut [f64], dout: &mut [f64]| unsafe {
1629 let fastk_period = combos[row].fastk_period.unwrap();
1630 let fastd_period = combos[row].fastd_period.unwrap();
1631 let matype = combos[row].fastd_matype.unwrap();
1632 match kern {
1633 Kernel::Scalar => stochf_row_scalar(
1634 high,
1635 low,
1636 close,
1637 first,
1638 fastk_period,
1639 fastd_period,
1640 matype,
1641 kout,
1642 dout,
1643 ),
1644 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1645 Kernel::Avx2 => stochf_row_avx2(
1646 high,
1647 low,
1648 close,
1649 first,
1650 fastk_period,
1651 fastd_period,
1652 matype,
1653 kout,
1654 dout,
1655 ),
1656 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1657 Kernel::Avx512 => stochf_row_avx512(
1658 high,
1659 low,
1660 close,
1661 first,
1662 fastk_period,
1663 fastd_period,
1664 matype,
1665 kout,
1666 dout,
1667 ),
1668 _ => unreachable!(),
1669 }
1670 };
1671 if parallel {
1672 #[cfg(not(target_arch = "wasm32"))]
1673 {
1674 k_out
1675 .par_chunks_mut(cols)
1676 .zip(d_out.par_chunks_mut(cols))
1677 .enumerate()
1678 .for_each(|(row, (k, d))| do_row(row, k, d));
1679 }
1680
1681 #[cfg(target_arch = "wasm32")]
1682 {
1683 for (row, (k, d)) in k_out
1684 .chunks_mut(cols)
1685 .zip(d_out.chunks_mut(cols))
1686 .enumerate()
1687 {
1688 do_row(row, k, d);
1689 }
1690 }
1691 } else {
1692 for (row, (k, d)) in k_out
1693 .chunks_mut(cols)
1694 .zip(d_out.chunks_mut(cols))
1695 .enumerate()
1696 {
1697 do_row(row, k, d);
1698 }
1699 }
1700
1701 let k_vec = unsafe { Vec::from_raw_parts(k_ptr as *mut f64, k_buf_len, k_buf_cap) };
1702 let d_vec = unsafe { Vec::from_raw_parts(d_ptr as *mut f64, d_buf_len, d_buf_cap) };
1703
1704 Ok(StochfBatchOutput {
1705 k: k_vec,
1706 d: d_vec,
1707 combos,
1708 rows,
1709 cols,
1710 })
1711}
1712
1713#[inline(always)]
1714unsafe fn stochf_row_scalar(
1715 high: &[f64],
1716 low: &[f64],
1717 close: &[f64],
1718 first: usize,
1719 fastk_period: usize,
1720 fastd_period: usize,
1721 matype: usize,
1722 k_out: &mut [f64],
1723 d_out: &mut [f64],
1724) {
1725 let len = high.len();
1726
1727 let hp = high.as_ptr();
1728 let lp = low.as_ptr();
1729 let cp = close.as_ptr();
1730
1731 let k_start = first + fastk_period - 1;
1732
1733 if fastk_period <= 16 {
1734 let use_sma_d = matype == 0;
1735 let mut d_sum: f64 = 0.0;
1736 let mut d_cnt: usize = 0;
1737
1738 let mut i = k_start;
1739 while i < len {
1740 let start = i + 1 - fastk_period;
1741 let end = i + 1;
1742
1743 let mut hh = f64::NEG_INFINITY;
1744 let mut ll = f64::INFINITY;
1745
1746 let mut j = start;
1747 let unroll_end = end - ((end - j) & 3);
1748 while j < unroll_end {
1749 let h0 = *hp.add(j);
1750 let l0 = *lp.add(j);
1751 if h0 > hh {
1752 hh = h0;
1753 }
1754 if l0 < ll {
1755 ll = l0;
1756 }
1757
1758 let h1 = *hp.add(j + 1);
1759 let l1 = *lp.add(j + 1);
1760 if h1 > hh {
1761 hh = h1;
1762 }
1763 if l1 < ll {
1764 ll = l1;
1765 }
1766
1767 let h2 = *hp.add(j + 2);
1768 let l2 = *lp.add(j + 2);
1769 if h2 > hh {
1770 hh = h2;
1771 }
1772 if l2 < ll {
1773 ll = l2;
1774 }
1775
1776 let h3 = *hp.add(j + 3);
1777 let l3 = *lp.add(j + 3);
1778 if h3 > hh {
1779 hh = h3;
1780 }
1781 if l3 < ll {
1782 ll = l3;
1783 }
1784
1785 j += 4;
1786 }
1787 while j < end {
1788 let h = *hp.add(j);
1789 let l = *lp.add(j);
1790 if h > hh {
1791 hh = h;
1792 }
1793 if l < ll {
1794 ll = l;
1795 }
1796 j += 1;
1797 }
1798
1799 let c = *cp.add(i);
1800 let denom = hh - ll;
1801 let kv = if denom == 0.0 {
1802 if c == hh {
1803 100.0
1804 } else {
1805 0.0
1806 }
1807 } else {
1808 let inv = 100.0 / denom;
1809 c.mul_add(inv, (-ll) * inv)
1810 };
1811 *k_out.get_unchecked_mut(i) = kv;
1812
1813 if use_sma_d {
1814 if kv.is_nan() {
1815 *d_out.get_unchecked_mut(i) = f64::NAN;
1816 } else if d_cnt < fastd_period {
1817 d_sum += kv;
1818 d_cnt += 1;
1819 if d_cnt == fastd_period {
1820 *d_out.get_unchecked_mut(i) = d_sum / (fastd_period as f64);
1821 } else {
1822 *d_out.get_unchecked_mut(i) = f64::NAN;
1823 }
1824 } else {
1825 d_sum += kv - *k_out.get_unchecked(i - fastd_period);
1826 *d_out.get_unchecked_mut(i) = d_sum / (fastd_period as f64);
1827 }
1828 }
1829
1830 i += 1;
1831 }
1832
1833 if matype != 0 {
1834 d_out.fill(f64::NAN);
1835 }
1836 return;
1837 }
1838
1839 let cap = fastk_period;
1840 let mut qh = vec![0usize; cap];
1841 let mut ql = vec![0usize; cap];
1842 let mut qh_head = 0usize;
1843 let mut qh_tail = 0usize;
1844 let mut ql_head = 0usize;
1845 let mut ql_tail = 0usize;
1846
1847 let use_sma_d = matype == 0;
1848 let mut d_sum: f64 = 0.0;
1849 let mut d_cnt: usize = 0;
1850
1851 let mut i = first;
1852 while i < len {
1853 if i + 1 >= fastk_period {
1854 let win_start = i + 1 - fastk_period;
1855 while qh_head != qh_tail {
1856 let idx = *qh.get_unchecked(qh_head);
1857 if idx >= win_start {
1858 break;
1859 }
1860 qh_head += 1;
1861 if qh_head == cap {
1862 qh_head = 0;
1863 }
1864 }
1865 while ql_head != ql_tail {
1866 let idx = *ql.get_unchecked(ql_head);
1867 if idx >= win_start {
1868 break;
1869 }
1870 ql_head += 1;
1871 if ql_head == cap {
1872 ql_head = 0;
1873 }
1874 }
1875 }
1876
1877 let h_i = *hp.add(i);
1878 if h_i == h_i {
1879 while qh_head != qh_tail {
1880 let back = if qh_tail == 0 { cap - 1 } else { qh_tail - 1 };
1881 let back_idx = *qh.get_unchecked(back);
1882 if *hp.add(back_idx) <= h_i {
1883 qh_tail = back;
1884 } else {
1885 break;
1886 }
1887 }
1888 *qh.get_unchecked_mut(qh_tail) = i;
1889 qh_tail += 1;
1890 if qh_tail == cap {
1891 qh_tail = 0;
1892 }
1893 }
1894
1895 let l_i = *lp.add(i);
1896 if l_i == l_i {
1897 while ql_head != ql_tail {
1898 let back = if ql_tail == 0 { cap - 1 } else { ql_tail - 1 };
1899 let back_idx = *ql.get_unchecked(back);
1900 if *lp.add(back_idx) >= l_i {
1901 ql_tail = back;
1902 } else {
1903 break;
1904 }
1905 }
1906 *ql.get_unchecked_mut(ql_tail) = i;
1907 ql_tail += 1;
1908 if ql_tail == cap {
1909 ql_tail = 0;
1910 }
1911 }
1912
1913 if i >= k_start {
1914 let hh = if qh_head != qh_tail {
1915 *hp.add(*qh.get_unchecked(qh_head))
1916 } else {
1917 f64::NEG_INFINITY
1918 };
1919 let ll = if ql_head != ql_tail {
1920 *lp.add(*ql.get_unchecked(ql_head))
1921 } else {
1922 f64::INFINITY
1923 };
1924 let c = *cp.add(i);
1925 let denom = hh - ll;
1926 let kv = if denom == 0.0 {
1927 if c == hh {
1928 100.0
1929 } else {
1930 0.0
1931 }
1932 } else {
1933 let inv = 100.0 / denom;
1934 c.mul_add(inv, (-ll) * inv)
1935 };
1936 *k_out.get_unchecked_mut(i) = kv;
1937
1938 if use_sma_d {
1939 if kv.is_nan() {
1940 *d_out.get_unchecked_mut(i) = f64::NAN;
1941 } else if d_cnt < fastd_period {
1942 d_sum += kv;
1943 d_cnt += 1;
1944 if d_cnt == fastd_period {
1945 *d_out.get_unchecked_mut(i) = d_sum / (fastd_period as f64);
1946 } else {
1947 *d_out.get_unchecked_mut(i) = f64::NAN;
1948 }
1949 } else {
1950 d_sum += kv - *k_out.get_unchecked(i - fastd_period);
1951 *d_out.get_unchecked_mut(i) = d_sum / (fastd_period as f64);
1952 }
1953 }
1954 }
1955
1956 i += 1;
1957 }
1958
1959 if !use_sma_d {
1960 d_out.fill(f64::NAN);
1961 }
1962}
1963
1964#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1965#[inline]
1966unsafe fn stochf_row_avx2(
1967 high: &[f64],
1968 low: &[f64],
1969 close: &[f64],
1970 first: usize,
1971 fastk_period: usize,
1972 fastd_period: usize,
1973 matype: usize,
1974 k_out: &mut [f64],
1975 d_out: &mut [f64],
1976) {
1977 stochf_row_scalar(
1978 high,
1979 low,
1980 close,
1981 first,
1982 fastk_period,
1983 fastd_period,
1984 matype,
1985 k_out,
1986 d_out,
1987 );
1988}
1989
1990#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1991#[inline]
1992pub unsafe fn stochf_row_avx512(
1993 high: &[f64],
1994 low: &[f64],
1995 close: &[f64],
1996 first: usize,
1997 fastk_period: usize,
1998 fastd_period: usize,
1999 matype: usize,
2000 k_out: &mut [f64],
2001 d_out: &mut [f64],
2002) {
2003 if fastk_period <= 32 {
2004 stochf_row_avx512_short(
2005 high,
2006 low,
2007 close,
2008 first,
2009 fastk_period,
2010 fastd_period,
2011 matype,
2012 k_out,
2013 d_out,
2014 );
2015 } else {
2016 stochf_row_avx512_long(
2017 high,
2018 low,
2019 close,
2020 first,
2021 fastk_period,
2022 fastd_period,
2023 matype,
2024 k_out,
2025 d_out,
2026 );
2027 }
2028}
2029
2030#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2031#[inline]
2032pub unsafe fn stochf_row_avx512_short(
2033 high: &[f64],
2034 low: &[f64],
2035 close: &[f64],
2036 first: usize,
2037 fastk_period: usize,
2038 fastd_period: usize,
2039 matype: usize,
2040 k_out: &mut [f64],
2041 d_out: &mut [f64],
2042) {
2043 stochf_row_scalar(
2044 high,
2045 low,
2046 close,
2047 first,
2048 fastk_period,
2049 fastd_period,
2050 matype,
2051 k_out,
2052 d_out,
2053 );
2054}
2055
2056#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2057#[inline]
2058pub unsafe fn stochf_row_avx512_long(
2059 high: &[f64],
2060 low: &[f64],
2061 close: &[f64],
2062 first: usize,
2063 fastk_period: usize,
2064 fastd_period: usize,
2065 matype: usize,
2066 k_out: &mut [f64],
2067 d_out: &mut [f64],
2068) {
2069 stochf_row_scalar(
2070 high,
2071 low,
2072 close,
2073 first,
2074 fastk_period,
2075 fastd_period,
2076 matype,
2077 k_out,
2078 d_out,
2079 );
2080}
2081
2082#[cfg(feature = "python")]
2083#[pyfunction(name = "stochf")]
2084#[pyo3(signature = (high, low, close, fastk_period=None, fastd_period=None, fastd_matype=None, kernel=None))]
2085pub fn stochf_py<'py>(
2086 py: Python<'py>,
2087 high: numpy::PyReadonlyArray1<'py, f64>,
2088 low: numpy::PyReadonlyArray1<'py, f64>,
2089 close: numpy::PyReadonlyArray1<'py, f64>,
2090 fastk_period: Option<usize>,
2091 fastd_period: Option<usize>,
2092 fastd_matype: Option<usize>,
2093 kernel: Option<&str>,
2094) -> PyResult<(
2095 Bound<'py, numpy::PyArray1<f64>>,
2096 Bound<'py, numpy::PyArray1<f64>>,
2097)> {
2098 use numpy::{IntoPyArray, PyArrayMethods};
2099
2100 let high_slice = high.as_slice()?;
2101 let low_slice = low.as_slice()?;
2102 let close_slice = close.as_slice()?;
2103 let kern = validate_kernel(kernel, false)?;
2104
2105 if high_slice.len() != low_slice.len() || high_slice.len() != close_slice.len() {
2106 return Err(PyValueError::new_err(
2107 "Input arrays must have the same length",
2108 ));
2109 }
2110
2111 let params = StochfParams {
2112 fastk_period,
2113 fastd_period,
2114 fastd_matype,
2115 };
2116 let input = StochfInput::from_slices(high_slice, low_slice, close_slice, params);
2117
2118 let (k_vec, d_vec) = py
2119 .allow_threads(|| stochf_with_kernel(&input, kern).map(|o| (o.k, o.d)))
2120 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2121
2122 Ok((k_vec.into_pyarray(py), d_vec.into_pyarray(py)))
2123}
2124
2125#[cfg(feature = "python")]
2126#[pyclass(name = "StochfStream")]
2127pub struct StochfStreamPy {
2128 stream: StochfStream,
2129}
2130
2131#[cfg(feature = "python")]
2132#[pymethods]
2133impl StochfStreamPy {
2134 #[new]
2135 fn new(fastk_period: usize, fastd_period: usize, fastd_matype: usize) -> PyResult<Self> {
2136 let params = StochfParams {
2137 fastk_period: Some(fastk_period),
2138 fastd_period: Some(fastd_period),
2139 fastd_matype: Some(fastd_matype),
2140 };
2141 let stream =
2142 StochfStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2143 Ok(StochfStreamPy { stream })
2144 }
2145
2146 fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
2147 self.stream.update(high, low, close)
2148 }
2149}
2150
2151#[cfg(feature = "python")]
2152#[pyfunction(name = "stochf_batch")]
2153#[pyo3(signature = (high, low, close, fastk_range, fastd_range, kernel=None))]
2154pub fn stochf_batch_py<'py>(
2155 py: Python<'py>,
2156 high: numpy::PyReadonlyArray1<'py, f64>,
2157 low: numpy::PyReadonlyArray1<'py, f64>,
2158 close: numpy::PyReadonlyArray1<'py, f64>,
2159 fastk_range: (usize, usize, usize),
2160 fastd_range: (usize, usize, usize),
2161 kernel: Option<&str>,
2162) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2163 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2164 use pyo3::types::PyDict;
2165
2166 let high_slice = high.as_slice()?;
2167 let low_slice = low.as_slice()?;
2168 let close_slice = close.as_slice()?;
2169
2170 if high_slice.len() != low_slice.len() || high_slice.len() != close_slice.len() {
2171 return Err(PyValueError::new_err(
2172 "Input arrays must have the same length",
2173 ));
2174 }
2175
2176 let sweep = StochfBatchRange {
2177 fastk_period: fastk_range,
2178 fastd_period: fastd_range,
2179 };
2180
2181 let combos = expand_grid(&sweep);
2182 let rows = combos.len();
2183 if rows == 0 {
2184 return Err(PyValueError::new_err(
2185 "stochf: invalid range (empty expansion)",
2186 ));
2187 }
2188 let cols = high_slice.len();
2189 let total = rows
2190 .checked_mul(cols)
2191 .ok_or_else(|| PyValueError::new_err("stochf: rows*cols overflow"))?;
2192
2193 let k_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2194 let d_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2195 let k_slice = unsafe { k_arr.as_slice_mut()? };
2196 let d_slice = unsafe { d_arr.as_slice_mut()? };
2197
2198 let kern = validate_kernel(kernel, true)?;
2199
2200 let combos = py
2201 .allow_threads(|| {
2202 let kernel = match kern {
2203 Kernel::Auto => detect_best_batch_kernel(),
2204 k => k,
2205 };
2206 let simd = match kernel {
2207 Kernel::Avx512Batch => Kernel::Avx512,
2208 Kernel::Avx2Batch => Kernel::Avx2,
2209 Kernel::ScalarBatch => Kernel::Scalar,
2210 _ => unreachable!(),
2211 };
2212 stochf_batch_inner_into(
2213 high_slice,
2214 low_slice,
2215 close_slice,
2216 &sweep,
2217 simd,
2218 true,
2219 k_slice,
2220 d_slice,
2221 )
2222 })
2223 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2224
2225 let dict = PyDict::new(py);
2226 dict.set_item("k_values", k_arr.reshape((rows, cols))?)?;
2227 dict.set_item("d_values", d_arr.reshape((rows, cols))?)?;
2228 dict.set_item(
2229 "fastk_periods",
2230 combos
2231 .iter()
2232 .map(|p| p.fastk_period.unwrap() as u64)
2233 .collect::<Vec<_>>()
2234 .into_pyarray(py),
2235 )?;
2236 dict.set_item(
2237 "fastd_periods",
2238 combos
2239 .iter()
2240 .map(|p| p.fastd_period.unwrap() as u64)
2241 .collect::<Vec<_>>()
2242 .into_pyarray(py),
2243 )?;
2244
2245 Ok(dict)
2246}
2247
2248#[cfg(all(feature = "python", feature = "cuda"))]
2249use crate::cuda::{cuda_available, CudaStochf};
2250#[cfg(all(feature = "python", feature = "cuda"))]
2251use numpy::PyReadonlyArray1;
2252#[cfg(all(feature = "python", feature = "cuda"))]
2253use pyo3::exceptions::PyValueError as PyErrValue;
2254#[cfg(all(feature = "python", feature = "cuda"))]
2255use pyo3::PyErr;
2256
2257#[cfg(all(feature = "python", feature = "cuda"))]
2258#[pyfunction(name = "stochf_cuda_batch_dev")]
2259#[pyo3(signature = (high_f32, low_f32, close_f32, fastk_range, fastd_range, device_id=0))]
2260pub fn stochf_cuda_batch_dev_py(
2261 py: Python<'_>,
2262 high_f32: PyReadonlyArray1<'_, f32>,
2263 low_f32: PyReadonlyArray1<'_, f32>,
2264 close_f32: PyReadonlyArray1<'_, f32>,
2265 fastk_range: (usize, usize, usize),
2266 fastd_range: (usize, usize, usize),
2267 device_id: usize,
2268) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
2269 if !cuda_available() {
2270 return Err(PyErrValue::new_err("CUDA not available"));
2271 }
2272 let h = high_f32.as_slice()?;
2273 let l = low_f32.as_slice()?;
2274 let c = close_f32.as_slice()?;
2275 if h.len() != l.len() || h.len() != c.len() {
2276 return Err(PyErrValue::new_err("mismatched input lengths"));
2277 }
2278 let sweep = StochfBatchRange {
2279 fastk_period: fastk_range,
2280 fastd_period: fastd_range,
2281 };
2282 let (pair, ctx, dev_id) = py.allow_threads(|| {
2283 let cuda = CudaStochf::new(device_id).map_err(|e| PyErrValue::new_err(e.to_string()))?;
2284 let ctx = cuda.context_arc();
2285 let dev_id = cuda.device_id();
2286 let (pair, _combos) = cuda
2287 .stochf_batch_dev(h, l, c, &sweep)
2288 .map_err(|e| PyErrValue::new_err(e.to_string()))?;
2289 Ok::<_, PyErr>((pair, ctx, dev_id))
2290 })?;
2291 Ok((
2292 DeviceArrayF32Py {
2293 inner: pair.a,
2294 _ctx: Some(ctx.clone()),
2295 device_id: Some(dev_id),
2296 },
2297 DeviceArrayF32Py {
2298 inner: pair.b,
2299 _ctx: Some(ctx),
2300 device_id: Some(dev_id),
2301 },
2302 ))
2303}
2304
2305#[cfg(all(feature = "python", feature = "cuda"))]
2306#[pyfunction(name = "stochf_cuda_many_series_one_param_dev")]
2307#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, cols, rows, fastk, fastd, fastd_matype=0, device_id=0))]
2308pub fn stochf_cuda_many_series_one_param_dev_py(
2309 py: Python<'_>,
2310 high_tm_f32: PyReadonlyArray1<'_, f32>,
2311 low_tm_f32: PyReadonlyArray1<'_, f32>,
2312 close_tm_f32: PyReadonlyArray1<'_, f32>,
2313 cols: usize,
2314 rows: usize,
2315 fastk: usize,
2316 fastd: usize,
2317 fastd_matype: usize,
2318 device_id: usize,
2319) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
2320 if !cuda_available() {
2321 return Err(PyErrValue::new_err("CUDA not available"));
2322 }
2323 let htm = high_tm_f32.as_slice()?;
2324 let ltm = low_tm_f32.as_slice()?;
2325 let ctm = close_tm_f32.as_slice()?;
2326 let params = StochfParams {
2327 fastk_period: Some(fastk),
2328 fastd_period: Some(fastd),
2329 fastd_matype: Some(fastd_matype),
2330 };
2331 let (k, d, ctx, dev_id) = py.allow_threads(|| {
2332 let cuda = CudaStochf::new(device_id).map_err(|e| PyErrValue::new_err(e.to_string()))?;
2333 let ctx = cuda.context_arc();
2334 let dev_id = cuda.device_id();
2335 let (k, d) = cuda
2336 .stochf_many_series_one_param_time_major_dev(htm, ltm, ctm, cols, rows, ¶ms)
2337 .map_err(|e| PyErrValue::new_err(e.to_string()))?;
2338 Ok::<_, PyErr>((k, d, ctx, dev_id))
2339 })?;
2340 Ok((
2341 DeviceArrayF32Py {
2342 inner: k,
2343 _ctx: Some(ctx.clone()),
2344 device_id: Some(dev_id),
2345 },
2346 DeviceArrayF32Py {
2347 inner: d,
2348 _ctx: Some(ctx),
2349 device_id: Some(dev_id),
2350 },
2351 ))
2352}
2353
2354#[cfg(test)]
2355mod tests {
2356 use super::*;
2357 use crate::skip_if_unsupported;
2358 use crate::utilities::data_loader::read_candles_from_csv;
2359
2360 fn check_stochf_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2361 skip_if_unsupported!(kernel, test_name);
2362 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2363 let candles = read_candles_from_csv(file_path)?;
2364 let params = StochfParams {
2365 fastk_period: None,
2366 fastd_period: None,
2367 fastd_matype: None,
2368 };
2369 let input = StochfInput::from_candles(&candles, params);
2370 let output = stochf_with_kernel(&input, kernel)?;
2371 assert_eq!(output.k.len(), candles.close.len());
2372 assert_eq!(output.d.len(), candles.close.len());
2373 Ok(())
2374 }
2375
2376 fn check_stochf_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2377 skip_if_unsupported!(kernel, test_name);
2378 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2379 let candles = read_candles_from_csv(file_path)?;
2380 let params = StochfParams {
2381 fastk_period: Some(5),
2382 fastd_period: Some(3),
2383 fastd_matype: Some(0),
2384 };
2385 let input = StochfInput::from_candles(&candles, params);
2386 let output = stochf_with_kernel(&input, kernel)?;
2387 let expected_k = [
2388 80.6987399770905,
2389 40.88471849865952,
2390 15.507246376811594,
2391 36.920529801324506,
2392 32.1880650994575,
2393 ];
2394 let expected_d = [
2395 70.99960994145033,
2396 61.44725644908976,
2397 45.696901617520815,
2398 31.104164892265487,
2399 28.205280425864817,
2400 ];
2401 let k_slice = &output.k[output.k.len() - 5..];
2402 let d_slice = &output.d[output.d.len() - 5..];
2403 for i in 0..5 {
2404 assert!(
2405 (k_slice[i] - expected_k[i]).abs() < 1e-4,
2406 "K mismatch at idx {}",
2407 i
2408 );
2409 assert!(
2410 (d_slice[i] - expected_d[i]).abs() < 1e-4,
2411 "D mismatch at idx {}",
2412 i
2413 );
2414 }
2415 Ok(())
2416 }
2417
2418 fn check_stochf_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2419 skip_if_unsupported!(kernel, test_name);
2420 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2421 let candles = read_candles_from_csv(file_path)?;
2422 let input = StochfInput::with_default_candles(&candles);
2423 let output = stochf_with_kernel(&input, kernel)?;
2424 assert_eq!(output.k.len(), candles.close.len());
2425 assert_eq!(output.d.len(), candles.close.len());
2426 Ok(())
2427 }
2428
2429 fn check_stochf_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2430 skip_if_unsupported!(kernel, test_name);
2431 let data = [10.0, 20.0, 30.0, 40.0, 50.0];
2432 let params = StochfParams {
2433 fastk_period: Some(0),
2434 fastd_period: Some(3),
2435 fastd_matype: Some(0),
2436 };
2437 let input = StochfInput::from_slices(&data, &data, &data, params);
2438 let res = stochf_with_kernel(&input, kernel);
2439 assert!(res.is_err());
2440 Ok(())
2441 }
2442
2443 fn check_stochf_period_exceeds_length(
2444 test_name: &str,
2445 kernel: Kernel,
2446 ) -> Result<(), Box<dyn Error>> {
2447 skip_if_unsupported!(kernel, test_name);
2448 let data = [10.0, 20.0, 30.0];
2449 let params = StochfParams {
2450 fastk_period: Some(10),
2451 fastd_period: Some(3),
2452 fastd_matype: Some(0),
2453 };
2454 let input = StochfInput::from_slices(&data, &data, &data, params);
2455 let res = stochf_with_kernel(&input, kernel);
2456 assert!(res.is_err());
2457 Ok(())
2458 }
2459
2460 fn check_stochf_very_small_dataset(
2461 test_name: &str,
2462 kernel: Kernel,
2463 ) -> Result<(), Box<dyn Error>> {
2464 skip_if_unsupported!(kernel, test_name);
2465 let data = [42.0];
2466 let params = StochfParams {
2467 fastk_period: Some(9),
2468 fastd_period: Some(3),
2469 fastd_matype: Some(0),
2470 };
2471 let input = StochfInput::from_slices(&data, &data, &data, params);
2472 let res = stochf_with_kernel(&input, kernel);
2473 assert!(res.is_err());
2474 Ok(())
2475 }
2476
2477 fn check_stochf_slice_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2478 skip_if_unsupported!(kernel, test_name);
2479 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2480 let candles = read_candles_from_csv(file_path)?;
2481 let params = StochfParams {
2482 fastk_period: Some(5),
2483 fastd_period: Some(3),
2484 fastd_matype: Some(0),
2485 };
2486 let input1 = StochfInput::from_candles(&candles, params.clone());
2487 let res1 = stochf_with_kernel(&input1, kernel)?;
2488 let input2 = StochfInput::from_slices(&res1.k, &res1.k, &res1.k, params);
2489 let res2 = stochf_with_kernel(&input2, kernel)?;
2490 assert_eq!(res2.k.len(), res1.k.len());
2491 assert_eq!(res2.d.len(), res1.d.len());
2492 Ok(())
2493 }
2494
2495 #[cfg(debug_assertions)]
2496 fn check_stochf_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2497 skip_if_unsupported!(kernel, test_name);
2498
2499 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2500 let candles = read_candles_from_csv(file_path)?;
2501
2502 let test_params = vec![
2503 StochfParams::default(),
2504 StochfParams {
2505 fastk_period: Some(2),
2506 fastd_period: Some(1),
2507 fastd_matype: Some(0),
2508 },
2509 StochfParams {
2510 fastk_period: Some(3),
2511 fastd_period: Some(2),
2512 fastd_matype: Some(0),
2513 },
2514 StochfParams {
2515 fastk_period: Some(5),
2516 fastd_period: Some(5),
2517 fastd_matype: Some(0),
2518 },
2519 StochfParams {
2520 fastk_period: Some(10),
2521 fastd_period: Some(3),
2522 fastd_matype: Some(0),
2523 },
2524 StochfParams {
2525 fastk_period: Some(14),
2526 fastd_period: Some(3),
2527 fastd_matype: Some(0),
2528 },
2529 StochfParams {
2530 fastk_period: Some(20),
2531 fastd_period: Some(5),
2532 fastd_matype: Some(0),
2533 },
2534 StochfParams {
2535 fastk_period: Some(50),
2536 fastd_period: Some(10),
2537 fastd_matype: Some(0),
2538 },
2539 StochfParams {
2540 fastk_period: Some(100),
2541 fastd_period: Some(20),
2542 fastd_matype: Some(0),
2543 },
2544 StochfParams {
2545 fastk_period: Some(8),
2546 fastd_period: Some(7),
2547 fastd_matype: Some(0),
2548 },
2549 ];
2550
2551 for (param_idx, params) in test_params.iter().enumerate() {
2552 let input = StochfInput::from_candles(&candles, params.clone());
2553 let output = stochf_with_kernel(&input, kernel)?;
2554
2555 for (i, &val) in output.k.iter().enumerate() {
2556 if val.is_nan() {
2557 continue;
2558 }
2559
2560 let bits = val.to_bits();
2561
2562 if bits == 0x11111111_11111111 {
2563 panic!(
2564 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at K index {} \
2565 with params: fastk_period={}, fastd_period={}, fastd_matype={} (param set {})",
2566 test_name, val, bits, i,
2567 params.fastk_period.unwrap_or(5),
2568 params.fastd_period.unwrap_or(3),
2569 params.fastd_matype.unwrap_or(0),
2570 param_idx
2571 );
2572 }
2573
2574 if bits == 0x22222222_22222222 {
2575 panic!(
2576 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at K index {} \
2577 with params: fastk_period={}, fastd_period={}, fastd_matype={} (param set {})",
2578 test_name, val, bits, i,
2579 params.fastk_period.unwrap_or(5),
2580 params.fastd_period.unwrap_or(3),
2581 params.fastd_matype.unwrap_or(0),
2582 param_idx
2583 );
2584 }
2585
2586 if bits == 0x33333333_33333333 {
2587 panic!(
2588 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at K index {} \
2589 with params: fastk_period={}, fastd_period={}, fastd_matype={} (param set {})",
2590 test_name,
2591 val,
2592 bits,
2593 i,
2594 params.fastk_period.unwrap_or(5),
2595 params.fastd_period.unwrap_or(3),
2596 params.fastd_matype.unwrap_or(0),
2597 param_idx
2598 );
2599 }
2600 }
2601
2602 for (i, &val) in output.d.iter().enumerate() {
2603 if val.is_nan() {
2604 continue;
2605 }
2606
2607 let bits = val.to_bits();
2608
2609 if bits == 0x11111111_11111111 {
2610 panic!(
2611 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at D index {} \
2612 with params: fastk_period={}, fastd_period={}, fastd_matype={} (param set {})",
2613 test_name, val, bits, i,
2614 params.fastk_period.unwrap_or(5),
2615 params.fastd_period.unwrap_or(3),
2616 params.fastd_matype.unwrap_or(0),
2617 param_idx
2618 );
2619 }
2620
2621 if bits == 0x22222222_22222222 {
2622 panic!(
2623 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at D index {} \
2624 with params: fastk_period={}, fastd_period={}, fastd_matype={} (param set {})",
2625 test_name, val, bits, i,
2626 params.fastk_period.unwrap_or(5),
2627 params.fastd_period.unwrap_or(3),
2628 params.fastd_matype.unwrap_or(0),
2629 param_idx
2630 );
2631 }
2632
2633 if bits == 0x33333333_33333333 {
2634 panic!(
2635 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at D index {} \
2636 with params: fastk_period={}, fastd_period={}, fastd_matype={} (param set {})",
2637 test_name,
2638 val,
2639 bits,
2640 i,
2641 params.fastk_period.unwrap_or(5),
2642 params.fastd_period.unwrap_or(3),
2643 params.fastd_matype.unwrap_or(0),
2644 param_idx
2645 );
2646 }
2647 }
2648 }
2649
2650 Ok(())
2651 }
2652
2653 #[cfg(not(debug_assertions))]
2654 fn check_stochf_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2655 Ok(())
2656 }
2657
2658 macro_rules! generate_all_stochf_tests {
2659 ($($test_fn:ident),*) => {
2660 paste::paste! {
2661 $(
2662 #[test]
2663 fn [<$test_fn _scalar_f64>]() {
2664 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2665 }
2666 )*
2667 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2668 $(
2669 #[test]
2670 fn [<$test_fn _avx2_f64>]() {
2671 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2672 }
2673 #[test]
2674 fn [<$test_fn _avx512_f64>]() {
2675 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2676 }
2677 )*
2678 }
2679 }
2680 }
2681
2682 generate_all_stochf_tests!(
2683 check_stochf_partial_params,
2684 check_stochf_accuracy,
2685 check_stochf_default_candles,
2686 check_stochf_zero_period,
2687 check_stochf_period_exceeds_length,
2688 check_stochf_very_small_dataset,
2689 check_stochf_slice_reinput,
2690 check_stochf_no_poison
2691 );
2692
2693 #[cfg(feature = "proptest")]
2694 generate_all_stochf_tests!(check_stochf_property);
2695
2696 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2697 skip_if_unsupported!(kernel, test);
2698 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2699 let c = read_candles_from_csv(file)?;
2700 let output = StochfBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
2701 let def = StochfParams::default();
2702 let krow = output.k_for(&def).expect("default row missing");
2703 let drow = output.d_for(&def).expect("default row missing");
2704 assert_eq!(krow.len(), c.close.len());
2705 assert_eq!(drow.len(), c.close.len());
2706 Ok(())
2707 }
2708
2709 #[cfg(debug_assertions)]
2710 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2711 skip_if_unsupported!(kernel, test);
2712
2713 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2714 let c = read_candles_from_csv(file)?;
2715
2716 let test_configs = vec![
2717 (2, 10, 2, 1, 5, 1),
2718 (5, 25, 5, 3, 3, 0),
2719 (30, 60, 15, 5, 15, 5),
2720 (2, 5, 1, 1, 3, 1),
2721 (10, 20, 2, 3, 9, 3),
2722 (14, 14, 0, 1, 7, 2),
2723 (3, 12, 3, 2, 2, 0),
2724 (50, 100, 25, 10, 20, 10),
2725 ];
2726
2727 for (cfg_idx, &(fk_start, fk_end, fk_step, fd_start, fd_end, fd_step)) in
2728 test_configs.iter().enumerate()
2729 {
2730 let output = StochfBatchBuilder::new()
2731 .kernel(kernel)
2732 .fastk_range(fk_start, fk_end, fk_step)
2733 .fastd_range(fd_start, fd_end, fd_step)
2734 .apply_candles(&c)?;
2735
2736 for (idx, &val) in output.k.iter().enumerate() {
2737 if val.is_nan() {
2738 continue;
2739 }
2740
2741 let bits = val.to_bits();
2742 let row = idx / output.cols;
2743 let col = idx % output.cols;
2744 let combo = &output.combos[row];
2745
2746 if bits == 0x11111111_11111111 {
2747 panic!(
2748 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2749 at K row {} col {} (flat index {}) with params: fastk={}, fastd={}, matype={}",
2750 test,
2751 cfg_idx,
2752 val,
2753 bits,
2754 row,
2755 col,
2756 idx,
2757 combo.fastk_period.unwrap_or(5),
2758 combo.fastd_period.unwrap_or(3),
2759 combo.fastd_matype.unwrap_or(0)
2760 );
2761 }
2762
2763 if bits == 0x22222222_22222222 {
2764 panic!(
2765 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2766 at K row {} col {} (flat index {}) with params: fastk={}, fastd={}, matype={}",
2767 test,
2768 cfg_idx,
2769 val,
2770 bits,
2771 row,
2772 col,
2773 idx,
2774 combo.fastk_period.unwrap_or(5),
2775 combo.fastd_period.unwrap_or(3),
2776 combo.fastd_matype.unwrap_or(0)
2777 );
2778 }
2779
2780 if bits == 0x33333333_33333333 {
2781 panic!(
2782 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2783 at K row {} col {} (flat index {}) with params: fastk={}, fastd={}, matype={}",
2784 test,
2785 cfg_idx,
2786 val,
2787 bits,
2788 row,
2789 col,
2790 idx,
2791 combo.fastk_period.unwrap_or(5),
2792 combo.fastd_period.unwrap_or(3),
2793 combo.fastd_matype.unwrap_or(0)
2794 );
2795 }
2796 }
2797
2798 for (idx, &val) in output.d.iter().enumerate() {
2799 if val.is_nan() {
2800 continue;
2801 }
2802
2803 let bits = val.to_bits();
2804 let row = idx / output.cols;
2805 let col = idx % output.cols;
2806 let combo = &output.combos[row];
2807
2808 if bits == 0x11111111_11111111 {
2809 panic!(
2810 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2811 at D row {} col {} (flat index {}) with params: fastk={}, fastd={}, matype={}",
2812 test,
2813 cfg_idx,
2814 val,
2815 bits,
2816 row,
2817 col,
2818 idx,
2819 combo.fastk_period.unwrap_or(5),
2820 combo.fastd_period.unwrap_or(3),
2821 combo.fastd_matype.unwrap_or(0)
2822 );
2823 }
2824
2825 if bits == 0x22222222_22222222 {
2826 panic!(
2827 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2828 at D row {} col {} (flat index {}) with params: fastk={}, fastd={}, matype={}",
2829 test,
2830 cfg_idx,
2831 val,
2832 bits,
2833 row,
2834 col,
2835 idx,
2836 combo.fastk_period.unwrap_or(5),
2837 combo.fastd_period.unwrap_or(3),
2838 combo.fastd_matype.unwrap_or(0)
2839 );
2840 }
2841
2842 if bits == 0x33333333_33333333 {
2843 panic!(
2844 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2845 at D row {} col {} (flat index {}) with params: fastk={}, fastd={}, matype={}",
2846 test,
2847 cfg_idx,
2848 val,
2849 bits,
2850 row,
2851 col,
2852 idx,
2853 combo.fastk_period.unwrap_or(5),
2854 combo.fastd_period.unwrap_or(3),
2855 combo.fastd_matype.unwrap_or(0)
2856 );
2857 }
2858 }
2859 }
2860
2861 Ok(())
2862 }
2863
2864 #[cfg(not(debug_assertions))]
2865 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2866 Ok(())
2867 }
2868
2869 #[cfg(feature = "proptest")]
2870 #[allow(clippy::float_cmp)]
2871 fn check_stochf_property(
2872 test_name: &str,
2873 kernel: Kernel,
2874 ) -> Result<(), Box<dyn std::error::Error>> {
2875 use proptest::prelude::*;
2876 skip_if_unsupported!(kernel, test_name);
2877
2878 let strat = (2usize..=50).prop_flat_map(|fastk_period| {
2879 (
2880 prop::collection::vec(
2881 (1.0f64..1000.0f64).prop_filter("finite", |x| x.is_finite()),
2882 fastk_period + 50..400,
2883 ),
2884 prop::collection::vec(0.0f64..=1.0f64, fastk_period + 50..400),
2885 Just(fastk_period),
2886 1usize..=10,
2887 0.001f64..0.1f64,
2888 )
2889 });
2890
2891 proptest::test_runner::TestRunner::default()
2892 .run(
2893 &strat,
2894 |(base_prices, close_positions, fastk_period, fastd_period, volatility)| {
2895 let len = base_prices.len().min(close_positions.len());
2896 let mut high = Vec::with_capacity(len);
2897 let mut low = Vec::with_capacity(len);
2898 let mut close = Vec::with_capacity(len);
2899
2900 for i in 0..len {
2901 let base = base_prices[i];
2902 let spread = base * volatility;
2903 let h = base + spread * 0.5;
2904 let l = base - spread * 0.5;
2905
2906 let c = l + (h - l) * close_positions[i];
2907
2908 high.push(h);
2909 low.push(l);
2910 close.push(c);
2911 }
2912
2913 let params = StochfParams {
2914 fastk_period: Some(fastk_period),
2915 fastd_period: Some(fastd_period),
2916 fastd_matype: Some(0),
2917 };
2918 let input = StochfInput::from_slices(&high, &low, &close, params.clone());
2919
2920 let output = stochf_with_kernel(&input, kernel).unwrap();
2921 let ref_output = stochf_with_kernel(&input, Kernel::Scalar).unwrap();
2922
2923 for (i, &k_val) in output.k.iter().enumerate() {
2924 if !k_val.is_nan() {
2925 prop_assert!(
2926 k_val >= -1e-9 && k_val <= 100.0 + 1e-9,
2927 "K value out of range at idx {}: {} (should be in [0, 100])",
2928 i,
2929 k_val
2930 );
2931 }
2932 }
2933
2934 for (i, &d_val) in output.d.iter().enumerate() {
2935 if !d_val.is_nan() {
2936 prop_assert!(
2937 d_val >= -1e-9 && d_val <= 100.0 + 1e-9,
2938 "D value out of range at idx {}: {} (should be in [0, 100])",
2939 i,
2940 d_val
2941 );
2942 }
2943 }
2944
2945 let k_warmup = fastk_period - 1;
2946 let d_warmup = fastk_period - 1 + fastd_period - 1;
2947
2948 for i in 0..k_warmup.min(len) {
2949 prop_assert!(
2950 output.k[i].is_nan(),
2951 "K value should be NaN during warmup at idx {}: {}",
2952 i,
2953 output.k[i]
2954 );
2955 }
2956
2957 for i in 0..d_warmup.min(len) {
2958 prop_assert!(
2959 output.d[i].is_nan(),
2960 "D value should be NaN during warmup at idx {}: {}",
2961 i,
2962 output.d[i]
2963 );
2964 }
2965
2966 for i in 0..len {
2967 let k_val = output.k[i];
2968 let k_ref = ref_output.k[i];
2969 let d_val = output.d[i];
2970 let d_ref = ref_output.d[i];
2971
2972 if !k_val.is_nan() && !k_ref.is_nan() {
2973 prop_assert!(
2974 (k_val - k_ref).abs() <= 1e-9,
2975 "K kernel mismatch at idx {}: {} vs {} (diff: {})",
2976 i,
2977 k_val,
2978 k_ref,
2979 (k_val - k_ref).abs()
2980 );
2981 }
2982
2983 if !d_val.is_nan() && !d_ref.is_nan() {
2984 prop_assert!(
2985 (d_val - d_ref).abs() <= 1e-9,
2986 "D kernel mismatch at idx {}: {} vs {} (diff: {})",
2987 i,
2988 d_val,
2989 d_ref,
2990 (d_val - d_ref).abs()
2991 );
2992 }
2993 }
2994
2995 for i in k_warmup..len {
2996 let start = i + 1 - fastk_period;
2997 let window_high = &high[start..=i];
2998 let window_low = &low[start..=i];
2999
3000 let hh = window_high
3001 .iter()
3002 .cloned()
3003 .fold(f64::NEG_INFINITY, f64::max);
3004 let ll = window_low.iter().cloned().fold(f64::INFINITY, f64::min);
3005
3006 let expected_k = if hh == ll {
3007 if close[i] == hh {
3008 100.0
3009 } else {
3010 0.0
3011 }
3012 } else {
3013 100.0 * (close[i] - ll) / (hh - ll)
3014 };
3015
3016 let actual_k = output.k[i];
3017 prop_assert!(
3018 (actual_k - expected_k).abs() <= 1e-9,
3019 "K formula mismatch at idx {}: actual {} vs expected {} (diff: {})",
3020 i,
3021 actual_k,
3022 expected_k,
3023 (actual_k - expected_k).abs()
3024 );
3025 }
3026
3027 for i in d_warmup..len {
3028 let start = i + 1 - fastd_period;
3029 let k_window = &output.k[start..=i];
3030 let expected_d = k_window.iter().sum::<f64>() / (fastd_period as f64);
3031 let actual_d = output.d[i];
3032
3033 prop_assert!(
3034 (actual_d - expected_d).abs() <= 1e-9,
3035 "D SMA mismatch at idx {}: actual {} vs expected {} (diff: {})",
3036 i,
3037 actual_d,
3038 expected_d,
3039 (actual_d - expected_d).abs()
3040 );
3041 }
3042
3043 let const_len = (fastk_period + fastd_period) * 2;
3044 if len > const_len {
3045 let const_price = 100.0;
3046 let const_high = vec![const_price; const_len];
3047 let const_low = vec![const_price; const_len];
3048 let const_close = vec![const_price; const_len];
3049
3050 let const_input = StochfInput::from_slices(
3051 &const_high,
3052 &const_low,
3053 &const_close,
3054 params.clone(),
3055 );
3056 let const_output = stochf_with_kernel(&const_input, kernel).unwrap();
3057
3058 for i in k_warmup..const_high.len() {
3059 prop_assert!(
3060 (const_output.k[i] - 100.0).abs() <= 1e-9,
3061 "Constant price K should be 100 at idx {}: {}",
3062 i,
3063 const_output.k[i]
3064 );
3065 }
3066 }
3067
3068 let extreme_len = (fastk_period + fastd_period) * 2;
3069 if len > extreme_len {
3070 let low_close_high = vec![100.0; extreme_len];
3071 let low_close_low = vec![90.0; extreme_len];
3072 let low_close_close = vec![90.0; extreme_len];
3073
3074 let low_input = StochfInput::from_slices(
3075 &low_close_high,
3076 &low_close_low,
3077 &low_close_close,
3078 params.clone(),
3079 );
3080 let low_output = stochf_with_kernel(&low_input, kernel).unwrap();
3081
3082 for i in k_warmup..low_close_high.len() {
3083 prop_assert!(
3084 low_output.k[i].abs() <= 1e-9,
3085 "When close == low, K should be 0 at idx {}: {}",
3086 i,
3087 low_output.k[i]
3088 );
3089 }
3090
3091 let high_close_high = vec![100.0; extreme_len];
3092 let high_close_low = vec![90.0; extreme_len];
3093 let high_close_close = vec![100.0; extreme_len];
3094
3095 let high_input = StochfInput::from_slices(
3096 &high_close_high,
3097 &high_close_low,
3098 &high_close_close,
3099 params.clone(),
3100 );
3101 let high_output = stochf_with_kernel(&high_input, kernel).unwrap();
3102
3103 for i in k_warmup..high_close_high.len() {
3104 prop_assert!(
3105 (high_output.k[i] - 100.0).abs() <= 1e-9,
3106 "When close == high, K should be 100 at idx {}: {}",
3107 i,
3108 high_output.k[i]
3109 );
3110 }
3111 }
3112
3113 #[cfg(debug_assertions)]
3114 {
3115 for (i, &val) in output.k.iter().enumerate() {
3116 if !val.is_nan() {
3117 let bits = val.to_bits();
3118 prop_assert!(
3119 bits != 0x11111111_11111111
3120 && bits != 0x22222222_22222222
3121 && bits != 0x33333333_33333333,
3122 "Found poison value in K at idx {}: {} (0x{:016X})",
3123 i,
3124 val,
3125 bits
3126 );
3127 }
3128 }
3129
3130 for (i, &val) in output.d.iter().enumerate() {
3131 if !val.is_nan() {
3132 let bits = val.to_bits();
3133 prop_assert!(
3134 bits != 0x11111111_11111111
3135 && bits != 0x22222222_22222222
3136 && bits != 0x33333333_33333333,
3137 "Found poison value in D at idx {}: {} (0x{:016X})",
3138 i,
3139 val,
3140 bits
3141 );
3142 }
3143 }
3144 }
3145
3146 Ok(())
3147 },
3148 )
3149 .unwrap();
3150
3151 Ok(())
3152 }
3153
3154 macro_rules! gen_batch_tests {
3155 ($fn_name:ident) => {
3156 paste::paste! {
3157 #[test] fn [<$fn_name _scalar>]() {
3158 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3159 }
3160 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3161 #[test] fn [<$fn_name _avx2>]() {
3162 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3163 }
3164 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3165 #[test] fn [<$fn_name _avx512>]() {
3166 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3167 }
3168 #[test] fn [<$fn_name _auto_detect>]() {
3169 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3170 }
3171 }
3172 };
3173 }
3174 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3175 #[test]
3176 fn test_wasm_batch_warmup_initialization() {
3177 let high = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
3178 let low = vec![0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5];
3179 let close = vec![0.8, 1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8, 9.8];
3180
3181 let mut k_out = vec![999.0; 10];
3182 let mut d_out = vec![999.0; 10];
3183
3184 let result = unsafe {
3185 stochf_batch_into(
3186 high.as_ptr(),
3187 low.as_ptr(),
3188 close.as_ptr(),
3189 k_out.as_mut_ptr(),
3190 d_out.as_mut_ptr(),
3191 10,
3192 3,
3193 3,
3194 0,
3195 2,
3196 2,
3197 0,
3198 0,
3199 )
3200 };
3201
3202 assert!(result.is_ok());
3203 assert_eq!(result.unwrap(), 1);
3204
3205 assert!(k_out[0].is_nan(), "K[0] should be NaN");
3206 assert!(k_out[1].is_nan(), "K[1] should be NaN");
3207 assert!(!k_out[2].is_nan(), "K[2] should have a value");
3208
3209 assert!(d_out[0].is_nan(), "D[0] should be NaN");
3210 assert!(d_out[1].is_nan(), "D[1] should be NaN");
3211 assert!(d_out[2].is_nan(), "D[2] should be NaN");
3212 assert!(!d_out[3].is_nan(), "D[3] should have a value");
3213 }
3214
3215 #[test]
3216 fn test_batch_invalid_output_size() {
3217 let high = vec![10.0, 20.0, 30.0, 40.0, 50.0];
3218 let low = vec![5.0, 15.0, 25.0, 35.0, 45.0];
3219 let close = vec![7.0, 17.0, 27.0, 37.0, 47.0];
3220
3221 let sweep = StochfBatchRange {
3222 fastk_period: (3, 4, 1),
3223 fastd_period: (2, 2, 0),
3224 };
3225
3226 let mut k_out = vec![0.0; 5];
3227 let mut d_out = vec![0.0; 5];
3228
3229 let result = stochf_batch_inner_into(
3230 &high,
3231 &low,
3232 &close,
3233 &sweep,
3234 Kernel::Scalar,
3235 false,
3236 &mut k_out,
3237 &mut d_out,
3238 );
3239
3240 assert!(matches!(
3241 result,
3242 Err(StochfError::OutputLengthMismatch {
3243 expected: 10,
3244 k_got: 5,
3245 d_got: 5
3246 })
3247 ));
3248
3249 let mut k_out = vec![0.0; 10];
3250 let mut d_out = vec![0.0; 8];
3251
3252 let result = stochf_batch_inner_into(
3253 &high,
3254 &low,
3255 &close,
3256 &sweep,
3257 Kernel::Scalar,
3258 false,
3259 &mut k_out,
3260 &mut d_out,
3261 );
3262
3263 assert!(matches!(
3264 result,
3265 Err(StochfError::OutputLengthMismatch {
3266 expected: 10,
3267 k_got: 10,
3268 d_got: 8
3269 })
3270 ));
3271
3272 let mut k_out = vec![0.0; 10];
3273 let mut d_out = vec![0.0; 10];
3274
3275 let result = stochf_batch_inner_into(
3276 &high,
3277 &low,
3278 &close,
3279 &sweep,
3280 Kernel::Scalar,
3281 false,
3282 &mut k_out,
3283 &mut d_out,
3284 );
3285
3286 assert!(result.is_ok());
3287 }
3288
3289 #[test]
3290 fn test_stochf_into_matches_api() {
3291 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3292 let candles = read_candles_from_csv(file_path).expect("failed to read csv");
3293 let input = StochfInput::with_default_candles(&candles);
3294
3295 let base = stochf(&input).expect("baseline stochf failed");
3296
3297 let len = candles.close.len();
3298 let mut k_out = vec![0.0f64; len];
3299 let mut d_out = vec![0.0f64; len];
3300 stochf_into(&input, &mut k_out, &mut d_out).expect("stochf_into failed");
3301
3302 assert_eq!(base.k.len(), k_out.len());
3303 assert_eq!(base.d.len(), d_out.len());
3304
3305 fn eq_or_nan(a: f64, b: f64) -> bool {
3306 (a.is_nan() && b.is_nan()) || (a == b)
3307 }
3308
3309 for i in 0..len {
3310 assert!(
3311 eq_or_nan(base.k[i], k_out[i]),
3312 "K mismatch at {}: base={:?} into={:?}",
3313 i,
3314 base.k[i],
3315 k_out[i]
3316 );
3317 assert!(
3318 eq_or_nan(base.d[i], d_out[i]),
3319 "D mismatch at {}: base={:?} into={:?}",
3320 i,
3321 base.d[i],
3322 d_out[i]
3323 );
3324 }
3325 }
3326
3327 gen_batch_tests!(check_batch_default_row);
3328 gen_batch_tests!(check_batch_no_poison);
3329}
3330
3331#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3332#[wasm_bindgen]
3333pub fn stochf_js(
3334 high: &[f64],
3335 low: &[f64],
3336 close: &[f64],
3337 fastk_period: usize,
3338 fastd_period: usize,
3339 fastd_matype: usize,
3340) -> Result<Vec<f64>, JsValue> {
3341 let params = StochfParams {
3342 fastk_period: Some(fastk_period),
3343 fastd_period: Some(fastd_period),
3344 fastd_matype: Some(fastd_matype),
3345 };
3346 let input = StochfInput::from_slices(high, low, close, params);
3347 let out =
3348 stochf_with_kernel(&input, Kernel::Auto).map_err(|e| JsValue::from_str(&e.to_string()))?;
3349
3350 let mut values = Vec::with_capacity(2 * out.k.len());
3351 values.extend_from_slice(&out.k);
3352 values.extend_from_slice(&out.d);
3353
3354 Ok(values)
3355}
3356
3357#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3358#[wasm_bindgen]
3359pub fn stochf_alloc(len: usize) -> *mut f64 {
3360 let mut vec = Vec::<f64>::with_capacity(len);
3361 let ptr = vec.as_mut_ptr();
3362 std::mem::forget(vec);
3363 ptr
3364}
3365
3366#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3367#[wasm_bindgen]
3368pub fn stochf_free(ptr: *mut f64, len: usize) {
3369 if !ptr.is_null() {
3370 unsafe {
3371 let _ = Vec::from_raw_parts(ptr, len, len);
3372 }
3373 }
3374}
3375
3376#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3377#[wasm_bindgen]
3378pub fn stochf_into(
3379 high_ptr: *const f64,
3380 low_ptr: *const f64,
3381 close_ptr: *const f64,
3382 k_out_ptr: *mut f64,
3383 d_out_ptr: *mut f64,
3384 len: usize,
3385 fastk_period: usize,
3386 fastd_period: usize,
3387 fastd_matype: usize,
3388) -> Result<(), JsValue> {
3389 if high_ptr.is_null()
3390 || low_ptr.is_null()
3391 || close_ptr.is_null()
3392 || k_out_ptr.is_null()
3393 || d_out_ptr.is_null()
3394 {
3395 return Err(JsValue::from_str("null pointer passed to stochf_into"));
3396 }
3397
3398 unsafe {
3399 let high = std::slice::from_raw_parts(high_ptr, len);
3400 let low = std::slice::from_raw_parts(low_ptr, len);
3401 let close = std::slice::from_raw_parts(close_ptr, len);
3402 let mut k_out = std::slice::from_raw_parts_mut(k_out_ptr, len);
3403 let mut d_out = std::slice::from_raw_parts_mut(d_out_ptr, len);
3404
3405 let params = StochfParams {
3406 fastk_period: Some(fastk_period),
3407 fastd_period: Some(fastd_period),
3408 fastd_matype: Some(fastd_matype),
3409 };
3410 let input = StochfInput::from_slices(high, low, close, params);
3411 stochf_into_slice(&mut k_out, &mut d_out, &input, detect_best_kernel())
3412 .map_err(|e| JsValue::from_str(&e.to_string()))
3413 }
3414}
3415
3416#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3417#[derive(Serialize, Deserialize)]
3418pub struct StochfBatchConfig {
3419 pub fastk_range: (usize, usize, usize),
3420 pub fastd_range: (usize, usize, usize),
3421}
3422
3423#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3424#[derive(Serialize, Deserialize)]
3425pub struct StochfBatchJsOutput {
3426 pub k_values: Vec<f64>,
3427 pub d_values: Vec<f64>,
3428 pub rows: usize,
3429 pub cols: usize,
3430 pub combos: Vec<StochfParams>,
3431}
3432
3433#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3434#[wasm_bindgen(js_name = stochf_batch)]
3435pub fn stochf_batch_unified_js(
3436 high: &[f64],
3437 low: &[f64],
3438 close: &[f64],
3439 config: JsValue,
3440) -> Result<JsValue, JsValue> {
3441 let cfg: StochfBatchConfig = serde_wasm_bindgen::from_value(config)
3442 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
3443
3444 let sweep = StochfBatchRange {
3445 fastk_period: cfg.fastk_range,
3446 fastd_period: cfg.fastd_range,
3447 };
3448
3449 let out = stochf_batch_inner(high, low, close, &sweep, detect_best_kernel(), false)
3450 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3451
3452 let js = StochfBatchJsOutput {
3453 k_values: out.k,
3454 d_values: out.d,
3455 rows: out.rows,
3456 cols: out.cols,
3457 combos: out.combos,
3458 };
3459 serde_wasm_bindgen::to_value(&js)
3460 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
3461}
3462
3463#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3464#[wasm_bindgen]
3465pub fn stochf_batch_into(
3466 in_high_ptr: *const f64,
3467 in_low_ptr: *const f64,
3468 in_close_ptr: *const f64,
3469 out_k_ptr: *mut f64,
3470 out_d_ptr: *mut f64,
3471 len: usize,
3472 fastk_start: usize,
3473 fastk_end: usize,
3474 fastk_step: usize,
3475 fastd_start: usize,
3476 fastd_end: usize,
3477 fastd_step: usize,
3478 fastd_matype: usize,
3479) -> Result<usize, JsValue> {
3480 if in_high_ptr.is_null()
3481 || in_low_ptr.is_null()
3482 || in_close_ptr.is_null()
3483 || out_k_ptr.is_null()
3484 || out_d_ptr.is_null()
3485 {
3486 return Err(JsValue::from_str("Null pointer provided"));
3487 }
3488
3489 unsafe {
3490 let high = std::slice::from_raw_parts(in_high_ptr, len);
3491 let low = std::slice::from_raw_parts(in_low_ptr, len);
3492 let close = std::slice::from_raw_parts(in_close_ptr, len);
3493
3494 let sweep = StochfBatchRange {
3495 fastk_period: (fastk_start, fastk_end, fastk_step),
3496 fastd_period: (fastd_start, fastd_end, fastd_step),
3497 };
3498
3499 let combos = expand_grid(&sweep);
3500 let rows = combos.len();
3501 let cols = len;
3502
3503 let aliasing = in_high_ptr == out_k_ptr
3504 || in_high_ptr == out_d_ptr
3505 || in_low_ptr == out_k_ptr
3506 || in_low_ptr == out_d_ptr
3507 || in_close_ptr == out_k_ptr
3508 || in_close_ptr == out_d_ptr;
3509
3510 if aliasing {
3511 let mut temp_k = vec![0.0; rows * cols];
3512 let mut temp_d = vec![0.0; rows * cols];
3513
3514 let kernel = detect_best_batch_kernel();
3515
3516 let simd_kernel = match kernel {
3517 Kernel::Avx512Batch => Kernel::Avx512,
3518 Kernel::Avx2Batch => Kernel::Avx2,
3519 Kernel::ScalarBatch => Kernel::Scalar,
3520 _ => Kernel::Scalar,
3521 };
3522
3523 stochf_batch_inner_into(
3524 high,
3525 low,
3526 close,
3527 &sweep,
3528 simd_kernel,
3529 false,
3530 &mut temp_k,
3531 &mut temp_d,
3532 )
3533 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3534
3535 let out_k_slice = std::slice::from_raw_parts_mut(out_k_ptr, rows * cols);
3536 let out_d_slice = std::slice::from_raw_parts_mut(out_d_ptr, rows * cols);
3537
3538 out_k_slice.copy_from_slice(&temp_k);
3539 out_d_slice.copy_from_slice(&temp_d);
3540 } else {
3541 let out_k_slice = std::slice::from_raw_parts_mut(out_k_ptr, rows * cols);
3542 let out_d_slice = std::slice::from_raw_parts_mut(out_d_ptr, rows * cols);
3543
3544 let kernel = detect_best_batch_kernel();
3545
3546 let simd_kernel = match kernel {
3547 Kernel::Avx512Batch => Kernel::Avx512,
3548 Kernel::Avx2Batch => Kernel::Avx2,
3549 Kernel::ScalarBatch => Kernel::Scalar,
3550 _ => Kernel::Scalar,
3551 };
3552 stochf_batch_inner_into(
3553 high,
3554 low,
3555 close,
3556 &sweep,
3557 simd_kernel,
3558 false,
3559 out_k_slice,
3560 out_d_slice,
3561 )
3562 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3563 }
3564
3565 Ok(rows)
3566 }
3567}