Skip to main content

vector_ta/indicators/
emd.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::{cuda_available, CudaEmd, CudaEmdBatchResult, DeviceArrayF32Triple};
3use crate::utilities::data_loader::{source_type, Candles};
4#[cfg(all(feature = "python", feature = "cuda"))]
5use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
6use crate::utilities::enums::Kernel;
7use crate::utilities::helpers::{
8    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
9    make_uninit_matrix,
10};
11#[cfg(feature = "python")]
12use crate::utilities::kernel_validation::validate_kernel;
13#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
14use core::arch::x86_64::*;
15#[cfg(feature = "python")]
16use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
17#[cfg(feature = "python")]
18use pyo3::exceptions::PyValueError;
19#[cfg(feature = "python")]
20use pyo3::prelude::*;
21#[cfg(feature = "python")]
22use pyo3::types::PyDict;
23#[cfg(not(target_arch = "wasm32"))]
24use rayon::prelude::*;
25use std::convert::AsRef;
26use std::error::Error;
27use std::mem::MaybeUninit;
28use thiserror::Error;
29
30impl<'a> AsRef<[f64]> for EmdInput<'a> {
31    #[inline(always)]
32    fn as_ref(&self) -> &[f64] {
33        match &self.data {
34            EmdData::Candles { candles } => source_type(candles, "close"),
35            EmdData::Slices { close, .. } => close,
36        }
37    }
38}
39
40#[derive(Debug, Clone)]
41pub enum EmdData<'a> {
42    Candles {
43        candles: &'a Candles,
44    },
45    Slices {
46        high: &'a [f64],
47        low: &'a [f64],
48        close: &'a [f64],
49        volume: &'a [f64],
50    },
51}
52
53#[derive(Debug, Clone)]
54pub struct EmdOutput {
55    pub upperband: Vec<f64>,
56    pub middleband: Vec<f64>,
57    pub lowerband: Vec<f64>,
58}
59
60#[derive(Debug, Clone)]
61#[cfg_attr(
62    all(target_arch = "wasm32", feature = "wasm"),
63    derive(Serialize, Deserialize)
64)]
65pub struct EmdParams {
66    pub period: Option<usize>,
67    pub delta: Option<f64>,
68    pub fraction: Option<f64>,
69}
70
71impl Default for EmdParams {
72    fn default() -> Self {
73        Self {
74            period: Some(20),
75            delta: Some(0.5),
76            fraction: Some(0.1),
77        }
78    }
79}
80
81#[derive(Debug, Clone)]
82pub struct EmdInput<'a> {
83    pub data: EmdData<'a>,
84    pub params: EmdParams,
85}
86
87impl<'a> EmdInput<'a> {
88    #[inline]
89    pub fn from_candles(candles: &'a Candles, params: EmdParams) -> Self {
90        Self {
91            data: EmdData::Candles { candles },
92            params,
93        }
94    }
95
96    #[inline]
97    pub fn from_slices(
98        high: &'a [f64],
99        low: &'a [f64],
100        close: &'a [f64],
101        volume: &'a [f64],
102        params: EmdParams,
103    ) -> Self {
104        Self {
105            data: EmdData::Slices {
106                high,
107                low,
108                close,
109                volume,
110            },
111            params,
112        }
113    }
114
115    #[inline]
116    pub fn with_default_candles(candles: &'a Candles) -> Self {
117        Self::from_candles(candles, EmdParams::default())
118    }
119
120    #[inline]
121    pub fn get_period(&self) -> usize {
122        self.params.period.unwrap_or(20)
123    }
124    #[inline]
125    pub fn get_delta(&self) -> f64 {
126        self.params.delta.unwrap_or(0.5)
127    }
128    #[inline]
129    pub fn get_fraction(&self) -> f64 {
130        self.params.fraction.unwrap_or(0.1)
131    }
132}
133
134#[derive(Copy, Clone, Debug)]
135pub struct EmdBuilder {
136    period: Option<usize>,
137    delta: Option<f64>,
138    fraction: Option<f64>,
139    kernel: Kernel,
140}
141
142impl Default for EmdBuilder {
143    fn default() -> Self {
144        Self {
145            period: None,
146            delta: None,
147            fraction: None,
148            kernel: Kernel::Auto,
149        }
150    }
151}
152
153impl EmdBuilder {
154    #[inline(always)]
155    pub fn new() -> Self {
156        Self::default()
157    }
158    #[inline(always)]
159    pub fn period(mut self, n: usize) -> Self {
160        self.period = Some(n);
161        self
162    }
163    #[inline(always)]
164    pub fn delta(mut self, d: f64) -> Self {
165        self.delta = Some(d);
166        self
167    }
168    #[inline(always)]
169    pub fn fraction(mut self, f: f64) -> Self {
170        self.fraction = Some(f);
171        self
172    }
173    #[inline(always)]
174    pub fn kernel(mut self, k: Kernel) -> Self {
175        self.kernel = k;
176        self
177    }
178
179    #[inline(always)]
180    pub fn apply(self, c: &Candles) -> Result<EmdOutput, EmdError> {
181        let p = EmdParams {
182            period: self.period,
183            delta: self.delta,
184            fraction: self.fraction,
185        };
186        let i = EmdInput::from_candles(c, p);
187        emd_with_kernel(&i, self.kernel)
188    }
189
190    #[inline(always)]
191    pub fn apply_slices(
192        self,
193        high: &[f64],
194        low: &[f64],
195        close: &[f64],
196        volume: &[f64],
197    ) -> Result<EmdOutput, EmdError> {
198        let p = EmdParams {
199            period: self.period,
200            delta: self.delta,
201            fraction: self.fraction,
202        };
203        let i = EmdInput::from_slices(high, low, close, volume, p);
204        emd_with_kernel(&i, self.kernel)
205    }
206
207    #[inline(always)]
208    pub fn into_stream(self) -> Result<EmdStream, EmdError> {
209        let p = EmdParams {
210            period: self.period,
211            delta: self.delta,
212            fraction: self.fraction,
213        };
214        EmdStream::try_new(p)
215    }
216}
217
218#[derive(Debug, Error)]
219pub enum EmdError {
220    #[error("emd: Invalid input length (empty input data)")]
221    EmptyInputData,
222
223    #[error("emd: All values are NaN.")]
224    AllValuesNaN,
225
226    #[error("emd: Invalid period: period = {period}, data length = {data_len}")]
227    InvalidPeriod { period: usize, data_len: usize },
228
229    #[error("emd: Not enough valid data: needed = {needed}, valid = {valid}")]
230    NotEnoughValidData { needed: usize, valid: usize },
231
232    #[error("emd: Invalid delta: {delta}")]
233    InvalidDelta { delta: f64 },
234
235    #[error("emd: Invalid fraction: {fraction}")]
236    InvalidFraction { fraction: f64 },
237
238    #[error("emd: Invalid input length: expected = {expected}, actual = {actual}")]
239    InvalidInputLength { expected: usize, actual: usize },
240
241    #[error("emd: Output length mismatch: expected = {expected}, got = {got}")]
242    OutputLengthMismatch { expected: usize, got: usize },
243
244    #[error("emd: Invalid range (usize): start={start}, end={end}, step={step}")]
245    InvalidRangeU {
246        start: usize,
247        end: usize,
248        step: usize,
249    },
250
251    #[error("emd: Invalid range (float): start={start}, end={end}, step={step}")]
252    InvalidRangeF { start: f64, end: f64, step: f64 },
253
254    #[error("emd: Invalid kernel for batch path: {0:?}")]
255    InvalidKernelForBatch(Kernel),
256}
257
258#[inline]
259pub fn emd(input: &EmdInput) -> Result<EmdOutput, EmdError> {
260    emd_with_kernel(input, Kernel::Auto)
261}
262
263#[inline]
264pub fn emd_into_slices(
265    ub: &mut [f64],
266    mb: &mut [f64],
267    lb: &mut [f64],
268    input: &EmdInput,
269    kernel: Kernel,
270) -> Result<(), EmdError> {
271    let (high, low, period, delta, fraction, first, chosen) = emd_prepare(input, kernel)?;
272    if ub.len() != high.len() || mb.len() != high.len() || lb.len() != high.len() {
273        return Err(EmdError::OutputLengthMismatch {
274            expected: high.len(),
275            got: ub.len().min(mb.len()).min(lb.len()),
276        });
277    }
278
279    emd_compute_into(
280        high, low, period, delta, fraction, first, chosen, ub, mb, lb,
281    );
282
283    let up_low_warm = first + 50 - 1;
284    let mid_warm = first + 2 * period - 1;
285    let ub_len = ub.len();
286    let lb_len = lb.len();
287    let mb_len = mb.len();
288    for v in &mut ub[..up_low_warm.min(ub_len)] {
289        *v = f64::NAN;
290    }
291    for v in &mut lb[..up_low_warm.min(lb_len)] {
292        *v = f64::NAN;
293    }
294    for v in &mut mb[..mid_warm.min(mb_len)] {
295        *v = f64::NAN;
296    }
297
298    Ok(())
299}
300
301#[inline]
302fn emd_prepare<'a>(
303    input: &'a EmdInput<'a>,
304    kernel: Kernel,
305) -> Result<(&'a [f64], &'a [f64], usize, f64, f64, usize, Kernel), EmdError> {
306    let (high, low) = match &input.data {
307        EmdData::Candles { candles } => (source_type(candles, "high"), source_type(candles, "low")),
308        EmdData::Slices { high, low, .. } => (*high, *low),
309    };
310
311    let len = high.len();
312    if len == 0 {
313        return Err(EmdError::EmptyInputData);
314    }
315    if low.len() != len {
316        return Err(EmdError::InvalidInputLength {
317            expected: len,
318            actual: low.len(),
319        });
320    }
321
322    let period = input.get_period();
323    let delta = input.get_delta();
324    let fraction = input.get_fraction();
325
326    let first = (0..len)
327        .find(|&i| !high[i].is_nan() && !low[i].is_nan())
328        .ok_or(EmdError::AllValuesNaN)?;
329
330    if period == 0 || period > len {
331        return Err(EmdError::InvalidPeriod {
332            period,
333            data_len: len,
334        });
335    }
336    let needed = (2 * period).max(50);
337    if len - first < needed {
338        return Err(EmdError::NotEnoughValidData {
339            needed,
340            valid: len - first,
341        });
342    }
343    if delta.is_nan() || delta.is_infinite() {
344        return Err(EmdError::InvalidDelta { delta });
345    }
346    if fraction.is_nan() || fraction.is_infinite() {
347        return Err(EmdError::InvalidFraction { fraction });
348    }
349
350    let chosen = match kernel {
351        Kernel::Auto => Kernel::Scalar,
352        k => k,
353    };
354    Ok((high, low, period, delta, fraction, first, chosen))
355}
356
357#[inline(always)]
358fn emd_compute_into(
359    high: &[f64],
360    low: &[f64],
361    period: usize,
362    delta: f64,
363    fraction: f64,
364    first: usize,
365    kernel: Kernel,
366    ub: &mut [f64],
367    mb: &mut [f64],
368    lb: &mut [f64],
369) {
370    unsafe {
371        match kernel {
372            Kernel::Scalar | Kernel::ScalarBatch => {
373                emd_scalar_into(high, low, period, delta, fraction, first, ub, mb, lb)
374            }
375            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
376            Kernel::Avx2 | Kernel::Avx2Batch => {
377                emd_scalar_into(high, low, period, delta, fraction, first, ub, mb, lb)
378            }
379            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
380            Kernel::Avx512 | Kernel::Avx512Batch => {
381                emd_scalar_into(high, low, period, delta, fraction, first, ub, mb, lb)
382            }
383            _ => unreachable!(),
384        }
385    }
386}
387
388pub fn emd_with_kernel(input: &EmdInput, kernel: Kernel) -> Result<EmdOutput, EmdError> {
389    let (high, low, period, delta, fraction, first, chosen) = emd_prepare(input, kernel)?;
390    let len = high.len();
391    let up_low_warm = first + 50 - 1;
392    let mid_warm = first + 2 * period - 1;
393
394    let mut upperband = alloc_with_nan_prefix(len, up_low_warm);
395    let mut middleband = alloc_with_nan_prefix(len, mid_warm);
396    let mut lowerband = alloc_with_nan_prefix(len, up_low_warm);
397
398    emd_compute_into(
399        high,
400        low,
401        period,
402        delta,
403        fraction,
404        first,
405        chosen,
406        &mut upperband,
407        &mut middleband,
408        &mut lowerband,
409    );
410
411    Ok(EmdOutput {
412        upperband,
413        middleband,
414        lowerband,
415    })
416}
417
418#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
419pub fn emd_into(
420    input: &EmdInput,
421    upperband_out: &mut [f64],
422    middleband_out: &mut [f64],
423    lowerband_out: &mut [f64],
424) -> Result<(), EmdError> {
425    let (high, low, period, delta, fraction, first, chosen) = emd_prepare(input, Kernel::Auto)?;
426
427    if upperband_out.len() != high.len()
428        || middleband_out.len() != high.len()
429        || lowerband_out.len() != high.len()
430    {
431        return Err(EmdError::OutputLengthMismatch {
432            expected: high.len(),
433            got: upperband_out
434                .len()
435                .min(middleband_out.len())
436                .min(lowerband_out.len()),
437        });
438    }
439
440    let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
441    let up_low_warm = first + 50 - 1;
442    let mid_warm = first + 2 * period - 1;
443
444    let end_u = up_low_warm.min(upperband_out.len());
445    for v in &mut upperband_out[..end_u] {
446        *v = qnan;
447    }
448    let end_l = up_low_warm.min(lowerband_out.len());
449    for v in &mut lowerband_out[..end_l] {
450        *v = qnan;
451    }
452    let end_m = mid_warm.min(middleband_out.len());
453    for v in &mut middleband_out[..end_m] {
454        *v = qnan;
455    }
456
457    emd_compute_into(
458        high,
459        low,
460        period,
461        delta,
462        fraction,
463        first,
464        chosen,
465        upperband_out,
466        middleband_out,
467        lowerband_out,
468    );
469
470    Ok(())
471}
472
473#[inline]
474pub unsafe fn emd_scalar_into(
475    high: &[f64],
476    low: &[f64],
477    period: usize,
478    delta: f64,
479    fraction: f64,
480    first: usize,
481    ub: &mut [f64],
482    mb: &mut [f64],
483    lb: &mut [f64],
484) {
485    let len = high.len();
486    debug_assert_eq!(low.len(), len);
487    debug_assert_eq!(ub.len(), len);
488    debug_assert_eq!(mb.len(), len);
489    debug_assert_eq!(lb.len(), len);
490
491    let per_up_low = 50usize;
492    let per_mid = 2 * period;
493    let inv_up_low = 1.0 / (per_up_low as f64);
494    let inv_mid = 1.0 / (per_mid as f64);
495
496    let two_pi = core::f64::consts::PI * 2.0;
497    let beta = (two_pi / (period as f64)).cos();
498    let gamma = 1.0 / ((two_pi * 2.0 * delta / (period as f64)).cos());
499    let alpha = gamma - (gamma * gamma - 1.0).sqrt();
500    let half_one_minus_alpha = 0.5 * (1.0 - alpha);
501    let beta_times_one_plus_alpha = beta * (1.0 + alpha);
502
503    let mut sp_ring = vec![0.0f64; per_up_low];
504    let mut sv_ring = vec![0.0f64; per_up_low];
505    let mut bp_ring = vec![0.0f64; per_mid];
506
507    let mut idx_ul = 0usize;
508    let mut idx_mid = 0usize;
509
510    let mut sum_up = 0.0f64;
511    let mut sum_low = 0.0f64;
512    let mut sum_mb = 0.0f64;
513
514    let mut bp_prev1 = 0.0f64;
515    let mut bp_prev2 = 0.0f64;
516    let mut peak_prev = 0.0f64;
517    let mut valley_prev = 0.0f64;
518
519    let mut price_prev1 = 0.0f64;
520    let mut price_prev2 = 0.0f64;
521
522    let hi_ptr = high.as_ptr();
523    let lo_ptr = low.as_ptr();
524    let ub_ptr = ub.as_mut_ptr();
525    let mb_ptr = mb.as_mut_ptr();
526    let lb_ptr = lb.as_mut_ptr();
527
528    let mut i = first;
529    if i < len {
530        let p0 = ((*hi_ptr.add(i)) + (*lo_ptr.add(i))) * 0.5;
531        bp_prev1 = p0;
532        bp_prev2 = p0;
533        peak_prev = p0;
534        valley_prev = p0;
535        price_prev1 = p0;
536        price_prev2 = p0;
537    }
538
539    let mut count = 0usize;
540
541    while i < len {
542        let price = ((*hi_ptr.add(i)) + (*lo_ptr.add(i))) * 0.5;
543
544        let bp_curr = if count >= 2 {
545            half_one_minus_alpha * (price - price_prev2) + beta_times_one_plus_alpha * bp_prev1
546                - alpha * bp_prev2
547        } else {
548            price
549        };
550
551        let mut peak_curr = peak_prev;
552        let mut valley_curr = valley_prev;
553        if count >= 2 {
554            if bp_prev1 > bp_curr && bp_prev1 > bp_prev2 {
555                peak_curr = bp_prev1;
556            }
557            if bp_prev1 < bp_curr && bp_prev1 < bp_prev2 {
558                valley_curr = bp_prev1;
559            }
560        }
561
562        let sp = peak_curr * fraction;
563        let sv = valley_curr * fraction;
564
565        let old_sp = *sp_ring.get_unchecked(idx_ul);
566        let old_sv = *sv_ring.get_unchecked(idx_ul);
567        let old_bp = *bp_ring.get_unchecked(idx_mid);
568
569        *sp_ring.get_unchecked_mut(idx_ul) = sp;
570        *sv_ring.get_unchecked_mut(idx_ul) = sv;
571        *bp_ring.get_unchecked_mut(idx_mid) = bp_curr;
572
573        sum_up += sp - old_sp;
574        sum_low += sv - old_sv;
575        sum_mb += bp_curr - old_bp;
576
577        idx_ul += 1;
578        if idx_ul == per_up_low {
579            idx_ul = 0;
580        }
581        idx_mid += 1;
582        if idx_mid == per_mid {
583            idx_mid = 0;
584        }
585
586        let filled = count + 1;
587        if filled >= per_up_low {
588            *ub_ptr.add(i) = sum_up * inv_up_low;
589            *lb_ptr.add(i) = sum_low * inv_up_low;
590        }
591        if filled >= per_mid {
592            *mb_ptr.add(i) = sum_mb * inv_mid;
593        }
594
595        bp_prev2 = bp_prev1;
596        bp_prev1 = bp_curr;
597        peak_prev = peak_curr;
598        valley_prev = valley_curr;
599        price_prev2 = price_prev1;
600        price_prev1 = price;
601
602        count += 1;
603        i += 1;
604    }
605}
606
607#[inline(always)]
608fn emd_compute_from_prices_into(
609    prices: &[f64],
610    period: usize,
611    delta: f64,
612    fraction: f64,
613    first: usize,
614    kernel: Kernel,
615    ub: &mut [f64],
616    mb: &mut [f64],
617    lb: &mut [f64],
618) {
619    unsafe {
620        match kernel {
621            Kernel::Scalar | Kernel::ScalarBatch => {
622                emd_scalar_prices_into(prices, period, delta, fraction, first, ub, mb, lb)
623            }
624            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
625            Kernel::Avx2 | Kernel::Avx2Batch => {
626                emd_scalar_prices_into(prices, period, delta, fraction, first, ub, mb, lb)
627            }
628            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
629            Kernel::Avx512 | Kernel::Avx512Batch => {
630                emd_scalar_prices_into(prices, period, delta, fraction, first, ub, mb, lb)
631            }
632            _ => unreachable!(),
633        }
634    }
635}
636
637#[inline]
638unsafe fn emd_scalar_prices_into(
639    prices: &[f64],
640    period: usize,
641    delta: f64,
642    fraction: f64,
643    first: usize,
644    ub: &mut [f64],
645    mb: &mut [f64],
646    lb: &mut [f64],
647) {
648    let len = prices.len();
649    debug_assert_eq!(ub.len(), len);
650    debug_assert_eq!(mb.len(), len);
651    debug_assert_eq!(lb.len(), len);
652
653    let per_up_low = 50usize;
654    let per_mid = 2 * period;
655    let inv_up_low = 1.0 / (per_up_low as f64);
656    let inv_mid = 1.0 / (per_mid as f64);
657
658    let two_pi = core::f64::consts::PI * 2.0;
659    let beta = (two_pi / (period as f64)).cos();
660    let gamma = 1.0 / ((two_pi * 2.0 * delta / (period as f64)).cos());
661    let alpha = gamma - (gamma * gamma - 1.0).sqrt();
662    let half_one_minus_alpha = 0.5 * (1.0 - alpha);
663    let beta_times_one_plus_alpha = beta * (1.0 + alpha);
664
665    let mut sp_ring = vec![0.0f64; per_up_low];
666    let mut sv_ring = vec![0.0f64; per_up_low];
667    let mut bp_ring = vec![0.0f64; per_mid];
668    let mut idx_ul = 0usize;
669    let mut idx_mid = 0usize;
670
671    let mut sum_up = 0.0f64;
672    let mut sum_low = 0.0f64;
673    let mut sum_mb = 0.0f64;
674
675    let mut bp_prev1 = 0.0f64;
676    let mut bp_prev2 = 0.0f64;
677    let mut peak_prev = 0.0f64;
678    let mut valley_prev = 0.0f64;
679
680    let mut price_prev1 = 0.0f64;
681    let mut price_prev2 = 0.0f64;
682
683    let pr_ptr = prices.as_ptr();
684    let ub_ptr = ub.as_mut_ptr();
685    let mb_ptr = mb.as_mut_ptr();
686    let lb_ptr = lb.as_mut_ptr();
687
688    let mut i = first;
689    if i < len {
690        let p0 = *pr_ptr.add(i);
691        bp_prev1 = p0;
692        bp_prev2 = p0;
693        peak_prev = p0;
694        valley_prev = p0;
695        price_prev1 = p0;
696        price_prev2 = p0;
697    }
698
699    let mut count = 0usize;
700    while i < len {
701        let price = *pr_ptr.add(i);
702
703        let bp_curr = if count >= 2 {
704            half_one_minus_alpha * (price - price_prev2) + beta_times_one_plus_alpha * bp_prev1
705                - alpha * bp_prev2
706        } else {
707            price
708        };
709
710        let mut peak_curr = peak_prev;
711        let mut valley_curr = valley_prev;
712        if count >= 2 {
713            if bp_prev1 > bp_curr && bp_prev1 > bp_prev2 {
714                peak_curr = bp_prev1;
715            }
716            if bp_prev1 < bp_curr && bp_prev1 < bp_prev2 {
717                valley_curr = bp_prev1;
718            }
719        }
720        let sp = peak_curr * fraction;
721        let sv = valley_curr * fraction;
722
723        let old_sp = *sp_ring.get_unchecked(idx_ul);
724        let old_sv = *sv_ring.get_unchecked(idx_ul);
725        let old_bp = *bp_ring.get_unchecked(idx_mid);
726
727        *sp_ring.get_unchecked_mut(idx_ul) = sp;
728        *sv_ring.get_unchecked_mut(idx_ul) = sv;
729        *bp_ring.get_unchecked_mut(idx_mid) = bp_curr;
730
731        sum_up += sp - old_sp;
732        sum_low += sv - old_sv;
733        sum_mb += bp_curr - old_bp;
734
735        idx_ul += 1;
736        if idx_ul == per_up_low {
737            idx_ul = 0;
738        }
739        idx_mid += 1;
740        if idx_mid == per_mid {
741            idx_mid = 0;
742        }
743
744        let filled = count + 1;
745        if filled >= per_up_low {
746            *ub_ptr.add(i) = sum_up * inv_up_low;
747            *lb_ptr.add(i) = sum_low * inv_up_low;
748        }
749        if filled >= per_mid {
750            *mb_ptr.add(i) = sum_mb * inv_mid;
751        }
752
753        bp_prev2 = bp_prev1;
754        bp_prev1 = bp_curr;
755        peak_prev = peak_curr;
756        valley_prev = valley_curr;
757        price_prev2 = price_prev1;
758        price_prev1 = price;
759
760        count += 1;
761        i += 1;
762    }
763}
764
765#[inline]
766pub unsafe fn emd_scalar(
767    high: &[f64],
768    low: &[f64],
769    period: usize,
770    delta: f64,
771    fraction: f64,
772    first: usize,
773    len: usize,
774) -> Result<EmdOutput, EmdError> {
775    let per_up_low = 50;
776    let per_mid = 2 * period;
777    let upperband_warmup = first + per_up_low - 1;
778    let middleband_warmup = first + per_mid - 1;
779
780    let mut upperband = alloc_with_nan_prefix(len, upperband_warmup);
781    let mut middleband = alloc_with_nan_prefix(len, middleband_warmup);
782    let mut lowerband = alloc_with_nan_prefix(len, upperband_warmup);
783
784    emd_scalar_into(
785        high,
786        low,
787        period,
788        delta,
789        fraction,
790        first,
791        &mut upperband,
792        &mut middleband,
793        &mut lowerband,
794    );
795
796    Ok(EmdOutput {
797        upperband,
798        middleband,
799        lowerband,
800    })
801}
802
803#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
804#[inline]
805pub unsafe fn emd_avx2(
806    high: &[f64],
807    low: &[f64],
808    period: usize,
809    delta: f64,
810    fraction: f64,
811    first: usize,
812    len: usize,
813) -> Result<EmdOutput, EmdError> {
814    emd_scalar(high, low, period, delta, fraction, first, len)
815}
816
817#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
818#[inline]
819pub unsafe fn emd_avx512(
820    high: &[f64],
821    low: &[f64],
822    period: usize,
823    delta: f64,
824    fraction: f64,
825    first: usize,
826    len: usize,
827) -> Result<EmdOutput, EmdError> {
828    if period <= 32 {
829        emd_avx512_short(high, low, period, delta, fraction, first, len)
830    } else {
831        emd_avx512_long(high, low, period, delta, fraction, first, len)
832    }
833}
834
835#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
836#[inline]
837pub unsafe fn emd_avx512_short(
838    high: &[f64],
839    low: &[f64],
840    period: usize,
841    delta: f64,
842    fraction: f64,
843    first: usize,
844    len: usize,
845) -> Result<EmdOutput, EmdError> {
846    emd_scalar(high, low, period, delta, fraction, first, len)
847}
848
849#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
850#[inline]
851pub unsafe fn emd_avx512_long(
852    high: &[f64],
853    low: &[f64],
854    period: usize,
855    delta: f64,
856    fraction: f64,
857    first: usize,
858    len: usize,
859) -> Result<EmdOutput, EmdError> {
860    emd_scalar(high, low, period, delta, fraction, first, len)
861}
862
863#[inline(always)]
864pub fn emd_row_scalar(
865    high: &[f64],
866    low: &[f64],
867    period: usize,
868    delta: f64,
869    fraction: f64,
870    first: usize,
871    len: usize,
872) -> Result<EmdOutput, EmdError> {
873    unsafe { emd_scalar(high, low, period, delta, fraction, first, len) }
874}
875
876#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
877#[inline(always)]
878pub fn emd_row_avx2(
879    high: &[f64],
880    low: &[f64],
881    period: usize,
882    delta: f64,
883    fraction: f64,
884    first: usize,
885    len: usize,
886) -> Result<EmdOutput, EmdError> {
887    unsafe { emd_avx2(high, low, period, delta, fraction, first, len) }
888}
889
890#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
891#[inline(always)]
892pub fn emd_row_avx512(
893    high: &[f64],
894    low: &[f64],
895    period: usize,
896    delta: f64,
897    fraction: f64,
898    first: usize,
899    len: usize,
900) -> Result<EmdOutput, EmdError> {
901    unsafe { emd_avx512(high, low, period, delta, fraction, first, len) }
902}
903
904#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
905#[inline(always)]
906pub fn emd_row_avx512_short(
907    high: &[f64],
908    low: &[f64],
909    period: usize,
910    delta: f64,
911    fraction: f64,
912    first: usize,
913    len: usize,
914) -> Result<EmdOutput, EmdError> {
915    unsafe { emd_avx512_short(high, low, period, delta, fraction, first, len) }
916}
917
918#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
919#[inline(always)]
920pub fn emd_row_avx512_long(
921    high: &[f64],
922    low: &[f64],
923    period: usize,
924    delta: f64,
925    fraction: f64,
926    first: usize,
927    len: usize,
928) -> Result<EmdOutput, EmdError> {
929    unsafe { emd_avx512_long(high, low, period, delta, fraction, first, len) }
930}
931
932#[derive(Debug, Clone)]
933pub struct EmdStream {
934    period: usize,
935    delta: f64,
936    fraction: f64,
937    per_up_low: usize,
938    per_mid: usize,
939
940    inv_up_low: f64,
941    inv_mid: f64,
942    sum_up: f64,
943    sum_low: f64,
944    sum_mb: f64,
945    sp_ring: Vec<f64>,
946    sv_ring: Vec<f64>,
947    bp_ring: Vec<f64>,
948    idx_up_low: usize,
949    idx_mid: usize,
950    bp_prev1: f64,
951    bp_prev2: f64,
952    peak_prev: f64,
953    valley_prev: f64,
954    price_prev1: f64,
955    price_prev2: f64,
956    alpha: f64,
957    beta: f64,
958
959    beta_times_one_plus_alpha: f64,
960    half_one_minus_alpha: f64,
961    initialized: bool,
962    count: usize,
963}
964
965impl EmdStream {
966    pub fn try_new(params: EmdParams) -> Result<Self, EmdError> {
967        let period = params.period.unwrap_or(20);
968        let delta = params.delta.unwrap_or(0.5);
969        let fraction = params.fraction.unwrap_or(0.1);
970
971        if period == 0 {
972            return Err(EmdError::InvalidPeriod {
973                period,
974                data_len: 0,
975            });
976        }
977        if delta.is_nan() || delta.is_infinite() {
978            return Err(EmdError::InvalidDelta { delta });
979        }
980        if fraction.is_nan() || fraction.is_infinite() {
981            return Err(EmdError::InvalidFraction { fraction });
982        }
983
984        let two_pi_over_p = 2.0 * std::f64::consts::PI / (period as f64);
985        let beta = (two_pi_over_p).cos();
986        let gamma = 1.0 / ((2.0 * delta * two_pi_over_p).cos());
987        let alpha = gamma - (gamma * gamma - 1.0).sqrt();
988        let half_one_minus_alpha = 0.5 * (1.0 - alpha);
989        let beta_times_one_plus_alpha = beta * (1.0 + alpha);
990        let per_up_low = 50usize;
991        let per_mid = 2 * period;
992
993        Ok(Self {
994            period,
995            delta,
996            fraction,
997            per_up_low,
998            per_mid,
999            inv_up_low: 1.0 / (per_up_low as f64),
1000            inv_mid: 1.0 / (per_mid as f64),
1001            sum_up: 0.0,
1002            sum_low: 0.0,
1003            sum_mb: 0.0,
1004            sp_ring: vec![0.0; per_up_low],
1005            sv_ring: vec![0.0; per_up_low],
1006            bp_ring: vec![0.0; per_mid],
1007            idx_up_low: 0,
1008            idx_mid: 0,
1009            bp_prev1: 0.0,
1010            bp_prev2: 0.0,
1011            peak_prev: 0.0,
1012            valley_prev: 0.0,
1013            price_prev1: 0.0,
1014            price_prev2: 0.0,
1015            alpha,
1016            beta,
1017            beta_times_one_plus_alpha,
1018            half_one_minus_alpha,
1019            initialized: false,
1020            count: 0,
1021        })
1022    }
1023
1024    #[inline(always)]
1025    pub fn update(&mut self, high: f64, low: f64) -> (Option<f64>, Option<f64>, Option<f64>) {
1026        let price = (high + low) * 0.5;
1027
1028        if !self.initialized {
1029            self.bp_prev1 = price;
1030            self.bp_prev2 = price;
1031            self.peak_prev = price;
1032            self.valley_prev = price;
1033            self.price_prev1 = price;
1034            self.price_prev2 = price;
1035            self.initialized = true;
1036        }
1037        let bp_curr = if self.count >= 2 {
1038            self.half_one_minus_alpha * (price - self.price_prev2)
1039                + self
1040                    .beta_times_one_plus_alpha
1041                    .mul_add(self.bp_prev1, -self.alpha * self.bp_prev2)
1042        } else {
1043            price
1044        };
1045        let mut peak_curr = self.peak_prev;
1046        let mut valley_curr = self.valley_prev;
1047        if self.count >= 2 {
1048            if self.bp_prev1 > bp_curr && self.bp_prev1 > self.bp_prev2 {
1049                peak_curr = self.bp_prev1;
1050            }
1051            if self.bp_prev1 < bp_curr && self.bp_prev1 < self.bp_prev2 {
1052                valley_curr = self.bp_prev1;
1053            }
1054        }
1055        let sp = peak_curr * self.fraction;
1056        let sv = valley_curr * self.fraction;
1057
1058        let old_sp = self.sp_ring[self.idx_up_low];
1059        let old_sv = self.sv_ring[self.idx_up_low];
1060        let old_bp = self.bp_ring[self.idx_mid];
1061        self.sum_up += sp - old_sp;
1062        self.sum_low += sv - old_sv;
1063        self.sum_mb += bp_curr - old_bp;
1064        self.sp_ring[self.idx_up_low] = sp;
1065        self.sv_ring[self.idx_up_low] = sv;
1066        self.bp_ring[self.idx_mid] = bp_curr;
1067
1068        self.idx_up_low += 1;
1069        if self.idx_up_low == self.per_up_low {
1070            self.idx_up_low = 0;
1071        }
1072        self.idx_mid += 1;
1073        if self.idx_mid == self.per_mid {
1074            self.idx_mid = 0;
1075        }
1076        let mut ub = None;
1077        let mut lb = None;
1078        let mut mb = None;
1079        if self.count + 1 >= self.per_up_low {
1080            ub = Some(self.sum_up * self.inv_up_low);
1081            lb = Some(self.sum_low * self.inv_up_low);
1082        }
1083        if self.count + 1 >= self.per_mid {
1084            mb = Some(self.sum_mb * self.inv_mid);
1085        }
1086        self.bp_prev2 = self.bp_prev1;
1087        self.bp_prev1 = bp_curr;
1088        self.peak_prev = peak_curr;
1089        self.valley_prev = valley_curr;
1090        self.price_prev2 = self.price_prev1;
1091        self.price_prev1 = price;
1092        self.count += 1;
1093        (ub, mb, lb)
1094    }
1095}
1096
1097#[derive(Clone, Debug)]
1098pub struct EmdBatchRange {
1099    pub period: (usize, usize, usize),
1100    pub delta: (f64, f64, f64),
1101    pub fraction: (f64, f64, f64),
1102}
1103
1104impl Default for EmdBatchRange {
1105    fn default() -> Self {
1106        Self {
1107            period: (20, 269, 1),
1108            delta: (0.5, 0.5, 0.0),
1109            fraction: (0.1, 0.1, 0.0),
1110        }
1111    }
1112}
1113
1114#[derive(Clone, Debug, Default)]
1115pub struct EmdBatchBuilder {
1116    range: EmdBatchRange,
1117    kernel: Kernel,
1118}
1119
1120impl EmdBatchBuilder {
1121    pub fn new() -> Self {
1122        Self::default()
1123    }
1124    pub fn kernel(mut self, k: Kernel) -> Self {
1125        self.kernel = k;
1126        self
1127    }
1128    #[inline]
1129    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1130        self.range.period = (start, end, step);
1131        self
1132    }
1133    #[inline]
1134    pub fn period_static(mut self, p: usize) -> Self {
1135        self.range.period = (p, p, 0);
1136        self
1137    }
1138    #[inline]
1139    pub fn delta_range(mut self, start: f64, end: f64, step: f64) -> Self {
1140        self.range.delta = (start, end, step);
1141        self
1142    }
1143    #[inline]
1144    pub fn delta_static(mut self, x: f64) -> Self {
1145        self.range.delta = (x, x, 0.0);
1146        self
1147    }
1148    #[inline]
1149    pub fn fraction_range(mut self, start: f64, end: f64, step: f64) -> Self {
1150        self.range.fraction = (start, end, step);
1151        self
1152    }
1153    #[inline]
1154    pub fn fraction_static(mut self, x: f64) -> Self {
1155        self.range.fraction = (x, x, 0.0);
1156        self
1157    }
1158    pub fn apply_slices(
1159        self,
1160        high: &[f64],
1161        low: &[f64],
1162        close: &[f64],
1163        volume: &[f64],
1164    ) -> Result<EmdBatchOutput, EmdError> {
1165        emd_batch_with_kernel(high, low, &self.range, self.kernel)
1166    }
1167    pub fn with_default_slices(
1168        high: &[f64],
1169        low: &[f64],
1170        close: &[f64],
1171        volume: &[f64],
1172        k: Kernel,
1173    ) -> Result<EmdBatchOutput, EmdError> {
1174        EmdBatchBuilder::new()
1175            .kernel(k)
1176            .apply_slices(high, low, close, volume)
1177    }
1178    pub fn apply_candles(self, c: &Candles) -> Result<EmdBatchOutput, EmdError> {
1179        let high = source_type(c, "high");
1180        let low = source_type(c, "low");
1181        let close = source_type(c, "close");
1182        let volume = source_type(c, "volume");
1183        self.apply_slices(high, low, close, volume)
1184    }
1185}
1186
1187pub fn emd_batch_with_kernel(
1188    high: &[f64],
1189    low: &[f64],
1190    sweep: &EmdBatchRange,
1191    k: Kernel,
1192) -> Result<EmdBatchOutput, EmdError> {
1193    let kernel = match k {
1194        Kernel::Auto => Kernel::ScalarBatch,
1195        other if other.is_batch() => other,
1196        _ => {
1197            return Err(EmdError::InvalidKernelForBatch(k));
1198        }
1199    };
1200    emd_batch_par_slice(high, low, sweep, kernel)
1201}
1202
1203#[derive(Clone, Debug)]
1204pub struct EmdBatchOutput {
1205    pub upperband: Vec<f64>,
1206    pub middleband: Vec<f64>,
1207    pub lowerband: Vec<f64>,
1208    pub combos: Vec<EmdParams>,
1209    pub rows: usize,
1210    pub cols: usize,
1211}
1212
1213#[inline(always)]
1214fn expand_grid(r: &EmdBatchRange) -> Result<Vec<EmdParams>, EmdError> {
1215    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, EmdError> {
1216        if step == 0 || start == end {
1217            return Ok(vec![start]);
1218        }
1219        let mut v = Vec::new();
1220        if start < end {
1221            let mut cur = start;
1222            while cur <= end {
1223                v.push(cur);
1224                cur = match cur.checked_add(step) {
1225                    Some(n) => n,
1226                    None => break,
1227                };
1228            }
1229        } else {
1230            let mut cur = start;
1231            while cur >= end {
1232                v.push(cur);
1233                match cur.checked_sub(step) {
1234                    Some(n) => cur = n,
1235                    None => break,
1236                }
1237            }
1238        }
1239        if v.is_empty() {
1240            Err(EmdError::InvalidRangeU { start, end, step })
1241        } else {
1242            Ok(v)
1243        }
1244    }
1245    fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, EmdError> {
1246        if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
1247            return Ok(vec![start]);
1248        }
1249        let mut v = Vec::new();
1250        if start < end {
1251            let mut x = start;
1252            while x <= end + 1e-12 {
1253                v.push(x);
1254                x += step;
1255                if !x.is_finite() {
1256                    break;
1257                }
1258            }
1259        } else {
1260            let mut x = start;
1261            while x >= end - 1e-12 {
1262                v.push(x);
1263                x -= step.abs();
1264                if !x.is_finite() {
1265                    break;
1266                }
1267            }
1268        }
1269        if v.is_empty() {
1270            Err(EmdError::InvalidRangeF { start, end, step })
1271        } else {
1272            Ok(v)
1273        }
1274    }
1275
1276    let periods = axis_usize(r.period)?;
1277    let deltas = axis_f64(r.delta)?;
1278    let fractions = axis_f64(r.fraction)?;
1279
1280    let cap = periods
1281        .len()
1282        .checked_mul(deltas.len())
1283        .and_then(|t| t.checked_mul(fractions.len()))
1284        .ok_or(EmdError::InvalidRangeU {
1285            start: 0,
1286            end: 0,
1287            step: 0,
1288        })?;
1289    let mut out = Vec::with_capacity(cap);
1290    for &p in &periods {
1291        for &d in &deltas {
1292            for &f in &fractions {
1293                out.push(EmdParams {
1294                    period: Some(p),
1295                    delta: Some(d),
1296                    fraction: Some(f),
1297                });
1298            }
1299        }
1300    }
1301    Ok(out)
1302}
1303
1304#[inline(always)]
1305pub fn emd_batch_slice(
1306    high: &[f64],
1307    low: &[f64],
1308    sweep: &EmdBatchRange,
1309    kern: Kernel,
1310) -> Result<EmdBatchOutput, EmdError> {
1311    if !kern.is_batch() && kern != Kernel::Auto {
1312        return Err(EmdError::InvalidKernelForBatch(kern));
1313    }
1314    emd_batch_inner(high, low, sweep, kern, false)
1315}
1316
1317#[inline(always)]
1318pub fn emd_batch_par_slice(
1319    high: &[f64],
1320    low: &[f64],
1321    sweep: &EmdBatchRange,
1322    kern: Kernel,
1323) -> Result<EmdBatchOutput, EmdError> {
1324    if !kern.is_batch() && kern != Kernel::Auto {
1325        return Err(EmdError::InvalidKernelForBatch(kern));
1326    }
1327    emd_batch_inner(high, low, sweep, kern, true)
1328}
1329
1330#[inline(always)]
1331fn emd_batch_inner(
1332    high: &[f64],
1333    low: &[f64],
1334    sweep: &EmdBatchRange,
1335    kern: Kernel,
1336    parallel: bool,
1337) -> Result<EmdBatchOutput, EmdError> {
1338    let combos = expand_grid(sweep)?;
1339    if combos.is_empty() {
1340        return Err(EmdError::InvalidRangeU {
1341            start: 0,
1342            end: 0,
1343            step: 0,
1344        });
1345    }
1346
1347    let len = high.len();
1348    if low.len() != len {
1349        return Err(EmdError::InvalidInputLength {
1350            expected: len,
1351            actual: low.len(),
1352        });
1353    }
1354
1355    let first = (0..len)
1356        .find(|&i| !high[i].is_nan() && !low[i].is_nan())
1357        .ok_or(EmdError::AllValuesNaN)?;
1358    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1359    let needed = (2 * max_p).max(50);
1360    if len - first < needed {
1361        return Err(EmdError::NotEnoughValidData {
1362            needed,
1363            valid: len - first,
1364        });
1365    }
1366
1367    let rows = combos.len();
1368    let cols = len;
1369    let total = rows.checked_mul(cols).ok_or(EmdError::InvalidRangeU {
1370        start: rows,
1371        end: cols,
1372        step: usize::MAX,
1373    })?;
1374
1375    let prices: Vec<f64> = high
1376        .iter()
1377        .zip(low.iter())
1378        .map(|(&h, &l)| (h + l) * 0.5)
1379        .collect();
1380
1381    let mut ub_mu = make_uninit_matrix(rows, cols);
1382    let mut mb_mu = make_uninit_matrix(rows, cols);
1383    let mut lb_mu = make_uninit_matrix(rows, cols);
1384
1385    let warm_up_low: Vec<usize> = combos.iter().map(|_| first + 50 - 1).collect();
1386    let warm_mid: Vec<usize> = combos
1387        .iter()
1388        .map(|c| first + 2 * c.period.unwrap() - 1)
1389        .collect();
1390
1391    init_matrix_prefixes(&mut ub_mu, cols, &warm_up_low);
1392    init_matrix_prefixes(&mut mb_mu, cols, &warm_mid);
1393    init_matrix_prefixes(&mut lb_mu, cols, &warm_up_low);
1394
1395    let ub_ptr = ub_mu.as_mut_ptr() as *mut f64 as usize;
1396    let mb_ptr = mb_mu.as_mut_ptr() as *mut f64 as usize;
1397    let lb_ptr = lb_mu.as_mut_ptr() as *mut f64 as usize;
1398
1399    let simd = match kern {
1400        Kernel::Auto => match detect_best_batch_kernel() {
1401            Kernel::Avx512Batch => Kernel::Avx512,
1402            Kernel::Avx2Batch => Kernel::Avx2,
1403            Kernel::ScalarBatch => Kernel::Scalar,
1404            _ => Kernel::Scalar,
1405        },
1406        Kernel::Avx512 | Kernel::Avx512Batch => Kernel::Avx512,
1407        Kernel::Avx2 | Kernel::Avx2Batch => Kernel::Avx2,
1408        Kernel::Scalar | Kernel::ScalarBatch => Kernel::Scalar,
1409        _ => Kernel::Scalar,
1410    };
1411
1412    if parallel {
1413        #[cfg(not(target_arch = "wasm32"))]
1414        (0..rows).into_par_iter().for_each(|row| {
1415            let prm = &combos[row];
1416            let p = prm.period.unwrap();
1417            let d = prm.delta.unwrap();
1418            let f = prm.fraction.unwrap();
1419
1420            let ub = unsafe {
1421                std::slice::from_raw_parts_mut((ub_ptr as *mut f64).add(row * cols), cols)
1422            };
1423            let mb = unsafe {
1424                std::slice::from_raw_parts_mut((mb_ptr as *mut f64).add(row * cols), cols)
1425            };
1426            let lb = unsafe {
1427                std::slice::from_raw_parts_mut((lb_ptr as *mut f64).add(row * cols), cols)
1428            };
1429
1430            emd_compute_from_prices_into(&prices, p, d, f, first, simd, ub, mb, lb);
1431        });
1432        #[cfg(target_arch = "wasm32")]
1433        {
1434            let ub_rows = unsafe { std::slice::from_raw_parts_mut(ub_ptr as *mut f64, total) };
1435            let mb_rows = unsafe { std::slice::from_raw_parts_mut(mb_ptr as *mut f64, total) };
1436            let lb_rows = unsafe { std::slice::from_raw_parts_mut(lb_ptr as *mut f64, total) };
1437            for row in 0..rows {
1438                let prm = &combos[row];
1439                let p = prm.period.unwrap();
1440                let d = prm.delta.unwrap();
1441                let f = prm.fraction.unwrap();
1442
1443                let ub = &mut ub_rows[row * cols..(row + 1) * cols];
1444                let mb = &mut mb_rows[row * cols..(row + 1) * cols];
1445                let lb = &mut lb_rows[row * cols..(row + 1) * cols];
1446
1447                emd_compute_from_prices_into(&prices, p, d, f, first, simd, ub, mb, lb);
1448            }
1449        }
1450    } else {
1451        let ub_rows = unsafe { std::slice::from_raw_parts_mut(ub_ptr as *mut f64, total) };
1452        let mb_rows = unsafe { std::slice::from_raw_parts_mut(mb_ptr as *mut f64, total) };
1453        let lb_rows = unsafe { std::slice::from_raw_parts_mut(lb_ptr as *mut f64, total) };
1454        for row in 0..rows {
1455            let prm = &combos[row];
1456            let p = prm.period.unwrap();
1457            let d = prm.delta.unwrap();
1458            let f = prm.fraction.unwrap();
1459
1460            let ub = &mut ub_rows[row * cols..(row + 1) * cols];
1461            let mb = &mut mb_rows[row * cols..(row + 1) * cols];
1462            let lb = &mut lb_rows[row * cols..(row + 1) * cols];
1463
1464            emd_compute_from_prices_into(&prices, p, d, f, first, simd, ub, mb, lb);
1465        }
1466    }
1467
1468    let upperband = unsafe { Vec::from_raw_parts(ub_mu.as_mut_ptr() as *mut f64, total, total) };
1469    let middleband = unsafe { Vec::from_raw_parts(mb_mu.as_mut_ptr() as *mut f64, total, total) };
1470    let lowerband = unsafe { Vec::from_raw_parts(lb_mu.as_mut_ptr() as *mut f64, total, total) };
1471
1472    std::mem::forget(ub_mu);
1473    std::mem::forget(mb_mu);
1474    std::mem::forget(lb_mu);
1475
1476    Ok(EmdBatchOutput {
1477        upperband,
1478        middleband,
1479        lowerband,
1480        combos,
1481        rows,
1482        cols,
1483    })
1484}
1485
1486#[inline(always)]
1487fn emd_batch_inner_into(
1488    high: &[f64],
1489    low: &[f64],
1490    sweep: &EmdBatchRange,
1491    kern: Kernel,
1492    parallel: bool,
1493    upperband_out: &mut [f64],
1494    middleband_out: &mut [f64],
1495    lowerband_out: &mut [f64],
1496) -> Result<Vec<EmdParams>, EmdError> {
1497    let combos = expand_grid(sweep)?;
1498    if combos.is_empty() {
1499        return Err(EmdError::InvalidRangeU {
1500            start: 0,
1501            end: 0,
1502            step: 0,
1503        });
1504    }
1505    let len = high.len();
1506    if low.len() != len {
1507        return Err(EmdError::InvalidInputLength {
1508            expected: len,
1509            actual: low.len(),
1510        });
1511    }
1512
1513    let rows = combos.len();
1514    let cols = len;
1515    let expected = rows.checked_mul(cols).ok_or(EmdError::InvalidRangeU {
1516        start: rows,
1517        end: cols,
1518        step: usize::MAX,
1519    })?;
1520    if upperband_out.len() != expected
1521        || middleband_out.len() != expected
1522        || lowerband_out.len() != expected
1523    {
1524        return Err(EmdError::OutputLengthMismatch {
1525            expected,
1526            got: upperband_out
1527                .len()
1528                .min(middleband_out.len())
1529                .min(lowerband_out.len()),
1530        });
1531    }
1532
1533    let first = (0..len)
1534        .find(|&i| !high[i].is_nan() && !low[i].is_nan())
1535        .ok_or(EmdError::AllValuesNaN)?;
1536    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1537    let needed = (2 * max_p).max(50);
1538    if len - first < needed {
1539        return Err(EmdError::NotEnoughValidData {
1540            needed,
1541            valid: len - first,
1542        });
1543    }
1544
1545    {
1546        let mut ub_mu = unsafe {
1547            std::slice::from_raw_parts_mut(
1548                upperband_out.as_mut_ptr() as *mut MaybeUninit<f64>,
1549                rows * cols,
1550            )
1551        };
1552        let mut mb_mu = unsafe {
1553            std::slice::from_raw_parts_mut(
1554                middleband_out.as_mut_ptr() as *mut MaybeUninit<f64>,
1555                rows * cols,
1556            )
1557        };
1558        let mut lb_mu = unsafe {
1559            std::slice::from_raw_parts_mut(
1560                lowerband_out.as_mut_ptr() as *mut MaybeUninit<f64>,
1561                rows * cols,
1562            )
1563        };
1564
1565        let warm_up_low: Vec<usize> = combos.iter().map(|_| first + 50 - 1).collect();
1566        let warm_mid: Vec<usize> = combos
1567            .iter()
1568            .map(|c| first + 2 * c.period.unwrap() - 1)
1569            .collect();
1570
1571        init_matrix_prefixes(&mut ub_mu, cols, &warm_up_low);
1572        init_matrix_prefixes(&mut mb_mu, cols, &warm_mid);
1573        init_matrix_prefixes(&mut lb_mu, cols, &warm_up_low);
1574    }
1575
1576    let ub_ptr = upperband_out.as_mut_ptr() as usize;
1577    let mb_ptr = middleband_out.as_mut_ptr() as usize;
1578    let lb_ptr = lowerband_out.as_mut_ptr() as usize;
1579
1580    let simd = match kern {
1581        Kernel::Auto => match detect_best_batch_kernel() {
1582            Kernel::Avx512Batch => Kernel::Avx512,
1583            Kernel::Avx2Batch => Kernel::Avx2,
1584            Kernel::ScalarBatch => Kernel::Scalar,
1585            _ => Kernel::Scalar,
1586        },
1587        Kernel::Avx512 | Kernel::Avx512Batch => Kernel::Avx512,
1588        Kernel::Avx2 | Kernel::Avx2Batch => Kernel::Avx2,
1589        Kernel::Scalar | Kernel::ScalarBatch => Kernel::Scalar,
1590        _ => Kernel::Scalar,
1591    };
1592
1593    if parallel {
1594        #[cfg(not(target_arch = "wasm32"))]
1595        (0..rows).into_par_iter().for_each(|row| {
1596            let prm = &combos[row];
1597            let p = prm.period.unwrap();
1598            let d = prm.delta.unwrap();
1599            let f = prm.fraction.unwrap();
1600
1601            let ub = unsafe {
1602                std::slice::from_raw_parts_mut((ub_ptr as *mut f64).add(row * cols), cols)
1603            };
1604            let mb = unsafe {
1605                std::slice::from_raw_parts_mut((mb_ptr as *mut f64).add(row * cols), cols)
1606            };
1607            let lb = unsafe {
1608                std::slice::from_raw_parts_mut((lb_ptr as *mut f64).add(row * cols), cols)
1609            };
1610
1611            emd_compute_into(high, low, p, d, f, first, simd, ub, mb, lb);
1612        });
1613        #[cfg(target_arch = "wasm32")]
1614        for row in 0..rows {
1615            let prm = &combos[row];
1616            let p = prm.period.unwrap();
1617            let d = prm.delta.unwrap();
1618            let f = prm.fraction.unwrap();
1619
1620            let ub = &mut upperband_out[row * cols..(row + 1) * cols];
1621            let mb = &mut middleband_out[row * cols..(row + 1) * cols];
1622            let lb = &mut lowerband_out[row * cols..(row + 1) * cols];
1623
1624            emd_compute_into(high, low, p, d, f, first, simd, ub, mb, lb);
1625        }
1626    } else {
1627        for row in 0..rows {
1628            let prm = &combos[row];
1629            let p = prm.period.unwrap();
1630            let d = prm.delta.unwrap();
1631            let f = prm.fraction.unwrap();
1632
1633            let ub = &mut upperband_out[row * cols..(row + 1) * cols];
1634            let mb = &mut middleband_out[row * cols..(row + 1) * cols];
1635            let lb = &mut lowerband_out[row * cols..(row + 1) * cols];
1636
1637            emd_compute_into(high, low, p, d, f, first, simd, ub, mb, lb);
1638        }
1639    }
1640
1641    Ok(combos)
1642}
1643
1644impl EmdBatchOutput {
1645    pub fn row_for_params(&self, p: &EmdParams) -> Option<usize> {
1646        self.combos.iter().position(|c| {
1647            c.period.unwrap_or(20) == p.period.unwrap_or(20)
1648                && (c.delta.unwrap_or(0.5) - p.delta.unwrap_or(0.5)).abs() < 1e-12
1649                && (c.fraction.unwrap_or(0.1) - p.fraction.unwrap_or(0.1)).abs() < 1e-12
1650        })
1651    }
1652    pub fn bands_for(&self, p: &EmdParams) -> Option<(&[f64], &[f64], &[f64])> {
1653        self.row_for_params(p).map(|row| {
1654            let start = row * self.cols;
1655            (
1656                &self.upperband[start..start + self.cols],
1657                &self.middleband[start..start + self.cols],
1658                &self.lowerband[start..start + self.cols],
1659            )
1660        })
1661    }
1662}
1663
1664#[cfg(test)]
1665mod tests {
1666    use super::*;
1667    use crate::skip_if_unsupported;
1668    use crate::utilities::data_loader::read_candles_from_csv;
1669
1670    #[test]
1671    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1672    fn test_emd_into_matches_api() -> Result<(), Box<dyn Error>> {
1673        let n = 256usize;
1674        let mut high = Vec::with_capacity(n);
1675        let mut low = Vec::with_capacity(n);
1676        for i in 0..n {
1677            let base = 100.0
1678                + (i as f64 * 0.01)
1679                + (2.0 * std::f64::consts::PI * (i as f64) / 17.0).sin() * 3.0
1680                + (2.0 * std::f64::consts::PI * (i as f64) / 49.0).cos() * 2.0;
1681            high.push(base + 1.25);
1682            low.push(base - 1.10);
1683        }
1684
1685        let params = EmdParams::default();
1686        let input = EmdInput::from_slices(&high, &low, &[], &[], params);
1687
1688        let baseline = emd(&input)?;
1689
1690        let mut ub = vec![0.0; n];
1691        let mut mb = vec![0.0; n];
1692        let mut lb = vec![0.0; n];
1693        emd_into(&input, &mut ub, &mut mb, &mut lb)?;
1694
1695        fn eq_or_both_nan(a: f64, b: f64) -> bool {
1696            (a.is_nan() && b.is_nan()) || (a == b)
1697        }
1698
1699        assert_eq!(baseline.upperband.len(), ub.len());
1700        assert_eq!(baseline.middleband.len(), mb.len());
1701        assert_eq!(baseline.lowerband.len(), lb.len());
1702        for i in 0..n {
1703            assert!(
1704                eq_or_both_nan(baseline.upperband[i], ub[i]),
1705                "upperband mismatch at {}: {:?} vs {:?}",
1706                i,
1707                baseline.upperband[i],
1708                ub[i]
1709            );
1710            assert!(
1711                eq_or_both_nan(baseline.middleband[i], mb[i]),
1712                "middleband mismatch at {}: {:?} vs {:?}",
1713                i,
1714                baseline.middleband[i],
1715                mb[i]
1716            );
1717            assert!(
1718                eq_or_both_nan(baseline.lowerband[i], lb[i]),
1719                "lowerband mismatch at {}: {:?} vs {:?}",
1720                i,
1721                baseline.lowerband[i],
1722                lb[i]
1723            );
1724        }
1725        Ok(())
1726    }
1727
1728    fn check_emd_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1729        skip_if_unsupported!(kernel, test_name);
1730        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1731        let candles = read_candles_from_csv(file_path)?;
1732
1733        let params = EmdParams::default();
1734        let input = EmdInput::from_candles(&candles, params);
1735        let emd_result = emd_with_kernel(&input, kernel)?;
1736
1737        let expected_last_five_upper = [
1738            50.33760237677157,
1739            50.28850695686447,
1740            50.23941153695737,
1741            50.19031611705027,
1742            48.709744457737344,
1743        ];
1744        let expected_last_five_middle = [
1745            -368.71064280396706,
1746            -399.11033986231377,
1747            -421.9368852621732,
1748            -437.879217150269,
1749            -447.3257167904511,
1750        ];
1751        let expected_last_five_lower = [
1752            -60.67834136221248,
1753            -60.93110347122829,
1754            -61.68154077026321,
1755            -62.43197806929814,
1756            -63.18241536833306,
1757        ];
1758
1759        let len = candles.close.len();
1760        let start_idx = len - 5;
1761        let actual_ub = &emd_result.upperband[start_idx..];
1762        let actual_mb = &emd_result.middleband[start_idx..];
1763        let actual_lb = &emd_result.lowerband[start_idx..];
1764        for i in 0..5 {
1765            assert!(
1766                (actual_ub[i] - expected_last_five_upper[i]).abs() < 1e-6,
1767                "Upperband mismatch at index {}: expected {}, got {}",
1768                i,
1769                expected_last_five_upper[i],
1770                actual_ub[i]
1771            );
1772            assert!(
1773                (actual_mb[i] - expected_last_five_middle[i]).abs() < 1e-6,
1774                "Middleband mismatch at index {}: expected {}, got {}",
1775                i,
1776                expected_last_five_middle[i],
1777                actual_mb[i]
1778            );
1779            assert!(
1780                (actual_lb[i] - expected_last_five_lower[i]).abs() < 1e-6,
1781                "Lowerband mismatch at index {}: expected {}, got {}",
1782                i,
1783                expected_last_five_lower[i],
1784                actual_lb[i]
1785            );
1786        }
1787        Ok(())
1788    }
1789
1790    fn check_emd_empty_data(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1791        skip_if_unsupported!(kernel, test_name);
1792        let empty_data: [f64; 0] = [];
1793        let params = EmdParams::default();
1794        let input =
1795            EmdInput::from_slices(&empty_data, &empty_data, &empty_data, &empty_data, params);
1796        let result = emd_with_kernel(&input, kernel);
1797        assert!(result.is_err(), "Expected error on empty data");
1798        Ok(())
1799    }
1800
1801    fn check_emd_all_nans(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1802        skip_if_unsupported!(kernel, test_name);
1803        let data = [f64::NAN, f64::NAN, f64::NAN];
1804        let params = EmdParams::default();
1805        let input = EmdInput::from_slices(&data, &data, &data, &data, params);
1806        let result = emd_with_kernel(&input, kernel);
1807        assert!(result.is_err(), "Expected error for all-NaN data");
1808        Ok(())
1809    }
1810
1811    fn check_emd_invalid_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1812        skip_if_unsupported!(kernel, test_name);
1813        let data = [1.0, 2.0, 3.0];
1814        let params = EmdParams {
1815            period: Some(0),
1816            ..Default::default()
1817        };
1818        let input = EmdInput::from_slices(&data, &data, &data, &data, params);
1819        let result = emd_with_kernel(&input, kernel);
1820        assert!(result.is_err(), "Expected error for zero period");
1821        Ok(())
1822    }
1823
1824    fn check_emd_not_enough_valid_data(
1825        test_name: &str,
1826        kernel: Kernel,
1827    ) -> Result<(), Box<dyn Error>> {
1828        skip_if_unsupported!(kernel, test_name);
1829        let data = vec![10.0; 10];
1830        let params = EmdParams {
1831            period: Some(20),
1832            ..Default::default()
1833        };
1834        let input = EmdInput::from_slices(&data, &data, &data, &data, params);
1835        let result = emd_with_kernel(&input, kernel);
1836        assert!(result.is_err(), "Expected error for not enough valid data");
1837        Ok(())
1838    }
1839
1840    fn check_emd_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1841        skip_if_unsupported!(kernel, test_name);
1842        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1843        let candles = read_candles_from_csv(file_path)?;
1844        let input = EmdInput::with_default_candles(&candles);
1845        let result = emd_with_kernel(&input, kernel);
1846        assert!(
1847            result.is_ok(),
1848            "Expected EMD to succeed with default params"
1849        );
1850        Ok(())
1851    }
1852
1853    #[cfg(feature = "proptest")]
1854    #[allow(clippy::float_cmp)]
1855    fn check_emd_property(
1856        test_name: &str,
1857        kernel: Kernel,
1858    ) -> Result<(), Box<dyn std::error::Error>> {
1859        use proptest::prelude::*;
1860        skip_if_unsupported!(kernel, test_name);
1861
1862        let strat1 = (2usize..=64).prop_flat_map(|period| {
1863            (
1864                prop::collection::vec(
1865                    (1f64..1000f64).prop_filter("finite", |x| x.is_finite()),
1866                    (2 * period).max(50)..400,
1867                ),
1868                Just(period),
1869                (0.1f64..1.0f64).prop_filter("finite", |x| x.is_finite()),
1870                (0.01f64..0.5f64).prop_filter("finite", |x| x.is_finite()),
1871            )
1872        });
1873
1874        let strat2 = prop::collection::vec(
1875            (100f64..10000f64).prop_filter("finite", |x| x.is_finite()),
1876            100..500,
1877        )
1878        .prop_map(|data| (data, 20usize, 0.5f64, 0.1f64));
1879
1880        let strat3 = (100usize..400, prop::bool::ANY).prop_map(|(len, increasing)| {
1881            let mut data = Vec::with_capacity(len);
1882            let mut val = 100.0;
1883            for _ in 0..len {
1884                data.push(val);
1885                val += if increasing { 1.0 } else { -1.0 };
1886            }
1887            (data, 14usize, 0.5f64, 0.1f64)
1888        });
1889
1890        let strat4 = (100usize..400, 5usize..50).prop_map(|(len, period_wave)| {
1891            let mut data = Vec::with_capacity(len);
1892            for i in 0..len {
1893                let val = 1000.0
1894                    + 100.0 * (2.0 * std::f64::consts::PI * i as f64 / period_wave as f64).sin();
1895                data.push(val);
1896            }
1897            (data, 20usize, 0.5f64, 0.1f64)
1898        });
1899
1900        let strat5 = (2usize..=30).prop_flat_map(|period| {
1901            let min_len = (2 * period).max(50);
1902            prop::collection::vec(
1903                (
1904                    50f64..150f64,
1905                    0.1f64..10f64,
1906                    -0.5f64..0.5f64,
1907                    0f64..0.5f64,
1908                    0f64..0.5f64,
1909                )
1910                    .prop_map(
1911                        |(base, range, close_offset, high_extra, low_extra)| {
1912                            let open = base;
1913                            let close = base + range * close_offset;
1914                            let high = open.max(close) + range * high_extra;
1915                            let low = open.min(close) - range * low_extra;
1916                            (high, low, close, base * 1000.0)
1917                        },
1918                    ),
1919                min_len..300,
1920            )
1921            .prop_map(move |ohlc_data| {
1922                let (highs, lows, closes, volumes): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) =
1923                    ohlc_data.into_iter().unzip4();
1924                (highs, lows, closes, volumes, period, 0.5f64, 0.1f64)
1925            })
1926        });
1927
1928        trait Unzip4<A, B, C, D> {
1929            fn unzip4(self) -> (Vec<A>, Vec<B>, Vec<C>, Vec<D>);
1930        }
1931        impl<A, B, C, D, I: Iterator<Item = (A, B, C, D)>> Unzip4<A, B, C, D> for I {
1932            fn unzip4(self) -> (Vec<A>, Vec<B>, Vec<C>, Vec<D>) {
1933                let (mut a_vec, mut b_vec, mut c_vec, mut d_vec) =
1934                    (Vec::new(), Vec::new(), Vec::new(), Vec::new());
1935                for (a, b, c, d) in self {
1936                    a_vec.push(a);
1937                    b_vec.push(b);
1938                    c_vec.push(c);
1939                    d_vec.push(d);
1940                }
1941                (a_vec, b_vec, c_vec, d_vec)
1942            }
1943        }
1944
1945        let combined_strat = prop_oneof![
1946            strat1.prop_map(|(data, period, delta, fraction)| {
1947                let high = data.clone();
1948                let low = data.clone();
1949                let close = data.clone();
1950                let volume = vec![1000.0; data.len()];
1951                (high, low, close, volume, period, delta, fraction)
1952            }),
1953            strat2.prop_map(|(data, period, delta, fraction)| {
1954                let high = data.iter().map(|x| x + 10.0).collect();
1955                let low = data.iter().map(|x| x - 10.0).collect();
1956                let close = data.clone();
1957                let volume = vec![1000.0; data.len()];
1958                (high, low, close, volume, period, delta, fraction)
1959            }),
1960            strat3.prop_map(|(data, period, delta, fraction)| {
1961                let high = data.iter().map(|x| x + 5.0).collect();
1962                let low = data.iter().map(|x| x - 5.0).collect();
1963                let close = data.clone();
1964                let volume = vec![1000.0; data.len()];
1965                (high, low, close, volume, period, delta, fraction)
1966            }),
1967            strat4.prop_map(|(data, period, delta, fraction)| {
1968                let high = data.iter().map(|x| x + 20.0).collect();
1969                let low = data.iter().map(|x| x - 20.0).collect();
1970                let close = data.clone();
1971                let volume = vec![5000.0; data.len()];
1972                (high, low, close, volume, period, delta, fraction)
1973            }),
1974            strat5,
1975        ];
1976
1977        proptest::test_runner::TestRunner::default().run(
1978            &combined_strat,
1979            |(high, low, close, volume, period, delta, fraction)| {
1980                for i in 0..high.len() {
1981                    prop_assert!(
1982                        high[i] >= low[i],
1983                        "Invalid OHLC data at index {}: high ({}) < low ({})",
1984                        i,
1985                        high[i],
1986                        low[i]
1987                    );
1988                }
1989
1990                let params = EmdParams {
1991                    period: Some(period),
1992                    delta: Some(delta),
1993                    fraction: Some(fraction),
1994                };
1995                let input = EmdInput::from_slices(&high, &low, &close, &volume, params);
1996
1997                let result = emd_with_kernel(&input, kernel).unwrap();
1998                let upperband = &result.upperband;
1999                let middleband = &result.middleband;
2000                let lowerband = &result.lowerband;
2001
2002                let ref_result = emd_with_kernel(&input, Kernel::Scalar).unwrap();
2003                let ref_upperband = &ref_result.upperband;
2004                let ref_middleband = &ref_result.middleband;
2005                let ref_lowerband = &ref_result.lowerband;
2006
2007                prop_assert_eq!(upperband.len(), high.len());
2008                prop_assert_eq!(middleband.len(), high.len());
2009                prop_assert_eq!(lowerband.len(), high.len());
2010
2011                let upperband_warmup = (50 - 1).min(high.len());
2012                let middleband_warmup = ((2 * period) - 1).min(high.len());
2013
2014                for i in 0..upperband_warmup {
2015                    prop_assert!(
2016                        upperband[i].is_nan(),
2017                        "Upperband should be NaN during warmup at index {}",
2018                        i
2019                    );
2020                    prop_assert!(
2021                        lowerband[i].is_nan(),
2022                        "Lowerband should be NaN during warmup at index {}",
2023                        i
2024                    );
2025                }
2026
2027                for i in 0..middleband_warmup {
2028                    prop_assert!(
2029                        middleband[i].is_nan(),
2030                        "Middleband should be NaN during warmup at index {}",
2031                        i
2032                    );
2033                }
2034
2035                let start_idx = upperband_warmup.max(middleband_warmup) + 1;
2036                if start_idx < high.len() {
2037                    let input_min = high[start_idx..]
2038                        .iter()
2039                        .chain(low[start_idx..].iter())
2040                        .fold(
2041                            f64::INFINITY,
2042                            |a, &b| if b.is_finite() { a.min(b) } else { a },
2043                        );
2044                    let input_max = high[start_idx..]
2045                        .iter()
2046                        .chain(low[start_idx..].iter())
2047                        .fold(
2048                            f64::NEG_INFINITY,
2049                            |a, &b| if b.is_finite() { a.max(b) } else { a },
2050                        );
2051
2052                    if input_min.is_finite() && input_max.is_finite() {
2053                        let range = input_max - input_min;
2054                        let center = (input_max + input_min) / 2.0;
2055
2056                        let bounds_factor = 3.0;
2057                        let lower_bound = center - bounds_factor * range.max(1.0);
2058                        let upper_bound = center + bounds_factor * range.max(1.0);
2059
2060                        for i in start_idx..high.len() {
2061                            if !upperband[i].is_nan()
2062                                && !middleband[i].is_nan()
2063                                && !lowerband[i].is_nan()
2064                            {
2065                                prop_assert!(
2066                                    upperband[i].is_finite(),
2067                                    "Upperband should be finite at index {}",
2068                                    i
2069                                );
2070                                prop_assert!(
2071                                    middleband[i].is_finite(),
2072                                    "Middleband should be finite at index {}",
2073                                    i
2074                                );
2075                                prop_assert!(
2076                                    lowerband[i].is_finite(),
2077                                    "Lowerband should be finite at index {}",
2078                                    i
2079                                );
2080
2081                                prop_assert!(
2082                                    upperband[i] >= lower_bound && upperband[i] <= upper_bound,
2083                                    "Upperband {} at index {} outside reasonable bounds [{}, {}]",
2084                                    upperband[i],
2085                                    i,
2086                                    lower_bound,
2087                                    upper_bound
2088                                );
2089                                prop_assert!(
2090                                    middleband[i] >= lower_bound && middleband[i] <= upper_bound,
2091                                    "Middleband {} at index {} outside reasonable bounds [{}, {}]",
2092                                    middleband[i],
2093                                    i,
2094                                    lower_bound,
2095                                    upper_bound
2096                                );
2097                                prop_assert!(
2098                                    lowerband[i] >= lower_bound && lowerband[i] <= upper_bound,
2099                                    "Lowerband {} at index {} outside reasonable bounds [{}, {}]",
2100                                    lowerband[i],
2101                                    i,
2102                                    lower_bound,
2103                                    upper_bound
2104                                );
2105                            }
2106                        }
2107                    }
2108                }
2109
2110                let tolerance = 1e-10;
2111                for i in 0..high.len() {
2112                    let ub_diff = (upperband[i] - ref_upperband[i]).abs();
2113                    let mb_diff = (middleband[i] - ref_middleband[i]).abs();
2114                    let lb_diff = (lowerband[i] - ref_lowerband[i]).abs();
2115
2116                    if !upperband[i].is_nan() && !ref_upperband[i].is_nan() {
2117                        prop_assert!(
2118                            ub_diff < tolerance,
2119                            "Upperband kernel mismatch at index {}: {} vs {} (diff: {})",
2120                            i,
2121                            upperband[i],
2122                            ref_upperband[i],
2123                            ub_diff
2124                        );
2125                    }
2126                    if !middleband[i].is_nan() && !ref_middleband[i].is_nan() {
2127                        prop_assert!(
2128                            mb_diff < tolerance,
2129                            "Middleband kernel mismatch at index {}: {} vs {} (diff: {})",
2130                            i,
2131                            middleband[i],
2132                            ref_middleband[i],
2133                            mb_diff
2134                        );
2135                    }
2136                    if !lowerband[i].is_nan() && !ref_lowerband[i].is_nan() {
2137                        prop_assert!(
2138                            lb_diff < tolerance,
2139                            "Lowerband kernel mismatch at index {}: {} vs {} (diff: {})",
2140                            i,
2141                            lowerband[i],
2142                            ref_lowerband[i],
2143                            lb_diff
2144                        );
2145                    }
2146                }
2147
2148                for i in 0..high.len() {
2149                    let ub_bits = upperband[i].to_bits();
2150                    let mb_bits = middleband[i].to_bits();
2151                    let lb_bits = lowerband[i].to_bits();
2152
2153                    prop_assert_ne!(
2154                        ub_bits,
2155                        0x1111_1111_1111_1111,
2156                        "Poison value in upperband at {}",
2157                        i
2158                    );
2159                    prop_assert_ne!(
2160                        ub_bits,
2161                        0x2222_2222_2222_2222,
2162                        "Poison value in upperband at {}",
2163                        i
2164                    );
2165                    prop_assert_ne!(
2166                        ub_bits,
2167                        0x3333_3333_3333_3333,
2168                        "Poison value in upperband at {}",
2169                        i
2170                    );
2171
2172                    prop_assert_ne!(
2173                        mb_bits,
2174                        0x1111_1111_1111_1111,
2175                        "Poison value in middleband at {}",
2176                        i
2177                    );
2178                    prop_assert_ne!(
2179                        mb_bits,
2180                        0x2222_2222_2222_2222,
2181                        "Poison value in middleband at {}",
2182                        i
2183                    );
2184                    prop_assert_ne!(
2185                        mb_bits,
2186                        0x3333_3333_3333_3333,
2187                        "Poison value in middleband at {}",
2188                        i
2189                    );
2190
2191                    prop_assert_ne!(
2192                        lb_bits,
2193                        0x1111_1111_1111_1111,
2194                        "Poison value in lowerband at {}",
2195                        i
2196                    );
2197                    prop_assert_ne!(
2198                        lb_bits,
2199                        0x2222_2222_2222_2222,
2200                        "Poison value in lowerband at {}",
2201                        i
2202                    );
2203                    prop_assert_ne!(
2204                        lb_bits,
2205                        0x3333_3333_3333_3333,
2206                        "Poison value in lowerband at {}",
2207                        i
2208                    );
2209                }
2210
2211                if period == 2 {
2212                    let min_warmup = (2 * 2).max(50);
2213                    if high.len() > min_warmup {
2214                        prop_assert!(
2215                            middleband[min_warmup].is_finite() || middleband[min_warmup].is_nan(),
2216                            "Period=2 should produce valid or NaN output"
2217                        );
2218                    }
2219                }
2220
2221                if high.windows(2).all(|w| (w[0] - w[1]).abs() < f64::EPSILON)
2222                    && low.windows(2).all(|w| (w[0] - w[1]).abs() < f64::EPSILON)
2223                    && close.windows(2).all(|w| (w[0] - w[1]).abs() < f64::EPSILON)
2224                {
2225                    let check_start = start_idx + period;
2226                    if check_start + 5 < high.len() {
2227                        let ub_stable = &upperband[check_start..check_start + 5];
2228                        let mb_stable = &middleband[check_start..check_start + 5];
2229                        let lb_stable = &lowerband[check_start..check_start + 5];
2230
2231                        for w in ub_stable.windows(2) {
2232                            if !w[0].is_nan() && !w[1].is_nan() {
2233                                prop_assert!(
2234                                    (w[0] - w[1]).abs() < 1e-6,
2235                                    "Upperband should be stable for constant input"
2236                                );
2237                            }
2238                        }
2239                        for w in mb_stable.windows(2) {
2240                            if !w[0].is_nan() && !w[1].is_nan() {
2241                                prop_assert!(
2242                                    (w[0] - w[1]).abs() < 1e-6,
2243                                    "Middleband should be stable for constant input"
2244                                );
2245                            }
2246                        }
2247                        for w in lb_stable.windows(2) {
2248                            if !w[0].is_nan() && !w[1].is_nan() {
2249                                prop_assert!(
2250                                    (w[0] - w[1]).abs() < 1e-6,
2251                                    "Lowerband should be stable for constant input"
2252                                );
2253                            }
2254                        }
2255                    }
2256                }
2257
2258                if fraction < 0.01 && start_idx < high.len() {
2259                    let check_end = (start_idx + 10).min(high.len());
2260                    for i in start_idx..check_end {
2261                        if !upperband[i].is_nan() && !lowerband[i].is_nan() {
2262                            prop_assert!(
2263                                upperband[i].abs() < 10.0,
2264                                "With fraction={}, upperband should be small, got {} at index {}",
2265                                fraction,
2266                                upperband[i],
2267                                i
2268                            );
2269                            prop_assert!(
2270                                lowerband[i].abs() < 10.0,
2271                                "With fraction={}, lowerband should be small, got {} at index {}",
2272                                fraction,
2273                                lowerband[i],
2274                                i
2275                            );
2276                        }
2277                    }
2278                }
2279
2280                Ok(())
2281            },
2282        )?;
2283
2284        Ok(())
2285    }
2286
2287    #[cfg(not(feature = "proptest"))]
2288    fn check_emd_property(
2289        _test_name: &str,
2290        _kernel: Kernel,
2291    ) -> Result<(), Box<dyn std::error::Error>> {
2292        Ok(())
2293    }
2294
2295    macro_rules! generate_all_emd_tests {
2296        ($($test_fn:ident),*) => {
2297            paste::paste! {
2298                $(
2299                    #[test]
2300                    fn [<$test_fn _scalar_f64>]() {
2301                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2302                    }
2303                )*
2304                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2305                $(
2306                    #[test]
2307                    fn [<$test_fn _avx2_f64>]() {
2308                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2309                    }
2310                    #[test]
2311                    fn [<$test_fn _avx512_f64>]() {
2312                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2313                    }
2314                )*
2315            }
2316        }
2317    }
2318
2319    #[cfg(debug_assertions)]
2320    fn check_emd_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2321        skip_if_unsupported!(kernel, test_name);
2322
2323        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2324        let candles = read_candles_from_csv(file_path)?;
2325
2326        let test_params = vec![
2327            EmdParams::default(),
2328            EmdParams {
2329                period: Some(2),
2330                delta: Some(0.1),
2331                fraction: Some(0.05),
2332            },
2333            EmdParams {
2334                period: Some(5),
2335                delta: Some(0.3),
2336                fraction: Some(0.1),
2337            },
2338            EmdParams {
2339                period: Some(10),
2340                delta: Some(0.5),
2341                fraction: Some(0.15),
2342            },
2343            EmdParams {
2344                period: Some(20),
2345                delta: Some(0.4),
2346                fraction: Some(0.1),
2347            },
2348            EmdParams {
2349                period: Some(30),
2350                delta: Some(0.6),
2351                fraction: Some(0.2),
2352            },
2353            EmdParams {
2354                period: Some(50),
2355                delta: Some(0.7),
2356                fraction: Some(0.25),
2357            },
2358            EmdParams {
2359                period: Some(100),
2360                delta: Some(0.8),
2361                fraction: Some(0.3),
2362            },
2363            EmdParams {
2364                period: Some(15),
2365                delta: Some(0.1),
2366                fraction: Some(0.1),
2367            },
2368            EmdParams {
2369                period: Some(15),
2370                delta: Some(0.9),
2371                fraction: Some(0.1),
2372            },
2373            EmdParams {
2374                period: Some(25),
2375                delta: Some(0.5),
2376                fraction: Some(0.01),
2377            },
2378            EmdParams {
2379                period: Some(25),
2380                delta: Some(0.5),
2381                fraction: Some(0.5),
2382            },
2383            EmdParams {
2384                period: Some(40),
2385                delta: Some(0.65),
2386                fraction: Some(0.12),
2387            },
2388            EmdParams {
2389                period: Some(7),
2390                delta: Some(0.25),
2391                fraction: Some(0.08),
2392            },
2393        ];
2394
2395        for (param_idx, params) in test_params.iter().enumerate() {
2396            let input = EmdInput::from_candles(&candles, params.clone());
2397            let output = emd_with_kernel(&input, kernel)?;
2398
2399            for (i, &val) in output.upperband.iter().enumerate() {
2400                if val.is_nan() {
2401                    continue;
2402                }
2403
2404                let bits = val.to_bits();
2405
2406                if bits == 0x11111111_11111111 {
2407                    panic!(
2408						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in upperband \
2409						 with params: period={}, delta={}, fraction={} (param set {})",
2410						test_name, val, bits, i,
2411						params.period.unwrap_or(20),
2412						params.delta.unwrap_or(0.5),
2413						params.fraction.unwrap_or(0.1),
2414						param_idx
2415					);
2416                }
2417
2418                if bits == 0x22222222_22222222 {
2419                    panic!(
2420						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in upperband \
2421						 with params: period={}, delta={}, fraction={} (param set {})",
2422						test_name, val, bits, i,
2423						params.period.unwrap_or(20),
2424						params.delta.unwrap_or(0.5),
2425						params.fraction.unwrap_or(0.1),
2426						param_idx
2427					);
2428                }
2429
2430                if bits == 0x33333333_33333333 {
2431                    panic!(
2432						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in upperband \
2433						 with params: period={}, delta={}, fraction={} (param set {})",
2434						test_name, val, bits, i,
2435						params.period.unwrap_or(20),
2436						params.delta.unwrap_or(0.5),
2437						params.fraction.unwrap_or(0.1),
2438						param_idx
2439					);
2440                }
2441            }
2442
2443            for (i, &val) in output.middleband.iter().enumerate() {
2444                if val.is_nan() {
2445                    continue;
2446                }
2447
2448                let bits = val.to_bits();
2449
2450                if bits == 0x11111111_11111111 {
2451                    panic!(
2452						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in middleband \
2453						 with params: period={}, delta={}, fraction={} (param set {})",
2454						test_name, val, bits, i,
2455						params.period.unwrap_or(20),
2456						params.delta.unwrap_or(0.5),
2457						params.fraction.unwrap_or(0.1),
2458						param_idx
2459					);
2460                }
2461
2462                if bits == 0x22222222_22222222 {
2463                    panic!(
2464						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in middleband \
2465						 with params: period={}, delta={}, fraction={} (param set {})",
2466						test_name, val, bits, i,
2467						params.period.unwrap_or(20),
2468						params.delta.unwrap_or(0.5),
2469						params.fraction.unwrap_or(0.1),
2470						param_idx
2471					);
2472                }
2473
2474                if bits == 0x33333333_33333333 {
2475                    panic!(
2476						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in middleband \
2477						 with params: period={}, delta={}, fraction={} (param set {})",
2478						test_name, val, bits, i,
2479						params.period.unwrap_or(20),
2480						params.delta.unwrap_or(0.5),
2481						params.fraction.unwrap_or(0.1),
2482						param_idx
2483					);
2484                }
2485            }
2486
2487            for (i, &val) in output.lowerband.iter().enumerate() {
2488                if val.is_nan() {
2489                    continue;
2490                }
2491
2492                let bits = val.to_bits();
2493
2494                if bits == 0x11111111_11111111 {
2495                    panic!(
2496						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in lowerband \
2497						 with params: period={}, delta={}, fraction={} (param set {})",
2498						test_name, val, bits, i,
2499						params.period.unwrap_or(20),
2500						params.delta.unwrap_or(0.5),
2501						params.fraction.unwrap_or(0.1),
2502						param_idx
2503					);
2504                }
2505
2506                if bits == 0x22222222_22222222 {
2507                    panic!(
2508						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in lowerband \
2509						 with params: period={}, delta={}, fraction={} (param set {})",
2510						test_name, val, bits, i,
2511						params.period.unwrap_or(20),
2512						params.delta.unwrap_or(0.5),
2513						params.fraction.unwrap_or(0.1),
2514						param_idx
2515					);
2516                }
2517
2518                if bits == 0x33333333_33333333 {
2519                    panic!(
2520						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in lowerband \
2521						 with params: period={}, delta={}, fraction={} (param set {})",
2522						test_name, val, bits, i,
2523						params.period.unwrap_or(20),
2524						params.delta.unwrap_or(0.5),
2525						params.fraction.unwrap_or(0.1),
2526						param_idx
2527					);
2528                }
2529            }
2530        }
2531
2532        Ok(())
2533    }
2534
2535    #[cfg(not(debug_assertions))]
2536    fn check_emd_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2537        Ok(())
2538    }
2539
2540    generate_all_emd_tests!(
2541        check_emd_accuracy,
2542        check_emd_empty_data,
2543        check_emd_all_nans,
2544        check_emd_invalid_period,
2545        check_emd_not_enough_valid_data,
2546        check_emd_default_candles,
2547        check_emd_no_poison
2548    );
2549
2550    #[cfg(feature = "proptest")]
2551    generate_all_emd_tests!(check_emd_property);
2552
2553    #[cfg(test)]
2554    mod batch_tests {
2555        use super::*;
2556        use crate::skip_if_unsupported;
2557        use crate::utilities::data_loader::read_candles_from_csv;
2558
2559        fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2560            skip_if_unsupported!(kernel, test);
2561
2562            let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2563            let c = read_candles_from_csv(file)?;
2564
2565            let output = EmdBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
2566
2567            let def = EmdParams::default();
2568            let (ub, mb, lb) = output.bands_for(&def).expect("default row missing");
2569
2570            assert_eq!(ub.len(), c.close.len(), "Upperband length mismatch");
2571            assert_eq!(mb.len(), c.close.len(), "Middleband length mismatch");
2572            assert_eq!(lb.len(), c.close.len(), "Lowerband length mismatch");
2573
2574            let expected_last_five_upper = [
2575                50.33760237677157,
2576                50.28850695686447,
2577                50.23941153695737,
2578                50.19031611705027,
2579                48.709744457737344,
2580            ];
2581            let expected_last_five_middle = [
2582                -368.71064280396706,
2583                -399.11033986231377,
2584                -421.9368852621732,
2585                -437.879217150269,
2586                -447.3257167904511,
2587            ];
2588            let expected_last_five_lower = [
2589                -60.67834136221248,
2590                -60.93110347122829,
2591                -61.68154077026321,
2592                -62.43197806929814,
2593                -63.18241536833306,
2594            ];
2595            let len = ub.len();
2596            for i in 0..5 {
2597                assert!(
2598                    (ub[len - 5 + i] - expected_last_five_upper[i]).abs() < 1e-6,
2599                    "[{test}] upperband mismatch idx {i}: {} vs {}",
2600                    ub[len - 5 + i],
2601                    expected_last_five_upper[i]
2602                );
2603                assert!(
2604                    (mb[len - 5 + i] - expected_last_five_middle[i]).abs() < 1e-6,
2605                    "[{test}] middleband mismatch idx {i}: {} vs {}",
2606                    mb[len - 5 + i],
2607                    expected_last_five_middle[i]
2608                );
2609                assert!(
2610                    (lb[len - 5 + i] - expected_last_five_lower[i]).abs() < 1e-6,
2611                    "[{test}] lowerband mismatch idx {i}: {} vs {}",
2612                    lb[len - 5 + i],
2613                    expected_last_five_lower[i]
2614                );
2615            }
2616
2617            Ok(())
2618        }
2619
2620        fn check_batch_param_sweep(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2621            skip_if_unsupported!(kernel, test);
2622
2623            let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2624            let c = read_candles_from_csv(file)?;
2625
2626            let output = EmdBatchBuilder::new()
2627                .kernel(kernel)
2628                .period_range(20, 22, 2)
2629                .delta_range(0.5, 0.6, 0.1)
2630                .fraction_range(0.1, 0.2, 0.1)
2631                .apply_candles(&c)?;
2632
2633            assert!(
2634                output.rows == 8,
2635                "Expected 8 rows (2*2*2 grid), got {}",
2636                output.rows
2637            );
2638            assert_eq!(output.cols, c.close.len());
2639
2640            for params in &output.combos {
2641                let (ub, mb, lb) = output
2642                    .bands_for(params)
2643                    .expect("row for params missing in sweep");
2644                assert_eq!(ub.len(), output.cols);
2645                assert_eq!(mb.len(), output.cols);
2646                assert_eq!(lb.len(), output.cols);
2647            }
2648
2649            Ok(())
2650        }
2651
2652        macro_rules! gen_batch_tests {
2653            ($fn_name:ident) => {
2654                paste::paste! {
2655                    #[test] fn [<$fn_name _scalar>]()      {
2656                        let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2657                    }
2658                    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2659                    #[test] fn [<$fn_name _avx2>]()        {
2660                        let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2661                    }
2662                    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2663                    #[test] fn [<$fn_name _avx512>]()      {
2664                        let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2665                    }
2666                    #[test] fn [<$fn_name _auto_detect>]() {
2667                        let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2668                    }
2669                }
2670            };
2671        }
2672
2673        #[cfg(debug_assertions)]
2674        fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2675            skip_if_unsupported!(kernel, test);
2676
2677            let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2678            let c = read_candles_from_csv(file)?;
2679
2680            let test_configs = vec![
2681                (5, 15, 5, 0.1, 0.5, 0.2, 0.05, 0.15, 0.05),
2682                (10, 30, 10, 0.3, 0.7, 0.2, 0.1, 0.2, 0.05),
2683                (20, 50, 15, 0.5, 0.8, 0.15, 0.15, 0.3, 0.075),
2684                (8, 12, 1, 0.4, 0.6, 0.1, 0.08, 0.12, 0.02),
2685                (20, 20, 0, 0.5, 0.5, 0.0, 0.1, 0.1, 0.0),
2686                (5, 40, 5, 0.2, 0.8, 0.1, 0.05, 0.25, 0.05),
2687                (2, 6, 2, 0.1, 0.9, 0.4, 0.01, 0.5, 0.245),
2688            ];
2689
2690            for (
2691                cfg_idx,
2692                &(p_start, p_end, p_step, d_start, d_end, d_step, f_start, f_end, f_step),
2693            ) in test_configs.iter().enumerate()
2694            {
2695                let output = EmdBatchBuilder::new()
2696                    .kernel(kernel)
2697                    .period_range(p_start, p_end, p_step)
2698                    .delta_range(d_start, d_end, d_step)
2699                    .fraction_range(f_start, f_end, f_step)
2700                    .apply_candles(&c)?;
2701
2702                for (idx, &val) in output.upperband.iter().enumerate() {
2703                    if val.is_nan() {
2704                        continue;
2705                    }
2706
2707                    let bits = val.to_bits();
2708                    let row = idx / output.cols;
2709                    let col = idx % output.cols;
2710                    let combo = &output.combos[row];
2711
2712                    if bits == 0x11111111_11111111 {
2713                        panic!(
2714							"[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) in upperband \
2715							 at row {} col {} (flat index {}) with params: period={}, delta={}, fraction={}",
2716							test, cfg_idx, val, bits, row, col, idx,
2717							combo.period.unwrap_or(20),
2718							combo.delta.unwrap_or(0.5),
2719							combo.fraction.unwrap_or(0.1)
2720						);
2721                    }
2722
2723                    if bits == 0x22222222_22222222 {
2724                        panic!(
2725							"[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) in upperband \
2726							 at row {} col {} (flat index {}) with params: period={}, delta={}, fraction={}",
2727							test, cfg_idx, val, bits, row, col, idx,
2728							combo.period.unwrap_or(20),
2729							combo.delta.unwrap_or(0.5),
2730							combo.fraction.unwrap_or(0.1)
2731						);
2732                    }
2733
2734                    if bits == 0x33333333_33333333 {
2735                        panic!(
2736							"[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) in upperband \
2737							 at row {} col {} (flat index {}) with params: period={}, delta={}, fraction={}",
2738							test, cfg_idx, val, bits, row, col, idx,
2739							combo.period.unwrap_or(20),
2740							combo.delta.unwrap_or(0.5),
2741							combo.fraction.unwrap_or(0.1)
2742						);
2743                    }
2744                }
2745
2746                for (idx, &val) in output.middleband.iter().enumerate() {
2747                    if val.is_nan() {
2748                        continue;
2749                    }
2750
2751                    let bits = val.to_bits();
2752                    let row = idx / output.cols;
2753                    let col = idx % output.cols;
2754                    let combo = &output.combos[row];
2755
2756                    if bits == 0x11111111_11111111 {
2757                        panic!(
2758							"[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) in middleband \
2759							 at row {} col {} (flat index {}) with params: period={}, delta={}, fraction={}",
2760							test, cfg_idx, val, bits, row, col, idx,
2761							combo.period.unwrap_or(20),
2762							combo.delta.unwrap_or(0.5),
2763							combo.fraction.unwrap_or(0.1)
2764						);
2765                    }
2766
2767                    if bits == 0x22222222_22222222 {
2768                        panic!(
2769							"[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) in middleband \
2770							 at row {} col {} (flat index {}) with params: period={}, delta={}, fraction={}",
2771							test, cfg_idx, val, bits, row, col, idx,
2772							combo.period.unwrap_or(20),
2773							combo.delta.unwrap_or(0.5),
2774							combo.fraction.unwrap_or(0.1)
2775						);
2776                    }
2777
2778                    if bits == 0x33333333_33333333 {
2779                        panic!(
2780							"[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) in middleband \
2781							 at row {} col {} (flat index {}) with params: period={}, delta={}, fraction={}",
2782							test, cfg_idx, val, bits, row, col, idx,
2783							combo.period.unwrap_or(20),
2784							combo.delta.unwrap_or(0.5),
2785							combo.fraction.unwrap_or(0.1)
2786						);
2787                    }
2788                }
2789
2790                for (idx, &val) in output.lowerband.iter().enumerate() {
2791                    if val.is_nan() {
2792                        continue;
2793                    }
2794
2795                    let bits = val.to_bits();
2796                    let row = idx / output.cols;
2797                    let col = idx % output.cols;
2798                    let combo = &output.combos[row];
2799
2800                    if bits == 0x11111111_11111111 {
2801                        panic!(
2802							"[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) in lowerband \
2803							 at row {} col {} (flat index {}) with params: period={}, delta={}, fraction={}",
2804							test, cfg_idx, val, bits, row, col, idx,
2805							combo.period.unwrap_or(20),
2806							combo.delta.unwrap_or(0.5),
2807							combo.fraction.unwrap_or(0.1)
2808						);
2809                    }
2810
2811                    if bits == 0x22222222_22222222 {
2812                        panic!(
2813							"[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) in lowerband \
2814							 at row {} col {} (flat index {}) with params: period={}, delta={}, fraction={}",
2815							test, cfg_idx, val, bits, row, col, idx,
2816							combo.period.unwrap_or(20),
2817							combo.delta.unwrap_or(0.5),
2818							combo.fraction.unwrap_or(0.1)
2819						);
2820                    }
2821
2822                    if bits == 0x33333333_33333333 {
2823                        panic!(
2824							"[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) in lowerband \
2825							 at row {} col {} (flat index {}) with params: period={}, delta={}, fraction={}",
2826							test, cfg_idx, val, bits, row, col, idx,
2827							combo.period.unwrap_or(20),
2828							combo.delta.unwrap_or(0.5),
2829							combo.fraction.unwrap_or(0.1)
2830						);
2831                    }
2832                }
2833            }
2834
2835            Ok(())
2836        }
2837
2838        #[cfg(not(debug_assertions))]
2839        fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2840            Ok(())
2841        }
2842
2843        gen_batch_tests!(check_batch_default_row);
2844        gen_batch_tests!(check_batch_param_sweep);
2845        gen_batch_tests!(check_batch_no_poison);
2846    }
2847}
2848
2849#[cfg(feature = "python")]
2850#[pyfunction(name = "emd")]
2851#[pyo3(signature = (high, low, period, delta, fraction, kernel=None))]
2852pub fn emd_py<'py>(
2853    py: Python<'py>,
2854    high: PyReadonlyArray1<'py, f64>,
2855    low: PyReadonlyArray1<'py, f64>,
2856    period: usize,
2857    delta: f64,
2858    fraction: f64,
2859    kernel: Option<&str>,
2860) -> PyResult<(
2861    Bound<'py, PyArray1<f64>>,
2862    Bound<'py, PyArray1<f64>>,
2863    Bound<'py, PyArray1<f64>>,
2864)> {
2865    let hi = high.as_slice()?;
2866    let lo = low.as_slice()?;
2867    if hi.len() != lo.len() {
2868        return Err(PyValueError::new_err("high and low must have same length"));
2869    }
2870
2871    let params = EmdParams {
2872        period: Some(period),
2873        delta: Some(delta),
2874        fraction: Some(fraction),
2875    };
2876    let inp = EmdInput::from_slices(hi, lo, &[], &[], params);
2877    let kern = validate_kernel(kernel, false)?;
2878
2879    let ub = unsafe { PyArray1::<f64>::new(py, [hi.len()], false) };
2880    let mb = unsafe { PyArray1::<f64>::new(py, [hi.len()], false) };
2881    let lb = unsafe { PyArray1::<f64>::new(py, [hi.len()], false) };
2882
2883    let ubm = unsafe { ub.as_slice_mut()? };
2884    let mbm = unsafe { mb.as_slice_mut()? };
2885    let lbm = unsafe { lb.as_slice_mut()? };
2886
2887    py.allow_threads(|| emd_into_slices(ubm, mbm, lbm, &inp, kern))
2888        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2889
2890    Ok((ub, mb, lb))
2891}
2892
2893#[cfg(feature = "python")]
2894#[pyclass(name = "EmdStream")]
2895pub struct EmdStreamPy {
2896    stream: EmdStream,
2897}
2898
2899#[cfg(feature = "python")]
2900#[pymethods]
2901impl EmdStreamPy {
2902    #[new]
2903    fn new(period: usize, delta: f64, fraction: f64) -> PyResult<Self> {
2904        let params = EmdParams {
2905            period: Some(period),
2906            delta: Some(delta),
2907            fraction: Some(fraction),
2908        };
2909        let stream =
2910            EmdStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2911        Ok(EmdStreamPy { stream })
2912    }
2913
2914    fn update(&mut self, high: f64, low: f64) -> (Option<f64>, Option<f64>, Option<f64>) {
2915        self.stream.update(high, low)
2916    }
2917}
2918
2919#[cfg(feature = "python")]
2920#[pyfunction(name = "emd_batch")]
2921#[pyo3(signature = (high, low, period_range, delta_range, fraction_range, kernel=None))]
2922pub fn emd_batch_py<'py>(
2923    py: Python<'py>,
2924    high: PyReadonlyArray1<'py, f64>,
2925    low: PyReadonlyArray1<'py, f64>,
2926    period_range: (usize, usize, usize),
2927    delta_range: (f64, f64, f64),
2928    fraction_range: (f64, f64, f64),
2929    kernel: Option<&str>,
2930) -> PyResult<Bound<'py, PyDict>> {
2931    use numpy::PyArrayMethods;
2932
2933    let hi = high.as_slice()?;
2934    let lo = low.as_slice()?;
2935    if hi.len() != lo.len() {
2936        return Err(PyValueError::new_err("high and low must have same length"));
2937    }
2938
2939    let sweep = EmdBatchRange {
2940        period: period_range,
2941        delta: delta_range,
2942        fraction: fraction_range,
2943    };
2944    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2945    let rows = combos.len();
2946    let cols = hi.len();
2947    let total = rows
2948        .checked_mul(cols)
2949        .ok_or_else(|| PyValueError::new_err("rows * cols overflow in emd_batch_py"))?;
2950
2951    let ub = unsafe { PyArray1::<f64>::new(py, [total], false) };
2952    let mb = unsafe { PyArray1::<f64>::new(py, [total], false) };
2953    let lb = unsafe { PyArray1::<f64>::new(py, [total], false) };
2954
2955    let ubm = unsafe { ub.as_slice_mut()? };
2956    let mbm = unsafe { mb.as_slice_mut()? };
2957    let lbm = unsafe { lb.as_slice_mut()? };
2958
2959    let kern = validate_kernel(kernel, true)?;
2960
2961    let combos = py
2962        .allow_threads(|| {
2963            let k = match kern {
2964                Kernel::Auto => detect_best_batch_kernel(),
2965                k => k,
2966            };
2967            let simd = match k {
2968                Kernel::Avx512Batch => Kernel::Avx512,
2969                Kernel::Avx2Batch => Kernel::Avx2,
2970                Kernel::ScalarBatch => Kernel::Scalar,
2971                _ => unreachable!(),
2972            };
2973            emd_batch_inner_into(hi, lo, &sweep, simd, true, ubm, mbm, lbm)
2974        })
2975        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2976
2977    let d = PyDict::new(py);
2978    d.set_item("upper", ub.reshape((rows, cols))?)?;
2979    d.set_item("middle", mb.reshape((rows, cols))?)?;
2980    d.set_item("lower", lb.reshape((rows, cols))?)?;
2981    d.set_item(
2982        "periods",
2983        combos
2984            .iter()
2985            .map(|p| p.period.unwrap() as u64)
2986            .collect::<Vec<_>>()
2987            .into_pyarray(py),
2988    )?;
2989    d.set_item(
2990        "deltas",
2991        combos
2992            .iter()
2993            .map(|p| p.delta.unwrap())
2994            .collect::<Vec<_>>()
2995            .into_pyarray(py),
2996    )?;
2997    d.set_item(
2998        "fractions",
2999        combos
3000            .iter()
3001            .map(|p| p.fraction.unwrap())
3002            .collect::<Vec<_>>()
3003            .into_pyarray(py),
3004    )?;
3005    Ok(d)
3006}
3007
3008#[cfg(all(feature = "python", feature = "cuda"))]
3009#[pyfunction(name = "emd_cuda_batch_dev")]
3010#[pyo3(signature = (high, low, period_range, delta_range, fraction_range, device_id=0))]
3011pub fn emd_cuda_batch_dev_py<'py>(
3012    py: Python<'py>,
3013    high: PyReadonlyArray1<'py, f32>,
3014    low: PyReadonlyArray1<'py, f32>,
3015    period_range: (usize, usize, usize),
3016    delta_range: (f64, f64, f64),
3017    fraction_range: (f64, f64, f64),
3018    device_id: usize,
3019) -> PyResult<Bound<'py, PyDict>> {
3020    use numpy::PyArrayMethods;
3021    if !cuda_available() {
3022        return Err(PyValueError::new_err("CUDA not available"));
3023    }
3024    let hi = high.as_slice()?;
3025    let lo = low.as_slice()?;
3026    if hi.len() != lo.len() {
3027        return Err(PyValueError::new_err("high and low must have same length"));
3028    }
3029    let sweep = EmdBatchRange {
3030        period: period_range,
3031        delta: delta_range,
3032        fraction: fraction_range,
3033    };
3034    let (outputs, combos, dev_id) = py.allow_threads(|| {
3035        let cuda = CudaEmd::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
3036        let dev_id = cuda.device_id();
3037        let res = cuda
3038            .emd_batch_dev(hi, lo, &sweep)
3039            .map_err(|e| PyValueError::new_err(e.to_string()));
3040        res.map(|r| (r.outputs, r.combos, dev_id))
3041    })?;
3042    let DeviceArrayF32Triple {
3043        upper,
3044        middle,
3045        lower,
3046    } = outputs;
3047    let dict = pyo3::types::PyDict::new(py);
3048    let upper_dev = make_device_array_py(dev_id as usize, upper)?;
3049    dict.set_item("upperband", Py::new(py, upper_dev)?)?;
3050    let middle_dev = make_device_array_py(dev_id as usize, middle)?;
3051    dict.set_item("middleband", Py::new(py, middle_dev)?)?;
3052    let lower_dev = make_device_array_py(dev_id as usize, lower)?;
3053    dict.set_item("lowerband", Py::new(py, lower_dev)?)?;
3054
3055    let periods: Vec<usize> = combos.iter().map(|c| c.period.unwrap()).collect();
3056    let deltas: Vec<f64> = combos.iter().map(|c| c.delta.unwrap()).collect();
3057    let fractions: Vec<f64> = combos.iter().map(|c| c.fraction.unwrap()).collect();
3058    dict.set_item("periods", periods.into_pyarray(py))?;
3059    dict.set_item("deltas", deltas.into_pyarray(py))?;
3060    dict.set_item("fractions", fractions.into_pyarray(py))?;
3061    dict.set_item("rows", combos.len())?;
3062    dict.set_item("cols", hi.len())?;
3063    Ok(dict)
3064}
3065
3066#[cfg(all(feature = "python", feature = "cuda"))]
3067#[pyfunction(name = "emd_cuda_many_series_one_param_dev")]
3068#[pyo3(signature = (data_tm_f32, period, delta, fraction, device_id=0))]
3069pub fn emd_cuda_many_series_one_param_dev_py<'py>(
3070    py: Python<'py>,
3071    data_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
3072    period: usize,
3073    delta: f64,
3074    fraction: f64,
3075    device_id: usize,
3076) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
3077    use numpy::PyUntypedArrayMethods;
3078    if !cuda_available() {
3079        return Err(PyValueError::new_err("CUDA not available"));
3080    }
3081    let shape = data_tm_f32.shape();
3082    if shape.len() != 2 {
3083        return Err(PyValueError::new_err("expected 2D array"));
3084    }
3085    let rows = shape[0];
3086    let cols = shape[1];
3087    let flat = data_tm_f32.as_slice()?;
3088
3089    let mut first_valids = vec![0i32; cols];
3090    for s in 0..cols {
3091        let mut fv: Option<i32> = None;
3092        for t in 0..rows {
3093            let v = flat[t * cols + s];
3094            if v.is_finite() {
3095                fv = Some(t as i32);
3096                break;
3097            }
3098        }
3099        first_valids[s] =
3100            fv.ok_or_else(|| PyValueError::new_err(format!("series {} has no finite values", s)))?;
3101    }
3102
3103    let params = EmdParams {
3104        period: Some(period),
3105        delta: Some(delta),
3106        fraction: Some(fraction),
3107    };
3108    let (outputs, dev_id) = py.allow_threads(|| {
3109        let cuda = CudaEmd::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
3110        let dev_id = cuda.device_id();
3111        cuda.emd_many_series_one_param_time_major_dev(flat, cols, rows, &params, &first_valids)
3112            .map(|o| (o, dev_id))
3113            .map_err(|e| PyValueError::new_err(e.to_string()))
3114    })?;
3115    let DeviceArrayF32Triple {
3116        upper,
3117        middle,
3118        lower,
3119    } = outputs;
3120    let dict = pyo3::types::PyDict::new(py);
3121    let upper_dev = make_device_array_py(dev_id as usize, upper)?;
3122    dict.set_item("upperband", Py::new(py, upper_dev)?)?;
3123    let middle_dev = make_device_array_py(dev_id as usize, middle)?;
3124    dict.set_item("middleband", Py::new(py, middle_dev)?)?;
3125    let lower_dev = make_device_array_py(dev_id as usize, lower)?;
3126    dict.set_item("lowerband", Py::new(py, lower_dev)?)?;
3127    dict.set_item("rows", rows)?;
3128    dict.set_item("cols", cols)?;
3129    dict.set_item("period", period)?;
3130    dict.set_item("delta", delta)?;
3131    dict.set_item("fraction", fraction)?;
3132    Ok(dict)
3133}
3134
3135#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3136use serde::{Deserialize, Serialize};
3137#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3138use wasm_bindgen::prelude::*;
3139
3140#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3141#[derive(Serialize, Deserialize)]
3142pub struct EmdJsOutput {
3143    pub values: Vec<f64>,
3144    pub rows: usize,
3145    pub cols: usize,
3146}
3147
3148#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3149#[wasm_bindgen(js_name = "emd_js")]
3150pub fn emd_js(
3151    high: &[f64],
3152    low: &[f64],
3153    _close: &[f64],
3154    _volume: &[f64],
3155    period: usize,
3156    delta: f64,
3157    fraction: f64,
3158) -> Result<JsValue, JsValue> {
3159    if high.len() != low.len() {
3160        return Err(JsValue::from_str("high and low must have same length"));
3161    }
3162    let params = EmdParams {
3163        period: Some(period),
3164        delta: Some(delta),
3165        fraction: Some(fraction),
3166    };
3167    let input = EmdInput::from_slices(high, low, &[], &[], params);
3168
3169    let mut values = vec![f64::NAN; 3 * high.len()];
3170    let (ub, rest) = values.split_at_mut(high.len());
3171    let (mb, lb) = rest.split_at_mut(high.len());
3172
3173    emd_into_slices(ub, mb, lb, &input, detect_best_kernel())
3174        .map_err(|e| JsValue::from_str(&e.to_string()))?;
3175
3176    let output = EmdJsOutput {
3177        values,
3178        rows: 3,
3179        cols: high.len(),
3180    };
3181    serde_wasm_bindgen::to_value(&output).map_err(|e| JsValue::from_str(&e.to_string()))
3182}
3183
3184#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3185#[wasm_bindgen]
3186pub fn emd_alloc(len: usize) -> *mut f64 {
3187    let mut vec = Vec::<f64>::with_capacity(len);
3188    let ptr = vec.as_mut_ptr();
3189    std::mem::forget(vec);
3190    ptr
3191}
3192
3193#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3194#[wasm_bindgen]
3195pub fn emd_free(ptr: *mut f64, len: usize) {
3196    if !ptr.is_null() {
3197        unsafe {
3198            let _ = Vec::from_raw_parts(ptr, len, len);
3199        }
3200    }
3201}
3202
3203#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3204#[wasm_bindgen]
3205pub fn emd_into(
3206    high_ptr: *const f64,
3207    low_ptr: *const f64,
3208    _close_ptr: *const f64,
3209    _volume_ptr: *const f64,
3210    upper_ptr: *mut f64,
3211    middle_ptr: *mut f64,
3212    lower_ptr: *mut f64,
3213    len: usize,
3214    period: usize,
3215    delta: f64,
3216    fraction: f64,
3217) -> Result<(), JsValue> {
3218    if high_ptr.is_null()
3219        || low_ptr.is_null()
3220        || upper_ptr.is_null()
3221        || middle_ptr.is_null()
3222        || lower_ptr.is_null()
3223    {
3224        return Err(JsValue::from_str("null pointer"));
3225    }
3226    unsafe {
3227        let hi_aliased = high_ptr as *const f64 == upper_ptr as *const f64;
3228        let lo_aliased = low_ptr as *const f64 == upper_ptr as *const f64
3229            || low_ptr as *const f64 == middle_ptr as *const f64
3230            || low_ptr as *const f64 == lower_ptr as *const f64;
3231
3232        if hi_aliased || lo_aliased {
3233            let hi = std::slice::from_raw_parts(high_ptr, len);
3234            let lo = std::slice::from_raw_parts(low_ptr, len);
3235
3236            let params = EmdParams {
3237                period: Some(period),
3238                delta: Some(delta),
3239                fraction: Some(fraction),
3240            };
3241            let input = EmdInput::from_slices(hi, lo, &[], &[], params);
3242            let output = emd_with_kernel(&input, detect_best_kernel())
3243                .map_err(|e| JsValue::from_str(&e.to_string()))?;
3244
3245            let ub = std::slice::from_raw_parts_mut(upper_ptr, len);
3246            let mb = std::slice::from_raw_parts_mut(middle_ptr, len);
3247            let lb = std::slice::from_raw_parts_mut(lower_ptr, len);
3248            ub.copy_from_slice(&output.upperband);
3249            mb.copy_from_slice(&output.middleband);
3250            lb.copy_from_slice(&output.lowerband);
3251        } else {
3252            let hi = std::slice::from_raw_parts(high_ptr, len);
3253            let lo = std::slice::from_raw_parts(low_ptr, len);
3254            let ub = std::slice::from_raw_parts_mut(upper_ptr, len);
3255            let mb = std::slice::from_raw_parts_mut(middle_ptr, len);
3256            let lb = std::slice::from_raw_parts_mut(lower_ptr, len);
3257            let params = EmdParams {
3258                period: Some(period),
3259                delta: Some(delta),
3260                fraction: Some(fraction),
3261            };
3262            let input = EmdInput::from_slices(hi, lo, &[], &[], params);
3263            emd_into_slices(ub, mb, lb, &input, detect_best_kernel())
3264                .map_err(|e| JsValue::from_str(&e.to_string()))?;
3265        }
3266        Ok(())
3267    }
3268}
3269
3270#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3271#[derive(Serialize, Deserialize)]
3272pub struct EmdBatchConfig {
3273    pub period_range: (usize, usize, usize),
3274    pub delta_range: (f64, f64, f64),
3275    pub fraction_range: (f64, f64, f64),
3276}
3277
3278#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3279#[derive(Serialize, Deserialize)]
3280pub struct EmdBatchJsOutput {
3281    pub upperband: Vec<f64>,
3282    pub middleband: Vec<f64>,
3283    pub lowerband: Vec<f64>,
3284    pub combos: Vec<EmdParams>,
3285    pub rows: usize,
3286    pub cols: usize,
3287}
3288
3289#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3290#[wasm_bindgen(js_name = "emd_batch")]
3291pub fn emd_batch_unified_js(
3292    high: &[f64],
3293    low: &[f64],
3294    _close: &[f64],
3295    _volume: &[f64],
3296    config: JsValue,
3297) -> Result<JsValue, JsValue> {
3298    if high.len() != low.len() {
3299        return Err(JsValue::from_str("high and low must have same length"));
3300    }
3301
3302    let cfg: EmdBatchConfig = serde_wasm_bindgen::from_value(config)
3303        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
3304    let sweep = EmdBatchRange {
3305        period: cfg.period_range,
3306        delta: cfg.delta_range,
3307        fraction: cfg.fraction_range,
3308    };
3309    let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
3310    let rows_p = combos.len();
3311    let cols = high.len();
3312    let total = rows_p
3313        .checked_mul(cols)
3314        .ok_or_else(|| JsValue::from_str("rows * cols overflow in emd_batch_unified_js"))?;
3315
3316    let mut ub = vec![f64::NAN; total];
3317    let mut mb = vec![f64::NAN; total];
3318    let mut lb = vec![f64::NAN; total];
3319
3320    emd_batch_inner_into(
3321        high,
3322        low,
3323        &sweep,
3324        detect_best_kernel(),
3325        false,
3326        &mut ub,
3327        &mut mb,
3328        &mut lb,
3329    )
3330    .map_err(|e| JsValue::from_str(&e.to_string()))?;
3331
3332    let out = EmdBatchJsOutput {
3333        upperband: ub,
3334        middleband: mb,
3335        lowerband: lb,
3336        combos,
3337        rows: rows_p,
3338        cols,
3339    };
3340    serde_wasm_bindgen::to_value(&out)
3341        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
3342}
3343
3344#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3345#[wasm_bindgen]
3346pub fn emd_batch_into(
3347    high_ptr: *const f64,
3348    low_ptr: *const f64,
3349    close_ptr: *const f64,
3350    volume_ptr: *const f64,
3351    upper_ptr: *mut f64,
3352    middle_ptr: *mut f64,
3353    lower_ptr: *mut f64,
3354    len: usize,
3355    period_start: usize,
3356    period_end: usize,
3357    period_step: usize,
3358    delta_start: f64,
3359    delta_end: f64,
3360    delta_step: f64,
3361    fraction_start: f64,
3362    fraction_end: f64,
3363    fraction_step: f64,
3364) -> Result<usize, JsValue> {
3365    if high_ptr.is_null()
3366        || low_ptr.is_null()
3367        || close_ptr.is_null()
3368        || volume_ptr.is_null()
3369        || upper_ptr.is_null()
3370        || middle_ptr.is_null()
3371        || lower_ptr.is_null()
3372    {
3373        return Err(JsValue::from_str("null pointer passed to emd_batch_into"));
3374    }
3375
3376    unsafe {
3377        let high = std::slice::from_raw_parts(high_ptr, len);
3378        let low = std::slice::from_raw_parts(low_ptr, len);
3379
3380        let sweep = EmdBatchRange {
3381            period: (period_start, period_end, period_step),
3382            delta: (delta_start, delta_end, delta_step),
3383            fraction: (fraction_start, fraction_end, fraction_step),
3384        };
3385
3386        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
3387        let rows = combos.len();
3388        let cols = len;
3389        let total_len = rows
3390            .checked_mul(cols)
3391            .ok_or_else(|| JsValue::from_str("rows * cols overflow in emd_batch_into"))?;
3392
3393        let upper_slice = std::slice::from_raw_parts_mut(upper_ptr, total_len);
3394        let middle_slice = std::slice::from_raw_parts_mut(middle_ptr, total_len);
3395        let lower_slice = std::slice::from_raw_parts_mut(lower_ptr, total_len);
3396
3397        emd_batch_inner_into(
3398            high,
3399            low,
3400            &sweep,
3401            detect_best_kernel(),
3402            false,
3403            upper_slice,
3404            middle_slice,
3405            lower_slice,
3406        )
3407        .map_err(|e| JsValue::from_str(&e.to_string()))?;
3408
3409        Ok(rows)
3410    }
3411}