1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::{cuda_available, CudaEmd, CudaEmdBatchResult, DeviceArrayF32Triple};
3use crate::utilities::data_loader::{source_type, Candles};
4#[cfg(all(feature = "python", feature = "cuda"))]
5use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
6use crate::utilities::enums::Kernel;
7use crate::utilities::helpers::{
8 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
9 make_uninit_matrix,
10};
11#[cfg(feature = "python")]
12use crate::utilities::kernel_validation::validate_kernel;
13#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
14use core::arch::x86_64::*;
15#[cfg(feature = "python")]
16use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
17#[cfg(feature = "python")]
18use pyo3::exceptions::PyValueError;
19#[cfg(feature = "python")]
20use pyo3::prelude::*;
21#[cfg(feature = "python")]
22use pyo3::types::PyDict;
23#[cfg(not(target_arch = "wasm32"))]
24use rayon::prelude::*;
25use std::convert::AsRef;
26use std::error::Error;
27use std::mem::MaybeUninit;
28use thiserror::Error;
29
30impl<'a> AsRef<[f64]> for EmdInput<'a> {
31 #[inline(always)]
32 fn as_ref(&self) -> &[f64] {
33 match &self.data {
34 EmdData::Candles { candles } => source_type(candles, "close"),
35 EmdData::Slices { close, .. } => close,
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
41pub enum EmdData<'a> {
42 Candles {
43 candles: &'a Candles,
44 },
45 Slices {
46 high: &'a [f64],
47 low: &'a [f64],
48 close: &'a [f64],
49 volume: &'a [f64],
50 },
51}
52
53#[derive(Debug, Clone)]
54pub struct EmdOutput {
55 pub upperband: Vec<f64>,
56 pub middleband: Vec<f64>,
57 pub lowerband: Vec<f64>,
58}
59
60#[derive(Debug, Clone)]
61#[cfg_attr(
62 all(target_arch = "wasm32", feature = "wasm"),
63 derive(Serialize, Deserialize)
64)]
65pub struct EmdParams {
66 pub period: Option<usize>,
67 pub delta: Option<f64>,
68 pub fraction: Option<f64>,
69}
70
71impl Default for EmdParams {
72 fn default() -> Self {
73 Self {
74 period: Some(20),
75 delta: Some(0.5),
76 fraction: Some(0.1),
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
82pub struct EmdInput<'a> {
83 pub data: EmdData<'a>,
84 pub params: EmdParams,
85}
86
87impl<'a> EmdInput<'a> {
88 #[inline]
89 pub fn from_candles(candles: &'a Candles, params: EmdParams) -> Self {
90 Self {
91 data: EmdData::Candles { candles },
92 params,
93 }
94 }
95
96 #[inline]
97 pub fn from_slices(
98 high: &'a [f64],
99 low: &'a [f64],
100 close: &'a [f64],
101 volume: &'a [f64],
102 params: EmdParams,
103 ) -> Self {
104 Self {
105 data: EmdData::Slices {
106 high,
107 low,
108 close,
109 volume,
110 },
111 params,
112 }
113 }
114
115 #[inline]
116 pub fn with_default_candles(candles: &'a Candles) -> Self {
117 Self::from_candles(candles, EmdParams::default())
118 }
119
120 #[inline]
121 pub fn get_period(&self) -> usize {
122 self.params.period.unwrap_or(20)
123 }
124 #[inline]
125 pub fn get_delta(&self) -> f64 {
126 self.params.delta.unwrap_or(0.5)
127 }
128 #[inline]
129 pub fn get_fraction(&self) -> f64 {
130 self.params.fraction.unwrap_or(0.1)
131 }
132}
133
134#[derive(Copy, Clone, Debug)]
135pub struct EmdBuilder {
136 period: Option<usize>,
137 delta: Option<f64>,
138 fraction: Option<f64>,
139 kernel: Kernel,
140}
141
142impl Default for EmdBuilder {
143 fn default() -> Self {
144 Self {
145 period: None,
146 delta: None,
147 fraction: None,
148 kernel: Kernel::Auto,
149 }
150 }
151}
152
153impl EmdBuilder {
154 #[inline(always)]
155 pub fn new() -> Self {
156 Self::default()
157 }
158 #[inline(always)]
159 pub fn period(mut self, n: usize) -> Self {
160 self.period = Some(n);
161 self
162 }
163 #[inline(always)]
164 pub fn delta(mut self, d: f64) -> Self {
165 self.delta = Some(d);
166 self
167 }
168 #[inline(always)]
169 pub fn fraction(mut self, f: f64) -> Self {
170 self.fraction = Some(f);
171 self
172 }
173 #[inline(always)]
174 pub fn kernel(mut self, k: Kernel) -> Self {
175 self.kernel = k;
176 self
177 }
178
179 #[inline(always)]
180 pub fn apply(self, c: &Candles) -> Result<EmdOutput, EmdError> {
181 let p = EmdParams {
182 period: self.period,
183 delta: self.delta,
184 fraction: self.fraction,
185 };
186 let i = EmdInput::from_candles(c, p);
187 emd_with_kernel(&i, self.kernel)
188 }
189
190 #[inline(always)]
191 pub fn apply_slices(
192 self,
193 high: &[f64],
194 low: &[f64],
195 close: &[f64],
196 volume: &[f64],
197 ) -> Result<EmdOutput, EmdError> {
198 let p = EmdParams {
199 period: self.period,
200 delta: self.delta,
201 fraction: self.fraction,
202 };
203 let i = EmdInput::from_slices(high, low, close, volume, p);
204 emd_with_kernel(&i, self.kernel)
205 }
206
207 #[inline(always)]
208 pub fn into_stream(self) -> Result<EmdStream, EmdError> {
209 let p = EmdParams {
210 period: self.period,
211 delta: self.delta,
212 fraction: self.fraction,
213 };
214 EmdStream::try_new(p)
215 }
216}
217
218#[derive(Debug, Error)]
219pub enum EmdError {
220 #[error("emd: Invalid input length (empty input data)")]
221 EmptyInputData,
222
223 #[error("emd: All values are NaN.")]
224 AllValuesNaN,
225
226 #[error("emd: Invalid period: period = {period}, data length = {data_len}")]
227 InvalidPeriod { period: usize, data_len: usize },
228
229 #[error("emd: Not enough valid data: needed = {needed}, valid = {valid}")]
230 NotEnoughValidData { needed: usize, valid: usize },
231
232 #[error("emd: Invalid delta: {delta}")]
233 InvalidDelta { delta: f64 },
234
235 #[error("emd: Invalid fraction: {fraction}")]
236 InvalidFraction { fraction: f64 },
237
238 #[error("emd: Invalid input length: expected = {expected}, actual = {actual}")]
239 InvalidInputLength { expected: usize, actual: usize },
240
241 #[error("emd: Output length mismatch: expected = {expected}, got = {got}")]
242 OutputLengthMismatch { expected: usize, got: usize },
243
244 #[error("emd: Invalid range (usize): start={start}, end={end}, step={step}")]
245 InvalidRangeU {
246 start: usize,
247 end: usize,
248 step: usize,
249 },
250
251 #[error("emd: Invalid range (float): start={start}, end={end}, step={step}")]
252 InvalidRangeF { start: f64, end: f64, step: f64 },
253
254 #[error("emd: Invalid kernel for batch path: {0:?}")]
255 InvalidKernelForBatch(Kernel),
256}
257
258#[inline]
259pub fn emd(input: &EmdInput) -> Result<EmdOutput, EmdError> {
260 emd_with_kernel(input, Kernel::Auto)
261}
262
263#[inline]
264pub fn emd_into_slices(
265 ub: &mut [f64],
266 mb: &mut [f64],
267 lb: &mut [f64],
268 input: &EmdInput,
269 kernel: Kernel,
270) -> Result<(), EmdError> {
271 let (high, low, period, delta, fraction, first, chosen) = emd_prepare(input, kernel)?;
272 if ub.len() != high.len() || mb.len() != high.len() || lb.len() != high.len() {
273 return Err(EmdError::OutputLengthMismatch {
274 expected: high.len(),
275 got: ub.len().min(mb.len()).min(lb.len()),
276 });
277 }
278
279 emd_compute_into(
280 high, low, period, delta, fraction, first, chosen, ub, mb, lb,
281 );
282
283 let up_low_warm = first + 50 - 1;
284 let mid_warm = first + 2 * period - 1;
285 let ub_len = ub.len();
286 let lb_len = lb.len();
287 let mb_len = mb.len();
288 for v in &mut ub[..up_low_warm.min(ub_len)] {
289 *v = f64::NAN;
290 }
291 for v in &mut lb[..up_low_warm.min(lb_len)] {
292 *v = f64::NAN;
293 }
294 for v in &mut mb[..mid_warm.min(mb_len)] {
295 *v = f64::NAN;
296 }
297
298 Ok(())
299}
300
301#[inline]
302fn emd_prepare<'a>(
303 input: &'a EmdInput<'a>,
304 kernel: Kernel,
305) -> Result<(&'a [f64], &'a [f64], usize, f64, f64, usize, Kernel), EmdError> {
306 let (high, low) = match &input.data {
307 EmdData::Candles { candles } => (source_type(candles, "high"), source_type(candles, "low")),
308 EmdData::Slices { high, low, .. } => (*high, *low),
309 };
310
311 let len = high.len();
312 if len == 0 {
313 return Err(EmdError::EmptyInputData);
314 }
315 if low.len() != len {
316 return Err(EmdError::InvalidInputLength {
317 expected: len,
318 actual: low.len(),
319 });
320 }
321
322 let period = input.get_period();
323 let delta = input.get_delta();
324 let fraction = input.get_fraction();
325
326 let first = (0..len)
327 .find(|&i| !high[i].is_nan() && !low[i].is_nan())
328 .ok_or(EmdError::AllValuesNaN)?;
329
330 if period == 0 || period > len {
331 return Err(EmdError::InvalidPeriod {
332 period,
333 data_len: len,
334 });
335 }
336 let needed = (2 * period).max(50);
337 if len - first < needed {
338 return Err(EmdError::NotEnoughValidData {
339 needed,
340 valid: len - first,
341 });
342 }
343 if delta.is_nan() || delta.is_infinite() {
344 return Err(EmdError::InvalidDelta { delta });
345 }
346 if fraction.is_nan() || fraction.is_infinite() {
347 return Err(EmdError::InvalidFraction { fraction });
348 }
349
350 let chosen = match kernel {
351 Kernel::Auto => Kernel::Scalar,
352 k => k,
353 };
354 Ok((high, low, period, delta, fraction, first, chosen))
355}
356
357#[inline(always)]
358fn emd_compute_into(
359 high: &[f64],
360 low: &[f64],
361 period: usize,
362 delta: f64,
363 fraction: f64,
364 first: usize,
365 kernel: Kernel,
366 ub: &mut [f64],
367 mb: &mut [f64],
368 lb: &mut [f64],
369) {
370 unsafe {
371 match kernel {
372 Kernel::Scalar | Kernel::ScalarBatch => {
373 emd_scalar_into(high, low, period, delta, fraction, first, ub, mb, lb)
374 }
375 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
376 Kernel::Avx2 | Kernel::Avx2Batch => {
377 emd_scalar_into(high, low, period, delta, fraction, first, ub, mb, lb)
378 }
379 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
380 Kernel::Avx512 | Kernel::Avx512Batch => {
381 emd_scalar_into(high, low, period, delta, fraction, first, ub, mb, lb)
382 }
383 _ => unreachable!(),
384 }
385 }
386}
387
388pub fn emd_with_kernel(input: &EmdInput, kernel: Kernel) -> Result<EmdOutput, EmdError> {
389 let (high, low, period, delta, fraction, first, chosen) = emd_prepare(input, kernel)?;
390 let len = high.len();
391 let up_low_warm = first + 50 - 1;
392 let mid_warm = first + 2 * period - 1;
393
394 let mut upperband = alloc_with_nan_prefix(len, up_low_warm);
395 let mut middleband = alloc_with_nan_prefix(len, mid_warm);
396 let mut lowerband = alloc_with_nan_prefix(len, up_low_warm);
397
398 emd_compute_into(
399 high,
400 low,
401 period,
402 delta,
403 fraction,
404 first,
405 chosen,
406 &mut upperband,
407 &mut middleband,
408 &mut lowerband,
409 );
410
411 Ok(EmdOutput {
412 upperband,
413 middleband,
414 lowerband,
415 })
416}
417
418#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
419pub fn emd_into(
420 input: &EmdInput,
421 upperband_out: &mut [f64],
422 middleband_out: &mut [f64],
423 lowerband_out: &mut [f64],
424) -> Result<(), EmdError> {
425 let (high, low, period, delta, fraction, first, chosen) = emd_prepare(input, Kernel::Auto)?;
426
427 if upperband_out.len() != high.len()
428 || middleband_out.len() != high.len()
429 || lowerband_out.len() != high.len()
430 {
431 return Err(EmdError::OutputLengthMismatch {
432 expected: high.len(),
433 got: upperband_out
434 .len()
435 .min(middleband_out.len())
436 .min(lowerband_out.len()),
437 });
438 }
439
440 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
441 let up_low_warm = first + 50 - 1;
442 let mid_warm = first + 2 * period - 1;
443
444 let end_u = up_low_warm.min(upperband_out.len());
445 for v in &mut upperband_out[..end_u] {
446 *v = qnan;
447 }
448 let end_l = up_low_warm.min(lowerband_out.len());
449 for v in &mut lowerband_out[..end_l] {
450 *v = qnan;
451 }
452 let end_m = mid_warm.min(middleband_out.len());
453 for v in &mut middleband_out[..end_m] {
454 *v = qnan;
455 }
456
457 emd_compute_into(
458 high,
459 low,
460 period,
461 delta,
462 fraction,
463 first,
464 chosen,
465 upperband_out,
466 middleband_out,
467 lowerband_out,
468 );
469
470 Ok(())
471}
472
473#[inline]
474pub unsafe fn emd_scalar_into(
475 high: &[f64],
476 low: &[f64],
477 period: usize,
478 delta: f64,
479 fraction: f64,
480 first: usize,
481 ub: &mut [f64],
482 mb: &mut [f64],
483 lb: &mut [f64],
484) {
485 let len = high.len();
486 debug_assert_eq!(low.len(), len);
487 debug_assert_eq!(ub.len(), len);
488 debug_assert_eq!(mb.len(), len);
489 debug_assert_eq!(lb.len(), len);
490
491 let per_up_low = 50usize;
492 let per_mid = 2 * period;
493 let inv_up_low = 1.0 / (per_up_low as f64);
494 let inv_mid = 1.0 / (per_mid as f64);
495
496 let two_pi = core::f64::consts::PI * 2.0;
497 let beta = (two_pi / (period as f64)).cos();
498 let gamma = 1.0 / ((two_pi * 2.0 * delta / (period as f64)).cos());
499 let alpha = gamma - (gamma * gamma - 1.0).sqrt();
500 let half_one_minus_alpha = 0.5 * (1.0 - alpha);
501 let beta_times_one_plus_alpha = beta * (1.0 + alpha);
502
503 let mut sp_ring = vec![0.0f64; per_up_low];
504 let mut sv_ring = vec![0.0f64; per_up_low];
505 let mut bp_ring = vec![0.0f64; per_mid];
506
507 let mut idx_ul = 0usize;
508 let mut idx_mid = 0usize;
509
510 let mut sum_up = 0.0f64;
511 let mut sum_low = 0.0f64;
512 let mut sum_mb = 0.0f64;
513
514 let mut bp_prev1 = 0.0f64;
515 let mut bp_prev2 = 0.0f64;
516 let mut peak_prev = 0.0f64;
517 let mut valley_prev = 0.0f64;
518
519 let mut price_prev1 = 0.0f64;
520 let mut price_prev2 = 0.0f64;
521
522 let hi_ptr = high.as_ptr();
523 let lo_ptr = low.as_ptr();
524 let ub_ptr = ub.as_mut_ptr();
525 let mb_ptr = mb.as_mut_ptr();
526 let lb_ptr = lb.as_mut_ptr();
527
528 let mut i = first;
529 if i < len {
530 let p0 = ((*hi_ptr.add(i)) + (*lo_ptr.add(i))) * 0.5;
531 bp_prev1 = p0;
532 bp_prev2 = p0;
533 peak_prev = p0;
534 valley_prev = p0;
535 price_prev1 = p0;
536 price_prev2 = p0;
537 }
538
539 let mut count = 0usize;
540
541 while i < len {
542 let price = ((*hi_ptr.add(i)) + (*lo_ptr.add(i))) * 0.5;
543
544 let bp_curr = if count >= 2 {
545 half_one_minus_alpha * (price - price_prev2) + beta_times_one_plus_alpha * bp_prev1
546 - alpha * bp_prev2
547 } else {
548 price
549 };
550
551 let mut peak_curr = peak_prev;
552 let mut valley_curr = valley_prev;
553 if count >= 2 {
554 if bp_prev1 > bp_curr && bp_prev1 > bp_prev2 {
555 peak_curr = bp_prev1;
556 }
557 if bp_prev1 < bp_curr && bp_prev1 < bp_prev2 {
558 valley_curr = bp_prev1;
559 }
560 }
561
562 let sp = peak_curr * fraction;
563 let sv = valley_curr * fraction;
564
565 let old_sp = *sp_ring.get_unchecked(idx_ul);
566 let old_sv = *sv_ring.get_unchecked(idx_ul);
567 let old_bp = *bp_ring.get_unchecked(idx_mid);
568
569 *sp_ring.get_unchecked_mut(idx_ul) = sp;
570 *sv_ring.get_unchecked_mut(idx_ul) = sv;
571 *bp_ring.get_unchecked_mut(idx_mid) = bp_curr;
572
573 sum_up += sp - old_sp;
574 sum_low += sv - old_sv;
575 sum_mb += bp_curr - old_bp;
576
577 idx_ul += 1;
578 if idx_ul == per_up_low {
579 idx_ul = 0;
580 }
581 idx_mid += 1;
582 if idx_mid == per_mid {
583 idx_mid = 0;
584 }
585
586 let filled = count + 1;
587 if filled >= per_up_low {
588 *ub_ptr.add(i) = sum_up * inv_up_low;
589 *lb_ptr.add(i) = sum_low * inv_up_low;
590 }
591 if filled >= per_mid {
592 *mb_ptr.add(i) = sum_mb * inv_mid;
593 }
594
595 bp_prev2 = bp_prev1;
596 bp_prev1 = bp_curr;
597 peak_prev = peak_curr;
598 valley_prev = valley_curr;
599 price_prev2 = price_prev1;
600 price_prev1 = price;
601
602 count += 1;
603 i += 1;
604 }
605}
606
607#[inline(always)]
608fn emd_compute_from_prices_into(
609 prices: &[f64],
610 period: usize,
611 delta: f64,
612 fraction: f64,
613 first: usize,
614 kernel: Kernel,
615 ub: &mut [f64],
616 mb: &mut [f64],
617 lb: &mut [f64],
618) {
619 unsafe {
620 match kernel {
621 Kernel::Scalar | Kernel::ScalarBatch => {
622 emd_scalar_prices_into(prices, period, delta, fraction, first, ub, mb, lb)
623 }
624 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
625 Kernel::Avx2 | Kernel::Avx2Batch => {
626 emd_scalar_prices_into(prices, period, delta, fraction, first, ub, mb, lb)
627 }
628 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
629 Kernel::Avx512 | Kernel::Avx512Batch => {
630 emd_scalar_prices_into(prices, period, delta, fraction, first, ub, mb, lb)
631 }
632 _ => unreachable!(),
633 }
634 }
635}
636
637#[inline]
638unsafe fn emd_scalar_prices_into(
639 prices: &[f64],
640 period: usize,
641 delta: f64,
642 fraction: f64,
643 first: usize,
644 ub: &mut [f64],
645 mb: &mut [f64],
646 lb: &mut [f64],
647) {
648 let len = prices.len();
649 debug_assert_eq!(ub.len(), len);
650 debug_assert_eq!(mb.len(), len);
651 debug_assert_eq!(lb.len(), len);
652
653 let per_up_low = 50usize;
654 let per_mid = 2 * period;
655 let inv_up_low = 1.0 / (per_up_low as f64);
656 let inv_mid = 1.0 / (per_mid as f64);
657
658 let two_pi = core::f64::consts::PI * 2.0;
659 let beta = (two_pi / (period as f64)).cos();
660 let gamma = 1.0 / ((two_pi * 2.0 * delta / (period as f64)).cos());
661 let alpha = gamma - (gamma * gamma - 1.0).sqrt();
662 let half_one_minus_alpha = 0.5 * (1.0 - alpha);
663 let beta_times_one_plus_alpha = beta * (1.0 + alpha);
664
665 let mut sp_ring = vec![0.0f64; per_up_low];
666 let mut sv_ring = vec![0.0f64; per_up_low];
667 let mut bp_ring = vec![0.0f64; per_mid];
668 let mut idx_ul = 0usize;
669 let mut idx_mid = 0usize;
670
671 let mut sum_up = 0.0f64;
672 let mut sum_low = 0.0f64;
673 let mut sum_mb = 0.0f64;
674
675 let mut bp_prev1 = 0.0f64;
676 let mut bp_prev2 = 0.0f64;
677 let mut peak_prev = 0.0f64;
678 let mut valley_prev = 0.0f64;
679
680 let mut price_prev1 = 0.0f64;
681 let mut price_prev2 = 0.0f64;
682
683 let pr_ptr = prices.as_ptr();
684 let ub_ptr = ub.as_mut_ptr();
685 let mb_ptr = mb.as_mut_ptr();
686 let lb_ptr = lb.as_mut_ptr();
687
688 let mut i = first;
689 if i < len {
690 let p0 = *pr_ptr.add(i);
691 bp_prev1 = p0;
692 bp_prev2 = p0;
693 peak_prev = p0;
694 valley_prev = p0;
695 price_prev1 = p0;
696 price_prev2 = p0;
697 }
698
699 let mut count = 0usize;
700 while i < len {
701 let price = *pr_ptr.add(i);
702
703 let bp_curr = if count >= 2 {
704 half_one_minus_alpha * (price - price_prev2) + beta_times_one_plus_alpha * bp_prev1
705 - alpha * bp_prev2
706 } else {
707 price
708 };
709
710 let mut peak_curr = peak_prev;
711 let mut valley_curr = valley_prev;
712 if count >= 2 {
713 if bp_prev1 > bp_curr && bp_prev1 > bp_prev2 {
714 peak_curr = bp_prev1;
715 }
716 if bp_prev1 < bp_curr && bp_prev1 < bp_prev2 {
717 valley_curr = bp_prev1;
718 }
719 }
720 let sp = peak_curr * fraction;
721 let sv = valley_curr * fraction;
722
723 let old_sp = *sp_ring.get_unchecked(idx_ul);
724 let old_sv = *sv_ring.get_unchecked(idx_ul);
725 let old_bp = *bp_ring.get_unchecked(idx_mid);
726
727 *sp_ring.get_unchecked_mut(idx_ul) = sp;
728 *sv_ring.get_unchecked_mut(idx_ul) = sv;
729 *bp_ring.get_unchecked_mut(idx_mid) = bp_curr;
730
731 sum_up += sp - old_sp;
732 sum_low += sv - old_sv;
733 sum_mb += bp_curr - old_bp;
734
735 idx_ul += 1;
736 if idx_ul == per_up_low {
737 idx_ul = 0;
738 }
739 idx_mid += 1;
740 if idx_mid == per_mid {
741 idx_mid = 0;
742 }
743
744 let filled = count + 1;
745 if filled >= per_up_low {
746 *ub_ptr.add(i) = sum_up * inv_up_low;
747 *lb_ptr.add(i) = sum_low * inv_up_low;
748 }
749 if filled >= per_mid {
750 *mb_ptr.add(i) = sum_mb * inv_mid;
751 }
752
753 bp_prev2 = bp_prev1;
754 bp_prev1 = bp_curr;
755 peak_prev = peak_curr;
756 valley_prev = valley_curr;
757 price_prev2 = price_prev1;
758 price_prev1 = price;
759
760 count += 1;
761 i += 1;
762 }
763}
764
765#[inline]
766pub unsafe fn emd_scalar(
767 high: &[f64],
768 low: &[f64],
769 period: usize,
770 delta: f64,
771 fraction: f64,
772 first: usize,
773 len: usize,
774) -> Result<EmdOutput, EmdError> {
775 let per_up_low = 50;
776 let per_mid = 2 * period;
777 let upperband_warmup = first + per_up_low - 1;
778 let middleband_warmup = first + per_mid - 1;
779
780 let mut upperband = alloc_with_nan_prefix(len, upperband_warmup);
781 let mut middleband = alloc_with_nan_prefix(len, middleband_warmup);
782 let mut lowerband = alloc_with_nan_prefix(len, upperband_warmup);
783
784 emd_scalar_into(
785 high,
786 low,
787 period,
788 delta,
789 fraction,
790 first,
791 &mut upperband,
792 &mut middleband,
793 &mut lowerband,
794 );
795
796 Ok(EmdOutput {
797 upperband,
798 middleband,
799 lowerband,
800 })
801}
802
803#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
804#[inline]
805pub unsafe fn emd_avx2(
806 high: &[f64],
807 low: &[f64],
808 period: usize,
809 delta: f64,
810 fraction: f64,
811 first: usize,
812 len: usize,
813) -> Result<EmdOutput, EmdError> {
814 emd_scalar(high, low, period, delta, fraction, first, len)
815}
816
817#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
818#[inline]
819pub unsafe fn emd_avx512(
820 high: &[f64],
821 low: &[f64],
822 period: usize,
823 delta: f64,
824 fraction: f64,
825 first: usize,
826 len: usize,
827) -> Result<EmdOutput, EmdError> {
828 if period <= 32 {
829 emd_avx512_short(high, low, period, delta, fraction, first, len)
830 } else {
831 emd_avx512_long(high, low, period, delta, fraction, first, len)
832 }
833}
834
835#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
836#[inline]
837pub unsafe fn emd_avx512_short(
838 high: &[f64],
839 low: &[f64],
840 period: usize,
841 delta: f64,
842 fraction: f64,
843 first: usize,
844 len: usize,
845) -> Result<EmdOutput, EmdError> {
846 emd_scalar(high, low, period, delta, fraction, first, len)
847}
848
849#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
850#[inline]
851pub unsafe fn emd_avx512_long(
852 high: &[f64],
853 low: &[f64],
854 period: usize,
855 delta: f64,
856 fraction: f64,
857 first: usize,
858 len: usize,
859) -> Result<EmdOutput, EmdError> {
860 emd_scalar(high, low, period, delta, fraction, first, len)
861}
862
863#[inline(always)]
864pub fn emd_row_scalar(
865 high: &[f64],
866 low: &[f64],
867 period: usize,
868 delta: f64,
869 fraction: f64,
870 first: usize,
871 len: usize,
872) -> Result<EmdOutput, EmdError> {
873 unsafe { emd_scalar(high, low, period, delta, fraction, first, len) }
874}
875
876#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
877#[inline(always)]
878pub fn emd_row_avx2(
879 high: &[f64],
880 low: &[f64],
881 period: usize,
882 delta: f64,
883 fraction: f64,
884 first: usize,
885 len: usize,
886) -> Result<EmdOutput, EmdError> {
887 unsafe { emd_avx2(high, low, period, delta, fraction, first, len) }
888}
889
890#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
891#[inline(always)]
892pub fn emd_row_avx512(
893 high: &[f64],
894 low: &[f64],
895 period: usize,
896 delta: f64,
897 fraction: f64,
898 first: usize,
899 len: usize,
900) -> Result<EmdOutput, EmdError> {
901 unsafe { emd_avx512(high, low, period, delta, fraction, first, len) }
902}
903
904#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
905#[inline(always)]
906pub fn emd_row_avx512_short(
907 high: &[f64],
908 low: &[f64],
909 period: usize,
910 delta: f64,
911 fraction: f64,
912 first: usize,
913 len: usize,
914) -> Result<EmdOutput, EmdError> {
915 unsafe { emd_avx512_short(high, low, period, delta, fraction, first, len) }
916}
917
918#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
919#[inline(always)]
920pub fn emd_row_avx512_long(
921 high: &[f64],
922 low: &[f64],
923 period: usize,
924 delta: f64,
925 fraction: f64,
926 first: usize,
927 len: usize,
928) -> Result<EmdOutput, EmdError> {
929 unsafe { emd_avx512_long(high, low, period, delta, fraction, first, len) }
930}
931
932#[derive(Debug, Clone)]
933pub struct EmdStream {
934 period: usize,
935 delta: f64,
936 fraction: f64,
937 per_up_low: usize,
938 per_mid: usize,
939
940 inv_up_low: f64,
941 inv_mid: f64,
942 sum_up: f64,
943 sum_low: f64,
944 sum_mb: f64,
945 sp_ring: Vec<f64>,
946 sv_ring: Vec<f64>,
947 bp_ring: Vec<f64>,
948 idx_up_low: usize,
949 idx_mid: usize,
950 bp_prev1: f64,
951 bp_prev2: f64,
952 peak_prev: f64,
953 valley_prev: f64,
954 price_prev1: f64,
955 price_prev2: f64,
956 alpha: f64,
957 beta: f64,
958
959 beta_times_one_plus_alpha: f64,
960 half_one_minus_alpha: f64,
961 initialized: bool,
962 count: usize,
963}
964
965impl EmdStream {
966 pub fn try_new(params: EmdParams) -> Result<Self, EmdError> {
967 let period = params.period.unwrap_or(20);
968 let delta = params.delta.unwrap_or(0.5);
969 let fraction = params.fraction.unwrap_or(0.1);
970
971 if period == 0 {
972 return Err(EmdError::InvalidPeriod {
973 period,
974 data_len: 0,
975 });
976 }
977 if delta.is_nan() || delta.is_infinite() {
978 return Err(EmdError::InvalidDelta { delta });
979 }
980 if fraction.is_nan() || fraction.is_infinite() {
981 return Err(EmdError::InvalidFraction { fraction });
982 }
983
984 let two_pi_over_p = 2.0 * std::f64::consts::PI / (period as f64);
985 let beta = (two_pi_over_p).cos();
986 let gamma = 1.0 / ((2.0 * delta * two_pi_over_p).cos());
987 let alpha = gamma - (gamma * gamma - 1.0).sqrt();
988 let half_one_minus_alpha = 0.5 * (1.0 - alpha);
989 let beta_times_one_plus_alpha = beta * (1.0 + alpha);
990 let per_up_low = 50usize;
991 let per_mid = 2 * period;
992
993 Ok(Self {
994 period,
995 delta,
996 fraction,
997 per_up_low,
998 per_mid,
999 inv_up_low: 1.0 / (per_up_low as f64),
1000 inv_mid: 1.0 / (per_mid as f64),
1001 sum_up: 0.0,
1002 sum_low: 0.0,
1003 sum_mb: 0.0,
1004 sp_ring: vec![0.0; per_up_low],
1005 sv_ring: vec![0.0; per_up_low],
1006 bp_ring: vec![0.0; per_mid],
1007 idx_up_low: 0,
1008 idx_mid: 0,
1009 bp_prev1: 0.0,
1010 bp_prev2: 0.0,
1011 peak_prev: 0.0,
1012 valley_prev: 0.0,
1013 price_prev1: 0.0,
1014 price_prev2: 0.0,
1015 alpha,
1016 beta,
1017 beta_times_one_plus_alpha,
1018 half_one_minus_alpha,
1019 initialized: false,
1020 count: 0,
1021 })
1022 }
1023
1024 #[inline(always)]
1025 pub fn update(&mut self, high: f64, low: f64) -> (Option<f64>, Option<f64>, Option<f64>) {
1026 let price = (high + low) * 0.5;
1027
1028 if !self.initialized {
1029 self.bp_prev1 = price;
1030 self.bp_prev2 = price;
1031 self.peak_prev = price;
1032 self.valley_prev = price;
1033 self.price_prev1 = price;
1034 self.price_prev2 = price;
1035 self.initialized = true;
1036 }
1037 let bp_curr = if self.count >= 2 {
1038 self.half_one_minus_alpha * (price - self.price_prev2)
1039 + self
1040 .beta_times_one_plus_alpha
1041 .mul_add(self.bp_prev1, -self.alpha * self.bp_prev2)
1042 } else {
1043 price
1044 };
1045 let mut peak_curr = self.peak_prev;
1046 let mut valley_curr = self.valley_prev;
1047 if self.count >= 2 {
1048 if self.bp_prev1 > bp_curr && self.bp_prev1 > self.bp_prev2 {
1049 peak_curr = self.bp_prev1;
1050 }
1051 if self.bp_prev1 < bp_curr && self.bp_prev1 < self.bp_prev2 {
1052 valley_curr = self.bp_prev1;
1053 }
1054 }
1055 let sp = peak_curr * self.fraction;
1056 let sv = valley_curr * self.fraction;
1057
1058 let old_sp = self.sp_ring[self.idx_up_low];
1059 let old_sv = self.sv_ring[self.idx_up_low];
1060 let old_bp = self.bp_ring[self.idx_mid];
1061 self.sum_up += sp - old_sp;
1062 self.sum_low += sv - old_sv;
1063 self.sum_mb += bp_curr - old_bp;
1064 self.sp_ring[self.idx_up_low] = sp;
1065 self.sv_ring[self.idx_up_low] = sv;
1066 self.bp_ring[self.idx_mid] = bp_curr;
1067
1068 self.idx_up_low += 1;
1069 if self.idx_up_low == self.per_up_low {
1070 self.idx_up_low = 0;
1071 }
1072 self.idx_mid += 1;
1073 if self.idx_mid == self.per_mid {
1074 self.idx_mid = 0;
1075 }
1076 let mut ub = None;
1077 let mut lb = None;
1078 let mut mb = None;
1079 if self.count + 1 >= self.per_up_low {
1080 ub = Some(self.sum_up * self.inv_up_low);
1081 lb = Some(self.sum_low * self.inv_up_low);
1082 }
1083 if self.count + 1 >= self.per_mid {
1084 mb = Some(self.sum_mb * self.inv_mid);
1085 }
1086 self.bp_prev2 = self.bp_prev1;
1087 self.bp_prev1 = bp_curr;
1088 self.peak_prev = peak_curr;
1089 self.valley_prev = valley_curr;
1090 self.price_prev2 = self.price_prev1;
1091 self.price_prev1 = price;
1092 self.count += 1;
1093 (ub, mb, lb)
1094 }
1095}
1096
1097#[derive(Clone, Debug)]
1098pub struct EmdBatchRange {
1099 pub period: (usize, usize, usize),
1100 pub delta: (f64, f64, f64),
1101 pub fraction: (f64, f64, f64),
1102}
1103
1104impl Default for EmdBatchRange {
1105 fn default() -> Self {
1106 Self {
1107 period: (20, 269, 1),
1108 delta: (0.5, 0.5, 0.0),
1109 fraction: (0.1, 0.1, 0.0),
1110 }
1111 }
1112}
1113
1114#[derive(Clone, Debug, Default)]
1115pub struct EmdBatchBuilder {
1116 range: EmdBatchRange,
1117 kernel: Kernel,
1118}
1119
1120impl EmdBatchBuilder {
1121 pub fn new() -> Self {
1122 Self::default()
1123 }
1124 pub fn kernel(mut self, k: Kernel) -> Self {
1125 self.kernel = k;
1126 self
1127 }
1128 #[inline]
1129 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1130 self.range.period = (start, end, step);
1131 self
1132 }
1133 #[inline]
1134 pub fn period_static(mut self, p: usize) -> Self {
1135 self.range.period = (p, p, 0);
1136 self
1137 }
1138 #[inline]
1139 pub fn delta_range(mut self, start: f64, end: f64, step: f64) -> Self {
1140 self.range.delta = (start, end, step);
1141 self
1142 }
1143 #[inline]
1144 pub fn delta_static(mut self, x: f64) -> Self {
1145 self.range.delta = (x, x, 0.0);
1146 self
1147 }
1148 #[inline]
1149 pub fn fraction_range(mut self, start: f64, end: f64, step: f64) -> Self {
1150 self.range.fraction = (start, end, step);
1151 self
1152 }
1153 #[inline]
1154 pub fn fraction_static(mut self, x: f64) -> Self {
1155 self.range.fraction = (x, x, 0.0);
1156 self
1157 }
1158 pub fn apply_slices(
1159 self,
1160 high: &[f64],
1161 low: &[f64],
1162 close: &[f64],
1163 volume: &[f64],
1164 ) -> Result<EmdBatchOutput, EmdError> {
1165 emd_batch_with_kernel(high, low, &self.range, self.kernel)
1166 }
1167 pub fn with_default_slices(
1168 high: &[f64],
1169 low: &[f64],
1170 close: &[f64],
1171 volume: &[f64],
1172 k: Kernel,
1173 ) -> Result<EmdBatchOutput, EmdError> {
1174 EmdBatchBuilder::new()
1175 .kernel(k)
1176 .apply_slices(high, low, close, volume)
1177 }
1178 pub fn apply_candles(self, c: &Candles) -> Result<EmdBatchOutput, EmdError> {
1179 let high = source_type(c, "high");
1180 let low = source_type(c, "low");
1181 let close = source_type(c, "close");
1182 let volume = source_type(c, "volume");
1183 self.apply_slices(high, low, close, volume)
1184 }
1185}
1186
1187pub fn emd_batch_with_kernel(
1188 high: &[f64],
1189 low: &[f64],
1190 sweep: &EmdBatchRange,
1191 k: Kernel,
1192) -> Result<EmdBatchOutput, EmdError> {
1193 let kernel = match k {
1194 Kernel::Auto => Kernel::ScalarBatch,
1195 other if other.is_batch() => other,
1196 _ => {
1197 return Err(EmdError::InvalidKernelForBatch(k));
1198 }
1199 };
1200 emd_batch_par_slice(high, low, sweep, kernel)
1201}
1202
1203#[derive(Clone, Debug)]
1204pub struct EmdBatchOutput {
1205 pub upperband: Vec<f64>,
1206 pub middleband: Vec<f64>,
1207 pub lowerband: Vec<f64>,
1208 pub combos: Vec<EmdParams>,
1209 pub rows: usize,
1210 pub cols: usize,
1211}
1212
1213#[inline(always)]
1214fn expand_grid(r: &EmdBatchRange) -> Result<Vec<EmdParams>, EmdError> {
1215 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, EmdError> {
1216 if step == 0 || start == end {
1217 return Ok(vec![start]);
1218 }
1219 let mut v = Vec::new();
1220 if start < end {
1221 let mut cur = start;
1222 while cur <= end {
1223 v.push(cur);
1224 cur = match cur.checked_add(step) {
1225 Some(n) => n,
1226 None => break,
1227 };
1228 }
1229 } else {
1230 let mut cur = start;
1231 while cur >= end {
1232 v.push(cur);
1233 match cur.checked_sub(step) {
1234 Some(n) => cur = n,
1235 None => break,
1236 }
1237 }
1238 }
1239 if v.is_empty() {
1240 Err(EmdError::InvalidRangeU { start, end, step })
1241 } else {
1242 Ok(v)
1243 }
1244 }
1245 fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, EmdError> {
1246 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
1247 return Ok(vec![start]);
1248 }
1249 let mut v = Vec::new();
1250 if start < end {
1251 let mut x = start;
1252 while x <= end + 1e-12 {
1253 v.push(x);
1254 x += step;
1255 if !x.is_finite() {
1256 break;
1257 }
1258 }
1259 } else {
1260 let mut x = start;
1261 while x >= end - 1e-12 {
1262 v.push(x);
1263 x -= step.abs();
1264 if !x.is_finite() {
1265 break;
1266 }
1267 }
1268 }
1269 if v.is_empty() {
1270 Err(EmdError::InvalidRangeF { start, end, step })
1271 } else {
1272 Ok(v)
1273 }
1274 }
1275
1276 let periods = axis_usize(r.period)?;
1277 let deltas = axis_f64(r.delta)?;
1278 let fractions = axis_f64(r.fraction)?;
1279
1280 let cap = periods
1281 .len()
1282 .checked_mul(deltas.len())
1283 .and_then(|t| t.checked_mul(fractions.len()))
1284 .ok_or(EmdError::InvalidRangeU {
1285 start: 0,
1286 end: 0,
1287 step: 0,
1288 })?;
1289 let mut out = Vec::with_capacity(cap);
1290 for &p in &periods {
1291 for &d in &deltas {
1292 for &f in &fractions {
1293 out.push(EmdParams {
1294 period: Some(p),
1295 delta: Some(d),
1296 fraction: Some(f),
1297 });
1298 }
1299 }
1300 }
1301 Ok(out)
1302}
1303
1304#[inline(always)]
1305pub fn emd_batch_slice(
1306 high: &[f64],
1307 low: &[f64],
1308 sweep: &EmdBatchRange,
1309 kern: Kernel,
1310) -> Result<EmdBatchOutput, EmdError> {
1311 if !kern.is_batch() && kern != Kernel::Auto {
1312 return Err(EmdError::InvalidKernelForBatch(kern));
1313 }
1314 emd_batch_inner(high, low, sweep, kern, false)
1315}
1316
1317#[inline(always)]
1318pub fn emd_batch_par_slice(
1319 high: &[f64],
1320 low: &[f64],
1321 sweep: &EmdBatchRange,
1322 kern: Kernel,
1323) -> Result<EmdBatchOutput, EmdError> {
1324 if !kern.is_batch() && kern != Kernel::Auto {
1325 return Err(EmdError::InvalidKernelForBatch(kern));
1326 }
1327 emd_batch_inner(high, low, sweep, kern, true)
1328}
1329
1330#[inline(always)]
1331fn emd_batch_inner(
1332 high: &[f64],
1333 low: &[f64],
1334 sweep: &EmdBatchRange,
1335 kern: Kernel,
1336 parallel: bool,
1337) -> Result<EmdBatchOutput, EmdError> {
1338 let combos = expand_grid(sweep)?;
1339 if combos.is_empty() {
1340 return Err(EmdError::InvalidRangeU {
1341 start: 0,
1342 end: 0,
1343 step: 0,
1344 });
1345 }
1346
1347 let len = high.len();
1348 if low.len() != len {
1349 return Err(EmdError::InvalidInputLength {
1350 expected: len,
1351 actual: low.len(),
1352 });
1353 }
1354
1355 let first = (0..len)
1356 .find(|&i| !high[i].is_nan() && !low[i].is_nan())
1357 .ok_or(EmdError::AllValuesNaN)?;
1358 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1359 let needed = (2 * max_p).max(50);
1360 if len - first < needed {
1361 return Err(EmdError::NotEnoughValidData {
1362 needed,
1363 valid: len - first,
1364 });
1365 }
1366
1367 let rows = combos.len();
1368 let cols = len;
1369 let total = rows.checked_mul(cols).ok_or(EmdError::InvalidRangeU {
1370 start: rows,
1371 end: cols,
1372 step: usize::MAX,
1373 })?;
1374
1375 let prices: Vec<f64> = high
1376 .iter()
1377 .zip(low.iter())
1378 .map(|(&h, &l)| (h + l) * 0.5)
1379 .collect();
1380
1381 let mut ub_mu = make_uninit_matrix(rows, cols);
1382 let mut mb_mu = make_uninit_matrix(rows, cols);
1383 let mut lb_mu = make_uninit_matrix(rows, cols);
1384
1385 let warm_up_low: Vec<usize> = combos.iter().map(|_| first + 50 - 1).collect();
1386 let warm_mid: Vec<usize> = combos
1387 .iter()
1388 .map(|c| first + 2 * c.period.unwrap() - 1)
1389 .collect();
1390
1391 init_matrix_prefixes(&mut ub_mu, cols, &warm_up_low);
1392 init_matrix_prefixes(&mut mb_mu, cols, &warm_mid);
1393 init_matrix_prefixes(&mut lb_mu, cols, &warm_up_low);
1394
1395 let ub_ptr = ub_mu.as_mut_ptr() as *mut f64 as usize;
1396 let mb_ptr = mb_mu.as_mut_ptr() as *mut f64 as usize;
1397 let lb_ptr = lb_mu.as_mut_ptr() as *mut f64 as usize;
1398
1399 let simd = match kern {
1400 Kernel::Auto => match detect_best_batch_kernel() {
1401 Kernel::Avx512Batch => Kernel::Avx512,
1402 Kernel::Avx2Batch => Kernel::Avx2,
1403 Kernel::ScalarBatch => Kernel::Scalar,
1404 _ => Kernel::Scalar,
1405 },
1406 Kernel::Avx512 | Kernel::Avx512Batch => Kernel::Avx512,
1407 Kernel::Avx2 | Kernel::Avx2Batch => Kernel::Avx2,
1408 Kernel::Scalar | Kernel::ScalarBatch => Kernel::Scalar,
1409 _ => Kernel::Scalar,
1410 };
1411
1412 if parallel {
1413 #[cfg(not(target_arch = "wasm32"))]
1414 (0..rows).into_par_iter().for_each(|row| {
1415 let prm = &combos[row];
1416 let p = prm.period.unwrap();
1417 let d = prm.delta.unwrap();
1418 let f = prm.fraction.unwrap();
1419
1420 let ub = unsafe {
1421 std::slice::from_raw_parts_mut((ub_ptr as *mut f64).add(row * cols), cols)
1422 };
1423 let mb = unsafe {
1424 std::slice::from_raw_parts_mut((mb_ptr as *mut f64).add(row * cols), cols)
1425 };
1426 let lb = unsafe {
1427 std::slice::from_raw_parts_mut((lb_ptr as *mut f64).add(row * cols), cols)
1428 };
1429
1430 emd_compute_from_prices_into(&prices, p, d, f, first, simd, ub, mb, lb);
1431 });
1432 #[cfg(target_arch = "wasm32")]
1433 {
1434 let ub_rows = unsafe { std::slice::from_raw_parts_mut(ub_ptr as *mut f64, total) };
1435 let mb_rows = unsafe { std::slice::from_raw_parts_mut(mb_ptr as *mut f64, total) };
1436 let lb_rows = unsafe { std::slice::from_raw_parts_mut(lb_ptr as *mut f64, total) };
1437 for row in 0..rows {
1438 let prm = &combos[row];
1439 let p = prm.period.unwrap();
1440 let d = prm.delta.unwrap();
1441 let f = prm.fraction.unwrap();
1442
1443 let ub = &mut ub_rows[row * cols..(row + 1) * cols];
1444 let mb = &mut mb_rows[row * cols..(row + 1) * cols];
1445 let lb = &mut lb_rows[row * cols..(row + 1) * cols];
1446
1447 emd_compute_from_prices_into(&prices, p, d, f, first, simd, ub, mb, lb);
1448 }
1449 }
1450 } else {
1451 let ub_rows = unsafe { std::slice::from_raw_parts_mut(ub_ptr as *mut f64, total) };
1452 let mb_rows = unsafe { std::slice::from_raw_parts_mut(mb_ptr as *mut f64, total) };
1453 let lb_rows = unsafe { std::slice::from_raw_parts_mut(lb_ptr as *mut f64, total) };
1454 for row in 0..rows {
1455 let prm = &combos[row];
1456 let p = prm.period.unwrap();
1457 let d = prm.delta.unwrap();
1458 let f = prm.fraction.unwrap();
1459
1460 let ub = &mut ub_rows[row * cols..(row + 1) * cols];
1461 let mb = &mut mb_rows[row * cols..(row + 1) * cols];
1462 let lb = &mut lb_rows[row * cols..(row + 1) * cols];
1463
1464 emd_compute_from_prices_into(&prices, p, d, f, first, simd, ub, mb, lb);
1465 }
1466 }
1467
1468 let upperband = unsafe { Vec::from_raw_parts(ub_mu.as_mut_ptr() as *mut f64, total, total) };
1469 let middleband = unsafe { Vec::from_raw_parts(mb_mu.as_mut_ptr() as *mut f64, total, total) };
1470 let lowerband = unsafe { Vec::from_raw_parts(lb_mu.as_mut_ptr() as *mut f64, total, total) };
1471
1472 std::mem::forget(ub_mu);
1473 std::mem::forget(mb_mu);
1474 std::mem::forget(lb_mu);
1475
1476 Ok(EmdBatchOutput {
1477 upperband,
1478 middleband,
1479 lowerband,
1480 combos,
1481 rows,
1482 cols,
1483 })
1484}
1485
1486#[inline(always)]
1487fn emd_batch_inner_into(
1488 high: &[f64],
1489 low: &[f64],
1490 sweep: &EmdBatchRange,
1491 kern: Kernel,
1492 parallel: bool,
1493 upperband_out: &mut [f64],
1494 middleband_out: &mut [f64],
1495 lowerband_out: &mut [f64],
1496) -> Result<Vec<EmdParams>, EmdError> {
1497 let combos = expand_grid(sweep)?;
1498 if combos.is_empty() {
1499 return Err(EmdError::InvalidRangeU {
1500 start: 0,
1501 end: 0,
1502 step: 0,
1503 });
1504 }
1505 let len = high.len();
1506 if low.len() != len {
1507 return Err(EmdError::InvalidInputLength {
1508 expected: len,
1509 actual: low.len(),
1510 });
1511 }
1512
1513 let rows = combos.len();
1514 let cols = len;
1515 let expected = rows.checked_mul(cols).ok_or(EmdError::InvalidRangeU {
1516 start: rows,
1517 end: cols,
1518 step: usize::MAX,
1519 })?;
1520 if upperband_out.len() != expected
1521 || middleband_out.len() != expected
1522 || lowerband_out.len() != expected
1523 {
1524 return Err(EmdError::OutputLengthMismatch {
1525 expected,
1526 got: upperband_out
1527 .len()
1528 .min(middleband_out.len())
1529 .min(lowerband_out.len()),
1530 });
1531 }
1532
1533 let first = (0..len)
1534 .find(|&i| !high[i].is_nan() && !low[i].is_nan())
1535 .ok_or(EmdError::AllValuesNaN)?;
1536 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1537 let needed = (2 * max_p).max(50);
1538 if len - first < needed {
1539 return Err(EmdError::NotEnoughValidData {
1540 needed,
1541 valid: len - first,
1542 });
1543 }
1544
1545 {
1546 let mut ub_mu = unsafe {
1547 std::slice::from_raw_parts_mut(
1548 upperband_out.as_mut_ptr() as *mut MaybeUninit<f64>,
1549 rows * cols,
1550 )
1551 };
1552 let mut mb_mu = unsafe {
1553 std::slice::from_raw_parts_mut(
1554 middleband_out.as_mut_ptr() as *mut MaybeUninit<f64>,
1555 rows * cols,
1556 )
1557 };
1558 let mut lb_mu = unsafe {
1559 std::slice::from_raw_parts_mut(
1560 lowerband_out.as_mut_ptr() as *mut MaybeUninit<f64>,
1561 rows * cols,
1562 )
1563 };
1564
1565 let warm_up_low: Vec<usize> = combos.iter().map(|_| first + 50 - 1).collect();
1566 let warm_mid: Vec<usize> = combos
1567 .iter()
1568 .map(|c| first + 2 * c.period.unwrap() - 1)
1569 .collect();
1570
1571 init_matrix_prefixes(&mut ub_mu, cols, &warm_up_low);
1572 init_matrix_prefixes(&mut mb_mu, cols, &warm_mid);
1573 init_matrix_prefixes(&mut lb_mu, cols, &warm_up_low);
1574 }
1575
1576 let ub_ptr = upperband_out.as_mut_ptr() as usize;
1577 let mb_ptr = middleband_out.as_mut_ptr() as usize;
1578 let lb_ptr = lowerband_out.as_mut_ptr() as usize;
1579
1580 let simd = match kern {
1581 Kernel::Auto => match detect_best_batch_kernel() {
1582 Kernel::Avx512Batch => Kernel::Avx512,
1583 Kernel::Avx2Batch => Kernel::Avx2,
1584 Kernel::ScalarBatch => Kernel::Scalar,
1585 _ => Kernel::Scalar,
1586 },
1587 Kernel::Avx512 | Kernel::Avx512Batch => Kernel::Avx512,
1588 Kernel::Avx2 | Kernel::Avx2Batch => Kernel::Avx2,
1589 Kernel::Scalar | Kernel::ScalarBatch => Kernel::Scalar,
1590 _ => Kernel::Scalar,
1591 };
1592
1593 if parallel {
1594 #[cfg(not(target_arch = "wasm32"))]
1595 (0..rows).into_par_iter().for_each(|row| {
1596 let prm = &combos[row];
1597 let p = prm.period.unwrap();
1598 let d = prm.delta.unwrap();
1599 let f = prm.fraction.unwrap();
1600
1601 let ub = unsafe {
1602 std::slice::from_raw_parts_mut((ub_ptr as *mut f64).add(row * cols), cols)
1603 };
1604 let mb = unsafe {
1605 std::slice::from_raw_parts_mut((mb_ptr as *mut f64).add(row * cols), cols)
1606 };
1607 let lb = unsafe {
1608 std::slice::from_raw_parts_mut((lb_ptr as *mut f64).add(row * cols), cols)
1609 };
1610
1611 emd_compute_into(high, low, p, d, f, first, simd, ub, mb, lb);
1612 });
1613 #[cfg(target_arch = "wasm32")]
1614 for row in 0..rows {
1615 let prm = &combos[row];
1616 let p = prm.period.unwrap();
1617 let d = prm.delta.unwrap();
1618 let f = prm.fraction.unwrap();
1619
1620 let ub = &mut upperband_out[row * cols..(row + 1) * cols];
1621 let mb = &mut middleband_out[row * cols..(row + 1) * cols];
1622 let lb = &mut lowerband_out[row * cols..(row + 1) * cols];
1623
1624 emd_compute_into(high, low, p, d, f, first, simd, ub, mb, lb);
1625 }
1626 } else {
1627 for row in 0..rows {
1628 let prm = &combos[row];
1629 let p = prm.period.unwrap();
1630 let d = prm.delta.unwrap();
1631 let f = prm.fraction.unwrap();
1632
1633 let ub = &mut upperband_out[row * cols..(row + 1) * cols];
1634 let mb = &mut middleband_out[row * cols..(row + 1) * cols];
1635 let lb = &mut lowerband_out[row * cols..(row + 1) * cols];
1636
1637 emd_compute_into(high, low, p, d, f, first, simd, ub, mb, lb);
1638 }
1639 }
1640
1641 Ok(combos)
1642}
1643
1644impl EmdBatchOutput {
1645 pub fn row_for_params(&self, p: &EmdParams) -> Option<usize> {
1646 self.combos.iter().position(|c| {
1647 c.period.unwrap_or(20) == p.period.unwrap_or(20)
1648 && (c.delta.unwrap_or(0.5) - p.delta.unwrap_or(0.5)).abs() < 1e-12
1649 && (c.fraction.unwrap_or(0.1) - p.fraction.unwrap_or(0.1)).abs() < 1e-12
1650 })
1651 }
1652 pub fn bands_for(&self, p: &EmdParams) -> Option<(&[f64], &[f64], &[f64])> {
1653 self.row_for_params(p).map(|row| {
1654 let start = row * self.cols;
1655 (
1656 &self.upperband[start..start + self.cols],
1657 &self.middleband[start..start + self.cols],
1658 &self.lowerband[start..start + self.cols],
1659 )
1660 })
1661 }
1662}
1663
1664#[cfg(test)]
1665mod tests {
1666 use super::*;
1667 use crate::skip_if_unsupported;
1668 use crate::utilities::data_loader::read_candles_from_csv;
1669
1670 #[test]
1671 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1672 fn test_emd_into_matches_api() -> Result<(), Box<dyn Error>> {
1673 let n = 256usize;
1674 let mut high = Vec::with_capacity(n);
1675 let mut low = Vec::with_capacity(n);
1676 for i in 0..n {
1677 let base = 100.0
1678 + (i as f64 * 0.01)
1679 + (2.0 * std::f64::consts::PI * (i as f64) / 17.0).sin() * 3.0
1680 + (2.0 * std::f64::consts::PI * (i as f64) / 49.0).cos() * 2.0;
1681 high.push(base + 1.25);
1682 low.push(base - 1.10);
1683 }
1684
1685 let params = EmdParams::default();
1686 let input = EmdInput::from_slices(&high, &low, &[], &[], params);
1687
1688 let baseline = emd(&input)?;
1689
1690 let mut ub = vec![0.0; n];
1691 let mut mb = vec![0.0; n];
1692 let mut lb = vec![0.0; n];
1693 emd_into(&input, &mut ub, &mut mb, &mut lb)?;
1694
1695 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1696 (a.is_nan() && b.is_nan()) || (a == b)
1697 }
1698
1699 assert_eq!(baseline.upperband.len(), ub.len());
1700 assert_eq!(baseline.middleband.len(), mb.len());
1701 assert_eq!(baseline.lowerband.len(), lb.len());
1702 for i in 0..n {
1703 assert!(
1704 eq_or_both_nan(baseline.upperband[i], ub[i]),
1705 "upperband mismatch at {}: {:?} vs {:?}",
1706 i,
1707 baseline.upperband[i],
1708 ub[i]
1709 );
1710 assert!(
1711 eq_or_both_nan(baseline.middleband[i], mb[i]),
1712 "middleband mismatch at {}: {:?} vs {:?}",
1713 i,
1714 baseline.middleband[i],
1715 mb[i]
1716 );
1717 assert!(
1718 eq_or_both_nan(baseline.lowerband[i], lb[i]),
1719 "lowerband mismatch at {}: {:?} vs {:?}",
1720 i,
1721 baseline.lowerband[i],
1722 lb[i]
1723 );
1724 }
1725 Ok(())
1726 }
1727
1728 fn check_emd_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1729 skip_if_unsupported!(kernel, test_name);
1730 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1731 let candles = read_candles_from_csv(file_path)?;
1732
1733 let params = EmdParams::default();
1734 let input = EmdInput::from_candles(&candles, params);
1735 let emd_result = emd_with_kernel(&input, kernel)?;
1736
1737 let expected_last_five_upper = [
1738 50.33760237677157,
1739 50.28850695686447,
1740 50.23941153695737,
1741 50.19031611705027,
1742 48.709744457737344,
1743 ];
1744 let expected_last_five_middle = [
1745 -368.71064280396706,
1746 -399.11033986231377,
1747 -421.9368852621732,
1748 -437.879217150269,
1749 -447.3257167904511,
1750 ];
1751 let expected_last_five_lower = [
1752 -60.67834136221248,
1753 -60.93110347122829,
1754 -61.68154077026321,
1755 -62.43197806929814,
1756 -63.18241536833306,
1757 ];
1758
1759 let len = candles.close.len();
1760 let start_idx = len - 5;
1761 let actual_ub = &emd_result.upperband[start_idx..];
1762 let actual_mb = &emd_result.middleband[start_idx..];
1763 let actual_lb = &emd_result.lowerband[start_idx..];
1764 for i in 0..5 {
1765 assert!(
1766 (actual_ub[i] - expected_last_five_upper[i]).abs() < 1e-6,
1767 "Upperband mismatch at index {}: expected {}, got {}",
1768 i,
1769 expected_last_five_upper[i],
1770 actual_ub[i]
1771 );
1772 assert!(
1773 (actual_mb[i] - expected_last_five_middle[i]).abs() < 1e-6,
1774 "Middleband mismatch at index {}: expected {}, got {}",
1775 i,
1776 expected_last_five_middle[i],
1777 actual_mb[i]
1778 );
1779 assert!(
1780 (actual_lb[i] - expected_last_five_lower[i]).abs() < 1e-6,
1781 "Lowerband mismatch at index {}: expected {}, got {}",
1782 i,
1783 expected_last_five_lower[i],
1784 actual_lb[i]
1785 );
1786 }
1787 Ok(())
1788 }
1789
1790 fn check_emd_empty_data(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1791 skip_if_unsupported!(kernel, test_name);
1792 let empty_data: [f64; 0] = [];
1793 let params = EmdParams::default();
1794 let input =
1795 EmdInput::from_slices(&empty_data, &empty_data, &empty_data, &empty_data, params);
1796 let result = emd_with_kernel(&input, kernel);
1797 assert!(result.is_err(), "Expected error on empty data");
1798 Ok(())
1799 }
1800
1801 fn check_emd_all_nans(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1802 skip_if_unsupported!(kernel, test_name);
1803 let data = [f64::NAN, f64::NAN, f64::NAN];
1804 let params = EmdParams::default();
1805 let input = EmdInput::from_slices(&data, &data, &data, &data, params);
1806 let result = emd_with_kernel(&input, kernel);
1807 assert!(result.is_err(), "Expected error for all-NaN data");
1808 Ok(())
1809 }
1810
1811 fn check_emd_invalid_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1812 skip_if_unsupported!(kernel, test_name);
1813 let data = [1.0, 2.0, 3.0];
1814 let params = EmdParams {
1815 period: Some(0),
1816 ..Default::default()
1817 };
1818 let input = EmdInput::from_slices(&data, &data, &data, &data, params);
1819 let result = emd_with_kernel(&input, kernel);
1820 assert!(result.is_err(), "Expected error for zero period");
1821 Ok(())
1822 }
1823
1824 fn check_emd_not_enough_valid_data(
1825 test_name: &str,
1826 kernel: Kernel,
1827 ) -> Result<(), Box<dyn Error>> {
1828 skip_if_unsupported!(kernel, test_name);
1829 let data = vec![10.0; 10];
1830 let params = EmdParams {
1831 period: Some(20),
1832 ..Default::default()
1833 };
1834 let input = EmdInput::from_slices(&data, &data, &data, &data, params);
1835 let result = emd_with_kernel(&input, kernel);
1836 assert!(result.is_err(), "Expected error for not enough valid data");
1837 Ok(())
1838 }
1839
1840 fn check_emd_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1841 skip_if_unsupported!(kernel, test_name);
1842 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1843 let candles = read_candles_from_csv(file_path)?;
1844 let input = EmdInput::with_default_candles(&candles);
1845 let result = emd_with_kernel(&input, kernel);
1846 assert!(
1847 result.is_ok(),
1848 "Expected EMD to succeed with default params"
1849 );
1850 Ok(())
1851 }
1852
1853 #[cfg(feature = "proptest")]
1854 #[allow(clippy::float_cmp)]
1855 fn check_emd_property(
1856 test_name: &str,
1857 kernel: Kernel,
1858 ) -> Result<(), Box<dyn std::error::Error>> {
1859 use proptest::prelude::*;
1860 skip_if_unsupported!(kernel, test_name);
1861
1862 let strat1 = (2usize..=64).prop_flat_map(|period| {
1863 (
1864 prop::collection::vec(
1865 (1f64..1000f64).prop_filter("finite", |x| x.is_finite()),
1866 (2 * period).max(50)..400,
1867 ),
1868 Just(period),
1869 (0.1f64..1.0f64).prop_filter("finite", |x| x.is_finite()),
1870 (0.01f64..0.5f64).prop_filter("finite", |x| x.is_finite()),
1871 )
1872 });
1873
1874 let strat2 = prop::collection::vec(
1875 (100f64..10000f64).prop_filter("finite", |x| x.is_finite()),
1876 100..500,
1877 )
1878 .prop_map(|data| (data, 20usize, 0.5f64, 0.1f64));
1879
1880 let strat3 = (100usize..400, prop::bool::ANY).prop_map(|(len, increasing)| {
1881 let mut data = Vec::with_capacity(len);
1882 let mut val = 100.0;
1883 for _ in 0..len {
1884 data.push(val);
1885 val += if increasing { 1.0 } else { -1.0 };
1886 }
1887 (data, 14usize, 0.5f64, 0.1f64)
1888 });
1889
1890 let strat4 = (100usize..400, 5usize..50).prop_map(|(len, period_wave)| {
1891 let mut data = Vec::with_capacity(len);
1892 for i in 0..len {
1893 let val = 1000.0
1894 + 100.0 * (2.0 * std::f64::consts::PI * i as f64 / period_wave as f64).sin();
1895 data.push(val);
1896 }
1897 (data, 20usize, 0.5f64, 0.1f64)
1898 });
1899
1900 let strat5 = (2usize..=30).prop_flat_map(|period| {
1901 let min_len = (2 * period).max(50);
1902 prop::collection::vec(
1903 (
1904 50f64..150f64,
1905 0.1f64..10f64,
1906 -0.5f64..0.5f64,
1907 0f64..0.5f64,
1908 0f64..0.5f64,
1909 )
1910 .prop_map(
1911 |(base, range, close_offset, high_extra, low_extra)| {
1912 let open = base;
1913 let close = base + range * close_offset;
1914 let high = open.max(close) + range * high_extra;
1915 let low = open.min(close) - range * low_extra;
1916 (high, low, close, base * 1000.0)
1917 },
1918 ),
1919 min_len..300,
1920 )
1921 .prop_map(move |ohlc_data| {
1922 let (highs, lows, closes, volumes): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) =
1923 ohlc_data.into_iter().unzip4();
1924 (highs, lows, closes, volumes, period, 0.5f64, 0.1f64)
1925 })
1926 });
1927
1928 trait Unzip4<A, B, C, D> {
1929 fn unzip4(self) -> (Vec<A>, Vec<B>, Vec<C>, Vec<D>);
1930 }
1931 impl<A, B, C, D, I: Iterator<Item = (A, B, C, D)>> Unzip4<A, B, C, D> for I {
1932 fn unzip4(self) -> (Vec<A>, Vec<B>, Vec<C>, Vec<D>) {
1933 let (mut a_vec, mut b_vec, mut c_vec, mut d_vec) =
1934 (Vec::new(), Vec::new(), Vec::new(), Vec::new());
1935 for (a, b, c, d) in self {
1936 a_vec.push(a);
1937 b_vec.push(b);
1938 c_vec.push(c);
1939 d_vec.push(d);
1940 }
1941 (a_vec, b_vec, c_vec, d_vec)
1942 }
1943 }
1944
1945 let combined_strat = prop_oneof![
1946 strat1.prop_map(|(data, period, delta, fraction)| {
1947 let high = data.clone();
1948 let low = data.clone();
1949 let close = data.clone();
1950 let volume = vec![1000.0; data.len()];
1951 (high, low, close, volume, period, delta, fraction)
1952 }),
1953 strat2.prop_map(|(data, period, delta, fraction)| {
1954 let high = data.iter().map(|x| x + 10.0).collect();
1955 let low = data.iter().map(|x| x - 10.0).collect();
1956 let close = data.clone();
1957 let volume = vec![1000.0; data.len()];
1958 (high, low, close, volume, period, delta, fraction)
1959 }),
1960 strat3.prop_map(|(data, period, delta, fraction)| {
1961 let high = data.iter().map(|x| x + 5.0).collect();
1962 let low = data.iter().map(|x| x - 5.0).collect();
1963 let close = data.clone();
1964 let volume = vec![1000.0; data.len()];
1965 (high, low, close, volume, period, delta, fraction)
1966 }),
1967 strat4.prop_map(|(data, period, delta, fraction)| {
1968 let high = data.iter().map(|x| x + 20.0).collect();
1969 let low = data.iter().map(|x| x - 20.0).collect();
1970 let close = data.clone();
1971 let volume = vec![5000.0; data.len()];
1972 (high, low, close, volume, period, delta, fraction)
1973 }),
1974 strat5,
1975 ];
1976
1977 proptest::test_runner::TestRunner::default().run(
1978 &combined_strat,
1979 |(high, low, close, volume, period, delta, fraction)| {
1980 for i in 0..high.len() {
1981 prop_assert!(
1982 high[i] >= low[i],
1983 "Invalid OHLC data at index {}: high ({}) < low ({})",
1984 i,
1985 high[i],
1986 low[i]
1987 );
1988 }
1989
1990 let params = EmdParams {
1991 period: Some(period),
1992 delta: Some(delta),
1993 fraction: Some(fraction),
1994 };
1995 let input = EmdInput::from_slices(&high, &low, &close, &volume, params);
1996
1997 let result = emd_with_kernel(&input, kernel).unwrap();
1998 let upperband = &result.upperband;
1999 let middleband = &result.middleband;
2000 let lowerband = &result.lowerband;
2001
2002 let ref_result = emd_with_kernel(&input, Kernel::Scalar).unwrap();
2003 let ref_upperband = &ref_result.upperband;
2004 let ref_middleband = &ref_result.middleband;
2005 let ref_lowerband = &ref_result.lowerband;
2006
2007 prop_assert_eq!(upperband.len(), high.len());
2008 prop_assert_eq!(middleband.len(), high.len());
2009 prop_assert_eq!(lowerband.len(), high.len());
2010
2011 let upperband_warmup = (50 - 1).min(high.len());
2012 let middleband_warmup = ((2 * period) - 1).min(high.len());
2013
2014 for i in 0..upperband_warmup {
2015 prop_assert!(
2016 upperband[i].is_nan(),
2017 "Upperband should be NaN during warmup at index {}",
2018 i
2019 );
2020 prop_assert!(
2021 lowerband[i].is_nan(),
2022 "Lowerband should be NaN during warmup at index {}",
2023 i
2024 );
2025 }
2026
2027 for i in 0..middleband_warmup {
2028 prop_assert!(
2029 middleband[i].is_nan(),
2030 "Middleband should be NaN during warmup at index {}",
2031 i
2032 );
2033 }
2034
2035 let start_idx = upperband_warmup.max(middleband_warmup) + 1;
2036 if start_idx < high.len() {
2037 let input_min = high[start_idx..]
2038 .iter()
2039 .chain(low[start_idx..].iter())
2040 .fold(
2041 f64::INFINITY,
2042 |a, &b| if b.is_finite() { a.min(b) } else { a },
2043 );
2044 let input_max = high[start_idx..]
2045 .iter()
2046 .chain(low[start_idx..].iter())
2047 .fold(
2048 f64::NEG_INFINITY,
2049 |a, &b| if b.is_finite() { a.max(b) } else { a },
2050 );
2051
2052 if input_min.is_finite() && input_max.is_finite() {
2053 let range = input_max - input_min;
2054 let center = (input_max + input_min) / 2.0;
2055
2056 let bounds_factor = 3.0;
2057 let lower_bound = center - bounds_factor * range.max(1.0);
2058 let upper_bound = center + bounds_factor * range.max(1.0);
2059
2060 for i in start_idx..high.len() {
2061 if !upperband[i].is_nan()
2062 && !middleband[i].is_nan()
2063 && !lowerband[i].is_nan()
2064 {
2065 prop_assert!(
2066 upperband[i].is_finite(),
2067 "Upperband should be finite at index {}",
2068 i
2069 );
2070 prop_assert!(
2071 middleband[i].is_finite(),
2072 "Middleband should be finite at index {}",
2073 i
2074 );
2075 prop_assert!(
2076 lowerband[i].is_finite(),
2077 "Lowerband should be finite at index {}",
2078 i
2079 );
2080
2081 prop_assert!(
2082 upperband[i] >= lower_bound && upperband[i] <= upper_bound,
2083 "Upperband {} at index {} outside reasonable bounds [{}, {}]",
2084 upperband[i],
2085 i,
2086 lower_bound,
2087 upper_bound
2088 );
2089 prop_assert!(
2090 middleband[i] >= lower_bound && middleband[i] <= upper_bound,
2091 "Middleband {} at index {} outside reasonable bounds [{}, {}]",
2092 middleband[i],
2093 i,
2094 lower_bound,
2095 upper_bound
2096 );
2097 prop_assert!(
2098 lowerband[i] >= lower_bound && lowerband[i] <= upper_bound,
2099 "Lowerband {} at index {} outside reasonable bounds [{}, {}]",
2100 lowerband[i],
2101 i,
2102 lower_bound,
2103 upper_bound
2104 );
2105 }
2106 }
2107 }
2108 }
2109
2110 let tolerance = 1e-10;
2111 for i in 0..high.len() {
2112 let ub_diff = (upperband[i] - ref_upperband[i]).abs();
2113 let mb_diff = (middleband[i] - ref_middleband[i]).abs();
2114 let lb_diff = (lowerband[i] - ref_lowerband[i]).abs();
2115
2116 if !upperband[i].is_nan() && !ref_upperband[i].is_nan() {
2117 prop_assert!(
2118 ub_diff < tolerance,
2119 "Upperband kernel mismatch at index {}: {} vs {} (diff: {})",
2120 i,
2121 upperband[i],
2122 ref_upperband[i],
2123 ub_diff
2124 );
2125 }
2126 if !middleband[i].is_nan() && !ref_middleband[i].is_nan() {
2127 prop_assert!(
2128 mb_diff < tolerance,
2129 "Middleband kernel mismatch at index {}: {} vs {} (diff: {})",
2130 i,
2131 middleband[i],
2132 ref_middleband[i],
2133 mb_diff
2134 );
2135 }
2136 if !lowerband[i].is_nan() && !ref_lowerband[i].is_nan() {
2137 prop_assert!(
2138 lb_diff < tolerance,
2139 "Lowerband kernel mismatch at index {}: {} vs {} (diff: {})",
2140 i,
2141 lowerband[i],
2142 ref_lowerband[i],
2143 lb_diff
2144 );
2145 }
2146 }
2147
2148 for i in 0..high.len() {
2149 let ub_bits = upperband[i].to_bits();
2150 let mb_bits = middleband[i].to_bits();
2151 let lb_bits = lowerband[i].to_bits();
2152
2153 prop_assert_ne!(
2154 ub_bits,
2155 0x1111_1111_1111_1111,
2156 "Poison value in upperband at {}",
2157 i
2158 );
2159 prop_assert_ne!(
2160 ub_bits,
2161 0x2222_2222_2222_2222,
2162 "Poison value in upperband at {}",
2163 i
2164 );
2165 prop_assert_ne!(
2166 ub_bits,
2167 0x3333_3333_3333_3333,
2168 "Poison value in upperband at {}",
2169 i
2170 );
2171
2172 prop_assert_ne!(
2173 mb_bits,
2174 0x1111_1111_1111_1111,
2175 "Poison value in middleband at {}",
2176 i
2177 );
2178 prop_assert_ne!(
2179 mb_bits,
2180 0x2222_2222_2222_2222,
2181 "Poison value in middleband at {}",
2182 i
2183 );
2184 prop_assert_ne!(
2185 mb_bits,
2186 0x3333_3333_3333_3333,
2187 "Poison value in middleband at {}",
2188 i
2189 );
2190
2191 prop_assert_ne!(
2192 lb_bits,
2193 0x1111_1111_1111_1111,
2194 "Poison value in lowerband at {}",
2195 i
2196 );
2197 prop_assert_ne!(
2198 lb_bits,
2199 0x2222_2222_2222_2222,
2200 "Poison value in lowerband at {}",
2201 i
2202 );
2203 prop_assert_ne!(
2204 lb_bits,
2205 0x3333_3333_3333_3333,
2206 "Poison value in lowerband at {}",
2207 i
2208 );
2209 }
2210
2211 if period == 2 {
2212 let min_warmup = (2 * 2).max(50);
2213 if high.len() > min_warmup {
2214 prop_assert!(
2215 middleband[min_warmup].is_finite() || middleband[min_warmup].is_nan(),
2216 "Period=2 should produce valid or NaN output"
2217 );
2218 }
2219 }
2220
2221 if high.windows(2).all(|w| (w[0] - w[1]).abs() < f64::EPSILON)
2222 && low.windows(2).all(|w| (w[0] - w[1]).abs() < f64::EPSILON)
2223 && close.windows(2).all(|w| (w[0] - w[1]).abs() < f64::EPSILON)
2224 {
2225 let check_start = start_idx + period;
2226 if check_start + 5 < high.len() {
2227 let ub_stable = &upperband[check_start..check_start + 5];
2228 let mb_stable = &middleband[check_start..check_start + 5];
2229 let lb_stable = &lowerband[check_start..check_start + 5];
2230
2231 for w in ub_stable.windows(2) {
2232 if !w[0].is_nan() && !w[1].is_nan() {
2233 prop_assert!(
2234 (w[0] - w[1]).abs() < 1e-6,
2235 "Upperband should be stable for constant input"
2236 );
2237 }
2238 }
2239 for w in mb_stable.windows(2) {
2240 if !w[0].is_nan() && !w[1].is_nan() {
2241 prop_assert!(
2242 (w[0] - w[1]).abs() < 1e-6,
2243 "Middleband should be stable for constant input"
2244 );
2245 }
2246 }
2247 for w in lb_stable.windows(2) {
2248 if !w[0].is_nan() && !w[1].is_nan() {
2249 prop_assert!(
2250 (w[0] - w[1]).abs() < 1e-6,
2251 "Lowerband should be stable for constant input"
2252 );
2253 }
2254 }
2255 }
2256 }
2257
2258 if fraction < 0.01 && start_idx < high.len() {
2259 let check_end = (start_idx + 10).min(high.len());
2260 for i in start_idx..check_end {
2261 if !upperband[i].is_nan() && !lowerband[i].is_nan() {
2262 prop_assert!(
2263 upperband[i].abs() < 10.0,
2264 "With fraction={}, upperband should be small, got {} at index {}",
2265 fraction,
2266 upperband[i],
2267 i
2268 );
2269 prop_assert!(
2270 lowerband[i].abs() < 10.0,
2271 "With fraction={}, lowerband should be small, got {} at index {}",
2272 fraction,
2273 lowerband[i],
2274 i
2275 );
2276 }
2277 }
2278 }
2279
2280 Ok(())
2281 },
2282 )?;
2283
2284 Ok(())
2285 }
2286
2287 #[cfg(not(feature = "proptest"))]
2288 fn check_emd_property(
2289 _test_name: &str,
2290 _kernel: Kernel,
2291 ) -> Result<(), Box<dyn std::error::Error>> {
2292 Ok(())
2293 }
2294
2295 macro_rules! generate_all_emd_tests {
2296 ($($test_fn:ident),*) => {
2297 paste::paste! {
2298 $(
2299 #[test]
2300 fn [<$test_fn _scalar_f64>]() {
2301 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2302 }
2303 )*
2304 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2305 $(
2306 #[test]
2307 fn [<$test_fn _avx2_f64>]() {
2308 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2309 }
2310 #[test]
2311 fn [<$test_fn _avx512_f64>]() {
2312 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2313 }
2314 )*
2315 }
2316 }
2317 }
2318
2319 #[cfg(debug_assertions)]
2320 fn check_emd_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2321 skip_if_unsupported!(kernel, test_name);
2322
2323 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2324 let candles = read_candles_from_csv(file_path)?;
2325
2326 let test_params = vec![
2327 EmdParams::default(),
2328 EmdParams {
2329 period: Some(2),
2330 delta: Some(0.1),
2331 fraction: Some(0.05),
2332 },
2333 EmdParams {
2334 period: Some(5),
2335 delta: Some(0.3),
2336 fraction: Some(0.1),
2337 },
2338 EmdParams {
2339 period: Some(10),
2340 delta: Some(0.5),
2341 fraction: Some(0.15),
2342 },
2343 EmdParams {
2344 period: Some(20),
2345 delta: Some(0.4),
2346 fraction: Some(0.1),
2347 },
2348 EmdParams {
2349 period: Some(30),
2350 delta: Some(0.6),
2351 fraction: Some(0.2),
2352 },
2353 EmdParams {
2354 period: Some(50),
2355 delta: Some(0.7),
2356 fraction: Some(0.25),
2357 },
2358 EmdParams {
2359 period: Some(100),
2360 delta: Some(0.8),
2361 fraction: Some(0.3),
2362 },
2363 EmdParams {
2364 period: Some(15),
2365 delta: Some(0.1),
2366 fraction: Some(0.1),
2367 },
2368 EmdParams {
2369 period: Some(15),
2370 delta: Some(0.9),
2371 fraction: Some(0.1),
2372 },
2373 EmdParams {
2374 period: Some(25),
2375 delta: Some(0.5),
2376 fraction: Some(0.01),
2377 },
2378 EmdParams {
2379 period: Some(25),
2380 delta: Some(0.5),
2381 fraction: Some(0.5),
2382 },
2383 EmdParams {
2384 period: Some(40),
2385 delta: Some(0.65),
2386 fraction: Some(0.12),
2387 },
2388 EmdParams {
2389 period: Some(7),
2390 delta: Some(0.25),
2391 fraction: Some(0.08),
2392 },
2393 ];
2394
2395 for (param_idx, params) in test_params.iter().enumerate() {
2396 let input = EmdInput::from_candles(&candles, params.clone());
2397 let output = emd_with_kernel(&input, kernel)?;
2398
2399 for (i, &val) in output.upperband.iter().enumerate() {
2400 if val.is_nan() {
2401 continue;
2402 }
2403
2404 let bits = val.to_bits();
2405
2406 if bits == 0x11111111_11111111 {
2407 panic!(
2408 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in upperband \
2409 with params: period={}, delta={}, fraction={} (param set {})",
2410 test_name, val, bits, i,
2411 params.period.unwrap_or(20),
2412 params.delta.unwrap_or(0.5),
2413 params.fraction.unwrap_or(0.1),
2414 param_idx
2415 );
2416 }
2417
2418 if bits == 0x22222222_22222222 {
2419 panic!(
2420 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in upperband \
2421 with params: period={}, delta={}, fraction={} (param set {})",
2422 test_name, val, bits, i,
2423 params.period.unwrap_or(20),
2424 params.delta.unwrap_or(0.5),
2425 params.fraction.unwrap_or(0.1),
2426 param_idx
2427 );
2428 }
2429
2430 if bits == 0x33333333_33333333 {
2431 panic!(
2432 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in upperband \
2433 with params: period={}, delta={}, fraction={} (param set {})",
2434 test_name, val, bits, i,
2435 params.period.unwrap_or(20),
2436 params.delta.unwrap_or(0.5),
2437 params.fraction.unwrap_or(0.1),
2438 param_idx
2439 );
2440 }
2441 }
2442
2443 for (i, &val) in output.middleband.iter().enumerate() {
2444 if val.is_nan() {
2445 continue;
2446 }
2447
2448 let bits = val.to_bits();
2449
2450 if bits == 0x11111111_11111111 {
2451 panic!(
2452 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in middleband \
2453 with params: period={}, delta={}, fraction={} (param set {})",
2454 test_name, val, bits, i,
2455 params.period.unwrap_or(20),
2456 params.delta.unwrap_or(0.5),
2457 params.fraction.unwrap_or(0.1),
2458 param_idx
2459 );
2460 }
2461
2462 if bits == 0x22222222_22222222 {
2463 panic!(
2464 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in middleband \
2465 with params: period={}, delta={}, fraction={} (param set {})",
2466 test_name, val, bits, i,
2467 params.period.unwrap_or(20),
2468 params.delta.unwrap_or(0.5),
2469 params.fraction.unwrap_or(0.1),
2470 param_idx
2471 );
2472 }
2473
2474 if bits == 0x33333333_33333333 {
2475 panic!(
2476 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in middleband \
2477 with params: period={}, delta={}, fraction={} (param set {})",
2478 test_name, val, bits, i,
2479 params.period.unwrap_or(20),
2480 params.delta.unwrap_or(0.5),
2481 params.fraction.unwrap_or(0.1),
2482 param_idx
2483 );
2484 }
2485 }
2486
2487 for (i, &val) in output.lowerband.iter().enumerate() {
2488 if val.is_nan() {
2489 continue;
2490 }
2491
2492 let bits = val.to_bits();
2493
2494 if bits == 0x11111111_11111111 {
2495 panic!(
2496 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in lowerband \
2497 with params: period={}, delta={}, fraction={} (param set {})",
2498 test_name, val, bits, i,
2499 params.period.unwrap_or(20),
2500 params.delta.unwrap_or(0.5),
2501 params.fraction.unwrap_or(0.1),
2502 param_idx
2503 );
2504 }
2505
2506 if bits == 0x22222222_22222222 {
2507 panic!(
2508 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in lowerband \
2509 with params: period={}, delta={}, fraction={} (param set {})",
2510 test_name, val, bits, i,
2511 params.period.unwrap_or(20),
2512 params.delta.unwrap_or(0.5),
2513 params.fraction.unwrap_or(0.1),
2514 param_idx
2515 );
2516 }
2517
2518 if bits == 0x33333333_33333333 {
2519 panic!(
2520 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in lowerband \
2521 with params: period={}, delta={}, fraction={} (param set {})",
2522 test_name, val, bits, i,
2523 params.period.unwrap_or(20),
2524 params.delta.unwrap_or(0.5),
2525 params.fraction.unwrap_or(0.1),
2526 param_idx
2527 );
2528 }
2529 }
2530 }
2531
2532 Ok(())
2533 }
2534
2535 #[cfg(not(debug_assertions))]
2536 fn check_emd_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2537 Ok(())
2538 }
2539
2540 generate_all_emd_tests!(
2541 check_emd_accuracy,
2542 check_emd_empty_data,
2543 check_emd_all_nans,
2544 check_emd_invalid_period,
2545 check_emd_not_enough_valid_data,
2546 check_emd_default_candles,
2547 check_emd_no_poison
2548 );
2549
2550 #[cfg(feature = "proptest")]
2551 generate_all_emd_tests!(check_emd_property);
2552
2553 #[cfg(test)]
2554 mod batch_tests {
2555 use super::*;
2556 use crate::skip_if_unsupported;
2557 use crate::utilities::data_loader::read_candles_from_csv;
2558
2559 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2560 skip_if_unsupported!(kernel, test);
2561
2562 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2563 let c = read_candles_from_csv(file)?;
2564
2565 let output = EmdBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
2566
2567 let def = EmdParams::default();
2568 let (ub, mb, lb) = output.bands_for(&def).expect("default row missing");
2569
2570 assert_eq!(ub.len(), c.close.len(), "Upperband length mismatch");
2571 assert_eq!(mb.len(), c.close.len(), "Middleband length mismatch");
2572 assert_eq!(lb.len(), c.close.len(), "Lowerband length mismatch");
2573
2574 let expected_last_five_upper = [
2575 50.33760237677157,
2576 50.28850695686447,
2577 50.23941153695737,
2578 50.19031611705027,
2579 48.709744457737344,
2580 ];
2581 let expected_last_five_middle = [
2582 -368.71064280396706,
2583 -399.11033986231377,
2584 -421.9368852621732,
2585 -437.879217150269,
2586 -447.3257167904511,
2587 ];
2588 let expected_last_five_lower = [
2589 -60.67834136221248,
2590 -60.93110347122829,
2591 -61.68154077026321,
2592 -62.43197806929814,
2593 -63.18241536833306,
2594 ];
2595 let len = ub.len();
2596 for i in 0..5 {
2597 assert!(
2598 (ub[len - 5 + i] - expected_last_five_upper[i]).abs() < 1e-6,
2599 "[{test}] upperband mismatch idx {i}: {} vs {}",
2600 ub[len - 5 + i],
2601 expected_last_five_upper[i]
2602 );
2603 assert!(
2604 (mb[len - 5 + i] - expected_last_five_middle[i]).abs() < 1e-6,
2605 "[{test}] middleband mismatch idx {i}: {} vs {}",
2606 mb[len - 5 + i],
2607 expected_last_five_middle[i]
2608 );
2609 assert!(
2610 (lb[len - 5 + i] - expected_last_five_lower[i]).abs() < 1e-6,
2611 "[{test}] lowerband mismatch idx {i}: {} vs {}",
2612 lb[len - 5 + i],
2613 expected_last_five_lower[i]
2614 );
2615 }
2616
2617 Ok(())
2618 }
2619
2620 fn check_batch_param_sweep(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2621 skip_if_unsupported!(kernel, test);
2622
2623 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2624 let c = read_candles_from_csv(file)?;
2625
2626 let output = EmdBatchBuilder::new()
2627 .kernel(kernel)
2628 .period_range(20, 22, 2)
2629 .delta_range(0.5, 0.6, 0.1)
2630 .fraction_range(0.1, 0.2, 0.1)
2631 .apply_candles(&c)?;
2632
2633 assert!(
2634 output.rows == 8,
2635 "Expected 8 rows (2*2*2 grid), got {}",
2636 output.rows
2637 );
2638 assert_eq!(output.cols, c.close.len());
2639
2640 for params in &output.combos {
2641 let (ub, mb, lb) = output
2642 .bands_for(params)
2643 .expect("row for params missing in sweep");
2644 assert_eq!(ub.len(), output.cols);
2645 assert_eq!(mb.len(), output.cols);
2646 assert_eq!(lb.len(), output.cols);
2647 }
2648
2649 Ok(())
2650 }
2651
2652 macro_rules! gen_batch_tests {
2653 ($fn_name:ident) => {
2654 paste::paste! {
2655 #[test] fn [<$fn_name _scalar>]() {
2656 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2657 }
2658 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2659 #[test] fn [<$fn_name _avx2>]() {
2660 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2661 }
2662 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2663 #[test] fn [<$fn_name _avx512>]() {
2664 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2665 }
2666 #[test] fn [<$fn_name _auto_detect>]() {
2667 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2668 }
2669 }
2670 };
2671 }
2672
2673 #[cfg(debug_assertions)]
2674 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2675 skip_if_unsupported!(kernel, test);
2676
2677 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2678 let c = read_candles_from_csv(file)?;
2679
2680 let test_configs = vec![
2681 (5, 15, 5, 0.1, 0.5, 0.2, 0.05, 0.15, 0.05),
2682 (10, 30, 10, 0.3, 0.7, 0.2, 0.1, 0.2, 0.05),
2683 (20, 50, 15, 0.5, 0.8, 0.15, 0.15, 0.3, 0.075),
2684 (8, 12, 1, 0.4, 0.6, 0.1, 0.08, 0.12, 0.02),
2685 (20, 20, 0, 0.5, 0.5, 0.0, 0.1, 0.1, 0.0),
2686 (5, 40, 5, 0.2, 0.8, 0.1, 0.05, 0.25, 0.05),
2687 (2, 6, 2, 0.1, 0.9, 0.4, 0.01, 0.5, 0.245),
2688 ];
2689
2690 for (
2691 cfg_idx,
2692 &(p_start, p_end, p_step, d_start, d_end, d_step, f_start, f_end, f_step),
2693 ) in test_configs.iter().enumerate()
2694 {
2695 let output = EmdBatchBuilder::new()
2696 .kernel(kernel)
2697 .period_range(p_start, p_end, p_step)
2698 .delta_range(d_start, d_end, d_step)
2699 .fraction_range(f_start, f_end, f_step)
2700 .apply_candles(&c)?;
2701
2702 for (idx, &val) in output.upperband.iter().enumerate() {
2703 if val.is_nan() {
2704 continue;
2705 }
2706
2707 let bits = val.to_bits();
2708 let row = idx / output.cols;
2709 let col = idx % output.cols;
2710 let combo = &output.combos[row];
2711
2712 if bits == 0x11111111_11111111 {
2713 panic!(
2714 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) in upperband \
2715 at row {} col {} (flat index {}) with params: period={}, delta={}, fraction={}",
2716 test, cfg_idx, val, bits, row, col, idx,
2717 combo.period.unwrap_or(20),
2718 combo.delta.unwrap_or(0.5),
2719 combo.fraction.unwrap_or(0.1)
2720 );
2721 }
2722
2723 if bits == 0x22222222_22222222 {
2724 panic!(
2725 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) in upperband \
2726 at row {} col {} (flat index {}) with params: period={}, delta={}, fraction={}",
2727 test, cfg_idx, val, bits, row, col, idx,
2728 combo.period.unwrap_or(20),
2729 combo.delta.unwrap_or(0.5),
2730 combo.fraction.unwrap_or(0.1)
2731 );
2732 }
2733
2734 if bits == 0x33333333_33333333 {
2735 panic!(
2736 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) in upperband \
2737 at row {} col {} (flat index {}) with params: period={}, delta={}, fraction={}",
2738 test, cfg_idx, val, bits, row, col, idx,
2739 combo.period.unwrap_or(20),
2740 combo.delta.unwrap_or(0.5),
2741 combo.fraction.unwrap_or(0.1)
2742 );
2743 }
2744 }
2745
2746 for (idx, &val) in output.middleband.iter().enumerate() {
2747 if val.is_nan() {
2748 continue;
2749 }
2750
2751 let bits = val.to_bits();
2752 let row = idx / output.cols;
2753 let col = idx % output.cols;
2754 let combo = &output.combos[row];
2755
2756 if bits == 0x11111111_11111111 {
2757 panic!(
2758 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) in middleband \
2759 at row {} col {} (flat index {}) with params: period={}, delta={}, fraction={}",
2760 test, cfg_idx, val, bits, row, col, idx,
2761 combo.period.unwrap_or(20),
2762 combo.delta.unwrap_or(0.5),
2763 combo.fraction.unwrap_or(0.1)
2764 );
2765 }
2766
2767 if bits == 0x22222222_22222222 {
2768 panic!(
2769 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) in middleband \
2770 at row {} col {} (flat index {}) with params: period={}, delta={}, fraction={}",
2771 test, cfg_idx, val, bits, row, col, idx,
2772 combo.period.unwrap_or(20),
2773 combo.delta.unwrap_or(0.5),
2774 combo.fraction.unwrap_or(0.1)
2775 );
2776 }
2777
2778 if bits == 0x33333333_33333333 {
2779 panic!(
2780 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) in middleband \
2781 at row {} col {} (flat index {}) with params: period={}, delta={}, fraction={}",
2782 test, cfg_idx, val, bits, row, col, idx,
2783 combo.period.unwrap_or(20),
2784 combo.delta.unwrap_or(0.5),
2785 combo.fraction.unwrap_or(0.1)
2786 );
2787 }
2788 }
2789
2790 for (idx, &val) in output.lowerband.iter().enumerate() {
2791 if val.is_nan() {
2792 continue;
2793 }
2794
2795 let bits = val.to_bits();
2796 let row = idx / output.cols;
2797 let col = idx % output.cols;
2798 let combo = &output.combos[row];
2799
2800 if bits == 0x11111111_11111111 {
2801 panic!(
2802 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) in lowerband \
2803 at row {} col {} (flat index {}) with params: period={}, delta={}, fraction={}",
2804 test, cfg_idx, val, bits, row, col, idx,
2805 combo.period.unwrap_or(20),
2806 combo.delta.unwrap_or(0.5),
2807 combo.fraction.unwrap_or(0.1)
2808 );
2809 }
2810
2811 if bits == 0x22222222_22222222 {
2812 panic!(
2813 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) in lowerband \
2814 at row {} col {} (flat index {}) with params: period={}, delta={}, fraction={}",
2815 test, cfg_idx, val, bits, row, col, idx,
2816 combo.period.unwrap_or(20),
2817 combo.delta.unwrap_or(0.5),
2818 combo.fraction.unwrap_or(0.1)
2819 );
2820 }
2821
2822 if bits == 0x33333333_33333333 {
2823 panic!(
2824 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) in lowerband \
2825 at row {} col {} (flat index {}) with params: period={}, delta={}, fraction={}",
2826 test, cfg_idx, val, bits, row, col, idx,
2827 combo.period.unwrap_or(20),
2828 combo.delta.unwrap_or(0.5),
2829 combo.fraction.unwrap_or(0.1)
2830 );
2831 }
2832 }
2833 }
2834
2835 Ok(())
2836 }
2837
2838 #[cfg(not(debug_assertions))]
2839 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2840 Ok(())
2841 }
2842
2843 gen_batch_tests!(check_batch_default_row);
2844 gen_batch_tests!(check_batch_param_sweep);
2845 gen_batch_tests!(check_batch_no_poison);
2846 }
2847}
2848
2849#[cfg(feature = "python")]
2850#[pyfunction(name = "emd")]
2851#[pyo3(signature = (high, low, period, delta, fraction, kernel=None))]
2852pub fn emd_py<'py>(
2853 py: Python<'py>,
2854 high: PyReadonlyArray1<'py, f64>,
2855 low: PyReadonlyArray1<'py, f64>,
2856 period: usize,
2857 delta: f64,
2858 fraction: f64,
2859 kernel: Option<&str>,
2860) -> PyResult<(
2861 Bound<'py, PyArray1<f64>>,
2862 Bound<'py, PyArray1<f64>>,
2863 Bound<'py, PyArray1<f64>>,
2864)> {
2865 let hi = high.as_slice()?;
2866 let lo = low.as_slice()?;
2867 if hi.len() != lo.len() {
2868 return Err(PyValueError::new_err("high and low must have same length"));
2869 }
2870
2871 let params = EmdParams {
2872 period: Some(period),
2873 delta: Some(delta),
2874 fraction: Some(fraction),
2875 };
2876 let inp = EmdInput::from_slices(hi, lo, &[], &[], params);
2877 let kern = validate_kernel(kernel, false)?;
2878
2879 let ub = unsafe { PyArray1::<f64>::new(py, [hi.len()], false) };
2880 let mb = unsafe { PyArray1::<f64>::new(py, [hi.len()], false) };
2881 let lb = unsafe { PyArray1::<f64>::new(py, [hi.len()], false) };
2882
2883 let ubm = unsafe { ub.as_slice_mut()? };
2884 let mbm = unsafe { mb.as_slice_mut()? };
2885 let lbm = unsafe { lb.as_slice_mut()? };
2886
2887 py.allow_threads(|| emd_into_slices(ubm, mbm, lbm, &inp, kern))
2888 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2889
2890 Ok((ub, mb, lb))
2891}
2892
2893#[cfg(feature = "python")]
2894#[pyclass(name = "EmdStream")]
2895pub struct EmdStreamPy {
2896 stream: EmdStream,
2897}
2898
2899#[cfg(feature = "python")]
2900#[pymethods]
2901impl EmdStreamPy {
2902 #[new]
2903 fn new(period: usize, delta: f64, fraction: f64) -> PyResult<Self> {
2904 let params = EmdParams {
2905 period: Some(period),
2906 delta: Some(delta),
2907 fraction: Some(fraction),
2908 };
2909 let stream =
2910 EmdStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2911 Ok(EmdStreamPy { stream })
2912 }
2913
2914 fn update(&mut self, high: f64, low: f64) -> (Option<f64>, Option<f64>, Option<f64>) {
2915 self.stream.update(high, low)
2916 }
2917}
2918
2919#[cfg(feature = "python")]
2920#[pyfunction(name = "emd_batch")]
2921#[pyo3(signature = (high, low, period_range, delta_range, fraction_range, kernel=None))]
2922pub fn emd_batch_py<'py>(
2923 py: Python<'py>,
2924 high: PyReadonlyArray1<'py, f64>,
2925 low: PyReadonlyArray1<'py, f64>,
2926 period_range: (usize, usize, usize),
2927 delta_range: (f64, f64, f64),
2928 fraction_range: (f64, f64, f64),
2929 kernel: Option<&str>,
2930) -> PyResult<Bound<'py, PyDict>> {
2931 use numpy::PyArrayMethods;
2932
2933 let hi = high.as_slice()?;
2934 let lo = low.as_slice()?;
2935 if hi.len() != lo.len() {
2936 return Err(PyValueError::new_err("high and low must have same length"));
2937 }
2938
2939 let sweep = EmdBatchRange {
2940 period: period_range,
2941 delta: delta_range,
2942 fraction: fraction_range,
2943 };
2944 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2945 let rows = combos.len();
2946 let cols = hi.len();
2947 let total = rows
2948 .checked_mul(cols)
2949 .ok_or_else(|| PyValueError::new_err("rows * cols overflow in emd_batch_py"))?;
2950
2951 let ub = unsafe { PyArray1::<f64>::new(py, [total], false) };
2952 let mb = unsafe { PyArray1::<f64>::new(py, [total], false) };
2953 let lb = unsafe { PyArray1::<f64>::new(py, [total], false) };
2954
2955 let ubm = unsafe { ub.as_slice_mut()? };
2956 let mbm = unsafe { mb.as_slice_mut()? };
2957 let lbm = unsafe { lb.as_slice_mut()? };
2958
2959 let kern = validate_kernel(kernel, true)?;
2960
2961 let combos = py
2962 .allow_threads(|| {
2963 let k = match kern {
2964 Kernel::Auto => detect_best_batch_kernel(),
2965 k => k,
2966 };
2967 let simd = match k {
2968 Kernel::Avx512Batch => Kernel::Avx512,
2969 Kernel::Avx2Batch => Kernel::Avx2,
2970 Kernel::ScalarBatch => Kernel::Scalar,
2971 _ => unreachable!(),
2972 };
2973 emd_batch_inner_into(hi, lo, &sweep, simd, true, ubm, mbm, lbm)
2974 })
2975 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2976
2977 let d = PyDict::new(py);
2978 d.set_item("upper", ub.reshape((rows, cols))?)?;
2979 d.set_item("middle", mb.reshape((rows, cols))?)?;
2980 d.set_item("lower", lb.reshape((rows, cols))?)?;
2981 d.set_item(
2982 "periods",
2983 combos
2984 .iter()
2985 .map(|p| p.period.unwrap() as u64)
2986 .collect::<Vec<_>>()
2987 .into_pyarray(py),
2988 )?;
2989 d.set_item(
2990 "deltas",
2991 combos
2992 .iter()
2993 .map(|p| p.delta.unwrap())
2994 .collect::<Vec<_>>()
2995 .into_pyarray(py),
2996 )?;
2997 d.set_item(
2998 "fractions",
2999 combos
3000 .iter()
3001 .map(|p| p.fraction.unwrap())
3002 .collect::<Vec<_>>()
3003 .into_pyarray(py),
3004 )?;
3005 Ok(d)
3006}
3007
3008#[cfg(all(feature = "python", feature = "cuda"))]
3009#[pyfunction(name = "emd_cuda_batch_dev")]
3010#[pyo3(signature = (high, low, period_range, delta_range, fraction_range, device_id=0))]
3011pub fn emd_cuda_batch_dev_py<'py>(
3012 py: Python<'py>,
3013 high: PyReadonlyArray1<'py, f32>,
3014 low: PyReadonlyArray1<'py, f32>,
3015 period_range: (usize, usize, usize),
3016 delta_range: (f64, f64, f64),
3017 fraction_range: (f64, f64, f64),
3018 device_id: usize,
3019) -> PyResult<Bound<'py, PyDict>> {
3020 use numpy::PyArrayMethods;
3021 if !cuda_available() {
3022 return Err(PyValueError::new_err("CUDA not available"));
3023 }
3024 let hi = high.as_slice()?;
3025 let lo = low.as_slice()?;
3026 if hi.len() != lo.len() {
3027 return Err(PyValueError::new_err("high and low must have same length"));
3028 }
3029 let sweep = EmdBatchRange {
3030 period: period_range,
3031 delta: delta_range,
3032 fraction: fraction_range,
3033 };
3034 let (outputs, combos, dev_id) = py.allow_threads(|| {
3035 let cuda = CudaEmd::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
3036 let dev_id = cuda.device_id();
3037 let res = cuda
3038 .emd_batch_dev(hi, lo, &sweep)
3039 .map_err(|e| PyValueError::new_err(e.to_string()));
3040 res.map(|r| (r.outputs, r.combos, dev_id))
3041 })?;
3042 let DeviceArrayF32Triple {
3043 upper,
3044 middle,
3045 lower,
3046 } = outputs;
3047 let dict = pyo3::types::PyDict::new(py);
3048 let upper_dev = make_device_array_py(dev_id as usize, upper)?;
3049 dict.set_item("upperband", Py::new(py, upper_dev)?)?;
3050 let middle_dev = make_device_array_py(dev_id as usize, middle)?;
3051 dict.set_item("middleband", Py::new(py, middle_dev)?)?;
3052 let lower_dev = make_device_array_py(dev_id as usize, lower)?;
3053 dict.set_item("lowerband", Py::new(py, lower_dev)?)?;
3054
3055 let periods: Vec<usize> = combos.iter().map(|c| c.period.unwrap()).collect();
3056 let deltas: Vec<f64> = combos.iter().map(|c| c.delta.unwrap()).collect();
3057 let fractions: Vec<f64> = combos.iter().map(|c| c.fraction.unwrap()).collect();
3058 dict.set_item("periods", periods.into_pyarray(py))?;
3059 dict.set_item("deltas", deltas.into_pyarray(py))?;
3060 dict.set_item("fractions", fractions.into_pyarray(py))?;
3061 dict.set_item("rows", combos.len())?;
3062 dict.set_item("cols", hi.len())?;
3063 Ok(dict)
3064}
3065
3066#[cfg(all(feature = "python", feature = "cuda"))]
3067#[pyfunction(name = "emd_cuda_many_series_one_param_dev")]
3068#[pyo3(signature = (data_tm_f32, period, delta, fraction, device_id=0))]
3069pub fn emd_cuda_many_series_one_param_dev_py<'py>(
3070 py: Python<'py>,
3071 data_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
3072 period: usize,
3073 delta: f64,
3074 fraction: f64,
3075 device_id: usize,
3076) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
3077 use numpy::PyUntypedArrayMethods;
3078 if !cuda_available() {
3079 return Err(PyValueError::new_err("CUDA not available"));
3080 }
3081 let shape = data_tm_f32.shape();
3082 if shape.len() != 2 {
3083 return Err(PyValueError::new_err("expected 2D array"));
3084 }
3085 let rows = shape[0];
3086 let cols = shape[1];
3087 let flat = data_tm_f32.as_slice()?;
3088
3089 let mut first_valids = vec![0i32; cols];
3090 for s in 0..cols {
3091 let mut fv: Option<i32> = None;
3092 for t in 0..rows {
3093 let v = flat[t * cols + s];
3094 if v.is_finite() {
3095 fv = Some(t as i32);
3096 break;
3097 }
3098 }
3099 first_valids[s] =
3100 fv.ok_or_else(|| PyValueError::new_err(format!("series {} has no finite values", s)))?;
3101 }
3102
3103 let params = EmdParams {
3104 period: Some(period),
3105 delta: Some(delta),
3106 fraction: Some(fraction),
3107 };
3108 let (outputs, dev_id) = py.allow_threads(|| {
3109 let cuda = CudaEmd::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
3110 let dev_id = cuda.device_id();
3111 cuda.emd_many_series_one_param_time_major_dev(flat, cols, rows, ¶ms, &first_valids)
3112 .map(|o| (o, dev_id))
3113 .map_err(|e| PyValueError::new_err(e.to_string()))
3114 })?;
3115 let DeviceArrayF32Triple {
3116 upper,
3117 middle,
3118 lower,
3119 } = outputs;
3120 let dict = pyo3::types::PyDict::new(py);
3121 let upper_dev = make_device_array_py(dev_id as usize, upper)?;
3122 dict.set_item("upperband", Py::new(py, upper_dev)?)?;
3123 let middle_dev = make_device_array_py(dev_id as usize, middle)?;
3124 dict.set_item("middleband", Py::new(py, middle_dev)?)?;
3125 let lower_dev = make_device_array_py(dev_id as usize, lower)?;
3126 dict.set_item("lowerband", Py::new(py, lower_dev)?)?;
3127 dict.set_item("rows", rows)?;
3128 dict.set_item("cols", cols)?;
3129 dict.set_item("period", period)?;
3130 dict.set_item("delta", delta)?;
3131 dict.set_item("fraction", fraction)?;
3132 Ok(dict)
3133}
3134
3135#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3136use serde::{Deserialize, Serialize};
3137#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3138use wasm_bindgen::prelude::*;
3139
3140#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3141#[derive(Serialize, Deserialize)]
3142pub struct EmdJsOutput {
3143 pub values: Vec<f64>,
3144 pub rows: usize,
3145 pub cols: usize,
3146}
3147
3148#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3149#[wasm_bindgen(js_name = "emd_js")]
3150pub fn emd_js(
3151 high: &[f64],
3152 low: &[f64],
3153 _close: &[f64],
3154 _volume: &[f64],
3155 period: usize,
3156 delta: f64,
3157 fraction: f64,
3158) -> Result<JsValue, JsValue> {
3159 if high.len() != low.len() {
3160 return Err(JsValue::from_str("high and low must have same length"));
3161 }
3162 let params = EmdParams {
3163 period: Some(period),
3164 delta: Some(delta),
3165 fraction: Some(fraction),
3166 };
3167 let input = EmdInput::from_slices(high, low, &[], &[], params);
3168
3169 let mut values = vec![f64::NAN; 3 * high.len()];
3170 let (ub, rest) = values.split_at_mut(high.len());
3171 let (mb, lb) = rest.split_at_mut(high.len());
3172
3173 emd_into_slices(ub, mb, lb, &input, detect_best_kernel())
3174 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3175
3176 let output = EmdJsOutput {
3177 values,
3178 rows: 3,
3179 cols: high.len(),
3180 };
3181 serde_wasm_bindgen::to_value(&output).map_err(|e| JsValue::from_str(&e.to_string()))
3182}
3183
3184#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3185#[wasm_bindgen]
3186pub fn emd_alloc(len: usize) -> *mut f64 {
3187 let mut vec = Vec::<f64>::with_capacity(len);
3188 let ptr = vec.as_mut_ptr();
3189 std::mem::forget(vec);
3190 ptr
3191}
3192
3193#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3194#[wasm_bindgen]
3195pub fn emd_free(ptr: *mut f64, len: usize) {
3196 if !ptr.is_null() {
3197 unsafe {
3198 let _ = Vec::from_raw_parts(ptr, len, len);
3199 }
3200 }
3201}
3202
3203#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3204#[wasm_bindgen]
3205pub fn emd_into(
3206 high_ptr: *const f64,
3207 low_ptr: *const f64,
3208 _close_ptr: *const f64,
3209 _volume_ptr: *const f64,
3210 upper_ptr: *mut f64,
3211 middle_ptr: *mut f64,
3212 lower_ptr: *mut f64,
3213 len: usize,
3214 period: usize,
3215 delta: f64,
3216 fraction: f64,
3217) -> Result<(), JsValue> {
3218 if high_ptr.is_null()
3219 || low_ptr.is_null()
3220 || upper_ptr.is_null()
3221 || middle_ptr.is_null()
3222 || lower_ptr.is_null()
3223 {
3224 return Err(JsValue::from_str("null pointer"));
3225 }
3226 unsafe {
3227 let hi_aliased = high_ptr as *const f64 == upper_ptr as *const f64;
3228 let lo_aliased = low_ptr as *const f64 == upper_ptr as *const f64
3229 || low_ptr as *const f64 == middle_ptr as *const f64
3230 || low_ptr as *const f64 == lower_ptr as *const f64;
3231
3232 if hi_aliased || lo_aliased {
3233 let hi = std::slice::from_raw_parts(high_ptr, len);
3234 let lo = std::slice::from_raw_parts(low_ptr, len);
3235
3236 let params = EmdParams {
3237 period: Some(period),
3238 delta: Some(delta),
3239 fraction: Some(fraction),
3240 };
3241 let input = EmdInput::from_slices(hi, lo, &[], &[], params);
3242 let output = emd_with_kernel(&input, detect_best_kernel())
3243 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3244
3245 let ub = std::slice::from_raw_parts_mut(upper_ptr, len);
3246 let mb = std::slice::from_raw_parts_mut(middle_ptr, len);
3247 let lb = std::slice::from_raw_parts_mut(lower_ptr, len);
3248 ub.copy_from_slice(&output.upperband);
3249 mb.copy_from_slice(&output.middleband);
3250 lb.copy_from_slice(&output.lowerband);
3251 } else {
3252 let hi = std::slice::from_raw_parts(high_ptr, len);
3253 let lo = std::slice::from_raw_parts(low_ptr, len);
3254 let ub = std::slice::from_raw_parts_mut(upper_ptr, len);
3255 let mb = std::slice::from_raw_parts_mut(middle_ptr, len);
3256 let lb = std::slice::from_raw_parts_mut(lower_ptr, len);
3257 let params = EmdParams {
3258 period: Some(period),
3259 delta: Some(delta),
3260 fraction: Some(fraction),
3261 };
3262 let input = EmdInput::from_slices(hi, lo, &[], &[], params);
3263 emd_into_slices(ub, mb, lb, &input, detect_best_kernel())
3264 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3265 }
3266 Ok(())
3267 }
3268}
3269
3270#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3271#[derive(Serialize, Deserialize)]
3272pub struct EmdBatchConfig {
3273 pub period_range: (usize, usize, usize),
3274 pub delta_range: (f64, f64, f64),
3275 pub fraction_range: (f64, f64, f64),
3276}
3277
3278#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3279#[derive(Serialize, Deserialize)]
3280pub struct EmdBatchJsOutput {
3281 pub upperband: Vec<f64>,
3282 pub middleband: Vec<f64>,
3283 pub lowerband: Vec<f64>,
3284 pub combos: Vec<EmdParams>,
3285 pub rows: usize,
3286 pub cols: usize,
3287}
3288
3289#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3290#[wasm_bindgen(js_name = "emd_batch")]
3291pub fn emd_batch_unified_js(
3292 high: &[f64],
3293 low: &[f64],
3294 _close: &[f64],
3295 _volume: &[f64],
3296 config: JsValue,
3297) -> Result<JsValue, JsValue> {
3298 if high.len() != low.len() {
3299 return Err(JsValue::from_str("high and low must have same length"));
3300 }
3301
3302 let cfg: EmdBatchConfig = serde_wasm_bindgen::from_value(config)
3303 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
3304 let sweep = EmdBatchRange {
3305 period: cfg.period_range,
3306 delta: cfg.delta_range,
3307 fraction: cfg.fraction_range,
3308 };
3309 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
3310 let rows_p = combos.len();
3311 let cols = high.len();
3312 let total = rows_p
3313 .checked_mul(cols)
3314 .ok_or_else(|| JsValue::from_str("rows * cols overflow in emd_batch_unified_js"))?;
3315
3316 let mut ub = vec![f64::NAN; total];
3317 let mut mb = vec![f64::NAN; total];
3318 let mut lb = vec![f64::NAN; total];
3319
3320 emd_batch_inner_into(
3321 high,
3322 low,
3323 &sweep,
3324 detect_best_kernel(),
3325 false,
3326 &mut ub,
3327 &mut mb,
3328 &mut lb,
3329 )
3330 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3331
3332 let out = EmdBatchJsOutput {
3333 upperband: ub,
3334 middleband: mb,
3335 lowerband: lb,
3336 combos,
3337 rows: rows_p,
3338 cols,
3339 };
3340 serde_wasm_bindgen::to_value(&out)
3341 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
3342}
3343
3344#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3345#[wasm_bindgen]
3346pub fn emd_batch_into(
3347 high_ptr: *const f64,
3348 low_ptr: *const f64,
3349 close_ptr: *const f64,
3350 volume_ptr: *const f64,
3351 upper_ptr: *mut f64,
3352 middle_ptr: *mut f64,
3353 lower_ptr: *mut f64,
3354 len: usize,
3355 period_start: usize,
3356 period_end: usize,
3357 period_step: usize,
3358 delta_start: f64,
3359 delta_end: f64,
3360 delta_step: f64,
3361 fraction_start: f64,
3362 fraction_end: f64,
3363 fraction_step: f64,
3364) -> Result<usize, JsValue> {
3365 if high_ptr.is_null()
3366 || low_ptr.is_null()
3367 || close_ptr.is_null()
3368 || volume_ptr.is_null()
3369 || upper_ptr.is_null()
3370 || middle_ptr.is_null()
3371 || lower_ptr.is_null()
3372 {
3373 return Err(JsValue::from_str("null pointer passed to emd_batch_into"));
3374 }
3375
3376 unsafe {
3377 let high = std::slice::from_raw_parts(high_ptr, len);
3378 let low = std::slice::from_raw_parts(low_ptr, len);
3379
3380 let sweep = EmdBatchRange {
3381 period: (period_start, period_end, period_step),
3382 delta: (delta_start, delta_end, delta_step),
3383 fraction: (fraction_start, fraction_end, fraction_step),
3384 };
3385
3386 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
3387 let rows = combos.len();
3388 let cols = len;
3389 let total_len = rows
3390 .checked_mul(cols)
3391 .ok_or_else(|| JsValue::from_str("rows * cols overflow in emd_batch_into"))?;
3392
3393 let upper_slice = std::slice::from_raw_parts_mut(upper_ptr, total_len);
3394 let middle_slice = std::slice::from_raw_parts_mut(middle_ptr, total_len);
3395 let lower_slice = std::slice::from_raw_parts_mut(lower_ptr, total_len);
3396
3397 emd_batch_inner_into(
3398 high,
3399 low,
3400 &sweep,
3401 detect_best_kernel(),
3402 false,
3403 upper_slice,
3404 middle_slice,
3405 lower_slice,
3406 )
3407 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3408
3409 Ok(rows)
3410 }
3411}