Skip to main content

vector_ta/indicators/
rsi.rs

1use crate::utilities::data_loader::{source_type, Candles};
2#[cfg(all(feature = "python", feature = "cuda"))]
3use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
4use crate::utilities::enums::Kernel;
5use crate::utilities::helpers::{
6    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
7    make_uninit_matrix,
8};
9#[cfg(feature = "python")]
10use crate::utilities::kernel_validation::validate_kernel;
11use aligned_vec::{AVec, CACHELINE_ALIGN};
12#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
13use core::arch::x86_64::*;
14#[cfg(feature = "python")]
15use numpy::{IntoPyArray, PyArray1};
16use paste::paste;
17#[cfg(feature = "python")]
18use pyo3::exceptions::PyValueError;
19#[cfg(feature = "python")]
20use pyo3::prelude::*;
21#[cfg(feature = "python")]
22use pyo3::types::PyDict;
23#[cfg(not(target_arch = "wasm32"))]
24use rayon::prelude::*;
25#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
26use serde::{Deserialize, Serialize};
27use std::convert::AsRef;
28use std::error::Error;
29use std::mem::MaybeUninit;
30use thiserror::Error;
31#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
32use wasm_bindgen::prelude::*;
33
34impl<'a> AsRef<[f64]> for RsiInput<'a> {
35    #[inline(always)]
36    fn as_ref(&self) -> &[f64] {
37        match &self.data {
38            RsiData::Slice(slice) => slice,
39            RsiData::Candles { candles, source } => source_type(candles, source),
40        }
41    }
42}
43
44#[derive(Debug, Clone)]
45pub enum RsiData<'a> {
46    Candles {
47        candles: &'a Candles,
48        source: &'a str,
49    },
50    Slice(&'a [f64]),
51}
52
53#[derive(Debug, Clone)]
54pub struct RsiOutput {
55    pub values: Vec<f64>,
56}
57
58#[derive(Debug, Clone)]
59#[cfg_attr(
60    all(target_arch = "wasm32", feature = "wasm"),
61    derive(Serialize, Deserialize)
62)]
63pub struct RsiParams {
64    pub period: Option<usize>,
65}
66
67impl Default for RsiParams {
68    fn default() -> Self {
69        Self { period: Some(14) }
70    }
71}
72
73#[derive(Debug, Clone)]
74pub struct RsiInput<'a> {
75    pub data: RsiData<'a>,
76    pub params: RsiParams,
77}
78
79impl<'a> RsiInput<'a> {
80    #[inline]
81    pub fn from_candles(c: &'a Candles, s: &'a str, p: RsiParams) -> Self {
82        Self {
83            data: RsiData::Candles {
84                candles: c,
85                source: s,
86            },
87            params: p,
88        }
89    }
90    #[inline]
91    pub fn from_slice(sl: &'a [f64], p: RsiParams) -> Self {
92        Self {
93            data: RsiData::Slice(sl),
94            params: p,
95        }
96    }
97    #[inline]
98    pub fn with_default_candles(c: &'a Candles) -> Self {
99        Self::from_candles(c, "close", RsiParams::default())
100    }
101    #[inline]
102    pub fn get_period(&self) -> usize {
103        self.params.period.unwrap_or(14)
104    }
105}
106
107#[derive(Copy, Clone, Debug)]
108pub struct RsiBuilder {
109    period: Option<usize>,
110    kernel: Kernel,
111}
112
113impl Default for RsiBuilder {
114    fn default() -> Self {
115        Self {
116            period: None,
117            kernel: Kernel::Auto,
118        }
119    }
120}
121
122impl RsiBuilder {
123    #[inline(always)]
124    pub fn new() -> Self {
125        Self::default()
126    }
127    #[inline(always)]
128    pub fn period(mut self, n: usize) -> Self {
129        self.period = Some(n);
130        self
131    }
132    #[inline(always)]
133    pub fn kernel(mut self, k: Kernel) -> Self {
134        self.kernel = k;
135        self
136    }
137    #[inline(always)]
138    pub fn apply(self, c: &Candles) -> Result<RsiOutput, RsiError> {
139        let p = RsiParams {
140            period: self.period,
141        };
142        let i = RsiInput::from_candles(c, "close", p);
143        rsi_with_kernel(&i, self.kernel)
144    }
145    #[inline(always)]
146    pub fn apply_slice(self, d: &[f64]) -> Result<RsiOutput, RsiError> {
147        let p = RsiParams {
148            period: self.period,
149        };
150        let i = RsiInput::from_slice(d, p);
151        rsi_with_kernel(&i, self.kernel)
152    }
153    #[inline(always)]
154    pub fn into_stream(self) -> Result<RsiStream, RsiError> {
155        let p = RsiParams {
156            period: self.period,
157        };
158        RsiStream::try_new(p)
159    }
160}
161
162#[derive(Debug, Error)]
163pub enum RsiError {
164    #[error("rsi: Input data slice is empty.")]
165    EmptyInputData,
166    #[error("rsi: All values are NaN.")]
167    AllValuesNaN,
168    #[error("rsi: Invalid period: period = {period}, data length = {data_len}")]
169    InvalidPeriod { period: usize, data_len: usize },
170    #[error("rsi: Not enough valid data: needed = {needed}, valid = {valid}")]
171    NotEnoughValidData { needed: usize, valid: usize },
172    #[error("rsi: Output length mismatch: expected {expected}, got {got}")]
173    OutputLengthMismatch { expected: usize, got: usize },
174    #[error("rsi: Invalid range: start = {start}, end = {end}, step = {step}")]
175    InvalidRange {
176        start: usize,
177        end: usize,
178        step: usize,
179    },
180    #[error("rsi: Invalid kernel for batch: {0:?}")]
181    InvalidKernelForBatch(Kernel),
182}
183
184#[inline]
185pub fn rsi(input: &RsiInput) -> Result<RsiOutput, RsiError> {
186    rsi_with_kernel(input, Kernel::Auto)
187}
188
189pub fn rsi_with_kernel(input: &RsiInput, kernel: Kernel) -> Result<RsiOutput, RsiError> {
190    let data: &[f64] = match &input.data {
191        RsiData::Candles { candles, source } => source_type(candles, source),
192        RsiData::Slice(sl) => sl,
193    };
194
195    let len = data.len();
196    if len == 0 {
197        return Err(RsiError::EmptyInputData);
198    }
199
200    let first = data
201        .iter()
202        .position(|x| !x.is_nan())
203        .ok_or(RsiError::AllValuesNaN)?;
204    let period = input.get_period();
205
206    if period == 0 || period > len {
207        return Err(RsiError::InvalidPeriod {
208            period,
209            data_len: len,
210        });
211    }
212    if (len - first) < period {
213        return Err(RsiError::NotEnoughValidData {
214            needed: period,
215            valid: len - first,
216        });
217    }
218
219    let chosen = match kernel {
220        Kernel::Auto => Kernel::Scalar,
221        other => other,
222    };
223
224    let mut out = alloc_with_nan_prefix(len, first + period);
225    rsi_compute_into(data, period, first, chosen, &mut out);
226    Ok(RsiOutput { values: out })
227}
228
229#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
230pub fn rsi_into(input: &RsiInput, out: &mut [f64]) -> Result<(), RsiError> {
231    rsi_into_slice(out, input, Kernel::Auto)?;
232
233    let data: &[f64] = match &input.data {
234        RsiData::Candles { candles, source } => source_type(candles, source),
235        RsiData::Slice(sl) => sl,
236    };
237    let first = data
238        .iter()
239        .position(|x| !x.is_nan())
240        .ok_or(RsiError::AllValuesNaN)?;
241    let warmup_end = (first + input.get_period()).min(out.len());
242    for v in &mut out[..warmup_end] {
243        *v = f64::from_bits(0x7ff8_0000_0000_0000);
244    }
245
246    Ok(())
247}
248
249#[inline(always)]
250fn rsi_compute_into(data: &[f64], period: usize, first: usize, kernel: Kernel, out: &mut [f64]) {
251    unsafe {
252        match kernel {
253            Kernel::Scalar | Kernel::ScalarBatch => {
254                rsi_compute_into_scalar(data, period, first, out)
255            }
256            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
257            Kernel::Avx2 | Kernel::Avx2Batch => rsi_compute_into_scalar(data, period, first, out),
258            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
259            Kernel::Avx2 | Kernel::Avx2Batch => rsi_compute_into_scalar(data, period, first, out),
260            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
261            Kernel::Avx512 | Kernel::Avx512Batch => {
262                rsi_compute_into_scalar(data, period, first, out)
263            }
264            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
265            Kernel::Avx512 | Kernel::Avx512Batch => {
266                rsi_compute_into_scalar(data, period, first, out)
267            }
268            _ => unreachable!(),
269        }
270    }
271}
272
273#[inline]
274pub fn rsi_into_slice(dst: &mut [f64], input: &RsiInput, kern: Kernel) -> Result<(), RsiError> {
275    let data: &[f64] = match &input.data {
276        RsiData::Candles { candles, source } => source_type(candles, source),
277        RsiData::Slice(sl) => sl,
278    };
279
280    let len = data.len();
281    if len == 0 {
282        return Err(RsiError::EmptyInputData);
283    }
284
285    let first = data
286        .iter()
287        .position(|x| !x.is_nan())
288        .ok_or(RsiError::AllValuesNaN)?;
289    let period = input.get_period();
290
291    if period == 0 || period > len {
292        return Err(RsiError::InvalidPeriod {
293            period,
294            data_len: len,
295        });
296    }
297    if (len - first) < period {
298        return Err(RsiError::NotEnoughValidData {
299            needed: period,
300            valid: len - first,
301        });
302    }
303
304    if dst.len() != data.len() {
305        return Err(RsiError::OutputLengthMismatch {
306            expected: data.len(),
307            got: dst.len(),
308        });
309    }
310
311    let chosen = match kern {
312        Kernel::Auto => Kernel::Scalar,
313        other => other,
314    };
315
316    rsi_compute_into(data, period, first, chosen, dst);
317
318    let warmup_end = first + period;
319    for v in &mut dst[..warmup_end] {
320        *v = f64::NAN;
321    }
322
323    Ok(())
324}
325
326#[inline(always)]
327unsafe fn rsi_compute_into_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
328    let len = data.len();
329    let inv_p = 1.0 / (period as f64);
330    let beta = 1.0 - inv_p;
331
332    let mut avg_gain = 0.0f64;
333    let mut avg_loss = 0.0f64;
334    let mut has_nan = false;
335
336    let warm_last = core::cmp::min(first + period, len.saturating_sub(1));
337    let mut i = first + 1;
338    while i <= warm_last {
339        let delta = data[i] - data[i - 1];
340        if !delta.is_finite() {
341            has_nan = true;
342            break;
343        }
344        if delta > 0.0 {
345            avg_gain += delta;
346        } else if delta < 0.0 {
347            avg_loss -= delta;
348        }
349        i += 1;
350    }
351
352    let idx0 = first + period;
353    if has_nan {
354        avg_gain = f64::NAN;
355        avg_loss = f64::NAN;
356        if idx0 < len {
357            out[idx0] = f64::NAN;
358        }
359    } else {
360        avg_gain *= inv_p;
361        avg_loss *= inv_p;
362        if idx0 < len {
363            let denom = avg_gain + avg_loss;
364            out[idx0] = if denom == 0.0 {
365                50.0
366            } else {
367                100.0 * avg_gain / denom
368            };
369        }
370    }
371
372    let mut j = idx0 + 1;
373    while j + 1 < len {
374        let d1 = data[j] - data[j - 1];
375        let g1 = if d1 > 0.0 { d1 } else { 0.0 };
376        let l1 = if d1 < 0.0 { -d1 } else { 0.0 };
377        avg_gain = avg_gain.mul_add(beta, inv_p * g1);
378        avg_loss = avg_loss.mul_add(beta, inv_p * l1);
379        let denom1 = avg_gain + avg_loss;
380        out[j] = if denom1 == 0.0 {
381            50.0
382        } else {
383            100.0 * avg_gain / denom1
384        };
385
386        let d2 = data[j + 1] - data[j];
387        let g2 = if d2 > 0.0 { d2 } else { 0.0 };
388        let l2 = if d2 < 0.0 { -d2 } else { 0.0 };
389        avg_gain = avg_gain.mul_add(beta, inv_p * g2);
390        avg_loss = avg_loss.mul_add(beta, inv_p * l2);
391        let denom2 = avg_gain + avg_loss;
392        out[j + 1] = if denom2 == 0.0 {
393            50.0
394        } else {
395            100.0 * avg_gain / denom2
396        };
397
398        j += 2;
399    }
400
401    if j < len {
402        let d = data[j] - data[j - 1];
403        let g = if d > 0.0 { d } else { 0.0 };
404        let l = if d < 0.0 { -d } else { 0.0 };
405        avg_gain = avg_gain.mul_add(beta, inv_p * g);
406        avg_loss = avg_loss.mul_add(beta, inv_p * l);
407        let denom = avg_gain + avg_loss;
408        out[j] = if denom == 0.0 {
409            50.0
410        } else {
411            100.0 * avg_gain / denom
412        };
413    }
414}
415
416#[derive(Clone, Debug)]
417pub struct RsiBatchRange {
418    pub period: (usize, usize, usize),
419}
420impl Default for RsiBatchRange {
421    fn default() -> Self {
422        Self {
423            period: (14, 263, 1),
424        }
425    }
426}
427#[derive(Clone, Debug, Default)]
428pub struct RsiBatchBuilder {
429    range: RsiBatchRange,
430    kernel: Kernel,
431}
432impl RsiBatchBuilder {
433    pub fn new() -> Self {
434        Self::default()
435    }
436    pub fn kernel(mut self, k: Kernel) -> Self {
437        self.kernel = k;
438        self
439    }
440    #[inline]
441    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
442        self.range.period = (start, end, step);
443        self
444    }
445    #[inline]
446    pub fn period_static(mut self, p: usize) -> Self {
447        self.range.period = (p, p, 0);
448        self
449    }
450    pub fn apply_slice(self, data: &[f64]) -> Result<RsiBatchOutput, RsiError> {
451        rsi_batch_with_kernel(data, &self.range, self.kernel)
452    }
453    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<RsiBatchOutput, RsiError> {
454        RsiBatchBuilder::new().kernel(k).apply_slice(data)
455    }
456    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<RsiBatchOutput, RsiError> {
457        let slice = source_type(c, src);
458        self.apply_slice(slice)
459    }
460    pub fn with_default_candles(c: &Candles) -> Result<RsiBatchOutput, RsiError> {
461        RsiBatchBuilder::new()
462            .kernel(Kernel::Auto)
463            .apply_candles(c, "close")
464    }
465}
466
467pub fn rsi_batch_with_kernel(
468    data: &[f64],
469    sweep: &RsiBatchRange,
470    k: Kernel,
471) -> Result<RsiBatchOutput, RsiError> {
472    let kernel = match k {
473        Kernel::Auto => detect_best_batch_kernel(),
474        other => {
475            if other.is_batch() {
476                other
477            } else {
478                return Err(RsiError::InvalidKernelForBatch(other));
479            }
480        }
481    };
482    let simd = match kernel {
483        Kernel::Avx512Batch => Kernel::Avx512,
484        Kernel::Avx2Batch => Kernel::Avx2,
485        Kernel::ScalarBatch => Kernel::Scalar,
486        _ => unreachable!(),
487    };
488    rsi_batch_par_slice(data, sweep, simd)
489}
490
491#[derive(Clone, Debug)]
492pub struct RsiBatchOutput {
493    pub values: Vec<f64>,
494    pub combos: Vec<RsiParams>,
495    pub rows: usize,
496    pub cols: usize,
497}
498impl RsiBatchOutput {
499    pub fn row_for_params(&self, p: &RsiParams) -> Option<usize> {
500        self.combos
501            .iter()
502            .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
503    }
504    pub fn values_for(&self, p: &RsiParams) -> Option<&[f64]> {
505        self.row_for_params(p).map(|row| {
506            let start = row * self.cols;
507            &self.values[start..start + self.cols]
508        })
509    }
510}
511
512#[inline(always)]
513fn expand_grid(r: &RsiBatchRange) -> Result<Vec<RsiParams>, RsiError> {
514    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, RsiError> {
515        if step == 0 || start == end {
516            return Ok(vec![start]);
517        }
518        let (lo, hi) = if start <= end {
519            (start, end)
520        } else {
521            (end, start)
522        };
523        let mut out = Vec::new();
524        let mut v = lo;
525        loop {
526            out.push(v);
527            if v == hi {
528                break;
529            }
530            v = match v.checked_add(step) {
531                Some(next) => next,
532                None => return Err(RsiError::InvalidRange { start, end, step }),
533            };
534            if v > hi {
535                break;
536            }
537        }
538        if out.is_empty() {
539            return Err(RsiError::InvalidRange { start, end, step });
540        }
541        Ok(out)
542    }
543    let periods = axis_usize(r.period)?;
544    let mut out = Vec::with_capacity(periods.len());
545    for &p in &periods {
546        out.push(RsiParams { period: Some(p) });
547    }
548    Ok(out)
549}
550
551#[inline(always)]
552pub fn rsi_batch_slice(
553    data: &[f64],
554    sweep: &RsiBatchRange,
555    kern: Kernel,
556) -> Result<RsiBatchOutput, RsiError> {
557    rsi_batch_inner(data, sweep, kern, false)
558}
559#[inline(always)]
560pub fn rsi_batch_par_slice(
561    data: &[f64],
562    sweep: &RsiBatchRange,
563    kern: Kernel,
564) -> Result<RsiBatchOutput, RsiError> {
565    rsi_batch_inner(data, sweep, kern, true)
566}
567
568#[inline(always)]
569fn rsi_batch_inner(
570    data: &[f64],
571    sweep: &RsiBatchRange,
572    kern: Kernel,
573    parallel: bool,
574) -> Result<RsiBatchOutput, RsiError> {
575    if data.is_empty() {
576        return Err(RsiError::EmptyInputData);
577    }
578    let combos = expand_grid(sweep)?;
579    let first = data
580        .iter()
581        .position(|x| !x.is_nan())
582        .ok_or(RsiError::AllValuesNaN)?;
583    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
584    if data.len() - first < max_p {
585        return Err(RsiError::NotEnoughValidData {
586            needed: max_p,
587            valid: data.len() - first,
588        });
589    }
590    let rows = combos.len();
591    let cols = data.len();
592    let _expected = rows.checked_mul(cols).ok_or(RsiError::InvalidRange {
593        start: rows,
594        end: cols,
595        step: 1,
596    })?;
597
598    let mut buf_mu = make_uninit_matrix(rows, cols);
599
600    let warmup_periods: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
601    init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
602
603    let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
604    let values: &mut [f64] = unsafe {
605        core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
606    };
607
608    let mut gains = vec![0.0f64; cols];
609    let mut losses = vec![0.0f64; cols];
610    for i in (first + 1)..cols {
611        let d = data[i] - data[i - 1];
612        if d.is_finite() {
613            if d > 0.0 {
614                gains[i] = d;
615            } else if d < 0.0 {
616                losses[i] = -d;
617            }
618        } else {
619            gains[i] = f64::NAN;
620            losses[i] = f64::NAN;
621        }
622    }
623    let mut pg = vec![0.0f64; cols];
624    let mut pl = vec![0.0f64; cols];
625    for i in 1..cols {
626        pg[i] = pg[i - 1] + gains[i];
627        pl[i] = pl[i - 1] + losses[i];
628    }
629
630    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
631        let period = combos[row].period.unwrap();
632        match kern {
633            Kernel::Scalar | Kernel::Avx2 | Kernel::Avx512 => {
634                let inv_p = 1.0 / (period as f64);
635                let beta = 1.0 - inv_p;
636                let idx0 = first + period;
637                if idx0 < cols {
638                    let sum_g = pg[idx0] - pg[first];
639                    let sum_l = pl[idx0] - pl[first];
640                    let mut avg_g = sum_g * inv_p;
641                    let mut avg_l = sum_l * inv_p;
642                    if sum_g.is_nan() || sum_l.is_nan() {
643                        avg_g = f64::NAN;
644                        avg_l = f64::NAN;
645                        out_row[idx0] = f64::NAN;
646                    } else {
647                        let denom = avg_g + avg_l;
648                        out_row[idx0] = if denom == 0.0 {
649                            50.0
650                        } else {
651                            100.0 * avg_g / denom
652                        };
653                    }
654                    let mut j = idx0 + 1;
655                    while j + 1 < cols {
656                        let g1 = gains[j];
657                        let l1 = losses[j];
658                        avg_g = avg_g.mul_add(beta, inv_p * g1);
659                        avg_l = avg_l.mul_add(beta, inv_p * l1);
660                        let denom1 = avg_g + avg_l;
661                        out_row[j] = if denom1 == 0.0 {
662                            50.0
663                        } else {
664                            100.0 * avg_g / denom1
665                        };
666
667                        let g2 = gains[j + 1];
668                        let l2 = losses[j + 1];
669                        avg_g = avg_g.mul_add(beta, inv_p * g2);
670                        avg_l = avg_l.mul_add(beta, inv_p * l2);
671                        let denom2 = avg_g + avg_l;
672                        out_row[j + 1] = if denom2 == 0.0 {
673                            50.0
674                        } else {
675                            100.0 * avg_g / denom2
676                        };
677                        j += 2;
678                    }
679                    if j < cols {
680                        let g = gains[j];
681                        let l = losses[j];
682                        avg_g = avg_g.mul_add(beta, inv_p * g);
683                        avg_l = avg_l.mul_add(beta, inv_p * l);
684                        let denom = avg_g + avg_l;
685                        out_row[j] = if denom == 0.0 {
686                            50.0
687                        } else {
688                            100.0 * avg_g / denom
689                        };
690                    }
691                }
692            }
693            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
694            Kernel::Avx2 => rsi_row_avx2(data, first, period, out_row),
695            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
696            Kernel::Avx512 => rsi_row_avx512(data, first, period, out_row),
697            _ => unreachable!(),
698        }
699    };
700
701    if parallel {
702        #[cfg(not(target_arch = "wasm32"))]
703        {
704            values
705                .par_chunks_mut(cols)
706                .enumerate()
707                .for_each(|(row, slice)| do_row(row, slice));
708        }
709
710        #[cfg(target_arch = "wasm32")]
711        {
712            for (row, slice) in values.chunks_mut(cols).enumerate() {
713                do_row(row, slice);
714            }
715        }
716    } else {
717        for (row, slice) in values.chunks_mut(cols).enumerate() {
718            do_row(row, slice);
719        }
720    }
721
722    let values = unsafe {
723        Vec::from_raw_parts(
724            buf_guard.as_mut_ptr() as *mut f64,
725            buf_guard.len(),
726            buf_guard.capacity(),
727        )
728    };
729
730    Ok(RsiBatchOutput {
731        values,
732        combos,
733        rows,
734        cols,
735    })
736}
737
738#[inline(always)]
739pub fn rsi_batch_inner_into(
740    data: &[f64],
741    sweep: &RsiBatchRange,
742    kern: Kernel,
743    parallel: bool,
744    out: &mut [f64],
745) -> Result<Vec<RsiParams>, RsiError> {
746    if data.is_empty() {
747        return Err(RsiError::EmptyInputData);
748    }
749    let combos = expand_grid(sweep)?;
750    let first = data
751        .iter()
752        .position(|x| !x.is_nan())
753        .ok_or(RsiError::AllValuesNaN)?;
754    let rows = combos.len();
755    let cols = data.len();
756
757    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
758    if cols - first < max_p {
759        return Err(RsiError::NotEnoughValidData {
760            needed: max_p,
761            valid: cols - first,
762        });
763    }
764    let expected = rows.checked_mul(cols).ok_or(RsiError::InvalidRange {
765        start: rows,
766        end: cols,
767        step: 1,
768    })?;
769    if out.len() != expected {
770        return Err(RsiError::OutputLengthMismatch {
771            expected,
772            got: out.len(),
773        });
774    }
775
776    let out_mu: &mut [MaybeUninit<f64>] = unsafe {
777        core::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
778    };
779    let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
780    init_matrix_prefixes(out_mu, cols, &warm);
781
782    let values: &mut [f64] =
783        unsafe { core::slice::from_raw_parts_mut(out_mu.as_mut_ptr() as *mut f64, out_mu.len()) };
784
785    let mut gains = vec![0.0f64; cols];
786    let mut losses = vec![0.0f64; cols];
787    for i in (first + 1)..cols {
788        let d = data[i] - data[i - 1];
789        if d.is_finite() {
790            if d > 0.0 {
791                gains[i] = d;
792            } else if d < 0.0 {
793                losses[i] = -d;
794            }
795        } else {
796            gains[i] = f64::NAN;
797            losses[i] = f64::NAN;
798        }
799    }
800    let mut pg = vec![0.0f64; cols];
801    let mut pl = vec![0.0f64; cols];
802    for i in 1..cols {
803        pg[i] = pg[i - 1] + gains[i];
804        pl[i] = pl[i - 1] + losses[i];
805    }
806
807    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
808        let period = combos[row].period.unwrap();
809        match kern {
810            Kernel::Scalar | Kernel::Avx2 | Kernel::Avx512 => {
811                let inv_p = 1.0 / (period as f64);
812                let beta = 1.0 - inv_p;
813                let idx0 = first + period;
814                if idx0 < cols {
815                    let sum_g = pg[idx0] - pg[first];
816                    let sum_l = pl[idx0] - pl[first];
817                    let mut avg_g = sum_g * inv_p;
818                    let mut avg_l = sum_l * inv_p;
819                    if sum_g.is_nan() || sum_l.is_nan() {
820                        avg_g = f64::NAN;
821                        avg_l = f64::NAN;
822                        out_row[idx0] = f64::NAN;
823                    } else {
824                        let denom = avg_g + avg_l;
825                        out_row[idx0] = if denom == 0.0 {
826                            50.0
827                        } else {
828                            100.0 * avg_g / denom
829                        };
830                    }
831                    let mut j = idx0 + 1;
832                    while j + 1 < cols {
833                        let g1 = gains[j];
834                        let l1 = losses[j];
835                        avg_g = avg_g.mul_add(beta, inv_p * g1);
836                        avg_l = avg_l.mul_add(beta, inv_p * l1);
837                        let denom1 = avg_g + avg_l;
838                        out_row[j] = if denom1 == 0.0 {
839                            50.0
840                        } else {
841                            100.0 * avg_g / denom1
842                        };
843
844                        let g2 = gains[j + 1];
845                        let l2 = losses[j + 1];
846                        avg_g = avg_g.mul_add(beta, inv_p * g2);
847                        avg_l = avg_l.mul_add(beta, inv_p * l2);
848                        let denom2 = avg_g + avg_l;
849                        out_row[j + 1] = if denom2 == 0.0 {
850                            50.0
851                        } else {
852                            100.0 * avg_g / denom2
853                        };
854                        j += 2;
855                    }
856                    if j < cols {
857                        let g = gains[j];
858                        let l = losses[j];
859                        avg_g = avg_g.mul_add(beta, inv_p * g);
860                        avg_l = avg_l.mul_add(beta, inv_p * l);
861                        let denom = avg_g + avg_l;
862                        out_row[j] = if denom == 0.0 {
863                            50.0
864                        } else {
865                            100.0 * avg_g / denom
866                        };
867                    }
868                }
869            }
870            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
871            Kernel::Avx2 => rsi_row_avx2(data, first, period, out_row),
872            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
873            Kernel::Avx512 => rsi_row_avx512(data, first, period, out_row),
874            _ => unreachable!(),
875        }
876    };
877
878    if parallel {
879        #[cfg(not(target_arch = "wasm32"))]
880        {
881            values
882                .par_chunks_mut(cols)
883                .enumerate()
884                .for_each(|(row, slice)| do_row(row, slice));
885        }
886
887        #[cfg(target_arch = "wasm32")]
888        {
889            for (row, slice) in values.chunks_mut(cols).enumerate() {
890                do_row(row, slice);
891            }
892        }
893    } else {
894        for (row, slice) in values.chunks_mut(cols).enumerate() {
895            do_row(row, slice);
896        }
897    }
898
899    Ok(combos)
900}
901
902#[inline(always)]
903unsafe fn rsi_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
904    let len = data.len();
905    let inv_p = 1.0 / (period as f64);
906    let beta = 1.0 - inv_p;
907
908    let mut avg_gain = 0.0f64;
909    let mut avg_loss = 0.0f64;
910    let mut has_nan = false;
911
912    let warm_last = core::cmp::min(first + period, len.saturating_sub(1));
913    let mut i = first + 1;
914    while i <= warm_last {
915        let delta = data[i] - data[i - 1];
916        if !delta.is_finite() {
917            has_nan = true;
918            break;
919        }
920        if delta > 0.0 {
921            avg_gain += delta;
922        } else if delta < 0.0 {
923            avg_loss -= delta;
924        }
925        i += 1;
926    }
927
928    let idx0 = first + period;
929    if has_nan {
930        avg_gain = f64::NAN;
931        avg_loss = f64::NAN;
932        if idx0 < len {
933            out[idx0] = f64::NAN;
934        }
935    } else {
936        avg_gain *= inv_p;
937        avg_loss *= inv_p;
938        if idx0 < len {
939            let denom = avg_gain + avg_loss;
940            out[idx0] = if denom == 0.0 {
941                50.0
942            } else {
943                100.0 * avg_gain / denom
944            };
945        }
946    }
947
948    let mut j = idx0 + 1;
949    while j + 1 < len {
950        let d1 = data[j] - data[j - 1];
951        let g1 = if d1 > 0.0 { d1 } else { 0.0 };
952        let l1 = if d1 < 0.0 { -d1 } else { 0.0 };
953        avg_gain = avg_gain.mul_add(beta, inv_p * g1);
954        avg_loss = avg_loss.mul_add(beta, inv_p * l1);
955        let denom1 = avg_gain + avg_loss;
956        out[j] = if denom1 == 0.0 {
957            50.0
958        } else {
959            100.0 * avg_gain / denom1
960        };
961
962        let d2 = data[j + 1] - data[j];
963        let g2 = if d2 > 0.0 { d2 } else { 0.0 };
964        let l2 = if d2 < 0.0 { -d2 } else { 0.0 };
965        avg_gain = avg_gain.mul_add(beta, inv_p * g2);
966        avg_loss = avg_loss.mul_add(beta, inv_p * l2);
967        let denom2 = avg_gain + avg_loss;
968        out[j + 1] = if denom2 == 0.0 {
969            50.0
970        } else {
971            100.0 * avg_gain / denom2
972        };
973
974        j += 2;
975    }
976    if j < len {
977        let d = data[j] - data[j - 1];
978        let g = if d > 0.0 { d } else { 0.0 };
979        let l = if d < 0.0 { -d } else { 0.0 };
980        avg_gain = avg_gain.mul_add(beta, inv_p * g);
981        avg_loss = avg_loss.mul_add(beta, inv_p * l);
982        let denom = avg_gain + avg_loss;
983        out[j] = if denom == 0.0 {
984            50.0
985        } else {
986            100.0 * avg_gain / denom
987        };
988    }
989}
990#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
991#[inline(always)]
992unsafe fn rsi_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
993    rsi_row_scalar(data, first, period, out)
994}
995#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
996#[inline(always)]
997unsafe fn rsi_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
998    if period <= 32 {
999        rsi_row_avx512_short(data, first, period, out)
1000    } else {
1001        rsi_row_avx512_long(data, first, period, out)
1002    }
1003}
1004#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1005#[inline(always)]
1006unsafe fn rsi_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1007    rsi_row_scalar(data, first, period, out)
1008}
1009#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1010#[inline(always)]
1011unsafe fn rsi_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1012    rsi_row_scalar(data, first, period, out)
1013}
1014
1015#[derive(Debug, Clone)]
1016pub struct RsiStream {
1017    period: usize,
1018    inv_p: f64,
1019    beta: f64,
1020
1021    has_prev: bool,
1022    prev: f64,
1023
1024    seed_count: usize,
1025    sum_gain: f64,
1026    sum_loss: f64,
1027    poisoned: bool,
1028
1029    avg_gain: f64,
1030    avg_loss: f64,
1031    seeded: bool,
1032}
1033impl RsiStream {
1034    #[inline(always)]
1035    pub fn try_new(params: RsiParams) -> Result<Self, RsiError> {
1036        let period = params.period.unwrap_or(14);
1037        if period == 0 {
1038            return Err(RsiError::InvalidPeriod {
1039                period,
1040                data_len: 0,
1041            });
1042        }
1043        let inv_p = 1.0 / (period as f64);
1044        Ok(Self {
1045            period,
1046            inv_p,
1047            beta: 1.0 - inv_p,
1048
1049            has_prev: false,
1050            prev: f64::NAN,
1051
1052            seed_count: 0,
1053            sum_gain: 0.0,
1054            sum_loss: 0.0,
1055            poisoned: false,
1056
1057            avg_gain: 0.0,
1058            avg_loss: 0.0,
1059            seeded: false,
1060        })
1061    }
1062
1063    #[inline(always)]
1064    pub fn update(&mut self, value: f64) -> Option<f64> {
1065        if !self.has_prev {
1066            self.prev = value;
1067            self.has_prev = true;
1068            return None;
1069        }
1070
1071        let delta = value - self.prev;
1072        self.prev = value;
1073
1074        if !self.seeded {
1075            if !delta.is_finite() {
1076                self.poisoned = true;
1077            }
1078
1079            let gain = delta.max(0.0);
1080            let loss = (-delta).max(0.0);
1081
1082            self.sum_gain += gain;
1083            self.sum_loss += loss;
1084            self.seed_count += 1;
1085
1086            if self.seed_count == self.period {
1087                self.seeded = true;
1088                if self.poisoned {
1089                    self.avg_gain = f64::NAN;
1090                    self.avg_loss = f64::NAN;
1091                    return Some(f64::NAN);
1092                } else {
1093                    self.avg_gain = self.sum_gain * self.inv_p;
1094                    self.avg_loss = self.sum_loss * self.inv_p;
1095                    let denom = self.avg_gain + self.avg_loss;
1096                    let rsi = if denom == 0.0 {
1097                        50.0
1098                    } else {
1099                        100.0 * self.avg_gain / denom
1100                    };
1101                    return Some(rsi);
1102                }
1103            } else {
1104                return None;
1105            }
1106        }
1107
1108        let gain = delta.max(0.0);
1109        let loss = (-delta).max(0.0);
1110
1111        self.avg_gain = self.avg_gain.mul_add(self.beta, self.inv_p * gain);
1112        self.avg_loss = self.avg_loss.mul_add(self.beta, self.inv_p * loss);
1113        let denom = self.avg_gain + self.avg_loss;
1114        Some(if denom == 0.0 {
1115            50.0
1116        } else {
1117            100.0 * self.avg_gain / denom
1118        })
1119    }
1120}
1121
1122#[cfg(test)]
1123mod tests {
1124    use super::*;
1125    use crate::skip_if_unsupported;
1126    use crate::utilities::data_loader::read_candles_from_csv;
1127
1128    fn check_rsi_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1129        skip_if_unsupported!(kernel, test_name);
1130        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1131        let candles = read_candles_from_csv(file_path)?;
1132        let partial_params = RsiParams { period: None };
1133        let input = RsiInput::from_candles(&candles, "close", partial_params);
1134        let result = rsi_with_kernel(&input, kernel)?;
1135        assert_eq!(result.values.len(), candles.close.len());
1136        Ok(())
1137    }
1138    fn check_rsi_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1139        skip_if_unsupported!(kernel, test_name);
1140        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1141        let candles = read_candles_from_csv(file_path)?;
1142        let input = RsiInput::from_candles(&candles, "close", RsiParams { period: Some(14) });
1143        let result = rsi_with_kernel(&input, kernel)?;
1144        let expected_last_five = [43.42, 42.68, 41.62, 42.86, 39.01];
1145        let start = result.values.len().saturating_sub(5);
1146        for (i, &val) in result.values[start..].iter().enumerate() {
1147            let diff = (val - expected_last_five[i]).abs();
1148            assert!(
1149                diff < 1e-2,
1150                "[{}] RSI {:?} mismatch at idx {}: got {}, expected {}",
1151                test_name,
1152                kernel,
1153                i,
1154                val,
1155                expected_last_five[i]
1156            );
1157        }
1158        Ok(())
1159    }
1160    fn check_rsi_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1161        skip_if_unsupported!(kernel, test_name);
1162        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1163        let candles = read_candles_from_csv(file_path)?;
1164        let input = RsiInput::with_default_candles(&candles);
1165        match input.data {
1166            RsiData::Candles { source, .. } => assert_eq!(source, "close"),
1167            _ => panic!("Expected RsiData::Candles"),
1168        }
1169        let output = rsi_with_kernel(&input, kernel)?;
1170        assert_eq!(output.values.len(), candles.close.len());
1171        Ok(())
1172    }
1173    fn check_rsi_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1174        skip_if_unsupported!(kernel, test_name);
1175        let input_data = [10.0, 20.0, 30.0];
1176        let params = RsiParams { period: Some(0) };
1177        let input = RsiInput::from_slice(&input_data, params);
1178        let res = rsi_with_kernel(&input, kernel);
1179        assert!(
1180            res.is_err(),
1181            "[{}] RSI should fail with zero period",
1182            test_name
1183        );
1184        Ok(())
1185    }
1186    fn check_rsi_period_exceeds_length(
1187        test_name: &str,
1188        kernel: Kernel,
1189    ) -> Result<(), Box<dyn Error>> {
1190        skip_if_unsupported!(kernel, test_name);
1191        let data_small = [10.0, 20.0, 30.0];
1192        let params = RsiParams { period: Some(10) };
1193        let input = RsiInput::from_slice(&data_small, params);
1194        let res = rsi_with_kernel(&input, kernel);
1195        assert!(
1196            res.is_err(),
1197            "[{}] RSI should fail with period exceeding length",
1198            test_name
1199        );
1200        Ok(())
1201    }
1202    fn check_rsi_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1203        skip_if_unsupported!(kernel, test_name);
1204        let single_point = [42.0];
1205        let params = RsiParams { period: Some(14) };
1206        let input = RsiInput::from_slice(&single_point, params);
1207        let res = rsi_with_kernel(&input, kernel);
1208        assert!(
1209            res.is_err(),
1210            "[{}] RSI should fail with insufficient data",
1211            test_name
1212        );
1213        Ok(())
1214    }
1215    fn check_rsi_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1216        skip_if_unsupported!(kernel, test_name);
1217        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1218        let candles = read_candles_from_csv(file_path)?;
1219        let first_params = RsiParams { period: Some(14) };
1220        let first_input = RsiInput::from_candles(&candles, "close", first_params);
1221        let first_result = rsi_with_kernel(&first_input, kernel)?;
1222        let second_params = RsiParams { period: Some(5) };
1223        let second_input = RsiInput::from_slice(&first_result.values, second_params);
1224        let second_result = rsi_with_kernel(&second_input, kernel)?;
1225        assert_eq!(second_result.values.len(), first_result.values.len());
1226        if second_result.values.len() > 240 {
1227            for i in 240..second_result.values.len() {
1228                assert!(
1229                    !second_result.values[i].is_nan(),
1230                    "Found NaN in RSI at {}",
1231                    i
1232                );
1233            }
1234        }
1235        Ok(())
1236    }
1237    fn check_rsi_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1238        skip_if_unsupported!(kernel, test_name);
1239        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1240        let candles = read_candles_from_csv(file_path)?;
1241        let input = RsiInput::from_candles(&candles, "close", RsiParams { period: Some(14) });
1242        let res = rsi_with_kernel(&input, kernel)?;
1243        assert_eq!(res.values.len(), candles.close.len());
1244        if res.values.len() > 240 {
1245            for (i, &val) in res.values[240..].iter().enumerate() {
1246                assert!(
1247                    !val.is_nan(),
1248                    "[{}] Found unexpected NaN at out-index {}",
1249                    test_name,
1250                    240 + i
1251                );
1252            }
1253        }
1254        Ok(())
1255    }
1256    fn check_rsi_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1257        skip_if_unsupported!(kernel, test_name);
1258        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1259        let candles = read_candles_from_csv(file_path)?;
1260        let period = 14;
1261        let input = RsiInput::from_candles(
1262            &candles,
1263            "close",
1264            RsiParams {
1265                period: Some(period),
1266            },
1267        );
1268        let batch_output = rsi_with_kernel(&input, kernel)?.values;
1269
1270        let mut stream = RsiStream::try_new(RsiParams {
1271            period: Some(period),
1272        })?;
1273        let mut stream_values = Vec::with_capacity(candles.close.len());
1274        for &price in &candles.close {
1275            match stream.update(price) {
1276                Some(rsi_val) => stream_values.push(rsi_val),
1277                None => stream_values.push(f64::NAN),
1278            }
1279        }
1280        assert_eq!(batch_output.len(), stream_values.len());
1281        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1282            if b.is_nan() && s.is_nan() {
1283                continue;
1284            }
1285            let diff = (b - s).abs();
1286            assert!(
1287                diff < 1e-6,
1288                "[{}] RSI streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1289                test_name,
1290                i,
1291                b,
1292                s,
1293                diff
1294            );
1295        }
1296        Ok(())
1297    }
1298
1299    #[cfg(debug_assertions)]
1300    fn check_rsi_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1301        skip_if_unsupported!(kernel, test_name);
1302
1303        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1304        let candles = read_candles_from_csv(file_path)?;
1305
1306        let test_params = vec![
1307            RsiParams::default(),
1308            RsiParams { period: Some(2) },
1309            RsiParams { period: Some(5) },
1310            RsiParams { period: Some(7) },
1311            RsiParams { period: Some(10) },
1312            RsiParams { period: Some(14) },
1313            RsiParams { period: Some(20) },
1314            RsiParams { period: Some(30) },
1315            RsiParams { period: Some(50) },
1316            RsiParams { period: Some(100) },
1317            RsiParams { period: Some(200) },
1318        ];
1319
1320        for (param_idx, params) in test_params.iter().enumerate() {
1321            let input = RsiInput::from_candles(&candles, "close", params.clone());
1322            let output = rsi_with_kernel(&input, kernel)?;
1323
1324            for (i, &val) in output.values.iter().enumerate() {
1325                if val.is_nan() {
1326                    continue;
1327                }
1328
1329                let bits = val.to_bits();
1330
1331                if bits == 0x11111111_11111111 {
1332                    panic!(
1333                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1334						 with params: period={} (param set {})",
1335                        test_name,
1336                        val,
1337                        bits,
1338                        i,
1339                        params.period.unwrap_or(14),
1340                        param_idx
1341                    );
1342                }
1343
1344                if bits == 0x22222222_22222222 {
1345                    panic!(
1346                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1347						 with params: period={} (param set {})",
1348                        test_name,
1349                        val,
1350                        bits,
1351                        i,
1352                        params.period.unwrap_or(14),
1353                        param_idx
1354                    );
1355                }
1356
1357                if bits == 0x33333333_33333333 {
1358                    panic!(
1359                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1360						 with params: period={} (param set {})",
1361                        test_name,
1362                        val,
1363                        bits,
1364                        i,
1365                        params.period.unwrap_or(14),
1366                        param_idx
1367                    );
1368                }
1369            }
1370        }
1371
1372        Ok(())
1373    }
1374
1375    #[cfg(not(debug_assertions))]
1376    fn check_rsi_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1377        Ok(())
1378    }
1379
1380    #[cfg(feature = "proptest")]
1381    #[allow(clippy::float_cmp)]
1382    fn check_rsi_property(
1383        test_name: &str,
1384        kernel: Kernel,
1385    ) -> Result<(), Box<dyn std::error::Error>> {
1386        use proptest::prelude::*;
1387        skip_if_unsupported!(kernel, test_name);
1388
1389        let strat = (2usize..=100).prop_flat_map(|period| {
1390            (
1391                prop::collection::vec(
1392                    (-1e6f64..1e6f64)
1393                        .prop_filter("finite price", |x| x.is_finite() && x.abs() > 1e-10),
1394                    period + 10..400,
1395                ),
1396                Just(period),
1397            )
1398        });
1399
1400        proptest::test_runner::TestRunner::default().run(&strat, |(data, period)| {
1401            let params = RsiParams {
1402                period: Some(period),
1403            };
1404            let input = RsiInput::from_slice(&data, params);
1405
1406            let RsiOutput { values: out } = rsi_with_kernel(&input, kernel)?;
1407
1408            let RsiOutput { values: ref_out } = rsi_with_kernel(&input, Kernel::Scalar)?;
1409
1410            let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1411            let warmup_end = first_valid + period;
1412
1413            for (i, &val) in out.iter().enumerate() {
1414                if !val.is_nan() {
1415                    prop_assert!(
1416                        val >= 0.0 && val <= 100.0,
1417                        "[{}] RSI value {} at index {} is out of range [0, 100]",
1418                        test_name,
1419                        val,
1420                        i
1421                    );
1422                }
1423            }
1424
1425            for i in 0..warmup_end.min(out.len()) {
1426                prop_assert!(
1427                    out[i].is_nan(),
1428                    "[{}] Expected NaN during warmup at index {}, got {}",
1429                    test_name,
1430                    i,
1431                    out[i]
1432                );
1433            }
1434
1435            if warmup_end < out.len() {
1436                prop_assert!(
1437                    !out[warmup_end].is_nan(),
1438                    "[{}] Expected non-NaN at index {} (warmup_end), got NaN",
1439                    test_name,
1440                    warmup_end
1441                );
1442            }
1443
1444            for i in 0..out.len() {
1445                let y = out[i];
1446                let r = ref_out[i];
1447
1448                if y.is_nan() && r.is_nan() {
1449                    continue;
1450                }
1451
1452                prop_assert!(
1453                    (y - r).abs() < 1e-9,
1454                    "[{}] Kernel mismatch at index {}: {} vs {} (diff: {})",
1455                    test_name,
1456                    i,
1457                    y,
1458                    r,
1459                    (y - r).abs()
1460                );
1461            }
1462
1463            if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12) && warmup_end < out.len() {
1464                for i in warmup_end..out.len() {
1465                    prop_assert!(
1466                        (out[i] - 50.0).abs() < 1e-9,
1467                        "[{}] Constant prices should yield RSI=50, got {} at index {}",
1468                        test_name,
1469                        out[i],
1470                        i
1471                    );
1472                }
1473            }
1474
1475            let strictly_increasing = data.windows(2).all(|w| w[1] > w[0] + 1e-10);
1476            if strictly_increasing && out.len() > warmup_end + 10 {
1477                let last_rsi = out[out.len() - 1];
1478                let high_threshold = if period <= 5 {
1479                    60.0
1480                } else if period <= 20 {
1481                    65.0
1482                } else {
1483                    70.0
1484                };
1485                prop_assert!(
1486                    last_rsi > high_threshold,
1487                    "[{}] Strictly increasing prices should yield RSI > {} (period={}), got {}",
1488                    test_name,
1489                    high_threshold,
1490                    period,
1491                    last_rsi
1492                );
1493            }
1494
1495            let strictly_decreasing = data.windows(2).all(|w| w[1] < w[0] - 1e-10);
1496            if strictly_decreasing && out.len() > warmup_end + 10 {
1497                let last_rsi = out[out.len() - 1];
1498                let low_threshold = if period <= 5 {
1499                    40.0
1500                } else if period <= 20 {
1501                    35.0
1502                } else {
1503                    30.0
1504                };
1505                prop_assert!(
1506                    last_rsi < low_threshold,
1507                    "[{}] Strictly decreasing prices should yield RSI < {} (period={}), got {}",
1508                    test_name,
1509                    low_threshold,
1510                    period,
1511                    last_rsi
1512                );
1513            }
1514
1515            #[cfg(debug_assertions)]
1516            {
1517                for (i, &val) in out.iter().enumerate() {
1518                    if val.is_nan() {
1519                        continue;
1520                    }
1521
1522                    let bits = val.to_bits();
1523                    prop_assert!(
1524                        bits != 0x11111111_11111111
1525                            && bits != 0x22222222_22222222
1526                            && bits != 0x33333333_33333333,
1527                        "[{}] Found poison value {} (0x{:016X}) at index {}",
1528                        test_name,
1529                        val,
1530                        bits,
1531                        i
1532                    );
1533                }
1534            }
1535
1536            let mut oscillating = true;
1537            let mut prev_delta = 0.0;
1538            for window in data.windows(2) {
1539                let delta = window[1] - window[0];
1540                if prev_delta != 0.0 && delta != 0.0 {
1541                    if (delta > 0.0 && prev_delta > 0.0) || (delta < 0.0 && prev_delta < 0.0) {
1542                        oscillating = false;
1543                        break;
1544                    }
1545                }
1546                prev_delta = delta;
1547            }
1548
1549            if oscillating && out.len() > warmup_end + 10 && prev_delta != 0.0 {
1550                let last_quarter_start = out.len() - (out.len() - warmup_end) / 4;
1551                for i in last_quarter_start..out.len() {
1552                    if !out[i].is_nan() {
1553                        prop_assert!(
1554								out[i] >= 35.0 && out[i] <= 65.0,
1555								"[{}] Oscillating prices should keep RSI in [35, 65] range, got {} at index {}",
1556								test_name, out[i], i
1557							);
1558                    }
1559                }
1560            }
1561
1562            if warmup_end + 5 < out.len() {
1563                let idx = warmup_end + 3;
1564                let mut avg_gain = 0.0;
1565                let mut avg_loss = 0.0;
1566
1567                for j in (first_valid + 1)..=(first_valid + period) {
1568                    let delta = data[j] - data[j - 1];
1569                    if delta > 0.0 {
1570                        avg_gain += delta;
1571                    } else {
1572                        avg_loss += -delta;
1573                    }
1574                }
1575                avg_gain /= period as f64;
1576                avg_loss /= period as f64;
1577
1578                let inv_period = 1.0 / period as f64;
1579                let beta = 1.0 - inv_period;
1580                for j in (first_valid + period + 1)..=idx {
1581                    let delta = data[j] - data[j - 1];
1582                    let gain = if delta > 0.0 { delta } else { 0.0 };
1583                    let loss = if delta < 0.0 { -delta } else { 0.0 };
1584                    avg_gain = inv_period * gain + beta * avg_gain;
1585                    avg_loss = inv_period * loss + beta * avg_loss;
1586                }
1587
1588                let expected_rsi = if avg_gain + avg_loss == 0.0 {
1589                    50.0
1590                } else {
1591                    100.0 * avg_gain / (avg_gain + avg_loss)
1592                };
1593
1594                prop_assert!(
1595                    (out[idx] - expected_rsi).abs() < 1e-9,
1596                    "[{}] RSI calculation mismatch at index {}: got {}, expected {}",
1597                    test_name,
1598                    idx,
1599                    out[idx],
1600                    expected_rsi
1601                );
1602            }
1603
1604            Ok(())
1605        })?;
1606
1607        Ok(())
1608    }
1609
1610    macro_rules! generate_all_rsi_tests {
1611        ($($test_fn:ident),*) => {
1612            paste! {
1613                $(
1614                    #[test]
1615                    fn [<$test_fn _scalar_f64>]() {
1616                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1617                    }
1618                )*
1619                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1620                $(
1621                    #[test]
1622                    fn [<$test_fn _avx2_f64>]() {
1623                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1624                    }
1625                    #[test]
1626                    fn [<$test_fn _avx512_f64>]() {
1627                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1628                    }
1629                )*
1630            }
1631        }
1632    }
1633
1634    fn check_rsi_error_variants(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1635        skip_if_unsupported!(kernel, test_name);
1636
1637        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1638        let mut dst = vec![0.0; 3];
1639        let params = RsiParams { period: Some(2) };
1640        let input = RsiInput::from_slice(&data, params);
1641
1642        match rsi_into_slice(&mut dst, &input, kernel) {
1643            Err(RsiError::OutputLengthMismatch {
1644                expected: 5,
1645                got: 3,
1646            }) => {}
1647            other => panic!(
1648                "[{}] Expected OutputLengthMismatch error, got {:?}",
1649                test_name, other
1650            ),
1651        }
1652
1653        let sweep = RsiBatchRange {
1654            period: (14, 14, 0),
1655        };
1656        match rsi_batch_with_kernel(&data, &sweep, Kernel::Scalar) {
1657            Err(RsiError::InvalidKernelForBatch(Kernel::Scalar)) => {}
1658            other => panic!(
1659                "[{}] Expected InvalidKernelForBatch error, got {:?}",
1660                test_name, other
1661            ),
1662        }
1663
1664        Ok(())
1665    }
1666
1667    generate_all_rsi_tests!(
1668        check_rsi_partial_params,
1669        check_rsi_accuracy,
1670        check_rsi_default_candles,
1671        check_rsi_zero_period,
1672        check_rsi_period_exceeds_length,
1673        check_rsi_very_small_dataset,
1674        check_rsi_reinput,
1675        check_rsi_nan_handling,
1676        check_rsi_streaming,
1677        check_rsi_no_poison,
1678        check_rsi_error_variants
1679    );
1680
1681    #[cfg(feature = "proptest")]
1682    generate_all_rsi_tests!(check_rsi_property);
1683
1684    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1685        skip_if_unsupported!(kernel, test);
1686        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1687        let c = read_candles_from_csv(file)?;
1688        let output = RsiBatchBuilder::new()
1689            .kernel(kernel)
1690            .apply_candles(&c, "close")?;
1691        let def = RsiParams::default();
1692        let row = output.values_for(&def).expect("default row missing");
1693        assert_eq!(row.len(), c.close.len());
1694        let expected = [43.42, 42.68, 41.62, 42.86, 39.01];
1695        let start = row.len() - 5;
1696        for (i, &v) in row[start..].iter().enumerate() {
1697            assert!(
1698                (v - expected[i]).abs() < 1e-2,
1699                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1700            );
1701        }
1702        Ok(())
1703    }
1704
1705    #[cfg(debug_assertions)]
1706    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1707        skip_if_unsupported!(kernel, test);
1708
1709        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1710        let c = read_candles_from_csv(file)?;
1711
1712        let test_configs = vec![
1713            (2, 10, 2),
1714            (5, 25, 5),
1715            (30, 60, 15),
1716            (2, 5, 1),
1717            (7, 21, 7),
1718            (10, 50, 10),
1719            (14, 28, 14),
1720        ];
1721
1722        for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1723            let output = RsiBatchBuilder::new()
1724                .kernel(kernel)
1725                .period_range(p_start, p_end, p_step)
1726                .apply_candles(&c, "close")?;
1727
1728            for (idx, &val) in output.values.iter().enumerate() {
1729                if val.is_nan() {
1730                    continue;
1731                }
1732
1733                let bits = val.to_bits();
1734                let row = idx / output.cols;
1735                let col = idx % output.cols;
1736                let combo = &output.combos[row];
1737
1738                if bits == 0x11111111_11111111 {
1739                    panic!(
1740                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1741						 at row {} col {} (flat index {}) with params: period={}",
1742                        test,
1743                        cfg_idx,
1744                        val,
1745                        bits,
1746                        row,
1747                        col,
1748                        idx,
1749                        combo.period.unwrap_or(14)
1750                    );
1751                }
1752
1753                if bits == 0x22222222_22222222 {
1754                    panic!(
1755                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1756						 at row {} col {} (flat index {}) with params: period={}",
1757                        test,
1758                        cfg_idx,
1759                        val,
1760                        bits,
1761                        row,
1762                        col,
1763                        idx,
1764                        combo.period.unwrap_or(14)
1765                    );
1766                }
1767
1768                if bits == 0x33333333_33333333 {
1769                    panic!(
1770                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1771						 at row {} col {} (flat index {}) with params: period={}",
1772                        test,
1773                        cfg_idx,
1774                        val,
1775                        bits,
1776                        row,
1777                        col,
1778                        idx,
1779                        combo.period.unwrap_or(14)
1780                    );
1781                }
1782            }
1783        }
1784
1785        Ok(())
1786    }
1787
1788    #[cfg(not(debug_assertions))]
1789    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1790        Ok(())
1791    }
1792
1793    macro_rules! gen_batch_tests {
1794        ($fn_name:ident) => {
1795            paste! {
1796                #[test] fn [<$fn_name _scalar>]()      {
1797                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1798                }
1799                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1800                #[test] fn [<$fn_name _avx2>]()        {
1801                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1802                }
1803                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1804                #[test] fn [<$fn_name _avx512>]()      {
1805                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1806                }
1807                #[test] fn [<$fn_name _auto_detect>]() {
1808                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1809                }
1810            }
1811        };
1812    }
1813    gen_batch_tests!(check_batch_default_row);
1814    gen_batch_tests!(check_batch_no_poison);
1815}
1816
1817#[cfg(test)]
1818#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1819mod into_parity_tests {
1820    use super::*;
1821    use crate::utilities::data_loader::read_candles_from_csv;
1822    use std::error::Error;
1823
1824    #[inline]
1825    fn eq_or_both_nan(a: f64, b: f64) -> bool {
1826        (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
1827    }
1828
1829    #[test]
1830    fn test_rsi_into_matches_api() -> Result<(), Box<dyn Error>> {
1831        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1832        let candles = read_candles_from_csv(file_path)?;
1833
1834        let input = RsiInput::from_candles(&candles, "close", RsiParams::default());
1835
1836        let baseline = rsi(&input)?.values;
1837
1838        let mut out = vec![0.0; candles.close.len()];
1839        rsi_into(&input, &mut out)?;
1840
1841        assert_eq!(baseline.len(), out.len());
1842        for i in 0..out.len() {
1843            assert!(
1844                eq_or_both_nan(baseline[i], out[i]),
1845                "rsi_into parity mismatch at {}: {} vs {}",
1846                i,
1847                baseline[i],
1848                out[i]
1849            );
1850        }
1851
1852        Ok(())
1853    }
1854}
1855
1856#[cfg(feature = "python")]
1857#[pyfunction(name = "rsi")]
1858#[pyo3(signature = (data, period, kernel=None))]
1859pub fn rsi_py<'py>(
1860    py: Python<'py>,
1861    data: numpy::PyReadonlyArray1<'py, f64>,
1862    period: usize,
1863    kernel: Option<&str>,
1864) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1865    use numpy::{IntoPyArray, PyArrayMethods};
1866
1867    let slice_in = data.as_slice()?;
1868    let kern = validate_kernel(kernel, false)?;
1869
1870    let params = RsiParams {
1871        period: Some(period),
1872    };
1873    let input = RsiInput::from_slice(slice_in, params);
1874
1875    let result_vec: Vec<f64> = py
1876        .allow_threads(|| rsi_with_kernel(&input, kern).map(|o| o.values))
1877        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1878
1879    Ok(result_vec.into_pyarray(py))
1880}
1881
1882#[cfg(feature = "python")]
1883#[pyclass(name = "RsiStream")]
1884pub struct RsiStreamPy {
1885    stream: RsiStream,
1886}
1887
1888#[cfg(feature = "python")]
1889#[pymethods]
1890impl RsiStreamPy {
1891    #[new]
1892    fn new(period: usize) -> PyResult<Self> {
1893        let params = RsiParams {
1894            period: Some(period),
1895        };
1896        let stream =
1897            RsiStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1898        Ok(RsiStreamPy { stream })
1899    }
1900
1901    fn update(&mut self, value: f64) -> Option<f64> {
1902        self.stream.update(value)
1903    }
1904}
1905
1906#[cfg(feature = "python")]
1907#[pyfunction(name = "rsi_batch")]
1908#[pyo3(signature = (data, period_range, kernel=None))]
1909pub fn rsi_batch_py<'py>(
1910    py: Python<'py>,
1911    data: numpy::PyReadonlyArray1<'py, f64>,
1912    period_range: (usize, usize, usize),
1913    kernel: Option<&str>,
1914) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1915    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1916    use pyo3::types::PyDict;
1917
1918    let slice_in = data.as_slice()?;
1919    let kern = validate_kernel(kernel, true)?;
1920
1921    let sweep = RsiBatchRange {
1922        period: period_range,
1923    };
1924
1925    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1926    let rows = combos.len();
1927    let cols = slice_in.len();
1928
1929    let total = rows
1930        .checked_mul(cols)
1931        .ok_or_else(|| PyValueError::new_err("rows*cols overflow in rsi_batch_py"))?;
1932    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1933    let slice_out = unsafe { out_arr.as_slice_mut()? };
1934
1935    let combos = py
1936        .allow_threads(|| {
1937            let kernel = match kern {
1938                Kernel::Auto => detect_best_batch_kernel(),
1939                k => k,
1940            };
1941            let simd = match kernel {
1942                Kernel::Avx512Batch => Kernel::Avx512,
1943                Kernel::Avx2Batch => Kernel::Avx2,
1944                Kernel::ScalarBatch => Kernel::Scalar,
1945                _ => kernel,
1946            };
1947            rsi_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1948        })
1949        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1950
1951    let dict = PyDict::new(py);
1952    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1953    dict.set_item(
1954        "periods",
1955        combos
1956            .iter()
1957            .map(|p| p.period.unwrap() as u64)
1958            .collect::<Vec<_>>()
1959            .into_pyarray(py),
1960    )?;
1961
1962    Ok(dict)
1963}
1964
1965#[cfg(all(feature = "python", feature = "cuda"))]
1966#[pyfunction(name = "rsi_cuda_batch_dev")]
1967#[pyo3(signature = (data_f32, period_range, device_id=0))]
1968pub fn rsi_cuda_batch_dev_py<'py>(
1969    py: Python<'py>,
1970    data_f32: numpy::PyReadonlyArray1<'py, f32>,
1971    period_range: (usize, usize, usize),
1972    device_id: usize,
1973) -> PyResult<(DeviceArrayF32Py, Bound<'py, pyo3::types::PyDict>)> {
1974    use crate::cuda::cuda_available;
1975    use crate::cuda::oscillators::rsi_wrapper::CudaRsi;
1976    use numpy::IntoPyArray;
1977    use pyo3::types::PyDict;
1978
1979    if !cuda_available() {
1980        return Err(PyValueError::new_err("CUDA not available"));
1981    }
1982
1983    let prices = data_f32.as_slice()?;
1984    let sweep = RsiBatchRange {
1985        period: period_range,
1986    };
1987
1988    let inner = py.allow_threads(|| {
1989        let cuda = CudaRsi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1990        cuda.rsi_batch_dev(prices, &sweep)
1991            .map_err(|e| PyValueError::new_err(e.to_string()))
1992    })?;
1993
1994    let dict = PyDict::new(py);
1995    let (start, end, step) = period_range;
1996    let mut periods: Vec<u64> = Vec::new();
1997    if step == 0 {
1998        periods.push(start as u64);
1999    } else {
2000        let mut p = start;
2001        while p <= end {
2002            periods.push(p as u64);
2003            p = p.saturating_add(step);
2004        }
2005    }
2006    dict.set_item("periods", periods.into_pyarray(py))?;
2007
2008    let handle = make_device_array_py(device_id, inner)?;
2009    Ok((handle, dict))
2010}
2011
2012#[cfg(all(feature = "python", feature = "cuda"))]
2013#[pyfunction(name = "rsi_cuda_many_series_one_param_dev")]
2014#[pyo3(signature = (data_tm_f32, period, device_id=0))]
2015pub fn rsi_cuda_many_series_one_param_dev_py(
2016    py: Python<'_>,
2017    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
2018    period: usize,
2019    device_id: usize,
2020) -> PyResult<DeviceArrayF32Py> {
2021    use crate::cuda::cuda_available;
2022    use crate::cuda::oscillators::rsi_wrapper::CudaRsi;
2023    use numpy::PyUntypedArrayMethods;
2024    if !cuda_available() {
2025        return Err(PyValueError::new_err("CUDA not available"));
2026    }
2027
2028    let flat = data_tm_f32.as_slice()?;
2029    let rows = data_tm_f32.shape()[0];
2030    let cols = data_tm_f32.shape()[1];
2031    let inner = py.allow_threads(|| {
2032        let cuda = CudaRsi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2033        cuda.rsi_many_series_one_param_time_major_dev(flat, cols, rows, period)
2034            .map_err(|e| PyValueError::new_err(e.to_string()))
2035    })?;
2036    make_device_array_py(device_id, inner)
2037}
2038
2039#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2040#[wasm_bindgen]
2041pub fn rsi_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2042    let params = RsiParams {
2043        period: Some(period),
2044    };
2045    let input = RsiInput::from_slice(data, params);
2046
2047    let mut output = vec![0.0; data.len()];
2048
2049    rsi_into_slice(&mut output, &input, detect_best_kernel())
2050        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2051
2052    Ok(output)
2053}
2054
2055#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2056#[wasm_bindgen]
2057pub fn rsi_alloc(len: usize) -> *mut f64 {
2058    let mut vec = Vec::<f64>::with_capacity(len);
2059    let ptr = vec.as_mut_ptr();
2060    std::mem::forget(vec);
2061    ptr
2062}
2063
2064#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2065#[wasm_bindgen]
2066pub fn rsi_free(ptr: *mut f64, len: usize) {
2067    unsafe {
2068        let _ = Vec::from_raw_parts(ptr, len, len);
2069    }
2070}
2071
2072#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2073#[wasm_bindgen]
2074pub fn rsi_into(
2075    in_ptr: *const f64,
2076    out_ptr: *mut f64,
2077    len: usize,
2078    period: usize,
2079) -> Result<(), JsValue> {
2080    if in_ptr.is_null() || out_ptr.is_null() {
2081        return Err(JsValue::from_str("null pointer passed to rsi_into"));
2082    }
2083
2084    unsafe {
2085        let data = std::slice::from_raw_parts(in_ptr, len);
2086
2087        if period == 0 || period > len {
2088            return Err(JsValue::from_str("Invalid period"));
2089        }
2090
2091        let params = RsiParams {
2092            period: Some(period),
2093        };
2094        let input = RsiInput::from_slice(data, params);
2095
2096        if in_ptr == out_ptr {
2097            let mut temp = vec![0.0; len];
2098            rsi_into_slice(&mut temp, &input, detect_best_kernel())
2099                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2100            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2101            out.copy_from_slice(&temp);
2102        } else {
2103            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2104            rsi_into_slice(out, &input, detect_best_kernel())
2105                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2106        }
2107
2108        Ok(())
2109    }
2110}
2111
2112#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2113#[derive(Serialize, Deserialize)]
2114pub struct RsiBatchConfig {
2115    pub period_range: (usize, usize, usize),
2116}
2117
2118#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2119#[derive(Serialize, Deserialize)]
2120pub struct RsiBatchJsOutput {
2121    pub values: Vec<f64>,
2122    pub combos: Vec<RsiParams>,
2123    pub rows: usize,
2124    pub cols: usize,
2125}
2126
2127#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2128#[wasm_bindgen(js_name = rsi_batch)]
2129pub fn rsi_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2130    let config: RsiBatchConfig = serde_wasm_bindgen::from_value(config)
2131        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2132
2133    let sweep = RsiBatchRange {
2134        period: config.period_range,
2135    };
2136
2137    let output = rsi_batch_with_kernel(data, &sweep, detect_best_batch_kernel())
2138        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2139
2140    let js_output = RsiBatchJsOutput {
2141        values: output.values,
2142        combos: output.combos,
2143        rows: output.rows,
2144        cols: output.cols,
2145    };
2146
2147    serde_wasm_bindgen::to_value(&js_output)
2148        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2149}