Skip to main content

vector_ta/indicators/
vi.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5    make_uninit_matrix,
6};
7#[cfg(feature = "python")]
8use crate::utilities::kernel_validation::validate_kernel;
9use aligned_vec::{AVec, CACHELINE_ALIGN};
10#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
11use core::arch::x86_64::*;
12#[cfg(feature = "python")]
13use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
14#[cfg(feature = "python")]
15use pyo3::exceptions::PyValueError;
16#[cfg(feature = "python")]
17use pyo3::prelude::*;
18#[cfg(feature = "python")]
19use pyo3::types::PyDict;
20#[cfg(not(target_arch = "wasm32"))]
21use rayon::prelude::*;
22use std::error::Error;
23use std::mem::MaybeUninit;
24use thiserror::Error;
25
26#[derive(Debug, Clone)]
27pub enum ViData<'a> {
28    Candles {
29        candles: &'a Candles,
30    },
31    Slices {
32        high: &'a [f64],
33        low: &'a [f64],
34        close: &'a [f64],
35    },
36}
37
38#[derive(Debug, Clone)]
39pub struct ViOutput {
40    pub plus: Vec<f64>,
41    pub minus: Vec<f64>,
42}
43
44#[derive(Debug, Clone)]
45pub struct ViParams {
46    pub period: Option<usize>,
47}
48
49impl Default for ViParams {
50    fn default() -> Self {
51        Self { period: Some(14) }
52    }
53}
54
55#[derive(Debug, Clone)]
56pub struct ViInput<'a> {
57    pub data: ViData<'a>,
58    pub params: ViParams,
59}
60
61impl<'a> ViInput<'a> {
62    #[inline]
63    pub fn from_candles(candles: &'a Candles, params: ViParams) -> Self {
64        Self {
65            data: ViData::Candles { candles },
66            params,
67        }
68    }
69    #[inline]
70    pub fn from_slices(
71        high: &'a [f64],
72        low: &'a [f64],
73        close: &'a [f64],
74        params: ViParams,
75    ) -> Self {
76        Self {
77            data: ViData::Slices { high, low, close },
78            params,
79        }
80    }
81    #[inline]
82    pub fn with_default_candles(candles: &'a Candles) -> Self {
83        Self {
84            data: ViData::Candles { candles },
85            params: ViParams::default(),
86        }
87    }
88    #[inline]
89    pub fn get_period(&self) -> usize {
90        self.params.period.unwrap_or(14)
91    }
92}
93
94#[derive(Copy, Clone, Debug)]
95pub struct ViBuilder {
96    period: Option<usize>,
97    kernel: Kernel,
98}
99
100impl Default for ViBuilder {
101    fn default() -> Self {
102        Self {
103            period: None,
104            kernel: Kernel::Auto,
105        }
106    }
107}
108
109impl ViBuilder {
110    #[inline(always)]
111    pub fn new() -> Self {
112        Self::default()
113    }
114    #[inline(always)]
115    pub fn period(mut self, n: usize) -> Self {
116        self.period = Some(n);
117        self
118    }
119    #[inline(always)]
120    pub fn kernel(mut self, k: Kernel) -> Self {
121        self.kernel = k;
122        self
123    }
124    #[inline(always)]
125    pub fn apply(self, c: &Candles) -> Result<ViOutput, ViError> {
126        let p = ViParams {
127            period: self.period,
128        };
129        let i = ViInput::from_candles(c, p);
130        vi_with_kernel(&i, self.kernel)
131    }
132    #[inline(always)]
133    pub fn apply_slices(
134        self,
135        high: &[f64],
136        low: &[f64],
137        close: &[f64],
138    ) -> Result<ViOutput, ViError> {
139        let p = ViParams {
140            period: self.period,
141        };
142        let i = ViInput::from_slices(high, low, close, p);
143        vi_with_kernel(&i, self.kernel)
144    }
145    #[inline(always)]
146    pub fn into_stream(self) -> Result<ViStream, ViError> {
147        let p = ViParams {
148            period: self.period,
149        };
150        ViStream::try_new(p)
151    }
152}
153
154#[derive(Debug, Error)]
155pub enum ViError {
156    #[error("vi: Empty data provided.")]
157    EmptyInputData,
158    #[error("vi: All values are NaN.")]
159    AllValuesNaN,
160    #[error("vi: Invalid period: period = {period}, data length = {data_len}")]
161    InvalidPeriod { period: usize, data_len: usize },
162    #[error("vi: Not enough valid data: needed = {needed}, valid = {valid}")]
163    NotEnoughValidData { needed: usize, valid: usize },
164    #[error("vi: output length mismatch: expected = {expected}, got = {got}")]
165    OutputLengthMismatch { expected: usize, got: usize },
166    #[error("vi: invalid range: start={start}, end={end}, step={step}")]
167    InvalidRange {
168        start: usize,
169        end: usize,
170        step: usize,
171    },
172    #[error("vi: invalid kernel for batch: {0:?}")]
173    InvalidKernelForBatch(Kernel),
174    #[error("vi: invalid input: {0}")]
175    InvalidInput(String),
176}
177
178#[inline]
179pub fn vi(input: &ViInput) -> Result<ViOutput, ViError> {
180    vi_with_kernel(input, Kernel::Auto)
181}
182
183#[inline(always)]
184fn vi_prepare<'a>(
185    input: &'a ViInput,
186    kernel: Kernel,
187) -> Result<(&'a [f64], &'a [f64], &'a [f64], usize, usize, Kernel), ViError> {
188    let (high, low, close) = match &input.data {
189        ViData::Candles { candles } => (
190            source_type(candles, "high"),
191            source_type(candles, "low"),
192            source_type(candles, "close"),
193        ),
194        ViData::Slices { high, low, close } => (*high, *low, *close),
195    };
196
197    if high.is_empty() || low.is_empty() || close.is_empty() {
198        return Err(ViError::EmptyInputData);
199    }
200    let len = high.len();
201    if len != low.len() || len != close.len() {
202        return Err(ViError::EmptyInputData);
203    }
204
205    let period = input.get_period();
206    if period == 0 || period > len {
207        return Err(ViError::InvalidPeriod {
208            period,
209            data_len: len,
210        });
211    }
212
213    let first = (0..len)
214        .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
215        .ok_or(ViError::AllValuesNaN)?;
216
217    if len - first < period {
218        return Err(ViError::NotEnoughValidData {
219            needed: period,
220            valid: len - first,
221        });
222    }
223
224    let chosen = match kernel {
225        Kernel::Auto => detect_best_kernel(),
226        k => k,
227    };
228    Ok((high, low, close, period, first, chosen))
229}
230
231#[inline(always)]
232fn vi_compute_into(
233    high: &[f64],
234    low: &[f64],
235    close: &[f64],
236    period: usize,
237    first: usize,
238    kernel: Kernel,
239    plus: &mut [f64],
240    minus: &mut [f64],
241) {
242    unsafe {
243        match kernel {
244            Kernel::Scalar | Kernel::ScalarBatch => {
245                vi_scalar(high, low, close, period, first, plus, minus)
246            }
247            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
248            Kernel::Avx2 | Kernel::Avx2Batch => {
249                vi_avx2(high, low, close, period, first, plus, minus)
250            }
251            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
252            Kernel::Avx512 | Kernel::Avx512Batch => {
253                vi_avx512(high, low, close, period, first, plus, minus)
254            }
255            _ => unreachable!(),
256        }
257    }
258}
259
260pub fn vi_with_kernel(input: &ViInput, kernel: Kernel) -> Result<ViOutput, ViError> {
261    let (h, l, c, period, first, chosen) = vi_prepare(input, kernel)?;
262    let mut plus = alloc_with_nan_prefix(h.len(), first + period - 1);
263    let mut minus = alloc_with_nan_prefix(h.len(), first + period - 1);
264    vi_compute_into(h, l, c, period, first, chosen, &mut plus, &mut minus);
265    Ok(ViOutput { plus, minus })
266}
267
268pub fn vi_into_slice(
269    dst_plus: &mut [f64],
270    dst_minus: &mut [f64],
271    input: &ViInput,
272    kernel: Kernel,
273) -> Result<(), ViError> {
274    let (h, l, c, period, first, chosen) = vi_prepare(input, kernel)?;
275    if dst_plus.len() != h.len() || dst_minus.len() != h.len() {
276        let expected = h.len();
277        let got = dst_plus.len().max(dst_minus.len());
278        return Err(ViError::OutputLengthMismatch { expected, got });
279    }
280    vi_compute_into(h, l, c, period, first, chosen, dst_plus, dst_minus);
281    let warm = first + period - 1;
282    for i in 0..warm {
283        dst_plus[i] = f64::NAN;
284        dst_minus[i] = f64::NAN;
285    }
286    Ok(())
287}
288
289#[inline(always)]
290pub unsafe fn vi_scalar(
291    high: &[f64],
292    low: &[f64],
293    close: &[f64],
294    period: usize,
295    first: usize,
296    plus: &mut [f64],
297    minus: &mut [f64],
298) {
299    let n = high.len();
300    if n == 0 {
301        return;
302    }
303
304    let warm = first + period - 1;
305
306    let h = high.as_ptr();
307    let l = low.as_ptr();
308    let c = close.as_ptr();
309    let p_out = plus.as_mut_ptr();
310    let m_out = minus.as_mut_ptr();
311
312    let mut tr_buf: Vec<f64> = Vec::with_capacity(period);
313    let mut vp_buf: Vec<f64> = Vec::with_capacity(period);
314    let mut vm_buf: Vec<f64> = Vec::with_capacity(period);
315    tr_buf.set_len(period);
316    vp_buf.set_len(period);
317    vm_buf.set_len(period);
318    let trp = tr_buf.as_mut_ptr();
319    let vpp = vp_buf.as_mut_ptr();
320    let vmp = vm_buf.as_mut_ptr();
321
322    let mut prev_h = *h.add(first);
323    let mut prev_l = *l.add(first);
324    let mut prev_c = *c.add(first);
325
326    let mut sum_tr = prev_h - prev_l;
327    let mut sum_vp = 0.0f64;
328    let mut sum_vm = 0.0f64;
329
330    *trp.add(0) = sum_tr;
331    *vpp.add(0) = 0.0;
332    *vmp.add(0) = 0.0;
333
334    if period == 1 {
335        *p_out.add(warm) = 0.0;
336        *m_out.add(warm) = 0.0;
337    }
338
339    let mut r = if period == 1 { 0 } else { 1 };
340
341    let mut i = first + 1;
342    while i < n {
343        let hi = *h.add(i);
344        let lo = *l.add(i);
345
346        let hl = hi - lo;
347        let hc = (hi - prev_c).abs();
348        let lc = (lo - prev_c).abs();
349        let mut tr_new = if hl > hc { hl } else { hc };
350        if lc > tr_new {
351            tr_new = lc;
352        }
353
354        let vp_new = (hi - prev_l).abs();
355        let vm_new = (lo - prev_h).abs();
356
357        if i <= warm {
358            sum_tr += tr_new;
359            sum_vp += vp_new;
360            sum_vm += vm_new;
361
362            *trp.add(r) = tr_new;
363            *vpp.add(r) = vp_new;
364            *vmp.add(r) = vm_new;
365
366            if i == warm {
367                *p_out.add(i) = sum_vp / sum_tr;
368                *m_out.add(i) = sum_vm / sum_tr;
369            }
370        } else {
371            let tr_old = *trp.add(r);
372            let vp_old = *vpp.add(r);
373            let vm_old = *vmp.add(r);
374
375            sum_tr += tr_new - tr_old;
376            sum_vp += vp_new - vp_old;
377            sum_vm += vm_new - vm_old;
378
379            *trp.add(r) = tr_new;
380            *vpp.add(r) = vp_new;
381            *vmp.add(r) = vm_new;
382
383            *p_out.add(i) = sum_vp / sum_tr;
384            *m_out.add(i) = sum_vm / sum_tr;
385        }
386
387        prev_h = hi;
388        prev_l = lo;
389        prev_c = *c.add(i);
390
391        r += 1;
392        if r == period {
393            r = 0;
394        }
395
396        i += 1;
397    }
398}
399
400#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
401#[inline]
402pub unsafe fn vi_avx2(
403    high: &[f64],
404    low: &[f64],
405    close: &[f64],
406    period: usize,
407    first: usize,
408    plus: &mut [f64],
409    minus: &mut [f64],
410) {
411    vi_scalar(high, low, close, period, first, plus, minus);
412}
413
414#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
415#[inline]
416pub unsafe fn vi_avx512(
417    high: &[f64],
418    low: &[f64],
419    close: &[f64],
420    period: usize,
421    first: usize,
422    plus: &mut [f64],
423    minus: &mut [f64],
424) {
425    if period <= 32 {
426        vi_avx512_short(high, low, close, period, first, plus, minus);
427    } else {
428        vi_avx512_long(high, low, close, period, first, plus, minus);
429    }
430}
431
432#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
433#[inline]
434pub unsafe fn vi_avx512_short(
435    high: &[f64],
436    low: &[f64],
437    close: &[f64],
438    period: usize,
439    first: usize,
440    plus: &mut [f64],
441    minus: &mut [f64],
442) {
443    vi_scalar(high, low, close, period, first, plus, minus);
444}
445
446#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
447#[inline]
448pub unsafe fn vi_avx512_long(
449    high: &[f64],
450    low: &[f64],
451    close: &[f64],
452    period: usize,
453    first: usize,
454    plus: &mut [f64],
455    minus: &mut [f64],
456) {
457    vi_scalar(high, low, close, period, first, plus, minus);
458}
459
460#[derive(Debug, Clone)]
461pub struct ViStream {
462    period: usize,
463    tr: Vec<f64>,
464    vp: Vec<f64>,
465    vm: Vec<f64>,
466    idx: usize,
467    filled: bool,
468    sum_tr: f64,
469    sum_vp: f64,
470    sum_vm: f64,
471}
472
473impl ViStream {
474    pub fn try_new(params: ViParams) -> Result<Self, ViError> {
475        let period = params.period.unwrap_or(14);
476        if period == 0 {
477            return Err(ViError::InvalidPeriod {
478                period,
479                data_len: 0,
480            });
481        }
482        Ok(Self {
483            period,
484            tr: vec![0.0; period],
485            vp: vec![0.0; period],
486            vm: vec![0.0; period],
487            idx: 0,
488            filled: false,
489            sum_tr: 0.0,
490            sum_vp: 0.0,
491            sum_vm: 0.0,
492        })
493    }
494
495    pub fn update(
496        &mut self,
497        high: f64,
498        low: f64,
499        close: f64,
500        prev_low: f64,
501        prev_high: f64,
502        prev_close: f64,
503    ) -> Option<(f64, f64)> {
504        let _ = close;
505
506        let i = self.idx;
507
508        let hl = high - low;
509        let hc = (high - prev_close).abs();
510        let lc = (low - prev_close).abs();
511        let tr_new = hl.max(hc.max(lc));
512
513        let vp_new = (high - prev_low).abs();
514        let vm_new = (low - prev_high).abs();
515
516        let tr_old = self.tr[i];
517        let vp_old = self.vp[i];
518        let vm_old = self.vm[i];
519
520        self.sum_tr += tr_new - tr_old;
521        self.sum_vp += vp_new - vp_old;
522        self.sum_vm += vm_new - vm_old;
523
524        self.tr[i] = tr_new;
525        self.vp[i] = vp_new;
526        self.vm[i] = vm_new;
527
528        self.idx += 1;
529        if self.idx == self.period {
530            self.idx = 0;
531            self.filled = true;
532        }
533
534        if self.filled {
535            let inv_tr = 1.0 / self.sum_tr;
536            let vi_p = self.sum_vp * inv_tr;
537            let vi_m = self.sum_vm * inv_tr;
538            Some((vi_p, vi_m))
539        } else {
540            None
541        }
542    }
543}
544
545#[derive(Clone, Debug)]
546pub struct ViBatchRange {
547    pub period: (usize, usize, usize),
548}
549
550impl Default for ViBatchRange {
551    fn default() -> Self {
552        Self {
553            period: (14, 263, 1),
554        }
555    }
556}
557
558#[derive(Clone, Debug, Default)]
559pub struct ViBatchBuilder {
560    range: ViBatchRange,
561    kernel: Kernel,
562}
563
564impl ViBatchBuilder {
565    pub fn new() -> Self {
566        Self::default()
567    }
568    pub fn kernel(mut self, k: Kernel) -> Self {
569        self.kernel = k;
570        self
571    }
572    #[inline]
573    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
574        self.range.period = (start, end, step);
575        self
576    }
577    #[inline]
578    pub fn period_static(mut self, p: usize) -> Self {
579        self.range.period = (p, p, 0);
580        self
581    }
582    pub fn apply_slices(
583        self,
584        high: &[f64],
585        low: &[f64],
586        close: &[f64],
587    ) -> Result<ViBatchOutput, ViError> {
588        vi_batch_with_kernel(high, low, close, &self.range, self.kernel)
589    }
590    pub fn apply_candles(self, c: &Candles) -> Result<ViBatchOutput, ViError> {
591        let high = source_type(c, "high");
592        let low = source_type(c, "low");
593        let close = source_type(c, "close");
594        self.apply_slices(high, low, close)
595    }
596}
597
598#[derive(Clone, Debug)]
599pub struct ViBatchOutput {
600    pub plus: Vec<f64>,
601    pub minus: Vec<f64>,
602    pub combos: Vec<ViParams>,
603    pub rows: usize,
604    pub cols: usize,
605}
606impl ViBatchOutput {
607    pub fn row_for_params(&self, p: &ViParams) -> Option<usize> {
608        self.combos
609            .iter()
610            .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
611    }
612    pub fn plus_for(&self, p: &ViParams) -> Option<&[f64]> {
613        self.row_for_params(p).map(|row| {
614            let start = row * self.cols;
615            &self.plus[start..start + self.cols]
616        })
617    }
618    pub fn minus_for(&self, p: &ViParams) -> Option<&[f64]> {
619        self.row_for_params(p).map(|row| {
620            let start = row * self.cols;
621            &self.minus[start..start + self.cols]
622        })
623    }
624}
625
626#[inline(always)]
627fn expand_grid(r: &ViBatchRange) -> Vec<ViParams> {
628    fn axis_usize(range: (usize, usize, usize)) -> Result<Vec<usize>, ViError> {
629        let (start, end, step) = range;
630        if step == 0 || start == end {
631            return Ok(vec![start]);
632        }
633        if start < end {
634            let v: Vec<usize> = (start..=end).step_by(step).collect();
635            if v.is_empty() {
636                return Err(ViError::InvalidRange { start, end, step });
637            }
638            return Ok(v);
639        }
640        let mut v = Vec::new();
641        let mut cur = start;
642        loop {
643            v.push(cur);
644            if cur == end {
645                break;
646            }
647            cur = cur
648                .checked_sub(step)
649                .ok_or(ViError::InvalidRange { start, end, step })?;
650            if cur < end {
651                break;
652            }
653        }
654        if v.is_empty() {
655            return Err(ViError::InvalidRange { start, end, step });
656        }
657        Ok(v)
658    }
659
660    let periods = match axis_usize(r.period) {
661        Ok(v) => v,
662        Err(_) => return Vec::new(),
663    };
664    let mut out = Vec::with_capacity(periods.len());
665    for &p in &periods {
666        out.push(ViParams { period: Some(p) });
667    }
668    out
669}
670
671pub fn vi_batch_with_kernel(
672    high: &[f64],
673    low: &[f64],
674    close: &[f64],
675    sweep: &ViBatchRange,
676    k: Kernel,
677) -> Result<ViBatchOutput, ViError> {
678    let kernel = match k {
679        Kernel::Auto => match detect_best_batch_kernel() {
680            Kernel::Avx512Batch => Kernel::Avx2Batch,
681            other => other,
682        },
683        other if other.is_batch() => other,
684        other => {
685            return Err(ViError::InvalidKernelForBatch(other));
686        }
687    };
688    let simd = match kernel {
689        Kernel::Avx512Batch => Kernel::Avx512,
690        Kernel::Avx2Batch => Kernel::Avx2,
691        Kernel::ScalarBatch => Kernel::Scalar,
692        _ => unreachable!(),
693    };
694    vi_batch_par_slice(high, low, close, sweep, simd)
695}
696
697pub fn vi_batch_slice(
698    high: &[f64],
699    low: &[f64],
700    close: &[f64],
701    sweep: &ViBatchRange,
702    kern: Kernel,
703) -> Result<ViBatchOutput, ViError> {
704    vi_batch_inner(high, low, close, sweep, kern, false)
705}
706pub fn vi_batch_par_slice(
707    high: &[f64],
708    low: &[f64],
709    close: &[f64],
710    sweep: &ViBatchRange,
711    kern: Kernel,
712) -> Result<ViBatchOutput, ViError> {
713    vi_batch_inner(high, low, close, sweep, kern, true)
714}
715#[inline(always)]
716fn vi_batch_inner(
717    high: &[f64],
718    low: &[f64],
719    close: &[f64],
720    sweep: &ViBatchRange,
721    kern: Kernel,
722    parallel: bool,
723) -> Result<ViBatchOutput, ViError> {
724    let combos = expand_grid(sweep);
725    if combos.is_empty() {
726        return Err(ViError::InvalidRange {
727            start: sweep.period.0,
728            end: sweep.period.1,
729            step: sweep.period.2,
730        });
731    }
732    if high.is_empty() || low.is_empty() || close.is_empty() {
733        return Err(ViError::EmptyInputData);
734    }
735    let first = (0..high.len())
736        .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
737        .ok_or(ViError::AllValuesNaN)?;
738    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
739    if high.len() - first < max_p {
740        return Err(ViError::NotEnoughValidData {
741            needed: max_p,
742            valid: high.len() - first,
743        });
744    }
745    let rows = combos.len();
746    let cols = high.len();
747    rows.checked_mul(cols)
748        .ok_or_else(|| ViError::InvalidRange {
749            start: sweep.period.0,
750            end: sweep.period.1,
751            step: sweep.period.2,
752        })?;
753
754    let mut plus_mu = make_uninit_matrix(rows, cols);
755    let mut minus_mu = make_uninit_matrix(rows, cols);
756
757    let mut warm: Vec<usize> = Vec::with_capacity(combos.len());
758    for c in &combos {
759        let p = c.period.unwrap();
760        let warm_i = first
761            .checked_add(p)
762            .and_then(|v| v.checked_sub(1))
763            .ok_or_else(|| ViError::InvalidPeriod {
764                period: p,
765                data_len: high.len(),
766            })?;
767        warm.push(warm_i);
768    }
769
770    init_matrix_prefixes(&mut plus_mu, cols, &warm);
771    init_matrix_prefixes(&mut minus_mu, cols, &warm);
772
773    let mut plus_guard = core::mem::ManuallyDrop::new(plus_mu);
774    let mut minus_guard = core::mem::ManuallyDrop::new(minus_mu);
775    let plus: &mut [f64] = unsafe {
776        core::slice::from_raw_parts_mut(plus_guard.as_mut_ptr() as *mut f64, plus_guard.len())
777    };
778    let minus: &mut [f64] = unsafe {
779        core::slice::from_raw_parts_mut(minus_guard.as_mut_ptr() as *mut f64, minus_guard.len())
780    };
781
782    let mut pfx_tr = vec![0.0f64; cols];
783    let mut pfx_vp = vec![0.0f64; cols];
784    let mut pfx_vm = vec![0.0f64; cols];
785    if cols > 0 && first < cols {
786        unsafe {
787            match kern {
788                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
789                Kernel::Avx512 | Kernel::Avx512Batch => vi_prefix_avx512(
790                    high,
791                    low,
792                    close,
793                    first,
794                    &mut pfx_tr,
795                    &mut pfx_vp,
796                    &mut pfx_vm,
797                ),
798                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
799                Kernel::Avx2 | Kernel::Avx2Batch => vi_prefix_avx2(
800                    high,
801                    low,
802                    close,
803                    first,
804                    &mut pfx_tr,
805                    &mut pfx_vp,
806                    &mut pfx_vm,
807                ),
808                _ => vi_prefix_scalar(
809                    high,
810                    low,
811                    close,
812                    first,
813                    &mut pfx_tr,
814                    &mut pfx_vp,
815                    &mut pfx_vm,
816                ),
817            }
818        }
819    }
820
821    let do_row = |row: usize, plus_row: &mut [f64], minus_row: &mut [f64]| {
822        let period = combos[row].period.unwrap();
823        let warm = first + period - 1;
824        if warm >= cols {
825            return;
826        }
827        let mut i = warm;
828        while i < cols {
829            let tr_sum = if i >= period {
830                pfx_tr[i] - pfx_tr[i - period]
831            } else {
832                pfx_tr[i]
833            };
834            let vp_sum = if i >= period {
835                pfx_vp[i] - pfx_vp[i - period]
836            } else {
837                pfx_vp[i]
838            };
839            let vm_sum = if i >= period {
840                pfx_vm[i] - pfx_vm[i - period]
841            } else {
842                pfx_vm[i]
843            };
844            plus_row[i] = vp_sum / tr_sum;
845            minus_row[i] = vm_sum / tr_sum;
846            i += 1;
847        }
848    };
849    if parallel {
850        #[cfg(not(target_arch = "wasm32"))]
851        {
852            plus.par_chunks_mut(cols)
853                .zip(minus.par_chunks_mut(cols))
854                .enumerate()
855                .for_each(|(row, (p, m))| do_row(row, p, m));
856        }
857
858        #[cfg(target_arch = "wasm32")]
859        {
860            for ((row, p), m) in plus
861                .chunks_mut(cols)
862                .enumerate()
863                .zip(minus.chunks_mut(cols))
864            {
865                do_row(row, p, m);
866            }
867        }
868    } else {
869        for ((row, p), m) in plus
870            .chunks_mut(cols)
871            .enumerate()
872            .zip(minus.chunks_mut(cols))
873        {
874            do_row(row, p, m);
875        }
876    }
877
878    let plus_vec = unsafe {
879        Vec::from_raw_parts(
880            plus_guard.as_mut_ptr() as *mut f64,
881            plus_guard.len(),
882            plus_guard.capacity(),
883        )
884    };
885    let minus_vec = unsafe {
886        Vec::from_raw_parts(
887            minus_guard.as_mut_ptr() as *mut f64,
888            minus_guard.len(),
889            minus_guard.capacity(),
890        )
891    };
892
893    Ok(ViBatchOutput {
894        plus: plus_vec,
895        minus: minus_vec,
896        combos,
897        rows,
898        cols,
899    })
900}
901
902#[inline(always)]
903unsafe fn vi_row_scalar(
904    high: &[f64],
905    low: &[f64],
906    close: &[f64],
907    first: usize,
908    period: usize,
909    plus: &mut [f64],
910    minus: &mut [f64],
911) {
912    vi_scalar(high, low, close, period, first, plus, minus);
913}
914#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
915#[inline(always)]
916unsafe fn vi_row_avx2(
917    high: &[f64],
918    low: &[f64],
919    close: &[f64],
920    first: usize,
921    period: usize,
922    plus: &mut [f64],
923    minus: &mut [f64],
924) {
925    vi_scalar(high, low, close, period, first, plus, minus);
926}
927#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
928#[inline(always)]
929unsafe fn vi_row_avx512(
930    high: &[f64],
931    low: &[f64],
932    close: &[f64],
933    first: usize,
934    period: usize,
935    plus: &mut [f64],
936    minus: &mut [f64],
937) {
938    if period <= 32 {
939        vi_row_avx512_short(high, low, close, first, period, plus, minus);
940    } else {
941        vi_row_avx512_long(high, low, close, first, period, plus, minus);
942    }
943}
944#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
945#[inline(always)]
946unsafe fn vi_row_avx512_short(
947    high: &[f64],
948    low: &[f64],
949    close: &[f64],
950    first: usize,
951    period: usize,
952    plus: &mut [f64],
953    minus: &mut [f64],
954) {
955    vi_scalar(high, low, close, period, first, plus, minus);
956}
957#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
958#[inline(always)]
959unsafe fn vi_row_avx512_long(
960    high: &[f64],
961    low: &[f64],
962    close: &[f64],
963    first: usize,
964    period: usize,
965    plus: &mut [f64],
966    minus: &mut [f64],
967) {
968    vi_scalar(high, low, close, period, first, plus, minus);
969}
970
971fn vi_batch_inner_into(
972    high: &[f64],
973    low: &[f64],
974    close: &[f64],
975    sweep: &ViBatchRange,
976    kernel: Kernel,
977    parallel: bool,
978    out_plus: &mut [f64],
979    out_minus: &mut [f64],
980) -> Result<Vec<ViParams>, ViError> {
981    let combos = expand_grid(&sweep);
982    let rows = combos.len();
983    let cols = close.len();
984    if rows == 0 {
985        return Err(ViError::InvalidRange {
986            start: sweep.period.0,
987            end: sweep.period.1,
988            step: sweep.period.2,
989        });
990    }
991    let expected = rows
992        .checked_mul(cols)
993        .ok_or_else(|| ViError::InvalidRange {
994            start: sweep.period.0,
995            end: sweep.period.1,
996            step: sweep.period.2,
997        })?;
998    if out_plus.len() != expected || out_minus.len() != expected {
999        let got = out_plus.len().max(out_minus.len());
1000        return Err(ViError::OutputLengthMismatch { expected, got });
1001    }
1002
1003    if high.is_empty() || low.is_empty() || close.is_empty() {
1004        return Err(ViError::EmptyInputData);
1005    }
1006    let first = (0..high.len())
1007        .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
1008        .ok_or(ViError::AllValuesNaN)?;
1009    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1010    if high.len() - first < max_p {
1011        return Err(ViError::NotEnoughValidData {
1012            needed: max_p,
1013            valid: high.len() - first,
1014        });
1015    }
1016
1017    let cols = close.len();
1018    let mut pfx_tr = vec![0.0f64; cols];
1019    let mut pfx_vp = vec![0.0f64; cols];
1020    let mut pfx_vm = vec![0.0f64; cols];
1021    if cols > 0 && first < cols {
1022        unsafe {
1023            match kernel {
1024                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1025                Kernel::Avx512 | Kernel::Avx512Batch => vi_prefix_avx512(
1026                    high,
1027                    low,
1028                    close,
1029                    first,
1030                    &mut pfx_tr,
1031                    &mut pfx_vp,
1032                    &mut pfx_vm,
1033                ),
1034                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1035                Kernel::Avx2 | Kernel::Avx2Batch => vi_prefix_avx2(
1036                    high,
1037                    low,
1038                    close,
1039                    first,
1040                    &mut pfx_tr,
1041                    &mut pfx_vp,
1042                    &mut pfx_vm,
1043                ),
1044                _ => vi_prefix_scalar(
1045                    high,
1046                    low,
1047                    close,
1048                    first,
1049                    &mut pfx_tr,
1050                    &mut pfx_vp,
1051                    &mut pfx_vm,
1052                ),
1053            }
1054        }
1055    }
1056
1057    let do_row = |row: usize, p_row: &mut [f64], m_row: &mut [f64]| {
1058        let period = combos[row].period.unwrap();
1059        let warm = first + period - 1;
1060
1061        for i in 0..warm.min(cols) {
1062            p_row[i] = f64::NAN;
1063            m_row[i] = f64::NAN;
1064        }
1065        if warm >= cols {
1066            return;
1067        }
1068        let mut i = warm;
1069        while i < cols {
1070            let tr_sum = if i >= period {
1071                pfx_tr[i] - pfx_tr[i - period]
1072            } else {
1073                pfx_tr[i]
1074            };
1075            let vp_sum = if i >= period {
1076                pfx_vp[i] - pfx_vp[i - period]
1077            } else {
1078                pfx_vp[i]
1079            };
1080            let vm_sum = if i >= period {
1081                pfx_vm[i] - pfx_vm[i - period]
1082            } else {
1083                pfx_vm[i]
1084            };
1085            p_row[i] = vp_sum / tr_sum;
1086            m_row[i] = vm_sum / tr_sum;
1087            i += 1;
1088        }
1089    };
1090
1091    if parallel {
1092        #[cfg(not(target_arch = "wasm32"))]
1093        out_plus
1094            .par_chunks_mut(cols)
1095            .zip(out_minus.par_chunks_mut(cols))
1096            .enumerate()
1097            .for_each(|(row, (p, m))| do_row(row, p, m));
1098        #[cfg(target_arch = "wasm32")]
1099        for ((row, p), m) in out_plus
1100            .chunks_mut(cols)
1101            .enumerate()
1102            .zip(out_minus.chunks_mut(cols))
1103        {
1104            do_row(row, p, m);
1105        }
1106    } else {
1107        for ((row, p), m) in out_plus
1108            .chunks_mut(cols)
1109            .enumerate()
1110            .zip(out_minus.chunks_mut(cols))
1111        {
1112            do_row(row, p, m);
1113        }
1114    }
1115    Ok(combos)
1116}
1117
1118#[inline(always)]
1119unsafe fn vi_prefix_scalar(
1120    high: &[f64],
1121    low: &[f64],
1122    close: &[f64],
1123    first: usize,
1124    pfx_tr: &mut [f64],
1125    pfx_vp: &mut [f64],
1126    pfx_vm: &mut [f64],
1127) {
1128    let n = high.len();
1129    pfx_tr[first] = high[first] - low[first];
1130    pfx_vp[first] = 0.0;
1131    pfx_vm[first] = 0.0;
1132    let mut prev_h = high[first];
1133    let mut prev_l = low[first];
1134    let mut prev_c = close[first];
1135    let mut i = first + 1;
1136    while i < n {
1137        let hi = high[i];
1138        let lo = low[i];
1139        let hl = hi - lo;
1140        let hc = (hi - prev_c).abs();
1141        let lc = (lo - prev_c).abs();
1142        let tr_i = hl.max(hc.max(lc));
1143        let vp_i = (hi - prev_l).abs();
1144        let vm_i = (lo - prev_h).abs();
1145        pfx_tr[i] = pfx_tr[i - 1] + tr_i;
1146        pfx_vp[i] = pfx_vp[i - 1] + vp_i;
1147        pfx_vm[i] = pfx_vm[i - 1] + vm_i;
1148        prev_h = hi;
1149        prev_l = lo;
1150        prev_c = close[i];
1151        i += 1;
1152    }
1153}
1154
1155#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1156#[inline(always)]
1157unsafe fn abs256(x: __m256d) -> __m256d {
1158    let zero = _mm256_set1_pd(0.0);
1159    _mm256_max_pd(x, _mm256_sub_pd(zero, x))
1160}
1161#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1162#[inline(always)]
1163unsafe fn vi_prefix_avx2(
1164    high: &[f64],
1165    low: &[f64],
1166    close: &[f64],
1167    first: usize,
1168    pfx_tr: &mut [f64],
1169    pfx_vp: &mut [f64],
1170    pfx_vm: &mut [f64],
1171) {
1172    use core::arch::x86_64::*;
1173    let n = high.len();
1174    pfx_tr[first] = high[first] - low[first];
1175    pfx_vp[first] = 0.0;
1176    pfx_vm[first] = 0.0;
1177    let mut i = first + 1;
1178
1179    let mut carry_tr = pfx_tr[i - 1];
1180    let mut carry_vp = pfx_vp[i - 1];
1181    let mut carry_vm = pfx_vm[i - 1];
1182    let step = 4;
1183    while i + step <= n {
1184        let v_hi = _mm256_loadu_pd(high.as_ptr().add(i));
1185        let v_lo = _mm256_loadu_pd(low.as_ptr().add(i));
1186        let v_cl_prev = _mm256_loadu_pd(close.as_ptr().add(i - 1));
1187        let v_lo_prev = _mm256_loadu_pd(low.as_ptr().add(i - 1));
1188        let v_hi_prev = _mm256_loadu_pd(high.as_ptr().add(i - 1));
1189
1190        let hl = _mm256_sub_pd(v_hi, v_lo);
1191        let hc = abs256(_mm256_sub_pd(v_hi, v_cl_prev));
1192        let lc = abs256(_mm256_sub_pd(v_lo, v_cl_prev));
1193        let tr_v = _mm256_max_pd(hl, _mm256_max_pd(hc, lc));
1194        let vp_v = abs256(_mm256_sub_pd(v_hi, v_lo_prev));
1195        let vm_v = abs256(_mm256_sub_pd(v_lo, v_hi_prev));
1196
1197        let mut tr_tmp = [0.0f64; 4];
1198        let mut vp_tmp = [0.0f64; 4];
1199        let mut vm_tmp = [0.0f64; 4];
1200        _mm256_storeu_pd(tr_tmp.as_mut_ptr(), tr_v);
1201        _mm256_storeu_pd(vp_tmp.as_mut_ptr(), vp_v);
1202        _mm256_storeu_pd(vm_tmp.as_mut_ptr(), vm_v);
1203        let mut k = 0;
1204        while k < step {
1205            carry_tr += tr_tmp[k];
1206            carry_vp += vp_tmp[k];
1207            carry_vm += vm_tmp[k];
1208            pfx_tr[i + k] = carry_tr;
1209            pfx_vp[i + k] = carry_vp;
1210            pfx_vm[i + k] = carry_vm;
1211            k += 1;
1212        }
1213
1214        i += step;
1215    }
1216    while i < n {
1217        let hi = *high.get_unchecked(i);
1218        let lo = *low.get_unchecked(i);
1219        let prev_c = *close.get_unchecked(i - 1);
1220        let prev_l = *low.get_unchecked(i - 1);
1221        let prev_h = *high.get_unchecked(i - 1);
1222        let hl = hi - lo;
1223        let hc = (hi - prev_c).abs();
1224        let lc = (lo - prev_c).abs();
1225        let tr_i = hl.max(hc.max(lc));
1226        let vp_i = (hi - prev_l).abs();
1227        let vm_i = (lo - prev_h).abs();
1228        carry_tr += tr_i;
1229        carry_vp += vp_i;
1230        carry_vm += vm_i;
1231        pfx_tr[i] = carry_tr;
1232        pfx_vp[i] = carry_vp;
1233        pfx_vm[i] = carry_vm;
1234        i += 1;
1235    }
1236}
1237
1238#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1239#[inline(always)]
1240unsafe fn abs512(x: __m512d) -> __m512d {
1241    let zero = _mm512_set1_pd(0.0);
1242    _mm512_max_pd(x, _mm512_sub_pd(zero, x))
1243}
1244#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1245#[inline(always)]
1246unsafe fn vi_prefix_avx512(
1247    high: &[f64],
1248    low: &[f64],
1249    close: &[f64],
1250    first: usize,
1251    pfx_tr: &mut [f64],
1252    pfx_vp: &mut [f64],
1253    pfx_vm: &mut [f64],
1254) {
1255    use core::arch::x86_64::*;
1256    let n = high.len();
1257    pfx_tr[first] = high[first] - low[first];
1258    pfx_vp[first] = 0.0;
1259    pfx_vm[first] = 0.0;
1260    let mut i = first + 1;
1261    let mut carry_tr = pfx_tr[i - 1];
1262    let mut carry_vp = pfx_vp[i - 1];
1263    let mut carry_vm = pfx_vm[i - 1];
1264    let step = 8;
1265    while i + step <= n {
1266        let v_hi = _mm512_loadu_pd(high.as_ptr().add(i));
1267        let v_lo = _mm512_loadu_pd(low.as_ptr().add(i));
1268        let v_cl_prev = _mm512_loadu_pd(close.as_ptr().add(i - 1));
1269        let v_lo_prev = _mm512_loadu_pd(low.as_ptr().add(i - 1));
1270        let v_hi_prev = _mm512_loadu_pd(high.as_ptr().add(i - 1));
1271
1272        let hl = _mm512_sub_pd(v_hi, v_lo);
1273        let hc = abs512(_mm512_sub_pd(v_hi, v_cl_prev));
1274        let lc = abs512(_mm512_sub_pd(v_lo, v_cl_prev));
1275        let tr_v = _mm512_max_pd(hl, _mm512_max_pd(hc, lc));
1276        let vp_v = abs512(_mm512_sub_pd(v_hi, v_lo_prev));
1277        let vm_v = abs512(_mm512_sub_pd(v_lo, v_hi_prev));
1278
1279        let mut tr_tmp = [0.0f64; 8];
1280        let mut vp_tmp = [0.0f64; 8];
1281        let mut vm_tmp = [0.0f64; 8];
1282        _mm512_storeu_pd(tr_tmp.as_mut_ptr(), tr_v);
1283        _mm512_storeu_pd(vp_tmp.as_mut_ptr(), vp_v);
1284        _mm512_storeu_pd(vm_tmp.as_mut_ptr(), vm_v);
1285        let mut k = 0;
1286        while k < step {
1287            carry_tr += tr_tmp[k];
1288            carry_vp += vp_tmp[k];
1289            carry_vm += vm_tmp[k];
1290            pfx_tr[i + k] = carry_tr;
1291            pfx_vp[i + k] = carry_vp;
1292            pfx_vm[i + k] = carry_vm;
1293            k += 1;
1294        }
1295        i += step;
1296    }
1297    while i < n {
1298        let hi = *high.get_unchecked(i);
1299        let lo = *low.get_unchecked(i);
1300        let prev_c = *close.get_unchecked(i - 1);
1301        let prev_l = *low.get_unchecked(i - 1);
1302        let prev_h = *high.get_unchecked(i - 1);
1303        let hl = hi - lo;
1304        let hc = (hi - prev_c).abs();
1305        let lc = (lo - prev_c).abs();
1306        let tr_i = hl.max(hc.max(lc));
1307        let vp_i = (hi - prev_l).abs();
1308        let vm_i = (lo - prev_h).abs();
1309        carry_tr += tr_i;
1310        carry_vp += vp_i;
1311        carry_vm += vm_i;
1312        pfx_tr[i] = carry_tr;
1313        pfx_vp[i] = carry_vp;
1314        pfx_vm[i] = carry_vm;
1315        i += 1;
1316    }
1317}
1318
1319#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1320use serde::{Deserialize, Serialize};
1321#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1322use wasm_bindgen::prelude::*;
1323
1324#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1325#[derive(Serialize, Deserialize)]
1326pub struct ViJsResult {
1327    pub plus: Vec<f64>,
1328    pub minus: Vec<f64>,
1329}
1330
1331#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1332#[wasm_bindgen]
1333pub fn vi_js(high: &[f64], low: &[f64], close: &[f64], period: usize) -> Result<JsValue, JsValue> {
1334    let mut plus = vec![0.0; high.len()];
1335    let mut minus = vec![0.0; high.len()];
1336
1337    vi_into_slice_wasm(
1338        &mut plus,
1339        &mut minus,
1340        high,
1341        low,
1342        close,
1343        period,
1344        detect_best_kernel(),
1345    )
1346    .map_err(|e| JsValue::from_str(&e.to_string()))?;
1347
1348    let result = ViJsResult { plus, minus };
1349
1350    serde_wasm_bindgen::to_value(&result).map_err(|e| JsValue::from_str(&e.to_string()))
1351}
1352
1353#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1354#[wasm_bindgen]
1355pub fn vi_unified_js(
1356    high: &[f64],
1357    low: &[f64],
1358    close: &[f64],
1359    period: usize,
1360) -> Result<Vec<f64>, JsValue> {
1361    let mut result = vec![0.0; high.len() * 2];
1362
1363    let (plus_slice, minus_slice) = result.split_at_mut(high.len());
1364
1365    vi_into_slice_wasm(
1366        plus_slice,
1367        minus_slice,
1368        high,
1369        low,
1370        close,
1371        period,
1372        detect_best_kernel(),
1373    )
1374    .map_err(|e| JsValue::from_str(&e.to_string()))?;
1375
1376    Ok(result)
1377}
1378
1379#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1380#[wasm_bindgen]
1381pub fn vi_alloc(len: usize) -> *mut f64 {
1382    let mut vec = Vec::<f64>::with_capacity(len * 2);
1383    let ptr = vec.as_mut_ptr();
1384    std::mem::forget(vec);
1385    ptr
1386}
1387
1388#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1389#[wasm_bindgen]
1390pub fn vi_free(ptr: *mut f64, len: usize) {
1391    unsafe {
1392        let _ = Vec::from_raw_parts(ptr, len * 2, len * 2);
1393    }
1394}
1395
1396#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1397#[wasm_bindgen]
1398pub fn vi_into(
1399    high_ptr: *const f64,
1400    low_ptr: *const f64,
1401    close_ptr: *const f64,
1402    plus_ptr: *mut f64,
1403    minus_ptr: *mut f64,
1404    len: usize,
1405    period: usize,
1406) -> Result<(), JsValue> {
1407    if high_ptr.is_null()
1408        || low_ptr.is_null()
1409        || close_ptr.is_null()
1410        || plus_ptr.is_null()
1411        || minus_ptr.is_null()
1412    {
1413        return Err(JsValue::from_str("Null pointer provided"));
1414    }
1415
1416    unsafe {
1417        let high = std::slice::from_raw_parts(high_ptr, len);
1418        let low = std::slice::from_raw_parts(low_ptr, len);
1419        let close = std::slice::from_raw_parts(close_ptr, len);
1420
1421        let plus_out = std::slice::from_raw_parts_mut(plus_ptr, len);
1422        let minus_out = std::slice::from_raw_parts_mut(minus_ptr, len);
1423
1424        vi_into_slice_wasm(
1425            plus_out,
1426            minus_out,
1427            high,
1428            low,
1429            close,
1430            period,
1431            detect_best_kernel(),
1432        )
1433        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1434
1435        Ok(())
1436    }
1437}
1438
1439#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1440#[derive(Serialize, Deserialize)]
1441pub struct ViBatchConfig {
1442    pub period_range: (usize, usize, usize),
1443}
1444
1445#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1446#[derive(Serialize, Deserialize)]
1447pub struct ViBatchJsOutput {
1448    pub plus: Vec<f64>,
1449    pub minus: Vec<f64>,
1450    pub periods: Vec<usize>,
1451    pub rows: usize,
1452    pub cols: usize,
1453}
1454
1455#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1456#[wasm_bindgen(js_name = vi_batch)]
1457pub fn vi_batch_js(
1458    high: &[f64],
1459    low: &[f64],
1460    close: &[f64],
1461    config: JsValue,
1462) -> Result<JsValue, JsValue> {
1463    let config: ViBatchConfig = serde_wasm_bindgen::from_value(config)
1464        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1465
1466    let sweep = ViBatchRange {
1467        period: config.period_range,
1468    };
1469    let output = vi_batch_with_kernel(high, low, close, &sweep, Kernel::Auto)
1470        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1471
1472    let periods: Vec<usize> = output
1473        .combos
1474        .iter()
1475        .map(|p| p.period.unwrap_or(14))
1476        .collect();
1477
1478    let js_output = ViBatchJsOutput {
1479        plus: output.plus,
1480        minus: output.minus,
1481        periods,
1482        rows: output.rows,
1483        cols: output.cols,
1484    };
1485
1486    serde_wasm_bindgen::to_value(&js_output)
1487        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1488}
1489
1490#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1491#[wasm_bindgen]
1492pub fn vi_batch_into(
1493    high_ptr: *const f64,
1494    low_ptr: *const f64,
1495    close_ptr: *const f64,
1496    plus_ptr: *mut f64,
1497    minus_ptr: *mut f64,
1498    len: usize,
1499    period_start: usize,
1500    period_end: usize,
1501    period_step: usize,
1502) -> Result<usize, JsValue> {
1503    if high_ptr.is_null()
1504        || low_ptr.is_null()
1505        || close_ptr.is_null()
1506        || plus_ptr.is_null()
1507        || minus_ptr.is_null()
1508    {
1509        return Err(JsValue::from_str("Null pointer provided"));
1510    }
1511
1512    unsafe {
1513        let high = std::slice::from_raw_parts(high_ptr, len);
1514        let low = std::slice::from_raw_parts(low_ptr, len);
1515        let close = std::slice::from_raw_parts(close_ptr, len);
1516
1517        let sweep = ViBatchRange {
1518            period: (period_start, period_end, period_step),
1519        };
1520
1521        let combos = expand_grid(&sweep);
1522        let rows = combos.len();
1523        let cols = len;
1524        let total = rows
1525            .checked_mul(cols)
1526            .ok_or_else(|| JsValue::from_str("rows*cols overflow in vi_into"))?;
1527
1528        let plus_out = std::slice::from_raw_parts_mut(plus_ptr, total);
1529        let minus_out = std::slice::from_raw_parts_mut(minus_ptr, total);
1530
1531        let _ = vi_batch_inner_into(
1532            high,
1533            low,
1534            close,
1535            &sweep,
1536            Kernel::Auto,
1537            false,
1538            plus_out,
1539            minus_out,
1540        )
1541        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1542
1543        Ok(rows)
1544    }
1545}
1546
1547#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1548pub fn vi_into_slice_wasm(
1549    dst_plus: &mut [f64],
1550    dst_minus: &mut [f64],
1551    high: &[f64],
1552    low: &[f64],
1553    close: &[f64],
1554    period: usize,
1555    kern: Kernel,
1556) -> Result<(), ViError> {
1557    let params = ViParams {
1558        period: Some(period),
1559    };
1560    let input = ViInput::from_slices(high, low, close, params);
1561    vi_into_slice(dst_plus, dst_minus, &input, kern)
1562}
1563
1564#[cfg(feature = "python")]
1565pub fn register_vi_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
1566    m.add_function(wrap_pyfunction!(vi_py, m)?)?;
1567    m.add_function(wrap_pyfunction!(vi_batch_py, m)?)?;
1568    m.add_class::<ViStreamPy>()?;
1569    Ok(())
1570}
1571
1572#[cfg(test)]
1573mod tests {
1574    use super::*;
1575    use crate::skip_if_unsupported;
1576    use crate::utilities::data_loader::read_candles_from_csv;
1577
1578    fn check_vi_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1579        skip_if_unsupported!(kernel, test_name);
1580        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1581        let candles = read_candles_from_csv(file_path)?;
1582        let default_params = ViParams { period: None };
1583        let input = ViInput::from_candles(&candles, default_params);
1584        let output = vi_with_kernel(&input, kernel)?;
1585        assert_eq!(output.plus.len(), candles.close.len());
1586        assert_eq!(output.minus.len(), candles.close.len());
1587        Ok(())
1588    }
1589    fn check_vi_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1590        skip_if_unsupported!(kernel, test_name);
1591        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1592        let candles = read_candles_from_csv(file_path)?;
1593        let input = ViInput::from_candles(&candles, ViParams::default());
1594        let result = vi_with_kernel(&input, kernel)?;
1595        let expected_last_five_plus = [
1596            0.9970238095238095,
1597            0.9871071716357775,
1598            0.9464453759945247,
1599            0.890897412369242,
1600            0.9206478557604156,
1601        ];
1602        let expected_last_five_minus = [
1603            1.0097117794486214,
1604            1.04174053182917,
1605            1.1152365471811105,
1606            1.181684712791338,
1607            1.1894672506875827,
1608        ];
1609        let n = result.plus.len();
1610        let plus_slice = &result.plus[n - 5..];
1611        let minus_slice = &result.minus[n - 5..];
1612        for (i, &val) in plus_slice.iter().enumerate() {
1613            let expected = expected_last_five_plus[i];
1614            assert!(
1615                (val - expected).abs() < 1e-8,
1616                "[{}] VI+ mismatch at idx {}: got {}, expected {}",
1617                test_name,
1618                i,
1619                val,
1620                expected
1621            );
1622        }
1623        for (i, &val) in minus_slice.iter().enumerate() {
1624            let expected = expected_last_five_minus[i];
1625            assert!(
1626                (val - expected).abs() < 1e-8,
1627                "[{}] VI- mismatch at idx {}: got {}, expected {}",
1628                test_name,
1629                i,
1630                val,
1631                expected
1632            );
1633        }
1634        Ok(())
1635    }
1636    fn check_vi_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1637        skip_if_unsupported!(kernel, test_name);
1638        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1639        let candles = read_candles_from_csv(file_path)?;
1640        let input = ViInput::with_default_candles(&candles);
1641        let output = vi_with_kernel(&input, kernel)?;
1642        assert_eq!(output.plus.len(), candles.close.len());
1643        assert_eq!(output.minus.len(), candles.close.len());
1644        Ok(())
1645    }
1646    fn check_vi_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1647        skip_if_unsupported!(kernel, test_name);
1648        let input_data = [10.0, 20.0, 30.0];
1649        let params = ViParams { period: Some(0) };
1650        let input = ViInput::from_slices(&input_data, &input_data, &input_data, params);
1651        let res = vi_with_kernel(&input, kernel);
1652        assert!(
1653            res.is_err(),
1654            "[{}] VI should fail with zero period",
1655            test_name
1656        );
1657        Ok(())
1658    }
1659    fn check_vi_period_exceeds_length(
1660        test_name: &str,
1661        kernel: Kernel,
1662    ) -> Result<(), Box<dyn Error>> {
1663        skip_if_unsupported!(kernel, test_name);
1664        let data_small = [10.0, 20.0, 30.0];
1665        let params = ViParams { period: Some(10) };
1666        let input = ViInput::from_slices(&data_small, &data_small, &data_small, params);
1667        let res = vi_with_kernel(&input, kernel);
1668        assert!(
1669            res.is_err(),
1670            "[{}] VI should fail with period exceeding length",
1671            test_name
1672        );
1673        Ok(())
1674    }
1675    fn check_vi_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1676        skip_if_unsupported!(kernel, test_name);
1677        let single_point = [42.0];
1678        let params = ViParams { period: Some(14) };
1679        let input = ViInput::from_slices(&single_point, &single_point, &single_point, params);
1680        let res = vi_with_kernel(&input, kernel);
1681        assert!(
1682            res.is_err(),
1683            "[{}] VI should fail with insufficient data",
1684            test_name
1685        );
1686        Ok(())
1687    }
1688    fn check_vi_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1689        skip_if_unsupported!(kernel, test_name);
1690        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1691        let candles = read_candles_from_csv(file_path)?;
1692        let input = ViInput::from_candles(&candles, ViParams::default());
1693        let res = vi_with_kernel(&input, kernel)?;
1694        assert_eq!(res.plus.len(), candles.close.len());
1695        if res.plus.len() > 20 {
1696            for (i, &val) in res.plus[20..].iter().enumerate() {
1697                assert!(
1698                    !val.is_nan(),
1699                    "[{}] Found unexpected NaN at out-index {}",
1700                    test_name,
1701                    20 + i
1702                );
1703            }
1704        }
1705        Ok(())
1706    }
1707
1708    #[cfg(debug_assertions)]
1709    fn check_vi_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1710        skip_if_unsupported!(kernel, test_name);
1711
1712        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1713        let candles = read_candles_from_csv(file_path)?;
1714
1715        let test_params = vec![
1716            ViParams::default(),
1717            ViParams { period: Some(1) },
1718            ViParams { period: Some(2) },
1719            ViParams { period: Some(5) },
1720            ViParams { period: Some(7) },
1721            ViParams { period: Some(10) },
1722            ViParams { period: Some(20) },
1723            ViParams { period: Some(30) },
1724            ViParams { period: Some(50) },
1725            ViParams { period: Some(100) },
1726            ViParams { period: Some(200) },
1727        ];
1728
1729        for (param_idx, params) in test_params.iter().enumerate() {
1730            let input = ViInput::from_candles(&candles, params.clone());
1731            let output = vi_with_kernel(&input, kernel)?;
1732
1733            for (i, &val) in output.plus.iter().enumerate() {
1734                if val.is_nan() {
1735                    continue;
1736                }
1737
1738                let bits = val.to_bits();
1739
1740                if bits == 0x11111111_11111111 {
1741                    panic!(
1742						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in plus array \
1743						 with params: period={} (param set {})",
1744						test_name,
1745						val,
1746						bits,
1747						i,
1748						params.period.unwrap_or(14),
1749						param_idx
1750					);
1751                }
1752
1753                if bits == 0x22222222_22222222 {
1754                    panic!(
1755						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in plus array \
1756						 with params: period={} (param set {})",
1757						test_name,
1758						val,
1759						bits,
1760						i,
1761						params.period.unwrap_or(14),
1762						param_idx
1763					);
1764                }
1765
1766                if bits == 0x33333333_33333333 {
1767                    panic!(
1768						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in plus array \
1769						 with params: period={} (param set {})",
1770						test_name,
1771						val,
1772						bits,
1773						i,
1774						params.period.unwrap_or(14),
1775						param_idx
1776					);
1777                }
1778            }
1779
1780            for (i, &val) in output.minus.iter().enumerate() {
1781                if val.is_nan() {
1782                    continue;
1783                }
1784
1785                let bits = val.to_bits();
1786
1787                if bits == 0x11111111_11111111 {
1788                    panic!(
1789						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in minus array \
1790						 with params: period={} (param set {})",
1791						test_name,
1792						val,
1793						bits,
1794						i,
1795						params.period.unwrap_or(14),
1796						param_idx
1797					);
1798                }
1799
1800                if bits == 0x22222222_22222222 {
1801                    panic!(
1802						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in minus array \
1803						 with params: period={} (param set {})",
1804						test_name,
1805						val,
1806						bits,
1807						i,
1808						params.period.unwrap_or(14),
1809						param_idx
1810					);
1811                }
1812
1813                if bits == 0x33333333_33333333 {
1814                    panic!(
1815						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in minus array \
1816						 with params: period={} (param set {})",
1817						test_name,
1818						val,
1819						bits,
1820						i,
1821						params.period.unwrap_or(14),
1822						param_idx
1823					);
1824                }
1825            }
1826        }
1827
1828        Ok(())
1829    }
1830
1831    #[cfg(not(debug_assertions))]
1832    fn check_vi_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1833        Ok(())
1834    }
1835
1836    #[cfg(feature = "proptest")]
1837    #[allow(clippy::float_cmp)]
1838    fn check_vi_property(
1839        test_name: &str,
1840        kernel: Kernel,
1841    ) -> Result<(), Box<dyn std::error::Error>> {
1842        use proptest::prelude::*;
1843        skip_if_unsupported!(kernel, test_name);
1844
1845        let strat = (2usize..=100).prop_flat_map(|period| {
1846            (period + 50..400).prop_flat_map(move |len| {
1847                (
1848                    prop::collection::vec(
1849                        (50.0f64..500.0f64).prop_filter("finite", |x| x.is_finite()),
1850                        len,
1851                    ),
1852                    prop::collection::vec((0.001f64..0.05f64), len),
1853                    prop::collection::vec((0.0f64..1.0f64), len),
1854                    Just(period),
1855                )
1856            })
1857        });
1858
1859        proptest::test_runner::TestRunner::default()
1860            .run(
1861                &strat,
1862                |(base_prices, volatilities, close_positions, period)| {
1863                    let mut high = Vec::with_capacity(base_prices.len());
1864                    let mut low = Vec::with_capacity(base_prices.len());
1865                    let mut close = Vec::with_capacity(base_prices.len());
1866
1867                    assert_eq!(base_prices.len(), volatilities.len());
1868                    assert_eq!(base_prices.len(), close_positions.len());
1869
1870                    for i in 0..base_prices.len() {
1871                        let price = base_prices[i];
1872                        let vol = volatilities[i];
1873                        let close_pos = close_positions[i];
1874
1875                        let range = price * vol;
1876                        let h = price + range * (0.3 + vol * 2.0);
1877                        let l = price - range * (0.3 + vol * 2.0);
1878                        let c = l + (h - l) * close_pos;
1879
1880                        high.push(h);
1881                        low.push(l);
1882                        close.push(c);
1883                    }
1884
1885                    let params = ViParams {
1886                        period: Some(period),
1887                    };
1888                    let input = ViInput::from_slices(&high, &low, &close, params.clone());
1889
1890                    let ViOutput {
1891                        plus: out_plus,
1892                        minus: out_minus,
1893                    } = vi_with_kernel(&input, kernel).unwrap();
1894                    let ViOutput {
1895                        plus: ref_plus,
1896                        minus: ref_minus,
1897                    } = vi_with_kernel(&input, Kernel::Scalar).unwrap();
1898
1899                    for i in 0..out_plus.len() {
1900                        if !out_plus[i].is_nan() {
1901                            prop_assert!(
1902                                out_plus[i] >= -1e-9,
1903                                "[{}] VI+ negative at idx {}: {}",
1904                                test_name,
1905                                i,
1906                                out_plus[i]
1907                            );
1908                        }
1909                        if !out_minus[i].is_nan() {
1910                            prop_assert!(
1911                                out_minus[i] >= -1e-9,
1912                                "[{}] VI- negative at idx {}: {}",
1913                                test_name,
1914                                i,
1915                                out_minus[i]
1916                            );
1917                        }
1918                    }
1919
1920                    let first_valid = (0..high.len())
1921                        .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
1922                        .unwrap_or(0);
1923                    let warmup_end = first_valid + period - 1;
1924
1925                    for i in 0..warmup_end.min(out_plus.len()) {
1926                        prop_assert!(
1927                            out_plus[i].is_nan(),
1928                            "[{}] Expected NaN during warmup at idx {}, got {}",
1929                            test_name,
1930                            i,
1931                            out_plus[i]
1932                        );
1933                        prop_assert!(
1934                            out_minus[i].is_nan(),
1935                            "[{}] Expected NaN during warmup at idx {}, got {}",
1936                            test_name,
1937                            i,
1938                            out_minus[i]
1939                        );
1940                    }
1941
1942                    for i in warmup_end..out_plus.len() {
1943                        let plus_bits = out_plus[i].to_bits();
1944                        let ref_plus_bits = ref_plus[i].to_bits();
1945                        let minus_bits = out_minus[i].to_bits();
1946                        let ref_minus_bits = ref_minus[i].to_bits();
1947
1948                        if !out_plus[i].is_finite() || !ref_plus[i].is_finite() {
1949                            prop_assert!(
1950                                plus_bits == ref_plus_bits,
1951                                "[{}] VI+ finite/NaN mismatch at idx {}: {} vs {}",
1952                                test_name,
1953                                i,
1954                                out_plus[i],
1955                                ref_plus[i]
1956                            );
1957                        } else {
1958                            let ulp_diff = plus_bits.abs_diff(ref_plus_bits);
1959                            prop_assert!(
1960                                (out_plus[i] - ref_plus[i]).abs() <= 1e-9 || ulp_diff <= 4,
1961                                "[{}] VI+ mismatch at idx {}: {} vs {} (ULP={})",
1962                                test_name,
1963                                i,
1964                                out_plus[i],
1965                                ref_plus[i],
1966                                ulp_diff
1967                            );
1968                        }
1969
1970                        if !out_minus[i].is_finite() || !ref_minus[i].is_finite() {
1971                            prop_assert!(
1972                                minus_bits == ref_minus_bits,
1973                                "[{}] VI- finite/NaN mismatch at idx {}: {} vs {}",
1974                                test_name,
1975                                i,
1976                                out_minus[i],
1977                                ref_minus[i]
1978                            );
1979                        } else {
1980                            let ulp_diff = minus_bits.abs_diff(ref_minus_bits);
1981                            prop_assert!(
1982                                (out_minus[i] - ref_minus[i]).abs() <= 1e-9 || ulp_diff <= 4,
1983                                "[{}] VI- mismatch at idx {}: {} vs {} (ULP={})",
1984                                test_name,
1985                                i,
1986                                out_minus[i],
1987                                ref_minus[i],
1988                                ulp_diff
1989                            );
1990                        }
1991                    }
1992
1993                    if period == 1 {
1994                        if warmup_end < out_plus.len() {
1995                            prop_assert!(
1996                                out_plus[warmup_end].is_finite(),
1997                                "[{}] VI+ should be finite for period=1 at idx {}",
1998                                test_name,
1999                                warmup_end
2000                            );
2001                            prop_assert!(
2002                                out_minus[warmup_end].is_finite(),
2003                                "[{}] VI- should be finite for period=1 at idx {}",
2004                                test_name,
2005                                warmup_end
2006                            );
2007                        }
2008                    }
2009
2010                    if period <= 5 && warmup_end + 5 < high.len() && warmup_end >= period {
2011                        let idx = warmup_end;
2012
2013                        let mut tr_sum = 0.0;
2014                        let mut vp_sum = 0.0;
2015                        let mut vm_sum = 0.0;
2016
2017                        let first_idx = idx + 1 - period;
2018                        tr_sum += high[first_idx] - low[first_idx];
2019
2020                        for j in (first_idx + 1)..=idx {
2021                            let tr = (high[j] - low[j])
2022                                .max((high[j] - close[j - 1]).abs())
2023                                .max((low[j] - close[j - 1]).abs());
2024                            let vp = (high[j] - low[j - 1]).abs();
2025                            let vm = (low[j] - high[j - 1]).abs();
2026
2027                            tr_sum += tr;
2028                            vp_sum += vp;
2029                            vm_sum += vm;
2030                        }
2031
2032                        if tr_sum > 1e-10 {
2033                            let expected_plus = vp_sum / tr_sum;
2034                            let expected_minus = vm_sum / tr_sum;
2035
2036                            prop_assert!(
2037                                (out_plus[idx] - expected_plus).abs() < 1e-6,
2038                                "[{}] VI+ formula verification failed at idx {}: {} vs {}",
2039                                test_name,
2040                                idx,
2041                                out_plus[idx],
2042                                expected_plus
2043                            );
2044                            prop_assert!(
2045                                (out_minus[idx] - expected_minus).abs() < 1e-6,
2046                                "[{}] VI- formula verification failed at idx {}: {} vs {}",
2047                                test_name,
2048                                idx,
2049                                out_minus[idx],
2050                                expected_minus
2051                            );
2052                        }
2053                    }
2054
2055                    #[cfg(debug_assertions)]
2056                    {
2057                        for i in 0..out_plus.len() {
2058                            if !out_plus[i].is_nan() {
2059                                let bits = out_plus[i].to_bits();
2060                                prop_assert!(
2061                                    bits != 0x11111111_11111111
2062                                        && bits != 0x22222222_22222222
2063                                        && bits != 0x33333333_33333333,
2064                                    "[{}] Found poison value in VI+ at idx {}: 0x{:016X}",
2065                                    test_name,
2066                                    i,
2067                                    bits
2068                                );
2069                            }
2070                            if !out_minus[i].is_nan() {
2071                                let bits = out_minus[i].to_bits();
2072                                prop_assert!(
2073                                    bits != 0x11111111_11111111
2074                                        && bits != 0x22222222_22222222
2075                                        && bits != 0x33333333_33333333,
2076                                    "[{}] Found poison value in VI- at idx {}: 0x{:016X}",
2077                                    test_name,
2078                                    i,
2079                                    bits
2080                                );
2081                            }
2082                        }
2083                    }
2084
2085                    Ok(())
2086                },
2087            )
2088            .unwrap();
2089
2090        Ok(())
2091    }
2092
2093    macro_rules! generate_all_vi_tests {
2094        ($($test_fn:ident),*) => {
2095            paste::paste! {
2096                $(
2097                    #[test]
2098                    fn [<$test_fn _scalar_f64>]() {
2099                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2100                    }
2101                )*
2102                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2103                $(
2104                    #[test]
2105                    fn [<$test_fn _avx2_f64>]() {
2106                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2107                    }
2108                    #[test]
2109                    fn [<$test_fn _avx512_f64>]() {
2110                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2111                    }
2112                )*
2113            }
2114        }
2115    }
2116    generate_all_vi_tests!(
2117        check_vi_partial_params,
2118        check_vi_accuracy,
2119        check_vi_default_candles,
2120        check_vi_zero_period,
2121        check_vi_period_exceeds_length,
2122        check_vi_very_small_dataset,
2123        check_vi_nan_handling,
2124        check_vi_no_poison
2125    );
2126
2127    #[cfg(feature = "proptest")]
2128    generate_all_vi_tests!(check_vi_property);
2129    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2130        skip_if_unsupported!(kernel, test);
2131        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2132        let c = read_candles_from_csv(file)?;
2133        let output = ViBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
2134        let def = ViParams::default();
2135        let row = output.plus_for(&def).expect("default row missing");
2136        assert_eq!(row.len(), c.close.len());
2137        Ok(())
2138    }
2139
2140    #[cfg(debug_assertions)]
2141    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2142        skip_if_unsupported!(kernel, test);
2143
2144        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2145        let c = read_candles_from_csv(file)?;
2146
2147        let test_configs = vec![
2148            (2, 10, 2),
2149            (5, 25, 5),
2150            (20, 50, 10),
2151            (2, 5, 1),
2152            (14, 14, 0),
2153            (30, 60, 15),
2154            (50, 100, 25),
2155            (100, 200, 50),
2156        ];
2157
2158        for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
2159            let output = ViBatchBuilder::new()
2160                .kernel(kernel)
2161                .period_range(p_start, p_end, p_step)
2162                .apply_candles(&c)?;
2163
2164            for (idx, &val) in output.plus.iter().enumerate() {
2165                if val.is_nan() {
2166                    continue;
2167                }
2168
2169                let bits = val.to_bits();
2170                let row = idx / output.cols;
2171                let col = idx % output.cols;
2172                let combo = &output.combos[row];
2173
2174                if bits == 0x11111111_11111111 {
2175                    panic!(
2176                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2177						 at row {} col {} (flat index {}) in plus array with params: period={}",
2178                        test,
2179                        cfg_idx,
2180                        val,
2181                        bits,
2182                        row,
2183                        col,
2184                        idx,
2185                        combo.period.unwrap_or(14)
2186                    );
2187                }
2188
2189                if bits == 0x22222222_22222222 {
2190                    panic!(
2191                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2192						 at row {} col {} (flat index {}) in plus array with params: period={}",
2193                        test,
2194                        cfg_idx,
2195                        val,
2196                        bits,
2197                        row,
2198                        col,
2199                        idx,
2200                        combo.period.unwrap_or(14)
2201                    );
2202                }
2203
2204                if bits == 0x33333333_33333333 {
2205                    panic!(
2206                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2207						 at row {} col {} (flat index {}) in plus array with params: period={}",
2208                        test,
2209                        cfg_idx,
2210                        val,
2211                        bits,
2212                        row,
2213                        col,
2214                        idx,
2215                        combo.period.unwrap_or(14)
2216                    );
2217                }
2218            }
2219
2220            for (idx, &val) in output.minus.iter().enumerate() {
2221                if val.is_nan() {
2222                    continue;
2223                }
2224
2225                let bits = val.to_bits();
2226                let row = idx / output.cols;
2227                let col = idx % output.cols;
2228                let combo = &output.combos[row];
2229
2230                if bits == 0x11111111_11111111 {
2231                    panic!(
2232                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2233						 at row {} col {} (flat index {}) in minus array with params: period={}",
2234                        test,
2235                        cfg_idx,
2236                        val,
2237                        bits,
2238                        row,
2239                        col,
2240                        idx,
2241                        combo.period.unwrap_or(14)
2242                    );
2243                }
2244
2245                if bits == 0x22222222_22222222 {
2246                    panic!(
2247                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2248						 at row {} col {} (flat index {}) in minus array with params: period={}",
2249                        test,
2250                        cfg_idx,
2251                        val,
2252                        bits,
2253                        row,
2254                        col,
2255                        idx,
2256                        combo.period.unwrap_or(14)
2257                    );
2258                }
2259
2260                if bits == 0x33333333_33333333 {
2261                    panic!(
2262                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2263						 at row {} col {} (flat index {}) in minus array with params: period={}",
2264                        test,
2265                        cfg_idx,
2266                        val,
2267                        bits,
2268                        row,
2269                        col,
2270                        idx,
2271                        combo.period.unwrap_or(14)
2272                    );
2273                }
2274            }
2275        }
2276
2277        Ok(())
2278    }
2279
2280    #[cfg(not(debug_assertions))]
2281    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2282        Ok(())
2283    }
2284    macro_rules! gen_batch_tests {
2285        ($fn_name:ident) => {
2286            paste::paste! {
2287                #[test] fn [<$fn_name _scalar>]()      {
2288                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2289                }
2290                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2291                #[test] fn [<$fn_name _avx2>]()        {
2292                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2293                }
2294                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2295                #[test] fn [<$fn_name _avx512>]()      {
2296                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2297                }
2298                #[test] fn [<$fn_name _auto_detect>]() {
2299                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2300                }
2301            }
2302        };
2303    }
2304    gen_batch_tests!(check_batch_default_row);
2305    gen_batch_tests!(check_batch_no_poison);
2306}
2307
2308#[cfg(feature = "python")]
2309#[pyfunction(name = "vi")]
2310#[pyo3(signature = (high, low, close, period, kernel=None))]
2311pub fn vi_py<'py>(
2312    py: Python<'py>,
2313    high: PyReadonlyArray1<'py, f64>,
2314    low: PyReadonlyArray1<'py, f64>,
2315    close: PyReadonlyArray1<'py, f64>,
2316    period: usize,
2317    kernel: Option<&str>,
2318) -> PyResult<Bound<'py, PyDict>> {
2319    let h = high.as_slice()?;
2320    let l = low.as_slice()?;
2321    let c = close.as_slice()?;
2322
2323    if h.len() != l.len() || h.len() != c.len() {
2324        return Err(PyValueError::new_err(format!(
2325            "Input data length mismatch: high={}, low={}, close={}",
2326            h.len(),
2327            l.len(),
2328            c.len()
2329        )));
2330    }
2331
2332    let params = ViParams {
2333        period: Some(period),
2334    };
2335    let input = ViInput::from_slices(h, l, c, params);
2336    let kern = validate_kernel(kernel, false)?;
2337
2338    let (plus, minus) = py
2339        .allow_threads(|| {
2340            let mut plus = vec![0.0; h.len()];
2341            let mut minus = vec![0.0; h.len()];
2342            vi_into_slice(&mut plus, &mut minus, &input, kern).map(|_| (plus, minus))
2343        })
2344        .map_err(|e: ViError| PyValueError::new_err(e.to_string()))?;
2345
2346    let d = PyDict::new(py);
2347    d.set_item("plus", plus.into_pyarray(py))?;
2348    d.set_item("minus", minus.into_pyarray(py))?;
2349    Ok(d)
2350}
2351
2352#[cfg(feature = "python")]
2353#[pyfunction(name = "vi_batch")]
2354#[pyo3(signature = (high, low, close, period_range, kernel=None))]
2355pub fn vi_batch_py<'py>(
2356    py: Python<'py>,
2357    high: PyReadonlyArray1<'py, f64>,
2358    low: PyReadonlyArray1<'py, f64>,
2359    close: PyReadonlyArray1<'py, f64>,
2360    period_range: (usize, usize, usize),
2361    kernel: Option<&str>,
2362) -> PyResult<Bound<'py, PyDict>> {
2363    use numpy::PyArray2;
2364
2365    let h = high.as_slice()?;
2366    let l = low.as_slice()?;
2367    let c = close.as_slice()?;
2368
2369    let sweep = ViBatchRange {
2370        period: period_range,
2371    };
2372    let kern = validate_kernel(kernel, true)?;
2373
2374    let combos = expand_grid(&sweep);
2375    let rows = combos.len();
2376    let cols = h.len();
2377    let total = rows
2378        .checked_mul(cols)
2379        .ok_or_else(|| PyValueError::new_err("rows*cols overflow in vi_batch_py"))?;
2380
2381    let out_plus = unsafe { PyArray1::<f64>::new(py, [total], false) };
2382    let out_minus = unsafe { PyArray1::<f64>::new(py, [total], false) };
2383    let slice_plus = unsafe { out_plus.as_slice_mut()? };
2384    let slice_minus = unsafe { out_minus.as_slice_mut()? };
2385
2386    py.allow_threads(|| {
2387        let simd = match kern {
2388            Kernel::Avx512Batch => Kernel::Avx512,
2389            Kernel::Avx2Batch => Kernel::Avx2,
2390            Kernel::ScalarBatch => Kernel::Scalar,
2391            Kernel::Auto => match detect_best_batch_kernel() {
2392                Kernel::Avx512Batch => Kernel::Avx2Batch,
2393                other => other,
2394            }
2395            .to_scalar_equivalent(),
2396            _ => Kernel::Scalar,
2397        };
2398
2399        vi_batch_inner_into(h, l, c, &sweep, simd, true, slice_plus, slice_minus)
2400    })
2401    .map_err(|e| PyValueError::new_err(e.to_string()))?;
2402
2403    let d = PyDict::new(py);
2404    d.set_item("plus", out_plus.reshape((rows, cols))?)?;
2405    d.set_item("minus", out_minus.reshape((rows, cols))?)?;
2406    d.set_item(
2407        "periods",
2408        combos
2409            .iter()
2410            .map(|p| p.period.unwrap_or(14) as u64)
2411            .collect::<Vec<_>>()
2412            .into_pyarray(py),
2413    )?;
2414    Ok(d)
2415}
2416
2417#[cfg(feature = "python")]
2418#[pyclass(name = "ViStream")]
2419pub struct ViStreamPy {
2420    stream: ViStream,
2421    prev_high: Option<f64>,
2422    prev_low: Option<f64>,
2423    prev_close: Option<f64>,
2424}
2425
2426#[cfg(feature = "python")]
2427#[pymethods]
2428impl ViStreamPy {
2429    #[new]
2430    fn new(period: usize) -> PyResult<Self> {
2431        let s = ViStream::try_new(ViParams {
2432            period: Some(period),
2433        })
2434        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2435        Ok(Self {
2436            stream: s,
2437            prev_high: None,
2438            prev_low: None,
2439            prev_close: None,
2440        })
2441    }
2442
2443    fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
2444        match (self.prev_high, self.prev_low, self.prev_close) {
2445            (Some(ph), Some(pl), Some(pc)) => {
2446                let result = self.stream.update(high, low, close, pl, ph, pc);
2447                self.prev_high = Some(high);
2448                self.prev_low = Some(low);
2449                self.prev_close = Some(close);
2450                result
2451            }
2452            _ => {
2453                self.prev_high = Some(high);
2454                self.prev_low = Some(low);
2455                self.prev_close = Some(close);
2456                None
2457            }
2458        }
2459    }
2460}
2461
2462#[cfg(feature = "python")]
2463trait BatchToScalar {
2464    fn to_scalar_equivalent(self) -> Kernel;
2465}
2466#[cfg(feature = "python")]
2467impl BatchToScalar for Kernel {
2468    fn to_scalar_equivalent(self) -> Kernel {
2469        match self {
2470            Kernel::Avx512Batch => Kernel::Avx512,
2471            Kernel::Avx2Batch => Kernel::Avx2,
2472            Kernel::ScalarBatch => Kernel::Scalar,
2473            Kernel::Auto => Kernel::Scalar,
2474            k => k,
2475        }
2476    }
2477}
2478
2479#[cfg(all(feature = "python", feature = "cuda"))]
2480use crate::cuda::cuda_available;
2481#[cfg(all(feature = "python", feature = "cuda"))]
2482use crate::cuda::vi_wrapper::CudaVi;
2483#[cfg(all(feature = "python", feature = "cuda"))]
2484use crate::utilities::dlpack_cuda::DeviceArrayF32Py;
2485
2486#[cfg(all(feature = "python", feature = "cuda"))]
2487#[pyfunction(name = "vi_cuda_batch_dev")]
2488#[pyo3(signature = (high_f32, low_f32, close_f32, period_range, device_id=0))]
2489pub fn vi_cuda_batch_dev_py<'py>(
2490    py: Python<'py>,
2491    high_f32: numpy::PyReadonlyArray1<'py, f32>,
2492    low_f32: numpy::PyReadonlyArray1<'py, f32>,
2493    close_f32: numpy::PyReadonlyArray1<'py, f32>,
2494    period_range: (usize, usize, usize),
2495    device_id: usize,
2496) -> PyResult<Bound<'py, PyDict>> {
2497    use numpy::IntoPyArray;
2498    if !cuda_available() {
2499        return Err(PyValueError::new_err("CUDA not available"));
2500    }
2501    let h = high_f32.as_slice()?;
2502    let l = low_f32.as_slice()?;
2503    let c = close_f32.as_slice()?;
2504    if h.len() != l.len() || h.len() != c.len() {
2505        return Err(PyValueError::new_err("Input data length mismatch"));
2506    }
2507    let sweep = ViBatchRange {
2508        period: period_range,
2509    };
2510    let ((pair, combos), ctx, dev_id) = py.allow_threads(|| {
2511        let cuda = CudaVi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2512        let ctx = cuda.context_arc();
2513        let dev_id = cuda.device_id();
2514        cuda.vi_batch_dev(h, l, c, &sweep)
2515            .map(|res| (res, ctx, dev_id))
2516            .map_err(|e| PyValueError::new_err(e.to_string()))
2517    })?;
2518    let dict = PyDict::new(py);
2519    dict.set_item(
2520        "plus",
2521        Py::new(
2522            py,
2523            DeviceArrayF32Py {
2524                inner: pair.a,
2525                _ctx: Some(ctx.clone()),
2526                device_id: Some(dev_id),
2527            },
2528        )?,
2529    )?;
2530    dict.set_item(
2531        "minus",
2532        Py::new(
2533            py,
2534            DeviceArrayF32Py {
2535                inner: pair.b,
2536                _ctx: Some(ctx),
2537                device_id: Some(dev_id),
2538            },
2539        )?,
2540    )?;
2541    dict.set_item("rows", combos.len())?;
2542    dict.set_item("cols", h.len())?;
2543    dict.set_item(
2544        "periods",
2545        combos
2546            .iter()
2547            .map(|p| p.period.unwrap_or(14) as u64)
2548            .collect::<Vec<_>>()
2549            .into_pyarray(py),
2550    )?;
2551    Ok(dict)
2552}
2553
2554#[cfg(all(feature = "python", feature = "cuda"))]
2555#[pyfunction(name = "vi_cuda_many_series_one_param_dev")]
2556#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, period, device_id=0))]
2557pub fn vi_cuda_many_series_one_param_dev_py<'py>(
2558    py: Python<'py>,
2559    high_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
2560    low_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
2561    close_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
2562    period: usize,
2563    device_id: usize,
2564) -> PyResult<Bound<'py, PyDict>> {
2565    use numpy::PyUntypedArrayMethods;
2566    if !cuda_available() {
2567        return Err(PyValueError::new_err("CUDA not available"));
2568    }
2569    let shape = high_tm_f32.shape();
2570    if shape.len() != 2 {
2571        return Err(PyValueError::new_err("expected 2D array for high"));
2572    }
2573    if low_tm_f32.shape() != shape || close_tm_f32.shape() != shape {
2574        return Err(PyValueError::new_err(
2575            "input arrays must share the same shape",
2576        ));
2577    }
2578    let rows = shape[0];
2579    let cols = shape[1];
2580    let h = high_tm_f32.as_slice()?;
2581    let l = low_tm_f32.as_slice()?;
2582    let c = close_tm_f32.as_slice()?;
2583    let params = ViParams {
2584        period: Some(period),
2585    };
2586    let (pair, ctx, dev_id) = py.allow_threads(|| {
2587        let cuda = CudaVi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2588        let ctx = cuda.context_arc();
2589        let dev_id = cuda.device_id();
2590        cuda.vi_many_series_one_param_time_major_dev(h, l, c, cols, rows, &params)
2591            .map(|res| (res, ctx, dev_id))
2592            .map_err(|e| PyValueError::new_err(e.to_string()))
2593    })?;
2594    let dict = PyDict::new(py);
2595    dict.set_item(
2596        "plus",
2597        Py::new(
2598            py,
2599            DeviceArrayF32Py {
2600                inner: pair.a,
2601                _ctx: Some(ctx.clone()),
2602                device_id: Some(dev_id),
2603            },
2604        )?,
2605    )?;
2606    dict.set_item(
2607        "minus",
2608        Py::new(
2609            py,
2610            DeviceArrayF32Py {
2611                inner: pair.b,
2612                _ctx: Some(ctx),
2613                device_id: Some(dev_id),
2614            },
2615        )?,
2616    )?;
2617    dict.set_item("rows", rows)?;
2618    dict.set_item("cols", cols)?;
2619    dict.set_item("period", period)?;
2620    Ok(dict)
2621}