Skip to main content

vector_ta/indicators/
net_myrsi.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::indicators::moving_averages::alma::DeviceArrayF32Py;
3#[cfg(feature = "python")]
4use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyValueError;
7#[cfg(feature = "python")]
8use pyo3::prelude::*;
9#[cfg(feature = "python")]
10use pyo3::types::PyDict;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::{source_type, Candles};
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21    make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25
26#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
27use core::arch::x86_64::*;
28
29#[cfg(not(target_arch = "wasm32"))]
30use rayon::prelude::*;
31
32use std::convert::AsRef;
33use std::error::Error;
34use std::mem::MaybeUninit;
35use thiserror::Error;
36
37impl<'a> AsRef<[f64]> for NetMyrsiInput<'a> {
38    #[inline(always)]
39    fn as_ref(&self) -> &[f64] {
40        match &self.data {
41            NetMyrsiData::Slice(slice) => slice,
42            NetMyrsiData::Candles { candles, source } => source_type(candles, source),
43        }
44    }
45}
46
47#[derive(Debug, Clone)]
48pub enum NetMyrsiData<'a> {
49    Candles {
50        candles: &'a Candles,
51        source: &'a str,
52    },
53    Slice(&'a [f64]),
54}
55
56#[derive(Debug, Clone)]
57pub struct NetMyrsiOutput {
58    pub values: Vec<f64>,
59}
60
61#[derive(Debug, Clone)]
62#[cfg_attr(
63    all(target_arch = "wasm32", feature = "wasm"),
64    derive(Serialize, Deserialize)
65)]
66pub struct NetMyrsiParams {
67    pub period: Option<usize>,
68}
69
70impl Default for NetMyrsiParams {
71    fn default() -> Self {
72        Self { period: Some(14) }
73    }
74}
75
76#[derive(Debug, Clone)]
77pub struct NetMyrsiInput<'a> {
78    pub data: NetMyrsiData<'a>,
79    pub params: NetMyrsiParams,
80}
81
82impl<'a> NetMyrsiInput<'a> {
83    #[inline]
84    pub fn from_candles(c: &'a Candles, s: &'a str, p: NetMyrsiParams) -> Self {
85        Self {
86            data: NetMyrsiData::Candles {
87                candles: c,
88                source: s,
89            },
90            params: p,
91        }
92    }
93
94    #[inline]
95    pub fn from_slice(sl: &'a [f64], p: NetMyrsiParams) -> Self {
96        Self {
97            data: NetMyrsiData::Slice(sl),
98            params: p,
99        }
100    }
101
102    #[inline]
103    pub fn with_default_candles(c: &'a Candles) -> Self {
104        Self::from_candles(c, "close", NetMyrsiParams::default())
105    }
106
107    #[inline]
108    pub fn get_period(&self) -> usize {
109        self.params.period.unwrap_or(14)
110    }
111}
112
113#[derive(Copy, Clone, Debug)]
114pub struct NetMyrsiBuilder {
115    period: Option<usize>,
116    kernel: Kernel,
117}
118
119impl Default for NetMyrsiBuilder {
120    fn default() -> Self {
121        Self {
122            period: None,
123            kernel: Kernel::Auto,
124        }
125    }
126}
127
128impl NetMyrsiBuilder {
129    #[inline(always)]
130    pub fn new() -> Self {
131        Self::default()
132    }
133
134    #[inline(always)]
135    pub fn period(mut self, val: usize) -> Self {
136        self.period = Some(val);
137        self
138    }
139
140    #[inline(always)]
141    pub fn kernel(mut self, k: Kernel) -> Self {
142        self.kernel = k;
143        self
144    }
145
146    #[inline(always)]
147    pub fn apply(self, c: &Candles) -> Result<NetMyrsiOutput, NetMyrsiError> {
148        let p = NetMyrsiParams {
149            period: self.period,
150        };
151        let i = NetMyrsiInput::from_candles(c, "close", p);
152        net_myrsi_with_kernel(&i, self.kernel)
153    }
154
155    #[inline(always)]
156    pub fn apply_slice(self, d: &[f64]) -> Result<NetMyrsiOutput, NetMyrsiError> {
157        let p = NetMyrsiParams {
158            period: self.period,
159        };
160        let i = NetMyrsiInput::from_slice(d, p);
161        net_myrsi_with_kernel(&i, self.kernel)
162    }
163
164    #[inline(always)]
165    pub fn into_stream(self) -> Result<NetMyrsiStream, NetMyrsiError> {
166        let p = NetMyrsiParams {
167            period: self.period,
168        };
169        NetMyrsiStream::try_new(p)
170    }
171}
172
173#[derive(Debug, Error)]
174pub enum NetMyrsiError {
175    #[error("net_myrsi: Input data slice is empty.")]
176    EmptyInputData,
177
178    #[error("net_myrsi: All values are NaN.")]
179    AllValuesNaN,
180
181    #[error("net_myrsi: Invalid period: period = {period}, data length = {data_len}")]
182    InvalidPeriod { period: usize, data_len: usize },
183
184    #[error("net_myrsi: Not enough valid data: needed = {needed}, valid = {valid}")]
185    NotEnoughValidData { needed: usize, valid: usize },
186
187    #[error("net_myrsi: Output length mismatch: expected {expected}, got {got}")]
188    OutputLengthMismatch { expected: usize, got: usize },
189
190    #[error("net_myrsi: Invalid range: start={start}, end={end}, step={step}")]
191    InvalidRange {
192        start: String,
193        end: String,
194        step: String,
195    },
196
197    #[error("net_myrsi: Invalid kernel for batch: {0:?}")]
198    InvalidKernelForBatch(Kernel),
199}
200
201#[inline(always)]
202fn net_myrsi_compute_into(
203    data: &[f64],
204    period: usize,
205    first: usize,
206    out: &mut [f64],
207    kernel: Kernel,
208) {
209    let k = match kernel {
210        Kernel::Scalar | Kernel::ScalarBatch => Kernel::Scalar,
211        Kernel::Avx2 | Kernel::Avx2Batch => Kernel::Avx2,
212        Kernel::Avx512 | Kernel::Avx512Batch => Kernel::Avx512,
213        Kernel::Auto => Kernel::Scalar,
214    };
215
216    unsafe {
217        match kernel {
218            Kernel::Scalar | Kernel::ScalarBatch => {
219                net_myrsi_kernel_scalar(data, period, first, out)
220            }
221            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
222            Kernel::Avx2 | Kernel::Avx2Batch => net_myrsi_kernel_avx2(data, period, first, out),
223            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
224            Kernel::Avx512 | Kernel::Avx512Batch => {
225                net_myrsi_kernel_avx512(data, period, first, out)
226            }
227            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
228            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
229                net_myrsi_kernel_scalar(data, period, first, out)
230            }
231            Kernel::Auto => match k {
232                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
233                Kernel::Avx512 => net_myrsi_kernel_avx512(data, period, first, out),
234                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
235                Kernel::Avx2 => net_myrsi_kernel_avx2(data, period, first, out),
236                #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
237                Kernel::Avx2 | Kernel::Avx512 => net_myrsi_kernel_scalar(data, period, first, out),
238                Kernel::Scalar => net_myrsi_kernel_scalar(data, period, first, out),
239                Kernel::Auto | Kernel::ScalarBatch | Kernel::Avx2Batch | Kernel::Avx512Batch => {
240                    unreachable!()
241                }
242            },
243        }
244    }
245}
246
247#[inline(always)]
248fn compute_myrsi_from(data: &[f64], period: usize, first: usize, out_myrsi: &mut [f64]) {
249    let start = first + period;
250    let len = data.len();
251
252    for i in start..len {
253        let mut cu = 0.0;
254        let mut cd = 0.0;
255        for j in 0..period {
256            let newer = data[i - j];
257            let older = data[i - j - 1];
258            let diff = newer - older;
259            if diff > 0.0 {
260                cu += diff;
261            } else if diff < 0.0 {
262                cd += -diff;
263            }
264        }
265        let sum = cu + cd;
266        out_myrsi[i] = if sum != 0.0 { (cu - cd) / sum } else { 0.0 };
267    }
268}
269
270#[inline(always)]
271fn compute_net_from(myrsi: &[f64], period: usize, first: usize, out: &mut [f64]) {
272    let start = first + period - 1;
273    let len = myrsi.len();
274    let denom = (period * (period - 1)) as f64 / 2.0;
275
276    for idx in start..len {
277        let mut num = 0.0;
278        for i in 1..period {
279            for k in 0..i {
280                let vi = myrsi[idx - i];
281                let vk = myrsi[idx - k];
282                let d = vi - vk;
283                if d > 0.0 {
284                    num -= 1.0;
285                } else if d < 0.0 {
286                    num += 1.0;
287                }
288            }
289        }
290        out[idx] = num / denom;
291    }
292}
293
294#[inline(always)]
295fn net_myrsi_kernel_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
296    let len = data.len();
297    if len <= first + 1 {
298        return;
299    }
300
301    let mut cu = 0.0f64;
302    let mut cd = 0.0f64;
303    let mut diffs = vec![0.0f64; period];
304    let mut d_head = 0usize;
305    let mut d_count = 0usize;
306
307    let mut myr = vec![0.0f64; period];
308    let mut m_head = 0usize;
309    let mut m_count = 0usize;
310    let mut num: i32 = 0;
311    let denom = (period * (period - 1)) as f64 * 0.5;
312
313    let warm = first + period - 1;
314    if warm < out.len() {
315        out[warm] = if period > 1 { 0.0 } else { f64::NAN };
316    }
317
318    #[inline(always)]
319    fn lt_minus_gt_scalar(slice: &[f64], s: f64) -> i32 {
320        let mut lt: i32 = 0;
321        let mut gt: i32 = 0;
322        for &v in slice {
323            lt += (v < s) as i32;
324            gt += (v > s) as i32;
325        }
326        lt - gt
327    }
328
329    let mut i = first + 1;
330    while i < len {
331        let newer = data[i];
332        let older = data[i - 1];
333        let diff = newer - older;
334
335        cu += ((diff > 0.0) as i32 as f64) * diff;
336        cd += ((diff < 0.0) as i32 as f64) * (-diff);
337
338        if d_count < period {
339            diffs[d_head] = diff;
340            d_head += 1;
341            if d_head == period {
342                d_head = 0;
343            }
344            d_count += 1;
345        } else {
346            let old = diffs[d_head];
347            cu -= ((old > 0.0) as i32 as f64) * old;
348            cd -= ((old < 0.0) as i32 as f64) * (-old);
349            diffs[d_head] = diff;
350            d_head += 1;
351            if d_head == period {
352                d_head = 0;
353            }
354        }
355
356        if d_count >= period {
357            let sum = cu + cd;
358            let r = if sum != 0.0 { (cu - cd) / sum } else { 0.0 };
359
360            if m_count < period {
361                let add = lt_minus_gt_scalar(&myr[..m_head], r);
362                num += add;
363                myr[m_head] = r;
364                m_head += 1;
365                if m_head == period {
366                    m_head = 0;
367                }
368                m_count += 1;
369            } else {
370                let z = myr[m_head];
371                let rm1 = if m_head + 1 < period {
372                    lt_minus_gt_scalar(&myr[m_head + 1..period], z)
373                } else {
374                    0
375                };
376                let rm2 = if m_head > 0 {
377                    lt_minus_gt_scalar(&myr[..m_head], z)
378                } else {
379                    0
380                };
381                num += rm1 + rm2;
382
383                let ad1 = if m_head + 1 < period {
384                    lt_minus_gt_scalar(&myr[m_head + 1..period], r)
385                } else {
386                    0
387                };
388                let ad2 = if m_head > 0 {
389                    lt_minus_gt_scalar(&myr[..m_head], r)
390                } else {
391                    0
392                };
393                num += ad1 + ad2;
394
395                myr[m_head] = r;
396                m_head += 1;
397                if m_head == period {
398                    m_head = 0;
399                }
400            }
401
402            if denom != 0.0 {
403                out[i] = (num as f64) / denom;
404            }
405        }
406
407        i += 1;
408    }
409}
410
411#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
412#[inline(always)]
413unsafe fn net_myrsi_kernel_avx2(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
414    use core::arch::x86_64::*;
415    if data.len() <= first + 1 {
416        return;
417    }
418
419    let mut cu = 0.0f64;
420    let mut cd = 0.0f64;
421    let mut diffs = vec![0.0f64; period];
422    let mut d_head = 0usize;
423    let mut d_count = 0usize;
424
425    let mut myr = vec![0.0f64; period];
426    let mut m_head = 0usize;
427    let mut m_count = 0usize;
428    let mut num: i32 = 0;
429    let denom = (period * (period - 1)) as f64 * 0.5;
430
431    let warm = first + period - 1;
432    if warm < out.len() {
433        out[warm] = if period > 1 { 0.0 } else { f64::NAN };
434    }
435
436    #[inline(always)]
437    unsafe fn lt_minus_gt_avx2(ptr: *const f64, len: usize, s: f64) -> i32 {
438        let mut lt: i32 = 0;
439        let mut gt: i32 = 0;
440        let mut p = ptr;
441        let mut n = len;
442        let vs = _mm256_set1_pd(s);
443
444        while n >= 4 {
445            let v = _mm256_loadu_pd(p);
446            let mlt = _mm256_cmp_pd(v, vs, _CMP_LT_OQ);
447            let mgt = _mm256_cmp_pd(v, vs, _CMP_GT_OQ);
448            lt += (_mm256_movemask_pd(mlt) as u32).count_ones() as i32;
449            gt += (_mm256_movemask_pd(mgt) as u32).count_ones() as i32;
450            p = p.add(4);
451            n -= 4;
452        }
453        let slice = core::slice::from_raw_parts(p, n);
454        let mut i = 0usize;
455        while i < slice.len() {
456            let v = *slice.get_unchecked(i);
457            lt += (v < s) as i32;
458            gt += (v > s) as i32;
459            i += 1;
460        }
461        lt - gt
462    }
463
464    #[inline(always)]
465    unsafe fn lt_minus_gt_avx2_two(a: &[f64], b: &[f64], s: f64) -> i32 {
466        lt_minus_gt_avx2(a.as_ptr(), a.len(), s) + lt_minus_gt_avx2(b.as_ptr(), b.len(), s)
467    }
468
469    let len = data.len();
470    let mut i = first + 1;
471    while i < len {
472        let newer = data[i];
473        let older = data[i - 1];
474        let diff = newer - older;
475
476        cu += ((diff > 0.0) as i32 as f64) * diff;
477        cd += ((diff < 0.0) as i32 as f64) * (-diff);
478
479        if d_count < period {
480            *diffs.get_unchecked_mut(d_head) = diff;
481            d_head += 1;
482            if d_head == period {
483                d_head = 0;
484            }
485            d_count += 1;
486        } else {
487            let old = *diffs.get_unchecked(d_head);
488            cu -= ((old > 0.0) as i32 as f64) * old;
489            cd -= ((old < 0.0) as i32 as f64) * (-old);
490            *diffs.get_unchecked_mut(d_head) = diff;
491            d_head += 1;
492            if d_head == period {
493                d_head = 0;
494            }
495        }
496
497        if d_count >= period {
498            let sum = cu + cd;
499            let r = if sum != 0.0 { (cu - cd) / sum } else { 0.0 };
500
501            if m_count < period {
502                let add = lt_minus_gt_avx2(myr.as_ptr(), m_head, r);
503                num += add;
504                *myr.get_unchecked_mut(m_head) = r;
505                m_head += 1;
506                if m_head == period {
507                    m_head = 0;
508                }
509                m_count += 1;
510            } else {
511                let z = *myr.get_unchecked(m_head);
512                let ad_rm = if m_head + 1 < period {
513                    lt_minus_gt_avx2_two(&myr[m_head + 1..period], &myr[..m_head], z)
514                } else {
515                    lt_minus_gt_avx2(myr.as_ptr(), m_head, z)
516                };
517                num += ad_rm;
518
519                let ad_new = if m_head + 1 < period {
520                    lt_minus_gt_avx2_two(&myr[m_head + 1..period], &myr[..m_head], r)
521                } else {
522                    lt_minus_gt_avx2(myr.as_ptr(), m_head, r)
523                };
524                num += ad_new;
525
526                *myr.get_unchecked_mut(m_head) = r;
527                m_head += 1;
528                if m_head == period {
529                    m_head = 0;
530                }
531            }
532
533            if denom != 0.0 {
534                out[i] = (num as f64) / denom;
535            }
536        }
537
538        i += 1;
539    }
540}
541
542#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
543#[inline(always)]
544unsafe fn net_myrsi_kernel_avx512(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
545    use core::arch::x86_64::*;
546    if data.len() <= first + 1 {
547        return;
548    }
549
550    let mut cu = 0.0f64;
551    let mut cd = 0.0f64;
552    let mut diffs = vec![0.0f64; period];
553    let mut d_head = 0usize;
554    let mut d_count = 0usize;
555
556    let mut myr = vec![0.0f64; period];
557    let mut m_head = 0usize;
558    let mut m_count = 0usize;
559    let mut num: i32 = 0;
560    let denom = (period * (period - 1)) as f64 * 0.5;
561
562    let warm = first + period - 1;
563    if warm < out.len() {
564        out[warm] = if period > 1 { 0.0 } else { f64::NAN };
565    }
566
567    #[inline(always)]
568    unsafe fn lt_minus_gt_avx512(ptr: *const f64, len: usize, s: f64) -> i32 {
569        let mut lt: i32 = 0;
570        let mut gt: i32 = 0;
571        let mut p = ptr;
572        let mut n = len;
573        let vs = _mm512_set1_pd(s);
574
575        while n >= 8 {
576            let v = _mm512_loadu_pd(p);
577            let mlt: __mmask8 = _mm512_cmp_pd_mask(v, vs, _CMP_LT_OQ);
578            let mgt: __mmask8 = _mm512_cmp_pd_mask(v, vs, _CMP_GT_OQ);
579            lt += (mlt as u32).count_ones() as i32;
580            gt += (mgt as u32).count_ones() as i32;
581            p = p.add(8);
582            n -= 8;
583        }
584        if n >= 4 {
585            let v = _mm256_loadu_pd(p);
586            let vs2 = _mm256_set1_pd(s);
587            let mlt = _mm256_cmp_pd(v, vs2, _CMP_LT_OQ);
588            let mgt = _mm256_cmp_pd(v, vs2, _CMP_GT_OQ);
589            lt += (_mm256_movemask_pd(mlt) as u32).count_ones() as i32;
590            gt += (_mm256_movemask_pd(mgt) as u32).count_ones() as i32;
591            p = p.add(4);
592            n -= 4;
593        }
594        let slice = core::slice::from_raw_parts(p, n);
595        let mut i = 0usize;
596        while i < slice.len() {
597            let v = *slice.get_unchecked(i);
598            lt += (v < s) as i32;
599            gt += (v > s) as i32;
600            i += 1;
601        }
602        lt - gt
603    }
604
605    #[inline(always)]
606    unsafe fn lt_minus_gt_avx512_two(a: &[f64], b: &[f64], s: f64) -> i32 {
607        lt_minus_gt_avx512(a.as_ptr(), a.len(), s) + lt_minus_gt_avx512(b.as_ptr(), b.len(), s)
608    }
609
610    let len = data.len();
611    let mut i = first + 1;
612    while i < len {
613        let newer = data[i];
614        let older = data[i - 1];
615        let diff = newer - older;
616
617        cu += ((diff > 0.0) as i32 as f64) * diff;
618        cd += ((diff < 0.0) as i32 as f64) * (-diff);
619
620        if d_count < period {
621            *diffs.get_unchecked_mut(d_head) = diff;
622            d_head += 1;
623            if d_head == period {
624                d_head = 0;
625            }
626            d_count += 1;
627        } else {
628            let old = *diffs.get_unchecked(d_head);
629            cu -= ((old > 0.0) as i32 as f64) * old;
630            cd -= ((old < 0.0) as i32 as f64) * (-old);
631            *diffs.get_unchecked_mut(d_head) = diff;
632            d_head += 1;
633            if d_head == period {
634                d_head = 0;
635            }
636        }
637
638        if d_count >= period {
639            let sum = cu + cd;
640            let r = if sum != 0.0 { (cu - cd) / sum } else { 0.0 };
641
642            if m_count < period {
643                let add = lt_minus_gt_avx512(myr.as_ptr(), m_head, r);
644                num += add;
645                *myr.get_unchecked_mut(m_head) = r;
646                m_head += 1;
647                if m_head == period {
648                    m_head = 0;
649                }
650                m_count += 1;
651            } else {
652                let z = *myr.get_unchecked(m_head);
653                let ad_rm = if m_head + 1 < period {
654                    lt_minus_gt_avx512_two(&myr[m_head + 1..period], &myr[..m_head], z)
655                } else {
656                    lt_minus_gt_avx512(myr.as_ptr(), m_head, z)
657                };
658                num += ad_rm;
659
660                let ad_new = if m_head + 1 < period {
661                    lt_minus_gt_avx512_two(&myr[m_head + 1..period], &myr[..m_head], r)
662                } else {
663                    lt_minus_gt_avx512(myr.as_ptr(), m_head, r)
664                };
665                num += ad_new;
666
667                *myr.get_unchecked_mut(m_head) = r;
668                m_head += 1;
669                if m_head == period {
670                    m_head = 0;
671                }
672            }
673
674            if denom != 0.0 {
675                out[i] = (num as f64) / denom;
676            }
677        }
678
679        i += 1;
680    }
681}
682
683#[inline(always)]
684fn net_myrsi_prepare<'a>(
685    input: &'a NetMyrsiInput,
686    kernel: Kernel,
687) -> Result<(&'a [f64], usize, usize, Kernel), NetMyrsiError> {
688    let data: &[f64] = input.as_ref();
689    let len = data.len();
690
691    if len == 0 {
692        return Err(NetMyrsiError::EmptyInputData);
693    }
694
695    let first = data
696        .iter()
697        .position(|x| !x.is_nan())
698        .ok_or(NetMyrsiError::AllValuesNaN)?;
699    let period = input.get_period();
700
701    if period == 0 || period > len {
702        return Err(NetMyrsiError::InvalidPeriod {
703            period,
704            data_len: len,
705        });
706    }
707
708    if len - first < period + 1 {
709        return Err(NetMyrsiError::NotEnoughValidData {
710            needed: period + 1,
711            valid: len - first,
712        });
713    }
714
715    let chosen = if matches!(kernel, Kernel::Auto) {
716        detect_best_kernel()
717    } else {
718        kernel
719    };
720
721    Ok((data, period, first, chosen))
722}
723
724#[inline]
725pub fn net_myrsi(input: &NetMyrsiInput) -> Result<NetMyrsiOutput, NetMyrsiError> {
726    net_myrsi_with_kernel(input, Kernel::Auto)
727}
728
729pub fn net_myrsi_with_kernel(
730    input: &NetMyrsiInput,
731    kernel: Kernel,
732) -> Result<NetMyrsiOutput, NetMyrsiError> {
733    let (data, period, first, chosen) = net_myrsi_prepare(input, kernel)?;
734    let mut out = alloc_with_nan_prefix(data.len(), first + period - 1);
735    net_myrsi_compute_into(data, period, first, &mut out, chosen);
736    Ok(NetMyrsiOutput { values: out })
737}
738
739#[inline]
740pub fn net_myrsi_into_slice(
741    dst: &mut [f64],
742    input: &NetMyrsiInput,
743    kern: Kernel,
744) -> Result<(), NetMyrsiError> {
745    let (data, period, first, chosen) = net_myrsi_prepare(input, kern)?;
746
747    if dst.len() != data.len() {
748        return Err(NetMyrsiError::OutputLengthMismatch {
749            expected: data.len(),
750            got: dst.len(),
751        });
752    }
753
754    net_myrsi_compute_into(data, period, first, dst, chosen);
755
756    let warm = first + period - 1;
757    for v in &mut dst[..warm] {
758        *v = f64::NAN;
759    }
760
761    Ok(())
762}
763
764#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
765#[inline]
766pub fn net_myrsi_into(input: &NetMyrsiInput, out: &mut [f64]) -> Result<(), NetMyrsiError> {
767    let (data, period, first, chosen) = net_myrsi_prepare(input, Kernel::Auto)?;
768
769    if out.len() != data.len() {
770        return Err(NetMyrsiError::OutputLengthMismatch {
771            expected: data.len(),
772            got: out.len(),
773        });
774    }
775
776    let warm = first + period - 1;
777    let warm = warm.min(out.len());
778    for v in &mut out[..warm] {
779        *v = f64::from_bits(0x7ff8_0000_0000_0000);
780    }
781
782    net_myrsi_compute_into(data, period, first, out, chosen);
783    Ok(())
784}
785
786#[derive(Debug, Clone)]
787pub struct NetMyrsiStream {
788    period: usize,
789    price: Vec<f64>,
790    myrsi: Vec<f64>,
791    head: usize,
792    filled_prices: bool,
793    filled_myrsi: bool,
794}
795
796impl NetMyrsiStream {
797    pub fn try_new(params: NetMyrsiParams) -> Result<Self, NetMyrsiError> {
798        let period = params.period.unwrap_or(14);
799        if period == 0 {
800            return Err(NetMyrsiError::InvalidPeriod {
801                period,
802                data_len: 0,
803            });
804        }
805
806        Ok(Self {
807            period,
808            price: vec![f64::NAN; period + 1],
809            myrsi: vec![f64::NAN; period],
810            head: 0,
811            filled_prices: false,
812            filled_myrsi: false,
813        })
814    }
815
816    #[inline(always)]
817    pub fn update(&mut self, value: f64) -> Option<f64> {
818        self.price[self.head % (self.period + 1)] = value;
819        if !self.filled_prices && self.head + 1 >= self.period + 1 {
820            self.filled_prices = true;
821        }
822
823        if self.filled_prices {
824            let mut cu = 0.0;
825            let mut cd = 0.0;
826            for j in 0..self.period {
827                let newer = self.price[(self.head - j) % (self.period + 1)];
828                let older = self.price[(self.head - j - 1) % (self.period + 1)];
829                let diff = newer - older;
830                if diff > 0.0 {
831                    cu += diff;
832                } else if diff < 0.0 {
833                    cd += -diff;
834                }
835            }
836            let s = cu + cd;
837            let r = if s != 0.0 { (cu - cd) / s } else { 0.0 };
838            self.myrsi[self.head % self.period] = r;
839
840            if !self.filled_myrsi && self.head + 1 >= self.period * 2 {
841                self.filled_myrsi = true;
842            }
843        }
844
845        let out = if self.filled_myrsi {
846            let denom = (self.period * (self.period - 1)) as f64 / 2.0;
847            let mut num = 0.0;
848            for i in 1..self.period {
849                for k in 0..i {
850                    let vi = self.myrsi[(self.head - i) % self.period];
851                    let vk = self.myrsi[(self.head - k) % self.period];
852                    let d = vi - vk;
853                    if d > 0.0 {
854                        num -= 1.0;
855                    } else if d < 0.0 {
856                        num += 1.0;
857                    }
858                }
859            }
860            Some(num / denom)
861        } else {
862            None
863        };
864
865        self.head = self.head.wrapping_add(1);
866        out
867    }
868}
869
870#[derive(Clone, Debug)]
871pub struct NetMyrsiBatchRange {
872    pub period: (usize, usize, usize),
873}
874
875impl Default for NetMyrsiBatchRange {
876    fn default() -> Self {
877        Self {
878            period: (14, 263, 1),
879        }
880    }
881}
882
883#[derive(Clone, Debug, Default)]
884pub struct NetMyrsiBatchBuilder {
885    range: NetMyrsiBatchRange,
886    kernel: Kernel,
887}
888
889impl NetMyrsiBatchBuilder {
890    pub fn new() -> Self {
891        Self::default()
892    }
893
894    pub fn kernel(mut self, k: Kernel) -> Self {
895        self.kernel = k;
896        self
897    }
898
899    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
900        self.range.period = (start, end, step);
901        self
902    }
903
904    pub fn period_static(mut self, p: usize) -> Self {
905        self.range.period = (p, p, 0);
906        self
907    }
908
909    pub fn apply_slice(self, data: &[f64]) -> Result<NetMyrsiBatchOutput, NetMyrsiError> {
910        net_myrsi_batch_with_kernel(data, &self.range, self.kernel)
911    }
912
913    pub fn apply_candles(
914        self,
915        c: &Candles,
916        src: &str,
917    ) -> Result<NetMyrsiBatchOutput, NetMyrsiError> {
918        self.apply_slice(source_type(c, src))
919    }
920
921    pub fn with_default_candles(c: &Candles) -> Result<NetMyrsiBatchOutput, NetMyrsiError> {
922        NetMyrsiBatchBuilder::new()
923            .kernel(Kernel::Auto)
924            .apply_candles(c, "close")
925    }
926}
927
928#[derive(Clone, Debug)]
929pub struct NetMyrsiBatchOutput {
930    pub values: Vec<f64>,
931    pub combos: Vec<NetMyrsiParams>,
932    pub rows: usize,
933    pub cols: usize,
934}
935
936impl NetMyrsiBatchOutput {
937    pub fn row_for_params(&self, p: &NetMyrsiParams) -> Option<usize> {
938        self.combos
939            .iter()
940            .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
941    }
942
943    pub fn values_for(&self, p: &NetMyrsiParams) -> Option<&[f64]> {
944        self.row_for_params(p).and_then(|r| {
945            let start = r.checked_mul(self.cols)?;
946            let end = start.checked_add(self.cols)?;
947            self.values.get(start..end)
948        })
949    }
950}
951
952#[inline(always)]
953fn expand_grid_period(
954    (start, end, step): (usize, usize, usize),
955) -> Result<Vec<usize>, NetMyrsiError> {
956    if step == 0 || start == end {
957        return Ok(vec![start]);
958    }
959
960    if start < end {
961        let mut v = Vec::new();
962        let mut x = start;
963        let st = step.max(1);
964        while x <= end {
965            v.push(x);
966            x = x
967                .checked_add(st)
968                .ok_or_else(|| NetMyrsiError::InvalidRange {
969                    start: start.to_string(),
970                    end: end.to_string(),
971                    step: step.to_string(),
972                })?;
973        }
974        if v.is_empty() {
975            return Err(NetMyrsiError::InvalidRange {
976                start: start.to_string(),
977                end: end.to_string(),
978                step: step.to_string(),
979            });
980        }
981        return Ok(v);
982    }
983
984    let mut v = Vec::new();
985    let mut x = start as isize;
986    let end_i = end as isize;
987    let st = (step as isize).max(1);
988    while x >= end_i {
989        v.push(x as usize);
990        x -= st;
991    }
992    if v.is_empty() {
993        return Err(NetMyrsiError::InvalidRange {
994            start: start.to_string(),
995            end: end.to_string(),
996            step: step.to_string(),
997        });
998    }
999    Ok(v)
1000}
1001
1002pub fn net_myrsi_batch_with_kernel(
1003    data: &[f64],
1004    sweep: &NetMyrsiBatchRange,
1005    k: Kernel,
1006) -> Result<NetMyrsiBatchOutput, NetMyrsiError> {
1007    let kernel = match k {
1008        Kernel::Auto => detect_best_batch_kernel(),
1009        other if other.is_batch() => other,
1010        other => {
1011            return Err(NetMyrsiError::InvalidKernelForBatch(other));
1012        }
1013    };
1014
1015    let periods = expand_grid_period(sweep.period)?;
1016    let combos: Vec<NetMyrsiParams> = periods
1017        .into_iter()
1018        .map(|p| NetMyrsiParams { period: Some(p) })
1019        .collect();
1020    net_myrsi_batch_inner(data, &combos, kernel)
1021}
1022
1023#[inline(always)]
1024fn net_myrsi_batch_inner(
1025    data: &[f64],
1026    combos: &[NetMyrsiParams],
1027    kern: Kernel,
1028) -> Result<NetMyrsiBatchOutput, NetMyrsiError> {
1029    if data.is_empty() {
1030        return Err(NetMyrsiError::EmptyInputData);
1031    }
1032    let first = data
1033        .iter()
1034        .position(|x| !x.is_nan())
1035        .ok_or(NetMyrsiError::AllValuesNaN)?;
1036    let cols = data.len();
1037    let rows = combos.len();
1038
1039    let _ = rows
1040        .checked_mul(cols)
1041        .ok_or_else(|| NetMyrsiError::InvalidRange {
1042            start: rows.to_string(),
1043            end: cols.to_string(),
1044            step: "rows*cols".into(),
1045        })?;
1046
1047    let max_needed = combos
1048        .iter()
1049        .map(|c| c.period.unwrap_or(14) + 1)
1050        .max()
1051        .unwrap();
1052    if cols - first < max_needed {
1053        return Err(NetMyrsiError::NotEnoughValidData {
1054            needed: max_needed,
1055            valid: cols - first,
1056        });
1057    }
1058
1059    let mut buf_mu = make_uninit_matrix(rows, cols);
1060
1061    let warms: Vec<usize> = combos
1062        .iter()
1063        .map(|c| first + c.period.unwrap_or(14) - 1)
1064        .collect();
1065    init_matrix_prefixes(&mut buf_mu, cols, &warms);
1066
1067    let mut guard = core::mem::ManuallyDrop::new(buf_mu);
1068    let out: &mut [f64] =
1069        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
1070
1071    for (row, slice) in out.chunks_mut(cols).enumerate() {
1072        let p = combos[row].period.unwrap_or(14);
1073
1074        net_myrsi_compute_into(data, p, first, slice, kern);
1075    }
1076
1077    let values = unsafe {
1078        Vec::from_raw_parts(
1079            guard.as_mut_ptr() as *mut f64,
1080            guard.len(),
1081            guard.capacity(),
1082        )
1083    };
1084
1085    Ok(NetMyrsiBatchOutput {
1086        values,
1087        combos: combos.to_vec(),
1088        rows,
1089        cols,
1090    })
1091}
1092
1093#[inline(always)]
1094fn net_myrsi_batch_inner_into(
1095    data: &[f64],
1096    combos: &[NetMyrsiParams],
1097    kern: Kernel,
1098    out: &mut [f64],
1099) -> Result<(), NetMyrsiError> {
1100    if data.is_empty() {
1101        return Err(NetMyrsiError::EmptyInputData);
1102    }
1103    let first = data
1104        .iter()
1105        .position(|x| !x.is_nan())
1106        .ok_or(NetMyrsiError::AllValuesNaN)?;
1107    let cols = data.len();
1108    let rows = combos.len();
1109
1110    let expected = rows
1111        .checked_mul(cols)
1112        .ok_or_else(|| NetMyrsiError::InvalidRange {
1113            start: rows.to_string(),
1114            end: cols.to_string(),
1115            step: "rows*cols".into(),
1116        })?;
1117    if out.len() != expected {
1118        return Err(NetMyrsiError::OutputLengthMismatch {
1119            expected,
1120            got: out.len(),
1121        });
1122    }
1123
1124    let max_needed = combos
1125        .iter()
1126        .map(|c| c.period.unwrap_or(14) + 1)
1127        .max()
1128        .unwrap();
1129    if cols - first < max_needed {
1130        return Err(NetMyrsiError::NotEnoughValidData {
1131            needed: max_needed,
1132            valid: cols - first,
1133        });
1134    }
1135
1136    let mut tmp_mu = make_uninit_matrix(1, cols);
1137
1138    for (row, dst) in out.chunks_mut(cols).enumerate() {
1139        let p = combos[row].period.unwrap_or(14);
1140        let warm = first + p - 1;
1141
1142        for i in 0..warm {
1143            dst[i] = f64::NAN;
1144        }
1145        init_matrix_prefixes(&mut tmp_mu, cols, &[warm]);
1146        let tmp: &mut [f64] =
1147            unsafe { core::slice::from_raw_parts_mut(tmp_mu.as_mut_ptr() as *mut f64, cols) };
1148        compute_myrsi_from(data, p, first, tmp);
1149        compute_net_from(tmp, p, first, dst);
1150    }
1151
1152    let _ = tmp_mu;
1153    Ok(())
1154}
1155
1156#[cfg(not(target_arch = "wasm32"))]
1157pub fn net_myrsi_batch(
1158    data: &[Vec<f64>],
1159    params: &NetMyrsiParams,
1160    kernel: Option<Kernel>,
1161) -> Result<Vec<Vec<f64>>, NetMyrsiError> {
1162    let kern = kernel.unwrap_or_else(detect_best_batch_kernel);
1163
1164    let results: Result<Vec<_>, _> = data
1165        .par_iter()
1166        .map(|series| {
1167            let input = NetMyrsiInput::from_slice(series, params.clone());
1168            net_myrsi_with_kernel(&input, kern).map(|o| o.values)
1169        })
1170        .collect();
1171
1172    results
1173}
1174
1175#[cfg(feature = "python")]
1176#[pyfunction]
1177#[pyo3(name = "net_myrsi")]
1178pub fn net_myrsi_py<'py>(
1179    py: Python<'py>,
1180    data: PyReadonlyArray1<'py, f64>,
1181    period: usize,
1182    kernel: Option<&str>,
1183) -> PyResult<Bound<'py, PyArray1<f64>>> {
1184    let slice_in = data.as_slice()?;
1185    let kern = validate_kernel(kernel, false)?;
1186    let params = NetMyrsiParams {
1187        period: Some(period),
1188    };
1189    let input = NetMyrsiInput::from_slice(slice_in, params);
1190
1191    let result_vec: Vec<f64> = py
1192        .allow_threads(|| net_myrsi_with_kernel(&input, kern).map(|o| o.values))
1193        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1194
1195    Ok(result_vec.into_pyarray(py))
1196}
1197
1198#[cfg(feature = "python")]
1199#[pyfunction(name = "net_myrsi_batch")]
1200#[pyo3(signature = (data, period_range, kernel=None))]
1201pub fn net_myrsi_batch_py<'py>(
1202    py: Python<'py>,
1203    data: numpy::PyReadonlyArray1<'py, f64>,
1204    period_range: (usize, usize, usize),
1205    kernel: Option<&str>,
1206) -> PyResult<Bound<'py, PyDict>> {
1207    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1208    let slice_in = data.as_slice()?;
1209
1210    let periods =
1211        expand_grid_period(period_range).map_err(|e| PyValueError::new_err(e.to_string()))?;
1212    let combos: Vec<NetMyrsiParams> = periods
1213        .into_iter()
1214        .map(|p| NetMyrsiParams { period: Some(p) })
1215        .collect();
1216    let rows = combos.len();
1217    let cols = slice_in.len();
1218
1219    let len = rows
1220        .checked_mul(cols)
1221        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1222    let out_arr = unsafe { PyArray1::<f64>::new(py, [len], false) };
1223    let out_slice = unsafe { out_arr.as_slice_mut()? };
1224
1225    let kern = validate_kernel(kernel, true)?;
1226    py.allow_threads(|| {
1227        let actual = match kern {
1228            Kernel::Auto => detect_best_batch_kernel(),
1229            k => k,
1230        };
1231
1232        net_myrsi_batch_inner_into(slice_in, &combos, actual, out_slice)
1233    })
1234    .map_err(|e: NetMyrsiError| PyValueError::new_err(e.to_string()))?;
1235
1236    let dict = PyDict::new(py);
1237    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1238    dict.set_item(
1239        "periods",
1240        combos
1241            .iter()
1242            .map(|p| p.period.unwrap() as u64)
1243            .collect::<Vec<_>>()
1244            .into_pyarray(py),
1245    )?;
1246    Ok(dict)
1247}
1248
1249#[cfg(all(feature = "python", feature = "cuda"))]
1250#[pyfunction(name = "net_myrsi_cuda_batch_dev")]
1251#[pyo3(signature = (data_f32, period_range, device_id=0))]
1252pub fn net_myrsi_cuda_batch_dev_py<'py>(
1253    py: Python<'py>,
1254    data_f32: numpy::PyReadonlyArray1<'py, f32>,
1255    period_range: (usize, usize, usize),
1256    device_id: usize,
1257) -> PyResult<DeviceArrayF32Py> {
1258    use crate::cuda::cuda_available;
1259    if !cuda_available() {
1260        return Err(PyValueError::new_err("CUDA not available"));
1261    }
1262    let prices = data_f32.as_slice()?;
1263    let sweep = NetMyrsiBatchRange {
1264        period: period_range,
1265    };
1266    let (inner, ctx, dev_id) = py.allow_threads(|| {
1267        let cuda = crate::cuda::CudaNetMyrsi::new(device_id)
1268            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1269        let ctx = cuda.context_arc();
1270        let dev_id = cuda.device_id();
1271        cuda.net_myrsi_batch_dev(prices, &sweep)
1272            .map(|(inner, _)| (inner, ctx, dev_id))
1273            .map_err(|e| PyValueError::new_err(e.to_string()))
1274    })?;
1275
1276    Ok(DeviceArrayF32Py {
1277        inner,
1278        _ctx: Some(ctx),
1279        device_id: Some(dev_id),
1280    })
1281}
1282
1283#[cfg(all(feature = "python", feature = "cuda"))]
1284#[pyfunction(name = "net_myrsi_cuda_many_series_one_param_dev")]
1285#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1286pub fn net_myrsi_cuda_many_series_one_param_dev_py(
1287    py: Python<'_>,
1288    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1289    period: usize,
1290    device_id: usize,
1291) -> PyResult<DeviceArrayF32Py> {
1292    use crate::cuda::cuda_available;
1293    use numpy::PyUntypedArrayMethods;
1294    if !cuda_available() {
1295        return Err(PyValueError::new_err("CUDA not available"));
1296    }
1297    let flat = data_tm_f32.as_slice()?;
1298    let rows = data_tm_f32.shape()[0];
1299    let cols = data_tm_f32.shape()[1];
1300    let params = NetMyrsiParams {
1301        period: Some(period),
1302    };
1303    let (inner, ctx, dev_id) = py.allow_threads(|| {
1304        let cuda = crate::cuda::CudaNetMyrsi::new(device_id)
1305            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1306        let ctx = cuda.context_arc();
1307        let dev_id = cuda.device_id();
1308        cuda.net_myrsi_many_series_one_param_time_major_dev(flat, cols, rows, &params)
1309            .map(|inner| (inner, ctx, dev_id))
1310            .map_err(|e| PyValueError::new_err(e.to_string()))
1311    })?;
1312    Ok(DeviceArrayF32Py {
1313        inner,
1314        _ctx: Some(ctx),
1315        device_id: Some(dev_id),
1316    })
1317}
1318
1319#[cfg(feature = "python")]
1320#[pyclass(name = "NetMyrsiStream")]
1321pub struct NetMyrsiStreamPy {
1322    stream: NetMyrsiStream,
1323}
1324
1325#[cfg(feature = "python")]
1326#[pymethods]
1327impl NetMyrsiStreamPy {
1328    #[new]
1329    pub fn new(period: usize) -> PyResult<Self> {
1330        let params = NetMyrsiParams {
1331            period: Some(period),
1332        };
1333        let stream =
1334            NetMyrsiStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1335        Ok(Self { stream })
1336    }
1337
1338    pub fn update(&mut self, value: f64) -> Option<f64> {
1339        self.stream.update(value)
1340    }
1341}
1342
1343#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1344#[wasm_bindgen]
1345pub fn net_myrsi_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1346    let input = NetMyrsiInput::from_slice(
1347        data,
1348        NetMyrsiParams {
1349            period: Some(period),
1350        },
1351    );
1352    let mut out = vec![f64::NAN; data.len()];
1353    net_myrsi_into_slice(&mut out, &input, detect_best_kernel())
1354        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1355    Ok(out)
1356}
1357
1358#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1359#[wasm_bindgen]
1360pub fn net_myrsi_alloc(len: usize) -> *mut f64 {
1361    let mut v = Vec::<f64>::with_capacity(len);
1362    let p = v.as_mut_ptr();
1363    std::mem::forget(v);
1364    p
1365}
1366
1367#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1368#[wasm_bindgen]
1369pub fn net_myrsi_free(ptr: *mut f64, len: usize) {
1370    unsafe {
1371        let _ = Vec::from_raw_parts(ptr, len, len);
1372    }
1373}
1374
1375#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1376#[wasm_bindgen]
1377pub fn net_myrsi_into(
1378    in_ptr: *const f64,
1379    out_ptr: *mut f64,
1380    len: usize,
1381    period: usize,
1382) -> Result<(), JsValue> {
1383    if in_ptr.is_null() || out_ptr.is_null() {
1384        return Err(JsValue::from_str("null pointer"));
1385    }
1386    unsafe {
1387        let data = std::slice::from_raw_parts(in_ptr, len);
1388        let input = NetMyrsiInput::from_slice(
1389            data,
1390            NetMyrsiParams {
1391                period: Some(period),
1392            },
1393        );
1394
1395        if in_ptr == out_ptr {
1396            let mut tmp = vec![0.0; len];
1397            net_myrsi_into_slice(&mut tmp, &input, detect_best_kernel())
1398                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1399            std::slice::from_raw_parts_mut(out_ptr, len).copy_from_slice(&tmp);
1400        } else {
1401            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1402            net_myrsi_into_slice(out, &input, detect_best_kernel())
1403                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1404        }
1405        Ok(())
1406    }
1407}
1408
1409#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1410#[derive(Serialize, Deserialize)]
1411pub struct NetMyrsiBatchConfig {
1412    pub period_range: (usize, usize, usize),
1413}
1414
1415#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1416#[derive(Serialize, Deserialize)]
1417pub struct NetMyrsiBatchJsOutput {
1418    pub values: Vec<f64>,
1419    pub combos: Vec<NetMyrsiParams>,
1420    pub rows: usize,
1421    pub cols: usize,
1422}
1423
1424#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1425#[wasm_bindgen(js_name = net_myrsi_batch)]
1426pub fn net_myrsi_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1427    let cfg: NetMyrsiBatchConfig = serde_wasm_bindgen::from_value(config)
1428        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1429    let sweep = NetMyrsiBatchRange {
1430        period: cfg.period_range,
1431    };
1432    let out = net_myrsi_batch_with_kernel(data, &sweep, detect_best_kernel())
1433        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1434    serde_wasm_bindgen::to_value(&NetMyrsiBatchJsOutput {
1435        values: out.values,
1436        combos: out.combos,
1437        rows: out.rows,
1438        cols: out.cols,
1439    })
1440    .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1441}
1442
1443#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1444#[wasm_bindgen]
1445pub fn net_myrsi_batch_into(
1446    in_ptr: *const f64,
1447    out_ptr: *mut f64,
1448    len: usize,
1449    period_start: usize,
1450    period_end: usize,
1451    period_step: usize,
1452) -> Result<usize, JsValue> {
1453    if in_ptr.is_null() || out_ptr.is_null() {
1454        return Err(JsValue::from_str(
1455            "null pointer passed to net_myrsi_batch_into",
1456        ));
1457    }
1458    unsafe {
1459        let data = std::slice::from_raw_parts(in_ptr, len);
1460        let periods = expand_grid_period((period_start, period_end, period_step))
1461            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1462        let combos: Vec<NetMyrsiParams> = periods
1463            .into_iter()
1464            .map(|p| NetMyrsiParams { period: Some(p) })
1465            .collect();
1466        let rows = combos.len();
1467        let cols = len;
1468        let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
1469
1470        net_myrsi_batch_inner_into(data, &combos, detect_best_kernel(), out)
1471            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1472        Ok(rows)
1473    }
1474}
1475
1476#[cfg(test)]
1477mod tests {
1478    use super::*;
1479    use crate::skip_if_unsupported;
1480    use crate::utilities::data_loader::read_candles_from_csv;
1481    #[cfg(feature = "proptest")]
1482    use proptest::prelude::*;
1483    use std::error::Error;
1484
1485    fn check_net_myrsi_partial_params(
1486        test_name: &str,
1487        kernel: Kernel,
1488    ) -> Result<(), Box<dyn Error>> {
1489        skip_if_unsupported!(kernel, test_name);
1490        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1491        let candles = read_candles_from_csv(file_path)?;
1492
1493        let default_params = NetMyrsiParams { period: None };
1494        let input = NetMyrsiInput::from_candles(&candles, "close", default_params);
1495        let output = net_myrsi_with_kernel(&input, kernel)?;
1496        assert_eq!(output.values.len(), candles.close.len());
1497
1498        Ok(())
1499    }
1500
1501    fn check_net_myrsi_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1502        skip_if_unsupported!(kernel, test_name);
1503        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1504        let candles = read_candles_from_csv(file_path)?;
1505
1506        let input = NetMyrsiInput::from_candles(&candles, "close", NetMyrsiParams::default());
1507        let result = net_myrsi_with_kernel(&input, kernel)?;
1508
1509        let expected_last_five = [0.64835165, 0.49450549, 0.29670330, 0.07692308, -0.07692308];
1510
1511        let start = result.values.len().saturating_sub(5);
1512        for (i, &val) in result.values[start..].iter().enumerate() {
1513            let diff = (val - expected_last_five[i]).abs();
1514            assert!(
1515                diff < 1e-6,
1516                "[{}] NET_MYRSI {:?} mismatch at idx {}: got {}, expected {}",
1517                test_name,
1518                kernel,
1519                i,
1520                val,
1521                expected_last_five[i]
1522            );
1523        }
1524        Ok(())
1525    }
1526
1527    fn check_net_myrsi_default_candles(
1528        test_name: &str,
1529        kernel: Kernel,
1530    ) -> Result<(), Box<dyn Error>> {
1531        skip_if_unsupported!(kernel, test_name);
1532        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1533        let candles = read_candles_from_csv(file_path)?;
1534
1535        let input = NetMyrsiInput::with_default_candles(&candles);
1536        match input.data {
1537            NetMyrsiData::Candles { source, .. } => assert_eq!(source, "close"),
1538            _ => panic!("Expected NetMyrsiData::Candles"),
1539        }
1540        let output = net_myrsi_with_kernel(&input, kernel)?;
1541        assert_eq!(output.values.len(), candles.close.len());
1542
1543        Ok(())
1544    }
1545
1546    fn check_net_myrsi_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1547        skip_if_unsupported!(kernel, test_name);
1548        let input_data = [10.0, 20.0, 30.0];
1549        let params = NetMyrsiParams { period: Some(0) };
1550        let input = NetMyrsiInput::from_slice(&input_data, params);
1551        let res = net_myrsi_with_kernel(&input, kernel);
1552        assert!(
1553            res.is_err(),
1554            "[{}] NET_MYRSI should fail with zero period",
1555            test_name
1556        );
1557        Ok(())
1558    }
1559
1560    fn check_net_myrsi_period_exceeds_length(
1561        test_name: &str,
1562        kernel: Kernel,
1563    ) -> Result<(), Box<dyn Error>> {
1564        skip_if_unsupported!(kernel, test_name);
1565        let data_small = [10.0, 20.0, 30.0];
1566        let params = NetMyrsiParams { period: Some(10) };
1567        let input = NetMyrsiInput::from_slice(&data_small, params);
1568        let res = net_myrsi_with_kernel(&input, kernel);
1569        assert!(
1570            res.is_err(),
1571            "[{}] NET_MYRSI should fail with period exceeding length",
1572            test_name
1573        );
1574        Ok(())
1575    }
1576
1577    fn check_net_myrsi_very_small_dataset(
1578        test_name: &str,
1579        kernel: Kernel,
1580    ) -> Result<(), Box<dyn Error>> {
1581        skip_if_unsupported!(kernel, test_name);
1582        let data_small = vec![10.0, 20.0, 30.0, 15.0, 25.0];
1583        let params = NetMyrsiParams { period: Some(3) };
1584        let input = NetMyrsiInput::from_slice(&data_small, params);
1585        let result = net_myrsi_with_kernel(&input, kernel)?;
1586        assert_eq!(result.values.len(), data_small.len());
1587        Ok(())
1588    }
1589
1590    fn check_net_myrsi_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1591        skip_if_unsupported!(kernel, test_name);
1592        let data: Vec<f64> = vec![];
1593        let params = NetMyrsiParams::default();
1594        let input = NetMyrsiInput::from_slice(&data, params);
1595        let res = net_myrsi_with_kernel(&input, kernel);
1596        assert!(
1597            res.is_err(),
1598            "[{}] NET_MYRSI should fail with empty input",
1599            test_name
1600        );
1601        Ok(())
1602    }
1603
1604    fn check_net_myrsi_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1605        skip_if_unsupported!(kernel, test_name);
1606        let data = vec![f64::NAN; 30];
1607        let params = NetMyrsiParams::default();
1608        let input = NetMyrsiInput::from_slice(&data, params);
1609        let res = net_myrsi_with_kernel(&input, kernel);
1610        assert!(
1611            res.is_err(),
1612            "[{}] NET_MYRSI should fail with all NaN values",
1613            test_name
1614        );
1615        Ok(())
1616    }
1617
1618    fn check_net_myrsi_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1619        skip_if_unsupported!(kernel, test_name);
1620        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1621        let candles = read_candles_from_csv(file_path)?;
1622
1623        let first_params = NetMyrsiParams { period: Some(14) };
1624        let first_input = NetMyrsiInput::from_candles(&candles, "close", first_params);
1625        let first_result = net_myrsi_with_kernel(&first_input, kernel)?;
1626
1627        let second_params = NetMyrsiParams { period: Some(14) };
1628        let second_input = NetMyrsiInput::from_slice(&first_result.values, second_params);
1629        let second_result = net_myrsi_with_kernel(&second_input, kernel)?;
1630
1631        assert_eq!(second_result.values.len(), first_result.values.len());
1632
1633        let non_nan_count = second_result.values.iter().filter(|v| !v.is_nan()).count();
1634        assert!(
1635            non_nan_count > 0,
1636            "[{}] Second pass should produce non-NaN values",
1637            test_name
1638        );
1639
1640        Ok(())
1641    }
1642
1643    fn check_net_myrsi_warmup_nans(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1644        skip_if_unsupported!(kernel, test_name);
1645        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1646        let candles = read_candles_from_csv(file_path)?;
1647        let first = candles.close.iter().position(|x| !x.is_nan()).unwrap_or(0);
1648        let p = 14;
1649        let out = net_myrsi_with_kernel(
1650            &NetMyrsiInput::from_candles(&candles, "close", NetMyrsiParams { period: Some(p) }),
1651            kernel,
1652        )?
1653        .values;
1654        let warm = first + p - 1;
1655        assert!(
1656            out[..warm].iter().all(|v| v.is_nan()),
1657            "[{}] Warmup NaNs should not be overwritten",
1658            test_name
1659        );
1660        Ok(())
1661    }
1662
1663    fn check_net_myrsi_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1664        skip_if_unsupported!(kernel, test_name);
1665        let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1666        for _ in 0..20 {
1667            data.push(data[data.len() - 1] + 1.0);
1668        }
1669
1670        data[15] = f64::NAN;
1671
1672        let params = NetMyrsiParams { period: Some(14) };
1673        let input = NetMyrsiInput::from_slice(&data, params);
1674        let result = net_myrsi_with_kernel(&input, kernel)?;
1675
1676        assert_eq!(result.values.len(), data.len());
1677
1678        let non_nan_count = result.values.iter().filter(|v| !v.is_nan()).count();
1679        assert!(
1680            non_nan_count > 0,
1681            "[{}] Should produce some non-NaN values despite NaN input",
1682            test_name
1683        );
1684
1685        Ok(())
1686    }
1687
1688    fn check_net_myrsi_streaming(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1689        skip_if_unsupported!(kernel, test);
1690
1691        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1692        let c = read_candles_from_csv(file)?;
1693        let period = 14;
1694
1695        let batch = NetMyrsiInput::from_candles(
1696            &c,
1697            "close",
1698            NetMyrsiParams {
1699                period: Some(period),
1700            },
1701        );
1702        let b = net_myrsi_with_kernel(&batch, kernel)?.values;
1703
1704        let mut s = NetMyrsiStream::try_new(NetMyrsiParams {
1705            period: Some(period),
1706        })?;
1707        let mut stream = Vec::with_capacity(c.close.len());
1708        for &px in &c.close {
1709            stream.push(s.update(px).unwrap_or(f64::NAN));
1710        }
1711
1712        let first_valid = b.iter().position(|v| !v.is_nan()).unwrap();
1713        for i in first_valid..b.len() {
1714            if !b[i].is_nan() && !stream[i].is_nan() {
1715                assert!(
1716                    (b[i] - stream[i]).abs() < 1e-10,
1717                    "[{test}] idx {i}: {} vs {}",
1718                    b[i],
1719                    stream[i]
1720                );
1721            }
1722        }
1723
1724        Ok(())
1725    }
1726
1727    #[cfg(debug_assertions)]
1728    fn check_net_myrsi_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1729        skip_if_unsupported!(kernel, test_name);
1730
1731        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1732        let candles = read_candles_from_csv(file_path)?;
1733
1734        let test_params = vec![
1735            NetMyrsiParams::default(),
1736            NetMyrsiParams { period: Some(10) },
1737            NetMyrsiParams { period: Some(20) },
1738        ];
1739
1740        for params in test_params {
1741            let input = NetMyrsiInput::from_candles(&candles, "close", params.clone());
1742            let output = net_myrsi_with_kernel(&input, kernel)?;
1743
1744            for (i, &val) in output.values.iter().enumerate() {
1745                if val.is_nan() {
1746                    continue;
1747                }
1748                let bits = val.to_bits();
1749                assert_ne!(
1750                    bits, 0x11111111_11111111,
1751                    "[{}] alloc_with_nan_prefix poison at {} with params {:?}",
1752                    test_name, i, params
1753                );
1754            }
1755        }
1756
1757        Ok(())
1758    }
1759
1760    #[cfg(not(debug_assertions))]
1761    fn check_net_myrsi_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1762        Ok(())
1763    }
1764
1765    #[cfg(feature = "proptest")]
1766    #[allow(clippy::float_cmp)]
1767    fn check_net_myrsi_property(
1768        test_name: &str,
1769        kernel: Kernel,
1770    ) -> Result<(), Box<dyn std::error::Error>> {
1771        use proptest::prelude::*;
1772        skip_if_unsupported!(kernel, test_name);
1773
1774        let strat = (2usize..=50).prop_flat_map(|period| {
1775            (
1776                prop::collection::vec(
1777                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1778                    period + 2..400,
1779                ),
1780                Just(period),
1781            )
1782        });
1783
1784        proptest::test_runner::TestRunner::default()
1785            .run(&strat, |(data, period)| {
1786                let params = NetMyrsiParams {
1787                    period: Some(period),
1788                };
1789                let input = NetMyrsiInput::from_slice(&data, params);
1790
1791                let NetMyrsiOutput { values: out } = net_myrsi_with_kernel(&input, kernel).unwrap();
1792                let NetMyrsiOutput { values: ref_out } =
1793                    net_myrsi_with_kernel(&input, Kernel::Scalar).unwrap();
1794
1795                assert_eq!(out.len(), ref_out.len());
1796                assert_eq!(out.len(), data.len());
1797
1798                let first_valid = ref_out
1799                    .iter()
1800                    .position(|v| !v.is_nan())
1801                    .unwrap_or(ref_out.len());
1802
1803                for i in 0..first_valid {
1804                    assert!(out[i].is_nan(), "Expected NaN at index {}", i);
1805                }
1806
1807                for i in first_valid..out.len() {
1808                    if ref_out[i].is_nan() {
1809                        assert!(out[i].is_nan(), "Expected NaN at index {}", i);
1810                    } else {
1811                        let diff = (out[i] - ref_out[i]).abs();
1812                        let rel_diff = diff / ref_out[i].abs().max(1e-10);
1813                        assert!(
1814                            rel_diff < 1e-10,
1815                            "Mismatch at {}: kernel={}, scalar={}, diff={}",
1816                            i,
1817                            out[i],
1818                            ref_out[i],
1819                            diff
1820                        );
1821                    }
1822                }
1823
1824                Ok(())
1825            })
1826            .map_err(|e| format!("Property test failed: {}", e).into())
1827    }
1828
1829    #[cfg(not(feature = "proptest"))]
1830    fn check_net_myrsi_property(
1831        _test_name: &str,
1832        _kernel: Kernel,
1833    ) -> Result<(), Box<dyn std::error::Error>> {
1834        Ok(())
1835    }
1836
1837    macro_rules! generate_all_net_myrsi_tests {
1838        ($($test_fn:ident),*) => {
1839            paste::paste! {
1840                $(
1841                    #[test]
1842                    fn [<$test_fn _scalar>]() -> Result<(), Box<dyn Error>> {
1843                        $test_fn(stringify!([<$test_fn _scalar>]), Kernel::Scalar)
1844                    }
1845                )*
1846                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1847                $(
1848                    #[test]
1849                    fn [<$test_fn _avx2>]() -> Result<(), Box<dyn Error>> {
1850                        $test_fn(stringify!([<$test_fn _avx2>]), Kernel::Avx2)
1851                    }
1852                )*
1853                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1854                $(
1855                    #[test]
1856                    fn [<$test_fn _avx512>]() -> Result<(), Box<dyn Error>> {
1857                        $test_fn(stringify!([<$test_fn _avx512>]), Kernel::Avx512)
1858                    }
1859                )*
1860            }
1861        };
1862    }
1863
1864    generate_all_net_myrsi_tests!(
1865        check_net_myrsi_partial_params,
1866        check_net_myrsi_accuracy,
1867        check_net_myrsi_default_candles,
1868        check_net_myrsi_zero_period,
1869        check_net_myrsi_period_exceeds_length,
1870        check_net_myrsi_very_small_dataset,
1871        check_net_myrsi_empty_input,
1872        check_net_myrsi_all_nan,
1873        check_net_myrsi_reinput,
1874        check_net_myrsi_warmup_nans,
1875        check_net_myrsi_nan_handling,
1876        check_net_myrsi_streaming,
1877        check_net_myrsi_no_poison,
1878        check_net_myrsi_property
1879    );
1880
1881    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1882        skip_if_unsupported!(kernel, test);
1883        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1884        let c = read_candles_from_csv(file)?;
1885        let out = NetMyrsiBatchBuilder::new()
1886            .kernel(kernel)
1887            .period_range(10, 20, 1)
1888            .apply_candles(&c, "close")?;
1889        let def = NetMyrsiParams::default();
1890        let row = out.values_for(&def).expect("default row missing");
1891        assert_eq!(row.len(), c.close.len());
1892        Ok(())
1893    }
1894
1895    #[cfg(debug_assertions)]
1896    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1897        skip_if_unsupported!(kernel, test);
1898        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1899        let c = read_candles_from_csv(file)?;
1900        let out = NetMyrsiBatchBuilder::new()
1901            .kernel(kernel)
1902            .period_range(10, 20, 5)
1903            .apply_candles(&c, "close")?;
1904        for (idx, &v) in out.values.iter().enumerate() {
1905            if v.is_nan() {
1906                continue;
1907            }
1908            let bits = v.to_bits();
1909            assert_ne!(
1910                bits, 0x1111_1111_1111_1111,
1911                "[{test}] alloc_with_nan_prefix poison at {idx}"
1912            );
1913            assert_ne!(
1914                bits, 0x2222_2222_2222_2222,
1915                "[{test}] init_matrix_prefixes poison at {idx}"
1916            );
1917            assert_ne!(
1918                bits, 0x3333_3333_3333_3333,
1919                "[{test}] make_uninit_matrix poison at {idx}"
1920            );
1921        }
1922        Ok(())
1923    }
1924
1925    #[cfg(not(debug_assertions))]
1926    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1927        Ok(())
1928    }
1929
1930    fn check_batch_sweep(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1931        skip_if_unsupported!(kernel, test);
1932        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1933        let c = read_candles_from_csv(file)?;
1934        let out = NetMyrsiBatchBuilder::new()
1935            .kernel(kernel)
1936            .period_range(10, 20, 5)
1937            .apply_candles(&c, "close")?;
1938        assert_eq!(out.rows, 3);
1939        assert_eq!(out.cols, c.close.len());
1940        assert_eq!(out.values.len(), out.rows * out.cols);
1941
1942        assert_eq!(out.combos.len(), 3);
1943        assert_eq!(out.combos[0].period, Some(10));
1944        assert_eq!(out.combos[1].period, Some(15));
1945        assert_eq!(out.combos[2].period, Some(20));
1946        Ok(())
1947    }
1948
1949    macro_rules! gen_batch_tests {
1950        ($fn_name:ident) => {
1951            paste::paste! {
1952                #[test] fn [<$fn_name _scalar>]() {
1953                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1954                }
1955                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1956                #[test] fn [<$fn_name _avx2>]() {
1957                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1958                }
1959                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1960                #[test] fn [<$fn_name _avx512>]() {
1961                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1962                }
1963                #[test] fn [<$fn_name _auto_detect>]() {
1964                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]),
1965                                    Kernel::Auto);
1966                }
1967            }
1968        };
1969    }
1970
1971    gen_batch_tests!(check_batch_default_row);
1972    gen_batch_tests!(check_batch_sweep);
1973    gen_batch_tests!(check_batch_no_poison);
1974
1975    #[test]
1976    fn test_net_myrsi_into_matches_api() -> Result<(), Box<dyn Error>> {
1977        let mut data = Vec::with_capacity(256);
1978        data.extend_from_slice(&[f64::NAN, f64::NAN, f64::NAN]);
1979        for i in 0..(256 - 3) {
1980            let x = i as f64;
1981            data.push((x.sin() + (0.1 * x).cos()) * 10.0);
1982        }
1983
1984        let input = NetMyrsiInput::from_slice(&data, NetMyrsiParams::default());
1985
1986        let baseline = net_myrsi(&input)?.values;
1987
1988        let mut out = vec![0.0; data.len()];
1989
1990        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1991        {
1992            net_myrsi_into(&input, &mut out)?;
1993        }
1994        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1995        {
1996            net_myrsi_into_slice(&mut out, &input, Kernel::Auto)?;
1997        }
1998
1999        assert_eq!(baseline.len(), out.len());
2000
2001        fn eq_or_both_nan(a: f64, b: f64) -> bool {
2002            (a.is_nan() && b.is_nan()) || (a == b)
2003        }
2004
2005        for (i, (&a, &b)) in baseline.iter().zip(out.iter()).enumerate() {
2006            assert!(
2007                eq_or_both_nan(a, b),
2008                "value mismatch at {}: baseline={}, into={}",
2009                i,
2010                a,
2011                b
2012            );
2013        }
2014        Ok(())
2015    }
2016}