1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7use aligned_vec::{AVec, CACHELINE_ALIGN};
8#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
9use core::arch::x86_64::*;
10#[cfg(not(target_arch = "wasm32"))]
11use rayon::prelude::*;
12use std::mem::MaybeUninit;
13use thiserror::Error;
14
15#[cfg(feature = "python")]
16use crate::utilities::kernel_validation::validate_kernel;
17#[cfg(feature = "python")]
18use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
19#[cfg(feature = "python")]
20use pyo3::exceptions::PyValueError;
21#[cfg(feature = "python")]
22use pyo3::prelude::*;
23#[cfg(feature = "python")]
24use pyo3::types::PyDict;
25
26#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
27use serde::{Deserialize, Serialize};
28#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
29use serde_wasm_bindgen;
30#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
31use wasm_bindgen::prelude::*;
32
33impl<'a> AsRef<[f64]> for RsxInput<'a> {
34 #[inline(always)]
35 fn as_ref(&self) -> &[f64] {
36 match &self.data {
37 RsxData::Slice(slice) => slice,
38 RsxData::Candles { candles, source } => source_type(candles, source),
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
44pub enum RsxData<'a> {
45 Candles {
46 candles: &'a Candles,
47 source: &'a str,
48 },
49 Slice(&'a [f64]),
50}
51
52#[derive(Debug, Clone)]
53pub struct RsxOutput {
54 pub values: Vec<f64>,
55}
56
57#[derive(Debug, Clone)]
58#[cfg_attr(
59 all(target_arch = "wasm32", feature = "wasm"),
60 derive(Serialize, Deserialize)
61)]
62pub struct RsxParams {
63 pub period: Option<usize>,
64}
65
66impl Default for RsxParams {
67 fn default() -> Self {
68 Self { period: Some(14) }
69 }
70}
71
72#[derive(Debug, Clone)]
73pub struct RsxInput<'a> {
74 pub data: RsxData<'a>,
75 pub params: RsxParams,
76}
77
78impl<'a> RsxInput<'a> {
79 #[inline]
80 pub fn from_candles(c: &'a Candles, s: &'a str, p: RsxParams) -> Self {
81 Self {
82 data: RsxData::Candles {
83 candles: c,
84 source: s,
85 },
86 params: p,
87 }
88 }
89 #[inline]
90 pub fn from_slice(sl: &'a [f64], p: RsxParams) -> Self {
91 Self {
92 data: RsxData::Slice(sl),
93 params: p,
94 }
95 }
96 #[inline]
97 pub fn with_default_candles(c: &'a Candles) -> Self {
98 Self::from_candles(c, "close", RsxParams::default())
99 }
100 #[inline]
101 pub fn get_period(&self) -> usize {
102 self.params.period.unwrap_or(14)
103 }
104}
105
106#[derive(Copy, Clone, Debug)]
107pub struct RsxBuilder {
108 period: Option<usize>,
109 kernel: Kernel,
110}
111
112impl Default for RsxBuilder {
113 fn default() -> Self {
114 Self {
115 period: None,
116 kernel: Kernel::Auto,
117 }
118 }
119}
120
121impl RsxBuilder {
122 #[inline(always)]
123 pub fn new() -> Self {
124 Self::default()
125 }
126 #[inline(always)]
127 pub fn period(mut self, n: usize) -> Self {
128 self.period = Some(n);
129 self
130 }
131 #[inline(always)]
132 pub fn kernel(mut self, k: Kernel) -> Self {
133 self.kernel = k;
134 self
135 }
136 #[inline(always)]
137 pub fn apply(self, c: &Candles) -> Result<RsxOutput, RsxError> {
138 let p = RsxParams {
139 period: self.period,
140 };
141 let i = RsxInput::from_candles(c, "close", p);
142 rsx_with_kernel(&i, self.kernel)
143 }
144 #[inline(always)]
145 pub fn apply_slice(self, d: &[f64]) -> Result<RsxOutput, RsxError> {
146 let p = RsxParams {
147 period: self.period,
148 };
149 let i = RsxInput::from_slice(d, p);
150 rsx_with_kernel(&i, self.kernel)
151 }
152 #[inline(always)]
153 pub fn into_stream(self) -> Result<RsxStream, RsxError> {
154 let p = RsxParams {
155 period: self.period,
156 };
157 RsxStream::try_new(p)
158 }
159}
160
161#[derive(Debug, Error)]
162pub enum RsxError {
163 #[error("rsx: Input data slice is empty.")]
164 EmptyInputData,
165
166 #[error("rsx: All values are NaN.")]
167 AllValuesNaN,
168
169 #[error("rsx: Invalid period: period = {period}, data length = {data_len}")]
170 InvalidPeriod { period: usize, data_len: usize },
171
172 #[error("rsx: Not enough valid data: needed = {needed}, valid = {valid}")]
173 NotEnoughValidData { needed: usize, valid: usize },
174
175 #[error("rsx: Output length mismatch: expected {expected}, got {got}")]
176 OutputLengthMismatch { expected: usize, got: usize },
177
178 #[error("rsx: Invalid range: start={start}, end={end}, step={step}")]
179 InvalidRange {
180 start: String,
181 end: String,
182 step: String,
183 },
184
185 #[error("rsx: Invalid kernel: expected batch kernel, got {kernel:?}")]
186 InvalidKernel { kernel: Kernel },
187
188 #[error("rsx: Invalid kernel for batch: {0:?}")]
189 InvalidKernelForBatch(Kernel),
190}
191
192#[inline(always)]
193fn rsx_prepare<'a>(
194 input: &'a RsxInput,
195 kernel: Kernel,
196) -> Result<(&'a [f64], usize, usize, Kernel), RsxError> {
197 let data: &[f64] = input.as_ref();
198 let len = data.len();
199 if len == 0 {
200 return Err(RsxError::EmptyInputData);
201 }
202 let first = data
203 .iter()
204 .position(|x| !x.is_nan())
205 .ok_or(RsxError::AllValuesNaN)?;
206 let period = input.get_period();
207
208 if period == 0 || period > len {
209 return Err(RsxError::InvalidPeriod {
210 period,
211 data_len: len,
212 });
213 }
214 if len - first < period {
215 return Err(RsxError::NotEnoughValidData {
216 needed: period,
217 valid: len - first,
218 });
219 }
220
221 let chosen = match kernel {
222 Kernel::Auto => Kernel::Scalar,
223 k => k,
224 };
225 Ok((data, period, first, chosen))
226}
227
228#[inline]
229pub fn rsx(input: &RsxInput) -> Result<RsxOutput, RsxError> {
230 rsx_with_kernel(input, Kernel::Auto)
231}
232
233pub fn rsx_with_kernel(input: &RsxInput, kernel: Kernel) -> Result<RsxOutput, RsxError> {
234 let (data, period, first, chosen) = rsx_prepare(input, kernel)?;
235 let mut out = alloc_with_nan_prefix(data.len(), first + period - 1);
236
237 unsafe {
238 match chosen {
239 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
240 Kernel::Scalar | Kernel::ScalarBatch => rsx_simd128(data, period, first, &mut out),
241 #[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
242 Kernel::Scalar | Kernel::ScalarBatch => rsx_scalar(data, period, first, &mut out),
243
244 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
245 Kernel::Avx2 | Kernel::Avx2Batch => rsx_avx2(data, period, first, &mut out),
246 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
247 Kernel::Avx512 | Kernel::Avx512Batch => rsx_avx512(data, period, first, &mut out),
248
249 _ => rsx_scalar(data, period, first, &mut out),
250 }
251 }
252 Ok(RsxOutput { values: out })
253}
254
255#[inline]
256pub fn rsx_into_slice(dst: &mut [f64], input: &RsxInput, kern: Kernel) -> Result<(), RsxError> {
257 let (data, period, first, chosen) = rsx_prepare(input, kern)?;
258 if dst.len() != data.len() {
259 return Err(RsxError::OutputLengthMismatch {
260 expected: data.len(),
261 got: dst.len(),
262 });
263 }
264
265 unsafe {
266 match chosen {
267 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
268 Kernel::Scalar | Kernel::ScalarBatch => rsx_simd128(data, period, first, dst),
269 #[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
270 Kernel::Scalar | Kernel::ScalarBatch => rsx_scalar(data, period, first, dst),
271
272 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
273 Kernel::Avx2 | Kernel::Avx2Batch => rsx_avx2(data, period, first, dst),
274 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
275 Kernel::Avx512 | Kernel::Avx512Batch => rsx_avx512(data, period, first, dst),
276
277 _ => rsx_scalar(data, period, first, dst),
278 }
279 }
280
281 let warmup_end = first + period - 1;
282 for v in &mut dst[..warmup_end] {
283 *v = f64::NAN;
284 }
285 Ok(())
286}
287
288#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
289#[inline]
290pub fn rsx_into(input: &RsxInput, out: &mut [f64]) -> Result<(), RsxError> {
291 rsx_into_slice(out, input, Kernel::Auto)
292}
293
294#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
295#[inline]
296pub fn rsx_avx512(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
297 if period <= 32 {
298 unsafe { rsx_avx512_short(data, period, first_valid, out) }
299 } else {
300 unsafe { rsx_avx512_long(data, period, first_valid, out) }
301 }
302}
303
304#[inline]
305pub fn rsx_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
306 let mut f0 = 0.0;
307 let mut f8 = 0.0;
308 let mut f18 = 0.0;
309 let mut f20 = 0.0;
310 let mut f28 = 0.0;
311 let mut f30 = 0.0;
312 let mut f38 = 0.0;
313 let mut f40 = 0.0;
314 let mut f48 = 0.0;
315 let mut f50 = 0.0;
316 let mut f58 = 0.0;
317 let mut f60 = 0.0;
318 let mut f68 = 0.0;
319 let mut f70 = 0.0;
320 let mut f78 = 0.0;
321 let mut f80 = 0.0;
322 let mut f88 = 0.0;
323 let mut f90 = 0.0;
324
325 let start = first + period - 1;
326 if start >= data.len() {
327 return;
328 }
329
330 f90 = 1.0;
331 f0 = 0.0;
332 f88 = if period >= 6 {
333 (period - 1) as f64
334 } else {
335 5.0
336 };
337 f8 = 100.0 * data[start];
338 f18 = 3.0 / (period as f64 + 2.0);
339 f20 = 1.0 - f18;
340 out[start] = f64::NAN;
341
342 for i in (start + 1)..data.len() {
343 f90 = if f88 <= f90 { f88 + 1.0 } else { f90 + 1.0 };
344
345 let prev = f8;
346 f8 = 100.0 * data[i];
347 let v8 = f8 - prev;
348
349 f28 = f20 * f28 + f18 * v8;
350 f30 = f18 * f28 + f20 * f30;
351 let v_c = f28 * 1.5 - f30 * 0.5;
352
353 f38 = f20 * f38 + f18 * v_c;
354 f40 = f18 * f38 + f20 * f40;
355 let v10 = f38 * 1.5 - f40 * 0.5;
356
357 f48 = f20 * f48 + f18 * v10;
358 f50 = f18 * f48 + f20 * f50;
359 let v14 = f48 * 1.5 - f50 * 0.5;
360
361 let av = v8.abs();
362 f58 = f20 * f58 + f18 * av;
363 f60 = f18 * f58 + f20 * f60;
364 let v18 = f58 * 1.5 - f60 * 0.5;
365
366 f68 = f20 * f68 + f18 * v18;
367 f70 = f18 * f68 + f20 * f70;
368 let v1c = f68 * 1.5 - f70 * 0.5;
369
370 f78 = f20 * f78 + f18 * v1c;
371 f80 = f18 * f78 + f20 * f80;
372 let v20_ = f78 * 1.5 - f80 * 0.5;
373
374 if f88 >= f90 && f8 != prev {
375 f0 = 1.0;
376 }
377 if (f88 - f90).abs() < f64::EPSILON && f0 == 0.0 {
378 f90 = 0.0;
379 }
380
381 if f88 < f90 && v20_ > 1e-10 {
382 let mut v4 = (v14 / v20_ + 1.0) * 50.0;
383 if v4 > 100.0 {
384 v4 = 100.0;
385 }
386 if v4 < 0.0 {
387 v4 = 0.0;
388 }
389 out[i] = v4;
390 } else {
391 out[i] = 50.0;
392 }
393 }
394}
395
396#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
397#[inline]
398pub unsafe fn rsx_avx2(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
399 let len = data.len();
400 let start = first + period - 1;
401 if start >= len {
402 return;
403 }
404
405 let mut f0 = 0.0;
406 let mut f8 = 100.0 * *data.get_unchecked(start);
407 let f18 = 3.0 / (period as f64 + 2.0);
408 let f20 = 1.0 - f18;
409 let mut f28 = 0.0;
410 let mut f30 = 0.0;
411 let mut f38 = 0.0;
412 let mut f40 = 0.0;
413 let mut f48 = 0.0;
414 let mut f50 = 0.0;
415 let mut f58 = 0.0;
416 let mut f60 = 0.0;
417 let mut f68 = 0.0;
418 let mut f70 = 0.0;
419 let mut f78 = 0.0;
420 let mut f80 = 0.0;
421 let mut f88 = if period >= 6 {
422 (period - 1) as f64
423 } else {
424 5.0
425 };
426 let mut f90 = 1.0;
427
428 *out.get_unchecked_mut(start) = f64::NAN;
429
430 for i in (start + 1)..len {
431 f90 = if f88 <= f90 { f88 + 1.0 } else { f90 + 1.0 };
432
433 let prev = f8;
434 f8 = 100.0 * *data.get_unchecked(i);
435 let v8 = f8 - prev;
436
437 f28 = f18.mul_add(v8, f20 * f28);
438 f30 = f18.mul_add(f28, f20 * f30);
439 let v_c = 1.5f64.mul_add(f28, -0.5 * f30);
440
441 f38 = f20.mul_add(f38, f18 * v_c);
442 f40 = f18.mul_add(f38, f20 * f40);
443 let v10 = 1.5f64.mul_add(f38, -0.5 * f40);
444
445 f48 = f20.mul_add(f48, f18 * v10);
446 f50 = f18.mul_add(f48, f20 * f50);
447 let v14 = 1.5f64.mul_add(f48, -0.5 * f50);
448
449 let av = v8.abs();
450 f58 = f20.mul_add(f58, f18 * av);
451 f60 = f18.mul_add(f58, f20 * f60);
452 let v18 = 1.5f64.mul_add(f58, -0.5 * f60);
453
454 f68 = f20.mul_add(f68, f18 * v18);
455 f70 = f18.mul_add(f68, f20 * f70);
456 let v1c = 1.5f64.mul_add(f68, -0.5 * f70);
457
458 f78 = f20.mul_add(f78, f18 * v1c);
459 f80 = f18.mul_add(f78, f20 * f80);
460 let v20_ = 1.5f64.mul_add(f78, -0.5 * f80);
461
462 if f88 >= f90 && f8 != prev {
463 f0 = 1.0;
464 }
465 if (f88 - f90).abs() < f64::EPSILON && f0 == 0.0 {
466 f90 = 0.0;
467 }
468
469 let y = if f88 < f90 && v20_ > 1e-10 {
470 let v4 = (v14 / v20_ + 1.0) * 50.0;
471 v4.max(0.0).min(100.0)
472 } else {
473 50.0
474 };
475
476 *out.get_unchecked_mut(i) = y;
477 }
478}
479
480#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
481#[inline]
482pub unsafe fn rsx_avx512_short(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
483 rsx_avx2(data, period, first, out)
484}
485
486#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
487#[inline]
488pub unsafe fn rsx_avx512_long(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
489 rsx_avx2(data, period, first, out)
490}
491
492#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
493#[inline]
494unsafe fn rsx_simd128(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
495 rsx_scalar(data, period, first, out)
496}
497
498#[derive(Debug, Clone)]
499pub struct RsxStream {
500 period: usize,
501
502 seen: usize,
503 init_done: bool,
504
505 alpha: f64,
506 beta: f64,
507 f88: f64,
508 f90: f64,
509
510 f0: f64,
511 f8: f64,
512
513 f28: f64,
514 f30: f64,
515 f38: f64,
516 f40: f64,
517 f48: f64,
518 f50: f64,
519 f58: f64,
520 f60: f64,
521 f68: f64,
522 f70: f64,
523 f78: f64,
524 f80: f64,
525}
526
527impl RsxStream {
528 pub fn try_new(params: RsxParams) -> Result<Self, RsxError> {
529 let period = params.period.unwrap_or(14);
530 if period == 0 {
531 return Err(RsxError::InvalidPeriod {
532 period,
533 data_len: 0,
534 });
535 }
536
537 let alpha = 3.0 / (period as f64 + 2.0);
538 let beta = 1.0 - alpha;
539
540 Ok(Self {
541 period,
542 seen: 0,
543 init_done: false,
544 alpha,
545 beta,
546 f88: if period >= 6 {
547 (period - 1) as f64
548 } else {
549 5.0
550 },
551 f90: 1.0,
552 f0: 0.0,
553 f8: 0.0,
554 f28: 0.0,
555 f30: 0.0,
556 f38: 0.0,
557 f40: 0.0,
558 f48: 0.0,
559 f50: 0.0,
560 f58: 0.0,
561 f60: 0.0,
562 f68: 0.0,
563 f70: 0.0,
564 f78: 0.0,
565 f80: 0.0,
566 })
567 }
568
569 #[inline(always)]
570 pub fn update(&mut self, value: f64) -> Option<f64> {
571 self.seen += 1;
572
573 if !self.init_done && self.seen < self.period {
574 return None;
575 }
576
577 if !self.init_done {
578 self.init_done = true;
579 self.f0 = 0.0;
580 self.f8 = 100.0 * value;
581 return Some(f64::NAN);
582 }
583
584 self.f90 = if self.f88 <= self.f90 {
585 self.f88 + 1.0
586 } else {
587 self.f90 + 1.0
588 };
589
590 let prev = self.f8;
591 self.f8 = 100.0 * value;
592 let v8 = self.f8 - prev;
593
594 self.f28 = self.beta * self.f28 + self.alpha * v8;
595 self.f30 = self.alpha * self.f28 + self.beta * self.f30;
596 let v_c = self.f28 * 1.5 - self.f30 * 0.5;
597
598 self.f38 = self.beta * self.f38 + self.alpha * v_c;
599 self.f40 = self.alpha * self.f38 + self.beta * self.f40;
600 let v10 = self.f38 * 1.5 - self.f40 * 0.5;
601
602 self.f48 = self.beta * self.f48 + self.alpha * v10;
603 self.f50 = self.alpha * self.f48 + self.beta * self.f50;
604 let v14 = self.f48 * 1.5 - self.f50 * 0.5;
605
606 let av = v8.abs();
607 self.f58 = self.beta * self.f58 + self.alpha * av;
608 self.f60 = self.alpha * self.f58 + self.beta * self.f60;
609 let v18 = self.f58 * 1.5 - self.f60 * 0.5;
610
611 self.f68 = self.beta * self.f68 + self.alpha * v18;
612 self.f70 = self.alpha * self.f68 + self.beta * self.f70;
613 let v1c = self.f68 * 1.5 - self.f70 * 0.5;
614
615 self.f78 = self.beta * self.f78 + self.alpha * v1c;
616 self.f80 = self.alpha * self.f78 + self.beta * self.f80;
617 let v20_ = self.f78 * 1.5 - self.f80 * 0.5;
618
619 if self.f88 >= self.f90 && self.f8 != prev {
620 self.f0 = 1.0;
621 }
622 if (self.f88 - self.f90).abs() < f64::EPSILON && self.f0 == 0.0 {
623 self.f90 = 0.0;
624 }
625
626 let y = if self.f88 < self.f90 && v20_ > 1e-10 {
627 let v4 = (v14 / v20_ + 1.0) * 50.0;
628 v4.max(0.0).min(100.0)
629 } else {
630 50.0
631 };
632
633 Some(y)
634 }
635}
636
637#[derive(Clone, Debug)]
638pub struct RsxBatchRange {
639 pub period: (usize, usize, usize),
640}
641
642impl Default for RsxBatchRange {
643 fn default() -> Self {
644 Self {
645 period: (14, 263, 1),
646 }
647 }
648}
649
650#[derive(Clone, Debug, Default)]
651pub struct RsxBatchBuilder {
652 range: RsxBatchRange,
653 kernel: Kernel,
654}
655
656impl RsxBatchBuilder {
657 pub fn new() -> Self {
658 Self::default()
659 }
660 pub fn kernel(mut self, k: Kernel) -> Self {
661 self.kernel = k;
662 self
663 }
664
665 #[inline]
666 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
667 self.range.period = (start, end, step);
668 self
669 }
670 #[inline]
671 pub fn period_static(mut self, p: usize) -> Self {
672 self.range.period = (p, p, 0);
673 self
674 }
675
676 pub fn apply_slice(self, data: &[f64]) -> Result<RsxBatchOutput, RsxError> {
677 rsx_batch_with_kernel(data, &self.range, self.kernel)
678 }
679
680 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<RsxBatchOutput, RsxError> {
681 RsxBatchBuilder::new().kernel(k).apply_slice(data)
682 }
683
684 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<RsxBatchOutput, RsxError> {
685 let slice = source_type(c, src);
686 self.apply_slice(slice)
687 }
688
689 pub fn with_default_candles(c: &Candles) -> Result<RsxBatchOutput, RsxError> {
690 RsxBatchBuilder::new()
691 .kernel(Kernel::Auto)
692 .apply_candles(c, "close")
693 }
694}
695
696pub fn rsx_batch_with_kernel(
697 data: &[f64],
698 sweep: &RsxBatchRange,
699 k: Kernel,
700) -> Result<RsxBatchOutput, RsxError> {
701 let kernel = match k {
702 Kernel::Auto => detect_best_batch_kernel(),
703 other if other.is_batch() => other,
704 other => return Err(RsxError::InvalidKernelForBatch(other)),
705 };
706
707 let simd = match kernel {
708 Kernel::Avx512Batch => Kernel::Avx512,
709 Kernel::Avx2Batch => Kernel::Avx2,
710 Kernel::ScalarBatch => Kernel::Scalar,
711 _ => unreachable!(),
712 };
713 rsx_batch_par_slice(data, sweep, simd)
714}
715
716#[derive(Clone, Debug)]
717pub struct RsxBatchOutput {
718 pub values: Vec<f64>,
719 pub combos: Vec<RsxParams>,
720 pub rows: usize,
721 pub cols: usize,
722}
723impl RsxBatchOutput {
724 pub fn row_for_params(&self, p: &RsxParams) -> Option<usize> {
725 self.combos
726 .iter()
727 .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
728 }
729
730 pub fn values_for(&self, p: &RsxParams) -> Option<&[f64]> {
731 self.row_for_params(p).map(|row| {
732 let start = row * self.cols;
733 &self.values[start..start + self.cols]
734 })
735 }
736}
737
738#[inline(always)]
739fn expand_grid(r: &RsxBatchRange) -> Result<Vec<RsxParams>, RsxError> {
740 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, RsxError> {
741 if step == 0 || start == end {
742 return Ok(vec![start]);
743 }
744 if start < end {
745 let st = step.max(1);
746 let v: Vec<usize> = (start..=end).step_by(st).collect();
747 if v.is_empty() {
748 return Err(RsxError::InvalidRange {
749 start: start.to_string(),
750 end: end.to_string(),
751 step: step.to_string(),
752 });
753 }
754 return Ok(v);
755 }
756
757 let mut v = Vec::new();
758 let mut x = start as isize;
759 let end_i = end as isize;
760 let st = (step as isize).max(1);
761 while x >= end_i {
762 v.push(x as usize);
763 x -= st;
764 }
765 if v.is_empty() {
766 return Err(RsxError::InvalidRange {
767 start: start.to_string(),
768 end: end.to_string(),
769 step: step.to_string(),
770 });
771 }
772 Ok(v)
773 }
774 let periods = axis_usize(r.period)?;
775 if periods.is_empty() {
776 return Err(RsxError::InvalidRange {
777 start: r.period.0.to_string(),
778 end: r.period.1.to_string(),
779 step: r.period.2.to_string(),
780 });
781 }
782 let mut out = Vec::with_capacity(periods.len());
783 for &p in &periods {
784 out.push(RsxParams { period: Some(p) });
785 }
786 Ok(out)
787}
788
789#[inline(always)]
790pub fn rsx_batch_slice(
791 data: &[f64],
792 sweep: &RsxBatchRange,
793 kern: Kernel,
794) -> Result<RsxBatchOutput, RsxError> {
795 rsx_batch_inner(data, sweep, kern, false)
796}
797
798#[inline(always)]
799pub fn rsx_batch_par_slice(
800 data: &[f64],
801 sweep: &RsxBatchRange,
802 kern: Kernel,
803) -> Result<RsxBatchOutput, RsxError> {
804 rsx_batch_inner(data, sweep, kern, true)
805}
806
807#[inline(always)]
808fn rsx_batch_inner(
809 data: &[f64],
810 sweep: &RsxBatchRange,
811 kern: Kernel,
812 parallel: bool,
813) -> Result<RsxBatchOutput, RsxError> {
814 let combos = expand_grid(sweep)?;
815 if combos.is_empty() {
816 return Err(RsxError::InvalidRange {
817 start: "range".into(),
818 end: "range".into(),
819 step: "empty".into(),
820 });
821 }
822
823 let first = data
824 .iter()
825 .position(|x| !x.is_nan())
826 .ok_or(RsxError::AllValuesNaN)?;
827 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
828 if data.len() - first < max_p {
829 return Err(RsxError::NotEnoughValidData {
830 needed: max_p,
831 valid: data.len() - first,
832 });
833 }
834
835 let rows = combos.len();
836 let cols = data.len();
837
838 let _ = rows
839 .checked_mul(cols)
840 .ok_or_else(|| RsxError::InvalidRange {
841 start: rows.to_string(),
842 end: cols.to_string(),
843 step: "rows*cols".into(),
844 })?;
845
846 let actual_kernel = match kern {
847 Kernel::Auto => Kernel::Scalar,
848 k => k,
849 };
850
851 let mut buf_mu = make_uninit_matrix(rows, cols);
852
853 let warmup_periods: Vec<usize> = combos
854 .iter()
855 .map(|c| first + c.period.unwrap() - 1)
856 .collect();
857
858 init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
859
860 let mut buf_guard = std::mem::ManuallyDrop::new(buf_mu);
861 let values_slice: &mut [f64] =
862 unsafe { std::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, rows * cols) };
863
864 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
865 let period = combos[row].period.unwrap();
866 match actual_kernel {
867 Kernel::Scalar | Kernel::ScalarBatch => rsx_row_scalar(data, first, period, out_row),
868 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
869 Kernel::Avx2 | Kernel::Avx2Batch => rsx_row_avx2(data, first, period, out_row),
870 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
871 Kernel::Avx512 | Kernel::Avx512Batch => rsx_row_avx512(data, first, period, out_row),
872 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
873 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
874 rsx_row_scalar(data, first, period, out_row)
875 }
876 Kernel::Auto => unreachable!("Auto kernel should have been resolved"),
877 }
878 };
879
880 if parallel {
881 #[cfg(not(target_arch = "wasm32"))]
882 {
883 values_slice
884 .par_chunks_mut(cols)
885 .enumerate()
886 .for_each(|(row, slice)| do_row(row, slice));
887 }
888
889 #[cfg(target_arch = "wasm32")]
890 {
891 for (row, slice) in values_slice.chunks_mut(cols).enumerate() {
892 do_row(row, slice);
893 }
894 }
895 } else {
896 for (row, slice) in values_slice.chunks_mut(cols).enumerate() {
897 do_row(row, slice);
898 }
899 }
900
901 let values = unsafe {
902 Vec::from_raw_parts(buf_guard.as_mut_ptr() as *mut f64, rows * cols, rows * cols)
903 };
904
905 Ok(RsxBatchOutput {
906 values,
907 combos,
908 rows,
909 cols,
910 })
911}
912
913#[inline(always)]
914unsafe fn rsx_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
915 rsx_scalar(data, period, first, out);
916}
917
918#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
919#[inline(always)]
920unsafe fn rsx_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
921 rsx_avx2(data, period, first, out);
922}
923
924#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
925#[inline(always)]
926unsafe fn rsx_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
927 rsx_avx512(data, period, first, out);
928}
929
930#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
931#[inline(always)]
932unsafe fn rsx_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
933 rsx_avx512_short(data, period, first, out);
934}
935
936#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
937#[inline(always)]
938unsafe fn rsx_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
939 rsx_avx512_long(data, period, first, out);
940}
941
942#[cfg(test)]
943mod tests {
944 use super::*;
945 use crate::skip_if_unsupported;
946 use crate::utilities::data_loader::read_candles_from_csv;
947
948 #[test]
949 fn test_rsx_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
950 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
951 let candles = read_candles_from_csv(file)?;
952
953 let input = RsxInput::with_default_candles(&candles);
954
955 let RsxOutput { values: expected } = rsx(&input)?;
956
957 let mut out = vec![0.0; candles.close.len()];
958
959 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
960 {
961 rsx_into(&input, &mut out)?;
962 }
963 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
964 {
965 rsx_into_slice(&mut out, &input, Kernel::Auto)?;
966 }
967
968 assert_eq!(expected.len(), out.len());
969
970 for i in 0..out.len() {
971 let a = expected[i];
972 let b = out[i];
973 if a.is_nan() || b.is_nan() {
974 assert!(
975 a.is_nan() && b.is_nan(),
976 "NaN mismatch at {i}: {a:?} vs {b:?}"
977 );
978 } else {
979 assert!(a == b, "Mismatch at {i}: {a} vs {b}");
980 }
981 }
982
983 Ok(())
984 }
985
986 fn check_rsx_partial_params(
987 test_name: &str,
988 kernel: Kernel,
989 ) -> Result<(), Box<dyn std::error::Error>> {
990 skip_if_unsupported!(kernel, test_name);
991 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
992 let candles = read_candles_from_csv(file_path)?;
993
994 let default_params = RsxParams { period: None };
995 let input_default = RsxInput::from_candles(&candles, "close", default_params);
996 let output_default = rsx_with_kernel(&input_default, kernel)?;
997 assert_eq!(output_default.values.len(), candles.close.len());
998
999 let params_period_10 = RsxParams { period: Some(10) };
1000 let input_period_10 = RsxInput::from_candles(&candles, "hl2", params_period_10);
1001 let output_period_10 = rsx_with_kernel(&input_period_10, kernel)?;
1002 assert_eq!(output_period_10.values.len(), candles.close.len());
1003
1004 let params_custom = RsxParams { period: Some(20) };
1005 let input_custom = RsxInput::from_candles(&candles, "hlc3", params_custom);
1006 let output_custom = rsx_with_kernel(&input_custom, kernel)?;
1007 assert_eq!(output_custom.values.len(), candles.close.len());
1008
1009 Ok(())
1010 }
1011
1012 fn check_rsx_accuracy(
1013 test_name: &str,
1014 kernel: Kernel,
1015 ) -> Result<(), Box<dyn std::error::Error>> {
1016 skip_if_unsupported!(kernel, test_name);
1017 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1018 let candles = read_candles_from_csv(file_path)?;
1019 let params = RsxParams { period: Some(14) };
1020 let input = RsxInput::from_candles(&candles, "close", params);
1021 let rsx_result = rsx_with_kernel(&input, kernel)?;
1022
1023 let expected_last_five_rsx = [
1024 46.11486311289701,
1025 46.88048640321688,
1026 47.174443049619995,
1027 47.48751360654475,
1028 46.552886446171684,
1029 ];
1030 let start_index = rsx_result.values.len() - 5;
1031 let result_last_five_rsx = &rsx_result.values[start_index..];
1032 for (i, &value) in result_last_five_rsx.iter().enumerate() {
1033 let expected_value = expected_last_five_rsx[i];
1034 assert!(
1035 (value - expected_value).abs() < 1e-1,
1036 "RSX mismatch at index {}: expected {}, got {}",
1037 i,
1038 expected_value,
1039 value
1040 );
1041 }
1042 Ok(())
1043 }
1044
1045 fn check_rsx_default_candles(
1046 test_name: &str,
1047 kernel: Kernel,
1048 ) -> Result<(), Box<dyn std::error::Error>> {
1049 skip_if_unsupported!(kernel, test_name);
1050 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1051 let candles = read_candles_from_csv(file_path)?;
1052 let input = RsxInput::with_default_candles(&candles);
1053 match input.data {
1054 RsxData::Candles { source, .. } => assert_eq!(source, "close"),
1055 _ => panic!("Expected RsxData::Candles"),
1056 }
1057 let output = rsx_with_kernel(&input, kernel)?;
1058 assert_eq!(output.values.len(), candles.close.len());
1059 Ok(())
1060 }
1061
1062 fn check_rsx_zero_period(
1063 test_name: &str,
1064 kernel: Kernel,
1065 ) -> Result<(), Box<dyn std::error::Error>> {
1066 skip_if_unsupported!(kernel, test_name);
1067 let input_data = [10.0, 20.0, 30.0];
1068 let params = RsxParams { period: Some(0) };
1069 let input = RsxInput::from_slice(&input_data, params);
1070 let res = rsx_with_kernel(&input, kernel);
1071 assert!(
1072 res.is_err(),
1073 "[{}] RSX should fail with zero period",
1074 test_name
1075 );
1076 Ok(())
1077 }
1078
1079 fn check_rsx_period_exceeds_length(
1080 test_name: &str,
1081 kernel: Kernel,
1082 ) -> Result<(), Box<dyn std::error::Error>> {
1083 skip_if_unsupported!(kernel, test_name);
1084 let data_small = [10.0, 20.0, 30.0];
1085 let params = RsxParams { period: Some(10) };
1086 let input = RsxInput::from_slice(&data_small, params);
1087 let res = rsx_with_kernel(&input, kernel);
1088 assert!(
1089 res.is_err(),
1090 "[{}] RSX should fail with period exceeding length",
1091 test_name
1092 );
1093 Ok(())
1094 }
1095
1096 fn check_rsx_very_small_dataset(
1097 test_name: &str,
1098 kernel: Kernel,
1099 ) -> Result<(), Box<dyn std::error::Error>> {
1100 skip_if_unsupported!(kernel, test_name);
1101 let single_point = [42.0];
1102 let params = RsxParams { period: Some(14) };
1103 let input = RsxInput::from_slice(&single_point, params);
1104 let res = rsx_with_kernel(&input, kernel);
1105 assert!(
1106 res.is_err(),
1107 "[{}] RSX should fail with insufficient data",
1108 test_name
1109 );
1110 Ok(())
1111 }
1112
1113 fn check_rsx_reinput(
1114 test_name: &str,
1115 kernel: Kernel,
1116 ) -> Result<(), Box<dyn std::error::Error>> {
1117 skip_if_unsupported!(kernel, test_name);
1118 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1119 let candles = read_candles_from_csv(file_path)?;
1120
1121 let first_params = RsxParams { period: Some(14) };
1122 let first_input = RsxInput::from_candles(&candles, "close", first_params);
1123 let first_result = rsx_with_kernel(&first_input, kernel)?;
1124
1125 let second_params = RsxParams { period: Some(14) };
1126 let second_input = RsxInput::from_slice(&first_result.values, second_params);
1127 let second_result = rsx_with_kernel(&second_input, kernel)?;
1128
1129 assert_eq!(second_result.values.len(), first_result.values.len());
1130 Ok(())
1131 }
1132
1133 fn check_rsx_nan_handling(
1134 test_name: &str,
1135 kernel: Kernel,
1136 ) -> Result<(), Box<dyn std::error::Error>> {
1137 skip_if_unsupported!(kernel, test_name);
1138 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1139 let candles = read_candles_from_csv(file_path)?;
1140
1141 let input = RsxInput::from_candles(&candles, "close", RsxParams { period: Some(14) });
1142 let res = rsx_with_kernel(&input, kernel)?;
1143 assert_eq!(res.values.len(), candles.close.len());
1144 if res.values.len() > 50 {
1145 for (i, &val) in res.values[50..].iter().enumerate() {
1146 assert!(
1147 !val.is_nan(),
1148 "[{}] Found unexpected NaN at out-index {}",
1149 test_name,
1150 50 + i
1151 );
1152 }
1153 }
1154 Ok(())
1155 }
1156
1157 fn check_rsx_streaming(
1158 test_name: &str,
1159 kernel: Kernel,
1160 ) -> Result<(), Box<dyn std::error::Error>> {
1161 skip_if_unsupported!(kernel, test_name);
1162 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1163 let candles = read_candles_from_csv(file_path)?;
1164
1165 let period = 14;
1166
1167 let input = RsxInput::from_candles(
1168 &candles,
1169 "close",
1170 RsxParams {
1171 period: Some(period),
1172 },
1173 );
1174 let batch_output = rsx_with_kernel(&input, kernel)?.values;
1175
1176 let mut stream = RsxStream::try_new(RsxParams {
1177 period: Some(period),
1178 })?;
1179
1180 let mut stream_values = Vec::with_capacity(candles.close.len());
1181 for &price in &candles.close {
1182 match stream.update(price) {
1183 Some(rsx_val) => stream_values.push(rsx_val),
1184 None => stream_values.push(f64::NAN),
1185 }
1186 }
1187
1188 assert_eq!(batch_output.len(), stream_values.len());
1189 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1190 if b.is_nan() && s.is_nan() {
1191 continue;
1192 }
1193 let diff = (b - s).abs();
1194 assert!(
1195 diff < 1e-9,
1196 "[{}] RSX streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1197 test_name,
1198 i,
1199 b,
1200 s,
1201 diff
1202 );
1203 }
1204 Ok(())
1205 }
1206
1207 macro_rules! generate_all_rsx_tests {
1208 ($($test_fn:ident),*) => {
1209 paste::paste! {
1210 $(
1211 #[test]
1212 fn [<$test_fn _scalar_f64>]() {
1213 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1214 }
1215 )*
1216 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1217 $(
1218 #[test]
1219 fn [<$test_fn _avx2_f64>]() {
1220 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1221 }
1222 #[test]
1223 fn [<$test_fn _avx512_f64>]() {
1224 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1225 }
1226 )*
1227 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1228 $(
1229 #[test]
1230 fn [<$test_fn _simd128_f64>]() {
1231 let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
1232 }
1233 )*
1234 }
1235 }
1236 }
1237
1238 #[cfg(debug_assertions)]
1239 fn check_rsx_no_poison(
1240 test_name: &str,
1241 kernel: Kernel,
1242 ) -> Result<(), Box<dyn std::error::Error>> {
1243 skip_if_unsupported!(kernel, test_name);
1244
1245 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1246 let candles = read_candles_from_csv(file_path)?;
1247
1248 let test_params = vec![
1249 RsxParams::default(),
1250 RsxParams { period: Some(2) },
1251 RsxParams { period: Some(7) },
1252 RsxParams { period: Some(21) },
1253 RsxParams { period: Some(50) },
1254 RsxParams { period: Some(100) },
1255 RsxParams { period: Some(200) },
1256 RsxParams { period: Some(5) },
1257 RsxParams { period: Some(10) },
1258 RsxParams { period: Some(20) },
1259 RsxParams { period: Some(30) },
1260 RsxParams { period: Some(40) },
1261 ];
1262
1263 for (param_idx, params) in test_params.iter().enumerate() {
1264 let input = RsxInput::from_candles(&candles, "close", params.clone());
1265 let output = rsx_with_kernel(&input, kernel)?;
1266
1267 for (i, &val) in output.values.iter().enumerate() {
1268 if val.is_nan() {
1269 continue;
1270 }
1271
1272 let bits = val.to_bits();
1273
1274 if bits == 0x11111111_11111111 {
1275 panic!(
1276 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1277 with params: period={} (param set {})",
1278 test_name,
1279 val,
1280 bits,
1281 i,
1282 params.period.unwrap_or(14),
1283 param_idx
1284 );
1285 }
1286
1287 if bits == 0x22222222_22222222 {
1288 panic!(
1289 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1290 with params: period={} (param set {})",
1291 test_name,
1292 val,
1293 bits,
1294 i,
1295 params.period.unwrap_or(14),
1296 param_idx
1297 );
1298 }
1299
1300 if bits == 0x33333333_33333333 {
1301 panic!(
1302 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1303 with params: period={} (param set {})",
1304 test_name,
1305 val,
1306 bits,
1307 i,
1308 params.period.unwrap_or(14),
1309 param_idx
1310 );
1311 }
1312 }
1313 }
1314
1315 Ok(())
1316 }
1317
1318 #[cfg(not(debug_assertions))]
1319 fn check_rsx_no_poison(
1320 _test_name: &str,
1321 _kernel: Kernel,
1322 ) -> Result<(), Box<dyn std::error::Error>> {
1323 Ok(())
1324 }
1325
1326 #[cfg(feature = "proptest")]
1327 #[allow(clippy::float_cmp)]
1328 fn check_rsx_property(
1329 test_name: &str,
1330 kernel: Kernel,
1331 ) -> Result<(), Box<dyn std::error::Error>> {
1332 use proptest::prelude::*;
1333 skip_if_unsupported!(kernel, test_name);
1334
1335 let strat = (2usize..=64).prop_flat_map(|period| {
1336 (
1337 prop::collection::vec(
1338 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1339 period..400,
1340 ),
1341 Just(period),
1342 )
1343 });
1344
1345 proptest::test_runner::TestRunner::default()
1346 .run(&strat, |(data, period)| {
1347 let params = RsxParams {
1348 period: Some(period),
1349 };
1350 let input = RsxInput::from_slice(&data, params);
1351
1352 let RsxOutput { values: out } = rsx_with_kernel(&input, kernel).unwrap();
1353 let RsxOutput { values: ref_out } =
1354 rsx_with_kernel(&input, Kernel::Scalar).unwrap();
1355
1356 for i in period..data.len() {
1357 let y = out[i];
1358 if !y.is_nan() {
1359 prop_assert!(
1360 y >= 0.0 && y <= 100.0,
1361 "idx {i}: RSX value {y} outside [0, 100] bounds"
1362 );
1363 }
1364 }
1365
1366 for i in 0..period.min(data.len()) {
1367 prop_assert!(
1368 out[i].is_nan(),
1369 "idx {i}: Expected NaN during warmup, got {}",
1370 out[i]
1371 );
1372 }
1373
1374 if data.len() > period {
1375 prop_assert!(
1376 !out[period].is_nan(),
1377 "idx {}: Expected valid RSX value after warmup, got NaN",
1378 period
1379 );
1380 }
1381
1382 if data.windows(2).all(|w| (w[0] - w[1]).abs() < f64::EPSILON)
1383 && data.len() > period
1384 {
1385 for i in period..data.len() {
1386 prop_assert!(
1387 (out[i] - 50.0).abs() <= 1e-9,
1388 "idx {i}: Constant data should produce RSX=50.0, got {}",
1389 out[i]
1390 );
1391 }
1392 }
1393
1394 let is_strictly_increasing = data.windows(2).all(|w| w[1] > w[0]);
1395 if is_strictly_increasing && data.len() >= period + 10 {
1396 for i in (period + 10)..data.len() {
1397 prop_assert!(
1398 out[i] > 50.0 || (out[i] - 50.0).abs() < 1e-9,
1399 "idx {i}: Strictly increasing prices should produce RSX > 50, got {}",
1400 out[i]
1401 );
1402 }
1403 }
1404
1405 let is_strictly_decreasing = data.windows(2).all(|w| w[1] < w[0]);
1406 if is_strictly_decreasing && data.len() >= period + 10 {
1407 for i in (period + 10)..data.len() {
1408 prop_assert!(
1409 out[i] < 50.0 || (out[i] - 50.0).abs() < 1e-9,
1410 "idx {i}: Strictly decreasing prices should produce RSX < 50, got {}",
1411 out[i]
1412 );
1413 }
1414 }
1415
1416 for i in 0..data.len() {
1417 let y = out[i];
1418 let r = ref_out[i];
1419
1420 if !y.is_finite() || !r.is_finite() {
1421 prop_assert!(
1422 y.to_bits() == r.to_bits(),
1423 "finite/NaN mismatch idx {i}: {y} vs {r}"
1424 );
1425 continue;
1426 }
1427
1428 let y_bits = y.to_bits();
1429 let r_bits = r.to_bits();
1430 let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1431
1432 prop_assert!(
1433 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1434 "kernel mismatch idx {i}: {y} vs {r} (ULP={ulp_diff})"
1435 );
1436 }
1437
1438 let RsxOutput { values: out2 } = rsx_with_kernel(&input, kernel).unwrap();
1439 for i in 0..data.len() {
1440 prop_assert!(
1441 out[i].to_bits() == out2[i].to_bits(),
1442 "determinism failed at idx {i}: first={}, second={}",
1443 out[i],
1444 out2[i]
1445 );
1446 }
1447
1448 Ok(())
1449 })
1450 .unwrap();
1451
1452 Ok(())
1453 }
1454
1455 generate_all_rsx_tests!(
1456 check_rsx_partial_params,
1457 check_rsx_accuracy,
1458 check_rsx_default_candles,
1459 check_rsx_zero_period,
1460 check_rsx_period_exceeds_length,
1461 check_rsx_very_small_dataset,
1462 check_rsx_reinput,
1463 check_rsx_nan_handling,
1464 check_rsx_streaming,
1465 check_rsx_no_poison
1466 );
1467
1468 #[cfg(feature = "proptest")]
1469 generate_all_rsx_tests!(check_rsx_property);
1470
1471 fn check_batch_default_row(
1472 test: &str,
1473 kernel: Kernel,
1474 ) -> Result<(), Box<dyn std::error::Error>> {
1475 skip_if_unsupported!(kernel, test);
1476
1477 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1478 let c = read_candles_from_csv(file)?;
1479
1480 let output = RsxBatchBuilder::new()
1481 .kernel(kernel)
1482 .apply_candles(&c, "close")?;
1483
1484 let def = RsxParams::default();
1485 let row = output.values_for(&def).expect("default row missing");
1486
1487 assert_eq!(row.len(), c.close.len());
1488
1489 let expected = [
1490 46.11486311289701,
1491 46.88048640321688,
1492 47.174443049619995,
1493 47.48751360654475,
1494 46.552886446171684,
1495 ];
1496 let start = row.len() - 5;
1497 for (i, &v) in row[start..].iter().enumerate() {
1498 assert!(
1499 (v - expected[i]).abs() < 1e-1,
1500 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1501 );
1502 }
1503 Ok(())
1504 }
1505
1506 macro_rules! gen_batch_tests {
1507 ($fn_name:ident) => {
1508 paste::paste! {
1509 #[test] fn [<$fn_name _scalar>]() {
1510 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1511 }
1512 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1513 #[test] fn [<$fn_name _avx2>]() {
1514 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1515 }
1516 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1517 #[test] fn [<$fn_name _avx512>]() {
1518 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1519 }
1520 #[test] fn [<$fn_name _auto_detect>]() {
1521 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1522 }
1523 }
1524 };
1525 }
1526 #[cfg(debug_assertions)]
1527 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1528 skip_if_unsupported!(kernel, test);
1529
1530 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1531 let c = read_candles_from_csv(file)?;
1532
1533 let test_configs = vec![
1534 (2, 10, 2),
1535 (5, 25, 5),
1536 (30, 60, 15),
1537 (2, 5, 1),
1538 (10, 100, 10),
1539 (14, 28, 7),
1540 (3, 9, 3),
1541 (50, 150, 25),
1542 ];
1543
1544 for (cfg_idx, &(period_start, period_end, period_step)) in test_configs.iter().enumerate() {
1545 let output = RsxBatchBuilder::new()
1546 .kernel(kernel)
1547 .period_range(period_start, period_end, period_step)
1548 .apply_candles(&c, "close")?;
1549
1550 for (idx, &val) in output.values.iter().enumerate() {
1551 if val.is_nan() {
1552 continue;
1553 }
1554
1555 let bits = val.to_bits();
1556 let row = idx / output.cols;
1557 let col = idx % output.cols;
1558 let combo = &output.combos[row];
1559
1560 if bits == 0x11111111_11111111 {
1561 panic!(
1562 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1563 at row {} col {} (flat index {}) with params: period={}",
1564 test,
1565 cfg_idx,
1566 val,
1567 bits,
1568 row,
1569 col,
1570 idx,
1571 combo.period.unwrap_or(14)
1572 );
1573 }
1574
1575 if bits == 0x22222222_22222222 {
1576 panic!(
1577 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1578 at row {} col {} (flat index {}) with params: period={}",
1579 test,
1580 cfg_idx,
1581 val,
1582 bits,
1583 row,
1584 col,
1585 idx,
1586 combo.period.unwrap_or(14)
1587 );
1588 }
1589
1590 if bits == 0x33333333_33333333 {
1591 panic!(
1592 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1593 at row {} col {} (flat index {}) with params: period={}",
1594 test,
1595 cfg_idx,
1596 val,
1597 bits,
1598 row,
1599 col,
1600 idx,
1601 combo.period.unwrap_or(14)
1602 );
1603 }
1604 }
1605 }
1606
1607 Ok(())
1608 }
1609
1610 #[cfg(not(debug_assertions))]
1611 fn check_batch_no_poison(
1612 _test: &str,
1613 _kernel: Kernel,
1614 ) -> Result<(), Box<dyn std::error::Error>> {
1615 Ok(())
1616 }
1617
1618 gen_batch_tests!(check_batch_default_row);
1619 gen_batch_tests!(check_batch_no_poison);
1620}
1621
1622#[cfg(feature = "python")]
1623#[pyfunction(name = "rsx")]
1624#[pyo3(signature = (data, period, kernel=None))]
1625pub fn rsx_py<'py>(
1626 py: Python<'py>,
1627 data: PyReadonlyArray1<'py, f64>,
1628 period: usize,
1629 kernel: Option<&str>,
1630) -> PyResult<Bound<'py, PyArray1<f64>>> {
1631 use numpy::{IntoPyArray, PyArrayMethods};
1632
1633 let slice_in = data.as_slice()?;
1634 let kern = validate_kernel(kernel, false)?;
1635
1636 let params = RsxParams {
1637 period: Some(period),
1638 };
1639 let input = RsxInput::from_slice(slice_in, params);
1640
1641 let result_vec: Vec<f64> = py
1642 .allow_threads(|| rsx_with_kernel(&input, kern).map(|o| o.values))
1643 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1644
1645 Ok(result_vec.into_pyarray(py))
1646}
1647
1648#[cfg(feature = "python")]
1649#[pyclass(name = "RsxStream")]
1650pub struct RsxStreamPy {
1651 stream: RsxStream,
1652}
1653
1654#[cfg(feature = "python")]
1655#[pymethods]
1656impl RsxStreamPy {
1657 #[new]
1658 fn new(period: usize) -> PyResult<Self> {
1659 let params = RsxParams {
1660 period: Some(period),
1661 };
1662 let stream =
1663 RsxStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1664 Ok(RsxStreamPy { stream })
1665 }
1666
1667 fn update(&mut self, value: f64) -> Option<f64> {
1668 self.stream.update(value)
1669 }
1670}
1671
1672#[cfg(feature = "python")]
1673#[pyfunction(name = "rsx_batch")]
1674#[pyo3(signature = (data, period_range, kernel=None))]
1675pub fn rsx_batch_py<'py>(
1676 py: Python<'py>,
1677 data: PyReadonlyArray1<'py, f64>,
1678 period_range: (usize, usize, usize),
1679 kernel: Option<&str>,
1680) -> PyResult<Bound<'py, PyDict>> {
1681 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1682 use pyo3::types::PyDict;
1683
1684 let slice_in = data.as_slice()?;
1685 let kern = validate_kernel(kernel, true)?;
1686
1687 let sweep = RsxBatchRange {
1688 period: period_range,
1689 };
1690
1691 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1692 let rows = combos.len();
1693 let cols = slice_in.len();
1694 let total = rows
1695 .checked_mul(cols)
1696 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1697
1698 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1699 let slice_out = unsafe { out_arr.as_slice_mut()? };
1700
1701 let combos = py
1702 .allow_threads(|| {
1703 let kernel = match kern {
1704 Kernel::Auto => detect_best_batch_kernel(),
1705 k => k,
1706 };
1707
1708 let simd = match kernel {
1709 Kernel::Avx512Batch => Kernel::Avx512,
1710 Kernel::Avx2Batch => Kernel::Avx2,
1711 Kernel::ScalarBatch => Kernel::Scalar,
1712 _ => kernel,
1713 };
1714
1715 rsx_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1716 })
1717 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1718
1719 let dict = PyDict::new(py);
1720 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1721
1722 dict.set_item(
1723 "periods",
1724 combos
1725 .iter()
1726 .map(|p| p.period.unwrap() as u64)
1727 .collect::<Vec<_>>()
1728 .into_pyarray(py),
1729 )?;
1730
1731 Ok(dict)
1732}
1733
1734#[cfg(any(feature = "python", feature = "wasm"))]
1735#[inline(always)]
1736fn rsx_batch_inner_into(
1737 data: &[f64],
1738 sweep: &RsxBatchRange,
1739 kern: Kernel,
1740 parallel: bool,
1741 out: &mut [f64],
1742) -> Result<Vec<RsxParams>, RsxError> {
1743 let combos = expand_grid(sweep)?;
1744 if combos.is_empty() {
1745 return Err(RsxError::InvalidRange {
1746 start: "range".into(),
1747 end: "range".into(),
1748 step: "empty".into(),
1749 });
1750 }
1751
1752 let len = data.len();
1753 if len == 0 {
1754 return Err(RsxError::EmptyInputData);
1755 }
1756 let first = data
1757 .iter()
1758 .position(|x| !x.is_nan())
1759 .ok_or(RsxError::AllValuesNaN)?;
1760 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1761 if len - first < max_p {
1762 return Err(RsxError::NotEnoughValidData {
1763 needed: max_p,
1764 valid: len - first,
1765 });
1766 }
1767
1768 let rows = combos.len();
1769 let cols = len;
1770
1771 unsafe {
1772 let out_mu =
1773 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len());
1774 let warm: Vec<usize> = combos
1775 .iter()
1776 .map(|c| first + c.period.unwrap() - 1)
1777 .collect();
1778 init_matrix_prefixes(out_mu, cols, &warm);
1779 }
1780
1781 let actual = match kern {
1782 Kernel::Auto => Kernel::Scalar,
1783 k => k,
1784 };
1785
1786 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1787 let period = combos[row].period.unwrap();
1788 match actual {
1789 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1790 Kernel::Scalar | Kernel::ScalarBatch => rsx_simd128(data, period, first, out_row),
1791 #[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
1792 Kernel::Scalar | Kernel::ScalarBatch => rsx_scalar(data, period, first, out_row),
1793
1794 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1795 Kernel::Avx2 | Kernel::Avx2Batch => rsx_avx2(data, period, first, out_row),
1796 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1797 Kernel::Avx512 | Kernel::Avx512Batch => rsx_avx512(data, period, first, out_row),
1798 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1799 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
1800 rsx_scalar(data, period, first, out_row)
1801 }
1802
1803 _ => rsx_scalar(data, period, first, out_row),
1804 }
1805 };
1806
1807 if parallel {
1808 #[cfg(not(target_arch = "wasm32"))]
1809 {
1810 out.par_chunks_mut(cols)
1811 .enumerate()
1812 .for_each(|(row, slice)| do_row(row, slice));
1813 }
1814
1815 #[cfg(target_arch = "wasm32")]
1816 {
1817 for (row, slice) in out.chunks_mut(cols).enumerate() {
1818 do_row(row, slice);
1819 }
1820 }
1821 } else {
1822 for (row, slice) in out.chunks_mut(cols).enumerate() {
1823 do_row(row, slice);
1824 }
1825 }
1826
1827 Ok(combos)
1828}
1829
1830#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1831#[wasm_bindgen]
1832pub fn rsx_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1833 let params = RsxParams {
1834 period: Some(period),
1835 };
1836 let input = RsxInput::from_slice(data, params);
1837
1838 let mut output = vec![0.0; data.len()];
1839
1840 rsx_into_slice(&mut output, &input, Kernel::Auto)
1841 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1842
1843 Ok(output)
1844}
1845
1846#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1847#[wasm_bindgen]
1848pub fn rsx_into(
1849 in_ptr: *const f64,
1850 out_ptr: *mut f64,
1851 len: usize,
1852 period: usize,
1853) -> Result<(), JsValue> {
1854 if in_ptr.is_null() || out_ptr.is_null() {
1855 return Err(JsValue::from_str("null pointer passed to rsx_into"));
1856 }
1857
1858 unsafe {
1859 let data = std::slice::from_raw_parts(in_ptr, len);
1860
1861 if period == 0 || period > len {
1862 return Err(JsValue::from_str("Invalid period"));
1863 }
1864
1865 let params = RsxParams {
1866 period: Some(period),
1867 };
1868 let input = RsxInput::from_slice(data, params);
1869
1870 if in_ptr == out_ptr {
1871 let mut temp = vec![0.0; len];
1872 rsx_into_slice(&mut temp, &input, Kernel::Auto)
1873 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1874 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1875 out.copy_from_slice(&temp);
1876 } else {
1877 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1878 rsx_into_slice(out, &input, Kernel::Auto)
1879 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1880 }
1881
1882 Ok(())
1883 }
1884}
1885
1886#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1887#[wasm_bindgen]
1888pub fn rsx_alloc(len: usize) -> *mut f64 {
1889 let mut vec = Vec::<f64>::with_capacity(len);
1890 let ptr = vec.as_mut_ptr();
1891 std::mem::forget(vec);
1892 ptr
1893}
1894
1895#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1896#[wasm_bindgen]
1897pub fn rsx_free(ptr: *mut f64, len: usize) {
1898 if !ptr.is_null() {
1899 unsafe {
1900 let _ = Vec::from_raw_parts(ptr, len, len);
1901 }
1902 }
1903}
1904
1905#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1906#[derive(Serialize, Deserialize)]
1907pub struct RsxBatchConfig {
1908 pub period_range: (usize, usize, usize),
1909}
1910
1911#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1912#[derive(Serialize, Deserialize)]
1913pub struct RsxBatchJsOutput {
1914 pub values: Vec<f64>,
1915 pub combos: Vec<RsxParams>,
1916 pub rows: usize,
1917 pub cols: usize,
1918}
1919
1920#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1921#[wasm_bindgen(js_name = rsx_batch)]
1922pub fn rsx_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1923 let config: RsxBatchConfig = serde_wasm_bindgen::from_value(config)
1924 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1925 let sweep = RsxBatchRange {
1926 period: config.period_range,
1927 };
1928
1929 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1930 let rows = combos.len();
1931 let cols = data.len();
1932 let _ = rows
1933 .checked_mul(cols)
1934 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
1935
1936 let mut buf_mu = make_uninit_matrix(rows, cols);
1937 let first = data
1938 .iter()
1939 .position(|x| !x.is_nan())
1940 .ok_or_else(|| JsValue::from_str("All NaN"))?;
1941 let warm: Vec<usize> = combos
1942 .iter()
1943 .map(|c| first + c.period.unwrap() - 1)
1944 .collect();
1945 init_matrix_prefixes(&mut buf_mu, cols, &warm);
1946
1947 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
1948 let out_slice: &mut [f64] =
1949 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
1950
1951 rsx_batch_inner_into(data, &sweep, detect_best_kernel(), false, out_slice)
1952 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1953
1954 let values = unsafe {
1955 Vec::from_raw_parts(
1956 guard.as_mut_ptr() as *mut f64,
1957 guard.len(),
1958 guard.capacity(),
1959 )
1960 };
1961
1962 let js = RsxBatchJsOutput {
1963 values,
1964 combos,
1965 rows,
1966 cols,
1967 };
1968 serde_wasm_bindgen::to_value(&js)
1969 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1970}
1971
1972#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1973#[wasm_bindgen]
1974pub fn rsx_batch_into(
1975 in_ptr: *const f64,
1976 out_ptr: *mut f64,
1977 len: usize,
1978 period_start: usize,
1979 period_end: usize,
1980 period_step: usize,
1981) -> Result<usize, JsValue> {
1982 if in_ptr.is_null() || out_ptr.is_null() {
1983 return Err(JsValue::from_str("null pointer passed to rsx_batch_into"));
1984 }
1985
1986 unsafe {
1987 let data = std::slice::from_raw_parts(in_ptr, len);
1988
1989 let sweep = RsxBatchRange {
1990 period: (period_start, period_end, period_step),
1991 };
1992
1993 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1994 let rows = combos.len();
1995 let cols = len;
1996
1997 let total = rows
1998 .checked_mul(cols)
1999 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
2000 let out = std::slice::from_raw_parts_mut(out_ptr, total);
2001
2002 rsx_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
2003 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2004
2005 Ok(rows)
2006 }
2007}
2008
2009#[cfg(all(feature = "python", feature = "cuda"))]
2010use crate::cuda::cuda_available;
2011#[cfg(all(feature = "python", feature = "cuda"))]
2012use crate::cuda::moving_averages::DeviceArrayF32;
2013#[cfg(all(feature = "python", feature = "cuda"))]
2014use crate::cuda::oscillators::rsx_wrapper::CudaRsx;
2015#[cfg(all(feature = "python", feature = "cuda"))]
2016use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2017#[cfg(all(feature = "python", feature = "cuda"))]
2018use cust::context::Context;
2019#[cfg(all(feature = "python", feature = "cuda"))]
2020use std::sync::Arc;
2021
2022#[cfg(all(feature = "python", feature = "cuda"))]
2023#[pyclass(module = "ta_indicators.cuda", name = "RsxDeviceArrayF32", unsendable)]
2024pub struct RsxDeviceArrayF32Py {
2025 pub(crate) inner: DeviceArrayF32,
2026 pub(crate) _ctx: Arc<Context>,
2027 pub(crate) device_id: u32,
2028}
2029
2030#[cfg(all(feature = "python", feature = "cuda"))]
2031#[pymethods]
2032impl RsxDeviceArrayF32Py {
2033 #[getter]
2034 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2035 let d = PyDict::new(py);
2036 let itemsize = std::mem::size_of::<f32>();
2037 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
2038 d.set_item("typestr", "<f4")?;
2039 d.set_item("strides", (self.inner.cols * itemsize, itemsize))?;
2040 d.set_item("data", (self.inner.device_ptr() as usize, false))?;
2041
2042 d.set_item("version", 3)?;
2043 Ok(d)
2044 }
2045
2046 fn __dlpack_device__(&self) -> (i32, i32) {
2047 (2, self.device_id as i32)
2048 }
2049
2050 #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
2051 fn __dlpack__<'py>(
2052 &mut self,
2053 py: Python<'py>,
2054 stream: Option<pyo3::PyObject>,
2055 max_version: Option<pyo3::PyObject>,
2056 dl_device: Option<pyo3::PyObject>,
2057 copy: Option<pyo3::PyObject>,
2058 ) -> PyResult<PyObject> {
2059 if let Some(stream_obj) = stream.as_ref() {
2060 if let Ok(s) = stream_obj.extract::<usize>(py) {
2061 if s == 0 {
2062 return Err(PyValueError::new_err(
2063 "__dlpack__ stream=0 is invalid for CUDA",
2064 ));
2065 }
2066 }
2067 }
2068
2069 let (kdl, alloc_dev) = self.__dlpack_device__();
2070 if let Some(dev_obj) = dl_device.as_ref() {
2071 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2072 if dev_ty != kdl || dev_id != alloc_dev {
2073 return Err(PyValueError::new_err(
2074 "dl_device mismatch; cross-device copy not supported for RsxDeviceArrayF32",
2075 ));
2076 }
2077 }
2078 }
2079
2080 if let Some(copy_obj) = copy.as_ref() {
2081 let do_copy: bool = copy_obj.extract(py)?;
2082 if do_copy {
2083 return Err(PyValueError::new_err(
2084 "copy=True not supported for RsxDeviceArrayF32",
2085 ));
2086 }
2087 }
2088
2089 use cust::memory::DeviceBuffer;
2090 let dummy = DeviceBuffer::<f32>::from_slice(&[])
2091 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2092 let inner = std::mem::replace(
2093 &mut self.inner,
2094 DeviceArrayF32 {
2095 buf: dummy,
2096 rows: 0,
2097 cols: 0,
2098 },
2099 );
2100
2101 let rows = inner.rows;
2102 let cols = inner.cols;
2103 let buf = inner.buf;
2104
2105 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2106
2107 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2108 }
2109}
2110
2111#[cfg(all(feature = "python", feature = "cuda"))]
2112impl RsxDeviceArrayF32Py {
2113 pub fn new_from_rust(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
2114 Self {
2115 inner,
2116 _ctx: ctx_guard,
2117 device_id,
2118 }
2119 }
2120}
2121
2122#[cfg(all(feature = "python", feature = "cuda"))]
2123#[pyfunction(name = "rsx_cuda_batch_dev")]
2124#[pyo3(signature = (data, period_range, device_id=0))]
2125pub fn rsx_cuda_batch_dev_py(
2126 py: Python<'_>,
2127 data: numpy::PyReadonlyArray1<'_, f32>,
2128 period_range: (usize, usize, usize),
2129 device_id: usize,
2130) -> PyResult<RsxDeviceArrayF32Py> {
2131 if !cuda_available() {
2132 return Err(PyValueError::new_err("CUDA not available"));
2133 }
2134 let prices = data.as_slice()?;
2135 let sweep = RsxBatchRange {
2136 period: period_range,
2137 };
2138 let (inner, ctx, dev_id) = py.allow_threads(|| {
2139 let cuda = CudaRsx::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2140 let ctx = cuda.context_arc();
2141 let dev_id = cuda.device_id();
2142 let arr = cuda
2143 .rsx_batch_dev(prices, &sweep)
2144 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2145 Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
2146 })?;
2147 Ok(RsxDeviceArrayF32Py::new_from_rust(inner, ctx, dev_id))
2148}
2149
2150#[cfg(all(feature = "python", feature = "cuda"))]
2151#[pyfunction(name = "rsx_cuda_many_series_one_param_dev")]
2152#[pyo3(signature = (data_tm, cols, rows, period, device_id=0))]
2153pub fn rsx_cuda_many_series_one_param_dev_py(
2154 py: Python<'_>,
2155 data_tm: numpy::PyReadonlyArray1<'_, f32>,
2156 cols: usize,
2157 rows: usize,
2158 period: usize,
2159 device_id: usize,
2160) -> PyResult<RsxDeviceArrayF32Py> {
2161 if !cuda_available() {
2162 return Err(PyValueError::new_err("CUDA not available"));
2163 }
2164 let prices_tm = data_tm.as_slice()?;
2165 let expected = cols
2166 .checked_mul(rows)
2167 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2168 if prices_tm.len() != expected {
2169 return Err(PyValueError::new_err("time-major input length mismatch"));
2170 }
2171 let (inner, ctx, dev_id) = py.allow_threads(|| {
2172 let cuda = CudaRsx::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2173 let ctx = cuda.context_arc();
2174 let dev_id = cuda.device_id();
2175 let arr = cuda
2176 .rsx_many_series_one_param_time_major_dev(prices_tm, cols, rows, period)
2177 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2178 Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
2179 })?;
2180 Ok(RsxDeviceArrayF32Py::new_from_rust(inner, ctx, dev_id))
2181}