1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{alloc_with_nan_prefix, init_matrix_prefixes, make_uninit_matrix};
4#[cfg(feature = "python")]
5use crate::utilities::kernel_validation::validate_kernel;
6#[cfg(feature = "python")]
7use numpy::{IntoPyArray, PyArray1};
8#[cfg(feature = "python")]
9use pyo3::exceptions::PyValueError;
10#[cfg(feature = "python")]
11use pyo3::prelude::*;
12#[cfg(feature = "python")]
13use pyo3::types::PyDict;
14#[cfg(not(target_arch = "wasm32"))]
15use rayon::prelude::*;
16#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
17use serde::{Deserialize, Serialize};
18use std::convert::AsRef;
19use std::mem::MaybeUninit;
20use thiserror::Error;
21#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
22use wasm_bindgen::prelude::*;
23
24#[cfg(all(feature = "python", feature = "cuda"))]
25use crate::cuda::moving_averages::DeviceArrayF32;
26#[cfg(all(feature = "python", feature = "cuda"))]
27use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
28
29#[derive(Debug, Clone)]
30pub enum EdcfData<'a> {
31 Candles {
32 candles: &'a Candles,
33 source: &'a str,
34 },
35 Slice(&'a [f64]),
36}
37
38impl<'a> AsRef<[f64]> for EdcfInput<'a> {
39 #[inline(always)]
40 fn as_ref(&self) -> &[f64] {
41 match &self.data {
42 EdcfData::Slice(slice) => slice,
43 EdcfData::Candles { candles, source } => source_type(candles, source),
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
49#[cfg_attr(
50 all(target_arch = "wasm32", feature = "wasm"),
51 derive(Serialize, Deserialize)
52)]
53
54pub struct EdcfParams {
55 pub period: Option<usize>,
56}
57
58impl Default for EdcfParams {
59 fn default() -> Self {
60 Self { period: Some(15) }
61 }
62}
63
64#[derive(Debug, Clone)]
65pub struct EdcfOutput {
66 pub values: Vec<f64>,
67}
68
69#[derive(Debug, Clone)]
70pub struct EdcfInput<'a> {
71 pub data: EdcfData<'a>,
72
73 pub params: EdcfParams,
74}
75
76impl<'a> EdcfInput<'a> {
77 #[inline]
78 pub fn from_candles(c: &'a Candles, s: &'a str, p: EdcfParams) -> Self {
79 Self {
80 data: EdcfData::Candles {
81 candles: c,
82 source: s,
83 },
84 params: p,
85 }
86 }
87 #[inline]
88 pub fn from_slice(sl: &'a [f64], p: EdcfParams) -> Self {
89 Self {
90 data: EdcfData::Slice(sl),
91 params: p,
92 }
93 }
94 #[inline]
95 pub fn with_default_candles(c: &'a Candles) -> Self {
96 Self::from_candles(c, "close", EdcfParams::default())
97 }
98 #[inline]
99 pub fn get_period(&self) -> usize {
100 self.params.period.unwrap_or(15)
101 }
102}
103
104#[derive(Debug, Error)]
105pub enum EdcfError {
106 #[error("edcf: No data provided to EDCF filter.")]
107 NoData,
108 #[error("edcf: Empty input data.")]
109 EmptyInputData,
110 #[error("edcf: All values are NaN.")]
111 AllValuesNaN,
112 #[error("edcf: Invalid period: period = {period}, data length = {data_len}")]
113 InvalidPeriod { period: usize, data_len: usize },
114 #[error("edcf: Not enough valid data: needed = {needed}, valid = {valid}")]
115 NotEnoughValidData { needed: usize, valid: usize },
116 #[error("edcf: Output buffer length mismatch: expected = {expected}, got = {got}")]
117 OutputLenMismatch { expected: usize, got: usize },
118 #[error("edcf: Invalid kernel specified")]
119 InvalidKernel,
120 #[error("edcf: Invalid range: start={start}, end={end}, step={step}")]
121 InvalidRange {
122 start: usize,
123 end: usize,
124 step: usize,
125 },
126 #[error("edcf: Invalid kernel for batch API: {0:?}")]
127 InvalidKernelForBatch(Kernel),
128 #[error("edcf: size overflow during allocation ({op})")]
129 SizeOverflow { op: &'static str },
130}
131
132#[derive(Copy, Clone, Debug)]
133
134pub struct EdcfBuilder {
135 period: Option<usize>,
136 kernel: Kernel,
137}
138
139impl Default for EdcfBuilder {
140 fn default() -> Self {
141 Self {
142 period: None,
143 kernel: Kernel::Auto,
144 }
145 }
146}
147
148impl EdcfBuilder {
149 #[inline(always)]
150 pub fn new() -> Self {
151 Self::default()
152 }
153 #[inline(always)]
154 pub fn period(mut self, n: usize) -> Self {
155 self.period = Some(n);
156 self
157 }
158 #[inline(always)]
159 pub fn kernel(mut self, k: Kernel) -> Self {
160 self.kernel = k;
161 self
162 }
163 #[inline(always)]
164 pub fn apply(self, c: &Candles) -> Result<EdcfOutput, EdcfError> {
165 let p = EdcfParams {
166 period: self.period,
167 };
168 let i = EdcfInput::from_candles(c, "close", p);
169 edcf_with_kernel(&i, self.kernel)
170 }
171 #[inline(always)]
172 pub fn apply_slice(self, d: &[f64]) -> Result<EdcfOutput, EdcfError> {
173 let p = EdcfParams {
174 period: self.period,
175 };
176 let i = EdcfInput::from_slice(d, p);
177 edcf_with_kernel(&i, self.kernel)
178 }
179 #[inline(always)]
180 pub fn into_stream(self) -> Result<EdcfStream, EdcfError> {
181 let p = EdcfParams {
182 period: self.period,
183 };
184 EdcfStream::try_new(p)
185 }
186}
187
188#[inline]
189pub fn edcf(input: &EdcfInput) -> Result<EdcfOutput, EdcfError> {
190 edcf_with_kernel(input, Kernel::Auto)
191}
192
193#[inline(always)]
194fn edcf_prepare<'a>(
195 input: &'a EdcfInput,
196 kernel: Kernel,
197) -> Result<(&'a [f64], usize, usize, usize, Kernel), EdcfError> {
198 let data: &[f64] = input.as_ref();
199 let len = data.len();
200 if len == 0 {
201 return Err(EdcfError::NoData);
202 }
203 let first = data
204 .iter()
205 .position(|x| !x.is_nan())
206 .ok_or(EdcfError::AllValuesNaN)?;
207 let period = input.get_period();
208 if period == 0 || period > len {
209 return Err(EdcfError::InvalidPeriod {
210 period,
211 data_len: len,
212 });
213 }
214 let needed = 2 * period;
215 if (len - first) < needed {
216 return Err(EdcfError::NotEnoughValidData {
217 needed,
218 valid: len - first,
219 });
220 }
221
222 let warm = first + 2 * period;
223
224 let chosen = match kernel {
225 Kernel::Auto => Kernel::Scalar,
226 other => other,
227 };
228
229 Ok((data, period, first, warm, chosen))
230}
231
232#[inline(always)]
233fn edcf_compute_into(data: &[f64], period: usize, first: usize, chosen: Kernel, out: &mut [f64]) {
234 #[cfg(target_arch = "wasm32")]
235 {
236 if matches!(chosen, Kernel::Scalar | Kernel::ScalarBatch) {
237 edcf_scalar_wasm(data, period, first, out);
238 return;
239 }
240 }
241
242 unsafe {
243 match chosen {
244 Kernel::Scalar | Kernel::ScalarBatch => edcf_scalar(data, period, first, out),
245 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
246 Kernel::Avx2 | Kernel::Avx2Batch => edcf_avx2(data, period, first, out),
247 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
248 Kernel::Avx512 | Kernel::Avx512Batch => edcf_avx512(data, period, first, out),
249 _ => unreachable!(),
250 }
251 }
252}
253
254pub fn edcf_with_kernel(input: &EdcfInput, kernel: Kernel) -> Result<EdcfOutput, EdcfError> {
255 let (data, period, first, warm, chosen) = edcf_prepare(input, kernel)?;
256 let len = data.len();
257 let mut out = alloc_with_nan_prefix(len, warm);
258 edcf_compute_into(data, period, first, chosen, &mut out);
259 Ok(EdcfOutput { values: out })
260}
261
262#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
263#[inline]
264pub fn edcf_into(input: &EdcfInput, out: &mut [f64]) -> Result<(), EdcfError> {
265 let (data, period, first, warm, chosen) = edcf_prepare(input, Kernel::Auto)?;
266
267 if out.len() != data.len() {
268 return Err(EdcfError::OutputLenMismatch {
269 expected: data.len(),
270 got: out.len(),
271 });
272 }
273
274 let in_ptr = data.as_ptr();
275 let out_ptr = out.as_ptr();
276 if core::ptr::eq(in_ptr, out_ptr) {
277 let mut temp = alloc_with_nan_prefix(out.len(), warm);
278 edcf_compute_into(data, period, first, chosen, &mut temp);
279 out.copy_from_slice(&temp);
280 return Ok(());
281 }
282
283 let warm = warm.min(out.len());
284 for v in &mut out[..warm] {
285 *v = f64::from_bits(0x7ff8_0000_0000_0000);
286 }
287
288 edcf_compute_into(data, period, first, chosen, out);
289
290 Ok(())
291}
292
293#[inline]
294pub fn edcf_into_slice(dst: &mut [f64], input: &EdcfInput, kern: Kernel) -> Result<(), EdcfError> {
295 let (data, period, first, warm, chosen) = edcf_prepare(input, kern)?;
296
297 if dst.len() != data.len() {
298 return Err(EdcfError::OutputLenMismatch {
299 expected: data.len(),
300 got: dst.len(),
301 });
302 }
303
304 edcf_compute_into(data, period, first, chosen, dst);
305
306 for v in &mut dst[..warm] {
307 *v = f64::NAN;
308 }
309
310 Ok(())
311}
312
313#[inline(always)]
314pub fn edcf_scalar(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
315 let mut buf = vec![0.0; period];
316 let mut wbuf = vec![0.0; period];
317 edcf_scalar_o1_into(data, period, first_valid, out, &mut buf, &mut wbuf);
318}
319
320#[inline(always)]
321fn edcf_scalar_o1_into(
322 data: &[f64],
323 period: usize,
324 first_valid: usize,
325 out: &mut [f64],
326 buf: &mut [f64],
327 wbuf: &mut [f64],
328) {
329 debug_assert_eq!(buf.len(), period);
330 debug_assert_eq!(wbuf.len(), period);
331
332 buf.fill(0.0);
333 wbuf.fill(0.0);
334
335 let len = data.len();
336 let warm = first_valid + 2 * period;
337
338 let mut head = 0usize;
339 let mut count = 0usize;
340
341 let mut sum_prev = 0.0;
342 let mut sum_prev_sq = 0.0;
343 let mut den = 0.0;
344 let mut num = 0.0;
345 let p_minus1_f = (period - 1) as f64;
346
347 for idx in first_valid..len {
348 let value = data[idx];
349
350 let old_x = unsafe { *buf.get_unchecked(head) };
351 let old_w = unsafe { *wbuf.get_unchecked(head) };
352 let had_full_window = count >= period;
353
354 let w_new = if count >= period {
355 let x2 = value * value;
356 p_minus1_f.mul_add(x2, sum_prev_sq) - (2.0 * value * sum_prev)
357 } else {
358 0.0
359 };
360
361 if had_full_window {
362 den -= old_w;
363 num -= old_w * old_x;
364 }
365 den += w_new;
366 num = w_new.mul_add(value, num);
367
368 unsafe {
369 *buf.get_unchecked_mut(head) = value;
370 *wbuf.get_unchecked_mut(head) = w_new;
371 }
372 head += 1;
373 if head == period {
374 head = 0;
375 }
376
377 sum_prev += value;
378 sum_prev_sq = value.mul_add(value, sum_prev_sq);
379 if count >= (period - 1) {
380 let drop_x = unsafe { *buf.get_unchecked(head) };
381 sum_prev -= drop_x;
382 sum_prev_sq -= drop_x * drop_x;
383 }
384
385 count += 1;
386
387 if idx >= warm {
388 if den != 0.0 {
389 out[idx] = num / den;
390 } else {
391 out[idx] = f64::NAN;
392 }
393 }
394 }
395}
396
397#[inline(always)]
398fn edcf_scalar_into_with_scratch(
399 data: &[f64],
400 period: usize,
401 first_valid: usize,
402 out: &mut [f64],
403 scratch: &mut Vec<f64>,
404) {
405 let need = period * 2;
406 if scratch.len() < need {
407 scratch.resize(need, 0.0);
408 }
409 let (buf, wbuf) = scratch.split_at_mut(period);
410 edcf_scalar_o1_into(data, period, first_valid, out, buf, wbuf);
411}
412
413#[cfg(target_arch = "wasm32")]
414#[inline]
415fn edcf_scalar_wasm(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
416 edcf_scalar(data, period, first_valid, out);
417}
418
419#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
420#[target_feature(enable = "avx2,fma")]
421pub unsafe fn edcf_avx2(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
422 edcf_scalar(data, period, first_valid, out);
423}
424#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
425#[target_feature(enable = "avx512f,avx512dq,fma")]
426pub unsafe fn edcf_avx512(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
427 edcf_scalar(data, period, first_valid, out);
428}
429
430#[derive(Debug, Clone)]
431
432pub struct EdcfStream {
433 period: usize,
434
435 buffer: Vec<f64>,
436
437 dist: Vec<f64>,
438
439 head: usize,
440
441 count: usize,
442
443 sum_prev: f64,
444 sum_prev_sq: f64,
445
446 den: f64,
447 num: f64,
448
449 p_minus1_f: f64,
450}
451
452impl EdcfStream {
453 #[inline]
454 pub fn try_new(params: EdcfParams) -> Result<Self, EdcfError> {
455 let period = params.period.unwrap_or(15);
456 if period == 0 {
457 return Err(EdcfError::InvalidPeriod {
458 period,
459 data_len: 0,
460 });
461 }
462
463 let buffer = alloc_with_nan_prefix(period, period);
464 let dist = vec![0.0; period];
465
466 Ok(Self {
467 period,
468 buffer,
469 dist,
470 head: 0,
471 count: 0,
472 sum_prev: 0.0,
473 sum_prev_sq: 0.0,
474 den: 0.0,
475 num: 0.0,
476 p_minus1_f: (period - 1) as f64,
477 })
478 }
479
480 #[inline(always)]
481 fn bump_head(&mut self) {
482 let n = self.head + 1;
483 self.head = if n == self.period { 0 } else { n };
484 }
485
486 #[inline(always)]
487 pub fn update(&mut self, value: f64) -> Option<f64> {
488 if !value.is_finite() {
489 return None;
490 }
491
492 let p = self.period;
493
494 let old_x = self.buffer[self.head];
495 let old_w = self.dist[self.head];
496 let had_full_window = self.count >= p;
497
498 let w_new = if self.count >= p {
499 let x2 = value * value;
500 self.p_minus1_f.mul_add(x2, self.sum_prev_sq) - (2.0 * value * self.sum_prev)
501 } else {
502 0.0
503 };
504
505 if had_full_window {
506 self.den -= old_w;
507 self.num -= old_w * old_x;
508 }
509
510 self.den += w_new;
511 self.num = w_new.mul_add(value, self.num);
512
513 self.buffer[self.head] = value;
514 self.dist[self.head] = w_new;
515 self.bump_head();
516
517 self.sum_prev += value;
518 self.sum_prev_sq = value.mul_add(value, self.sum_prev_sq);
519 if self.count >= (p - 1) {
520 let drop_x = self.buffer[self.head];
521 self.sum_prev -= drop_x;
522 self.sum_prev_sq -= drop_x * drop_x;
523 }
524
525 self.count += 1;
526 if self.count < 2 * p {
527 return None;
528 }
529 if self.den != 0.0 {
530 Some(fast_div(self.num, self.den))
531 } else {
532 None
533 }
534 }
535}
536
537#[inline(always)]
538fn fast_div(num: f64, den: f64) -> f64 {
539 num / den
540}
541
542#[derive(Clone, Debug)]
543#[cfg_attr(
544 all(target_arch = "wasm32", feature = "wasm"),
545 derive(Serialize, Deserialize)
546)]
547pub struct EdcfBatchRange {
548 pub period: (usize, usize, usize),
549}
550
551impl Default for EdcfBatchRange {
552 fn default() -> Self {
553 Self {
554 period: (15, 264, 1),
555 }
556 }
557}
558
559#[derive(Clone, Debug, Default)]
560
561pub struct EdcfBatchBuilder {
562 range: EdcfBatchRange,
563 kernel: Kernel,
564}
565
566impl EdcfBatchBuilder {
567 pub fn new() -> Self {
568 Self::default()
569 }
570 pub fn kernel(mut self, k: Kernel) -> Self {
571 self.kernel = k;
572 self
573 }
574 #[inline]
575 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
576 self.range.period = (start, end, step);
577 self
578 }
579 #[inline]
580 pub fn period_static(mut self, p: usize) -> Self {
581 self.range.period = (p, p, 0);
582 self
583 }
584 pub fn apply_slice(self, data: &[f64]) -> Result<EdcfBatchOutput, EdcfError> {
585 edcf_batch_with_kernel(data, &self.range, self.kernel)
586 }
587 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<EdcfBatchOutput, EdcfError> {
588 EdcfBatchBuilder::new().kernel(k).apply_slice(data)
589 }
590 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<EdcfBatchOutput, EdcfError> {
591 let slice = source_type(c, src);
592 self.apply_slice(slice)
593 }
594 pub fn with_default_candles(c: &Candles) -> Result<EdcfBatchOutput, EdcfError> {
595 EdcfBatchBuilder::new()
596 .kernel(Kernel::Auto)
597 .apply_candles(c, "close")
598 }
599}
600
601pub fn edcf_batch_with_kernel(
602 data: &[f64],
603 sweep: &EdcfBatchRange,
604 k: Kernel,
605) -> Result<EdcfBatchOutput, EdcfError> {
606 if data.is_empty() {
607 return Err(EdcfError::NoData);
608 }
609 let kernel = match k {
610 Kernel::Auto => Kernel::ScalarBatch,
611 other if other.is_batch() => other,
612 other => return Err(EdcfError::InvalidKernelForBatch(other)),
613 };
614 let simd = match kernel {
615 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
616 Kernel::Avx512Batch => Kernel::Avx512,
617 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
618 Kernel::Avx2Batch => Kernel::Avx2,
619 Kernel::ScalarBatch => Kernel::Scalar,
620 _ => unreachable!(),
621 };
622 edcf_batch_par_slice(data, sweep, simd)
623}
624
625#[derive(Clone, Debug)]
626pub struct EdcfBatchOutput {
627 pub values: Vec<f64>,
628
629 pub combos: Vec<EdcfParams>,
630
631 pub rows: usize,
632
633 pub cols: usize,
634}
635impl EdcfBatchOutput {
636 pub fn row_for_params(&self, p: &EdcfParams) -> Option<usize> {
637 self.combos
638 .iter()
639 .position(|c| c.period.unwrap_or(15) == p.period.unwrap_or(15))
640 }
641 pub fn values_for(&self, p: &EdcfParams) -> Option<&[f64]> {
642 self.row_for_params(p).map(|row| {
643 let start = row * self.cols;
644 &self.values[start..start + self.cols]
645 })
646 }
647}
648
649#[inline(always)]
650fn expand_grid(r: &EdcfBatchRange) -> Vec<EdcfParams> {
651 let (mut start, mut end, step) = r.period;
652
653 if start > end {
654 core::mem::swap(&mut start, &mut end);
655 }
656 let periods: Vec<usize> = if step == 0 || start == end {
657 vec![start]
658 } else {
659 (start..=end).step_by(step).collect()
660 };
661
662 periods
663 .into_iter()
664 .map(|p| EdcfParams { period: Some(p) })
665 .collect()
666}
667
668#[inline(always)]
669pub fn edcf_batch_slice(
670 data: &[f64],
671 sweep: &EdcfBatchRange,
672 kern: Kernel,
673) -> Result<EdcfBatchOutput, EdcfError> {
674 edcf_batch_inner(data, sweep, kern, false)
675}
676
677#[inline(always)]
678pub fn edcf_batch_par_slice(
679 data: &[f64],
680 sweep: &EdcfBatchRange,
681 kern: Kernel,
682) -> Result<EdcfBatchOutput, EdcfError> {
683 edcf_batch_inner(data, sweep, kern, true)
684}
685
686#[inline(always)]
687fn edcf_batch_inner(
688 data: &[f64],
689 sweep: &EdcfBatchRange,
690 kern: Kernel,
691 parallel: bool,
692) -> Result<EdcfBatchOutput, EdcfError> {
693 if data.is_empty() {
694 return Err(EdcfError::NoData);
695 }
696
697 let combos = expand_grid(sweep);
698 if combos.is_empty() {
699 return Err(EdcfError::InvalidRange {
700 start: sweep.period.0,
701 end: sweep.period.1,
702 step: sweep.period.2,
703 });
704 }
705
706 let rows = combos.len();
707 let cols = data.len();
708
709 let _total = rows
710 .checked_mul(cols)
711 .ok_or(EdcfError::SizeOverflow { op: "rows*cols" })?;
712
713 let mut buf_mu = make_uninit_matrix(rows, cols);
714
715 let warm: Vec<usize> = combos
716 .iter()
717 .map(|c| data.iter().position(|x| !x.is_nan()).unwrap_or(0) + 2 * c.period.unwrap_or(15))
718 .collect();
719
720 init_matrix_prefixes(&mut buf_mu, cols, &warm);
721
722 let mut buf_guard = std::mem::ManuallyDrop::new(buf_mu);
723 let out: &mut [f64] = unsafe {
724 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
725 };
726
727 let result_combos = edcf_batch_inner_into(data, sweep, kern, parallel, out)?;
728
729 let values = unsafe {
730 Vec::from_raw_parts(
731 buf_guard.as_mut_ptr() as *mut f64,
732 buf_guard.len(),
733 buf_guard.capacity(),
734 )
735 };
736
737 Ok(EdcfBatchOutput {
738 values,
739 combos: result_combos,
740 rows,
741 cols,
742 })
743}
744
745#[inline(always)]
746fn edcf_batch_inner_into(
747 data: &[f64],
748 sweep: &EdcfBatchRange,
749 kern: Kernel,
750 parallel: bool,
751 out: &mut [f64],
752) -> Result<Vec<EdcfParams>, EdcfError> {
753 if data.is_empty() {
754 return Err(EdcfError::NoData);
755 }
756 let combos = expand_grid(sweep);
757 if combos.is_empty() {
758 return Err(EdcfError::InvalidRange {
759 start: sweep.period.0,
760 end: sweep.period.1,
761 step: sweep.period.2,
762 });
763 }
764
765 let first = data
766 .iter()
767 .position(|x| !x.is_nan())
768 .ok_or(EdcfError::AllValuesNaN)?;
769 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
770 if data.len() - first < 2 * max_p {
771 return Err(EdcfError::NotEnoughValidData {
772 needed: 2 * max_p,
773 valid: data.len() - first,
774 });
775 }
776
777 let rows = combos.len();
778 let cols = data.len();
779
780 if parallel {
781 #[cfg(not(target_arch = "wasm32"))]
782 {
783 use rayon::prelude::*;
784
785 out.par_chunks_mut(cols).enumerate().for_each(|(row, dst)| {
786 let period = combos[row].period.unwrap();
787
788 let mut scratch = Vec::<f64>::new();
789 match kern {
790 Kernel::Scalar | Kernel::Avx2 | Kernel::Avx512 => {
791 edcf_scalar_into_with_scratch(data, period, first, dst, &mut scratch);
792 }
793 _ => unsafe { edcf_row_scalar(data, first, period, dst) },
794 }
795 });
796 }
797
798 #[cfg(target_arch = "wasm32")]
799 {
800 for (row, dst) in out.chunks_mut(cols).enumerate() {
801 let period = combos[row].period.unwrap();
802 unsafe { edcf_row_scalar(data, first, period, dst) }
803 }
804 }
805 } else {
806 #[cfg(not(target_arch = "wasm32"))]
807 {
808 let mut scratch = Vec::<f64>::new();
809 for (row, dst) in out.chunks_mut(cols).enumerate() {
810 let period = combos[row].period.unwrap();
811 match kern {
812 Kernel::Scalar | Kernel::Avx2 | Kernel::Avx512 => {
813 edcf_scalar_into_with_scratch(data, period, first, dst, &mut scratch)
814 }
815 _ => unsafe { edcf_row_scalar(data, first, period, dst) },
816 }
817 }
818 }
819 #[cfg(target_arch = "wasm32")]
820 {
821 for (row, dst) in out.chunks_mut(cols).enumerate() {
822 let period = combos[row].period.unwrap();
823 unsafe { edcf_row_scalar(data, first, period, dst) }
824 }
825 }
826 }
827
828 Ok(combos)
829}
830
831#[inline(always)]
832unsafe fn edcf_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
833 edcf_scalar(data, period, first, out)
834}
835#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
836#[inline(always)]
837unsafe fn edcf_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
838 edcf_avx2(data, period, first, out);
839}
840#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
841#[inline(always)]
842unsafe fn edcf_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
843 edcf_avx512(data, period, first, out);
844}
845
846#[cfg(feature = "python")]
847pub fn register_edcf_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
848 m.add_function(wrap_pyfunction!(edcf_py, m)?)?;
849 m.add_function(wrap_pyfunction!(edcf_batch_py, m)?)?;
850 m.add_class::<EdcfStreamPy>()?;
851 #[cfg(feature = "cuda")]
852 {
853 m.add_function(wrap_pyfunction!(edcf_cuda_batch_dev_py, m)?)?;
854 m.add_function(wrap_pyfunction!(edcf_cuda_many_series_one_param_dev_py, m)?)?;
855 }
856 Ok(())
857}
858
859#[cfg(test)]
860mod tests {
861 use super::*;
862 use crate::skip_if_unsupported;
863 use crate::utilities::data_loader::read_candles_from_csv;
864 use proptest::prelude::*;
865 use std::error::Error;
866
867 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
868 #[test]
869 fn test_edcf_into_matches_api() -> Result<(), Box<dyn Error>> {
870 let mut data: Vec<f64> = Vec::new();
871 data.extend_from_slice(&[f64::NAN, f64::NAN, f64::NAN, f64::NAN, f64::NAN]);
872 for i in 0..250usize {
873 let x = (i as f64).sin() * 3.0 + (i as f64) * 0.05 + ((i % 7) as f64) * 0.1;
874 data.push(x);
875 }
876
877 let input = EdcfInput::from_slice(&data, EdcfParams::default());
878
879 let baseline = edcf(&input)?.values;
880
881 let mut out = vec![0.0; data.len()];
882 edcf_into(&input, &mut out)?;
883
884 assert_eq!(baseline.len(), out.len());
885
886 fn eq_or_both_nan(a: f64, b: f64) -> bool {
887 (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
888 }
889
890 for i in 0..out.len() {
891 assert!(
892 eq_or_both_nan(baseline[i], out[i]),
893 "mismatch at {}: expected {:?}, got {:?}",
894 i,
895 baseline[i],
896 out[i]
897 );
898 }
899
900 Ok(())
901 }
902
903 fn check_edcf_partial_params(
904 test_name: &str,
905 kernel: Kernel,
906 ) -> Result<(), Box<dyn std::error::Error>> {
907 skip_if_unsupported!(kernel, test_name);
908 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
909 let candles = read_candles_from_csv(file_path)?;
910 let input = EdcfInput::from_candles(&candles, "close", EdcfParams { period: None });
911 let result = edcf_with_kernel(&input, kernel)?;
912 assert_eq!(result.values.len(), candles.close.len());
913 Ok(())
914 }
915
916 fn check_edcf_accuracy_last_five(
917 test_name: &str,
918 kernel: Kernel,
919 ) -> Result<(), Box<dyn std::error::Error>> {
920 skip_if_unsupported!(kernel, test_name);
921 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
922 let candles = read_candles_from_csv(file_path)?;
923 let input = EdcfInput::from_candles(&candles, "hl2", EdcfParams { period: Some(15) });
924 let result = edcf_with_kernel(&input, kernel)?;
925 let expected = [
926 59593.332275678375,
927 59731.70263288801,
928 59766.41512339413,
929 59655.66162110993,
930 59332.492883847,
931 ];
932 let len = result.values.len();
933 let start = len - expected.len();
934 for (i, &v) in result.values[start..].iter().enumerate() {
935 assert!(
936 (v - expected[i]).abs() < 1e-8,
937 "[{}] EDCF mismatch at {}: got {}, expected {}",
938 test_name,
939 start + i,
940 v,
941 expected[i]
942 );
943 }
944 Ok(())
945 }
946
947 fn check_edcf_with_default_candles(
948 test_name: &str,
949 kernel: Kernel,
950 ) -> Result<(), Box<dyn std::error::Error>> {
951 skip_if_unsupported!(kernel, test_name);
952 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
953 let candles = read_candles_from_csv(file_path)?;
954 let input = EdcfInput::with_default_candles(&candles);
955 match input.data {
956 EdcfData::Candles { source, .. } => assert_eq!(source, "close"),
957 _ => panic!("Expected EdcfData::Candles"),
958 }
959 let period = input.get_period();
960 assert_eq!(period, 15);
961 Ok(())
962 }
963
964 fn check_edcf_with_zero_period(
965 test_name: &str,
966 kernel: Kernel,
967 ) -> Result<(), Box<dyn std::error::Error>> {
968 skip_if_unsupported!(kernel, test_name);
969 let data = [10.0, 20.0, 30.0];
970 let input = EdcfInput::from_slice(&data, EdcfParams { period: Some(0) });
971 let result = edcf_with_kernel(&input, kernel);
972 assert!(result.is_err());
973 Ok(())
974 }
975
976 fn check_edcf_with_no_data(
977 test_name: &str,
978 kernel: Kernel,
979 ) -> Result<(), Box<dyn std::error::Error>> {
980 skip_if_unsupported!(kernel, test_name);
981 let data: [f64; 0] = [];
982 let input = EdcfInput::from_slice(&data, EdcfParams { period: Some(15) });
983 let result = edcf_with_kernel(&input, kernel);
984 assert!(result.is_err());
985 Ok(())
986 }
987
988 fn check_edcf_with_period_exceeding_data_length(
989 test_name: &str,
990 kernel: Kernel,
991 ) -> Result<(), Box<dyn std::error::Error>> {
992 skip_if_unsupported!(kernel, test_name);
993 let data = [10.0, 20.0, 30.0];
994 let input = EdcfInput::from_slice(&data, EdcfParams { period: Some(10) });
995 let result = edcf_with_kernel(&input, kernel);
996 assert!(result.is_err());
997 Ok(())
998 }
999
1000 fn check_edcf_very_small_data_set(
1001 test_name: &str,
1002 kernel: Kernel,
1003 ) -> Result<(), Box<dyn std::error::Error>> {
1004 skip_if_unsupported!(kernel, test_name);
1005 let data = [42.0];
1006 let input = EdcfInput::from_slice(&data, EdcfParams { period: Some(15) });
1007 let result = edcf_with_kernel(&input, kernel);
1008 assert!(result.is_err());
1009 Ok(())
1010 }
1011
1012 fn check_edcf_with_slice_data_reinput(
1013 test_name: &str,
1014 kernel: Kernel,
1015 ) -> Result<(), Box<dyn std::error::Error>> {
1016 skip_if_unsupported!(kernel, test_name);
1017 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1018 let candles = read_candles_from_csv(file_path)?;
1019 let first_input =
1020 EdcfInput::from_candles(&candles, "close", EdcfParams { period: Some(15) });
1021 let first_result = edcf_with_kernel(&first_input, kernel)?;
1022 let first_values = first_result.values;
1023 let second_input = EdcfInput::from_slice(&first_values, EdcfParams { period: Some(5) });
1024 let second_result = edcf_with_kernel(&second_input, kernel)?;
1025 assert_eq!(second_result.values.len(), first_values.len());
1026 if second_result.values.len() > 240 {
1027 for (i, &val) in second_result.values.iter().enumerate().skip(240) {
1028 assert!(
1029 !val.is_nan(),
1030 "Found NaN in second EDCF output at index {}",
1031 i
1032 );
1033 }
1034 }
1035 Ok(())
1036 }
1037
1038 fn check_edcf_accuracy_nan_check(
1039 test_name: &str,
1040 kernel: Kernel,
1041 ) -> Result<(), Box<dyn std::error::Error>> {
1042 skip_if_unsupported!(kernel, test_name);
1043 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1044 let candles = read_candles_from_csv(file_path)?;
1045 let period = 15;
1046 let input = EdcfInput::from_candles(
1047 &candles,
1048 "close",
1049 EdcfParams {
1050 period: Some(period),
1051 },
1052 );
1053 let result = edcf_with_kernel(&input, kernel)?;
1054 assert_eq!(result.values.len(), candles.close.len());
1055 let start_index = 2 * period;
1056 if result.values.len() > start_index {
1057 for (i, &val) in result.values.iter().enumerate().skip(start_index) {
1058 assert!(!val.is_nan(), "Found NaN in EDCF output at index {}", i);
1059 }
1060 }
1061 Ok(())
1062 }
1063
1064 fn check_edcf_streaming(
1065 test_name: &str,
1066 kernel: Kernel,
1067 ) -> Result<(), Box<dyn std::error::Error>> {
1068 skip_if_unsupported!(kernel, test_name);
1069 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1070 let candles = read_candles_from_csv(file_path)?;
1071
1072 let input = EdcfInput::from_candles(&candles, "close", EdcfParams { period: Some(15) });
1073 let _batch = edcf_with_kernel(&input, kernel)?;
1074
1075 let mut stream = EdcfStream::try_new(EdcfParams { period: Some(15) })?;
1076 let mut vals = Vec::with_capacity(candles.close.len());
1077 for &v in &candles.close {
1078 vals.push(stream.update(v).unwrap_or(f64::NAN));
1079 }
1080 for (i, &v) in vals.iter().enumerate().skip(30) {
1081 assert!(!v.is_nan(), "[{test_name}] NaN at {i}");
1082 }
1083 Ok(())
1084 }
1085
1086 #[cfg(debug_assertions)]
1087 fn check_edcf_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1088 skip_if_unsupported!(kernel, test_name);
1089
1090 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1091 let candles = read_candles_from_csv(file_path)?;
1092
1093 let test_periods = vec![3, 5, 10, 15, 30, 50, 100, 200];
1094 let test_sources = vec!["open", "high", "low", "close", "hl2", "hlc3", "ohlc4"];
1095
1096 for period in &test_periods {
1097 for source in &test_sources {
1098 let input = EdcfInput::from_candles(
1099 &candles,
1100 source,
1101 EdcfParams {
1102 period: Some(*period),
1103 },
1104 );
1105 let output = edcf_with_kernel(&input, kernel)?;
1106
1107 for (i, &val) in output.values.iter().enumerate() {
1108 if val.is_nan() {
1109 continue;
1110 }
1111
1112 let bits = val.to_bits();
1113
1114 if bits == 0x11111111_11111111 {
1115 panic!(
1116 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with period={}, source={}",
1117 test_name, val, bits, i, period, source
1118 );
1119 }
1120
1121 if bits == 0x22222222_22222222 {
1122 panic!(
1123 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with period={}, source={}",
1124 test_name, val, bits, i, period, source
1125 );
1126 }
1127
1128 if bits == 0x33333333_33333333 {
1129 panic!(
1130 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with period={}, source={}",
1131 test_name, val, bits, i, period, source
1132 );
1133 }
1134 }
1135 }
1136 }
1137
1138 Ok(())
1139 }
1140
1141 #[cfg(not(debug_assertions))]
1142 fn check_edcf_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1143 Ok(())
1144 }
1145
1146 #[allow(clippy::float_cmp)]
1147 fn check_edcf_property(
1148 test_name: &str,
1149 kernel: Kernel,
1150 ) -> Result<(), Box<dyn std::error::Error>> {
1151 use proptest::prelude::*;
1152 skip_if_unsupported!(kernel, test_name);
1153
1154 let strat = (3usize..=30).prop_flat_map(|period| {
1155 (
1156 prop::collection::vec(
1157 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1158 2 * period..400,
1159 ),
1160 Just(period),
1161 (-1e3f64..1e3f64).prop_filter("a≠0", |a| a.abs() > 1e-12),
1162 -1e3f64..1e3f64,
1163 )
1164 });
1165
1166 proptest::test_runner::TestRunner::default().run(&strat, |(data, period, a, b)| {
1167 let params = EdcfParams {
1168 period: Some(period),
1169 };
1170 let input = EdcfInput::from_slice(&data, params.clone());
1171
1172 let fast = edcf_with_kernel(&input, kernel);
1173 let slow = edcf_with_kernel(&input, Kernel::Scalar);
1174
1175 match (fast, slow) {
1176 (Err(e1), Err(e2))
1177 if std::mem::discriminant(&e1) == std::mem::discriminant(&e2) =>
1178 {
1179 return Ok(());
1180 }
1181
1182 (Err(e1), Err(e2)) => {
1183 prop_assert!(false, "different errors: fast={:?} slow={:?}", e1, e2)
1184 }
1185
1186 (Err(e), Ok(_)) => prop_assert!(false, "fast errored {e:?} but scalar succeeded"),
1187 (Ok(_), Err(e)) => prop_assert!(false, "scalar errored {e:?} but fast succeeded"),
1188
1189 (Ok(fast), Ok(reference)) => {
1190 let EdcfOutput { values: out } = fast;
1191 let EdcfOutput { values: rref } = reference;
1192
1193 let mut stream = EdcfStream::try_new(params.clone()).unwrap();
1194 let mut s_out = Vec::with_capacity(data.len());
1195 for &v in &data {
1196 s_out.push(stream.update(v).unwrap_or(f64::NAN));
1197 }
1198
1199 let transformed: Vec<f64> = data.iter().map(|x| a * x + b).collect();
1200 let t_out = edcf(&EdcfInput::from_slice(&transformed, params.clone()))?.values;
1201
1202 let warm = 2 * period;
1203
1204 for i in warm..data.len() {
1205 let win = &data[i + 1 - period..=i];
1206 let (lo, hi) = win
1207 .iter()
1208 .fold((f64::INFINITY, f64::NEG_INFINITY), |(l, h), &v| {
1209 (l.min(v), h.max(v))
1210 });
1211 let y = out[i];
1212 let yr = rref[i];
1213 let ys = s_out[i];
1214 let yt = t_out[i];
1215
1216 prop_assert!(
1217 y.is_nan() || (y >= lo - 1e-9 && y <= hi + 1e-9),
1218 "idx {i}: {y} ∉ [{lo}, {hi}]"
1219 );
1220
1221 if win.iter().all(|v| *v == win[0]) {
1222 prop_assert!(y.is_nan(), "idx {i}: expected NaN on constant series");
1223 }
1224
1225 if y.is_finite() && yt.is_finite() {
1226 let expect = a * y + b;
1227 let diff = (yt - expect).abs();
1228 let tol = 1e-9_f64.max(expect.abs() * 1e-9);
1229 let ulp = yt.to_bits().abs_diff(expect.to_bits());
1230 prop_assert!(
1231 diff <= tol || ulp <= 8,
1232 "idx {i}: affine mismatch diff={diff:e} ULP={ulp}"
1233 );
1234 }
1235
1236 let ulp = y.to_bits().abs_diff(yr.to_bits());
1237 prop_assert!(
1238 (y - yr).abs() <= 1e-9 || ulp <= 4,
1239 "idx {i}: fast={y} ref={yr} ULP={ulp}"
1240 );
1241
1242 prop_assert!(
1243 (y - ys).abs() <= 1e-9 || (y.is_nan() && ys.is_nan()),
1244 "idx {i}: stream mismatch"
1245 );
1246 }
1247
1248 let first = data.iter().position(|x| !x.is_nan()).unwrap_or(data.len());
1249 let warm_expected = first + warm;
1250 prop_assert!(out[..warm_expected].iter().all(|v| v.is_nan()));
1251 }
1252 }
1253
1254 Ok(())
1255 })?;
1256
1257 assert!(edcf(&EdcfInput::from_slice(&[], EdcfParams::default())).is_err());
1258 assert!(edcf(&EdcfInput::from_slice(
1259 &[f64::NAN; 12],
1260 EdcfParams::default()
1261 ))
1262 .is_err());
1263 assert!(edcf(&EdcfInput::from_slice(
1264 &[1.0; 5],
1265 EdcfParams { period: Some(8) }
1266 ))
1267 .is_err());
1268 assert!(edcf(&EdcfInput::from_slice(
1269 &[1.0; 5],
1270 EdcfParams { period: Some(0) }
1271 ))
1272 .is_err());
1273
1274 Ok(())
1275 }
1276
1277 fn check_edcf_invalid_kernel(
1278 test_name: &str,
1279 _kernel: Kernel,
1280 ) -> Result<(), Box<dyn std::error::Error>> {
1281 let data = [1.0, 2.0, 3.0];
1282 let range = EdcfBatchRange::default();
1283 let res = edcf_batch_with_kernel(&data, &range, Kernel::Avx2);
1284 assert!(
1285 matches!(res, Err(EdcfError::InvalidKernelForBatch(Kernel::Avx2))),
1286 "{}",
1287 test_name
1288 );
1289 Ok(())
1290 }
1291
1292 macro_rules! generate_all_edcf_tests {
1293 ($($test_fn:ident),*) => {
1294 paste::paste! {
1295 $(
1296 #[test]
1297 fn [<$test_fn _scalar_f64>]() {
1298 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1299 }
1300 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1301 #[test]
1302 fn [<$test_fn _avx2_f64>]() {
1303 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1304 }
1305 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1306 #[test]
1307 fn [<$test_fn _avx512_f64>]() {
1308 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1309 }
1310 )*
1311 }
1312 }
1313 }
1314
1315 generate_all_edcf_tests!(
1316 check_edcf_partial_params,
1317 check_edcf_accuracy_last_five,
1318 check_edcf_with_default_candles,
1319 check_edcf_with_zero_period,
1320 check_edcf_with_no_data,
1321 check_edcf_with_period_exceeding_data_length,
1322 check_edcf_very_small_data_set,
1323 check_edcf_with_slice_data_reinput,
1324 check_edcf_accuracy_nan_check,
1325 check_edcf_streaming,
1326 check_edcf_property,
1327 check_edcf_invalid_kernel,
1328 check_edcf_no_poison
1329 );
1330
1331 fn check_batch_default_row(
1332 test: &str,
1333 kernel: Kernel,
1334 ) -> Result<(), Box<dyn std::error::Error>> {
1335 skip_if_unsupported!(kernel, test);
1336 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1337 let c = read_candles_from_csv(file)?;
1338 let output = EdcfBatchBuilder::new()
1339 .kernel(kernel)
1340 .apply_candles(&c, "close")?;
1341 let def = EdcfParams::default();
1342 let row = output.values_for(&def).expect("default row missing");
1343 assert_eq!(row.len(), c.close.len());
1344 Ok(())
1345 }
1346
1347 #[cfg(debug_assertions)]
1348 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1349 skip_if_unsupported!(kernel, test);
1350
1351 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1352 let c = read_candles_from_csv(file)?;
1353
1354 let test_sources = vec!["open", "high", "low", "close", "hl2", "hlc3", "ohlc4"];
1355
1356 for source in &test_sources {
1357 let output = EdcfBatchBuilder::new()
1358 .kernel(kernel)
1359 .period_range(3, 200, 5)
1360 .apply_candles(&c, source)?;
1361
1362 for (idx, &val) in output.values.iter().enumerate() {
1363 if val.is_nan() {
1364 continue;
1365 }
1366
1367 let bits = val.to_bits();
1368 let row = idx / output.cols;
1369 let col = idx % output.cols;
1370
1371 if bits == 0x11111111_11111111 {
1372 panic!(
1373 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1374 test, val, bits, row, col, idx, source
1375 );
1376 }
1377
1378 if bits == 0x22222222_22222222 {
1379 panic!(
1380 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1381 test, val, bits, row, col, idx, source
1382 );
1383 }
1384
1385 if bits == 0x33333333_33333333 {
1386 panic!(
1387 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1388 test, val, bits, row, col, idx, source
1389 );
1390 }
1391 }
1392 }
1393
1394 let edge_case_ranges = vec![(3, 5, 1), (190, 200, 2), (50, 100, 10)];
1395 for (start, end, step) in edge_case_ranges {
1396 let output = EdcfBatchBuilder::new()
1397 .kernel(kernel)
1398 .period_range(start, end, step)
1399 .apply_candles(&c, "close")?;
1400
1401 for (idx, &val) in output.values.iter().enumerate() {
1402 if val.is_nan() {
1403 continue;
1404 }
1405
1406 let bits = val.to_bits();
1407 let row = idx / output.cols;
1408 let col = idx % output.cols;
1409
1410 if bits == 0x11111111_11111111
1411 || bits == 0x22222222_22222222
1412 || bits == 0x33333333_33333333
1413 {
1414 panic!(
1415 "[{}] Found poison value {} (0x{:016X}) at row {} col {} with range ({},{},{})",
1416 test, val, bits, row, col, start, end, step
1417 );
1418 }
1419 }
1420 }
1421
1422 Ok(())
1423 }
1424
1425 #[cfg(not(debug_assertions))]
1426 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1427 Ok(())
1428 }
1429
1430 macro_rules! gen_batch_tests {
1431 ($fn_name:ident) => {
1432 paste::paste! {
1433 #[test]
1434 fn [<$fn_name _scalar>]() {
1435 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1436 }
1437 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1438 #[test]
1439 fn [<$fn_name _avx2>]() {
1440 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1441 }
1442 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1443 #[test]
1444 fn [<$fn_name _avx512>]() {
1445 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1446 }
1447 #[test]
1448 fn [<$fn_name _auto_detect>]() {
1449 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1450 }
1451 }
1452 };
1453 }
1454
1455 gen_batch_tests!(check_batch_default_row);
1456 gen_batch_tests!(check_batch_no_poison);
1457}
1458
1459#[cfg(feature = "python")]
1460#[pyfunction(name = "edcf")]
1461#[pyo3(signature = (data, period, kernel=None))]
1462pub fn edcf_py<'py>(
1463 py: Python<'py>,
1464 data: numpy::PyReadonlyArray1<'py, f64>,
1465 period: usize,
1466 kernel: Option<&str>,
1467) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1468 use numpy::{IntoPyArray, PyArrayMethods};
1469
1470 let slice_in = data.as_slice()?;
1471 let kern = validate_kernel(kernel, false)?;
1472
1473 let params = EdcfParams {
1474 period: Some(period),
1475 };
1476 let edcf_in = EdcfInput::from_slice(slice_in, params);
1477
1478 let result_vec: Vec<f64> = py
1479 .allow_threads(|| edcf_with_kernel(&edcf_in, kern).map(|o| o.values))
1480 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1481
1482 Ok(result_vec.into_pyarray(py))
1483}
1484
1485#[cfg(feature = "python")]
1486#[pyclass(name = "EdcfStream")]
1487pub struct EdcfStreamPy {
1488 stream: EdcfStream,
1489}
1490
1491#[cfg(feature = "python")]
1492#[pymethods]
1493impl EdcfStreamPy {
1494 #[new]
1495 fn new(period: usize) -> PyResult<Self> {
1496 let params = EdcfParams {
1497 period: Some(period),
1498 };
1499 let stream =
1500 EdcfStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1501 Ok(EdcfStreamPy { stream })
1502 }
1503
1504 fn update(&mut self, value: f64) -> Option<f64> {
1505 self.stream.update(value)
1506 }
1507}
1508
1509#[cfg(feature = "python")]
1510#[pyfunction(name = "edcf_batch")]
1511#[pyo3(signature = (data, period_range, kernel=None))]
1512pub fn edcf_batch_py<'py>(
1513 py: Python<'py>,
1514 data: numpy::PyReadonlyArray1<'py, f64>,
1515 period_range: (usize, usize, usize),
1516 kernel: Option<&str>,
1517) -> PyResult<Bound<'py, PyDict>> {
1518 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1519
1520 let slice_in = data.as_slice()?;
1521
1522 let sweep = EdcfBatchRange {
1523 period: period_range,
1524 };
1525
1526 let combos = expand_grid(&sweep);
1527 let rows = combos.len();
1528 let cols = slice_in.len();
1529
1530 let total = rows
1531 .checked_mul(cols)
1532 .ok_or_else(|| PyValueError::new_err("edcf_batch: rows*cols overflow"))?;
1533
1534 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1535 let slice_out = unsafe { out_arr.as_slice_mut()? };
1536
1537 if !slice_in.is_empty() && rows > 0 {
1538 if let Some(first) = slice_in.iter().position(|x| !x.is_nan()) {
1539 let warm: Vec<usize> = combos
1540 .iter()
1541 .map(|c| {
1542 let period = c.period.unwrap_or(15);
1543 let w = first + 2 * period;
1544 if w > cols {
1545 cols
1546 } else {
1547 w
1548 }
1549 })
1550 .collect();
1551
1552 let buf_mu: &mut [MaybeUninit<f64>] = unsafe {
1553 core::slice::from_raw_parts_mut(
1554 slice_out.as_mut_ptr() as *mut MaybeUninit<f64>,
1555 slice_out.len(),
1556 )
1557 };
1558 init_matrix_prefixes(buf_mu, cols, &warm);
1559 }
1560 }
1561
1562 let kern = validate_kernel(kernel, true)?;
1563
1564 let combos = py
1565 .allow_threads(|| {
1566 let kernel = match kern {
1567 Kernel::Auto => Kernel::ScalarBatch,
1568 k => k,
1569 };
1570 let simd = match kernel {
1571 Kernel::Avx512Batch => Kernel::Avx512,
1572 Kernel::Avx2Batch => Kernel::Avx2,
1573 Kernel::ScalarBatch => Kernel::Scalar,
1574 _ => unreachable!(),
1575 };
1576 edcf_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1577 })
1578 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1579
1580 let dict = PyDict::new(py);
1581 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1582 dict.set_item(
1583 "periods",
1584 combos
1585 .iter()
1586 .map(|p| p.period.unwrap() as u64)
1587 .collect::<Vec<_>>()
1588 .into_pyarray(py),
1589 )?;
1590
1591 Ok(dict)
1592}
1593
1594#[cfg(all(feature = "python", feature = "cuda"))]
1595#[pyfunction(name = "edcf_cuda_batch_dev")]
1596#[pyo3(signature = (data_f32, period_range, device_id=0))]
1597pub fn edcf_cuda_batch_dev_py(
1598 py: Python<'_>,
1599 data_f32: numpy::PyReadonlyArray1<'_, f32>,
1600 period_range: (usize, usize, usize),
1601 device_id: usize,
1602) -> PyResult<DeviceArrayF32Py> {
1603 use crate::cuda::cuda_available;
1604 use crate::cuda::moving_averages::CudaEdcf;
1605
1606 if !cuda_available() {
1607 return Err(PyValueError::new_err("CUDA not available"));
1608 }
1609
1610 let slice_in = data_f32.as_slice()?;
1611 let sweep = EdcfBatchRange {
1612 period: period_range,
1613 };
1614
1615 let (inner, dev_id) = py.allow_threads(|| {
1616 let cuda = CudaEdcf::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1617 let dev_id = cuda.device_id();
1618 let out = cuda
1619 .edcf_batch_dev(slice_in, &sweep)
1620 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1621 Ok::<_, PyErr>((out, dev_id))
1622 })?;
1623
1624 make_device_array_py(dev_id as usize, inner)
1625}
1626
1627#[cfg(all(feature = "python", feature = "cuda"))]
1628#[pyfunction(name = "edcf_cuda_many_series_one_param_dev")]
1629#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1630pub fn edcf_cuda_many_series_one_param_dev_py(
1631 py: Python<'_>,
1632 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1633 period: usize,
1634 device_id: usize,
1635) -> PyResult<DeviceArrayF32Py> {
1636 use crate::cuda::cuda_available;
1637 use crate::cuda::moving_averages::CudaEdcf;
1638 use numpy::PyUntypedArrayMethods;
1639
1640 if !cuda_available() {
1641 return Err(PyValueError::new_err("CUDA not available"));
1642 }
1643
1644 let flat_in: &[f32] = data_tm_f32.as_slice()?;
1645 let rows = data_tm_f32.shape()[0];
1646 let cols = data_tm_f32.shape()[1];
1647 let params = EdcfParams {
1648 period: Some(period),
1649 };
1650
1651 let (inner, dev_id) = py.allow_threads(|| {
1652 let mut cuda =
1653 CudaEdcf::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1654 let dev_id = cuda.device_id();
1655 let out = cuda
1656 .edcf_many_series_one_param_time_major_dev(flat_in, cols, rows, ¶ms)
1657 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1658 Ok::<_, PyErr>((out, dev_id))
1659 })?;
1660
1661 make_device_array_py(dev_id as usize, inner)
1662}
1663
1664#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1665#[wasm_bindgen]
1666pub fn edcf_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1667 let params = EdcfParams {
1668 period: Some(period),
1669 };
1670 let input = EdcfInput::from_slice(data, params);
1671
1672 let mut output = vec![0.0; data.len()];
1673
1674 edcf_into_slice(&mut output, &input, Kernel::Auto)
1675 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1676
1677 Ok(output)
1678}
1679
1680#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1681#[wasm_bindgen]
1682pub fn edcf_batch_js(
1683 data: &[f64],
1684 period_start: usize,
1685 period_end: usize,
1686 period_step: usize,
1687) -> Result<Vec<f64>, JsValue> {
1688 let sweep = EdcfBatchRange {
1689 period: (period_start, period_end, period_step),
1690 };
1691
1692 edcf_batch_inner(data, &sweep, Kernel::Scalar, false)
1693 .map(|output| output.values)
1694 .map_err(|e| JsValue::from_str(&e.to_string()))
1695}
1696
1697#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1698#[wasm_bindgen]
1699pub fn edcf_batch_metadata_js(
1700 period_start: usize,
1701 period_end: usize,
1702 period_step: usize,
1703) -> Result<Vec<f64>, JsValue> {
1704 let sweep = EdcfBatchRange {
1705 period: (period_start, period_end, period_step),
1706 };
1707
1708 let combos = expand_grid(&sweep);
1709 let metadata: Vec<f64> = combos
1710 .iter()
1711 .map(|combo| combo.period.unwrap() as f64)
1712 .collect();
1713
1714 Ok(metadata)
1715}
1716
1717#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1718#[wasm_bindgen]
1719pub fn edcf_alloc(len: usize) -> *mut f64 {
1720 let mut vec = Vec::<f64>::with_capacity(len);
1721 let ptr = vec.as_mut_ptr();
1722 std::mem::forget(vec);
1723 ptr
1724}
1725
1726#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1727#[wasm_bindgen]
1728pub fn edcf_free(ptr: *mut f64, len: usize) {
1729 if !ptr.is_null() {
1730 unsafe {
1731 let _ = Vec::from_raw_parts(ptr, len, len);
1732 }
1733 }
1734}
1735
1736#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1737#[wasm_bindgen]
1738pub fn edcf_into(
1739 in_ptr: *const f64,
1740 out_ptr: *mut f64,
1741 len: usize,
1742 period: usize,
1743) -> Result<(), JsValue> {
1744 if in_ptr.is_null() || out_ptr.is_null() {
1745 return Err(JsValue::from_str("null pointer passed to edcf_into"));
1746 }
1747
1748 unsafe {
1749 let data = std::slice::from_raw_parts(in_ptr, len);
1750
1751 if period == 0 || period > len {
1752 return Err(JsValue::from_str("Invalid period"));
1753 }
1754
1755 let params = EdcfParams {
1756 period: Some(period),
1757 };
1758 let input = EdcfInput::from_slice(data, params);
1759
1760 if in_ptr == out_ptr {
1761 let mut temp = vec![0.0; len];
1762 edcf_into_slice(&mut temp, &input, Kernel::Auto)
1763 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1764
1765 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1766 out.copy_from_slice(&temp);
1767 } else {
1768 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1769 edcf_into_slice(out, &input, Kernel::Auto)
1770 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1771 }
1772
1773 Ok(())
1774 }
1775}
1776
1777#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1778#[derive(Serialize, Deserialize)]
1779pub struct EdcfBatchConfig {
1780 pub period_range: (usize, usize, usize),
1781}
1782
1783#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1784#[derive(Serialize, Deserialize)]
1785pub struct EdcfBatchJsOutput {
1786 pub values: Vec<f64>,
1787 pub combos: Vec<EdcfParams>,
1788 pub rows: usize,
1789 pub cols: usize,
1790}
1791
1792#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1793#[wasm_bindgen(js_name = edcf_batch)]
1794pub fn edcf_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1795 let config: EdcfBatchConfig = serde_wasm_bindgen::from_value(config)
1796 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1797
1798 let sweep = EdcfBatchRange {
1799 period: config.period_range,
1800 };
1801
1802 let output = edcf_batch_inner(data, &sweep, Kernel::Auto, false)
1803 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1804
1805 let js_output = EdcfBatchJsOutput {
1806 values: output.values,
1807 combos: output.combos,
1808 rows: output.rows,
1809 cols: output.cols,
1810 };
1811
1812 serde_wasm_bindgen::to_value(&js_output).map_err(|e| JsValue::from_str(&e.to_string()))
1813}
1814
1815#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1816#[wasm_bindgen]
1817pub fn edcf_batch_into(
1818 in_ptr: *const f64,
1819 out_ptr: *mut f64,
1820 len: usize,
1821 period_start: usize,
1822 period_end: usize,
1823 period_step: usize,
1824) -> Result<usize, JsValue> {
1825 if in_ptr.is_null() || out_ptr.is_null() {
1826 return Err(JsValue::from_str("null pointer passed to edcf_batch_into"));
1827 }
1828
1829 unsafe {
1830 let data = std::slice::from_raw_parts(in_ptr, len);
1831
1832 let sweep = EdcfBatchRange {
1833 period: (period_start, period_end, period_step),
1834 };
1835
1836 let combos = expand_grid(&sweep);
1837 let rows = combos.len();
1838 let cols = len;
1839
1840 let total = rows
1841 .checked_mul(cols)
1842 .ok_or_else(|| JsValue::from_str("edcf_batch_into: rows*cols overflow"))?;
1843
1844 let out = std::slice::from_raw_parts_mut(out_ptr, total);
1845
1846 if !data.is_empty() && rows > 0 {
1847 if let Some(first) = data.iter().position(|x| !x.is_nan()) {
1848 let warm: Vec<usize> = combos
1849 .iter()
1850 .map(|c| {
1851 let period = c.period.unwrap_or(15);
1852 let w = first + 2 * period;
1853 if w > cols {
1854 cols
1855 } else {
1856 w
1857 }
1858 })
1859 .collect();
1860
1861 let buf_mu: &mut [MaybeUninit<f64>] = core::slice::from_raw_parts_mut(
1862 out.as_mut_ptr() as *mut MaybeUninit<f64>,
1863 out.len(),
1864 );
1865 init_matrix_prefixes(buf_mu, cols, &warm);
1866 }
1867 }
1868
1869 edcf_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
1870 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1871
1872 Ok(rows)
1873 }
1874}