Skip to main content

vector_ta/indicators/
sar.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use cust::context::Context;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use cust::memory::DeviceBuffer;
7#[cfg(feature = "python")]
8use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
9#[cfg(feature = "python")]
10use pyo3::exceptions::PyValueError;
11#[cfg(feature = "python")]
12use pyo3::prelude::*;
13#[cfg(feature = "python")]
14use pyo3::types::PyDict;
15
16#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
17use serde::{Deserialize, Serialize};
18#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
19use wasm_bindgen::prelude::*;
20
21#[cfg(all(feature = "python", feature = "cuda"))]
22use std::sync::Arc;
23
24use crate::utilities::data_loader::Candles;
25use crate::utilities::enums::Kernel;
26use crate::utilities::helpers::{
27    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
28    make_uninit_matrix,
29};
30#[cfg(feature = "python")]
31use crate::utilities::kernel_validation::validate_kernel;
32#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
33use core::arch::x86_64::*;
34#[cfg(not(target_arch = "wasm32"))]
35use rayon::prelude::*;
36use std::error::Error;
37use thiserror::Error;
38
39#[derive(Debug, Clone)]
40pub enum SarData<'a> {
41    Candles { candles: &'a Candles },
42    Slices { high: &'a [f64], low: &'a [f64] },
43}
44
45#[derive(Debug, Clone)]
46pub struct SarOutput {
47    pub values: Vec<f64>,
48}
49
50#[derive(Debug, Clone)]
51#[cfg_attr(
52    all(target_arch = "wasm32", feature = "wasm"),
53    derive(Serialize, Deserialize)
54)]
55pub struct SarParams {
56    pub acceleration: Option<f64>,
57    pub maximum: Option<f64>,
58}
59
60impl Default for SarParams {
61    fn default() -> Self {
62        Self {
63            acceleration: Some(0.02),
64            maximum: Some(0.2),
65        }
66    }
67}
68
69#[derive(Debug, Clone)]
70pub struct SarInput<'a> {
71    pub data: SarData<'a>,
72    pub params: SarParams,
73}
74
75impl<'a> SarInput<'a> {
76    #[inline]
77    pub fn from_candles(candles: &'a Candles, params: SarParams) -> Self {
78        Self {
79            data: SarData::Candles { candles },
80            params,
81        }
82    }
83
84    #[inline]
85    pub fn from_slices(high: &'a [f64], low: &'a [f64], params: SarParams) -> Self {
86        Self {
87            data: SarData::Slices { high, low },
88            params,
89        }
90    }
91
92    #[inline]
93    pub fn with_default_candles(candles: &'a Candles) -> Self {
94        Self {
95            data: SarData::Candles { candles },
96            params: SarParams::default(),
97        }
98    }
99
100    #[inline]
101    pub fn get_acceleration(&self) -> f64 {
102        self.params.acceleration.unwrap_or(0.02)
103    }
104
105    #[inline]
106    pub fn get_maximum(&self) -> f64 {
107        self.params.maximum.unwrap_or(0.2)
108    }
109}
110
111#[derive(Copy, Clone, Debug)]
112pub struct SarBuilder {
113    acceleration: Option<f64>,
114    maximum: Option<f64>,
115    kernel: Kernel,
116}
117
118impl Default for SarBuilder {
119    fn default() -> Self {
120        Self {
121            acceleration: None,
122            maximum: None,
123            kernel: Kernel::Auto,
124        }
125    }
126}
127
128impl SarBuilder {
129    #[inline(always)]
130    pub fn new() -> Self {
131        Self::default()
132    }
133    #[inline(always)]
134    pub fn acceleration(mut self, v: f64) -> Self {
135        self.acceleration = Some(v);
136        self
137    }
138    #[inline(always)]
139    pub fn maximum(mut self, v: f64) -> Self {
140        self.maximum = Some(v);
141        self
142    }
143    #[inline(always)]
144    pub fn kernel(mut self, k: Kernel) -> Self {
145        self.kernel = k;
146        self
147    }
148    #[inline(always)]
149    pub fn apply(self, c: &Candles) -> Result<SarOutput, SarError> {
150        let params = SarParams {
151            acceleration: self.acceleration,
152            maximum: self.maximum,
153        };
154        let input = SarInput::from_candles(c, params);
155        sar_with_kernel(&input, self.kernel)
156    }
157    #[inline(always)]
158    pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<SarOutput, SarError> {
159        let params = SarParams {
160            acceleration: self.acceleration,
161            maximum: self.maximum,
162        };
163        let input = SarInput::from_slices(high, low, params);
164        sar_with_kernel(&input, self.kernel)
165    }
166    #[inline(always)]
167    pub fn into_stream(self) -> Result<SarStream, SarError> {
168        let params = SarParams {
169            acceleration: self.acceleration,
170            maximum: self.maximum,
171        };
172        SarStream::try_new(params)
173    }
174}
175
176#[derive(Debug, Error)]
177pub enum SarError {
178    #[error("sar: Empty data provided.")]
179    EmptyInputData,
180    #[error("sar: All values are NaN.")]
181    AllValuesNaN,
182    #[error("sar: Not enough valid data. needed = {needed}, valid = {valid}")]
183    NotEnoughValidData { needed: usize, valid: usize },
184    #[error("sar: Invalid acceleration: {acceleration}")]
185    InvalidAcceleration { acceleration: f64 },
186    #[error("sar: Invalid maximum: {maximum}")]
187    InvalidMaximum { maximum: f64 },
188    #[error("sar: Output length mismatch: expected = {expected}, got = {got}")]
189    OutputLengthMismatch { expected: usize, got: usize },
190    #[error("sar: Invalid parameter range: start = {start}, end = {end}, step = {step}")]
191    InvalidRange { start: f64, end: f64, step: f64 },
192    #[error("sar: Invalid kernel for batch: {0:?}")]
193    InvalidKernelForBatch(Kernel),
194}
195
196#[inline]
197pub fn sar(input: &SarInput) -> Result<SarOutput, SarError> {
198    sar_with_kernel(input, Kernel::Auto)
199}
200
201pub fn sar_with_kernel(input: &SarInput, kernel: Kernel) -> Result<SarOutput, SarError> {
202    let (high, low) = match &input.data {
203        SarData::Candles { candles } => (candles.high.as_slice(), candles.low.as_slice()),
204        SarData::Slices { high, low } => (*high, *low),
205    };
206
207    if high.is_empty() || low.is_empty() {
208        return Err(SarError::EmptyInputData);
209    }
210
211    let min_len = high.len().min(low.len());
212    let (high, low) = (&high[..min_len], &low[..min_len]);
213
214    let first_valid_idx = high
215        .iter()
216        .zip(low.iter())
217        .position(|(&h, &l)| !h.is_nan() && !l.is_nan());
218    let first = match first_valid_idx {
219        Some(idx) => idx,
220        None => return Err(SarError::AllValuesNaN),
221    };
222
223    if (high.len() - first) < 2 {
224        return Err(SarError::NotEnoughValidData {
225            needed: 2,
226            valid: high.len() - first,
227        });
228    }
229
230    let acceleration = input.get_acceleration();
231    let maximum = input.get_maximum();
232
233    if !(acceleration > 0.0) || acceleration.is_nan() || acceleration.is_infinite() {
234        return Err(SarError::InvalidAcceleration { acceleration });
235    }
236    if !(maximum > 0.0) || maximum.is_nan() || maximum.is_infinite() {
237        return Err(SarError::InvalidMaximum { maximum });
238    }
239
240    let mut out = alloc_with_nan_prefix(high.len(), first + 1);
241
242    let chosen = match kernel {
243        Kernel::Auto => Kernel::Scalar,
244        other => other,
245    };
246
247    unsafe {
248        match chosen {
249            Kernel::Scalar | Kernel::ScalarBatch => {
250                sar_scalar(high, low, first, acceleration, maximum, &mut out)
251            }
252            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
253            Kernel::Avx2 | Kernel::Avx2Batch => {
254                sar_avx2(high, low, first, acceleration, maximum, &mut out)
255            }
256            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
257            Kernel::Avx512 | Kernel::Avx512Batch => {
258                sar_avx512(high, low, first, acceleration, maximum, &mut out)
259            }
260            _ => unreachable!(),
261        }
262    }
263
264    Ok(SarOutput { values: out })
265}
266
267#[inline]
268pub fn sar_into_slice(dst: &mut [f64], input: &SarInput, kern: Kernel) -> Result<(), SarError> {
269    let (high, low) = match &input.data {
270        SarData::Candles { candles } => (candles.high.as_slice(), candles.low.as_slice()),
271        SarData::Slices { high, low } => (*high, *low),
272    };
273
274    if high.is_empty() || low.is_empty() {
275        return Err(SarError::EmptyInputData);
276    }
277
278    let expected_len = high.len().min(low.len());
279    if dst.len() != expected_len {
280        return Err(SarError::OutputLengthMismatch {
281            expected: expected_len,
282            got: dst.len(),
283        });
284    }
285
286    let (high, low) = (&high[..expected_len], &low[..expected_len]);
287
288    let first_valid_idx = high
289        .iter()
290        .zip(low.iter())
291        .position(|(&h, &l)| !h.is_nan() && !l.is_nan());
292    let first = match first_valid_idx {
293        Some(idx) => idx,
294        None => return Err(SarError::AllValuesNaN),
295    };
296
297    if (high.len() - first) < 2 {
298        return Err(SarError::NotEnoughValidData {
299            valid: high.len() - first,
300            needed: 2,
301        });
302    }
303
304    let acceleration = input.params.acceleration.unwrap_or(0.02);
305    let maximum = input.params.maximum.unwrap_or(0.2);
306
307    if acceleration <= 0.0 || acceleration.is_nan() || acceleration.is_infinite() {
308        return Err(SarError::InvalidAcceleration { acceleration });
309    }
310    if maximum <= 0.0 || maximum.is_nan() || maximum.is_infinite() {
311        return Err(SarError::InvalidMaximum { maximum });
312    }
313
314    for v in &mut dst[..first.saturating_add(1)] {
315        *v = f64::from_bits(0x7ff8_0000_0000_0000);
316    }
317
318    let chosen = match kern {
319        Kernel::Auto => Kernel::Scalar,
320        x => x,
321    };
322
323    match chosen {
324        Kernel::Scalar => sar_scalar(high, low, first, acceleration, maximum, dst),
325        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
326        Kernel::Avx2 | Kernel::Avx2Batch => sar_avx2(high, low, first, acceleration, maximum, dst),
327        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
328        Kernel::Avx512 | Kernel::Avx512Batch => {
329            sar_avx512(high, low, first, acceleration, maximum, dst)
330        }
331
332        _ => sar_scalar(high, low, first, acceleration, maximum, dst),
333    }
334    Ok(())
335}
336
337#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
338#[inline]
339pub fn sar_into(input: &SarInput, out: &mut [f64]) -> Result<(), SarError> {
340    sar_into_slice(out, input, Kernel::Auto)
341}
342
343#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
344#[inline]
345pub fn sar_avx512(
346    high: &[f64],
347    low: &[f64],
348    first_valid: usize,
349    acceleration: f64,
350    maximum: f64,
351    out: &mut [f64],
352) {
353    sar_avx2(high, low, first_valid, acceleration, maximum, out)
354}
355
356#[inline]
357pub fn sar_scalar(
358    high: &[f64],
359    low: &[f64],
360    first: usize,
361    acceleration: f64,
362    maximum: f64,
363    out: &mut [f64],
364) {
365    let len = high.len();
366    let i0 = first;
367    let i1 = i0 + 1;
368
369    if i1 >= len || i1 >= low.len() || i1 >= out.len() {
370        return;
371    }
372
373    let h0 = high[i0];
374    let h1 = high[i1];
375    let l0 = low[i0];
376    let l1 = low[i1];
377
378    let mut trend_up = h1 > h0;
379    let mut sar = if trend_up { l0 } else { h0 };
380    let mut ep = if trend_up { h1 } else { l1 };
381    let mut acc = acceleration;
382
383    out[i0] = f64::NAN;
384    out[i1] = sar;
385
386    let mut low_prev2 = l0;
387    let mut low_prev = l1;
388    let mut high_prev2 = h0;
389    let mut high_prev = h1;
390
391    let mut i = i1 + 1;
392    while i < len {
393        let hi = high[i];
394        let lo = low[i];
395
396        let mut next_sar = acc.mul_add(ep - sar, sar);
397
398        if trend_up {
399            if lo < next_sar {
400                trend_up = false;
401                next_sar = ep;
402                ep = lo;
403                acc = acceleration;
404            } else {
405                if hi > ep {
406                    ep = hi;
407                    acc = (acc + acceleration).min(maximum);
408                }
409                next_sar = next_sar.min(low_prev).min(low_prev2);
410            }
411        } else {
412            if hi > next_sar {
413                trend_up = true;
414                next_sar = ep;
415                ep = hi;
416                acc = acceleration;
417            } else {
418                if lo < ep {
419                    ep = lo;
420                    acc = (acc + acceleration).min(maximum);
421                }
422                next_sar = next_sar.max(high_prev).max(high_prev2);
423            }
424        }
425
426        out[i] = next_sar;
427        sar = next_sar;
428
429        low_prev2 = low_prev;
430        low_prev = lo;
431        high_prev2 = high_prev;
432        high_prev = hi;
433
434        i += 1;
435    }
436}
437
438#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
439#[inline]
440pub fn sar_avx2(
441    high: &[f64],
442    low: &[f64],
443    first_valid: usize,
444    acceleration: f64,
445    maximum: f64,
446    out: &mut [f64],
447) {
448    let len = high.len();
449    let i0 = first_valid;
450    let i1 = i0 + 1;
451
452    if i1 >= len || i1 >= low.len() || i1 >= out.len() {
453        return;
454    }
455
456    unsafe {
457        let h0 = *high.get_unchecked(i0);
458        let h1 = *high.get_unchecked(i1);
459        let l0 = *low.get_unchecked(i0);
460        let l1 = *low.get_unchecked(i1);
461
462        let mut trend_up = h1 > h0;
463        let mut sar = if trend_up { l0 } else { h0 };
464        let mut ep = if trend_up { h1 } else { l1 };
465        let mut acc = acceleration;
466
467        *out.get_unchecked_mut(i0) = f64::NAN;
468        *out.get_unchecked_mut(i1) = sar;
469
470        let mut low_prev2 = l0;
471        let mut low_prev = l1;
472        let mut high_prev2 = h0;
473        let mut high_prev = h1;
474
475        let mut i = i1 + 1;
476        while i < len {
477            let hi = *high.get_unchecked(i);
478            let lo = *low.get_unchecked(i);
479
480            let mut next_sar = acc.mul_add(ep - sar, sar);
481
482            if trend_up {
483                if lo < next_sar {
484                    trend_up = false;
485                    next_sar = ep;
486                    ep = lo;
487                    acc = acceleration;
488                } else {
489                    if hi > ep {
490                        ep = hi;
491                        acc = (acc + acceleration).min(maximum);
492                    }
493                    next_sar = next_sar.min(low_prev).min(low_prev2);
494                }
495            } else {
496                if hi > next_sar {
497                    trend_up = true;
498                    next_sar = ep;
499                    ep = hi;
500                    acc = acceleration;
501                } else {
502                    if lo < ep {
503                        ep = lo;
504                        acc = (acc + acceleration).min(maximum);
505                    }
506                    next_sar = next_sar.max(high_prev).max(high_prev2);
507                }
508            }
509
510            *out.get_unchecked_mut(i) = next_sar;
511            sar = next_sar;
512
513            low_prev2 = low_prev;
514            low_prev = lo;
515            high_prev2 = high_prev;
516            high_prev = hi;
517
518            i += 1;
519        }
520    }
521}
522
523#[cfg(all(feature = "simd128", target_arch = "wasm32"))]
524#[inline]
525pub fn sar_simd128(
526    high: &[f64],
527    low: &[f64],
528    first_valid: usize,
529    acceleration: f64,
530    maximum: f64,
531    out: &mut [f64],
532) {
533    sar_scalar(high, low, first_valid, acceleration, maximum, out)
534}
535
536#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
537#[inline]
538pub unsafe fn sar_avx512_short(
539    high: &[f64],
540    low: &[f64],
541    first_valid: usize,
542    acceleration: f64,
543    maximum: f64,
544    out: &mut [f64],
545) {
546    sar_avx2(high, low, first_valid, acceleration, maximum, out)
547}
548
549#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
550#[inline]
551pub unsafe fn sar_avx512_long(
552    high: &[f64],
553    low: &[f64],
554    first_valid: usize,
555    acceleration: f64,
556    maximum: f64,
557    out: &mut [f64],
558) {
559    sar_avx2(high, low, first_valid, acceleration, maximum, out)
560}
561
562#[derive(Debug, Clone)]
563pub struct SarStream {
564    acceleration: f64,
565    maximum: f64,
566    state: Option<StreamState>,
567
568    idx: usize,
569}
570
571#[derive(Debug, Clone)]
572struct StreamState {
573    trend_up: bool,
574    sar: f64,
575    ep: f64,
576    acc: f64,
577
578    prev_high: f64,
579    prev_high2: f64,
580    prev_low: f64,
581    prev_low2: f64,
582}
583
584impl SarStream {
585    pub fn try_new(params: SarParams) -> Result<Self, SarError> {
586        let acceleration = params.acceleration.unwrap_or(0.02);
587        let maximum = params.maximum.unwrap_or(0.2);
588
589        if !(acceleration > 0.0) || !acceleration.is_finite() {
590            return Err(SarError::InvalidAcceleration { acceleration });
591        }
592        if !(maximum > 0.0) || !maximum.is_finite() {
593            return Err(SarError::InvalidMaximum { maximum });
594        }
595
596        Ok(Self {
597            acceleration,
598            maximum,
599            state: None,
600            idx: 0,
601        })
602    }
603
604    #[inline(always)]
605    pub fn update(&mut self, high: f64, low: f64) -> Option<f64> {
606        if !high.is_finite() || !low.is_finite() {
607            return None;
608        }
609
610        match self.state.as_mut() {
611            None => {
612                self.state = Some(StreamState {
613                    trend_up: false,
614                    sar: f64::NAN,
615                    ep: f64::NAN,
616                    acc: self.acceleration,
617                    prev_high: high,
618                    prev_high2: high,
619                    prev_low: low,
620                    prev_low2: low,
621                });
622                self.idx = 1;
623                None
624            }
625
626            Some(st) if self.idx == 1 => {
627                let trend_up = high > st.prev_high;
628
629                let sar = if trend_up { st.prev_low } else { st.prev_high };
630                let ep = if trend_up { high } else { low };
631
632                st.prev_high2 = st.prev_high;
633                st.prev_low2 = st.prev_low;
634                st.prev_high = high;
635                st.prev_low = low;
636
637                st.trend_up = trend_up;
638                st.sar = sar;
639                st.ep = ep;
640                st.acc = self.acceleration;
641
642                self.idx = 2;
643                Some(sar)
644            }
645
646            Some(st) => {
647                let mut next_sar = st.acc.mul_add(st.ep - st.sar, st.sar);
648
649                if st.trend_up {
650                    if low < next_sar {
651                        st.trend_up = false;
652                        next_sar = st.ep;
653                        st.ep = low;
654                        st.acc = self.acceleration;
655                    } else {
656                        if high > st.ep {
657                            st.ep = high;
658                            st.acc = (st.acc + self.acceleration).min(self.maximum);
659                        }
660                        next_sar = min3(next_sar, st.prev_low, st.prev_low2);
661                    }
662                } else {
663                    if high > next_sar {
664                        st.trend_up = true;
665                        next_sar = st.ep;
666                        st.ep = high;
667                        st.acc = self.acceleration;
668                    } else {
669                        if low < st.ep {
670                            st.ep = low;
671                            st.acc = (st.acc + self.acceleration).min(self.maximum);
672                        }
673                        next_sar = max3(next_sar, st.prev_high, st.prev_high2);
674                    }
675                }
676
677                st.prev_high2 = st.prev_high;
678                st.prev_low2 = st.prev_low;
679                st.prev_high = high;
680                st.prev_low = low;
681
682                st.sar = next_sar;
683                self.idx += 1;
684                Some(next_sar)
685            }
686        }
687    }
688}
689
690#[inline(always)]
691fn min3(a: f64, b: f64, c: f64) -> f64 {
692    a.min(b.min(c))
693}
694
695#[inline(always)]
696fn max3(a: f64, b: f64, c: f64) -> f64 {
697    a.max(b.max(c))
698}
699
700#[derive(Clone, Debug)]
701pub struct SarBatchRange {
702    pub acceleration: (f64, f64, f64),
703    pub maximum: (f64, f64, f64),
704}
705
706impl Default for SarBatchRange {
707    fn default() -> Self {
708        Self {
709            acceleration: (0.02, 0.02, 0.0),
710            maximum: (0.2, 0.449, 0.001),
711        }
712    }
713}
714
715#[derive(Clone, Debug, Default)]
716pub struct SarBatchBuilder {
717    range: SarBatchRange,
718    kernel: Kernel,
719}
720
721impl SarBatchBuilder {
722    pub fn new() -> Self {
723        Self::default()
724    }
725    pub fn kernel(mut self, k: Kernel) -> Self {
726        self.kernel = k;
727        self
728    }
729
730    pub fn acceleration_range(mut self, start: f64, end: f64, step: f64) -> Self {
731        self.range.acceleration = (start, end, step);
732        self
733    }
734    pub fn acceleration_static(mut self, x: f64) -> Self {
735        self.range.acceleration = (x, x, 0.0);
736        self
737    }
738    pub fn maximum_range(mut self, start: f64, end: f64, step: f64) -> Self {
739        self.range.maximum = (start, end, step);
740        self
741    }
742    pub fn maximum_static(mut self, x: f64) -> Self {
743        self.range.maximum = (x, x, 0.0);
744        self
745    }
746
747    pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<SarBatchOutput, SarError> {
748        sar_batch_with_kernel(high, low, &self.range, self.kernel)
749    }
750
751    pub fn with_default_slices(
752        high: &[f64],
753        low: &[f64],
754        k: Kernel,
755    ) -> Result<SarBatchOutput, SarError> {
756        SarBatchBuilder::new().kernel(k).apply_slices(high, low)
757    }
758
759    pub fn apply_candles(self, c: &Candles) -> Result<SarBatchOutput, SarError> {
760        self.apply_slices(&c.high, &c.low)
761    }
762
763    pub fn with_default_candles(c: &Candles) -> Result<SarBatchOutput, SarError> {
764        SarBatchBuilder::new().kernel(Kernel::Auto).apply_candles(c)
765    }
766}
767
768pub fn sar_batch_with_kernel(
769    high: &[f64],
770    low: &[f64],
771    sweep: &SarBatchRange,
772    k: Kernel,
773) -> Result<SarBatchOutput, SarError> {
774    let kernel = match k {
775        Kernel::Auto => detect_best_batch_kernel(),
776        other if other.is_batch() => other,
777        other => return Err(SarError::InvalidKernelForBatch(other)),
778    };
779    let simd = match kernel {
780        Kernel::Avx512Batch => Kernel::Avx512,
781        Kernel::Avx2Batch => Kernel::Avx2,
782        Kernel::ScalarBatch => Kernel::Scalar,
783        _ => unreachable!(),
784    };
785    sar_batch_par_slice(high, low, sweep, simd)
786}
787
788#[derive(Clone, Debug)]
789pub struct SarBatchOutput {
790    pub values: Vec<f64>,
791    pub combos: Vec<SarParams>,
792    pub rows: usize,
793    pub cols: usize,
794}
795impl SarBatchOutput {
796    pub fn row_for_params(&self, p: &SarParams) -> Option<usize> {
797        self.combos.iter().position(|c| {
798            (c.acceleration.unwrap_or(0.02) - p.acceleration.unwrap_or(0.02)).abs() < 1e-12
799                && (c.maximum.unwrap_or(0.2) - p.maximum.unwrap_or(0.2)).abs() < 1e-12
800        })
801    }
802    pub fn values_for(&self, p: &SarParams) -> Option<&[f64]> {
803        self.row_for_params(p).map(|row| {
804            let start = row * self.cols;
805            &self.values[start..start + self.cols]
806        })
807    }
808}
809
810#[inline(always)]
811fn axis_f64_checked(axis: (f64, f64, f64)) -> Result<Vec<f64>, SarError> {
812    let (start, end, step) = axis;
813    if !step.is_finite() {
814        return Err(SarError::InvalidRange { start, end, step });
815    }
816    if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
817        return Ok(vec![start]);
818    }
819
820    let mut v = Vec::new();
821    let tol = step.abs() * 1e-12;
822
823    if step > 0.0 {
824        if start <= end {
825            let mut x = start;
826            while x <= end + tol {
827                v.push(x);
828                x += step;
829            }
830        } else {
831            let mut x = start;
832            while x >= end - tol {
833                v.push(x);
834                x -= step;
835            }
836        }
837    } else {
838        if start >= end {
839            let mut x = start;
840            while x >= end - tol {
841                v.push(x);
842                x += step;
843            }
844        } else {
845            return Err(SarError::InvalidRange { start, end, step });
846        }
847    }
848
849    if v.is_empty() {
850        Err(SarError::InvalidRange { start, end, step })
851    } else {
852        Ok(v)
853    }
854}
855
856#[inline(always)]
857fn expand_grid(r: &SarBatchRange) -> Result<Vec<SarParams>, SarError> {
858    let accs = axis_f64_checked(r.acceleration)?;
859    let maxs = axis_f64_checked(r.maximum)?;
860
861    let capacity = accs
862        .len()
863        .checked_mul(maxs.len())
864        .ok_or(SarError::InvalidRange {
865            start: r.acceleration.0,
866            end: r.acceleration.1,
867            step: r.acceleration.2,
868        })?;
869
870    let mut out = Vec::with_capacity(capacity);
871    for &a in &accs {
872        for &m in &maxs {
873            out.push(SarParams {
874                acceleration: Some(a),
875                maximum: Some(m),
876            });
877        }
878    }
879    Ok(out)
880}
881
882#[inline(always)]
883pub fn sar_batch_slice(
884    high: &[f64],
885    low: &[f64],
886    sweep: &SarBatchRange,
887    kern: Kernel,
888) -> Result<SarBatchOutput, SarError> {
889    sar_batch_inner(high, low, sweep, kern, false)
890}
891
892#[inline(always)]
893pub fn sar_batch_par_slice(
894    high: &[f64],
895    low: &[f64],
896    sweep: &SarBatchRange,
897    kern: Kernel,
898) -> Result<SarBatchOutput, SarError> {
899    sar_batch_inner(high, low, sweep, kern, true)
900}
901
902#[inline(always)]
903fn sar_batch_inner(
904    high: &[f64],
905    low: &[f64],
906    sweep: &SarBatchRange,
907    kern: Kernel,
908    parallel: bool,
909) -> Result<SarBatchOutput, SarError> {
910    let combos = expand_grid(sweep)?;
911
912    let min_len = high.len().min(low.len());
913    let (high, low) = (&high[..min_len], &low[..min_len]);
914    let first = high
915        .iter()
916        .zip(low.iter())
917        .position(|(&h, &l)| !h.is_nan() && !l.is_nan())
918        .ok_or(SarError::AllValuesNaN)?;
919
920    if high.len() - first < 2 {
921        return Err(SarError::NotEnoughValidData {
922            needed: 2,
923            valid: high.len() - first,
924        });
925    }
926    let rows = combos.len();
927    let cols = high.len();
928
929    let mut buf_mu = make_uninit_matrix(rows, cols);
930
931    let warm = vec![first + 1; rows];
932    init_matrix_prefixes(&mut buf_mu, cols, &warm);
933
934    let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
935    let values: &mut [f64] = unsafe {
936        core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
937    };
938
939    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
940        let p = &combos[row];
941        match kern {
942            Kernel::Scalar => sar_row_scalar(
943                high,
944                low,
945                first,
946                p.acceleration.unwrap(),
947                p.maximum.unwrap(),
948                out_row,
949            ),
950            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
951            Kernel::Avx2 => sar_row_avx2(
952                high,
953                low,
954                first,
955                p.acceleration.unwrap(),
956                p.maximum.unwrap(),
957                out_row,
958            ),
959            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
960            Kernel::Avx512 => sar_row_avx512(
961                high,
962                low,
963                first,
964                p.acceleration.unwrap(),
965                p.maximum.unwrap(),
966                out_row,
967            ),
968            _ => unreachable!(),
969        }
970    };
971
972    if parallel {
973        #[cfg(not(target_arch = "wasm32"))]
974        {
975            values
976                .par_chunks_mut(cols)
977                .enumerate()
978                .for_each(|(row, slice)| do_row(row, slice));
979        }
980
981        #[cfg(target_arch = "wasm32")]
982        {
983            for (row, slice) in values.chunks_mut(cols).enumerate() {
984                do_row(row, slice);
985            }
986        }
987    } else {
988        for (row, slice) in values.chunks_mut(cols).enumerate() {
989            do_row(row, slice);
990        }
991    }
992
993    let values = unsafe {
994        Vec::from_raw_parts(
995            buf_guard.as_mut_ptr() as *mut f64,
996            buf_guard.len(),
997            buf_guard.capacity(),
998        )
999    };
1000
1001    Ok(SarBatchOutput {
1002        values,
1003        combos,
1004        rows,
1005        cols,
1006    })
1007}
1008
1009#[inline(always)]
1010unsafe fn sar_row_scalar(
1011    high: &[f64],
1012    low: &[f64],
1013    first: usize,
1014    acceleration: f64,
1015    maximum: f64,
1016    out: &mut [f64],
1017) {
1018    sar_scalar(high, low, first, acceleration, maximum, out)
1019}
1020
1021#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1022#[inline(always)]
1023unsafe fn sar_row_avx2(
1024    high: &[f64],
1025    low: &[f64],
1026    first: usize,
1027    acceleration: f64,
1028    maximum: f64,
1029    out: &mut [f64],
1030) {
1031    sar_avx2(high, low, first, acceleration, maximum, out)
1032}
1033
1034#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1035#[inline(always)]
1036pub unsafe fn sar_row_avx512(
1037    high: &[f64],
1038    low: &[f64],
1039    first: usize,
1040    acceleration: f64,
1041    maximum: f64,
1042    out: &mut [f64],
1043) {
1044    sar_avx2(high, low, first, acceleration, maximum, out)
1045}
1046
1047#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1048#[inline(always)]
1049pub unsafe fn sar_row_avx512_short(
1050    high: &[f64],
1051    low: &[f64],
1052    first: usize,
1053    acceleration: f64,
1054    maximum: f64,
1055    out: &mut [f64],
1056) {
1057    sar_avx2(high, low, first, acceleration, maximum, out)
1058}
1059
1060#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1061#[inline(always)]
1062pub unsafe fn sar_row_avx512_long(
1063    high: &[f64],
1064    low: &[f64],
1065    first: usize,
1066    acceleration: f64,
1067    maximum: f64,
1068    out: &mut [f64],
1069) {
1070    sar_avx2(high, low, first, acceleration, maximum, out)
1071}
1072
1073#[inline(always)]
1074fn expand_grid_for_test(r: &SarBatchRange) -> Vec<SarParams> {
1075    expand_grid(r).expect("expand_grid_for_test should not fail")
1076}
1077
1078#[cfg(feature = "python")]
1079#[pyfunction(name = "sar")]
1080#[pyo3(signature = (high, low, acceleration=None, maximum=None, kernel=None))]
1081pub fn sar_py<'py>(
1082    py: Python<'py>,
1083    high: PyReadonlyArray1<'py, f64>,
1084    low: PyReadonlyArray1<'py, f64>,
1085    acceleration: Option<f64>,
1086    maximum: Option<f64>,
1087    kernel: Option<&str>,
1088) -> PyResult<Bound<'py, PyArray1<f64>>> {
1089    let high_slice = high.as_slice()?;
1090    let low_slice = low.as_slice()?;
1091
1092    let min_len = high_slice.len().min(low_slice.len());
1093    let high_trimmed = &high_slice[..min_len];
1094    let low_trimmed = &low_slice[..min_len];
1095
1096    let kern = validate_kernel(kernel, false)?;
1097
1098    let params = SarParams {
1099        acceleration,
1100        maximum,
1101    };
1102    let input = SarInput::from_slices(high_trimmed, low_trimmed, params);
1103
1104    let result_vec: Vec<f64> = py
1105        .allow_threads(|| sar_with_kernel(&input, kern).map(|o| o.values))
1106        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1107
1108    Ok(result_vec.into_pyarray(py))
1109}
1110
1111#[cfg(feature = "python")]
1112#[pyclass(name = "SarStream")]
1113pub struct SarStreamPy {
1114    stream: SarStream,
1115}
1116
1117#[cfg(feature = "python")]
1118#[pymethods]
1119impl SarStreamPy {
1120    #[new]
1121    #[pyo3(signature = (acceleration=None, maximum=None))]
1122    fn new(acceleration: Option<f64>, maximum: Option<f64>) -> PyResult<Self> {
1123        let params = SarParams {
1124            acceleration,
1125            maximum,
1126        };
1127        let stream =
1128            SarStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1129        Ok(SarStreamPy { stream })
1130    }
1131
1132    fn update(&mut self, high: f64, low: f64) -> Option<f64> {
1133        self.stream.update(high, low)
1134    }
1135}
1136
1137#[cfg(feature = "python")]
1138#[pyfunction(name = "sar_batch")]
1139#[pyo3(signature = (high, low, acceleration_range, maximum_range, kernel=None))]
1140pub fn sar_batch_py<'py>(
1141    py: Python<'py>,
1142    high: PyReadonlyArray1<'py, f64>,
1143    low: PyReadonlyArray1<'py, f64>,
1144    acceleration_range: (f64, f64, f64),
1145    maximum_range: (f64, f64, f64),
1146    kernel: Option<&str>,
1147) -> PyResult<Bound<'py, PyDict>> {
1148    let high_slice = high.as_slice()?;
1149    let low_slice = low.as_slice()?;
1150
1151    let min_len = high_slice.len().min(low_slice.len());
1152    let (high_slice, low_slice) = (&high_slice[..min_len], &low_slice[..min_len]);
1153
1154    let sweep = SarBatchRange {
1155        acceleration: acceleration_range,
1156        maximum: maximum_range,
1157    };
1158    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1159    let rows = combos.len();
1160    let cols = min_len;
1161    let total = rows
1162        .checked_mul(cols)
1163        .ok_or_else(|| PyValueError::new_err("sar_batch_py: size overflow in rows*cols"))?;
1164
1165    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1166    let slice_out = unsafe { out_arr.as_slice_mut()? };
1167
1168    let kern = validate_kernel(kernel, true)?;
1169    py.allow_threads(|| {
1170        let k = match kern {
1171            Kernel::Auto => detect_best_batch_kernel(),
1172            k => k,
1173        };
1174        sar_batch_inner_into_noalloc(high_slice, low_slice, &sweep, k, true, slice_out)
1175    })
1176    .map_err(|e| PyValueError::new_err(e.to_string()))?;
1177
1178    let dict = PyDict::new(py);
1179    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1180    dict.set_item(
1181        "accelerations",
1182        combos
1183            .iter()
1184            .map(|p| p.acceleration.unwrap_or(0.02))
1185            .collect::<Vec<_>>()
1186            .into_pyarray(py),
1187    )?;
1188    dict.set_item(
1189        "maximums",
1190        combos
1191            .iter()
1192            .map(|p| p.maximum.unwrap_or(0.2))
1193            .collect::<Vec<_>>()
1194            .into_pyarray(py),
1195    )?;
1196
1197    Ok(dict)
1198}
1199
1200#[cfg(feature = "python")]
1201fn sar_batch_inner_into_noalloc(
1202    high: &[f64],
1203    low: &[f64],
1204    sweep: &SarBatchRange,
1205    kern: Kernel,
1206    parallel: bool,
1207    out: &mut [f64],
1208) -> Result<Vec<SarParams>, SarError> {
1209    let combos = expand_grid(sweep)?;
1210
1211    let first = high
1212        .iter()
1213        .zip(low.iter())
1214        .position(|(&h, &l)| !h.is_nan() && !l.is_nan())
1215        .ok_or(SarError::AllValuesNaN)?;
1216    if high.len() - first < 2 {
1217        return Err(SarError::NotEnoughValidData {
1218            needed: 2,
1219            valid: high.len() - first,
1220        });
1221    }
1222
1223    let rows = combos.len();
1224    let cols = high.len();
1225    let expected = rows.checked_mul(cols).ok_or(SarError::InvalidRange {
1226        start: sweep.acceleration.0,
1227        end: sweep.acceleration.1,
1228        step: sweep.acceleration.2,
1229    })?;
1230    if out.len() != expected {
1231        return Err(SarError::OutputLengthMismatch {
1232            expected,
1233            got: out.len(),
1234        });
1235    }
1236
1237    unsafe {
1238        let out_mu = std::slice::from_raw_parts_mut(
1239            out.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
1240            out.len(),
1241        );
1242        let warm: Vec<usize> = vec![first + 1; rows];
1243        init_matrix_prefixes(out_mu, cols, &warm);
1244    }
1245
1246    let simd = match kern {
1247        Kernel::Auto => detect_best_batch_kernel(),
1248        k => k,
1249    };
1250    let exec = |row: usize, row_out: &mut [f64]| {
1251        let p = &combos[row];
1252        match simd {
1253            Kernel::Scalar | Kernel::ScalarBatch => unsafe {
1254                sar_row_scalar(
1255                    high,
1256                    low,
1257                    first,
1258                    p.acceleration.unwrap(),
1259                    p.maximum.unwrap(),
1260                    row_out,
1261                );
1262            },
1263            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1264            Kernel::Avx2 | Kernel::Avx2Batch => unsafe {
1265                sar_row_avx2(
1266                    high,
1267                    low,
1268                    first,
1269                    p.acceleration.unwrap(),
1270                    p.maximum.unwrap(),
1271                    row_out,
1272                );
1273            },
1274            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1275            Kernel::Avx512 | Kernel::Avx512Batch => unsafe {
1276                sar_row_avx512(
1277                    high,
1278                    low,
1279                    first,
1280                    p.acceleration.unwrap(),
1281                    p.maximum.unwrap(),
1282                    row_out,
1283                );
1284            },
1285            _ => unsafe {
1286                sar_row_scalar(
1287                    high,
1288                    low,
1289                    first,
1290                    p.acceleration.unwrap(),
1291                    p.maximum.unwrap(),
1292                    row_out,
1293                );
1294            },
1295        }
1296    };
1297
1298    if parallel {
1299        #[cfg(not(target_arch = "wasm32"))]
1300        {
1301            use rayon::prelude::*;
1302            out.par_chunks_mut(cols)
1303                .enumerate()
1304                .for_each(|(row, r)| exec(row, r));
1305        }
1306        #[cfg(target_arch = "wasm32")]
1307        {
1308            for (row, r) in out.chunks_mut(cols).enumerate() {
1309                exec(row, r);
1310            }
1311        }
1312    } else {
1313        for (row, r) in out.chunks_mut(cols).enumerate() {
1314            exec(row, r);
1315        }
1316    }
1317
1318    Ok(combos)
1319}
1320
1321#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1322#[wasm_bindgen]
1323pub fn sar_js(
1324    high: &[f64],
1325    low: &[f64],
1326    acceleration: f64,
1327    maximum: f64,
1328) -> Result<Vec<f64>, JsValue> {
1329    let min_len = high.len().min(low.len());
1330    let (high, low) = (&high[..min_len], &low[..min_len]);
1331
1332    let params = SarParams {
1333        acceleration: Some(acceleration),
1334        maximum: Some(maximum),
1335    };
1336    let input = SarInput::from_slices(high, low, params);
1337
1338    let mut output = vec![0.0; min_len];
1339
1340    sar_into_slice(&mut output, &input, Kernel::Auto)
1341        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1342
1343    Ok(output)
1344}
1345
1346#[cfg(all(feature = "python", feature = "cuda"))]
1347#[pyclass(module = "ta_indicators.cuda", unsendable)]
1348pub struct SarDeviceArrayF32Py {
1349    pub(crate) buf: Option<DeviceBuffer<f32>>,
1350    pub(crate) rows: usize,
1351    pub(crate) cols: usize,
1352    pub(crate) _ctx: Arc<Context>,
1353    pub(crate) device_id: u32,
1354}
1355
1356#[cfg(all(feature = "python", feature = "cuda"))]
1357#[pymethods]
1358impl SarDeviceArrayF32Py {
1359    #[getter]
1360    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1361        let d = PyDict::new(py);
1362        d.set_item("shape", (self.rows, self.cols))?;
1363        d.set_item("typestr", "<f4")?;
1364        d.set_item(
1365            "strides",
1366            (
1367                self.cols * std::mem::size_of::<f32>(),
1368                std::mem::size_of::<f32>(),
1369            ),
1370        )?;
1371        let buf = self
1372            .buf
1373            .as_ref()
1374            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1375        let ptr = buf.as_device_ptr().as_raw() as usize;
1376        d.set_item("data", (ptr, false))?;
1377
1378        d.set_item("version", 3)?;
1379        Ok(d)
1380    }
1381
1382    fn __dlpack_device__(&self) -> (i32, i32) {
1383        (2, self.device_id as i32)
1384    }
1385
1386    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1387    fn __dlpack__<'py>(
1388        &mut self,
1389        py: Python<'py>,
1390        stream: Option<pyo3::PyObject>,
1391        max_version: Option<pyo3::PyObject>,
1392        dl_device: Option<pyo3::PyObject>,
1393        copy: Option<pyo3::PyObject>,
1394    ) -> PyResult<PyObject> {
1395        let (kdl, alloc_dev) = self.__dlpack_device__();
1396        if let Some(dev_obj) = dl_device.as_ref() {
1397            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1398                if dev_ty != kdl || dev_id != alloc_dev {
1399                    let wants_copy = copy
1400                        .as_ref()
1401                        .and_then(|c| c.extract::<bool>(py).ok())
1402                        .unwrap_or(false);
1403                    if wants_copy {
1404                        return Err(PyValueError::new_err(
1405                            "cross-device DLPack copy is not implemented for SarDeviceArrayF32Py",
1406                        ));
1407                    } else {
1408                        return Err(PyValueError::new_err(
1409                            "requested dl_device does not match SAR producer device",
1410                        ));
1411                    }
1412                }
1413            }
1414        }
1415        let _ = stream;
1416
1417        let buf = self
1418            .buf
1419            .take()
1420            .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
1421
1422        let rows = self.rows;
1423        let cols = self.cols;
1424
1425        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1426
1427        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1428    }
1429}
1430
1431#[cfg(all(feature = "python", feature = "cuda"))]
1432#[pyfunction(name = "sar_cuda_batch_dev")]
1433#[pyo3(signature = (high_f32, low_f32, acceleration_range, maximum_range, device_id=0))]
1434pub fn sar_cuda_batch_dev_py(
1435    py: Python<'_>,
1436    high_f32: numpy::PyReadonlyArray1<'_, f32>,
1437    low_f32: numpy::PyReadonlyArray1<'_, f32>,
1438    acceleration_range: (f64, f64, f64),
1439    maximum_range: (f64, f64, f64),
1440    device_id: usize,
1441) -> PyResult<SarDeviceArrayF32Py> {
1442    use crate::cuda::cuda_available;
1443    use crate::cuda::CudaSar;
1444
1445    if !cuda_available() {
1446        return Err(PyValueError::new_err("CUDA not available"));
1447    }
1448    let h = high_f32.as_slice()?;
1449    let l = low_f32.as_slice()?;
1450    if h.len() != l.len() {
1451        return Err(PyValueError::new_err("high/low length mismatch"));
1452    }
1453    let sweep = SarBatchRange {
1454        acceleration: acceleration_range,
1455        maximum: maximum_range,
1456    };
1457    let (buf, rows, cols, ctx, dev_id) = py.allow_threads(|| {
1458        let cuda = CudaSar::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1459        let (dev, _combos) = cuda
1460            .sar_batch_dev(h, l, &sweep)
1461            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1462        let ctx = cuda.context_arc();
1463        Ok::<_, pyo3::PyErr>((dev.buf, dev.rows, dev.cols, ctx, cuda.device_id()))
1464    })?;
1465    Ok(SarDeviceArrayF32Py {
1466        buf: Some(buf),
1467        rows,
1468        cols,
1469        _ctx: ctx,
1470        device_id: dev_id,
1471    })
1472}
1473
1474#[cfg(all(feature = "python", feature = "cuda"))]
1475#[pyfunction(name = "sar_cuda_many_series_one_param_dev")]
1476#[pyo3(signature = (high_tm_f32, low_tm_f32, cols, rows, acceleration, maximum, device_id=0))]
1477pub fn sar_cuda_many_series_one_param_dev_py(
1478    py: Python<'_>,
1479    high_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
1480    low_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
1481    cols: usize,
1482    rows: usize,
1483    acceleration: f64,
1484    maximum: f64,
1485    device_id: usize,
1486) -> PyResult<SarDeviceArrayF32Py> {
1487    use crate::cuda::cuda_available;
1488    use crate::cuda::CudaSar;
1489
1490    if !cuda_available() {
1491        return Err(PyValueError::new_err("CUDA not available"));
1492    }
1493    let h = high_tm_f32.as_slice()?;
1494    let l = low_tm_f32.as_slice()?;
1495    let expected = cols.checked_mul(rows).ok_or_else(|| {
1496        PyValueError::new_err("sar_cuda_many_series_one_param_dev: size overflow in cols*rows")
1497    })?;
1498    if expected != h.len() || h.len() != l.len() {
1499        return Err(PyValueError::new_err(
1500            "time‑major inputs must be equal length and cols*rows",
1501        ));
1502    }
1503    let params = SarParams {
1504        acceleration: Some(acceleration),
1505        maximum: Some(maximum),
1506    };
1507    let (buf, r_out, c_out, ctx, dev_id) = py.allow_threads(|| {
1508        let cuda = CudaSar::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1509        let dev = cuda
1510            .sar_many_series_one_param_time_major_dev(h, l, cols, rows, &params)
1511            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1512        let ctx = cuda.context_arc();
1513        Ok::<_, pyo3::PyErr>((dev.buf, dev.rows, dev.cols, ctx, cuda.device_id()))
1514    })?;
1515    Ok(SarDeviceArrayF32Py {
1516        buf: Some(buf),
1517        rows: r_out,
1518        cols: c_out,
1519        _ctx: ctx,
1520        device_id: dev_id,
1521    })
1522}
1523
1524#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1525#[wasm_bindgen]
1526pub fn sar_into(
1527    high_ptr: *const f64,
1528    low_ptr: *const f64,
1529    out_ptr: *mut f64,
1530    len: usize,
1531    acceleration: f64,
1532    maximum: f64,
1533) -> Result<(), JsValue> {
1534    if high_ptr.is_null() || low_ptr.is_null() || out_ptr.is_null() {
1535        return Err(JsValue::from_str("null pointer passed to sar_into"));
1536    }
1537
1538    unsafe {
1539        let high_slice = std::slice::from_raw_parts(high_ptr, len);
1540        let low_slice = std::slice::from_raw_parts(low_ptr, len);
1541
1542        let params = SarParams {
1543            acceleration: Some(acceleration),
1544            maximum: Some(maximum),
1545        };
1546        let input = SarInput::from_slices(high_slice, low_slice, params);
1547
1548        if high_ptr == out_ptr || low_ptr == out_ptr {
1549            let mut temp = vec![0.0; len];
1550            sar_into_slice(&mut temp, &input, Kernel::Auto)
1551                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1552            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1553            out.copy_from_slice(&temp);
1554        } else {
1555            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1556            sar_into_slice(out, &input, Kernel::Auto)
1557                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1558        }
1559
1560        Ok(())
1561    }
1562}
1563
1564#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1565#[wasm_bindgen]
1566pub fn sar_alloc(len: usize) -> *mut f64 {
1567    let mut vec = Vec::<f64>::with_capacity(len);
1568    let ptr = vec.as_mut_ptr();
1569    std::mem::forget(vec);
1570    ptr
1571}
1572
1573#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1574#[wasm_bindgen]
1575pub fn sar_free(ptr: *mut f64, len: usize) {
1576    if !ptr.is_null() {
1577        unsafe {
1578            let _ = Vec::from_raw_parts(ptr, len, len);
1579        }
1580    }
1581}
1582
1583#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1584#[wasm_bindgen]
1585pub fn sar_batch_into(
1586    high_ptr: *const f64,
1587    low_ptr: *const f64,
1588    out_ptr: *mut f64,
1589    len: usize,
1590    acc_start: f64,
1591    acc_end: f64,
1592    acc_step: f64,
1593    max_start: f64,
1594    max_end: f64,
1595    max_step: f64,
1596) -> Result<usize, JsValue> {
1597    if high_ptr.is_null() || low_ptr.is_null() || out_ptr.is_null() {
1598        return Err(JsValue::from_str("null pointer passed to sar_batch_into"));
1599    }
1600    unsafe {
1601        let high = std::slice::from_raw_parts(high_ptr, len);
1602        let low = std::slice::from_raw_parts(low_ptr, len);
1603        let sweep = SarBatchRange {
1604            acceleration: (acc_start, acc_end, acc_step),
1605            maximum: (max_start, max_end, max_step),
1606        };
1607        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1608        let rows = combos.len();
1609        let cols = len;
1610        let total = rows
1611            .checked_mul(cols)
1612            .ok_or_else(|| JsValue::from_str("sar_batch_into: size overflow in rows*cols"))?;
1613
1614        let out = std::slice::from_raw_parts_mut(out_ptr, total);
1615
1616        sar_batch_inner_into_noalloc_wasm(high, low, &sweep, Kernel::Scalar, false, out)
1617            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1618        Ok(rows)
1619    }
1620}
1621
1622#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1623fn sar_batch_inner_into_noalloc_wasm(
1624    high: &[f64],
1625    low: &[f64],
1626    sweep: &SarBatchRange,
1627    _kern: Kernel,
1628    parallel: bool,
1629    out: &mut [f64],
1630) -> Result<Vec<SarParams>, SarError> {
1631    let combos = expand_grid(sweep)?;
1632
1633    let min_len = high.len().min(low.len());
1634    let (high, low) = (&high[..min_len], &low[..min_len]);
1635
1636    let first = high
1637        .iter()
1638        .zip(low.iter())
1639        .position(|(&h, &l)| !h.is_nan() && !l.is_nan())
1640        .ok_or(SarError::AllValuesNaN)?;
1641    if high.len() - first < 2 {
1642        return Err(SarError::NotEnoughValidData {
1643            needed: 2,
1644            valid: high.len() - first,
1645        });
1646    }
1647
1648    let rows = combos.len();
1649    let cols = high.len();
1650    let expected = rows.checked_mul(cols).ok_or(SarError::InvalidRange {
1651        start: sweep.acceleration.0,
1652        end: sweep.acceleration.1,
1653        step: sweep.acceleration.2,
1654    })?;
1655    if out.len() != expected {
1656        return Err(SarError::OutputLengthMismatch {
1657            expected,
1658            got: out.len(),
1659        });
1660    }
1661
1662    unsafe {
1663        let out_mu = std::slice::from_raw_parts_mut(
1664            out.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
1665            out.len(),
1666        );
1667        let warm: Vec<usize> = vec![first + 1; rows];
1668        init_matrix_prefixes(out_mu, cols, &warm);
1669    }
1670
1671    let exec = |row: usize, row_out: &mut [f64]| unsafe {
1672        let p = &combos[row];
1673        sar_row_scalar(
1674            high,
1675            low,
1676            first,
1677            p.acceleration.unwrap(),
1678            p.maximum.unwrap(),
1679            row_out,
1680        );
1681    };
1682
1683    if parallel {
1684        #[cfg(not(target_arch = "wasm32"))]
1685        {
1686            use rayon::prelude::*;
1687            out.par_chunks_mut(cols)
1688                .enumerate()
1689                .for_each(|(row, r)| exec(row, r));
1690        }
1691        #[cfg(target_arch = "wasm32")]
1692        {
1693            for (row, r) in out.chunks_mut(cols).enumerate() {
1694                exec(row, r);
1695            }
1696        }
1697    } else {
1698        for (row, r) in out.chunks_mut(cols).enumerate() {
1699            exec(row, r);
1700        }
1701    }
1702
1703    Ok(combos)
1704}
1705
1706#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1707#[derive(Serialize, Deserialize)]
1708pub struct SarBatchConfig {
1709    pub acceleration_range: (f64, f64, f64),
1710    pub maximum_range: (f64, f64, f64),
1711}
1712
1713#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1714#[derive(Serialize, Deserialize)]
1715pub struct SarBatchJsOutput {
1716    pub values: Vec<f64>,
1717    pub combos: Vec<SarParams>,
1718    pub rows: usize,
1719    pub cols: usize,
1720}
1721
1722#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1723#[wasm_bindgen(js_name = sar_batch)]
1724pub fn sar_batch_unified_js(
1725    high: &[f64],
1726    low: &[f64],
1727    config: JsValue,
1728) -> Result<JsValue, JsValue> {
1729    let config: SarBatchConfig = serde_wasm_bindgen::from_value(config)
1730        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1731
1732    let sweep = SarBatchRange {
1733        acceleration: config.acceleration_range,
1734        maximum: config.maximum_range,
1735    };
1736
1737    let kernel = if cfg!(target_arch = "wasm32") {
1738        Kernel::Scalar
1739    } else {
1740        detect_best_batch_kernel()
1741    };
1742    let output = sar_batch_inner(high, low, &sweep, kernel, false)
1743        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1744
1745    let js_output = SarBatchJsOutput {
1746        values: output.values,
1747        combos: output.combos,
1748        rows: output.rows,
1749        cols: output.cols,
1750    };
1751
1752    serde_wasm_bindgen::to_value(&js_output).map_err(|e| JsValue::from_str(&e.to_string()))
1753}
1754
1755#[cfg(test)]
1756mod tests {
1757    use super::*;
1758    use crate::skip_if_unsupported;
1759    use crate::utilities::data_loader::read_candles_from_csv;
1760
1761    fn check_sar_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1762        skip_if_unsupported!(kernel, test_name);
1763        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1764        let candles = read_candles_from_csv(file_path)?;
1765
1766        let default_params = SarParams {
1767            acceleration: None,
1768            maximum: None,
1769        };
1770        let input = SarInput::from_candles(&candles, default_params);
1771        let output = sar_with_kernel(&input, kernel)?;
1772        assert_eq!(output.values.len(), candles.close.len());
1773
1774        Ok(())
1775    }
1776
1777    fn check_sar_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1778        skip_if_unsupported!(kernel, test_name);
1779        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1780        let candles = read_candles_from_csv(file_path)?;
1781
1782        let input = SarInput::from_candles(&candles, SarParams::default());
1783        let result = sar_with_kernel(&input, kernel)?;
1784        let expected_last_five = [
1785            60370.00224209362,
1786            60220.362107568006,
1787            60079.70038111392,
1788            59947.478358247085,
1789            59823.189656752256,
1790        ];
1791        let start = result.values.len().saturating_sub(5);
1792        for (i, &val) in result.values[start..].iter().enumerate() {
1793            let diff = (val - expected_last_five[i]).abs();
1794            assert!(
1795                diff < 1e-4,
1796                "[{}] SAR {:?} mismatch at idx {}: got {}, expected {}",
1797                test_name,
1798                kernel,
1799                i,
1800                val,
1801                expected_last_five[i]
1802            );
1803        }
1804        Ok(())
1805    }
1806
1807    fn check_sar_from_slices(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1808        skip_if_unsupported!(kernel, test_name);
1809        let high = [50000.0, 50500.0, 51000.0];
1810        let low = [49000.0, 49500.0, 49900.0];
1811        let params = SarParams::default();
1812        let input = SarInput::from_slices(&high, &low, params);
1813        let result = sar_with_kernel(&input, kernel)?;
1814        assert_eq!(result.values.len(), high.len());
1815        Ok(())
1816    }
1817
1818    #[test]
1819    fn test_sar_into_slice_trims_mismatched_lengths() -> Result<(), Box<dyn Error>> {
1820        let high = [50000.0, 50500.0, 51000.0, 52000.0];
1821        let low = [49000.0, 49500.0, 49900.0];
1822        let params = SarParams::default();
1823        let input = SarInput::from_slices(&high, &low, params);
1824
1825        let expected = sar_with_kernel(&input, Kernel::Scalar)?;
1826        assert_eq!(expected.values.len(), low.len());
1827
1828        let mut out = vec![0.0; low.len()];
1829        sar_into_slice(&mut out, &input, Kernel::Scalar)?;
1830
1831        for i in 0..out.len() {
1832            let a = out[i];
1833            let b = expected.values[i];
1834            assert!(
1835                (a.is_nan() && b.is_nan()) || a == b,
1836                "mismatch at {}: {} vs {}",
1837                i,
1838                a,
1839                b
1840            );
1841        }
1842
1843        Ok(())
1844    }
1845
1846    fn check_sar_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1847        skip_if_unsupported!(kernel, test_name);
1848        let high = [f64::NAN, f64::NAN, f64::NAN];
1849        let low = [f64::NAN, f64::NAN, f64::NAN];
1850        let params = SarParams::default();
1851        let input = SarInput::from_slices(&high, &low, params);
1852        let result = sar_with_kernel(&input, kernel);
1853        assert!(result.is_err());
1854        if let Err(e) = result {
1855            assert!(e.to_string().contains("All values are NaN"));
1856        }
1857        Ok(())
1858    }
1859
1860    #[cfg(debug_assertions)]
1861    fn check_sar_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1862        skip_if_unsupported!(kernel, test_name);
1863
1864        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1865        let candles = read_candles_from_csv(file_path)?;
1866
1867        let test_params = vec![
1868            SarParams::default(),
1869            SarParams {
1870                acceleration: Some(0.001),
1871                maximum: Some(0.001),
1872            },
1873            SarParams {
1874                acceleration: Some(0.01),
1875                maximum: Some(0.1),
1876            },
1877            SarParams {
1878                acceleration: Some(0.02),
1879                maximum: Some(0.3),
1880            },
1881            SarParams {
1882                acceleration: Some(0.05),
1883                maximum: Some(0.2),
1884            },
1885            SarParams {
1886                acceleration: Some(0.05),
1887                maximum: Some(0.5),
1888            },
1889            SarParams {
1890                acceleration: Some(0.1),
1891                maximum: Some(0.5),
1892            },
1893            SarParams {
1894                acceleration: Some(0.1),
1895                maximum: Some(0.9),
1896            },
1897            SarParams {
1898                acceleration: Some(0.2),
1899                maximum: Some(0.9),
1900            },
1901            SarParams {
1902                acceleration: Some(0.001),
1903                maximum: Some(0.9),
1904            },
1905            SarParams {
1906                acceleration: Some(0.2),
1907                maximum: Some(0.01),
1908            },
1909        ];
1910
1911        for (param_idx, params) in test_params.iter().enumerate() {
1912            let input = SarInput::from_candles(&candles, params.clone());
1913            let output = sar_with_kernel(&input, kernel)?;
1914
1915            for (i, &val) in output.values.iter().enumerate() {
1916                if val.is_nan() {
1917                    continue;
1918                }
1919
1920                let bits = val.to_bits();
1921
1922                if bits == 0x11111111_11111111 {
1923                    panic!(
1924                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1925						 with params: acceleration={}, maximum={} (param set {})",
1926                        test_name,
1927                        val,
1928                        bits,
1929                        i,
1930                        params.acceleration.unwrap_or(0.02),
1931                        params.maximum.unwrap_or(0.2),
1932                        param_idx
1933                    );
1934                }
1935
1936                if bits == 0x22222222_22222222 {
1937                    panic!(
1938                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1939						 with params: acceleration={}, maximum={} (param set {})",
1940                        test_name,
1941                        val,
1942                        bits,
1943                        i,
1944                        params.acceleration.unwrap_or(0.02),
1945                        params.maximum.unwrap_or(0.2),
1946                        param_idx
1947                    );
1948                }
1949
1950                if bits == 0x33333333_33333333 {
1951                    panic!(
1952                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1953						 with params: acceleration={}, maximum={} (param set {})",
1954                        test_name,
1955                        val,
1956                        bits,
1957                        i,
1958                        params.acceleration.unwrap_or(0.02),
1959                        params.maximum.unwrap_or(0.2),
1960                        param_idx
1961                    );
1962                }
1963            }
1964        }
1965
1966        Ok(())
1967    }
1968
1969    #[cfg(not(debug_assertions))]
1970    fn check_sar_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1971        Ok(())
1972    }
1973
1974    #[cfg(feature = "proptest")]
1975    #[allow(clippy::float_cmp)]
1976    fn check_sar_property(
1977        test_name: &str,
1978        kernel: Kernel,
1979    ) -> Result<(), Box<dyn std::error::Error>> {
1980        use proptest::prelude::*;
1981        skip_if_unsupported!(kernel, test_name);
1982
1983        let strat = (0.001f64..0.5f64)
1984            .prop_flat_map(|acceleration| (Just(acceleration), acceleration..1.0f64))
1985            .prop_flat_map(|(acceleration, maximum)| {
1986                (
1987                    prop::collection::vec(
1988                        (1.0f64..1e6f64).prop_filter("finite price", |x| x.is_finite() && *x > 0.0),
1989                        10..400,
1990                    ),
1991                    Just(acceleration),
1992                    Just(maximum),
1993                    0.001f64..0.1f64,
1994                )
1995            });
1996
1997        proptest::test_runner::TestRunner::default().run(
1998            &strat,
1999            |(base_prices, acceleration, maximum, volatility)| {
2000                let mut high = Vec::with_capacity(base_prices.len());
2001                let mut low = Vec::with_capacity(base_prices.len());
2002
2003                let mut spread_factor = 1.0;
2004                for price in &base_prices {
2005                    spread_factor = (spread_factor + (price % 0.1 - 0.05) * 0.2)
2006                        .max(0.5)
2007                        .min(2.0);
2008                    let spread = price * volatility * spread_factor;
2009                    high.push(price + spread);
2010                    low.push(price - spread);
2011                }
2012
2013                let params = SarParams {
2014                    acceleration: Some(acceleration),
2015                    maximum: Some(maximum),
2016                };
2017                let input = SarInput::from_slices(&high, &low, params.clone());
2018
2019                let SarOutput { values: out } = sar_with_kernel(&input, kernel).unwrap();
2020
2021                let SarOutput { values: ref_out } =
2022                    sar_with_kernel(&input, Kernel::Scalar).unwrap();
2023
2024                for i in 1..out.len() {
2025                    if !out[i].is_nan() {
2026                        let min_price = low.iter().cloned().fold(f64::INFINITY, f64::min);
2027                        let max_price = high.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
2028
2029                        prop_assert!(
2030                            out[i] >= min_price - 1e-9 && out[i] <= max_price + 1e-9,
2031                            "SAR[{}] = {} is outside range [{}, {}]",
2032                            i,
2033                            out[i],
2034                            min_price,
2035                            max_price
2036                        );
2037                    }
2038                }
2039
2040                prop_assert!(
2041                    out[0].is_nan(),
2042                    "First SAR value should be NaN during warmup, got {}",
2043                    out[0]
2044                );
2045
2046                for i in 0..out.len() {
2047                    let y = out[i];
2048                    let r = ref_out[i];
2049
2050                    if y.is_nan() {
2051                        prop_assert!(
2052                            r.is_nan(),
2053                            "NaN mismatch at index {}: test={}, ref={}",
2054                            i,
2055                            y,
2056                            r
2057                        );
2058                    } else if r.is_nan() {
2059                        prop_assert!(
2060                            y.is_nan(),
2061                            "NaN mismatch at index {}: test={}, ref={}",
2062                            i,
2063                            y,
2064                            r
2065                        );
2066                    } else {
2067                        let diff = (y - r).abs();
2068                        prop_assert!(
2069                            diff < 1e-9,
2070                            "Kernel mismatch at index {}: test={}, ref={}, diff={}",
2071                            i,
2072                            y,
2073                            r,
2074                            diff
2075                        );
2076                    }
2077                }
2078
2079                if out.len() > 10 {
2080                    let mut last_movement = 0.0;
2081                    let mut increasing_count = 0;
2082
2083                    for i in 2..out.len().min(20) {
2084                        if !out[i].is_nan() && !out[i - 1].is_nan() {
2085                            let movement = (out[i] - out[i - 1]).abs();
2086                            if movement > last_movement {
2087                                increasing_count += 1;
2088                            }
2089                            last_movement = movement;
2090                        }
2091                    }
2092
2093                    prop_assert!(
2094                        increasing_count > 0 || out.len() < 5,
2095                        "SAR acceleration never increases (count: {})",
2096                        increasing_count
2097                    );
2098                }
2099
2100                let strong_uptrend = high.windows(2).all(|w| w[1] > w[0] + 1e-9)
2101                    && low.windows(2).all(|w| w[1] > w[0] + 1e-9);
2102                if strong_uptrend && out.len() > 10 {
2103                    let start = out.len() * 3 / 4;
2104                    for i in start..out.len() {
2105                        if !out[i].is_nan() {
2106                            prop_assert!(
2107                                out[i] <= low[i] + 1e-6,
2108                                "In uptrend, SAR[{}] = {} should be <= low[{}] = {}",
2109                                i,
2110                                out[i],
2111                                i,
2112                                low[i]
2113                            );
2114                        }
2115                    }
2116                }
2117
2118                let strong_downtrend = high.windows(2).all(|w| w[1] < w[0] - 1e-9)
2119                    && low.windows(2).all(|w| w[1] < w[0] - 1e-9);
2120                if strong_downtrend && out.len() > 10 {
2121                    let start = out.len() * 3 / 4;
2122                    for i in start..out.len() {
2123                        if !out[i].is_nan() {
2124                            prop_assert!(
2125                                out[i] >= high[i] - 1e-6,
2126                                "In downtrend, SAR[{}] = {} should be >= high[{}] = {}",
2127                                i,
2128                                out[i],
2129                                i,
2130                                high[i]
2131                            );
2132                        }
2133                    }
2134                }
2135
2136                if out.len() > 5 {
2137                    for i in 2..out.len() {
2138                        if !out[i].is_nan() && !out[i - 1].is_nan() {
2139                            let jump = (out[i] - out[i - 1]).abs();
2140                            let avg_price = (high[i] + low[i]) / 2.0;
2141
2142                            if jump > avg_price * 0.05 {
2143                                let prev_below = out[i - 1] < low[i - 1];
2144                                let curr_below = out[i] < low[i];
2145
2146                                prop_assert!(
2147                                    prev_below != curr_below || jump > avg_price * 0.03,
2148                                    "Large SAR jump without proper reversal at index {}",
2149                                    i
2150                                );
2151                            }
2152                        }
2153                    }
2154                }
2155
2156                if base_prices.windows(2).all(|w| w[1] > w[0]) && out.len() > 20 {
2157                    let quarter = out.len() / 4;
2158                    let first_quarter: Vec<f64> = out[quarter..quarter * 2]
2159                        .iter()
2160                        .filter(|v| !v.is_nan())
2161                        .cloned()
2162                        .collect();
2163                    let last_quarter: Vec<f64> = out[quarter * 3..]
2164                        .iter()
2165                        .filter(|v| !v.is_nan())
2166                        .cloned()
2167                        .collect();
2168
2169                    if !first_quarter.is_empty() && !last_quarter.is_empty() {
2170                        let first_avg =
2171                            first_quarter.iter().sum::<f64>() / first_quarter.len() as f64;
2172                        let last_avg = last_quarter.iter().sum::<f64>() / last_quarter.len() as f64;
2173
2174                        prop_assert!(
2175							last_avg >= first_avg * 0.95,
2176							"For increasing prices, SAR should generally trend up: first_avg={}, last_avg={}",
2177							first_avg, last_avg
2178						);
2179                    }
2180                }
2181
2182                #[cfg(debug_assertions)]
2183                {
2184                    for (i, &val) in out.iter().enumerate() {
2185                        if !val.is_nan() {
2186                            let bits = val.to_bits();
2187                            prop_assert!(
2188                                bits != 0x11111111_11111111
2189                                    && bits != 0x22222222_22222222
2190                                    && bits != 0x33333333_33333333,
2191                                "Found poison value at index {}: {} (0x{:016X})",
2192                                i,
2193                                val,
2194                                bits
2195                            );
2196                        }
2197                    }
2198                }
2199
2200                Ok(())
2201            },
2202        )?;
2203
2204        Ok(())
2205    }
2206
2207    macro_rules! generate_all_sar_tests {
2208        ($($test_fn:ident),*) => {
2209            paste::paste! {
2210                $(
2211                    #[test]
2212                    fn [<$test_fn _scalar_f64>]() {
2213                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2214                    }
2215                )*
2216                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2217                $(
2218                    #[test]
2219                    fn [<$test_fn _avx2_f64>]() {
2220                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2221                    }
2222                    #[test]
2223                    fn [<$test_fn _avx512_f64>]() {
2224                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2225                    }
2226                )*
2227            }
2228        }
2229    }
2230
2231    generate_all_sar_tests!(
2232        check_sar_partial_params,
2233        check_sar_accuracy,
2234        check_sar_from_slices,
2235        check_sar_all_nan,
2236        check_sar_no_poison
2237    );
2238
2239    #[cfg(feature = "proptest")]
2240    generate_all_sar_tests!(check_sar_property);
2241
2242    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2243        skip_if_unsupported!(kernel, test);
2244
2245        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2246        let c = read_candles_from_csv(file)?;
2247
2248        let output = SarBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
2249
2250        let def = SarParams::default();
2251        let row = output.values_for(&def).expect("default row missing");
2252
2253        assert_eq!(row.len(), c.close.len());
2254
2255        let expected = [
2256            60370.00224209362,
2257            60220.362107568006,
2258            60079.70038111392,
2259            59947.478358247085,
2260            59823.189656752256,
2261        ];
2262        let start = row.len() - 5;
2263        for (i, &v) in row[start..].iter().enumerate() {
2264            assert!(
2265                (v - expected[i]).abs() < 1e-4,
2266                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2267            );
2268        }
2269        Ok(())
2270    }
2271
2272    #[cfg(debug_assertions)]
2273    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2274        skip_if_unsupported!(kernel, test);
2275
2276        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2277        let c = read_candles_from_csv(file)?;
2278
2279        let test_configs = vec![
2280            (0.01, 0.05, 0.01, 0.1, 0.3, 0.1),
2281            (0.02, 0.1, 0.02, 0.2, 0.5, 0.1),
2282            (0.05, 0.2, 0.05, 0.3, 0.9, 0.2),
2283            (0.001, 0.005, 0.001, 0.05, 0.1, 0.05),
2284            (0.02, 0.02, 0.0, 0.2, 0.2, 0.0),
2285            (0.1, 0.2, 0.025, 0.5, 0.9, 0.1),
2286            (0.001, 0.01, 0.003, 0.1, 0.5, 0.2),
2287        ];
2288
2289        for (cfg_idx, &(a_start, a_end, a_step, m_start, m_end, m_step)) in
2290            test_configs.iter().enumerate()
2291        {
2292            let output = SarBatchBuilder::new()
2293                .kernel(kernel)
2294                .acceleration_range(a_start, a_end, a_step)
2295                .maximum_range(m_start, m_end, m_step)
2296                .apply_candles(&c)?;
2297
2298            for (idx, &val) in output.values.iter().enumerate() {
2299                if val.is_nan() {
2300                    continue;
2301                }
2302
2303                let bits = val.to_bits();
2304                let row = idx / output.cols;
2305                let col = idx % output.cols;
2306                let combo = &output.combos[row];
2307
2308                if bits == 0x11111111_11111111 {
2309                    panic!(
2310                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2311						 at row {} col {} (flat index {}) with params: acceleration={}, maximum={}",
2312                        test,
2313                        cfg_idx,
2314                        val,
2315                        bits,
2316                        row,
2317                        col,
2318                        idx,
2319                        combo.acceleration.unwrap_or(0.02),
2320                        combo.maximum.unwrap_or(0.2)
2321                    );
2322                }
2323
2324                if bits == 0x22222222_22222222 {
2325                    panic!(
2326                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2327						 at row {} col {} (flat index {}) with params: acceleration={}, maximum={}",
2328                        test,
2329                        cfg_idx,
2330                        val,
2331                        bits,
2332                        row,
2333                        col,
2334                        idx,
2335                        combo.acceleration.unwrap_or(0.02),
2336                        combo.maximum.unwrap_or(0.2)
2337                    );
2338                }
2339
2340                if bits == 0x33333333_33333333 {
2341                    panic!(
2342                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2343						 at row {} col {} (flat index {}) with params: acceleration={}, maximum={}",
2344                        test,
2345                        cfg_idx,
2346                        val,
2347                        bits,
2348                        row,
2349                        col,
2350                        idx,
2351                        combo.acceleration.unwrap_or(0.02),
2352                        combo.maximum.unwrap_or(0.2)
2353                    );
2354                }
2355            }
2356        }
2357
2358        Ok(())
2359    }
2360
2361    #[cfg(not(debug_assertions))]
2362    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2363        Ok(())
2364    }
2365
2366    macro_rules! gen_batch_tests {
2367        ($fn_name:ident) => {
2368            paste::paste! {
2369                #[test] fn [<$fn_name _scalar>]()      {
2370                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2371                }
2372                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2373                #[test] fn [<$fn_name _avx2>]()        {
2374                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2375                }
2376                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2377                #[test] fn [<$fn_name _avx512>]()      {
2378                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2379                }
2380                #[test] fn [<$fn_name _auto_detect>]() {
2381                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2382                }
2383            }
2384        };
2385    }
2386    gen_batch_tests!(check_batch_default_row);
2387    gen_batch_tests!(check_batch_no_poison);
2388
2389    #[test]
2390    fn test_sar_into_matches_api() -> Result<(), Box<dyn Error>> {
2391        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2392        let candles = read_candles_from_csv(file_path)?;
2393
2394        let input = SarInput::from_candles(&candles, SarParams::default());
2395
2396        let SarOutput { values: expected } = sar(&input)?;
2397
2398        let mut actual = vec![0.0; candles.high.len()];
2399        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2400        {
2401            sar_into(&input, &mut actual)?;
2402        }
2403        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2404        {
2405            sar_into_slice(&mut actual, &input, Kernel::Auto)?;
2406        }
2407
2408        assert_eq!(expected.len(), actual.len());
2409
2410        fn eq_or_both_nan(a: f64, b: f64) -> bool {
2411            (a.is_nan() && b.is_nan()) || (a == b)
2412        }
2413
2414        for i in 0..expected.len() {
2415            assert!(
2416                eq_or_both_nan(expected[i], actual[i]),
2417                "Mismatch at index {}: expected {:?}, got {:?}",
2418                i,
2419                expected[i],
2420                actual[i]
2421            );
2422        }
2423        Ok(())
2424    }
2425}