Skip to main content

vector_ta/indicators/
var.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
10use serde::{Deserialize, Serialize};
11#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
12use wasm_bindgen::prelude::*;
13
14use crate::utilities::data_loader::{source_type, Candles};
15use crate::utilities::enums::Kernel;
16use crate::utilities::helpers::{
17    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
18    make_uninit_matrix,
19};
20#[cfg(feature = "python")]
21use crate::utilities::kernel_validation::validate_kernel;
22#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
23use core::arch::x86_64::*;
24#[cfg(not(target_arch = "wasm32"))]
25use rayon::prelude::*;
26use std::convert::AsRef;
27use std::error::Error;
28use thiserror::Error;
29
30#[derive(Debug, Clone)]
31pub enum VarData<'a> {
32    Candles {
33        candles: &'a Candles,
34        source: &'a str,
35    },
36    Slice(&'a [f64]),
37}
38
39impl<'a> AsRef<[f64]> for VarInput<'a> {
40    #[inline(always)]
41    fn as_ref(&self) -> &[f64] {
42        match &self.data {
43            VarData::Slice(slice) => slice,
44            VarData::Candles { candles, source } => source_type(candles, source),
45        }
46    }
47}
48
49#[derive(Debug, Clone)]
50pub struct VarOutput {
51    pub values: Vec<f64>,
52}
53
54#[derive(Debug, Clone)]
55#[cfg_attr(
56    all(target_arch = "wasm32", feature = "wasm"),
57    derive(Serialize, Deserialize)
58)]
59pub struct VarParams {
60    pub period: Option<usize>,
61    pub nbdev: Option<f64>,
62}
63
64impl Default for VarParams {
65    fn default() -> Self {
66        Self {
67            period: Some(14),
68            nbdev: Some(1.0),
69        }
70    }
71}
72
73#[derive(Debug, Clone)]
74pub struct VarInput<'a> {
75    pub data: VarData<'a>,
76    pub params: VarParams,
77}
78
79impl<'a> VarInput<'a> {
80    #[inline]
81    pub fn from_candles(c: &'a Candles, s: &'a str, p: VarParams) -> Self {
82        Self {
83            data: VarData::Candles {
84                candles: c,
85                source: s,
86            },
87            params: p,
88        }
89    }
90    #[inline]
91    pub fn from_slice(sl: &'a [f64], p: VarParams) -> Self {
92        Self {
93            data: VarData::Slice(sl),
94            params: p,
95        }
96    }
97    #[inline]
98    pub fn with_default_candles(c: &'a Candles) -> Self {
99        Self::from_candles(c, "close", VarParams::default())
100    }
101    #[inline]
102    pub fn get_period(&self) -> usize {
103        self.params.period.unwrap_or(14)
104    }
105    #[inline]
106    pub fn get_nbdev(&self) -> f64 {
107        self.params.nbdev.unwrap_or(1.0)
108    }
109}
110
111#[derive(Copy, Clone, Debug)]
112pub struct VarBuilder {
113    period: Option<usize>,
114    nbdev: Option<f64>,
115    kernel: Kernel,
116}
117
118impl Default for VarBuilder {
119    fn default() -> Self {
120        Self {
121            period: None,
122            nbdev: None,
123            kernel: Kernel::Auto,
124        }
125    }
126}
127
128impl VarBuilder {
129    #[inline(always)]
130    pub fn new() -> Self {
131        Self::default()
132    }
133    #[inline(always)]
134    pub fn period(mut self, n: usize) -> Self {
135        self.period = Some(n);
136        self
137    }
138    #[inline(always)]
139    pub fn nbdev(mut self, x: f64) -> Self {
140        self.nbdev = Some(x);
141        self
142    }
143    #[inline(always)]
144    pub fn kernel(mut self, k: Kernel) -> Self {
145        self.kernel = k;
146        self
147    }
148    #[inline(always)]
149    pub fn apply(self, c: &Candles) -> Result<VarOutput, VarError> {
150        let p = VarParams {
151            period: self.period,
152            nbdev: self.nbdev,
153        };
154        let i = VarInput::from_candles(c, "close", p);
155        var_with_kernel(&i, self.kernel)
156    }
157    #[inline(always)]
158    pub fn apply_slice(self, d: &[f64]) -> Result<VarOutput, VarError> {
159        let p = VarParams {
160            period: self.period,
161            nbdev: self.nbdev,
162        };
163        let i = VarInput::from_slice(d, p);
164        var_with_kernel(&i, self.kernel)
165    }
166    #[inline(always)]
167    pub fn into_stream(self) -> Result<VarStream, VarError> {
168        let p = VarParams {
169            period: self.period,
170            nbdev: self.nbdev,
171        };
172        VarStream::try_new(p)
173    }
174}
175
176#[derive(Debug, Error)]
177pub enum VarError {
178    #[error("var: input data is empty (All values are NaN).")]
179    EmptyInputData,
180    #[error("var: All values are NaN.")]
181    AllValuesNaN,
182    #[error("var: Invalid period: period = {period}, data length = {data_len}")]
183    InvalidPeriod { period: usize, data_len: usize },
184    #[error("var: Not enough valid data: needed = {needed}, valid = {valid}")]
185    NotEnoughValidData { needed: usize, valid: usize },
186    #[error("var: output length mismatch: expected = {expected}, got = {got}")]
187    OutputLengthMismatch { expected: usize, got: usize },
188    #[error("var: invalid range: start = {start}, end = {end}, step = {step}")]
189    InvalidRange { start: f64, end: f64, step: f64 },
190    #[error("var: invalid kernel for batch: {0:?}")]
191    InvalidKernelForBatch(Kernel),
192    #[error("var: nbdev is NaN or infinite: {nbdev}")]
193    InvalidNbdev { nbdev: f64 },
194}
195
196#[inline]
197pub fn var(input: &VarInput) -> Result<VarOutput, VarError> {
198    var_with_kernel(input, Kernel::Auto)
199}
200
201pub fn var_with_kernel(input: &VarInput, kernel: Kernel) -> Result<VarOutput, VarError> {
202    let data: &[f64] = input.as_ref();
203    let len = data.len();
204    if len == 0 {
205        return Err(VarError::EmptyInputData);
206    }
207    let first = data
208        .iter()
209        .position(|x| !x.is_nan())
210        .ok_or(VarError::AllValuesNaN)?;
211    let period = input.get_period();
212    let nbdev = input.get_nbdev();
213
214    if period == 0 || period > len {
215        return Err(VarError::InvalidPeriod {
216            period,
217            data_len: len,
218        });
219    }
220    if (len - first) < period {
221        return Err(VarError::NotEnoughValidData {
222            needed: period,
223            valid: len - first,
224        });
225    }
226    if nbdev.is_nan() || nbdev.is_infinite() {
227        return Err(VarError::InvalidNbdev { nbdev });
228    }
229
230    let chosen = match kernel {
231        Kernel::Auto => match detect_best_kernel() {
232            Kernel::Avx512 => Kernel::Avx2,
233            other => other,
234        },
235        other => other,
236    };
237
238    let mut out = alloc_with_nan_prefix(len, first + period - 1);
239
240    unsafe {
241        match chosen {
242            Kernel::Scalar | Kernel::ScalarBatch => {
243                var_scalar(data, period, first, nbdev, &mut out)?
244            }
245            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
246            Kernel::Avx2 | Kernel::Avx2Batch => var_avx2(data, period, first, nbdev, &mut out)?,
247            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
248            Kernel::Avx512 | Kernel::Avx512Batch => {
249                var_avx512(data, period, first, nbdev, &mut out)?
250            }
251            _ => unreachable!(),
252        }
253    }
254
255    Ok(VarOutput { values: out })
256}
257
258#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
259pub fn var_into(input: &VarInput, out: &mut [f64]) -> Result<(), VarError> {
260    let data: &[f64] = input.as_ref();
261    let len = data.len();
262    if len == 0 {
263        return Err(VarError::EmptyInputData);
264    }
265    if out.len() != len {
266        return Err(VarError::OutputLengthMismatch {
267            expected: len,
268            got: out.len(),
269        });
270    }
271    let first = data
272        .iter()
273        .position(|x| !x.is_nan())
274        .ok_or(VarError::AllValuesNaN)?;
275    let period = input.get_period();
276    let nbdev = input.get_nbdev();
277
278    if period == 0 || period > len {
279        return Err(VarError::InvalidPeriod {
280            period,
281            data_len: len,
282        });
283    }
284    if (len - first) < period {
285        return Err(VarError::NotEnoughValidData {
286            needed: period,
287            valid: len - first,
288        });
289    }
290    if nbdev.is_nan() || nbdev.is_infinite() {
291        return Err(VarError::InvalidNbdev { nbdev });
292    }
293
294    let warmup_end = (first + period - 1).min(len);
295    for v in &mut out[..warmup_end] {
296        *v = f64::from_bits(0x7ff8_0000_0000_0000);
297    }
298
299    let chosen = match detect_best_kernel() {
300        Kernel::Avx512 => Kernel::Avx2,
301        other => other,
302    };
303
304    unsafe {
305        match chosen {
306            Kernel::Scalar | Kernel::ScalarBatch => var_scalar(data, period, first, nbdev, out)?,
307            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
308            Kernel::Avx2 | Kernel::Avx2Batch => var_avx2(data, period, first, nbdev, out)?,
309            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
310            Kernel::Avx512 | Kernel::Avx512Batch => var_avx512(data, period, first, nbdev, out)?,
311            _ => unreachable!(),
312        }
313    }
314
315    Ok(())
316}
317
318#[inline(always)]
319pub fn var_scalar(
320    data: &[f64],
321    period: usize,
322    first: usize,
323    nbdev: f64,
324    out: &mut [f64],
325) -> Result<(), VarError> {
326    let len = data.len();
327    let nbdev2 = nbdev * nbdev;
328    let inv_p = 1.0 / (period as f64);
329
330    let mut sum = 0.0f64;
331    let mut sum_sq = 0.0f64;
332    let init = &data[first..first + period];
333    let mut it = init.chunks_exact(4);
334    for c in &mut it {
335        let x0 = c[0];
336        let x1 = c[1];
337        let x2 = c[2];
338        let x3 = c[3];
339        sum += x0 + x1 + x2 + x3;
340        sum_sq += x0 * x0 + x1 * x1 + x2 * x2 + x3 * x3;
341    }
342    for &x in it.remainder() {
343        sum += x;
344        sum_sq += x * x;
345    }
346
347    let idx0 = first + period - 1;
348    let mean0 = sum * inv_p;
349    out[idx0] = (sum_sq * inv_p - mean0 * mean0) * nbdev2;
350
351    unsafe {
352        let mut out_ptr = out.as_mut_ptr().add(idx0 + 1);
353        let mut in_new = data.as_ptr().add(first + period);
354        let mut in_old = data.as_ptr().add(first);
355        let end = data.as_ptr().add(len);
356
357        while in_new < end {
358            let new = *in_new;
359            let old = *in_old;
360
361            sum += new - old;
362            sum_sq += new * new - old * old;
363
364            let mean = sum * inv_p;
365            *out_ptr = (sum_sq * inv_p - mean * mean) * nbdev2;
366
367            in_new = in_new.add(1);
368            in_old = in_old.add(1);
369            out_ptr = out_ptr.add(1);
370        }
371    }
372
373    Ok(())
374}
375
376#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
377#[inline(always)]
378pub fn var_avx2(
379    data: &[f64],
380    period: usize,
381    first: usize,
382    nbdev: f64,
383    out: &mut [f64],
384) -> Result<(), VarError> {
385    let len = data.len();
386    let nbdev2 = nbdev * nbdev;
387    let inv_p = 1.0 / (period as f64);
388
389    let mut sum = 0.0f64;
390    let mut sum_sq = 0.0f64;
391    unsafe {
392        let mut idx = first;
393        let end = first + period;
394
395        while idx + 4 <= end {
396            let v = _mm256_loadu_pd(data.as_ptr().add(idx));
397            let v2 = _mm256_mul_pd(v, v);
398
399            let mut lanes: [f64; 4] = core::mem::zeroed();
400            _mm256_storeu_pd(lanes.as_mut_ptr(), v);
401            let mut t = lanes[0] + lanes[1];
402            t = t + lanes[2];
403            t = t + lanes[3];
404            sum += t;
405
406            _mm256_storeu_pd(lanes.as_mut_ptr(), v2);
407            let mut t2 = lanes[0] + lanes[1];
408            t2 = t2 + lanes[2];
409            t2 = t2 + lanes[3];
410            sum_sq += t2;
411
412            idx += 4;
413        }
414
415        while idx < end {
416            let x = *data.get_unchecked(idx);
417            sum += x;
418            sum_sq += x * x;
419            idx += 1;
420        }
421    }
422
423    let idx0 = first + period - 1;
424    let mean0 = sum * inv_p;
425    out[idx0] = (sum_sq * inv_p - mean0 * mean0) * nbdev2;
426
427    unsafe {
428        let mut out_ptr = out.as_mut_ptr().add(idx0 + 1);
429        let mut in_new = data.as_ptr().add(first + period);
430        let mut in_old = data.as_ptr().add(first);
431        let end = data.as_ptr().add(len);
432
433        while in_new < end {
434            let new = *in_new;
435            let old = *in_old;
436
437            sum += new - old;
438            sum_sq += new * new - old * old;
439
440            let mean = sum * inv_p;
441            *out_ptr = (sum_sq * inv_p - mean * mean) * nbdev2;
442
443            in_new = in_new.add(1);
444            in_old = in_old.add(1);
445            out_ptr = out_ptr.add(1);
446        }
447    }
448
449    Ok(())
450}
451
452#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
453#[inline(always)]
454pub fn var_avx512(
455    data: &[f64],
456    period: usize,
457    first: usize,
458    nbdev: f64,
459    out: &mut [f64],
460) -> Result<(), VarError> {
461    use core::arch::x86_64::*;
462
463    let len = data.len();
464    let nbdev2 = nbdev * nbdev;
465    let inv_p = 1.0 / (period as f64);
466
467    let mut sum = 0.0f64;
468    let mut sum_sq = 0.0f64;
469
470    unsafe {
471        let mut idx = first;
472        let end = first + period;
473        while idx + 8 <= end {
474            let v = _mm512_loadu_pd(data.as_ptr().add(idx));
475            let v2 = _mm512_mul_pd(v, v);
476
477            let mut lanes: [f64; 8] = core::mem::zeroed();
478            _mm512_storeu_pd(lanes.as_mut_ptr(), v);
479
480            let mut t0 = lanes[0] + lanes[1];
481            t0 = t0 + lanes[2];
482            t0 = t0 + lanes[3];
483            sum += t0;
484
485            let mut t1 = lanes[4] + lanes[5];
486            t1 = t1 + lanes[6];
487            t1 = t1 + lanes[7];
488            sum += t1;
489
490            _mm512_storeu_pd(lanes.as_mut_ptr(), v2);
491            let mut u0 = lanes[0] + lanes[1];
492            u0 = u0 + lanes[2];
493            u0 = u0 + lanes[3];
494            sum_sq += u0;
495            let mut u1 = lanes[4] + lanes[5];
496            u1 = u1 + lanes[6];
497            u1 = u1 + lanes[7];
498            sum_sq += u1;
499
500            idx += 8;
501        }
502
503        while idx + 4 <= end {
504            let v4 = _mm256_loadu_pd(data.as_ptr().add(idx));
505            let v4sq = _mm256_mul_pd(v4, v4);
506            let mut lanes4: [f64; 4] = core::mem::zeroed();
507            _mm256_storeu_pd(lanes4.as_mut_ptr(), v4);
508            let mut t = lanes4[0] + lanes4[1];
509            t = t + lanes4[2];
510            t = t + lanes4[3];
511            sum += t;
512            _mm256_storeu_pd(lanes4.as_mut_ptr(), v4sq);
513            let mut u = lanes4[0] + lanes4[1];
514            u = u + lanes4[2];
515            u = u + lanes4[3];
516            sum_sq += u;
517            idx += 4;
518        }
519
520        while idx < end {
521            let x = *data.get_unchecked(idx);
522            sum += x;
523            sum_sq += x * x;
524            idx += 1;
525        }
526
527        let idx0 = first + period - 1;
528        let mean0 = sum * inv_p;
529        out[idx0] = (sum_sq * inv_p - mean0 * mean0) * nbdev2;
530
531        let mut out_ptr = out.as_mut_ptr().add(idx0 + 1);
532        let mut in_new = data.as_ptr().add(first + period);
533        let mut in_old = data.as_ptr().add(first);
534        let end_ptr = data.as_ptr().add(len);
535
536        while in_new < end_ptr {
537            let new = *in_new;
538            let old = *in_old;
539
540            sum += new - old;
541            sum_sq += new * new - old * old;
542
543            let mean = sum * inv_p;
544            *out_ptr = (sum_sq * inv_p - mean * mean) * nbdev2;
545
546            in_new = in_new.add(1);
547            in_old = in_old.add(1);
548            out_ptr = out_ptr.add(1);
549        }
550    }
551
552    Ok(())
553}
554
555#[inline]
556pub fn var_into_slice(dst: &mut [f64], input: &VarInput, kern: Kernel) -> Result<(), VarError> {
557    let data: &[f64] = input.as_ref();
558    let len = data.len();
559    if len == 0 {
560        return Err(VarError::EmptyInputData);
561    }
562    if dst.len() != len {
563        return Err(VarError::OutputLengthMismatch {
564            expected: len,
565            got: dst.len(),
566        });
567    }
568    let first = data
569        .iter()
570        .position(|x| !x.is_nan())
571        .ok_or(VarError::AllValuesNaN)?;
572    let period = input.get_period();
573    let nbdev = input.get_nbdev();
574
575    if period == 0 || period > len {
576        return Err(VarError::InvalidPeriod {
577            period,
578            data_len: len,
579        });
580    }
581    if (len - first) < period {
582        return Err(VarError::NotEnoughValidData {
583            needed: period,
584            valid: len - first,
585        });
586    }
587    if nbdev.is_nan() || nbdev.is_infinite() {
588        return Err(VarError::InvalidNbdev { nbdev });
589    }
590    let chosen = match kern {
591        Kernel::Auto => Kernel::Scalar,
592        other => other,
593    };
594
595    unsafe {
596        match chosen {
597            Kernel::Scalar | Kernel::ScalarBatch => {
598                var_scalar(data, period, first, nbdev, dst)?;
599            }
600            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
601            Kernel::Avx2 | Kernel::Avx2Batch => {
602                var_avx2(data, period, first, nbdev, dst)?;
603            }
604            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
605            Kernel::Avx512 | Kernel::Avx512Batch => {
606                var_avx512(data, period, first, nbdev, dst)?;
607            }
608            _ => unreachable!(),
609        }
610    }
611
612    let warmup_end = first + period - 1;
613    for v in &mut dst[..warmup_end] {
614        *v = f64::NAN;
615    }
616
617    Ok(())
618}
619
620#[inline(always)]
621pub unsafe fn var_row_scalar(
622    data: &[f64],
623    first: usize,
624    period: usize,
625    _stride: usize,
626    nbdev: f64,
627    out: &mut [f64],
628) {
629    let len = data.len();
630    let inv_p = 1.0 / (period as f64);
631    let nbdev2 = nbdev * nbdev;
632
633    let mut sum = 0.0f64;
634    let mut sum_sq = 0.0f64;
635
636    let mut p = data.as_ptr().add(first);
637    let end = p.add(period);
638    while p.add(4) <= end {
639        let x0 = *p;
640        let x1 = *p.add(1);
641        let x2 = *p.add(2);
642        let x3 = *p.add(3);
643        sum += x0 + x1 + x2 + x3;
644        sum_sq += x0 * x0 + x1 * x1 + x2 * x2 + x3 * x3;
645        p = p.add(4);
646    }
647    while p < end {
648        let x = *p;
649        sum += x;
650        sum_sq += x * x;
651        p = p.add(1);
652    }
653
654    let idx0 = first + period - 1;
655    let mean0 = sum * inv_p;
656    *out.get_unchecked_mut(idx0) = (sum_sq * inv_p - mean0 * mean0) * nbdev2;
657
658    let mut i = idx0 + 1;
659    while i + 1 < len {
660        {
661            let new0 = *data.get_unchecked(i);
662            let old0 = *data.get_unchecked(i - period);
663            sum += new0 - old0;
664            sum_sq += new0 * new0 - old0 * old0;
665            let mean = sum * inv_p;
666            *out.get_unchecked_mut(i) = (sum_sq * inv_p - mean * mean) * nbdev2;
667        }
668
669        {
670            let new1 = *data.get_unchecked(i + 1);
671            let old1 = *data.get_unchecked(i + 1 - period);
672            sum += new1 - old1;
673            sum_sq += new1 * new1 - old1 * old1;
674            let mean = sum * inv_p;
675            *out.get_unchecked_mut(i + 1) = (sum_sq * inv_p - mean * mean) * nbdev2;
676        }
677
678        i += 2;
679    }
680
681    if i < len {
682        let newx = *data.get_unchecked(i);
683        let oldx = *data.get_unchecked(i - period);
684        sum += newx - oldx;
685        sum_sq += newx * newx - oldx * oldx;
686        let mean = sum * inv_p;
687        *out.get_unchecked_mut(i) = (sum_sq * inv_p - mean * mean) * nbdev2;
688    }
689}
690
691#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
692#[inline(always)]
693pub unsafe fn var_row_avx2(
694    data: &[f64],
695    first: usize,
696    period: usize,
697    stride: usize,
698    nbdev: f64,
699    out: &mut [f64],
700) {
701    var_row_scalar(data, first, period, stride, nbdev, out)
702}
703
704#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
705#[inline(always)]
706pub unsafe fn var_row_avx512(
707    data: &[f64],
708    first: usize,
709    period: usize,
710    stride: usize,
711    nbdev: f64,
712    out: &mut [f64],
713) {
714    if period <= 32 {
715        var_row_avx512_short(data, first, period, stride, nbdev, out);
716    } else {
717        var_row_avx512_long(data, first, period, stride, nbdev, out);
718    }
719}
720
721#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
722#[inline(always)]
723pub unsafe fn var_row_avx512_short(
724    data: &[f64],
725    first: usize,
726    period: usize,
727    stride: usize,
728    nbdev: f64,
729    out: &mut [f64],
730) {
731    var_row_scalar(data, first, period, stride, nbdev, out)
732}
733
734#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
735#[inline(always)]
736pub unsafe fn var_row_avx512_long(
737    data: &[f64],
738    first: usize,
739    period: usize,
740    stride: usize,
741    nbdev: f64,
742    out: &mut [f64],
743) {
744    var_row_scalar(data, first, period, stride, nbdev, out)
745}
746
747#[derive(Clone, Debug)]
748pub struct VarBatchRange {
749    pub period: (usize, usize, usize),
750    pub nbdev: (f64, f64, f64),
751}
752
753impl Default for VarBatchRange {
754    fn default() -> Self {
755        Self {
756            period: (14, 263, 1),
757            nbdev: (1.0, 1.0, 0.0),
758        }
759    }
760}
761
762#[derive(Clone, Debug, Default)]
763pub struct VarBatchBuilder {
764    range: VarBatchRange,
765    kernel: Kernel,
766}
767
768impl VarBatchBuilder {
769    pub fn new() -> Self {
770        Self::default()
771    }
772    pub fn kernel(mut self, k: Kernel) -> Self {
773        self.kernel = k;
774        self
775    }
776    #[inline]
777    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
778        self.range.period = (start, end, step);
779        self
780    }
781    #[inline]
782    pub fn period_static(mut self, p: usize) -> Self {
783        self.range.period = (p, p, 0);
784        self
785    }
786    #[inline]
787    pub fn nbdev_range(mut self, start: f64, end: f64, step: f64) -> Self {
788        self.range.nbdev = (start, end, step);
789        self
790    }
791    #[inline]
792    pub fn nbdev_static(mut self, x: f64) -> Self {
793        self.range.nbdev = (x, x, 0.0);
794        self
795    }
796    pub fn apply_slice(self, data: &[f64]) -> Result<VarBatchOutput, VarError> {
797        var_batch_with_kernel(data, &self.range, self.kernel)
798    }
799    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<VarBatchOutput, VarError> {
800        VarBatchBuilder::new().kernel(k).apply_slice(data)
801    }
802    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<VarBatchOutput, VarError> {
803        let slice = source_type(c, src);
804        self.apply_slice(slice)
805    }
806    pub fn with_default_candles(c: &Candles) -> Result<VarBatchOutput, VarError> {
807        VarBatchBuilder::new()
808            .kernel(Kernel::Auto)
809            .apply_candles(c, "close")
810    }
811}
812
813#[derive(Clone, Debug)]
814pub struct VarBatchOutput {
815    pub values: Vec<f64>,
816    pub combos: Vec<VarParams>,
817    pub rows: usize,
818    pub cols: usize,
819}
820impl VarBatchOutput {
821    pub fn row_for_params(&self, p: &VarParams) -> Option<usize> {
822        self.combos.iter().position(|c| {
823            c.period.unwrap_or(14) == p.period.unwrap_or(14)
824                && (c.nbdev.unwrap_or(1.0) - p.nbdev.unwrap_or(1.0)).abs() < 1e-12
825        })
826    }
827    pub fn values_for(&self, p: &VarParams) -> Option<&[f64]> {
828        self.row_for_params(p).map(|row| {
829            let start = row * self.cols;
830            &self.values[start..start + self.cols]
831        })
832    }
833}
834
835#[inline(always)]
836fn expand_grid(r: &VarBatchRange) -> Vec<VarParams> {
837    fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
838        if step == 0 || start == end {
839            return vec![start];
840        }
841        let mut v = Vec::new();
842        if start < end {
843            let mut x = start;
844            while x <= end {
845                v.push(x);
846                match x.checked_add(step) {
847                    Some(next) if next > x => x = next,
848                    _ => break,
849                }
850            }
851        } else {
852            let mut x = start;
853            while x >= end {
854                v.push(x);
855                match x.checked_sub(step) {
856                    Some(next) if next < x => x = next,
857                    _ => break,
858                }
859            }
860        }
861        v
862    }
863    fn axis_f64((start, end, step): (f64, f64, f64)) -> Vec<f64> {
864        let eps = 1e-12;
865        if step.abs() < eps || (start - end).abs() < eps {
866            return vec![start];
867        }
868        let mut v = Vec::new();
869        let mut x = start;
870        if step > 0.0 {
871            if start > end + eps {
872                return v;
873            }
874            while x <= end + eps {
875                v.push(x);
876                x += step;
877            }
878        } else {
879            if start < end - eps {
880                return v;
881            }
882            while x >= end - eps {
883                v.push(x);
884                x += step;
885            }
886        }
887        v
888    }
889    let periods = axis_usize(r.period);
890    let nbdevs = axis_f64(r.nbdev);
891    let capacity = periods.len().checked_mul(nbdevs.len()).unwrap_or(0);
892    let mut out = Vec::with_capacity(capacity);
893    for &p in &periods {
894        for &n in &nbdevs {
895            out.push(VarParams {
896                period: Some(p),
897                nbdev: Some(n),
898            });
899        }
900    }
901    out
902}
903
904#[inline(always)]
905pub fn var_expand_grid(r: &VarBatchRange) -> Vec<VarParams> {
906    expand_grid(r)
907}
908
909#[inline(always)]
910pub fn var_batch_slice(
911    data: &[f64],
912    sweep: &VarBatchRange,
913    kern: Kernel,
914) -> Result<VarBatchOutput, VarError> {
915    var_batch_inner(data, sweep, kern, false)
916}
917#[inline(always)]
918pub fn var_batch_par_slice(
919    data: &[f64],
920    sweep: &VarBatchRange,
921    kern: Kernel,
922) -> Result<VarBatchOutput, VarError> {
923    var_batch_inner(data, sweep, kern, true)
924}
925
926fn round_up8(x: usize) -> usize {
927    (x + 7) & !7
928}
929
930const ENABLE_VAR_BATCH_PREFIX_SUMS: bool = false;
931
932#[inline(always)]
933fn make_prefix_sums(data: &[f64], first: usize) -> (Vec<f64>, Vec<f64>) {
934    let len = data.len();
935    let mut ps = vec![0.0f64; len + 1];
936    let mut psq = vec![0.0f64; len + 1];
937    let mut s = 0.0f64;
938    let mut s2 = 0.0f64;
939    for j in first..len {
940        let x = data[j];
941        s += x;
942        s2 += x * x;
943        ps[j + 1] = s;
944        psq[j + 1] = s2;
945    }
946    (ps, psq)
947}
948
949#[inline(always)]
950fn var_batch_inner(
951    data: &[f64],
952    sweep: &VarBatchRange,
953    kern: Kernel,
954    parallel: bool,
955) -> Result<VarBatchOutput, VarError> {
956    if data.is_empty() {
957        return Err(VarError::EmptyInputData);
958    }
959    let combos = expand_grid(sweep);
960    if combos.is_empty() {
961        return Err(VarError::InvalidRange {
962            start: sweep.period.0 as f64,
963            end: sweep.period.1 as f64,
964            step: sweep.period.2 as f64,
965        });
966    }
967    let first = data
968        .iter()
969        .position(|x| !x.is_nan())
970        .ok_or(VarError::AllValuesNaN)?;
971    let max_period = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
972    if data.len() - first < max_period {
973        return Err(VarError::NotEnoughValidData {
974            needed: max_period,
975            valid: data.len() - first,
976        });
977    }
978    let stride = round_up8(max_period);
979    let rows = combos.len();
980    let cols = data.len();
981    let mut buf_mu = make_uninit_matrix(rows, cols);
982
983    let warmup_periods: Vec<usize> = combos
984        .iter()
985        .map(|c| first + c.period.unwrap() - 1)
986        .collect();
987    init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
988
989    let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
990    let out: &mut [f64] = unsafe {
991        core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
992    };
993
994    if ENABLE_VAR_BATCH_PREFIX_SUMS && matches!(kern, Kernel::Scalar) {
995        let (ps, psq) = make_prefix_sums(data, first);
996        let compute_row = |row: usize, out_row: &mut [f64]| {
997            let period = combos[row].period.unwrap();
998            let inv_p = 1.0 / (period as f64);
999            let nbdev2 = combos[row].nbdev.unwrap() * combos[row].nbdev.unwrap();
1000            let start = first + period - 1;
1001            for i in start..cols {
1002                let s = ps[i + 1] - ps[i + 1 - period];
1003                let s2 = psq[i + 1] - psq[i + 1 - period];
1004                let m = s * inv_p;
1005                out_row[i] = (s2 * inv_p - m * m) * nbdev2;
1006            }
1007        };
1008
1009        if parallel {
1010            #[cfg(not(target_arch = "wasm32"))]
1011            {
1012                out.par_chunks_mut(cols)
1013                    .enumerate()
1014                    .for_each(|(row, slice)| compute_row(row, slice));
1015            }
1016
1017            #[cfg(target_arch = "wasm32")]
1018            {
1019                for (row, slice) in out.chunks_mut(cols).enumerate() {
1020                    compute_row(row, slice);
1021                }
1022            }
1023        } else {
1024            for (row, slice) in out.chunks_mut(cols).enumerate() {
1025                compute_row(row, slice);
1026            }
1027        }
1028    } else {
1029        let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1030            let period = combos[row].period.unwrap();
1031            let nbdev = combos[row].nbdev.unwrap();
1032            match kern {
1033                Kernel::Scalar => var_row_scalar(data, first, period, stride, nbdev, out_row),
1034                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1035                Kernel::Avx2 => var_row_avx2(data, first, period, stride, nbdev, out_row),
1036                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1037                Kernel::Avx512 => var_row_avx512(data, first, period, stride, nbdev, out_row),
1038                #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1039                Kernel::Avx2 | Kernel::Avx512 => {
1040                    var_row_scalar(data, first, period, stride, nbdev, out_row)
1041                }
1042                _ => unreachable!(),
1043            }
1044        };
1045        if parallel {
1046            #[cfg(not(target_arch = "wasm32"))]
1047            {
1048                out.par_chunks_mut(cols)
1049                    .enumerate()
1050                    .for_each(|(row, slice)| do_row(row, slice));
1051            }
1052
1053            #[cfg(target_arch = "wasm32")]
1054            {
1055                for (row, slice) in out.chunks_mut(cols).enumerate() {
1056                    do_row(row, slice);
1057                }
1058            }
1059        } else {
1060            for (row, slice) in out.chunks_mut(cols).enumerate() {
1061                do_row(row, slice);
1062            }
1063        }
1064    }
1065
1066    let values = unsafe {
1067        Vec::from_raw_parts(
1068            buf_guard.as_mut_ptr() as *mut f64,
1069            buf_guard.len(),
1070            buf_guard.capacity(),
1071        )
1072    };
1073    std::mem::forget(buf_guard);
1074
1075    Ok(VarBatchOutput {
1076        values,
1077        combos,
1078        rows,
1079        cols,
1080    })
1081}
1082
1083pub fn var_batch_with_kernel(
1084    data: &[f64],
1085    sweep: &VarBatchRange,
1086    k: Kernel,
1087) -> Result<VarBatchOutput, VarError> {
1088    let kernel = match k {
1089        Kernel::Auto => detect_best_batch_kernel(),
1090        other if other.is_batch() => other,
1091        _ => return Err(VarError::InvalidKernelForBatch(k)),
1092    };
1093    let simd = match kernel {
1094        Kernel::Avx512Batch => Kernel::Avx512,
1095        Kernel::Avx2Batch => Kernel::Avx2,
1096        Kernel::ScalarBatch => Kernel::Scalar,
1097        _ => unreachable!(),
1098    };
1099    var_batch_par_slice(data, sweep, simd)
1100}
1101
1102#[derive(Debug, Clone)]
1103pub struct VarStream {
1104    period: usize,
1105    nbdev: f64,
1106    buffer: Vec<f64>,
1107    sum: f64,
1108    sum_sq: f64,
1109    head: usize,
1110    filled: bool,
1111}
1112impl VarStream {
1113    pub fn try_new(params: VarParams) -> Result<Self, VarError> {
1114        let period = params.period.unwrap_or(14);
1115        if period == 0 {
1116            return Err(VarError::InvalidPeriod {
1117                period,
1118                data_len: 0,
1119            });
1120        }
1121        let nbdev = params.nbdev.unwrap_or(1.0);
1122        if nbdev.is_nan() || nbdev.is_infinite() {
1123            return Err(VarError::InvalidNbdev { nbdev });
1124        }
1125        Ok(Self {
1126            period,
1127            nbdev,
1128            buffer: vec![f64::NAN; period],
1129            sum: 0.0,
1130            sum_sq: 0.0,
1131            head: 0,
1132            filled: false,
1133        })
1134    }
1135    #[inline(always)]
1136    pub fn update(&mut self, value: f64) -> Option<f64> {
1137        let old = self.buffer[self.head];
1138        self.buffer[self.head] = value;
1139        self.head = (self.head + 1) % self.period;
1140        if !self.filled {
1141            self.sum += value;
1142            self.sum_sq += value * value;
1143            if self.head == 0 {
1144                self.filled = true;
1145            } else {
1146                return None;
1147            }
1148        } else {
1149            self.sum += value - old;
1150            self.sum_sq += value * value - old * old;
1151        }
1152        let inv_p = 1.0 / self.period as f64;
1153        let mean = self.sum * inv_p;
1154        let mean_sq = self.sum_sq * inv_p;
1155        Some((mean_sq - mean * mean) * self.nbdev * self.nbdev)
1156    }
1157}
1158
1159#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1160#[wasm_bindgen]
1161pub fn var_js(data: &[f64], period: usize, nbdev: f64) -> Result<Vec<f64>, JsValue> {
1162    let params = VarParams {
1163        period: Some(period),
1164        nbdev: Some(nbdev),
1165    };
1166    let input = VarInput::from_slice(data, params);
1167
1168    let mut output = vec![0.0; data.len()];
1169
1170    var_into_slice(&mut output, &input, detect_best_kernel())
1171        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1172
1173    Ok(output)
1174}
1175
1176#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1177#[wasm_bindgen]
1178pub fn var_alloc(len: usize) -> *mut f64 {
1179    let mut vec = Vec::<f64>::with_capacity(len);
1180    let ptr = vec.as_mut_ptr();
1181    std::mem::forget(vec);
1182    ptr
1183}
1184
1185#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1186#[wasm_bindgen]
1187pub fn var_free(ptr: *mut f64, len: usize) {
1188    unsafe {
1189        let _ = Vec::from_raw_parts(ptr, len, len);
1190    }
1191}
1192
1193#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1194#[wasm_bindgen]
1195pub fn var_into(
1196    in_ptr: *const f64,
1197    out_ptr: *mut f64,
1198    len: usize,
1199    period: usize,
1200    nbdev: f64,
1201) -> Result<(), JsValue> {
1202    if in_ptr.is_null() || out_ptr.is_null() {
1203        return Err(JsValue::from_str("Null pointer provided"));
1204    }
1205
1206    unsafe {
1207        let data = std::slice::from_raw_parts(in_ptr, len);
1208
1209        if period == 0 || period > len {
1210            return Err(JsValue::from_str("Invalid period"));
1211        }
1212
1213        let params = VarParams {
1214            period: Some(period),
1215            nbdev: Some(nbdev),
1216        };
1217        let input = VarInput::from_slice(data, params);
1218
1219        if in_ptr == out_ptr as *const f64 {
1220            let mut temp = vec![0.0; len];
1221            var_into_slice(&mut temp, &input, detect_best_kernel())
1222                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1223            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1224            out.copy_from_slice(&temp);
1225        } else {
1226            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1227            var_into_slice(out, &input, detect_best_kernel())
1228                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1229        }
1230
1231        Ok(())
1232    }
1233}
1234
1235#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1236#[derive(Serialize, Deserialize)]
1237pub struct VarBatchConfig {
1238    pub period_range: (usize, usize, usize),
1239    pub nbdev_range: (f64, f64, f64),
1240}
1241
1242#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1243#[derive(Serialize, Deserialize)]
1244pub struct VarBatchJsOutput {
1245    pub values: Vec<f64>,
1246    pub combos: Vec<VarParams>,
1247    pub rows: usize,
1248    pub cols: usize,
1249}
1250
1251#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1252#[wasm_bindgen(js_name = var_batch)]
1253pub fn var_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1254    let config: VarBatchConfig = serde_wasm_bindgen::from_value(config)
1255        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1256
1257    let sweep = VarBatchRange {
1258        period: config.period_range,
1259        nbdev: config.nbdev_range,
1260    };
1261
1262    let output = var_batch_inner(data, &sweep, detect_best_kernel(), false)
1263        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1264
1265    let js_output = VarBatchJsOutput {
1266        values: output.values,
1267        combos: output.combos,
1268        rows: output.rows,
1269        cols: output.cols,
1270    };
1271
1272    serde_wasm_bindgen::to_value(&js_output).map_err(|e| JsValue::from_str(&e.to_string()))
1273}
1274
1275#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1276#[wasm_bindgen]
1277pub fn var_batch_into(
1278    in_ptr: *const f64,
1279    out_ptr: *mut f64,
1280    len: usize,
1281    period_start: usize,
1282    period_end: usize,
1283    period_step: usize,
1284    nbdev_start: f64,
1285    nbdev_end: f64,
1286    nbdev_step: f64,
1287) -> Result<usize, JsValue> {
1288    if in_ptr.is_null() || out_ptr.is_null() {
1289        return Err(JsValue::from_str("Null pointer provided"));
1290    }
1291
1292    unsafe {
1293        let data = std::slice::from_raw_parts(in_ptr, len);
1294
1295        let sweep = VarBatchRange {
1296            period: (period_start, period_end, period_step),
1297            nbdev: (nbdev_start, nbdev_end, nbdev_step),
1298        };
1299
1300        let combos = expand_grid(&sweep);
1301        if combos.is_empty() {
1302            return Err(JsValue::from_str("No valid parameter combinations"));
1303        }
1304        let rows = combos.len();
1305        let cols = len;
1306
1307        let total = rows
1308            .checked_mul(cols)
1309            .ok_or_else(|| JsValue::from_str("rows * cols overflow"))?;
1310        let out = std::slice::from_raw_parts_mut(out_ptr, total);
1311
1312        let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1313        for (r, prm) in combos.iter().enumerate() {
1314            let warm = (first + prm.period.unwrap() - 1).min(cols);
1315            let row = &mut out[r * cols..r * cols + warm];
1316            for v in row {
1317                *v = f64::NAN;
1318            }
1319        }
1320
1321        let _ = var_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
1322            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1323
1324        Ok(rows)
1325    }
1326}
1327
1328#[cfg(test)]
1329mod tests {
1330    use super::*;
1331    use crate::skip_if_unsupported;
1332    use crate::utilities::data_loader::read_candles_from_csv;
1333    use paste::paste;
1334
1335    #[test]
1336    fn test_var_into_matches_api() -> Result<(), Box<dyn Error>> {
1337        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1338        let candles = read_candles_from_csv(file_path)?;
1339        let input = VarInput::from_candles(&candles, "close", VarParams::default());
1340
1341        let VarOutput { values: expected } = var(&input)?;
1342
1343        let mut out = vec![0.0f64; expected.len()];
1344
1345        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1346        {
1347            var_into(&input, &mut out)?;
1348        }
1349        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1350        {
1351            var_into_slice(&mut out, &input, detect_best_kernel())?;
1352        }
1353
1354        assert_eq!(out.len(), expected.len());
1355
1356        for (i, (a, b)) in out.iter().zip(expected.iter()).enumerate() {
1357            let eq_or_both_nan = (*a == *b) || (a.is_nan() && b.is_nan());
1358            assert!(
1359                eq_or_both_nan,
1360                "Mismatch at index {}: got {:?}, expected {:?}",
1361                i, a, b
1362            );
1363        }
1364
1365        Ok(())
1366    }
1367
1368    fn check_var_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1369        skip_if_unsupported!(kernel, test_name);
1370        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1371        let candles = read_candles_from_csv(file_path)?;
1372        let default_params = VarParams {
1373            period: None,
1374            nbdev: None,
1375        };
1376        let input = VarInput::from_candles(&candles, "close", default_params);
1377        let output = var_with_kernel(&input, kernel)?;
1378        assert_eq!(output.values.len(), candles.close.len());
1379        Ok(())
1380    }
1381
1382    fn check_var_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1383        skip_if_unsupported!(kernel, test_name);
1384        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1385        let candles = read_candles_from_csv(file_path)?;
1386        let input = VarInput::from_candles(&candles, "close", VarParams::default());
1387        let var_result = var_with_kernel(&input, kernel)?;
1388        assert_eq!(var_result.values.len(), candles.close.len());
1389        let expected_last_five = [
1390            350987.4081501961,
1391            348493.9183540344,
1392            302611.06121110916,
1393            106092.2499871254,
1394            121941.35202789307,
1395        ];
1396        let start_index = var_result.values.len() - 5;
1397        let result_last_five = &var_result.values[start_index..];
1398        for (i, &value) in result_last_five.iter().enumerate() {
1399            let expected_value = expected_last_five[i];
1400            assert!(
1401                (value - expected_value).abs() < 1e-1,
1402                "[{}] VAR mismatch at idx {}: got {}, expected {}",
1403                test_name,
1404                i,
1405                value,
1406                expected_value
1407            );
1408        }
1409        Ok(())
1410    }
1411
1412    fn check_var_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1413        skip_if_unsupported!(kernel, test_name);
1414        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1415        let candles = read_candles_from_csv(file_path)?;
1416        let input = VarInput::with_default_candles(&candles);
1417        match input.data {
1418            VarData::Candles { source, .. } => assert_eq!(source, "close"),
1419            _ => panic!("Expected VarData::Candles"),
1420        }
1421        let output = var_with_kernel(&input, kernel)?;
1422        assert_eq!(output.values.len(), candles.close.len());
1423        Ok(())
1424    }
1425
1426    fn check_var_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1427        skip_if_unsupported!(kernel, test_name);
1428        let input_data = [10.0, 20.0, 30.0];
1429        let params = VarParams {
1430            period: Some(0),
1431            nbdev: None,
1432        };
1433        let input = VarInput::from_slice(&input_data, params);
1434        let res = var_with_kernel(&input, kernel);
1435        assert!(
1436            res.is_err(),
1437            "[{}] VAR should fail with zero period",
1438            test_name
1439        );
1440        Ok(())
1441    }
1442
1443    fn check_var_period_exceeds_length(
1444        test_name: &str,
1445        kernel: Kernel,
1446    ) -> Result<(), Box<dyn Error>> {
1447        skip_if_unsupported!(kernel, test_name);
1448        let data_small = [10.0, 20.0, 30.0];
1449        let params = VarParams {
1450            period: Some(10),
1451            nbdev: None,
1452        };
1453        let input = VarInput::from_slice(&data_small, params);
1454        let res = var_with_kernel(&input, kernel);
1455        assert!(
1456            res.is_err(),
1457            "[{}] VAR should fail with period exceeding length",
1458            test_name
1459        );
1460        Ok(())
1461    }
1462
1463    fn check_var_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1464        skip_if_unsupported!(kernel, test_name);
1465        let single_point = [42.0];
1466        let params = VarParams {
1467            period: Some(14),
1468            nbdev: None,
1469        };
1470        let input = VarInput::from_slice(&single_point, params);
1471        let res = var_with_kernel(&input, kernel);
1472        assert!(
1473            res.is_err(),
1474            "[{}] VAR should fail with insufficient data",
1475            test_name
1476        );
1477        Ok(())
1478    }
1479
1480    fn check_var_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1481        skip_if_unsupported!(kernel, test_name);
1482        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1483        let candles = read_candles_from_csv(file_path)?;
1484        let first_params = VarParams {
1485            period: Some(14),
1486            nbdev: Some(1.0),
1487        };
1488        let first_input = VarInput::from_candles(&candles, "close", first_params);
1489        let first_result = var_with_kernel(&first_input, kernel)?;
1490        let second_params = VarParams {
1491            period: Some(14),
1492            nbdev: Some(1.0),
1493        };
1494        let second_input = VarInput::from_slice(&first_result.values, second_params);
1495        let second_result = var_with_kernel(&second_input, kernel)?;
1496        assert_eq!(second_result.values.len(), first_result.values.len());
1497        Ok(())
1498    }
1499
1500    fn check_var_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1501        skip_if_unsupported!(kernel, test_name);
1502        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1503        let candles = read_candles_from_csv(file_path)?;
1504        let input = VarInput::from_candles(
1505            &candles,
1506            "close",
1507            VarParams {
1508                period: Some(14),
1509                nbdev: None,
1510            },
1511        );
1512        let res = var_with_kernel(&input, kernel)?;
1513        assert_eq!(res.values.len(), candles.close.len());
1514        if res.values.len() > 30 {
1515            for (i, &val) in res.values[30..].iter().enumerate() {
1516                assert!(
1517                    !val.is_nan(),
1518                    "[{}] Found unexpected NaN at out-index {}",
1519                    test_name,
1520                    30 + i
1521                );
1522            }
1523        }
1524        Ok(())
1525    }
1526
1527    fn check_var_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1528        skip_if_unsupported!(kernel, test_name);
1529        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1530        let candles = read_candles_from_csv(file_path)?;
1531        let period = 14;
1532        let nbdev = 1.0;
1533        let input = VarInput::from_candles(
1534            &candles,
1535            "close",
1536            VarParams {
1537                period: Some(period),
1538                nbdev: Some(nbdev),
1539            },
1540        );
1541        let batch_output = var_with_kernel(&input, kernel)?.values;
1542        let mut stream = VarStream::try_new(VarParams {
1543            period: Some(period),
1544            nbdev: Some(nbdev),
1545        })?;
1546        let mut stream_values = Vec::with_capacity(candles.close.len());
1547        for &price in &candles.close {
1548            match stream.update(price) {
1549                Some(var_val) => stream_values.push(var_val),
1550                None => stream_values.push(f64::NAN),
1551            }
1552        }
1553        assert_eq!(batch_output.len(), stream_values.len());
1554        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1555            if b.is_nan() && s.is_nan() {
1556                continue;
1557            }
1558            let diff = (b - s).abs();
1559            assert!(
1560                diff < 1e-6,
1561                "[{}] VAR streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1562                test_name,
1563                i,
1564                b,
1565                s,
1566                diff
1567            );
1568        }
1569        Ok(())
1570    }
1571
1572    #[cfg(debug_assertions)]
1573    fn check_var_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1574        skip_if_unsupported!(kernel, test_name);
1575
1576        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1577        let candles = read_candles_from_csv(file_path)?;
1578
1579        let test_params = vec![
1580            VarParams::default(),
1581            VarParams {
1582                period: Some(2),
1583                nbdev: Some(1.0),
1584            },
1585            VarParams {
1586                period: Some(5),
1587                nbdev: Some(1.0),
1588            },
1589            VarParams {
1590                period: Some(10),
1591                nbdev: Some(1.0),
1592            },
1593            VarParams {
1594                period: Some(20),
1595                nbdev: Some(1.0),
1596            },
1597            VarParams {
1598                period: Some(30),
1599                nbdev: Some(1.0),
1600            },
1601            VarParams {
1602                period: Some(50),
1603                nbdev: Some(1.0),
1604            },
1605            VarParams {
1606                period: Some(100),
1607                nbdev: Some(1.0),
1608            },
1609            VarParams {
1610                period: Some(200),
1611                nbdev: Some(1.0),
1612            },
1613            VarParams {
1614                period: Some(14),
1615                nbdev: Some(0.5),
1616            },
1617            VarParams {
1618                period: Some(14),
1619                nbdev: Some(2.0),
1620            },
1621            VarParams {
1622                period: Some(14),
1623                nbdev: Some(3.0),
1624            },
1625            VarParams {
1626                period: Some(7),
1627                nbdev: Some(1.5),
1628            },
1629            VarParams {
1630                period: Some(21),
1631                nbdev: Some(2.5),
1632            },
1633            VarParams {
1634                period: Some(50),
1635                nbdev: Some(0.75),
1636            },
1637        ];
1638
1639        for (param_idx, params) in test_params.iter().enumerate() {
1640            let input = VarInput::from_candles(&candles, "close", params.clone());
1641            let output = var_with_kernel(&input, kernel)?;
1642
1643            for (i, &val) in output.values.iter().enumerate() {
1644                if val.is_nan() {
1645                    continue;
1646                }
1647
1648                let bits = val.to_bits();
1649
1650                if bits == 0x11111111_11111111 {
1651                    panic!(
1652                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1653						 with params: period={}, nbdev={} (param set {})",
1654                        test_name,
1655                        val,
1656                        bits,
1657                        i,
1658                        params.period.unwrap_or(14),
1659                        params.nbdev.unwrap_or(1.0),
1660                        param_idx
1661                    );
1662                }
1663
1664                if bits == 0x22222222_22222222 {
1665                    panic!(
1666                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1667						 with params: period={}, nbdev={} (param set {})",
1668                        test_name,
1669                        val,
1670                        bits,
1671                        i,
1672                        params.period.unwrap_or(14),
1673                        params.nbdev.unwrap_or(1.0),
1674                        param_idx
1675                    );
1676                }
1677
1678                if bits == 0x33333333_33333333 {
1679                    panic!(
1680                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1681						 with params: period={}, nbdev={} (param set {})",
1682                        test_name,
1683                        val,
1684                        bits,
1685                        i,
1686                        params.period.unwrap_or(14),
1687                        params.nbdev.unwrap_or(1.0),
1688                        param_idx
1689                    );
1690                }
1691            }
1692        }
1693
1694        Ok(())
1695    }
1696
1697    #[cfg(not(debug_assertions))]
1698    fn check_var_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1699        Ok(())
1700    }
1701
1702    #[cfg(feature = "proptest")]
1703    #[allow(clippy::float_cmp)]
1704    fn check_var_property(
1705        test_name: &str,
1706        kernel: Kernel,
1707    ) -> Result<(), Box<dyn std::error::Error>> {
1708        use proptest::prelude::*;
1709        skip_if_unsupported!(kernel, test_name);
1710
1711        let strat = (2usize..=64).prop_flat_map(|period| {
1712            (
1713                prop::collection::vec(
1714                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1715                    period.max(10)..400,
1716                ),
1717                Just(period),
1718                0.1f64..3.0f64,
1719                -100.0f64..100.0f64,
1720                -1e5f64..1e5f64,
1721                prop::bool::ANY,
1722            )
1723        });
1724
1725        proptest::test_runner::TestRunner::default()
1726            .run(
1727                &strat,
1728                |(mut data, period, nbdev, trend, intercept, use_special_pattern)| {
1729                    if use_special_pattern {
1730                        for (i, val) in data.iter_mut().enumerate() {
1731                            *val = intercept + trend * (i as f64);
1732                        }
1733
1734                        for val in data.iter_mut() {
1735                            *val += (val.abs() * 0.01).min(10.0)
1736                                * (if val.is_sign_positive() { 1.0 } else { -1.0 });
1737                        }
1738                    }
1739
1740                    let params = VarParams {
1741                        period: Some(period),
1742                        nbdev: Some(nbdev),
1743                    };
1744                    let input = VarInput::from_slice(&data, params);
1745
1746                    let VarOutput { values: out } = var_with_kernel(&input, kernel).unwrap();
1747                    let VarOutput { values: ref_out } =
1748                        var_with_kernel(&input, Kernel::Scalar).unwrap();
1749
1750                    for i in (period - 1)..data.len() {
1751                        let y = out[i];
1752                        if !y.is_nan() {
1753                            prop_assert!(
1754                                y >= -1e-6,
1755                                "[{}] Variance should be non-negative at idx {}: got {}",
1756                                test_name,
1757                                i,
1758                                y
1759                            );
1760                        }
1761                    }
1762
1763                    if period == 2 && !use_special_pattern {
1764                        for i in (period - 1)..data.len().min(10) {
1765                            if !out[i].is_nan() {
1766                                let window = &data[i + 1 - period..=i];
1767                                let mean = (window[0] + window[1]) / 2.0;
1768                                let expected =
1769                                    ((window[0] - mean).powi(2) + (window[1] - mean).powi(2)) / 2.0
1770                                        * nbdev
1771                                        * nbdev;
1772
1773                                let tolerance = (expected.abs() + 1.0) * 1e-8;
1774                                prop_assert!(
1775                                    (out[i] - expected).abs() <= tolerance,
1776                                    "[{}] Period=2 variance mismatch at idx {}: got {} expected {}",
1777                                    test_name,
1778                                    i,
1779                                    out[i],
1780                                    expected
1781                                );
1782                            }
1783                        }
1784                    }
1785
1786                    let is_constant = data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
1787                    if is_constant && data.len() >= period {
1788                        for i in (period - 1)..data.len() {
1789                            if !out[i].is_nan() {
1790                                prop_assert!(
1791								out[i].abs() <= 1e-6,
1792								"[{}] Constant data should have near-zero variance at idx {}: got {}",
1793								test_name, i, out[i]
1794							);
1795                            }
1796                        }
1797                    }
1798
1799                    for i in (period - 1)..data.len() {
1800                        let y = out[i];
1801                        let r = ref_out[i];
1802
1803                        if !y.is_finite() || !r.is_finite() {
1804                            prop_assert!(
1805                                y.to_bits() == r.to_bits(),
1806                                "[{}] finite/NaN mismatch at idx {}: {} vs {}",
1807                                test_name,
1808                                i,
1809                                y,
1810                                r
1811                            );
1812                            continue;
1813                        }
1814
1815                        let y_bits = y.to_bits();
1816                        let r_bits = r.to_bits();
1817                        let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1818
1819                        prop_assert!(
1820                            (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1821                            "[{}] Kernel mismatch at idx {}: {} vs {} (ULP={})",
1822                            test_name,
1823                            i,
1824                            y,
1825                            r,
1826                            ulp_diff
1827                        );
1828                    }
1829
1830                    if !use_special_pattern && data.len() >= period && period <= 10 {
1831                        let idx = period * 2;
1832                        if idx < data.len() && !out[idx].is_nan() {
1833                            let window = &data[idx + 1 - period..=idx];
1834                            let mean: f64 = window.iter().sum::<f64>() / (period as f64);
1835                            let mean_sq: f64 =
1836                                window.iter().map(|x| x * x).sum::<f64>() / (period as f64);
1837                            let expected_var = (mean_sq - mean * mean) * nbdev * nbdev;
1838
1839                            let tolerance = (expected_var.abs() + 1.0) * 1e-8;
1840                            prop_assert!(
1841							(out[idx] - expected_var).abs() <= tolerance,
1842							"[{}] Mathematical formula mismatch at idx {}: got {} expected {} (diff: {})",
1843							test_name, idx, out[idx], expected_var, (out[idx] - expected_var).abs()
1844						);
1845                        }
1846                    }
1847
1848                    Ok(())
1849                },
1850            )
1851            .unwrap();
1852
1853        Ok(())
1854    }
1855
1856    macro_rules! generate_all_var_tests {
1857        ($($test_fn:ident),*) => {
1858            paste! {
1859                $(
1860                    #[test]
1861                    fn [<$test_fn _scalar_f64>]() {
1862                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1863                    }
1864                )*
1865                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1866                $(
1867                    #[test]
1868                    fn [<$test_fn _avx2_f64>]() {
1869                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1870                    }
1871                    #[test]
1872                    fn [<$test_fn _avx512_f64>]() {
1873                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1874                    }
1875                )*
1876            }
1877        }
1878    }
1879    generate_all_var_tests!(
1880        check_var_partial_params,
1881        check_var_accuracy,
1882        check_var_default_candles,
1883        check_var_zero_period,
1884        check_var_period_exceeds_length,
1885        check_var_very_small_dataset,
1886        check_var_reinput,
1887        check_var_nan_handling,
1888        check_var_streaming,
1889        check_var_no_poison
1890    );
1891
1892    #[cfg(feature = "proptest")]
1893    generate_all_var_tests!(check_var_property);
1894
1895    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1896        skip_if_unsupported!(kernel, test);
1897        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1898        let c = read_candles_from_csv(file)?;
1899        let output = VarBatchBuilder::new()
1900            .kernel(kernel)
1901            .apply_candles(&c, "close")?;
1902        let def = VarParams::default();
1903        let row = output.values_for(&def).expect("default row missing");
1904        assert_eq!(row.len(), c.close.len());
1905        let expected = [
1906            350987.4081501961,
1907            348493.9183540344,
1908            302611.06121110916,
1909            106092.2499871254,
1910            121941.35202789307,
1911        ];
1912        let start = row.len() - 5;
1913        for (i, &v) in row[start..].iter().enumerate() {
1914            assert!(
1915                (v - expected[i]).abs() < 1e-1,
1916                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1917            );
1918        }
1919        Ok(())
1920    }
1921
1922    #[cfg(debug_assertions)]
1923    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1924        skip_if_unsupported!(kernel, test);
1925
1926        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1927        let c = read_candles_from_csv(file)?;
1928
1929        let test_configs = vec![
1930            ((2, 10, 2), (1.0, 1.0, 0.0)),
1931            ((10, 30, 5), (1.0, 2.0, 0.5)),
1932            ((30, 100, 10), (0.5, 1.5, 0.5)),
1933            ((2, 5, 1), (1.0, 3.0, 1.0)),
1934            ((14, 14, 0), (1.0, 1.0, 0.0)),
1935            ((5, 25, 5), (2.0, 2.0, 0.0)),
1936            ((50, 100, 25), (1.0, 2.0, 0.25)),
1937            ((14, 28, 7), (0.5, 2.5, 0.5)),
1938        ];
1939
1940        for (cfg_idx, &(period_range, nbdev_range)) in test_configs.iter().enumerate() {
1941            let output = VarBatchBuilder::new()
1942                .kernel(kernel)
1943                .period_range(period_range.0, period_range.1, period_range.2)
1944                .nbdev_range(nbdev_range.0, nbdev_range.1, nbdev_range.2)
1945                .apply_candles(&c, "close")?;
1946
1947            for (idx, &val) in output.values.iter().enumerate() {
1948                if val.is_nan() {
1949                    continue;
1950                }
1951
1952                let bits = val.to_bits();
1953                let row = idx / output.cols;
1954                let col = idx % output.cols;
1955                let combo = &output.combos[row];
1956
1957                if bits == 0x11111111_11111111 {
1958                    panic!(
1959                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1960						 at row {} col {} (flat index {}) with params: period={}, nbdev={}",
1961                        test,
1962                        cfg_idx,
1963                        val,
1964                        bits,
1965                        row,
1966                        col,
1967                        idx,
1968                        combo.period.unwrap_or(14),
1969                        combo.nbdev.unwrap_or(1.0)
1970                    );
1971                }
1972
1973                if bits == 0x22222222_22222222 {
1974                    panic!(
1975                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1976						 at row {} col {} (flat index {}) with params: period={}, nbdev={}",
1977                        test,
1978                        cfg_idx,
1979                        val,
1980                        bits,
1981                        row,
1982                        col,
1983                        idx,
1984                        combo.period.unwrap_or(14),
1985                        combo.nbdev.unwrap_or(1.0)
1986                    );
1987                }
1988
1989                if bits == 0x33333333_33333333 {
1990                    panic!(
1991                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1992						 at row {} col {} (flat index {}) with params: period={}, nbdev={}",
1993                        test,
1994                        cfg_idx,
1995                        val,
1996                        bits,
1997                        row,
1998                        col,
1999                        idx,
2000                        combo.period.unwrap_or(14),
2001                        combo.nbdev.unwrap_or(1.0)
2002                    );
2003                }
2004            }
2005        }
2006
2007        Ok(())
2008    }
2009
2010    #[cfg(not(debug_assertions))]
2011    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2012        Ok(())
2013    }
2014
2015    macro_rules! gen_batch_tests {
2016        ($fn_name:ident) => {
2017            paste! {
2018                #[test] fn [<$fn_name _scalar>]()      {
2019                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2020                }
2021                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2022                #[test] fn [<$fn_name _avx2>]()        {
2023                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2024                }
2025                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2026                #[test] fn [<$fn_name _avx512>]()      {
2027                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2028                }
2029                #[test] fn [<$fn_name _auto_detect>]() {
2030                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2031                }
2032            }
2033        };
2034    }
2035    gen_batch_tests!(check_batch_default_row);
2036    gen_batch_tests!(check_batch_no_poison);
2037
2038    #[test]
2039    fn test_batch_non_aligned_periods() {
2040        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
2041
2042        let sweep = VarBatchRange {
2043            period: (7, 7, 0),
2044            nbdev: (1.0, 1.0, 0.0),
2045        };
2046
2047        let result = var_batch_slice(&data, &sweep, Kernel::Scalar);
2048        assert!(result.is_ok(), "Should handle period=7 with 10 data points");
2049
2050        let sweep_multi = VarBatchRange {
2051            period: (5, 7, 1),
2052            nbdev: (1.0, 1.0, 0.0),
2053        };
2054
2055        let result_multi = var_batch_slice(&data, &sweep_multi, Kernel::Scalar);
2056        assert!(
2057            result_multi.is_ok(),
2058            "Should handle periods 5,6,7 with 10 data points"
2059        );
2060        let output = result_multi.unwrap();
2061        assert_eq!(output.rows, 3, "Should have 3 rows for periods 5,6,7");
2062
2063        let data_short = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
2064        let result_short = var_batch_slice(&data_short, &sweep, Kernel::Scalar);
2065        assert!(
2066            result_short.is_err(),
2067            "Should reject when data length (6) < period (7)"
2068        );
2069
2070        let data_15 = vec![1.0; 15];
2071        let sweep_15 = VarBatchRange {
2072            period: (15, 15, 0),
2073            nbdev: (1.0, 1.0, 0.0),
2074        };
2075
2076        let result_15 = var_batch_slice(&data_15, &sweep_15, Kernel::Scalar);
2077        assert!(
2078            result_15.is_ok(),
2079            "Should handle period=15 with exactly 15 data points"
2080        );
2081    }
2082}
2083
2084#[cfg(feature = "python")]
2085#[pyfunction(name = "var")]
2086#[pyo3(signature = (data, period=14, nbdev=1.0, kernel=None))]
2087pub fn var_py<'py>(
2088    py: Python<'py>,
2089    data: PyReadonlyArray1<'py, f64>,
2090    period: usize,
2091    nbdev: f64,
2092    kernel: Option<&str>,
2093) -> PyResult<Bound<'py, PyArray1<f64>>> {
2094    let slice_in = data.as_slice()?;
2095    let kern = validate_kernel(kernel, false)?;
2096
2097    let params = VarParams {
2098        period: Some(period),
2099        nbdev: Some(nbdev),
2100    };
2101    let input = VarInput::from_slice(slice_in, params);
2102
2103    let result_vec: Vec<f64> = py
2104        .allow_threads(|| var_with_kernel(&input, kern).map(|o| o.values))
2105        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2106
2107    Ok(result_vec.into_pyarray(py))
2108}
2109
2110#[cfg(feature = "python")]
2111#[pyclass(name = "VarStream")]
2112pub struct VarStreamPy {
2113    stream: VarStream,
2114}
2115
2116#[cfg(feature = "python")]
2117#[pymethods]
2118impl VarStreamPy {
2119    #[new]
2120    fn new(period: usize, nbdev: f64) -> PyResult<Self> {
2121        let params = VarParams {
2122            period: Some(period),
2123            nbdev: Some(nbdev),
2124        };
2125        let stream =
2126            VarStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2127        Ok(VarStreamPy { stream })
2128    }
2129
2130    fn update(&mut self, value: f64) -> Option<f64> {
2131        self.stream.update(value)
2132    }
2133}
2134
2135#[cfg(feature = "python")]
2136#[pyfunction(name = "var_batch")]
2137#[pyo3(signature = (data, period_range, nbdev_range=(1.0, 1.0, 0.0), kernel=None))]
2138pub fn var_batch_py<'py>(
2139    py: Python<'py>,
2140    data: PyReadonlyArray1<'py, f64>,
2141    period_range: (usize, usize, usize),
2142    nbdev_range: (f64, f64, f64),
2143    kernel: Option<&str>,
2144) -> PyResult<Bound<'py, PyDict>> {
2145    let slice_in = data.as_slice()?;
2146    let sweep = VarBatchRange {
2147        period: period_range,
2148        nbdev: nbdev_range,
2149    };
2150
2151    let kern = validate_kernel(kernel, true)?;
2152    let kernel = match kern {
2153        Kernel::Auto => detect_best_batch_kernel(),
2154        k => k,
2155    };
2156
2157    let simd = match kernel {
2158        Kernel::Avx512Batch => Kernel::Avx512,
2159        Kernel::Avx2Batch => Kernel::Avx2,
2160        Kernel::ScalarBatch => Kernel::Scalar,
2161        _ => unreachable!(),
2162    };
2163
2164    let out = py
2165        .allow_threads(|| var_batch_par_slice(slice_in, &sweep, simd))
2166        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2167
2168    let rows = out.rows;
2169    let cols = out.cols;
2170
2171    let dict = PyDict::new(py);
2172
2173    let total = rows
2174        .checked_mul(cols)
2175        .ok_or_else(|| PyValueError::new_err("rows * cols overflow"))?;
2176    let values_2d = unsafe { numpy::PyArray2::<f64>::new(py, [rows, cols], false) };
2177    let raw_ptr = values_2d.data() as *mut f64;
2178    let output_slice = unsafe { std::slice::from_raw_parts_mut(raw_ptr, total) };
2179    output_slice.copy_from_slice(&out.values);
2180
2181    dict.set_item("values", values_2d)?;
2182    dict.set_item(
2183        "periods",
2184        out.combos
2185            .iter()
2186            .map(|p| p.period.unwrap() as u64)
2187            .collect::<Vec<_>>()
2188            .into_pyarray(py),
2189    )?;
2190    dict.set_item(
2191        "nbdevs",
2192        out.combos
2193            .iter()
2194            .map(|p| p.nbdev.unwrap())
2195            .collect::<Vec<_>>()
2196            .into_pyarray(py),
2197    )?;
2198    Ok(dict)
2199}
2200
2201#[cfg(all(feature = "python", feature = "cuda"))]
2202use crate::cuda::cuda_available;
2203#[cfg(all(feature = "python", feature = "cuda"))]
2204use crate::cuda::moving_averages::DeviceArrayF32;
2205#[cfg(all(feature = "python", feature = "cuda"))]
2206use crate::cuda::var_wrapper::CudaVar;
2207#[cfg(all(feature = "python", feature = "cuda"))]
2208use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2209#[cfg(all(feature = "python", feature = "cuda"))]
2210use cust::context::Context;
2211#[cfg(all(feature = "python", feature = "cuda"))]
2212use cust::memory::DeviceBuffer;
2213#[cfg(all(feature = "python", feature = "cuda"))]
2214use std::sync::Arc;
2215
2216#[cfg(all(feature = "python", feature = "cuda"))]
2217#[pyclass(module = "ta_indicators.cuda", name = "VarDeviceArrayF32", unsendable)]
2218pub struct VarDeviceArrayF32Py {
2219    pub(crate) inner: DeviceArrayF32,
2220    pub(crate) _ctx: Arc<Context>,
2221    pub(crate) device_id: u32,
2222}
2223
2224#[cfg(all(feature = "python", feature = "cuda"))]
2225#[pymethods]
2226impl VarDeviceArrayF32Py {
2227    #[getter]
2228    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2229        let d = PyDict::new(py);
2230        d.set_item("shape", (self.inner.rows, self.inner.cols))?;
2231        d.set_item("typestr", "<f4")?;
2232        d.set_item(
2233            "strides",
2234            (
2235                self.inner.cols * std::mem::size_of::<f32>(),
2236                std::mem::size_of::<f32>(),
2237            ),
2238        )?;
2239        d.set_item("data", (self.inner.device_ptr() as usize, false))?;
2240
2241        d.set_item("version", 3)?;
2242        Ok(d)
2243    }
2244
2245    fn __dlpack_device__(&self) -> (i32, i32) {
2246        (2, self.device_id as i32)
2247    }
2248
2249    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2250    fn __dlpack__<'py>(
2251        mut slf: pyo3::PyRefMut<'py, Self>,
2252        py: Python<'py>,
2253        stream: Option<pyo3::PyObject>,
2254        max_version: Option<pyo3::PyObject>,
2255        dl_device: Option<pyo3::PyObject>,
2256        copy: Option<pyo3::PyObject>,
2257    ) -> PyResult<PyObject> {
2258        use cust::memory::DeviceBuffer;
2259
2260        let (expected_type, expected_dev) = slf.__dlpack_device__();
2261        if let Some(dev_obj) = dl_device.as_ref() {
2262            if let Ok((dev_type, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2263                if dev_type != expected_type || dev_id != expected_dev {
2264                    return Err(PyValueError::new_err("dl_device mismatch for VAR buffer"));
2265                }
2266            }
2267        }
2268
2269        let wants_copy = copy
2270            .as_ref()
2271            .and_then(|c| c.extract::<bool>(py).ok())
2272            .unwrap_or(false);
2273        if wants_copy {
2274            return Err(PyValueError::new_err(
2275                "copy=True not supported for VAR DLPack export",
2276            ));
2277        }
2278
2279        let _ = stream;
2280
2281        let dummy =
2282            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2283        let rows = slf.inner.rows;
2284        let cols = slf.inner.cols;
2285        let inner = std::mem::replace(
2286            &mut slf.inner,
2287            DeviceArrayF32 {
2288                buf: dummy,
2289                rows: 0,
2290                cols: 0,
2291            },
2292        );
2293
2294        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2295
2296        export_f32_cuda_dlpack_2d(py, inner.buf, rows, cols, expected_dev, max_version_bound)
2297    }
2298}
2299
2300#[cfg(all(feature = "python", feature = "cuda"))]
2301#[pyfunction(name = "var_cuda_batch_dev")]
2302#[pyo3(signature = (data_f32, period_range, nbdev_range=(1.0,1.0,0.0), device_id=0))]
2303pub fn var_cuda_batch_dev_py(
2304    py: Python<'_>,
2305    data_f32: numpy::PyReadonlyArray1<'_, f32>,
2306    period_range: (usize, usize, usize),
2307    nbdev_range: (f32, f32, f32),
2308    device_id: usize,
2309) -> PyResult<VarDeviceArrayF32Py> {
2310    if !cuda_available() {
2311        return Err(PyValueError::new_err("CUDA not available"));
2312    }
2313    let slice_in = data_f32.as_slice()?;
2314    let sweep = VarBatchRange {
2315        period: period_range,
2316        nbdev: (
2317            nbdev_range.0 as f64,
2318            nbdev_range.1 as f64,
2319            nbdev_range.2 as f64,
2320        ),
2321    };
2322    let (inner, ctx, dev_id) = py.allow_threads(|| {
2323        let cuda = CudaVar::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2324        let ctx = cuda.context_arc();
2325        let dev_id = cuda.device_id();
2326        cuda.var_batch_dev(slice_in, &sweep)
2327            .map(|pair| (pair.0, ctx, dev_id))
2328            .map_err(|e| PyValueError::new_err(e.to_string()))
2329    })?;
2330    Ok(VarDeviceArrayF32Py {
2331        inner,
2332        _ctx: ctx,
2333        device_id: dev_id,
2334    })
2335}
2336
2337#[cfg(all(feature = "python", feature = "cuda"))]
2338#[pyfunction(name = "var_cuda_many_series_one_param_dev")]
2339#[pyo3(signature = (data_tm_f32, cols, rows, period, nbdev=1.0, device_id=0))]
2340pub fn var_cuda_many_series_one_param_dev_py(
2341    py: Python<'_>,
2342    data_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2343    cols: usize,
2344    rows: usize,
2345    period: usize,
2346    nbdev: f64,
2347    device_id: usize,
2348) -> PyResult<VarDeviceArrayF32Py> {
2349    if !cuda_available() {
2350        return Err(PyValueError::new_err("CUDA not available"));
2351    }
2352    let slice_tm = data_tm_f32.as_slice()?;
2353    let params = VarParams {
2354        period: Some(period),
2355        nbdev: Some(nbdev),
2356    };
2357    let (inner, ctx, dev_id) = py.allow_threads(|| {
2358        let cuda = CudaVar::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2359        let ctx = cuda.context_arc();
2360        let dev_id = cuda.device_id();
2361        cuda.var_many_series_one_param_time_major_dev(slice_tm, cols, rows, &params)
2362            .map(|h| (h, ctx, dev_id))
2363            .map_err(|e| PyValueError::new_err(e.to_string()))
2364    })?;
2365    Ok(VarDeviceArrayF32Py {
2366        inner,
2367        _ctx: ctx,
2368        device_id: dev_id,
2369    })
2370}
2371
2372#[inline(always)]
2373fn var_batch_inner_into(
2374    data: &[f64],
2375    sweep: &VarBatchRange,
2376    kern: Kernel,
2377    parallel: bool,
2378    out: &mut [f64],
2379) -> Result<Vec<VarParams>, VarError> {
2380    if data.is_empty() {
2381        return Err(VarError::EmptyInputData);
2382    }
2383    let combos = expand_grid(sweep);
2384    if combos.is_empty() {
2385        return Err(VarError::InvalidRange {
2386            start: sweep.period.0 as f64,
2387            end: sweep.period.1 as f64,
2388            step: sweep.period.2 as f64,
2389        });
2390    }
2391    let first = data
2392        .iter()
2393        .position(|x| !x.is_nan())
2394        .ok_or(VarError::AllValuesNaN)?;
2395    let max_period = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
2396    if data.len() - first < max_period {
2397        return Err(VarError::NotEnoughValidData {
2398            needed: max_period,
2399            valid: data.len() - first,
2400        });
2401    }
2402    let stride = round_up8(max_period);
2403    let cols = data.len();
2404
2405    if ENABLE_VAR_BATCH_PREFIX_SUMS && matches!(kern, Kernel::Scalar) {
2406        let (ps, psq) = make_prefix_sums(data, first);
2407        let compute_row = |row: usize, out_row: &mut [f64]| {
2408            let period = combos[row].period.unwrap();
2409            let inv_p = 1.0 / (period as f64);
2410            let nbdev2 = combos[row].nbdev.unwrap() * combos[row].nbdev.unwrap();
2411            let start = first + period - 1;
2412            for i in start..cols {
2413                let s = ps[i + 1] - ps[i + 1 - period];
2414                let s2 = psq[i + 1] - psq[i + 1 - period];
2415                let m = s * inv_p;
2416                out_row[i] = (s2 * inv_p - m * m) * nbdev2;
2417            }
2418        };
2419        if parallel {
2420            #[cfg(not(target_arch = "wasm32"))]
2421            {
2422                out.par_chunks_mut(cols)
2423                    .enumerate()
2424                    .for_each(|(row, slice)| compute_row(row, slice));
2425            }
2426            #[cfg(target_arch = "wasm32")]
2427            {
2428                for (row, slice) in out.chunks_mut(cols).enumerate() {
2429                    compute_row(row, slice);
2430                }
2431            }
2432        } else {
2433            for (row, slice) in out.chunks_mut(cols).enumerate() {
2434                compute_row(row, slice);
2435            }
2436        }
2437    } else {
2438        let do_row = |row: usize, out_row: &mut [f64]| unsafe {
2439            let period = combos[row].period.unwrap();
2440            let nbdev = combos[row].nbdev.unwrap();
2441            match kern {
2442                Kernel::Scalar => var_row_scalar(data, first, period, stride, nbdev, out_row),
2443                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2444                Kernel::Avx2 => var_row_avx2(data, first, period, stride, nbdev, out_row),
2445                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2446                Kernel::Avx512 => var_row_avx512(data, first, period, stride, nbdev, out_row),
2447                #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
2448                Kernel::Avx2 | Kernel::Avx512 => {
2449                    var_row_scalar(data, first, period, stride, nbdev, out_row)
2450                }
2451                _ => unreachable!(),
2452            }
2453        };
2454        if parallel {
2455            #[cfg(not(target_arch = "wasm32"))]
2456            {
2457                out.par_chunks_mut(cols)
2458                    .enumerate()
2459                    .for_each(|(row, slice)| do_row(row, slice));
2460            }
2461            #[cfg(target_arch = "wasm32")]
2462            {
2463                for (row, slice) in out.chunks_mut(cols).enumerate() {
2464                    do_row(row, slice);
2465                }
2466            }
2467        } else {
2468            for (row, slice) in out.chunks_mut(cols).enumerate() {
2469                do_row(row, slice);
2470            }
2471        }
2472    }
2473
2474    Ok(combos)
2475}