1#[cfg(all(feature = "python", feature = "cuda"))]
2use cust::context::Context as CudaContext;
3#[cfg(feature = "python")]
4use numpy::{IntoPyArray, PyArray1};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyValueError;
7#[cfg(feature = "python")]
8use pyo3::prelude::*;
9#[cfg(feature = "python")]
10use pyo3::types::{PyDict, PyList};
11#[cfg(all(feature = "python", feature = "cuda"))]
12use std::sync::Arc as StdArc;
13
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use serde::{Deserialize, Serialize};
16#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
17use wasm_bindgen::prelude::*;
18
19use crate::utilities::data_loader::{source_type, Candles};
20use crate::utilities::enums::Kernel;
21use crate::utilities::helpers::{
22 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
23 make_uninit_matrix,
24};
25#[cfg(feature = "python")]
26use crate::utilities::kernel_validation::validate_kernel;
27#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
28use core::arch::x86_64::*;
29#[cfg(not(target_arch = "wasm32"))]
30use rayon::prelude::*;
31use std::convert::AsRef;
32use std::error::Error;
33use std::mem::ManuallyDrop;
34use thiserror::Error;
35
36#[inline(always)]
37fn correlation_cycle_auto_kernel() -> Kernel {
38 Kernel::Scalar
39}
40
41impl<'a> AsRef<[f64]> for CorrelationCycleInput<'a> {
42 #[inline(always)]
43 fn as_ref(&self) -> &[f64] {
44 match &self.data {
45 CorrelationCycleData::Slice(slice) => slice,
46 CorrelationCycleData::Candles { candles, source } => source_type(candles, source),
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
52pub enum CorrelationCycleData<'a> {
53 Candles {
54 candles: &'a Candles,
55 source: &'a str,
56 },
57 Slice(&'a [f64]),
58}
59
60#[derive(Debug, Clone)]
61pub struct CorrelationCycleOutput {
62 pub real: Vec<f64>,
63 pub imag: Vec<f64>,
64 pub angle: Vec<f64>,
65 pub state: Vec<f64>,
66}
67
68#[derive(Debug, Clone)]
69#[cfg_attr(
70 all(target_arch = "wasm32", feature = "wasm"),
71 derive(Serialize, Deserialize)
72)]
73pub struct CorrelationCycleParams {
74 pub period: Option<usize>,
75 pub threshold: Option<f64>,
76}
77
78impl Default for CorrelationCycleParams {
79 fn default() -> Self {
80 Self {
81 period: Some(20),
82 threshold: Some(9.0),
83 }
84 }
85}
86
87#[derive(Debug, Clone)]
88pub struct CorrelationCycleInput<'a> {
89 pub data: CorrelationCycleData<'a>,
90 pub params: CorrelationCycleParams,
91}
92
93impl<'a> CorrelationCycleInput<'a> {
94 #[inline]
95 pub fn from_candles(
96 candles: &'a Candles,
97 source: &'a str,
98 params: CorrelationCycleParams,
99 ) -> Self {
100 Self {
101 data: CorrelationCycleData::Candles { candles, source },
102 params,
103 }
104 }
105
106 #[inline]
107 pub fn from_slice(slice: &'a [f64], params: CorrelationCycleParams) -> Self {
108 Self {
109 data: CorrelationCycleData::Slice(slice),
110 params,
111 }
112 }
113
114 #[inline]
115 pub fn with_default_candles(candles: &'a Candles) -> Self {
116 Self::from_candles(candles, "close", CorrelationCycleParams::default())
117 }
118
119 #[inline]
120 pub fn get_period(&self) -> usize {
121 self.params.period.unwrap_or(20)
122 }
123
124 #[inline]
125 pub fn get_threshold(&self) -> f64 {
126 self.params.threshold.unwrap_or(9.0)
127 }
128}
129
130#[derive(Debug, Clone)]
131pub struct CorrelationCycleBuilder {
132 period: Option<usize>,
133 threshold: Option<f64>,
134 kernel: Kernel,
135}
136
137impl Default for CorrelationCycleBuilder {
138 fn default() -> Self {
139 Self {
140 period: None,
141 threshold: None,
142 kernel: Kernel::Auto,
143 }
144 }
145}
146
147impl CorrelationCycleBuilder {
148 #[inline(always)]
149 pub fn new() -> Self {
150 Self::default()
151 }
152 #[inline(always)]
153 pub fn period(mut self, n: usize) -> Self {
154 self.period = Some(n);
155 self
156 }
157 #[inline(always)]
158 pub fn threshold(mut self, t: f64) -> Self {
159 self.threshold = Some(t);
160 self
161 }
162 #[inline(always)]
163 pub fn kernel(mut self, k: Kernel) -> Self {
164 self.kernel = k;
165 self
166 }
167
168 #[inline(always)]
169 pub fn apply(self, c: &Candles) -> Result<CorrelationCycleOutput, CorrelationCycleError> {
170 let p = CorrelationCycleParams {
171 period: self.period,
172 threshold: self.threshold,
173 };
174 let i = CorrelationCycleInput::from_candles(c, "close", p);
175 correlation_cycle_with_kernel(&i, self.kernel)
176 }
177
178 #[inline(always)]
179 pub fn apply_slice(self, d: &[f64]) -> Result<CorrelationCycleOutput, CorrelationCycleError> {
180 let p = CorrelationCycleParams {
181 period: self.period,
182 threshold: self.threshold,
183 };
184 let i = CorrelationCycleInput::from_slice(d, p);
185 correlation_cycle_with_kernel(&i, self.kernel)
186 }
187
188 #[inline(always)]
189 pub fn into_stream(self) -> Result<CorrelationCycleStream, CorrelationCycleError> {
190 let p = CorrelationCycleParams {
191 period: self.period,
192 threshold: self.threshold,
193 };
194 CorrelationCycleStream::try_new(p)
195 }
196}
197
198#[derive(Debug, Error)]
199pub enum CorrelationCycleError {
200 #[error("correlation_cycle: Empty data provided.")]
201 EmptyInputData,
202 #[error("correlation_cycle: Invalid period: period = {period}, data length = {data_len}")]
203 InvalidPeriod { period: usize, data_len: usize },
204 #[error("correlation_cycle: Not enough valid data: needed = {needed}, valid = {valid}")]
205 NotEnoughValidData { needed: usize, valid: usize },
206 #[error("correlation_cycle: All values are NaN.")]
207 AllValuesNaN,
208 #[error("correlation_cycle: output length mismatch: expected = {expected}, got = {got}")]
209 OutputLengthMismatch { expected: usize, got: usize },
210 #[error("correlation_cycle: invalid range: start={start}, end={end}, step={step}")]
211 InvalidRange {
212 start: usize,
213 end: usize,
214 step: usize,
215 },
216 #[error("correlation_cycle: invalid kernel for batch: {0:?}")]
217 InvalidKernelForBatch(Kernel),
218 #[error("correlation_cycle: invalid input: {0}")]
219 InvalidInput(String),
220}
221
222#[inline]
223pub fn correlation_cycle(
224 input: &CorrelationCycleInput,
225) -> Result<CorrelationCycleOutput, CorrelationCycleError> {
226 correlation_cycle_with_kernel(input, Kernel::Auto)
227}
228
229#[inline(always)]
230pub fn correlation_cycle_with_kernel(
231 input: &CorrelationCycleInput,
232 kernel: Kernel,
233) -> Result<CorrelationCycleOutput, CorrelationCycleError> {
234 let data: &[f64] = input.as_ref();
235 let len = data.len();
236 if len == 0 {
237 return Err(CorrelationCycleError::EmptyInputData);
238 }
239 let first = data
240 .iter()
241 .position(|x| !x.is_nan())
242 .ok_or(CorrelationCycleError::AllValuesNaN)?;
243 let period = input.get_period();
244 if period == 0 || period > len {
245 return Err(CorrelationCycleError::InvalidPeriod {
246 period,
247 data_len: len,
248 });
249 }
250 if len - first < period {
251 return Err(CorrelationCycleError::NotEnoughValidData {
252 needed: period,
253 valid: len - first,
254 });
255 }
256
257 let threshold = input.get_threshold();
258 let chosen = match kernel {
259 Kernel::Auto => correlation_cycle_auto_kernel(),
260 k => k,
261 };
262
263 let warm_real_imag_angle = first + period;
264 let warm_state = first + period + 1;
265
266 let mut real = alloc_with_nan_prefix(len, warm_real_imag_angle);
267 let mut imag = alloc_with_nan_prefix(len, warm_real_imag_angle);
268 let mut angle = alloc_with_nan_prefix(len, warm_real_imag_angle);
269 let mut state = alloc_with_nan_prefix(len, warm_state);
270
271 unsafe {
272 match chosen {
273 Kernel::Scalar | Kernel::ScalarBatch => correlation_cycle_compute_into(
274 data, period, threshold, first, &mut real, &mut imag, &mut angle, &mut state,
275 ),
276 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
277 Kernel::Avx2 | Kernel::Avx2Batch => correlation_cycle_avx2(
278 data, period, threshold, first, &mut real, &mut imag, &mut angle, &mut state,
279 ),
280 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
281 Kernel::Avx512 | Kernel::Avx512Batch => correlation_cycle_avx512(
282 data, period, threshold, first, &mut real, &mut imag, &mut angle, &mut state,
283 ),
284 _ => unreachable!(),
285 }
286 }
287
288 Ok(CorrelationCycleOutput {
289 real,
290 imag,
291 angle,
292 state,
293 })
294}
295
296#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
297pub fn correlation_cycle_into(
298 input: &CorrelationCycleInput,
299 out_real: &mut [f64],
300 out_imag: &mut [f64],
301 out_angle: &mut [f64],
302 out_state: &mut [f64],
303) -> Result<(), CorrelationCycleError> {
304 let data: &[f64] = input.as_ref();
305 let len = data.len();
306 if len == 0 {
307 return Err(CorrelationCycleError::EmptyInputData);
308 }
309
310 if out_real.len() != len
311 || out_imag.len() != len
312 || out_angle.len() != len
313 || out_state.len() != len
314 {
315 let got = *[
316 out_real.len(),
317 out_imag.len(),
318 out_angle.len(),
319 out_state.len(),
320 ]
321 .iter()
322 .min()
323 .unwrap_or(&0);
324 return Err(CorrelationCycleError::OutputLengthMismatch { expected: len, got });
325 }
326
327 let first = data
328 .iter()
329 .position(|x| !x.is_nan())
330 .ok_or(CorrelationCycleError::AllValuesNaN)?;
331
332 let period = input.get_period();
333 if period == 0 || period > len {
334 return Err(CorrelationCycleError::InvalidPeriod {
335 period,
336 data_len: len,
337 });
338 }
339 if len - first < period {
340 return Err(CorrelationCycleError::NotEnoughValidData {
341 needed: period,
342 valid: len - first,
343 });
344 }
345
346 let threshold = input.get_threshold();
347 let chosen = correlation_cycle_auto_kernel();
348
349 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
350 let warm_ria = (first + period).min(len);
351 let warm_s = (first + period + 1).min(len);
352 for v in &mut out_real[..warm_ria] {
353 *v = qnan;
354 }
355 for v in &mut out_imag[..warm_ria] {
356 *v = qnan;
357 }
358 for v in &mut out_angle[..warm_ria] {
359 *v = qnan;
360 }
361 for v in &mut out_state[..warm_s] {
362 *v = qnan;
363 }
364
365 unsafe {
366 match chosen {
367 Kernel::Scalar | Kernel::ScalarBatch => correlation_cycle_compute_into(
368 data, period, threshold, first, out_real, out_imag, out_angle, out_state,
369 ),
370 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
371 Kernel::Avx2 | Kernel::Avx2Batch => correlation_cycle_avx2(
372 data, period, threshold, first, out_real, out_imag, out_angle, out_state,
373 ),
374 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
375 Kernel::Avx512 | Kernel::Avx512Batch => correlation_cycle_avx512(
376 data, period, threshold, first, out_real, out_imag, out_angle, out_state,
377 ),
378 _ => correlation_cycle_compute_into(
379 data, period, threshold, first, out_real, out_imag, out_angle, out_state,
380 ),
381 }
382 }
383
384 Ok(())
385}
386
387#[inline(always)]
388pub fn correlation_cycle_into_slices(
389 dst_real: &mut [f64],
390 dst_imag: &mut [f64],
391 dst_angle: &mut [f64],
392 dst_state: &mut [f64],
393 input: &CorrelationCycleInput,
394 kernel: Kernel,
395) -> Result<(), CorrelationCycleError> {
396 let data: &[f64] = input.as_ref();
397 let len = data.len();
398 if len == 0 {
399 return Err(CorrelationCycleError::EmptyInputData);
400 }
401 if dst_real.len() != len
402 || dst_imag.len() != len
403 || dst_angle.len() != len
404 || dst_state.len() != len
405 {
406 let got = *[
407 dst_real.len(),
408 dst_imag.len(),
409 dst_angle.len(),
410 dst_state.len(),
411 ]
412 .iter()
413 .min()
414 .unwrap_or(&0);
415 return Err(CorrelationCycleError::OutputLengthMismatch { expected: len, got });
416 }
417 let first = data
418 .iter()
419 .position(|x| !x.is_nan())
420 .ok_or(CorrelationCycleError::AllValuesNaN)?;
421 let period = input.get_period();
422 if period == 0 || period > len {
423 return Err(CorrelationCycleError::InvalidPeriod {
424 period,
425 data_len: len,
426 });
427 }
428 if len - first < period {
429 return Err(CorrelationCycleError::NotEnoughValidData {
430 needed: period,
431 valid: len - first,
432 });
433 }
434 let threshold = input.get_threshold();
435 let chosen = match kernel {
436 Kernel::Auto => correlation_cycle_auto_kernel(),
437 k => k,
438 };
439
440 unsafe {
441 match chosen {
442 Kernel::Scalar | Kernel::ScalarBatch => correlation_cycle_compute_into(
443 data, period, threshold, first, dst_real, dst_imag, dst_angle, dst_state,
444 ),
445 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
446 Kernel::Avx2 | Kernel::Avx2Batch => correlation_cycle_avx2(
447 data, period, threshold, first, dst_real, dst_imag, dst_angle, dst_state,
448 ),
449 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
450 Kernel::Avx512 | Kernel::Avx512Batch => correlation_cycle_avx512(
451 data, period, threshold, first, dst_real, dst_imag, dst_angle, dst_state,
452 ),
453 _ => correlation_cycle_compute_into(
454 data, period, threshold, first, dst_real, dst_imag, dst_angle, dst_state,
455 ),
456 }
457 }
458
459 let warm_ria = first + period;
460 let warm_s = first + period + 1;
461 for v in &mut dst_real[..warm_ria] {
462 *v = f64::NAN;
463 }
464 for v in &mut dst_imag[..warm_ria] {
465 *v = f64::NAN;
466 }
467 for v in &mut dst_angle[..warm_ria] {
468 *v = f64::NAN;
469 }
470 for v in &mut dst_state[..warm_s] {
471 *v = f64::NAN;
472 }
473
474 Ok(())
475}
476
477#[inline]
478pub fn correlation_cycle_scalar(
479 data: &[f64],
480 period: usize,
481 threshold: f64,
482 first: usize,
483 real: &mut [f64],
484 imag: &mut [f64],
485 angle: &mut [f64],
486 state: &mut [f64],
487) {
488 unsafe {
489 correlation_cycle_compute_into(data, period, threshold, first, real, imag, angle, state)
490 }
491}
492
493#[inline(always)]
494unsafe fn correlation_cycle_compute_into(
495 data: &[f64],
496 period: usize,
497 threshold: f64,
498 first: usize,
499 real: &mut [f64],
500 imag: &mut [f64],
501 angle: &mut [f64],
502 state: &mut [f64],
503) {
504 let half_pi = f64::asin(1.0);
505 let two_pi = 4.0 * f64::asin(1.0);
506
507 let n = period as f64;
508 let w = two_pi / n;
509
510 let mut cos_table = vec![0.0f64; period];
511 let mut sin_table = vec![0.0f64; period];
512
513 let mut sum_cos = 0.0f64;
514 let mut sum_sin = 0.0f64;
515 let mut sum_cos2 = 0.0f64;
516 let mut sum_sin2 = 0.0f64;
517
518 {
519 let mut j = 0usize;
520 while j + 4 <= period {
521 let a0 = w * ((j as f64) + 1.0);
522 let (s0, c0) = a0.sin_cos();
523 let ys0 = -s0;
524 *cos_table.get_unchecked_mut(j) = c0;
525 *sin_table.get_unchecked_mut(j) = ys0;
526 sum_cos += c0;
527 sum_sin += ys0;
528 sum_cos2 += c0 * c0;
529 sum_sin2 += ys0 * ys0;
530
531 let a1 = a0 + w;
532 let (s1, c1) = a1.sin_cos();
533 let ys1 = -s1;
534 *cos_table.get_unchecked_mut(j + 1) = c1;
535 *sin_table.get_unchecked_mut(j + 1) = ys1;
536 sum_cos += c1;
537 sum_sin += ys1;
538 sum_cos2 += c1 * c1;
539 sum_sin2 += ys1 * ys1;
540
541 let a2 = a1 + w;
542 let (s2, c2) = a2.sin_cos();
543 let ys2 = -s2;
544 *cos_table.get_unchecked_mut(j + 2) = c2;
545 *sin_table.get_unchecked_mut(j + 2) = ys2;
546 sum_cos += c2;
547 sum_sin += ys2;
548 sum_cos2 += c2 * c2;
549 sum_sin2 += ys2 * ys2;
550
551 let a3 = a2 + w;
552 let (s3, c3) = a3.sin_cos();
553 let ys3 = -s3;
554 *cos_table.get_unchecked_mut(j + 3) = c3;
555 *sin_table.get_unchecked_mut(j + 3) = ys3;
556 sum_cos += c3;
557 sum_sin += ys3;
558 sum_cos2 += c3 * c3;
559 sum_sin2 += ys3 * ys3;
560
561 j += 4;
562 }
563 while j < period {
564 let a = w * ((j as f64) + 1.0);
565 let (s, c) = a.sin_cos();
566 let ys = -s;
567 *cos_table.get_unchecked_mut(j) = c;
568 *sin_table.get_unchecked_mut(j) = ys;
569 sum_cos += c;
570 sum_sin += ys;
571 sum_cos2 += c * c;
572 sum_sin2 += ys * ys;
573 j += 1;
574 }
575 }
576
577 let t2_const = n.mul_add(sum_cos2, -(sum_cos * sum_cos));
578 let t4_const = n.mul_add(sum_sin2, -(sum_sin * sum_sin));
579 let has_t2 = t2_const > 0.0;
580 let has_t4 = t4_const > 0.0;
581 let sqrt_t2c = if has_t2 { t2_const.sqrt() } else { 0.0 };
582 let sqrt_t4c = if has_t4 { t4_const.sqrt() } else { 0.0 };
583
584 let start_ria = first + period;
585 let start_s = start_ria + 1;
586
587 let mut prev_angle = f64::NAN;
588
589 let dptr = data.as_ptr();
590 let cptr = cos_table.as_ptr();
591 let sptr = sin_table.as_ptr();
592
593 for i in start_ria..data.len() {
594 let mut sum_x = 0.0f64;
595 let mut sum_x2 = 0.0f64;
596 let mut sum_xc = 0.0f64;
597 let mut sum_xs = 0.0f64;
598
599 let mut j = 0usize;
600 while j + 4 <= period {
601 let idx0 = i - (j + 1);
602 let idx1 = idx0 - 1;
603 let idx2 = idx1 - 1;
604 let idx3 = idx2 - 1;
605
606 let mut x0 = *dptr.add(idx0);
607 let mut x1 = *dptr.add(idx1);
608 let mut x2 = *dptr.add(idx2);
609 let mut x3 = *dptr.add(idx3);
610
611 if x0 != x0 {
612 x0 = 0.0;
613 }
614 if x1 != x1 {
615 x1 = 0.0;
616 }
617 if x2 != x2 {
618 x2 = 0.0;
619 }
620 if x3 != x3 {
621 x3 = 0.0;
622 }
623
624 let c0 = *cptr.add(j);
625 let s0 = *sptr.add(j);
626 let c1 = *cptr.add(j + 1);
627 let s1 = *sptr.add(j + 1);
628 let c2 = *cptr.add(j + 2);
629 let s2 = *sptr.add(j + 2);
630 let c3 = *cptr.add(j + 3);
631 let s3 = *sptr.add(j + 3);
632
633 sum_x += x0 + x1 + x2 + x3;
634 sum_x2 = x0.mul_add(x0, x1.mul_add(x1, x2.mul_add(x2, x3.mul_add(x3, sum_x2))));
635 sum_xc = x0.mul_add(c0, x1.mul_add(c1, x2.mul_add(c2, x3.mul_add(c3, sum_xc))));
636 sum_xs = x0.mul_add(s0, x1.mul_add(s1, x2.mul_add(s2, x3.mul_add(s3, sum_xs))));
637 j += 4;
638 }
639 while j < period {
640 let idx = i - (j + 1);
641 let mut x = *dptr.add(idx);
642 if x != x {
643 x = 0.0;
644 }
645 let c = *cptr.add(j);
646 let s = *sptr.add(j);
647 sum_x += x;
648 sum_x2 = x.mul_add(x, sum_x2);
649 sum_xc = x.mul_add(c, sum_xc);
650 sum_xs = x.mul_add(s, sum_xs);
651 j += 1;
652 }
653
654 let t1 = n.mul_add(sum_x2, -(sum_x * sum_x));
655 let mut r_val = 0.0;
656 let mut i_val = 0.0;
657
658 if t1 > 0.0 {
659 if has_t2 {
660 let denom = t1.sqrt() * sqrt_t2c;
661 if denom > 0.0 {
662 r_val = (n.mul_add(sum_xc, -(sum_x * sum_cos))) / denom;
663 }
664 }
665 if has_t4 {
666 let denom = t1.sqrt() * sqrt_t4c;
667 if denom > 0.0 {
668 i_val = (n.mul_add(sum_xs, -(sum_x * sum_sin))) / denom;
669 }
670 }
671 }
672
673 *real.get_unchecked_mut(i) = r_val;
674 *imag.get_unchecked_mut(i) = i_val;
675
676 let a = if i_val == 0.0 {
677 0.0
678 } else {
679 let mut a = (r_val / i_val).atan() + half_pi;
680 a = a.to_degrees();
681 if i_val > 0.0 {
682 a -= 180.0;
683 }
684 a
685 };
686 *angle.get_unchecked_mut(i) = a;
687
688 if i >= start_s {
689 let prev = prev_angle;
690 let st = if !prev.is_nan() && (a - prev).abs() < threshold {
691 if a >= 0.0 {
692 1.0
693 } else {
694 -1.0
695 }
696 } else {
697 0.0
698 };
699 *state.get_unchecked_mut(i) = st;
700 }
701
702 prev_angle = a;
703 }
704}
705
706#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
707#[inline(always)]
708pub unsafe fn correlation_cycle_avx2(
709 data: &[f64],
710 period: usize,
711 threshold: f64,
712 first: usize,
713 real: &mut [f64],
714 imag: &mut [f64],
715 angle: &mut [f64],
716 state: &mut [f64],
717) {
718 use core::arch::x86_64::*;
719
720 let half_pi = f64::asin(1.0);
721 let two_pi = 4.0 * f64::asin(1.0);
722 let n = period as f64;
723 let w = two_pi / n;
724
725 let mut cos_table = vec![0.0f64; period];
726 let mut sin_table = vec![0.0f64; period];
727
728 let mut sum_cos = 0.0f64;
729 let mut sum_sin = 0.0f64;
730 let mut sum_cos2 = 0.0f64;
731 let mut sum_sin2 = 0.0f64;
732
733 for j in 0..period {
734 let a = w * ((j as f64) + 1.0);
735 let (s, c) = a.sin_cos();
736 let ys = -s;
737 *cos_table.get_unchecked_mut(j) = c;
738 *sin_table.get_unchecked_mut(j) = ys;
739 sum_cos += c;
740 sum_sin += ys;
741 sum_cos2 += c * c;
742 sum_sin2 += ys * ys;
743 }
744
745 let t2_const = n.mul_add(sum_cos2, -(sum_cos * sum_cos));
746 let t4_const = n.mul_add(sum_sin2, -(sum_sin * sum_sin));
747 let has_t2 = t2_const > 0.0;
748 let has_t4 = t4_const > 0.0;
749 let sqrt_t2c = if has_t2 { t2_const.sqrt() } else { 0.0 };
750 let sqrt_t4c = if has_t4 { t4_const.sqrt() } else { 0.0 };
751
752 let start_ria = first + period;
753 let start_s = start_ria + 1;
754
755 #[inline(always)]
756 fn hsum256(v: __m256d) -> f64 {
757 unsafe {
758 let hi = _mm256_extractf128_pd(v, 1);
759 let lo = _mm256_castpd256_pd128(v);
760 let sum128 = _mm_add_pd(hi, lo);
761 let hi64 = _mm_unpackhi_pd(sum128, sum128);
762 _mm_cvtsd_f64(_mm_add_sd(sum128, hi64))
763 }
764 }
765
766 let dptr = data.as_ptr();
767 let cptr = cos_table.as_ptr();
768 let sptr = sin_table.as_ptr();
769
770 let mut prev_angle = f64::NAN;
771
772 for i in start_ria..data.len() {
773 let mut vx = _mm256_setzero_pd();
774 let mut vx2 = _mm256_setzero_pd();
775 let mut vxc = _mm256_setzero_pd();
776 let mut vxs = _mm256_setzero_pd();
777
778 let mut j = 0usize;
779 while j + 4 <= period {
780 let idx0 = i - (j + 1);
781 let x0 = *dptr.add(idx0);
782 let x1 = *dptr.add(idx0 - 1);
783 let x2 = *dptr.add(idx0 - 2);
784 let x3 = *dptr.add(idx0 - 3);
785 let mut vx0123 = _mm256_set_pd(x3, x2, x1, x0);
786
787 let ord = _mm256_cmp_pd(vx0123, vx0123, _CMP_ORD_Q);
788 vx0123 = _mm256_and_pd(vx0123, ord);
789
790 let vc = _mm256_loadu_pd(cptr.add(j));
791 let vs = _mm256_loadu_pd(sptr.add(j));
792
793 vx = _mm256_add_pd(vx, vx0123);
794 vx2 = _mm256_fmadd_pd(vx0123, vx0123, vx2);
795 vxc = _mm256_fmadd_pd(vx0123, vc, vxc);
796 vxs = _mm256_fmadd_pd(vx0123, vs, vxs);
797
798 j += 4;
799 }
800
801 let mut sum_x = hsum256(vx);
802 let mut sum_x2 = hsum256(vx2);
803 let mut sum_xc = hsum256(vxc);
804 let mut sum_xs = hsum256(vxs);
805
806 while j < period {
807 let idx = i - (j + 1);
808 let mut x = *dptr.add(idx);
809 if x != x {
810 x = 0.0;
811 }
812 let c = *cptr.add(j);
813 let s = *sptr.add(j);
814 sum_x += x;
815 sum_x2 = x.mul_add(x, sum_x2);
816 sum_xc = x.mul_add(c, sum_xc);
817 sum_xs = x.mul_add(s, sum_xs);
818 j += 1;
819 }
820
821 let t1 = n.mul_add(sum_x2, -(sum_x * sum_x));
822 let mut r_val = 0.0;
823 let mut i_val = 0.0;
824 if t1 > 0.0 {
825 if has_t2 {
826 let denom = t1.sqrt() * sqrt_t2c;
827 if denom > 0.0 {
828 r_val = (n.mul_add(sum_xc, -(sum_x * sum_cos))) / denom;
829 }
830 }
831 if has_t4 {
832 let denom = t1.sqrt() * sqrt_t4c;
833 if denom > 0.0 {
834 i_val = (n.mul_add(sum_xs, -(sum_x * sum_sin))) / denom;
835 }
836 }
837 }
838
839 *real.get_unchecked_mut(i) = r_val;
840 *imag.get_unchecked_mut(i) = i_val;
841
842 let a = if i_val == 0.0 {
843 0.0
844 } else {
845 let mut a = (r_val / i_val).atan() + half_pi;
846 a = a.to_degrees();
847 if i_val > 0.0 {
848 a -= 180.0;
849 }
850 a
851 };
852 *angle.get_unchecked_mut(i) = a;
853
854 if i >= start_s {
855 let prev = prev_angle;
856 let st = if !prev.is_nan() && (a - prev).abs() < threshold {
857 if a >= 0.0 {
858 1.0
859 } else {
860 -1.0
861 }
862 } else {
863 0.0
864 };
865 *state.get_unchecked_mut(i) = st;
866 }
867
868 prev_angle = a;
869 }
870}
871
872#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
873#[inline(always)]
874pub unsafe fn correlation_cycle_avx512(
875 data: &[f64],
876 period: usize,
877 threshold: f64,
878 first: usize,
879 real: &mut [f64],
880 imag: &mut [f64],
881 angle: &mut [f64],
882 state: &mut [f64],
883) {
884 use core::arch::x86_64::*;
885
886 let half_pi = f64::asin(1.0);
887 let two_pi = 4.0 * f64::asin(1.0);
888 let n = period as f64;
889 let w = two_pi / n;
890
891 let mut cos_table = vec![0.0f64; period];
892 let mut sin_table = vec![0.0f64; period];
893
894 let mut sum_cos = 0.0f64;
895 let mut sum_sin = 0.0f64;
896 let mut sum_cos2 = 0.0f64;
897 let mut sum_sin2 = 0.0f64;
898
899 for j in 0..period {
900 let a = w * ((j as f64) + 1.0);
901 let (s, c) = a.sin_cos();
902 let ys = -s;
903 *cos_table.get_unchecked_mut(j) = c;
904 *sin_table.get_unchecked_mut(j) = ys;
905 sum_cos += c;
906 sum_sin += ys;
907 sum_cos2 += c * c;
908 sum_sin2 += ys * ys;
909 }
910
911 let t2_const = n.mul_add(sum_cos2, -(sum_cos * sum_cos));
912 let t4_const = n.mul_add(sum_sin2, -(sum_sin * sum_sin));
913 let has_t2 = t2_const > 0.0;
914 let has_t4 = t4_const > 0.0;
915 let sqrt_t2c = if has_t2 { t2_const.sqrt() } else { 0.0 };
916 let sqrt_t4c = if has_t4 { t4_const.sqrt() } else { 0.0 };
917
918 let start_ria = first + period;
919 let start_s = start_ria + 1;
920
921 #[inline(always)]
922 fn hsum512(v: __m512d) -> f64 {
923 unsafe {
924 let lo = _mm512_castpd512_pd256(v);
925 let hi = _mm512_extractf64x4_pd(v, 1);
926 let lohi = _mm256_add_pd(lo, hi);
927 let hi128 = _mm256_extractf128_pd(lohi, 1);
928 let lo128 = _mm256_castpd256_pd128(lohi);
929 let sum128 = _mm_add_pd(hi128, lo128);
930 let hi64 = _mm_unpackhi_pd(sum128, sum128);
931 _mm_cvtsd_f64(_mm_add_sd(sum128, hi64))
932 }
933 }
934
935 let dptr = data.as_ptr();
936 let cptr = cos_table.as_ptr();
937 let sptr = sin_table.as_ptr();
938
939 let mut prev_angle = f64::NAN;
940
941 for i in start_ria..data.len() {
942 let mut vx = _mm512_setzero_pd();
943 let mut vx2 = _mm512_setzero_pd();
944 let mut vxc = _mm512_setzero_pd();
945 let mut vxs = _mm512_setzero_pd();
946
947 let mut j = 0usize;
948 while j + 8 <= period {
949 let idx0 = i - (j + 1);
950 let x0 = *dptr.add(idx0);
951 let x1 = *dptr.add(idx0 - 1);
952 let x2 = *dptr.add(idx0 - 2);
953 let x3 = *dptr.add(idx0 - 3);
954 let x4 = *dptr.add(idx0 - 4);
955 let x5 = *dptr.add(idx0 - 5);
956 let x6 = *dptr.add(idx0 - 6);
957 let x7 = *dptr.add(idx0 - 7);
958
959 let mut vx01234567 = _mm512_setr_pd(x0, x1, x2, x3, x4, x5, x6, x7);
960
961 let ordk = _mm512_cmp_pd_mask(vx01234567, vx01234567, _CMP_ORD_Q);
962 vx01234567 = _mm512_maskz_mov_pd(ordk, vx01234567);
963
964 let vc = _mm512_loadu_pd(cptr.add(j));
965 let vs = _mm512_loadu_pd(sptr.add(j));
966
967 vx = _mm512_add_pd(vx, vx01234567);
968 vx2 = _mm512_fmadd_pd(vx01234567, vx01234567, vx2);
969 vxc = _mm512_fmadd_pd(vx01234567, vc, vxc);
970 vxs = _mm512_fmadd_pd(vx01234567, vs, vxs);
971
972 j += 8;
973 }
974
975 let mut sum_x = hsum512(vx);
976 let mut sum_x2 = hsum512(vx2);
977 let mut sum_xc = hsum512(vxc);
978 let mut sum_xs = hsum512(vxs);
979
980 while j < period {
981 let idx = i - (j + 1);
982 let mut x = *dptr.add(idx);
983 if x != x {
984 x = 0.0;
985 }
986 let c = *cptr.add(j);
987 let s = *sptr.add(j);
988 sum_x += x;
989 sum_x2 = x.mul_add(x, sum_x2);
990 sum_xc = x.mul_add(c, sum_xc);
991 sum_xs = x.mul_add(s, sum_xs);
992 j += 1;
993 }
994
995 let t1 = n.mul_add(sum_x2, -(sum_x * sum_x));
996 let mut r_val = 0.0;
997 let mut i_val = 0.0;
998 if t1 > 0.0 {
999 if has_t2 {
1000 let denom = t1.sqrt() * sqrt_t2c;
1001 if denom > 0.0 {
1002 r_val = (n.mul_add(sum_xc, -(sum_x * sum_cos))) / denom;
1003 }
1004 }
1005 if has_t4 {
1006 let denom = t1.sqrt() * sqrt_t4c;
1007 if denom > 0.0 {
1008 i_val = (n.mul_add(sum_xs, -(sum_x * sum_sin))) / denom;
1009 }
1010 }
1011 }
1012
1013 *real.get_unchecked_mut(i) = r_val;
1014 *imag.get_unchecked_mut(i) = i_val;
1015
1016 let a = if i_val == 0.0 {
1017 0.0
1018 } else {
1019 let mut a = (r_val / i_val).atan() + half_pi;
1020 a = a.to_degrees();
1021 if i_val > 0.0 {
1022 a -= 180.0;
1023 }
1024 a
1025 };
1026 *angle.get_unchecked_mut(i) = a;
1027
1028 if i >= start_s {
1029 let prev = prev_angle;
1030 let st = if !prev.is_nan() && (a - prev).abs() < threshold {
1031 if a >= 0.0 {
1032 1.0
1033 } else {
1034 -1.0
1035 }
1036 } else {
1037 0.0
1038 };
1039 *state.get_unchecked_mut(i) = st;
1040 }
1041
1042 prev_angle = a;
1043 }
1044}
1045
1046#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1047#[inline(always)]
1048pub unsafe fn correlation_cycle_avx512_short(
1049 data: &[f64],
1050 period: usize,
1051 threshold: f64,
1052 first: usize,
1053 real: &mut [f64],
1054 imag: &mut [f64],
1055 angle: &mut [f64],
1056 state: &mut [f64],
1057) {
1058 correlation_cycle_compute_into(data, period, threshold, first, real, imag, angle, state)
1059}
1060
1061#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1062#[inline(always)]
1063pub unsafe fn correlation_cycle_avx512_long(
1064 data: &[f64],
1065 period: usize,
1066 threshold: f64,
1067 first: usize,
1068 real: &mut [f64],
1069 imag: &mut [f64],
1070 angle: &mut [f64],
1071 state: &mut [f64],
1072) {
1073 correlation_cycle_compute_into(data, period, threshold, first, real, imag, angle, state)
1074}
1075
1076#[inline(always)]
1077pub unsafe fn correlation_cycle_row_scalar_with_first(
1078 data: &[f64],
1079 period: usize,
1080 threshold: f64,
1081 first: usize,
1082 real: &mut [f64],
1083 imag: &mut [f64],
1084 angle: &mut [f64],
1085 state: &mut [f64],
1086) {
1087 correlation_cycle_compute_into(data, period, threshold, first, real, imag, angle, state)
1088}
1089
1090#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1091#[inline(always)]
1092pub unsafe fn correlation_cycle_row_avx2_with_first(
1093 data: &[f64],
1094 period: usize,
1095 threshold: f64,
1096 first: usize,
1097 real: &mut [f64],
1098 imag: &mut [f64],
1099 angle: &mut [f64],
1100 state: &mut [f64],
1101) {
1102 correlation_cycle_avx2(data, period, threshold, first, real, imag, angle, state)
1103}
1104
1105#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1106#[inline(always)]
1107pub unsafe fn correlation_cycle_row_avx512_with_first(
1108 data: &[f64],
1109 period: usize,
1110 threshold: f64,
1111 first: usize,
1112 real: &mut [f64],
1113 imag: &mut [f64],
1114 angle: &mut [f64],
1115 state: &mut [f64],
1116) {
1117 correlation_cycle_avx512(data, period, threshold, first, real, imag, angle, state)
1118}
1119
1120#[derive(Debug, Clone)]
1121pub struct CorrelationCycleStream {
1122 period: usize,
1123 threshold: f64,
1124
1125 buffer: Vec<f64>,
1126 head: usize,
1127 filled: bool,
1128
1129 last: Option<(f64, f64, f64, f64)>,
1130
1131 sum_x: f64,
1132 sum_x2: f64,
1133
1134 phasor_re: f64,
1135 phasor_im: f64,
1136
1137 prev_angle: f64,
1138
1139 n: f64,
1140 half_pi: f64,
1141 z_re: f64,
1142 z_im: f64,
1143 sum_cos: f64,
1144 sum_sin: f64,
1145 sqrt_t2c: f64,
1146 sqrt_t4c: f64,
1147 has_t2: bool,
1148 has_t4: bool,
1149}
1150
1151impl CorrelationCycleStream {
1152 pub fn try_new(params: CorrelationCycleParams) -> Result<Self, CorrelationCycleError> {
1153 let period = params.period.unwrap_or(20);
1154 if period == 0 {
1155 return Err(CorrelationCycleError::InvalidPeriod {
1156 period,
1157 data_len: 0,
1158 });
1159 }
1160 let threshold = params.threshold.unwrap_or(9.0);
1161
1162 let half_pi = f64::asin(1.0);
1163 let two_pi = 4.0 * half_pi;
1164 let n = period as f64;
1165 let w = two_pi / n;
1166
1167 let (s_w, c_w) = w.sin_cos();
1168 let z_re = c_w;
1169 let z_im = -s_w;
1170
1171 let mut sum_cos = 0.0f64;
1172 let mut sum_sin = 0.0f64;
1173 let mut sum_cos2 = 0.0f64;
1174 let mut sum_sin2 = 0.0f64;
1175
1176 let mut j = 0usize;
1177 while j + 4 <= period {
1178 let a0 = w * ((j as f64) + 1.0);
1179 let (s0, c0) = a0.sin_cos();
1180 let ys0 = -s0;
1181
1182 let a1 = a0 + w;
1183 let (s1, c1) = a1.sin_cos();
1184 let ys1 = -s1;
1185
1186 let a2 = a1 + w;
1187 let (s2, c2) = a2.sin_cos();
1188 let ys2 = -s2;
1189
1190 let a3 = a2 + w;
1191 let (s3, c3) = a3.sin_cos();
1192 let ys3 = -s3;
1193
1194 sum_cos += c0 + c1 + c2 + c3;
1195 sum_sin += ys0 + ys1 + ys2 + ys3;
1196 sum_cos2 += c0 * c0 + c1 * c1 + c2 * c2 + c3 * c3;
1197 sum_sin2 += ys0 * ys0 + ys1 * ys1 + ys2 * ys2 + ys3 * ys3;
1198
1199 j += 4;
1200 }
1201 while j < period {
1202 let a = w * ((j as f64) + 1.0);
1203 let (s, c) = a.sin_cos();
1204 let ys = -s;
1205 sum_cos += c;
1206 sum_sin += ys;
1207 sum_cos2 += c * c;
1208 sum_sin2 += ys * ys;
1209 j += 1;
1210 }
1211
1212 let t2_const = n.mul_add(sum_cos2, -(sum_cos * sum_cos));
1213 let t4_const = n.mul_add(sum_sin2, -(sum_sin * sum_sin));
1214 let has_t2 = t2_const > 0.0;
1215 let has_t4 = t4_const > 0.0;
1216 let sqrt_t2c = if has_t2 { t2_const.sqrt() } else { 0.0 };
1217 let sqrt_t4c = if has_t4 { t4_const.sqrt() } else { 0.0 };
1218
1219 Ok(Self {
1220 period,
1221 threshold,
1222 buffer: vec![0.0; period],
1223 head: 0,
1224 filled: false,
1225 last: None,
1226
1227 sum_x: 0.0,
1228 sum_x2: 0.0,
1229 phasor_re: 0.0,
1230 phasor_im: 0.0,
1231 prev_angle: f64::NAN,
1232
1233 n,
1234 half_pi,
1235 z_re,
1236 z_im,
1237 sum_cos,
1238 sum_sin,
1239 sqrt_t2c,
1240 sqrt_t4c,
1241 has_t2,
1242 has_t4,
1243 })
1244 }
1245
1246 #[inline(always)]
1247 pub fn update(&mut self, value: f64) -> Option<(f64, f64, f64, f64)> {
1248 let x_new = if value.is_nan() { 0.0 } else { value };
1249 let x_old = self.buffer[self.head];
1250 self.buffer[self.head] = x_new;
1251 self.head = (self.head + 1) % self.period;
1252
1253 self.sum_x += x_new - x_old;
1254
1255 self.sum_x2 = (x_new * x_new) - (x_old * x_old) + self.sum_x2;
1256
1257 let dx = x_new - x_old;
1258 let s = self.phasor_re + dx;
1259
1260 let new_re = self.z_re.mul_add(s, -self.z_im * self.phasor_im);
1261 let new_im = self.z_im.mul_add(s, self.z_re * self.phasor_im);
1262
1263 self.phasor_re = new_re;
1264 self.phasor_im = new_im;
1265
1266 let first_wrap_now = !self.filled && self.head == 0;
1267 if first_wrap_now {
1268 self.filled = true;
1269 } else if !self.filled {
1270 return None;
1271 }
1272
1273 let mut sum_x_exact = 0.0f64;
1274 let mut sum_x2_exact = 0.0f64;
1275 let mut k = 0usize;
1276 while k + 4 <= self.period {
1277 let idx0 = (self.head + k) % self.period;
1278 let idx1 = (self.head + k + 1) % self.period;
1279 let idx2 = (self.head + k + 2) % self.period;
1280 let idx3 = (self.head + k + 3) % self.period;
1281 let x0 = self.buffer[idx0];
1282 let x1 = self.buffer[idx1];
1283 let x2 = self.buffer[idx2];
1284 let x3 = self.buffer[idx3];
1285 sum_x_exact += x0 + x1 + x2 + x3;
1286 sum_x2_exact = x0.mul_add(x0, sum_x2_exact);
1287 sum_x2_exact = x1.mul_add(x1, sum_x2_exact);
1288 sum_x2_exact = x2.mul_add(x2, sum_x2_exact);
1289 sum_x2_exact = x3.mul_add(x3, sum_x2_exact);
1290 k += 4;
1291 }
1292 while k < self.period {
1293 let idx = (self.head + k) % self.period;
1294 let x = self.buffer[idx];
1295 sum_x_exact += x;
1296 sum_x2_exact = x.mul_add(x, sum_x2_exact);
1297 k += 1;
1298 }
1299
1300 let t1 = self.n.mul_add(sum_x2_exact, -(sum_x_exact * sum_x_exact));
1301
1302 let mut r_val = 0.0;
1303 let mut i_val = 0.0;
1304
1305 if t1 > 0.0 {
1306 let sqrt_t1 = t1.sqrt();
1307 if self.has_t2 {
1308 let denom_r = sqrt_t1 * self.sqrt_t2c;
1309 if denom_r > 0.0 {
1310 r_val = (self
1311 .n
1312 .mul_add(self.phasor_re, -(sum_x_exact * self.sum_cos)))
1313 / denom_r;
1314 }
1315 }
1316 if self.has_t4 {
1317 let denom_i = sqrt_t1 * self.sqrt_t4c;
1318 if denom_i > 0.0 {
1319 i_val = (self
1320 .n
1321 .mul_add(self.phasor_im, -(sum_x_exact * self.sum_sin)))
1322 / denom_i;
1323 }
1324 }
1325 }
1326
1327 let mut ang = if i_val == 0.0 {
1328 0.0
1329 } else {
1330 let mut a = (r_val / i_val).atan() + self.half_pi;
1331 a = a.to_degrees();
1332 if i_val > 0.0 {
1333 a -= 180.0;
1334 }
1335 a
1336 };
1337
1338 let st = if self.prev_angle.is_finite() && (ang - self.prev_angle).abs() < self.threshold {
1339 if ang >= 0.0 {
1340 1.0
1341 } else {
1342 -1.0
1343 }
1344 } else if self.prev_angle.is_finite() {
1345 0.0
1346 } else {
1347 f64::NAN
1348 };
1349
1350 self.prev_angle = ang;
1351
1352 let to_emit = self.last.take();
1353 self.last = Some((r_val, i_val, ang, st));
1354
1355 if first_wrap_now {
1356 None
1357 } else {
1358 to_emit
1359 }
1360 }
1361}
1362
1363#[derive(Clone, Debug)]
1364pub struct CorrelationCycleBatchRange {
1365 pub period: (usize, usize, usize),
1366 pub threshold: (f64, f64, f64),
1367}
1368
1369impl Default for CorrelationCycleBatchRange {
1370 fn default() -> Self {
1371 Self {
1372 period: (20, 269, 1),
1373 threshold: (9.0, 9.0, 0.0),
1374 }
1375 }
1376}
1377
1378#[derive(Clone, Debug, Default)]
1379pub struct CorrelationCycleBatchBuilder {
1380 range: CorrelationCycleBatchRange,
1381 kernel: Kernel,
1382}
1383
1384impl CorrelationCycleBatchBuilder {
1385 pub fn new() -> Self {
1386 Self::default()
1387 }
1388 pub fn kernel(mut self, k: Kernel) -> Self {
1389 self.kernel = k;
1390 self
1391 }
1392
1393 #[inline]
1394 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1395 self.range.period = (start, end, step);
1396 self
1397 }
1398 #[inline]
1399 pub fn period_static(mut self, p: usize) -> Self {
1400 self.range.period = (p, p, 0);
1401 self
1402 }
1403
1404 #[inline]
1405 pub fn threshold_range(mut self, start: f64, end: f64, step: f64) -> Self {
1406 self.range.threshold = (start, end, step);
1407 self
1408 }
1409 #[inline]
1410 pub fn threshold_static(mut self, x: f64) -> Self {
1411 self.range.threshold = (x, x, 0.0);
1412 self
1413 }
1414
1415 pub fn apply_slice(
1416 self,
1417 data: &[f64],
1418 ) -> Result<CorrelationCycleBatchOutput, CorrelationCycleError> {
1419 correlation_cycle_batch_with_kernel(data, &self.range, self.kernel)
1420 }
1421
1422 pub fn with_default_slice(
1423 data: &[f64],
1424 k: Kernel,
1425 ) -> Result<CorrelationCycleBatchOutput, CorrelationCycleError> {
1426 CorrelationCycleBatchBuilder::new()
1427 .kernel(k)
1428 .apply_slice(data)
1429 }
1430
1431 pub fn apply_candles(
1432 self,
1433 c: &Candles,
1434 src: &str,
1435 ) -> Result<CorrelationCycleBatchOutput, CorrelationCycleError> {
1436 let slice = source_type(c, src);
1437 self.apply_slice(slice)
1438 }
1439
1440 pub fn with_default_candles(
1441 c: &Candles,
1442 ) -> Result<CorrelationCycleBatchOutput, CorrelationCycleError> {
1443 CorrelationCycleBatchBuilder::new()
1444 .kernel(Kernel::Auto)
1445 .apply_candles(c, "close")
1446 }
1447}
1448
1449#[inline(always)]
1450pub fn correlation_cycle_batch_with_kernel(
1451 data: &[f64],
1452 sweep: &CorrelationCycleBatchRange,
1453 k: Kernel,
1454) -> Result<CorrelationCycleBatchOutput, CorrelationCycleError> {
1455 let kernel = match k {
1456 Kernel::Auto => detect_best_batch_kernel(),
1457 other if other.is_batch() => other,
1458 _ => return Err(CorrelationCycleError::InvalidKernelForBatch(k)),
1459 };
1460
1461 let simd = match kernel {
1462 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1463 Kernel::Avx512Batch => Kernel::Avx512,
1464 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1465 Kernel::Avx2Batch => Kernel::Avx2,
1466 Kernel::ScalarBatch => Kernel::Scalar,
1467 _ => Kernel::Scalar,
1468 };
1469 correlation_cycle_batch_par_slice(data, sweep, simd)
1470}
1471
1472#[derive(Clone, Debug)]
1473pub struct CorrelationCycleBatchOutput {
1474 pub real: Vec<f64>,
1475 pub imag: Vec<f64>,
1476 pub angle: Vec<f64>,
1477 pub state: Vec<f64>,
1478 pub combos: Vec<CorrelationCycleParams>,
1479 pub rows: usize,
1480 pub cols: usize,
1481}
1482
1483impl CorrelationCycleBatchOutput {
1484 pub fn row_for_params(&self, p: &CorrelationCycleParams) -> Option<usize> {
1485 self.combos.iter().position(|c| {
1486 c.period.unwrap_or(20) == p.period.unwrap_or(20)
1487 && (c.threshold.unwrap_or(9.0) - p.threshold.unwrap_or(9.0)).abs() < 1e-12
1488 })
1489 }
1490 pub fn values_for(
1491 &self,
1492 p: &CorrelationCycleParams,
1493 ) -> Option<(&[f64], &[f64], &[f64], &[f64])> {
1494 self.row_for_params(p).map(|row| {
1495 let start = row * self.cols;
1496 (
1497 &self.real[start..start + self.cols],
1498 &self.imag[start..start + self.cols],
1499 &self.angle[start..start + self.cols],
1500 &self.state[start..start + self.cols],
1501 )
1502 })
1503 }
1504}
1505
1506#[inline(always)]
1507fn expand_grid(
1508 r: &CorrelationCycleBatchRange,
1509) -> Result<Vec<CorrelationCycleParams>, CorrelationCycleError> {
1510 fn axis_usize(
1511 (start, end, step): (usize, usize, usize),
1512 ) -> Result<Vec<usize>, CorrelationCycleError> {
1513 if step == 0 || start == end {
1514 return Ok(vec![start]);
1515 }
1516 let mut vals = Vec::new();
1517 if start < end {
1518 let mut v = start;
1519 while v <= end {
1520 vals.push(v);
1521 v = match v.checked_add(step) {
1522 Some(n) => n,
1523 None => break,
1524 };
1525 }
1526 } else {
1527 let mut v = start;
1528 while v >= end {
1529 vals.push(v);
1530 v = match v.checked_sub(step) {
1531 Some(n) => n,
1532 None => break,
1533 };
1534 }
1535 }
1536 if vals.is_empty() {
1537 return Err(CorrelationCycleError::InvalidRange { start, end, step });
1538 }
1539 Ok(vals)
1540 }
1541 fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, CorrelationCycleError> {
1542 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
1543 return Ok(vec![start]);
1544 }
1545 let mut vals = Vec::new();
1546 if start <= end {
1547 let mut x = start;
1548 loop {
1549 vals.push(x);
1550 if x >= end {
1551 break;
1552 }
1553 let next = x + step;
1554 if !next.is_finite() || next == x {
1555 break;
1556 }
1557 x = next;
1558 if x > end + 1e-12 {
1559 break;
1560 }
1561 }
1562 } else {
1563 let mut x = start;
1564 loop {
1565 vals.push(x);
1566 if x <= end {
1567 break;
1568 }
1569 let next = x - step.abs();
1570 if !next.is_finite() || next == x {
1571 break;
1572 }
1573 x = next;
1574 if x < end - 1e-12 {
1575 break;
1576 }
1577 }
1578 }
1579 if vals.is_empty() {
1580 return Err(CorrelationCycleError::InvalidRange {
1581 start: start as usize,
1582 end: end as usize,
1583 step: step.abs() as usize,
1584 });
1585 }
1586 Ok(vals)
1587 }
1588
1589 let periods = axis_usize(r.period)?;
1590 let thresholds = axis_f64(r.threshold)?;
1591
1592 let cap = periods
1593 .len()
1594 .checked_mul(thresholds.len())
1595 .ok_or_else(|| CorrelationCycleError::InvalidInput("rows*cols overflow".into()))?;
1596 let mut out = Vec::with_capacity(cap);
1597 for &p in &periods {
1598 for &t in &thresholds {
1599 out.push(CorrelationCycleParams {
1600 period: Some(p),
1601 threshold: Some(t),
1602 });
1603 }
1604 }
1605 Ok(out)
1606}
1607
1608#[inline(always)]
1609pub fn correlation_cycle_batch_slice(
1610 data: &[f64],
1611 sweep: &CorrelationCycleBatchRange,
1612 kern: Kernel,
1613) -> Result<CorrelationCycleBatchOutput, CorrelationCycleError> {
1614 correlation_cycle_batch_inner(data, sweep, kern, false)
1615}
1616
1617#[inline(always)]
1618pub fn correlation_cycle_batch_par_slice(
1619 data: &[f64],
1620 sweep: &CorrelationCycleBatchRange,
1621 kern: Kernel,
1622) -> Result<CorrelationCycleBatchOutput, CorrelationCycleError> {
1623 correlation_cycle_batch_inner(data, sweep, kern, true)
1624}
1625
1626#[inline(always)]
1627fn correlation_cycle_batch_inner(
1628 data: &[f64],
1629 sweep: &CorrelationCycleBatchRange,
1630 kern: Kernel,
1631 parallel: bool,
1632) -> Result<CorrelationCycleBatchOutput, CorrelationCycleError> {
1633 if data.is_empty() {
1634 return Err(CorrelationCycleError::EmptyInputData);
1635 }
1636 let combos = expand_grid(sweep)?;
1637
1638 let first = data
1639 .iter()
1640 .position(|x| !x.is_nan())
1641 .ok_or(CorrelationCycleError::AllValuesNaN)?;
1642 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1643 if data.len() - first < max_p {
1644 return Err(CorrelationCycleError::NotEnoughValidData {
1645 needed: max_p,
1646 valid: data.len() - first,
1647 });
1648 }
1649
1650 let rows = combos.len();
1651 let cols = data.len();
1652 let _total = rows
1653 .checked_mul(cols)
1654 .ok_or_else(|| CorrelationCycleError::InvalidInput("rows*cols overflow".into()))?;
1655
1656 let mut real_mu = make_uninit_matrix(rows, cols);
1657 let mut imag_mu = make_uninit_matrix(rows, cols);
1658 let mut angle_mu = make_uninit_matrix(rows, cols);
1659 let mut state_mu = make_uninit_matrix(rows, cols);
1660
1661 let ria_warm: Vec<usize> = combos
1662 .iter()
1663 .map(|c| first.checked_add(c.period.unwrap()).unwrap_or(usize::MAX))
1664 .collect();
1665 let st_warm: Vec<usize> = combos
1666 .iter()
1667 .map(|c| {
1668 first
1669 .checked_add(c.period.unwrap())
1670 .and_then(|v| v.checked_add(1))
1671 .unwrap_or(usize::MAX)
1672 })
1673 .collect();
1674
1675 init_matrix_prefixes(&mut real_mu, cols, &ria_warm);
1676 init_matrix_prefixes(&mut imag_mu, cols, &ria_warm);
1677 init_matrix_prefixes(&mut angle_mu, cols, &ria_warm);
1678 init_matrix_prefixes(&mut state_mu, cols, &st_warm);
1679
1680 let mut real_guard = ManuallyDrop::new(real_mu);
1681 let mut imag_guard = ManuallyDrop::new(imag_mu);
1682 let mut angle_guard = ManuallyDrop::new(angle_mu);
1683 let mut state_guard = ManuallyDrop::new(state_mu);
1684
1685 let real: &mut [f64] = unsafe {
1686 core::slice::from_raw_parts_mut(real_guard.as_mut_ptr() as *mut f64, real_guard.len())
1687 };
1688 let imag: &mut [f64] = unsafe {
1689 core::slice::from_raw_parts_mut(imag_guard.as_mut_ptr() as *mut f64, imag_guard.len())
1690 };
1691 let angle: &mut [f64] = unsafe {
1692 core::slice::from_raw_parts_mut(angle_guard.as_mut_ptr() as *mut f64, angle_guard.len())
1693 };
1694 let state: &mut [f64] = unsafe {
1695 core::slice::from_raw_parts_mut(state_guard.as_mut_ptr() as *mut f64, state_guard.len())
1696 };
1697
1698 let do_row = |row: usize,
1699 out_real: &mut [f64],
1700 out_imag: &mut [f64],
1701 out_angle: &mut [f64],
1702 out_state: &mut [f64]| unsafe {
1703 let period = combos[row].period.unwrap();
1704 let threshold = combos[row].threshold.unwrap();
1705 match kern {
1706 Kernel::Scalar => correlation_cycle_row_scalar_with_first(
1707 data, period, threshold, first, out_real, out_imag, out_angle, out_state,
1708 ),
1709 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1710 Kernel::Avx2 => correlation_cycle_row_avx2_with_first(
1711 data, period, threshold, first, out_real, out_imag, out_angle, out_state,
1712 ),
1713 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1714 Kernel::Avx512 => correlation_cycle_row_avx512_with_first(
1715 data, period, threshold, first, out_real, out_imag, out_angle, out_state,
1716 ),
1717 _ => correlation_cycle_row_scalar_with_first(
1718 data, period, threshold, first, out_real, out_imag, out_angle, out_state,
1719 ),
1720 }
1721 };
1722
1723 if parallel {
1724 #[cfg(not(target_arch = "wasm32"))]
1725 {
1726 real.par_chunks_mut(cols)
1727 .zip(imag.par_chunks_mut(cols))
1728 .zip(angle.par_chunks_mut(cols))
1729 .zip(state.par_chunks_mut(cols))
1730 .enumerate()
1731 .for_each(|(row, (((r, im), an), st))| do_row(row, r, im, an, st));
1732 }
1733
1734 #[cfg(target_arch = "wasm32")]
1735 {
1736 for (row, (((r, im), an), st)) in real
1737 .chunks_mut(cols)
1738 .zip(imag.chunks_mut(cols))
1739 .zip(angle.chunks_mut(cols))
1740 .zip(state.chunks_mut(cols))
1741 .enumerate()
1742 {
1743 do_row(row, r, im, an, st);
1744 }
1745 }
1746 } else {
1747 for (row, (((r, im), an), st)) in real
1748 .chunks_mut(cols)
1749 .zip(imag.chunks_mut(cols))
1750 .zip(angle.chunks_mut(cols))
1751 .zip(state.chunks_mut(cols))
1752 .enumerate()
1753 {
1754 do_row(row, r, im, an, st);
1755 }
1756 }
1757
1758 let real = unsafe {
1759 Vec::from_raw_parts(
1760 real_guard.as_mut_ptr() as *mut f64,
1761 real_guard.len(),
1762 real_guard.capacity(),
1763 )
1764 };
1765 let imag = unsafe {
1766 Vec::from_raw_parts(
1767 imag_guard.as_mut_ptr() as *mut f64,
1768 imag_guard.len(),
1769 imag_guard.capacity(),
1770 )
1771 };
1772 let angle = unsafe {
1773 Vec::from_raw_parts(
1774 angle_guard.as_mut_ptr() as *mut f64,
1775 angle_guard.len(),
1776 angle_guard.capacity(),
1777 )
1778 };
1779 let state = unsafe {
1780 Vec::from_raw_parts(
1781 state_guard.as_mut_ptr() as *mut f64,
1782 state_guard.len(),
1783 state_guard.capacity(),
1784 )
1785 };
1786
1787 Ok(CorrelationCycleBatchOutput {
1788 real,
1789 imag,
1790 angle,
1791 state,
1792 combos,
1793 rows,
1794 cols,
1795 })
1796}
1797
1798#[inline(always)]
1799fn correlation_cycle_batch_inner_into(
1800 data: &[f64],
1801 sweep: &CorrelationCycleBatchRange,
1802 kern: Kernel,
1803 parallel: bool,
1804 out_real: &mut [f64],
1805 out_imag: &mut [f64],
1806 out_angle: &mut [f64],
1807 out_state: &mut [f64],
1808) -> Result<Vec<CorrelationCycleParams>, CorrelationCycleError> {
1809 if data.is_empty() {
1810 return Err(CorrelationCycleError::EmptyInputData);
1811 }
1812 let combos = expand_grid(sweep)?;
1813
1814 let first = data
1815 .iter()
1816 .position(|x| !x.is_nan())
1817 .ok_or(CorrelationCycleError::AllValuesNaN)?;
1818 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1819 if data.len() - first < max_p {
1820 return Err(CorrelationCycleError::NotEnoughValidData {
1821 needed: max_p,
1822 valid: data.len() - first,
1823 });
1824 }
1825
1826 let rows = combos.len();
1827 let cols = data.len();
1828 let expected = rows
1829 .checked_mul(cols)
1830 .ok_or_else(|| CorrelationCycleError::InvalidInput("rows*cols overflow".into()))?;
1831 if out_real.len() != expected
1832 || out_imag.len() != expected
1833 || out_angle.len() != expected
1834 || out_state.len() != expected
1835 {
1836 let got = *[
1837 out_real.len(),
1838 out_imag.len(),
1839 out_angle.len(),
1840 out_state.len(),
1841 ]
1842 .iter()
1843 .min()
1844 .unwrap_or(&0);
1845 return Err(CorrelationCycleError::OutputLengthMismatch { expected, got });
1846 }
1847
1848 let do_row = |row: usize, r: &mut [f64], im: &mut [f64], an: &mut [f64], st: &mut [f64]| unsafe {
1849 let p = combos[row].period.unwrap();
1850 let t = combos[row].threshold.unwrap();
1851 match kern {
1852 Kernel::Scalar | Kernel::ScalarBatch => {
1853 correlation_cycle_row_scalar_with_first(data, p, t, first, r, im, an, st)
1854 }
1855 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1856 Kernel::Avx2 | Kernel::Avx2Batch => {
1857 correlation_cycle_row_avx2_with_first(data, p, t, first, r, im, an, st)
1858 }
1859 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1860 Kernel::Avx512 | Kernel::Avx512Batch => {
1861 correlation_cycle_row_avx512_with_first(data, p, t, first, r, im, an, st)
1862 }
1863 _ => correlation_cycle_row_scalar_with_first(data, p, t, first, r, im, an, st),
1864 }
1865
1866 let ria = first + p;
1867 let stp = first + p + 1;
1868 for v in &mut r[..ria] {
1869 *v = f64::NAN;
1870 }
1871 for v in &mut im[..ria] {
1872 *v = f64::NAN;
1873 }
1874 for v in &mut an[..ria] {
1875 *v = f64::NAN;
1876 }
1877 for v in &mut st[..stp] {
1878 *v = f64::NAN;
1879 }
1880 };
1881
1882 if parallel {
1883 #[cfg(not(target_arch = "wasm32"))]
1884 {
1885 out_real
1886 .par_chunks_mut(cols)
1887 .zip(out_imag.par_chunks_mut(cols))
1888 .zip(out_angle.par_chunks_mut(cols))
1889 .zip(out_state.par_chunks_mut(cols))
1890 .enumerate()
1891 .for_each(|(row, (((r, im), an), st))| do_row(row, r, im, an, st));
1892 }
1893 #[cfg(target_arch = "wasm32")]
1894 {
1895 for (row, (((r, im), an), st)) in out_real
1896 .chunks_mut(cols)
1897 .zip(out_imag.chunks_mut(cols))
1898 .zip(out_angle.chunks_mut(cols))
1899 .zip(out_state.chunks_mut(cols))
1900 .enumerate()
1901 {
1902 do_row(row, r, im, an, st);
1903 }
1904 }
1905 } else {
1906 for (row, (((r, im), an), st)) in out_real
1907 .chunks_mut(cols)
1908 .zip(out_imag.chunks_mut(cols))
1909 .zip(out_angle.chunks_mut(cols))
1910 .zip(out_state.chunks_mut(cols))
1911 .enumerate()
1912 {
1913 do_row(row, r, im, an, st);
1914 }
1915 }
1916
1917 Ok(combos)
1918}
1919
1920#[cfg(test)]
1921mod tests {
1922 use super::*;
1923 use crate::skip_if_unsupported;
1924 use crate::utilities::data_loader::read_candles_from_csv;
1925 use crate::utilities::enums::Kernel;
1926
1927 fn check_cc_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1928 skip_if_unsupported!(kernel, test_name);
1929 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1930 let candles = read_candles_from_csv(file_path)?;
1931 let default_params = CorrelationCycleParams {
1932 period: None,
1933 threshold: None,
1934 };
1935 let input = CorrelationCycleInput::from_candles(&candles, "close", default_params);
1936 let output = correlation_cycle_with_kernel(&input, kernel)?;
1937 assert_eq!(output.real.len(), candles.close.len());
1938 Ok(())
1939 }
1940
1941 fn check_cc_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1942 skip_if_unsupported!(kernel, test_name);
1943 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1944 let candles = read_candles_from_csv(file_path)?;
1945 let params = CorrelationCycleParams {
1946 period: Some(20),
1947 threshold: Some(9.0),
1948 };
1949 let input = CorrelationCycleInput::from_candles(&candles, "close", params);
1950 let result = correlation_cycle_with_kernel(&input, kernel)?;
1951 let expected_last_five_real = [
1952 -0.3348928030992766,
1953 -0.2908979303392832,
1954 -0.10648582811938148,
1955 -0.09118320471750277,
1956 0.0826798259258665,
1957 ];
1958 let expected_last_five_imag = [
1959 0.2902308064575494,
1960 0.4025192756952553,
1961 0.4704322460080054,
1962 0.5404405595224989,
1963 0.5418162415918566,
1964 ];
1965 let expected_last_five_angle = [
1966 -139.0865569687123,
1967 -125.8553823569915,
1968 -102.75438860700636,
1969 -99.576759208278,
1970 -81.32373697835556,
1971 ];
1972 let start = result.real.len().saturating_sub(5);
1973 for i in 0..5 {
1974 let diff_real = (result.real[start + i] - expected_last_five_real[i]).abs();
1975 let diff_imag = (result.imag[start + i] - expected_last_five_imag[i]).abs();
1976 let diff_angle = (result.angle[start + i] - expected_last_five_angle[i]).abs();
1977 assert!(
1978 diff_real < 1e-8,
1979 "[{}] CC {:?} real mismatch at idx {}: got {}, expected {}",
1980 test_name,
1981 kernel,
1982 i,
1983 result.real[start + i],
1984 expected_last_five_real[i]
1985 );
1986 assert!(
1987 diff_imag < 1e-8,
1988 "[{}] CC {:?} imag mismatch at idx {}: got {}, expected {}",
1989 test_name,
1990 kernel,
1991 i,
1992 result.imag[start + i],
1993 expected_last_five_imag[i]
1994 );
1995 assert!(
1996 diff_angle < 1e-8,
1997 "[{}] CC {:?} angle mismatch at idx {}: got {}, expected {}",
1998 test_name,
1999 kernel,
2000 i,
2001 result.angle[start + i],
2002 expected_last_five_angle[i]
2003 );
2004 }
2005 Ok(())
2006 }
2007
2008 fn check_cc_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2009 skip_if_unsupported!(kernel, test_name);
2010 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2011 let candles = read_candles_from_csv(file_path)?;
2012 let input = CorrelationCycleInput::with_default_candles(&candles);
2013 match input.data {
2014 CorrelationCycleData::Candles { source, .. } => assert_eq!(source, "close"),
2015 _ => panic!("Expected CorrelationCycleData::Candles"),
2016 }
2017 let output = correlation_cycle_with_kernel(&input, kernel)?;
2018 assert_eq!(output.real.len(), candles.close.len());
2019 Ok(())
2020 }
2021
2022 fn check_cc_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2023 skip_if_unsupported!(kernel, test_name);
2024 let input_data = [10.0, 20.0, 30.0];
2025 let params = CorrelationCycleParams {
2026 period: Some(0),
2027 threshold: None,
2028 };
2029 let input = CorrelationCycleInput::from_slice(&input_data, params);
2030 let res = correlation_cycle_with_kernel(&input, kernel);
2031 assert!(
2032 res.is_err(),
2033 "[{}] CC should fail with zero period",
2034 test_name
2035 );
2036 Ok(())
2037 }
2038
2039 fn check_cc_period_exceeds_length(
2040 test_name: &str,
2041 kernel: Kernel,
2042 ) -> Result<(), Box<dyn Error>> {
2043 skip_if_unsupported!(kernel, test_name);
2044 let data_small = [10.0, 20.0, 30.0];
2045 let params = CorrelationCycleParams {
2046 period: Some(10),
2047 threshold: None,
2048 };
2049 let input = CorrelationCycleInput::from_slice(&data_small, params);
2050 let res = correlation_cycle_with_kernel(&input, kernel);
2051 assert!(
2052 res.is_err(),
2053 "[{}] CC should fail with period exceeding length",
2054 test_name
2055 );
2056 Ok(())
2057 }
2058
2059 fn check_cc_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2060 skip_if_unsupported!(kernel, test_name);
2061 let single_point = [42.0];
2062 let params = CorrelationCycleParams {
2063 period: Some(9),
2064 threshold: None,
2065 };
2066 let input = CorrelationCycleInput::from_slice(&single_point, params);
2067 let res = correlation_cycle_with_kernel(&input, kernel);
2068 assert!(
2069 res.is_err(),
2070 "[{}] CC should fail with insufficient data",
2071 test_name
2072 );
2073 Ok(())
2074 }
2075
2076 fn check_cc_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2077 skip_if_unsupported!(kernel, test_name);
2078 let data = [10.0, 10.5, 11.0, 11.5, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0];
2079 let params = CorrelationCycleParams {
2080 period: Some(4),
2081 threshold: Some(2.0),
2082 };
2083 let input = CorrelationCycleInput::from_slice(&data, params.clone());
2084 let first_result = correlation_cycle_with_kernel(&input, kernel)?;
2085 let second_input = CorrelationCycleInput::from_slice(&first_result.real, params);
2086 let second_result = correlation_cycle_with_kernel(&second_input, kernel)?;
2087 assert_eq!(first_result.real.len(), data.len());
2088 assert_eq!(second_result.real.len(), data.len());
2089 Ok(())
2090 }
2091
2092 fn check_cc_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2093 skip_if_unsupported!(kernel, test_name);
2094 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2095 let candles = read_candles_from_csv(file_path)?;
2096 let input = CorrelationCycleInput::from_candles(
2097 &candles,
2098 "close",
2099 CorrelationCycleParams {
2100 period: Some(20),
2101 threshold: None,
2102 },
2103 );
2104 let res = correlation_cycle_with_kernel(&input, kernel)?;
2105 assert_eq!(res.real.len(), candles.close.len());
2106 if res.real.len() > 40 {
2107 for (i, &val) in res.real[40..].iter().enumerate() {
2108 assert!(
2109 !val.is_nan(),
2110 "[{}] Found unexpected NaN at out-index {}",
2111 test_name,
2112 40 + i
2113 );
2114 }
2115 }
2116 Ok(())
2117 }
2118
2119 fn check_cc_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2120 skip_if_unsupported!(kernel, test_name);
2121 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2122 let candles = read_candles_from_csv(file_path)?;
2123 let period = 20;
2124 let threshold = 9.0;
2125 let input = CorrelationCycleInput::from_candles(
2126 &candles,
2127 "close",
2128 CorrelationCycleParams {
2129 period: Some(period),
2130 threshold: Some(threshold),
2131 },
2132 );
2133 let batch_output = correlation_cycle_with_kernel(&input, kernel)?;
2134 let mut stream = CorrelationCycleStream::try_new(CorrelationCycleParams {
2135 period: Some(period),
2136 threshold: Some(threshold),
2137 })?;
2138 let mut stream_real = Vec::with_capacity(candles.close.len());
2139 let mut stream_imag = Vec::with_capacity(candles.close.len());
2140 let mut stream_angle = Vec::with_capacity(candles.close.len());
2141 let mut stream_state = Vec::with_capacity(candles.close.len());
2142 for &price in &candles.close {
2143 match stream.update(price) {
2144 Some((r, im, ang, st)) => {
2145 stream_real.push(r);
2146 stream_imag.push(im);
2147 stream_angle.push(ang);
2148 stream_state.push(st);
2149 }
2150 None => {
2151 stream_real.push(f64::NAN);
2152 stream_imag.push(f64::NAN);
2153 stream_angle.push(f64::NAN);
2154 stream_state.push(0.0);
2155 }
2156 }
2157 }
2158 assert_eq!(batch_output.real.len(), stream_real.len());
2159 for (i, (&b, &s)) in batch_output.real.iter().zip(stream_real.iter()).enumerate() {
2160 if b.is_nan() && s.is_nan() {
2161 continue;
2162 }
2163 let diff = (b - s).abs();
2164 assert!(
2165 diff < 1e-9,
2166 "[{}] CC streaming real f64 mismatch at idx {}: batch={}, stream={}, diff={}",
2167 test_name,
2168 i,
2169 b,
2170 s,
2171 diff
2172 );
2173 }
2174 Ok(())
2175 }
2176
2177 #[cfg(debug_assertions)]
2178 fn check_cc_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2179 skip_if_unsupported!(kernel, test_name);
2180
2181 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2182 let candles = read_candles_from_csv(file_path)?;
2183
2184 let test_params = vec![
2185 CorrelationCycleParams {
2186 period: Some(20),
2187 threshold: Some(9.0),
2188 },
2189 CorrelationCycleParams {
2190 period: Some(10),
2191 threshold: Some(5.0),
2192 },
2193 CorrelationCycleParams {
2194 period: Some(30),
2195 threshold: Some(15.0),
2196 },
2197 CorrelationCycleParams {
2198 period: None,
2199 threshold: None,
2200 },
2201 ];
2202
2203 for params in test_params {
2204 let input = CorrelationCycleInput::from_candles(&candles, "close", params.clone());
2205 let output = correlation_cycle_with_kernel(&input, kernel)?;
2206
2207 let arrays = vec![
2208 ("real", &output.real),
2209 ("imag", &output.imag),
2210 ("angle", &output.angle),
2211 ("state", &output.state),
2212 ];
2213
2214 for (array_name, values) in arrays {
2215 for (i, &val) in values.iter().enumerate() {
2216 if val.is_nan() {
2217 continue;
2218 }
2219
2220 let bits = val.to_bits();
2221
2222 if bits == 0x11111111_11111111 {
2223 panic!(
2224 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in {} array with params {:?}",
2225 test_name, val, bits, i, array_name, params
2226 );
2227 }
2228
2229 if bits == 0x22222222_22222222 {
2230 panic!(
2231 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in {} array with params {:?}",
2232 test_name, val, bits, i, array_name, params
2233 );
2234 }
2235
2236 if bits == 0x33333333_33333333 {
2237 panic!(
2238 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in {} array with params {:?}",
2239 test_name, val, bits, i, array_name, params
2240 );
2241 }
2242 }
2243 }
2244 }
2245
2246 Ok(())
2247 }
2248
2249 #[cfg(not(debug_assertions))]
2250 fn check_cc_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2251 Ok(())
2252 }
2253
2254 macro_rules! generate_all_cc_tests {
2255 ($($test_fn:ident),*) => {
2256 paste::paste! {
2257 $(
2258 #[test]
2259 fn [<$test_fn _scalar_f64>]() {
2260 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2261 }
2262 )*
2263 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2264 $(
2265 #[test]
2266 fn [<$test_fn _avx2_f64>]() {
2267 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2268 }
2269 #[test]
2270 fn [<$test_fn _avx512_f64>]() {
2271 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2272 }
2273 )*
2274 }
2275 }
2276 }
2277
2278 generate_all_cc_tests!(
2279 check_cc_partial_params,
2280 check_cc_accuracy,
2281 check_cc_default_candles,
2282 check_cc_zero_period,
2283 check_cc_period_exceeds_length,
2284 check_cc_very_small_dataset,
2285 check_cc_reinput,
2286 check_cc_nan_handling,
2287 check_cc_streaming,
2288 check_cc_no_poison
2289 );
2290
2291 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2292 #[test]
2293 fn test_correlation_cycle_into_matches_api_v2() -> Result<(), Box<dyn Error>> {
2294 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2295 let candles = read_candles_from_csv(file_path)?;
2296
2297 let input = CorrelationCycleInput::from_candles(
2298 &candles,
2299 "close",
2300 CorrelationCycleParams::default(),
2301 );
2302
2303 let baseline = correlation_cycle(&input)?;
2304
2305 let len = candles.close.len();
2306 let mut out_real = vec![0.0f64; len];
2307 let mut out_imag = vec![0.0f64; len];
2308 let mut out_angle = vec![0.0f64; len];
2309 let mut out_state = vec![0.0f64; len];
2310
2311 correlation_cycle_into(
2312 &input,
2313 &mut out_real,
2314 &mut out_imag,
2315 &mut out_angle,
2316 &mut out_state,
2317 )?;
2318
2319 assert_eq!(baseline.real.len(), out_real.len());
2320 assert_eq!(baseline.imag.len(), out_imag.len());
2321 assert_eq!(baseline.angle.len(), out_angle.len());
2322 assert_eq!(baseline.state.len(), out_state.len());
2323
2324 fn eq_or_both_nan_eps(a: f64, b: f64) -> bool {
2325 (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
2326 }
2327
2328 for i in 0..len {
2329 assert!(
2330 eq_or_both_nan_eps(baseline.real[i], out_real[i]),
2331 "real mismatch at {}: base={}, into={}",
2332 i,
2333 baseline.real[i],
2334 out_real[i]
2335 );
2336 assert!(
2337 eq_or_both_nan_eps(baseline.imag[i], out_imag[i]),
2338 "imag mismatch at {}: base={}, into={}",
2339 i,
2340 baseline.imag[i],
2341 out_imag[i]
2342 );
2343 assert!(
2344 eq_or_both_nan_eps(baseline.angle[i], out_angle[i]),
2345 "angle mismatch at {}: base={}, into={}",
2346 i,
2347 baseline.angle[i],
2348 out_angle[i]
2349 );
2350 assert!(
2351 eq_or_both_nan_eps(baseline.state[i], out_state[i]),
2352 "state mismatch at {}: base={}, into={}",
2353 i,
2354 baseline.state[i],
2355 out_state[i]
2356 );
2357 }
2358
2359 Ok(())
2360 }
2361
2362 #[cfg(feature = "proptest")]
2363 fn check_correlation_cycle_property(
2364 test_name: &str,
2365 kernel: Kernel,
2366 ) -> Result<(), Box<dyn std::error::Error>> {
2367 use proptest::prelude::*;
2368 skip_if_unsupported!(kernel, test_name);
2369
2370 let strat = (5usize..=50).prop_flat_map(|period| {
2371 (
2372 prop::collection::vec(
2373 (10.0f64..10000.0f64).prop_filter("finite", |x| x.is_finite()),
2374 period + 10..400,
2375 ),
2376 Just(period),
2377 1.0f64..20.0f64,
2378 )
2379 });
2380
2381 proptest::test_runner::TestRunner::default()
2382 .run(&strat, |(data, period, threshold)| {
2383 let params = CorrelationCycleParams {
2384 period: Some(period),
2385 threshold: Some(threshold),
2386 };
2387 let input = CorrelationCycleInput::from_slice(&data, params);
2388
2389 let output = correlation_cycle_with_kernel(&input, kernel).unwrap();
2390
2391 let ref_output = correlation_cycle_with_kernel(&input, Kernel::Scalar).unwrap();
2392
2393 let warmup_period = period;
2394
2395 for i in 0..warmup_period {
2396 prop_assert!(
2397 output.real[i].is_nan(),
2398 "[{}] real[{}] should be NaN during warmup but is {}",
2399 test_name,
2400 i,
2401 output.real[i]
2402 );
2403 prop_assert!(
2404 output.imag[i].is_nan(),
2405 "[{}] imag[{}] should be NaN during warmup but is {}",
2406 test_name,
2407 i,
2408 output.imag[i]
2409 );
2410 prop_assert!(
2411 output.angle[i].is_nan(),
2412 "[{}] angle[{}] should be NaN during warmup but is {}",
2413 test_name,
2414 i,
2415 output.angle[i]
2416 );
2417 }
2418
2419 for i in 0..=period {
2420 prop_assert!(
2421 output.state[i].is_nan(),
2422 "[{}] state[{}] should be NaN during warmup but is {}",
2423 test_name,
2424 i,
2425 output.state[i]
2426 );
2427 }
2428
2429 for i in warmup_period..data.len() {
2430 if !output.real[i].is_nan() {
2431 prop_assert!(
2432 output.real[i] >= -1.0 - 1e-9 && output.real[i] <= 1.0 + 1e-9,
2433 "[{}] real[{}] = {} is outside [-1, 1] bounds",
2434 test_name,
2435 i,
2436 output.real[i]
2437 );
2438 }
2439 if !output.imag[i].is_nan() {
2440 prop_assert!(
2441 output.imag[i] >= -1.0 - 1e-9 && output.imag[i] <= 1.0 + 1e-9,
2442 "[{}] imag[{}] = {} is outside [-1, 1] bounds",
2443 test_name,
2444 i,
2445 output.imag[i]
2446 );
2447 }
2448 }
2449
2450 for i in warmup_period..data.len() {
2451 if !output.angle[i].is_nan() {
2452 prop_assert!(
2453 output.angle[i] >= -180.0 - 1e-9 && output.angle[i] <= 180.0 + 1e-9,
2454 "[{}] angle[{}] = {} is outside [-180, 180] bounds",
2455 test_name,
2456 i,
2457 output.angle[i]
2458 );
2459 }
2460 }
2461
2462 for i in (period + 1)..data.len() {
2463 if !output.state[i].is_nan() {
2464 let state_val = output.state[i];
2465 prop_assert!(
2466 (state_val + 1.0).abs() < 1e-9
2467 || state_val.abs() < 1e-9
2468 || (state_val - 1.0).abs() < 1e-9,
2469 "[{}] state[{}] = {} is not -1, 0, or 1",
2470 test_name,
2471 i,
2472 state_val
2473 );
2474 }
2475 }
2476
2477 for i in 0..data.len() {
2478 let real_bits = output.real[i].to_bits();
2479 let ref_real_bits = ref_output.real[i].to_bits();
2480
2481 if !output.real[i].is_finite() || !ref_output.real[i].is_finite() {
2482 prop_assert!(
2483 real_bits == ref_real_bits,
2484 "[{}] real finite/NaN mismatch at idx {}: {} vs {}",
2485 test_name,
2486 i,
2487 output.real[i],
2488 ref_output.real[i]
2489 );
2490 } else {
2491 let ulp_diff = real_bits.abs_diff(ref_real_bits);
2492 prop_assert!(
2493 (output.real[i] - ref_output.real[i]).abs() <= 1e-9 || ulp_diff <= 4,
2494 "[{}] real mismatch at idx {}: {} vs {} (ULP={})",
2495 test_name,
2496 i,
2497 output.real[i],
2498 ref_output.real[i],
2499 ulp_diff
2500 );
2501 }
2502
2503 let imag_bits = output.imag[i].to_bits();
2504 let ref_imag_bits = ref_output.imag[i].to_bits();
2505
2506 if !output.imag[i].is_finite() || !ref_output.imag[i].is_finite() {
2507 prop_assert!(
2508 imag_bits == ref_imag_bits,
2509 "[{}] imag finite/NaN mismatch at idx {}: {} vs {}",
2510 test_name,
2511 i,
2512 output.imag[i],
2513 ref_output.imag[i]
2514 );
2515 } else {
2516 let ulp_diff = imag_bits.abs_diff(ref_imag_bits);
2517 prop_assert!(
2518 (output.imag[i] - ref_output.imag[i]).abs() <= 1e-9 || ulp_diff <= 4,
2519 "[{}] imag mismatch at idx {}: {} vs {} (ULP={})",
2520 test_name,
2521 i,
2522 output.imag[i],
2523 ref_output.imag[i],
2524 ulp_diff
2525 );
2526 }
2527
2528 let angle_bits = output.angle[i].to_bits();
2529 let ref_angle_bits = ref_output.angle[i].to_bits();
2530
2531 if !output.angle[i].is_finite() || !ref_output.angle[i].is_finite() {
2532 prop_assert!(
2533 angle_bits == ref_angle_bits,
2534 "[{}] angle finite/NaN mismatch at idx {}: {} vs {}",
2535 test_name,
2536 i,
2537 output.angle[i],
2538 ref_output.angle[i]
2539 );
2540 } else {
2541 let ulp_diff = angle_bits.abs_diff(ref_angle_bits);
2542 prop_assert!(
2543 (output.angle[i] - ref_output.angle[i]).abs() <= 1e-9 || ulp_diff <= 4,
2544 "[{}] angle mismatch at idx {}: {} vs {} (ULP={})",
2545 test_name,
2546 i,
2547 output.angle[i],
2548 ref_output.angle[i],
2549 ulp_diff
2550 );
2551 }
2552
2553 let state_bits = output.state[i].to_bits();
2554 let ref_state_bits = ref_output.state[i].to_bits();
2555
2556 if !output.state[i].is_finite() || !ref_output.state[i].is_finite() {
2557 prop_assert!(
2558 state_bits == ref_state_bits,
2559 "[{}] state finite/NaN mismatch at idx {}: {} vs {}",
2560 test_name,
2561 i,
2562 output.state[i],
2563 ref_output.state[i]
2564 );
2565 } else {
2566 prop_assert!(
2567 (output.state[i] - ref_output.state[i]).abs() <= 1e-9,
2568 "[{}] state mismatch at idx {}: {} vs {}",
2569 test_name,
2570 i,
2571 output.state[i],
2572 ref_output.state[i]
2573 );
2574 }
2575 }
2576
2577 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) {
2578 for i in warmup_period..data.len() {
2579 if !output.real[i].is_nan() {
2580 prop_assert!(
2581 output.real[i].abs() < 1e-6,
2582 "[{}] real[{}] = {} should be near 0 for constant data",
2583 test_name,
2584 i,
2585 output.real[i]
2586 );
2587 }
2588 if !output.imag[i].is_nan() {
2589 prop_assert!(
2590 output.imag[i].abs() < 1e-6,
2591 "[{}] imag[{}] = {} should be near 0 for constant data",
2592 test_name,
2593 i,
2594 output.imag[i]
2595 );
2596 }
2597 }
2598 }
2599
2600 Ok(())
2601 })
2602 .unwrap();
2603
2604 Ok(())
2605 }
2606
2607 #[cfg(feature = "proptest")]
2608 generate_all_cc_tests!(check_correlation_cycle_property);
2609
2610 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2611 skip_if_unsupported!(kernel, test);
2612 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2613 let c = read_candles_from_csv(file)?;
2614 let output = CorrelationCycleBatchBuilder::new()
2615 .kernel(kernel)
2616 .apply_candles(&c, "close")?;
2617 let def = CorrelationCycleParams::default();
2618 let (row_real, row_imag, row_angle, row_state) =
2619 output.values_for(&def).expect("default row missing");
2620 assert_eq!(row_real.len(), c.close.len());
2621 assert_eq!(row_imag.len(), c.close.len());
2622 assert_eq!(row_angle.len(), c.close.len());
2623 assert_eq!(row_state.len(), c.close.len());
2624 Ok(())
2625 }
2626
2627 macro_rules! gen_batch_tests {
2628 ($fn_name:ident) => {
2629 paste::paste! {
2630 #[test] fn [<$fn_name _scalar>]() {
2631 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2632 }
2633 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2634 #[test] fn [<$fn_name _avx2>]() {
2635 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2636 }
2637 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2638 #[test] fn [<$fn_name _avx512>]() {
2639 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2640 }
2641 #[test] fn [<$fn_name _auto_detect>]() {
2642 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2643 }
2644 }
2645 };
2646 }
2647
2648 #[cfg(debug_assertions)]
2649 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2650 skip_if_unsupported!(kernel, test);
2651
2652 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2653 let c = read_candles_from_csv(file)?;
2654
2655 let output = CorrelationCycleBatchBuilder::new()
2656 .kernel(kernel)
2657 .period_range(10, 40, 10)
2658 .threshold_range(5.0, 15.0, 5.0)
2659 .apply_candles(&c, "close")?;
2660
2661 let matrices = vec![
2662 ("real", &output.real),
2663 ("imag", &output.imag),
2664 ("angle", &output.angle),
2665 ("state", &output.state),
2666 ];
2667
2668 for (matrix_name, values) in matrices {
2669 for (idx, &val) in values.iter().enumerate() {
2670 if val.is_nan() {
2671 continue;
2672 }
2673
2674 let bits = val.to_bits();
2675 let row = idx / output.cols;
2676 let col = idx % output.cols;
2677 let period = output.combos[row].period.unwrap();
2678 let threshold = output.combos[row].threshold.unwrap();
2679
2680 if bits == 0x11111111_11111111 {
2681 panic!(
2682 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) in {} matrix, params: period={}, threshold={}",
2683 test, val, bits, row, col, idx, matrix_name, period, threshold
2684 );
2685 }
2686
2687 if bits == 0x22222222_22222222 {
2688 panic!(
2689 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) in {} matrix, params: period={}, threshold={}",
2690 test, val, bits, row, col, idx, matrix_name, period, threshold
2691 );
2692 }
2693
2694 if bits == 0x33333333_33333333 {
2695 panic!(
2696 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) in {} matrix, params: period={}, threshold={}",
2697 test, val, bits, row, col, idx, matrix_name, period, threshold
2698 );
2699 }
2700 }
2701 }
2702
2703 Ok(())
2704 }
2705
2706 #[cfg(not(debug_assertions))]
2707 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2708 Ok(())
2709 }
2710
2711 gen_batch_tests!(check_batch_default_row);
2712 gen_batch_tests!(check_batch_no_poison);
2713
2714 #[test]
2715 fn test_correlation_cycle_into_matches_api() -> Result<(), Box<dyn Error>> {
2716 let n = 256usize;
2717 let mut data = Vec::with_capacity(n);
2718 for i in 0..n {
2719 let x = 100.0 + (i as f64 * 0.07).sin() * 2.5 + (i as f64 * 0.011).cos() * 0.4;
2720 data.push(x);
2721 }
2722
2723 let input = CorrelationCycleInput::from_slice(&data, CorrelationCycleParams::default());
2724
2725 let base = correlation_cycle(&input)?;
2726
2727 let mut out_r = vec![0.0; n];
2728 let mut out_i = vec![0.0; n];
2729 let mut out_a = vec![0.0; n];
2730 let mut out_s = vec![0.0; n];
2731
2732 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2733 {
2734 correlation_cycle_into(&input, &mut out_r, &mut out_i, &mut out_a, &mut out_s)?;
2735 }
2736 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2737 {
2738 return Ok(());
2739 }
2740
2741 assert_eq!(base.real.len(), n);
2742 assert_eq!(base.imag.len(), n);
2743 assert_eq!(base.angle.len(), n);
2744 assert_eq!(base.state.len(), n);
2745
2746 fn eq_or_both_nan(a: f64, b: f64) -> bool {
2747 (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
2748 }
2749
2750 for i in 0..n {
2751 assert!(
2752 eq_or_both_nan(base.real[i], out_r[i]),
2753 "real mismatch at {}: base={}, into={}",
2754 i,
2755 base.real[i],
2756 out_r[i]
2757 );
2758 assert!(
2759 eq_or_both_nan(base.imag[i], out_i[i]),
2760 "imag mismatch at {}: base={}, into={}",
2761 i,
2762 base.imag[i],
2763 out_i[i]
2764 );
2765 assert!(
2766 eq_or_both_nan(base.angle[i], out_a[i]),
2767 "angle mismatch at {}: base={}, into={}",
2768 i,
2769 base.angle[i],
2770 out_a[i]
2771 );
2772 assert!(
2773 eq_or_both_nan(base.state[i], out_s[i]),
2774 "state mismatch at {}: base={}, into={}",
2775 i,
2776 base.state[i],
2777 out_s[i]
2778 );
2779 }
2780
2781 Ok(())
2782 }
2783}
2784
2785#[cfg(all(feature = "python", feature = "cuda"))]
2786#[pyclass(
2787 module = "ta_indicators.cuda",
2788 name = "DeviceArrayF32CorrelationCycle",
2789 unsendable
2790)]
2791pub struct DeviceArrayF32CcPy {
2792 pub(crate) inner: crate::cuda::moving_averages::DeviceArrayF32,
2793
2794 pub(crate) _ctx: StdArc<CudaContext>,
2795 pub(crate) device_id: i32,
2796}
2797
2798#[cfg(all(feature = "python", feature = "cuda"))]
2799#[pymethods]
2800impl DeviceArrayF32CcPy {
2801 #[new]
2802 fn py_new() -> PyResult<Self> {
2803 Err(pyo3::exceptions::PyTypeError::new_err(
2804 "use correlation_cycle_cuda_* factory functions to create this type",
2805 ))
2806 }
2807
2808 #[getter]
2809 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2810 let d = PyDict::new(py);
2811 let inner = &self.inner;
2812 d.set_item("shape", (inner.rows, inner.cols))?;
2813 d.set_item("typestr", "<f4")?;
2814 d.set_item(
2815 "strides",
2816 (
2817 inner.cols * std::mem::size_of::<f32>(),
2818 std::mem::size_of::<f32>(),
2819 ),
2820 )?;
2821 let ptr_val: usize = if inner.rows == 0 || inner.cols == 0 {
2822 0
2823 } else {
2824 inner.buf.as_device_ptr().as_raw() as usize
2825 };
2826 d.set_item("data", (ptr_val, false))?;
2827
2828 d.set_item("version", 3)?;
2829 Ok(d)
2830 }
2831
2832 fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
2833 Ok((2, self.device_id))
2834 }
2835
2836 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2837 fn __dlpack__<'py>(
2838 &mut self,
2839 py: Python<'py>,
2840 stream: Option<pyo3::PyObject>,
2841 max_version: Option<pyo3::PyObject>,
2842 dl_device: Option<pyo3::PyObject>,
2843 copy: Option<pyo3::PyObject>,
2844 ) -> PyResult<pyo3::PyObject> {
2845 use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2846
2847 let (kdl, alloc_dev) = self.__dlpack_device__()?;
2848 if let Some(dev_obj) = dl_device.as_ref() {
2849 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2850 if dev_ty != kdl || dev_id != alloc_dev {
2851 let wants_copy = copy
2852 .as_ref()
2853 .and_then(|c| c.extract::<bool>(py).ok())
2854 .unwrap_or(false);
2855 if wants_copy {
2856 return Err(PyValueError::new_err(
2857 "device copy not implemented for __dlpack__",
2858 ));
2859 } else {
2860 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2861 }
2862 }
2863 }
2864 }
2865
2866 if let Some(obj) = &stream {
2867 if let Ok(s) = obj.extract::<i64>(py) {
2868 if s == 0 {
2869 return Err(PyValueError::new_err(
2870 "stream=0 is reserved and not supported by this producer",
2871 ));
2872 }
2873 }
2874 }
2875
2876 let dummy = cust::memory::DeviceBuffer::<f32>::from_slice(&[])
2877 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2878 let inner = std::mem::replace(
2879 &mut self.inner,
2880 crate::cuda::moving_averages::DeviceArrayF32 {
2881 buf: dummy,
2882 rows: 0,
2883 cols: 0,
2884 },
2885 );
2886
2887 let rows = inner.rows;
2888 let cols = inner.cols;
2889 let buf = inner.buf;
2890
2891 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2892
2893 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2894 }
2895}
2896
2897#[cfg(all(feature = "python", feature = "cuda"))]
2898impl DeviceArrayF32CcPy {
2899 pub fn new_from_rust(
2900 inner: crate::cuda::moving_averages::DeviceArrayF32,
2901 ctx: StdArc<CudaContext>,
2902 device_id: u32,
2903 ) -> Self {
2904 Self {
2905 inner,
2906 _ctx: ctx,
2907 device_id: device_id as i32,
2908 }
2909 }
2910}
2911
2912#[cfg(all(feature = "python", feature = "cuda"))]
2913#[pyfunction(name = "correlation_cycle_cuda_batch_dev")]
2914#[pyo3(signature = (data_f32, period_range, threshold_range, device_id=0))]
2915pub fn correlation_cycle_cuda_batch_dev_py(
2916 py: Python<'_>,
2917 data_f32: numpy::PyReadonlyArray1<'_, f32>,
2918 period_range: (usize, usize, usize),
2919 threshold_range: (f64, f64, f64),
2920 device_id: usize,
2921) -> PyResult<(
2922 DeviceArrayF32CcPy,
2923 DeviceArrayF32CcPy,
2924 DeviceArrayF32CcPy,
2925 DeviceArrayF32CcPy,
2926)> {
2927 use crate::cuda::cuda_available;
2928 use crate::cuda::moving_averages::CudaCorrelationCycle;
2929 if !cuda_available() {
2930 return Err(PyValueError::new_err("CUDA not available"));
2931 }
2932 let slice_in = data_f32.as_slice()?;
2933 let sweep = CorrelationCycleBatchRange {
2934 period: period_range,
2935 threshold: threshold_range,
2936 };
2937 let (quad, ctx, dev_id) = py.allow_threads(|| {
2938 let mut cuda = CudaCorrelationCycle::new(device_id)
2939 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2940 let ctx = cuda.ctx();
2941 let dev_id = cuda.device_id();
2942 let quad = cuda
2943 .correlation_cycle_batch_dev(slice_in, &sweep)
2944 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2945 Ok::<_, pyo3::PyErr>((quad, ctx, dev_id))
2946 })?;
2947 Ok((
2948 DeviceArrayF32CcPy::new_from_rust(quad.real, ctx.clone(), dev_id),
2949 DeviceArrayF32CcPy::new_from_rust(quad.imag, ctx.clone(), dev_id),
2950 DeviceArrayF32CcPy::new_from_rust(quad.angle, ctx.clone(), dev_id),
2951 DeviceArrayF32CcPy::new_from_rust(quad.state, ctx, dev_id),
2952 ))
2953}
2954
2955#[cfg(all(feature = "python", feature = "cuda"))]
2956#[pyfunction(name = "correlation_cycle_cuda_many_series_one_param_dev")]
2957#[pyo3(signature = (data_tm_f32, period, threshold, device_id=0))]
2958pub fn correlation_cycle_cuda_many_series_one_param_dev_py(
2959 py: Python<'_>,
2960 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
2961 period: usize,
2962 threshold: f64,
2963 device_id: usize,
2964) -> PyResult<(
2965 DeviceArrayF32CcPy,
2966 DeviceArrayF32CcPy,
2967 DeviceArrayF32CcPy,
2968 DeviceArrayF32CcPy,
2969)> {
2970 use crate::cuda::cuda_available;
2971 use crate::cuda::moving_averages::CudaCorrelationCycle;
2972 use numpy::PyUntypedArrayMethods;
2973 if !cuda_available() {
2974 return Err(PyValueError::new_err("CUDA not available"));
2975 }
2976 let shape = data_tm_f32.shape();
2977 if shape.len() != 2 {
2978 return Err(PyValueError::new_err("expected 2D array"));
2979 }
2980 let rows = shape[0];
2981 let cols = shape[1];
2982 let flat = data_tm_f32.as_slice()?;
2983 let expected = rows
2984 .checked_mul(cols)
2985 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2986 if flat.len() != expected {
2987 return Err(PyValueError::new_err("time-major input length mismatch"));
2988 }
2989 let params = CorrelationCycleParams {
2990 period: Some(period),
2991 threshold: Some(threshold),
2992 };
2993 let (quad, ctx, dev_id) = py.allow_threads(|| {
2994 let mut cuda = CudaCorrelationCycle::new(device_id)
2995 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2996 let ctx = cuda.ctx();
2997 let dev_id = cuda.device_id();
2998 let quad = cuda
2999 .correlation_cycle_many_series_one_param_time_major_dev(flat, cols, rows, ¶ms)
3000 .map_err(|e| PyValueError::new_err(e.to_string()))?;
3001 Ok::<_, pyo3::PyErr>((quad, ctx, dev_id))
3002 })?;
3003 Ok((
3004 DeviceArrayF32CcPy::new_from_rust(quad.real, ctx.clone(), dev_id),
3005 DeviceArrayF32CcPy::new_from_rust(quad.imag, ctx.clone(), dev_id),
3006 DeviceArrayF32CcPy::new_from_rust(quad.angle, ctx.clone(), dev_id),
3007 DeviceArrayF32CcPy::new_from_rust(quad.state, ctx, dev_id),
3008 ))
3009}
3010
3011#[cfg(feature = "python")]
3012#[pyfunction(name = "correlation_cycle")]
3013#[pyo3(signature = (data, period=None, threshold=None, kernel=None))]
3014pub fn correlation_cycle_py<'py>(
3015 py: Python<'py>,
3016 data: numpy::PyReadonlyArray1<'py, f64>,
3017 period: Option<usize>,
3018 threshold: Option<f64>,
3019 kernel: Option<&str>,
3020) -> PyResult<Bound<'py, PyDict>> {
3021 use numpy::PyArrayMethods;
3022
3023 let data_slice = data.as_slice()?;
3024 let kern = match kernel {
3025 Some(k) => crate::utilities::kernel_validation::validate_kernel(Some(k), false)?,
3026 None => Kernel::Auto,
3027 };
3028
3029 let params = CorrelationCycleParams { period, threshold };
3030 let input = CorrelationCycleInput::from_slice(data_slice, params);
3031
3032 let output = py
3033 .allow_threads(|| correlation_cycle_with_kernel(&input, kern))
3034 .map_err(|e| PyValueError::new_err(e.to_string()))?;
3035
3036 let dict = PyDict::new(py);
3037 dict.set_item("real", output.real.into_pyarray(py))?;
3038 dict.set_item("imag", output.imag.into_pyarray(py))?;
3039 dict.set_item("angle", output.angle.into_pyarray(py))?;
3040 dict.set_item("state", output.state.into_pyarray(py))?;
3041
3042 Ok(dict)
3043}
3044
3045#[cfg(feature = "python")]
3046#[pyfunction(name = "correlation_cycle_batch")]
3047#[pyo3(signature = (data, period_range=None, threshold_range=None, kernel=None))]
3048pub fn correlation_cycle_batch_py<'py>(
3049 py: Python<'py>,
3050 data: numpy::PyReadonlyArray1<'py, f64>,
3051 period_range: Option<(usize, usize, usize)>,
3052 threshold_range: Option<(f64, f64, f64)>,
3053 kernel: Option<&str>,
3054) -> PyResult<Bound<'py, PyDict>> {
3055 use numpy::PyArrayMethods;
3056
3057 let slice_in = data.as_slice()?;
3058
3059 let sweep = CorrelationCycleBatchRange {
3060 period: period_range.unwrap_or((20, 100, 1)),
3061 threshold: threshold_range.unwrap_or((9.0, 9.0, 0.0)),
3062 };
3063
3064 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
3065 let rows = combos.len();
3066 let cols = slice_in.len();
3067 let total = rows
3068 .checked_mul(cols)
3069 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
3070
3071 let out_real = unsafe { PyArray1::<f64>::new(py, [total], false) };
3072 let out_imag = unsafe { PyArray1::<f64>::new(py, [total], false) };
3073 let out_angle = unsafe { PyArray1::<f64>::new(py, [total], false) };
3074 let out_state = unsafe { PyArray1::<f64>::new(py, [total], false) };
3075
3076 let mut_r = unsafe { out_real.as_slice_mut()? };
3077 let mut_im = unsafe { out_imag.as_slice_mut()? };
3078 let mut_an = unsafe { out_angle.as_slice_mut()? };
3079 let mut_st = unsafe { out_state.as_slice_mut()? };
3080
3081 let kern = validate_kernel(kernel, true)?;
3082 py.allow_threads(|| {
3083 let simd = match kern {
3084 Kernel::Auto => detect_best_batch_kernel(),
3085 k => k,
3086 };
3087 let row_k = match simd {
3088 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3089 Kernel::Avx512Batch => Kernel::Avx512,
3090 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3091 Kernel::Avx2Batch => Kernel::Avx2,
3092 Kernel::ScalarBatch => Kernel::Scalar,
3093 _ => Kernel::Scalar,
3094 };
3095 correlation_cycle_batch_inner_into(
3096 slice_in, &sweep, row_k, true, mut_r, mut_im, mut_an, mut_st,
3097 )
3098 })
3099 .map_err(|e| PyValueError::new_err(e.to_string()))?;
3100
3101 let dict = PyDict::new(py);
3102 dict.set_item("real", out_real.reshape((rows, cols))?)?;
3103 dict.set_item("imag", out_imag.reshape((rows, cols))?)?;
3104 dict.set_item("angle", out_angle.reshape((rows, cols))?)?;
3105 dict.set_item("state", out_state.reshape((rows, cols))?)?;
3106 dict.set_item(
3107 "periods",
3108 combos
3109 .iter()
3110 .map(|p| p.period.unwrap() as u64)
3111 .collect::<Vec<_>>()
3112 .into_pyarray(py),
3113 )?;
3114 dict.set_item(
3115 "thresholds",
3116 combos
3117 .iter()
3118 .map(|p| p.threshold.unwrap())
3119 .collect::<Vec<_>>()
3120 .into_pyarray(py),
3121 )?;
3122 Ok(dict)
3123}
3124
3125#[cfg(feature = "python")]
3126#[pyclass(name = "CorrelationCycleStream")]
3127pub struct CorrelationCycleStreamPy {
3128 inner: CorrelationCycleStream,
3129}
3130
3131#[cfg(feature = "python")]
3132#[pymethods]
3133impl CorrelationCycleStreamPy {
3134 #[new]
3135 #[pyo3(signature = (period=None, threshold=None))]
3136 pub fn new(period: Option<usize>, threshold: Option<f64>) -> PyResult<Self> {
3137 let params = CorrelationCycleParams { period, threshold };
3138 let inner = CorrelationCycleStream::try_new(params)
3139 .map_err(|e| PyValueError::new_err(e.to_string()))?;
3140 Ok(Self { inner })
3141 }
3142
3143 pub fn update(&mut self, value: f64) -> Option<(f64, f64, f64, f64)> {
3144 self.inner.update(value)
3145 }
3146}
3147
3148#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3149#[derive(Serialize, Deserialize)]
3150pub struct CorrelationCycleJsOutput {
3151 pub real: Vec<f64>,
3152 pub imag: Vec<f64>,
3153 pub angle: Vec<f64>,
3154 pub state: Vec<f64>,
3155}
3156
3157#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3158#[wasm_bindgen]
3159pub fn correlation_cycle_js(
3160 data: &[f64],
3161 period: Option<usize>,
3162 threshold: Option<f64>,
3163) -> Result<JsValue, JsValue> {
3164 let params = CorrelationCycleParams { period, threshold };
3165 let input = CorrelationCycleInput::from_slice(data, params);
3166
3167 let output = correlation_cycle(&input).map_err(|e| JsValue::from_str(&e.to_string()))?;
3168
3169 let js_output = CorrelationCycleJsOutput {
3170 real: output.real,
3171 imag: output.imag,
3172 angle: output.angle,
3173 state: output.state,
3174 };
3175
3176 serde_wasm_bindgen::to_value(&js_output)
3177 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
3178}
3179
3180#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3181#[derive(Serialize, Deserialize)]
3182pub struct CorrelationCycleBatchJsOutput {
3183 pub real: Vec<f64>,
3184 pub imag: Vec<f64>,
3185 pub angle: Vec<f64>,
3186 pub state: Vec<f64>,
3187 pub combos: Vec<CorrelationCycleParams>,
3188 pub rows: usize,
3189 pub cols: usize,
3190}
3191
3192#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3193#[wasm_bindgen]
3194pub fn correlation_cycle_batch_js(
3195 data: &[f64],
3196 period_start: usize,
3197 period_end: usize,
3198 period_step: usize,
3199 threshold_start: f64,
3200 threshold_end: f64,
3201 threshold_step: f64,
3202) -> Result<JsValue, JsValue> {
3203 let sweep = CorrelationCycleBatchRange {
3204 period: (period_start, period_end, period_step),
3205 threshold: (threshold_start, threshold_end, threshold_step),
3206 };
3207
3208 let output = correlation_cycle_batch_inner(data, &sweep, detect_best_kernel(), false)
3209 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3210
3211 let js_output = CorrelationCycleBatchJsOutput {
3212 real: output.real,
3213 imag: output.imag,
3214 angle: output.angle,
3215 state: output.state,
3216 combos: output.combos,
3217 rows: output.rows,
3218 cols: output.cols,
3219 };
3220
3221 serde_wasm_bindgen::to_value(&js_output)
3222 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
3223}
3224
3225#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3226#[wasm_bindgen]
3227pub fn correlation_cycle_alloc(len: usize) -> *mut f64 {
3228 let mut vec = Vec::<f64>::with_capacity(len);
3229 let ptr = vec.as_mut_ptr();
3230 std::mem::forget(vec);
3231 ptr
3232}
3233
3234#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3235#[wasm_bindgen]
3236pub fn correlation_cycle_free(ptr: *mut f64, len: usize) {
3237 if !ptr.is_null() {
3238 unsafe {
3239 let _ = Vec::from_raw_parts(ptr, len, len);
3240 }
3241 }
3242}
3243
3244#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3245#[wasm_bindgen]
3246pub fn correlation_cycle_into(
3247 in_ptr: *const f64,
3248 real_ptr: *mut f64,
3249 imag_ptr: *mut f64,
3250 angle_ptr: *mut f64,
3251 state_ptr: *mut f64,
3252 len: usize,
3253 period: Option<usize>,
3254 threshold: Option<f64>,
3255) -> Result<(), JsValue> {
3256 if in_ptr.is_null()
3257 || real_ptr.is_null()
3258 || imag_ptr.is_null()
3259 || angle_ptr.is_null()
3260 || state_ptr.is_null()
3261 {
3262 return Err(JsValue::from_str("Null pointer provided"));
3263 }
3264
3265 unsafe {
3266 let data = std::slice::from_raw_parts(in_ptr, len);
3267 let params = CorrelationCycleParams { period, threshold };
3268 let input = CorrelationCycleInput::from_slice(data, params);
3269
3270 let has_aliasing = in_ptr == real_ptr as *const f64
3271 || in_ptr == imag_ptr as *const f64
3272 || in_ptr == angle_ptr as *const f64
3273 || in_ptr == state_ptr as *const f64;
3274
3275 if has_aliasing {
3276 let mut temp_real = vec![0.0; len];
3277 let mut temp_imag = vec![0.0; len];
3278 let mut temp_angle = vec![0.0; len];
3279 let mut temp_state = vec![0.0; len];
3280
3281 correlation_cycle_into_slices(
3282 &mut temp_real,
3283 &mut temp_imag,
3284 &mut temp_angle,
3285 &mut temp_state,
3286 &input,
3287 Kernel::Auto,
3288 )
3289 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3290
3291 let real_out = std::slice::from_raw_parts_mut(real_ptr, len);
3292 let imag_out = std::slice::from_raw_parts_mut(imag_ptr, len);
3293 let angle_out = std::slice::from_raw_parts_mut(angle_ptr, len);
3294 let state_out = std::slice::from_raw_parts_mut(state_ptr, len);
3295
3296 real_out.copy_from_slice(&temp_real);
3297 imag_out.copy_from_slice(&temp_imag);
3298 angle_out.copy_from_slice(&temp_angle);
3299 state_out.copy_from_slice(&temp_state);
3300 } else {
3301 let real_out = std::slice::from_raw_parts_mut(real_ptr, len);
3302 let imag_out = std::slice::from_raw_parts_mut(imag_ptr, len);
3303 let angle_out = std::slice::from_raw_parts_mut(angle_ptr, len);
3304 let state_out = std::slice::from_raw_parts_mut(state_ptr, len);
3305
3306 correlation_cycle_into_slices(
3307 real_out,
3308 imag_out,
3309 angle_out,
3310 state_out,
3311 &input,
3312 Kernel::Auto,
3313 )
3314 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3315 }
3316
3317 Ok(())
3318 }
3319}