Skip to main content

vector_ta/indicators/
emv.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5    make_uninit_matrix,
6};
7#[cfg(feature = "python")]
8use crate::utilities::kernel_validation::validate_kernel;
9#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
10use core::arch::x86_64::*;
11#[cfg(not(target_arch = "wasm32"))]
12use rayon::prelude::*;
13use std::error::Error;
14use std::mem::MaybeUninit;
15use thiserror::Error;
16
17#[cfg(all(feature = "python", feature = "cuda"))]
18use crate::cuda::cuda_available;
19#[cfg(all(feature = "python", feature = "cuda"))]
20use crate::cuda::moving_averages::DeviceArrayF32;
21#[cfg(all(feature = "python", feature = "cuda"))]
22use crate::cuda::CudaEmv;
23#[cfg(all(feature = "python", feature = "cuda"))]
24use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
25#[cfg(all(feature = "python", feature = "cuda"))]
26use cust::context::Context;
27#[cfg(all(feature = "python", feature = "cuda"))]
28use cust::memory::DeviceBuffer;
29#[cfg(feature = "python")]
30use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
31#[cfg(feature = "python")]
32use pyo3::exceptions::PyValueError;
33#[cfg(feature = "python")]
34use pyo3::prelude::*;
35#[cfg(feature = "python")]
36use pyo3::types::PyDict;
37#[cfg(all(feature = "python", feature = "cuda"))]
38use std::sync::Arc;
39
40#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
41use serde::{Deserialize, Serialize};
42#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
43use wasm_bindgen::prelude::*;
44
45#[derive(Debug, Clone)]
46pub enum EmvData<'a> {
47    Candles {
48        candles: &'a Candles,
49    },
50    Slices {
51        high: &'a [f64],
52        low: &'a [f64],
53        close: &'a [f64],
54        volume: &'a [f64],
55    },
56}
57
58#[derive(Debug, Clone)]
59pub struct EmvOutput {
60    pub values: Vec<f64>,
61}
62
63#[derive(Debug, Clone, Default)]
64#[cfg_attr(
65    all(target_arch = "wasm32", feature = "wasm"),
66    derive(Serialize, Deserialize)
67)]
68pub struct EmvParams;
69
70#[derive(Debug, Clone)]
71pub struct EmvInput<'a> {
72    pub data: EmvData<'a>,
73    pub params: EmvParams,
74}
75
76impl<'a> EmvInput<'a> {
77    #[inline(always)]
78    pub fn from_candles(candles: &'a Candles) -> Self {
79        Self {
80            data: EmvData::Candles { candles },
81            params: EmvParams,
82        }
83    }
84
85    #[inline(always)]
86    pub fn from_slices(
87        high: &'a [f64],
88        low: &'a [f64],
89        close: &'a [f64],
90        volume: &'a [f64],
91    ) -> Self {
92        Self {
93            data: EmvData::Slices {
94                high,
95                low,
96                close,
97                volume,
98            },
99            params: EmvParams,
100        }
101    }
102
103    #[inline(always)]
104    pub fn with_default_candles(candles: &'a Candles) -> Self {
105        Self::from_candles(candles)
106    }
107}
108
109#[derive(Copy, Clone, Debug, Default)]
110pub struct EmvBuilder {
111    kernel: Kernel,
112}
113
114impl EmvBuilder {
115    #[inline(always)]
116    pub fn new() -> Self {
117        Self::default()
118    }
119    #[inline(always)]
120    pub fn kernel(mut self, k: Kernel) -> Self {
121        self.kernel = k;
122        self
123    }
124    #[inline(always)]
125    pub fn apply(self, c: &Candles) -> Result<EmvOutput, EmvError> {
126        let input = EmvInput::from_candles(c);
127        emv_with_kernel(&input, self.kernel)
128    }
129    #[inline(always)]
130    pub fn apply_slices(
131        self,
132        high: &[f64],
133        low: &[f64],
134        close: &[f64],
135        volume: &[f64],
136    ) -> Result<EmvOutput, EmvError> {
137        let input = EmvInput::from_slices(high, low, close, volume);
138        emv_with_kernel(&input, self.kernel)
139    }
140    #[inline(always)]
141    pub fn into_stream(self) -> Result<EmvStream, EmvError> {
142        EmvStream::try_new()
143    }
144}
145
146#[derive(Debug, Error)]
147pub enum EmvError {
148    #[error("emv: input data slice is empty")]
149    EmptyInputData,
150    #[error("emv: All values are NaN")]
151    AllValuesNaN,
152    #[error("emv: invalid period: period = {period}, data length = {data_len}")]
153    InvalidPeriod { period: usize, data_len: usize },
154    #[error("emv: not enough valid data: needed = {needed}, valid = {valid}")]
155    NotEnoughValidData { needed: usize, valid: usize },
156    #[error("emv: output length mismatch: expected {expected}, got {got}")]
157    OutputLengthMismatch { expected: usize, got: usize },
158    #[error("emv: invalid range expansion: start={start} end={end} step={step}")]
159    InvalidRange {
160        start: isize,
161        end: isize,
162        step: isize,
163    },
164    #[error("emv: invalid kernel for batch: {0:?}")]
165    InvalidKernelForBatch(Kernel),
166    #[error("emv: invalid input: {0}")]
167    InvalidInput(&'static str),
168}
169
170#[inline]
171pub fn emv(input: &EmvInput) -> Result<EmvOutput, EmvError> {
172    emv_with_kernel(input, Kernel::Auto)
173}
174
175pub fn emv_with_kernel(input: &EmvInput, kernel: Kernel) -> Result<EmvOutput, EmvError> {
176    let (high, low, _close, volume) = match &input.data {
177        EmvData::Candles { candles } => {
178            let high = source_type(candles, "high");
179            let low = source_type(candles, "low");
180            let close = source_type(candles, "close");
181            let volume = source_type(candles, "volume");
182            (high, low, close, volume)
183        }
184        EmvData::Slices {
185            high,
186            low,
187            close,
188            volume,
189        } => (*high, *low, *close, *volume),
190    };
191
192    if high.is_empty() || low.is_empty() || volume.is_empty() {
193        return Err(EmvError::EmptyInputData);
194    }
195    let len = high.len().min(low.len()).min(volume.len());
196    if len == 0 {
197        return Err(EmvError::EmptyInputData);
198    }
199
200    let first = (0..len).find(|&i| !(high[i].is_nan() || low[i].is_nan() || volume[i].is_nan()));
201    let first = match first {
202        Some(idx) => idx,
203        None => return Err(EmvError::AllValuesNaN),
204    };
205
206    let has_second = (first + 1..len)
207        .find(|&i| !(high[i].is_nan() || low[i].is_nan() || volume[i].is_nan()))
208        .is_some();
209    if !has_second {
210        return Err(EmvError::NotEnoughValidData {
211            needed: 2,
212            valid: 1,
213        });
214    }
215
216    let mut out = alloc_with_nan_prefix(len, first + 1);
217    let chosen = match kernel {
218        Kernel::Auto => Kernel::Scalar,
219        other => other,
220    };
221
222    unsafe {
223        match chosen {
224            Kernel::Scalar | Kernel::ScalarBatch => emv_scalar(high, low, volume, first, &mut out),
225            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
226            Kernel::Avx2 | Kernel::Avx2Batch => emv_avx2(high, low, volume, first, &mut out),
227            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
228            Kernel::Avx512 | Kernel::Avx512Batch => emv_avx512(high, low, volume, first, &mut out),
229            _ => unreachable!(),
230        }
231    }
232    Ok(EmvOutput { values: out })
233}
234
235#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
236#[inline]
237pub fn emv_into(input: &EmvInput, out: &mut [f64]) -> Result<(), EmvError> {
238    emv_into_slice(out, input, Kernel::Auto)
239}
240
241#[inline]
242pub fn emv_into_slice(dst: &mut [f64], input: &EmvInput, kern: Kernel) -> Result<(), EmvError> {
243    let (high, low, _close, volume) = match &input.data {
244        EmvData::Candles { candles } => {
245            let high = source_type(candles, "high");
246            let low = source_type(candles, "low");
247            let close = source_type(candles, "close");
248            let volume = source_type(candles, "volume");
249            (high, low, close, volume)
250        }
251        EmvData::Slices {
252            high,
253            low,
254            close,
255            volume,
256        } => (*high, *low, *close, *volume),
257    };
258
259    if high.is_empty() || low.is_empty() || volume.is_empty() {
260        return Err(EmvError::EmptyInputData);
261    }
262    let len = high.len().min(low.len()).min(volume.len());
263    if len == 0 {
264        return Err(EmvError::EmptyInputData);
265    }
266
267    if dst.len() != len {
268        return Err(EmvError::OutputLengthMismatch {
269            expected: len,
270            got: dst.len(),
271        });
272    }
273
274    let first = (0..len).find(|&i| !(high[i].is_nan() || low[i].is_nan() || volume[i].is_nan()));
275    let first = match first {
276        Some(idx) => idx,
277        None => return Err(EmvError::AllValuesNaN),
278    };
279
280    let has_second = (first + 1..len)
281        .find(|&i| !(high[i].is_nan() || low[i].is_nan() || volume[i].is_nan()))
282        .is_some();
283    if !has_second {
284        return Err(EmvError::NotEnoughValidData {
285            needed: 2,
286            valid: 1,
287        });
288    }
289
290    let warm = first + 1;
291    let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
292    for v in &mut dst[..warm] {
293        *v = qnan;
294    }
295
296    let chosen = match kern {
297        Kernel::Auto => Kernel::Scalar,
298        other => other,
299    };
300
301    unsafe {
302        match chosen {
303            Kernel::Scalar | Kernel::ScalarBatch => emv_scalar(high, low, volume, first, dst),
304            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
305            Kernel::Avx2 | Kernel::Avx2Batch => emv_avx2(high, low, volume, first, dst),
306            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
307            Kernel::Avx512 | Kernel::Avx512Batch => emv_avx512(high, low, volume, first, dst),
308            _ => unreachable!(),
309        }
310    }
311    Ok(())
312}
313
314#[inline]
315pub fn emv_scalar(high: &[f64], low: &[f64], volume: &[f64], first: usize, out: &mut [f64]) {
316    let len = high.len().min(low.len()).min(volume.len());
317    let mut last_mid = 0.5 * (high[first] + low[first]);
318
319    unsafe {
320        let h_ptr = high.as_ptr();
321        let l_ptr = low.as_ptr();
322        let v_ptr = volume.as_ptr();
323        let o_ptr = out.as_mut_ptr();
324
325        let mut i = first + 1;
326        while i < len {
327            let h = *h_ptr.add(i);
328            let l = *l_ptr.add(i);
329            let v = *v_ptr.add(i);
330
331            if h.is_nan() || l.is_nan() || v.is_nan() {
332                *o_ptr.add(i) = f64::NAN;
333                i += 1;
334                continue;
335            }
336
337            let current_mid = 0.5 * (h + l);
338            let range = h - l;
339            if range == 0.0 {
340                *o_ptr.add(i) = f64::NAN;
341                last_mid = current_mid;
342                i += 1;
343                continue;
344            }
345
346            let br = v / 10000.0 / range;
347            let dmid = current_mid - last_mid;
348            *o_ptr.add(i) = dmid / br;
349            last_mid = current_mid;
350
351            i += 1;
352        }
353    }
354}
355
356#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
357#[inline]
358pub fn emv_avx512(high: &[f64], low: &[f64], volume: &[f64], first: usize, out: &mut [f64]) {
359    emv_avx2(high, low, volume, first, out)
360}
361
362#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
363#[inline]
364pub fn emv_avx2(high: &[f64], low: &[f64], volume: &[f64], first: usize, out: &mut [f64]) {
365    let len = high.len().min(low.len()).min(volume.len());
366    let mut last_mid = 0.5 * (high[first] + low[first]);
367    unsafe {
368        let h_ptr = high.as_ptr();
369        let l_ptr = low.as_ptr();
370        let v_ptr = volume.as_ptr();
371        let o_ptr = out.as_mut_ptr();
372
373        let mut i = first + 1;
374        while i < len {
375            let h = *h_ptr.add(i);
376            let l = *l_ptr.add(i);
377            let v = *v_ptr.add(i);
378
379            if !(h.is_nan() || l.is_nan() || v.is_nan()) {
380                let range = h - l;
381                let current_mid = 0.5 * (h + l);
382
383                if range == 0.0 {
384                    *o_ptr.add(i) = f64::NAN;
385                    last_mid = current_mid;
386                } else {
387                    let br = (v / 10000.0) / range;
388                    let dmid = current_mid - last_mid;
389                    *o_ptr.add(i) = dmid / br;
390                    last_mid = current_mid;
391                }
392            } else {
393                *o_ptr.add(i) = f64::NAN;
394            }
395
396            i += 1;
397        }
398    }
399}
400
401#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
402#[target_feature(enable = "avx512f")]
403pub unsafe fn emv_avx512_short(
404    high: &[f64],
405    low: &[f64],
406    volume: &[f64],
407    first: usize,
408    out: &mut [f64],
409) {
410    emv_avx2(high, low, volume, first, out);
411}
412
413#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
414#[target_feature(enable = "avx512f")]
415pub unsafe fn emv_avx512_long(
416    high: &[f64],
417    low: &[f64],
418    volume: &[f64],
419    first: usize,
420    out: &mut [f64],
421) {
422    emv_avx2(high, low, volume, first, out);
423}
424
425#[derive(Debug, Clone)]
426pub struct EmvStream {
427    last_mid: Option<f64>,
428}
429
430impl EmvStream {
431    pub fn try_new() -> Result<Self, EmvError> {
432        Ok(Self { last_mid: None })
433    }
434
435    #[inline(always)]
436    pub fn update(&mut self, high: f64, low: f64, volume: f64) -> Option<f64> {
437        if high.is_nan() || low.is_nan() || volume.is_nan() {
438            return None;
439        }
440        let current_mid = 0.5 * (high + low);
441        if self.last_mid.is_none() {
442            self.last_mid = Some(current_mid);
443            return None;
444        }
445        let last_mid = self.last_mid.unwrap();
446        let range = high - low;
447        if range == 0.0 {
448            self.last_mid = Some(current_mid);
449            return None;
450        }
451        let br = volume / 10000.0 / range;
452        let out = (current_mid - last_mid) / br;
453        self.last_mid = Some(current_mid);
454        Some(out)
455    }
456
457    #[inline(always)]
458    pub fn update_fast(&mut self, high: f64, low: f64, volume: f64) -> Option<f64> {
459        if high.is_nan() || low.is_nan() || volume.is_nan() {
460            return None;
461        }
462        let current_mid = 0.5 * (high + low);
463        if self.last_mid.is_none() {
464            self.last_mid = Some(current_mid);
465            return None;
466        }
467        let last_mid = self.last_mid.unwrap();
468        let range = high - low;
469        if range == 0.0 {
470            self.last_mid = Some(current_mid);
471            return None;
472        }
473
474        let inv_v = fast_recip_f64(volume);
475        let out = (current_mid - last_mid) * range * 10_000.0 * inv_v;
476        self.last_mid = Some(current_mid);
477        Some(out)
478    }
479}
480
481#[inline(always)]
482fn newton_refine_recip(y0: f64, x: f64) -> f64 {
483    let t = 2.0_f64 - x.mul_add(y0, 0.0);
484    y0 * t
485}
486
487#[inline(always)]
488fn fast_recip_f64(x: f64) -> f64 {
489    #[cfg(all(
490        feature = "nightly-avx",
491        target_arch = "x86_64",
492        target_feature = "avx512f"
493    ))]
494    unsafe {
495        use core::arch::x86_64::*;
496        let vx = _mm512_set1_pd(x);
497        let rcp = _mm512_rcp14_pd(vx);
498        let lo = _mm512_castpd512_pd128(rcp);
499        let y0 = _mm_cvtsd_f64(lo);
500        let y1 = newton_refine_recip(y0, x);
501        let y2 = newton_refine_recip(y1, x);
502        return y2;
503    }
504    1.0 / x
505}
506
507#[derive(Clone, Debug)]
508pub struct EmvBatchRange {}
509
510impl Default for EmvBatchRange {
511    fn default() -> Self {
512        Self {}
513    }
514}
515
516#[derive(Clone, Debug, Default)]
517pub struct EmvBatchBuilder {
518    kernel: Kernel,
519    _range: EmvBatchRange,
520}
521
522impl EmvBatchBuilder {
523    pub fn new() -> Self {
524        Self::default()
525    }
526    pub fn kernel(mut self, k: Kernel) -> Self {
527        self.kernel = k;
528        self
529    }
530
531    pub fn apply_slices(
532        self,
533        high: &[f64],
534        low: &[f64],
535        close: &[f64],
536        volume: &[f64],
537    ) -> Result<EmvBatchOutput, EmvError> {
538        emv_batch_with_kernel(high, low, close, volume, self.kernel)
539    }
540
541    pub fn with_default_slices(
542        high: &[f64],
543        low: &[f64],
544        close: &[f64],
545        volume: &[f64],
546        k: Kernel,
547    ) -> Result<EmvBatchOutput, EmvError> {
548        EmvBatchBuilder::new()
549            .kernel(k)
550            .apply_slices(high, low, close, volume)
551    }
552
553    pub fn apply_candles(self, c: &Candles) -> Result<EmvBatchOutput, EmvError> {
554        let high = source_type(c, "high");
555        let low = source_type(c, "low");
556        let close = source_type(c, "close");
557        let volume = source_type(c, "volume");
558        self.apply_slices(high, low, close, volume)
559    }
560
561    pub fn with_default_candles(c: &Candles, k: Kernel) -> Result<EmvBatchOutput, EmvError> {
562        EmvBatchBuilder::new().kernel(k).apply_candles(c)
563    }
564}
565
566pub fn emv_batch_with_kernel(
567    high: &[f64],
568    low: &[f64],
569    _close: &[f64],
570    volume: &[f64],
571    kernel: Kernel,
572) -> Result<EmvBatchOutput, EmvError> {
573    let simd = match kernel {
574        Kernel::Auto => detect_best_batch_kernel(),
575        other if other.is_batch() => other,
576        other => return Err(EmvError::InvalidKernelForBatch(other)),
577    };
578    emv_batch_par_slice(high, low, volume, simd)
579}
580
581#[derive(Clone, Debug)]
582pub struct EmvBatchOutput {
583    pub values: Vec<f64>,
584    pub combos: Vec<EmvParams>,
585    pub rows: usize,
586    pub cols: usize,
587}
588
589impl EmvBatchOutput {
590    #[inline]
591    pub fn single_row(&self) -> &[f64] {
592        debug_assert_eq!(self.rows, 1);
593        &self.values[..self.cols]
594    }
595}
596
597#[inline(always)]
598fn expand_grid(_r: &EmvBatchRange) -> Vec<()> {
599    vec![()]
600}
601
602#[inline(always)]
603pub fn emv_batch_slice(
604    high: &[f64],
605    low: &[f64],
606    volume: &[f64],
607    kern: Kernel,
608) -> Result<EmvBatchOutput, EmvError> {
609    emv_batch_inner(high, low, volume, kern, false)
610}
611
612#[inline(always)]
613pub fn emv_batch_par_slice(
614    high: &[f64],
615    low: &[f64],
616    volume: &[f64],
617    kern: Kernel,
618) -> Result<EmvBatchOutput, EmvError> {
619    emv_batch_inner(high, low, volume, kern, true)
620}
621
622fn emv_batch_inner(
623    high: &[f64],
624    low: &[f64],
625    volume: &[f64],
626    kern: Kernel,
627    _parallel: bool,
628) -> Result<EmvBatchOutput, EmvError> {
629    let len = high.len().min(low.len()).min(volume.len());
630    if len == 0 {
631        return Err(EmvError::EmptyInputData);
632    }
633
634    let first = (0..len)
635        .find(|&i| !(high[i].is_nan() || low[i].is_nan() || volume[i].is_nan()))
636        .ok_or(EmvError::AllValuesNaN)?;
637
638    let valid = (first..len)
639        .filter(|&i| !(high[i].is_nan() || low[i].is_nan() || volume[i].is_nan()))
640        .count();
641    if valid < 2 {
642        return Err(EmvError::NotEnoughValidData { needed: 2, valid });
643    }
644
645    let rows = 1usize;
646    let cols = len;
647    let _ = rows
648        .checked_mul(cols)
649        .ok_or(EmvError::InvalidInput("rows*cols overflow"))?;
650
651    let mut buf_mu = make_uninit_matrix(rows, cols);
652    init_matrix_prefixes(&mut buf_mu, cols, &[first + 1]);
653
654    let mut guard = core::mem::ManuallyDrop::new(buf_mu);
655    let out: &mut [f64] =
656        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
657
658    unsafe {
659        match kern {
660            Kernel::Scalar | Kernel::ScalarBatch => emv_scalar(high, low, volume, first, out),
661            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
662            Kernel::Avx2 | Kernel::Avx2Batch => emv_avx2(high, low, volume, first, out),
663            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
664            Kernel::Avx512 | Kernel::Avx512Batch => emv_avx512(high, low, volume, first, out),
665            _ => emv_scalar(high, low, volume, first, out),
666        }
667    }
668
669    let values = unsafe {
670        Vec::from_raw_parts(
671            guard.as_mut_ptr() as *mut f64,
672            guard.len(),
673            guard.capacity(),
674        )
675    };
676
677    Ok(EmvBatchOutput {
678        values,
679        combos: vec![EmvParams],
680        rows,
681        cols,
682    })
683}
684
685#[inline(always)]
686pub fn emv_row_scalar(
687    high: &[f64],
688    low: &[f64],
689    volume: &[f64],
690    first: usize,
691    _stride: usize,
692    _w_ptr: *const f64,
693    _inv_n: f64,
694    out: &mut [f64],
695) {
696    emv_scalar(high, low, volume, first, out);
697}
698
699#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
700#[inline(always)]
701pub fn emv_row_avx2(
702    high: &[f64],
703    low: &[f64],
704    volume: &[f64],
705    first: usize,
706    _stride: usize,
707    _w_ptr: *const f64,
708    _inv_n: f64,
709    out: &mut [f64],
710) {
711    emv_scalar(high, low, volume, first, out);
712}
713
714#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
715#[inline(always)]
716pub fn emv_row_avx512(
717    high: &[f64],
718    low: &[f64],
719    volume: &[f64],
720    first: usize,
721    _stride: usize,
722    _w_ptr: *const f64,
723    _inv_n: f64,
724    out: &mut [f64],
725) {
726    emv_avx512(high, low, volume, first, out);
727}
728
729#[inline(always)]
730fn expand_grid_emv(_r: &EmvBatchRange) -> Vec<()> {
731    vec![()]
732}
733
734#[cfg(feature = "python")]
735#[pyfunction(name = "emv")]
736#[pyo3(signature = (high, low, close, volume, kernel=None))]
737pub fn emv_py<'py>(
738    py: Python<'py>,
739    high: PyReadonlyArray1<'py, f64>,
740    low: PyReadonlyArray1<'py, f64>,
741    close: PyReadonlyArray1<'py, f64>,
742    volume: PyReadonlyArray1<'py, f64>,
743    kernel: Option<&str>,
744) -> PyResult<Bound<'py, PyArray1<f64>>> {
745    use numpy::{IntoPyArray, PyArrayMethods};
746
747    let high_slice = high.as_slice()?;
748    let low_slice = low.as_slice()?;
749    let close_slice = close.as_slice()?;
750    let volume_slice = volume.as_slice()?;
751    let kern = validate_kernel(kernel, false)?;
752
753    let data = EmvData::Slices {
754        high: high_slice,
755        low: low_slice,
756        close: close_slice,
757        volume: volume_slice,
758    };
759    let input = EmvInput {
760        data,
761        params: EmvParams,
762    };
763
764    let result_vec: Vec<f64> = py
765        .allow_threads(|| emv_with_kernel(&input, kern).map(|o| o.values))
766        .map_err(|e| PyValueError::new_err(e.to_string()))?;
767
768    Ok(result_vec.into_pyarray(py))
769}
770
771#[cfg(feature = "python")]
772#[pyclass(name = "EmvStream")]
773pub struct EmvStreamPy {
774    stream: EmvStream,
775}
776
777#[cfg(feature = "python")]
778#[pymethods]
779impl EmvStreamPy {
780    #[new]
781    fn new() -> PyResult<Self> {
782        let stream = EmvStream::try_new().map_err(|e| PyValueError::new_err(e.to_string()))?;
783        Ok(EmvStreamPy { stream })
784    }
785
786    fn update(&mut self, high: f64, low: f64, close: f64, volume: f64) -> Option<f64> {
787        self.stream.update(high, low, volume)
788    }
789}
790
791#[cfg(feature = "python")]
792fn emv_batch_inner_into(
793    high: &[f64],
794    low: &[f64],
795    _close: &[f64],
796    volume: &[f64],
797    _range: &EmvBatchRange,
798    kern: Kernel,
799    _parallel: bool,
800    out: &mut [f64],
801) -> Result<Vec<EmvParams>, EmvError> {
802    let len = high.len().min(low.len()).min(volume.len());
803    if len == 0 {
804        return Err(EmvError::EmptyInputData);
805    }
806
807    if out.len() != len {
808        return Err(EmvError::OutputLengthMismatch {
809            expected: len,
810            got: out.len(),
811        });
812    }
813
814    let out_mu: &mut [MaybeUninit<f64>] = unsafe {
815        core::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
816    };
817
818    let first = (0..len)
819        .find(|&i| !(high[i].is_nan() || low[i].is_nan() || volume[i].is_nan()))
820        .ok_or(EmvError::AllValuesNaN)?;
821
822    let valid = (first..len)
823        .filter(|&i| !(high[i].is_nan() || low[i].is_nan() || volume[i].is_nan()))
824        .count();
825    if valid < 2 {
826        return Err(EmvError::NotEnoughValidData { needed: 2, valid });
827    }
828
829    init_matrix_prefixes(out_mu, len, &[first + 1]);
830
831    let out_f: &mut [f64] =
832        unsafe { core::slice::from_raw_parts_mut(out_mu.as_mut_ptr() as *mut f64, out_mu.len()) };
833
834    unsafe {
835        match kern {
836            Kernel::Scalar | Kernel::ScalarBatch => emv_scalar(high, low, volume, first, out_f),
837            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
838            Kernel::Avx2 | Kernel::Avx2Batch => emv_avx2(high, low, volume, first, out_f),
839            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
840            Kernel::Avx512 | Kernel::Avx512Batch => emv_avx512(high, low, volume, first, out_f),
841            _ => emv_scalar(high, low, volume, first, out_f),
842        }
843    }
844
845    Ok(vec![EmvParams])
846}
847
848#[cfg(feature = "python")]
849#[pyfunction(name = "emv_batch")]
850#[pyo3(signature = (high, low, close, volume, kernel=None))]
851pub fn emv_batch_py<'py>(
852    py: Python<'py>,
853    high: PyReadonlyArray1<'py, f64>,
854    low: PyReadonlyArray1<'py, f64>,
855    close: PyReadonlyArray1<'py, f64>,
856    volume: PyReadonlyArray1<'py, f64>,
857    kernel: Option<&str>,
858) -> PyResult<Bound<'py, PyDict>> {
859    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
860
861    let high_slice = high.as_slice()?;
862    let low_slice = low.as_slice()?;
863    let close_slice = close.as_slice()?;
864    let volume_slice = volume.as_slice()?;
865    let kern = validate_kernel(kernel, true)?;
866
867    let sweep = EmvBatchRange {};
868    let combos = expand_grid(&sweep);
869    let rows = combos.len();
870    let cols = high_slice
871        .len()
872        .min(low_slice.len())
873        .min(volume_slice.len());
874
875    let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
876    let slice_out = unsafe { out_arr.as_slice_mut()? };
877
878    let _params = py
879        .allow_threads(|| {
880            let kernel = match kern {
881                Kernel::Auto => detect_best_batch_kernel(),
882                k => k,
883            };
884            emv_batch_inner_into(
885                high_slice,
886                low_slice,
887                close_slice,
888                volume_slice,
889                &sweep,
890                kernel,
891                true,
892                slice_out,
893            )
894        })
895        .map_err(|e| PyValueError::new_err(e.to_string()))?;
896
897    let dict = PyDict::new(py);
898    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
899
900    Ok(dict)
901}
902
903#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
904#[wasm_bindgen]
905pub fn emv_js(
906    high: &[f64],
907    low: &[f64],
908    close: &[f64],
909    volume: &[f64],
910) -> Result<Vec<f64>, JsValue> {
911    let input = EmvInput::from_slices(high, low, close, volume);
912
913    let mut output = vec![0.0; high.len().min(low.len()).min(close.len()).min(volume.len())];
914
915    emv_into_slice(&mut output, &input, Kernel::Auto)
916        .map_err(|e| JsValue::from_str(&e.to_string()))?;
917
918    Ok(output)
919}
920
921#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
922#[wasm_bindgen]
923pub fn emv_into(
924    high_ptr: *const f64,
925    low_ptr: *const f64,
926    close_ptr: *const f64,
927    volume_ptr: *const f64,
928    out_ptr: *mut f64,
929    len: usize,
930) -> Result<(), JsValue> {
931    if high_ptr.is_null()
932        || low_ptr.is_null()
933        || close_ptr.is_null()
934        || volume_ptr.is_null()
935        || out_ptr.is_null()
936    {
937        return Err(JsValue::from_str("null pointer passed to emv_into"));
938    }
939
940    unsafe {
941        let high = std::slice::from_raw_parts(high_ptr, len);
942        let low = std::slice::from_raw_parts(low_ptr, len);
943        let close = std::slice::from_raw_parts(close_ptr, len);
944        let volume = std::slice::from_raw_parts(volume_ptr, len);
945
946        let input = EmvInput::from_slices(high, low, close, volume);
947
948        if out_ptr == high_ptr as *mut f64
949            || out_ptr == low_ptr as *mut f64
950            || out_ptr == close_ptr as *mut f64
951            || out_ptr == volume_ptr as *mut f64
952        {
953            let mut temp = vec![0.0; len];
954            emv_into_slice(&mut temp, &input, Kernel::Auto)
955                .map_err(|e| JsValue::from_str(&e.to_string()))?;
956            let out = std::slice::from_raw_parts_mut(out_ptr, len);
957            out.copy_from_slice(&temp);
958        } else {
959            let out = std::slice::from_raw_parts_mut(out_ptr, len);
960            emv_into_slice(out, &input, Kernel::Auto)
961                .map_err(|e| JsValue::from_str(&e.to_string()))?;
962        }
963
964        Ok(())
965    }
966}
967
968#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
969#[wasm_bindgen]
970pub fn emv_alloc(len: usize) -> *mut f64 {
971    let mut vec = Vec::<f64>::with_capacity(len);
972    let ptr = vec.as_mut_ptr();
973    std::mem::forget(vec);
974    ptr
975}
976
977#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
978#[wasm_bindgen]
979pub fn emv_free(ptr: *mut f64, len: usize) {
980    if !ptr.is_null() {
981        unsafe {
982            let _ = Vec::from_raw_parts(ptr, len, len);
983        }
984    }
985}
986
987#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
988#[derive(Serialize, Deserialize)]
989pub struct EmvBatchConfig {}
990
991#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
992#[derive(Serialize, Deserialize)]
993pub struct EmvBatchJsOutput {
994    pub values: Vec<f64>,
995    pub combos: Vec<EmvParams>,
996    pub rows: usize,
997    pub cols: usize,
998}
999
1000#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1001#[wasm_bindgen(js_name = emv_batch)]
1002pub fn emv_batch_js(
1003    high: &[f64],
1004    low: &[f64],
1005    close: &[f64],
1006    volume: &[f64],
1007    _config: JsValue,
1008) -> Result<JsValue, JsValue> {
1009    let input = EmvInput::from_slices(high, low, close, volume);
1010    let len = high.len().min(low.len()).min(close.len()).min(volume.len());
1011
1012    let mut output = vec![0.0; len];
1013
1014    let kernel = detect_best_kernel();
1015
1016    emv_into_slice(&mut output, &input, kernel).map_err(|e| JsValue::from_str(&e.to_string()))?;
1017
1018    let js_output = EmvBatchJsOutput {
1019        values: output,
1020        combos: vec![EmvParams],
1021        rows: 1,
1022        cols: len,
1023    };
1024
1025    serde_wasm_bindgen::to_value(&js_output)
1026        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1027}
1028
1029#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1030#[wasm_bindgen]
1031pub fn emv_batch_into(
1032    high_ptr: *const f64,
1033    low_ptr: *const f64,
1034    close_ptr: *const f64,
1035    volume_ptr: *const f64,
1036    out_ptr: *mut f64,
1037    len: usize,
1038) -> Result<usize, JsValue> {
1039    if high_ptr.is_null()
1040        || low_ptr.is_null()
1041        || close_ptr.is_null()
1042        || volume_ptr.is_null()
1043        || out_ptr.is_null()
1044    {
1045        return Err(JsValue::from_str("null pointer passed to emv_batch_into"));
1046    }
1047
1048    unsafe {
1049        let high = std::slice::from_raw_parts(high_ptr, len);
1050        let low = std::slice::from_raw_parts(low_ptr, len);
1051        let close = std::slice::from_raw_parts(close_ptr, len);
1052        let volume = std::slice::from_raw_parts(volume_ptr, len);
1053
1054        let input = EmvInput::from_slices(high, low, close, volume);
1055
1056        let kernel = detect_best_kernel();
1057
1058        let out = std::slice::from_raw_parts_mut(out_ptr, len);
1059        emv_into_slice(out, &input, kernel).map_err(|e| JsValue::from_str(&e.to_string()))?;
1060
1061        Ok(1)
1062    }
1063}
1064
1065#[cfg(test)]
1066mod tests {
1067    use super::*;
1068    use crate::skip_if_unsupported;
1069    use crate::utilities::data_loader::read_candles_from_csv;
1070
1071    fn check_emv_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1072        skip_if_unsupported!(kernel, test_name);
1073        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1074        let candles = read_candles_from_csv(file_path)?;
1075        let input = EmvInput::from_candles(&candles);
1076        let output = emv_with_kernel(&input, kernel)?;
1077        assert_eq!(output.values.len(), candles.close.len());
1078        let expected_last_five_emv = [
1079            -6488905.579799851,
1080            2371436.7401001123,
1081            -3855069.958128531,
1082            1051939.877943717,
1083            -8519287.22257077,
1084        ];
1085        let start = output.values.len().saturating_sub(5);
1086        for (i, &val) in output.values[start..].iter().enumerate() {
1087            let diff = (val - expected_last_five_emv[i]).abs();
1088            let tol = expected_last_five_emv[i].abs() * 0.0001;
1089            assert!(
1090                diff <= tol,
1091                "[{}] EMV {:?} mismatch at idx {}: got {}, expected {}, diff={}",
1092                test_name,
1093                kernel,
1094                i,
1095                val,
1096                expected_last_five_emv[i],
1097                diff
1098            );
1099        }
1100        Ok(())
1101    }
1102
1103    fn check_emv_with_default_candles(
1104        test_name: &str,
1105        kernel: Kernel,
1106    ) -> Result<(), Box<dyn Error>> {
1107        skip_if_unsupported!(kernel, test_name);
1108        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1109        let candles = read_candles_from_csv(file_path)?;
1110        let input = EmvInput::with_default_candles(&candles);
1111        let output = emv_with_kernel(&input, kernel)?;
1112        assert_eq!(output.values.len(), candles.close.len());
1113        Ok(())
1114    }
1115
1116    fn check_emv_empty_data(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1117        skip_if_unsupported!(kernel, test_name);
1118        let empty: [f64; 0] = [];
1119        let input = EmvInput::from_slices(&empty, &empty, &empty, &empty);
1120        let result = emv_with_kernel(&input, kernel);
1121        assert!(result.is_err());
1122        Ok(())
1123    }
1124
1125    fn check_emv_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1126        skip_if_unsupported!(kernel, test_name);
1127        let nan_arr = [f64::NAN, f64::NAN];
1128        let input = EmvInput::from_slices(&nan_arr, &nan_arr, &nan_arr, &nan_arr);
1129        let result = emv_with_kernel(&input, kernel);
1130        assert!(result.is_err());
1131        Ok(())
1132    }
1133
1134    fn check_emv_not_enough_data(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1135        skip_if_unsupported!(kernel, test_name);
1136        let high = [10000.0, f64::NAN];
1137        let low = [9990.0, f64::NAN];
1138        let close = [9995.0, f64::NAN];
1139        let volume = [1_000_000.0, f64::NAN];
1140        let input = EmvInput::from_slices(&high, &low, &close, &volume);
1141        let result = emv_with_kernel(&input, kernel);
1142        assert!(result.is_err());
1143        Ok(())
1144    }
1145
1146    fn check_emv_basic_calculation(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1147        skip_if_unsupported!(kernel, test_name);
1148        let high = [10.0, 12.0, 13.0, 15.0];
1149        let low = [5.0, 7.0, 8.0, 10.0];
1150        let close = [7.5, 9.0, 10.5, 12.5];
1151        let volume = [10000.0, 20000.0, 25000.0, 30000.0];
1152        let input = EmvInput::from_slices(&high, &low, &close, &volume);
1153        let output = emv_with_kernel(&input, kernel)?;
1154        assert_eq!(output.values.len(), 4);
1155        assert!(output.values[0].is_nan());
1156        for &val in &output.values[1..] {
1157            assert!(!val.is_nan());
1158        }
1159        Ok(())
1160    }
1161
1162    fn check_emv_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1163        skip_if_unsupported!(kernel, test_name);
1164        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1165        let candles = read_candles_from_csv(file_path)?;
1166        let high = source_type(&candles, "high");
1167        let low = source_type(&candles, "low");
1168        let volume = source_type(&candles, "volume");
1169
1170        let output = emv_with_kernel(&EmvInput::from_candles(&candles), kernel)?.values;
1171
1172        let mut stream = EmvStream::try_new()?;
1173        let mut stream_values = Vec::with_capacity(high.len());
1174        for i in 0..high.len() {
1175            match stream.update(high[i], low[i], volume[i]) {
1176                Some(val) => stream_values.push(val),
1177                None => stream_values.push(f64::NAN),
1178            }
1179        }
1180        assert_eq!(output.len(), stream_values.len());
1181        for (b, s) in output.iter().zip(stream_values.iter()) {
1182            if b.is_nan() && s.is_nan() {
1183                continue;
1184            }
1185            let diff = (b - s).abs();
1186            assert!(
1187                diff < 1e-9,
1188                "[{}] EMV streaming f64 mismatch: batch={}, stream={}, diff={}",
1189                test_name,
1190                b,
1191                s,
1192                diff
1193            );
1194        }
1195        Ok(())
1196    }
1197
1198    #[cfg(debug_assertions)]
1199    fn check_emv_no_poison(
1200        test_name: &str,
1201        kernel: Kernel,
1202    ) -> Result<(), Box<dyn std::error::Error>> {
1203        skip_if_unsupported!(kernel, test_name);
1204
1205        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1206        let candles = read_candles_from_csv(file_path)?;
1207
1208        let input1 = EmvInput::from_candles(&candles);
1209        let output1 = emv_with_kernel(&input1, kernel)?;
1210
1211        let high = source_type(&candles, "high");
1212        let low = source_type(&candles, "low");
1213        let close = source_type(&candles, "close");
1214        let volume = source_type(&candles, "volume");
1215        let input2 = EmvInput::from_slices(high, low, close, volume);
1216        let output2 = emv_with_kernel(&input2, kernel)?;
1217
1218        let input3 = EmvInput::with_default_candles(&candles);
1219        let output3 = emv_with_kernel(&input3, kernel)?;
1220
1221        let outputs = [
1222            ("from_candles", &output1.values),
1223            ("from_slices", &output2.values),
1224            ("with_default_candles", &output3.values),
1225        ];
1226
1227        for (method_name, values) in &outputs {
1228            for (i, &val) in values.iter().enumerate() {
1229                if val.is_nan() {
1230                    continue;
1231                }
1232
1233                let bits = val.to_bits();
1234
1235                if bits == 0x11111111_11111111 {
1236                    panic!(
1237                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1238						 using method: {}",
1239                        test_name, val, bits, i, method_name
1240                    );
1241                }
1242
1243                if bits == 0x22222222_22222222 {
1244                    panic!(
1245                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1246						 using method: {}",
1247                        test_name, val, bits, i, method_name
1248                    );
1249                }
1250
1251                if bits == 0x33333333_33333333 {
1252                    panic!(
1253                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1254						 using method: {}",
1255                        test_name, val, bits, i, method_name
1256                    );
1257                }
1258            }
1259        }
1260
1261        Ok(())
1262    }
1263
1264    #[cfg(not(debug_assertions))]
1265    fn check_emv_no_poison(
1266        _test_name: &str,
1267        _kernel: Kernel,
1268    ) -> Result<(), Box<dyn std::error::Error>> {
1269        Ok(())
1270    }
1271
1272    #[cfg(feature = "proptest")]
1273    #[allow(clippy::float_cmp)]
1274    fn check_emv_property(
1275        test_name: &str,
1276        kernel: Kernel,
1277    ) -> Result<(), Box<dyn std::error::Error>> {
1278        use proptest::prelude::*;
1279        skip_if_unsupported!(kernel, test_name);
1280
1281        let strat = prop::collection::vec(
1282            (10.0f64..100000.0f64, 0.5f64..0.999f64, 1000.0f64..1e9f64),
1283            2..400,
1284        )
1285        .prop_map(|data| {
1286            let high: Vec<f64> = data.iter().map(|(h, _, _)| *h).collect();
1287            let low: Vec<f64> = data
1288                .iter()
1289                .zip(&high)
1290                .map(|((_, l_pct, _), h)| h * l_pct)
1291                .collect();
1292            let volume: Vec<f64> = data.iter().map(|(_, _, v)| *v).collect();
1293
1294            let close = high.clone();
1295            (high, low, close, volume)
1296        });
1297
1298        proptest::test_runner::TestRunner::default()
1299            .run(&strat, |(high, low, close, volume)| {
1300                let input = EmvInput::from_slices(&high, &low, &close, &volume);
1301
1302                let EmvOutput { values: out } = emv_with_kernel(&input, kernel).unwrap();
1303
1304                let EmvOutput { values: ref_out } =
1305                    emv_with_kernel(&input, Kernel::Scalar).unwrap();
1306
1307                prop_assert!(
1308                    out[0].is_nan(),
1309                    "First EMV value should always be NaN (warmup period)"
1310                );
1311
1312                for i in 1..out.len() {
1313                    if high[i].is_finite() && low[i].is_finite() && volume[i].is_finite() {
1314                        let range = high[i] - low[i];
1315                        if range != 0.0 {
1316                            prop_assert!(
1317								out[i].is_finite(),
1318								"EMV at index {} should be finite when inputs are finite and range != 0",
1319								i
1320							);
1321                        }
1322                    }
1323                }
1324
1325                for i in 0..out.len() {
1326                    let y = out[i];
1327                    let r = ref_out[i];
1328
1329                    if !y.is_finite() || !r.is_finite() {
1330                        prop_assert!(
1331                            y.to_bits() == r.to_bits(),
1332                            "Non-finite mismatch at index {}: {} vs {}",
1333                            i,
1334                            y,
1335                            r
1336                        );
1337                    } else {
1338                        let y_bits = y.to_bits();
1339                        let r_bits = r.to_bits();
1340                        let ulp_diff = y_bits.abs_diff(r_bits);
1341
1342                        prop_assert!(
1343                            ulp_diff <= 3,
1344                            "ULP difference too large at index {}: {} vs {} (ULP={})",
1345                            i,
1346                            y,
1347                            r,
1348                            ulp_diff
1349                        );
1350                    }
1351                }
1352
1353                let mut last_mid = 0.5 * (high[0] + low[0]);
1354                for i in 1..out.len() {
1355                    let current_mid = 0.5 * (high[i] + low[i]);
1356                    let range = high[i] - low[i];
1357
1358                    if range == 0.0 {
1359                        prop_assert!(
1360                            out[i].is_nan(),
1361                            "EMV at index {} should be NaN when range is zero",
1362                            i
1363                        );
1364                    } else {
1365                        let expected_emv = (current_mid - last_mid) / (volume[i] / 10000.0 / range);
1366
1367                        if out[i].is_finite() && expected_emv.is_finite() {
1368                            let diff = (out[i] - expected_emv).abs();
1369                            let tolerance = 1e-9;
1370                            prop_assert!(
1371                                diff <= tolerance,
1372                                "EMV formula mismatch at index {}: got {}, expected {}, diff={}",
1373                                i,
1374                                out[i],
1375                                expected_emv,
1376                                diff
1377                            );
1378                        }
1379                    }
1380
1381                    last_mid = current_mid;
1382                }
1383
1384                for i in 1..out.len() {
1385                    if out[i].is_finite() {
1386                        let price_change =
1387                            (high[i] + low[i]) / 2.0 - (high[i - 1] + low[i - 1]) / 2.0;
1388                        let max_reasonable = price_change.abs() * 1e8;
1389
1390                        prop_assert!(
1391                            out[i].abs() <= max_reasonable,
1392                            "EMV at index {} seems unreasonably large: {} (price change: {})",
1393                            i,
1394                            out[i],
1395                            price_change
1396                        );
1397                    }
1398                }
1399
1400                if high.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10)
1401                    && low.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10)
1402                    && high.iter().zip(&low).all(|(h, l)| h > l)
1403                {
1404                    for i in 1..out.len() {
1405                        if out[i].is_finite() {
1406                            prop_assert!(
1407                                out[i].abs() < 1e-9,
1408                                "EMV should be ~0 for constant prices, got {} at index {}",
1409                                out[i],
1410                                i
1411                            );
1412                        }
1413                    }
1414                }
1415
1416                for (i, &val) in out.iter().enumerate() {
1417                    if !val.is_nan() {
1418                        let bits = val.to_bits();
1419                        prop_assert!(
1420                            bits != 0x11111111_11111111
1421                                && bits != 0x22222222_22222222
1422                                && bits != 0x33333333_33333333,
1423                            "Found poison value at index {}: {} (0x{:016X})",
1424                            i,
1425                            val,
1426                            bits
1427                        );
1428                    }
1429                }
1430
1431                Ok(())
1432            })
1433            .unwrap();
1434
1435        Ok(())
1436    }
1437
1438    macro_rules! generate_all_emv_tests {
1439        ($($test_fn:ident),*) => {
1440            paste::paste! {
1441                $(
1442                    #[test]
1443                    fn [<$test_fn _scalar_f64>]() {
1444                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1445                    }
1446                )*
1447                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1448                $(
1449                    #[test]
1450                    fn [<$test_fn _avx2_f64>]() {
1451                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1452                    }
1453                    #[test]
1454                    fn [<$test_fn _avx512_f64>]() {
1455                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1456                    }
1457                )*
1458            }
1459        }
1460    }
1461
1462    generate_all_emv_tests!(
1463        check_emv_accuracy,
1464        check_emv_with_default_candles,
1465        check_emv_empty_data,
1466        check_emv_all_nan,
1467        check_emv_not_enough_data,
1468        check_emv_basic_calculation,
1469        check_emv_streaming,
1470        check_emv_no_poison
1471    );
1472
1473    #[cfg(feature = "proptest")]
1474    generate_all_emv_tests!(check_emv_property);
1475
1476    fn check_batch_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1477        skip_if_unsupported!(kernel, test);
1478        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1479        let c = read_candles_from_csv(file)?;
1480        let output = EmvBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
1481        assert_eq!(output.values.len(), c.close.len());
1482        Ok(())
1483    }
1484
1485    macro_rules! gen_batch_tests {
1486        ($fn_name:ident) => {
1487            paste::paste! {
1488                #[test] fn [<$fn_name _scalar>]()      {
1489                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1490                }
1491                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1492                #[test] fn [<$fn_name _avx2>]()        {
1493                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1494                }
1495                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1496                #[test] fn [<$fn_name _avx512>]()      {
1497                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1498                }
1499                #[test] fn [<$fn_name _auto_detect>]() {
1500                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1501                }
1502            }
1503        };
1504    }
1505    gen_batch_tests!(check_batch_row);
1506    gen_batch_tests!(check_batch_no_poison);
1507
1508    #[cfg(debug_assertions)]
1509    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1510        skip_if_unsupported!(kernel, test);
1511
1512        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1513        let c = read_candles_from_csv(file)?;
1514
1515        let output = EmvBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
1516
1517        for (idx, &val) in output.values.iter().enumerate() {
1518            if val.is_nan() {
1519                continue;
1520            }
1521
1522            let bits = val.to_bits();
1523            let row = idx / output.cols;
1524            let col = idx % output.cols;
1525
1526            if bits == 0x11111111_11111111 {
1527                panic!(
1528                    "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1529					 at row {} col {} (flat index {})",
1530                    test, val, bits, row, col, idx
1531                );
1532            }
1533
1534            if bits == 0x22222222_22222222 {
1535                panic!(
1536                    "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) \
1537					 at row {} col {} (flat index {})",
1538                    test, val, bits, row, col, idx
1539                );
1540            }
1541
1542            if bits == 0x33333333_33333333 {
1543                panic!(
1544                    "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) \
1545					 at row {} col {} (flat index {})",
1546                    test, val, bits, row, col, idx
1547                );
1548            }
1549        }
1550
1551        Ok(())
1552    }
1553
1554    #[cfg(not(debug_assertions))]
1555    fn check_batch_no_poison(
1556        _test: &str,
1557        _kernel: Kernel,
1558    ) -> Result<(), Box<dyn std::error::Error>> {
1559        Ok(())
1560    }
1561
1562    #[test]
1563    fn test_emv_into_matches_api() -> Result<(), Box<dyn Error>> {
1564        let n = 256usize;
1565        let mut high = Vec::with_capacity(n);
1566        let mut low = Vec::with_capacity(n);
1567        let mut close = Vec::with_capacity(n);
1568        let mut volume = Vec::with_capacity(n);
1569        for i in 0..n {
1570            let base = 100.0 + (i as f64) * 0.1;
1571            let spread = 1.0 + ((i % 5) as f64) * 0.2;
1572            let h = base + spread * 0.6;
1573            let l = base - spread * 0.4;
1574            high.push(h);
1575            low.push(l);
1576            close.push(0.5 * (h + l));
1577            volume.push(10_000.0 + ((i * 37) % 1000) as f64 * 100.0);
1578        }
1579
1580        let input = EmvInput::from_slices(&high, &low, &close, &volume);
1581        let baseline = emv(&input)?.values;
1582
1583        let mut into_out = vec![0.0; baseline.len()];
1584
1585        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1586        {
1587            emv_into(&input, &mut into_out)?;
1588        }
1589        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1590        {
1591            emv_into_slice(&mut into_out, &input, Kernel::Auto)?;
1592        }
1593
1594        assert_eq!(baseline.len(), into_out.len());
1595        fn eq_or_both_nan(a: f64, b: f64) -> bool {
1596            (a.is_nan() && b.is_nan()) || (a == b) || (a - b).abs() <= 1e-12
1597        }
1598        for (i, (a, b)) in baseline.iter().zip(into_out.iter()).enumerate() {
1599            assert!(
1600                eq_or_both_nan(*a, *b),
1601                "divergence at idx {}: api={}, into={}",
1602                i,
1603                a,
1604                b
1605            );
1606        }
1607        Ok(())
1608    }
1609}
1610
1611#[cfg(all(feature = "python", feature = "cuda"))]
1612#[pyclass(module = "ta_indicators.cuda", name = "EmvDeviceArrayF32", unsendable)]
1613pub struct EmvDeviceArrayF32Py {
1614    pub inner: DeviceArrayF32,
1615    _ctx_guard: Arc<Context>,
1616    device_id: i32,
1617}
1618
1619#[cfg(all(feature = "python", feature = "cuda"))]
1620#[pymethods]
1621impl EmvDeviceArrayF32Py {
1622    #[getter]
1623    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1624        let d = PyDict::new(py);
1625        let inner = &self.inner;
1626        let itemsize = std::mem::size_of::<f32>();
1627        d.set_item("shape", (inner.rows, inner.cols))?;
1628        d.set_item("typestr", "<f4")?;
1629        d.set_item("strides", (inner.cols * itemsize, itemsize))?;
1630        let ptr_val = inner.buf.as_device_ptr().as_raw() as usize;
1631        d.set_item("data", (ptr_val, false))?;
1632
1633        d.set_item("version", 3)?;
1634        Ok(d)
1635    }
1636
1637    fn __dlpack_device__(&self) -> (i32, i32) {
1638        (2, self.device_id)
1639    }
1640
1641    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1642    fn __dlpack__<'py>(
1643        &mut self,
1644        py: Python<'py>,
1645        stream: Option<pyo3::PyObject>,
1646        max_version: Option<pyo3::PyObject>,
1647        dl_device: Option<pyo3::PyObject>,
1648        copy: Option<pyo3::PyObject>,
1649    ) -> PyResult<PyObject> {
1650        let (kdl, alloc_dev) = self.__dlpack_device__();
1651        if let Some(dev_obj) = dl_device.as_ref() {
1652            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1653                if dev_ty != kdl || dev_id != alloc_dev {
1654                    let wants_copy = copy
1655                        .as_ref()
1656                        .and_then(|c| c.extract::<bool>(py).ok())
1657                        .unwrap_or(false);
1658                    if wants_copy {
1659                        return Err(PyValueError::new_err(
1660                            "device copy not implemented for __dlpack__",
1661                        ));
1662                    } else {
1663                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1664                    }
1665                }
1666            }
1667        }
1668        let _ = stream;
1669
1670        let dummy =
1671            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1672        let inner = std::mem::replace(
1673            &mut self.inner,
1674            DeviceArrayF32 {
1675                buf: dummy,
1676                rows: 0,
1677                cols: 0,
1678            },
1679        );
1680
1681        let rows = inner.rows;
1682        let cols = inner.cols;
1683        let buf = inner.buf;
1684
1685        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1686
1687        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1688    }
1689}
1690
1691#[cfg(all(feature = "python", feature = "cuda"))]
1692impl EmvDeviceArrayF32Py {
1693    fn new_from_cuda(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
1694        Self {
1695            inner,
1696            _ctx_guard: ctx_guard,
1697            device_id: device_id as i32,
1698        }
1699    }
1700}
1701
1702#[cfg(all(feature = "python", feature = "cuda"))]
1703#[pyfunction(name = "emv_cuda_batch_dev")]
1704#[pyo3(signature = (high_f32, low_f32, volume_f32, device_id=0))]
1705pub fn emv_cuda_batch_dev_py<'py>(
1706    py: Python<'py>,
1707    high_f32: numpy::PyReadonlyArray1<'py, f32>,
1708    low_f32: numpy::PyReadonlyArray1<'py, f32>,
1709    volume_f32: numpy::PyReadonlyArray1<'py, f32>,
1710    device_id: usize,
1711) -> PyResult<EmvDeviceArrayF32Py> {
1712    if !cuda_available() {
1713        return Err(PyValueError::new_err("CUDA not available"));
1714    }
1715    let h = high_f32.as_slice()?;
1716    let l = low_f32.as_slice()?;
1717    let v = volume_f32.as_slice()?;
1718    let (inner, ctx, dev_id) = py.allow_threads(|| -> PyResult<_> {
1719        let cuda = CudaEmv::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1720        let ctx = cuda.context_arc();
1721        let dev_id = cuda.device_id();
1722        let buf = cuda
1723            .emv_batch_dev(h, l, v)
1724            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1725        Ok((buf, ctx, dev_id))
1726    })?;
1727    Ok(EmvDeviceArrayF32Py::new_from_cuda(inner, ctx, dev_id))
1728}
1729
1730#[cfg(all(feature = "python", feature = "cuda"))]
1731#[pyfunction(name = "emv_cuda_many_series_one_param_dev")]
1732#[pyo3(signature = (high_tm_f32, low_tm_f32, volume_tm_f32, device_id=0))]
1733pub fn emv_cuda_many_series_one_param_dev_py(
1734    py: Python<'_>,
1735    high_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1736    low_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1737    volume_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1738    device_id: usize,
1739) -> PyResult<EmvDeviceArrayF32Py> {
1740    if !cuda_available() {
1741        return Err(PyValueError::new_err("CUDA not available"));
1742    }
1743    use numpy::PyUntypedArrayMethods;
1744    let h_flat = high_tm_f32.as_slice()?;
1745    let l_flat = low_tm_f32.as_slice()?;
1746    let v_flat = volume_tm_f32.as_slice()?;
1747    let rows = high_tm_f32.shape()[0];
1748    let cols = high_tm_f32.shape()[1];
1749    if low_tm_f32.shape() != [rows, cols] || volume_tm_f32.shape() != [rows, cols] {
1750        return Err(PyValueError::new_err("high/low/volume shapes mismatch"));
1751    }
1752    let (inner, ctx, dev_id) = py.allow_threads(|| -> PyResult<_> {
1753        let cuda = CudaEmv::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1754        let ctx = cuda.context_arc();
1755        let dev_id = cuda.device_id();
1756        let buf = cuda
1757            .emv_many_series_one_param_time_major_dev(h_flat, l_flat, v_flat, cols, rows)
1758            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1759        Ok((buf, ctx, dev_id))
1760    })?;
1761    Ok(EmvDeviceArrayF32Py::new_from_cuda(inner, ctx, dev_id))
1762}