1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde_wasm_bindgen;
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::{source_type, Candles};
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21 make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25use aligned_vec::{AVec, CACHELINE_ALIGN};
26#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
27use core::arch::x86_64::*;
28#[cfg(not(target_arch = "wasm32"))]
29use rayon::prelude::*;
30use std::error::Error;
31use thiserror::Error;
32
33use crate::indicators::moving_averages::ma::{ma, MaData};
34
35#[derive(Debug, Clone)]
36pub enum RsmkData<'a> {
37 Candles {
38 candles: &'a Candles,
39 candles_compare: &'a Candles,
40 source: &'a str,
41 },
42 Slices {
43 main: &'a [f64],
44 compare: &'a [f64],
45 },
46}
47
48#[derive(Debug, Clone)]
49pub struct RsmkOutput {
50 pub indicator: Vec<f64>,
51 pub signal: Vec<f64>,
52}
53
54#[derive(Debug, Clone)]
55#[cfg_attr(
56 all(target_arch = "wasm32", feature = "wasm"),
57 derive(Serialize, Deserialize)
58)]
59pub struct RsmkParams {
60 pub lookback: Option<usize>,
61 pub period: Option<usize>,
62 pub signal_period: Option<usize>,
63 pub matype: Option<String>,
64 pub signal_matype: Option<String>,
65}
66
67impl Default for RsmkParams {
68 fn default() -> Self {
69 Self {
70 lookback: Some(90),
71 period: Some(3),
72 signal_period: Some(20),
73 matype: Some("ema".to_string()),
74 signal_matype: Some("ema".to_string()),
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
80pub struct RsmkInput<'a> {
81 pub data: RsmkData<'a>,
82 pub params: RsmkParams,
83}
84
85impl<'a> RsmkInput<'a> {
86 #[inline]
87 pub fn from_candles(
88 candles: &'a Candles,
89 candles_compare: &'a Candles,
90 source: &'a str,
91 params: RsmkParams,
92 ) -> Self {
93 Self {
94 data: RsmkData::Candles {
95 candles,
96 candles_compare,
97 source,
98 },
99 params,
100 }
101 }
102
103 #[inline]
104 pub fn from_slices(main: &'a [f64], compare: &'a [f64], params: RsmkParams) -> Self {
105 Self {
106 data: RsmkData::Slices { main, compare },
107 params,
108 }
109 }
110
111 #[inline]
112 pub fn with_default_candles(candles: &'a Candles, candles_compare: &'a Candles) -> Self {
113 Self::from_candles(candles, candles_compare, "close", RsmkParams::default())
114 }
115
116 #[inline]
117 pub fn get_lookback(&self) -> usize {
118 self.params.lookback.unwrap_or(90)
119 }
120 #[inline]
121 pub fn get_period(&self) -> usize {
122 self.params.period.unwrap_or(3)
123 }
124 #[inline]
125 pub fn get_signal_period(&self) -> usize {
126 self.params.signal_period.unwrap_or(20)
127 }
128 #[inline]
129 pub fn get_ma_type(&self) -> &str {
130 self.params.matype.as_deref().unwrap_or("ema")
131 }
132 #[inline]
133 pub fn get_signal_ma_type(&self) -> &str {
134 self.params.signal_matype.as_deref().unwrap_or("ema")
135 }
136}
137
138#[derive(Clone, Debug)]
139pub struct RsmkBuilder {
140 lookback: Option<usize>,
141 period: Option<usize>,
142 signal_period: Option<usize>,
143 matype: Option<String>,
144 signal_matype: Option<String>,
145 kernel: Kernel,
146}
147
148impl Default for RsmkBuilder {
149 fn default() -> Self {
150 Self {
151 lookback: None,
152 period: None,
153 signal_period: None,
154 matype: None,
155 signal_matype: None,
156 kernel: Kernel::Auto,
157 }
158 }
159}
160
161impl RsmkBuilder {
162 #[inline(always)]
163 pub fn new() -> Self {
164 Self::default()
165 }
166 #[inline(always)]
167 pub fn lookback(mut self, n: usize) -> Self {
168 self.lookback = Some(n);
169 self
170 }
171 #[inline(always)]
172 pub fn period(mut self, n: usize) -> Self {
173 self.period = Some(n);
174 self
175 }
176 #[inline(always)]
177 pub fn signal_period(mut self, n: usize) -> Self {
178 self.signal_period = Some(n);
179 self
180 }
181 #[inline(always)]
182 pub fn matype<S: Into<String>>(mut self, s: S) -> Self {
183 self.matype = Some(s.into());
184 self
185 }
186 #[inline(always)]
187 pub fn signal_matype<S: Into<String>>(mut self, s: S) -> Self {
188 self.signal_matype = Some(s.into());
189 self
190 }
191 #[inline(always)]
192 pub fn kernel(mut self, k: Kernel) -> Self {
193 self.kernel = k;
194 self
195 }
196
197 #[inline(always)]
198 pub fn apply(
199 self,
200 candles: &Candles,
201 candles_compare: &Candles,
202 ) -> Result<RsmkOutput, RsmkError> {
203 let params = RsmkParams {
204 lookback: self.lookback,
205 period: self.period,
206 signal_period: self.signal_period,
207 matype: self.matype.clone(),
208 signal_matype: self.signal_matype.clone(),
209 };
210 let input = RsmkInput::from_candles(candles, candles_compare, "close", params);
211 rsmk_with_kernel(&input, self.kernel)
212 }
213
214 #[inline(always)]
215 pub fn apply_slices(self, main: &[f64], compare: &[f64]) -> Result<RsmkOutput, RsmkError> {
216 let params = RsmkParams {
217 lookback: self.lookback,
218 period: self.period,
219 signal_period: self.signal_period,
220 matype: self.matype.clone(),
221 signal_matype: self.signal_matype.clone(),
222 };
223 let input = RsmkInput::from_slices(main, compare, params);
224 rsmk_with_kernel(&input, self.kernel)
225 }
226
227 #[inline(always)]
228 pub fn into_stream(self) -> Result<RsmkStream, RsmkError> {
229 let params = RsmkParams {
230 lookback: self.lookback,
231 period: self.period,
232 signal_period: self.signal_period,
233 matype: self.matype,
234 signal_matype: self.signal_matype,
235 };
236 RsmkStream::try_new(params)
237 }
238}
239
240#[derive(Debug, Error)]
241pub enum RsmkError {
242 #[error("rsmk: Input data slice is empty.")]
243 EmptyInputData,
244 #[error("rsmk: Invalid period: period = {period}, data length = {data_len}")]
245 InvalidPeriod { period: usize, data_len: usize },
246 #[error("rsmk: Not enough valid data: needed = {needed}, valid = {valid}")]
247 NotEnoughValidData { needed: usize, valid: usize },
248 #[error("rsmk: All values are NaN.")]
249 AllValuesNaN,
250 #[error("rsmk: Output length mismatch: expected = {expected}, got = {got}")]
251 OutputLengthMismatch { expected: usize, got: usize },
252 #[error("rsmk: Invalid range: start = {start}, end = {end}, step = {step}")]
253 InvalidRange {
254 start: usize,
255 end: usize,
256 step: usize,
257 },
258 #[error("rsmk: Invalid kernel for batch: {0:?}")]
259 InvalidKernelForBatch(Kernel),
260 #[error("rsmk: Error from MA function: {0}")]
261 MaError(String),
262}
263
264#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
265impl From<RsmkError> for JsValue {
266 fn from(err: RsmkError) -> Self {
267 JsValue::from_str(&err.to_string())
268 }
269}
270
271#[inline]
272pub fn rsmk(input: &RsmkInput) -> Result<RsmkOutput, RsmkError> {
273 rsmk_with_kernel(input, Kernel::Auto)
274}
275
276pub fn rsmk_with_kernel(input: &RsmkInput, kernel: Kernel) -> Result<RsmkOutput, RsmkError> {
277 let (main, compare) = match &input.data {
278 RsmkData::Candles {
279 candles,
280 candles_compare,
281 source,
282 } => (
283 source_type(candles, source),
284 source_type(candles_compare, source),
285 ),
286 RsmkData::Slices { main, compare } => (*main, *compare),
287 };
288 if main.is_empty() || compare.is_empty() {
289 return Err(RsmkError::EmptyInputData);
290 }
291 if main.len() != compare.len() {
292 return Err(RsmkError::InvalidPeriod {
293 period: 0,
294 data_len: main.len().min(compare.len()),
295 });
296 }
297
298 let lookback = input.get_lookback();
299 let period = input.get_period();
300 let signal_period = input.get_signal_period();
301 if lookback == 0
302 || period == 0
303 || signal_period == 0
304 || period > main.len()
305 || signal_period > main.len()
306 || lookback >= main.len()
307 {
308 return Err(RsmkError::InvalidPeriod {
309 period: lookback.max(period).max(signal_period),
310 data_len: main.len(),
311 });
312 }
313
314 let mut lr = Vec::with_capacity(main.len());
315 unsafe {
316 lr.set_len(main.len());
317 }
318 for i in 0..main.len() {
319 let m = main[i];
320 let c = compare[i];
321
322 unsafe {
323 *lr.get_unchecked_mut(i) = if m.is_nan() || c.is_nan() || c == 0.0 {
324 f64::NAN
325 } else {
326 (m / c).ln()
327 };
328 }
329 }
330
331 let first_valid = lr
332 .iter()
333 .position(|&x| !x.is_nan())
334 .ok_or(RsmkError::AllValuesNaN)?;
335
336 let needed = lookback + period.max(signal_period);
337 if lr.len() - first_valid < needed {
338 return Err(RsmkError::NotEnoughValidData {
339 needed,
340 valid: lr.len() - first_valid,
341 });
342 }
343
344 let mut mom = alloc_with_nan_prefix(lr.len(), first_valid + lookback);
345 let ksel = match kernel {
346 Kernel::Auto => detect_best_kernel(),
347 k if k.is_batch() => match k {
348 Kernel::Avx512Batch => Kernel::Avx512,
349 Kernel::Avx2Batch => Kernel::Avx2,
350 Kernel::ScalarBatch => Kernel::Scalar,
351 _ => Kernel::Scalar,
352 },
353 k => k,
354 };
355 return match ksel {
356 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
357 Kernel::Avx512 => unsafe {
358 rsmk_avx512(
359 &lr,
360 lookback,
361 period,
362 signal_period,
363 input,
364 first_valid,
365 &mut mom,
366 )
367 },
368 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
369 Kernel::Avx2 => unsafe {
370 rsmk_avx2(
371 &lr,
372 lookback,
373 period,
374 signal_period,
375 input,
376 first_valid,
377 &mut mom,
378 )
379 },
380 _ => rsmk_scalar(
381 &lr,
382 lookback,
383 period,
384 signal_period,
385 input,
386 first_valid,
387 &mut mom,
388 ),
389 };
390}
391
392pub fn rsmk_scalar(
393 lr: &[f64],
394 lookback: usize,
395 period: usize,
396 signal_period: usize,
397 input: &RsmkInput,
398 first_valid: usize,
399 mom: &mut [f64],
400) -> Result<RsmkOutput, RsmkError> {
401 let len = lr.len();
402 let mom_fv = first_valid + lookback;
403
404 unsafe {
405 for i in mom_fv..len {
406 let a = *lr.get_unchecked(i);
407 let b = *lr.get_unchecked(i - lookback);
408 *mom.get_unchecked_mut(i) = if a.is_nan() || b.is_nan() {
409 f64::NAN
410 } else {
411 a - b
412 };
413 }
414 }
415
416 #[inline(always)]
417 fn is_ema(s: &str) -> bool {
418 s.eq_ignore_ascii_case("ema")
419 }
420 #[inline(always)]
421 fn is_sma(s: &str) -> bool {
422 s.eq_ignore_ascii_case("sma")
423 }
424
425 let matype = input.get_ma_type();
426 let sigtype = input.get_signal_ma_type();
427
428 let ind_warmup = mom_fv.saturating_add(period.saturating_sub(1));
429 let sig_warmup = ind_warmup.saturating_add(signal_period.saturating_sub(1));
430
431 if is_ema(matype) && is_ema(sigtype) {
432 let mut indicator = alloc_with_nan_prefix(len, ind_warmup);
433 let mut signal = alloc_with_nan_prefix(len, sig_warmup);
434
435 if ind_warmup < len {
436 let mut sum = 0.0;
437 let mut cnt = 0usize;
438 let init_end = (mom_fv + period).min(len);
439 unsafe {
440 for i in mom_fv..init_end {
441 let v = *mom.get_unchecked(i);
442 if !v.is_nan() {
443 sum += v;
444 cnt += 1;
445 }
446 }
447 }
448
449 if cnt > 0 {
450 let alpha_ind = 2.0 / (period as f64 + 1.0);
451 let alpha_sig = 2.0 / (signal_period as f64 + 1.0);
452
453 let mut ema_ind = (sum / cnt as f64) * 100.0;
454 unsafe {
455 *indicator.get_unchecked_mut(ind_warmup) = ema_ind;
456 }
457
458 let mut ema_sig = 0.0f64;
459 let mut acc_sig = ema_ind;
460 let mut cnt_sig = 1usize;
461
462 if sig_warmup == ind_warmup {
463 ema_sig = acc_sig / (cnt_sig as f64);
464 unsafe {
465 *signal.get_unchecked_mut(sig_warmup) = ema_sig;
466 }
467 }
468
469 unsafe {
470 for i in (ind_warmup + 1)..len {
471 let mv = *mom.get_unchecked(i);
472 if !mv.is_nan() {
473 let src100 = mv * 100.0;
474
475 ema_ind = (src100 - ema_ind).mul_add(alpha_ind, ema_ind);
476 }
477 *indicator.get_unchecked_mut(i) = ema_ind;
478
479 if i < sig_warmup {
480 acc_sig += ema_ind;
481 cnt_sig += 1;
482 } else if i == sig_warmup {
483 ema_sig = acc_sig / (cnt_sig as f64);
484 *signal.get_unchecked_mut(i) = ema_sig;
485 } else {
486 ema_sig = (ema_ind - ema_sig).mul_add(alpha_sig, ema_sig);
487 *signal.get_unchecked_mut(i) = ema_sig;
488 }
489 }
490 }
491 } else {
492 for i in ind_warmup..len {
493 indicator[i] = f64::NAN;
494 }
495 for i in sig_warmup..len {
496 signal[i] = f64::NAN;
497 }
498 }
499 }
500
501 return Ok(RsmkOutput { indicator, signal });
502 }
503
504 if is_sma(matype) && is_sma(sigtype) {
505 let mut indicator = alloc_with_nan_prefix(len, ind_warmup);
506 let mut signal = alloc_with_nan_prefix(len, sig_warmup);
507
508 let mut sum_ind = 0.0;
509 let mut cnt_ind = 0usize;
510
511 let mut sum_sig = 0.0;
512 let mut cnt_sig = 0usize;
513
514 unsafe {
515 for i in mom_fv..len {
516 let v_new = *mom.get_unchecked(i);
517 if !v_new.is_nan() {
518 sum_ind += v_new;
519 cnt_ind += 1;
520 }
521
522 if i >= mom_fv + period {
523 let v_old = *mom.get_unchecked(i - period);
524 if !v_old.is_nan() {
525 sum_ind -= v_old;
526 cnt_ind -= 1;
527 }
528 }
529
530 if i >= ind_warmup {
531 let ind_val = if cnt_ind > 0 {
532 (sum_ind / cnt_ind as f64) * 100.0
533 } else {
534 f64::NAN
535 };
536 *indicator.get_unchecked_mut(i) = ind_val;
537
538 if !ind_val.is_nan() {
539 sum_sig += ind_val;
540 cnt_sig += 1;
541 }
542
543 if i >= sig_warmup {
544 let old_idx = i - signal_period;
545 let old_ind = *indicator.get_unchecked(old_idx);
546 if !old_ind.is_nan() {
547 sum_sig -= old_ind;
548 cnt_sig -= 1;
549 }
550
551 *signal.get_unchecked_mut(i) = if cnt_sig > 0 {
552 sum_sig / cnt_sig as f64
553 } else {
554 f64::NAN
555 };
556 }
557 }
558 }
559 }
560
561 return Ok(RsmkOutput { indicator, signal });
562 }
563
564 if is_ema(matype) && is_sma(sigtype) {
565 let mut indicator = alloc_with_nan_prefix(len, ind_warmup);
566 let mut signal = alloc_with_nan_prefix(len, sig_warmup);
567
568 if ind_warmup < len {
569 let mut sum = 0.0;
570 let mut cnt = 0usize;
571 let init_end = (mom_fv + period).min(len);
572 unsafe {
573 for i in mom_fv..init_end {
574 let v = *mom.get_unchecked(i);
575 if !v.is_nan() {
576 sum += v;
577 cnt += 1;
578 }
579 }
580 }
581
582 if cnt > 0 {
583 let alpha_ind = 2.0 / (period as f64 + 1.0);
584 let mut ema_ind = (sum / cnt as f64) * 100.0;
585
586 let mut sum_sig = 0.0;
587 let mut cnt_sig = 0usize;
588
589 unsafe {
590 *indicator.get_unchecked_mut(ind_warmup) = ema_ind;
591
592 sum_sig += ema_ind;
593 cnt_sig += 1;
594
595 if sig_warmup == ind_warmup {
596 *signal.get_unchecked_mut(sig_warmup) = sum_sig / cnt_sig as f64;
597 }
598
599 for i in (ind_warmup + 1)..len {
600 let mv = *mom.get_unchecked(i);
601 if !mv.is_nan() {
602 let src100 = mv * 100.0;
603 ema_ind = (src100 - ema_ind).mul_add(alpha_ind, ema_ind);
604 }
605 *indicator.get_unchecked_mut(i) = ema_ind;
606
607 if !ema_ind.is_nan() {
608 sum_sig += ema_ind;
609 cnt_sig += 1;
610 }
611
612 if i >= sig_warmup {
613 let old_idx = i - signal_period;
614 let old_ind = *indicator.get_unchecked(old_idx);
615 if !old_ind.is_nan() {
616 sum_sig -= old_ind;
617 cnt_sig -= 1;
618 }
619
620 *signal.get_unchecked_mut(i) = if cnt_sig > 0 {
621 sum_sig / cnt_sig as f64
622 } else {
623 f64::NAN
624 };
625 }
626 }
627 }
628 } else {
629 for i in ind_warmup..len {
630 indicator[i] = f64::NAN;
631 }
632 for i in sig_warmup..len {
633 signal[i] = f64::NAN;
634 }
635 }
636 }
637
638 return Ok(RsmkOutput { indicator, signal });
639 }
640
641 if is_sma(matype) && is_ema(sigtype) {
642 let mut indicator = alloc_with_nan_prefix(len, ind_warmup);
643 let mut signal = alloc_with_nan_prefix(len, sig_warmup);
644
645 let mut sum_ind = 0.0;
646 let mut cnt_ind = 0usize;
647
648 let alpha_sig = 2.0 / (signal_period as f64 + 1.0);
649 let mut acc_sig = 0.0;
650 let mut cnt_sig = 0usize;
651 let mut ema_sig = 0.0f64;
652
653 unsafe {
654 for i in mom_fv..len {
655 let v_new = *mom.get_unchecked(i);
656 if !v_new.is_nan() {
657 sum_ind += v_new;
658 cnt_ind += 1;
659 }
660
661 if i >= mom_fv + period {
662 let v_old = *mom.get_unchecked(i - period);
663 if !v_old.is_nan() {
664 sum_ind -= v_old;
665 cnt_ind -= 1;
666 }
667 }
668
669 if i >= ind_warmup {
670 let ind_val = if cnt_ind > 0 {
671 (sum_ind / cnt_ind as f64) * 100.0
672 } else {
673 f64::NAN
674 };
675 *indicator.get_unchecked_mut(i) = ind_val;
676
677 if i < sig_warmup {
678 if !ind_val.is_nan() {
679 acc_sig += ind_val;
680 cnt_sig += 1;
681 }
682 } else if i == sig_warmup {
683 ema_sig = if cnt_sig > 0 {
684 acc_sig / cnt_sig as f64
685 } else {
686 f64::NAN
687 };
688 *signal.get_unchecked_mut(i) = ema_sig;
689 } else {
690 if !ind_val.is_nan() && !ema_sig.is_nan() {
691 ema_sig = (ind_val - ema_sig).mul_add(alpha_sig, ema_sig);
692 } else if !ind_val.is_nan() && ema_sig.is_nan() {
693 ema_sig = ind_val;
694 }
695 *signal.get_unchecked_mut(i) = ema_sig;
696 }
697 }
698 }
699 }
700
701 return Ok(RsmkOutput { indicator, signal });
702 }
703
704 let matype = input.get_ma_type();
705 let sigmatype = input.get_signal_ma_type();
706
707 let mut indicator =
708 ma(matype, MaData::Slice(mom), period).map_err(|e| RsmkError::MaError(e.to_string()))?;
709 for v in &mut indicator {
710 *v *= 100.0;
711 }
712
713 let signal = ma(sigmatype, MaData::Slice(&indicator), signal_period)
714 .map_err(|e| RsmkError::MaError(e.to_string()))?;
715
716 Ok(RsmkOutput { indicator, signal })
717}
718
719#[inline]
720pub fn rsmk_into_slice(
721 dst_indicator: &mut [f64],
722 dst_signal: &mut [f64],
723 input: &RsmkInput,
724 _kern: Kernel,
725) -> Result<(), RsmkError> {
726 let (main, compare) = match &input.data {
727 RsmkData::Candles {
728 candles,
729 candles_compare,
730 source,
731 } => (
732 source_type(candles, source),
733 source_type(candles_compare, source),
734 ),
735 RsmkData::Slices { main, compare } => (*main, *compare),
736 };
737 if main.len() == 0 || compare.len() == 0 {
738 return Err(RsmkError::EmptyInputData);
739 }
740 if main.len() != compare.len() {
741 return Err(RsmkError::InvalidPeriod {
742 period: 0,
743 data_len: main.len(),
744 });
745 }
746 if dst_indicator.len() != main.len() {
747 return Err(RsmkError::OutputLengthMismatch {
748 expected: main.len(),
749 got: dst_indicator.len(),
750 });
751 }
752 if dst_signal.len() != main.len() {
753 return Err(RsmkError::OutputLengthMismatch {
754 expected: main.len(),
755 got: dst_signal.len(),
756 });
757 }
758
759 let p = &input.params;
760 let lookback = p.lookback.unwrap_or(90);
761 let period = p.period.unwrap_or(3);
762 let signal_period = p.signal_period.unwrap_or(20);
763 if lookback == 0 || period == 0 || signal_period == 0 {
764 return Err(RsmkError::InvalidPeriod {
765 period: 0,
766 data_len: main.len(),
767 });
768 }
769
770 let mut lr = Vec::with_capacity(main.len());
771 unsafe {
772 lr.set_len(main.len());
773 }
774 for i in 0..main.len() {
775 let m = main[i];
776 let c = compare[i];
777 unsafe {
778 *lr.get_unchecked_mut(i) = if m.is_nan() || c.is_nan() || c == 0.0 {
779 f64::NAN
780 } else {
781 (m / c).ln()
782 };
783 }
784 }
785 let first = lr
786 .iter()
787 .position(|x| !x.is_nan())
788 .ok_or(RsmkError::AllValuesNaN)?;
789
790 let mut mom = alloc_with_nan_prefix(lr.len(), first + lookback);
791 for i in (first + lookback)..lr.len() {
792 let a = lr[i];
793 let b = lr[i - lookback];
794 mom[i] = if a.is_nan() || b.is_nan() {
795 f64::NAN
796 } else {
797 a - b
798 };
799 }
800
801 let matype = p.matype.as_deref().unwrap_or("ema");
802 let mut ind =
803 ma(matype, MaData::Slice(&mom), period).map_err(|e| RsmkError::MaError(e.to_string()))?;
804 for v in &mut ind {
805 *v *= 100.0;
806 }
807
808 let sigtype = p.signal_matype.as_deref().unwrap_or("ema");
809 let sig = ma(sigtype, MaData::Slice(&ind), signal_period)
810 .map_err(|e| RsmkError::MaError(e.to_string()))?;
811
812 dst_indicator.copy_from_slice(&ind);
813 dst_signal.copy_from_slice(&sig);
814
815 Ok(())
816}
817
818#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
819#[inline]
820pub fn rsmk_into(
821 input: &RsmkInput,
822 indicator_out: &mut [f64],
823 signal_out: &mut [f64],
824) -> Result<(), RsmkError> {
825 let (main, compare) = match &input.data {
826 RsmkData::Candles {
827 candles,
828 candles_compare,
829 source,
830 } => (
831 source_type(candles, source),
832 source_type(candles_compare, source),
833 ),
834 RsmkData::Slices { main, compare } => (*main, *compare),
835 };
836 let len = main.len();
837 if len == 0 || compare.len() == 0 {
838 return Err(RsmkError::EmptyInputData);
839 }
840 if main.len() != compare.len() {
841 return Err(RsmkError::InvalidPeriod {
842 period: 0,
843 data_len: len,
844 });
845 }
846 if indicator_out.len() != len {
847 return Err(RsmkError::OutputLengthMismatch {
848 expected: len,
849 got: indicator_out.len(),
850 });
851 }
852 if signal_out.len() != len {
853 return Err(RsmkError::OutputLengthMismatch {
854 expected: len,
855 got: signal_out.len(),
856 });
857 }
858
859 let lookback = input.get_lookback();
860 let period = input.get_period();
861 let signal_period = input.get_signal_period();
862 if lookback == 0
863 || period == 0
864 || signal_period == 0
865 || period > len
866 || signal_period > len
867 || lookback >= len
868 {
869 return Err(RsmkError::InvalidPeriod {
870 period: lookback.max(period).max(signal_period),
871 data_len: len,
872 });
873 }
874
875 let mut lr = Vec::with_capacity(len);
876 unsafe {
877 lr.set_len(len);
878 }
879 for i in 0..len {
880 let m = main[i];
881 let c = compare[i];
882 unsafe {
883 *lr.get_unchecked_mut(i) = if m.is_nan() || c.is_nan() || c == 0.0 {
884 f64::NAN
885 } else {
886 (m / c).ln()
887 };
888 }
889 }
890 let first_valid = lr
891 .iter()
892 .position(|x| !x.is_nan())
893 .ok_or(RsmkError::AllValuesNaN)?;
894
895 let mut mom = alloc_with_nan_prefix(len, first_valid + lookback);
896 let mom_fv = first_valid + lookback;
897 unsafe {
898 for i in mom_fv..len {
899 let a = *lr.get_unchecked(i);
900 let b = *lr.get_unchecked(i - lookback);
901 *mom.get_unchecked_mut(i) = if a.is_nan() || b.is_nan() {
902 f64::NAN
903 } else {
904 a - b
905 };
906 }
907 }
908
909 #[inline(always)]
910 fn is_ema(s: &str) -> bool {
911 s.eq_ignore_ascii_case("ema")
912 }
913 #[inline(always)]
914 fn is_sma(s: &str) -> bool {
915 s.eq_ignore_ascii_case("sma")
916 }
917
918 let matype = input.get_ma_type();
919 let sigtype = input.get_signal_ma_type();
920
921 let ind_warmup = mom_fv.saturating_add(period.saturating_sub(1));
922 let sig_warmup = ind_warmup.saturating_add(signal_period.saturating_sub(1));
923
924 for i in 0..ind_warmup.min(len) {
925 indicator_out[i] = f64::NAN;
926 }
927 for i in 0..sig_warmup.min(len) {
928 signal_out[i] = f64::NAN;
929 }
930
931 if is_ema(matype) && is_ema(sigtype) {
932 if ind_warmup < len {
933 let mut sum = 0.0;
934 let mut cnt = 0usize;
935 let init_end = (mom_fv + period).min(len);
936 unsafe {
937 for i in mom_fv..init_end {
938 let v = *mom.get_unchecked(i);
939 if !v.is_nan() {
940 sum += v;
941 cnt += 1;
942 }
943 }
944 }
945 if cnt > 0 {
946 let alpha_ind = 2.0 / (period as f64 + 1.0);
947 let alpha_sig = 2.0 / (signal_period as f64 + 1.0);
948
949 let mut ema_ind = (sum / cnt as f64) * 100.0;
950 indicator_out[ind_warmup] = ema_ind;
951
952 let mut ema_sig = 0.0f64;
953 let mut acc_sig = ema_ind;
954 let mut cnt_sig = 1usize;
955 if sig_warmup == ind_warmup {
956 ema_sig = acc_sig / (cnt_sig as f64);
957 signal_out[sig_warmup] = ema_sig;
958 }
959
960 unsafe {
961 for i in (ind_warmup + 1)..len {
962 let mv = *mom.get_unchecked(i);
963 if mv.is_finite() {
964 let src100 = mv * 100.0;
965 ema_ind = (src100 - ema_ind).mul_add(alpha_ind, ema_ind);
966 }
967 *indicator_out.get_unchecked_mut(i) = ema_ind;
968
969 if i < sig_warmup {
970 acc_sig += ema_ind;
971 cnt_sig += 1;
972 } else if i == sig_warmup {
973 ema_sig = acc_sig / (cnt_sig as f64);
974 *signal_out.get_unchecked_mut(i) = ema_sig;
975 } else {
976 ema_sig = (ema_ind - ema_sig).mul_add(alpha_sig, ema_sig);
977 *signal_out.get_unchecked_mut(i) = ema_sig;
978 }
979 }
980 }
981 } else {
982 for i in ind_warmup..len {
983 indicator_out[i] = f64::NAN;
984 }
985 for i in sig_warmup..len {
986 signal_out[i] = f64::NAN;
987 }
988 }
989 }
990 return Ok(());
991 }
992
993 if is_sma(matype) && is_sma(sigtype) {
994 let mut sum_ind = 0.0;
995 let mut cnt_ind = 0usize;
996 let mut sum_sig = 0.0;
997 let mut cnt_sig = 0usize;
998 unsafe {
999 for i in mom_fv..len {
1000 let v_new = *mom.get_unchecked(i);
1001 if !v_new.is_nan() {
1002 sum_ind += v_new;
1003 cnt_ind += 1;
1004 }
1005 if i >= mom_fv + period {
1006 let v_old = *mom.get_unchecked(i - period);
1007 if !v_old.is_nan() {
1008 sum_ind -= v_old;
1009 cnt_ind -= 1;
1010 }
1011 }
1012 if i >= ind_warmup {
1013 let ind_val = if cnt_ind > 0 {
1014 (sum_ind / cnt_ind as f64) * 100.0
1015 } else {
1016 f64::NAN
1017 };
1018 *indicator_out.get_unchecked_mut(i) = ind_val;
1019 if !ind_val.is_nan() {
1020 sum_sig += ind_val;
1021 cnt_sig += 1;
1022 }
1023 if i >= sig_warmup {
1024 let old_idx = i - signal_period;
1025 let old_ind = *indicator_out.get_unchecked(old_idx);
1026 if !old_ind.is_nan() {
1027 sum_sig -= old_ind;
1028 cnt_sig -= 1;
1029 }
1030 *signal_out.get_unchecked_mut(i) = if cnt_sig > 0 {
1031 sum_sig / cnt_sig as f64
1032 } else {
1033 f64::NAN
1034 };
1035 }
1036 }
1037 }
1038 }
1039 return Ok(());
1040 }
1041
1042 if is_ema(matype) && is_sma(sigtype) {
1043 if ind_warmup < len {
1044 let mut sum = 0.0;
1045 let mut cnt = 0usize;
1046 let init_end = (mom_fv + period).min(len);
1047 unsafe {
1048 for i in mom_fv..init_end {
1049 let v = *mom.get_unchecked(i);
1050 if !v.is_nan() {
1051 sum += v;
1052 cnt += 1;
1053 }
1054 }
1055 }
1056 if cnt > 0 {
1057 let alpha_ind = 2.0 / (period as f64 + 1.0);
1058 let mut ema_ind = (sum / cnt as f64) * 100.0;
1059 let mut sum_sig = 0.0;
1060 let mut cnt_sig = 0usize;
1061 unsafe {
1062 *indicator_out.get_unchecked_mut(ind_warmup) = ema_ind;
1063 sum_sig += ema_ind;
1064 cnt_sig += 1;
1065 if sig_warmup == ind_warmup {
1066 *signal_out.get_unchecked_mut(sig_warmup) = sum_sig / cnt_sig as f64;
1067 }
1068 for i in (ind_warmup + 1)..len {
1069 let mv = *mom.get_unchecked(i);
1070 if !mv.is_nan() {
1071 let src100 = mv * 100.0;
1072 ema_ind = (src100 - ema_ind).mul_add(alpha_ind, ema_ind);
1073 }
1074 *indicator_out.get_unchecked_mut(i) = ema_ind;
1075 if i < sig_warmup {
1076 sum_sig += ema_ind;
1077 cnt_sig += 1;
1078 } else if i == sig_warmup {
1079 *signal_out.get_unchecked_mut(i) = sum_sig / cnt_sig as f64;
1080 } else {
1081 let old_idx = i - signal_period;
1082 let old_ind = *indicator_out.get_unchecked(old_idx);
1083 if !old_ind.is_nan() {
1084 sum_sig -= old_ind;
1085 cnt_sig -= 1;
1086 }
1087 sum_sig += ema_ind;
1088 cnt_sig += 1;
1089 *signal_out.get_unchecked_mut(i) = if cnt_sig > 0 {
1090 sum_sig / cnt_sig as f64
1091 } else {
1092 f64::NAN
1093 };
1094 }
1095 }
1096 }
1097 } else {
1098 for i in ind_warmup..len {
1099 indicator_out[i] = f64::NAN;
1100 }
1101 for i in sig_warmup..len {
1102 signal_out[i] = f64::NAN;
1103 }
1104 }
1105 }
1106 return Ok(());
1107 }
1108
1109 if is_sma(matype) && is_ema(sigtype) {
1110 let mut sum_ind = 0.0;
1111 let mut cnt_ind = 0usize;
1112 let alpha_sig = 2.0 / (signal_period as f64 + 1.0);
1113 let mut ema_sig = 0.0f64;
1114 let mut seeded_sig = false;
1115 let mut acc_sig = 0.0f64;
1116 let mut cnt_sig = 0usize;
1117 unsafe {
1118 for i in mom_fv..len {
1119 let v_new = *mom.get_unchecked(i);
1120 if !v_new.is_nan() {
1121 sum_ind += v_new;
1122 cnt_ind += 1;
1123 }
1124 if i >= mom_fv + period {
1125 let v_old = *mom.get_unchecked(i - period);
1126 if !v_old.is_nan() {
1127 sum_ind -= v_old;
1128 cnt_ind -= 1;
1129 }
1130 }
1131 if i >= ind_warmup {
1132 let ind_val = if cnt_ind > 0 {
1133 (sum_ind / cnt_ind as f64) * 100.0
1134 } else {
1135 f64::NAN
1136 };
1137 *indicator_out.get_unchecked_mut(i) = ind_val;
1138 if !seeded_sig {
1139 acc_sig += ind_val;
1140 cnt_sig += 1;
1141 if i == sig_warmup {
1142 ema_sig = acc_sig / (cnt_sig as f64);
1143 seeded_sig = true;
1144 *signal_out.get_unchecked_mut(i) = ema_sig;
1145 }
1146 } else {
1147 if ind_val.is_finite() {
1148 ema_sig = (ind_val - ema_sig).mul_add(alpha_sig, ema_sig);
1149 }
1150 *signal_out.get_unchecked_mut(i) = ema_sig;
1151 }
1152 }
1153 }
1154 }
1155 return Ok(());
1156 }
1157
1158 let mut indicator =
1159 ma(matype, MaData::Slice(&mom), period).map_err(|e| RsmkError::MaError(e.to_string()))?;
1160 for v in &mut indicator {
1161 *v *= 100.0;
1162 }
1163 let signal = ma(sigtype, MaData::Slice(&indicator), signal_period)
1164 .map_err(|e| RsmkError::MaError(e.to_string()))?;
1165 indicator_out.copy_from_slice(&indicator);
1166 signal_out.copy_from_slice(&signal);
1167 Ok(())
1168}
1169
1170#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1171#[inline]
1172#[target_feature(enable = "avx2")]
1173pub fn rsmk_avx2(
1174 lr: &[f64],
1175 lookback: usize,
1176 period: usize,
1177 signal_period: usize,
1178 input: &RsmkInput,
1179 first_valid: usize,
1180 mom: &mut [f64],
1181) -> Result<RsmkOutput, RsmkError> {
1182 rsmk_scalar(lr, lookback, period, signal_period, input, first_valid, mom)
1183}
1184
1185#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1186#[inline]
1187#[target_feature(enable = "avx512f")]
1188pub fn rsmk_avx512(
1189 lr: &[f64],
1190 lookback: usize,
1191 period: usize,
1192 signal_period: usize,
1193 input: &RsmkInput,
1194 first_valid: usize,
1195 mom: &mut [f64],
1196) -> Result<RsmkOutput, RsmkError> {
1197 if period <= 32 {
1198 unsafe { rsmk_avx512_short(lr, lookback, period, signal_period, input, first_valid, mom) }
1199 } else {
1200 unsafe { rsmk_avx512_long(lr, lookback, period, signal_period, input, first_valid, mom) }
1201 }
1202}
1203
1204#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1205#[inline]
1206#[target_feature(enable = "avx512f")]
1207pub unsafe fn rsmk_avx512_short(
1208 lr: &[f64],
1209 lookback: usize,
1210 period: usize,
1211 signal_period: usize,
1212 input: &RsmkInput,
1213 first_valid: usize,
1214 mom: &mut [f64],
1215) -> Result<RsmkOutput, RsmkError> {
1216 rsmk_scalar(lr, lookback, period, signal_period, input, first_valid, mom)
1217}
1218
1219#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1220#[inline]
1221#[target_feature(enable = "avx512f")]
1222pub unsafe fn rsmk_avx512_long(
1223 lr: &[f64],
1224 lookback: usize,
1225 period: usize,
1226 signal_period: usize,
1227 input: &RsmkInput,
1228 first_valid: usize,
1229 mom: &mut [f64],
1230) -> Result<RsmkOutput, RsmkError> {
1231 rsmk_avx512_short(lr, lookback, period, signal_period, input, first_valid, mom)
1232}
1233
1234#[derive(Debug, Clone)]
1235pub struct RsmkStream {
1236 lookback: usize,
1237 period: usize,
1238 signal_period: usize,
1239
1240 matype: String,
1241 signal_matype: String,
1242 main_is_ema: bool,
1243 signal_is_ema: bool,
1244
1245 lr_buf: Vec<f64>,
1246 lr_head: usize,
1247 lr_len: usize,
1248
1249 saw_first_finite_mom: bool,
1250
1251 ema_seed_pos: usize,
1252 ema_seed_sum: f64,
1253 ema_seed_cnt: usize,
1254 ema_ind: f64,
1255 ema_ind_seeded: bool,
1256 ema_ind_dead: bool,
1257
1258 ind_win_buf: Vec<f64>,
1259 ind_win_head: usize,
1260 ind_win_len: usize,
1261 ind_win_sum: f64,
1262 ind_win_cnt: usize,
1263
1264 indicator_started: bool,
1265
1266 alpha_sig: f64,
1267 ema_sig: f64,
1268 ema_sig_seeded: bool,
1269 ema_sig_seed_pos: usize,
1270 ema_sig_seed_sum: f64,
1271 ema_sig_seed_cnt: usize,
1272
1273 sig_win_buf: Vec<f64>,
1274 sig_win_head: usize,
1275 sig_win_len: usize,
1276 sig_win_sum: f64,
1277 sig_win_cnt: usize,
1278
1279 alpha_main: f64,
1280}
1281
1282impl RsmkStream {
1283 pub fn try_new(params: RsmkParams) -> Result<Self, RsmkError> {
1284 let lookback = params.lookback.unwrap_or(90);
1285 let period = params.period.unwrap_or(3);
1286 let signal_period = params.signal_period.unwrap_or(20);
1287 if lookback == 0 || period == 0 || signal_period == 0 {
1288 return Err(RsmkError::InvalidPeriod {
1289 period: lookback.max(period).max(signal_period),
1290 data_len: 0,
1291 });
1292 }
1293 let matype = params.matype.unwrap_or_else(|| "ema".to_string());
1294 let signal_matype = params.signal_matype.unwrap_or_else(|| "ema".to_string());
1295 let main_is_ema = matype.eq_ignore_ascii_case("ema");
1296 let signal_is_ema = signal_matype.eq_ignore_ascii_case("ema");
1297
1298 Ok(Self {
1299 lookback,
1300 period,
1301 signal_period,
1302 matype,
1303 signal_matype,
1304 main_is_ema,
1305 signal_is_ema,
1306
1307 lr_buf: vec![f64::NAN; lookback],
1308 lr_head: 0,
1309 lr_len: 0,
1310
1311 saw_first_finite_mom: false,
1312 ema_seed_pos: 0,
1313 ema_seed_sum: 0.0,
1314 ema_seed_cnt: 0,
1315 ema_ind: f64::NAN,
1316 ema_ind_seeded: false,
1317 ema_ind_dead: false,
1318
1319 ind_win_buf: vec![f64::NAN; period],
1320 ind_win_head: 0,
1321 ind_win_len: 0,
1322 ind_win_sum: 0.0,
1323 ind_win_cnt: 0,
1324
1325 indicator_started: false,
1326
1327 alpha_sig: 2.0 / (signal_period as f64 + 1.0),
1328 ema_sig: f64::NAN,
1329 ema_sig_seeded: false,
1330 ema_sig_seed_pos: 0,
1331 ema_sig_seed_sum: 0.0,
1332 ema_sig_seed_cnt: 0,
1333
1334 sig_win_buf: vec![f64::NAN; signal_period],
1335 sig_win_head: 0,
1336 sig_win_len: 0,
1337 sig_win_sum: 0.0,
1338 sig_win_cnt: 0,
1339
1340 alpha_main: 2.0 / (period as f64 + 1.0),
1341 })
1342 }
1343
1344 #[inline(always)]
1345 fn push_ring(buf: &mut [f64], head: &mut usize, len: &mut usize, v: f64) -> Option<f64> {
1346 let cap = buf.len();
1347 if cap == 0 {
1348 return None;
1349 }
1350 let evicted = if *len < cap {
1351 buf[*head] = v;
1352 *len += 1;
1353 None
1354 } else {
1355 let old = core::mem::replace(&mut buf[*head], v);
1356 Some(old)
1357 };
1358 *head += 1;
1359 if *head == cap {
1360 *head = 0;
1361 }
1362 evicted
1363 }
1364
1365 #[inline(always)]
1366 fn window_push(
1367 buf: &mut [f64],
1368 head: &mut usize,
1369 len: &mut usize,
1370 sum: &mut f64,
1371 cnt: &mut usize,
1372 v: f64,
1373 ) {
1374 let cap = buf.len();
1375 if cap == 0 {
1376 return;
1377 }
1378 if *len < cap {
1379 buf[*head] = v;
1380 if v.is_finite() {
1381 *sum += v;
1382 *cnt += 1;
1383 }
1384 *len += 1;
1385 } else {
1386 let old = core::mem::replace(&mut buf[*head], v);
1387 if old.is_finite() {
1388 *sum -= old;
1389 *cnt -= 1;
1390 }
1391 if v.is_finite() {
1392 *sum += v;
1393 *cnt += 1;
1394 }
1395 }
1396 *head += 1;
1397 if *head == cap {
1398 *head = 0;
1399 }
1400 }
1401
1402 pub fn update(&mut self, main: f64, compare: f64) -> Option<(f64, f64)> {
1403 let lr = if main.is_nan() || compare.is_nan() || compare == 0.0 {
1404 f64::NAN
1405 } else {
1406 (main / compare).ln()
1407 };
1408
1409 let evicted = Self::push_ring(&mut self.lr_buf, &mut self.lr_head, &mut self.lr_len, lr);
1410 let mom = match evicted {
1411 None => f64::NAN,
1412 Some(old_lr) => {
1413 if lr.is_nan() || old_lr.is_nan() {
1414 f64::NAN
1415 } else {
1416 lr - old_lr
1417 }
1418 }
1419 };
1420
1421 if !self.saw_first_finite_mom && mom.is_finite() {
1422 self.saw_first_finite_mom = true;
1423 }
1424
1425 let indicator = if self.main_is_ema {
1426 self.update_indicator_ema(mom)
1427 } else {
1428 self.update_indicator_sma(mom)
1429 };
1430
1431 let signal = if self.signal_is_ema {
1432 self.update_signal_ema(indicator)
1433 } else {
1434 self.update_signal_sma(indicator)
1435 };
1436
1437 Some((indicator, signal))
1438 }
1439
1440 #[inline(always)]
1441 fn update_indicator_ema(&mut self, mom: f64) -> f64 {
1442 if self.ema_ind_dead {
1443 return f64::NAN;
1444 }
1445 if !self.saw_first_finite_mom {
1446 return f64::NAN;
1447 }
1448 if !self.ema_ind_seeded {
1449 self.ema_seed_pos += 1;
1450 if mom.is_finite() {
1451 self.ema_seed_sum += mom;
1452 self.ema_seed_cnt += 1;
1453 }
1454 if self.ema_seed_pos < self.period {
1455 return f64::NAN;
1456 }
1457 if self.ema_seed_cnt == 0 {
1458 self.ema_ind_dead = true;
1459 self.indicator_started = true;
1460 return f64::NAN;
1461 }
1462 self.ema_ind = (self.ema_seed_sum / self.ema_seed_cnt as f64) * 100.0;
1463 self.ema_ind_seeded = true;
1464 self.indicator_started = true;
1465 return self.ema_ind;
1466 }
1467 if mom.is_finite() {
1468 let src100 = mom * 100.0;
1469 self.ema_ind = (src100 - self.ema_ind).mul_add(self.alpha_main, self.ema_ind);
1470 }
1471 self.ema_ind
1472 }
1473
1474 #[inline(always)]
1475 fn update_indicator_sma(&mut self, mom: f64) -> f64 {
1476 if !self.saw_first_finite_mom {
1477 return f64::NAN;
1478 }
1479 Self::window_push(
1480 &mut self.ind_win_buf,
1481 &mut self.ind_win_head,
1482 &mut self.ind_win_len,
1483 &mut self.ind_win_sum,
1484 &mut self.ind_win_cnt,
1485 mom,
1486 );
1487
1488 if self.ind_win_len < self.period {
1489 return f64::NAN;
1490 }
1491 let ind = if self.ind_win_cnt > 0 {
1492 (self.ind_win_sum / self.ind_win_cnt as f64) * 100.0
1493 } else {
1494 f64::NAN
1495 };
1496 if !self.indicator_started {
1497 self.indicator_started = true;
1498 }
1499 ind
1500 }
1501
1502 #[inline(always)]
1503 fn update_signal_ema(&mut self, indicator: f64) -> f64 {
1504 if self.ema_ind_dead {
1505 return f64::NAN;
1506 }
1507 if !self.indicator_started {
1508 return f64::NAN;
1509 }
1510 if !self.ema_sig_seeded {
1511 self.ema_sig_seed_pos += 1;
1512 if indicator.is_finite() {
1513 self.ema_sig_seed_sum += indicator;
1514 self.ema_sig_seed_cnt += 1;
1515 }
1516 if self.ema_sig_seed_pos < self.signal_period {
1517 return f64::NAN;
1518 }
1519 self.ema_sig = if self.ema_sig_seed_cnt > 0 {
1520 self.ema_sig_seed_sum / self.ema_sig_seed_cnt as f64
1521 } else {
1522 f64::NAN
1523 };
1524 self.ema_sig_seeded = true;
1525 return self.ema_sig;
1526 }
1527 if indicator.is_finite() && self.ema_sig.is_finite() {
1528 self.ema_sig = (indicator - self.ema_sig).mul_add(self.alpha_sig, self.ema_sig);
1529 } else if indicator.is_finite() && !self.ema_sig.is_finite() {
1530 if !self.main_is_ema {
1531 self.ema_sig = indicator;
1532 }
1533 }
1534 self.ema_sig
1535 }
1536
1537 #[inline(always)]
1538 fn update_signal_sma(&mut self, indicator: f64) -> f64 {
1539 if self.ema_ind_dead {
1540 return f64::NAN;
1541 }
1542 if !self.indicator_started {
1543 return f64::NAN;
1544 }
1545 Self::window_push(
1546 &mut self.sig_win_buf,
1547 &mut self.sig_win_head,
1548 &mut self.sig_win_len,
1549 &mut self.sig_win_sum,
1550 &mut self.sig_win_cnt,
1551 indicator,
1552 );
1553 if self.sig_win_len < self.signal_period {
1554 return f64::NAN;
1555 }
1556 if self.sig_win_cnt > 0 {
1557 self.sig_win_sum / self.sig_win_cnt as f64
1558 } else {
1559 f64::NAN
1560 }
1561 }
1562}
1563
1564#[derive(Clone, Debug)]
1565pub struct RsmkBatchRange {
1566 pub lookback: (usize, usize, usize),
1567 pub period: (usize, usize, usize),
1568 pub signal_period: (usize, usize, usize),
1569}
1570
1571impl Default for RsmkBatchRange {
1572 fn default() -> Self {
1573 Self {
1574 lookback: (90, 339, 1),
1575 period: (3, 3, 0),
1576 signal_period: (20, 20, 0),
1577 }
1578 }
1579}
1580
1581#[derive(Clone, Debug, Default)]
1582pub struct RsmkBatchBuilder {
1583 range: RsmkBatchRange,
1584 kernel: Kernel,
1585}
1586
1587impl RsmkBatchBuilder {
1588 pub fn new() -> Self {
1589 Self::default()
1590 }
1591 pub fn kernel(mut self, k: Kernel) -> Self {
1592 self.kernel = k;
1593 self
1594 }
1595 #[inline]
1596 pub fn lookback_range(mut self, start: usize, end: usize, step: usize) -> Self {
1597 self.range.lookback = (start, end, step);
1598 self
1599 }
1600 #[inline]
1601 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1602 self.range.period = (start, end, step);
1603 self
1604 }
1605 #[inline]
1606 pub fn signal_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1607 self.range.signal_period = (start, end, step);
1608 self
1609 }
1610 pub fn apply_slices(self, main: &[f64], compare: &[f64]) -> Result<RsmkBatchOutput, RsmkError> {
1611 rsmk_batch_with_kernel(main, compare, &self.range, self.kernel)
1612 }
1613}
1614
1615#[derive(Clone, Debug)]
1616pub struct RsmkBatchOutput {
1617 pub indicator: Vec<f64>,
1618 pub signal: Vec<f64>,
1619 pub combos: Vec<RsmkParams>,
1620 pub rows: usize,
1621 pub cols: usize,
1622}
1623
1624impl RsmkBatchOutput {
1625 pub fn row_for_params(&self, p: &RsmkParams) -> Option<usize> {
1626 self.combos.iter().position(|c| {
1627 c.lookback.unwrap_or(90) == p.lookback.unwrap_or(90)
1628 && c.period.unwrap_or(3) == p.period.unwrap_or(3)
1629 && c.signal_period.unwrap_or(20) == p.signal_period.unwrap_or(20)
1630 })
1631 }
1632
1633 pub fn indicator_for(&self, p: &RsmkParams) -> Option<&[f64]> {
1634 self.row_for_params(p).map(|row| {
1635 let start = row * self.cols;
1636 &self.indicator[start..start + self.cols]
1637 })
1638 }
1639
1640 pub fn signal_for(&self, p: &RsmkParams) -> Option<&[f64]> {
1641 self.row_for_params(p).map(|row| {
1642 let start = row * self.cols;
1643 &self.signal[start..start + self.cols]
1644 })
1645 }
1646}
1647
1648#[inline(always)]
1649fn expand_grid(r: &RsmkBatchRange) -> Vec<RsmkParams> {
1650 fn axis((start, end, step): (usize, usize, usize)) -> Vec<usize> {
1651 if step == 0 || start == end {
1652 return vec![start];
1653 }
1654 let mut vals = Vec::new();
1655 if start <= end {
1656 let st = step.max(1);
1657 for v in (start..=end).step_by(st) {
1658 vals.push(v);
1659 }
1660 } else {
1661 let mut cur = start;
1662 let s = step.max(1);
1663 loop {
1664 vals.push(cur);
1665 if cur <= end {
1666 break;
1667 }
1668 if cur < s {
1669 break;
1670 }
1671 let next = cur - s;
1672 if next == cur {
1673 break;
1674 }
1675 cur = next;
1676 }
1677 }
1678 vals
1679 }
1680 let looks = axis(r.lookback);
1681 let periods = axis(r.period);
1682 let signals = axis(r.signal_period);
1683
1684 let mut out = Vec::with_capacity(looks.len() * periods.len() * signals.len());
1685 for &l in &looks {
1686 for &p in &periods {
1687 for &s in &signals {
1688 out.push(RsmkParams {
1689 lookback: Some(l),
1690 period: Some(p),
1691 signal_period: Some(s),
1692 matype: Some("ema".to_string()),
1693 signal_matype: Some("ema".to_string()),
1694 });
1695 }
1696 }
1697 }
1698 out
1699}
1700
1701pub fn rsmk_batch_with_kernel(
1702 main: &[f64],
1703 compare: &[f64],
1704 sweep: &RsmkBatchRange,
1705 k: Kernel,
1706) -> Result<RsmkBatchOutput, RsmkError> {
1707 let kernel = match k {
1708 Kernel::Auto => detect_best_batch_kernel(),
1709 other if other.is_batch() => other,
1710 other => return Err(RsmkError::InvalidKernelForBatch(other)),
1711 };
1712 let simd = match kernel {
1713 Kernel::Avx512Batch => Kernel::Avx512,
1714 Kernel::Avx2Batch => Kernel::Avx2,
1715 Kernel::ScalarBatch => Kernel::Scalar,
1716 _ => unreachable!(),
1717 };
1718 rsmk_batch_par_slice(main, compare, sweep, simd)
1719}
1720
1721pub fn rsmk_batch_slice(
1722 main: &[f64],
1723 compare: &[f64],
1724 sweep: &RsmkBatchRange,
1725 kern: Kernel,
1726) -> Result<RsmkBatchOutput, RsmkError> {
1727 rsmk_batch_inner(main, compare, sweep, kern, false)
1728}
1729
1730pub fn rsmk_batch_par_slice(
1731 main: &[f64],
1732 compare: &[f64],
1733 sweep: &RsmkBatchRange,
1734 kern: Kernel,
1735) -> Result<RsmkBatchOutput, RsmkError> {
1736 rsmk_batch_inner(main, compare, sweep, kern, true)
1737}
1738
1739#[inline(always)]
1740fn rsmk_batch_inner_into(
1741 main: &[f64],
1742 compare: &[f64],
1743 sweep: &RsmkBatchRange,
1744 kern: Kernel,
1745 parallel: bool,
1746 indicator_out: &mut [f64],
1747 signal_out: &mut [f64],
1748) -> Result<Vec<RsmkParams>, RsmkError> {
1749 let combos = expand_grid(sweep);
1750 if combos.is_empty() {
1751 return Err(RsmkError::InvalidRange {
1752 start: sweep.lookback.0,
1753 end: sweep.lookback.1,
1754 step: sweep.lookback.2,
1755 });
1756 }
1757 let first = main
1758 .iter()
1759 .zip(compare.iter())
1760 .position(|(&m, &c)| m.is_finite() && c.is_finite() && c != 0.0)
1761 .ok_or(RsmkError::AllValuesNaN)?;
1762 let max_p = combos
1763 .iter()
1764 .map(|c| {
1765 c.lookback
1766 .unwrap()
1767 .max(c.period.unwrap())
1768 .max(c.signal_period.unwrap())
1769 })
1770 .max()
1771 .unwrap();
1772
1773 if main.len() - first < max_p {
1774 return Err(RsmkError::NotEnoughValidData {
1775 needed: max_p,
1776 valid: main.len() - first,
1777 });
1778 }
1779
1780 let rows = combos.len();
1781 let cols = main.len();
1782
1783 let expected = rows.checked_mul(cols).ok_or(RsmkError::InvalidRange {
1784 start: rows,
1785 end: cols,
1786 step: 0,
1787 })?;
1788 if indicator_out.len() != expected {
1789 return Err(RsmkError::OutputLengthMismatch {
1790 expected,
1791 got: indicator_out.len(),
1792 });
1793 }
1794 if signal_out.len() != expected {
1795 return Err(RsmkError::OutputLengthMismatch {
1796 expected,
1797 got: signal_out.len(),
1798 });
1799 }
1800
1801 let mut lr = Vec::with_capacity(cols);
1802 unsafe {
1803 lr.set_len(cols);
1804 }
1805 for i in 0..cols {
1806 let m = main[i];
1807 let c = compare[i];
1808 unsafe {
1809 *lr.get_unchecked_mut(i) = if m.is_nan() || c.is_nan() || c == 0.0 {
1810 f64::NAN
1811 } else {
1812 (m / c).ln()
1813 };
1814 }
1815 }
1816
1817 use std::collections::HashMap;
1818 let mut mom_by_lookback: HashMap<usize, Vec<f64>> = HashMap::new();
1819 for &lookback in combos
1820 .iter()
1821 .map(|c| c.lookback.unwrap())
1822 .collect::<std::collections::BTreeSet<_>>()
1823 .iter()
1824 {
1825 let mut m = alloc_with_nan_prefix(cols, first + lookback);
1826 let start = first + lookback;
1827 for i in start..cols {
1828 let a = unsafe { *lr.get_unchecked(i) };
1829 let b = unsafe { *lr.get_unchecked(i - lookback) };
1830 unsafe { *m.get_unchecked_mut(i) = a - b };
1831 }
1832 mom_by_lookback.insert(lookback, m);
1833 }
1834
1835 let do_row = |row: usize, ind_row: &mut [f64], sig_row: &mut [f64]| unsafe {
1836 let prm = &combos[row];
1837 let lookback = prm.lookback.unwrap();
1838 let period = prm.period.unwrap();
1839 let signal_period = prm.signal_period.unwrap();
1840 let mt = prm.matype.as_deref().unwrap_or("ema");
1841 let st = prm.signal_matype.as_deref().unwrap_or("ema");
1842
1843 let mom = mom_by_lookback.get(&lookback).unwrap();
1844
1845 match ma(mt, MaData::Slice(&mom), period) {
1846 Ok(mut v) => {
1847 for x in &mut v {
1848 *x *= 100.0;
1849 }
1850 ind_row.copy_from_slice(&v);
1851 }
1852 Err(_) => {
1853 for x in ind_row.iter_mut() {
1854 *x = f64::NAN;
1855 }
1856 }
1857 }
1858
1859 match ma(st, MaData::Slice(ind_row), signal_period) {
1860 Ok(vs) => {
1861 sig_row.copy_from_slice(&vs);
1862 }
1863 Err(_) => {
1864 for x in sig_row.iter_mut() {
1865 *x = f64::NAN;
1866 }
1867 }
1868 }
1869 };
1870
1871 if parallel {
1872 #[cfg(not(target_arch = "wasm32"))]
1873 {
1874 indicator_out
1875 .par_chunks_mut(cols)
1876 .zip(signal_out.par_chunks_mut(cols))
1877 .enumerate()
1878 .for_each(|(row, (ind_row, sig_row))| do_row(row, ind_row, sig_row));
1879 }
1880
1881 #[cfg(target_arch = "wasm32")]
1882 {
1883 for (row, (ind_row, sig_row)) in indicator_out
1884 .chunks_mut(cols)
1885 .zip(signal_out.chunks_mut(cols))
1886 .enumerate()
1887 {
1888 do_row(row, ind_row, sig_row);
1889 }
1890 }
1891 } else {
1892 for (row, (ind_row, sig_row)) in indicator_out
1893 .chunks_mut(cols)
1894 .zip(signal_out.chunks_mut(cols))
1895 .enumerate()
1896 {
1897 do_row(row, ind_row, sig_row);
1898 }
1899 }
1900
1901 Ok(combos)
1902}
1903
1904fn rsmk_batch_inner(
1905 main: &[f64],
1906 compare: &[f64],
1907 sweep: &RsmkBatchRange,
1908 kern: Kernel,
1909 parallel: bool,
1910) -> Result<RsmkBatchOutput, RsmkError> {
1911 let combos = expand_grid(sweep);
1912 if combos.is_empty() {
1913 return Err(RsmkError::InvalidRange {
1914 start: sweep.lookback.0,
1915 end: sweep.lookback.1,
1916 step: sweep.lookback.2,
1917 });
1918 }
1919 let first = main
1920 .iter()
1921 .zip(compare.iter())
1922 .position(|(&m, &c)| m.is_finite() && c.is_finite() && c != 0.0)
1923 .ok_or(RsmkError::AllValuesNaN)?;
1924 let max_p = combos
1925 .iter()
1926 .map(|c| {
1927 c.lookback
1928 .unwrap()
1929 .max(c.period.unwrap())
1930 .max(c.signal_period.unwrap())
1931 })
1932 .max()
1933 .unwrap();
1934
1935 if main.len() - first < max_p {
1936 return Err(RsmkError::NotEnoughValidData {
1937 needed: max_p,
1938 valid: main.len() - first,
1939 });
1940 }
1941
1942 let rows = combos.len();
1943 let cols = main.len();
1944
1945 let _expected = rows.checked_mul(cols).ok_or(RsmkError::InvalidRange {
1946 start: rows,
1947 end: cols,
1948 step: 0,
1949 })?;
1950
1951 let mut indicators = make_uninit_matrix(rows, cols);
1952 let mut signals = make_uninit_matrix(rows, cols);
1953
1954 let warmup_periods: Vec<usize> = combos
1955 .iter()
1956 .map(|c| {
1957 let lookback = c.lookback.unwrap();
1958 let period = c.period.unwrap();
1959 let signal_period = c.signal_period.unwrap();
1960 first + lookback.max(period).max(signal_period)
1961 })
1962 .collect();
1963
1964 init_matrix_prefixes(&mut indicators, cols, &warmup_periods);
1965 init_matrix_prefixes(&mut signals, cols, &warmup_periods);
1966
1967 let mut indicators = unsafe {
1968 use std::mem::ManuallyDrop;
1969 let mut v = ManuallyDrop::new(indicators);
1970 Vec::from_raw_parts(v.as_mut_ptr() as *mut f64, v.len(), v.capacity())
1971 };
1972 let mut signals = unsafe {
1973 use std::mem::ManuallyDrop;
1974 let mut v = ManuallyDrop::new(signals);
1975 Vec::from_raw_parts(v.as_mut_ptr() as *mut f64, v.len(), v.capacity())
1976 };
1977
1978 let mut lr = Vec::with_capacity(cols);
1979 unsafe {
1980 lr.set_len(cols);
1981 }
1982 for i in 0..cols {
1983 let m = main[i];
1984 let c = compare[i];
1985 unsafe {
1986 *lr.get_unchecked_mut(i) = if m.is_nan() || c.is_nan() || c == 0.0 {
1987 f64::NAN
1988 } else {
1989 (m / c).ln()
1990 };
1991 }
1992 }
1993
1994 use std::collections::HashMap;
1995 let mut mom_by_lookback: HashMap<usize, Vec<f64>> = HashMap::new();
1996 for &lookback in combos
1997 .iter()
1998 .map(|c| c.lookback.unwrap())
1999 .collect::<std::collections::BTreeSet<_>>()
2000 .iter()
2001 {
2002 let mut m = alloc_with_nan_prefix(cols, first + lookback);
2003 let start = first + lookback;
2004 for i in start..cols {
2005 let a = unsafe { *lr.get_unchecked(i) };
2006 let b = unsafe { *lr.get_unchecked(i - lookback) };
2007 unsafe { *m.get_unchecked_mut(i) = a - b };
2008 }
2009 mom_by_lookback.insert(lookback, m);
2010 }
2011
2012 let do_row = |row: usize, ind_row: &mut [f64], sig_row: &mut [f64]| unsafe {
2013 let prm = &combos[row];
2014 let lookback = prm.lookback.unwrap();
2015 let period = prm.period.unwrap();
2016 let signal_period = prm.signal_period.unwrap();
2017 let mt = prm.matype.as_deref().unwrap_or("ema");
2018 let st = prm.signal_matype.as_deref().unwrap_or("ema");
2019
2020 let mom = mom_by_lookback.get(&lookback).unwrap();
2021
2022 match ma(mt, MaData::Slice(&mom), period) {
2023 Ok(mut v) => {
2024 for x in &mut v {
2025 *x *= 100.0;
2026 }
2027 ind_row.copy_from_slice(&v);
2028 }
2029 Err(_) => {
2030 for x in ind_row.iter_mut() {
2031 *x = f64::NAN;
2032 }
2033 }
2034 }
2035
2036 match ma(st, MaData::Slice(ind_row), signal_period) {
2037 Ok(vs) => {
2038 sig_row.copy_from_slice(&vs);
2039 }
2040 Err(_) => {
2041 for x in sig_row.iter_mut() {
2042 *x = f64::NAN;
2043 }
2044 }
2045 }
2046 };
2047
2048 if parallel {
2049 #[cfg(not(target_arch = "wasm32"))]
2050 {
2051 indicators
2052 .par_chunks_mut(cols)
2053 .zip(signals.par_chunks_mut(cols))
2054 .enumerate()
2055 .for_each(|(row, (ind_row, sig_row))| do_row(row, ind_row, sig_row));
2056 }
2057
2058 #[cfg(target_arch = "wasm32")]
2059 {
2060 for (row, (ind_row, sig_row)) in indicators
2061 .chunks_mut(cols)
2062 .zip(signals.chunks_mut(cols))
2063 .enumerate()
2064 {
2065 do_row(row, ind_row, sig_row);
2066 }
2067 }
2068 } else {
2069 for (row, (ind_row, sig_row)) in indicators
2070 .chunks_mut(cols)
2071 .zip(signals.chunks_mut(cols))
2072 .enumerate()
2073 {
2074 do_row(row, ind_row, sig_row);
2075 }
2076 }
2077
2078 Ok(RsmkBatchOutput {
2079 indicator: indicators,
2080 signal: signals,
2081 combos,
2082 rows,
2083 cols,
2084 })
2085}
2086
2087#[cfg(feature = "python")]
2088#[pyfunction(name = "rsmk")]
2089#[pyo3(signature = (main, compare, lookback, period, signal_period, matype=None, signal_matype=None, kernel=None))]
2090pub fn rsmk_py<'py>(
2091 py: Python<'py>,
2092 main: PyReadonlyArray1<'py, f64>,
2093 compare: PyReadonlyArray1<'py, f64>,
2094 lookback: usize,
2095 period: usize,
2096 signal_period: usize,
2097 matype: Option<&str>,
2098 signal_matype: Option<&str>,
2099 kernel: Option<&str>,
2100) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
2101 use numpy::{IntoPyArray, PyArrayMethods};
2102
2103 let main_slice = main.as_slice()?;
2104 let compare_slice = compare.as_slice()?;
2105 let kern = validate_kernel(kernel, false)?;
2106
2107 let params = RsmkParams {
2108 lookback: Some(lookback),
2109 period: Some(period),
2110 signal_period: Some(signal_period),
2111 matype: matype.map(|s| s.to_string()),
2112 signal_matype: signal_matype.map(|s| s.to_string()),
2113 };
2114 let input = RsmkInput::from_slices(main_slice, compare_slice, params);
2115
2116 let output = py
2117 .allow_threads(|| rsmk_with_kernel(&input, kern))
2118 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2119
2120 Ok((
2121 output.indicator.into_pyarray(py),
2122 output.signal.into_pyarray(py),
2123 ))
2124}
2125
2126#[cfg(feature = "python")]
2127#[pyfunction(name = "rsmk_batch")]
2128#[pyo3(signature = (main, compare, lookback_range, period_range, signal_period_range, matype=None, signal_matype=None, kernel=None))]
2129pub fn rsmk_batch_py<'py>(
2130 py: Python<'py>,
2131 main: PyReadonlyArray1<'py, f64>,
2132 compare: PyReadonlyArray1<'py, f64>,
2133 lookback_range: (usize, usize, usize),
2134 period_range: (usize, usize, usize),
2135 signal_period_range: (usize, usize, usize),
2136 matype: Option<&str>,
2137 signal_matype: Option<&str>,
2138 kernel: Option<&str>,
2139) -> PyResult<Bound<'py, PyDict>> {
2140 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2141
2142 let main_slice = main.as_slice()?;
2143 let compare_slice = compare.as_slice()?;
2144 let kern = validate_kernel(kernel, true)?;
2145
2146 let sweep = RsmkBatchRange {
2147 lookback: lookback_range,
2148 period: period_range,
2149 signal_period: signal_period_range,
2150 };
2151
2152 let combos = expand_grid(&sweep);
2153 let rows = combos.len();
2154 let cols = main_slice.len();
2155
2156 let total = rows.checked_mul(cols).ok_or_else(|| {
2157 PyValueError::new_err(format!(
2158 "rsmk: rows*cols overflow (rows={}, cols={})",
2159 rows, cols
2160 ))
2161 })?;
2162
2163 let indicator_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2164 let signal_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2165 let indicator_slice = unsafe { indicator_arr.as_slice_mut()? };
2166 let signal_slice = unsafe { signal_arr.as_slice_mut()? };
2167
2168 let combos = py
2169 .allow_threads(|| {
2170 let kernel = match kern {
2171 Kernel::Auto => detect_best_batch_kernel(),
2172 k => k,
2173 };
2174
2175 let simd = match kernel {
2176 Kernel::Avx512Batch => Kernel::Avx512,
2177 Kernel::Avx2Batch => Kernel::Avx2,
2178 Kernel::ScalarBatch => Kernel::Scalar,
2179 _ => kern,
2180 };
2181
2182 rsmk_batch_inner_into(
2183 main_slice,
2184 compare_slice,
2185 &sweep,
2186 simd,
2187 true,
2188 indicator_slice,
2189 signal_slice,
2190 )
2191 })
2192 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2193
2194 let dict = PyDict::new(py);
2195 dict.set_item("indicator", indicator_arr.reshape((rows, cols))?)?;
2196 dict.set_item("signal", signal_arr.reshape((rows, cols))?)?;
2197 dict.set_item(
2198 "lookbacks",
2199 combos
2200 .iter()
2201 .map(|p| p.lookback.unwrap() as u64)
2202 .collect::<Vec<_>>()
2203 .into_pyarray(py),
2204 )?;
2205 dict.set_item(
2206 "periods",
2207 combos
2208 .iter()
2209 .map(|p| p.period.unwrap() as u64)
2210 .collect::<Vec<_>>()
2211 .into_pyarray(py),
2212 )?;
2213 dict.set_item(
2214 "signal_periods",
2215 combos
2216 .iter()
2217 .map(|p| p.signal_period.unwrap() as u64)
2218 .collect::<Vec<_>>()
2219 .into_pyarray(py),
2220 )?;
2221 use pyo3::types::PyList;
2222 dict.set_item(
2223 "matypes",
2224 PyList::new(
2225 py,
2226 combos.iter().map(|p| p.matype.as_deref().unwrap_or("ema")),
2227 )?,
2228 )?;
2229 dict.set_item(
2230 "signal_matypes",
2231 PyList::new(
2232 py,
2233 combos
2234 .iter()
2235 .map(|p| p.signal_matype.as_deref().unwrap_or("ema")),
2236 )?,
2237 )?;
2238
2239 Ok(dict)
2240}
2241
2242#[cfg(feature = "python")]
2243#[pyclass(name = "RsmkStream")]
2244pub struct RsmkStreamPy {
2245 inner: RsmkStream,
2246}
2247
2248#[cfg(feature = "python")]
2249#[pymethods]
2250impl RsmkStreamPy {
2251 #[new]
2252 pub fn new(
2253 lookback: usize,
2254 period: usize,
2255 signal_period: usize,
2256 matype: Option<&str>,
2257 signal_matype: Option<&str>,
2258 ) -> PyResult<Self> {
2259 let params = RsmkParams {
2260 lookback: Some(lookback),
2261 period: Some(period),
2262 signal_period: Some(signal_period),
2263 matype: matype.map(|s| s.to_string()),
2264 signal_matype: signal_matype.map(|s| s.to_string()),
2265 };
2266 let inner =
2267 RsmkStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2268 Ok(RsmkStreamPy { inner })
2269 }
2270
2271 pub fn update(&mut self, main: f64, compare: f64) -> Option<(f64, f64)> {
2272 self.inner.update(main, compare)
2273 }
2274}
2275
2276#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2277#[derive(Serialize, Deserialize)]
2278pub struct RsmkResult {
2279 pub values: Vec<f64>,
2280 pub rows: usize,
2281 pub cols: usize,
2282}
2283
2284#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2285#[wasm_bindgen]
2286pub fn rsmk_js(
2287 main: &[f64],
2288 compare: &[f64],
2289 lookback: usize,
2290 period: usize,
2291 signal_period: usize,
2292 matype: Option<String>,
2293 signal_matype: Option<String>,
2294) -> Result<JsValue, JsValue> {
2295 let params = RsmkParams {
2296 lookback: Some(lookback),
2297 period: Some(period),
2298 signal_period: Some(signal_period),
2299 matype: matype.or(Some("ema".into())),
2300 signal_matype: signal_matype.or(Some("ema".into())),
2301 };
2302 let input = RsmkInput::from_slices(main, compare, params);
2303 let out = rsmk(&input).map_err(|e| JsValue::from_str(&e.to_string()))?;
2304 if out.indicator.len() != main.len() || out.signal.len() != main.len() {
2305 return Err(JsValue::from_str("length mismatch"));
2306 }
2307 let mut values = Vec::with_capacity(2 * main.len());
2308 values.extend_from_slice(&out.indicator);
2309 values.extend_from_slice(&out.signal);
2310
2311 let res = RsmkResult {
2312 values,
2313 rows: 2,
2314 cols: main.len(),
2315 };
2316 serde_wasm_bindgen::to_value(&res)
2317 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2318}
2319
2320#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2321#[wasm_bindgen]
2322pub fn rsmk_into(
2323 in_ptr: *const f64,
2324 indicator_ptr: *mut f64,
2325 signal_ptr: *mut f64,
2326 len: usize,
2327 compare_ptr: *const f64,
2328 lookback: usize,
2329 period: usize,
2330 signal_period: usize,
2331 matype: Option<String>,
2332 signal_matype: Option<String>,
2333) -> Result<(), JsValue> {
2334 if in_ptr.is_null() || indicator_ptr.is_null() || signal_ptr.is_null() || compare_ptr.is_null()
2335 {
2336 return Err(JsValue::from_str("Null pointer provided"));
2337 }
2338
2339 unsafe {
2340 let main = std::slice::from_raw_parts(in_ptr, len);
2341 let compare = std::slice::from_raw_parts(compare_ptr, len);
2342 let params = RsmkParams {
2343 lookback: Some(lookback),
2344 period: Some(period),
2345 signal_period: Some(signal_period),
2346 matype: matype.or_else(|| Some("ema".to_string())),
2347 signal_matype: signal_matype.or_else(|| Some("ema".to_string())),
2348 };
2349 let input = RsmkInput::from_slices(main, compare, params);
2350
2351 let in_aliased = in_ptr == indicator_ptr || in_ptr == signal_ptr;
2352 let compare_aliased = compare_ptr == indicator_ptr || compare_ptr == signal_ptr;
2353 let outputs_aliased = indicator_ptr == signal_ptr;
2354
2355 if in_aliased || compare_aliased || outputs_aliased {
2356 let mut temp_indicator = vec![0.0; len];
2357 let mut temp_signal = vec![0.0; len];
2358
2359 rsmk_into_slice(&mut temp_indicator, &mut temp_signal, &input, Kernel::Auto)
2360 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2361
2362 let indicator_out = std::slice::from_raw_parts_mut(indicator_ptr, len);
2363 let signal_out = std::slice::from_raw_parts_mut(signal_ptr, len);
2364
2365 if outputs_aliased {
2366 indicator_out.copy_from_slice(&temp_indicator);
2367 signal_out.copy_from_slice(&temp_signal);
2368 } else {
2369 indicator_out.copy_from_slice(&temp_indicator);
2370 signal_out.copy_from_slice(&temp_signal);
2371 }
2372 } else {
2373 let indicator_out = std::slice::from_raw_parts_mut(indicator_ptr, len);
2374 let signal_out = std::slice::from_raw_parts_mut(signal_ptr, len);
2375 rsmk_into_slice(indicator_out, signal_out, &input, Kernel::Auto)
2376 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2377 }
2378
2379 Ok(())
2380 }
2381}
2382
2383#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2384#[wasm_bindgen]
2385pub fn rsmk_alloc(len: usize) -> *mut f64 {
2386 let mut vec = Vec::<f64>::with_capacity(len);
2387 let ptr = vec.as_mut_ptr();
2388 std::mem::forget(vec);
2389 ptr
2390}
2391
2392#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2393#[wasm_bindgen]
2394pub fn rsmk_free(ptr: *mut f64, len: usize) {
2395 if !ptr.is_null() {
2396 unsafe {
2397 let _ = Vec::from_raw_parts(ptr, len, len);
2398 }
2399 }
2400}
2401
2402#[cfg(all(feature = "python", feature = "cuda"))]
2403use crate::cuda::{cuda_available, CudaRsmk};
2404#[cfg(all(feature = "python", feature = "cuda"))]
2405use crate::indicators::moving_averages::alma::DeviceArrayF32Py;
2406#[cfg(all(feature = "python", feature = "cuda"))]
2407#[cfg(all(feature = "python", feature = "cuda"))]
2408use numpy::{PyReadonlyArray2, PyUntypedArrayMethods};
2409#[cfg(all(feature = "python", feature = "cuda"))]
2410use pyo3::{pyfunction, PyResult, Python};
2411
2412#[cfg(all(feature = "python", feature = "cuda"))]
2413#[pyfunction(name = "rsmk_cuda_batch_dev")]
2414#[pyo3(signature = (main_f32, compare_f32, lookback_range, period_range, signal_period_range, device_id=0))]
2415pub fn rsmk_cuda_batch_dev_py(
2416 py: Python<'_>,
2417 main_f32: PyReadonlyArray1<'_, f32>,
2418 compare_f32: PyReadonlyArray1<'_, f32>,
2419 lookback_range: (usize, usize, usize),
2420 period_range: (usize, usize, usize),
2421 signal_period_range: (usize, usize, usize),
2422 device_id: usize,
2423) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
2424 if !cuda_available() {
2425 return Err(PyValueError::new_err("CUDA not available"));
2426 }
2427 let main = main_f32.as_slice()?;
2428 let comp = compare_f32.as_slice()?;
2429 let sweep = RsmkBatchRange {
2430 lookback: lookback_range,
2431 period: period_range,
2432 signal_period: signal_period_range,
2433 };
2434 let (pair, ctx, dev_id) = py.allow_threads(|| {
2435 let cuda = CudaRsmk::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2436 let ctx = cuda.context_arc();
2437 let dev_id = cuda.device_id();
2438 cuda.rsmk_batch_dev(main, comp, &sweep)
2439 .map(|(pair, _combos)| (pair, ctx, dev_id))
2440 .map_err(|e| PyValueError::new_err(e.to_string()))
2441 })?;
2442 Ok((
2443 DeviceArrayF32Py {
2444 inner: pair.a,
2445 _ctx: Some(ctx.clone()),
2446 device_id: Some(dev_id),
2447 },
2448 DeviceArrayF32Py {
2449 inner: pair.b,
2450 _ctx: Some(ctx),
2451 device_id: Some(dev_id),
2452 },
2453 ))
2454}
2455
2456#[cfg(all(feature = "python", feature = "cuda"))]
2457#[pyfunction(name = "rsmk_cuda_many_series_one_param_dev")]
2458#[pyo3(signature = (main_tm_f32, compare_tm_f32, cols, rows, lookback, period, signal_period, device_id=0))]
2459pub fn rsmk_cuda_many_series_one_param_dev_py(
2460 py: Python<'_>,
2461 main_tm_f32: PyReadonlyArray2<'_, f32>,
2462 compare_tm_f32: PyReadonlyArray2<'_, f32>,
2463 cols: usize,
2464 rows: usize,
2465 lookback: usize,
2466 period: usize,
2467 signal_period: usize,
2468 device_id: usize,
2469) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
2470 if !cuda_available() {
2471 return Err(PyValueError::new_err("CUDA not available"));
2472 }
2473 let main_tm: &[f32] = main_tm_f32.as_slice()?;
2474 let comp_tm: &[f32] = compare_tm_f32.as_slice()?;
2475 let params = RsmkParams {
2476 lookback: Some(lookback),
2477 period: Some(period),
2478 signal_period: Some(signal_period),
2479 matype: Some("ema".into()),
2480 signal_matype: Some("ema".into()),
2481 };
2482 let (pair, ctx, dev_id) = py.allow_threads(|| {
2483 let cuda = CudaRsmk::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2484 let ctx = cuda.context_arc();
2485 let dev_id = cuda.device_id();
2486 cuda.rsmk_many_series_one_param_time_major_dev(main_tm, comp_tm, cols, rows, ¶ms)
2487 .map(|pair| (pair, ctx, dev_id))
2488 .map_err(|e| PyValueError::new_err(e.to_string()))
2489 })?;
2490 Ok((
2491 DeviceArrayF32Py {
2492 inner: pair.a,
2493 _ctx: Some(ctx.clone()),
2494 device_id: Some(dev_id),
2495 },
2496 DeviceArrayF32Py {
2497 inner: pair.b,
2498 _ctx: Some(ctx),
2499 device_id: Some(dev_id),
2500 },
2501 ))
2502}
2503
2504#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2505#[derive(Serialize, Deserialize)]
2506pub struct RsmkBatchConfig {
2507 pub lookback_range: (usize, usize, usize),
2508 pub period_range: (usize, usize, usize),
2509 pub signal_period_range: (usize, usize, usize),
2510 pub matype: Option<String>,
2511 pub signal_matype: Option<String>,
2512}
2513
2514#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2515#[derive(Serialize, Deserialize)]
2516pub struct RsmkBatchJsOutput {
2517 pub indicators: Vec<f64>,
2518 pub signals: Vec<f64>,
2519 pub combos: Vec<RsmkParams>,
2520 pub rows: usize,
2521 pub cols: usize,
2522}
2523
2524#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2525#[wasm_bindgen(js_name = rsmk_batch)]
2526pub fn rsmk_batch_js(main: &[f64], compare: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2527 let config: RsmkBatchConfig = serde_wasm_bindgen::from_value(config)
2528 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2529
2530 let output = RsmkBatchBuilder::new()
2531 .lookback_range(
2532 config.lookback_range.0,
2533 config.lookback_range.1,
2534 config.lookback_range.2,
2535 )
2536 .period_range(
2537 config.period_range.0,
2538 config.period_range.1,
2539 config.period_range.2,
2540 )
2541 .signal_period_range(
2542 config.signal_period_range.0,
2543 config.signal_period_range.1,
2544 config.signal_period_range.2,
2545 )
2546 .apply_slices(main, compare)
2547 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2548
2549 let indicators: Vec<f64> = output
2550 .indicator
2551 .chunks(output.cols)
2552 .flat_map(|row| row.iter().copied())
2553 .collect();
2554
2555 let signals: Vec<f64> = output
2556 .signal
2557 .chunks(output.cols)
2558 .flat_map(|row| row.iter().copied())
2559 .collect();
2560
2561 let js_output = RsmkBatchJsOutput {
2562 indicators,
2563 signals,
2564 combos: output.combos,
2565 rows: output.rows,
2566 cols: output.cols,
2567 };
2568
2569 serde_wasm_bindgen::to_value(&js_output)
2570 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2571}
2572
2573#[cfg(test)]
2574mod tests {
2575 use super::*;
2576 use crate::skip_if_unsupported;
2577 use crate::utilities::data_loader::read_candles_from_csv;
2578 #[cfg(feature = "proptest")]
2579 use proptest::prelude::*;
2580
2581 #[test]
2582 fn test_rsmk_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
2583 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2584 let candles = read_candles_from_csv(file)?;
2585
2586 let input = RsmkInput::with_default_candles(&candles, &candles);
2587
2588 let baseline = rsmk(&input)?;
2589
2590 let n = candles.close.len();
2591 let mut out_ind = vec![0.0f64; n];
2592 let mut out_sig = vec![0.0f64; n];
2593 rsmk_into(&input, &mut out_ind, &mut out_sig)?;
2594
2595 assert_eq!(baseline.indicator.len(), n);
2596 assert_eq!(baseline.signal.len(), n);
2597 assert_eq!(out_ind.len(), n);
2598 assert_eq!(out_sig.len(), n);
2599
2600 fn eq_or_both_nan(a: f64, b: f64) -> bool {
2601 (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
2602 }
2603
2604 for i in 0..n {
2605 assert!(
2606 eq_or_both_nan(baseline.indicator[i], out_ind[i]),
2607 "indicator mismatch at {i}: {} vs {}",
2608 baseline.indicator[i],
2609 out_ind[i]
2610 );
2611 assert!(
2612 eq_or_both_nan(baseline.signal[i], out_sig[i]),
2613 "signal mismatch at {i}: {} vs {}",
2614 baseline.signal[i],
2615 out_sig[i]
2616 );
2617 }
2618
2619 Ok(())
2620 }
2621
2622 fn check_rsmk_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2623 skip_if_unsupported!(kernel, test_name);
2624 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2625 let candles = read_candles_from_csv(file_path)?;
2626 let default_params = RsmkParams {
2627 lookback: None,
2628 period: None,
2629 signal_period: None,
2630 matype: None,
2631 signal_matype: None,
2632 };
2633 let input_default = RsmkInput::from_candles(&candles, &candles, "close", default_params);
2634 let output_default = rsmk_with_kernel(&input_default, kernel)?;
2635 assert_eq!(output_default.indicator.len(), candles.close.len());
2636 assert_eq!(output_default.signal.len(), candles.close.len());
2637 Ok(())
2638 }
2639
2640 fn check_rsmk_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2641 skip_if_unsupported!(kernel, test_name);
2642 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2643 let candles = read_candles_from_csv(file_path)?;
2644 let params = RsmkParams::default();
2645 let input = RsmkInput::from_candles(&candles, &candles, "close", params.clone());
2646 let rsmk_result = rsmk_with_kernel(&input, kernel)?;
2647 assert_eq!(rsmk_result.indicator.len(), candles.close.len());
2648 assert_eq!(rsmk_result.signal.len(), candles.close.len());
2649 let expected_last_five = [0.0, 0.0, 0.0, 0.0, 0.0];
2650 let start = rsmk_result.indicator.len() - 5;
2651 for (i, &value) in rsmk_result.indicator[start..].iter().enumerate() {
2652 let expected_value = expected_last_five[i];
2653 assert!((value - expected_value).abs() < 1e-1);
2654 }
2655 for (i, &value) in rsmk_result.signal[start..].iter().enumerate() {
2656 let expected_value = expected_last_five[i];
2657 assert!((value - expected_value).abs() < 1e-1);
2658 }
2659 Ok(())
2660 }
2661
2662 fn check_rsmk_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2663 skip_if_unsupported!(kernel, test_name);
2664 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2665 let candles = read_candles_from_csv(file_path)?;
2666 let input = RsmkInput::with_default_candles(&candles, &candles);
2667 let rsmk_result = rsmk_with_kernel(&input, kernel)?;
2668 assert_eq!(rsmk_result.indicator.len(), candles.close.len());
2669 assert_eq!(rsmk_result.signal.len(), candles.close.len());
2670 Ok(())
2671 }
2672
2673 fn check_rsmk_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2674 skip_if_unsupported!(kernel, test_name);
2675 let input_data = [10.0, 11.0, 12.0];
2676 let params = RsmkParams {
2677 lookback: Some(0),
2678 period: Some(0),
2679 signal_period: Some(0),
2680 matype: Some("ema".to_string()),
2681 signal_matype: Some("ema".to_string()),
2682 };
2683 let input = RsmkInput::from_slices(&input_data, &input_data, params);
2684 let result = rsmk_with_kernel(&input, kernel);
2685 assert!(result.is_err());
2686 Ok(())
2687 }
2688
2689 fn check_rsmk_very_small_dataset(
2690 test_name: &str,
2691 kernel: Kernel,
2692 ) -> Result<(), Box<dyn Error>> {
2693 skip_if_unsupported!(kernel, test_name);
2694 let input_data = [42.0];
2695 let params = RsmkParams::default();
2696 let input = RsmkInput::from_slices(&input_data, &input_data, params);
2697 let result = rsmk_with_kernel(&input, kernel);
2698 assert!(result.is_err());
2699 Ok(())
2700 }
2701
2702 fn check_rsmk_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2703 skip_if_unsupported!(kernel, test_name);
2704 let input_data = [f64::NAN, f64::NAN, f64::NAN];
2705 let params = RsmkParams::default();
2706 let input = RsmkInput::from_slices(&input_data, &input_data, params);
2707 let result = rsmk_with_kernel(&input, kernel);
2708 assert!(result.is_err());
2709 Ok(())
2710 }
2711
2712 fn check_rsmk_not_enough_valid_data(
2713 test_name: &str,
2714 kernel: Kernel,
2715 ) -> Result<(), Box<dyn Error>> {
2716 skip_if_unsupported!(kernel, test_name);
2717 let input_data = [f64::NAN, 10.0, 20.0, 30.0];
2718 let params = RsmkParams {
2719 lookback: Some(3),
2720 period: Some(3),
2721 signal_period: Some(3),
2722 matype: Some("ema".to_string()),
2723 signal_matype: Some("ema".to_string()),
2724 };
2725 let input = RsmkInput::from_slices(&input_data, &input_data, params);
2726 let result = rsmk_with_kernel(&input, kernel);
2727 assert!(result.is_err());
2728 Ok(())
2729 }
2730
2731 fn check_rsmk_ma_error(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2732 skip_if_unsupported!(kernel, test_name);
2733 let input_data = [10.0, 11.0, 12.0, 13.0, 14.0, 15.0];
2734 let params = RsmkParams {
2735 lookback: Some(2),
2736 period: Some(3),
2737 signal_period: Some(3),
2738 matype: Some("nonexistent_ma".to_string()),
2739 signal_matype: Some("ema".to_string()),
2740 };
2741 let input = RsmkInput::from_slices(&input_data, &input_data, params);
2742 let result = rsmk_with_kernel(&input, kernel);
2743 assert!(result.is_err());
2744 Ok(())
2745 }
2746
2747 #[cfg(debug_assertions)]
2748 fn check_rsmk_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2749 skip_if_unsupported!(kernel, test_name);
2750
2751 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2752 let candles = read_candles_from_csv(file_path)?;
2753
2754 let test_params = vec![
2755 RsmkParams::default(),
2756 RsmkParams {
2757 lookback: Some(1),
2758 period: Some(1),
2759 signal_period: Some(1),
2760 matype: Some("ema".to_string()),
2761 signal_matype: Some("ema".to_string()),
2762 },
2763 RsmkParams {
2764 lookback: Some(10),
2765 period: Some(2),
2766 signal_period: Some(5),
2767 matype: Some("ema".to_string()),
2768 signal_matype: Some("ema".to_string()),
2769 },
2770 RsmkParams {
2771 lookback: Some(50),
2772 period: Some(10),
2773 signal_period: Some(15),
2774 matype: Some("sma".to_string()),
2775 signal_matype: Some("sma".to_string()),
2776 },
2777 RsmkParams {
2778 lookback: Some(100),
2779 period: Some(20),
2780 signal_period: Some(30),
2781 matype: Some("ema".to_string()),
2782 signal_matype: Some("sma".to_string()),
2783 },
2784 RsmkParams {
2785 lookback: Some(200),
2786 period: Some(50),
2787 signal_period: Some(50),
2788 matype: Some("sma".to_string()),
2789 signal_matype: Some("ema".to_string()),
2790 },
2791 RsmkParams {
2792 lookback: Some(5),
2793 period: Some(20),
2794 signal_period: Some(10),
2795 matype: Some("ema".to_string()),
2796 signal_matype: Some("ema".to_string()),
2797 },
2798 ];
2799
2800 for (param_idx, params) in test_params.iter().enumerate() {
2801 let input = RsmkInput::from_candles(&candles, &candles, "close", params.clone());
2802 let output = rsmk_with_kernel(&input, kernel)?;
2803
2804 for (i, &val) in output.indicator.iter().enumerate() {
2805 if val.is_nan() {
2806 continue;
2807 }
2808
2809 let bits = val.to_bits();
2810
2811 if bits == 0x11111111_11111111 {
2812 panic!(
2813 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) in indicator at index {} \
2814 with params: lookback={}, period={}, signal_period={}, matype={}, signal_matype={} (param set {})",
2815 test_name, val, bits, i,
2816 params.lookback.unwrap_or(90),
2817 params.period.unwrap_or(3),
2818 params.signal_period.unwrap_or(20),
2819 params.matype.as_deref().unwrap_or("ema"),
2820 params.signal_matype.as_deref().unwrap_or("ema"),
2821 param_idx
2822 );
2823 }
2824
2825 if bits == 0x22222222_22222222 {
2826 panic!(
2827 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) in indicator at index {} \
2828 with params: lookback={}, period={}, signal_period={}, matype={}, signal_matype={} (param set {})",
2829 test_name, val, bits, i,
2830 params.lookback.unwrap_or(90),
2831 params.period.unwrap_or(3),
2832 params.signal_period.unwrap_or(20),
2833 params.matype.as_deref().unwrap_or("ema"),
2834 params.signal_matype.as_deref().unwrap_or("ema"),
2835 param_idx
2836 );
2837 }
2838
2839 if bits == 0x33333333_33333333 {
2840 panic!(
2841 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) in indicator at index {} \
2842 with params: lookback={}, period={}, signal_period={}, matype={}, signal_matype={} (param set {})",
2843 test_name, val, bits, i,
2844 params.lookback.unwrap_or(90),
2845 params.period.unwrap_or(3),
2846 params.signal_period.unwrap_or(20),
2847 params.matype.as_deref().unwrap_or("ema"),
2848 params.signal_matype.as_deref().unwrap_or("ema"),
2849 param_idx
2850 );
2851 }
2852 }
2853
2854 for (i, &val) in output.signal.iter().enumerate() {
2855 if val.is_nan() {
2856 continue;
2857 }
2858
2859 let bits = val.to_bits();
2860
2861 if bits == 0x11111111_11111111 {
2862 panic!(
2863 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) in signal at index {} \
2864 with params: lookback={}, period={}, signal_period={}, matype={}, signal_matype={} (param set {})",
2865 test_name, val, bits, i,
2866 params.lookback.unwrap_or(90),
2867 params.period.unwrap_or(3),
2868 params.signal_period.unwrap_or(20),
2869 params.matype.as_deref().unwrap_or("ema"),
2870 params.signal_matype.as_deref().unwrap_or("ema"),
2871 param_idx
2872 );
2873 }
2874
2875 if bits == 0x22222222_22222222 {
2876 panic!(
2877 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) in signal at index {} \
2878 with params: lookback={}, period={}, signal_period={}, matype={}, signal_matype={} (param set {})",
2879 test_name, val, bits, i,
2880 params.lookback.unwrap_or(90),
2881 params.period.unwrap_or(3),
2882 params.signal_period.unwrap_or(20),
2883 params.matype.as_deref().unwrap_or("ema"),
2884 params.signal_matype.as_deref().unwrap_or("ema"),
2885 param_idx
2886 );
2887 }
2888
2889 if bits == 0x33333333_33333333 {
2890 panic!(
2891 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) in signal at index {} \
2892 with params: lookback={}, period={}, signal_period={}, matype={}, signal_matype={} (param set {})",
2893 test_name, val, bits, i,
2894 params.lookback.unwrap_or(90),
2895 params.period.unwrap_or(3),
2896 params.signal_period.unwrap_or(20),
2897 params.matype.as_deref().unwrap_or("ema"),
2898 params.signal_matype.as_deref().unwrap_or("ema"),
2899 param_idx
2900 );
2901 }
2902 }
2903 }
2904
2905 Ok(())
2906 }
2907
2908 #[cfg(not(debug_assertions))]
2909 fn check_rsmk_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2910 Ok(())
2911 }
2912
2913 #[cfg(feature = "proptest")]
2914 #[allow(clippy::float_cmp)]
2915 fn check_rsmk_property(
2916 test_name: &str,
2917 kernel: Kernel,
2918 ) -> Result<(), Box<dyn std::error::Error>> {
2919 use proptest::prelude::*;
2920 skip_if_unsupported!(kernel, test_name);
2921
2922 let strat = (1usize..=100, 1usize..=50, 1usize..=50).prop_flat_map(
2923 |(lookback, period, signal_period)| {
2924 let min_len = lookback + period.max(signal_period) + 50;
2925 (min_len..=500usize).prop_flat_map(move |len| {
2926 (
2927 prop::collection::vec(
2928 (1.0f64..10000.0f64).prop_filter("finite", |x| x.is_finite()),
2929 len,
2930 ),
2931 prop::collection::vec(
2932 (1.0f64..10000.0f64).prop_filter("finite", |x| x.is_finite()),
2933 len,
2934 ),
2935 Just(lookback),
2936 Just(period),
2937 Just(signal_period),
2938 )
2939 })
2940 },
2941 );
2942
2943 proptest::test_runner::TestRunner::default()
2944 .run(
2945 &strat,
2946 |(main, compare, lookback, period, signal_period)| {
2947 let params = RsmkParams {
2948 lookback: Some(lookback),
2949 period: Some(period),
2950 signal_period: Some(signal_period),
2951 matype: Some("ema".to_string()),
2952 signal_matype: Some("ema".to_string()),
2953 };
2954 let input = RsmkInput::from_slices(&main, &compare, params.clone());
2955
2956 let output = match rsmk_with_kernel(&input, kernel) {
2957 Ok(out) => out,
2958 Err(_) => {
2959 return Ok(());
2960 }
2961 };
2962
2963 let ref_output = match rsmk_with_kernel(&input, Kernel::Scalar) {
2964 Ok(out) => out,
2965 Err(_) => {
2966 return Ok(());
2967 }
2968 };
2969
2970 prop_assert_eq!(output.indicator.len(), main.len());
2971 prop_assert_eq!(output.signal.len(), main.len());
2972
2973 let all_equal = main
2974 .iter()
2975 .zip(compare.iter())
2976 .all(|(a, b)| (a - b).abs() < f64::EPSILON);
2977
2978 for i in 0..lookback.min(main.len()) {
2979 if all_equal {
2980 prop_assert!(
2981 output.indicator[i].is_nan() || output.indicator[i].abs() < 1e-9,
2982 "Expected NaN or 0 during warmup at index {} (before lookback {}), got {}",
2983 i, lookback, output.indicator[i]
2984 );
2985 } else {
2986 prop_assert!(
2987 output.indicator[i].is_nan(),
2988 "Expected NaN during warmup at index {} (before lookback {}), got {}",
2989 i, lookback, output.indicator[i]
2990 );
2991 }
2992 }
2993
2994 let full_warmup = lookback + period.max(signal_period);
2995 if main.len() > full_warmup + 5 {
2996 let has_valid =
2997 output.indicator[full_warmup..].iter().any(|&x| !x.is_nan());
2998 prop_assert!(
2999 has_valid,
3000 "Expected some non-NaN values after full warmup period ({})",
3001 full_warmup
3002 );
3003 }
3004
3005 let identical_params = RsmkParams {
3006 lookback: Some(lookback),
3007 period: Some(period),
3008 signal_period: Some(signal_period),
3009 matype: Some("ema".to_string()),
3010 signal_matype: Some("ema".to_string()),
3011 };
3012 let identical_input = RsmkInput::from_slices(&main, &main, identical_params);
3013 if let Ok(identical_output) = rsmk_with_kernel(&identical_input, kernel) {
3014 let warmup = lookback.max(period).max(signal_period);
3015 for i in warmup..main.len() {
3016 if !identical_output.indicator[i].is_nan() {
3017 prop_assert!(
3018 identical_output.indicator[i].abs() < 1e-9,
3019 "When main==compare, indicator should be 0 at index {}, got {}",
3020 i,
3021 identical_output.indicator[i]
3022 );
3023 }
3024 }
3025 }
3026
3027 let const_ratio = 2.0;
3028 let main_scaled: Vec<f64> = compare.iter().map(|&x| x * const_ratio).collect();
3029 let const_params = RsmkParams {
3030 lookback: Some(lookback),
3031 period: Some(period),
3032 signal_period: Some(signal_period),
3033 matype: Some("ema".to_string()),
3034 signal_matype: Some("ema".to_string()),
3035 };
3036 let const_input = RsmkInput::from_slices(&main_scaled, &compare, const_params);
3037 if let Ok(const_output) = rsmk_with_kernel(&const_input, kernel) {
3038 let warmup = lookback.max(period).max(signal_period);
3039 let check_start = (warmup + 10).min(main.len());
3040 for i in check_start..main.len() {
3041 if !const_output.indicator[i].is_nan() {
3042 prop_assert!(
3043 const_output.indicator[i].abs() < 1e-6,
3044 "For constant ratio, indicator should be near 0 at index {}, got {}",
3045 i, const_output.indicator[i]
3046 );
3047 }
3048 }
3049 }
3050
3051 if period == 1 {
3052 prop_assert!(output.indicator.iter().any(|&x| !x.is_nan()));
3053 }
3054
3055 let warmup = lookback.max(period).max(signal_period);
3056 for i in warmup..main.len() {
3057 let ind = output.indicator[i];
3058 let ref_ind = ref_output.indicator[i];
3059 let sig = output.signal[i];
3060 let ref_sig = ref_output.signal[i];
3061
3062 if !ind.is_nan() && !ref_ind.is_nan() {
3063 let ind_bits = ind.to_bits();
3064 let ref_ind_bits = ref_ind.to_bits();
3065 let ulp_diff = ind_bits.abs_diff(ref_ind_bits);
3066
3067 prop_assert!(
3068 (ind - ref_ind).abs() <= 1e-9 || ulp_diff <= 10,
3069 "Indicator mismatch at index {}: {} vs {} (ULP={})",
3070 i,
3071 ind,
3072 ref_ind,
3073 ulp_diff
3074 );
3075 } else {
3076 prop_assert_eq!(ind.is_nan(), ref_ind.is_nan());
3077 }
3078
3079 if !sig.is_nan() && !ref_sig.is_nan() {
3080 let sig_bits = sig.to_bits();
3081 let ref_sig_bits = ref_sig.to_bits();
3082 let ulp_diff = sig_bits.abs_diff(ref_sig_bits);
3083
3084 prop_assert!(
3085 (sig - ref_sig).abs() <= 1e-9 || ulp_diff <= 10,
3086 "Signal mismatch at index {}: {} vs {} (ULP={})",
3087 i,
3088 sig,
3089 ref_sig,
3090 ulp_diff
3091 );
3092 } else {
3093 prop_assert_eq!(sig.is_nan(), ref_sig.is_nan());
3094 }
3095 }
3096
3097 let indicator_diffs: Vec<f64> = output
3098 .indicator
3099 .windows(2)
3100 .filter_map(|w| {
3101 if !w[0].is_nan() && !w[1].is_nan() {
3102 Some((w[1] - w[0]).abs())
3103 } else {
3104 None
3105 }
3106 })
3107 .collect();
3108
3109 let signal_diffs: Vec<f64> = output
3110 .signal
3111 .windows(2)
3112 .filter_map(|w| {
3113 if !w[0].is_nan() && !w[1].is_nan() {
3114 Some((w[1] - w[0]).abs())
3115 } else {
3116 None
3117 }
3118 })
3119 .collect();
3120
3121 if !indicator_diffs.is_empty()
3122 && !signal_diffs.is_empty()
3123 && indicator_diffs.len() > 10
3124 {
3125 let ind_mean =
3126 indicator_diffs.iter().sum::<f64>() / indicator_diffs.len() as f64;
3127 let ind_var = indicator_diffs
3128 .iter()
3129 .map(|x| (x - ind_mean).powi(2))
3130 .sum::<f64>()
3131 / indicator_diffs.len() as f64;
3132
3133 let sig_mean = signal_diffs.iter().sum::<f64>() / signal_diffs.len() as f64;
3134 let sig_var = signal_diffs
3135 .iter()
3136 .map(|x| (x - sig_mean).powi(2))
3137 .sum::<f64>()
3138 / signal_diffs.len() as f64;
3139
3140 if signal_period > 1 && ind_var > 1e-12 {
3141 prop_assert!(
3142 sig_var <= ind_var * 1.2 || sig_var < 1e-10,
3143 "Signal should be smoother than indicator: sig_var={} ind_var={}",
3144 sig_var,
3145 ind_var
3146 );
3147 }
3148 }
3149
3150 let mut compare_with_zero = compare.clone();
3151 if compare_with_zero.len() > lookback + 5 {
3152 compare_with_zero[lookback + 2] = 0.0;
3153 let zero_params = RsmkParams {
3154 lookback: Some(lookback),
3155 period: Some(period),
3156 signal_period: Some(signal_period),
3157 matype: Some("ema".to_string()),
3158 signal_matype: Some("ema".to_string()),
3159 };
3160 let zero_input =
3161 RsmkInput::from_slices(&main, &compare_with_zero, zero_params);
3162 if let Ok(zero_output) = rsmk_with_kernel(&zero_input, kernel) {
3163 prop_assert!(
3164 zero_output.indicator.len() == main.len(),
3165 "Output length should match input length even with zeros in compare"
3166 );
3167 }
3168 }
3169
3170 if main.len() > warmup + 10 {
3171 let large_ratio = 10000.0;
3172 let main_large: Vec<f64> =
3173 compare.iter().map(|&x| x * large_ratio).collect();
3174 let large_params = RsmkParams {
3175 lookback: Some(lookback),
3176 period: Some(period),
3177 signal_period: Some(signal_period),
3178 matype: Some("ema".to_string()),
3179 signal_matype: Some("ema".to_string()),
3180 };
3181 let large_input =
3182 RsmkInput::from_slices(&main_large, &compare, large_params);
3183 if let Ok(large_output) = rsmk_with_kernel(&large_input, kernel) {
3184 for i in warmup..main.len() {
3185 if !large_output.indicator[i].is_nan() {
3186 prop_assert!(
3187 large_output.indicator[i].is_finite(),
3188 "Indicator should be finite for large ratios at index {}, got {}",
3189 i, large_output.indicator[i]
3190 );
3191
3192 if i > warmup + 10 {
3193 prop_assert!(
3194 large_output.indicator[i].abs() < 1.0,
3195 "Large constant ratio should still have near-zero momentum at index {}, got {}",
3196 i, large_output.indicator[i]
3197 );
3198 }
3199 }
3200 }
3201 }
3202 }
3203
3204 Ok(())
3205 },
3206 )
3207 .unwrap();
3208
3209 Ok(())
3210 }
3211
3212 macro_rules! generate_all_rsmk_tests {
3213 ($($test_fn:ident),*) => {
3214 paste::paste! {
3215 $(
3216 #[test]
3217 fn [<$test_fn _scalar_f64>]() {
3218 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
3219 }
3220 )*
3221 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3222 $(
3223 #[test]
3224 fn [<$test_fn _avx2_f64>]() {
3225 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
3226 }
3227 #[test]
3228 fn [<$test_fn _avx512_f64>]() {
3229 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
3230 }
3231 )*
3232 }
3233 }
3234 }
3235
3236 generate_all_rsmk_tests!(
3237 check_rsmk_partial_params,
3238 check_rsmk_accuracy,
3239 check_rsmk_default_candles,
3240 check_rsmk_zero_period,
3241 check_rsmk_very_small_dataset,
3242 check_rsmk_all_nan,
3243 check_rsmk_not_enough_valid_data,
3244 check_rsmk_ma_error,
3245 check_rsmk_no_poison
3246 );
3247
3248 #[cfg(feature = "proptest")]
3249 generate_all_rsmk_tests!(check_rsmk_property);
3250 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3251 skip_if_unsupported!(kernel, test);
3252
3253 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3254 let candles = read_candles_from_csv(file)?;
3255 let main = &candles.close;
3256 let compare = &candles.close;
3257
3258 let batch = RsmkBatchBuilder::new()
3259 .kernel(kernel)
3260 .apply_slices(main, compare)?;
3261
3262 let def = RsmkParams::default();
3263
3264 let default_row = batch
3265 .combos
3266 .iter()
3267 .position(|c| {
3268 c.lookback.unwrap_or(90) == def.lookback.unwrap_or(90)
3269 && c.period.unwrap_or(3) == def.period.unwrap_or(3)
3270 && c.signal_period.unwrap_or(20) == def.signal_period.unwrap_or(20)
3271 })
3272 .expect("default row missing");
3273
3274 let start = default_row * batch.cols;
3275 let ind_row = &batch.indicator[start..start + batch.cols];
3276 let sig_row = &batch.signal[start..start + batch.cols];
3277
3278 assert_eq!(ind_row.len(), candles.close.len());
3279 assert_eq!(sig_row.len(), candles.close.len());
3280
3281 let expected = [0.0, 0.0, 0.0, 0.0, 0.0];
3282 let len = ind_row.len();
3283 let start_idx = len - 5;
3284
3285 for (i, &v) in ind_row[start_idx..].iter().enumerate() {
3286 assert!(
3287 (v - expected[i]).abs() < 1e-1,
3288 "[{test}] default-indicator mismatch at idx {i}: {v} vs {expected:?}"
3289 );
3290 }
3291 for (i, &v) in sig_row[start_idx..].iter().enumerate() {
3292 assert!(
3293 (v - expected[i]).abs() < 1e-1,
3294 "[{test}] default-signal mismatch at idx {i}: {v} vs {expected:?}"
3295 );
3296 }
3297
3298 let max_period = def
3299 .lookback
3300 .unwrap()
3301 .max(def.period.unwrap())
3302 .max(def.signal_period.unwrap());
3303 for i in 0..max_period {
3304 if i < ind_row.len() {
3305 assert!(
3306 ind_row[i].is_nan(),
3307 "Expected indicator NaN at index {i}, got {}",
3308 ind_row[i]
3309 );
3310 }
3311 if i < sig_row.len() {
3312 assert!(
3313 sig_row[i].is_nan(),
3314 "Expected signal NaN at index {i}, got {}",
3315 sig_row[i]
3316 );
3317 }
3318 }
3319 Ok(())
3320 }
3321
3322 #[cfg(debug_assertions)]
3323 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3324 skip_if_unsupported!(kernel, test);
3325
3326 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3327 let candles = read_candles_from_csv(file)?;
3328 let main = &candles.close;
3329 let compare = &candles.close;
3330
3331 let test_configs = vec![
3332 ((10, 20, 5), (2, 5, 1), (5, 10, 5)),
3333 ((50, 100, 25), (5, 15, 5), (10, 30, 10)),
3334 ((100, 200, 50), (20, 40, 10), (30, 60, 15)),
3335 ((1, 5, 1), (1, 3, 1), (1, 5, 1)),
3336 ((90, 90, 0), (3, 3, 0), (20, 20, 0)),
3337 ((200, 250, 25), (50, 70, 10), (50, 100, 25)),
3338 ];
3339
3340 for (cfg_idx, &(lookback_range, period_range, signal_period_range)) in
3341 test_configs.iter().enumerate()
3342 {
3343 let output = RsmkBatchBuilder::new()
3344 .kernel(kernel)
3345 .lookback_range(lookback_range.0, lookback_range.1, lookback_range.2)
3346 .period_range(period_range.0, period_range.1, period_range.2)
3347 .signal_period_range(
3348 signal_period_range.0,
3349 signal_period_range.1,
3350 signal_period_range.2,
3351 )
3352 .apply_slices(main, compare)?;
3353
3354 for (idx, &val) in output.indicator.iter().enumerate() {
3355 if val.is_nan() {
3356 continue;
3357 }
3358
3359 let bits = val.to_bits();
3360 let row = idx / output.cols;
3361 let col = idx % output.cols;
3362 let combo = &output.combos[row];
3363
3364 if bits == 0x11111111_11111111 {
3365 panic!(
3366 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) in indicator \
3367 at row {} col {} (flat index {}) with params: lookback={}, period={}, signal_period={}, \
3368 matype={}, signal_matype={}",
3369 test, cfg_idx, val, bits, row, col, idx,
3370 combo.lookback.unwrap_or(90),
3371 combo.period.unwrap_or(3),
3372 combo.signal_period.unwrap_or(20),
3373 combo.matype.as_deref().unwrap_or("ema"),
3374 combo.signal_matype.as_deref().unwrap_or("ema")
3375 );
3376 }
3377
3378 if bits == 0x22222222_22222222 {
3379 panic!(
3380 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) in indicator \
3381 at row {} col {} (flat index {}) with params: lookback={}, period={}, signal_period={}, \
3382 matype={}, signal_matype={}",
3383 test, cfg_idx, val, bits, row, col, idx,
3384 combo.lookback.unwrap_or(90),
3385 combo.period.unwrap_or(3),
3386 combo.signal_period.unwrap_or(20),
3387 combo.matype.as_deref().unwrap_or("ema"),
3388 combo.signal_matype.as_deref().unwrap_or("ema")
3389 );
3390 }
3391
3392 if bits == 0x33333333_33333333 {
3393 panic!(
3394 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) in indicator \
3395 at row {} col {} (flat index {}) with params: lookback={}, period={}, signal_period={}, \
3396 matype={}, signal_matype={}",
3397 test, cfg_idx, val, bits, row, col, idx,
3398 combo.lookback.unwrap_or(90),
3399 combo.period.unwrap_or(3),
3400 combo.signal_period.unwrap_or(20),
3401 combo.matype.as_deref().unwrap_or("ema"),
3402 combo.signal_matype.as_deref().unwrap_or("ema")
3403 );
3404 }
3405 }
3406
3407 for (idx, &val) in output.signal.iter().enumerate() {
3408 if val.is_nan() {
3409 continue;
3410 }
3411
3412 let bits = val.to_bits();
3413 let row = idx / output.cols;
3414 let col = idx % output.cols;
3415 let combo = &output.combos[row];
3416
3417 if bits == 0x11111111_11111111 {
3418 panic!(
3419 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) in signal \
3420 at row {} col {} (flat index {}) with params: lookback={}, period={}, signal_period={}, \
3421 matype={}, signal_matype={}",
3422 test, cfg_idx, val, bits, row, col, idx,
3423 combo.lookback.unwrap_or(90),
3424 combo.period.unwrap_or(3),
3425 combo.signal_period.unwrap_or(20),
3426 combo.matype.as_deref().unwrap_or("ema"),
3427 combo.signal_matype.as_deref().unwrap_or("ema")
3428 );
3429 }
3430
3431 if bits == 0x22222222_22222222 {
3432 panic!(
3433 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) in signal \
3434 at row {} col {} (flat index {}) with params: lookback={}, period={}, signal_period={}, \
3435 matype={}, signal_matype={}",
3436 test, cfg_idx, val, bits, row, col, idx,
3437 combo.lookback.unwrap_or(90),
3438 combo.period.unwrap_or(3),
3439 combo.signal_period.unwrap_or(20),
3440 combo.matype.as_deref().unwrap_or("ema"),
3441 combo.signal_matype.as_deref().unwrap_or("ema")
3442 );
3443 }
3444
3445 if bits == 0x33333333_33333333 {
3446 panic!(
3447 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) in signal \
3448 at row {} col {} (flat index {}) with params: lookback={}, period={}, signal_period={}, \
3449 matype={}, signal_matype={}",
3450 test, cfg_idx, val, bits, row, col, idx,
3451 combo.lookback.unwrap_or(90),
3452 combo.period.unwrap_or(3),
3453 combo.signal_period.unwrap_or(20),
3454 combo.matype.as_deref().unwrap_or("ema"),
3455 combo.signal_matype.as_deref().unwrap_or("ema")
3456 );
3457 }
3458 }
3459 }
3460
3461 Ok(())
3462 }
3463
3464 #[cfg(not(debug_assertions))]
3465 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
3466 Ok(())
3467 }
3468
3469 macro_rules! gen_batch_tests {
3470 ($fn_name:ident) => {
3471 paste::paste! {
3472 #[test] fn [<$fn_name _scalar>]() {
3473 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3474 }
3475 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3476 #[test] fn [<$fn_name _avx2>]() {
3477 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3478 }
3479 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3480 #[test] fn [<$fn_name _avx512>]() {
3481 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3482 }
3483 #[test] fn [<$fn_name _auto_detect>]() {
3484 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3485 }
3486 }
3487 };
3488 }
3489 gen_batch_tests!(check_batch_default_row);
3490 gen_batch_tests!(check_batch_no_poison);
3491}
3492
3493#[inline]
3494fn rsmk_classic_sma(
3495 mom: &[f64],
3496 period: usize,
3497 signal_period: usize,
3498 first_valid: usize,
3499) -> Result<RsmkOutput, RsmkError> {
3500 let len = mom.len();
3501 let ind_warmup = first_valid + period - 1;
3502 let sig_warmup = ind_warmup + signal_period - 1;
3503
3504 let needed = period.max(signal_period);
3505 if len < first_valid || len - first_valid < needed {
3506 return Err(RsmkError::NotEnoughValidData {
3507 needed,
3508 valid: if len >= first_valid {
3509 len - first_valid
3510 } else {
3511 0
3512 },
3513 });
3514 }
3515
3516 let mut indicator = alloc_with_nan_prefix(len, ind_warmup);
3517 let mut signal = alloc_with_nan_prefix(len, sig_warmup);
3518
3519 let mut sum_ind = 0.0;
3520 let mut count_ind = 0;
3521
3522 for i in first_valid..(first_valid + period).min(len) {
3523 if !mom[i].is_nan() {
3524 sum_ind += mom[i];
3525 count_ind += 1;
3526 }
3527 }
3528
3529 if count_ind > 0 && ind_warmup < len {
3530 indicator[ind_warmup] = (sum_ind / count_ind as f64) * 100.0;
3531
3532 for i in (ind_warmup + 1)..len {
3533 let old_val = mom[i - period];
3534 let new_val = mom[i];
3535 if !old_val.is_nan() {
3536 sum_ind -= old_val;
3537 count_ind -= 1;
3538 }
3539 if !new_val.is_nan() {
3540 sum_ind += new_val;
3541 count_ind += 1;
3542 }
3543 indicator[i] = if count_ind > 0 {
3544 (sum_ind / count_ind as f64) * 100.0
3545 } else {
3546 f64::NAN
3547 };
3548 }
3549 }
3550
3551 let mut sum_sig = 0.0;
3552 let mut count_sig = 0;
3553
3554 for i in ind_warmup..(ind_warmup + signal_period).min(len) {
3555 if !indicator[i].is_nan() {
3556 sum_sig += indicator[i];
3557 count_sig += 1;
3558 }
3559 }
3560
3561 if count_sig > 0 && sig_warmup < len {
3562 signal[sig_warmup] = sum_sig / count_sig as f64;
3563
3564 for i in (sig_warmup + 1)..len {
3565 let old_val = indicator[i - signal_period];
3566 let new_val = indicator[i];
3567 if !old_val.is_nan() {
3568 sum_sig -= old_val;
3569 count_sig -= 1;
3570 }
3571 if !new_val.is_nan() {
3572 sum_sig += new_val;
3573 count_sig += 1;
3574 }
3575 signal[i] = if count_sig > 0 {
3576 sum_sig / count_sig as f64
3577 } else {
3578 f64::NAN
3579 };
3580 }
3581 }
3582
3583 Ok(RsmkOutput { indicator, signal })
3584}
3585
3586#[inline]
3587fn rsmk_classic_ema(
3588 mom: &[f64],
3589 period: usize,
3590 signal_period: usize,
3591 first_valid: usize,
3592) -> Result<RsmkOutput, RsmkError> {
3593 let len = mom.len();
3594 let ind_warmup = first_valid + period - 1;
3595 let sig_warmup = ind_warmup + signal_period - 1;
3596
3597 let needed = period.max(signal_period);
3598 if len < first_valid || len - first_valid < needed {
3599 return Err(RsmkError::NotEnoughValidData {
3600 needed,
3601 valid: if len >= first_valid {
3602 len - first_valid
3603 } else {
3604 0
3605 },
3606 });
3607 }
3608
3609 let mut indicator = alloc_with_nan_prefix(len, ind_warmup);
3610 let mut signal = alloc_with_nan_prefix(len, sig_warmup);
3611
3612 let alpha_ind = 2.0 / (period as f64 + 1.0);
3613 let one_minus_alpha_ind = 1.0 - alpha_ind;
3614
3615 let mut sum_ind = 0.0;
3616 let mut count_ind = 0;
3617 for i in first_valid..(first_valid + period).min(len) {
3618 if !mom[i].is_nan() {
3619 sum_ind += mom[i];
3620 count_ind += 1;
3621 }
3622 }
3623
3624 if count_ind > 0 && ind_warmup < len {
3625 let mut ema_ind = (sum_ind / count_ind as f64) * 100.0;
3626 indicator[ind_warmup] = ema_ind;
3627
3628 for i in (ind_warmup + 1)..len {
3629 if !mom[i].is_nan() {
3630 ema_ind = (alpha_ind * mom[i] * 100.0) + (one_minus_alpha_ind * ema_ind);
3631 }
3632 indicator[i] = ema_ind;
3633 }
3634 }
3635
3636 let alpha_sig = 2.0 / (signal_period as f64 + 1.0);
3637 let one_minus_alpha_sig = 1.0 - alpha_sig;
3638
3639 let mut sum_sig = 0.0;
3640 let mut count_sig = 0;
3641 for i in ind_warmup..(ind_warmup + signal_period).min(len) {
3642 if !indicator[i].is_nan() {
3643 sum_sig += indicator[i];
3644 count_sig += 1;
3645 }
3646 }
3647
3648 if count_sig > 0 && sig_warmup < len {
3649 let mut ema_sig = sum_sig / count_sig as f64;
3650 signal[sig_warmup] = ema_sig;
3651
3652 for i in (sig_warmup + 1)..len {
3653 if !indicator[i].is_nan() {
3654 ema_sig = (alpha_sig * indicator[i]) + (one_minus_alpha_sig * ema_sig);
3655 }
3656 signal[i] = ema_sig;
3657 }
3658 }
3659
3660 Ok(RsmkOutput { indicator, signal })
3661}
3662
3663#[inline]
3664fn rsmk_classic_ema_sma(
3665 mom: &[f64],
3666 period: usize,
3667 signal_period: usize,
3668 first_valid: usize,
3669) -> Result<RsmkOutput, RsmkError> {
3670 let len = mom.len();
3671 let ind_warmup = first_valid + period - 1;
3672 let sig_warmup = ind_warmup + signal_period - 1;
3673
3674 let needed = period.max(signal_period);
3675 if len < first_valid || len - first_valid < needed {
3676 return Err(RsmkError::NotEnoughValidData {
3677 needed,
3678 valid: if len >= first_valid {
3679 len - first_valid
3680 } else {
3681 0
3682 },
3683 });
3684 }
3685
3686 let mut indicator = alloc_with_nan_prefix(len, ind_warmup);
3687 let mut signal = alloc_with_nan_prefix(len, sig_warmup);
3688
3689 if ind_warmup < len {
3690 let mut sum = 0.0;
3691 let mut cnt = 0usize;
3692 let init_end = (first_valid + period).min(len);
3693 unsafe {
3694 for i in first_valid..init_end {
3695 let v = *mom.get_unchecked(i);
3696 if !v.is_nan() {
3697 sum += v;
3698 cnt += 1;
3699 }
3700 }
3701 }
3702
3703 if cnt > 0 {
3704 let alpha_ind = 2.0 / (period as f64 + 1.0);
3705 let mut ema_ind = (sum / cnt as f64) * 100.0;
3706
3707 let mut sum_sig = 0.0;
3708 let mut cnt_sig = 0usize;
3709
3710 unsafe {
3711 *indicator.get_unchecked_mut(ind_warmup) = ema_ind;
3712
3713 sum_sig += ema_ind;
3714 cnt_sig += 1;
3715
3716 for i in (ind_warmup + 1)..len {
3717 let mv = *mom.get_unchecked(i);
3718 if !mv.is_nan() {
3719 let src100 = mv * 100.0;
3720 ema_ind = (src100 - ema_ind).mul_add(alpha_ind, ema_ind);
3721 }
3722 *indicator.get_unchecked_mut(i) = ema_ind;
3723
3724 if !ema_ind.is_nan() {
3725 sum_sig += ema_ind;
3726 cnt_sig += 1;
3727 }
3728
3729 if i >= sig_warmup {
3730 let old_idx = i - signal_period;
3731 let old_ind = *indicator.get_unchecked(old_idx);
3732 if !old_ind.is_nan() {
3733 sum_sig -= old_ind;
3734 cnt_sig -= 1;
3735 }
3736
3737 *signal.get_unchecked_mut(i) = if cnt_sig > 0 {
3738 sum_sig / cnt_sig as f64
3739 } else {
3740 f64::NAN
3741 };
3742 }
3743 }
3744 }
3745 } else {
3746 for i in ind_warmup..len {
3747 indicator[i] = f64::NAN;
3748 }
3749 for i in sig_warmup..len {
3750 signal[i] = f64::NAN;
3751 }
3752 }
3753 }
3754
3755 Ok(RsmkOutput { indicator, signal })
3756}
3757
3758#[inline]
3759fn rsmk_classic_sma_ema(
3760 mom: &[f64],
3761 period: usize,
3762 signal_period: usize,
3763 first_valid: usize,
3764) -> Result<RsmkOutput, RsmkError> {
3765 let len = mom.len();
3766 let ind_warmup = first_valid + period - 1;
3767 let sig_warmup = ind_warmup + signal_period - 1;
3768
3769 let needed = period.max(signal_period);
3770 if len < first_valid || len - first_valid < needed {
3771 return Err(RsmkError::NotEnoughValidData {
3772 needed,
3773 valid: if len >= first_valid {
3774 len - first_valid
3775 } else {
3776 0
3777 },
3778 });
3779 }
3780
3781 let mut indicator = alloc_with_nan_prefix(len, ind_warmup);
3782 let mut signal = alloc_with_nan_prefix(len, sig_warmup);
3783
3784 let mut sum_ind = 0.0;
3785 let mut cnt_ind = 0usize;
3786
3787 let alpha_sig = 2.0 / (signal_period as f64 + 1.0);
3788 let mut acc_sig = 0.0;
3789 let mut cnt_sig = 0usize;
3790 let mut ema_sig = 0.0f64;
3791
3792 unsafe {
3793 for i in first_valid..len {
3794 let v_new = *mom.get_unchecked(i);
3795 if !v_new.is_nan() {
3796 sum_ind += v_new;
3797 cnt_ind += 1;
3798 }
3799
3800 if i >= first_valid + period {
3801 let v_old = *mom.get_unchecked(i - period);
3802 if !v_old.is_nan() {
3803 sum_ind -= v_old;
3804 cnt_ind -= 1;
3805 }
3806 }
3807
3808 if i >= ind_warmup {
3809 let ind_val = if cnt_ind > 0 {
3810 (sum_ind / cnt_ind as f64) * 100.0
3811 } else {
3812 f64::NAN
3813 };
3814 *indicator.get_unchecked_mut(i) = ind_val;
3815
3816 if i < sig_warmup {
3817 if !ind_val.is_nan() {
3818 acc_sig += ind_val;
3819 cnt_sig += 1;
3820 }
3821 } else if i == sig_warmup {
3822 ema_sig = if cnt_sig > 0 {
3823 acc_sig / cnt_sig as f64
3824 } else {
3825 f64::NAN
3826 };
3827 *signal.get_unchecked_mut(i) = ema_sig;
3828 } else {
3829 if !ind_val.is_nan() && !ema_sig.is_nan() {
3830 ema_sig = (ind_val - ema_sig).mul_add(alpha_sig, ema_sig);
3831 } else if !ind_val.is_nan() && ema_sig.is_nan() {
3832 ema_sig = ind_val;
3833 }
3834 *signal.get_unchecked_mut(i) = ema_sig;
3835 }
3836 }
3837 }
3838 }
3839
3840 Ok(RsmkOutput { indicator, signal })
3841}