1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::{PyDict, PyList};
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15#[cfg(all(feature = "python", feature = "cuda"))]
16use crate::cuda::moving_averages::DeviceArrayF32;
17#[cfg(all(feature = "python", feature = "cuda"))]
18use crate::cuda::oscillators::CudaCci;
19use crate::utilities::data_loader::{source_type, Candles};
20#[cfg(all(feature = "python", feature = "cuda"))]
21use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
22use crate::utilities::enums::Kernel;
23use crate::utilities::helpers::{
24 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, make_uninit_matrix,
25};
26#[cfg(feature = "python")]
27use crate::utilities::kernel_validation::validate_kernel;
28#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
29use core::arch::x86_64::*;
30#[cfg(all(feature = "python", feature = "cuda"))]
31use cust::context::Context;
32#[cfg(all(feature = "python", feature = "cuda"))]
33use cust::memory::DeviceBuffer;
34#[cfg(not(target_arch = "wasm32"))]
35use rayon::prelude::*;
36use std::convert::AsRef;
37use std::error::Error;
38use std::mem::MaybeUninit;
39#[cfg(all(feature = "python", feature = "cuda"))]
40use std::sync::Arc;
41use thiserror::Error;
42
43#[derive(Debug, Clone)]
44pub enum CciData<'a> {
45 Candles {
46 candles: &'a Candles,
47 source: &'a str,
48 },
49 Slice(&'a [f64]),
50}
51
52#[derive(Debug, Clone)]
53pub struct CciOutput {
54 pub values: Vec<f64>,
55}
56
57#[derive(Debug, Clone)]
58#[cfg_attr(
59 all(target_arch = "wasm32", feature = "wasm"),
60 derive(Serialize, Deserialize)
61)]
62pub struct CciParams {
63 pub period: Option<usize>,
64}
65
66impl Default for CciParams {
67 fn default() -> Self {
68 Self { period: Some(14) }
69 }
70}
71
72#[derive(Debug, Clone)]
73pub struct CciInput<'a> {
74 pub data: CciData<'a>,
75 pub params: CciParams,
76}
77
78impl<'a> AsRef<[f64]> for CciInput<'a> {
79 #[inline(always)]
80 fn as_ref(&self) -> &[f64] {
81 match &self.data {
82 CciData::Slice(slice) => slice,
83 CciData::Candles { candles, source } => source_type(candles, source),
84 }
85 }
86}
87
88impl<'a> CciInput<'a> {
89 #[inline]
90 pub fn from_candles(c: &'a Candles, s: &'a str, p: CciParams) -> Self {
91 Self {
92 data: CciData::Candles {
93 candles: c,
94 source: s,
95 },
96 params: p,
97 }
98 }
99 #[inline]
100 pub fn from_slice(sl: &'a [f64], p: CciParams) -> Self {
101 Self {
102 data: CciData::Slice(sl),
103 params: p,
104 }
105 }
106 #[inline]
107 pub fn with_default_candles(c: &'a Candles) -> Self {
108 Self::from_candles(c, "hlc3", CciParams::default())
109 }
110 #[inline]
111 pub fn get_period(&self) -> usize {
112 self.params.period.unwrap_or(14)
113 }
114 #[inline]
115 pub fn data_len(&self) -> usize {
116 match &self.data {
117 CciData::Slice(slice) => slice.len(),
118 CciData::Candles { candles, .. } => candles.close.len(),
119 }
120 }
121}
122
123#[derive(Copy, Clone, Debug)]
124pub struct CciBuilder {
125 period: Option<usize>,
126 kernel: Kernel,
127}
128
129impl Default for CciBuilder {
130 fn default() -> Self {
131 Self {
132 period: None,
133 kernel: Kernel::Auto,
134 }
135 }
136}
137
138impl CciBuilder {
139 #[inline(always)]
140 pub fn new() -> Self {
141 Self::default()
142 }
143 #[inline(always)]
144 pub fn period(mut self, n: usize) -> Self {
145 self.period = Some(n);
146 self
147 }
148 #[inline(always)]
149 pub fn kernel(mut self, k: Kernel) -> Self {
150 self.kernel = k;
151 self
152 }
153 #[inline(always)]
154 pub fn apply(self, c: &Candles) -> Result<CciOutput, CciError> {
155 let p = CciParams {
156 period: self.period,
157 };
158 let i = CciInput::from_candles(c, "hlc3", p);
159 cci_with_kernel(&i, self.kernel)
160 }
161 #[inline(always)]
162 pub fn apply_slice(self, d: &[f64]) -> Result<CciOutput, CciError> {
163 let p = CciParams {
164 period: self.period,
165 };
166 let i = CciInput::from_slice(d, p);
167 cci_with_kernel(&i, self.kernel)
168 }
169 #[inline(always)]
170 pub fn into_stream(self) -> Result<CciStream, CciError> {
171 let p = CciParams {
172 period: self.period,
173 };
174 CciStream::try_new(p)
175 }
176}
177
178#[derive(Debug, Error)]
179pub enum CciError {
180 #[error("cci: Input data slice is empty.")]
181 EmptyInputData,
182 #[error("cci: All values are NaN.")]
183 AllValuesNaN,
184 #[error("cci: Invalid period: period = {period}, data length = {data_len}")]
185 InvalidPeriod { period: usize, data_len: usize },
186 #[error("cci: Not enough valid data: needed = {needed}, valid = {valid}")]
187 NotEnoughValidData { needed: usize, valid: usize },
188 #[error("cci: output length mismatch: expected {expected}, got {got}")]
189 OutputLengthMismatch { expected: usize, got: usize },
190 #[error("cci: invalid range expansion: start={start} end={end} step={step}")]
191 InvalidRange {
192 start: usize,
193 end: usize,
194 step: usize,
195 },
196 #[error("cci: invalid kernel for batch path: {0:?}")]
197 InvalidKernelForBatch(crate::utilities::enums::Kernel),
198}
199
200#[inline]
201pub fn cci(input: &CciInput) -> Result<CciOutput, CciError> {
202 cci_with_kernel(input, Kernel::Auto)
203}
204
205#[inline(always)]
206fn cci_prepare<'a>(
207 input: &'a CciInput,
208 kernel: Kernel,
209) -> Result<(&'a [f64], usize, usize, Kernel), CciError> {
210 let data: &[f64] = input.as_ref();
211 let len = data.len();
212 if len == 0 {
213 return Err(CciError::EmptyInputData);
214 }
215 let first = data
216 .iter()
217 .position(|x| !x.is_nan())
218 .ok_or(CciError::AllValuesNaN)?;
219 let period = input.get_period();
220
221 if period == 0 || period > len {
222 return Err(CciError::InvalidPeriod {
223 period,
224 data_len: len,
225 });
226 }
227 if len - first < period {
228 return Err(CciError::NotEnoughValidData {
229 needed: period,
230 valid: len - first,
231 });
232 }
233
234 let chosen = match kernel {
235 Kernel::Auto => detect_best_kernel(),
236 k => k,
237 };
238
239 Ok((data, period, first, chosen))
240}
241
242pub fn cci_with_kernel(input: &CciInput, kernel: Kernel) -> Result<CciOutput, CciError> {
243 let (data, period, first, chosen) = cci_prepare(input, kernel)?;
244
245 let prefix = first + period - 1;
246 let mut out = alloc_with_nan_prefix(data.len(), prefix);
247
248 unsafe {
249 match chosen {
250 Kernel::Scalar | Kernel::ScalarBatch => cci_scalar(data, period, first, &mut out),
251 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
252 Kernel::Avx2 | Kernel::Avx2Batch => cci_avx2(data, period, first, &mut out),
253 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
254 Kernel::Avx512 | Kernel::Avx512Batch => cci_avx512(data, period, first, &mut out),
255 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
256 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
257 cci_scalar(data, period, first, &mut out)
258 }
259 _ => unreachable!(),
260 }
261 }
262 Ok(CciOutput { values: out })
263}
264
265#[inline]
266pub fn cci_into_slice(dst: &mut [f64], input: &CciInput, kern: Kernel) -> Result<(), CciError> {
267 let (data, period, first, chosen) = cci_prepare(input, kern)?;
268
269 if dst.len() != data.len() {
270 return Err(CciError::OutputLengthMismatch {
271 expected: data.len(),
272 got: dst.len(),
273 });
274 }
275
276 unsafe {
277 match chosen {
278 Kernel::Scalar | Kernel::ScalarBatch => cci_scalar(data, period, first, dst),
279 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
280 Kernel::Avx2 | Kernel::Avx2Batch => cci_avx2(data, period, first, dst),
281 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
282 Kernel::Avx512 | Kernel::Avx512Batch => cci_avx512(data, period, first, dst),
283 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
284 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
285 cci_scalar(data, period, first, dst)
286 }
287 _ => unreachable!(),
288 }
289 }
290
291 let warmup_end = first + period - 1;
292 for v in &mut dst[..warmup_end] {
293 *v = f64::NAN;
294 }
295
296 Ok(())
297}
298
299#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
300#[inline]
301pub fn cci_into(input: &CciInput, out: &mut [f64]) -> Result<(), CciError> {
302 cci_into_slice(out, input, Kernel::Auto)
303}
304
305#[inline]
306pub fn cci_scalar(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
307 debug_assert_eq!(data.len(), out.len());
308 let n = data.len();
309 if n == 0 {
310 return;
311 }
312
313 let inv_p = 1.0 / (period as f64);
314
315 let start0 = first_valid;
316 let end0 = start0 + period;
317 let mut sum: f64 = data[start0..end0].iter().sum();
318 let mut sma = sum * inv_p;
319
320 let mut sum_abs = 0.0;
321 for &v in &data[start0..end0] {
322 sum_abs += (v - sma).abs();
323 }
324
325 let first_out = first_valid + period - 1;
326 let price0 = data[first_out];
327 out[first_out] = {
328 let denom = 0.015 * (sum_abs * inv_p);
329 if denom == 0.0 {
330 0.0
331 } else {
332 (price0 - sma) / denom
333 }
334 };
335
336 for i in (first_out + 1)..n {
337 let exiting = data[i - period];
338 let entering = data[i];
339 sum = sum - exiting + entering;
340 sma = sum * inv_p;
341
342 let wstart = i + 1 - period;
343 let wend = i + 1;
344 let mut sabs = 0.0;
345 for &v in &data[wstart..wend] {
346 sabs += (v - sma).abs();
347 }
348
349 out[i] = {
350 let denom = 0.015 * (sabs * inv_p);
351 if denom == 0.0 {
352 0.0
353 } else {
354 (entering - sma) / denom
355 }
356 };
357 }
358}
359
360#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
361#[inline]
362pub fn cci_avx512(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
363 cci_scalar(data, period, first_valid, out)
364}
365
366#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
367#[inline]
368pub fn cci_avx2(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
369 cci_scalar(data, period, first_valid, out)
370}
371
372#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
373#[target_feature(enable = "avx2,fma")]
374unsafe fn cci_avx2_impl(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
375 debug_assert!(data.len() == out.len());
376 debug_assert!(period >= 1 && first_valid + period <= data.len());
377
378 let n = data.len();
379 let inv_p = 1.0 / (period as f64);
380 let scale = (period as f64) * (1.0 / 0.015);
381
382 let base = data.as_ptr().add(first_valid);
383 let mut sum = 0.0;
384 {
385 let mut k = 0usize;
386 while k + 4 <= period {
387 let x0 = *base.add(k + 0);
388 let x1 = *base.add(k + 1);
389 let x2 = *base.add(k + 2);
390 let x3 = *base.add(k + 3);
391 sum = sum + x0 + x1 + x2 + x3;
392 k += 4;
393 }
394 while k < period {
395 sum += *base.add(k);
396 k += 1;
397 }
398 }
399
400 let first_out = first_valid + period - 1;
401 let mut sma = sum * inv_p;
402
403 {
404 let vmean = _mm256_set1_pd(sma);
405 let vsgn = _mm256_set1_pd(-0.0f64);
406 let mut k = 0usize;
407 let mut sum_abs = 0.0f64;
408 let mut comp = 0.0f64;
409 while k + 4 <= period {
410 let x = _mm256_loadu_pd(base.add(k));
411 let d = _mm256_sub_pd(x, vmean);
412 let a = _mm256_andnot_pd(vsgn, d);
413 let mut lane = [0.0f64; 4];
414 _mm256_storeu_pd(lane.as_mut_ptr(), a);
415
416 for &val in &lane {
417 let y = val - comp;
418 let t = sum_abs + y;
419 comp = (t - sum_abs) - y;
420 sum_abs = t;
421 }
422 k += 4;
423 }
424 while k < period {
425 let val = (*base.add(k) - sma).abs();
426 let y = val - comp;
427 let t = sum_abs + y;
428 comp = (t - sum_abs) - y;
429 sum_abs = t;
430 k += 1;
431 }
432 let price0 = *data.get_unchecked(first_out);
433 let denom = 0.015 * (sum_abs * inv_p);
434 *out.get_unchecked_mut(first_out) = if denom == 0.0 {
435 0.0
436 } else {
437 (price0 - sma) / denom
438 };
439 }
440
441 let mut i = first_out + 1;
442 while i < n {
443 let exiting = *data.get_unchecked(i - period);
444 let entering = *data.get_unchecked(i);
445 sum = sum - exiting + entering;
446 sma = sum * inv_p;
447
448 let start = i + 1 - period;
449 let wptr = data.as_ptr().add(start);
450
451 let vmean = _mm256_set1_pd(sma);
452 let vsgn = _mm256_set1_pd(-0.0f64);
453 let mut k = 0usize;
454 let mut sum_abs = 0.0f64;
455 let mut comp = 0.0f64;
456 while k + 4 <= period {
457 let x = _mm256_loadu_pd(wptr.add(k));
458 let d = _mm256_sub_pd(x, vmean);
459 let a = _mm256_andnot_pd(vsgn, d);
460 let mut lane = [0.0f64; 4];
461 _mm256_storeu_pd(lane.as_mut_ptr(), a);
462 for &val in &lane {
463 let y = val - comp;
464 let t = sum_abs + y;
465 comp = (t - sum_abs) - y;
466 sum_abs = t;
467 }
468 k += 4;
469 }
470 while k < period {
471 let val = (*wptr.add(k) - sma).abs();
472 let y = val - comp;
473 let t = sum_abs + y;
474 comp = (t - sum_abs) - y;
475 sum_abs = t;
476 k += 1;
477 }
478
479 let denom = 0.015 * (sum_abs * inv_p);
480 *out.get_unchecked_mut(i) = if denom == 0.0 {
481 0.0
482 } else {
483 (entering - sma) / denom
484 };
485 i += 1;
486 }
487}
488
489#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
490#[inline]
491pub unsafe fn cci_avx512_short(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
492 cci_scalar(data, period, first_valid, out)
493}
494
495#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
496#[inline]
497pub unsafe fn cci_avx512_long(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
498 cci_scalar(data, period, first_valid, out)
499}
500
501#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
502#[target_feature(enable = "avx512f,fma")]
503unsafe fn cci_avx512_impl(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
504 debug_assert!(data.len() == out.len());
505 debug_assert!(period >= 1 && first_valid + period <= data.len());
506
507 let n = data.len();
508 let inv_p = 1.0 / (period as f64);
509 let scale = (period as f64) * (1.0 / 0.015);
510
511 let base = data.as_ptr().add(first_valid);
512 let mut sum = 0.0;
513 {
514 let mut k = 0usize;
515 while k + 4 <= period {
516 let x0 = *base.add(k + 0);
517 let x1 = *base.add(k + 1);
518 let x2 = *base.add(k + 2);
519 let x3 = *base.add(k + 3);
520 sum = sum + x0 + x1 + x2 + x3;
521 k += 4;
522 }
523 while k < period {
524 sum += *base.add(k);
525 k += 1;
526 }
527 }
528
529 let first_out = first_valid + period - 1;
530 let mut sma = sum * inv_p;
531
532 let pos_mask_i = _mm512_set1_epi64(0x7FFF_FFFF_FFFF_FFFFu64 as i64);
533 let pos_mask = _mm512_castsi512_pd(pos_mask_i);
534
535 {
536 let vmean = _mm512_set1_pd(sma);
537 let mut k = 0usize;
538 let mut sum_abs = 0.0f64;
539 let mut comp = 0.0f64;
540 while k + 8 <= period {
541 let x = _mm512_loadu_pd(base.add(k));
542 let d = _mm512_sub_pd(x, vmean);
543 let a = _mm512_and_pd(d, pos_mask);
544 let mut lane = [0.0f64; 8];
545 _mm512_storeu_pd(lane.as_mut_ptr(), a);
546 for &val in &lane {
547 let y = val - comp;
548 let t = sum_abs + y;
549 comp = (t - sum_abs) - y;
550 sum_abs = t;
551 }
552 k += 8;
553 }
554 while k < period {
555 let val = (*base.add(k) - sma).abs();
556 let y = val - comp;
557 let t = sum_abs + y;
558 comp = (t - sum_abs) - y;
559 sum_abs = t;
560 k += 1;
561 }
562 let price0 = *data.get_unchecked(first_out);
563 let denom = 0.015 * (sum_abs * inv_p);
564 *out.get_unchecked_mut(first_out) = if denom == 0.0 {
565 0.0
566 } else {
567 (price0 - sma) / denom
568 };
569 }
570
571 let mut i = first_out + 1;
572 while i < n {
573 let exiting = *data.get_unchecked(i - period);
574 let entering = *data.get_unchecked(i);
575 sum = sum - exiting + entering;
576 sma = sum * inv_p;
577
578 let start = i + 1 - period;
579 let wptr = data.as_ptr().add(start);
580
581 let vmean = _mm512_set1_pd(sma);
582 let mut k = 0usize;
583 let mut sum_abs = 0.0f64;
584 let mut comp = 0.0f64;
585 while k + 8 <= period {
586 let x = _mm512_loadu_pd(wptr.add(k));
587 let d = _mm512_sub_pd(x, vmean);
588 let a = _mm512_and_pd(d, pos_mask);
589 let mut lane = [0.0f64; 8];
590 _mm512_storeu_pd(lane.as_mut_ptr(), a);
591 for &val in &lane {
592 let y = val - comp;
593 let t = sum_abs + y;
594 comp = (t - sum_abs) - y;
595 sum_abs = t;
596 }
597 k += 8;
598 }
599 while k < period {
600 let val = (*wptr.add(k) - sma).abs();
601 let y = val - comp;
602 let t = sum_abs + y;
603 comp = (t - sum_abs) - y;
604 sum_abs = t;
605 k += 1;
606 }
607 let denom = 0.015 * (sum_abs * inv_p);
608 *out.get_unchecked_mut(i) = if denom == 0.0 {
609 0.0
610 } else {
611 (entering - sma) / denom
612 };
613 i += 1;
614 }
615}
616
617#[inline(always)]
618pub unsafe fn cci_row_scalar(
619 data: &[f64],
620 first: usize,
621 period: usize,
622 _stride: usize,
623 _w_ptr: *const f64,
624 _inv: f64,
625 out: &mut [f64],
626) {
627 cci_scalar(data, period, first, out)
628}
629
630#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
631#[inline(always)]
632pub unsafe fn cci_row_avx2(
633 data: &[f64],
634 first: usize,
635 period: usize,
636 _stride: usize,
637 _w_ptr: *const f64,
638 _inv: f64,
639 out: &mut [f64],
640) {
641 cci_avx2(data, period, first, out)
642}
643
644#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
645#[inline(always)]
646pub unsafe fn cci_row_avx512(
647 data: &[f64],
648 first: usize,
649 period: usize,
650 _stride: usize,
651 _w_ptr: *const f64,
652 _inv: f64,
653 out: &mut [f64],
654) {
655 cci_avx512(data, period, first, out)
656}
657
658#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
659#[inline(always)]
660pub unsafe fn cci_row_avx512_short(
661 data: &[f64],
662 first: usize,
663 period: usize,
664 _stride: usize,
665 _w_ptr: *const f64,
666 _inv: f64,
667 out: &mut [f64],
668) {
669 cci_avx512_short(data, period, first, out)
670}
671
672#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
673#[inline(always)]
674pub unsafe fn cci_row_avx512_long(
675 data: &[f64],
676 first: usize,
677 period: usize,
678 _stride: usize,
679 _w_ptr: *const f64,
680 _inv: f64,
681 out: &mut [f64],
682) {
683 cci_avx512_long(data, period, first, out)
684}
685
686#[derive(Debug, Clone)]
687pub struct CciStream {
688 period: usize,
689 buffer: Vec<f64>,
690 head: usize,
691 filled: bool,
692
693 sum: f64,
694
695 scale: f64,
696
697 ost: OrderStatsTreap,
698}
699
700#[inline(always)]
701fn splitmix64(mut x: u64) -> u64 {
702 x = x.wrapping_add(0x9E3779B97F4A7C15);
703 let mut z = x;
704 z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
705 z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
706 z ^ (z >> 31)
707}
708
709#[derive(Debug, Clone)]
710struct OrderStatsTreap {
711 root: Option<Box<Node>>,
712 seed: u64,
713}
714
715#[derive(Debug, Clone)]
716struct Node {
717 key: f64,
718 prio: u64,
719 cnt: u32,
720 size: usize,
721 sum: f64,
722 left: Option<Box<Node>>,
723 right: Option<Box<Node>>,
724}
725
726#[inline(always)]
727fn sz(t: &Option<Box<Node>>) -> usize {
728 t.as_ref().map(|n| n.size).unwrap_or(0)
729}
730#[inline(always)]
731fn sm(t: &Option<Box<Node>>) -> f64 {
732 t.as_ref().map(|n| n.sum).unwrap_or(0.0)
733}
734
735impl Node {
736 #[inline(always)]
737 fn new(key: f64, prio: u64) -> Self {
738 Self {
739 key,
740 prio,
741 cnt: 1,
742 size: 1,
743 sum: key,
744 left: None,
745 right: None,
746 }
747 }
748 #[inline(always)]
749 fn recalc(&mut self) {
750 self.size = self.cnt as usize + sz(&self.left) + sz(&self.right);
751 self.sum = (self.key * self.cnt as f64) + sm(&self.left) + sm(&self.right);
752 }
753}
754
755impl OrderStatsTreap {
756 #[inline(always)]
757 fn new() -> Self {
758 let seed = splitmix64((&() as *const () as usize as u64) ^ 0xA5A5_A5A5_A5A5_A5A5);
759 Self { root: None, seed }
760 }
761
762 #[inline(always)]
763 fn next_prio(&mut self) -> u64 {
764 self.seed = splitmix64(self.seed);
765 self.seed
766 }
767
768 fn merge(a: Option<Box<Node>>, b: Option<Box<Node>>) -> Option<Box<Node>> {
769 match (a, b) {
770 (None, t) | (t, None) => t,
771 (Some(mut x), Some(mut y)) => {
772 if x.prio > y.prio {
773 x.right = Self::merge(x.right.take(), Some(y));
774 x.recalc();
775 Some(x)
776 } else {
777 y.left = Self::merge(Some(x), y.left.take());
778 y.recalc();
779 Some(y)
780 }
781 }
782 }
783 }
784
785 fn split(mut t: Option<Box<Node>>, key: f64) -> (Option<Box<Node>>, Option<Box<Node>>) {
786 match t.take() {
787 None => (None, None),
788 Some(mut n) => {
789 if key < n.key {
790 let (l, r) = Self::split(n.left.take(), key);
791 n.left = r;
792 n.recalc();
793 (l, Some(n))
794 } else {
795 let (l, r) = Self::split(n.right.take(), key);
796 n.right = l;
797 n.recalc();
798 (Some(n), r)
799 }
800 }
801 }
802 }
803
804 fn insert(&mut self, key: f64) {
805 debug_assert!(key.is_finite());
806 self.root = match self.root.take() {
807 None => Some(Box::new(Node::new(key, self.next_prio()))),
808 Some(mut _n) => {
809 self.root = Some(_n);
810 Self::insert_into(self.root.take(), key, self.next_prio())
811 }
812 };
813 }
814
815 fn insert_into(t: Option<Box<Node>>, key: f64, prio: u64) -> Option<Box<Node>> {
816 match t {
817 None => Some(Box::new(Node::new(key, prio))),
818 Some(mut n) => {
819 if key == n.key {
820 n.cnt += 1;
821 n.recalc();
822 Some(n)
823 } else if prio > n.prio {
824 let (l, r) = Self::split(Some(n), key);
825 let mut m = Box::new(Node::new(key, prio));
826 m.left = l;
827 m.right = r;
828 m.recalc();
829 Some(m)
830 } else if key < n.key {
831 n.left = Self::insert_into(n.left.take(), key, prio);
832 n.recalc();
833 Some(n)
834 } else {
835 n.right = Self::insert_into(n.right.take(), key, prio);
836 n.recalc();
837 Some(n)
838 }
839 }
840 }
841 }
842
843 fn erase(&mut self, key: f64) {
844 debug_assert!(key.is_finite());
845 self.root = Self::erase_from(self.root.take(), key);
846 }
847
848 fn erase_from(t: Option<Box<Node>>, key: f64) -> Option<Box<Node>> {
849 match t {
850 None => None,
851 Some(mut n) => {
852 if key == n.key {
853 if n.cnt > 1 {
854 n.cnt -= 1;
855 n.recalc();
856 Some(n)
857 } else {
858 Self::merge(n.left.take(), n.right.take())
859 }
860 } else if key < n.key {
861 n.left = Self::erase_from(n.left.take(), key);
862 n.recalc();
863 Some(n)
864 } else {
865 n.right = Self::erase_from(n.right.take(), key);
866 n.recalc();
867 Some(n)
868 }
869 }
870 }
871 }
872
873 fn prefix_le(&self, key: f64) -> (usize, f64) {
874 debug_assert!(key.is_finite());
875 let mut t = &self.root;
876 let mut count = 0usize;
877 let mut sum = 0.0f64;
878
879 while let Some(n) = t.as_ref() {
880 if key < n.key {
881 t = &n.left;
882 } else {
883 count += sz(&n.left) + n.cnt as usize;
884 sum += sm(&n.left) + (n.key * n.cnt as f64);
885 t = &n.right;
886 }
887 }
888 (count, sum)
889 }
890
891 #[inline(always)]
892 fn size(&self) -> usize {
893 sz(&self.root)
894 }
895}
896
897impl CciStream {
898 pub fn try_new(params: CciParams) -> Result<Self, CciError> {
899 let period = params.period.unwrap_or(14);
900 if period == 0 {
901 return Err(CciError::InvalidPeriod {
902 period,
903 data_len: 0,
904 });
905 }
906
907 let buffer = alloc_with_nan_prefix(period, period);
908 let scale = (period as f64) * (1.0 / 0.015);
909
910 Ok(Self {
911 period,
912 buffer,
913 head: 0,
914 filled: false,
915 sum: 0.0,
916 scale,
917 ost: OrderStatsTreap::new(),
918 })
919 }
920
921 #[inline(always)]
922 pub fn update(&mut self, value: f64) -> Option<f64> {
923 debug_assert!(value.is_finite(), "CCI stream expects finite inputs");
924
925 let old = self.buffer[self.head];
926 self.buffer[self.head] = value;
927 self.head = (self.head + 1) % self.period;
928 if !self.filled && self.head == 0 {
929 self.filled = true;
930 }
931
932 if !self.filled {
933 self.sum += value;
934 self.ost.insert(value);
935 return None;
936 }
937
938 if !old.is_nan() {
939 self.sum -= old;
940 self.ost.erase(old);
941 }
942 self.sum += value;
943 self.ost.insert(value);
944
945 let mean = self.sum / (self.period as f64);
946
947 let (k_le, sum_le) = self.ost.prefix_le(mean);
948
949 let n = self.period as f64;
950 let sum_abs = mean.mul_add(2.0 * (k_le as f64) - n, self.sum - 2.0 * sum_le);
951
952 if sum_abs == 0.0 {
953 Some(0.0)
954 } else {
955 Some((value - mean) * (self.scale / sum_abs))
956 }
957 }
958}
959
960#[derive(Clone, Debug)]
961pub struct CciBatchRange {
962 pub period: (usize, usize, usize),
963}
964
965impl Default for CciBatchRange {
966 fn default() -> Self {
967 Self {
968 period: (14, 263, 1),
969 }
970 }
971}
972
973#[derive(Clone, Debug, Default)]
974pub struct CciBatchBuilder {
975 range: CciBatchRange,
976 kernel: Kernel,
977}
978
979impl CciBatchBuilder {
980 pub fn new() -> Self {
981 Self::default()
982 }
983 pub fn kernel(mut self, k: Kernel) -> Self {
984 self.kernel = k;
985 self
986 }
987 #[inline]
988 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
989 self.range.period = (start, end, step);
990 self
991 }
992 #[inline]
993 pub fn period_static(mut self, p: usize) -> Self {
994 self.range.period = (p, p, 0);
995 self
996 }
997 pub fn apply_slice(self, data: &[f64]) -> Result<CciBatchOutput, CciError> {
998 cci_batch_with_kernel(data, &self.range, self.kernel)
999 }
1000 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<CciBatchOutput, CciError> {
1001 CciBatchBuilder::new().kernel(k).apply_slice(data)
1002 }
1003 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<CciBatchOutput, CciError> {
1004 let slice = source_type(c, src);
1005 self.apply_slice(slice)
1006 }
1007 pub fn with_default_candles(c: &Candles) -> Result<CciBatchOutput, CciError> {
1008 CciBatchBuilder::new()
1009 .kernel(Kernel::Auto)
1010 .apply_candles(c, "hlc3")
1011 }
1012}
1013
1014#[derive(Clone, Debug)]
1015pub struct CciBatchOutput {
1016 pub values: Vec<f64>,
1017 pub combos: Vec<CciParams>,
1018 pub rows: usize,
1019 pub cols: usize,
1020}
1021impl CciBatchOutput {
1022 pub fn row_for_params(&self, p: &CciParams) -> Option<usize> {
1023 self.combos
1024 .iter()
1025 .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
1026 }
1027 pub fn values_for(&self, p: &CciParams) -> Option<&[f64]> {
1028 self.row_for_params(p).map(|row| {
1029 let start = row * self.cols;
1030 &self.values[start..start + self.cols]
1031 })
1032 }
1033}
1034
1035#[inline(always)]
1036fn expand_grid(r: &CciBatchRange) -> Result<Vec<CciParams>, CciError> {
1037 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, CciError> {
1038 if step == 0 || start == end {
1039 return Ok(vec![start]);
1040 }
1041 if start < end {
1042 let mut v = Vec::new();
1043 let mut cur = start;
1044 while cur <= end {
1045 v.push(cur);
1046 match cur.checked_add(step) {
1047 Some(next) => {
1048 if next == cur {
1049 break;
1050 }
1051 cur = next;
1052 }
1053 None => break,
1054 }
1055 }
1056 if v.is_empty() {
1057 return Err(CciError::InvalidRange { start, end, step });
1058 }
1059 Ok(v)
1060 } else {
1061 let mut v = Vec::new();
1062 let mut cur = start;
1063 loop {
1064 v.push(cur);
1065 if cur <= end {
1066 break;
1067 }
1068
1069 if cur < step {
1070 break;
1071 }
1072 cur -= step;
1073 if cur == end {
1074 v.push(cur);
1075 break;
1076 }
1077 if cur < end {
1078 break;
1079 }
1080 }
1081 v.sort_unstable();
1082 v.dedup();
1083 if v.is_empty() {
1084 return Err(CciError::InvalidRange { start, end, step });
1085 }
1086 Ok(v)
1087 }
1088 }
1089 let periods = axis_usize(r.period)?;
1090 let mut out = Vec::with_capacity(periods.len());
1091 for &p in &periods {
1092 out.push(CciParams { period: Some(p) });
1093 }
1094 Ok(out)
1095}
1096
1097pub fn cci_batch_with_kernel(
1098 data: &[f64],
1099 sweep: &CciBatchRange,
1100 k: Kernel,
1101) -> Result<CciBatchOutput, CciError> {
1102 let kernel = match k {
1103 Kernel::Auto => detect_best_batch_kernel(),
1104 other if other.is_batch() => other,
1105 other => return Err(CciError::InvalidKernelForBatch(other)),
1106 };
1107 let simd = match kernel {
1108 Kernel::Avx512Batch => Kernel::Avx512,
1109 Kernel::Avx2Batch => Kernel::Avx2,
1110 Kernel::ScalarBatch => Kernel::Scalar,
1111 _ => unreachable!(),
1112 };
1113 cci_batch_par_slice(data, sweep, simd)
1114}
1115
1116#[inline(always)]
1117pub fn cci_batch_slice(
1118 data: &[f64],
1119 sweep: &CciBatchRange,
1120 kern: Kernel,
1121) -> Result<CciBatchOutput, CciError> {
1122 cci_batch_inner(data, sweep, kern, false)
1123}
1124#[inline(always)]
1125pub fn cci_batch_par_slice(
1126 data: &[f64],
1127 sweep: &CciBatchRange,
1128 kern: Kernel,
1129) -> Result<CciBatchOutput, CciError> {
1130 cci_batch_inner(data, sweep, kern, true)
1131}
1132
1133#[inline(always)]
1134fn cci_batch_inner(
1135 data: &[f64],
1136 sweep: &CciBatchRange,
1137 kern: Kernel,
1138 parallel: bool,
1139) -> Result<CciBatchOutput, CciError> {
1140 if data.is_empty() {
1141 return Err(CciError::EmptyInputData);
1142 }
1143 let combos = expand_grid(sweep)?;
1144 let cols = data.len();
1145 let rows = combos.len();
1146
1147 let first = data
1148 .iter()
1149 .position(|x| !x.is_nan())
1150 .ok_or(CciError::AllValuesNaN)?;
1151 let mut max_p = 0usize;
1152 for c in &combos {
1153 let p = c.period.unwrap();
1154 if p == 0 || p > cols {
1155 return Err(CciError::InvalidPeriod {
1156 period: p,
1157 data_len: cols,
1158 });
1159 }
1160 max_p = max_p.max(p);
1161 }
1162 if cols - first < max_p {
1163 return Err(CciError::NotEnoughValidData {
1164 needed: max_p,
1165 valid: cols - first,
1166 });
1167 }
1168
1169 rows.checked_mul(cols).ok_or(CciError::InvalidRange {
1170 start: sweep.period.0,
1171 end: sweep.period.1,
1172 step: sweep.period.2,
1173 })?;
1174
1175 let mut buf_mu = make_uninit_matrix(rows, cols);
1176 let out_ptr = buf_mu.as_mut_ptr() as *mut f64;
1177 let out_len = buf_mu.len();
1178 let out_cap = buf_mu.capacity();
1179 let out: &mut [f64] = unsafe { core::slice::from_raw_parts_mut(out_ptr, out_len) };
1180
1181 let _ = cci_batch_inner_into(data, sweep, kern, parallel, out)?;
1182
1183 let values = unsafe { Vec::from_raw_parts(out_ptr, out_len, out_cap) };
1184 core::mem::forget(buf_mu);
1185
1186 Ok(CciBatchOutput {
1187 values,
1188 combos,
1189 rows,
1190 cols,
1191 })
1192}
1193
1194#[inline(always)]
1195fn cci_batch_inner_into(
1196 data: &[f64],
1197 sweep: &CciBatchRange,
1198 kern: Kernel,
1199 parallel: bool,
1200 out: &mut [f64],
1201) -> Result<Vec<CciParams>, CciError> {
1202 let combos = expand_grid(sweep)?;
1203 if combos.is_empty() {
1204 return Err(CciError::InvalidRange {
1205 start: sweep.period.0,
1206 end: sweep.period.1,
1207 step: sweep.period.2,
1208 });
1209 }
1210
1211 if data.is_empty() {
1212 return Err(CciError::EmptyInputData);
1213 }
1214
1215 let first = data
1216 .iter()
1217 .position(|x| !x.is_nan())
1218 .ok_or(CciError::AllValuesNaN)?;
1219 let mut max_p = 0usize;
1220 for c in &combos {
1221 let p = c.period.unwrap();
1222 if p == 0 || p > data.len() {
1223 return Err(CciError::InvalidPeriod {
1224 period: p,
1225 data_len: data.len(),
1226 });
1227 }
1228 max_p = max_p.max(p);
1229 }
1230 if data.len() - first < max_p {
1231 return Err(CciError::NotEnoughValidData {
1232 needed: max_p,
1233 valid: data.len() - first,
1234 });
1235 }
1236
1237 let rows = combos.len();
1238 let cols = data.len();
1239
1240 let expected = rows.checked_mul(cols).ok_or(CciError::InvalidRange {
1241 start: sweep.period.0,
1242 end: sweep.period.1,
1243 step: sweep.period.2,
1244 })?;
1245 if out.len() != expected {
1246 return Err(CciError::OutputLengthMismatch {
1247 expected,
1248 got: out.len(),
1249 });
1250 }
1251
1252 let kernel = match kern {
1253 Kernel::Auto => detect_best_batch_kernel(),
1254 Kernel::Scalar => Kernel::ScalarBatch,
1255 Kernel::Avx2 => Kernel::Avx2Batch,
1256 Kernel::Avx512 => Kernel::Avx512Batch,
1257 k => k,
1258 };
1259 let simd = match kernel {
1260 Kernel::Avx512Batch => Kernel::Avx512,
1261 Kernel::Avx2Batch => Kernel::Avx2,
1262 Kernel::ScalarBatch => Kernel::Scalar,
1263 _ => unreachable!(),
1264 };
1265
1266 let warm: Vec<usize> = combos
1267 .iter()
1268 .map(|c| first + c.period.unwrap() - 1)
1269 .collect();
1270
1271 let out_uninit = unsafe {
1272 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1273 };
1274
1275 for (row, &warmup) in warm.iter().enumerate() {
1276 let row_start = row * cols;
1277 for col in 0..warmup.min(cols) {
1278 out_uninit[row_start + col].write(f64::NAN);
1279 }
1280 }
1281
1282 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1283 let period = combos[row].period.unwrap();
1284
1285 let dst = core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1286
1287 assert_eq!(dst.len(), cols, "Output row length mismatch");
1288
1289 match simd {
1290 Kernel::Scalar => cci_row_scalar(data, first, period, 0, std::ptr::null(), 0.0, dst),
1291 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1292 Kernel::Avx2 => cci_row_avx2(data, first, period, 0, std::ptr::null(), 0.0, dst),
1293 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1294 Kernel::Avx512 => cci_row_avx512(data, first, period, 0, std::ptr::null(), 0.0, dst),
1295 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1296 Kernel::Avx2 | Kernel::Avx512 => {
1297 cci_row_scalar(data, first, period, 0, std::ptr::null(), 0.0, dst)
1298 }
1299 _ => unreachable!(),
1300 }
1301 };
1302
1303 #[cfg(not(target_arch = "wasm32"))]
1304 if parallel && rows > 1 {
1305 out_uninit
1306 .par_chunks_mut(cols)
1307 .enumerate()
1308 .for_each(|(row, out_row)| do_row(row, out_row));
1309 } else {
1310 out_uninit
1311 .chunks_mut(cols)
1312 .enumerate()
1313 .for_each(|(row, out_row)| do_row(row, out_row));
1314 }
1315
1316 #[cfg(target_arch = "wasm32")]
1317 {
1318 out_uninit
1319 .chunks_mut(cols)
1320 .enumerate()
1321 .for_each(|(row, out_row)| do_row(row, out_row));
1322 }
1323
1324 Ok(combos)
1325}
1326
1327#[cfg(test)]
1328mod tests {
1329 use super::*;
1330 use crate::skip_if_unsupported;
1331 use crate::utilities::data_loader::read_candles_from_csv;
1332 use paste::paste;
1333 #[cfg(feature = "proptest")]
1334 use proptest::prelude::*;
1335
1336 fn check_cci_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1337 skip_if_unsupported!(kernel, test_name);
1338 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1339 let candles = read_candles_from_csv(file_path)?;
1340
1341 let default_params = CciParams { period: None };
1342 let input_default = CciInput::from_candles(&candles, "close", default_params);
1343 let output_default = cci_with_kernel(&input_default, kernel)?;
1344 assert_eq!(output_default.values.len(), candles.close.len());
1345
1346 let params_20 = CciParams { period: Some(20) };
1347 let input_20 = CciInput::from_candles(&candles, "hl2", params_20);
1348 let output_20 = cci_with_kernel(&input_20, kernel)?;
1349 assert_eq!(output_20.values.len(), candles.close.len());
1350
1351 let params_custom = CciParams { period: Some(9) };
1352 let input_custom = CciInput::from_candles(&candles, "hlc3", params_custom);
1353 let output_custom = cci_with_kernel(&input_custom, kernel)?;
1354 assert_eq!(output_custom.values.len(), candles.close.len());
1355 Ok(())
1356 }
1357
1358 fn check_cci_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1359 skip_if_unsupported!(kernel, test_name);
1360 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1361 let candles = read_candles_from_csv(file_path)?;
1362 let input = CciInput::with_default_candles(&candles);
1363 let cci_result = cci_with_kernel(&input, kernel)?;
1364 assert_eq!(cci_result.values.len(), candles.close.len());
1365
1366 let expected_last_five_cci = [
1367 -51.55252564125841,
1368 -43.50326506381541,
1369 -64.05117302269149,
1370 -39.05150631680948,
1371 -152.50523930896998,
1372 ];
1373
1374 let start_idx = cci_result.values.len() - 5;
1375 let last_five_cci = &cci_result.values[start_idx..];
1376 for (i, &value) in last_five_cci.iter().enumerate() {
1377 let expected = expected_last_five_cci[i];
1378 assert!(
1379 (value - expected).abs() < 1e-6,
1380 "[{}] CCI mismatch at last five index {}: expected {}, got {}",
1381 test_name,
1382 i,
1383 expected,
1384 value
1385 );
1386 }
1387 let period: usize = input.get_period();
1388 for i in 0..(period - 1) {
1389 assert!(
1390 cci_result.values[i].is_nan(),
1391 "Expected NaN at index {} for initial period warm-up",
1392 i
1393 );
1394 }
1395 Ok(())
1396 }
1397
1398 fn check_cci_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1399 skip_if_unsupported!(kernel, test_name);
1400 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1401 let candles = read_candles_from_csv(file_path)?;
1402 let input = CciInput::with_default_candles(&candles);
1403
1404 match input.data {
1405 CciData::Candles { source, .. } => {
1406 assert_eq!(source, "hlc3", "Expected default source to be 'hlc3'");
1407 }
1408 _ => panic!("Expected CciData::Candles variant"),
1409 }
1410 let output = cci_with_kernel(&input, kernel)?;
1411 assert_eq!(output.values.len(), candles.close.len());
1412 Ok(())
1413 }
1414
1415 fn check_cci_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1416 skip_if_unsupported!(kernel, test_name);
1417 let input_data = [10.0, 20.0, 30.0];
1418 let params = CciParams { period: Some(0) };
1419 let input = CciInput::from_slice(&input_data, params);
1420 let res = cci_with_kernel(&input, kernel);
1421 assert!(
1422 res.is_err(),
1423 "[{}] CCI should fail with zero period",
1424 test_name
1425 );
1426 Ok(())
1427 }
1428
1429 fn check_cci_period_exceeds_length(
1430 test_name: &str,
1431 kernel: Kernel,
1432 ) -> Result<(), Box<dyn Error>> {
1433 skip_if_unsupported!(kernel, test_name);
1434 let data_small = [10.0, 20.0, 30.0];
1435 let params = CciParams { period: Some(10) };
1436 let input = CciInput::from_slice(&data_small, params);
1437 let res = cci_with_kernel(&input, kernel);
1438 assert!(
1439 res.is_err(),
1440 "[{}] CCI should fail with period exceeding length",
1441 test_name
1442 );
1443 Ok(())
1444 }
1445
1446 fn check_cci_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1447 skip_if_unsupported!(kernel, test_name);
1448 let single_point = [42.0];
1449 let params = CciParams { period: Some(9) };
1450 let input = CciInput::from_slice(&single_point, params);
1451 let res = cci_with_kernel(&input, kernel);
1452 assert!(
1453 res.is_err(),
1454 "[{}] CCI should fail with insufficient data",
1455 test_name
1456 );
1457 Ok(())
1458 }
1459
1460 fn check_cci_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1461 skip_if_unsupported!(kernel, test_name);
1462 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1463 let candles = read_candles_from_csv(file_path)?;
1464
1465 let first_params = CciParams { period: Some(14) };
1466 let first_input = CciInput::from_candles(&candles, "close", first_params);
1467 let first_result = cci_with_kernel(&first_input, kernel)?;
1468
1469 let second_params = CciParams { period: Some(14) };
1470 let second_input = CciInput::from_slice(&first_result.values, second_params);
1471 let second_result = cci_with_kernel(&second_input, kernel)?;
1472
1473 assert_eq!(second_result.values.len(), first_result.values.len());
1474 if second_result.values.len() > 28 {
1475 for i in 28..second_result.values.len() {
1476 assert!(
1477 !second_result.values[i].is_nan(),
1478 "Expected no NaN after index 28, found NaN at index {}",
1479 i
1480 );
1481 }
1482 }
1483 Ok(())
1484 }
1485
1486 fn check_cci_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1487 skip_if_unsupported!(kernel, test_name);
1488 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1489 let candles = read_candles_from_csv(file_path)?;
1490 let input = CciInput::from_candles(&candles, "close", CciParams { period: Some(14) });
1491 let res = cci_with_kernel(&input, kernel)?;
1492 assert_eq!(res.values.len(), candles.close.len());
1493 if res.values.len() > 240 {
1494 for (i, &val) in res.values[240..].iter().enumerate() {
1495 assert!(
1496 !val.is_nan(),
1497 "[{}] Found unexpected NaN at out-index {}",
1498 test_name,
1499 240 + i
1500 );
1501 }
1502 }
1503 Ok(())
1504 }
1505
1506 fn check_cci_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1507 skip_if_unsupported!(kernel, test_name);
1508 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1509 let candles = read_candles_from_csv(file_path)?;
1510
1511 let period = 14;
1512 let input = CciInput::from_candles(
1513 &candles,
1514 "close",
1515 CciParams {
1516 period: Some(period),
1517 },
1518 );
1519 let batch_output = cci_with_kernel(&input, kernel)?.values;
1520
1521 let mut stream = CciStream::try_new(CciParams {
1522 period: Some(period),
1523 })?;
1524
1525 let mut stream_values = Vec::with_capacity(candles.close.len());
1526 for &price in &candles.close {
1527 match stream.update(price) {
1528 Some(cci_val) => stream_values.push(cci_val),
1529 None => stream_values.push(f64::NAN),
1530 }
1531 }
1532
1533 assert_eq!(batch_output.len(), stream_values.len());
1534 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1535 if b.is_nan() && s.is_nan() {
1536 continue;
1537 }
1538 let diff = (b - s).abs();
1539 assert!(
1540 diff < 1e-9,
1541 "[{}] CCI streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1542 test_name,
1543 i,
1544 b,
1545 s,
1546 diff
1547 );
1548 }
1549 Ok(())
1550 }
1551
1552 fn check_cci_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1553 skip_if_unsupported!(kernel, test_name);
1554 let empty_data: &[f64] = &[];
1555 let params = CciParams { period: Some(14) };
1556 let input = CciInput::from_slice(empty_data, params);
1557 let res = cci_with_kernel(&input, kernel);
1558 assert!(
1559 res.is_err(),
1560 "[{}] CCI should fail with empty input data",
1561 test_name
1562 );
1563 if let Err(e) = res {
1564 match e {
1565 CciError::EmptyInputData => {}
1566 other => panic!(
1567 "[{}] Expected EmptyInputData error, got: {:?}",
1568 test_name, other
1569 ),
1570 }
1571 }
1572 Ok(())
1573 }
1574
1575 #[cfg(debug_assertions)]
1576 fn check_cci_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1577 skip_if_unsupported!(kernel, test_name);
1578
1579 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1580 let candles = read_candles_from_csv(file_path)?;
1581
1582 let input = CciInput::from_candles(&candles, "close", CciParams::default());
1583 let output = cci_with_kernel(&input, kernel)?;
1584
1585 let params_20 = CciParams { period: Some(20) };
1586 let input_20 = CciInput::from_candles(&candles, "hlc3", params_20);
1587 let output_20 = cci_with_kernel(&input_20, kernel)?;
1588
1589 for output in [output, output_20] {
1590 for (i, &val) in output.values.iter().enumerate() {
1591 if val.is_nan() {
1592 continue;
1593 }
1594
1595 let bits = val.to_bits();
1596
1597 if bits == 0x11111111_11111111 {
1598 panic!(
1599 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {}",
1600 test_name, val, bits, i
1601 );
1602 }
1603
1604 if bits == 0x22222222_22222222 {
1605 panic!(
1606 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {}",
1607 test_name, val, bits, i
1608 );
1609 }
1610
1611 if bits == 0x33333333_33333333 {
1612 panic!(
1613 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {}",
1614 test_name, val, bits, i
1615 );
1616 }
1617 }
1618 }
1619
1620 Ok(())
1621 }
1622
1623 #[cfg(not(debug_assertions))]
1624 fn check_cci_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1625 Ok(())
1626 }
1627
1628 #[cfg(feature = "proptest")]
1629 #[allow(clippy::float_cmp)]
1630 fn check_cci_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1631 use proptest::prelude::*;
1632 skip_if_unsupported!(kernel, test_name);
1633
1634 let strat = (1usize..=64).prop_flat_map(|period| {
1635 (
1636 prop::collection::vec(
1637 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1638 period..400,
1639 ),
1640 Just(period),
1641 )
1642 });
1643
1644 proptest::test_runner::TestRunner::default().run(&strat, |(data, period)| {
1645 let params = CciParams {
1646 period: Some(period),
1647 };
1648 let input = CciInput::from_slice(&data, params);
1649
1650 let CciOutput { values: out } = cci_with_kernel(&input, kernel).unwrap();
1651
1652 let CciOutput { values: ref_out } = cci_with_kernel(&input, Kernel::Scalar).unwrap();
1653
1654 for i in 0..(period - 1) {
1655 prop_assert!(
1656 out[i].is_nan(),
1657 "[{}] Expected NaN at index {} during warmup period, got {}",
1658 test_name,
1659 i,
1660 out[i]
1661 );
1662 }
1663
1664 for i in (period - 1)..data.len() {
1665 prop_assert!(
1666 !out[i].is_nan(),
1667 "[{}] Expected valid value at index {} after warmup, got NaN",
1668 test_name,
1669 i
1670 );
1671 }
1672
1673 for i in 0..data.len() {
1674 let y = out[i];
1675 let r = ref_out[i];
1676
1677 if y.is_nan() && r.is_nan() {
1678 continue;
1679 }
1680
1681 let y_bits = y.to_bits();
1682 let r_bits = r.to_bits();
1683 let ulp_diff = if y_bits > r_bits {
1684 y_bits - r_bits
1685 } else {
1686 r_bits - y_bits
1687 };
1688
1689 prop_assert!(
1690 ulp_diff <= 8,
1691 "[{}] Kernel mismatch at index {}: {} != {} (ULP diff: {})",
1692 test_name,
1693 i,
1694 y,
1695 r,
1696 ulp_diff
1697 );
1698 }
1699
1700 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) && data.len() >= period {
1701 for i in (period - 1)..data.len() {
1702 prop_assert!(
1703 out[i].abs() < 1e-9,
1704 "[{}] CCI should be ~0 for constant prices, got {} at index {}",
1705 test_name,
1706 out[i],
1707 i
1708 );
1709 }
1710 }
1711
1712 for i in (period - 1)..data.len() {
1713 let window_start = i + 1 - period;
1714 let window = &data[window_start..=i];
1715
1716 let sum: f64 = window.iter().sum();
1717 let sma = sum / period as f64;
1718
1719 let mad: f64 = window.iter().map(|&x| (x - sma).abs()).sum::<f64>() / period as f64;
1720
1721 let price = data[i];
1722 let expected_cci = if mad == 0.0 {
1723 0.0
1724 } else {
1725 (price - sma) / (0.015 * mad)
1726 };
1727
1728 let actual_cci = out[i];
1729 let diff = (actual_cci - expected_cci).abs();
1730
1731 prop_assert!(
1732 diff < 1e-10,
1733 "[{}] CCI calculation mismatch at index {}: expected {}, got {}, diff {}",
1734 test_name,
1735 i,
1736 expected_cci,
1737 actual_cci,
1738 diff
1739 );
1740 }
1741
1742 if period == 1 {
1743 for i in 0..data.len() {
1744 prop_assert!(
1745 out[i].abs() < 1e-9,
1746 "[{}] CCI should be ~0 for period=1, got {} at index {}",
1747 test_name,
1748 out[i],
1749 i
1750 );
1751 }
1752 }
1753
1754 for i in (period - 1)..data.len() {
1755 if out[i].abs() > 500.0 {
1756 eprintln!(
1757 "[{}] Warning: Extreme CCI value {} at index {} (typical range is -300 to 300)",
1758 test_name,
1759 out[i],
1760 i
1761 );
1762 }
1763 }
1764
1765 for i in (period - 1)..data.len() {
1766 let window_start = i + 1 - period;
1767 let window = &data[window_start..=i];
1768 let sum: f64 = window.iter().sum();
1769 let sma = sum / period as f64;
1770 let mad: f64 = window.iter().map(|&x| (x - sma).abs()).sum::<f64>() / period as f64;
1771
1772 if mad > 0.0 && mad < 1e-12 {
1773 let actual_cci = out[i];
1774
1775 prop_assert!(
1776 actual_cci.is_finite(),
1777 "[{}] CCI should be finite even with very small MAD ({}) at index {}, got {}",
1778 test_name,
1779 mad,
1780 i,
1781 actual_cci
1782 );
1783
1784 let price = data[i];
1785 let expected_cci = (price - sma) / (0.015 * mad);
1786 let relative_error = ((actual_cci - expected_cci) / expected_cci).abs();
1787
1788 prop_assert!(
1789 relative_error < 1e-8 || (actual_cci - expected_cci).abs() < 1e-10,
1790 "[{}] CCI calculation with small MAD at index {}: expected {}, got {}, relative error {}",
1791 test_name,
1792 i,
1793 expected_cci,
1794 actual_cci,
1795 relative_error
1796 );
1797 }
1798 }
1799
1800 Ok(())
1801 })?;
1802
1803 Ok(())
1804 }
1805
1806 macro_rules! generate_all_cci_tests {
1807 ($($test_fn:ident),*) => {
1808 paste! {
1809 $(
1810 #[test]
1811 fn [<$test_fn _scalar_f64>]() {
1812 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1813 }
1814 )*
1815 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1816 $(
1817 #[test]
1818 fn [<$test_fn _avx2_f64>]() {
1819 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1820 }
1821 #[test]
1822 fn [<$test_fn _avx512_f64>]() {
1823 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1824 }
1825 )*
1826 }
1827 }
1828 }
1829
1830 generate_all_cci_tests!(
1831 check_cci_partial_params,
1832 check_cci_accuracy,
1833 check_cci_default_candles,
1834 check_cci_zero_period,
1835 check_cci_period_exceeds_length,
1836 check_cci_very_small_dataset,
1837 check_cci_reinput,
1838 check_cci_nan_handling,
1839 check_cci_streaming,
1840 check_cci_empty_input,
1841 check_cci_no_poison
1842 );
1843
1844 #[cfg(feature = "proptest")]
1845 generate_all_cci_tests!(check_cci_property);
1846
1847 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1848 #[test]
1849 fn test_cci_into_matches_api() -> Result<(), Box<dyn Error>> {
1850 let mut data = Vec::with_capacity(256);
1851 data.extend_from_slice(&[f64::NAN, f64::NAN, f64::NAN]);
1852 for i in 0..253usize {
1853 let x = (i as f64 * 0.037).sin() * 5.0 + 100.0 + (i % 7) as f64 * 0.1;
1854 data.push(x);
1855 }
1856
1857 let input = CciInput::from_slice(&data, CciParams::default());
1858
1859 let baseline = cci(&input)?.values;
1860
1861 let mut out = vec![0.0; data.len()];
1862 cci_into(&input, &mut out)?;
1863
1864 assert_eq!(baseline.len(), out.len());
1865
1866 fn eq_or_nan(a: f64, b: f64) -> bool {
1867 (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
1868 }
1869
1870 for (i, (a, b)) in baseline.iter().zip(out.iter()).enumerate() {
1871 assert!(
1872 eq_or_nan(*a, *b),
1873 "cci_into mismatch at idx {}: baseline={}, into={}",
1874 i,
1875 a,
1876 b
1877 );
1878 }
1879
1880 Ok(())
1881 }
1882
1883 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1884 skip_if_unsupported!(kernel, test);
1885 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1886 let c = read_candles_from_csv(file)?;
1887
1888 let output = CciBatchBuilder::new()
1889 .kernel(kernel)
1890 .apply_candles(&c, "hlc3")?;
1891
1892 let def = CciParams::default();
1893 let row = output.values_for(&def).expect("default row missing");
1894 assert_eq!(row.len(), c.close.len());
1895
1896 let expected = [
1897 -51.55252564125841,
1898 -43.50326506381541,
1899 -64.05117302269149,
1900 -39.05150631680948,
1901 -152.50523930896998,
1902 ];
1903 let start = row.len() - 5;
1904 for (i, &v) in row[start..].iter().enumerate() {
1905 assert!(
1906 (v - expected[i]).abs() < 1e-6,
1907 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1908 );
1909 }
1910 Ok(())
1911 }
1912
1913 #[cfg(debug_assertions)]
1914 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1915 skip_if_unsupported!(kernel, test);
1916
1917 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1918 let c = read_candles_from_csv(file)?;
1919
1920 let output = CciBatchBuilder::new()
1921 .kernel(kernel)
1922 .period_range(10, 30, 10)
1923 .apply_candles(&c, "close")?;
1924
1925 for (idx, &val) in output.values.iter().enumerate() {
1926 if val.is_nan() {
1927 continue;
1928 }
1929
1930 let bits = val.to_bits();
1931 let row = idx / output.cols;
1932 let col = idx % output.cols;
1933
1934 if bits == 0x11111111_11111111 {
1935 panic!(
1936 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {})",
1937 test, val, bits, row, col, idx
1938 );
1939 }
1940
1941 if bits == 0x22222222_22222222 {
1942 panic!(
1943 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {})",
1944 test, val, bits, row, col, idx
1945 );
1946 }
1947
1948 if bits == 0x33333333_33333333 {
1949 panic!(
1950 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {})",
1951 test, val, bits, row, col, idx
1952 );
1953 }
1954 }
1955
1956 Ok(())
1957 }
1958
1959 #[cfg(not(debug_assertions))]
1960 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1961 Ok(())
1962 }
1963
1964 macro_rules! gen_batch_tests {
1965 ($fn_name:ident) => {
1966 paste! {
1967 #[test] fn [<$fn_name _scalar>]() {
1968 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1969 }
1970 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1971 #[test] fn [<$fn_name _avx2>]() {
1972 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1973 }
1974 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1975 #[test] fn [<$fn_name _avx512>]() {
1976 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1977 }
1978 #[test] fn [<$fn_name _auto_detect>]() {
1979 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]),
1980 Kernel::Auto);
1981 }
1982 }
1983 };
1984 }
1985 gen_batch_tests!(check_batch_default_row);
1986 gen_batch_tests!(check_batch_no_poison);
1987}
1988
1989#[cfg(feature = "python")]
1990#[pyfunction(name = "cci")]
1991#[pyo3(signature = (data, period, kernel=None))]
1992pub fn cci_py<'py>(
1993 py: Python<'py>,
1994 data: numpy::PyReadonlyArray1<'py, f64>,
1995 period: usize,
1996 kernel: Option<&str>,
1997) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1998 use numpy::{IntoPyArray, PyArrayMethods};
1999
2000 let slice_in = data.as_slice()?;
2001 let kern = validate_kernel(kernel, false)?;
2002
2003 let params = CciParams {
2004 period: Some(period),
2005 };
2006 let cci_in = CciInput::from_slice(slice_in, params);
2007
2008 let result_vec: Vec<f64> = py
2009 .allow_threads(|| cci_with_kernel(&cci_in, kern).map(|o| o.values))
2010 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2011
2012 Ok(result_vec.into_pyarray(py))
2013}
2014
2015#[cfg(feature = "python")]
2016#[pyclass(name = "CciStream")]
2017pub struct CciStreamPy {
2018 stream: CciStream,
2019}
2020
2021#[cfg(feature = "python")]
2022#[pymethods]
2023impl CciStreamPy {
2024 #[new]
2025 fn new(period: usize) -> PyResult<Self> {
2026 let params = CciParams {
2027 period: Some(period),
2028 };
2029 let stream =
2030 CciStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2031 Ok(CciStreamPy { stream })
2032 }
2033
2034 fn update(&mut self, value: f64) -> Option<f64> {
2035 self.stream.update(value)
2036 }
2037}
2038
2039#[cfg(feature = "python")]
2040#[pyfunction(name = "cci_batch")]
2041#[pyo3(signature = (data, period_range, kernel=None))]
2042pub fn cci_batch_py<'py>(
2043 py: Python<'py>,
2044 data: numpy::PyReadonlyArray1<'py, f64>,
2045 period_range: (usize, usize, usize),
2046 kernel: Option<&str>,
2047) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2048 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2049 use pyo3::types::PyDict;
2050
2051 let slice_in = data.as_slice()?;
2052 let kern = validate_kernel(kernel, true)?;
2053
2054 let sweep = CciBatchRange {
2055 period: period_range,
2056 };
2057
2058 let output = py
2059 .allow_threads(|| cci_batch_with_kernel(slice_in, &sweep, kern))
2060 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2061
2062 let values_arr = output.values.into_pyarray(py);
2063 let reshaped = values_arr.reshape((output.rows, output.cols))?;
2064
2065 let dict = PyDict::new(py);
2066 dict.set_item("values", reshaped)?;
2067 dict.set_item(
2068 "periods",
2069 output
2070 .combos
2071 .iter()
2072 .map(|p| p.period.unwrap() as u64)
2073 .collect::<Vec<_>>()
2074 .into_pyarray(py),
2075 )?;
2076
2077 Ok(dict)
2078}
2079
2080#[cfg(all(feature = "python", feature = "cuda"))]
2081#[pyclass(module = "ta_indicators.cuda", name = "CciDeviceArrayF32", unsendable)]
2082pub struct CciDeviceArrayF32Py {
2083 pub(crate) inner: DeviceArrayF32,
2084 pub(crate) _ctx: Arc<Context>,
2085 pub(crate) device_id: u32,
2086 pub(crate) stream: usize,
2087}
2088
2089#[cfg(all(feature = "python", feature = "cuda"))]
2090#[pymethods]
2091impl CciDeviceArrayF32Py {
2092 #[getter]
2093 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2094 let inner = &self.inner;
2095 let d = PyDict::new(py);
2096 d.set_item("shape", (inner.rows, inner.cols))?;
2097 d.set_item("typestr", "<f4")?;
2098 d.set_item(
2099 "strides",
2100 (
2101 inner.cols * std::mem::size_of::<f32>(),
2102 std::mem::size_of::<f32>(),
2103 ),
2104 )?;
2105 d.set_item("data", (inner.device_ptr() as usize, false))?;
2106
2107 d.set_item("version", 3)?;
2108 Ok(d)
2109 }
2110
2111 fn __dlpack_device__(&self) -> (i32, i32) {
2112 (2, self.device_id as i32)
2113 }
2114
2115 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2116 pub fn __dlpack__<'py>(
2117 &mut self,
2118 py: Python<'py>,
2119 stream: Option<pyo3::PyObject>,
2120 max_version: Option<pyo3::PyObject>,
2121 dl_device: Option<pyo3::PyObject>,
2122 copy: Option<pyo3::PyObject>,
2123 ) -> PyResult<PyObject> {
2124 let (kdl, alloc_dev) = self.__dlpack_device__();
2125 if let Some(dev_obj) = dl_device.as_ref() {
2126 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2127 if dev_ty != kdl || dev_id != alloc_dev {
2128 let wants_copy = copy
2129 .as_ref()
2130 .and_then(|c| c.extract::<bool>(py).ok())
2131 .unwrap_or(false);
2132 if wants_copy {
2133 return Err(PyValueError::new_err(
2134 "device copy not implemented for __dlpack__",
2135 ));
2136 } else {
2137 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2138 }
2139 }
2140 }
2141 }
2142 let _ = stream;
2143
2144 let dummy =
2145 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2146 let inner = std::mem::replace(
2147 &mut self.inner,
2148 DeviceArrayF32 {
2149 buf: dummy,
2150 rows: 0,
2151 cols: 0,
2152 },
2153 );
2154
2155 let rows = inner.rows;
2156 let cols = inner.cols;
2157 let buf = inner.buf;
2158
2159 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2160
2161 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2162 }
2163}
2164#[cfg(all(feature = "python", feature = "cuda"))]
2165#[pyfunction(name = "cci_cuda_batch_dev")]
2166#[pyo3(signature = (data, period_range, device_id=0))]
2167pub fn cci_cuda_batch_dev_py(
2168 py: Python<'_>,
2169 data: numpy::PyReadonlyArray1<'_, f32>,
2170 period_range: (usize, usize, usize),
2171 device_id: usize,
2172) -> PyResult<CciDeviceArrayF32Py> {
2173 use crate::cuda::cuda_available;
2174 if !cuda_available() {
2175 return Err(PyValueError::new_err("CUDA not available"));
2176 }
2177 let slice = data.as_slice()?;
2178 let sweep = CciBatchRange {
2179 period: period_range,
2180 };
2181 let (inner, dev_id, ctx, stream) = py.allow_threads(|| -> PyResult<_> {
2182 let cuda = CudaCci::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2183 let dev_id = cuda.device_id();
2184 let ctx = cuda.context_arc();
2185 let out = cuda
2186 .cci_batch_dev(slice, &sweep)
2187 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2188 cuda.stream()
2189 .synchronize()
2190 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2191 Ok((out, dev_id, ctx, cuda.stream_handle_usize()))
2192 })?;
2193 Ok(CciDeviceArrayF32Py {
2194 inner,
2195 _ctx: ctx,
2196 device_id: dev_id,
2197 stream,
2198 })
2199}
2200
2201#[cfg(all(feature = "python", feature = "cuda"))]
2202#[pyfunction(name = "cci_cuda_many_series_one_param_dev")]
2203#[pyo3(signature = (data_tm, cols, rows, period, device_id=0))]
2204pub fn cci_cuda_many_series_one_param_dev_py(
2205 py: Python<'_>,
2206 data_tm: numpy::PyReadonlyArray1<'_, f32>,
2207 cols: usize,
2208 rows: usize,
2209 period: usize,
2210 device_id: usize,
2211) -> PyResult<CciDeviceArrayF32Py> {
2212 use crate::cuda::cuda_available;
2213 if !cuda_available() {
2214 return Err(PyValueError::new_err("CUDA not available"));
2215 }
2216 let slice = data_tm.as_slice()?;
2217 let (inner, dev_id, ctx, stream) = py.allow_threads(|| -> PyResult<_> {
2218 let cuda = CudaCci::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2219 let dev_id = cuda.device_id();
2220 let ctx = cuda.context_arc();
2221 let out = cuda
2222 .cci_many_series_one_param_time_major_dev(slice, cols, rows, period)
2223 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2224 cuda.stream()
2225 .synchronize()
2226 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2227 Ok((out, dev_id, ctx, cuda.stream_handle_usize()))
2228 })?;
2229 Ok(CciDeviceArrayF32Py {
2230 inner,
2231 _ctx: ctx,
2232 device_id: dev_id,
2233 stream,
2234 })
2235}
2236
2237#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2238#[wasm_bindgen]
2239pub fn cci_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2240 let params = CciParams {
2241 period: Some(period),
2242 };
2243 let input = CciInput::from_slice(data, params);
2244
2245 let mut output = vec![0.0; data.len()];
2246
2247 cci_into_slice(&mut output, &input, detect_best_kernel())
2248 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2249
2250 Ok(output)
2251}
2252
2253#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2254#[wasm_bindgen]
2255pub fn cci_into(
2256 in_ptr: *const f64,
2257 out_ptr: *mut f64,
2258 len: usize,
2259 period: usize,
2260) -> Result<(), JsValue> {
2261 if in_ptr.is_null() || out_ptr.is_null() {
2262 return Err(JsValue::from_str("null pointer passed to cci_into"));
2263 }
2264
2265 unsafe {
2266 let data = std::slice::from_raw_parts(in_ptr, len);
2267
2268 if period == 0 || period > len {
2269 return Err(JsValue::from_str("Invalid period"));
2270 }
2271
2272 let params = CciParams {
2273 period: Some(period),
2274 };
2275 let input = CciInput::from_slice(data, params);
2276
2277 if in_ptr == out_ptr {
2278 let mut temp = vec![0.0; len];
2279 cci_into_slice(&mut temp, &input, detect_best_kernel())
2280 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2281 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2282 out.copy_from_slice(&temp);
2283 } else {
2284 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2285 cci_into_slice(out, &input, detect_best_kernel())
2286 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2287 }
2288
2289 Ok(())
2290 }
2291}
2292
2293#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2294#[wasm_bindgen]
2295pub fn cci_alloc(len: usize) -> *mut f64 {
2296 let mut vec = Vec::<f64>::with_capacity(len);
2297 let ptr = vec.as_mut_ptr();
2298 std::mem::forget(vec);
2299 ptr
2300}
2301
2302#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2303#[wasm_bindgen]
2304pub fn cci_free(ptr: *mut f64, len: usize) {
2305 unsafe {
2306 let _ = Vec::from_raw_parts(ptr, len, len);
2307 }
2308}
2309
2310#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2311#[wasm_bindgen]
2312pub fn cci_batch_js(
2313 data: &[f64],
2314 period_start: usize,
2315 period_end: usize,
2316 period_step: usize,
2317) -> Result<Vec<f64>, JsValue> {
2318 let sweep = CciBatchRange {
2319 period: (period_start, period_end, period_step),
2320 };
2321
2322 cci_batch_inner(data, &sweep, detect_best_kernel(), false)
2323 .map(|output| output.values)
2324 .map_err(|e| JsValue::from_str(&e.to_string()))
2325}
2326
2327#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2328#[wasm_bindgen]
2329pub fn cci_batch_metadata_js(
2330 period_start: usize,
2331 period_end: usize,
2332 period_step: usize,
2333) -> Result<Vec<f64>, JsValue> {
2334 let sweep = CciBatchRange {
2335 period: (period_start, period_end, period_step),
2336 };
2337
2338 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2339 let mut metadata = Vec::with_capacity(combos.len());
2340
2341 for combo in combos {
2342 metadata.push(combo.period.unwrap() as f64);
2343 }
2344
2345 Ok(metadata)
2346}
2347
2348#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2349#[derive(Serialize, Deserialize)]
2350pub struct CciBatchConfig {
2351 pub period_range: (usize, usize, usize),
2352}
2353
2354#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2355#[derive(Serialize, Deserialize)]
2356pub struct CciBatchJsOutput {
2357 pub values: Vec<f64>,
2358 pub combos: Vec<CciParams>,
2359 pub rows: usize,
2360 pub cols: usize,
2361}
2362
2363#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2364#[wasm_bindgen(js_name = cci_batch)]
2365pub fn cci_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2366 let config: CciBatchConfig = serde_wasm_bindgen::from_value(config)
2367 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2368
2369 let sweep = CciBatchRange {
2370 period: config.period_range,
2371 };
2372
2373 let output = cci_batch_inner(data, &sweep, detect_best_kernel(), false)
2374 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2375
2376 let js_output = CciBatchJsOutput {
2377 values: output.values,
2378 combos: output.combos,
2379 rows: output.rows,
2380 cols: output.cols,
2381 };
2382
2383 serde_wasm_bindgen::to_value(&js_output)
2384 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2385}
2386
2387#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2388#[wasm_bindgen]
2389pub fn cci_batch_into(
2390 in_ptr: *const f64,
2391 out_ptr: *mut f64,
2392 len: usize,
2393 period_start: usize,
2394 period_end: usize,
2395 period_step: usize,
2396) -> Result<usize, JsValue> {
2397 if in_ptr.is_null() || out_ptr.is_null() {
2398 return Err(JsValue::from_str("null pointer passed to cci_batch_into"));
2399 }
2400
2401 unsafe {
2402 let data = std::slice::from_raw_parts(in_ptr, len);
2403
2404 let sweep = CciBatchRange {
2405 period: (period_start, period_end, period_step),
2406 };
2407
2408 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2409 let rows = combos.len();
2410 let cols = len;
2411
2412 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2413
2414 cci_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
2415 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2416
2417 Ok(rows)
2418 }
2419}