Skip to main content

vector_ta/indicators/
stoch.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::types::PyDict;
5#[cfg(feature = "python")]
6use pyo3::{exceptions::PyValueError, prelude::*};
7
8#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
9use serde::{Deserialize, Serialize};
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use wasm_bindgen::prelude::*;
12
13use crate::indicators::moving_averages::ma::{ma, MaData};
14use crate::indicators::utility_functions::{max_rolling, min_rolling};
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
19    make_uninit_matrix,
20};
21#[cfg(feature = "python")]
22use crate::utilities::kernel_validation::validate_kernel;
23use aligned_vec::{AVec, CACHELINE_ALIGN};
24#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
25use core::arch::x86_64::*;
26#[cfg(not(target_arch = "wasm32"))]
27use rayon::prelude::*;
28use std::collections::VecDeque;
29use std::convert::AsRef;
30use std::error::Error;
31use thiserror::Error;
32
33#[derive(Debug, Clone)]
34pub enum StochData<'a> {
35    Candles {
36        candles: &'a Candles,
37    },
38    Slices {
39        high: &'a [f64],
40        low: &'a [f64],
41        close: &'a [f64],
42    },
43}
44
45#[derive(Debug, Clone)]
46pub struct StochOutput {
47    pub k: Vec<f64>,
48    pub d: Vec<f64>,
49}
50
51#[derive(Debug, Clone)]
52#[cfg_attr(
53    all(target_arch = "wasm32", feature = "wasm"),
54    derive(Serialize, Deserialize)
55)]
56pub struct StochParams {
57    pub fastk_period: Option<usize>,
58    pub slowk_period: Option<usize>,
59    pub slowk_ma_type: Option<String>,
60    pub slowd_period: Option<usize>,
61    pub slowd_ma_type: Option<String>,
62}
63
64impl Default for StochParams {
65    fn default() -> Self {
66        Self {
67            fastk_period: Some(14),
68            slowk_period: Some(3),
69            slowk_ma_type: Some("sma".to_string()),
70            slowd_period: Some(3),
71            slowd_ma_type: Some("sma".to_string()),
72        }
73    }
74}
75
76#[derive(Debug, Clone)]
77pub struct StochInput<'a> {
78    pub data: StochData<'a>,
79    pub params: StochParams,
80}
81
82impl<'a> StochInput<'a> {
83    #[inline]
84    pub fn from_candles(c: &'a Candles, p: StochParams) -> Self {
85        Self {
86            data: StochData::Candles { candles: c },
87            params: p,
88        }
89    }
90    #[inline]
91    pub fn from_slices(high: &'a [f64], low: &'a [f64], close: &'a [f64], p: StochParams) -> Self {
92        Self {
93            data: StochData::Slices { high, low, close },
94            params: p,
95        }
96    }
97    #[inline]
98    pub fn with_default_candles(c: &'a Candles) -> Self {
99        Self::from_candles(c, StochParams::default())
100    }
101    #[inline]
102    pub fn get_fastk_period(&self) -> usize {
103        self.params.fastk_period.unwrap_or(14)
104    }
105    #[inline]
106    pub fn get_slowk_period(&self) -> usize {
107        self.params.slowk_period.unwrap_or(3)
108    }
109    #[inline]
110    pub fn get_slowk_ma_type(&self) -> String {
111        self.params
112            .slowk_ma_type
113            .clone()
114            .unwrap_or_else(|| "sma".to_string())
115    }
116    #[inline]
117    pub fn get_slowd_period(&self) -> usize {
118        self.params.slowd_period.unwrap_or(3)
119    }
120    #[inline]
121    pub fn get_slowd_ma_type(&self) -> String {
122        self.params
123            .slowd_ma_type
124            .clone()
125            .unwrap_or_else(|| "sma".to_string())
126    }
127}
128
129#[derive(Copy, Clone, Debug)]
130pub struct StochBuilder {
131    fastk_period: Option<usize>,
132    slowk_period: Option<usize>,
133    slowk_ma_type: Option<&'static str>,
134    slowd_period: Option<usize>,
135    slowd_ma_type: Option<&'static str>,
136    kernel: Kernel,
137}
138
139impl Default for StochBuilder {
140    fn default() -> Self {
141        Self {
142            fastk_period: None,
143            slowk_period: None,
144            slowk_ma_type: None,
145            slowd_period: None,
146            slowd_ma_type: None,
147            kernel: Kernel::Auto,
148        }
149    }
150}
151
152impl StochBuilder {
153    #[inline(always)]
154    pub fn new() -> Self {
155        Self::default()
156    }
157    #[inline(always)]
158    pub fn fastk_period(mut self, n: usize) -> Self {
159        self.fastk_period = Some(n);
160        self
161    }
162    #[inline(always)]
163    pub fn slowk_period(mut self, n: usize) -> Self {
164        self.slowk_period = Some(n);
165        self
166    }
167    #[inline(always)]
168    pub fn slowk_ma_type(mut self, t: &'static str) -> Self {
169        self.slowk_ma_type = Some(t);
170        self
171    }
172    #[inline(always)]
173    pub fn slowd_period(mut self, n: usize) -> Self {
174        self.slowd_period = Some(n);
175        self
176    }
177    #[inline(always)]
178    pub fn slowd_ma_type(mut self, t: &'static str) -> Self {
179        self.slowd_ma_type = Some(t);
180        self
181    }
182    #[inline(always)]
183    pub fn kernel(mut self, k: Kernel) -> Self {
184        self.kernel = k;
185        self
186    }
187
188    #[inline(always)]
189    pub fn apply(self, c: &Candles) -> Result<StochOutput, StochError> {
190        let p = StochParams {
191            fastk_period: self.fastk_period,
192            slowk_period: self.slowk_period,
193            slowk_ma_type: self.slowk_ma_type.map(|s| s.to_string()),
194            slowd_period: self.slowd_period,
195            slowd_ma_type: self.slowd_ma_type.map(|s| s.to_string()),
196        };
197        let i = StochInput::from_candles(c, p);
198        stoch_with_kernel(&i, self.kernel)
199    }
200    #[inline(always)]
201    pub fn apply_slices(
202        self,
203        high: &[f64],
204        low: &[f64],
205        close: &[f64],
206    ) -> Result<StochOutput, StochError> {
207        let p = StochParams {
208            fastk_period: self.fastk_period,
209            slowk_period: self.slowk_period,
210            slowk_ma_type: self.slowk_ma_type.map(|s| s.to_string()),
211            slowd_period: self.slowd_period,
212            slowd_ma_type: self.slowd_ma_type.map(|s| s.to_string()),
213        };
214        let i = StochInput::from_slices(high, low, close, p);
215        stoch_with_kernel(&i, self.kernel)
216    }
217    #[inline(always)]
218    pub fn into_stream(self) -> Result<StochStream, StochError> {
219        let p = StochParams {
220            fastk_period: self.fastk_period,
221            slowk_period: self.slowk_period,
222            slowk_ma_type: self.slowk_ma_type.map(|s| s.to_string()),
223            slowd_period: self.slowd_period,
224            slowd_ma_type: self.slowd_ma_type.map(|s| s.to_string()),
225        };
226        StochStream::try_new(p)
227    }
228}
229
230#[derive(Debug, Error)]
231pub enum StochError {
232    #[error("stoch: Empty data provided.")]
233    EmptyInputData,
234    #[error("stoch: Mismatched length.")]
235    MismatchedLength,
236    #[error("stoch: Invalid period: period = {period}, data length = {data_len}")]
237    InvalidPeriod { period: usize, data_len: usize },
238    #[error("stoch: Not enough valid data: needed = {needed}, valid = {valid}")]
239    NotEnoughValidData { needed: usize, valid: usize },
240    #[error("stoch: All values are NaN.")]
241    AllValuesNaN,
242    #[error("stoch: Output length mismatch: expected {expected}, got {got}")]
243    OutputLengthMismatch { expected: usize, got: usize },
244    #[error("stoch: Invalid range: start={start}, end={end}, step={step}")]
245    InvalidRange {
246        start: usize,
247        end: usize,
248        step: usize,
249    },
250    #[error("stoch: Invalid kernel for batch: {0:?}")]
251    InvalidKernelForBatch(Kernel),
252    #[error("stoch: {0}")]
253    Other(String),
254}
255
256#[inline]
257pub fn stoch(input: &StochInput) -> Result<StochOutput, StochError> {
258    stoch_with_kernel(input, Kernel::Auto)
259}
260
261pub fn stoch_with_kernel(input: &StochInput, kernel: Kernel) -> Result<StochOutput, StochError> {
262    let (high, low, close) = match &input.data {
263        StochData::Candles { candles } => {
264            let high = candles
265                .select_candle_field("high")
266                .map_err(|e| StochError::Other(e.to_string()))?;
267            let low = candles
268                .select_candle_field("low")
269                .map_err(|e| StochError::Other(e.to_string()))?;
270            let close = candles
271                .select_candle_field("close")
272                .map_err(|e| StochError::Other(e.to_string()))?;
273            (high, low, close)
274        }
275        StochData::Slices { high, low, close } => (*high, *low, *close),
276    };
277
278    let data_len = high.len();
279    if data_len == 0 || low.is_empty() || close.is_empty() {
280        return Err(StochError::EmptyInputData);
281    }
282    if data_len != low.len() || data_len != close.len() {
283        return Err(StochError::MismatchedLength);
284    }
285
286    let fastk_period = input.get_fastk_period();
287    let slowk_period = input.get_slowk_period();
288    let slowd_period = input.get_slowd_period();
289
290    if fastk_period == 0 || fastk_period > data_len {
291        return Err(StochError::InvalidPeriod {
292            period: fastk_period,
293            data_len,
294        });
295    }
296    if slowk_period == 0 || slowk_period > data_len {
297        return Err(StochError::InvalidPeriod {
298            period: slowk_period,
299            data_len,
300        });
301    }
302    if slowd_period == 0 || slowd_period > data_len {
303        return Err(StochError::InvalidPeriod {
304            period: slowd_period,
305            data_len,
306        });
307    }
308
309    let first_valid_idx = high
310        .iter()
311        .zip(low.iter())
312        .zip(close.iter())
313        .position(|((h, l), c)| !h.is_nan() && !l.is_nan() && !c.is_nan())
314        .ok_or(StochError::AllValuesNaN)?;
315
316    if (data_len - first_valid_idx) < fastk_period {
317        return Err(StochError::NotEnoughValidData {
318            needed: fastk_period,
319            valid: data_len - first_valid_idx,
320        });
321    }
322
323    let mut hh = alloc_with_nan_prefix(data_len, first_valid_idx + fastk_period - 1);
324    let mut ll = alloc_with_nan_prefix(data_len, first_valid_idx + fastk_period - 1);
325
326    let max_vals = max_rolling(&high[first_valid_idx..], fastk_period)
327        .map_err(|e| StochError::Other(e.to_string()))?;
328    let min_vals = min_rolling(&low[first_valid_idx..], fastk_period)
329        .map_err(|e| StochError::Other(e.to_string()))?;
330
331    for (i, &val) in max_vals.iter().enumerate() {
332        hh[i + first_valid_idx] = val;
333    }
334    for (i, &val) in min_vals.iter().enumerate() {
335        ll[i + first_valid_idx] = val;
336    }
337
338    let mut k_raw = alloc_with_nan_prefix(data_len, first_valid_idx + fastk_period - 1);
339
340    let chosen = match kernel {
341        Kernel::Auto => Kernel::Scalar,
342        other => other,
343    };
344    unsafe {
345        match chosen {
346            Kernel::Scalar | Kernel::ScalarBatch => stoch_scalar(
347                high,
348                low,
349                close,
350                &hh,
351                &ll,
352                fastk_period,
353                first_valid_idx,
354                &mut k_raw,
355            ),
356            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
357            Kernel::Avx2 | Kernel::Avx2Batch => stoch_avx2(
358                high,
359                low,
360                close,
361                &hh,
362                &ll,
363                fastk_period,
364                first_valid_idx,
365                &mut k_raw,
366            ),
367            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
368            Kernel::Avx512 | Kernel::Avx512Batch => stoch_avx512(
369                high,
370                low,
371                close,
372                &hh,
373                &ll,
374                fastk_period,
375                first_valid_idx,
376                &mut k_raw,
377            ),
378            _ => unreachable!(),
379        }
380    }
381
382    let slowk_ma_type = input.get_slowk_ma_type();
383    let slowd_ma_type = input.get_slowd_ma_type();
384
385    let k_first_valid = first_valid_idx + fastk_period - 1;
386    if (slowk_ma_type == "sma" || slowk_ma_type == "SMA")
387        && (slowd_ma_type == "sma" || slowd_ma_type == "SMA")
388    {
389        return stoch_classic_sma(&k_raw, slowk_period, slowd_period, k_first_valid);
390    } else if (slowk_ma_type == "ema" || slowk_ma_type == "EMA")
391        && (slowd_ma_type == "ema" || slowd_ma_type == "EMA")
392    {
393        return stoch_classic_ema(&k_raw, slowk_period, slowd_period, k_first_valid);
394    }
395
396    let k_vec = ma(&slowk_ma_type, MaData::Slice(&k_raw), slowk_period)
397        .map_err(|e| StochError::Other(e.to_string()))?;
398    let d_vec = ma(&slowd_ma_type, MaData::Slice(&k_vec), slowd_period)
399        .map_err(|e| StochError::Other(e.to_string()))?;
400    Ok(StochOutput { k: k_vec, d: d_vec })
401}
402
403pub fn stoch_into_slices(
404    out_k: &mut [f64],
405    out_d: &mut [f64],
406    input: &StochInput,
407    kernel: Kernel,
408) -> Result<(), StochError> {
409    let StochOutput { k, d } = stoch_with_kernel(input, kernel)?;
410    if out_k.len() != k.len() {
411        return Err(StochError::OutputLengthMismatch {
412            expected: k.len(),
413            got: out_k.len(),
414        });
415    }
416    if out_d.len() != d.len() {
417        return Err(StochError::OutputLengthMismatch {
418            expected: d.len(),
419            got: out_d.len(),
420        });
421    }
422    out_k.copy_from_slice(&k);
423    out_d.copy_from_slice(&d);
424    Ok(())
425}
426
427#[inline]
428fn prefill_nan_prefix(dst: &mut [f64], warm: usize) {
429    let warm = warm.min(dst.len());
430    for v in &mut dst[..warm] {
431        *v = f64::from_bits(0x7ff8_0000_0000_0000);
432    }
433}
434
435#[inline]
436fn stoch_compute_into(
437    input: &StochInput,
438    out_k: &mut [f64],
439    out_d: &mut [f64],
440    kernel: Kernel,
441) -> Result<(), StochError> {
442    let (high, low, close) = match &input.data {
443        StochData::Candles { candles } => {
444            let high = candles
445                .select_candle_field("high")
446                .map_err(|e| StochError::Other(e.to_string()))?;
447            let low = candles
448                .select_candle_field("low")
449                .map_err(|e| StochError::Other(e.to_string()))?;
450            let close = candles
451                .select_candle_field("close")
452                .map_err(|e| StochError::Other(e.to_string()))?;
453            (high, low, close)
454        }
455        StochData::Slices { high, low, close } => (*high, *low, *close),
456    };
457
458    let len = high.len();
459    if len == 0 || low.is_empty() || close.is_empty() {
460        return Err(StochError::EmptyInputData);
461    }
462    if len != low.len() || len != close.len() {
463        return Err(StochError::MismatchedLength);
464    }
465    if out_k.len() != len {
466        return Err(StochError::OutputLengthMismatch {
467            expected: len,
468            got: out_k.len(),
469        });
470    }
471    if out_d.len() != len {
472        return Err(StochError::OutputLengthMismatch {
473            expected: len,
474            got: out_d.len(),
475        });
476    }
477
478    let fastk_period = input.get_fastk_period();
479    let slowk_period = input.get_slowk_period();
480    let slowd_period = input.get_slowd_period();
481
482    if fastk_period == 0 || fastk_period > len {
483        return Err(StochError::InvalidPeriod {
484            period: fastk_period,
485            data_len: len,
486        });
487    }
488    if slowk_period == 0 || slowk_period > len {
489        return Err(StochError::InvalidPeriod {
490            period: slowk_period,
491            data_len: len,
492        });
493    }
494    if slowd_period == 0 || slowd_period > len {
495        return Err(StochError::InvalidPeriod {
496            period: slowd_period,
497            data_len: len,
498        });
499    }
500
501    let first = high
502        .iter()
503        .zip(low.iter())
504        .zip(close.iter())
505        .position(|((h, l), c)| !h.is_nan() && !l.is_nan() && !c.is_nan())
506        .ok_or(StochError::AllValuesNaN)?;
507
508    if (len - first) < fastk_period {
509        return Err(StochError::NotEnoughValidData {
510            needed: fastk_period,
511            valid: len - first,
512        });
513    }
514
515    let slowk_ma_type = input.get_slowk_ma_type();
516    let slowd_ma_type = input.get_slowd_ma_type();
517    let chosen = match kernel {
518        Kernel::Auto => Kernel::Scalar,
519        other => other,
520    };
521
522    if (slowk_ma_type == "sma" || slowk_ma_type == "SMA")
523        && (slowd_ma_type == "sma" || slowd_ma_type == "SMA")
524        && matches!(chosen, Kernel::Scalar | Kernel::ScalarBatch)
525    {
526        return stoch_classic_sma_into_single_pass(
527            high,
528            low,
529            close,
530            fastk_period,
531            slowk_period,
532            slowd_period,
533            first,
534            out_k,
535            out_d,
536        );
537    }
538
539    let mut hh = alloc_with_nan_prefix(len, first + fastk_period - 1);
540    let mut ll = alloc_with_nan_prefix(len, first + fastk_period - 1);
541    let highs =
542        max_rolling(&high[first..], fastk_period).map_err(|e| StochError::Other(e.to_string()))?;
543    let lows =
544        min_rolling(&low[first..], fastk_period).map_err(|e| StochError::Other(e.to_string()))?;
545    for (i, &v) in highs.iter().enumerate() {
546        hh[first + i] = v;
547    }
548    for (i, &v) in lows.iter().enumerate() {
549        ll[first + i] = v;
550    }
551
552    let mut k_raw = alloc_with_nan_prefix(len, first + fastk_period - 1);
553
554    unsafe {
555        match chosen {
556            Kernel::Scalar | Kernel::ScalarBatch => {
557                stoch_scalar(high, low, close, &hh, &ll, fastk_period, first, &mut k_raw)
558            }
559            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
560            Kernel::Avx2 | Kernel::Avx2Batch => {
561                stoch_avx2(high, low, close, &hh, &ll, fastk_period, first, &mut k_raw)
562            }
563            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
564            Kernel::Avx512 | Kernel::Avx512Batch => {
565                stoch_avx512(high, low, close, &hh, &ll, fastk_period, first, &mut k_raw)
566            }
567            _ => unreachable!(),
568        }
569    }
570
571    let k_first_valid = first + fastk_period - 1;
572
573    if (slowk_ma_type == "sma" || slowk_ma_type == "SMA")
574        && (slowd_ma_type == "sma" || slowd_ma_type == "SMA")
575    {
576        prefill_nan_prefix(out_k, k_first_valid + slowk_period - 1);
577        prefill_nan_prefix(out_d, k_first_valid + slowk_period + slowd_period - 2);
578
579        let mut sum_k = 0.0;
580        let k_start = k_first_valid;
581        for i in k_start..(k_start + slowk_period).min(len) {
582            if !k_raw[i].is_nan() {
583                sum_k += k_raw[i];
584            }
585        }
586        if k_start + slowk_period - 1 < len {
587            out_k[k_start + slowk_period - 1] = sum_k / slowk_period as f64;
588        }
589        for i in (k_start + slowk_period)..len {
590            let old = k_raw[i - slowk_period];
591            let newv = k_raw[i];
592            if !old.is_nan() {
593                sum_k -= old;
594            }
595            if !newv.is_nan() {
596                sum_k += newv;
597            }
598            out_k[i] = sum_k / slowk_period as f64;
599        }
600
601        let mut sum_d = 0.0;
602        let d_start = k_first_valid + slowk_period - 1;
603        for i in d_start..(d_start + slowd_period).min(len) {
604            if !out_k[i].is_nan() {
605                sum_d += out_k[i];
606            }
607        }
608        if d_start + slowd_period - 1 < len {
609            out_d[d_start + slowd_period - 1] = sum_d / slowd_period as f64;
610        }
611        for i in (d_start + slowd_period)..len {
612            let old = out_k[i - slowd_period];
613            let newv = out_k[i];
614            if !old.is_nan() {
615                sum_d -= old;
616            }
617            if !newv.is_nan() {
618                sum_d += newv;
619            }
620            out_d[i] = sum_d / slowd_period as f64;
621        }
622        return Ok(());
623    }
624
625    if (slowk_ma_type == "ema" || slowk_ma_type == "EMA")
626        && (slowd_ma_type == "ema" || slowd_ma_type == "EMA")
627    {
628        prefill_nan_prefix(out_k, k_first_valid + slowk_period - 1);
629        prefill_nan_prefix(out_d, k_first_valid + slowk_period + slowd_period - 2);
630
631        let alpha_k = 2.0 / (slowk_period as f64 + 1.0);
632        let one_minus_alpha_k = 1.0 - alpha_k;
633        let k_warm = k_first_valid + slowk_period - 1;
634        let mut sum_k = 0.0;
635        let mut cnt_k = 0;
636        for i in k_first_valid..(k_first_valid + slowk_period).min(len) {
637            if !k_raw[i].is_nan() {
638                sum_k += k_raw[i];
639                cnt_k += 1;
640            }
641        }
642        if cnt_k > 0 && k_warm < len {
643            let mut ema_k = sum_k / cnt_k as f64;
644            out_k[k_warm] = ema_k;
645            for i in (k_warm + 1)..len {
646                if !k_raw[i].is_nan() {
647                    ema_k = alpha_k * k_raw[i] + one_minus_alpha_k * ema_k;
648                }
649                out_k[i] = ema_k;
650            }
651        } else {
652            for i in k_warm..len {
653                out_k[i] = f64::from_bits(0x7ff8_0000_0000_0000);
654            }
655        }
656
657        let alpha_d = 2.0 / (slowd_period as f64 + 1.0);
658        let one_minus_alpha_d = 1.0 - alpha_d;
659        let d_warm = k_first_valid + slowk_period + slowd_period - 2;
660        let d_start = k_first_valid + slowk_period - 1;
661        let mut sum_d = 0.0;
662        let mut cnt_d = 0;
663        for i in d_start..(d_start + slowd_period).min(len) {
664            if !out_k[i].is_nan() {
665                sum_d += out_k[i];
666                cnt_d += 1;
667            }
668        }
669        if cnt_d > 0 && d_warm < len {
670            let mut ema_d = sum_d / cnt_d as f64;
671            out_d[d_warm] = ema_d;
672            for i in (d_warm + 1)..len {
673                if !out_k[i].is_nan() {
674                    ema_d = alpha_d * out_k[i] + one_minus_alpha_d * ema_d;
675                }
676                out_d[i] = ema_d;
677            }
678        } else {
679            for i in d_warm..len {
680                out_d[i] = f64::from_bits(0x7ff8_0000_0000_0000);
681            }
682        }
683        return Ok(());
684    }
685
686    let k_vec = ma(&slowk_ma_type, MaData::Slice(&k_raw), slowk_period)
687        .map_err(|e| StochError::Other(e.to_string()))?;
688    let d_vec = ma(&slowd_ma_type, MaData::Slice(&k_vec), slowd_period)
689        .map_err(|e| StochError::Other(e.to_string()))?;
690    out_k.copy_from_slice(&k_vec);
691    out_d.copy_from_slice(&d_vec);
692    Ok(())
693}
694
695#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
696pub fn stoch_into(
697    input: &StochInput,
698    out_k: &mut [f64],
699    out_d: &mut [f64],
700) -> Result<(), StochError> {
701    #[cfg(test)]
702    {
703        stoch_into_slices(out_k, out_d, input, Kernel::Auto)
704    }
705    #[cfg(not(test))]
706    {
707        stoch_compute_into(input, out_k, out_d, Kernel::Auto)
708    }
709}
710
711fn stoch_classic_sma_into_single_pass(
712    high: &[f64],
713    low: &[f64],
714    close: &[f64],
715    fastk_period: usize,
716    slowk_period: usize,
717    slowd_period: usize,
718    first: usize,
719    out_k: &mut [f64],
720    out_d: &mut [f64],
721) -> Result<(), StochError> {
722    let len = close.len();
723
724    let k_first_valid = first + fastk_period - 1;
725    let k_warm = k_first_valid + slowk_period - 1;
726    let d_warm = k_first_valid + slowk_period + slowd_period - 2;
727
728    prefill_nan_prefix(out_k, k_warm);
729    prefill_nan_prefix(out_d, d_warm);
730
731    let mut trail = first;
732    let mut maxi = first;
733    let mut mini = first;
734    let mut max = high[first];
735    let mut min = low[first];
736
737    let mut k_buf = vec![0.0f64; slowk_period];
738    let mut k_pos: usize = 0;
739    let mut k_sum = 0.0f64;
740    let mut k_count: usize = 0;
741
742    let mut d_buf = vec![0.0f64; slowd_period];
743    let mut d_pos: usize = 0;
744    let mut d_sum = 0.0f64;
745    let mut d_count: usize = 0;
746
747    const SCALE: f64 = 100.0;
748    const EPS: f64 = f64::EPSILON;
749
750    for i in first..len {
751        if i >= first + fastk_period {
752            trail += 1;
753        }
754
755        let bar_h = high[i];
756        if maxi < trail {
757            maxi = trail;
758            max = high[maxi];
759            let mut j = trail;
760            while j < i {
761                j += 1;
762                let v = high[j];
763                if v >= max {
764                    max = v;
765                    maxi = j;
766                }
767            }
768        } else if bar_h >= max {
769            maxi = i;
770            max = bar_h;
771        }
772
773        let bar_l = low[i];
774        if mini < trail {
775            mini = trail;
776            min = low[mini];
777            let mut j = trail;
778            while j < i {
779                j += 1;
780                let v = low[j];
781                if v <= min {
782                    min = v;
783                    mini = j;
784                }
785            }
786        } else if bar_l <= min {
787            mini = i;
788            min = bar_l;
789        }
790
791        if i < k_first_valid {
792            continue;
793        }
794
795        let c = close[i];
796        let denom = max - min;
797        let k_raw = if denom.abs() < EPS {
798            50.0
799        } else {
800            (c - min).mul_add(SCALE / denom, 0.0)
801        };
802
803        if k_count >= slowk_period {
804            k_sum -= k_buf[k_pos];
805        }
806        k_buf[k_pos] = k_raw;
807        k_sum += k_raw;
808        k_count += 1;
809        k_pos += 1;
810        if k_pos == slowk_period {
811            k_pos = 0;
812        }
813
814        if i >= k_warm {
815            let k_sma = k_sum / slowk_period as f64;
816            out_k[i] = k_sma;
817
818            if d_count >= slowd_period {
819                d_sum -= d_buf[d_pos];
820            }
821            d_buf[d_pos] = k_sma;
822            d_sum += k_sma;
823            d_count += 1;
824            d_pos += 1;
825            if d_pos == slowd_period {
826                d_pos = 0;
827            }
828
829            if i >= d_warm {
830                out_d[i] = d_sum / slowd_period as f64;
831            }
832        }
833    }
834
835    Ok(())
836}
837
838#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
839#[inline]
840pub fn stoch_avx512(
841    high: &[f64],
842    low: &[f64],
843    close: &[f64],
844    hh: &[f64],
845    ll: &[f64],
846    fastk_period: usize,
847    first_valid: usize,
848    out: &mut [f64],
849) {
850    if fastk_period <= 32 {
851        unsafe { stoch_avx512_short(high, low, close, hh, ll, fastk_period, first_valid, out) }
852    } else {
853        unsafe { stoch_avx512_long(high, low, close, hh, ll, fastk_period, first_valid, out) }
854    }
855}
856
857#[inline]
858pub fn stoch_scalar(
859    _high: &[f64],
860    _low: &[f64],
861    close: &[f64],
862    hh: &[f64],
863    ll: &[f64],
864    fastk_period: usize,
865    first_val: usize,
866    out: &mut [f64],
867) {
868    let start = first_val + fastk_period - 1;
869    if start >= close.len() {
870        return;
871    }
872
873    const SCALE: f64 = 100.0;
874    const EPS: f64 = f64::EPSILON;
875
876    let c = &close[start..];
877    let h = &hh[start..];
878    let l = &ll[start..];
879    let outv = &mut out[start..];
880
881    for (o, (&cv, (&hv, &lv))) in outv.iter_mut().zip(c.iter().zip(h.iter().zip(l.iter()))) {
882        let d = hv - lv;
883        *o = if d.abs() < EPS {
884            50.0
885        } else {
886            (cv - lv).mul_add(SCALE / d, 0.0)
887        };
888    }
889}
890
891#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
892#[inline]
893pub fn stoch_avx2(
894    high: &[f64],
895    low: &[f64],
896    close: &[f64],
897    hh: &[f64],
898    ll: &[f64],
899    fastk_period: usize,
900    first_valid: usize,
901    out: &mut [f64],
902) {
903    unsafe { stoch_avx2_impl(high, low, close, hh, ll, fastk_period, first_valid, out) }
904}
905
906#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
907#[target_feature(enable = "avx2")]
908unsafe fn stoch_avx2_impl(
909    _high: &[f64],
910    _low: &[f64],
911    close: &[f64],
912    hh: &[f64],
913    ll: &[f64],
914    fastk_period: usize,
915    first_valid: usize,
916    out: &mut [f64],
917) {
918    let start = first_valid + fastk_period - 1;
919    if start >= close.len() {
920        return;
921    }
922
923    let n = close.len() - start;
924    let mut i = 0usize;
925
926    let c_ptr = close.as_ptr().add(start);
927    let h_ptr = hh.as_ptr().add(start);
928    let l_ptr = ll.as_ptr().add(start);
929    let o_ptr = out.as_mut_ptr().add(start);
930
931    const STEP: usize = 4;
932    let vec_end = n & !(STEP - 1);
933
934    let scale = _mm256_set1_pd(100.0);
935    let fifty = _mm256_set1_pd(50.0);
936    let eps = _mm256_set1_pd(f64::EPSILON);
937    let sign_mask = _mm256_set1_pd(-0.0);
938
939    while i + STEP <= vec_end {
940        let c0 = _mm256_loadu_pd(c_ptr.add(i));
941        let h0 = _mm256_loadu_pd(h_ptr.add(i));
942        let l0 = _mm256_loadu_pd(l_ptr.add(i));
943        let d0 = _mm256_sub_pd(h0, l0);
944        let n0 = _mm256_sub_pd(c0, l0);
945        let a0 = _mm256_andnot_pd(sign_mask, d0);
946        let m0 = _mm256_cmp_pd(a0, eps, _CMP_LT_OQ);
947        let inv0 = _mm256_div_pd(scale, d0);
948        let v0 = _mm256_mul_pd(n0, inv0);
949        let o0 = _mm256_blendv_pd(v0, fifty, m0);
950
951        if i + 2 * STEP <= vec_end {
952            let c1 = _mm256_loadu_pd(c_ptr.add(i + STEP));
953            let h1 = _mm256_loadu_pd(h_ptr.add(i + STEP));
954            let l1 = _mm256_loadu_pd(l_ptr.add(i + STEP));
955            let d1 = _mm256_sub_pd(h1, l1);
956            let n1 = _mm256_sub_pd(c1, l1);
957            let a1 = _mm256_andnot_pd(sign_mask, d1);
958            let m1 = _mm256_cmp_pd(a1, eps, _CMP_LT_OQ);
959            let inv1 = _mm256_div_pd(scale, d1);
960            let v1 = _mm256_mul_pd(n1, inv1);
961            let o1 = _mm256_blendv_pd(v1, fifty, m1);
962
963            _mm256_storeu_pd(o_ptr.add(i), o0);
964            _mm256_storeu_pd(o_ptr.add(i + STEP), o1);
965            i += 2 * STEP;
966        } else {
967            _mm256_storeu_pd(o_ptr.add(i), o0);
968            i += STEP;
969        }
970    }
971
972    while i < n {
973        let c = *c_ptr.add(i);
974        let l = *l_ptr.add(i);
975        let d = *h_ptr.add(i) - l;
976        *o_ptr.add(i) = if d.abs() < f64::EPSILON {
977            50.0
978        } else {
979            (c - l) * (100.0 / d)
980        };
981        i += 1;
982    }
983}
984
985#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
986#[inline]
987pub fn stoch_avx512_short(
988    high: &[f64],
989    low: &[f64],
990    close: &[f64],
991    hh: &[f64],
992    ll: &[f64],
993    fastk_period: usize,
994    first_valid: usize,
995    out: &mut [f64],
996) {
997    unsafe { stoch_avx512_impl(high, low, close, hh, ll, fastk_period, first_valid, out) }
998}
999
1000#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1001#[inline]
1002pub fn stoch_avx512_long(
1003    high: &[f64],
1004    low: &[f64],
1005    close: &[f64],
1006    hh: &[f64],
1007    ll: &[f64],
1008    fastk_period: usize,
1009    first_valid: usize,
1010    out: &mut [f64],
1011) {
1012    unsafe { stoch_avx512_impl(high, low, close, hh, ll, fastk_period, first_valid, out) }
1013}
1014
1015#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1016#[target_feature(enable = "avx512f")]
1017unsafe fn stoch_avx512_impl(
1018    _high: &[f64],
1019    _low: &[f64],
1020    close: &[f64],
1021    hh: &[f64],
1022    ll: &[f64],
1023    fastk_period: usize,
1024    first_valid: usize,
1025    out: &mut [f64],
1026) {
1027    let start = first_valid + fastk_period - 1;
1028    if start >= close.len() {
1029        return;
1030    }
1031
1032    let n = close.len() - start;
1033
1034    let c_ptr = close.as_ptr().add(start);
1035    let h_ptr = hh.as_ptr().add(start);
1036    let l_ptr = ll.as_ptr().add(start);
1037    let o_ptr = out.as_mut_ptr().add(start);
1038
1039    const STEP: usize = 8;
1040    let vec_end = n & !(STEP - 1);
1041
1042    let scale = _mm512_set1_pd(100.0);
1043    let fifty = _mm512_set1_pd(50.0);
1044    let eps = _mm512_set1_pd(f64::EPSILON);
1045    let sign_mask = _mm512_set1_pd(-0.0);
1046
1047    let mut i = 0usize;
1048    while i + STEP <= vec_end {
1049        let c0 = _mm512_loadu_pd(c_ptr.add(i));
1050        let h0 = _mm512_loadu_pd(h_ptr.add(i));
1051        let l0 = _mm512_loadu_pd(l_ptr.add(i));
1052        let d0 = _mm512_sub_pd(h0, l0);
1053        let n0 = _mm512_sub_pd(c0, l0);
1054        let a0 = _mm512_andnot_pd(sign_mask, d0);
1055        let m0: __mmask8 = _mm512_cmp_pd_mask(a0, eps, _CMP_LT_OQ);
1056        let inv0 = _mm512_div_pd(scale, d0);
1057        let v0 = _mm512_mul_pd(n0, inv0);
1058        let o0 = _mm512_mask_blend_pd(m0, v0, fifty);
1059
1060        if i + 2 * STEP <= vec_end {
1061            let c1 = _mm512_loadu_pd(c_ptr.add(i + STEP));
1062            let h1 = _mm512_loadu_pd(h_ptr.add(i + STEP));
1063            let l1 = _mm512_loadu_pd(l_ptr.add(i + STEP));
1064            let d1 = _mm512_sub_pd(h1, l1);
1065            let n1 = _mm512_sub_pd(c1, l1);
1066            let a1 = _mm512_andnot_pd(sign_mask, d1);
1067            let m1: __mmask8 = _mm512_cmp_pd_mask(a1, eps, _CMP_LT_OQ);
1068            let inv1 = _mm512_div_pd(scale, d1);
1069            let v1 = _mm512_mul_pd(n1, inv1);
1070            let o1 = _mm512_mask_blend_pd(m1, v1, fifty);
1071
1072            _mm512_storeu_pd(o_ptr.add(i), o0);
1073            _mm512_storeu_pd(o_ptr.add(i + STEP), o1);
1074            i += 2 * STEP;
1075        } else {
1076            _mm512_storeu_pd(o_ptr.add(i), o0);
1077            i += STEP;
1078        }
1079    }
1080
1081    while i < n {
1082        let c = *c_ptr.add(i);
1083        let l = *l_ptr.add(i);
1084        let d = *h_ptr.add(i) - l;
1085        *o_ptr.add(i) = if d.abs() < f64::EPSILON {
1086            50.0
1087        } else {
1088            (c - l) * (100.0 / d)
1089        };
1090        i += 1;
1091    }
1092}
1093
1094#[derive(Clone, Debug)]
1095pub struct StochBatchRange {
1096    pub fastk_period: (usize, usize, usize),
1097    pub slowk_period: (usize, usize, usize),
1098    pub slowk_ma_type: (String, String, f64),
1099    pub slowd_period: (usize, usize, usize),
1100    pub slowd_ma_type: (String, String, f64),
1101}
1102
1103impl Default for StochBatchRange {
1104    fn default() -> Self {
1105        Self {
1106            fastk_period: (14, 263, 1),
1107            slowk_period: (3, 3, 0),
1108            slowk_ma_type: ("sma".to_string(), "sma".to_string(), 0.0),
1109            slowd_period: (3, 3, 0),
1110            slowd_ma_type: ("sma".to_string(), "sma".to_string(), 0.0),
1111        }
1112    }
1113}
1114
1115#[derive(Clone, Debug, Default)]
1116pub struct StochBatchBuilder {
1117    range: StochBatchRange,
1118    kernel: Kernel,
1119}
1120
1121impl StochBatchBuilder {
1122    pub fn new() -> Self {
1123        Self::default()
1124    }
1125    pub fn kernel(mut self, k: Kernel) -> Self {
1126        self.kernel = k;
1127        self
1128    }
1129    pub fn fastk_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1130        self.range.fastk_period = (start, end, step);
1131        self
1132    }
1133    pub fn fastk_period_static(mut self, p: usize) -> Self {
1134        self.range.fastk_period = (p, p, 0);
1135        self
1136    }
1137    pub fn slowk_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1138        self.range.slowk_period = (start, end, step);
1139        self
1140    }
1141    pub fn slowk_period_static(mut self, p: usize) -> Self {
1142        self.range.slowk_period = (p, p, 0);
1143        self
1144    }
1145    pub fn slowk_ma_type_static(mut self, t: &str) -> Self {
1146        self.range.slowk_ma_type = (t.to_string(), t.to_string(), 0.0);
1147        self
1148    }
1149    pub fn slowd_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1150        self.range.slowd_period = (start, end, step);
1151        self
1152    }
1153    pub fn slowd_period_static(mut self, p: usize) -> Self {
1154        self.range.slowd_period = (p, p, 0);
1155        self
1156    }
1157    pub fn slowd_ma_type_static(mut self, t: &str) -> Self {
1158        self.range.slowd_ma_type = (t.to_string(), t.to_string(), 0.0);
1159        self
1160    }
1161
1162    pub fn apply_slices(
1163        self,
1164        high: &[f64],
1165        low: &[f64],
1166        close: &[f64],
1167    ) -> Result<StochBatchOutput, StochError> {
1168        stoch_batch_with_kernel(high, low, close, &self.range, self.kernel)
1169    }
1170    pub fn apply_candles(self, c: &Candles) -> Result<StochBatchOutput, StochError> {
1171        let high = source_type(c, "high");
1172        let low = source_type(c, "low");
1173        let close = source_type(c, "close");
1174        self.apply_slices(high, low, close)
1175    }
1176}
1177
1178pub fn stoch_batch_with_kernel(
1179    high: &[f64],
1180    low: &[f64],
1181    close: &[f64],
1182    sweep: &StochBatchRange,
1183    k: Kernel,
1184) -> Result<StochBatchOutput, StochError> {
1185    let kernel = match k {
1186        Kernel::Auto => Kernel::ScalarBatch,
1187        other if other.is_batch() => other,
1188        other => return Err(StochError::InvalidKernelForBatch(other)),
1189    };
1190    let simd = match kernel {
1191        Kernel::Avx512Batch => Kernel::Avx512,
1192        Kernel::Avx2Batch => Kernel::Avx2,
1193        Kernel::ScalarBatch => Kernel::Scalar,
1194        _ => unreachable!(),
1195    };
1196    stoch_batch_par_slice(high, low, close, sweep, simd)
1197}
1198
1199#[derive(Clone, Debug)]
1200pub struct StochBatchOutput {
1201    pub k: Vec<f64>,
1202    pub d: Vec<f64>,
1203    pub combos: Vec<StochParams>,
1204    pub rows: usize,
1205    pub cols: usize,
1206}
1207impl StochBatchOutput {
1208    pub fn row_for_params(&self, p: &StochParams) -> Option<usize> {
1209        self.combos.iter().position(|c| {
1210            c.fastk_period == p.fastk_period
1211                && c.slowk_period == p.slowk_period
1212                && c.slowk_ma_type == p.slowk_ma_type
1213                && c.slowd_period == p.slowd_period
1214                && c.slowd_ma_type == p.slowd_ma_type
1215        })
1216    }
1217    pub fn values_for(&self, p: &StochParams) -> Option<(&[f64], &[f64])> {
1218        self.row_for_params(p).map(|row| {
1219            let start = row * self.cols;
1220            (
1221                &self.k[start..start + self.cols],
1222                &self.d[start..start + self.cols],
1223            )
1224        })
1225    }
1226}
1227
1228#[inline(always)]
1229fn expand_grid(r: &StochBatchRange) -> Result<Vec<StochParams>, StochError> {
1230    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, StochError> {
1231        if step == 0 || start == end {
1232            return Ok(vec![start]);
1233        }
1234
1235        let mut v = Vec::new();
1236        if start < end {
1237            let mut x = start;
1238            loop {
1239                v.push(x);
1240                match x.checked_add(step) {
1241                    Some(next) if next <= end => x = next,
1242                    Some(_) | None => break,
1243                }
1244            }
1245        } else {
1246            let mut x = start;
1247            loop {
1248                v.push(x);
1249                match x.checked_sub(step) {
1250                    Some(next) if next >= end => x = next,
1251                    Some(_) | None => break,
1252                }
1253            }
1254        }
1255
1256        if v.is_empty() {
1257            Err(StochError::InvalidRange { start, end, step })
1258        } else {
1259            Ok(v)
1260        }
1261    }
1262    fn axis_str((start, end, _): (String, String, f64)) -> Vec<String> {
1263        if start == end {
1264            vec![start]
1265        } else {
1266            vec![start, end]
1267        }
1268    }
1269    let fastk_periods = axis_usize(r.fastk_period)?;
1270    let slowk_periods = axis_usize(r.slowk_period)?;
1271    let slowk_types = axis_str(r.slowk_ma_type.clone());
1272    let slowd_periods = axis_usize(r.slowd_period)?;
1273    let slowd_types = axis_str(r.slowd_ma_type.clone());
1274
1275    let combos_len = fastk_periods
1276        .len()
1277        .checked_mul(slowk_periods.len())
1278        .and_then(|v| v.checked_mul(slowk_types.len()))
1279        .and_then(|v| v.checked_mul(slowd_periods.len()))
1280        .and_then(|v| v.checked_mul(slowd_types.len()))
1281        .ok_or(StochError::InvalidRange {
1282            start: r.fastk_period.0,
1283            end: r.fastk_period.1,
1284            step: r.fastk_period.2,
1285        })?;
1286
1287    let mut out = Vec::with_capacity(combos_len);
1288    for &fkp in &fastk_periods {
1289        for &skp in &slowk_periods {
1290            for skt in &slowk_types {
1291                for &sdp in &slowd_periods {
1292                    for sdt in &slowd_types {
1293                        out.push(StochParams {
1294                            fastk_period: Some(fkp),
1295                            slowk_period: Some(skp),
1296                            slowk_ma_type: Some(skt.clone()),
1297                            slowd_period: Some(sdp),
1298                            slowd_ma_type: Some(sdt.clone()),
1299                        });
1300                    }
1301                }
1302            }
1303        }
1304    }
1305    Ok(out)
1306}
1307
1308#[inline(always)]
1309pub fn stoch_batch_slice(
1310    high: &[f64],
1311    low: &[f64],
1312    close: &[f64],
1313    sweep: &StochBatchRange,
1314    kern: Kernel,
1315) -> Result<StochBatchOutput, StochError> {
1316    stoch_batch_inner(high, low, close, sweep, kern, false)
1317}
1318
1319#[inline(always)]
1320pub fn stoch_batch_par_slice(
1321    high: &[f64],
1322    low: &[f64],
1323    close: &[f64],
1324    sweep: &StochBatchRange,
1325    kern: Kernel,
1326) -> Result<StochBatchOutput, StochError> {
1327    stoch_batch_inner(high, low, close, sweep, kern, true)
1328}
1329
1330#[inline(always)]
1331fn stoch_batch_inner(
1332    high: &[f64],
1333    low: &[f64],
1334    close: &[f64],
1335    sweep: &StochBatchRange,
1336    kern: Kernel,
1337    parallel: bool,
1338) -> Result<StochBatchOutput, StochError> {
1339    let combos = expand_grid(sweep)?;
1340
1341    let n = high.len();
1342    if n == 0 || low.len() != n || close.len() != n {
1343        return Err(StochError::MismatchedLength);
1344    }
1345
1346    let first = high
1347        .iter()
1348        .zip(low.iter())
1349        .zip(close.iter())
1350        .position(|((h, l), c)| !h.is_nan() && !l.is_nan() && !c.is_nan())
1351        .ok_or(StochError::AllValuesNaN)?;
1352    let max_fkp = combos
1353        .iter()
1354        .map(|c| c.fastk_period.unwrap())
1355        .max()
1356        .unwrap();
1357    if n - first < max_fkp {
1358        return Err(StochError::NotEnoughValidData {
1359            needed: max_fkp,
1360            valid: n - first,
1361        });
1362    }
1363
1364    let rows = combos.len();
1365    let cols = n;
1366
1367    rows.checked_mul(cols).ok_or(StochError::InvalidRange {
1368        start: rows,
1369        end: cols,
1370        step: 0,
1371    })?;
1372
1373    let mut k_mu = make_uninit_matrix(rows, cols);
1374    let mut d_mu = make_uninit_matrix(rows, cols);
1375
1376    let warm_k: Vec<usize> = combos
1377        .iter()
1378        .map(|c| first + c.fastk_period.unwrap() - 1)
1379        .collect();
1380    init_matrix_prefixes(&mut k_mu, cols, &warm_k);
1381    init_matrix_prefixes(&mut d_mu, cols, &warm_k);
1382
1383    let mut k_guard = core::mem::ManuallyDrop::new(k_mu);
1384    let mut d_guard = core::mem::ManuallyDrop::new(d_mu);
1385    let k_mat: &mut [f64] =
1386        unsafe { core::slice::from_raw_parts_mut(k_guard.as_mut_ptr() as *mut f64, k_guard.len()) };
1387    let d_mat: &mut [f64] =
1388        unsafe { core::slice::from_raw_parts_mut(d_guard.as_mut_ptr() as *mut f64, d_guard.len()) };
1389
1390    use std::collections::HashMap;
1391    let mut groups: HashMap<usize, Vec<usize>> = HashMap::new();
1392    for (row, prm) in combos.iter().enumerate() {
1393        groups
1394            .entry(prm.fastk_period.unwrap())
1395            .or_default()
1396            .push(row);
1397    }
1398
1399    let mut compute_k_raw = |fkp: usize| -> Vec<f64> {
1400        let mut hh = alloc_with_nan_prefix(cols, first + fkp - 1);
1401        let mut ll = alloc_with_nan_prefix(cols, first + fkp - 1);
1402        let highs = max_rolling(&high[first..], fkp).unwrap();
1403        let lows = min_rolling(&low[first..], fkp).unwrap();
1404        for (i, &v) in highs.iter().enumerate() {
1405            hh[first + i] = v;
1406        }
1407        for (i, &v) in lows.iter().enumerate() {
1408            ll[first + i] = v;
1409        }
1410
1411        let mut k_raw = alloc_with_nan_prefix(cols, first + fkp - 1);
1412        unsafe {
1413            match kern {
1414                Kernel::Scalar => {
1415                    stoch_row_scalar(high, low, close, &hh, &ll, fkp, first, &mut k_raw)
1416                }
1417                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1418                Kernel::Avx2 => stoch_row_avx2(high, low, close, &hh, &ll, fkp, first, &mut k_raw),
1419                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1420                Kernel::Avx512 => {
1421                    stoch_row_avx512(high, low, close, &hh, &ll, fkp, first, &mut k_raw)
1422                }
1423                _ => unreachable!(),
1424            }
1425        }
1426        k_raw
1427    };
1428
1429    for (fkp, rows_in_group) in groups {
1430        let k_raw = compute_k_raw(fkp);
1431        for &row in &rows_in_group {
1432            let prm = &combos[row];
1433            let k_vec = ma(
1434                prm.slowk_ma_type.as_ref().unwrap(),
1435                MaData::Slice(&k_raw),
1436                prm.slowk_period.unwrap(),
1437            )
1438            .unwrap();
1439            let d_vec = ma(
1440                prm.slowd_ma_type.as_ref().unwrap(),
1441                MaData::Slice(&k_vec),
1442                prm.slowd_period.unwrap(),
1443            )
1444            .unwrap();
1445            let start = row * cols;
1446            let dst_k = &mut k_mat[start..start + cols];
1447            let dst_d = &mut d_mat[start..start + cols];
1448            dst_k.copy_from_slice(&k_vec);
1449            dst_d.copy_from_slice(&d_vec);
1450        }
1451    }
1452
1453    let k = unsafe {
1454        Vec::from_raw_parts(
1455            k_guard.as_mut_ptr() as *mut f64,
1456            k_guard.len(),
1457            k_guard.capacity(),
1458        )
1459    };
1460    let d = unsafe {
1461        Vec::from_raw_parts(
1462            d_guard.as_mut_ptr() as *mut f64,
1463            d_guard.len(),
1464            d_guard.capacity(),
1465        )
1466    };
1467    core::mem::forget(k_guard);
1468    core::mem::forget(d_guard);
1469
1470    Ok(StochBatchOutput {
1471        k,
1472        d,
1473        combos,
1474        rows,
1475        cols,
1476    })
1477}
1478
1479#[inline(always)]
1480unsafe fn stoch_row_scalar(
1481    _high: &[f64],
1482    _low: &[f64],
1483    close: &[f64],
1484    hh: &[f64],
1485    ll: &[f64],
1486    fastk_period: usize,
1487    first: usize,
1488    out: &mut [f64],
1489) {
1490    let start = first + fastk_period - 1;
1491    if start >= close.len() {
1492        return;
1493    }
1494
1495    const SCALE: f64 = 100.0;
1496    const EPS: f64 = f64::EPSILON;
1497
1498    let c = &close[start..];
1499    let h = &hh[start..];
1500    let l = &ll[start..];
1501    let outv = &mut out[start..];
1502
1503    for (o, (&cv, (&hv, &lv))) in outv.iter_mut().zip(c.iter().zip(h.iter().zip(l.iter()))) {
1504        let d = hv - lv;
1505        *o = if d.abs() < EPS {
1506            50.0
1507        } else {
1508            (cv - lv).mul_add(SCALE / d, 0.0)
1509        };
1510    }
1511}
1512#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1513#[inline(always)]
1514unsafe fn stoch_row_avx2(
1515    high: &[f64],
1516    low: &[f64],
1517    close: &[f64],
1518    hh: &[f64],
1519    ll: &[f64],
1520    fastk_period: usize,
1521    first: usize,
1522    out: &mut [f64],
1523) {
1524    stoch_row_avx2_impl(high, low, close, hh, ll, fastk_period, first, out)
1525}
1526
1527#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1528#[target_feature(enable = "avx2")]
1529unsafe fn stoch_row_avx2_impl(
1530    _high: &[f64],
1531    _low: &[f64],
1532    close: &[f64],
1533    hh: &[f64],
1534    ll: &[f64],
1535    fastk_period: usize,
1536    first: usize,
1537    out: &mut [f64],
1538) {
1539    let start = first + fastk_period - 1;
1540    if start >= close.len() {
1541        return;
1542    }
1543    let n = close.len() - start;
1544
1545    let mut i = 0usize;
1546    let c_ptr = close.as_ptr().add(start);
1547    let h_ptr = hh.as_ptr().add(start);
1548    let l_ptr = ll.as_ptr().add(start);
1549    let o_ptr = out.as_mut_ptr().add(start);
1550
1551    const STEP: usize = 4;
1552    let vec_end = n & !(STEP - 1);
1553
1554    let scale = _mm256_set1_pd(100.0);
1555    let fifty = _mm256_set1_pd(50.0);
1556    let eps = _mm256_set1_pd(f64::EPSILON);
1557    let sign_mask = _mm256_set1_pd(-0.0);
1558
1559    while i + STEP <= vec_end {
1560        let c0 = _mm256_loadu_pd(c_ptr.add(i));
1561        let h0 = _mm256_loadu_pd(h_ptr.add(i));
1562        let l0 = _mm256_loadu_pd(l_ptr.add(i));
1563        let d0 = _mm256_sub_pd(h0, l0);
1564        let n0 = _mm256_sub_pd(c0, l0);
1565        let a0 = _mm256_andnot_pd(sign_mask, d0);
1566        let m0 = _mm256_cmp_pd(a0, eps, _CMP_LT_OQ);
1567        let inv0 = _mm256_div_pd(scale, d0);
1568        let v0 = _mm256_mul_pd(n0, inv0);
1569        let o0 = _mm256_blendv_pd(v0, fifty, m0);
1570
1571        if i + 2 * STEP <= vec_end {
1572            let c1 = _mm256_loadu_pd(c_ptr.add(i + STEP));
1573            let h1 = _mm256_loadu_pd(h_ptr.add(i + STEP));
1574            let l1 = _mm256_loadu_pd(l_ptr.add(i + STEP));
1575            let d1 = _mm256_sub_pd(h1, l1);
1576            let n1 = _mm256_sub_pd(c1, l1);
1577            let a1 = _mm256_andnot_pd(sign_mask, d1);
1578            let m1 = _mm256_cmp_pd(a1, eps, _CMP_LT_OQ);
1579            let inv1 = _mm256_div_pd(scale, d1);
1580            let v1 = _mm256_mul_pd(n1, inv1);
1581            let o1 = _mm256_blendv_pd(v1, fifty, m1);
1582
1583            _mm256_storeu_pd(o_ptr.add(i), o0);
1584            _mm256_storeu_pd(o_ptr.add(i + STEP), o1);
1585            i += 2 * STEP;
1586        } else {
1587            _mm256_storeu_pd(o_ptr.add(i), o0);
1588            i += STEP;
1589        }
1590    }
1591
1592    while i < n {
1593        let c = *c_ptr.add(i);
1594        let l = *l_ptr.add(i);
1595        let d = *h_ptr.add(i) - l;
1596        *o_ptr.add(i) = if d.abs() < f64::EPSILON {
1597            50.0
1598        } else {
1599            (c - l) * (100.0 / d)
1600        };
1601        i += 1;
1602    }
1603}
1604#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1605#[inline(always)]
1606unsafe fn stoch_row_avx512(
1607    high: &[f64],
1608    low: &[f64],
1609    close: &[f64],
1610    hh: &[f64],
1611    ll: &[f64],
1612    fastk_period: usize,
1613    first: usize,
1614    out: &mut [f64],
1615) {
1616    if fastk_period <= 32 {
1617        stoch_row_avx512_short(high, low, close, hh, ll, fastk_period, first, out)
1618    } else {
1619        stoch_row_avx512_long(high, low, close, hh, ll, fastk_period, first, out)
1620    }
1621}
1622#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1623#[inline(always)]
1624unsafe fn stoch_row_avx512_short(
1625    high: &[f64],
1626    low: &[f64],
1627    close: &[f64],
1628    hh: &[f64],
1629    ll: &[f64],
1630    fastk_period: usize,
1631    first: usize,
1632    out: &mut [f64],
1633) {
1634    stoch_row_avx512_impl(high, low, close, hh, ll, fastk_period, first, out)
1635}
1636#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1637#[inline(always)]
1638unsafe fn stoch_row_avx512_long(
1639    high: &[f64],
1640    low: &[f64],
1641    close: &[f64],
1642    hh: &[f64],
1643    ll: &[f64],
1644    fastk_period: usize,
1645    first: usize,
1646    out: &mut [f64],
1647) {
1648    stoch_row_avx512_impl(high, low, close, hh, ll, fastk_period, first, out)
1649}
1650
1651#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1652#[target_feature(enable = "avx512f")]
1653unsafe fn stoch_row_avx512_impl(
1654    _high: &[f64],
1655    _low: &[f64],
1656    close: &[f64],
1657    hh: &[f64],
1658    ll: &[f64],
1659    fastk_period: usize,
1660    first: usize,
1661    out: &mut [f64],
1662) {
1663    let start = first + fastk_period - 1;
1664    if start >= close.len() {
1665        return;
1666    }
1667    let n = close.len() - start;
1668
1669    let c_ptr = close.as_ptr().add(start);
1670    let h_ptr = hh.as_ptr().add(start);
1671    let l_ptr = ll.as_ptr().add(start);
1672    let o_ptr = out.as_mut_ptr().add(start);
1673
1674    const STEP: usize = 8;
1675    let vec_end = n & !(STEP - 1);
1676
1677    let scale = _mm512_set1_pd(100.0);
1678    let fifty = _mm512_set1_pd(50.0);
1679    let eps = _mm512_set1_pd(f64::EPSILON);
1680    let sign_mask = _mm512_set1_pd(-0.0);
1681
1682    let mut i = 0usize;
1683    while i + STEP <= vec_end {
1684        let c0 = _mm512_loadu_pd(c_ptr.add(i));
1685        let h0 = _mm512_loadu_pd(h_ptr.add(i));
1686        let l0 = _mm512_loadu_pd(l_ptr.add(i));
1687        let d0 = _mm512_sub_pd(h0, l0);
1688        let n0 = _mm512_sub_pd(c0, l0);
1689        let a0 = _mm512_andnot_pd(sign_mask, d0);
1690        let m0: __mmask8 = _mm512_cmp_pd_mask(a0, eps, _CMP_LT_OQ);
1691        let inv0 = _mm512_div_pd(scale, d0);
1692        let v0 = _mm512_mul_pd(n0, inv0);
1693        let o0 = _mm512_mask_blend_pd(m0, v0, fifty);
1694
1695        if i + 2 * STEP <= vec_end {
1696            let c1 = _mm512_loadu_pd(c_ptr.add(i + STEP));
1697            let h1 = _mm512_loadu_pd(h_ptr.add(i + STEP));
1698            let l1 = _mm512_loadu_pd(l_ptr.add(i + STEP));
1699            let d1 = _mm512_sub_pd(h1, l1);
1700            let n1 = _mm512_sub_pd(c1, l1);
1701            let a1 = _mm512_andnot_pd(sign_mask, d1);
1702            let m1: __mmask8 = _mm512_cmp_pd_mask(a1, eps, _CMP_LT_OQ);
1703            let inv1 = _mm512_div_pd(scale, d1);
1704            let v1 = _mm512_mul_pd(n1, inv1);
1705            let o1 = _mm512_mask_blend_pd(m1, v1, fifty);
1706
1707            _mm512_storeu_pd(o_ptr.add(i), o0);
1708            _mm512_storeu_pd(o_ptr.add(i + STEP), o1);
1709            i += 2 * STEP;
1710        } else {
1711            _mm512_storeu_pd(o_ptr.add(i), o0);
1712            i += STEP;
1713        }
1714    }
1715
1716    while i < n {
1717        let c = *c_ptr.add(i);
1718        let l = *l_ptr.add(i);
1719        let d = *h_ptr.add(i) - l;
1720        *o_ptr.add(i) = if d.abs() < f64::EPSILON {
1721            50.0
1722        } else {
1723            (c - l) * (100.0 / d)
1724        };
1725        i += 1;
1726    }
1727}
1728
1729#[derive(Debug, Clone)]
1730struct DeqEntry {
1731    val: f64,
1732    idx: usize,
1733}
1734
1735#[derive(Debug, Clone)]
1736pub struct StochStream {
1737    fastk_period: usize,
1738    slowk_period: usize,
1739    slowk_ma_type: String,
1740    slowd_period: usize,
1741    slowd_ma_type: String,
1742
1743    maxq: VecDeque<DeqEntry>,
1744    minq: VecDeque<DeqEntry>,
1745    t: usize,
1746    have_window: bool,
1747
1748    k_sma_buf: Vec<f64>,
1749    k_sma_sum: f64,
1750    k_sma_head: usize,
1751    k_sma_count: usize,
1752
1753    k_ema: Option<f64>,
1754    k_ema_seed_sum: f64,
1755    k_ema_seed_count: usize,
1756    alpha_k: f64,
1757
1758    d_sma_buf: Vec<f64>,
1759    d_sma_sum: f64,
1760    d_sma_head: usize,
1761    d_sma_count: usize,
1762
1763    d_ema: Option<f64>,
1764    d_ema_seed_sum: f64,
1765    d_ema_seed_count: usize,
1766    alpha_d: f64,
1767
1768    k_stream: Option<Vec<f64>>,
1769    d_stream: Option<Vec<f64>>,
1770}
1771
1772impl StochStream {
1773    pub fn try_new(params: StochParams) -> Result<Self, StochError> {
1774        let fastk_period = params.fastk_period.unwrap_or(14);
1775        let slowk_period = params.slowk_period.unwrap_or(3);
1776        let slowd_period = params.slowd_period.unwrap_or(3);
1777        if fastk_period == 0 || slowk_period == 0 || slowd_period == 0 {
1778            return Err(StochError::InvalidPeriod {
1779                period: 0,
1780                data_len: 0,
1781            });
1782        }
1783
1784        let slowk_ma_type = params.slowk_ma_type.unwrap_or_else(|| "sma".to_string());
1785        let slowd_ma_type = params.slowd_ma_type.unwrap_or_else(|| "sma".to_string());
1786
1787        let alpha_k = 2.0 / (slowk_period as f64 + 1.0);
1788        let alpha_d = 2.0 / (slowd_period as f64 + 1.0);
1789
1790        Ok(Self {
1791            fastk_period,
1792            slowk_period,
1793            slowk_ma_type,
1794            slowd_period,
1795            slowd_ma_type,
1796
1797            maxq: VecDeque::with_capacity(fastk_period),
1798            minq: VecDeque::with_capacity(fastk_period),
1799            t: 0,
1800            have_window: false,
1801
1802            k_sma_buf: vec![f64::NAN; slowk_period.max(1)],
1803            k_sma_sum: 0.0,
1804            k_sma_head: 0,
1805            k_sma_count: 0,
1806
1807            k_ema: None,
1808            k_ema_seed_sum: 0.0,
1809            k_ema_seed_count: 0,
1810            alpha_k,
1811
1812            d_sma_buf: vec![f64::NAN; slowd_period.max(1)],
1813            d_sma_sum: 0.0,
1814            d_sma_head: 0,
1815            d_sma_count: 0,
1816
1817            d_ema: None,
1818            d_ema_seed_sum: 0.0,
1819            d_ema_seed_count: 0,
1820            alpha_d,
1821
1822            k_stream: None,
1823            d_stream: None,
1824        })
1825    }
1826
1827    #[inline(always)]
1828    fn evict_older_than(dq: &mut VecDeque<DeqEntry>, min_idx: usize) {
1829        while let Some(front) = dq.front() {
1830            if front.idx < min_idx {
1831                dq.pop_front();
1832            } else {
1833                break;
1834            }
1835        }
1836    }
1837
1838    #[inline(always)]
1839    fn push_maxq(&mut self, val: f64, idx: usize) {
1840        while let Some(back) = self.maxq.back() {
1841            if back.val <= val {
1842                self.maxq.pop_back();
1843            } else {
1844                break;
1845            }
1846        }
1847        self.maxq.push_back(DeqEntry { val, idx });
1848    }
1849
1850    #[inline(always)]
1851    fn push_minq(&mut self, val: f64, idx: usize) {
1852        while let Some(back) = self.minq.back() {
1853            if back.val >= val {
1854                self.minq.pop_back();
1855            } else {
1856                break;
1857            }
1858        }
1859        self.minq.push_back(DeqEntry { val, idx });
1860    }
1861
1862    pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
1863        if !high.is_finite() || !low.is_finite() || !close.is_finite() {
1864            return None;
1865        }
1866
1867        let idx = self.t;
1868        self.t = self.t.wrapping_add(1);
1869
1870        self.push_maxq(high, idx);
1871        self.push_minq(low, idx);
1872
1873        let seen = idx + 1;
1874        if seen >= self.fastk_period {
1875            let window_start = seen - self.fastk_period;
1876            Self::evict_older_than(&mut self.maxq, window_start);
1877            Self::evict_older_than(&mut self.minq, window_start);
1878            self.have_window = true;
1879        }
1880
1881        if !self.have_window {
1882            return None;
1883        }
1884
1885        debug_assert!(!self.maxq.is_empty() && !self.minq.is_empty());
1886        let hh = self.maxq.front().unwrap().val;
1887        let ll = self.minq.front().unwrap().val;
1888
1889        const SCALE: f64 = 100.0;
1890        const EPS: f64 = f64::EPSILON;
1891
1892        let denom = hh - ll;
1893        let k_raw = if denom.abs() < EPS {
1894            50.0
1895        } else {
1896            (close - ll).mul_add(SCALE / denom, 0.0)
1897        };
1898
1899        let k_last = if self.slowk_ma_type.eq_ignore_ascii_case("sma") {
1900            if self.slowk_period == 1 {
1901                k_raw
1902            } else if self.k_sma_count < self.slowk_period {
1903                self.k_sma_sum += k_raw;
1904                self.k_sma_buf[self.k_sma_head] = k_raw;
1905                self.k_sma_head = (self.k_sma_head + 1) % self.slowk_period;
1906                self.k_sma_count += 1;
1907                if self.k_sma_count == self.slowk_period {
1908                    self.k_sma_sum / self.slowk_period as f64
1909                } else {
1910                    f64::NAN
1911                }
1912            } else {
1913                let old = self.k_sma_buf[self.k_sma_head];
1914                self.k_sma_sum += k_raw - old;
1915                self.k_sma_buf[self.k_sma_head] = k_raw;
1916                self.k_sma_head = (self.k_sma_head + 1) % self.slowk_period;
1917                self.k_sma_sum / self.slowk_period as f64
1918            }
1919        } else if self.slowk_ma_type.eq_ignore_ascii_case("ema") {
1920            if self.slowk_period == 1 {
1921                self.k_ema = Some(k_raw);
1922                k_raw
1923            } else if self.k_ema.is_none() {
1924                self.k_ema_seed_sum += k_raw;
1925                self.k_ema_seed_count += 1;
1926                if self.k_ema_seed_count == self.slowk_period {
1927                    let seed = self.k_ema_seed_sum / self.slowk_period as f64;
1928                    self.k_ema = Some(seed);
1929                    seed
1930                } else {
1931                    f64::NAN
1932                }
1933            } else {
1934                let prev = self.k_ema.unwrap();
1935                let ema = prev + self.alpha_k * (k_raw - prev);
1936                self.k_ema = Some(ema);
1937                ema
1938            }
1939        } else {
1940            let mut k_vec = self
1941                .k_stream
1942                .take()
1943                .unwrap_or_else(|| vec![f64::NAN; self.slowk_period]);
1944            k_vec.remove(0);
1945            k_vec.push(k_raw);
1946            self.k_stream = Some(k_vec.clone());
1947
1948            match ma(
1949                &self.slowk_ma_type,
1950                MaData::Slice(&k_vec),
1951                self.slowk_period,
1952            ) {
1953                Ok(slowk) => *slowk.last().unwrap_or(&f64::NAN),
1954                Err(_) => k_raw,
1955            }
1956        };
1957
1958        let d_last = if self.slowd_ma_type.eq_ignore_ascii_case("sma") {
1959            if self.slowd_period == 1 {
1960                k_last
1961            } else if !k_last.is_finite() {
1962                f64::NAN
1963            } else if self.d_sma_count < self.slowd_period {
1964                self.d_sma_sum += k_last;
1965                self.d_sma_buf[self.d_sma_head] = k_last;
1966                self.d_sma_head = (self.d_sma_head + 1) % self.slowd_period;
1967                self.d_sma_count += 1;
1968                if self.d_sma_count == self.slowd_period {
1969                    self.d_sma_sum / self.slowd_period as f64
1970                } else {
1971                    f64::NAN
1972                }
1973            } else {
1974                let old = self.d_sma_buf[self.d_sma_head];
1975                self.d_sma_sum += k_last - old;
1976                self.d_sma_buf[self.d_sma_head] = k_last;
1977                self.d_sma_head = (self.d_sma_head + 1) % self.slowd_period;
1978                self.d_sma_sum / self.slowd_period as f64
1979            }
1980        } else if self.slowd_ma_type.eq_ignore_ascii_case("ema") {
1981            if self.slowd_period == 1 {
1982                self.d_ema = Some(k_last);
1983                k_last
1984            } else if !k_last.is_finite() {
1985                f64::NAN
1986            } else if self.d_ema.is_none() {
1987                self.d_ema_seed_sum += k_last;
1988                self.d_ema_seed_count += 1;
1989                if self.d_ema_seed_count == self.slowd_period {
1990                    let seed = self.d_ema_seed_sum / self.slowd_period as f64;
1991                    self.d_ema = Some(seed);
1992                    seed
1993                } else {
1994                    f64::NAN
1995                }
1996            } else {
1997                let prev = self.d_ema.unwrap();
1998                let ema = prev + self.alpha_d * (k_last - prev);
1999                self.d_ema = Some(ema);
2000                ema
2001            }
2002        } else {
2003            let mut d_vec = self
2004                .d_stream
2005                .take()
2006                .unwrap_or_else(|| vec![f64::NAN; self.slowd_period]);
2007            d_vec.remove(0);
2008            d_vec.push(k_last);
2009            self.d_stream = Some(d_vec.clone());
2010
2011            match ma(
2012                &self.slowd_ma_type,
2013                MaData::Slice(&d_vec),
2014                self.slowd_period,
2015            ) {
2016                Ok(slowd) => *slowd.last().unwrap_or(&f64::NAN),
2017                Err(_) => k_last,
2018            }
2019        };
2020
2021        Some((k_last, d_last))
2022    }
2023}
2024
2025#[cfg(all(feature = "python", feature = "cuda"))]
2026use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2027#[cfg(all(feature = "python", feature = "cuda"))]
2028use cust::context::Context;
2029#[cfg(all(feature = "python", feature = "cuda"))]
2030use cust::memory::DeviceBuffer;
2031#[cfg(all(feature = "python", feature = "cuda"))]
2032use std::sync::Arc;
2033
2034#[cfg(all(feature = "python", feature = "cuda"))]
2035#[pyclass(
2036    module = "ta_indicators.cuda",
2037    name = "StochDeviceArrayF32",
2038    unsendable
2039)]
2040pub struct StochDeviceArrayF32Py {
2041    pub(crate) buf: Option<DeviceBuffer<f32>>,
2042    pub(crate) rows: usize,
2043    pub(crate) cols: usize,
2044    pub(crate) _ctx: Arc<Context>,
2045    pub(crate) device_id: u32,
2046}
2047
2048#[cfg(all(feature = "python", feature = "cuda"))]
2049#[pymethods]
2050impl StochDeviceArrayF32Py {
2051    #[getter]
2052    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2053        let d = PyDict::new(py);
2054        d.set_item("shape", (self.rows, self.cols))?;
2055        d.set_item("typestr", "<f4")?;
2056        d.set_item(
2057            "strides",
2058            (
2059                self.cols * std::mem::size_of::<f32>(),
2060                std::mem::size_of::<f32>(),
2061            ),
2062        )?;
2063        let buf = self
2064            .buf
2065            .as_ref()
2066            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2067        let ptr = buf.as_device_ptr().as_raw() as usize;
2068        d.set_item("data", (ptr, false))?;
2069        d.set_item("version", 3)?;
2070        Ok(d)
2071    }
2072
2073    fn __dlpack_device__(&self) -> (i32, i32) {
2074        (2, self.device_id as i32)
2075    }
2076
2077    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2078    fn __dlpack__<'py>(
2079        &mut self,
2080        py: Python<'py>,
2081        stream: Option<pyo3::PyObject>,
2082        max_version: Option<pyo3::PyObject>,
2083        dl_device: Option<pyo3::PyObject>,
2084        copy: Option<pyo3::PyObject>,
2085    ) -> PyResult<PyObject> {
2086        let (kdl, alloc_dev) = self.__dlpack_device__();
2087        if let Some(dev_obj) = dl_device.as_ref() {
2088            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2089                if dev_ty != kdl || dev_id != alloc_dev {
2090                    let wants_copy = copy
2091                        .as_ref()
2092                        .and_then(|c| c.extract::<bool>(py).ok())
2093                        .unwrap_or(false);
2094                    if wants_copy {
2095                        return Err(PyValueError::new_err(
2096                            "stoch: device copy not implemented for __dlpack__",
2097                        ));
2098                    } else {
2099                        return Err(PyValueError::new_err(
2100                            "stoch: requested dl_device does not match buffer device",
2101                        ));
2102                    }
2103                }
2104            }
2105        }
2106        let _ = stream;
2107
2108        if let Some(copy_obj) = copy.as_ref() {
2109            let do_copy: bool = copy_obj.extract(py)?;
2110            if do_copy {
2111                return Err(PyValueError::new_err(
2112                    "stoch: __dlpack__(copy=True) not supported",
2113                ));
2114            }
2115        }
2116
2117        let buf = self
2118            .buf
2119            .take()
2120            .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
2121
2122        let rows = self.rows;
2123        let cols = self.cols;
2124
2125        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2126
2127        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2128    }
2129}
2130
2131#[cfg(all(feature = "python", feature = "cuda"))]
2132#[pyfunction(name = "stoch_cuda_batch_dev")]
2133#[pyo3(signature = (high_f32, low_f32, close_f32, fastk_period=(14,14,0), slowk_period=(3,3,0), slowd_period=(3,3,0), slowk_ma_type="sma", slowd_ma_type="sma", device_id=0))]
2134pub fn stoch_cuda_batch_dev_py(
2135    py: Python<'_>,
2136    high_f32: numpy::PyReadonlyArray1<'_, f32>,
2137    low_f32: numpy::PyReadonlyArray1<'_, f32>,
2138    close_f32: numpy::PyReadonlyArray1<'_, f32>,
2139    fastk_period: (usize, usize, usize),
2140    slowk_period: (usize, usize, usize),
2141    slowd_period: (usize, usize, usize),
2142    slowk_ma_type: &str,
2143    slowd_ma_type: &str,
2144    device_id: usize,
2145) -> PyResult<(StochDeviceArrayF32Py, StochDeviceArrayF32Py)> {
2146    use crate::cuda::cuda_available;
2147    if !cuda_available() {
2148        return Err(PyValueError::new_err("CUDA not available"));
2149    }
2150    let h = high_f32.as_slice()?;
2151    let l = low_f32.as_slice()?;
2152    let c = close_f32.as_slice()?;
2153    let sweep = StochBatchRange {
2154        fastk_period,
2155        slowk_period,
2156        slowk_ma_type: (slowk_ma_type.to_string(), slowk_ma_type.to_string(), 0.0),
2157        slowd_period,
2158        slowd_ma_type: (slowd_ma_type.to_string(), slowd_ma_type.to_string(), 0.0),
2159    };
2160    let (k_buf, d_buf, rows, cols, ctx, dev_id) = py.allow_threads(|| {
2161        let cuda = crate::cuda::oscillators::CudaStoch::new(device_id)
2162            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2163        let batch = cuda
2164            .stoch_batch_dev(h, l, c, &sweep)
2165            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2166        let ctx = cuda.context_arc();
2167        Ok::<_, PyErr>((
2168            batch.k.buf,
2169            batch.d.buf,
2170            batch.k.rows,
2171            batch.k.cols,
2172            ctx,
2173            cuda.device_id(),
2174        ))
2175    })?;
2176    Ok((
2177        StochDeviceArrayF32Py {
2178            buf: Some(k_buf),
2179            rows,
2180            cols,
2181            _ctx: ctx.clone(),
2182            device_id: dev_id,
2183        },
2184        StochDeviceArrayF32Py {
2185            buf: Some(d_buf),
2186            rows,
2187            cols,
2188            _ctx: ctx,
2189            device_id: dev_id,
2190        },
2191    ))
2192}
2193
2194#[cfg(all(feature = "python", feature = "cuda"))]
2195#[pyfunction(name = "stoch_cuda_many_series_one_param_dev")]
2196#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, cols, rows, fastk_period=14, slowk_period=3, slowd_period=3, slowk_ma_type="sma", slowd_ma_type="sma", device_id=0))]
2197pub fn stoch_cuda_many_series_one_param_dev_py(
2198    py: Python<'_>,
2199    high_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2200    low_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2201    close_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2202    cols: usize,
2203    rows: usize,
2204    fastk_period: usize,
2205    slowk_period: usize,
2206    slowd_period: usize,
2207    slowk_ma_type: &str,
2208    slowd_ma_type: &str,
2209    device_id: usize,
2210) -> PyResult<(StochDeviceArrayF32Py, StochDeviceArrayF32Py)> {
2211    use crate::cuda::cuda_available;
2212    if !cuda_available() {
2213        return Err(PyValueError::new_err("CUDA not available"));
2214    }
2215    let h = high_tm_f32.as_slice()?;
2216    let l = low_tm_f32.as_slice()?;
2217    let c = close_tm_f32.as_slice()?;
2218    let params = StochParams {
2219        fastk_period: Some(fastk_period),
2220        slowk_period: Some(slowk_period),
2221        slowk_ma_type: Some(slowk_ma_type.to_string()),
2222        slowd_period: Some(slowd_period),
2223        slowd_ma_type: Some(slowd_ma_type.to_string()),
2224    };
2225    let (k_dev, d_dev, ctx, dev_id) = py.allow_threads(|| {
2226        let cuda = crate::cuda::oscillators::CudaStoch::new(device_id)
2227            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2228        let (k, d) = cuda
2229            .stoch_many_series_one_param_time_major_dev(h, l, c, cols, rows, &params)
2230            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2231        let ctx = cuda.context_arc();
2232        Ok::<_, PyErr>((k.buf, d.buf, ctx, cuda.device_id()))
2233    })?;
2234    Ok((
2235        StochDeviceArrayF32Py {
2236            buf: Some(k_dev),
2237            rows,
2238            cols,
2239            _ctx: ctx.clone(),
2240            device_id: dev_id,
2241        },
2242        StochDeviceArrayF32Py {
2243            buf: Some(d_dev),
2244            rows,
2245            cols,
2246            _ctx: ctx,
2247            device_id: dev_id,
2248        },
2249    ))
2250}
2251
2252#[cfg(feature = "python")]
2253#[pyfunction(name = "stoch")]
2254#[pyo3(signature = (high, low, close, fastk_period=14, slowk_period=3, slowk_ma_type="sma", slowd_period=3, slowd_ma_type="sma", kernel=None))]
2255pub fn stoch_py<'py>(
2256    py: Python<'py>,
2257    high: PyReadonlyArray1<'py, f64>,
2258    low: PyReadonlyArray1<'py, f64>,
2259    close: PyReadonlyArray1<'py, f64>,
2260    fastk_period: usize,
2261    slowk_period: usize,
2262    slowk_ma_type: &str,
2263    slowd_period: usize,
2264    slowd_ma_type: &str,
2265    kernel: Option<&str>,
2266) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
2267    let hi = high.as_slice()?;
2268    let lo = low.as_slice()?;
2269    let cl = close.as_slice()?;
2270    let params = StochParams {
2271        fastk_period: Some(fastk_period),
2272        slowk_period: Some(slowk_period),
2273        slowk_ma_type: Some(slowk_ma_type.to_string()),
2274        slowd_period: Some(slowd_period),
2275        slowd_ma_type: Some(slowd_ma_type.to_string()),
2276    };
2277    let kern = validate_kernel(kernel, false)?;
2278    let input = StochInput::from_slices(hi, lo, cl, params);
2279    let out = py
2280        .allow_threads(|| stoch_with_kernel(&input, kern))
2281        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2282    Ok((out.k.into_pyarray(py), out.d.into_pyarray(py)))
2283}
2284
2285#[cfg(feature = "python")]
2286#[pyfunction(name = "stoch_batch")]
2287#[pyo3(signature = (high, low, close, fastk_range, slowk_range, slowk_ma_type, slowd_range, slowd_ma_type, kernel=None))]
2288pub fn stoch_batch_py<'py>(
2289    py: Python<'py>,
2290    high: PyReadonlyArray1<'py, f64>,
2291    low: PyReadonlyArray1<'py, f64>,
2292    close: PyReadonlyArray1<'py, f64>,
2293    fastk_range: (usize, usize, usize),
2294    slowk_range: (usize, usize, usize),
2295    slowk_ma_type: &str,
2296    slowd_range: (usize, usize, usize),
2297    slowd_ma_type: &str,
2298    kernel: Option<&str>,
2299) -> PyResult<Bound<'py, PyDict>> {
2300    let hi = high.as_slice()?;
2301    let lo = low.as_slice()?;
2302    let cl = close.as_slice()?;
2303
2304    let sweep = StochBatchRange {
2305        fastk_period: fastk_range,
2306        slowk_period: slowk_range,
2307        slowk_ma_type: (slowk_ma_type.to_string(), slowk_ma_type.to_string(), 0.0),
2308        slowd_period: slowd_range,
2309        slowd_ma_type: (slowd_ma_type.to_string(), slowd_ma_type.to_string(), 0.0),
2310    };
2311
2312    let kern = validate_kernel(kernel, true)?;
2313    let out = py
2314        .allow_threads(|| stoch_batch_with_kernel(hi, lo, cl, &sweep, kern))
2315        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2316
2317    let rows = out.rows;
2318    let cols = out.cols;
2319    let total = rows
2320        .checked_mul(cols)
2321        .ok_or_else(|| PyValueError::new_err("stoch_batch: size overflow in rows*cols"))?;
2322
2323    let dict = PyDict::new(py);
2324
2325    let k_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2326    let d_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2327    unsafe { k_arr.as_slice_mut()? }.copy_from_slice(&out.k);
2328    unsafe { d_arr.as_slice_mut()? }.copy_from_slice(&out.d);
2329
2330    dict.set_item("k", k_arr.reshape((rows, cols))?)?;
2331    dict.set_item("d", d_arr.reshape((rows, cols))?)?;
2332    dict.set_item(
2333        "fastk_periods",
2334        out.combos
2335            .iter()
2336            .map(|p| p.fastk_period.unwrap() as u64)
2337            .collect::<Vec<_>>()
2338            .into_pyarray(py),
2339    )?;
2340    dict.set_item(
2341        "slowk_periods",
2342        out.combos
2343            .iter()
2344            .map(|p| p.slowk_period.unwrap() as u64)
2345            .collect::<Vec<_>>()
2346            .into_pyarray(py),
2347    )?;
2348    dict.set_item(
2349        "slowk_types",
2350        out.combos
2351            .iter()
2352            .map(|p| p.slowk_ma_type.as_deref().unwrap_or("sma"))
2353            .collect::<Vec<_>>(),
2354    )?;
2355    dict.set_item(
2356        "slowd_periods",
2357        out.combos
2358            .iter()
2359            .map(|p| p.slowd_period.unwrap() as u64)
2360            .collect::<Vec<_>>()
2361            .into_pyarray(py),
2362    )?;
2363    dict.set_item(
2364        "slowd_types",
2365        out.combos
2366            .iter()
2367            .map(|p| p.slowd_ma_type.as_deref().unwrap_or("sma"))
2368            .collect::<Vec<_>>(),
2369    )?;
2370
2371    Ok(dict)
2372}
2373
2374#[cfg(feature = "python")]
2375#[pyclass(name = "StochStream")]
2376pub struct StochStreamPy {
2377    stream: StochStream,
2378}
2379
2380#[cfg(feature = "python")]
2381#[pymethods]
2382impl StochStreamPy {
2383    #[new]
2384    fn new(
2385        fastk_period: usize,
2386        slowk_period: usize,
2387        slowk_ma_type: &str,
2388        slowd_period: usize,
2389        slowd_ma_type: &str,
2390    ) -> PyResult<Self> {
2391        let params = StochParams {
2392            fastk_period: Some(fastk_period),
2393            slowk_period: Some(slowk_period),
2394            slowk_ma_type: Some(slowk_ma_type.to_string()),
2395            slowd_period: Some(slowd_period),
2396            slowd_ma_type: Some(slowd_ma_type.to_string()),
2397        };
2398        Ok(Self {
2399            stream: StochStream::try_new(params)
2400                .map_err(|e| PyValueError::new_err(e.to_string()))?,
2401        })
2402    }
2403    fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
2404        self.stream.update(high, low, close)
2405    }
2406}
2407
2408#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2409#[derive(Serialize, Deserialize)]
2410pub struct StochResult {
2411    pub values: Vec<f64>,
2412    pub rows: usize,
2413    pub cols: usize,
2414}
2415
2416#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2417#[wasm_bindgen(js_name = stoch)]
2418pub fn stoch_js(
2419    high: &[f64],
2420    low: &[f64],
2421    close: &[f64],
2422    fastk_period: usize,
2423    slowk_period: usize,
2424    slowk_ma_type: &str,
2425    slowd_period: usize,
2426    slowd_ma_type: &str,
2427) -> Result<JsValue, JsValue> {
2428    let params = StochParams {
2429        fastk_period: Some(fastk_period),
2430        slowk_period: Some(slowk_period),
2431        slowk_ma_type: Some(slowk_ma_type.to_string()),
2432        slowd_period: Some(slowd_period),
2433        slowd_ma_type: Some(slowd_ma_type.to_string()),
2434    };
2435    let input = StochInput::from_slices(high, low, close, params);
2436    let out = stoch_with_kernel(&input, detect_best_kernel())
2437        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2438    let mut values = out.k;
2439    values.extend_from_slice(&out.d);
2440    serde_wasm_bindgen::to_value(&StochResult {
2441        values,
2442        rows: 2,
2443        cols: high.len(),
2444    })
2445    .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2446}
2447
2448#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2449#[derive(Serialize, Deserialize)]
2450pub struct StochBatchJsOutput {
2451    pub values: Vec<f64>,
2452    pub combos: Vec<StochParams>,
2453    pub rows_per_combo: usize,
2454    pub cols: usize,
2455}
2456
2457#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2458#[wasm_bindgen(js_name = stoch_batch)]
2459pub fn stoch_batch_unified_js(
2460    high: &[f64],
2461    low: &[f64],
2462    close: &[f64],
2463    fastk_start: usize,
2464    fastk_end: usize,
2465    fastk_step: usize,
2466    slowk_start: usize,
2467    slowk_end: usize,
2468    slowk_step: usize,
2469    slowk_ma_type: &str,
2470    slowd_start: usize,
2471    slowd_end: usize,
2472    slowd_step: usize,
2473    slowd_ma_type: &str,
2474) -> Result<JsValue, JsValue> {
2475    let sweep = StochBatchRange {
2476        fastk_period: (fastk_start, fastk_end, fastk_step),
2477        slowk_period: (slowk_start, slowk_end, slowk_step),
2478        slowk_ma_type: (slowk_ma_type.to_string(), slowk_ma_type.to_string(), 0.0),
2479        slowd_period: (slowd_start, slowd_end, slowd_step),
2480        slowd_ma_type: (slowd_ma_type.to_string(), slowd_ma_type.to_string(), 0.0),
2481    };
2482    let out = stoch_batch_inner(high, low, close, &sweep, detect_best_kernel(), false)
2483        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2484    let mut values = out.k.clone();
2485    values.extend_from_slice(&out.d);
2486    let js = StochBatchJsOutput {
2487        values,
2488        combos: out.combos,
2489        rows_per_combo: 2,
2490        cols: out.cols,
2491    };
2492    serde_wasm_bindgen::to_value(&js)
2493        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2494}
2495
2496#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2497#[wasm_bindgen]
2498pub fn stoch_alloc(len: usize) -> *mut f64 {
2499    let mut v = Vec::<f64>::with_capacity(len);
2500    let ptr = v.as_mut_ptr();
2501    core::mem::forget(v);
2502    ptr
2503}
2504
2505#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2506#[wasm_bindgen]
2507pub fn stoch_free(ptr: *mut f64, len: usize) {
2508    unsafe {
2509        let _ = Vec::from_raw_parts(ptr, len, len);
2510    }
2511}
2512
2513#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2514#[wasm_bindgen(js_name = stoch_into)]
2515pub fn stoch_into_js(
2516    high_ptr: *const f64,
2517    low_ptr: *const f64,
2518    close_ptr: *const f64,
2519    len: usize,
2520    fastk_period: usize,
2521    slowk_period: usize,
2522    slowk_ma_type: &str,
2523    slowd_period: usize,
2524    slowd_ma_type: &str,
2525    out_k_ptr: *mut f64,
2526    out_d_ptr: *mut f64,
2527) -> Result<(), JsValue> {
2528    if [high_ptr, low_ptr, close_ptr, out_k_ptr, out_d_ptr]
2529        .iter()
2530        .any(|p| p.is_null())
2531    {
2532        return Err(JsValue::from_str("null pointer"));
2533    }
2534    unsafe {
2535        let hi = core::slice::from_raw_parts(high_ptr, len);
2536        let lo = core::slice::from_raw_parts(low_ptr, len);
2537        let cl = core::slice::from_raw_parts(close_ptr, len);
2538        let mut ok = core::slice::from_raw_parts_mut(out_k_ptr, len);
2539        let mut od = core::slice::from_raw_parts_mut(out_d_ptr, len);
2540        let params = StochParams {
2541            fastk_period: Some(fastk_period),
2542            slowk_period: Some(slowk_period),
2543            slowk_ma_type: Some(slowk_ma_type.to_string()),
2544            slowd_period: Some(slowd_period),
2545            slowd_ma_type: Some(slowd_ma_type.to_string()),
2546        };
2547        let input = StochInput::from_slices(hi, lo, cl, params);
2548        stoch_into_slices(&mut ok, &mut od, &input, detect_best_kernel())
2549            .map_err(|e| JsValue::from_str(&e.to_string()))
2550    }
2551}
2552
2553#[cfg(test)]
2554mod tests {
2555    use super::*;
2556    use crate::skip_if_unsupported;
2557    use crate::utilities::data_loader::read_candles_from_csv;
2558    use crate::utilities::enums::Kernel;
2559
2560    fn check_stoch_partial_params(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2561        skip_if_unsupported!(kernel, test);
2562        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2563        let candles = read_candles_from_csv(file_path)?;
2564        let default_params = StochParams::default();
2565        let input = StochInput::from_candles(&candles, default_params);
2566        let output = stoch_with_kernel(&input, kernel)?;
2567        assert_eq!(output.k.len(), candles.close.len());
2568        assert_eq!(output.d.len(), candles.close.len());
2569        Ok(())
2570    }
2571    fn check_stoch_accuracy(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2572        skip_if_unsupported!(kernel, test);
2573        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2574        let candles = read_candles_from_csv(file_path)?;
2575        let input = StochInput::from_candles(&candles, StochParams::default());
2576        let result = stoch_with_kernel(&input, kernel)?;
2577        assert_eq!(result.k.len(), candles.close.len());
2578        assert_eq!(result.d.len(), candles.close.len());
2579        let last_five_k = [
2580            42.51122827572717,
2581            40.13864479593807,
2582            37.853934778363374,
2583            37.337021714266086,
2584            36.26053890551548,
2585        ];
2586        let last_five_d = [
2587            41.36561869426493,
2588            41.7691857059163,
2589            40.16793595000925,
2590            38.44320042952222,
2591            37.15049846604803,
2592        ];
2593        let k_slice = &result.k[result.k.len() - 5..];
2594        let d_slice = &result.d[result.d.len() - 5..];
2595        for i in 0..5 {
2596            assert!(
2597                (k_slice[i] - last_five_k[i]).abs() < 1e-6,
2598                "Mismatch in K at {}: got {}, expected {}",
2599                i,
2600                k_slice[i],
2601                last_five_k[i]
2602            );
2603            assert!(
2604                (d_slice[i] - last_five_d[i]).abs() < 1e-6,
2605                "Mismatch in D at {}: got {}, expected {}",
2606                i,
2607                d_slice[i],
2608                last_five_d[i]
2609            );
2610        }
2611        Ok(())
2612    }
2613    fn check_stoch_default_candles(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2614        skip_if_unsupported!(kernel, test);
2615        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2616        let candles = read_candles_from_csv(file_path)?;
2617        let input = StochInput::with_default_candles(&candles);
2618        let output = stoch_with_kernel(&input, kernel)?;
2619        assert_eq!(output.k.len(), candles.close.len());
2620        assert_eq!(output.d.len(), candles.close.len());
2621        Ok(())
2622    }
2623    fn check_stoch_zero_period(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2624        skip_if_unsupported!(kernel, test);
2625        let high = [10.0, 11.0, 12.0];
2626        let low = [9.0, 9.5, 10.5];
2627        let close = [9.5, 10.6, 11.5];
2628        let params = StochParams {
2629            fastk_period: Some(0),
2630            ..Default::default()
2631        };
2632        let input = StochInput::from_slices(&high, &low, &close, params);
2633        let result = stoch_with_kernel(&input, kernel);
2634        assert!(result.is_err());
2635        Ok(())
2636    }
2637    fn check_stoch_period_exceeds_length(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2638        skip_if_unsupported!(kernel, test);
2639        let high = [10.0, 11.0, 12.0];
2640        let low = [9.0, 9.5, 10.5];
2641        let close = [9.5, 10.6, 11.5];
2642        let params = StochParams {
2643            fastk_period: Some(10),
2644            ..Default::default()
2645        };
2646        let input = StochInput::from_slices(&high, &low, &close, params);
2647        let result = stoch_with_kernel(&input, kernel);
2648        assert!(result.is_err());
2649        Ok(())
2650    }
2651    fn check_stoch_all_nan(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2652        skip_if_unsupported!(kernel, test);
2653        let nan_data = [f64::NAN, f64::NAN, f64::NAN];
2654        let params = StochParams::default();
2655        let input = StochInput::from_slices(&nan_data, &nan_data, &nan_data, params);
2656        let result = stoch_with_kernel(&input, kernel);
2657        assert!(result.is_err());
2658        Ok(())
2659    }
2660
2661    #[cfg(debug_assertions)]
2662    fn check_stoch_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2663        skip_if_unsupported!(kernel, test_name);
2664
2665        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2666        let candles = read_candles_from_csv(file_path)?;
2667
2668        let test_params = vec![
2669            StochParams::default(),
2670            StochParams {
2671                fastk_period: Some(2),
2672                slowk_period: Some(1),
2673                slowd_period: Some(1),
2674                slowk_ma_type: Some("sma".to_string()),
2675                slowd_ma_type: Some("sma".to_string()),
2676            },
2677            StochParams {
2678                fastk_period: Some(5),
2679                slowk_period: Some(2),
2680                slowd_period: Some(2),
2681                slowk_ma_type: Some("sma".to_string()),
2682                slowd_ma_type: Some("sma".to_string()),
2683            },
2684            StochParams {
2685                fastk_period: Some(10),
2686                slowk_period: Some(5),
2687                slowd_period: Some(3),
2688                slowk_ma_type: Some("ema".to_string()),
2689                slowd_ma_type: Some("ema".to_string()),
2690            },
2691            StochParams {
2692                fastk_period: Some(14),
2693                slowk_period: Some(5),
2694                slowd_period: Some(5),
2695                slowk_ma_type: Some("sma".to_string()),
2696                slowd_ma_type: Some("ema".to_string()),
2697            },
2698            StochParams {
2699                fastk_period: Some(20),
2700                slowk_period: Some(3),
2701                slowd_period: Some(3),
2702                slowk_ma_type: Some("sma".to_string()),
2703                slowd_ma_type: Some("sma".to_string()),
2704            },
2705            StochParams {
2706                fastk_period: Some(50),
2707                slowk_period: Some(10),
2708                slowd_period: Some(10),
2709                slowk_ma_type: Some("ema".to_string()),
2710                slowd_ma_type: Some("sma".to_string()),
2711            },
2712            StochParams {
2713                fastk_period: Some(100),
2714                slowk_period: Some(20),
2715                slowd_period: Some(15),
2716                slowk_ma_type: Some("sma".to_string()),
2717                slowd_ma_type: Some("sma".to_string()),
2718            },
2719            StochParams {
2720                fastk_period: Some(7),
2721                slowk_period: Some(1),
2722                slowd_period: Some(7),
2723                slowk_ma_type: Some("sma".to_string()),
2724                slowd_ma_type: Some("ema".to_string()),
2725            },
2726            StochParams {
2727                fastk_period: Some(3),
2728                slowk_period: Some(3),
2729                slowd_period: Some(1),
2730                slowk_ma_type: Some("ema".to_string()),
2731                slowd_ma_type: Some("sma".to_string()),
2732            },
2733        ];
2734
2735        for (param_idx, params) in test_params.iter().enumerate() {
2736            let input = StochInput::from_candles(&candles, params.clone());
2737            let output = stoch_with_kernel(&input, kernel)?;
2738
2739            for (i, &val) in output.k.iter().enumerate() {
2740                if val.is_nan() {
2741                    continue;
2742                }
2743
2744                let bits = val.to_bits();
2745
2746                if bits == 0x11111111_11111111 {
2747                    panic!(
2748						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in K values \
2749						 with params: fastk_period={}, slowk_period={}, slowd_period={}, \
2750						 slowk_ma_type={}, slowd_ma_type={} (param set {})",
2751						test_name, val, bits, i,
2752						params.fastk_period.unwrap_or(14),
2753						params.slowk_period.unwrap_or(3),
2754						params.slowd_period.unwrap_or(3),
2755						params.slowk_ma_type.as_deref().unwrap_or("sma"),
2756						params.slowd_ma_type.as_deref().unwrap_or("sma"),
2757						param_idx
2758					);
2759                }
2760
2761                if bits == 0x22222222_22222222 {
2762                    panic!(
2763						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in K values \
2764						 with params: fastk_period={}, slowk_period={}, slowd_period={}, \
2765						 slowk_ma_type={}, slowd_ma_type={} (param set {})",
2766						test_name, val, bits, i,
2767						params.fastk_period.unwrap_or(14),
2768						params.slowk_period.unwrap_or(3),
2769						params.slowd_period.unwrap_or(3),
2770						params.slowk_ma_type.as_deref().unwrap_or("sma"),
2771						params.slowd_ma_type.as_deref().unwrap_or("sma"),
2772						param_idx
2773					);
2774                }
2775
2776                if bits == 0x33333333_33333333 {
2777                    panic!(
2778						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in K values \
2779						 with params: fastk_period={}, slowk_period={}, slowd_period={}, \
2780						 slowk_ma_type={}, slowd_ma_type={} (param set {})",
2781						test_name, val, bits, i,
2782						params.fastk_period.unwrap_or(14),
2783						params.slowk_period.unwrap_or(3),
2784						params.slowd_period.unwrap_or(3),
2785						params.slowk_ma_type.as_deref().unwrap_or("sma"),
2786						params.slowd_ma_type.as_deref().unwrap_or("sma"),
2787						param_idx
2788					);
2789                }
2790            }
2791
2792            for (i, &val) in output.d.iter().enumerate() {
2793                if val.is_nan() {
2794                    continue;
2795                }
2796
2797                let bits = val.to_bits();
2798
2799                if bits == 0x11111111_11111111 {
2800                    panic!(
2801						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in D values \
2802						 with params: fastk_period={}, slowk_period={}, slowd_period={}, \
2803						 slowk_ma_type={}, slowd_ma_type={} (param set {})",
2804						test_name, val, bits, i,
2805						params.fastk_period.unwrap_or(14),
2806						params.slowk_period.unwrap_or(3),
2807						params.slowd_period.unwrap_or(3),
2808						params.slowk_ma_type.as_deref().unwrap_or("sma"),
2809						params.slowd_ma_type.as_deref().unwrap_or("sma"),
2810						param_idx
2811					);
2812                }
2813
2814                if bits == 0x22222222_22222222 {
2815                    panic!(
2816						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in D values \
2817						 with params: fastk_period={}, slowk_period={}, slowd_period={}, \
2818						 slowk_ma_type={}, slowd_ma_type={} (param set {})",
2819						test_name, val, bits, i,
2820						params.fastk_period.unwrap_or(14),
2821						params.slowk_period.unwrap_or(3),
2822						params.slowd_period.unwrap_or(3),
2823						params.slowk_ma_type.as_deref().unwrap_or("sma"),
2824						params.slowd_ma_type.as_deref().unwrap_or("sma"),
2825						param_idx
2826					);
2827                }
2828
2829                if bits == 0x33333333_33333333 {
2830                    panic!(
2831						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in D values \
2832						 with params: fastk_period={}, slowk_period={}, slowd_period={}, \
2833						 slowk_ma_type={}, slowd_ma_type={} (param set {})",
2834						test_name, val, bits, i,
2835						params.fastk_period.unwrap_or(14),
2836						params.slowk_period.unwrap_or(3),
2837						params.slowd_period.unwrap_or(3),
2838						params.slowk_ma_type.as_deref().unwrap_or("sma"),
2839						params.slowd_ma_type.as_deref().unwrap_or("sma"),
2840						param_idx
2841					);
2842                }
2843            }
2844        }
2845
2846        Ok(())
2847    }
2848
2849    #[cfg(not(debug_assertions))]
2850    fn check_stoch_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2851        Ok(())
2852    }
2853
2854    #[cfg(feature = "proptest")]
2855    #[allow(clippy::float_cmp)]
2856    fn check_stoch_property(
2857        test_name: &str,
2858        kernel: Kernel,
2859    ) -> Result<(), Box<dyn std::error::Error>> {
2860        use proptest::prelude::*;
2861        skip_if_unsupported!(kernel, test_name);
2862
2863        let strat = (2usize..=50)
2864            .prop_flat_map(|fastk_period| {
2865                (
2866                    prop::collection::vec(
2867                        (1.0f64..1000.0f64, 0.001f64..0.1f64),
2868                        fastk_period.max(10)..400,
2869                    ),
2870                    Just(fastk_period),
2871                    1usize..=10,
2872                    1usize..=10,
2873                    prop::bool::ANY,
2874                    -0.01f64..0.01f64,
2875                    prop::bool::ANY,
2876                )
2877            })
2878            .prop_flat_map(
2879                |(
2880                    price_vol_pairs,
2881                    fastk_period,
2882                    slowk_period,
2883                    slowd_period,
2884                    use_ema,
2885                    trend,
2886                    is_flat,
2887                )| {
2888                    let len = price_vol_pairs.len();
2889                    (
2890                        Just((
2891                            price_vol_pairs,
2892                            fastk_period,
2893                            slowk_period,
2894                            slowd_period,
2895                            use_ema,
2896                            trend,
2897                            is_flat,
2898                        )),
2899                        prop::collection::vec(-1.0f64..1.0f64, len),
2900                        prop::collection::vec(0.0f64..1.0f64, len),
2901                    )
2902                },
2903            )
2904            .prop_map(
2905                |(
2906                    (
2907                        price_vol_pairs,
2908                        fastk_period,
2909                        slowk_period,
2910                        slowd_period,
2911                        use_ema,
2912                        trend,
2913                        is_flat,
2914                    ),
2915                    close_factors,
2916                    beta_params,
2917                )| {
2918                    let mut high = Vec::with_capacity(price_vol_pairs.len());
2919                    let mut low = Vec::with_capacity(price_vol_pairs.len());
2920                    let mut close = Vec::with_capacity(price_vol_pairs.len());
2921
2922                    let mut cumulative_trend = 1.0;
2923
2924                    for (i, ((base_price, volatility), (close_factor, beta))) in price_vol_pairs
2925                        .into_iter()
2926                        .zip(close_factors.into_iter().zip(beta_params))
2927                        .enumerate()
2928                    {
2929                        cumulative_trend *= 1.0 + trend;
2930                        let trended_price = base_price * cumulative_trend;
2931
2932                        if is_flat {
2933                            let flat_price = if i == 0 { base_price } else { high[0] };
2934                            high.push(flat_price);
2935                            low.push(flat_price);
2936                            close.push(flat_price);
2937                        } else {
2938                            let spread = trended_price * volatility;
2939                            let h = trended_price + spread;
2940                            let l = (trended_price - spread).max(0.01);
2941
2942                            let beta_factor = if beta < 0.5 {
2943                                2.0 * beta * beta
2944                            } else {
2945                                1.0 - 2.0 * (1.0 - beta) * (1.0 - beta)
2946                            };
2947
2948                            let close_position = close_factor * 0.5 + beta_factor * 0.5;
2949                            let c = l + (h - l) * ((close_position + 1.0) / 2.0);
2950
2951                            high.push(h);
2952                            low.push(l);
2953                            close.push(c.clamp(l, h));
2954                        }
2955                    }
2956
2957                    let ma_type = if use_ema { "ema" } else { "sma" };
2958
2959                    (
2960                        high,
2961                        low,
2962                        close,
2963                        fastk_period,
2964                        slowk_period,
2965                        slowd_period,
2966                        ma_type.to_string(),
2967                        is_flat,
2968                    )
2969                },
2970            );
2971
2972        proptest::test_runner::TestRunner::default().run(
2973            &strat,
2974            |(high, low, close, fastk_period, slowk_period, slowd_period, ma_type, is_flat)| {
2975                let params = StochParams {
2976                    fastk_period: Some(fastk_period),
2977                    slowk_period: Some(slowk_period),
2978                    slowk_ma_type: Some(ma_type.clone()),
2979                    slowd_period: Some(slowd_period),
2980                    slowd_ma_type: Some(ma_type.clone()),
2981                };
2982
2983                let input = StochInput::from_slices(&high, &low, &close, params.clone());
2984
2985                let result = stoch_with_kernel(&input, kernel)?;
2986
2987                let ref_result = stoch_with_kernel(&input, Kernel::Scalar)?;
2988
2989                prop_assert_eq!(result.k.len(), high.len());
2990                prop_assert_eq!(result.d.len(), high.len());
2991
2992                let warmup_k = fastk_period - 1;
2993                let warmup_slowk = if ma_type == "ema" {
2994                    0
2995                } else {
2996                    slowk_period - 1
2997                };
2998                let warmup_slowd = if ma_type == "ema" {
2999                    0
3000                } else {
3001                    slowd_period - 1
3002                };
3003                let expected_warmup = warmup_k
3004                    .max(warmup_k + warmup_slowk)
3005                    .max(warmup_k + warmup_slowk + warmup_slowd);
3006
3007                for i in 0..warmup_k.min(high.len()) {
3008                    prop_assert!(
3009                        result.k[i].is_nan(),
3010                        "K[{}] should be NaN during initial warmup but was {}",
3011                        i,
3012                        result.k[i]
3013                    );
3014                    prop_assert!(
3015                        result.d[i].is_nan(),
3016                        "D[{}] should be NaN during initial warmup but was {}",
3017                        i,
3018                        result.d[i]
3019                    );
3020                }
3021
3022                for i in expected_warmup..high.len() {
3023                    let k_val = result.k[i];
3024                    let d_val = result.d[i];
3025                    let ref_k = ref_result.k[i];
3026                    let ref_d = ref_result.d[i];
3027
3028                    if !k_val.is_nan() {
3029                        prop_assert!(
3030                            k_val >= -1e-9 && k_val <= 100.0 + 1e-9,
3031                            "K[{}] = {} is outside [0, 100] range",
3032                            i,
3033                            k_val
3034                        );
3035                    }
3036
3037                    if !d_val.is_nan() {
3038                        prop_assert!(
3039                            d_val >= -1e-9 && d_val <= 100.0 + 1e-9,
3040                            "D[{}] = {} is outside [0, 100] range",
3041                            i,
3042                            d_val
3043                        );
3044                    }
3045
3046                    if k_val.is_finite() && ref_k.is_finite() {
3047                        let k_diff = (k_val - ref_k).abs();
3048                        let k_ulp_diff = k_val.to_bits().abs_diff(ref_k.to_bits());
3049                        prop_assert!(
3050                            k_diff <= 1e-9 || k_ulp_diff <= 4,
3051                            "K mismatch at [{}]: {} vs {} (diff={}, ULP={})",
3052                            i,
3053                            k_val,
3054                            ref_k,
3055                            k_diff,
3056                            k_ulp_diff
3057                        );
3058                    }
3059
3060                    if d_val.is_finite() && ref_d.is_finite() {
3061                        let d_diff = (d_val - ref_d).abs();
3062                        let d_ulp_diff = d_val.to_bits().abs_diff(ref_d.to_bits());
3063                        prop_assert!(
3064                            d_diff <= 1e-9 || d_ulp_diff <= 4,
3065                            "D mismatch at [{}]: {} vs {} (diff={}, ULP={})",
3066                            i,
3067                            d_val,
3068                            ref_d,
3069                            d_diff,
3070                            d_ulp_diff
3071                        );
3072                    }
3073
3074                    if i >= fastk_period - 1 && !k_val.is_nan() {
3075                        let window_start = i + 1 - fastk_period;
3076                        let window_high = &high[window_start..=i];
3077                        let window_low = &low[window_start..=i];
3078
3079                        let max_h = window_high
3080                            .iter()
3081                            .cloned()
3082                            .fold(f64::NEG_INFINITY, f64::max);
3083                        let min_l = window_low.iter().cloned().fold(f64::INFINITY, f64::min);
3084
3085                        if is_flat || (max_h - min_l).abs() < f64::EPSILON {
3086                            prop_assert!(
3087                                (k_val - 50.0).abs() < 1e-6,
3088                                "K[{}] = {} should be 50 in flat market",
3089                                i,
3090                                k_val
3091                            );
3092                        } else {
3093                            if (close[i] - max_h).abs() < 1e-10 {
3094                                let expected_min = if slowk_period == 1 { 99.0 } else { 85.0 };
3095                                prop_assert!(
3096									k_val >= expected_min,
3097									"K[{}] = {} should be >= {} when close equals highest high (slowk_period={})",
3098									i, k_val, expected_min, slowk_period
3099								);
3100                            }
3101
3102                            if (close[i] - min_l).abs() < 1e-10 {
3103                                let expected_max = if slowk_period == 1 { 1.0 } else { 15.0 };
3104                                prop_assert!(
3105									k_val <= expected_max,
3106									"K[{}] = {} should be <= {} when close equals lowest low (slowk_period={})",
3107									i, k_val, expected_max, slowk_period
3108								);
3109                            }
3110                        }
3111                    }
3112                }
3113
3114                let k_valid: Vec<f64> =
3115                    result.k.iter().filter(|x| x.is_finite()).copied().collect();
3116                let d_valid: Vec<f64> =
3117                    result.d.iter().filter(|x| x.is_finite()).copied().collect();
3118
3119                if k_valid.len() > 10 && d_valid.len() > 10 && !is_flat {
3120                    let k_mean = k_valid.iter().sum::<f64>() / k_valid.len() as f64;
3121                    let d_mean = d_valid.iter().sum::<f64>() / d_valid.len() as f64;
3122
3123                    let k_var = k_valid.iter().map(|x| (x - k_mean).powi(2)).sum::<f64>()
3124                        / k_valid.len() as f64;
3125                    let d_var = d_valid.iter().map(|x| (x - d_mean).powi(2)).sum::<f64>()
3126                        / d_valid.len() as f64;
3127
3128                    if slowd_period > 1 && k_var > 1e-6 {
3129                        prop_assert!(
3130							d_var <= k_var * 1.01,
3131							"D variance {} should be <= K variance {} (smoothing effect with slowd_period={})",
3132							d_var, k_var, slowd_period
3133						);
3134                    }
3135
3136                    if slowd_period == 1 {
3137                        for i in expected_warmup..result.k.len() {
3138                            if result.k[i].is_finite() && result.d[i].is_finite() {
3139                                prop_assert!(
3140                                    (result.k[i] - result.d[i]).abs() < 1e-9,
3141                                    "When slowd_period=1, D[{}]={} should equal K[{}]={}",
3142                                    i,
3143                                    result.d[i],
3144                                    i,
3145                                    result.k[i]
3146                                );
3147                            }
3148                        }
3149                    }
3150                }
3151
3152                if !is_flat && high.len() > fastk_period + 10 {
3153                    let opposite_ma_type = if ma_type == "sma" { "ema" } else { "sma" };
3154                    let opposite_params = StochParams {
3155                        fastk_period: Some(fastk_period),
3156                        slowk_period: Some(slowk_period),
3157                        slowk_ma_type: Some(opposite_ma_type.to_string()),
3158                        slowd_period: Some(slowd_period),
3159                        slowd_ma_type: Some(opposite_ma_type.to_string()),
3160                    };
3161
3162                    let opposite_input =
3163                        StochInput::from_slices(&high, &low, &close, opposite_params);
3164                    let opposite_result = stoch_with_kernel(&opposite_input, kernel)?;
3165
3166                    let mut diff_count = 0;
3167                    let mut total_valid = 0;
3168                    for i in expected_warmup..result.k.len() {
3169                        if result.k[i].is_finite() && opposite_result.k[i].is_finite() {
3170                            total_valid += 1;
3171                            if (result.k[i] - opposite_result.k[i]).abs() > 1e-6 {
3172                                diff_count += 1;
3173                            }
3174                        }
3175                    }
3176
3177                    if total_valid > 10 && slowk_period > 1 {
3178                        let diff_ratio = diff_count as f64 / total_valid as f64;
3179                        prop_assert!(
3180							diff_ratio >= 0.8,
3181							"SMA and EMA should produce different results: only {}/{} values differ ({}%)",
3182							diff_count, total_valid, (diff_ratio * 100.0) as i32
3183						);
3184                    }
3185                }
3186
3187                Ok(())
3188            },
3189        )?;
3190
3191        Ok(())
3192    }
3193
3194    macro_rules! generate_all_stoch_tests {
3195        ($($test_fn:ident),*) => {
3196            paste::paste! {
3197                $( #[test] fn [<$test_fn _scalar_f64>]() { let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar); } )*
3198                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3199                $( #[test] fn [<$test_fn _avx2_f64>]() { let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2); } )*
3200                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3201                $( #[test] fn [<$test_fn _avx512_f64>]() { let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512); } )*
3202            }
3203        }
3204    }
3205    generate_all_stoch_tests!(
3206        check_stoch_partial_params,
3207        check_stoch_accuracy,
3208        check_stoch_default_candles,
3209        check_stoch_zero_period,
3210        check_stoch_period_exceeds_length,
3211        check_stoch_all_nan,
3212        check_stoch_no_poison
3213    );
3214
3215    #[cfg(feature = "proptest")]
3216    generate_all_stoch_tests!(check_stoch_property);
3217    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3218        skip_if_unsupported!(kernel, test);
3219
3220        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3221        let c = read_candles_from_csv(file)?;
3222
3223        let output = StochBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
3224
3225        let def = StochParams::default();
3226        let (row_k, row_d) = output.values_for(&def).expect("default row missing");
3227
3228        assert_eq!(row_k.len(), c.close.len());
3229        assert_eq!(row_d.len(), c.close.len());
3230
3231        let expected_k = [
3232            42.51122827572717,
3233            40.13864479593807,
3234            37.853934778363374,
3235            37.337021714266086,
3236            36.26053890551548,
3237        ];
3238        let expected_d = [
3239            41.36561869426493,
3240            41.7691857059163,
3241            40.16793595000925,
3242            38.44320042952222,
3243            37.15049846604803,
3244        ];
3245        let start = row_k.len() - 5;
3246        for (i, &v) in row_k[start..].iter().enumerate() {
3247            assert!(
3248                (v - expected_k[i]).abs() < 1e-6,
3249                "[{test}] default-row K mismatch at idx {i}: {v} vs {expected_k:?}"
3250            );
3251        }
3252        for (i, &v) in row_d[start..].iter().enumerate() {
3253            assert!(
3254                (v - expected_d[i]).abs() < 1e-6,
3255                "[{test}] default-row D mismatch at idx {i}: {v} vs {expected_d:?}"
3256            );
3257        }
3258        Ok(())
3259    }
3260
3261    #[cfg(debug_assertions)]
3262    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3263        skip_if_unsupported!(kernel, test);
3264
3265        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3266        let c = read_candles_from_csv(file)?;
3267
3268        let test_configs = vec![
3269            (2, 10, 2, 1, 5, 1, 1, 5, 1),
3270            (5, 25, 5, 2, 10, 2, 2, 10, 2),
3271            (10, 30, 10, 3, 9, 3, 3, 9, 3),
3272            (14, 14, 0, 1, 5, 1, 1, 5, 1),
3273            (2, 5, 1, 3, 3, 0, 3, 3, 0),
3274            (20, 50, 15, 5, 15, 5, 5, 15, 5),
3275            (7, 21, 7, 2, 6, 2, 2, 6, 2),
3276            (3, 12, 3, 1, 3, 1, 1, 3, 1),
3277        ];
3278
3279        for (
3280            cfg_idx,
3281            &(fk_start, fk_end, fk_step, sk_start, sk_end, sk_step, sd_start, sd_end, sd_step),
3282        ) in test_configs.iter().enumerate()
3283        {
3284            let output = StochBatchBuilder::new()
3285                .kernel(kernel)
3286                .fastk_period_range(fk_start, fk_end, fk_step)
3287                .slowk_period_range(sk_start, sk_end, sk_step)
3288                .slowd_period_range(sd_start, sd_end, sd_step)
3289                .slowk_ma_type_static("sma")
3290                .slowd_ma_type_static("sma")
3291                .apply_candles(&c)?;
3292
3293            for (idx, &val) in output.k.iter().enumerate() {
3294                if val.is_nan() {
3295                    continue;
3296                }
3297
3298                let bits = val.to_bits();
3299                let row = idx / output.cols;
3300                let col = idx % output.cols;
3301                let combo = &output.combos[row];
3302
3303                if bits == 0x11111111_11111111 {
3304                    panic!(
3305                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
3306						 at row {} col {} (flat index {}) in K values with params: \
3307						 fastk_period={}, slowk_period={}, slowd_period={}, \
3308						 slowk_ma_type={}, slowd_ma_type={}",
3309                        test,
3310                        cfg_idx,
3311                        val,
3312                        bits,
3313                        row,
3314                        col,
3315                        idx,
3316                        combo.fastk_period.unwrap_or(14),
3317                        combo.slowk_period.unwrap_or(3),
3318                        combo.slowd_period.unwrap_or(3),
3319                        combo.slowk_ma_type.as_deref().unwrap_or("sma"),
3320                        combo.slowd_ma_type.as_deref().unwrap_or("sma")
3321                    );
3322                }
3323
3324                if bits == 0x22222222_22222222 {
3325                    panic!(
3326                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
3327						 at row {} col {} (flat index {}) in K values with params: \
3328						 fastk_period={}, slowk_period={}, slowd_period={}, \
3329						 slowk_ma_type={}, slowd_ma_type={}",
3330                        test,
3331                        cfg_idx,
3332                        val,
3333                        bits,
3334                        row,
3335                        col,
3336                        idx,
3337                        combo.fastk_period.unwrap_or(14),
3338                        combo.slowk_period.unwrap_or(3),
3339                        combo.slowd_period.unwrap_or(3),
3340                        combo.slowk_ma_type.as_deref().unwrap_or("sma"),
3341                        combo.slowd_ma_type.as_deref().unwrap_or("sma")
3342                    );
3343                }
3344
3345                if bits == 0x33333333_33333333 {
3346                    panic!(
3347                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
3348						 at row {} col {} (flat index {}) in K values with params: \
3349						 fastk_period={}, slowk_period={}, slowd_period={}, \
3350						 slowk_ma_type={}, slowd_ma_type={}",
3351                        test,
3352                        cfg_idx,
3353                        val,
3354                        bits,
3355                        row,
3356                        col,
3357                        idx,
3358                        combo.fastk_period.unwrap_or(14),
3359                        combo.slowk_period.unwrap_or(3),
3360                        combo.slowd_period.unwrap_or(3),
3361                        combo.slowk_ma_type.as_deref().unwrap_or("sma"),
3362                        combo.slowd_ma_type.as_deref().unwrap_or("sma")
3363                    );
3364                }
3365            }
3366
3367            for (idx, &val) in output.d.iter().enumerate() {
3368                if val.is_nan() {
3369                    continue;
3370                }
3371
3372                let bits = val.to_bits();
3373                let row = idx / output.cols;
3374                let col = idx % output.cols;
3375                let combo = &output.combos[row];
3376
3377                if bits == 0x11111111_11111111 {
3378                    panic!(
3379                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
3380						 at row {} col {} (flat index {}) in D values with params: \
3381						 fastk_period={}, slowk_period={}, slowd_period={}, \
3382						 slowk_ma_type={}, slowd_ma_type={}",
3383                        test,
3384                        cfg_idx,
3385                        val,
3386                        bits,
3387                        row,
3388                        col,
3389                        idx,
3390                        combo.fastk_period.unwrap_or(14),
3391                        combo.slowk_period.unwrap_or(3),
3392                        combo.slowd_period.unwrap_or(3),
3393                        combo.slowk_ma_type.as_deref().unwrap_or("sma"),
3394                        combo.slowd_ma_type.as_deref().unwrap_or("sma")
3395                    );
3396                }
3397
3398                if bits == 0x22222222_22222222 {
3399                    panic!(
3400                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
3401						 at row {} col {} (flat index {}) in D values with params: \
3402						 fastk_period={}, slowk_period={}, slowd_period={}, \
3403						 slowk_ma_type={}, slowd_ma_type={}",
3404                        test,
3405                        cfg_idx,
3406                        val,
3407                        bits,
3408                        row,
3409                        col,
3410                        idx,
3411                        combo.fastk_period.unwrap_or(14),
3412                        combo.slowk_period.unwrap_or(3),
3413                        combo.slowd_period.unwrap_or(3),
3414                        combo.slowk_ma_type.as_deref().unwrap_or("sma"),
3415                        combo.slowd_ma_type.as_deref().unwrap_or("sma")
3416                    );
3417                }
3418
3419                if bits == 0x33333333_33333333 {
3420                    panic!(
3421                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
3422						 at row {} col {} (flat index {}) in D values with params: \
3423						 fastk_period={}, slowk_period={}, slowd_period={}, \
3424						 slowk_ma_type={}, slowd_ma_type={}",
3425                        test,
3426                        cfg_idx,
3427                        val,
3428                        bits,
3429                        row,
3430                        col,
3431                        idx,
3432                        combo.fastk_period.unwrap_or(14),
3433                        combo.slowk_period.unwrap_or(3),
3434                        combo.slowd_period.unwrap_or(3),
3435                        combo.slowk_ma_type.as_deref().unwrap_or("sma"),
3436                        combo.slowd_ma_type.as_deref().unwrap_or("sma")
3437                    );
3438                }
3439            }
3440        }
3441
3442        Ok(())
3443    }
3444
3445    #[cfg(not(debug_assertions))]
3446    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
3447        Ok(())
3448    }
3449
3450    macro_rules! gen_batch_tests {
3451        ($fn_name:ident) => {
3452            paste::paste! {
3453                #[test] fn [<$fn_name _scalar>]()      {
3454                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3455                }
3456                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3457                #[test] fn [<$fn_name _avx2>]()        {
3458                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3459                }
3460                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3461                #[test] fn [<$fn_name _avx512>]()      {
3462                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3463                }
3464                #[test] fn [<$fn_name _auto_detect>]() {
3465                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3466                }
3467            }
3468        };
3469    }
3470
3471    gen_batch_tests!(check_batch_default_row);
3472    gen_batch_tests!(check_batch_no_poison);
3473
3474    fn eq_or_both_nan(a: f64, b: f64) -> bool {
3475        (a.is_nan() && b.is_nan()) || (a == b)
3476    }
3477
3478    #[test]
3479    fn test_stoch_into_matches_api() -> Result<(), Box<dyn Error>> {
3480        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3481        let candles = read_candles_from_csv(file_path)?;
3482        let input = StochInput::with_default_candles(&candles);
3483
3484        let baseline = stoch(&input)?;
3485
3486        let mut out_k = vec![0.0; baseline.k.len()];
3487        let mut out_d = vec![0.0; baseline.d.len()];
3488
3489        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
3490        {
3491            stoch_into(&input, &mut out_k, &mut out_d)?;
3492        }
3493        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3494        {
3495            stoch_into_slices(&mut out_k, &mut out_d, &input, detect_best_kernel())?;
3496        }
3497
3498        assert_eq!(out_k.len(), baseline.k.len());
3499        assert_eq!(out_d.len(), baseline.d.len());
3500        for i in 0..out_k.len() {
3501            assert!(
3502                eq_or_both_nan(out_k[i], baseline.k[i]),
3503                "K mismatch at {}: got {}, expected {}",
3504                i,
3505                out_k[i],
3506                baseline.k[i]
3507            );
3508            assert!(
3509                eq_or_both_nan(out_d[i], baseline.d[i]),
3510                "D mismatch at {}: got {}, expected {}",
3511                i,
3512                out_d[i],
3513                baseline.d[i]
3514            );
3515        }
3516        Ok(())
3517    }
3518
3519    #[test]
3520    fn test_stoch_compute_into_matches_api() -> Result<(), Box<dyn Error>> {
3521        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3522        let candles = read_candles_from_csv(file_path)?;
3523        let input = StochInput::with_default_candles(&candles);
3524
3525        let baseline = stoch(&input)?;
3526
3527        let mut out_k = vec![0.0; baseline.k.len()];
3528        let mut out_d = vec![0.0; baseline.d.len()];
3529        stoch_compute_into(&input, &mut out_k, &mut out_d, Kernel::Auto)?;
3530
3531        assert_eq!(out_k.len(), baseline.k.len());
3532        assert_eq!(out_d.len(), baseline.d.len());
3533        for i in 0..out_k.len() {
3534            assert!(
3535                eq_or_both_nan(out_k[i], baseline.k[i]),
3536                "K mismatch at {}: got {}, expected {}",
3537                i,
3538                out_k[i],
3539                baseline.k[i]
3540            );
3541            assert!(
3542                eq_or_both_nan(out_d[i], baseline.d[i]),
3543                "D mismatch at {}: got {}, expected {}",
3544                i,
3545                out_d[i],
3546                baseline.d[i]
3547            );
3548        }
3549        Ok(())
3550    }
3551}
3552
3553#[inline]
3554fn stoch_classic_sma(
3555    k_raw: &[f64],
3556    slowk_period: usize,
3557    slowd_period: usize,
3558    first_valid_idx: usize,
3559) -> Result<StochOutput, StochError> {
3560    let len = k_raw.len();
3561    let mut k_vec = alloc_with_nan_prefix(len, first_valid_idx + slowk_period - 1);
3562    let mut d_vec = alloc_with_nan_prefix(len, first_valid_idx + slowk_period + slowd_period - 2);
3563
3564    let mut sum_k = 0.0;
3565    let k_start = first_valid_idx;
3566
3567    for i in k_start..(k_start + slowk_period).min(len) {
3568        if !k_raw[i].is_nan() {
3569            sum_k += k_raw[i];
3570        }
3571    }
3572    if k_start + slowk_period - 1 < len {
3573        k_vec[k_start + slowk_period - 1] = sum_k / slowk_period as f64;
3574    }
3575
3576    for i in (k_start + slowk_period)..len {
3577        let old_val = k_raw[i - slowk_period];
3578        let new_val = k_raw[i];
3579        if !old_val.is_nan() {
3580            sum_k -= old_val;
3581        }
3582        if !new_val.is_nan() {
3583            sum_k += new_val;
3584        }
3585        k_vec[i] = sum_k / slowk_period as f64;
3586    }
3587
3588    let mut sum_d = 0.0;
3589    let d_start = first_valid_idx + slowk_period - 1;
3590
3591    for i in d_start..(d_start + slowd_period).min(len) {
3592        if !k_vec[i].is_nan() {
3593            sum_d += k_vec[i];
3594        }
3595    }
3596    if d_start + slowd_period - 1 < len {
3597        d_vec[d_start + slowd_period - 1] = sum_d / slowd_period as f64;
3598    }
3599
3600    for i in (d_start + slowd_period)..len {
3601        let old_val = k_vec[i - slowd_period];
3602        let new_val = k_vec[i];
3603        if !old_val.is_nan() {
3604            sum_d -= old_val;
3605        }
3606        if !new_val.is_nan() {
3607            sum_d += new_val;
3608        }
3609        d_vec[i] = sum_d / slowd_period as f64;
3610    }
3611
3612    Ok(StochOutput { k: k_vec, d: d_vec })
3613}
3614
3615#[inline]
3616fn stoch_classic_ema(
3617    k_raw: &[f64],
3618    slowk_period: usize,
3619    slowd_period: usize,
3620    first_valid_idx: usize,
3621) -> Result<StochOutput, StochError> {
3622    let len = k_raw.len();
3623    let mut k_vec = alloc_with_nan_prefix(len, first_valid_idx + slowk_period - 1);
3624    let mut d_vec = alloc_with_nan_prefix(len, first_valid_idx + slowk_period + slowd_period - 2);
3625
3626    let alpha_k = 2.0 / (slowk_period as f64 + 1.0);
3627    let one_minus_alpha_k = 1.0 - alpha_k;
3628
3629    let k_warmup = first_valid_idx + slowk_period - 1;
3630    let mut sum_k = 0.0;
3631    let mut count_k = 0;
3632    for i in first_valid_idx..(first_valid_idx + slowk_period).min(len) {
3633        if !k_raw[i].is_nan() {
3634            sum_k += k_raw[i];
3635            count_k += 1;
3636        }
3637    }
3638
3639    if count_k > 0 && k_warmup < len {
3640        let mut ema_k = sum_k / count_k as f64;
3641        k_vec[k_warmup] = ema_k;
3642
3643        for i in (k_warmup + 1)..len {
3644            if !k_raw[i].is_nan() {
3645                ema_k = alpha_k * k_raw[i] + one_minus_alpha_k * ema_k;
3646            }
3647            k_vec[i] = ema_k;
3648        }
3649    } else {
3650        for i in k_warmup..len {
3651            k_vec[i] = f64::NAN;
3652        }
3653    }
3654
3655    let alpha_d = 2.0 / (slowd_period as f64 + 1.0);
3656    let one_minus_alpha_d = 1.0 - alpha_d;
3657
3658    let d_warmup = first_valid_idx + slowk_period + slowd_period - 2;
3659    let d_start = first_valid_idx + slowk_period - 1;
3660    let mut sum_d = 0.0;
3661    let mut count_d = 0;
3662    for i in d_start..(d_start + slowd_period).min(len) {
3663        if !k_vec[i].is_nan() {
3664            sum_d += k_vec[i];
3665            count_d += 1;
3666        }
3667    }
3668
3669    if count_d > 0 && d_warmup < len {
3670        let mut ema_d = sum_d / count_d as f64;
3671        d_vec[d_warmup] = ema_d;
3672
3673        for i in (d_warmup + 1)..len {
3674            if !k_vec[i].is_nan() {
3675                ema_d = alpha_d * k_vec[i] + one_minus_alpha_d * ema_d;
3676            }
3677            d_vec[i] = ema_d;
3678        }
3679    } else {
3680        for i in d_warmup..len {
3681            d_vec[i] = f64::NAN;
3682        }
3683    }
3684
3685    Ok(StochOutput { k: k_vec, d: d_vec })
3686}