1use crate::utilities::data_loader::{source_type, Candles};
2#[cfg(all(feature = "python", feature = "cuda"))]
3use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
4use crate::utilities::enums::Kernel;
5use crate::utilities::helpers::{
6 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
7 make_uninit_matrix,
8};
9#[cfg(feature = "python")]
10use crate::utilities::kernel_validation::validate_kernel;
11use aligned_vec::{AVec, CACHELINE_ALIGN};
12#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
13use core::arch::x86_64::*;
14#[cfg(feature = "python")]
15use numpy::{IntoPyArray, PyArray1};
16use paste::paste;
17#[cfg(feature = "python")]
18use pyo3::exceptions::PyValueError;
19#[cfg(feature = "python")]
20use pyo3::prelude::*;
21#[cfg(feature = "python")]
22use pyo3::types::PyDict;
23#[cfg(not(target_arch = "wasm32"))]
24use rayon::prelude::*;
25#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
26use serde::{Deserialize, Serialize};
27use std::convert::AsRef;
28use std::error::Error;
29use std::mem::MaybeUninit;
30use thiserror::Error;
31#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
32use wasm_bindgen::prelude::*;
33
34impl<'a> AsRef<[f64]> for RsiInput<'a> {
35 #[inline(always)]
36 fn as_ref(&self) -> &[f64] {
37 match &self.data {
38 RsiData::Slice(slice) => slice,
39 RsiData::Candles { candles, source } => source_type(candles, source),
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
45pub enum RsiData<'a> {
46 Candles {
47 candles: &'a Candles,
48 source: &'a str,
49 },
50 Slice(&'a [f64]),
51}
52
53#[derive(Debug, Clone)]
54pub struct RsiOutput {
55 pub values: Vec<f64>,
56}
57
58#[derive(Debug, Clone)]
59#[cfg_attr(
60 all(target_arch = "wasm32", feature = "wasm"),
61 derive(Serialize, Deserialize)
62)]
63pub struct RsiParams {
64 pub period: Option<usize>,
65}
66
67impl Default for RsiParams {
68 fn default() -> Self {
69 Self { period: Some(14) }
70 }
71}
72
73#[derive(Debug, Clone)]
74pub struct RsiInput<'a> {
75 pub data: RsiData<'a>,
76 pub params: RsiParams,
77}
78
79impl<'a> RsiInput<'a> {
80 #[inline]
81 pub fn from_candles(c: &'a Candles, s: &'a str, p: RsiParams) -> Self {
82 Self {
83 data: RsiData::Candles {
84 candles: c,
85 source: s,
86 },
87 params: p,
88 }
89 }
90 #[inline]
91 pub fn from_slice(sl: &'a [f64], p: RsiParams) -> Self {
92 Self {
93 data: RsiData::Slice(sl),
94 params: p,
95 }
96 }
97 #[inline]
98 pub fn with_default_candles(c: &'a Candles) -> Self {
99 Self::from_candles(c, "close", RsiParams::default())
100 }
101 #[inline]
102 pub fn get_period(&self) -> usize {
103 self.params.period.unwrap_or(14)
104 }
105}
106
107#[derive(Copy, Clone, Debug)]
108pub struct RsiBuilder {
109 period: Option<usize>,
110 kernel: Kernel,
111}
112
113impl Default for RsiBuilder {
114 fn default() -> Self {
115 Self {
116 period: None,
117 kernel: Kernel::Auto,
118 }
119 }
120}
121
122impl RsiBuilder {
123 #[inline(always)]
124 pub fn new() -> Self {
125 Self::default()
126 }
127 #[inline(always)]
128 pub fn period(mut self, n: usize) -> Self {
129 self.period = Some(n);
130 self
131 }
132 #[inline(always)]
133 pub fn kernel(mut self, k: Kernel) -> Self {
134 self.kernel = k;
135 self
136 }
137 #[inline(always)]
138 pub fn apply(self, c: &Candles) -> Result<RsiOutput, RsiError> {
139 let p = RsiParams {
140 period: self.period,
141 };
142 let i = RsiInput::from_candles(c, "close", p);
143 rsi_with_kernel(&i, self.kernel)
144 }
145 #[inline(always)]
146 pub fn apply_slice(self, d: &[f64]) -> Result<RsiOutput, RsiError> {
147 let p = RsiParams {
148 period: self.period,
149 };
150 let i = RsiInput::from_slice(d, p);
151 rsi_with_kernel(&i, self.kernel)
152 }
153 #[inline(always)]
154 pub fn into_stream(self) -> Result<RsiStream, RsiError> {
155 let p = RsiParams {
156 period: self.period,
157 };
158 RsiStream::try_new(p)
159 }
160}
161
162#[derive(Debug, Error)]
163pub enum RsiError {
164 #[error("rsi: Input data slice is empty.")]
165 EmptyInputData,
166 #[error("rsi: All values are NaN.")]
167 AllValuesNaN,
168 #[error("rsi: Invalid period: period = {period}, data length = {data_len}")]
169 InvalidPeriod { period: usize, data_len: usize },
170 #[error("rsi: Not enough valid data: needed = {needed}, valid = {valid}")]
171 NotEnoughValidData { needed: usize, valid: usize },
172 #[error("rsi: Output length mismatch: expected {expected}, got {got}")]
173 OutputLengthMismatch { expected: usize, got: usize },
174 #[error("rsi: Invalid range: start = {start}, end = {end}, step = {step}")]
175 InvalidRange {
176 start: usize,
177 end: usize,
178 step: usize,
179 },
180 #[error("rsi: Invalid kernel for batch: {0:?}")]
181 InvalidKernelForBatch(Kernel),
182}
183
184#[inline]
185pub fn rsi(input: &RsiInput) -> Result<RsiOutput, RsiError> {
186 rsi_with_kernel(input, Kernel::Auto)
187}
188
189pub fn rsi_with_kernel(input: &RsiInput, kernel: Kernel) -> Result<RsiOutput, RsiError> {
190 let data: &[f64] = match &input.data {
191 RsiData::Candles { candles, source } => source_type(candles, source),
192 RsiData::Slice(sl) => sl,
193 };
194
195 let len = data.len();
196 if len == 0 {
197 return Err(RsiError::EmptyInputData);
198 }
199
200 let first = data
201 .iter()
202 .position(|x| !x.is_nan())
203 .ok_or(RsiError::AllValuesNaN)?;
204 let period = input.get_period();
205
206 if period == 0 || period > len {
207 return Err(RsiError::InvalidPeriod {
208 period,
209 data_len: len,
210 });
211 }
212 if (len - first) < period {
213 return Err(RsiError::NotEnoughValidData {
214 needed: period,
215 valid: len - first,
216 });
217 }
218
219 let chosen = match kernel {
220 Kernel::Auto => Kernel::Scalar,
221 other => other,
222 };
223
224 let mut out = alloc_with_nan_prefix(len, first + period);
225 rsi_compute_into(data, period, first, chosen, &mut out);
226 Ok(RsiOutput { values: out })
227}
228
229#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
230pub fn rsi_into(input: &RsiInput, out: &mut [f64]) -> Result<(), RsiError> {
231 rsi_into_slice(out, input, Kernel::Auto)?;
232
233 let data: &[f64] = match &input.data {
234 RsiData::Candles { candles, source } => source_type(candles, source),
235 RsiData::Slice(sl) => sl,
236 };
237 let first = data
238 .iter()
239 .position(|x| !x.is_nan())
240 .ok_or(RsiError::AllValuesNaN)?;
241 let warmup_end = (first + input.get_period()).min(out.len());
242 for v in &mut out[..warmup_end] {
243 *v = f64::from_bits(0x7ff8_0000_0000_0000);
244 }
245
246 Ok(())
247}
248
249#[inline(always)]
250fn rsi_compute_into(data: &[f64], period: usize, first: usize, kernel: Kernel, out: &mut [f64]) {
251 unsafe {
252 match kernel {
253 Kernel::Scalar | Kernel::ScalarBatch => {
254 rsi_compute_into_scalar(data, period, first, out)
255 }
256 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
257 Kernel::Avx2 | Kernel::Avx2Batch => rsi_compute_into_scalar(data, period, first, out),
258 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
259 Kernel::Avx2 | Kernel::Avx2Batch => rsi_compute_into_scalar(data, period, first, out),
260 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
261 Kernel::Avx512 | Kernel::Avx512Batch => {
262 rsi_compute_into_scalar(data, period, first, out)
263 }
264 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
265 Kernel::Avx512 | Kernel::Avx512Batch => {
266 rsi_compute_into_scalar(data, period, first, out)
267 }
268 _ => unreachable!(),
269 }
270 }
271}
272
273#[inline]
274pub fn rsi_into_slice(dst: &mut [f64], input: &RsiInput, kern: Kernel) -> Result<(), RsiError> {
275 let data: &[f64] = match &input.data {
276 RsiData::Candles { candles, source } => source_type(candles, source),
277 RsiData::Slice(sl) => sl,
278 };
279
280 let len = data.len();
281 if len == 0 {
282 return Err(RsiError::EmptyInputData);
283 }
284
285 let first = data
286 .iter()
287 .position(|x| !x.is_nan())
288 .ok_or(RsiError::AllValuesNaN)?;
289 let period = input.get_period();
290
291 if period == 0 || period > len {
292 return Err(RsiError::InvalidPeriod {
293 period,
294 data_len: len,
295 });
296 }
297 if (len - first) < period {
298 return Err(RsiError::NotEnoughValidData {
299 needed: period,
300 valid: len - first,
301 });
302 }
303
304 if dst.len() != data.len() {
305 return Err(RsiError::OutputLengthMismatch {
306 expected: data.len(),
307 got: dst.len(),
308 });
309 }
310
311 let chosen = match kern {
312 Kernel::Auto => Kernel::Scalar,
313 other => other,
314 };
315
316 rsi_compute_into(data, period, first, chosen, dst);
317
318 let warmup_end = first + period;
319 for v in &mut dst[..warmup_end] {
320 *v = f64::NAN;
321 }
322
323 Ok(())
324}
325
326#[inline(always)]
327unsafe fn rsi_compute_into_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
328 let len = data.len();
329 let inv_p = 1.0 / (period as f64);
330 let beta = 1.0 - inv_p;
331
332 let mut avg_gain = 0.0f64;
333 let mut avg_loss = 0.0f64;
334 let mut has_nan = false;
335
336 let warm_last = core::cmp::min(first + period, len.saturating_sub(1));
337 let mut i = first + 1;
338 while i <= warm_last {
339 let delta = data[i] - data[i - 1];
340 if !delta.is_finite() {
341 has_nan = true;
342 break;
343 }
344 if delta > 0.0 {
345 avg_gain += delta;
346 } else if delta < 0.0 {
347 avg_loss -= delta;
348 }
349 i += 1;
350 }
351
352 let idx0 = first + period;
353 if has_nan {
354 avg_gain = f64::NAN;
355 avg_loss = f64::NAN;
356 if idx0 < len {
357 out[idx0] = f64::NAN;
358 }
359 } else {
360 avg_gain *= inv_p;
361 avg_loss *= inv_p;
362 if idx0 < len {
363 let denom = avg_gain + avg_loss;
364 out[idx0] = if denom == 0.0 {
365 50.0
366 } else {
367 100.0 * avg_gain / denom
368 };
369 }
370 }
371
372 let mut j = idx0 + 1;
373 while j + 1 < len {
374 let d1 = data[j] - data[j - 1];
375 let g1 = if d1 > 0.0 { d1 } else { 0.0 };
376 let l1 = if d1 < 0.0 { -d1 } else { 0.0 };
377 avg_gain = avg_gain.mul_add(beta, inv_p * g1);
378 avg_loss = avg_loss.mul_add(beta, inv_p * l1);
379 let denom1 = avg_gain + avg_loss;
380 out[j] = if denom1 == 0.0 {
381 50.0
382 } else {
383 100.0 * avg_gain / denom1
384 };
385
386 let d2 = data[j + 1] - data[j];
387 let g2 = if d2 > 0.0 { d2 } else { 0.0 };
388 let l2 = if d2 < 0.0 { -d2 } else { 0.0 };
389 avg_gain = avg_gain.mul_add(beta, inv_p * g2);
390 avg_loss = avg_loss.mul_add(beta, inv_p * l2);
391 let denom2 = avg_gain + avg_loss;
392 out[j + 1] = if denom2 == 0.0 {
393 50.0
394 } else {
395 100.0 * avg_gain / denom2
396 };
397
398 j += 2;
399 }
400
401 if j < len {
402 let d = data[j] - data[j - 1];
403 let g = if d > 0.0 { d } else { 0.0 };
404 let l = if d < 0.0 { -d } else { 0.0 };
405 avg_gain = avg_gain.mul_add(beta, inv_p * g);
406 avg_loss = avg_loss.mul_add(beta, inv_p * l);
407 let denom = avg_gain + avg_loss;
408 out[j] = if denom == 0.0 {
409 50.0
410 } else {
411 100.0 * avg_gain / denom
412 };
413 }
414}
415
416#[derive(Clone, Debug)]
417pub struct RsiBatchRange {
418 pub period: (usize, usize, usize),
419}
420impl Default for RsiBatchRange {
421 fn default() -> Self {
422 Self {
423 period: (14, 263, 1),
424 }
425 }
426}
427#[derive(Clone, Debug, Default)]
428pub struct RsiBatchBuilder {
429 range: RsiBatchRange,
430 kernel: Kernel,
431}
432impl RsiBatchBuilder {
433 pub fn new() -> Self {
434 Self::default()
435 }
436 pub fn kernel(mut self, k: Kernel) -> Self {
437 self.kernel = k;
438 self
439 }
440 #[inline]
441 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
442 self.range.period = (start, end, step);
443 self
444 }
445 #[inline]
446 pub fn period_static(mut self, p: usize) -> Self {
447 self.range.period = (p, p, 0);
448 self
449 }
450 pub fn apply_slice(self, data: &[f64]) -> Result<RsiBatchOutput, RsiError> {
451 rsi_batch_with_kernel(data, &self.range, self.kernel)
452 }
453 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<RsiBatchOutput, RsiError> {
454 RsiBatchBuilder::new().kernel(k).apply_slice(data)
455 }
456 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<RsiBatchOutput, RsiError> {
457 let slice = source_type(c, src);
458 self.apply_slice(slice)
459 }
460 pub fn with_default_candles(c: &Candles) -> Result<RsiBatchOutput, RsiError> {
461 RsiBatchBuilder::new()
462 .kernel(Kernel::Auto)
463 .apply_candles(c, "close")
464 }
465}
466
467pub fn rsi_batch_with_kernel(
468 data: &[f64],
469 sweep: &RsiBatchRange,
470 k: Kernel,
471) -> Result<RsiBatchOutput, RsiError> {
472 let kernel = match k {
473 Kernel::Auto => detect_best_batch_kernel(),
474 other => {
475 if other.is_batch() {
476 other
477 } else {
478 return Err(RsiError::InvalidKernelForBatch(other));
479 }
480 }
481 };
482 let simd = match kernel {
483 Kernel::Avx512Batch => Kernel::Avx512,
484 Kernel::Avx2Batch => Kernel::Avx2,
485 Kernel::ScalarBatch => Kernel::Scalar,
486 _ => unreachable!(),
487 };
488 rsi_batch_par_slice(data, sweep, simd)
489}
490
491#[derive(Clone, Debug)]
492pub struct RsiBatchOutput {
493 pub values: Vec<f64>,
494 pub combos: Vec<RsiParams>,
495 pub rows: usize,
496 pub cols: usize,
497}
498impl RsiBatchOutput {
499 pub fn row_for_params(&self, p: &RsiParams) -> Option<usize> {
500 self.combos
501 .iter()
502 .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
503 }
504 pub fn values_for(&self, p: &RsiParams) -> Option<&[f64]> {
505 self.row_for_params(p).map(|row| {
506 let start = row * self.cols;
507 &self.values[start..start + self.cols]
508 })
509 }
510}
511
512#[inline(always)]
513fn expand_grid(r: &RsiBatchRange) -> Result<Vec<RsiParams>, RsiError> {
514 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, RsiError> {
515 if step == 0 || start == end {
516 return Ok(vec![start]);
517 }
518 let (lo, hi) = if start <= end {
519 (start, end)
520 } else {
521 (end, start)
522 };
523 let mut out = Vec::new();
524 let mut v = lo;
525 loop {
526 out.push(v);
527 if v == hi {
528 break;
529 }
530 v = match v.checked_add(step) {
531 Some(next) => next,
532 None => return Err(RsiError::InvalidRange { start, end, step }),
533 };
534 if v > hi {
535 break;
536 }
537 }
538 if out.is_empty() {
539 return Err(RsiError::InvalidRange { start, end, step });
540 }
541 Ok(out)
542 }
543 let periods = axis_usize(r.period)?;
544 let mut out = Vec::with_capacity(periods.len());
545 for &p in &periods {
546 out.push(RsiParams { period: Some(p) });
547 }
548 Ok(out)
549}
550
551#[inline(always)]
552pub fn rsi_batch_slice(
553 data: &[f64],
554 sweep: &RsiBatchRange,
555 kern: Kernel,
556) -> Result<RsiBatchOutput, RsiError> {
557 rsi_batch_inner(data, sweep, kern, false)
558}
559#[inline(always)]
560pub fn rsi_batch_par_slice(
561 data: &[f64],
562 sweep: &RsiBatchRange,
563 kern: Kernel,
564) -> Result<RsiBatchOutput, RsiError> {
565 rsi_batch_inner(data, sweep, kern, true)
566}
567
568#[inline(always)]
569fn rsi_batch_inner(
570 data: &[f64],
571 sweep: &RsiBatchRange,
572 kern: Kernel,
573 parallel: bool,
574) -> Result<RsiBatchOutput, RsiError> {
575 if data.is_empty() {
576 return Err(RsiError::EmptyInputData);
577 }
578 let combos = expand_grid(sweep)?;
579 let first = data
580 .iter()
581 .position(|x| !x.is_nan())
582 .ok_or(RsiError::AllValuesNaN)?;
583 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
584 if data.len() - first < max_p {
585 return Err(RsiError::NotEnoughValidData {
586 needed: max_p,
587 valid: data.len() - first,
588 });
589 }
590 let rows = combos.len();
591 let cols = data.len();
592 let _expected = rows.checked_mul(cols).ok_or(RsiError::InvalidRange {
593 start: rows,
594 end: cols,
595 step: 1,
596 })?;
597
598 let mut buf_mu = make_uninit_matrix(rows, cols);
599
600 let warmup_periods: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
601 init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
602
603 let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
604 let values: &mut [f64] = unsafe {
605 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
606 };
607
608 let mut gains = vec![0.0f64; cols];
609 let mut losses = vec![0.0f64; cols];
610 for i in (first + 1)..cols {
611 let d = data[i] - data[i - 1];
612 if d.is_finite() {
613 if d > 0.0 {
614 gains[i] = d;
615 } else if d < 0.0 {
616 losses[i] = -d;
617 }
618 } else {
619 gains[i] = f64::NAN;
620 losses[i] = f64::NAN;
621 }
622 }
623 let mut pg = vec![0.0f64; cols];
624 let mut pl = vec![0.0f64; cols];
625 for i in 1..cols {
626 pg[i] = pg[i - 1] + gains[i];
627 pl[i] = pl[i - 1] + losses[i];
628 }
629
630 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
631 let period = combos[row].period.unwrap();
632 match kern {
633 Kernel::Scalar | Kernel::Avx2 | Kernel::Avx512 => {
634 let inv_p = 1.0 / (period as f64);
635 let beta = 1.0 - inv_p;
636 let idx0 = first + period;
637 if idx0 < cols {
638 let sum_g = pg[idx0] - pg[first];
639 let sum_l = pl[idx0] - pl[first];
640 let mut avg_g = sum_g * inv_p;
641 let mut avg_l = sum_l * inv_p;
642 if sum_g.is_nan() || sum_l.is_nan() {
643 avg_g = f64::NAN;
644 avg_l = f64::NAN;
645 out_row[idx0] = f64::NAN;
646 } else {
647 let denom = avg_g + avg_l;
648 out_row[idx0] = if denom == 0.0 {
649 50.0
650 } else {
651 100.0 * avg_g / denom
652 };
653 }
654 let mut j = idx0 + 1;
655 while j + 1 < cols {
656 let g1 = gains[j];
657 let l1 = losses[j];
658 avg_g = avg_g.mul_add(beta, inv_p * g1);
659 avg_l = avg_l.mul_add(beta, inv_p * l1);
660 let denom1 = avg_g + avg_l;
661 out_row[j] = if denom1 == 0.0 {
662 50.0
663 } else {
664 100.0 * avg_g / denom1
665 };
666
667 let g2 = gains[j + 1];
668 let l2 = losses[j + 1];
669 avg_g = avg_g.mul_add(beta, inv_p * g2);
670 avg_l = avg_l.mul_add(beta, inv_p * l2);
671 let denom2 = avg_g + avg_l;
672 out_row[j + 1] = if denom2 == 0.0 {
673 50.0
674 } else {
675 100.0 * avg_g / denom2
676 };
677 j += 2;
678 }
679 if j < cols {
680 let g = gains[j];
681 let l = losses[j];
682 avg_g = avg_g.mul_add(beta, inv_p * g);
683 avg_l = avg_l.mul_add(beta, inv_p * l);
684 let denom = avg_g + avg_l;
685 out_row[j] = if denom == 0.0 {
686 50.0
687 } else {
688 100.0 * avg_g / denom
689 };
690 }
691 }
692 }
693 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
694 Kernel::Avx2 => rsi_row_avx2(data, first, period, out_row),
695 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
696 Kernel::Avx512 => rsi_row_avx512(data, first, period, out_row),
697 _ => unreachable!(),
698 }
699 };
700
701 if parallel {
702 #[cfg(not(target_arch = "wasm32"))]
703 {
704 values
705 .par_chunks_mut(cols)
706 .enumerate()
707 .for_each(|(row, slice)| do_row(row, slice));
708 }
709
710 #[cfg(target_arch = "wasm32")]
711 {
712 for (row, slice) in values.chunks_mut(cols).enumerate() {
713 do_row(row, slice);
714 }
715 }
716 } else {
717 for (row, slice) in values.chunks_mut(cols).enumerate() {
718 do_row(row, slice);
719 }
720 }
721
722 let values = unsafe {
723 Vec::from_raw_parts(
724 buf_guard.as_mut_ptr() as *mut f64,
725 buf_guard.len(),
726 buf_guard.capacity(),
727 )
728 };
729
730 Ok(RsiBatchOutput {
731 values,
732 combos,
733 rows,
734 cols,
735 })
736}
737
738#[inline(always)]
739pub fn rsi_batch_inner_into(
740 data: &[f64],
741 sweep: &RsiBatchRange,
742 kern: Kernel,
743 parallel: bool,
744 out: &mut [f64],
745) -> Result<Vec<RsiParams>, RsiError> {
746 if data.is_empty() {
747 return Err(RsiError::EmptyInputData);
748 }
749 let combos = expand_grid(sweep)?;
750 let first = data
751 .iter()
752 .position(|x| !x.is_nan())
753 .ok_or(RsiError::AllValuesNaN)?;
754 let rows = combos.len();
755 let cols = data.len();
756
757 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
758 if cols - first < max_p {
759 return Err(RsiError::NotEnoughValidData {
760 needed: max_p,
761 valid: cols - first,
762 });
763 }
764 let expected = rows.checked_mul(cols).ok_or(RsiError::InvalidRange {
765 start: rows,
766 end: cols,
767 step: 1,
768 })?;
769 if out.len() != expected {
770 return Err(RsiError::OutputLengthMismatch {
771 expected,
772 got: out.len(),
773 });
774 }
775
776 let out_mu: &mut [MaybeUninit<f64>] = unsafe {
777 core::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
778 };
779 let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
780 init_matrix_prefixes(out_mu, cols, &warm);
781
782 let values: &mut [f64] =
783 unsafe { core::slice::from_raw_parts_mut(out_mu.as_mut_ptr() as *mut f64, out_mu.len()) };
784
785 let mut gains = vec![0.0f64; cols];
786 let mut losses = vec![0.0f64; cols];
787 for i in (first + 1)..cols {
788 let d = data[i] - data[i - 1];
789 if d.is_finite() {
790 if d > 0.0 {
791 gains[i] = d;
792 } else if d < 0.0 {
793 losses[i] = -d;
794 }
795 } else {
796 gains[i] = f64::NAN;
797 losses[i] = f64::NAN;
798 }
799 }
800 let mut pg = vec![0.0f64; cols];
801 let mut pl = vec![0.0f64; cols];
802 for i in 1..cols {
803 pg[i] = pg[i - 1] + gains[i];
804 pl[i] = pl[i - 1] + losses[i];
805 }
806
807 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
808 let period = combos[row].period.unwrap();
809 match kern {
810 Kernel::Scalar | Kernel::Avx2 | Kernel::Avx512 => {
811 let inv_p = 1.0 / (period as f64);
812 let beta = 1.0 - inv_p;
813 let idx0 = first + period;
814 if idx0 < cols {
815 let sum_g = pg[idx0] - pg[first];
816 let sum_l = pl[idx0] - pl[first];
817 let mut avg_g = sum_g * inv_p;
818 let mut avg_l = sum_l * inv_p;
819 if sum_g.is_nan() || sum_l.is_nan() {
820 avg_g = f64::NAN;
821 avg_l = f64::NAN;
822 out_row[idx0] = f64::NAN;
823 } else {
824 let denom = avg_g + avg_l;
825 out_row[idx0] = if denom == 0.0 {
826 50.0
827 } else {
828 100.0 * avg_g / denom
829 };
830 }
831 let mut j = idx0 + 1;
832 while j + 1 < cols {
833 let g1 = gains[j];
834 let l1 = losses[j];
835 avg_g = avg_g.mul_add(beta, inv_p * g1);
836 avg_l = avg_l.mul_add(beta, inv_p * l1);
837 let denom1 = avg_g + avg_l;
838 out_row[j] = if denom1 == 0.0 {
839 50.0
840 } else {
841 100.0 * avg_g / denom1
842 };
843
844 let g2 = gains[j + 1];
845 let l2 = losses[j + 1];
846 avg_g = avg_g.mul_add(beta, inv_p * g2);
847 avg_l = avg_l.mul_add(beta, inv_p * l2);
848 let denom2 = avg_g + avg_l;
849 out_row[j + 1] = if denom2 == 0.0 {
850 50.0
851 } else {
852 100.0 * avg_g / denom2
853 };
854 j += 2;
855 }
856 if j < cols {
857 let g = gains[j];
858 let l = losses[j];
859 avg_g = avg_g.mul_add(beta, inv_p * g);
860 avg_l = avg_l.mul_add(beta, inv_p * l);
861 let denom = avg_g + avg_l;
862 out_row[j] = if denom == 0.0 {
863 50.0
864 } else {
865 100.0 * avg_g / denom
866 };
867 }
868 }
869 }
870 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
871 Kernel::Avx2 => rsi_row_avx2(data, first, period, out_row),
872 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
873 Kernel::Avx512 => rsi_row_avx512(data, first, period, out_row),
874 _ => unreachable!(),
875 }
876 };
877
878 if parallel {
879 #[cfg(not(target_arch = "wasm32"))]
880 {
881 values
882 .par_chunks_mut(cols)
883 .enumerate()
884 .for_each(|(row, slice)| do_row(row, slice));
885 }
886
887 #[cfg(target_arch = "wasm32")]
888 {
889 for (row, slice) in values.chunks_mut(cols).enumerate() {
890 do_row(row, slice);
891 }
892 }
893 } else {
894 for (row, slice) in values.chunks_mut(cols).enumerate() {
895 do_row(row, slice);
896 }
897 }
898
899 Ok(combos)
900}
901
902#[inline(always)]
903unsafe fn rsi_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
904 let len = data.len();
905 let inv_p = 1.0 / (period as f64);
906 let beta = 1.0 - inv_p;
907
908 let mut avg_gain = 0.0f64;
909 let mut avg_loss = 0.0f64;
910 let mut has_nan = false;
911
912 let warm_last = core::cmp::min(first + period, len.saturating_sub(1));
913 let mut i = first + 1;
914 while i <= warm_last {
915 let delta = data[i] - data[i - 1];
916 if !delta.is_finite() {
917 has_nan = true;
918 break;
919 }
920 if delta > 0.0 {
921 avg_gain += delta;
922 } else if delta < 0.0 {
923 avg_loss -= delta;
924 }
925 i += 1;
926 }
927
928 let idx0 = first + period;
929 if has_nan {
930 avg_gain = f64::NAN;
931 avg_loss = f64::NAN;
932 if idx0 < len {
933 out[idx0] = f64::NAN;
934 }
935 } else {
936 avg_gain *= inv_p;
937 avg_loss *= inv_p;
938 if idx0 < len {
939 let denom = avg_gain + avg_loss;
940 out[idx0] = if denom == 0.0 {
941 50.0
942 } else {
943 100.0 * avg_gain / denom
944 };
945 }
946 }
947
948 let mut j = idx0 + 1;
949 while j + 1 < len {
950 let d1 = data[j] - data[j - 1];
951 let g1 = if d1 > 0.0 { d1 } else { 0.0 };
952 let l1 = if d1 < 0.0 { -d1 } else { 0.0 };
953 avg_gain = avg_gain.mul_add(beta, inv_p * g1);
954 avg_loss = avg_loss.mul_add(beta, inv_p * l1);
955 let denom1 = avg_gain + avg_loss;
956 out[j] = if denom1 == 0.0 {
957 50.0
958 } else {
959 100.0 * avg_gain / denom1
960 };
961
962 let d2 = data[j + 1] - data[j];
963 let g2 = if d2 > 0.0 { d2 } else { 0.0 };
964 let l2 = if d2 < 0.0 { -d2 } else { 0.0 };
965 avg_gain = avg_gain.mul_add(beta, inv_p * g2);
966 avg_loss = avg_loss.mul_add(beta, inv_p * l2);
967 let denom2 = avg_gain + avg_loss;
968 out[j + 1] = if denom2 == 0.0 {
969 50.0
970 } else {
971 100.0 * avg_gain / denom2
972 };
973
974 j += 2;
975 }
976 if j < len {
977 let d = data[j] - data[j - 1];
978 let g = if d > 0.0 { d } else { 0.0 };
979 let l = if d < 0.0 { -d } else { 0.0 };
980 avg_gain = avg_gain.mul_add(beta, inv_p * g);
981 avg_loss = avg_loss.mul_add(beta, inv_p * l);
982 let denom = avg_gain + avg_loss;
983 out[j] = if denom == 0.0 {
984 50.0
985 } else {
986 100.0 * avg_gain / denom
987 };
988 }
989}
990#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
991#[inline(always)]
992unsafe fn rsi_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
993 rsi_row_scalar(data, first, period, out)
994}
995#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
996#[inline(always)]
997unsafe fn rsi_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
998 if period <= 32 {
999 rsi_row_avx512_short(data, first, period, out)
1000 } else {
1001 rsi_row_avx512_long(data, first, period, out)
1002 }
1003}
1004#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1005#[inline(always)]
1006unsafe fn rsi_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1007 rsi_row_scalar(data, first, period, out)
1008}
1009#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1010#[inline(always)]
1011unsafe fn rsi_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1012 rsi_row_scalar(data, first, period, out)
1013}
1014
1015#[derive(Debug, Clone)]
1016pub struct RsiStream {
1017 period: usize,
1018 inv_p: f64,
1019 beta: f64,
1020
1021 has_prev: bool,
1022 prev: f64,
1023
1024 seed_count: usize,
1025 sum_gain: f64,
1026 sum_loss: f64,
1027 poisoned: bool,
1028
1029 avg_gain: f64,
1030 avg_loss: f64,
1031 seeded: bool,
1032}
1033impl RsiStream {
1034 #[inline(always)]
1035 pub fn try_new(params: RsiParams) -> Result<Self, RsiError> {
1036 let period = params.period.unwrap_or(14);
1037 if period == 0 {
1038 return Err(RsiError::InvalidPeriod {
1039 period,
1040 data_len: 0,
1041 });
1042 }
1043 let inv_p = 1.0 / (period as f64);
1044 Ok(Self {
1045 period,
1046 inv_p,
1047 beta: 1.0 - inv_p,
1048
1049 has_prev: false,
1050 prev: f64::NAN,
1051
1052 seed_count: 0,
1053 sum_gain: 0.0,
1054 sum_loss: 0.0,
1055 poisoned: false,
1056
1057 avg_gain: 0.0,
1058 avg_loss: 0.0,
1059 seeded: false,
1060 })
1061 }
1062
1063 #[inline(always)]
1064 pub fn update(&mut self, value: f64) -> Option<f64> {
1065 if !self.has_prev {
1066 self.prev = value;
1067 self.has_prev = true;
1068 return None;
1069 }
1070
1071 let delta = value - self.prev;
1072 self.prev = value;
1073
1074 if !self.seeded {
1075 if !delta.is_finite() {
1076 self.poisoned = true;
1077 }
1078
1079 let gain = delta.max(0.0);
1080 let loss = (-delta).max(0.0);
1081
1082 self.sum_gain += gain;
1083 self.sum_loss += loss;
1084 self.seed_count += 1;
1085
1086 if self.seed_count == self.period {
1087 self.seeded = true;
1088 if self.poisoned {
1089 self.avg_gain = f64::NAN;
1090 self.avg_loss = f64::NAN;
1091 return Some(f64::NAN);
1092 } else {
1093 self.avg_gain = self.sum_gain * self.inv_p;
1094 self.avg_loss = self.sum_loss * self.inv_p;
1095 let denom = self.avg_gain + self.avg_loss;
1096 let rsi = if denom == 0.0 {
1097 50.0
1098 } else {
1099 100.0 * self.avg_gain / denom
1100 };
1101 return Some(rsi);
1102 }
1103 } else {
1104 return None;
1105 }
1106 }
1107
1108 let gain = delta.max(0.0);
1109 let loss = (-delta).max(0.0);
1110
1111 self.avg_gain = self.avg_gain.mul_add(self.beta, self.inv_p * gain);
1112 self.avg_loss = self.avg_loss.mul_add(self.beta, self.inv_p * loss);
1113 let denom = self.avg_gain + self.avg_loss;
1114 Some(if denom == 0.0 {
1115 50.0
1116 } else {
1117 100.0 * self.avg_gain / denom
1118 })
1119 }
1120}
1121
1122#[cfg(test)]
1123mod tests {
1124 use super::*;
1125 use crate::skip_if_unsupported;
1126 use crate::utilities::data_loader::read_candles_from_csv;
1127
1128 fn check_rsi_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1129 skip_if_unsupported!(kernel, test_name);
1130 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1131 let candles = read_candles_from_csv(file_path)?;
1132 let partial_params = RsiParams { period: None };
1133 let input = RsiInput::from_candles(&candles, "close", partial_params);
1134 let result = rsi_with_kernel(&input, kernel)?;
1135 assert_eq!(result.values.len(), candles.close.len());
1136 Ok(())
1137 }
1138 fn check_rsi_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1139 skip_if_unsupported!(kernel, test_name);
1140 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1141 let candles = read_candles_from_csv(file_path)?;
1142 let input = RsiInput::from_candles(&candles, "close", RsiParams { period: Some(14) });
1143 let result = rsi_with_kernel(&input, kernel)?;
1144 let expected_last_five = [43.42, 42.68, 41.62, 42.86, 39.01];
1145 let start = result.values.len().saturating_sub(5);
1146 for (i, &val) in result.values[start..].iter().enumerate() {
1147 let diff = (val - expected_last_five[i]).abs();
1148 assert!(
1149 diff < 1e-2,
1150 "[{}] RSI {:?} mismatch at idx {}: got {}, expected {}",
1151 test_name,
1152 kernel,
1153 i,
1154 val,
1155 expected_last_five[i]
1156 );
1157 }
1158 Ok(())
1159 }
1160 fn check_rsi_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1161 skip_if_unsupported!(kernel, test_name);
1162 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1163 let candles = read_candles_from_csv(file_path)?;
1164 let input = RsiInput::with_default_candles(&candles);
1165 match input.data {
1166 RsiData::Candles { source, .. } => assert_eq!(source, "close"),
1167 _ => panic!("Expected RsiData::Candles"),
1168 }
1169 let output = rsi_with_kernel(&input, kernel)?;
1170 assert_eq!(output.values.len(), candles.close.len());
1171 Ok(())
1172 }
1173 fn check_rsi_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1174 skip_if_unsupported!(kernel, test_name);
1175 let input_data = [10.0, 20.0, 30.0];
1176 let params = RsiParams { period: Some(0) };
1177 let input = RsiInput::from_slice(&input_data, params);
1178 let res = rsi_with_kernel(&input, kernel);
1179 assert!(
1180 res.is_err(),
1181 "[{}] RSI should fail with zero period",
1182 test_name
1183 );
1184 Ok(())
1185 }
1186 fn check_rsi_period_exceeds_length(
1187 test_name: &str,
1188 kernel: Kernel,
1189 ) -> Result<(), Box<dyn Error>> {
1190 skip_if_unsupported!(kernel, test_name);
1191 let data_small = [10.0, 20.0, 30.0];
1192 let params = RsiParams { period: Some(10) };
1193 let input = RsiInput::from_slice(&data_small, params);
1194 let res = rsi_with_kernel(&input, kernel);
1195 assert!(
1196 res.is_err(),
1197 "[{}] RSI should fail with period exceeding length",
1198 test_name
1199 );
1200 Ok(())
1201 }
1202 fn check_rsi_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1203 skip_if_unsupported!(kernel, test_name);
1204 let single_point = [42.0];
1205 let params = RsiParams { period: Some(14) };
1206 let input = RsiInput::from_slice(&single_point, params);
1207 let res = rsi_with_kernel(&input, kernel);
1208 assert!(
1209 res.is_err(),
1210 "[{}] RSI should fail with insufficient data",
1211 test_name
1212 );
1213 Ok(())
1214 }
1215 fn check_rsi_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1216 skip_if_unsupported!(kernel, test_name);
1217 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1218 let candles = read_candles_from_csv(file_path)?;
1219 let first_params = RsiParams { period: Some(14) };
1220 let first_input = RsiInput::from_candles(&candles, "close", first_params);
1221 let first_result = rsi_with_kernel(&first_input, kernel)?;
1222 let second_params = RsiParams { period: Some(5) };
1223 let second_input = RsiInput::from_slice(&first_result.values, second_params);
1224 let second_result = rsi_with_kernel(&second_input, kernel)?;
1225 assert_eq!(second_result.values.len(), first_result.values.len());
1226 if second_result.values.len() > 240 {
1227 for i in 240..second_result.values.len() {
1228 assert!(
1229 !second_result.values[i].is_nan(),
1230 "Found NaN in RSI at {}",
1231 i
1232 );
1233 }
1234 }
1235 Ok(())
1236 }
1237 fn check_rsi_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1238 skip_if_unsupported!(kernel, test_name);
1239 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1240 let candles = read_candles_from_csv(file_path)?;
1241 let input = RsiInput::from_candles(&candles, "close", RsiParams { period: Some(14) });
1242 let res = rsi_with_kernel(&input, kernel)?;
1243 assert_eq!(res.values.len(), candles.close.len());
1244 if res.values.len() > 240 {
1245 for (i, &val) in res.values[240..].iter().enumerate() {
1246 assert!(
1247 !val.is_nan(),
1248 "[{}] Found unexpected NaN at out-index {}",
1249 test_name,
1250 240 + i
1251 );
1252 }
1253 }
1254 Ok(())
1255 }
1256 fn check_rsi_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1257 skip_if_unsupported!(kernel, test_name);
1258 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1259 let candles = read_candles_from_csv(file_path)?;
1260 let period = 14;
1261 let input = RsiInput::from_candles(
1262 &candles,
1263 "close",
1264 RsiParams {
1265 period: Some(period),
1266 },
1267 );
1268 let batch_output = rsi_with_kernel(&input, kernel)?.values;
1269
1270 let mut stream = RsiStream::try_new(RsiParams {
1271 period: Some(period),
1272 })?;
1273 let mut stream_values = Vec::with_capacity(candles.close.len());
1274 for &price in &candles.close {
1275 match stream.update(price) {
1276 Some(rsi_val) => stream_values.push(rsi_val),
1277 None => stream_values.push(f64::NAN),
1278 }
1279 }
1280 assert_eq!(batch_output.len(), stream_values.len());
1281 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1282 if b.is_nan() && s.is_nan() {
1283 continue;
1284 }
1285 let diff = (b - s).abs();
1286 assert!(
1287 diff < 1e-6,
1288 "[{}] RSI streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1289 test_name,
1290 i,
1291 b,
1292 s,
1293 diff
1294 );
1295 }
1296 Ok(())
1297 }
1298
1299 #[cfg(debug_assertions)]
1300 fn check_rsi_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1301 skip_if_unsupported!(kernel, test_name);
1302
1303 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1304 let candles = read_candles_from_csv(file_path)?;
1305
1306 let test_params = vec![
1307 RsiParams::default(),
1308 RsiParams { period: Some(2) },
1309 RsiParams { period: Some(5) },
1310 RsiParams { period: Some(7) },
1311 RsiParams { period: Some(10) },
1312 RsiParams { period: Some(14) },
1313 RsiParams { period: Some(20) },
1314 RsiParams { period: Some(30) },
1315 RsiParams { period: Some(50) },
1316 RsiParams { period: Some(100) },
1317 RsiParams { period: Some(200) },
1318 ];
1319
1320 for (param_idx, params) in test_params.iter().enumerate() {
1321 let input = RsiInput::from_candles(&candles, "close", params.clone());
1322 let output = rsi_with_kernel(&input, kernel)?;
1323
1324 for (i, &val) in output.values.iter().enumerate() {
1325 if val.is_nan() {
1326 continue;
1327 }
1328
1329 let bits = val.to_bits();
1330
1331 if bits == 0x11111111_11111111 {
1332 panic!(
1333 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1334 with params: period={} (param set {})",
1335 test_name,
1336 val,
1337 bits,
1338 i,
1339 params.period.unwrap_or(14),
1340 param_idx
1341 );
1342 }
1343
1344 if bits == 0x22222222_22222222 {
1345 panic!(
1346 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1347 with params: period={} (param set {})",
1348 test_name,
1349 val,
1350 bits,
1351 i,
1352 params.period.unwrap_or(14),
1353 param_idx
1354 );
1355 }
1356
1357 if bits == 0x33333333_33333333 {
1358 panic!(
1359 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1360 with params: period={} (param set {})",
1361 test_name,
1362 val,
1363 bits,
1364 i,
1365 params.period.unwrap_or(14),
1366 param_idx
1367 );
1368 }
1369 }
1370 }
1371
1372 Ok(())
1373 }
1374
1375 #[cfg(not(debug_assertions))]
1376 fn check_rsi_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1377 Ok(())
1378 }
1379
1380 #[cfg(feature = "proptest")]
1381 #[allow(clippy::float_cmp)]
1382 fn check_rsi_property(
1383 test_name: &str,
1384 kernel: Kernel,
1385 ) -> Result<(), Box<dyn std::error::Error>> {
1386 use proptest::prelude::*;
1387 skip_if_unsupported!(kernel, test_name);
1388
1389 let strat = (2usize..=100).prop_flat_map(|period| {
1390 (
1391 prop::collection::vec(
1392 (-1e6f64..1e6f64)
1393 .prop_filter("finite price", |x| x.is_finite() && x.abs() > 1e-10),
1394 period + 10..400,
1395 ),
1396 Just(period),
1397 )
1398 });
1399
1400 proptest::test_runner::TestRunner::default().run(&strat, |(data, period)| {
1401 let params = RsiParams {
1402 period: Some(period),
1403 };
1404 let input = RsiInput::from_slice(&data, params);
1405
1406 let RsiOutput { values: out } = rsi_with_kernel(&input, kernel)?;
1407
1408 let RsiOutput { values: ref_out } = rsi_with_kernel(&input, Kernel::Scalar)?;
1409
1410 let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1411 let warmup_end = first_valid + period;
1412
1413 for (i, &val) in out.iter().enumerate() {
1414 if !val.is_nan() {
1415 prop_assert!(
1416 val >= 0.0 && val <= 100.0,
1417 "[{}] RSI value {} at index {} is out of range [0, 100]",
1418 test_name,
1419 val,
1420 i
1421 );
1422 }
1423 }
1424
1425 for i in 0..warmup_end.min(out.len()) {
1426 prop_assert!(
1427 out[i].is_nan(),
1428 "[{}] Expected NaN during warmup at index {}, got {}",
1429 test_name,
1430 i,
1431 out[i]
1432 );
1433 }
1434
1435 if warmup_end < out.len() {
1436 prop_assert!(
1437 !out[warmup_end].is_nan(),
1438 "[{}] Expected non-NaN at index {} (warmup_end), got NaN",
1439 test_name,
1440 warmup_end
1441 );
1442 }
1443
1444 for i in 0..out.len() {
1445 let y = out[i];
1446 let r = ref_out[i];
1447
1448 if y.is_nan() && r.is_nan() {
1449 continue;
1450 }
1451
1452 prop_assert!(
1453 (y - r).abs() < 1e-9,
1454 "[{}] Kernel mismatch at index {}: {} vs {} (diff: {})",
1455 test_name,
1456 i,
1457 y,
1458 r,
1459 (y - r).abs()
1460 );
1461 }
1462
1463 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12) && warmup_end < out.len() {
1464 for i in warmup_end..out.len() {
1465 prop_assert!(
1466 (out[i] - 50.0).abs() < 1e-9,
1467 "[{}] Constant prices should yield RSI=50, got {} at index {}",
1468 test_name,
1469 out[i],
1470 i
1471 );
1472 }
1473 }
1474
1475 let strictly_increasing = data.windows(2).all(|w| w[1] > w[0] + 1e-10);
1476 if strictly_increasing && out.len() > warmup_end + 10 {
1477 let last_rsi = out[out.len() - 1];
1478 let high_threshold = if period <= 5 {
1479 60.0
1480 } else if period <= 20 {
1481 65.0
1482 } else {
1483 70.0
1484 };
1485 prop_assert!(
1486 last_rsi > high_threshold,
1487 "[{}] Strictly increasing prices should yield RSI > {} (period={}), got {}",
1488 test_name,
1489 high_threshold,
1490 period,
1491 last_rsi
1492 );
1493 }
1494
1495 let strictly_decreasing = data.windows(2).all(|w| w[1] < w[0] - 1e-10);
1496 if strictly_decreasing && out.len() > warmup_end + 10 {
1497 let last_rsi = out[out.len() - 1];
1498 let low_threshold = if period <= 5 {
1499 40.0
1500 } else if period <= 20 {
1501 35.0
1502 } else {
1503 30.0
1504 };
1505 prop_assert!(
1506 last_rsi < low_threshold,
1507 "[{}] Strictly decreasing prices should yield RSI < {} (period={}), got {}",
1508 test_name,
1509 low_threshold,
1510 period,
1511 last_rsi
1512 );
1513 }
1514
1515 #[cfg(debug_assertions)]
1516 {
1517 for (i, &val) in out.iter().enumerate() {
1518 if val.is_nan() {
1519 continue;
1520 }
1521
1522 let bits = val.to_bits();
1523 prop_assert!(
1524 bits != 0x11111111_11111111
1525 && bits != 0x22222222_22222222
1526 && bits != 0x33333333_33333333,
1527 "[{}] Found poison value {} (0x{:016X}) at index {}",
1528 test_name,
1529 val,
1530 bits,
1531 i
1532 );
1533 }
1534 }
1535
1536 let mut oscillating = true;
1537 let mut prev_delta = 0.0;
1538 for window in data.windows(2) {
1539 let delta = window[1] - window[0];
1540 if prev_delta != 0.0 && delta != 0.0 {
1541 if (delta > 0.0 && prev_delta > 0.0) || (delta < 0.0 && prev_delta < 0.0) {
1542 oscillating = false;
1543 break;
1544 }
1545 }
1546 prev_delta = delta;
1547 }
1548
1549 if oscillating && out.len() > warmup_end + 10 && prev_delta != 0.0 {
1550 let last_quarter_start = out.len() - (out.len() - warmup_end) / 4;
1551 for i in last_quarter_start..out.len() {
1552 if !out[i].is_nan() {
1553 prop_assert!(
1554 out[i] >= 35.0 && out[i] <= 65.0,
1555 "[{}] Oscillating prices should keep RSI in [35, 65] range, got {} at index {}",
1556 test_name, out[i], i
1557 );
1558 }
1559 }
1560 }
1561
1562 if warmup_end + 5 < out.len() {
1563 let idx = warmup_end + 3;
1564 let mut avg_gain = 0.0;
1565 let mut avg_loss = 0.0;
1566
1567 for j in (first_valid + 1)..=(first_valid + period) {
1568 let delta = data[j] - data[j - 1];
1569 if delta > 0.0 {
1570 avg_gain += delta;
1571 } else {
1572 avg_loss += -delta;
1573 }
1574 }
1575 avg_gain /= period as f64;
1576 avg_loss /= period as f64;
1577
1578 let inv_period = 1.0 / period as f64;
1579 let beta = 1.0 - inv_period;
1580 for j in (first_valid + period + 1)..=idx {
1581 let delta = data[j] - data[j - 1];
1582 let gain = if delta > 0.0 { delta } else { 0.0 };
1583 let loss = if delta < 0.0 { -delta } else { 0.0 };
1584 avg_gain = inv_period * gain + beta * avg_gain;
1585 avg_loss = inv_period * loss + beta * avg_loss;
1586 }
1587
1588 let expected_rsi = if avg_gain + avg_loss == 0.0 {
1589 50.0
1590 } else {
1591 100.0 * avg_gain / (avg_gain + avg_loss)
1592 };
1593
1594 prop_assert!(
1595 (out[idx] - expected_rsi).abs() < 1e-9,
1596 "[{}] RSI calculation mismatch at index {}: got {}, expected {}",
1597 test_name,
1598 idx,
1599 out[idx],
1600 expected_rsi
1601 );
1602 }
1603
1604 Ok(())
1605 })?;
1606
1607 Ok(())
1608 }
1609
1610 macro_rules! generate_all_rsi_tests {
1611 ($($test_fn:ident),*) => {
1612 paste! {
1613 $(
1614 #[test]
1615 fn [<$test_fn _scalar_f64>]() {
1616 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1617 }
1618 )*
1619 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1620 $(
1621 #[test]
1622 fn [<$test_fn _avx2_f64>]() {
1623 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1624 }
1625 #[test]
1626 fn [<$test_fn _avx512_f64>]() {
1627 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1628 }
1629 )*
1630 }
1631 }
1632 }
1633
1634 fn check_rsi_error_variants(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1635 skip_if_unsupported!(kernel, test_name);
1636
1637 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1638 let mut dst = vec![0.0; 3];
1639 let params = RsiParams { period: Some(2) };
1640 let input = RsiInput::from_slice(&data, params);
1641
1642 match rsi_into_slice(&mut dst, &input, kernel) {
1643 Err(RsiError::OutputLengthMismatch {
1644 expected: 5,
1645 got: 3,
1646 }) => {}
1647 other => panic!(
1648 "[{}] Expected OutputLengthMismatch error, got {:?}",
1649 test_name, other
1650 ),
1651 }
1652
1653 let sweep = RsiBatchRange {
1654 period: (14, 14, 0),
1655 };
1656 match rsi_batch_with_kernel(&data, &sweep, Kernel::Scalar) {
1657 Err(RsiError::InvalidKernelForBatch(Kernel::Scalar)) => {}
1658 other => panic!(
1659 "[{}] Expected InvalidKernelForBatch error, got {:?}",
1660 test_name, other
1661 ),
1662 }
1663
1664 Ok(())
1665 }
1666
1667 generate_all_rsi_tests!(
1668 check_rsi_partial_params,
1669 check_rsi_accuracy,
1670 check_rsi_default_candles,
1671 check_rsi_zero_period,
1672 check_rsi_period_exceeds_length,
1673 check_rsi_very_small_dataset,
1674 check_rsi_reinput,
1675 check_rsi_nan_handling,
1676 check_rsi_streaming,
1677 check_rsi_no_poison,
1678 check_rsi_error_variants
1679 );
1680
1681 #[cfg(feature = "proptest")]
1682 generate_all_rsi_tests!(check_rsi_property);
1683
1684 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1685 skip_if_unsupported!(kernel, test);
1686 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1687 let c = read_candles_from_csv(file)?;
1688 let output = RsiBatchBuilder::new()
1689 .kernel(kernel)
1690 .apply_candles(&c, "close")?;
1691 let def = RsiParams::default();
1692 let row = output.values_for(&def).expect("default row missing");
1693 assert_eq!(row.len(), c.close.len());
1694 let expected = [43.42, 42.68, 41.62, 42.86, 39.01];
1695 let start = row.len() - 5;
1696 for (i, &v) in row[start..].iter().enumerate() {
1697 assert!(
1698 (v - expected[i]).abs() < 1e-2,
1699 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1700 );
1701 }
1702 Ok(())
1703 }
1704
1705 #[cfg(debug_assertions)]
1706 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1707 skip_if_unsupported!(kernel, test);
1708
1709 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1710 let c = read_candles_from_csv(file)?;
1711
1712 let test_configs = vec![
1713 (2, 10, 2),
1714 (5, 25, 5),
1715 (30, 60, 15),
1716 (2, 5, 1),
1717 (7, 21, 7),
1718 (10, 50, 10),
1719 (14, 28, 14),
1720 ];
1721
1722 for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1723 let output = RsiBatchBuilder::new()
1724 .kernel(kernel)
1725 .period_range(p_start, p_end, p_step)
1726 .apply_candles(&c, "close")?;
1727
1728 for (idx, &val) in output.values.iter().enumerate() {
1729 if val.is_nan() {
1730 continue;
1731 }
1732
1733 let bits = val.to_bits();
1734 let row = idx / output.cols;
1735 let col = idx % output.cols;
1736 let combo = &output.combos[row];
1737
1738 if bits == 0x11111111_11111111 {
1739 panic!(
1740 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1741 at row {} col {} (flat index {}) with params: period={}",
1742 test,
1743 cfg_idx,
1744 val,
1745 bits,
1746 row,
1747 col,
1748 idx,
1749 combo.period.unwrap_or(14)
1750 );
1751 }
1752
1753 if bits == 0x22222222_22222222 {
1754 panic!(
1755 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1756 at row {} col {} (flat index {}) with params: period={}",
1757 test,
1758 cfg_idx,
1759 val,
1760 bits,
1761 row,
1762 col,
1763 idx,
1764 combo.period.unwrap_or(14)
1765 );
1766 }
1767
1768 if bits == 0x33333333_33333333 {
1769 panic!(
1770 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1771 at row {} col {} (flat index {}) with params: period={}",
1772 test,
1773 cfg_idx,
1774 val,
1775 bits,
1776 row,
1777 col,
1778 idx,
1779 combo.period.unwrap_or(14)
1780 );
1781 }
1782 }
1783 }
1784
1785 Ok(())
1786 }
1787
1788 #[cfg(not(debug_assertions))]
1789 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1790 Ok(())
1791 }
1792
1793 macro_rules! gen_batch_tests {
1794 ($fn_name:ident) => {
1795 paste! {
1796 #[test] fn [<$fn_name _scalar>]() {
1797 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1798 }
1799 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1800 #[test] fn [<$fn_name _avx2>]() {
1801 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1802 }
1803 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1804 #[test] fn [<$fn_name _avx512>]() {
1805 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1806 }
1807 #[test] fn [<$fn_name _auto_detect>]() {
1808 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1809 }
1810 }
1811 };
1812 }
1813 gen_batch_tests!(check_batch_default_row);
1814 gen_batch_tests!(check_batch_no_poison);
1815}
1816
1817#[cfg(test)]
1818#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1819mod into_parity_tests {
1820 use super::*;
1821 use crate::utilities::data_loader::read_candles_from_csv;
1822 use std::error::Error;
1823
1824 #[inline]
1825 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1826 (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
1827 }
1828
1829 #[test]
1830 fn test_rsi_into_matches_api() -> Result<(), Box<dyn Error>> {
1831 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1832 let candles = read_candles_from_csv(file_path)?;
1833
1834 let input = RsiInput::from_candles(&candles, "close", RsiParams::default());
1835
1836 let baseline = rsi(&input)?.values;
1837
1838 let mut out = vec![0.0; candles.close.len()];
1839 rsi_into(&input, &mut out)?;
1840
1841 assert_eq!(baseline.len(), out.len());
1842 for i in 0..out.len() {
1843 assert!(
1844 eq_or_both_nan(baseline[i], out[i]),
1845 "rsi_into parity mismatch at {}: {} vs {}",
1846 i,
1847 baseline[i],
1848 out[i]
1849 );
1850 }
1851
1852 Ok(())
1853 }
1854}
1855
1856#[cfg(feature = "python")]
1857#[pyfunction(name = "rsi")]
1858#[pyo3(signature = (data, period, kernel=None))]
1859pub fn rsi_py<'py>(
1860 py: Python<'py>,
1861 data: numpy::PyReadonlyArray1<'py, f64>,
1862 period: usize,
1863 kernel: Option<&str>,
1864) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1865 use numpy::{IntoPyArray, PyArrayMethods};
1866
1867 let slice_in = data.as_slice()?;
1868 let kern = validate_kernel(kernel, false)?;
1869
1870 let params = RsiParams {
1871 period: Some(period),
1872 };
1873 let input = RsiInput::from_slice(slice_in, params);
1874
1875 let result_vec: Vec<f64> = py
1876 .allow_threads(|| rsi_with_kernel(&input, kern).map(|o| o.values))
1877 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1878
1879 Ok(result_vec.into_pyarray(py))
1880}
1881
1882#[cfg(feature = "python")]
1883#[pyclass(name = "RsiStream")]
1884pub struct RsiStreamPy {
1885 stream: RsiStream,
1886}
1887
1888#[cfg(feature = "python")]
1889#[pymethods]
1890impl RsiStreamPy {
1891 #[new]
1892 fn new(period: usize) -> PyResult<Self> {
1893 let params = RsiParams {
1894 period: Some(period),
1895 };
1896 let stream =
1897 RsiStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1898 Ok(RsiStreamPy { stream })
1899 }
1900
1901 fn update(&mut self, value: f64) -> Option<f64> {
1902 self.stream.update(value)
1903 }
1904}
1905
1906#[cfg(feature = "python")]
1907#[pyfunction(name = "rsi_batch")]
1908#[pyo3(signature = (data, period_range, kernel=None))]
1909pub fn rsi_batch_py<'py>(
1910 py: Python<'py>,
1911 data: numpy::PyReadonlyArray1<'py, f64>,
1912 period_range: (usize, usize, usize),
1913 kernel: Option<&str>,
1914) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1915 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1916 use pyo3::types::PyDict;
1917
1918 let slice_in = data.as_slice()?;
1919 let kern = validate_kernel(kernel, true)?;
1920
1921 let sweep = RsiBatchRange {
1922 period: period_range,
1923 };
1924
1925 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1926 let rows = combos.len();
1927 let cols = slice_in.len();
1928
1929 let total = rows
1930 .checked_mul(cols)
1931 .ok_or_else(|| PyValueError::new_err("rows*cols overflow in rsi_batch_py"))?;
1932 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1933 let slice_out = unsafe { out_arr.as_slice_mut()? };
1934
1935 let combos = py
1936 .allow_threads(|| {
1937 let kernel = match kern {
1938 Kernel::Auto => detect_best_batch_kernel(),
1939 k => k,
1940 };
1941 let simd = match kernel {
1942 Kernel::Avx512Batch => Kernel::Avx512,
1943 Kernel::Avx2Batch => Kernel::Avx2,
1944 Kernel::ScalarBatch => Kernel::Scalar,
1945 _ => kernel,
1946 };
1947 rsi_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1948 })
1949 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1950
1951 let dict = PyDict::new(py);
1952 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1953 dict.set_item(
1954 "periods",
1955 combos
1956 .iter()
1957 .map(|p| p.period.unwrap() as u64)
1958 .collect::<Vec<_>>()
1959 .into_pyarray(py),
1960 )?;
1961
1962 Ok(dict)
1963}
1964
1965#[cfg(all(feature = "python", feature = "cuda"))]
1966#[pyfunction(name = "rsi_cuda_batch_dev")]
1967#[pyo3(signature = (data_f32, period_range, device_id=0))]
1968pub fn rsi_cuda_batch_dev_py<'py>(
1969 py: Python<'py>,
1970 data_f32: numpy::PyReadonlyArray1<'py, f32>,
1971 period_range: (usize, usize, usize),
1972 device_id: usize,
1973) -> PyResult<(DeviceArrayF32Py, Bound<'py, pyo3::types::PyDict>)> {
1974 use crate::cuda::cuda_available;
1975 use crate::cuda::oscillators::rsi_wrapper::CudaRsi;
1976 use numpy::IntoPyArray;
1977 use pyo3::types::PyDict;
1978
1979 if !cuda_available() {
1980 return Err(PyValueError::new_err("CUDA not available"));
1981 }
1982
1983 let prices = data_f32.as_slice()?;
1984 let sweep = RsiBatchRange {
1985 period: period_range,
1986 };
1987
1988 let inner = py.allow_threads(|| {
1989 let cuda = CudaRsi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1990 cuda.rsi_batch_dev(prices, &sweep)
1991 .map_err(|e| PyValueError::new_err(e.to_string()))
1992 })?;
1993
1994 let dict = PyDict::new(py);
1995 let (start, end, step) = period_range;
1996 let mut periods: Vec<u64> = Vec::new();
1997 if step == 0 {
1998 periods.push(start as u64);
1999 } else {
2000 let mut p = start;
2001 while p <= end {
2002 periods.push(p as u64);
2003 p = p.saturating_add(step);
2004 }
2005 }
2006 dict.set_item("periods", periods.into_pyarray(py))?;
2007
2008 let handle = make_device_array_py(device_id, inner)?;
2009 Ok((handle, dict))
2010}
2011
2012#[cfg(all(feature = "python", feature = "cuda"))]
2013#[pyfunction(name = "rsi_cuda_many_series_one_param_dev")]
2014#[pyo3(signature = (data_tm_f32, period, device_id=0))]
2015pub fn rsi_cuda_many_series_one_param_dev_py(
2016 py: Python<'_>,
2017 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
2018 period: usize,
2019 device_id: usize,
2020) -> PyResult<DeviceArrayF32Py> {
2021 use crate::cuda::cuda_available;
2022 use crate::cuda::oscillators::rsi_wrapper::CudaRsi;
2023 use numpy::PyUntypedArrayMethods;
2024 if !cuda_available() {
2025 return Err(PyValueError::new_err("CUDA not available"));
2026 }
2027
2028 let flat = data_tm_f32.as_slice()?;
2029 let rows = data_tm_f32.shape()[0];
2030 let cols = data_tm_f32.shape()[1];
2031 let inner = py.allow_threads(|| {
2032 let cuda = CudaRsi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2033 cuda.rsi_many_series_one_param_time_major_dev(flat, cols, rows, period)
2034 .map_err(|e| PyValueError::new_err(e.to_string()))
2035 })?;
2036 make_device_array_py(device_id, inner)
2037}
2038
2039#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2040#[wasm_bindgen]
2041pub fn rsi_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2042 let params = RsiParams {
2043 period: Some(period),
2044 };
2045 let input = RsiInput::from_slice(data, params);
2046
2047 let mut output = vec![0.0; data.len()];
2048
2049 rsi_into_slice(&mut output, &input, detect_best_kernel())
2050 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2051
2052 Ok(output)
2053}
2054
2055#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2056#[wasm_bindgen]
2057pub fn rsi_alloc(len: usize) -> *mut f64 {
2058 let mut vec = Vec::<f64>::with_capacity(len);
2059 let ptr = vec.as_mut_ptr();
2060 std::mem::forget(vec);
2061 ptr
2062}
2063
2064#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2065#[wasm_bindgen]
2066pub fn rsi_free(ptr: *mut f64, len: usize) {
2067 unsafe {
2068 let _ = Vec::from_raw_parts(ptr, len, len);
2069 }
2070}
2071
2072#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2073#[wasm_bindgen]
2074pub fn rsi_into(
2075 in_ptr: *const f64,
2076 out_ptr: *mut f64,
2077 len: usize,
2078 period: usize,
2079) -> Result<(), JsValue> {
2080 if in_ptr.is_null() || out_ptr.is_null() {
2081 return Err(JsValue::from_str("null pointer passed to rsi_into"));
2082 }
2083
2084 unsafe {
2085 let data = std::slice::from_raw_parts(in_ptr, len);
2086
2087 if period == 0 || period > len {
2088 return Err(JsValue::from_str("Invalid period"));
2089 }
2090
2091 let params = RsiParams {
2092 period: Some(period),
2093 };
2094 let input = RsiInput::from_slice(data, params);
2095
2096 if in_ptr == out_ptr {
2097 let mut temp = vec![0.0; len];
2098 rsi_into_slice(&mut temp, &input, detect_best_kernel())
2099 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2100 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2101 out.copy_from_slice(&temp);
2102 } else {
2103 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2104 rsi_into_slice(out, &input, detect_best_kernel())
2105 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2106 }
2107
2108 Ok(())
2109 }
2110}
2111
2112#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2113#[derive(Serialize, Deserialize)]
2114pub struct RsiBatchConfig {
2115 pub period_range: (usize, usize, usize),
2116}
2117
2118#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2119#[derive(Serialize, Deserialize)]
2120pub struct RsiBatchJsOutput {
2121 pub values: Vec<f64>,
2122 pub combos: Vec<RsiParams>,
2123 pub rows: usize,
2124 pub cols: usize,
2125}
2126
2127#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2128#[wasm_bindgen(js_name = rsi_batch)]
2129pub fn rsi_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2130 let config: RsiBatchConfig = serde_wasm_bindgen::from_value(config)
2131 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2132
2133 let sweep = RsiBatchRange {
2134 period: config.period_range,
2135 };
2136
2137 let output = rsi_batch_with_kernel(data, &sweep, detect_best_batch_kernel())
2138 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2139
2140 let js_output = RsiBatchJsOutput {
2141 values: output.values,
2142 combos: output.combos,
2143 rows: output.rows,
2144 cols: output.cols,
2145 };
2146
2147 serde_wasm_bindgen::to_value(&js_output)
2148 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2149}