1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7#[cfg(feature = "python")]
8use crate::utilities::kernel_validation::validate_kernel;
9#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
10use core::arch::x86_64::*;
11#[cfg(not(target_arch = "wasm32"))]
12use rayon::prelude::*;
13use std::error::Error;
14use std::mem::MaybeUninit;
15use thiserror::Error;
16
17#[cfg(all(feature = "python", feature = "cuda"))]
18use crate::cuda::cuda_available;
19#[cfg(all(feature = "python", feature = "cuda"))]
20use crate::cuda::moving_averages::DeviceArrayF32;
21#[cfg(all(feature = "python", feature = "cuda"))]
22use crate::cuda::CudaEmv;
23#[cfg(all(feature = "python", feature = "cuda"))]
24use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
25#[cfg(all(feature = "python", feature = "cuda"))]
26use cust::context::Context;
27#[cfg(all(feature = "python", feature = "cuda"))]
28use cust::memory::DeviceBuffer;
29#[cfg(feature = "python")]
30use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
31#[cfg(feature = "python")]
32use pyo3::exceptions::PyValueError;
33#[cfg(feature = "python")]
34use pyo3::prelude::*;
35#[cfg(feature = "python")]
36use pyo3::types::PyDict;
37#[cfg(all(feature = "python", feature = "cuda"))]
38use std::sync::Arc;
39
40#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
41use serde::{Deserialize, Serialize};
42#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
43use wasm_bindgen::prelude::*;
44
45#[derive(Debug, Clone)]
46pub enum EmvData<'a> {
47 Candles {
48 candles: &'a Candles,
49 },
50 Slices {
51 high: &'a [f64],
52 low: &'a [f64],
53 close: &'a [f64],
54 volume: &'a [f64],
55 },
56}
57
58#[derive(Debug, Clone)]
59pub struct EmvOutput {
60 pub values: Vec<f64>,
61}
62
63#[derive(Debug, Clone, Default)]
64#[cfg_attr(
65 all(target_arch = "wasm32", feature = "wasm"),
66 derive(Serialize, Deserialize)
67)]
68pub struct EmvParams;
69
70#[derive(Debug, Clone)]
71pub struct EmvInput<'a> {
72 pub data: EmvData<'a>,
73 pub params: EmvParams,
74}
75
76impl<'a> EmvInput<'a> {
77 #[inline(always)]
78 pub fn from_candles(candles: &'a Candles) -> Self {
79 Self {
80 data: EmvData::Candles { candles },
81 params: EmvParams,
82 }
83 }
84
85 #[inline(always)]
86 pub fn from_slices(
87 high: &'a [f64],
88 low: &'a [f64],
89 close: &'a [f64],
90 volume: &'a [f64],
91 ) -> Self {
92 Self {
93 data: EmvData::Slices {
94 high,
95 low,
96 close,
97 volume,
98 },
99 params: EmvParams,
100 }
101 }
102
103 #[inline(always)]
104 pub fn with_default_candles(candles: &'a Candles) -> Self {
105 Self::from_candles(candles)
106 }
107}
108
109#[derive(Copy, Clone, Debug, Default)]
110pub struct EmvBuilder {
111 kernel: Kernel,
112}
113
114impl EmvBuilder {
115 #[inline(always)]
116 pub fn new() -> Self {
117 Self::default()
118 }
119 #[inline(always)]
120 pub fn kernel(mut self, k: Kernel) -> Self {
121 self.kernel = k;
122 self
123 }
124 #[inline(always)]
125 pub fn apply(self, c: &Candles) -> Result<EmvOutput, EmvError> {
126 let input = EmvInput::from_candles(c);
127 emv_with_kernel(&input, self.kernel)
128 }
129 #[inline(always)]
130 pub fn apply_slices(
131 self,
132 high: &[f64],
133 low: &[f64],
134 close: &[f64],
135 volume: &[f64],
136 ) -> Result<EmvOutput, EmvError> {
137 let input = EmvInput::from_slices(high, low, close, volume);
138 emv_with_kernel(&input, self.kernel)
139 }
140 #[inline(always)]
141 pub fn into_stream(self) -> Result<EmvStream, EmvError> {
142 EmvStream::try_new()
143 }
144}
145
146#[derive(Debug, Error)]
147pub enum EmvError {
148 #[error("emv: input data slice is empty")]
149 EmptyInputData,
150 #[error("emv: All values are NaN")]
151 AllValuesNaN,
152 #[error("emv: invalid period: period = {period}, data length = {data_len}")]
153 InvalidPeriod { period: usize, data_len: usize },
154 #[error("emv: not enough valid data: needed = {needed}, valid = {valid}")]
155 NotEnoughValidData { needed: usize, valid: usize },
156 #[error("emv: output length mismatch: expected {expected}, got {got}")]
157 OutputLengthMismatch { expected: usize, got: usize },
158 #[error("emv: invalid range expansion: start={start} end={end} step={step}")]
159 InvalidRange {
160 start: isize,
161 end: isize,
162 step: isize,
163 },
164 #[error("emv: invalid kernel for batch: {0:?}")]
165 InvalidKernelForBatch(Kernel),
166 #[error("emv: invalid input: {0}")]
167 InvalidInput(&'static str),
168}
169
170#[inline]
171pub fn emv(input: &EmvInput) -> Result<EmvOutput, EmvError> {
172 emv_with_kernel(input, Kernel::Auto)
173}
174
175pub fn emv_with_kernel(input: &EmvInput, kernel: Kernel) -> Result<EmvOutput, EmvError> {
176 let (high, low, _close, volume) = match &input.data {
177 EmvData::Candles { candles } => {
178 let high = source_type(candles, "high");
179 let low = source_type(candles, "low");
180 let close = source_type(candles, "close");
181 let volume = source_type(candles, "volume");
182 (high, low, close, volume)
183 }
184 EmvData::Slices {
185 high,
186 low,
187 close,
188 volume,
189 } => (*high, *low, *close, *volume),
190 };
191
192 if high.is_empty() || low.is_empty() || volume.is_empty() {
193 return Err(EmvError::EmptyInputData);
194 }
195 let len = high.len().min(low.len()).min(volume.len());
196 if len == 0 {
197 return Err(EmvError::EmptyInputData);
198 }
199
200 let first = (0..len).find(|&i| !(high[i].is_nan() || low[i].is_nan() || volume[i].is_nan()));
201 let first = match first {
202 Some(idx) => idx,
203 None => return Err(EmvError::AllValuesNaN),
204 };
205
206 let has_second = (first + 1..len)
207 .find(|&i| !(high[i].is_nan() || low[i].is_nan() || volume[i].is_nan()))
208 .is_some();
209 if !has_second {
210 return Err(EmvError::NotEnoughValidData {
211 needed: 2,
212 valid: 1,
213 });
214 }
215
216 let mut out = alloc_with_nan_prefix(len, first + 1);
217 let chosen = match kernel {
218 Kernel::Auto => Kernel::Scalar,
219 other => other,
220 };
221
222 unsafe {
223 match chosen {
224 Kernel::Scalar | Kernel::ScalarBatch => emv_scalar(high, low, volume, first, &mut out),
225 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
226 Kernel::Avx2 | Kernel::Avx2Batch => emv_avx2(high, low, volume, first, &mut out),
227 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
228 Kernel::Avx512 | Kernel::Avx512Batch => emv_avx512(high, low, volume, first, &mut out),
229 _ => unreachable!(),
230 }
231 }
232 Ok(EmvOutput { values: out })
233}
234
235#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
236#[inline]
237pub fn emv_into(input: &EmvInput, out: &mut [f64]) -> Result<(), EmvError> {
238 emv_into_slice(out, input, Kernel::Auto)
239}
240
241#[inline]
242pub fn emv_into_slice(dst: &mut [f64], input: &EmvInput, kern: Kernel) -> Result<(), EmvError> {
243 let (high, low, _close, volume) = match &input.data {
244 EmvData::Candles { candles } => {
245 let high = source_type(candles, "high");
246 let low = source_type(candles, "low");
247 let close = source_type(candles, "close");
248 let volume = source_type(candles, "volume");
249 (high, low, close, volume)
250 }
251 EmvData::Slices {
252 high,
253 low,
254 close,
255 volume,
256 } => (*high, *low, *close, *volume),
257 };
258
259 if high.is_empty() || low.is_empty() || volume.is_empty() {
260 return Err(EmvError::EmptyInputData);
261 }
262 let len = high.len().min(low.len()).min(volume.len());
263 if len == 0 {
264 return Err(EmvError::EmptyInputData);
265 }
266
267 if dst.len() != len {
268 return Err(EmvError::OutputLengthMismatch {
269 expected: len,
270 got: dst.len(),
271 });
272 }
273
274 let first = (0..len).find(|&i| !(high[i].is_nan() || low[i].is_nan() || volume[i].is_nan()));
275 let first = match first {
276 Some(idx) => idx,
277 None => return Err(EmvError::AllValuesNaN),
278 };
279
280 let has_second = (first + 1..len)
281 .find(|&i| !(high[i].is_nan() || low[i].is_nan() || volume[i].is_nan()))
282 .is_some();
283 if !has_second {
284 return Err(EmvError::NotEnoughValidData {
285 needed: 2,
286 valid: 1,
287 });
288 }
289
290 let warm = first + 1;
291 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
292 for v in &mut dst[..warm] {
293 *v = qnan;
294 }
295
296 let chosen = match kern {
297 Kernel::Auto => Kernel::Scalar,
298 other => other,
299 };
300
301 unsafe {
302 match chosen {
303 Kernel::Scalar | Kernel::ScalarBatch => emv_scalar(high, low, volume, first, dst),
304 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
305 Kernel::Avx2 | Kernel::Avx2Batch => emv_avx2(high, low, volume, first, dst),
306 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
307 Kernel::Avx512 | Kernel::Avx512Batch => emv_avx512(high, low, volume, first, dst),
308 _ => unreachable!(),
309 }
310 }
311 Ok(())
312}
313
314#[inline]
315pub fn emv_scalar(high: &[f64], low: &[f64], volume: &[f64], first: usize, out: &mut [f64]) {
316 let len = high.len().min(low.len()).min(volume.len());
317 let mut last_mid = 0.5 * (high[first] + low[first]);
318
319 unsafe {
320 let h_ptr = high.as_ptr();
321 let l_ptr = low.as_ptr();
322 let v_ptr = volume.as_ptr();
323 let o_ptr = out.as_mut_ptr();
324
325 let mut i = first + 1;
326 while i < len {
327 let h = *h_ptr.add(i);
328 let l = *l_ptr.add(i);
329 let v = *v_ptr.add(i);
330
331 if h.is_nan() || l.is_nan() || v.is_nan() {
332 *o_ptr.add(i) = f64::NAN;
333 i += 1;
334 continue;
335 }
336
337 let current_mid = 0.5 * (h + l);
338 let range = h - l;
339 if range == 0.0 {
340 *o_ptr.add(i) = f64::NAN;
341 last_mid = current_mid;
342 i += 1;
343 continue;
344 }
345
346 let br = v / 10000.0 / range;
347 let dmid = current_mid - last_mid;
348 *o_ptr.add(i) = dmid / br;
349 last_mid = current_mid;
350
351 i += 1;
352 }
353 }
354}
355
356#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
357#[inline]
358pub fn emv_avx512(high: &[f64], low: &[f64], volume: &[f64], first: usize, out: &mut [f64]) {
359 emv_avx2(high, low, volume, first, out)
360}
361
362#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
363#[inline]
364pub fn emv_avx2(high: &[f64], low: &[f64], volume: &[f64], first: usize, out: &mut [f64]) {
365 let len = high.len().min(low.len()).min(volume.len());
366 let mut last_mid = 0.5 * (high[first] + low[first]);
367 unsafe {
368 let h_ptr = high.as_ptr();
369 let l_ptr = low.as_ptr();
370 let v_ptr = volume.as_ptr();
371 let o_ptr = out.as_mut_ptr();
372
373 let mut i = first + 1;
374 while i < len {
375 let h = *h_ptr.add(i);
376 let l = *l_ptr.add(i);
377 let v = *v_ptr.add(i);
378
379 if !(h.is_nan() || l.is_nan() || v.is_nan()) {
380 let range = h - l;
381 let current_mid = 0.5 * (h + l);
382
383 if range == 0.0 {
384 *o_ptr.add(i) = f64::NAN;
385 last_mid = current_mid;
386 } else {
387 let br = (v / 10000.0) / range;
388 let dmid = current_mid - last_mid;
389 *o_ptr.add(i) = dmid / br;
390 last_mid = current_mid;
391 }
392 } else {
393 *o_ptr.add(i) = f64::NAN;
394 }
395
396 i += 1;
397 }
398 }
399}
400
401#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
402#[target_feature(enable = "avx512f")]
403pub unsafe fn emv_avx512_short(
404 high: &[f64],
405 low: &[f64],
406 volume: &[f64],
407 first: usize,
408 out: &mut [f64],
409) {
410 emv_avx2(high, low, volume, first, out);
411}
412
413#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
414#[target_feature(enable = "avx512f")]
415pub unsafe fn emv_avx512_long(
416 high: &[f64],
417 low: &[f64],
418 volume: &[f64],
419 first: usize,
420 out: &mut [f64],
421) {
422 emv_avx2(high, low, volume, first, out);
423}
424
425#[derive(Debug, Clone)]
426pub struct EmvStream {
427 last_mid: Option<f64>,
428}
429
430impl EmvStream {
431 pub fn try_new() -> Result<Self, EmvError> {
432 Ok(Self { last_mid: None })
433 }
434
435 #[inline(always)]
436 pub fn update(&mut self, high: f64, low: f64, volume: f64) -> Option<f64> {
437 if high.is_nan() || low.is_nan() || volume.is_nan() {
438 return None;
439 }
440 let current_mid = 0.5 * (high + low);
441 if self.last_mid.is_none() {
442 self.last_mid = Some(current_mid);
443 return None;
444 }
445 let last_mid = self.last_mid.unwrap();
446 let range = high - low;
447 if range == 0.0 {
448 self.last_mid = Some(current_mid);
449 return None;
450 }
451 let br = volume / 10000.0 / range;
452 let out = (current_mid - last_mid) / br;
453 self.last_mid = Some(current_mid);
454 Some(out)
455 }
456
457 #[inline(always)]
458 pub fn update_fast(&mut self, high: f64, low: f64, volume: f64) -> Option<f64> {
459 if high.is_nan() || low.is_nan() || volume.is_nan() {
460 return None;
461 }
462 let current_mid = 0.5 * (high + low);
463 if self.last_mid.is_none() {
464 self.last_mid = Some(current_mid);
465 return None;
466 }
467 let last_mid = self.last_mid.unwrap();
468 let range = high - low;
469 if range == 0.0 {
470 self.last_mid = Some(current_mid);
471 return None;
472 }
473
474 let inv_v = fast_recip_f64(volume);
475 let out = (current_mid - last_mid) * range * 10_000.0 * inv_v;
476 self.last_mid = Some(current_mid);
477 Some(out)
478 }
479}
480
481#[inline(always)]
482fn newton_refine_recip(y0: f64, x: f64) -> f64 {
483 let t = 2.0_f64 - x.mul_add(y0, 0.0);
484 y0 * t
485}
486
487#[inline(always)]
488fn fast_recip_f64(x: f64) -> f64 {
489 #[cfg(all(
490 feature = "nightly-avx",
491 target_arch = "x86_64",
492 target_feature = "avx512f"
493 ))]
494 unsafe {
495 use core::arch::x86_64::*;
496 let vx = _mm512_set1_pd(x);
497 let rcp = _mm512_rcp14_pd(vx);
498 let lo = _mm512_castpd512_pd128(rcp);
499 let y0 = _mm_cvtsd_f64(lo);
500 let y1 = newton_refine_recip(y0, x);
501 let y2 = newton_refine_recip(y1, x);
502 return y2;
503 }
504 1.0 / x
505}
506
507#[derive(Clone, Debug)]
508pub struct EmvBatchRange {}
509
510impl Default for EmvBatchRange {
511 fn default() -> Self {
512 Self {}
513 }
514}
515
516#[derive(Clone, Debug, Default)]
517pub struct EmvBatchBuilder {
518 kernel: Kernel,
519 _range: EmvBatchRange,
520}
521
522impl EmvBatchBuilder {
523 pub fn new() -> Self {
524 Self::default()
525 }
526 pub fn kernel(mut self, k: Kernel) -> Self {
527 self.kernel = k;
528 self
529 }
530
531 pub fn apply_slices(
532 self,
533 high: &[f64],
534 low: &[f64],
535 close: &[f64],
536 volume: &[f64],
537 ) -> Result<EmvBatchOutput, EmvError> {
538 emv_batch_with_kernel(high, low, close, volume, self.kernel)
539 }
540
541 pub fn with_default_slices(
542 high: &[f64],
543 low: &[f64],
544 close: &[f64],
545 volume: &[f64],
546 k: Kernel,
547 ) -> Result<EmvBatchOutput, EmvError> {
548 EmvBatchBuilder::new()
549 .kernel(k)
550 .apply_slices(high, low, close, volume)
551 }
552
553 pub fn apply_candles(self, c: &Candles) -> Result<EmvBatchOutput, EmvError> {
554 let high = source_type(c, "high");
555 let low = source_type(c, "low");
556 let close = source_type(c, "close");
557 let volume = source_type(c, "volume");
558 self.apply_slices(high, low, close, volume)
559 }
560
561 pub fn with_default_candles(c: &Candles, k: Kernel) -> Result<EmvBatchOutput, EmvError> {
562 EmvBatchBuilder::new().kernel(k).apply_candles(c)
563 }
564}
565
566pub fn emv_batch_with_kernel(
567 high: &[f64],
568 low: &[f64],
569 _close: &[f64],
570 volume: &[f64],
571 kernel: Kernel,
572) -> Result<EmvBatchOutput, EmvError> {
573 let simd = match kernel {
574 Kernel::Auto => detect_best_batch_kernel(),
575 other if other.is_batch() => other,
576 other => return Err(EmvError::InvalidKernelForBatch(other)),
577 };
578 emv_batch_par_slice(high, low, volume, simd)
579}
580
581#[derive(Clone, Debug)]
582pub struct EmvBatchOutput {
583 pub values: Vec<f64>,
584 pub combos: Vec<EmvParams>,
585 pub rows: usize,
586 pub cols: usize,
587}
588
589impl EmvBatchOutput {
590 #[inline]
591 pub fn single_row(&self) -> &[f64] {
592 debug_assert_eq!(self.rows, 1);
593 &self.values[..self.cols]
594 }
595}
596
597#[inline(always)]
598fn expand_grid(_r: &EmvBatchRange) -> Vec<()> {
599 vec![()]
600}
601
602#[inline(always)]
603pub fn emv_batch_slice(
604 high: &[f64],
605 low: &[f64],
606 volume: &[f64],
607 kern: Kernel,
608) -> Result<EmvBatchOutput, EmvError> {
609 emv_batch_inner(high, low, volume, kern, false)
610}
611
612#[inline(always)]
613pub fn emv_batch_par_slice(
614 high: &[f64],
615 low: &[f64],
616 volume: &[f64],
617 kern: Kernel,
618) -> Result<EmvBatchOutput, EmvError> {
619 emv_batch_inner(high, low, volume, kern, true)
620}
621
622fn emv_batch_inner(
623 high: &[f64],
624 low: &[f64],
625 volume: &[f64],
626 kern: Kernel,
627 _parallel: bool,
628) -> Result<EmvBatchOutput, EmvError> {
629 let len = high.len().min(low.len()).min(volume.len());
630 if len == 0 {
631 return Err(EmvError::EmptyInputData);
632 }
633
634 let first = (0..len)
635 .find(|&i| !(high[i].is_nan() || low[i].is_nan() || volume[i].is_nan()))
636 .ok_or(EmvError::AllValuesNaN)?;
637
638 let valid = (first..len)
639 .filter(|&i| !(high[i].is_nan() || low[i].is_nan() || volume[i].is_nan()))
640 .count();
641 if valid < 2 {
642 return Err(EmvError::NotEnoughValidData { needed: 2, valid });
643 }
644
645 let rows = 1usize;
646 let cols = len;
647 let _ = rows
648 .checked_mul(cols)
649 .ok_or(EmvError::InvalidInput("rows*cols overflow"))?;
650
651 let mut buf_mu = make_uninit_matrix(rows, cols);
652 init_matrix_prefixes(&mut buf_mu, cols, &[first + 1]);
653
654 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
655 let out: &mut [f64] =
656 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
657
658 unsafe {
659 match kern {
660 Kernel::Scalar | Kernel::ScalarBatch => emv_scalar(high, low, volume, first, out),
661 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
662 Kernel::Avx2 | Kernel::Avx2Batch => emv_avx2(high, low, volume, first, out),
663 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
664 Kernel::Avx512 | Kernel::Avx512Batch => emv_avx512(high, low, volume, first, out),
665 _ => emv_scalar(high, low, volume, first, out),
666 }
667 }
668
669 let values = unsafe {
670 Vec::from_raw_parts(
671 guard.as_mut_ptr() as *mut f64,
672 guard.len(),
673 guard.capacity(),
674 )
675 };
676
677 Ok(EmvBatchOutput {
678 values,
679 combos: vec![EmvParams],
680 rows,
681 cols,
682 })
683}
684
685#[inline(always)]
686pub fn emv_row_scalar(
687 high: &[f64],
688 low: &[f64],
689 volume: &[f64],
690 first: usize,
691 _stride: usize,
692 _w_ptr: *const f64,
693 _inv_n: f64,
694 out: &mut [f64],
695) {
696 emv_scalar(high, low, volume, first, out);
697}
698
699#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
700#[inline(always)]
701pub fn emv_row_avx2(
702 high: &[f64],
703 low: &[f64],
704 volume: &[f64],
705 first: usize,
706 _stride: usize,
707 _w_ptr: *const f64,
708 _inv_n: f64,
709 out: &mut [f64],
710) {
711 emv_scalar(high, low, volume, first, out);
712}
713
714#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
715#[inline(always)]
716pub fn emv_row_avx512(
717 high: &[f64],
718 low: &[f64],
719 volume: &[f64],
720 first: usize,
721 _stride: usize,
722 _w_ptr: *const f64,
723 _inv_n: f64,
724 out: &mut [f64],
725) {
726 emv_avx512(high, low, volume, first, out);
727}
728
729#[inline(always)]
730fn expand_grid_emv(_r: &EmvBatchRange) -> Vec<()> {
731 vec![()]
732}
733
734#[cfg(feature = "python")]
735#[pyfunction(name = "emv")]
736#[pyo3(signature = (high, low, close, volume, kernel=None))]
737pub fn emv_py<'py>(
738 py: Python<'py>,
739 high: PyReadonlyArray1<'py, f64>,
740 low: PyReadonlyArray1<'py, f64>,
741 close: PyReadonlyArray1<'py, f64>,
742 volume: PyReadonlyArray1<'py, f64>,
743 kernel: Option<&str>,
744) -> PyResult<Bound<'py, PyArray1<f64>>> {
745 use numpy::{IntoPyArray, PyArrayMethods};
746
747 let high_slice = high.as_slice()?;
748 let low_slice = low.as_slice()?;
749 let close_slice = close.as_slice()?;
750 let volume_slice = volume.as_slice()?;
751 let kern = validate_kernel(kernel, false)?;
752
753 let data = EmvData::Slices {
754 high: high_slice,
755 low: low_slice,
756 close: close_slice,
757 volume: volume_slice,
758 };
759 let input = EmvInput {
760 data,
761 params: EmvParams,
762 };
763
764 let result_vec: Vec<f64> = py
765 .allow_threads(|| emv_with_kernel(&input, kern).map(|o| o.values))
766 .map_err(|e| PyValueError::new_err(e.to_string()))?;
767
768 Ok(result_vec.into_pyarray(py))
769}
770
771#[cfg(feature = "python")]
772#[pyclass(name = "EmvStream")]
773pub struct EmvStreamPy {
774 stream: EmvStream,
775}
776
777#[cfg(feature = "python")]
778#[pymethods]
779impl EmvStreamPy {
780 #[new]
781 fn new() -> PyResult<Self> {
782 let stream = EmvStream::try_new().map_err(|e| PyValueError::new_err(e.to_string()))?;
783 Ok(EmvStreamPy { stream })
784 }
785
786 fn update(&mut self, high: f64, low: f64, close: f64, volume: f64) -> Option<f64> {
787 self.stream.update(high, low, volume)
788 }
789}
790
791#[cfg(feature = "python")]
792fn emv_batch_inner_into(
793 high: &[f64],
794 low: &[f64],
795 _close: &[f64],
796 volume: &[f64],
797 _range: &EmvBatchRange,
798 kern: Kernel,
799 _parallel: bool,
800 out: &mut [f64],
801) -> Result<Vec<EmvParams>, EmvError> {
802 let len = high.len().min(low.len()).min(volume.len());
803 if len == 0 {
804 return Err(EmvError::EmptyInputData);
805 }
806
807 if out.len() != len {
808 return Err(EmvError::OutputLengthMismatch {
809 expected: len,
810 got: out.len(),
811 });
812 }
813
814 let out_mu: &mut [MaybeUninit<f64>] = unsafe {
815 core::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
816 };
817
818 let first = (0..len)
819 .find(|&i| !(high[i].is_nan() || low[i].is_nan() || volume[i].is_nan()))
820 .ok_or(EmvError::AllValuesNaN)?;
821
822 let valid = (first..len)
823 .filter(|&i| !(high[i].is_nan() || low[i].is_nan() || volume[i].is_nan()))
824 .count();
825 if valid < 2 {
826 return Err(EmvError::NotEnoughValidData { needed: 2, valid });
827 }
828
829 init_matrix_prefixes(out_mu, len, &[first + 1]);
830
831 let out_f: &mut [f64] =
832 unsafe { core::slice::from_raw_parts_mut(out_mu.as_mut_ptr() as *mut f64, out_mu.len()) };
833
834 unsafe {
835 match kern {
836 Kernel::Scalar | Kernel::ScalarBatch => emv_scalar(high, low, volume, first, out_f),
837 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
838 Kernel::Avx2 | Kernel::Avx2Batch => emv_avx2(high, low, volume, first, out_f),
839 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
840 Kernel::Avx512 | Kernel::Avx512Batch => emv_avx512(high, low, volume, first, out_f),
841 _ => emv_scalar(high, low, volume, first, out_f),
842 }
843 }
844
845 Ok(vec![EmvParams])
846}
847
848#[cfg(feature = "python")]
849#[pyfunction(name = "emv_batch")]
850#[pyo3(signature = (high, low, close, volume, kernel=None))]
851pub fn emv_batch_py<'py>(
852 py: Python<'py>,
853 high: PyReadonlyArray1<'py, f64>,
854 low: PyReadonlyArray1<'py, f64>,
855 close: PyReadonlyArray1<'py, f64>,
856 volume: PyReadonlyArray1<'py, f64>,
857 kernel: Option<&str>,
858) -> PyResult<Bound<'py, PyDict>> {
859 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
860
861 let high_slice = high.as_slice()?;
862 let low_slice = low.as_slice()?;
863 let close_slice = close.as_slice()?;
864 let volume_slice = volume.as_slice()?;
865 let kern = validate_kernel(kernel, true)?;
866
867 let sweep = EmvBatchRange {};
868 let combos = expand_grid(&sweep);
869 let rows = combos.len();
870 let cols = high_slice
871 .len()
872 .min(low_slice.len())
873 .min(volume_slice.len());
874
875 let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
876 let slice_out = unsafe { out_arr.as_slice_mut()? };
877
878 let _params = py
879 .allow_threads(|| {
880 let kernel = match kern {
881 Kernel::Auto => detect_best_batch_kernel(),
882 k => k,
883 };
884 emv_batch_inner_into(
885 high_slice,
886 low_slice,
887 close_slice,
888 volume_slice,
889 &sweep,
890 kernel,
891 true,
892 slice_out,
893 )
894 })
895 .map_err(|e| PyValueError::new_err(e.to_string()))?;
896
897 let dict = PyDict::new(py);
898 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
899
900 Ok(dict)
901}
902
903#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
904#[wasm_bindgen]
905pub fn emv_js(
906 high: &[f64],
907 low: &[f64],
908 close: &[f64],
909 volume: &[f64],
910) -> Result<Vec<f64>, JsValue> {
911 let input = EmvInput::from_slices(high, low, close, volume);
912
913 let mut output = vec![0.0; high.len().min(low.len()).min(close.len()).min(volume.len())];
914
915 emv_into_slice(&mut output, &input, Kernel::Auto)
916 .map_err(|e| JsValue::from_str(&e.to_string()))?;
917
918 Ok(output)
919}
920
921#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
922#[wasm_bindgen]
923pub fn emv_into(
924 high_ptr: *const f64,
925 low_ptr: *const f64,
926 close_ptr: *const f64,
927 volume_ptr: *const f64,
928 out_ptr: *mut f64,
929 len: usize,
930) -> Result<(), JsValue> {
931 if high_ptr.is_null()
932 || low_ptr.is_null()
933 || close_ptr.is_null()
934 || volume_ptr.is_null()
935 || out_ptr.is_null()
936 {
937 return Err(JsValue::from_str("null pointer passed to emv_into"));
938 }
939
940 unsafe {
941 let high = std::slice::from_raw_parts(high_ptr, len);
942 let low = std::slice::from_raw_parts(low_ptr, len);
943 let close = std::slice::from_raw_parts(close_ptr, len);
944 let volume = std::slice::from_raw_parts(volume_ptr, len);
945
946 let input = EmvInput::from_slices(high, low, close, volume);
947
948 if out_ptr == high_ptr as *mut f64
949 || out_ptr == low_ptr as *mut f64
950 || out_ptr == close_ptr as *mut f64
951 || out_ptr == volume_ptr as *mut f64
952 {
953 let mut temp = vec![0.0; len];
954 emv_into_slice(&mut temp, &input, Kernel::Auto)
955 .map_err(|e| JsValue::from_str(&e.to_string()))?;
956 let out = std::slice::from_raw_parts_mut(out_ptr, len);
957 out.copy_from_slice(&temp);
958 } else {
959 let out = std::slice::from_raw_parts_mut(out_ptr, len);
960 emv_into_slice(out, &input, Kernel::Auto)
961 .map_err(|e| JsValue::from_str(&e.to_string()))?;
962 }
963
964 Ok(())
965 }
966}
967
968#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
969#[wasm_bindgen]
970pub fn emv_alloc(len: usize) -> *mut f64 {
971 let mut vec = Vec::<f64>::with_capacity(len);
972 let ptr = vec.as_mut_ptr();
973 std::mem::forget(vec);
974 ptr
975}
976
977#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
978#[wasm_bindgen]
979pub fn emv_free(ptr: *mut f64, len: usize) {
980 if !ptr.is_null() {
981 unsafe {
982 let _ = Vec::from_raw_parts(ptr, len, len);
983 }
984 }
985}
986
987#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
988#[derive(Serialize, Deserialize)]
989pub struct EmvBatchConfig {}
990
991#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
992#[derive(Serialize, Deserialize)]
993pub struct EmvBatchJsOutput {
994 pub values: Vec<f64>,
995 pub combos: Vec<EmvParams>,
996 pub rows: usize,
997 pub cols: usize,
998}
999
1000#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1001#[wasm_bindgen(js_name = emv_batch)]
1002pub fn emv_batch_js(
1003 high: &[f64],
1004 low: &[f64],
1005 close: &[f64],
1006 volume: &[f64],
1007 _config: JsValue,
1008) -> Result<JsValue, JsValue> {
1009 let input = EmvInput::from_slices(high, low, close, volume);
1010 let len = high.len().min(low.len()).min(close.len()).min(volume.len());
1011
1012 let mut output = vec![0.0; len];
1013
1014 let kernel = detect_best_kernel();
1015
1016 emv_into_slice(&mut output, &input, kernel).map_err(|e| JsValue::from_str(&e.to_string()))?;
1017
1018 let js_output = EmvBatchJsOutput {
1019 values: output,
1020 combos: vec![EmvParams],
1021 rows: 1,
1022 cols: len,
1023 };
1024
1025 serde_wasm_bindgen::to_value(&js_output)
1026 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1027}
1028
1029#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1030#[wasm_bindgen]
1031pub fn emv_batch_into(
1032 high_ptr: *const f64,
1033 low_ptr: *const f64,
1034 close_ptr: *const f64,
1035 volume_ptr: *const f64,
1036 out_ptr: *mut f64,
1037 len: usize,
1038) -> Result<usize, JsValue> {
1039 if high_ptr.is_null()
1040 || low_ptr.is_null()
1041 || close_ptr.is_null()
1042 || volume_ptr.is_null()
1043 || out_ptr.is_null()
1044 {
1045 return Err(JsValue::from_str("null pointer passed to emv_batch_into"));
1046 }
1047
1048 unsafe {
1049 let high = std::slice::from_raw_parts(high_ptr, len);
1050 let low = std::slice::from_raw_parts(low_ptr, len);
1051 let close = std::slice::from_raw_parts(close_ptr, len);
1052 let volume = std::slice::from_raw_parts(volume_ptr, len);
1053
1054 let input = EmvInput::from_slices(high, low, close, volume);
1055
1056 let kernel = detect_best_kernel();
1057
1058 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1059 emv_into_slice(out, &input, kernel).map_err(|e| JsValue::from_str(&e.to_string()))?;
1060
1061 Ok(1)
1062 }
1063}
1064
1065#[cfg(test)]
1066mod tests {
1067 use super::*;
1068 use crate::skip_if_unsupported;
1069 use crate::utilities::data_loader::read_candles_from_csv;
1070
1071 fn check_emv_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1072 skip_if_unsupported!(kernel, test_name);
1073 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1074 let candles = read_candles_from_csv(file_path)?;
1075 let input = EmvInput::from_candles(&candles);
1076 let output = emv_with_kernel(&input, kernel)?;
1077 assert_eq!(output.values.len(), candles.close.len());
1078 let expected_last_five_emv = [
1079 -6488905.579799851,
1080 2371436.7401001123,
1081 -3855069.958128531,
1082 1051939.877943717,
1083 -8519287.22257077,
1084 ];
1085 let start = output.values.len().saturating_sub(5);
1086 for (i, &val) in output.values[start..].iter().enumerate() {
1087 let diff = (val - expected_last_five_emv[i]).abs();
1088 let tol = expected_last_five_emv[i].abs() * 0.0001;
1089 assert!(
1090 diff <= tol,
1091 "[{}] EMV {:?} mismatch at idx {}: got {}, expected {}, diff={}",
1092 test_name,
1093 kernel,
1094 i,
1095 val,
1096 expected_last_five_emv[i],
1097 diff
1098 );
1099 }
1100 Ok(())
1101 }
1102
1103 fn check_emv_with_default_candles(
1104 test_name: &str,
1105 kernel: Kernel,
1106 ) -> Result<(), Box<dyn Error>> {
1107 skip_if_unsupported!(kernel, test_name);
1108 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1109 let candles = read_candles_from_csv(file_path)?;
1110 let input = EmvInput::with_default_candles(&candles);
1111 let output = emv_with_kernel(&input, kernel)?;
1112 assert_eq!(output.values.len(), candles.close.len());
1113 Ok(())
1114 }
1115
1116 fn check_emv_empty_data(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1117 skip_if_unsupported!(kernel, test_name);
1118 let empty: [f64; 0] = [];
1119 let input = EmvInput::from_slices(&empty, &empty, &empty, &empty);
1120 let result = emv_with_kernel(&input, kernel);
1121 assert!(result.is_err());
1122 Ok(())
1123 }
1124
1125 fn check_emv_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1126 skip_if_unsupported!(kernel, test_name);
1127 let nan_arr = [f64::NAN, f64::NAN];
1128 let input = EmvInput::from_slices(&nan_arr, &nan_arr, &nan_arr, &nan_arr);
1129 let result = emv_with_kernel(&input, kernel);
1130 assert!(result.is_err());
1131 Ok(())
1132 }
1133
1134 fn check_emv_not_enough_data(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1135 skip_if_unsupported!(kernel, test_name);
1136 let high = [10000.0, f64::NAN];
1137 let low = [9990.0, f64::NAN];
1138 let close = [9995.0, f64::NAN];
1139 let volume = [1_000_000.0, f64::NAN];
1140 let input = EmvInput::from_slices(&high, &low, &close, &volume);
1141 let result = emv_with_kernel(&input, kernel);
1142 assert!(result.is_err());
1143 Ok(())
1144 }
1145
1146 fn check_emv_basic_calculation(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1147 skip_if_unsupported!(kernel, test_name);
1148 let high = [10.0, 12.0, 13.0, 15.0];
1149 let low = [5.0, 7.0, 8.0, 10.0];
1150 let close = [7.5, 9.0, 10.5, 12.5];
1151 let volume = [10000.0, 20000.0, 25000.0, 30000.0];
1152 let input = EmvInput::from_slices(&high, &low, &close, &volume);
1153 let output = emv_with_kernel(&input, kernel)?;
1154 assert_eq!(output.values.len(), 4);
1155 assert!(output.values[0].is_nan());
1156 for &val in &output.values[1..] {
1157 assert!(!val.is_nan());
1158 }
1159 Ok(())
1160 }
1161
1162 fn check_emv_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1163 skip_if_unsupported!(kernel, test_name);
1164 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1165 let candles = read_candles_from_csv(file_path)?;
1166 let high = source_type(&candles, "high");
1167 let low = source_type(&candles, "low");
1168 let volume = source_type(&candles, "volume");
1169
1170 let output = emv_with_kernel(&EmvInput::from_candles(&candles), kernel)?.values;
1171
1172 let mut stream = EmvStream::try_new()?;
1173 let mut stream_values = Vec::with_capacity(high.len());
1174 for i in 0..high.len() {
1175 match stream.update(high[i], low[i], volume[i]) {
1176 Some(val) => stream_values.push(val),
1177 None => stream_values.push(f64::NAN),
1178 }
1179 }
1180 assert_eq!(output.len(), stream_values.len());
1181 for (b, s) in output.iter().zip(stream_values.iter()) {
1182 if b.is_nan() && s.is_nan() {
1183 continue;
1184 }
1185 let diff = (b - s).abs();
1186 assert!(
1187 diff < 1e-9,
1188 "[{}] EMV streaming f64 mismatch: batch={}, stream={}, diff={}",
1189 test_name,
1190 b,
1191 s,
1192 diff
1193 );
1194 }
1195 Ok(())
1196 }
1197
1198 #[cfg(debug_assertions)]
1199 fn check_emv_no_poison(
1200 test_name: &str,
1201 kernel: Kernel,
1202 ) -> Result<(), Box<dyn std::error::Error>> {
1203 skip_if_unsupported!(kernel, test_name);
1204
1205 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1206 let candles = read_candles_from_csv(file_path)?;
1207
1208 let input1 = EmvInput::from_candles(&candles);
1209 let output1 = emv_with_kernel(&input1, kernel)?;
1210
1211 let high = source_type(&candles, "high");
1212 let low = source_type(&candles, "low");
1213 let close = source_type(&candles, "close");
1214 let volume = source_type(&candles, "volume");
1215 let input2 = EmvInput::from_slices(high, low, close, volume);
1216 let output2 = emv_with_kernel(&input2, kernel)?;
1217
1218 let input3 = EmvInput::with_default_candles(&candles);
1219 let output3 = emv_with_kernel(&input3, kernel)?;
1220
1221 let outputs = [
1222 ("from_candles", &output1.values),
1223 ("from_slices", &output2.values),
1224 ("with_default_candles", &output3.values),
1225 ];
1226
1227 for (method_name, values) in &outputs {
1228 for (i, &val) in values.iter().enumerate() {
1229 if val.is_nan() {
1230 continue;
1231 }
1232
1233 let bits = val.to_bits();
1234
1235 if bits == 0x11111111_11111111 {
1236 panic!(
1237 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1238 using method: {}",
1239 test_name, val, bits, i, method_name
1240 );
1241 }
1242
1243 if bits == 0x22222222_22222222 {
1244 panic!(
1245 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1246 using method: {}",
1247 test_name, val, bits, i, method_name
1248 );
1249 }
1250
1251 if bits == 0x33333333_33333333 {
1252 panic!(
1253 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1254 using method: {}",
1255 test_name, val, bits, i, method_name
1256 );
1257 }
1258 }
1259 }
1260
1261 Ok(())
1262 }
1263
1264 #[cfg(not(debug_assertions))]
1265 fn check_emv_no_poison(
1266 _test_name: &str,
1267 _kernel: Kernel,
1268 ) -> Result<(), Box<dyn std::error::Error>> {
1269 Ok(())
1270 }
1271
1272 #[cfg(feature = "proptest")]
1273 #[allow(clippy::float_cmp)]
1274 fn check_emv_property(
1275 test_name: &str,
1276 kernel: Kernel,
1277 ) -> Result<(), Box<dyn std::error::Error>> {
1278 use proptest::prelude::*;
1279 skip_if_unsupported!(kernel, test_name);
1280
1281 let strat = prop::collection::vec(
1282 (10.0f64..100000.0f64, 0.5f64..0.999f64, 1000.0f64..1e9f64),
1283 2..400,
1284 )
1285 .prop_map(|data| {
1286 let high: Vec<f64> = data.iter().map(|(h, _, _)| *h).collect();
1287 let low: Vec<f64> = data
1288 .iter()
1289 .zip(&high)
1290 .map(|((_, l_pct, _), h)| h * l_pct)
1291 .collect();
1292 let volume: Vec<f64> = data.iter().map(|(_, _, v)| *v).collect();
1293
1294 let close = high.clone();
1295 (high, low, close, volume)
1296 });
1297
1298 proptest::test_runner::TestRunner::default()
1299 .run(&strat, |(high, low, close, volume)| {
1300 let input = EmvInput::from_slices(&high, &low, &close, &volume);
1301
1302 let EmvOutput { values: out } = emv_with_kernel(&input, kernel).unwrap();
1303
1304 let EmvOutput { values: ref_out } =
1305 emv_with_kernel(&input, Kernel::Scalar).unwrap();
1306
1307 prop_assert!(
1308 out[0].is_nan(),
1309 "First EMV value should always be NaN (warmup period)"
1310 );
1311
1312 for i in 1..out.len() {
1313 if high[i].is_finite() && low[i].is_finite() && volume[i].is_finite() {
1314 let range = high[i] - low[i];
1315 if range != 0.0 {
1316 prop_assert!(
1317 out[i].is_finite(),
1318 "EMV at index {} should be finite when inputs are finite and range != 0",
1319 i
1320 );
1321 }
1322 }
1323 }
1324
1325 for i in 0..out.len() {
1326 let y = out[i];
1327 let r = ref_out[i];
1328
1329 if !y.is_finite() || !r.is_finite() {
1330 prop_assert!(
1331 y.to_bits() == r.to_bits(),
1332 "Non-finite mismatch at index {}: {} vs {}",
1333 i,
1334 y,
1335 r
1336 );
1337 } else {
1338 let y_bits = y.to_bits();
1339 let r_bits = r.to_bits();
1340 let ulp_diff = y_bits.abs_diff(r_bits);
1341
1342 prop_assert!(
1343 ulp_diff <= 3,
1344 "ULP difference too large at index {}: {} vs {} (ULP={})",
1345 i,
1346 y,
1347 r,
1348 ulp_diff
1349 );
1350 }
1351 }
1352
1353 let mut last_mid = 0.5 * (high[0] + low[0]);
1354 for i in 1..out.len() {
1355 let current_mid = 0.5 * (high[i] + low[i]);
1356 let range = high[i] - low[i];
1357
1358 if range == 0.0 {
1359 prop_assert!(
1360 out[i].is_nan(),
1361 "EMV at index {} should be NaN when range is zero",
1362 i
1363 );
1364 } else {
1365 let expected_emv = (current_mid - last_mid) / (volume[i] / 10000.0 / range);
1366
1367 if out[i].is_finite() && expected_emv.is_finite() {
1368 let diff = (out[i] - expected_emv).abs();
1369 let tolerance = 1e-9;
1370 prop_assert!(
1371 diff <= tolerance,
1372 "EMV formula mismatch at index {}: got {}, expected {}, diff={}",
1373 i,
1374 out[i],
1375 expected_emv,
1376 diff
1377 );
1378 }
1379 }
1380
1381 last_mid = current_mid;
1382 }
1383
1384 for i in 1..out.len() {
1385 if out[i].is_finite() {
1386 let price_change =
1387 (high[i] + low[i]) / 2.0 - (high[i - 1] + low[i - 1]) / 2.0;
1388 let max_reasonable = price_change.abs() * 1e8;
1389
1390 prop_assert!(
1391 out[i].abs() <= max_reasonable,
1392 "EMV at index {} seems unreasonably large: {} (price change: {})",
1393 i,
1394 out[i],
1395 price_change
1396 );
1397 }
1398 }
1399
1400 if high.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10)
1401 && low.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10)
1402 && high.iter().zip(&low).all(|(h, l)| h > l)
1403 {
1404 for i in 1..out.len() {
1405 if out[i].is_finite() {
1406 prop_assert!(
1407 out[i].abs() < 1e-9,
1408 "EMV should be ~0 for constant prices, got {} at index {}",
1409 out[i],
1410 i
1411 );
1412 }
1413 }
1414 }
1415
1416 for (i, &val) in out.iter().enumerate() {
1417 if !val.is_nan() {
1418 let bits = val.to_bits();
1419 prop_assert!(
1420 bits != 0x11111111_11111111
1421 && bits != 0x22222222_22222222
1422 && bits != 0x33333333_33333333,
1423 "Found poison value at index {}: {} (0x{:016X})",
1424 i,
1425 val,
1426 bits
1427 );
1428 }
1429 }
1430
1431 Ok(())
1432 })
1433 .unwrap();
1434
1435 Ok(())
1436 }
1437
1438 macro_rules! generate_all_emv_tests {
1439 ($($test_fn:ident),*) => {
1440 paste::paste! {
1441 $(
1442 #[test]
1443 fn [<$test_fn _scalar_f64>]() {
1444 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1445 }
1446 )*
1447 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1448 $(
1449 #[test]
1450 fn [<$test_fn _avx2_f64>]() {
1451 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1452 }
1453 #[test]
1454 fn [<$test_fn _avx512_f64>]() {
1455 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1456 }
1457 )*
1458 }
1459 }
1460 }
1461
1462 generate_all_emv_tests!(
1463 check_emv_accuracy,
1464 check_emv_with_default_candles,
1465 check_emv_empty_data,
1466 check_emv_all_nan,
1467 check_emv_not_enough_data,
1468 check_emv_basic_calculation,
1469 check_emv_streaming,
1470 check_emv_no_poison
1471 );
1472
1473 #[cfg(feature = "proptest")]
1474 generate_all_emv_tests!(check_emv_property);
1475
1476 fn check_batch_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1477 skip_if_unsupported!(kernel, test);
1478 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1479 let c = read_candles_from_csv(file)?;
1480 let output = EmvBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
1481 assert_eq!(output.values.len(), c.close.len());
1482 Ok(())
1483 }
1484
1485 macro_rules! gen_batch_tests {
1486 ($fn_name:ident) => {
1487 paste::paste! {
1488 #[test] fn [<$fn_name _scalar>]() {
1489 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1490 }
1491 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1492 #[test] fn [<$fn_name _avx2>]() {
1493 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1494 }
1495 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1496 #[test] fn [<$fn_name _avx512>]() {
1497 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1498 }
1499 #[test] fn [<$fn_name _auto_detect>]() {
1500 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1501 }
1502 }
1503 };
1504 }
1505 gen_batch_tests!(check_batch_row);
1506 gen_batch_tests!(check_batch_no_poison);
1507
1508 #[cfg(debug_assertions)]
1509 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1510 skip_if_unsupported!(kernel, test);
1511
1512 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1513 let c = read_candles_from_csv(file)?;
1514
1515 let output = EmvBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
1516
1517 for (idx, &val) in output.values.iter().enumerate() {
1518 if val.is_nan() {
1519 continue;
1520 }
1521
1522 let bits = val.to_bits();
1523 let row = idx / output.cols;
1524 let col = idx % output.cols;
1525
1526 if bits == 0x11111111_11111111 {
1527 panic!(
1528 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1529 at row {} col {} (flat index {})",
1530 test, val, bits, row, col, idx
1531 );
1532 }
1533
1534 if bits == 0x22222222_22222222 {
1535 panic!(
1536 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) \
1537 at row {} col {} (flat index {})",
1538 test, val, bits, row, col, idx
1539 );
1540 }
1541
1542 if bits == 0x33333333_33333333 {
1543 panic!(
1544 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) \
1545 at row {} col {} (flat index {})",
1546 test, val, bits, row, col, idx
1547 );
1548 }
1549 }
1550
1551 Ok(())
1552 }
1553
1554 #[cfg(not(debug_assertions))]
1555 fn check_batch_no_poison(
1556 _test: &str,
1557 _kernel: Kernel,
1558 ) -> Result<(), Box<dyn std::error::Error>> {
1559 Ok(())
1560 }
1561
1562 #[test]
1563 fn test_emv_into_matches_api() -> Result<(), Box<dyn Error>> {
1564 let n = 256usize;
1565 let mut high = Vec::with_capacity(n);
1566 let mut low = Vec::with_capacity(n);
1567 let mut close = Vec::with_capacity(n);
1568 let mut volume = Vec::with_capacity(n);
1569 for i in 0..n {
1570 let base = 100.0 + (i as f64) * 0.1;
1571 let spread = 1.0 + ((i % 5) as f64) * 0.2;
1572 let h = base + spread * 0.6;
1573 let l = base - spread * 0.4;
1574 high.push(h);
1575 low.push(l);
1576 close.push(0.5 * (h + l));
1577 volume.push(10_000.0 + ((i * 37) % 1000) as f64 * 100.0);
1578 }
1579
1580 let input = EmvInput::from_slices(&high, &low, &close, &volume);
1581 let baseline = emv(&input)?.values;
1582
1583 let mut into_out = vec![0.0; baseline.len()];
1584
1585 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1586 {
1587 emv_into(&input, &mut into_out)?;
1588 }
1589 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1590 {
1591 emv_into_slice(&mut into_out, &input, Kernel::Auto)?;
1592 }
1593
1594 assert_eq!(baseline.len(), into_out.len());
1595 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1596 (a.is_nan() && b.is_nan()) || (a == b) || (a - b).abs() <= 1e-12
1597 }
1598 for (i, (a, b)) in baseline.iter().zip(into_out.iter()).enumerate() {
1599 assert!(
1600 eq_or_both_nan(*a, *b),
1601 "divergence at idx {}: api={}, into={}",
1602 i,
1603 a,
1604 b
1605 );
1606 }
1607 Ok(())
1608 }
1609}
1610
1611#[cfg(all(feature = "python", feature = "cuda"))]
1612#[pyclass(module = "ta_indicators.cuda", name = "EmvDeviceArrayF32", unsendable)]
1613pub struct EmvDeviceArrayF32Py {
1614 pub inner: DeviceArrayF32,
1615 _ctx_guard: Arc<Context>,
1616 device_id: i32,
1617}
1618
1619#[cfg(all(feature = "python", feature = "cuda"))]
1620#[pymethods]
1621impl EmvDeviceArrayF32Py {
1622 #[getter]
1623 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1624 let d = PyDict::new(py);
1625 let inner = &self.inner;
1626 let itemsize = std::mem::size_of::<f32>();
1627 d.set_item("shape", (inner.rows, inner.cols))?;
1628 d.set_item("typestr", "<f4")?;
1629 d.set_item("strides", (inner.cols * itemsize, itemsize))?;
1630 let ptr_val = inner.buf.as_device_ptr().as_raw() as usize;
1631 d.set_item("data", (ptr_val, false))?;
1632
1633 d.set_item("version", 3)?;
1634 Ok(d)
1635 }
1636
1637 fn __dlpack_device__(&self) -> (i32, i32) {
1638 (2, self.device_id)
1639 }
1640
1641 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1642 fn __dlpack__<'py>(
1643 &mut self,
1644 py: Python<'py>,
1645 stream: Option<pyo3::PyObject>,
1646 max_version: Option<pyo3::PyObject>,
1647 dl_device: Option<pyo3::PyObject>,
1648 copy: Option<pyo3::PyObject>,
1649 ) -> PyResult<PyObject> {
1650 let (kdl, alloc_dev) = self.__dlpack_device__();
1651 if let Some(dev_obj) = dl_device.as_ref() {
1652 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1653 if dev_ty != kdl || dev_id != alloc_dev {
1654 let wants_copy = copy
1655 .as_ref()
1656 .and_then(|c| c.extract::<bool>(py).ok())
1657 .unwrap_or(false);
1658 if wants_copy {
1659 return Err(PyValueError::new_err(
1660 "device copy not implemented for __dlpack__",
1661 ));
1662 } else {
1663 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1664 }
1665 }
1666 }
1667 }
1668 let _ = stream;
1669
1670 let dummy =
1671 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1672 let inner = std::mem::replace(
1673 &mut self.inner,
1674 DeviceArrayF32 {
1675 buf: dummy,
1676 rows: 0,
1677 cols: 0,
1678 },
1679 );
1680
1681 let rows = inner.rows;
1682 let cols = inner.cols;
1683 let buf = inner.buf;
1684
1685 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1686
1687 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1688 }
1689}
1690
1691#[cfg(all(feature = "python", feature = "cuda"))]
1692impl EmvDeviceArrayF32Py {
1693 fn new_from_cuda(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
1694 Self {
1695 inner,
1696 _ctx_guard: ctx_guard,
1697 device_id: device_id as i32,
1698 }
1699 }
1700}
1701
1702#[cfg(all(feature = "python", feature = "cuda"))]
1703#[pyfunction(name = "emv_cuda_batch_dev")]
1704#[pyo3(signature = (high_f32, low_f32, volume_f32, device_id=0))]
1705pub fn emv_cuda_batch_dev_py<'py>(
1706 py: Python<'py>,
1707 high_f32: numpy::PyReadonlyArray1<'py, f32>,
1708 low_f32: numpy::PyReadonlyArray1<'py, f32>,
1709 volume_f32: numpy::PyReadonlyArray1<'py, f32>,
1710 device_id: usize,
1711) -> PyResult<EmvDeviceArrayF32Py> {
1712 if !cuda_available() {
1713 return Err(PyValueError::new_err("CUDA not available"));
1714 }
1715 let h = high_f32.as_slice()?;
1716 let l = low_f32.as_slice()?;
1717 let v = volume_f32.as_slice()?;
1718 let (inner, ctx, dev_id) = py.allow_threads(|| -> PyResult<_> {
1719 let cuda = CudaEmv::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1720 let ctx = cuda.context_arc();
1721 let dev_id = cuda.device_id();
1722 let buf = cuda
1723 .emv_batch_dev(h, l, v)
1724 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1725 Ok((buf, ctx, dev_id))
1726 })?;
1727 Ok(EmvDeviceArrayF32Py::new_from_cuda(inner, ctx, dev_id))
1728}
1729
1730#[cfg(all(feature = "python", feature = "cuda"))]
1731#[pyfunction(name = "emv_cuda_many_series_one_param_dev")]
1732#[pyo3(signature = (high_tm_f32, low_tm_f32, volume_tm_f32, device_id=0))]
1733pub fn emv_cuda_many_series_one_param_dev_py(
1734 py: Python<'_>,
1735 high_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1736 low_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1737 volume_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1738 device_id: usize,
1739) -> PyResult<EmvDeviceArrayF32Py> {
1740 if !cuda_available() {
1741 return Err(PyValueError::new_err("CUDA not available"));
1742 }
1743 use numpy::PyUntypedArrayMethods;
1744 let h_flat = high_tm_f32.as_slice()?;
1745 let l_flat = low_tm_f32.as_slice()?;
1746 let v_flat = volume_tm_f32.as_slice()?;
1747 let rows = high_tm_f32.shape()[0];
1748 let cols = high_tm_f32.shape()[1];
1749 if low_tm_f32.shape() != [rows, cols] || volume_tm_f32.shape() != [rows, cols] {
1750 return Err(PyValueError::new_err("high/low/volume shapes mismatch"));
1751 }
1752 let (inner, ctx, dev_id) = py.allow_threads(|| -> PyResult<_> {
1753 let cuda = CudaEmv::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1754 let ctx = cuda.context_arc();
1755 let dev_id = cuda.device_id();
1756 let buf = cuda
1757 .emv_many_series_one_param_time_major_dev(h_flat, l_flat, v_flat, cols, rows)
1758 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1759 Ok((buf, ctx, dev_id))
1760 })?;
1761 Ok(EmvDeviceArrayF32Py::new_from_cuda(inner, ctx, dev_id))
1762}