Skip to main content

vector_ta/indicators/
stochf.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::indicators::moving_averages::alma::DeviceArrayF32Py;
3#[cfg(feature = "python")]
4use numpy::{IntoPyArray, PyArray1};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyValueError;
7#[cfg(feature = "python")]
8use pyo3::prelude::*;
9#[cfg(feature = "python")]
10use pyo3::types::PyDict;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::{source_type, Candles};
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21    make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25use aligned_vec::{AVec, CACHELINE_ALIGN};
26#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
27use core::arch::x86_64::*;
28#[cfg(not(target_arch = "wasm32"))]
29use rayon::prelude::*;
30use std::convert::AsRef;
31use std::error::Error;
32use std::mem::MaybeUninit;
33use thiserror::Error;
34
35#[derive(Debug, Clone)]
36pub enum StochfData<'a> {
37    Candles {
38        candles: &'a Candles,
39    },
40    Slices {
41        high: &'a [f64],
42        low: &'a [f64],
43        close: &'a [f64],
44    },
45}
46
47#[derive(Debug, Clone)]
48pub struct StochfOutput {
49    pub k: Vec<f64>,
50    pub d: Vec<f64>,
51}
52
53#[derive(Debug, Clone)]
54#[cfg_attr(
55    all(target_arch = "wasm32", feature = "wasm"),
56    derive(Serialize, Deserialize)
57)]
58pub struct StochfParams {
59    pub fastk_period: Option<usize>,
60    pub fastd_period: Option<usize>,
61    pub fastd_matype: Option<usize>,
62}
63
64impl Default for StochfParams {
65    fn default() -> Self {
66        Self {
67            fastk_period: Some(5),
68            fastd_period: Some(3),
69            fastd_matype: Some(0),
70        }
71    }
72}
73
74#[derive(Debug, Clone)]
75pub struct StochfInput<'a> {
76    pub data: StochfData<'a>,
77    pub params: StochfParams,
78}
79
80impl<'a> StochfInput<'a> {
81    #[inline]
82    pub fn from_candles(candles: &'a Candles, params: StochfParams) -> Self {
83        Self {
84            data: StochfData::Candles { candles },
85            params,
86        }
87    }
88    #[inline]
89    pub fn from_slices(
90        high: &'a [f64],
91        low: &'a [f64],
92        close: &'a [f64],
93        params: StochfParams,
94    ) -> Self {
95        Self {
96            data: StochfData::Slices { high, low, close },
97            params,
98        }
99    }
100    #[inline]
101    pub fn with_default_candles(candles: &'a Candles) -> Self {
102        Self::from_candles(candles, StochfParams::default())
103    }
104    #[inline]
105    pub fn get_fastk_period(&self) -> usize {
106        self.params.fastk_period.unwrap_or(5)
107    }
108    #[inline]
109    pub fn get_fastd_period(&self) -> usize {
110        self.params.fastd_period.unwrap_or(3)
111    }
112    #[inline]
113    pub fn get_fastd_matype(&self) -> usize {
114        self.params.fastd_matype.unwrap_or(0)
115    }
116}
117
118#[derive(Copy, Clone, Debug)]
119pub struct StochfBuilder {
120    fastk_period: Option<usize>,
121    fastd_period: Option<usize>,
122    fastd_matype: Option<usize>,
123    kernel: Kernel,
124}
125
126impl Default for StochfBuilder {
127    fn default() -> Self {
128        Self {
129            fastk_period: None,
130            fastd_period: None,
131            fastd_matype: None,
132            kernel: Kernel::Auto,
133        }
134    }
135}
136
137impl StochfBuilder {
138    #[inline]
139    pub fn new() -> Self {
140        Self::default()
141    }
142    #[inline]
143    pub fn fastk_period(mut self, n: usize) -> Self {
144        self.fastk_period = Some(n);
145        self
146    }
147    #[inline]
148    pub fn fastd_period(mut self, n: usize) -> Self {
149        self.fastd_period = Some(n);
150        self
151    }
152    #[inline]
153    pub fn fastd_matype(mut self, t: usize) -> Self {
154        self.fastd_matype = Some(t);
155        self
156    }
157    #[inline]
158    pub fn kernel(mut self, k: Kernel) -> Self {
159        self.kernel = k;
160        self
161    }
162    #[inline]
163    pub fn apply(self, candles: &Candles) -> Result<StochfOutput, StochfError> {
164        let p = StochfParams {
165            fastk_period: self.fastk_period,
166            fastd_period: self.fastd_period,
167            fastd_matype: self.fastd_matype,
168        };
169        let i = StochfInput::from_candles(candles, p);
170        stochf_with_kernel(&i, self.kernel)
171    }
172    #[inline]
173    pub fn apply_slices(
174        self,
175        high: &[f64],
176        low: &[f64],
177        close: &[f64],
178    ) -> Result<StochfOutput, StochfError> {
179        let p = StochfParams {
180            fastk_period: self.fastk_period,
181            fastd_period: self.fastd_period,
182            fastd_matype: self.fastd_matype,
183        };
184        let i = StochfInput::from_slices(high, low, close, p);
185        stochf_with_kernel(&i, self.kernel)
186    }
187    #[inline]
188    pub fn into_stream(self) -> Result<StochfStream, StochfError> {
189        let p = StochfParams {
190            fastk_period: self.fastk_period,
191            fastd_period: self.fastd_period,
192            fastd_matype: self.fastd_matype,
193        };
194        StochfStream::try_new(p)
195    }
196}
197
198#[derive(Debug, Error)]
199pub enum StochfError {
200    #[error("stochf: Empty data provided.")]
201    EmptyInputData,
202    #[error("stochf: Invalid period (fastk={fastk}, fastd={fastd}), data length={data_len}.")]
203    InvalidPeriod {
204        fastk: usize,
205        fastd: usize,
206        data_len: usize,
207    },
208    #[error("stochf: All values are NaN.")]
209    AllValuesNaN,
210    #[error(
211        "stochf: Not enough valid data after first valid index (needed={needed}, valid={valid})."
212    )]
213    NotEnoughValidData { needed: usize, valid: usize },
214    #[error("stochf: Invalid output size (expected={expected}, k_got={k_got}, d_got={d_got}).")]
215    OutputLengthMismatch {
216        expected: usize,
217        k_got: usize,
218        d_got: usize,
219    },
220    #[error("stochf: invalid range: start={start}, end={end}, step={step}")]
221    InvalidRange {
222        start: usize,
223        end: usize,
224        step: usize,
225    },
226    #[error("stochf: invalid kernel for batch: {0:?}")]
227    InvalidKernelForBatch(Kernel),
228}
229
230#[inline]
231pub fn stochf(input: &StochfInput) -> Result<StochfOutput, StochfError> {
232    stochf_with_kernel(input, Kernel::Auto)
233}
234
235#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
236#[inline]
237pub fn stochf_into(
238    input: &StochfInput,
239    out_k: &mut [f64],
240    out_d: &mut [f64],
241) -> Result<(), StochfError> {
242    let (high, low, close) = match &input.data {
243        StochfData::Candles { candles } => {
244            let high = candles
245                .select_candle_field("high")
246                .map_err(|_| StochfError::EmptyInputData)?;
247            let low = candles
248                .select_candle_field("low")
249                .map_err(|_| StochfError::EmptyInputData)?;
250            let close = candles
251                .select_candle_field("close")
252                .map_err(|_| StochfError::EmptyInputData)?;
253            (high, low, close)
254        }
255        StochfData::Slices { high, low, close } => (*high, *low, *close),
256    };
257
258    if high.is_empty() || low.is_empty() || close.is_empty() {
259        return Err(StochfError::EmptyInputData);
260    }
261    let len = high.len();
262    if low.len() != len || close.len() != len {
263        return Err(StochfError::EmptyInputData);
264    }
265    if out_k.len() != len || out_d.len() != len {
266        return Err(StochfError::OutputLengthMismatch {
267            expected: len,
268            k_got: out_k.len(),
269            d_got: out_d.len(),
270        });
271    }
272
273    let fastk = input.get_fastk_period();
274    let fastd = input.get_fastd_period();
275    let _matype = input.get_fastd_matype();
276    if fastk == 0 || fastd == 0 || fastk > len || fastd > len {
277        return Err(StochfError::InvalidPeriod {
278            fastk,
279            fastd,
280            data_len: len,
281        });
282    }
283
284    let first = (0..len)
285        .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
286        .ok_or(StochfError::AllValuesNaN)?;
287    if (len - first) < fastk {
288        return Err(StochfError::NotEnoughValidData {
289            needed: fastk,
290            valid: len - first,
291        });
292    }
293
294    let k_warm = first + fastk - 1;
295    let d_warm = first + fastk + fastd - 2;
296    let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
297    for v in &mut out_k[..k_warm.min(len)] {
298        *v = qnan;
299    }
300    for v in &mut out_d[..d_warm.min(len)] {
301        *v = qnan;
302    }
303
304    stochf_into_slice(out_k, out_d, input, Kernel::Auto)
305}
306
307#[inline]
308pub fn stochf_into_slice(
309    dst_k: &mut [f64],
310    dst_d: &mut [f64],
311    input: &StochfInput,
312    kernel: Kernel,
313) -> Result<(), StochfError> {
314    let (high, low, close) = match &input.data {
315        StochfData::Candles { candles } => {
316            let high = candles
317                .select_candle_field("high")
318                .map_err(|_| StochfError::EmptyInputData)?;
319            let low = candles
320                .select_candle_field("low")
321                .map_err(|_| StochfError::EmptyInputData)?;
322            let close = candles
323                .select_candle_field("close")
324                .map_err(|_| StochfError::EmptyInputData)?;
325            (high, low, close)
326        }
327        StochfData::Slices { high, low, close } => (*high, *low, *close),
328    };
329
330    if high.is_empty() || low.is_empty() || close.is_empty() {
331        return Err(StochfError::EmptyInputData);
332    }
333    let len = high.len();
334    if low.len() != len || close.len() != len {
335        return Err(StochfError::EmptyInputData);
336    }
337    if dst_k.len() != len || dst_d.len() != len {
338        return Err(StochfError::OutputLengthMismatch {
339            expected: len,
340            k_got: dst_k.len(),
341            d_got: dst_d.len(),
342        });
343    }
344
345    let fastk_period = input.get_fastk_period();
346    let fastd_period = input.get_fastd_period();
347    let matype = input.get_fastd_matype();
348
349    if fastk_period == 0 || fastd_period == 0 || fastk_period > len || fastd_period > len {
350        return Err(StochfError::InvalidPeriod {
351            fastk: fastk_period,
352            fastd: fastd_period,
353            data_len: len,
354        });
355    }
356    let first_valid_idx = (0..len)
357        .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
358        .ok_or(StochfError::AllValuesNaN)?;
359    if (len - first_valid_idx) < fastk_period {
360        return Err(StochfError::NotEnoughValidData {
361            needed: fastk_period,
362            valid: len - first_valid_idx,
363        });
364    }
365
366    let chosen = match kernel {
367        Kernel::Auto => Kernel::Scalar,
368        other => other,
369    };
370
371    unsafe {
372        match chosen {
373            Kernel::Scalar | Kernel::ScalarBatch => stochf_scalar(
374                high,
375                low,
376                close,
377                fastk_period,
378                fastd_period,
379                matype,
380                first_valid_idx,
381                dst_k,
382                dst_d,
383            ),
384            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
385            Kernel::Avx2 | Kernel::Avx2Batch => stochf_avx2(
386                high,
387                low,
388                close,
389                fastk_period,
390                fastd_period,
391                matype,
392                first_valid_idx,
393                dst_k,
394                dst_d,
395            ),
396            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
397            Kernel::Avx512 | Kernel::Avx512Batch => stochf_avx512(
398                high,
399                low,
400                close,
401                fastk_period,
402                fastd_period,
403                matype,
404                first_valid_idx,
405                dst_k,
406                dst_d,
407            ),
408            _ => unreachable!(),
409        }
410    }
411
412    let k_warmup = (first_valid_idx + fastk_period - 1).min(len);
413    let d_warmup = (first_valid_idx + fastk_period + fastd_period - 2).min(len);
414    for v in &mut dst_k[..k_warmup] {
415        *v = f64::NAN;
416    }
417    for v in &mut dst_d[..d_warmup] {
418        *v = f64::NAN;
419    }
420
421    Ok(())
422}
423
424pub fn stochf_with_kernel(
425    input: &StochfInput,
426    kernel: Kernel,
427) -> Result<StochfOutput, StochfError> {
428    let (high, low, close) = match &input.data {
429        StochfData::Candles { candles } => {
430            let high = candles
431                .select_candle_field("high")
432                .map_err(|_| StochfError::EmptyInputData)?;
433            let low = candles
434                .select_candle_field("low")
435                .map_err(|_| StochfError::EmptyInputData)?;
436            let close = candles
437                .select_candle_field("close")
438                .map_err(|_| StochfError::EmptyInputData)?;
439            (high, low, close)
440        }
441        StochfData::Slices { high, low, close } => (*high, *low, *close),
442    };
443
444    if high.is_empty() || low.is_empty() || close.is_empty() {
445        return Err(StochfError::EmptyInputData);
446    }
447    let len = high.len();
448    if low.len() != len || close.len() != len {
449        return Err(StochfError::EmptyInputData);
450    }
451
452    let fastk_period = input.get_fastk_period();
453    let fastd_period = input.get_fastd_period();
454    let matype = input.get_fastd_matype();
455
456    if fastk_period == 0 || fastd_period == 0 || fastk_period > len || fastd_period > len {
457        return Err(StochfError::InvalidPeriod {
458            fastk: fastk_period,
459            fastd: fastd_period,
460            data_len: len,
461        });
462    }
463    let first_valid_idx = (0..len)
464        .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
465        .ok_or(StochfError::AllValuesNaN)?;
466    if (len - first_valid_idx) < fastk_period {
467        return Err(StochfError::NotEnoughValidData {
468            needed: fastk_period,
469            valid: len - first_valid_idx,
470        });
471    }
472
473    let k_warmup = first_valid_idx + fastk_period - 1;
474    let d_warmup = first_valid_idx + fastk_period + fastd_period - 2;
475    let mut k_vals = alloc_with_nan_prefix(len, k_warmup.min(len));
476    let mut d_vals = alloc_with_nan_prefix(len, d_warmup.min(len));
477
478    let chosen = match kernel {
479        Kernel::Auto => Kernel::Scalar,
480        other => other,
481    };
482
483    unsafe {
484        match chosen {
485            Kernel::Scalar | Kernel::ScalarBatch => stochf_scalar(
486                high,
487                low,
488                close,
489                fastk_period,
490                fastd_period,
491                matype,
492                first_valid_idx,
493                &mut k_vals,
494                &mut d_vals,
495            ),
496            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
497            Kernel::Avx2 | Kernel::Avx2Batch => stochf_avx2(
498                high,
499                low,
500                close,
501                fastk_period,
502                fastd_period,
503                matype,
504                first_valid_idx,
505                &mut k_vals,
506                &mut d_vals,
507            ),
508            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
509            Kernel::Avx512 | Kernel::Avx512Batch => stochf_avx512(
510                high,
511                low,
512                close,
513                fastk_period,
514                fastd_period,
515                matype,
516                first_valid_idx,
517                &mut k_vals,
518                &mut d_vals,
519            ),
520            _ => unreachable!(),
521        }
522    }
523
524    Ok(StochfOutput {
525        k: k_vals,
526        d: d_vals,
527    })
528}
529
530#[inline]
531pub unsafe fn stochf_scalar(
532    high: &[f64],
533    low: &[f64],
534    close: &[f64],
535    fastk_period: usize,
536    fastd_period: usize,
537    matype: usize,
538    first_valid_idx: usize,
539    k_vals: &mut [f64],
540    d_vals: &mut [f64],
541) {
542    debug_assert_eq!(high.len(), low.len());
543    debug_assert_eq!(high.len(), close.len());
544    debug_assert_eq!(high.len(), k_vals.len());
545    debug_assert_eq!(k_vals.len(), d_vals.len());
546
547    let len = high.len();
548    if len == 0 {
549        return;
550    }
551
552    let hp = high.as_ptr();
553    let lp = low.as_ptr();
554    let cp = close.as_ptr();
555
556    let k_start = first_valid_idx + fastk_period - 1;
557
558    if fastk_period <= 16 {
559        let use_sma_d = matype == 0;
560        let mut d_sum: f64 = 0.0;
561        let mut d_cnt: usize = 0;
562
563        let mut i = k_start;
564        while i < len {
565            let start = i + 1 - fastk_period;
566            let end = i + 1;
567
568            let mut hh = f64::NEG_INFINITY;
569            let mut ll = f64::INFINITY;
570
571            let mut j = start;
572            let unroll_end = end - ((end - j) & 3);
573            while j < unroll_end {
574                let h0 = *hp.add(j);
575                let l0 = *lp.add(j);
576                if h0 > hh {
577                    hh = h0;
578                }
579                if l0 < ll {
580                    ll = l0;
581                }
582
583                let h1 = *hp.add(j + 1);
584                let l1 = *lp.add(j + 1);
585                if h1 > hh {
586                    hh = h1;
587                }
588                if l1 < ll {
589                    ll = l1;
590                }
591
592                let h2 = *hp.add(j + 2);
593                let l2 = *lp.add(j + 2);
594                if h2 > hh {
595                    hh = h2;
596                }
597                if l2 < ll {
598                    ll = l2;
599                }
600
601                let h3 = *hp.add(j + 3);
602                let l3 = *lp.add(j + 3);
603                if h3 > hh {
604                    hh = h3;
605                }
606                if l3 < ll {
607                    ll = l3;
608                }
609
610                j += 4;
611            }
612            while j < end {
613                let h = *hp.add(j);
614                let l = *lp.add(j);
615                if h > hh {
616                    hh = h;
617                }
618                if l < ll {
619                    ll = l;
620                }
621                j += 1;
622            }
623
624            let c = *cp.add(i);
625            let denom = hh - ll;
626            let kv = if denom == 0.0 {
627                if c == hh {
628                    100.0
629                } else {
630                    0.0
631                }
632            } else {
633                let inv = 100.0 / denom;
634                c.mul_add(inv, (-ll) * inv)
635            };
636            *k_vals.get_unchecked_mut(i) = kv;
637
638            if use_sma_d {
639                if kv.is_nan() {
640                    *d_vals.get_unchecked_mut(i) = f64::NAN;
641                } else if d_cnt < fastd_period {
642                    d_sum += kv;
643                    d_cnt += 1;
644                    if d_cnt == fastd_period {
645                        *d_vals.get_unchecked_mut(i) = d_sum / (fastd_period as f64);
646                    } else {
647                        *d_vals.get_unchecked_mut(i) = f64::NAN;
648                    }
649                } else {
650                    d_sum += kv - *k_vals.get_unchecked(i - fastd_period);
651                    *d_vals.get_unchecked_mut(i) = d_sum / (fastd_period as f64);
652                }
653            }
654
655            i += 1;
656        }
657
658        if matype != 0 {
659            d_vals.fill(f64::NAN);
660        }
661        return;
662    }
663
664    let cap = fastk_period;
665    let mut qh = vec![0usize; cap];
666    let mut ql = vec![0usize; cap];
667    let mut qh_head = 0usize;
668    let mut qh_tail = 0usize;
669    let mut ql_head = 0usize;
670    let mut ql_tail = 0usize;
671
672    let use_sma_d = matype == 0;
673    let mut d_sum: f64 = 0.0;
674    let mut d_cnt: usize = 0;
675
676    let mut i = first_valid_idx;
677    while i < len {
678        if i + 1 >= fastk_period {
679            let win_start = i + 1 - fastk_period;
680            while qh_head != qh_tail {
681                let idx = *qh.get_unchecked(qh_head);
682                if idx >= win_start {
683                    break;
684                }
685                qh_head += 1;
686                if qh_head == cap {
687                    qh_head = 0;
688                }
689            }
690            while ql_head != ql_tail {
691                let idx = *ql.get_unchecked(ql_head);
692                if idx >= win_start {
693                    break;
694                }
695                ql_head += 1;
696                if ql_head == cap {
697                    ql_head = 0;
698                }
699            }
700        }
701
702        let h_i = *hp.add(i);
703        if h_i == h_i {
704            while qh_head != qh_tail {
705                let back = if qh_tail == 0 { cap - 1 } else { qh_tail - 1 };
706                let back_idx = *qh.get_unchecked(back);
707                if *hp.add(back_idx) <= h_i {
708                    qh_tail = back;
709                } else {
710                    break;
711                }
712            }
713            *qh.get_unchecked_mut(qh_tail) = i;
714            qh_tail += 1;
715            if qh_tail == cap {
716                qh_tail = 0;
717            }
718        }
719
720        let l_i = *lp.add(i);
721        if l_i == l_i {
722            while ql_head != ql_tail {
723                let back = if ql_tail == 0 { cap - 1 } else { ql_tail - 1 };
724                let back_idx = *ql.get_unchecked(back);
725                if *lp.add(back_idx) >= l_i {
726                    ql_tail = back;
727                } else {
728                    break;
729                }
730            }
731            *ql.get_unchecked_mut(ql_tail) = i;
732            ql_tail += 1;
733            if ql_tail == cap {
734                ql_tail = 0;
735            }
736        }
737
738        if i >= k_start {
739            let hh = if qh_head != qh_tail {
740                *hp.add(*qh.get_unchecked(qh_head))
741            } else {
742                f64::NEG_INFINITY
743            };
744            let ll = if ql_head != ql_tail {
745                *lp.add(*ql.get_unchecked(ql_head))
746            } else {
747                f64::INFINITY
748            };
749            let c = *cp.add(i);
750            let denom = hh - ll;
751            let kv = if denom == 0.0 {
752                if c == hh {
753                    100.0
754                } else {
755                    0.0
756                }
757            } else {
758                let inv = 100.0 / denom;
759                c.mul_add(inv, (-ll) * inv)
760            };
761            *k_vals.get_unchecked_mut(i) = kv;
762
763            if use_sma_d {
764                if kv.is_nan() {
765                    *d_vals.get_unchecked_mut(i) = f64::NAN;
766                } else if d_cnt < fastd_period {
767                    d_sum += kv;
768                    d_cnt += 1;
769                    if d_cnt == fastd_period {
770                        *d_vals.get_unchecked_mut(i) = d_sum / (fastd_period as f64);
771                    } else {
772                        *d_vals.get_unchecked_mut(i) = f64::NAN;
773                    }
774                } else {
775                    d_sum += kv - *k_vals.get_unchecked(i - fastd_period);
776                    *d_vals.get_unchecked_mut(i) = d_sum / (fastd_period as f64);
777                }
778            }
779        }
780
781        i += 1;
782    }
783
784    if !use_sma_d {
785        d_vals.fill(f64::NAN);
786    }
787}
788
789#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
790#[inline]
791pub unsafe fn stochf_avx2(
792    high: &[f64],
793    low: &[f64],
794    close: &[f64],
795    fastk_period: usize,
796    fastd_period: usize,
797    matype: usize,
798    first_valid_idx: usize,
799    k_vals: &mut [f64],
800    d_vals: &mut [f64],
801) {
802    if fastk_period <= 32 {
803        let len = high.len();
804        let start_i = first_valid_idx + fastk_period - 1;
805        let neg_inf = _mm256_set1_pd(f64::NEG_INFINITY);
806        let pos_inf = _mm256_set1_pd(f64::INFINITY);
807
808        let use_sma_d = matype == 0;
809        let mut d_sum = 0.0f64;
810        let mut d_cnt: usize = 0;
811
812        for i in start_i..len {
813            let start = i + 1 - fastk_period;
814            let end = i + 1;
815
816            let mut vmax = neg_inf;
817            let mut vmin = pos_inf;
818            let mut j = start;
819            while j + 4 <= end {
820                let vh = _mm256_loadu_pd(high.as_ptr().add(j));
821                let vl = _mm256_loadu_pd(low.as_ptr().add(j));
822
823                let mask_h = _mm256_cmp_pd(vh, vh, _CMP_ORD_Q);
824                let mask_l = _mm256_cmp_pd(vl, vl, _CMP_ORD_Q);
825                let vh_nnan = _mm256_blendv_pd(neg_inf, vh, mask_h);
826                let vl_nnan = _mm256_blendv_pd(pos_inf, vl, mask_l);
827
828                vmax = _mm256_max_pd(vmax, vh_nnan);
829                vmin = _mm256_min_pd(vmin, vl_nnan);
830                j += 4;
831            }
832
833            let vmax_lo = _mm256_castpd256_pd128(vmax);
834            let vmax_hi = _mm256_extractf128_pd(vmax, 1);
835            let vmax_128 = _mm_max_pd(vmax_lo, vmax_hi);
836            let vmax_hi64 = _mm_unpackhi_pd(vmax_128, vmax_128);
837            let mut hh = f64::max(_mm_cvtsd_f64(vmax_128), _mm_cvtsd_f64(vmax_hi64));
838
839            let vmin_lo = _mm256_castpd256_pd128(vmin);
840            let vmin_hi = _mm256_extractf128_pd(vmin, 1);
841            let vmin_128 = _mm_min_pd(vmin_lo, vmin_hi);
842            let vmin_hi64 = _mm_unpackhi_pd(vmin_128, vmin_128);
843            let mut ll = f64::min(_mm_cvtsd_f64(vmin_128), _mm_cvtsd_f64(vmin_hi64));
844
845            while j < end {
846                let h = *high.get_unchecked(j);
847                let l = *low.get_unchecked(j);
848                if h == h && h > hh {
849                    hh = h;
850                }
851                if l == l && l < ll {
852                    ll = l;
853                }
854                j += 1;
855            }
856
857            let c = *close.get_unchecked(i);
858            let denom = hh - ll;
859            let kv = if denom == 0.0 {
860                if c == hh {
861                    100.0
862                } else {
863                    0.0
864                }
865            } else {
866                let inv = 100.0 / denom;
867                c.mul_add(inv, (-ll) * inv)
868            };
869            *k_vals.get_unchecked_mut(i) = kv;
870
871            if use_sma_d {
872                if kv.is_nan() {
873                    *d_vals.get_unchecked_mut(i) = f64::NAN;
874                } else if d_cnt < fastd_period {
875                    d_sum += kv;
876                    d_cnt += 1;
877                    if d_cnt == fastd_period {
878                        *d_vals.get_unchecked_mut(i) = d_sum / (fastd_period as f64);
879                    } else {
880                        *d_vals.get_unchecked_mut(i) = f64::NAN;
881                    }
882                } else {
883                    d_sum += kv - *k_vals.get_unchecked(i - fastd_period);
884                    *d_vals.get_unchecked_mut(i) = d_sum / (fastd_period as f64);
885                }
886            }
887        }
888
889        if !use_sma_d {
890            d_vals.fill(f64::NAN);
891        }
892    } else {
893        stochf_scalar(
894            high,
895            low,
896            close,
897            fastk_period,
898            fastd_period,
899            matype,
900            first_valid_idx,
901            k_vals,
902            d_vals,
903        );
904    }
905}
906
907#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
908#[inline]
909pub unsafe fn stochf_avx512(
910    high: &[f64],
911    low: &[f64],
912    close: &[f64],
913    fastk_period: usize,
914    fastd_period: usize,
915    matype: usize,
916    first_valid_idx: usize,
917    k_vals: &mut [f64],
918    d_vals: &mut [f64],
919) {
920    if fastk_period <= 32 {
921        stochf_avx512_short(
922            high,
923            low,
924            close,
925            fastk_period,
926            fastd_period,
927            matype,
928            first_valid_idx,
929            k_vals,
930            d_vals,
931        );
932    } else {
933        stochf_avx512_long(
934            high,
935            low,
936            close,
937            fastk_period,
938            fastd_period,
939            matype,
940            first_valid_idx,
941            k_vals,
942            d_vals,
943        );
944    }
945}
946
947#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
948#[inline]
949pub unsafe fn stochf_avx512_short(
950    high: &[f64],
951    low: &[f64],
952    close: &[f64],
953    fastk_period: usize,
954    fastd_period: usize,
955    matype: usize,
956    first_valid_idx: usize,
957    k_vals: &mut [f64],
958    d_vals: &mut [f64],
959) {
960    stochf_scalar(
961        high,
962        low,
963        close,
964        fastk_period,
965        fastd_period,
966        matype,
967        first_valid_idx,
968        k_vals,
969        d_vals,
970    );
971}
972
973#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
974#[inline]
975pub unsafe fn stochf_avx512_long(
976    high: &[f64],
977    low: &[f64],
978    close: &[f64],
979    fastk_period: usize,
980    fastd_period: usize,
981    matype: usize,
982    first_valid_idx: usize,
983    k_vals: &mut [f64],
984    d_vals: &mut [f64],
985) {
986    stochf_scalar(
987        high,
988        low,
989        close,
990        fastk_period,
991        fastd_period,
992        matype,
993        first_valid_idx,
994        k_vals,
995        d_vals,
996    );
997}
998
999#[derive(Debug, Clone)]
1000pub struct StochfStream {
1001    fastk_period: usize,
1002    fastd_period: usize,
1003    fastd_matype: usize,
1004
1005    qh_idx: Vec<usize>,
1006    qh_val: Vec<f64>,
1007    qh_head: usize,
1008    qh_tail: usize,
1009
1010    ql_idx: Vec<usize>,
1011    ql_val: Vec<f64>,
1012    ql_head: usize,
1013    ql_tail: usize,
1014
1015    cap_k: usize,
1016
1017    qh_full: bool,
1018    ql_full: bool,
1019
1020    t: usize,
1021
1022    k_ring: Vec<f64>,
1023    k_head: usize,
1024    k_count: usize,
1025    d_sma_sum: f64,
1026}
1027
1028impl StochfStream {
1029    pub fn try_new(params: StochfParams) -> Result<Self, StochfError> {
1030        let fastk_period = params.fastk_period.unwrap_or(5);
1031        let fastd_period = params.fastd_period.unwrap_or(3);
1032        let fastd_matype = params.fastd_matype.unwrap_or(0);
1033
1034        if fastk_period == 0 || fastd_period == 0 {
1035            return Err(StochfError::InvalidPeriod {
1036                fastk: fastk_period,
1037                fastd: fastd_period,
1038                data_len: 0,
1039            });
1040        }
1041
1042        let cap_k = fastk_period + 1;
1043
1044        Ok(Self {
1045            fastk_period,
1046            fastd_period,
1047            fastd_matype,
1048
1049            qh_idx: vec![0; cap_k],
1050            qh_val: vec![0.0; cap_k],
1051            qh_head: 0,
1052            qh_tail: 0,
1053            qh_full: false,
1054
1055            ql_idx: vec![0; cap_k],
1056            ql_val: vec![0.0; cap_k],
1057            ql_head: 0,
1058            ql_tail: 0,
1059            ql_full: false,
1060
1061            cap_k,
1062
1063            t: 0,
1064
1065            k_ring: vec![0.0; fastd_period],
1066            k_head: 0,
1067            k_count: 0,
1068            d_sma_sum: 0.0,
1069        })
1070    }
1071
1072    #[inline(always)]
1073    fn inc(idx: &mut usize, cap: usize) {
1074        *idx += 1;
1075        if *idx == cap {
1076            *idx = 0;
1077        }
1078    }
1079
1080    #[inline(always)]
1081    fn dec(idx: &mut usize, cap: usize) {
1082        if *idx == 0 {
1083            *idx = cap - 1;
1084        } else {
1085            *idx -= 1;
1086        }
1087    }
1088
1089    #[inline(always)]
1090    fn qh_expire(&mut self, win_start: usize) {
1091        while (self.qh_head != self.qh_tail || self.qh_full)
1092            && self.qh_idx[self.qh_head] < win_start
1093        {
1094            Self::inc(&mut self.qh_head, self.cap_k);
1095
1096            self.qh_full = false;
1097        }
1098    }
1099    #[inline(always)]
1100    fn ql_expire(&mut self, win_start: usize) {
1101        while (self.ql_head != self.ql_tail || self.ql_full)
1102            && self.ql_idx[self.ql_head] < win_start
1103        {
1104            Self::inc(&mut self.ql_head, self.cap_k);
1105            self.ql_full = false;
1106        }
1107    }
1108
1109    #[inline(always)]
1110    fn qh_push(&mut self, idx: usize, val: f64) {
1111        while self.qh_head != self.qh_tail || self.qh_full {
1112            let mut back = self.qh_tail;
1113            Self::dec(&mut back, self.cap_k);
1114            if self.qh_val[back] <= val {
1115                self.qh_tail = back;
1116
1117                self.qh_full = false;
1118            } else {
1119                break;
1120            }
1121        }
1122        self.qh_idx[self.qh_tail] = idx;
1123        self.qh_val[self.qh_tail] = val;
1124        Self::inc(&mut self.qh_tail, self.cap_k);
1125
1126        if self.qh_tail == self.qh_head {
1127            self.qh_full = true;
1128        }
1129    }
1130
1131    #[inline(always)]
1132    fn ql_push(&mut self, idx: usize, val: f64) {
1133        while self.ql_head != self.ql_tail || self.ql_full {
1134            let mut back = self.ql_tail;
1135            Self::dec(&mut back, self.cap_k);
1136            if self.ql_val[back] >= val {
1137                self.ql_tail = back;
1138                self.ql_full = false;
1139            } else {
1140                break;
1141            }
1142        }
1143        self.ql_idx[self.ql_tail] = idx;
1144        self.ql_val[self.ql_tail] = val;
1145        Self::inc(&mut self.ql_tail, self.cap_k);
1146        if self.ql_tail == self.ql_head {
1147            self.ql_full = true;
1148        }
1149    }
1150
1151    pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
1152        let i = self.t;
1153
1154        self.t = self.t.wrapping_add(1);
1155
1156        if high == high {
1157            self.qh_push(i, high);
1158        }
1159        if low == low {
1160            self.ql_push(i, low);
1161        }
1162
1163        let have_k_window = (i + 1) >= self.fastk_period;
1164        if have_k_window {
1165            let win_start = i + 1 - self.fastk_period;
1166            self.qh_expire(win_start);
1167            self.ql_expire(win_start);
1168        } else {
1169            return None;
1170        }
1171
1172        let hh = if self.qh_head != self.qh_tail || self.qh_full {
1173            self.qh_val[self.qh_head]
1174        } else {
1175            f64::NEG_INFINITY
1176        };
1177        let ll = if self.ql_head != self.ql_tail || self.ql_full {
1178            self.ql_val[self.ql_head]
1179        } else {
1180            f64::INFINITY
1181        };
1182
1183        let denom = hh - ll;
1184        let k = if denom == 0.0 {
1185            if close == hh {
1186                100.0
1187            } else {
1188                0.0
1189            }
1190        } else {
1191            let scale = 100.0 / denom;
1192            close.mul_add(scale, (-ll) * scale)
1193        };
1194
1195        let d = if self.fastd_matype != 0 {
1196            f64::NAN
1197        } else if self.k_count < self.fastd_period {
1198            self.k_ring[self.k_head] = k;
1199            self.d_sma_sum += k;
1200            self.k_count += 1;
1201            StochfStream::inc(&mut self.k_head, self.fastd_period);
1202
1203            if self.k_count == self.fastd_period {
1204                self.d_sma_sum / (self.fastd_period as f64)
1205            } else {
1206                f64::NAN
1207            }
1208        } else {
1209            let old = self.k_ring[self.k_head];
1210            self.k_ring[self.k_head] = k;
1211            StochfStream::inc(&mut self.k_head, self.fastd_period);
1212
1213            self.d_sma_sum += k - old;
1214            self.d_sma_sum / (self.fastd_period as f64)
1215        };
1216
1217        Some((k, d))
1218    }
1219}
1220
1221#[derive(Clone, Debug)]
1222pub struct StochfBatchRange {
1223    pub fastk_period: (usize, usize, usize),
1224    pub fastd_period: (usize, usize, usize),
1225}
1226
1227impl Default for StochfBatchRange {
1228    fn default() -> Self {
1229        Self {
1230            fastk_period: (5, 254, 1),
1231            fastd_period: (3, 3, 0),
1232        }
1233    }
1234}
1235
1236#[derive(Clone, Debug, Default)]
1237pub struct StochfBatchBuilder {
1238    range: StochfBatchRange,
1239    kernel: Kernel,
1240}
1241
1242impl StochfBatchBuilder {
1243    pub fn new() -> Self {
1244        Self::default()
1245    }
1246    pub fn kernel(mut self, k: Kernel) -> Self {
1247        self.kernel = k;
1248        self
1249    }
1250    pub fn fastk_range(mut self, start: usize, end: usize, step: usize) -> Self {
1251        self.range.fastk_period = (start, end, step);
1252        self
1253    }
1254    pub fn fastk_static(mut self, p: usize) -> Self {
1255        self.range.fastk_period = (p, p, 0);
1256        self
1257    }
1258    pub fn fastd_range(mut self, start: usize, end: usize, step: usize) -> Self {
1259        self.range.fastd_period = (start, end, step);
1260        self
1261    }
1262    pub fn fastd_static(mut self, p: usize) -> Self {
1263        self.range.fastd_period = (p, p, 0);
1264        self
1265    }
1266    pub fn apply_slices(
1267        self,
1268        high: &[f64],
1269        low: &[f64],
1270        close: &[f64],
1271    ) -> Result<StochfBatchOutput, StochfError> {
1272        stochf_batch_with_kernel(high, low, close, &self.range, self.kernel)
1273    }
1274    pub fn with_default_slices(
1275        high: &[f64],
1276        low: &[f64],
1277        close: &[f64],
1278        k: Kernel,
1279    ) -> Result<StochfBatchOutput, StochfError> {
1280        StochfBatchBuilder::new()
1281            .kernel(k)
1282            .apply_slices(high, low, close)
1283    }
1284    pub fn apply_candles(self, c: &Candles) -> Result<StochfBatchOutput, StochfError> {
1285        let high = source_type(c, "high");
1286        let low = source_type(c, "low");
1287        let close = source_type(c, "close");
1288        self.apply_slices(high, low, close)
1289    }
1290    pub fn with_default_candles(c: &Candles) -> Result<StochfBatchOutput, StochfError> {
1291        StochfBatchBuilder::new()
1292            .kernel(Kernel::Auto)
1293            .apply_candles(c)
1294    }
1295}
1296
1297pub fn stochf_batch_with_kernel(
1298    high: &[f64],
1299    low: &[f64],
1300    close: &[f64],
1301    sweep: &StochfBatchRange,
1302    k: Kernel,
1303) -> Result<StochfBatchOutput, StochfError> {
1304    let kernel = match k {
1305        Kernel::Auto => detect_best_batch_kernel(),
1306        other if other.is_batch() => other,
1307        _ => return Err(StochfError::InvalidKernelForBatch(k)),
1308    };
1309    let simd = match kernel {
1310        Kernel::Avx512Batch => Kernel::Avx512,
1311        Kernel::Avx2Batch => Kernel::Avx2,
1312        Kernel::ScalarBatch => Kernel::Scalar,
1313        _ => unreachable!(),
1314    };
1315    stochf_batch_par_slice(high, low, close, sweep, simd)
1316}
1317
1318#[derive(Clone, Debug)]
1319pub struct StochfBatchOutput {
1320    pub k: Vec<f64>,
1321    pub d: Vec<f64>,
1322    pub combos: Vec<StochfParams>,
1323    pub rows: usize,
1324    pub cols: usize,
1325}
1326impl StochfBatchOutput {
1327    pub fn row_for_params(&self, p: &StochfParams) -> Option<usize> {
1328        self.combos.iter().position(|c| {
1329            c.fastk_period.unwrap_or(5) == p.fastk_period.unwrap_or(5)
1330                && c.fastd_period.unwrap_or(3) == p.fastd_period.unwrap_or(3)
1331                && c.fastd_matype.unwrap_or(0) == p.fastd_matype.unwrap_or(0)
1332        })
1333    }
1334    pub fn k_for(&self, p: &StochfParams) -> Option<&[f64]> {
1335        self.row_for_params(p).map(|row| {
1336            let start = row * self.cols;
1337            &self.k[start..start + self.cols]
1338        })
1339    }
1340    pub fn d_for(&self, p: &StochfParams) -> Option<&[f64]> {
1341        self.row_for_params(p).map(|row| {
1342            let start = row * self.cols;
1343            &self.d[start..start + self.cols]
1344        })
1345    }
1346}
1347
1348#[inline(always)]
1349fn expand_grid(r: &StochfBatchRange) -> Vec<StochfParams> {
1350    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, StochfError> {
1351        if step == 0 || start == end {
1352            return Ok(vec![start]);
1353        }
1354        if start < end {
1355            let mut v = Vec::new();
1356            let st = step.max(1);
1357            let mut x = start;
1358            while x <= end {
1359                v.push(x);
1360                x = match x.checked_add(st) {
1361                    Some(next) => next,
1362                    None => break,
1363                };
1364            }
1365            if v.is_empty() {
1366                return Err(StochfError::InvalidRange { start, end, step });
1367            }
1368            return Ok(v);
1369        }
1370
1371        let mut v = Vec::new();
1372        let st = step.max(1) as isize;
1373        let mut x = start as isize;
1374        let end_i = end as isize;
1375        while x >= end_i {
1376            v.push(x as usize);
1377            x -= st;
1378        }
1379        if v.is_empty() {
1380            return Err(StochfError::InvalidRange { start, end, step });
1381        }
1382        Ok(v)
1383    }
1384    let fastk = axis_usize(r.fastk_period).unwrap_or_else(|_| Vec::new());
1385    let fastd = axis_usize(r.fastd_period).unwrap_or_else(|_| Vec::new());
1386    let mut out = Vec::with_capacity(fastk.len().saturating_mul(fastd.len()));
1387    for &k in &fastk {
1388        for &d in &fastd {
1389            out.push(StochfParams {
1390                fastk_period: Some(k),
1391                fastd_period: Some(d),
1392                fastd_matype: Some(0),
1393            });
1394        }
1395    }
1396    out
1397}
1398
1399#[inline(always)]
1400pub fn stochf_batch_slice(
1401    high: &[f64],
1402    low: &[f64],
1403    close: &[f64],
1404    sweep: &StochfBatchRange,
1405    kern: Kernel,
1406) -> Result<StochfBatchOutput, StochfError> {
1407    stochf_batch_inner(high, low, close, sweep, kern, false)
1408}
1409
1410#[inline(always)]
1411pub fn stochf_batch_par_slice(
1412    high: &[f64],
1413    low: &[f64],
1414    close: &[f64],
1415    sweep: &StochfBatchRange,
1416    kern: Kernel,
1417) -> Result<StochfBatchOutput, StochfError> {
1418    stochf_batch_inner(high, low, close, sweep, kern, true)
1419}
1420
1421#[inline(always)]
1422pub fn stochf_batch_inner_into(
1423    high: &[f64],
1424    low: &[f64],
1425    close: &[f64],
1426    sweep: &StochfBatchRange,
1427    kern: Kernel,
1428    parallel: bool,
1429    k_out: &mut [f64],
1430    d_out: &mut [f64],
1431) -> Result<Vec<StochfParams>, StochfError> {
1432    let combos = expand_grid(sweep);
1433    if combos.is_empty() {
1434        return Err(StochfError::InvalidRange {
1435            start: sweep.fastk_period.0,
1436            end: sweep.fastk_period.1,
1437            step: sweep.fastk_period.2,
1438        });
1439    }
1440    let first = (0..high.len())
1441        .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
1442        .ok_or(StochfError::AllValuesNaN)?;
1443    let max_k = combos
1444        .iter()
1445        .map(|c| c.fastk_period.unwrap())
1446        .max()
1447        .unwrap();
1448    if high.len() - first < max_k {
1449        return Err(StochfError::NotEnoughValidData {
1450            needed: max_k,
1451            valid: high.len() - first,
1452        });
1453    }
1454    let rows = combos.len();
1455    let cols = high.len();
1456
1457    let expected_size = rows.checked_mul(cols).ok_or(StochfError::InvalidRange {
1458        start: sweep.fastk_period.0,
1459        end: sweep.fastk_period.1,
1460        step: sweep.fastk_period.2,
1461    })?;
1462    if k_out.len() != expected_size || d_out.len() != expected_size {
1463        return Err(StochfError::OutputLengthMismatch {
1464            expected: expected_size,
1465            k_got: k_out.len(),
1466            d_got: d_out.len(),
1467        });
1468    }
1469
1470    for (row, combo) in combos.iter().enumerate() {
1471        let k_warmup = (first + combo.fastk_period.unwrap() - 1).min(cols);
1472        let d_warmup =
1473            (first + combo.fastk_period.unwrap() + combo.fastd_period.unwrap() - 2).min(cols);
1474        let row_start = row * cols;
1475
1476        for i in 0..k_warmup {
1477            k_out[row_start + i] = f64::NAN;
1478        }
1479
1480        for i in 0..d_warmup {
1481            d_out[row_start + i] = f64::NAN;
1482        }
1483    }
1484
1485    let do_row = |row: usize, kout: &mut [f64], dout: &mut [f64]| unsafe {
1486        let fastk_period = combos[row].fastk_period.unwrap();
1487        let fastd_period = combos[row].fastd_period.unwrap();
1488        let matype = combos[row].fastd_matype.unwrap();
1489        match kern {
1490            Kernel::Scalar => stochf_row_scalar(
1491                high,
1492                low,
1493                close,
1494                first,
1495                fastk_period,
1496                fastd_period,
1497                matype,
1498                kout,
1499                dout,
1500            ),
1501            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1502            Kernel::Avx2 => stochf_row_avx2(
1503                high,
1504                low,
1505                close,
1506                first,
1507                fastk_period,
1508                fastd_period,
1509                matype,
1510                kout,
1511                dout,
1512            ),
1513            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1514            Kernel::Avx512 => stochf_row_avx512(
1515                high,
1516                low,
1517                close,
1518                first,
1519                fastk_period,
1520                fastd_period,
1521                matype,
1522                kout,
1523                dout,
1524            ),
1525            _ => unreachable!(),
1526        }
1527    };
1528
1529    if parallel {
1530        #[cfg(not(target_arch = "wasm32"))]
1531        {
1532            k_out
1533                .par_chunks_mut(cols)
1534                .zip(d_out.par_chunks_mut(cols))
1535                .enumerate()
1536                .for_each(|(row, (k, d))| do_row(row, k, d));
1537        }
1538
1539        #[cfg(target_arch = "wasm32")]
1540        {
1541            for (row, (k, d)) in k_out
1542                .chunks_mut(cols)
1543                .zip(d_out.chunks_mut(cols))
1544                .enumerate()
1545            {
1546                do_row(row, k, d);
1547            }
1548        }
1549    } else {
1550        for (row, (k, d)) in k_out
1551            .chunks_mut(cols)
1552            .zip(d_out.chunks_mut(cols))
1553            .enumerate()
1554        {
1555            do_row(row, k, d);
1556        }
1557    }
1558
1559    Ok(combos)
1560}
1561
1562#[inline(always)]
1563fn stochf_batch_inner(
1564    high: &[f64],
1565    low: &[f64],
1566    close: &[f64],
1567    sweep: &StochfBatchRange,
1568    kern: Kernel,
1569    parallel: bool,
1570) -> Result<StochfBatchOutput, StochfError> {
1571    let combos = expand_grid(sweep);
1572    if combos.is_empty() {
1573        return Err(StochfError::InvalidRange {
1574            start: sweep.fastk_period.0,
1575            end: sweep.fastk_period.1,
1576            step: sweep.fastk_period.2,
1577        });
1578    }
1579    let first = (0..high.len())
1580        .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
1581        .ok_or(StochfError::AllValuesNaN)?;
1582    let max_k = combos
1583        .iter()
1584        .map(|c| c.fastk_period.unwrap())
1585        .max()
1586        .unwrap();
1587    if high.len() - first < max_k {
1588        return Err(StochfError::NotEnoughValidData {
1589            needed: max_k,
1590            valid: high.len() - first,
1591        });
1592    }
1593    let rows = combos.len();
1594    let cols = high.len();
1595
1596    let _total = rows.checked_mul(cols).ok_or(StochfError::InvalidRange {
1597        start: sweep.fastk_period.0,
1598        end: sweep.fastk_period.1,
1599        step: sweep.fastk_period.2,
1600    })?;
1601
1602    let mut k_buf = make_uninit_matrix(rows, cols);
1603    let mut d_buf = make_uninit_matrix(rows, cols);
1604
1605    let k_warmups: Vec<usize> = combos
1606        .iter()
1607        .map(|c| (first + c.fastk_period.unwrap() - 1).min(cols))
1608        .collect();
1609    let d_warmups: Vec<usize> = combos
1610        .iter()
1611        .map(|c| (first + c.fastk_period.unwrap() + c.fastd_period.unwrap() - 2).min(cols))
1612        .collect();
1613
1614    init_matrix_prefixes(&mut k_buf, cols, &k_warmups);
1615    init_matrix_prefixes(&mut d_buf, cols, &d_warmups);
1616
1617    let k_buf_len = k_buf.len();
1618    let d_buf_len = d_buf.len();
1619    let k_buf_cap = k_buf.capacity();
1620    let d_buf_cap = d_buf.capacity();
1621    let k_ptr = k_buf.as_mut_ptr();
1622    let d_ptr = d_buf.as_mut_ptr();
1623    std::mem::forget(k_buf);
1624    std::mem::forget(d_buf);
1625    let k_out = unsafe { std::slice::from_raw_parts_mut(k_ptr as *mut f64, k_buf_len) };
1626    let d_out = unsafe { std::slice::from_raw_parts_mut(d_ptr as *mut f64, d_buf_len) };
1627
1628    let do_row = |row: usize, kout: &mut [f64], dout: &mut [f64]| unsafe {
1629        let fastk_period = combos[row].fastk_period.unwrap();
1630        let fastd_period = combos[row].fastd_period.unwrap();
1631        let matype = combos[row].fastd_matype.unwrap();
1632        match kern {
1633            Kernel::Scalar => stochf_row_scalar(
1634                high,
1635                low,
1636                close,
1637                first,
1638                fastk_period,
1639                fastd_period,
1640                matype,
1641                kout,
1642                dout,
1643            ),
1644            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1645            Kernel::Avx2 => stochf_row_avx2(
1646                high,
1647                low,
1648                close,
1649                first,
1650                fastk_period,
1651                fastd_period,
1652                matype,
1653                kout,
1654                dout,
1655            ),
1656            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1657            Kernel::Avx512 => stochf_row_avx512(
1658                high,
1659                low,
1660                close,
1661                first,
1662                fastk_period,
1663                fastd_period,
1664                matype,
1665                kout,
1666                dout,
1667            ),
1668            _ => unreachable!(),
1669        }
1670    };
1671    if parallel {
1672        #[cfg(not(target_arch = "wasm32"))]
1673        {
1674            k_out
1675                .par_chunks_mut(cols)
1676                .zip(d_out.par_chunks_mut(cols))
1677                .enumerate()
1678                .for_each(|(row, (k, d))| do_row(row, k, d));
1679        }
1680
1681        #[cfg(target_arch = "wasm32")]
1682        {
1683            for (row, (k, d)) in k_out
1684                .chunks_mut(cols)
1685                .zip(d_out.chunks_mut(cols))
1686                .enumerate()
1687            {
1688                do_row(row, k, d);
1689            }
1690        }
1691    } else {
1692        for (row, (k, d)) in k_out
1693            .chunks_mut(cols)
1694            .zip(d_out.chunks_mut(cols))
1695            .enumerate()
1696        {
1697            do_row(row, k, d);
1698        }
1699    }
1700
1701    let k_vec = unsafe { Vec::from_raw_parts(k_ptr as *mut f64, k_buf_len, k_buf_cap) };
1702    let d_vec = unsafe { Vec::from_raw_parts(d_ptr as *mut f64, d_buf_len, d_buf_cap) };
1703
1704    Ok(StochfBatchOutput {
1705        k: k_vec,
1706        d: d_vec,
1707        combos,
1708        rows,
1709        cols,
1710    })
1711}
1712
1713#[inline(always)]
1714unsafe fn stochf_row_scalar(
1715    high: &[f64],
1716    low: &[f64],
1717    close: &[f64],
1718    first: usize,
1719    fastk_period: usize,
1720    fastd_period: usize,
1721    matype: usize,
1722    k_out: &mut [f64],
1723    d_out: &mut [f64],
1724) {
1725    let len = high.len();
1726
1727    let hp = high.as_ptr();
1728    let lp = low.as_ptr();
1729    let cp = close.as_ptr();
1730
1731    let k_start = first + fastk_period - 1;
1732
1733    if fastk_period <= 16 {
1734        let use_sma_d = matype == 0;
1735        let mut d_sum: f64 = 0.0;
1736        let mut d_cnt: usize = 0;
1737
1738        let mut i = k_start;
1739        while i < len {
1740            let start = i + 1 - fastk_period;
1741            let end = i + 1;
1742
1743            let mut hh = f64::NEG_INFINITY;
1744            let mut ll = f64::INFINITY;
1745
1746            let mut j = start;
1747            let unroll_end = end - ((end - j) & 3);
1748            while j < unroll_end {
1749                let h0 = *hp.add(j);
1750                let l0 = *lp.add(j);
1751                if h0 > hh {
1752                    hh = h0;
1753                }
1754                if l0 < ll {
1755                    ll = l0;
1756                }
1757
1758                let h1 = *hp.add(j + 1);
1759                let l1 = *lp.add(j + 1);
1760                if h1 > hh {
1761                    hh = h1;
1762                }
1763                if l1 < ll {
1764                    ll = l1;
1765                }
1766
1767                let h2 = *hp.add(j + 2);
1768                let l2 = *lp.add(j + 2);
1769                if h2 > hh {
1770                    hh = h2;
1771                }
1772                if l2 < ll {
1773                    ll = l2;
1774                }
1775
1776                let h3 = *hp.add(j + 3);
1777                let l3 = *lp.add(j + 3);
1778                if h3 > hh {
1779                    hh = h3;
1780                }
1781                if l3 < ll {
1782                    ll = l3;
1783                }
1784
1785                j += 4;
1786            }
1787            while j < end {
1788                let h = *hp.add(j);
1789                let l = *lp.add(j);
1790                if h > hh {
1791                    hh = h;
1792                }
1793                if l < ll {
1794                    ll = l;
1795                }
1796                j += 1;
1797            }
1798
1799            let c = *cp.add(i);
1800            let denom = hh - ll;
1801            let kv = if denom == 0.0 {
1802                if c == hh {
1803                    100.0
1804                } else {
1805                    0.0
1806                }
1807            } else {
1808                let inv = 100.0 / denom;
1809                c.mul_add(inv, (-ll) * inv)
1810            };
1811            *k_out.get_unchecked_mut(i) = kv;
1812
1813            if use_sma_d {
1814                if kv.is_nan() {
1815                    *d_out.get_unchecked_mut(i) = f64::NAN;
1816                } else if d_cnt < fastd_period {
1817                    d_sum += kv;
1818                    d_cnt += 1;
1819                    if d_cnt == fastd_period {
1820                        *d_out.get_unchecked_mut(i) = d_sum / (fastd_period as f64);
1821                    } else {
1822                        *d_out.get_unchecked_mut(i) = f64::NAN;
1823                    }
1824                } else {
1825                    d_sum += kv - *k_out.get_unchecked(i - fastd_period);
1826                    *d_out.get_unchecked_mut(i) = d_sum / (fastd_period as f64);
1827                }
1828            }
1829
1830            i += 1;
1831        }
1832
1833        if matype != 0 {
1834            d_out.fill(f64::NAN);
1835        }
1836        return;
1837    }
1838
1839    let cap = fastk_period;
1840    let mut qh = vec![0usize; cap];
1841    let mut ql = vec![0usize; cap];
1842    let mut qh_head = 0usize;
1843    let mut qh_tail = 0usize;
1844    let mut ql_head = 0usize;
1845    let mut ql_tail = 0usize;
1846
1847    let use_sma_d = matype == 0;
1848    let mut d_sum: f64 = 0.0;
1849    let mut d_cnt: usize = 0;
1850
1851    let mut i = first;
1852    while i < len {
1853        if i + 1 >= fastk_period {
1854            let win_start = i + 1 - fastk_period;
1855            while qh_head != qh_tail {
1856                let idx = *qh.get_unchecked(qh_head);
1857                if idx >= win_start {
1858                    break;
1859                }
1860                qh_head += 1;
1861                if qh_head == cap {
1862                    qh_head = 0;
1863                }
1864            }
1865            while ql_head != ql_tail {
1866                let idx = *ql.get_unchecked(ql_head);
1867                if idx >= win_start {
1868                    break;
1869                }
1870                ql_head += 1;
1871                if ql_head == cap {
1872                    ql_head = 0;
1873                }
1874            }
1875        }
1876
1877        let h_i = *hp.add(i);
1878        if h_i == h_i {
1879            while qh_head != qh_tail {
1880                let back = if qh_tail == 0 { cap - 1 } else { qh_tail - 1 };
1881                let back_idx = *qh.get_unchecked(back);
1882                if *hp.add(back_idx) <= h_i {
1883                    qh_tail = back;
1884                } else {
1885                    break;
1886                }
1887            }
1888            *qh.get_unchecked_mut(qh_tail) = i;
1889            qh_tail += 1;
1890            if qh_tail == cap {
1891                qh_tail = 0;
1892            }
1893        }
1894
1895        let l_i = *lp.add(i);
1896        if l_i == l_i {
1897            while ql_head != ql_tail {
1898                let back = if ql_tail == 0 { cap - 1 } else { ql_tail - 1 };
1899                let back_idx = *ql.get_unchecked(back);
1900                if *lp.add(back_idx) >= l_i {
1901                    ql_tail = back;
1902                } else {
1903                    break;
1904                }
1905            }
1906            *ql.get_unchecked_mut(ql_tail) = i;
1907            ql_tail += 1;
1908            if ql_tail == cap {
1909                ql_tail = 0;
1910            }
1911        }
1912
1913        if i >= k_start {
1914            let hh = if qh_head != qh_tail {
1915                *hp.add(*qh.get_unchecked(qh_head))
1916            } else {
1917                f64::NEG_INFINITY
1918            };
1919            let ll = if ql_head != ql_tail {
1920                *lp.add(*ql.get_unchecked(ql_head))
1921            } else {
1922                f64::INFINITY
1923            };
1924            let c = *cp.add(i);
1925            let denom = hh - ll;
1926            let kv = if denom == 0.0 {
1927                if c == hh {
1928                    100.0
1929                } else {
1930                    0.0
1931                }
1932            } else {
1933                let inv = 100.0 / denom;
1934                c.mul_add(inv, (-ll) * inv)
1935            };
1936            *k_out.get_unchecked_mut(i) = kv;
1937
1938            if use_sma_d {
1939                if kv.is_nan() {
1940                    *d_out.get_unchecked_mut(i) = f64::NAN;
1941                } else if d_cnt < fastd_period {
1942                    d_sum += kv;
1943                    d_cnt += 1;
1944                    if d_cnt == fastd_period {
1945                        *d_out.get_unchecked_mut(i) = d_sum / (fastd_period as f64);
1946                    } else {
1947                        *d_out.get_unchecked_mut(i) = f64::NAN;
1948                    }
1949                } else {
1950                    d_sum += kv - *k_out.get_unchecked(i - fastd_period);
1951                    *d_out.get_unchecked_mut(i) = d_sum / (fastd_period as f64);
1952                }
1953            }
1954        }
1955
1956        i += 1;
1957    }
1958
1959    if !use_sma_d {
1960        d_out.fill(f64::NAN);
1961    }
1962}
1963
1964#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1965#[inline]
1966unsafe fn stochf_row_avx2(
1967    high: &[f64],
1968    low: &[f64],
1969    close: &[f64],
1970    first: usize,
1971    fastk_period: usize,
1972    fastd_period: usize,
1973    matype: usize,
1974    k_out: &mut [f64],
1975    d_out: &mut [f64],
1976) {
1977    stochf_row_scalar(
1978        high,
1979        low,
1980        close,
1981        first,
1982        fastk_period,
1983        fastd_period,
1984        matype,
1985        k_out,
1986        d_out,
1987    );
1988}
1989
1990#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1991#[inline]
1992pub unsafe fn stochf_row_avx512(
1993    high: &[f64],
1994    low: &[f64],
1995    close: &[f64],
1996    first: usize,
1997    fastk_period: usize,
1998    fastd_period: usize,
1999    matype: usize,
2000    k_out: &mut [f64],
2001    d_out: &mut [f64],
2002) {
2003    if fastk_period <= 32 {
2004        stochf_row_avx512_short(
2005            high,
2006            low,
2007            close,
2008            first,
2009            fastk_period,
2010            fastd_period,
2011            matype,
2012            k_out,
2013            d_out,
2014        );
2015    } else {
2016        stochf_row_avx512_long(
2017            high,
2018            low,
2019            close,
2020            first,
2021            fastk_period,
2022            fastd_period,
2023            matype,
2024            k_out,
2025            d_out,
2026        );
2027    }
2028}
2029
2030#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2031#[inline]
2032pub unsafe fn stochf_row_avx512_short(
2033    high: &[f64],
2034    low: &[f64],
2035    close: &[f64],
2036    first: usize,
2037    fastk_period: usize,
2038    fastd_period: usize,
2039    matype: usize,
2040    k_out: &mut [f64],
2041    d_out: &mut [f64],
2042) {
2043    stochf_row_scalar(
2044        high,
2045        low,
2046        close,
2047        first,
2048        fastk_period,
2049        fastd_period,
2050        matype,
2051        k_out,
2052        d_out,
2053    );
2054}
2055
2056#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2057#[inline]
2058pub unsafe fn stochf_row_avx512_long(
2059    high: &[f64],
2060    low: &[f64],
2061    close: &[f64],
2062    first: usize,
2063    fastk_period: usize,
2064    fastd_period: usize,
2065    matype: usize,
2066    k_out: &mut [f64],
2067    d_out: &mut [f64],
2068) {
2069    stochf_row_scalar(
2070        high,
2071        low,
2072        close,
2073        first,
2074        fastk_period,
2075        fastd_period,
2076        matype,
2077        k_out,
2078        d_out,
2079    );
2080}
2081
2082#[cfg(feature = "python")]
2083#[pyfunction(name = "stochf")]
2084#[pyo3(signature = (high, low, close, fastk_period=None, fastd_period=None, fastd_matype=None, kernel=None))]
2085pub fn stochf_py<'py>(
2086    py: Python<'py>,
2087    high: numpy::PyReadonlyArray1<'py, f64>,
2088    low: numpy::PyReadonlyArray1<'py, f64>,
2089    close: numpy::PyReadonlyArray1<'py, f64>,
2090    fastk_period: Option<usize>,
2091    fastd_period: Option<usize>,
2092    fastd_matype: Option<usize>,
2093    kernel: Option<&str>,
2094) -> PyResult<(
2095    Bound<'py, numpy::PyArray1<f64>>,
2096    Bound<'py, numpy::PyArray1<f64>>,
2097)> {
2098    use numpy::{IntoPyArray, PyArrayMethods};
2099
2100    let high_slice = high.as_slice()?;
2101    let low_slice = low.as_slice()?;
2102    let close_slice = close.as_slice()?;
2103    let kern = validate_kernel(kernel, false)?;
2104
2105    if high_slice.len() != low_slice.len() || high_slice.len() != close_slice.len() {
2106        return Err(PyValueError::new_err(
2107            "Input arrays must have the same length",
2108        ));
2109    }
2110
2111    let params = StochfParams {
2112        fastk_period,
2113        fastd_period,
2114        fastd_matype,
2115    };
2116    let input = StochfInput::from_slices(high_slice, low_slice, close_slice, params);
2117
2118    let (k_vec, d_vec) = py
2119        .allow_threads(|| stochf_with_kernel(&input, kern).map(|o| (o.k, o.d)))
2120        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2121
2122    Ok((k_vec.into_pyarray(py), d_vec.into_pyarray(py)))
2123}
2124
2125#[cfg(feature = "python")]
2126#[pyclass(name = "StochfStream")]
2127pub struct StochfStreamPy {
2128    stream: StochfStream,
2129}
2130
2131#[cfg(feature = "python")]
2132#[pymethods]
2133impl StochfStreamPy {
2134    #[new]
2135    fn new(fastk_period: usize, fastd_period: usize, fastd_matype: usize) -> PyResult<Self> {
2136        let params = StochfParams {
2137            fastk_period: Some(fastk_period),
2138            fastd_period: Some(fastd_period),
2139            fastd_matype: Some(fastd_matype),
2140        };
2141        let stream =
2142            StochfStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2143        Ok(StochfStreamPy { stream })
2144    }
2145
2146    fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
2147        self.stream.update(high, low, close)
2148    }
2149}
2150
2151#[cfg(feature = "python")]
2152#[pyfunction(name = "stochf_batch")]
2153#[pyo3(signature = (high, low, close, fastk_range, fastd_range, kernel=None))]
2154pub fn stochf_batch_py<'py>(
2155    py: Python<'py>,
2156    high: numpy::PyReadonlyArray1<'py, f64>,
2157    low: numpy::PyReadonlyArray1<'py, f64>,
2158    close: numpy::PyReadonlyArray1<'py, f64>,
2159    fastk_range: (usize, usize, usize),
2160    fastd_range: (usize, usize, usize),
2161    kernel: Option<&str>,
2162) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2163    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2164    use pyo3::types::PyDict;
2165
2166    let high_slice = high.as_slice()?;
2167    let low_slice = low.as_slice()?;
2168    let close_slice = close.as_slice()?;
2169
2170    if high_slice.len() != low_slice.len() || high_slice.len() != close_slice.len() {
2171        return Err(PyValueError::new_err(
2172            "Input arrays must have the same length",
2173        ));
2174    }
2175
2176    let sweep = StochfBatchRange {
2177        fastk_period: fastk_range,
2178        fastd_period: fastd_range,
2179    };
2180
2181    let combos = expand_grid(&sweep);
2182    let rows = combos.len();
2183    if rows == 0 {
2184        return Err(PyValueError::new_err(
2185            "stochf: invalid range (empty expansion)",
2186        ));
2187    }
2188    let cols = high_slice.len();
2189    let total = rows
2190        .checked_mul(cols)
2191        .ok_or_else(|| PyValueError::new_err("stochf: rows*cols overflow"))?;
2192
2193    let k_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2194    let d_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2195    let k_slice = unsafe { k_arr.as_slice_mut()? };
2196    let d_slice = unsafe { d_arr.as_slice_mut()? };
2197
2198    let kern = validate_kernel(kernel, true)?;
2199
2200    let combos = py
2201        .allow_threads(|| {
2202            let kernel = match kern {
2203                Kernel::Auto => detect_best_batch_kernel(),
2204                k => k,
2205            };
2206            let simd = match kernel {
2207                Kernel::Avx512Batch => Kernel::Avx512,
2208                Kernel::Avx2Batch => Kernel::Avx2,
2209                Kernel::ScalarBatch => Kernel::Scalar,
2210                _ => unreachable!(),
2211            };
2212            stochf_batch_inner_into(
2213                high_slice,
2214                low_slice,
2215                close_slice,
2216                &sweep,
2217                simd,
2218                true,
2219                k_slice,
2220                d_slice,
2221            )
2222        })
2223        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2224
2225    let dict = PyDict::new(py);
2226    dict.set_item("k_values", k_arr.reshape((rows, cols))?)?;
2227    dict.set_item("d_values", d_arr.reshape((rows, cols))?)?;
2228    dict.set_item(
2229        "fastk_periods",
2230        combos
2231            .iter()
2232            .map(|p| p.fastk_period.unwrap() as u64)
2233            .collect::<Vec<_>>()
2234            .into_pyarray(py),
2235    )?;
2236    dict.set_item(
2237        "fastd_periods",
2238        combos
2239            .iter()
2240            .map(|p| p.fastd_period.unwrap() as u64)
2241            .collect::<Vec<_>>()
2242            .into_pyarray(py),
2243    )?;
2244
2245    Ok(dict)
2246}
2247
2248#[cfg(all(feature = "python", feature = "cuda"))]
2249use crate::cuda::{cuda_available, CudaStochf};
2250#[cfg(all(feature = "python", feature = "cuda"))]
2251use numpy::PyReadonlyArray1;
2252#[cfg(all(feature = "python", feature = "cuda"))]
2253use pyo3::exceptions::PyValueError as PyErrValue;
2254#[cfg(all(feature = "python", feature = "cuda"))]
2255use pyo3::PyErr;
2256
2257#[cfg(all(feature = "python", feature = "cuda"))]
2258#[pyfunction(name = "stochf_cuda_batch_dev")]
2259#[pyo3(signature = (high_f32, low_f32, close_f32, fastk_range, fastd_range, device_id=0))]
2260pub fn stochf_cuda_batch_dev_py(
2261    py: Python<'_>,
2262    high_f32: PyReadonlyArray1<'_, f32>,
2263    low_f32: PyReadonlyArray1<'_, f32>,
2264    close_f32: PyReadonlyArray1<'_, f32>,
2265    fastk_range: (usize, usize, usize),
2266    fastd_range: (usize, usize, usize),
2267    device_id: usize,
2268) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
2269    if !cuda_available() {
2270        return Err(PyErrValue::new_err("CUDA not available"));
2271    }
2272    let h = high_f32.as_slice()?;
2273    let l = low_f32.as_slice()?;
2274    let c = close_f32.as_slice()?;
2275    if h.len() != l.len() || h.len() != c.len() {
2276        return Err(PyErrValue::new_err("mismatched input lengths"));
2277    }
2278    let sweep = StochfBatchRange {
2279        fastk_period: fastk_range,
2280        fastd_period: fastd_range,
2281    };
2282    let (pair, ctx, dev_id) = py.allow_threads(|| {
2283        let cuda = CudaStochf::new(device_id).map_err(|e| PyErrValue::new_err(e.to_string()))?;
2284        let ctx = cuda.context_arc();
2285        let dev_id = cuda.device_id();
2286        let (pair, _combos) = cuda
2287            .stochf_batch_dev(h, l, c, &sweep)
2288            .map_err(|e| PyErrValue::new_err(e.to_string()))?;
2289        Ok::<_, PyErr>((pair, ctx, dev_id))
2290    })?;
2291    Ok((
2292        DeviceArrayF32Py {
2293            inner: pair.a,
2294            _ctx: Some(ctx.clone()),
2295            device_id: Some(dev_id),
2296        },
2297        DeviceArrayF32Py {
2298            inner: pair.b,
2299            _ctx: Some(ctx),
2300            device_id: Some(dev_id),
2301        },
2302    ))
2303}
2304
2305#[cfg(all(feature = "python", feature = "cuda"))]
2306#[pyfunction(name = "stochf_cuda_many_series_one_param_dev")]
2307#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, cols, rows, fastk, fastd, fastd_matype=0, device_id=0))]
2308pub fn stochf_cuda_many_series_one_param_dev_py(
2309    py: Python<'_>,
2310    high_tm_f32: PyReadonlyArray1<'_, f32>,
2311    low_tm_f32: PyReadonlyArray1<'_, f32>,
2312    close_tm_f32: PyReadonlyArray1<'_, f32>,
2313    cols: usize,
2314    rows: usize,
2315    fastk: usize,
2316    fastd: usize,
2317    fastd_matype: usize,
2318    device_id: usize,
2319) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
2320    if !cuda_available() {
2321        return Err(PyErrValue::new_err("CUDA not available"));
2322    }
2323    let htm = high_tm_f32.as_slice()?;
2324    let ltm = low_tm_f32.as_slice()?;
2325    let ctm = close_tm_f32.as_slice()?;
2326    let params = StochfParams {
2327        fastk_period: Some(fastk),
2328        fastd_period: Some(fastd),
2329        fastd_matype: Some(fastd_matype),
2330    };
2331    let (k, d, ctx, dev_id) = py.allow_threads(|| {
2332        let cuda = CudaStochf::new(device_id).map_err(|e| PyErrValue::new_err(e.to_string()))?;
2333        let ctx = cuda.context_arc();
2334        let dev_id = cuda.device_id();
2335        let (k, d) = cuda
2336            .stochf_many_series_one_param_time_major_dev(htm, ltm, ctm, cols, rows, &params)
2337            .map_err(|e| PyErrValue::new_err(e.to_string()))?;
2338        Ok::<_, PyErr>((k, d, ctx, dev_id))
2339    })?;
2340    Ok((
2341        DeviceArrayF32Py {
2342            inner: k,
2343            _ctx: Some(ctx.clone()),
2344            device_id: Some(dev_id),
2345        },
2346        DeviceArrayF32Py {
2347            inner: d,
2348            _ctx: Some(ctx),
2349            device_id: Some(dev_id),
2350        },
2351    ))
2352}
2353
2354#[cfg(test)]
2355mod tests {
2356    use super::*;
2357    use crate::skip_if_unsupported;
2358    use crate::utilities::data_loader::read_candles_from_csv;
2359
2360    fn check_stochf_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2361        skip_if_unsupported!(kernel, test_name);
2362        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2363        let candles = read_candles_from_csv(file_path)?;
2364        let params = StochfParams {
2365            fastk_period: None,
2366            fastd_period: None,
2367            fastd_matype: None,
2368        };
2369        let input = StochfInput::from_candles(&candles, params);
2370        let output = stochf_with_kernel(&input, kernel)?;
2371        assert_eq!(output.k.len(), candles.close.len());
2372        assert_eq!(output.d.len(), candles.close.len());
2373        Ok(())
2374    }
2375
2376    fn check_stochf_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2377        skip_if_unsupported!(kernel, test_name);
2378        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2379        let candles = read_candles_from_csv(file_path)?;
2380        let params = StochfParams {
2381            fastk_period: Some(5),
2382            fastd_period: Some(3),
2383            fastd_matype: Some(0),
2384        };
2385        let input = StochfInput::from_candles(&candles, params);
2386        let output = stochf_with_kernel(&input, kernel)?;
2387        let expected_k = [
2388            80.6987399770905,
2389            40.88471849865952,
2390            15.507246376811594,
2391            36.920529801324506,
2392            32.1880650994575,
2393        ];
2394        let expected_d = [
2395            70.99960994145033,
2396            61.44725644908976,
2397            45.696901617520815,
2398            31.104164892265487,
2399            28.205280425864817,
2400        ];
2401        let k_slice = &output.k[output.k.len() - 5..];
2402        let d_slice = &output.d[output.d.len() - 5..];
2403        for i in 0..5 {
2404            assert!(
2405                (k_slice[i] - expected_k[i]).abs() < 1e-4,
2406                "K mismatch at idx {}",
2407                i
2408            );
2409            assert!(
2410                (d_slice[i] - expected_d[i]).abs() < 1e-4,
2411                "D mismatch at idx {}",
2412                i
2413            );
2414        }
2415        Ok(())
2416    }
2417
2418    fn check_stochf_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2419        skip_if_unsupported!(kernel, test_name);
2420        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2421        let candles = read_candles_from_csv(file_path)?;
2422        let input = StochfInput::with_default_candles(&candles);
2423        let output = stochf_with_kernel(&input, kernel)?;
2424        assert_eq!(output.k.len(), candles.close.len());
2425        assert_eq!(output.d.len(), candles.close.len());
2426        Ok(())
2427    }
2428
2429    fn check_stochf_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2430        skip_if_unsupported!(kernel, test_name);
2431        let data = [10.0, 20.0, 30.0, 40.0, 50.0];
2432        let params = StochfParams {
2433            fastk_period: Some(0),
2434            fastd_period: Some(3),
2435            fastd_matype: Some(0),
2436        };
2437        let input = StochfInput::from_slices(&data, &data, &data, params);
2438        let res = stochf_with_kernel(&input, kernel);
2439        assert!(res.is_err());
2440        Ok(())
2441    }
2442
2443    fn check_stochf_period_exceeds_length(
2444        test_name: &str,
2445        kernel: Kernel,
2446    ) -> Result<(), Box<dyn Error>> {
2447        skip_if_unsupported!(kernel, test_name);
2448        let data = [10.0, 20.0, 30.0];
2449        let params = StochfParams {
2450            fastk_period: Some(10),
2451            fastd_period: Some(3),
2452            fastd_matype: Some(0),
2453        };
2454        let input = StochfInput::from_slices(&data, &data, &data, params);
2455        let res = stochf_with_kernel(&input, kernel);
2456        assert!(res.is_err());
2457        Ok(())
2458    }
2459
2460    fn check_stochf_very_small_dataset(
2461        test_name: &str,
2462        kernel: Kernel,
2463    ) -> Result<(), Box<dyn Error>> {
2464        skip_if_unsupported!(kernel, test_name);
2465        let data = [42.0];
2466        let params = StochfParams {
2467            fastk_period: Some(9),
2468            fastd_period: Some(3),
2469            fastd_matype: Some(0),
2470        };
2471        let input = StochfInput::from_slices(&data, &data, &data, params);
2472        let res = stochf_with_kernel(&input, kernel);
2473        assert!(res.is_err());
2474        Ok(())
2475    }
2476
2477    fn check_stochf_slice_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2478        skip_if_unsupported!(kernel, test_name);
2479        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2480        let candles = read_candles_from_csv(file_path)?;
2481        let params = StochfParams {
2482            fastk_period: Some(5),
2483            fastd_period: Some(3),
2484            fastd_matype: Some(0),
2485        };
2486        let input1 = StochfInput::from_candles(&candles, params.clone());
2487        let res1 = stochf_with_kernel(&input1, kernel)?;
2488        let input2 = StochfInput::from_slices(&res1.k, &res1.k, &res1.k, params);
2489        let res2 = stochf_with_kernel(&input2, kernel)?;
2490        assert_eq!(res2.k.len(), res1.k.len());
2491        assert_eq!(res2.d.len(), res1.d.len());
2492        Ok(())
2493    }
2494
2495    #[cfg(debug_assertions)]
2496    fn check_stochf_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2497        skip_if_unsupported!(kernel, test_name);
2498
2499        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2500        let candles = read_candles_from_csv(file_path)?;
2501
2502        let test_params = vec![
2503            StochfParams::default(),
2504            StochfParams {
2505                fastk_period: Some(2),
2506                fastd_period: Some(1),
2507                fastd_matype: Some(0),
2508            },
2509            StochfParams {
2510                fastk_period: Some(3),
2511                fastd_period: Some(2),
2512                fastd_matype: Some(0),
2513            },
2514            StochfParams {
2515                fastk_period: Some(5),
2516                fastd_period: Some(5),
2517                fastd_matype: Some(0),
2518            },
2519            StochfParams {
2520                fastk_period: Some(10),
2521                fastd_period: Some(3),
2522                fastd_matype: Some(0),
2523            },
2524            StochfParams {
2525                fastk_period: Some(14),
2526                fastd_period: Some(3),
2527                fastd_matype: Some(0),
2528            },
2529            StochfParams {
2530                fastk_period: Some(20),
2531                fastd_period: Some(5),
2532                fastd_matype: Some(0),
2533            },
2534            StochfParams {
2535                fastk_period: Some(50),
2536                fastd_period: Some(10),
2537                fastd_matype: Some(0),
2538            },
2539            StochfParams {
2540                fastk_period: Some(100),
2541                fastd_period: Some(20),
2542                fastd_matype: Some(0),
2543            },
2544            StochfParams {
2545                fastk_period: Some(8),
2546                fastd_period: Some(7),
2547                fastd_matype: Some(0),
2548            },
2549        ];
2550
2551        for (param_idx, params) in test_params.iter().enumerate() {
2552            let input = StochfInput::from_candles(&candles, params.clone());
2553            let output = stochf_with_kernel(&input, kernel)?;
2554
2555            for (i, &val) in output.k.iter().enumerate() {
2556                if val.is_nan() {
2557                    continue;
2558                }
2559
2560                let bits = val.to_bits();
2561
2562                if bits == 0x11111111_11111111 {
2563                    panic!(
2564						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at K index {} \
2565						 with params: fastk_period={}, fastd_period={}, fastd_matype={} (param set {})",
2566						test_name, val, bits, i,
2567						params.fastk_period.unwrap_or(5),
2568						params.fastd_period.unwrap_or(3),
2569						params.fastd_matype.unwrap_or(0),
2570						param_idx
2571					);
2572                }
2573
2574                if bits == 0x22222222_22222222 {
2575                    panic!(
2576						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at K index {} \
2577						 with params: fastk_period={}, fastd_period={}, fastd_matype={} (param set {})",
2578						test_name, val, bits, i,
2579						params.fastk_period.unwrap_or(5),
2580						params.fastd_period.unwrap_or(3),
2581						params.fastd_matype.unwrap_or(0),
2582						param_idx
2583					);
2584                }
2585
2586                if bits == 0x33333333_33333333 {
2587                    panic!(
2588                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at K index {} \
2589						 with params: fastk_period={}, fastd_period={}, fastd_matype={} (param set {})",
2590                        test_name,
2591                        val,
2592                        bits,
2593                        i,
2594                        params.fastk_period.unwrap_or(5),
2595                        params.fastd_period.unwrap_or(3),
2596                        params.fastd_matype.unwrap_or(0),
2597                        param_idx
2598                    );
2599                }
2600            }
2601
2602            for (i, &val) in output.d.iter().enumerate() {
2603                if val.is_nan() {
2604                    continue;
2605                }
2606
2607                let bits = val.to_bits();
2608
2609                if bits == 0x11111111_11111111 {
2610                    panic!(
2611						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at D index {} \
2612						 with params: fastk_period={}, fastd_period={}, fastd_matype={} (param set {})",
2613						test_name, val, bits, i,
2614						params.fastk_period.unwrap_or(5),
2615						params.fastd_period.unwrap_or(3),
2616						params.fastd_matype.unwrap_or(0),
2617						param_idx
2618					);
2619                }
2620
2621                if bits == 0x22222222_22222222 {
2622                    panic!(
2623						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at D index {} \
2624						 with params: fastk_period={}, fastd_period={}, fastd_matype={} (param set {})",
2625						test_name, val, bits, i,
2626						params.fastk_period.unwrap_or(5),
2627						params.fastd_period.unwrap_or(3),
2628						params.fastd_matype.unwrap_or(0),
2629						param_idx
2630					);
2631                }
2632
2633                if bits == 0x33333333_33333333 {
2634                    panic!(
2635                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at D index {} \
2636						 with params: fastk_period={}, fastd_period={}, fastd_matype={} (param set {})",
2637                        test_name,
2638                        val,
2639                        bits,
2640                        i,
2641                        params.fastk_period.unwrap_or(5),
2642                        params.fastd_period.unwrap_or(3),
2643                        params.fastd_matype.unwrap_or(0),
2644                        param_idx
2645                    );
2646                }
2647            }
2648        }
2649
2650        Ok(())
2651    }
2652
2653    #[cfg(not(debug_assertions))]
2654    fn check_stochf_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2655        Ok(())
2656    }
2657
2658    macro_rules! generate_all_stochf_tests {
2659        ($($test_fn:ident),*) => {
2660            paste::paste! {
2661                $(
2662                    #[test]
2663                    fn [<$test_fn _scalar_f64>]() {
2664                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2665                    }
2666                )*
2667                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2668                $(
2669                    #[test]
2670                    fn [<$test_fn _avx2_f64>]() {
2671                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2672                    }
2673                    #[test]
2674                    fn [<$test_fn _avx512_f64>]() {
2675                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2676                    }
2677                )*
2678            }
2679        }
2680    }
2681
2682    generate_all_stochf_tests!(
2683        check_stochf_partial_params,
2684        check_stochf_accuracy,
2685        check_stochf_default_candles,
2686        check_stochf_zero_period,
2687        check_stochf_period_exceeds_length,
2688        check_stochf_very_small_dataset,
2689        check_stochf_slice_reinput,
2690        check_stochf_no_poison
2691    );
2692
2693    #[cfg(feature = "proptest")]
2694    generate_all_stochf_tests!(check_stochf_property);
2695
2696    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2697        skip_if_unsupported!(kernel, test);
2698        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2699        let c = read_candles_from_csv(file)?;
2700        let output = StochfBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
2701        let def = StochfParams::default();
2702        let krow = output.k_for(&def).expect("default row missing");
2703        let drow = output.d_for(&def).expect("default row missing");
2704        assert_eq!(krow.len(), c.close.len());
2705        assert_eq!(drow.len(), c.close.len());
2706        Ok(())
2707    }
2708
2709    #[cfg(debug_assertions)]
2710    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2711        skip_if_unsupported!(kernel, test);
2712
2713        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2714        let c = read_candles_from_csv(file)?;
2715
2716        let test_configs = vec![
2717            (2, 10, 2, 1, 5, 1),
2718            (5, 25, 5, 3, 3, 0),
2719            (30, 60, 15, 5, 15, 5),
2720            (2, 5, 1, 1, 3, 1),
2721            (10, 20, 2, 3, 9, 3),
2722            (14, 14, 0, 1, 7, 2),
2723            (3, 12, 3, 2, 2, 0),
2724            (50, 100, 25, 10, 20, 10),
2725        ];
2726
2727        for (cfg_idx, &(fk_start, fk_end, fk_step, fd_start, fd_end, fd_step)) in
2728            test_configs.iter().enumerate()
2729        {
2730            let output = StochfBatchBuilder::new()
2731                .kernel(kernel)
2732                .fastk_range(fk_start, fk_end, fk_step)
2733                .fastd_range(fd_start, fd_end, fd_step)
2734                .apply_candles(&c)?;
2735
2736            for (idx, &val) in output.k.iter().enumerate() {
2737                if val.is_nan() {
2738                    continue;
2739                }
2740
2741                let bits = val.to_bits();
2742                let row = idx / output.cols;
2743                let col = idx % output.cols;
2744                let combo = &output.combos[row];
2745
2746                if bits == 0x11111111_11111111 {
2747                    panic!(
2748                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2749						 at K row {} col {} (flat index {}) with params: fastk={}, fastd={}, matype={}",
2750                        test,
2751                        cfg_idx,
2752                        val,
2753                        bits,
2754                        row,
2755                        col,
2756                        idx,
2757                        combo.fastk_period.unwrap_or(5),
2758                        combo.fastd_period.unwrap_or(3),
2759                        combo.fastd_matype.unwrap_or(0)
2760                    );
2761                }
2762
2763                if bits == 0x22222222_22222222 {
2764                    panic!(
2765                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2766						 at K row {} col {} (flat index {}) with params: fastk={}, fastd={}, matype={}",
2767                        test,
2768                        cfg_idx,
2769                        val,
2770                        bits,
2771                        row,
2772                        col,
2773                        idx,
2774                        combo.fastk_period.unwrap_or(5),
2775                        combo.fastd_period.unwrap_or(3),
2776                        combo.fastd_matype.unwrap_or(0)
2777                    );
2778                }
2779
2780                if bits == 0x33333333_33333333 {
2781                    panic!(
2782                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2783						 at K row {} col {} (flat index {}) with params: fastk={}, fastd={}, matype={}",
2784                        test,
2785                        cfg_idx,
2786                        val,
2787                        bits,
2788                        row,
2789                        col,
2790                        idx,
2791                        combo.fastk_period.unwrap_or(5),
2792                        combo.fastd_period.unwrap_or(3),
2793                        combo.fastd_matype.unwrap_or(0)
2794                    );
2795                }
2796            }
2797
2798            for (idx, &val) in output.d.iter().enumerate() {
2799                if val.is_nan() {
2800                    continue;
2801                }
2802
2803                let bits = val.to_bits();
2804                let row = idx / output.cols;
2805                let col = idx % output.cols;
2806                let combo = &output.combos[row];
2807
2808                if bits == 0x11111111_11111111 {
2809                    panic!(
2810                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2811						 at D row {} col {} (flat index {}) with params: fastk={}, fastd={}, matype={}",
2812                        test,
2813                        cfg_idx,
2814                        val,
2815                        bits,
2816                        row,
2817                        col,
2818                        idx,
2819                        combo.fastk_period.unwrap_or(5),
2820                        combo.fastd_period.unwrap_or(3),
2821                        combo.fastd_matype.unwrap_or(0)
2822                    );
2823                }
2824
2825                if bits == 0x22222222_22222222 {
2826                    panic!(
2827                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2828						 at D row {} col {} (flat index {}) with params: fastk={}, fastd={}, matype={}",
2829                        test,
2830                        cfg_idx,
2831                        val,
2832                        bits,
2833                        row,
2834                        col,
2835                        idx,
2836                        combo.fastk_period.unwrap_or(5),
2837                        combo.fastd_period.unwrap_or(3),
2838                        combo.fastd_matype.unwrap_or(0)
2839                    );
2840                }
2841
2842                if bits == 0x33333333_33333333 {
2843                    panic!(
2844                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2845						 at D row {} col {} (flat index {}) with params: fastk={}, fastd={}, matype={}",
2846                        test,
2847                        cfg_idx,
2848                        val,
2849                        bits,
2850                        row,
2851                        col,
2852                        idx,
2853                        combo.fastk_period.unwrap_or(5),
2854                        combo.fastd_period.unwrap_or(3),
2855                        combo.fastd_matype.unwrap_or(0)
2856                    );
2857                }
2858            }
2859        }
2860
2861        Ok(())
2862    }
2863
2864    #[cfg(not(debug_assertions))]
2865    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2866        Ok(())
2867    }
2868
2869    #[cfg(feature = "proptest")]
2870    #[allow(clippy::float_cmp)]
2871    fn check_stochf_property(
2872        test_name: &str,
2873        kernel: Kernel,
2874    ) -> Result<(), Box<dyn std::error::Error>> {
2875        use proptest::prelude::*;
2876        skip_if_unsupported!(kernel, test_name);
2877
2878        let strat = (2usize..=50).prop_flat_map(|fastk_period| {
2879            (
2880                prop::collection::vec(
2881                    (1.0f64..1000.0f64).prop_filter("finite", |x| x.is_finite()),
2882                    fastk_period + 50..400,
2883                ),
2884                prop::collection::vec(0.0f64..=1.0f64, fastk_period + 50..400),
2885                Just(fastk_period),
2886                1usize..=10,
2887                0.001f64..0.1f64,
2888            )
2889        });
2890
2891        proptest::test_runner::TestRunner::default()
2892            .run(
2893                &strat,
2894                |(base_prices, close_positions, fastk_period, fastd_period, volatility)| {
2895                    let len = base_prices.len().min(close_positions.len());
2896                    let mut high = Vec::with_capacity(len);
2897                    let mut low = Vec::with_capacity(len);
2898                    let mut close = Vec::with_capacity(len);
2899
2900                    for i in 0..len {
2901                        let base = base_prices[i];
2902                        let spread = base * volatility;
2903                        let h = base + spread * 0.5;
2904                        let l = base - spread * 0.5;
2905
2906                        let c = l + (h - l) * close_positions[i];
2907
2908                        high.push(h);
2909                        low.push(l);
2910                        close.push(c);
2911                    }
2912
2913                    let params = StochfParams {
2914                        fastk_period: Some(fastk_period),
2915                        fastd_period: Some(fastd_period),
2916                        fastd_matype: Some(0),
2917                    };
2918                    let input = StochfInput::from_slices(&high, &low, &close, params.clone());
2919
2920                    let output = stochf_with_kernel(&input, kernel).unwrap();
2921                    let ref_output = stochf_with_kernel(&input, Kernel::Scalar).unwrap();
2922
2923                    for (i, &k_val) in output.k.iter().enumerate() {
2924                        if !k_val.is_nan() {
2925                            prop_assert!(
2926                                k_val >= -1e-9 && k_val <= 100.0 + 1e-9,
2927                                "K value out of range at idx {}: {} (should be in [0, 100])",
2928                                i,
2929                                k_val
2930                            );
2931                        }
2932                    }
2933
2934                    for (i, &d_val) in output.d.iter().enumerate() {
2935                        if !d_val.is_nan() {
2936                            prop_assert!(
2937                                d_val >= -1e-9 && d_val <= 100.0 + 1e-9,
2938                                "D value out of range at idx {}: {} (should be in [0, 100])",
2939                                i,
2940                                d_val
2941                            );
2942                        }
2943                    }
2944
2945                    let k_warmup = fastk_period - 1;
2946                    let d_warmup = fastk_period - 1 + fastd_period - 1;
2947
2948                    for i in 0..k_warmup.min(len) {
2949                        prop_assert!(
2950                            output.k[i].is_nan(),
2951                            "K value should be NaN during warmup at idx {}: {}",
2952                            i,
2953                            output.k[i]
2954                        );
2955                    }
2956
2957                    for i in 0..d_warmup.min(len) {
2958                        prop_assert!(
2959                            output.d[i].is_nan(),
2960                            "D value should be NaN during warmup at idx {}: {}",
2961                            i,
2962                            output.d[i]
2963                        );
2964                    }
2965
2966                    for i in 0..len {
2967                        let k_val = output.k[i];
2968                        let k_ref = ref_output.k[i];
2969                        let d_val = output.d[i];
2970                        let d_ref = ref_output.d[i];
2971
2972                        if !k_val.is_nan() && !k_ref.is_nan() {
2973                            prop_assert!(
2974                                (k_val - k_ref).abs() <= 1e-9,
2975                                "K kernel mismatch at idx {}: {} vs {} (diff: {})",
2976                                i,
2977                                k_val,
2978                                k_ref,
2979                                (k_val - k_ref).abs()
2980                            );
2981                        }
2982
2983                        if !d_val.is_nan() && !d_ref.is_nan() {
2984                            prop_assert!(
2985                                (d_val - d_ref).abs() <= 1e-9,
2986                                "D kernel mismatch at idx {}: {} vs {} (diff: {})",
2987                                i,
2988                                d_val,
2989                                d_ref,
2990                                (d_val - d_ref).abs()
2991                            );
2992                        }
2993                    }
2994
2995                    for i in k_warmup..len {
2996                        let start = i + 1 - fastk_period;
2997                        let window_high = &high[start..=i];
2998                        let window_low = &low[start..=i];
2999
3000                        let hh = window_high
3001                            .iter()
3002                            .cloned()
3003                            .fold(f64::NEG_INFINITY, f64::max);
3004                        let ll = window_low.iter().cloned().fold(f64::INFINITY, f64::min);
3005
3006                        let expected_k = if hh == ll {
3007                            if close[i] == hh {
3008                                100.0
3009                            } else {
3010                                0.0
3011                            }
3012                        } else {
3013                            100.0 * (close[i] - ll) / (hh - ll)
3014                        };
3015
3016                        let actual_k = output.k[i];
3017                        prop_assert!(
3018                            (actual_k - expected_k).abs() <= 1e-9,
3019                            "K formula mismatch at idx {}: actual {} vs expected {} (diff: {})",
3020                            i,
3021                            actual_k,
3022                            expected_k,
3023                            (actual_k - expected_k).abs()
3024                        );
3025                    }
3026
3027                    for i in d_warmup..len {
3028                        let start = i + 1 - fastd_period;
3029                        let k_window = &output.k[start..=i];
3030                        let expected_d = k_window.iter().sum::<f64>() / (fastd_period as f64);
3031                        let actual_d = output.d[i];
3032
3033                        prop_assert!(
3034                            (actual_d - expected_d).abs() <= 1e-9,
3035                            "D SMA mismatch at idx {}: actual {} vs expected {} (diff: {})",
3036                            i,
3037                            actual_d,
3038                            expected_d,
3039                            (actual_d - expected_d).abs()
3040                        );
3041                    }
3042
3043                    let const_len = (fastk_period + fastd_period) * 2;
3044                    if len > const_len {
3045                        let const_price = 100.0;
3046                        let const_high = vec![const_price; const_len];
3047                        let const_low = vec![const_price; const_len];
3048                        let const_close = vec![const_price; const_len];
3049
3050                        let const_input = StochfInput::from_slices(
3051                            &const_high,
3052                            &const_low,
3053                            &const_close,
3054                            params.clone(),
3055                        );
3056                        let const_output = stochf_with_kernel(&const_input, kernel).unwrap();
3057
3058                        for i in k_warmup..const_high.len() {
3059                            prop_assert!(
3060                                (const_output.k[i] - 100.0).abs() <= 1e-9,
3061                                "Constant price K should be 100 at idx {}: {}",
3062                                i,
3063                                const_output.k[i]
3064                            );
3065                        }
3066                    }
3067
3068                    let extreme_len = (fastk_period + fastd_period) * 2;
3069                    if len > extreme_len {
3070                        let low_close_high = vec![100.0; extreme_len];
3071                        let low_close_low = vec![90.0; extreme_len];
3072                        let low_close_close = vec![90.0; extreme_len];
3073
3074                        let low_input = StochfInput::from_slices(
3075                            &low_close_high,
3076                            &low_close_low,
3077                            &low_close_close,
3078                            params.clone(),
3079                        );
3080                        let low_output = stochf_with_kernel(&low_input, kernel).unwrap();
3081
3082                        for i in k_warmup..low_close_high.len() {
3083                            prop_assert!(
3084                                low_output.k[i].abs() <= 1e-9,
3085                                "When close == low, K should be 0 at idx {}: {}",
3086                                i,
3087                                low_output.k[i]
3088                            );
3089                        }
3090
3091                        let high_close_high = vec![100.0; extreme_len];
3092                        let high_close_low = vec![90.0; extreme_len];
3093                        let high_close_close = vec![100.0; extreme_len];
3094
3095                        let high_input = StochfInput::from_slices(
3096                            &high_close_high,
3097                            &high_close_low,
3098                            &high_close_close,
3099                            params.clone(),
3100                        );
3101                        let high_output = stochf_with_kernel(&high_input, kernel).unwrap();
3102
3103                        for i in k_warmup..high_close_high.len() {
3104                            prop_assert!(
3105                                (high_output.k[i] - 100.0).abs() <= 1e-9,
3106                                "When close == high, K should be 100 at idx {}: {}",
3107                                i,
3108                                high_output.k[i]
3109                            );
3110                        }
3111                    }
3112
3113                    #[cfg(debug_assertions)]
3114                    {
3115                        for (i, &val) in output.k.iter().enumerate() {
3116                            if !val.is_nan() {
3117                                let bits = val.to_bits();
3118                                prop_assert!(
3119                                    bits != 0x11111111_11111111
3120                                        && bits != 0x22222222_22222222
3121                                        && bits != 0x33333333_33333333,
3122                                    "Found poison value in K at idx {}: {} (0x{:016X})",
3123                                    i,
3124                                    val,
3125                                    bits
3126                                );
3127                            }
3128                        }
3129
3130                        for (i, &val) in output.d.iter().enumerate() {
3131                            if !val.is_nan() {
3132                                let bits = val.to_bits();
3133                                prop_assert!(
3134                                    bits != 0x11111111_11111111
3135                                        && bits != 0x22222222_22222222
3136                                        && bits != 0x33333333_33333333,
3137                                    "Found poison value in D at idx {}: {} (0x{:016X})",
3138                                    i,
3139                                    val,
3140                                    bits
3141                                );
3142                            }
3143                        }
3144                    }
3145
3146                    Ok(())
3147                },
3148            )
3149            .unwrap();
3150
3151        Ok(())
3152    }
3153
3154    macro_rules! gen_batch_tests {
3155        ($fn_name:ident) => {
3156            paste::paste! {
3157                #[test] fn [<$fn_name _scalar>]()      {
3158                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3159                }
3160                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3161                #[test] fn [<$fn_name _avx2>]()        {
3162                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3163                }
3164                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3165                #[test] fn [<$fn_name _avx512>]()      {
3166                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3167                }
3168                #[test] fn [<$fn_name _auto_detect>]() {
3169                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3170                }
3171            }
3172        };
3173    }
3174    #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3175    #[test]
3176    fn test_wasm_batch_warmup_initialization() {
3177        let high = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
3178        let low = vec![0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5];
3179        let close = vec![0.8, 1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8, 9.8];
3180
3181        let mut k_out = vec![999.0; 10];
3182        let mut d_out = vec![999.0; 10];
3183
3184        let result = unsafe {
3185            stochf_batch_into(
3186                high.as_ptr(),
3187                low.as_ptr(),
3188                close.as_ptr(),
3189                k_out.as_mut_ptr(),
3190                d_out.as_mut_ptr(),
3191                10,
3192                3,
3193                3,
3194                0,
3195                2,
3196                2,
3197                0,
3198                0,
3199            )
3200        };
3201
3202        assert!(result.is_ok());
3203        assert_eq!(result.unwrap(), 1);
3204
3205        assert!(k_out[0].is_nan(), "K[0] should be NaN");
3206        assert!(k_out[1].is_nan(), "K[1] should be NaN");
3207        assert!(!k_out[2].is_nan(), "K[2] should have a value");
3208
3209        assert!(d_out[0].is_nan(), "D[0] should be NaN");
3210        assert!(d_out[1].is_nan(), "D[1] should be NaN");
3211        assert!(d_out[2].is_nan(), "D[2] should be NaN");
3212        assert!(!d_out[3].is_nan(), "D[3] should have a value");
3213    }
3214
3215    #[test]
3216    fn test_batch_invalid_output_size() {
3217        let high = vec![10.0, 20.0, 30.0, 40.0, 50.0];
3218        let low = vec![5.0, 15.0, 25.0, 35.0, 45.0];
3219        let close = vec![7.0, 17.0, 27.0, 37.0, 47.0];
3220
3221        let sweep = StochfBatchRange {
3222            fastk_period: (3, 4, 1),
3223            fastd_period: (2, 2, 0),
3224        };
3225
3226        let mut k_out = vec![0.0; 5];
3227        let mut d_out = vec![0.0; 5];
3228
3229        let result = stochf_batch_inner_into(
3230            &high,
3231            &low,
3232            &close,
3233            &sweep,
3234            Kernel::Scalar,
3235            false,
3236            &mut k_out,
3237            &mut d_out,
3238        );
3239
3240        assert!(matches!(
3241            result,
3242            Err(StochfError::OutputLengthMismatch {
3243                expected: 10,
3244                k_got: 5,
3245                d_got: 5
3246            })
3247        ));
3248
3249        let mut k_out = vec![0.0; 10];
3250        let mut d_out = vec![0.0; 8];
3251
3252        let result = stochf_batch_inner_into(
3253            &high,
3254            &low,
3255            &close,
3256            &sweep,
3257            Kernel::Scalar,
3258            false,
3259            &mut k_out,
3260            &mut d_out,
3261        );
3262
3263        assert!(matches!(
3264            result,
3265            Err(StochfError::OutputLengthMismatch {
3266                expected: 10,
3267                k_got: 10,
3268                d_got: 8
3269            })
3270        ));
3271
3272        let mut k_out = vec![0.0; 10];
3273        let mut d_out = vec![0.0; 10];
3274
3275        let result = stochf_batch_inner_into(
3276            &high,
3277            &low,
3278            &close,
3279            &sweep,
3280            Kernel::Scalar,
3281            false,
3282            &mut k_out,
3283            &mut d_out,
3284        );
3285
3286        assert!(result.is_ok());
3287    }
3288
3289    #[test]
3290    fn test_stochf_into_matches_api() {
3291        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3292        let candles = read_candles_from_csv(file_path).expect("failed to read csv");
3293        let input = StochfInput::with_default_candles(&candles);
3294
3295        let base = stochf(&input).expect("baseline stochf failed");
3296
3297        let len = candles.close.len();
3298        let mut k_out = vec![0.0f64; len];
3299        let mut d_out = vec![0.0f64; len];
3300        stochf_into(&input, &mut k_out, &mut d_out).expect("stochf_into failed");
3301
3302        assert_eq!(base.k.len(), k_out.len());
3303        assert_eq!(base.d.len(), d_out.len());
3304
3305        fn eq_or_nan(a: f64, b: f64) -> bool {
3306            (a.is_nan() && b.is_nan()) || (a == b)
3307        }
3308
3309        for i in 0..len {
3310            assert!(
3311                eq_or_nan(base.k[i], k_out[i]),
3312                "K mismatch at {}: base={:?} into={:?}",
3313                i,
3314                base.k[i],
3315                k_out[i]
3316            );
3317            assert!(
3318                eq_or_nan(base.d[i], d_out[i]),
3319                "D mismatch at {}: base={:?} into={:?}",
3320                i,
3321                base.d[i],
3322                d_out[i]
3323            );
3324        }
3325    }
3326
3327    gen_batch_tests!(check_batch_default_row);
3328    gen_batch_tests!(check_batch_no_poison);
3329}
3330
3331#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3332#[wasm_bindgen]
3333pub fn stochf_js(
3334    high: &[f64],
3335    low: &[f64],
3336    close: &[f64],
3337    fastk_period: usize,
3338    fastd_period: usize,
3339    fastd_matype: usize,
3340) -> Result<Vec<f64>, JsValue> {
3341    let params = StochfParams {
3342        fastk_period: Some(fastk_period),
3343        fastd_period: Some(fastd_period),
3344        fastd_matype: Some(fastd_matype),
3345    };
3346    let input = StochfInput::from_slices(high, low, close, params);
3347    let out =
3348        stochf_with_kernel(&input, Kernel::Auto).map_err(|e| JsValue::from_str(&e.to_string()))?;
3349
3350    let mut values = Vec::with_capacity(2 * out.k.len());
3351    values.extend_from_slice(&out.k);
3352    values.extend_from_slice(&out.d);
3353
3354    Ok(values)
3355}
3356
3357#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3358#[wasm_bindgen]
3359pub fn stochf_alloc(len: usize) -> *mut f64 {
3360    let mut vec = Vec::<f64>::with_capacity(len);
3361    let ptr = vec.as_mut_ptr();
3362    std::mem::forget(vec);
3363    ptr
3364}
3365
3366#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3367#[wasm_bindgen]
3368pub fn stochf_free(ptr: *mut f64, len: usize) {
3369    if !ptr.is_null() {
3370        unsafe {
3371            let _ = Vec::from_raw_parts(ptr, len, len);
3372        }
3373    }
3374}
3375
3376#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3377#[wasm_bindgen]
3378pub fn stochf_into(
3379    high_ptr: *const f64,
3380    low_ptr: *const f64,
3381    close_ptr: *const f64,
3382    k_out_ptr: *mut f64,
3383    d_out_ptr: *mut f64,
3384    len: usize,
3385    fastk_period: usize,
3386    fastd_period: usize,
3387    fastd_matype: usize,
3388) -> Result<(), JsValue> {
3389    if high_ptr.is_null()
3390        || low_ptr.is_null()
3391        || close_ptr.is_null()
3392        || k_out_ptr.is_null()
3393        || d_out_ptr.is_null()
3394    {
3395        return Err(JsValue::from_str("null pointer passed to stochf_into"));
3396    }
3397
3398    unsafe {
3399        let high = std::slice::from_raw_parts(high_ptr, len);
3400        let low = std::slice::from_raw_parts(low_ptr, len);
3401        let close = std::slice::from_raw_parts(close_ptr, len);
3402        let mut k_out = std::slice::from_raw_parts_mut(k_out_ptr, len);
3403        let mut d_out = std::slice::from_raw_parts_mut(d_out_ptr, len);
3404
3405        let params = StochfParams {
3406            fastk_period: Some(fastk_period),
3407            fastd_period: Some(fastd_period),
3408            fastd_matype: Some(fastd_matype),
3409        };
3410        let input = StochfInput::from_slices(high, low, close, params);
3411        stochf_into_slice(&mut k_out, &mut d_out, &input, detect_best_kernel())
3412            .map_err(|e| JsValue::from_str(&e.to_string()))
3413    }
3414}
3415
3416#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3417#[derive(Serialize, Deserialize)]
3418pub struct StochfBatchConfig {
3419    pub fastk_range: (usize, usize, usize),
3420    pub fastd_range: (usize, usize, usize),
3421}
3422
3423#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3424#[derive(Serialize, Deserialize)]
3425pub struct StochfBatchJsOutput {
3426    pub k_values: Vec<f64>,
3427    pub d_values: Vec<f64>,
3428    pub rows: usize,
3429    pub cols: usize,
3430    pub combos: Vec<StochfParams>,
3431}
3432
3433#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3434#[wasm_bindgen(js_name = stochf_batch)]
3435pub fn stochf_batch_unified_js(
3436    high: &[f64],
3437    low: &[f64],
3438    close: &[f64],
3439    config: JsValue,
3440) -> Result<JsValue, JsValue> {
3441    let cfg: StochfBatchConfig = serde_wasm_bindgen::from_value(config)
3442        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
3443
3444    let sweep = StochfBatchRange {
3445        fastk_period: cfg.fastk_range,
3446        fastd_period: cfg.fastd_range,
3447    };
3448
3449    let out = stochf_batch_inner(high, low, close, &sweep, detect_best_kernel(), false)
3450        .map_err(|e| JsValue::from_str(&e.to_string()))?;
3451
3452    let js = StochfBatchJsOutput {
3453        k_values: out.k,
3454        d_values: out.d,
3455        rows: out.rows,
3456        cols: out.cols,
3457        combos: out.combos,
3458    };
3459    serde_wasm_bindgen::to_value(&js)
3460        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
3461}
3462
3463#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3464#[wasm_bindgen]
3465pub fn stochf_batch_into(
3466    in_high_ptr: *const f64,
3467    in_low_ptr: *const f64,
3468    in_close_ptr: *const f64,
3469    out_k_ptr: *mut f64,
3470    out_d_ptr: *mut f64,
3471    len: usize,
3472    fastk_start: usize,
3473    fastk_end: usize,
3474    fastk_step: usize,
3475    fastd_start: usize,
3476    fastd_end: usize,
3477    fastd_step: usize,
3478    fastd_matype: usize,
3479) -> Result<usize, JsValue> {
3480    if in_high_ptr.is_null()
3481        || in_low_ptr.is_null()
3482        || in_close_ptr.is_null()
3483        || out_k_ptr.is_null()
3484        || out_d_ptr.is_null()
3485    {
3486        return Err(JsValue::from_str("Null pointer provided"));
3487    }
3488
3489    unsafe {
3490        let high = std::slice::from_raw_parts(in_high_ptr, len);
3491        let low = std::slice::from_raw_parts(in_low_ptr, len);
3492        let close = std::slice::from_raw_parts(in_close_ptr, len);
3493
3494        let sweep = StochfBatchRange {
3495            fastk_period: (fastk_start, fastk_end, fastk_step),
3496            fastd_period: (fastd_start, fastd_end, fastd_step),
3497        };
3498
3499        let combos = expand_grid(&sweep);
3500        let rows = combos.len();
3501        let cols = len;
3502
3503        let aliasing = in_high_ptr == out_k_ptr
3504            || in_high_ptr == out_d_ptr
3505            || in_low_ptr == out_k_ptr
3506            || in_low_ptr == out_d_ptr
3507            || in_close_ptr == out_k_ptr
3508            || in_close_ptr == out_d_ptr;
3509
3510        if aliasing {
3511            let mut temp_k = vec![0.0; rows * cols];
3512            let mut temp_d = vec![0.0; rows * cols];
3513
3514            let kernel = detect_best_batch_kernel();
3515
3516            let simd_kernel = match kernel {
3517                Kernel::Avx512Batch => Kernel::Avx512,
3518                Kernel::Avx2Batch => Kernel::Avx2,
3519                Kernel::ScalarBatch => Kernel::Scalar,
3520                _ => Kernel::Scalar,
3521            };
3522
3523            stochf_batch_inner_into(
3524                high,
3525                low,
3526                close,
3527                &sweep,
3528                simd_kernel,
3529                false,
3530                &mut temp_k,
3531                &mut temp_d,
3532            )
3533            .map_err(|e| JsValue::from_str(&e.to_string()))?;
3534
3535            let out_k_slice = std::slice::from_raw_parts_mut(out_k_ptr, rows * cols);
3536            let out_d_slice = std::slice::from_raw_parts_mut(out_d_ptr, rows * cols);
3537
3538            out_k_slice.copy_from_slice(&temp_k);
3539            out_d_slice.copy_from_slice(&temp_d);
3540        } else {
3541            let out_k_slice = std::slice::from_raw_parts_mut(out_k_ptr, rows * cols);
3542            let out_d_slice = std::slice::from_raw_parts_mut(out_d_ptr, rows * cols);
3543
3544            let kernel = detect_best_batch_kernel();
3545
3546            let simd_kernel = match kernel {
3547                Kernel::Avx512Batch => Kernel::Avx512,
3548                Kernel::Avx2Batch => Kernel::Avx2,
3549                Kernel::ScalarBatch => Kernel::Scalar,
3550                _ => Kernel::Scalar,
3551            };
3552            stochf_batch_inner_into(
3553                high,
3554                low,
3555                close,
3556                &sweep,
3557                simd_kernel,
3558                false,
3559                out_k_slice,
3560                out_d_slice,
3561            )
3562            .map_err(|e| JsValue::from_str(&e.to_string()))?;
3563        }
3564
3565        Ok(rows)
3566    }
3567}