Skip to main content

vector_ta/indicators/
cci.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::{PyDict, PyList};
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15#[cfg(all(feature = "python", feature = "cuda"))]
16use crate::cuda::moving_averages::DeviceArrayF32;
17#[cfg(all(feature = "python", feature = "cuda"))]
18use crate::cuda::oscillators::CudaCci;
19use crate::utilities::data_loader::{source_type, Candles};
20#[cfg(all(feature = "python", feature = "cuda"))]
21use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
22use crate::utilities::enums::Kernel;
23use crate::utilities::helpers::{
24    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, make_uninit_matrix,
25};
26#[cfg(feature = "python")]
27use crate::utilities::kernel_validation::validate_kernel;
28#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
29use core::arch::x86_64::*;
30#[cfg(all(feature = "python", feature = "cuda"))]
31use cust::context::Context;
32#[cfg(all(feature = "python", feature = "cuda"))]
33use cust::memory::DeviceBuffer;
34#[cfg(not(target_arch = "wasm32"))]
35use rayon::prelude::*;
36use std::convert::AsRef;
37use std::error::Error;
38use std::mem::MaybeUninit;
39#[cfg(all(feature = "python", feature = "cuda"))]
40use std::sync::Arc;
41use thiserror::Error;
42
43#[derive(Debug, Clone)]
44pub enum CciData<'a> {
45    Candles {
46        candles: &'a Candles,
47        source: &'a str,
48    },
49    Slice(&'a [f64]),
50}
51
52#[derive(Debug, Clone)]
53pub struct CciOutput {
54    pub values: Vec<f64>,
55}
56
57#[derive(Debug, Clone)]
58#[cfg_attr(
59    all(target_arch = "wasm32", feature = "wasm"),
60    derive(Serialize, Deserialize)
61)]
62pub struct CciParams {
63    pub period: Option<usize>,
64}
65
66impl Default for CciParams {
67    fn default() -> Self {
68        Self { period: Some(14) }
69    }
70}
71
72#[derive(Debug, Clone)]
73pub struct CciInput<'a> {
74    pub data: CciData<'a>,
75    pub params: CciParams,
76}
77
78impl<'a> AsRef<[f64]> for CciInput<'a> {
79    #[inline(always)]
80    fn as_ref(&self) -> &[f64] {
81        match &self.data {
82            CciData::Slice(slice) => slice,
83            CciData::Candles { candles, source } => source_type(candles, source),
84        }
85    }
86}
87
88impl<'a> CciInput<'a> {
89    #[inline]
90    pub fn from_candles(c: &'a Candles, s: &'a str, p: CciParams) -> Self {
91        Self {
92            data: CciData::Candles {
93                candles: c,
94                source: s,
95            },
96            params: p,
97        }
98    }
99    #[inline]
100    pub fn from_slice(sl: &'a [f64], p: CciParams) -> Self {
101        Self {
102            data: CciData::Slice(sl),
103            params: p,
104        }
105    }
106    #[inline]
107    pub fn with_default_candles(c: &'a Candles) -> Self {
108        Self::from_candles(c, "hlc3", CciParams::default())
109    }
110    #[inline]
111    pub fn get_period(&self) -> usize {
112        self.params.period.unwrap_or(14)
113    }
114    #[inline]
115    pub fn data_len(&self) -> usize {
116        match &self.data {
117            CciData::Slice(slice) => slice.len(),
118            CciData::Candles { candles, .. } => candles.close.len(),
119        }
120    }
121}
122
123#[derive(Copy, Clone, Debug)]
124pub struct CciBuilder {
125    period: Option<usize>,
126    kernel: Kernel,
127}
128
129impl Default for CciBuilder {
130    fn default() -> Self {
131        Self {
132            period: None,
133            kernel: Kernel::Auto,
134        }
135    }
136}
137
138impl CciBuilder {
139    #[inline(always)]
140    pub fn new() -> Self {
141        Self::default()
142    }
143    #[inline(always)]
144    pub fn period(mut self, n: usize) -> Self {
145        self.period = Some(n);
146        self
147    }
148    #[inline(always)]
149    pub fn kernel(mut self, k: Kernel) -> Self {
150        self.kernel = k;
151        self
152    }
153    #[inline(always)]
154    pub fn apply(self, c: &Candles) -> Result<CciOutput, CciError> {
155        let p = CciParams {
156            period: self.period,
157        };
158        let i = CciInput::from_candles(c, "hlc3", p);
159        cci_with_kernel(&i, self.kernel)
160    }
161    #[inline(always)]
162    pub fn apply_slice(self, d: &[f64]) -> Result<CciOutput, CciError> {
163        let p = CciParams {
164            period: self.period,
165        };
166        let i = CciInput::from_slice(d, p);
167        cci_with_kernel(&i, self.kernel)
168    }
169    #[inline(always)]
170    pub fn into_stream(self) -> Result<CciStream, CciError> {
171        let p = CciParams {
172            period: self.period,
173        };
174        CciStream::try_new(p)
175    }
176}
177
178#[derive(Debug, Error)]
179pub enum CciError {
180    #[error("cci: Input data slice is empty.")]
181    EmptyInputData,
182    #[error("cci: All values are NaN.")]
183    AllValuesNaN,
184    #[error("cci: Invalid period: period = {period}, data length = {data_len}")]
185    InvalidPeriod { period: usize, data_len: usize },
186    #[error("cci: Not enough valid data: needed = {needed}, valid = {valid}")]
187    NotEnoughValidData { needed: usize, valid: usize },
188    #[error("cci: output length mismatch: expected {expected}, got {got}")]
189    OutputLengthMismatch { expected: usize, got: usize },
190    #[error("cci: invalid range expansion: start={start} end={end} step={step}")]
191    InvalidRange {
192        start: usize,
193        end: usize,
194        step: usize,
195    },
196    #[error("cci: invalid kernel for batch path: {0:?}")]
197    InvalidKernelForBatch(crate::utilities::enums::Kernel),
198}
199
200#[inline]
201pub fn cci(input: &CciInput) -> Result<CciOutput, CciError> {
202    cci_with_kernel(input, Kernel::Auto)
203}
204
205#[inline(always)]
206fn cci_prepare<'a>(
207    input: &'a CciInput,
208    kernel: Kernel,
209) -> Result<(&'a [f64], usize, usize, Kernel), CciError> {
210    let data: &[f64] = input.as_ref();
211    let len = data.len();
212    if len == 0 {
213        return Err(CciError::EmptyInputData);
214    }
215    let first = data
216        .iter()
217        .position(|x| !x.is_nan())
218        .ok_or(CciError::AllValuesNaN)?;
219    let period = input.get_period();
220
221    if period == 0 || period > len {
222        return Err(CciError::InvalidPeriod {
223            period,
224            data_len: len,
225        });
226    }
227    if len - first < period {
228        return Err(CciError::NotEnoughValidData {
229            needed: period,
230            valid: len - first,
231        });
232    }
233
234    let chosen = match kernel {
235        Kernel::Auto => detect_best_kernel(),
236        k => k,
237    };
238
239    Ok((data, period, first, chosen))
240}
241
242pub fn cci_with_kernel(input: &CciInput, kernel: Kernel) -> Result<CciOutput, CciError> {
243    let (data, period, first, chosen) = cci_prepare(input, kernel)?;
244
245    let prefix = first + period - 1;
246    let mut out = alloc_with_nan_prefix(data.len(), prefix);
247
248    unsafe {
249        match chosen {
250            Kernel::Scalar | Kernel::ScalarBatch => cci_scalar(data, period, first, &mut out),
251            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
252            Kernel::Avx2 | Kernel::Avx2Batch => cci_avx2(data, period, first, &mut out),
253            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
254            Kernel::Avx512 | Kernel::Avx512Batch => cci_avx512(data, period, first, &mut out),
255            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
256            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
257                cci_scalar(data, period, first, &mut out)
258            }
259            _ => unreachable!(),
260        }
261    }
262    Ok(CciOutput { values: out })
263}
264
265#[inline]
266pub fn cci_into_slice(dst: &mut [f64], input: &CciInput, kern: Kernel) -> Result<(), CciError> {
267    let (data, period, first, chosen) = cci_prepare(input, kern)?;
268
269    if dst.len() != data.len() {
270        return Err(CciError::OutputLengthMismatch {
271            expected: data.len(),
272            got: dst.len(),
273        });
274    }
275
276    unsafe {
277        match chosen {
278            Kernel::Scalar | Kernel::ScalarBatch => cci_scalar(data, period, first, dst),
279            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
280            Kernel::Avx2 | Kernel::Avx2Batch => cci_avx2(data, period, first, dst),
281            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
282            Kernel::Avx512 | Kernel::Avx512Batch => cci_avx512(data, period, first, dst),
283            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
284            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
285                cci_scalar(data, period, first, dst)
286            }
287            _ => unreachable!(),
288        }
289    }
290
291    let warmup_end = first + period - 1;
292    for v in &mut dst[..warmup_end] {
293        *v = f64::NAN;
294    }
295
296    Ok(())
297}
298
299#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
300#[inline]
301pub fn cci_into(input: &CciInput, out: &mut [f64]) -> Result<(), CciError> {
302    cci_into_slice(out, input, Kernel::Auto)
303}
304
305#[inline]
306pub fn cci_scalar(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
307    debug_assert_eq!(data.len(), out.len());
308    let n = data.len();
309    if n == 0 {
310        return;
311    }
312
313    let inv_p = 1.0 / (period as f64);
314
315    let start0 = first_valid;
316    let end0 = start0 + period;
317    let mut sum: f64 = data[start0..end0].iter().sum();
318    let mut sma = sum * inv_p;
319
320    let mut sum_abs = 0.0;
321    for &v in &data[start0..end0] {
322        sum_abs += (v - sma).abs();
323    }
324
325    let first_out = first_valid + period - 1;
326    let price0 = data[first_out];
327    out[first_out] = {
328        let denom = 0.015 * (sum_abs * inv_p);
329        if denom == 0.0 {
330            0.0
331        } else {
332            (price0 - sma) / denom
333        }
334    };
335
336    for i in (first_out + 1)..n {
337        let exiting = data[i - period];
338        let entering = data[i];
339        sum = sum - exiting + entering;
340        sma = sum * inv_p;
341
342        let wstart = i + 1 - period;
343        let wend = i + 1;
344        let mut sabs = 0.0;
345        for &v in &data[wstart..wend] {
346            sabs += (v - sma).abs();
347        }
348
349        out[i] = {
350            let denom = 0.015 * (sabs * inv_p);
351            if denom == 0.0 {
352                0.0
353            } else {
354                (entering - sma) / denom
355            }
356        };
357    }
358}
359
360#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
361#[inline]
362pub fn cci_avx512(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
363    cci_scalar(data, period, first_valid, out)
364}
365
366#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
367#[inline]
368pub fn cci_avx2(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
369    cci_scalar(data, period, first_valid, out)
370}
371
372#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
373#[target_feature(enable = "avx2,fma")]
374unsafe fn cci_avx2_impl(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
375    debug_assert!(data.len() == out.len());
376    debug_assert!(period >= 1 && first_valid + period <= data.len());
377
378    let n = data.len();
379    let inv_p = 1.0 / (period as f64);
380    let scale = (period as f64) * (1.0 / 0.015);
381
382    let base = data.as_ptr().add(first_valid);
383    let mut sum = 0.0;
384    {
385        let mut k = 0usize;
386        while k + 4 <= period {
387            let x0 = *base.add(k + 0);
388            let x1 = *base.add(k + 1);
389            let x2 = *base.add(k + 2);
390            let x3 = *base.add(k + 3);
391            sum = sum + x0 + x1 + x2 + x3;
392            k += 4;
393        }
394        while k < period {
395            sum += *base.add(k);
396            k += 1;
397        }
398    }
399
400    let first_out = first_valid + period - 1;
401    let mut sma = sum * inv_p;
402
403    {
404        let vmean = _mm256_set1_pd(sma);
405        let vsgn = _mm256_set1_pd(-0.0f64);
406        let mut k = 0usize;
407        let mut sum_abs = 0.0f64;
408        let mut comp = 0.0f64;
409        while k + 4 <= period {
410            let x = _mm256_loadu_pd(base.add(k));
411            let d = _mm256_sub_pd(x, vmean);
412            let a = _mm256_andnot_pd(vsgn, d);
413            let mut lane = [0.0f64; 4];
414            _mm256_storeu_pd(lane.as_mut_ptr(), a);
415
416            for &val in &lane {
417                let y = val - comp;
418                let t = sum_abs + y;
419                comp = (t - sum_abs) - y;
420                sum_abs = t;
421            }
422            k += 4;
423        }
424        while k < period {
425            let val = (*base.add(k) - sma).abs();
426            let y = val - comp;
427            let t = sum_abs + y;
428            comp = (t - sum_abs) - y;
429            sum_abs = t;
430            k += 1;
431        }
432        let price0 = *data.get_unchecked(first_out);
433        let denom = 0.015 * (sum_abs * inv_p);
434        *out.get_unchecked_mut(first_out) = if denom == 0.0 {
435            0.0
436        } else {
437            (price0 - sma) / denom
438        };
439    }
440
441    let mut i = first_out + 1;
442    while i < n {
443        let exiting = *data.get_unchecked(i - period);
444        let entering = *data.get_unchecked(i);
445        sum = sum - exiting + entering;
446        sma = sum * inv_p;
447
448        let start = i + 1 - period;
449        let wptr = data.as_ptr().add(start);
450
451        let vmean = _mm256_set1_pd(sma);
452        let vsgn = _mm256_set1_pd(-0.0f64);
453        let mut k = 0usize;
454        let mut sum_abs = 0.0f64;
455        let mut comp = 0.0f64;
456        while k + 4 <= period {
457            let x = _mm256_loadu_pd(wptr.add(k));
458            let d = _mm256_sub_pd(x, vmean);
459            let a = _mm256_andnot_pd(vsgn, d);
460            let mut lane = [0.0f64; 4];
461            _mm256_storeu_pd(lane.as_mut_ptr(), a);
462            for &val in &lane {
463                let y = val - comp;
464                let t = sum_abs + y;
465                comp = (t - sum_abs) - y;
466                sum_abs = t;
467            }
468            k += 4;
469        }
470        while k < period {
471            let val = (*wptr.add(k) - sma).abs();
472            let y = val - comp;
473            let t = sum_abs + y;
474            comp = (t - sum_abs) - y;
475            sum_abs = t;
476            k += 1;
477        }
478
479        let denom = 0.015 * (sum_abs * inv_p);
480        *out.get_unchecked_mut(i) = if denom == 0.0 {
481            0.0
482        } else {
483            (entering - sma) / denom
484        };
485        i += 1;
486    }
487}
488
489#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
490#[inline]
491pub unsafe fn cci_avx512_short(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
492    cci_scalar(data, period, first_valid, out)
493}
494
495#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
496#[inline]
497pub unsafe fn cci_avx512_long(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
498    cci_scalar(data, period, first_valid, out)
499}
500
501#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
502#[target_feature(enable = "avx512f,fma")]
503unsafe fn cci_avx512_impl(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
504    debug_assert!(data.len() == out.len());
505    debug_assert!(period >= 1 && first_valid + period <= data.len());
506
507    let n = data.len();
508    let inv_p = 1.0 / (period as f64);
509    let scale = (period as f64) * (1.0 / 0.015);
510
511    let base = data.as_ptr().add(first_valid);
512    let mut sum = 0.0;
513    {
514        let mut k = 0usize;
515        while k + 4 <= period {
516            let x0 = *base.add(k + 0);
517            let x1 = *base.add(k + 1);
518            let x2 = *base.add(k + 2);
519            let x3 = *base.add(k + 3);
520            sum = sum + x0 + x1 + x2 + x3;
521            k += 4;
522        }
523        while k < period {
524            sum += *base.add(k);
525            k += 1;
526        }
527    }
528
529    let first_out = first_valid + period - 1;
530    let mut sma = sum * inv_p;
531
532    let pos_mask_i = _mm512_set1_epi64(0x7FFF_FFFF_FFFF_FFFFu64 as i64);
533    let pos_mask = _mm512_castsi512_pd(pos_mask_i);
534
535    {
536        let vmean = _mm512_set1_pd(sma);
537        let mut k = 0usize;
538        let mut sum_abs = 0.0f64;
539        let mut comp = 0.0f64;
540        while k + 8 <= period {
541            let x = _mm512_loadu_pd(base.add(k));
542            let d = _mm512_sub_pd(x, vmean);
543            let a = _mm512_and_pd(d, pos_mask);
544            let mut lane = [0.0f64; 8];
545            _mm512_storeu_pd(lane.as_mut_ptr(), a);
546            for &val in &lane {
547                let y = val - comp;
548                let t = sum_abs + y;
549                comp = (t - sum_abs) - y;
550                sum_abs = t;
551            }
552            k += 8;
553        }
554        while k < period {
555            let val = (*base.add(k) - sma).abs();
556            let y = val - comp;
557            let t = sum_abs + y;
558            comp = (t - sum_abs) - y;
559            sum_abs = t;
560            k += 1;
561        }
562        let price0 = *data.get_unchecked(first_out);
563        let denom = 0.015 * (sum_abs * inv_p);
564        *out.get_unchecked_mut(first_out) = if denom == 0.0 {
565            0.0
566        } else {
567            (price0 - sma) / denom
568        };
569    }
570
571    let mut i = first_out + 1;
572    while i < n {
573        let exiting = *data.get_unchecked(i - period);
574        let entering = *data.get_unchecked(i);
575        sum = sum - exiting + entering;
576        sma = sum * inv_p;
577
578        let start = i + 1 - period;
579        let wptr = data.as_ptr().add(start);
580
581        let vmean = _mm512_set1_pd(sma);
582        let mut k = 0usize;
583        let mut sum_abs = 0.0f64;
584        let mut comp = 0.0f64;
585        while k + 8 <= period {
586            let x = _mm512_loadu_pd(wptr.add(k));
587            let d = _mm512_sub_pd(x, vmean);
588            let a = _mm512_and_pd(d, pos_mask);
589            let mut lane = [0.0f64; 8];
590            _mm512_storeu_pd(lane.as_mut_ptr(), a);
591            for &val in &lane {
592                let y = val - comp;
593                let t = sum_abs + y;
594                comp = (t - sum_abs) - y;
595                sum_abs = t;
596            }
597            k += 8;
598        }
599        while k < period {
600            let val = (*wptr.add(k) - sma).abs();
601            let y = val - comp;
602            let t = sum_abs + y;
603            comp = (t - sum_abs) - y;
604            sum_abs = t;
605            k += 1;
606        }
607        let denom = 0.015 * (sum_abs * inv_p);
608        *out.get_unchecked_mut(i) = if denom == 0.0 {
609            0.0
610        } else {
611            (entering - sma) / denom
612        };
613        i += 1;
614    }
615}
616
617#[inline(always)]
618pub unsafe fn cci_row_scalar(
619    data: &[f64],
620    first: usize,
621    period: usize,
622    _stride: usize,
623    _w_ptr: *const f64,
624    _inv: f64,
625    out: &mut [f64],
626) {
627    cci_scalar(data, period, first, out)
628}
629
630#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
631#[inline(always)]
632pub unsafe fn cci_row_avx2(
633    data: &[f64],
634    first: usize,
635    period: usize,
636    _stride: usize,
637    _w_ptr: *const f64,
638    _inv: f64,
639    out: &mut [f64],
640) {
641    cci_avx2(data, period, first, out)
642}
643
644#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
645#[inline(always)]
646pub unsafe fn cci_row_avx512(
647    data: &[f64],
648    first: usize,
649    period: usize,
650    _stride: usize,
651    _w_ptr: *const f64,
652    _inv: f64,
653    out: &mut [f64],
654) {
655    cci_avx512(data, period, first, out)
656}
657
658#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
659#[inline(always)]
660pub unsafe fn cci_row_avx512_short(
661    data: &[f64],
662    first: usize,
663    period: usize,
664    _stride: usize,
665    _w_ptr: *const f64,
666    _inv: f64,
667    out: &mut [f64],
668) {
669    cci_avx512_short(data, period, first, out)
670}
671
672#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
673#[inline(always)]
674pub unsafe fn cci_row_avx512_long(
675    data: &[f64],
676    first: usize,
677    period: usize,
678    _stride: usize,
679    _w_ptr: *const f64,
680    _inv: f64,
681    out: &mut [f64],
682) {
683    cci_avx512_long(data, period, first, out)
684}
685
686#[derive(Debug, Clone)]
687pub struct CciStream {
688    period: usize,
689    buffer: Vec<f64>,
690    head: usize,
691    filled: bool,
692
693    sum: f64,
694
695    scale: f64,
696
697    ost: OrderStatsTreap,
698}
699
700#[inline(always)]
701fn splitmix64(mut x: u64) -> u64 {
702    x = x.wrapping_add(0x9E3779B97F4A7C15);
703    let mut z = x;
704    z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
705    z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
706    z ^ (z >> 31)
707}
708
709#[derive(Debug, Clone)]
710struct OrderStatsTreap {
711    root: Option<Box<Node>>,
712    seed: u64,
713}
714
715#[derive(Debug, Clone)]
716struct Node {
717    key: f64,
718    prio: u64,
719    cnt: u32,
720    size: usize,
721    sum: f64,
722    left: Option<Box<Node>>,
723    right: Option<Box<Node>>,
724}
725
726#[inline(always)]
727fn sz(t: &Option<Box<Node>>) -> usize {
728    t.as_ref().map(|n| n.size).unwrap_or(0)
729}
730#[inline(always)]
731fn sm(t: &Option<Box<Node>>) -> f64 {
732    t.as_ref().map(|n| n.sum).unwrap_or(0.0)
733}
734
735impl Node {
736    #[inline(always)]
737    fn new(key: f64, prio: u64) -> Self {
738        Self {
739            key,
740            prio,
741            cnt: 1,
742            size: 1,
743            sum: key,
744            left: None,
745            right: None,
746        }
747    }
748    #[inline(always)]
749    fn recalc(&mut self) {
750        self.size = self.cnt as usize + sz(&self.left) + sz(&self.right);
751        self.sum = (self.key * self.cnt as f64) + sm(&self.left) + sm(&self.right);
752    }
753}
754
755impl OrderStatsTreap {
756    #[inline(always)]
757    fn new() -> Self {
758        let seed = splitmix64((&() as *const () as usize as u64) ^ 0xA5A5_A5A5_A5A5_A5A5);
759        Self { root: None, seed }
760    }
761
762    #[inline(always)]
763    fn next_prio(&mut self) -> u64 {
764        self.seed = splitmix64(self.seed);
765        self.seed
766    }
767
768    fn merge(a: Option<Box<Node>>, b: Option<Box<Node>>) -> Option<Box<Node>> {
769        match (a, b) {
770            (None, t) | (t, None) => t,
771            (Some(mut x), Some(mut y)) => {
772                if x.prio > y.prio {
773                    x.right = Self::merge(x.right.take(), Some(y));
774                    x.recalc();
775                    Some(x)
776                } else {
777                    y.left = Self::merge(Some(x), y.left.take());
778                    y.recalc();
779                    Some(y)
780                }
781            }
782        }
783    }
784
785    fn split(mut t: Option<Box<Node>>, key: f64) -> (Option<Box<Node>>, Option<Box<Node>>) {
786        match t.take() {
787            None => (None, None),
788            Some(mut n) => {
789                if key < n.key {
790                    let (l, r) = Self::split(n.left.take(), key);
791                    n.left = r;
792                    n.recalc();
793                    (l, Some(n))
794                } else {
795                    let (l, r) = Self::split(n.right.take(), key);
796                    n.right = l;
797                    n.recalc();
798                    (Some(n), r)
799                }
800            }
801        }
802    }
803
804    fn insert(&mut self, key: f64) {
805        debug_assert!(key.is_finite());
806        self.root = match self.root.take() {
807            None => Some(Box::new(Node::new(key, self.next_prio()))),
808            Some(mut _n) => {
809                self.root = Some(_n);
810                Self::insert_into(self.root.take(), key, self.next_prio())
811            }
812        };
813    }
814
815    fn insert_into(t: Option<Box<Node>>, key: f64, prio: u64) -> Option<Box<Node>> {
816        match t {
817            None => Some(Box::new(Node::new(key, prio))),
818            Some(mut n) => {
819                if key == n.key {
820                    n.cnt += 1;
821                    n.recalc();
822                    Some(n)
823                } else if prio > n.prio {
824                    let (l, r) = Self::split(Some(n), key);
825                    let mut m = Box::new(Node::new(key, prio));
826                    m.left = l;
827                    m.right = r;
828                    m.recalc();
829                    Some(m)
830                } else if key < n.key {
831                    n.left = Self::insert_into(n.left.take(), key, prio);
832                    n.recalc();
833                    Some(n)
834                } else {
835                    n.right = Self::insert_into(n.right.take(), key, prio);
836                    n.recalc();
837                    Some(n)
838                }
839            }
840        }
841    }
842
843    fn erase(&mut self, key: f64) {
844        debug_assert!(key.is_finite());
845        self.root = Self::erase_from(self.root.take(), key);
846    }
847
848    fn erase_from(t: Option<Box<Node>>, key: f64) -> Option<Box<Node>> {
849        match t {
850            None => None,
851            Some(mut n) => {
852                if key == n.key {
853                    if n.cnt > 1 {
854                        n.cnt -= 1;
855                        n.recalc();
856                        Some(n)
857                    } else {
858                        Self::merge(n.left.take(), n.right.take())
859                    }
860                } else if key < n.key {
861                    n.left = Self::erase_from(n.left.take(), key);
862                    n.recalc();
863                    Some(n)
864                } else {
865                    n.right = Self::erase_from(n.right.take(), key);
866                    n.recalc();
867                    Some(n)
868                }
869            }
870        }
871    }
872
873    fn prefix_le(&self, key: f64) -> (usize, f64) {
874        debug_assert!(key.is_finite());
875        let mut t = &self.root;
876        let mut count = 0usize;
877        let mut sum = 0.0f64;
878
879        while let Some(n) = t.as_ref() {
880            if key < n.key {
881                t = &n.left;
882            } else {
883                count += sz(&n.left) + n.cnt as usize;
884                sum += sm(&n.left) + (n.key * n.cnt as f64);
885                t = &n.right;
886            }
887        }
888        (count, sum)
889    }
890
891    #[inline(always)]
892    fn size(&self) -> usize {
893        sz(&self.root)
894    }
895}
896
897impl CciStream {
898    pub fn try_new(params: CciParams) -> Result<Self, CciError> {
899        let period = params.period.unwrap_or(14);
900        if period == 0 {
901            return Err(CciError::InvalidPeriod {
902                period,
903                data_len: 0,
904            });
905        }
906
907        let buffer = alloc_with_nan_prefix(period, period);
908        let scale = (period as f64) * (1.0 / 0.015);
909
910        Ok(Self {
911            period,
912            buffer,
913            head: 0,
914            filled: false,
915            sum: 0.0,
916            scale,
917            ost: OrderStatsTreap::new(),
918        })
919    }
920
921    #[inline(always)]
922    pub fn update(&mut self, value: f64) -> Option<f64> {
923        debug_assert!(value.is_finite(), "CCI stream expects finite inputs");
924
925        let old = self.buffer[self.head];
926        self.buffer[self.head] = value;
927        self.head = (self.head + 1) % self.period;
928        if !self.filled && self.head == 0 {
929            self.filled = true;
930        }
931
932        if !self.filled {
933            self.sum += value;
934            self.ost.insert(value);
935            return None;
936        }
937
938        if !old.is_nan() {
939            self.sum -= old;
940            self.ost.erase(old);
941        }
942        self.sum += value;
943        self.ost.insert(value);
944
945        let mean = self.sum / (self.period as f64);
946
947        let (k_le, sum_le) = self.ost.prefix_le(mean);
948
949        let n = self.period as f64;
950        let sum_abs = mean.mul_add(2.0 * (k_le as f64) - n, self.sum - 2.0 * sum_le);
951
952        if sum_abs == 0.0 {
953            Some(0.0)
954        } else {
955            Some((value - mean) * (self.scale / sum_abs))
956        }
957    }
958}
959
960#[derive(Clone, Debug)]
961pub struct CciBatchRange {
962    pub period: (usize, usize, usize),
963}
964
965impl Default for CciBatchRange {
966    fn default() -> Self {
967        Self {
968            period: (14, 263, 1),
969        }
970    }
971}
972
973#[derive(Clone, Debug, Default)]
974pub struct CciBatchBuilder {
975    range: CciBatchRange,
976    kernel: Kernel,
977}
978
979impl CciBatchBuilder {
980    pub fn new() -> Self {
981        Self::default()
982    }
983    pub fn kernel(mut self, k: Kernel) -> Self {
984        self.kernel = k;
985        self
986    }
987    #[inline]
988    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
989        self.range.period = (start, end, step);
990        self
991    }
992    #[inline]
993    pub fn period_static(mut self, p: usize) -> Self {
994        self.range.period = (p, p, 0);
995        self
996    }
997    pub fn apply_slice(self, data: &[f64]) -> Result<CciBatchOutput, CciError> {
998        cci_batch_with_kernel(data, &self.range, self.kernel)
999    }
1000    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<CciBatchOutput, CciError> {
1001        CciBatchBuilder::new().kernel(k).apply_slice(data)
1002    }
1003    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<CciBatchOutput, CciError> {
1004        let slice = source_type(c, src);
1005        self.apply_slice(slice)
1006    }
1007    pub fn with_default_candles(c: &Candles) -> Result<CciBatchOutput, CciError> {
1008        CciBatchBuilder::new()
1009            .kernel(Kernel::Auto)
1010            .apply_candles(c, "hlc3")
1011    }
1012}
1013
1014#[derive(Clone, Debug)]
1015pub struct CciBatchOutput {
1016    pub values: Vec<f64>,
1017    pub combos: Vec<CciParams>,
1018    pub rows: usize,
1019    pub cols: usize,
1020}
1021impl CciBatchOutput {
1022    pub fn row_for_params(&self, p: &CciParams) -> Option<usize> {
1023        self.combos
1024            .iter()
1025            .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
1026    }
1027    pub fn values_for(&self, p: &CciParams) -> Option<&[f64]> {
1028        self.row_for_params(p).map(|row| {
1029            let start = row * self.cols;
1030            &self.values[start..start + self.cols]
1031        })
1032    }
1033}
1034
1035#[inline(always)]
1036fn expand_grid(r: &CciBatchRange) -> Result<Vec<CciParams>, CciError> {
1037    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, CciError> {
1038        if step == 0 || start == end {
1039            return Ok(vec![start]);
1040        }
1041        if start < end {
1042            let mut v = Vec::new();
1043            let mut cur = start;
1044            while cur <= end {
1045                v.push(cur);
1046                match cur.checked_add(step) {
1047                    Some(next) => {
1048                        if next == cur {
1049                            break;
1050                        }
1051                        cur = next;
1052                    }
1053                    None => break,
1054                }
1055            }
1056            if v.is_empty() {
1057                return Err(CciError::InvalidRange { start, end, step });
1058            }
1059            Ok(v)
1060        } else {
1061            let mut v = Vec::new();
1062            let mut cur = start;
1063            loop {
1064                v.push(cur);
1065                if cur <= end {
1066                    break;
1067                }
1068
1069                if cur < step {
1070                    break;
1071                }
1072                cur -= step;
1073                if cur == end {
1074                    v.push(cur);
1075                    break;
1076                }
1077                if cur < end {
1078                    break;
1079                }
1080            }
1081            v.sort_unstable();
1082            v.dedup();
1083            if v.is_empty() {
1084                return Err(CciError::InvalidRange { start, end, step });
1085            }
1086            Ok(v)
1087        }
1088    }
1089    let periods = axis_usize(r.period)?;
1090    let mut out = Vec::with_capacity(periods.len());
1091    for &p in &periods {
1092        out.push(CciParams { period: Some(p) });
1093    }
1094    Ok(out)
1095}
1096
1097pub fn cci_batch_with_kernel(
1098    data: &[f64],
1099    sweep: &CciBatchRange,
1100    k: Kernel,
1101) -> Result<CciBatchOutput, CciError> {
1102    let kernel = match k {
1103        Kernel::Auto => detect_best_batch_kernel(),
1104        other if other.is_batch() => other,
1105        other => return Err(CciError::InvalidKernelForBatch(other)),
1106    };
1107    let simd = match kernel {
1108        Kernel::Avx512Batch => Kernel::Avx512,
1109        Kernel::Avx2Batch => Kernel::Avx2,
1110        Kernel::ScalarBatch => Kernel::Scalar,
1111        _ => unreachable!(),
1112    };
1113    cci_batch_par_slice(data, sweep, simd)
1114}
1115
1116#[inline(always)]
1117pub fn cci_batch_slice(
1118    data: &[f64],
1119    sweep: &CciBatchRange,
1120    kern: Kernel,
1121) -> Result<CciBatchOutput, CciError> {
1122    cci_batch_inner(data, sweep, kern, false)
1123}
1124#[inline(always)]
1125pub fn cci_batch_par_slice(
1126    data: &[f64],
1127    sweep: &CciBatchRange,
1128    kern: Kernel,
1129) -> Result<CciBatchOutput, CciError> {
1130    cci_batch_inner(data, sweep, kern, true)
1131}
1132
1133#[inline(always)]
1134fn cci_batch_inner(
1135    data: &[f64],
1136    sweep: &CciBatchRange,
1137    kern: Kernel,
1138    parallel: bool,
1139) -> Result<CciBatchOutput, CciError> {
1140    if data.is_empty() {
1141        return Err(CciError::EmptyInputData);
1142    }
1143    let combos = expand_grid(sweep)?;
1144    let cols = data.len();
1145    let rows = combos.len();
1146
1147    let first = data
1148        .iter()
1149        .position(|x| !x.is_nan())
1150        .ok_or(CciError::AllValuesNaN)?;
1151    let mut max_p = 0usize;
1152    for c in &combos {
1153        let p = c.period.unwrap();
1154        if p == 0 || p > cols {
1155            return Err(CciError::InvalidPeriod {
1156                period: p,
1157                data_len: cols,
1158            });
1159        }
1160        max_p = max_p.max(p);
1161    }
1162    if cols - first < max_p {
1163        return Err(CciError::NotEnoughValidData {
1164            needed: max_p,
1165            valid: cols - first,
1166        });
1167    }
1168
1169    rows.checked_mul(cols).ok_or(CciError::InvalidRange {
1170        start: sweep.period.0,
1171        end: sweep.period.1,
1172        step: sweep.period.2,
1173    })?;
1174
1175    let mut buf_mu = make_uninit_matrix(rows, cols);
1176    let out_ptr = buf_mu.as_mut_ptr() as *mut f64;
1177    let out_len = buf_mu.len();
1178    let out_cap = buf_mu.capacity();
1179    let out: &mut [f64] = unsafe { core::slice::from_raw_parts_mut(out_ptr, out_len) };
1180
1181    let _ = cci_batch_inner_into(data, sweep, kern, parallel, out)?;
1182
1183    let values = unsafe { Vec::from_raw_parts(out_ptr, out_len, out_cap) };
1184    core::mem::forget(buf_mu);
1185
1186    Ok(CciBatchOutput {
1187        values,
1188        combos,
1189        rows,
1190        cols,
1191    })
1192}
1193
1194#[inline(always)]
1195fn cci_batch_inner_into(
1196    data: &[f64],
1197    sweep: &CciBatchRange,
1198    kern: Kernel,
1199    parallel: bool,
1200    out: &mut [f64],
1201) -> Result<Vec<CciParams>, CciError> {
1202    let combos = expand_grid(sweep)?;
1203    if combos.is_empty() {
1204        return Err(CciError::InvalidRange {
1205            start: sweep.period.0,
1206            end: sweep.period.1,
1207            step: sweep.period.2,
1208        });
1209    }
1210
1211    if data.is_empty() {
1212        return Err(CciError::EmptyInputData);
1213    }
1214
1215    let first = data
1216        .iter()
1217        .position(|x| !x.is_nan())
1218        .ok_or(CciError::AllValuesNaN)?;
1219    let mut max_p = 0usize;
1220    for c in &combos {
1221        let p = c.period.unwrap();
1222        if p == 0 || p > data.len() {
1223            return Err(CciError::InvalidPeriod {
1224                period: p,
1225                data_len: data.len(),
1226            });
1227        }
1228        max_p = max_p.max(p);
1229    }
1230    if data.len() - first < max_p {
1231        return Err(CciError::NotEnoughValidData {
1232            needed: max_p,
1233            valid: data.len() - first,
1234        });
1235    }
1236
1237    let rows = combos.len();
1238    let cols = data.len();
1239
1240    let expected = rows.checked_mul(cols).ok_or(CciError::InvalidRange {
1241        start: sweep.period.0,
1242        end: sweep.period.1,
1243        step: sweep.period.2,
1244    })?;
1245    if out.len() != expected {
1246        return Err(CciError::OutputLengthMismatch {
1247            expected,
1248            got: out.len(),
1249        });
1250    }
1251
1252    let kernel = match kern {
1253        Kernel::Auto => detect_best_batch_kernel(),
1254        Kernel::Scalar => Kernel::ScalarBatch,
1255        Kernel::Avx2 => Kernel::Avx2Batch,
1256        Kernel::Avx512 => Kernel::Avx512Batch,
1257        k => k,
1258    };
1259    let simd = match kernel {
1260        Kernel::Avx512Batch => Kernel::Avx512,
1261        Kernel::Avx2Batch => Kernel::Avx2,
1262        Kernel::ScalarBatch => Kernel::Scalar,
1263        _ => unreachable!(),
1264    };
1265
1266    let warm: Vec<usize> = combos
1267        .iter()
1268        .map(|c| first + c.period.unwrap() - 1)
1269        .collect();
1270
1271    let out_uninit = unsafe {
1272        std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1273    };
1274
1275    for (row, &warmup) in warm.iter().enumerate() {
1276        let row_start = row * cols;
1277        for col in 0..warmup.min(cols) {
1278            out_uninit[row_start + col].write(f64::NAN);
1279        }
1280    }
1281
1282    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1283        let period = combos[row].period.unwrap();
1284
1285        let dst = core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1286
1287        assert_eq!(dst.len(), cols, "Output row length mismatch");
1288
1289        match simd {
1290            Kernel::Scalar => cci_row_scalar(data, first, period, 0, std::ptr::null(), 0.0, dst),
1291            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1292            Kernel::Avx2 => cci_row_avx2(data, first, period, 0, std::ptr::null(), 0.0, dst),
1293            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1294            Kernel::Avx512 => cci_row_avx512(data, first, period, 0, std::ptr::null(), 0.0, dst),
1295            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1296            Kernel::Avx2 | Kernel::Avx512 => {
1297                cci_row_scalar(data, first, period, 0, std::ptr::null(), 0.0, dst)
1298            }
1299            _ => unreachable!(),
1300        }
1301    };
1302
1303    #[cfg(not(target_arch = "wasm32"))]
1304    if parallel && rows > 1 {
1305        out_uninit
1306            .par_chunks_mut(cols)
1307            .enumerate()
1308            .for_each(|(row, out_row)| do_row(row, out_row));
1309    } else {
1310        out_uninit
1311            .chunks_mut(cols)
1312            .enumerate()
1313            .for_each(|(row, out_row)| do_row(row, out_row));
1314    }
1315
1316    #[cfg(target_arch = "wasm32")]
1317    {
1318        out_uninit
1319            .chunks_mut(cols)
1320            .enumerate()
1321            .for_each(|(row, out_row)| do_row(row, out_row));
1322    }
1323
1324    Ok(combos)
1325}
1326
1327#[cfg(test)]
1328mod tests {
1329    use super::*;
1330    use crate::skip_if_unsupported;
1331    use crate::utilities::data_loader::read_candles_from_csv;
1332    use paste::paste;
1333    #[cfg(feature = "proptest")]
1334    use proptest::prelude::*;
1335
1336    fn check_cci_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1337        skip_if_unsupported!(kernel, test_name);
1338        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1339        let candles = read_candles_from_csv(file_path)?;
1340
1341        let default_params = CciParams { period: None };
1342        let input_default = CciInput::from_candles(&candles, "close", default_params);
1343        let output_default = cci_with_kernel(&input_default, kernel)?;
1344        assert_eq!(output_default.values.len(), candles.close.len());
1345
1346        let params_20 = CciParams { period: Some(20) };
1347        let input_20 = CciInput::from_candles(&candles, "hl2", params_20);
1348        let output_20 = cci_with_kernel(&input_20, kernel)?;
1349        assert_eq!(output_20.values.len(), candles.close.len());
1350
1351        let params_custom = CciParams { period: Some(9) };
1352        let input_custom = CciInput::from_candles(&candles, "hlc3", params_custom);
1353        let output_custom = cci_with_kernel(&input_custom, kernel)?;
1354        assert_eq!(output_custom.values.len(), candles.close.len());
1355        Ok(())
1356    }
1357
1358    fn check_cci_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1359        skip_if_unsupported!(kernel, test_name);
1360        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1361        let candles = read_candles_from_csv(file_path)?;
1362        let input = CciInput::with_default_candles(&candles);
1363        let cci_result = cci_with_kernel(&input, kernel)?;
1364        assert_eq!(cci_result.values.len(), candles.close.len());
1365
1366        let expected_last_five_cci = [
1367            -51.55252564125841,
1368            -43.50326506381541,
1369            -64.05117302269149,
1370            -39.05150631680948,
1371            -152.50523930896998,
1372        ];
1373
1374        let start_idx = cci_result.values.len() - 5;
1375        let last_five_cci = &cci_result.values[start_idx..];
1376        for (i, &value) in last_five_cci.iter().enumerate() {
1377            let expected = expected_last_five_cci[i];
1378            assert!(
1379                (value - expected).abs() < 1e-6,
1380                "[{}] CCI mismatch at last five index {}: expected {}, got {}",
1381                test_name,
1382                i,
1383                expected,
1384                value
1385            );
1386        }
1387        let period: usize = input.get_period();
1388        for i in 0..(period - 1) {
1389            assert!(
1390                cci_result.values[i].is_nan(),
1391                "Expected NaN at index {} for initial period warm-up",
1392                i
1393            );
1394        }
1395        Ok(())
1396    }
1397
1398    fn check_cci_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1399        skip_if_unsupported!(kernel, test_name);
1400        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1401        let candles = read_candles_from_csv(file_path)?;
1402        let input = CciInput::with_default_candles(&candles);
1403
1404        match input.data {
1405            CciData::Candles { source, .. } => {
1406                assert_eq!(source, "hlc3", "Expected default source to be 'hlc3'");
1407            }
1408            _ => panic!("Expected CciData::Candles variant"),
1409        }
1410        let output = cci_with_kernel(&input, kernel)?;
1411        assert_eq!(output.values.len(), candles.close.len());
1412        Ok(())
1413    }
1414
1415    fn check_cci_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1416        skip_if_unsupported!(kernel, test_name);
1417        let input_data = [10.0, 20.0, 30.0];
1418        let params = CciParams { period: Some(0) };
1419        let input = CciInput::from_slice(&input_data, params);
1420        let res = cci_with_kernel(&input, kernel);
1421        assert!(
1422            res.is_err(),
1423            "[{}] CCI should fail with zero period",
1424            test_name
1425        );
1426        Ok(())
1427    }
1428
1429    fn check_cci_period_exceeds_length(
1430        test_name: &str,
1431        kernel: Kernel,
1432    ) -> Result<(), Box<dyn Error>> {
1433        skip_if_unsupported!(kernel, test_name);
1434        let data_small = [10.0, 20.0, 30.0];
1435        let params = CciParams { period: Some(10) };
1436        let input = CciInput::from_slice(&data_small, params);
1437        let res = cci_with_kernel(&input, kernel);
1438        assert!(
1439            res.is_err(),
1440            "[{}] CCI should fail with period exceeding length",
1441            test_name
1442        );
1443        Ok(())
1444    }
1445
1446    fn check_cci_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1447        skip_if_unsupported!(kernel, test_name);
1448        let single_point = [42.0];
1449        let params = CciParams { period: Some(9) };
1450        let input = CciInput::from_slice(&single_point, params);
1451        let res = cci_with_kernel(&input, kernel);
1452        assert!(
1453            res.is_err(),
1454            "[{}] CCI should fail with insufficient data",
1455            test_name
1456        );
1457        Ok(())
1458    }
1459
1460    fn check_cci_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1461        skip_if_unsupported!(kernel, test_name);
1462        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1463        let candles = read_candles_from_csv(file_path)?;
1464
1465        let first_params = CciParams { period: Some(14) };
1466        let first_input = CciInput::from_candles(&candles, "close", first_params);
1467        let first_result = cci_with_kernel(&first_input, kernel)?;
1468
1469        let second_params = CciParams { period: Some(14) };
1470        let second_input = CciInput::from_slice(&first_result.values, second_params);
1471        let second_result = cci_with_kernel(&second_input, kernel)?;
1472
1473        assert_eq!(second_result.values.len(), first_result.values.len());
1474        if second_result.values.len() > 28 {
1475            for i in 28..second_result.values.len() {
1476                assert!(
1477                    !second_result.values[i].is_nan(),
1478                    "Expected no NaN after index 28, found NaN at index {}",
1479                    i
1480                );
1481            }
1482        }
1483        Ok(())
1484    }
1485
1486    fn check_cci_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1487        skip_if_unsupported!(kernel, test_name);
1488        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1489        let candles = read_candles_from_csv(file_path)?;
1490        let input = CciInput::from_candles(&candles, "close", CciParams { period: Some(14) });
1491        let res = cci_with_kernel(&input, kernel)?;
1492        assert_eq!(res.values.len(), candles.close.len());
1493        if res.values.len() > 240 {
1494            for (i, &val) in res.values[240..].iter().enumerate() {
1495                assert!(
1496                    !val.is_nan(),
1497                    "[{}] Found unexpected NaN at out-index {}",
1498                    test_name,
1499                    240 + i
1500                );
1501            }
1502        }
1503        Ok(())
1504    }
1505
1506    fn check_cci_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1507        skip_if_unsupported!(kernel, test_name);
1508        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1509        let candles = read_candles_from_csv(file_path)?;
1510
1511        let period = 14;
1512        let input = CciInput::from_candles(
1513            &candles,
1514            "close",
1515            CciParams {
1516                period: Some(period),
1517            },
1518        );
1519        let batch_output = cci_with_kernel(&input, kernel)?.values;
1520
1521        let mut stream = CciStream::try_new(CciParams {
1522            period: Some(period),
1523        })?;
1524
1525        let mut stream_values = Vec::with_capacity(candles.close.len());
1526        for &price in &candles.close {
1527            match stream.update(price) {
1528                Some(cci_val) => stream_values.push(cci_val),
1529                None => stream_values.push(f64::NAN),
1530            }
1531        }
1532
1533        assert_eq!(batch_output.len(), stream_values.len());
1534        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1535            if b.is_nan() && s.is_nan() {
1536                continue;
1537            }
1538            let diff = (b - s).abs();
1539            assert!(
1540                diff < 1e-9,
1541                "[{}] CCI streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1542                test_name,
1543                i,
1544                b,
1545                s,
1546                diff
1547            );
1548        }
1549        Ok(())
1550    }
1551
1552    fn check_cci_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1553        skip_if_unsupported!(kernel, test_name);
1554        let empty_data: &[f64] = &[];
1555        let params = CciParams { period: Some(14) };
1556        let input = CciInput::from_slice(empty_data, params);
1557        let res = cci_with_kernel(&input, kernel);
1558        assert!(
1559            res.is_err(),
1560            "[{}] CCI should fail with empty input data",
1561            test_name
1562        );
1563        if let Err(e) = res {
1564            match e {
1565                CciError::EmptyInputData => {}
1566                other => panic!(
1567                    "[{}] Expected EmptyInputData error, got: {:?}",
1568                    test_name, other
1569                ),
1570            }
1571        }
1572        Ok(())
1573    }
1574
1575    #[cfg(debug_assertions)]
1576    fn check_cci_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1577        skip_if_unsupported!(kernel, test_name);
1578
1579        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1580        let candles = read_candles_from_csv(file_path)?;
1581
1582        let input = CciInput::from_candles(&candles, "close", CciParams::default());
1583        let output = cci_with_kernel(&input, kernel)?;
1584
1585        let params_20 = CciParams { period: Some(20) };
1586        let input_20 = CciInput::from_candles(&candles, "hlc3", params_20);
1587        let output_20 = cci_with_kernel(&input_20, kernel)?;
1588
1589        for output in [output, output_20] {
1590            for (i, &val) in output.values.iter().enumerate() {
1591                if val.is_nan() {
1592                    continue;
1593                }
1594
1595                let bits = val.to_bits();
1596
1597                if bits == 0x11111111_11111111 {
1598                    panic!(
1599                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {}",
1600                        test_name, val, bits, i
1601                    );
1602                }
1603
1604                if bits == 0x22222222_22222222 {
1605                    panic!(
1606                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {}",
1607                        test_name, val, bits, i
1608                    );
1609                }
1610
1611                if bits == 0x33333333_33333333 {
1612                    panic!(
1613                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {}",
1614                        test_name, val, bits, i
1615                    );
1616                }
1617            }
1618        }
1619
1620        Ok(())
1621    }
1622
1623    #[cfg(not(debug_assertions))]
1624    fn check_cci_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1625        Ok(())
1626    }
1627
1628    #[cfg(feature = "proptest")]
1629    #[allow(clippy::float_cmp)]
1630    fn check_cci_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1631        use proptest::prelude::*;
1632        skip_if_unsupported!(kernel, test_name);
1633
1634        let strat = (1usize..=64).prop_flat_map(|period| {
1635            (
1636                prop::collection::vec(
1637                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1638                    period..400,
1639                ),
1640                Just(period),
1641            )
1642        });
1643
1644        proptest::test_runner::TestRunner::default().run(&strat, |(data, period)| {
1645            let params = CciParams {
1646                period: Some(period),
1647            };
1648            let input = CciInput::from_slice(&data, params);
1649
1650            let CciOutput { values: out } = cci_with_kernel(&input, kernel).unwrap();
1651
1652            let CciOutput { values: ref_out } = cci_with_kernel(&input, Kernel::Scalar).unwrap();
1653
1654            for i in 0..(period - 1) {
1655                prop_assert!(
1656                    out[i].is_nan(),
1657                    "[{}] Expected NaN at index {} during warmup period, got {}",
1658                    test_name,
1659                    i,
1660                    out[i]
1661                );
1662            }
1663
1664            for i in (period - 1)..data.len() {
1665                prop_assert!(
1666                    !out[i].is_nan(),
1667                    "[{}] Expected valid value at index {} after warmup, got NaN",
1668                    test_name,
1669                    i
1670                );
1671            }
1672
1673            for i in 0..data.len() {
1674                let y = out[i];
1675                let r = ref_out[i];
1676
1677                if y.is_nan() && r.is_nan() {
1678                    continue;
1679                }
1680
1681                let y_bits = y.to_bits();
1682                let r_bits = r.to_bits();
1683                let ulp_diff = if y_bits > r_bits {
1684                    y_bits - r_bits
1685                } else {
1686                    r_bits - y_bits
1687                };
1688
1689                prop_assert!(
1690                    ulp_diff <= 8,
1691                    "[{}] Kernel mismatch at index {}: {} != {} (ULP diff: {})",
1692                    test_name,
1693                    i,
1694                    y,
1695                    r,
1696                    ulp_diff
1697                );
1698            }
1699
1700            if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) && data.len() >= period {
1701                for i in (period - 1)..data.len() {
1702                    prop_assert!(
1703                        out[i].abs() < 1e-9,
1704                        "[{}] CCI should be ~0 for constant prices, got {} at index {}",
1705                        test_name,
1706                        out[i],
1707                        i
1708                    );
1709                }
1710            }
1711
1712            for i in (period - 1)..data.len() {
1713                let window_start = i + 1 - period;
1714                let window = &data[window_start..=i];
1715
1716                let sum: f64 = window.iter().sum();
1717                let sma = sum / period as f64;
1718
1719                let mad: f64 = window.iter().map(|&x| (x - sma).abs()).sum::<f64>() / period as f64;
1720
1721                let price = data[i];
1722                let expected_cci = if mad == 0.0 {
1723                    0.0
1724                } else {
1725                    (price - sma) / (0.015 * mad)
1726                };
1727
1728                let actual_cci = out[i];
1729                let diff = (actual_cci - expected_cci).abs();
1730
1731                prop_assert!(
1732                    diff < 1e-10,
1733                    "[{}] CCI calculation mismatch at index {}: expected {}, got {}, diff {}",
1734                    test_name,
1735                    i,
1736                    expected_cci,
1737                    actual_cci,
1738                    diff
1739                );
1740            }
1741
1742            if period == 1 {
1743                for i in 0..data.len() {
1744                    prop_assert!(
1745                        out[i].abs() < 1e-9,
1746                        "[{}] CCI should be ~0 for period=1, got {} at index {}",
1747                        test_name,
1748                        out[i],
1749                        i
1750                    );
1751                }
1752            }
1753
1754            for i in (period - 1)..data.len() {
1755                if out[i].abs() > 500.0 {
1756                    eprintln!(
1757							"[{}] Warning: Extreme CCI value {} at index {} (typical range is -300 to 300)",
1758							test_name,
1759							out[i],
1760							i
1761						);
1762                }
1763            }
1764
1765            for i in (period - 1)..data.len() {
1766                let window_start = i + 1 - period;
1767                let window = &data[window_start..=i];
1768                let sum: f64 = window.iter().sum();
1769                let sma = sum / period as f64;
1770                let mad: f64 = window.iter().map(|&x| (x - sma).abs()).sum::<f64>() / period as f64;
1771
1772                if mad > 0.0 && mad < 1e-12 {
1773                    let actual_cci = out[i];
1774
1775                    prop_assert!(
1776							actual_cci.is_finite(),
1777							"[{}] CCI should be finite even with very small MAD ({}) at index {}, got {}",
1778							test_name,
1779							mad,
1780							i,
1781							actual_cci
1782						);
1783
1784                    let price = data[i];
1785                    let expected_cci = (price - sma) / (0.015 * mad);
1786                    let relative_error = ((actual_cci - expected_cci) / expected_cci).abs();
1787
1788                    prop_assert!(
1789							relative_error < 1e-8 || (actual_cci - expected_cci).abs() < 1e-10,
1790							"[{}] CCI calculation with small MAD at index {}: expected {}, got {}, relative error {}",
1791							test_name,
1792							i,
1793							expected_cci,
1794							actual_cci,
1795							relative_error
1796						);
1797                }
1798            }
1799
1800            Ok(())
1801        })?;
1802
1803        Ok(())
1804    }
1805
1806    macro_rules! generate_all_cci_tests {
1807        ($($test_fn:ident),*) => {
1808            paste! {
1809                $(
1810                    #[test]
1811                    fn [<$test_fn _scalar_f64>]() {
1812                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1813                    }
1814                )*
1815                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1816                $(
1817                    #[test]
1818                    fn [<$test_fn _avx2_f64>]() {
1819                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1820                    }
1821                    #[test]
1822                    fn [<$test_fn _avx512_f64>]() {
1823                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1824                    }
1825                )*
1826            }
1827        }
1828    }
1829
1830    generate_all_cci_tests!(
1831        check_cci_partial_params,
1832        check_cci_accuracy,
1833        check_cci_default_candles,
1834        check_cci_zero_period,
1835        check_cci_period_exceeds_length,
1836        check_cci_very_small_dataset,
1837        check_cci_reinput,
1838        check_cci_nan_handling,
1839        check_cci_streaming,
1840        check_cci_empty_input,
1841        check_cci_no_poison
1842    );
1843
1844    #[cfg(feature = "proptest")]
1845    generate_all_cci_tests!(check_cci_property);
1846
1847    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1848    #[test]
1849    fn test_cci_into_matches_api() -> Result<(), Box<dyn Error>> {
1850        let mut data = Vec::with_capacity(256);
1851        data.extend_from_slice(&[f64::NAN, f64::NAN, f64::NAN]);
1852        for i in 0..253usize {
1853            let x = (i as f64 * 0.037).sin() * 5.0 + 100.0 + (i % 7) as f64 * 0.1;
1854            data.push(x);
1855        }
1856
1857        let input = CciInput::from_slice(&data, CciParams::default());
1858
1859        let baseline = cci(&input)?.values;
1860
1861        let mut out = vec![0.0; data.len()];
1862        cci_into(&input, &mut out)?;
1863
1864        assert_eq!(baseline.len(), out.len());
1865
1866        fn eq_or_nan(a: f64, b: f64) -> bool {
1867            (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
1868        }
1869
1870        for (i, (a, b)) in baseline.iter().zip(out.iter()).enumerate() {
1871            assert!(
1872                eq_or_nan(*a, *b),
1873                "cci_into mismatch at idx {}: baseline={}, into={}",
1874                i,
1875                a,
1876                b
1877            );
1878        }
1879
1880        Ok(())
1881    }
1882
1883    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1884        skip_if_unsupported!(kernel, test);
1885        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1886        let c = read_candles_from_csv(file)?;
1887
1888        let output = CciBatchBuilder::new()
1889            .kernel(kernel)
1890            .apply_candles(&c, "hlc3")?;
1891
1892        let def = CciParams::default();
1893        let row = output.values_for(&def).expect("default row missing");
1894        assert_eq!(row.len(), c.close.len());
1895
1896        let expected = [
1897            -51.55252564125841,
1898            -43.50326506381541,
1899            -64.05117302269149,
1900            -39.05150631680948,
1901            -152.50523930896998,
1902        ];
1903        let start = row.len() - 5;
1904        for (i, &v) in row[start..].iter().enumerate() {
1905            assert!(
1906                (v - expected[i]).abs() < 1e-6,
1907                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1908            );
1909        }
1910        Ok(())
1911    }
1912
1913    #[cfg(debug_assertions)]
1914    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1915        skip_if_unsupported!(kernel, test);
1916
1917        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1918        let c = read_candles_from_csv(file)?;
1919
1920        let output = CciBatchBuilder::new()
1921            .kernel(kernel)
1922            .period_range(10, 30, 10)
1923            .apply_candles(&c, "close")?;
1924
1925        for (idx, &val) in output.values.iter().enumerate() {
1926            if val.is_nan() {
1927                continue;
1928            }
1929
1930            let bits = val.to_bits();
1931            let row = idx / output.cols;
1932            let col = idx % output.cols;
1933
1934            if bits == 0x11111111_11111111 {
1935                panic!(
1936					"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {})",
1937					test, val, bits, row, col, idx
1938				);
1939            }
1940
1941            if bits == 0x22222222_22222222 {
1942                panic!(
1943					"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {})",
1944					test, val, bits, row, col, idx
1945				);
1946            }
1947
1948            if bits == 0x33333333_33333333 {
1949                panic!(
1950					"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {})",
1951					test, val, bits, row, col, idx
1952				);
1953            }
1954        }
1955
1956        Ok(())
1957    }
1958
1959    #[cfg(not(debug_assertions))]
1960    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1961        Ok(())
1962    }
1963
1964    macro_rules! gen_batch_tests {
1965        ($fn_name:ident) => {
1966            paste! {
1967                #[test] fn [<$fn_name _scalar>]()      {
1968                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1969                }
1970                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1971                #[test] fn [<$fn_name _avx2>]()        {
1972                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1973                }
1974                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1975                #[test] fn [<$fn_name _avx512>]()      {
1976                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1977                }
1978                #[test] fn [<$fn_name _auto_detect>]() {
1979                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]),
1980                                     Kernel::Auto);
1981                }
1982            }
1983        };
1984    }
1985    gen_batch_tests!(check_batch_default_row);
1986    gen_batch_tests!(check_batch_no_poison);
1987}
1988
1989#[cfg(feature = "python")]
1990#[pyfunction(name = "cci")]
1991#[pyo3(signature = (data, period, kernel=None))]
1992pub fn cci_py<'py>(
1993    py: Python<'py>,
1994    data: numpy::PyReadonlyArray1<'py, f64>,
1995    period: usize,
1996    kernel: Option<&str>,
1997) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1998    use numpy::{IntoPyArray, PyArrayMethods};
1999
2000    let slice_in = data.as_slice()?;
2001    let kern = validate_kernel(kernel, false)?;
2002
2003    let params = CciParams {
2004        period: Some(period),
2005    };
2006    let cci_in = CciInput::from_slice(slice_in, params);
2007
2008    let result_vec: Vec<f64> = py
2009        .allow_threads(|| cci_with_kernel(&cci_in, kern).map(|o| o.values))
2010        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2011
2012    Ok(result_vec.into_pyarray(py))
2013}
2014
2015#[cfg(feature = "python")]
2016#[pyclass(name = "CciStream")]
2017pub struct CciStreamPy {
2018    stream: CciStream,
2019}
2020
2021#[cfg(feature = "python")]
2022#[pymethods]
2023impl CciStreamPy {
2024    #[new]
2025    fn new(period: usize) -> PyResult<Self> {
2026        let params = CciParams {
2027            period: Some(period),
2028        };
2029        let stream =
2030            CciStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2031        Ok(CciStreamPy { stream })
2032    }
2033
2034    fn update(&mut self, value: f64) -> Option<f64> {
2035        self.stream.update(value)
2036    }
2037}
2038
2039#[cfg(feature = "python")]
2040#[pyfunction(name = "cci_batch")]
2041#[pyo3(signature = (data, period_range, kernel=None))]
2042pub fn cci_batch_py<'py>(
2043    py: Python<'py>,
2044    data: numpy::PyReadonlyArray1<'py, f64>,
2045    period_range: (usize, usize, usize),
2046    kernel: Option<&str>,
2047) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2048    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2049    use pyo3::types::PyDict;
2050
2051    let slice_in = data.as_slice()?;
2052    let kern = validate_kernel(kernel, true)?;
2053
2054    let sweep = CciBatchRange {
2055        period: period_range,
2056    };
2057
2058    let output = py
2059        .allow_threads(|| cci_batch_with_kernel(slice_in, &sweep, kern))
2060        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2061
2062    let values_arr = output.values.into_pyarray(py);
2063    let reshaped = values_arr.reshape((output.rows, output.cols))?;
2064
2065    let dict = PyDict::new(py);
2066    dict.set_item("values", reshaped)?;
2067    dict.set_item(
2068        "periods",
2069        output
2070            .combos
2071            .iter()
2072            .map(|p| p.period.unwrap() as u64)
2073            .collect::<Vec<_>>()
2074            .into_pyarray(py),
2075    )?;
2076
2077    Ok(dict)
2078}
2079
2080#[cfg(all(feature = "python", feature = "cuda"))]
2081#[pyclass(module = "ta_indicators.cuda", name = "CciDeviceArrayF32", unsendable)]
2082pub struct CciDeviceArrayF32Py {
2083    pub(crate) inner: DeviceArrayF32,
2084    pub(crate) _ctx: Arc<Context>,
2085    pub(crate) device_id: u32,
2086    pub(crate) stream: usize,
2087}
2088
2089#[cfg(all(feature = "python", feature = "cuda"))]
2090#[pymethods]
2091impl CciDeviceArrayF32Py {
2092    #[getter]
2093    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2094        let inner = &self.inner;
2095        let d = PyDict::new(py);
2096        d.set_item("shape", (inner.rows, inner.cols))?;
2097        d.set_item("typestr", "<f4")?;
2098        d.set_item(
2099            "strides",
2100            (
2101                inner.cols * std::mem::size_of::<f32>(),
2102                std::mem::size_of::<f32>(),
2103            ),
2104        )?;
2105        d.set_item("data", (inner.device_ptr() as usize, false))?;
2106
2107        d.set_item("version", 3)?;
2108        Ok(d)
2109    }
2110
2111    fn __dlpack_device__(&self) -> (i32, i32) {
2112        (2, self.device_id as i32)
2113    }
2114
2115    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2116    pub fn __dlpack__<'py>(
2117        &mut self,
2118        py: Python<'py>,
2119        stream: Option<pyo3::PyObject>,
2120        max_version: Option<pyo3::PyObject>,
2121        dl_device: Option<pyo3::PyObject>,
2122        copy: Option<pyo3::PyObject>,
2123    ) -> PyResult<PyObject> {
2124        let (kdl, alloc_dev) = self.__dlpack_device__();
2125        if let Some(dev_obj) = dl_device.as_ref() {
2126            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2127                if dev_ty != kdl || dev_id != alloc_dev {
2128                    let wants_copy = copy
2129                        .as_ref()
2130                        .and_then(|c| c.extract::<bool>(py).ok())
2131                        .unwrap_or(false);
2132                    if wants_copy {
2133                        return Err(PyValueError::new_err(
2134                            "device copy not implemented for __dlpack__",
2135                        ));
2136                    } else {
2137                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2138                    }
2139                }
2140            }
2141        }
2142        let _ = stream;
2143
2144        let dummy =
2145            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2146        let inner = std::mem::replace(
2147            &mut self.inner,
2148            DeviceArrayF32 {
2149                buf: dummy,
2150                rows: 0,
2151                cols: 0,
2152            },
2153        );
2154
2155        let rows = inner.rows;
2156        let cols = inner.cols;
2157        let buf = inner.buf;
2158
2159        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2160
2161        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2162    }
2163}
2164#[cfg(all(feature = "python", feature = "cuda"))]
2165#[pyfunction(name = "cci_cuda_batch_dev")]
2166#[pyo3(signature = (data, period_range, device_id=0))]
2167pub fn cci_cuda_batch_dev_py(
2168    py: Python<'_>,
2169    data: numpy::PyReadonlyArray1<'_, f32>,
2170    period_range: (usize, usize, usize),
2171    device_id: usize,
2172) -> PyResult<CciDeviceArrayF32Py> {
2173    use crate::cuda::cuda_available;
2174    if !cuda_available() {
2175        return Err(PyValueError::new_err("CUDA not available"));
2176    }
2177    let slice = data.as_slice()?;
2178    let sweep = CciBatchRange {
2179        period: period_range,
2180    };
2181    let (inner, dev_id, ctx, stream) = py.allow_threads(|| -> PyResult<_> {
2182        let cuda = CudaCci::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2183        let dev_id = cuda.device_id();
2184        let ctx = cuda.context_arc();
2185        let out = cuda
2186            .cci_batch_dev(slice, &sweep)
2187            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2188        cuda.stream()
2189            .synchronize()
2190            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2191        Ok((out, dev_id, ctx, cuda.stream_handle_usize()))
2192    })?;
2193    Ok(CciDeviceArrayF32Py {
2194        inner,
2195        _ctx: ctx,
2196        device_id: dev_id,
2197        stream,
2198    })
2199}
2200
2201#[cfg(all(feature = "python", feature = "cuda"))]
2202#[pyfunction(name = "cci_cuda_many_series_one_param_dev")]
2203#[pyo3(signature = (data_tm, cols, rows, period, device_id=0))]
2204pub fn cci_cuda_many_series_one_param_dev_py(
2205    py: Python<'_>,
2206    data_tm: numpy::PyReadonlyArray1<'_, f32>,
2207    cols: usize,
2208    rows: usize,
2209    period: usize,
2210    device_id: usize,
2211) -> PyResult<CciDeviceArrayF32Py> {
2212    use crate::cuda::cuda_available;
2213    if !cuda_available() {
2214        return Err(PyValueError::new_err("CUDA not available"));
2215    }
2216    let slice = data_tm.as_slice()?;
2217    let (inner, dev_id, ctx, stream) = py.allow_threads(|| -> PyResult<_> {
2218        let cuda = CudaCci::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2219        let dev_id = cuda.device_id();
2220        let ctx = cuda.context_arc();
2221        let out = cuda
2222            .cci_many_series_one_param_time_major_dev(slice, cols, rows, period)
2223            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2224        cuda.stream()
2225            .synchronize()
2226            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2227        Ok((out, dev_id, ctx, cuda.stream_handle_usize()))
2228    })?;
2229    Ok(CciDeviceArrayF32Py {
2230        inner,
2231        _ctx: ctx,
2232        device_id: dev_id,
2233        stream,
2234    })
2235}
2236
2237#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2238#[wasm_bindgen]
2239pub fn cci_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2240    let params = CciParams {
2241        period: Some(period),
2242    };
2243    let input = CciInput::from_slice(data, params);
2244
2245    let mut output = vec![0.0; data.len()];
2246
2247    cci_into_slice(&mut output, &input, detect_best_kernel())
2248        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2249
2250    Ok(output)
2251}
2252
2253#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2254#[wasm_bindgen]
2255pub fn cci_into(
2256    in_ptr: *const f64,
2257    out_ptr: *mut f64,
2258    len: usize,
2259    period: usize,
2260) -> Result<(), JsValue> {
2261    if in_ptr.is_null() || out_ptr.is_null() {
2262        return Err(JsValue::from_str("null pointer passed to cci_into"));
2263    }
2264
2265    unsafe {
2266        let data = std::slice::from_raw_parts(in_ptr, len);
2267
2268        if period == 0 || period > len {
2269            return Err(JsValue::from_str("Invalid period"));
2270        }
2271
2272        let params = CciParams {
2273            period: Some(period),
2274        };
2275        let input = CciInput::from_slice(data, params);
2276
2277        if in_ptr == out_ptr {
2278            let mut temp = vec![0.0; len];
2279            cci_into_slice(&mut temp, &input, detect_best_kernel())
2280                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2281            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2282            out.copy_from_slice(&temp);
2283        } else {
2284            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2285            cci_into_slice(out, &input, detect_best_kernel())
2286                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2287        }
2288
2289        Ok(())
2290    }
2291}
2292
2293#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2294#[wasm_bindgen]
2295pub fn cci_alloc(len: usize) -> *mut f64 {
2296    let mut vec = Vec::<f64>::with_capacity(len);
2297    let ptr = vec.as_mut_ptr();
2298    std::mem::forget(vec);
2299    ptr
2300}
2301
2302#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2303#[wasm_bindgen]
2304pub fn cci_free(ptr: *mut f64, len: usize) {
2305    unsafe {
2306        let _ = Vec::from_raw_parts(ptr, len, len);
2307    }
2308}
2309
2310#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2311#[wasm_bindgen]
2312pub fn cci_batch_js(
2313    data: &[f64],
2314    period_start: usize,
2315    period_end: usize,
2316    period_step: usize,
2317) -> Result<Vec<f64>, JsValue> {
2318    let sweep = CciBatchRange {
2319        period: (period_start, period_end, period_step),
2320    };
2321
2322    cci_batch_inner(data, &sweep, detect_best_kernel(), false)
2323        .map(|output| output.values)
2324        .map_err(|e| JsValue::from_str(&e.to_string()))
2325}
2326
2327#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2328#[wasm_bindgen]
2329pub fn cci_batch_metadata_js(
2330    period_start: usize,
2331    period_end: usize,
2332    period_step: usize,
2333) -> Result<Vec<f64>, JsValue> {
2334    let sweep = CciBatchRange {
2335        period: (period_start, period_end, period_step),
2336    };
2337
2338    let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2339    let mut metadata = Vec::with_capacity(combos.len());
2340
2341    for combo in combos {
2342        metadata.push(combo.period.unwrap() as f64);
2343    }
2344
2345    Ok(metadata)
2346}
2347
2348#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2349#[derive(Serialize, Deserialize)]
2350pub struct CciBatchConfig {
2351    pub period_range: (usize, usize, usize),
2352}
2353
2354#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2355#[derive(Serialize, Deserialize)]
2356pub struct CciBatchJsOutput {
2357    pub values: Vec<f64>,
2358    pub combos: Vec<CciParams>,
2359    pub rows: usize,
2360    pub cols: usize,
2361}
2362
2363#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2364#[wasm_bindgen(js_name = cci_batch)]
2365pub fn cci_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2366    let config: CciBatchConfig = serde_wasm_bindgen::from_value(config)
2367        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2368
2369    let sweep = CciBatchRange {
2370        period: config.period_range,
2371    };
2372
2373    let output = cci_batch_inner(data, &sweep, detect_best_kernel(), false)
2374        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2375
2376    let js_output = CciBatchJsOutput {
2377        values: output.values,
2378        combos: output.combos,
2379        rows: output.rows,
2380        cols: output.cols,
2381    };
2382
2383    serde_wasm_bindgen::to_value(&js_output)
2384        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2385}
2386
2387#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2388#[wasm_bindgen]
2389pub fn cci_batch_into(
2390    in_ptr: *const f64,
2391    out_ptr: *mut f64,
2392    len: usize,
2393    period_start: usize,
2394    period_end: usize,
2395    period_step: usize,
2396) -> Result<usize, JsValue> {
2397    if in_ptr.is_null() || out_ptr.is_null() {
2398        return Err(JsValue::from_str("null pointer passed to cci_batch_into"));
2399    }
2400
2401    unsafe {
2402        let data = std::slice::from_raw_parts(in_ptr, len);
2403
2404        let sweep = CciBatchRange {
2405            period: (period_start, period_end, period_step),
2406        };
2407
2408        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2409        let rows = combos.len();
2410        let cols = len;
2411
2412        let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2413
2414        cci_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
2415            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2416
2417        Ok(rows)
2418    }
2419}