Skip to main content

vector_ta/indicators/
srsi.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::moving_averages::DeviceArrayF32;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use cust::context::Context;
7#[cfg(all(feature = "python", feature = "cuda"))]
8use cust::memory::DeviceBuffer;
9#[cfg(feature = "python")]
10use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
11#[cfg(feature = "python")]
12use pyo3::exceptions::PyValueError;
13#[cfg(feature = "python")]
14use pyo3::prelude::*;
15#[cfg(feature = "python")]
16use pyo3::types::PyDict;
17#[cfg(all(feature = "python", feature = "cuda"))]
18use std::sync::Arc;
19
20#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
21use serde::{Deserialize, Serialize};
22#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
23use wasm_bindgen::prelude::*;
24
25use crate::indicators::rsi::{rsi, RsiError, RsiInput, RsiOutput, RsiParams};
26use crate::indicators::stoch::{stoch, StochError, StochInput, StochOutput, StochParams};
27use crate::utilities::data_loader::{source_type, Candles};
28use crate::utilities::enums::Kernel;
29use crate::utilities::helpers::{
30    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
31    make_uninit_matrix,
32};
33#[cfg(feature = "python")]
34use crate::utilities::kernel_validation::validate_kernel;
35use aligned_vec::{AVec, CACHELINE_ALIGN};
36#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
37use core::arch::x86_64::*;
38#[cfg(not(target_arch = "wasm32"))]
39use rayon::prelude::*;
40use std::collections::VecDeque;
41use std::convert::AsRef;
42use std::error::Error;
43use std::mem::MaybeUninit;
44use thiserror::Error;
45
46impl<'a> AsRef<[f64]> for SrsiInput<'a> {
47    #[inline(always)]
48    fn as_ref(&self) -> &[f64] {
49        match &self.data {
50            SrsiData::Slice(slice) => slice,
51            SrsiData::Candles { candles, source } => source_type(candles, source),
52        }
53    }
54}
55
56#[derive(Debug, Clone)]
57pub enum SrsiData<'a> {
58    Candles {
59        candles: &'a Candles,
60        source: &'a str,
61    },
62    Slice(&'a [f64]),
63}
64
65#[derive(Debug, Clone)]
66pub struct SrsiOutput {
67    pub k: Vec<f64>,
68    pub d: Vec<f64>,
69}
70
71#[derive(Debug, Clone)]
72#[cfg_attr(
73    all(target_arch = "wasm32", feature = "wasm"),
74    derive(Serialize, Deserialize)
75)]
76pub struct SrsiParams {
77    pub rsi_period: Option<usize>,
78    pub stoch_period: Option<usize>,
79    pub k: Option<usize>,
80    pub d: Option<usize>,
81    pub source: Option<String>,
82}
83
84impl Default for SrsiParams {
85    fn default() -> Self {
86        Self {
87            rsi_period: Some(14),
88            stoch_period: Some(14),
89            k: Some(3),
90            d: Some(3),
91            source: Some("close".to_string()),
92        }
93    }
94}
95
96#[derive(Debug, Clone)]
97pub struct SrsiInput<'a> {
98    pub data: SrsiData<'a>,
99    pub params: SrsiParams,
100}
101
102impl<'a> SrsiInput<'a> {
103    #[inline]
104    pub fn from_candles(c: &'a Candles, s: &'a str, p: SrsiParams) -> Self {
105        Self {
106            data: SrsiData::Candles {
107                candles: c,
108                source: s,
109            },
110            params: p,
111        }
112    }
113    #[inline]
114    pub fn from_slice(sl: &'a [f64], p: SrsiParams) -> Self {
115        Self {
116            data: SrsiData::Slice(sl),
117            params: p,
118        }
119    }
120    #[inline]
121    pub fn with_default_candles(c: &'a Candles) -> Self {
122        Self::from_candles(c, "close", SrsiParams::default())
123    }
124    #[inline]
125    pub fn get_rsi_period(&self) -> usize {
126        self.params.rsi_period.unwrap_or(14)
127    }
128    #[inline]
129    pub fn get_stoch_period(&self) -> usize {
130        self.params.stoch_period.unwrap_or(14)
131    }
132    #[inline]
133    pub fn get_k(&self) -> usize {
134        self.params.k.unwrap_or(3)
135    }
136    #[inline]
137    pub fn get_d(&self) -> usize {
138        self.params.d.unwrap_or(3)
139    }
140    #[inline]
141    pub fn get_source(&self) -> &str {
142        self.params.source.as_deref().unwrap_or("close")
143    }
144}
145
146#[derive(Clone, Debug)]
147pub struct SrsiBuilder {
148    rsi_period: Option<usize>,
149    stoch_period: Option<usize>,
150    k: Option<usize>,
151    d: Option<usize>,
152    source: Option<String>,
153    kernel: Kernel,
154}
155
156impl Default for SrsiBuilder {
157    fn default() -> Self {
158        Self {
159            rsi_period: None,
160            stoch_period: None,
161            k: None,
162            d: None,
163            source: None,
164            kernel: Kernel::Auto,
165        }
166    }
167}
168
169impl SrsiBuilder {
170    #[inline(always)]
171    pub fn new() -> Self {
172        Self::default()
173    }
174    #[inline(always)]
175    pub fn rsi_period(mut self, n: usize) -> Self {
176        self.rsi_period = Some(n);
177        self
178    }
179    #[inline(always)]
180    pub fn stoch_period(mut self, n: usize) -> Self {
181        self.stoch_period = Some(n);
182        self
183    }
184    #[inline(always)]
185    pub fn k(mut self, n: usize) -> Self {
186        self.k = Some(n);
187        self
188    }
189    #[inline(always)]
190    pub fn d(mut self, n: usize) -> Self {
191        self.d = Some(n);
192        self
193    }
194    #[inline(always)]
195    pub fn source<S: Into<String>>(mut self, s: S) -> Self {
196        self.source = Some(s.into());
197        self
198    }
199    #[inline(always)]
200    pub fn kernel(mut self, k: Kernel) -> Self {
201        self.kernel = k;
202        self
203    }
204
205    #[inline(always)]
206    pub fn apply(self, c: &Candles) -> Result<SrsiOutput, SrsiError> {
207        let p = SrsiParams {
208            rsi_period: self.rsi_period,
209            stoch_period: self.stoch_period,
210            k: self.k,
211            d: self.d,
212            source: self.source.clone(),
213        };
214        let i = SrsiInput::from_candles(c, self.source.as_deref().unwrap_or("close"), p);
215        srsi_with_kernel(&i, self.kernel)
216    }
217
218    #[inline(always)]
219    pub fn apply_slice(self, d: &[f64]) -> Result<SrsiOutput, SrsiError> {
220        let p = SrsiParams {
221            rsi_period: self.rsi_period,
222            stoch_period: self.stoch_period,
223            k: self.k,
224            d: self.d,
225            source: self.source.clone(),
226        };
227        let i = SrsiInput::from_slice(d, p);
228        srsi_with_kernel(&i, self.kernel)
229    }
230}
231
232#[derive(Debug, Error)]
233pub enum SrsiError {
234    #[error("srsi: Error from RSI calculation: {0}")]
235    RsiError(#[from] RsiError),
236    #[error("srsi: Error from Stochastic calculation: {0}")]
237    StochError(#[from] StochError),
238    #[error("srsi: Input data is empty.")]
239    EmptyInputData,
240    #[error("srsi: All input data values are NaN.")]
241    AllValuesNaN,
242    #[error("srsi: Invalid period {period} for data length {data_len}.")]
243    InvalidPeriod { period: usize, data_len: usize },
244    #[error(
245        "srsi: Not enough valid data for the requested period. needed={needed}, valid={valid}"
246    )]
247    NotEnoughValidData { needed: usize, valid: usize },
248    #[error("srsi: Output length mismatch - destination buffers must match input data length. Expected {expected}, got k={k_len}, d={d_len}")]
249    OutputLengthMismatch {
250        expected: usize,
251        k_len: usize,
252        d_len: usize,
253    },
254    #[error("srsi: Invalid range: start={start}, end={end}, step={step}")]
255    InvalidRange {
256        start: String,
257        end: String,
258        step: String,
259    },
260    #[error("srsi: Invalid kernel for batch: {0:?}")]
261    InvalidKernelForBatch(Kernel),
262}
263
264#[inline]
265pub fn srsi(input: &SrsiInput) -> Result<SrsiOutput, SrsiError> {
266    srsi_with_kernel(input, Kernel::Auto)
267}
268
269pub fn srsi_with_kernel(input: &SrsiInput, kernel: Kernel) -> Result<SrsiOutput, SrsiError> {
270    let data: &[f64] = match &input.data {
271        SrsiData::Candles { candles, source } => source_type(candles, source),
272        SrsiData::Slice(sl) => sl,
273    };
274
275    if data.is_empty() {
276        return Err(SrsiError::EmptyInputData);
277    }
278
279    let first = data
280        .iter()
281        .position(|x| !x.is_nan())
282        .ok_or(SrsiError::AllValuesNaN)?;
283    let len = data.len();
284    let rsi_period = input.get_rsi_period();
285    let stoch_period = input.get_stoch_period();
286    let k_len = input.get_k();
287    let d_len = input.get_d();
288
289    let needed = rsi_period.max(stoch_period).max(k_len).max(d_len);
290    let valid = len - first;
291    if valid < needed {
292        return Err(SrsiError::NotEnoughValidData { needed, valid });
293    }
294
295    let chosen = match kernel {
296        Kernel::Auto => Kernel::Scalar,
297        other => other,
298    };
299
300    unsafe {
301        match chosen {
302            Kernel::Scalar | Kernel::ScalarBatch => {
303                srsi_scalar(data, rsi_period, stoch_period, k_len, d_len)
304            }
305            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
306            Kernel::Avx2 | Kernel::Avx2Batch => {
307                srsi_avx2(data, rsi_period, stoch_period, k_len, d_len)
308            }
309            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
310            Kernel::Avx512 | Kernel::Avx512Batch => {
311                srsi_avx512(data, rsi_period, stoch_period, k_len, d_len)
312            }
313            _ => unreachable!(),
314        }
315    }
316}
317
318#[inline]
319pub unsafe fn srsi_scalar(
320    data: &[f64],
321    rsi_period: usize,
322    stoch_period: usize,
323    k_period: usize,
324    d_period: usize,
325) -> Result<SrsiOutput, SrsiError> {
326    if rsi_period == 14 && stoch_period == 14 && k_period == 3 && d_period == 3 {
327        return srsi_scalar_classic(data, rsi_period, stoch_period, k_period, d_period);
328    }
329
330    let n = data.len();
331    if n == 0 {
332        return Err(SrsiError::EmptyInputData);
333    }
334    if rsi_period == 0 || stoch_period == 0 || k_period == 0 || d_period == 0 {
335        return Err(SrsiError::InvalidPeriod {
336            period: rsi_period.max(stoch_period).max(k_period).max(d_period),
337            data_len: n,
338        });
339    }
340    let first = data
341        .iter()
342        .position(|x| !x.is_nan())
343        .ok_or(SrsiError::AllValuesNaN)?;
344    let max_need = rsi_period.max(stoch_period).max(k_period).max(d_period);
345    if n - first < max_need {
346        return Err(SrsiError::NotEnoughValidData {
347            needed: max_need,
348            valid: n - first,
349        });
350    }
351
352    let rsi_warmup = first + rsi_period;
353    let stoch_warmup = rsi_warmup + stoch_period - 1;
354    let k_warmup = stoch_warmup + k_period - 1;
355    let d_warmup = k_warmup + d_period - 1;
356
357    if n <= d_warmup {
358        return Err(SrsiError::NotEnoughValidData {
359            needed: d_warmup + 1,
360            valid: n,
361        });
362    }
363
364    let mut rsi_vals = alloc_with_nan_prefix(n, rsi_warmup);
365    let mut k_out = alloc_with_nan_prefix(n, k_warmup);
366    let mut d_out = alloc_with_nan_prefix(n, d_warmup);
367
368    let mut avg_gain = 0.0f64;
369    let mut avg_loss = 0.0f64;
370    let mut prev = *data.get_unchecked(first);
371    let end_init = (first + rsi_period).min(n.saturating_sub(1));
372    for i in (first + 1)..=end_init {
373        let cur = *data.get_unchecked(i);
374        if cur.is_finite() && prev.is_finite() {
375            let ch = cur - prev;
376            if ch > 0.0 {
377                avg_gain += ch;
378            } else {
379                avg_loss += -ch;
380            }
381        }
382        prev = cur;
383    }
384
385    let rp = rsi_period as f64;
386    avg_gain /= rp;
387    avg_loss /= rp;
388    let alpha = 1.0f64 / rp;
389
390    if rsi_warmup < n {
391        rsi_vals[rsi_warmup] = if avg_loss == 0.0 {
392            100.0
393        } else {
394            let rs = avg_gain / avg_loss;
395            100.0 - 100.0 / (1.0 + rs)
396        };
397    }
398
399    prev = *data.get_unchecked(rsi_warmup);
400    for i in (rsi_warmup + 1)..n {
401        let cur = *data.get_unchecked(i);
402        if cur.is_finite() && prev.is_finite() {
403            let ch = cur - prev;
404            let gain = if ch > 0.0 { ch } else { 0.0 };
405            let loss = if ch < 0.0 { -ch } else { 0.0 };
406            avg_gain = (gain - avg_gain).mul_add(alpha, avg_gain);
407            avg_loss = (loss - avg_loss).mul_add(alpha, avg_loss);
408            rsi_vals[i] = if avg_loss == 0.0 {
409                100.0
410            } else {
411                let rs = avg_gain / avg_loss;
412                100.0 - 100.0 / (1.0 + rs)
413            };
414        }
415        prev = cur;
416    }
417
418    let sp = stoch_period;
419    let kp = k_period;
420    let dp = d_period;
421    if rsi_warmup < n {
422        let m = n - rsi_warmup;
423        let base = rsi_warmup;
424
425        let mut pref_max = vec![0.0f64; m];
426        let mut suff_max = vec![0.0f64; m];
427        let mut pref_min = vec![0.0f64; m];
428        let mut suff_min = vec![0.0f64; m];
429
430        let blocks = (m + sp - 1) / sp;
431        let p_pref_max = pref_max.as_mut_ptr();
432        let p_pref_min = pref_min.as_mut_ptr();
433        let p_rsi = rsi_vals.as_ptr().add(base);
434        for b in 0..blocks {
435            let start = b * sp;
436            let end = core::cmp::min(start + sp, m);
437            if start >= end {
438                break;
439            }
440            unsafe {
441                let v0 = *p_rsi.add(start);
442                *p_pref_max.add(start) = v0;
443                *p_pref_min.add(start) = v0;
444                let mut i = start + 1;
445                while i < end {
446                    let v = *p_rsi.add(i);
447                    let pmx = *p_pref_max.add(i - 1);
448                    let pmn = *p_pref_min.add(i - 1);
449                    *p_pref_max.add(i) = if v > pmx { v } else { pmx };
450                    *p_pref_min.add(i) = if v < pmn { v } else { pmn };
451                    i += 1;
452                }
453            }
454        }
455
456        let p_suff_max = suff_max.as_mut_ptr();
457        let p_suff_min = suff_min.as_mut_ptr();
458        for b in 0..blocks {
459            let block_end_excl = core::cmp::min((b + 1) * sp, m);
460            if block_end_excl == 0 {
461                break;
462            }
463            let block_start = block_end_excl - core::cmp::min(sp, block_end_excl);
464            unsafe {
465                let last = block_end_excl - 1;
466                let v_last = *p_rsi.add(last);
467                *p_suff_max.add(last) = v_last;
468                *p_suff_min.add(last) = v_last;
469                let mut i = last;
470                while i > block_start {
471                    let prev = i - 1;
472                    let v = *p_rsi.add(prev);
473                    let smx = *p_suff_max.add(i);
474                    let smn = *p_suff_min.add(i);
475                    *p_suff_max.add(prev) = if v > smx { v } else { smx };
476                    *p_suff_min.add(prev) = if v < smn { v } else { smn };
477                    i = prev;
478                }
479            }
480        }
481
482        let mut sum_k = 0.0f64;
483        let mut sum_d = 0.0f64;
484        let mut fk_ring = vec![0.0f64; kp];
485        let mut sk_ring = vec![0.0f64; dp];
486        let mut fk_pos = 0usize;
487        let mut sk_pos = 0usize;
488
489        let i0 = stoch_warmup;
490        let mut i = i0;
491        while i < n {
492            let t = i - base;
493            let t_start = t + 1 - sp;
494            let hi_l = suff_max[t_start];
495            let hi_r = pref_max[t];
496            let lo_l = suff_min[t_start];
497            let lo_r = pref_min[t];
498            let hi = if hi_l > hi_r { hi_l } else { hi_r };
499            let lo = if lo_l < lo_r { lo_l } else { lo_r };
500            let x = *rsi_vals.get_unchecked(i);
501            let fk = if hi > lo {
502                ((x - lo) * 100.0) / (hi - lo)
503            } else {
504                50.0
505            };
506
507            sum_k += fk;
508            if i >= i0 + kp {
509                sum_k -= *fk_ring.get_unchecked(fk_pos);
510            }
511            *fk_ring.get_unchecked_mut(fk_pos) = fk;
512            fk_pos += 1;
513            if fk_pos == kp {
514                fk_pos = 0;
515            }
516
517            if i >= k_warmup {
518                let sk = sum_k / (kp as f64);
519                *k_out.get_unchecked_mut(i) = sk;
520
521                sum_d += sk;
522                if i >= k_warmup + dp {
523                    sum_d -= *sk_ring.get_unchecked(sk_pos);
524                }
525                *sk_ring.get_unchecked_mut(sk_pos) = sk;
526                sk_pos += 1;
527                if sk_pos == dp {
528                    sk_pos = 0;
529                }
530
531                if i >= d_warmup {
532                    *d_out.get_unchecked_mut(i) = sum_d / (dp as f64);
533                }
534            }
535            i += 1;
536        }
537    }
538
539    Ok(SrsiOutput { k: k_out, d: d_out })
540}
541
542#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
543#[inline]
544pub unsafe fn srsi_avx2(
545    data: &[f64],
546    rsi_period: usize,
547    stoch_period: usize,
548    k: usize,
549    d: usize,
550) -> Result<SrsiOutput, SrsiError> {
551    srsi_scalar(data, rsi_period, stoch_period, k, d)
552}
553
554#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
555#[inline]
556pub unsafe fn srsi_avx512(
557    data: &[f64],
558    rsi_period: usize,
559    stoch_period: usize,
560    k: usize,
561    d: usize,
562) -> Result<SrsiOutput, SrsiError> {
563    if stoch_period <= 32 {
564        srsi_avx512_short(data, rsi_period, stoch_period, k, d)
565    } else {
566        srsi_avx512_long(data, rsi_period, stoch_period, k, d)
567    }
568}
569
570#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
571#[inline]
572pub unsafe fn srsi_avx512_short(
573    data: &[f64],
574    rsi_period: usize,
575    stoch_period: usize,
576    k: usize,
577    d: usize,
578) -> Result<SrsiOutput, SrsiError> {
579    srsi_scalar(data, rsi_period, stoch_period, k, d)
580}
581
582#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
583#[inline]
584pub unsafe fn srsi_avx512_long(
585    data: &[f64],
586    rsi_period: usize,
587    stoch_period: usize,
588    k: usize,
589    d: usize,
590) -> Result<SrsiOutput, SrsiError> {
591    srsi_scalar(data, rsi_period, stoch_period, k, d)
592}
593
594#[inline]
595pub unsafe fn srsi_scalar_classic(
596    data: &[f64],
597    rsi_period: usize,
598    stoch_period: usize,
599    k_period: usize,
600    d_period: usize,
601) -> Result<SrsiOutput, SrsiError> {
602    let n = data.len();
603    if n == 0 {
604        return Err(SrsiError::EmptyInputData);
605    }
606    let first = data
607        .iter()
608        .position(|x| !x.is_nan())
609        .ok_or(SrsiError::AllValuesNaN)?;
610
611    let rsi_warmup = first + rsi_period;
612    let stoch_warmup = rsi_warmup + stoch_period - 1;
613    let k_warmup = stoch_warmup + k_period - 1;
614    let d_warmup = k_warmup + d_period - 1;
615
616    if n <= d_warmup {
617        return Err(SrsiError::NotEnoughValidData {
618            needed: d_warmup + 1,
619            valid: n,
620        });
621    }
622
623    let mut rsi_values = alloc_with_nan_prefix(n, rsi_warmup);
624
625    let mut avg_gain = 0.0;
626    let mut avg_loss = 0.0;
627    let mut prev = data[first];
628
629    for i in (first + 1)..(first + rsi_period + 1).min(n) {
630        if data[i].is_finite() && prev.is_finite() {
631            let change = data[i] - prev;
632            if change > 0.0 {
633                avg_gain += change;
634            } else {
635                avg_loss += -change;
636            }
637            prev = data[i];
638        }
639    }
640
641    avg_gain /= rsi_period as f64;
642    avg_loss /= rsi_period as f64;
643
644    let alpha = 1.0 / rsi_period as f64;
645    let alpha_1minus = 1.0 - alpha;
646
647    if first + rsi_period < n {
648        rsi_values[first + rsi_period] = if avg_loss == 0.0 {
649            100.0
650        } else {
651            100.0 - (100.0 / (1.0 + avg_gain / avg_loss))
652        };
653
654        prev = data[first + rsi_period];
655    }
656
657    for i in (first + rsi_period + 1)..n {
658        if data[i].is_finite() && prev.is_finite() {
659            let change = data[i] - prev;
660            let (gain, loss) = if change > 0.0 {
661                (change, 0.0)
662            } else {
663                (0.0, -change)
664            };
665
666            avg_gain = alpha * gain + alpha_1minus * avg_gain;
667            avg_loss = alpha * loss + alpha_1minus * avg_loss;
668
669            rsi_values[i] = if avg_loss == 0.0 {
670                100.0
671            } else {
672                100.0 - (100.0 / (1.0 + avg_gain / avg_loss))
673            };
674
675            prev = data[i];
676        }
677    }
678
679    let mut fast_k = alloc_with_nan_prefix(n, stoch_warmup);
680
681    for i in stoch_warmup..n {
682        let start = i + 1 - stoch_period;
683        let mut min_rsi = f64::MAX;
684        let mut max_rsi = f64::MIN;
685
686        for j in start..=i {
687            if rsi_values[j].is_finite() {
688                min_rsi = min_rsi.min(rsi_values[j]);
689                max_rsi = max_rsi.max(rsi_values[j]);
690            }
691        }
692
693        if max_rsi > min_rsi {
694            fast_k[i] = 100.0 * (rsi_values[i] - min_rsi) / (max_rsi - min_rsi);
695        } else {
696            fast_k[i] = 50.0;
697        }
698    }
699
700    let mut slow_k = alloc_with_nan_prefix(n, k_warmup);
701
702    let mut k_sum = 0.0;
703    for i in stoch_warmup..(stoch_warmup + k_period).min(n) {
704        if fast_k[i].is_finite() {
705            k_sum += fast_k[i];
706        }
707    }
708
709    if stoch_warmup + k_period <= n {
710        slow_k[stoch_warmup + k_period - 1] = k_sum / k_period as f64;
711
712        for i in (stoch_warmup + k_period)..n {
713            if fast_k[i].is_finite() && fast_k[i - k_period].is_finite() {
714                k_sum += fast_k[i] - fast_k[i - k_period];
715                slow_k[i] = k_sum / k_period as f64;
716            }
717        }
718    }
719
720    let mut slow_d = alloc_with_nan_prefix(n, d_warmup);
721
722    let mut d_sum = 0.0;
723    for i in k_warmup..(k_warmup + d_period).min(n) {
724        if slow_k[i].is_finite() {
725            d_sum += slow_k[i];
726        }
727    }
728
729    if k_warmup + d_period <= n {
730        slow_d[k_warmup + d_period - 1] = d_sum / d_period as f64;
731
732        for i in (k_warmup + d_period)..n {
733            if slow_k[i].is_finite() && slow_k[i - d_period].is_finite() {
734                d_sum += slow_k[i] - slow_k[i - d_period];
735                slow_d[i] = d_sum / d_period as f64;
736            }
737        }
738    }
739
740    Ok(SrsiOutput {
741        k: slow_k,
742        d: slow_d,
743    })
744}
745
746#[inline(always)]
747pub fn srsi_row_scalar(
748    data: &[f64],
749    rsi_period: usize,
750    stoch_period: usize,
751    k: usize,
752    d: usize,
753    k_out: &mut [f64],
754    d_out: &mut [f64],
755) {
756    if let Ok(res) = unsafe { srsi_scalar(data, rsi_period, stoch_period, k, d) } {
757        k_out.copy_from_slice(&res.k);
758        d_out.copy_from_slice(&res.d);
759    }
760}
761
762#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
763#[inline(always)]
764pub fn srsi_row_avx2(
765    data: &[f64],
766    rsi_period: usize,
767    stoch_period: usize,
768    k: usize,
769    d: usize,
770    k_out: &mut [f64],
771    d_out: &mut [f64],
772) {
773    srsi_row_scalar(data, rsi_period, stoch_period, k, d, k_out, d_out)
774}
775
776#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
777#[inline(always)]
778pub fn srsi_row_avx512(
779    data: &[f64],
780    rsi_period: usize,
781    stoch_period: usize,
782    k: usize,
783    d: usize,
784    k_out: &mut [f64],
785    d_out: &mut [f64],
786) {
787    srsi_row_scalar(data, rsi_period, stoch_period, k, d, k_out, d_out)
788}
789
790#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
791#[inline(always)]
792pub fn srsi_row_avx512_short(
793    data: &[f64],
794    rsi_period: usize,
795    stoch_period: usize,
796    k: usize,
797    d: usize,
798    k_out: &mut [f64],
799    d_out: &mut [f64],
800) {
801    srsi_row_scalar(data, rsi_period, stoch_period, k, d, k_out, d_out)
802}
803
804#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
805#[inline(always)]
806pub fn srsi_row_avx512_long(
807    data: &[f64],
808    rsi_period: usize,
809    stoch_period: usize,
810    k: usize,
811    d: usize,
812    k_out: &mut [f64],
813    d_out: &mut [f64],
814) {
815    srsi_row_scalar(data, rsi_period, stoch_period, k, d, k_out, d_out)
816}
817
818#[derive(Debug, Clone)]
819pub struct SrsiStream {
820    rsi_period: usize,
821    stoch_period: usize,
822    k_period: usize,
823    d_period: usize,
824
825    prev: f64,
826    has_prev: bool,
827    init_count: usize,
828    sum_gain: f64,
829    sum_loss: f64,
830    avg_gain: f64,
831    avg_loss: f64,
832    alpha: f64,
833    rsi_ready: bool,
834    rsi_index: usize,
835    last_rsi: f64,
836
837    max_q: VecDeque<(usize, f64)>,
838    min_q: VecDeque<(usize, f64)>,
839
840    fk_ring: Vec<f64>,
841    fk_sum: f64,
842    fk_pos: usize,
843    fk_count: usize,
844    inv_k: f64,
845
846    sk_ring: Vec<f64>,
847    sk_sum: f64,
848    sk_pos: usize,
849    sk_count: usize,
850    inv_d: f64,
851}
852
853impl SrsiStream {
854    pub fn try_new(params: SrsiParams) -> Result<Self, SrsiError> {
855        let rsi_period = params.rsi_period.unwrap_or(14);
856        let stoch_period = params.stoch_period.unwrap_or(14);
857        let k_period = params.k.unwrap_or(3);
858        let d_period = params.d.unwrap_or(3);
859
860        if rsi_period == 0 || stoch_period == 0 || k_period == 0 || d_period == 0 {
861            return Err(SrsiError::InvalidPeriod {
862                period: rsi_period.max(stoch_period).max(k_period).max(d_period),
863                data_len: 0,
864            });
865        }
866
867        Ok(Self {
868            rsi_period,
869            stoch_period,
870            k_period,
871            d_period,
872
873            prev: f64::NAN,
874            has_prev: false,
875            init_count: 0,
876            sum_gain: 0.0,
877            sum_loss: 0.0,
878            avg_gain: 0.0,
879            avg_loss: 0.0,
880            alpha: 1.0 / (rsi_period as f64),
881            rsi_ready: false,
882            rsi_index: 0,
883            last_rsi: f64::NAN,
884
885            max_q: VecDeque::with_capacity(stoch_period),
886            min_q: VecDeque::with_capacity(stoch_period),
887
888            fk_ring: vec![0.0; k_period],
889            fk_sum: 0.0,
890            fk_pos: 0,
891            fk_count: 0,
892            inv_k: 1.0 / (k_period as f64),
893
894            sk_ring: vec![0.0; d_period],
895            sk_sum: 0.0,
896            sk_pos: 0,
897            sk_count: 0,
898            inv_d: 1.0 / (d_period as f64),
899        })
900    }
901
902    #[inline]
903    pub fn reset(&mut self) {
904        self.prev = f64::NAN;
905        self.has_prev = false;
906        self.init_count = 0;
907        self.sum_gain = 0.0;
908        self.sum_loss = 0.0;
909        self.avg_gain = 0.0;
910        self.avg_loss = 0.0;
911        self.rsi_ready = false;
912        self.rsi_index = 0;
913        self.last_rsi = f64::NAN;
914        self.max_q.clear();
915        self.min_q.clear();
916        self.fk_ring.fill(0.0);
917        self.fk_sum = 0.0;
918        self.fk_pos = 0;
919        self.fk_count = 0;
920        self.sk_ring.fill(0.0);
921        self.sk_sum = 0.0;
922        self.sk_pos = 0;
923        self.sk_count = 0;
924    }
925
926    pub fn update(&mut self, v: f64) -> Option<(f64, f64)> {
927        if !v.is_finite() {
928            self.reset();
929            return None;
930        }
931
932        if !self.has_prev {
933            self.prev = v;
934            self.has_prev = true;
935            return None;
936        }
937
938        let ch = v - self.prev;
939        self.prev = v;
940
941        if !self.rsi_ready {
942            if ch > 0.0 {
943                self.sum_gain += ch;
944            } else {
945                self.sum_loss += -ch;
946            }
947            self.init_count += 1;
948
949            if self.init_count < self.rsi_period {
950                return None;
951            }
952
953            self.avg_gain = self.sum_gain / (self.rsi_period as f64);
954            self.avg_loss = self.sum_loss / (self.rsi_period as f64);
955
956            let rsi = if self.avg_loss == 0.0 {
957                100.0
958            } else {
959                let rs = self.avg_gain / self.avg_loss;
960                100.0 - 100.0 / (1.0 + rs)
961            };
962            self.last_rsi = rsi;
963            self.rsi_ready = true;
964            self.rsi_index = 0;
965
966            self.push_rsi_to_deques(self.rsi_index, rsi);
967
968            return None;
969        }
970
971        let gain = if ch > 0.0 { ch } else { 0.0 };
972        let loss = if ch < 0.0 { -ch } else { 0.0 };
973        self.avg_gain = (gain - self.avg_gain).mul_add(self.alpha, self.avg_gain);
974        self.avg_loss = (loss - self.avg_loss).mul_add(self.alpha, self.avg_loss);
975
976        let rsi = if self.avg_loss == 0.0 {
977            100.0
978        } else {
979            let rs = self.avg_gain / self.avg_loss;
980            100.0 - 100.0 / (1.0 + rs)
981        };
982        self.last_rsi = rsi;
983
984        self.rsi_index += 1;
985        self.push_rsi_to_deques(self.rsi_index, rsi);
986
987        if self.rsi_index + 1 < self.stoch_period {
988            return None;
989        }
990
991        let start = self.rsi_index + 1 - self.stoch_period;
992
993        while let Some(&(j, _)) = self.max_q.front() {
994            if j < start {
995                self.max_q.pop_front();
996            } else {
997                break;
998            }
999        }
1000        while let Some(&(j, _)) = self.min_q.front() {
1001            if j < start {
1002                self.min_q.pop_front();
1003            } else {
1004                break;
1005            }
1006        }
1007
1008        debug_assert!(!self.max_q.is_empty() && !self.min_q.is_empty());
1009        let hi = self.max_q.front().unwrap().1;
1010        let lo = self.min_q.front().unwrap().1;
1011
1012        let fast_k = if hi > lo {
1013            let range = hi - lo;
1014            ((rsi - lo) * 100.0) / range
1015        } else {
1016            50.0
1017        };
1018
1019        let slow_k_opt = Self::push_sma(
1020            fast_k,
1021            &mut self.fk_ring,
1022            &mut self.fk_sum,
1023            &mut self.fk_pos,
1024            &mut self.fk_count,
1025            self.k_period,
1026            self.inv_k,
1027        );
1028
1029        let slow_k = match slow_k_opt {
1030            None => return None,
1031            Some(v) => v,
1032        };
1033
1034        let slow_d_opt = Self::push_sma(
1035            slow_k,
1036            &mut self.sk_ring,
1037            &mut self.sk_sum,
1038            &mut self.sk_pos,
1039            &mut self.sk_count,
1040            self.d_period,
1041            self.inv_d,
1042        );
1043
1044        slow_d_opt.map(|d| (slow_k, d))
1045    }
1046
1047    #[inline(always)]
1048    fn push_rsi_to_deques(&mut self, idx: usize, rsi: f64) {
1049        while let Some(&(_, v)) = self.max_q.back() {
1050            if v <= rsi {
1051                self.max_q.pop_back();
1052            } else {
1053                break;
1054            }
1055        }
1056        if self.max_q.len() == self.stoch_period {
1057            self.max_q.pop_front();
1058        }
1059        self.max_q.push_back((idx, rsi));
1060
1061        while let Some(&(_, v)) = self.min_q.back() {
1062            if v >= rsi {
1063                self.min_q.pop_back();
1064            } else {
1065                break;
1066            }
1067        }
1068        if self.min_q.len() == self.stoch_period {
1069            self.min_q.pop_front();
1070        }
1071        self.min_q.push_back((idx, rsi));
1072    }
1073
1074    #[inline(always)]
1075    fn push_sma(
1076        new_val: f64,
1077        ring: &mut [f64],
1078        sum: &mut f64,
1079        pos: &mut usize,
1080        count: &mut usize,
1081        period: usize,
1082        inv_period: f64,
1083    ) -> Option<f64> {
1084        if *count < period {
1085            *sum += new_val;
1086            ring[*pos] = new_val;
1087            *pos += 1;
1088            if *pos == period {
1089                *pos = 0;
1090            }
1091            *count += 1;
1092            if *count == period {
1093                Some(*sum * inv_period)
1094            } else {
1095                None
1096            }
1097        } else {
1098            *sum += new_val - ring[*pos];
1099            ring[*pos] = new_val;
1100            *pos += 1;
1101            if *pos == period {
1102                *pos = 0;
1103            }
1104            Some(*sum * inv_period)
1105        }
1106    }
1107}
1108
1109#[derive(Clone, Debug)]
1110pub struct SrsiBatchRange {
1111    pub rsi_period: (usize, usize, usize),
1112    pub stoch_period: (usize, usize, usize),
1113    pub k: (usize, usize, usize),
1114    pub d: (usize, usize, usize),
1115}
1116
1117impl Default for SrsiBatchRange {
1118    fn default() -> Self {
1119        Self {
1120            rsi_period: (14, 263, 1),
1121            stoch_period: (14, 14, 0),
1122            k: (3, 3, 0),
1123            d: (3, 3, 0),
1124        }
1125    }
1126}
1127
1128#[derive(Clone, Debug, Default)]
1129pub struct SrsiBatchBuilder {
1130    range: SrsiBatchRange,
1131    kernel: Kernel,
1132}
1133
1134impl SrsiBatchBuilder {
1135    pub fn new() -> Self {
1136        Self::default()
1137    }
1138    pub fn kernel(mut self, k: Kernel) -> Self {
1139        self.kernel = k;
1140        self
1141    }
1142    #[inline]
1143    pub fn rsi_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1144        self.range.rsi_period = (start, end, step);
1145        self
1146    }
1147    #[inline]
1148    pub fn stoch_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1149        self.range.stoch_period = (start, end, step);
1150        self
1151    }
1152    #[inline]
1153    pub fn k_range(mut self, start: usize, end: usize, step: usize) -> Self {
1154        self.range.k = (start, end, step);
1155        self
1156    }
1157    #[inline]
1158    pub fn d_range(mut self, start: usize, end: usize, step: usize) -> Self {
1159        self.range.d = (start, end, step);
1160        self
1161    }
1162    pub fn apply_slice(self, data: &[f64]) -> Result<SrsiBatchOutput, SrsiError> {
1163        srsi_batch_with_kernel(data, &self.range, self.kernel)
1164    }
1165}
1166
1167#[derive(Clone, Debug)]
1168pub struct SrsiBatchOutput {
1169    pub k: Vec<f64>,
1170    pub d: Vec<f64>,
1171    pub combos: Vec<SrsiParams>,
1172    pub rows: usize,
1173    pub cols: usize,
1174}
1175impl SrsiBatchOutput {
1176    pub fn row_for_params(&self, p: &SrsiParams) -> Option<usize> {
1177        self.combos.iter().position(|c| {
1178            c.rsi_period.unwrap_or(14) == p.rsi_period.unwrap_or(14)
1179                && c.stoch_period.unwrap_or(14) == p.stoch_period.unwrap_or(14)
1180                && c.k.unwrap_or(3) == p.k.unwrap_or(3)
1181                && c.d.unwrap_or(3) == p.d.unwrap_or(3)
1182        })
1183    }
1184    pub fn k_for(&self, p: &SrsiParams) -> Option<&[f64]> {
1185        self.row_for_params(p).map(|row| {
1186            let start = row * self.cols;
1187            &self.k[start..start + self.cols]
1188        })
1189    }
1190    pub fn d_for(&self, p: &SrsiParams) -> Option<&[f64]> {
1191        self.row_for_params(p).map(|row| {
1192            let start = row * self.cols;
1193            &self.d[start..start + self.cols]
1194        })
1195    }
1196}
1197
1198#[inline(always)]
1199fn expand_grid(r: &SrsiBatchRange) -> Result<Vec<SrsiParams>, SrsiError> {
1200    fn axis((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, SrsiError> {
1201        if step == 0 || start == end {
1202            return Ok(vec![start]);
1203        }
1204        if start < end {
1205            let st = step.max(1);
1206            let v: Vec<usize> = (start..=end).step_by(st).collect();
1207            if v.is_empty() {
1208                return Err(SrsiError::InvalidRange {
1209                    start: start.to_string(),
1210                    end: end.to_string(),
1211                    step: step.to_string(),
1212                });
1213            }
1214            return Ok(v);
1215        }
1216        let mut v = Vec::new();
1217        let mut x = start as isize;
1218        let end_i = end as isize;
1219        let st = (step as isize).max(1);
1220        while x >= end_i {
1221            v.push(x as usize);
1222            x -= st;
1223        }
1224        if v.is_empty() {
1225            return Err(SrsiError::InvalidRange {
1226                start: start.to_string(),
1227                end: end.to_string(),
1228                step: step.to_string(),
1229            });
1230        }
1231        Ok(v)
1232    }
1233    let rsi_periods = axis(r.rsi_period)?;
1234    let stoch_periods = axis(r.stoch_period)?;
1235    let ks = axis(r.k)?;
1236    let ds = axis(r.d)?;
1237
1238    if rsi_periods.is_empty() || stoch_periods.is_empty() || ks.is_empty() || ds.is_empty() {
1239        return Err(SrsiError::InvalidRange {
1240            start: r.rsi_period.0.to_string(),
1241            end: r.rsi_period.1.to_string(),
1242            step: r.rsi_period.2.to_string(),
1243        });
1244    }
1245
1246    let mut out = Vec::with_capacity(rsi_periods.len() * stoch_periods.len() * ks.len() * ds.len());
1247    for &rsi_p in &rsi_periods {
1248        for &stoch_p in &stoch_periods {
1249            for &k in &ks {
1250                for &d in &ds {
1251                    out.push(SrsiParams {
1252                        rsi_period: Some(rsi_p),
1253                        stoch_period: Some(stoch_p),
1254                        k: Some(k),
1255                        d: Some(d),
1256                        source: None,
1257                    });
1258                }
1259            }
1260        }
1261    }
1262    Ok(out)
1263}
1264
1265#[inline(always)]
1266pub fn srsi_batch_with_kernel(
1267    data: &[f64],
1268    sweep: &SrsiBatchRange,
1269    k: Kernel,
1270) -> Result<SrsiBatchOutput, SrsiError> {
1271    let kernel = match k {
1272        Kernel::Auto => detect_best_batch_kernel(),
1273        other if other.is_batch() => other,
1274        _ => return Err(SrsiError::InvalidKernelForBatch(k)),
1275    };
1276    let simd = match kernel {
1277        Kernel::Avx512Batch => Kernel::Avx512,
1278        Kernel::Avx2Batch => Kernel::Avx2,
1279        Kernel::ScalarBatch => Kernel::Scalar,
1280        _ => unreachable!(),
1281    };
1282    srsi_batch_par_slice(data, sweep, simd)
1283}
1284
1285#[inline(always)]
1286pub fn srsi_batch_slice(
1287    data: &[f64],
1288    sweep: &SrsiBatchRange,
1289    kern: Kernel,
1290) -> Result<SrsiBatchOutput, SrsiError> {
1291    srsi_batch_inner(data, sweep, kern, false)
1292}
1293
1294#[inline(always)]
1295pub fn srsi_batch_par_slice(
1296    data: &[f64],
1297    sweep: &SrsiBatchRange,
1298    kern: Kernel,
1299) -> Result<SrsiBatchOutput, SrsiError> {
1300    srsi_batch_inner(data, sweep, kern, true)
1301}
1302
1303#[inline(always)]
1304fn srsi_batch_inner_into(
1305    data: &[f64],
1306    sweep: &SrsiBatchRange,
1307    kern: Kernel,
1308    parallel: bool,
1309    k_out: &mut [f64],
1310    d_out: &mut [f64],
1311) -> Result<Vec<SrsiParams>, SrsiError> {
1312    let combos = expand_grid(sweep)?;
1313    if combos.is_empty() {
1314        return Err(SrsiError::InvalidRange {
1315            start: sweep.rsi_period.0.to_string(),
1316            end: sweep.rsi_period.1.to_string(),
1317            step: sweep.rsi_period.2.to_string(),
1318        });
1319    }
1320
1321    let first = data
1322        .iter()
1323        .position(|x| !x.is_nan())
1324        .ok_or(SrsiError::AllValuesNaN)?;
1325
1326    use std::collections::{BTreeMap, BTreeSet};
1327    let mut rsi_cache: BTreeMap<usize, Vec<f64>> = BTreeMap::new();
1328    let uniq_rsi: BTreeSet<usize> = combos.iter().map(|c| c.rsi_period.unwrap()).collect();
1329    for rp in uniq_rsi {
1330        let rsi_in = RsiInput::from_slice(data, RsiParams { period: Some(rp) });
1331        let rsi_out = rsi(&rsi_in)?;
1332        rsi_cache.insert(rp, rsi_out.values);
1333    }
1334
1335    let max_period = combos
1336        .iter()
1337        .map(|c| {
1338            c.rsi_period
1339                .unwrap()
1340                .max(c.stoch_period.unwrap())
1341                .max(c.k.unwrap())
1342                .max(c.d.unwrap())
1343        })
1344        .max()
1345        .unwrap();
1346
1347    if data.len() - first < max_period {
1348        return Err(SrsiError::NotEnoughValidData {
1349            needed: max_period,
1350            valid: data.len() - first,
1351        });
1352    }
1353
1354    let cols = data.len();
1355
1356    let do_row = |row: usize, k_row: &mut [f64], d_row: &mut [f64]| -> Result<(), SrsiError> {
1357        let prm = &combos[row];
1358        let rsi_vals = rsi_cache.get(&prm.rsi_period.unwrap()).expect("cached rsi");
1359        let st_in = StochInput {
1360            data: crate::indicators::stoch::StochData::Slices {
1361                high: rsi_vals,
1362                low: rsi_vals,
1363                close: rsi_vals,
1364            },
1365            params: StochParams {
1366                fastk_period: prm.stoch_period,
1367                slowk_period: prm.k,
1368                slowk_ma_type: Some("sma".to_string()),
1369                slowd_period: prm.d,
1370                slowd_ma_type: Some("sma".to_string()),
1371            },
1372        };
1373        let st = stoch(&st_in)?;
1374        k_row.copy_from_slice(&st.k);
1375        d_row.copy_from_slice(&st.d);
1376        Ok(())
1377    };
1378
1379    if parallel {
1380        #[cfg(not(target_arch = "wasm32"))]
1381        {
1382            k_out
1383                .par_chunks_mut(cols)
1384                .zip(d_out.par_chunks_mut(cols))
1385                .enumerate()
1386                .try_for_each(|(row, (k_row, d_row))| do_row(row, k_row, d_row))?;
1387        }
1388
1389        #[cfg(target_arch = "wasm32")]
1390        {
1391            for (row, (k_row, d_row)) in k_out
1392                .chunks_mut(cols)
1393                .zip(d_out.chunks_mut(cols))
1394                .enumerate()
1395            {
1396                do_row(row, k_row, d_row)?;
1397            }
1398        }
1399    } else {
1400        for (row, (k_row, d_row)) in k_out
1401            .chunks_mut(cols)
1402            .zip(d_out.chunks_mut(cols))
1403            .enumerate()
1404        {
1405            do_row(row, k_row, d_row)?;
1406        }
1407    }
1408
1409    Ok(combos)
1410}
1411
1412#[inline(always)]
1413fn srsi_batch_inner(
1414    data: &[f64],
1415    sweep: &SrsiBatchRange,
1416    kern: Kernel,
1417    parallel: bool,
1418) -> Result<SrsiBatchOutput, SrsiError> {
1419    let combos = expand_grid(sweep)?;
1420    if combos.is_empty() {
1421        return Err(SrsiError::InvalidRange {
1422            start: sweep.rsi_period.0.to_string(),
1423            end: sweep.rsi_period.1.to_string(),
1424            step: sweep.rsi_period.2.to_string(),
1425        });
1426    }
1427    let first = data
1428        .iter()
1429        .position(|x| !x.is_nan())
1430        .ok_or(SrsiError::AllValuesNaN)?;
1431    let max_period = combos
1432        .iter()
1433        .map(|c| {
1434            c.rsi_period
1435                .unwrap()
1436                .max(c.stoch_period.unwrap())
1437                .max(c.k.unwrap())
1438                .max(c.d.unwrap())
1439        })
1440        .max()
1441        .unwrap();
1442    if data.len() - first < max_period {
1443        return Err(SrsiError::NotEnoughValidData {
1444            needed: max_period,
1445            valid: data.len() - first,
1446        });
1447    }
1448    let rows = combos.len();
1449    let cols = data.len();
1450    let _ = rows
1451        .checked_mul(cols)
1452        .ok_or_else(|| SrsiError::InvalidRange {
1453            start: sweep.rsi_period.0.to_string(),
1454            end: sweep.rsi_period.1.to_string(),
1455            step: sweep.rsi_period.2.to_string(),
1456        })?;
1457
1458    if rows == 1 {
1459        let prm = &combos[0];
1460        let res = unsafe {
1461            match kern {
1462                Kernel::Avx512 | Kernel::Avx512Batch => {
1463                    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1464                    {
1465                        srsi_avx512(
1466                            data,
1467                            prm.rsi_period.unwrap(),
1468                            prm.stoch_period.unwrap(),
1469                            prm.k.unwrap(),
1470                            prm.d.unwrap(),
1471                        )
1472                    }
1473                    #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1474                    {
1475                        srsi_scalar(
1476                            data,
1477                            prm.rsi_period.unwrap(),
1478                            prm.stoch_period.unwrap(),
1479                            prm.k.unwrap(),
1480                            prm.d.unwrap(),
1481                        )
1482                    }
1483                }
1484                Kernel::Avx2 | Kernel::Avx2Batch => {
1485                    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1486                    {
1487                        srsi_avx2(
1488                            data,
1489                            prm.rsi_period.unwrap(),
1490                            prm.stoch_period.unwrap(),
1491                            prm.k.unwrap(),
1492                            prm.d.unwrap(),
1493                        )
1494                    }
1495                    #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1496                    {
1497                        srsi_scalar(
1498                            data,
1499                            prm.rsi_period.unwrap(),
1500                            prm.stoch_period.unwrap(),
1501                            prm.k.unwrap(),
1502                            prm.d.unwrap(),
1503                        )
1504                    }
1505                }
1506                _ => srsi_scalar(
1507                    data,
1508                    prm.rsi_period.unwrap(),
1509                    prm.stoch_period.unwrap(),
1510                    prm.k.unwrap(),
1511                    prm.d.unwrap(),
1512                ),
1513            }
1514        }?;
1515        return Ok(SrsiBatchOutput {
1516            k: res.k,
1517            d: res.d,
1518            combos,
1519            rows: 1,
1520            cols,
1521        });
1522    }
1523    let mut k_vals = make_uninit_matrix(rows, cols);
1524    let mut d_vals = make_uninit_matrix(rows, cols);
1525
1526    fn warm_for(c: &SrsiParams, first: usize) -> usize {
1527        let rp = c.rsi_period.unwrap();
1528        let sp = c.stoch_period.unwrap();
1529        let kp = c.k.unwrap();
1530        let dp = c.d.unwrap();
1531
1532        first + rp - 1 + sp - 1 + kp.max(dp) - 1
1533    }
1534
1535    let warmup_periods: Vec<usize> = combos
1536        .iter()
1537        .map(|c| warm_for(c, first).min(cols))
1538        .collect();
1539
1540    init_matrix_prefixes(&mut k_vals, cols, &warmup_periods);
1541    init_matrix_prefixes(&mut d_vals, cols, &warmup_periods);
1542
1543    let mut k_guard = core::mem::ManuallyDrop::new(k_vals);
1544    let mut d_guard = core::mem::ManuallyDrop::new(d_vals);
1545    let k_out: &mut [f64] =
1546        unsafe { core::slice::from_raw_parts_mut(k_guard.as_mut_ptr() as *mut f64, k_guard.len()) };
1547    let d_out: &mut [f64] =
1548        unsafe { core::slice::from_raw_parts_mut(d_guard.as_mut_ptr() as *mut f64, d_guard.len()) };
1549
1550    let combos = srsi_batch_inner_into(data, sweep, kern, parallel, k_out, d_out)?;
1551
1552    let k_values = unsafe {
1553        Vec::from_raw_parts(
1554            k_guard.as_mut_ptr() as *mut f64,
1555            k_guard.len(),
1556            k_guard.capacity(),
1557        )
1558    };
1559
1560    let d_values = unsafe {
1561        Vec::from_raw_parts(
1562            d_guard.as_mut_ptr() as *mut f64,
1563            d_guard.len(),
1564            d_guard.capacity(),
1565        )
1566    };
1567
1568    Ok(SrsiBatchOutput {
1569        k: k_values,
1570        d: d_values,
1571        combos,
1572        rows,
1573        cols,
1574    })
1575}
1576
1577#[inline(always)]
1578pub fn expand_grid_srsi(r: &SrsiBatchRange) -> Result<Vec<SrsiParams>, SrsiError> {
1579    expand_grid(r)
1580}
1581
1582#[cfg(all(feature = "python", feature = "cuda"))]
1583#[pyclass(module = "ta_indicators.cuda", name = "SrsiDeviceArrayF32", unsendable)]
1584pub struct SrsiDeviceArrayF32Py {
1585    pub(crate) inner: DeviceArrayF32,
1586    pub(crate) _ctx: Arc<Context>,
1587    pub(crate) device_id: u32,
1588}
1589
1590#[cfg(all(feature = "python", feature = "cuda"))]
1591#[pymethods]
1592impl SrsiDeviceArrayF32Py {
1593    #[getter]
1594    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1595        let d = PyDict::new(py);
1596        let itemsize = std::mem::size_of::<f32>();
1597        d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1598        d.set_item("typestr", "<f4")?;
1599        d.set_item("strides", (self.inner.cols * itemsize, itemsize))?;
1600        d.set_item("data", (self.inner.device_ptr() as usize, false))?;
1601
1602        d.set_item("version", 3)?;
1603        Ok(d)
1604    }
1605
1606    fn __dlpack_device__(&self) -> (i32, i32) {
1607        (2, self.device_id as i32)
1608    }
1609
1610    #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
1611    fn __dlpack__<'py>(
1612        &mut self,
1613        py: Python<'py>,
1614        stream: Option<PyObject>,
1615        max_version: Option<PyObject>,
1616        dl_device: Option<PyObject>,
1617        copy: Option<PyObject>,
1618    ) -> PyResult<PyObject> {
1619        if let Some(sobj) = stream.as_ref() {
1620            if let Ok(s) = sobj.extract::<usize>(py) {
1621                if s == 0 {
1622                    return Err(PyValueError::new_err(
1623                        "__dlpack__ stream=0 is invalid for CUDA",
1624                    ));
1625                }
1626            }
1627        }
1628
1629        let (kdl, alloc_dev) = self.__dlpack_device__();
1630        if let Some(dev_obj) = dl_device.as_ref() {
1631            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1632                if dev_ty != kdl || dev_id != alloc_dev {
1633                    let wants_copy = copy
1634                        .as_ref()
1635                        .and_then(|c| c.extract::<bool>(py).ok())
1636                        .unwrap_or(false);
1637                    if wants_copy {
1638                        return Err(PyValueError::new_err(
1639                            "device copy not implemented for __dlpack__",
1640                        ));
1641                    } else {
1642                        return Err(PyValueError::new_err(
1643                            "dl_device mismatch for __dlpack__ on SrsiDeviceArrayF32",
1644                        ));
1645                    }
1646                }
1647            }
1648        }
1649
1650        if let Some(copy_obj) = copy.as_ref() {
1651            let do_copy: bool = copy_obj.extract(py)?;
1652            if do_copy {
1653                return Err(PyValueError::new_err(
1654                    "copy=True not supported for SrsiDeviceArrayF32",
1655                ));
1656            }
1657        }
1658
1659        let dummy =
1660            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1661        let rows = self.inner.rows;
1662        let cols = self.inner.cols;
1663        let inner = std::mem::replace(
1664            &mut self.inner,
1665            DeviceArrayF32 {
1666                buf: dummy,
1667                rows: 0,
1668                cols: 0,
1669            },
1670        );
1671
1672        let buf = inner.buf;
1673        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1674
1675        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1676    }
1677}
1678
1679#[cfg(all(feature = "python", feature = "cuda"))]
1680impl SrsiDeviceArrayF32Py {
1681    pub fn new_from_rust(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
1682        Self {
1683            inner,
1684            _ctx: ctx_guard,
1685            device_id,
1686        }
1687    }
1688}
1689
1690#[cfg(all(feature = "python", feature = "cuda"))]
1691#[pyfunction(name = "srsi_cuda_batch_dev")]
1692#[pyo3(signature = (data_f32, rsi_range, stoch_range, k_range, d_range, device_id=0))]
1693pub fn srsi_cuda_batch_dev_py<'py>(
1694    py: Python<'py>,
1695    data_f32: numpy::PyReadonlyArray1<'py, f32>,
1696    rsi_range: (usize, usize, usize),
1697    stoch_range: (usize, usize, usize),
1698    k_range: (usize, usize, usize),
1699    d_range: (usize, usize, usize),
1700    device_id: usize,
1701) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1702    use crate::cuda::cuda_available;
1703    use crate::cuda::oscillators::CudaSrsi;
1704    use numpy::IntoPyArray;
1705    if !cuda_available() {
1706        return Err(PyValueError::new_err("CUDA not available"));
1707    }
1708    let slice = data_f32.as_slice()?;
1709    let sweep = SrsiBatchRange {
1710        rsi_period: rsi_range,
1711        stoch_period: stoch_range,
1712        k: k_range,
1713        d: d_range,
1714    };
1715    let ((pair, combos), ctx, dev_id) = py.allow_threads(|| {
1716        let cuda = CudaSrsi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1717        let ctx = cuda.context_arc();
1718        let dev_id = cuda.device_id();
1719        let res = cuda
1720            .srsi_batch_dev(slice, &sweep)
1721            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1722        Ok::<_, pyo3::PyErr>((res, ctx, dev_id))
1723    })?;
1724    let dict = pyo3::types::PyDict::new(py);
1725    dict.set_item(
1726        "k",
1727        SrsiDeviceArrayF32Py::new_from_rust(pair.k, ctx.clone(), dev_id),
1728    )?;
1729    dict.set_item(
1730        "d",
1731        SrsiDeviceArrayF32Py::new_from_rust(pair.d, ctx, dev_id),
1732    )?;
1733    dict.set_item("rows", combos.len())?;
1734    dict.set_item("cols", slice.len())?;
1735    dict.set_item(
1736        "rsi_periods",
1737        combos
1738            .iter()
1739            .map(|p| p.rsi_period.unwrap() as u64)
1740            .collect::<Vec<_>>()
1741            .into_pyarray(py),
1742    )?;
1743    dict.set_item(
1744        "stoch_periods",
1745        combos
1746            .iter()
1747            .map(|p| p.stoch_period.unwrap() as u64)
1748            .collect::<Vec<_>>()
1749            .into_pyarray(py),
1750    )?;
1751    dict.set_item(
1752        "k_periods",
1753        combos
1754            .iter()
1755            .map(|p| p.k.unwrap() as u64)
1756            .collect::<Vec<_>>()
1757            .into_pyarray(py),
1758    )?;
1759    dict.set_item(
1760        "d_periods",
1761        combos
1762            .iter()
1763            .map(|p| p.d.unwrap() as u64)
1764            .collect::<Vec<_>>()
1765            .into_pyarray(py),
1766    )?;
1767    Ok(dict)
1768}
1769
1770#[cfg(all(feature = "python", feature = "cuda"))]
1771#[pyfunction(name = "srsi_cuda_many_series_one_param_dev")]
1772#[pyo3(signature = (data_tm_f32, rsi_period=14, stoch_period=14, k=3, d=3, device_id=0))]
1773pub fn srsi_cuda_many_series_one_param_dev_py<'py>(
1774    py: Python<'py>,
1775    data_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
1776    rsi_period: usize,
1777    stoch_period: usize,
1778    k: usize,
1779    d: usize,
1780    device_id: usize,
1781) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1782    use crate::cuda::cuda_available;
1783    use crate::cuda::oscillators::CudaSrsi;
1784    use numpy::PyUntypedArrayMethods;
1785    if !cuda_available() {
1786        return Err(PyValueError::new_err("CUDA not available"));
1787    }
1788    let shape = data_tm_f32.shape();
1789    if shape.len() != 2 {
1790        return Err(PyValueError::new_err("expected 2D array"));
1791    }
1792    let rows = shape[0];
1793    let cols = shape[1];
1794    let flat = data_tm_f32.as_slice()?;
1795    let expected = rows
1796        .checked_mul(cols)
1797        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1798    if flat.len() != expected {
1799        return Err(PyValueError::new_err("time-major input length mismatch"));
1800    }
1801    let params = SrsiParams {
1802        rsi_period: Some(rsi_period),
1803        stoch_period: Some(stoch_period),
1804        k: Some(k),
1805        d: Some(d),
1806        source: None,
1807    };
1808    let (pair, ctx, dev_id) = py.allow_threads(|| {
1809        let cuda = CudaSrsi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1810        let ctx = cuda.context_arc();
1811        let dev_id = cuda.device_id();
1812        let res = cuda
1813            .srsi_many_series_one_param_time_major_dev(flat, cols, rows, &params)
1814            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1815        Ok::<_, pyo3::PyErr>((res, ctx, dev_id))
1816    })?;
1817    let dict = pyo3::types::PyDict::new(py);
1818    dict.set_item(
1819        "k",
1820        SrsiDeviceArrayF32Py::new_from_rust(pair.k, ctx.clone(), dev_id),
1821    )?;
1822    dict.set_item(
1823        "d",
1824        SrsiDeviceArrayF32Py::new_from_rust(pair.d, ctx, dev_id),
1825    )?;
1826    dict.set_item("rows", rows)?;
1827    dict.set_item("cols", cols)?;
1828    dict.set_item("rsi_period", rsi_period)?;
1829    dict.set_item("stoch_period", stoch_period)?;
1830    dict.set_item("k_period", k)?;
1831    dict.set_item("d_period", d)?;
1832    Ok(dict)
1833}
1834
1835#[cfg(feature = "python")]
1836#[pyfunction(name = "srsi")]
1837#[pyo3(signature = (data, rsi_period=None, stoch_period=None, k=None, d=None, source=None, kernel=None))]
1838pub fn srsi_py<'py>(
1839    py: Python<'py>,
1840    data: PyReadonlyArray1<'py, f64>,
1841    rsi_period: Option<usize>,
1842    stoch_period: Option<usize>,
1843    k: Option<usize>,
1844    d: Option<usize>,
1845    source: Option<&str>,
1846    kernel: Option<&str>,
1847) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
1848    let slice_in = data.as_slice()?;
1849    let kern = validate_kernel(kernel, false)?;
1850
1851    if matches!(rsi_period, Some(0))
1852        || matches!(stoch_period, Some(0))
1853        || matches!(k, Some(0))
1854        || matches!(d, Some(0))
1855    {
1856        return Err(PyValueError::new_err("Invalid period: values must be > 0"));
1857    }
1858
1859    let params = SrsiParams {
1860        rsi_period,
1861        stoch_period,
1862        k,
1863        d,
1864        source: source.map(|s| s.to_string()),
1865    };
1866    let input = SrsiInput::from_slice(slice_in, params);
1867
1868    let (k_vec, d_vec) = py
1869        .allow_threads(|| srsi_with_kernel(&input, kern).map(|o| (o.k, o.d)))
1870        .map_err(|e| {
1871            let msg = e.to_string();
1872            if msg.contains("Not enough valid data")
1873                && (matches!(rsi_period, Some(0))
1874                    || matches!(stoch_period, Some(0))
1875                    || matches!(k, Some(0))
1876                    || matches!(d, Some(0)))
1877            {
1878                PyValueError::new_err("Invalid period: values must be > 0")
1879            } else {
1880                PyValueError::new_err(msg)
1881            }
1882        })?;
1883
1884    Ok((k_vec.into_pyarray(py), d_vec.into_pyarray(py)))
1885}
1886
1887#[cfg(feature = "python")]
1888#[pyclass(name = "SrsiStream")]
1889pub struct SrsiStreamPy {
1890    stream: SrsiStream,
1891}
1892
1893#[cfg(feature = "python")]
1894#[pymethods]
1895impl SrsiStreamPy {
1896    #[new]
1897    fn new(
1898        rsi_period: Option<usize>,
1899        stoch_period: Option<usize>,
1900        k: Option<usize>,
1901        d: Option<usize>,
1902    ) -> PyResult<Self> {
1903        let params = SrsiParams {
1904            rsi_period,
1905            stoch_period,
1906            k,
1907            d,
1908            source: None,
1909        };
1910        let stream =
1911            SrsiStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1912        Ok(SrsiStreamPy { stream })
1913    }
1914
1915    fn update(&mut self, value: f64) -> Option<(f64, f64)> {
1916        self.stream.update(value)
1917    }
1918}
1919
1920#[cfg(feature = "python")]
1921#[pyfunction(name = "srsi_batch")]
1922#[pyo3(signature = (data, rsi_period_range, stoch_period_range, k_range, d_range, source=None, kernel=None))]
1923pub fn srsi_batch_py<'py>(
1924    py: Python<'py>,
1925    data: PyReadonlyArray1<'py, f64>,
1926    rsi_period_range: (usize, usize, usize),
1927    stoch_period_range: (usize, usize, usize),
1928    k_range: (usize, usize, usize),
1929    d_range: (usize, usize, usize),
1930    source: Option<&str>,
1931    kernel: Option<&str>,
1932) -> PyResult<Bound<'py, PyDict>> {
1933    let slice_in = data.as_slice()?;
1934    let kern = validate_kernel(kernel, true)?;
1935
1936    let sweep = SrsiBatchRange {
1937        rsi_period: rsi_period_range,
1938        stoch_period: stoch_period_range,
1939        k: k_range,
1940        d: d_range,
1941    };
1942
1943    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1944    let rows = combos.len();
1945    let cols = slice_in.len();
1946
1947    let total = rows
1948        .checked_mul(cols)
1949        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1950    let k_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1951    let d_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1952    let k_slice = unsafe { k_arr.as_slice_mut()? };
1953    let d_slice = unsafe { d_arr.as_slice_mut()? };
1954
1955    let combos = py
1956        .allow_threads(|| {
1957            let kernel = match kern {
1958                Kernel::Auto => detect_best_batch_kernel(),
1959                k => k,
1960            };
1961            srsi_batch_inner_into(slice_in, &sweep, kernel, true, k_slice, d_slice)
1962        })
1963        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1964
1965    let dict = PyDict::new(py);
1966    dict.set_item("k", k_arr.reshape((rows, cols))?)?;
1967    dict.set_item("d", d_arr.reshape((rows, cols))?)?;
1968    dict.set_item(
1969        "rsi_periods",
1970        combos
1971            .iter()
1972            .map(|p| p.rsi_period.unwrap() as u64)
1973            .collect::<Vec<_>>()
1974            .into_pyarray(py),
1975    )?;
1976    dict.set_item(
1977        "stoch_periods",
1978        combos
1979            .iter()
1980            .map(|p| p.stoch_period.unwrap() as u64)
1981            .collect::<Vec<_>>()
1982            .into_pyarray(py),
1983    )?;
1984    dict.set_item(
1985        "k_periods",
1986        combos
1987            .iter()
1988            .map(|p| p.k.unwrap() as u64)
1989            .collect::<Vec<_>>()
1990            .into_pyarray(py),
1991    )?;
1992    dict.set_item(
1993        "d_periods",
1994        combos
1995            .iter()
1996            .map(|p| p.d.unwrap() as u64)
1997            .collect::<Vec<_>>()
1998            .into_pyarray(py),
1999    )?;
2000
2001    Ok(dict)
2002}
2003
2004pub fn srsi_into_slice(
2005    dst_k: &mut [f64],
2006    dst_d: &mut [f64],
2007    input: &SrsiInput,
2008    kern: Kernel,
2009) -> Result<(), SrsiError> {
2010    let data: &[f64] = input.as_ref();
2011
2012    if dst_k.len() != data.len() || dst_d.len() != data.len() {
2013        return Err(SrsiError::OutputLengthMismatch {
2014            expected: data.len(),
2015            k_len: dst_k.len(),
2016            d_len: dst_d.len(),
2017        });
2018    }
2019
2020    let out = srsi_with_kernel(input, kern)?;
2021    dst_k.copy_from_slice(&out.k);
2022    dst_d.copy_from_slice(&out.d);
2023
2024    Ok(())
2025}
2026
2027#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2028#[inline]
2029pub fn srsi_into(input: &SrsiInput, out_k: &mut [f64], out_d: &mut [f64]) -> Result<(), SrsiError> {
2030    srsi_into_slice(out_k, out_d, input, Kernel::Auto)
2031}
2032
2033#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2034#[wasm_bindgen]
2035pub fn srsi_js(
2036    data: &[f64],
2037    rsi_period: usize,
2038    stoch_period: usize,
2039    k: usize,
2040    d: usize,
2041) -> Result<Vec<f64>, JsValue> {
2042    if data.is_empty() {
2043        return Err(JsValue::from_str("srsi: Input data is empty"));
2044    }
2045
2046    if rsi_period == 0 || stoch_period == 0 || k == 0 || d == 0 {
2047        return Err(JsValue::from_str("srsi: Invalid period"));
2048    }
2049
2050    let params = SrsiParams {
2051        rsi_period: Some(rsi_period),
2052        stoch_period: Some(stoch_period),
2053        k: Some(k),
2054        d: Some(d),
2055        source: None,
2056    };
2057    let input = SrsiInput::from_slice(data, params);
2058    let out = srsi_with_kernel(&input, Kernel::Auto)
2059        .map_err(|e| JsValue::from_str(&format!("srsi: {}", e)))?;
2060
2061    let mut values = Vec::with_capacity(2 * data.len());
2062    values.extend_from_slice(&out.k);
2063    values.extend_from_slice(&out.d);
2064
2065    Ok(values)
2066}
2067
2068#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2069#[wasm_bindgen]
2070pub fn srsi_alloc(len: usize) -> *mut f64 {
2071    let mut vec = Vec::<f64>::with_capacity(len);
2072    let ptr = vec.as_mut_ptr();
2073    std::mem::forget(vec);
2074    ptr
2075}
2076
2077#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2078#[wasm_bindgen]
2079pub fn srsi_free(ptr: *mut f64, len: usize) {
2080    unsafe {
2081        let _ = Vec::from_raw_parts(ptr, len, len);
2082    }
2083}
2084
2085#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2086#[wasm_bindgen]
2087pub fn srsi_into(
2088    in_ptr: usize,
2089    k_ptr: usize,
2090    d_ptr: usize,
2091    len: usize,
2092    rsi_period: usize,
2093    stoch_period: usize,
2094    k: usize,
2095    d: usize,
2096) -> Result<(), JsValue> {
2097    unsafe {
2098        let data = std::slice::from_raw_parts(in_ptr as *const f64, len);
2099
2100        if rsi_period == 0 || stoch_period == 0 || k == 0 || d == 0 {
2101            return Err(JsValue::from_str("Invalid period"));
2102        }
2103
2104        let params = SrsiParams {
2105            rsi_period: Some(rsi_period),
2106            stoch_period: Some(stoch_period),
2107            k: Some(k),
2108            d: Some(d),
2109            source: None,
2110        };
2111        let input = SrsiInput::from_slice(data, params);
2112
2113        let needs_temp = in_ptr == k_ptr || in_ptr == d_ptr || k_ptr == d_ptr;
2114
2115        if needs_temp {
2116            let mut temp_k = vec![0.0; len];
2117            let mut temp_d = vec![0.0; len];
2118            srsi_into_slice(&mut temp_k, &mut temp_d, &input, Kernel::Auto)
2119                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2120
2121            let k_out = std::slice::from_raw_parts_mut(k_ptr as *mut f64, len);
2122            let d_out = std::slice::from_raw_parts_mut(d_ptr as *mut f64, len);
2123            k_out.copy_from_slice(&temp_k);
2124            d_out.copy_from_slice(&temp_d);
2125        } else {
2126            let k_out = std::slice::from_raw_parts_mut(k_ptr as *mut f64, len);
2127            let d_out = std::slice::from_raw_parts_mut(d_ptr as *mut f64, len);
2128            srsi_into_slice(k_out, d_out, &input, Kernel::Auto)
2129                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2130        }
2131
2132        Ok(())
2133    }
2134}
2135
2136#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2137#[derive(Serialize, Deserialize)]
2138pub struct SrsiBatchConfig {
2139    pub rsi_period_range: (usize, usize, usize),
2140    pub stoch_period_range: (usize, usize, usize),
2141    pub k_range: (usize, usize, usize),
2142    pub d_range: (usize, usize, usize),
2143}
2144
2145#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2146#[derive(Serialize, Deserialize)]
2147pub struct SrsiBatchJsOutput {
2148    pub k_values: Vec<f64>,
2149    pub d_values: Vec<f64>,
2150    pub rows: usize,
2151    pub cols: usize,
2152    pub combos: Vec<SrsiParams>,
2153}
2154
2155#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2156#[wasm_bindgen(js_name = srsi_batch)]
2157pub fn srsi_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2158    let cfg: SrsiBatchConfig = serde_wasm_bindgen::from_value(config)
2159        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2160
2161    let sweep = SrsiBatchRange {
2162        rsi_period: cfg.rsi_period_range,
2163        stoch_period: cfg.stoch_period_range,
2164        k: cfg.k_range,
2165        d: cfg.d_range,
2166    };
2167
2168    let out = srsi_batch_inner(data, &sweep, Kernel::Auto, false)
2169        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2170
2171    let res = SrsiBatchJsOutput {
2172        k_values: out.k,
2173        d_values: out.d,
2174        rows: out.rows,
2175        cols: out.cols,
2176        combos: out.combos,
2177    };
2178
2179    serde_wasm_bindgen::to_value(&res)
2180        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2181}
2182
2183#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2184#[wasm_bindgen]
2185pub fn srsi_batch_into(
2186    in_ptr: usize,
2187    k_ptr: usize,
2188    d_ptr: usize,
2189    len: usize,
2190    rsi_period_start: usize,
2191    rsi_period_end: usize,
2192    rsi_period_step: usize,
2193    stoch_period_start: usize,
2194    stoch_period_end: usize,
2195    stoch_period_step: usize,
2196    k_start: usize,
2197    k_end: usize,
2198    k_step: usize,
2199    d_start: usize,
2200    d_end: usize,
2201    d_step: usize,
2202) -> Result<usize, JsValue> {
2203    unsafe {
2204        let data = std::slice::from_raw_parts(in_ptr as *const f64, len);
2205
2206        let sweep = SrsiBatchRange {
2207            rsi_period: (rsi_period_start, rsi_period_end, rsi_period_step),
2208            stoch_period: (stoch_period_start, stoch_period_end, stoch_period_step),
2209            k: (k_start, k_end, k_step),
2210            d: (d_start, d_end, d_step),
2211        };
2212
2213        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2214        let rows = combos.len();
2215        let cols = len;
2216
2217        let total = rows
2218            .checked_mul(cols)
2219            .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
2220        let k_out = std::slice::from_raw_parts_mut(k_ptr as *mut f64, total);
2221        let d_out = std::slice::from_raw_parts_mut(d_ptr as *mut f64, total);
2222
2223        srsi_batch_inner_into(data, &sweep, Kernel::Auto, false, k_out, d_out)
2224            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2225
2226        Ok(rows)
2227    }
2228}
2229
2230#[cfg(test)]
2231mod tests {
2232    use super::*;
2233    use crate::skip_if_unsupported;
2234    use crate::utilities::data_loader::read_candles_from_csv;
2235
2236    fn check_srsi_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2237        skip_if_unsupported!(kernel, test_name);
2238        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2239        let candles = read_candles_from_csv(file_path)?;
2240        let default_params = SrsiParams {
2241            rsi_period: None,
2242            stoch_period: None,
2243            k: None,
2244            d: None,
2245            source: None,
2246        };
2247        let input = SrsiInput::from_candles(&candles, "close", default_params);
2248        let output = srsi_with_kernel(&input, kernel)?;
2249        assert_eq!(output.k.len(), candles.close.len());
2250        assert_eq!(output.d.len(), candles.close.len());
2251        Ok(())
2252    }
2253
2254    fn check_srsi_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2255        skip_if_unsupported!(kernel, test_name);
2256        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2257        let candles = read_candles_from_csv(file_path)?;
2258        let params = SrsiParams::default();
2259        let input = SrsiInput::from_candles(&candles, "close", params);
2260        let result = srsi_with_kernel(&input, kernel)?;
2261        assert_eq!(result.k.len(), candles.close.len());
2262        assert_eq!(result.d.len(), candles.close.len());
2263        let last_five_k = [
2264            65.52066633236464,
2265            61.22507053191985,
2266            57.220471530042644,
2267            64.61344854988147,
2268            60.66534359318523,
2269        ];
2270        let last_five_d = [
2271            64.33503158970049,
2272            64.42143544464182,
2273            61.32206946477942,
2274            61.01966353728503,
2275            60.83308789104016,
2276        ];
2277        let k_slice = &result.k[result.k.len() - 5..];
2278        let d_slice = &result.d[result.d.len() - 5..];
2279        for i in 0..5 {
2280            let diff_k = (k_slice[i] - last_five_k[i]).abs();
2281            let diff_d = (d_slice[i] - last_five_d[i]).abs();
2282            assert!(
2283                diff_k < 1e-6,
2284                "Mismatch in SRSI K at index {}: got {}, expected {}",
2285                i,
2286                k_slice[i],
2287                last_five_k[i]
2288            );
2289            assert!(
2290                diff_d < 1e-6,
2291                "Mismatch in SRSI D at index {}: got {}, expected {}",
2292                i,
2293                d_slice[i],
2294                last_five_d[i]
2295            );
2296        }
2297        Ok(())
2298    }
2299
2300    fn check_srsi_from_slice(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2301        skip_if_unsupported!(kernel, test_name);
2302        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2303        let candles = read_candles_from_csv(file_path)?;
2304        let slice_data = candles.close.as_slice();
2305        let params = SrsiParams {
2306            rsi_period: Some(3),
2307            stoch_period: Some(3),
2308            k: Some(2),
2309            d: Some(2),
2310            source: Some("close".to_string()),
2311        };
2312        let input = SrsiInput::from_slice(&slice_data, params);
2313        let output = srsi_with_kernel(&input, kernel)?;
2314        assert_eq!(output.k.len(), slice_data.len());
2315        assert_eq!(output.d.len(), slice_data.len());
2316        Ok(())
2317    }
2318
2319    fn check_srsi_custom_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2320        skip_if_unsupported!(kernel, test_name);
2321        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2322        let candles = read_candles_from_csv(file_path)?;
2323        let params = SrsiParams {
2324            rsi_period: Some(10),
2325            stoch_period: Some(10),
2326            k: Some(4),
2327            d: Some(4),
2328            source: Some("hlc3".to_string()),
2329        };
2330        let input = SrsiInput::from_candles(&candles, "hlc3", params);
2331        let output = srsi_with_kernel(&input, kernel)?;
2332        assert_eq!(output.k.len(), candles.close.len());
2333        assert_eq!(output.d.len(), candles.close.len());
2334        Ok(())
2335    }
2336
2337    #[cfg(debug_assertions)]
2338    fn check_srsi_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2339        skip_if_unsupported!(kernel, test_name);
2340
2341        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2342        let candles = read_candles_from_csv(file_path)?;
2343
2344        let test_params = vec![
2345            SrsiParams::default(),
2346            SrsiParams {
2347                rsi_period: Some(2),
2348                stoch_period: Some(2),
2349                k: Some(2),
2350                d: Some(2),
2351                source: None,
2352            },
2353            SrsiParams {
2354                rsi_period: Some(5),
2355                stoch_period: Some(5),
2356                k: Some(3),
2357                d: Some(3),
2358                source: None,
2359            },
2360            SrsiParams {
2361                rsi_period: Some(10),
2362                stoch_period: Some(10),
2363                k: Some(5),
2364                d: Some(5),
2365                source: None,
2366            },
2367            SrsiParams {
2368                rsi_period: Some(20),
2369                stoch_period: Some(20),
2370                k: Some(7),
2371                d: Some(7),
2372                source: None,
2373            },
2374            SrsiParams {
2375                rsi_period: Some(50),
2376                stoch_period: Some(50),
2377                k: Some(10),
2378                d: Some(10),
2379                source: None,
2380            },
2381            SrsiParams {
2382                rsi_period: Some(7),
2383                stoch_period: Some(14),
2384                k: Some(3),
2385                d: Some(5),
2386                source: None,
2387            },
2388            SrsiParams {
2389                rsi_period: Some(14),
2390                stoch_period: Some(7),
2391                k: Some(5),
2392                d: Some(3),
2393                source: None,
2394            },
2395            SrsiParams {
2396                rsi_period: Some(21),
2397                stoch_period: Some(14),
2398                k: Some(6),
2399                d: Some(4),
2400                source: None,
2401            },
2402            SrsiParams {
2403                rsi_period: Some(100),
2404                stoch_period: Some(100),
2405                k: Some(20),
2406                d: Some(20),
2407                source: None,
2408            },
2409        ];
2410
2411        for (param_idx, params) in test_params.iter().enumerate() {
2412            let input = SrsiInput::from_candles(&candles, "close", params.clone());
2413            let output = srsi_with_kernel(&input, kernel)?;
2414
2415            for (i, &val) in output.k.iter().enumerate() {
2416                if val.is_nan() {
2417                    continue;
2418                }
2419
2420                let bits = val.to_bits();
2421
2422                if bits == 0x11111111_11111111 {
2423                    panic!(
2424						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in K output \
2425						 with params: rsi_period={}, stoch_period={}, k={}, d={} (param set {})",
2426						test_name, val, bits, i,
2427						params.rsi_period.unwrap_or(14),
2428						params.stoch_period.unwrap_or(14),
2429						params.k.unwrap_or(3),
2430						params.d.unwrap_or(3),
2431						param_idx
2432					);
2433                }
2434
2435                if bits == 0x22222222_22222222 {
2436                    panic!(
2437						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in K output \
2438						 with params: rsi_period={}, stoch_period={}, k={}, d={} (param set {})",
2439						test_name, val, bits, i,
2440						params.rsi_period.unwrap_or(14),
2441						params.stoch_period.unwrap_or(14),
2442						params.k.unwrap_or(3),
2443						params.d.unwrap_or(3),
2444						param_idx
2445					);
2446                }
2447
2448                if bits == 0x33333333_33333333 {
2449                    panic!(
2450						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in K output \
2451						 with params: rsi_period={}, stoch_period={}, k={}, d={} (param set {})",
2452						test_name, val, bits, i,
2453						params.rsi_period.unwrap_or(14),
2454						params.stoch_period.unwrap_or(14),
2455						params.k.unwrap_or(3),
2456						params.d.unwrap_or(3),
2457						param_idx
2458					);
2459                }
2460            }
2461
2462            for (i, &val) in output.d.iter().enumerate() {
2463                if val.is_nan() {
2464                    continue;
2465                }
2466
2467                let bits = val.to_bits();
2468
2469                if bits == 0x11111111_11111111 {
2470                    panic!(
2471						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in D output \
2472						 with params: rsi_period={}, stoch_period={}, k={}, d={} (param set {})",
2473						test_name, val, bits, i,
2474						params.rsi_period.unwrap_or(14),
2475						params.stoch_period.unwrap_or(14),
2476						params.k.unwrap_or(3),
2477						params.d.unwrap_or(3),
2478						param_idx
2479					);
2480                }
2481
2482                if bits == 0x22222222_22222222 {
2483                    panic!(
2484						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in D output \
2485						 with params: rsi_period={}, stoch_period={}, k={}, d={} (param set {})",
2486						test_name, val, bits, i,
2487						params.rsi_period.unwrap_or(14),
2488						params.stoch_period.unwrap_or(14),
2489						params.k.unwrap_or(3),
2490						params.d.unwrap_or(3),
2491						param_idx
2492					);
2493                }
2494
2495                if bits == 0x33333333_33333333 {
2496                    panic!(
2497						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in D output \
2498						 with params: rsi_period={}, stoch_period={}, k={}, d={} (param set {})",
2499						test_name, val, bits, i,
2500						params.rsi_period.unwrap_or(14),
2501						params.stoch_period.unwrap_or(14),
2502						params.k.unwrap_or(3),
2503						params.d.unwrap_or(3),
2504						param_idx
2505					);
2506                }
2507            }
2508        }
2509
2510        Ok(())
2511    }
2512
2513    #[cfg(not(debug_assertions))]
2514    fn check_srsi_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2515        Ok(())
2516    }
2517
2518    #[cfg(feature = "proptest")]
2519    #[allow(clippy::float_cmp)]
2520    fn check_srsi_property(
2521        test_name: &str,
2522        kernel: Kernel,
2523    ) -> Result<(), Box<dyn std::error::Error>> {
2524        use proptest::prelude::*;
2525        skip_if_unsupported!(kernel, test_name);
2526
2527        let strat = (2usize..=20, 2usize..=20, 2usize..=10, 2usize..=10).prop_flat_map(
2528            |(rsi_period, stoch_period, k, d)| {
2529                let min_data_needed = rsi_period + stoch_period.max(k).max(d) + 10;
2530                (
2531                    prop::collection::vec(
2532                        (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2533                        min_data_needed..400,
2534                    ),
2535                    Just(rsi_period),
2536                    Just(stoch_period),
2537                    Just(k),
2538                    Just(d),
2539                )
2540            },
2541        );
2542
2543        proptest::test_runner::TestRunner::default()
2544            .run(&strat, |(data, rsi_period, stoch_period, k, d)| {
2545                let params = SrsiParams {
2546                    rsi_period: Some(rsi_period),
2547                    stoch_period: Some(stoch_period),
2548                    k: Some(k),
2549                    d: Some(d),
2550                    source: None,
2551                };
2552                let input = SrsiInput::from_slice(&data, params.clone());
2553
2554                let output_result = srsi_with_kernel(&input, kernel);
2555                let ref_output_result = srsi_with_kernel(&input, Kernel::Scalar);
2556
2557                match (output_result, ref_output_result) {
2558                    (Ok(output), Ok(ref_output)) => {
2559                        let expected_min_warmup = rsi_period;
2560
2561                        for i in 0..data.len() {
2562                            if !output.k[i].is_nan() {
2563                                prop_assert!(
2564                                    output.k[i] >= -1e-9 && output.k[i] <= 100.0 + 1e-9,
2565                                    "idx {}: K value {} is out of bounds [0, 100]",
2566                                    i,
2567                                    output.k[i]
2568                                );
2569                            }
2570                            if !output.d[i].is_nan() {
2571                                prop_assert!(
2572                                    output.d[i] >= -1e-9 && output.d[i] <= 100.0 + 1e-9,
2573                                    "idx {}: D value {} is out of bounds [0, 100]",
2574                                    i,
2575                                    output.d[i]
2576                                );
2577                            }
2578                        }
2579
2580                        for i in 0..expected_min_warmup.min(data.len()) {
2581                            prop_assert!(
2582                                output.k[i].is_nan(),
2583                                "idx {}: Expected NaN during early warmup for K, got {}",
2584                                i,
2585                                output.k[i]
2586                            );
2587                            prop_assert!(
2588                                output.d[i].is_nan(),
2589                                "idx {}: Expected NaN during early warmup for D, got {}",
2590                                i,
2591                                output.d[i]
2592                            );
2593                        }
2594
2595                        let has_valid_k = output.k.iter().any(|&x| !x.is_nan());
2596                        let has_valid_d = output.d.iter().any(|&x| !x.is_nan());
2597                        if data.len() > rsi_period + stoch_period + k + d {
2598                            prop_assert!(
2599                                has_valid_k,
2600                                "Expected at least one valid K value with sufficient data"
2601                            );
2602                            prop_assert!(
2603                                has_valid_d,
2604                                "Expected at least one valid D value with sufficient data"
2605                            );
2606                        }
2607
2608                        if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) {
2609                            let last_k = output.k[data.len() - 1];
2610                            let last_d = output.d[data.len() - 1];
2611                            if !last_k.is_nan() && !last_d.is_nan() {
2612                                prop_assert!(
2613                                    (last_k - 50.0).abs() < 10.0,
2614                                    "Constant data should produce K near 50, got {}",
2615                                    last_k
2616                                );
2617                                prop_assert!(
2618                                    (last_d - 50.0).abs() < 10.0,
2619                                    "Constant data should produce D near 50, got {}",
2620                                    last_d
2621                                );
2622                            }
2623                        }
2624
2625                        let is_increasing = data.windows(2).all(|w| w[1] > w[0]);
2626                        if is_increasing && has_valid_k {
2627                            let last_k = output.k[data.len() - 1];
2628                            if !last_k.is_nan() {
2629                                prop_assert!(
2630                                    last_k > 50.0,
2631                                    "Strictly increasing prices should produce K > 50, got {}",
2632                                    last_k
2633                                );
2634                            }
2635                        }
2636
2637                        let is_decreasing = data.windows(2).all(|w| w[1] < w[0]);
2638                        if is_decreasing && has_valid_k {
2639                            let last_k = output.k[data.len() - 1];
2640                            if !last_k.is_nan() {
2641                                prop_assert!(
2642                                    last_k < 50.0,
2643                                    "Strictly decreasing prices should produce K < 50, got {}",
2644                                    last_k
2645                                );
2646                            }
2647                        }
2648
2649                        for i in 0..data.len() {
2650                            let k_val = output.k[i];
2651                            let d_val = output.d[i];
2652                            let ref_k = ref_output.k[i];
2653                            let ref_d = ref_output.d[i];
2654
2655                            if !k_val.is_finite() || !ref_k.is_finite() {
2656                                prop_assert!(
2657                                    k_val.to_bits() == ref_k.to_bits(),
2658                                    "K finite/NaN mismatch idx {}: {} vs {}",
2659                                    i,
2660                                    k_val,
2661                                    ref_k
2662                                );
2663                            } else {
2664                                let k_ulp_diff = k_val.to_bits().abs_diff(ref_k.to_bits());
2665                                prop_assert!(
2666                                    (k_val - ref_k).abs() <= 1e-9 || k_ulp_diff <= 4,
2667                                    "K mismatch idx {}: {} vs {} (ULP={})",
2668                                    i,
2669                                    k_val,
2670                                    ref_k,
2671                                    k_ulp_diff
2672                                );
2673                            }
2674
2675                            if !d_val.is_finite() || !ref_d.is_finite() {
2676                                prop_assert!(
2677                                    d_val.to_bits() == ref_d.to_bits(),
2678                                    "D finite/NaN mismatch idx {}: {} vs {}",
2679                                    i,
2680                                    d_val,
2681                                    ref_d
2682                                );
2683                            } else {
2684                                let d_ulp_diff = d_val.to_bits().abs_diff(ref_d.to_bits());
2685                                prop_assert!(
2686                                    (d_val - ref_d).abs() <= 1e-9 || d_ulp_diff <= 4,
2687                                    "D mismatch idx {}: {} vs {} (ULP={})",
2688                                    i,
2689                                    d_val,
2690                                    ref_d,
2691                                    d_ulp_diff
2692                                );
2693                            }
2694                        }
2695
2696                        let output2 = srsi_with_kernel(&input, kernel).unwrap();
2697                        for i in 0..data.len() {
2698                            prop_assert!(
2699                                output.k[i].to_bits() == output2.k[i].to_bits(),
2700                                "K determinism failed at idx {}: {} vs {}",
2701                                i,
2702                                output.k[i],
2703                                output2.k[i]
2704                            );
2705                            prop_assert!(
2706                                output.d[i].to_bits() == output2.d[i].to_bits(),
2707                                "D determinism failed at idx {}: {} vs {}",
2708                                i,
2709                                output.d[i],
2710                                output2.d[i]
2711                            );
2712                        }
2713                    }
2714                    (Err(_), Err(_)) => {}
2715                    (Ok(_), Err(e)) => {
2716                        prop_assert!(false, "Kernel succeeded but scalar failed: {:?}", e);
2717                    }
2718                    (Err(e), Ok(_)) => {
2719                        prop_assert!(false, "Kernel failed but scalar succeeded: {:?}", e);
2720                    }
2721                }
2722
2723                Ok(())
2724            })
2725            .unwrap();
2726
2727        Ok(())
2728    }
2729
2730    macro_rules! generate_all_srsi_tests {
2731        ($($test_fn:ident),*) => {
2732            paste::paste! {
2733                $(
2734                    #[test]
2735                    fn [<$test_fn _scalar_f64>]() {
2736                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2737                    }
2738                )*
2739                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2740                $(
2741                    #[test]
2742                    fn [<$test_fn _avx2_f64>]() {
2743                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2744                    }
2745                    #[test]
2746                    fn [<$test_fn _avx512_f64>]() {
2747                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2748                    }
2749                )*
2750            }
2751        }
2752    }
2753
2754    generate_all_srsi_tests!(
2755        check_srsi_partial_params,
2756        check_srsi_accuracy,
2757        check_srsi_custom_params,
2758        check_srsi_from_slice,
2759        check_srsi_no_poison
2760    );
2761
2762    #[test]
2763    fn test_srsi_into_slice_size_mismatch() {
2764        let data: Vec<f64> = (1..=50).map(|x| x as f64).collect();
2765        let data_len = data.len();
2766        let params = SrsiParams::default();
2767        let input = SrsiInput::from_slice(&data, params);
2768
2769        let mut k_small = vec![0.0; 30];
2770        let mut d_correct = vec![0.0; data_len];
2771        let result = srsi_into_slice(&mut k_small, &mut d_correct, &input, Kernel::Scalar);
2772        match result {
2773            Err(SrsiError::OutputLengthMismatch {
2774                expected,
2775                k_len,
2776                d_len,
2777            }) => {
2778                assert_eq!(expected, data_len);
2779                assert_eq!(k_len, 30);
2780                assert_eq!(d_len, data_len);
2781            }
2782            _ => panic!("Expected SizeMismatch error with k buffer too small"),
2783        }
2784
2785        let mut k_correct = vec![0.0; data_len];
2786        let mut d_small = vec![0.0; 35];
2787        let result = srsi_into_slice(&mut k_correct, &mut d_small, &input, Kernel::Scalar);
2788        match result {
2789            Err(SrsiError::OutputLengthMismatch {
2790                expected,
2791                k_len,
2792                d_len,
2793            }) => {
2794                assert_eq!(expected, data_len);
2795                assert_eq!(k_len, data_len);
2796                assert_eq!(d_len, 35);
2797            }
2798            _ => panic!("Expected SizeMismatch error with d buffer too small"),
2799        }
2800
2801        let mut k_wrong = vec![0.0; 60];
2802        let mut d_wrong = vec![0.0; 70];
2803        let result = srsi_into_slice(&mut k_wrong, &mut d_wrong, &input, Kernel::Scalar);
2804        match result {
2805            Err(SrsiError::OutputLengthMismatch {
2806                expected,
2807                k_len,
2808                d_len,
2809            }) => {
2810                assert_eq!(expected, data_len);
2811                assert_eq!(k_len, 60);
2812                assert_eq!(d_len, 70);
2813            }
2814            _ => panic!("Expected SizeMismatch error with both buffers wrong size"),
2815        }
2816
2817        let mut k_ok = vec![0.0; data_len];
2818        let mut d_ok = vec![0.0; data_len];
2819        let result = srsi_into_slice(&mut k_ok, &mut d_ok, &input, Kernel::Scalar);
2820        assert!(
2821            result.is_ok(),
2822            "Should succeed with correct buffer sizes. Error: {:?}",
2823            result
2824        );
2825    }
2826
2827    #[cfg(feature = "proptest")]
2828    generate_all_srsi_tests!(check_srsi_property);
2829
2830    #[test]
2831    fn test_srsi_into_matches_api() {
2832        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2833        let c = read_candles_from_csv(file).expect("load csv");
2834        let input = SrsiInput::from_candles(&c, "close", SrsiParams::default());
2835
2836        let base = srsi(&input).expect("srsi baseline");
2837
2838        let mut out_k = vec![0.0; c.close.len()];
2839        let mut out_d = vec![0.0; c.close.len()];
2840
2841        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2842        {
2843            srsi_into(&input, &mut out_k, &mut out_d).expect("srsi_into");
2844        }
2845        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2846        {
2847            srsi_into_slice(&mut out_k, &mut out_d, &input, Kernel::Auto).expect("srsi_into_slice");
2848        }
2849
2850        assert_eq!(base.k.len(), c.close.len());
2851        assert_eq!(base.d.len(), c.close.len());
2852        assert_eq!(out_k.len(), c.close.len());
2853        assert_eq!(out_d.len(), c.close.len());
2854
2855        fn eq_or_both_nan(a: f64, b: f64) -> bool {
2856            (a.is_nan() && b.is_nan()) || (a == b)
2857        }
2858
2859        for i in 0..c.close.len() {
2860            assert!(
2861                eq_or_both_nan(base.k[i], out_k[i]),
2862                "SRSI K mismatch at {i}: {} vs {}",
2863                base.k[i],
2864                out_k[i]
2865            );
2866            assert!(
2867                eq_or_both_nan(base.d[i], out_d[i]),
2868                "SRSI D mismatch at {i}: {} vs {}",
2869                base.d[i],
2870                out_d[i]
2871            );
2872        }
2873    }
2874
2875    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2876        skip_if_unsupported!(kernel, test);
2877        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2878        let c = read_candles_from_csv(file)?;
2879        let output = SrsiBatchBuilder::new()
2880            .kernel(kernel)
2881            .apply_slice(&c.close)?;
2882        let def = SrsiParams::default();
2883        let k_row = output.k_for(&def).expect("default k row missing");
2884        let d_row = output.d_for(&def).expect("default d row missing");
2885        assert_eq!(k_row.len(), c.close.len());
2886        assert_eq!(d_row.len(), c.close.len());
2887        Ok(())
2888    }
2889
2890    #[cfg(debug_assertions)]
2891    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2892        skip_if_unsupported!(kernel, test);
2893
2894        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2895        let c = read_candles_from_csv(file)?;
2896
2897        let test_configs = vec![
2898            ((2, 10, 2), (2, 10, 2), (2, 4, 1), (2, 4, 1)),
2899            ((5, 25, 5), (5, 25, 5), (3, 7, 2), (3, 7, 2)),
2900            ((30, 60, 15), (30, 60, 15), (5, 10, 5), (5, 10, 5)),
2901            ((2, 5, 1), (2, 5, 1), (2, 3, 1), (2, 3, 1)),
2902            ((10, 30, 10), (5, 15, 5), (3, 6, 3), (3, 6, 3)),
2903            ((14, 14, 0), (14, 14, 0), (3, 3, 0), (3, 3, 0)),
2904            ((7, 21, 7), (14, 28, 14), (3, 9, 3), (3, 9, 3)),
2905        ];
2906
2907        for (cfg_idx, &(rsi_range, stoch_range, k_range, d_range)) in
2908            test_configs.iter().enumerate()
2909        {
2910            let output = SrsiBatchBuilder::new()
2911                .kernel(kernel)
2912                .rsi_period_range(rsi_range.0, rsi_range.1, rsi_range.2)
2913                .stoch_period_range(stoch_range.0, stoch_range.1, stoch_range.2)
2914                .k_range(k_range.0, k_range.1, k_range.2)
2915                .d_range(d_range.0, d_range.1, d_range.2)
2916                .apply_slice(&c.close)?;
2917
2918            for (idx, &val) in output.k.iter().enumerate() {
2919                if val.is_nan() {
2920                    continue;
2921                }
2922
2923                let bits = val.to_bits();
2924                let row = idx / output.cols;
2925                let col = idx % output.cols;
2926                let combo = &output.combos[row];
2927
2928                if bits == 0x11111111_11111111 {
2929                    panic!(
2930						"[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) in K output \
2931						 at row {} col {} (flat index {}) with params: rsi_period={}, stoch_period={}, k={}, d={}",
2932						test, cfg_idx, val, bits, row, col, idx,
2933						combo.rsi_period.unwrap_or(14),
2934						combo.stoch_period.unwrap_or(14),
2935						combo.k.unwrap_or(3),
2936						combo.d.unwrap_or(3)
2937					);
2938                }
2939
2940                if bits == 0x22222222_22222222 {
2941                    panic!(
2942						"[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) in K output \
2943						 at row {} col {} (flat index {}) with params: rsi_period={}, stoch_period={}, k={}, d={}",
2944						test, cfg_idx, val, bits, row, col, idx,
2945						combo.rsi_period.unwrap_or(14),
2946						combo.stoch_period.unwrap_or(14),
2947						combo.k.unwrap_or(3),
2948						combo.d.unwrap_or(3)
2949					);
2950                }
2951
2952                if bits == 0x33333333_33333333 {
2953                    panic!(
2954						"[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) in K output \
2955						 at row {} col {} (flat index {}) with params: rsi_period={}, stoch_period={}, k={}, d={}",
2956						test, cfg_idx, val, bits, row, col, idx,
2957						combo.rsi_period.unwrap_or(14),
2958						combo.stoch_period.unwrap_or(14),
2959						combo.k.unwrap_or(3),
2960						combo.d.unwrap_or(3)
2961					);
2962                }
2963            }
2964
2965            for (idx, &val) in output.d.iter().enumerate() {
2966                if val.is_nan() {
2967                    continue;
2968                }
2969
2970                let bits = val.to_bits();
2971                let row = idx / output.cols;
2972                let col = idx % output.cols;
2973                let combo = &output.combos[row];
2974
2975                if bits == 0x11111111_11111111 {
2976                    panic!(
2977						"[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) in D output \
2978						 at row {} col {} (flat index {}) with params: rsi_period={}, stoch_period={}, k={}, d={}",
2979						test, cfg_idx, val, bits, row, col, idx,
2980						combo.rsi_period.unwrap_or(14),
2981						combo.stoch_period.unwrap_or(14),
2982						combo.k.unwrap_or(3),
2983						combo.d.unwrap_or(3)
2984					);
2985                }
2986
2987                if bits == 0x22222222_22222222 {
2988                    panic!(
2989						"[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) in D output \
2990						 at row {} col {} (flat index {}) with params: rsi_period={}, stoch_period={}, k={}, d={}",
2991						test, cfg_idx, val, bits, row, col, idx,
2992						combo.rsi_period.unwrap_or(14),
2993						combo.stoch_period.unwrap_or(14),
2994						combo.k.unwrap_or(3),
2995						combo.d.unwrap_or(3)
2996					);
2997                }
2998
2999                if bits == 0x33333333_33333333 {
3000                    panic!(
3001						"[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) in D output \
3002						 at row {} col {} (flat index {}) with params: rsi_period={}, stoch_period={}, k={}, d={}",
3003						test, cfg_idx, val, bits, row, col, idx,
3004						combo.rsi_period.unwrap_or(14),
3005						combo.stoch_period.unwrap_or(14),
3006						combo.k.unwrap_or(3),
3007						combo.d.unwrap_or(3)
3008					);
3009                }
3010            }
3011        }
3012
3013        Ok(())
3014    }
3015
3016    #[cfg(not(debug_assertions))]
3017    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
3018        Ok(())
3019    }
3020
3021    macro_rules! gen_batch_tests {
3022        ($fn_name:ident) => {
3023            paste::paste! {
3024                #[test] fn [<$fn_name _scalar>]()      {
3025                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3026                }
3027                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3028                #[test] fn [<$fn_name _avx2>]()        {
3029                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3030                }
3031                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3032                #[test] fn [<$fn_name _avx512>]()      {
3033                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3034                }
3035                #[test] fn [<$fn_name _auto_detect>]() {
3036                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3037                }
3038            }
3039        };
3040    }
3041    gen_batch_tests!(check_batch_default_row);
3042    gen_batch_tests!(check_batch_no_poison);
3043}