Skip to main content

vector_ta/indicators/
correl_hl.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use numpy::PyUntypedArrayMethods;
3#[cfg(feature = "python")]
4use numpy::{IntoPyArray, PyArray1};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyValueError;
7#[cfg(feature = "python")]
8use pyo3::prelude::*;
9#[cfg(feature = "python")]
10use pyo3::types::PyDict;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::Candles;
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21    make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
26use core::arch::x86_64::*;
27#[cfg(not(target_arch = "wasm32"))]
28use rayon::prelude::*;
29use std::mem::{ManuallyDrop, MaybeUninit};
30#[cfg(all(feature = "python", feature = "cuda"))]
31use std::sync::Arc;
32use thiserror::Error;
33
34#[cfg(all(feature = "python", feature = "cuda"))]
35use crate::cuda::moving_averages::DeviceArrayF32;
36#[cfg(all(feature = "python", feature = "cuda"))]
37use cust::context::Context;
38#[cfg(all(feature = "python", feature = "cuda"))]
39use cust::memory::DeviceBuffer;
40
41#[derive(Debug, Clone)]
42pub enum CorrelHlData<'a> {
43    Candles { candles: &'a Candles },
44    Slices { high: &'a [f64], low: &'a [f64] },
45}
46
47#[derive(Debug, Clone)]
48pub struct CorrelHlOutput {
49    pub values: Vec<f64>,
50}
51
52#[derive(Debug, Clone)]
53pub struct CorrelHlParams {
54    pub period: Option<usize>,
55}
56
57impl Default for CorrelHlParams {
58    fn default() -> Self {
59        Self { period: Some(9) }
60    }
61}
62
63#[derive(Debug, Clone)]
64pub struct CorrelHlInput<'a> {
65    pub data: CorrelHlData<'a>,
66    pub params: CorrelHlParams,
67}
68
69impl<'a> CorrelHlInput<'a> {
70    #[inline]
71    pub fn from_candles(candles: &'a Candles, params: CorrelHlParams) -> Self {
72        Self {
73            data: CorrelHlData::Candles { candles },
74            params,
75        }
76    }
77
78    #[inline]
79    pub fn from_slices(high: &'a [f64], low: &'a [f64], params: CorrelHlParams) -> Self {
80        Self {
81            data: CorrelHlData::Slices { high, low },
82            params,
83        }
84    }
85
86    #[inline]
87    pub fn with_default_candles(candles: &'a Candles) -> Self {
88        Self::from_candles(candles, CorrelHlParams::default())
89    }
90
91    #[inline]
92    pub fn get_period(&self) -> usize {
93        self.params.period.unwrap_or(9)
94    }
95
96    #[inline(always)]
97    pub fn as_refs(&'a self) -> Result<(&'a [f64], &'a [f64]), CorrelHlError> {
98        match &self.data {
99            CorrelHlData::Candles { candles } => {
100                let hi = candles
101                    .select_candle_field("high")
102                    .map_err(|_| CorrelHlError::CandleFieldError { field: "high" })?;
103                let lo = candles
104                    .select_candle_field("low")
105                    .map_err(|_| CorrelHlError::CandleFieldError { field: "low" })?;
106                Ok((hi, lo))
107            }
108            CorrelHlData::Slices { high, low } => Ok((*high, *low)),
109        }
110    }
111
112    #[inline(always)]
113    pub fn period_or_default(&self) -> usize {
114        self.params.period.unwrap_or(9)
115    }
116}
117
118#[derive(Copy, Clone, Debug)]
119pub struct CorrelHlBuilder {
120    period: Option<usize>,
121    kernel: Kernel,
122}
123
124impl Default for CorrelHlBuilder {
125    fn default() -> Self {
126        Self {
127            period: None,
128            kernel: Kernel::Auto,
129        }
130    }
131}
132
133impl CorrelHlBuilder {
134    #[inline(always)]
135    pub fn new() -> Self {
136        Self::default()
137    }
138    #[inline(always)]
139    pub fn period(mut self, n: usize) -> Self {
140        self.period = Some(n);
141        self
142    }
143    #[inline(always)]
144    pub fn kernel(mut self, k: Kernel) -> Self {
145        self.kernel = k;
146        self
147    }
148
149    #[inline(always)]
150    pub fn apply(self, candles: &Candles) -> Result<CorrelHlOutput, CorrelHlError> {
151        let params = CorrelHlParams {
152            period: self.period,
153        };
154        let input = CorrelHlInput::from_candles(candles, params);
155        correl_hl_with_kernel(&input, self.kernel)
156    }
157
158    #[inline(always)]
159    pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<CorrelHlOutput, CorrelHlError> {
160        let params = CorrelHlParams {
161            period: self.period,
162        };
163        let input = CorrelHlInput::from_slices(high, low, params);
164        correl_hl_with_kernel(&input, self.kernel)
165    }
166
167    #[inline(always)]
168    pub fn into_stream(self) -> Result<CorrelHlStream, CorrelHlError> {
169        let params = CorrelHlParams {
170            period: self.period,
171        };
172        CorrelHlStream::try_new(params)
173    }
174}
175
176#[derive(Debug, Error)]
177pub enum CorrelHlError {
178    #[error("correl_hl: Empty data (high or low).")]
179    EmptyInputData,
180    #[error("correl_hl: Invalid period: period = {period}, data length = {data_len}")]
181    InvalidPeriod { period: usize, data_len: usize },
182    #[error("correl_hl: Data length mismatch between high and low.")]
183    DataLengthMismatch,
184    #[error("correl_hl: Not enough valid data: needed = {needed}, valid = {valid}")]
185    NotEnoughValidData { needed: usize, valid: usize },
186    #[error("correl_hl: All values are NaN in high or low.")]
187    AllValuesNaN,
188    #[error("correl_hl: Candle field error: {field}")]
189    CandleFieldError { field: &'static str },
190    #[error("correl_hl: Output length mismatch (expected {expected}, got {got})")]
191    OutputLengthMismatch { expected: usize, got: usize },
192    #[error("correl_hl: invalid input: {0}")]
193    InvalidInput(&'static str),
194
195    #[error("correl_hl: invalid range: start={start} end={end} step={step}")]
196    InvalidRange {
197        start: usize,
198        end: usize,
199        step: usize,
200    },
201    #[error("correl_hl: invalid kernel for batch path: {0:?}")]
202    InvalidKernelForBatch(Kernel),
203}
204
205#[inline]
206pub fn correl_hl(input: &CorrelHlInput) -> Result<CorrelHlOutput, CorrelHlError> {
207    correl_hl_with_kernel(input, Kernel::Auto)
208}
209
210#[inline(always)]
211fn correl_hl_prepare<'a>(
212    input: &'a CorrelHlInput,
213    kernel: Kernel,
214) -> Result<(&'a [f64], &'a [f64], usize, usize, Kernel), CorrelHlError> {
215    let (high, low) = input.as_refs()?;
216    if high.is_empty() || low.is_empty() {
217        return Err(CorrelHlError::EmptyInputData);
218    }
219    if high.len() != low.len() {
220        return Err(CorrelHlError::DataLengthMismatch);
221    }
222
223    let period = input.period_or_default();
224    if period == 0 || period > high.len() {
225        return Err(CorrelHlError::InvalidPeriod {
226            period,
227            data_len: high.len(),
228        });
229    }
230
231    let first = high
232        .iter()
233        .zip(low.iter())
234        .position(|(&h, &l)| !h.is_nan() && !l.is_nan())
235        .ok_or(CorrelHlError::AllValuesNaN)?;
236
237    if high.len() - first < period {
238        return Err(CorrelHlError::NotEnoughValidData {
239            needed: period,
240            valid: high.len() - first,
241        });
242    }
243
244    let chosen = match kernel {
245        Kernel::Auto => detect_best_kernel(),
246        k => k,
247    };
248    Ok((high, low, period, first, chosen))
249}
250
251#[inline(always)]
252fn correl_hl_compute_into(
253    high: &[f64],
254    low: &[f64],
255    period: usize,
256    first: usize,
257    kern: Kernel,
258    out: &mut [f64],
259) {
260    unsafe {
261        match kern {
262            Kernel::Scalar | Kernel::ScalarBatch => correl_hl_scalar(high, low, period, first, out),
263            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
264            Kernel::Avx2 | Kernel::Avx2Batch => correl_hl_avx2(high, low, period, first, out),
265            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
266            Kernel::Avx512 | Kernel::Avx512Batch => correl_hl_avx512(high, low, period, first, out),
267            _ => correl_hl_scalar(high, low, period, first, out),
268        }
269    }
270}
271
272pub fn correl_hl_with_kernel(
273    input: &CorrelHlInput,
274    kernel: Kernel,
275) -> Result<CorrelHlOutput, CorrelHlError> {
276    let (high, low, period, first, chosen) = correl_hl_prepare(input, kernel)?;
277    let warm = first + period - 1;
278    let mut out = alloc_with_nan_prefix(high.len(), warm);
279    correl_hl_compute_into(high, low, period, first, chosen, &mut out);
280    Ok(CorrelHlOutput { values: out })
281}
282
283#[inline]
284pub fn correl_hl_into_slice(
285    dst: &mut [f64],
286    input: &CorrelHlInput,
287    kernel: Kernel,
288) -> Result<(), CorrelHlError> {
289    let (high, low, period, first, chosen) = correl_hl_prepare(input, kernel)?;
290    if dst.len() != high.len() {
291        return Err(CorrelHlError::OutputLengthMismatch {
292            expected: high.len(),
293            got: dst.len(),
294        });
295    }
296    correl_hl_compute_into(high, low, period, first, chosen, dst);
297    let warm = first + period - 1;
298    for v in &mut dst[..warm] {
299        *v = f64::from_bits(0x7ff8_0000_0000_0000);
300    }
301    Ok(())
302}
303
304#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
305#[inline]
306pub fn correl_hl_into(out: &mut [f64], input: &CorrelHlInput) -> Result<(), CorrelHlError> {
307    let (high, _low, period, first, _chosen) = correl_hl_prepare(input, Kernel::Auto)?;
308    if out.len() != high.len() {
309        return Err(CorrelHlError::OutputLengthMismatch {
310            expected: high.len(),
311            got: out.len(),
312        });
313    }
314
315    let warm = first + period - 1;
316    let warm_cap = warm.min(out.len());
317    for v in &mut out[..warm_cap] {
318        *v = f64::from_bits(0x7ff8_0000_0000_0000);
319    }
320
321    correl_hl_into_slice(out, input, Kernel::Auto)
322}
323
324#[inline]
325pub fn correl_hl_scalar(high: &[f64], low: &[f64], period: usize, first: usize, out: &mut [f64]) {
326    let mut sum_h = 0.0_f64;
327    let mut sum_h2 = 0.0_f64;
328    let mut sum_l = 0.0_f64;
329    let mut sum_l2 = 0.0_f64;
330    let mut sum_hl = 0.0_f64;
331
332    let inv_pf = 1.0 / (period as f64);
333
334    #[inline(always)]
335    fn corr_from_sums(
336        sum_h: f64,
337        sum_h2: f64,
338        sum_l: f64,
339        sum_l2: f64,
340        sum_hl: f64,
341        inv_pf: f64,
342    ) -> f64 {
343        let cov = sum_hl - (sum_h * sum_l) * inv_pf;
344        let var_h = sum_h2 - (sum_h * sum_h) * inv_pf;
345        let var_l = sum_l2 - (sum_l * sum_l) * inv_pf;
346        if var_h <= 0.0 || var_l <= 0.0 {
347            0.0
348        } else {
349            cov / (var_h.sqrt() * var_l.sqrt())
350        }
351    }
352
353    let init_start = first;
354    let init_end = first + period;
355    let mut j = init_start;
356
357    while j + 4 <= init_end {
358        let h0 = high[j + 0];
359        let l0 = low[j + 0];
360        let h1 = high[j + 1];
361        let l1 = low[j + 1];
362        let h2 = high[j + 2];
363        let l2 = low[j + 2];
364        let h3 = high[j + 3];
365        let l3 = low[j + 3];
366
367        sum_h += h0 + h1 + h2 + h3;
368        sum_l += l0 + l1 + l2 + l3;
369        sum_h2 += h0 * h0 + h1 * h1 + h2 * h2 + h3 * h3;
370        sum_l2 += l0 * l0 + l1 * l1 + l2 * l2 + l3 * l3;
371        sum_hl += h0 * l0 + h1 * l1 + h2 * l2 + h3 * l3;
372        j += 4;
373    }
374    while j < init_end {
375        let h = high[j];
376        let l = low[j];
377        sum_h += h;
378        sum_l += l;
379        sum_h2 += h * h;
380        sum_l2 += l * l;
381        sum_hl += h * l;
382        j += 1;
383    }
384
385    let warm = init_end - 1;
386    out[warm] = corr_from_sums(sum_h, sum_h2, sum_l, sum_l2, sum_hl, inv_pf);
387
388    let n = high.len();
389    for i in init_end..n {
390        let old_idx = i - period;
391        let new_idx = i;
392        let old_h = high[old_idx];
393        let old_l = low[old_idx];
394        let new_h = high[new_idx];
395        let new_l = low[new_idx];
396
397        if old_h.is_nan() || old_l.is_nan() || new_h.is_nan() || new_l.is_nan() {
398            let start = i + 1 - period;
399            let end = i + 1;
400            sum_h = 0.0;
401            sum_l = 0.0;
402            sum_h2 = 0.0;
403            sum_l2 = 0.0;
404            sum_hl = 0.0;
405            let mut k = start;
406            while k + 4 <= end {
407                let h0 = high[k + 0];
408                let l0 = low[k + 0];
409                let h1 = high[k + 1];
410                let l1 = low[k + 1];
411                let h2 = high[k + 2];
412                let l2 = low[k + 2];
413                let h3 = high[k + 3];
414                let l3 = low[k + 3];
415                sum_h += h0 + h1 + h2 + h3;
416                sum_l += l0 + l1 + l2 + l3;
417                sum_h2 += h0 * h0 + h1 * h1 + h2 * h2 + h3 * h3;
418                sum_l2 += l0 * l0 + l1 * l1 + l2 * l2 + l3 * l3;
419                sum_hl += h0 * l0 + h1 * l1 + h2 * l2 + h3 * l3;
420                k += 4;
421            }
422            while k < end {
423                let h = high[k];
424                let l = low[k];
425                sum_h += h;
426                sum_l += l;
427                sum_h2 += h * h;
428                sum_l2 += l * l;
429                sum_hl += h * l;
430                k += 1;
431            }
432        } else {
433            sum_h += new_h - old_h;
434            sum_l += new_l - old_l;
435            sum_h2 += new_h * new_h - old_h * old_h;
436            sum_l2 += new_l * new_l - old_l * old_l;
437            let old_hl = old_h * old_l;
438            sum_hl = new_h.mul_add(new_l, sum_hl - old_hl);
439        }
440
441        out[i] = corr_from_sums(sum_h, sum_h2, sum_l, sum_l2, sum_hl, inv_pf);
442    }
443}
444
445#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
446#[inline]
447pub fn correl_hl_avx2(high: &[f64], low: &[f64], period: usize, first: usize, out: &mut [f64]) {
448    unsafe {
449        #[inline(always)]
450        unsafe fn hsum256_pd(v: __m256d) -> f64 {
451            let hi: __m128d = _mm256_extractf128_pd(v, 1);
452            let lo: __m128d = _mm256_castpd256_pd128(v);
453            let sum2 = _mm_add_pd(lo, hi);
454            let shuf = _mm_unpackhi_pd(sum2, sum2);
455            let sum1 = _mm_add_sd(sum2, shuf);
456            _mm_cvtsd_f64(sum1)
457        }
458
459        #[inline(always)]
460        unsafe fn sum_window_avx2(
461            high: &[f64],
462            low: &[f64],
463            start: usize,
464            end: usize,
465        ) -> (f64, f64, f64, f64, f64) {
466            let mut v_h = _mm256_setzero_pd();
467            let mut v_l = _mm256_setzero_pd();
468            let mut v_h2 = _mm256_setzero_pd();
469            let mut v_l2 = _mm256_setzero_pd();
470            let mut v_hl = _mm256_setzero_pd();
471
472            let mut i = start;
473            let ptr_h = high.as_ptr();
474            let ptr_l = low.as_ptr();
475
476            while i + 4 <= end {
477                let mh = _mm256_loadu_pd(ptr_h.add(i));
478                let ml = _mm256_loadu_pd(ptr_l.add(i));
479                v_h = _mm256_add_pd(v_h, mh);
480                v_l = _mm256_add_pd(v_l, ml);
481                let mh2 = _mm256_mul_pd(mh, mh);
482                let ml2 = _mm256_mul_pd(ml, ml);
483                v_h2 = _mm256_add_pd(v_h2, mh2);
484                v_l2 = _mm256_add_pd(v_l2, ml2);
485                let mhl = _mm256_mul_pd(mh, ml);
486                v_hl = _mm256_add_pd(v_hl, mhl);
487                i += 4;
488            }
489
490            let mut sum_h = hsum256_pd(v_h);
491            let mut sum_l = hsum256_pd(v_l);
492            let mut sum_h2 = hsum256_pd(v_h2);
493            let mut sum_l2 = hsum256_pd(v_l2);
494            let mut sum_hl = hsum256_pd(v_hl);
495
496            while i < end {
497                let h = *high.get_unchecked(i);
498                let l = *low.get_unchecked(i);
499                sum_h += h;
500                sum_l += l;
501                sum_h2 += h * h;
502                sum_l2 += l * l;
503                sum_hl += h * l;
504                i += 1;
505            }
506            (sum_h, sum_h2, sum_l, sum_l2, sum_hl)
507        }
508
509        #[inline(always)]
510        fn corr_from_sums(
511            sum_h: f64,
512            sum_h2: f64,
513            sum_l: f64,
514            sum_l2: f64,
515            sum_hl: f64,
516            inv_pf: f64,
517        ) -> f64 {
518            let cov = sum_hl - (sum_h * sum_l) * inv_pf;
519            let varh = sum_h2 - (sum_h * sum_h) * inv_pf;
520            let varl = sum_l2 - (sum_l * sum_l) * inv_pf;
521            if varh <= 0.0 || varl <= 0.0 {
522                0.0
523            } else {
524                cov / (varh.sqrt() * varl.sqrt())
525            }
526        }
527
528        let inv_pf = 1.0 / (period as f64);
529        let init_start = first;
530        let init_end = first + period;
531
532        let (mut sum_h, mut sum_h2, mut sum_l, mut sum_l2, mut sum_hl) =
533            sum_window_avx2(high, low, init_start, init_end);
534
535        let warm = init_end - 1;
536        out[warm] = corr_from_sums(sum_h, sum_h2, sum_l, sum_l2, sum_hl, inv_pf);
537
538        let n = high.len();
539        for i in init_end..n {
540            let old_idx = i - period;
541            let new_idx = i;
542            let old_h = *high.get_unchecked(old_idx);
543            let old_l = *low.get_unchecked(old_idx);
544            let new_h = *high.get_unchecked(new_idx);
545            let new_l = *low.get_unchecked(new_idx);
546
547            if old_h.is_nan() || old_l.is_nan() || new_h.is_nan() || new_l.is_nan() {
548                let (sh, sh2, sl, sl2, shl) = sum_window_avx2(high, low, i + 1 - period, i + 1);
549                sum_h = sh;
550                sum_h2 = sh2;
551                sum_l = sl;
552                sum_l2 = sl2;
553                sum_hl = shl;
554            } else {
555                sum_h += new_h - old_h;
556                sum_l += new_l - old_l;
557                sum_h2 += new_h * new_h - old_h * old_h;
558                sum_l2 += new_l * new_l - old_l * old_l;
559                let old_hl = old_h * old_l;
560                sum_hl = new_h.mul_add(new_l, sum_hl - old_hl);
561            }
562
563            out[i] = corr_from_sums(sum_h, sum_h2, sum_l, sum_l2, sum_hl, inv_pf);
564        }
565    }
566}
567
568#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
569#[inline]
570pub fn correl_hl_avx512(high: &[f64], low: &[f64], period: usize, first: usize, out: &mut [f64]) {
571    if period <= 32 {
572        unsafe { correl_hl_avx512_short(high, low, period, first, out) }
573    } else {
574        unsafe { correl_hl_avx512_long(high, low, period, first, out) }
575    }
576}
577
578#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
579#[inline]
580pub unsafe fn correl_hl_avx512_short(
581    high: &[f64],
582    low: &[f64],
583    period: usize,
584    first: usize,
585    out: &mut [f64],
586) {
587    correl_hl_avx512_long(high, low, period, first, out)
588}
589
590#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
591#[inline]
592pub unsafe fn correl_hl_avx512_long(
593    high: &[f64],
594    low: &[f64],
595    period: usize,
596    first: usize,
597    out: &mut [f64],
598) {
599    #[inline(always)]
600    unsafe fn hsum256_pd(v: __m256d) -> f64 {
601        let hi: __m128d = _mm256_extractf128_pd(v, 1);
602        let lo: __m128d = _mm256_castpd256_pd128(v);
603        let sum2 = _mm_add_pd(lo, hi);
604        let shuf = _mm_unpackhi_pd(sum2, sum2);
605        let sum1 = _mm_add_sd(sum2, shuf);
606        _mm_cvtsd_f64(sum1)
607    }
608
609    #[inline(always)]
610    unsafe fn hsum512_pd(v: __m512d) -> f64 {
611        let lo256: __m256d = _mm512_castpd512_pd256(v);
612        let hi256: __m256d = _mm512_extractf64x4_pd(v, 1);
613        hsum256_pd(_mm256_add_pd(lo256, hi256))
614    }
615
616    #[inline(always)]
617    unsafe fn sum_window_avx512(
618        high: &[f64],
619        low: &[f64],
620        start: usize,
621        end: usize,
622    ) -> (f64, f64, f64, f64, f64) {
623        let mut v_h = _mm512_setzero_pd();
624        let mut v_l = _mm512_setzero_pd();
625        let mut v_h2 = _mm512_setzero_pd();
626        let mut v_l2 = _mm512_setzero_pd();
627        let mut v_hl = _mm512_setzero_pd();
628
629        let ptr_h = high.as_ptr();
630        let ptr_l = low.as_ptr();
631
632        let mut i = start;
633        while i + 8 <= end {
634            let mh = _mm512_loadu_pd(ptr_h.add(i));
635            let ml = _mm512_loadu_pd(ptr_l.add(i));
636            v_h = _mm512_add_pd(v_h, mh);
637            v_l = _mm512_add_pd(v_l, ml);
638            let mh2 = _mm512_mul_pd(mh, mh);
639            let ml2 = _mm512_mul_pd(ml, ml);
640            v_h2 = _mm512_add_pd(v_h2, mh2);
641            v_l2 = _mm512_add_pd(v_l2, ml2);
642            let mhl = _mm512_mul_pd(mh, ml);
643            v_hl = _mm512_add_pd(v_hl, mhl);
644            i += 8;
645        }
646
647        let rem = (end - i) as i32;
648        if rem != 0 {
649            let mask: __mmask8 = ((1u16 << rem) - 1) as __mmask8;
650            let mh = _mm512_maskz_loadu_pd(mask, ptr_h.add(i));
651            let ml = _mm512_maskz_loadu_pd(mask, ptr_l.add(i));
652            v_h = _mm512_add_pd(v_h, mh);
653            v_l = _mm512_add_pd(v_l, ml);
654            v_h2 = _mm512_add_pd(v_h2, _mm512_mul_pd(mh, mh));
655            v_l2 = _mm512_add_pd(v_l2, _mm512_mul_pd(ml, ml));
656            v_hl = _mm512_add_pd(v_hl, _mm512_mul_pd(mh, ml));
657        }
658
659        (
660            hsum512_pd(v_h),
661            hsum512_pd(v_h2),
662            hsum512_pd(v_l),
663            hsum512_pd(v_l2),
664            hsum512_pd(v_hl),
665        )
666    }
667
668    #[inline(always)]
669    fn corr_from_sums(
670        sum_h: f64,
671        sum_h2: f64,
672        sum_l: f64,
673        sum_l2: f64,
674        sum_hl: f64,
675        inv_pf: f64,
676    ) -> f64 {
677        let cov = sum_hl - (sum_h * sum_l) * inv_pf;
678        let varh = sum_h2 - (sum_h * sum_h) * inv_pf;
679        let varl = sum_l2 - (sum_l * sum_l) * inv_pf;
680        if varh <= 0.0 || varl <= 0.0 {
681            0.0
682        } else {
683            cov / (varh.sqrt() * varl.sqrt())
684        }
685    }
686
687    let inv_pf = 1.0 / (period as f64);
688    let init_start = first;
689    let init_end = first + period;
690
691    let (mut sum_h, mut sum_h2, mut sum_l, mut sum_l2, mut sum_hl) =
692        sum_window_avx512(high, low, init_start, init_end);
693
694    let warm = init_end - 1;
695    out[warm] = corr_from_sums(sum_h, sum_h2, sum_l, sum_l2, sum_hl, inv_pf);
696
697    let n = high.len();
698    for i in init_end..n {
699        let old_idx = i - period;
700        let new_idx = i;
701        let old_h = *high.get_unchecked(old_idx);
702        let old_l = *low.get_unchecked(old_idx);
703        let new_h = *high.get_unchecked(new_idx);
704        let new_l = *low.get_unchecked(new_idx);
705
706        if old_h.is_nan() || old_l.is_nan() || new_h.is_nan() || new_l.is_nan() {
707            let (sh, sh2, sl, sl2, shl) = sum_window_avx512(high, low, i + 1 - period, i + 1);
708            sum_h = sh;
709            sum_h2 = sh2;
710            sum_l = sl;
711            sum_l2 = sl2;
712            sum_hl = shl;
713        } else {
714            sum_h += new_h - old_h;
715            sum_l += new_l - old_l;
716            sum_h2 += new_h * new_h - old_h * old_h;
717            sum_l2 += new_l * new_l - old_l * old_l;
718            let old_hl = old_h * old_l;
719            sum_hl = new_h.mul_add(new_l, sum_hl - old_hl);
720        }
721
722        out[i] = corr_from_sums(sum_h, sum_h2, sum_l, sum_l2, sum_hl, inv_pf);
723    }
724}
725
726#[derive(Debug, Clone)]
727pub struct CorrelHlStream {
728    period: usize,
729    buffer_high: Vec<f64>,
730    buffer_low: Vec<f64>,
731    head: usize,
732    len: usize,
733    nan_in_win: usize,
734
735    sum_h: f64,
736    sum_h2: f64,
737    sum_l: f64,
738    sum_l2: f64,
739    sum_hl: f64,
740
741    inv_pf: f64,
742}
743
744impl CorrelHlStream {
745    #[inline]
746    pub fn try_new(params: CorrelHlParams) -> Result<Self, CorrelHlError> {
747        let period = params.period.unwrap_or(9);
748        if period == 0 {
749            return Err(CorrelHlError::InvalidPeriod {
750                period,
751                data_len: 0,
752            });
753        }
754
755        Ok(Self {
756            period,
757            buffer_high: vec![f64::NAN; period],
758            buffer_low: vec![f64::NAN; period],
759            head: 0,
760            len: 0,
761            nan_in_win: 0,
762            sum_h: 0.0,
763            sum_h2: 0.0,
764            sum_l: 0.0,
765            sum_l2: 0.0,
766            sum_hl: 0.0,
767            inv_pf: 1.0 / (period as f64),
768        })
769    }
770
771    #[inline(always)]
772    pub fn update(&mut self, h: f64, l: f64) -> Option<f64> {
773        if self.len == self.period {
774            let old_h = self.buffer_high[self.head];
775            let old_l = self.buffer_low[self.head];
776
777            if old_h.is_nan() || old_l.is_nan() {
778                if self.nan_in_win > 0 {
779                    self.nan_in_win -= 1;
780                }
781            } else {
782                self.sum_h -= old_h;
783                self.sum_l -= old_l;
784                self.sum_h2 -= old_h * old_h;
785                self.sum_l2 -= old_l * old_l;
786                self.sum_hl -= old_h * old_l;
787            }
788        }
789
790        self.buffer_high[self.head] = h;
791        self.buffer_low[self.head] = l;
792
793        if h.is_nan() || l.is_nan() {
794            self.nan_in_win += 1;
795        } else {
796            self.sum_h += h;
797            self.sum_l += l;
798            self.sum_h2 += h * h;
799            self.sum_l2 += l * l;
800            self.sum_hl += h * l;
801        }
802
803        self.head += 1;
804        if self.head == self.period {
805            self.head = 0;
806        }
807        if self.len < self.period {
808            self.len += 1;
809        }
810
811        if self.len < self.period {
812            return None;
813        }
814        if self.nan_in_win != 0 {
815            return Some(f64::NAN);
816        }
817
818        let cov = self.sum_hl - (self.sum_h * self.sum_l) * self.inv_pf;
819        let var_h = self.sum_h2 - (self.sum_h * self.sum_h) * self.inv_pf;
820        let var_l = self.sum_l2 - (self.sum_l * self.sum_l) * self.inv_pf;
821
822        if var_h <= 0.0 || var_l <= 0.0 {
823            return Some(0.0);
824        }
825
826        let denom = (var_h * var_l).sqrt();
827        Some(cov / denom)
828    }
829}
830
831#[derive(Clone, Debug)]
832pub struct CorrelHlBatchRange {
833    pub period: (usize, usize, usize),
834}
835
836impl Default for CorrelHlBatchRange {
837    fn default() -> Self {
838        Self {
839            period: (9, 258, 1),
840        }
841    }
842}
843
844#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
845#[derive(Serialize, Deserialize)]
846pub struct CorrelHlBatchConfig {
847    pub period_range: (usize, usize, usize),
848}
849
850#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
851#[derive(Serialize, Deserialize)]
852pub struct CorrelHlBatchJsOutput {
853    pub values: Vec<f64>,
854    pub periods: Vec<usize>,
855    pub rows: usize,
856    pub cols: usize,
857}
858
859#[derive(Clone, Debug, Default)]
860pub struct CorrelHlBatchBuilder {
861    range: CorrelHlBatchRange,
862    kernel: Kernel,
863}
864
865impl CorrelHlBatchBuilder {
866    pub fn new() -> Self {
867        Self::default()
868    }
869    pub fn kernel(mut self, k: Kernel) -> Self {
870        self.kernel = k;
871        self
872    }
873
874    #[inline]
875    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
876        self.range.period = (start, end, step);
877        self
878    }
879    #[inline]
880    pub fn period_static(mut self, p: usize) -> Self {
881        self.range.period = (p, p, 0);
882        self
883    }
884
885    pub fn apply_slices(
886        self,
887        high: &[f64],
888        low: &[f64],
889    ) -> Result<CorrelHlBatchOutput, CorrelHlError> {
890        correl_hl_batch_with_kernel(high, low, &self.range, self.kernel)
891    }
892
893    pub fn apply_candles(self, c: &Candles) -> Result<CorrelHlBatchOutput, CorrelHlError> {
894        let high = c
895            .select_candle_field("high")
896            .map_err(|_| CorrelHlError::EmptyInputData)?;
897        let low = c
898            .select_candle_field("low")
899            .map_err(|_| CorrelHlError::EmptyInputData)?;
900        self.apply_slices(high, low)
901    }
902}
903
904pub fn expand_grid(r: &CorrelHlBatchRange) -> Result<Vec<CorrelHlParams>, CorrelHlError> {
905    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, CorrelHlError> {
906        if step == 0 || start == end {
907            return Ok(vec![start]);
908        }
909        if start < end {
910            let mut v = Vec::new();
911            let mut x = start;
912            while x <= end {
913                v.push(x);
914                match x.checked_add(step) {
915                    Some(nx) if nx > x => x = nx,
916                    _ => break,
917                }
918            }
919            if v.is_empty() {
920                return Err(CorrelHlError::InvalidRange { start, end, step });
921            }
922            Ok(v)
923        } else {
924            let mut v = Vec::new();
925            let mut x = start;
926            while x >= end {
927                v.push(x);
928                if x < end + step {
929                    break;
930                }
931                x = x.saturating_sub(step);
932                if x == 0 {
933                    break;
934                }
935            }
936            if v.is_empty() {
937                return Err(CorrelHlError::InvalidRange { start, end, step });
938            }
939            Ok(v)
940        }
941    }
942    let periods = axis_usize(r.period)?;
943    if periods.is_empty() {
944        return Err(CorrelHlError::InvalidRange {
945            start: r.period.0,
946            end: r.period.1,
947            step: r.period.2,
948        });
949    }
950    let mut out = Vec::with_capacity(periods.len());
951    for &p in &periods {
952        out.push(CorrelHlParams { period: Some(p) });
953    }
954    Ok(out)
955}
956
957#[derive(Clone, Debug)]
958pub struct CorrelHlBatchOutput {
959    pub values: Vec<f64>,
960    pub combos: Vec<CorrelHlParams>,
961    pub rows: usize,
962    pub cols: usize,
963}
964
965impl CorrelHlBatchOutput {
966    pub fn row_for_params(&self, p: &CorrelHlParams) -> Option<usize> {
967        self.combos
968            .iter()
969            .position(|c| c.period.unwrap_or(9) == p.period.unwrap_or(9))
970    }
971
972    pub fn values_for(&self, p: &CorrelHlParams) -> Option<&[f64]> {
973        self.row_for_params(p).map(|row| {
974            let start = row * self.cols;
975            &self.values[start..start + self.cols]
976        })
977    }
978}
979
980pub fn correl_hl_batch_with_kernel(
981    high: &[f64],
982    low: &[f64],
983    sweep: &CorrelHlBatchRange,
984    k: Kernel,
985) -> Result<CorrelHlBatchOutput, CorrelHlError> {
986    let kernel = match k {
987        Kernel::Auto => detect_best_batch_kernel(),
988        other if other.is_batch() => other,
989        other => return Err(CorrelHlError::InvalidKernelForBatch(other)),
990    };
991
992    let simd = match kernel {
993        Kernel::Avx512Batch => Kernel::Avx512,
994        Kernel::Avx2Batch => Kernel::Avx2,
995        Kernel::ScalarBatch => Kernel::Scalar,
996        _ => unreachable!(),
997    };
998    correl_hl_batch_par_slice(high, low, sweep, simd)
999}
1000
1001#[inline(always)]
1002pub fn correl_hl_batch_slice(
1003    high: &[f64],
1004    low: &[f64],
1005    sweep: &CorrelHlBatchRange,
1006    kern: Kernel,
1007) -> Result<CorrelHlBatchOutput, CorrelHlError> {
1008    correl_hl_batch_inner(high, low, sweep, kern, false)
1009}
1010
1011#[inline(always)]
1012pub fn correl_hl_batch_par_slice(
1013    high: &[f64],
1014    low: &[f64],
1015    sweep: &CorrelHlBatchRange,
1016    kern: Kernel,
1017) -> Result<CorrelHlBatchOutput, CorrelHlError> {
1018    correl_hl_batch_inner(high, low, sweep, kern, true)
1019}
1020
1021#[inline(always)]
1022fn correl_hl_batch_inner(
1023    high: &[f64],
1024    low: &[f64],
1025    sweep: &CorrelHlBatchRange,
1026    kern: Kernel,
1027    parallel: bool,
1028) -> Result<CorrelHlBatchOutput, CorrelHlError> {
1029    let combos = expand_grid(sweep)?;
1030
1031    let first = high
1032        .iter()
1033        .zip(low.iter())
1034        .position(|(&h, &l)| !h.is_nan() && !l.is_nan())
1035        .ok_or(CorrelHlError::AllValuesNaN)?;
1036
1037    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1038    if high.len() - first < max_p {
1039        return Err(CorrelHlError::NotEnoughValidData {
1040            needed: max_p,
1041            valid: high.len() - first,
1042        });
1043    }
1044
1045    let rows = combos.len();
1046    let cols = high.len();
1047
1048    rows.checked_mul(cols)
1049        .ok_or(CorrelHlError::InvalidInput("rows*cols overflow"))?;
1050
1051    let warm: Vec<usize> = combos
1052        .iter()
1053        .map(|c| first + c.period.unwrap() - 1)
1054        .collect();
1055
1056    let mut buf_mu = make_uninit_matrix(rows, cols);
1057
1058    init_matrix_prefixes(&mut buf_mu, cols, &warm);
1059
1060    let mut buf_guard = ManuallyDrop::new(buf_mu);
1061    let values_slice: &mut [f64] = unsafe {
1062        core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
1063    };
1064
1065    let n = high.len();
1066    let mut ps_h = vec![0.0f64; n + 1];
1067    let mut ps_h2 = vec![0.0f64; n + 1];
1068    let mut ps_l = vec![0.0f64; n + 1];
1069    let mut ps_l2 = vec![0.0f64; n + 1];
1070    let mut ps_hl = vec![0.0f64; n + 1];
1071    let mut ps_nan = vec![0i32; n + 1];
1072    for i in 0..n {
1073        let h = high[i];
1074        let l = low[i];
1075        let (ph, ph2, pl, pl2, phl) = (ps_h[i], ps_h2[i], ps_l[i], ps_l2[i], ps_hl[i]);
1076        if h.is_nan() || l.is_nan() {
1077            ps_h[i + 1] = ph;
1078            ps_h2[i + 1] = ph2;
1079            ps_l[i + 1] = pl;
1080            ps_l2[i + 1] = pl2;
1081            ps_hl[i + 1] = phl;
1082            ps_nan[i + 1] = ps_nan[i] + 1;
1083        } else {
1084            ps_h[i + 1] = ph + h;
1085            ps_h2[i + 1] = ph2 + h * h;
1086            ps_l[i + 1] = pl + l;
1087            ps_l2[i + 1] = pl2 + l * l;
1088            ps_hl[i + 1] = phl + h * l;
1089            ps_nan[i + 1] = ps_nan[i];
1090        }
1091    }
1092
1093    let do_row = |row: usize, out_row: &mut [f64]| {
1094        let p = combos[row].period.unwrap();
1095        let inv_pf = 1.0 / (p as f64);
1096        let warm = first + p - 1;
1097        for i in warm..n {
1098            let end = i + 1;
1099            let start = end - p;
1100            let nan_w = ps_nan[end] - ps_nan[start];
1101            if nan_w != 0 {
1102                out_row[i] = f64::NAN;
1103            } else {
1104                let sum_h = ps_h[end] - ps_h[start];
1105                let sum_l = ps_l[end] - ps_l[start];
1106                let sum_h2 = ps_h2[end] - ps_h2[start];
1107                let sum_l2 = ps_l2[end] - ps_l2[start];
1108                let sum_hl = ps_hl[end] - ps_hl[start];
1109                let cov = sum_hl - (sum_h * sum_l) * inv_pf;
1110                let var_h = sum_h2 - (sum_h * sum_h) * inv_pf;
1111                let var_l = sum_l2 - (sum_l * sum_l) * inv_pf;
1112                if var_h <= 0.0 || var_l <= 0.0 {
1113                    out_row[i] = 0.0;
1114                } else {
1115                    out_row[i] = cov / (var_h.sqrt() * var_l.sqrt());
1116                }
1117            }
1118        }
1119    };
1120
1121    if parallel {
1122        #[cfg(not(target_arch = "wasm32"))]
1123        {
1124            values_slice
1125                .par_chunks_mut(cols)
1126                .enumerate()
1127                .for_each(|(row, slice)| do_row(row, slice));
1128        }
1129
1130        #[cfg(target_arch = "wasm32")]
1131        {
1132            for (row, slice) in values_slice.chunks_mut(cols).enumerate() {
1133                do_row(row, slice);
1134            }
1135        }
1136    } else {
1137        for (row, slice) in values_slice.chunks_mut(cols).enumerate() {
1138            do_row(row, slice);
1139        }
1140    }
1141
1142    let values = unsafe {
1143        Vec::from_raw_parts(
1144            buf_guard.as_mut_ptr() as *mut f64,
1145            buf_guard.len(),
1146            buf_guard.capacity(),
1147        )
1148    };
1149
1150    Ok(CorrelHlBatchOutput {
1151        values,
1152        combos,
1153        rows,
1154        cols,
1155    })
1156}
1157
1158#[inline(always)]
1159fn correl_hl_batch_inner_into(
1160    high: &[f64],
1161    low: &[f64],
1162    sweep: &CorrelHlBatchRange,
1163    kern: Kernel,
1164    parallel: bool,
1165    out: &mut [f64],
1166) -> Result<Vec<CorrelHlParams>, CorrelHlError> {
1167    let combos = expand_grid(sweep)?;
1168
1169    let first = high
1170        .iter()
1171        .zip(low.iter())
1172        .position(|(&h, &l)| !h.is_nan() && !l.is_nan())
1173        .ok_or(CorrelHlError::AllValuesNaN)?;
1174
1175    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1176    if high.len() - first < max_p {
1177        return Err(CorrelHlError::NotEnoughValidData {
1178            needed: max_p,
1179            valid: high.len() - first,
1180        });
1181    }
1182
1183    let rows = combos.len();
1184    let cols = high.len();
1185
1186    let total = rows
1187        .checked_mul(cols)
1188        .ok_or(CorrelHlError::InvalidInput("rows*cols overflow"))?;
1189    if out.len() != total {
1190        return Err(CorrelHlError::OutputLengthMismatch {
1191            expected: total,
1192            got: out.len(),
1193        });
1194    }
1195
1196    let warm: Vec<usize> = combos
1197        .iter()
1198        .map(|c| first + c.period.unwrap() - 1)
1199        .collect();
1200    let out_mu: &mut [MaybeUninit<f64>] = unsafe {
1201        core::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1202    };
1203    init_matrix_prefixes(out_mu, cols, &warm);
1204
1205    let n = high.len();
1206    let mut ps_h = vec![0.0f64; n + 1];
1207    let mut ps_h2 = vec![0.0f64; n + 1];
1208    let mut ps_l = vec![0.0f64; n + 1];
1209    let mut ps_l2 = vec![0.0f64; n + 1];
1210    let mut ps_hl = vec![0.0f64; n + 1];
1211    let mut ps_nan = vec![0i32; n + 1];
1212    for i in 0..n {
1213        let h = high[i];
1214        let l = low[i];
1215        let (prev_h, prev_h2, prev_l, prev_l2, prev_hl) =
1216            (ps_h[i], ps_h2[i], ps_l[i], ps_l2[i], ps_hl[i]);
1217        if h.is_nan() || l.is_nan() {
1218            ps_h[i + 1] = prev_h;
1219            ps_h2[i + 1] = prev_h2;
1220            ps_l[i + 1] = prev_l;
1221            ps_l2[i + 1] = prev_l2;
1222            ps_hl[i + 1] = prev_hl;
1223            ps_nan[i + 1] = ps_nan[i] + 1;
1224        } else {
1225            ps_h[i + 1] = prev_h + h;
1226            ps_h2[i + 1] = prev_h2 + h * h;
1227            ps_l[i + 1] = prev_l + l;
1228            ps_l2[i + 1] = prev_l2 + l * l;
1229            ps_hl[i + 1] = prev_hl + h * l;
1230            ps_nan[i + 1] = ps_nan[i];
1231        }
1232    }
1233
1234    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| {
1235        let p = combos[row].period.unwrap();
1236        let inv_pf = 1.0 / (p as f64);
1237        let dst: &mut [f64] = unsafe {
1238            core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len())
1239        };
1240
1241        let warm = first + p - 1;
1242        for i in warm..n {
1243            let end = i + 1;
1244            let start = end - p;
1245            let nan_w = ps_nan[end] - ps_nan[start];
1246            if nan_w != 0 {
1247                dst[i] = f64::NAN;
1248            } else {
1249                let sum_h = ps_h[end] - ps_h[start];
1250                let sum_l = ps_l[end] - ps_l[start];
1251                let sum_h2 = ps_h2[end] - ps_h2[start];
1252                let sum_l2 = ps_l2[end] - ps_l2[start];
1253                let sum_hl = ps_hl[end] - ps_hl[start];
1254                let cov = sum_hl - (sum_h * sum_l) * inv_pf;
1255                let var_h = sum_h2 - (sum_h * sum_h) * inv_pf;
1256                let var_l = sum_l2 - (sum_l * sum_l) * inv_pf;
1257                if var_h <= 0.0 || var_l <= 0.0 {
1258                    dst[i] = 0.0;
1259                } else {
1260                    dst[i] = cov / (var_h.sqrt() * var_l.sqrt());
1261                }
1262            }
1263        }
1264    };
1265
1266    if parallel {
1267        #[cfg(not(target_arch = "wasm32"))]
1268        {
1269            out_mu
1270                .par_chunks_mut(cols)
1271                .enumerate()
1272                .for_each(|(r, s)| do_row(r, s));
1273        }
1274        #[cfg(target_arch = "wasm32")]
1275        {
1276            for (r, s) in out_mu.chunks_mut(cols).enumerate() {
1277                do_row(r, s);
1278            }
1279        }
1280    } else {
1281        for (r, s) in out_mu.chunks_mut(cols).enumerate() {
1282            do_row(r, s);
1283        }
1284    }
1285
1286    Ok(combos)
1287}
1288
1289#[inline(always)]
1290unsafe fn correl_hl_row_scalar(
1291    high: &[f64],
1292    low: &[f64],
1293    first: usize,
1294    period: usize,
1295    out: &mut [f64],
1296) {
1297    correl_hl_scalar(high, low, period, first, out)
1298}
1299
1300#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1301#[inline(always)]
1302unsafe fn correl_hl_row_avx2(
1303    high: &[f64],
1304    low: &[f64],
1305    first: usize,
1306    period: usize,
1307    out: &mut [f64],
1308) {
1309    correl_hl_avx2(high, low, period, first, out)
1310}
1311
1312#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1313#[inline(always)]
1314pub unsafe fn correl_hl_row_avx512(
1315    high: &[f64],
1316    low: &[f64],
1317    first: usize,
1318    period: usize,
1319    out: &mut [f64],
1320) {
1321    if period <= 32 {
1322        correl_hl_row_avx512_short(high, low, first, period, out)
1323    } else {
1324        correl_hl_row_avx512_long(high, low, first, period, out)
1325    }
1326}
1327
1328#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1329#[inline(always)]
1330pub unsafe fn correl_hl_row_avx512_short(
1331    high: &[f64],
1332    low: &[f64],
1333    first: usize,
1334    period: usize,
1335    out: &mut [f64],
1336) {
1337    correl_hl_avx512_short(high, low, period, first, out)
1338}
1339
1340#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1341#[inline(always)]
1342pub unsafe fn correl_hl_row_avx512_long(
1343    high: &[f64],
1344    low: &[f64],
1345    first: usize,
1346    period: usize,
1347    out: &mut [f64],
1348) {
1349    correl_hl_avx512_long(high, low, period, first, out)
1350}
1351
1352#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1353#[wasm_bindgen]
1354pub fn correl_hl_js(high: &[f64], low: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1355    let params = CorrelHlParams {
1356        period: Some(period),
1357    };
1358    let input = CorrelHlInput::from_slices(high, low, params);
1359
1360    let mut output = vec![0.0; high.len()];
1361
1362    correl_hl_into_slice(&mut output, &input, Kernel::Auto)
1363        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1364
1365    Ok(output)
1366}
1367
1368#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1369#[wasm_bindgen]
1370pub fn correl_hl_into(
1371    high_ptr: *const f64,
1372    low_ptr: *const f64,
1373    out_ptr: *mut f64,
1374    len: usize,
1375    period: usize,
1376) -> Result<(), JsValue> {
1377    if high_ptr.is_null() || low_ptr.is_null() || out_ptr.is_null() {
1378        return Err(JsValue::from_str("Null pointer provided"));
1379    }
1380
1381    unsafe {
1382        let high = std::slice::from_raw_parts(high_ptr, len);
1383        let low = std::slice::from_raw_parts(low_ptr, len);
1384        let params = CorrelHlParams {
1385            period: Some(period),
1386        };
1387        let input = CorrelHlInput::from_slices(high, low, params);
1388
1389        if high_ptr == out_ptr || low_ptr == out_ptr {
1390            let mut temp = vec![0.0; len];
1391            correl_hl_into_slice(&mut temp, &input, Kernel::Auto)
1392                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1393            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1394            out.copy_from_slice(&temp);
1395        } else {
1396            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1397            correl_hl_into_slice(out, &input, Kernel::Auto)
1398                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1399        }
1400        Ok(())
1401    }
1402}
1403
1404#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1405#[wasm_bindgen]
1406pub fn correl_hl_alloc(len: usize) -> *mut f64 {
1407    let mut vec = Vec::<f64>::with_capacity(len);
1408    let ptr = vec.as_mut_ptr();
1409    std::mem::forget(vec);
1410    ptr
1411}
1412
1413#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1414#[wasm_bindgen]
1415pub fn correl_hl_free(ptr: *mut f64, len: usize) {
1416    if !ptr.is_null() {
1417        unsafe {
1418            let _ = Vec::from_raw_parts(ptr, len, len);
1419        }
1420    }
1421}
1422
1423#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1424#[wasm_bindgen(js_name = correl_hl_batch)]
1425pub fn correl_hl_batch_js(high: &[f64], low: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1426    let config: CorrelHlBatchConfig = serde_wasm_bindgen::from_value(config)
1427        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1428
1429    let sweep = CorrelHlBatchRange {
1430        period: config.period_range,
1431    };
1432
1433    let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1434    let rows = combos.len();
1435    let cols = high.len();
1436    let total = rows
1437        .checked_mul(cols)
1438        .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
1439    let mut values = vec![0.0f64; total];
1440
1441    correl_hl_batch_inner_into(high, low, &sweep, Kernel::Auto, false, &mut values)
1442        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1443
1444    let periods: Vec<usize> = combos.iter().map(|c| c.period.unwrap()).collect();
1445
1446    let js_output = CorrelHlBatchJsOutput {
1447        values,
1448        periods,
1449        rows,
1450        cols,
1451    };
1452
1453    serde_wasm_bindgen::to_value(&js_output)
1454        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1455}
1456
1457#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1458#[wasm_bindgen]
1459pub fn correl_hl_batch_into(
1460    high_ptr: *const f64,
1461    low_ptr: *const f64,
1462    out_ptr: *mut f64,
1463    len: usize,
1464    period_start: usize,
1465    period_end: usize,
1466    period_step: usize,
1467) -> Result<usize, JsValue> {
1468    if high_ptr.is_null() || low_ptr.is_null() || out_ptr.is_null() {
1469        return Err(JsValue::from_str("Null pointer provided"));
1470    }
1471
1472    unsafe {
1473        let high = std::slice::from_raw_parts(high_ptr, len);
1474        let low = std::slice::from_raw_parts(low_ptr, len);
1475
1476        let sweep = CorrelHlBatchRange {
1477            period: (period_start, period_end, period_step),
1478        };
1479
1480        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1481        let rows = combos.len();
1482
1483        let total = rows
1484            .checked_mul(len)
1485            .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
1486        let out_slice = std::slice::from_raw_parts_mut(out_ptr, total);
1487
1488        correl_hl_batch_inner_into(high, low, &sweep, Kernel::Auto, false, out_slice)
1489            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1490
1491        Ok(rows)
1492    }
1493}
1494
1495#[cfg(feature = "python")]
1496#[pyfunction(name = "correl_hl")]
1497#[pyo3(signature = (high, low, period, kernel=None))]
1498pub fn correl_hl_py<'py>(
1499    py: Python<'py>,
1500    high: numpy::PyReadonlyArray1<'py, f64>,
1501    low: numpy::PyReadonlyArray1<'py, f64>,
1502    period: usize,
1503    kernel: Option<&str>,
1504) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1505    use numpy::{IntoPyArray, PyArrayMethods};
1506
1507    let high_slice = high.as_slice()?;
1508    let low_slice = low.as_slice()?;
1509    let kern = validate_kernel(kernel, false)?;
1510
1511    let params = CorrelHlParams {
1512        period: Some(period),
1513    };
1514    let input = CorrelHlInput::from_slices(high_slice, low_slice, params);
1515
1516    let result_vec: Vec<f64> = py
1517        .allow_threads(|| correl_hl_with_kernel(&input, kern).map(|o| o.values))
1518        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1519
1520    Ok(result_vec.into_pyarray(py))
1521}
1522
1523#[cfg(feature = "python")]
1524#[pyclass(name = "CorrelHlStream")]
1525pub struct CorrelHlStreamPy {
1526    stream: CorrelHlStream,
1527}
1528
1529#[cfg(feature = "python")]
1530#[pymethods]
1531impl CorrelHlStreamPy {
1532    #[new]
1533    fn new(period: usize) -> PyResult<Self> {
1534        let params = CorrelHlParams {
1535            period: Some(period),
1536        };
1537        let stream =
1538            CorrelHlStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1539        Ok(CorrelHlStreamPy { stream })
1540    }
1541
1542    fn update(&mut self, high: f64, low: f64) -> Option<f64> {
1543        self.stream.update(high, low)
1544    }
1545}
1546
1547#[cfg(feature = "python")]
1548#[pyfunction(name = "correl_hl_batch")]
1549#[pyo3(signature = (high, low, period_range, kernel=None))]
1550pub fn correl_hl_batch_py<'py>(
1551    py: Python<'py>,
1552    high: numpy::PyReadonlyArray1<'py, f64>,
1553    low: numpy::PyReadonlyArray1<'py, f64>,
1554    period_range: (usize, usize, usize),
1555    kernel: Option<&str>,
1556) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1557    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1558    use pyo3::types::PyDict;
1559
1560    let high_slice = high.as_slice()?;
1561    let low_slice = low.as_slice()?;
1562
1563    let sweep = CorrelHlBatchRange {
1564        period: period_range,
1565    };
1566
1567    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1568    let rows = combos.len();
1569    let cols = high_slice.len();
1570
1571    let total = rows
1572        .checked_mul(cols)
1573        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1574    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1575    let slice_out = unsafe { out_arr.as_slice_mut()? };
1576
1577    let kern = validate_kernel(kernel, true)?;
1578
1579    let combos = py
1580        .allow_threads(|| {
1581            let kernel = match kern {
1582                Kernel::Auto => detect_best_batch_kernel(),
1583                k => k,
1584            };
1585            let simd = match kernel {
1586                Kernel::Avx512Batch => Kernel::Avx512,
1587                Kernel::Avx2Batch => Kernel::Avx2,
1588                Kernel::ScalarBatch => Kernel::Scalar,
1589                _ => unreachable!(),
1590            };
1591            correl_hl_batch_inner_into(high_slice, low_slice, &sweep, simd, true, slice_out)
1592        })
1593        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1594
1595    let dict = PyDict::new(py);
1596    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1597    dict.set_item(
1598        "periods",
1599        combos
1600            .iter()
1601            .map(|p| p.period.unwrap() as u64)
1602            .collect::<Vec<_>>()
1603            .into_pyarray(py),
1604    )?;
1605
1606    Ok(dict)
1607}
1608
1609#[cfg(all(feature = "python", feature = "cuda"))]
1610#[pyclass(
1611    module = "ta_indicators.cuda",
1612    name = "CorrelHlDeviceArrayF32",
1613    unsendable
1614)]
1615pub struct CorrelHlDeviceArrayF32Py {
1616    pub(crate) inner: DeviceArrayF32,
1617    _ctx_guard: Arc<Context>,
1618    _device_id: u32,
1619}
1620
1621#[cfg(all(feature = "python", feature = "cuda"))]
1622#[pymethods]
1623impl CorrelHlDeviceArrayF32Py {
1624    #[getter]
1625    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1626        let inner = &self.inner;
1627        let d = PyDict::new(py);
1628        let itemsize = std::mem::size_of::<f32>();
1629        d.set_item("shape", (inner.rows, inner.cols))?;
1630        d.set_item("typestr", "<f4")?;
1631        d.set_item("strides", (inner.cols * itemsize, itemsize))?;
1632        d.set_item("data", (inner.device_ptr() as usize, false))?;
1633        d.set_item("version", 3)?;
1634        Ok(d)
1635    }
1636
1637    fn __dlpack_device__(&self) -> (i32, i32) {
1638        (2, self._device_id as i32)
1639    }
1640
1641    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1642    fn __dlpack__<'py>(
1643        &mut self,
1644        py: Python<'py>,
1645        stream: Option<usize>,
1646        max_version: Option<(u32, u32)>,
1647        dl_device: Option<(i32, i32)>,
1648        copy: Option<bool>,
1649    ) -> PyResult<PyObject> {
1650        use pyo3::ffi as pyffi;
1651        use std::ffi::{c_void, CString};
1652
1653        #[repr(C)]
1654        struct DLDevice {
1655            device_type: i32,
1656            device_id: i32,
1657        }
1658        #[repr(C)]
1659        struct DLDataType {
1660            code: u8,
1661            bits: u8,
1662            lanes: u16,
1663        }
1664        #[repr(C)]
1665        struct DLTensor {
1666            data: *mut c_void,
1667            device: DLDevice,
1668            ndim: i32,
1669            dtype: DLDataType,
1670            shape: *mut i64,
1671            strides: *mut i64,
1672            byte_offset: u64,
1673        }
1674        #[repr(C)]
1675        struct DLManagedTensor {
1676            dl_tensor: DLTensor,
1677            manager_ctx: *mut c_void,
1678            deleter: Option<unsafe extern "C" fn(*mut DLManagedTensor)>,
1679        }
1680        #[repr(C)]
1681        struct DLManagedTensorVersioned {
1682            manager: *mut DLManagedTensor,
1683            version: u32,
1684        }
1685
1686        #[repr(C)]
1687        struct ManagerCtx {
1688            shape: *mut i64,
1689            strides: *mut i64,
1690            _shape: Box<[i64; 2]>,
1691            _strides: Box<[i64; 2]>,
1692            _self_ref: PyObject,
1693            _arr: DeviceArrayF32,
1694            _ctx: Arc<Context>,
1695        }
1696
1697        unsafe extern "C" fn deleter(p: *mut DLManagedTensor) {
1698            if p.is_null() {
1699                return;
1700            }
1701            let mt = Box::from_raw(p);
1702            let ctx_ptr = mt.manager_ctx as *mut ManagerCtx;
1703            if !ctx_ptr.is_null() {
1704                let _ = Box::from_raw(ctx_ptr);
1705            }
1706        }
1707
1708        let dummy =
1709            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1710        let inner = std::mem::replace(
1711            &mut self.inner,
1712            DeviceArrayF32 {
1713                buf: dummy,
1714                rows: 0,
1715                cols: 0,
1716            },
1717        );
1718
1719        let rows = inner.rows as i64;
1720        let cols = inner.cols as i64;
1721        let total = (rows as i128) * (cols as i128);
1722        let mut shape = Box::new([rows, cols]);
1723        let mut strides = Box::new([cols, 1]);
1724        let shape_ptr = shape.as_mut_ptr();
1725        let strides_ptr = strides.as_mut_ptr();
1726
1727        let self_ref =
1728            unsafe { PyObject::from_borrowed_ptr(py, self as *mut _ as *mut pyo3::ffi::PyObject) };
1729        let mgr = Box::new(ManagerCtx {
1730            shape: shape_ptr,
1731            strides: strides_ptr,
1732            _shape: shape,
1733            _strides: strides,
1734            _self_ref: self_ref,
1735            _arr: inner,
1736            _ctx: self._ctx_guard.clone(),
1737        });
1738        let mgr_ptr = Box::into_raw(mgr) as *mut c_void;
1739
1740        let dl = DLTensor {
1741            data: if total == 0 {
1742                std::ptr::null_mut()
1743            } else {
1744                unsafe {
1745                    (*(mgr_ptr as *mut ManagerCtx))
1746                        ._arr
1747                        .buf
1748                        .as_device_ptr()
1749                        .as_raw() as *mut c_void
1750                }
1751            },
1752            device: DLDevice {
1753                device_type: 2,
1754                device_id: self._device_id as i32,
1755            },
1756            ndim: 2,
1757            dtype: DLDataType {
1758                code: 2,
1759                bits: 32,
1760                lanes: 1,
1761            },
1762            shape: shape_ptr,
1763            strides: strides_ptr,
1764            byte_offset: 0,
1765        };
1766        let mt = Box::new(DLManagedTensor {
1767            dl_tensor: dl,
1768            manager_ctx: mgr_ptr,
1769            deleter: Some(deleter),
1770        });
1771
1772        let want_versioned = max_version.map(|(maj, _)| maj >= 1).unwrap_or(false);
1773
1774        unsafe {
1775            if want_versioned {
1776                let wrapped = Box::new(DLManagedTensorVersioned {
1777                    manager: Box::into_raw(mt),
1778                    version: 1,
1779                });
1780                let ptr = Box::into_raw(wrapped) as *mut c_void;
1781                let name = CString::new("dltensor_versioned").unwrap();
1782                let cap = pyffi::PyCapsule_New(ptr, name.as_ptr(), None);
1783                if cap.is_null() {
1784                    let _ = Box::from_raw(ptr as *mut DLManagedTensorVersioned);
1785                    return Err(PyValueError::new_err("failed to create DLPack capsule"));
1786                }
1787                Ok(PyObject::from_owned_ptr(py, cap))
1788            } else {
1789                let ptr = Box::into_raw(mt) as *mut c_void;
1790                let name = CString::new("dltensor").unwrap();
1791                let cap = pyffi::PyCapsule_New(ptr, name.as_ptr(), None);
1792                if cap.is_null() {
1793                    let _ = Box::from_raw(ptr as *mut DLManagedTensor);
1794                    return Err(PyValueError::new_err("failed to create DLPack capsule"));
1795                }
1796                Ok(PyObject::from_owned_ptr(py, cap))
1797            }
1798        }
1799    }
1800}
1801
1802#[cfg(all(feature = "python", feature = "cuda"))]
1803impl CorrelHlDeviceArrayF32Py {
1804    pub fn new_from_rust(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
1805        Self {
1806            inner,
1807            _ctx_guard: ctx_guard,
1808            _device_id: device_id,
1809        }
1810    }
1811}
1812
1813#[cfg(all(feature = "python", feature = "cuda"))]
1814#[pyfunction(name = "correl_hl_cuda_batch_dev")]
1815#[pyo3(signature = (high_f32, low_f32, period_range, device_id=0))]
1816pub fn correl_hl_cuda_batch_dev_py(
1817    py: Python<'_>,
1818    high_f32: numpy::PyReadonlyArray1<'_, f32>,
1819    low_f32: numpy::PyReadonlyArray1<'_, f32>,
1820    period_range: (usize, usize, usize),
1821    device_id: usize,
1822) -> PyResult<CorrelHlDeviceArrayF32Py> {
1823    use crate::cuda::correl_hl_wrapper::CudaCorrelHl;
1824    use crate::cuda::cuda_available;
1825    if !cuda_available() {
1826        return Err(PyValueError::new_err("CUDA not available"));
1827    }
1828    let h = high_f32.as_slice()?;
1829    let l = low_f32.as_slice()?;
1830    let sweep = CorrelHlBatchRange {
1831        period: period_range,
1832    };
1833    let (inner, ctx, dev_id) = py.allow_threads(|| {
1834        let cuda =
1835            CudaCorrelHl::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1836        let (dev, _combos) = cuda
1837            .correl_hl_batch_dev(h, l, &sweep)
1838            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1839        Ok::<_, PyErr>((dev, cuda.context_arc(), cuda.device_id()))
1840    })?;
1841    Ok(CorrelHlDeviceArrayF32Py::new_from_rust(inner, ctx, dev_id))
1842}
1843
1844#[cfg(all(feature = "python", feature = "cuda"))]
1845#[pyfunction(name = "correl_hl_cuda_many_series_one_param_dev")]
1846#[pyo3(signature = (high_tm_f32, low_tm_f32, period, device_id=0))]
1847pub fn correl_hl_cuda_many_series_one_param_dev_py(
1848    py: Python<'_>,
1849    high_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1850    low_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1851    period: usize,
1852    device_id: usize,
1853) -> PyResult<CorrelHlDeviceArrayF32Py> {
1854    use crate::cuda::correl_hl_wrapper::CudaCorrelHl;
1855    use crate::cuda::cuda_available;
1856    if !cuda_available() {
1857        return Err(PyValueError::new_err("CUDA not available"));
1858    }
1859    let shape = high_tm_f32.shape();
1860    if shape.len() != 2 || low_tm_f32.shape() != shape {
1861        return Err(PyValueError::new_err("expected matching 2D arrays"));
1862    }
1863    let rows = shape[0];
1864    let cols = shape[1];
1865    let h = high_tm_f32.as_slice()?;
1866    let l = low_tm_f32.as_slice()?;
1867    let (inner, ctx, dev_id) = py.allow_threads(|| {
1868        let cuda =
1869            CudaCorrelHl::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1870        let dev = cuda
1871            .correl_hl_many_series_one_param_time_major_dev(h, l, cols, rows, period)
1872            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1873        Ok::<_, PyErr>((dev, cuda.context_arc(), cuda.device_id()))
1874    })?;
1875    Ok(CorrelHlDeviceArrayF32Py::new_from_rust(inner, ctx, dev_id))
1876}
1877
1878#[cfg(test)]
1879mod tests {
1880    use super::*;
1881    use crate::skip_if_unsupported;
1882    use crate::utilities::data_loader::read_candles_from_csv;
1883    #[cfg(feature = "proptest")]
1884    use proptest::prelude::*;
1885
1886    #[test]
1887    fn test_correl_hl_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1888        let n = 256usize;
1889        let mut ts = Vec::with_capacity(n);
1890        let mut open = Vec::with_capacity(n);
1891        let mut high = Vec::with_capacity(n);
1892        let mut low = Vec::with_capacity(n);
1893        let mut close = Vec::with_capacity(n);
1894        let mut vol = Vec::with_capacity(n);
1895
1896        let mut cur = 100.0f64;
1897        for i in 0..n {
1898            let step = ((i as f64).sin() * 0.5) + 0.1;
1899            let o = cur;
1900            let c = cur + step;
1901            let (lo, hi) = if c >= o {
1902                (o - 0.3, c + 0.4)
1903            } else {
1904                (c - 0.3, o + 0.4)
1905            };
1906            ts.push(i as i64);
1907            open.push(o);
1908            close.push(c);
1909            high.push(hi);
1910            low.push(lo);
1911            vol.push(1000.0 + (i % 10) as f64);
1912            cur = c;
1913        }
1914
1915        let candles = crate::utilities::data_loader::Candles::new(
1916            ts,
1917            open,
1918            high.clone(),
1919            low.clone(),
1920            close,
1921            vol,
1922        );
1923
1924        let input = CorrelHlInput::from_candles(&candles, CorrelHlParams::default());
1925
1926        let baseline = correl_hl(&input)?;
1927
1928        let mut out = vec![0.0f64; n];
1929        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1930        {
1931            correl_hl_into(&mut out, &input)?;
1932        }
1933        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1934        {
1935            correl_hl_into_slice(&mut out, &input, Kernel::Auto)?;
1936        }
1937
1938        assert_eq!(baseline.values.len(), out.len());
1939        for (a, b) in baseline.values.iter().zip(out.iter()) {
1940            let equal = (a.is_nan() && b.is_nan()) || (*a == *b) || ((*a - *b).abs() <= 1e-12);
1941            assert!(equal, "Mismatch: baseline={} into={}", a, b);
1942        }
1943
1944        Ok(())
1945    }
1946
1947    fn check_correl_hl_partial_params(
1948        test_name: &str,
1949        kernel: Kernel,
1950    ) -> Result<(), Box<dyn std::error::Error>> {
1951        skip_if_unsupported!(kernel, test_name);
1952        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1953        let candles = read_candles_from_csv(file_path)?;
1954        let params = CorrelHlParams { period: None };
1955        let input = CorrelHlInput::from_candles(&candles, params);
1956        let output = correl_hl_with_kernel(&input, kernel)?;
1957        assert_eq!(output.values.len(), candles.close.len());
1958        Ok(())
1959    }
1960
1961    fn check_correl_hl_accuracy(
1962        test_name: &str,
1963        kernel: Kernel,
1964    ) -> Result<(), Box<dyn std::error::Error>> {
1965        skip_if_unsupported!(kernel, test_name);
1966        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1967        let candles = read_candles_from_csv(file_path)?;
1968        let params = CorrelHlParams { period: Some(5) };
1969        let input = CorrelHlInput::from_candles(&candles, params);
1970        let result = correl_hl_with_kernel(&input, kernel)?;
1971        let expected = [
1972            0.04589155420456278,
1973            0.6491664099299647,
1974            0.9691259236943873,
1975            0.9915438003818791,
1976            0.8460608423095615,
1977        ];
1978        let start_index = result.values.len() - 5;
1979        for (i, &val) in result.values[start_index..].iter().enumerate() {
1980            let exp = expected[i];
1981            let diff = (val - exp).abs();
1982            assert!(
1983                diff < 1e-7,
1984                "[{}] Value mismatch at index {}: expected {}, got {}",
1985                test_name,
1986                i,
1987                exp,
1988                val
1989            );
1990        }
1991        Ok(())
1992    }
1993
1994    fn check_correl_hl_zero_period(
1995        test_name: &str,
1996        kernel: Kernel,
1997    ) -> Result<(), Box<dyn std::error::Error>> {
1998        skip_if_unsupported!(kernel, test_name);
1999        let high = [1.0, 2.0, 3.0];
2000        let low = [1.0, 2.0, 3.0];
2001        let params = CorrelHlParams { period: Some(0) };
2002        let input = CorrelHlInput::from_slices(&high, &low, params);
2003        let result = correl_hl_with_kernel(&input, kernel);
2004        assert!(
2005            result.is_err(),
2006            "[{}] correl_hl should fail with zero period",
2007            test_name
2008        );
2009        Ok(())
2010    }
2011
2012    fn check_correl_hl_period_exceeds_length(
2013        test_name: &str,
2014        kernel: Kernel,
2015    ) -> Result<(), Box<dyn std::error::Error>> {
2016        skip_if_unsupported!(kernel, test_name);
2017        let high = [1.0, 2.0, 3.0];
2018        let low = [1.0, 2.0, 3.0];
2019        let params = CorrelHlParams { period: Some(10) };
2020        let input = CorrelHlInput::from_slices(&high, &low, params);
2021        let result = correl_hl_with_kernel(&input, kernel);
2022        assert!(
2023            result.is_err(),
2024            "[{}] correl_hl should fail with period exceeding length",
2025            test_name
2026        );
2027        Ok(())
2028    }
2029
2030    fn check_correl_hl_data_length_mismatch(
2031        test_name: &str,
2032        kernel: Kernel,
2033    ) -> Result<(), Box<dyn std::error::Error>> {
2034        skip_if_unsupported!(kernel, test_name);
2035        let high = [1.0, 2.0, 3.0];
2036        let low = [1.0, 2.0];
2037        let params = CorrelHlParams { period: Some(2) };
2038        let input = CorrelHlInput::from_slices(&high, &low, params);
2039        let result = correl_hl_with_kernel(&input, kernel);
2040        assert!(
2041            result.is_err(),
2042            "[{}] correl_hl should fail on length mismatch",
2043            test_name
2044        );
2045        Ok(())
2046    }
2047
2048    fn check_correl_hl_all_nan(
2049        test_name: &str,
2050        kernel: Kernel,
2051    ) -> Result<(), Box<dyn std::error::Error>> {
2052        skip_if_unsupported!(kernel, test_name);
2053        let high = [f64::NAN, f64::NAN, f64::NAN];
2054        let low = [f64::NAN, f64::NAN, f64::NAN];
2055        let params = CorrelHlParams { period: Some(2) };
2056        let input = CorrelHlInput::from_slices(&high, &low, params);
2057        let result = correl_hl_with_kernel(&input, kernel);
2058        assert!(
2059            result.is_err(),
2060            "[{}] correl_hl should fail on all NaN",
2061            test_name
2062        );
2063        Ok(())
2064    }
2065
2066    fn check_correl_hl_from_candles(
2067        test_name: &str,
2068        kernel: Kernel,
2069    ) -> Result<(), Box<dyn std::error::Error>> {
2070        skip_if_unsupported!(kernel, test_name);
2071        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2072        let candles = read_candles_from_csv(file_path)?;
2073        let params = CorrelHlParams { period: Some(9) };
2074        let input = CorrelHlInput::from_candles(&candles, params);
2075        let output = correl_hl_with_kernel(&input, kernel)?;
2076        assert_eq!(output.values.len(), candles.close.len());
2077        Ok(())
2078    }
2079
2080    fn check_correl_hl_reinput(
2081        test_name: &str,
2082        kernel: Kernel,
2083    ) -> Result<(), Box<dyn std::error::Error>> {
2084        skip_if_unsupported!(kernel, test_name);
2085        let high = [1.0, 2.0, 3.0, 4.0, 5.0];
2086        let low = [0.5, 1.0, 1.5, 2.0, 2.5];
2087        let params = CorrelHlParams { period: Some(2) };
2088        let first_input = CorrelHlInput::from_slices(&high, &low, params.clone());
2089        let first_result = correl_hl_with_kernel(&first_input, kernel)?;
2090        let second_input = CorrelHlInput::from_slices(&first_result.values, &low, params);
2091        let second_result = correl_hl_with_kernel(&second_input, kernel)?;
2092        assert_eq!(second_result.values.len(), low.len());
2093        Ok(())
2094    }
2095
2096    fn check_correl_hl_very_small_dataset(
2097        test_name: &str,
2098        kernel: Kernel,
2099    ) -> Result<(), Box<dyn std::error::Error>> {
2100        skip_if_unsupported!(kernel, test_name);
2101
2102        let single_high = [42.0];
2103        let single_low = [21.0];
2104        let params = CorrelHlParams { period: Some(1) };
2105        let input = CorrelHlInput::from_slices(&single_high, &single_low, params);
2106        let result = correl_hl_with_kernel(&input, kernel)?;
2107        assert_eq!(result.values.len(), 1);
2108
2109        assert!(result.values[0].is_nan() || result.values[0].abs() < f64::EPSILON);
2110        Ok(())
2111    }
2112
2113    fn check_correl_hl_empty_input(
2114        test_name: &str,
2115        kernel: Kernel,
2116    ) -> Result<(), Box<dyn std::error::Error>> {
2117        skip_if_unsupported!(kernel, test_name);
2118        let empty_high: [f64; 0] = [];
2119        let empty_low: [f64; 0] = [];
2120        let params = CorrelHlParams { period: Some(5) };
2121        let input = CorrelHlInput::from_slices(&empty_high, &empty_low, params);
2122        let result = correl_hl_with_kernel(&input, kernel);
2123        assert!(
2124            result.is_err(),
2125            "[{}] correl_hl should fail on empty input",
2126            test_name
2127        );
2128        Ok(())
2129    }
2130
2131    fn check_correl_hl_nan_handling(
2132        test_name: &str,
2133        kernel: Kernel,
2134    ) -> Result<(), Box<dyn std::error::Error>> {
2135        skip_if_unsupported!(kernel, test_name);
2136
2137        let high = vec![1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
2138        let low = vec![0.5, 1.0, 1.5, f64::NAN, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0];
2139        let params = CorrelHlParams { period: Some(3) };
2140        let input = CorrelHlInput::from_slices(&high, &low, params);
2141        let result = correl_hl_with_kernel(&input, kernel)?;
2142
2143        assert_eq!(result.values.len(), high.len());
2144
2145        let mut valid_count = 0;
2146        for i in 0..high.len() {
2147            if !high[i].is_nan() && !low[i].is_nan() {
2148                valid_count += 1;
2149                if valid_count >= 3 {
2150                    let has_valid = result.values[i..].iter().any(|&v| !v.is_nan());
2151                    assert!(
2152                        has_valid,
2153                        "[{}] Should have valid correlations after enough data",
2154                        test_name
2155                    );
2156                    break;
2157                }
2158            }
2159        }
2160        Ok(())
2161    }
2162
2163    fn check_correl_hl_streaming(
2164        test_name: &str,
2165        kernel: Kernel,
2166    ) -> Result<(), Box<dyn std::error::Error>> {
2167        skip_if_unsupported!(kernel, test_name);
2168
2169        let params = CorrelHlParams { period: Some(3) };
2170        let mut stream = CorrelHlStream::try_new(params)?;
2171
2172        let high_data = [1.0, 2.0, 3.0, 4.0, 5.0];
2173        let low_data = [0.5, 1.0, 1.5, 2.0, 2.5];
2174
2175        assert!(stream.update(high_data[0], low_data[0]).is_none());
2176        assert!(stream.update(high_data[1], low_data[1]).is_none());
2177
2178        let first_corr = stream.update(high_data[2], low_data[2]);
2179        assert!(first_corr.is_some());
2180
2181        let second_corr = stream.update(high_data[3], low_data[3]);
2182        assert!(second_corr.is_some());
2183
2184        let params_batch = CorrelHlParams { period: Some(3) };
2185        let input_batch = CorrelHlInput::from_slices(&high_data[..4], &low_data[..4], params_batch);
2186        let batch_result = correl_hl_with_kernel(&input_batch, kernel)?;
2187
2188        if let Some(batch_val) = batch_result.values.iter().rev().find(|&&v| !v.is_nan()) {
2189            if let Some(stream_val) = second_corr {
2190                assert!(
2191                    (batch_val - stream_val).abs() < 1e-10,
2192                    "[{}] Streaming vs batch mismatch: {} vs {}",
2193                    test_name,
2194                    stream_val,
2195                    batch_val
2196                );
2197            }
2198        }
2199
2200        Ok(())
2201    }
2202
2203    #[cfg(feature = "proptest")]
2204    #[allow(clippy::float_cmp)]
2205    fn check_correl_hl_property(
2206        test_name: &str,
2207        kernel: Kernel,
2208    ) -> Result<(), Box<dyn std::error::Error>> {
2209        use proptest::prelude::*;
2210        skip_if_unsupported!(kernel, test_name);
2211
2212        let strat = (2usize..=100).prop_flat_map(|period| {
2213            (
2214                prop::collection::vec((1.0f64..1000.0f64), period..400).prop_flat_map(
2215                    move |close_prices| {
2216                        let len = close_prices.len();
2217                        (
2218                            Just(close_prices.clone()),
2219                            prop::collection::vec((0.001f64..0.05f64, 0.001f64..0.05f64), len),
2220                        )
2221                            .prop_map(move |(close, spreads)| {
2222                                let mut high = Vec::with_capacity(len);
2223                                let mut low = Vec::with_capacity(len);
2224
2225                                for (i, &close_price) in close.iter().enumerate() {
2226                                    let (up_spread, down_spread) = spreads[i];
2227
2228                                    high.push(close_price * (1.0 + up_spread));
2229                                    low.push(close_price * (1.0 - down_spread));
2230                                }
2231
2232                                (high, low)
2233                            })
2234                    },
2235                ),
2236                Just(period),
2237            )
2238        });
2239
2240        proptest::test_runner::TestRunner::default()
2241            .run(&strat, |((high, low), period)| {
2242                let params = CorrelHlParams {
2243                    period: Some(period),
2244                };
2245                let input = CorrelHlInput::from_slices(&high, &low, params);
2246
2247                let result = correl_hl_with_kernel(&input, kernel);
2248                let reference = correl_hl_with_kernel(&input, Kernel::Scalar);
2249
2250                match (result, reference) {
2251                    (Ok(output), Ok(ref_output)) => {
2252                        let out = &output.values;
2253                        let ref_out = &ref_output.values;
2254
2255                        prop_assert_eq!(out.len(), high.len());
2256
2257                        let warmup_len = period.saturating_sub(1).min(high.len());
2258                        for i in 0..warmup_len {
2259                            prop_assert!(
2260                                out[i].is_nan(),
2261                                "Expected NaN during warmup at index {}, got {}",
2262                                i,
2263                                out[i]
2264                            );
2265                        }
2266
2267                        for i in 0..out.len() {
2268                            let y = out[i];
2269                            let r = ref_out[i];
2270
2271                            if !y.is_finite() || !r.is_finite() {
2272                                prop_assert_eq!(
2273                                    y.to_bits(),
2274                                    r.to_bits(),
2275                                    "Special value mismatch at index {}: {} vs {}",
2276                                    i,
2277                                    y,
2278                                    r
2279                                );
2280                                continue;
2281                            }
2282
2283                            let ulp_diff = y.to_bits().abs_diff(r.to_bits());
2284                            prop_assert!(
2285                                (y - r).abs() <= 1e-9 || ulp_diff <= 4,
2286                                "Kernel mismatch at index {}: {} vs {} (ULP={})",
2287                                i,
2288                                y,
2289                                r,
2290                                ulp_diff
2291                            );
2292                        }
2293
2294                        for (i, &val) in out.iter().enumerate() {
2295                            if !val.is_nan() {
2296                                let tolerance = 1e-3;
2297                                prop_assert!(
2298                                    val >= -1.0 - tolerance && val <= 1.0 + tolerance,
2299                                    "Correlation at index {} out of range: {}",
2300                                    i,
2301                                    val
2302                                );
2303                            }
2304                        }
2305                    }
2306                    (Err(_), Err(_)) => {}
2307                    (Ok(_), Err(e)) => {
2308                        prop_assert!(
2309                            false,
2310                            "Reference kernel failed but test kernel succeeded: {}",
2311                            e
2312                        );
2313                    }
2314                    (Err(e), Ok(_)) => {
2315                        prop_assert!(
2316                            false,
2317                            "Test kernel failed but reference kernel succeeded: {}",
2318                            e
2319                        );
2320                    }
2321                }
2322
2323                Ok(())
2324            })
2325            .unwrap();
2326
2327        Ok(())
2328    }
2329
2330    #[cfg(debug_assertions)]
2331    fn check_correl_hl_no_poison(
2332        test_name: &str,
2333        kernel: Kernel,
2334    ) -> Result<(), Box<dyn std::error::Error>> {
2335        skip_if_unsupported!(kernel, test_name);
2336
2337        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2338        let candles = read_candles_from_csv(file_path)?;
2339
2340        let input = CorrelHlInput::from_candles(&candles, CorrelHlParams::default());
2341        let output = correl_hl_with_kernel(&input, kernel)?;
2342
2343        for (i, &val) in output.values.iter().enumerate() {
2344            if val.is_nan() {
2345                continue;
2346            }
2347
2348            let bits = val.to_bits();
2349
2350            if bits == 0x11111111_11111111 {
2351                panic!(
2352                    "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {}",
2353                    test_name, val, bits, i
2354                );
2355            }
2356
2357            if bits == 0x22222222_22222222 {
2358                panic!(
2359                    "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {}",
2360                    test_name, val, bits, i
2361                );
2362            }
2363
2364            if bits == 0x33333333_33333333 {
2365                panic!(
2366                    "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {}",
2367                    test_name, val, bits, i
2368                );
2369            }
2370        }
2371
2372        Ok(())
2373    }
2374
2375    #[cfg(not(debug_assertions))]
2376    fn check_correl_hl_no_poison(
2377        _test_name: &str,
2378        _kernel: Kernel,
2379    ) -> Result<(), Box<dyn std::error::Error>> {
2380        Ok(())
2381    }
2382
2383    macro_rules! generate_all_correl_hl_tests {
2384        ($($test_fn:ident),*) => {
2385            paste::paste! {
2386                $(
2387                    #[test]
2388                    fn [<$test_fn _scalar_f64>]() {
2389                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2390                    }
2391                )*
2392                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2393                $(
2394                    #[test]
2395                    fn [<$test_fn _avx2_f64>]() {
2396                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2397                    }
2398                    #[test]
2399                    fn [<$test_fn _avx512_f64>]() {
2400                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2401                    }
2402                )*
2403            }
2404        }
2405    }
2406
2407    #[test]
2408    fn test_period_one_bug() {
2409        let high = vec![100.0, 200.0, 300.0];
2410        let low = vec![90.0, 190.0, 310.0];
2411
2412        let params = CorrelHlParams { period: Some(1) };
2413        let input = CorrelHlInput::from_slices(&high, &low, params.clone());
2414        let result = correl_hl(&input).unwrap();
2415
2416        println!("Period=1 correlation with different high/low:");
2417        for (i, &val) in result.values.iter().enumerate() {
2418            println!(
2419                "  Index {}: high={}, low={}, corr={}",
2420                i, high[i], low[i], val
2421            );
2422
2423            assert!(
2424                val.is_nan() || (val >= -1.0 && val <= 1.0),
2425                "Period=1 correlation at index {} out of bounds: {}",
2426                i,
2427                val
2428            );
2429        }
2430
2431        let high2 = vec![100.0, 200.0, 300.0];
2432        let low2 = vec![100.0, 200.0, 300.0];
2433
2434        let input2 = CorrelHlInput::from_slices(&high2, &low2, params.clone());
2435        let result2 = correl_hl(&input2).unwrap();
2436
2437        println!("\nPeriod=1 correlation with identical high/low:");
2438        for (i, &val) in result2.values.iter().enumerate() {
2439            println!(
2440                "  Index {}: high={}, low={}, corr={}",
2441                i, high2[i], low2[i], val
2442            );
2443        }
2444    }
2445
2446    generate_all_correl_hl_tests!(
2447        check_correl_hl_partial_params,
2448        check_correl_hl_accuracy,
2449        check_correl_hl_zero_period,
2450        check_correl_hl_period_exceeds_length,
2451        check_correl_hl_data_length_mismatch,
2452        check_correl_hl_all_nan,
2453        check_correl_hl_from_candles,
2454        check_correl_hl_reinput,
2455        check_correl_hl_very_small_dataset,
2456        check_correl_hl_empty_input,
2457        check_correl_hl_nan_handling,
2458        check_correl_hl_streaming,
2459        check_correl_hl_no_poison
2460    );
2461
2462    #[cfg(feature = "proptest")]
2463    generate_all_correl_hl_tests!(check_correl_hl_property);
2464
2465    fn check_batch_default_row(
2466        test: &str,
2467        kernel: Kernel,
2468    ) -> Result<(), Box<dyn std::error::Error>> {
2469        skip_if_unsupported!(kernel, test);
2470
2471        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2472        let c = read_candles_from_csv(file)?;
2473
2474        let output = CorrelHlBatchBuilder::new()
2475            .kernel(kernel)
2476            .apply_candles(&c)?;
2477
2478        let def = CorrelHlParams::default();
2479        let row = output.values_for(&def).expect("default row missing");
2480
2481        assert_eq!(row.len(), c.close.len());
2482        Ok(())
2483    }
2484
2485    #[cfg(debug_assertions)]
2486    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
2487        skip_if_unsupported!(kernel, test);
2488
2489        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2490        let c = read_candles_from_csv(file)?;
2491
2492        let output = CorrelHlBatchBuilder::new()
2493            .kernel(kernel)
2494            .period_range(5, 20, 5)
2495            .apply_candles(&c)?;
2496
2497        for (idx, &val) in output.values.iter().enumerate() {
2498            if val.is_nan() {
2499                continue;
2500            }
2501
2502            let bits = val.to_bits();
2503            let row = idx / output.cols;
2504            let col = idx % output.cols;
2505
2506            if bits == 0x11111111_11111111 {
2507                panic!(
2508					"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {})",
2509					test, val, bits, row, col, idx
2510				);
2511            }
2512
2513            if bits == 0x22222222_22222222 {
2514                panic!(
2515					"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {})",
2516					test, val, bits, row, col, idx
2517				);
2518            }
2519
2520            if bits == 0x33333333_33333333 {
2521                panic!(
2522					"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {})",
2523					test, val, bits, row, col, idx
2524				);
2525            }
2526        }
2527
2528        Ok(())
2529    }
2530
2531    #[cfg(not(debug_assertions))]
2532    fn check_batch_no_poison(
2533        _test: &str,
2534        _kernel: Kernel,
2535    ) -> Result<(), Box<dyn std::error::Error>> {
2536        Ok(())
2537    }
2538
2539    macro_rules! gen_batch_tests {
2540        ($fn_name:ident) => {
2541            paste::paste! {
2542                #[test] fn [<$fn_name _scalar>]()      {
2543                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2544                }
2545                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2546                #[test] fn [<$fn_name _avx2>]()        {
2547                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2548                }
2549                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2550                #[test] fn [<$fn_name _avx512>]()      {
2551                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2552                }
2553                #[test] fn [<$fn_name _auto_detect>]() {
2554                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2555                }
2556            }
2557        };
2558    }
2559    gen_batch_tests!(check_batch_default_row);
2560    gen_batch_tests!(check_batch_no_poison);
2561}