Skip to main content

vector_ta/indicators/
acosc.rs

1use crate::utilities::data_loader::Candles;
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5    make_uninit_matrix,
6};
7#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
8use core::arch::x86_64::*;
9use std::mem::ManuallyDrop;
10use thiserror::Error;
11
12#[derive(Debug, Clone)]
13pub enum AcoscData<'a> {
14    Candles { candles: &'a Candles },
15    Slices { high: &'a [f64], low: &'a [f64] },
16}
17
18#[derive(Debug, Clone, Default)]
19pub struct AcoscParams {}
20
21#[derive(Debug, Clone)]
22pub struct AcoscInput<'a> {
23    pub data: AcoscData<'a>,
24    pub params: AcoscParams,
25}
26impl<'a> AcoscInput<'a> {
27    #[inline]
28    pub fn from_candles(candles: &'a Candles, params: AcoscParams) -> Self {
29        Self {
30            data: AcoscData::Candles { candles },
31            params,
32        }
33    }
34    #[inline]
35    pub fn from_slices(high: &'a [f64], low: &'a [f64], params: AcoscParams) -> Self {
36        Self {
37            data: AcoscData::Slices { high, low },
38            params,
39        }
40    }
41    #[inline]
42    pub fn with_default_candles(candles: &'a Candles) -> Self {
43        Self {
44            data: AcoscData::Candles { candles },
45            params: AcoscParams::default(),
46        }
47    }
48}
49
50#[derive(Debug, Clone)]
51pub struct AcoscOutput {
52    pub osc: Vec<f64>,
53    pub change: Vec<f64>,
54}
55
56#[derive(Debug, Error)]
57pub enum AcoscError {
58    #[error("acosc: Failed to get high/low fields from candles: {msg}")]
59    CandleFieldError { msg: String },
60    #[error(
61        "acosc: Mismatch in high/low candle data lengths: high_len={high_len}, low_len={low_len}"
62    )]
63    LengthMismatch { high_len: usize, low_len: usize },
64    #[error("acosc: Empty input data")]
65    EmptyInputData,
66    #[error("acosc: Not enough data: all values are NaN")]
67    AllValuesNaN,
68    #[error("acosc: Invalid period: period={period}, data_len={data_len}")]
69    InvalidPeriod { period: usize, data_len: usize },
70    #[error("acosc: Not enough data: needed={needed}, valid={valid}")]
71    NotEnoughValidData { needed: usize, valid: usize },
72    #[error("acosc: Output length mismatch: expected={expected}, got={got}")]
73    OutputLengthMismatch { expected: usize, got: usize },
74    #[error("acosc: Invalid range: start={start}, end={end}, step={step}")]
75    InvalidRange { start: i64, end: i64, step: i64 },
76    #[error("acosc: Invalid kernel for batch operation. Expected batch kernel, got: {0:?}")]
77    InvalidKernelForBatch(Kernel),
78
79    #[error("acosc: Not enough data points: required={required}, actual={actual}")]
80    NotEnoughData { required: usize, actual: usize },
81    #[error("acosc: Invalid kernel for batch operation. Expected batch kernel, got: {kernel:?}")]
82    InvalidBatchKernel { kernel: Kernel },
83}
84
85#[inline]
86pub fn acosc(input: &AcoscInput) -> Result<AcoscOutput, AcoscError> {
87    acosc_with_kernel(input, Kernel::Auto)
88}
89
90#[inline(always)]
91fn acosc_prepare<'a>(
92    input: &'a AcoscInput,
93    kernel: Kernel,
94) -> Result<(&'a [f64], &'a [f64], usize, Kernel), AcoscError> {
95    let (high, low) = match &input.data {
96        AcoscData::Candles { candles } => {
97            let h = candles
98                .select_candle_field("high")
99                .map_err(|e| AcoscError::CandleFieldError { msg: e.to_string() })?;
100            let l = candles
101                .select_candle_field("low")
102                .map_err(|e| AcoscError::CandleFieldError { msg: e.to_string() })?;
103            (h, l)
104        }
105        AcoscData::Slices { high, low } => (*high, *low),
106    };
107
108    if high.len() != low.len() {
109        return Err(AcoscError::LengthMismatch {
110            high_len: high.len(),
111            low_len: low.len(),
112        });
113    }
114
115    let len = high.len();
116    if len == 0 {
117        return Err(AcoscError::EmptyInputData);
118    }
119    const REQUIRED_LENGTH: usize = 39;
120
121    let first = (0..len)
122        .find(|&i| !high[i].is_nan() && !low[i].is_nan())
123        .unwrap_or(len);
124    let valid = len.saturating_sub(first);
125    if valid == 0 {
126        return Err(AcoscError::AllValuesNaN);
127    }
128    if valid < REQUIRED_LENGTH {
129        return Err(AcoscError::NotEnoughValidData {
130            needed: REQUIRED_LENGTH,
131            valid,
132        });
133    }
134
135    let chosen = match kernel {
136        Kernel::Auto => detect_best_kernel(),
137        other => other,
138    };
139    Ok((high, low, first, chosen))
140}
141pub fn acosc_with_kernel(input: &AcoscInput, kernel: Kernel) -> Result<AcoscOutput, AcoscError> {
142    let (high, low, first, chosen) = acosc_prepare(input, kernel)?;
143
144    let len = low.len();
145    const WARMUP: usize = 38;
146    let warmup_end = first + WARMUP;
147
148    let mut osc = alloc_with_nan_prefix(len, warmup_end);
149    let mut change = alloc_with_nan_prefix(len, warmup_end);
150
151    if first < len {
152        let valid_len = len - first;
153        if valid_len > WARMUP {
154            acosc_compute_into(
155                &high[first..],
156                &low[first..],
157                chosen,
158                &mut osc[first..],
159                &mut change[first..],
160            );
161        }
162    }
163
164    Ok(AcoscOutput { osc, change })
165}
166
167#[inline(always)]
168fn acosc_compute_into(
169    high: &[f64],
170    low: &[f64],
171    kernel: Kernel,
172    osc_out: &mut [f64],
173    change_out: &mut [f64],
174) {
175    unsafe {
176        match kernel {
177            Kernel::Scalar | Kernel::ScalarBatch => acosc_scalar(high, low, osc_out, change_out),
178            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
179            Kernel::Avx2 | Kernel::Avx2Batch => acosc_avx2(high, low, osc_out, change_out),
180            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
181            Kernel::Avx512 | Kernel::Avx512Batch => acosc_avx512(high, low, osc_out, change_out),
182            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
183            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
184                acosc_scalar(high, low, osc_out, change_out)
185            }
186            Kernel::Auto => {
187                unreachable!("Kernel::Auto should be resolved before calling compute_into")
188            }
189        }
190    }
191}
192
193#[inline(always)]
194pub fn acosc_scalar(high: &[f64], low: &[f64], osc: &mut [f64], change: &mut [f64]) {
195    const PERIOD_SMA5: usize = 5;
196    const PERIOD_SMA34: usize = 34;
197    const INV5: f64 = 1.0 / 5.0;
198    const INV34: f64 = 1.0 / 34.0;
199    let len = high.len();
200    debug_assert_eq!(low.len(), len);
201    debug_assert_eq!(osc.len(), len);
202    debug_assert_eq!(change.len(), len);
203    debug_assert!(len >= PERIOD_SMA34 + PERIOD_SMA5);
204    let mut queue5 = [0.0; PERIOD_SMA5];
205    let mut queue34 = [0.0; PERIOD_SMA34];
206    let mut queue5_ao = [0.0; PERIOD_SMA5];
207    let mut sum5 = 0.0;
208    let mut sum34 = 0.0;
209    let mut sum5_ao = 0.0;
210    let mut idx5 = 0;
211    let mut idx34 = 0;
212    let mut idx5_ao = 0;
213
214    unsafe {
215        let h_ptr = high.as_ptr();
216        let l_ptr = low.as_ptr();
217        let osc_ptr = osc.as_mut_ptr();
218        let ch_ptr = change.as_mut_ptr();
219
220        for i in 0..PERIOD_SMA34 {
221            let med = (*h_ptr.add(i) + *l_ptr.add(i)) * 0.5;
222            sum34 += med;
223            queue34[i] = med;
224            if i < PERIOD_SMA5 {
225                sum5 += med;
226                queue5[i] = med;
227            }
228        }
229        for i in PERIOD_SMA34..(PERIOD_SMA34 + PERIOD_SMA5 - 1) {
230            let med = (*h_ptr.add(i) + *l_ptr.add(i)) * 0.5;
231            sum34 += med - queue34[idx34];
232            queue34[idx34] = med;
233            idx34 += 1;
234            if idx34 == PERIOD_SMA34 {
235                idx34 = 0;
236            }
237            let sma34 = sum34 * INV34;
238            sum5 += med - queue5[idx5];
239            queue5[idx5] = med;
240            idx5 += 1;
241            if idx5 == PERIOD_SMA5 {
242                idx5 = 0;
243            }
244            let sma5 = sum5 * INV5;
245            let ao = sma5 - sma34;
246            sum5_ao += ao;
247            queue5_ao[idx5_ao] = ao;
248            idx5_ao += 1;
249        }
250        if idx5_ao == PERIOD_SMA5 {
251            idx5_ao = 0;
252        }
253        let mut prev_res = 0.0;
254        for i in (PERIOD_SMA34 + PERIOD_SMA5 - 1)..len {
255            let med = (*h_ptr.add(i) + *l_ptr.add(i)) * 0.5;
256            sum34 += med - queue34[idx34];
257            queue34[idx34] = med;
258            idx34 += 1;
259            if idx34 == PERIOD_SMA34 {
260                idx34 = 0;
261            }
262            let sma34 = sum34 * INV34;
263            sum5 += med - queue5[idx5];
264            queue5[idx5] = med;
265            idx5 += 1;
266            if idx5 == PERIOD_SMA5 {
267                idx5 = 0;
268            }
269            let sma5 = sum5 * INV5;
270            let ao = sma5 - sma34;
271            let old_ao = queue5_ao[idx5_ao];
272            sum5_ao += ao - old_ao;
273            queue5_ao[idx5_ao] = ao;
274            idx5_ao += 1;
275            if idx5_ao == PERIOD_SMA5 {
276                idx5_ao = 0;
277            }
278            let sma5_ao = sum5_ao * INV5;
279            let res = ao - sma5_ao;
280            let mom = res - prev_res;
281            prev_res = res;
282            *osc_ptr.add(i) = res;
283            *ch_ptr.add(i) = mom;
284        }
285    }
286}
287
288#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
289#[inline]
290pub fn acosc_avx512(high: &[f64], low: &[f64], osc: &mut [f64], change: &mut [f64]) {
291    acosc_scalar(high, low, osc, change)
292}
293#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
294#[inline]
295pub fn acosc_avx2(high: &[f64], low: &[f64], osc: &mut [f64], change: &mut [f64]) {
296    acosc_scalar(high, low, osc, change)
297}
298#[inline]
299pub fn acosc_avx512_short(high: &[f64], low: &[f64], osc: &mut [f64], change: &mut [f64]) {
300    acosc_scalar(high, low, osc, change)
301}
302#[inline]
303pub fn acosc_avx512_long(high: &[f64], low: &[f64], osc: &mut [f64], change: &mut [f64]) {
304    acosc_scalar(high, low, osc, change)
305}
306
307#[derive(Debug, Clone)]
308pub struct AcoscStream {
309    queue5: [f64; 5],
310    queue34: [f64; 34],
311    queue5_ao: [f64; 5],
312    sum5: f64,
313    sum34: f64,
314    sum5_ao: f64,
315    idx5: usize,
316    idx34: usize,
317    idx5_ao: usize,
318    filled: usize,
319    prev_res: f64,
320}
321impl AcoscStream {
322    pub fn try_new(_params: AcoscParams) -> Result<Self, AcoscError> {
323        Ok(Self {
324            queue5: [0.0; 5],
325            queue34: [0.0; 34],
326            queue5_ao: [0.0; 5],
327            sum5: 0.0,
328            sum34: 0.0,
329            sum5_ao: 0.0,
330            idx5: 0,
331            idx34: 0,
332            idx5_ao: 0,
333            filled: 0,
334            prev_res: 0.0,
335        })
336    }
337    #[inline(always)]
338    pub fn update(&mut self, high: f64, low: f64) -> Option<(f64, f64)> {
339        const PERIOD_SMA5: usize = 5;
340        const PERIOD_SMA34: usize = 34;
341        const INV5: f64 = 1.0 / 5.0;
342        const INV34: f64 = 1.0 / 34.0;
343
344        let med = (high + low) * 0.5;
345
346        self.filled += 1;
347
348        if self.filled <= PERIOD_SMA34 {
349            self.sum34 += med;
350            self.queue34[self.filled - 1] = med;
351
352            if self.filled <= PERIOD_SMA5 {
353                self.sum5 += med;
354                self.queue5[self.filled - 1] = med;
355            }
356            return None;
357        }
358
359        if self.filled < (PERIOD_SMA34 + PERIOD_SMA5) {
360            let old34 = self.queue34[self.idx34];
361            self.sum34 += med - old34;
362            self.queue34[self.idx34] = med;
363            self.idx34 += 1;
364            if self.idx34 == PERIOD_SMA34 {
365                self.idx34 = 0;
366            }
367            let sma34 = self.sum34 * INV34;
368
369            let old5 = self.queue5[self.idx5];
370            self.sum5 += med - old5;
371            self.queue5[self.idx5] = med;
372            self.idx5 += 1;
373            if self.idx5 == PERIOD_SMA5 {
374                self.idx5 = 0;
375            }
376            let sma5 = self.sum5 * INV5;
377
378            let ao = sma5 - sma34;
379            self.sum5_ao += ao;
380            self.queue5_ao[self.idx5_ao] = ao;
381            self.idx5_ao += 1;
382            if self.idx5_ao == PERIOD_SMA5 {
383                self.idx5_ao = 0;
384            }
385            return None;
386        }
387
388        let old34 = self.queue34[self.idx34];
389        self.sum34 += med - old34;
390        self.queue34[self.idx34] = med;
391        self.idx34 += 1;
392        if self.idx34 == PERIOD_SMA34 {
393            self.idx34 = 0;
394        }
395        let sma34 = self.sum34 * INV34;
396
397        let old5 = self.queue5[self.idx5];
398        self.sum5 += med - old5;
399        self.queue5[self.idx5] = med;
400        self.idx5 += 1;
401        if self.idx5 == PERIOD_SMA5 {
402            self.idx5 = 0;
403        }
404        let sma5 = self.sum5 * INV5;
405
406        let ao = sma5 - sma34;
407        let old_ao = self.queue5_ao[self.idx5_ao];
408        self.sum5_ao += ao - old_ao;
409        self.queue5_ao[self.idx5_ao] = ao;
410        self.idx5_ao += 1;
411        if self.idx5_ao == PERIOD_SMA5 {
412            self.idx5_ao = 0;
413        }
414
415        let sma5_ao = self.sum5_ao * INV5;
416
417        let res = ao - sma5_ao;
418        let mom = res - self.prev_res;
419        self.prev_res = res;
420
421        Some((res, mom))
422    }
423}
424
425#[derive(Clone, Debug)]
426pub struct AcoscBatchRange {}
427
428impl Default for AcoscBatchRange {
429    fn default() -> Self {
430        Self {}
431    }
432}
433
434#[derive(Clone, Debug, Default)]
435pub struct AcoscBatchBuilder {
436    kernel: Kernel,
437}
438impl AcoscBatchBuilder {
439    pub fn new() -> Self {
440        Self::default()
441    }
442    pub fn kernel(mut self, k: Kernel) -> Self {
443        self.kernel = k;
444        self
445    }
446    pub fn apply_slice(self, high: &[f64], low: &[f64]) -> Result<AcoscBatchOutput, AcoscError> {
447        acosc_batch_with_kernel(high, low, self.kernel)
448    }
449    pub fn with_default_slice(
450        high: &[f64],
451        low: &[f64],
452        k: Kernel,
453    ) -> Result<AcoscBatchOutput, AcoscError> {
454        AcoscBatchBuilder::new().kernel(k).apply_slice(high, low)
455    }
456    pub fn apply_candles(self, c: &Candles) -> Result<AcoscBatchOutput, AcoscError> {
457        let high = c
458            .select_candle_field("high")
459            .map_err(|e| AcoscError::CandleFieldError { msg: e.to_string() })?;
460        let low = c
461            .select_candle_field("low")
462            .map_err(|e| AcoscError::CandleFieldError { msg: e.to_string() })?;
463        self.apply_slice(high, low)
464    }
465    pub fn with_default_candles(c: &Candles) -> Result<AcoscBatchOutput, AcoscError> {
466        AcoscBatchBuilder::new()
467            .kernel(Kernel::Auto)
468            .apply_candles(c)
469    }
470}
471#[derive(Clone, Debug)]
472pub struct AcoscBatchOutput {
473    pub osc: Vec<f64>,
474    pub change: Vec<f64>,
475    pub rows: usize,
476    pub cols: usize,
477}
478pub fn acosc_batch_with_kernel(
479    high: &[f64],
480    low: &[f64],
481    k: Kernel,
482) -> Result<AcoscBatchOutput, AcoscError> {
483    let kernel = match k {
484        Kernel::Auto => detect_best_batch_kernel(),
485        other if other.is_batch() => other,
486        _ => return Err(AcoscError::InvalidKernelForBatch(k)),
487    };
488    let simd = match kernel {
489        Kernel::Avx512Batch => Kernel::Avx512,
490        Kernel::Avx2Batch => Kernel::Avx2,
491        Kernel::ScalarBatch => Kernel::Scalar,
492        _ => unreachable!(),
493    };
494    acosc_batch_par_slice(high, low, simd)
495}
496#[inline(always)]
497pub fn acosc_batch_slice(
498    high: &[f64],
499    low: &[f64],
500    kern: Kernel,
501) -> Result<AcoscBatchOutput, AcoscError> {
502    acosc_batch_inner(high, low, kern, false)
503}
504#[inline(always)]
505pub fn acosc_batch_par_slice(
506    high: &[f64],
507    low: &[f64],
508    kern: Kernel,
509) -> Result<AcoscBatchOutput, AcoscError> {
510    acosc_batch_inner(high, low, kern, true)
511}
512#[inline(always)]
513fn acosc_batch_inner(
514    high: &[f64],
515    low: &[f64],
516    kern: Kernel,
517    _parallel: bool,
518) -> Result<AcoscBatchOutput, AcoscError> {
519    let cols = high.len();
520    let rows: usize = 1;
521
522    let _total = rows.checked_mul(cols).ok_or(AcoscError::InvalidRange {
523        start: 0,
524        end: cols as i64,
525        step: 0,
526    })?;
527
528    let first = (0..cols)
529        .find(|&i| !high[i].is_nan() && !low[i].is_nan())
530        .unwrap_or(cols);
531    const REQUIRED_LENGTH: usize = 39;
532    let valid = cols.saturating_sub(first);
533    if valid == 0 {
534        return Err(AcoscError::AllValuesNaN);
535    }
536    if valid < REQUIRED_LENGTH {
537        return Err(AcoscError::NotEnoughValidData {
538            needed: REQUIRED_LENGTH,
539            valid,
540        });
541    }
542
543    let mut buf_osc_mu = make_uninit_matrix(rows, cols);
544    let mut buf_change_mu = make_uninit_matrix(rows, cols);
545
546    const WARMUP: usize = 38;
547    let warmups = vec![first + WARMUP];
548    init_matrix_prefixes(&mut buf_osc_mu, cols, &warmups);
549    init_matrix_prefixes(&mut buf_change_mu, cols, &warmups);
550
551    let mut osc_guard = core::mem::ManuallyDrop::new(buf_osc_mu);
552    let mut change_guard = core::mem::ManuallyDrop::new(buf_change_mu);
553
554    let osc_slice: &mut [f64] = unsafe {
555        core::slice::from_raw_parts_mut(osc_guard.as_mut_ptr() as *mut f64, osc_guard.len())
556    };
557    let change_slice: &mut [f64] = unsafe {
558        core::slice::from_raw_parts_mut(change_guard.as_mut_ptr() as *mut f64, change_guard.len())
559    };
560
561    let simd = match kern {
562        Kernel::Auto => detect_best_kernel(),
563        other => other,
564    };
565
566    if first < cols {
567        let valid_len = cols - first;
568        if valid_len > WARMUP {
569            acosc_compute_into(
570                &high[first..],
571                &low[first..],
572                simd,
573                &mut osc_slice[first..],
574                &mut change_slice[first..],
575            );
576        }
577    }
578
579    let osc = unsafe {
580        Vec::from_raw_parts(
581            osc_guard.as_mut_ptr() as *mut f64,
582            osc_guard.len(),
583            osc_guard.capacity(),
584        )
585    };
586    let change = unsafe {
587        Vec::from_raw_parts(
588            change_guard.as_mut_ptr() as *mut f64,
589            change_guard.len(),
590            change_guard.capacity(),
591        )
592    };
593
594    Ok(AcoscBatchOutput {
595        osc,
596        change,
597        rows,
598        cols,
599    })
600}
601#[inline(always)]
602pub fn expand_grid(_r: &AcoscBatchRange) -> Vec<AcoscParams> {
603    vec![AcoscParams::default()]
604}
605
606#[inline(always)]
607pub unsafe fn acosc_row_scalar(
608    high: &[f64],
609    low: &[f64],
610    out_osc: &mut [f64],
611    out_change: &mut [f64],
612) {
613    acosc_scalar(high, low, out_osc, out_change)
614}
615#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
616#[inline(always)]
617pub unsafe fn acosc_row_avx2(
618    high: &[f64],
619    low: &[f64],
620    out_osc: &mut [f64],
621    out_change: &mut [f64],
622) {
623    acosc_avx2(high, low, out_osc, out_change)
624}
625#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
626#[inline(always)]
627pub unsafe fn acosc_row_avx512(
628    high: &[f64],
629    low: &[f64],
630    out_osc: &mut [f64],
631    out_change: &mut [f64],
632) {
633    acosc_avx512(high, low, out_osc, out_change)
634}
635#[inline(always)]
636pub fn acosc_row_avx512_short(
637    high: &[f64],
638    low: &[f64],
639    out_osc: &mut [f64],
640    out_change: &mut [f64],
641) {
642    acosc_scalar(high, low, out_osc, out_change)
643}
644#[inline(always)]
645pub fn acosc_row_avx512_long(
646    high: &[f64],
647    low: &[f64],
648    out_osc: &mut [f64],
649    out_change: &mut [f64],
650) {
651    acosc_scalar(high, low, out_osc, out_change)
652}
653
654#[derive(Copy, Clone, Debug, Default)]
655pub struct AcoscBuilder {
656    kernel: Kernel,
657}
658impl AcoscBuilder {
659    #[inline(always)]
660    pub fn new() -> Self {
661        Self::default()
662    }
663    #[inline(always)]
664    pub fn kernel(mut self, k: Kernel) -> Self {
665        self.kernel = k;
666        self
667    }
668    #[inline(always)]
669    pub fn apply_candles(self, candles: &Candles) -> Result<AcoscOutput, AcoscError> {
670        let input = AcoscInput::with_default_candles(candles);
671        acosc_with_kernel(&input, self.kernel)
672    }
673    #[inline(always)]
674    pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<AcoscOutput, AcoscError> {
675        let input = AcoscInput::from_slices(high, low, AcoscParams::default());
676        acosc_with_kernel(&input, self.kernel)
677    }
678}
679
680#[cfg(all(feature = "python", feature = "cuda"))]
681use crate::cuda::cuda_available;
682#[cfg(all(feature = "python", feature = "cuda"))]
683use crate::cuda::oscillators::CudaAcosc;
684#[cfg(all(feature = "python", feature = "cuda"))]
685use crate::indicators::moving_averages::alma::DeviceArrayF32Py;
686#[cfg(feature = "python")]
687use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
688#[cfg(feature = "python")]
689use pyo3::exceptions::PyValueError;
690#[cfg(feature = "python")]
691use pyo3::prelude::*;
692#[cfg(feature = "python")]
693use pyo3::types::PyDict;
694#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
695use serde::{Deserialize, Serialize};
696#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
697use wasm_bindgen::prelude::*;
698
699#[cfg(feature = "python")]
700#[pyfunction(name = "acosc")]
701#[pyo3(signature = (high, low, kernel=None))]
702
703pub fn acosc_py<'py>(
704    py: Python<'py>,
705    high: PyReadonlyArray1<'py, f64>,
706    low: PyReadonlyArray1<'py, f64>,
707    kernel: Option<&str>,
708) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
709    use numpy::{IntoPyArray, PyArrayMethods};
710
711    let high_slice = high.as_slice()?;
712    let low_slice = low.as_slice()?;
713    let kern = crate::utilities::kernel_validation::validate_kernel(kernel, false)?;
714
715    let params = AcoscParams::default();
716    let acosc_in = AcoscInput::from_slices(high_slice, low_slice, params);
717
718    let (osc_vec, change_vec) = py
719        .allow_threads(|| {
720            acosc_with_kernel(&acosc_in, kern).map(|output| (output.osc, output.change))
721        })
722        .map_err(|e| PyValueError::new_err(e.to_string()))?;
723
724    Ok((osc_vec.into_pyarray(py), change_vec.into_pyarray(py)))
725}
726
727#[cfg(feature = "python")]
728#[pyclass(name = "AcoscStream")]
729pub struct AcoscStreamPy {
730    stream: AcoscStream,
731}
732
733#[cfg(feature = "python")]
734#[pymethods]
735impl AcoscStreamPy {
736    #[new]
737    fn new() -> PyResult<Self> {
738        let params = AcoscParams::default();
739        let stream =
740            AcoscStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
741        Ok(AcoscStreamPy { stream })
742    }
743
744    fn update(&mut self, high: f64, low: f64) -> Option<(f64, f64)> {
745        self.stream.update(high, low)
746    }
747}
748
749#[cfg(feature = "python")]
750#[pyfunction(name = "acosc_batch")]
751#[pyo3(signature = (high, low, kernel=None))]
752pub fn acosc_batch_py<'py>(
753    py: Python<'py>,
754    high: PyReadonlyArray1<'py, f64>,
755    low: PyReadonlyArray1<'py, f64>,
756    kernel: Option<&str>,
757) -> PyResult<Bound<'py, PyDict>> {
758    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
759
760    let h = high.as_slice()?;
761    let l = low.as_slice()?;
762    let kern = crate::utilities::kernel_validation::validate_kernel(kernel, true)?;
763
764    let rows = 1usize;
765    let cols = h.len();
766
767    let out_osc = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
768    let out_change = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
769    let slice_osc = unsafe { out_osc.as_slice_mut()? };
770    let slice_change = unsafe { out_change.as_slice_mut()? };
771
772    py.allow_threads(|| -> Result<(), AcoscError> {
773        let simd = match kern {
774            Kernel::Auto => detect_best_batch_kernel(),
775            k => k,
776        };
777        let simd = match simd {
778            Kernel::Avx512Batch => Kernel::Avx512,
779            Kernel::Avx2Batch => Kernel::Avx2,
780            Kernel::ScalarBatch => Kernel::Scalar,
781            _ => simd,
782        };
783
784        let first = (0..cols)
785            .find(|&i| !h[i].is_nan() && !l[i].is_nan())
786            .unwrap_or(cols);
787        const REQUIRED_LENGTH: usize = 39;
788        let valid = cols.saturating_sub(first);
789        if valid < REQUIRED_LENGTH {
790            return Err(AcoscError::NotEnoughValidData {
791                needed: REQUIRED_LENGTH,
792                valid,
793            });
794        }
795
796        const WARMUP: usize = 38;
797        let warm = first + WARMUP;
798
799        for i in 0..warm.min(cols) {
800            slice_osc[i] = f64::from_bits(0x7ff8_0000_0000_0000);
801            slice_change[i] = f64::from_bits(0x7ff8_0000_0000_0000);
802        }
803
804        if first < cols && valid > WARMUP {
805            acosc_compute_into(
806                &h[first..],
807                &l[first..],
808                simd,
809                &mut slice_osc[first..],
810                &mut slice_change[first..],
811            )
812        };
813        Ok(())
814    })
815    .map_err(|e| PyValueError::new_err(e.to_string()))?;
816
817    let d = PyDict::new(py);
818    d.set_item("osc", out_osc.reshape((rows, cols))?)?;
819    d.set_item("change", out_change.reshape((rows, cols))?)?;
820    Ok(d)
821}
822
823#[cfg(all(feature = "python", feature = "cuda"))]
824#[pyfunction(name = "acosc_cuda_batch_dev")]
825#[pyo3(signature = (high_f32, low_f32, device_id=0))]
826pub fn acosc_cuda_batch_dev_py(
827    py: Python<'_>,
828    high_f32: numpy::PyReadonlyArray1<'_, f32>,
829    low_f32: numpy::PyReadonlyArray1<'_, f32>,
830    device_id: usize,
831) -> PyResult<(AcoscDeviceArrayF32Py, AcoscDeviceArrayF32Py)> {
832    if !cuda_available() {
833        return Err(PyValueError::new_err("CUDA not available"));
834    }
835    let h = high_f32.as_slice()?;
836    let l = low_f32.as_slice()?;
837    let pair = py.allow_threads(|| {
838        let cuda = CudaAcosc::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
839        cuda.acosc_batch_dev(h, l)
840            .map_err(|e| PyValueError::new_err(e.to_string()))
841    })?;
842    Ok((
843        AcoscDeviceArrayF32Py {
844            inner: Some(pair.osc),
845            device_id: device_id as u32,
846        },
847        AcoscDeviceArrayF32Py {
848            inner: Some(pair.change),
849            device_id: device_id as u32,
850        },
851    ))
852}
853
854#[cfg(all(feature = "python", feature = "cuda"))]
855#[pyfunction(name = "acosc_cuda_many_series_one_param_dev")]
856#[pyo3(signature = (high_tm_f32, low_tm_f32, device_id=0))]
857pub fn acosc_cuda_many_series_one_param_dev_py(
858    py: Python<'_>,
859    high_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
860    low_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
861    device_id: usize,
862) -> PyResult<(AcoscDeviceArrayF32Py, AcoscDeviceArrayF32Py)> {
863    use numpy::PyUntypedArrayMethods;
864    if !cuda_available() {
865        return Err(PyValueError::new_err("CUDA not available"));
866    }
867    let shape_h = high_tm_f32.shape();
868    let shape_l = low_tm_f32.shape();
869    if shape_h != shape_l || shape_h.len() != 2 {
870        return Err(PyValueError::new_err("high/low must be same 2D shape"));
871    }
872    let rows = shape_h[0];
873    let cols = shape_h[1];
874    let h = high_tm_f32.as_slice()?;
875    let l = low_tm_f32.as_slice()?;
876    let pair = py.allow_threads(|| {
877        let cuda = CudaAcosc::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
878        cuda.acosc_many_series_one_param_time_major_dev(h, l, cols, rows)
879            .map_err(|e| PyValueError::new_err(e.to_string()))
880    })?;
881    Ok((
882        AcoscDeviceArrayF32Py {
883            inner: Some(pair.osc),
884            device_id: device_id as u32,
885        },
886        AcoscDeviceArrayF32Py {
887            inner: Some(pair.change),
888            device_id: device_id as u32,
889        },
890    ))
891}
892
893#[cfg(all(feature = "python", feature = "cuda"))]
894use crate::cuda::oscillators::DeviceArrayF32Acosc;
895#[cfg(all(feature = "python", feature = "cuda"))]
896#[pyclass(module = "ta_indicators.cuda", unsendable)]
897pub struct AcoscDeviceArrayF32Py {
898    pub(crate) inner: Option<DeviceArrayF32Acosc>,
899    pub(crate) device_id: u32,
900}
901#[cfg(all(feature = "python", feature = "cuda"))]
902#[pymethods]
903impl AcoscDeviceArrayF32Py {
904    #[getter]
905    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
906        let inner = self
907            .inner
908            .as_ref()
909            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
910        let d = PyDict::new(py);
911        d.set_item("shape", (inner.rows, inner.cols))?;
912        d.set_item("typestr", "<f4")?;
913        d.set_item(
914            "strides",
915            (
916                inner.cols * std::mem::size_of::<f32>(),
917                std::mem::size_of::<f32>(),
918            ),
919        )?;
920        d.set_item("data", (inner.device_ptr() as usize, false))?;
921
922        d.set_item("version", 3)?;
923        Ok(d)
924    }
925
926    fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
927        let inner = self
928            .inner
929            .as_ref()
930            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
931        Ok((2, inner.device_id as i32))
932    }
933
934    fn __dlpack__<'py>(
935        &mut self,
936        py: Python<'py>,
937        stream: Option<i64>,
938        max_version: Option<(u32, u32)>,
939        dl_device: Option<(i32, i32)>,
940        _copy: Option<bool>,
941    ) -> PyResult<PyObject> {
942        use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
943
944        let inner = self
945            .inner
946            .take()
947            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
948
949        if let Some((_ty, dev_id)) = dl_device {
950            if dev_id as u32 != inner.device_id {
951                return Err(PyValueError::new_err(
952                    "dl_device does not match allocation device",
953                ));
954            }
955        }
956
957        let _ = stream;
958
959        let DeviceArrayF32Acosc {
960            buf,
961            rows,
962            cols,
963            ctx: _,
964            device_id,
965        } = inner;
966
967        let max_version_bound = max_version
968            .map(|(maj, min)| -> PyResult<_> {
969                use pyo3::IntoPyObjectExt;
970                (maj as i32, min as i32).into_bound_py_any(py)
971            })
972            .transpose()?;
973
974        export_f32_cuda_dlpack_2d(py, buf, rows, cols, device_id as i32, max_version_bound)
975    }
976}
977
978#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
979#[wasm_bindgen]
980pub fn acosc_js(high: &[f64], low: &[f64]) -> Result<Vec<f64>, JsValue> {
981    let params = AcoscParams::default();
982    let input = AcoscInput::from_slices(high, low, params);
983
984    let len = high.len();
985    let total = len
986        .checked_mul(2)
987        .ok_or_else(|| JsValue::from_str("acosc_js: size overflow"))?;
988    let mut output = vec![0.0; total];
989
990    let (osc_slice, change_slice) = output.split_at_mut(len);
991
992    acosc_into_slice(osc_slice, change_slice, &input, Kernel::Auto)
993        .map_err(|e| JsValue::from_str(&e.to_string()))?;
994
995    Ok(output)
996}
997
998#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
999#[derive(Serialize, Deserialize)]
1000pub struct AcoscBatchConfig {}
1001
1002#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1003#[derive(Serialize, Deserialize)]
1004pub struct AcoscBatchJsOutput {
1005    pub values: Vec<f64>,
1006    pub rows: usize,
1007    pub cols: usize,
1008}
1009
1010#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1011#[wasm_bindgen(js_name = acosc_batch)]
1012pub fn acosc_batch_unified_js(
1013    high: &[f64],
1014    low: &[f64],
1015    _config: JsValue,
1016) -> Result<JsValue, JsValue> {
1017    let rows = 1;
1018    let cols = high.len();
1019
1020    let total = cols
1021        .checked_mul(2)
1022        .ok_or_else(|| JsValue::from_str("acosc_batch_unified_js: size overflow"))?;
1023    let mut output = vec![0.0; total];
1024
1025    let (osc_slice, change_slice) = output.split_at_mut(cols);
1026
1027    let params = AcoscParams::default();
1028    let input = AcoscInput::from_slices(high, low, params);
1029
1030    acosc_into_slice(osc_slice, change_slice, &input, Kernel::Auto)
1031        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1032
1033    let js_output = AcoscBatchJsOutput {
1034        values: output,
1035        rows,
1036        cols,
1037    };
1038
1039    serde_wasm_bindgen::to_value(&js_output)
1040        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1041}
1042
1043#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1044#[wasm_bindgen]
1045pub fn acosc_batch_js(high: &[f64], low: &[f64]) -> Result<Vec<f64>, JsValue> {
1046    acosc_js(high, low)
1047}
1048
1049#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1050#[wasm_bindgen]
1051pub fn acosc_batch_metadata_js() -> Result<Vec<f64>, JsValue> {
1052    Ok(vec![])
1053}
1054
1055pub fn acosc_into_slice(
1056    osc_dst: &mut [f64],
1057    change_dst: &mut [f64],
1058    input: &AcoscInput,
1059    kern: Kernel,
1060) -> Result<(), AcoscError> {
1061    let (high, low, first, kernel) = acosc_prepare(input, kern)?;
1062
1063    if osc_dst.len() != high.len() {
1064        return Err(AcoscError::OutputLengthMismatch {
1065            expected: high.len(),
1066            got: osc_dst.len(),
1067        });
1068    }
1069    if change_dst.len() != high.len() {
1070        return Err(AcoscError::OutputLengthMismatch {
1071            expected: high.len(),
1072            got: change_dst.len(),
1073        });
1074    }
1075
1076    const WARMUP: usize = 38;
1077    let warm = first + WARMUP;
1078    for i in 0..warm.min(osc_dst.len()) {
1079        osc_dst[i] = f64::from_bits(0x7ff8_0000_0000_0000);
1080        change_dst[i] = f64::from_bits(0x7ff8_0000_0000_0000);
1081    }
1082
1083    let valid = high.len() - first;
1084    if first < high.len() && valid > WARMUP {
1085        acosc_compute_into(
1086            &high[first..],
1087            &low[first..],
1088            kernel,
1089            &mut osc_dst[first..],
1090            &mut change_dst[first..],
1091        );
1092    }
1093    Ok(())
1094}
1095
1096#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1097#[inline]
1098pub fn acosc_into(
1099    input: &AcoscInput,
1100    osc_out: &mut [f64],
1101    change_out: &mut [f64],
1102) -> Result<(), AcoscError> {
1103    acosc_into_slice(osc_out, change_out, input, Kernel::Auto)
1104}
1105
1106#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1107#[wasm_bindgen]
1108pub fn acosc_into(
1109    high_ptr: *const f64,
1110    low_ptr: *const f64,
1111    osc_ptr: *mut f64,
1112    change_ptr: *mut f64,
1113    len: usize,
1114) -> Result<(), JsValue> {
1115    if high_ptr.is_null() || low_ptr.is_null() || osc_ptr.is_null() || change_ptr.is_null() {
1116        return Err(JsValue::from_str("null pointer passed to acosc_into"));
1117    }
1118
1119    unsafe {
1120        let high = std::slice::from_raw_parts(high_ptr, len);
1121        let low = std::slice::from_raw_parts(low_ptr, len);
1122
1123        if len < 39 {
1124            return Err(JsValue::from_str("Not enough data"));
1125        }
1126
1127        let params = AcoscParams::default();
1128        let input = AcoscInput::from_slices(high, low, params);
1129
1130        let need_temp = high_ptr == osc_ptr as *const f64
1131            || high_ptr == change_ptr as *const f64
1132            || low_ptr == osc_ptr as *const f64
1133            || low_ptr == change_ptr as *const f64
1134            || osc_ptr == change_ptr;
1135
1136        if need_temp {
1137            let mut temp_osc = vec![0.0; len];
1138            let mut temp_change = vec![0.0; len];
1139
1140            acosc_into_slice(&mut temp_osc, &mut temp_change, &input, Kernel::Auto)
1141                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1142
1143            let osc_out = std::slice::from_raw_parts_mut(osc_ptr, len);
1144            let change_out = std::slice::from_raw_parts_mut(change_ptr, len);
1145            osc_out.copy_from_slice(&temp_osc);
1146            change_out.copy_from_slice(&temp_change);
1147        } else {
1148            let osc_out = std::slice::from_raw_parts_mut(osc_ptr, len);
1149            let change_out = std::slice::from_raw_parts_mut(change_ptr, len);
1150
1151            acosc_into_slice(osc_out, change_out, &input, Kernel::Auto)
1152                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1153        }
1154
1155        Ok(())
1156    }
1157}
1158
1159#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1160#[wasm_bindgen]
1161pub fn acosc_alloc(len: usize) -> *mut f64 {
1162    let mut vec = Vec::<f64>::with_capacity(len);
1163    let ptr = vec.as_mut_ptr();
1164    std::mem::forget(vec);
1165    ptr
1166}
1167
1168#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1169#[wasm_bindgen]
1170pub fn acosc_free(ptr: *mut f64, len: usize) {
1171    if !ptr.is_null() {
1172        unsafe {
1173            let _ = Vec::from_raw_parts(ptr, len, len);
1174        }
1175    }
1176}
1177
1178#[cfg(test)]
1179mod tests {
1180    use super::*;
1181    use crate::skip_if_unsupported;
1182    use crate::utilities::data_loader::read_candles_from_csv;
1183    use std::error::Error;
1184
1185    fn check_acosc_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1186        skip_if_unsupported!(kernel, test_name);
1187        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1188        let candles = read_candles_from_csv(file_path)?;
1189        let default_params = AcoscParams::default();
1190        let input = AcoscInput::from_candles(&candles, default_params);
1191        let output = acosc_with_kernel(&input, kernel)?;
1192        assert_eq!(output.osc.len(), candles.close.len());
1193        assert_eq!(output.change.len(), candles.close.len());
1194        Ok(())
1195    }
1196
1197    fn check_acosc_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1198        skip_if_unsupported!(kernel, test_name);
1199        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1200        let candles = read_candles_from_csv(file_path)?;
1201        let input = AcoscInput::with_default_candles(&candles);
1202        let result = acosc_with_kernel(&input, kernel)?;
1203        assert_eq!(result.osc.len(), candles.close.len());
1204        assert_eq!(result.change.len(), candles.close.len());
1205        let expected_last_five_acosc_osc = [273.30, 383.72, 357.7, 291.25, 176.84];
1206        let expected_last_five_acosc_change = [49.6, 110.4, -26.0, -66.5, -114.4];
1207        let start = result.osc.len().saturating_sub(5);
1208        for (i, &val) in result.osc[start..].iter().enumerate() {
1209            assert!(
1210                (val - expected_last_five_acosc_osc[i]).abs() < 1e-1,
1211                "[{}] ACOSC {:?} osc mismatch idx {}: got {}, expected {}",
1212                test_name,
1213                kernel,
1214                i,
1215                val,
1216                expected_last_five_acosc_osc[i]
1217            );
1218        }
1219        for (i, &val) in result.change[start..].iter().enumerate() {
1220            assert!(
1221                (val - expected_last_five_acosc_change[i]).abs() < 1e-1,
1222                "[{}] ACOSC {:?} change mismatch idx {}: got {}, expected {}",
1223                test_name,
1224                kernel,
1225                i,
1226                val,
1227                expected_last_five_acosc_change[i]
1228            );
1229        }
1230        Ok(())
1231    }
1232
1233    fn check_acosc_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1234        skip_if_unsupported!(kernel, test_name);
1235        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1236        let candles = read_candles_from_csv(file_path)?;
1237        let input = AcoscInput::with_default_candles(&candles);
1238        match input.data {
1239            AcoscData::Candles { .. } => {}
1240            _ => panic!("Expected AcoscData::Candles variant"),
1241        }
1242        let output = acosc_with_kernel(&input, kernel)?;
1243        assert_eq!(output.osc.len(), candles.close.len());
1244        Ok(())
1245    }
1246
1247    fn check_acosc_too_short(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1248        skip_if_unsupported!(kernel, test_name);
1249        let high = [100.0, 101.0];
1250        let low = [99.0, 98.0];
1251        let params = AcoscParams::default();
1252        let input = AcoscInput::from_slices(&high, &low, params);
1253        let result = acosc_with_kernel(&input, kernel);
1254        assert!(
1255            result.is_err(),
1256            "[{}] Should fail with not enough data",
1257            test_name
1258        );
1259        Ok(())
1260    }
1261
1262    fn check_acosc_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1263        skip_if_unsupported!(kernel, test_name);
1264        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1265        let candles = read_candles_from_csv(file_path)?;
1266        let input = AcoscInput::with_default_candles(&candles);
1267        let first_result = acosc_with_kernel(&input, kernel)?;
1268        assert_eq!(first_result.osc.len(), candles.close.len());
1269        assert_eq!(first_result.change.len(), candles.close.len());
1270        let input2 = AcoscInput::from_slices(&candles.high, &candles.low, AcoscParams::default());
1271        let second_result = acosc_with_kernel(&input2, kernel)?;
1272        assert_eq!(second_result.osc.len(), candles.close.len());
1273        for (a, b) in second_result.osc.iter().zip(first_result.osc.iter()) {
1274            if a.is_nan() && b.is_nan() {
1275                continue;
1276            }
1277            assert!(
1278                (a - b).abs() < 1e-8,
1279                "Reinput values mismatch: {} vs {}",
1280                a,
1281                b
1282            );
1283        }
1284        Ok(())
1285    }
1286
1287    fn check_acosc_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1288        skip_if_unsupported!(kernel, test_name);
1289        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1290        let candles = read_candles_from_csv(file_path)?;
1291        let input = AcoscInput::with_default_candles(&candles);
1292        let result = acosc_with_kernel(&input, kernel)?;
1293        if result.osc.len() > 240 {
1294            for i in 240..result.osc.len() {
1295                assert!(!result.osc[i].is_nan(), "Found NaN in osc at {}", i);
1296                assert!(!result.change[i].is_nan(), "Found NaN in change at {}", i);
1297            }
1298        }
1299        Ok(())
1300    }
1301
1302    fn check_acosc_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1303        skip_if_unsupported!(kernel, test_name);
1304        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1305        let candles = read_candles_from_csv(file_path)?;
1306        let input = AcoscInput::with_default_candles(&candles);
1307        let batch = acosc_with_kernel(&input, kernel)?;
1308        let mut stream = AcoscStream::try_new(AcoscParams::default())?;
1309        let mut osc_stream = Vec::with_capacity(candles.close.len());
1310        let mut change_stream = Vec::with_capacity(candles.close.len());
1311        for (&h, &l) in candles.high.iter().zip(candles.low.iter()) {
1312            match stream.update(h, l) {
1313                Some((o, c)) => {
1314                    osc_stream.push(o);
1315                    change_stream.push(c);
1316                }
1317                None => {
1318                    osc_stream.push(f64::NAN);
1319                    change_stream.push(f64::NAN);
1320                }
1321            }
1322        }
1323        assert_eq!(batch.osc.len(), osc_stream.len());
1324        assert_eq!(batch.change.len(), change_stream.len());
1325        for (i, (&a, &b)) in batch.osc.iter().zip(osc_stream.iter()).enumerate() {
1326            if a.is_nan() && b.is_nan() {
1327                continue;
1328            }
1329            assert!(
1330                (a - b).abs() < 1e-9,
1331                "Streaming osc mismatch at idx {}: {} vs {}",
1332                i,
1333                a,
1334                b
1335            );
1336        }
1337        for (i, (&a, &b)) in batch.change.iter().zip(change_stream.iter()).enumerate() {
1338            if a.is_nan() && b.is_nan() {
1339                continue;
1340            }
1341            assert!(
1342                (a - b).abs() < 1e-9,
1343                "Streaming change mismatch at idx {}: {} vs {}",
1344                i,
1345                a,
1346                b
1347            );
1348        }
1349        Ok(())
1350    }
1351
1352    macro_rules! generate_all_acosc_tests {
1353        ($($test_fn:ident),*) => {
1354            paste::paste! {
1355                $(#[test]
1356                  fn [<$test_fn _scalar_f64>]() {
1357                      let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1358                  })*
1359                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1360                $(#[test]
1361                  fn [<$test_fn _avx2_f64>]() {
1362                      let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1363                  }
1364                  #[test]
1365                  fn [<$test_fn _avx512_f64>]() {
1366                      let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1367                  })*
1368            }
1369        }
1370    }
1371    generate_all_acosc_tests!(
1372        check_acosc_partial_params,
1373        check_acosc_accuracy,
1374        check_acosc_default_candles,
1375        check_acosc_too_short,
1376        check_acosc_reinput,
1377        check_acosc_nan_handling,
1378        check_acosc_streaming,
1379        check_acosc_no_poison
1380    );
1381
1382    #[cfg(feature = "proptest")]
1383    generate_all_acosc_tests!(check_acosc_property);
1384
1385    #[cfg(feature = "proptest")]
1386    fn check_acosc_property(
1387        test_name: &str,
1388        kernel: Kernel,
1389    ) -> Result<(), Box<dyn std::error::Error>> {
1390        use proptest::prelude::*;
1391        skip_if_unsupported!(kernel, test_name);
1392
1393        let strat = (40usize..=400).prop_flat_map(|len| {
1394            prop::collection::vec(
1395                (1.0f64..10000.0f64)
1396                    .prop_flat_map(|base_price| {
1397                        (0.0f64..0.1f64).prop_map(move |spread_pct| {
1398                            let half_spread = base_price * spread_pct * 0.5;
1399                            let high = base_price + half_spread;
1400                            let low = base_price - half_spread;
1401                            (high, low)
1402                        })
1403                    })
1404                    .prop_filter("prices must be finite", |(h, l)| {
1405                        h.is_finite() && l.is_finite()
1406                    }),
1407                len,
1408            )
1409        });
1410
1411        proptest::test_runner::TestRunner::default().run(&strat, |price_pairs| {
1412            let (high_vec, low_vec): (Vec<f64>, Vec<f64>) = price_pairs.into_iter().unzip();
1413            let params = AcoscParams::default();
1414            let input = AcoscInput::from_slices(&high_vec, &low_vec, params);
1415
1416            let result = acosc_with_kernel(&input, kernel).unwrap();
1417            let scalar_result = acosc_with_kernel(&input, Kernel::Scalar).unwrap();
1418
1419            for i in 0..result.osc.len() {
1420                let y = result.osc[i];
1421                let r = scalar_result.osc[i];
1422
1423                if !y.is_finite() || !r.is_finite() {
1424                    prop_assert_eq!(
1425                        y.to_bits(),
1426                        r.to_bits(),
1427                        "NaN/finite mismatch in osc at idx {}: {} vs {}",
1428                        i,
1429                        y,
1430                        r
1431                    );
1432                    continue;
1433                }
1434
1435                let y_bits = y.to_bits();
1436                let r_bits = r.to_bits();
1437                let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1438
1439                prop_assert!(
1440                    (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1441                    "Kernel mismatch in osc at idx {}: {} vs {} (ULP={})",
1442                    i,
1443                    y,
1444                    r,
1445                    ulp_diff
1446                );
1447            }
1448
1449            for i in 0..result.change.len() {
1450                let y = result.change[i];
1451                let r = scalar_result.change[i];
1452
1453                if !y.is_finite() || !r.is_finite() {
1454                    prop_assert_eq!(
1455                        y.to_bits(),
1456                        r.to_bits(),
1457                        "NaN/finite mismatch in change at idx {}: {} vs {}",
1458                        i,
1459                        y,
1460                        r
1461                    );
1462                    continue;
1463                }
1464
1465                let y_bits = y.to_bits();
1466                let r_bits = r.to_bits();
1467                let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1468
1469                prop_assert!(
1470                    (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1471                    "Kernel mismatch in change at idx {}: {} vs {} (ULP={})",
1472                    i,
1473                    y,
1474                    r,
1475                    ulp_diff
1476                );
1477            }
1478
1479            for i in 0..38.min(result.osc.len()) {
1480                prop_assert!(
1481                    result.osc[i].is_nan(),
1482                    "Expected NaN in osc warmup at idx {}, got {}",
1483                    i,
1484                    result.osc[i]
1485                );
1486                prop_assert!(
1487                    result.change[i].is_nan(),
1488                    "Expected NaN in change warmup at idx {}, got {}",
1489                    i,
1490                    result.change[i]
1491                );
1492            }
1493
1494            if result.osc.len() > 38 {
1495                prop_assert!(
1496                    result.osc[38].is_finite(),
1497                    "Expected finite value at idx 38 in osc, got {}",
1498                    result.osc[38]
1499                );
1500                prop_assert!(
1501                    result.change[38].is_finite(),
1502                    "Expected finite value at idx 38 in change, got {}",
1503                    result.change[38]
1504                );
1505            }
1506
1507            for i in 39..result.osc.len() {
1508                if result.osc[i].is_finite() && result.osc[i - 1].is_finite() {
1509                    let expected_change = result.osc[i] - result.osc[i - 1];
1510                    let actual_change = result.change[i];
1511
1512                    prop_assert!(
1513                        (expected_change - actual_change).abs() <= 1e-9,
1514                        "Change formula mismatch at idx {}: expected {} ({}−{}), got {}",
1515                        i,
1516                        expected_change,
1517                        result.osc[i],
1518                        result.osc[i - 1],
1519                        actual_change
1520                    );
1521                }
1522            }
1523
1524            if high_vec.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10)
1525                && low_vec.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10)
1526            {
1527                for i in 39..result.osc.len() {
1528                    prop_assert!(
1529                        result.osc[i].abs() <= 1e-6,
1530                        "Expected near-zero osc with constant prices at idx {}, got {}",
1531                        i,
1532                        result.osc[i]
1533                    );
1534                }
1535            }
1536
1537            Ok(())
1538        })?;
1539
1540        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1541        let candles = read_candles_from_csv(file_path)?;
1542        let test_len = candles.high.len().min(200);
1543        let high_data = &candles.high[..test_len];
1544        let low_data = &candles.low[..test_len];
1545
1546        {
1547            let params = AcoscParams::default();
1548            let input = AcoscInput::from_slices(high_data, low_data, params.clone());
1549            let batch_result = acosc_with_kernel(&input, kernel)?;
1550
1551            let mut stream = AcoscStream::try_new(params)?;
1552            let mut stream_osc = Vec::with_capacity(test_len);
1553            let mut stream_change = Vec::with_capacity(test_len);
1554
1555            for i in 0..test_len {
1556                match stream.update(high_data[i], low_data[i]) {
1557                    Some((osc, change)) => {
1558                        stream_osc.push(osc);
1559                        stream_change.push(change);
1560                    }
1561                    None => {
1562                        stream_osc.push(f64::NAN);
1563                        stream_change.push(f64::NAN);
1564                    }
1565                }
1566            }
1567
1568            for i in 0..test_len {
1569                let batch_o = batch_result.osc[i];
1570                let stream_o = stream_osc[i];
1571
1572                if batch_o.is_nan() && stream_o.is_nan() {
1573                    continue;
1574                }
1575
1576                assert!(
1577                    (batch_o - stream_o).abs() <= 1e-9,
1578                    "[{}] Streaming vs batch mismatch in osc at idx {}: {} vs {}",
1579                    test_name,
1580                    i,
1581                    batch_o,
1582                    stream_o
1583                );
1584
1585                let batch_c = batch_result.change[i];
1586                let stream_c = stream_change[i];
1587
1588                if batch_c.is_nan() && stream_c.is_nan() {
1589                    continue;
1590                }
1591
1592                assert!(
1593                    (batch_c - stream_c).abs() <= 1e-9,
1594                    "[{}] Streaming vs batch mismatch in change at idx {}: {} vs {}",
1595                    test_name,
1596                    i,
1597                    batch_c,
1598                    stream_c
1599                );
1600            }
1601        }
1602
1603        Ok(())
1604    }
1605
1606    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1607        skip_if_unsupported!(kernel, test);
1608        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1609        let c = read_candles_from_csv(file)?;
1610        let output = AcoscBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
1611        assert_eq!(output.osc.len(), c.close.len());
1612        Ok(())
1613    }
1614
1615    #[cfg(debug_assertions)]
1616    fn check_acosc_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1617        skip_if_unsupported!(kernel, test_name);
1618        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1619        let candles = read_candles_from_csv(file_path)?;
1620        let input = AcoscInput::with_default_candles(&candles);
1621        let output = acosc_with_kernel(&input, kernel)?;
1622
1623        for (i, &val) in output.osc.iter().enumerate() {
1624            if val.is_nan() {
1625                continue;
1626            }
1627
1628            let bits = val.to_bits();
1629
1630            if bits == 0x11111111_11111111 {
1631                panic!(
1632					"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in osc",
1633					test_name, val, bits, i
1634				);
1635            }
1636        }
1637
1638        for (i, &val) in output.change.iter().enumerate() {
1639            if val.is_nan() {
1640                continue;
1641            }
1642
1643            let bits = val.to_bits();
1644
1645            if bits == 0x11111111_11111111 {
1646                panic!(
1647					"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in change",
1648					test_name, val, bits, i
1649				);
1650            }
1651        }
1652
1653        Ok(())
1654    }
1655
1656    #[cfg(debug_assertions)]
1657    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1658        skip_if_unsupported!(kernel, test);
1659        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1660        let c = read_candles_from_csv(file)?;
1661        let output = AcoscBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
1662
1663        for (idx, &val) in output.osc.iter().enumerate() {
1664            if val.is_nan() {
1665                continue;
1666            }
1667
1668            let bits = val.to_bits();
1669
1670            if bits == 0x11111111_11111111 {
1671                panic!(
1672					"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in osc",
1673					test, val, bits, idx
1674				);
1675            }
1676
1677            if bits == 0x22222222_22222222 {
1678                panic!(
1679					"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in osc",
1680					test, val, bits, idx
1681				);
1682            }
1683
1684            if bits == 0x33333333_33333333 {
1685                panic!(
1686                    "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in osc",
1687                    test, val, bits, idx
1688                );
1689            }
1690        }
1691
1692        for (idx, &val) in output.change.iter().enumerate() {
1693            if val.is_nan() {
1694                continue;
1695            }
1696
1697            let bits = val.to_bits();
1698
1699            if bits == 0x11111111_11111111 {
1700                panic!(
1701					"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in change",
1702					test, val, bits, idx
1703				);
1704            }
1705
1706            if bits == 0x22222222_22222222 {
1707                panic!(
1708					"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in change",
1709					test, val, bits, idx
1710				);
1711            }
1712
1713            if bits == 0x33333333_33333333 {
1714                panic!(
1715					"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in change",
1716					test, val, bits, idx
1717				);
1718            }
1719        }
1720
1721        Ok(())
1722    }
1723
1724    #[cfg(not(debug_assertions))]
1725    fn check_acosc_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1726        Ok(())
1727    }
1728
1729    #[cfg(not(debug_assertions))]
1730    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1731        Ok(())
1732    }
1733    macro_rules! gen_batch_tests {
1734        ($fn_name:ident) => {
1735            paste::paste! {
1736                #[test] fn [<$fn_name _scalar>]()      {
1737                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1738                }
1739                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1740                #[test] fn [<$fn_name _avx2>]()        {
1741                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1742                }
1743                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1744                #[test] fn [<$fn_name _avx512>]()      {
1745                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1746                }
1747                #[test] fn [<$fn_name _auto_detect>]() {
1748                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1749                }
1750            }
1751        };
1752    }
1753    gen_batch_tests!(check_batch_default_row);
1754    gen_batch_tests!(check_batch_no_poison);
1755
1756    #[test]
1757    fn test_batch_kernel_error() {
1758        let high = vec![100.0; 50];
1759        let low = vec![99.0; 50];
1760
1761        let result = acosc_batch_with_kernel(&high, &low, Kernel::Scalar);
1762        assert!(result.is_err());
1763
1764        match result.unwrap_err() {
1765            AcoscError::InvalidKernelForBatch(kernel) => {
1766                assert_eq!(kernel, Kernel::Scalar);
1767            }
1768            _ => panic!("Expected InvalidKernelForBatch error"),
1769        }
1770
1771        let result = acosc_batch_with_kernel(&high, &low, Kernel::Avx2);
1772        assert!(matches!(
1773            result,
1774            Err(AcoscError::InvalidKernelForBatch(Kernel::Avx2))
1775        ));
1776    }
1777
1778    #[test]
1779    fn test_acosc_into_matches_api() -> Result<(), Box<dyn Error>> {
1780        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1781        let candles = read_candles_from_csv(file_path)?;
1782        let n = candles.high.len().min(512).max(64);
1783        let high = &candles.high[..n];
1784        let low = &candles.low[..n];
1785
1786        let params = AcoscParams::default();
1787        let input = AcoscInput::from_slices(high, low, params);
1788
1789        let base = acosc(&input)?;
1790
1791        let mut out_osc = vec![0.0; n];
1792        let mut out_change = vec![0.0; n];
1793
1794        acosc_into(&input, &mut out_osc, &mut out_change)?;
1795
1796        assert_eq!(base.osc.len(), out_osc.len());
1797        assert_eq!(base.change.len(), out_change.len());
1798
1799        fn eq_or_both_nan(a: f64, b: f64) -> bool {
1800            (a.is_nan() && b.is_nan()) || (a == b)
1801        }
1802
1803        for i in 0..n {
1804            assert!(
1805                eq_or_both_nan(base.osc[i], out_osc[i]),
1806                "osc mismatch at {}: base={} out={}",
1807                i,
1808                base.osc[i],
1809                out_osc[i]
1810            );
1811            assert!(
1812                eq_or_both_nan(base.change[i], out_change[i]),
1813                "change mismatch at {}: base={} out={}",
1814                i,
1815                base.change[i],
1816                out_change[i]
1817            );
1818        }
1819
1820        Ok(())
1821    }
1822}