1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::types::PyDict;
5#[cfg(feature = "python")]
6use pyo3::{exceptions::PyValueError, prelude::*};
7
8#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
9use serde::{Deserialize, Serialize};
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use wasm_bindgen::prelude::*;
12
13use crate::indicators::moving_averages::ma::{ma, MaData};
14use crate::indicators::utility_functions::{max_rolling, min_rolling};
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
19 make_uninit_matrix,
20};
21#[cfg(feature = "python")]
22use crate::utilities::kernel_validation::validate_kernel;
23use aligned_vec::{AVec, CACHELINE_ALIGN};
24#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
25use core::arch::x86_64::*;
26#[cfg(not(target_arch = "wasm32"))]
27use rayon::prelude::*;
28use std::collections::VecDeque;
29use std::convert::AsRef;
30use std::error::Error;
31use thiserror::Error;
32
33#[derive(Debug, Clone)]
34pub enum StochData<'a> {
35 Candles {
36 candles: &'a Candles,
37 },
38 Slices {
39 high: &'a [f64],
40 low: &'a [f64],
41 close: &'a [f64],
42 },
43}
44
45#[derive(Debug, Clone)]
46pub struct StochOutput {
47 pub k: Vec<f64>,
48 pub d: Vec<f64>,
49}
50
51#[derive(Debug, Clone)]
52#[cfg_attr(
53 all(target_arch = "wasm32", feature = "wasm"),
54 derive(Serialize, Deserialize)
55)]
56pub struct StochParams {
57 pub fastk_period: Option<usize>,
58 pub slowk_period: Option<usize>,
59 pub slowk_ma_type: Option<String>,
60 pub slowd_period: Option<usize>,
61 pub slowd_ma_type: Option<String>,
62}
63
64impl Default for StochParams {
65 fn default() -> Self {
66 Self {
67 fastk_period: Some(14),
68 slowk_period: Some(3),
69 slowk_ma_type: Some("sma".to_string()),
70 slowd_period: Some(3),
71 slowd_ma_type: Some("sma".to_string()),
72 }
73 }
74}
75
76#[derive(Debug, Clone)]
77pub struct StochInput<'a> {
78 pub data: StochData<'a>,
79 pub params: StochParams,
80}
81
82impl<'a> StochInput<'a> {
83 #[inline]
84 pub fn from_candles(c: &'a Candles, p: StochParams) -> Self {
85 Self {
86 data: StochData::Candles { candles: c },
87 params: p,
88 }
89 }
90 #[inline]
91 pub fn from_slices(high: &'a [f64], low: &'a [f64], close: &'a [f64], p: StochParams) -> Self {
92 Self {
93 data: StochData::Slices { high, low, close },
94 params: p,
95 }
96 }
97 #[inline]
98 pub fn with_default_candles(c: &'a Candles) -> Self {
99 Self::from_candles(c, StochParams::default())
100 }
101 #[inline]
102 pub fn get_fastk_period(&self) -> usize {
103 self.params.fastk_period.unwrap_or(14)
104 }
105 #[inline]
106 pub fn get_slowk_period(&self) -> usize {
107 self.params.slowk_period.unwrap_or(3)
108 }
109 #[inline]
110 pub fn get_slowk_ma_type(&self) -> String {
111 self.params
112 .slowk_ma_type
113 .clone()
114 .unwrap_or_else(|| "sma".to_string())
115 }
116 #[inline]
117 pub fn get_slowd_period(&self) -> usize {
118 self.params.slowd_period.unwrap_or(3)
119 }
120 #[inline]
121 pub fn get_slowd_ma_type(&self) -> String {
122 self.params
123 .slowd_ma_type
124 .clone()
125 .unwrap_or_else(|| "sma".to_string())
126 }
127}
128
129#[derive(Copy, Clone, Debug)]
130pub struct StochBuilder {
131 fastk_period: Option<usize>,
132 slowk_period: Option<usize>,
133 slowk_ma_type: Option<&'static str>,
134 slowd_period: Option<usize>,
135 slowd_ma_type: Option<&'static str>,
136 kernel: Kernel,
137}
138
139impl Default for StochBuilder {
140 fn default() -> Self {
141 Self {
142 fastk_period: None,
143 slowk_period: None,
144 slowk_ma_type: None,
145 slowd_period: None,
146 slowd_ma_type: None,
147 kernel: Kernel::Auto,
148 }
149 }
150}
151
152impl StochBuilder {
153 #[inline(always)]
154 pub fn new() -> Self {
155 Self::default()
156 }
157 #[inline(always)]
158 pub fn fastk_period(mut self, n: usize) -> Self {
159 self.fastk_period = Some(n);
160 self
161 }
162 #[inline(always)]
163 pub fn slowk_period(mut self, n: usize) -> Self {
164 self.slowk_period = Some(n);
165 self
166 }
167 #[inline(always)]
168 pub fn slowk_ma_type(mut self, t: &'static str) -> Self {
169 self.slowk_ma_type = Some(t);
170 self
171 }
172 #[inline(always)]
173 pub fn slowd_period(mut self, n: usize) -> Self {
174 self.slowd_period = Some(n);
175 self
176 }
177 #[inline(always)]
178 pub fn slowd_ma_type(mut self, t: &'static str) -> Self {
179 self.slowd_ma_type = Some(t);
180 self
181 }
182 #[inline(always)]
183 pub fn kernel(mut self, k: Kernel) -> Self {
184 self.kernel = k;
185 self
186 }
187
188 #[inline(always)]
189 pub fn apply(self, c: &Candles) -> Result<StochOutput, StochError> {
190 let p = StochParams {
191 fastk_period: self.fastk_period,
192 slowk_period: self.slowk_period,
193 slowk_ma_type: self.slowk_ma_type.map(|s| s.to_string()),
194 slowd_period: self.slowd_period,
195 slowd_ma_type: self.slowd_ma_type.map(|s| s.to_string()),
196 };
197 let i = StochInput::from_candles(c, p);
198 stoch_with_kernel(&i, self.kernel)
199 }
200 #[inline(always)]
201 pub fn apply_slices(
202 self,
203 high: &[f64],
204 low: &[f64],
205 close: &[f64],
206 ) -> Result<StochOutput, StochError> {
207 let p = StochParams {
208 fastk_period: self.fastk_period,
209 slowk_period: self.slowk_period,
210 slowk_ma_type: self.slowk_ma_type.map(|s| s.to_string()),
211 slowd_period: self.slowd_period,
212 slowd_ma_type: self.slowd_ma_type.map(|s| s.to_string()),
213 };
214 let i = StochInput::from_slices(high, low, close, p);
215 stoch_with_kernel(&i, self.kernel)
216 }
217 #[inline(always)]
218 pub fn into_stream(self) -> Result<StochStream, StochError> {
219 let p = StochParams {
220 fastk_period: self.fastk_period,
221 slowk_period: self.slowk_period,
222 slowk_ma_type: self.slowk_ma_type.map(|s| s.to_string()),
223 slowd_period: self.slowd_period,
224 slowd_ma_type: self.slowd_ma_type.map(|s| s.to_string()),
225 };
226 StochStream::try_new(p)
227 }
228}
229
230#[derive(Debug, Error)]
231pub enum StochError {
232 #[error("stoch: Empty data provided.")]
233 EmptyInputData,
234 #[error("stoch: Mismatched length.")]
235 MismatchedLength,
236 #[error("stoch: Invalid period: period = {period}, data length = {data_len}")]
237 InvalidPeriod { period: usize, data_len: usize },
238 #[error("stoch: Not enough valid data: needed = {needed}, valid = {valid}")]
239 NotEnoughValidData { needed: usize, valid: usize },
240 #[error("stoch: All values are NaN.")]
241 AllValuesNaN,
242 #[error("stoch: Output length mismatch: expected {expected}, got {got}")]
243 OutputLengthMismatch { expected: usize, got: usize },
244 #[error("stoch: Invalid range: start={start}, end={end}, step={step}")]
245 InvalidRange {
246 start: usize,
247 end: usize,
248 step: usize,
249 },
250 #[error("stoch: Invalid kernel for batch: {0:?}")]
251 InvalidKernelForBatch(Kernel),
252 #[error("stoch: {0}")]
253 Other(String),
254}
255
256#[inline]
257pub fn stoch(input: &StochInput) -> Result<StochOutput, StochError> {
258 stoch_with_kernel(input, Kernel::Auto)
259}
260
261pub fn stoch_with_kernel(input: &StochInput, kernel: Kernel) -> Result<StochOutput, StochError> {
262 let (high, low, close) = match &input.data {
263 StochData::Candles { candles } => {
264 let high = candles
265 .select_candle_field("high")
266 .map_err(|e| StochError::Other(e.to_string()))?;
267 let low = candles
268 .select_candle_field("low")
269 .map_err(|e| StochError::Other(e.to_string()))?;
270 let close = candles
271 .select_candle_field("close")
272 .map_err(|e| StochError::Other(e.to_string()))?;
273 (high, low, close)
274 }
275 StochData::Slices { high, low, close } => (*high, *low, *close),
276 };
277
278 let data_len = high.len();
279 if data_len == 0 || low.is_empty() || close.is_empty() {
280 return Err(StochError::EmptyInputData);
281 }
282 if data_len != low.len() || data_len != close.len() {
283 return Err(StochError::MismatchedLength);
284 }
285
286 let fastk_period = input.get_fastk_period();
287 let slowk_period = input.get_slowk_period();
288 let slowd_period = input.get_slowd_period();
289
290 if fastk_period == 0 || fastk_period > data_len {
291 return Err(StochError::InvalidPeriod {
292 period: fastk_period,
293 data_len,
294 });
295 }
296 if slowk_period == 0 || slowk_period > data_len {
297 return Err(StochError::InvalidPeriod {
298 period: slowk_period,
299 data_len,
300 });
301 }
302 if slowd_period == 0 || slowd_period > data_len {
303 return Err(StochError::InvalidPeriod {
304 period: slowd_period,
305 data_len,
306 });
307 }
308
309 let first_valid_idx = high
310 .iter()
311 .zip(low.iter())
312 .zip(close.iter())
313 .position(|((h, l), c)| !h.is_nan() && !l.is_nan() && !c.is_nan())
314 .ok_or(StochError::AllValuesNaN)?;
315
316 if (data_len - first_valid_idx) < fastk_period {
317 return Err(StochError::NotEnoughValidData {
318 needed: fastk_period,
319 valid: data_len - first_valid_idx,
320 });
321 }
322
323 let mut hh = alloc_with_nan_prefix(data_len, first_valid_idx + fastk_period - 1);
324 let mut ll = alloc_with_nan_prefix(data_len, first_valid_idx + fastk_period - 1);
325
326 let max_vals = max_rolling(&high[first_valid_idx..], fastk_period)
327 .map_err(|e| StochError::Other(e.to_string()))?;
328 let min_vals = min_rolling(&low[first_valid_idx..], fastk_period)
329 .map_err(|e| StochError::Other(e.to_string()))?;
330
331 for (i, &val) in max_vals.iter().enumerate() {
332 hh[i + first_valid_idx] = val;
333 }
334 for (i, &val) in min_vals.iter().enumerate() {
335 ll[i + first_valid_idx] = val;
336 }
337
338 let mut k_raw = alloc_with_nan_prefix(data_len, first_valid_idx + fastk_period - 1);
339
340 let chosen = match kernel {
341 Kernel::Auto => Kernel::Scalar,
342 other => other,
343 };
344 unsafe {
345 match chosen {
346 Kernel::Scalar | Kernel::ScalarBatch => stoch_scalar(
347 high,
348 low,
349 close,
350 &hh,
351 &ll,
352 fastk_period,
353 first_valid_idx,
354 &mut k_raw,
355 ),
356 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
357 Kernel::Avx2 | Kernel::Avx2Batch => stoch_avx2(
358 high,
359 low,
360 close,
361 &hh,
362 &ll,
363 fastk_period,
364 first_valid_idx,
365 &mut k_raw,
366 ),
367 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
368 Kernel::Avx512 | Kernel::Avx512Batch => stoch_avx512(
369 high,
370 low,
371 close,
372 &hh,
373 &ll,
374 fastk_period,
375 first_valid_idx,
376 &mut k_raw,
377 ),
378 _ => unreachable!(),
379 }
380 }
381
382 let slowk_ma_type = input.get_slowk_ma_type();
383 let slowd_ma_type = input.get_slowd_ma_type();
384
385 let k_first_valid = first_valid_idx + fastk_period - 1;
386 if (slowk_ma_type == "sma" || slowk_ma_type == "SMA")
387 && (slowd_ma_type == "sma" || slowd_ma_type == "SMA")
388 {
389 return stoch_classic_sma(&k_raw, slowk_period, slowd_period, k_first_valid);
390 } else if (slowk_ma_type == "ema" || slowk_ma_type == "EMA")
391 && (slowd_ma_type == "ema" || slowd_ma_type == "EMA")
392 {
393 return stoch_classic_ema(&k_raw, slowk_period, slowd_period, k_first_valid);
394 }
395
396 let k_vec = ma(&slowk_ma_type, MaData::Slice(&k_raw), slowk_period)
397 .map_err(|e| StochError::Other(e.to_string()))?;
398 let d_vec = ma(&slowd_ma_type, MaData::Slice(&k_vec), slowd_period)
399 .map_err(|e| StochError::Other(e.to_string()))?;
400 Ok(StochOutput { k: k_vec, d: d_vec })
401}
402
403pub fn stoch_into_slices(
404 out_k: &mut [f64],
405 out_d: &mut [f64],
406 input: &StochInput,
407 kernel: Kernel,
408) -> Result<(), StochError> {
409 let StochOutput { k, d } = stoch_with_kernel(input, kernel)?;
410 if out_k.len() != k.len() {
411 return Err(StochError::OutputLengthMismatch {
412 expected: k.len(),
413 got: out_k.len(),
414 });
415 }
416 if out_d.len() != d.len() {
417 return Err(StochError::OutputLengthMismatch {
418 expected: d.len(),
419 got: out_d.len(),
420 });
421 }
422 out_k.copy_from_slice(&k);
423 out_d.copy_from_slice(&d);
424 Ok(())
425}
426
427#[inline]
428fn prefill_nan_prefix(dst: &mut [f64], warm: usize) {
429 let warm = warm.min(dst.len());
430 for v in &mut dst[..warm] {
431 *v = f64::from_bits(0x7ff8_0000_0000_0000);
432 }
433}
434
435#[inline]
436fn stoch_compute_into(
437 input: &StochInput,
438 out_k: &mut [f64],
439 out_d: &mut [f64],
440 kernel: Kernel,
441) -> Result<(), StochError> {
442 let (high, low, close) = match &input.data {
443 StochData::Candles { candles } => {
444 let high = candles
445 .select_candle_field("high")
446 .map_err(|e| StochError::Other(e.to_string()))?;
447 let low = candles
448 .select_candle_field("low")
449 .map_err(|e| StochError::Other(e.to_string()))?;
450 let close = candles
451 .select_candle_field("close")
452 .map_err(|e| StochError::Other(e.to_string()))?;
453 (high, low, close)
454 }
455 StochData::Slices { high, low, close } => (*high, *low, *close),
456 };
457
458 let len = high.len();
459 if len == 0 || low.is_empty() || close.is_empty() {
460 return Err(StochError::EmptyInputData);
461 }
462 if len != low.len() || len != close.len() {
463 return Err(StochError::MismatchedLength);
464 }
465 if out_k.len() != len {
466 return Err(StochError::OutputLengthMismatch {
467 expected: len,
468 got: out_k.len(),
469 });
470 }
471 if out_d.len() != len {
472 return Err(StochError::OutputLengthMismatch {
473 expected: len,
474 got: out_d.len(),
475 });
476 }
477
478 let fastk_period = input.get_fastk_period();
479 let slowk_period = input.get_slowk_period();
480 let slowd_period = input.get_slowd_period();
481
482 if fastk_period == 0 || fastk_period > len {
483 return Err(StochError::InvalidPeriod {
484 period: fastk_period,
485 data_len: len,
486 });
487 }
488 if slowk_period == 0 || slowk_period > len {
489 return Err(StochError::InvalidPeriod {
490 period: slowk_period,
491 data_len: len,
492 });
493 }
494 if slowd_period == 0 || slowd_period > len {
495 return Err(StochError::InvalidPeriod {
496 period: slowd_period,
497 data_len: len,
498 });
499 }
500
501 let first = high
502 .iter()
503 .zip(low.iter())
504 .zip(close.iter())
505 .position(|((h, l), c)| !h.is_nan() && !l.is_nan() && !c.is_nan())
506 .ok_or(StochError::AllValuesNaN)?;
507
508 if (len - first) < fastk_period {
509 return Err(StochError::NotEnoughValidData {
510 needed: fastk_period,
511 valid: len - first,
512 });
513 }
514
515 let slowk_ma_type = input.get_slowk_ma_type();
516 let slowd_ma_type = input.get_slowd_ma_type();
517 let chosen = match kernel {
518 Kernel::Auto => Kernel::Scalar,
519 other => other,
520 };
521
522 if (slowk_ma_type == "sma" || slowk_ma_type == "SMA")
523 && (slowd_ma_type == "sma" || slowd_ma_type == "SMA")
524 && matches!(chosen, Kernel::Scalar | Kernel::ScalarBatch)
525 {
526 return stoch_classic_sma_into_single_pass(
527 high,
528 low,
529 close,
530 fastk_period,
531 slowk_period,
532 slowd_period,
533 first,
534 out_k,
535 out_d,
536 );
537 }
538
539 let mut hh = alloc_with_nan_prefix(len, first + fastk_period - 1);
540 let mut ll = alloc_with_nan_prefix(len, first + fastk_period - 1);
541 let highs =
542 max_rolling(&high[first..], fastk_period).map_err(|e| StochError::Other(e.to_string()))?;
543 let lows =
544 min_rolling(&low[first..], fastk_period).map_err(|e| StochError::Other(e.to_string()))?;
545 for (i, &v) in highs.iter().enumerate() {
546 hh[first + i] = v;
547 }
548 for (i, &v) in lows.iter().enumerate() {
549 ll[first + i] = v;
550 }
551
552 let mut k_raw = alloc_with_nan_prefix(len, first + fastk_period - 1);
553
554 unsafe {
555 match chosen {
556 Kernel::Scalar | Kernel::ScalarBatch => {
557 stoch_scalar(high, low, close, &hh, &ll, fastk_period, first, &mut k_raw)
558 }
559 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
560 Kernel::Avx2 | Kernel::Avx2Batch => {
561 stoch_avx2(high, low, close, &hh, &ll, fastk_period, first, &mut k_raw)
562 }
563 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
564 Kernel::Avx512 | Kernel::Avx512Batch => {
565 stoch_avx512(high, low, close, &hh, &ll, fastk_period, first, &mut k_raw)
566 }
567 _ => unreachable!(),
568 }
569 }
570
571 let k_first_valid = first + fastk_period - 1;
572
573 if (slowk_ma_type == "sma" || slowk_ma_type == "SMA")
574 && (slowd_ma_type == "sma" || slowd_ma_type == "SMA")
575 {
576 prefill_nan_prefix(out_k, k_first_valid + slowk_period - 1);
577 prefill_nan_prefix(out_d, k_first_valid + slowk_period + slowd_period - 2);
578
579 let mut sum_k = 0.0;
580 let k_start = k_first_valid;
581 for i in k_start..(k_start + slowk_period).min(len) {
582 if !k_raw[i].is_nan() {
583 sum_k += k_raw[i];
584 }
585 }
586 if k_start + slowk_period - 1 < len {
587 out_k[k_start + slowk_period - 1] = sum_k / slowk_period as f64;
588 }
589 for i in (k_start + slowk_period)..len {
590 let old = k_raw[i - slowk_period];
591 let newv = k_raw[i];
592 if !old.is_nan() {
593 sum_k -= old;
594 }
595 if !newv.is_nan() {
596 sum_k += newv;
597 }
598 out_k[i] = sum_k / slowk_period as f64;
599 }
600
601 let mut sum_d = 0.0;
602 let d_start = k_first_valid + slowk_period - 1;
603 for i in d_start..(d_start + slowd_period).min(len) {
604 if !out_k[i].is_nan() {
605 sum_d += out_k[i];
606 }
607 }
608 if d_start + slowd_period - 1 < len {
609 out_d[d_start + slowd_period - 1] = sum_d / slowd_period as f64;
610 }
611 for i in (d_start + slowd_period)..len {
612 let old = out_k[i - slowd_period];
613 let newv = out_k[i];
614 if !old.is_nan() {
615 sum_d -= old;
616 }
617 if !newv.is_nan() {
618 sum_d += newv;
619 }
620 out_d[i] = sum_d / slowd_period as f64;
621 }
622 return Ok(());
623 }
624
625 if (slowk_ma_type == "ema" || slowk_ma_type == "EMA")
626 && (slowd_ma_type == "ema" || slowd_ma_type == "EMA")
627 {
628 prefill_nan_prefix(out_k, k_first_valid + slowk_period - 1);
629 prefill_nan_prefix(out_d, k_first_valid + slowk_period + slowd_period - 2);
630
631 let alpha_k = 2.0 / (slowk_period as f64 + 1.0);
632 let one_minus_alpha_k = 1.0 - alpha_k;
633 let k_warm = k_first_valid + slowk_period - 1;
634 let mut sum_k = 0.0;
635 let mut cnt_k = 0;
636 for i in k_first_valid..(k_first_valid + slowk_period).min(len) {
637 if !k_raw[i].is_nan() {
638 sum_k += k_raw[i];
639 cnt_k += 1;
640 }
641 }
642 if cnt_k > 0 && k_warm < len {
643 let mut ema_k = sum_k / cnt_k as f64;
644 out_k[k_warm] = ema_k;
645 for i in (k_warm + 1)..len {
646 if !k_raw[i].is_nan() {
647 ema_k = alpha_k * k_raw[i] + one_minus_alpha_k * ema_k;
648 }
649 out_k[i] = ema_k;
650 }
651 } else {
652 for i in k_warm..len {
653 out_k[i] = f64::from_bits(0x7ff8_0000_0000_0000);
654 }
655 }
656
657 let alpha_d = 2.0 / (slowd_period as f64 + 1.0);
658 let one_minus_alpha_d = 1.0 - alpha_d;
659 let d_warm = k_first_valid + slowk_period + slowd_period - 2;
660 let d_start = k_first_valid + slowk_period - 1;
661 let mut sum_d = 0.0;
662 let mut cnt_d = 0;
663 for i in d_start..(d_start + slowd_period).min(len) {
664 if !out_k[i].is_nan() {
665 sum_d += out_k[i];
666 cnt_d += 1;
667 }
668 }
669 if cnt_d > 0 && d_warm < len {
670 let mut ema_d = sum_d / cnt_d as f64;
671 out_d[d_warm] = ema_d;
672 for i in (d_warm + 1)..len {
673 if !out_k[i].is_nan() {
674 ema_d = alpha_d * out_k[i] + one_minus_alpha_d * ema_d;
675 }
676 out_d[i] = ema_d;
677 }
678 } else {
679 for i in d_warm..len {
680 out_d[i] = f64::from_bits(0x7ff8_0000_0000_0000);
681 }
682 }
683 return Ok(());
684 }
685
686 let k_vec = ma(&slowk_ma_type, MaData::Slice(&k_raw), slowk_period)
687 .map_err(|e| StochError::Other(e.to_string()))?;
688 let d_vec = ma(&slowd_ma_type, MaData::Slice(&k_vec), slowd_period)
689 .map_err(|e| StochError::Other(e.to_string()))?;
690 out_k.copy_from_slice(&k_vec);
691 out_d.copy_from_slice(&d_vec);
692 Ok(())
693}
694
695#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
696pub fn stoch_into(
697 input: &StochInput,
698 out_k: &mut [f64],
699 out_d: &mut [f64],
700) -> Result<(), StochError> {
701 #[cfg(test)]
702 {
703 stoch_into_slices(out_k, out_d, input, Kernel::Auto)
704 }
705 #[cfg(not(test))]
706 {
707 stoch_compute_into(input, out_k, out_d, Kernel::Auto)
708 }
709}
710
711fn stoch_classic_sma_into_single_pass(
712 high: &[f64],
713 low: &[f64],
714 close: &[f64],
715 fastk_period: usize,
716 slowk_period: usize,
717 slowd_period: usize,
718 first: usize,
719 out_k: &mut [f64],
720 out_d: &mut [f64],
721) -> Result<(), StochError> {
722 let len = close.len();
723
724 let k_first_valid = first + fastk_period - 1;
725 let k_warm = k_first_valid + slowk_period - 1;
726 let d_warm = k_first_valid + slowk_period + slowd_period - 2;
727
728 prefill_nan_prefix(out_k, k_warm);
729 prefill_nan_prefix(out_d, d_warm);
730
731 let mut trail = first;
732 let mut maxi = first;
733 let mut mini = first;
734 let mut max = high[first];
735 let mut min = low[first];
736
737 let mut k_buf = vec![0.0f64; slowk_period];
738 let mut k_pos: usize = 0;
739 let mut k_sum = 0.0f64;
740 let mut k_count: usize = 0;
741
742 let mut d_buf = vec![0.0f64; slowd_period];
743 let mut d_pos: usize = 0;
744 let mut d_sum = 0.0f64;
745 let mut d_count: usize = 0;
746
747 const SCALE: f64 = 100.0;
748 const EPS: f64 = f64::EPSILON;
749
750 for i in first..len {
751 if i >= first + fastk_period {
752 trail += 1;
753 }
754
755 let bar_h = high[i];
756 if maxi < trail {
757 maxi = trail;
758 max = high[maxi];
759 let mut j = trail;
760 while j < i {
761 j += 1;
762 let v = high[j];
763 if v >= max {
764 max = v;
765 maxi = j;
766 }
767 }
768 } else if bar_h >= max {
769 maxi = i;
770 max = bar_h;
771 }
772
773 let bar_l = low[i];
774 if mini < trail {
775 mini = trail;
776 min = low[mini];
777 let mut j = trail;
778 while j < i {
779 j += 1;
780 let v = low[j];
781 if v <= min {
782 min = v;
783 mini = j;
784 }
785 }
786 } else if bar_l <= min {
787 mini = i;
788 min = bar_l;
789 }
790
791 if i < k_first_valid {
792 continue;
793 }
794
795 let c = close[i];
796 let denom = max - min;
797 let k_raw = if denom.abs() < EPS {
798 50.0
799 } else {
800 (c - min).mul_add(SCALE / denom, 0.0)
801 };
802
803 if k_count >= slowk_period {
804 k_sum -= k_buf[k_pos];
805 }
806 k_buf[k_pos] = k_raw;
807 k_sum += k_raw;
808 k_count += 1;
809 k_pos += 1;
810 if k_pos == slowk_period {
811 k_pos = 0;
812 }
813
814 if i >= k_warm {
815 let k_sma = k_sum / slowk_period as f64;
816 out_k[i] = k_sma;
817
818 if d_count >= slowd_period {
819 d_sum -= d_buf[d_pos];
820 }
821 d_buf[d_pos] = k_sma;
822 d_sum += k_sma;
823 d_count += 1;
824 d_pos += 1;
825 if d_pos == slowd_period {
826 d_pos = 0;
827 }
828
829 if i >= d_warm {
830 out_d[i] = d_sum / slowd_period as f64;
831 }
832 }
833 }
834
835 Ok(())
836}
837
838#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
839#[inline]
840pub fn stoch_avx512(
841 high: &[f64],
842 low: &[f64],
843 close: &[f64],
844 hh: &[f64],
845 ll: &[f64],
846 fastk_period: usize,
847 first_valid: usize,
848 out: &mut [f64],
849) {
850 if fastk_period <= 32 {
851 unsafe { stoch_avx512_short(high, low, close, hh, ll, fastk_period, first_valid, out) }
852 } else {
853 unsafe { stoch_avx512_long(high, low, close, hh, ll, fastk_period, first_valid, out) }
854 }
855}
856
857#[inline]
858pub fn stoch_scalar(
859 _high: &[f64],
860 _low: &[f64],
861 close: &[f64],
862 hh: &[f64],
863 ll: &[f64],
864 fastk_period: usize,
865 first_val: usize,
866 out: &mut [f64],
867) {
868 let start = first_val + fastk_period - 1;
869 if start >= close.len() {
870 return;
871 }
872
873 const SCALE: f64 = 100.0;
874 const EPS: f64 = f64::EPSILON;
875
876 let c = &close[start..];
877 let h = &hh[start..];
878 let l = &ll[start..];
879 let outv = &mut out[start..];
880
881 for (o, (&cv, (&hv, &lv))) in outv.iter_mut().zip(c.iter().zip(h.iter().zip(l.iter()))) {
882 let d = hv - lv;
883 *o = if d.abs() < EPS {
884 50.0
885 } else {
886 (cv - lv).mul_add(SCALE / d, 0.0)
887 };
888 }
889}
890
891#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
892#[inline]
893pub fn stoch_avx2(
894 high: &[f64],
895 low: &[f64],
896 close: &[f64],
897 hh: &[f64],
898 ll: &[f64],
899 fastk_period: usize,
900 first_valid: usize,
901 out: &mut [f64],
902) {
903 unsafe { stoch_avx2_impl(high, low, close, hh, ll, fastk_period, first_valid, out) }
904}
905
906#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
907#[target_feature(enable = "avx2")]
908unsafe fn stoch_avx2_impl(
909 _high: &[f64],
910 _low: &[f64],
911 close: &[f64],
912 hh: &[f64],
913 ll: &[f64],
914 fastk_period: usize,
915 first_valid: usize,
916 out: &mut [f64],
917) {
918 let start = first_valid + fastk_period - 1;
919 if start >= close.len() {
920 return;
921 }
922
923 let n = close.len() - start;
924 let mut i = 0usize;
925
926 let c_ptr = close.as_ptr().add(start);
927 let h_ptr = hh.as_ptr().add(start);
928 let l_ptr = ll.as_ptr().add(start);
929 let o_ptr = out.as_mut_ptr().add(start);
930
931 const STEP: usize = 4;
932 let vec_end = n & !(STEP - 1);
933
934 let scale = _mm256_set1_pd(100.0);
935 let fifty = _mm256_set1_pd(50.0);
936 let eps = _mm256_set1_pd(f64::EPSILON);
937 let sign_mask = _mm256_set1_pd(-0.0);
938
939 while i + STEP <= vec_end {
940 let c0 = _mm256_loadu_pd(c_ptr.add(i));
941 let h0 = _mm256_loadu_pd(h_ptr.add(i));
942 let l0 = _mm256_loadu_pd(l_ptr.add(i));
943 let d0 = _mm256_sub_pd(h0, l0);
944 let n0 = _mm256_sub_pd(c0, l0);
945 let a0 = _mm256_andnot_pd(sign_mask, d0);
946 let m0 = _mm256_cmp_pd(a0, eps, _CMP_LT_OQ);
947 let inv0 = _mm256_div_pd(scale, d0);
948 let v0 = _mm256_mul_pd(n0, inv0);
949 let o0 = _mm256_blendv_pd(v0, fifty, m0);
950
951 if i + 2 * STEP <= vec_end {
952 let c1 = _mm256_loadu_pd(c_ptr.add(i + STEP));
953 let h1 = _mm256_loadu_pd(h_ptr.add(i + STEP));
954 let l1 = _mm256_loadu_pd(l_ptr.add(i + STEP));
955 let d1 = _mm256_sub_pd(h1, l1);
956 let n1 = _mm256_sub_pd(c1, l1);
957 let a1 = _mm256_andnot_pd(sign_mask, d1);
958 let m1 = _mm256_cmp_pd(a1, eps, _CMP_LT_OQ);
959 let inv1 = _mm256_div_pd(scale, d1);
960 let v1 = _mm256_mul_pd(n1, inv1);
961 let o1 = _mm256_blendv_pd(v1, fifty, m1);
962
963 _mm256_storeu_pd(o_ptr.add(i), o0);
964 _mm256_storeu_pd(o_ptr.add(i + STEP), o1);
965 i += 2 * STEP;
966 } else {
967 _mm256_storeu_pd(o_ptr.add(i), o0);
968 i += STEP;
969 }
970 }
971
972 while i < n {
973 let c = *c_ptr.add(i);
974 let l = *l_ptr.add(i);
975 let d = *h_ptr.add(i) - l;
976 *o_ptr.add(i) = if d.abs() < f64::EPSILON {
977 50.0
978 } else {
979 (c - l) * (100.0 / d)
980 };
981 i += 1;
982 }
983}
984
985#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
986#[inline]
987pub fn stoch_avx512_short(
988 high: &[f64],
989 low: &[f64],
990 close: &[f64],
991 hh: &[f64],
992 ll: &[f64],
993 fastk_period: usize,
994 first_valid: usize,
995 out: &mut [f64],
996) {
997 unsafe { stoch_avx512_impl(high, low, close, hh, ll, fastk_period, first_valid, out) }
998}
999
1000#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1001#[inline]
1002pub fn stoch_avx512_long(
1003 high: &[f64],
1004 low: &[f64],
1005 close: &[f64],
1006 hh: &[f64],
1007 ll: &[f64],
1008 fastk_period: usize,
1009 first_valid: usize,
1010 out: &mut [f64],
1011) {
1012 unsafe { stoch_avx512_impl(high, low, close, hh, ll, fastk_period, first_valid, out) }
1013}
1014
1015#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1016#[target_feature(enable = "avx512f")]
1017unsafe fn stoch_avx512_impl(
1018 _high: &[f64],
1019 _low: &[f64],
1020 close: &[f64],
1021 hh: &[f64],
1022 ll: &[f64],
1023 fastk_period: usize,
1024 first_valid: usize,
1025 out: &mut [f64],
1026) {
1027 let start = first_valid + fastk_period - 1;
1028 if start >= close.len() {
1029 return;
1030 }
1031
1032 let n = close.len() - start;
1033
1034 let c_ptr = close.as_ptr().add(start);
1035 let h_ptr = hh.as_ptr().add(start);
1036 let l_ptr = ll.as_ptr().add(start);
1037 let o_ptr = out.as_mut_ptr().add(start);
1038
1039 const STEP: usize = 8;
1040 let vec_end = n & !(STEP - 1);
1041
1042 let scale = _mm512_set1_pd(100.0);
1043 let fifty = _mm512_set1_pd(50.0);
1044 let eps = _mm512_set1_pd(f64::EPSILON);
1045 let sign_mask = _mm512_set1_pd(-0.0);
1046
1047 let mut i = 0usize;
1048 while i + STEP <= vec_end {
1049 let c0 = _mm512_loadu_pd(c_ptr.add(i));
1050 let h0 = _mm512_loadu_pd(h_ptr.add(i));
1051 let l0 = _mm512_loadu_pd(l_ptr.add(i));
1052 let d0 = _mm512_sub_pd(h0, l0);
1053 let n0 = _mm512_sub_pd(c0, l0);
1054 let a0 = _mm512_andnot_pd(sign_mask, d0);
1055 let m0: __mmask8 = _mm512_cmp_pd_mask(a0, eps, _CMP_LT_OQ);
1056 let inv0 = _mm512_div_pd(scale, d0);
1057 let v0 = _mm512_mul_pd(n0, inv0);
1058 let o0 = _mm512_mask_blend_pd(m0, v0, fifty);
1059
1060 if i + 2 * STEP <= vec_end {
1061 let c1 = _mm512_loadu_pd(c_ptr.add(i + STEP));
1062 let h1 = _mm512_loadu_pd(h_ptr.add(i + STEP));
1063 let l1 = _mm512_loadu_pd(l_ptr.add(i + STEP));
1064 let d1 = _mm512_sub_pd(h1, l1);
1065 let n1 = _mm512_sub_pd(c1, l1);
1066 let a1 = _mm512_andnot_pd(sign_mask, d1);
1067 let m1: __mmask8 = _mm512_cmp_pd_mask(a1, eps, _CMP_LT_OQ);
1068 let inv1 = _mm512_div_pd(scale, d1);
1069 let v1 = _mm512_mul_pd(n1, inv1);
1070 let o1 = _mm512_mask_blend_pd(m1, v1, fifty);
1071
1072 _mm512_storeu_pd(o_ptr.add(i), o0);
1073 _mm512_storeu_pd(o_ptr.add(i + STEP), o1);
1074 i += 2 * STEP;
1075 } else {
1076 _mm512_storeu_pd(o_ptr.add(i), o0);
1077 i += STEP;
1078 }
1079 }
1080
1081 while i < n {
1082 let c = *c_ptr.add(i);
1083 let l = *l_ptr.add(i);
1084 let d = *h_ptr.add(i) - l;
1085 *o_ptr.add(i) = if d.abs() < f64::EPSILON {
1086 50.0
1087 } else {
1088 (c - l) * (100.0 / d)
1089 };
1090 i += 1;
1091 }
1092}
1093
1094#[derive(Clone, Debug)]
1095pub struct StochBatchRange {
1096 pub fastk_period: (usize, usize, usize),
1097 pub slowk_period: (usize, usize, usize),
1098 pub slowk_ma_type: (String, String, f64),
1099 pub slowd_period: (usize, usize, usize),
1100 pub slowd_ma_type: (String, String, f64),
1101}
1102
1103impl Default for StochBatchRange {
1104 fn default() -> Self {
1105 Self {
1106 fastk_period: (14, 263, 1),
1107 slowk_period: (3, 3, 0),
1108 slowk_ma_type: ("sma".to_string(), "sma".to_string(), 0.0),
1109 slowd_period: (3, 3, 0),
1110 slowd_ma_type: ("sma".to_string(), "sma".to_string(), 0.0),
1111 }
1112 }
1113}
1114
1115#[derive(Clone, Debug, Default)]
1116pub struct StochBatchBuilder {
1117 range: StochBatchRange,
1118 kernel: Kernel,
1119}
1120
1121impl StochBatchBuilder {
1122 pub fn new() -> Self {
1123 Self::default()
1124 }
1125 pub fn kernel(mut self, k: Kernel) -> Self {
1126 self.kernel = k;
1127 self
1128 }
1129 pub fn fastk_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1130 self.range.fastk_period = (start, end, step);
1131 self
1132 }
1133 pub fn fastk_period_static(mut self, p: usize) -> Self {
1134 self.range.fastk_period = (p, p, 0);
1135 self
1136 }
1137 pub fn slowk_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1138 self.range.slowk_period = (start, end, step);
1139 self
1140 }
1141 pub fn slowk_period_static(mut self, p: usize) -> Self {
1142 self.range.slowk_period = (p, p, 0);
1143 self
1144 }
1145 pub fn slowk_ma_type_static(mut self, t: &str) -> Self {
1146 self.range.slowk_ma_type = (t.to_string(), t.to_string(), 0.0);
1147 self
1148 }
1149 pub fn slowd_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1150 self.range.slowd_period = (start, end, step);
1151 self
1152 }
1153 pub fn slowd_period_static(mut self, p: usize) -> Self {
1154 self.range.slowd_period = (p, p, 0);
1155 self
1156 }
1157 pub fn slowd_ma_type_static(mut self, t: &str) -> Self {
1158 self.range.slowd_ma_type = (t.to_string(), t.to_string(), 0.0);
1159 self
1160 }
1161
1162 pub fn apply_slices(
1163 self,
1164 high: &[f64],
1165 low: &[f64],
1166 close: &[f64],
1167 ) -> Result<StochBatchOutput, StochError> {
1168 stoch_batch_with_kernel(high, low, close, &self.range, self.kernel)
1169 }
1170 pub fn apply_candles(self, c: &Candles) -> Result<StochBatchOutput, StochError> {
1171 let high = source_type(c, "high");
1172 let low = source_type(c, "low");
1173 let close = source_type(c, "close");
1174 self.apply_slices(high, low, close)
1175 }
1176}
1177
1178pub fn stoch_batch_with_kernel(
1179 high: &[f64],
1180 low: &[f64],
1181 close: &[f64],
1182 sweep: &StochBatchRange,
1183 k: Kernel,
1184) -> Result<StochBatchOutput, StochError> {
1185 let kernel = match k {
1186 Kernel::Auto => Kernel::ScalarBatch,
1187 other if other.is_batch() => other,
1188 other => return Err(StochError::InvalidKernelForBatch(other)),
1189 };
1190 let simd = match kernel {
1191 Kernel::Avx512Batch => Kernel::Avx512,
1192 Kernel::Avx2Batch => Kernel::Avx2,
1193 Kernel::ScalarBatch => Kernel::Scalar,
1194 _ => unreachable!(),
1195 };
1196 stoch_batch_par_slice(high, low, close, sweep, simd)
1197}
1198
1199#[derive(Clone, Debug)]
1200pub struct StochBatchOutput {
1201 pub k: Vec<f64>,
1202 pub d: Vec<f64>,
1203 pub combos: Vec<StochParams>,
1204 pub rows: usize,
1205 pub cols: usize,
1206}
1207impl StochBatchOutput {
1208 pub fn row_for_params(&self, p: &StochParams) -> Option<usize> {
1209 self.combos.iter().position(|c| {
1210 c.fastk_period == p.fastk_period
1211 && c.slowk_period == p.slowk_period
1212 && c.slowk_ma_type == p.slowk_ma_type
1213 && c.slowd_period == p.slowd_period
1214 && c.slowd_ma_type == p.slowd_ma_type
1215 })
1216 }
1217 pub fn values_for(&self, p: &StochParams) -> Option<(&[f64], &[f64])> {
1218 self.row_for_params(p).map(|row| {
1219 let start = row * self.cols;
1220 (
1221 &self.k[start..start + self.cols],
1222 &self.d[start..start + self.cols],
1223 )
1224 })
1225 }
1226}
1227
1228#[inline(always)]
1229fn expand_grid(r: &StochBatchRange) -> Result<Vec<StochParams>, StochError> {
1230 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, StochError> {
1231 if step == 0 || start == end {
1232 return Ok(vec![start]);
1233 }
1234
1235 let mut v = Vec::new();
1236 if start < end {
1237 let mut x = start;
1238 loop {
1239 v.push(x);
1240 match x.checked_add(step) {
1241 Some(next) if next <= end => x = next,
1242 Some(_) | None => break,
1243 }
1244 }
1245 } else {
1246 let mut x = start;
1247 loop {
1248 v.push(x);
1249 match x.checked_sub(step) {
1250 Some(next) if next >= end => x = next,
1251 Some(_) | None => break,
1252 }
1253 }
1254 }
1255
1256 if v.is_empty() {
1257 Err(StochError::InvalidRange { start, end, step })
1258 } else {
1259 Ok(v)
1260 }
1261 }
1262 fn axis_str((start, end, _): (String, String, f64)) -> Vec<String> {
1263 if start == end {
1264 vec![start]
1265 } else {
1266 vec![start, end]
1267 }
1268 }
1269 let fastk_periods = axis_usize(r.fastk_period)?;
1270 let slowk_periods = axis_usize(r.slowk_period)?;
1271 let slowk_types = axis_str(r.slowk_ma_type.clone());
1272 let slowd_periods = axis_usize(r.slowd_period)?;
1273 let slowd_types = axis_str(r.slowd_ma_type.clone());
1274
1275 let combos_len = fastk_periods
1276 .len()
1277 .checked_mul(slowk_periods.len())
1278 .and_then(|v| v.checked_mul(slowk_types.len()))
1279 .and_then(|v| v.checked_mul(slowd_periods.len()))
1280 .and_then(|v| v.checked_mul(slowd_types.len()))
1281 .ok_or(StochError::InvalidRange {
1282 start: r.fastk_period.0,
1283 end: r.fastk_period.1,
1284 step: r.fastk_period.2,
1285 })?;
1286
1287 let mut out = Vec::with_capacity(combos_len);
1288 for &fkp in &fastk_periods {
1289 for &skp in &slowk_periods {
1290 for skt in &slowk_types {
1291 for &sdp in &slowd_periods {
1292 for sdt in &slowd_types {
1293 out.push(StochParams {
1294 fastk_period: Some(fkp),
1295 slowk_period: Some(skp),
1296 slowk_ma_type: Some(skt.clone()),
1297 slowd_period: Some(sdp),
1298 slowd_ma_type: Some(sdt.clone()),
1299 });
1300 }
1301 }
1302 }
1303 }
1304 }
1305 Ok(out)
1306}
1307
1308#[inline(always)]
1309pub fn stoch_batch_slice(
1310 high: &[f64],
1311 low: &[f64],
1312 close: &[f64],
1313 sweep: &StochBatchRange,
1314 kern: Kernel,
1315) -> Result<StochBatchOutput, StochError> {
1316 stoch_batch_inner(high, low, close, sweep, kern, false)
1317}
1318
1319#[inline(always)]
1320pub fn stoch_batch_par_slice(
1321 high: &[f64],
1322 low: &[f64],
1323 close: &[f64],
1324 sweep: &StochBatchRange,
1325 kern: Kernel,
1326) -> Result<StochBatchOutput, StochError> {
1327 stoch_batch_inner(high, low, close, sweep, kern, true)
1328}
1329
1330#[inline(always)]
1331fn stoch_batch_inner(
1332 high: &[f64],
1333 low: &[f64],
1334 close: &[f64],
1335 sweep: &StochBatchRange,
1336 kern: Kernel,
1337 parallel: bool,
1338) -> Result<StochBatchOutput, StochError> {
1339 let combos = expand_grid(sweep)?;
1340
1341 let n = high.len();
1342 if n == 0 || low.len() != n || close.len() != n {
1343 return Err(StochError::MismatchedLength);
1344 }
1345
1346 let first = high
1347 .iter()
1348 .zip(low.iter())
1349 .zip(close.iter())
1350 .position(|((h, l), c)| !h.is_nan() && !l.is_nan() && !c.is_nan())
1351 .ok_or(StochError::AllValuesNaN)?;
1352 let max_fkp = combos
1353 .iter()
1354 .map(|c| c.fastk_period.unwrap())
1355 .max()
1356 .unwrap();
1357 if n - first < max_fkp {
1358 return Err(StochError::NotEnoughValidData {
1359 needed: max_fkp,
1360 valid: n - first,
1361 });
1362 }
1363
1364 let rows = combos.len();
1365 let cols = n;
1366
1367 rows.checked_mul(cols).ok_or(StochError::InvalidRange {
1368 start: rows,
1369 end: cols,
1370 step: 0,
1371 })?;
1372
1373 let mut k_mu = make_uninit_matrix(rows, cols);
1374 let mut d_mu = make_uninit_matrix(rows, cols);
1375
1376 let warm_k: Vec<usize> = combos
1377 .iter()
1378 .map(|c| first + c.fastk_period.unwrap() - 1)
1379 .collect();
1380 init_matrix_prefixes(&mut k_mu, cols, &warm_k);
1381 init_matrix_prefixes(&mut d_mu, cols, &warm_k);
1382
1383 let mut k_guard = core::mem::ManuallyDrop::new(k_mu);
1384 let mut d_guard = core::mem::ManuallyDrop::new(d_mu);
1385 let k_mat: &mut [f64] =
1386 unsafe { core::slice::from_raw_parts_mut(k_guard.as_mut_ptr() as *mut f64, k_guard.len()) };
1387 let d_mat: &mut [f64] =
1388 unsafe { core::slice::from_raw_parts_mut(d_guard.as_mut_ptr() as *mut f64, d_guard.len()) };
1389
1390 use std::collections::HashMap;
1391 let mut groups: HashMap<usize, Vec<usize>> = HashMap::new();
1392 for (row, prm) in combos.iter().enumerate() {
1393 groups
1394 .entry(prm.fastk_period.unwrap())
1395 .or_default()
1396 .push(row);
1397 }
1398
1399 let mut compute_k_raw = |fkp: usize| -> Vec<f64> {
1400 let mut hh = alloc_with_nan_prefix(cols, first + fkp - 1);
1401 let mut ll = alloc_with_nan_prefix(cols, first + fkp - 1);
1402 let highs = max_rolling(&high[first..], fkp).unwrap();
1403 let lows = min_rolling(&low[first..], fkp).unwrap();
1404 for (i, &v) in highs.iter().enumerate() {
1405 hh[first + i] = v;
1406 }
1407 for (i, &v) in lows.iter().enumerate() {
1408 ll[first + i] = v;
1409 }
1410
1411 let mut k_raw = alloc_with_nan_prefix(cols, first + fkp - 1);
1412 unsafe {
1413 match kern {
1414 Kernel::Scalar => {
1415 stoch_row_scalar(high, low, close, &hh, &ll, fkp, first, &mut k_raw)
1416 }
1417 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1418 Kernel::Avx2 => stoch_row_avx2(high, low, close, &hh, &ll, fkp, first, &mut k_raw),
1419 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1420 Kernel::Avx512 => {
1421 stoch_row_avx512(high, low, close, &hh, &ll, fkp, first, &mut k_raw)
1422 }
1423 _ => unreachable!(),
1424 }
1425 }
1426 k_raw
1427 };
1428
1429 for (fkp, rows_in_group) in groups {
1430 let k_raw = compute_k_raw(fkp);
1431 for &row in &rows_in_group {
1432 let prm = &combos[row];
1433 let k_vec = ma(
1434 prm.slowk_ma_type.as_ref().unwrap(),
1435 MaData::Slice(&k_raw),
1436 prm.slowk_period.unwrap(),
1437 )
1438 .unwrap();
1439 let d_vec = ma(
1440 prm.slowd_ma_type.as_ref().unwrap(),
1441 MaData::Slice(&k_vec),
1442 prm.slowd_period.unwrap(),
1443 )
1444 .unwrap();
1445 let start = row * cols;
1446 let dst_k = &mut k_mat[start..start + cols];
1447 let dst_d = &mut d_mat[start..start + cols];
1448 dst_k.copy_from_slice(&k_vec);
1449 dst_d.copy_from_slice(&d_vec);
1450 }
1451 }
1452
1453 let k = unsafe {
1454 Vec::from_raw_parts(
1455 k_guard.as_mut_ptr() as *mut f64,
1456 k_guard.len(),
1457 k_guard.capacity(),
1458 )
1459 };
1460 let d = unsafe {
1461 Vec::from_raw_parts(
1462 d_guard.as_mut_ptr() as *mut f64,
1463 d_guard.len(),
1464 d_guard.capacity(),
1465 )
1466 };
1467 core::mem::forget(k_guard);
1468 core::mem::forget(d_guard);
1469
1470 Ok(StochBatchOutput {
1471 k,
1472 d,
1473 combos,
1474 rows,
1475 cols,
1476 })
1477}
1478
1479#[inline(always)]
1480unsafe fn stoch_row_scalar(
1481 _high: &[f64],
1482 _low: &[f64],
1483 close: &[f64],
1484 hh: &[f64],
1485 ll: &[f64],
1486 fastk_period: usize,
1487 first: usize,
1488 out: &mut [f64],
1489) {
1490 let start = first + fastk_period - 1;
1491 if start >= close.len() {
1492 return;
1493 }
1494
1495 const SCALE: f64 = 100.0;
1496 const EPS: f64 = f64::EPSILON;
1497
1498 let c = &close[start..];
1499 let h = &hh[start..];
1500 let l = &ll[start..];
1501 let outv = &mut out[start..];
1502
1503 for (o, (&cv, (&hv, &lv))) in outv.iter_mut().zip(c.iter().zip(h.iter().zip(l.iter()))) {
1504 let d = hv - lv;
1505 *o = if d.abs() < EPS {
1506 50.0
1507 } else {
1508 (cv - lv).mul_add(SCALE / d, 0.0)
1509 };
1510 }
1511}
1512#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1513#[inline(always)]
1514unsafe fn stoch_row_avx2(
1515 high: &[f64],
1516 low: &[f64],
1517 close: &[f64],
1518 hh: &[f64],
1519 ll: &[f64],
1520 fastk_period: usize,
1521 first: usize,
1522 out: &mut [f64],
1523) {
1524 stoch_row_avx2_impl(high, low, close, hh, ll, fastk_period, first, out)
1525}
1526
1527#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1528#[target_feature(enable = "avx2")]
1529unsafe fn stoch_row_avx2_impl(
1530 _high: &[f64],
1531 _low: &[f64],
1532 close: &[f64],
1533 hh: &[f64],
1534 ll: &[f64],
1535 fastk_period: usize,
1536 first: usize,
1537 out: &mut [f64],
1538) {
1539 let start = first + fastk_period - 1;
1540 if start >= close.len() {
1541 return;
1542 }
1543 let n = close.len() - start;
1544
1545 let mut i = 0usize;
1546 let c_ptr = close.as_ptr().add(start);
1547 let h_ptr = hh.as_ptr().add(start);
1548 let l_ptr = ll.as_ptr().add(start);
1549 let o_ptr = out.as_mut_ptr().add(start);
1550
1551 const STEP: usize = 4;
1552 let vec_end = n & !(STEP - 1);
1553
1554 let scale = _mm256_set1_pd(100.0);
1555 let fifty = _mm256_set1_pd(50.0);
1556 let eps = _mm256_set1_pd(f64::EPSILON);
1557 let sign_mask = _mm256_set1_pd(-0.0);
1558
1559 while i + STEP <= vec_end {
1560 let c0 = _mm256_loadu_pd(c_ptr.add(i));
1561 let h0 = _mm256_loadu_pd(h_ptr.add(i));
1562 let l0 = _mm256_loadu_pd(l_ptr.add(i));
1563 let d0 = _mm256_sub_pd(h0, l0);
1564 let n0 = _mm256_sub_pd(c0, l0);
1565 let a0 = _mm256_andnot_pd(sign_mask, d0);
1566 let m0 = _mm256_cmp_pd(a0, eps, _CMP_LT_OQ);
1567 let inv0 = _mm256_div_pd(scale, d0);
1568 let v0 = _mm256_mul_pd(n0, inv0);
1569 let o0 = _mm256_blendv_pd(v0, fifty, m0);
1570
1571 if i + 2 * STEP <= vec_end {
1572 let c1 = _mm256_loadu_pd(c_ptr.add(i + STEP));
1573 let h1 = _mm256_loadu_pd(h_ptr.add(i + STEP));
1574 let l1 = _mm256_loadu_pd(l_ptr.add(i + STEP));
1575 let d1 = _mm256_sub_pd(h1, l1);
1576 let n1 = _mm256_sub_pd(c1, l1);
1577 let a1 = _mm256_andnot_pd(sign_mask, d1);
1578 let m1 = _mm256_cmp_pd(a1, eps, _CMP_LT_OQ);
1579 let inv1 = _mm256_div_pd(scale, d1);
1580 let v1 = _mm256_mul_pd(n1, inv1);
1581 let o1 = _mm256_blendv_pd(v1, fifty, m1);
1582
1583 _mm256_storeu_pd(o_ptr.add(i), o0);
1584 _mm256_storeu_pd(o_ptr.add(i + STEP), o1);
1585 i += 2 * STEP;
1586 } else {
1587 _mm256_storeu_pd(o_ptr.add(i), o0);
1588 i += STEP;
1589 }
1590 }
1591
1592 while i < n {
1593 let c = *c_ptr.add(i);
1594 let l = *l_ptr.add(i);
1595 let d = *h_ptr.add(i) - l;
1596 *o_ptr.add(i) = if d.abs() < f64::EPSILON {
1597 50.0
1598 } else {
1599 (c - l) * (100.0 / d)
1600 };
1601 i += 1;
1602 }
1603}
1604#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1605#[inline(always)]
1606unsafe fn stoch_row_avx512(
1607 high: &[f64],
1608 low: &[f64],
1609 close: &[f64],
1610 hh: &[f64],
1611 ll: &[f64],
1612 fastk_period: usize,
1613 first: usize,
1614 out: &mut [f64],
1615) {
1616 if fastk_period <= 32 {
1617 stoch_row_avx512_short(high, low, close, hh, ll, fastk_period, first, out)
1618 } else {
1619 stoch_row_avx512_long(high, low, close, hh, ll, fastk_period, first, out)
1620 }
1621}
1622#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1623#[inline(always)]
1624unsafe fn stoch_row_avx512_short(
1625 high: &[f64],
1626 low: &[f64],
1627 close: &[f64],
1628 hh: &[f64],
1629 ll: &[f64],
1630 fastk_period: usize,
1631 first: usize,
1632 out: &mut [f64],
1633) {
1634 stoch_row_avx512_impl(high, low, close, hh, ll, fastk_period, first, out)
1635}
1636#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1637#[inline(always)]
1638unsafe fn stoch_row_avx512_long(
1639 high: &[f64],
1640 low: &[f64],
1641 close: &[f64],
1642 hh: &[f64],
1643 ll: &[f64],
1644 fastk_period: usize,
1645 first: usize,
1646 out: &mut [f64],
1647) {
1648 stoch_row_avx512_impl(high, low, close, hh, ll, fastk_period, first, out)
1649}
1650
1651#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1652#[target_feature(enable = "avx512f")]
1653unsafe fn stoch_row_avx512_impl(
1654 _high: &[f64],
1655 _low: &[f64],
1656 close: &[f64],
1657 hh: &[f64],
1658 ll: &[f64],
1659 fastk_period: usize,
1660 first: usize,
1661 out: &mut [f64],
1662) {
1663 let start = first + fastk_period - 1;
1664 if start >= close.len() {
1665 return;
1666 }
1667 let n = close.len() - start;
1668
1669 let c_ptr = close.as_ptr().add(start);
1670 let h_ptr = hh.as_ptr().add(start);
1671 let l_ptr = ll.as_ptr().add(start);
1672 let o_ptr = out.as_mut_ptr().add(start);
1673
1674 const STEP: usize = 8;
1675 let vec_end = n & !(STEP - 1);
1676
1677 let scale = _mm512_set1_pd(100.0);
1678 let fifty = _mm512_set1_pd(50.0);
1679 let eps = _mm512_set1_pd(f64::EPSILON);
1680 let sign_mask = _mm512_set1_pd(-0.0);
1681
1682 let mut i = 0usize;
1683 while i + STEP <= vec_end {
1684 let c0 = _mm512_loadu_pd(c_ptr.add(i));
1685 let h0 = _mm512_loadu_pd(h_ptr.add(i));
1686 let l0 = _mm512_loadu_pd(l_ptr.add(i));
1687 let d0 = _mm512_sub_pd(h0, l0);
1688 let n0 = _mm512_sub_pd(c0, l0);
1689 let a0 = _mm512_andnot_pd(sign_mask, d0);
1690 let m0: __mmask8 = _mm512_cmp_pd_mask(a0, eps, _CMP_LT_OQ);
1691 let inv0 = _mm512_div_pd(scale, d0);
1692 let v0 = _mm512_mul_pd(n0, inv0);
1693 let o0 = _mm512_mask_blend_pd(m0, v0, fifty);
1694
1695 if i + 2 * STEP <= vec_end {
1696 let c1 = _mm512_loadu_pd(c_ptr.add(i + STEP));
1697 let h1 = _mm512_loadu_pd(h_ptr.add(i + STEP));
1698 let l1 = _mm512_loadu_pd(l_ptr.add(i + STEP));
1699 let d1 = _mm512_sub_pd(h1, l1);
1700 let n1 = _mm512_sub_pd(c1, l1);
1701 let a1 = _mm512_andnot_pd(sign_mask, d1);
1702 let m1: __mmask8 = _mm512_cmp_pd_mask(a1, eps, _CMP_LT_OQ);
1703 let inv1 = _mm512_div_pd(scale, d1);
1704 let v1 = _mm512_mul_pd(n1, inv1);
1705 let o1 = _mm512_mask_blend_pd(m1, v1, fifty);
1706
1707 _mm512_storeu_pd(o_ptr.add(i), o0);
1708 _mm512_storeu_pd(o_ptr.add(i + STEP), o1);
1709 i += 2 * STEP;
1710 } else {
1711 _mm512_storeu_pd(o_ptr.add(i), o0);
1712 i += STEP;
1713 }
1714 }
1715
1716 while i < n {
1717 let c = *c_ptr.add(i);
1718 let l = *l_ptr.add(i);
1719 let d = *h_ptr.add(i) - l;
1720 *o_ptr.add(i) = if d.abs() < f64::EPSILON {
1721 50.0
1722 } else {
1723 (c - l) * (100.0 / d)
1724 };
1725 i += 1;
1726 }
1727}
1728
1729#[derive(Debug, Clone)]
1730struct DeqEntry {
1731 val: f64,
1732 idx: usize,
1733}
1734
1735#[derive(Debug, Clone)]
1736pub struct StochStream {
1737 fastk_period: usize,
1738 slowk_period: usize,
1739 slowk_ma_type: String,
1740 slowd_period: usize,
1741 slowd_ma_type: String,
1742
1743 maxq: VecDeque<DeqEntry>,
1744 minq: VecDeque<DeqEntry>,
1745 t: usize,
1746 have_window: bool,
1747
1748 k_sma_buf: Vec<f64>,
1749 k_sma_sum: f64,
1750 k_sma_head: usize,
1751 k_sma_count: usize,
1752
1753 k_ema: Option<f64>,
1754 k_ema_seed_sum: f64,
1755 k_ema_seed_count: usize,
1756 alpha_k: f64,
1757
1758 d_sma_buf: Vec<f64>,
1759 d_sma_sum: f64,
1760 d_sma_head: usize,
1761 d_sma_count: usize,
1762
1763 d_ema: Option<f64>,
1764 d_ema_seed_sum: f64,
1765 d_ema_seed_count: usize,
1766 alpha_d: f64,
1767
1768 k_stream: Option<Vec<f64>>,
1769 d_stream: Option<Vec<f64>>,
1770}
1771
1772impl StochStream {
1773 pub fn try_new(params: StochParams) -> Result<Self, StochError> {
1774 let fastk_period = params.fastk_period.unwrap_or(14);
1775 let slowk_period = params.slowk_period.unwrap_or(3);
1776 let slowd_period = params.slowd_period.unwrap_or(3);
1777 if fastk_period == 0 || slowk_period == 0 || slowd_period == 0 {
1778 return Err(StochError::InvalidPeriod {
1779 period: 0,
1780 data_len: 0,
1781 });
1782 }
1783
1784 let slowk_ma_type = params.slowk_ma_type.unwrap_or_else(|| "sma".to_string());
1785 let slowd_ma_type = params.slowd_ma_type.unwrap_or_else(|| "sma".to_string());
1786
1787 let alpha_k = 2.0 / (slowk_period as f64 + 1.0);
1788 let alpha_d = 2.0 / (slowd_period as f64 + 1.0);
1789
1790 Ok(Self {
1791 fastk_period,
1792 slowk_period,
1793 slowk_ma_type,
1794 slowd_period,
1795 slowd_ma_type,
1796
1797 maxq: VecDeque::with_capacity(fastk_period),
1798 minq: VecDeque::with_capacity(fastk_period),
1799 t: 0,
1800 have_window: false,
1801
1802 k_sma_buf: vec![f64::NAN; slowk_period.max(1)],
1803 k_sma_sum: 0.0,
1804 k_sma_head: 0,
1805 k_sma_count: 0,
1806
1807 k_ema: None,
1808 k_ema_seed_sum: 0.0,
1809 k_ema_seed_count: 0,
1810 alpha_k,
1811
1812 d_sma_buf: vec![f64::NAN; slowd_period.max(1)],
1813 d_sma_sum: 0.0,
1814 d_sma_head: 0,
1815 d_sma_count: 0,
1816
1817 d_ema: None,
1818 d_ema_seed_sum: 0.0,
1819 d_ema_seed_count: 0,
1820 alpha_d,
1821
1822 k_stream: None,
1823 d_stream: None,
1824 })
1825 }
1826
1827 #[inline(always)]
1828 fn evict_older_than(dq: &mut VecDeque<DeqEntry>, min_idx: usize) {
1829 while let Some(front) = dq.front() {
1830 if front.idx < min_idx {
1831 dq.pop_front();
1832 } else {
1833 break;
1834 }
1835 }
1836 }
1837
1838 #[inline(always)]
1839 fn push_maxq(&mut self, val: f64, idx: usize) {
1840 while let Some(back) = self.maxq.back() {
1841 if back.val <= val {
1842 self.maxq.pop_back();
1843 } else {
1844 break;
1845 }
1846 }
1847 self.maxq.push_back(DeqEntry { val, idx });
1848 }
1849
1850 #[inline(always)]
1851 fn push_minq(&mut self, val: f64, idx: usize) {
1852 while let Some(back) = self.minq.back() {
1853 if back.val >= val {
1854 self.minq.pop_back();
1855 } else {
1856 break;
1857 }
1858 }
1859 self.minq.push_back(DeqEntry { val, idx });
1860 }
1861
1862 pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
1863 if !high.is_finite() || !low.is_finite() || !close.is_finite() {
1864 return None;
1865 }
1866
1867 let idx = self.t;
1868 self.t = self.t.wrapping_add(1);
1869
1870 self.push_maxq(high, idx);
1871 self.push_minq(low, idx);
1872
1873 let seen = idx + 1;
1874 if seen >= self.fastk_period {
1875 let window_start = seen - self.fastk_period;
1876 Self::evict_older_than(&mut self.maxq, window_start);
1877 Self::evict_older_than(&mut self.minq, window_start);
1878 self.have_window = true;
1879 }
1880
1881 if !self.have_window {
1882 return None;
1883 }
1884
1885 debug_assert!(!self.maxq.is_empty() && !self.minq.is_empty());
1886 let hh = self.maxq.front().unwrap().val;
1887 let ll = self.minq.front().unwrap().val;
1888
1889 const SCALE: f64 = 100.0;
1890 const EPS: f64 = f64::EPSILON;
1891
1892 let denom = hh - ll;
1893 let k_raw = if denom.abs() < EPS {
1894 50.0
1895 } else {
1896 (close - ll).mul_add(SCALE / denom, 0.0)
1897 };
1898
1899 let k_last = if self.slowk_ma_type.eq_ignore_ascii_case("sma") {
1900 if self.slowk_period == 1 {
1901 k_raw
1902 } else if self.k_sma_count < self.slowk_period {
1903 self.k_sma_sum += k_raw;
1904 self.k_sma_buf[self.k_sma_head] = k_raw;
1905 self.k_sma_head = (self.k_sma_head + 1) % self.slowk_period;
1906 self.k_sma_count += 1;
1907 if self.k_sma_count == self.slowk_period {
1908 self.k_sma_sum / self.slowk_period as f64
1909 } else {
1910 f64::NAN
1911 }
1912 } else {
1913 let old = self.k_sma_buf[self.k_sma_head];
1914 self.k_sma_sum += k_raw - old;
1915 self.k_sma_buf[self.k_sma_head] = k_raw;
1916 self.k_sma_head = (self.k_sma_head + 1) % self.slowk_period;
1917 self.k_sma_sum / self.slowk_period as f64
1918 }
1919 } else if self.slowk_ma_type.eq_ignore_ascii_case("ema") {
1920 if self.slowk_period == 1 {
1921 self.k_ema = Some(k_raw);
1922 k_raw
1923 } else if self.k_ema.is_none() {
1924 self.k_ema_seed_sum += k_raw;
1925 self.k_ema_seed_count += 1;
1926 if self.k_ema_seed_count == self.slowk_period {
1927 let seed = self.k_ema_seed_sum / self.slowk_period as f64;
1928 self.k_ema = Some(seed);
1929 seed
1930 } else {
1931 f64::NAN
1932 }
1933 } else {
1934 let prev = self.k_ema.unwrap();
1935 let ema = prev + self.alpha_k * (k_raw - prev);
1936 self.k_ema = Some(ema);
1937 ema
1938 }
1939 } else {
1940 let mut k_vec = self
1941 .k_stream
1942 .take()
1943 .unwrap_or_else(|| vec![f64::NAN; self.slowk_period]);
1944 k_vec.remove(0);
1945 k_vec.push(k_raw);
1946 self.k_stream = Some(k_vec.clone());
1947
1948 match ma(
1949 &self.slowk_ma_type,
1950 MaData::Slice(&k_vec),
1951 self.slowk_period,
1952 ) {
1953 Ok(slowk) => *slowk.last().unwrap_or(&f64::NAN),
1954 Err(_) => k_raw,
1955 }
1956 };
1957
1958 let d_last = if self.slowd_ma_type.eq_ignore_ascii_case("sma") {
1959 if self.slowd_period == 1 {
1960 k_last
1961 } else if !k_last.is_finite() {
1962 f64::NAN
1963 } else if self.d_sma_count < self.slowd_period {
1964 self.d_sma_sum += k_last;
1965 self.d_sma_buf[self.d_sma_head] = k_last;
1966 self.d_sma_head = (self.d_sma_head + 1) % self.slowd_period;
1967 self.d_sma_count += 1;
1968 if self.d_sma_count == self.slowd_period {
1969 self.d_sma_sum / self.slowd_period as f64
1970 } else {
1971 f64::NAN
1972 }
1973 } else {
1974 let old = self.d_sma_buf[self.d_sma_head];
1975 self.d_sma_sum += k_last - old;
1976 self.d_sma_buf[self.d_sma_head] = k_last;
1977 self.d_sma_head = (self.d_sma_head + 1) % self.slowd_period;
1978 self.d_sma_sum / self.slowd_period as f64
1979 }
1980 } else if self.slowd_ma_type.eq_ignore_ascii_case("ema") {
1981 if self.slowd_period == 1 {
1982 self.d_ema = Some(k_last);
1983 k_last
1984 } else if !k_last.is_finite() {
1985 f64::NAN
1986 } else if self.d_ema.is_none() {
1987 self.d_ema_seed_sum += k_last;
1988 self.d_ema_seed_count += 1;
1989 if self.d_ema_seed_count == self.slowd_period {
1990 let seed = self.d_ema_seed_sum / self.slowd_period as f64;
1991 self.d_ema = Some(seed);
1992 seed
1993 } else {
1994 f64::NAN
1995 }
1996 } else {
1997 let prev = self.d_ema.unwrap();
1998 let ema = prev + self.alpha_d * (k_last - prev);
1999 self.d_ema = Some(ema);
2000 ema
2001 }
2002 } else {
2003 let mut d_vec = self
2004 .d_stream
2005 .take()
2006 .unwrap_or_else(|| vec![f64::NAN; self.slowd_period]);
2007 d_vec.remove(0);
2008 d_vec.push(k_last);
2009 self.d_stream = Some(d_vec.clone());
2010
2011 match ma(
2012 &self.slowd_ma_type,
2013 MaData::Slice(&d_vec),
2014 self.slowd_period,
2015 ) {
2016 Ok(slowd) => *slowd.last().unwrap_or(&f64::NAN),
2017 Err(_) => k_last,
2018 }
2019 };
2020
2021 Some((k_last, d_last))
2022 }
2023}
2024
2025#[cfg(all(feature = "python", feature = "cuda"))]
2026use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2027#[cfg(all(feature = "python", feature = "cuda"))]
2028use cust::context::Context;
2029#[cfg(all(feature = "python", feature = "cuda"))]
2030use cust::memory::DeviceBuffer;
2031#[cfg(all(feature = "python", feature = "cuda"))]
2032use std::sync::Arc;
2033
2034#[cfg(all(feature = "python", feature = "cuda"))]
2035#[pyclass(
2036 module = "ta_indicators.cuda",
2037 name = "StochDeviceArrayF32",
2038 unsendable
2039)]
2040pub struct StochDeviceArrayF32Py {
2041 pub(crate) buf: Option<DeviceBuffer<f32>>,
2042 pub(crate) rows: usize,
2043 pub(crate) cols: usize,
2044 pub(crate) _ctx: Arc<Context>,
2045 pub(crate) device_id: u32,
2046}
2047
2048#[cfg(all(feature = "python", feature = "cuda"))]
2049#[pymethods]
2050impl StochDeviceArrayF32Py {
2051 #[getter]
2052 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2053 let d = PyDict::new(py);
2054 d.set_item("shape", (self.rows, self.cols))?;
2055 d.set_item("typestr", "<f4")?;
2056 d.set_item(
2057 "strides",
2058 (
2059 self.cols * std::mem::size_of::<f32>(),
2060 std::mem::size_of::<f32>(),
2061 ),
2062 )?;
2063 let buf = self
2064 .buf
2065 .as_ref()
2066 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2067 let ptr = buf.as_device_ptr().as_raw() as usize;
2068 d.set_item("data", (ptr, false))?;
2069 d.set_item("version", 3)?;
2070 Ok(d)
2071 }
2072
2073 fn __dlpack_device__(&self) -> (i32, i32) {
2074 (2, self.device_id as i32)
2075 }
2076
2077 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2078 fn __dlpack__<'py>(
2079 &mut self,
2080 py: Python<'py>,
2081 stream: Option<pyo3::PyObject>,
2082 max_version: Option<pyo3::PyObject>,
2083 dl_device: Option<pyo3::PyObject>,
2084 copy: Option<pyo3::PyObject>,
2085 ) -> PyResult<PyObject> {
2086 let (kdl, alloc_dev) = self.__dlpack_device__();
2087 if let Some(dev_obj) = dl_device.as_ref() {
2088 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2089 if dev_ty != kdl || dev_id != alloc_dev {
2090 let wants_copy = copy
2091 .as_ref()
2092 .and_then(|c| c.extract::<bool>(py).ok())
2093 .unwrap_or(false);
2094 if wants_copy {
2095 return Err(PyValueError::new_err(
2096 "stoch: device copy not implemented for __dlpack__",
2097 ));
2098 } else {
2099 return Err(PyValueError::new_err(
2100 "stoch: requested dl_device does not match buffer device",
2101 ));
2102 }
2103 }
2104 }
2105 }
2106 let _ = stream;
2107
2108 if let Some(copy_obj) = copy.as_ref() {
2109 let do_copy: bool = copy_obj.extract(py)?;
2110 if do_copy {
2111 return Err(PyValueError::new_err(
2112 "stoch: __dlpack__(copy=True) not supported",
2113 ));
2114 }
2115 }
2116
2117 let buf = self
2118 .buf
2119 .take()
2120 .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
2121
2122 let rows = self.rows;
2123 let cols = self.cols;
2124
2125 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2126
2127 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2128 }
2129}
2130
2131#[cfg(all(feature = "python", feature = "cuda"))]
2132#[pyfunction(name = "stoch_cuda_batch_dev")]
2133#[pyo3(signature = (high_f32, low_f32, close_f32, fastk_period=(14,14,0), slowk_period=(3,3,0), slowd_period=(3,3,0), slowk_ma_type="sma", slowd_ma_type="sma", device_id=0))]
2134pub fn stoch_cuda_batch_dev_py(
2135 py: Python<'_>,
2136 high_f32: numpy::PyReadonlyArray1<'_, f32>,
2137 low_f32: numpy::PyReadonlyArray1<'_, f32>,
2138 close_f32: numpy::PyReadonlyArray1<'_, f32>,
2139 fastk_period: (usize, usize, usize),
2140 slowk_period: (usize, usize, usize),
2141 slowd_period: (usize, usize, usize),
2142 slowk_ma_type: &str,
2143 slowd_ma_type: &str,
2144 device_id: usize,
2145) -> PyResult<(StochDeviceArrayF32Py, StochDeviceArrayF32Py)> {
2146 use crate::cuda::cuda_available;
2147 if !cuda_available() {
2148 return Err(PyValueError::new_err("CUDA not available"));
2149 }
2150 let h = high_f32.as_slice()?;
2151 let l = low_f32.as_slice()?;
2152 let c = close_f32.as_slice()?;
2153 let sweep = StochBatchRange {
2154 fastk_period,
2155 slowk_period,
2156 slowk_ma_type: (slowk_ma_type.to_string(), slowk_ma_type.to_string(), 0.0),
2157 slowd_period,
2158 slowd_ma_type: (slowd_ma_type.to_string(), slowd_ma_type.to_string(), 0.0),
2159 };
2160 let (k_buf, d_buf, rows, cols, ctx, dev_id) = py.allow_threads(|| {
2161 let cuda = crate::cuda::oscillators::CudaStoch::new(device_id)
2162 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2163 let batch = cuda
2164 .stoch_batch_dev(h, l, c, &sweep)
2165 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2166 let ctx = cuda.context_arc();
2167 Ok::<_, PyErr>((
2168 batch.k.buf,
2169 batch.d.buf,
2170 batch.k.rows,
2171 batch.k.cols,
2172 ctx,
2173 cuda.device_id(),
2174 ))
2175 })?;
2176 Ok((
2177 StochDeviceArrayF32Py {
2178 buf: Some(k_buf),
2179 rows,
2180 cols,
2181 _ctx: ctx.clone(),
2182 device_id: dev_id,
2183 },
2184 StochDeviceArrayF32Py {
2185 buf: Some(d_buf),
2186 rows,
2187 cols,
2188 _ctx: ctx,
2189 device_id: dev_id,
2190 },
2191 ))
2192}
2193
2194#[cfg(all(feature = "python", feature = "cuda"))]
2195#[pyfunction(name = "stoch_cuda_many_series_one_param_dev")]
2196#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, cols, rows, fastk_period=14, slowk_period=3, slowd_period=3, slowk_ma_type="sma", slowd_ma_type="sma", device_id=0))]
2197pub fn stoch_cuda_many_series_one_param_dev_py(
2198 py: Python<'_>,
2199 high_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2200 low_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2201 close_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2202 cols: usize,
2203 rows: usize,
2204 fastk_period: usize,
2205 slowk_period: usize,
2206 slowd_period: usize,
2207 slowk_ma_type: &str,
2208 slowd_ma_type: &str,
2209 device_id: usize,
2210) -> PyResult<(StochDeviceArrayF32Py, StochDeviceArrayF32Py)> {
2211 use crate::cuda::cuda_available;
2212 if !cuda_available() {
2213 return Err(PyValueError::new_err("CUDA not available"));
2214 }
2215 let h = high_tm_f32.as_slice()?;
2216 let l = low_tm_f32.as_slice()?;
2217 let c = close_tm_f32.as_slice()?;
2218 let params = StochParams {
2219 fastk_period: Some(fastk_period),
2220 slowk_period: Some(slowk_period),
2221 slowk_ma_type: Some(slowk_ma_type.to_string()),
2222 slowd_period: Some(slowd_period),
2223 slowd_ma_type: Some(slowd_ma_type.to_string()),
2224 };
2225 let (k_dev, d_dev, ctx, dev_id) = py.allow_threads(|| {
2226 let cuda = crate::cuda::oscillators::CudaStoch::new(device_id)
2227 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2228 let (k, d) = cuda
2229 .stoch_many_series_one_param_time_major_dev(h, l, c, cols, rows, ¶ms)
2230 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2231 let ctx = cuda.context_arc();
2232 Ok::<_, PyErr>((k.buf, d.buf, ctx, cuda.device_id()))
2233 })?;
2234 Ok((
2235 StochDeviceArrayF32Py {
2236 buf: Some(k_dev),
2237 rows,
2238 cols,
2239 _ctx: ctx.clone(),
2240 device_id: dev_id,
2241 },
2242 StochDeviceArrayF32Py {
2243 buf: Some(d_dev),
2244 rows,
2245 cols,
2246 _ctx: ctx,
2247 device_id: dev_id,
2248 },
2249 ))
2250}
2251
2252#[cfg(feature = "python")]
2253#[pyfunction(name = "stoch")]
2254#[pyo3(signature = (high, low, close, fastk_period=14, slowk_period=3, slowk_ma_type="sma", slowd_period=3, slowd_ma_type="sma", kernel=None))]
2255pub fn stoch_py<'py>(
2256 py: Python<'py>,
2257 high: PyReadonlyArray1<'py, f64>,
2258 low: PyReadonlyArray1<'py, f64>,
2259 close: PyReadonlyArray1<'py, f64>,
2260 fastk_period: usize,
2261 slowk_period: usize,
2262 slowk_ma_type: &str,
2263 slowd_period: usize,
2264 slowd_ma_type: &str,
2265 kernel: Option<&str>,
2266) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
2267 let hi = high.as_slice()?;
2268 let lo = low.as_slice()?;
2269 let cl = close.as_slice()?;
2270 let params = StochParams {
2271 fastk_period: Some(fastk_period),
2272 slowk_period: Some(slowk_period),
2273 slowk_ma_type: Some(slowk_ma_type.to_string()),
2274 slowd_period: Some(slowd_period),
2275 slowd_ma_type: Some(slowd_ma_type.to_string()),
2276 };
2277 let kern = validate_kernel(kernel, false)?;
2278 let input = StochInput::from_slices(hi, lo, cl, params);
2279 let out = py
2280 .allow_threads(|| stoch_with_kernel(&input, kern))
2281 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2282 Ok((out.k.into_pyarray(py), out.d.into_pyarray(py)))
2283}
2284
2285#[cfg(feature = "python")]
2286#[pyfunction(name = "stoch_batch")]
2287#[pyo3(signature = (high, low, close, fastk_range, slowk_range, slowk_ma_type, slowd_range, slowd_ma_type, kernel=None))]
2288pub fn stoch_batch_py<'py>(
2289 py: Python<'py>,
2290 high: PyReadonlyArray1<'py, f64>,
2291 low: PyReadonlyArray1<'py, f64>,
2292 close: PyReadonlyArray1<'py, f64>,
2293 fastk_range: (usize, usize, usize),
2294 slowk_range: (usize, usize, usize),
2295 slowk_ma_type: &str,
2296 slowd_range: (usize, usize, usize),
2297 slowd_ma_type: &str,
2298 kernel: Option<&str>,
2299) -> PyResult<Bound<'py, PyDict>> {
2300 let hi = high.as_slice()?;
2301 let lo = low.as_slice()?;
2302 let cl = close.as_slice()?;
2303
2304 let sweep = StochBatchRange {
2305 fastk_period: fastk_range,
2306 slowk_period: slowk_range,
2307 slowk_ma_type: (slowk_ma_type.to_string(), slowk_ma_type.to_string(), 0.0),
2308 slowd_period: slowd_range,
2309 slowd_ma_type: (slowd_ma_type.to_string(), slowd_ma_type.to_string(), 0.0),
2310 };
2311
2312 let kern = validate_kernel(kernel, true)?;
2313 let out = py
2314 .allow_threads(|| stoch_batch_with_kernel(hi, lo, cl, &sweep, kern))
2315 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2316
2317 let rows = out.rows;
2318 let cols = out.cols;
2319 let total = rows
2320 .checked_mul(cols)
2321 .ok_or_else(|| PyValueError::new_err("stoch_batch: size overflow in rows*cols"))?;
2322
2323 let dict = PyDict::new(py);
2324
2325 let k_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2326 let d_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2327 unsafe { k_arr.as_slice_mut()? }.copy_from_slice(&out.k);
2328 unsafe { d_arr.as_slice_mut()? }.copy_from_slice(&out.d);
2329
2330 dict.set_item("k", k_arr.reshape((rows, cols))?)?;
2331 dict.set_item("d", d_arr.reshape((rows, cols))?)?;
2332 dict.set_item(
2333 "fastk_periods",
2334 out.combos
2335 .iter()
2336 .map(|p| p.fastk_period.unwrap() as u64)
2337 .collect::<Vec<_>>()
2338 .into_pyarray(py),
2339 )?;
2340 dict.set_item(
2341 "slowk_periods",
2342 out.combos
2343 .iter()
2344 .map(|p| p.slowk_period.unwrap() as u64)
2345 .collect::<Vec<_>>()
2346 .into_pyarray(py),
2347 )?;
2348 dict.set_item(
2349 "slowk_types",
2350 out.combos
2351 .iter()
2352 .map(|p| p.slowk_ma_type.as_deref().unwrap_or("sma"))
2353 .collect::<Vec<_>>(),
2354 )?;
2355 dict.set_item(
2356 "slowd_periods",
2357 out.combos
2358 .iter()
2359 .map(|p| p.slowd_period.unwrap() as u64)
2360 .collect::<Vec<_>>()
2361 .into_pyarray(py),
2362 )?;
2363 dict.set_item(
2364 "slowd_types",
2365 out.combos
2366 .iter()
2367 .map(|p| p.slowd_ma_type.as_deref().unwrap_or("sma"))
2368 .collect::<Vec<_>>(),
2369 )?;
2370
2371 Ok(dict)
2372}
2373
2374#[cfg(feature = "python")]
2375#[pyclass(name = "StochStream")]
2376pub struct StochStreamPy {
2377 stream: StochStream,
2378}
2379
2380#[cfg(feature = "python")]
2381#[pymethods]
2382impl StochStreamPy {
2383 #[new]
2384 fn new(
2385 fastk_period: usize,
2386 slowk_period: usize,
2387 slowk_ma_type: &str,
2388 slowd_period: usize,
2389 slowd_ma_type: &str,
2390 ) -> PyResult<Self> {
2391 let params = StochParams {
2392 fastk_period: Some(fastk_period),
2393 slowk_period: Some(slowk_period),
2394 slowk_ma_type: Some(slowk_ma_type.to_string()),
2395 slowd_period: Some(slowd_period),
2396 slowd_ma_type: Some(slowd_ma_type.to_string()),
2397 };
2398 Ok(Self {
2399 stream: StochStream::try_new(params)
2400 .map_err(|e| PyValueError::new_err(e.to_string()))?,
2401 })
2402 }
2403 fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
2404 self.stream.update(high, low, close)
2405 }
2406}
2407
2408#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2409#[derive(Serialize, Deserialize)]
2410pub struct StochResult {
2411 pub values: Vec<f64>,
2412 pub rows: usize,
2413 pub cols: usize,
2414}
2415
2416#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2417#[wasm_bindgen(js_name = stoch)]
2418pub fn stoch_js(
2419 high: &[f64],
2420 low: &[f64],
2421 close: &[f64],
2422 fastk_period: usize,
2423 slowk_period: usize,
2424 slowk_ma_type: &str,
2425 slowd_period: usize,
2426 slowd_ma_type: &str,
2427) -> Result<JsValue, JsValue> {
2428 let params = StochParams {
2429 fastk_period: Some(fastk_period),
2430 slowk_period: Some(slowk_period),
2431 slowk_ma_type: Some(slowk_ma_type.to_string()),
2432 slowd_period: Some(slowd_period),
2433 slowd_ma_type: Some(slowd_ma_type.to_string()),
2434 };
2435 let input = StochInput::from_slices(high, low, close, params);
2436 let out = stoch_with_kernel(&input, detect_best_kernel())
2437 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2438 let mut values = out.k;
2439 values.extend_from_slice(&out.d);
2440 serde_wasm_bindgen::to_value(&StochResult {
2441 values,
2442 rows: 2,
2443 cols: high.len(),
2444 })
2445 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2446}
2447
2448#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2449#[derive(Serialize, Deserialize)]
2450pub struct StochBatchJsOutput {
2451 pub values: Vec<f64>,
2452 pub combos: Vec<StochParams>,
2453 pub rows_per_combo: usize,
2454 pub cols: usize,
2455}
2456
2457#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2458#[wasm_bindgen(js_name = stoch_batch)]
2459pub fn stoch_batch_unified_js(
2460 high: &[f64],
2461 low: &[f64],
2462 close: &[f64],
2463 fastk_start: usize,
2464 fastk_end: usize,
2465 fastk_step: usize,
2466 slowk_start: usize,
2467 slowk_end: usize,
2468 slowk_step: usize,
2469 slowk_ma_type: &str,
2470 slowd_start: usize,
2471 slowd_end: usize,
2472 slowd_step: usize,
2473 slowd_ma_type: &str,
2474) -> Result<JsValue, JsValue> {
2475 let sweep = StochBatchRange {
2476 fastk_period: (fastk_start, fastk_end, fastk_step),
2477 slowk_period: (slowk_start, slowk_end, slowk_step),
2478 slowk_ma_type: (slowk_ma_type.to_string(), slowk_ma_type.to_string(), 0.0),
2479 slowd_period: (slowd_start, slowd_end, slowd_step),
2480 slowd_ma_type: (slowd_ma_type.to_string(), slowd_ma_type.to_string(), 0.0),
2481 };
2482 let out = stoch_batch_inner(high, low, close, &sweep, detect_best_kernel(), false)
2483 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2484 let mut values = out.k.clone();
2485 values.extend_from_slice(&out.d);
2486 let js = StochBatchJsOutput {
2487 values,
2488 combos: out.combos,
2489 rows_per_combo: 2,
2490 cols: out.cols,
2491 };
2492 serde_wasm_bindgen::to_value(&js)
2493 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2494}
2495
2496#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2497#[wasm_bindgen]
2498pub fn stoch_alloc(len: usize) -> *mut f64 {
2499 let mut v = Vec::<f64>::with_capacity(len);
2500 let ptr = v.as_mut_ptr();
2501 core::mem::forget(v);
2502 ptr
2503}
2504
2505#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2506#[wasm_bindgen]
2507pub fn stoch_free(ptr: *mut f64, len: usize) {
2508 unsafe {
2509 let _ = Vec::from_raw_parts(ptr, len, len);
2510 }
2511}
2512
2513#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2514#[wasm_bindgen(js_name = stoch_into)]
2515pub fn stoch_into_js(
2516 high_ptr: *const f64,
2517 low_ptr: *const f64,
2518 close_ptr: *const f64,
2519 len: usize,
2520 fastk_period: usize,
2521 slowk_period: usize,
2522 slowk_ma_type: &str,
2523 slowd_period: usize,
2524 slowd_ma_type: &str,
2525 out_k_ptr: *mut f64,
2526 out_d_ptr: *mut f64,
2527) -> Result<(), JsValue> {
2528 if [high_ptr, low_ptr, close_ptr, out_k_ptr, out_d_ptr]
2529 .iter()
2530 .any(|p| p.is_null())
2531 {
2532 return Err(JsValue::from_str("null pointer"));
2533 }
2534 unsafe {
2535 let hi = core::slice::from_raw_parts(high_ptr, len);
2536 let lo = core::slice::from_raw_parts(low_ptr, len);
2537 let cl = core::slice::from_raw_parts(close_ptr, len);
2538 let mut ok = core::slice::from_raw_parts_mut(out_k_ptr, len);
2539 let mut od = core::slice::from_raw_parts_mut(out_d_ptr, len);
2540 let params = StochParams {
2541 fastk_period: Some(fastk_period),
2542 slowk_period: Some(slowk_period),
2543 slowk_ma_type: Some(slowk_ma_type.to_string()),
2544 slowd_period: Some(slowd_period),
2545 slowd_ma_type: Some(slowd_ma_type.to_string()),
2546 };
2547 let input = StochInput::from_slices(hi, lo, cl, params);
2548 stoch_into_slices(&mut ok, &mut od, &input, detect_best_kernel())
2549 .map_err(|e| JsValue::from_str(&e.to_string()))
2550 }
2551}
2552
2553#[cfg(test)]
2554mod tests {
2555 use super::*;
2556 use crate::skip_if_unsupported;
2557 use crate::utilities::data_loader::read_candles_from_csv;
2558 use crate::utilities::enums::Kernel;
2559
2560 fn check_stoch_partial_params(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2561 skip_if_unsupported!(kernel, test);
2562 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2563 let candles = read_candles_from_csv(file_path)?;
2564 let default_params = StochParams::default();
2565 let input = StochInput::from_candles(&candles, default_params);
2566 let output = stoch_with_kernel(&input, kernel)?;
2567 assert_eq!(output.k.len(), candles.close.len());
2568 assert_eq!(output.d.len(), candles.close.len());
2569 Ok(())
2570 }
2571 fn check_stoch_accuracy(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2572 skip_if_unsupported!(kernel, test);
2573 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2574 let candles = read_candles_from_csv(file_path)?;
2575 let input = StochInput::from_candles(&candles, StochParams::default());
2576 let result = stoch_with_kernel(&input, kernel)?;
2577 assert_eq!(result.k.len(), candles.close.len());
2578 assert_eq!(result.d.len(), candles.close.len());
2579 let last_five_k = [
2580 42.51122827572717,
2581 40.13864479593807,
2582 37.853934778363374,
2583 37.337021714266086,
2584 36.26053890551548,
2585 ];
2586 let last_five_d = [
2587 41.36561869426493,
2588 41.7691857059163,
2589 40.16793595000925,
2590 38.44320042952222,
2591 37.15049846604803,
2592 ];
2593 let k_slice = &result.k[result.k.len() - 5..];
2594 let d_slice = &result.d[result.d.len() - 5..];
2595 for i in 0..5 {
2596 assert!(
2597 (k_slice[i] - last_five_k[i]).abs() < 1e-6,
2598 "Mismatch in K at {}: got {}, expected {}",
2599 i,
2600 k_slice[i],
2601 last_five_k[i]
2602 );
2603 assert!(
2604 (d_slice[i] - last_five_d[i]).abs() < 1e-6,
2605 "Mismatch in D at {}: got {}, expected {}",
2606 i,
2607 d_slice[i],
2608 last_five_d[i]
2609 );
2610 }
2611 Ok(())
2612 }
2613 fn check_stoch_default_candles(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2614 skip_if_unsupported!(kernel, test);
2615 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2616 let candles = read_candles_from_csv(file_path)?;
2617 let input = StochInput::with_default_candles(&candles);
2618 let output = stoch_with_kernel(&input, kernel)?;
2619 assert_eq!(output.k.len(), candles.close.len());
2620 assert_eq!(output.d.len(), candles.close.len());
2621 Ok(())
2622 }
2623 fn check_stoch_zero_period(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2624 skip_if_unsupported!(kernel, test);
2625 let high = [10.0, 11.0, 12.0];
2626 let low = [9.0, 9.5, 10.5];
2627 let close = [9.5, 10.6, 11.5];
2628 let params = StochParams {
2629 fastk_period: Some(0),
2630 ..Default::default()
2631 };
2632 let input = StochInput::from_slices(&high, &low, &close, params);
2633 let result = stoch_with_kernel(&input, kernel);
2634 assert!(result.is_err());
2635 Ok(())
2636 }
2637 fn check_stoch_period_exceeds_length(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2638 skip_if_unsupported!(kernel, test);
2639 let high = [10.0, 11.0, 12.0];
2640 let low = [9.0, 9.5, 10.5];
2641 let close = [9.5, 10.6, 11.5];
2642 let params = StochParams {
2643 fastk_period: Some(10),
2644 ..Default::default()
2645 };
2646 let input = StochInput::from_slices(&high, &low, &close, params);
2647 let result = stoch_with_kernel(&input, kernel);
2648 assert!(result.is_err());
2649 Ok(())
2650 }
2651 fn check_stoch_all_nan(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2652 skip_if_unsupported!(kernel, test);
2653 let nan_data = [f64::NAN, f64::NAN, f64::NAN];
2654 let params = StochParams::default();
2655 let input = StochInput::from_slices(&nan_data, &nan_data, &nan_data, params);
2656 let result = stoch_with_kernel(&input, kernel);
2657 assert!(result.is_err());
2658 Ok(())
2659 }
2660
2661 #[cfg(debug_assertions)]
2662 fn check_stoch_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2663 skip_if_unsupported!(kernel, test_name);
2664
2665 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2666 let candles = read_candles_from_csv(file_path)?;
2667
2668 let test_params = vec![
2669 StochParams::default(),
2670 StochParams {
2671 fastk_period: Some(2),
2672 slowk_period: Some(1),
2673 slowd_period: Some(1),
2674 slowk_ma_type: Some("sma".to_string()),
2675 slowd_ma_type: Some("sma".to_string()),
2676 },
2677 StochParams {
2678 fastk_period: Some(5),
2679 slowk_period: Some(2),
2680 slowd_period: Some(2),
2681 slowk_ma_type: Some("sma".to_string()),
2682 slowd_ma_type: Some("sma".to_string()),
2683 },
2684 StochParams {
2685 fastk_period: Some(10),
2686 slowk_period: Some(5),
2687 slowd_period: Some(3),
2688 slowk_ma_type: Some("ema".to_string()),
2689 slowd_ma_type: Some("ema".to_string()),
2690 },
2691 StochParams {
2692 fastk_period: Some(14),
2693 slowk_period: Some(5),
2694 slowd_period: Some(5),
2695 slowk_ma_type: Some("sma".to_string()),
2696 slowd_ma_type: Some("ema".to_string()),
2697 },
2698 StochParams {
2699 fastk_period: Some(20),
2700 slowk_period: Some(3),
2701 slowd_period: Some(3),
2702 slowk_ma_type: Some("sma".to_string()),
2703 slowd_ma_type: Some("sma".to_string()),
2704 },
2705 StochParams {
2706 fastk_period: Some(50),
2707 slowk_period: Some(10),
2708 slowd_period: Some(10),
2709 slowk_ma_type: Some("ema".to_string()),
2710 slowd_ma_type: Some("sma".to_string()),
2711 },
2712 StochParams {
2713 fastk_period: Some(100),
2714 slowk_period: Some(20),
2715 slowd_period: Some(15),
2716 slowk_ma_type: Some("sma".to_string()),
2717 slowd_ma_type: Some("sma".to_string()),
2718 },
2719 StochParams {
2720 fastk_period: Some(7),
2721 slowk_period: Some(1),
2722 slowd_period: Some(7),
2723 slowk_ma_type: Some("sma".to_string()),
2724 slowd_ma_type: Some("ema".to_string()),
2725 },
2726 StochParams {
2727 fastk_period: Some(3),
2728 slowk_period: Some(3),
2729 slowd_period: Some(1),
2730 slowk_ma_type: Some("ema".to_string()),
2731 slowd_ma_type: Some("sma".to_string()),
2732 },
2733 ];
2734
2735 for (param_idx, params) in test_params.iter().enumerate() {
2736 let input = StochInput::from_candles(&candles, params.clone());
2737 let output = stoch_with_kernel(&input, kernel)?;
2738
2739 for (i, &val) in output.k.iter().enumerate() {
2740 if val.is_nan() {
2741 continue;
2742 }
2743
2744 let bits = val.to_bits();
2745
2746 if bits == 0x11111111_11111111 {
2747 panic!(
2748 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in K values \
2749 with params: fastk_period={}, slowk_period={}, slowd_period={}, \
2750 slowk_ma_type={}, slowd_ma_type={} (param set {})",
2751 test_name, val, bits, i,
2752 params.fastk_period.unwrap_or(14),
2753 params.slowk_period.unwrap_or(3),
2754 params.slowd_period.unwrap_or(3),
2755 params.slowk_ma_type.as_deref().unwrap_or("sma"),
2756 params.slowd_ma_type.as_deref().unwrap_or("sma"),
2757 param_idx
2758 );
2759 }
2760
2761 if bits == 0x22222222_22222222 {
2762 panic!(
2763 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in K values \
2764 with params: fastk_period={}, slowk_period={}, slowd_period={}, \
2765 slowk_ma_type={}, slowd_ma_type={} (param set {})",
2766 test_name, val, bits, i,
2767 params.fastk_period.unwrap_or(14),
2768 params.slowk_period.unwrap_or(3),
2769 params.slowd_period.unwrap_or(3),
2770 params.slowk_ma_type.as_deref().unwrap_or("sma"),
2771 params.slowd_ma_type.as_deref().unwrap_or("sma"),
2772 param_idx
2773 );
2774 }
2775
2776 if bits == 0x33333333_33333333 {
2777 panic!(
2778 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in K values \
2779 with params: fastk_period={}, slowk_period={}, slowd_period={}, \
2780 slowk_ma_type={}, slowd_ma_type={} (param set {})",
2781 test_name, val, bits, i,
2782 params.fastk_period.unwrap_or(14),
2783 params.slowk_period.unwrap_or(3),
2784 params.slowd_period.unwrap_or(3),
2785 params.slowk_ma_type.as_deref().unwrap_or("sma"),
2786 params.slowd_ma_type.as_deref().unwrap_or("sma"),
2787 param_idx
2788 );
2789 }
2790 }
2791
2792 for (i, &val) in output.d.iter().enumerate() {
2793 if val.is_nan() {
2794 continue;
2795 }
2796
2797 let bits = val.to_bits();
2798
2799 if bits == 0x11111111_11111111 {
2800 panic!(
2801 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in D values \
2802 with params: fastk_period={}, slowk_period={}, slowd_period={}, \
2803 slowk_ma_type={}, slowd_ma_type={} (param set {})",
2804 test_name, val, bits, i,
2805 params.fastk_period.unwrap_or(14),
2806 params.slowk_period.unwrap_or(3),
2807 params.slowd_period.unwrap_or(3),
2808 params.slowk_ma_type.as_deref().unwrap_or("sma"),
2809 params.slowd_ma_type.as_deref().unwrap_or("sma"),
2810 param_idx
2811 );
2812 }
2813
2814 if bits == 0x22222222_22222222 {
2815 panic!(
2816 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in D values \
2817 with params: fastk_period={}, slowk_period={}, slowd_period={}, \
2818 slowk_ma_type={}, slowd_ma_type={} (param set {})",
2819 test_name, val, bits, i,
2820 params.fastk_period.unwrap_or(14),
2821 params.slowk_period.unwrap_or(3),
2822 params.slowd_period.unwrap_or(3),
2823 params.slowk_ma_type.as_deref().unwrap_or("sma"),
2824 params.slowd_ma_type.as_deref().unwrap_or("sma"),
2825 param_idx
2826 );
2827 }
2828
2829 if bits == 0x33333333_33333333 {
2830 panic!(
2831 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in D values \
2832 with params: fastk_period={}, slowk_period={}, slowd_period={}, \
2833 slowk_ma_type={}, slowd_ma_type={} (param set {})",
2834 test_name, val, bits, i,
2835 params.fastk_period.unwrap_or(14),
2836 params.slowk_period.unwrap_or(3),
2837 params.slowd_period.unwrap_or(3),
2838 params.slowk_ma_type.as_deref().unwrap_or("sma"),
2839 params.slowd_ma_type.as_deref().unwrap_or("sma"),
2840 param_idx
2841 );
2842 }
2843 }
2844 }
2845
2846 Ok(())
2847 }
2848
2849 #[cfg(not(debug_assertions))]
2850 fn check_stoch_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2851 Ok(())
2852 }
2853
2854 #[cfg(feature = "proptest")]
2855 #[allow(clippy::float_cmp)]
2856 fn check_stoch_property(
2857 test_name: &str,
2858 kernel: Kernel,
2859 ) -> Result<(), Box<dyn std::error::Error>> {
2860 use proptest::prelude::*;
2861 skip_if_unsupported!(kernel, test_name);
2862
2863 let strat = (2usize..=50)
2864 .prop_flat_map(|fastk_period| {
2865 (
2866 prop::collection::vec(
2867 (1.0f64..1000.0f64, 0.001f64..0.1f64),
2868 fastk_period.max(10)..400,
2869 ),
2870 Just(fastk_period),
2871 1usize..=10,
2872 1usize..=10,
2873 prop::bool::ANY,
2874 -0.01f64..0.01f64,
2875 prop::bool::ANY,
2876 )
2877 })
2878 .prop_flat_map(
2879 |(
2880 price_vol_pairs,
2881 fastk_period,
2882 slowk_period,
2883 slowd_period,
2884 use_ema,
2885 trend,
2886 is_flat,
2887 )| {
2888 let len = price_vol_pairs.len();
2889 (
2890 Just((
2891 price_vol_pairs,
2892 fastk_period,
2893 slowk_period,
2894 slowd_period,
2895 use_ema,
2896 trend,
2897 is_flat,
2898 )),
2899 prop::collection::vec(-1.0f64..1.0f64, len),
2900 prop::collection::vec(0.0f64..1.0f64, len),
2901 )
2902 },
2903 )
2904 .prop_map(
2905 |(
2906 (
2907 price_vol_pairs,
2908 fastk_period,
2909 slowk_period,
2910 slowd_period,
2911 use_ema,
2912 trend,
2913 is_flat,
2914 ),
2915 close_factors,
2916 beta_params,
2917 )| {
2918 let mut high = Vec::with_capacity(price_vol_pairs.len());
2919 let mut low = Vec::with_capacity(price_vol_pairs.len());
2920 let mut close = Vec::with_capacity(price_vol_pairs.len());
2921
2922 let mut cumulative_trend = 1.0;
2923
2924 for (i, ((base_price, volatility), (close_factor, beta))) in price_vol_pairs
2925 .into_iter()
2926 .zip(close_factors.into_iter().zip(beta_params))
2927 .enumerate()
2928 {
2929 cumulative_trend *= 1.0 + trend;
2930 let trended_price = base_price * cumulative_trend;
2931
2932 if is_flat {
2933 let flat_price = if i == 0 { base_price } else { high[0] };
2934 high.push(flat_price);
2935 low.push(flat_price);
2936 close.push(flat_price);
2937 } else {
2938 let spread = trended_price * volatility;
2939 let h = trended_price + spread;
2940 let l = (trended_price - spread).max(0.01);
2941
2942 let beta_factor = if beta < 0.5 {
2943 2.0 * beta * beta
2944 } else {
2945 1.0 - 2.0 * (1.0 - beta) * (1.0 - beta)
2946 };
2947
2948 let close_position = close_factor * 0.5 + beta_factor * 0.5;
2949 let c = l + (h - l) * ((close_position + 1.0) / 2.0);
2950
2951 high.push(h);
2952 low.push(l);
2953 close.push(c.clamp(l, h));
2954 }
2955 }
2956
2957 let ma_type = if use_ema { "ema" } else { "sma" };
2958
2959 (
2960 high,
2961 low,
2962 close,
2963 fastk_period,
2964 slowk_period,
2965 slowd_period,
2966 ma_type.to_string(),
2967 is_flat,
2968 )
2969 },
2970 );
2971
2972 proptest::test_runner::TestRunner::default().run(
2973 &strat,
2974 |(high, low, close, fastk_period, slowk_period, slowd_period, ma_type, is_flat)| {
2975 let params = StochParams {
2976 fastk_period: Some(fastk_period),
2977 slowk_period: Some(slowk_period),
2978 slowk_ma_type: Some(ma_type.clone()),
2979 slowd_period: Some(slowd_period),
2980 slowd_ma_type: Some(ma_type.clone()),
2981 };
2982
2983 let input = StochInput::from_slices(&high, &low, &close, params.clone());
2984
2985 let result = stoch_with_kernel(&input, kernel)?;
2986
2987 let ref_result = stoch_with_kernel(&input, Kernel::Scalar)?;
2988
2989 prop_assert_eq!(result.k.len(), high.len());
2990 prop_assert_eq!(result.d.len(), high.len());
2991
2992 let warmup_k = fastk_period - 1;
2993 let warmup_slowk = if ma_type == "ema" {
2994 0
2995 } else {
2996 slowk_period - 1
2997 };
2998 let warmup_slowd = if ma_type == "ema" {
2999 0
3000 } else {
3001 slowd_period - 1
3002 };
3003 let expected_warmup = warmup_k
3004 .max(warmup_k + warmup_slowk)
3005 .max(warmup_k + warmup_slowk + warmup_slowd);
3006
3007 for i in 0..warmup_k.min(high.len()) {
3008 prop_assert!(
3009 result.k[i].is_nan(),
3010 "K[{}] should be NaN during initial warmup but was {}",
3011 i,
3012 result.k[i]
3013 );
3014 prop_assert!(
3015 result.d[i].is_nan(),
3016 "D[{}] should be NaN during initial warmup but was {}",
3017 i,
3018 result.d[i]
3019 );
3020 }
3021
3022 for i in expected_warmup..high.len() {
3023 let k_val = result.k[i];
3024 let d_val = result.d[i];
3025 let ref_k = ref_result.k[i];
3026 let ref_d = ref_result.d[i];
3027
3028 if !k_val.is_nan() {
3029 prop_assert!(
3030 k_val >= -1e-9 && k_val <= 100.0 + 1e-9,
3031 "K[{}] = {} is outside [0, 100] range",
3032 i,
3033 k_val
3034 );
3035 }
3036
3037 if !d_val.is_nan() {
3038 prop_assert!(
3039 d_val >= -1e-9 && d_val <= 100.0 + 1e-9,
3040 "D[{}] = {} is outside [0, 100] range",
3041 i,
3042 d_val
3043 );
3044 }
3045
3046 if k_val.is_finite() && ref_k.is_finite() {
3047 let k_diff = (k_val - ref_k).abs();
3048 let k_ulp_diff = k_val.to_bits().abs_diff(ref_k.to_bits());
3049 prop_assert!(
3050 k_diff <= 1e-9 || k_ulp_diff <= 4,
3051 "K mismatch at [{}]: {} vs {} (diff={}, ULP={})",
3052 i,
3053 k_val,
3054 ref_k,
3055 k_diff,
3056 k_ulp_diff
3057 );
3058 }
3059
3060 if d_val.is_finite() && ref_d.is_finite() {
3061 let d_diff = (d_val - ref_d).abs();
3062 let d_ulp_diff = d_val.to_bits().abs_diff(ref_d.to_bits());
3063 prop_assert!(
3064 d_diff <= 1e-9 || d_ulp_diff <= 4,
3065 "D mismatch at [{}]: {} vs {} (diff={}, ULP={})",
3066 i,
3067 d_val,
3068 ref_d,
3069 d_diff,
3070 d_ulp_diff
3071 );
3072 }
3073
3074 if i >= fastk_period - 1 && !k_val.is_nan() {
3075 let window_start = i + 1 - fastk_period;
3076 let window_high = &high[window_start..=i];
3077 let window_low = &low[window_start..=i];
3078
3079 let max_h = window_high
3080 .iter()
3081 .cloned()
3082 .fold(f64::NEG_INFINITY, f64::max);
3083 let min_l = window_low.iter().cloned().fold(f64::INFINITY, f64::min);
3084
3085 if is_flat || (max_h - min_l).abs() < f64::EPSILON {
3086 prop_assert!(
3087 (k_val - 50.0).abs() < 1e-6,
3088 "K[{}] = {} should be 50 in flat market",
3089 i,
3090 k_val
3091 );
3092 } else {
3093 if (close[i] - max_h).abs() < 1e-10 {
3094 let expected_min = if slowk_period == 1 { 99.0 } else { 85.0 };
3095 prop_assert!(
3096 k_val >= expected_min,
3097 "K[{}] = {} should be >= {} when close equals highest high (slowk_period={})",
3098 i, k_val, expected_min, slowk_period
3099 );
3100 }
3101
3102 if (close[i] - min_l).abs() < 1e-10 {
3103 let expected_max = if slowk_period == 1 { 1.0 } else { 15.0 };
3104 prop_assert!(
3105 k_val <= expected_max,
3106 "K[{}] = {} should be <= {} when close equals lowest low (slowk_period={})",
3107 i, k_val, expected_max, slowk_period
3108 );
3109 }
3110 }
3111 }
3112 }
3113
3114 let k_valid: Vec<f64> =
3115 result.k.iter().filter(|x| x.is_finite()).copied().collect();
3116 let d_valid: Vec<f64> =
3117 result.d.iter().filter(|x| x.is_finite()).copied().collect();
3118
3119 if k_valid.len() > 10 && d_valid.len() > 10 && !is_flat {
3120 let k_mean = k_valid.iter().sum::<f64>() / k_valid.len() as f64;
3121 let d_mean = d_valid.iter().sum::<f64>() / d_valid.len() as f64;
3122
3123 let k_var = k_valid.iter().map(|x| (x - k_mean).powi(2)).sum::<f64>()
3124 / k_valid.len() as f64;
3125 let d_var = d_valid.iter().map(|x| (x - d_mean).powi(2)).sum::<f64>()
3126 / d_valid.len() as f64;
3127
3128 if slowd_period > 1 && k_var > 1e-6 {
3129 prop_assert!(
3130 d_var <= k_var * 1.01,
3131 "D variance {} should be <= K variance {} (smoothing effect with slowd_period={})",
3132 d_var, k_var, slowd_period
3133 );
3134 }
3135
3136 if slowd_period == 1 {
3137 for i in expected_warmup..result.k.len() {
3138 if result.k[i].is_finite() && result.d[i].is_finite() {
3139 prop_assert!(
3140 (result.k[i] - result.d[i]).abs() < 1e-9,
3141 "When slowd_period=1, D[{}]={} should equal K[{}]={}",
3142 i,
3143 result.d[i],
3144 i,
3145 result.k[i]
3146 );
3147 }
3148 }
3149 }
3150 }
3151
3152 if !is_flat && high.len() > fastk_period + 10 {
3153 let opposite_ma_type = if ma_type == "sma" { "ema" } else { "sma" };
3154 let opposite_params = StochParams {
3155 fastk_period: Some(fastk_period),
3156 slowk_period: Some(slowk_period),
3157 slowk_ma_type: Some(opposite_ma_type.to_string()),
3158 slowd_period: Some(slowd_period),
3159 slowd_ma_type: Some(opposite_ma_type.to_string()),
3160 };
3161
3162 let opposite_input =
3163 StochInput::from_slices(&high, &low, &close, opposite_params);
3164 let opposite_result = stoch_with_kernel(&opposite_input, kernel)?;
3165
3166 let mut diff_count = 0;
3167 let mut total_valid = 0;
3168 for i in expected_warmup..result.k.len() {
3169 if result.k[i].is_finite() && opposite_result.k[i].is_finite() {
3170 total_valid += 1;
3171 if (result.k[i] - opposite_result.k[i]).abs() > 1e-6 {
3172 diff_count += 1;
3173 }
3174 }
3175 }
3176
3177 if total_valid > 10 && slowk_period > 1 {
3178 let diff_ratio = diff_count as f64 / total_valid as f64;
3179 prop_assert!(
3180 diff_ratio >= 0.8,
3181 "SMA and EMA should produce different results: only {}/{} values differ ({}%)",
3182 diff_count, total_valid, (diff_ratio * 100.0) as i32
3183 );
3184 }
3185 }
3186
3187 Ok(())
3188 },
3189 )?;
3190
3191 Ok(())
3192 }
3193
3194 macro_rules! generate_all_stoch_tests {
3195 ($($test_fn:ident),*) => {
3196 paste::paste! {
3197 $( #[test] fn [<$test_fn _scalar_f64>]() { let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar); } )*
3198 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3199 $( #[test] fn [<$test_fn _avx2_f64>]() { let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2); } )*
3200 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3201 $( #[test] fn [<$test_fn _avx512_f64>]() { let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512); } )*
3202 }
3203 }
3204 }
3205 generate_all_stoch_tests!(
3206 check_stoch_partial_params,
3207 check_stoch_accuracy,
3208 check_stoch_default_candles,
3209 check_stoch_zero_period,
3210 check_stoch_period_exceeds_length,
3211 check_stoch_all_nan,
3212 check_stoch_no_poison
3213 );
3214
3215 #[cfg(feature = "proptest")]
3216 generate_all_stoch_tests!(check_stoch_property);
3217 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3218 skip_if_unsupported!(kernel, test);
3219
3220 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3221 let c = read_candles_from_csv(file)?;
3222
3223 let output = StochBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
3224
3225 let def = StochParams::default();
3226 let (row_k, row_d) = output.values_for(&def).expect("default row missing");
3227
3228 assert_eq!(row_k.len(), c.close.len());
3229 assert_eq!(row_d.len(), c.close.len());
3230
3231 let expected_k = [
3232 42.51122827572717,
3233 40.13864479593807,
3234 37.853934778363374,
3235 37.337021714266086,
3236 36.26053890551548,
3237 ];
3238 let expected_d = [
3239 41.36561869426493,
3240 41.7691857059163,
3241 40.16793595000925,
3242 38.44320042952222,
3243 37.15049846604803,
3244 ];
3245 let start = row_k.len() - 5;
3246 for (i, &v) in row_k[start..].iter().enumerate() {
3247 assert!(
3248 (v - expected_k[i]).abs() < 1e-6,
3249 "[{test}] default-row K mismatch at idx {i}: {v} vs {expected_k:?}"
3250 );
3251 }
3252 for (i, &v) in row_d[start..].iter().enumerate() {
3253 assert!(
3254 (v - expected_d[i]).abs() < 1e-6,
3255 "[{test}] default-row D mismatch at idx {i}: {v} vs {expected_d:?}"
3256 );
3257 }
3258 Ok(())
3259 }
3260
3261 #[cfg(debug_assertions)]
3262 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3263 skip_if_unsupported!(kernel, test);
3264
3265 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3266 let c = read_candles_from_csv(file)?;
3267
3268 let test_configs = vec![
3269 (2, 10, 2, 1, 5, 1, 1, 5, 1),
3270 (5, 25, 5, 2, 10, 2, 2, 10, 2),
3271 (10, 30, 10, 3, 9, 3, 3, 9, 3),
3272 (14, 14, 0, 1, 5, 1, 1, 5, 1),
3273 (2, 5, 1, 3, 3, 0, 3, 3, 0),
3274 (20, 50, 15, 5, 15, 5, 5, 15, 5),
3275 (7, 21, 7, 2, 6, 2, 2, 6, 2),
3276 (3, 12, 3, 1, 3, 1, 1, 3, 1),
3277 ];
3278
3279 for (
3280 cfg_idx,
3281 &(fk_start, fk_end, fk_step, sk_start, sk_end, sk_step, sd_start, sd_end, sd_step),
3282 ) in test_configs.iter().enumerate()
3283 {
3284 let output = StochBatchBuilder::new()
3285 .kernel(kernel)
3286 .fastk_period_range(fk_start, fk_end, fk_step)
3287 .slowk_period_range(sk_start, sk_end, sk_step)
3288 .slowd_period_range(sd_start, sd_end, sd_step)
3289 .slowk_ma_type_static("sma")
3290 .slowd_ma_type_static("sma")
3291 .apply_candles(&c)?;
3292
3293 for (idx, &val) in output.k.iter().enumerate() {
3294 if val.is_nan() {
3295 continue;
3296 }
3297
3298 let bits = val.to_bits();
3299 let row = idx / output.cols;
3300 let col = idx % output.cols;
3301 let combo = &output.combos[row];
3302
3303 if bits == 0x11111111_11111111 {
3304 panic!(
3305 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
3306 at row {} col {} (flat index {}) in K values with params: \
3307 fastk_period={}, slowk_period={}, slowd_period={}, \
3308 slowk_ma_type={}, slowd_ma_type={}",
3309 test,
3310 cfg_idx,
3311 val,
3312 bits,
3313 row,
3314 col,
3315 idx,
3316 combo.fastk_period.unwrap_or(14),
3317 combo.slowk_period.unwrap_or(3),
3318 combo.slowd_period.unwrap_or(3),
3319 combo.slowk_ma_type.as_deref().unwrap_or("sma"),
3320 combo.slowd_ma_type.as_deref().unwrap_or("sma")
3321 );
3322 }
3323
3324 if bits == 0x22222222_22222222 {
3325 panic!(
3326 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
3327 at row {} col {} (flat index {}) in K values with params: \
3328 fastk_period={}, slowk_period={}, slowd_period={}, \
3329 slowk_ma_type={}, slowd_ma_type={}",
3330 test,
3331 cfg_idx,
3332 val,
3333 bits,
3334 row,
3335 col,
3336 idx,
3337 combo.fastk_period.unwrap_or(14),
3338 combo.slowk_period.unwrap_or(3),
3339 combo.slowd_period.unwrap_or(3),
3340 combo.slowk_ma_type.as_deref().unwrap_or("sma"),
3341 combo.slowd_ma_type.as_deref().unwrap_or("sma")
3342 );
3343 }
3344
3345 if bits == 0x33333333_33333333 {
3346 panic!(
3347 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
3348 at row {} col {} (flat index {}) in K values with params: \
3349 fastk_period={}, slowk_period={}, slowd_period={}, \
3350 slowk_ma_type={}, slowd_ma_type={}",
3351 test,
3352 cfg_idx,
3353 val,
3354 bits,
3355 row,
3356 col,
3357 idx,
3358 combo.fastk_period.unwrap_or(14),
3359 combo.slowk_period.unwrap_or(3),
3360 combo.slowd_period.unwrap_or(3),
3361 combo.slowk_ma_type.as_deref().unwrap_or("sma"),
3362 combo.slowd_ma_type.as_deref().unwrap_or("sma")
3363 );
3364 }
3365 }
3366
3367 for (idx, &val) in output.d.iter().enumerate() {
3368 if val.is_nan() {
3369 continue;
3370 }
3371
3372 let bits = val.to_bits();
3373 let row = idx / output.cols;
3374 let col = idx % output.cols;
3375 let combo = &output.combos[row];
3376
3377 if bits == 0x11111111_11111111 {
3378 panic!(
3379 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
3380 at row {} col {} (flat index {}) in D values with params: \
3381 fastk_period={}, slowk_period={}, slowd_period={}, \
3382 slowk_ma_type={}, slowd_ma_type={}",
3383 test,
3384 cfg_idx,
3385 val,
3386 bits,
3387 row,
3388 col,
3389 idx,
3390 combo.fastk_period.unwrap_or(14),
3391 combo.slowk_period.unwrap_or(3),
3392 combo.slowd_period.unwrap_or(3),
3393 combo.slowk_ma_type.as_deref().unwrap_or("sma"),
3394 combo.slowd_ma_type.as_deref().unwrap_or("sma")
3395 );
3396 }
3397
3398 if bits == 0x22222222_22222222 {
3399 panic!(
3400 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
3401 at row {} col {} (flat index {}) in D values with params: \
3402 fastk_period={}, slowk_period={}, slowd_period={}, \
3403 slowk_ma_type={}, slowd_ma_type={}",
3404 test,
3405 cfg_idx,
3406 val,
3407 bits,
3408 row,
3409 col,
3410 idx,
3411 combo.fastk_period.unwrap_or(14),
3412 combo.slowk_period.unwrap_or(3),
3413 combo.slowd_period.unwrap_or(3),
3414 combo.slowk_ma_type.as_deref().unwrap_or("sma"),
3415 combo.slowd_ma_type.as_deref().unwrap_or("sma")
3416 );
3417 }
3418
3419 if bits == 0x33333333_33333333 {
3420 panic!(
3421 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
3422 at row {} col {} (flat index {}) in D values with params: \
3423 fastk_period={}, slowk_period={}, slowd_period={}, \
3424 slowk_ma_type={}, slowd_ma_type={}",
3425 test,
3426 cfg_idx,
3427 val,
3428 bits,
3429 row,
3430 col,
3431 idx,
3432 combo.fastk_period.unwrap_or(14),
3433 combo.slowk_period.unwrap_or(3),
3434 combo.slowd_period.unwrap_or(3),
3435 combo.slowk_ma_type.as_deref().unwrap_or("sma"),
3436 combo.slowd_ma_type.as_deref().unwrap_or("sma")
3437 );
3438 }
3439 }
3440 }
3441
3442 Ok(())
3443 }
3444
3445 #[cfg(not(debug_assertions))]
3446 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
3447 Ok(())
3448 }
3449
3450 macro_rules! gen_batch_tests {
3451 ($fn_name:ident) => {
3452 paste::paste! {
3453 #[test] fn [<$fn_name _scalar>]() {
3454 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3455 }
3456 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3457 #[test] fn [<$fn_name _avx2>]() {
3458 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3459 }
3460 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3461 #[test] fn [<$fn_name _avx512>]() {
3462 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3463 }
3464 #[test] fn [<$fn_name _auto_detect>]() {
3465 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3466 }
3467 }
3468 };
3469 }
3470
3471 gen_batch_tests!(check_batch_default_row);
3472 gen_batch_tests!(check_batch_no_poison);
3473
3474 fn eq_or_both_nan(a: f64, b: f64) -> bool {
3475 (a.is_nan() && b.is_nan()) || (a == b)
3476 }
3477
3478 #[test]
3479 fn test_stoch_into_matches_api() -> Result<(), Box<dyn Error>> {
3480 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3481 let candles = read_candles_from_csv(file_path)?;
3482 let input = StochInput::with_default_candles(&candles);
3483
3484 let baseline = stoch(&input)?;
3485
3486 let mut out_k = vec![0.0; baseline.k.len()];
3487 let mut out_d = vec![0.0; baseline.d.len()];
3488
3489 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
3490 {
3491 stoch_into(&input, &mut out_k, &mut out_d)?;
3492 }
3493 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3494 {
3495 stoch_into_slices(&mut out_k, &mut out_d, &input, detect_best_kernel())?;
3496 }
3497
3498 assert_eq!(out_k.len(), baseline.k.len());
3499 assert_eq!(out_d.len(), baseline.d.len());
3500 for i in 0..out_k.len() {
3501 assert!(
3502 eq_or_both_nan(out_k[i], baseline.k[i]),
3503 "K mismatch at {}: got {}, expected {}",
3504 i,
3505 out_k[i],
3506 baseline.k[i]
3507 );
3508 assert!(
3509 eq_or_both_nan(out_d[i], baseline.d[i]),
3510 "D mismatch at {}: got {}, expected {}",
3511 i,
3512 out_d[i],
3513 baseline.d[i]
3514 );
3515 }
3516 Ok(())
3517 }
3518
3519 #[test]
3520 fn test_stoch_compute_into_matches_api() -> Result<(), Box<dyn Error>> {
3521 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3522 let candles = read_candles_from_csv(file_path)?;
3523 let input = StochInput::with_default_candles(&candles);
3524
3525 let baseline = stoch(&input)?;
3526
3527 let mut out_k = vec![0.0; baseline.k.len()];
3528 let mut out_d = vec![0.0; baseline.d.len()];
3529 stoch_compute_into(&input, &mut out_k, &mut out_d, Kernel::Auto)?;
3530
3531 assert_eq!(out_k.len(), baseline.k.len());
3532 assert_eq!(out_d.len(), baseline.d.len());
3533 for i in 0..out_k.len() {
3534 assert!(
3535 eq_or_both_nan(out_k[i], baseline.k[i]),
3536 "K mismatch at {}: got {}, expected {}",
3537 i,
3538 out_k[i],
3539 baseline.k[i]
3540 );
3541 assert!(
3542 eq_or_both_nan(out_d[i], baseline.d[i]),
3543 "D mismatch at {}: got {}, expected {}",
3544 i,
3545 out_d[i],
3546 baseline.d[i]
3547 );
3548 }
3549 Ok(())
3550 }
3551}
3552
3553#[inline]
3554fn stoch_classic_sma(
3555 k_raw: &[f64],
3556 slowk_period: usize,
3557 slowd_period: usize,
3558 first_valid_idx: usize,
3559) -> Result<StochOutput, StochError> {
3560 let len = k_raw.len();
3561 let mut k_vec = alloc_with_nan_prefix(len, first_valid_idx + slowk_period - 1);
3562 let mut d_vec = alloc_with_nan_prefix(len, first_valid_idx + slowk_period + slowd_period - 2);
3563
3564 let mut sum_k = 0.0;
3565 let k_start = first_valid_idx;
3566
3567 for i in k_start..(k_start + slowk_period).min(len) {
3568 if !k_raw[i].is_nan() {
3569 sum_k += k_raw[i];
3570 }
3571 }
3572 if k_start + slowk_period - 1 < len {
3573 k_vec[k_start + slowk_period - 1] = sum_k / slowk_period as f64;
3574 }
3575
3576 for i in (k_start + slowk_period)..len {
3577 let old_val = k_raw[i - slowk_period];
3578 let new_val = k_raw[i];
3579 if !old_val.is_nan() {
3580 sum_k -= old_val;
3581 }
3582 if !new_val.is_nan() {
3583 sum_k += new_val;
3584 }
3585 k_vec[i] = sum_k / slowk_period as f64;
3586 }
3587
3588 let mut sum_d = 0.0;
3589 let d_start = first_valid_idx + slowk_period - 1;
3590
3591 for i in d_start..(d_start + slowd_period).min(len) {
3592 if !k_vec[i].is_nan() {
3593 sum_d += k_vec[i];
3594 }
3595 }
3596 if d_start + slowd_period - 1 < len {
3597 d_vec[d_start + slowd_period - 1] = sum_d / slowd_period as f64;
3598 }
3599
3600 for i in (d_start + slowd_period)..len {
3601 let old_val = k_vec[i - slowd_period];
3602 let new_val = k_vec[i];
3603 if !old_val.is_nan() {
3604 sum_d -= old_val;
3605 }
3606 if !new_val.is_nan() {
3607 sum_d += new_val;
3608 }
3609 d_vec[i] = sum_d / slowd_period as f64;
3610 }
3611
3612 Ok(StochOutput { k: k_vec, d: d_vec })
3613}
3614
3615#[inline]
3616fn stoch_classic_ema(
3617 k_raw: &[f64],
3618 slowk_period: usize,
3619 slowd_period: usize,
3620 first_valid_idx: usize,
3621) -> Result<StochOutput, StochError> {
3622 let len = k_raw.len();
3623 let mut k_vec = alloc_with_nan_prefix(len, first_valid_idx + slowk_period - 1);
3624 let mut d_vec = alloc_with_nan_prefix(len, first_valid_idx + slowk_period + slowd_period - 2);
3625
3626 let alpha_k = 2.0 / (slowk_period as f64 + 1.0);
3627 let one_minus_alpha_k = 1.0 - alpha_k;
3628
3629 let k_warmup = first_valid_idx + slowk_period - 1;
3630 let mut sum_k = 0.0;
3631 let mut count_k = 0;
3632 for i in first_valid_idx..(first_valid_idx + slowk_period).min(len) {
3633 if !k_raw[i].is_nan() {
3634 sum_k += k_raw[i];
3635 count_k += 1;
3636 }
3637 }
3638
3639 if count_k > 0 && k_warmup < len {
3640 let mut ema_k = sum_k / count_k as f64;
3641 k_vec[k_warmup] = ema_k;
3642
3643 for i in (k_warmup + 1)..len {
3644 if !k_raw[i].is_nan() {
3645 ema_k = alpha_k * k_raw[i] + one_minus_alpha_k * ema_k;
3646 }
3647 k_vec[i] = ema_k;
3648 }
3649 } else {
3650 for i in k_warmup..len {
3651 k_vec[i] = f64::NAN;
3652 }
3653 }
3654
3655 let alpha_d = 2.0 / (slowd_period as f64 + 1.0);
3656 let one_minus_alpha_d = 1.0 - alpha_d;
3657
3658 let d_warmup = first_valid_idx + slowk_period + slowd_period - 2;
3659 let d_start = first_valid_idx + slowk_period - 1;
3660 let mut sum_d = 0.0;
3661 let mut count_d = 0;
3662 for i in d_start..(d_start + slowd_period).min(len) {
3663 if !k_vec[i].is_nan() {
3664 sum_d += k_vec[i];
3665 count_d += 1;
3666 }
3667 }
3668
3669 if count_d > 0 && d_warmup < len {
3670 let mut ema_d = sum_d / count_d as f64;
3671 d_vec[d_warmup] = ema_d;
3672
3673 for i in (d_warmup + 1)..len {
3674 if !k_vec[i].is_nan() {
3675 ema_d = alpha_d * k_vec[i] + one_minus_alpha_d * ema_d;
3676 }
3677 d_vec[i] = ema_d;
3678 }
3679 } else {
3680 for i in d_warmup..len {
3681 d_vec[i] = f64::NAN;
3682 }
3683 }
3684
3685 Ok(StochOutput { k: k_vec, d: d_vec })
3686}