Skip to main content

vector_ta/indicators/
ift_rsi.rs

1#[cfg(feature = "python")]
2use crate::utilities::kernel_validation::validate_kernel;
3#[cfg(feature = "python")]
4use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyValueError;
7#[cfg(feature = "python")]
8use pyo3::prelude::*;
9#[cfg(feature = "python")]
10use pyo3::types::PyDict;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::indicators::rsi::{rsi, RsiError, RsiInput, RsiParams};
18use crate::indicators::wma::{wma, WmaError, WmaInput, WmaParams};
19use crate::utilities::data_loader::{source_type, Candles};
20#[cfg(all(feature = "python", feature = "cuda"))]
21use crate::utilities::dlpack_cuda::DeviceArrayF32Py;
22use crate::utilities::enums::Kernel;
23use crate::utilities::helpers::{
24    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
25    make_uninit_matrix,
26};
27use aligned_vec::{AVec, CACHELINE_ALIGN};
28#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
29use core::arch::x86_64::*;
30#[cfg(not(target_arch = "wasm32"))]
31use rayon::prelude::*;
32use std::convert::AsRef;
33use std::error::Error;
34use thiserror::Error;
35
36impl<'a> AsRef<[f64]> for IftRsiInput<'a> {
37    #[inline(always)]
38    fn as_ref(&self) -> &[f64] {
39        match &self.data {
40            IftRsiData::Slice(slice) => slice,
41            IftRsiData::Candles { candles, source } => source_type(candles, source),
42        }
43    }
44}
45
46#[derive(Debug, Clone)]
47pub enum IftRsiData<'a> {
48    Candles {
49        candles: &'a Candles,
50        source: &'a str,
51    },
52    Slice(&'a [f64]),
53}
54
55#[derive(Debug, Clone)]
56pub struct IftRsiOutput {
57    pub values: Vec<f64>,
58}
59
60#[derive(Debug, Clone)]
61#[cfg_attr(
62    all(target_arch = "wasm32", feature = "wasm"),
63    derive(Serialize, Deserialize)
64)]
65pub struct IftRsiParams {
66    pub rsi_period: Option<usize>,
67    pub wma_period: Option<usize>,
68}
69
70impl Default for IftRsiParams {
71    fn default() -> Self {
72        Self {
73            rsi_period: Some(5),
74            wma_period: Some(9),
75        }
76    }
77}
78
79#[derive(Debug, Clone)]
80pub struct IftRsiInput<'a> {
81    pub data: IftRsiData<'a>,
82    pub params: IftRsiParams,
83}
84
85impl<'a> IftRsiInput<'a> {
86    #[inline]
87    pub fn from_candles(c: &'a Candles, s: &'a str, p: IftRsiParams) -> Self {
88        Self {
89            data: IftRsiData::Candles {
90                candles: c,
91                source: s,
92            },
93            params: p,
94        }
95    }
96    #[inline]
97    pub fn from_slice(sl: &'a [f64], p: IftRsiParams) -> Self {
98        Self {
99            data: IftRsiData::Slice(sl),
100            params: p,
101        }
102    }
103    #[inline]
104    pub fn with_default_candles(c: &'a Candles) -> Self {
105        Self::from_candles(c, "close", IftRsiParams::default())
106    }
107    #[inline]
108    pub fn get_rsi_period(&self) -> usize {
109        self.params.rsi_period.unwrap_or(5)
110    }
111    #[inline]
112    pub fn get_wma_period(&self) -> usize {
113        self.params.wma_period.unwrap_or(9)
114    }
115}
116
117#[derive(Copy, Clone, Debug)]
118pub struct IftRsiBuilder {
119    rsi_period: Option<usize>,
120    wma_period: Option<usize>,
121    kernel: Kernel,
122}
123
124impl Default for IftRsiBuilder {
125    fn default() -> Self {
126        Self {
127            rsi_period: None,
128            wma_period: None,
129            kernel: Kernel::Auto,
130        }
131    }
132}
133
134impl IftRsiBuilder {
135    #[inline(always)]
136    pub fn new() -> Self {
137        Self::default()
138    }
139    #[inline(always)]
140    pub fn rsi_period(mut self, n: usize) -> Self {
141        self.rsi_period = Some(n);
142        self
143    }
144    #[inline(always)]
145    pub fn wma_period(mut self, n: usize) -> Self {
146        self.wma_period = Some(n);
147        self
148    }
149    #[inline(always)]
150    pub fn kernel(mut self, k: Kernel) -> Self {
151        self.kernel = k;
152        self
153    }
154
155    #[inline(always)]
156    pub fn apply(self, c: &Candles) -> Result<IftRsiOutput, IftRsiError> {
157        let p = IftRsiParams {
158            rsi_period: self.rsi_period,
159            wma_period: self.wma_period,
160        };
161        let i = IftRsiInput::from_candles(c, "close", p);
162        ift_rsi_with_kernel(&i, self.kernel)
163    }
164
165    #[inline(always)]
166    pub fn apply_slice(self, d: &[f64]) -> Result<IftRsiOutput, IftRsiError> {
167        let p = IftRsiParams {
168            rsi_period: self.rsi_period,
169            wma_period: self.wma_period,
170        };
171        let i = IftRsiInput::from_slice(d, p);
172        ift_rsi_with_kernel(&i, self.kernel)
173    }
174
175    #[inline(always)]
176    pub fn into_stream(self) -> Result<IftRsiStream, IftRsiError> {
177        let p = IftRsiParams {
178            rsi_period: self.rsi_period,
179            wma_period: self.wma_period,
180        };
181        IftRsiStream::try_new(p)
182    }
183}
184
185#[derive(Debug, Error)]
186pub enum IftRsiError {
187    #[error("ift_rsi: Input data slice is empty.")]
188    EmptyData,
189    #[error("ift_rsi: All values are NaN.")]
190    AllValuesNaN,
191    #[error("ift_rsi: Invalid RSI period {rsi_period} or WMA period {wma_period}, data length = {data_len}.")]
192    InvalidPeriod {
193        rsi_period: usize,
194        wma_period: usize,
195        data_len: usize,
196    },
197    #[error("ift_rsi: Not enough valid data: needed = {needed}, valid = {valid}")]
198    NotEnoughValidData { needed: usize, valid: usize },
199    #[error("ift_rsi: RSI calculation error: {0}")]
200    RsiCalculationError(String),
201    #[error("ift_rsi: WMA calculation error: {0}")]
202    WmaCalculationError(String),
203    #[error("ift_rsi: Output length mismatch: expected = {expected}, got = {got}")]
204    OutputLengthMismatch { expected: usize, got: usize },
205    #[error("ift_rsi: Wrong kernel for batch operation. Use a batch kernel variant.")]
206    WrongKernelForBatch,
207    #[error("ift_rsi: Invalid kernel for batch: {0:?}")]
208    InvalidKernelForBatch(crate::utilities::enums::Kernel),
209    #[error("ift_rsi: Invalid range: start={start}, end={end}, step={step}")]
210    InvalidRange {
211        start: usize,
212        end: usize,
213        step: usize,
214    },
215}
216
217#[inline]
218pub fn ift_rsi(input: &IftRsiInput) -> Result<IftRsiOutput, IftRsiError> {
219    ift_rsi_with_kernel(input, Kernel::Auto)
220}
221
222pub fn ift_rsi_with_kernel(
223    input: &IftRsiInput,
224    kernel: Kernel,
225) -> Result<IftRsiOutput, IftRsiError> {
226    let data: &[f64] = match &input.data {
227        IftRsiData::Candles { candles, source } => source_type(candles, source),
228        IftRsiData::Slice(sl) => sl,
229    };
230
231    if data.is_empty() {
232        return Err(IftRsiError::EmptyData);
233    }
234    let first = data
235        .iter()
236        .position(|x| !x.is_nan())
237        .ok_or(IftRsiError::AllValuesNaN)?;
238    let len = data.len();
239    let rsi_period = input.get_rsi_period();
240    let wma_period = input.get_wma_period();
241    if rsi_period == 0 || wma_period == 0 || rsi_period > len || wma_period > len {
242        return Err(IftRsiError::InvalidPeriod {
243            rsi_period,
244            wma_period,
245            data_len: len,
246        });
247    }
248    let needed = rsi_period.max(wma_period);
249    if (len - first) < needed {
250        return Err(IftRsiError::NotEnoughValidData {
251            needed,
252            valid: len - first,
253        });
254    }
255
256    if kernel.is_batch() {
257        return Err(IftRsiError::WrongKernelForBatch);
258    }
259
260    let warmup_period = first + rsi_period + wma_period - 1;
261    let mut out = alloc_with_nan_prefix(len, warmup_period);
262
263    unsafe {
264        ift_rsi_scalar_classic(data, rsi_period, wma_period, first, &mut out)?;
265    }
266
267    Ok(IftRsiOutput { values: out })
268}
269
270#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
271#[inline]
272pub fn ift_rsi_into(input: &IftRsiInput, out: &mut [f64]) -> Result<(), IftRsiError> {
273    let data: &[f64] = match &input.data {
274        IftRsiData::Candles { candles, source } => source_type(candles, source),
275        IftRsiData::Slice(sl) => sl,
276    };
277
278    if out.len() != data.len() {
279        return Err(IftRsiError::OutputLengthMismatch {
280            expected: data.len(),
281            got: out.len(),
282        });
283    }
284
285    let kern = Kernel::Auto;
286    ift_rsi_into_slice(out, input, kern)
287}
288
289#[inline(always)]
290fn ift_rsi_compute_into(
291    data: &[f64],
292    rsi_period: usize,
293    wma_period: usize,
294    first_valid: usize,
295    out: &mut [f64],
296) -> Result<(), IftRsiError> {
297    let sliced = &data[first_valid..];
298    let mut rsi_values = rsi(&RsiInput::from_slice(
299        sliced,
300        RsiParams {
301            period: Some(rsi_period),
302        },
303    ))
304    .map_err(|e| IftRsiError::RsiCalculationError(e.to_string()))?
305    .values;
306
307    for val in rsi_values.iter_mut() {
308        if !val.is_nan() {
309            *val = 0.1 * (*val - 50.0);
310        }
311    }
312
313    let wma_values = wma(&WmaInput::from_slice(
314        &rsi_values,
315        WmaParams {
316            period: Some(wma_period),
317        },
318    ))
319    .map_err(|e| IftRsiError::WmaCalculationError(e.to_string()))?
320    .values;
321
322    for (i, &w) in wma_values.iter().enumerate() {
323        if !w.is_nan() {
324            out[first_valid + i] = w.tanh();
325        }
326    }
327    Ok(())
328}
329
330#[inline]
331pub fn ift_rsi_scalar(
332    data: &[f64],
333    rsi_period: usize,
334    wma_period: usize,
335    first_valid: usize,
336    out: &mut [f64],
337) -> Result<(), IftRsiError> {
338    ift_rsi_compute_into(data, rsi_period, wma_period, first_valid, out)
339}
340
341pub fn ift_rsi_into_slice(
342    dst: &mut [f64],
343    input: &IftRsiInput,
344    kern: Kernel,
345) -> Result<(), IftRsiError> {
346    let data: &[f64] = match &input.data {
347        IftRsiData::Candles { candles, source } => source_type(candles, source),
348        IftRsiData::Slice(sl) => sl,
349    };
350
351    if data.is_empty() {
352        return Err(IftRsiError::EmptyData);
353    }
354
355    if dst.len() != data.len() {
356        return Err(IftRsiError::OutputLengthMismatch {
357            expected: data.len(),
358            got: dst.len(),
359        });
360    }
361
362    let first = data
363        .iter()
364        .position(|x| !x.is_nan())
365        .ok_or(IftRsiError::AllValuesNaN)?;
366    let rsi_period = input.get_rsi_period();
367    let wma_period = input.get_wma_period();
368
369    if rsi_period == 0 || wma_period == 0 || rsi_period > data.len() || wma_period > data.len() {
370        return Err(IftRsiError::InvalidPeriod {
371            rsi_period,
372            wma_period,
373            data_len: data.len(),
374        });
375    }
376
377    let warmup_period = (first + rsi_period + wma_period - 1).min(dst.len());
378    for v in &mut dst[..warmup_period] {
379        *v = f64::NAN;
380    }
381
382    unsafe {
383        return ift_rsi_scalar_classic(data, rsi_period, wma_period, first, dst);
384    }
385
386    Ok(())
387}
388
389#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
390#[inline]
391pub fn ift_rsi_avx512(
392    data: &[f64],
393    rsi_period: usize,
394    wma_period: usize,
395    first_valid: usize,
396    out: &mut [f64],
397) -> Result<(), IftRsiError> {
398    unsafe { ift_rsi_scalar_classic(data, rsi_period, wma_period, first_valid, out) }
399}
400
401#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
402#[inline]
403pub fn ift_rsi_avx2(
404    data: &[f64],
405    rsi_period: usize,
406    wma_period: usize,
407    first_valid: usize,
408    out: &mut [f64],
409) -> Result<(), IftRsiError> {
410    unsafe { ift_rsi_scalar_classic(data, rsi_period, wma_period, first_valid, out) }
411}
412
413#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
414#[inline]
415pub fn ift_rsi_avx512_short(
416    data: &[f64],
417    rsi_period: usize,
418    wma_period: usize,
419    first_valid: usize,
420    out: &mut [f64],
421) -> Result<(), IftRsiError> {
422    ift_rsi_avx512(data, rsi_period, wma_period, first_valid, out)
423}
424
425#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
426#[inline]
427pub fn ift_rsi_avx512_long(
428    data: &[f64],
429    rsi_period: usize,
430    wma_period: usize,
431    first_valid: usize,
432    out: &mut [f64],
433) -> Result<(), IftRsiError> {
434    ift_rsi_avx512(data, rsi_period, wma_period, first_valid, out)
435}
436
437#[inline]
438pub fn ift_rsi_batch_with_kernel(
439    data: &[f64],
440    sweep: &IftRsiBatchRange,
441    k: Kernel,
442) -> Result<IftRsiBatchOutput, IftRsiError> {
443    let kernel = match k {
444        Kernel::Auto => detect_best_batch_kernel(),
445        other if other.is_batch() => other,
446        other => return Err(IftRsiError::InvalidKernelForBatch(other)),
447    };
448    let simd = match kernel {
449        Kernel::Avx512Batch => Kernel::Avx512,
450        Kernel::Avx2Batch => Kernel::Avx2,
451        Kernel::ScalarBatch => Kernel::Scalar,
452        _ => unreachable!(),
453    };
454    ift_rsi_batch_par_slice(data, sweep, simd)
455}
456
457#[derive(Clone, Debug)]
458pub struct IftRsiBatchRange {
459    pub rsi_period: (usize, usize, usize),
460    pub wma_period: (usize, usize, usize),
461}
462
463impl Default for IftRsiBatchRange {
464    fn default() -> Self {
465        Self {
466            rsi_period: (5, 5, 0),
467            wma_period: (9, 258, 1),
468        }
469    }
470}
471
472#[derive(Clone, Debug, Default)]
473pub struct IftRsiBatchBuilder {
474    range: IftRsiBatchRange,
475    kernel: Kernel,
476}
477
478impl IftRsiBatchBuilder {
479    pub fn new() -> Self {
480        Self::default()
481    }
482    pub fn kernel(mut self, k: Kernel) -> Self {
483        self.kernel = k;
484        self
485    }
486    #[inline]
487    pub fn rsi_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
488        self.range.rsi_period = (start, end, step);
489        self
490    }
491    #[inline]
492    pub fn rsi_period_static(mut self, p: usize) -> Self {
493        self.range.rsi_period = (p, p, 0);
494        self
495    }
496    #[inline]
497    pub fn wma_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
498        self.range.wma_period = (start, end, step);
499        self
500    }
501    #[inline]
502    pub fn wma_period_static(mut self, n: usize) -> Self {
503        self.range.wma_period = (n, n, 0);
504        self
505    }
506    pub fn apply_slice(self, data: &[f64]) -> Result<IftRsiBatchOutput, IftRsiError> {
507        ift_rsi_batch_with_kernel(data, &self.range, self.kernel)
508    }
509    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<IftRsiBatchOutput, IftRsiError> {
510        let slice = source_type(c, src);
511        self.apply_slice(slice)
512    }
513    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<IftRsiBatchOutput, IftRsiError> {
514        IftRsiBatchBuilder::new().kernel(k).apply_slice(data)
515    }
516    pub fn with_default_candles(c: &Candles) -> Result<IftRsiBatchOutput, IftRsiError> {
517        IftRsiBatchBuilder::new()
518            .kernel(Kernel::Auto)
519            .apply_candles(c, "close")
520    }
521}
522
523#[derive(Clone, Debug)]
524pub struct IftRsiBatchOutput {
525    pub values: Vec<f64>,
526    pub combos: Vec<IftRsiParams>,
527    pub rows: usize,
528    pub cols: usize,
529}
530
531impl IftRsiBatchOutput {
532    pub fn row_for_params(&self, p: &IftRsiParams) -> Option<usize> {
533        self.combos.iter().position(|c| {
534            c.rsi_period.unwrap_or(5) == p.rsi_period.unwrap_or(5)
535                && c.wma_period.unwrap_or(9) == p.wma_period.unwrap_or(9)
536        })
537    }
538
539    pub fn values_for(&self, p: &IftRsiParams) -> Option<&[f64]> {
540        self.row_for_params(p).map(|row| {
541            let start = row * self.cols;
542            &self.values[start..start + self.cols]
543        })
544    }
545}
546
547#[inline(always)]
548fn expand_grid(r: &IftRsiBatchRange) -> Result<Vec<IftRsiParams>, IftRsiError> {
549    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, IftRsiError> {
550        if step == 0 || start == end {
551            return Ok(vec![start]);
552        }
553        let (lo, hi) = if start <= end {
554            (start, end)
555        } else {
556            (end, start)
557        };
558        let vals: Vec<usize> = (lo..=hi).step_by(step).collect();
559        if vals.is_empty() {
560            return Err(IftRsiError::InvalidRange { start, end, step });
561        }
562        Ok(vals)
563    }
564    let rsi_periods = axis_usize(r.rsi_period)?;
565    let wma_periods = axis_usize(r.wma_period)?;
566    let cap =
567        rsi_periods
568            .len()
569            .checked_mul(wma_periods.len())
570            .ok_or(IftRsiError::InvalidRange {
571                start: r.rsi_period.0,
572                end: r.rsi_period.1,
573                step: r.rsi_period.2,
574            })?;
575    let mut out = Vec::with_capacity(cap);
576    for &rsi_p in &rsi_periods {
577        for &wma_p in &wma_periods {
578            out.push(IftRsiParams {
579                rsi_period: Some(rsi_p),
580                wma_period: Some(wma_p),
581            });
582        }
583    }
584    Ok(out)
585}
586
587#[inline(always)]
588pub fn ift_rsi_batch_slice(
589    data: &[f64],
590    sweep: &IftRsiBatchRange,
591    kern: Kernel,
592) -> Result<IftRsiBatchOutput, IftRsiError> {
593    ift_rsi_batch_inner(data, sweep, kern, false)
594}
595
596#[inline(always)]
597pub fn ift_rsi_batch_par_slice(
598    data: &[f64],
599    sweep: &IftRsiBatchRange,
600    kern: Kernel,
601) -> Result<IftRsiBatchOutput, IftRsiError> {
602    ift_rsi_batch_inner(data, sweep, kern, true)
603}
604
605#[inline(always)]
606fn ift_rsi_batch_inner(
607    data: &[f64],
608    sweep: &IftRsiBatchRange,
609    kern: Kernel,
610    parallel: bool,
611) -> Result<IftRsiBatchOutput, IftRsiError> {
612    let combos = expand_grid(sweep)?;
613    let first = data
614        .iter()
615        .position(|x| !x.is_nan())
616        .ok_or(IftRsiError::AllValuesNaN)?;
617    let max_rsi = combos.iter().map(|c| c.rsi_period.unwrap()).max().unwrap();
618    let max_wma = combos.iter().map(|c| c.wma_period.unwrap()).max().unwrap();
619    let max_p = max_rsi.max(max_wma);
620    if data.len() - first < max_p {
621        return Err(IftRsiError::NotEnoughValidData {
622            needed: max_p,
623            valid: data.len() - first,
624        });
625    }
626
627    let rows = combos.len();
628    let cols = data.len();
629    rows.checked_mul(cols).ok_or(IftRsiError::InvalidRange {
630        start: sweep.rsi_period.0,
631        end: sweep.rsi_period.1,
632        step: sweep.rsi_period.2,
633    })?;
634
635    let warmup_periods: Vec<usize> = combos
636        .iter()
637        .map(|c| first + c.rsi_period.unwrap() + c.wma_period.unwrap() - 1)
638        .collect();
639
640    let mut buf_mu = make_uninit_matrix(rows, cols);
641    init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
642
643    let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
644    let values: &mut [f64] = unsafe {
645        core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
646    };
647
648    let sliced = &data[first..];
649    let n = sliced.len();
650    let mut gains = Vec::with_capacity(n.saturating_sub(1));
651    let mut losses = Vec::with_capacity(n.saturating_sub(1));
652    for i in 1..n {
653        let d = sliced[i] - sliced[i - 1];
654        if d > 0.0 {
655            gains.push(d);
656            losses.push(0.0);
657        } else {
658            gains.push(0.0);
659            losses.push(-d);
660        }
661    }
662
663    let n1 = gains.len();
664    let mut pg = Vec::with_capacity(n1 + 1);
665    let mut pl = Vec::with_capacity(n1 + 1);
666    pg.push(0.0);
667    pl.push(0.0);
668    for i in 0..n1 {
669        pg.push(pg[i] + gains[i]);
670        pl.push(pl[i] + losses[i]);
671    }
672
673    let n1 = gains.len();
674    let mut pg = Vec::with_capacity(n1 + 1);
675    let mut pl = Vec::with_capacity(n1 + 1);
676    pg.push(0.0);
677    pl.push(0.0);
678    for i in 0..n1 {
679        pg.push(pg[i] + gains[i]);
680        pl.push(pl[i] + losses[i]);
681    }
682
683    let n1 = gains.len();
684    let mut pg = Vec::with_capacity(n1 + 1);
685    let mut pl = Vec::with_capacity(n1 + 1);
686    pg.push(0.0);
687    pl.push(0.0);
688    for i in 0..n1 {
689        pg.push(pg[i] + gains[i]);
690        pl.push(pl[i] + losses[i]);
691    }
692
693    let n1 = gains.len();
694    let mut pg = Vec::with_capacity(n1 + 1);
695    let mut pl = Vec::with_capacity(n1 + 1);
696    pg.push(0.0);
697    pl.push(0.0);
698    for i in 0..n1 {
699        pg.push(pg[i] + gains[i]);
700        pl.push(pl[i] + losses[i]);
701    }
702
703    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
704        let rsi_p = combos[row].rsi_period.unwrap();
705        let wma_p = combos[row].wma_period.unwrap();
706        match kern {
707            Kernel::Scalar => ift_rsi_row_scalar_precomputed_ps(
708                &gains, &losses, &pg, &pl, rsi_p, wma_p, first, out_row,
709            ),
710            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
711            Kernel::Avx2 => ift_rsi_row_avx2(data, first, rsi_p, wma_p, out_row),
712            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
713            Kernel::Avx512 => ift_rsi_row_avx512(data, first, rsi_p, wma_p, out_row),
714            _ => unreachable!(),
715        }
716    };
717
718    if parallel {
719        #[cfg(not(target_arch = "wasm32"))]
720        {
721            values
722                .par_chunks_mut(cols)
723                .enumerate()
724                .for_each(|(row, slice)| do_row(row, slice));
725        }
726
727        #[cfg(target_arch = "wasm32")]
728        {
729            for (row, slice) in values.chunks_mut(cols).enumerate() {
730                do_row(row, slice);
731            }
732        }
733    } else {
734        for (row, slice) in values.chunks_mut(cols).enumerate() {
735            do_row(row, slice);
736        }
737    }
738
739    let values = unsafe {
740        Vec::from_raw_parts(
741            buf_guard.as_mut_ptr() as *mut f64,
742            buf_guard.len(),
743            buf_guard.capacity(),
744        )
745    };
746
747    Ok(IftRsiBatchOutput {
748        values,
749        combos,
750        rows,
751        cols,
752    })
753}
754
755#[inline(always)]
756fn ift_rsi_batch_inner_into(
757    data: &[f64],
758    sweep: &IftRsiBatchRange,
759    kern: Kernel,
760    parallel: bool,
761    out: &mut [f64],
762) -> Result<Vec<IftRsiParams>, IftRsiError> {
763    let combos = expand_grid(sweep)?;
764    let rows = combos.len();
765    let cols = data.len();
766    rows.checked_mul(cols).ok_or(IftRsiError::InvalidRange {
767        start: sweep.rsi_period.0,
768        end: sweep.rsi_period.1,
769        step: sweep.rsi_period.2,
770    })?;
771    let first = data
772        .iter()
773        .position(|x| !x.is_nan())
774        .ok_or(IftRsiError::AllValuesNaN)?;
775    let max_rsi = combos.iter().map(|c| c.rsi_period.unwrap()).max().unwrap();
776    let max_wma = combos.iter().map(|c| c.wma_period.unwrap()).max().unwrap();
777    let max_p = max_rsi.max(max_wma);
778    if data.len() - first < max_p {
779        return Err(IftRsiError::NotEnoughValidData {
780            needed: max_p,
781            valid: data.len() - first,
782        });
783    }
784
785    let rows = combos.len();
786    let cols = data.len();
787
788    for (row, combo) in combos.iter().enumerate() {
789        let warmup = (first + combo.rsi_period.unwrap() + combo.wma_period.unwrap() - 1).min(cols);
790        let row_start = row * cols;
791        for i in 0..warmup {
792            out[row_start + i] = f64::NAN;
793        }
794    }
795
796    let sliced = &data[first..];
797    let n = sliced.len();
798    let mut gains = Vec::with_capacity(n.saturating_sub(1));
799    let mut losses = Vec::with_capacity(n.saturating_sub(1));
800    for i in 1..n {
801        let d = sliced[i] - sliced[i - 1];
802        if d > 0.0 {
803            gains.push(d);
804            losses.push(0.0);
805        } else {
806            gains.push(0.0);
807            losses.push(-d);
808        }
809    }
810
811    let n1 = gains.len();
812    let mut pg = Vec::with_capacity(n1 + 1);
813    let mut pl = Vec::with_capacity(n1 + 1);
814    pg.push(0.0);
815    pl.push(0.0);
816    for i in 0..n1 {
817        pg.push(pg[i] + gains[i]);
818        pl.push(pl[i] + losses[i]);
819    }
820
821    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
822        let rsi_p = combos[row].rsi_period.unwrap();
823        let wma_p = combos[row].wma_period.unwrap();
824        match kern {
825            Kernel::Scalar => ift_rsi_row_scalar_precomputed_ps(
826                &gains, &losses, &pg, &pl, rsi_p, wma_p, first, out_row,
827            ),
828            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
829            Kernel::Avx2 => ift_rsi_row_avx2(data, first, rsi_p, wma_p, out_row),
830            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
831            Kernel::Avx512 => ift_rsi_row_avx512(data, first, rsi_p, wma_p, out_row),
832            _ => unreachable!(),
833        }
834    };
835
836    if parallel {
837        #[cfg(not(target_arch = "wasm32"))]
838        {
839            out.par_chunks_mut(cols)
840                .enumerate()
841                .for_each(|(row, slice)| do_row(row, slice));
842        }
843
844        #[cfg(target_arch = "wasm32")]
845        {
846            for (row, slice) in out.chunks_mut(cols).enumerate() {
847                do_row(row, slice);
848            }
849        }
850    } else {
851        for (row, slice) in out.chunks_mut(cols).enumerate() {
852            do_row(row, slice);
853        }
854    }
855
856    Ok(combos)
857}
858
859#[inline(always)]
860unsafe fn ift_rsi_row_scalar(
861    data: &[f64],
862    first: usize,
863    rsi_period: usize,
864    wma_period: usize,
865    out: &mut [f64],
866) {
867    let sliced = &data[first..];
868    let n = sliced.len();
869    if n == 0 {
870        return;
871    }
872    let mut gains = Vec::with_capacity(n.saturating_sub(1));
873    let mut losses = Vec::with_capacity(n.saturating_sub(1));
874    for i in 1..n {
875        let d = sliced[i] - sliced[i - 1];
876        if d > 0.0 {
877            gains.push(d);
878            losses.push(0.0);
879        } else {
880            gains.push(0.0);
881            losses.push(-d);
882        }
883    }
884    ift_rsi_row_scalar_precomputed(&gains, &losses, rsi_period, wma_period, first, out);
885}
886
887#[inline(always)]
888unsafe fn ift_rsi_row_scalar_precomputed(
889    gains: &[f64],
890    losses: &[f64],
891    rsi_period: usize,
892    wma_period: usize,
893    first: usize,
894    out: &mut [f64],
895) {
896    let n1 = gains.len();
897    if rsi_period == 0 || wma_period == 0 {
898        return;
899    }
900    if rsi_period + wma_period - 1 >= n1 + 1 {
901        return;
902    }
903
904    let mut avg_gain = 0.0f64;
905    let mut avg_loss = 0.0f64;
906    for i in 0..rsi_period {
907        avg_gain += *gains.get_unchecked(i);
908        avg_loss += *losses.get_unchecked(i);
909    }
910    let rp_f = rsi_period as f64;
911    avg_gain /= rp_f;
912    avg_loss /= rp_f;
913    let alpha = 1.0f64 / rp_f;
914    let beta = 1.0f64 - alpha;
915
916    let wp = wma_period;
917    let wp_f = wp as f64;
918    let denom = 0.5f64 * wp_f * (wp_f + 1.0);
919    let denom_rcp = 1.0f64 / denom;
920    let mut buf: Vec<f64> = vec![0.0; wp];
921    let mut head = 0usize;
922    let mut filled = 0usize;
923    let mut sum = 0.0f64;
924    let mut num = 0.0f64;
925
926    let mut i = rsi_period;
927    while i <= n1 {
928        if i > rsi_period {
929            let g = *gains.get_unchecked(i - 1);
930            let l = *losses.get_unchecked(i - 1);
931            avg_gain = f64::mul_add(avg_gain, beta, alpha * g);
932            avg_loss = f64::mul_add(avg_loss, beta, alpha * l);
933        }
934
935        let rs = if avg_loss != 0.0 {
936            avg_gain / avg_loss
937        } else {
938            100.0
939        };
940        let rsi = 100.0 - 100.0 / (1.0 + rs);
941        let x = 0.1f64 * (rsi - 50.0);
942
943        if filled < wp {
944            sum += x;
945            num = f64::mul_add((filled as f64) + 1.0, x, num);
946            *buf.get_unchecked_mut(head) = x;
947            head += 1;
948            if head == wp {
949                head = 0;
950            }
951            filled += 1;
952            if filled == wp {
953                let wma = num * denom_rcp;
954                *out.get_unchecked_mut(first + i) = wma.tanh();
955            }
956        } else {
957            let x_old = *buf.get_unchecked(head);
958            *buf.get_unchecked_mut(head) = x;
959            head += 1;
960            if head == wp {
961                head = 0;
962            }
963            let sum_t = sum;
964            num = f64::mul_add(wp_f, x, num) - sum_t;
965            sum = sum_t + x - x_old;
966            let wma = num * denom_rcp;
967            *out.get_unchecked_mut(first + i) = wma.tanh();
968        }
969
970        i += 1;
971        if i > n1 {
972            break;
973        }
974    }
975}
976
977#[inline(always)]
978unsafe fn ift_rsi_row_scalar_precomputed_ps(
979    gains: &[f64],
980    losses: &[f64],
981    pg: &[f64],
982    pl: &[f64],
983    rsi_period: usize,
984    wma_period: usize,
985    first: usize,
986    out: &mut [f64],
987) {
988    let n1 = gains.len();
989    if rsi_period == 0 || wma_period == 0 {
990        return;
991    }
992    if rsi_period + wma_period - 1 >= n1 + 1 {
993        return;
994    }
995
996    let sum_gain = *pg.get_unchecked(rsi_period) - *pg.get_unchecked(0);
997    let sum_loss = *pl.get_unchecked(rsi_period) - *pl.get_unchecked(0);
998    let rp_f = rsi_period as f64;
999    let mut avg_gain = sum_gain / rp_f;
1000    let mut avg_loss = sum_loss / rp_f;
1001    let alpha = 1.0f64 / rp_f;
1002    let beta = 1.0f64 - alpha;
1003
1004    let wp = wma_period;
1005    let wp_f = wp as f64;
1006    let denom = 0.5f64 * wp_f * (wp_f + 1.0);
1007    let denom_rcp = 1.0f64 / denom;
1008    let mut buf: Vec<f64> = vec![0.0; wp];
1009    let mut head = 0usize;
1010    let mut filled = 0usize;
1011    let mut sum = 0.0f64;
1012    let mut num = 0.0f64;
1013
1014    let mut i = rsi_period;
1015    while i <= n1 {
1016        if i > rsi_period {
1017            let g = *gains.get_unchecked(i - 1);
1018            let l = *losses.get_unchecked(i - 1);
1019            avg_gain = f64::mul_add(avg_gain, beta, alpha * g);
1020            avg_loss = f64::mul_add(avg_loss, beta, alpha * l);
1021        }
1022
1023        let rs = if avg_loss != 0.0 {
1024            avg_gain / avg_loss
1025        } else {
1026            100.0
1027        };
1028        let rsi = 100.0 - 100.0 / (1.0 + rs);
1029        let x = 0.1f64 * (rsi - 50.0);
1030
1031        if filled < wp {
1032            sum += x;
1033            num = f64::mul_add((filled as f64) + 1.0, x, num);
1034            *buf.get_unchecked_mut(head) = x;
1035            head += 1;
1036            if head == wp {
1037                head = 0;
1038            }
1039            filled += 1;
1040            if filled == wp {
1041                let wma = num * denom_rcp;
1042                *out.get_unchecked_mut(first + i) = wma.tanh();
1043            }
1044        } else {
1045            let x_old = *buf.get_unchecked(head);
1046            *buf.get_unchecked_mut(head) = x;
1047            head += 1;
1048            if head == wp {
1049                head = 0;
1050            }
1051            let sum_t = sum;
1052            num = f64::mul_add(wp_f, x, num) - sum_t;
1053            sum = sum_t + x - x_old;
1054            let wma = num * denom_rcp;
1055            *out.get_unchecked_mut(first + i) = wma.tanh();
1056        }
1057
1058        i += 1;
1059        if i > n1 {
1060            break;
1061        }
1062    }
1063}
1064
1065#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1066#[inline(always)]
1067unsafe fn ift_rsi_row_avx2(
1068    data: &[f64],
1069    first: usize,
1070    rsi_period: usize,
1071    wma_period: usize,
1072    out: &mut [f64],
1073) {
1074    ift_rsi_row_scalar(data, first, rsi_period, wma_period, out)
1075}
1076
1077#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1078#[inline(always)]
1079unsafe fn ift_rsi_row_avx512(
1080    data: &[f64],
1081    first: usize,
1082    rsi_period: usize,
1083    wma_period: usize,
1084    out: &mut [f64],
1085) {
1086    if rsi_period.max(wma_period) <= 32 {
1087        ift_rsi_row_avx512_short(data, first, rsi_period, wma_period, out);
1088    } else {
1089        ift_rsi_row_avx512_long(data, first, rsi_period, wma_period, out);
1090    }
1091}
1092
1093#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1094#[inline(always)]
1095unsafe fn ift_rsi_row_avx512_short(
1096    data: &[f64],
1097    first: usize,
1098    rsi_period: usize,
1099    wma_period: usize,
1100    out: &mut [f64],
1101) {
1102    ift_rsi_row_scalar(data, first, rsi_period, wma_period, out);
1103}
1104
1105#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1106#[inline(always)]
1107unsafe fn ift_rsi_row_avx512_long(
1108    data: &[f64],
1109    first: usize,
1110    rsi_period: usize,
1111    wma_period: usize,
1112    out: &mut [f64],
1113) {
1114    ift_rsi_row_scalar(data, first, rsi_period, wma_period, out);
1115}
1116
1117#[derive(Debug, Clone)]
1118pub struct IftRsiStream {
1119    rsi_period: usize,
1120    wma_period: usize,
1121
1122    prev: f64,
1123    have_prev: bool,
1124    seed_g: f64,
1125    seed_l: f64,
1126    seed_cnt: usize,
1127    avg_gain: f64,
1128    avg_loss: f64,
1129    seeded: bool,
1130    alpha: f64,
1131    beta: f64,
1132
1133    buf: Vec<f64>,
1134    head: usize,
1135    filled: usize,
1136    sum: f64,
1137    num: f64,
1138    wp_f: f64,
1139    denom_rcp: f64,
1140}
1141
1142impl IftRsiStream {
1143    pub fn try_new(params: IftRsiParams) -> Result<Self, IftRsiError> {
1144        let rsi_period = params.rsi_period.unwrap_or(5);
1145        let wma_period = params.wma_period.unwrap_or(9);
1146        if rsi_period == 0 || wma_period == 0 {
1147            return Err(IftRsiError::InvalidPeriod {
1148                rsi_period,
1149                wma_period,
1150                data_len: 0,
1151            });
1152        }
1153        let wp_f = wma_period as f64;
1154        let denom = 0.5 * wp_f * (wp_f + 1.0);
1155        Ok(Self {
1156            rsi_period,
1157            wma_period,
1158
1159            prev: 0.0,
1160            have_prev: false,
1161            seed_g: 0.0,
1162            seed_l: 0.0,
1163            seed_cnt: 0,
1164            avg_gain: 0.0,
1165            avg_loss: 0.0,
1166            seeded: false,
1167            alpha: 1.0 / (rsi_period as f64),
1168            beta: 1.0 - 1.0 / (rsi_period as f64),
1169
1170            buf: vec![0.0; wma_period],
1171            head: 0,
1172            filled: 0,
1173            sum: 0.0,
1174            num: 0.0,
1175            wp_f,
1176            denom_rcp: 1.0 / denom,
1177        })
1178    }
1179
1180    #[inline]
1181    pub fn update(&mut self, value: f64) -> Option<f64> {
1182        if !value.is_finite() {
1183            self.reset_soft();
1184            return None;
1185        }
1186
1187        if !self.have_prev {
1188            self.prev = value;
1189            self.have_prev = true;
1190            return None;
1191        }
1192
1193        let d = value - self.prev;
1194        self.prev = value;
1195        let gain = if d > 0.0 { d } else { 0.0 };
1196        let loss = if d < 0.0 { -d } else { 0.0 };
1197
1198        if !self.seeded {
1199            self.seed_g += gain;
1200            self.seed_l += loss;
1201            self.seed_cnt += 1;
1202            if self.seed_cnt < self.rsi_period {
1203                return None;
1204            }
1205
1206            self.avg_gain = self.seed_g / (self.rsi_period as f64);
1207            self.avg_loss = self.seed_l / (self.rsi_period as f64);
1208            self.seeded = true;
1209        } else {
1210            self.avg_gain = f64::mul_add(self.avg_gain, self.beta, self.alpha * gain);
1211            self.avg_loss = f64::mul_add(self.avg_loss, self.beta, self.alpha * loss);
1212        }
1213
1214        let rs = if self.avg_loss != 0.0 {
1215            self.avg_gain / self.avg_loss
1216        } else {
1217            100.0
1218        };
1219        let rsi = 100.0 - 100.0 / (1.0 + rs);
1220        let x = 0.1 * (rsi - 50.0);
1221
1222        if self.filled < self.wma_period {
1223            self.sum += x;
1224            self.num = f64::mul_add((self.filled as f64) + 1.0, x, self.num);
1225            self.buf[self.head] = x;
1226            self.head += 1;
1227            if self.head == self.wma_period {
1228                self.head = 0;
1229            }
1230            self.filled += 1;
1231
1232            if self.filled == self.wma_period {
1233                let wma = self.num * self.denom_rcp;
1234                return Some(tanh_kernel(wma));
1235            }
1236            return None;
1237        } else {
1238            let x_old = self.buf[self.head];
1239            self.buf[self.head] = x;
1240            self.head += 1;
1241            if self.head == self.wma_period {
1242                self.head = 0;
1243            }
1244
1245            let sum_prev = self.sum;
1246            self.num = f64::mul_add(self.wp_f, x, self.num) - sum_prev;
1247            self.sum = sum_prev + x - x_old;
1248
1249            let wma = self.num * self.denom_rcp;
1250            return Some(tanh_kernel(wma));
1251        }
1252    }
1253
1254    #[inline]
1255    fn reset_soft(&mut self) {
1256        self.have_prev = false;
1257        self.seed_g = 0.0;
1258        self.seed_l = 0.0;
1259        self.seed_cnt = 0;
1260        self.avg_gain = 0.0;
1261        self.avg_loss = 0.0;
1262        self.seeded = false;
1263
1264        self.head = 0;
1265        self.filled = 0;
1266        self.sum = 0.0;
1267        self.num = 0.0;
1268        for v in &mut self.buf {
1269            *v = 0.0;
1270        }
1271    }
1272}
1273
1274#[inline(always)]
1275fn tanh_kernel(x: f64) -> f64 {
1276    x.tanh()
1277}
1278
1279#[inline]
1280pub unsafe fn ift_rsi_scalar_classic(
1281    data: &[f64],
1282    rsi_period: usize,
1283    wma_period: usize,
1284    first_valid: usize,
1285    out: &mut [f64],
1286) -> Result<(), IftRsiError> {
1287    debug_assert!(rsi_period > 0 && wma_period > 0);
1288    let len = data.len();
1289    if first_valid >= len {
1290        return Ok(());
1291    }
1292    let sliced = data.get_unchecked(first_valid..);
1293    let n = sliced.len();
1294    if n == 0 {
1295        return Ok(());
1296    }
1297
1298    if rsi_period + wma_period - 1 >= n {
1299        return Ok(());
1300    }
1301
1302    let rp = rsi_period;
1303    let rp_f = rp as f64;
1304    let alpha = 1.0f64 / rp_f;
1305    let beta = 1.0f64 - alpha;
1306
1307    let mut avg_gain = 0.0f64;
1308    let mut avg_loss = 0.0f64;
1309    {
1310        let mut i = 1usize;
1311        while i <= rp {
1312            let d = *sliced.get_unchecked(i) - *sliced.get_unchecked(i - 1);
1313            if d > 0.0 {
1314                avg_gain += d;
1315            } else {
1316                avg_loss -= d;
1317            }
1318            i += 1;
1319        }
1320        avg_gain /= rp_f;
1321        avg_loss /= rp_f;
1322    }
1323
1324    let wp = wma_period;
1325    let wp_f = wp as f64;
1326    let denom = 0.5f64 * wp_f * (wp_f + 1.0);
1327    let denom_rcp = 1.0f64 / denom;
1328
1329    let mut buf: Vec<f64> = vec![0.0; wp];
1330    let mut head: usize = 0;
1331    let mut filled: usize = 0;
1332
1333    let mut sum = 0.0f64;
1334    let mut num = 0.0f64;
1335
1336    let mut i = rp;
1337    while i < n {
1338        if i > rp {
1339            let d = *sliced.get_unchecked(i) - *sliced.get_unchecked(i - 1);
1340            let gain = if d > 0.0 { d } else { 0.0 };
1341            let loss = if d < 0.0 { -d } else { 0.0 };
1342            avg_gain = f64::mul_add(avg_gain, beta, alpha * gain);
1343            avg_loss = f64::mul_add(avg_loss, beta, alpha * loss);
1344        }
1345
1346        let rs = if avg_loss != 0.0 {
1347            avg_gain / avg_loss
1348        } else {
1349            100.0
1350        };
1351        let rsi = 100.0 - 100.0 / (1.0 + rs);
1352        let x = 0.1f64 * (rsi - 50.0);
1353
1354        if filled < wp {
1355            sum += x;
1356            num = f64::mul_add((filled as f64) + 1.0, x, num);
1357            *buf.get_unchecked_mut(head) = x;
1358            head += 1;
1359            if head == wp {
1360                head = 0;
1361            }
1362            filled += 1;
1363
1364            if filled == wp {
1365                let wma = num * denom_rcp;
1366                *out.get_unchecked_mut(first_valid + i) = wma.tanh();
1367            }
1368        } else {
1369            let x_old = *buf.get_unchecked(head);
1370            *buf.get_unchecked_mut(head) = x;
1371            head += 1;
1372            if head == wp {
1373                head = 0;
1374            }
1375
1376            let sum_t = sum;
1377            num = f64::mul_add(wp_f, x, num) - sum_t;
1378            sum = sum_t + x - x_old;
1379
1380            let wma = num * denom_rcp;
1381            *out.get_unchecked_mut(first_valid + i) = wma.tanh();
1382        }
1383
1384        i += 1;
1385    }
1386
1387    Ok(())
1388}
1389
1390#[cfg(test)]
1391mod tests {
1392    use super::*;
1393    use crate::skip_if_unsupported;
1394    use crate::utilities::data_loader::read_candles_from_csv;
1395
1396    fn check_ift_rsi_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1397        skip_if_unsupported!(kernel, test_name);
1398        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1399        let candles = read_candles_from_csv(file_path)?;
1400        let default_params = IftRsiParams {
1401            rsi_period: None,
1402            wma_period: None,
1403        };
1404        let input = IftRsiInput::from_candles(&candles, "close", default_params);
1405        let output = ift_rsi_with_kernel(&input, kernel)?;
1406        assert_eq!(output.values.len(), candles.close.len());
1407        Ok(())
1408    }
1409
1410    #[test]
1411    fn test_ift_rsi_into_matches_api() -> Result<(), Box<dyn Error>> {
1412        let n = 256usize;
1413        let mut data = Vec::with_capacity(n);
1414        for i in 0..n {
1415            if i < 3 {
1416                data.push(f64::NAN);
1417            } else {
1418                let x = (i as f64).sin() * 5.0 + 100.0 + ((i % 7) as f64);
1419                data.push(x);
1420            }
1421        }
1422
1423        let input = IftRsiInput::from_slice(&data, IftRsiParams::default());
1424
1425        let baseline = ift_rsi(&input)?.values;
1426
1427        let mut out = vec![0.0; data.len()];
1428        ift_rsi_into(&input, &mut out)?;
1429
1430        assert_eq!(baseline.len(), out.len());
1431        for i in 0..out.len() {
1432            let a = baseline[i];
1433            let b = out[i];
1434            let equal = (a.is_nan() && b.is_nan()) || (a == b);
1435            assert!(equal, "Mismatch at {}: baseline={}, into={}", i, a, b);
1436        }
1437        Ok(())
1438    }
1439
1440    fn check_ift_rsi_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1441        skip_if_unsupported!(kernel, test_name);
1442        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1443        let candles = read_candles_from_csv(file_path)?;
1444        let input = IftRsiInput::from_candles(&candles, "close", IftRsiParams::default());
1445        let result = ift_rsi_with_kernel(&input, kernel)?;
1446
1447        let expected_last_five = [
1448            -0.35919800205778424,
1449            -0.3275464113984847,
1450            -0.39970276998138216,
1451            -0.36321812798797737,
1452            -0.5843346528346959,
1453        ];
1454        let start = result.values.len().saturating_sub(5);
1455        for (i, &val) in result.values[start..].iter().enumerate() {
1456            let diff = (val - expected_last_five[i]).abs();
1457            assert!(
1458                diff < 1e-8,
1459                "[{}] IFT_RSI {:?} mismatch at idx {}: got {}, expected {}",
1460                test_name,
1461                kernel,
1462                i,
1463                val,
1464                expected_last_five[i]
1465            );
1466        }
1467        Ok(())
1468    }
1469
1470    fn check_ift_rsi_default_candles(
1471        test_name: &str,
1472        kernel: Kernel,
1473    ) -> Result<(), Box<dyn Error>> {
1474        skip_if_unsupported!(kernel, test_name);
1475        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1476        let candles = read_candles_from_csv(file_path)?;
1477        let input = IftRsiInput::with_default_candles(&candles);
1478        let output = ift_rsi_with_kernel(&input, kernel)?;
1479        assert_eq!(output.values.len(), candles.close.len());
1480        Ok(())
1481    }
1482
1483    fn check_ift_rsi_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1484        skip_if_unsupported!(kernel, test_name);
1485        let input_data = [10.0, 20.0, 30.0];
1486        let params = IftRsiParams {
1487            rsi_period: Some(0),
1488            wma_period: Some(9),
1489        };
1490        let input = IftRsiInput::from_slice(&input_data, params);
1491        let res = ift_rsi_with_kernel(&input, kernel);
1492        assert!(
1493            res.is_err(),
1494            "[{}] IFT_RSI should fail with zero period",
1495            test_name
1496        );
1497        Ok(())
1498    }
1499
1500    fn check_ift_rsi_period_exceeds_length(
1501        test_name: &str,
1502        kernel: Kernel,
1503    ) -> Result<(), Box<dyn Error>> {
1504        skip_if_unsupported!(kernel, test_name);
1505        let data_small = [10.0, 20.0, 30.0];
1506        let params = IftRsiParams {
1507            rsi_period: Some(10),
1508            wma_period: Some(9),
1509        };
1510        let input = IftRsiInput::from_slice(&data_small, params);
1511        let res = ift_rsi_with_kernel(&input, kernel);
1512        assert!(
1513            res.is_err(),
1514            "[{}] IFT_RSI should fail with period exceeding length",
1515            test_name
1516        );
1517        Ok(())
1518    }
1519
1520    fn check_ift_rsi_very_small_dataset(
1521        test_name: &str,
1522        kernel: Kernel,
1523    ) -> Result<(), Box<dyn Error>> {
1524        skip_if_unsupported!(kernel, test_name);
1525        let single_point = [42.0];
1526        let params = IftRsiParams {
1527            rsi_period: Some(5),
1528            wma_period: Some(9),
1529        };
1530        let input = IftRsiInput::from_slice(&single_point, params);
1531        let res = ift_rsi_with_kernel(&input, kernel);
1532        assert!(
1533            res.is_err(),
1534            "[{}] IFT_RSI should fail with insufficient data",
1535            test_name
1536        );
1537        Ok(())
1538    }
1539
1540    fn check_ift_rsi_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1541        skip_if_unsupported!(kernel, test_name);
1542        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1543        let candles = read_candles_from_csv(file_path)?;
1544        let first_params = IftRsiParams {
1545            rsi_period: Some(5),
1546            wma_period: Some(9),
1547        };
1548        let first_input = IftRsiInput::from_candles(&candles, "close", first_params);
1549        let first_result = ift_rsi_with_kernel(&first_input, kernel)?;
1550        let second_params = IftRsiParams {
1551            rsi_period: Some(5),
1552            wma_period: Some(9),
1553        };
1554        let second_input = IftRsiInput::from_slice(&first_result.values, second_params);
1555        let second_result = ift_rsi_with_kernel(&second_input, kernel)?;
1556        assert_eq!(second_result.values.len(), first_result.values.len());
1557        Ok(())
1558    }
1559
1560    fn check_ift_rsi_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1561        skip_if_unsupported!(kernel, test_name);
1562        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1563        let candles = read_candles_from_csv(file_path)?;
1564        let input = IftRsiInput::from_candles(
1565            &candles,
1566            "close",
1567            IftRsiParams {
1568                rsi_period: Some(5),
1569                wma_period: Some(9),
1570            },
1571        );
1572        let res = ift_rsi_with_kernel(&input, kernel)?;
1573        assert_eq!(res.values.len(), candles.close.len());
1574        if res.values.len() > 240 {
1575            for (i, &val) in res.values[240..].iter().enumerate() {
1576                assert!(
1577                    !val.is_nan(),
1578                    "[{}] Found unexpected NaN at out-index {}",
1579                    test_name,
1580                    240 + i
1581                );
1582            }
1583        }
1584        Ok(())
1585    }
1586
1587    #[cfg(debug_assertions)]
1588    fn check_ift_rsi_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1589        skip_if_unsupported!(kernel, test_name);
1590
1591        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1592        let candles = read_candles_from_csv(file_path)?;
1593
1594        let test_params = vec![
1595            IftRsiParams::default(),
1596            IftRsiParams {
1597                rsi_period: Some(2),
1598                wma_period: Some(2),
1599            },
1600            IftRsiParams {
1601                rsi_period: Some(3),
1602                wma_period: Some(5),
1603            },
1604            IftRsiParams {
1605                rsi_period: Some(7),
1606                wma_period: Some(14),
1607            },
1608            IftRsiParams {
1609                rsi_period: Some(14),
1610                wma_period: Some(21),
1611            },
1612            IftRsiParams {
1613                rsi_period: Some(21),
1614                wma_period: Some(9),
1615            },
1616            IftRsiParams {
1617                rsi_period: Some(50),
1618                wma_period: Some(50),
1619            },
1620            IftRsiParams {
1621                rsi_period: Some(100),
1622                wma_period: Some(100),
1623            },
1624            IftRsiParams {
1625                rsi_period: Some(2),
1626                wma_period: Some(50),
1627            },
1628            IftRsiParams {
1629                rsi_period: Some(50),
1630                wma_period: Some(2),
1631            },
1632            IftRsiParams {
1633                rsi_period: Some(9),
1634                wma_period: Some(21),
1635            },
1636            IftRsiParams {
1637                rsi_period: Some(25),
1638                wma_period: Some(10),
1639            },
1640        ];
1641
1642        for (param_idx, params) in test_params.iter().enumerate() {
1643            let input = IftRsiInput::from_candles(&candles, "close", params.clone());
1644            let output = ift_rsi_with_kernel(&input, kernel)?;
1645
1646            for (i, &val) in output.values.iter().enumerate() {
1647                if val.is_nan() {
1648                    continue;
1649                }
1650
1651                let bits = val.to_bits();
1652
1653                if bits == 0x11111111_11111111 {
1654                    panic!(
1655                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1656						 with params: rsi_period={}, wma_period={} (param set {})",
1657                        test_name,
1658                        val,
1659                        bits,
1660                        i,
1661                        params.rsi_period.unwrap_or(5),
1662                        params.wma_period.unwrap_or(9),
1663                        param_idx
1664                    );
1665                }
1666
1667                if bits == 0x22222222_22222222 {
1668                    panic!(
1669                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1670						 with params: rsi_period={}, wma_period={} (param set {})",
1671                        test_name,
1672                        val,
1673                        bits,
1674                        i,
1675                        params.rsi_period.unwrap_or(5),
1676                        params.wma_period.unwrap_or(9),
1677                        param_idx
1678                    );
1679                }
1680
1681                if bits == 0x33333333_33333333 {
1682                    panic!(
1683                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1684						 with params: rsi_period={}, wma_period={} (param set {})",
1685                        test_name,
1686                        val,
1687                        bits,
1688                        i,
1689                        params.rsi_period.unwrap_or(5),
1690                        params.wma_period.unwrap_or(9),
1691                        param_idx
1692                    );
1693                }
1694            }
1695        }
1696
1697        Ok(())
1698    }
1699
1700    #[cfg(not(debug_assertions))]
1701    fn check_ift_rsi_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1702        Ok(())
1703    }
1704
1705    #[cfg(feature = "proptest")]
1706    fn check_ift_rsi_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1707        use proptest::prelude::*;
1708        skip_if_unsupported!(kernel, test_name);
1709
1710        let strat = (2usize..=50, 2usize..=50)
1711            .prop_flat_map(|(rsi_period, wma_period)| {
1712                let min_len = (rsi_period + wma_period) * 2;
1713                (
1714                    (100.0f64..5000.0f64, 0.01f64..0.1f64),
1715                    -0.02f64..0.02f64,
1716                    Just(rsi_period),
1717                    Just(wma_period),
1718                    min_len..400,
1719                )
1720            })
1721            .prop_map(
1722                |((base_price, volatility), trend, rsi_period, wma_period, len)| {
1723                    let mut prices = Vec::with_capacity(len);
1724                    let mut current_price = base_price;
1725
1726                    for i in 0..len {
1727                        current_price *= 1.0 + trend;
1728
1729                        let noise = 1.0 + (i as f64 * 0.1).sin() * volatility;
1730                        prices.push(current_price * noise);
1731                    }
1732
1733                    (prices, rsi_period, wma_period)
1734                },
1735            );
1736
1737        proptest::test_runner::TestRunner::default().run(
1738            &strat,
1739            |(data, rsi_period, wma_period)| {
1740                let params = IftRsiParams {
1741                    rsi_period: Some(rsi_period),
1742                    wma_period: Some(wma_period),
1743                };
1744                let input = IftRsiInput::from_slice(&data, params);
1745
1746                let IftRsiOutput { values: out } = ift_rsi_with_kernel(&input, kernel)?;
1747
1748                let IftRsiOutput { values: ref_out } = ift_rsi_with_kernel(&input, Kernel::Scalar)?;
1749
1750                prop_assert_eq!(out.len(), data.len(), "Output length mismatch");
1751
1752                let warmup_period = rsi_period + wma_period - 1;
1753                for i in 0..warmup_period.min(data.len()) {
1754                    prop_assert!(
1755                        out[i].is_nan(),
1756                        "Expected NaN during warmup at index {}, got {}",
1757                        i,
1758                        out[i]
1759                    );
1760                }
1761
1762                for i in warmup_period..data.len() {
1763                    let y = out[i];
1764                    let r = ref_out[i];
1765
1766                    if y.is_finite() {
1767                        prop_assert!(
1768                            y >= -1.0 - 1e-9 && y <= 1.0 + 1e-9,
1769                            "IFT RSI value {} at index {} outside [-1, 1] bounds",
1770                            y,
1771                            i
1772                        );
1773                    }
1774
1775                    if !y.is_finite() || !r.is_finite() {
1776                        prop_assert_eq!(
1777                            y.to_bits(),
1778                            r.to_bits(),
1779                            "NaN/Inf mismatch at index {}: {} vs {}",
1780                            i,
1781                            y,
1782                            r
1783                        );
1784                    } else {
1785                        let ulp_diff = y.to_bits().abs_diff(r.to_bits());
1786                        prop_assert!(
1787                            (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1788                            "Kernel mismatch at index {}: {} vs {} (ULP={})",
1789                            i,
1790                            y,
1791                            r,
1792                            ulp_diff
1793                        );
1794                    }
1795
1796                    if i >= warmup_period + 10 {
1797                        let lookback = 10;
1798                        let recent_prices = &data[i - lookback..=i];
1799                        let price_change =
1800                            (recent_prices[lookback] - recent_prices[0]) / recent_prices[0];
1801
1802                        if price_change > 0.05 && y.is_finite() {
1803                            prop_assert!(
1804								y > 0.2,
1805								"Strong uptrend should produce positive IFT RSI > 0.2, got {} at index {}",
1806								y,
1807								i
1808							);
1809                        }
1810
1811                        if price_change < -0.05 && y.is_finite() {
1812                            prop_assert!(
1813								y < -0.2,
1814								"Strong downtrend should produce negative IFT RSI < -0.2, got {} at index {}",
1815								y,
1816								i
1817							);
1818                        }
1819                    }
1820
1821                    if !data[..=i].iter().any(|x| x.is_nan()) {
1822                        prop_assert!(!y.is_nan(), "Unexpected NaN at index {} after warmup", i);
1823                    }
1824                }
1825
1826                if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10)
1827                    && data.len() > warmup_period
1828                {
1829                    for i in warmup_period..out.len() {
1830                        if out[i].is_finite() {
1831                            prop_assert!(
1832                                (out[i] - (-1.0)).abs() < 1e-6,
1833                                "Constant prices should yield IFT RSI = -1, got {} at index {}",
1834                                out[i],
1835                                i
1836                            );
1837                        }
1838                    }
1839                }
1840
1841                let volatility = if data.len() > 2 {
1842                    let returns: Vec<f64> = data.windows(2).map(|w| (w[1] - w[0]) / w[0]).collect();
1843                    let mean_return = returns.iter().sum::<f64>() / returns.len() as f64;
1844                    let variance = returns
1845                        .iter()
1846                        .map(|r| (r - mean_return).powi(2))
1847                        .sum::<f64>()
1848                        / returns.len() as f64;
1849                    variance.sqrt()
1850                } else {
1851                    0.0
1852                };
1853
1854                if volatility > 0.1 {
1855                    for &val in out.iter() {
1856                        if val.is_finite() {
1857                            prop_assert!(
1858                                val >= -1.0 && val <= 1.0,
1859                                "Even with extreme volatility, IFT RSI must be bounded: {}",
1860                                val
1861                            );
1862                        }
1863                    }
1864                }
1865
1866                if data.len() > warmup_period + 20 {
1867                    for check_idx in (warmup_period + 10..data.len()).step_by(20) {
1868                        if check_idx + 5 >= data.len() {
1869                            break;
1870                        }
1871
1872                        let recent_window = &data[check_idx - 5..=check_idx];
1873                        let gains: f64 = recent_window
1874                            .windows(2)
1875                            .map(|w| (w[1] - w[0]).max(0.0))
1876                            .sum();
1877                        let losses: f64 = recent_window
1878                            .windows(2)
1879                            .map(|w| (w[0] - w[1]).max(0.0))
1880                            .sum();
1881
1882                        if gains > losses * 1.5 && out[check_idx].is_finite() {
1883                            prop_assert!(
1884								out[check_idx] > -0.1,
1885								"Bullish momentum (gains > losses*1.5) should yield IFT RSI > -0.1, got {} at index {}",
1886								out[check_idx],
1887								check_idx
1888							);
1889                        }
1890
1891                        if losses > gains * 1.5 && out[check_idx].is_finite() {
1892                            prop_assert!(
1893								out[check_idx] < 0.1,
1894								"Bearish momentum (losses > gains*1.5) should yield IFT RSI < 0.1, got {} at index {}",
1895								out[check_idx],
1896								check_idx
1897							);
1898                        }
1899                    }
1900                }
1901
1902                Ok(())
1903            },
1904        )?;
1905
1906        Ok(())
1907    }
1908
1909    #[cfg(not(feature = "proptest"))]
1910    fn check_ift_rsi_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1911        skip_if_unsupported!(kernel, test_name);
1912        Ok(())
1913    }
1914
1915    macro_rules! generate_all_ift_rsi_tests {
1916        ($($test_fn:ident),*) => {
1917            paste::paste! {
1918                $(
1919                    #[test]
1920                    fn [<$test_fn _scalar_f64>]() {
1921                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1922                    }
1923                )*
1924                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1925                $(
1926                    #[test]
1927                    fn [<$test_fn _avx2_f64>]() {
1928                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1929                    }
1930                    #[test]
1931                    fn [<$test_fn _avx512_f64>]() {
1932                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1933                    }
1934                )*
1935            }
1936        }
1937    }
1938
1939    generate_all_ift_rsi_tests!(
1940        check_ift_rsi_partial_params,
1941        check_ift_rsi_accuracy,
1942        check_ift_rsi_default_candles,
1943        check_ift_rsi_zero_period,
1944        check_ift_rsi_period_exceeds_length,
1945        check_ift_rsi_very_small_dataset,
1946        check_ift_rsi_reinput,
1947        check_ift_rsi_nan_handling,
1948        check_ift_rsi_no_poison,
1949        check_ift_rsi_property
1950    );
1951
1952    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1953        skip_if_unsupported!(kernel, test);
1954        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1955        let c = read_candles_from_csv(file)?;
1956        let output = IftRsiBatchBuilder::new()
1957            .kernel(kernel)
1958            .apply_candles(&c, "close")?;
1959        let def = IftRsiParams::default();
1960        let row = output.values_for(&def).expect("default row missing");
1961        assert_eq!(row.len(), c.close.len());
1962        Ok(())
1963    }
1964
1965    #[cfg(debug_assertions)]
1966    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1967        skip_if_unsupported!(kernel, test);
1968
1969        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1970        let c = read_candles_from_csv(file)?;
1971
1972        let test_configs = vec![
1973            (2, 10, 2, 2, 10, 2),
1974            (5, 25, 5, 5, 25, 5),
1975            (30, 60, 15, 30, 60, 15),
1976            (2, 5, 1, 2, 5, 1),
1977            (9, 15, 3, 9, 15, 3),
1978            (2, 2, 0, 2, 20, 2),
1979            (2, 20, 2, 9, 9, 0),
1980        ];
1981
1982        for (cfg_idx, &(rsi_start, rsi_end, rsi_step, wma_start, wma_end, wma_step)) in
1983            test_configs.iter().enumerate()
1984        {
1985            let output = IftRsiBatchBuilder::new()
1986                .kernel(kernel)
1987                .rsi_period_range(rsi_start, rsi_end, rsi_step)
1988                .wma_period_range(wma_start, wma_end, wma_step)
1989                .apply_candles(&c, "close")?;
1990
1991            for (idx, &val) in output.values.iter().enumerate() {
1992                if val.is_nan() {
1993                    continue;
1994                }
1995
1996                let bits = val.to_bits();
1997                let row = idx / output.cols;
1998                let col = idx % output.cols;
1999                let combo = &output.combos[row];
2000
2001                if bits == 0x11111111_11111111 {
2002                    panic!(
2003                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2004						 at row {} col {} (flat index {}) with params: rsi_period={}, wma_period={}",
2005                        test,
2006                        cfg_idx,
2007                        val,
2008                        bits,
2009                        row,
2010                        col,
2011                        idx,
2012                        combo.rsi_period.unwrap_or(5),
2013                        combo.wma_period.unwrap_or(9)
2014                    );
2015                }
2016
2017                if bits == 0x22222222_22222222 {
2018                    panic!(
2019                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2020						 at row {} col {} (flat index {}) with params: rsi_period={}, wma_period={}",
2021                        test,
2022                        cfg_idx,
2023                        val,
2024                        bits,
2025                        row,
2026                        col,
2027                        idx,
2028                        combo.rsi_period.unwrap_or(5),
2029                        combo.wma_period.unwrap_or(9)
2030                    );
2031                }
2032
2033                if bits == 0x33333333_33333333 {
2034                    panic!(
2035                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2036						 at row {} col {} (flat index {}) with params: rsi_period={}, wma_period={}",
2037                        test,
2038                        cfg_idx,
2039                        val,
2040                        bits,
2041                        row,
2042                        col,
2043                        idx,
2044                        combo.rsi_period.unwrap_or(5),
2045                        combo.wma_period.unwrap_or(9)
2046                    );
2047                }
2048            }
2049        }
2050
2051        Ok(())
2052    }
2053
2054    #[cfg(not(debug_assertions))]
2055    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2056        Ok(())
2057    }
2058
2059    macro_rules! gen_batch_tests {
2060        ($fn_name:ident) => {
2061            paste::paste! {
2062                #[test] fn [<$fn_name _scalar>]()      {
2063                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2064                }
2065                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2066                #[test] fn [<$fn_name _avx2>]()        {
2067                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2068                }
2069                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2070                #[test] fn [<$fn_name _avx512>]()      {
2071                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2072                }
2073                #[test] fn [<$fn_name _auto_detect>]() {
2074                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2075                }
2076            }
2077        };
2078    }
2079    gen_batch_tests!(check_batch_default_row);
2080    gen_batch_tests!(check_batch_no_poison);
2081}
2082
2083#[cfg(feature = "python")]
2084#[pyfunction(name = "ift_rsi")]
2085#[pyo3(signature = (data, rsi_period, wma_period, kernel=None))]
2086pub fn ift_rsi_py<'py>(
2087    py: Python<'py>,
2088    data: PyReadonlyArray1<'py, f64>,
2089    rsi_period: usize,
2090    wma_period: usize,
2091    kernel: Option<&str>,
2092) -> PyResult<Bound<'py, PyArray1<f64>>> {
2093    use numpy::{IntoPyArray, PyArrayMethods};
2094
2095    let slice_in = data.as_slice()?;
2096    let kern = validate_kernel(kernel, false)?;
2097
2098    let params = IftRsiParams {
2099        rsi_period: Some(rsi_period),
2100        wma_period: Some(wma_period),
2101    };
2102    let input = IftRsiInput::from_slice(slice_in, params);
2103
2104    let result_vec: Vec<f64> = py
2105        .allow_threads(|| ift_rsi_with_kernel(&input, kern).map(|o| o.values))
2106        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2107
2108    Ok(result_vec.into_pyarray(py))
2109}
2110
2111#[cfg(feature = "python")]
2112#[pyclass(name = "IftRsiStream")]
2113pub struct IftRsiStreamPy {
2114    stream: IftRsiStream,
2115}
2116
2117#[cfg(feature = "python")]
2118#[pymethods]
2119impl IftRsiStreamPy {
2120    #[new]
2121    fn new(rsi_period: usize, wma_period: usize) -> PyResult<Self> {
2122        let params = IftRsiParams {
2123            rsi_period: Some(rsi_period),
2124            wma_period: Some(wma_period),
2125        };
2126        let stream =
2127            IftRsiStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2128        Ok(IftRsiStreamPy { stream })
2129    }
2130
2131    fn update(&mut self, value: f64) -> Option<f64> {
2132        self.stream.update(value)
2133    }
2134}
2135
2136#[cfg(feature = "python")]
2137#[pyfunction(name = "ift_rsi_batch")]
2138#[pyo3(signature = (data, rsi_period_range, wma_period_range, kernel=None))]
2139pub fn ift_rsi_batch_py<'py>(
2140    py: Python<'py>,
2141    data: PyReadonlyArray1<'py, f64>,
2142    rsi_period_range: (usize, usize, usize),
2143    wma_period_range: (usize, usize, usize),
2144    kernel: Option<&str>,
2145) -> PyResult<Bound<'py, PyDict>> {
2146    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2147    use pyo3::types::PyDict;
2148
2149    let slice_in = data.as_slice()?;
2150    let kern = validate_kernel(kernel, true)?;
2151
2152    let sweep = IftRsiBatchRange {
2153        rsi_period: rsi_period_range,
2154        wma_period: wma_period_range,
2155    };
2156
2157    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2158    let rows = combos.len();
2159    let cols = slice_in.len();
2160
2161    let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
2162    let slice_out = unsafe { out_arr.as_slice_mut()? };
2163
2164    let combos = py
2165        .allow_threads(|| {
2166            let kernel = match kern {
2167                Kernel::Auto => detect_best_batch_kernel(),
2168                k => k,
2169            };
2170            let simd = match kernel {
2171                Kernel::Avx512Batch => Kernel::Avx512,
2172                Kernel::Avx2Batch => Kernel::Avx2,
2173                Kernel::ScalarBatch => Kernel::Scalar,
2174                _ => unreachable!(),
2175            };
2176            ift_rsi_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2177        })
2178        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2179
2180    let dict = PyDict::new(py);
2181    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2182    dict.set_item(
2183        "rsi_periods",
2184        combos
2185            .iter()
2186            .map(|p| p.rsi_period.unwrap() as u64)
2187            .collect::<Vec<_>>()
2188            .into_pyarray(py),
2189    )?;
2190    dict.set_item(
2191        "wma_periods",
2192        combos
2193            .iter()
2194            .map(|p| p.wma_period.unwrap() as u64)
2195            .collect::<Vec<_>>()
2196            .into_pyarray(py),
2197    )?;
2198    dict.set_item("rows", rows)?;
2199    dict.set_item("cols", cols)?;
2200
2201    Ok(dict)
2202}
2203
2204#[cfg(all(feature = "python", feature = "cuda"))]
2205#[pyfunction(name = "ift_rsi_cuda_batch_dev")]
2206#[pyo3(signature = (data_f32, rsi_range, wma_range, device_id=0))]
2207pub fn ift_rsi_cuda_batch_dev_py(
2208    py: Python<'_>,
2209    data_f32: numpy::PyReadonlyArray1<'_, f32>,
2210    rsi_range: (usize, usize, usize),
2211    wma_range: (usize, usize, usize),
2212    device_id: usize,
2213) -> PyResult<DeviceArrayF32Py> {
2214    use crate::cuda::cuda_available;
2215    use crate::cuda::oscillators::CudaIftRsi;
2216    if !cuda_available() {
2217        return Err(PyValueError::new_err("CUDA not available"));
2218    }
2219    let slice_in: &[f32] = data_f32.as_slice()?;
2220    let sweep = IftRsiBatchRange {
2221        rsi_period: rsi_range,
2222        wma_period: wma_range,
2223    };
2224    let (inner, dev_id, ctx) = py.allow_threads(|| {
2225        let cuda = CudaIftRsi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2226        let dev_id = cuda.device_id();
2227        let ctx = cuda.context_arc();
2228        let (dev, _combos) = cuda
2229            .ift_rsi_batch_dev(slice_in, &sweep)
2230            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2231        cuda.synchronize()
2232            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2233        Ok::<_, PyErr>((dev, dev_id, ctx))
2234    })?;
2235    let handle = DeviceArrayF32Py {
2236        inner,
2237        _ctx: Some(ctx),
2238        device_id: Some(dev_id),
2239    };
2240    Ok(handle)
2241}
2242
2243#[cfg(all(feature = "python", feature = "cuda"))]
2244#[pyfunction(name = "ift_rsi_cuda_many_series_one_param_dev")]
2245#[pyo3(signature = (data_tm_f32, rsi_period, wma_period, device_id=0))]
2246pub fn ift_rsi_cuda_many_series_one_param_dev_py(
2247    py: Python<'_>,
2248    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
2249    rsi_period: usize,
2250    wma_period: usize,
2251    device_id: usize,
2252) -> PyResult<DeviceArrayF32Py> {
2253    use crate::cuda::cuda_available;
2254    use crate::cuda::oscillators::CudaIftRsi;
2255    use numpy::PyUntypedArrayMethods;
2256    if !cuda_available() {
2257        return Err(PyValueError::new_err("CUDA not available"));
2258    }
2259    let flat_in: &[f32] = data_tm_f32.as_slice()?;
2260    let rows = data_tm_f32.shape()[0];
2261    let cols = data_tm_f32.shape()[1];
2262    let params = IftRsiParams {
2263        rsi_period: Some(rsi_period),
2264        wma_period: Some(wma_period),
2265    };
2266    let (inner, dev_id, ctx) = py.allow_threads(|| {
2267        let cuda = CudaIftRsi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2268        let dev_id = cuda.device_id();
2269        let ctx = cuda.context_arc();
2270        let dev = cuda
2271            .ift_rsi_many_series_one_param_time_major_dev(flat_in, cols, rows, &params)
2272            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2273        cuda.synchronize()
2274            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2275        Ok::<_, PyErr>((dev, dev_id, ctx))
2276    })?;
2277    let handle = DeviceArrayF32Py {
2278        inner,
2279        _ctx: Some(ctx),
2280        device_id: Some(dev_id),
2281    };
2282    Ok(handle)
2283}
2284
2285#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2286#[wasm_bindgen]
2287pub fn ift_rsi_js(data: &[f64], rsi_period: usize, wma_period: usize) -> Result<Vec<f64>, JsValue> {
2288    let params = IftRsiParams {
2289        rsi_period: Some(rsi_period),
2290        wma_period: Some(wma_period),
2291    };
2292    let input = IftRsiInput::from_slice(data, params);
2293
2294    let mut output = vec![0.0; data.len()];
2295
2296    let kernel = Kernel::Scalar;
2297
2298    ift_rsi_into_slice(&mut output, &input, kernel)
2299        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2300
2301    Ok(output)
2302}
2303
2304#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2305#[wasm_bindgen]
2306pub fn ift_rsi_into(
2307    in_ptr: *const f64,
2308    out_ptr: *mut f64,
2309    len: usize,
2310    rsi_period: usize,
2311    wma_period: usize,
2312) -> Result<(), JsValue> {
2313    if in_ptr.is_null() || out_ptr.is_null() {
2314        return Err(JsValue::from_str("null pointer passed to ift_rsi_into"));
2315    }
2316
2317    unsafe {
2318        let data = std::slice::from_raw_parts(in_ptr, len);
2319        let params = IftRsiParams {
2320            rsi_period: Some(rsi_period),
2321            wma_period: Some(wma_period),
2322        };
2323        let input = IftRsiInput::from_slice(data, params);
2324
2325        let kernel = Kernel::Scalar;
2326
2327        if in_ptr == out_ptr as *const f64 {
2328            let mut temp = vec![0.0; len];
2329            ift_rsi_into_slice(&mut temp, &input, kernel)
2330                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2331            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2332            out.copy_from_slice(&temp);
2333        } else {
2334            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2335            ift_rsi_into_slice(out, &input, kernel)
2336                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2337        }
2338        Ok(())
2339    }
2340}
2341
2342#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2343#[wasm_bindgen]
2344pub fn ift_rsi_alloc(len: usize) -> *mut f64 {
2345    let mut vec = Vec::<f64>::with_capacity(len);
2346    let ptr = vec.as_mut_ptr();
2347    std::mem::forget(vec);
2348    ptr
2349}
2350
2351#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2352#[wasm_bindgen]
2353pub fn ift_rsi_free(ptr: *mut f64, len: usize) {
2354    if !ptr.is_null() {
2355        unsafe {
2356            let _ = Vec::from_raw_parts(ptr, len, len);
2357        }
2358    }
2359}
2360
2361#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2362#[derive(Serialize, Deserialize)]
2363pub struct IftRsiBatchConfig {
2364    pub rsi_period_range: (usize, usize, usize),
2365    pub wma_period_range: (usize, usize, usize),
2366}
2367
2368#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2369#[derive(Serialize, Deserialize)]
2370pub struct IftRsiBatchJsOutput {
2371    pub values: Vec<f64>,
2372    pub combos: Vec<IftRsiParams>,
2373    pub rows: usize,
2374    pub cols: usize,
2375}
2376
2377#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2378#[wasm_bindgen(js_name = ift_rsi_batch)]
2379pub fn ift_rsi_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2380    let config: IftRsiBatchConfig = serde_wasm_bindgen::from_value(config)
2381        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2382
2383    let sweep = IftRsiBatchRange {
2384        rsi_period: config.rsi_period_range,
2385        wma_period: config.wma_period_range,
2386    };
2387
2388    #[cfg(target_arch = "wasm32")]
2389    let kernel = detect_best_kernel();
2390    #[cfg(not(target_arch = "wasm32"))]
2391    let kernel = Kernel::Scalar;
2392
2393    let output = ift_rsi_batch_inner(data, &sweep, kernel, false)
2394        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2395
2396    let js_output = IftRsiBatchJsOutput {
2397        values: output.values,
2398        combos: output.combos,
2399        rows: output.rows,
2400        cols: output.cols,
2401    };
2402
2403    serde_wasm_bindgen::to_value(&js_output)
2404        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2405}
2406
2407#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2408#[wasm_bindgen]
2409pub fn ift_rsi_batch_into(
2410    in_ptr: *const f64,
2411    out_ptr: *mut f64,
2412    len: usize,
2413    rsi_start: usize,
2414    rsi_end: usize,
2415    rsi_step: usize,
2416    wma_start: usize,
2417    wma_end: usize,
2418    wma_step: usize,
2419) -> Result<usize, JsValue> {
2420    if in_ptr.is_null() || out_ptr.is_null() {
2421        return Err(JsValue::from_str(
2422            "null pointer passed to ift_rsi_batch_into",
2423        ));
2424    }
2425
2426    unsafe {
2427        let data = std::slice::from_raw_parts(in_ptr, len);
2428        let sweep = IftRsiBatchRange {
2429            rsi_period: (rsi_start, rsi_end, rsi_step),
2430            wma_period: (wma_start, wma_end, wma_step),
2431        };
2432
2433        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2434        let rows = combos.len();
2435        let cols = len;
2436        let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2437
2438        #[cfg(target_arch = "wasm32")]
2439        let kernel = detect_best_kernel();
2440        #[cfg(not(target_arch = "wasm32"))]
2441        let kernel = Kernel::Scalar;
2442
2443        ift_rsi_batch_inner_into(data, &sweep, kernel, false, out)
2444            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2445
2446        Ok(rows)
2447    }
2448}