1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::indicators::moving_averages::alma::DeviceArrayF32Py;
3#[cfg(feature = "python")]
4use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyValueError;
7#[cfg(feature = "python")]
8use pyo3::prelude::*;
9#[cfg(feature = "python")]
10use pyo3::types::PyDict;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::{source_type, Candles};
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21 make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25
26#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
27use core::arch::x86_64::*;
28
29#[cfg(not(target_arch = "wasm32"))]
30use rayon::prelude::*;
31
32use std::convert::AsRef;
33use std::error::Error;
34use std::mem::MaybeUninit;
35use thiserror::Error;
36
37impl<'a> AsRef<[f64]> for NetMyrsiInput<'a> {
38 #[inline(always)]
39 fn as_ref(&self) -> &[f64] {
40 match &self.data {
41 NetMyrsiData::Slice(slice) => slice,
42 NetMyrsiData::Candles { candles, source } => source_type(candles, source),
43 }
44 }
45}
46
47#[derive(Debug, Clone)]
48pub enum NetMyrsiData<'a> {
49 Candles {
50 candles: &'a Candles,
51 source: &'a str,
52 },
53 Slice(&'a [f64]),
54}
55
56#[derive(Debug, Clone)]
57pub struct NetMyrsiOutput {
58 pub values: Vec<f64>,
59}
60
61#[derive(Debug, Clone)]
62#[cfg_attr(
63 all(target_arch = "wasm32", feature = "wasm"),
64 derive(Serialize, Deserialize)
65)]
66pub struct NetMyrsiParams {
67 pub period: Option<usize>,
68}
69
70impl Default for NetMyrsiParams {
71 fn default() -> Self {
72 Self { period: Some(14) }
73 }
74}
75
76#[derive(Debug, Clone)]
77pub struct NetMyrsiInput<'a> {
78 pub data: NetMyrsiData<'a>,
79 pub params: NetMyrsiParams,
80}
81
82impl<'a> NetMyrsiInput<'a> {
83 #[inline]
84 pub fn from_candles(c: &'a Candles, s: &'a str, p: NetMyrsiParams) -> Self {
85 Self {
86 data: NetMyrsiData::Candles {
87 candles: c,
88 source: s,
89 },
90 params: p,
91 }
92 }
93
94 #[inline]
95 pub fn from_slice(sl: &'a [f64], p: NetMyrsiParams) -> Self {
96 Self {
97 data: NetMyrsiData::Slice(sl),
98 params: p,
99 }
100 }
101
102 #[inline]
103 pub fn with_default_candles(c: &'a Candles) -> Self {
104 Self::from_candles(c, "close", NetMyrsiParams::default())
105 }
106
107 #[inline]
108 pub fn get_period(&self) -> usize {
109 self.params.period.unwrap_or(14)
110 }
111}
112
113#[derive(Copy, Clone, Debug)]
114pub struct NetMyrsiBuilder {
115 period: Option<usize>,
116 kernel: Kernel,
117}
118
119impl Default for NetMyrsiBuilder {
120 fn default() -> Self {
121 Self {
122 period: None,
123 kernel: Kernel::Auto,
124 }
125 }
126}
127
128impl NetMyrsiBuilder {
129 #[inline(always)]
130 pub fn new() -> Self {
131 Self::default()
132 }
133
134 #[inline(always)]
135 pub fn period(mut self, val: usize) -> Self {
136 self.period = Some(val);
137 self
138 }
139
140 #[inline(always)]
141 pub fn kernel(mut self, k: Kernel) -> Self {
142 self.kernel = k;
143 self
144 }
145
146 #[inline(always)]
147 pub fn apply(self, c: &Candles) -> Result<NetMyrsiOutput, NetMyrsiError> {
148 let p = NetMyrsiParams {
149 period: self.period,
150 };
151 let i = NetMyrsiInput::from_candles(c, "close", p);
152 net_myrsi_with_kernel(&i, self.kernel)
153 }
154
155 #[inline(always)]
156 pub fn apply_slice(self, d: &[f64]) -> Result<NetMyrsiOutput, NetMyrsiError> {
157 let p = NetMyrsiParams {
158 period: self.period,
159 };
160 let i = NetMyrsiInput::from_slice(d, p);
161 net_myrsi_with_kernel(&i, self.kernel)
162 }
163
164 #[inline(always)]
165 pub fn into_stream(self) -> Result<NetMyrsiStream, NetMyrsiError> {
166 let p = NetMyrsiParams {
167 period: self.period,
168 };
169 NetMyrsiStream::try_new(p)
170 }
171}
172
173#[derive(Debug, Error)]
174pub enum NetMyrsiError {
175 #[error("net_myrsi: Input data slice is empty.")]
176 EmptyInputData,
177
178 #[error("net_myrsi: All values are NaN.")]
179 AllValuesNaN,
180
181 #[error("net_myrsi: Invalid period: period = {period}, data length = {data_len}")]
182 InvalidPeriod { period: usize, data_len: usize },
183
184 #[error("net_myrsi: Not enough valid data: needed = {needed}, valid = {valid}")]
185 NotEnoughValidData { needed: usize, valid: usize },
186
187 #[error("net_myrsi: Output length mismatch: expected {expected}, got {got}")]
188 OutputLengthMismatch { expected: usize, got: usize },
189
190 #[error("net_myrsi: Invalid range: start={start}, end={end}, step={step}")]
191 InvalidRange {
192 start: String,
193 end: String,
194 step: String,
195 },
196
197 #[error("net_myrsi: Invalid kernel for batch: {0:?}")]
198 InvalidKernelForBatch(Kernel),
199}
200
201#[inline(always)]
202fn net_myrsi_compute_into(
203 data: &[f64],
204 period: usize,
205 first: usize,
206 out: &mut [f64],
207 kernel: Kernel,
208) {
209 let k = match kernel {
210 Kernel::Scalar | Kernel::ScalarBatch => Kernel::Scalar,
211 Kernel::Avx2 | Kernel::Avx2Batch => Kernel::Avx2,
212 Kernel::Avx512 | Kernel::Avx512Batch => Kernel::Avx512,
213 Kernel::Auto => Kernel::Scalar,
214 };
215
216 unsafe {
217 match kernel {
218 Kernel::Scalar | Kernel::ScalarBatch => {
219 net_myrsi_kernel_scalar(data, period, first, out)
220 }
221 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
222 Kernel::Avx2 | Kernel::Avx2Batch => net_myrsi_kernel_avx2(data, period, first, out),
223 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
224 Kernel::Avx512 | Kernel::Avx512Batch => {
225 net_myrsi_kernel_avx512(data, period, first, out)
226 }
227 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
228 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
229 net_myrsi_kernel_scalar(data, period, first, out)
230 }
231 Kernel::Auto => match k {
232 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
233 Kernel::Avx512 => net_myrsi_kernel_avx512(data, period, first, out),
234 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
235 Kernel::Avx2 => net_myrsi_kernel_avx2(data, period, first, out),
236 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
237 Kernel::Avx2 | Kernel::Avx512 => net_myrsi_kernel_scalar(data, period, first, out),
238 Kernel::Scalar => net_myrsi_kernel_scalar(data, period, first, out),
239 Kernel::Auto | Kernel::ScalarBatch | Kernel::Avx2Batch | Kernel::Avx512Batch => {
240 unreachable!()
241 }
242 },
243 }
244 }
245}
246
247#[inline(always)]
248fn compute_myrsi_from(data: &[f64], period: usize, first: usize, out_myrsi: &mut [f64]) {
249 let start = first + period;
250 let len = data.len();
251
252 for i in start..len {
253 let mut cu = 0.0;
254 let mut cd = 0.0;
255 for j in 0..period {
256 let newer = data[i - j];
257 let older = data[i - j - 1];
258 let diff = newer - older;
259 if diff > 0.0 {
260 cu += diff;
261 } else if diff < 0.0 {
262 cd += -diff;
263 }
264 }
265 let sum = cu + cd;
266 out_myrsi[i] = if sum != 0.0 { (cu - cd) / sum } else { 0.0 };
267 }
268}
269
270#[inline(always)]
271fn compute_net_from(myrsi: &[f64], period: usize, first: usize, out: &mut [f64]) {
272 let start = first + period - 1;
273 let len = myrsi.len();
274 let denom = (period * (period - 1)) as f64 / 2.0;
275
276 for idx in start..len {
277 let mut num = 0.0;
278 for i in 1..period {
279 for k in 0..i {
280 let vi = myrsi[idx - i];
281 let vk = myrsi[idx - k];
282 let d = vi - vk;
283 if d > 0.0 {
284 num -= 1.0;
285 } else if d < 0.0 {
286 num += 1.0;
287 }
288 }
289 }
290 out[idx] = num / denom;
291 }
292}
293
294#[inline(always)]
295fn net_myrsi_kernel_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
296 let len = data.len();
297 if len <= first + 1 {
298 return;
299 }
300
301 let mut cu = 0.0f64;
302 let mut cd = 0.0f64;
303 let mut diffs = vec![0.0f64; period];
304 let mut d_head = 0usize;
305 let mut d_count = 0usize;
306
307 let mut myr = vec![0.0f64; period];
308 let mut m_head = 0usize;
309 let mut m_count = 0usize;
310 let mut num: i32 = 0;
311 let denom = (period * (period - 1)) as f64 * 0.5;
312
313 let warm = first + period - 1;
314 if warm < out.len() {
315 out[warm] = if period > 1 { 0.0 } else { f64::NAN };
316 }
317
318 #[inline(always)]
319 fn lt_minus_gt_scalar(slice: &[f64], s: f64) -> i32 {
320 let mut lt: i32 = 0;
321 let mut gt: i32 = 0;
322 for &v in slice {
323 lt += (v < s) as i32;
324 gt += (v > s) as i32;
325 }
326 lt - gt
327 }
328
329 let mut i = first + 1;
330 while i < len {
331 let newer = data[i];
332 let older = data[i - 1];
333 let diff = newer - older;
334
335 cu += ((diff > 0.0) as i32 as f64) * diff;
336 cd += ((diff < 0.0) as i32 as f64) * (-diff);
337
338 if d_count < period {
339 diffs[d_head] = diff;
340 d_head += 1;
341 if d_head == period {
342 d_head = 0;
343 }
344 d_count += 1;
345 } else {
346 let old = diffs[d_head];
347 cu -= ((old > 0.0) as i32 as f64) * old;
348 cd -= ((old < 0.0) as i32 as f64) * (-old);
349 diffs[d_head] = diff;
350 d_head += 1;
351 if d_head == period {
352 d_head = 0;
353 }
354 }
355
356 if d_count >= period {
357 let sum = cu + cd;
358 let r = if sum != 0.0 { (cu - cd) / sum } else { 0.0 };
359
360 if m_count < period {
361 let add = lt_minus_gt_scalar(&myr[..m_head], r);
362 num += add;
363 myr[m_head] = r;
364 m_head += 1;
365 if m_head == period {
366 m_head = 0;
367 }
368 m_count += 1;
369 } else {
370 let z = myr[m_head];
371 let rm1 = if m_head + 1 < period {
372 lt_minus_gt_scalar(&myr[m_head + 1..period], z)
373 } else {
374 0
375 };
376 let rm2 = if m_head > 0 {
377 lt_minus_gt_scalar(&myr[..m_head], z)
378 } else {
379 0
380 };
381 num += rm1 + rm2;
382
383 let ad1 = if m_head + 1 < period {
384 lt_minus_gt_scalar(&myr[m_head + 1..period], r)
385 } else {
386 0
387 };
388 let ad2 = if m_head > 0 {
389 lt_minus_gt_scalar(&myr[..m_head], r)
390 } else {
391 0
392 };
393 num += ad1 + ad2;
394
395 myr[m_head] = r;
396 m_head += 1;
397 if m_head == period {
398 m_head = 0;
399 }
400 }
401
402 if denom != 0.0 {
403 out[i] = (num as f64) / denom;
404 }
405 }
406
407 i += 1;
408 }
409}
410
411#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
412#[inline(always)]
413unsafe fn net_myrsi_kernel_avx2(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
414 use core::arch::x86_64::*;
415 if data.len() <= first + 1 {
416 return;
417 }
418
419 let mut cu = 0.0f64;
420 let mut cd = 0.0f64;
421 let mut diffs = vec![0.0f64; period];
422 let mut d_head = 0usize;
423 let mut d_count = 0usize;
424
425 let mut myr = vec![0.0f64; period];
426 let mut m_head = 0usize;
427 let mut m_count = 0usize;
428 let mut num: i32 = 0;
429 let denom = (period * (period - 1)) as f64 * 0.5;
430
431 let warm = first + period - 1;
432 if warm < out.len() {
433 out[warm] = if period > 1 { 0.0 } else { f64::NAN };
434 }
435
436 #[inline(always)]
437 unsafe fn lt_minus_gt_avx2(ptr: *const f64, len: usize, s: f64) -> i32 {
438 let mut lt: i32 = 0;
439 let mut gt: i32 = 0;
440 let mut p = ptr;
441 let mut n = len;
442 let vs = _mm256_set1_pd(s);
443
444 while n >= 4 {
445 let v = _mm256_loadu_pd(p);
446 let mlt = _mm256_cmp_pd(v, vs, _CMP_LT_OQ);
447 let mgt = _mm256_cmp_pd(v, vs, _CMP_GT_OQ);
448 lt += (_mm256_movemask_pd(mlt) as u32).count_ones() as i32;
449 gt += (_mm256_movemask_pd(mgt) as u32).count_ones() as i32;
450 p = p.add(4);
451 n -= 4;
452 }
453 let slice = core::slice::from_raw_parts(p, n);
454 let mut i = 0usize;
455 while i < slice.len() {
456 let v = *slice.get_unchecked(i);
457 lt += (v < s) as i32;
458 gt += (v > s) as i32;
459 i += 1;
460 }
461 lt - gt
462 }
463
464 #[inline(always)]
465 unsafe fn lt_minus_gt_avx2_two(a: &[f64], b: &[f64], s: f64) -> i32 {
466 lt_minus_gt_avx2(a.as_ptr(), a.len(), s) + lt_minus_gt_avx2(b.as_ptr(), b.len(), s)
467 }
468
469 let len = data.len();
470 let mut i = first + 1;
471 while i < len {
472 let newer = data[i];
473 let older = data[i - 1];
474 let diff = newer - older;
475
476 cu += ((diff > 0.0) as i32 as f64) * diff;
477 cd += ((diff < 0.0) as i32 as f64) * (-diff);
478
479 if d_count < period {
480 *diffs.get_unchecked_mut(d_head) = diff;
481 d_head += 1;
482 if d_head == period {
483 d_head = 0;
484 }
485 d_count += 1;
486 } else {
487 let old = *diffs.get_unchecked(d_head);
488 cu -= ((old > 0.0) as i32 as f64) * old;
489 cd -= ((old < 0.0) as i32 as f64) * (-old);
490 *diffs.get_unchecked_mut(d_head) = diff;
491 d_head += 1;
492 if d_head == period {
493 d_head = 0;
494 }
495 }
496
497 if d_count >= period {
498 let sum = cu + cd;
499 let r = if sum != 0.0 { (cu - cd) / sum } else { 0.0 };
500
501 if m_count < period {
502 let add = lt_minus_gt_avx2(myr.as_ptr(), m_head, r);
503 num += add;
504 *myr.get_unchecked_mut(m_head) = r;
505 m_head += 1;
506 if m_head == period {
507 m_head = 0;
508 }
509 m_count += 1;
510 } else {
511 let z = *myr.get_unchecked(m_head);
512 let ad_rm = if m_head + 1 < period {
513 lt_minus_gt_avx2_two(&myr[m_head + 1..period], &myr[..m_head], z)
514 } else {
515 lt_minus_gt_avx2(myr.as_ptr(), m_head, z)
516 };
517 num += ad_rm;
518
519 let ad_new = if m_head + 1 < period {
520 lt_minus_gt_avx2_two(&myr[m_head + 1..period], &myr[..m_head], r)
521 } else {
522 lt_minus_gt_avx2(myr.as_ptr(), m_head, r)
523 };
524 num += ad_new;
525
526 *myr.get_unchecked_mut(m_head) = r;
527 m_head += 1;
528 if m_head == period {
529 m_head = 0;
530 }
531 }
532
533 if denom != 0.0 {
534 out[i] = (num as f64) / denom;
535 }
536 }
537
538 i += 1;
539 }
540}
541
542#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
543#[inline(always)]
544unsafe fn net_myrsi_kernel_avx512(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
545 use core::arch::x86_64::*;
546 if data.len() <= first + 1 {
547 return;
548 }
549
550 let mut cu = 0.0f64;
551 let mut cd = 0.0f64;
552 let mut diffs = vec![0.0f64; period];
553 let mut d_head = 0usize;
554 let mut d_count = 0usize;
555
556 let mut myr = vec![0.0f64; period];
557 let mut m_head = 0usize;
558 let mut m_count = 0usize;
559 let mut num: i32 = 0;
560 let denom = (period * (period - 1)) as f64 * 0.5;
561
562 let warm = first + period - 1;
563 if warm < out.len() {
564 out[warm] = if period > 1 { 0.0 } else { f64::NAN };
565 }
566
567 #[inline(always)]
568 unsafe fn lt_minus_gt_avx512(ptr: *const f64, len: usize, s: f64) -> i32 {
569 let mut lt: i32 = 0;
570 let mut gt: i32 = 0;
571 let mut p = ptr;
572 let mut n = len;
573 let vs = _mm512_set1_pd(s);
574
575 while n >= 8 {
576 let v = _mm512_loadu_pd(p);
577 let mlt: __mmask8 = _mm512_cmp_pd_mask(v, vs, _CMP_LT_OQ);
578 let mgt: __mmask8 = _mm512_cmp_pd_mask(v, vs, _CMP_GT_OQ);
579 lt += (mlt as u32).count_ones() as i32;
580 gt += (mgt as u32).count_ones() as i32;
581 p = p.add(8);
582 n -= 8;
583 }
584 if n >= 4 {
585 let v = _mm256_loadu_pd(p);
586 let vs2 = _mm256_set1_pd(s);
587 let mlt = _mm256_cmp_pd(v, vs2, _CMP_LT_OQ);
588 let mgt = _mm256_cmp_pd(v, vs2, _CMP_GT_OQ);
589 lt += (_mm256_movemask_pd(mlt) as u32).count_ones() as i32;
590 gt += (_mm256_movemask_pd(mgt) as u32).count_ones() as i32;
591 p = p.add(4);
592 n -= 4;
593 }
594 let slice = core::slice::from_raw_parts(p, n);
595 let mut i = 0usize;
596 while i < slice.len() {
597 let v = *slice.get_unchecked(i);
598 lt += (v < s) as i32;
599 gt += (v > s) as i32;
600 i += 1;
601 }
602 lt - gt
603 }
604
605 #[inline(always)]
606 unsafe fn lt_minus_gt_avx512_two(a: &[f64], b: &[f64], s: f64) -> i32 {
607 lt_minus_gt_avx512(a.as_ptr(), a.len(), s) + lt_minus_gt_avx512(b.as_ptr(), b.len(), s)
608 }
609
610 let len = data.len();
611 let mut i = first + 1;
612 while i < len {
613 let newer = data[i];
614 let older = data[i - 1];
615 let diff = newer - older;
616
617 cu += ((diff > 0.0) as i32 as f64) * diff;
618 cd += ((diff < 0.0) as i32 as f64) * (-diff);
619
620 if d_count < period {
621 *diffs.get_unchecked_mut(d_head) = diff;
622 d_head += 1;
623 if d_head == period {
624 d_head = 0;
625 }
626 d_count += 1;
627 } else {
628 let old = *diffs.get_unchecked(d_head);
629 cu -= ((old > 0.0) as i32 as f64) * old;
630 cd -= ((old < 0.0) as i32 as f64) * (-old);
631 *diffs.get_unchecked_mut(d_head) = diff;
632 d_head += 1;
633 if d_head == period {
634 d_head = 0;
635 }
636 }
637
638 if d_count >= period {
639 let sum = cu + cd;
640 let r = if sum != 0.0 { (cu - cd) / sum } else { 0.0 };
641
642 if m_count < period {
643 let add = lt_minus_gt_avx512(myr.as_ptr(), m_head, r);
644 num += add;
645 *myr.get_unchecked_mut(m_head) = r;
646 m_head += 1;
647 if m_head == period {
648 m_head = 0;
649 }
650 m_count += 1;
651 } else {
652 let z = *myr.get_unchecked(m_head);
653 let ad_rm = if m_head + 1 < period {
654 lt_minus_gt_avx512_two(&myr[m_head + 1..period], &myr[..m_head], z)
655 } else {
656 lt_minus_gt_avx512(myr.as_ptr(), m_head, z)
657 };
658 num += ad_rm;
659
660 let ad_new = if m_head + 1 < period {
661 lt_minus_gt_avx512_two(&myr[m_head + 1..period], &myr[..m_head], r)
662 } else {
663 lt_minus_gt_avx512(myr.as_ptr(), m_head, r)
664 };
665 num += ad_new;
666
667 *myr.get_unchecked_mut(m_head) = r;
668 m_head += 1;
669 if m_head == period {
670 m_head = 0;
671 }
672 }
673
674 if denom != 0.0 {
675 out[i] = (num as f64) / denom;
676 }
677 }
678
679 i += 1;
680 }
681}
682
683#[inline(always)]
684fn net_myrsi_prepare<'a>(
685 input: &'a NetMyrsiInput,
686 kernel: Kernel,
687) -> Result<(&'a [f64], usize, usize, Kernel), NetMyrsiError> {
688 let data: &[f64] = input.as_ref();
689 let len = data.len();
690
691 if len == 0 {
692 return Err(NetMyrsiError::EmptyInputData);
693 }
694
695 let first = data
696 .iter()
697 .position(|x| !x.is_nan())
698 .ok_or(NetMyrsiError::AllValuesNaN)?;
699 let period = input.get_period();
700
701 if period == 0 || period > len {
702 return Err(NetMyrsiError::InvalidPeriod {
703 period,
704 data_len: len,
705 });
706 }
707
708 if len - first < period + 1 {
709 return Err(NetMyrsiError::NotEnoughValidData {
710 needed: period + 1,
711 valid: len - first,
712 });
713 }
714
715 let chosen = if matches!(kernel, Kernel::Auto) {
716 detect_best_kernel()
717 } else {
718 kernel
719 };
720
721 Ok((data, period, first, chosen))
722}
723
724#[inline]
725pub fn net_myrsi(input: &NetMyrsiInput) -> Result<NetMyrsiOutput, NetMyrsiError> {
726 net_myrsi_with_kernel(input, Kernel::Auto)
727}
728
729pub fn net_myrsi_with_kernel(
730 input: &NetMyrsiInput,
731 kernel: Kernel,
732) -> Result<NetMyrsiOutput, NetMyrsiError> {
733 let (data, period, first, chosen) = net_myrsi_prepare(input, kernel)?;
734 let mut out = alloc_with_nan_prefix(data.len(), first + period - 1);
735 net_myrsi_compute_into(data, period, first, &mut out, chosen);
736 Ok(NetMyrsiOutput { values: out })
737}
738
739#[inline]
740pub fn net_myrsi_into_slice(
741 dst: &mut [f64],
742 input: &NetMyrsiInput,
743 kern: Kernel,
744) -> Result<(), NetMyrsiError> {
745 let (data, period, first, chosen) = net_myrsi_prepare(input, kern)?;
746
747 if dst.len() != data.len() {
748 return Err(NetMyrsiError::OutputLengthMismatch {
749 expected: data.len(),
750 got: dst.len(),
751 });
752 }
753
754 net_myrsi_compute_into(data, period, first, dst, chosen);
755
756 let warm = first + period - 1;
757 for v in &mut dst[..warm] {
758 *v = f64::NAN;
759 }
760
761 Ok(())
762}
763
764#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
765#[inline]
766pub fn net_myrsi_into(input: &NetMyrsiInput, out: &mut [f64]) -> Result<(), NetMyrsiError> {
767 let (data, period, first, chosen) = net_myrsi_prepare(input, Kernel::Auto)?;
768
769 if out.len() != data.len() {
770 return Err(NetMyrsiError::OutputLengthMismatch {
771 expected: data.len(),
772 got: out.len(),
773 });
774 }
775
776 let warm = first + period - 1;
777 let warm = warm.min(out.len());
778 for v in &mut out[..warm] {
779 *v = f64::from_bits(0x7ff8_0000_0000_0000);
780 }
781
782 net_myrsi_compute_into(data, period, first, out, chosen);
783 Ok(())
784}
785
786#[derive(Debug, Clone)]
787pub struct NetMyrsiStream {
788 period: usize,
789 price: Vec<f64>,
790 myrsi: Vec<f64>,
791 head: usize,
792 filled_prices: bool,
793 filled_myrsi: bool,
794}
795
796impl NetMyrsiStream {
797 pub fn try_new(params: NetMyrsiParams) -> Result<Self, NetMyrsiError> {
798 let period = params.period.unwrap_or(14);
799 if period == 0 {
800 return Err(NetMyrsiError::InvalidPeriod {
801 period,
802 data_len: 0,
803 });
804 }
805
806 Ok(Self {
807 period,
808 price: vec![f64::NAN; period + 1],
809 myrsi: vec![f64::NAN; period],
810 head: 0,
811 filled_prices: false,
812 filled_myrsi: false,
813 })
814 }
815
816 #[inline(always)]
817 pub fn update(&mut self, value: f64) -> Option<f64> {
818 self.price[self.head % (self.period + 1)] = value;
819 if !self.filled_prices && self.head + 1 >= self.period + 1 {
820 self.filled_prices = true;
821 }
822
823 if self.filled_prices {
824 let mut cu = 0.0;
825 let mut cd = 0.0;
826 for j in 0..self.period {
827 let newer = self.price[(self.head - j) % (self.period + 1)];
828 let older = self.price[(self.head - j - 1) % (self.period + 1)];
829 let diff = newer - older;
830 if diff > 0.0 {
831 cu += diff;
832 } else if diff < 0.0 {
833 cd += -diff;
834 }
835 }
836 let s = cu + cd;
837 let r = if s != 0.0 { (cu - cd) / s } else { 0.0 };
838 self.myrsi[self.head % self.period] = r;
839
840 if !self.filled_myrsi && self.head + 1 >= self.period * 2 {
841 self.filled_myrsi = true;
842 }
843 }
844
845 let out = if self.filled_myrsi {
846 let denom = (self.period * (self.period - 1)) as f64 / 2.0;
847 let mut num = 0.0;
848 for i in 1..self.period {
849 for k in 0..i {
850 let vi = self.myrsi[(self.head - i) % self.period];
851 let vk = self.myrsi[(self.head - k) % self.period];
852 let d = vi - vk;
853 if d > 0.0 {
854 num -= 1.0;
855 } else if d < 0.0 {
856 num += 1.0;
857 }
858 }
859 }
860 Some(num / denom)
861 } else {
862 None
863 };
864
865 self.head = self.head.wrapping_add(1);
866 out
867 }
868}
869
870#[derive(Clone, Debug)]
871pub struct NetMyrsiBatchRange {
872 pub period: (usize, usize, usize),
873}
874
875impl Default for NetMyrsiBatchRange {
876 fn default() -> Self {
877 Self {
878 period: (14, 263, 1),
879 }
880 }
881}
882
883#[derive(Clone, Debug, Default)]
884pub struct NetMyrsiBatchBuilder {
885 range: NetMyrsiBatchRange,
886 kernel: Kernel,
887}
888
889impl NetMyrsiBatchBuilder {
890 pub fn new() -> Self {
891 Self::default()
892 }
893
894 pub fn kernel(mut self, k: Kernel) -> Self {
895 self.kernel = k;
896 self
897 }
898
899 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
900 self.range.period = (start, end, step);
901 self
902 }
903
904 pub fn period_static(mut self, p: usize) -> Self {
905 self.range.period = (p, p, 0);
906 self
907 }
908
909 pub fn apply_slice(self, data: &[f64]) -> Result<NetMyrsiBatchOutput, NetMyrsiError> {
910 net_myrsi_batch_with_kernel(data, &self.range, self.kernel)
911 }
912
913 pub fn apply_candles(
914 self,
915 c: &Candles,
916 src: &str,
917 ) -> Result<NetMyrsiBatchOutput, NetMyrsiError> {
918 self.apply_slice(source_type(c, src))
919 }
920
921 pub fn with_default_candles(c: &Candles) -> Result<NetMyrsiBatchOutput, NetMyrsiError> {
922 NetMyrsiBatchBuilder::new()
923 .kernel(Kernel::Auto)
924 .apply_candles(c, "close")
925 }
926}
927
928#[derive(Clone, Debug)]
929pub struct NetMyrsiBatchOutput {
930 pub values: Vec<f64>,
931 pub combos: Vec<NetMyrsiParams>,
932 pub rows: usize,
933 pub cols: usize,
934}
935
936impl NetMyrsiBatchOutput {
937 pub fn row_for_params(&self, p: &NetMyrsiParams) -> Option<usize> {
938 self.combos
939 .iter()
940 .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
941 }
942
943 pub fn values_for(&self, p: &NetMyrsiParams) -> Option<&[f64]> {
944 self.row_for_params(p).and_then(|r| {
945 let start = r.checked_mul(self.cols)?;
946 let end = start.checked_add(self.cols)?;
947 self.values.get(start..end)
948 })
949 }
950}
951
952#[inline(always)]
953fn expand_grid_period(
954 (start, end, step): (usize, usize, usize),
955) -> Result<Vec<usize>, NetMyrsiError> {
956 if step == 0 || start == end {
957 return Ok(vec![start]);
958 }
959
960 if start < end {
961 let mut v = Vec::new();
962 let mut x = start;
963 let st = step.max(1);
964 while x <= end {
965 v.push(x);
966 x = x
967 .checked_add(st)
968 .ok_or_else(|| NetMyrsiError::InvalidRange {
969 start: start.to_string(),
970 end: end.to_string(),
971 step: step.to_string(),
972 })?;
973 }
974 if v.is_empty() {
975 return Err(NetMyrsiError::InvalidRange {
976 start: start.to_string(),
977 end: end.to_string(),
978 step: step.to_string(),
979 });
980 }
981 return Ok(v);
982 }
983
984 let mut v = Vec::new();
985 let mut x = start as isize;
986 let end_i = end as isize;
987 let st = (step as isize).max(1);
988 while x >= end_i {
989 v.push(x as usize);
990 x -= st;
991 }
992 if v.is_empty() {
993 return Err(NetMyrsiError::InvalidRange {
994 start: start.to_string(),
995 end: end.to_string(),
996 step: step.to_string(),
997 });
998 }
999 Ok(v)
1000}
1001
1002pub fn net_myrsi_batch_with_kernel(
1003 data: &[f64],
1004 sweep: &NetMyrsiBatchRange,
1005 k: Kernel,
1006) -> Result<NetMyrsiBatchOutput, NetMyrsiError> {
1007 let kernel = match k {
1008 Kernel::Auto => detect_best_batch_kernel(),
1009 other if other.is_batch() => other,
1010 other => {
1011 return Err(NetMyrsiError::InvalidKernelForBatch(other));
1012 }
1013 };
1014
1015 let periods = expand_grid_period(sweep.period)?;
1016 let combos: Vec<NetMyrsiParams> = periods
1017 .into_iter()
1018 .map(|p| NetMyrsiParams { period: Some(p) })
1019 .collect();
1020 net_myrsi_batch_inner(data, &combos, kernel)
1021}
1022
1023#[inline(always)]
1024fn net_myrsi_batch_inner(
1025 data: &[f64],
1026 combos: &[NetMyrsiParams],
1027 kern: Kernel,
1028) -> Result<NetMyrsiBatchOutput, NetMyrsiError> {
1029 if data.is_empty() {
1030 return Err(NetMyrsiError::EmptyInputData);
1031 }
1032 let first = data
1033 .iter()
1034 .position(|x| !x.is_nan())
1035 .ok_or(NetMyrsiError::AllValuesNaN)?;
1036 let cols = data.len();
1037 let rows = combos.len();
1038
1039 let _ = rows
1040 .checked_mul(cols)
1041 .ok_or_else(|| NetMyrsiError::InvalidRange {
1042 start: rows.to_string(),
1043 end: cols.to_string(),
1044 step: "rows*cols".into(),
1045 })?;
1046
1047 let max_needed = combos
1048 .iter()
1049 .map(|c| c.period.unwrap_or(14) + 1)
1050 .max()
1051 .unwrap();
1052 if cols - first < max_needed {
1053 return Err(NetMyrsiError::NotEnoughValidData {
1054 needed: max_needed,
1055 valid: cols - first,
1056 });
1057 }
1058
1059 let mut buf_mu = make_uninit_matrix(rows, cols);
1060
1061 let warms: Vec<usize> = combos
1062 .iter()
1063 .map(|c| first + c.period.unwrap_or(14) - 1)
1064 .collect();
1065 init_matrix_prefixes(&mut buf_mu, cols, &warms);
1066
1067 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
1068 let out: &mut [f64] =
1069 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
1070
1071 for (row, slice) in out.chunks_mut(cols).enumerate() {
1072 let p = combos[row].period.unwrap_or(14);
1073
1074 net_myrsi_compute_into(data, p, first, slice, kern);
1075 }
1076
1077 let values = unsafe {
1078 Vec::from_raw_parts(
1079 guard.as_mut_ptr() as *mut f64,
1080 guard.len(),
1081 guard.capacity(),
1082 )
1083 };
1084
1085 Ok(NetMyrsiBatchOutput {
1086 values,
1087 combos: combos.to_vec(),
1088 rows,
1089 cols,
1090 })
1091}
1092
1093#[inline(always)]
1094fn net_myrsi_batch_inner_into(
1095 data: &[f64],
1096 combos: &[NetMyrsiParams],
1097 kern: Kernel,
1098 out: &mut [f64],
1099) -> Result<(), NetMyrsiError> {
1100 if data.is_empty() {
1101 return Err(NetMyrsiError::EmptyInputData);
1102 }
1103 let first = data
1104 .iter()
1105 .position(|x| !x.is_nan())
1106 .ok_or(NetMyrsiError::AllValuesNaN)?;
1107 let cols = data.len();
1108 let rows = combos.len();
1109
1110 let expected = rows
1111 .checked_mul(cols)
1112 .ok_or_else(|| NetMyrsiError::InvalidRange {
1113 start: rows.to_string(),
1114 end: cols.to_string(),
1115 step: "rows*cols".into(),
1116 })?;
1117 if out.len() != expected {
1118 return Err(NetMyrsiError::OutputLengthMismatch {
1119 expected,
1120 got: out.len(),
1121 });
1122 }
1123
1124 let max_needed = combos
1125 .iter()
1126 .map(|c| c.period.unwrap_or(14) + 1)
1127 .max()
1128 .unwrap();
1129 if cols - first < max_needed {
1130 return Err(NetMyrsiError::NotEnoughValidData {
1131 needed: max_needed,
1132 valid: cols - first,
1133 });
1134 }
1135
1136 let mut tmp_mu = make_uninit_matrix(1, cols);
1137
1138 for (row, dst) in out.chunks_mut(cols).enumerate() {
1139 let p = combos[row].period.unwrap_or(14);
1140 let warm = first + p - 1;
1141
1142 for i in 0..warm {
1143 dst[i] = f64::NAN;
1144 }
1145 init_matrix_prefixes(&mut tmp_mu, cols, &[warm]);
1146 let tmp: &mut [f64] =
1147 unsafe { core::slice::from_raw_parts_mut(tmp_mu.as_mut_ptr() as *mut f64, cols) };
1148 compute_myrsi_from(data, p, first, tmp);
1149 compute_net_from(tmp, p, first, dst);
1150 }
1151
1152 let _ = tmp_mu;
1153 Ok(())
1154}
1155
1156#[cfg(not(target_arch = "wasm32"))]
1157pub fn net_myrsi_batch(
1158 data: &[Vec<f64>],
1159 params: &NetMyrsiParams,
1160 kernel: Option<Kernel>,
1161) -> Result<Vec<Vec<f64>>, NetMyrsiError> {
1162 let kern = kernel.unwrap_or_else(detect_best_batch_kernel);
1163
1164 let results: Result<Vec<_>, _> = data
1165 .par_iter()
1166 .map(|series| {
1167 let input = NetMyrsiInput::from_slice(series, params.clone());
1168 net_myrsi_with_kernel(&input, kern).map(|o| o.values)
1169 })
1170 .collect();
1171
1172 results
1173}
1174
1175#[cfg(feature = "python")]
1176#[pyfunction]
1177#[pyo3(name = "net_myrsi")]
1178pub fn net_myrsi_py<'py>(
1179 py: Python<'py>,
1180 data: PyReadonlyArray1<'py, f64>,
1181 period: usize,
1182 kernel: Option<&str>,
1183) -> PyResult<Bound<'py, PyArray1<f64>>> {
1184 let slice_in = data.as_slice()?;
1185 let kern = validate_kernel(kernel, false)?;
1186 let params = NetMyrsiParams {
1187 period: Some(period),
1188 };
1189 let input = NetMyrsiInput::from_slice(slice_in, params);
1190
1191 let result_vec: Vec<f64> = py
1192 .allow_threads(|| net_myrsi_with_kernel(&input, kern).map(|o| o.values))
1193 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1194
1195 Ok(result_vec.into_pyarray(py))
1196}
1197
1198#[cfg(feature = "python")]
1199#[pyfunction(name = "net_myrsi_batch")]
1200#[pyo3(signature = (data, period_range, kernel=None))]
1201pub fn net_myrsi_batch_py<'py>(
1202 py: Python<'py>,
1203 data: numpy::PyReadonlyArray1<'py, f64>,
1204 period_range: (usize, usize, usize),
1205 kernel: Option<&str>,
1206) -> PyResult<Bound<'py, PyDict>> {
1207 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1208 let slice_in = data.as_slice()?;
1209
1210 let periods =
1211 expand_grid_period(period_range).map_err(|e| PyValueError::new_err(e.to_string()))?;
1212 let combos: Vec<NetMyrsiParams> = periods
1213 .into_iter()
1214 .map(|p| NetMyrsiParams { period: Some(p) })
1215 .collect();
1216 let rows = combos.len();
1217 let cols = slice_in.len();
1218
1219 let len = rows
1220 .checked_mul(cols)
1221 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1222 let out_arr = unsafe { PyArray1::<f64>::new(py, [len], false) };
1223 let out_slice = unsafe { out_arr.as_slice_mut()? };
1224
1225 let kern = validate_kernel(kernel, true)?;
1226 py.allow_threads(|| {
1227 let actual = match kern {
1228 Kernel::Auto => detect_best_batch_kernel(),
1229 k => k,
1230 };
1231
1232 net_myrsi_batch_inner_into(slice_in, &combos, actual, out_slice)
1233 })
1234 .map_err(|e: NetMyrsiError| PyValueError::new_err(e.to_string()))?;
1235
1236 let dict = PyDict::new(py);
1237 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1238 dict.set_item(
1239 "periods",
1240 combos
1241 .iter()
1242 .map(|p| p.period.unwrap() as u64)
1243 .collect::<Vec<_>>()
1244 .into_pyarray(py),
1245 )?;
1246 Ok(dict)
1247}
1248
1249#[cfg(all(feature = "python", feature = "cuda"))]
1250#[pyfunction(name = "net_myrsi_cuda_batch_dev")]
1251#[pyo3(signature = (data_f32, period_range, device_id=0))]
1252pub fn net_myrsi_cuda_batch_dev_py<'py>(
1253 py: Python<'py>,
1254 data_f32: numpy::PyReadonlyArray1<'py, f32>,
1255 period_range: (usize, usize, usize),
1256 device_id: usize,
1257) -> PyResult<DeviceArrayF32Py> {
1258 use crate::cuda::cuda_available;
1259 if !cuda_available() {
1260 return Err(PyValueError::new_err("CUDA not available"));
1261 }
1262 let prices = data_f32.as_slice()?;
1263 let sweep = NetMyrsiBatchRange {
1264 period: period_range,
1265 };
1266 let (inner, ctx, dev_id) = py.allow_threads(|| {
1267 let cuda = crate::cuda::CudaNetMyrsi::new(device_id)
1268 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1269 let ctx = cuda.context_arc();
1270 let dev_id = cuda.device_id();
1271 cuda.net_myrsi_batch_dev(prices, &sweep)
1272 .map(|(inner, _)| (inner, ctx, dev_id))
1273 .map_err(|e| PyValueError::new_err(e.to_string()))
1274 })?;
1275
1276 Ok(DeviceArrayF32Py {
1277 inner,
1278 _ctx: Some(ctx),
1279 device_id: Some(dev_id),
1280 })
1281}
1282
1283#[cfg(all(feature = "python", feature = "cuda"))]
1284#[pyfunction(name = "net_myrsi_cuda_many_series_one_param_dev")]
1285#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1286pub fn net_myrsi_cuda_many_series_one_param_dev_py(
1287 py: Python<'_>,
1288 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1289 period: usize,
1290 device_id: usize,
1291) -> PyResult<DeviceArrayF32Py> {
1292 use crate::cuda::cuda_available;
1293 use numpy::PyUntypedArrayMethods;
1294 if !cuda_available() {
1295 return Err(PyValueError::new_err("CUDA not available"));
1296 }
1297 let flat = data_tm_f32.as_slice()?;
1298 let rows = data_tm_f32.shape()[0];
1299 let cols = data_tm_f32.shape()[1];
1300 let params = NetMyrsiParams {
1301 period: Some(period),
1302 };
1303 let (inner, ctx, dev_id) = py.allow_threads(|| {
1304 let cuda = crate::cuda::CudaNetMyrsi::new(device_id)
1305 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1306 let ctx = cuda.context_arc();
1307 let dev_id = cuda.device_id();
1308 cuda.net_myrsi_many_series_one_param_time_major_dev(flat, cols, rows, ¶ms)
1309 .map(|inner| (inner, ctx, dev_id))
1310 .map_err(|e| PyValueError::new_err(e.to_string()))
1311 })?;
1312 Ok(DeviceArrayF32Py {
1313 inner,
1314 _ctx: Some(ctx),
1315 device_id: Some(dev_id),
1316 })
1317}
1318
1319#[cfg(feature = "python")]
1320#[pyclass(name = "NetMyrsiStream")]
1321pub struct NetMyrsiStreamPy {
1322 stream: NetMyrsiStream,
1323}
1324
1325#[cfg(feature = "python")]
1326#[pymethods]
1327impl NetMyrsiStreamPy {
1328 #[new]
1329 pub fn new(period: usize) -> PyResult<Self> {
1330 let params = NetMyrsiParams {
1331 period: Some(period),
1332 };
1333 let stream =
1334 NetMyrsiStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1335 Ok(Self { stream })
1336 }
1337
1338 pub fn update(&mut self, value: f64) -> Option<f64> {
1339 self.stream.update(value)
1340 }
1341}
1342
1343#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1344#[wasm_bindgen]
1345pub fn net_myrsi_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1346 let input = NetMyrsiInput::from_slice(
1347 data,
1348 NetMyrsiParams {
1349 period: Some(period),
1350 },
1351 );
1352 let mut out = vec![f64::NAN; data.len()];
1353 net_myrsi_into_slice(&mut out, &input, detect_best_kernel())
1354 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1355 Ok(out)
1356}
1357
1358#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1359#[wasm_bindgen]
1360pub fn net_myrsi_alloc(len: usize) -> *mut f64 {
1361 let mut v = Vec::<f64>::with_capacity(len);
1362 let p = v.as_mut_ptr();
1363 std::mem::forget(v);
1364 p
1365}
1366
1367#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1368#[wasm_bindgen]
1369pub fn net_myrsi_free(ptr: *mut f64, len: usize) {
1370 unsafe {
1371 let _ = Vec::from_raw_parts(ptr, len, len);
1372 }
1373}
1374
1375#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1376#[wasm_bindgen]
1377pub fn net_myrsi_into(
1378 in_ptr: *const f64,
1379 out_ptr: *mut f64,
1380 len: usize,
1381 period: usize,
1382) -> Result<(), JsValue> {
1383 if in_ptr.is_null() || out_ptr.is_null() {
1384 return Err(JsValue::from_str("null pointer"));
1385 }
1386 unsafe {
1387 let data = std::slice::from_raw_parts(in_ptr, len);
1388 let input = NetMyrsiInput::from_slice(
1389 data,
1390 NetMyrsiParams {
1391 period: Some(period),
1392 },
1393 );
1394
1395 if in_ptr == out_ptr {
1396 let mut tmp = vec![0.0; len];
1397 net_myrsi_into_slice(&mut tmp, &input, detect_best_kernel())
1398 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1399 std::slice::from_raw_parts_mut(out_ptr, len).copy_from_slice(&tmp);
1400 } else {
1401 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1402 net_myrsi_into_slice(out, &input, detect_best_kernel())
1403 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1404 }
1405 Ok(())
1406 }
1407}
1408
1409#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1410#[derive(Serialize, Deserialize)]
1411pub struct NetMyrsiBatchConfig {
1412 pub period_range: (usize, usize, usize),
1413}
1414
1415#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1416#[derive(Serialize, Deserialize)]
1417pub struct NetMyrsiBatchJsOutput {
1418 pub values: Vec<f64>,
1419 pub combos: Vec<NetMyrsiParams>,
1420 pub rows: usize,
1421 pub cols: usize,
1422}
1423
1424#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1425#[wasm_bindgen(js_name = net_myrsi_batch)]
1426pub fn net_myrsi_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1427 let cfg: NetMyrsiBatchConfig = serde_wasm_bindgen::from_value(config)
1428 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1429 let sweep = NetMyrsiBatchRange {
1430 period: cfg.period_range,
1431 };
1432 let out = net_myrsi_batch_with_kernel(data, &sweep, detect_best_kernel())
1433 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1434 serde_wasm_bindgen::to_value(&NetMyrsiBatchJsOutput {
1435 values: out.values,
1436 combos: out.combos,
1437 rows: out.rows,
1438 cols: out.cols,
1439 })
1440 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1441}
1442
1443#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1444#[wasm_bindgen]
1445pub fn net_myrsi_batch_into(
1446 in_ptr: *const f64,
1447 out_ptr: *mut f64,
1448 len: usize,
1449 period_start: usize,
1450 period_end: usize,
1451 period_step: usize,
1452) -> Result<usize, JsValue> {
1453 if in_ptr.is_null() || out_ptr.is_null() {
1454 return Err(JsValue::from_str(
1455 "null pointer passed to net_myrsi_batch_into",
1456 ));
1457 }
1458 unsafe {
1459 let data = std::slice::from_raw_parts(in_ptr, len);
1460 let periods = expand_grid_period((period_start, period_end, period_step))
1461 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1462 let combos: Vec<NetMyrsiParams> = periods
1463 .into_iter()
1464 .map(|p| NetMyrsiParams { period: Some(p) })
1465 .collect();
1466 let rows = combos.len();
1467 let cols = len;
1468 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
1469
1470 net_myrsi_batch_inner_into(data, &combos, detect_best_kernel(), out)
1471 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1472 Ok(rows)
1473 }
1474}
1475
1476#[cfg(test)]
1477mod tests {
1478 use super::*;
1479 use crate::skip_if_unsupported;
1480 use crate::utilities::data_loader::read_candles_from_csv;
1481 #[cfg(feature = "proptest")]
1482 use proptest::prelude::*;
1483 use std::error::Error;
1484
1485 fn check_net_myrsi_partial_params(
1486 test_name: &str,
1487 kernel: Kernel,
1488 ) -> Result<(), Box<dyn Error>> {
1489 skip_if_unsupported!(kernel, test_name);
1490 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1491 let candles = read_candles_from_csv(file_path)?;
1492
1493 let default_params = NetMyrsiParams { period: None };
1494 let input = NetMyrsiInput::from_candles(&candles, "close", default_params);
1495 let output = net_myrsi_with_kernel(&input, kernel)?;
1496 assert_eq!(output.values.len(), candles.close.len());
1497
1498 Ok(())
1499 }
1500
1501 fn check_net_myrsi_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1502 skip_if_unsupported!(kernel, test_name);
1503 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1504 let candles = read_candles_from_csv(file_path)?;
1505
1506 let input = NetMyrsiInput::from_candles(&candles, "close", NetMyrsiParams::default());
1507 let result = net_myrsi_with_kernel(&input, kernel)?;
1508
1509 let expected_last_five = [0.64835165, 0.49450549, 0.29670330, 0.07692308, -0.07692308];
1510
1511 let start = result.values.len().saturating_sub(5);
1512 for (i, &val) in result.values[start..].iter().enumerate() {
1513 let diff = (val - expected_last_five[i]).abs();
1514 assert!(
1515 diff < 1e-6,
1516 "[{}] NET_MYRSI {:?} mismatch at idx {}: got {}, expected {}",
1517 test_name,
1518 kernel,
1519 i,
1520 val,
1521 expected_last_five[i]
1522 );
1523 }
1524 Ok(())
1525 }
1526
1527 fn check_net_myrsi_default_candles(
1528 test_name: &str,
1529 kernel: Kernel,
1530 ) -> Result<(), Box<dyn Error>> {
1531 skip_if_unsupported!(kernel, test_name);
1532 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1533 let candles = read_candles_from_csv(file_path)?;
1534
1535 let input = NetMyrsiInput::with_default_candles(&candles);
1536 match input.data {
1537 NetMyrsiData::Candles { source, .. } => assert_eq!(source, "close"),
1538 _ => panic!("Expected NetMyrsiData::Candles"),
1539 }
1540 let output = net_myrsi_with_kernel(&input, kernel)?;
1541 assert_eq!(output.values.len(), candles.close.len());
1542
1543 Ok(())
1544 }
1545
1546 fn check_net_myrsi_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1547 skip_if_unsupported!(kernel, test_name);
1548 let input_data = [10.0, 20.0, 30.0];
1549 let params = NetMyrsiParams { period: Some(0) };
1550 let input = NetMyrsiInput::from_slice(&input_data, params);
1551 let res = net_myrsi_with_kernel(&input, kernel);
1552 assert!(
1553 res.is_err(),
1554 "[{}] NET_MYRSI should fail with zero period",
1555 test_name
1556 );
1557 Ok(())
1558 }
1559
1560 fn check_net_myrsi_period_exceeds_length(
1561 test_name: &str,
1562 kernel: Kernel,
1563 ) -> Result<(), Box<dyn Error>> {
1564 skip_if_unsupported!(kernel, test_name);
1565 let data_small = [10.0, 20.0, 30.0];
1566 let params = NetMyrsiParams { period: Some(10) };
1567 let input = NetMyrsiInput::from_slice(&data_small, params);
1568 let res = net_myrsi_with_kernel(&input, kernel);
1569 assert!(
1570 res.is_err(),
1571 "[{}] NET_MYRSI should fail with period exceeding length",
1572 test_name
1573 );
1574 Ok(())
1575 }
1576
1577 fn check_net_myrsi_very_small_dataset(
1578 test_name: &str,
1579 kernel: Kernel,
1580 ) -> Result<(), Box<dyn Error>> {
1581 skip_if_unsupported!(kernel, test_name);
1582 let data_small = vec![10.0, 20.0, 30.0, 15.0, 25.0];
1583 let params = NetMyrsiParams { period: Some(3) };
1584 let input = NetMyrsiInput::from_slice(&data_small, params);
1585 let result = net_myrsi_with_kernel(&input, kernel)?;
1586 assert_eq!(result.values.len(), data_small.len());
1587 Ok(())
1588 }
1589
1590 fn check_net_myrsi_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1591 skip_if_unsupported!(kernel, test_name);
1592 let data: Vec<f64> = vec![];
1593 let params = NetMyrsiParams::default();
1594 let input = NetMyrsiInput::from_slice(&data, params);
1595 let res = net_myrsi_with_kernel(&input, kernel);
1596 assert!(
1597 res.is_err(),
1598 "[{}] NET_MYRSI should fail with empty input",
1599 test_name
1600 );
1601 Ok(())
1602 }
1603
1604 fn check_net_myrsi_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1605 skip_if_unsupported!(kernel, test_name);
1606 let data = vec![f64::NAN; 30];
1607 let params = NetMyrsiParams::default();
1608 let input = NetMyrsiInput::from_slice(&data, params);
1609 let res = net_myrsi_with_kernel(&input, kernel);
1610 assert!(
1611 res.is_err(),
1612 "[{}] NET_MYRSI should fail with all NaN values",
1613 test_name
1614 );
1615 Ok(())
1616 }
1617
1618 fn check_net_myrsi_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1619 skip_if_unsupported!(kernel, test_name);
1620 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1621 let candles = read_candles_from_csv(file_path)?;
1622
1623 let first_params = NetMyrsiParams { period: Some(14) };
1624 let first_input = NetMyrsiInput::from_candles(&candles, "close", first_params);
1625 let first_result = net_myrsi_with_kernel(&first_input, kernel)?;
1626
1627 let second_params = NetMyrsiParams { period: Some(14) };
1628 let second_input = NetMyrsiInput::from_slice(&first_result.values, second_params);
1629 let second_result = net_myrsi_with_kernel(&second_input, kernel)?;
1630
1631 assert_eq!(second_result.values.len(), first_result.values.len());
1632
1633 let non_nan_count = second_result.values.iter().filter(|v| !v.is_nan()).count();
1634 assert!(
1635 non_nan_count > 0,
1636 "[{}] Second pass should produce non-NaN values",
1637 test_name
1638 );
1639
1640 Ok(())
1641 }
1642
1643 fn check_net_myrsi_warmup_nans(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1644 skip_if_unsupported!(kernel, test_name);
1645 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1646 let candles = read_candles_from_csv(file_path)?;
1647 let first = candles.close.iter().position(|x| !x.is_nan()).unwrap_or(0);
1648 let p = 14;
1649 let out = net_myrsi_with_kernel(
1650 &NetMyrsiInput::from_candles(&candles, "close", NetMyrsiParams { period: Some(p) }),
1651 kernel,
1652 )?
1653 .values;
1654 let warm = first + p - 1;
1655 assert!(
1656 out[..warm].iter().all(|v| v.is_nan()),
1657 "[{}] Warmup NaNs should not be overwritten",
1658 test_name
1659 );
1660 Ok(())
1661 }
1662
1663 fn check_net_myrsi_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1664 skip_if_unsupported!(kernel, test_name);
1665 let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1666 for _ in 0..20 {
1667 data.push(data[data.len() - 1] + 1.0);
1668 }
1669
1670 data[15] = f64::NAN;
1671
1672 let params = NetMyrsiParams { period: Some(14) };
1673 let input = NetMyrsiInput::from_slice(&data, params);
1674 let result = net_myrsi_with_kernel(&input, kernel)?;
1675
1676 assert_eq!(result.values.len(), data.len());
1677
1678 let non_nan_count = result.values.iter().filter(|v| !v.is_nan()).count();
1679 assert!(
1680 non_nan_count > 0,
1681 "[{}] Should produce some non-NaN values despite NaN input",
1682 test_name
1683 );
1684
1685 Ok(())
1686 }
1687
1688 fn check_net_myrsi_streaming(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1689 skip_if_unsupported!(kernel, test);
1690
1691 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1692 let c = read_candles_from_csv(file)?;
1693 let period = 14;
1694
1695 let batch = NetMyrsiInput::from_candles(
1696 &c,
1697 "close",
1698 NetMyrsiParams {
1699 period: Some(period),
1700 },
1701 );
1702 let b = net_myrsi_with_kernel(&batch, kernel)?.values;
1703
1704 let mut s = NetMyrsiStream::try_new(NetMyrsiParams {
1705 period: Some(period),
1706 })?;
1707 let mut stream = Vec::with_capacity(c.close.len());
1708 for &px in &c.close {
1709 stream.push(s.update(px).unwrap_or(f64::NAN));
1710 }
1711
1712 let first_valid = b.iter().position(|v| !v.is_nan()).unwrap();
1713 for i in first_valid..b.len() {
1714 if !b[i].is_nan() && !stream[i].is_nan() {
1715 assert!(
1716 (b[i] - stream[i]).abs() < 1e-10,
1717 "[{test}] idx {i}: {} vs {}",
1718 b[i],
1719 stream[i]
1720 );
1721 }
1722 }
1723
1724 Ok(())
1725 }
1726
1727 #[cfg(debug_assertions)]
1728 fn check_net_myrsi_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1729 skip_if_unsupported!(kernel, test_name);
1730
1731 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1732 let candles = read_candles_from_csv(file_path)?;
1733
1734 let test_params = vec![
1735 NetMyrsiParams::default(),
1736 NetMyrsiParams { period: Some(10) },
1737 NetMyrsiParams { period: Some(20) },
1738 ];
1739
1740 for params in test_params {
1741 let input = NetMyrsiInput::from_candles(&candles, "close", params.clone());
1742 let output = net_myrsi_with_kernel(&input, kernel)?;
1743
1744 for (i, &val) in output.values.iter().enumerate() {
1745 if val.is_nan() {
1746 continue;
1747 }
1748 let bits = val.to_bits();
1749 assert_ne!(
1750 bits, 0x11111111_11111111,
1751 "[{}] alloc_with_nan_prefix poison at {} with params {:?}",
1752 test_name, i, params
1753 );
1754 }
1755 }
1756
1757 Ok(())
1758 }
1759
1760 #[cfg(not(debug_assertions))]
1761 fn check_net_myrsi_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1762 Ok(())
1763 }
1764
1765 #[cfg(feature = "proptest")]
1766 #[allow(clippy::float_cmp)]
1767 fn check_net_myrsi_property(
1768 test_name: &str,
1769 kernel: Kernel,
1770 ) -> Result<(), Box<dyn std::error::Error>> {
1771 use proptest::prelude::*;
1772 skip_if_unsupported!(kernel, test_name);
1773
1774 let strat = (2usize..=50).prop_flat_map(|period| {
1775 (
1776 prop::collection::vec(
1777 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1778 period + 2..400,
1779 ),
1780 Just(period),
1781 )
1782 });
1783
1784 proptest::test_runner::TestRunner::default()
1785 .run(&strat, |(data, period)| {
1786 let params = NetMyrsiParams {
1787 period: Some(period),
1788 };
1789 let input = NetMyrsiInput::from_slice(&data, params);
1790
1791 let NetMyrsiOutput { values: out } = net_myrsi_with_kernel(&input, kernel).unwrap();
1792 let NetMyrsiOutput { values: ref_out } =
1793 net_myrsi_with_kernel(&input, Kernel::Scalar).unwrap();
1794
1795 assert_eq!(out.len(), ref_out.len());
1796 assert_eq!(out.len(), data.len());
1797
1798 let first_valid = ref_out
1799 .iter()
1800 .position(|v| !v.is_nan())
1801 .unwrap_or(ref_out.len());
1802
1803 for i in 0..first_valid {
1804 assert!(out[i].is_nan(), "Expected NaN at index {}", i);
1805 }
1806
1807 for i in first_valid..out.len() {
1808 if ref_out[i].is_nan() {
1809 assert!(out[i].is_nan(), "Expected NaN at index {}", i);
1810 } else {
1811 let diff = (out[i] - ref_out[i]).abs();
1812 let rel_diff = diff / ref_out[i].abs().max(1e-10);
1813 assert!(
1814 rel_diff < 1e-10,
1815 "Mismatch at {}: kernel={}, scalar={}, diff={}",
1816 i,
1817 out[i],
1818 ref_out[i],
1819 diff
1820 );
1821 }
1822 }
1823
1824 Ok(())
1825 })
1826 .map_err(|e| format!("Property test failed: {}", e).into())
1827 }
1828
1829 #[cfg(not(feature = "proptest"))]
1830 fn check_net_myrsi_property(
1831 _test_name: &str,
1832 _kernel: Kernel,
1833 ) -> Result<(), Box<dyn std::error::Error>> {
1834 Ok(())
1835 }
1836
1837 macro_rules! generate_all_net_myrsi_tests {
1838 ($($test_fn:ident),*) => {
1839 paste::paste! {
1840 $(
1841 #[test]
1842 fn [<$test_fn _scalar>]() -> Result<(), Box<dyn Error>> {
1843 $test_fn(stringify!([<$test_fn _scalar>]), Kernel::Scalar)
1844 }
1845 )*
1846 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1847 $(
1848 #[test]
1849 fn [<$test_fn _avx2>]() -> Result<(), Box<dyn Error>> {
1850 $test_fn(stringify!([<$test_fn _avx2>]), Kernel::Avx2)
1851 }
1852 )*
1853 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1854 $(
1855 #[test]
1856 fn [<$test_fn _avx512>]() -> Result<(), Box<dyn Error>> {
1857 $test_fn(stringify!([<$test_fn _avx512>]), Kernel::Avx512)
1858 }
1859 )*
1860 }
1861 };
1862 }
1863
1864 generate_all_net_myrsi_tests!(
1865 check_net_myrsi_partial_params,
1866 check_net_myrsi_accuracy,
1867 check_net_myrsi_default_candles,
1868 check_net_myrsi_zero_period,
1869 check_net_myrsi_period_exceeds_length,
1870 check_net_myrsi_very_small_dataset,
1871 check_net_myrsi_empty_input,
1872 check_net_myrsi_all_nan,
1873 check_net_myrsi_reinput,
1874 check_net_myrsi_warmup_nans,
1875 check_net_myrsi_nan_handling,
1876 check_net_myrsi_streaming,
1877 check_net_myrsi_no_poison,
1878 check_net_myrsi_property
1879 );
1880
1881 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1882 skip_if_unsupported!(kernel, test);
1883 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1884 let c = read_candles_from_csv(file)?;
1885 let out = NetMyrsiBatchBuilder::new()
1886 .kernel(kernel)
1887 .period_range(10, 20, 1)
1888 .apply_candles(&c, "close")?;
1889 let def = NetMyrsiParams::default();
1890 let row = out.values_for(&def).expect("default row missing");
1891 assert_eq!(row.len(), c.close.len());
1892 Ok(())
1893 }
1894
1895 #[cfg(debug_assertions)]
1896 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1897 skip_if_unsupported!(kernel, test);
1898 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1899 let c = read_candles_from_csv(file)?;
1900 let out = NetMyrsiBatchBuilder::new()
1901 .kernel(kernel)
1902 .period_range(10, 20, 5)
1903 .apply_candles(&c, "close")?;
1904 for (idx, &v) in out.values.iter().enumerate() {
1905 if v.is_nan() {
1906 continue;
1907 }
1908 let bits = v.to_bits();
1909 assert_ne!(
1910 bits, 0x1111_1111_1111_1111,
1911 "[{test}] alloc_with_nan_prefix poison at {idx}"
1912 );
1913 assert_ne!(
1914 bits, 0x2222_2222_2222_2222,
1915 "[{test}] init_matrix_prefixes poison at {idx}"
1916 );
1917 assert_ne!(
1918 bits, 0x3333_3333_3333_3333,
1919 "[{test}] make_uninit_matrix poison at {idx}"
1920 );
1921 }
1922 Ok(())
1923 }
1924
1925 #[cfg(not(debug_assertions))]
1926 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1927 Ok(())
1928 }
1929
1930 fn check_batch_sweep(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1931 skip_if_unsupported!(kernel, test);
1932 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1933 let c = read_candles_from_csv(file)?;
1934 let out = NetMyrsiBatchBuilder::new()
1935 .kernel(kernel)
1936 .period_range(10, 20, 5)
1937 .apply_candles(&c, "close")?;
1938 assert_eq!(out.rows, 3);
1939 assert_eq!(out.cols, c.close.len());
1940 assert_eq!(out.values.len(), out.rows * out.cols);
1941
1942 assert_eq!(out.combos.len(), 3);
1943 assert_eq!(out.combos[0].period, Some(10));
1944 assert_eq!(out.combos[1].period, Some(15));
1945 assert_eq!(out.combos[2].period, Some(20));
1946 Ok(())
1947 }
1948
1949 macro_rules! gen_batch_tests {
1950 ($fn_name:ident) => {
1951 paste::paste! {
1952 #[test] fn [<$fn_name _scalar>]() {
1953 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1954 }
1955 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1956 #[test] fn [<$fn_name _avx2>]() {
1957 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1958 }
1959 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1960 #[test] fn [<$fn_name _avx512>]() {
1961 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1962 }
1963 #[test] fn [<$fn_name _auto_detect>]() {
1964 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]),
1965 Kernel::Auto);
1966 }
1967 }
1968 };
1969 }
1970
1971 gen_batch_tests!(check_batch_default_row);
1972 gen_batch_tests!(check_batch_sweep);
1973 gen_batch_tests!(check_batch_no_poison);
1974
1975 #[test]
1976 fn test_net_myrsi_into_matches_api() -> Result<(), Box<dyn Error>> {
1977 let mut data = Vec::with_capacity(256);
1978 data.extend_from_slice(&[f64::NAN, f64::NAN, f64::NAN]);
1979 for i in 0..(256 - 3) {
1980 let x = i as f64;
1981 data.push((x.sin() + (0.1 * x).cos()) * 10.0);
1982 }
1983
1984 let input = NetMyrsiInput::from_slice(&data, NetMyrsiParams::default());
1985
1986 let baseline = net_myrsi(&input)?.values;
1987
1988 let mut out = vec![0.0; data.len()];
1989
1990 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1991 {
1992 net_myrsi_into(&input, &mut out)?;
1993 }
1994 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1995 {
1996 net_myrsi_into_slice(&mut out, &input, Kernel::Auto)?;
1997 }
1998
1999 assert_eq!(baseline.len(), out.len());
2000
2001 fn eq_or_both_nan(a: f64, b: f64) -> bool {
2002 (a.is_nan() && b.is_nan()) || (a == b)
2003 }
2004
2005 for (i, (&a, &b)) in baseline.iter().zip(out.iter()).enumerate() {
2006 assert!(
2007 eq_or_both_nan(a, b),
2008 "value mismatch at {}: baseline={}, into={}",
2009 i,
2010 a,
2011 b
2012 );
2013 }
2014 Ok(())
2015 }
2016}