1#[cfg(feature = "python")]
2use crate::utilities::kernel_validation::validate_kernel;
3#[cfg(feature = "python")]
4use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyValueError;
7#[cfg(feature = "python")]
8use pyo3::prelude::*;
9#[cfg(feature = "python")]
10use pyo3::types::PyDict;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::indicators::rsi::{rsi, RsiError, RsiInput, RsiParams};
18use crate::indicators::wma::{wma, WmaError, WmaInput, WmaParams};
19use crate::utilities::data_loader::{source_type, Candles};
20#[cfg(all(feature = "python", feature = "cuda"))]
21use crate::utilities::dlpack_cuda::DeviceArrayF32Py;
22use crate::utilities::enums::Kernel;
23use crate::utilities::helpers::{
24 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
25 make_uninit_matrix,
26};
27use aligned_vec::{AVec, CACHELINE_ALIGN};
28#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
29use core::arch::x86_64::*;
30#[cfg(not(target_arch = "wasm32"))]
31use rayon::prelude::*;
32use std::convert::AsRef;
33use std::error::Error;
34use thiserror::Error;
35
36impl<'a> AsRef<[f64]> for IftRsiInput<'a> {
37 #[inline(always)]
38 fn as_ref(&self) -> &[f64] {
39 match &self.data {
40 IftRsiData::Slice(slice) => slice,
41 IftRsiData::Candles { candles, source } => source_type(candles, source),
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
47pub enum IftRsiData<'a> {
48 Candles {
49 candles: &'a Candles,
50 source: &'a str,
51 },
52 Slice(&'a [f64]),
53}
54
55#[derive(Debug, Clone)]
56pub struct IftRsiOutput {
57 pub values: Vec<f64>,
58}
59
60#[derive(Debug, Clone)]
61#[cfg_attr(
62 all(target_arch = "wasm32", feature = "wasm"),
63 derive(Serialize, Deserialize)
64)]
65pub struct IftRsiParams {
66 pub rsi_period: Option<usize>,
67 pub wma_period: Option<usize>,
68}
69
70impl Default for IftRsiParams {
71 fn default() -> Self {
72 Self {
73 rsi_period: Some(5),
74 wma_period: Some(9),
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
80pub struct IftRsiInput<'a> {
81 pub data: IftRsiData<'a>,
82 pub params: IftRsiParams,
83}
84
85impl<'a> IftRsiInput<'a> {
86 #[inline]
87 pub fn from_candles(c: &'a Candles, s: &'a str, p: IftRsiParams) -> Self {
88 Self {
89 data: IftRsiData::Candles {
90 candles: c,
91 source: s,
92 },
93 params: p,
94 }
95 }
96 #[inline]
97 pub fn from_slice(sl: &'a [f64], p: IftRsiParams) -> Self {
98 Self {
99 data: IftRsiData::Slice(sl),
100 params: p,
101 }
102 }
103 #[inline]
104 pub fn with_default_candles(c: &'a Candles) -> Self {
105 Self::from_candles(c, "close", IftRsiParams::default())
106 }
107 #[inline]
108 pub fn get_rsi_period(&self) -> usize {
109 self.params.rsi_period.unwrap_or(5)
110 }
111 #[inline]
112 pub fn get_wma_period(&self) -> usize {
113 self.params.wma_period.unwrap_or(9)
114 }
115}
116
117#[derive(Copy, Clone, Debug)]
118pub struct IftRsiBuilder {
119 rsi_period: Option<usize>,
120 wma_period: Option<usize>,
121 kernel: Kernel,
122}
123
124impl Default for IftRsiBuilder {
125 fn default() -> Self {
126 Self {
127 rsi_period: None,
128 wma_period: None,
129 kernel: Kernel::Auto,
130 }
131 }
132}
133
134impl IftRsiBuilder {
135 #[inline(always)]
136 pub fn new() -> Self {
137 Self::default()
138 }
139 #[inline(always)]
140 pub fn rsi_period(mut self, n: usize) -> Self {
141 self.rsi_period = Some(n);
142 self
143 }
144 #[inline(always)]
145 pub fn wma_period(mut self, n: usize) -> Self {
146 self.wma_period = Some(n);
147 self
148 }
149 #[inline(always)]
150 pub fn kernel(mut self, k: Kernel) -> Self {
151 self.kernel = k;
152 self
153 }
154
155 #[inline(always)]
156 pub fn apply(self, c: &Candles) -> Result<IftRsiOutput, IftRsiError> {
157 let p = IftRsiParams {
158 rsi_period: self.rsi_period,
159 wma_period: self.wma_period,
160 };
161 let i = IftRsiInput::from_candles(c, "close", p);
162 ift_rsi_with_kernel(&i, self.kernel)
163 }
164
165 #[inline(always)]
166 pub fn apply_slice(self, d: &[f64]) -> Result<IftRsiOutput, IftRsiError> {
167 let p = IftRsiParams {
168 rsi_period: self.rsi_period,
169 wma_period: self.wma_period,
170 };
171 let i = IftRsiInput::from_slice(d, p);
172 ift_rsi_with_kernel(&i, self.kernel)
173 }
174
175 #[inline(always)]
176 pub fn into_stream(self) -> Result<IftRsiStream, IftRsiError> {
177 let p = IftRsiParams {
178 rsi_period: self.rsi_period,
179 wma_period: self.wma_period,
180 };
181 IftRsiStream::try_new(p)
182 }
183}
184
185#[derive(Debug, Error)]
186pub enum IftRsiError {
187 #[error("ift_rsi: Input data slice is empty.")]
188 EmptyData,
189 #[error("ift_rsi: All values are NaN.")]
190 AllValuesNaN,
191 #[error("ift_rsi: Invalid RSI period {rsi_period} or WMA period {wma_period}, data length = {data_len}.")]
192 InvalidPeriod {
193 rsi_period: usize,
194 wma_period: usize,
195 data_len: usize,
196 },
197 #[error("ift_rsi: Not enough valid data: needed = {needed}, valid = {valid}")]
198 NotEnoughValidData { needed: usize, valid: usize },
199 #[error("ift_rsi: RSI calculation error: {0}")]
200 RsiCalculationError(String),
201 #[error("ift_rsi: WMA calculation error: {0}")]
202 WmaCalculationError(String),
203 #[error("ift_rsi: Output length mismatch: expected = {expected}, got = {got}")]
204 OutputLengthMismatch { expected: usize, got: usize },
205 #[error("ift_rsi: Wrong kernel for batch operation. Use a batch kernel variant.")]
206 WrongKernelForBatch,
207 #[error("ift_rsi: Invalid kernel for batch: {0:?}")]
208 InvalidKernelForBatch(crate::utilities::enums::Kernel),
209 #[error("ift_rsi: Invalid range: start={start}, end={end}, step={step}")]
210 InvalidRange {
211 start: usize,
212 end: usize,
213 step: usize,
214 },
215}
216
217#[inline]
218pub fn ift_rsi(input: &IftRsiInput) -> Result<IftRsiOutput, IftRsiError> {
219 ift_rsi_with_kernel(input, Kernel::Auto)
220}
221
222pub fn ift_rsi_with_kernel(
223 input: &IftRsiInput,
224 kernel: Kernel,
225) -> Result<IftRsiOutput, IftRsiError> {
226 let data: &[f64] = match &input.data {
227 IftRsiData::Candles { candles, source } => source_type(candles, source),
228 IftRsiData::Slice(sl) => sl,
229 };
230
231 if data.is_empty() {
232 return Err(IftRsiError::EmptyData);
233 }
234 let first = data
235 .iter()
236 .position(|x| !x.is_nan())
237 .ok_or(IftRsiError::AllValuesNaN)?;
238 let len = data.len();
239 let rsi_period = input.get_rsi_period();
240 let wma_period = input.get_wma_period();
241 if rsi_period == 0 || wma_period == 0 || rsi_period > len || wma_period > len {
242 return Err(IftRsiError::InvalidPeriod {
243 rsi_period,
244 wma_period,
245 data_len: len,
246 });
247 }
248 let needed = rsi_period.max(wma_period);
249 if (len - first) < needed {
250 return Err(IftRsiError::NotEnoughValidData {
251 needed,
252 valid: len - first,
253 });
254 }
255
256 if kernel.is_batch() {
257 return Err(IftRsiError::WrongKernelForBatch);
258 }
259
260 let warmup_period = first + rsi_period + wma_period - 1;
261 let mut out = alloc_with_nan_prefix(len, warmup_period);
262
263 unsafe {
264 ift_rsi_scalar_classic(data, rsi_period, wma_period, first, &mut out)?;
265 }
266
267 Ok(IftRsiOutput { values: out })
268}
269
270#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
271#[inline]
272pub fn ift_rsi_into(input: &IftRsiInput, out: &mut [f64]) -> Result<(), IftRsiError> {
273 let data: &[f64] = match &input.data {
274 IftRsiData::Candles { candles, source } => source_type(candles, source),
275 IftRsiData::Slice(sl) => sl,
276 };
277
278 if out.len() != data.len() {
279 return Err(IftRsiError::OutputLengthMismatch {
280 expected: data.len(),
281 got: out.len(),
282 });
283 }
284
285 let kern = Kernel::Auto;
286 ift_rsi_into_slice(out, input, kern)
287}
288
289#[inline(always)]
290fn ift_rsi_compute_into(
291 data: &[f64],
292 rsi_period: usize,
293 wma_period: usize,
294 first_valid: usize,
295 out: &mut [f64],
296) -> Result<(), IftRsiError> {
297 let sliced = &data[first_valid..];
298 let mut rsi_values = rsi(&RsiInput::from_slice(
299 sliced,
300 RsiParams {
301 period: Some(rsi_period),
302 },
303 ))
304 .map_err(|e| IftRsiError::RsiCalculationError(e.to_string()))?
305 .values;
306
307 for val in rsi_values.iter_mut() {
308 if !val.is_nan() {
309 *val = 0.1 * (*val - 50.0);
310 }
311 }
312
313 let wma_values = wma(&WmaInput::from_slice(
314 &rsi_values,
315 WmaParams {
316 period: Some(wma_period),
317 },
318 ))
319 .map_err(|e| IftRsiError::WmaCalculationError(e.to_string()))?
320 .values;
321
322 for (i, &w) in wma_values.iter().enumerate() {
323 if !w.is_nan() {
324 out[first_valid + i] = w.tanh();
325 }
326 }
327 Ok(())
328}
329
330#[inline]
331pub fn ift_rsi_scalar(
332 data: &[f64],
333 rsi_period: usize,
334 wma_period: usize,
335 first_valid: usize,
336 out: &mut [f64],
337) -> Result<(), IftRsiError> {
338 ift_rsi_compute_into(data, rsi_period, wma_period, first_valid, out)
339}
340
341pub fn ift_rsi_into_slice(
342 dst: &mut [f64],
343 input: &IftRsiInput,
344 kern: Kernel,
345) -> Result<(), IftRsiError> {
346 let data: &[f64] = match &input.data {
347 IftRsiData::Candles { candles, source } => source_type(candles, source),
348 IftRsiData::Slice(sl) => sl,
349 };
350
351 if data.is_empty() {
352 return Err(IftRsiError::EmptyData);
353 }
354
355 if dst.len() != data.len() {
356 return Err(IftRsiError::OutputLengthMismatch {
357 expected: data.len(),
358 got: dst.len(),
359 });
360 }
361
362 let first = data
363 .iter()
364 .position(|x| !x.is_nan())
365 .ok_or(IftRsiError::AllValuesNaN)?;
366 let rsi_period = input.get_rsi_period();
367 let wma_period = input.get_wma_period();
368
369 if rsi_period == 0 || wma_period == 0 || rsi_period > data.len() || wma_period > data.len() {
370 return Err(IftRsiError::InvalidPeriod {
371 rsi_period,
372 wma_period,
373 data_len: data.len(),
374 });
375 }
376
377 let warmup_period = (first + rsi_period + wma_period - 1).min(dst.len());
378 for v in &mut dst[..warmup_period] {
379 *v = f64::NAN;
380 }
381
382 unsafe {
383 return ift_rsi_scalar_classic(data, rsi_period, wma_period, first, dst);
384 }
385
386 Ok(())
387}
388
389#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
390#[inline]
391pub fn ift_rsi_avx512(
392 data: &[f64],
393 rsi_period: usize,
394 wma_period: usize,
395 first_valid: usize,
396 out: &mut [f64],
397) -> Result<(), IftRsiError> {
398 unsafe { ift_rsi_scalar_classic(data, rsi_period, wma_period, first_valid, out) }
399}
400
401#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
402#[inline]
403pub fn ift_rsi_avx2(
404 data: &[f64],
405 rsi_period: usize,
406 wma_period: usize,
407 first_valid: usize,
408 out: &mut [f64],
409) -> Result<(), IftRsiError> {
410 unsafe { ift_rsi_scalar_classic(data, rsi_period, wma_period, first_valid, out) }
411}
412
413#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
414#[inline]
415pub fn ift_rsi_avx512_short(
416 data: &[f64],
417 rsi_period: usize,
418 wma_period: usize,
419 first_valid: usize,
420 out: &mut [f64],
421) -> Result<(), IftRsiError> {
422 ift_rsi_avx512(data, rsi_period, wma_period, first_valid, out)
423}
424
425#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
426#[inline]
427pub fn ift_rsi_avx512_long(
428 data: &[f64],
429 rsi_period: usize,
430 wma_period: usize,
431 first_valid: usize,
432 out: &mut [f64],
433) -> Result<(), IftRsiError> {
434 ift_rsi_avx512(data, rsi_period, wma_period, first_valid, out)
435}
436
437#[inline]
438pub fn ift_rsi_batch_with_kernel(
439 data: &[f64],
440 sweep: &IftRsiBatchRange,
441 k: Kernel,
442) -> Result<IftRsiBatchOutput, IftRsiError> {
443 let kernel = match k {
444 Kernel::Auto => detect_best_batch_kernel(),
445 other if other.is_batch() => other,
446 other => return Err(IftRsiError::InvalidKernelForBatch(other)),
447 };
448 let simd = match kernel {
449 Kernel::Avx512Batch => Kernel::Avx512,
450 Kernel::Avx2Batch => Kernel::Avx2,
451 Kernel::ScalarBatch => Kernel::Scalar,
452 _ => unreachable!(),
453 };
454 ift_rsi_batch_par_slice(data, sweep, simd)
455}
456
457#[derive(Clone, Debug)]
458pub struct IftRsiBatchRange {
459 pub rsi_period: (usize, usize, usize),
460 pub wma_period: (usize, usize, usize),
461}
462
463impl Default for IftRsiBatchRange {
464 fn default() -> Self {
465 Self {
466 rsi_period: (5, 5, 0),
467 wma_period: (9, 258, 1),
468 }
469 }
470}
471
472#[derive(Clone, Debug, Default)]
473pub struct IftRsiBatchBuilder {
474 range: IftRsiBatchRange,
475 kernel: Kernel,
476}
477
478impl IftRsiBatchBuilder {
479 pub fn new() -> Self {
480 Self::default()
481 }
482 pub fn kernel(mut self, k: Kernel) -> Self {
483 self.kernel = k;
484 self
485 }
486 #[inline]
487 pub fn rsi_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
488 self.range.rsi_period = (start, end, step);
489 self
490 }
491 #[inline]
492 pub fn rsi_period_static(mut self, p: usize) -> Self {
493 self.range.rsi_period = (p, p, 0);
494 self
495 }
496 #[inline]
497 pub fn wma_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
498 self.range.wma_period = (start, end, step);
499 self
500 }
501 #[inline]
502 pub fn wma_period_static(mut self, n: usize) -> Self {
503 self.range.wma_period = (n, n, 0);
504 self
505 }
506 pub fn apply_slice(self, data: &[f64]) -> Result<IftRsiBatchOutput, IftRsiError> {
507 ift_rsi_batch_with_kernel(data, &self.range, self.kernel)
508 }
509 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<IftRsiBatchOutput, IftRsiError> {
510 let slice = source_type(c, src);
511 self.apply_slice(slice)
512 }
513 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<IftRsiBatchOutput, IftRsiError> {
514 IftRsiBatchBuilder::new().kernel(k).apply_slice(data)
515 }
516 pub fn with_default_candles(c: &Candles) -> Result<IftRsiBatchOutput, IftRsiError> {
517 IftRsiBatchBuilder::new()
518 .kernel(Kernel::Auto)
519 .apply_candles(c, "close")
520 }
521}
522
523#[derive(Clone, Debug)]
524pub struct IftRsiBatchOutput {
525 pub values: Vec<f64>,
526 pub combos: Vec<IftRsiParams>,
527 pub rows: usize,
528 pub cols: usize,
529}
530
531impl IftRsiBatchOutput {
532 pub fn row_for_params(&self, p: &IftRsiParams) -> Option<usize> {
533 self.combos.iter().position(|c| {
534 c.rsi_period.unwrap_or(5) == p.rsi_period.unwrap_or(5)
535 && c.wma_period.unwrap_or(9) == p.wma_period.unwrap_or(9)
536 })
537 }
538
539 pub fn values_for(&self, p: &IftRsiParams) -> Option<&[f64]> {
540 self.row_for_params(p).map(|row| {
541 let start = row * self.cols;
542 &self.values[start..start + self.cols]
543 })
544 }
545}
546
547#[inline(always)]
548fn expand_grid(r: &IftRsiBatchRange) -> Result<Vec<IftRsiParams>, IftRsiError> {
549 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, IftRsiError> {
550 if step == 0 || start == end {
551 return Ok(vec![start]);
552 }
553 let (lo, hi) = if start <= end {
554 (start, end)
555 } else {
556 (end, start)
557 };
558 let vals: Vec<usize> = (lo..=hi).step_by(step).collect();
559 if vals.is_empty() {
560 return Err(IftRsiError::InvalidRange { start, end, step });
561 }
562 Ok(vals)
563 }
564 let rsi_periods = axis_usize(r.rsi_period)?;
565 let wma_periods = axis_usize(r.wma_period)?;
566 let cap =
567 rsi_periods
568 .len()
569 .checked_mul(wma_periods.len())
570 .ok_or(IftRsiError::InvalidRange {
571 start: r.rsi_period.0,
572 end: r.rsi_period.1,
573 step: r.rsi_period.2,
574 })?;
575 let mut out = Vec::with_capacity(cap);
576 for &rsi_p in &rsi_periods {
577 for &wma_p in &wma_periods {
578 out.push(IftRsiParams {
579 rsi_period: Some(rsi_p),
580 wma_period: Some(wma_p),
581 });
582 }
583 }
584 Ok(out)
585}
586
587#[inline(always)]
588pub fn ift_rsi_batch_slice(
589 data: &[f64],
590 sweep: &IftRsiBatchRange,
591 kern: Kernel,
592) -> Result<IftRsiBatchOutput, IftRsiError> {
593 ift_rsi_batch_inner(data, sweep, kern, false)
594}
595
596#[inline(always)]
597pub fn ift_rsi_batch_par_slice(
598 data: &[f64],
599 sweep: &IftRsiBatchRange,
600 kern: Kernel,
601) -> Result<IftRsiBatchOutput, IftRsiError> {
602 ift_rsi_batch_inner(data, sweep, kern, true)
603}
604
605#[inline(always)]
606fn ift_rsi_batch_inner(
607 data: &[f64],
608 sweep: &IftRsiBatchRange,
609 kern: Kernel,
610 parallel: bool,
611) -> Result<IftRsiBatchOutput, IftRsiError> {
612 let combos = expand_grid(sweep)?;
613 let first = data
614 .iter()
615 .position(|x| !x.is_nan())
616 .ok_or(IftRsiError::AllValuesNaN)?;
617 let max_rsi = combos.iter().map(|c| c.rsi_period.unwrap()).max().unwrap();
618 let max_wma = combos.iter().map(|c| c.wma_period.unwrap()).max().unwrap();
619 let max_p = max_rsi.max(max_wma);
620 if data.len() - first < max_p {
621 return Err(IftRsiError::NotEnoughValidData {
622 needed: max_p,
623 valid: data.len() - first,
624 });
625 }
626
627 let rows = combos.len();
628 let cols = data.len();
629 rows.checked_mul(cols).ok_or(IftRsiError::InvalidRange {
630 start: sweep.rsi_period.0,
631 end: sweep.rsi_period.1,
632 step: sweep.rsi_period.2,
633 })?;
634
635 let warmup_periods: Vec<usize> = combos
636 .iter()
637 .map(|c| first + c.rsi_period.unwrap() + c.wma_period.unwrap() - 1)
638 .collect();
639
640 let mut buf_mu = make_uninit_matrix(rows, cols);
641 init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
642
643 let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
644 let values: &mut [f64] = unsafe {
645 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
646 };
647
648 let sliced = &data[first..];
649 let n = sliced.len();
650 let mut gains = Vec::with_capacity(n.saturating_sub(1));
651 let mut losses = Vec::with_capacity(n.saturating_sub(1));
652 for i in 1..n {
653 let d = sliced[i] - sliced[i - 1];
654 if d > 0.0 {
655 gains.push(d);
656 losses.push(0.0);
657 } else {
658 gains.push(0.0);
659 losses.push(-d);
660 }
661 }
662
663 let n1 = gains.len();
664 let mut pg = Vec::with_capacity(n1 + 1);
665 let mut pl = Vec::with_capacity(n1 + 1);
666 pg.push(0.0);
667 pl.push(0.0);
668 for i in 0..n1 {
669 pg.push(pg[i] + gains[i]);
670 pl.push(pl[i] + losses[i]);
671 }
672
673 let n1 = gains.len();
674 let mut pg = Vec::with_capacity(n1 + 1);
675 let mut pl = Vec::with_capacity(n1 + 1);
676 pg.push(0.0);
677 pl.push(0.0);
678 for i in 0..n1 {
679 pg.push(pg[i] + gains[i]);
680 pl.push(pl[i] + losses[i]);
681 }
682
683 let n1 = gains.len();
684 let mut pg = Vec::with_capacity(n1 + 1);
685 let mut pl = Vec::with_capacity(n1 + 1);
686 pg.push(0.0);
687 pl.push(0.0);
688 for i in 0..n1 {
689 pg.push(pg[i] + gains[i]);
690 pl.push(pl[i] + losses[i]);
691 }
692
693 let n1 = gains.len();
694 let mut pg = Vec::with_capacity(n1 + 1);
695 let mut pl = Vec::with_capacity(n1 + 1);
696 pg.push(0.0);
697 pl.push(0.0);
698 for i in 0..n1 {
699 pg.push(pg[i] + gains[i]);
700 pl.push(pl[i] + losses[i]);
701 }
702
703 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
704 let rsi_p = combos[row].rsi_period.unwrap();
705 let wma_p = combos[row].wma_period.unwrap();
706 match kern {
707 Kernel::Scalar => ift_rsi_row_scalar_precomputed_ps(
708 &gains, &losses, &pg, &pl, rsi_p, wma_p, first, out_row,
709 ),
710 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
711 Kernel::Avx2 => ift_rsi_row_avx2(data, first, rsi_p, wma_p, out_row),
712 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
713 Kernel::Avx512 => ift_rsi_row_avx512(data, first, rsi_p, wma_p, out_row),
714 _ => unreachable!(),
715 }
716 };
717
718 if parallel {
719 #[cfg(not(target_arch = "wasm32"))]
720 {
721 values
722 .par_chunks_mut(cols)
723 .enumerate()
724 .for_each(|(row, slice)| do_row(row, slice));
725 }
726
727 #[cfg(target_arch = "wasm32")]
728 {
729 for (row, slice) in values.chunks_mut(cols).enumerate() {
730 do_row(row, slice);
731 }
732 }
733 } else {
734 for (row, slice) in values.chunks_mut(cols).enumerate() {
735 do_row(row, slice);
736 }
737 }
738
739 let values = unsafe {
740 Vec::from_raw_parts(
741 buf_guard.as_mut_ptr() as *mut f64,
742 buf_guard.len(),
743 buf_guard.capacity(),
744 )
745 };
746
747 Ok(IftRsiBatchOutput {
748 values,
749 combos,
750 rows,
751 cols,
752 })
753}
754
755#[inline(always)]
756fn ift_rsi_batch_inner_into(
757 data: &[f64],
758 sweep: &IftRsiBatchRange,
759 kern: Kernel,
760 parallel: bool,
761 out: &mut [f64],
762) -> Result<Vec<IftRsiParams>, IftRsiError> {
763 let combos = expand_grid(sweep)?;
764 let rows = combos.len();
765 let cols = data.len();
766 rows.checked_mul(cols).ok_or(IftRsiError::InvalidRange {
767 start: sweep.rsi_period.0,
768 end: sweep.rsi_period.1,
769 step: sweep.rsi_period.2,
770 })?;
771 let first = data
772 .iter()
773 .position(|x| !x.is_nan())
774 .ok_or(IftRsiError::AllValuesNaN)?;
775 let max_rsi = combos.iter().map(|c| c.rsi_period.unwrap()).max().unwrap();
776 let max_wma = combos.iter().map(|c| c.wma_period.unwrap()).max().unwrap();
777 let max_p = max_rsi.max(max_wma);
778 if data.len() - first < max_p {
779 return Err(IftRsiError::NotEnoughValidData {
780 needed: max_p,
781 valid: data.len() - first,
782 });
783 }
784
785 let rows = combos.len();
786 let cols = data.len();
787
788 for (row, combo) in combos.iter().enumerate() {
789 let warmup = (first + combo.rsi_period.unwrap() + combo.wma_period.unwrap() - 1).min(cols);
790 let row_start = row * cols;
791 for i in 0..warmup {
792 out[row_start + i] = f64::NAN;
793 }
794 }
795
796 let sliced = &data[first..];
797 let n = sliced.len();
798 let mut gains = Vec::with_capacity(n.saturating_sub(1));
799 let mut losses = Vec::with_capacity(n.saturating_sub(1));
800 for i in 1..n {
801 let d = sliced[i] - sliced[i - 1];
802 if d > 0.0 {
803 gains.push(d);
804 losses.push(0.0);
805 } else {
806 gains.push(0.0);
807 losses.push(-d);
808 }
809 }
810
811 let n1 = gains.len();
812 let mut pg = Vec::with_capacity(n1 + 1);
813 let mut pl = Vec::with_capacity(n1 + 1);
814 pg.push(0.0);
815 pl.push(0.0);
816 for i in 0..n1 {
817 pg.push(pg[i] + gains[i]);
818 pl.push(pl[i] + losses[i]);
819 }
820
821 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
822 let rsi_p = combos[row].rsi_period.unwrap();
823 let wma_p = combos[row].wma_period.unwrap();
824 match kern {
825 Kernel::Scalar => ift_rsi_row_scalar_precomputed_ps(
826 &gains, &losses, &pg, &pl, rsi_p, wma_p, first, out_row,
827 ),
828 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
829 Kernel::Avx2 => ift_rsi_row_avx2(data, first, rsi_p, wma_p, out_row),
830 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
831 Kernel::Avx512 => ift_rsi_row_avx512(data, first, rsi_p, wma_p, out_row),
832 _ => unreachable!(),
833 }
834 };
835
836 if parallel {
837 #[cfg(not(target_arch = "wasm32"))]
838 {
839 out.par_chunks_mut(cols)
840 .enumerate()
841 .for_each(|(row, slice)| do_row(row, slice));
842 }
843
844 #[cfg(target_arch = "wasm32")]
845 {
846 for (row, slice) in out.chunks_mut(cols).enumerate() {
847 do_row(row, slice);
848 }
849 }
850 } else {
851 for (row, slice) in out.chunks_mut(cols).enumerate() {
852 do_row(row, slice);
853 }
854 }
855
856 Ok(combos)
857}
858
859#[inline(always)]
860unsafe fn ift_rsi_row_scalar(
861 data: &[f64],
862 first: usize,
863 rsi_period: usize,
864 wma_period: usize,
865 out: &mut [f64],
866) {
867 let sliced = &data[first..];
868 let n = sliced.len();
869 if n == 0 {
870 return;
871 }
872 let mut gains = Vec::with_capacity(n.saturating_sub(1));
873 let mut losses = Vec::with_capacity(n.saturating_sub(1));
874 for i in 1..n {
875 let d = sliced[i] - sliced[i - 1];
876 if d > 0.0 {
877 gains.push(d);
878 losses.push(0.0);
879 } else {
880 gains.push(0.0);
881 losses.push(-d);
882 }
883 }
884 ift_rsi_row_scalar_precomputed(&gains, &losses, rsi_period, wma_period, first, out);
885}
886
887#[inline(always)]
888unsafe fn ift_rsi_row_scalar_precomputed(
889 gains: &[f64],
890 losses: &[f64],
891 rsi_period: usize,
892 wma_period: usize,
893 first: usize,
894 out: &mut [f64],
895) {
896 let n1 = gains.len();
897 if rsi_period == 0 || wma_period == 0 {
898 return;
899 }
900 if rsi_period + wma_period - 1 >= n1 + 1 {
901 return;
902 }
903
904 let mut avg_gain = 0.0f64;
905 let mut avg_loss = 0.0f64;
906 for i in 0..rsi_period {
907 avg_gain += *gains.get_unchecked(i);
908 avg_loss += *losses.get_unchecked(i);
909 }
910 let rp_f = rsi_period as f64;
911 avg_gain /= rp_f;
912 avg_loss /= rp_f;
913 let alpha = 1.0f64 / rp_f;
914 let beta = 1.0f64 - alpha;
915
916 let wp = wma_period;
917 let wp_f = wp as f64;
918 let denom = 0.5f64 * wp_f * (wp_f + 1.0);
919 let denom_rcp = 1.0f64 / denom;
920 let mut buf: Vec<f64> = vec![0.0; wp];
921 let mut head = 0usize;
922 let mut filled = 0usize;
923 let mut sum = 0.0f64;
924 let mut num = 0.0f64;
925
926 let mut i = rsi_period;
927 while i <= n1 {
928 if i > rsi_period {
929 let g = *gains.get_unchecked(i - 1);
930 let l = *losses.get_unchecked(i - 1);
931 avg_gain = f64::mul_add(avg_gain, beta, alpha * g);
932 avg_loss = f64::mul_add(avg_loss, beta, alpha * l);
933 }
934
935 let rs = if avg_loss != 0.0 {
936 avg_gain / avg_loss
937 } else {
938 100.0
939 };
940 let rsi = 100.0 - 100.0 / (1.0 + rs);
941 let x = 0.1f64 * (rsi - 50.0);
942
943 if filled < wp {
944 sum += x;
945 num = f64::mul_add((filled as f64) + 1.0, x, num);
946 *buf.get_unchecked_mut(head) = x;
947 head += 1;
948 if head == wp {
949 head = 0;
950 }
951 filled += 1;
952 if filled == wp {
953 let wma = num * denom_rcp;
954 *out.get_unchecked_mut(first + i) = wma.tanh();
955 }
956 } else {
957 let x_old = *buf.get_unchecked(head);
958 *buf.get_unchecked_mut(head) = x;
959 head += 1;
960 if head == wp {
961 head = 0;
962 }
963 let sum_t = sum;
964 num = f64::mul_add(wp_f, x, num) - sum_t;
965 sum = sum_t + x - x_old;
966 let wma = num * denom_rcp;
967 *out.get_unchecked_mut(first + i) = wma.tanh();
968 }
969
970 i += 1;
971 if i > n1 {
972 break;
973 }
974 }
975}
976
977#[inline(always)]
978unsafe fn ift_rsi_row_scalar_precomputed_ps(
979 gains: &[f64],
980 losses: &[f64],
981 pg: &[f64],
982 pl: &[f64],
983 rsi_period: usize,
984 wma_period: usize,
985 first: usize,
986 out: &mut [f64],
987) {
988 let n1 = gains.len();
989 if rsi_period == 0 || wma_period == 0 {
990 return;
991 }
992 if rsi_period + wma_period - 1 >= n1 + 1 {
993 return;
994 }
995
996 let sum_gain = *pg.get_unchecked(rsi_period) - *pg.get_unchecked(0);
997 let sum_loss = *pl.get_unchecked(rsi_period) - *pl.get_unchecked(0);
998 let rp_f = rsi_period as f64;
999 let mut avg_gain = sum_gain / rp_f;
1000 let mut avg_loss = sum_loss / rp_f;
1001 let alpha = 1.0f64 / rp_f;
1002 let beta = 1.0f64 - alpha;
1003
1004 let wp = wma_period;
1005 let wp_f = wp as f64;
1006 let denom = 0.5f64 * wp_f * (wp_f + 1.0);
1007 let denom_rcp = 1.0f64 / denom;
1008 let mut buf: Vec<f64> = vec![0.0; wp];
1009 let mut head = 0usize;
1010 let mut filled = 0usize;
1011 let mut sum = 0.0f64;
1012 let mut num = 0.0f64;
1013
1014 let mut i = rsi_period;
1015 while i <= n1 {
1016 if i > rsi_period {
1017 let g = *gains.get_unchecked(i - 1);
1018 let l = *losses.get_unchecked(i - 1);
1019 avg_gain = f64::mul_add(avg_gain, beta, alpha * g);
1020 avg_loss = f64::mul_add(avg_loss, beta, alpha * l);
1021 }
1022
1023 let rs = if avg_loss != 0.0 {
1024 avg_gain / avg_loss
1025 } else {
1026 100.0
1027 };
1028 let rsi = 100.0 - 100.0 / (1.0 + rs);
1029 let x = 0.1f64 * (rsi - 50.0);
1030
1031 if filled < wp {
1032 sum += x;
1033 num = f64::mul_add((filled as f64) + 1.0, x, num);
1034 *buf.get_unchecked_mut(head) = x;
1035 head += 1;
1036 if head == wp {
1037 head = 0;
1038 }
1039 filled += 1;
1040 if filled == wp {
1041 let wma = num * denom_rcp;
1042 *out.get_unchecked_mut(first + i) = wma.tanh();
1043 }
1044 } else {
1045 let x_old = *buf.get_unchecked(head);
1046 *buf.get_unchecked_mut(head) = x;
1047 head += 1;
1048 if head == wp {
1049 head = 0;
1050 }
1051 let sum_t = sum;
1052 num = f64::mul_add(wp_f, x, num) - sum_t;
1053 sum = sum_t + x - x_old;
1054 let wma = num * denom_rcp;
1055 *out.get_unchecked_mut(first + i) = wma.tanh();
1056 }
1057
1058 i += 1;
1059 if i > n1 {
1060 break;
1061 }
1062 }
1063}
1064
1065#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1066#[inline(always)]
1067unsafe fn ift_rsi_row_avx2(
1068 data: &[f64],
1069 first: usize,
1070 rsi_period: usize,
1071 wma_period: usize,
1072 out: &mut [f64],
1073) {
1074 ift_rsi_row_scalar(data, first, rsi_period, wma_period, out)
1075}
1076
1077#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1078#[inline(always)]
1079unsafe fn ift_rsi_row_avx512(
1080 data: &[f64],
1081 first: usize,
1082 rsi_period: usize,
1083 wma_period: usize,
1084 out: &mut [f64],
1085) {
1086 if rsi_period.max(wma_period) <= 32 {
1087 ift_rsi_row_avx512_short(data, first, rsi_period, wma_period, out);
1088 } else {
1089 ift_rsi_row_avx512_long(data, first, rsi_period, wma_period, out);
1090 }
1091}
1092
1093#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1094#[inline(always)]
1095unsafe fn ift_rsi_row_avx512_short(
1096 data: &[f64],
1097 first: usize,
1098 rsi_period: usize,
1099 wma_period: usize,
1100 out: &mut [f64],
1101) {
1102 ift_rsi_row_scalar(data, first, rsi_period, wma_period, out);
1103}
1104
1105#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1106#[inline(always)]
1107unsafe fn ift_rsi_row_avx512_long(
1108 data: &[f64],
1109 first: usize,
1110 rsi_period: usize,
1111 wma_period: usize,
1112 out: &mut [f64],
1113) {
1114 ift_rsi_row_scalar(data, first, rsi_period, wma_period, out);
1115}
1116
1117#[derive(Debug, Clone)]
1118pub struct IftRsiStream {
1119 rsi_period: usize,
1120 wma_period: usize,
1121
1122 prev: f64,
1123 have_prev: bool,
1124 seed_g: f64,
1125 seed_l: f64,
1126 seed_cnt: usize,
1127 avg_gain: f64,
1128 avg_loss: f64,
1129 seeded: bool,
1130 alpha: f64,
1131 beta: f64,
1132
1133 buf: Vec<f64>,
1134 head: usize,
1135 filled: usize,
1136 sum: f64,
1137 num: f64,
1138 wp_f: f64,
1139 denom_rcp: f64,
1140}
1141
1142impl IftRsiStream {
1143 pub fn try_new(params: IftRsiParams) -> Result<Self, IftRsiError> {
1144 let rsi_period = params.rsi_period.unwrap_or(5);
1145 let wma_period = params.wma_period.unwrap_or(9);
1146 if rsi_period == 0 || wma_period == 0 {
1147 return Err(IftRsiError::InvalidPeriod {
1148 rsi_period,
1149 wma_period,
1150 data_len: 0,
1151 });
1152 }
1153 let wp_f = wma_period as f64;
1154 let denom = 0.5 * wp_f * (wp_f + 1.0);
1155 Ok(Self {
1156 rsi_period,
1157 wma_period,
1158
1159 prev: 0.0,
1160 have_prev: false,
1161 seed_g: 0.0,
1162 seed_l: 0.0,
1163 seed_cnt: 0,
1164 avg_gain: 0.0,
1165 avg_loss: 0.0,
1166 seeded: false,
1167 alpha: 1.0 / (rsi_period as f64),
1168 beta: 1.0 - 1.0 / (rsi_period as f64),
1169
1170 buf: vec![0.0; wma_period],
1171 head: 0,
1172 filled: 0,
1173 sum: 0.0,
1174 num: 0.0,
1175 wp_f,
1176 denom_rcp: 1.0 / denom,
1177 })
1178 }
1179
1180 #[inline]
1181 pub fn update(&mut self, value: f64) -> Option<f64> {
1182 if !value.is_finite() {
1183 self.reset_soft();
1184 return None;
1185 }
1186
1187 if !self.have_prev {
1188 self.prev = value;
1189 self.have_prev = true;
1190 return None;
1191 }
1192
1193 let d = value - self.prev;
1194 self.prev = value;
1195 let gain = if d > 0.0 { d } else { 0.0 };
1196 let loss = if d < 0.0 { -d } else { 0.0 };
1197
1198 if !self.seeded {
1199 self.seed_g += gain;
1200 self.seed_l += loss;
1201 self.seed_cnt += 1;
1202 if self.seed_cnt < self.rsi_period {
1203 return None;
1204 }
1205
1206 self.avg_gain = self.seed_g / (self.rsi_period as f64);
1207 self.avg_loss = self.seed_l / (self.rsi_period as f64);
1208 self.seeded = true;
1209 } else {
1210 self.avg_gain = f64::mul_add(self.avg_gain, self.beta, self.alpha * gain);
1211 self.avg_loss = f64::mul_add(self.avg_loss, self.beta, self.alpha * loss);
1212 }
1213
1214 let rs = if self.avg_loss != 0.0 {
1215 self.avg_gain / self.avg_loss
1216 } else {
1217 100.0
1218 };
1219 let rsi = 100.0 - 100.0 / (1.0 + rs);
1220 let x = 0.1 * (rsi - 50.0);
1221
1222 if self.filled < self.wma_period {
1223 self.sum += x;
1224 self.num = f64::mul_add((self.filled as f64) + 1.0, x, self.num);
1225 self.buf[self.head] = x;
1226 self.head += 1;
1227 if self.head == self.wma_period {
1228 self.head = 0;
1229 }
1230 self.filled += 1;
1231
1232 if self.filled == self.wma_period {
1233 let wma = self.num * self.denom_rcp;
1234 return Some(tanh_kernel(wma));
1235 }
1236 return None;
1237 } else {
1238 let x_old = self.buf[self.head];
1239 self.buf[self.head] = x;
1240 self.head += 1;
1241 if self.head == self.wma_period {
1242 self.head = 0;
1243 }
1244
1245 let sum_prev = self.sum;
1246 self.num = f64::mul_add(self.wp_f, x, self.num) - sum_prev;
1247 self.sum = sum_prev + x - x_old;
1248
1249 let wma = self.num * self.denom_rcp;
1250 return Some(tanh_kernel(wma));
1251 }
1252 }
1253
1254 #[inline]
1255 fn reset_soft(&mut self) {
1256 self.have_prev = false;
1257 self.seed_g = 0.0;
1258 self.seed_l = 0.0;
1259 self.seed_cnt = 0;
1260 self.avg_gain = 0.0;
1261 self.avg_loss = 0.0;
1262 self.seeded = false;
1263
1264 self.head = 0;
1265 self.filled = 0;
1266 self.sum = 0.0;
1267 self.num = 0.0;
1268 for v in &mut self.buf {
1269 *v = 0.0;
1270 }
1271 }
1272}
1273
1274#[inline(always)]
1275fn tanh_kernel(x: f64) -> f64 {
1276 x.tanh()
1277}
1278
1279#[inline]
1280pub unsafe fn ift_rsi_scalar_classic(
1281 data: &[f64],
1282 rsi_period: usize,
1283 wma_period: usize,
1284 first_valid: usize,
1285 out: &mut [f64],
1286) -> Result<(), IftRsiError> {
1287 debug_assert!(rsi_period > 0 && wma_period > 0);
1288 let len = data.len();
1289 if first_valid >= len {
1290 return Ok(());
1291 }
1292 let sliced = data.get_unchecked(first_valid..);
1293 let n = sliced.len();
1294 if n == 0 {
1295 return Ok(());
1296 }
1297
1298 if rsi_period + wma_period - 1 >= n {
1299 return Ok(());
1300 }
1301
1302 let rp = rsi_period;
1303 let rp_f = rp as f64;
1304 let alpha = 1.0f64 / rp_f;
1305 let beta = 1.0f64 - alpha;
1306
1307 let mut avg_gain = 0.0f64;
1308 let mut avg_loss = 0.0f64;
1309 {
1310 let mut i = 1usize;
1311 while i <= rp {
1312 let d = *sliced.get_unchecked(i) - *sliced.get_unchecked(i - 1);
1313 if d > 0.0 {
1314 avg_gain += d;
1315 } else {
1316 avg_loss -= d;
1317 }
1318 i += 1;
1319 }
1320 avg_gain /= rp_f;
1321 avg_loss /= rp_f;
1322 }
1323
1324 let wp = wma_period;
1325 let wp_f = wp as f64;
1326 let denom = 0.5f64 * wp_f * (wp_f + 1.0);
1327 let denom_rcp = 1.0f64 / denom;
1328
1329 let mut buf: Vec<f64> = vec![0.0; wp];
1330 let mut head: usize = 0;
1331 let mut filled: usize = 0;
1332
1333 let mut sum = 0.0f64;
1334 let mut num = 0.0f64;
1335
1336 let mut i = rp;
1337 while i < n {
1338 if i > rp {
1339 let d = *sliced.get_unchecked(i) - *sliced.get_unchecked(i - 1);
1340 let gain = if d > 0.0 { d } else { 0.0 };
1341 let loss = if d < 0.0 { -d } else { 0.0 };
1342 avg_gain = f64::mul_add(avg_gain, beta, alpha * gain);
1343 avg_loss = f64::mul_add(avg_loss, beta, alpha * loss);
1344 }
1345
1346 let rs = if avg_loss != 0.0 {
1347 avg_gain / avg_loss
1348 } else {
1349 100.0
1350 };
1351 let rsi = 100.0 - 100.0 / (1.0 + rs);
1352 let x = 0.1f64 * (rsi - 50.0);
1353
1354 if filled < wp {
1355 sum += x;
1356 num = f64::mul_add((filled as f64) + 1.0, x, num);
1357 *buf.get_unchecked_mut(head) = x;
1358 head += 1;
1359 if head == wp {
1360 head = 0;
1361 }
1362 filled += 1;
1363
1364 if filled == wp {
1365 let wma = num * denom_rcp;
1366 *out.get_unchecked_mut(first_valid + i) = wma.tanh();
1367 }
1368 } else {
1369 let x_old = *buf.get_unchecked(head);
1370 *buf.get_unchecked_mut(head) = x;
1371 head += 1;
1372 if head == wp {
1373 head = 0;
1374 }
1375
1376 let sum_t = sum;
1377 num = f64::mul_add(wp_f, x, num) - sum_t;
1378 sum = sum_t + x - x_old;
1379
1380 let wma = num * denom_rcp;
1381 *out.get_unchecked_mut(first_valid + i) = wma.tanh();
1382 }
1383
1384 i += 1;
1385 }
1386
1387 Ok(())
1388}
1389
1390#[cfg(test)]
1391mod tests {
1392 use super::*;
1393 use crate::skip_if_unsupported;
1394 use crate::utilities::data_loader::read_candles_from_csv;
1395
1396 fn check_ift_rsi_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1397 skip_if_unsupported!(kernel, test_name);
1398 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1399 let candles = read_candles_from_csv(file_path)?;
1400 let default_params = IftRsiParams {
1401 rsi_period: None,
1402 wma_period: None,
1403 };
1404 let input = IftRsiInput::from_candles(&candles, "close", default_params);
1405 let output = ift_rsi_with_kernel(&input, kernel)?;
1406 assert_eq!(output.values.len(), candles.close.len());
1407 Ok(())
1408 }
1409
1410 #[test]
1411 fn test_ift_rsi_into_matches_api() -> Result<(), Box<dyn Error>> {
1412 let n = 256usize;
1413 let mut data = Vec::with_capacity(n);
1414 for i in 0..n {
1415 if i < 3 {
1416 data.push(f64::NAN);
1417 } else {
1418 let x = (i as f64).sin() * 5.0 + 100.0 + ((i % 7) as f64);
1419 data.push(x);
1420 }
1421 }
1422
1423 let input = IftRsiInput::from_slice(&data, IftRsiParams::default());
1424
1425 let baseline = ift_rsi(&input)?.values;
1426
1427 let mut out = vec![0.0; data.len()];
1428 ift_rsi_into(&input, &mut out)?;
1429
1430 assert_eq!(baseline.len(), out.len());
1431 for i in 0..out.len() {
1432 let a = baseline[i];
1433 let b = out[i];
1434 let equal = (a.is_nan() && b.is_nan()) || (a == b);
1435 assert!(equal, "Mismatch at {}: baseline={}, into={}", i, a, b);
1436 }
1437 Ok(())
1438 }
1439
1440 fn check_ift_rsi_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1441 skip_if_unsupported!(kernel, test_name);
1442 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1443 let candles = read_candles_from_csv(file_path)?;
1444 let input = IftRsiInput::from_candles(&candles, "close", IftRsiParams::default());
1445 let result = ift_rsi_with_kernel(&input, kernel)?;
1446
1447 let expected_last_five = [
1448 -0.35919800205778424,
1449 -0.3275464113984847,
1450 -0.39970276998138216,
1451 -0.36321812798797737,
1452 -0.5843346528346959,
1453 ];
1454 let start = result.values.len().saturating_sub(5);
1455 for (i, &val) in result.values[start..].iter().enumerate() {
1456 let diff = (val - expected_last_five[i]).abs();
1457 assert!(
1458 diff < 1e-8,
1459 "[{}] IFT_RSI {:?} mismatch at idx {}: got {}, expected {}",
1460 test_name,
1461 kernel,
1462 i,
1463 val,
1464 expected_last_five[i]
1465 );
1466 }
1467 Ok(())
1468 }
1469
1470 fn check_ift_rsi_default_candles(
1471 test_name: &str,
1472 kernel: Kernel,
1473 ) -> Result<(), Box<dyn Error>> {
1474 skip_if_unsupported!(kernel, test_name);
1475 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1476 let candles = read_candles_from_csv(file_path)?;
1477 let input = IftRsiInput::with_default_candles(&candles);
1478 let output = ift_rsi_with_kernel(&input, kernel)?;
1479 assert_eq!(output.values.len(), candles.close.len());
1480 Ok(())
1481 }
1482
1483 fn check_ift_rsi_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1484 skip_if_unsupported!(kernel, test_name);
1485 let input_data = [10.0, 20.0, 30.0];
1486 let params = IftRsiParams {
1487 rsi_period: Some(0),
1488 wma_period: Some(9),
1489 };
1490 let input = IftRsiInput::from_slice(&input_data, params);
1491 let res = ift_rsi_with_kernel(&input, kernel);
1492 assert!(
1493 res.is_err(),
1494 "[{}] IFT_RSI should fail with zero period",
1495 test_name
1496 );
1497 Ok(())
1498 }
1499
1500 fn check_ift_rsi_period_exceeds_length(
1501 test_name: &str,
1502 kernel: Kernel,
1503 ) -> Result<(), Box<dyn Error>> {
1504 skip_if_unsupported!(kernel, test_name);
1505 let data_small = [10.0, 20.0, 30.0];
1506 let params = IftRsiParams {
1507 rsi_period: Some(10),
1508 wma_period: Some(9),
1509 };
1510 let input = IftRsiInput::from_slice(&data_small, params);
1511 let res = ift_rsi_with_kernel(&input, kernel);
1512 assert!(
1513 res.is_err(),
1514 "[{}] IFT_RSI should fail with period exceeding length",
1515 test_name
1516 );
1517 Ok(())
1518 }
1519
1520 fn check_ift_rsi_very_small_dataset(
1521 test_name: &str,
1522 kernel: Kernel,
1523 ) -> Result<(), Box<dyn Error>> {
1524 skip_if_unsupported!(kernel, test_name);
1525 let single_point = [42.0];
1526 let params = IftRsiParams {
1527 rsi_period: Some(5),
1528 wma_period: Some(9),
1529 };
1530 let input = IftRsiInput::from_slice(&single_point, params);
1531 let res = ift_rsi_with_kernel(&input, kernel);
1532 assert!(
1533 res.is_err(),
1534 "[{}] IFT_RSI should fail with insufficient data",
1535 test_name
1536 );
1537 Ok(())
1538 }
1539
1540 fn check_ift_rsi_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1541 skip_if_unsupported!(kernel, test_name);
1542 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1543 let candles = read_candles_from_csv(file_path)?;
1544 let first_params = IftRsiParams {
1545 rsi_period: Some(5),
1546 wma_period: Some(9),
1547 };
1548 let first_input = IftRsiInput::from_candles(&candles, "close", first_params);
1549 let first_result = ift_rsi_with_kernel(&first_input, kernel)?;
1550 let second_params = IftRsiParams {
1551 rsi_period: Some(5),
1552 wma_period: Some(9),
1553 };
1554 let second_input = IftRsiInput::from_slice(&first_result.values, second_params);
1555 let second_result = ift_rsi_with_kernel(&second_input, kernel)?;
1556 assert_eq!(second_result.values.len(), first_result.values.len());
1557 Ok(())
1558 }
1559
1560 fn check_ift_rsi_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1561 skip_if_unsupported!(kernel, test_name);
1562 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1563 let candles = read_candles_from_csv(file_path)?;
1564 let input = IftRsiInput::from_candles(
1565 &candles,
1566 "close",
1567 IftRsiParams {
1568 rsi_period: Some(5),
1569 wma_period: Some(9),
1570 },
1571 );
1572 let res = ift_rsi_with_kernel(&input, kernel)?;
1573 assert_eq!(res.values.len(), candles.close.len());
1574 if res.values.len() > 240 {
1575 for (i, &val) in res.values[240..].iter().enumerate() {
1576 assert!(
1577 !val.is_nan(),
1578 "[{}] Found unexpected NaN at out-index {}",
1579 test_name,
1580 240 + i
1581 );
1582 }
1583 }
1584 Ok(())
1585 }
1586
1587 #[cfg(debug_assertions)]
1588 fn check_ift_rsi_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1589 skip_if_unsupported!(kernel, test_name);
1590
1591 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1592 let candles = read_candles_from_csv(file_path)?;
1593
1594 let test_params = vec![
1595 IftRsiParams::default(),
1596 IftRsiParams {
1597 rsi_period: Some(2),
1598 wma_period: Some(2),
1599 },
1600 IftRsiParams {
1601 rsi_period: Some(3),
1602 wma_period: Some(5),
1603 },
1604 IftRsiParams {
1605 rsi_period: Some(7),
1606 wma_period: Some(14),
1607 },
1608 IftRsiParams {
1609 rsi_period: Some(14),
1610 wma_period: Some(21),
1611 },
1612 IftRsiParams {
1613 rsi_period: Some(21),
1614 wma_period: Some(9),
1615 },
1616 IftRsiParams {
1617 rsi_period: Some(50),
1618 wma_period: Some(50),
1619 },
1620 IftRsiParams {
1621 rsi_period: Some(100),
1622 wma_period: Some(100),
1623 },
1624 IftRsiParams {
1625 rsi_period: Some(2),
1626 wma_period: Some(50),
1627 },
1628 IftRsiParams {
1629 rsi_period: Some(50),
1630 wma_period: Some(2),
1631 },
1632 IftRsiParams {
1633 rsi_period: Some(9),
1634 wma_period: Some(21),
1635 },
1636 IftRsiParams {
1637 rsi_period: Some(25),
1638 wma_period: Some(10),
1639 },
1640 ];
1641
1642 for (param_idx, params) in test_params.iter().enumerate() {
1643 let input = IftRsiInput::from_candles(&candles, "close", params.clone());
1644 let output = ift_rsi_with_kernel(&input, kernel)?;
1645
1646 for (i, &val) in output.values.iter().enumerate() {
1647 if val.is_nan() {
1648 continue;
1649 }
1650
1651 let bits = val.to_bits();
1652
1653 if bits == 0x11111111_11111111 {
1654 panic!(
1655 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1656 with params: rsi_period={}, wma_period={} (param set {})",
1657 test_name,
1658 val,
1659 bits,
1660 i,
1661 params.rsi_period.unwrap_or(5),
1662 params.wma_period.unwrap_or(9),
1663 param_idx
1664 );
1665 }
1666
1667 if bits == 0x22222222_22222222 {
1668 panic!(
1669 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1670 with params: rsi_period={}, wma_period={} (param set {})",
1671 test_name,
1672 val,
1673 bits,
1674 i,
1675 params.rsi_period.unwrap_or(5),
1676 params.wma_period.unwrap_or(9),
1677 param_idx
1678 );
1679 }
1680
1681 if bits == 0x33333333_33333333 {
1682 panic!(
1683 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1684 with params: rsi_period={}, wma_period={} (param set {})",
1685 test_name,
1686 val,
1687 bits,
1688 i,
1689 params.rsi_period.unwrap_or(5),
1690 params.wma_period.unwrap_or(9),
1691 param_idx
1692 );
1693 }
1694 }
1695 }
1696
1697 Ok(())
1698 }
1699
1700 #[cfg(not(debug_assertions))]
1701 fn check_ift_rsi_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1702 Ok(())
1703 }
1704
1705 #[cfg(feature = "proptest")]
1706 fn check_ift_rsi_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1707 use proptest::prelude::*;
1708 skip_if_unsupported!(kernel, test_name);
1709
1710 let strat = (2usize..=50, 2usize..=50)
1711 .prop_flat_map(|(rsi_period, wma_period)| {
1712 let min_len = (rsi_period + wma_period) * 2;
1713 (
1714 (100.0f64..5000.0f64, 0.01f64..0.1f64),
1715 -0.02f64..0.02f64,
1716 Just(rsi_period),
1717 Just(wma_period),
1718 min_len..400,
1719 )
1720 })
1721 .prop_map(
1722 |((base_price, volatility), trend, rsi_period, wma_period, len)| {
1723 let mut prices = Vec::with_capacity(len);
1724 let mut current_price = base_price;
1725
1726 for i in 0..len {
1727 current_price *= 1.0 + trend;
1728
1729 let noise = 1.0 + (i as f64 * 0.1).sin() * volatility;
1730 prices.push(current_price * noise);
1731 }
1732
1733 (prices, rsi_period, wma_period)
1734 },
1735 );
1736
1737 proptest::test_runner::TestRunner::default().run(
1738 &strat,
1739 |(data, rsi_period, wma_period)| {
1740 let params = IftRsiParams {
1741 rsi_period: Some(rsi_period),
1742 wma_period: Some(wma_period),
1743 };
1744 let input = IftRsiInput::from_slice(&data, params);
1745
1746 let IftRsiOutput { values: out } = ift_rsi_with_kernel(&input, kernel)?;
1747
1748 let IftRsiOutput { values: ref_out } = ift_rsi_with_kernel(&input, Kernel::Scalar)?;
1749
1750 prop_assert_eq!(out.len(), data.len(), "Output length mismatch");
1751
1752 let warmup_period = rsi_period + wma_period - 1;
1753 for i in 0..warmup_period.min(data.len()) {
1754 prop_assert!(
1755 out[i].is_nan(),
1756 "Expected NaN during warmup at index {}, got {}",
1757 i,
1758 out[i]
1759 );
1760 }
1761
1762 for i in warmup_period..data.len() {
1763 let y = out[i];
1764 let r = ref_out[i];
1765
1766 if y.is_finite() {
1767 prop_assert!(
1768 y >= -1.0 - 1e-9 && y <= 1.0 + 1e-9,
1769 "IFT RSI value {} at index {} outside [-1, 1] bounds",
1770 y,
1771 i
1772 );
1773 }
1774
1775 if !y.is_finite() || !r.is_finite() {
1776 prop_assert_eq!(
1777 y.to_bits(),
1778 r.to_bits(),
1779 "NaN/Inf mismatch at index {}: {} vs {}",
1780 i,
1781 y,
1782 r
1783 );
1784 } else {
1785 let ulp_diff = y.to_bits().abs_diff(r.to_bits());
1786 prop_assert!(
1787 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1788 "Kernel mismatch at index {}: {} vs {} (ULP={})",
1789 i,
1790 y,
1791 r,
1792 ulp_diff
1793 );
1794 }
1795
1796 if i >= warmup_period + 10 {
1797 let lookback = 10;
1798 let recent_prices = &data[i - lookback..=i];
1799 let price_change =
1800 (recent_prices[lookback] - recent_prices[0]) / recent_prices[0];
1801
1802 if price_change > 0.05 && y.is_finite() {
1803 prop_assert!(
1804 y > 0.2,
1805 "Strong uptrend should produce positive IFT RSI > 0.2, got {} at index {}",
1806 y,
1807 i
1808 );
1809 }
1810
1811 if price_change < -0.05 && y.is_finite() {
1812 prop_assert!(
1813 y < -0.2,
1814 "Strong downtrend should produce negative IFT RSI < -0.2, got {} at index {}",
1815 y,
1816 i
1817 );
1818 }
1819 }
1820
1821 if !data[..=i].iter().any(|x| x.is_nan()) {
1822 prop_assert!(!y.is_nan(), "Unexpected NaN at index {} after warmup", i);
1823 }
1824 }
1825
1826 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10)
1827 && data.len() > warmup_period
1828 {
1829 for i in warmup_period..out.len() {
1830 if out[i].is_finite() {
1831 prop_assert!(
1832 (out[i] - (-1.0)).abs() < 1e-6,
1833 "Constant prices should yield IFT RSI = -1, got {} at index {}",
1834 out[i],
1835 i
1836 );
1837 }
1838 }
1839 }
1840
1841 let volatility = if data.len() > 2 {
1842 let returns: Vec<f64> = data.windows(2).map(|w| (w[1] - w[0]) / w[0]).collect();
1843 let mean_return = returns.iter().sum::<f64>() / returns.len() as f64;
1844 let variance = returns
1845 .iter()
1846 .map(|r| (r - mean_return).powi(2))
1847 .sum::<f64>()
1848 / returns.len() as f64;
1849 variance.sqrt()
1850 } else {
1851 0.0
1852 };
1853
1854 if volatility > 0.1 {
1855 for &val in out.iter() {
1856 if val.is_finite() {
1857 prop_assert!(
1858 val >= -1.0 && val <= 1.0,
1859 "Even with extreme volatility, IFT RSI must be bounded: {}",
1860 val
1861 );
1862 }
1863 }
1864 }
1865
1866 if data.len() > warmup_period + 20 {
1867 for check_idx in (warmup_period + 10..data.len()).step_by(20) {
1868 if check_idx + 5 >= data.len() {
1869 break;
1870 }
1871
1872 let recent_window = &data[check_idx - 5..=check_idx];
1873 let gains: f64 = recent_window
1874 .windows(2)
1875 .map(|w| (w[1] - w[0]).max(0.0))
1876 .sum();
1877 let losses: f64 = recent_window
1878 .windows(2)
1879 .map(|w| (w[0] - w[1]).max(0.0))
1880 .sum();
1881
1882 if gains > losses * 1.5 && out[check_idx].is_finite() {
1883 prop_assert!(
1884 out[check_idx] > -0.1,
1885 "Bullish momentum (gains > losses*1.5) should yield IFT RSI > -0.1, got {} at index {}",
1886 out[check_idx],
1887 check_idx
1888 );
1889 }
1890
1891 if losses > gains * 1.5 && out[check_idx].is_finite() {
1892 prop_assert!(
1893 out[check_idx] < 0.1,
1894 "Bearish momentum (losses > gains*1.5) should yield IFT RSI < 0.1, got {} at index {}",
1895 out[check_idx],
1896 check_idx
1897 );
1898 }
1899 }
1900 }
1901
1902 Ok(())
1903 },
1904 )?;
1905
1906 Ok(())
1907 }
1908
1909 #[cfg(not(feature = "proptest"))]
1910 fn check_ift_rsi_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1911 skip_if_unsupported!(kernel, test_name);
1912 Ok(())
1913 }
1914
1915 macro_rules! generate_all_ift_rsi_tests {
1916 ($($test_fn:ident),*) => {
1917 paste::paste! {
1918 $(
1919 #[test]
1920 fn [<$test_fn _scalar_f64>]() {
1921 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1922 }
1923 )*
1924 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1925 $(
1926 #[test]
1927 fn [<$test_fn _avx2_f64>]() {
1928 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1929 }
1930 #[test]
1931 fn [<$test_fn _avx512_f64>]() {
1932 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1933 }
1934 )*
1935 }
1936 }
1937 }
1938
1939 generate_all_ift_rsi_tests!(
1940 check_ift_rsi_partial_params,
1941 check_ift_rsi_accuracy,
1942 check_ift_rsi_default_candles,
1943 check_ift_rsi_zero_period,
1944 check_ift_rsi_period_exceeds_length,
1945 check_ift_rsi_very_small_dataset,
1946 check_ift_rsi_reinput,
1947 check_ift_rsi_nan_handling,
1948 check_ift_rsi_no_poison,
1949 check_ift_rsi_property
1950 );
1951
1952 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1953 skip_if_unsupported!(kernel, test);
1954 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1955 let c = read_candles_from_csv(file)?;
1956 let output = IftRsiBatchBuilder::new()
1957 .kernel(kernel)
1958 .apply_candles(&c, "close")?;
1959 let def = IftRsiParams::default();
1960 let row = output.values_for(&def).expect("default row missing");
1961 assert_eq!(row.len(), c.close.len());
1962 Ok(())
1963 }
1964
1965 #[cfg(debug_assertions)]
1966 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1967 skip_if_unsupported!(kernel, test);
1968
1969 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1970 let c = read_candles_from_csv(file)?;
1971
1972 let test_configs = vec![
1973 (2, 10, 2, 2, 10, 2),
1974 (5, 25, 5, 5, 25, 5),
1975 (30, 60, 15, 30, 60, 15),
1976 (2, 5, 1, 2, 5, 1),
1977 (9, 15, 3, 9, 15, 3),
1978 (2, 2, 0, 2, 20, 2),
1979 (2, 20, 2, 9, 9, 0),
1980 ];
1981
1982 for (cfg_idx, &(rsi_start, rsi_end, rsi_step, wma_start, wma_end, wma_step)) in
1983 test_configs.iter().enumerate()
1984 {
1985 let output = IftRsiBatchBuilder::new()
1986 .kernel(kernel)
1987 .rsi_period_range(rsi_start, rsi_end, rsi_step)
1988 .wma_period_range(wma_start, wma_end, wma_step)
1989 .apply_candles(&c, "close")?;
1990
1991 for (idx, &val) in output.values.iter().enumerate() {
1992 if val.is_nan() {
1993 continue;
1994 }
1995
1996 let bits = val.to_bits();
1997 let row = idx / output.cols;
1998 let col = idx % output.cols;
1999 let combo = &output.combos[row];
2000
2001 if bits == 0x11111111_11111111 {
2002 panic!(
2003 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2004 at row {} col {} (flat index {}) with params: rsi_period={}, wma_period={}",
2005 test,
2006 cfg_idx,
2007 val,
2008 bits,
2009 row,
2010 col,
2011 idx,
2012 combo.rsi_period.unwrap_or(5),
2013 combo.wma_period.unwrap_or(9)
2014 );
2015 }
2016
2017 if bits == 0x22222222_22222222 {
2018 panic!(
2019 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2020 at row {} col {} (flat index {}) with params: rsi_period={}, wma_period={}",
2021 test,
2022 cfg_idx,
2023 val,
2024 bits,
2025 row,
2026 col,
2027 idx,
2028 combo.rsi_period.unwrap_or(5),
2029 combo.wma_period.unwrap_or(9)
2030 );
2031 }
2032
2033 if bits == 0x33333333_33333333 {
2034 panic!(
2035 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2036 at row {} col {} (flat index {}) with params: rsi_period={}, wma_period={}",
2037 test,
2038 cfg_idx,
2039 val,
2040 bits,
2041 row,
2042 col,
2043 idx,
2044 combo.rsi_period.unwrap_or(5),
2045 combo.wma_period.unwrap_or(9)
2046 );
2047 }
2048 }
2049 }
2050
2051 Ok(())
2052 }
2053
2054 #[cfg(not(debug_assertions))]
2055 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2056 Ok(())
2057 }
2058
2059 macro_rules! gen_batch_tests {
2060 ($fn_name:ident) => {
2061 paste::paste! {
2062 #[test] fn [<$fn_name _scalar>]() {
2063 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2064 }
2065 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2066 #[test] fn [<$fn_name _avx2>]() {
2067 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2068 }
2069 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2070 #[test] fn [<$fn_name _avx512>]() {
2071 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2072 }
2073 #[test] fn [<$fn_name _auto_detect>]() {
2074 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2075 }
2076 }
2077 };
2078 }
2079 gen_batch_tests!(check_batch_default_row);
2080 gen_batch_tests!(check_batch_no_poison);
2081}
2082
2083#[cfg(feature = "python")]
2084#[pyfunction(name = "ift_rsi")]
2085#[pyo3(signature = (data, rsi_period, wma_period, kernel=None))]
2086pub fn ift_rsi_py<'py>(
2087 py: Python<'py>,
2088 data: PyReadonlyArray1<'py, f64>,
2089 rsi_period: usize,
2090 wma_period: usize,
2091 kernel: Option<&str>,
2092) -> PyResult<Bound<'py, PyArray1<f64>>> {
2093 use numpy::{IntoPyArray, PyArrayMethods};
2094
2095 let slice_in = data.as_slice()?;
2096 let kern = validate_kernel(kernel, false)?;
2097
2098 let params = IftRsiParams {
2099 rsi_period: Some(rsi_period),
2100 wma_period: Some(wma_period),
2101 };
2102 let input = IftRsiInput::from_slice(slice_in, params);
2103
2104 let result_vec: Vec<f64> = py
2105 .allow_threads(|| ift_rsi_with_kernel(&input, kern).map(|o| o.values))
2106 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2107
2108 Ok(result_vec.into_pyarray(py))
2109}
2110
2111#[cfg(feature = "python")]
2112#[pyclass(name = "IftRsiStream")]
2113pub struct IftRsiStreamPy {
2114 stream: IftRsiStream,
2115}
2116
2117#[cfg(feature = "python")]
2118#[pymethods]
2119impl IftRsiStreamPy {
2120 #[new]
2121 fn new(rsi_period: usize, wma_period: usize) -> PyResult<Self> {
2122 let params = IftRsiParams {
2123 rsi_period: Some(rsi_period),
2124 wma_period: Some(wma_period),
2125 };
2126 let stream =
2127 IftRsiStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2128 Ok(IftRsiStreamPy { stream })
2129 }
2130
2131 fn update(&mut self, value: f64) -> Option<f64> {
2132 self.stream.update(value)
2133 }
2134}
2135
2136#[cfg(feature = "python")]
2137#[pyfunction(name = "ift_rsi_batch")]
2138#[pyo3(signature = (data, rsi_period_range, wma_period_range, kernel=None))]
2139pub fn ift_rsi_batch_py<'py>(
2140 py: Python<'py>,
2141 data: PyReadonlyArray1<'py, f64>,
2142 rsi_period_range: (usize, usize, usize),
2143 wma_period_range: (usize, usize, usize),
2144 kernel: Option<&str>,
2145) -> PyResult<Bound<'py, PyDict>> {
2146 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2147 use pyo3::types::PyDict;
2148
2149 let slice_in = data.as_slice()?;
2150 let kern = validate_kernel(kernel, true)?;
2151
2152 let sweep = IftRsiBatchRange {
2153 rsi_period: rsi_period_range,
2154 wma_period: wma_period_range,
2155 };
2156
2157 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2158 let rows = combos.len();
2159 let cols = slice_in.len();
2160
2161 let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
2162 let slice_out = unsafe { out_arr.as_slice_mut()? };
2163
2164 let combos = py
2165 .allow_threads(|| {
2166 let kernel = match kern {
2167 Kernel::Auto => detect_best_batch_kernel(),
2168 k => k,
2169 };
2170 let simd = match kernel {
2171 Kernel::Avx512Batch => Kernel::Avx512,
2172 Kernel::Avx2Batch => Kernel::Avx2,
2173 Kernel::ScalarBatch => Kernel::Scalar,
2174 _ => unreachable!(),
2175 };
2176 ift_rsi_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2177 })
2178 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2179
2180 let dict = PyDict::new(py);
2181 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2182 dict.set_item(
2183 "rsi_periods",
2184 combos
2185 .iter()
2186 .map(|p| p.rsi_period.unwrap() as u64)
2187 .collect::<Vec<_>>()
2188 .into_pyarray(py),
2189 )?;
2190 dict.set_item(
2191 "wma_periods",
2192 combos
2193 .iter()
2194 .map(|p| p.wma_period.unwrap() as u64)
2195 .collect::<Vec<_>>()
2196 .into_pyarray(py),
2197 )?;
2198 dict.set_item("rows", rows)?;
2199 dict.set_item("cols", cols)?;
2200
2201 Ok(dict)
2202}
2203
2204#[cfg(all(feature = "python", feature = "cuda"))]
2205#[pyfunction(name = "ift_rsi_cuda_batch_dev")]
2206#[pyo3(signature = (data_f32, rsi_range, wma_range, device_id=0))]
2207pub fn ift_rsi_cuda_batch_dev_py(
2208 py: Python<'_>,
2209 data_f32: numpy::PyReadonlyArray1<'_, f32>,
2210 rsi_range: (usize, usize, usize),
2211 wma_range: (usize, usize, usize),
2212 device_id: usize,
2213) -> PyResult<DeviceArrayF32Py> {
2214 use crate::cuda::cuda_available;
2215 use crate::cuda::oscillators::CudaIftRsi;
2216 if !cuda_available() {
2217 return Err(PyValueError::new_err("CUDA not available"));
2218 }
2219 let slice_in: &[f32] = data_f32.as_slice()?;
2220 let sweep = IftRsiBatchRange {
2221 rsi_period: rsi_range,
2222 wma_period: wma_range,
2223 };
2224 let (inner, dev_id, ctx) = py.allow_threads(|| {
2225 let cuda = CudaIftRsi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2226 let dev_id = cuda.device_id();
2227 let ctx = cuda.context_arc();
2228 let (dev, _combos) = cuda
2229 .ift_rsi_batch_dev(slice_in, &sweep)
2230 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2231 cuda.synchronize()
2232 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2233 Ok::<_, PyErr>((dev, dev_id, ctx))
2234 })?;
2235 let handle = DeviceArrayF32Py {
2236 inner,
2237 _ctx: Some(ctx),
2238 device_id: Some(dev_id),
2239 };
2240 Ok(handle)
2241}
2242
2243#[cfg(all(feature = "python", feature = "cuda"))]
2244#[pyfunction(name = "ift_rsi_cuda_many_series_one_param_dev")]
2245#[pyo3(signature = (data_tm_f32, rsi_period, wma_period, device_id=0))]
2246pub fn ift_rsi_cuda_many_series_one_param_dev_py(
2247 py: Python<'_>,
2248 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
2249 rsi_period: usize,
2250 wma_period: usize,
2251 device_id: usize,
2252) -> PyResult<DeviceArrayF32Py> {
2253 use crate::cuda::cuda_available;
2254 use crate::cuda::oscillators::CudaIftRsi;
2255 use numpy::PyUntypedArrayMethods;
2256 if !cuda_available() {
2257 return Err(PyValueError::new_err("CUDA not available"));
2258 }
2259 let flat_in: &[f32] = data_tm_f32.as_slice()?;
2260 let rows = data_tm_f32.shape()[0];
2261 let cols = data_tm_f32.shape()[1];
2262 let params = IftRsiParams {
2263 rsi_period: Some(rsi_period),
2264 wma_period: Some(wma_period),
2265 };
2266 let (inner, dev_id, ctx) = py.allow_threads(|| {
2267 let cuda = CudaIftRsi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2268 let dev_id = cuda.device_id();
2269 let ctx = cuda.context_arc();
2270 let dev = cuda
2271 .ift_rsi_many_series_one_param_time_major_dev(flat_in, cols, rows, ¶ms)
2272 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2273 cuda.synchronize()
2274 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2275 Ok::<_, PyErr>((dev, dev_id, ctx))
2276 })?;
2277 let handle = DeviceArrayF32Py {
2278 inner,
2279 _ctx: Some(ctx),
2280 device_id: Some(dev_id),
2281 };
2282 Ok(handle)
2283}
2284
2285#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2286#[wasm_bindgen]
2287pub fn ift_rsi_js(data: &[f64], rsi_period: usize, wma_period: usize) -> Result<Vec<f64>, JsValue> {
2288 let params = IftRsiParams {
2289 rsi_period: Some(rsi_period),
2290 wma_period: Some(wma_period),
2291 };
2292 let input = IftRsiInput::from_slice(data, params);
2293
2294 let mut output = vec![0.0; data.len()];
2295
2296 let kernel = Kernel::Scalar;
2297
2298 ift_rsi_into_slice(&mut output, &input, kernel)
2299 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2300
2301 Ok(output)
2302}
2303
2304#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2305#[wasm_bindgen]
2306pub fn ift_rsi_into(
2307 in_ptr: *const f64,
2308 out_ptr: *mut f64,
2309 len: usize,
2310 rsi_period: usize,
2311 wma_period: usize,
2312) -> Result<(), JsValue> {
2313 if in_ptr.is_null() || out_ptr.is_null() {
2314 return Err(JsValue::from_str("null pointer passed to ift_rsi_into"));
2315 }
2316
2317 unsafe {
2318 let data = std::slice::from_raw_parts(in_ptr, len);
2319 let params = IftRsiParams {
2320 rsi_period: Some(rsi_period),
2321 wma_period: Some(wma_period),
2322 };
2323 let input = IftRsiInput::from_slice(data, params);
2324
2325 let kernel = Kernel::Scalar;
2326
2327 if in_ptr == out_ptr as *const f64 {
2328 let mut temp = vec![0.0; len];
2329 ift_rsi_into_slice(&mut temp, &input, kernel)
2330 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2331 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2332 out.copy_from_slice(&temp);
2333 } else {
2334 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2335 ift_rsi_into_slice(out, &input, kernel)
2336 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2337 }
2338 Ok(())
2339 }
2340}
2341
2342#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2343#[wasm_bindgen]
2344pub fn ift_rsi_alloc(len: usize) -> *mut f64 {
2345 let mut vec = Vec::<f64>::with_capacity(len);
2346 let ptr = vec.as_mut_ptr();
2347 std::mem::forget(vec);
2348 ptr
2349}
2350
2351#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2352#[wasm_bindgen]
2353pub fn ift_rsi_free(ptr: *mut f64, len: usize) {
2354 if !ptr.is_null() {
2355 unsafe {
2356 let _ = Vec::from_raw_parts(ptr, len, len);
2357 }
2358 }
2359}
2360
2361#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2362#[derive(Serialize, Deserialize)]
2363pub struct IftRsiBatchConfig {
2364 pub rsi_period_range: (usize, usize, usize),
2365 pub wma_period_range: (usize, usize, usize),
2366}
2367
2368#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2369#[derive(Serialize, Deserialize)]
2370pub struct IftRsiBatchJsOutput {
2371 pub values: Vec<f64>,
2372 pub combos: Vec<IftRsiParams>,
2373 pub rows: usize,
2374 pub cols: usize,
2375}
2376
2377#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2378#[wasm_bindgen(js_name = ift_rsi_batch)]
2379pub fn ift_rsi_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2380 let config: IftRsiBatchConfig = serde_wasm_bindgen::from_value(config)
2381 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2382
2383 let sweep = IftRsiBatchRange {
2384 rsi_period: config.rsi_period_range,
2385 wma_period: config.wma_period_range,
2386 };
2387
2388 #[cfg(target_arch = "wasm32")]
2389 let kernel = detect_best_kernel();
2390 #[cfg(not(target_arch = "wasm32"))]
2391 let kernel = Kernel::Scalar;
2392
2393 let output = ift_rsi_batch_inner(data, &sweep, kernel, false)
2394 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2395
2396 let js_output = IftRsiBatchJsOutput {
2397 values: output.values,
2398 combos: output.combos,
2399 rows: output.rows,
2400 cols: output.cols,
2401 };
2402
2403 serde_wasm_bindgen::to_value(&js_output)
2404 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2405}
2406
2407#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2408#[wasm_bindgen]
2409pub fn ift_rsi_batch_into(
2410 in_ptr: *const f64,
2411 out_ptr: *mut f64,
2412 len: usize,
2413 rsi_start: usize,
2414 rsi_end: usize,
2415 rsi_step: usize,
2416 wma_start: usize,
2417 wma_end: usize,
2418 wma_step: usize,
2419) -> Result<usize, JsValue> {
2420 if in_ptr.is_null() || out_ptr.is_null() {
2421 return Err(JsValue::from_str(
2422 "null pointer passed to ift_rsi_batch_into",
2423 ));
2424 }
2425
2426 unsafe {
2427 let data = std::slice::from_raw_parts(in_ptr, len);
2428 let sweep = IftRsiBatchRange {
2429 rsi_period: (rsi_start, rsi_end, rsi_step),
2430 wma_period: (wma_start, wma_end, wma_step),
2431 };
2432
2433 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2434 let rows = combos.len();
2435 let cols = len;
2436 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2437
2438 #[cfg(target_arch = "wasm32")]
2439 let kernel = detect_best_kernel();
2440 #[cfg(not(target_arch = "wasm32"))]
2441 let kernel = Kernel::Scalar;
2442
2443 ift_rsi_batch_inner_into(data, &sweep, kernel, false, out)
2444 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2445
2446 Ok(rows)
2447 }
2448}