1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::indicators::moving_averages::alma::{make_device_array_py, DeviceArrayF32Py};
3#[cfg(feature = "python")]
4use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyValueError;
7#[cfg(feature = "python")]
8use pyo3::prelude::*;
9#[cfg(feature = "python")]
10use pyo3::types::PyDict;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::Candles;
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21 make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25use aligned_vec::{AVec, CACHELINE_ALIGN};
26#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
27use core::arch::x86_64::*;
28#[cfg(not(target_arch = "wasm32"))]
29use rayon::prelude::*;
30use std::error::Error;
31use thiserror::Error;
32
33#[derive(Debug, Clone)]
34pub enum LrsiData<'a> {
35 Candles { candles: &'a Candles },
36 Slices { high: &'a [f64], low: &'a [f64] },
37}
38
39#[derive(Debug, Clone)]
40pub struct LrsiOutput {
41 pub values: Vec<f64>,
42}
43
44#[derive(Debug, Clone)]
45#[cfg_attr(
46 all(target_arch = "wasm32", feature = "wasm"),
47 derive(Serialize, Deserialize)
48)]
49pub struct LrsiParams {
50 pub alpha: Option<f64>,
51}
52
53impl Default for LrsiParams {
54 fn default() -> Self {
55 Self { alpha: Some(0.2) }
56 }
57}
58
59#[derive(Debug, Clone)]
60pub struct LrsiInput<'a> {
61 pub data: LrsiData<'a>,
62 pub params: LrsiParams,
63}
64
65impl<'a> LrsiInput<'a> {
66 #[inline]
67 pub fn from_candles(c: &'a Candles, p: LrsiParams) -> Self {
68 Self {
69 data: LrsiData::Candles { candles: c },
70 params: p,
71 }
72 }
73 #[inline]
74 pub fn from_slices(high: &'a [f64], low: &'a [f64], p: LrsiParams) -> Self {
75 Self {
76 data: LrsiData::Slices { high, low },
77 params: p,
78 }
79 }
80 #[inline]
81 pub fn with_default_candles(c: &'a Candles) -> Self {
82 Self::from_candles(c, LrsiParams::default())
83 }
84 #[inline]
85 pub fn get_alpha(&self) -> f64 {
86 self.params.alpha.unwrap_or(0.2)
87 }
88}
89
90#[derive(Copy, Clone, Debug)]
91pub struct LrsiBuilder {
92 alpha: Option<f64>,
93 kernel: Kernel,
94}
95
96impl Default for LrsiBuilder {
97 fn default() -> Self {
98 Self {
99 alpha: None,
100 kernel: Kernel::Auto,
101 }
102 }
103}
104
105impl LrsiBuilder {
106 #[inline(always)]
107 pub fn new() -> Self {
108 Self::default()
109 }
110 #[inline(always)]
111 pub fn alpha(mut self, x: f64) -> Self {
112 self.alpha = Some(x);
113 self
114 }
115 #[inline(always)]
116 pub fn kernel(mut self, k: Kernel) -> Self {
117 self.kernel = k;
118 self
119 }
120
121 #[inline(always)]
122 pub fn apply(self, c: &Candles) -> Result<LrsiOutput, LrsiError> {
123 let p = LrsiParams { alpha: self.alpha };
124 let i = LrsiInput::from_candles(c, p);
125 lrsi_with_kernel(&i, self.kernel)
126 }
127
128 #[inline(always)]
129 pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<LrsiOutput, LrsiError> {
130 let p = LrsiParams { alpha: self.alpha };
131 let i = LrsiInput::from_slices(high, low, p);
132 lrsi_with_kernel(&i, self.kernel)
133 }
134
135 #[inline(always)]
136 pub fn into_stream(self) -> Result<LrsiStream, LrsiError> {
137 let p = LrsiParams { alpha: self.alpha };
138 LrsiStream::try_new(p)
139 }
140}
141
142#[derive(Debug, Error)]
143pub enum LrsiError {
144 #[error("lrsi: Empty input data slice.")]
145 EmptyInputData,
146 #[error("lrsi: Invalid alpha: alpha = {alpha}. Must be between 0 and 1.")]
147 InvalidAlpha { alpha: f64 },
148 #[error("lrsi: All values are NaN.")]
149 AllValuesNaN,
150 #[error("lrsi: Not enough valid data: needed = {needed}, valid = {valid}")]
151 NotEnoughValidData { needed: usize, valid: usize },
152 #[error("lrsi: Output length mismatch: expected = {expected}, got = {got}")]
153 OutputLengthMismatch { expected: usize, got: usize },
154 #[error("lrsi: Invalid range: start={start}, end={end}, step={step}")]
155 InvalidRange {
156 start: String,
157 end: String,
158 step: String,
159 },
160 #[error("lrsi: Invalid kernel for batch: {0:?}")]
161 InvalidKernelForBatch(crate::utilities::enums::Kernel),
162}
163
164#[inline]
165pub fn lrsi(input: &LrsiInput) -> Result<LrsiOutput, LrsiError> {
166 lrsi_with_kernel(input, Kernel::Auto)
167}
168
169#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
170#[inline]
171pub fn lrsi_into(input: &LrsiInput, out: &mut [f64]) -> Result<(), LrsiError> {
172 lrsi_into_slice(out, input, Kernel::Auto)
173}
174
175pub fn lrsi_with_kernel(input: &LrsiInput, kernel: Kernel) -> Result<LrsiOutput, LrsiError> {
176 let (high, low) = match &input.data {
177 LrsiData::Candles { candles } => {
178 let high = candles
179 .select_candle_field("high")
180 .map_err(|_| LrsiError::EmptyInputData)?;
181 let low = candles
182 .select_candle_field("low")
183 .map_err(|_| LrsiError::EmptyInputData)?;
184 if high.len() != low.len() {
185 return Err(LrsiError::EmptyInputData);
186 }
187 (high, low)
188 }
189 LrsiData::Slices { high, low } => (*high, *low),
190 };
191
192 if high.is_empty() || low.is_empty() {
193 return Err(LrsiError::EmptyInputData);
194 }
195
196 let alpha = input.get_alpha();
197 if !(0.0 < alpha && alpha < 1.0) {
198 return Err(LrsiError::InvalidAlpha { alpha });
199 }
200
201 let mut first_valid_idx = None;
202 for i in 0..high.len() {
203 let price = (high[i] + low[i]) / 2.0;
204 if !price.is_nan() {
205 first_valid_idx = Some(i);
206 break;
207 }
208 }
209
210 let first_valid_idx = first_valid_idx.ok_or(LrsiError::AllValuesNaN)?;
211 let n = high.len();
212 if n - first_valid_idx < 4 {
213 return Err(LrsiError::NotEnoughValidData {
214 needed: 4,
215 valid: n - first_valid_idx,
216 });
217 }
218
219 let warmup_period = first_valid_idx + 3;
220 let mut out = alloc_with_nan_prefix(n, warmup_period);
221
222 let chosen = match kernel {
223 Kernel::Auto => Kernel::Scalar,
224
225 Kernel::ScalarBatch | Kernel::Avx2Batch | Kernel::Avx512Batch => {
226 return Err(LrsiError::NotEnoughValidData {
227 needed: 2,
228 valid: 1,
229 });
230 }
231 other => other,
232 };
233
234 unsafe {
235 match chosen {
236 Kernel::Scalar => lrsi_scalar_hl(high, low, alpha, first_valid_idx, &mut out),
237 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
238 Kernel::Avx2 => lrsi_avx2_hl(high, low, alpha, first_valid_idx, &mut out),
239 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
240 Kernel::Avx512 => lrsi_avx512_hl(high, low, alpha, first_valid_idx, &mut out),
241 _ => unreachable!(),
242 }
243 }
244
245 Ok(LrsiOutput { values: out })
246}
247
248#[inline]
249pub fn lrsi_into_slice(dst: &mut [f64], input: &LrsiInput, kern: Kernel) -> Result<(), LrsiError> {
250 let (high, low) = match &input.data {
251 LrsiData::Candles { candles } => {
252 let high = candles
253 .select_candle_field("high")
254 .map_err(|_| LrsiError::EmptyInputData)?;
255 let low = candles
256 .select_candle_field("low")
257 .map_err(|_| LrsiError::EmptyInputData)?;
258 if high.len() != low.len() {
259 return Err(LrsiError::EmptyInputData);
260 }
261 (high, low)
262 }
263 LrsiData::Slices { high, low } => (*high, *low),
264 };
265
266 let alpha = input.get_alpha();
267 if !(0.0 < alpha && alpha < 1.0) {
268 return Err(LrsiError::InvalidAlpha { alpha });
269 }
270
271 let mut first_valid_idx = None;
272 for i in 0..high.len() {
273 let price = (high[i] + low[i]) / 2.0;
274 if !price.is_nan() {
275 first_valid_idx = Some(i);
276 break;
277 }
278 }
279
280 let first_valid_idx = first_valid_idx.ok_or(LrsiError::AllValuesNaN)?;
281 let n = high.len();
282
283 if dst.len() != n {
284 return Err(LrsiError::OutputLengthMismatch {
285 expected: n,
286 got: dst.len(),
287 });
288 }
289
290 if n - first_valid_idx < 4 {
291 return Err(LrsiError::NotEnoughValidData {
292 needed: 4,
293 valid: n - first_valid_idx,
294 });
295 }
296
297 let chosen = match kern {
298 Kernel::Auto => Kernel::Scalar,
299
300 Kernel::ScalarBatch | Kernel::Avx2Batch | Kernel::Avx512Batch => {
301 return Err(LrsiError::InvalidKernelForBatch(kern));
302 }
303 other => other,
304 };
305
306 unsafe {
307 match chosen {
308 Kernel::Scalar => lrsi_scalar_hl(high, low, alpha, first_valid_idx, dst),
309 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
310 Kernel::Avx2 => lrsi_avx2_hl(high, low, alpha, first_valid_idx, dst),
311 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
312 Kernel::Avx512 => lrsi_avx512_hl(high, low, alpha, first_valid_idx, dst),
313 _ => unreachable!(),
314 }
315 }
316
317 let warmup_end = first_valid_idx + 3;
318 for v in &mut dst[..warmup_end] {
319 *v = f64::NAN;
320 }
321
322 Ok(())
323}
324
325#[cfg(test)]
326mod tests_into_api {
327 use super::*;
328
329 #[test]
330 fn test_lrsi_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
331 let n = 256usize;
332 let mut high = vec![f64::NAN; n];
333 let mut low = vec![f64::NAN; n];
334
335 let mut v = 100.0f64;
336 for i in 3..n {
337 high[i] = v + 1.0;
338 low[i] = v - 1.0;
339
340 if i % 37 == 0 {
341 high[i] = f64::NAN;
342 }
343 if i % 53 == 0 {
344 low[i] = f64::NAN;
345 }
346 v += (i as f64).sin() * 0.25 + 0.5;
347 }
348
349 let input = LrsiInput::from_slices(&high, &low, LrsiParams::default());
350
351 let base = lrsi(&input)?.values;
352
353 let mut out = vec![0.0f64; n];
354 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
355 {
356 lrsi_into(&input, &mut out)?;
357 }
358 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
359 {
360 lrsi_into_slice(&mut out, &input, Kernel::Auto)?;
361 }
362
363 assert_eq!(base.len(), out.len());
364 for (i, (&a, &b)) in base.iter().zip(out.iter()).enumerate() {
365 let equal = (a.is_nan() && b.is_nan()) || (a == b);
366 assert!(equal, "mismatch at index {}: base={}, into={}", i, a, b);
367 }
368 Ok(())
369 }
370}
371
372#[inline]
373pub fn lrsi_scalar_hl(high: &[f64], low: &[f64], alpha: f64, first: usize, out: &mut [f64]) {
374 debug_assert_eq!(high.len(), low.len());
375
376 let len = high.len();
377 if len == 0 {
378 return;
379 }
380
381 let gamma = 1.0 - alpha;
382 let mgamma = -gamma;
383 let warm = first + 3;
384
385 let first_price = (high[first] + low[first]) * 0.5;
386 let mut l0 = first_price;
387 let mut l1 = first_price;
388 let mut l2 = first_price;
389 let mut l3 = first_price;
390
391 for i in (first + 1)..len {
392 let p = (high[i] + low[i]) * 0.5;
393
394 if p.is_nan() {
395 if i >= warm {
396 out[i] = f64::NAN;
397 }
398 continue;
399 }
400
401 let t0 = (p - l0).mul_add(alpha, l0);
402 let t1 = gamma.mul_add(l1, mgamma.mul_add(t0, l0));
403 let t2 = gamma.mul_add(l2, mgamma.mul_add(t1, l1));
404 let t3 = gamma.mul_add(l3, mgamma.mul_add(t2, l2));
405
406 if i >= warm {
407 let d01 = t0 - t1;
408 let d12 = t1 - t2;
409 let d23 = t2 - t3;
410
411 let a01 = d01.abs();
412 let a12 = d12.abs();
413 let a23 = d23.abs();
414
415 let sum_abs = a01 + a12 + a23;
416 let cu = 0.5 * (d01 + a01 + d12 + a12 + d23 + a23);
417
418 let v = if sum_abs <= f64::EPSILON {
419 0.0
420 } else {
421 cu / sum_abs
422 };
423
424 out[i] = v.min(1.0).max(0.0);
425 }
426
427 l0 = t0;
428 l1 = t1;
429 l2 = t2;
430 l3 = t3;
431 }
432}
433
434#[inline]
435pub fn lrsi_scalar(price: &[f64], alpha: f64, first: usize, out: &mut [f64]) {
436 let len = price.len();
437 if len == 0 {
438 return;
439 }
440
441 let gamma = 1.0 - alpha;
442 let mgamma = -gamma;
443 let warm = first + 3;
444
445 let mut l0 = price[first];
446 let mut l1 = l0;
447 let mut l2 = l0;
448 let mut l3 = l0;
449
450 for i in (first + 1)..len {
451 let p = price[i];
452 if p.is_nan() {
453 if i >= warm {
454 out[i] = f64::NAN;
455 }
456 continue;
457 }
458
459 let t0 = (p - l0).mul_add(alpha, l0);
460 let t1 = gamma.mul_add(l1, mgamma.mul_add(t0, l0));
461 let t2 = gamma.mul_add(l2, mgamma.mul_add(t1, l1));
462 let t3 = gamma.mul_add(l3, mgamma.mul_add(t2, l2));
463
464 if i >= warm {
465 let d01 = t0 - t1;
466 let d12 = t1 - t2;
467 let d23 = t2 - t3;
468
469 let a01 = d01.abs();
470 let a12 = d12.abs();
471 let a23 = d23.abs();
472
473 let sum_abs = a01 + a12 + a23;
474 let cu = 0.5 * (d01 + a01 + d12 + a12 + d23 + a23);
475
476 let v = if sum_abs <= f64::EPSILON {
477 0.0
478 } else {
479 cu / sum_abs
480 };
481
482 out[i] = v.min(1.0).max(0.0);
483 }
484
485 l0 = t0;
486 l1 = t1;
487 l2 = t2;
488 l3 = t3;
489 }
490}
491
492#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
493#[inline]
494pub fn lrsi_avx2_hl(high: &[f64], low: &[f64], alpha: f64, first: usize, out: &mut [f64]) {
495 lrsi_scalar_hl(high, low, alpha, first, out)
496}
497
498#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
499#[inline]
500pub fn lrsi_avx512_hl(high: &[f64], low: &[f64], alpha: f64, first: usize, out: &mut [f64]) {
501 lrsi_scalar_hl(high, low, alpha, first, out)
502}
503
504#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
505#[inline]
506pub fn lrsi_avx2(price: &[f64], alpha: f64, first: usize, out: &mut [f64]) {
507 lrsi_scalar(price, alpha, first, out)
508}
509
510#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
511#[inline]
512pub fn lrsi_avx512(price: &[f64], alpha: f64, first: usize, out: &mut [f64]) {
513 lrsi_scalar(price, alpha, first, out)
514}
515
516#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
517#[inline]
518pub fn lrsi_avx512_short(price: &[f64], alpha: f64, first: usize, out: &mut [f64]) {
519 lrsi_scalar(price, alpha, first, out)
520}
521
522#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
523#[inline]
524pub fn lrsi_avx512_long(price: &[f64], alpha: f64, first: usize, out: &mut [f64]) {
525 lrsi_scalar(price, alpha, first, out)
526}
527
528#[derive(Debug, Clone)]
529pub struct LrsiStream {
530 alpha: f64,
531 gamma: f64,
532 l0: f64,
533 l1: f64,
534 l2: f64,
535 l3: f64,
536 initialized: bool,
537 count: usize,
538}
539
540impl LrsiStream {
541 pub fn try_new(params: LrsiParams) -> Result<Self, LrsiError> {
542 let alpha = params.alpha.unwrap_or(0.2);
543 if !(0.0 < alpha && alpha < 1.0) {
544 return Err(LrsiError::InvalidAlpha { alpha });
545 }
546 Ok(Self {
547 alpha,
548 gamma: 1.0 - alpha,
549 l0: f64::NAN,
550 l1: f64::NAN,
551 l2: f64::NAN,
552 l3: f64::NAN,
553 initialized: false,
554 count: 0,
555 })
556 }
557 #[inline(always)]
558 pub fn update(&mut self, price: f64) -> Option<f64> {
559 if price.is_nan() {
560 return None;
561 }
562
563 if !self.initialized {
564 self.l0 = price;
565 self.l1 = price;
566 self.l2 = price;
567 self.l3 = price;
568 self.initialized = true;
569 self.count = 0;
570 return None;
571 }
572
573 let gamma = self.gamma;
574 let mgamma = -gamma;
575
576 let l0_prev = self.l0;
577 let l1_prev = self.l1;
578 let l2_prev = self.l2;
579 let l3_prev = self.l3;
580
581 let t0 = (price - l0_prev).mul_add(self.alpha, l0_prev);
582 let t1 = gamma.mul_add(l1_prev, mgamma.mul_add(t0, l0_prev));
583 let t2 = gamma.mul_add(l2_prev, mgamma.mul_add(t1, l1_prev));
584 let t3 = gamma.mul_add(l3_prev, mgamma.mul_add(t2, l2_prev));
585
586 self.l0 = t0;
587 self.l1 = t1;
588 self.l2 = t2;
589 self.l3 = t3;
590 self.count += 1;
591
592 if self.count < 3 {
593 return None;
594 }
595
596 let d01 = t0 - t1;
597 let d12 = t1 - t2;
598 let d23 = t2 - t3;
599
600 let a01 = d01.abs();
601 let a12 = d12.abs();
602 let a23 = d23.abs();
603
604 let sum_abs = a01 + a12 + a23;
605 if sum_abs <= f64::EPSILON {
606 return Some(0.0);
607 }
608
609 let cu = 0.5 * (d01 + a01 + d12 + a12 + d23 + a23);
610
611 let v = cu / sum_abs;
612 Some(v.min(1.0).max(0.0))
613 }
614}
615
616#[derive(Clone, Debug)]
617pub struct LrsiBatchRange {
618 pub alpha: (f64, f64, f64),
619}
620
621impl Default for LrsiBatchRange {
622 fn default() -> Self {
623 Self {
624 alpha: (0.2, 0.449, 0.001),
625 }
626 }
627}
628
629#[derive(Clone, Debug, Default)]
630pub struct LrsiBatchBuilder {
631 range: LrsiBatchRange,
632 kernel: Kernel,
633}
634
635impl LrsiBatchBuilder {
636 pub fn new() -> Self {
637 Self::default()
638 }
639 pub fn kernel(mut self, k: Kernel) -> Self {
640 self.kernel = k;
641 self
642 }
643
644 #[inline]
645 pub fn alpha_range(mut self, start: f64, end: f64, step: f64) -> Self {
646 self.range.alpha = (start, end, step);
647 self
648 }
649 #[inline]
650 pub fn alpha_static(mut self, x: f64) -> Self {
651 self.range.alpha = (x, x, 0.0);
652 self
653 }
654
655 pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<LrsiBatchOutput, LrsiError> {
656 lrsi_batch_with_kernel(high, low, &self.range, self.kernel)
657 }
658
659 pub fn with_default_slices(
660 high: &[f64],
661 low: &[f64],
662 k: Kernel,
663 ) -> Result<LrsiBatchOutput, LrsiError> {
664 LrsiBatchBuilder::new().kernel(k).apply_slices(high, low)
665 }
666
667 pub fn apply_candles(self, c: &Candles) -> Result<LrsiBatchOutput, LrsiError> {
668 let high = c
669 .select_candle_field("high")
670 .map_err(|_| LrsiError::EmptyInputData)?;
671 let low = c
672 .select_candle_field("low")
673 .map_err(|_| LrsiError::EmptyInputData)?;
674 if high.len() != low.len() {
675 return Err(LrsiError::EmptyInputData);
676 }
677 self.apply_slices(high, low)
678 }
679
680 pub fn with_default_candles(c: &Candles) -> Result<LrsiBatchOutput, LrsiError> {
681 LrsiBatchBuilder::new()
682 .kernel(Kernel::Auto)
683 .apply_candles(c)
684 }
685}
686
687pub fn lrsi_batch_with_kernel(
688 high: &[f64],
689 low: &[f64],
690 sweep: &LrsiBatchRange,
691 k: Kernel,
692) -> Result<LrsiBatchOutput, LrsiError> {
693 let kernel = match k {
694 Kernel::Auto => detect_best_batch_kernel(),
695 other if other.is_batch() => other,
696 other => return Err(LrsiError::InvalidKernelForBatch(other)),
697 };
698 lrsi_batch_par_slice(high, low, sweep, kernel)
699}
700
701#[derive(Clone, Debug)]
702pub struct LrsiBatchOutput {
703 pub values: Vec<f64>,
704 pub combos: Vec<LrsiParams>,
705 pub rows: usize,
706 pub cols: usize,
707}
708impl LrsiBatchOutput {
709 pub fn row_for_params(&self, p: &LrsiParams) -> Option<usize> {
710 self.combos
711 .iter()
712 .position(|c| (c.alpha.unwrap_or(0.2) - p.alpha.unwrap_or(0.2)).abs() < 1e-12)
713 }
714
715 pub fn values_for(&self, p: &LrsiParams) -> Option<&[f64]> {
716 self.row_for_params(p).map(|row| {
717 let start = row * self.cols;
718 &self.values[start..start + self.cols]
719 })
720 }
721}
722
723#[inline(always)]
724fn expand_grid(r: &LrsiBatchRange) -> Result<Vec<LrsiParams>, LrsiError> {
725 fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, LrsiError> {
726 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
727 return Ok(vec![start]);
728 }
729 if start < end {
730 let mut v = Vec::new();
731 let mut x = start;
732 let st = step.abs();
733 while x <= end + 1e-12 {
734 v.push(x);
735 x += st;
736 }
737 if v.is_empty() {
738 return Err(LrsiError::InvalidRange {
739 start: start.to_string(),
740 end: end.to_string(),
741 step: step.to_string(),
742 });
743 }
744 return Ok(v);
745 }
746
747 let mut v = Vec::new();
748 let mut x = start;
749 let st = step.abs();
750 while x + 1e-12 >= end {
751 v.push(x);
752 x -= st;
753 }
754 if v.is_empty() {
755 return Err(LrsiError::InvalidRange {
756 start: start.to_string(),
757 end: end.to_string(),
758 step: step.to_string(),
759 });
760 }
761 Ok(v)
762 }
763
764 let alphas = axis_f64(r.alpha)?;
765
766 let mut out = Vec::with_capacity(alphas.len());
767 for &a in &alphas {
768 out.push(LrsiParams { alpha: Some(a) });
769 }
770 Ok(out)
771}
772
773#[inline(always)]
774pub fn lrsi_batch_slice(
775 high: &[f64],
776 low: &[f64],
777 sweep: &LrsiBatchRange,
778 kern: Kernel,
779) -> Result<LrsiBatchOutput, LrsiError> {
780 lrsi_batch_inner(high, low, sweep, kern, false)
781}
782
783#[inline(always)]
784pub fn lrsi_batch_par_slice(
785 high: &[f64],
786 low: &[f64],
787 sweep: &LrsiBatchRange,
788 kern: Kernel,
789) -> Result<LrsiBatchOutput, LrsiError> {
790 lrsi_batch_inner(high, low, sweep, kern, true)
791}
792
793#[inline(always)]
794fn lrsi_batch_inner(
795 high: &[f64],
796 low: &[f64],
797 sweep: &LrsiBatchRange,
798 kern: Kernel,
799 parallel: bool,
800) -> Result<LrsiBatchOutput, LrsiError> {
801 let combos = expand_grid(sweep)?;
802 if combos.is_empty() {
803 return Err(LrsiError::InvalidRange {
804 start: sweep.alpha.0.to_string(),
805 end: sweep.alpha.1.to_string(),
806 step: sweep.alpha.2.to_string(),
807 });
808 }
809 if high.is_empty() || low.is_empty() {
810 return Err(LrsiError::EmptyInputData);
811 }
812 if high.len() != low.len() {
813 return Err(LrsiError::EmptyInputData);
814 }
815
816 let cols = high.len();
817 let rows = combos.len();
818 let total = rows.checked_mul(cols).ok_or(LrsiError::InvalidRange {
819 start: rows.to_string(),
820 end: cols.to_string(),
821 step: "rows*cols".into(),
822 })?;
823
824 let first = (0..cols)
825 .find(|&i| ((high[i] + low[i]) / 2.0).is_finite())
826 .ok_or(LrsiError::AllValuesNaN)?;
827 if cols - first < 4 {
828 return Err(LrsiError::NotEnoughValidData {
829 needed: 4,
830 valid: cols - first,
831 });
832 }
833
834 let mut buf_mu = make_uninit_matrix(rows, cols);
835 let warm = vec![first + 3; rows];
836 init_matrix_prefixes(&mut buf_mu, cols, &warm);
837
838 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
839 let out_slice: &mut [f64] =
840 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
841
842 let resolved = match kern {
843 Kernel::Auto => detect_best_batch_kernel(),
844 k if k.is_batch() => k,
845 other => return Err(LrsiError::InvalidKernelForBatch(other)),
846 };
847 let row_kernel = match resolved {
848 Kernel::Avx512Batch => Kernel::Avx512,
849 Kernel::Avx2Batch => Kernel::Avx2,
850 Kernel::ScalarBatch => Kernel::Scalar,
851 _ => unreachable!(),
852 };
853
854 let combos = lrsi_batch_inner_into(high, low, sweep, row_kernel, parallel, out_slice)?;
855
856 let values = unsafe {
857 Vec::from_raw_parts(
858 guard.as_mut_ptr() as *mut f64,
859 guard.len(),
860 guard.capacity(),
861 )
862 };
863
864 Ok(LrsiBatchOutput {
865 values,
866 combos,
867 rows,
868 cols,
869 })
870}
871
872#[inline(always)]
873unsafe fn lrsi_row_scalar_hl(high: &[f64], low: &[f64], first: usize, alpha: f64, out: &mut [f64]) {
874 lrsi_scalar_hl(high, low, alpha, first, out)
875}
876
877#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
878#[inline(always)]
879unsafe fn lrsi_row_avx2_hl(high: &[f64], low: &[f64], first: usize, alpha: f64, out: &mut [f64]) {
880 lrsi_scalar_hl(high, low, alpha, first, out)
881}
882
883#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
884#[inline(always)]
885unsafe fn lrsi_row_avx512_hl(high: &[f64], low: &[f64], first: usize, alpha: f64, out: &mut [f64]) {
886 lrsi_scalar_hl(high, low, alpha, first, out)
887}
888
889#[inline(always)]
890unsafe fn lrsi_row_scalar(price: &[f64], first: usize, alpha: f64, out: &mut [f64]) {
891 lrsi_scalar(price, alpha, first, out)
892}
893
894#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
895#[inline(always)]
896unsafe fn lrsi_row_avx2(price: &[f64], first: usize, alpha: f64, out: &mut [f64]) {
897 lrsi_scalar(price, alpha, first, out)
898}
899
900#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
901#[inline(always)]
902unsafe fn lrsi_row_avx512(price: &[f64], first: usize, alpha: f64, out: &mut [f64]) {
903 lrsi_scalar(price, alpha, first, out)
904}
905
906#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
907#[inline(always)]
908unsafe fn lrsi_row_avx512_short(price: &[f64], first: usize, alpha: f64, out: &mut [f64]) {
909 lrsi_scalar(price, alpha, first, out)
910}
911
912#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
913#[inline(always)]
914unsafe fn lrsi_row_avx512_long(price: &[f64], first: usize, alpha: f64, out: &mut [f64]) {
915 lrsi_scalar(price, alpha, first, out)
916}
917
918#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
919#[wasm_bindgen]
920pub fn lrsi_js(high: &[f64], low: &[f64], alpha: f64) -> Result<Vec<f64>, JsValue> {
921 let params = LrsiParams { alpha: Some(alpha) };
922 let input = LrsiInput::from_slices(high, low, params);
923
924 let mut output = vec![0.0; high.len()];
925 lrsi_into_slice(&mut output, &input, Kernel::Auto)
926 .map_err(|e| JsValue::from_str(&e.to_string()))?;
927
928 Ok(output)
929}
930
931#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
932#[derive(Serialize, Deserialize)]
933pub struct LrsiBatchConfig {
934 pub alpha_range: (f64, f64, f64),
935}
936
937#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
938#[derive(Serialize, Deserialize)]
939pub struct LrsiBatchJsOutput {
940 pub values: Vec<f64>,
941 pub combos: Vec<LrsiParams>,
942 pub rows: usize,
943 pub cols: usize,
944}
945
946#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
947#[wasm_bindgen(js_name = lrsi_batch)]
948pub fn lrsi_batch_unified_js(
949 high: &[f64],
950 low: &[f64],
951 config: JsValue,
952) -> Result<JsValue, JsValue> {
953 let config: LrsiBatchConfig = serde_wasm_bindgen::from_value(config)
954 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
955
956 let sweep = LrsiBatchRange {
957 alpha: config.alpha_range,
958 };
959 let result = lrsi_batch_with_kernel(high, low, &sweep, Kernel::Auto)
960 .map_err(|e| JsValue::from_str(&e.to_string()))?;
961
962 let output = LrsiBatchJsOutput {
963 values: result.values,
964 combos: result.combos,
965 rows: result.rows,
966 cols: result.cols,
967 };
968
969 serde_wasm_bindgen::to_value(&output).map_err(|e| JsValue::from_str(&e.to_string()))
970}
971
972#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
973#[wasm_bindgen]
974pub fn lrsi_alloc(len: usize) -> *mut f64 {
975 let mut vec = Vec::<f64>::with_capacity(len);
976 let ptr = vec.as_mut_ptr();
977 std::mem::forget(vec);
978 ptr
979}
980
981#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
982#[wasm_bindgen]
983pub fn lrsi_free(ptr: *mut f64, len: usize) {
984 unsafe {
985 let _ = Vec::from_raw_parts(ptr, len, len);
986 }
987}
988
989#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
990#[wasm_bindgen]
991pub fn lrsi_into(
992 high_ptr: *const f64,
993 low_ptr: *const f64,
994 out_ptr: *mut f64,
995 len: usize,
996 alpha: f64,
997) -> Result<(), JsValue> {
998 if high_ptr.is_null() || low_ptr.is_null() || out_ptr.is_null() {
999 return Err(JsValue::from_str("null pointer passed to lrsi_into"));
1000 }
1001
1002 unsafe {
1003 let high = std::slice::from_raw_parts(high_ptr, len);
1004 let low = std::slice::from_raw_parts(low_ptr, len);
1005 let params = LrsiParams { alpha: Some(alpha) };
1006 let input = LrsiInput::from_slices(high, low, params);
1007
1008 if high_ptr == out_ptr || low_ptr == out_ptr {
1009 let mut temp = vec![0.0; len];
1010 lrsi_into_slice(&mut temp, &input, Kernel::Auto)
1011 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1012 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1013 out.copy_from_slice(&temp);
1014 } else {
1015 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1016 lrsi_into_slice(out, &input, Kernel::Auto)
1017 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1018 }
1019
1020 Ok(())
1021 }
1022}
1023
1024#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1025#[wasm_bindgen]
1026pub fn lrsi_batch_into(
1027 high_ptr: *const f64,
1028 low_ptr: *const f64,
1029 out_ptr: *mut f64,
1030 len: usize,
1031 alpha_start: f64,
1032 alpha_end: f64,
1033 alpha_step: f64,
1034) -> Result<usize, JsValue> {
1035 if high_ptr.is_null() || low_ptr.is_null() || out_ptr.is_null() {
1036 return Err(JsValue::from_str("null pointer passed to lrsi_batch_into"));
1037 }
1038
1039 if !(0.0 < alpha_start && alpha_start <= 1.0) {
1040 return Err(JsValue::from_str(&format!(
1041 "Invalid alpha_start: {}",
1042 alpha_start
1043 )));
1044 }
1045 if !(0.0 < alpha_end && alpha_end <= 1.0) {
1046 return Err(JsValue::from_str(&format!(
1047 "Invalid alpha_end: {}",
1048 alpha_end
1049 )));
1050 }
1051
1052 unsafe {
1053 let high = std::slice::from_raw_parts(high_ptr, len);
1054 let low = std::slice::from_raw_parts(low_ptr, len);
1055 let sweep = LrsiBatchRange {
1056 alpha: (alpha_start, alpha_end, alpha_step),
1057 };
1058 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1059 let rows = combos.len();
1060 let cols = len;
1061 let total = rows
1062 .checked_mul(cols)
1063 .ok_or_else(|| JsValue::from_str("rows*cols overflow in lrsi_batch"))?;
1064 let out = std::slice::from_raw_parts_mut(out_ptr, total);
1065
1066 let row_kernel = match detect_best_batch_kernel() {
1067 Kernel::Avx512Batch => Kernel::Avx512,
1068 Kernel::Avx2Batch => Kernel::Avx2,
1069 _ => Kernel::Scalar,
1070 };
1071 lrsi_batch_inner_into(high, low, &sweep, row_kernel, false, out)
1072 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1073 Ok(rows)
1074 }
1075}
1076
1077#[cfg(test)]
1078mod tests {
1079 use super::*;
1080 use crate::skip_if_unsupported;
1081 use crate::utilities::data_loader::read_candles_from_csv;
1082
1083 fn check_lrsi_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1084 skip_if_unsupported!(kernel, test_name);
1085 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1086 let candles = read_candles_from_csv(file_path)?;
1087 let default_params = LrsiParams { alpha: None };
1088 let input = LrsiInput::from_candles(&candles, default_params);
1089 let output = lrsi_with_kernel(&input, kernel)?;
1090 assert_eq!(output.values.len(), candles.close.len());
1091 Ok(())
1092 }
1093
1094 fn check_lrsi_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1095 skip_if_unsupported!(kernel, test_name);
1096 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1097 let candles = read_candles_from_csv(file_path)?;
1098 let input = LrsiInput::from_candles(&candles, LrsiParams::default());
1099 let lrsi_result = lrsi_with_kernel(&input, kernel)?;
1100 assert_eq!(lrsi_result.values.len(), candles.close.len());
1101 let expected_last_five_lrsi = [0.0, 0.0, 0.0, 0.0, 0.0];
1102 let start_index = lrsi_result.values.len() - 5;
1103 let result_last_five_lrsi = &lrsi_result.values[start_index..];
1104 for (i, &value) in result_last_five_lrsi.iter().enumerate() {
1105 let expected_value = expected_last_five_lrsi[i];
1106 assert!(
1107 (value - expected_value).abs() < 1e-9,
1108 "LRSI mismatch at index {}: expected {}, got {}",
1109 i,
1110 expected_value,
1111 value
1112 );
1113 }
1114 Ok(())
1115 }
1116
1117 fn check_lrsi_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1118 skip_if_unsupported!(kernel, test_name);
1119 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1120 let candles = read_candles_from_csv(file_path)?;
1121 let input = LrsiInput::with_default_candles(&candles);
1122 let output = lrsi_with_kernel(&input, kernel)?;
1123 assert_eq!(output.values.len(), candles.close.len());
1124 Ok(())
1125 }
1126
1127 fn check_lrsi_invalid_alpha(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1128 skip_if_unsupported!(kernel, test_name);
1129 let high = [1.0, 2.0];
1130 let low = [1.0, 2.0];
1131 let params = LrsiParams { alpha: Some(1.2) };
1132 let input = LrsiInput::from_slices(&high, &low, params);
1133 let result = lrsi_with_kernel(&input, kernel);
1134 assert!(result.is_err());
1135 Ok(())
1136 }
1137
1138 fn check_lrsi_empty_data(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1139 skip_if_unsupported!(kernel, test_name);
1140 let high: [f64; 0] = [];
1141 let low: [f64; 0] = [];
1142 let params = LrsiParams::default();
1143 let input = LrsiInput::from_slices(&high, &low, params);
1144 let result = lrsi_with_kernel(&input, kernel);
1145 assert!(result.is_err());
1146 Ok(())
1147 }
1148
1149 fn check_lrsi_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1150 skip_if_unsupported!(kernel, test_name);
1151 let high = [f64::NAN, f64::NAN, f64::NAN];
1152 let low = [f64::NAN, f64::NAN, f64::NAN];
1153 let params = LrsiParams::default();
1154 let input = LrsiInput::from_slices(&high, &low, params);
1155 let result = lrsi_with_kernel(&input, kernel);
1156 assert!(result.is_err());
1157 Ok(())
1158 }
1159
1160 fn check_lrsi_very_small_dataset(
1161 test_name: &str,
1162 kernel: Kernel,
1163 ) -> Result<(), Box<dyn Error>> {
1164 skip_if_unsupported!(kernel, test_name);
1165 let high = [1.0, 1.0];
1166 let low = [1.0, 1.0];
1167 let params = LrsiParams::default();
1168 let input = LrsiInput::from_slices(&high, &low, params);
1169 let result = lrsi_with_kernel(&input, kernel);
1170 assert!(result.is_err());
1171 Ok(())
1172 }
1173
1174 fn check_lrsi_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1175 skip_if_unsupported!(kernel, test_name);
1176 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1177 let candles = read_candles_from_csv(file_path)?;
1178 let high = candles.select_candle_field("high").unwrap();
1179 let low = candles.select_candle_field("low").unwrap();
1180
1181 let input = LrsiInput::from_slices(high, low, LrsiParams::default());
1182 let batch_output = lrsi_with_kernel(&input, kernel)?.values;
1183
1184 let mut stream = LrsiStream::try_new(LrsiParams::default())?;
1185 let mut stream_values = Vec::with_capacity(high.len());
1186 for i in 0..high.len() {
1187 let price = (high[i] + low[i]) / 2.0;
1188 match stream.update(price) {
1189 Some(val) => stream_values.push(val),
1190 None => stream_values.push(f64::NAN),
1191 }
1192 }
1193 assert_eq!(batch_output.len(), stream_values.len());
1194 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1195 if b.is_nan() && s.is_nan() {
1196 continue;
1197 }
1198 let diff = (b - s).abs();
1199 assert!(
1200 diff < 1e-9,
1201 "[{}] LRSI streaming mismatch at idx {}: batch={}, stream={}, diff={}",
1202 test_name,
1203 i,
1204 b,
1205 s,
1206 diff
1207 );
1208 }
1209 Ok(())
1210 }
1211
1212 #[cfg(debug_assertions)]
1213 fn check_lrsi_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1214 skip_if_unsupported!(kernel, test_name);
1215 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1216 let candles = read_candles_from_csv(file_path)?;
1217
1218 let len = candles.close.len();
1219 let mut high = AVec::<f64>::with_capacity(CACHELINE_ALIGN, len);
1220 let mut low = AVec::<f64>::with_capacity(CACHELINE_ALIGN, len);
1221
1222 high.resize(len, f64::from_bits(0x11111111_11111111));
1223 low.resize(len, f64::from_bits(0x22222222_22222222));
1224
1225 high.copy_from_slice(&candles.high);
1226 low.copy_from_slice(&candles.low);
1227
1228 let test_params = vec![
1229 LrsiParams { alpha: Some(0.1) },
1230 LrsiParams { alpha: Some(0.2) },
1231 LrsiParams { alpha: Some(0.5) },
1232 LrsiParams { alpha: Some(0.8) },
1233 LrsiParams { alpha: Some(0.95) },
1234 ];
1235
1236 for params in test_params {
1237 let input = LrsiInput::from_slices(&high, &low, params.clone());
1238 let result = lrsi_with_kernel(&input, kernel)?;
1239
1240 for (i, &val) in result.values.iter().enumerate() {
1241 if val.is_nan() {
1242 continue;
1243 }
1244
1245 let bits = val.to_bits();
1246
1247 if bits == 0x11111111_11111111 {
1248 panic!(
1249 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1250 with params: alpha={}",
1251 test_name,
1252 val,
1253 bits,
1254 i,
1255 params.alpha.unwrap_or(0.2)
1256 );
1257 }
1258
1259 if bits == 0x22222222_22222222 {
1260 panic!(
1261 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1262 with params: alpha={}",
1263 test_name,
1264 val,
1265 bits,
1266 i,
1267 params.alpha.unwrap_or(0.2)
1268 );
1269 }
1270
1271 if bits == 0x33333333_33333333 {
1272 panic!(
1273 "[{}] Found third poison value {} (0x{:016X}) at index {} \
1274 with params: alpha={}",
1275 test_name,
1276 val,
1277 bits,
1278 i,
1279 params.alpha.unwrap_or(0.2)
1280 );
1281 }
1282 }
1283 }
1284
1285 Ok(())
1286 }
1287
1288 #[cfg(not(debug_assertions))]
1289 fn check_lrsi_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1290 Ok(())
1291 }
1292
1293 #[cfg(feature = "proptest")]
1294 fn check_lrsi_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1295 use proptest::prelude::*;
1296 skip_if_unsupported!(kernel, test_name);
1297
1298 let strat = (4usize..=400, 0.01f64..0.99f64, prop::bool::weighted(0.1)).prop_flat_map(
1299 |(len, alpha, use_constant_price)| {
1300 if use_constant_price && len < 50 {
1301 let constant_price = (10.0f64..200.0f64);
1302 constant_price
1303 .prop_map(move |price| {
1304 let high = vec![price; len];
1305 let low = vec![price; len];
1306 (high, low, alpha)
1307 })
1308 .boxed()
1309 } else {
1310 (
1311 proptest::collection::vec(
1312 (10.0f64..200.0f64).prop_filter("finite", |x| x.is_finite()),
1313 len,
1314 ),
1315 proptest::collection::vec((0.0f64..0.05f64), len),
1316 Just(alpha),
1317 )
1318 .prop_map(|(base_prices, spreads, alpha)| {
1319 let mut high = Vec::with_capacity(base_prices.len());
1320 let mut low = Vec::with_capacity(base_prices.len());
1321
1322 for (base, spread) in base_prices.iter().zip(spreads.iter()) {
1323 let half_spread = base * spread / 2.0;
1324 high.push(base + half_spread);
1325 low.push(base - half_spread);
1326 }
1327
1328 (high, low, alpha)
1329 })
1330 .boxed()
1331 }
1332 },
1333 );
1334
1335 proptest::test_runner::TestRunner::default()
1336 .run(&strat, |(high, low, alpha)| {
1337 let params = LrsiParams { alpha: Some(alpha) };
1338 let input = LrsiInput::from_slices(&high, &low, params.clone());
1339
1340 let result = lrsi_with_kernel(&input, kernel)?;
1341 let out = result.values;
1342
1343 let ref_result = lrsi_with_kernel(&input, Kernel::Scalar)?;
1344 let ref_out = ref_result.values;
1345
1346 prop_assert_eq!(out.len(), high.len(), "Output length mismatch");
1347
1348 let mut first_valid_idx = None;
1349 for i in 0..high.len() {
1350 let price = (high[i] + low[i]) / 2.0;
1351 if !price.is_nan() {
1352 first_valid_idx = Some(i);
1353 break;
1354 }
1355 }
1356
1357 if let Some(first_idx) = first_valid_idx {
1358 let warmup_end = first_idx + 3;
1359
1360 for i in 0..first_idx {
1361 prop_assert!(
1362 out[i].is_nan(),
1363 "Expected NaN before first valid at index {}, got {}",
1364 i,
1365 out[i]
1366 );
1367 }
1368
1369 let first_output_idx = first_idx + 3;
1370 if first_output_idx < out.len() && !out[first_output_idx].is_nan() {
1371 prop_assert!(
1372 out[first_output_idx] >= 0.0 && out[first_output_idx] <= 1.0,
1373 "First output after warmup at index {} = {}, should be in [0, 1]",
1374 first_output_idx,
1375 out[first_output_idx]
1376 );
1377 }
1378
1379 for i in (first_idx + 3)..out.len() {
1380 if !out[i].is_nan() {
1381 prop_assert!(
1382 out[i] >= 0.0 && out[i] <= 1.0,
1383 "LRSI value {} at index {} outside [0, 1] range",
1384 out[i],
1385 i
1386 );
1387 }
1388 }
1389
1390 for i in 0..out.len() {
1391 let y = out[i];
1392 let r = ref_out[i];
1393
1394 if !y.is_finite() || !r.is_finite() {
1395 prop_assert_eq!(
1396 y.to_bits(),
1397 r.to_bits(),
1398 "NaN/infinite mismatch at index {}: {} vs {}",
1399 i,
1400 y,
1401 r
1402 );
1403 } else {
1404 let y_bits = y.to_bits();
1405 let r_bits = r.to_bits();
1406 let ulp_diff = y_bits.abs_diff(r_bits);
1407
1408 prop_assert!(
1409 (y - r).abs() <= 1e-9 || ulp_diff <= 5,
1410 "Kernel mismatch at index {}: {} vs {} (ULP={}, alpha={})",
1411 i,
1412 y,
1413 r,
1414 ulp_diff,
1415 alpha
1416 );
1417 }
1418 }
1419
1420 let is_constant = high
1421 .iter()
1422 .zip(low.iter())
1423 .all(|(h, l)| (h - l).abs() < f64::EPSILON && h.is_finite());
1424
1425 if is_constant && out.len() > first_idx + 10 {
1426 let last_values = &out[out.len() - 5..];
1427 let valid_last = last_values
1428 .iter()
1429 .filter(|v| v.is_finite())
1430 .collect::<Vec<_>>();
1431
1432 if valid_last.len() >= 2 {
1433 let variance = valid_last
1434 .windows(2)
1435 .map(|w| (w[1] - w[0]).abs())
1436 .fold(0.0, f64::max);
1437
1438 prop_assert!(
1439 variance < 0.1,
1440 "LRSI not stable for constant prices, variance: {}",
1441 variance
1442 );
1443 }
1444 }
1445
1446 if out.len() > first_idx + 25 {
1447 let prices: Vec<f64> = high
1448 .iter()
1449 .zip(low.iter())
1450 .map(|(h, l)| (h + l) / 2.0)
1451 .filter(|p| p.is_finite())
1452 .collect();
1453
1454 if prices.len() >= 10 {
1455 let price_changes: Vec<f64> = prices
1456 .windows(2)
1457 .map(|w| ((w[1] - w[0]) / w[0]).abs())
1458 .collect();
1459 let input_volatility = if !price_changes.is_empty() {
1460 price_changes.iter().sum::<f64>() / price_changes.len() as f64
1461 } else {
1462 0.01
1463 };
1464
1465 let start = (first_idx + 5).min(out.len().saturating_sub(5));
1466 let end = out.len().saturating_sub(5);
1467 if start < end {
1468 let mid_section = &out[start..end];
1469 let valid_mid: Vec<f64> = mid_section
1470 .iter()
1471 .filter(|v| v.is_finite())
1472 .copied()
1473 .collect();
1474
1475 if valid_mid.len() >= 10 {
1476 let avg_change = valid_mid
1477 .windows(2)
1478 .map(|w| (w[1] - w[0]).abs())
1479 .sum::<f64>()
1480 / (valid_mid.len() - 1) as f64;
1481
1482 if alpha < 0.05 {
1483 let expected_max_change =
1484 input_volatility * (alpha * 20.0).max(0.1);
1485 prop_assert!(
1486 avg_change <= expected_max_change,
1487 "Low alpha ({}) should produce smooth output relative to input volatility. \
1488 Avg change: {}, Expected max: {}, Input volatility: {}",
1489 alpha,
1490 avg_change,
1491 expected_max_change,
1492 input_volatility
1493 );
1494 } else if alpha > 0.95 {
1495 let expected_min_change = (input_volatility * 0.2).min(0.1);
1496
1497 if input_volatility > 0.01 {
1498 prop_assert!(
1499 avg_change >= expected_min_change || avg_change < 0.001,
1500 "High alpha ({}) should be responsive to input changes. \
1501 Avg change: {}, Expected min: {}, Input volatility: {}",
1502 alpha,
1503 avg_change,
1504 expected_min_change,
1505 input_volatility
1506 );
1507 }
1508 }
1509 }
1510 }
1511 }
1512 }
1513
1514 let is_monotonic_up = high.windows(2).all(|w| w[1] >= w[0])
1515 && high.windows(2).any(|w| w[1] > w[0] + f64::EPSILON);
1516 let is_monotonic_down = high.windows(2).all(|w| w[1] <= w[0])
1517 && high.windows(2).any(|w| w[1] < w[0] - f64::EPSILON);
1518
1519 if (is_monotonic_up || is_monotonic_down) && out.len() > first_idx + 20 {
1520 let valid_out: Vec<(usize, f64)> = out
1521 .iter()
1522 .enumerate()
1523 .skip(first_idx + 1)
1524 .filter(|(_, v)| v.is_finite())
1525 .map(|(i, v)| (i, *v))
1526 .collect();
1527
1528 if valid_out.len() >= 20 {
1529 let chunk_size = valid_out.len() / 3;
1530 if chunk_size >= 5 {
1531 let first_chunk_avg =
1532 valid_out[..chunk_size].iter().map(|(_, v)| v).sum::<f64>()
1533 / chunk_size as f64;
1534 let last_chunk_avg = valid_out[valid_out.len() - chunk_size..]
1535 .iter()
1536 .map(|(_, v)| v)
1537 .sum::<f64>()
1538 / chunk_size as f64;
1539
1540 let price_range = high
1541 .iter()
1542 .zip(low.iter())
1543 .map(|(h, l)| (h + l) / 2.0)
1544 .filter(|p| p.is_finite())
1545 .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), p| {
1546 (min.min(p), max.max(p))
1547 });
1548 let trend_strength = if price_range.1 > price_range.0 {
1549 (price_range.1 - price_range.0) / price_range.0
1550 } else {
1551 0.01
1552 };
1553
1554 let tolerance = (0.05 * (1.0 - alpha * 0.5)).min(0.05);
1555
1556 if is_monotonic_up {
1557 prop_assert!(
1558 last_chunk_avg >= first_chunk_avg - tolerance,
1559 "LRSI should respond to uptrend, but first_avg={}, last_avg={}, \
1560 tolerance={}, alpha={}, trend_strength={}",
1561 first_chunk_avg,
1562 last_chunk_avg,
1563 tolerance,
1564 alpha,
1565 trend_strength
1566 );
1567 } else if is_monotonic_down {
1568 prop_assert!(
1569 last_chunk_avg <= first_chunk_avg + tolerance,
1570 "LRSI should respond to downtrend, but first_avg={}, last_avg={}, \
1571 tolerance={}, alpha={}, trend_strength={}",
1572 first_chunk_avg,
1573 last_chunk_avg,
1574 tolerance,
1575 alpha,
1576 trend_strength
1577 );
1578 }
1579 }
1580 }
1581 }
1582
1583 if alpha < 0.02 || alpha > 0.98 {
1584 if out.len() > first_idx + 50 {
1585 let valid_values: Vec<f64> = out
1586 .iter()
1587 .skip(first_idx + 10)
1588 .filter(|v| v.is_finite())
1589 .copied()
1590 .collect();
1591
1592 if valid_values.len() >= 20 {
1593 if alpha < 0.02 {
1594 let settled_values = if valid_values.len() > 10 {
1595 &valid_values[5..]
1596 } else {
1597 &valid_values[..]
1598 };
1599
1600 if settled_values.len() >= 5 {
1601 let max_step = settled_values
1602 .windows(2)
1603 .map(|w| (w[1] - w[0]).abs())
1604 .fold(0.0f64, f64::max);
1605
1606 prop_assert!(
1607 max_step < 0.7,
1608 "Extreme low alpha ({}) should produce smooth output after settling, \
1609 but max step is {}",
1610 alpha,
1611 max_step
1612 );
1613 }
1614
1615 if valid_values.len() >= 20 {
1616 let last_10 = &valid_values[valid_values.len() - 10..];
1617
1618 let min_val =
1619 last_10.iter().fold(f64::INFINITY, |a, &b| a.min(b));
1620 let max_val = last_10
1621 .iter()
1622 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1623 let range = max_val - min_val;
1624
1625 prop_assert!(
1626 range < 0.5,
1627 "Extreme low alpha ({}) should converge to stable value, \
1628 but range in last 10 values is {}",
1629 alpha,
1630 range
1631 );
1632 }
1633 } else {
1634 let range = valid_values.iter().fold(
1635 (f64::INFINITY, f64::NEG_INFINITY),
1636 |(min, max), &v| (min.min(v), max.max(v)),
1637 );
1638
1639 let input_has_variation =
1640 high.windows(2).any(|w| (w[1] - w[0]).abs() > w[0] * 0.001);
1641
1642 if input_has_variation {
1643 prop_assert!(
1644 range.1 - range.0 > 0.05,
1645 "Extreme high alpha ({}) should produce varied output \
1646 for varied input, but range is only {}",
1647 alpha,
1648 range.1 - range.0
1649 );
1650 }
1651 }
1652 }
1653 }
1654 }
1655 }
1656
1657 Ok(())
1658 })
1659 .unwrap();
1660
1661 Ok(())
1662 }
1663
1664 #[cfg(not(feature = "proptest"))]
1665 fn check_lrsi_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1666 skip_if_unsupported!(kernel, test_name);
1667 Ok(())
1668 }
1669
1670 macro_rules! generate_all_lrsi_tests {
1671 ($($test_fn:ident),*) => {
1672 paste::paste! {
1673 $(
1674 #[test]
1675 fn [<$test_fn _scalar_f64>]() {
1676 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1677 }
1678 )*
1679 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1680 $(
1681 #[test]
1682 fn [<$test_fn _avx2_f64>]() {
1683 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1684 }
1685 #[test]
1686 fn [<$test_fn _avx512_f64>]() {
1687 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1688 }
1689 )*
1690 }
1691 }
1692 }
1693
1694 generate_all_lrsi_tests!(
1695 check_lrsi_partial_params,
1696 check_lrsi_accuracy,
1697 check_lrsi_default_candles,
1698 check_lrsi_invalid_alpha,
1699 check_lrsi_empty_data,
1700 check_lrsi_all_nan,
1701 check_lrsi_very_small_dataset,
1702 check_lrsi_streaming,
1703 check_lrsi_no_poison,
1704 check_lrsi_property
1705 );
1706
1707 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1708 skip_if_unsupported!(kernel, test);
1709
1710 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1711 let c = read_candles_from_csv(file)?;
1712
1713 let output = LrsiBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
1714
1715 let def = LrsiParams::default();
1716 let row = output.values_for(&def).expect("default row missing");
1717 assert_eq!(row.len(), c.close.len());
1718
1719 let expected = [0.0, 0.0, 0.0, 0.0, 0.0];
1720 let start = row.len() - 5;
1721 for (i, &v) in row[start..].iter().enumerate() {
1722 assert!(
1723 (v - expected[i]).abs() < 1e-9,
1724 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1725 );
1726 }
1727 Ok(())
1728 }
1729
1730 #[cfg(debug_assertions)]
1731 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1732 skip_if_unsupported!(kernel, test);
1733
1734 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1735 let c = read_candles_from_csv(file)?;
1736
1737 let slice_end = c.close.len().min(1000);
1738 let high_slice = &c.high[..slice_end];
1739 let low_slice = &c.low[..slice_end];
1740
1741 let test_configs = vec![
1742 (0.1, 0.3, 0.1),
1743 (0.2, 0.8, 0.2),
1744 (0.5, 0.9, 0.1),
1745 (0.1, 0.95, 0.15),
1746 (0.85, 0.95, 0.05),
1747 ];
1748
1749 for (cfg_idx, &(a_start, a_end, a_step)) in test_configs.iter().enumerate() {
1750 let output = LrsiBatchBuilder::new()
1751 .kernel(kernel)
1752 .alpha_range(a_start, a_end, a_step)
1753 .apply_slices(high_slice, low_slice)?;
1754
1755 for (idx, &val) in output.values.iter().enumerate() {
1756 if val.is_nan() {
1757 continue;
1758 }
1759
1760 let bits = val.to_bits();
1761 let row = idx / output.cols;
1762 let col = idx % output.cols;
1763 let combo = &output.combos[row];
1764
1765 if bits == 0x11111111_11111111 {
1766 panic!(
1767 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1768 at row {} col {} (flat index {}) with params: alpha={}",
1769 test,
1770 cfg_idx,
1771 val,
1772 bits,
1773 row,
1774 col,
1775 idx,
1776 combo.alpha.unwrap_or(0.2)
1777 );
1778 }
1779
1780 if bits == 0x22222222_22222222 {
1781 panic!(
1782 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1783 at row {} col {} (flat index {}) with params: alpha={}",
1784 test,
1785 cfg_idx,
1786 val,
1787 bits,
1788 row,
1789 col,
1790 idx,
1791 combo.alpha.unwrap_or(0.2)
1792 );
1793 }
1794
1795 if bits == 0x33333333_33333333 {
1796 panic!(
1797 "[{}] Config {}: Found third poison value {} (0x{:016X}) \
1798 at row {} col {} (flat index {}) with params: alpha={}",
1799 test,
1800 cfg_idx,
1801 val,
1802 bits,
1803 row,
1804 col,
1805 idx,
1806 combo.alpha.unwrap_or(0.2)
1807 );
1808 }
1809 }
1810 }
1811
1812 Ok(())
1813 }
1814
1815 #[cfg(not(debug_assertions))]
1816 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1817 Ok(())
1818 }
1819
1820 macro_rules! gen_batch_tests {
1821 ($fn_name:ident) => {
1822 paste::paste! {
1823 #[test] fn [<$fn_name _scalar>]() {
1824 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1825 }
1826 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1827 #[test] fn [<$fn_name _avx2>]() {
1828 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1829 }
1830 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1831 #[test] fn [<$fn_name _avx512>]() {
1832 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1833 }
1834 #[test] fn [<$fn_name _auto_detect>]() {
1835 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1836 }
1837 }
1838 };
1839 }
1840 gen_batch_tests!(check_batch_default_row);
1841 gen_batch_tests!(check_batch_no_poison);
1842}
1843
1844#[inline(always)]
1845fn lrsi_batch_inner_into(
1846 high: &[f64],
1847 low: &[f64],
1848 sweep: &LrsiBatchRange,
1849 kern: Kernel,
1850 parallel: bool,
1851 out: &mut [f64],
1852) -> Result<Vec<LrsiParams>, LrsiError> {
1853 let combos = expand_grid(sweep)?;
1854 if combos.is_empty() {
1855 return Err(LrsiError::InvalidRange {
1856 start: sweep.alpha.0.to_string(),
1857 end: sweep.alpha.1.to_string(),
1858 step: sweep.alpha.2.to_string(),
1859 });
1860 }
1861
1862 if high.is_empty() || low.is_empty() {
1863 return Err(LrsiError::EmptyInputData);
1864 }
1865 if high.len() != low.len() {
1866 return Err(LrsiError::EmptyInputData);
1867 }
1868
1869 let len = high.len();
1870 let mut prices = Vec::with_capacity(len);
1871 prices.extend((0..len).map(|i| (high[i] + low[i]) * 0.5));
1872 let first = (0..len)
1873 .find(|&i| prices[i].is_finite())
1874 .ok_or(LrsiError::AllValuesNaN)?;
1875 if len - first < 4 {
1876 return Err(LrsiError::NotEnoughValidData {
1877 needed: 4,
1878 valid: len - first,
1879 });
1880 }
1881
1882 let rows = combos.len();
1883 let cols = high.len();
1884 let expected = rows.checked_mul(cols).ok_or(LrsiError::InvalidRange {
1885 start: rows.to_string(),
1886 end: cols.to_string(),
1887 step: "rows*cols".into(),
1888 })?;
1889 if out.len() != expected {
1890 return Err(LrsiError::OutputLengthMismatch {
1891 expected,
1892 got: out.len(),
1893 });
1894 }
1895
1896 let out_mu: &mut [std::mem::MaybeUninit<f64>] = unsafe {
1897 std::slice::from_raw_parts_mut(
1898 out.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
1899 out.len(),
1900 )
1901 };
1902
1903 let do_row = |row: usize, dst_row_mu: &mut [std::mem::MaybeUninit<f64>]| unsafe {
1904 let alpha = combos[row].alpha.unwrap();
1905
1906 let dst_row =
1907 std::slice::from_raw_parts_mut(dst_row_mu.as_mut_ptr() as *mut f64, dst_row_mu.len());
1908
1909 let warmup_end = first + 3;
1910 for i in 0..warmup_end.min(dst_row.len()) {
1911 dst_row[i] = f64::NAN;
1912 }
1913
1914 match kern {
1915 Kernel::Scalar => lrsi_row_scalar(&prices, first, alpha, dst_row),
1916 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1917 Kernel::Avx2 => lrsi_row_avx2(&prices, first, alpha, dst_row),
1918 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1919 Kernel::Avx512 => lrsi_row_avx512(&prices, first, alpha, dst_row),
1920 Kernel::Auto | _ => unreachable!(),
1921 }
1922 };
1923
1924 if parallel {
1925 #[cfg(not(target_arch = "wasm32"))]
1926 out_mu
1927 .par_chunks_mut(cols)
1928 .enumerate()
1929 .for_each(|(r, row)| do_row(r, row));
1930 #[cfg(target_arch = "wasm32")]
1931 for (r, row) in out_mu.chunks_mut(cols).enumerate() {
1932 do_row(r, row);
1933 }
1934 } else {
1935 for (r, row) in out_mu.chunks_mut(cols).enumerate() {
1936 do_row(r, row);
1937 }
1938 }
1939
1940 Ok(combos)
1941}
1942
1943#[cfg(feature = "python")]
1944#[pyfunction(name = "lrsi")]
1945#[pyo3(signature = (high, low, alpha, kernel=None))]
1946pub fn lrsi_py<'py>(
1947 py: Python<'py>,
1948 high: PyReadonlyArray1<'py, f64>,
1949 low: PyReadonlyArray1<'py, f64>,
1950 alpha: f64,
1951 kernel: Option<&str>,
1952) -> PyResult<Bound<'py, PyArray1<f64>>> {
1953 let h = high.as_slice()?;
1954 let l = low.as_slice()?;
1955 let kern = validate_kernel(kernel, false)?;
1956 let params = LrsiParams { alpha: Some(alpha) };
1957 let inp = LrsiInput::from_slices(h, l, params);
1958
1959 let vec_out: Vec<f64> = py
1960 .allow_threads(|| lrsi_with_kernel(&inp, kern).map(|o| o.values))
1961 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1962 Ok(vec_out.into_pyarray(py))
1963}
1964
1965#[cfg(feature = "python")]
1966#[pyclass(name = "LrsiStream")]
1967pub struct LrsiStreamPy {
1968 stream: LrsiStream,
1969}
1970
1971#[cfg(feature = "python")]
1972#[pymethods]
1973impl LrsiStreamPy {
1974 #[new]
1975 fn new(alpha: f64) -> PyResult<Self> {
1976 let params = LrsiParams { alpha: Some(alpha) };
1977 let stream =
1978 LrsiStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1979 Ok(LrsiStreamPy { stream })
1980 }
1981
1982 fn update(&mut self, high: f64, low: f64) -> Option<f64> {
1983 let price = (high + low) / 2.0;
1984 self.stream.update(price)
1985 }
1986}
1987
1988#[cfg(feature = "python")]
1989#[pyfunction(name = "lrsi_batch")]
1990#[pyo3(signature = (high, low, alpha_range, kernel=None))]
1991pub fn lrsi_batch_py<'py>(
1992 py: Python<'py>,
1993 high: PyReadonlyArray1<'py, f64>,
1994 low: PyReadonlyArray1<'py, f64>,
1995 alpha_range: (f64, f64, f64),
1996 kernel: Option<&str>,
1997) -> PyResult<Bound<'py, PyDict>> {
1998 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1999
2000 let h = high.as_slice()?;
2001 let l = low.as_slice()?;
2002 let sweep = LrsiBatchRange { alpha: alpha_range };
2003 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2004 let rows = combos.len();
2005 let cols = h.len();
2006
2007 let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
2008 let out_slice = unsafe { out_arr.as_slice_mut()? };
2009
2010 let kern = validate_kernel(kernel, true)?;
2011 py.allow_threads(|| {
2012 let resolved = match kern {
2013 Kernel::Auto => detect_best_batch_kernel(),
2014 k => k,
2015 };
2016 let row_kernel = match resolved {
2017 Kernel::Avx512Batch => Kernel::Avx512,
2018 Kernel::Avx2Batch => Kernel::Avx2,
2019 Kernel::ScalarBatch => Kernel::Scalar,
2020 _ => unreachable!(),
2021 };
2022 lrsi_batch_inner_into(h, l, &sweep, row_kernel, true, out_slice)
2023 })
2024 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2025
2026 let dict = PyDict::new(py);
2027 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2028 dict.set_item(
2029 "alphas",
2030 combos
2031 .iter()
2032 .map(|p| p.alpha.unwrap())
2033 .collect::<Vec<_>>()
2034 .into_pyarray(py),
2035 )?;
2036 Ok(dict)
2037}
2038
2039#[cfg(all(feature = "python", feature = "cuda"))]
2040#[pyfunction(name = "lrsi_cuda_batch_dev")]
2041#[pyo3(signature = (high_f32, low_f32, alpha_range, device_id=0))]
2042pub fn lrsi_cuda_batch_dev_py(
2043 py: Python<'_>,
2044 high_f32: numpy::PyReadonlyArray1<'_, f32>,
2045 low_f32: numpy::PyReadonlyArray1<'_, f32>,
2046 alpha_range: (f64, f64, f64),
2047 device_id: usize,
2048) -> PyResult<DeviceArrayF32Py> {
2049 use crate::cuda::cuda_available;
2050 use crate::cuda::oscillators::CudaLrsi;
2051 if !cuda_available() {
2052 return Err(PyValueError::new_err("CUDA not available"));
2053 }
2054 let h = high_f32.as_slice()?;
2055 let l = low_f32.as_slice()?;
2056 if h.len() != l.len() {
2057 return Err(PyValueError::new_err("mismatched input lengths"));
2058 }
2059 let sweep = LrsiBatchRange { alpha: alpha_range };
2060 let inner = py.allow_threads(|| {
2061 let mut cuda =
2062 CudaLrsi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2063 cuda.lrsi_batch_dev(h, l, &sweep)
2064 .map_err(|e| PyValueError::new_err(e.to_string()))
2065 })?;
2066 let handle = make_device_array_py(device_id, inner)?;
2067 Ok(handle)
2068}
2069
2070#[cfg(all(feature = "python", feature = "cuda"))]
2071#[pyfunction(name = "lrsi_cuda_many_series_one_param_dev")]
2072#[pyo3(signature = (high_tm_f32, low_tm_f32, alpha, device_id=0))]
2073pub fn lrsi_cuda_many_series_one_param_dev_py(
2074 py: Python<'_>,
2075 high_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
2076 low_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
2077 alpha: f64,
2078 device_id: usize,
2079) -> PyResult<DeviceArrayF32Py> {
2080 use crate::cuda::cuda_available;
2081 use crate::cuda::oscillators::CudaLrsi;
2082 use numpy::PyUntypedArrayMethods;
2083 if !cuda_available() {
2084 return Err(PyValueError::new_err("CUDA not available"));
2085 }
2086 let h = high_tm_f32.as_slice()?;
2087 let l = low_tm_f32.as_slice()?;
2088 let rows = high_tm_f32.shape()[0];
2089 let cols = high_tm_f32.shape()[1];
2090 if low_tm_f32.shape() != [rows, cols] {
2091 return Err(PyValueError::new_err("mismatched matrix shapes"));
2092 }
2093 let inner = py.allow_threads(|| {
2094 let mut cuda =
2095 CudaLrsi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2096 cuda.lrsi_many_series_one_param_time_major_dev(h, l, cols, rows, alpha)
2097 .map_err(|e| PyValueError::new_err(e.to_string()))
2098 })?;
2099 let handle = make_device_array_py(device_id, inner)?;
2100 Ok(handle)
2101}