Skip to main content

vector_ta/indicators/
rsx.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5    make_uninit_matrix,
6};
7use aligned_vec::{AVec, CACHELINE_ALIGN};
8#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
9use core::arch::x86_64::*;
10#[cfg(not(target_arch = "wasm32"))]
11use rayon::prelude::*;
12use std::mem::MaybeUninit;
13use thiserror::Error;
14
15#[cfg(feature = "python")]
16use crate::utilities::kernel_validation::validate_kernel;
17#[cfg(feature = "python")]
18use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
19#[cfg(feature = "python")]
20use pyo3::exceptions::PyValueError;
21#[cfg(feature = "python")]
22use pyo3::prelude::*;
23#[cfg(feature = "python")]
24use pyo3::types::PyDict;
25
26#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
27use serde::{Deserialize, Serialize};
28#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
29use serde_wasm_bindgen;
30#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
31use wasm_bindgen::prelude::*;
32
33impl<'a> AsRef<[f64]> for RsxInput<'a> {
34    #[inline(always)]
35    fn as_ref(&self) -> &[f64] {
36        match &self.data {
37            RsxData::Slice(slice) => slice,
38            RsxData::Candles { candles, source } => source_type(candles, source),
39        }
40    }
41}
42
43#[derive(Debug, Clone)]
44pub enum RsxData<'a> {
45    Candles {
46        candles: &'a Candles,
47        source: &'a str,
48    },
49    Slice(&'a [f64]),
50}
51
52#[derive(Debug, Clone)]
53pub struct RsxOutput {
54    pub values: Vec<f64>,
55}
56
57#[derive(Debug, Clone)]
58#[cfg_attr(
59    all(target_arch = "wasm32", feature = "wasm"),
60    derive(Serialize, Deserialize)
61)]
62pub struct RsxParams {
63    pub period: Option<usize>,
64}
65
66impl Default for RsxParams {
67    fn default() -> Self {
68        Self { period: Some(14) }
69    }
70}
71
72#[derive(Debug, Clone)]
73pub struct RsxInput<'a> {
74    pub data: RsxData<'a>,
75    pub params: RsxParams,
76}
77
78impl<'a> RsxInput<'a> {
79    #[inline]
80    pub fn from_candles(c: &'a Candles, s: &'a str, p: RsxParams) -> Self {
81        Self {
82            data: RsxData::Candles {
83                candles: c,
84                source: s,
85            },
86            params: p,
87        }
88    }
89    #[inline]
90    pub fn from_slice(sl: &'a [f64], p: RsxParams) -> Self {
91        Self {
92            data: RsxData::Slice(sl),
93            params: p,
94        }
95    }
96    #[inline]
97    pub fn with_default_candles(c: &'a Candles) -> Self {
98        Self::from_candles(c, "close", RsxParams::default())
99    }
100    #[inline]
101    pub fn get_period(&self) -> usize {
102        self.params.period.unwrap_or(14)
103    }
104}
105
106#[derive(Copy, Clone, Debug)]
107pub struct RsxBuilder {
108    period: Option<usize>,
109    kernel: Kernel,
110}
111
112impl Default for RsxBuilder {
113    fn default() -> Self {
114        Self {
115            period: None,
116            kernel: Kernel::Auto,
117        }
118    }
119}
120
121impl RsxBuilder {
122    #[inline(always)]
123    pub fn new() -> Self {
124        Self::default()
125    }
126    #[inline(always)]
127    pub fn period(mut self, n: usize) -> Self {
128        self.period = Some(n);
129        self
130    }
131    #[inline(always)]
132    pub fn kernel(mut self, k: Kernel) -> Self {
133        self.kernel = k;
134        self
135    }
136    #[inline(always)]
137    pub fn apply(self, c: &Candles) -> Result<RsxOutput, RsxError> {
138        let p = RsxParams {
139            period: self.period,
140        };
141        let i = RsxInput::from_candles(c, "close", p);
142        rsx_with_kernel(&i, self.kernel)
143    }
144    #[inline(always)]
145    pub fn apply_slice(self, d: &[f64]) -> Result<RsxOutput, RsxError> {
146        let p = RsxParams {
147            period: self.period,
148        };
149        let i = RsxInput::from_slice(d, p);
150        rsx_with_kernel(&i, self.kernel)
151    }
152    #[inline(always)]
153    pub fn into_stream(self) -> Result<RsxStream, RsxError> {
154        let p = RsxParams {
155            period: self.period,
156        };
157        RsxStream::try_new(p)
158    }
159}
160
161#[derive(Debug, Error)]
162pub enum RsxError {
163    #[error("rsx: Input data slice is empty.")]
164    EmptyInputData,
165
166    #[error("rsx: All values are NaN.")]
167    AllValuesNaN,
168
169    #[error("rsx: Invalid period: period = {period}, data length = {data_len}")]
170    InvalidPeriod { period: usize, data_len: usize },
171
172    #[error("rsx: Not enough valid data: needed = {needed}, valid = {valid}")]
173    NotEnoughValidData { needed: usize, valid: usize },
174
175    #[error("rsx: Output length mismatch: expected {expected}, got {got}")]
176    OutputLengthMismatch { expected: usize, got: usize },
177
178    #[error("rsx: Invalid range: start={start}, end={end}, step={step}")]
179    InvalidRange {
180        start: String,
181        end: String,
182        step: String,
183    },
184
185    #[error("rsx: Invalid kernel: expected batch kernel, got {kernel:?}")]
186    InvalidKernel { kernel: Kernel },
187
188    #[error("rsx: Invalid kernel for batch: {0:?}")]
189    InvalidKernelForBatch(Kernel),
190}
191
192#[inline(always)]
193fn rsx_prepare<'a>(
194    input: &'a RsxInput,
195    kernel: Kernel,
196) -> Result<(&'a [f64], usize, usize, Kernel), RsxError> {
197    let data: &[f64] = input.as_ref();
198    let len = data.len();
199    if len == 0 {
200        return Err(RsxError::EmptyInputData);
201    }
202    let first = data
203        .iter()
204        .position(|x| !x.is_nan())
205        .ok_or(RsxError::AllValuesNaN)?;
206    let period = input.get_period();
207
208    if period == 0 || period > len {
209        return Err(RsxError::InvalidPeriod {
210            period,
211            data_len: len,
212        });
213    }
214    if len - first < period {
215        return Err(RsxError::NotEnoughValidData {
216            needed: period,
217            valid: len - first,
218        });
219    }
220
221    let chosen = match kernel {
222        Kernel::Auto => Kernel::Scalar,
223        k => k,
224    };
225    Ok((data, period, first, chosen))
226}
227
228#[inline]
229pub fn rsx(input: &RsxInput) -> Result<RsxOutput, RsxError> {
230    rsx_with_kernel(input, Kernel::Auto)
231}
232
233pub fn rsx_with_kernel(input: &RsxInput, kernel: Kernel) -> Result<RsxOutput, RsxError> {
234    let (data, period, first, chosen) = rsx_prepare(input, kernel)?;
235    let mut out = alloc_with_nan_prefix(data.len(), first + period - 1);
236
237    unsafe {
238        match chosen {
239            #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
240            Kernel::Scalar | Kernel::ScalarBatch => rsx_simd128(data, period, first, &mut out),
241            #[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
242            Kernel::Scalar | Kernel::ScalarBatch => rsx_scalar(data, period, first, &mut out),
243
244            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
245            Kernel::Avx2 | Kernel::Avx2Batch => rsx_avx2(data, period, first, &mut out),
246            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
247            Kernel::Avx512 | Kernel::Avx512Batch => rsx_avx512(data, period, first, &mut out),
248
249            _ => rsx_scalar(data, period, first, &mut out),
250        }
251    }
252    Ok(RsxOutput { values: out })
253}
254
255#[inline]
256pub fn rsx_into_slice(dst: &mut [f64], input: &RsxInput, kern: Kernel) -> Result<(), RsxError> {
257    let (data, period, first, chosen) = rsx_prepare(input, kern)?;
258    if dst.len() != data.len() {
259        return Err(RsxError::OutputLengthMismatch {
260            expected: data.len(),
261            got: dst.len(),
262        });
263    }
264
265    unsafe {
266        match chosen {
267            #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
268            Kernel::Scalar | Kernel::ScalarBatch => rsx_simd128(data, period, first, dst),
269            #[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
270            Kernel::Scalar | Kernel::ScalarBatch => rsx_scalar(data, period, first, dst),
271
272            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
273            Kernel::Avx2 | Kernel::Avx2Batch => rsx_avx2(data, period, first, dst),
274            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
275            Kernel::Avx512 | Kernel::Avx512Batch => rsx_avx512(data, period, first, dst),
276
277            _ => rsx_scalar(data, period, first, dst),
278        }
279    }
280
281    let warmup_end = first + period - 1;
282    for v in &mut dst[..warmup_end] {
283        *v = f64::NAN;
284    }
285    Ok(())
286}
287
288#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
289#[inline]
290pub fn rsx_into(input: &RsxInput, out: &mut [f64]) -> Result<(), RsxError> {
291    rsx_into_slice(out, input, Kernel::Auto)
292}
293
294#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
295#[inline]
296pub fn rsx_avx512(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
297    if period <= 32 {
298        unsafe { rsx_avx512_short(data, period, first_valid, out) }
299    } else {
300        unsafe { rsx_avx512_long(data, period, first_valid, out) }
301    }
302}
303
304#[inline]
305pub fn rsx_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
306    let mut f0 = 0.0;
307    let mut f8 = 0.0;
308    let mut f18 = 0.0;
309    let mut f20 = 0.0;
310    let mut f28 = 0.0;
311    let mut f30 = 0.0;
312    let mut f38 = 0.0;
313    let mut f40 = 0.0;
314    let mut f48 = 0.0;
315    let mut f50 = 0.0;
316    let mut f58 = 0.0;
317    let mut f60 = 0.0;
318    let mut f68 = 0.0;
319    let mut f70 = 0.0;
320    let mut f78 = 0.0;
321    let mut f80 = 0.0;
322    let mut f88 = 0.0;
323    let mut f90 = 0.0;
324
325    let start = first + period - 1;
326    if start >= data.len() {
327        return;
328    }
329
330    f90 = 1.0;
331    f0 = 0.0;
332    f88 = if period >= 6 {
333        (period - 1) as f64
334    } else {
335        5.0
336    };
337    f8 = 100.0 * data[start];
338    f18 = 3.0 / (period as f64 + 2.0);
339    f20 = 1.0 - f18;
340    out[start] = f64::NAN;
341
342    for i in (start + 1)..data.len() {
343        f90 = if f88 <= f90 { f88 + 1.0 } else { f90 + 1.0 };
344
345        let prev = f8;
346        f8 = 100.0 * data[i];
347        let v8 = f8 - prev;
348
349        f28 = f20 * f28 + f18 * v8;
350        f30 = f18 * f28 + f20 * f30;
351        let v_c = f28 * 1.5 - f30 * 0.5;
352
353        f38 = f20 * f38 + f18 * v_c;
354        f40 = f18 * f38 + f20 * f40;
355        let v10 = f38 * 1.5 - f40 * 0.5;
356
357        f48 = f20 * f48 + f18 * v10;
358        f50 = f18 * f48 + f20 * f50;
359        let v14 = f48 * 1.5 - f50 * 0.5;
360
361        let av = v8.abs();
362        f58 = f20 * f58 + f18 * av;
363        f60 = f18 * f58 + f20 * f60;
364        let v18 = f58 * 1.5 - f60 * 0.5;
365
366        f68 = f20 * f68 + f18 * v18;
367        f70 = f18 * f68 + f20 * f70;
368        let v1c = f68 * 1.5 - f70 * 0.5;
369
370        f78 = f20 * f78 + f18 * v1c;
371        f80 = f18 * f78 + f20 * f80;
372        let v20_ = f78 * 1.5 - f80 * 0.5;
373
374        if f88 >= f90 && f8 != prev {
375            f0 = 1.0;
376        }
377        if (f88 - f90).abs() < f64::EPSILON && f0 == 0.0 {
378            f90 = 0.0;
379        }
380
381        if f88 < f90 && v20_ > 1e-10 {
382            let mut v4 = (v14 / v20_ + 1.0) * 50.0;
383            if v4 > 100.0 {
384                v4 = 100.0;
385            }
386            if v4 < 0.0 {
387                v4 = 0.0;
388            }
389            out[i] = v4;
390        } else {
391            out[i] = 50.0;
392        }
393    }
394}
395
396#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
397#[inline]
398pub unsafe fn rsx_avx2(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
399    let len = data.len();
400    let start = first + period - 1;
401    if start >= len {
402        return;
403    }
404
405    let mut f0 = 0.0;
406    let mut f8 = 100.0 * *data.get_unchecked(start);
407    let f18 = 3.0 / (period as f64 + 2.0);
408    let f20 = 1.0 - f18;
409    let mut f28 = 0.0;
410    let mut f30 = 0.0;
411    let mut f38 = 0.0;
412    let mut f40 = 0.0;
413    let mut f48 = 0.0;
414    let mut f50 = 0.0;
415    let mut f58 = 0.0;
416    let mut f60 = 0.0;
417    let mut f68 = 0.0;
418    let mut f70 = 0.0;
419    let mut f78 = 0.0;
420    let mut f80 = 0.0;
421    let mut f88 = if period >= 6 {
422        (period - 1) as f64
423    } else {
424        5.0
425    };
426    let mut f90 = 1.0;
427
428    *out.get_unchecked_mut(start) = f64::NAN;
429
430    for i in (start + 1)..len {
431        f90 = if f88 <= f90 { f88 + 1.0 } else { f90 + 1.0 };
432
433        let prev = f8;
434        f8 = 100.0 * *data.get_unchecked(i);
435        let v8 = f8 - prev;
436
437        f28 = f18.mul_add(v8, f20 * f28);
438        f30 = f18.mul_add(f28, f20 * f30);
439        let v_c = 1.5f64.mul_add(f28, -0.5 * f30);
440
441        f38 = f20.mul_add(f38, f18 * v_c);
442        f40 = f18.mul_add(f38, f20 * f40);
443        let v10 = 1.5f64.mul_add(f38, -0.5 * f40);
444
445        f48 = f20.mul_add(f48, f18 * v10);
446        f50 = f18.mul_add(f48, f20 * f50);
447        let v14 = 1.5f64.mul_add(f48, -0.5 * f50);
448
449        let av = v8.abs();
450        f58 = f20.mul_add(f58, f18 * av);
451        f60 = f18.mul_add(f58, f20 * f60);
452        let v18 = 1.5f64.mul_add(f58, -0.5 * f60);
453
454        f68 = f20.mul_add(f68, f18 * v18);
455        f70 = f18.mul_add(f68, f20 * f70);
456        let v1c = 1.5f64.mul_add(f68, -0.5 * f70);
457
458        f78 = f20.mul_add(f78, f18 * v1c);
459        f80 = f18.mul_add(f78, f20 * f80);
460        let v20_ = 1.5f64.mul_add(f78, -0.5 * f80);
461
462        if f88 >= f90 && f8 != prev {
463            f0 = 1.0;
464        }
465        if (f88 - f90).abs() < f64::EPSILON && f0 == 0.0 {
466            f90 = 0.0;
467        }
468
469        let y = if f88 < f90 && v20_ > 1e-10 {
470            let v4 = (v14 / v20_ + 1.0) * 50.0;
471            v4.max(0.0).min(100.0)
472        } else {
473            50.0
474        };
475
476        *out.get_unchecked_mut(i) = y;
477    }
478}
479
480#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
481#[inline]
482pub unsafe fn rsx_avx512_short(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
483    rsx_avx2(data, period, first, out)
484}
485
486#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
487#[inline]
488pub unsafe fn rsx_avx512_long(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
489    rsx_avx2(data, period, first, out)
490}
491
492#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
493#[inline]
494unsafe fn rsx_simd128(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
495    rsx_scalar(data, period, first, out)
496}
497
498#[derive(Debug, Clone)]
499pub struct RsxStream {
500    period: usize,
501
502    seen: usize,
503    init_done: bool,
504
505    alpha: f64,
506    beta: f64,
507    f88: f64,
508    f90: f64,
509
510    f0: f64,
511    f8: f64,
512
513    f28: f64,
514    f30: f64,
515    f38: f64,
516    f40: f64,
517    f48: f64,
518    f50: f64,
519    f58: f64,
520    f60: f64,
521    f68: f64,
522    f70: f64,
523    f78: f64,
524    f80: f64,
525}
526
527impl RsxStream {
528    pub fn try_new(params: RsxParams) -> Result<Self, RsxError> {
529        let period = params.period.unwrap_or(14);
530        if period == 0 {
531            return Err(RsxError::InvalidPeriod {
532                period,
533                data_len: 0,
534            });
535        }
536
537        let alpha = 3.0 / (period as f64 + 2.0);
538        let beta = 1.0 - alpha;
539
540        Ok(Self {
541            period,
542            seen: 0,
543            init_done: false,
544            alpha,
545            beta,
546            f88: if period >= 6 {
547                (period - 1) as f64
548            } else {
549                5.0
550            },
551            f90: 1.0,
552            f0: 0.0,
553            f8: 0.0,
554            f28: 0.0,
555            f30: 0.0,
556            f38: 0.0,
557            f40: 0.0,
558            f48: 0.0,
559            f50: 0.0,
560            f58: 0.0,
561            f60: 0.0,
562            f68: 0.0,
563            f70: 0.0,
564            f78: 0.0,
565            f80: 0.0,
566        })
567    }
568
569    #[inline(always)]
570    pub fn update(&mut self, value: f64) -> Option<f64> {
571        self.seen += 1;
572
573        if !self.init_done && self.seen < self.period {
574            return None;
575        }
576
577        if !self.init_done {
578            self.init_done = true;
579            self.f0 = 0.0;
580            self.f8 = 100.0 * value;
581            return Some(f64::NAN);
582        }
583
584        self.f90 = if self.f88 <= self.f90 {
585            self.f88 + 1.0
586        } else {
587            self.f90 + 1.0
588        };
589
590        let prev = self.f8;
591        self.f8 = 100.0 * value;
592        let v8 = self.f8 - prev;
593
594        self.f28 = self.beta * self.f28 + self.alpha * v8;
595        self.f30 = self.alpha * self.f28 + self.beta * self.f30;
596        let v_c = self.f28 * 1.5 - self.f30 * 0.5;
597
598        self.f38 = self.beta * self.f38 + self.alpha * v_c;
599        self.f40 = self.alpha * self.f38 + self.beta * self.f40;
600        let v10 = self.f38 * 1.5 - self.f40 * 0.5;
601
602        self.f48 = self.beta * self.f48 + self.alpha * v10;
603        self.f50 = self.alpha * self.f48 + self.beta * self.f50;
604        let v14 = self.f48 * 1.5 - self.f50 * 0.5;
605
606        let av = v8.abs();
607        self.f58 = self.beta * self.f58 + self.alpha * av;
608        self.f60 = self.alpha * self.f58 + self.beta * self.f60;
609        let v18 = self.f58 * 1.5 - self.f60 * 0.5;
610
611        self.f68 = self.beta * self.f68 + self.alpha * v18;
612        self.f70 = self.alpha * self.f68 + self.beta * self.f70;
613        let v1c = self.f68 * 1.5 - self.f70 * 0.5;
614
615        self.f78 = self.beta * self.f78 + self.alpha * v1c;
616        self.f80 = self.alpha * self.f78 + self.beta * self.f80;
617        let v20_ = self.f78 * 1.5 - self.f80 * 0.5;
618
619        if self.f88 >= self.f90 && self.f8 != prev {
620            self.f0 = 1.0;
621        }
622        if (self.f88 - self.f90).abs() < f64::EPSILON && self.f0 == 0.0 {
623            self.f90 = 0.0;
624        }
625
626        let y = if self.f88 < self.f90 && v20_ > 1e-10 {
627            let v4 = (v14 / v20_ + 1.0) * 50.0;
628            v4.max(0.0).min(100.0)
629        } else {
630            50.0
631        };
632
633        Some(y)
634    }
635}
636
637#[derive(Clone, Debug)]
638pub struct RsxBatchRange {
639    pub period: (usize, usize, usize),
640}
641
642impl Default for RsxBatchRange {
643    fn default() -> Self {
644        Self {
645            period: (14, 263, 1),
646        }
647    }
648}
649
650#[derive(Clone, Debug, Default)]
651pub struct RsxBatchBuilder {
652    range: RsxBatchRange,
653    kernel: Kernel,
654}
655
656impl RsxBatchBuilder {
657    pub fn new() -> Self {
658        Self::default()
659    }
660    pub fn kernel(mut self, k: Kernel) -> Self {
661        self.kernel = k;
662        self
663    }
664
665    #[inline]
666    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
667        self.range.period = (start, end, step);
668        self
669    }
670    #[inline]
671    pub fn period_static(mut self, p: usize) -> Self {
672        self.range.period = (p, p, 0);
673        self
674    }
675
676    pub fn apply_slice(self, data: &[f64]) -> Result<RsxBatchOutput, RsxError> {
677        rsx_batch_with_kernel(data, &self.range, self.kernel)
678    }
679
680    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<RsxBatchOutput, RsxError> {
681        RsxBatchBuilder::new().kernel(k).apply_slice(data)
682    }
683
684    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<RsxBatchOutput, RsxError> {
685        let slice = source_type(c, src);
686        self.apply_slice(slice)
687    }
688
689    pub fn with_default_candles(c: &Candles) -> Result<RsxBatchOutput, RsxError> {
690        RsxBatchBuilder::new()
691            .kernel(Kernel::Auto)
692            .apply_candles(c, "close")
693    }
694}
695
696pub fn rsx_batch_with_kernel(
697    data: &[f64],
698    sweep: &RsxBatchRange,
699    k: Kernel,
700) -> Result<RsxBatchOutput, RsxError> {
701    let kernel = match k {
702        Kernel::Auto => detect_best_batch_kernel(),
703        other if other.is_batch() => other,
704        other => return Err(RsxError::InvalidKernelForBatch(other)),
705    };
706
707    let simd = match kernel {
708        Kernel::Avx512Batch => Kernel::Avx512,
709        Kernel::Avx2Batch => Kernel::Avx2,
710        Kernel::ScalarBatch => Kernel::Scalar,
711        _ => unreachable!(),
712    };
713    rsx_batch_par_slice(data, sweep, simd)
714}
715
716#[derive(Clone, Debug)]
717pub struct RsxBatchOutput {
718    pub values: Vec<f64>,
719    pub combos: Vec<RsxParams>,
720    pub rows: usize,
721    pub cols: usize,
722}
723impl RsxBatchOutput {
724    pub fn row_for_params(&self, p: &RsxParams) -> Option<usize> {
725        self.combos
726            .iter()
727            .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
728    }
729
730    pub fn values_for(&self, p: &RsxParams) -> Option<&[f64]> {
731        self.row_for_params(p).map(|row| {
732            let start = row * self.cols;
733            &self.values[start..start + self.cols]
734        })
735    }
736}
737
738#[inline(always)]
739fn expand_grid(r: &RsxBatchRange) -> Result<Vec<RsxParams>, RsxError> {
740    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, RsxError> {
741        if step == 0 || start == end {
742            return Ok(vec![start]);
743        }
744        if start < end {
745            let st = step.max(1);
746            let v: Vec<usize> = (start..=end).step_by(st).collect();
747            if v.is_empty() {
748                return Err(RsxError::InvalidRange {
749                    start: start.to_string(),
750                    end: end.to_string(),
751                    step: step.to_string(),
752                });
753            }
754            return Ok(v);
755        }
756
757        let mut v = Vec::new();
758        let mut x = start as isize;
759        let end_i = end as isize;
760        let st = (step as isize).max(1);
761        while x >= end_i {
762            v.push(x as usize);
763            x -= st;
764        }
765        if v.is_empty() {
766            return Err(RsxError::InvalidRange {
767                start: start.to_string(),
768                end: end.to_string(),
769                step: step.to_string(),
770            });
771        }
772        Ok(v)
773    }
774    let periods = axis_usize(r.period)?;
775    if periods.is_empty() {
776        return Err(RsxError::InvalidRange {
777            start: r.period.0.to_string(),
778            end: r.period.1.to_string(),
779            step: r.period.2.to_string(),
780        });
781    }
782    let mut out = Vec::with_capacity(periods.len());
783    for &p in &periods {
784        out.push(RsxParams { period: Some(p) });
785    }
786    Ok(out)
787}
788
789#[inline(always)]
790pub fn rsx_batch_slice(
791    data: &[f64],
792    sweep: &RsxBatchRange,
793    kern: Kernel,
794) -> Result<RsxBatchOutput, RsxError> {
795    rsx_batch_inner(data, sweep, kern, false)
796}
797
798#[inline(always)]
799pub fn rsx_batch_par_slice(
800    data: &[f64],
801    sweep: &RsxBatchRange,
802    kern: Kernel,
803) -> Result<RsxBatchOutput, RsxError> {
804    rsx_batch_inner(data, sweep, kern, true)
805}
806
807#[inline(always)]
808fn rsx_batch_inner(
809    data: &[f64],
810    sweep: &RsxBatchRange,
811    kern: Kernel,
812    parallel: bool,
813) -> Result<RsxBatchOutput, RsxError> {
814    let combos = expand_grid(sweep)?;
815    if combos.is_empty() {
816        return Err(RsxError::InvalidRange {
817            start: "range".into(),
818            end: "range".into(),
819            step: "empty".into(),
820        });
821    }
822
823    let first = data
824        .iter()
825        .position(|x| !x.is_nan())
826        .ok_or(RsxError::AllValuesNaN)?;
827    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
828    if data.len() - first < max_p {
829        return Err(RsxError::NotEnoughValidData {
830            needed: max_p,
831            valid: data.len() - first,
832        });
833    }
834
835    let rows = combos.len();
836    let cols = data.len();
837
838    let _ = rows
839        .checked_mul(cols)
840        .ok_or_else(|| RsxError::InvalidRange {
841            start: rows.to_string(),
842            end: cols.to_string(),
843            step: "rows*cols".into(),
844        })?;
845
846    let actual_kernel = match kern {
847        Kernel::Auto => Kernel::Scalar,
848        k => k,
849    };
850
851    let mut buf_mu = make_uninit_matrix(rows, cols);
852
853    let warmup_periods: Vec<usize> = combos
854        .iter()
855        .map(|c| first + c.period.unwrap() - 1)
856        .collect();
857
858    init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
859
860    let mut buf_guard = std::mem::ManuallyDrop::new(buf_mu);
861    let values_slice: &mut [f64] =
862        unsafe { std::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, rows * cols) };
863
864    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
865        let period = combos[row].period.unwrap();
866        match actual_kernel {
867            Kernel::Scalar | Kernel::ScalarBatch => rsx_row_scalar(data, first, period, out_row),
868            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
869            Kernel::Avx2 | Kernel::Avx2Batch => rsx_row_avx2(data, first, period, out_row),
870            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
871            Kernel::Avx512 | Kernel::Avx512Batch => rsx_row_avx512(data, first, period, out_row),
872            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
873            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
874                rsx_row_scalar(data, first, period, out_row)
875            }
876            Kernel::Auto => unreachable!("Auto kernel should have been resolved"),
877        }
878    };
879
880    if parallel {
881        #[cfg(not(target_arch = "wasm32"))]
882        {
883            values_slice
884                .par_chunks_mut(cols)
885                .enumerate()
886                .for_each(|(row, slice)| do_row(row, slice));
887        }
888
889        #[cfg(target_arch = "wasm32")]
890        {
891            for (row, slice) in values_slice.chunks_mut(cols).enumerate() {
892                do_row(row, slice);
893            }
894        }
895    } else {
896        for (row, slice) in values_slice.chunks_mut(cols).enumerate() {
897            do_row(row, slice);
898        }
899    }
900
901    let values = unsafe {
902        Vec::from_raw_parts(buf_guard.as_mut_ptr() as *mut f64, rows * cols, rows * cols)
903    };
904
905    Ok(RsxBatchOutput {
906        values,
907        combos,
908        rows,
909        cols,
910    })
911}
912
913#[inline(always)]
914unsafe fn rsx_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
915    rsx_scalar(data, period, first, out);
916}
917
918#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
919#[inline(always)]
920unsafe fn rsx_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
921    rsx_avx2(data, period, first, out);
922}
923
924#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
925#[inline(always)]
926unsafe fn rsx_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
927    rsx_avx512(data, period, first, out);
928}
929
930#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
931#[inline(always)]
932unsafe fn rsx_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
933    rsx_avx512_short(data, period, first, out);
934}
935
936#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
937#[inline(always)]
938unsafe fn rsx_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
939    rsx_avx512_long(data, period, first, out);
940}
941
942#[cfg(test)]
943mod tests {
944    use super::*;
945    use crate::skip_if_unsupported;
946    use crate::utilities::data_loader::read_candles_from_csv;
947
948    #[test]
949    fn test_rsx_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
950        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
951        let candles = read_candles_from_csv(file)?;
952
953        let input = RsxInput::with_default_candles(&candles);
954
955        let RsxOutput { values: expected } = rsx(&input)?;
956
957        let mut out = vec![0.0; candles.close.len()];
958
959        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
960        {
961            rsx_into(&input, &mut out)?;
962        }
963        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
964        {
965            rsx_into_slice(&mut out, &input, Kernel::Auto)?;
966        }
967
968        assert_eq!(expected.len(), out.len());
969
970        for i in 0..out.len() {
971            let a = expected[i];
972            let b = out[i];
973            if a.is_nan() || b.is_nan() {
974                assert!(
975                    a.is_nan() && b.is_nan(),
976                    "NaN mismatch at {i}: {a:?} vs {b:?}"
977                );
978            } else {
979                assert!(a == b, "Mismatch at {i}: {a} vs {b}");
980            }
981        }
982
983        Ok(())
984    }
985
986    fn check_rsx_partial_params(
987        test_name: &str,
988        kernel: Kernel,
989    ) -> Result<(), Box<dyn std::error::Error>> {
990        skip_if_unsupported!(kernel, test_name);
991        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
992        let candles = read_candles_from_csv(file_path)?;
993
994        let default_params = RsxParams { period: None };
995        let input_default = RsxInput::from_candles(&candles, "close", default_params);
996        let output_default = rsx_with_kernel(&input_default, kernel)?;
997        assert_eq!(output_default.values.len(), candles.close.len());
998
999        let params_period_10 = RsxParams { period: Some(10) };
1000        let input_period_10 = RsxInput::from_candles(&candles, "hl2", params_period_10);
1001        let output_period_10 = rsx_with_kernel(&input_period_10, kernel)?;
1002        assert_eq!(output_period_10.values.len(), candles.close.len());
1003
1004        let params_custom = RsxParams { period: Some(20) };
1005        let input_custom = RsxInput::from_candles(&candles, "hlc3", params_custom);
1006        let output_custom = rsx_with_kernel(&input_custom, kernel)?;
1007        assert_eq!(output_custom.values.len(), candles.close.len());
1008
1009        Ok(())
1010    }
1011
1012    fn check_rsx_accuracy(
1013        test_name: &str,
1014        kernel: Kernel,
1015    ) -> Result<(), Box<dyn std::error::Error>> {
1016        skip_if_unsupported!(kernel, test_name);
1017        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1018        let candles = read_candles_from_csv(file_path)?;
1019        let params = RsxParams { period: Some(14) };
1020        let input = RsxInput::from_candles(&candles, "close", params);
1021        let rsx_result = rsx_with_kernel(&input, kernel)?;
1022
1023        let expected_last_five_rsx = [
1024            46.11486311289701,
1025            46.88048640321688,
1026            47.174443049619995,
1027            47.48751360654475,
1028            46.552886446171684,
1029        ];
1030        let start_index = rsx_result.values.len() - 5;
1031        let result_last_five_rsx = &rsx_result.values[start_index..];
1032        for (i, &value) in result_last_five_rsx.iter().enumerate() {
1033            let expected_value = expected_last_five_rsx[i];
1034            assert!(
1035                (value - expected_value).abs() < 1e-1,
1036                "RSX mismatch at index {}: expected {}, got {}",
1037                i,
1038                expected_value,
1039                value
1040            );
1041        }
1042        Ok(())
1043    }
1044
1045    fn check_rsx_default_candles(
1046        test_name: &str,
1047        kernel: Kernel,
1048    ) -> Result<(), Box<dyn std::error::Error>> {
1049        skip_if_unsupported!(kernel, test_name);
1050        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1051        let candles = read_candles_from_csv(file_path)?;
1052        let input = RsxInput::with_default_candles(&candles);
1053        match input.data {
1054            RsxData::Candles { source, .. } => assert_eq!(source, "close"),
1055            _ => panic!("Expected RsxData::Candles"),
1056        }
1057        let output = rsx_with_kernel(&input, kernel)?;
1058        assert_eq!(output.values.len(), candles.close.len());
1059        Ok(())
1060    }
1061
1062    fn check_rsx_zero_period(
1063        test_name: &str,
1064        kernel: Kernel,
1065    ) -> Result<(), Box<dyn std::error::Error>> {
1066        skip_if_unsupported!(kernel, test_name);
1067        let input_data = [10.0, 20.0, 30.0];
1068        let params = RsxParams { period: Some(0) };
1069        let input = RsxInput::from_slice(&input_data, params);
1070        let res = rsx_with_kernel(&input, kernel);
1071        assert!(
1072            res.is_err(),
1073            "[{}] RSX should fail with zero period",
1074            test_name
1075        );
1076        Ok(())
1077    }
1078
1079    fn check_rsx_period_exceeds_length(
1080        test_name: &str,
1081        kernel: Kernel,
1082    ) -> Result<(), Box<dyn std::error::Error>> {
1083        skip_if_unsupported!(kernel, test_name);
1084        let data_small = [10.0, 20.0, 30.0];
1085        let params = RsxParams { period: Some(10) };
1086        let input = RsxInput::from_slice(&data_small, params);
1087        let res = rsx_with_kernel(&input, kernel);
1088        assert!(
1089            res.is_err(),
1090            "[{}] RSX should fail with period exceeding length",
1091            test_name
1092        );
1093        Ok(())
1094    }
1095
1096    fn check_rsx_very_small_dataset(
1097        test_name: &str,
1098        kernel: Kernel,
1099    ) -> Result<(), Box<dyn std::error::Error>> {
1100        skip_if_unsupported!(kernel, test_name);
1101        let single_point = [42.0];
1102        let params = RsxParams { period: Some(14) };
1103        let input = RsxInput::from_slice(&single_point, params);
1104        let res = rsx_with_kernel(&input, kernel);
1105        assert!(
1106            res.is_err(),
1107            "[{}] RSX should fail with insufficient data",
1108            test_name
1109        );
1110        Ok(())
1111    }
1112
1113    fn check_rsx_reinput(
1114        test_name: &str,
1115        kernel: Kernel,
1116    ) -> Result<(), Box<dyn std::error::Error>> {
1117        skip_if_unsupported!(kernel, test_name);
1118        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1119        let candles = read_candles_from_csv(file_path)?;
1120
1121        let first_params = RsxParams { period: Some(14) };
1122        let first_input = RsxInput::from_candles(&candles, "close", first_params);
1123        let first_result = rsx_with_kernel(&first_input, kernel)?;
1124
1125        let second_params = RsxParams { period: Some(14) };
1126        let second_input = RsxInput::from_slice(&first_result.values, second_params);
1127        let second_result = rsx_with_kernel(&second_input, kernel)?;
1128
1129        assert_eq!(second_result.values.len(), first_result.values.len());
1130        Ok(())
1131    }
1132
1133    fn check_rsx_nan_handling(
1134        test_name: &str,
1135        kernel: Kernel,
1136    ) -> Result<(), Box<dyn std::error::Error>> {
1137        skip_if_unsupported!(kernel, test_name);
1138        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1139        let candles = read_candles_from_csv(file_path)?;
1140
1141        let input = RsxInput::from_candles(&candles, "close", RsxParams { period: Some(14) });
1142        let res = rsx_with_kernel(&input, kernel)?;
1143        assert_eq!(res.values.len(), candles.close.len());
1144        if res.values.len() > 50 {
1145            for (i, &val) in res.values[50..].iter().enumerate() {
1146                assert!(
1147                    !val.is_nan(),
1148                    "[{}] Found unexpected NaN at out-index {}",
1149                    test_name,
1150                    50 + i
1151                );
1152            }
1153        }
1154        Ok(())
1155    }
1156
1157    fn check_rsx_streaming(
1158        test_name: &str,
1159        kernel: Kernel,
1160    ) -> Result<(), Box<dyn std::error::Error>> {
1161        skip_if_unsupported!(kernel, test_name);
1162        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1163        let candles = read_candles_from_csv(file_path)?;
1164
1165        let period = 14;
1166
1167        let input = RsxInput::from_candles(
1168            &candles,
1169            "close",
1170            RsxParams {
1171                period: Some(period),
1172            },
1173        );
1174        let batch_output = rsx_with_kernel(&input, kernel)?.values;
1175
1176        let mut stream = RsxStream::try_new(RsxParams {
1177            period: Some(period),
1178        })?;
1179
1180        let mut stream_values = Vec::with_capacity(candles.close.len());
1181        for &price in &candles.close {
1182            match stream.update(price) {
1183                Some(rsx_val) => stream_values.push(rsx_val),
1184                None => stream_values.push(f64::NAN),
1185            }
1186        }
1187
1188        assert_eq!(batch_output.len(), stream_values.len());
1189        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1190            if b.is_nan() && s.is_nan() {
1191                continue;
1192            }
1193            let diff = (b - s).abs();
1194            assert!(
1195                diff < 1e-9,
1196                "[{}] RSX streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1197                test_name,
1198                i,
1199                b,
1200                s,
1201                diff
1202            );
1203        }
1204        Ok(())
1205    }
1206
1207    macro_rules! generate_all_rsx_tests {
1208        ($($test_fn:ident),*) => {
1209            paste::paste! {
1210                $(
1211                    #[test]
1212                    fn [<$test_fn _scalar_f64>]() {
1213                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1214                    }
1215                )*
1216                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1217                $(
1218                    #[test]
1219                    fn [<$test_fn _avx2_f64>]() {
1220                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1221                    }
1222                    #[test]
1223                    fn [<$test_fn _avx512_f64>]() {
1224                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1225                    }
1226                )*
1227                #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1228                $(
1229                    #[test]
1230                    fn [<$test_fn _simd128_f64>]() {
1231                        let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
1232                    }
1233                )*
1234            }
1235        }
1236    }
1237
1238    #[cfg(debug_assertions)]
1239    fn check_rsx_no_poison(
1240        test_name: &str,
1241        kernel: Kernel,
1242    ) -> Result<(), Box<dyn std::error::Error>> {
1243        skip_if_unsupported!(kernel, test_name);
1244
1245        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1246        let candles = read_candles_from_csv(file_path)?;
1247
1248        let test_params = vec![
1249            RsxParams::default(),
1250            RsxParams { period: Some(2) },
1251            RsxParams { period: Some(7) },
1252            RsxParams { period: Some(21) },
1253            RsxParams { period: Some(50) },
1254            RsxParams { period: Some(100) },
1255            RsxParams { period: Some(200) },
1256            RsxParams { period: Some(5) },
1257            RsxParams { period: Some(10) },
1258            RsxParams { period: Some(20) },
1259            RsxParams { period: Some(30) },
1260            RsxParams { period: Some(40) },
1261        ];
1262
1263        for (param_idx, params) in test_params.iter().enumerate() {
1264            let input = RsxInput::from_candles(&candles, "close", params.clone());
1265            let output = rsx_with_kernel(&input, kernel)?;
1266
1267            for (i, &val) in output.values.iter().enumerate() {
1268                if val.is_nan() {
1269                    continue;
1270                }
1271
1272                let bits = val.to_bits();
1273
1274                if bits == 0x11111111_11111111 {
1275                    panic!(
1276                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1277						 with params: period={} (param set {})",
1278                        test_name,
1279                        val,
1280                        bits,
1281                        i,
1282                        params.period.unwrap_or(14),
1283                        param_idx
1284                    );
1285                }
1286
1287                if bits == 0x22222222_22222222 {
1288                    panic!(
1289                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1290						 with params: period={} (param set {})",
1291                        test_name,
1292                        val,
1293                        bits,
1294                        i,
1295                        params.period.unwrap_or(14),
1296                        param_idx
1297                    );
1298                }
1299
1300                if bits == 0x33333333_33333333 {
1301                    panic!(
1302                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1303						 with params: period={} (param set {})",
1304                        test_name,
1305                        val,
1306                        bits,
1307                        i,
1308                        params.period.unwrap_or(14),
1309                        param_idx
1310                    );
1311                }
1312            }
1313        }
1314
1315        Ok(())
1316    }
1317
1318    #[cfg(not(debug_assertions))]
1319    fn check_rsx_no_poison(
1320        _test_name: &str,
1321        _kernel: Kernel,
1322    ) -> Result<(), Box<dyn std::error::Error>> {
1323        Ok(())
1324    }
1325
1326    #[cfg(feature = "proptest")]
1327    #[allow(clippy::float_cmp)]
1328    fn check_rsx_property(
1329        test_name: &str,
1330        kernel: Kernel,
1331    ) -> Result<(), Box<dyn std::error::Error>> {
1332        use proptest::prelude::*;
1333        skip_if_unsupported!(kernel, test_name);
1334
1335        let strat = (2usize..=64).prop_flat_map(|period| {
1336            (
1337                prop::collection::vec(
1338                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1339                    period..400,
1340                ),
1341                Just(period),
1342            )
1343        });
1344
1345        proptest::test_runner::TestRunner::default()
1346            .run(&strat, |(data, period)| {
1347                let params = RsxParams {
1348                    period: Some(period),
1349                };
1350                let input = RsxInput::from_slice(&data, params);
1351
1352                let RsxOutput { values: out } = rsx_with_kernel(&input, kernel).unwrap();
1353                let RsxOutput { values: ref_out } =
1354                    rsx_with_kernel(&input, Kernel::Scalar).unwrap();
1355
1356                for i in period..data.len() {
1357                    let y = out[i];
1358                    if !y.is_nan() {
1359                        prop_assert!(
1360                            y >= 0.0 && y <= 100.0,
1361                            "idx {i}: RSX value {y} outside [0, 100] bounds"
1362                        );
1363                    }
1364                }
1365
1366                for i in 0..period.min(data.len()) {
1367                    prop_assert!(
1368                        out[i].is_nan(),
1369                        "idx {i}: Expected NaN during warmup, got {}",
1370                        out[i]
1371                    );
1372                }
1373
1374                if data.len() > period {
1375                    prop_assert!(
1376                        !out[period].is_nan(),
1377                        "idx {}: Expected valid RSX value after warmup, got NaN",
1378                        period
1379                    );
1380                }
1381
1382                if data.windows(2).all(|w| (w[0] - w[1]).abs() < f64::EPSILON)
1383                    && data.len() > period
1384                {
1385                    for i in period..data.len() {
1386                        prop_assert!(
1387                            (out[i] - 50.0).abs() <= 1e-9,
1388                            "idx {i}: Constant data should produce RSX=50.0, got {}",
1389                            out[i]
1390                        );
1391                    }
1392                }
1393
1394                let is_strictly_increasing = data.windows(2).all(|w| w[1] > w[0]);
1395                if is_strictly_increasing && data.len() >= period + 10 {
1396                    for i in (period + 10)..data.len() {
1397                        prop_assert!(
1398                            out[i] > 50.0 || (out[i] - 50.0).abs() < 1e-9,
1399                            "idx {i}: Strictly increasing prices should produce RSX > 50, got {}",
1400                            out[i]
1401                        );
1402                    }
1403                }
1404
1405                let is_strictly_decreasing = data.windows(2).all(|w| w[1] < w[0]);
1406                if is_strictly_decreasing && data.len() >= period + 10 {
1407                    for i in (period + 10)..data.len() {
1408                        prop_assert!(
1409                            out[i] < 50.0 || (out[i] - 50.0).abs() < 1e-9,
1410                            "idx {i}: Strictly decreasing prices should produce RSX < 50, got {}",
1411                            out[i]
1412                        );
1413                    }
1414                }
1415
1416                for i in 0..data.len() {
1417                    let y = out[i];
1418                    let r = ref_out[i];
1419
1420                    if !y.is_finite() || !r.is_finite() {
1421                        prop_assert!(
1422                            y.to_bits() == r.to_bits(),
1423                            "finite/NaN mismatch idx {i}: {y} vs {r}"
1424                        );
1425                        continue;
1426                    }
1427
1428                    let y_bits = y.to_bits();
1429                    let r_bits = r.to_bits();
1430                    let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1431
1432                    prop_assert!(
1433                        (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1434                        "kernel mismatch idx {i}: {y} vs {r} (ULP={ulp_diff})"
1435                    );
1436                }
1437
1438                let RsxOutput { values: out2 } = rsx_with_kernel(&input, kernel).unwrap();
1439                for i in 0..data.len() {
1440                    prop_assert!(
1441                        out[i].to_bits() == out2[i].to_bits(),
1442                        "determinism failed at idx {i}: first={}, second={}",
1443                        out[i],
1444                        out2[i]
1445                    );
1446                }
1447
1448                Ok(())
1449            })
1450            .unwrap();
1451
1452        Ok(())
1453    }
1454
1455    generate_all_rsx_tests!(
1456        check_rsx_partial_params,
1457        check_rsx_accuracy,
1458        check_rsx_default_candles,
1459        check_rsx_zero_period,
1460        check_rsx_period_exceeds_length,
1461        check_rsx_very_small_dataset,
1462        check_rsx_reinput,
1463        check_rsx_nan_handling,
1464        check_rsx_streaming,
1465        check_rsx_no_poison
1466    );
1467
1468    #[cfg(feature = "proptest")]
1469    generate_all_rsx_tests!(check_rsx_property);
1470
1471    fn check_batch_default_row(
1472        test: &str,
1473        kernel: Kernel,
1474    ) -> Result<(), Box<dyn std::error::Error>> {
1475        skip_if_unsupported!(kernel, test);
1476
1477        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1478        let c = read_candles_from_csv(file)?;
1479
1480        let output = RsxBatchBuilder::new()
1481            .kernel(kernel)
1482            .apply_candles(&c, "close")?;
1483
1484        let def = RsxParams::default();
1485        let row = output.values_for(&def).expect("default row missing");
1486
1487        assert_eq!(row.len(), c.close.len());
1488
1489        let expected = [
1490            46.11486311289701,
1491            46.88048640321688,
1492            47.174443049619995,
1493            47.48751360654475,
1494            46.552886446171684,
1495        ];
1496        let start = row.len() - 5;
1497        for (i, &v) in row[start..].iter().enumerate() {
1498            assert!(
1499                (v - expected[i]).abs() < 1e-1,
1500                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1501            );
1502        }
1503        Ok(())
1504    }
1505
1506    macro_rules! gen_batch_tests {
1507        ($fn_name:ident) => {
1508            paste::paste! {
1509                #[test] fn [<$fn_name _scalar>]()      {
1510                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1511                }
1512                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1513                #[test] fn [<$fn_name _avx2>]()        {
1514                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1515                }
1516                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1517                #[test] fn [<$fn_name _avx512>]()      {
1518                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1519                }
1520                #[test] fn [<$fn_name _auto_detect>]() {
1521                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1522                }
1523            }
1524        };
1525    }
1526    #[cfg(debug_assertions)]
1527    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1528        skip_if_unsupported!(kernel, test);
1529
1530        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1531        let c = read_candles_from_csv(file)?;
1532
1533        let test_configs = vec![
1534            (2, 10, 2),
1535            (5, 25, 5),
1536            (30, 60, 15),
1537            (2, 5, 1),
1538            (10, 100, 10),
1539            (14, 28, 7),
1540            (3, 9, 3),
1541            (50, 150, 25),
1542        ];
1543
1544        for (cfg_idx, &(period_start, period_end, period_step)) in test_configs.iter().enumerate() {
1545            let output = RsxBatchBuilder::new()
1546                .kernel(kernel)
1547                .period_range(period_start, period_end, period_step)
1548                .apply_candles(&c, "close")?;
1549
1550            for (idx, &val) in output.values.iter().enumerate() {
1551                if val.is_nan() {
1552                    continue;
1553                }
1554
1555                let bits = val.to_bits();
1556                let row = idx / output.cols;
1557                let col = idx % output.cols;
1558                let combo = &output.combos[row];
1559
1560                if bits == 0x11111111_11111111 {
1561                    panic!(
1562                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1563						 at row {} col {} (flat index {}) with params: period={}",
1564                        test,
1565                        cfg_idx,
1566                        val,
1567                        bits,
1568                        row,
1569                        col,
1570                        idx,
1571                        combo.period.unwrap_or(14)
1572                    );
1573                }
1574
1575                if bits == 0x22222222_22222222 {
1576                    panic!(
1577                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1578						 at row {} col {} (flat index {}) with params: period={}",
1579                        test,
1580                        cfg_idx,
1581                        val,
1582                        bits,
1583                        row,
1584                        col,
1585                        idx,
1586                        combo.period.unwrap_or(14)
1587                    );
1588                }
1589
1590                if bits == 0x33333333_33333333 {
1591                    panic!(
1592                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1593						 at row {} col {} (flat index {}) with params: period={}",
1594                        test,
1595                        cfg_idx,
1596                        val,
1597                        bits,
1598                        row,
1599                        col,
1600                        idx,
1601                        combo.period.unwrap_or(14)
1602                    );
1603                }
1604            }
1605        }
1606
1607        Ok(())
1608    }
1609
1610    #[cfg(not(debug_assertions))]
1611    fn check_batch_no_poison(
1612        _test: &str,
1613        _kernel: Kernel,
1614    ) -> Result<(), Box<dyn std::error::Error>> {
1615        Ok(())
1616    }
1617
1618    gen_batch_tests!(check_batch_default_row);
1619    gen_batch_tests!(check_batch_no_poison);
1620}
1621
1622#[cfg(feature = "python")]
1623#[pyfunction(name = "rsx")]
1624#[pyo3(signature = (data, period, kernel=None))]
1625pub fn rsx_py<'py>(
1626    py: Python<'py>,
1627    data: PyReadonlyArray1<'py, f64>,
1628    period: usize,
1629    kernel: Option<&str>,
1630) -> PyResult<Bound<'py, PyArray1<f64>>> {
1631    use numpy::{IntoPyArray, PyArrayMethods};
1632
1633    let slice_in = data.as_slice()?;
1634    let kern = validate_kernel(kernel, false)?;
1635
1636    let params = RsxParams {
1637        period: Some(period),
1638    };
1639    let input = RsxInput::from_slice(slice_in, params);
1640
1641    let result_vec: Vec<f64> = py
1642        .allow_threads(|| rsx_with_kernel(&input, kern).map(|o| o.values))
1643        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1644
1645    Ok(result_vec.into_pyarray(py))
1646}
1647
1648#[cfg(feature = "python")]
1649#[pyclass(name = "RsxStream")]
1650pub struct RsxStreamPy {
1651    stream: RsxStream,
1652}
1653
1654#[cfg(feature = "python")]
1655#[pymethods]
1656impl RsxStreamPy {
1657    #[new]
1658    fn new(period: usize) -> PyResult<Self> {
1659        let params = RsxParams {
1660            period: Some(period),
1661        };
1662        let stream =
1663            RsxStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1664        Ok(RsxStreamPy { stream })
1665    }
1666
1667    fn update(&mut self, value: f64) -> Option<f64> {
1668        self.stream.update(value)
1669    }
1670}
1671
1672#[cfg(feature = "python")]
1673#[pyfunction(name = "rsx_batch")]
1674#[pyo3(signature = (data, period_range, kernel=None))]
1675pub fn rsx_batch_py<'py>(
1676    py: Python<'py>,
1677    data: PyReadonlyArray1<'py, f64>,
1678    period_range: (usize, usize, usize),
1679    kernel: Option<&str>,
1680) -> PyResult<Bound<'py, PyDict>> {
1681    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1682    use pyo3::types::PyDict;
1683
1684    let slice_in = data.as_slice()?;
1685    let kern = validate_kernel(kernel, true)?;
1686
1687    let sweep = RsxBatchRange {
1688        period: period_range,
1689    };
1690
1691    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1692    let rows = combos.len();
1693    let cols = slice_in.len();
1694    let total = rows
1695        .checked_mul(cols)
1696        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1697
1698    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1699    let slice_out = unsafe { out_arr.as_slice_mut()? };
1700
1701    let combos = py
1702        .allow_threads(|| {
1703            let kernel = match kern {
1704                Kernel::Auto => detect_best_batch_kernel(),
1705                k => k,
1706            };
1707
1708            let simd = match kernel {
1709                Kernel::Avx512Batch => Kernel::Avx512,
1710                Kernel::Avx2Batch => Kernel::Avx2,
1711                Kernel::ScalarBatch => Kernel::Scalar,
1712                _ => kernel,
1713            };
1714
1715            rsx_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1716        })
1717        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1718
1719    let dict = PyDict::new(py);
1720    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1721
1722    dict.set_item(
1723        "periods",
1724        combos
1725            .iter()
1726            .map(|p| p.period.unwrap() as u64)
1727            .collect::<Vec<_>>()
1728            .into_pyarray(py),
1729    )?;
1730
1731    Ok(dict)
1732}
1733
1734#[cfg(any(feature = "python", feature = "wasm"))]
1735#[inline(always)]
1736fn rsx_batch_inner_into(
1737    data: &[f64],
1738    sweep: &RsxBatchRange,
1739    kern: Kernel,
1740    parallel: bool,
1741    out: &mut [f64],
1742) -> Result<Vec<RsxParams>, RsxError> {
1743    let combos = expand_grid(sweep)?;
1744    if combos.is_empty() {
1745        return Err(RsxError::InvalidRange {
1746            start: "range".into(),
1747            end: "range".into(),
1748            step: "empty".into(),
1749        });
1750    }
1751
1752    let len = data.len();
1753    if len == 0 {
1754        return Err(RsxError::EmptyInputData);
1755    }
1756    let first = data
1757        .iter()
1758        .position(|x| !x.is_nan())
1759        .ok_or(RsxError::AllValuesNaN)?;
1760    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1761    if len - first < max_p {
1762        return Err(RsxError::NotEnoughValidData {
1763            needed: max_p,
1764            valid: len - first,
1765        });
1766    }
1767
1768    let rows = combos.len();
1769    let cols = len;
1770
1771    unsafe {
1772        let out_mu =
1773            std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len());
1774        let warm: Vec<usize> = combos
1775            .iter()
1776            .map(|c| first + c.period.unwrap() - 1)
1777            .collect();
1778        init_matrix_prefixes(out_mu, cols, &warm);
1779    }
1780
1781    let actual = match kern {
1782        Kernel::Auto => Kernel::Scalar,
1783        k => k,
1784    };
1785
1786    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1787        let period = combos[row].period.unwrap();
1788        match actual {
1789            #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1790            Kernel::Scalar | Kernel::ScalarBatch => rsx_simd128(data, period, first, out_row),
1791            #[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
1792            Kernel::Scalar | Kernel::ScalarBatch => rsx_scalar(data, period, first, out_row),
1793
1794            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1795            Kernel::Avx2 | Kernel::Avx2Batch => rsx_avx2(data, period, first, out_row),
1796            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1797            Kernel::Avx512 | Kernel::Avx512Batch => rsx_avx512(data, period, first, out_row),
1798            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1799            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
1800                rsx_scalar(data, period, first, out_row)
1801            }
1802
1803            _ => rsx_scalar(data, period, first, out_row),
1804        }
1805    };
1806
1807    if parallel {
1808        #[cfg(not(target_arch = "wasm32"))]
1809        {
1810            out.par_chunks_mut(cols)
1811                .enumerate()
1812                .for_each(|(row, slice)| do_row(row, slice));
1813        }
1814
1815        #[cfg(target_arch = "wasm32")]
1816        {
1817            for (row, slice) in out.chunks_mut(cols).enumerate() {
1818                do_row(row, slice);
1819            }
1820        }
1821    } else {
1822        for (row, slice) in out.chunks_mut(cols).enumerate() {
1823            do_row(row, slice);
1824        }
1825    }
1826
1827    Ok(combos)
1828}
1829
1830#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1831#[wasm_bindgen]
1832pub fn rsx_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1833    let params = RsxParams {
1834        period: Some(period),
1835    };
1836    let input = RsxInput::from_slice(data, params);
1837
1838    let mut output = vec![0.0; data.len()];
1839
1840    rsx_into_slice(&mut output, &input, Kernel::Auto)
1841        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1842
1843    Ok(output)
1844}
1845
1846#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1847#[wasm_bindgen]
1848pub fn rsx_into(
1849    in_ptr: *const f64,
1850    out_ptr: *mut f64,
1851    len: usize,
1852    period: usize,
1853) -> Result<(), JsValue> {
1854    if in_ptr.is_null() || out_ptr.is_null() {
1855        return Err(JsValue::from_str("null pointer passed to rsx_into"));
1856    }
1857
1858    unsafe {
1859        let data = std::slice::from_raw_parts(in_ptr, len);
1860
1861        if period == 0 || period > len {
1862            return Err(JsValue::from_str("Invalid period"));
1863        }
1864
1865        let params = RsxParams {
1866            period: Some(period),
1867        };
1868        let input = RsxInput::from_slice(data, params);
1869
1870        if in_ptr == out_ptr {
1871            let mut temp = vec![0.0; len];
1872            rsx_into_slice(&mut temp, &input, Kernel::Auto)
1873                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1874            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1875            out.copy_from_slice(&temp);
1876        } else {
1877            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1878            rsx_into_slice(out, &input, Kernel::Auto)
1879                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1880        }
1881
1882        Ok(())
1883    }
1884}
1885
1886#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1887#[wasm_bindgen]
1888pub fn rsx_alloc(len: usize) -> *mut f64 {
1889    let mut vec = Vec::<f64>::with_capacity(len);
1890    let ptr = vec.as_mut_ptr();
1891    std::mem::forget(vec);
1892    ptr
1893}
1894
1895#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1896#[wasm_bindgen]
1897pub fn rsx_free(ptr: *mut f64, len: usize) {
1898    if !ptr.is_null() {
1899        unsafe {
1900            let _ = Vec::from_raw_parts(ptr, len, len);
1901        }
1902    }
1903}
1904
1905#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1906#[derive(Serialize, Deserialize)]
1907pub struct RsxBatchConfig {
1908    pub period_range: (usize, usize, usize),
1909}
1910
1911#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1912#[derive(Serialize, Deserialize)]
1913pub struct RsxBatchJsOutput {
1914    pub values: Vec<f64>,
1915    pub combos: Vec<RsxParams>,
1916    pub rows: usize,
1917    pub cols: usize,
1918}
1919
1920#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1921#[wasm_bindgen(js_name = rsx_batch)]
1922pub fn rsx_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1923    let config: RsxBatchConfig = serde_wasm_bindgen::from_value(config)
1924        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1925    let sweep = RsxBatchRange {
1926        period: config.period_range,
1927    };
1928
1929    let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1930    let rows = combos.len();
1931    let cols = data.len();
1932    let _ = rows
1933        .checked_mul(cols)
1934        .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
1935
1936    let mut buf_mu = make_uninit_matrix(rows, cols);
1937    let first = data
1938        .iter()
1939        .position(|x| !x.is_nan())
1940        .ok_or_else(|| JsValue::from_str("All NaN"))?;
1941    let warm: Vec<usize> = combos
1942        .iter()
1943        .map(|c| first + c.period.unwrap() - 1)
1944        .collect();
1945    init_matrix_prefixes(&mut buf_mu, cols, &warm);
1946
1947    let mut guard = core::mem::ManuallyDrop::new(buf_mu);
1948    let out_slice: &mut [f64] =
1949        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
1950
1951    rsx_batch_inner_into(data, &sweep, detect_best_kernel(), false, out_slice)
1952        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1953
1954    let values = unsafe {
1955        Vec::from_raw_parts(
1956            guard.as_mut_ptr() as *mut f64,
1957            guard.len(),
1958            guard.capacity(),
1959        )
1960    };
1961
1962    let js = RsxBatchJsOutput {
1963        values,
1964        combos,
1965        rows,
1966        cols,
1967    };
1968    serde_wasm_bindgen::to_value(&js)
1969        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1970}
1971
1972#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1973#[wasm_bindgen]
1974pub fn rsx_batch_into(
1975    in_ptr: *const f64,
1976    out_ptr: *mut f64,
1977    len: usize,
1978    period_start: usize,
1979    period_end: usize,
1980    period_step: usize,
1981) -> Result<usize, JsValue> {
1982    if in_ptr.is_null() || out_ptr.is_null() {
1983        return Err(JsValue::from_str("null pointer passed to rsx_batch_into"));
1984    }
1985
1986    unsafe {
1987        let data = std::slice::from_raw_parts(in_ptr, len);
1988
1989        let sweep = RsxBatchRange {
1990            period: (period_start, period_end, period_step),
1991        };
1992
1993        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1994        let rows = combos.len();
1995        let cols = len;
1996
1997        let total = rows
1998            .checked_mul(cols)
1999            .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
2000        let out = std::slice::from_raw_parts_mut(out_ptr, total);
2001
2002        rsx_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
2003            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2004
2005        Ok(rows)
2006    }
2007}
2008
2009#[cfg(all(feature = "python", feature = "cuda"))]
2010use crate::cuda::cuda_available;
2011#[cfg(all(feature = "python", feature = "cuda"))]
2012use crate::cuda::moving_averages::DeviceArrayF32;
2013#[cfg(all(feature = "python", feature = "cuda"))]
2014use crate::cuda::oscillators::rsx_wrapper::CudaRsx;
2015#[cfg(all(feature = "python", feature = "cuda"))]
2016use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2017#[cfg(all(feature = "python", feature = "cuda"))]
2018use cust::context::Context;
2019#[cfg(all(feature = "python", feature = "cuda"))]
2020use std::sync::Arc;
2021
2022#[cfg(all(feature = "python", feature = "cuda"))]
2023#[pyclass(module = "ta_indicators.cuda", name = "RsxDeviceArrayF32", unsendable)]
2024pub struct RsxDeviceArrayF32Py {
2025    pub(crate) inner: DeviceArrayF32,
2026    pub(crate) _ctx: Arc<Context>,
2027    pub(crate) device_id: u32,
2028}
2029
2030#[cfg(all(feature = "python", feature = "cuda"))]
2031#[pymethods]
2032impl RsxDeviceArrayF32Py {
2033    #[getter]
2034    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2035        let d = PyDict::new(py);
2036        let itemsize = std::mem::size_of::<f32>();
2037        d.set_item("shape", (self.inner.rows, self.inner.cols))?;
2038        d.set_item("typestr", "<f4")?;
2039        d.set_item("strides", (self.inner.cols * itemsize, itemsize))?;
2040        d.set_item("data", (self.inner.device_ptr() as usize, false))?;
2041
2042        d.set_item("version", 3)?;
2043        Ok(d)
2044    }
2045
2046    fn __dlpack_device__(&self) -> (i32, i32) {
2047        (2, self.device_id as i32)
2048    }
2049
2050    #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
2051    fn __dlpack__<'py>(
2052        &mut self,
2053        py: Python<'py>,
2054        stream: Option<pyo3::PyObject>,
2055        max_version: Option<pyo3::PyObject>,
2056        dl_device: Option<pyo3::PyObject>,
2057        copy: Option<pyo3::PyObject>,
2058    ) -> PyResult<PyObject> {
2059        if let Some(stream_obj) = stream.as_ref() {
2060            if let Ok(s) = stream_obj.extract::<usize>(py) {
2061                if s == 0 {
2062                    return Err(PyValueError::new_err(
2063                        "__dlpack__ stream=0 is invalid for CUDA",
2064                    ));
2065                }
2066            }
2067        }
2068
2069        let (kdl, alloc_dev) = self.__dlpack_device__();
2070        if let Some(dev_obj) = dl_device.as_ref() {
2071            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2072                if dev_ty != kdl || dev_id != alloc_dev {
2073                    return Err(PyValueError::new_err(
2074                        "dl_device mismatch; cross-device copy not supported for RsxDeviceArrayF32",
2075                    ));
2076                }
2077            }
2078        }
2079
2080        if let Some(copy_obj) = copy.as_ref() {
2081            let do_copy: bool = copy_obj.extract(py)?;
2082            if do_copy {
2083                return Err(PyValueError::new_err(
2084                    "copy=True not supported for RsxDeviceArrayF32",
2085                ));
2086            }
2087        }
2088
2089        use cust::memory::DeviceBuffer;
2090        let dummy = DeviceBuffer::<f32>::from_slice(&[])
2091            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2092        let inner = std::mem::replace(
2093            &mut self.inner,
2094            DeviceArrayF32 {
2095                buf: dummy,
2096                rows: 0,
2097                cols: 0,
2098            },
2099        );
2100
2101        let rows = inner.rows;
2102        let cols = inner.cols;
2103        let buf = inner.buf;
2104
2105        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2106
2107        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2108    }
2109}
2110
2111#[cfg(all(feature = "python", feature = "cuda"))]
2112impl RsxDeviceArrayF32Py {
2113    pub fn new_from_rust(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
2114        Self {
2115            inner,
2116            _ctx: ctx_guard,
2117            device_id,
2118        }
2119    }
2120}
2121
2122#[cfg(all(feature = "python", feature = "cuda"))]
2123#[pyfunction(name = "rsx_cuda_batch_dev")]
2124#[pyo3(signature = (data, period_range, device_id=0))]
2125pub fn rsx_cuda_batch_dev_py(
2126    py: Python<'_>,
2127    data: numpy::PyReadonlyArray1<'_, f32>,
2128    period_range: (usize, usize, usize),
2129    device_id: usize,
2130) -> PyResult<RsxDeviceArrayF32Py> {
2131    if !cuda_available() {
2132        return Err(PyValueError::new_err("CUDA not available"));
2133    }
2134    let prices = data.as_slice()?;
2135    let sweep = RsxBatchRange {
2136        period: period_range,
2137    };
2138    let (inner, ctx, dev_id) = py.allow_threads(|| {
2139        let cuda = CudaRsx::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2140        let ctx = cuda.context_arc();
2141        let dev_id = cuda.device_id();
2142        let arr = cuda
2143            .rsx_batch_dev(prices, &sweep)
2144            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2145        Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
2146    })?;
2147    Ok(RsxDeviceArrayF32Py::new_from_rust(inner, ctx, dev_id))
2148}
2149
2150#[cfg(all(feature = "python", feature = "cuda"))]
2151#[pyfunction(name = "rsx_cuda_many_series_one_param_dev")]
2152#[pyo3(signature = (data_tm, cols, rows, period, device_id=0))]
2153pub fn rsx_cuda_many_series_one_param_dev_py(
2154    py: Python<'_>,
2155    data_tm: numpy::PyReadonlyArray1<'_, f32>,
2156    cols: usize,
2157    rows: usize,
2158    period: usize,
2159    device_id: usize,
2160) -> PyResult<RsxDeviceArrayF32Py> {
2161    if !cuda_available() {
2162        return Err(PyValueError::new_err("CUDA not available"));
2163    }
2164    let prices_tm = data_tm.as_slice()?;
2165    let expected = cols
2166        .checked_mul(rows)
2167        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2168    if prices_tm.len() != expected {
2169        return Err(PyValueError::new_err("time-major input length mismatch"));
2170    }
2171    let (inner, ctx, dev_id) = py.allow_threads(|| {
2172        let cuda = CudaRsx::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2173        let ctx = cuda.context_arc();
2174        let dev_id = cuda.device_id();
2175        let arr = cuda
2176            .rsx_many_series_one_param_time_major_dev(prices_tm, cols, rows, period)
2177            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2178        Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
2179    })?;
2180    Ok(RsxDeviceArrayF32Py::new_from_rust(inner, ctx, dev_id))
2181}