1#[cfg(all(feature = "python", feature = "cuda"))]
2use numpy::PyUntypedArrayMethods;
3#[cfg(feature = "python")]
4use numpy::{IntoPyArray, PyArray1};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyValueError;
7#[cfg(feature = "python")]
8use pyo3::prelude::*;
9#[cfg(feature = "python")]
10use pyo3::types::PyDict;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::Candles;
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21 make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
26use core::arch::x86_64::*;
27#[cfg(not(target_arch = "wasm32"))]
28use rayon::prelude::*;
29use std::mem::{ManuallyDrop, MaybeUninit};
30#[cfg(all(feature = "python", feature = "cuda"))]
31use std::sync::Arc;
32use thiserror::Error;
33
34#[cfg(all(feature = "python", feature = "cuda"))]
35use crate::cuda::moving_averages::DeviceArrayF32;
36#[cfg(all(feature = "python", feature = "cuda"))]
37use cust::context::Context;
38#[cfg(all(feature = "python", feature = "cuda"))]
39use cust::memory::DeviceBuffer;
40
41#[derive(Debug, Clone)]
42pub enum CorrelHlData<'a> {
43 Candles { candles: &'a Candles },
44 Slices { high: &'a [f64], low: &'a [f64] },
45}
46
47#[derive(Debug, Clone)]
48pub struct CorrelHlOutput {
49 pub values: Vec<f64>,
50}
51
52#[derive(Debug, Clone)]
53pub struct CorrelHlParams {
54 pub period: Option<usize>,
55}
56
57impl Default for CorrelHlParams {
58 fn default() -> Self {
59 Self { period: Some(9) }
60 }
61}
62
63#[derive(Debug, Clone)]
64pub struct CorrelHlInput<'a> {
65 pub data: CorrelHlData<'a>,
66 pub params: CorrelHlParams,
67}
68
69impl<'a> CorrelHlInput<'a> {
70 #[inline]
71 pub fn from_candles(candles: &'a Candles, params: CorrelHlParams) -> Self {
72 Self {
73 data: CorrelHlData::Candles { candles },
74 params,
75 }
76 }
77
78 #[inline]
79 pub fn from_slices(high: &'a [f64], low: &'a [f64], params: CorrelHlParams) -> Self {
80 Self {
81 data: CorrelHlData::Slices { high, low },
82 params,
83 }
84 }
85
86 #[inline]
87 pub fn with_default_candles(candles: &'a Candles) -> Self {
88 Self::from_candles(candles, CorrelHlParams::default())
89 }
90
91 #[inline]
92 pub fn get_period(&self) -> usize {
93 self.params.period.unwrap_or(9)
94 }
95
96 #[inline(always)]
97 pub fn as_refs(&'a self) -> Result<(&'a [f64], &'a [f64]), CorrelHlError> {
98 match &self.data {
99 CorrelHlData::Candles { candles } => {
100 let hi = candles
101 .select_candle_field("high")
102 .map_err(|_| CorrelHlError::CandleFieldError { field: "high" })?;
103 let lo = candles
104 .select_candle_field("low")
105 .map_err(|_| CorrelHlError::CandleFieldError { field: "low" })?;
106 Ok((hi, lo))
107 }
108 CorrelHlData::Slices { high, low } => Ok((*high, *low)),
109 }
110 }
111
112 #[inline(always)]
113 pub fn period_or_default(&self) -> usize {
114 self.params.period.unwrap_or(9)
115 }
116}
117
118#[derive(Copy, Clone, Debug)]
119pub struct CorrelHlBuilder {
120 period: Option<usize>,
121 kernel: Kernel,
122}
123
124impl Default for CorrelHlBuilder {
125 fn default() -> Self {
126 Self {
127 period: None,
128 kernel: Kernel::Auto,
129 }
130 }
131}
132
133impl CorrelHlBuilder {
134 #[inline(always)]
135 pub fn new() -> Self {
136 Self::default()
137 }
138 #[inline(always)]
139 pub fn period(mut self, n: usize) -> Self {
140 self.period = Some(n);
141 self
142 }
143 #[inline(always)]
144 pub fn kernel(mut self, k: Kernel) -> Self {
145 self.kernel = k;
146 self
147 }
148
149 #[inline(always)]
150 pub fn apply(self, candles: &Candles) -> Result<CorrelHlOutput, CorrelHlError> {
151 let params = CorrelHlParams {
152 period: self.period,
153 };
154 let input = CorrelHlInput::from_candles(candles, params);
155 correl_hl_with_kernel(&input, self.kernel)
156 }
157
158 #[inline(always)]
159 pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<CorrelHlOutput, CorrelHlError> {
160 let params = CorrelHlParams {
161 period: self.period,
162 };
163 let input = CorrelHlInput::from_slices(high, low, params);
164 correl_hl_with_kernel(&input, self.kernel)
165 }
166
167 #[inline(always)]
168 pub fn into_stream(self) -> Result<CorrelHlStream, CorrelHlError> {
169 let params = CorrelHlParams {
170 period: self.period,
171 };
172 CorrelHlStream::try_new(params)
173 }
174}
175
176#[derive(Debug, Error)]
177pub enum CorrelHlError {
178 #[error("correl_hl: Empty data (high or low).")]
179 EmptyInputData,
180 #[error("correl_hl: Invalid period: period = {period}, data length = {data_len}")]
181 InvalidPeriod { period: usize, data_len: usize },
182 #[error("correl_hl: Data length mismatch between high and low.")]
183 DataLengthMismatch,
184 #[error("correl_hl: Not enough valid data: needed = {needed}, valid = {valid}")]
185 NotEnoughValidData { needed: usize, valid: usize },
186 #[error("correl_hl: All values are NaN in high or low.")]
187 AllValuesNaN,
188 #[error("correl_hl: Candle field error: {field}")]
189 CandleFieldError { field: &'static str },
190 #[error("correl_hl: Output length mismatch (expected {expected}, got {got})")]
191 OutputLengthMismatch { expected: usize, got: usize },
192 #[error("correl_hl: invalid input: {0}")]
193 InvalidInput(&'static str),
194
195 #[error("correl_hl: invalid range: start={start} end={end} step={step}")]
196 InvalidRange {
197 start: usize,
198 end: usize,
199 step: usize,
200 },
201 #[error("correl_hl: invalid kernel for batch path: {0:?}")]
202 InvalidKernelForBatch(Kernel),
203}
204
205#[inline]
206pub fn correl_hl(input: &CorrelHlInput) -> Result<CorrelHlOutput, CorrelHlError> {
207 correl_hl_with_kernel(input, Kernel::Auto)
208}
209
210#[inline(always)]
211fn correl_hl_prepare<'a>(
212 input: &'a CorrelHlInput,
213 kernel: Kernel,
214) -> Result<(&'a [f64], &'a [f64], usize, usize, Kernel), CorrelHlError> {
215 let (high, low) = input.as_refs()?;
216 if high.is_empty() || low.is_empty() {
217 return Err(CorrelHlError::EmptyInputData);
218 }
219 if high.len() != low.len() {
220 return Err(CorrelHlError::DataLengthMismatch);
221 }
222
223 let period = input.period_or_default();
224 if period == 0 || period > high.len() {
225 return Err(CorrelHlError::InvalidPeriod {
226 period,
227 data_len: high.len(),
228 });
229 }
230
231 let first = high
232 .iter()
233 .zip(low.iter())
234 .position(|(&h, &l)| !h.is_nan() && !l.is_nan())
235 .ok_or(CorrelHlError::AllValuesNaN)?;
236
237 if high.len() - first < period {
238 return Err(CorrelHlError::NotEnoughValidData {
239 needed: period,
240 valid: high.len() - first,
241 });
242 }
243
244 let chosen = match kernel {
245 Kernel::Auto => detect_best_kernel(),
246 k => k,
247 };
248 Ok((high, low, period, first, chosen))
249}
250
251#[inline(always)]
252fn correl_hl_compute_into(
253 high: &[f64],
254 low: &[f64],
255 period: usize,
256 first: usize,
257 kern: Kernel,
258 out: &mut [f64],
259) {
260 unsafe {
261 match kern {
262 Kernel::Scalar | Kernel::ScalarBatch => correl_hl_scalar(high, low, period, first, out),
263 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
264 Kernel::Avx2 | Kernel::Avx2Batch => correl_hl_avx2(high, low, period, first, out),
265 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
266 Kernel::Avx512 | Kernel::Avx512Batch => correl_hl_avx512(high, low, period, first, out),
267 _ => correl_hl_scalar(high, low, period, first, out),
268 }
269 }
270}
271
272pub fn correl_hl_with_kernel(
273 input: &CorrelHlInput,
274 kernel: Kernel,
275) -> Result<CorrelHlOutput, CorrelHlError> {
276 let (high, low, period, first, chosen) = correl_hl_prepare(input, kernel)?;
277 let warm = first + period - 1;
278 let mut out = alloc_with_nan_prefix(high.len(), warm);
279 correl_hl_compute_into(high, low, period, first, chosen, &mut out);
280 Ok(CorrelHlOutput { values: out })
281}
282
283#[inline]
284pub fn correl_hl_into_slice(
285 dst: &mut [f64],
286 input: &CorrelHlInput,
287 kernel: Kernel,
288) -> Result<(), CorrelHlError> {
289 let (high, low, period, first, chosen) = correl_hl_prepare(input, kernel)?;
290 if dst.len() != high.len() {
291 return Err(CorrelHlError::OutputLengthMismatch {
292 expected: high.len(),
293 got: dst.len(),
294 });
295 }
296 correl_hl_compute_into(high, low, period, first, chosen, dst);
297 let warm = first + period - 1;
298 for v in &mut dst[..warm] {
299 *v = f64::from_bits(0x7ff8_0000_0000_0000);
300 }
301 Ok(())
302}
303
304#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
305#[inline]
306pub fn correl_hl_into(out: &mut [f64], input: &CorrelHlInput) -> Result<(), CorrelHlError> {
307 let (high, _low, period, first, _chosen) = correl_hl_prepare(input, Kernel::Auto)?;
308 if out.len() != high.len() {
309 return Err(CorrelHlError::OutputLengthMismatch {
310 expected: high.len(),
311 got: out.len(),
312 });
313 }
314
315 let warm = first + period - 1;
316 let warm_cap = warm.min(out.len());
317 for v in &mut out[..warm_cap] {
318 *v = f64::from_bits(0x7ff8_0000_0000_0000);
319 }
320
321 correl_hl_into_slice(out, input, Kernel::Auto)
322}
323
324#[inline]
325pub fn correl_hl_scalar(high: &[f64], low: &[f64], period: usize, first: usize, out: &mut [f64]) {
326 let mut sum_h = 0.0_f64;
327 let mut sum_h2 = 0.0_f64;
328 let mut sum_l = 0.0_f64;
329 let mut sum_l2 = 0.0_f64;
330 let mut sum_hl = 0.0_f64;
331
332 let inv_pf = 1.0 / (period as f64);
333
334 #[inline(always)]
335 fn corr_from_sums(
336 sum_h: f64,
337 sum_h2: f64,
338 sum_l: f64,
339 sum_l2: f64,
340 sum_hl: f64,
341 inv_pf: f64,
342 ) -> f64 {
343 let cov = sum_hl - (sum_h * sum_l) * inv_pf;
344 let var_h = sum_h2 - (sum_h * sum_h) * inv_pf;
345 let var_l = sum_l2 - (sum_l * sum_l) * inv_pf;
346 if var_h <= 0.0 || var_l <= 0.0 {
347 0.0
348 } else {
349 cov / (var_h.sqrt() * var_l.sqrt())
350 }
351 }
352
353 let init_start = first;
354 let init_end = first + period;
355 let mut j = init_start;
356
357 while j + 4 <= init_end {
358 let h0 = high[j + 0];
359 let l0 = low[j + 0];
360 let h1 = high[j + 1];
361 let l1 = low[j + 1];
362 let h2 = high[j + 2];
363 let l2 = low[j + 2];
364 let h3 = high[j + 3];
365 let l3 = low[j + 3];
366
367 sum_h += h0 + h1 + h2 + h3;
368 sum_l += l0 + l1 + l2 + l3;
369 sum_h2 += h0 * h0 + h1 * h1 + h2 * h2 + h3 * h3;
370 sum_l2 += l0 * l0 + l1 * l1 + l2 * l2 + l3 * l3;
371 sum_hl += h0 * l0 + h1 * l1 + h2 * l2 + h3 * l3;
372 j += 4;
373 }
374 while j < init_end {
375 let h = high[j];
376 let l = low[j];
377 sum_h += h;
378 sum_l += l;
379 sum_h2 += h * h;
380 sum_l2 += l * l;
381 sum_hl += h * l;
382 j += 1;
383 }
384
385 let warm = init_end - 1;
386 out[warm] = corr_from_sums(sum_h, sum_h2, sum_l, sum_l2, sum_hl, inv_pf);
387
388 let n = high.len();
389 for i in init_end..n {
390 let old_idx = i - period;
391 let new_idx = i;
392 let old_h = high[old_idx];
393 let old_l = low[old_idx];
394 let new_h = high[new_idx];
395 let new_l = low[new_idx];
396
397 if old_h.is_nan() || old_l.is_nan() || new_h.is_nan() || new_l.is_nan() {
398 let start = i + 1 - period;
399 let end = i + 1;
400 sum_h = 0.0;
401 sum_l = 0.0;
402 sum_h2 = 0.0;
403 sum_l2 = 0.0;
404 sum_hl = 0.0;
405 let mut k = start;
406 while k + 4 <= end {
407 let h0 = high[k + 0];
408 let l0 = low[k + 0];
409 let h1 = high[k + 1];
410 let l1 = low[k + 1];
411 let h2 = high[k + 2];
412 let l2 = low[k + 2];
413 let h3 = high[k + 3];
414 let l3 = low[k + 3];
415 sum_h += h0 + h1 + h2 + h3;
416 sum_l += l0 + l1 + l2 + l3;
417 sum_h2 += h0 * h0 + h1 * h1 + h2 * h2 + h3 * h3;
418 sum_l2 += l0 * l0 + l1 * l1 + l2 * l2 + l3 * l3;
419 sum_hl += h0 * l0 + h1 * l1 + h2 * l2 + h3 * l3;
420 k += 4;
421 }
422 while k < end {
423 let h = high[k];
424 let l = low[k];
425 sum_h += h;
426 sum_l += l;
427 sum_h2 += h * h;
428 sum_l2 += l * l;
429 sum_hl += h * l;
430 k += 1;
431 }
432 } else {
433 sum_h += new_h - old_h;
434 sum_l += new_l - old_l;
435 sum_h2 += new_h * new_h - old_h * old_h;
436 sum_l2 += new_l * new_l - old_l * old_l;
437 let old_hl = old_h * old_l;
438 sum_hl = new_h.mul_add(new_l, sum_hl - old_hl);
439 }
440
441 out[i] = corr_from_sums(sum_h, sum_h2, sum_l, sum_l2, sum_hl, inv_pf);
442 }
443}
444
445#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
446#[inline]
447pub fn correl_hl_avx2(high: &[f64], low: &[f64], period: usize, first: usize, out: &mut [f64]) {
448 unsafe {
449 #[inline(always)]
450 unsafe fn hsum256_pd(v: __m256d) -> f64 {
451 let hi: __m128d = _mm256_extractf128_pd(v, 1);
452 let lo: __m128d = _mm256_castpd256_pd128(v);
453 let sum2 = _mm_add_pd(lo, hi);
454 let shuf = _mm_unpackhi_pd(sum2, sum2);
455 let sum1 = _mm_add_sd(sum2, shuf);
456 _mm_cvtsd_f64(sum1)
457 }
458
459 #[inline(always)]
460 unsafe fn sum_window_avx2(
461 high: &[f64],
462 low: &[f64],
463 start: usize,
464 end: usize,
465 ) -> (f64, f64, f64, f64, f64) {
466 let mut v_h = _mm256_setzero_pd();
467 let mut v_l = _mm256_setzero_pd();
468 let mut v_h2 = _mm256_setzero_pd();
469 let mut v_l2 = _mm256_setzero_pd();
470 let mut v_hl = _mm256_setzero_pd();
471
472 let mut i = start;
473 let ptr_h = high.as_ptr();
474 let ptr_l = low.as_ptr();
475
476 while i + 4 <= end {
477 let mh = _mm256_loadu_pd(ptr_h.add(i));
478 let ml = _mm256_loadu_pd(ptr_l.add(i));
479 v_h = _mm256_add_pd(v_h, mh);
480 v_l = _mm256_add_pd(v_l, ml);
481 let mh2 = _mm256_mul_pd(mh, mh);
482 let ml2 = _mm256_mul_pd(ml, ml);
483 v_h2 = _mm256_add_pd(v_h2, mh2);
484 v_l2 = _mm256_add_pd(v_l2, ml2);
485 let mhl = _mm256_mul_pd(mh, ml);
486 v_hl = _mm256_add_pd(v_hl, mhl);
487 i += 4;
488 }
489
490 let mut sum_h = hsum256_pd(v_h);
491 let mut sum_l = hsum256_pd(v_l);
492 let mut sum_h2 = hsum256_pd(v_h2);
493 let mut sum_l2 = hsum256_pd(v_l2);
494 let mut sum_hl = hsum256_pd(v_hl);
495
496 while i < end {
497 let h = *high.get_unchecked(i);
498 let l = *low.get_unchecked(i);
499 sum_h += h;
500 sum_l += l;
501 sum_h2 += h * h;
502 sum_l2 += l * l;
503 sum_hl += h * l;
504 i += 1;
505 }
506 (sum_h, sum_h2, sum_l, sum_l2, sum_hl)
507 }
508
509 #[inline(always)]
510 fn corr_from_sums(
511 sum_h: f64,
512 sum_h2: f64,
513 sum_l: f64,
514 sum_l2: f64,
515 sum_hl: f64,
516 inv_pf: f64,
517 ) -> f64 {
518 let cov = sum_hl - (sum_h * sum_l) * inv_pf;
519 let varh = sum_h2 - (sum_h * sum_h) * inv_pf;
520 let varl = sum_l2 - (sum_l * sum_l) * inv_pf;
521 if varh <= 0.0 || varl <= 0.0 {
522 0.0
523 } else {
524 cov / (varh.sqrt() * varl.sqrt())
525 }
526 }
527
528 let inv_pf = 1.0 / (period as f64);
529 let init_start = first;
530 let init_end = first + period;
531
532 let (mut sum_h, mut sum_h2, mut sum_l, mut sum_l2, mut sum_hl) =
533 sum_window_avx2(high, low, init_start, init_end);
534
535 let warm = init_end - 1;
536 out[warm] = corr_from_sums(sum_h, sum_h2, sum_l, sum_l2, sum_hl, inv_pf);
537
538 let n = high.len();
539 for i in init_end..n {
540 let old_idx = i - period;
541 let new_idx = i;
542 let old_h = *high.get_unchecked(old_idx);
543 let old_l = *low.get_unchecked(old_idx);
544 let new_h = *high.get_unchecked(new_idx);
545 let new_l = *low.get_unchecked(new_idx);
546
547 if old_h.is_nan() || old_l.is_nan() || new_h.is_nan() || new_l.is_nan() {
548 let (sh, sh2, sl, sl2, shl) = sum_window_avx2(high, low, i + 1 - period, i + 1);
549 sum_h = sh;
550 sum_h2 = sh2;
551 sum_l = sl;
552 sum_l2 = sl2;
553 sum_hl = shl;
554 } else {
555 sum_h += new_h - old_h;
556 sum_l += new_l - old_l;
557 sum_h2 += new_h * new_h - old_h * old_h;
558 sum_l2 += new_l * new_l - old_l * old_l;
559 let old_hl = old_h * old_l;
560 sum_hl = new_h.mul_add(new_l, sum_hl - old_hl);
561 }
562
563 out[i] = corr_from_sums(sum_h, sum_h2, sum_l, sum_l2, sum_hl, inv_pf);
564 }
565 }
566}
567
568#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
569#[inline]
570pub fn correl_hl_avx512(high: &[f64], low: &[f64], period: usize, first: usize, out: &mut [f64]) {
571 if period <= 32 {
572 unsafe { correl_hl_avx512_short(high, low, period, first, out) }
573 } else {
574 unsafe { correl_hl_avx512_long(high, low, period, first, out) }
575 }
576}
577
578#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
579#[inline]
580pub unsafe fn correl_hl_avx512_short(
581 high: &[f64],
582 low: &[f64],
583 period: usize,
584 first: usize,
585 out: &mut [f64],
586) {
587 correl_hl_avx512_long(high, low, period, first, out)
588}
589
590#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
591#[inline]
592pub unsafe fn correl_hl_avx512_long(
593 high: &[f64],
594 low: &[f64],
595 period: usize,
596 first: usize,
597 out: &mut [f64],
598) {
599 #[inline(always)]
600 unsafe fn hsum256_pd(v: __m256d) -> f64 {
601 let hi: __m128d = _mm256_extractf128_pd(v, 1);
602 let lo: __m128d = _mm256_castpd256_pd128(v);
603 let sum2 = _mm_add_pd(lo, hi);
604 let shuf = _mm_unpackhi_pd(sum2, sum2);
605 let sum1 = _mm_add_sd(sum2, shuf);
606 _mm_cvtsd_f64(sum1)
607 }
608
609 #[inline(always)]
610 unsafe fn hsum512_pd(v: __m512d) -> f64 {
611 let lo256: __m256d = _mm512_castpd512_pd256(v);
612 let hi256: __m256d = _mm512_extractf64x4_pd(v, 1);
613 hsum256_pd(_mm256_add_pd(lo256, hi256))
614 }
615
616 #[inline(always)]
617 unsafe fn sum_window_avx512(
618 high: &[f64],
619 low: &[f64],
620 start: usize,
621 end: usize,
622 ) -> (f64, f64, f64, f64, f64) {
623 let mut v_h = _mm512_setzero_pd();
624 let mut v_l = _mm512_setzero_pd();
625 let mut v_h2 = _mm512_setzero_pd();
626 let mut v_l2 = _mm512_setzero_pd();
627 let mut v_hl = _mm512_setzero_pd();
628
629 let ptr_h = high.as_ptr();
630 let ptr_l = low.as_ptr();
631
632 let mut i = start;
633 while i + 8 <= end {
634 let mh = _mm512_loadu_pd(ptr_h.add(i));
635 let ml = _mm512_loadu_pd(ptr_l.add(i));
636 v_h = _mm512_add_pd(v_h, mh);
637 v_l = _mm512_add_pd(v_l, ml);
638 let mh2 = _mm512_mul_pd(mh, mh);
639 let ml2 = _mm512_mul_pd(ml, ml);
640 v_h2 = _mm512_add_pd(v_h2, mh2);
641 v_l2 = _mm512_add_pd(v_l2, ml2);
642 let mhl = _mm512_mul_pd(mh, ml);
643 v_hl = _mm512_add_pd(v_hl, mhl);
644 i += 8;
645 }
646
647 let rem = (end - i) as i32;
648 if rem != 0 {
649 let mask: __mmask8 = ((1u16 << rem) - 1) as __mmask8;
650 let mh = _mm512_maskz_loadu_pd(mask, ptr_h.add(i));
651 let ml = _mm512_maskz_loadu_pd(mask, ptr_l.add(i));
652 v_h = _mm512_add_pd(v_h, mh);
653 v_l = _mm512_add_pd(v_l, ml);
654 v_h2 = _mm512_add_pd(v_h2, _mm512_mul_pd(mh, mh));
655 v_l2 = _mm512_add_pd(v_l2, _mm512_mul_pd(ml, ml));
656 v_hl = _mm512_add_pd(v_hl, _mm512_mul_pd(mh, ml));
657 }
658
659 (
660 hsum512_pd(v_h),
661 hsum512_pd(v_h2),
662 hsum512_pd(v_l),
663 hsum512_pd(v_l2),
664 hsum512_pd(v_hl),
665 )
666 }
667
668 #[inline(always)]
669 fn corr_from_sums(
670 sum_h: f64,
671 sum_h2: f64,
672 sum_l: f64,
673 sum_l2: f64,
674 sum_hl: f64,
675 inv_pf: f64,
676 ) -> f64 {
677 let cov = sum_hl - (sum_h * sum_l) * inv_pf;
678 let varh = sum_h2 - (sum_h * sum_h) * inv_pf;
679 let varl = sum_l2 - (sum_l * sum_l) * inv_pf;
680 if varh <= 0.0 || varl <= 0.0 {
681 0.0
682 } else {
683 cov / (varh.sqrt() * varl.sqrt())
684 }
685 }
686
687 let inv_pf = 1.0 / (period as f64);
688 let init_start = first;
689 let init_end = first + period;
690
691 let (mut sum_h, mut sum_h2, mut sum_l, mut sum_l2, mut sum_hl) =
692 sum_window_avx512(high, low, init_start, init_end);
693
694 let warm = init_end - 1;
695 out[warm] = corr_from_sums(sum_h, sum_h2, sum_l, sum_l2, sum_hl, inv_pf);
696
697 let n = high.len();
698 for i in init_end..n {
699 let old_idx = i - period;
700 let new_idx = i;
701 let old_h = *high.get_unchecked(old_idx);
702 let old_l = *low.get_unchecked(old_idx);
703 let new_h = *high.get_unchecked(new_idx);
704 let new_l = *low.get_unchecked(new_idx);
705
706 if old_h.is_nan() || old_l.is_nan() || new_h.is_nan() || new_l.is_nan() {
707 let (sh, sh2, sl, sl2, shl) = sum_window_avx512(high, low, i + 1 - period, i + 1);
708 sum_h = sh;
709 sum_h2 = sh2;
710 sum_l = sl;
711 sum_l2 = sl2;
712 sum_hl = shl;
713 } else {
714 sum_h += new_h - old_h;
715 sum_l += new_l - old_l;
716 sum_h2 += new_h * new_h - old_h * old_h;
717 sum_l2 += new_l * new_l - old_l * old_l;
718 let old_hl = old_h * old_l;
719 sum_hl = new_h.mul_add(new_l, sum_hl - old_hl);
720 }
721
722 out[i] = corr_from_sums(sum_h, sum_h2, sum_l, sum_l2, sum_hl, inv_pf);
723 }
724}
725
726#[derive(Debug, Clone)]
727pub struct CorrelHlStream {
728 period: usize,
729 buffer_high: Vec<f64>,
730 buffer_low: Vec<f64>,
731 head: usize,
732 len: usize,
733 nan_in_win: usize,
734
735 sum_h: f64,
736 sum_h2: f64,
737 sum_l: f64,
738 sum_l2: f64,
739 sum_hl: f64,
740
741 inv_pf: f64,
742}
743
744impl CorrelHlStream {
745 #[inline]
746 pub fn try_new(params: CorrelHlParams) -> Result<Self, CorrelHlError> {
747 let period = params.period.unwrap_or(9);
748 if period == 0 {
749 return Err(CorrelHlError::InvalidPeriod {
750 period,
751 data_len: 0,
752 });
753 }
754
755 Ok(Self {
756 period,
757 buffer_high: vec![f64::NAN; period],
758 buffer_low: vec![f64::NAN; period],
759 head: 0,
760 len: 0,
761 nan_in_win: 0,
762 sum_h: 0.0,
763 sum_h2: 0.0,
764 sum_l: 0.0,
765 sum_l2: 0.0,
766 sum_hl: 0.0,
767 inv_pf: 1.0 / (period as f64),
768 })
769 }
770
771 #[inline(always)]
772 pub fn update(&mut self, h: f64, l: f64) -> Option<f64> {
773 if self.len == self.period {
774 let old_h = self.buffer_high[self.head];
775 let old_l = self.buffer_low[self.head];
776
777 if old_h.is_nan() || old_l.is_nan() {
778 if self.nan_in_win > 0 {
779 self.nan_in_win -= 1;
780 }
781 } else {
782 self.sum_h -= old_h;
783 self.sum_l -= old_l;
784 self.sum_h2 -= old_h * old_h;
785 self.sum_l2 -= old_l * old_l;
786 self.sum_hl -= old_h * old_l;
787 }
788 }
789
790 self.buffer_high[self.head] = h;
791 self.buffer_low[self.head] = l;
792
793 if h.is_nan() || l.is_nan() {
794 self.nan_in_win += 1;
795 } else {
796 self.sum_h += h;
797 self.sum_l += l;
798 self.sum_h2 += h * h;
799 self.sum_l2 += l * l;
800 self.sum_hl += h * l;
801 }
802
803 self.head += 1;
804 if self.head == self.period {
805 self.head = 0;
806 }
807 if self.len < self.period {
808 self.len += 1;
809 }
810
811 if self.len < self.period {
812 return None;
813 }
814 if self.nan_in_win != 0 {
815 return Some(f64::NAN);
816 }
817
818 let cov = self.sum_hl - (self.sum_h * self.sum_l) * self.inv_pf;
819 let var_h = self.sum_h2 - (self.sum_h * self.sum_h) * self.inv_pf;
820 let var_l = self.sum_l2 - (self.sum_l * self.sum_l) * self.inv_pf;
821
822 if var_h <= 0.0 || var_l <= 0.0 {
823 return Some(0.0);
824 }
825
826 let denom = (var_h * var_l).sqrt();
827 Some(cov / denom)
828 }
829}
830
831#[derive(Clone, Debug)]
832pub struct CorrelHlBatchRange {
833 pub period: (usize, usize, usize),
834}
835
836impl Default for CorrelHlBatchRange {
837 fn default() -> Self {
838 Self {
839 period: (9, 258, 1),
840 }
841 }
842}
843
844#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
845#[derive(Serialize, Deserialize)]
846pub struct CorrelHlBatchConfig {
847 pub period_range: (usize, usize, usize),
848}
849
850#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
851#[derive(Serialize, Deserialize)]
852pub struct CorrelHlBatchJsOutput {
853 pub values: Vec<f64>,
854 pub periods: Vec<usize>,
855 pub rows: usize,
856 pub cols: usize,
857}
858
859#[derive(Clone, Debug, Default)]
860pub struct CorrelHlBatchBuilder {
861 range: CorrelHlBatchRange,
862 kernel: Kernel,
863}
864
865impl CorrelHlBatchBuilder {
866 pub fn new() -> Self {
867 Self::default()
868 }
869 pub fn kernel(mut self, k: Kernel) -> Self {
870 self.kernel = k;
871 self
872 }
873
874 #[inline]
875 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
876 self.range.period = (start, end, step);
877 self
878 }
879 #[inline]
880 pub fn period_static(mut self, p: usize) -> Self {
881 self.range.period = (p, p, 0);
882 self
883 }
884
885 pub fn apply_slices(
886 self,
887 high: &[f64],
888 low: &[f64],
889 ) -> Result<CorrelHlBatchOutput, CorrelHlError> {
890 correl_hl_batch_with_kernel(high, low, &self.range, self.kernel)
891 }
892
893 pub fn apply_candles(self, c: &Candles) -> Result<CorrelHlBatchOutput, CorrelHlError> {
894 let high = c
895 .select_candle_field("high")
896 .map_err(|_| CorrelHlError::EmptyInputData)?;
897 let low = c
898 .select_candle_field("low")
899 .map_err(|_| CorrelHlError::EmptyInputData)?;
900 self.apply_slices(high, low)
901 }
902}
903
904pub fn expand_grid(r: &CorrelHlBatchRange) -> Result<Vec<CorrelHlParams>, CorrelHlError> {
905 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, CorrelHlError> {
906 if step == 0 || start == end {
907 return Ok(vec![start]);
908 }
909 if start < end {
910 let mut v = Vec::new();
911 let mut x = start;
912 while x <= end {
913 v.push(x);
914 match x.checked_add(step) {
915 Some(nx) if nx > x => x = nx,
916 _ => break,
917 }
918 }
919 if v.is_empty() {
920 return Err(CorrelHlError::InvalidRange { start, end, step });
921 }
922 Ok(v)
923 } else {
924 let mut v = Vec::new();
925 let mut x = start;
926 while x >= end {
927 v.push(x);
928 if x < end + step {
929 break;
930 }
931 x = x.saturating_sub(step);
932 if x == 0 {
933 break;
934 }
935 }
936 if v.is_empty() {
937 return Err(CorrelHlError::InvalidRange { start, end, step });
938 }
939 Ok(v)
940 }
941 }
942 let periods = axis_usize(r.period)?;
943 if periods.is_empty() {
944 return Err(CorrelHlError::InvalidRange {
945 start: r.period.0,
946 end: r.period.1,
947 step: r.period.2,
948 });
949 }
950 let mut out = Vec::with_capacity(periods.len());
951 for &p in &periods {
952 out.push(CorrelHlParams { period: Some(p) });
953 }
954 Ok(out)
955}
956
957#[derive(Clone, Debug)]
958pub struct CorrelHlBatchOutput {
959 pub values: Vec<f64>,
960 pub combos: Vec<CorrelHlParams>,
961 pub rows: usize,
962 pub cols: usize,
963}
964
965impl CorrelHlBatchOutput {
966 pub fn row_for_params(&self, p: &CorrelHlParams) -> Option<usize> {
967 self.combos
968 .iter()
969 .position(|c| c.period.unwrap_or(9) == p.period.unwrap_or(9))
970 }
971
972 pub fn values_for(&self, p: &CorrelHlParams) -> Option<&[f64]> {
973 self.row_for_params(p).map(|row| {
974 let start = row * self.cols;
975 &self.values[start..start + self.cols]
976 })
977 }
978}
979
980pub fn correl_hl_batch_with_kernel(
981 high: &[f64],
982 low: &[f64],
983 sweep: &CorrelHlBatchRange,
984 k: Kernel,
985) -> Result<CorrelHlBatchOutput, CorrelHlError> {
986 let kernel = match k {
987 Kernel::Auto => detect_best_batch_kernel(),
988 other if other.is_batch() => other,
989 other => return Err(CorrelHlError::InvalidKernelForBatch(other)),
990 };
991
992 let simd = match kernel {
993 Kernel::Avx512Batch => Kernel::Avx512,
994 Kernel::Avx2Batch => Kernel::Avx2,
995 Kernel::ScalarBatch => Kernel::Scalar,
996 _ => unreachable!(),
997 };
998 correl_hl_batch_par_slice(high, low, sweep, simd)
999}
1000
1001#[inline(always)]
1002pub fn correl_hl_batch_slice(
1003 high: &[f64],
1004 low: &[f64],
1005 sweep: &CorrelHlBatchRange,
1006 kern: Kernel,
1007) -> Result<CorrelHlBatchOutput, CorrelHlError> {
1008 correl_hl_batch_inner(high, low, sweep, kern, false)
1009}
1010
1011#[inline(always)]
1012pub fn correl_hl_batch_par_slice(
1013 high: &[f64],
1014 low: &[f64],
1015 sweep: &CorrelHlBatchRange,
1016 kern: Kernel,
1017) -> Result<CorrelHlBatchOutput, CorrelHlError> {
1018 correl_hl_batch_inner(high, low, sweep, kern, true)
1019}
1020
1021#[inline(always)]
1022fn correl_hl_batch_inner(
1023 high: &[f64],
1024 low: &[f64],
1025 sweep: &CorrelHlBatchRange,
1026 kern: Kernel,
1027 parallel: bool,
1028) -> Result<CorrelHlBatchOutput, CorrelHlError> {
1029 let combos = expand_grid(sweep)?;
1030
1031 let first = high
1032 .iter()
1033 .zip(low.iter())
1034 .position(|(&h, &l)| !h.is_nan() && !l.is_nan())
1035 .ok_or(CorrelHlError::AllValuesNaN)?;
1036
1037 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1038 if high.len() - first < max_p {
1039 return Err(CorrelHlError::NotEnoughValidData {
1040 needed: max_p,
1041 valid: high.len() - first,
1042 });
1043 }
1044
1045 let rows = combos.len();
1046 let cols = high.len();
1047
1048 rows.checked_mul(cols)
1049 .ok_or(CorrelHlError::InvalidInput("rows*cols overflow"))?;
1050
1051 let warm: Vec<usize> = combos
1052 .iter()
1053 .map(|c| first + c.period.unwrap() - 1)
1054 .collect();
1055
1056 let mut buf_mu = make_uninit_matrix(rows, cols);
1057
1058 init_matrix_prefixes(&mut buf_mu, cols, &warm);
1059
1060 let mut buf_guard = ManuallyDrop::new(buf_mu);
1061 let values_slice: &mut [f64] = unsafe {
1062 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
1063 };
1064
1065 let n = high.len();
1066 let mut ps_h = vec![0.0f64; n + 1];
1067 let mut ps_h2 = vec![0.0f64; n + 1];
1068 let mut ps_l = vec![0.0f64; n + 1];
1069 let mut ps_l2 = vec![0.0f64; n + 1];
1070 let mut ps_hl = vec![0.0f64; n + 1];
1071 let mut ps_nan = vec![0i32; n + 1];
1072 for i in 0..n {
1073 let h = high[i];
1074 let l = low[i];
1075 let (ph, ph2, pl, pl2, phl) = (ps_h[i], ps_h2[i], ps_l[i], ps_l2[i], ps_hl[i]);
1076 if h.is_nan() || l.is_nan() {
1077 ps_h[i + 1] = ph;
1078 ps_h2[i + 1] = ph2;
1079 ps_l[i + 1] = pl;
1080 ps_l2[i + 1] = pl2;
1081 ps_hl[i + 1] = phl;
1082 ps_nan[i + 1] = ps_nan[i] + 1;
1083 } else {
1084 ps_h[i + 1] = ph + h;
1085 ps_h2[i + 1] = ph2 + h * h;
1086 ps_l[i + 1] = pl + l;
1087 ps_l2[i + 1] = pl2 + l * l;
1088 ps_hl[i + 1] = phl + h * l;
1089 ps_nan[i + 1] = ps_nan[i];
1090 }
1091 }
1092
1093 let do_row = |row: usize, out_row: &mut [f64]| {
1094 let p = combos[row].period.unwrap();
1095 let inv_pf = 1.0 / (p as f64);
1096 let warm = first + p - 1;
1097 for i in warm..n {
1098 let end = i + 1;
1099 let start = end - p;
1100 let nan_w = ps_nan[end] - ps_nan[start];
1101 if nan_w != 0 {
1102 out_row[i] = f64::NAN;
1103 } else {
1104 let sum_h = ps_h[end] - ps_h[start];
1105 let sum_l = ps_l[end] - ps_l[start];
1106 let sum_h2 = ps_h2[end] - ps_h2[start];
1107 let sum_l2 = ps_l2[end] - ps_l2[start];
1108 let sum_hl = ps_hl[end] - ps_hl[start];
1109 let cov = sum_hl - (sum_h * sum_l) * inv_pf;
1110 let var_h = sum_h2 - (sum_h * sum_h) * inv_pf;
1111 let var_l = sum_l2 - (sum_l * sum_l) * inv_pf;
1112 if var_h <= 0.0 || var_l <= 0.0 {
1113 out_row[i] = 0.0;
1114 } else {
1115 out_row[i] = cov / (var_h.sqrt() * var_l.sqrt());
1116 }
1117 }
1118 }
1119 };
1120
1121 if parallel {
1122 #[cfg(not(target_arch = "wasm32"))]
1123 {
1124 values_slice
1125 .par_chunks_mut(cols)
1126 .enumerate()
1127 .for_each(|(row, slice)| do_row(row, slice));
1128 }
1129
1130 #[cfg(target_arch = "wasm32")]
1131 {
1132 for (row, slice) in values_slice.chunks_mut(cols).enumerate() {
1133 do_row(row, slice);
1134 }
1135 }
1136 } else {
1137 for (row, slice) in values_slice.chunks_mut(cols).enumerate() {
1138 do_row(row, slice);
1139 }
1140 }
1141
1142 let values = unsafe {
1143 Vec::from_raw_parts(
1144 buf_guard.as_mut_ptr() as *mut f64,
1145 buf_guard.len(),
1146 buf_guard.capacity(),
1147 )
1148 };
1149
1150 Ok(CorrelHlBatchOutput {
1151 values,
1152 combos,
1153 rows,
1154 cols,
1155 })
1156}
1157
1158#[inline(always)]
1159fn correl_hl_batch_inner_into(
1160 high: &[f64],
1161 low: &[f64],
1162 sweep: &CorrelHlBatchRange,
1163 kern: Kernel,
1164 parallel: bool,
1165 out: &mut [f64],
1166) -> Result<Vec<CorrelHlParams>, CorrelHlError> {
1167 let combos = expand_grid(sweep)?;
1168
1169 let first = high
1170 .iter()
1171 .zip(low.iter())
1172 .position(|(&h, &l)| !h.is_nan() && !l.is_nan())
1173 .ok_or(CorrelHlError::AllValuesNaN)?;
1174
1175 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1176 if high.len() - first < max_p {
1177 return Err(CorrelHlError::NotEnoughValidData {
1178 needed: max_p,
1179 valid: high.len() - first,
1180 });
1181 }
1182
1183 let rows = combos.len();
1184 let cols = high.len();
1185
1186 let total = rows
1187 .checked_mul(cols)
1188 .ok_or(CorrelHlError::InvalidInput("rows*cols overflow"))?;
1189 if out.len() != total {
1190 return Err(CorrelHlError::OutputLengthMismatch {
1191 expected: total,
1192 got: out.len(),
1193 });
1194 }
1195
1196 let warm: Vec<usize> = combos
1197 .iter()
1198 .map(|c| first + c.period.unwrap() - 1)
1199 .collect();
1200 let out_mu: &mut [MaybeUninit<f64>] = unsafe {
1201 core::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1202 };
1203 init_matrix_prefixes(out_mu, cols, &warm);
1204
1205 let n = high.len();
1206 let mut ps_h = vec![0.0f64; n + 1];
1207 let mut ps_h2 = vec![0.0f64; n + 1];
1208 let mut ps_l = vec![0.0f64; n + 1];
1209 let mut ps_l2 = vec![0.0f64; n + 1];
1210 let mut ps_hl = vec![0.0f64; n + 1];
1211 let mut ps_nan = vec![0i32; n + 1];
1212 for i in 0..n {
1213 let h = high[i];
1214 let l = low[i];
1215 let (prev_h, prev_h2, prev_l, prev_l2, prev_hl) =
1216 (ps_h[i], ps_h2[i], ps_l[i], ps_l2[i], ps_hl[i]);
1217 if h.is_nan() || l.is_nan() {
1218 ps_h[i + 1] = prev_h;
1219 ps_h2[i + 1] = prev_h2;
1220 ps_l[i + 1] = prev_l;
1221 ps_l2[i + 1] = prev_l2;
1222 ps_hl[i + 1] = prev_hl;
1223 ps_nan[i + 1] = ps_nan[i] + 1;
1224 } else {
1225 ps_h[i + 1] = prev_h + h;
1226 ps_h2[i + 1] = prev_h2 + h * h;
1227 ps_l[i + 1] = prev_l + l;
1228 ps_l2[i + 1] = prev_l2 + l * l;
1229 ps_hl[i + 1] = prev_hl + h * l;
1230 ps_nan[i + 1] = ps_nan[i];
1231 }
1232 }
1233
1234 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| {
1235 let p = combos[row].period.unwrap();
1236 let inv_pf = 1.0 / (p as f64);
1237 let dst: &mut [f64] = unsafe {
1238 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len())
1239 };
1240
1241 let warm = first + p - 1;
1242 for i in warm..n {
1243 let end = i + 1;
1244 let start = end - p;
1245 let nan_w = ps_nan[end] - ps_nan[start];
1246 if nan_w != 0 {
1247 dst[i] = f64::NAN;
1248 } else {
1249 let sum_h = ps_h[end] - ps_h[start];
1250 let sum_l = ps_l[end] - ps_l[start];
1251 let sum_h2 = ps_h2[end] - ps_h2[start];
1252 let sum_l2 = ps_l2[end] - ps_l2[start];
1253 let sum_hl = ps_hl[end] - ps_hl[start];
1254 let cov = sum_hl - (sum_h * sum_l) * inv_pf;
1255 let var_h = sum_h2 - (sum_h * sum_h) * inv_pf;
1256 let var_l = sum_l2 - (sum_l * sum_l) * inv_pf;
1257 if var_h <= 0.0 || var_l <= 0.0 {
1258 dst[i] = 0.0;
1259 } else {
1260 dst[i] = cov / (var_h.sqrt() * var_l.sqrt());
1261 }
1262 }
1263 }
1264 };
1265
1266 if parallel {
1267 #[cfg(not(target_arch = "wasm32"))]
1268 {
1269 out_mu
1270 .par_chunks_mut(cols)
1271 .enumerate()
1272 .for_each(|(r, s)| do_row(r, s));
1273 }
1274 #[cfg(target_arch = "wasm32")]
1275 {
1276 for (r, s) in out_mu.chunks_mut(cols).enumerate() {
1277 do_row(r, s);
1278 }
1279 }
1280 } else {
1281 for (r, s) in out_mu.chunks_mut(cols).enumerate() {
1282 do_row(r, s);
1283 }
1284 }
1285
1286 Ok(combos)
1287}
1288
1289#[inline(always)]
1290unsafe fn correl_hl_row_scalar(
1291 high: &[f64],
1292 low: &[f64],
1293 first: usize,
1294 period: usize,
1295 out: &mut [f64],
1296) {
1297 correl_hl_scalar(high, low, period, first, out)
1298}
1299
1300#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1301#[inline(always)]
1302unsafe fn correl_hl_row_avx2(
1303 high: &[f64],
1304 low: &[f64],
1305 first: usize,
1306 period: usize,
1307 out: &mut [f64],
1308) {
1309 correl_hl_avx2(high, low, period, first, out)
1310}
1311
1312#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1313#[inline(always)]
1314pub unsafe fn correl_hl_row_avx512(
1315 high: &[f64],
1316 low: &[f64],
1317 first: usize,
1318 period: usize,
1319 out: &mut [f64],
1320) {
1321 if period <= 32 {
1322 correl_hl_row_avx512_short(high, low, first, period, out)
1323 } else {
1324 correl_hl_row_avx512_long(high, low, first, period, out)
1325 }
1326}
1327
1328#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1329#[inline(always)]
1330pub unsafe fn correl_hl_row_avx512_short(
1331 high: &[f64],
1332 low: &[f64],
1333 first: usize,
1334 period: usize,
1335 out: &mut [f64],
1336) {
1337 correl_hl_avx512_short(high, low, period, first, out)
1338}
1339
1340#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1341#[inline(always)]
1342pub unsafe fn correl_hl_row_avx512_long(
1343 high: &[f64],
1344 low: &[f64],
1345 first: usize,
1346 period: usize,
1347 out: &mut [f64],
1348) {
1349 correl_hl_avx512_long(high, low, period, first, out)
1350}
1351
1352#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1353#[wasm_bindgen]
1354pub fn correl_hl_js(high: &[f64], low: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1355 let params = CorrelHlParams {
1356 period: Some(period),
1357 };
1358 let input = CorrelHlInput::from_slices(high, low, params);
1359
1360 let mut output = vec![0.0; high.len()];
1361
1362 correl_hl_into_slice(&mut output, &input, Kernel::Auto)
1363 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1364
1365 Ok(output)
1366}
1367
1368#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1369#[wasm_bindgen]
1370pub fn correl_hl_into(
1371 high_ptr: *const f64,
1372 low_ptr: *const f64,
1373 out_ptr: *mut f64,
1374 len: usize,
1375 period: usize,
1376) -> Result<(), JsValue> {
1377 if high_ptr.is_null() || low_ptr.is_null() || out_ptr.is_null() {
1378 return Err(JsValue::from_str("Null pointer provided"));
1379 }
1380
1381 unsafe {
1382 let high = std::slice::from_raw_parts(high_ptr, len);
1383 let low = std::slice::from_raw_parts(low_ptr, len);
1384 let params = CorrelHlParams {
1385 period: Some(period),
1386 };
1387 let input = CorrelHlInput::from_slices(high, low, params);
1388
1389 if high_ptr == out_ptr || low_ptr == out_ptr {
1390 let mut temp = vec![0.0; len];
1391 correl_hl_into_slice(&mut temp, &input, Kernel::Auto)
1392 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1393 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1394 out.copy_from_slice(&temp);
1395 } else {
1396 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1397 correl_hl_into_slice(out, &input, Kernel::Auto)
1398 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1399 }
1400 Ok(())
1401 }
1402}
1403
1404#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1405#[wasm_bindgen]
1406pub fn correl_hl_alloc(len: usize) -> *mut f64 {
1407 let mut vec = Vec::<f64>::with_capacity(len);
1408 let ptr = vec.as_mut_ptr();
1409 std::mem::forget(vec);
1410 ptr
1411}
1412
1413#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1414#[wasm_bindgen]
1415pub fn correl_hl_free(ptr: *mut f64, len: usize) {
1416 if !ptr.is_null() {
1417 unsafe {
1418 let _ = Vec::from_raw_parts(ptr, len, len);
1419 }
1420 }
1421}
1422
1423#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1424#[wasm_bindgen(js_name = correl_hl_batch)]
1425pub fn correl_hl_batch_js(high: &[f64], low: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1426 let config: CorrelHlBatchConfig = serde_wasm_bindgen::from_value(config)
1427 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1428
1429 let sweep = CorrelHlBatchRange {
1430 period: config.period_range,
1431 };
1432
1433 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1434 let rows = combos.len();
1435 let cols = high.len();
1436 let total = rows
1437 .checked_mul(cols)
1438 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
1439 let mut values = vec![0.0f64; total];
1440
1441 correl_hl_batch_inner_into(high, low, &sweep, Kernel::Auto, false, &mut values)
1442 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1443
1444 let periods: Vec<usize> = combos.iter().map(|c| c.period.unwrap()).collect();
1445
1446 let js_output = CorrelHlBatchJsOutput {
1447 values,
1448 periods,
1449 rows,
1450 cols,
1451 };
1452
1453 serde_wasm_bindgen::to_value(&js_output)
1454 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1455}
1456
1457#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1458#[wasm_bindgen]
1459pub fn correl_hl_batch_into(
1460 high_ptr: *const f64,
1461 low_ptr: *const f64,
1462 out_ptr: *mut f64,
1463 len: usize,
1464 period_start: usize,
1465 period_end: usize,
1466 period_step: usize,
1467) -> Result<usize, JsValue> {
1468 if high_ptr.is_null() || low_ptr.is_null() || out_ptr.is_null() {
1469 return Err(JsValue::from_str("Null pointer provided"));
1470 }
1471
1472 unsafe {
1473 let high = std::slice::from_raw_parts(high_ptr, len);
1474 let low = std::slice::from_raw_parts(low_ptr, len);
1475
1476 let sweep = CorrelHlBatchRange {
1477 period: (period_start, period_end, period_step),
1478 };
1479
1480 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1481 let rows = combos.len();
1482
1483 let total = rows
1484 .checked_mul(len)
1485 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
1486 let out_slice = std::slice::from_raw_parts_mut(out_ptr, total);
1487
1488 correl_hl_batch_inner_into(high, low, &sweep, Kernel::Auto, false, out_slice)
1489 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1490
1491 Ok(rows)
1492 }
1493}
1494
1495#[cfg(feature = "python")]
1496#[pyfunction(name = "correl_hl")]
1497#[pyo3(signature = (high, low, period, kernel=None))]
1498pub fn correl_hl_py<'py>(
1499 py: Python<'py>,
1500 high: numpy::PyReadonlyArray1<'py, f64>,
1501 low: numpy::PyReadonlyArray1<'py, f64>,
1502 period: usize,
1503 kernel: Option<&str>,
1504) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1505 use numpy::{IntoPyArray, PyArrayMethods};
1506
1507 let high_slice = high.as_slice()?;
1508 let low_slice = low.as_slice()?;
1509 let kern = validate_kernel(kernel, false)?;
1510
1511 let params = CorrelHlParams {
1512 period: Some(period),
1513 };
1514 let input = CorrelHlInput::from_slices(high_slice, low_slice, params);
1515
1516 let result_vec: Vec<f64> = py
1517 .allow_threads(|| correl_hl_with_kernel(&input, kern).map(|o| o.values))
1518 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1519
1520 Ok(result_vec.into_pyarray(py))
1521}
1522
1523#[cfg(feature = "python")]
1524#[pyclass(name = "CorrelHlStream")]
1525pub struct CorrelHlStreamPy {
1526 stream: CorrelHlStream,
1527}
1528
1529#[cfg(feature = "python")]
1530#[pymethods]
1531impl CorrelHlStreamPy {
1532 #[new]
1533 fn new(period: usize) -> PyResult<Self> {
1534 let params = CorrelHlParams {
1535 period: Some(period),
1536 };
1537 let stream =
1538 CorrelHlStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1539 Ok(CorrelHlStreamPy { stream })
1540 }
1541
1542 fn update(&mut self, high: f64, low: f64) -> Option<f64> {
1543 self.stream.update(high, low)
1544 }
1545}
1546
1547#[cfg(feature = "python")]
1548#[pyfunction(name = "correl_hl_batch")]
1549#[pyo3(signature = (high, low, period_range, kernel=None))]
1550pub fn correl_hl_batch_py<'py>(
1551 py: Python<'py>,
1552 high: numpy::PyReadonlyArray1<'py, f64>,
1553 low: numpy::PyReadonlyArray1<'py, f64>,
1554 period_range: (usize, usize, usize),
1555 kernel: Option<&str>,
1556) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1557 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1558 use pyo3::types::PyDict;
1559
1560 let high_slice = high.as_slice()?;
1561 let low_slice = low.as_slice()?;
1562
1563 let sweep = CorrelHlBatchRange {
1564 period: period_range,
1565 };
1566
1567 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1568 let rows = combos.len();
1569 let cols = high_slice.len();
1570
1571 let total = rows
1572 .checked_mul(cols)
1573 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1574 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1575 let slice_out = unsafe { out_arr.as_slice_mut()? };
1576
1577 let kern = validate_kernel(kernel, true)?;
1578
1579 let combos = py
1580 .allow_threads(|| {
1581 let kernel = match kern {
1582 Kernel::Auto => detect_best_batch_kernel(),
1583 k => k,
1584 };
1585 let simd = match kernel {
1586 Kernel::Avx512Batch => Kernel::Avx512,
1587 Kernel::Avx2Batch => Kernel::Avx2,
1588 Kernel::ScalarBatch => Kernel::Scalar,
1589 _ => unreachable!(),
1590 };
1591 correl_hl_batch_inner_into(high_slice, low_slice, &sweep, simd, true, slice_out)
1592 })
1593 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1594
1595 let dict = PyDict::new(py);
1596 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1597 dict.set_item(
1598 "periods",
1599 combos
1600 .iter()
1601 .map(|p| p.period.unwrap() as u64)
1602 .collect::<Vec<_>>()
1603 .into_pyarray(py),
1604 )?;
1605
1606 Ok(dict)
1607}
1608
1609#[cfg(all(feature = "python", feature = "cuda"))]
1610#[pyclass(
1611 module = "ta_indicators.cuda",
1612 name = "CorrelHlDeviceArrayF32",
1613 unsendable
1614)]
1615pub struct CorrelHlDeviceArrayF32Py {
1616 pub(crate) inner: DeviceArrayF32,
1617 _ctx_guard: Arc<Context>,
1618 _device_id: u32,
1619}
1620
1621#[cfg(all(feature = "python", feature = "cuda"))]
1622#[pymethods]
1623impl CorrelHlDeviceArrayF32Py {
1624 #[getter]
1625 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1626 let inner = &self.inner;
1627 let d = PyDict::new(py);
1628 let itemsize = std::mem::size_of::<f32>();
1629 d.set_item("shape", (inner.rows, inner.cols))?;
1630 d.set_item("typestr", "<f4")?;
1631 d.set_item("strides", (inner.cols * itemsize, itemsize))?;
1632 d.set_item("data", (inner.device_ptr() as usize, false))?;
1633 d.set_item("version", 3)?;
1634 Ok(d)
1635 }
1636
1637 fn __dlpack_device__(&self) -> (i32, i32) {
1638 (2, self._device_id as i32)
1639 }
1640
1641 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1642 fn __dlpack__<'py>(
1643 &mut self,
1644 py: Python<'py>,
1645 stream: Option<usize>,
1646 max_version: Option<(u32, u32)>,
1647 dl_device: Option<(i32, i32)>,
1648 copy: Option<bool>,
1649 ) -> PyResult<PyObject> {
1650 use pyo3::ffi as pyffi;
1651 use std::ffi::{c_void, CString};
1652
1653 #[repr(C)]
1654 struct DLDevice {
1655 device_type: i32,
1656 device_id: i32,
1657 }
1658 #[repr(C)]
1659 struct DLDataType {
1660 code: u8,
1661 bits: u8,
1662 lanes: u16,
1663 }
1664 #[repr(C)]
1665 struct DLTensor {
1666 data: *mut c_void,
1667 device: DLDevice,
1668 ndim: i32,
1669 dtype: DLDataType,
1670 shape: *mut i64,
1671 strides: *mut i64,
1672 byte_offset: u64,
1673 }
1674 #[repr(C)]
1675 struct DLManagedTensor {
1676 dl_tensor: DLTensor,
1677 manager_ctx: *mut c_void,
1678 deleter: Option<unsafe extern "C" fn(*mut DLManagedTensor)>,
1679 }
1680 #[repr(C)]
1681 struct DLManagedTensorVersioned {
1682 manager: *mut DLManagedTensor,
1683 version: u32,
1684 }
1685
1686 #[repr(C)]
1687 struct ManagerCtx {
1688 shape: *mut i64,
1689 strides: *mut i64,
1690 _shape: Box<[i64; 2]>,
1691 _strides: Box<[i64; 2]>,
1692 _self_ref: PyObject,
1693 _arr: DeviceArrayF32,
1694 _ctx: Arc<Context>,
1695 }
1696
1697 unsafe extern "C" fn deleter(p: *mut DLManagedTensor) {
1698 if p.is_null() {
1699 return;
1700 }
1701 let mt = Box::from_raw(p);
1702 let ctx_ptr = mt.manager_ctx as *mut ManagerCtx;
1703 if !ctx_ptr.is_null() {
1704 let _ = Box::from_raw(ctx_ptr);
1705 }
1706 }
1707
1708 let dummy =
1709 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1710 let inner = std::mem::replace(
1711 &mut self.inner,
1712 DeviceArrayF32 {
1713 buf: dummy,
1714 rows: 0,
1715 cols: 0,
1716 },
1717 );
1718
1719 let rows = inner.rows as i64;
1720 let cols = inner.cols as i64;
1721 let total = (rows as i128) * (cols as i128);
1722 let mut shape = Box::new([rows, cols]);
1723 let mut strides = Box::new([cols, 1]);
1724 let shape_ptr = shape.as_mut_ptr();
1725 let strides_ptr = strides.as_mut_ptr();
1726
1727 let self_ref =
1728 unsafe { PyObject::from_borrowed_ptr(py, self as *mut _ as *mut pyo3::ffi::PyObject) };
1729 let mgr = Box::new(ManagerCtx {
1730 shape: shape_ptr,
1731 strides: strides_ptr,
1732 _shape: shape,
1733 _strides: strides,
1734 _self_ref: self_ref,
1735 _arr: inner,
1736 _ctx: self._ctx_guard.clone(),
1737 });
1738 let mgr_ptr = Box::into_raw(mgr) as *mut c_void;
1739
1740 let dl = DLTensor {
1741 data: if total == 0 {
1742 std::ptr::null_mut()
1743 } else {
1744 unsafe {
1745 (*(mgr_ptr as *mut ManagerCtx))
1746 ._arr
1747 .buf
1748 .as_device_ptr()
1749 .as_raw() as *mut c_void
1750 }
1751 },
1752 device: DLDevice {
1753 device_type: 2,
1754 device_id: self._device_id as i32,
1755 },
1756 ndim: 2,
1757 dtype: DLDataType {
1758 code: 2,
1759 bits: 32,
1760 lanes: 1,
1761 },
1762 shape: shape_ptr,
1763 strides: strides_ptr,
1764 byte_offset: 0,
1765 };
1766 let mt = Box::new(DLManagedTensor {
1767 dl_tensor: dl,
1768 manager_ctx: mgr_ptr,
1769 deleter: Some(deleter),
1770 });
1771
1772 let want_versioned = max_version.map(|(maj, _)| maj >= 1).unwrap_or(false);
1773
1774 unsafe {
1775 if want_versioned {
1776 let wrapped = Box::new(DLManagedTensorVersioned {
1777 manager: Box::into_raw(mt),
1778 version: 1,
1779 });
1780 let ptr = Box::into_raw(wrapped) as *mut c_void;
1781 let name = CString::new("dltensor_versioned").unwrap();
1782 let cap = pyffi::PyCapsule_New(ptr, name.as_ptr(), None);
1783 if cap.is_null() {
1784 let _ = Box::from_raw(ptr as *mut DLManagedTensorVersioned);
1785 return Err(PyValueError::new_err("failed to create DLPack capsule"));
1786 }
1787 Ok(PyObject::from_owned_ptr(py, cap))
1788 } else {
1789 let ptr = Box::into_raw(mt) as *mut c_void;
1790 let name = CString::new("dltensor").unwrap();
1791 let cap = pyffi::PyCapsule_New(ptr, name.as_ptr(), None);
1792 if cap.is_null() {
1793 let _ = Box::from_raw(ptr as *mut DLManagedTensor);
1794 return Err(PyValueError::new_err("failed to create DLPack capsule"));
1795 }
1796 Ok(PyObject::from_owned_ptr(py, cap))
1797 }
1798 }
1799 }
1800}
1801
1802#[cfg(all(feature = "python", feature = "cuda"))]
1803impl CorrelHlDeviceArrayF32Py {
1804 pub fn new_from_rust(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
1805 Self {
1806 inner,
1807 _ctx_guard: ctx_guard,
1808 _device_id: device_id,
1809 }
1810 }
1811}
1812
1813#[cfg(all(feature = "python", feature = "cuda"))]
1814#[pyfunction(name = "correl_hl_cuda_batch_dev")]
1815#[pyo3(signature = (high_f32, low_f32, period_range, device_id=0))]
1816pub fn correl_hl_cuda_batch_dev_py(
1817 py: Python<'_>,
1818 high_f32: numpy::PyReadonlyArray1<'_, f32>,
1819 low_f32: numpy::PyReadonlyArray1<'_, f32>,
1820 period_range: (usize, usize, usize),
1821 device_id: usize,
1822) -> PyResult<CorrelHlDeviceArrayF32Py> {
1823 use crate::cuda::correl_hl_wrapper::CudaCorrelHl;
1824 use crate::cuda::cuda_available;
1825 if !cuda_available() {
1826 return Err(PyValueError::new_err("CUDA not available"));
1827 }
1828 let h = high_f32.as_slice()?;
1829 let l = low_f32.as_slice()?;
1830 let sweep = CorrelHlBatchRange {
1831 period: period_range,
1832 };
1833 let (inner, ctx, dev_id) = py.allow_threads(|| {
1834 let cuda =
1835 CudaCorrelHl::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1836 let (dev, _combos) = cuda
1837 .correl_hl_batch_dev(h, l, &sweep)
1838 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1839 Ok::<_, PyErr>((dev, cuda.context_arc(), cuda.device_id()))
1840 })?;
1841 Ok(CorrelHlDeviceArrayF32Py::new_from_rust(inner, ctx, dev_id))
1842}
1843
1844#[cfg(all(feature = "python", feature = "cuda"))]
1845#[pyfunction(name = "correl_hl_cuda_many_series_one_param_dev")]
1846#[pyo3(signature = (high_tm_f32, low_tm_f32, period, device_id=0))]
1847pub fn correl_hl_cuda_many_series_one_param_dev_py(
1848 py: Python<'_>,
1849 high_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1850 low_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1851 period: usize,
1852 device_id: usize,
1853) -> PyResult<CorrelHlDeviceArrayF32Py> {
1854 use crate::cuda::correl_hl_wrapper::CudaCorrelHl;
1855 use crate::cuda::cuda_available;
1856 if !cuda_available() {
1857 return Err(PyValueError::new_err("CUDA not available"));
1858 }
1859 let shape = high_tm_f32.shape();
1860 if shape.len() != 2 || low_tm_f32.shape() != shape {
1861 return Err(PyValueError::new_err("expected matching 2D arrays"));
1862 }
1863 let rows = shape[0];
1864 let cols = shape[1];
1865 let h = high_tm_f32.as_slice()?;
1866 let l = low_tm_f32.as_slice()?;
1867 let (inner, ctx, dev_id) = py.allow_threads(|| {
1868 let cuda =
1869 CudaCorrelHl::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1870 let dev = cuda
1871 .correl_hl_many_series_one_param_time_major_dev(h, l, cols, rows, period)
1872 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1873 Ok::<_, PyErr>((dev, cuda.context_arc(), cuda.device_id()))
1874 })?;
1875 Ok(CorrelHlDeviceArrayF32Py::new_from_rust(inner, ctx, dev_id))
1876}
1877
1878#[cfg(test)]
1879mod tests {
1880 use super::*;
1881 use crate::skip_if_unsupported;
1882 use crate::utilities::data_loader::read_candles_from_csv;
1883 #[cfg(feature = "proptest")]
1884 use proptest::prelude::*;
1885
1886 #[test]
1887 fn test_correl_hl_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1888 let n = 256usize;
1889 let mut ts = Vec::with_capacity(n);
1890 let mut open = Vec::with_capacity(n);
1891 let mut high = Vec::with_capacity(n);
1892 let mut low = Vec::with_capacity(n);
1893 let mut close = Vec::with_capacity(n);
1894 let mut vol = Vec::with_capacity(n);
1895
1896 let mut cur = 100.0f64;
1897 for i in 0..n {
1898 let step = ((i as f64).sin() * 0.5) + 0.1;
1899 let o = cur;
1900 let c = cur + step;
1901 let (lo, hi) = if c >= o {
1902 (o - 0.3, c + 0.4)
1903 } else {
1904 (c - 0.3, o + 0.4)
1905 };
1906 ts.push(i as i64);
1907 open.push(o);
1908 close.push(c);
1909 high.push(hi);
1910 low.push(lo);
1911 vol.push(1000.0 + (i % 10) as f64);
1912 cur = c;
1913 }
1914
1915 let candles = crate::utilities::data_loader::Candles::new(
1916 ts,
1917 open,
1918 high.clone(),
1919 low.clone(),
1920 close,
1921 vol,
1922 );
1923
1924 let input = CorrelHlInput::from_candles(&candles, CorrelHlParams::default());
1925
1926 let baseline = correl_hl(&input)?;
1927
1928 let mut out = vec![0.0f64; n];
1929 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1930 {
1931 correl_hl_into(&mut out, &input)?;
1932 }
1933 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1934 {
1935 correl_hl_into_slice(&mut out, &input, Kernel::Auto)?;
1936 }
1937
1938 assert_eq!(baseline.values.len(), out.len());
1939 for (a, b) in baseline.values.iter().zip(out.iter()) {
1940 let equal = (a.is_nan() && b.is_nan()) || (*a == *b) || ((*a - *b).abs() <= 1e-12);
1941 assert!(equal, "Mismatch: baseline={} into={}", a, b);
1942 }
1943
1944 Ok(())
1945 }
1946
1947 fn check_correl_hl_partial_params(
1948 test_name: &str,
1949 kernel: Kernel,
1950 ) -> Result<(), Box<dyn std::error::Error>> {
1951 skip_if_unsupported!(kernel, test_name);
1952 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1953 let candles = read_candles_from_csv(file_path)?;
1954 let params = CorrelHlParams { period: None };
1955 let input = CorrelHlInput::from_candles(&candles, params);
1956 let output = correl_hl_with_kernel(&input, kernel)?;
1957 assert_eq!(output.values.len(), candles.close.len());
1958 Ok(())
1959 }
1960
1961 fn check_correl_hl_accuracy(
1962 test_name: &str,
1963 kernel: Kernel,
1964 ) -> Result<(), Box<dyn std::error::Error>> {
1965 skip_if_unsupported!(kernel, test_name);
1966 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1967 let candles = read_candles_from_csv(file_path)?;
1968 let params = CorrelHlParams { period: Some(5) };
1969 let input = CorrelHlInput::from_candles(&candles, params);
1970 let result = correl_hl_with_kernel(&input, kernel)?;
1971 let expected = [
1972 0.04589155420456278,
1973 0.6491664099299647,
1974 0.9691259236943873,
1975 0.9915438003818791,
1976 0.8460608423095615,
1977 ];
1978 let start_index = result.values.len() - 5;
1979 for (i, &val) in result.values[start_index..].iter().enumerate() {
1980 let exp = expected[i];
1981 let diff = (val - exp).abs();
1982 assert!(
1983 diff < 1e-7,
1984 "[{}] Value mismatch at index {}: expected {}, got {}",
1985 test_name,
1986 i,
1987 exp,
1988 val
1989 );
1990 }
1991 Ok(())
1992 }
1993
1994 fn check_correl_hl_zero_period(
1995 test_name: &str,
1996 kernel: Kernel,
1997 ) -> Result<(), Box<dyn std::error::Error>> {
1998 skip_if_unsupported!(kernel, test_name);
1999 let high = [1.0, 2.0, 3.0];
2000 let low = [1.0, 2.0, 3.0];
2001 let params = CorrelHlParams { period: Some(0) };
2002 let input = CorrelHlInput::from_slices(&high, &low, params);
2003 let result = correl_hl_with_kernel(&input, kernel);
2004 assert!(
2005 result.is_err(),
2006 "[{}] correl_hl should fail with zero period",
2007 test_name
2008 );
2009 Ok(())
2010 }
2011
2012 fn check_correl_hl_period_exceeds_length(
2013 test_name: &str,
2014 kernel: Kernel,
2015 ) -> Result<(), Box<dyn std::error::Error>> {
2016 skip_if_unsupported!(kernel, test_name);
2017 let high = [1.0, 2.0, 3.0];
2018 let low = [1.0, 2.0, 3.0];
2019 let params = CorrelHlParams { period: Some(10) };
2020 let input = CorrelHlInput::from_slices(&high, &low, params);
2021 let result = correl_hl_with_kernel(&input, kernel);
2022 assert!(
2023 result.is_err(),
2024 "[{}] correl_hl should fail with period exceeding length",
2025 test_name
2026 );
2027 Ok(())
2028 }
2029
2030 fn check_correl_hl_data_length_mismatch(
2031 test_name: &str,
2032 kernel: Kernel,
2033 ) -> Result<(), Box<dyn std::error::Error>> {
2034 skip_if_unsupported!(kernel, test_name);
2035 let high = [1.0, 2.0, 3.0];
2036 let low = [1.0, 2.0];
2037 let params = CorrelHlParams { period: Some(2) };
2038 let input = CorrelHlInput::from_slices(&high, &low, params);
2039 let result = correl_hl_with_kernel(&input, kernel);
2040 assert!(
2041 result.is_err(),
2042 "[{}] correl_hl should fail on length mismatch",
2043 test_name
2044 );
2045 Ok(())
2046 }
2047
2048 fn check_correl_hl_all_nan(
2049 test_name: &str,
2050 kernel: Kernel,
2051 ) -> Result<(), Box<dyn std::error::Error>> {
2052 skip_if_unsupported!(kernel, test_name);
2053 let high = [f64::NAN, f64::NAN, f64::NAN];
2054 let low = [f64::NAN, f64::NAN, f64::NAN];
2055 let params = CorrelHlParams { period: Some(2) };
2056 let input = CorrelHlInput::from_slices(&high, &low, params);
2057 let result = correl_hl_with_kernel(&input, kernel);
2058 assert!(
2059 result.is_err(),
2060 "[{}] correl_hl should fail on all NaN",
2061 test_name
2062 );
2063 Ok(())
2064 }
2065
2066 fn check_correl_hl_from_candles(
2067 test_name: &str,
2068 kernel: Kernel,
2069 ) -> Result<(), Box<dyn std::error::Error>> {
2070 skip_if_unsupported!(kernel, test_name);
2071 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2072 let candles = read_candles_from_csv(file_path)?;
2073 let params = CorrelHlParams { period: Some(9) };
2074 let input = CorrelHlInput::from_candles(&candles, params);
2075 let output = correl_hl_with_kernel(&input, kernel)?;
2076 assert_eq!(output.values.len(), candles.close.len());
2077 Ok(())
2078 }
2079
2080 fn check_correl_hl_reinput(
2081 test_name: &str,
2082 kernel: Kernel,
2083 ) -> Result<(), Box<dyn std::error::Error>> {
2084 skip_if_unsupported!(kernel, test_name);
2085 let high = [1.0, 2.0, 3.0, 4.0, 5.0];
2086 let low = [0.5, 1.0, 1.5, 2.0, 2.5];
2087 let params = CorrelHlParams { period: Some(2) };
2088 let first_input = CorrelHlInput::from_slices(&high, &low, params.clone());
2089 let first_result = correl_hl_with_kernel(&first_input, kernel)?;
2090 let second_input = CorrelHlInput::from_slices(&first_result.values, &low, params);
2091 let second_result = correl_hl_with_kernel(&second_input, kernel)?;
2092 assert_eq!(second_result.values.len(), low.len());
2093 Ok(())
2094 }
2095
2096 fn check_correl_hl_very_small_dataset(
2097 test_name: &str,
2098 kernel: Kernel,
2099 ) -> Result<(), Box<dyn std::error::Error>> {
2100 skip_if_unsupported!(kernel, test_name);
2101
2102 let single_high = [42.0];
2103 let single_low = [21.0];
2104 let params = CorrelHlParams { period: Some(1) };
2105 let input = CorrelHlInput::from_slices(&single_high, &single_low, params);
2106 let result = correl_hl_with_kernel(&input, kernel)?;
2107 assert_eq!(result.values.len(), 1);
2108
2109 assert!(result.values[0].is_nan() || result.values[0].abs() < f64::EPSILON);
2110 Ok(())
2111 }
2112
2113 fn check_correl_hl_empty_input(
2114 test_name: &str,
2115 kernel: Kernel,
2116 ) -> Result<(), Box<dyn std::error::Error>> {
2117 skip_if_unsupported!(kernel, test_name);
2118 let empty_high: [f64; 0] = [];
2119 let empty_low: [f64; 0] = [];
2120 let params = CorrelHlParams { period: Some(5) };
2121 let input = CorrelHlInput::from_slices(&empty_high, &empty_low, params);
2122 let result = correl_hl_with_kernel(&input, kernel);
2123 assert!(
2124 result.is_err(),
2125 "[{}] correl_hl should fail on empty input",
2126 test_name
2127 );
2128 Ok(())
2129 }
2130
2131 fn check_correl_hl_nan_handling(
2132 test_name: &str,
2133 kernel: Kernel,
2134 ) -> Result<(), Box<dyn std::error::Error>> {
2135 skip_if_unsupported!(kernel, test_name);
2136
2137 let high = vec![1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
2138 let low = vec![0.5, 1.0, 1.5, f64::NAN, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0];
2139 let params = CorrelHlParams { period: Some(3) };
2140 let input = CorrelHlInput::from_slices(&high, &low, params);
2141 let result = correl_hl_with_kernel(&input, kernel)?;
2142
2143 assert_eq!(result.values.len(), high.len());
2144
2145 let mut valid_count = 0;
2146 for i in 0..high.len() {
2147 if !high[i].is_nan() && !low[i].is_nan() {
2148 valid_count += 1;
2149 if valid_count >= 3 {
2150 let has_valid = result.values[i..].iter().any(|&v| !v.is_nan());
2151 assert!(
2152 has_valid,
2153 "[{}] Should have valid correlations after enough data",
2154 test_name
2155 );
2156 break;
2157 }
2158 }
2159 }
2160 Ok(())
2161 }
2162
2163 fn check_correl_hl_streaming(
2164 test_name: &str,
2165 kernel: Kernel,
2166 ) -> Result<(), Box<dyn std::error::Error>> {
2167 skip_if_unsupported!(kernel, test_name);
2168
2169 let params = CorrelHlParams { period: Some(3) };
2170 let mut stream = CorrelHlStream::try_new(params)?;
2171
2172 let high_data = [1.0, 2.0, 3.0, 4.0, 5.0];
2173 let low_data = [0.5, 1.0, 1.5, 2.0, 2.5];
2174
2175 assert!(stream.update(high_data[0], low_data[0]).is_none());
2176 assert!(stream.update(high_data[1], low_data[1]).is_none());
2177
2178 let first_corr = stream.update(high_data[2], low_data[2]);
2179 assert!(first_corr.is_some());
2180
2181 let second_corr = stream.update(high_data[3], low_data[3]);
2182 assert!(second_corr.is_some());
2183
2184 let params_batch = CorrelHlParams { period: Some(3) };
2185 let input_batch = CorrelHlInput::from_slices(&high_data[..4], &low_data[..4], params_batch);
2186 let batch_result = correl_hl_with_kernel(&input_batch, kernel)?;
2187
2188 if let Some(batch_val) = batch_result.values.iter().rev().find(|&&v| !v.is_nan()) {
2189 if let Some(stream_val) = second_corr {
2190 assert!(
2191 (batch_val - stream_val).abs() < 1e-10,
2192 "[{}] Streaming vs batch mismatch: {} vs {}",
2193 test_name,
2194 stream_val,
2195 batch_val
2196 );
2197 }
2198 }
2199
2200 Ok(())
2201 }
2202
2203 #[cfg(feature = "proptest")]
2204 #[allow(clippy::float_cmp)]
2205 fn check_correl_hl_property(
2206 test_name: &str,
2207 kernel: Kernel,
2208 ) -> Result<(), Box<dyn std::error::Error>> {
2209 use proptest::prelude::*;
2210 skip_if_unsupported!(kernel, test_name);
2211
2212 let strat = (2usize..=100).prop_flat_map(|period| {
2213 (
2214 prop::collection::vec((1.0f64..1000.0f64), period..400).prop_flat_map(
2215 move |close_prices| {
2216 let len = close_prices.len();
2217 (
2218 Just(close_prices.clone()),
2219 prop::collection::vec((0.001f64..0.05f64, 0.001f64..0.05f64), len),
2220 )
2221 .prop_map(move |(close, spreads)| {
2222 let mut high = Vec::with_capacity(len);
2223 let mut low = Vec::with_capacity(len);
2224
2225 for (i, &close_price) in close.iter().enumerate() {
2226 let (up_spread, down_spread) = spreads[i];
2227
2228 high.push(close_price * (1.0 + up_spread));
2229 low.push(close_price * (1.0 - down_spread));
2230 }
2231
2232 (high, low)
2233 })
2234 },
2235 ),
2236 Just(period),
2237 )
2238 });
2239
2240 proptest::test_runner::TestRunner::default()
2241 .run(&strat, |((high, low), period)| {
2242 let params = CorrelHlParams {
2243 period: Some(period),
2244 };
2245 let input = CorrelHlInput::from_slices(&high, &low, params);
2246
2247 let result = correl_hl_with_kernel(&input, kernel);
2248 let reference = correl_hl_with_kernel(&input, Kernel::Scalar);
2249
2250 match (result, reference) {
2251 (Ok(output), Ok(ref_output)) => {
2252 let out = &output.values;
2253 let ref_out = &ref_output.values;
2254
2255 prop_assert_eq!(out.len(), high.len());
2256
2257 let warmup_len = period.saturating_sub(1).min(high.len());
2258 for i in 0..warmup_len {
2259 prop_assert!(
2260 out[i].is_nan(),
2261 "Expected NaN during warmup at index {}, got {}",
2262 i,
2263 out[i]
2264 );
2265 }
2266
2267 for i in 0..out.len() {
2268 let y = out[i];
2269 let r = ref_out[i];
2270
2271 if !y.is_finite() || !r.is_finite() {
2272 prop_assert_eq!(
2273 y.to_bits(),
2274 r.to_bits(),
2275 "Special value mismatch at index {}: {} vs {}",
2276 i,
2277 y,
2278 r
2279 );
2280 continue;
2281 }
2282
2283 let ulp_diff = y.to_bits().abs_diff(r.to_bits());
2284 prop_assert!(
2285 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
2286 "Kernel mismatch at index {}: {} vs {} (ULP={})",
2287 i,
2288 y,
2289 r,
2290 ulp_diff
2291 );
2292 }
2293
2294 for (i, &val) in out.iter().enumerate() {
2295 if !val.is_nan() {
2296 let tolerance = 1e-3;
2297 prop_assert!(
2298 val >= -1.0 - tolerance && val <= 1.0 + tolerance,
2299 "Correlation at index {} out of range: {}",
2300 i,
2301 val
2302 );
2303 }
2304 }
2305 }
2306 (Err(_), Err(_)) => {}
2307 (Ok(_), Err(e)) => {
2308 prop_assert!(
2309 false,
2310 "Reference kernel failed but test kernel succeeded: {}",
2311 e
2312 );
2313 }
2314 (Err(e), Ok(_)) => {
2315 prop_assert!(
2316 false,
2317 "Test kernel failed but reference kernel succeeded: {}",
2318 e
2319 );
2320 }
2321 }
2322
2323 Ok(())
2324 })
2325 .unwrap();
2326
2327 Ok(())
2328 }
2329
2330 #[cfg(debug_assertions)]
2331 fn check_correl_hl_no_poison(
2332 test_name: &str,
2333 kernel: Kernel,
2334 ) -> Result<(), Box<dyn std::error::Error>> {
2335 skip_if_unsupported!(kernel, test_name);
2336
2337 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2338 let candles = read_candles_from_csv(file_path)?;
2339
2340 let input = CorrelHlInput::from_candles(&candles, CorrelHlParams::default());
2341 let output = correl_hl_with_kernel(&input, kernel)?;
2342
2343 for (i, &val) in output.values.iter().enumerate() {
2344 if val.is_nan() {
2345 continue;
2346 }
2347
2348 let bits = val.to_bits();
2349
2350 if bits == 0x11111111_11111111 {
2351 panic!(
2352 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {}",
2353 test_name, val, bits, i
2354 );
2355 }
2356
2357 if bits == 0x22222222_22222222 {
2358 panic!(
2359 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {}",
2360 test_name, val, bits, i
2361 );
2362 }
2363
2364 if bits == 0x33333333_33333333 {
2365 panic!(
2366 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {}",
2367 test_name, val, bits, i
2368 );
2369 }
2370 }
2371
2372 Ok(())
2373 }
2374
2375 #[cfg(not(debug_assertions))]
2376 fn check_correl_hl_no_poison(
2377 _test_name: &str,
2378 _kernel: Kernel,
2379 ) -> Result<(), Box<dyn std::error::Error>> {
2380 Ok(())
2381 }
2382
2383 macro_rules! generate_all_correl_hl_tests {
2384 ($($test_fn:ident),*) => {
2385 paste::paste! {
2386 $(
2387 #[test]
2388 fn [<$test_fn _scalar_f64>]() {
2389 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2390 }
2391 )*
2392 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2393 $(
2394 #[test]
2395 fn [<$test_fn _avx2_f64>]() {
2396 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2397 }
2398 #[test]
2399 fn [<$test_fn _avx512_f64>]() {
2400 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2401 }
2402 )*
2403 }
2404 }
2405 }
2406
2407 #[test]
2408 fn test_period_one_bug() {
2409 let high = vec![100.0, 200.0, 300.0];
2410 let low = vec![90.0, 190.0, 310.0];
2411
2412 let params = CorrelHlParams { period: Some(1) };
2413 let input = CorrelHlInput::from_slices(&high, &low, params.clone());
2414 let result = correl_hl(&input).unwrap();
2415
2416 println!("Period=1 correlation with different high/low:");
2417 for (i, &val) in result.values.iter().enumerate() {
2418 println!(
2419 " Index {}: high={}, low={}, corr={}",
2420 i, high[i], low[i], val
2421 );
2422
2423 assert!(
2424 val.is_nan() || (val >= -1.0 && val <= 1.0),
2425 "Period=1 correlation at index {} out of bounds: {}",
2426 i,
2427 val
2428 );
2429 }
2430
2431 let high2 = vec![100.0, 200.0, 300.0];
2432 let low2 = vec![100.0, 200.0, 300.0];
2433
2434 let input2 = CorrelHlInput::from_slices(&high2, &low2, params.clone());
2435 let result2 = correl_hl(&input2).unwrap();
2436
2437 println!("\nPeriod=1 correlation with identical high/low:");
2438 for (i, &val) in result2.values.iter().enumerate() {
2439 println!(
2440 " Index {}: high={}, low={}, corr={}",
2441 i, high2[i], low2[i], val
2442 );
2443 }
2444 }
2445
2446 generate_all_correl_hl_tests!(
2447 check_correl_hl_partial_params,
2448 check_correl_hl_accuracy,
2449 check_correl_hl_zero_period,
2450 check_correl_hl_period_exceeds_length,
2451 check_correl_hl_data_length_mismatch,
2452 check_correl_hl_all_nan,
2453 check_correl_hl_from_candles,
2454 check_correl_hl_reinput,
2455 check_correl_hl_very_small_dataset,
2456 check_correl_hl_empty_input,
2457 check_correl_hl_nan_handling,
2458 check_correl_hl_streaming,
2459 check_correl_hl_no_poison
2460 );
2461
2462 #[cfg(feature = "proptest")]
2463 generate_all_correl_hl_tests!(check_correl_hl_property);
2464
2465 fn check_batch_default_row(
2466 test: &str,
2467 kernel: Kernel,
2468 ) -> Result<(), Box<dyn std::error::Error>> {
2469 skip_if_unsupported!(kernel, test);
2470
2471 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2472 let c = read_candles_from_csv(file)?;
2473
2474 let output = CorrelHlBatchBuilder::new()
2475 .kernel(kernel)
2476 .apply_candles(&c)?;
2477
2478 let def = CorrelHlParams::default();
2479 let row = output.values_for(&def).expect("default row missing");
2480
2481 assert_eq!(row.len(), c.close.len());
2482 Ok(())
2483 }
2484
2485 #[cfg(debug_assertions)]
2486 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
2487 skip_if_unsupported!(kernel, test);
2488
2489 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2490 let c = read_candles_from_csv(file)?;
2491
2492 let output = CorrelHlBatchBuilder::new()
2493 .kernel(kernel)
2494 .period_range(5, 20, 5)
2495 .apply_candles(&c)?;
2496
2497 for (idx, &val) in output.values.iter().enumerate() {
2498 if val.is_nan() {
2499 continue;
2500 }
2501
2502 let bits = val.to_bits();
2503 let row = idx / output.cols;
2504 let col = idx % output.cols;
2505
2506 if bits == 0x11111111_11111111 {
2507 panic!(
2508 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {})",
2509 test, val, bits, row, col, idx
2510 );
2511 }
2512
2513 if bits == 0x22222222_22222222 {
2514 panic!(
2515 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {})",
2516 test, val, bits, row, col, idx
2517 );
2518 }
2519
2520 if bits == 0x33333333_33333333 {
2521 panic!(
2522 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {})",
2523 test, val, bits, row, col, idx
2524 );
2525 }
2526 }
2527
2528 Ok(())
2529 }
2530
2531 #[cfg(not(debug_assertions))]
2532 fn check_batch_no_poison(
2533 _test: &str,
2534 _kernel: Kernel,
2535 ) -> Result<(), Box<dyn std::error::Error>> {
2536 Ok(())
2537 }
2538
2539 macro_rules! gen_batch_tests {
2540 ($fn_name:ident) => {
2541 paste::paste! {
2542 #[test] fn [<$fn_name _scalar>]() {
2543 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2544 }
2545 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2546 #[test] fn [<$fn_name _avx2>]() {
2547 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2548 }
2549 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2550 #[test] fn [<$fn_name _avx512>]() {
2551 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2552 }
2553 #[test] fn [<$fn_name _auto_detect>]() {
2554 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2555 }
2556 }
2557 };
2558 }
2559 gen_batch_tests!(check_batch_default_row);
2560 gen_batch_tests!(check_batch_no_poison);
2561}