1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::moving_averages::DeviceArrayF32;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use cust::context::Context;
7#[cfg(all(feature = "python", feature = "cuda"))]
8use cust::memory::DeviceBuffer;
9#[cfg(feature = "python")]
10use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
11#[cfg(feature = "python")]
12use pyo3::exceptions::PyValueError;
13#[cfg(feature = "python")]
14use pyo3::prelude::*;
15#[cfg(feature = "python")]
16use pyo3::types::PyDict;
17#[cfg(all(feature = "python", feature = "cuda"))]
18use std::sync::Arc;
19
20#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
21use serde::{Deserialize, Serialize};
22#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
23use wasm_bindgen::prelude::*;
24
25use crate::indicators::rsi::{rsi, RsiError, RsiInput, RsiOutput, RsiParams};
26use crate::indicators::stoch::{stoch, StochError, StochInput, StochOutput, StochParams};
27use crate::utilities::data_loader::{source_type, Candles};
28use crate::utilities::enums::Kernel;
29use crate::utilities::helpers::{
30 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
31 make_uninit_matrix,
32};
33#[cfg(feature = "python")]
34use crate::utilities::kernel_validation::validate_kernel;
35use aligned_vec::{AVec, CACHELINE_ALIGN};
36#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
37use core::arch::x86_64::*;
38#[cfg(not(target_arch = "wasm32"))]
39use rayon::prelude::*;
40use std::collections::VecDeque;
41use std::convert::AsRef;
42use std::error::Error;
43use std::mem::MaybeUninit;
44use thiserror::Error;
45
46impl<'a> AsRef<[f64]> for SrsiInput<'a> {
47 #[inline(always)]
48 fn as_ref(&self) -> &[f64] {
49 match &self.data {
50 SrsiData::Slice(slice) => slice,
51 SrsiData::Candles { candles, source } => source_type(candles, source),
52 }
53 }
54}
55
56#[derive(Debug, Clone)]
57pub enum SrsiData<'a> {
58 Candles {
59 candles: &'a Candles,
60 source: &'a str,
61 },
62 Slice(&'a [f64]),
63}
64
65#[derive(Debug, Clone)]
66pub struct SrsiOutput {
67 pub k: Vec<f64>,
68 pub d: Vec<f64>,
69}
70
71#[derive(Debug, Clone)]
72#[cfg_attr(
73 all(target_arch = "wasm32", feature = "wasm"),
74 derive(Serialize, Deserialize)
75)]
76pub struct SrsiParams {
77 pub rsi_period: Option<usize>,
78 pub stoch_period: Option<usize>,
79 pub k: Option<usize>,
80 pub d: Option<usize>,
81 pub source: Option<String>,
82}
83
84impl Default for SrsiParams {
85 fn default() -> Self {
86 Self {
87 rsi_period: Some(14),
88 stoch_period: Some(14),
89 k: Some(3),
90 d: Some(3),
91 source: Some("close".to_string()),
92 }
93 }
94}
95
96#[derive(Debug, Clone)]
97pub struct SrsiInput<'a> {
98 pub data: SrsiData<'a>,
99 pub params: SrsiParams,
100}
101
102impl<'a> SrsiInput<'a> {
103 #[inline]
104 pub fn from_candles(c: &'a Candles, s: &'a str, p: SrsiParams) -> Self {
105 Self {
106 data: SrsiData::Candles {
107 candles: c,
108 source: s,
109 },
110 params: p,
111 }
112 }
113 #[inline]
114 pub fn from_slice(sl: &'a [f64], p: SrsiParams) -> Self {
115 Self {
116 data: SrsiData::Slice(sl),
117 params: p,
118 }
119 }
120 #[inline]
121 pub fn with_default_candles(c: &'a Candles) -> Self {
122 Self::from_candles(c, "close", SrsiParams::default())
123 }
124 #[inline]
125 pub fn get_rsi_period(&self) -> usize {
126 self.params.rsi_period.unwrap_or(14)
127 }
128 #[inline]
129 pub fn get_stoch_period(&self) -> usize {
130 self.params.stoch_period.unwrap_or(14)
131 }
132 #[inline]
133 pub fn get_k(&self) -> usize {
134 self.params.k.unwrap_or(3)
135 }
136 #[inline]
137 pub fn get_d(&self) -> usize {
138 self.params.d.unwrap_or(3)
139 }
140 #[inline]
141 pub fn get_source(&self) -> &str {
142 self.params.source.as_deref().unwrap_or("close")
143 }
144}
145
146#[derive(Clone, Debug)]
147pub struct SrsiBuilder {
148 rsi_period: Option<usize>,
149 stoch_period: Option<usize>,
150 k: Option<usize>,
151 d: Option<usize>,
152 source: Option<String>,
153 kernel: Kernel,
154}
155
156impl Default for SrsiBuilder {
157 fn default() -> Self {
158 Self {
159 rsi_period: None,
160 stoch_period: None,
161 k: None,
162 d: None,
163 source: None,
164 kernel: Kernel::Auto,
165 }
166 }
167}
168
169impl SrsiBuilder {
170 #[inline(always)]
171 pub fn new() -> Self {
172 Self::default()
173 }
174 #[inline(always)]
175 pub fn rsi_period(mut self, n: usize) -> Self {
176 self.rsi_period = Some(n);
177 self
178 }
179 #[inline(always)]
180 pub fn stoch_period(mut self, n: usize) -> Self {
181 self.stoch_period = Some(n);
182 self
183 }
184 #[inline(always)]
185 pub fn k(mut self, n: usize) -> Self {
186 self.k = Some(n);
187 self
188 }
189 #[inline(always)]
190 pub fn d(mut self, n: usize) -> Self {
191 self.d = Some(n);
192 self
193 }
194 #[inline(always)]
195 pub fn source<S: Into<String>>(mut self, s: S) -> Self {
196 self.source = Some(s.into());
197 self
198 }
199 #[inline(always)]
200 pub fn kernel(mut self, k: Kernel) -> Self {
201 self.kernel = k;
202 self
203 }
204
205 #[inline(always)]
206 pub fn apply(self, c: &Candles) -> Result<SrsiOutput, SrsiError> {
207 let p = SrsiParams {
208 rsi_period: self.rsi_period,
209 stoch_period: self.stoch_period,
210 k: self.k,
211 d: self.d,
212 source: self.source.clone(),
213 };
214 let i = SrsiInput::from_candles(c, self.source.as_deref().unwrap_or("close"), p);
215 srsi_with_kernel(&i, self.kernel)
216 }
217
218 #[inline(always)]
219 pub fn apply_slice(self, d: &[f64]) -> Result<SrsiOutput, SrsiError> {
220 let p = SrsiParams {
221 rsi_period: self.rsi_period,
222 stoch_period: self.stoch_period,
223 k: self.k,
224 d: self.d,
225 source: self.source.clone(),
226 };
227 let i = SrsiInput::from_slice(d, p);
228 srsi_with_kernel(&i, self.kernel)
229 }
230}
231
232#[derive(Debug, Error)]
233pub enum SrsiError {
234 #[error("srsi: Error from RSI calculation: {0}")]
235 RsiError(#[from] RsiError),
236 #[error("srsi: Error from Stochastic calculation: {0}")]
237 StochError(#[from] StochError),
238 #[error("srsi: Input data is empty.")]
239 EmptyInputData,
240 #[error("srsi: All input data values are NaN.")]
241 AllValuesNaN,
242 #[error("srsi: Invalid period {period} for data length {data_len}.")]
243 InvalidPeriod { period: usize, data_len: usize },
244 #[error(
245 "srsi: Not enough valid data for the requested period. needed={needed}, valid={valid}"
246 )]
247 NotEnoughValidData { needed: usize, valid: usize },
248 #[error("srsi: Output length mismatch - destination buffers must match input data length. Expected {expected}, got k={k_len}, d={d_len}")]
249 OutputLengthMismatch {
250 expected: usize,
251 k_len: usize,
252 d_len: usize,
253 },
254 #[error("srsi: Invalid range: start={start}, end={end}, step={step}")]
255 InvalidRange {
256 start: String,
257 end: String,
258 step: String,
259 },
260 #[error("srsi: Invalid kernel for batch: {0:?}")]
261 InvalidKernelForBatch(Kernel),
262}
263
264#[inline]
265pub fn srsi(input: &SrsiInput) -> Result<SrsiOutput, SrsiError> {
266 srsi_with_kernel(input, Kernel::Auto)
267}
268
269pub fn srsi_with_kernel(input: &SrsiInput, kernel: Kernel) -> Result<SrsiOutput, SrsiError> {
270 let data: &[f64] = match &input.data {
271 SrsiData::Candles { candles, source } => source_type(candles, source),
272 SrsiData::Slice(sl) => sl,
273 };
274
275 if data.is_empty() {
276 return Err(SrsiError::EmptyInputData);
277 }
278
279 let first = data
280 .iter()
281 .position(|x| !x.is_nan())
282 .ok_or(SrsiError::AllValuesNaN)?;
283 let len = data.len();
284 let rsi_period = input.get_rsi_period();
285 let stoch_period = input.get_stoch_period();
286 let k_len = input.get_k();
287 let d_len = input.get_d();
288
289 let needed = rsi_period.max(stoch_period).max(k_len).max(d_len);
290 let valid = len - first;
291 if valid < needed {
292 return Err(SrsiError::NotEnoughValidData { needed, valid });
293 }
294
295 let chosen = match kernel {
296 Kernel::Auto => Kernel::Scalar,
297 other => other,
298 };
299
300 unsafe {
301 match chosen {
302 Kernel::Scalar | Kernel::ScalarBatch => {
303 srsi_scalar(data, rsi_period, stoch_period, k_len, d_len)
304 }
305 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
306 Kernel::Avx2 | Kernel::Avx2Batch => {
307 srsi_avx2(data, rsi_period, stoch_period, k_len, d_len)
308 }
309 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
310 Kernel::Avx512 | Kernel::Avx512Batch => {
311 srsi_avx512(data, rsi_period, stoch_period, k_len, d_len)
312 }
313 _ => unreachable!(),
314 }
315 }
316}
317
318#[inline]
319pub unsafe fn srsi_scalar(
320 data: &[f64],
321 rsi_period: usize,
322 stoch_period: usize,
323 k_period: usize,
324 d_period: usize,
325) -> Result<SrsiOutput, SrsiError> {
326 if rsi_period == 14 && stoch_period == 14 && k_period == 3 && d_period == 3 {
327 return srsi_scalar_classic(data, rsi_period, stoch_period, k_period, d_period);
328 }
329
330 let n = data.len();
331 if n == 0 {
332 return Err(SrsiError::EmptyInputData);
333 }
334 if rsi_period == 0 || stoch_period == 0 || k_period == 0 || d_period == 0 {
335 return Err(SrsiError::InvalidPeriod {
336 period: rsi_period.max(stoch_period).max(k_period).max(d_period),
337 data_len: n,
338 });
339 }
340 let first = data
341 .iter()
342 .position(|x| !x.is_nan())
343 .ok_or(SrsiError::AllValuesNaN)?;
344 let max_need = rsi_period.max(stoch_period).max(k_period).max(d_period);
345 if n - first < max_need {
346 return Err(SrsiError::NotEnoughValidData {
347 needed: max_need,
348 valid: n - first,
349 });
350 }
351
352 let rsi_warmup = first + rsi_period;
353 let stoch_warmup = rsi_warmup + stoch_period - 1;
354 let k_warmup = stoch_warmup + k_period - 1;
355 let d_warmup = k_warmup + d_period - 1;
356
357 if n <= d_warmup {
358 return Err(SrsiError::NotEnoughValidData {
359 needed: d_warmup + 1,
360 valid: n,
361 });
362 }
363
364 let mut rsi_vals = alloc_with_nan_prefix(n, rsi_warmup);
365 let mut k_out = alloc_with_nan_prefix(n, k_warmup);
366 let mut d_out = alloc_with_nan_prefix(n, d_warmup);
367
368 let mut avg_gain = 0.0f64;
369 let mut avg_loss = 0.0f64;
370 let mut prev = *data.get_unchecked(first);
371 let end_init = (first + rsi_period).min(n.saturating_sub(1));
372 for i in (first + 1)..=end_init {
373 let cur = *data.get_unchecked(i);
374 if cur.is_finite() && prev.is_finite() {
375 let ch = cur - prev;
376 if ch > 0.0 {
377 avg_gain += ch;
378 } else {
379 avg_loss += -ch;
380 }
381 }
382 prev = cur;
383 }
384
385 let rp = rsi_period as f64;
386 avg_gain /= rp;
387 avg_loss /= rp;
388 let alpha = 1.0f64 / rp;
389
390 if rsi_warmup < n {
391 rsi_vals[rsi_warmup] = if avg_loss == 0.0 {
392 100.0
393 } else {
394 let rs = avg_gain / avg_loss;
395 100.0 - 100.0 / (1.0 + rs)
396 };
397 }
398
399 prev = *data.get_unchecked(rsi_warmup);
400 for i in (rsi_warmup + 1)..n {
401 let cur = *data.get_unchecked(i);
402 if cur.is_finite() && prev.is_finite() {
403 let ch = cur - prev;
404 let gain = if ch > 0.0 { ch } else { 0.0 };
405 let loss = if ch < 0.0 { -ch } else { 0.0 };
406 avg_gain = (gain - avg_gain).mul_add(alpha, avg_gain);
407 avg_loss = (loss - avg_loss).mul_add(alpha, avg_loss);
408 rsi_vals[i] = if avg_loss == 0.0 {
409 100.0
410 } else {
411 let rs = avg_gain / avg_loss;
412 100.0 - 100.0 / (1.0 + rs)
413 };
414 }
415 prev = cur;
416 }
417
418 let sp = stoch_period;
419 let kp = k_period;
420 let dp = d_period;
421 if rsi_warmup < n {
422 let m = n - rsi_warmup;
423 let base = rsi_warmup;
424
425 let mut pref_max = vec![0.0f64; m];
426 let mut suff_max = vec![0.0f64; m];
427 let mut pref_min = vec![0.0f64; m];
428 let mut suff_min = vec![0.0f64; m];
429
430 let blocks = (m + sp - 1) / sp;
431 let p_pref_max = pref_max.as_mut_ptr();
432 let p_pref_min = pref_min.as_mut_ptr();
433 let p_rsi = rsi_vals.as_ptr().add(base);
434 for b in 0..blocks {
435 let start = b * sp;
436 let end = core::cmp::min(start + sp, m);
437 if start >= end {
438 break;
439 }
440 unsafe {
441 let v0 = *p_rsi.add(start);
442 *p_pref_max.add(start) = v0;
443 *p_pref_min.add(start) = v0;
444 let mut i = start + 1;
445 while i < end {
446 let v = *p_rsi.add(i);
447 let pmx = *p_pref_max.add(i - 1);
448 let pmn = *p_pref_min.add(i - 1);
449 *p_pref_max.add(i) = if v > pmx { v } else { pmx };
450 *p_pref_min.add(i) = if v < pmn { v } else { pmn };
451 i += 1;
452 }
453 }
454 }
455
456 let p_suff_max = suff_max.as_mut_ptr();
457 let p_suff_min = suff_min.as_mut_ptr();
458 for b in 0..blocks {
459 let block_end_excl = core::cmp::min((b + 1) * sp, m);
460 if block_end_excl == 0 {
461 break;
462 }
463 let block_start = block_end_excl - core::cmp::min(sp, block_end_excl);
464 unsafe {
465 let last = block_end_excl - 1;
466 let v_last = *p_rsi.add(last);
467 *p_suff_max.add(last) = v_last;
468 *p_suff_min.add(last) = v_last;
469 let mut i = last;
470 while i > block_start {
471 let prev = i - 1;
472 let v = *p_rsi.add(prev);
473 let smx = *p_suff_max.add(i);
474 let smn = *p_suff_min.add(i);
475 *p_suff_max.add(prev) = if v > smx { v } else { smx };
476 *p_suff_min.add(prev) = if v < smn { v } else { smn };
477 i = prev;
478 }
479 }
480 }
481
482 let mut sum_k = 0.0f64;
483 let mut sum_d = 0.0f64;
484 let mut fk_ring = vec![0.0f64; kp];
485 let mut sk_ring = vec![0.0f64; dp];
486 let mut fk_pos = 0usize;
487 let mut sk_pos = 0usize;
488
489 let i0 = stoch_warmup;
490 let mut i = i0;
491 while i < n {
492 let t = i - base;
493 let t_start = t + 1 - sp;
494 let hi_l = suff_max[t_start];
495 let hi_r = pref_max[t];
496 let lo_l = suff_min[t_start];
497 let lo_r = pref_min[t];
498 let hi = if hi_l > hi_r { hi_l } else { hi_r };
499 let lo = if lo_l < lo_r { lo_l } else { lo_r };
500 let x = *rsi_vals.get_unchecked(i);
501 let fk = if hi > lo {
502 ((x - lo) * 100.0) / (hi - lo)
503 } else {
504 50.0
505 };
506
507 sum_k += fk;
508 if i >= i0 + kp {
509 sum_k -= *fk_ring.get_unchecked(fk_pos);
510 }
511 *fk_ring.get_unchecked_mut(fk_pos) = fk;
512 fk_pos += 1;
513 if fk_pos == kp {
514 fk_pos = 0;
515 }
516
517 if i >= k_warmup {
518 let sk = sum_k / (kp as f64);
519 *k_out.get_unchecked_mut(i) = sk;
520
521 sum_d += sk;
522 if i >= k_warmup + dp {
523 sum_d -= *sk_ring.get_unchecked(sk_pos);
524 }
525 *sk_ring.get_unchecked_mut(sk_pos) = sk;
526 sk_pos += 1;
527 if sk_pos == dp {
528 sk_pos = 0;
529 }
530
531 if i >= d_warmup {
532 *d_out.get_unchecked_mut(i) = sum_d / (dp as f64);
533 }
534 }
535 i += 1;
536 }
537 }
538
539 Ok(SrsiOutput { k: k_out, d: d_out })
540}
541
542#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
543#[inline]
544pub unsafe fn srsi_avx2(
545 data: &[f64],
546 rsi_period: usize,
547 stoch_period: usize,
548 k: usize,
549 d: usize,
550) -> Result<SrsiOutput, SrsiError> {
551 srsi_scalar(data, rsi_period, stoch_period, k, d)
552}
553
554#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
555#[inline]
556pub unsafe fn srsi_avx512(
557 data: &[f64],
558 rsi_period: usize,
559 stoch_period: usize,
560 k: usize,
561 d: usize,
562) -> Result<SrsiOutput, SrsiError> {
563 if stoch_period <= 32 {
564 srsi_avx512_short(data, rsi_period, stoch_period, k, d)
565 } else {
566 srsi_avx512_long(data, rsi_period, stoch_period, k, d)
567 }
568}
569
570#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
571#[inline]
572pub unsafe fn srsi_avx512_short(
573 data: &[f64],
574 rsi_period: usize,
575 stoch_period: usize,
576 k: usize,
577 d: usize,
578) -> Result<SrsiOutput, SrsiError> {
579 srsi_scalar(data, rsi_period, stoch_period, k, d)
580}
581
582#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
583#[inline]
584pub unsafe fn srsi_avx512_long(
585 data: &[f64],
586 rsi_period: usize,
587 stoch_period: usize,
588 k: usize,
589 d: usize,
590) -> Result<SrsiOutput, SrsiError> {
591 srsi_scalar(data, rsi_period, stoch_period, k, d)
592}
593
594#[inline]
595pub unsafe fn srsi_scalar_classic(
596 data: &[f64],
597 rsi_period: usize,
598 stoch_period: usize,
599 k_period: usize,
600 d_period: usize,
601) -> Result<SrsiOutput, SrsiError> {
602 let n = data.len();
603 if n == 0 {
604 return Err(SrsiError::EmptyInputData);
605 }
606 let first = data
607 .iter()
608 .position(|x| !x.is_nan())
609 .ok_or(SrsiError::AllValuesNaN)?;
610
611 let rsi_warmup = first + rsi_period;
612 let stoch_warmup = rsi_warmup + stoch_period - 1;
613 let k_warmup = stoch_warmup + k_period - 1;
614 let d_warmup = k_warmup + d_period - 1;
615
616 if n <= d_warmup {
617 return Err(SrsiError::NotEnoughValidData {
618 needed: d_warmup + 1,
619 valid: n,
620 });
621 }
622
623 let mut rsi_values = alloc_with_nan_prefix(n, rsi_warmup);
624
625 let mut avg_gain = 0.0;
626 let mut avg_loss = 0.0;
627 let mut prev = data[first];
628
629 for i in (first + 1)..(first + rsi_period + 1).min(n) {
630 if data[i].is_finite() && prev.is_finite() {
631 let change = data[i] - prev;
632 if change > 0.0 {
633 avg_gain += change;
634 } else {
635 avg_loss += -change;
636 }
637 prev = data[i];
638 }
639 }
640
641 avg_gain /= rsi_period as f64;
642 avg_loss /= rsi_period as f64;
643
644 let alpha = 1.0 / rsi_period as f64;
645 let alpha_1minus = 1.0 - alpha;
646
647 if first + rsi_period < n {
648 rsi_values[first + rsi_period] = if avg_loss == 0.0 {
649 100.0
650 } else {
651 100.0 - (100.0 / (1.0 + avg_gain / avg_loss))
652 };
653
654 prev = data[first + rsi_period];
655 }
656
657 for i in (first + rsi_period + 1)..n {
658 if data[i].is_finite() && prev.is_finite() {
659 let change = data[i] - prev;
660 let (gain, loss) = if change > 0.0 {
661 (change, 0.0)
662 } else {
663 (0.0, -change)
664 };
665
666 avg_gain = alpha * gain + alpha_1minus * avg_gain;
667 avg_loss = alpha * loss + alpha_1minus * avg_loss;
668
669 rsi_values[i] = if avg_loss == 0.0 {
670 100.0
671 } else {
672 100.0 - (100.0 / (1.0 + avg_gain / avg_loss))
673 };
674
675 prev = data[i];
676 }
677 }
678
679 let mut fast_k = alloc_with_nan_prefix(n, stoch_warmup);
680
681 for i in stoch_warmup..n {
682 let start = i + 1 - stoch_period;
683 let mut min_rsi = f64::MAX;
684 let mut max_rsi = f64::MIN;
685
686 for j in start..=i {
687 if rsi_values[j].is_finite() {
688 min_rsi = min_rsi.min(rsi_values[j]);
689 max_rsi = max_rsi.max(rsi_values[j]);
690 }
691 }
692
693 if max_rsi > min_rsi {
694 fast_k[i] = 100.0 * (rsi_values[i] - min_rsi) / (max_rsi - min_rsi);
695 } else {
696 fast_k[i] = 50.0;
697 }
698 }
699
700 let mut slow_k = alloc_with_nan_prefix(n, k_warmup);
701
702 let mut k_sum = 0.0;
703 for i in stoch_warmup..(stoch_warmup + k_period).min(n) {
704 if fast_k[i].is_finite() {
705 k_sum += fast_k[i];
706 }
707 }
708
709 if stoch_warmup + k_period <= n {
710 slow_k[stoch_warmup + k_period - 1] = k_sum / k_period as f64;
711
712 for i in (stoch_warmup + k_period)..n {
713 if fast_k[i].is_finite() && fast_k[i - k_period].is_finite() {
714 k_sum += fast_k[i] - fast_k[i - k_period];
715 slow_k[i] = k_sum / k_period as f64;
716 }
717 }
718 }
719
720 let mut slow_d = alloc_with_nan_prefix(n, d_warmup);
721
722 let mut d_sum = 0.0;
723 for i in k_warmup..(k_warmup + d_period).min(n) {
724 if slow_k[i].is_finite() {
725 d_sum += slow_k[i];
726 }
727 }
728
729 if k_warmup + d_period <= n {
730 slow_d[k_warmup + d_period - 1] = d_sum / d_period as f64;
731
732 for i in (k_warmup + d_period)..n {
733 if slow_k[i].is_finite() && slow_k[i - d_period].is_finite() {
734 d_sum += slow_k[i] - slow_k[i - d_period];
735 slow_d[i] = d_sum / d_period as f64;
736 }
737 }
738 }
739
740 Ok(SrsiOutput {
741 k: slow_k,
742 d: slow_d,
743 })
744}
745
746#[inline(always)]
747pub fn srsi_row_scalar(
748 data: &[f64],
749 rsi_period: usize,
750 stoch_period: usize,
751 k: usize,
752 d: usize,
753 k_out: &mut [f64],
754 d_out: &mut [f64],
755) {
756 if let Ok(res) = unsafe { srsi_scalar(data, rsi_period, stoch_period, k, d) } {
757 k_out.copy_from_slice(&res.k);
758 d_out.copy_from_slice(&res.d);
759 }
760}
761
762#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
763#[inline(always)]
764pub fn srsi_row_avx2(
765 data: &[f64],
766 rsi_period: usize,
767 stoch_period: usize,
768 k: usize,
769 d: usize,
770 k_out: &mut [f64],
771 d_out: &mut [f64],
772) {
773 srsi_row_scalar(data, rsi_period, stoch_period, k, d, k_out, d_out)
774}
775
776#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
777#[inline(always)]
778pub fn srsi_row_avx512(
779 data: &[f64],
780 rsi_period: usize,
781 stoch_period: usize,
782 k: usize,
783 d: usize,
784 k_out: &mut [f64],
785 d_out: &mut [f64],
786) {
787 srsi_row_scalar(data, rsi_period, stoch_period, k, d, k_out, d_out)
788}
789
790#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
791#[inline(always)]
792pub fn srsi_row_avx512_short(
793 data: &[f64],
794 rsi_period: usize,
795 stoch_period: usize,
796 k: usize,
797 d: usize,
798 k_out: &mut [f64],
799 d_out: &mut [f64],
800) {
801 srsi_row_scalar(data, rsi_period, stoch_period, k, d, k_out, d_out)
802}
803
804#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
805#[inline(always)]
806pub fn srsi_row_avx512_long(
807 data: &[f64],
808 rsi_period: usize,
809 stoch_period: usize,
810 k: usize,
811 d: usize,
812 k_out: &mut [f64],
813 d_out: &mut [f64],
814) {
815 srsi_row_scalar(data, rsi_period, stoch_period, k, d, k_out, d_out)
816}
817
818#[derive(Debug, Clone)]
819pub struct SrsiStream {
820 rsi_period: usize,
821 stoch_period: usize,
822 k_period: usize,
823 d_period: usize,
824
825 prev: f64,
826 has_prev: bool,
827 init_count: usize,
828 sum_gain: f64,
829 sum_loss: f64,
830 avg_gain: f64,
831 avg_loss: f64,
832 alpha: f64,
833 rsi_ready: bool,
834 rsi_index: usize,
835 last_rsi: f64,
836
837 max_q: VecDeque<(usize, f64)>,
838 min_q: VecDeque<(usize, f64)>,
839
840 fk_ring: Vec<f64>,
841 fk_sum: f64,
842 fk_pos: usize,
843 fk_count: usize,
844 inv_k: f64,
845
846 sk_ring: Vec<f64>,
847 sk_sum: f64,
848 sk_pos: usize,
849 sk_count: usize,
850 inv_d: f64,
851}
852
853impl SrsiStream {
854 pub fn try_new(params: SrsiParams) -> Result<Self, SrsiError> {
855 let rsi_period = params.rsi_period.unwrap_or(14);
856 let stoch_period = params.stoch_period.unwrap_or(14);
857 let k_period = params.k.unwrap_or(3);
858 let d_period = params.d.unwrap_or(3);
859
860 if rsi_period == 0 || stoch_period == 0 || k_period == 0 || d_period == 0 {
861 return Err(SrsiError::InvalidPeriod {
862 period: rsi_period.max(stoch_period).max(k_period).max(d_period),
863 data_len: 0,
864 });
865 }
866
867 Ok(Self {
868 rsi_period,
869 stoch_period,
870 k_period,
871 d_period,
872
873 prev: f64::NAN,
874 has_prev: false,
875 init_count: 0,
876 sum_gain: 0.0,
877 sum_loss: 0.0,
878 avg_gain: 0.0,
879 avg_loss: 0.0,
880 alpha: 1.0 / (rsi_period as f64),
881 rsi_ready: false,
882 rsi_index: 0,
883 last_rsi: f64::NAN,
884
885 max_q: VecDeque::with_capacity(stoch_period),
886 min_q: VecDeque::with_capacity(stoch_period),
887
888 fk_ring: vec![0.0; k_period],
889 fk_sum: 0.0,
890 fk_pos: 0,
891 fk_count: 0,
892 inv_k: 1.0 / (k_period as f64),
893
894 sk_ring: vec![0.0; d_period],
895 sk_sum: 0.0,
896 sk_pos: 0,
897 sk_count: 0,
898 inv_d: 1.0 / (d_period as f64),
899 })
900 }
901
902 #[inline]
903 pub fn reset(&mut self) {
904 self.prev = f64::NAN;
905 self.has_prev = false;
906 self.init_count = 0;
907 self.sum_gain = 0.0;
908 self.sum_loss = 0.0;
909 self.avg_gain = 0.0;
910 self.avg_loss = 0.0;
911 self.rsi_ready = false;
912 self.rsi_index = 0;
913 self.last_rsi = f64::NAN;
914 self.max_q.clear();
915 self.min_q.clear();
916 self.fk_ring.fill(0.0);
917 self.fk_sum = 0.0;
918 self.fk_pos = 0;
919 self.fk_count = 0;
920 self.sk_ring.fill(0.0);
921 self.sk_sum = 0.0;
922 self.sk_pos = 0;
923 self.sk_count = 0;
924 }
925
926 pub fn update(&mut self, v: f64) -> Option<(f64, f64)> {
927 if !v.is_finite() {
928 self.reset();
929 return None;
930 }
931
932 if !self.has_prev {
933 self.prev = v;
934 self.has_prev = true;
935 return None;
936 }
937
938 let ch = v - self.prev;
939 self.prev = v;
940
941 if !self.rsi_ready {
942 if ch > 0.0 {
943 self.sum_gain += ch;
944 } else {
945 self.sum_loss += -ch;
946 }
947 self.init_count += 1;
948
949 if self.init_count < self.rsi_period {
950 return None;
951 }
952
953 self.avg_gain = self.sum_gain / (self.rsi_period as f64);
954 self.avg_loss = self.sum_loss / (self.rsi_period as f64);
955
956 let rsi = if self.avg_loss == 0.0 {
957 100.0
958 } else {
959 let rs = self.avg_gain / self.avg_loss;
960 100.0 - 100.0 / (1.0 + rs)
961 };
962 self.last_rsi = rsi;
963 self.rsi_ready = true;
964 self.rsi_index = 0;
965
966 self.push_rsi_to_deques(self.rsi_index, rsi);
967
968 return None;
969 }
970
971 let gain = if ch > 0.0 { ch } else { 0.0 };
972 let loss = if ch < 0.0 { -ch } else { 0.0 };
973 self.avg_gain = (gain - self.avg_gain).mul_add(self.alpha, self.avg_gain);
974 self.avg_loss = (loss - self.avg_loss).mul_add(self.alpha, self.avg_loss);
975
976 let rsi = if self.avg_loss == 0.0 {
977 100.0
978 } else {
979 let rs = self.avg_gain / self.avg_loss;
980 100.0 - 100.0 / (1.0 + rs)
981 };
982 self.last_rsi = rsi;
983
984 self.rsi_index += 1;
985 self.push_rsi_to_deques(self.rsi_index, rsi);
986
987 if self.rsi_index + 1 < self.stoch_period {
988 return None;
989 }
990
991 let start = self.rsi_index + 1 - self.stoch_period;
992
993 while let Some(&(j, _)) = self.max_q.front() {
994 if j < start {
995 self.max_q.pop_front();
996 } else {
997 break;
998 }
999 }
1000 while let Some(&(j, _)) = self.min_q.front() {
1001 if j < start {
1002 self.min_q.pop_front();
1003 } else {
1004 break;
1005 }
1006 }
1007
1008 debug_assert!(!self.max_q.is_empty() && !self.min_q.is_empty());
1009 let hi = self.max_q.front().unwrap().1;
1010 let lo = self.min_q.front().unwrap().1;
1011
1012 let fast_k = if hi > lo {
1013 let range = hi - lo;
1014 ((rsi - lo) * 100.0) / range
1015 } else {
1016 50.0
1017 };
1018
1019 let slow_k_opt = Self::push_sma(
1020 fast_k,
1021 &mut self.fk_ring,
1022 &mut self.fk_sum,
1023 &mut self.fk_pos,
1024 &mut self.fk_count,
1025 self.k_period,
1026 self.inv_k,
1027 );
1028
1029 let slow_k = match slow_k_opt {
1030 None => return None,
1031 Some(v) => v,
1032 };
1033
1034 let slow_d_opt = Self::push_sma(
1035 slow_k,
1036 &mut self.sk_ring,
1037 &mut self.sk_sum,
1038 &mut self.sk_pos,
1039 &mut self.sk_count,
1040 self.d_period,
1041 self.inv_d,
1042 );
1043
1044 slow_d_opt.map(|d| (slow_k, d))
1045 }
1046
1047 #[inline(always)]
1048 fn push_rsi_to_deques(&mut self, idx: usize, rsi: f64) {
1049 while let Some(&(_, v)) = self.max_q.back() {
1050 if v <= rsi {
1051 self.max_q.pop_back();
1052 } else {
1053 break;
1054 }
1055 }
1056 if self.max_q.len() == self.stoch_period {
1057 self.max_q.pop_front();
1058 }
1059 self.max_q.push_back((idx, rsi));
1060
1061 while let Some(&(_, v)) = self.min_q.back() {
1062 if v >= rsi {
1063 self.min_q.pop_back();
1064 } else {
1065 break;
1066 }
1067 }
1068 if self.min_q.len() == self.stoch_period {
1069 self.min_q.pop_front();
1070 }
1071 self.min_q.push_back((idx, rsi));
1072 }
1073
1074 #[inline(always)]
1075 fn push_sma(
1076 new_val: f64,
1077 ring: &mut [f64],
1078 sum: &mut f64,
1079 pos: &mut usize,
1080 count: &mut usize,
1081 period: usize,
1082 inv_period: f64,
1083 ) -> Option<f64> {
1084 if *count < period {
1085 *sum += new_val;
1086 ring[*pos] = new_val;
1087 *pos += 1;
1088 if *pos == period {
1089 *pos = 0;
1090 }
1091 *count += 1;
1092 if *count == period {
1093 Some(*sum * inv_period)
1094 } else {
1095 None
1096 }
1097 } else {
1098 *sum += new_val - ring[*pos];
1099 ring[*pos] = new_val;
1100 *pos += 1;
1101 if *pos == period {
1102 *pos = 0;
1103 }
1104 Some(*sum * inv_period)
1105 }
1106 }
1107}
1108
1109#[derive(Clone, Debug)]
1110pub struct SrsiBatchRange {
1111 pub rsi_period: (usize, usize, usize),
1112 pub stoch_period: (usize, usize, usize),
1113 pub k: (usize, usize, usize),
1114 pub d: (usize, usize, usize),
1115}
1116
1117impl Default for SrsiBatchRange {
1118 fn default() -> Self {
1119 Self {
1120 rsi_period: (14, 263, 1),
1121 stoch_period: (14, 14, 0),
1122 k: (3, 3, 0),
1123 d: (3, 3, 0),
1124 }
1125 }
1126}
1127
1128#[derive(Clone, Debug, Default)]
1129pub struct SrsiBatchBuilder {
1130 range: SrsiBatchRange,
1131 kernel: Kernel,
1132}
1133
1134impl SrsiBatchBuilder {
1135 pub fn new() -> Self {
1136 Self::default()
1137 }
1138 pub fn kernel(mut self, k: Kernel) -> Self {
1139 self.kernel = k;
1140 self
1141 }
1142 #[inline]
1143 pub fn rsi_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1144 self.range.rsi_period = (start, end, step);
1145 self
1146 }
1147 #[inline]
1148 pub fn stoch_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1149 self.range.stoch_period = (start, end, step);
1150 self
1151 }
1152 #[inline]
1153 pub fn k_range(mut self, start: usize, end: usize, step: usize) -> Self {
1154 self.range.k = (start, end, step);
1155 self
1156 }
1157 #[inline]
1158 pub fn d_range(mut self, start: usize, end: usize, step: usize) -> Self {
1159 self.range.d = (start, end, step);
1160 self
1161 }
1162 pub fn apply_slice(self, data: &[f64]) -> Result<SrsiBatchOutput, SrsiError> {
1163 srsi_batch_with_kernel(data, &self.range, self.kernel)
1164 }
1165}
1166
1167#[derive(Clone, Debug)]
1168pub struct SrsiBatchOutput {
1169 pub k: Vec<f64>,
1170 pub d: Vec<f64>,
1171 pub combos: Vec<SrsiParams>,
1172 pub rows: usize,
1173 pub cols: usize,
1174}
1175impl SrsiBatchOutput {
1176 pub fn row_for_params(&self, p: &SrsiParams) -> Option<usize> {
1177 self.combos.iter().position(|c| {
1178 c.rsi_period.unwrap_or(14) == p.rsi_period.unwrap_or(14)
1179 && c.stoch_period.unwrap_or(14) == p.stoch_period.unwrap_or(14)
1180 && c.k.unwrap_or(3) == p.k.unwrap_or(3)
1181 && c.d.unwrap_or(3) == p.d.unwrap_or(3)
1182 })
1183 }
1184 pub fn k_for(&self, p: &SrsiParams) -> Option<&[f64]> {
1185 self.row_for_params(p).map(|row| {
1186 let start = row * self.cols;
1187 &self.k[start..start + self.cols]
1188 })
1189 }
1190 pub fn d_for(&self, p: &SrsiParams) -> Option<&[f64]> {
1191 self.row_for_params(p).map(|row| {
1192 let start = row * self.cols;
1193 &self.d[start..start + self.cols]
1194 })
1195 }
1196}
1197
1198#[inline(always)]
1199fn expand_grid(r: &SrsiBatchRange) -> Result<Vec<SrsiParams>, SrsiError> {
1200 fn axis((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, SrsiError> {
1201 if step == 0 || start == end {
1202 return Ok(vec![start]);
1203 }
1204 if start < end {
1205 let st = step.max(1);
1206 let v: Vec<usize> = (start..=end).step_by(st).collect();
1207 if v.is_empty() {
1208 return Err(SrsiError::InvalidRange {
1209 start: start.to_string(),
1210 end: end.to_string(),
1211 step: step.to_string(),
1212 });
1213 }
1214 return Ok(v);
1215 }
1216 let mut v = Vec::new();
1217 let mut x = start as isize;
1218 let end_i = end as isize;
1219 let st = (step as isize).max(1);
1220 while x >= end_i {
1221 v.push(x as usize);
1222 x -= st;
1223 }
1224 if v.is_empty() {
1225 return Err(SrsiError::InvalidRange {
1226 start: start.to_string(),
1227 end: end.to_string(),
1228 step: step.to_string(),
1229 });
1230 }
1231 Ok(v)
1232 }
1233 let rsi_periods = axis(r.rsi_period)?;
1234 let stoch_periods = axis(r.stoch_period)?;
1235 let ks = axis(r.k)?;
1236 let ds = axis(r.d)?;
1237
1238 if rsi_periods.is_empty() || stoch_periods.is_empty() || ks.is_empty() || ds.is_empty() {
1239 return Err(SrsiError::InvalidRange {
1240 start: r.rsi_period.0.to_string(),
1241 end: r.rsi_period.1.to_string(),
1242 step: r.rsi_period.2.to_string(),
1243 });
1244 }
1245
1246 let mut out = Vec::with_capacity(rsi_periods.len() * stoch_periods.len() * ks.len() * ds.len());
1247 for &rsi_p in &rsi_periods {
1248 for &stoch_p in &stoch_periods {
1249 for &k in &ks {
1250 for &d in &ds {
1251 out.push(SrsiParams {
1252 rsi_period: Some(rsi_p),
1253 stoch_period: Some(stoch_p),
1254 k: Some(k),
1255 d: Some(d),
1256 source: None,
1257 });
1258 }
1259 }
1260 }
1261 }
1262 Ok(out)
1263}
1264
1265#[inline(always)]
1266pub fn srsi_batch_with_kernel(
1267 data: &[f64],
1268 sweep: &SrsiBatchRange,
1269 k: Kernel,
1270) -> Result<SrsiBatchOutput, SrsiError> {
1271 let kernel = match k {
1272 Kernel::Auto => detect_best_batch_kernel(),
1273 other if other.is_batch() => other,
1274 _ => return Err(SrsiError::InvalidKernelForBatch(k)),
1275 };
1276 let simd = match kernel {
1277 Kernel::Avx512Batch => Kernel::Avx512,
1278 Kernel::Avx2Batch => Kernel::Avx2,
1279 Kernel::ScalarBatch => Kernel::Scalar,
1280 _ => unreachable!(),
1281 };
1282 srsi_batch_par_slice(data, sweep, simd)
1283}
1284
1285#[inline(always)]
1286pub fn srsi_batch_slice(
1287 data: &[f64],
1288 sweep: &SrsiBatchRange,
1289 kern: Kernel,
1290) -> Result<SrsiBatchOutput, SrsiError> {
1291 srsi_batch_inner(data, sweep, kern, false)
1292}
1293
1294#[inline(always)]
1295pub fn srsi_batch_par_slice(
1296 data: &[f64],
1297 sweep: &SrsiBatchRange,
1298 kern: Kernel,
1299) -> Result<SrsiBatchOutput, SrsiError> {
1300 srsi_batch_inner(data, sweep, kern, true)
1301}
1302
1303#[inline(always)]
1304fn srsi_batch_inner_into(
1305 data: &[f64],
1306 sweep: &SrsiBatchRange,
1307 kern: Kernel,
1308 parallel: bool,
1309 k_out: &mut [f64],
1310 d_out: &mut [f64],
1311) -> Result<Vec<SrsiParams>, SrsiError> {
1312 let combos = expand_grid(sweep)?;
1313 if combos.is_empty() {
1314 return Err(SrsiError::InvalidRange {
1315 start: sweep.rsi_period.0.to_string(),
1316 end: sweep.rsi_period.1.to_string(),
1317 step: sweep.rsi_period.2.to_string(),
1318 });
1319 }
1320
1321 let first = data
1322 .iter()
1323 .position(|x| !x.is_nan())
1324 .ok_or(SrsiError::AllValuesNaN)?;
1325
1326 use std::collections::{BTreeMap, BTreeSet};
1327 let mut rsi_cache: BTreeMap<usize, Vec<f64>> = BTreeMap::new();
1328 let uniq_rsi: BTreeSet<usize> = combos.iter().map(|c| c.rsi_period.unwrap()).collect();
1329 for rp in uniq_rsi {
1330 let rsi_in = RsiInput::from_slice(data, RsiParams { period: Some(rp) });
1331 let rsi_out = rsi(&rsi_in)?;
1332 rsi_cache.insert(rp, rsi_out.values);
1333 }
1334
1335 let max_period = combos
1336 .iter()
1337 .map(|c| {
1338 c.rsi_period
1339 .unwrap()
1340 .max(c.stoch_period.unwrap())
1341 .max(c.k.unwrap())
1342 .max(c.d.unwrap())
1343 })
1344 .max()
1345 .unwrap();
1346
1347 if data.len() - first < max_period {
1348 return Err(SrsiError::NotEnoughValidData {
1349 needed: max_period,
1350 valid: data.len() - first,
1351 });
1352 }
1353
1354 let cols = data.len();
1355
1356 let do_row = |row: usize, k_row: &mut [f64], d_row: &mut [f64]| -> Result<(), SrsiError> {
1357 let prm = &combos[row];
1358 let rsi_vals = rsi_cache.get(&prm.rsi_period.unwrap()).expect("cached rsi");
1359 let st_in = StochInput {
1360 data: crate::indicators::stoch::StochData::Slices {
1361 high: rsi_vals,
1362 low: rsi_vals,
1363 close: rsi_vals,
1364 },
1365 params: StochParams {
1366 fastk_period: prm.stoch_period,
1367 slowk_period: prm.k,
1368 slowk_ma_type: Some("sma".to_string()),
1369 slowd_period: prm.d,
1370 slowd_ma_type: Some("sma".to_string()),
1371 },
1372 };
1373 let st = stoch(&st_in)?;
1374 k_row.copy_from_slice(&st.k);
1375 d_row.copy_from_slice(&st.d);
1376 Ok(())
1377 };
1378
1379 if parallel {
1380 #[cfg(not(target_arch = "wasm32"))]
1381 {
1382 k_out
1383 .par_chunks_mut(cols)
1384 .zip(d_out.par_chunks_mut(cols))
1385 .enumerate()
1386 .try_for_each(|(row, (k_row, d_row))| do_row(row, k_row, d_row))?;
1387 }
1388
1389 #[cfg(target_arch = "wasm32")]
1390 {
1391 for (row, (k_row, d_row)) in k_out
1392 .chunks_mut(cols)
1393 .zip(d_out.chunks_mut(cols))
1394 .enumerate()
1395 {
1396 do_row(row, k_row, d_row)?;
1397 }
1398 }
1399 } else {
1400 for (row, (k_row, d_row)) in k_out
1401 .chunks_mut(cols)
1402 .zip(d_out.chunks_mut(cols))
1403 .enumerate()
1404 {
1405 do_row(row, k_row, d_row)?;
1406 }
1407 }
1408
1409 Ok(combos)
1410}
1411
1412#[inline(always)]
1413fn srsi_batch_inner(
1414 data: &[f64],
1415 sweep: &SrsiBatchRange,
1416 kern: Kernel,
1417 parallel: bool,
1418) -> Result<SrsiBatchOutput, SrsiError> {
1419 let combos = expand_grid(sweep)?;
1420 if combos.is_empty() {
1421 return Err(SrsiError::InvalidRange {
1422 start: sweep.rsi_period.0.to_string(),
1423 end: sweep.rsi_period.1.to_string(),
1424 step: sweep.rsi_period.2.to_string(),
1425 });
1426 }
1427 let first = data
1428 .iter()
1429 .position(|x| !x.is_nan())
1430 .ok_or(SrsiError::AllValuesNaN)?;
1431 let max_period = combos
1432 .iter()
1433 .map(|c| {
1434 c.rsi_period
1435 .unwrap()
1436 .max(c.stoch_period.unwrap())
1437 .max(c.k.unwrap())
1438 .max(c.d.unwrap())
1439 })
1440 .max()
1441 .unwrap();
1442 if data.len() - first < max_period {
1443 return Err(SrsiError::NotEnoughValidData {
1444 needed: max_period,
1445 valid: data.len() - first,
1446 });
1447 }
1448 let rows = combos.len();
1449 let cols = data.len();
1450 let _ = rows
1451 .checked_mul(cols)
1452 .ok_or_else(|| SrsiError::InvalidRange {
1453 start: sweep.rsi_period.0.to_string(),
1454 end: sweep.rsi_period.1.to_string(),
1455 step: sweep.rsi_period.2.to_string(),
1456 })?;
1457
1458 if rows == 1 {
1459 let prm = &combos[0];
1460 let res = unsafe {
1461 match kern {
1462 Kernel::Avx512 | Kernel::Avx512Batch => {
1463 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1464 {
1465 srsi_avx512(
1466 data,
1467 prm.rsi_period.unwrap(),
1468 prm.stoch_period.unwrap(),
1469 prm.k.unwrap(),
1470 prm.d.unwrap(),
1471 )
1472 }
1473 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1474 {
1475 srsi_scalar(
1476 data,
1477 prm.rsi_period.unwrap(),
1478 prm.stoch_period.unwrap(),
1479 prm.k.unwrap(),
1480 prm.d.unwrap(),
1481 )
1482 }
1483 }
1484 Kernel::Avx2 | Kernel::Avx2Batch => {
1485 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1486 {
1487 srsi_avx2(
1488 data,
1489 prm.rsi_period.unwrap(),
1490 prm.stoch_period.unwrap(),
1491 prm.k.unwrap(),
1492 prm.d.unwrap(),
1493 )
1494 }
1495 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1496 {
1497 srsi_scalar(
1498 data,
1499 prm.rsi_period.unwrap(),
1500 prm.stoch_period.unwrap(),
1501 prm.k.unwrap(),
1502 prm.d.unwrap(),
1503 )
1504 }
1505 }
1506 _ => srsi_scalar(
1507 data,
1508 prm.rsi_period.unwrap(),
1509 prm.stoch_period.unwrap(),
1510 prm.k.unwrap(),
1511 prm.d.unwrap(),
1512 ),
1513 }
1514 }?;
1515 return Ok(SrsiBatchOutput {
1516 k: res.k,
1517 d: res.d,
1518 combos,
1519 rows: 1,
1520 cols,
1521 });
1522 }
1523 let mut k_vals = make_uninit_matrix(rows, cols);
1524 let mut d_vals = make_uninit_matrix(rows, cols);
1525
1526 fn warm_for(c: &SrsiParams, first: usize) -> usize {
1527 let rp = c.rsi_period.unwrap();
1528 let sp = c.stoch_period.unwrap();
1529 let kp = c.k.unwrap();
1530 let dp = c.d.unwrap();
1531
1532 first + rp - 1 + sp - 1 + kp.max(dp) - 1
1533 }
1534
1535 let warmup_periods: Vec<usize> = combos
1536 .iter()
1537 .map(|c| warm_for(c, first).min(cols))
1538 .collect();
1539
1540 init_matrix_prefixes(&mut k_vals, cols, &warmup_periods);
1541 init_matrix_prefixes(&mut d_vals, cols, &warmup_periods);
1542
1543 let mut k_guard = core::mem::ManuallyDrop::new(k_vals);
1544 let mut d_guard = core::mem::ManuallyDrop::new(d_vals);
1545 let k_out: &mut [f64] =
1546 unsafe { core::slice::from_raw_parts_mut(k_guard.as_mut_ptr() as *mut f64, k_guard.len()) };
1547 let d_out: &mut [f64] =
1548 unsafe { core::slice::from_raw_parts_mut(d_guard.as_mut_ptr() as *mut f64, d_guard.len()) };
1549
1550 let combos = srsi_batch_inner_into(data, sweep, kern, parallel, k_out, d_out)?;
1551
1552 let k_values = unsafe {
1553 Vec::from_raw_parts(
1554 k_guard.as_mut_ptr() as *mut f64,
1555 k_guard.len(),
1556 k_guard.capacity(),
1557 )
1558 };
1559
1560 let d_values = unsafe {
1561 Vec::from_raw_parts(
1562 d_guard.as_mut_ptr() as *mut f64,
1563 d_guard.len(),
1564 d_guard.capacity(),
1565 )
1566 };
1567
1568 Ok(SrsiBatchOutput {
1569 k: k_values,
1570 d: d_values,
1571 combos,
1572 rows,
1573 cols,
1574 })
1575}
1576
1577#[inline(always)]
1578pub fn expand_grid_srsi(r: &SrsiBatchRange) -> Result<Vec<SrsiParams>, SrsiError> {
1579 expand_grid(r)
1580}
1581
1582#[cfg(all(feature = "python", feature = "cuda"))]
1583#[pyclass(module = "ta_indicators.cuda", name = "SrsiDeviceArrayF32", unsendable)]
1584pub struct SrsiDeviceArrayF32Py {
1585 pub(crate) inner: DeviceArrayF32,
1586 pub(crate) _ctx: Arc<Context>,
1587 pub(crate) device_id: u32,
1588}
1589
1590#[cfg(all(feature = "python", feature = "cuda"))]
1591#[pymethods]
1592impl SrsiDeviceArrayF32Py {
1593 #[getter]
1594 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1595 let d = PyDict::new(py);
1596 let itemsize = std::mem::size_of::<f32>();
1597 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1598 d.set_item("typestr", "<f4")?;
1599 d.set_item("strides", (self.inner.cols * itemsize, itemsize))?;
1600 d.set_item("data", (self.inner.device_ptr() as usize, false))?;
1601
1602 d.set_item("version", 3)?;
1603 Ok(d)
1604 }
1605
1606 fn __dlpack_device__(&self) -> (i32, i32) {
1607 (2, self.device_id as i32)
1608 }
1609
1610 #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
1611 fn __dlpack__<'py>(
1612 &mut self,
1613 py: Python<'py>,
1614 stream: Option<PyObject>,
1615 max_version: Option<PyObject>,
1616 dl_device: Option<PyObject>,
1617 copy: Option<PyObject>,
1618 ) -> PyResult<PyObject> {
1619 if let Some(sobj) = stream.as_ref() {
1620 if let Ok(s) = sobj.extract::<usize>(py) {
1621 if s == 0 {
1622 return Err(PyValueError::new_err(
1623 "__dlpack__ stream=0 is invalid for CUDA",
1624 ));
1625 }
1626 }
1627 }
1628
1629 let (kdl, alloc_dev) = self.__dlpack_device__();
1630 if let Some(dev_obj) = dl_device.as_ref() {
1631 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1632 if dev_ty != kdl || dev_id != alloc_dev {
1633 let wants_copy = copy
1634 .as_ref()
1635 .and_then(|c| c.extract::<bool>(py).ok())
1636 .unwrap_or(false);
1637 if wants_copy {
1638 return Err(PyValueError::new_err(
1639 "device copy not implemented for __dlpack__",
1640 ));
1641 } else {
1642 return Err(PyValueError::new_err(
1643 "dl_device mismatch for __dlpack__ on SrsiDeviceArrayF32",
1644 ));
1645 }
1646 }
1647 }
1648 }
1649
1650 if let Some(copy_obj) = copy.as_ref() {
1651 let do_copy: bool = copy_obj.extract(py)?;
1652 if do_copy {
1653 return Err(PyValueError::new_err(
1654 "copy=True not supported for SrsiDeviceArrayF32",
1655 ));
1656 }
1657 }
1658
1659 let dummy =
1660 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1661 let rows = self.inner.rows;
1662 let cols = self.inner.cols;
1663 let inner = std::mem::replace(
1664 &mut self.inner,
1665 DeviceArrayF32 {
1666 buf: dummy,
1667 rows: 0,
1668 cols: 0,
1669 },
1670 );
1671
1672 let buf = inner.buf;
1673 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1674
1675 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1676 }
1677}
1678
1679#[cfg(all(feature = "python", feature = "cuda"))]
1680impl SrsiDeviceArrayF32Py {
1681 pub fn new_from_rust(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
1682 Self {
1683 inner,
1684 _ctx: ctx_guard,
1685 device_id,
1686 }
1687 }
1688}
1689
1690#[cfg(all(feature = "python", feature = "cuda"))]
1691#[pyfunction(name = "srsi_cuda_batch_dev")]
1692#[pyo3(signature = (data_f32, rsi_range, stoch_range, k_range, d_range, device_id=0))]
1693pub fn srsi_cuda_batch_dev_py<'py>(
1694 py: Python<'py>,
1695 data_f32: numpy::PyReadonlyArray1<'py, f32>,
1696 rsi_range: (usize, usize, usize),
1697 stoch_range: (usize, usize, usize),
1698 k_range: (usize, usize, usize),
1699 d_range: (usize, usize, usize),
1700 device_id: usize,
1701) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1702 use crate::cuda::cuda_available;
1703 use crate::cuda::oscillators::CudaSrsi;
1704 use numpy::IntoPyArray;
1705 if !cuda_available() {
1706 return Err(PyValueError::new_err("CUDA not available"));
1707 }
1708 let slice = data_f32.as_slice()?;
1709 let sweep = SrsiBatchRange {
1710 rsi_period: rsi_range,
1711 stoch_period: stoch_range,
1712 k: k_range,
1713 d: d_range,
1714 };
1715 let ((pair, combos), ctx, dev_id) = py.allow_threads(|| {
1716 let cuda = CudaSrsi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1717 let ctx = cuda.context_arc();
1718 let dev_id = cuda.device_id();
1719 let res = cuda
1720 .srsi_batch_dev(slice, &sweep)
1721 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1722 Ok::<_, pyo3::PyErr>((res, ctx, dev_id))
1723 })?;
1724 let dict = pyo3::types::PyDict::new(py);
1725 dict.set_item(
1726 "k",
1727 SrsiDeviceArrayF32Py::new_from_rust(pair.k, ctx.clone(), dev_id),
1728 )?;
1729 dict.set_item(
1730 "d",
1731 SrsiDeviceArrayF32Py::new_from_rust(pair.d, ctx, dev_id),
1732 )?;
1733 dict.set_item("rows", combos.len())?;
1734 dict.set_item("cols", slice.len())?;
1735 dict.set_item(
1736 "rsi_periods",
1737 combos
1738 .iter()
1739 .map(|p| p.rsi_period.unwrap() as u64)
1740 .collect::<Vec<_>>()
1741 .into_pyarray(py),
1742 )?;
1743 dict.set_item(
1744 "stoch_periods",
1745 combos
1746 .iter()
1747 .map(|p| p.stoch_period.unwrap() as u64)
1748 .collect::<Vec<_>>()
1749 .into_pyarray(py),
1750 )?;
1751 dict.set_item(
1752 "k_periods",
1753 combos
1754 .iter()
1755 .map(|p| p.k.unwrap() as u64)
1756 .collect::<Vec<_>>()
1757 .into_pyarray(py),
1758 )?;
1759 dict.set_item(
1760 "d_periods",
1761 combos
1762 .iter()
1763 .map(|p| p.d.unwrap() as u64)
1764 .collect::<Vec<_>>()
1765 .into_pyarray(py),
1766 )?;
1767 Ok(dict)
1768}
1769
1770#[cfg(all(feature = "python", feature = "cuda"))]
1771#[pyfunction(name = "srsi_cuda_many_series_one_param_dev")]
1772#[pyo3(signature = (data_tm_f32, rsi_period=14, stoch_period=14, k=3, d=3, device_id=0))]
1773pub fn srsi_cuda_many_series_one_param_dev_py<'py>(
1774 py: Python<'py>,
1775 data_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
1776 rsi_period: usize,
1777 stoch_period: usize,
1778 k: usize,
1779 d: usize,
1780 device_id: usize,
1781) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1782 use crate::cuda::cuda_available;
1783 use crate::cuda::oscillators::CudaSrsi;
1784 use numpy::PyUntypedArrayMethods;
1785 if !cuda_available() {
1786 return Err(PyValueError::new_err("CUDA not available"));
1787 }
1788 let shape = data_tm_f32.shape();
1789 if shape.len() != 2 {
1790 return Err(PyValueError::new_err("expected 2D array"));
1791 }
1792 let rows = shape[0];
1793 let cols = shape[1];
1794 let flat = data_tm_f32.as_slice()?;
1795 let expected = rows
1796 .checked_mul(cols)
1797 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1798 if flat.len() != expected {
1799 return Err(PyValueError::new_err("time-major input length mismatch"));
1800 }
1801 let params = SrsiParams {
1802 rsi_period: Some(rsi_period),
1803 stoch_period: Some(stoch_period),
1804 k: Some(k),
1805 d: Some(d),
1806 source: None,
1807 };
1808 let (pair, ctx, dev_id) = py.allow_threads(|| {
1809 let cuda = CudaSrsi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1810 let ctx = cuda.context_arc();
1811 let dev_id = cuda.device_id();
1812 let res = cuda
1813 .srsi_many_series_one_param_time_major_dev(flat, cols, rows, ¶ms)
1814 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1815 Ok::<_, pyo3::PyErr>((res, ctx, dev_id))
1816 })?;
1817 let dict = pyo3::types::PyDict::new(py);
1818 dict.set_item(
1819 "k",
1820 SrsiDeviceArrayF32Py::new_from_rust(pair.k, ctx.clone(), dev_id),
1821 )?;
1822 dict.set_item(
1823 "d",
1824 SrsiDeviceArrayF32Py::new_from_rust(pair.d, ctx, dev_id),
1825 )?;
1826 dict.set_item("rows", rows)?;
1827 dict.set_item("cols", cols)?;
1828 dict.set_item("rsi_period", rsi_period)?;
1829 dict.set_item("stoch_period", stoch_period)?;
1830 dict.set_item("k_period", k)?;
1831 dict.set_item("d_period", d)?;
1832 Ok(dict)
1833}
1834
1835#[cfg(feature = "python")]
1836#[pyfunction(name = "srsi")]
1837#[pyo3(signature = (data, rsi_period=None, stoch_period=None, k=None, d=None, source=None, kernel=None))]
1838pub fn srsi_py<'py>(
1839 py: Python<'py>,
1840 data: PyReadonlyArray1<'py, f64>,
1841 rsi_period: Option<usize>,
1842 stoch_period: Option<usize>,
1843 k: Option<usize>,
1844 d: Option<usize>,
1845 source: Option<&str>,
1846 kernel: Option<&str>,
1847) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
1848 let slice_in = data.as_slice()?;
1849 let kern = validate_kernel(kernel, false)?;
1850
1851 if matches!(rsi_period, Some(0))
1852 || matches!(stoch_period, Some(0))
1853 || matches!(k, Some(0))
1854 || matches!(d, Some(0))
1855 {
1856 return Err(PyValueError::new_err("Invalid period: values must be > 0"));
1857 }
1858
1859 let params = SrsiParams {
1860 rsi_period,
1861 stoch_period,
1862 k,
1863 d,
1864 source: source.map(|s| s.to_string()),
1865 };
1866 let input = SrsiInput::from_slice(slice_in, params);
1867
1868 let (k_vec, d_vec) = py
1869 .allow_threads(|| srsi_with_kernel(&input, kern).map(|o| (o.k, o.d)))
1870 .map_err(|e| {
1871 let msg = e.to_string();
1872 if msg.contains("Not enough valid data")
1873 && (matches!(rsi_period, Some(0))
1874 || matches!(stoch_period, Some(0))
1875 || matches!(k, Some(0))
1876 || matches!(d, Some(0)))
1877 {
1878 PyValueError::new_err("Invalid period: values must be > 0")
1879 } else {
1880 PyValueError::new_err(msg)
1881 }
1882 })?;
1883
1884 Ok((k_vec.into_pyarray(py), d_vec.into_pyarray(py)))
1885}
1886
1887#[cfg(feature = "python")]
1888#[pyclass(name = "SrsiStream")]
1889pub struct SrsiStreamPy {
1890 stream: SrsiStream,
1891}
1892
1893#[cfg(feature = "python")]
1894#[pymethods]
1895impl SrsiStreamPy {
1896 #[new]
1897 fn new(
1898 rsi_period: Option<usize>,
1899 stoch_period: Option<usize>,
1900 k: Option<usize>,
1901 d: Option<usize>,
1902 ) -> PyResult<Self> {
1903 let params = SrsiParams {
1904 rsi_period,
1905 stoch_period,
1906 k,
1907 d,
1908 source: None,
1909 };
1910 let stream =
1911 SrsiStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1912 Ok(SrsiStreamPy { stream })
1913 }
1914
1915 fn update(&mut self, value: f64) -> Option<(f64, f64)> {
1916 self.stream.update(value)
1917 }
1918}
1919
1920#[cfg(feature = "python")]
1921#[pyfunction(name = "srsi_batch")]
1922#[pyo3(signature = (data, rsi_period_range, stoch_period_range, k_range, d_range, source=None, kernel=None))]
1923pub fn srsi_batch_py<'py>(
1924 py: Python<'py>,
1925 data: PyReadonlyArray1<'py, f64>,
1926 rsi_period_range: (usize, usize, usize),
1927 stoch_period_range: (usize, usize, usize),
1928 k_range: (usize, usize, usize),
1929 d_range: (usize, usize, usize),
1930 source: Option<&str>,
1931 kernel: Option<&str>,
1932) -> PyResult<Bound<'py, PyDict>> {
1933 let slice_in = data.as_slice()?;
1934 let kern = validate_kernel(kernel, true)?;
1935
1936 let sweep = SrsiBatchRange {
1937 rsi_period: rsi_period_range,
1938 stoch_period: stoch_period_range,
1939 k: k_range,
1940 d: d_range,
1941 };
1942
1943 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1944 let rows = combos.len();
1945 let cols = slice_in.len();
1946
1947 let total = rows
1948 .checked_mul(cols)
1949 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1950 let k_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1951 let d_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1952 let k_slice = unsafe { k_arr.as_slice_mut()? };
1953 let d_slice = unsafe { d_arr.as_slice_mut()? };
1954
1955 let combos = py
1956 .allow_threads(|| {
1957 let kernel = match kern {
1958 Kernel::Auto => detect_best_batch_kernel(),
1959 k => k,
1960 };
1961 srsi_batch_inner_into(slice_in, &sweep, kernel, true, k_slice, d_slice)
1962 })
1963 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1964
1965 let dict = PyDict::new(py);
1966 dict.set_item("k", k_arr.reshape((rows, cols))?)?;
1967 dict.set_item("d", d_arr.reshape((rows, cols))?)?;
1968 dict.set_item(
1969 "rsi_periods",
1970 combos
1971 .iter()
1972 .map(|p| p.rsi_period.unwrap() as u64)
1973 .collect::<Vec<_>>()
1974 .into_pyarray(py),
1975 )?;
1976 dict.set_item(
1977 "stoch_periods",
1978 combos
1979 .iter()
1980 .map(|p| p.stoch_period.unwrap() as u64)
1981 .collect::<Vec<_>>()
1982 .into_pyarray(py),
1983 )?;
1984 dict.set_item(
1985 "k_periods",
1986 combos
1987 .iter()
1988 .map(|p| p.k.unwrap() as u64)
1989 .collect::<Vec<_>>()
1990 .into_pyarray(py),
1991 )?;
1992 dict.set_item(
1993 "d_periods",
1994 combos
1995 .iter()
1996 .map(|p| p.d.unwrap() as u64)
1997 .collect::<Vec<_>>()
1998 .into_pyarray(py),
1999 )?;
2000
2001 Ok(dict)
2002}
2003
2004pub fn srsi_into_slice(
2005 dst_k: &mut [f64],
2006 dst_d: &mut [f64],
2007 input: &SrsiInput,
2008 kern: Kernel,
2009) -> Result<(), SrsiError> {
2010 let data: &[f64] = input.as_ref();
2011
2012 if dst_k.len() != data.len() || dst_d.len() != data.len() {
2013 return Err(SrsiError::OutputLengthMismatch {
2014 expected: data.len(),
2015 k_len: dst_k.len(),
2016 d_len: dst_d.len(),
2017 });
2018 }
2019
2020 let out = srsi_with_kernel(input, kern)?;
2021 dst_k.copy_from_slice(&out.k);
2022 dst_d.copy_from_slice(&out.d);
2023
2024 Ok(())
2025}
2026
2027#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2028#[inline]
2029pub fn srsi_into(input: &SrsiInput, out_k: &mut [f64], out_d: &mut [f64]) -> Result<(), SrsiError> {
2030 srsi_into_slice(out_k, out_d, input, Kernel::Auto)
2031}
2032
2033#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2034#[wasm_bindgen]
2035pub fn srsi_js(
2036 data: &[f64],
2037 rsi_period: usize,
2038 stoch_period: usize,
2039 k: usize,
2040 d: usize,
2041) -> Result<Vec<f64>, JsValue> {
2042 if data.is_empty() {
2043 return Err(JsValue::from_str("srsi: Input data is empty"));
2044 }
2045
2046 if rsi_period == 0 || stoch_period == 0 || k == 0 || d == 0 {
2047 return Err(JsValue::from_str("srsi: Invalid period"));
2048 }
2049
2050 let params = SrsiParams {
2051 rsi_period: Some(rsi_period),
2052 stoch_period: Some(stoch_period),
2053 k: Some(k),
2054 d: Some(d),
2055 source: None,
2056 };
2057 let input = SrsiInput::from_slice(data, params);
2058 let out = srsi_with_kernel(&input, Kernel::Auto)
2059 .map_err(|e| JsValue::from_str(&format!("srsi: {}", e)))?;
2060
2061 let mut values = Vec::with_capacity(2 * data.len());
2062 values.extend_from_slice(&out.k);
2063 values.extend_from_slice(&out.d);
2064
2065 Ok(values)
2066}
2067
2068#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2069#[wasm_bindgen]
2070pub fn srsi_alloc(len: usize) -> *mut f64 {
2071 let mut vec = Vec::<f64>::with_capacity(len);
2072 let ptr = vec.as_mut_ptr();
2073 std::mem::forget(vec);
2074 ptr
2075}
2076
2077#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2078#[wasm_bindgen]
2079pub fn srsi_free(ptr: *mut f64, len: usize) {
2080 unsafe {
2081 let _ = Vec::from_raw_parts(ptr, len, len);
2082 }
2083}
2084
2085#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2086#[wasm_bindgen]
2087pub fn srsi_into(
2088 in_ptr: usize,
2089 k_ptr: usize,
2090 d_ptr: usize,
2091 len: usize,
2092 rsi_period: usize,
2093 stoch_period: usize,
2094 k: usize,
2095 d: usize,
2096) -> Result<(), JsValue> {
2097 unsafe {
2098 let data = std::slice::from_raw_parts(in_ptr as *const f64, len);
2099
2100 if rsi_period == 0 || stoch_period == 0 || k == 0 || d == 0 {
2101 return Err(JsValue::from_str("Invalid period"));
2102 }
2103
2104 let params = SrsiParams {
2105 rsi_period: Some(rsi_period),
2106 stoch_period: Some(stoch_period),
2107 k: Some(k),
2108 d: Some(d),
2109 source: None,
2110 };
2111 let input = SrsiInput::from_slice(data, params);
2112
2113 let needs_temp = in_ptr == k_ptr || in_ptr == d_ptr || k_ptr == d_ptr;
2114
2115 if needs_temp {
2116 let mut temp_k = vec![0.0; len];
2117 let mut temp_d = vec![0.0; len];
2118 srsi_into_slice(&mut temp_k, &mut temp_d, &input, Kernel::Auto)
2119 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2120
2121 let k_out = std::slice::from_raw_parts_mut(k_ptr as *mut f64, len);
2122 let d_out = std::slice::from_raw_parts_mut(d_ptr as *mut f64, len);
2123 k_out.copy_from_slice(&temp_k);
2124 d_out.copy_from_slice(&temp_d);
2125 } else {
2126 let k_out = std::slice::from_raw_parts_mut(k_ptr as *mut f64, len);
2127 let d_out = std::slice::from_raw_parts_mut(d_ptr as *mut f64, len);
2128 srsi_into_slice(k_out, d_out, &input, Kernel::Auto)
2129 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2130 }
2131
2132 Ok(())
2133 }
2134}
2135
2136#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2137#[derive(Serialize, Deserialize)]
2138pub struct SrsiBatchConfig {
2139 pub rsi_period_range: (usize, usize, usize),
2140 pub stoch_period_range: (usize, usize, usize),
2141 pub k_range: (usize, usize, usize),
2142 pub d_range: (usize, usize, usize),
2143}
2144
2145#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2146#[derive(Serialize, Deserialize)]
2147pub struct SrsiBatchJsOutput {
2148 pub k_values: Vec<f64>,
2149 pub d_values: Vec<f64>,
2150 pub rows: usize,
2151 pub cols: usize,
2152 pub combos: Vec<SrsiParams>,
2153}
2154
2155#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2156#[wasm_bindgen(js_name = srsi_batch)]
2157pub fn srsi_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2158 let cfg: SrsiBatchConfig = serde_wasm_bindgen::from_value(config)
2159 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2160
2161 let sweep = SrsiBatchRange {
2162 rsi_period: cfg.rsi_period_range,
2163 stoch_period: cfg.stoch_period_range,
2164 k: cfg.k_range,
2165 d: cfg.d_range,
2166 };
2167
2168 let out = srsi_batch_inner(data, &sweep, Kernel::Auto, false)
2169 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2170
2171 let res = SrsiBatchJsOutput {
2172 k_values: out.k,
2173 d_values: out.d,
2174 rows: out.rows,
2175 cols: out.cols,
2176 combos: out.combos,
2177 };
2178
2179 serde_wasm_bindgen::to_value(&res)
2180 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2181}
2182
2183#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2184#[wasm_bindgen]
2185pub fn srsi_batch_into(
2186 in_ptr: usize,
2187 k_ptr: usize,
2188 d_ptr: usize,
2189 len: usize,
2190 rsi_period_start: usize,
2191 rsi_period_end: usize,
2192 rsi_period_step: usize,
2193 stoch_period_start: usize,
2194 stoch_period_end: usize,
2195 stoch_period_step: usize,
2196 k_start: usize,
2197 k_end: usize,
2198 k_step: usize,
2199 d_start: usize,
2200 d_end: usize,
2201 d_step: usize,
2202) -> Result<usize, JsValue> {
2203 unsafe {
2204 let data = std::slice::from_raw_parts(in_ptr as *const f64, len);
2205
2206 let sweep = SrsiBatchRange {
2207 rsi_period: (rsi_period_start, rsi_period_end, rsi_period_step),
2208 stoch_period: (stoch_period_start, stoch_period_end, stoch_period_step),
2209 k: (k_start, k_end, k_step),
2210 d: (d_start, d_end, d_step),
2211 };
2212
2213 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2214 let rows = combos.len();
2215 let cols = len;
2216
2217 let total = rows
2218 .checked_mul(cols)
2219 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
2220 let k_out = std::slice::from_raw_parts_mut(k_ptr as *mut f64, total);
2221 let d_out = std::slice::from_raw_parts_mut(d_ptr as *mut f64, total);
2222
2223 srsi_batch_inner_into(data, &sweep, Kernel::Auto, false, k_out, d_out)
2224 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2225
2226 Ok(rows)
2227 }
2228}
2229
2230#[cfg(test)]
2231mod tests {
2232 use super::*;
2233 use crate::skip_if_unsupported;
2234 use crate::utilities::data_loader::read_candles_from_csv;
2235
2236 fn check_srsi_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2237 skip_if_unsupported!(kernel, test_name);
2238 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2239 let candles = read_candles_from_csv(file_path)?;
2240 let default_params = SrsiParams {
2241 rsi_period: None,
2242 stoch_period: None,
2243 k: None,
2244 d: None,
2245 source: None,
2246 };
2247 let input = SrsiInput::from_candles(&candles, "close", default_params);
2248 let output = srsi_with_kernel(&input, kernel)?;
2249 assert_eq!(output.k.len(), candles.close.len());
2250 assert_eq!(output.d.len(), candles.close.len());
2251 Ok(())
2252 }
2253
2254 fn check_srsi_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2255 skip_if_unsupported!(kernel, test_name);
2256 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2257 let candles = read_candles_from_csv(file_path)?;
2258 let params = SrsiParams::default();
2259 let input = SrsiInput::from_candles(&candles, "close", params);
2260 let result = srsi_with_kernel(&input, kernel)?;
2261 assert_eq!(result.k.len(), candles.close.len());
2262 assert_eq!(result.d.len(), candles.close.len());
2263 let last_five_k = [
2264 65.52066633236464,
2265 61.22507053191985,
2266 57.220471530042644,
2267 64.61344854988147,
2268 60.66534359318523,
2269 ];
2270 let last_five_d = [
2271 64.33503158970049,
2272 64.42143544464182,
2273 61.32206946477942,
2274 61.01966353728503,
2275 60.83308789104016,
2276 ];
2277 let k_slice = &result.k[result.k.len() - 5..];
2278 let d_slice = &result.d[result.d.len() - 5..];
2279 for i in 0..5 {
2280 let diff_k = (k_slice[i] - last_five_k[i]).abs();
2281 let diff_d = (d_slice[i] - last_five_d[i]).abs();
2282 assert!(
2283 diff_k < 1e-6,
2284 "Mismatch in SRSI K at index {}: got {}, expected {}",
2285 i,
2286 k_slice[i],
2287 last_five_k[i]
2288 );
2289 assert!(
2290 diff_d < 1e-6,
2291 "Mismatch in SRSI D at index {}: got {}, expected {}",
2292 i,
2293 d_slice[i],
2294 last_five_d[i]
2295 );
2296 }
2297 Ok(())
2298 }
2299
2300 fn check_srsi_from_slice(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2301 skip_if_unsupported!(kernel, test_name);
2302 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2303 let candles = read_candles_from_csv(file_path)?;
2304 let slice_data = candles.close.as_slice();
2305 let params = SrsiParams {
2306 rsi_period: Some(3),
2307 stoch_period: Some(3),
2308 k: Some(2),
2309 d: Some(2),
2310 source: Some("close".to_string()),
2311 };
2312 let input = SrsiInput::from_slice(&slice_data, params);
2313 let output = srsi_with_kernel(&input, kernel)?;
2314 assert_eq!(output.k.len(), slice_data.len());
2315 assert_eq!(output.d.len(), slice_data.len());
2316 Ok(())
2317 }
2318
2319 fn check_srsi_custom_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2320 skip_if_unsupported!(kernel, test_name);
2321 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2322 let candles = read_candles_from_csv(file_path)?;
2323 let params = SrsiParams {
2324 rsi_period: Some(10),
2325 stoch_period: Some(10),
2326 k: Some(4),
2327 d: Some(4),
2328 source: Some("hlc3".to_string()),
2329 };
2330 let input = SrsiInput::from_candles(&candles, "hlc3", params);
2331 let output = srsi_with_kernel(&input, kernel)?;
2332 assert_eq!(output.k.len(), candles.close.len());
2333 assert_eq!(output.d.len(), candles.close.len());
2334 Ok(())
2335 }
2336
2337 #[cfg(debug_assertions)]
2338 fn check_srsi_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2339 skip_if_unsupported!(kernel, test_name);
2340
2341 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2342 let candles = read_candles_from_csv(file_path)?;
2343
2344 let test_params = vec![
2345 SrsiParams::default(),
2346 SrsiParams {
2347 rsi_period: Some(2),
2348 stoch_period: Some(2),
2349 k: Some(2),
2350 d: Some(2),
2351 source: None,
2352 },
2353 SrsiParams {
2354 rsi_period: Some(5),
2355 stoch_period: Some(5),
2356 k: Some(3),
2357 d: Some(3),
2358 source: None,
2359 },
2360 SrsiParams {
2361 rsi_period: Some(10),
2362 stoch_period: Some(10),
2363 k: Some(5),
2364 d: Some(5),
2365 source: None,
2366 },
2367 SrsiParams {
2368 rsi_period: Some(20),
2369 stoch_period: Some(20),
2370 k: Some(7),
2371 d: Some(7),
2372 source: None,
2373 },
2374 SrsiParams {
2375 rsi_period: Some(50),
2376 stoch_period: Some(50),
2377 k: Some(10),
2378 d: Some(10),
2379 source: None,
2380 },
2381 SrsiParams {
2382 rsi_period: Some(7),
2383 stoch_period: Some(14),
2384 k: Some(3),
2385 d: Some(5),
2386 source: None,
2387 },
2388 SrsiParams {
2389 rsi_period: Some(14),
2390 stoch_period: Some(7),
2391 k: Some(5),
2392 d: Some(3),
2393 source: None,
2394 },
2395 SrsiParams {
2396 rsi_period: Some(21),
2397 stoch_period: Some(14),
2398 k: Some(6),
2399 d: Some(4),
2400 source: None,
2401 },
2402 SrsiParams {
2403 rsi_period: Some(100),
2404 stoch_period: Some(100),
2405 k: Some(20),
2406 d: Some(20),
2407 source: None,
2408 },
2409 ];
2410
2411 for (param_idx, params) in test_params.iter().enumerate() {
2412 let input = SrsiInput::from_candles(&candles, "close", params.clone());
2413 let output = srsi_with_kernel(&input, kernel)?;
2414
2415 for (i, &val) in output.k.iter().enumerate() {
2416 if val.is_nan() {
2417 continue;
2418 }
2419
2420 let bits = val.to_bits();
2421
2422 if bits == 0x11111111_11111111 {
2423 panic!(
2424 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in K output \
2425 with params: rsi_period={}, stoch_period={}, k={}, d={} (param set {})",
2426 test_name, val, bits, i,
2427 params.rsi_period.unwrap_or(14),
2428 params.stoch_period.unwrap_or(14),
2429 params.k.unwrap_or(3),
2430 params.d.unwrap_or(3),
2431 param_idx
2432 );
2433 }
2434
2435 if bits == 0x22222222_22222222 {
2436 panic!(
2437 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in K output \
2438 with params: rsi_period={}, stoch_period={}, k={}, d={} (param set {})",
2439 test_name, val, bits, i,
2440 params.rsi_period.unwrap_or(14),
2441 params.stoch_period.unwrap_or(14),
2442 params.k.unwrap_or(3),
2443 params.d.unwrap_or(3),
2444 param_idx
2445 );
2446 }
2447
2448 if bits == 0x33333333_33333333 {
2449 panic!(
2450 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in K output \
2451 with params: rsi_period={}, stoch_period={}, k={}, d={} (param set {})",
2452 test_name, val, bits, i,
2453 params.rsi_period.unwrap_or(14),
2454 params.stoch_period.unwrap_or(14),
2455 params.k.unwrap_or(3),
2456 params.d.unwrap_or(3),
2457 param_idx
2458 );
2459 }
2460 }
2461
2462 for (i, &val) in output.d.iter().enumerate() {
2463 if val.is_nan() {
2464 continue;
2465 }
2466
2467 let bits = val.to_bits();
2468
2469 if bits == 0x11111111_11111111 {
2470 panic!(
2471 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in D output \
2472 with params: rsi_period={}, stoch_period={}, k={}, d={} (param set {})",
2473 test_name, val, bits, i,
2474 params.rsi_period.unwrap_or(14),
2475 params.stoch_period.unwrap_or(14),
2476 params.k.unwrap_or(3),
2477 params.d.unwrap_or(3),
2478 param_idx
2479 );
2480 }
2481
2482 if bits == 0x22222222_22222222 {
2483 panic!(
2484 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in D output \
2485 with params: rsi_period={}, stoch_period={}, k={}, d={} (param set {})",
2486 test_name, val, bits, i,
2487 params.rsi_period.unwrap_or(14),
2488 params.stoch_period.unwrap_or(14),
2489 params.k.unwrap_or(3),
2490 params.d.unwrap_or(3),
2491 param_idx
2492 );
2493 }
2494
2495 if bits == 0x33333333_33333333 {
2496 panic!(
2497 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in D output \
2498 with params: rsi_period={}, stoch_period={}, k={}, d={} (param set {})",
2499 test_name, val, bits, i,
2500 params.rsi_period.unwrap_or(14),
2501 params.stoch_period.unwrap_or(14),
2502 params.k.unwrap_or(3),
2503 params.d.unwrap_or(3),
2504 param_idx
2505 );
2506 }
2507 }
2508 }
2509
2510 Ok(())
2511 }
2512
2513 #[cfg(not(debug_assertions))]
2514 fn check_srsi_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2515 Ok(())
2516 }
2517
2518 #[cfg(feature = "proptest")]
2519 #[allow(clippy::float_cmp)]
2520 fn check_srsi_property(
2521 test_name: &str,
2522 kernel: Kernel,
2523 ) -> Result<(), Box<dyn std::error::Error>> {
2524 use proptest::prelude::*;
2525 skip_if_unsupported!(kernel, test_name);
2526
2527 let strat = (2usize..=20, 2usize..=20, 2usize..=10, 2usize..=10).prop_flat_map(
2528 |(rsi_period, stoch_period, k, d)| {
2529 let min_data_needed = rsi_period + stoch_period.max(k).max(d) + 10;
2530 (
2531 prop::collection::vec(
2532 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2533 min_data_needed..400,
2534 ),
2535 Just(rsi_period),
2536 Just(stoch_period),
2537 Just(k),
2538 Just(d),
2539 )
2540 },
2541 );
2542
2543 proptest::test_runner::TestRunner::default()
2544 .run(&strat, |(data, rsi_period, stoch_period, k, d)| {
2545 let params = SrsiParams {
2546 rsi_period: Some(rsi_period),
2547 stoch_period: Some(stoch_period),
2548 k: Some(k),
2549 d: Some(d),
2550 source: None,
2551 };
2552 let input = SrsiInput::from_slice(&data, params.clone());
2553
2554 let output_result = srsi_with_kernel(&input, kernel);
2555 let ref_output_result = srsi_with_kernel(&input, Kernel::Scalar);
2556
2557 match (output_result, ref_output_result) {
2558 (Ok(output), Ok(ref_output)) => {
2559 let expected_min_warmup = rsi_period;
2560
2561 for i in 0..data.len() {
2562 if !output.k[i].is_nan() {
2563 prop_assert!(
2564 output.k[i] >= -1e-9 && output.k[i] <= 100.0 + 1e-9,
2565 "idx {}: K value {} is out of bounds [0, 100]",
2566 i,
2567 output.k[i]
2568 );
2569 }
2570 if !output.d[i].is_nan() {
2571 prop_assert!(
2572 output.d[i] >= -1e-9 && output.d[i] <= 100.0 + 1e-9,
2573 "idx {}: D value {} is out of bounds [0, 100]",
2574 i,
2575 output.d[i]
2576 );
2577 }
2578 }
2579
2580 for i in 0..expected_min_warmup.min(data.len()) {
2581 prop_assert!(
2582 output.k[i].is_nan(),
2583 "idx {}: Expected NaN during early warmup for K, got {}",
2584 i,
2585 output.k[i]
2586 );
2587 prop_assert!(
2588 output.d[i].is_nan(),
2589 "idx {}: Expected NaN during early warmup for D, got {}",
2590 i,
2591 output.d[i]
2592 );
2593 }
2594
2595 let has_valid_k = output.k.iter().any(|&x| !x.is_nan());
2596 let has_valid_d = output.d.iter().any(|&x| !x.is_nan());
2597 if data.len() > rsi_period + stoch_period + k + d {
2598 prop_assert!(
2599 has_valid_k,
2600 "Expected at least one valid K value with sufficient data"
2601 );
2602 prop_assert!(
2603 has_valid_d,
2604 "Expected at least one valid D value with sufficient data"
2605 );
2606 }
2607
2608 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) {
2609 let last_k = output.k[data.len() - 1];
2610 let last_d = output.d[data.len() - 1];
2611 if !last_k.is_nan() && !last_d.is_nan() {
2612 prop_assert!(
2613 (last_k - 50.0).abs() < 10.0,
2614 "Constant data should produce K near 50, got {}",
2615 last_k
2616 );
2617 prop_assert!(
2618 (last_d - 50.0).abs() < 10.0,
2619 "Constant data should produce D near 50, got {}",
2620 last_d
2621 );
2622 }
2623 }
2624
2625 let is_increasing = data.windows(2).all(|w| w[1] > w[0]);
2626 if is_increasing && has_valid_k {
2627 let last_k = output.k[data.len() - 1];
2628 if !last_k.is_nan() {
2629 prop_assert!(
2630 last_k > 50.0,
2631 "Strictly increasing prices should produce K > 50, got {}",
2632 last_k
2633 );
2634 }
2635 }
2636
2637 let is_decreasing = data.windows(2).all(|w| w[1] < w[0]);
2638 if is_decreasing && has_valid_k {
2639 let last_k = output.k[data.len() - 1];
2640 if !last_k.is_nan() {
2641 prop_assert!(
2642 last_k < 50.0,
2643 "Strictly decreasing prices should produce K < 50, got {}",
2644 last_k
2645 );
2646 }
2647 }
2648
2649 for i in 0..data.len() {
2650 let k_val = output.k[i];
2651 let d_val = output.d[i];
2652 let ref_k = ref_output.k[i];
2653 let ref_d = ref_output.d[i];
2654
2655 if !k_val.is_finite() || !ref_k.is_finite() {
2656 prop_assert!(
2657 k_val.to_bits() == ref_k.to_bits(),
2658 "K finite/NaN mismatch idx {}: {} vs {}",
2659 i,
2660 k_val,
2661 ref_k
2662 );
2663 } else {
2664 let k_ulp_diff = k_val.to_bits().abs_diff(ref_k.to_bits());
2665 prop_assert!(
2666 (k_val - ref_k).abs() <= 1e-9 || k_ulp_diff <= 4,
2667 "K mismatch idx {}: {} vs {} (ULP={})",
2668 i,
2669 k_val,
2670 ref_k,
2671 k_ulp_diff
2672 );
2673 }
2674
2675 if !d_val.is_finite() || !ref_d.is_finite() {
2676 prop_assert!(
2677 d_val.to_bits() == ref_d.to_bits(),
2678 "D finite/NaN mismatch idx {}: {} vs {}",
2679 i,
2680 d_val,
2681 ref_d
2682 );
2683 } else {
2684 let d_ulp_diff = d_val.to_bits().abs_diff(ref_d.to_bits());
2685 prop_assert!(
2686 (d_val - ref_d).abs() <= 1e-9 || d_ulp_diff <= 4,
2687 "D mismatch idx {}: {} vs {} (ULP={})",
2688 i,
2689 d_val,
2690 ref_d,
2691 d_ulp_diff
2692 );
2693 }
2694 }
2695
2696 let output2 = srsi_with_kernel(&input, kernel).unwrap();
2697 for i in 0..data.len() {
2698 prop_assert!(
2699 output.k[i].to_bits() == output2.k[i].to_bits(),
2700 "K determinism failed at idx {}: {} vs {}",
2701 i,
2702 output.k[i],
2703 output2.k[i]
2704 );
2705 prop_assert!(
2706 output.d[i].to_bits() == output2.d[i].to_bits(),
2707 "D determinism failed at idx {}: {} vs {}",
2708 i,
2709 output.d[i],
2710 output2.d[i]
2711 );
2712 }
2713 }
2714 (Err(_), Err(_)) => {}
2715 (Ok(_), Err(e)) => {
2716 prop_assert!(false, "Kernel succeeded but scalar failed: {:?}", e);
2717 }
2718 (Err(e), Ok(_)) => {
2719 prop_assert!(false, "Kernel failed but scalar succeeded: {:?}", e);
2720 }
2721 }
2722
2723 Ok(())
2724 })
2725 .unwrap();
2726
2727 Ok(())
2728 }
2729
2730 macro_rules! generate_all_srsi_tests {
2731 ($($test_fn:ident),*) => {
2732 paste::paste! {
2733 $(
2734 #[test]
2735 fn [<$test_fn _scalar_f64>]() {
2736 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2737 }
2738 )*
2739 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2740 $(
2741 #[test]
2742 fn [<$test_fn _avx2_f64>]() {
2743 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2744 }
2745 #[test]
2746 fn [<$test_fn _avx512_f64>]() {
2747 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2748 }
2749 )*
2750 }
2751 }
2752 }
2753
2754 generate_all_srsi_tests!(
2755 check_srsi_partial_params,
2756 check_srsi_accuracy,
2757 check_srsi_custom_params,
2758 check_srsi_from_slice,
2759 check_srsi_no_poison
2760 );
2761
2762 #[test]
2763 fn test_srsi_into_slice_size_mismatch() {
2764 let data: Vec<f64> = (1..=50).map(|x| x as f64).collect();
2765 let data_len = data.len();
2766 let params = SrsiParams::default();
2767 let input = SrsiInput::from_slice(&data, params);
2768
2769 let mut k_small = vec![0.0; 30];
2770 let mut d_correct = vec![0.0; data_len];
2771 let result = srsi_into_slice(&mut k_small, &mut d_correct, &input, Kernel::Scalar);
2772 match result {
2773 Err(SrsiError::OutputLengthMismatch {
2774 expected,
2775 k_len,
2776 d_len,
2777 }) => {
2778 assert_eq!(expected, data_len);
2779 assert_eq!(k_len, 30);
2780 assert_eq!(d_len, data_len);
2781 }
2782 _ => panic!("Expected SizeMismatch error with k buffer too small"),
2783 }
2784
2785 let mut k_correct = vec![0.0; data_len];
2786 let mut d_small = vec![0.0; 35];
2787 let result = srsi_into_slice(&mut k_correct, &mut d_small, &input, Kernel::Scalar);
2788 match result {
2789 Err(SrsiError::OutputLengthMismatch {
2790 expected,
2791 k_len,
2792 d_len,
2793 }) => {
2794 assert_eq!(expected, data_len);
2795 assert_eq!(k_len, data_len);
2796 assert_eq!(d_len, 35);
2797 }
2798 _ => panic!("Expected SizeMismatch error with d buffer too small"),
2799 }
2800
2801 let mut k_wrong = vec![0.0; 60];
2802 let mut d_wrong = vec![0.0; 70];
2803 let result = srsi_into_slice(&mut k_wrong, &mut d_wrong, &input, Kernel::Scalar);
2804 match result {
2805 Err(SrsiError::OutputLengthMismatch {
2806 expected,
2807 k_len,
2808 d_len,
2809 }) => {
2810 assert_eq!(expected, data_len);
2811 assert_eq!(k_len, 60);
2812 assert_eq!(d_len, 70);
2813 }
2814 _ => panic!("Expected SizeMismatch error with both buffers wrong size"),
2815 }
2816
2817 let mut k_ok = vec![0.0; data_len];
2818 let mut d_ok = vec![0.0; data_len];
2819 let result = srsi_into_slice(&mut k_ok, &mut d_ok, &input, Kernel::Scalar);
2820 assert!(
2821 result.is_ok(),
2822 "Should succeed with correct buffer sizes. Error: {:?}",
2823 result
2824 );
2825 }
2826
2827 #[cfg(feature = "proptest")]
2828 generate_all_srsi_tests!(check_srsi_property);
2829
2830 #[test]
2831 fn test_srsi_into_matches_api() {
2832 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2833 let c = read_candles_from_csv(file).expect("load csv");
2834 let input = SrsiInput::from_candles(&c, "close", SrsiParams::default());
2835
2836 let base = srsi(&input).expect("srsi baseline");
2837
2838 let mut out_k = vec![0.0; c.close.len()];
2839 let mut out_d = vec![0.0; c.close.len()];
2840
2841 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2842 {
2843 srsi_into(&input, &mut out_k, &mut out_d).expect("srsi_into");
2844 }
2845 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2846 {
2847 srsi_into_slice(&mut out_k, &mut out_d, &input, Kernel::Auto).expect("srsi_into_slice");
2848 }
2849
2850 assert_eq!(base.k.len(), c.close.len());
2851 assert_eq!(base.d.len(), c.close.len());
2852 assert_eq!(out_k.len(), c.close.len());
2853 assert_eq!(out_d.len(), c.close.len());
2854
2855 fn eq_or_both_nan(a: f64, b: f64) -> bool {
2856 (a.is_nan() && b.is_nan()) || (a == b)
2857 }
2858
2859 for i in 0..c.close.len() {
2860 assert!(
2861 eq_or_both_nan(base.k[i], out_k[i]),
2862 "SRSI K mismatch at {i}: {} vs {}",
2863 base.k[i],
2864 out_k[i]
2865 );
2866 assert!(
2867 eq_or_both_nan(base.d[i], out_d[i]),
2868 "SRSI D mismatch at {i}: {} vs {}",
2869 base.d[i],
2870 out_d[i]
2871 );
2872 }
2873 }
2874
2875 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2876 skip_if_unsupported!(kernel, test);
2877 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2878 let c = read_candles_from_csv(file)?;
2879 let output = SrsiBatchBuilder::new()
2880 .kernel(kernel)
2881 .apply_slice(&c.close)?;
2882 let def = SrsiParams::default();
2883 let k_row = output.k_for(&def).expect("default k row missing");
2884 let d_row = output.d_for(&def).expect("default d row missing");
2885 assert_eq!(k_row.len(), c.close.len());
2886 assert_eq!(d_row.len(), c.close.len());
2887 Ok(())
2888 }
2889
2890 #[cfg(debug_assertions)]
2891 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2892 skip_if_unsupported!(kernel, test);
2893
2894 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2895 let c = read_candles_from_csv(file)?;
2896
2897 let test_configs = vec![
2898 ((2, 10, 2), (2, 10, 2), (2, 4, 1), (2, 4, 1)),
2899 ((5, 25, 5), (5, 25, 5), (3, 7, 2), (3, 7, 2)),
2900 ((30, 60, 15), (30, 60, 15), (5, 10, 5), (5, 10, 5)),
2901 ((2, 5, 1), (2, 5, 1), (2, 3, 1), (2, 3, 1)),
2902 ((10, 30, 10), (5, 15, 5), (3, 6, 3), (3, 6, 3)),
2903 ((14, 14, 0), (14, 14, 0), (3, 3, 0), (3, 3, 0)),
2904 ((7, 21, 7), (14, 28, 14), (3, 9, 3), (3, 9, 3)),
2905 ];
2906
2907 for (cfg_idx, &(rsi_range, stoch_range, k_range, d_range)) in
2908 test_configs.iter().enumerate()
2909 {
2910 let output = SrsiBatchBuilder::new()
2911 .kernel(kernel)
2912 .rsi_period_range(rsi_range.0, rsi_range.1, rsi_range.2)
2913 .stoch_period_range(stoch_range.0, stoch_range.1, stoch_range.2)
2914 .k_range(k_range.0, k_range.1, k_range.2)
2915 .d_range(d_range.0, d_range.1, d_range.2)
2916 .apply_slice(&c.close)?;
2917
2918 for (idx, &val) in output.k.iter().enumerate() {
2919 if val.is_nan() {
2920 continue;
2921 }
2922
2923 let bits = val.to_bits();
2924 let row = idx / output.cols;
2925 let col = idx % output.cols;
2926 let combo = &output.combos[row];
2927
2928 if bits == 0x11111111_11111111 {
2929 panic!(
2930 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) in K output \
2931 at row {} col {} (flat index {}) with params: rsi_period={}, stoch_period={}, k={}, d={}",
2932 test, cfg_idx, val, bits, row, col, idx,
2933 combo.rsi_period.unwrap_or(14),
2934 combo.stoch_period.unwrap_or(14),
2935 combo.k.unwrap_or(3),
2936 combo.d.unwrap_or(3)
2937 );
2938 }
2939
2940 if bits == 0x22222222_22222222 {
2941 panic!(
2942 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) in K output \
2943 at row {} col {} (flat index {}) with params: rsi_period={}, stoch_period={}, k={}, d={}",
2944 test, cfg_idx, val, bits, row, col, idx,
2945 combo.rsi_period.unwrap_or(14),
2946 combo.stoch_period.unwrap_or(14),
2947 combo.k.unwrap_or(3),
2948 combo.d.unwrap_or(3)
2949 );
2950 }
2951
2952 if bits == 0x33333333_33333333 {
2953 panic!(
2954 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) in K output \
2955 at row {} col {} (flat index {}) with params: rsi_period={}, stoch_period={}, k={}, d={}",
2956 test, cfg_idx, val, bits, row, col, idx,
2957 combo.rsi_period.unwrap_or(14),
2958 combo.stoch_period.unwrap_or(14),
2959 combo.k.unwrap_or(3),
2960 combo.d.unwrap_or(3)
2961 );
2962 }
2963 }
2964
2965 for (idx, &val) in output.d.iter().enumerate() {
2966 if val.is_nan() {
2967 continue;
2968 }
2969
2970 let bits = val.to_bits();
2971 let row = idx / output.cols;
2972 let col = idx % output.cols;
2973 let combo = &output.combos[row];
2974
2975 if bits == 0x11111111_11111111 {
2976 panic!(
2977 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) in D output \
2978 at row {} col {} (flat index {}) with params: rsi_period={}, stoch_period={}, k={}, d={}",
2979 test, cfg_idx, val, bits, row, col, idx,
2980 combo.rsi_period.unwrap_or(14),
2981 combo.stoch_period.unwrap_or(14),
2982 combo.k.unwrap_or(3),
2983 combo.d.unwrap_or(3)
2984 );
2985 }
2986
2987 if bits == 0x22222222_22222222 {
2988 panic!(
2989 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) in D output \
2990 at row {} col {} (flat index {}) with params: rsi_period={}, stoch_period={}, k={}, d={}",
2991 test, cfg_idx, val, bits, row, col, idx,
2992 combo.rsi_period.unwrap_or(14),
2993 combo.stoch_period.unwrap_or(14),
2994 combo.k.unwrap_or(3),
2995 combo.d.unwrap_or(3)
2996 );
2997 }
2998
2999 if bits == 0x33333333_33333333 {
3000 panic!(
3001 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) in D output \
3002 at row {} col {} (flat index {}) with params: rsi_period={}, stoch_period={}, k={}, d={}",
3003 test, cfg_idx, val, bits, row, col, idx,
3004 combo.rsi_period.unwrap_or(14),
3005 combo.stoch_period.unwrap_or(14),
3006 combo.k.unwrap_or(3),
3007 combo.d.unwrap_or(3)
3008 );
3009 }
3010 }
3011 }
3012
3013 Ok(())
3014 }
3015
3016 #[cfg(not(debug_assertions))]
3017 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
3018 Ok(())
3019 }
3020
3021 macro_rules! gen_batch_tests {
3022 ($fn_name:ident) => {
3023 paste::paste! {
3024 #[test] fn [<$fn_name _scalar>]() {
3025 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3026 }
3027 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3028 #[test] fn [<$fn_name _avx2>]() {
3029 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3030 }
3031 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3032 #[test] fn [<$fn_name _avx512>]() {
3033 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3034 }
3035 #[test] fn [<$fn_name _auto_detect>]() {
3036 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3037 }
3038 }
3039 };
3040 }
3041 gen_batch_tests!(check_batch_default_row);
3042 gen_batch_tests!(check_batch_no_poison);
3043}