1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7#[cfg(feature = "python")]
8use crate::utilities::kernel_validation::validate_kernel;
9use aligned_vec::{AVec, CACHELINE_ALIGN};
10#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
11use core::arch::x86_64::*;
12#[cfg(feature = "python")]
13use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
14#[cfg(feature = "python")]
15use pyo3::exceptions::PyValueError;
16#[cfg(feature = "python")]
17use pyo3::prelude::*;
18#[cfg(feature = "python")]
19use pyo3::types::PyDict;
20#[cfg(not(target_arch = "wasm32"))]
21use rayon::prelude::*;
22use std::error::Error;
23use std::mem::MaybeUninit;
24use thiserror::Error;
25
26#[derive(Debug, Clone)]
27pub enum ViData<'a> {
28 Candles {
29 candles: &'a Candles,
30 },
31 Slices {
32 high: &'a [f64],
33 low: &'a [f64],
34 close: &'a [f64],
35 },
36}
37
38#[derive(Debug, Clone)]
39pub struct ViOutput {
40 pub plus: Vec<f64>,
41 pub minus: Vec<f64>,
42}
43
44#[derive(Debug, Clone)]
45pub struct ViParams {
46 pub period: Option<usize>,
47}
48
49impl Default for ViParams {
50 fn default() -> Self {
51 Self { period: Some(14) }
52 }
53}
54
55#[derive(Debug, Clone)]
56pub struct ViInput<'a> {
57 pub data: ViData<'a>,
58 pub params: ViParams,
59}
60
61impl<'a> ViInput<'a> {
62 #[inline]
63 pub fn from_candles(candles: &'a Candles, params: ViParams) -> Self {
64 Self {
65 data: ViData::Candles { candles },
66 params,
67 }
68 }
69 #[inline]
70 pub fn from_slices(
71 high: &'a [f64],
72 low: &'a [f64],
73 close: &'a [f64],
74 params: ViParams,
75 ) -> Self {
76 Self {
77 data: ViData::Slices { high, low, close },
78 params,
79 }
80 }
81 #[inline]
82 pub fn with_default_candles(candles: &'a Candles) -> Self {
83 Self {
84 data: ViData::Candles { candles },
85 params: ViParams::default(),
86 }
87 }
88 #[inline]
89 pub fn get_period(&self) -> usize {
90 self.params.period.unwrap_or(14)
91 }
92}
93
94#[derive(Copy, Clone, Debug)]
95pub struct ViBuilder {
96 period: Option<usize>,
97 kernel: Kernel,
98}
99
100impl Default for ViBuilder {
101 fn default() -> Self {
102 Self {
103 period: None,
104 kernel: Kernel::Auto,
105 }
106 }
107}
108
109impl ViBuilder {
110 #[inline(always)]
111 pub fn new() -> Self {
112 Self::default()
113 }
114 #[inline(always)]
115 pub fn period(mut self, n: usize) -> Self {
116 self.period = Some(n);
117 self
118 }
119 #[inline(always)]
120 pub fn kernel(mut self, k: Kernel) -> Self {
121 self.kernel = k;
122 self
123 }
124 #[inline(always)]
125 pub fn apply(self, c: &Candles) -> Result<ViOutput, ViError> {
126 let p = ViParams {
127 period: self.period,
128 };
129 let i = ViInput::from_candles(c, p);
130 vi_with_kernel(&i, self.kernel)
131 }
132 #[inline(always)]
133 pub fn apply_slices(
134 self,
135 high: &[f64],
136 low: &[f64],
137 close: &[f64],
138 ) -> Result<ViOutput, ViError> {
139 let p = ViParams {
140 period: self.period,
141 };
142 let i = ViInput::from_slices(high, low, close, p);
143 vi_with_kernel(&i, self.kernel)
144 }
145 #[inline(always)]
146 pub fn into_stream(self) -> Result<ViStream, ViError> {
147 let p = ViParams {
148 period: self.period,
149 };
150 ViStream::try_new(p)
151 }
152}
153
154#[derive(Debug, Error)]
155pub enum ViError {
156 #[error("vi: Empty data provided.")]
157 EmptyInputData,
158 #[error("vi: All values are NaN.")]
159 AllValuesNaN,
160 #[error("vi: Invalid period: period = {period}, data length = {data_len}")]
161 InvalidPeriod { period: usize, data_len: usize },
162 #[error("vi: Not enough valid data: needed = {needed}, valid = {valid}")]
163 NotEnoughValidData { needed: usize, valid: usize },
164 #[error("vi: output length mismatch: expected = {expected}, got = {got}")]
165 OutputLengthMismatch { expected: usize, got: usize },
166 #[error("vi: invalid range: start={start}, end={end}, step={step}")]
167 InvalidRange {
168 start: usize,
169 end: usize,
170 step: usize,
171 },
172 #[error("vi: invalid kernel for batch: {0:?}")]
173 InvalidKernelForBatch(Kernel),
174 #[error("vi: invalid input: {0}")]
175 InvalidInput(String),
176}
177
178#[inline]
179pub fn vi(input: &ViInput) -> Result<ViOutput, ViError> {
180 vi_with_kernel(input, Kernel::Auto)
181}
182
183#[inline(always)]
184fn vi_prepare<'a>(
185 input: &'a ViInput,
186 kernel: Kernel,
187) -> Result<(&'a [f64], &'a [f64], &'a [f64], usize, usize, Kernel), ViError> {
188 let (high, low, close) = match &input.data {
189 ViData::Candles { candles } => (
190 source_type(candles, "high"),
191 source_type(candles, "low"),
192 source_type(candles, "close"),
193 ),
194 ViData::Slices { high, low, close } => (*high, *low, *close),
195 };
196
197 if high.is_empty() || low.is_empty() || close.is_empty() {
198 return Err(ViError::EmptyInputData);
199 }
200 let len = high.len();
201 if len != low.len() || len != close.len() {
202 return Err(ViError::EmptyInputData);
203 }
204
205 let period = input.get_period();
206 if period == 0 || period > len {
207 return Err(ViError::InvalidPeriod {
208 period,
209 data_len: len,
210 });
211 }
212
213 let first = (0..len)
214 .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
215 .ok_or(ViError::AllValuesNaN)?;
216
217 if len - first < period {
218 return Err(ViError::NotEnoughValidData {
219 needed: period,
220 valid: len - first,
221 });
222 }
223
224 let chosen = match kernel {
225 Kernel::Auto => detect_best_kernel(),
226 k => k,
227 };
228 Ok((high, low, close, period, first, chosen))
229}
230
231#[inline(always)]
232fn vi_compute_into(
233 high: &[f64],
234 low: &[f64],
235 close: &[f64],
236 period: usize,
237 first: usize,
238 kernel: Kernel,
239 plus: &mut [f64],
240 minus: &mut [f64],
241) {
242 unsafe {
243 match kernel {
244 Kernel::Scalar | Kernel::ScalarBatch => {
245 vi_scalar(high, low, close, period, first, plus, minus)
246 }
247 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
248 Kernel::Avx2 | Kernel::Avx2Batch => {
249 vi_avx2(high, low, close, period, first, plus, minus)
250 }
251 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
252 Kernel::Avx512 | Kernel::Avx512Batch => {
253 vi_avx512(high, low, close, period, first, plus, minus)
254 }
255 _ => unreachable!(),
256 }
257 }
258}
259
260pub fn vi_with_kernel(input: &ViInput, kernel: Kernel) -> Result<ViOutput, ViError> {
261 let (h, l, c, period, first, chosen) = vi_prepare(input, kernel)?;
262 let mut plus = alloc_with_nan_prefix(h.len(), first + period - 1);
263 let mut minus = alloc_with_nan_prefix(h.len(), first + period - 1);
264 vi_compute_into(h, l, c, period, first, chosen, &mut plus, &mut minus);
265 Ok(ViOutput { plus, minus })
266}
267
268pub fn vi_into_slice(
269 dst_plus: &mut [f64],
270 dst_minus: &mut [f64],
271 input: &ViInput,
272 kernel: Kernel,
273) -> Result<(), ViError> {
274 let (h, l, c, period, first, chosen) = vi_prepare(input, kernel)?;
275 if dst_plus.len() != h.len() || dst_minus.len() != h.len() {
276 let expected = h.len();
277 let got = dst_plus.len().max(dst_minus.len());
278 return Err(ViError::OutputLengthMismatch { expected, got });
279 }
280 vi_compute_into(h, l, c, period, first, chosen, dst_plus, dst_minus);
281 let warm = first + period - 1;
282 for i in 0..warm {
283 dst_plus[i] = f64::NAN;
284 dst_minus[i] = f64::NAN;
285 }
286 Ok(())
287}
288
289#[inline(always)]
290pub unsafe fn vi_scalar(
291 high: &[f64],
292 low: &[f64],
293 close: &[f64],
294 period: usize,
295 first: usize,
296 plus: &mut [f64],
297 minus: &mut [f64],
298) {
299 let n = high.len();
300 if n == 0 {
301 return;
302 }
303
304 let warm = first + period - 1;
305
306 let h = high.as_ptr();
307 let l = low.as_ptr();
308 let c = close.as_ptr();
309 let p_out = plus.as_mut_ptr();
310 let m_out = minus.as_mut_ptr();
311
312 let mut tr_buf: Vec<f64> = Vec::with_capacity(period);
313 let mut vp_buf: Vec<f64> = Vec::with_capacity(period);
314 let mut vm_buf: Vec<f64> = Vec::with_capacity(period);
315 tr_buf.set_len(period);
316 vp_buf.set_len(period);
317 vm_buf.set_len(period);
318 let trp = tr_buf.as_mut_ptr();
319 let vpp = vp_buf.as_mut_ptr();
320 let vmp = vm_buf.as_mut_ptr();
321
322 let mut prev_h = *h.add(first);
323 let mut prev_l = *l.add(first);
324 let mut prev_c = *c.add(first);
325
326 let mut sum_tr = prev_h - prev_l;
327 let mut sum_vp = 0.0f64;
328 let mut sum_vm = 0.0f64;
329
330 *trp.add(0) = sum_tr;
331 *vpp.add(0) = 0.0;
332 *vmp.add(0) = 0.0;
333
334 if period == 1 {
335 *p_out.add(warm) = 0.0;
336 *m_out.add(warm) = 0.0;
337 }
338
339 let mut r = if period == 1 { 0 } else { 1 };
340
341 let mut i = first + 1;
342 while i < n {
343 let hi = *h.add(i);
344 let lo = *l.add(i);
345
346 let hl = hi - lo;
347 let hc = (hi - prev_c).abs();
348 let lc = (lo - prev_c).abs();
349 let mut tr_new = if hl > hc { hl } else { hc };
350 if lc > tr_new {
351 tr_new = lc;
352 }
353
354 let vp_new = (hi - prev_l).abs();
355 let vm_new = (lo - prev_h).abs();
356
357 if i <= warm {
358 sum_tr += tr_new;
359 sum_vp += vp_new;
360 sum_vm += vm_new;
361
362 *trp.add(r) = tr_new;
363 *vpp.add(r) = vp_new;
364 *vmp.add(r) = vm_new;
365
366 if i == warm {
367 *p_out.add(i) = sum_vp / sum_tr;
368 *m_out.add(i) = sum_vm / sum_tr;
369 }
370 } else {
371 let tr_old = *trp.add(r);
372 let vp_old = *vpp.add(r);
373 let vm_old = *vmp.add(r);
374
375 sum_tr += tr_new - tr_old;
376 sum_vp += vp_new - vp_old;
377 sum_vm += vm_new - vm_old;
378
379 *trp.add(r) = tr_new;
380 *vpp.add(r) = vp_new;
381 *vmp.add(r) = vm_new;
382
383 *p_out.add(i) = sum_vp / sum_tr;
384 *m_out.add(i) = sum_vm / sum_tr;
385 }
386
387 prev_h = hi;
388 prev_l = lo;
389 prev_c = *c.add(i);
390
391 r += 1;
392 if r == period {
393 r = 0;
394 }
395
396 i += 1;
397 }
398}
399
400#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
401#[inline]
402pub unsafe fn vi_avx2(
403 high: &[f64],
404 low: &[f64],
405 close: &[f64],
406 period: usize,
407 first: usize,
408 plus: &mut [f64],
409 minus: &mut [f64],
410) {
411 vi_scalar(high, low, close, period, first, plus, minus);
412}
413
414#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
415#[inline]
416pub unsafe fn vi_avx512(
417 high: &[f64],
418 low: &[f64],
419 close: &[f64],
420 period: usize,
421 first: usize,
422 plus: &mut [f64],
423 minus: &mut [f64],
424) {
425 if period <= 32 {
426 vi_avx512_short(high, low, close, period, first, plus, minus);
427 } else {
428 vi_avx512_long(high, low, close, period, first, plus, minus);
429 }
430}
431
432#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
433#[inline]
434pub unsafe fn vi_avx512_short(
435 high: &[f64],
436 low: &[f64],
437 close: &[f64],
438 period: usize,
439 first: usize,
440 plus: &mut [f64],
441 minus: &mut [f64],
442) {
443 vi_scalar(high, low, close, period, first, plus, minus);
444}
445
446#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
447#[inline]
448pub unsafe fn vi_avx512_long(
449 high: &[f64],
450 low: &[f64],
451 close: &[f64],
452 period: usize,
453 first: usize,
454 plus: &mut [f64],
455 minus: &mut [f64],
456) {
457 vi_scalar(high, low, close, period, first, plus, minus);
458}
459
460#[derive(Debug, Clone)]
461pub struct ViStream {
462 period: usize,
463 tr: Vec<f64>,
464 vp: Vec<f64>,
465 vm: Vec<f64>,
466 idx: usize,
467 filled: bool,
468 sum_tr: f64,
469 sum_vp: f64,
470 sum_vm: f64,
471}
472
473impl ViStream {
474 pub fn try_new(params: ViParams) -> Result<Self, ViError> {
475 let period = params.period.unwrap_or(14);
476 if period == 0 {
477 return Err(ViError::InvalidPeriod {
478 period,
479 data_len: 0,
480 });
481 }
482 Ok(Self {
483 period,
484 tr: vec![0.0; period],
485 vp: vec![0.0; period],
486 vm: vec![0.0; period],
487 idx: 0,
488 filled: false,
489 sum_tr: 0.0,
490 sum_vp: 0.0,
491 sum_vm: 0.0,
492 })
493 }
494
495 pub fn update(
496 &mut self,
497 high: f64,
498 low: f64,
499 close: f64,
500 prev_low: f64,
501 prev_high: f64,
502 prev_close: f64,
503 ) -> Option<(f64, f64)> {
504 let _ = close;
505
506 let i = self.idx;
507
508 let hl = high - low;
509 let hc = (high - prev_close).abs();
510 let lc = (low - prev_close).abs();
511 let tr_new = hl.max(hc.max(lc));
512
513 let vp_new = (high - prev_low).abs();
514 let vm_new = (low - prev_high).abs();
515
516 let tr_old = self.tr[i];
517 let vp_old = self.vp[i];
518 let vm_old = self.vm[i];
519
520 self.sum_tr += tr_new - tr_old;
521 self.sum_vp += vp_new - vp_old;
522 self.sum_vm += vm_new - vm_old;
523
524 self.tr[i] = tr_new;
525 self.vp[i] = vp_new;
526 self.vm[i] = vm_new;
527
528 self.idx += 1;
529 if self.idx == self.period {
530 self.idx = 0;
531 self.filled = true;
532 }
533
534 if self.filled {
535 let inv_tr = 1.0 / self.sum_tr;
536 let vi_p = self.sum_vp * inv_tr;
537 let vi_m = self.sum_vm * inv_tr;
538 Some((vi_p, vi_m))
539 } else {
540 None
541 }
542 }
543}
544
545#[derive(Clone, Debug)]
546pub struct ViBatchRange {
547 pub period: (usize, usize, usize),
548}
549
550impl Default for ViBatchRange {
551 fn default() -> Self {
552 Self {
553 period: (14, 263, 1),
554 }
555 }
556}
557
558#[derive(Clone, Debug, Default)]
559pub struct ViBatchBuilder {
560 range: ViBatchRange,
561 kernel: Kernel,
562}
563
564impl ViBatchBuilder {
565 pub fn new() -> Self {
566 Self::default()
567 }
568 pub fn kernel(mut self, k: Kernel) -> Self {
569 self.kernel = k;
570 self
571 }
572 #[inline]
573 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
574 self.range.period = (start, end, step);
575 self
576 }
577 #[inline]
578 pub fn period_static(mut self, p: usize) -> Self {
579 self.range.period = (p, p, 0);
580 self
581 }
582 pub fn apply_slices(
583 self,
584 high: &[f64],
585 low: &[f64],
586 close: &[f64],
587 ) -> Result<ViBatchOutput, ViError> {
588 vi_batch_with_kernel(high, low, close, &self.range, self.kernel)
589 }
590 pub fn apply_candles(self, c: &Candles) -> Result<ViBatchOutput, ViError> {
591 let high = source_type(c, "high");
592 let low = source_type(c, "low");
593 let close = source_type(c, "close");
594 self.apply_slices(high, low, close)
595 }
596}
597
598#[derive(Clone, Debug)]
599pub struct ViBatchOutput {
600 pub plus: Vec<f64>,
601 pub minus: Vec<f64>,
602 pub combos: Vec<ViParams>,
603 pub rows: usize,
604 pub cols: usize,
605}
606impl ViBatchOutput {
607 pub fn row_for_params(&self, p: &ViParams) -> Option<usize> {
608 self.combos
609 .iter()
610 .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
611 }
612 pub fn plus_for(&self, p: &ViParams) -> Option<&[f64]> {
613 self.row_for_params(p).map(|row| {
614 let start = row * self.cols;
615 &self.plus[start..start + self.cols]
616 })
617 }
618 pub fn minus_for(&self, p: &ViParams) -> Option<&[f64]> {
619 self.row_for_params(p).map(|row| {
620 let start = row * self.cols;
621 &self.minus[start..start + self.cols]
622 })
623 }
624}
625
626#[inline(always)]
627fn expand_grid(r: &ViBatchRange) -> Vec<ViParams> {
628 fn axis_usize(range: (usize, usize, usize)) -> Result<Vec<usize>, ViError> {
629 let (start, end, step) = range;
630 if step == 0 || start == end {
631 return Ok(vec![start]);
632 }
633 if start < end {
634 let v: Vec<usize> = (start..=end).step_by(step).collect();
635 if v.is_empty() {
636 return Err(ViError::InvalidRange { start, end, step });
637 }
638 return Ok(v);
639 }
640 let mut v = Vec::new();
641 let mut cur = start;
642 loop {
643 v.push(cur);
644 if cur == end {
645 break;
646 }
647 cur = cur
648 .checked_sub(step)
649 .ok_or(ViError::InvalidRange { start, end, step })?;
650 if cur < end {
651 break;
652 }
653 }
654 if v.is_empty() {
655 return Err(ViError::InvalidRange { start, end, step });
656 }
657 Ok(v)
658 }
659
660 let periods = match axis_usize(r.period) {
661 Ok(v) => v,
662 Err(_) => return Vec::new(),
663 };
664 let mut out = Vec::with_capacity(periods.len());
665 for &p in &periods {
666 out.push(ViParams { period: Some(p) });
667 }
668 out
669}
670
671pub fn vi_batch_with_kernel(
672 high: &[f64],
673 low: &[f64],
674 close: &[f64],
675 sweep: &ViBatchRange,
676 k: Kernel,
677) -> Result<ViBatchOutput, ViError> {
678 let kernel = match k {
679 Kernel::Auto => match detect_best_batch_kernel() {
680 Kernel::Avx512Batch => Kernel::Avx2Batch,
681 other => other,
682 },
683 other if other.is_batch() => other,
684 other => {
685 return Err(ViError::InvalidKernelForBatch(other));
686 }
687 };
688 let simd = match kernel {
689 Kernel::Avx512Batch => Kernel::Avx512,
690 Kernel::Avx2Batch => Kernel::Avx2,
691 Kernel::ScalarBatch => Kernel::Scalar,
692 _ => unreachable!(),
693 };
694 vi_batch_par_slice(high, low, close, sweep, simd)
695}
696
697pub fn vi_batch_slice(
698 high: &[f64],
699 low: &[f64],
700 close: &[f64],
701 sweep: &ViBatchRange,
702 kern: Kernel,
703) -> Result<ViBatchOutput, ViError> {
704 vi_batch_inner(high, low, close, sweep, kern, false)
705}
706pub fn vi_batch_par_slice(
707 high: &[f64],
708 low: &[f64],
709 close: &[f64],
710 sweep: &ViBatchRange,
711 kern: Kernel,
712) -> Result<ViBatchOutput, ViError> {
713 vi_batch_inner(high, low, close, sweep, kern, true)
714}
715#[inline(always)]
716fn vi_batch_inner(
717 high: &[f64],
718 low: &[f64],
719 close: &[f64],
720 sweep: &ViBatchRange,
721 kern: Kernel,
722 parallel: bool,
723) -> Result<ViBatchOutput, ViError> {
724 let combos = expand_grid(sweep);
725 if combos.is_empty() {
726 return Err(ViError::InvalidRange {
727 start: sweep.period.0,
728 end: sweep.period.1,
729 step: sweep.period.2,
730 });
731 }
732 if high.is_empty() || low.is_empty() || close.is_empty() {
733 return Err(ViError::EmptyInputData);
734 }
735 let first = (0..high.len())
736 .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
737 .ok_or(ViError::AllValuesNaN)?;
738 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
739 if high.len() - first < max_p {
740 return Err(ViError::NotEnoughValidData {
741 needed: max_p,
742 valid: high.len() - first,
743 });
744 }
745 let rows = combos.len();
746 let cols = high.len();
747 rows.checked_mul(cols)
748 .ok_or_else(|| ViError::InvalidRange {
749 start: sweep.period.0,
750 end: sweep.period.1,
751 step: sweep.period.2,
752 })?;
753
754 let mut plus_mu = make_uninit_matrix(rows, cols);
755 let mut minus_mu = make_uninit_matrix(rows, cols);
756
757 let mut warm: Vec<usize> = Vec::with_capacity(combos.len());
758 for c in &combos {
759 let p = c.period.unwrap();
760 let warm_i = first
761 .checked_add(p)
762 .and_then(|v| v.checked_sub(1))
763 .ok_or_else(|| ViError::InvalidPeriod {
764 period: p,
765 data_len: high.len(),
766 })?;
767 warm.push(warm_i);
768 }
769
770 init_matrix_prefixes(&mut plus_mu, cols, &warm);
771 init_matrix_prefixes(&mut minus_mu, cols, &warm);
772
773 let mut plus_guard = core::mem::ManuallyDrop::new(plus_mu);
774 let mut minus_guard = core::mem::ManuallyDrop::new(minus_mu);
775 let plus: &mut [f64] = unsafe {
776 core::slice::from_raw_parts_mut(plus_guard.as_mut_ptr() as *mut f64, plus_guard.len())
777 };
778 let minus: &mut [f64] = unsafe {
779 core::slice::from_raw_parts_mut(minus_guard.as_mut_ptr() as *mut f64, minus_guard.len())
780 };
781
782 let mut pfx_tr = vec![0.0f64; cols];
783 let mut pfx_vp = vec![0.0f64; cols];
784 let mut pfx_vm = vec![0.0f64; cols];
785 if cols > 0 && first < cols {
786 unsafe {
787 match kern {
788 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
789 Kernel::Avx512 | Kernel::Avx512Batch => vi_prefix_avx512(
790 high,
791 low,
792 close,
793 first,
794 &mut pfx_tr,
795 &mut pfx_vp,
796 &mut pfx_vm,
797 ),
798 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
799 Kernel::Avx2 | Kernel::Avx2Batch => vi_prefix_avx2(
800 high,
801 low,
802 close,
803 first,
804 &mut pfx_tr,
805 &mut pfx_vp,
806 &mut pfx_vm,
807 ),
808 _ => vi_prefix_scalar(
809 high,
810 low,
811 close,
812 first,
813 &mut pfx_tr,
814 &mut pfx_vp,
815 &mut pfx_vm,
816 ),
817 }
818 }
819 }
820
821 let do_row = |row: usize, plus_row: &mut [f64], minus_row: &mut [f64]| {
822 let period = combos[row].period.unwrap();
823 let warm = first + period - 1;
824 if warm >= cols {
825 return;
826 }
827 let mut i = warm;
828 while i < cols {
829 let tr_sum = if i >= period {
830 pfx_tr[i] - pfx_tr[i - period]
831 } else {
832 pfx_tr[i]
833 };
834 let vp_sum = if i >= period {
835 pfx_vp[i] - pfx_vp[i - period]
836 } else {
837 pfx_vp[i]
838 };
839 let vm_sum = if i >= period {
840 pfx_vm[i] - pfx_vm[i - period]
841 } else {
842 pfx_vm[i]
843 };
844 plus_row[i] = vp_sum / tr_sum;
845 minus_row[i] = vm_sum / tr_sum;
846 i += 1;
847 }
848 };
849 if parallel {
850 #[cfg(not(target_arch = "wasm32"))]
851 {
852 plus.par_chunks_mut(cols)
853 .zip(minus.par_chunks_mut(cols))
854 .enumerate()
855 .for_each(|(row, (p, m))| do_row(row, p, m));
856 }
857
858 #[cfg(target_arch = "wasm32")]
859 {
860 for ((row, p), m) in plus
861 .chunks_mut(cols)
862 .enumerate()
863 .zip(minus.chunks_mut(cols))
864 {
865 do_row(row, p, m);
866 }
867 }
868 } else {
869 for ((row, p), m) in plus
870 .chunks_mut(cols)
871 .enumerate()
872 .zip(minus.chunks_mut(cols))
873 {
874 do_row(row, p, m);
875 }
876 }
877
878 let plus_vec = unsafe {
879 Vec::from_raw_parts(
880 plus_guard.as_mut_ptr() as *mut f64,
881 plus_guard.len(),
882 plus_guard.capacity(),
883 )
884 };
885 let minus_vec = unsafe {
886 Vec::from_raw_parts(
887 minus_guard.as_mut_ptr() as *mut f64,
888 minus_guard.len(),
889 minus_guard.capacity(),
890 )
891 };
892
893 Ok(ViBatchOutput {
894 plus: plus_vec,
895 minus: minus_vec,
896 combos,
897 rows,
898 cols,
899 })
900}
901
902#[inline(always)]
903unsafe fn vi_row_scalar(
904 high: &[f64],
905 low: &[f64],
906 close: &[f64],
907 first: usize,
908 period: usize,
909 plus: &mut [f64],
910 minus: &mut [f64],
911) {
912 vi_scalar(high, low, close, period, first, plus, minus);
913}
914#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
915#[inline(always)]
916unsafe fn vi_row_avx2(
917 high: &[f64],
918 low: &[f64],
919 close: &[f64],
920 first: usize,
921 period: usize,
922 plus: &mut [f64],
923 minus: &mut [f64],
924) {
925 vi_scalar(high, low, close, period, first, plus, minus);
926}
927#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
928#[inline(always)]
929unsafe fn vi_row_avx512(
930 high: &[f64],
931 low: &[f64],
932 close: &[f64],
933 first: usize,
934 period: usize,
935 plus: &mut [f64],
936 minus: &mut [f64],
937) {
938 if period <= 32 {
939 vi_row_avx512_short(high, low, close, first, period, plus, minus);
940 } else {
941 vi_row_avx512_long(high, low, close, first, period, plus, minus);
942 }
943}
944#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
945#[inline(always)]
946unsafe fn vi_row_avx512_short(
947 high: &[f64],
948 low: &[f64],
949 close: &[f64],
950 first: usize,
951 period: usize,
952 plus: &mut [f64],
953 minus: &mut [f64],
954) {
955 vi_scalar(high, low, close, period, first, plus, minus);
956}
957#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
958#[inline(always)]
959unsafe fn vi_row_avx512_long(
960 high: &[f64],
961 low: &[f64],
962 close: &[f64],
963 first: usize,
964 period: usize,
965 plus: &mut [f64],
966 minus: &mut [f64],
967) {
968 vi_scalar(high, low, close, period, first, plus, minus);
969}
970
971fn vi_batch_inner_into(
972 high: &[f64],
973 low: &[f64],
974 close: &[f64],
975 sweep: &ViBatchRange,
976 kernel: Kernel,
977 parallel: bool,
978 out_plus: &mut [f64],
979 out_minus: &mut [f64],
980) -> Result<Vec<ViParams>, ViError> {
981 let combos = expand_grid(&sweep);
982 let rows = combos.len();
983 let cols = close.len();
984 if rows == 0 {
985 return Err(ViError::InvalidRange {
986 start: sweep.period.0,
987 end: sweep.period.1,
988 step: sweep.period.2,
989 });
990 }
991 let expected = rows
992 .checked_mul(cols)
993 .ok_or_else(|| ViError::InvalidRange {
994 start: sweep.period.0,
995 end: sweep.period.1,
996 step: sweep.period.2,
997 })?;
998 if out_plus.len() != expected || out_minus.len() != expected {
999 let got = out_plus.len().max(out_minus.len());
1000 return Err(ViError::OutputLengthMismatch { expected, got });
1001 }
1002
1003 if high.is_empty() || low.is_empty() || close.is_empty() {
1004 return Err(ViError::EmptyInputData);
1005 }
1006 let first = (0..high.len())
1007 .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
1008 .ok_or(ViError::AllValuesNaN)?;
1009 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1010 if high.len() - first < max_p {
1011 return Err(ViError::NotEnoughValidData {
1012 needed: max_p,
1013 valid: high.len() - first,
1014 });
1015 }
1016
1017 let cols = close.len();
1018 let mut pfx_tr = vec![0.0f64; cols];
1019 let mut pfx_vp = vec![0.0f64; cols];
1020 let mut pfx_vm = vec![0.0f64; cols];
1021 if cols > 0 && first < cols {
1022 unsafe {
1023 match kernel {
1024 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1025 Kernel::Avx512 | Kernel::Avx512Batch => vi_prefix_avx512(
1026 high,
1027 low,
1028 close,
1029 first,
1030 &mut pfx_tr,
1031 &mut pfx_vp,
1032 &mut pfx_vm,
1033 ),
1034 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1035 Kernel::Avx2 | Kernel::Avx2Batch => vi_prefix_avx2(
1036 high,
1037 low,
1038 close,
1039 first,
1040 &mut pfx_tr,
1041 &mut pfx_vp,
1042 &mut pfx_vm,
1043 ),
1044 _ => vi_prefix_scalar(
1045 high,
1046 low,
1047 close,
1048 first,
1049 &mut pfx_tr,
1050 &mut pfx_vp,
1051 &mut pfx_vm,
1052 ),
1053 }
1054 }
1055 }
1056
1057 let do_row = |row: usize, p_row: &mut [f64], m_row: &mut [f64]| {
1058 let period = combos[row].period.unwrap();
1059 let warm = first + period - 1;
1060
1061 for i in 0..warm.min(cols) {
1062 p_row[i] = f64::NAN;
1063 m_row[i] = f64::NAN;
1064 }
1065 if warm >= cols {
1066 return;
1067 }
1068 let mut i = warm;
1069 while i < cols {
1070 let tr_sum = if i >= period {
1071 pfx_tr[i] - pfx_tr[i - period]
1072 } else {
1073 pfx_tr[i]
1074 };
1075 let vp_sum = if i >= period {
1076 pfx_vp[i] - pfx_vp[i - period]
1077 } else {
1078 pfx_vp[i]
1079 };
1080 let vm_sum = if i >= period {
1081 pfx_vm[i] - pfx_vm[i - period]
1082 } else {
1083 pfx_vm[i]
1084 };
1085 p_row[i] = vp_sum / tr_sum;
1086 m_row[i] = vm_sum / tr_sum;
1087 i += 1;
1088 }
1089 };
1090
1091 if parallel {
1092 #[cfg(not(target_arch = "wasm32"))]
1093 out_plus
1094 .par_chunks_mut(cols)
1095 .zip(out_minus.par_chunks_mut(cols))
1096 .enumerate()
1097 .for_each(|(row, (p, m))| do_row(row, p, m));
1098 #[cfg(target_arch = "wasm32")]
1099 for ((row, p), m) in out_plus
1100 .chunks_mut(cols)
1101 .enumerate()
1102 .zip(out_minus.chunks_mut(cols))
1103 {
1104 do_row(row, p, m);
1105 }
1106 } else {
1107 for ((row, p), m) in out_plus
1108 .chunks_mut(cols)
1109 .enumerate()
1110 .zip(out_minus.chunks_mut(cols))
1111 {
1112 do_row(row, p, m);
1113 }
1114 }
1115 Ok(combos)
1116}
1117
1118#[inline(always)]
1119unsafe fn vi_prefix_scalar(
1120 high: &[f64],
1121 low: &[f64],
1122 close: &[f64],
1123 first: usize,
1124 pfx_tr: &mut [f64],
1125 pfx_vp: &mut [f64],
1126 pfx_vm: &mut [f64],
1127) {
1128 let n = high.len();
1129 pfx_tr[first] = high[first] - low[first];
1130 pfx_vp[first] = 0.0;
1131 pfx_vm[first] = 0.0;
1132 let mut prev_h = high[first];
1133 let mut prev_l = low[first];
1134 let mut prev_c = close[first];
1135 let mut i = first + 1;
1136 while i < n {
1137 let hi = high[i];
1138 let lo = low[i];
1139 let hl = hi - lo;
1140 let hc = (hi - prev_c).abs();
1141 let lc = (lo - prev_c).abs();
1142 let tr_i = hl.max(hc.max(lc));
1143 let vp_i = (hi - prev_l).abs();
1144 let vm_i = (lo - prev_h).abs();
1145 pfx_tr[i] = pfx_tr[i - 1] + tr_i;
1146 pfx_vp[i] = pfx_vp[i - 1] + vp_i;
1147 pfx_vm[i] = pfx_vm[i - 1] + vm_i;
1148 prev_h = hi;
1149 prev_l = lo;
1150 prev_c = close[i];
1151 i += 1;
1152 }
1153}
1154
1155#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1156#[inline(always)]
1157unsafe fn abs256(x: __m256d) -> __m256d {
1158 let zero = _mm256_set1_pd(0.0);
1159 _mm256_max_pd(x, _mm256_sub_pd(zero, x))
1160}
1161#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1162#[inline(always)]
1163unsafe fn vi_prefix_avx2(
1164 high: &[f64],
1165 low: &[f64],
1166 close: &[f64],
1167 first: usize,
1168 pfx_tr: &mut [f64],
1169 pfx_vp: &mut [f64],
1170 pfx_vm: &mut [f64],
1171) {
1172 use core::arch::x86_64::*;
1173 let n = high.len();
1174 pfx_tr[first] = high[first] - low[first];
1175 pfx_vp[first] = 0.0;
1176 pfx_vm[first] = 0.0;
1177 let mut i = first + 1;
1178
1179 let mut carry_tr = pfx_tr[i - 1];
1180 let mut carry_vp = pfx_vp[i - 1];
1181 let mut carry_vm = pfx_vm[i - 1];
1182 let step = 4;
1183 while i + step <= n {
1184 let v_hi = _mm256_loadu_pd(high.as_ptr().add(i));
1185 let v_lo = _mm256_loadu_pd(low.as_ptr().add(i));
1186 let v_cl_prev = _mm256_loadu_pd(close.as_ptr().add(i - 1));
1187 let v_lo_prev = _mm256_loadu_pd(low.as_ptr().add(i - 1));
1188 let v_hi_prev = _mm256_loadu_pd(high.as_ptr().add(i - 1));
1189
1190 let hl = _mm256_sub_pd(v_hi, v_lo);
1191 let hc = abs256(_mm256_sub_pd(v_hi, v_cl_prev));
1192 let lc = abs256(_mm256_sub_pd(v_lo, v_cl_prev));
1193 let tr_v = _mm256_max_pd(hl, _mm256_max_pd(hc, lc));
1194 let vp_v = abs256(_mm256_sub_pd(v_hi, v_lo_prev));
1195 let vm_v = abs256(_mm256_sub_pd(v_lo, v_hi_prev));
1196
1197 let mut tr_tmp = [0.0f64; 4];
1198 let mut vp_tmp = [0.0f64; 4];
1199 let mut vm_tmp = [0.0f64; 4];
1200 _mm256_storeu_pd(tr_tmp.as_mut_ptr(), tr_v);
1201 _mm256_storeu_pd(vp_tmp.as_mut_ptr(), vp_v);
1202 _mm256_storeu_pd(vm_tmp.as_mut_ptr(), vm_v);
1203 let mut k = 0;
1204 while k < step {
1205 carry_tr += tr_tmp[k];
1206 carry_vp += vp_tmp[k];
1207 carry_vm += vm_tmp[k];
1208 pfx_tr[i + k] = carry_tr;
1209 pfx_vp[i + k] = carry_vp;
1210 pfx_vm[i + k] = carry_vm;
1211 k += 1;
1212 }
1213
1214 i += step;
1215 }
1216 while i < n {
1217 let hi = *high.get_unchecked(i);
1218 let lo = *low.get_unchecked(i);
1219 let prev_c = *close.get_unchecked(i - 1);
1220 let prev_l = *low.get_unchecked(i - 1);
1221 let prev_h = *high.get_unchecked(i - 1);
1222 let hl = hi - lo;
1223 let hc = (hi - prev_c).abs();
1224 let lc = (lo - prev_c).abs();
1225 let tr_i = hl.max(hc.max(lc));
1226 let vp_i = (hi - prev_l).abs();
1227 let vm_i = (lo - prev_h).abs();
1228 carry_tr += tr_i;
1229 carry_vp += vp_i;
1230 carry_vm += vm_i;
1231 pfx_tr[i] = carry_tr;
1232 pfx_vp[i] = carry_vp;
1233 pfx_vm[i] = carry_vm;
1234 i += 1;
1235 }
1236}
1237
1238#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1239#[inline(always)]
1240unsafe fn abs512(x: __m512d) -> __m512d {
1241 let zero = _mm512_set1_pd(0.0);
1242 _mm512_max_pd(x, _mm512_sub_pd(zero, x))
1243}
1244#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1245#[inline(always)]
1246unsafe fn vi_prefix_avx512(
1247 high: &[f64],
1248 low: &[f64],
1249 close: &[f64],
1250 first: usize,
1251 pfx_tr: &mut [f64],
1252 pfx_vp: &mut [f64],
1253 pfx_vm: &mut [f64],
1254) {
1255 use core::arch::x86_64::*;
1256 let n = high.len();
1257 pfx_tr[first] = high[first] - low[first];
1258 pfx_vp[first] = 0.0;
1259 pfx_vm[first] = 0.0;
1260 let mut i = first + 1;
1261 let mut carry_tr = pfx_tr[i - 1];
1262 let mut carry_vp = pfx_vp[i - 1];
1263 let mut carry_vm = pfx_vm[i - 1];
1264 let step = 8;
1265 while i + step <= n {
1266 let v_hi = _mm512_loadu_pd(high.as_ptr().add(i));
1267 let v_lo = _mm512_loadu_pd(low.as_ptr().add(i));
1268 let v_cl_prev = _mm512_loadu_pd(close.as_ptr().add(i - 1));
1269 let v_lo_prev = _mm512_loadu_pd(low.as_ptr().add(i - 1));
1270 let v_hi_prev = _mm512_loadu_pd(high.as_ptr().add(i - 1));
1271
1272 let hl = _mm512_sub_pd(v_hi, v_lo);
1273 let hc = abs512(_mm512_sub_pd(v_hi, v_cl_prev));
1274 let lc = abs512(_mm512_sub_pd(v_lo, v_cl_prev));
1275 let tr_v = _mm512_max_pd(hl, _mm512_max_pd(hc, lc));
1276 let vp_v = abs512(_mm512_sub_pd(v_hi, v_lo_prev));
1277 let vm_v = abs512(_mm512_sub_pd(v_lo, v_hi_prev));
1278
1279 let mut tr_tmp = [0.0f64; 8];
1280 let mut vp_tmp = [0.0f64; 8];
1281 let mut vm_tmp = [0.0f64; 8];
1282 _mm512_storeu_pd(tr_tmp.as_mut_ptr(), tr_v);
1283 _mm512_storeu_pd(vp_tmp.as_mut_ptr(), vp_v);
1284 _mm512_storeu_pd(vm_tmp.as_mut_ptr(), vm_v);
1285 let mut k = 0;
1286 while k < step {
1287 carry_tr += tr_tmp[k];
1288 carry_vp += vp_tmp[k];
1289 carry_vm += vm_tmp[k];
1290 pfx_tr[i + k] = carry_tr;
1291 pfx_vp[i + k] = carry_vp;
1292 pfx_vm[i + k] = carry_vm;
1293 k += 1;
1294 }
1295 i += step;
1296 }
1297 while i < n {
1298 let hi = *high.get_unchecked(i);
1299 let lo = *low.get_unchecked(i);
1300 let prev_c = *close.get_unchecked(i - 1);
1301 let prev_l = *low.get_unchecked(i - 1);
1302 let prev_h = *high.get_unchecked(i - 1);
1303 let hl = hi - lo;
1304 let hc = (hi - prev_c).abs();
1305 let lc = (lo - prev_c).abs();
1306 let tr_i = hl.max(hc.max(lc));
1307 let vp_i = (hi - prev_l).abs();
1308 let vm_i = (lo - prev_h).abs();
1309 carry_tr += tr_i;
1310 carry_vp += vp_i;
1311 carry_vm += vm_i;
1312 pfx_tr[i] = carry_tr;
1313 pfx_vp[i] = carry_vp;
1314 pfx_vm[i] = carry_vm;
1315 i += 1;
1316 }
1317}
1318
1319#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1320use serde::{Deserialize, Serialize};
1321#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1322use wasm_bindgen::prelude::*;
1323
1324#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1325#[derive(Serialize, Deserialize)]
1326pub struct ViJsResult {
1327 pub plus: Vec<f64>,
1328 pub minus: Vec<f64>,
1329}
1330
1331#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1332#[wasm_bindgen]
1333pub fn vi_js(high: &[f64], low: &[f64], close: &[f64], period: usize) -> Result<JsValue, JsValue> {
1334 let mut plus = vec![0.0; high.len()];
1335 let mut minus = vec![0.0; high.len()];
1336
1337 vi_into_slice_wasm(
1338 &mut plus,
1339 &mut minus,
1340 high,
1341 low,
1342 close,
1343 period,
1344 detect_best_kernel(),
1345 )
1346 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1347
1348 let result = ViJsResult { plus, minus };
1349
1350 serde_wasm_bindgen::to_value(&result).map_err(|e| JsValue::from_str(&e.to_string()))
1351}
1352
1353#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1354#[wasm_bindgen]
1355pub fn vi_unified_js(
1356 high: &[f64],
1357 low: &[f64],
1358 close: &[f64],
1359 period: usize,
1360) -> Result<Vec<f64>, JsValue> {
1361 let mut result = vec![0.0; high.len() * 2];
1362
1363 let (plus_slice, minus_slice) = result.split_at_mut(high.len());
1364
1365 vi_into_slice_wasm(
1366 plus_slice,
1367 minus_slice,
1368 high,
1369 low,
1370 close,
1371 period,
1372 detect_best_kernel(),
1373 )
1374 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1375
1376 Ok(result)
1377}
1378
1379#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1380#[wasm_bindgen]
1381pub fn vi_alloc(len: usize) -> *mut f64 {
1382 let mut vec = Vec::<f64>::with_capacity(len * 2);
1383 let ptr = vec.as_mut_ptr();
1384 std::mem::forget(vec);
1385 ptr
1386}
1387
1388#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1389#[wasm_bindgen]
1390pub fn vi_free(ptr: *mut f64, len: usize) {
1391 unsafe {
1392 let _ = Vec::from_raw_parts(ptr, len * 2, len * 2);
1393 }
1394}
1395
1396#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1397#[wasm_bindgen]
1398pub fn vi_into(
1399 high_ptr: *const f64,
1400 low_ptr: *const f64,
1401 close_ptr: *const f64,
1402 plus_ptr: *mut f64,
1403 minus_ptr: *mut f64,
1404 len: usize,
1405 period: usize,
1406) -> Result<(), JsValue> {
1407 if high_ptr.is_null()
1408 || low_ptr.is_null()
1409 || close_ptr.is_null()
1410 || plus_ptr.is_null()
1411 || minus_ptr.is_null()
1412 {
1413 return Err(JsValue::from_str("Null pointer provided"));
1414 }
1415
1416 unsafe {
1417 let high = std::slice::from_raw_parts(high_ptr, len);
1418 let low = std::slice::from_raw_parts(low_ptr, len);
1419 let close = std::slice::from_raw_parts(close_ptr, len);
1420
1421 let plus_out = std::slice::from_raw_parts_mut(plus_ptr, len);
1422 let minus_out = std::slice::from_raw_parts_mut(minus_ptr, len);
1423
1424 vi_into_slice_wasm(
1425 plus_out,
1426 minus_out,
1427 high,
1428 low,
1429 close,
1430 period,
1431 detect_best_kernel(),
1432 )
1433 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1434
1435 Ok(())
1436 }
1437}
1438
1439#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1440#[derive(Serialize, Deserialize)]
1441pub struct ViBatchConfig {
1442 pub period_range: (usize, usize, usize),
1443}
1444
1445#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1446#[derive(Serialize, Deserialize)]
1447pub struct ViBatchJsOutput {
1448 pub plus: Vec<f64>,
1449 pub minus: Vec<f64>,
1450 pub periods: Vec<usize>,
1451 pub rows: usize,
1452 pub cols: usize,
1453}
1454
1455#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1456#[wasm_bindgen(js_name = vi_batch)]
1457pub fn vi_batch_js(
1458 high: &[f64],
1459 low: &[f64],
1460 close: &[f64],
1461 config: JsValue,
1462) -> Result<JsValue, JsValue> {
1463 let config: ViBatchConfig = serde_wasm_bindgen::from_value(config)
1464 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1465
1466 let sweep = ViBatchRange {
1467 period: config.period_range,
1468 };
1469 let output = vi_batch_with_kernel(high, low, close, &sweep, Kernel::Auto)
1470 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1471
1472 let periods: Vec<usize> = output
1473 .combos
1474 .iter()
1475 .map(|p| p.period.unwrap_or(14))
1476 .collect();
1477
1478 let js_output = ViBatchJsOutput {
1479 plus: output.plus,
1480 minus: output.minus,
1481 periods,
1482 rows: output.rows,
1483 cols: output.cols,
1484 };
1485
1486 serde_wasm_bindgen::to_value(&js_output)
1487 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1488}
1489
1490#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1491#[wasm_bindgen]
1492pub fn vi_batch_into(
1493 high_ptr: *const f64,
1494 low_ptr: *const f64,
1495 close_ptr: *const f64,
1496 plus_ptr: *mut f64,
1497 minus_ptr: *mut f64,
1498 len: usize,
1499 period_start: usize,
1500 period_end: usize,
1501 period_step: usize,
1502) -> Result<usize, JsValue> {
1503 if high_ptr.is_null()
1504 || low_ptr.is_null()
1505 || close_ptr.is_null()
1506 || plus_ptr.is_null()
1507 || minus_ptr.is_null()
1508 {
1509 return Err(JsValue::from_str("Null pointer provided"));
1510 }
1511
1512 unsafe {
1513 let high = std::slice::from_raw_parts(high_ptr, len);
1514 let low = std::slice::from_raw_parts(low_ptr, len);
1515 let close = std::slice::from_raw_parts(close_ptr, len);
1516
1517 let sweep = ViBatchRange {
1518 period: (period_start, period_end, period_step),
1519 };
1520
1521 let combos = expand_grid(&sweep);
1522 let rows = combos.len();
1523 let cols = len;
1524 let total = rows
1525 .checked_mul(cols)
1526 .ok_or_else(|| JsValue::from_str("rows*cols overflow in vi_into"))?;
1527
1528 let plus_out = std::slice::from_raw_parts_mut(plus_ptr, total);
1529 let minus_out = std::slice::from_raw_parts_mut(minus_ptr, total);
1530
1531 let _ = vi_batch_inner_into(
1532 high,
1533 low,
1534 close,
1535 &sweep,
1536 Kernel::Auto,
1537 false,
1538 plus_out,
1539 minus_out,
1540 )
1541 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1542
1543 Ok(rows)
1544 }
1545}
1546
1547#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1548pub fn vi_into_slice_wasm(
1549 dst_plus: &mut [f64],
1550 dst_minus: &mut [f64],
1551 high: &[f64],
1552 low: &[f64],
1553 close: &[f64],
1554 period: usize,
1555 kern: Kernel,
1556) -> Result<(), ViError> {
1557 let params = ViParams {
1558 period: Some(period),
1559 };
1560 let input = ViInput::from_slices(high, low, close, params);
1561 vi_into_slice(dst_plus, dst_minus, &input, kern)
1562}
1563
1564#[cfg(feature = "python")]
1565pub fn register_vi_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
1566 m.add_function(wrap_pyfunction!(vi_py, m)?)?;
1567 m.add_function(wrap_pyfunction!(vi_batch_py, m)?)?;
1568 m.add_class::<ViStreamPy>()?;
1569 Ok(())
1570}
1571
1572#[cfg(test)]
1573mod tests {
1574 use super::*;
1575 use crate::skip_if_unsupported;
1576 use crate::utilities::data_loader::read_candles_from_csv;
1577
1578 fn check_vi_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1579 skip_if_unsupported!(kernel, test_name);
1580 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1581 let candles = read_candles_from_csv(file_path)?;
1582 let default_params = ViParams { period: None };
1583 let input = ViInput::from_candles(&candles, default_params);
1584 let output = vi_with_kernel(&input, kernel)?;
1585 assert_eq!(output.plus.len(), candles.close.len());
1586 assert_eq!(output.minus.len(), candles.close.len());
1587 Ok(())
1588 }
1589 fn check_vi_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1590 skip_if_unsupported!(kernel, test_name);
1591 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1592 let candles = read_candles_from_csv(file_path)?;
1593 let input = ViInput::from_candles(&candles, ViParams::default());
1594 let result = vi_with_kernel(&input, kernel)?;
1595 let expected_last_five_plus = [
1596 0.9970238095238095,
1597 0.9871071716357775,
1598 0.9464453759945247,
1599 0.890897412369242,
1600 0.9206478557604156,
1601 ];
1602 let expected_last_five_minus = [
1603 1.0097117794486214,
1604 1.04174053182917,
1605 1.1152365471811105,
1606 1.181684712791338,
1607 1.1894672506875827,
1608 ];
1609 let n = result.plus.len();
1610 let plus_slice = &result.plus[n - 5..];
1611 let minus_slice = &result.minus[n - 5..];
1612 for (i, &val) in plus_slice.iter().enumerate() {
1613 let expected = expected_last_five_plus[i];
1614 assert!(
1615 (val - expected).abs() < 1e-8,
1616 "[{}] VI+ mismatch at idx {}: got {}, expected {}",
1617 test_name,
1618 i,
1619 val,
1620 expected
1621 );
1622 }
1623 for (i, &val) in minus_slice.iter().enumerate() {
1624 let expected = expected_last_five_minus[i];
1625 assert!(
1626 (val - expected).abs() < 1e-8,
1627 "[{}] VI- mismatch at idx {}: got {}, expected {}",
1628 test_name,
1629 i,
1630 val,
1631 expected
1632 );
1633 }
1634 Ok(())
1635 }
1636 fn check_vi_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1637 skip_if_unsupported!(kernel, test_name);
1638 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1639 let candles = read_candles_from_csv(file_path)?;
1640 let input = ViInput::with_default_candles(&candles);
1641 let output = vi_with_kernel(&input, kernel)?;
1642 assert_eq!(output.plus.len(), candles.close.len());
1643 assert_eq!(output.minus.len(), candles.close.len());
1644 Ok(())
1645 }
1646 fn check_vi_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1647 skip_if_unsupported!(kernel, test_name);
1648 let input_data = [10.0, 20.0, 30.0];
1649 let params = ViParams { period: Some(0) };
1650 let input = ViInput::from_slices(&input_data, &input_data, &input_data, params);
1651 let res = vi_with_kernel(&input, kernel);
1652 assert!(
1653 res.is_err(),
1654 "[{}] VI should fail with zero period",
1655 test_name
1656 );
1657 Ok(())
1658 }
1659 fn check_vi_period_exceeds_length(
1660 test_name: &str,
1661 kernel: Kernel,
1662 ) -> Result<(), Box<dyn Error>> {
1663 skip_if_unsupported!(kernel, test_name);
1664 let data_small = [10.0, 20.0, 30.0];
1665 let params = ViParams { period: Some(10) };
1666 let input = ViInput::from_slices(&data_small, &data_small, &data_small, params);
1667 let res = vi_with_kernel(&input, kernel);
1668 assert!(
1669 res.is_err(),
1670 "[{}] VI should fail with period exceeding length",
1671 test_name
1672 );
1673 Ok(())
1674 }
1675 fn check_vi_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1676 skip_if_unsupported!(kernel, test_name);
1677 let single_point = [42.0];
1678 let params = ViParams { period: Some(14) };
1679 let input = ViInput::from_slices(&single_point, &single_point, &single_point, params);
1680 let res = vi_with_kernel(&input, kernel);
1681 assert!(
1682 res.is_err(),
1683 "[{}] VI should fail with insufficient data",
1684 test_name
1685 );
1686 Ok(())
1687 }
1688 fn check_vi_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1689 skip_if_unsupported!(kernel, test_name);
1690 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1691 let candles = read_candles_from_csv(file_path)?;
1692 let input = ViInput::from_candles(&candles, ViParams::default());
1693 let res = vi_with_kernel(&input, kernel)?;
1694 assert_eq!(res.plus.len(), candles.close.len());
1695 if res.plus.len() > 20 {
1696 for (i, &val) in res.plus[20..].iter().enumerate() {
1697 assert!(
1698 !val.is_nan(),
1699 "[{}] Found unexpected NaN at out-index {}",
1700 test_name,
1701 20 + i
1702 );
1703 }
1704 }
1705 Ok(())
1706 }
1707
1708 #[cfg(debug_assertions)]
1709 fn check_vi_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1710 skip_if_unsupported!(kernel, test_name);
1711
1712 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1713 let candles = read_candles_from_csv(file_path)?;
1714
1715 let test_params = vec![
1716 ViParams::default(),
1717 ViParams { period: Some(1) },
1718 ViParams { period: Some(2) },
1719 ViParams { period: Some(5) },
1720 ViParams { period: Some(7) },
1721 ViParams { period: Some(10) },
1722 ViParams { period: Some(20) },
1723 ViParams { period: Some(30) },
1724 ViParams { period: Some(50) },
1725 ViParams { period: Some(100) },
1726 ViParams { period: Some(200) },
1727 ];
1728
1729 for (param_idx, params) in test_params.iter().enumerate() {
1730 let input = ViInput::from_candles(&candles, params.clone());
1731 let output = vi_with_kernel(&input, kernel)?;
1732
1733 for (i, &val) in output.plus.iter().enumerate() {
1734 if val.is_nan() {
1735 continue;
1736 }
1737
1738 let bits = val.to_bits();
1739
1740 if bits == 0x11111111_11111111 {
1741 panic!(
1742 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in plus array \
1743 with params: period={} (param set {})",
1744 test_name,
1745 val,
1746 bits,
1747 i,
1748 params.period.unwrap_or(14),
1749 param_idx
1750 );
1751 }
1752
1753 if bits == 0x22222222_22222222 {
1754 panic!(
1755 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in plus array \
1756 with params: period={} (param set {})",
1757 test_name,
1758 val,
1759 bits,
1760 i,
1761 params.period.unwrap_or(14),
1762 param_idx
1763 );
1764 }
1765
1766 if bits == 0x33333333_33333333 {
1767 panic!(
1768 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in plus array \
1769 with params: period={} (param set {})",
1770 test_name,
1771 val,
1772 bits,
1773 i,
1774 params.period.unwrap_or(14),
1775 param_idx
1776 );
1777 }
1778 }
1779
1780 for (i, &val) in output.minus.iter().enumerate() {
1781 if val.is_nan() {
1782 continue;
1783 }
1784
1785 let bits = val.to_bits();
1786
1787 if bits == 0x11111111_11111111 {
1788 panic!(
1789 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in minus array \
1790 with params: period={} (param set {})",
1791 test_name,
1792 val,
1793 bits,
1794 i,
1795 params.period.unwrap_or(14),
1796 param_idx
1797 );
1798 }
1799
1800 if bits == 0x22222222_22222222 {
1801 panic!(
1802 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in minus array \
1803 with params: period={} (param set {})",
1804 test_name,
1805 val,
1806 bits,
1807 i,
1808 params.period.unwrap_or(14),
1809 param_idx
1810 );
1811 }
1812
1813 if bits == 0x33333333_33333333 {
1814 panic!(
1815 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in minus array \
1816 with params: period={} (param set {})",
1817 test_name,
1818 val,
1819 bits,
1820 i,
1821 params.period.unwrap_or(14),
1822 param_idx
1823 );
1824 }
1825 }
1826 }
1827
1828 Ok(())
1829 }
1830
1831 #[cfg(not(debug_assertions))]
1832 fn check_vi_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1833 Ok(())
1834 }
1835
1836 #[cfg(feature = "proptest")]
1837 #[allow(clippy::float_cmp)]
1838 fn check_vi_property(
1839 test_name: &str,
1840 kernel: Kernel,
1841 ) -> Result<(), Box<dyn std::error::Error>> {
1842 use proptest::prelude::*;
1843 skip_if_unsupported!(kernel, test_name);
1844
1845 let strat = (2usize..=100).prop_flat_map(|period| {
1846 (period + 50..400).prop_flat_map(move |len| {
1847 (
1848 prop::collection::vec(
1849 (50.0f64..500.0f64).prop_filter("finite", |x| x.is_finite()),
1850 len,
1851 ),
1852 prop::collection::vec((0.001f64..0.05f64), len),
1853 prop::collection::vec((0.0f64..1.0f64), len),
1854 Just(period),
1855 )
1856 })
1857 });
1858
1859 proptest::test_runner::TestRunner::default()
1860 .run(
1861 &strat,
1862 |(base_prices, volatilities, close_positions, period)| {
1863 let mut high = Vec::with_capacity(base_prices.len());
1864 let mut low = Vec::with_capacity(base_prices.len());
1865 let mut close = Vec::with_capacity(base_prices.len());
1866
1867 assert_eq!(base_prices.len(), volatilities.len());
1868 assert_eq!(base_prices.len(), close_positions.len());
1869
1870 for i in 0..base_prices.len() {
1871 let price = base_prices[i];
1872 let vol = volatilities[i];
1873 let close_pos = close_positions[i];
1874
1875 let range = price * vol;
1876 let h = price + range * (0.3 + vol * 2.0);
1877 let l = price - range * (0.3 + vol * 2.0);
1878 let c = l + (h - l) * close_pos;
1879
1880 high.push(h);
1881 low.push(l);
1882 close.push(c);
1883 }
1884
1885 let params = ViParams {
1886 period: Some(period),
1887 };
1888 let input = ViInput::from_slices(&high, &low, &close, params.clone());
1889
1890 let ViOutput {
1891 plus: out_plus,
1892 minus: out_minus,
1893 } = vi_with_kernel(&input, kernel).unwrap();
1894 let ViOutput {
1895 plus: ref_plus,
1896 minus: ref_minus,
1897 } = vi_with_kernel(&input, Kernel::Scalar).unwrap();
1898
1899 for i in 0..out_plus.len() {
1900 if !out_plus[i].is_nan() {
1901 prop_assert!(
1902 out_plus[i] >= -1e-9,
1903 "[{}] VI+ negative at idx {}: {}",
1904 test_name,
1905 i,
1906 out_plus[i]
1907 );
1908 }
1909 if !out_minus[i].is_nan() {
1910 prop_assert!(
1911 out_minus[i] >= -1e-9,
1912 "[{}] VI- negative at idx {}: {}",
1913 test_name,
1914 i,
1915 out_minus[i]
1916 );
1917 }
1918 }
1919
1920 let first_valid = (0..high.len())
1921 .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
1922 .unwrap_or(0);
1923 let warmup_end = first_valid + period - 1;
1924
1925 for i in 0..warmup_end.min(out_plus.len()) {
1926 prop_assert!(
1927 out_plus[i].is_nan(),
1928 "[{}] Expected NaN during warmup at idx {}, got {}",
1929 test_name,
1930 i,
1931 out_plus[i]
1932 );
1933 prop_assert!(
1934 out_minus[i].is_nan(),
1935 "[{}] Expected NaN during warmup at idx {}, got {}",
1936 test_name,
1937 i,
1938 out_minus[i]
1939 );
1940 }
1941
1942 for i in warmup_end..out_plus.len() {
1943 let plus_bits = out_plus[i].to_bits();
1944 let ref_plus_bits = ref_plus[i].to_bits();
1945 let minus_bits = out_minus[i].to_bits();
1946 let ref_minus_bits = ref_minus[i].to_bits();
1947
1948 if !out_plus[i].is_finite() || !ref_plus[i].is_finite() {
1949 prop_assert!(
1950 plus_bits == ref_plus_bits,
1951 "[{}] VI+ finite/NaN mismatch at idx {}: {} vs {}",
1952 test_name,
1953 i,
1954 out_plus[i],
1955 ref_plus[i]
1956 );
1957 } else {
1958 let ulp_diff = plus_bits.abs_diff(ref_plus_bits);
1959 prop_assert!(
1960 (out_plus[i] - ref_plus[i]).abs() <= 1e-9 || ulp_diff <= 4,
1961 "[{}] VI+ mismatch at idx {}: {} vs {} (ULP={})",
1962 test_name,
1963 i,
1964 out_plus[i],
1965 ref_plus[i],
1966 ulp_diff
1967 );
1968 }
1969
1970 if !out_minus[i].is_finite() || !ref_minus[i].is_finite() {
1971 prop_assert!(
1972 minus_bits == ref_minus_bits,
1973 "[{}] VI- finite/NaN mismatch at idx {}: {} vs {}",
1974 test_name,
1975 i,
1976 out_minus[i],
1977 ref_minus[i]
1978 );
1979 } else {
1980 let ulp_diff = minus_bits.abs_diff(ref_minus_bits);
1981 prop_assert!(
1982 (out_minus[i] - ref_minus[i]).abs() <= 1e-9 || ulp_diff <= 4,
1983 "[{}] VI- mismatch at idx {}: {} vs {} (ULP={})",
1984 test_name,
1985 i,
1986 out_minus[i],
1987 ref_minus[i],
1988 ulp_diff
1989 );
1990 }
1991 }
1992
1993 if period == 1 {
1994 if warmup_end < out_plus.len() {
1995 prop_assert!(
1996 out_plus[warmup_end].is_finite(),
1997 "[{}] VI+ should be finite for period=1 at idx {}",
1998 test_name,
1999 warmup_end
2000 );
2001 prop_assert!(
2002 out_minus[warmup_end].is_finite(),
2003 "[{}] VI- should be finite for period=1 at idx {}",
2004 test_name,
2005 warmup_end
2006 );
2007 }
2008 }
2009
2010 if period <= 5 && warmup_end + 5 < high.len() && warmup_end >= period {
2011 let idx = warmup_end;
2012
2013 let mut tr_sum = 0.0;
2014 let mut vp_sum = 0.0;
2015 let mut vm_sum = 0.0;
2016
2017 let first_idx = idx + 1 - period;
2018 tr_sum += high[first_idx] - low[first_idx];
2019
2020 for j in (first_idx + 1)..=idx {
2021 let tr = (high[j] - low[j])
2022 .max((high[j] - close[j - 1]).abs())
2023 .max((low[j] - close[j - 1]).abs());
2024 let vp = (high[j] - low[j - 1]).abs();
2025 let vm = (low[j] - high[j - 1]).abs();
2026
2027 tr_sum += tr;
2028 vp_sum += vp;
2029 vm_sum += vm;
2030 }
2031
2032 if tr_sum > 1e-10 {
2033 let expected_plus = vp_sum / tr_sum;
2034 let expected_minus = vm_sum / tr_sum;
2035
2036 prop_assert!(
2037 (out_plus[idx] - expected_plus).abs() < 1e-6,
2038 "[{}] VI+ formula verification failed at idx {}: {} vs {}",
2039 test_name,
2040 idx,
2041 out_plus[idx],
2042 expected_plus
2043 );
2044 prop_assert!(
2045 (out_minus[idx] - expected_minus).abs() < 1e-6,
2046 "[{}] VI- formula verification failed at idx {}: {} vs {}",
2047 test_name,
2048 idx,
2049 out_minus[idx],
2050 expected_minus
2051 );
2052 }
2053 }
2054
2055 #[cfg(debug_assertions)]
2056 {
2057 for i in 0..out_plus.len() {
2058 if !out_plus[i].is_nan() {
2059 let bits = out_plus[i].to_bits();
2060 prop_assert!(
2061 bits != 0x11111111_11111111
2062 && bits != 0x22222222_22222222
2063 && bits != 0x33333333_33333333,
2064 "[{}] Found poison value in VI+ at idx {}: 0x{:016X}",
2065 test_name,
2066 i,
2067 bits
2068 );
2069 }
2070 if !out_minus[i].is_nan() {
2071 let bits = out_minus[i].to_bits();
2072 prop_assert!(
2073 bits != 0x11111111_11111111
2074 && bits != 0x22222222_22222222
2075 && bits != 0x33333333_33333333,
2076 "[{}] Found poison value in VI- at idx {}: 0x{:016X}",
2077 test_name,
2078 i,
2079 bits
2080 );
2081 }
2082 }
2083 }
2084
2085 Ok(())
2086 },
2087 )
2088 .unwrap();
2089
2090 Ok(())
2091 }
2092
2093 macro_rules! generate_all_vi_tests {
2094 ($($test_fn:ident),*) => {
2095 paste::paste! {
2096 $(
2097 #[test]
2098 fn [<$test_fn _scalar_f64>]() {
2099 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2100 }
2101 )*
2102 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2103 $(
2104 #[test]
2105 fn [<$test_fn _avx2_f64>]() {
2106 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2107 }
2108 #[test]
2109 fn [<$test_fn _avx512_f64>]() {
2110 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2111 }
2112 )*
2113 }
2114 }
2115 }
2116 generate_all_vi_tests!(
2117 check_vi_partial_params,
2118 check_vi_accuracy,
2119 check_vi_default_candles,
2120 check_vi_zero_period,
2121 check_vi_period_exceeds_length,
2122 check_vi_very_small_dataset,
2123 check_vi_nan_handling,
2124 check_vi_no_poison
2125 );
2126
2127 #[cfg(feature = "proptest")]
2128 generate_all_vi_tests!(check_vi_property);
2129 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2130 skip_if_unsupported!(kernel, test);
2131 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2132 let c = read_candles_from_csv(file)?;
2133 let output = ViBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
2134 let def = ViParams::default();
2135 let row = output.plus_for(&def).expect("default row missing");
2136 assert_eq!(row.len(), c.close.len());
2137 Ok(())
2138 }
2139
2140 #[cfg(debug_assertions)]
2141 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2142 skip_if_unsupported!(kernel, test);
2143
2144 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2145 let c = read_candles_from_csv(file)?;
2146
2147 let test_configs = vec![
2148 (2, 10, 2),
2149 (5, 25, 5),
2150 (20, 50, 10),
2151 (2, 5, 1),
2152 (14, 14, 0),
2153 (30, 60, 15),
2154 (50, 100, 25),
2155 (100, 200, 50),
2156 ];
2157
2158 for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
2159 let output = ViBatchBuilder::new()
2160 .kernel(kernel)
2161 .period_range(p_start, p_end, p_step)
2162 .apply_candles(&c)?;
2163
2164 for (idx, &val) in output.plus.iter().enumerate() {
2165 if val.is_nan() {
2166 continue;
2167 }
2168
2169 let bits = val.to_bits();
2170 let row = idx / output.cols;
2171 let col = idx % output.cols;
2172 let combo = &output.combos[row];
2173
2174 if bits == 0x11111111_11111111 {
2175 panic!(
2176 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2177 at row {} col {} (flat index {}) in plus array with params: period={}",
2178 test,
2179 cfg_idx,
2180 val,
2181 bits,
2182 row,
2183 col,
2184 idx,
2185 combo.period.unwrap_or(14)
2186 );
2187 }
2188
2189 if bits == 0x22222222_22222222 {
2190 panic!(
2191 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2192 at row {} col {} (flat index {}) in plus array with params: period={}",
2193 test,
2194 cfg_idx,
2195 val,
2196 bits,
2197 row,
2198 col,
2199 idx,
2200 combo.period.unwrap_or(14)
2201 );
2202 }
2203
2204 if bits == 0x33333333_33333333 {
2205 panic!(
2206 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2207 at row {} col {} (flat index {}) in plus array with params: period={}",
2208 test,
2209 cfg_idx,
2210 val,
2211 bits,
2212 row,
2213 col,
2214 idx,
2215 combo.period.unwrap_or(14)
2216 );
2217 }
2218 }
2219
2220 for (idx, &val) in output.minus.iter().enumerate() {
2221 if val.is_nan() {
2222 continue;
2223 }
2224
2225 let bits = val.to_bits();
2226 let row = idx / output.cols;
2227 let col = idx % output.cols;
2228 let combo = &output.combos[row];
2229
2230 if bits == 0x11111111_11111111 {
2231 panic!(
2232 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2233 at row {} col {} (flat index {}) in minus array with params: period={}",
2234 test,
2235 cfg_idx,
2236 val,
2237 bits,
2238 row,
2239 col,
2240 idx,
2241 combo.period.unwrap_or(14)
2242 );
2243 }
2244
2245 if bits == 0x22222222_22222222 {
2246 panic!(
2247 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2248 at row {} col {} (flat index {}) in minus array with params: period={}",
2249 test,
2250 cfg_idx,
2251 val,
2252 bits,
2253 row,
2254 col,
2255 idx,
2256 combo.period.unwrap_or(14)
2257 );
2258 }
2259
2260 if bits == 0x33333333_33333333 {
2261 panic!(
2262 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2263 at row {} col {} (flat index {}) in minus array with params: period={}",
2264 test,
2265 cfg_idx,
2266 val,
2267 bits,
2268 row,
2269 col,
2270 idx,
2271 combo.period.unwrap_or(14)
2272 );
2273 }
2274 }
2275 }
2276
2277 Ok(())
2278 }
2279
2280 #[cfg(not(debug_assertions))]
2281 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2282 Ok(())
2283 }
2284 macro_rules! gen_batch_tests {
2285 ($fn_name:ident) => {
2286 paste::paste! {
2287 #[test] fn [<$fn_name _scalar>]() {
2288 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2289 }
2290 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2291 #[test] fn [<$fn_name _avx2>]() {
2292 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2293 }
2294 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2295 #[test] fn [<$fn_name _avx512>]() {
2296 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2297 }
2298 #[test] fn [<$fn_name _auto_detect>]() {
2299 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2300 }
2301 }
2302 };
2303 }
2304 gen_batch_tests!(check_batch_default_row);
2305 gen_batch_tests!(check_batch_no_poison);
2306}
2307
2308#[cfg(feature = "python")]
2309#[pyfunction(name = "vi")]
2310#[pyo3(signature = (high, low, close, period, kernel=None))]
2311pub fn vi_py<'py>(
2312 py: Python<'py>,
2313 high: PyReadonlyArray1<'py, f64>,
2314 low: PyReadonlyArray1<'py, f64>,
2315 close: PyReadonlyArray1<'py, f64>,
2316 period: usize,
2317 kernel: Option<&str>,
2318) -> PyResult<Bound<'py, PyDict>> {
2319 let h = high.as_slice()?;
2320 let l = low.as_slice()?;
2321 let c = close.as_slice()?;
2322
2323 if h.len() != l.len() || h.len() != c.len() {
2324 return Err(PyValueError::new_err(format!(
2325 "Input data length mismatch: high={}, low={}, close={}",
2326 h.len(),
2327 l.len(),
2328 c.len()
2329 )));
2330 }
2331
2332 let params = ViParams {
2333 period: Some(period),
2334 };
2335 let input = ViInput::from_slices(h, l, c, params);
2336 let kern = validate_kernel(kernel, false)?;
2337
2338 let (plus, minus) = py
2339 .allow_threads(|| {
2340 let mut plus = vec![0.0; h.len()];
2341 let mut minus = vec![0.0; h.len()];
2342 vi_into_slice(&mut plus, &mut minus, &input, kern).map(|_| (plus, minus))
2343 })
2344 .map_err(|e: ViError| PyValueError::new_err(e.to_string()))?;
2345
2346 let d = PyDict::new(py);
2347 d.set_item("plus", plus.into_pyarray(py))?;
2348 d.set_item("minus", minus.into_pyarray(py))?;
2349 Ok(d)
2350}
2351
2352#[cfg(feature = "python")]
2353#[pyfunction(name = "vi_batch")]
2354#[pyo3(signature = (high, low, close, period_range, kernel=None))]
2355pub fn vi_batch_py<'py>(
2356 py: Python<'py>,
2357 high: PyReadonlyArray1<'py, f64>,
2358 low: PyReadonlyArray1<'py, f64>,
2359 close: PyReadonlyArray1<'py, f64>,
2360 period_range: (usize, usize, usize),
2361 kernel: Option<&str>,
2362) -> PyResult<Bound<'py, PyDict>> {
2363 use numpy::PyArray2;
2364
2365 let h = high.as_slice()?;
2366 let l = low.as_slice()?;
2367 let c = close.as_slice()?;
2368
2369 let sweep = ViBatchRange {
2370 period: period_range,
2371 };
2372 let kern = validate_kernel(kernel, true)?;
2373
2374 let combos = expand_grid(&sweep);
2375 let rows = combos.len();
2376 let cols = h.len();
2377 let total = rows
2378 .checked_mul(cols)
2379 .ok_or_else(|| PyValueError::new_err("rows*cols overflow in vi_batch_py"))?;
2380
2381 let out_plus = unsafe { PyArray1::<f64>::new(py, [total], false) };
2382 let out_minus = unsafe { PyArray1::<f64>::new(py, [total], false) };
2383 let slice_plus = unsafe { out_plus.as_slice_mut()? };
2384 let slice_minus = unsafe { out_minus.as_slice_mut()? };
2385
2386 py.allow_threads(|| {
2387 let simd = match kern {
2388 Kernel::Avx512Batch => Kernel::Avx512,
2389 Kernel::Avx2Batch => Kernel::Avx2,
2390 Kernel::ScalarBatch => Kernel::Scalar,
2391 Kernel::Auto => match detect_best_batch_kernel() {
2392 Kernel::Avx512Batch => Kernel::Avx2Batch,
2393 other => other,
2394 }
2395 .to_scalar_equivalent(),
2396 _ => Kernel::Scalar,
2397 };
2398
2399 vi_batch_inner_into(h, l, c, &sweep, simd, true, slice_plus, slice_minus)
2400 })
2401 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2402
2403 let d = PyDict::new(py);
2404 d.set_item("plus", out_plus.reshape((rows, cols))?)?;
2405 d.set_item("minus", out_minus.reshape((rows, cols))?)?;
2406 d.set_item(
2407 "periods",
2408 combos
2409 .iter()
2410 .map(|p| p.period.unwrap_or(14) as u64)
2411 .collect::<Vec<_>>()
2412 .into_pyarray(py),
2413 )?;
2414 Ok(d)
2415}
2416
2417#[cfg(feature = "python")]
2418#[pyclass(name = "ViStream")]
2419pub struct ViStreamPy {
2420 stream: ViStream,
2421 prev_high: Option<f64>,
2422 prev_low: Option<f64>,
2423 prev_close: Option<f64>,
2424}
2425
2426#[cfg(feature = "python")]
2427#[pymethods]
2428impl ViStreamPy {
2429 #[new]
2430 fn new(period: usize) -> PyResult<Self> {
2431 let s = ViStream::try_new(ViParams {
2432 period: Some(period),
2433 })
2434 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2435 Ok(Self {
2436 stream: s,
2437 prev_high: None,
2438 prev_low: None,
2439 prev_close: None,
2440 })
2441 }
2442
2443 fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
2444 match (self.prev_high, self.prev_low, self.prev_close) {
2445 (Some(ph), Some(pl), Some(pc)) => {
2446 let result = self.stream.update(high, low, close, pl, ph, pc);
2447 self.prev_high = Some(high);
2448 self.prev_low = Some(low);
2449 self.prev_close = Some(close);
2450 result
2451 }
2452 _ => {
2453 self.prev_high = Some(high);
2454 self.prev_low = Some(low);
2455 self.prev_close = Some(close);
2456 None
2457 }
2458 }
2459 }
2460}
2461
2462#[cfg(feature = "python")]
2463trait BatchToScalar {
2464 fn to_scalar_equivalent(self) -> Kernel;
2465}
2466#[cfg(feature = "python")]
2467impl BatchToScalar for Kernel {
2468 fn to_scalar_equivalent(self) -> Kernel {
2469 match self {
2470 Kernel::Avx512Batch => Kernel::Avx512,
2471 Kernel::Avx2Batch => Kernel::Avx2,
2472 Kernel::ScalarBatch => Kernel::Scalar,
2473 Kernel::Auto => Kernel::Scalar,
2474 k => k,
2475 }
2476 }
2477}
2478
2479#[cfg(all(feature = "python", feature = "cuda"))]
2480use crate::cuda::cuda_available;
2481#[cfg(all(feature = "python", feature = "cuda"))]
2482use crate::cuda::vi_wrapper::CudaVi;
2483#[cfg(all(feature = "python", feature = "cuda"))]
2484use crate::utilities::dlpack_cuda::DeviceArrayF32Py;
2485
2486#[cfg(all(feature = "python", feature = "cuda"))]
2487#[pyfunction(name = "vi_cuda_batch_dev")]
2488#[pyo3(signature = (high_f32, low_f32, close_f32, period_range, device_id=0))]
2489pub fn vi_cuda_batch_dev_py<'py>(
2490 py: Python<'py>,
2491 high_f32: numpy::PyReadonlyArray1<'py, f32>,
2492 low_f32: numpy::PyReadonlyArray1<'py, f32>,
2493 close_f32: numpy::PyReadonlyArray1<'py, f32>,
2494 period_range: (usize, usize, usize),
2495 device_id: usize,
2496) -> PyResult<Bound<'py, PyDict>> {
2497 use numpy::IntoPyArray;
2498 if !cuda_available() {
2499 return Err(PyValueError::new_err("CUDA not available"));
2500 }
2501 let h = high_f32.as_slice()?;
2502 let l = low_f32.as_slice()?;
2503 let c = close_f32.as_slice()?;
2504 if h.len() != l.len() || h.len() != c.len() {
2505 return Err(PyValueError::new_err("Input data length mismatch"));
2506 }
2507 let sweep = ViBatchRange {
2508 period: period_range,
2509 };
2510 let ((pair, combos), ctx, dev_id) = py.allow_threads(|| {
2511 let cuda = CudaVi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2512 let ctx = cuda.context_arc();
2513 let dev_id = cuda.device_id();
2514 cuda.vi_batch_dev(h, l, c, &sweep)
2515 .map(|res| (res, ctx, dev_id))
2516 .map_err(|e| PyValueError::new_err(e.to_string()))
2517 })?;
2518 let dict = PyDict::new(py);
2519 dict.set_item(
2520 "plus",
2521 Py::new(
2522 py,
2523 DeviceArrayF32Py {
2524 inner: pair.a,
2525 _ctx: Some(ctx.clone()),
2526 device_id: Some(dev_id),
2527 },
2528 )?,
2529 )?;
2530 dict.set_item(
2531 "minus",
2532 Py::new(
2533 py,
2534 DeviceArrayF32Py {
2535 inner: pair.b,
2536 _ctx: Some(ctx),
2537 device_id: Some(dev_id),
2538 },
2539 )?,
2540 )?;
2541 dict.set_item("rows", combos.len())?;
2542 dict.set_item("cols", h.len())?;
2543 dict.set_item(
2544 "periods",
2545 combos
2546 .iter()
2547 .map(|p| p.period.unwrap_or(14) as u64)
2548 .collect::<Vec<_>>()
2549 .into_pyarray(py),
2550 )?;
2551 Ok(dict)
2552}
2553
2554#[cfg(all(feature = "python", feature = "cuda"))]
2555#[pyfunction(name = "vi_cuda_many_series_one_param_dev")]
2556#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, period, device_id=0))]
2557pub fn vi_cuda_many_series_one_param_dev_py<'py>(
2558 py: Python<'py>,
2559 high_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
2560 low_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
2561 close_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
2562 period: usize,
2563 device_id: usize,
2564) -> PyResult<Bound<'py, PyDict>> {
2565 use numpy::PyUntypedArrayMethods;
2566 if !cuda_available() {
2567 return Err(PyValueError::new_err("CUDA not available"));
2568 }
2569 let shape = high_tm_f32.shape();
2570 if shape.len() != 2 {
2571 return Err(PyValueError::new_err("expected 2D array for high"));
2572 }
2573 if low_tm_f32.shape() != shape || close_tm_f32.shape() != shape {
2574 return Err(PyValueError::new_err(
2575 "input arrays must share the same shape",
2576 ));
2577 }
2578 let rows = shape[0];
2579 let cols = shape[1];
2580 let h = high_tm_f32.as_slice()?;
2581 let l = low_tm_f32.as_slice()?;
2582 let c = close_tm_f32.as_slice()?;
2583 let params = ViParams {
2584 period: Some(period),
2585 };
2586 let (pair, ctx, dev_id) = py.allow_threads(|| {
2587 let cuda = CudaVi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2588 let ctx = cuda.context_arc();
2589 let dev_id = cuda.device_id();
2590 cuda.vi_many_series_one_param_time_major_dev(h, l, c, cols, rows, ¶ms)
2591 .map(|res| (res, ctx, dev_id))
2592 .map_err(|e| PyValueError::new_err(e.to_string()))
2593 })?;
2594 let dict = PyDict::new(py);
2595 dict.set_item(
2596 "plus",
2597 Py::new(
2598 py,
2599 DeviceArrayF32Py {
2600 inner: pair.a,
2601 _ctx: Some(ctx.clone()),
2602 device_id: Some(dev_id),
2603 },
2604 )?,
2605 )?;
2606 dict.set_item(
2607 "minus",
2608 Py::new(
2609 py,
2610 DeviceArrayF32Py {
2611 inner: pair.b,
2612 _ctx: Some(ctx),
2613 device_id: Some(dev_id),
2614 },
2615 )?,
2616 )?;
2617 dict.set_item("rows", rows)?;
2618 dict.set_item("cols", cols)?;
2619 dict.set_item("period", period)?;
2620 Ok(dict)
2621}