Skip to main content

vector_ta/indicators/moving_averages/
edcf.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{alloc_with_nan_prefix, init_matrix_prefixes, make_uninit_matrix};
4#[cfg(feature = "python")]
5use crate::utilities::kernel_validation::validate_kernel;
6#[cfg(feature = "python")]
7use numpy::{IntoPyArray, PyArray1};
8#[cfg(feature = "python")]
9use pyo3::exceptions::PyValueError;
10#[cfg(feature = "python")]
11use pyo3::prelude::*;
12#[cfg(feature = "python")]
13use pyo3::types::PyDict;
14#[cfg(not(target_arch = "wasm32"))]
15use rayon::prelude::*;
16#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
17use serde::{Deserialize, Serialize};
18use std::convert::AsRef;
19use std::mem::MaybeUninit;
20use thiserror::Error;
21#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
22use wasm_bindgen::prelude::*;
23
24#[cfg(all(feature = "python", feature = "cuda"))]
25use crate::cuda::moving_averages::DeviceArrayF32;
26#[cfg(all(feature = "python", feature = "cuda"))]
27use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
28
29#[derive(Debug, Clone)]
30pub enum EdcfData<'a> {
31    Candles {
32        candles: &'a Candles,
33        source: &'a str,
34    },
35    Slice(&'a [f64]),
36}
37
38impl<'a> AsRef<[f64]> for EdcfInput<'a> {
39    #[inline(always)]
40    fn as_ref(&self) -> &[f64] {
41        match &self.data {
42            EdcfData::Slice(slice) => slice,
43            EdcfData::Candles { candles, source } => source_type(candles, source),
44        }
45    }
46}
47
48#[derive(Debug, Clone)]
49#[cfg_attr(
50    all(target_arch = "wasm32", feature = "wasm"),
51    derive(Serialize, Deserialize)
52)]
53
54pub struct EdcfParams {
55    pub period: Option<usize>,
56}
57
58impl Default for EdcfParams {
59    fn default() -> Self {
60        Self { period: Some(15) }
61    }
62}
63
64#[derive(Debug, Clone)]
65pub struct EdcfOutput {
66    pub values: Vec<f64>,
67}
68
69#[derive(Debug, Clone)]
70pub struct EdcfInput<'a> {
71    pub data: EdcfData<'a>,
72
73    pub params: EdcfParams,
74}
75
76impl<'a> EdcfInput<'a> {
77    #[inline]
78    pub fn from_candles(c: &'a Candles, s: &'a str, p: EdcfParams) -> Self {
79        Self {
80            data: EdcfData::Candles {
81                candles: c,
82                source: s,
83            },
84            params: p,
85        }
86    }
87    #[inline]
88    pub fn from_slice(sl: &'a [f64], p: EdcfParams) -> Self {
89        Self {
90            data: EdcfData::Slice(sl),
91            params: p,
92        }
93    }
94    #[inline]
95    pub fn with_default_candles(c: &'a Candles) -> Self {
96        Self::from_candles(c, "close", EdcfParams::default())
97    }
98    #[inline]
99    pub fn get_period(&self) -> usize {
100        self.params.period.unwrap_or(15)
101    }
102}
103
104#[derive(Debug, Error)]
105pub enum EdcfError {
106    #[error("edcf: No data provided to EDCF filter.")]
107    NoData,
108    #[error("edcf: Empty input data.")]
109    EmptyInputData,
110    #[error("edcf: All values are NaN.")]
111    AllValuesNaN,
112    #[error("edcf: Invalid period: period = {period}, data length = {data_len}")]
113    InvalidPeriod { period: usize, data_len: usize },
114    #[error("edcf: Not enough valid data: needed = {needed}, valid = {valid}")]
115    NotEnoughValidData { needed: usize, valid: usize },
116    #[error("edcf: Output buffer length mismatch: expected = {expected}, got = {got}")]
117    OutputLenMismatch { expected: usize, got: usize },
118    #[error("edcf: Invalid kernel specified")]
119    InvalidKernel,
120    #[error("edcf: Invalid range: start={start}, end={end}, step={step}")]
121    InvalidRange {
122        start: usize,
123        end: usize,
124        step: usize,
125    },
126    #[error("edcf: Invalid kernel for batch API: {0:?}")]
127    InvalidKernelForBatch(Kernel),
128    #[error("edcf: size overflow during allocation ({op})")]
129    SizeOverflow { op: &'static str },
130}
131
132#[derive(Copy, Clone, Debug)]
133
134pub struct EdcfBuilder {
135    period: Option<usize>,
136    kernel: Kernel,
137}
138
139impl Default for EdcfBuilder {
140    fn default() -> Self {
141        Self {
142            period: None,
143            kernel: Kernel::Auto,
144        }
145    }
146}
147
148impl EdcfBuilder {
149    #[inline(always)]
150    pub fn new() -> Self {
151        Self::default()
152    }
153    #[inline(always)]
154    pub fn period(mut self, n: usize) -> Self {
155        self.period = Some(n);
156        self
157    }
158    #[inline(always)]
159    pub fn kernel(mut self, k: Kernel) -> Self {
160        self.kernel = k;
161        self
162    }
163    #[inline(always)]
164    pub fn apply(self, c: &Candles) -> Result<EdcfOutput, EdcfError> {
165        let p = EdcfParams {
166            period: self.period,
167        };
168        let i = EdcfInput::from_candles(c, "close", p);
169        edcf_with_kernel(&i, self.kernel)
170    }
171    #[inline(always)]
172    pub fn apply_slice(self, d: &[f64]) -> Result<EdcfOutput, EdcfError> {
173        let p = EdcfParams {
174            period: self.period,
175        };
176        let i = EdcfInput::from_slice(d, p);
177        edcf_with_kernel(&i, self.kernel)
178    }
179    #[inline(always)]
180    pub fn into_stream(self) -> Result<EdcfStream, EdcfError> {
181        let p = EdcfParams {
182            period: self.period,
183        };
184        EdcfStream::try_new(p)
185    }
186}
187
188#[inline]
189pub fn edcf(input: &EdcfInput) -> Result<EdcfOutput, EdcfError> {
190    edcf_with_kernel(input, Kernel::Auto)
191}
192
193#[inline(always)]
194fn edcf_prepare<'a>(
195    input: &'a EdcfInput,
196    kernel: Kernel,
197) -> Result<(&'a [f64], usize, usize, usize, Kernel), EdcfError> {
198    let data: &[f64] = input.as_ref();
199    let len = data.len();
200    if len == 0 {
201        return Err(EdcfError::NoData);
202    }
203    let first = data
204        .iter()
205        .position(|x| !x.is_nan())
206        .ok_or(EdcfError::AllValuesNaN)?;
207    let period = input.get_period();
208    if period == 0 || period > len {
209        return Err(EdcfError::InvalidPeriod {
210            period,
211            data_len: len,
212        });
213    }
214    let needed = 2 * period;
215    if (len - first) < needed {
216        return Err(EdcfError::NotEnoughValidData {
217            needed,
218            valid: len - first,
219        });
220    }
221
222    let warm = first + 2 * period;
223
224    let chosen = match kernel {
225        Kernel::Auto => Kernel::Scalar,
226        other => other,
227    };
228
229    Ok((data, period, first, warm, chosen))
230}
231
232#[inline(always)]
233fn edcf_compute_into(data: &[f64], period: usize, first: usize, chosen: Kernel, out: &mut [f64]) {
234    #[cfg(target_arch = "wasm32")]
235    {
236        if matches!(chosen, Kernel::Scalar | Kernel::ScalarBatch) {
237            edcf_scalar_wasm(data, period, first, out);
238            return;
239        }
240    }
241
242    unsafe {
243        match chosen {
244            Kernel::Scalar | Kernel::ScalarBatch => edcf_scalar(data, period, first, out),
245            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
246            Kernel::Avx2 | Kernel::Avx2Batch => edcf_avx2(data, period, first, out),
247            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
248            Kernel::Avx512 | Kernel::Avx512Batch => edcf_avx512(data, period, first, out),
249            _ => unreachable!(),
250        }
251    }
252}
253
254pub fn edcf_with_kernel(input: &EdcfInput, kernel: Kernel) -> Result<EdcfOutput, EdcfError> {
255    let (data, period, first, warm, chosen) = edcf_prepare(input, kernel)?;
256    let len = data.len();
257    let mut out = alloc_with_nan_prefix(len, warm);
258    edcf_compute_into(data, period, first, chosen, &mut out);
259    Ok(EdcfOutput { values: out })
260}
261
262#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
263#[inline]
264pub fn edcf_into(input: &EdcfInput, out: &mut [f64]) -> Result<(), EdcfError> {
265    let (data, period, first, warm, chosen) = edcf_prepare(input, Kernel::Auto)?;
266
267    if out.len() != data.len() {
268        return Err(EdcfError::OutputLenMismatch {
269            expected: data.len(),
270            got: out.len(),
271        });
272    }
273
274    let in_ptr = data.as_ptr();
275    let out_ptr = out.as_ptr();
276    if core::ptr::eq(in_ptr, out_ptr) {
277        let mut temp = alloc_with_nan_prefix(out.len(), warm);
278        edcf_compute_into(data, period, first, chosen, &mut temp);
279        out.copy_from_slice(&temp);
280        return Ok(());
281    }
282
283    let warm = warm.min(out.len());
284    for v in &mut out[..warm] {
285        *v = f64::from_bits(0x7ff8_0000_0000_0000);
286    }
287
288    edcf_compute_into(data, period, first, chosen, out);
289
290    Ok(())
291}
292
293#[inline]
294pub fn edcf_into_slice(dst: &mut [f64], input: &EdcfInput, kern: Kernel) -> Result<(), EdcfError> {
295    let (data, period, first, warm, chosen) = edcf_prepare(input, kern)?;
296
297    if dst.len() != data.len() {
298        return Err(EdcfError::OutputLenMismatch {
299            expected: data.len(),
300            got: dst.len(),
301        });
302    }
303
304    edcf_compute_into(data, period, first, chosen, dst);
305
306    for v in &mut dst[..warm] {
307        *v = f64::NAN;
308    }
309
310    Ok(())
311}
312
313#[inline(always)]
314pub fn edcf_scalar(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
315    let mut buf = vec![0.0; period];
316    let mut wbuf = vec![0.0; period];
317    edcf_scalar_o1_into(data, period, first_valid, out, &mut buf, &mut wbuf);
318}
319
320#[inline(always)]
321fn edcf_scalar_o1_into(
322    data: &[f64],
323    period: usize,
324    first_valid: usize,
325    out: &mut [f64],
326    buf: &mut [f64],
327    wbuf: &mut [f64],
328) {
329    debug_assert_eq!(buf.len(), period);
330    debug_assert_eq!(wbuf.len(), period);
331
332    buf.fill(0.0);
333    wbuf.fill(0.0);
334
335    let len = data.len();
336    let warm = first_valid + 2 * period;
337
338    let mut head = 0usize;
339    let mut count = 0usize;
340
341    let mut sum_prev = 0.0;
342    let mut sum_prev_sq = 0.0;
343    let mut den = 0.0;
344    let mut num = 0.0;
345    let p_minus1_f = (period - 1) as f64;
346
347    for idx in first_valid..len {
348        let value = data[idx];
349
350        let old_x = unsafe { *buf.get_unchecked(head) };
351        let old_w = unsafe { *wbuf.get_unchecked(head) };
352        let had_full_window = count >= period;
353
354        let w_new = if count >= period {
355            let x2 = value * value;
356            p_minus1_f.mul_add(x2, sum_prev_sq) - (2.0 * value * sum_prev)
357        } else {
358            0.0
359        };
360
361        if had_full_window {
362            den -= old_w;
363            num -= old_w * old_x;
364        }
365        den += w_new;
366        num = w_new.mul_add(value, num);
367
368        unsafe {
369            *buf.get_unchecked_mut(head) = value;
370            *wbuf.get_unchecked_mut(head) = w_new;
371        }
372        head += 1;
373        if head == period {
374            head = 0;
375        }
376
377        sum_prev += value;
378        sum_prev_sq = value.mul_add(value, sum_prev_sq);
379        if count >= (period - 1) {
380            let drop_x = unsafe { *buf.get_unchecked(head) };
381            sum_prev -= drop_x;
382            sum_prev_sq -= drop_x * drop_x;
383        }
384
385        count += 1;
386
387        if idx >= warm {
388            if den != 0.0 {
389                out[idx] = num / den;
390            } else {
391                out[idx] = f64::NAN;
392            }
393        }
394    }
395}
396
397#[inline(always)]
398fn edcf_scalar_into_with_scratch(
399    data: &[f64],
400    period: usize,
401    first_valid: usize,
402    out: &mut [f64],
403    scratch: &mut Vec<f64>,
404) {
405    let need = period * 2;
406    if scratch.len() < need {
407        scratch.resize(need, 0.0);
408    }
409    let (buf, wbuf) = scratch.split_at_mut(period);
410    edcf_scalar_o1_into(data, period, first_valid, out, buf, wbuf);
411}
412
413#[cfg(target_arch = "wasm32")]
414#[inline]
415fn edcf_scalar_wasm(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
416    edcf_scalar(data, period, first_valid, out);
417}
418
419#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
420#[target_feature(enable = "avx2,fma")]
421pub unsafe fn edcf_avx2(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
422    edcf_scalar(data, period, first_valid, out);
423}
424#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
425#[target_feature(enable = "avx512f,avx512dq,fma")]
426pub unsafe fn edcf_avx512(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
427    edcf_scalar(data, period, first_valid, out);
428}
429
430#[derive(Debug, Clone)]
431
432pub struct EdcfStream {
433    period: usize,
434
435    buffer: Vec<f64>,
436
437    dist: Vec<f64>,
438
439    head: usize,
440
441    count: usize,
442
443    sum_prev: f64,
444    sum_prev_sq: f64,
445
446    den: f64,
447    num: f64,
448
449    p_minus1_f: f64,
450}
451
452impl EdcfStream {
453    #[inline]
454    pub fn try_new(params: EdcfParams) -> Result<Self, EdcfError> {
455        let period = params.period.unwrap_or(15);
456        if period == 0 {
457            return Err(EdcfError::InvalidPeriod {
458                period,
459                data_len: 0,
460            });
461        }
462
463        let buffer = alloc_with_nan_prefix(period, period);
464        let dist = vec![0.0; period];
465
466        Ok(Self {
467            period,
468            buffer,
469            dist,
470            head: 0,
471            count: 0,
472            sum_prev: 0.0,
473            sum_prev_sq: 0.0,
474            den: 0.0,
475            num: 0.0,
476            p_minus1_f: (period - 1) as f64,
477        })
478    }
479
480    #[inline(always)]
481    fn bump_head(&mut self) {
482        let n = self.head + 1;
483        self.head = if n == self.period { 0 } else { n };
484    }
485
486    #[inline(always)]
487    pub fn update(&mut self, value: f64) -> Option<f64> {
488        if !value.is_finite() {
489            return None;
490        }
491
492        let p = self.period;
493
494        let old_x = self.buffer[self.head];
495        let old_w = self.dist[self.head];
496        let had_full_window = self.count >= p;
497
498        let w_new = if self.count >= p {
499            let x2 = value * value;
500            self.p_minus1_f.mul_add(x2, self.sum_prev_sq) - (2.0 * value * self.sum_prev)
501        } else {
502            0.0
503        };
504
505        if had_full_window {
506            self.den -= old_w;
507            self.num -= old_w * old_x;
508        }
509
510        self.den += w_new;
511        self.num = w_new.mul_add(value, self.num);
512
513        self.buffer[self.head] = value;
514        self.dist[self.head] = w_new;
515        self.bump_head();
516
517        self.sum_prev += value;
518        self.sum_prev_sq = value.mul_add(value, self.sum_prev_sq);
519        if self.count >= (p - 1) {
520            let drop_x = self.buffer[self.head];
521            self.sum_prev -= drop_x;
522            self.sum_prev_sq -= drop_x * drop_x;
523        }
524
525        self.count += 1;
526        if self.count < 2 * p {
527            return None;
528        }
529        if self.den != 0.0 {
530            Some(fast_div(self.num, self.den))
531        } else {
532            None
533        }
534    }
535}
536
537#[inline(always)]
538fn fast_div(num: f64, den: f64) -> f64 {
539    num / den
540}
541
542#[derive(Clone, Debug)]
543#[cfg_attr(
544    all(target_arch = "wasm32", feature = "wasm"),
545    derive(Serialize, Deserialize)
546)]
547pub struct EdcfBatchRange {
548    pub period: (usize, usize, usize),
549}
550
551impl Default for EdcfBatchRange {
552    fn default() -> Self {
553        Self {
554            period: (15, 264, 1),
555        }
556    }
557}
558
559#[derive(Clone, Debug, Default)]
560
561pub struct EdcfBatchBuilder {
562    range: EdcfBatchRange,
563    kernel: Kernel,
564}
565
566impl EdcfBatchBuilder {
567    pub fn new() -> Self {
568        Self::default()
569    }
570    pub fn kernel(mut self, k: Kernel) -> Self {
571        self.kernel = k;
572        self
573    }
574    #[inline]
575    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
576        self.range.period = (start, end, step);
577        self
578    }
579    #[inline]
580    pub fn period_static(mut self, p: usize) -> Self {
581        self.range.period = (p, p, 0);
582        self
583    }
584    pub fn apply_slice(self, data: &[f64]) -> Result<EdcfBatchOutput, EdcfError> {
585        edcf_batch_with_kernel(data, &self.range, self.kernel)
586    }
587    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<EdcfBatchOutput, EdcfError> {
588        EdcfBatchBuilder::new().kernel(k).apply_slice(data)
589    }
590    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<EdcfBatchOutput, EdcfError> {
591        let slice = source_type(c, src);
592        self.apply_slice(slice)
593    }
594    pub fn with_default_candles(c: &Candles) -> Result<EdcfBatchOutput, EdcfError> {
595        EdcfBatchBuilder::new()
596            .kernel(Kernel::Auto)
597            .apply_candles(c, "close")
598    }
599}
600
601pub fn edcf_batch_with_kernel(
602    data: &[f64],
603    sweep: &EdcfBatchRange,
604    k: Kernel,
605) -> Result<EdcfBatchOutput, EdcfError> {
606    if data.is_empty() {
607        return Err(EdcfError::NoData);
608    }
609    let kernel = match k {
610        Kernel::Auto => Kernel::ScalarBatch,
611        other if other.is_batch() => other,
612        other => return Err(EdcfError::InvalidKernelForBatch(other)),
613    };
614    let simd = match kernel {
615        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
616        Kernel::Avx512Batch => Kernel::Avx512,
617        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
618        Kernel::Avx2Batch => Kernel::Avx2,
619        Kernel::ScalarBatch => Kernel::Scalar,
620        _ => unreachable!(),
621    };
622    edcf_batch_par_slice(data, sweep, simd)
623}
624
625#[derive(Clone, Debug)]
626pub struct EdcfBatchOutput {
627    pub values: Vec<f64>,
628
629    pub combos: Vec<EdcfParams>,
630
631    pub rows: usize,
632
633    pub cols: usize,
634}
635impl EdcfBatchOutput {
636    pub fn row_for_params(&self, p: &EdcfParams) -> Option<usize> {
637        self.combos
638            .iter()
639            .position(|c| c.period.unwrap_or(15) == p.period.unwrap_or(15))
640    }
641    pub fn values_for(&self, p: &EdcfParams) -> Option<&[f64]> {
642        self.row_for_params(p).map(|row| {
643            let start = row * self.cols;
644            &self.values[start..start + self.cols]
645        })
646    }
647}
648
649#[inline(always)]
650fn expand_grid(r: &EdcfBatchRange) -> Vec<EdcfParams> {
651    let (mut start, mut end, step) = r.period;
652
653    if start > end {
654        core::mem::swap(&mut start, &mut end);
655    }
656    let periods: Vec<usize> = if step == 0 || start == end {
657        vec![start]
658    } else {
659        (start..=end).step_by(step).collect()
660    };
661
662    periods
663        .into_iter()
664        .map(|p| EdcfParams { period: Some(p) })
665        .collect()
666}
667
668#[inline(always)]
669pub fn edcf_batch_slice(
670    data: &[f64],
671    sweep: &EdcfBatchRange,
672    kern: Kernel,
673) -> Result<EdcfBatchOutput, EdcfError> {
674    edcf_batch_inner(data, sweep, kern, false)
675}
676
677#[inline(always)]
678pub fn edcf_batch_par_slice(
679    data: &[f64],
680    sweep: &EdcfBatchRange,
681    kern: Kernel,
682) -> Result<EdcfBatchOutput, EdcfError> {
683    edcf_batch_inner(data, sweep, kern, true)
684}
685
686#[inline(always)]
687fn edcf_batch_inner(
688    data: &[f64],
689    sweep: &EdcfBatchRange,
690    kern: Kernel,
691    parallel: bool,
692) -> Result<EdcfBatchOutput, EdcfError> {
693    if data.is_empty() {
694        return Err(EdcfError::NoData);
695    }
696
697    let combos = expand_grid(sweep);
698    if combos.is_empty() {
699        return Err(EdcfError::InvalidRange {
700            start: sweep.period.0,
701            end: sweep.period.1,
702            step: sweep.period.2,
703        });
704    }
705
706    let rows = combos.len();
707    let cols = data.len();
708
709    let _total = rows
710        .checked_mul(cols)
711        .ok_or(EdcfError::SizeOverflow { op: "rows*cols" })?;
712
713    let mut buf_mu = make_uninit_matrix(rows, cols);
714
715    let warm: Vec<usize> = combos
716        .iter()
717        .map(|c| data.iter().position(|x| !x.is_nan()).unwrap_or(0) + 2 * c.period.unwrap_or(15))
718        .collect();
719
720    init_matrix_prefixes(&mut buf_mu, cols, &warm);
721
722    let mut buf_guard = std::mem::ManuallyDrop::new(buf_mu);
723    let out: &mut [f64] = unsafe {
724        core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
725    };
726
727    let result_combos = edcf_batch_inner_into(data, sweep, kern, parallel, out)?;
728
729    let values = unsafe {
730        Vec::from_raw_parts(
731            buf_guard.as_mut_ptr() as *mut f64,
732            buf_guard.len(),
733            buf_guard.capacity(),
734        )
735    };
736
737    Ok(EdcfBatchOutput {
738        values,
739        combos: result_combos,
740        rows,
741        cols,
742    })
743}
744
745#[inline(always)]
746fn edcf_batch_inner_into(
747    data: &[f64],
748    sweep: &EdcfBatchRange,
749    kern: Kernel,
750    parallel: bool,
751    out: &mut [f64],
752) -> Result<Vec<EdcfParams>, EdcfError> {
753    if data.is_empty() {
754        return Err(EdcfError::NoData);
755    }
756    let combos = expand_grid(sweep);
757    if combos.is_empty() {
758        return Err(EdcfError::InvalidRange {
759            start: sweep.period.0,
760            end: sweep.period.1,
761            step: sweep.period.2,
762        });
763    }
764
765    let first = data
766        .iter()
767        .position(|x| !x.is_nan())
768        .ok_or(EdcfError::AllValuesNaN)?;
769    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
770    if data.len() - first < 2 * max_p {
771        return Err(EdcfError::NotEnoughValidData {
772            needed: 2 * max_p,
773            valid: data.len() - first,
774        });
775    }
776
777    let rows = combos.len();
778    let cols = data.len();
779
780    if parallel {
781        #[cfg(not(target_arch = "wasm32"))]
782        {
783            use rayon::prelude::*;
784
785            out.par_chunks_mut(cols).enumerate().for_each(|(row, dst)| {
786                let period = combos[row].period.unwrap();
787
788                let mut scratch = Vec::<f64>::new();
789                match kern {
790                    Kernel::Scalar | Kernel::Avx2 | Kernel::Avx512 => {
791                        edcf_scalar_into_with_scratch(data, period, first, dst, &mut scratch);
792                    }
793                    _ => unsafe { edcf_row_scalar(data, first, period, dst) },
794                }
795            });
796        }
797
798        #[cfg(target_arch = "wasm32")]
799        {
800            for (row, dst) in out.chunks_mut(cols).enumerate() {
801                let period = combos[row].period.unwrap();
802                unsafe { edcf_row_scalar(data, first, period, dst) }
803            }
804        }
805    } else {
806        #[cfg(not(target_arch = "wasm32"))]
807        {
808            let mut scratch = Vec::<f64>::new();
809            for (row, dst) in out.chunks_mut(cols).enumerate() {
810                let period = combos[row].period.unwrap();
811                match kern {
812                    Kernel::Scalar | Kernel::Avx2 | Kernel::Avx512 => {
813                        edcf_scalar_into_with_scratch(data, period, first, dst, &mut scratch)
814                    }
815                    _ => unsafe { edcf_row_scalar(data, first, period, dst) },
816                }
817            }
818        }
819        #[cfg(target_arch = "wasm32")]
820        {
821            for (row, dst) in out.chunks_mut(cols).enumerate() {
822                let period = combos[row].period.unwrap();
823                unsafe { edcf_row_scalar(data, first, period, dst) }
824            }
825        }
826    }
827
828    Ok(combos)
829}
830
831#[inline(always)]
832unsafe fn edcf_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
833    edcf_scalar(data, period, first, out)
834}
835#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
836#[inline(always)]
837unsafe fn edcf_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
838    edcf_avx2(data, period, first, out);
839}
840#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
841#[inline(always)]
842unsafe fn edcf_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
843    edcf_avx512(data, period, first, out);
844}
845
846#[cfg(feature = "python")]
847pub fn register_edcf_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
848    m.add_function(wrap_pyfunction!(edcf_py, m)?)?;
849    m.add_function(wrap_pyfunction!(edcf_batch_py, m)?)?;
850    m.add_class::<EdcfStreamPy>()?;
851    #[cfg(feature = "cuda")]
852    {
853        m.add_function(wrap_pyfunction!(edcf_cuda_batch_dev_py, m)?)?;
854        m.add_function(wrap_pyfunction!(edcf_cuda_many_series_one_param_dev_py, m)?)?;
855    }
856    Ok(())
857}
858
859#[cfg(test)]
860mod tests {
861    use super::*;
862    use crate::skip_if_unsupported;
863    use crate::utilities::data_loader::read_candles_from_csv;
864    use proptest::prelude::*;
865    use std::error::Error;
866
867    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
868    #[test]
869    fn test_edcf_into_matches_api() -> Result<(), Box<dyn Error>> {
870        let mut data: Vec<f64> = Vec::new();
871        data.extend_from_slice(&[f64::NAN, f64::NAN, f64::NAN, f64::NAN, f64::NAN]);
872        for i in 0..250usize {
873            let x = (i as f64).sin() * 3.0 + (i as f64) * 0.05 + ((i % 7) as f64) * 0.1;
874            data.push(x);
875        }
876
877        let input = EdcfInput::from_slice(&data, EdcfParams::default());
878
879        let baseline = edcf(&input)?.values;
880
881        let mut out = vec![0.0; data.len()];
882        edcf_into(&input, &mut out)?;
883
884        assert_eq!(baseline.len(), out.len());
885
886        fn eq_or_both_nan(a: f64, b: f64) -> bool {
887            (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
888        }
889
890        for i in 0..out.len() {
891            assert!(
892                eq_or_both_nan(baseline[i], out[i]),
893                "mismatch at {}: expected {:?}, got {:?}",
894                i,
895                baseline[i],
896                out[i]
897            );
898        }
899
900        Ok(())
901    }
902
903    fn check_edcf_partial_params(
904        test_name: &str,
905        kernel: Kernel,
906    ) -> Result<(), Box<dyn std::error::Error>> {
907        skip_if_unsupported!(kernel, test_name);
908        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
909        let candles = read_candles_from_csv(file_path)?;
910        let input = EdcfInput::from_candles(&candles, "close", EdcfParams { period: None });
911        let result = edcf_with_kernel(&input, kernel)?;
912        assert_eq!(result.values.len(), candles.close.len());
913        Ok(())
914    }
915
916    fn check_edcf_accuracy_last_five(
917        test_name: &str,
918        kernel: Kernel,
919    ) -> Result<(), Box<dyn std::error::Error>> {
920        skip_if_unsupported!(kernel, test_name);
921        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
922        let candles = read_candles_from_csv(file_path)?;
923        let input = EdcfInput::from_candles(&candles, "hl2", EdcfParams { period: Some(15) });
924        let result = edcf_with_kernel(&input, kernel)?;
925        let expected = [
926            59593.332275678375,
927            59731.70263288801,
928            59766.41512339413,
929            59655.66162110993,
930            59332.492883847,
931        ];
932        let len = result.values.len();
933        let start = len - expected.len();
934        for (i, &v) in result.values[start..].iter().enumerate() {
935            assert!(
936                (v - expected[i]).abs() < 1e-8,
937                "[{}] EDCF mismatch at {}: got {}, expected {}",
938                test_name,
939                start + i,
940                v,
941                expected[i]
942            );
943        }
944        Ok(())
945    }
946
947    fn check_edcf_with_default_candles(
948        test_name: &str,
949        kernel: Kernel,
950    ) -> Result<(), Box<dyn std::error::Error>> {
951        skip_if_unsupported!(kernel, test_name);
952        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
953        let candles = read_candles_from_csv(file_path)?;
954        let input = EdcfInput::with_default_candles(&candles);
955        match input.data {
956            EdcfData::Candles { source, .. } => assert_eq!(source, "close"),
957            _ => panic!("Expected EdcfData::Candles"),
958        }
959        let period = input.get_period();
960        assert_eq!(period, 15);
961        Ok(())
962    }
963
964    fn check_edcf_with_zero_period(
965        test_name: &str,
966        kernel: Kernel,
967    ) -> Result<(), Box<dyn std::error::Error>> {
968        skip_if_unsupported!(kernel, test_name);
969        let data = [10.0, 20.0, 30.0];
970        let input = EdcfInput::from_slice(&data, EdcfParams { period: Some(0) });
971        let result = edcf_with_kernel(&input, kernel);
972        assert!(result.is_err());
973        Ok(())
974    }
975
976    fn check_edcf_with_no_data(
977        test_name: &str,
978        kernel: Kernel,
979    ) -> Result<(), Box<dyn std::error::Error>> {
980        skip_if_unsupported!(kernel, test_name);
981        let data: [f64; 0] = [];
982        let input = EdcfInput::from_slice(&data, EdcfParams { period: Some(15) });
983        let result = edcf_with_kernel(&input, kernel);
984        assert!(result.is_err());
985        Ok(())
986    }
987
988    fn check_edcf_with_period_exceeding_data_length(
989        test_name: &str,
990        kernel: Kernel,
991    ) -> Result<(), Box<dyn std::error::Error>> {
992        skip_if_unsupported!(kernel, test_name);
993        let data = [10.0, 20.0, 30.0];
994        let input = EdcfInput::from_slice(&data, EdcfParams { period: Some(10) });
995        let result = edcf_with_kernel(&input, kernel);
996        assert!(result.is_err());
997        Ok(())
998    }
999
1000    fn check_edcf_very_small_data_set(
1001        test_name: &str,
1002        kernel: Kernel,
1003    ) -> Result<(), Box<dyn std::error::Error>> {
1004        skip_if_unsupported!(kernel, test_name);
1005        let data = [42.0];
1006        let input = EdcfInput::from_slice(&data, EdcfParams { period: Some(15) });
1007        let result = edcf_with_kernel(&input, kernel);
1008        assert!(result.is_err());
1009        Ok(())
1010    }
1011
1012    fn check_edcf_with_slice_data_reinput(
1013        test_name: &str,
1014        kernel: Kernel,
1015    ) -> Result<(), Box<dyn std::error::Error>> {
1016        skip_if_unsupported!(kernel, test_name);
1017        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1018        let candles = read_candles_from_csv(file_path)?;
1019        let first_input =
1020            EdcfInput::from_candles(&candles, "close", EdcfParams { period: Some(15) });
1021        let first_result = edcf_with_kernel(&first_input, kernel)?;
1022        let first_values = first_result.values;
1023        let second_input = EdcfInput::from_slice(&first_values, EdcfParams { period: Some(5) });
1024        let second_result = edcf_with_kernel(&second_input, kernel)?;
1025        assert_eq!(second_result.values.len(), first_values.len());
1026        if second_result.values.len() > 240 {
1027            for (i, &val) in second_result.values.iter().enumerate().skip(240) {
1028                assert!(
1029                    !val.is_nan(),
1030                    "Found NaN in second EDCF output at index {}",
1031                    i
1032                );
1033            }
1034        }
1035        Ok(())
1036    }
1037
1038    fn check_edcf_accuracy_nan_check(
1039        test_name: &str,
1040        kernel: Kernel,
1041    ) -> Result<(), Box<dyn std::error::Error>> {
1042        skip_if_unsupported!(kernel, test_name);
1043        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1044        let candles = read_candles_from_csv(file_path)?;
1045        let period = 15;
1046        let input = EdcfInput::from_candles(
1047            &candles,
1048            "close",
1049            EdcfParams {
1050                period: Some(period),
1051            },
1052        );
1053        let result = edcf_with_kernel(&input, kernel)?;
1054        assert_eq!(result.values.len(), candles.close.len());
1055        let start_index = 2 * period;
1056        if result.values.len() > start_index {
1057            for (i, &val) in result.values.iter().enumerate().skip(start_index) {
1058                assert!(!val.is_nan(), "Found NaN in EDCF output at index {}", i);
1059            }
1060        }
1061        Ok(())
1062    }
1063
1064    fn check_edcf_streaming(
1065        test_name: &str,
1066        kernel: Kernel,
1067    ) -> Result<(), Box<dyn std::error::Error>> {
1068        skip_if_unsupported!(kernel, test_name);
1069        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1070        let candles = read_candles_from_csv(file_path)?;
1071
1072        let input = EdcfInput::from_candles(&candles, "close", EdcfParams { period: Some(15) });
1073        let _batch = edcf_with_kernel(&input, kernel)?;
1074
1075        let mut stream = EdcfStream::try_new(EdcfParams { period: Some(15) })?;
1076        let mut vals = Vec::with_capacity(candles.close.len());
1077        for &v in &candles.close {
1078            vals.push(stream.update(v).unwrap_or(f64::NAN));
1079        }
1080        for (i, &v) in vals.iter().enumerate().skip(30) {
1081            assert!(!v.is_nan(), "[{test_name}] NaN at {i}");
1082        }
1083        Ok(())
1084    }
1085
1086    #[cfg(debug_assertions)]
1087    fn check_edcf_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1088        skip_if_unsupported!(kernel, test_name);
1089
1090        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1091        let candles = read_candles_from_csv(file_path)?;
1092
1093        let test_periods = vec![3, 5, 10, 15, 30, 50, 100, 200];
1094        let test_sources = vec!["open", "high", "low", "close", "hl2", "hlc3", "ohlc4"];
1095
1096        for period in &test_periods {
1097            for source in &test_sources {
1098                let input = EdcfInput::from_candles(
1099                    &candles,
1100                    source,
1101                    EdcfParams {
1102                        period: Some(*period),
1103                    },
1104                );
1105                let output = edcf_with_kernel(&input, kernel)?;
1106
1107                for (i, &val) in output.values.iter().enumerate() {
1108                    if val.is_nan() {
1109                        continue;
1110                    }
1111
1112                    let bits = val.to_bits();
1113
1114                    if bits == 0x11111111_11111111 {
1115                        panic!(
1116                            "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with period={}, source={}",
1117                            test_name, val, bits, i, period, source
1118                        );
1119                    }
1120
1121                    if bits == 0x22222222_22222222 {
1122                        panic!(
1123                            "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with period={}, source={}",
1124                            test_name, val, bits, i, period, source
1125                        );
1126                    }
1127
1128                    if bits == 0x33333333_33333333 {
1129                        panic!(
1130                            "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with period={}, source={}",
1131                            test_name, val, bits, i, period, source
1132                        );
1133                    }
1134                }
1135            }
1136        }
1137
1138        Ok(())
1139    }
1140
1141    #[cfg(not(debug_assertions))]
1142    fn check_edcf_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1143        Ok(())
1144    }
1145
1146    #[allow(clippy::float_cmp)]
1147    fn check_edcf_property(
1148        test_name: &str,
1149        kernel: Kernel,
1150    ) -> Result<(), Box<dyn std::error::Error>> {
1151        use proptest::prelude::*;
1152        skip_if_unsupported!(kernel, test_name);
1153
1154        let strat = (3usize..=30).prop_flat_map(|period| {
1155            (
1156                prop::collection::vec(
1157                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1158                    2 * period..400,
1159                ),
1160                Just(period),
1161                (-1e3f64..1e3f64).prop_filter("a≠0", |a| a.abs() > 1e-12),
1162                -1e3f64..1e3f64,
1163            )
1164        });
1165
1166        proptest::test_runner::TestRunner::default().run(&strat, |(data, period, a, b)| {
1167            let params = EdcfParams {
1168                period: Some(period),
1169            };
1170            let input = EdcfInput::from_slice(&data, params.clone());
1171
1172            let fast = edcf_with_kernel(&input, kernel);
1173            let slow = edcf_with_kernel(&input, Kernel::Scalar);
1174
1175            match (fast, slow) {
1176                (Err(e1), Err(e2))
1177                    if std::mem::discriminant(&e1) == std::mem::discriminant(&e2) =>
1178                {
1179                    return Ok(());
1180                }
1181
1182                (Err(e1), Err(e2)) => {
1183                    prop_assert!(false, "different errors: fast={:?} slow={:?}", e1, e2)
1184                }
1185
1186                (Err(e), Ok(_)) => prop_assert!(false, "fast errored {e:?} but scalar succeeded"),
1187                (Ok(_), Err(e)) => prop_assert!(false, "scalar errored {e:?} but fast succeeded"),
1188
1189                (Ok(fast), Ok(reference)) => {
1190                    let EdcfOutput { values: out } = fast;
1191                    let EdcfOutput { values: rref } = reference;
1192
1193                    let mut stream = EdcfStream::try_new(params.clone()).unwrap();
1194                    let mut s_out = Vec::with_capacity(data.len());
1195                    for &v in &data {
1196                        s_out.push(stream.update(v).unwrap_or(f64::NAN));
1197                    }
1198
1199                    let transformed: Vec<f64> = data.iter().map(|x| a * x + b).collect();
1200                    let t_out = edcf(&EdcfInput::from_slice(&transformed, params.clone()))?.values;
1201
1202                    let warm = 2 * period;
1203
1204                    for i in warm..data.len() {
1205                        let win = &data[i + 1 - period..=i];
1206                        let (lo, hi) = win
1207                            .iter()
1208                            .fold((f64::INFINITY, f64::NEG_INFINITY), |(l, h), &v| {
1209                                (l.min(v), h.max(v))
1210                            });
1211                        let y = out[i];
1212                        let yr = rref[i];
1213                        let ys = s_out[i];
1214                        let yt = t_out[i];
1215
1216                        prop_assert!(
1217                            y.is_nan() || (y >= lo - 1e-9 && y <= hi + 1e-9),
1218                            "idx {i}: {y} ∉ [{lo}, {hi}]"
1219                        );
1220
1221                        if win.iter().all(|v| *v == win[0]) {
1222                            prop_assert!(y.is_nan(), "idx {i}: expected NaN on constant series");
1223                        }
1224
1225                        if y.is_finite() && yt.is_finite() {
1226                            let expect = a * y + b;
1227                            let diff = (yt - expect).abs();
1228                            let tol = 1e-9_f64.max(expect.abs() * 1e-9);
1229                            let ulp = yt.to_bits().abs_diff(expect.to_bits());
1230                            prop_assert!(
1231                                diff <= tol || ulp <= 8,
1232                                "idx {i}: affine mismatch diff={diff:e}  ULP={ulp}"
1233                            );
1234                        }
1235
1236                        let ulp = y.to_bits().abs_diff(yr.to_bits());
1237                        prop_assert!(
1238                            (y - yr).abs() <= 1e-9 || ulp <= 4,
1239                            "idx {i}: fast={y} ref={yr} ULP={ulp}"
1240                        );
1241
1242                        prop_assert!(
1243                            (y - ys).abs() <= 1e-9 || (y.is_nan() && ys.is_nan()),
1244                            "idx {i}: stream mismatch"
1245                        );
1246                    }
1247
1248                    let first = data.iter().position(|x| !x.is_nan()).unwrap_or(data.len());
1249                    let warm_expected = first + warm;
1250                    prop_assert!(out[..warm_expected].iter().all(|v| v.is_nan()));
1251                }
1252            }
1253
1254            Ok(())
1255        })?;
1256
1257        assert!(edcf(&EdcfInput::from_slice(&[], EdcfParams::default())).is_err());
1258        assert!(edcf(&EdcfInput::from_slice(
1259            &[f64::NAN; 12],
1260            EdcfParams::default()
1261        ))
1262        .is_err());
1263        assert!(edcf(&EdcfInput::from_slice(
1264            &[1.0; 5],
1265            EdcfParams { period: Some(8) }
1266        ))
1267        .is_err());
1268        assert!(edcf(&EdcfInput::from_slice(
1269            &[1.0; 5],
1270            EdcfParams { period: Some(0) }
1271        ))
1272        .is_err());
1273
1274        Ok(())
1275    }
1276
1277    fn check_edcf_invalid_kernel(
1278        test_name: &str,
1279        _kernel: Kernel,
1280    ) -> Result<(), Box<dyn std::error::Error>> {
1281        let data = [1.0, 2.0, 3.0];
1282        let range = EdcfBatchRange::default();
1283        let res = edcf_batch_with_kernel(&data, &range, Kernel::Avx2);
1284        assert!(
1285            matches!(res, Err(EdcfError::InvalidKernelForBatch(Kernel::Avx2))),
1286            "{}",
1287            test_name
1288        );
1289        Ok(())
1290    }
1291
1292    macro_rules! generate_all_edcf_tests {
1293        ($($test_fn:ident),*) => {
1294            paste::paste! {
1295                $(
1296                    #[test]
1297                    fn [<$test_fn _scalar_f64>]() {
1298                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1299                    }
1300                    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1301                    #[test]
1302                    fn [<$test_fn _avx2_f64>]() {
1303                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1304                    }
1305                    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1306                    #[test]
1307                    fn [<$test_fn _avx512_f64>]() {
1308                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1309                    }
1310                )*
1311            }
1312        }
1313    }
1314
1315    generate_all_edcf_tests!(
1316        check_edcf_partial_params,
1317        check_edcf_accuracy_last_five,
1318        check_edcf_with_default_candles,
1319        check_edcf_with_zero_period,
1320        check_edcf_with_no_data,
1321        check_edcf_with_period_exceeding_data_length,
1322        check_edcf_very_small_data_set,
1323        check_edcf_with_slice_data_reinput,
1324        check_edcf_accuracy_nan_check,
1325        check_edcf_streaming,
1326        check_edcf_property,
1327        check_edcf_invalid_kernel,
1328        check_edcf_no_poison
1329    );
1330
1331    fn check_batch_default_row(
1332        test: &str,
1333        kernel: Kernel,
1334    ) -> Result<(), Box<dyn std::error::Error>> {
1335        skip_if_unsupported!(kernel, test);
1336        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1337        let c = read_candles_from_csv(file)?;
1338        let output = EdcfBatchBuilder::new()
1339            .kernel(kernel)
1340            .apply_candles(&c, "close")?;
1341        let def = EdcfParams::default();
1342        let row = output.values_for(&def).expect("default row missing");
1343        assert_eq!(row.len(), c.close.len());
1344        Ok(())
1345    }
1346
1347    #[cfg(debug_assertions)]
1348    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1349        skip_if_unsupported!(kernel, test);
1350
1351        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1352        let c = read_candles_from_csv(file)?;
1353
1354        let test_sources = vec!["open", "high", "low", "close", "hl2", "hlc3", "ohlc4"];
1355
1356        for source in &test_sources {
1357            let output = EdcfBatchBuilder::new()
1358                .kernel(kernel)
1359                .period_range(3, 200, 5)
1360                .apply_candles(&c, source)?;
1361
1362            for (idx, &val) in output.values.iter().enumerate() {
1363                if val.is_nan() {
1364                    continue;
1365                }
1366
1367                let bits = val.to_bits();
1368                let row = idx / output.cols;
1369                let col = idx % output.cols;
1370
1371                if bits == 0x11111111_11111111 {
1372                    panic!(
1373                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1374                        test, val, bits, row, col, idx, source
1375                    );
1376                }
1377
1378                if bits == 0x22222222_22222222 {
1379                    panic!(
1380                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1381                        test, val, bits, row, col, idx, source
1382                    );
1383                }
1384
1385                if bits == 0x33333333_33333333 {
1386                    panic!(
1387                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1388                        test, val, bits, row, col, idx, source
1389                    );
1390                }
1391            }
1392        }
1393
1394        let edge_case_ranges = vec![(3, 5, 1), (190, 200, 2), (50, 100, 10)];
1395        for (start, end, step) in edge_case_ranges {
1396            let output = EdcfBatchBuilder::new()
1397                .kernel(kernel)
1398                .period_range(start, end, step)
1399                .apply_candles(&c, "close")?;
1400
1401            for (idx, &val) in output.values.iter().enumerate() {
1402                if val.is_nan() {
1403                    continue;
1404                }
1405
1406                let bits = val.to_bits();
1407                let row = idx / output.cols;
1408                let col = idx % output.cols;
1409
1410                if bits == 0x11111111_11111111
1411                    || bits == 0x22222222_22222222
1412                    || bits == 0x33333333_33333333
1413                {
1414                    panic!(
1415						"[{}] Found poison value {} (0x{:016X}) at row {} col {} with range ({},{},{})",
1416						test, val, bits, row, col, start, end, step
1417					);
1418                }
1419            }
1420        }
1421
1422        Ok(())
1423    }
1424
1425    #[cfg(not(debug_assertions))]
1426    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1427        Ok(())
1428    }
1429
1430    macro_rules! gen_batch_tests {
1431        ($fn_name:ident) => {
1432            paste::paste! {
1433                #[test]
1434                fn [<$fn_name _scalar>]() {
1435                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1436                }
1437                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1438                #[test]
1439                fn [<$fn_name _avx2>]() {
1440                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1441                }
1442                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1443                #[test]
1444                fn [<$fn_name _avx512>]() {
1445                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1446                }
1447                #[test]
1448                fn [<$fn_name _auto_detect>]() {
1449                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1450                }
1451            }
1452        };
1453    }
1454
1455    gen_batch_tests!(check_batch_default_row);
1456    gen_batch_tests!(check_batch_no_poison);
1457}
1458
1459#[cfg(feature = "python")]
1460#[pyfunction(name = "edcf")]
1461#[pyo3(signature = (data, period, kernel=None))]
1462pub fn edcf_py<'py>(
1463    py: Python<'py>,
1464    data: numpy::PyReadonlyArray1<'py, f64>,
1465    period: usize,
1466    kernel: Option<&str>,
1467) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1468    use numpy::{IntoPyArray, PyArrayMethods};
1469
1470    let slice_in = data.as_slice()?;
1471    let kern = validate_kernel(kernel, false)?;
1472
1473    let params = EdcfParams {
1474        period: Some(period),
1475    };
1476    let edcf_in = EdcfInput::from_slice(slice_in, params);
1477
1478    let result_vec: Vec<f64> = py
1479        .allow_threads(|| edcf_with_kernel(&edcf_in, kern).map(|o| o.values))
1480        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1481
1482    Ok(result_vec.into_pyarray(py))
1483}
1484
1485#[cfg(feature = "python")]
1486#[pyclass(name = "EdcfStream")]
1487pub struct EdcfStreamPy {
1488    stream: EdcfStream,
1489}
1490
1491#[cfg(feature = "python")]
1492#[pymethods]
1493impl EdcfStreamPy {
1494    #[new]
1495    fn new(period: usize) -> PyResult<Self> {
1496        let params = EdcfParams {
1497            period: Some(period),
1498        };
1499        let stream =
1500            EdcfStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1501        Ok(EdcfStreamPy { stream })
1502    }
1503
1504    fn update(&mut self, value: f64) -> Option<f64> {
1505        self.stream.update(value)
1506    }
1507}
1508
1509#[cfg(feature = "python")]
1510#[pyfunction(name = "edcf_batch")]
1511#[pyo3(signature = (data, period_range, kernel=None))]
1512pub fn edcf_batch_py<'py>(
1513    py: Python<'py>,
1514    data: numpy::PyReadonlyArray1<'py, f64>,
1515    period_range: (usize, usize, usize),
1516    kernel: Option<&str>,
1517) -> PyResult<Bound<'py, PyDict>> {
1518    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1519
1520    let slice_in = data.as_slice()?;
1521
1522    let sweep = EdcfBatchRange {
1523        period: period_range,
1524    };
1525
1526    let combos = expand_grid(&sweep);
1527    let rows = combos.len();
1528    let cols = slice_in.len();
1529
1530    let total = rows
1531        .checked_mul(cols)
1532        .ok_or_else(|| PyValueError::new_err("edcf_batch: rows*cols overflow"))?;
1533
1534    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1535    let slice_out = unsafe { out_arr.as_slice_mut()? };
1536
1537    if !slice_in.is_empty() && rows > 0 {
1538        if let Some(first) = slice_in.iter().position(|x| !x.is_nan()) {
1539            let warm: Vec<usize> = combos
1540                .iter()
1541                .map(|c| {
1542                    let period = c.period.unwrap_or(15);
1543                    let w = first + 2 * period;
1544                    if w > cols {
1545                        cols
1546                    } else {
1547                        w
1548                    }
1549                })
1550                .collect();
1551
1552            let buf_mu: &mut [MaybeUninit<f64>] = unsafe {
1553                core::slice::from_raw_parts_mut(
1554                    slice_out.as_mut_ptr() as *mut MaybeUninit<f64>,
1555                    slice_out.len(),
1556                )
1557            };
1558            init_matrix_prefixes(buf_mu, cols, &warm);
1559        }
1560    }
1561
1562    let kern = validate_kernel(kernel, true)?;
1563
1564    let combos = py
1565        .allow_threads(|| {
1566            let kernel = match kern {
1567                Kernel::Auto => Kernel::ScalarBatch,
1568                k => k,
1569            };
1570            let simd = match kernel {
1571                Kernel::Avx512Batch => Kernel::Avx512,
1572                Kernel::Avx2Batch => Kernel::Avx2,
1573                Kernel::ScalarBatch => Kernel::Scalar,
1574                _ => unreachable!(),
1575            };
1576            edcf_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1577        })
1578        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1579
1580    let dict = PyDict::new(py);
1581    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1582    dict.set_item(
1583        "periods",
1584        combos
1585            .iter()
1586            .map(|p| p.period.unwrap() as u64)
1587            .collect::<Vec<_>>()
1588            .into_pyarray(py),
1589    )?;
1590
1591    Ok(dict)
1592}
1593
1594#[cfg(all(feature = "python", feature = "cuda"))]
1595#[pyfunction(name = "edcf_cuda_batch_dev")]
1596#[pyo3(signature = (data_f32, period_range, device_id=0))]
1597pub fn edcf_cuda_batch_dev_py(
1598    py: Python<'_>,
1599    data_f32: numpy::PyReadonlyArray1<'_, f32>,
1600    period_range: (usize, usize, usize),
1601    device_id: usize,
1602) -> PyResult<DeviceArrayF32Py> {
1603    use crate::cuda::cuda_available;
1604    use crate::cuda::moving_averages::CudaEdcf;
1605
1606    if !cuda_available() {
1607        return Err(PyValueError::new_err("CUDA not available"));
1608    }
1609
1610    let slice_in = data_f32.as_slice()?;
1611    let sweep = EdcfBatchRange {
1612        period: period_range,
1613    };
1614
1615    let (inner, dev_id) = py.allow_threads(|| {
1616        let cuda = CudaEdcf::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1617        let dev_id = cuda.device_id();
1618        let out = cuda
1619            .edcf_batch_dev(slice_in, &sweep)
1620            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1621        Ok::<_, PyErr>((out, dev_id))
1622    })?;
1623
1624    make_device_array_py(dev_id as usize, inner)
1625}
1626
1627#[cfg(all(feature = "python", feature = "cuda"))]
1628#[pyfunction(name = "edcf_cuda_many_series_one_param_dev")]
1629#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1630pub fn edcf_cuda_many_series_one_param_dev_py(
1631    py: Python<'_>,
1632    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1633    period: usize,
1634    device_id: usize,
1635) -> PyResult<DeviceArrayF32Py> {
1636    use crate::cuda::cuda_available;
1637    use crate::cuda::moving_averages::CudaEdcf;
1638    use numpy::PyUntypedArrayMethods;
1639
1640    if !cuda_available() {
1641        return Err(PyValueError::new_err("CUDA not available"));
1642    }
1643
1644    let flat_in: &[f32] = data_tm_f32.as_slice()?;
1645    let rows = data_tm_f32.shape()[0];
1646    let cols = data_tm_f32.shape()[1];
1647    let params = EdcfParams {
1648        period: Some(period),
1649    };
1650
1651    let (inner, dev_id) = py.allow_threads(|| {
1652        let mut cuda =
1653            CudaEdcf::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1654        let dev_id = cuda.device_id();
1655        let out = cuda
1656            .edcf_many_series_one_param_time_major_dev(flat_in, cols, rows, &params)
1657            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1658        Ok::<_, PyErr>((out, dev_id))
1659    })?;
1660
1661    make_device_array_py(dev_id as usize, inner)
1662}
1663
1664#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1665#[wasm_bindgen]
1666pub fn edcf_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1667    let params = EdcfParams {
1668        period: Some(period),
1669    };
1670    let input = EdcfInput::from_slice(data, params);
1671
1672    let mut output = vec![0.0; data.len()];
1673
1674    edcf_into_slice(&mut output, &input, Kernel::Auto)
1675        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1676
1677    Ok(output)
1678}
1679
1680#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1681#[wasm_bindgen]
1682pub fn edcf_batch_js(
1683    data: &[f64],
1684    period_start: usize,
1685    period_end: usize,
1686    period_step: usize,
1687) -> Result<Vec<f64>, JsValue> {
1688    let sweep = EdcfBatchRange {
1689        period: (period_start, period_end, period_step),
1690    };
1691
1692    edcf_batch_inner(data, &sweep, Kernel::Scalar, false)
1693        .map(|output| output.values)
1694        .map_err(|e| JsValue::from_str(&e.to_string()))
1695}
1696
1697#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1698#[wasm_bindgen]
1699pub fn edcf_batch_metadata_js(
1700    period_start: usize,
1701    period_end: usize,
1702    period_step: usize,
1703) -> Result<Vec<f64>, JsValue> {
1704    let sweep = EdcfBatchRange {
1705        period: (period_start, period_end, period_step),
1706    };
1707
1708    let combos = expand_grid(&sweep);
1709    let metadata: Vec<f64> = combos
1710        .iter()
1711        .map(|combo| combo.period.unwrap() as f64)
1712        .collect();
1713
1714    Ok(metadata)
1715}
1716
1717#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1718#[wasm_bindgen]
1719pub fn edcf_alloc(len: usize) -> *mut f64 {
1720    let mut vec = Vec::<f64>::with_capacity(len);
1721    let ptr = vec.as_mut_ptr();
1722    std::mem::forget(vec);
1723    ptr
1724}
1725
1726#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1727#[wasm_bindgen]
1728pub fn edcf_free(ptr: *mut f64, len: usize) {
1729    if !ptr.is_null() {
1730        unsafe {
1731            let _ = Vec::from_raw_parts(ptr, len, len);
1732        }
1733    }
1734}
1735
1736#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1737#[wasm_bindgen]
1738pub fn edcf_into(
1739    in_ptr: *const f64,
1740    out_ptr: *mut f64,
1741    len: usize,
1742    period: usize,
1743) -> Result<(), JsValue> {
1744    if in_ptr.is_null() || out_ptr.is_null() {
1745        return Err(JsValue::from_str("null pointer passed to edcf_into"));
1746    }
1747
1748    unsafe {
1749        let data = std::slice::from_raw_parts(in_ptr, len);
1750
1751        if period == 0 || period > len {
1752            return Err(JsValue::from_str("Invalid period"));
1753        }
1754
1755        let params = EdcfParams {
1756            period: Some(period),
1757        };
1758        let input = EdcfInput::from_slice(data, params);
1759
1760        if in_ptr == out_ptr {
1761            let mut temp = vec![0.0; len];
1762            edcf_into_slice(&mut temp, &input, Kernel::Auto)
1763                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1764
1765            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1766            out.copy_from_slice(&temp);
1767        } else {
1768            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1769            edcf_into_slice(out, &input, Kernel::Auto)
1770                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1771        }
1772
1773        Ok(())
1774    }
1775}
1776
1777#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1778#[derive(Serialize, Deserialize)]
1779pub struct EdcfBatchConfig {
1780    pub period_range: (usize, usize, usize),
1781}
1782
1783#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1784#[derive(Serialize, Deserialize)]
1785pub struct EdcfBatchJsOutput {
1786    pub values: Vec<f64>,
1787    pub combos: Vec<EdcfParams>,
1788    pub rows: usize,
1789    pub cols: usize,
1790}
1791
1792#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1793#[wasm_bindgen(js_name = edcf_batch)]
1794pub fn edcf_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1795    let config: EdcfBatchConfig = serde_wasm_bindgen::from_value(config)
1796        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1797
1798    let sweep = EdcfBatchRange {
1799        period: config.period_range,
1800    };
1801
1802    let output = edcf_batch_inner(data, &sweep, Kernel::Auto, false)
1803        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1804
1805    let js_output = EdcfBatchJsOutput {
1806        values: output.values,
1807        combos: output.combos,
1808        rows: output.rows,
1809        cols: output.cols,
1810    };
1811
1812    serde_wasm_bindgen::to_value(&js_output).map_err(|e| JsValue::from_str(&e.to_string()))
1813}
1814
1815#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1816#[wasm_bindgen]
1817pub fn edcf_batch_into(
1818    in_ptr: *const f64,
1819    out_ptr: *mut f64,
1820    len: usize,
1821    period_start: usize,
1822    period_end: usize,
1823    period_step: usize,
1824) -> Result<usize, JsValue> {
1825    if in_ptr.is_null() || out_ptr.is_null() {
1826        return Err(JsValue::from_str("null pointer passed to edcf_batch_into"));
1827    }
1828
1829    unsafe {
1830        let data = std::slice::from_raw_parts(in_ptr, len);
1831
1832        let sweep = EdcfBatchRange {
1833            period: (period_start, period_end, period_step),
1834        };
1835
1836        let combos = expand_grid(&sweep);
1837        let rows = combos.len();
1838        let cols = len;
1839
1840        let total = rows
1841            .checked_mul(cols)
1842            .ok_or_else(|| JsValue::from_str("edcf_batch_into: rows*cols overflow"))?;
1843
1844        let out = std::slice::from_raw_parts_mut(out_ptr, total);
1845
1846        if !data.is_empty() && rows > 0 {
1847            if let Some(first) = data.iter().position(|x| !x.is_nan()) {
1848                let warm: Vec<usize> = combos
1849                    .iter()
1850                    .map(|c| {
1851                        let period = c.period.unwrap_or(15);
1852                        let w = first + 2 * period;
1853                        if w > cols {
1854                            cols
1855                        } else {
1856                            w
1857                        }
1858                    })
1859                    .collect();
1860
1861                let buf_mu: &mut [MaybeUninit<f64>] = core::slice::from_raw_parts_mut(
1862                    out.as_mut_ptr() as *mut MaybeUninit<f64>,
1863                    out.len(),
1864                );
1865                init_matrix_prefixes(buf_mu, cols, &warm);
1866            }
1867        }
1868
1869        edcf_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
1870            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1871
1872        Ok(rows)
1873    }
1874}