1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use cust::context::Context;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use cust::memory::DeviceBuffer;
7#[cfg(feature = "python")]
8use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
9#[cfg(feature = "python")]
10use pyo3::exceptions::PyValueError;
11#[cfg(feature = "python")]
12use pyo3::prelude::*;
13#[cfg(feature = "python")]
14use pyo3::types::PyDict;
15
16#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
17use serde::{Deserialize, Serialize};
18#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
19use wasm_bindgen::prelude::*;
20
21#[cfg(all(feature = "python", feature = "cuda"))]
22use std::sync::Arc;
23
24use crate::utilities::data_loader::Candles;
25use crate::utilities::enums::Kernel;
26use crate::utilities::helpers::{
27 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
28 make_uninit_matrix,
29};
30#[cfg(feature = "python")]
31use crate::utilities::kernel_validation::validate_kernel;
32#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
33use core::arch::x86_64::*;
34#[cfg(not(target_arch = "wasm32"))]
35use rayon::prelude::*;
36use std::error::Error;
37use thiserror::Error;
38
39#[derive(Debug, Clone)]
40pub enum SarData<'a> {
41 Candles { candles: &'a Candles },
42 Slices { high: &'a [f64], low: &'a [f64] },
43}
44
45#[derive(Debug, Clone)]
46pub struct SarOutput {
47 pub values: Vec<f64>,
48}
49
50#[derive(Debug, Clone)]
51#[cfg_attr(
52 all(target_arch = "wasm32", feature = "wasm"),
53 derive(Serialize, Deserialize)
54)]
55pub struct SarParams {
56 pub acceleration: Option<f64>,
57 pub maximum: Option<f64>,
58}
59
60impl Default for SarParams {
61 fn default() -> Self {
62 Self {
63 acceleration: Some(0.02),
64 maximum: Some(0.2),
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
70pub struct SarInput<'a> {
71 pub data: SarData<'a>,
72 pub params: SarParams,
73}
74
75impl<'a> SarInput<'a> {
76 #[inline]
77 pub fn from_candles(candles: &'a Candles, params: SarParams) -> Self {
78 Self {
79 data: SarData::Candles { candles },
80 params,
81 }
82 }
83
84 #[inline]
85 pub fn from_slices(high: &'a [f64], low: &'a [f64], params: SarParams) -> Self {
86 Self {
87 data: SarData::Slices { high, low },
88 params,
89 }
90 }
91
92 #[inline]
93 pub fn with_default_candles(candles: &'a Candles) -> Self {
94 Self {
95 data: SarData::Candles { candles },
96 params: SarParams::default(),
97 }
98 }
99
100 #[inline]
101 pub fn get_acceleration(&self) -> f64 {
102 self.params.acceleration.unwrap_or(0.02)
103 }
104
105 #[inline]
106 pub fn get_maximum(&self) -> f64 {
107 self.params.maximum.unwrap_or(0.2)
108 }
109}
110
111#[derive(Copy, Clone, Debug)]
112pub struct SarBuilder {
113 acceleration: Option<f64>,
114 maximum: Option<f64>,
115 kernel: Kernel,
116}
117
118impl Default for SarBuilder {
119 fn default() -> Self {
120 Self {
121 acceleration: None,
122 maximum: None,
123 kernel: Kernel::Auto,
124 }
125 }
126}
127
128impl SarBuilder {
129 #[inline(always)]
130 pub fn new() -> Self {
131 Self::default()
132 }
133 #[inline(always)]
134 pub fn acceleration(mut self, v: f64) -> Self {
135 self.acceleration = Some(v);
136 self
137 }
138 #[inline(always)]
139 pub fn maximum(mut self, v: f64) -> Self {
140 self.maximum = Some(v);
141 self
142 }
143 #[inline(always)]
144 pub fn kernel(mut self, k: Kernel) -> Self {
145 self.kernel = k;
146 self
147 }
148 #[inline(always)]
149 pub fn apply(self, c: &Candles) -> Result<SarOutput, SarError> {
150 let params = SarParams {
151 acceleration: self.acceleration,
152 maximum: self.maximum,
153 };
154 let input = SarInput::from_candles(c, params);
155 sar_with_kernel(&input, self.kernel)
156 }
157 #[inline(always)]
158 pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<SarOutput, SarError> {
159 let params = SarParams {
160 acceleration: self.acceleration,
161 maximum: self.maximum,
162 };
163 let input = SarInput::from_slices(high, low, params);
164 sar_with_kernel(&input, self.kernel)
165 }
166 #[inline(always)]
167 pub fn into_stream(self) -> Result<SarStream, SarError> {
168 let params = SarParams {
169 acceleration: self.acceleration,
170 maximum: self.maximum,
171 };
172 SarStream::try_new(params)
173 }
174}
175
176#[derive(Debug, Error)]
177pub enum SarError {
178 #[error("sar: Empty data provided.")]
179 EmptyInputData,
180 #[error("sar: All values are NaN.")]
181 AllValuesNaN,
182 #[error("sar: Not enough valid data. needed = {needed}, valid = {valid}")]
183 NotEnoughValidData { needed: usize, valid: usize },
184 #[error("sar: Invalid acceleration: {acceleration}")]
185 InvalidAcceleration { acceleration: f64 },
186 #[error("sar: Invalid maximum: {maximum}")]
187 InvalidMaximum { maximum: f64 },
188 #[error("sar: Output length mismatch: expected = {expected}, got = {got}")]
189 OutputLengthMismatch { expected: usize, got: usize },
190 #[error("sar: Invalid parameter range: start = {start}, end = {end}, step = {step}")]
191 InvalidRange { start: f64, end: f64, step: f64 },
192 #[error("sar: Invalid kernel for batch: {0:?}")]
193 InvalidKernelForBatch(Kernel),
194}
195
196#[inline]
197pub fn sar(input: &SarInput) -> Result<SarOutput, SarError> {
198 sar_with_kernel(input, Kernel::Auto)
199}
200
201pub fn sar_with_kernel(input: &SarInput, kernel: Kernel) -> Result<SarOutput, SarError> {
202 let (high, low) = match &input.data {
203 SarData::Candles { candles } => (candles.high.as_slice(), candles.low.as_slice()),
204 SarData::Slices { high, low } => (*high, *low),
205 };
206
207 if high.is_empty() || low.is_empty() {
208 return Err(SarError::EmptyInputData);
209 }
210
211 let min_len = high.len().min(low.len());
212 let (high, low) = (&high[..min_len], &low[..min_len]);
213
214 let first_valid_idx = high
215 .iter()
216 .zip(low.iter())
217 .position(|(&h, &l)| !h.is_nan() && !l.is_nan());
218 let first = match first_valid_idx {
219 Some(idx) => idx,
220 None => return Err(SarError::AllValuesNaN),
221 };
222
223 if (high.len() - first) < 2 {
224 return Err(SarError::NotEnoughValidData {
225 needed: 2,
226 valid: high.len() - first,
227 });
228 }
229
230 let acceleration = input.get_acceleration();
231 let maximum = input.get_maximum();
232
233 if !(acceleration > 0.0) || acceleration.is_nan() || acceleration.is_infinite() {
234 return Err(SarError::InvalidAcceleration { acceleration });
235 }
236 if !(maximum > 0.0) || maximum.is_nan() || maximum.is_infinite() {
237 return Err(SarError::InvalidMaximum { maximum });
238 }
239
240 let mut out = alloc_with_nan_prefix(high.len(), first + 1);
241
242 let chosen = match kernel {
243 Kernel::Auto => Kernel::Scalar,
244 other => other,
245 };
246
247 unsafe {
248 match chosen {
249 Kernel::Scalar | Kernel::ScalarBatch => {
250 sar_scalar(high, low, first, acceleration, maximum, &mut out)
251 }
252 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
253 Kernel::Avx2 | Kernel::Avx2Batch => {
254 sar_avx2(high, low, first, acceleration, maximum, &mut out)
255 }
256 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
257 Kernel::Avx512 | Kernel::Avx512Batch => {
258 sar_avx512(high, low, first, acceleration, maximum, &mut out)
259 }
260 _ => unreachable!(),
261 }
262 }
263
264 Ok(SarOutput { values: out })
265}
266
267#[inline]
268pub fn sar_into_slice(dst: &mut [f64], input: &SarInput, kern: Kernel) -> Result<(), SarError> {
269 let (high, low) = match &input.data {
270 SarData::Candles { candles } => (candles.high.as_slice(), candles.low.as_slice()),
271 SarData::Slices { high, low } => (*high, *low),
272 };
273
274 if high.is_empty() || low.is_empty() {
275 return Err(SarError::EmptyInputData);
276 }
277
278 let expected_len = high.len().min(low.len());
279 if dst.len() != expected_len {
280 return Err(SarError::OutputLengthMismatch {
281 expected: expected_len,
282 got: dst.len(),
283 });
284 }
285
286 let (high, low) = (&high[..expected_len], &low[..expected_len]);
287
288 let first_valid_idx = high
289 .iter()
290 .zip(low.iter())
291 .position(|(&h, &l)| !h.is_nan() && !l.is_nan());
292 let first = match first_valid_idx {
293 Some(idx) => idx,
294 None => return Err(SarError::AllValuesNaN),
295 };
296
297 if (high.len() - first) < 2 {
298 return Err(SarError::NotEnoughValidData {
299 valid: high.len() - first,
300 needed: 2,
301 });
302 }
303
304 let acceleration = input.params.acceleration.unwrap_or(0.02);
305 let maximum = input.params.maximum.unwrap_or(0.2);
306
307 if acceleration <= 0.0 || acceleration.is_nan() || acceleration.is_infinite() {
308 return Err(SarError::InvalidAcceleration { acceleration });
309 }
310 if maximum <= 0.0 || maximum.is_nan() || maximum.is_infinite() {
311 return Err(SarError::InvalidMaximum { maximum });
312 }
313
314 for v in &mut dst[..first.saturating_add(1)] {
315 *v = f64::from_bits(0x7ff8_0000_0000_0000);
316 }
317
318 let chosen = match kern {
319 Kernel::Auto => Kernel::Scalar,
320 x => x,
321 };
322
323 match chosen {
324 Kernel::Scalar => sar_scalar(high, low, first, acceleration, maximum, dst),
325 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
326 Kernel::Avx2 | Kernel::Avx2Batch => sar_avx2(high, low, first, acceleration, maximum, dst),
327 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
328 Kernel::Avx512 | Kernel::Avx512Batch => {
329 sar_avx512(high, low, first, acceleration, maximum, dst)
330 }
331
332 _ => sar_scalar(high, low, first, acceleration, maximum, dst),
333 }
334 Ok(())
335}
336
337#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
338#[inline]
339pub fn sar_into(input: &SarInput, out: &mut [f64]) -> Result<(), SarError> {
340 sar_into_slice(out, input, Kernel::Auto)
341}
342
343#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
344#[inline]
345pub fn sar_avx512(
346 high: &[f64],
347 low: &[f64],
348 first_valid: usize,
349 acceleration: f64,
350 maximum: f64,
351 out: &mut [f64],
352) {
353 sar_avx2(high, low, first_valid, acceleration, maximum, out)
354}
355
356#[inline]
357pub fn sar_scalar(
358 high: &[f64],
359 low: &[f64],
360 first: usize,
361 acceleration: f64,
362 maximum: f64,
363 out: &mut [f64],
364) {
365 let len = high.len();
366 let i0 = first;
367 let i1 = i0 + 1;
368
369 if i1 >= len || i1 >= low.len() || i1 >= out.len() {
370 return;
371 }
372
373 let h0 = high[i0];
374 let h1 = high[i1];
375 let l0 = low[i0];
376 let l1 = low[i1];
377
378 let mut trend_up = h1 > h0;
379 let mut sar = if trend_up { l0 } else { h0 };
380 let mut ep = if trend_up { h1 } else { l1 };
381 let mut acc = acceleration;
382
383 out[i0] = f64::NAN;
384 out[i1] = sar;
385
386 let mut low_prev2 = l0;
387 let mut low_prev = l1;
388 let mut high_prev2 = h0;
389 let mut high_prev = h1;
390
391 let mut i = i1 + 1;
392 while i < len {
393 let hi = high[i];
394 let lo = low[i];
395
396 let mut next_sar = acc.mul_add(ep - sar, sar);
397
398 if trend_up {
399 if lo < next_sar {
400 trend_up = false;
401 next_sar = ep;
402 ep = lo;
403 acc = acceleration;
404 } else {
405 if hi > ep {
406 ep = hi;
407 acc = (acc + acceleration).min(maximum);
408 }
409 next_sar = next_sar.min(low_prev).min(low_prev2);
410 }
411 } else {
412 if hi > next_sar {
413 trend_up = true;
414 next_sar = ep;
415 ep = hi;
416 acc = acceleration;
417 } else {
418 if lo < ep {
419 ep = lo;
420 acc = (acc + acceleration).min(maximum);
421 }
422 next_sar = next_sar.max(high_prev).max(high_prev2);
423 }
424 }
425
426 out[i] = next_sar;
427 sar = next_sar;
428
429 low_prev2 = low_prev;
430 low_prev = lo;
431 high_prev2 = high_prev;
432 high_prev = hi;
433
434 i += 1;
435 }
436}
437
438#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
439#[inline]
440pub fn sar_avx2(
441 high: &[f64],
442 low: &[f64],
443 first_valid: usize,
444 acceleration: f64,
445 maximum: f64,
446 out: &mut [f64],
447) {
448 let len = high.len();
449 let i0 = first_valid;
450 let i1 = i0 + 1;
451
452 if i1 >= len || i1 >= low.len() || i1 >= out.len() {
453 return;
454 }
455
456 unsafe {
457 let h0 = *high.get_unchecked(i0);
458 let h1 = *high.get_unchecked(i1);
459 let l0 = *low.get_unchecked(i0);
460 let l1 = *low.get_unchecked(i1);
461
462 let mut trend_up = h1 > h0;
463 let mut sar = if trend_up { l0 } else { h0 };
464 let mut ep = if trend_up { h1 } else { l1 };
465 let mut acc = acceleration;
466
467 *out.get_unchecked_mut(i0) = f64::NAN;
468 *out.get_unchecked_mut(i1) = sar;
469
470 let mut low_prev2 = l0;
471 let mut low_prev = l1;
472 let mut high_prev2 = h0;
473 let mut high_prev = h1;
474
475 let mut i = i1 + 1;
476 while i < len {
477 let hi = *high.get_unchecked(i);
478 let lo = *low.get_unchecked(i);
479
480 let mut next_sar = acc.mul_add(ep - sar, sar);
481
482 if trend_up {
483 if lo < next_sar {
484 trend_up = false;
485 next_sar = ep;
486 ep = lo;
487 acc = acceleration;
488 } else {
489 if hi > ep {
490 ep = hi;
491 acc = (acc + acceleration).min(maximum);
492 }
493 next_sar = next_sar.min(low_prev).min(low_prev2);
494 }
495 } else {
496 if hi > next_sar {
497 trend_up = true;
498 next_sar = ep;
499 ep = hi;
500 acc = acceleration;
501 } else {
502 if lo < ep {
503 ep = lo;
504 acc = (acc + acceleration).min(maximum);
505 }
506 next_sar = next_sar.max(high_prev).max(high_prev2);
507 }
508 }
509
510 *out.get_unchecked_mut(i) = next_sar;
511 sar = next_sar;
512
513 low_prev2 = low_prev;
514 low_prev = lo;
515 high_prev2 = high_prev;
516 high_prev = hi;
517
518 i += 1;
519 }
520 }
521}
522
523#[cfg(all(feature = "simd128", target_arch = "wasm32"))]
524#[inline]
525pub fn sar_simd128(
526 high: &[f64],
527 low: &[f64],
528 first_valid: usize,
529 acceleration: f64,
530 maximum: f64,
531 out: &mut [f64],
532) {
533 sar_scalar(high, low, first_valid, acceleration, maximum, out)
534}
535
536#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
537#[inline]
538pub unsafe fn sar_avx512_short(
539 high: &[f64],
540 low: &[f64],
541 first_valid: usize,
542 acceleration: f64,
543 maximum: f64,
544 out: &mut [f64],
545) {
546 sar_avx2(high, low, first_valid, acceleration, maximum, out)
547}
548
549#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
550#[inline]
551pub unsafe fn sar_avx512_long(
552 high: &[f64],
553 low: &[f64],
554 first_valid: usize,
555 acceleration: f64,
556 maximum: f64,
557 out: &mut [f64],
558) {
559 sar_avx2(high, low, first_valid, acceleration, maximum, out)
560}
561
562#[derive(Debug, Clone)]
563pub struct SarStream {
564 acceleration: f64,
565 maximum: f64,
566 state: Option<StreamState>,
567
568 idx: usize,
569}
570
571#[derive(Debug, Clone)]
572struct StreamState {
573 trend_up: bool,
574 sar: f64,
575 ep: f64,
576 acc: f64,
577
578 prev_high: f64,
579 prev_high2: f64,
580 prev_low: f64,
581 prev_low2: f64,
582}
583
584impl SarStream {
585 pub fn try_new(params: SarParams) -> Result<Self, SarError> {
586 let acceleration = params.acceleration.unwrap_or(0.02);
587 let maximum = params.maximum.unwrap_or(0.2);
588
589 if !(acceleration > 0.0) || !acceleration.is_finite() {
590 return Err(SarError::InvalidAcceleration { acceleration });
591 }
592 if !(maximum > 0.0) || !maximum.is_finite() {
593 return Err(SarError::InvalidMaximum { maximum });
594 }
595
596 Ok(Self {
597 acceleration,
598 maximum,
599 state: None,
600 idx: 0,
601 })
602 }
603
604 #[inline(always)]
605 pub fn update(&mut self, high: f64, low: f64) -> Option<f64> {
606 if !high.is_finite() || !low.is_finite() {
607 return None;
608 }
609
610 match self.state.as_mut() {
611 None => {
612 self.state = Some(StreamState {
613 trend_up: false,
614 sar: f64::NAN,
615 ep: f64::NAN,
616 acc: self.acceleration,
617 prev_high: high,
618 prev_high2: high,
619 prev_low: low,
620 prev_low2: low,
621 });
622 self.idx = 1;
623 None
624 }
625
626 Some(st) if self.idx == 1 => {
627 let trend_up = high > st.prev_high;
628
629 let sar = if trend_up { st.prev_low } else { st.prev_high };
630 let ep = if trend_up { high } else { low };
631
632 st.prev_high2 = st.prev_high;
633 st.prev_low2 = st.prev_low;
634 st.prev_high = high;
635 st.prev_low = low;
636
637 st.trend_up = trend_up;
638 st.sar = sar;
639 st.ep = ep;
640 st.acc = self.acceleration;
641
642 self.idx = 2;
643 Some(sar)
644 }
645
646 Some(st) => {
647 let mut next_sar = st.acc.mul_add(st.ep - st.sar, st.sar);
648
649 if st.trend_up {
650 if low < next_sar {
651 st.trend_up = false;
652 next_sar = st.ep;
653 st.ep = low;
654 st.acc = self.acceleration;
655 } else {
656 if high > st.ep {
657 st.ep = high;
658 st.acc = (st.acc + self.acceleration).min(self.maximum);
659 }
660 next_sar = min3(next_sar, st.prev_low, st.prev_low2);
661 }
662 } else {
663 if high > next_sar {
664 st.trend_up = true;
665 next_sar = st.ep;
666 st.ep = high;
667 st.acc = self.acceleration;
668 } else {
669 if low < st.ep {
670 st.ep = low;
671 st.acc = (st.acc + self.acceleration).min(self.maximum);
672 }
673 next_sar = max3(next_sar, st.prev_high, st.prev_high2);
674 }
675 }
676
677 st.prev_high2 = st.prev_high;
678 st.prev_low2 = st.prev_low;
679 st.prev_high = high;
680 st.prev_low = low;
681
682 st.sar = next_sar;
683 self.idx += 1;
684 Some(next_sar)
685 }
686 }
687 }
688}
689
690#[inline(always)]
691fn min3(a: f64, b: f64, c: f64) -> f64 {
692 a.min(b.min(c))
693}
694
695#[inline(always)]
696fn max3(a: f64, b: f64, c: f64) -> f64 {
697 a.max(b.max(c))
698}
699
700#[derive(Clone, Debug)]
701pub struct SarBatchRange {
702 pub acceleration: (f64, f64, f64),
703 pub maximum: (f64, f64, f64),
704}
705
706impl Default for SarBatchRange {
707 fn default() -> Self {
708 Self {
709 acceleration: (0.02, 0.02, 0.0),
710 maximum: (0.2, 0.449, 0.001),
711 }
712 }
713}
714
715#[derive(Clone, Debug, Default)]
716pub struct SarBatchBuilder {
717 range: SarBatchRange,
718 kernel: Kernel,
719}
720
721impl SarBatchBuilder {
722 pub fn new() -> Self {
723 Self::default()
724 }
725 pub fn kernel(mut self, k: Kernel) -> Self {
726 self.kernel = k;
727 self
728 }
729
730 pub fn acceleration_range(mut self, start: f64, end: f64, step: f64) -> Self {
731 self.range.acceleration = (start, end, step);
732 self
733 }
734 pub fn acceleration_static(mut self, x: f64) -> Self {
735 self.range.acceleration = (x, x, 0.0);
736 self
737 }
738 pub fn maximum_range(mut self, start: f64, end: f64, step: f64) -> Self {
739 self.range.maximum = (start, end, step);
740 self
741 }
742 pub fn maximum_static(mut self, x: f64) -> Self {
743 self.range.maximum = (x, x, 0.0);
744 self
745 }
746
747 pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<SarBatchOutput, SarError> {
748 sar_batch_with_kernel(high, low, &self.range, self.kernel)
749 }
750
751 pub fn with_default_slices(
752 high: &[f64],
753 low: &[f64],
754 k: Kernel,
755 ) -> Result<SarBatchOutput, SarError> {
756 SarBatchBuilder::new().kernel(k).apply_slices(high, low)
757 }
758
759 pub fn apply_candles(self, c: &Candles) -> Result<SarBatchOutput, SarError> {
760 self.apply_slices(&c.high, &c.low)
761 }
762
763 pub fn with_default_candles(c: &Candles) -> Result<SarBatchOutput, SarError> {
764 SarBatchBuilder::new().kernel(Kernel::Auto).apply_candles(c)
765 }
766}
767
768pub fn sar_batch_with_kernel(
769 high: &[f64],
770 low: &[f64],
771 sweep: &SarBatchRange,
772 k: Kernel,
773) -> Result<SarBatchOutput, SarError> {
774 let kernel = match k {
775 Kernel::Auto => detect_best_batch_kernel(),
776 other if other.is_batch() => other,
777 other => return Err(SarError::InvalidKernelForBatch(other)),
778 };
779 let simd = match kernel {
780 Kernel::Avx512Batch => Kernel::Avx512,
781 Kernel::Avx2Batch => Kernel::Avx2,
782 Kernel::ScalarBatch => Kernel::Scalar,
783 _ => unreachable!(),
784 };
785 sar_batch_par_slice(high, low, sweep, simd)
786}
787
788#[derive(Clone, Debug)]
789pub struct SarBatchOutput {
790 pub values: Vec<f64>,
791 pub combos: Vec<SarParams>,
792 pub rows: usize,
793 pub cols: usize,
794}
795impl SarBatchOutput {
796 pub fn row_for_params(&self, p: &SarParams) -> Option<usize> {
797 self.combos.iter().position(|c| {
798 (c.acceleration.unwrap_or(0.02) - p.acceleration.unwrap_or(0.02)).abs() < 1e-12
799 && (c.maximum.unwrap_or(0.2) - p.maximum.unwrap_or(0.2)).abs() < 1e-12
800 })
801 }
802 pub fn values_for(&self, p: &SarParams) -> Option<&[f64]> {
803 self.row_for_params(p).map(|row| {
804 let start = row * self.cols;
805 &self.values[start..start + self.cols]
806 })
807 }
808}
809
810#[inline(always)]
811fn axis_f64_checked(axis: (f64, f64, f64)) -> Result<Vec<f64>, SarError> {
812 let (start, end, step) = axis;
813 if !step.is_finite() {
814 return Err(SarError::InvalidRange { start, end, step });
815 }
816 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
817 return Ok(vec![start]);
818 }
819
820 let mut v = Vec::new();
821 let tol = step.abs() * 1e-12;
822
823 if step > 0.0 {
824 if start <= end {
825 let mut x = start;
826 while x <= end + tol {
827 v.push(x);
828 x += step;
829 }
830 } else {
831 let mut x = start;
832 while x >= end - tol {
833 v.push(x);
834 x -= step;
835 }
836 }
837 } else {
838 if start >= end {
839 let mut x = start;
840 while x >= end - tol {
841 v.push(x);
842 x += step;
843 }
844 } else {
845 return Err(SarError::InvalidRange { start, end, step });
846 }
847 }
848
849 if v.is_empty() {
850 Err(SarError::InvalidRange { start, end, step })
851 } else {
852 Ok(v)
853 }
854}
855
856#[inline(always)]
857fn expand_grid(r: &SarBatchRange) -> Result<Vec<SarParams>, SarError> {
858 let accs = axis_f64_checked(r.acceleration)?;
859 let maxs = axis_f64_checked(r.maximum)?;
860
861 let capacity = accs
862 .len()
863 .checked_mul(maxs.len())
864 .ok_or(SarError::InvalidRange {
865 start: r.acceleration.0,
866 end: r.acceleration.1,
867 step: r.acceleration.2,
868 })?;
869
870 let mut out = Vec::with_capacity(capacity);
871 for &a in &accs {
872 for &m in &maxs {
873 out.push(SarParams {
874 acceleration: Some(a),
875 maximum: Some(m),
876 });
877 }
878 }
879 Ok(out)
880}
881
882#[inline(always)]
883pub fn sar_batch_slice(
884 high: &[f64],
885 low: &[f64],
886 sweep: &SarBatchRange,
887 kern: Kernel,
888) -> Result<SarBatchOutput, SarError> {
889 sar_batch_inner(high, low, sweep, kern, false)
890}
891
892#[inline(always)]
893pub fn sar_batch_par_slice(
894 high: &[f64],
895 low: &[f64],
896 sweep: &SarBatchRange,
897 kern: Kernel,
898) -> Result<SarBatchOutput, SarError> {
899 sar_batch_inner(high, low, sweep, kern, true)
900}
901
902#[inline(always)]
903fn sar_batch_inner(
904 high: &[f64],
905 low: &[f64],
906 sweep: &SarBatchRange,
907 kern: Kernel,
908 parallel: bool,
909) -> Result<SarBatchOutput, SarError> {
910 let combos = expand_grid(sweep)?;
911
912 let min_len = high.len().min(low.len());
913 let (high, low) = (&high[..min_len], &low[..min_len]);
914 let first = high
915 .iter()
916 .zip(low.iter())
917 .position(|(&h, &l)| !h.is_nan() && !l.is_nan())
918 .ok_or(SarError::AllValuesNaN)?;
919
920 if high.len() - first < 2 {
921 return Err(SarError::NotEnoughValidData {
922 needed: 2,
923 valid: high.len() - first,
924 });
925 }
926 let rows = combos.len();
927 let cols = high.len();
928
929 let mut buf_mu = make_uninit_matrix(rows, cols);
930
931 let warm = vec![first + 1; rows];
932 init_matrix_prefixes(&mut buf_mu, cols, &warm);
933
934 let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
935 let values: &mut [f64] = unsafe {
936 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
937 };
938
939 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
940 let p = &combos[row];
941 match kern {
942 Kernel::Scalar => sar_row_scalar(
943 high,
944 low,
945 first,
946 p.acceleration.unwrap(),
947 p.maximum.unwrap(),
948 out_row,
949 ),
950 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
951 Kernel::Avx2 => sar_row_avx2(
952 high,
953 low,
954 first,
955 p.acceleration.unwrap(),
956 p.maximum.unwrap(),
957 out_row,
958 ),
959 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
960 Kernel::Avx512 => sar_row_avx512(
961 high,
962 low,
963 first,
964 p.acceleration.unwrap(),
965 p.maximum.unwrap(),
966 out_row,
967 ),
968 _ => unreachable!(),
969 }
970 };
971
972 if parallel {
973 #[cfg(not(target_arch = "wasm32"))]
974 {
975 values
976 .par_chunks_mut(cols)
977 .enumerate()
978 .for_each(|(row, slice)| do_row(row, slice));
979 }
980
981 #[cfg(target_arch = "wasm32")]
982 {
983 for (row, slice) in values.chunks_mut(cols).enumerate() {
984 do_row(row, slice);
985 }
986 }
987 } else {
988 for (row, slice) in values.chunks_mut(cols).enumerate() {
989 do_row(row, slice);
990 }
991 }
992
993 let values = unsafe {
994 Vec::from_raw_parts(
995 buf_guard.as_mut_ptr() as *mut f64,
996 buf_guard.len(),
997 buf_guard.capacity(),
998 )
999 };
1000
1001 Ok(SarBatchOutput {
1002 values,
1003 combos,
1004 rows,
1005 cols,
1006 })
1007}
1008
1009#[inline(always)]
1010unsafe fn sar_row_scalar(
1011 high: &[f64],
1012 low: &[f64],
1013 first: usize,
1014 acceleration: f64,
1015 maximum: f64,
1016 out: &mut [f64],
1017) {
1018 sar_scalar(high, low, first, acceleration, maximum, out)
1019}
1020
1021#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1022#[inline(always)]
1023unsafe fn sar_row_avx2(
1024 high: &[f64],
1025 low: &[f64],
1026 first: usize,
1027 acceleration: f64,
1028 maximum: f64,
1029 out: &mut [f64],
1030) {
1031 sar_avx2(high, low, first, acceleration, maximum, out)
1032}
1033
1034#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1035#[inline(always)]
1036pub unsafe fn sar_row_avx512(
1037 high: &[f64],
1038 low: &[f64],
1039 first: usize,
1040 acceleration: f64,
1041 maximum: f64,
1042 out: &mut [f64],
1043) {
1044 sar_avx2(high, low, first, acceleration, maximum, out)
1045}
1046
1047#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1048#[inline(always)]
1049pub unsafe fn sar_row_avx512_short(
1050 high: &[f64],
1051 low: &[f64],
1052 first: usize,
1053 acceleration: f64,
1054 maximum: f64,
1055 out: &mut [f64],
1056) {
1057 sar_avx2(high, low, first, acceleration, maximum, out)
1058}
1059
1060#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1061#[inline(always)]
1062pub unsafe fn sar_row_avx512_long(
1063 high: &[f64],
1064 low: &[f64],
1065 first: usize,
1066 acceleration: f64,
1067 maximum: f64,
1068 out: &mut [f64],
1069) {
1070 sar_avx2(high, low, first, acceleration, maximum, out)
1071}
1072
1073#[inline(always)]
1074fn expand_grid_for_test(r: &SarBatchRange) -> Vec<SarParams> {
1075 expand_grid(r).expect("expand_grid_for_test should not fail")
1076}
1077
1078#[cfg(feature = "python")]
1079#[pyfunction(name = "sar")]
1080#[pyo3(signature = (high, low, acceleration=None, maximum=None, kernel=None))]
1081pub fn sar_py<'py>(
1082 py: Python<'py>,
1083 high: PyReadonlyArray1<'py, f64>,
1084 low: PyReadonlyArray1<'py, f64>,
1085 acceleration: Option<f64>,
1086 maximum: Option<f64>,
1087 kernel: Option<&str>,
1088) -> PyResult<Bound<'py, PyArray1<f64>>> {
1089 let high_slice = high.as_slice()?;
1090 let low_slice = low.as_slice()?;
1091
1092 let min_len = high_slice.len().min(low_slice.len());
1093 let high_trimmed = &high_slice[..min_len];
1094 let low_trimmed = &low_slice[..min_len];
1095
1096 let kern = validate_kernel(kernel, false)?;
1097
1098 let params = SarParams {
1099 acceleration,
1100 maximum,
1101 };
1102 let input = SarInput::from_slices(high_trimmed, low_trimmed, params);
1103
1104 let result_vec: Vec<f64> = py
1105 .allow_threads(|| sar_with_kernel(&input, kern).map(|o| o.values))
1106 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1107
1108 Ok(result_vec.into_pyarray(py))
1109}
1110
1111#[cfg(feature = "python")]
1112#[pyclass(name = "SarStream")]
1113pub struct SarStreamPy {
1114 stream: SarStream,
1115}
1116
1117#[cfg(feature = "python")]
1118#[pymethods]
1119impl SarStreamPy {
1120 #[new]
1121 #[pyo3(signature = (acceleration=None, maximum=None))]
1122 fn new(acceleration: Option<f64>, maximum: Option<f64>) -> PyResult<Self> {
1123 let params = SarParams {
1124 acceleration,
1125 maximum,
1126 };
1127 let stream =
1128 SarStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1129 Ok(SarStreamPy { stream })
1130 }
1131
1132 fn update(&mut self, high: f64, low: f64) -> Option<f64> {
1133 self.stream.update(high, low)
1134 }
1135}
1136
1137#[cfg(feature = "python")]
1138#[pyfunction(name = "sar_batch")]
1139#[pyo3(signature = (high, low, acceleration_range, maximum_range, kernel=None))]
1140pub fn sar_batch_py<'py>(
1141 py: Python<'py>,
1142 high: PyReadonlyArray1<'py, f64>,
1143 low: PyReadonlyArray1<'py, f64>,
1144 acceleration_range: (f64, f64, f64),
1145 maximum_range: (f64, f64, f64),
1146 kernel: Option<&str>,
1147) -> PyResult<Bound<'py, PyDict>> {
1148 let high_slice = high.as_slice()?;
1149 let low_slice = low.as_slice()?;
1150
1151 let min_len = high_slice.len().min(low_slice.len());
1152 let (high_slice, low_slice) = (&high_slice[..min_len], &low_slice[..min_len]);
1153
1154 let sweep = SarBatchRange {
1155 acceleration: acceleration_range,
1156 maximum: maximum_range,
1157 };
1158 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1159 let rows = combos.len();
1160 let cols = min_len;
1161 let total = rows
1162 .checked_mul(cols)
1163 .ok_or_else(|| PyValueError::new_err("sar_batch_py: size overflow in rows*cols"))?;
1164
1165 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1166 let slice_out = unsafe { out_arr.as_slice_mut()? };
1167
1168 let kern = validate_kernel(kernel, true)?;
1169 py.allow_threads(|| {
1170 let k = match kern {
1171 Kernel::Auto => detect_best_batch_kernel(),
1172 k => k,
1173 };
1174 sar_batch_inner_into_noalloc(high_slice, low_slice, &sweep, k, true, slice_out)
1175 })
1176 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1177
1178 let dict = PyDict::new(py);
1179 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1180 dict.set_item(
1181 "accelerations",
1182 combos
1183 .iter()
1184 .map(|p| p.acceleration.unwrap_or(0.02))
1185 .collect::<Vec<_>>()
1186 .into_pyarray(py),
1187 )?;
1188 dict.set_item(
1189 "maximums",
1190 combos
1191 .iter()
1192 .map(|p| p.maximum.unwrap_or(0.2))
1193 .collect::<Vec<_>>()
1194 .into_pyarray(py),
1195 )?;
1196
1197 Ok(dict)
1198}
1199
1200#[cfg(feature = "python")]
1201fn sar_batch_inner_into_noalloc(
1202 high: &[f64],
1203 low: &[f64],
1204 sweep: &SarBatchRange,
1205 kern: Kernel,
1206 parallel: bool,
1207 out: &mut [f64],
1208) -> Result<Vec<SarParams>, SarError> {
1209 let combos = expand_grid(sweep)?;
1210
1211 let first = high
1212 .iter()
1213 .zip(low.iter())
1214 .position(|(&h, &l)| !h.is_nan() && !l.is_nan())
1215 .ok_or(SarError::AllValuesNaN)?;
1216 if high.len() - first < 2 {
1217 return Err(SarError::NotEnoughValidData {
1218 needed: 2,
1219 valid: high.len() - first,
1220 });
1221 }
1222
1223 let rows = combos.len();
1224 let cols = high.len();
1225 let expected = rows.checked_mul(cols).ok_or(SarError::InvalidRange {
1226 start: sweep.acceleration.0,
1227 end: sweep.acceleration.1,
1228 step: sweep.acceleration.2,
1229 })?;
1230 if out.len() != expected {
1231 return Err(SarError::OutputLengthMismatch {
1232 expected,
1233 got: out.len(),
1234 });
1235 }
1236
1237 unsafe {
1238 let out_mu = std::slice::from_raw_parts_mut(
1239 out.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
1240 out.len(),
1241 );
1242 let warm: Vec<usize> = vec![first + 1; rows];
1243 init_matrix_prefixes(out_mu, cols, &warm);
1244 }
1245
1246 let simd = match kern {
1247 Kernel::Auto => detect_best_batch_kernel(),
1248 k => k,
1249 };
1250 let exec = |row: usize, row_out: &mut [f64]| {
1251 let p = &combos[row];
1252 match simd {
1253 Kernel::Scalar | Kernel::ScalarBatch => unsafe {
1254 sar_row_scalar(
1255 high,
1256 low,
1257 first,
1258 p.acceleration.unwrap(),
1259 p.maximum.unwrap(),
1260 row_out,
1261 );
1262 },
1263 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1264 Kernel::Avx2 | Kernel::Avx2Batch => unsafe {
1265 sar_row_avx2(
1266 high,
1267 low,
1268 first,
1269 p.acceleration.unwrap(),
1270 p.maximum.unwrap(),
1271 row_out,
1272 );
1273 },
1274 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1275 Kernel::Avx512 | Kernel::Avx512Batch => unsafe {
1276 sar_row_avx512(
1277 high,
1278 low,
1279 first,
1280 p.acceleration.unwrap(),
1281 p.maximum.unwrap(),
1282 row_out,
1283 );
1284 },
1285 _ => unsafe {
1286 sar_row_scalar(
1287 high,
1288 low,
1289 first,
1290 p.acceleration.unwrap(),
1291 p.maximum.unwrap(),
1292 row_out,
1293 );
1294 },
1295 }
1296 };
1297
1298 if parallel {
1299 #[cfg(not(target_arch = "wasm32"))]
1300 {
1301 use rayon::prelude::*;
1302 out.par_chunks_mut(cols)
1303 .enumerate()
1304 .for_each(|(row, r)| exec(row, r));
1305 }
1306 #[cfg(target_arch = "wasm32")]
1307 {
1308 for (row, r) in out.chunks_mut(cols).enumerate() {
1309 exec(row, r);
1310 }
1311 }
1312 } else {
1313 for (row, r) in out.chunks_mut(cols).enumerate() {
1314 exec(row, r);
1315 }
1316 }
1317
1318 Ok(combos)
1319}
1320
1321#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1322#[wasm_bindgen]
1323pub fn sar_js(
1324 high: &[f64],
1325 low: &[f64],
1326 acceleration: f64,
1327 maximum: f64,
1328) -> Result<Vec<f64>, JsValue> {
1329 let min_len = high.len().min(low.len());
1330 let (high, low) = (&high[..min_len], &low[..min_len]);
1331
1332 let params = SarParams {
1333 acceleration: Some(acceleration),
1334 maximum: Some(maximum),
1335 };
1336 let input = SarInput::from_slices(high, low, params);
1337
1338 let mut output = vec![0.0; min_len];
1339
1340 sar_into_slice(&mut output, &input, Kernel::Auto)
1341 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1342
1343 Ok(output)
1344}
1345
1346#[cfg(all(feature = "python", feature = "cuda"))]
1347#[pyclass(module = "ta_indicators.cuda", unsendable)]
1348pub struct SarDeviceArrayF32Py {
1349 pub(crate) buf: Option<DeviceBuffer<f32>>,
1350 pub(crate) rows: usize,
1351 pub(crate) cols: usize,
1352 pub(crate) _ctx: Arc<Context>,
1353 pub(crate) device_id: u32,
1354}
1355
1356#[cfg(all(feature = "python", feature = "cuda"))]
1357#[pymethods]
1358impl SarDeviceArrayF32Py {
1359 #[getter]
1360 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1361 let d = PyDict::new(py);
1362 d.set_item("shape", (self.rows, self.cols))?;
1363 d.set_item("typestr", "<f4")?;
1364 d.set_item(
1365 "strides",
1366 (
1367 self.cols * std::mem::size_of::<f32>(),
1368 std::mem::size_of::<f32>(),
1369 ),
1370 )?;
1371 let buf = self
1372 .buf
1373 .as_ref()
1374 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1375 let ptr = buf.as_device_ptr().as_raw() as usize;
1376 d.set_item("data", (ptr, false))?;
1377
1378 d.set_item("version", 3)?;
1379 Ok(d)
1380 }
1381
1382 fn __dlpack_device__(&self) -> (i32, i32) {
1383 (2, self.device_id as i32)
1384 }
1385
1386 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1387 fn __dlpack__<'py>(
1388 &mut self,
1389 py: Python<'py>,
1390 stream: Option<pyo3::PyObject>,
1391 max_version: Option<pyo3::PyObject>,
1392 dl_device: Option<pyo3::PyObject>,
1393 copy: Option<pyo3::PyObject>,
1394 ) -> PyResult<PyObject> {
1395 let (kdl, alloc_dev) = self.__dlpack_device__();
1396 if let Some(dev_obj) = dl_device.as_ref() {
1397 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1398 if dev_ty != kdl || dev_id != alloc_dev {
1399 let wants_copy = copy
1400 .as_ref()
1401 .and_then(|c| c.extract::<bool>(py).ok())
1402 .unwrap_or(false);
1403 if wants_copy {
1404 return Err(PyValueError::new_err(
1405 "cross-device DLPack copy is not implemented for SarDeviceArrayF32Py",
1406 ));
1407 } else {
1408 return Err(PyValueError::new_err(
1409 "requested dl_device does not match SAR producer device",
1410 ));
1411 }
1412 }
1413 }
1414 }
1415 let _ = stream;
1416
1417 let buf = self
1418 .buf
1419 .take()
1420 .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
1421
1422 let rows = self.rows;
1423 let cols = self.cols;
1424
1425 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1426
1427 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1428 }
1429}
1430
1431#[cfg(all(feature = "python", feature = "cuda"))]
1432#[pyfunction(name = "sar_cuda_batch_dev")]
1433#[pyo3(signature = (high_f32, low_f32, acceleration_range, maximum_range, device_id=0))]
1434pub fn sar_cuda_batch_dev_py(
1435 py: Python<'_>,
1436 high_f32: numpy::PyReadonlyArray1<'_, f32>,
1437 low_f32: numpy::PyReadonlyArray1<'_, f32>,
1438 acceleration_range: (f64, f64, f64),
1439 maximum_range: (f64, f64, f64),
1440 device_id: usize,
1441) -> PyResult<SarDeviceArrayF32Py> {
1442 use crate::cuda::cuda_available;
1443 use crate::cuda::CudaSar;
1444
1445 if !cuda_available() {
1446 return Err(PyValueError::new_err("CUDA not available"));
1447 }
1448 let h = high_f32.as_slice()?;
1449 let l = low_f32.as_slice()?;
1450 if h.len() != l.len() {
1451 return Err(PyValueError::new_err("high/low length mismatch"));
1452 }
1453 let sweep = SarBatchRange {
1454 acceleration: acceleration_range,
1455 maximum: maximum_range,
1456 };
1457 let (buf, rows, cols, ctx, dev_id) = py.allow_threads(|| {
1458 let cuda = CudaSar::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1459 let (dev, _combos) = cuda
1460 .sar_batch_dev(h, l, &sweep)
1461 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1462 let ctx = cuda.context_arc();
1463 Ok::<_, pyo3::PyErr>((dev.buf, dev.rows, dev.cols, ctx, cuda.device_id()))
1464 })?;
1465 Ok(SarDeviceArrayF32Py {
1466 buf: Some(buf),
1467 rows,
1468 cols,
1469 _ctx: ctx,
1470 device_id: dev_id,
1471 })
1472}
1473
1474#[cfg(all(feature = "python", feature = "cuda"))]
1475#[pyfunction(name = "sar_cuda_many_series_one_param_dev")]
1476#[pyo3(signature = (high_tm_f32, low_tm_f32, cols, rows, acceleration, maximum, device_id=0))]
1477pub fn sar_cuda_many_series_one_param_dev_py(
1478 py: Python<'_>,
1479 high_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
1480 low_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
1481 cols: usize,
1482 rows: usize,
1483 acceleration: f64,
1484 maximum: f64,
1485 device_id: usize,
1486) -> PyResult<SarDeviceArrayF32Py> {
1487 use crate::cuda::cuda_available;
1488 use crate::cuda::CudaSar;
1489
1490 if !cuda_available() {
1491 return Err(PyValueError::new_err("CUDA not available"));
1492 }
1493 let h = high_tm_f32.as_slice()?;
1494 let l = low_tm_f32.as_slice()?;
1495 let expected = cols.checked_mul(rows).ok_or_else(|| {
1496 PyValueError::new_err("sar_cuda_many_series_one_param_dev: size overflow in cols*rows")
1497 })?;
1498 if expected != h.len() || h.len() != l.len() {
1499 return Err(PyValueError::new_err(
1500 "time‑major inputs must be equal length and cols*rows",
1501 ));
1502 }
1503 let params = SarParams {
1504 acceleration: Some(acceleration),
1505 maximum: Some(maximum),
1506 };
1507 let (buf, r_out, c_out, ctx, dev_id) = py.allow_threads(|| {
1508 let cuda = CudaSar::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1509 let dev = cuda
1510 .sar_many_series_one_param_time_major_dev(h, l, cols, rows, ¶ms)
1511 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1512 let ctx = cuda.context_arc();
1513 Ok::<_, pyo3::PyErr>((dev.buf, dev.rows, dev.cols, ctx, cuda.device_id()))
1514 })?;
1515 Ok(SarDeviceArrayF32Py {
1516 buf: Some(buf),
1517 rows: r_out,
1518 cols: c_out,
1519 _ctx: ctx,
1520 device_id: dev_id,
1521 })
1522}
1523
1524#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1525#[wasm_bindgen]
1526pub fn sar_into(
1527 high_ptr: *const f64,
1528 low_ptr: *const f64,
1529 out_ptr: *mut f64,
1530 len: usize,
1531 acceleration: f64,
1532 maximum: f64,
1533) -> Result<(), JsValue> {
1534 if high_ptr.is_null() || low_ptr.is_null() || out_ptr.is_null() {
1535 return Err(JsValue::from_str("null pointer passed to sar_into"));
1536 }
1537
1538 unsafe {
1539 let high_slice = std::slice::from_raw_parts(high_ptr, len);
1540 let low_slice = std::slice::from_raw_parts(low_ptr, len);
1541
1542 let params = SarParams {
1543 acceleration: Some(acceleration),
1544 maximum: Some(maximum),
1545 };
1546 let input = SarInput::from_slices(high_slice, low_slice, params);
1547
1548 if high_ptr == out_ptr || low_ptr == out_ptr {
1549 let mut temp = vec![0.0; len];
1550 sar_into_slice(&mut temp, &input, Kernel::Auto)
1551 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1552 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1553 out.copy_from_slice(&temp);
1554 } else {
1555 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1556 sar_into_slice(out, &input, Kernel::Auto)
1557 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1558 }
1559
1560 Ok(())
1561 }
1562}
1563
1564#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1565#[wasm_bindgen]
1566pub fn sar_alloc(len: usize) -> *mut f64 {
1567 let mut vec = Vec::<f64>::with_capacity(len);
1568 let ptr = vec.as_mut_ptr();
1569 std::mem::forget(vec);
1570 ptr
1571}
1572
1573#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1574#[wasm_bindgen]
1575pub fn sar_free(ptr: *mut f64, len: usize) {
1576 if !ptr.is_null() {
1577 unsafe {
1578 let _ = Vec::from_raw_parts(ptr, len, len);
1579 }
1580 }
1581}
1582
1583#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1584#[wasm_bindgen]
1585pub fn sar_batch_into(
1586 high_ptr: *const f64,
1587 low_ptr: *const f64,
1588 out_ptr: *mut f64,
1589 len: usize,
1590 acc_start: f64,
1591 acc_end: f64,
1592 acc_step: f64,
1593 max_start: f64,
1594 max_end: f64,
1595 max_step: f64,
1596) -> Result<usize, JsValue> {
1597 if high_ptr.is_null() || low_ptr.is_null() || out_ptr.is_null() {
1598 return Err(JsValue::from_str("null pointer passed to sar_batch_into"));
1599 }
1600 unsafe {
1601 let high = std::slice::from_raw_parts(high_ptr, len);
1602 let low = std::slice::from_raw_parts(low_ptr, len);
1603 let sweep = SarBatchRange {
1604 acceleration: (acc_start, acc_end, acc_step),
1605 maximum: (max_start, max_end, max_step),
1606 };
1607 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1608 let rows = combos.len();
1609 let cols = len;
1610 let total = rows
1611 .checked_mul(cols)
1612 .ok_or_else(|| JsValue::from_str("sar_batch_into: size overflow in rows*cols"))?;
1613
1614 let out = std::slice::from_raw_parts_mut(out_ptr, total);
1615
1616 sar_batch_inner_into_noalloc_wasm(high, low, &sweep, Kernel::Scalar, false, out)
1617 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1618 Ok(rows)
1619 }
1620}
1621
1622#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1623fn sar_batch_inner_into_noalloc_wasm(
1624 high: &[f64],
1625 low: &[f64],
1626 sweep: &SarBatchRange,
1627 _kern: Kernel,
1628 parallel: bool,
1629 out: &mut [f64],
1630) -> Result<Vec<SarParams>, SarError> {
1631 let combos = expand_grid(sweep)?;
1632
1633 let min_len = high.len().min(low.len());
1634 let (high, low) = (&high[..min_len], &low[..min_len]);
1635
1636 let first = high
1637 .iter()
1638 .zip(low.iter())
1639 .position(|(&h, &l)| !h.is_nan() && !l.is_nan())
1640 .ok_or(SarError::AllValuesNaN)?;
1641 if high.len() - first < 2 {
1642 return Err(SarError::NotEnoughValidData {
1643 needed: 2,
1644 valid: high.len() - first,
1645 });
1646 }
1647
1648 let rows = combos.len();
1649 let cols = high.len();
1650 let expected = rows.checked_mul(cols).ok_or(SarError::InvalidRange {
1651 start: sweep.acceleration.0,
1652 end: sweep.acceleration.1,
1653 step: sweep.acceleration.2,
1654 })?;
1655 if out.len() != expected {
1656 return Err(SarError::OutputLengthMismatch {
1657 expected,
1658 got: out.len(),
1659 });
1660 }
1661
1662 unsafe {
1663 let out_mu = std::slice::from_raw_parts_mut(
1664 out.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
1665 out.len(),
1666 );
1667 let warm: Vec<usize> = vec![first + 1; rows];
1668 init_matrix_prefixes(out_mu, cols, &warm);
1669 }
1670
1671 let exec = |row: usize, row_out: &mut [f64]| unsafe {
1672 let p = &combos[row];
1673 sar_row_scalar(
1674 high,
1675 low,
1676 first,
1677 p.acceleration.unwrap(),
1678 p.maximum.unwrap(),
1679 row_out,
1680 );
1681 };
1682
1683 if parallel {
1684 #[cfg(not(target_arch = "wasm32"))]
1685 {
1686 use rayon::prelude::*;
1687 out.par_chunks_mut(cols)
1688 .enumerate()
1689 .for_each(|(row, r)| exec(row, r));
1690 }
1691 #[cfg(target_arch = "wasm32")]
1692 {
1693 for (row, r) in out.chunks_mut(cols).enumerate() {
1694 exec(row, r);
1695 }
1696 }
1697 } else {
1698 for (row, r) in out.chunks_mut(cols).enumerate() {
1699 exec(row, r);
1700 }
1701 }
1702
1703 Ok(combos)
1704}
1705
1706#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1707#[derive(Serialize, Deserialize)]
1708pub struct SarBatchConfig {
1709 pub acceleration_range: (f64, f64, f64),
1710 pub maximum_range: (f64, f64, f64),
1711}
1712
1713#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1714#[derive(Serialize, Deserialize)]
1715pub struct SarBatchJsOutput {
1716 pub values: Vec<f64>,
1717 pub combos: Vec<SarParams>,
1718 pub rows: usize,
1719 pub cols: usize,
1720}
1721
1722#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1723#[wasm_bindgen(js_name = sar_batch)]
1724pub fn sar_batch_unified_js(
1725 high: &[f64],
1726 low: &[f64],
1727 config: JsValue,
1728) -> Result<JsValue, JsValue> {
1729 let config: SarBatchConfig = serde_wasm_bindgen::from_value(config)
1730 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1731
1732 let sweep = SarBatchRange {
1733 acceleration: config.acceleration_range,
1734 maximum: config.maximum_range,
1735 };
1736
1737 let kernel = if cfg!(target_arch = "wasm32") {
1738 Kernel::Scalar
1739 } else {
1740 detect_best_batch_kernel()
1741 };
1742 let output = sar_batch_inner(high, low, &sweep, kernel, false)
1743 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1744
1745 let js_output = SarBatchJsOutput {
1746 values: output.values,
1747 combos: output.combos,
1748 rows: output.rows,
1749 cols: output.cols,
1750 };
1751
1752 serde_wasm_bindgen::to_value(&js_output).map_err(|e| JsValue::from_str(&e.to_string()))
1753}
1754
1755#[cfg(test)]
1756mod tests {
1757 use super::*;
1758 use crate::skip_if_unsupported;
1759 use crate::utilities::data_loader::read_candles_from_csv;
1760
1761 fn check_sar_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1762 skip_if_unsupported!(kernel, test_name);
1763 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1764 let candles = read_candles_from_csv(file_path)?;
1765
1766 let default_params = SarParams {
1767 acceleration: None,
1768 maximum: None,
1769 };
1770 let input = SarInput::from_candles(&candles, default_params);
1771 let output = sar_with_kernel(&input, kernel)?;
1772 assert_eq!(output.values.len(), candles.close.len());
1773
1774 Ok(())
1775 }
1776
1777 fn check_sar_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1778 skip_if_unsupported!(kernel, test_name);
1779 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1780 let candles = read_candles_from_csv(file_path)?;
1781
1782 let input = SarInput::from_candles(&candles, SarParams::default());
1783 let result = sar_with_kernel(&input, kernel)?;
1784 let expected_last_five = [
1785 60370.00224209362,
1786 60220.362107568006,
1787 60079.70038111392,
1788 59947.478358247085,
1789 59823.189656752256,
1790 ];
1791 let start = result.values.len().saturating_sub(5);
1792 for (i, &val) in result.values[start..].iter().enumerate() {
1793 let diff = (val - expected_last_five[i]).abs();
1794 assert!(
1795 diff < 1e-4,
1796 "[{}] SAR {:?} mismatch at idx {}: got {}, expected {}",
1797 test_name,
1798 kernel,
1799 i,
1800 val,
1801 expected_last_five[i]
1802 );
1803 }
1804 Ok(())
1805 }
1806
1807 fn check_sar_from_slices(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1808 skip_if_unsupported!(kernel, test_name);
1809 let high = [50000.0, 50500.0, 51000.0];
1810 let low = [49000.0, 49500.0, 49900.0];
1811 let params = SarParams::default();
1812 let input = SarInput::from_slices(&high, &low, params);
1813 let result = sar_with_kernel(&input, kernel)?;
1814 assert_eq!(result.values.len(), high.len());
1815 Ok(())
1816 }
1817
1818 #[test]
1819 fn test_sar_into_slice_trims_mismatched_lengths() -> Result<(), Box<dyn Error>> {
1820 let high = [50000.0, 50500.0, 51000.0, 52000.0];
1821 let low = [49000.0, 49500.0, 49900.0];
1822 let params = SarParams::default();
1823 let input = SarInput::from_slices(&high, &low, params);
1824
1825 let expected = sar_with_kernel(&input, Kernel::Scalar)?;
1826 assert_eq!(expected.values.len(), low.len());
1827
1828 let mut out = vec![0.0; low.len()];
1829 sar_into_slice(&mut out, &input, Kernel::Scalar)?;
1830
1831 for i in 0..out.len() {
1832 let a = out[i];
1833 let b = expected.values[i];
1834 assert!(
1835 (a.is_nan() && b.is_nan()) || a == b,
1836 "mismatch at {}: {} vs {}",
1837 i,
1838 a,
1839 b
1840 );
1841 }
1842
1843 Ok(())
1844 }
1845
1846 fn check_sar_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1847 skip_if_unsupported!(kernel, test_name);
1848 let high = [f64::NAN, f64::NAN, f64::NAN];
1849 let low = [f64::NAN, f64::NAN, f64::NAN];
1850 let params = SarParams::default();
1851 let input = SarInput::from_slices(&high, &low, params);
1852 let result = sar_with_kernel(&input, kernel);
1853 assert!(result.is_err());
1854 if let Err(e) = result {
1855 assert!(e.to_string().contains("All values are NaN"));
1856 }
1857 Ok(())
1858 }
1859
1860 #[cfg(debug_assertions)]
1861 fn check_sar_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1862 skip_if_unsupported!(kernel, test_name);
1863
1864 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1865 let candles = read_candles_from_csv(file_path)?;
1866
1867 let test_params = vec![
1868 SarParams::default(),
1869 SarParams {
1870 acceleration: Some(0.001),
1871 maximum: Some(0.001),
1872 },
1873 SarParams {
1874 acceleration: Some(0.01),
1875 maximum: Some(0.1),
1876 },
1877 SarParams {
1878 acceleration: Some(0.02),
1879 maximum: Some(0.3),
1880 },
1881 SarParams {
1882 acceleration: Some(0.05),
1883 maximum: Some(0.2),
1884 },
1885 SarParams {
1886 acceleration: Some(0.05),
1887 maximum: Some(0.5),
1888 },
1889 SarParams {
1890 acceleration: Some(0.1),
1891 maximum: Some(0.5),
1892 },
1893 SarParams {
1894 acceleration: Some(0.1),
1895 maximum: Some(0.9),
1896 },
1897 SarParams {
1898 acceleration: Some(0.2),
1899 maximum: Some(0.9),
1900 },
1901 SarParams {
1902 acceleration: Some(0.001),
1903 maximum: Some(0.9),
1904 },
1905 SarParams {
1906 acceleration: Some(0.2),
1907 maximum: Some(0.01),
1908 },
1909 ];
1910
1911 for (param_idx, params) in test_params.iter().enumerate() {
1912 let input = SarInput::from_candles(&candles, params.clone());
1913 let output = sar_with_kernel(&input, kernel)?;
1914
1915 for (i, &val) in output.values.iter().enumerate() {
1916 if val.is_nan() {
1917 continue;
1918 }
1919
1920 let bits = val.to_bits();
1921
1922 if bits == 0x11111111_11111111 {
1923 panic!(
1924 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1925 with params: acceleration={}, maximum={} (param set {})",
1926 test_name,
1927 val,
1928 bits,
1929 i,
1930 params.acceleration.unwrap_or(0.02),
1931 params.maximum.unwrap_or(0.2),
1932 param_idx
1933 );
1934 }
1935
1936 if bits == 0x22222222_22222222 {
1937 panic!(
1938 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1939 with params: acceleration={}, maximum={} (param set {})",
1940 test_name,
1941 val,
1942 bits,
1943 i,
1944 params.acceleration.unwrap_or(0.02),
1945 params.maximum.unwrap_or(0.2),
1946 param_idx
1947 );
1948 }
1949
1950 if bits == 0x33333333_33333333 {
1951 panic!(
1952 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1953 with params: acceleration={}, maximum={} (param set {})",
1954 test_name,
1955 val,
1956 bits,
1957 i,
1958 params.acceleration.unwrap_or(0.02),
1959 params.maximum.unwrap_or(0.2),
1960 param_idx
1961 );
1962 }
1963 }
1964 }
1965
1966 Ok(())
1967 }
1968
1969 #[cfg(not(debug_assertions))]
1970 fn check_sar_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1971 Ok(())
1972 }
1973
1974 #[cfg(feature = "proptest")]
1975 #[allow(clippy::float_cmp)]
1976 fn check_sar_property(
1977 test_name: &str,
1978 kernel: Kernel,
1979 ) -> Result<(), Box<dyn std::error::Error>> {
1980 use proptest::prelude::*;
1981 skip_if_unsupported!(kernel, test_name);
1982
1983 let strat = (0.001f64..0.5f64)
1984 .prop_flat_map(|acceleration| (Just(acceleration), acceleration..1.0f64))
1985 .prop_flat_map(|(acceleration, maximum)| {
1986 (
1987 prop::collection::vec(
1988 (1.0f64..1e6f64).prop_filter("finite price", |x| x.is_finite() && *x > 0.0),
1989 10..400,
1990 ),
1991 Just(acceleration),
1992 Just(maximum),
1993 0.001f64..0.1f64,
1994 )
1995 });
1996
1997 proptest::test_runner::TestRunner::default().run(
1998 &strat,
1999 |(base_prices, acceleration, maximum, volatility)| {
2000 let mut high = Vec::with_capacity(base_prices.len());
2001 let mut low = Vec::with_capacity(base_prices.len());
2002
2003 let mut spread_factor = 1.0;
2004 for price in &base_prices {
2005 spread_factor = (spread_factor + (price % 0.1 - 0.05) * 0.2)
2006 .max(0.5)
2007 .min(2.0);
2008 let spread = price * volatility * spread_factor;
2009 high.push(price + spread);
2010 low.push(price - spread);
2011 }
2012
2013 let params = SarParams {
2014 acceleration: Some(acceleration),
2015 maximum: Some(maximum),
2016 };
2017 let input = SarInput::from_slices(&high, &low, params.clone());
2018
2019 let SarOutput { values: out } = sar_with_kernel(&input, kernel).unwrap();
2020
2021 let SarOutput { values: ref_out } =
2022 sar_with_kernel(&input, Kernel::Scalar).unwrap();
2023
2024 for i in 1..out.len() {
2025 if !out[i].is_nan() {
2026 let min_price = low.iter().cloned().fold(f64::INFINITY, f64::min);
2027 let max_price = high.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
2028
2029 prop_assert!(
2030 out[i] >= min_price - 1e-9 && out[i] <= max_price + 1e-9,
2031 "SAR[{}] = {} is outside range [{}, {}]",
2032 i,
2033 out[i],
2034 min_price,
2035 max_price
2036 );
2037 }
2038 }
2039
2040 prop_assert!(
2041 out[0].is_nan(),
2042 "First SAR value should be NaN during warmup, got {}",
2043 out[0]
2044 );
2045
2046 for i in 0..out.len() {
2047 let y = out[i];
2048 let r = ref_out[i];
2049
2050 if y.is_nan() {
2051 prop_assert!(
2052 r.is_nan(),
2053 "NaN mismatch at index {}: test={}, ref={}",
2054 i,
2055 y,
2056 r
2057 );
2058 } else if r.is_nan() {
2059 prop_assert!(
2060 y.is_nan(),
2061 "NaN mismatch at index {}: test={}, ref={}",
2062 i,
2063 y,
2064 r
2065 );
2066 } else {
2067 let diff = (y - r).abs();
2068 prop_assert!(
2069 diff < 1e-9,
2070 "Kernel mismatch at index {}: test={}, ref={}, diff={}",
2071 i,
2072 y,
2073 r,
2074 diff
2075 );
2076 }
2077 }
2078
2079 if out.len() > 10 {
2080 let mut last_movement = 0.0;
2081 let mut increasing_count = 0;
2082
2083 for i in 2..out.len().min(20) {
2084 if !out[i].is_nan() && !out[i - 1].is_nan() {
2085 let movement = (out[i] - out[i - 1]).abs();
2086 if movement > last_movement {
2087 increasing_count += 1;
2088 }
2089 last_movement = movement;
2090 }
2091 }
2092
2093 prop_assert!(
2094 increasing_count > 0 || out.len() < 5,
2095 "SAR acceleration never increases (count: {})",
2096 increasing_count
2097 );
2098 }
2099
2100 let strong_uptrend = high.windows(2).all(|w| w[1] > w[0] + 1e-9)
2101 && low.windows(2).all(|w| w[1] > w[0] + 1e-9);
2102 if strong_uptrend && out.len() > 10 {
2103 let start = out.len() * 3 / 4;
2104 for i in start..out.len() {
2105 if !out[i].is_nan() {
2106 prop_assert!(
2107 out[i] <= low[i] + 1e-6,
2108 "In uptrend, SAR[{}] = {} should be <= low[{}] = {}",
2109 i,
2110 out[i],
2111 i,
2112 low[i]
2113 );
2114 }
2115 }
2116 }
2117
2118 let strong_downtrend = high.windows(2).all(|w| w[1] < w[0] - 1e-9)
2119 && low.windows(2).all(|w| w[1] < w[0] - 1e-9);
2120 if strong_downtrend && out.len() > 10 {
2121 let start = out.len() * 3 / 4;
2122 for i in start..out.len() {
2123 if !out[i].is_nan() {
2124 prop_assert!(
2125 out[i] >= high[i] - 1e-6,
2126 "In downtrend, SAR[{}] = {} should be >= high[{}] = {}",
2127 i,
2128 out[i],
2129 i,
2130 high[i]
2131 );
2132 }
2133 }
2134 }
2135
2136 if out.len() > 5 {
2137 for i in 2..out.len() {
2138 if !out[i].is_nan() && !out[i - 1].is_nan() {
2139 let jump = (out[i] - out[i - 1]).abs();
2140 let avg_price = (high[i] + low[i]) / 2.0;
2141
2142 if jump > avg_price * 0.05 {
2143 let prev_below = out[i - 1] < low[i - 1];
2144 let curr_below = out[i] < low[i];
2145
2146 prop_assert!(
2147 prev_below != curr_below || jump > avg_price * 0.03,
2148 "Large SAR jump without proper reversal at index {}",
2149 i
2150 );
2151 }
2152 }
2153 }
2154 }
2155
2156 if base_prices.windows(2).all(|w| w[1] > w[0]) && out.len() > 20 {
2157 let quarter = out.len() / 4;
2158 let first_quarter: Vec<f64> = out[quarter..quarter * 2]
2159 .iter()
2160 .filter(|v| !v.is_nan())
2161 .cloned()
2162 .collect();
2163 let last_quarter: Vec<f64> = out[quarter * 3..]
2164 .iter()
2165 .filter(|v| !v.is_nan())
2166 .cloned()
2167 .collect();
2168
2169 if !first_quarter.is_empty() && !last_quarter.is_empty() {
2170 let first_avg =
2171 first_quarter.iter().sum::<f64>() / first_quarter.len() as f64;
2172 let last_avg = last_quarter.iter().sum::<f64>() / last_quarter.len() as f64;
2173
2174 prop_assert!(
2175 last_avg >= first_avg * 0.95,
2176 "For increasing prices, SAR should generally trend up: first_avg={}, last_avg={}",
2177 first_avg, last_avg
2178 );
2179 }
2180 }
2181
2182 #[cfg(debug_assertions)]
2183 {
2184 for (i, &val) in out.iter().enumerate() {
2185 if !val.is_nan() {
2186 let bits = val.to_bits();
2187 prop_assert!(
2188 bits != 0x11111111_11111111
2189 && bits != 0x22222222_22222222
2190 && bits != 0x33333333_33333333,
2191 "Found poison value at index {}: {} (0x{:016X})",
2192 i,
2193 val,
2194 bits
2195 );
2196 }
2197 }
2198 }
2199
2200 Ok(())
2201 },
2202 )?;
2203
2204 Ok(())
2205 }
2206
2207 macro_rules! generate_all_sar_tests {
2208 ($($test_fn:ident),*) => {
2209 paste::paste! {
2210 $(
2211 #[test]
2212 fn [<$test_fn _scalar_f64>]() {
2213 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2214 }
2215 )*
2216 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2217 $(
2218 #[test]
2219 fn [<$test_fn _avx2_f64>]() {
2220 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2221 }
2222 #[test]
2223 fn [<$test_fn _avx512_f64>]() {
2224 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2225 }
2226 )*
2227 }
2228 }
2229 }
2230
2231 generate_all_sar_tests!(
2232 check_sar_partial_params,
2233 check_sar_accuracy,
2234 check_sar_from_slices,
2235 check_sar_all_nan,
2236 check_sar_no_poison
2237 );
2238
2239 #[cfg(feature = "proptest")]
2240 generate_all_sar_tests!(check_sar_property);
2241
2242 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2243 skip_if_unsupported!(kernel, test);
2244
2245 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2246 let c = read_candles_from_csv(file)?;
2247
2248 let output = SarBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
2249
2250 let def = SarParams::default();
2251 let row = output.values_for(&def).expect("default row missing");
2252
2253 assert_eq!(row.len(), c.close.len());
2254
2255 let expected = [
2256 60370.00224209362,
2257 60220.362107568006,
2258 60079.70038111392,
2259 59947.478358247085,
2260 59823.189656752256,
2261 ];
2262 let start = row.len() - 5;
2263 for (i, &v) in row[start..].iter().enumerate() {
2264 assert!(
2265 (v - expected[i]).abs() < 1e-4,
2266 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2267 );
2268 }
2269 Ok(())
2270 }
2271
2272 #[cfg(debug_assertions)]
2273 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2274 skip_if_unsupported!(kernel, test);
2275
2276 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2277 let c = read_candles_from_csv(file)?;
2278
2279 let test_configs = vec![
2280 (0.01, 0.05, 0.01, 0.1, 0.3, 0.1),
2281 (0.02, 0.1, 0.02, 0.2, 0.5, 0.1),
2282 (0.05, 0.2, 0.05, 0.3, 0.9, 0.2),
2283 (0.001, 0.005, 0.001, 0.05, 0.1, 0.05),
2284 (0.02, 0.02, 0.0, 0.2, 0.2, 0.0),
2285 (0.1, 0.2, 0.025, 0.5, 0.9, 0.1),
2286 (0.001, 0.01, 0.003, 0.1, 0.5, 0.2),
2287 ];
2288
2289 for (cfg_idx, &(a_start, a_end, a_step, m_start, m_end, m_step)) in
2290 test_configs.iter().enumerate()
2291 {
2292 let output = SarBatchBuilder::new()
2293 .kernel(kernel)
2294 .acceleration_range(a_start, a_end, a_step)
2295 .maximum_range(m_start, m_end, m_step)
2296 .apply_candles(&c)?;
2297
2298 for (idx, &val) in output.values.iter().enumerate() {
2299 if val.is_nan() {
2300 continue;
2301 }
2302
2303 let bits = val.to_bits();
2304 let row = idx / output.cols;
2305 let col = idx % output.cols;
2306 let combo = &output.combos[row];
2307
2308 if bits == 0x11111111_11111111 {
2309 panic!(
2310 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2311 at row {} col {} (flat index {}) with params: acceleration={}, maximum={}",
2312 test,
2313 cfg_idx,
2314 val,
2315 bits,
2316 row,
2317 col,
2318 idx,
2319 combo.acceleration.unwrap_or(0.02),
2320 combo.maximum.unwrap_or(0.2)
2321 );
2322 }
2323
2324 if bits == 0x22222222_22222222 {
2325 panic!(
2326 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2327 at row {} col {} (flat index {}) with params: acceleration={}, maximum={}",
2328 test,
2329 cfg_idx,
2330 val,
2331 bits,
2332 row,
2333 col,
2334 idx,
2335 combo.acceleration.unwrap_or(0.02),
2336 combo.maximum.unwrap_or(0.2)
2337 );
2338 }
2339
2340 if bits == 0x33333333_33333333 {
2341 panic!(
2342 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2343 at row {} col {} (flat index {}) with params: acceleration={}, maximum={}",
2344 test,
2345 cfg_idx,
2346 val,
2347 bits,
2348 row,
2349 col,
2350 idx,
2351 combo.acceleration.unwrap_or(0.02),
2352 combo.maximum.unwrap_or(0.2)
2353 );
2354 }
2355 }
2356 }
2357
2358 Ok(())
2359 }
2360
2361 #[cfg(not(debug_assertions))]
2362 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2363 Ok(())
2364 }
2365
2366 macro_rules! gen_batch_tests {
2367 ($fn_name:ident) => {
2368 paste::paste! {
2369 #[test] fn [<$fn_name _scalar>]() {
2370 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2371 }
2372 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2373 #[test] fn [<$fn_name _avx2>]() {
2374 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2375 }
2376 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2377 #[test] fn [<$fn_name _avx512>]() {
2378 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2379 }
2380 #[test] fn [<$fn_name _auto_detect>]() {
2381 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2382 }
2383 }
2384 };
2385 }
2386 gen_batch_tests!(check_batch_default_row);
2387 gen_batch_tests!(check_batch_no_poison);
2388
2389 #[test]
2390 fn test_sar_into_matches_api() -> Result<(), Box<dyn Error>> {
2391 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2392 let candles = read_candles_from_csv(file_path)?;
2393
2394 let input = SarInput::from_candles(&candles, SarParams::default());
2395
2396 let SarOutput { values: expected } = sar(&input)?;
2397
2398 let mut actual = vec![0.0; candles.high.len()];
2399 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2400 {
2401 sar_into(&input, &mut actual)?;
2402 }
2403 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2404 {
2405 sar_into_slice(&mut actual, &input, Kernel::Auto)?;
2406 }
2407
2408 assert_eq!(expected.len(), actual.len());
2409
2410 fn eq_or_both_nan(a: f64, b: f64) -> bool {
2411 (a.is_nan() && b.is_nan()) || (a == b)
2412 }
2413
2414 for i in 0..expected.len() {
2415 assert!(
2416 eq_or_both_nan(expected[i], actual[i]),
2417 "Mismatch at index {}: expected {:?}, got {:?}",
2418 i,
2419 expected[i],
2420 actual[i]
2421 );
2422 }
2423 Ok(())
2424 }
2425}