Skip to main content

vector_ta/indicators/
vidya.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::moving_averages::DeviceArrayF32;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::CudaVidya;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
7#[cfg(feature = "python")]
8use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
9#[cfg(feature = "python")]
10use pyo3::exceptions::PyValueError;
11#[cfg(feature = "python")]
12use pyo3::prelude::*;
13#[cfg(feature = "python")]
14use pyo3::types::PyDict;
15#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
16use serde::{Deserialize, Serialize};
17#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
18use wasm_bindgen::prelude::*;
19
20use crate::utilities::data_loader::{source_type, Candles};
21use crate::utilities::enums::Kernel;
22use crate::utilities::helpers::{
23    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
24    make_uninit_matrix,
25};
26#[cfg(feature = "python")]
27use crate::utilities::kernel_validation::validate_kernel;
28use aligned_vec::{AVec, CACHELINE_ALIGN};
29#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
30use core::arch::x86_64::*;
31#[cfg(all(feature = "python", feature = "cuda"))]
32use cust::context::Context;
33#[cfg(all(feature = "python", feature = "cuda"))]
34use cust::memory::DeviceBuffer;
35use paste::paste;
36#[cfg(not(target_arch = "wasm32"))]
37use rayon::prelude::*;
38use std::convert::AsRef;
39use thiserror::Error;
40
41impl<'a> AsRef<[f64]> for VidyaInput<'a> {
42    #[inline(always)]
43    fn as_ref(&self) -> &[f64] {
44        match &self.data {
45            VidyaData::Slice(slice) => slice,
46            VidyaData::Candles { candles, source } => source_type(candles, source),
47        }
48    }
49}
50
51#[cfg(test)]
52mod tests_into_parity {
53    use super::*;
54    use crate::utilities::data_loader::read_candles_from_csv;
55
56    #[test]
57    fn test_vidya_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
58        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
59        let candles = read_candles_from_csv(file_path)?;
60        let input = VidyaInput::with_default_candles(&candles);
61
62        let baseline = vidya(&input)?.values;
63
64        let mut out = vec![0.0; candles.close.len()];
65        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
66        {
67            vidya_into(&input, &mut out)?;
68        }
69        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
70        {
71            vidya_into_slice(&mut out, &input, Kernel::Auto)?;
72        }
73
74        assert_eq!(baseline.len(), out.len());
75
76        fn eq_or_both_nan(a: f64, b: f64) -> bool {
77            (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
78        }
79
80        for (i, (&a, &b)) in baseline.iter().zip(out.iter()).enumerate() {
81            assert!(
82                eq_or_both_nan(a, b),
83                "Mismatch at index {}: baseline={} into={}",
84                i,
85                a,
86                b
87            );
88        }
89
90        Ok(())
91    }
92}
93
94#[derive(Debug, Clone)]
95pub enum VidyaData<'a> {
96    Candles {
97        candles: &'a Candles,
98        source: &'a str,
99    },
100    Slice(&'a [f64]),
101}
102
103#[derive(Debug, Clone)]
104pub struct VidyaOutput {
105    pub values: Vec<f64>,
106}
107
108#[derive(Debug, Clone)]
109#[cfg_attr(
110    all(target_arch = "wasm32", feature = "wasm"),
111    derive(Serialize, Deserialize)
112)]
113pub struct VidyaParams {
114    pub short_period: Option<usize>,
115    pub long_period: Option<usize>,
116    pub alpha: Option<f64>,
117}
118
119impl Default for VidyaParams {
120    fn default() -> Self {
121        Self {
122            short_period: Some(2),
123            long_period: Some(5),
124            alpha: Some(0.2),
125        }
126    }
127}
128
129#[derive(Debug, Clone)]
130pub struct VidyaInput<'a> {
131    pub data: VidyaData<'a>,
132    pub params: VidyaParams,
133}
134
135impl<'a> VidyaInput<'a> {
136    #[inline]
137    pub fn from_candles(c: &'a Candles, s: &'a str, p: VidyaParams) -> Self {
138        Self {
139            data: VidyaData::Candles {
140                candles: c,
141                source: s,
142            },
143            params: p,
144        }
145    }
146    #[inline]
147    pub fn from_slice(sl: &'a [f64], p: VidyaParams) -> Self {
148        Self {
149            data: VidyaData::Slice(sl),
150            params: p,
151        }
152    }
153    #[inline]
154    pub fn with_default_candles(c: &'a Candles) -> Self {
155        Self::from_candles(c, "close", VidyaParams::default())
156    }
157    #[inline]
158    pub fn get_short_period(&self) -> usize {
159        self.params.short_period.unwrap_or(2)
160    }
161    #[inline]
162    pub fn get_long_period(&self) -> usize {
163        self.params.long_period.unwrap_or(5)
164    }
165    #[inline]
166    pub fn get_alpha(&self) -> f64 {
167        self.params.alpha.unwrap_or(0.2)
168    }
169}
170
171#[derive(Copy, Clone, Debug)]
172pub struct VidyaBuilder {
173    short_period: Option<usize>,
174    long_period: Option<usize>,
175    alpha: Option<f64>,
176    kernel: Kernel,
177}
178
179impl Default for VidyaBuilder {
180    fn default() -> Self {
181        Self {
182            short_period: None,
183            long_period: None,
184            alpha: None,
185            kernel: Kernel::Auto,
186        }
187    }
188}
189
190impl VidyaBuilder {
191    #[inline(always)]
192    pub fn new() -> Self {
193        Self::default()
194    }
195    #[inline(always)]
196    pub fn short_period(mut self, n: usize) -> Self {
197        self.short_period = Some(n);
198        self
199    }
200    #[inline(always)]
201    pub fn long_period(mut self, n: usize) -> Self {
202        self.long_period = Some(n);
203        self
204    }
205    #[inline(always)]
206    pub fn alpha(mut self, a: f64) -> Self {
207        self.alpha = Some(a);
208        self
209    }
210    #[inline(always)]
211    pub fn kernel(mut self, k: Kernel) -> Self {
212        self.kernel = k;
213        self
214    }
215    #[inline(always)]
216    pub fn apply(self, c: &Candles) -> Result<VidyaOutput, VidyaError> {
217        let p = VidyaParams {
218            short_period: self.short_period,
219            long_period: self.long_period,
220            alpha: self.alpha,
221        };
222        let i = VidyaInput::from_candles(c, "close", p);
223        vidya_with_kernel(&i, self.kernel)
224    }
225    #[inline(always)]
226    pub fn apply_slice(self, d: &[f64]) -> Result<VidyaOutput, VidyaError> {
227        let p = VidyaParams {
228            short_period: self.short_period,
229            long_period: self.long_period,
230            alpha: self.alpha,
231        };
232        let i = VidyaInput::from_slice(d, p);
233        vidya_with_kernel(&i, self.kernel)
234    }
235    #[inline(always)]
236    pub fn into_stream(self) -> Result<VidyaStream, VidyaError> {
237        let p = VidyaParams {
238            short_period: self.short_period,
239            long_period: self.long_period,
240            alpha: self.alpha,
241        };
242        VidyaStream::try_new(p)
243    }
244}
245
246#[derive(Debug, Error)]
247pub enum VidyaError {
248    #[error("vidya: Input data slice is empty.")]
249    EmptyInputData,
250    #[error("vidya: All values are NaN.")]
251    AllValuesNaN,
252    #[error("vidya: Invalid period: period = {period}, data length = {data_len}")]
253    InvalidPeriod { period: usize, data_len: usize },
254    #[error("vidya: Not enough valid data: needed = {needed}, valid = {valid}")]
255    NotEnoughValidData { needed: usize, valid: usize },
256    #[error("vidya: Invalid alpha: {alpha}")]
257    InvalidAlpha { alpha: f64 },
258    #[error("vidya: Output length mismatch: expected {expected}, got {got}")]
259    OutputLengthMismatch { expected: usize, got: usize },
260    #[error("vidya: Invalid range: start={start}, end={end}, step={step}")]
261    InvalidRange {
262        start: String,
263        end: String,
264        step: String,
265    },
266    #[error("vidya: Invalid kernel for batch: {0:?}")]
267    InvalidKernelForBatch(Kernel),
268}
269
270#[inline]
271pub fn vidya(input: &VidyaInput) -> Result<VidyaOutput, VidyaError> {
272    vidya_with_kernel(input, Kernel::Auto)
273}
274
275pub fn vidya_with_kernel(input: &VidyaInput, kernel: Kernel) -> Result<VidyaOutput, VidyaError> {
276    let data: &[f64] = match &input.data {
277        VidyaData::Candles { candles, source } => source_type(candles, source),
278        VidyaData::Slice(sl) => sl,
279    };
280
281    if data.is_empty() {
282        return Err(VidyaError::EmptyInputData);
283    }
284
285    let short_period = input.get_short_period();
286    let long_period = input.get_long_period();
287    let alpha = input.get_alpha();
288
289    if short_period < 2 {
290        return Err(VidyaError::InvalidPeriod {
291            period: short_period,
292            data_len: data.len(),
293        });
294    }
295    if long_period < short_period || long_period < 2 || long_period > data.len() {
296        return Err(VidyaError::InvalidPeriod {
297            period: long_period,
298            data_len: data.len(),
299        });
300    }
301    if !(0.0..=1.0).contains(&alpha) || alpha.is_nan() || alpha.is_infinite() {
302        return Err(VidyaError::InvalidAlpha { alpha });
303    }
304
305    let first = data
306        .iter()
307        .position(|&x| !x.is_nan())
308        .ok_or(VidyaError::AllValuesNaN)?;
309    if (data.len() - first) < long_period {
310        return Err(VidyaError::NotEnoughValidData {
311            needed: long_period,
312            valid: data.len() - first,
313        });
314    }
315
316    let chosen = match kernel {
317        Kernel::Auto => match detect_best_kernel() {
318            Kernel::Avx512 => Kernel::Avx2,
319            other => other,
320        },
321        other => other,
322    };
323
324    let warmup_period = first + long_period - 2;
325    let mut out = alloc_with_nan_prefix(data.len(), warmup_period);
326    unsafe {
327        match chosen {
328            Kernel::Scalar | Kernel::ScalarBatch => {
329                vidya_scalar(data, short_period, long_period, alpha, first, &mut out)
330            }
331            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
332            Kernel::Avx2 | Kernel::Avx2Batch => {
333                vidya_avx2(data, short_period, long_period, alpha, first, &mut out)
334            }
335            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
336            Kernel::Avx512 | Kernel::Avx512Batch => {
337                vidya_avx512(data, short_period, long_period, alpha, first, &mut out)
338            }
339            _ => vidya_scalar(data, short_period, long_period, alpha, first, &mut out),
340        }
341    }
342
343    Ok(VidyaOutput { values: out })
344}
345
346#[inline]
347pub fn vidya_into_slice(
348    dst: &mut [f64],
349    input: &VidyaInput,
350    kern: Kernel,
351) -> Result<(), VidyaError> {
352    let data: &[f64] = match &input.data {
353        VidyaData::Candles { candles, source } => source_type(candles, source),
354        VidyaData::Slice(sl) => sl,
355    };
356
357    if data.is_empty() {
358        return Err(VidyaError::EmptyInputData);
359    }
360
361    if dst.len() != data.len() {
362        return Err(VidyaError::OutputLengthMismatch {
363            expected: data.len(),
364            got: dst.len(),
365        });
366    }
367
368    let short_period = input.get_short_period();
369    let long_period = input.get_long_period();
370    let alpha = input.get_alpha();
371
372    if short_period < 2 {
373        return Err(VidyaError::InvalidPeriod {
374            period: short_period,
375            data_len: data.len(),
376        });
377    }
378    if long_period < short_period || long_period < 2 || long_period > data.len() {
379        return Err(VidyaError::InvalidPeriod {
380            period: long_period,
381            data_len: data.len(),
382        });
383    }
384    if !(0.0..=1.0).contains(&alpha) || alpha.is_nan() || alpha.is_infinite() {
385        return Err(VidyaError::InvalidAlpha { alpha });
386    }
387
388    let first = data
389        .iter()
390        .position(|&x| !x.is_nan())
391        .ok_or(VidyaError::AllValuesNaN)?;
392    if (data.len() - first) < long_period {
393        return Err(VidyaError::NotEnoughValidData {
394            needed: long_period,
395            valid: data.len() - first,
396        });
397    }
398
399    let chosen = match kern {
400        Kernel::Auto => match detect_best_kernel() {
401            Kernel::Avx512 => Kernel::Avx2,
402            other => other,
403        },
404        other => other,
405    };
406
407    let warmup_period = first + long_period - 2;
408
409    unsafe {
410        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
411        {
412            if matches!(chosen, Kernel::Scalar | Kernel::ScalarBatch) {
413                vidya_simd128(data, short_period, long_period, alpha, first, dst);
414
415                for v in &mut dst[..warmup_period] {
416                    *v = f64::NAN;
417                }
418                return Ok(());
419            }
420        }
421
422        match chosen {
423            Kernel::Scalar | Kernel::ScalarBatch => {
424                vidya_scalar(data, short_period, long_period, alpha, first, dst)
425            }
426            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
427            Kernel::Avx2 | Kernel::Avx2Batch => {
428                vidya_avx2(data, short_period, long_period, alpha, first, dst)
429            }
430            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
431            Kernel::Avx512 | Kernel::Avx512Batch => {
432                vidya_avx512(data, short_period, long_period, alpha, first, dst)
433            }
434            _ => vidya_scalar(data, short_period, long_period, alpha, first, dst),
435        }
436    }
437
438    for v in &mut dst[..warmup_period] {
439        *v = f64::NAN;
440    }
441
442    Ok(())
443}
444
445#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
446#[inline]
447pub fn vidya_into(input: &VidyaInput, out: &mut [f64]) -> Result<(), VidyaError> {
448    vidya_into_slice(out, input, Kernel::Auto)
449}
450
451#[inline]
452pub unsafe fn vidya_scalar(
453    data: &[f64],
454    short_period: usize,
455    long_period: usize,
456    alpha: f64,
457    first: usize,
458    out: &mut [f64],
459) {
460    let len = data.len();
461
462    let mut long_sum = 0.0_f64;
463    let mut long_sum2 = 0.0_f64;
464    let mut short_sum = 0.0_f64;
465    let mut short_sum2 = 0.0_f64;
466
467    let sp_f = short_period as f64;
468    let lp_f = long_period as f64;
469    let short_inv = 1.0 / sp_f;
470    let long_inv = 1.0 / lp_f;
471
472    let warm_end = first + long_period;
473    let short_head = warm_end - short_period;
474
475    for i in first..short_head {
476        let x = data[i];
477        long_sum += x;
478
479        long_sum2 = x.mul_add(x, long_sum2);
480    }
481
482    for i in short_head..warm_end {
483        let x = data[i];
484        long_sum += x;
485        long_sum2 = x.mul_add(x, long_sum2);
486        short_sum += x;
487        short_sum2 = x.mul_add(x, short_sum2);
488    }
489
490    let idx_m2 = warm_end - 2;
491    let idx_m1 = warm_end - 1;
492
493    let mut val = data[idx_m2];
494    out[idx_m2] = val;
495
496    if idx_m1 < len {
497        let short_mean = short_sum * short_inv;
498        let long_mean = long_sum * long_inv;
499        let short_var = short_sum2 * short_inv - (short_mean * short_mean);
500        let long_var = long_sum2 * long_inv - (long_mean * long_mean);
501        let short_std = short_var.sqrt();
502        let long_std = long_var.sqrt();
503
504        let mut k = short_std / long_std;
505        if k.is_nan() {
506            k = 0.0;
507        }
508        k *= alpha;
509
510        let x = data[idx_m1];
511
512        val = (x - val).mul_add(k, val);
513        out[idx_m1] = val;
514    }
515
516    for t in warm_end..len {
517        let x_new = data[t];
518        let x_new2 = x_new * x_new;
519
520        long_sum += x_new;
521        long_sum2 += x_new2;
522        short_sum += x_new;
523        short_sum2 += x_new2;
524
525        let x_long_out = data[t - long_period];
526        let x_short_out = data[t - short_period];
527        long_sum -= x_long_out;
528
529        long_sum2 = (-x_long_out).mul_add(x_long_out, long_sum2);
530        short_sum -= x_short_out;
531        short_sum2 = (-x_short_out).mul_add(x_short_out, short_sum2);
532
533        let short_mean = short_sum * short_inv;
534        let long_mean = long_sum * long_inv;
535        let short_var = short_sum2 * short_inv - (short_mean * short_mean);
536        let long_var = long_sum2 * long_inv - (long_mean * long_mean);
537        let short_std = short_var.sqrt();
538        let long_std = long_var.sqrt();
539
540        let mut k = short_std / long_std;
541        if k.is_nan() {
542            k = 0.0;
543        }
544        k *= alpha;
545
546        val = (x_new - val).mul_add(k, val);
547        out[t] = val;
548    }
549}
550
551#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
552#[inline]
553unsafe fn vidya_simd128(
554    data: &[f64],
555    short_period: usize,
556    long_period: usize,
557    alpha: f64,
558    first: usize,
559    out: &mut [f64],
560) {
561    use core::arch::wasm32::*;
562
563    let len = data.len();
564    let mut long_sum = 0.0;
565    let mut long_sum2 = 0.0;
566    let mut short_sum = 0.0;
567    let mut short_sum2 = 0.0;
568
569    for i in first..(first + long_period) {
570        long_sum += data[i];
571        long_sum2 += data[i] * data[i];
572        if i >= (first + long_period - short_period) {
573            short_sum += data[i];
574            short_sum2 += data[i] * data[i];
575        }
576    }
577
578    let mut val = data[first + long_period - 2];
579    out[first + long_period - 2] = val;
580
581    if first + long_period - 1 < data.len() {
582        let sp = short_period as f64;
583        let lp = long_period as f64;
584        let short_div = 1.0 / sp;
585        let long_div = 1.0 / lp;
586        let short_stddev =
587            (short_sum2 * short_div - (short_sum * short_div) * (short_sum * short_div)).sqrt();
588        let long_stddev =
589            (long_sum2 * long_div - (long_sum * long_div) * (long_sum * long_div)).sqrt();
590        let mut k = short_stddev / long_stddev;
591        if k.is_nan() {
592            k = 0.0;
593        }
594        k *= alpha;
595        val = (data[first + long_period - 1] - val) * k + val;
596        out[first + long_period - 1] = val;
597    }
598
599    let alpha_v = f64x2_splat(alpha);
600    let sp_v = f64x2_splat(short_period as f64);
601    let lp_v = f64x2_splat(long_period as f64);
602    let short_div_v = f64x2_splat(1.0 / short_period as f64);
603    let long_div_v = f64x2_splat(1.0 / long_period as f64);
604
605    for i in (first + long_period)..len {
606        let new_val = data[i];
607        long_sum += new_val;
608        long_sum2 += new_val * new_val;
609        short_sum += new_val;
610        short_sum2 += new_val * new_val;
611
612        let remove_long_idx = i - long_period;
613        let remove_short_idx = i - short_period;
614        let remove_long = data[remove_long_idx];
615        let remove_short = data[remove_short_idx];
616
617        long_sum -= remove_long;
618        long_sum2 -= remove_long * remove_long;
619        short_sum -= remove_short;
620        short_sum2 -= remove_short * remove_short;
621
622        let short_mean = short_sum / short_period as f64;
623        let long_mean = long_sum / long_period as f64;
624
625        let short_variance = short_sum2 / short_period as f64 - short_mean * short_mean;
626        let long_variance = long_sum2 / long_period as f64 - long_mean * long_mean;
627
628        let short_stddev = short_variance.sqrt();
629        let long_stddev = long_variance.sqrt();
630
631        let mut k = short_stddev / long_stddev;
632        if k.is_nan() {
633            k = 0.0;
634        }
635        k *= alpha;
636        val = (new_val - val) * k + val;
637        out[i] = val;
638    }
639}
640
641#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
642#[inline]
643pub unsafe fn vidya_avx2(
644    data: &[f64],
645    short_period: usize,
646    long_period: usize,
647    alpha: f64,
648    first: usize,
649    out: &mut [f64],
650) {
651    vidya_avx2_experimental(data, short_period, long_period, alpha, first, out);
652}
653
654#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
655#[inline]
656unsafe fn vidya_avx2_experimental(
657    data: &[f64],
658    short_period: usize,
659    long_period: usize,
660    alpha: f64,
661    first: usize,
662    out: &mut [f64],
663) {
664    let len = data.len();
665    let ptr = data.as_ptr();
666    let out_ptr = out.as_mut_ptr();
667
668    let mut long_sum = 0.0_f64;
669    let mut long_sum2 = 0.0_f64;
670    let mut short_sum = 0.0_f64;
671    let mut short_sum2 = 0.0_f64;
672
673    let sp_f = short_period as f64;
674    let lp_f = long_period as f64;
675    let short_inv = 1.0 / sp_f;
676    let long_inv = 1.0 / lp_f;
677
678    let warm_end = first + long_period;
679    let short_head = warm_end - short_period;
680
681    let mut i = first;
682    while i < short_head {
683        let x = *ptr.add(i);
684        long_sum += x;
685        long_sum2 = x.mul_add(x, long_sum2);
686        i += 1;
687    }
688
689    while i < warm_end {
690        let x = *ptr.add(i);
691        long_sum += x;
692        long_sum2 = x.mul_add(x, long_sum2);
693        short_sum += x;
694        short_sum2 = x.mul_add(x, short_sum2);
695        i += 1;
696    }
697
698    let idx_m2 = warm_end - 2;
699    let idx_m1 = warm_end - 1;
700
701    let mut val = *ptr.add(idx_m2);
702    *out_ptr.add(idx_m2) = val;
703
704    if idx_m1 < len {
705        let short_mean = short_sum * short_inv;
706        let long_mean = long_sum * long_inv;
707        let short_var = short_sum2 * short_inv - (short_mean * short_mean);
708        let long_var = long_sum2 * long_inv - (long_mean * long_mean);
709        let short_std = short_var.sqrt();
710        let long_std = long_var.sqrt();
711
712        let mut k = short_std / long_std;
713        if k.is_nan() {
714            k = 0.0;
715        }
716        k *= alpha;
717
718        let x = *ptr.add(idx_m1);
719
720        val = (x - val).mul_add(k, val);
721        *out_ptr.add(idx_m1) = val;
722    }
723
724    let mut t = warm_end;
725    while t < len {
726        let x_new = *ptr.add(t);
727        let x_new2 = x_new * x_new;
728
729        long_sum += x_new;
730        long_sum2 = x_new.mul_add(x_new, long_sum2);
731        short_sum += x_new;
732        short_sum2 = x_new.mul_add(x_new, short_sum2);
733
734        let x_long_out = *ptr.add(t - long_period);
735        let x_short_out = *ptr.add(t - short_period);
736        long_sum -= x_long_out;
737        long_sum2 = (-x_long_out).mul_add(x_long_out, long_sum2);
738        short_sum -= x_short_out;
739        short_sum2 = (-x_short_out).mul_add(x_short_out, short_sum2);
740
741        let short_mean = short_sum * short_inv;
742        let long_mean = long_sum * long_inv;
743        let short_var = short_sum2 * short_inv - (short_mean * short_mean);
744        let long_var = long_sum2 * long_inv - (long_mean * long_mean);
745        let short_std = short_var.sqrt();
746        let long_std = long_var.sqrt();
747
748        let mut k = short_std / long_std;
749        if k.is_nan() {
750            k = 0.0;
751        }
752        k *= alpha;
753
754        val = (x_new - val).mul_add(k, val);
755        *out_ptr.add(t) = val;
756
757        t += 1;
758    }
759}
760
761#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
762#[inline]
763pub unsafe fn vidya_avx512(
764    data: &[f64],
765    short_period: usize,
766    long_period: usize,
767    alpha: f64,
768    first: usize,
769    out: &mut [f64],
770) {
771    if long_period <= 32 {
772        vidya_avx512_short(data, short_period, long_period, alpha, first, out)
773    } else {
774        vidya_avx512_long(data, short_period, long_period, alpha, first, out)
775    }
776}
777
778#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
779#[inline]
780pub unsafe fn vidya_avx512_short(
781    data: &[f64],
782    short_period: usize,
783    long_period: usize,
784    alpha: f64,
785    first: usize,
786    out: &mut [f64],
787) {
788    vidya_scalar(data, short_period, long_period, alpha, first, out);
789}
790
791#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
792#[inline]
793pub unsafe fn vidya_avx512_long(
794    data: &[f64],
795    short_period: usize,
796    long_period: usize,
797    alpha: f64,
798    first: usize,
799    out: &mut [f64],
800) {
801    vidya_scalar(data, short_period, long_period, alpha, first, out);
802}
803
804#[derive(Debug, Clone)]
805pub struct VidyaStream {
806    short_period: usize,
807    long_period: usize,
808    alpha: f64,
809    long_buf: Vec<f64>,
810    short_buf: Vec<f64>,
811    long_sum: f64,
812    long_sum2: f64,
813    short_sum: f64,
814    short_sum2: f64,
815    head: usize,
816    idx: usize,
817    val: f64,
818    filled: bool,
819}
820
821impl VidyaStream {
822    pub fn try_new(params: VidyaParams) -> Result<Self, VidyaError> {
823        let short_period = params.short_period.unwrap_or(2);
824        let long_period = params.long_period.unwrap_or(5);
825        let alpha = params.alpha.unwrap_or(0.2);
826
827        if short_period < 2 || long_period < short_period || long_period < 2 {
828            return Err(VidyaError::InvalidPeriod {
829                period: long_period,
830                data_len: 0,
831            });
832        }
833        if !(0.0..=1.0).contains(&alpha) || alpha.is_nan() || alpha.is_infinite() {
834            return Err(VidyaError::InvalidAlpha { alpha });
835        }
836        Ok(Self {
837            short_period,
838            long_period,
839            alpha,
840            long_buf: alloc_with_nan_prefix(long_period, long_period),
841            short_buf: alloc_with_nan_prefix(short_period, short_period),
842            long_sum: 0.0,
843            long_sum2: 0.0,
844            short_sum: 0.0,
845            short_sum2: 0.0,
846            head: 0,
847            idx: 0,
848            val: f64::NAN,
849            filled: false,
850        })
851    }
852
853    #[inline(always)]
854    pub fn update(&mut self, x: f64) -> Option<f64> {
855        let long_tail = self.long_buf[self.head];
856        let short_idx = self.idx % self.short_period;
857        let short_tail = self.short_buf[short_idx];
858
859        let phase2_start = self.long_period - self.short_period;
860
861        self.long_sum += x;
862        self.long_sum2 = x.mul_add(x, self.long_sum2);
863
864        if self.idx >= phase2_start {
865            self.short_sum += x;
866            self.short_sum2 = x.mul_add(x, self.short_sum2);
867        }
868
869        if self.idx >= self.long_period {
870            self.long_sum -= long_tail;
871            self.long_sum2 = (-long_tail).mul_add(long_tail, self.long_sum2);
872
873            self.short_sum -= short_tail;
874            self.short_sum2 = (-short_tail).mul_add(short_tail, self.short_sum2);
875        }
876
877        self.long_buf[self.head] = x;
878        self.short_buf[short_idx] = x;
879
880        let mut h = self.head + 1;
881        if h == self.long_period {
882            h = 0;
883        }
884        self.head = h;
885
886        self.idx += 1;
887
888        if self.idx < self.long_period - 1 {
889            self.val = x;
890            return None;
891        }
892        if self.idx == self.long_period - 1 {
893            self.val = x;
894            return Some(self.val);
895        }
896
897        let short_inv = 1.0 / (self.short_period as f64);
898        let long_inv = 1.0 / (self.long_period as f64);
899
900        let short_mean = self.short_sum * short_inv;
901        let long_mean = self.long_sum * long_inv;
902
903        let short_var = self.short_sum2 * short_inv - (short_mean * short_mean);
904        let long_var = self.long_sum2 * long_inv - (long_mean * long_mean);
905
906        let mut k = 0.0;
907        if long_var > 0.0 && short_var > 0.0 {
908            k = (short_var / long_var).sqrt() * self.alpha;
909        }
910
911        self.val = (x - self.val).mul_add(k, self.val);
912        Some(self.val)
913    }
914}
915
916#[derive(Clone, Debug)]
917pub struct VidyaBatchRange {
918    pub short_period: (usize, usize, usize),
919    pub long_period: (usize, usize, usize),
920    pub alpha: (f64, f64, f64),
921}
922
923impl Default for VidyaBatchRange {
924    fn default() -> Self {
925        Self {
926            short_period: (2, 2, 0),
927            long_period: (5, 254, 1),
928            alpha: (0.2, 0.2, 0.0),
929        }
930    }
931}
932
933#[derive(Clone, Debug, Default)]
934pub struct VidyaBatchBuilder {
935    range: VidyaBatchRange,
936    kernel: Kernel,
937}
938
939impl VidyaBatchBuilder {
940    pub fn new() -> Self {
941        Self::default()
942    }
943    pub fn kernel(mut self, k: Kernel) -> Self {
944        self.kernel = k;
945        self
946    }
947    #[inline]
948    pub fn short_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
949        self.range.short_period = (start, end, step);
950        self
951    }
952    #[inline]
953    pub fn short_period_static(mut self, n: usize) -> Self {
954        self.range.short_period = (n, n, 0);
955        self
956    }
957    #[inline]
958    pub fn long_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
959        self.range.long_period = (start, end, step);
960        self
961    }
962    #[inline]
963    pub fn long_period_static(mut self, n: usize) -> Self {
964        self.range.long_period = (n, n, 0);
965        self
966    }
967    #[inline]
968    pub fn alpha_range(mut self, start: f64, end: f64, step: f64) -> Self {
969        self.range.alpha = (start, end, step);
970        self
971    }
972    #[inline]
973    pub fn alpha_static(mut self, a: f64) -> Self {
974        self.range.alpha = (a, a, 0.0);
975        self
976    }
977    pub fn apply_slice(self, data: &[f64]) -> Result<VidyaBatchOutput, VidyaError> {
978        vidya_batch_with_kernel(data, &self.range, self.kernel)
979    }
980    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<VidyaBatchOutput, VidyaError> {
981        VidyaBatchBuilder::new().kernel(k).apply_slice(data)
982    }
983    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<VidyaBatchOutput, VidyaError> {
984        let slice = source_type(c, src);
985        self.apply_slice(slice)
986    }
987    pub fn with_default_candles(c: &Candles) -> Result<VidyaBatchOutput, VidyaError> {
988        VidyaBatchBuilder::new()
989            .kernel(Kernel::Auto)
990            .apply_candles(c, "close")
991    }
992}
993
994pub fn vidya_batch_with_kernel(
995    data: &[f64],
996    sweep: &VidyaBatchRange,
997    k: Kernel,
998) -> Result<VidyaBatchOutput, VidyaError> {
999    if data.is_empty() {
1000        return Err(VidyaError::EmptyInputData);
1001    }
1002    let kernel = match k {
1003        Kernel::Auto => match detect_best_batch_kernel() {
1004            Kernel::Avx512Batch => Kernel::Avx2Batch,
1005            other => other,
1006        },
1007        other if other.is_batch() => other,
1008        other => {
1009            return Err(VidyaError::InvalidKernelForBatch(other));
1010        }
1011    };
1012
1013    let simd = match kernel {
1014        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1015        Kernel::Avx512Batch => Kernel::Avx512,
1016        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1017        Kernel::Avx2Batch => Kernel::Avx2,
1018        Kernel::ScalarBatch => Kernel::Scalar,
1019        _ => Kernel::Scalar,
1020    };
1021    vidya_batch_par_slice(data, sweep, simd)
1022}
1023
1024#[derive(Clone, Debug)]
1025pub struct VidyaBatchOutput {
1026    pub values: Vec<f64>,
1027    pub combos: Vec<VidyaParams>,
1028    pub rows: usize,
1029    pub cols: usize,
1030}
1031impl VidyaBatchOutput {
1032    pub fn row_for_params(&self, p: &VidyaParams) -> Option<usize> {
1033        self.combos.iter().position(|c| {
1034            c.short_period.unwrap_or(2) == p.short_period.unwrap_or(2)
1035                && c.long_period.unwrap_or(5) == p.long_period.unwrap_or(5)
1036                && (c.alpha.unwrap_or(0.2) - p.alpha.unwrap_or(0.2)).abs() < 1e-12
1037        })
1038    }
1039    pub fn values_for(&self, p: &VidyaParams) -> Option<&[f64]> {
1040        self.row_for_params(p).map(|row| {
1041            let start = row * self.cols;
1042            &self.values[start..start + self.cols]
1043        })
1044    }
1045}
1046
1047#[inline(always)]
1048fn expand_grid(r: &VidyaBatchRange) -> Result<Vec<VidyaParams>, VidyaError> {
1049    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, VidyaError> {
1050        if step == 0 || start == end {
1051            return Ok(vec![start]);
1052        }
1053        if start < end {
1054            let mut v = Vec::new();
1055            let mut x = start;
1056            let st = step.max(1);
1057            while x <= end {
1058                v.push(x);
1059                x = x.saturating_add(st);
1060            }
1061            if v.is_empty() {
1062                return Err(VidyaError::InvalidRange {
1063                    start: start.to_string(),
1064                    end: end.to_string(),
1065                    step: step.to_string(),
1066                });
1067            }
1068            return Ok(v);
1069        }
1070        let mut v = Vec::new();
1071        let mut x = start as isize;
1072        let end_i = end as isize;
1073        let st = (step as isize).max(1);
1074        while x >= end_i {
1075            v.push(x as usize);
1076            x -= st;
1077        }
1078        if v.is_empty() {
1079            return Err(VidyaError::InvalidRange {
1080                start: start.to_string(),
1081                end: end.to_string(),
1082                step: step.to_string(),
1083            });
1084        }
1085        Ok(v)
1086    }
1087    fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, VidyaError> {
1088        if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
1089            return Ok(vec![start]);
1090        }
1091        if start < end {
1092            let mut v = Vec::new();
1093            let mut x = start;
1094            let st = step.abs();
1095            while x <= end + 1e-12 {
1096                v.push(x);
1097                x += st;
1098            }
1099            if v.is_empty() {
1100                return Err(VidyaError::InvalidRange {
1101                    start: start.to_string(),
1102                    end: end.to_string(),
1103                    step: step.to_string(),
1104                });
1105            }
1106            return Ok(v);
1107        }
1108        let mut v = Vec::new();
1109        let mut x = start;
1110        let st = step.abs();
1111        while x + 1e-12 >= end {
1112            v.push(x);
1113            x -= st;
1114        }
1115        if v.is_empty() {
1116            return Err(VidyaError::InvalidRange {
1117                start: start.to_string(),
1118                end: end.to_string(),
1119                step: step.to_string(),
1120            });
1121        }
1122        Ok(v)
1123    }
1124
1125    let short_periods = axis_usize(r.short_period)?;
1126    let long_periods = axis_usize(r.long_period)?;
1127    let alphas = axis_f64(r.alpha)?;
1128
1129    let cap = short_periods
1130        .len()
1131        .checked_mul(long_periods.len())
1132        .and_then(|x| x.checked_mul(alphas.len()))
1133        .ok_or_else(|| VidyaError::InvalidRange {
1134            start: "cap".into(),
1135            end: "overflow".into(),
1136            step: "mul".into(),
1137        })?;
1138
1139    let mut out = Vec::with_capacity(cap);
1140    for &sp in &short_periods {
1141        for &lp in &long_periods {
1142            for &a in &alphas {
1143                out.push(VidyaParams {
1144                    short_period: Some(sp),
1145                    long_period: Some(lp),
1146                    alpha: Some(a),
1147                });
1148            }
1149        }
1150    }
1151    Ok(out)
1152}
1153
1154#[inline(always)]
1155pub fn vidya_batch_slice(
1156    data: &[f64],
1157    sweep: &VidyaBatchRange,
1158    kern: Kernel,
1159) -> Result<VidyaBatchOutput, VidyaError> {
1160    vidya_batch_inner(data, sweep, kern, false)
1161}
1162
1163#[inline(always)]
1164pub fn vidya_batch_par_slice(
1165    data: &[f64],
1166    sweep: &VidyaBatchRange,
1167    kern: Kernel,
1168) -> Result<VidyaBatchOutput, VidyaError> {
1169    vidya_batch_inner(data, sweep, kern, true)
1170}
1171
1172#[inline(always)]
1173fn vidya_batch_inner(
1174    data: &[f64],
1175    sweep: &VidyaBatchRange,
1176    kern: Kernel,
1177    parallel: bool,
1178) -> Result<VidyaBatchOutput, VidyaError> {
1179    let combos = expand_grid(sweep)?;
1180    if data.is_empty() {
1181        return Err(VidyaError::EmptyInputData);
1182    }
1183    let first = data
1184        .iter()
1185        .position(|x| !x.is_nan())
1186        .ok_or(VidyaError::AllValuesNaN)?;
1187    let max_long = combos.iter().map(|c| c.long_period.unwrap()).max().unwrap();
1188    if data.len() - first < max_long {
1189        return Err(VidyaError::NotEnoughValidData {
1190            needed: max_long,
1191            valid: data.len() - first,
1192        });
1193    }
1194
1195    let rows = combos.len();
1196    let cols = data.len();
1197
1198    let _ = rows
1199        .checked_mul(cols)
1200        .ok_or_else(|| VidyaError::InvalidRange {
1201            start: rows.to_string(),
1202            end: cols.to_string(),
1203            step: "rows*cols".into(),
1204        })?;
1205
1206    let warmup_periods: Vec<usize> = combos
1207        .iter()
1208        .map(|c| first + c.long_period.unwrap() - 2)
1209        .collect();
1210
1211    let mut buf_mu = make_uninit_matrix(rows, cols);
1212    init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
1213
1214    let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
1215    let out: &mut [f64] = unsafe {
1216        core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
1217    };
1218    let mut values = out;
1219
1220    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1221        let p = &combos[row];
1222        let sp = p.short_period.unwrap();
1223        let lp = p.long_period.unwrap();
1224        let a = p.alpha.unwrap();
1225        match kern {
1226            Kernel::Scalar | Kernel::ScalarBatch => {
1227                vidya_row_scalar(data, first, sp, lp, a, out_row)
1228            }
1229            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1230            Kernel::Avx2 | Kernel::Avx2Batch => vidya_row_avx2(data, first, sp, lp, a, out_row),
1231            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1232            Kernel::Avx512 | Kernel::Avx512Batch => {
1233                vidya_row_avx512(data, first, sp, lp, a, out_row)
1234            }
1235            _ => vidya_row_scalar(data, first, sp, lp, a, out_row),
1236        }
1237    };
1238    if parallel {
1239        #[cfg(not(target_arch = "wasm32"))]
1240        {
1241            values
1242                .par_chunks_mut(cols)
1243                .enumerate()
1244                .for_each(|(row, slice)| do_row(row, slice));
1245        }
1246
1247        #[cfg(target_arch = "wasm32")]
1248        {
1249            for (row, slice) in values.chunks_mut(cols).enumerate() {
1250                do_row(row, slice);
1251            }
1252        }
1253    } else {
1254        for (row, slice) in values.chunks_mut(cols).enumerate() {
1255            do_row(row, slice);
1256        }
1257    }
1258
1259    let values = unsafe {
1260        Vec::from_raw_parts(
1261            buf_guard.as_mut_ptr() as *mut f64,
1262            buf_guard.len(),
1263            buf_guard.capacity(),
1264        )
1265    };
1266
1267    Ok(VidyaBatchOutput {
1268        values,
1269        combos,
1270        rows,
1271        cols,
1272    })
1273}
1274
1275#[inline(always)]
1276unsafe fn vidya_row_scalar(
1277    data: &[f64],
1278    first: usize,
1279    short_period: usize,
1280    long_period: usize,
1281    alpha: f64,
1282    out: &mut [f64],
1283) {
1284    vidya_scalar(data, short_period, long_period, alpha, first, out);
1285}
1286
1287#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1288#[inline(always)]
1289unsafe fn vidya_row_avx2(
1290    data: &[f64],
1291    first: usize,
1292    short_period: usize,
1293    long_period: usize,
1294    alpha: f64,
1295    out: &mut [f64],
1296) {
1297    vidya_avx2_experimental(data, short_period, long_period, alpha, first, out);
1298}
1299
1300#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1301#[inline(always)]
1302unsafe fn vidya_row_avx512(
1303    data: &[f64],
1304    first: usize,
1305    short_period: usize,
1306    long_period: usize,
1307    alpha: f64,
1308    out: &mut [f64],
1309) {
1310    if long_period <= 32 {
1311        vidya_row_avx512_short(data, first, short_period, long_period, alpha, out)
1312    } else {
1313        vidya_row_avx512_long(data, first, short_period, long_period, alpha, out)
1314    }
1315}
1316
1317#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1318#[inline(always)]
1319unsafe fn vidya_row_avx512_short(
1320    data: &[f64],
1321    first: usize,
1322    short_period: usize,
1323    long_period: usize,
1324    alpha: f64,
1325    out: &mut [f64],
1326) {
1327    vidya_scalar(data, short_period, long_period, alpha, first, out);
1328}
1329
1330#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1331#[inline(always)]
1332unsafe fn vidya_row_avx512_long(
1333    data: &[f64],
1334    first: usize,
1335    short_period: usize,
1336    long_period: usize,
1337    alpha: f64,
1338    out: &mut [f64],
1339) {
1340    vidya_scalar(data, short_period, long_period, alpha, first, out);
1341}
1342
1343#[cfg(test)]
1344mod tests {
1345    use super::*;
1346    use crate::skip_if_unsupported;
1347    use crate::utilities::data_loader::read_candles_from_csv;
1348
1349    fn check_vidya_partial_params(
1350        test_name: &str,
1351        kernel: Kernel,
1352    ) -> Result<(), Box<dyn std::error::Error>> {
1353        skip_if_unsupported!(kernel, test_name);
1354        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1355        let candles = read_candles_from_csv(file_path)?;
1356        let default_params = VidyaParams {
1357            short_period: None,
1358            long_period: Some(10),
1359            alpha: None,
1360        };
1361        let input_default = VidyaInput::from_candles(&candles, "close", default_params);
1362        let output_default = vidya_with_kernel(&input_default, kernel)?;
1363        assert_eq!(output_default.values.len(), candles.close.len());
1364        Ok(())
1365    }
1366
1367    fn check_vidya_accuracy(
1368        test_name: &str,
1369        kernel: Kernel,
1370    ) -> Result<(), Box<dyn std::error::Error>> {
1371        skip_if_unsupported!(kernel, test_name);
1372        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1373        let candles = read_candles_from_csv(file_path)?;
1374        let close_prices = &candles.close;
1375
1376        let params = VidyaParams {
1377            short_period: Some(2),
1378            long_period: Some(5),
1379            alpha: Some(0.2),
1380        };
1381        let input = VidyaInput::from_candles(&candles, "close", params);
1382        let vidya_result = vidya_with_kernel(&input, kernel)?;
1383        assert_eq!(vidya_result.values.len(), close_prices.len());
1384
1385        if vidya_result.values.len() >= 5 {
1386            let expected_last_five = [
1387                59553.42785306692,
1388                59503.60445032524,
1389                59451.72283651444,
1390                59413.222561244685,
1391                59239.716526894175,
1392            ];
1393            let start_index = vidya_result.values.len() - 5;
1394            let result_last_five = &vidya_result.values[start_index..];
1395            for (i, &value) in result_last_five.iter().enumerate() {
1396                let expected_value = expected_last_five[i];
1397                assert!(
1398                    (value - expected_value).abs() < 1e-1,
1399                    "VIDYA mismatch at index {}: expected {}, got {}",
1400                    i,
1401                    expected_value,
1402                    value
1403                );
1404            }
1405        }
1406        Ok(())
1407    }
1408
1409    fn check_vidya_default_candles(
1410        test_name: &str,
1411        kernel: Kernel,
1412    ) -> Result<(), Box<dyn std::error::Error>> {
1413        skip_if_unsupported!(kernel, test_name);
1414        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1415        let candles = read_candles_from_csv(file_path)?;
1416        let input = VidyaInput::with_default_candles(&candles);
1417        match input.data {
1418            VidyaData::Candles { source, .. } => assert_eq!(source, "close"),
1419            _ => panic!("Expected VidyaData::Candles"),
1420        }
1421        let output = vidya_with_kernel(&input, kernel)?;
1422        assert_eq!(output.values.len(), candles.close.len());
1423        Ok(())
1424    }
1425
1426    fn check_vidya_invalid_params(
1427        test_name: &str,
1428        kernel: Kernel,
1429    ) -> Result<(), Box<dyn std::error::Error>> {
1430        skip_if_unsupported!(kernel, test_name);
1431        let data = [10.0, 20.0, 30.0];
1432        let params = VidyaParams {
1433            short_period: Some(0),
1434            long_period: Some(5),
1435            alpha: Some(0.2),
1436        };
1437        let input = VidyaInput::from_slice(&data, params);
1438        let result = vidya_with_kernel(&input, kernel);
1439        assert!(result.is_err(), "Expected error for invalid short period");
1440        Ok(())
1441    }
1442
1443    fn check_vidya_exceeding_data_length(
1444        test_name: &str,
1445        kernel: Kernel,
1446    ) -> Result<(), Box<dyn std::error::Error>> {
1447        skip_if_unsupported!(kernel, test_name);
1448        let data = [10.0, 20.0, 30.0];
1449        let params = VidyaParams {
1450            short_period: Some(2),
1451            long_period: Some(5),
1452            alpha: Some(0.2),
1453        };
1454        let input = VidyaInput::from_slice(&data, params);
1455        let result = vidya_with_kernel(&input, kernel);
1456        assert!(result.is_err(), "Expected error for period > data.len()");
1457        Ok(())
1458    }
1459
1460    fn check_vidya_very_small_data_set(
1461        test_name: &str,
1462        kernel: Kernel,
1463    ) -> Result<(), Box<dyn std::error::Error>> {
1464        skip_if_unsupported!(kernel, test_name);
1465        let data = [42.0, 43.0];
1466        let params = VidyaParams {
1467            short_period: Some(2),
1468            long_period: Some(5),
1469            alpha: Some(0.2),
1470        };
1471        let input = VidyaInput::from_slice(&data, params);
1472        let result = vidya_with_kernel(&input, kernel);
1473        assert!(
1474            result.is_err(),
1475            "Expected error for data smaller than long period"
1476        );
1477        Ok(())
1478    }
1479
1480    fn check_vidya_reinput(
1481        test_name: &str,
1482        kernel: Kernel,
1483    ) -> Result<(), Box<dyn std::error::Error>> {
1484        skip_if_unsupported!(kernel, test_name);
1485        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1486        let candles = read_candles_from_csv(file_path)?;
1487
1488        let first_params = VidyaParams {
1489            short_period: Some(2),
1490            long_period: Some(5),
1491            alpha: Some(0.2),
1492        };
1493        let first_input = VidyaInput::from_candles(&candles, "close", first_params);
1494        let first_result = vidya_with_kernel(&first_input, kernel)?;
1495
1496        let second_params = VidyaParams {
1497            short_period: Some(2),
1498            long_period: Some(5),
1499            alpha: Some(0.2),
1500        };
1501        let second_input = VidyaInput::from_slice(&first_result.values, second_params);
1502        let second_result = vidya_with_kernel(&second_input, kernel)?;
1503        assert_eq!(second_result.values.len(), first_result.values.len());
1504        Ok(())
1505    }
1506
1507    fn check_vidya_nan_handling(
1508        test_name: &str,
1509        kernel: Kernel,
1510    ) -> Result<(), Box<dyn std::error::Error>> {
1511        skip_if_unsupported!(kernel, test_name);
1512        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1513        let candles = read_candles_from_csv(file_path)?;
1514        let params = VidyaParams {
1515            short_period: Some(2),
1516            long_period: Some(5),
1517            alpha: Some(0.2),
1518        };
1519        let input = VidyaInput::from_candles(&candles, "close", params);
1520        let vidya_result = vidya_with_kernel(&input, kernel)?;
1521        if vidya_result.values.len() > 10 {
1522            for i in 10..vidya_result.values.len() {
1523                assert!(!vidya_result.values[i].is_nan());
1524            }
1525        }
1526        Ok(())
1527    }
1528
1529    fn check_vidya_streaming(
1530        test_name: &str,
1531        kernel: Kernel,
1532    ) -> Result<(), Box<dyn std::error::Error>> {
1533        skip_if_unsupported!(kernel, test_name);
1534        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1535        let candles = read_candles_from_csv(file_path)?;
1536
1537        let short_period = 2;
1538        let long_period = 5;
1539        let alpha = 0.2;
1540
1541        let params = VidyaParams {
1542            short_period: Some(short_period),
1543            long_period: Some(long_period),
1544            alpha: Some(alpha),
1545        };
1546        let input = VidyaInput::from_candles(&candles, "close", params.clone());
1547        let batch_output = vidya_with_kernel(&input, kernel)?.values;
1548
1549        let mut stream = VidyaStream::try_new(params.clone())?;
1550        let mut stream_values = Vec::with_capacity(candles.close.len());
1551        for &price in &candles.close {
1552            match stream.update(price) {
1553                Some(val) => stream_values.push(val),
1554                None => stream_values.push(f64::NAN),
1555            }
1556        }
1557        assert_eq!(batch_output.len(), stream_values.len());
1558        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1559            if b.is_nan() && s.is_nan() {
1560                continue;
1561            }
1562            let diff = (b - s).abs();
1563            assert!(
1564                diff < 1e-3,
1565                "[{}] VIDYA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1566                test_name,
1567                i,
1568                b,
1569                s,
1570                diff
1571            );
1572        }
1573        Ok(())
1574    }
1575
1576    #[cfg(debug_assertions)]
1577    fn check_vidya_no_poison(
1578        test_name: &str,
1579        kernel: Kernel,
1580    ) -> Result<(), Box<dyn std::error::Error>> {
1581        skip_if_unsupported!(kernel, test_name);
1582
1583        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1584        let candles = read_candles_from_csv(file_path)?;
1585
1586        let test_params = vec![
1587            VidyaParams::default(),
1588            VidyaParams {
1589                short_period: Some(1),
1590                long_period: Some(2),
1591                alpha: Some(0.1),
1592            },
1593            VidyaParams {
1594                short_period: Some(2),
1595                long_period: Some(3),
1596                alpha: Some(0.2),
1597            },
1598            VidyaParams {
1599                short_period: Some(2),
1600                long_period: Some(5),
1601                alpha: Some(0.5),
1602            },
1603            VidyaParams {
1604                short_period: Some(3),
1605                long_period: Some(7),
1606                alpha: Some(0.3),
1607            },
1608            VidyaParams {
1609                short_period: Some(4),
1610                long_period: Some(10),
1611                alpha: Some(0.2),
1612            },
1613            VidyaParams {
1614                short_period: Some(5),
1615                long_period: Some(20),
1616                alpha: Some(0.4),
1617            },
1618            VidyaParams {
1619                short_period: Some(10),
1620                long_period: Some(30),
1621                alpha: Some(0.2),
1622            },
1623            VidyaParams {
1624                short_period: Some(15),
1625                long_period: Some(50),
1626                alpha: Some(0.3),
1627            },
1628            VidyaParams {
1629                short_period: Some(20),
1630                long_period: Some(100),
1631                alpha: Some(0.2),
1632            },
1633            VidyaParams {
1634                short_period: Some(50),
1635                long_period: Some(200),
1636                alpha: Some(0.1),
1637            },
1638            VidyaParams {
1639                short_period: Some(2),
1640                long_period: Some(10),
1641                alpha: Some(0.8),
1642            },
1643            VidyaParams {
1644                short_period: Some(3),
1645                long_period: Some(15),
1646                alpha: Some(1.0),
1647            },
1648            VidyaParams {
1649                short_period: Some(1),
1650                long_period: Some(100),
1651                alpha: Some(0.01),
1652            },
1653            VidyaParams {
1654                short_period: Some(99),
1655                long_period: Some(100),
1656                alpha: Some(0.99),
1657            },
1658        ];
1659
1660        for (param_idx, params) in test_params.iter().enumerate() {
1661            let input = VidyaInput::from_candles(&candles, "close", params.clone());
1662            let output = vidya_with_kernel(&input, kernel)?;
1663
1664            for (i, &val) in output.values.iter().enumerate() {
1665                if val.is_nan() {
1666                    continue;
1667                }
1668
1669                let bits = val.to_bits();
1670
1671                if bits == 0x11111111_11111111 {
1672                    panic!(
1673                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1674						 with params: {:?} (param set {})",
1675                        test_name, val, bits, i, params, param_idx
1676                    );
1677                }
1678
1679                if bits == 0x22222222_22222222 {
1680                    panic!(
1681                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1682						 with params: {:?} (param set {})",
1683                        test_name, val, bits, i, params, param_idx
1684                    );
1685                }
1686
1687                if bits == 0x33333333_33333333 {
1688                    panic!(
1689                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1690						 with params: {:?} (param set {})",
1691                        test_name, val, bits, i, params, param_idx
1692                    );
1693                }
1694            }
1695        }
1696
1697        Ok(())
1698    }
1699
1700    #[cfg(not(debug_assertions))]
1701    fn check_vidya_no_poison(
1702        _test_name: &str,
1703        _kernel: Kernel,
1704    ) -> Result<(), Box<dyn std::error::Error>> {
1705        Ok(())
1706    }
1707
1708    #[cfg(feature = "proptest")]
1709    #[allow(clippy::float_cmp)]
1710    fn check_vidya_property(
1711        test_name: &str,
1712        kernel: Kernel,
1713    ) -> Result<(), Box<dyn std::error::Error>> {
1714        use proptest::prelude::*;
1715        skip_if_unsupported!(kernel, test_name);
1716
1717        let strat = (2usize..=20).prop_flat_map(|short_period| {
1718            let long_min = (short_period + 1).max(2);
1719            let long_max = 100.min(long_min + 50);
1720
1721            (long_min..=long_max).prop_flat_map(move |long_period| {
1722                let data_len = long_period.max(10)..400;
1723
1724                (
1725                    prop::collection::vec(
1726                        (-0.05f64..0.05f64).prop_filter("finite", |x| x.is_finite()),
1727                        data_len,
1728                    )
1729                    .prop_map(|returns| {
1730                        let mut prices = Vec::with_capacity(returns.len());
1731                        let mut price = 100.0;
1732
1733                        for (i, ret) in returns.iter().enumerate() {
1734                            let volatility_factor = if (i / 20) % 2 == 0 { 0.5 } else { 2.0 };
1735                            price *= 1.0 + (ret * volatility_factor);
1736                            prices.push(price);
1737                        }
1738                        prices
1739                    }),
1740                    Just(short_period),
1741                    Just(long_period),
1742                    0.01f64..1.0f64,
1743                )
1744            })
1745        });
1746
1747        proptest::test_runner::TestRunner::default()
1748			.run(&strat, |(data, short_period, long_period, alpha)| {
1749				let params = VidyaParams {
1750					short_period: Some(short_period),
1751					long_period: Some(long_period),
1752					alpha: Some(alpha),
1753				};
1754				let input = VidyaInput::from_slice(&data, params.clone());
1755
1756
1757				let VidyaOutput { values: out } = vidya_with_kernel(&input, kernel).unwrap();
1758				let VidyaOutput { values: ref_out } = vidya_with_kernel(&input, Kernel::Scalar).unwrap();
1759
1760
1761                for i in 0..data.len() {
1762                    let y = out[i];
1763                    let r = ref_out[i];
1764
1765                    if !y.is_finite() || !r.is_finite() {
1766                        prop_assert!(y.to_bits() == r.to_bits(),
1767                            "[{}] finite/NaN mismatch at idx {}: {} vs {}", test_name, i, y, r);
1768                        continue;
1769                    }
1770
1771	                    let ulp_diff = y.to_bits().abs_diff(r.to_bits());
1772	                    prop_assert!(
1773
1774
1775	                        (y - r).abs() <= 5e-8 || ulp_diff <= 4,
1776	                        "[{}] kernel mismatch at idx {}: {} vs {} (ULP={})",
1777	                        test_name, i, y, r, ulp_diff
1778	                    );
1779	                }
1780
1781
1782				let first = data.iter().position(|&x| !x.is_nan()).unwrap_or(0);
1783				let first_valid_idx = if first + long_period >= 2 {
1784					first + long_period - 2
1785				} else {
1786					0
1787				};
1788
1789
1790				for i in 0..first_valid_idx.min(data.len()) {
1791					prop_assert!(out[i].is_nan(),
1792						"[{}] Expected NaN during warmup at idx {}, got {}", test_name, i, out[i]);
1793				}
1794
1795
1796				if first_valid_idx < data.len() {
1797					prop_assert!(!out[first_valid_idx].is_nan(),
1798						"[{}] Expected valid value at first_valid_idx {}, got NaN", test_name, first_valid_idx);
1799				}
1800
1801
1802				let warmup_end = first + long_period - 2;
1803
1804
1805				if data.len() > warmup_end + 1 {
1806					let data_min = data.iter().cloned().fold(f64::INFINITY, f64::min);
1807					let data_max = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1808
1809					let range = data_max - data_min;
1810
1811
1812					let alpha_factor = 1.0 + alpha * 2.0;
1813					let margin = if range < 1.0 {
1814
1815						let avg_magnitude = (data_max.abs() + data_min.abs()) / 2.0;
1816						avg_magnitude * 0.3 * alpha_factor
1817					} else {
1818
1819
1820						range * 0.5 * alpha_factor
1821					};
1822
1823					for i in (warmup_end + 1)..data.len() {
1824						let y = out[i];
1825						if y.is_finite() {
1826							prop_assert!(
1827								y >= data_min - margin && y <= data_max + margin,
1828								"[{}] Output {} at idx {} outside reasonable range [{}, {}] (alpha={:.3})",
1829								test_name, y, i, data_min - margin, data_max + margin, alpha
1830							);
1831						}
1832					}
1833				}
1834
1835
1836				if alpha < 0.05 && data.len() > warmup_end + 10 {
1837
1838					let vidya_section = &out[(warmup_end + 1)..];
1839					if vidya_section.len() > 2 {
1840
1841						for window in vidya_section.windows(2) {
1842							let change_ratio = (window[1] - window[0]).abs() / window[0].abs().max(1e-10);
1843							prop_assert!(
1844								change_ratio < 0.1,
1845								"[{}] With alpha={}, VIDYA should be stable but found large change ratio {}",
1846								test_name, alpha, change_ratio
1847							);
1848						}
1849					}
1850				}
1851
1852
1853				if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) && data.len() > warmup_end + 1 {
1854
1855					let constant_value = data[0];
1856					for i in (warmup_end + 1)..data.len() {
1857						prop_assert!(
1858							(out[i] - constant_value).abs() <= 1e-9,
1859							"[{}] Constant input should produce constant output, got {} expected {}",
1860							test_name, out[i], constant_value
1861						);
1862					}
1863				}
1864
1865
1866
1867
1868
1869				if alpha >= 0.05 && data.len() > warmup_end + 20 {
1870
1871					let mut same_direction_count = 0;
1872					let mut total_movements = 0;
1873					let mut frozen_periods = 0;
1874
1875					for i in (warmup_end + 1)..data.len() {
1876						let price_change = data[i] - data[i - 1];
1877						let vidya_change = out[i] - out[i - 1];
1878
1879
1880						if price_change.abs() > 1e-6 && vidya_change.abs() <= 1e-10 {
1881							frozen_periods += 1;
1882						}
1883
1884						else if price_change.abs() > 1e-6 && vidya_change.abs() > 1e-10 {
1885							total_movements += 1;
1886							if price_change.signum() == vidya_change.signum() {
1887								same_direction_count += 1;
1888							}
1889						}
1890					}
1891
1892
1893
1894
1895					if total_movements > 10 && frozen_periods < (data.len() - warmup_end) / 2 {
1896						let direction_ratio = same_direction_count as f64 / total_movements as f64;
1897						prop_assert!(
1898							direction_ratio >= 0.40,
1899							"[{}] VIDYA should generally follow price direction when moving, but only followed {:.1}% of the time (frozen for {} periods)",
1900							test_name, direction_ratio * 100.0, frozen_periods
1901						);
1902					}
1903				}
1904
1905
1906				for (i, &val) in out.iter().enumerate() {
1907					if val.is_finite() {
1908						let bits = val.to_bits();
1909						prop_assert!(
1910							bits != 0x11111111_11111111 &&
1911							bits != 0x22222222_22222222 &&
1912							bits != 0x33333333_33333333,
1913							"[{}] Found poison value {} (0x{:016X}) at index {}",
1914							test_name, val, bits, i
1915						);
1916					}
1917				}
1918
1919				Ok(())
1920			})
1921			.unwrap();
1922
1923        Ok(())
1924    }
1925
1926    macro_rules! generate_all_vidya_tests {
1927        ($($test_fn:ident),*) => {
1928            paste! {
1929                $(
1930                    #[test]
1931                    fn [<$test_fn _scalar_f64>]() {
1932                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1933                    }
1934                )*
1935                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1936                $(
1937                    #[test]
1938                    fn [<$test_fn _avx2_f64>]() {
1939                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1940                    }
1941                    #[test]
1942                    fn [<$test_fn _avx512_f64>]() {
1943                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1944                    }
1945                )*
1946                #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1947                $(
1948                    #[test]
1949                    fn [<$test_fn _simd128_f64>]() {
1950                        let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
1951                    }
1952                )*
1953            }
1954        }
1955    }
1956
1957    generate_all_vidya_tests!(
1958        check_vidya_partial_params,
1959        check_vidya_accuracy,
1960        check_vidya_default_candles,
1961        check_vidya_invalid_params,
1962        check_vidya_exceeding_data_length,
1963        check_vidya_very_small_data_set,
1964        check_vidya_reinput,
1965        check_vidya_nan_handling,
1966        check_vidya_streaming,
1967        check_vidya_no_poison
1968    );
1969
1970    #[cfg(feature = "proptest")]
1971    generate_all_vidya_tests!(check_vidya_property);
1972
1973    fn check_batch_default_row(
1974        test: &str,
1975        kernel: Kernel,
1976    ) -> Result<(), Box<dyn std::error::Error>> {
1977        skip_if_unsupported!(kernel, test);
1978
1979        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1980        let c = read_candles_from_csv(file)?;
1981
1982        let output = VidyaBatchBuilder::new()
1983            .kernel(kernel)
1984            .apply_candles(&c, "close")?;
1985
1986        let def = VidyaParams::default();
1987        let row = output.values_for(&def).expect("default row missing");
1988
1989        assert_eq!(row.len(), c.close.len());
1990
1991        let expected = [
1992            59553.42785306692,
1993            59503.60445032524,
1994            59451.72283651444,
1995            59413.222561244685,
1996            59239.716526894175,
1997        ];
1998        let start = row.len() - 5;
1999        for (i, &v) in row[start..].iter().enumerate() {
2000            assert!(
2001                (v - expected[i]).abs() < 1e-1,
2002                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2003            );
2004        }
2005        Ok(())
2006    }
2007
2008    #[cfg(debug_assertions)]
2009    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
2010        skip_if_unsupported!(kernel, test);
2011
2012        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2013        let c = read_candles_from_csv(file)?;
2014
2015        let test_configs = vec![
2016            (2, 5, 1, 5, 10, 1, 0.2, 0.2, 0.0),
2017            (2, 10, 2, 10, 30, 5, 0.1, 0.5, 0.1),
2018            (5, 20, 5, 30, 60, 10, 0.2, 0.2, 0.0),
2019            (10, 30, 10, 50, 100, 25, 0.3, 0.3, 0.0),
2020            (2, 2, 0, 5, 50, 5, 0.1, 0.9, 0.2),
2021            (5, 15, 5, 20, 20, 0, 0.2, 0.8, 0.3),
2022            (1, 3, 1, 4, 8, 2, 0.5, 0.5, 0.0),
2023            (20, 50, 15, 100, 200, 50, 0.1, 0.3, 0.1),
2024            (2, 2, 0, 3, 3, 0, 0.1, 1.0, 0.1),
2025        ];
2026
2027        for (cfg_idx, &(s_start, s_end, s_step, l_start, l_end, l_step, a_start, a_end, a_step)) in
2028            test_configs.iter().enumerate()
2029        {
2030            let output = VidyaBatchBuilder::new()
2031                .kernel(kernel)
2032                .short_period_range(s_start, s_end, s_step)
2033                .long_period_range(l_start, l_end, l_step)
2034                .alpha_range(a_start, a_end, a_step)
2035                .apply_candles(&c, "close")?;
2036
2037            for (idx, &val) in output.values.iter().enumerate() {
2038                if val.is_nan() {
2039                    continue;
2040                }
2041
2042                let bits = val.to_bits();
2043                let row = idx / output.cols;
2044                let col = idx % output.cols;
2045                let combo = &output.combos[row];
2046
2047                if bits == 0x11111111_11111111 {
2048                    panic!(
2049                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2050						 at row {} col {} (flat index {}) with params: {:?}",
2051                        test, cfg_idx, val, bits, row, col, idx, combo
2052                    );
2053                }
2054
2055                if bits == 0x22222222_22222222 {
2056                    panic!(
2057                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2058						 at row {} col {} (flat index {}) with params: {:?}",
2059                        test, cfg_idx, val, bits, row, col, idx, combo
2060                    );
2061                }
2062
2063                if bits == 0x33333333_33333333 {
2064                    panic!(
2065                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2066						 at row {} col {} (flat index {}) with params: {:?}",
2067                        test, cfg_idx, val, bits, row, col, idx, combo
2068                    );
2069                }
2070            }
2071        }
2072
2073        Ok(())
2074    }
2075
2076    #[cfg(not(debug_assertions))]
2077    fn check_batch_no_poison(
2078        _test: &str,
2079        _kernel: Kernel,
2080    ) -> Result<(), Box<dyn std::error::Error>> {
2081        Ok(())
2082    }
2083
2084    macro_rules! gen_batch_tests {
2085        ($fn_name:ident) => {
2086            paste! {
2087                #[test] fn [<$fn_name _scalar>]()      {
2088                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2089                }
2090                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2091                #[test] fn [<$fn_name _avx2>]()        {
2092                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2093                }
2094                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2095                #[test] fn [<$fn_name _avx512>]()      {
2096                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2097                }
2098                #[test] fn [<$fn_name _auto_detect>]() {
2099                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2100                }
2101            }
2102        };
2103    }
2104    gen_batch_tests!(check_batch_default_row);
2105    gen_batch_tests!(check_batch_no_poison);
2106}
2107
2108#[inline(always)]
2109pub fn vidya_batch_inner_into(
2110    data: &[f64],
2111    sweep: &VidyaBatchRange,
2112    kern: Kernel,
2113    parallel: bool,
2114    out: &mut [f64],
2115) -> Result<Vec<VidyaParams>, VidyaError> {
2116    let combos = expand_grid(sweep)?;
2117    if data.is_empty() {
2118        return Err(VidyaError::EmptyInputData);
2119    }
2120
2121    let first = data
2122        .iter()
2123        .position(|x| !x.is_nan())
2124        .ok_or(VidyaError::AllValuesNaN)?;
2125    let max_long = combos.iter().map(|c| c.long_period.unwrap()).max().unwrap();
2126    if data.len() - first < max_long {
2127        return Err(VidyaError::NotEnoughValidData {
2128            needed: max_long,
2129            valid: data.len() - first,
2130        });
2131    }
2132
2133    let rows = combos.len();
2134    let cols = data.len();
2135
2136    let expected = rows
2137        .checked_mul(cols)
2138        .ok_or_else(|| VidyaError::InvalidRange {
2139            start: rows.to_string(),
2140            end: cols.to_string(),
2141            step: "rows*cols".into(),
2142        })?;
2143    if out.len() < expected {
2144        return Err(VidyaError::OutputLengthMismatch {
2145            expected,
2146            got: out.len(),
2147        });
2148    }
2149
2150    let warmup_periods: Vec<usize> = combos
2151        .iter()
2152        .map(|c| first + c.long_period.unwrap() - 2)
2153        .collect();
2154
2155    for (row, &warmup) in warmup_periods.iter().enumerate() {
2156        let row_start = row * cols;
2157        out[row_start..row_start + warmup].fill(f64::NAN);
2158    }
2159
2160    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
2161        let p = &combos[row];
2162        let sp = p.short_period.unwrap();
2163        let lp = p.long_period.unwrap();
2164        let a = p.alpha.unwrap();
2165        match kern {
2166            Kernel::Scalar | Kernel::ScalarBatch => {
2167                vidya_row_scalar(data, first, sp, lp, a, out_row)
2168            }
2169            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2170            Kernel::Avx2 | Kernel::Avx2Batch => vidya_row_avx2(data, first, sp, lp, a, out_row),
2171            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2172            Kernel::Avx512 | Kernel::Avx512Batch => {
2173                vidya_row_avx512(data, first, sp, lp, a, out_row)
2174            }
2175            _ => vidya_row_scalar(data, first, sp, lp, a, out_row),
2176        }
2177    };
2178
2179    if parallel {
2180        #[cfg(not(target_arch = "wasm32"))]
2181        {
2182            out.par_chunks_mut(cols)
2183                .enumerate()
2184                .for_each(|(row, slice)| do_row(row, slice));
2185        }
2186        #[cfg(target_arch = "wasm32")]
2187        {
2188            for (row, slice) in out.chunks_mut(cols).enumerate() {
2189                do_row(row, slice);
2190            }
2191        }
2192    } else {
2193        for (row, slice) in out.chunks_mut(cols).enumerate() {
2194            do_row(row, slice);
2195        }
2196    }
2197
2198    Ok(combos)
2199}
2200
2201#[cfg(feature = "python")]
2202#[pyfunction(name = "vidya")]
2203#[pyo3(signature = (data, short_period, long_period, alpha, kernel=None))]
2204pub fn vidya_py<'py>(
2205    py: Python<'py>,
2206    data: PyReadonlyArray1<'py, f64>,
2207    short_period: usize,
2208    long_period: usize,
2209    alpha: f64,
2210    kernel: Option<&str>,
2211) -> PyResult<Bound<'py, PyArray1<f64>>> {
2212    let slice_in = data.as_slice()?;
2213    let kern = validate_kernel(kernel, false)?;
2214
2215    let params = VidyaParams {
2216        short_period: Some(short_period),
2217        long_period: Some(long_period),
2218        alpha: Some(alpha),
2219    };
2220    let input = VidyaInput::from_slice(slice_in, params);
2221
2222    let result_vec: Vec<f64> = py
2223        .allow_threads(|| vidya_with_kernel(&input, kern).map(|o| o.values))
2224        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2225
2226    Ok(result_vec.into_pyarray(py))
2227}
2228
2229#[cfg(feature = "python")]
2230#[pyclass(name = "VidyaStream")]
2231pub struct VidyaStreamPy {
2232    stream: VidyaStream,
2233}
2234
2235#[cfg(feature = "python")]
2236#[pymethods]
2237impl VidyaStreamPy {
2238    #[new]
2239    fn new(short_period: usize, long_period: usize, alpha: f64) -> PyResult<Self> {
2240        let params = VidyaParams {
2241            short_period: Some(short_period),
2242            long_period: Some(long_period),
2243            alpha: Some(alpha),
2244        };
2245        let stream =
2246            VidyaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2247        Ok(VidyaStreamPy { stream })
2248    }
2249
2250    fn update(&mut self, value: f64) -> Option<f64> {
2251        self.stream.update(value)
2252    }
2253}
2254
2255#[cfg(feature = "python")]
2256#[pyfunction(name = "vidya_batch")]
2257#[pyo3(signature = (data, short_period_range, long_period_range, alpha_range, kernel=None))]
2258pub fn vidya_batch_py<'py>(
2259    py: Python<'py>,
2260    data: PyReadonlyArray1<'py, f64>,
2261    short_period_range: (usize, usize, usize),
2262    long_period_range: (usize, usize, usize),
2263    alpha_range: (f64, f64, f64),
2264    kernel: Option<&str>,
2265) -> PyResult<Bound<'py, PyDict>> {
2266    let slice_in = data.as_slice()?;
2267
2268    let sweep = VidyaBatchRange {
2269        short_period: short_period_range,
2270        long_period: long_period_range,
2271        alpha: alpha_range,
2272    };
2273
2274    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2275    let rows = combos.len();
2276    let cols = slice_in.len();
2277
2278    let total = rows
2279        .checked_mul(cols)
2280        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2281    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2282    let slice_out = unsafe { out_arr.as_slice_mut()? };
2283
2284    let kern = validate_kernel(kernel, true)?;
2285
2286    let combos = py
2287        .allow_threads(|| {
2288            let kernel = match kern {
2289                Kernel::Auto => match detect_best_batch_kernel() {
2290                    Kernel::Avx512Batch => Kernel::Avx2Batch,
2291                    other => other,
2292                },
2293                k => k,
2294            };
2295            let simd = match kernel {
2296                Kernel::Avx512Batch => Kernel::Avx512,
2297                Kernel::Avx2Batch => Kernel::Avx2,
2298                Kernel::ScalarBatch => Kernel::Scalar,
2299                _ => kernel,
2300            };
2301            vidya_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2302        })
2303        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2304
2305    let dict = PyDict::new(py);
2306    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2307    dict.set_item(
2308        "short_periods",
2309        combos
2310            .iter()
2311            .map(|p| p.short_period.unwrap() as u64)
2312            .collect::<Vec<_>>()
2313            .into_pyarray(py),
2314    )?;
2315    dict.set_item(
2316        "long_periods",
2317        combos
2318            .iter()
2319            .map(|p| p.long_period.unwrap() as u64)
2320            .collect::<Vec<_>>()
2321            .into_pyarray(py),
2322    )?;
2323    dict.set_item(
2324        "alphas",
2325        combos
2326            .iter()
2327            .map(|p| p.alpha.unwrap())
2328            .collect::<Vec<_>>()
2329            .into_pyarray(py),
2330    )?;
2331
2332    Ok(dict)
2333}
2334
2335#[cfg(all(feature = "python", feature = "cuda"))]
2336#[pyclass(
2337    module = "ta_indicators.cuda",
2338    name = "VidyaDeviceArrayF32",
2339    unsendable
2340)]
2341pub struct VidyaDeviceArrayF32Py {
2342    pub(crate) inner: DeviceArrayF32,
2343    pub(crate) _ctx: std::sync::Arc<Context>,
2344    pub(crate) device_id: u32,
2345}
2346
2347#[cfg(all(feature = "python", feature = "cuda"))]
2348#[pymethods]
2349impl VidyaDeviceArrayF32Py {
2350    #[getter]
2351    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2352        let d = PyDict::new(py);
2353        d.set_item("shape", (self.inner.rows, self.inner.cols))?;
2354        d.set_item("typestr", "<f4")?;
2355        d.set_item(
2356            "strides",
2357            (
2358                self.inner.cols * std::mem::size_of::<f32>(),
2359                std::mem::size_of::<f32>(),
2360            ),
2361        )?;
2362        d.set_item("data", (self.inner.device_ptr() as usize, false))?;
2363
2364        d.set_item("version", 3)?;
2365        Ok(d)
2366    }
2367
2368    fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
2369        Ok((2, self.device_id as i32))
2370    }
2371
2372    #[pyo3(signature=(_stream=None, max_version=None, _dl_device=None, _copy=None))]
2373    fn __dlpack__<'py>(
2374        &mut self,
2375        py: Python<'py>,
2376        _stream: Option<pyo3::PyObject>,
2377        max_version: Option<pyo3::PyObject>,
2378        _dl_device: Option<pyo3::PyObject>,
2379        _copy: Option<pyo3::PyObject>,
2380    ) -> PyResult<PyObject> {
2381        let (kdl, alloc_dev) = self.__dlpack_device__()?;
2382        if let Some(dev_obj) = _dl_device.as_ref() {
2383            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2384                if dev_ty != kdl || dev_id != alloc_dev {
2385                    let wants_copy = _copy
2386                        .as_ref()
2387                        .and_then(|c| c.extract::<bool>(py).ok())
2388                        .unwrap_or(false);
2389                    if wants_copy {
2390                        return Err(PyValueError::new_err(
2391                            "device copy not implemented for __dlpack__",
2392                        ));
2393                    } else {
2394                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2395                    }
2396                }
2397            }
2398        }
2399        let _ = _stream;
2400
2401        let dummy =
2402            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2403        let inner = std::mem::replace(
2404            &mut self.inner,
2405            DeviceArrayF32 {
2406                buf: dummy,
2407                rows: 0,
2408                cols: 0,
2409            },
2410        );
2411
2412        let rows = inner.rows;
2413        let cols = inner.cols;
2414        let buf = inner.buf;
2415
2416        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2417
2418        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2419    }
2420}
2421
2422#[cfg(all(feature = "python", feature = "cuda"))]
2423#[pyfunction(name = "vidya_cuda_batch_dev")]
2424#[pyo3(signature = (data, short_period_range, long_period_range, alpha_range, device_id=0))]
2425pub fn vidya_cuda_batch_dev_py(
2426    py: Python<'_>,
2427    data: PyReadonlyArray1<'_, f32>,
2428    short_period_range: (usize, usize, usize),
2429    long_period_range: (usize, usize, usize),
2430    alpha_range: (f64, f64, f64),
2431    device_id: usize,
2432) -> PyResult<VidyaDeviceArrayF32Py> {
2433    use crate::cuda::cuda_available;
2434    if !cuda_available() {
2435        return Err(PyValueError::new_err("CUDA not available"));
2436    }
2437    let slice = data.as_slice()?;
2438    let sweep = VidyaBatchRange {
2439        short_period: short_period_range,
2440        long_period: long_period_range,
2441        alpha: alpha_range,
2442    };
2443    let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
2444        let cuda = CudaVidya::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2445        let arr = cuda
2446            .vidya_batch_dev(slice, &sweep)
2447            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2448        Ok::<_, PyErr>((arr, cuda.context_arc_clone(), cuda.device_id()))
2449    })?;
2450    Ok(VidyaDeviceArrayF32Py {
2451        inner,
2452        _ctx: ctx_arc,
2453        device_id: dev_id,
2454    })
2455}
2456
2457#[cfg(all(feature = "python", feature = "cuda"))]
2458#[pyfunction(name = "vidya_cuda_many_series_one_param_dev")]
2459#[pyo3(signature = (data_tm, cols, rows, short_period, long_period, alpha, device_id=0))]
2460pub fn vidya_cuda_many_series_one_param_dev_py(
2461    py: Python<'_>,
2462    data_tm: PyReadonlyArray1<'_, f32>,
2463    cols: usize,
2464    rows: usize,
2465    short_period: usize,
2466    long_period: usize,
2467    alpha: f64,
2468    device_id: usize,
2469) -> PyResult<VidyaDeviceArrayF32Py> {
2470    use crate::cuda::cuda_available;
2471    if !cuda_available() {
2472        return Err(PyValueError::new_err("CUDA not available"));
2473    }
2474    let slice = data_tm.as_slice()?;
2475    let expected = cols
2476        .checked_mul(rows)
2477        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2478    if slice.len() != expected {
2479        return Err(PyValueError::new_err("time-major input length mismatch"));
2480    }
2481    let params = VidyaParams {
2482        short_period: Some(short_period),
2483        long_period: Some(long_period),
2484        alpha: Some(alpha),
2485    };
2486    let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
2487        let cuda = CudaVidya::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2488        let arr = cuda
2489            .vidya_many_series_one_param_time_major_dev(slice, cols, rows, &params)
2490            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2491        Ok::<_, PyErr>((arr, cuda.context_arc_clone(), cuda.device_id()))
2492    })?;
2493    Ok(VidyaDeviceArrayF32Py {
2494        inner,
2495        _ctx: ctx_arc,
2496        device_id: dev_id,
2497    })
2498}
2499
2500#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2501#[wasm_bindgen]
2502pub fn vidya_js(
2503    data: &[f64],
2504    short_period: usize,
2505    long_period: usize,
2506    alpha: f64,
2507) -> Result<Vec<f64>, JsValue> {
2508    let params = VidyaParams {
2509        short_period: Some(short_period),
2510        long_period: Some(long_period),
2511        alpha: Some(alpha),
2512    };
2513    let input = VidyaInput::from_slice(data, params);
2514
2515    let mut output = vec![0.0; data.len()];
2516    vidya_into_slice(&mut output, &input, Kernel::Auto)
2517        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2518
2519    Ok(output)
2520}
2521
2522#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2523#[wasm_bindgen]
2524pub fn vidya_into(
2525    in_ptr: *const f64,
2526    out_ptr: *mut f64,
2527    len: usize,
2528    short_period: usize,
2529    long_period: usize,
2530    alpha: f64,
2531) -> Result<(), JsValue> {
2532    if in_ptr.is_null() || out_ptr.is_null() {
2533        return Err(JsValue::from_str("Null pointer provided"));
2534    }
2535
2536    unsafe {
2537        let data = std::slice::from_raw_parts(in_ptr, len);
2538        let params = VidyaParams {
2539            short_period: Some(short_period),
2540            long_period: Some(long_period),
2541            alpha: Some(alpha),
2542        };
2543        let input = VidyaInput::from_slice(data, params);
2544
2545        if in_ptr == out_ptr {
2546            let mut temp = vec![0.0; len];
2547            vidya_into_slice(&mut temp, &input, Kernel::Auto)
2548                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2549            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2550            out.copy_from_slice(&temp);
2551        } else {
2552            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2553            vidya_into_slice(out, &input, Kernel::Auto)
2554                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2555        }
2556        Ok(())
2557    }
2558}
2559
2560#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2561#[wasm_bindgen]
2562pub fn vidya_alloc(len: usize) -> *mut f64 {
2563    let mut vec = Vec::<f64>::with_capacity(len);
2564    let ptr = vec.as_mut_ptr();
2565    std::mem::forget(vec);
2566    ptr
2567}
2568
2569#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2570#[wasm_bindgen]
2571pub fn vidya_free(ptr: *mut f64, len: usize) {
2572    if !ptr.is_null() {
2573        unsafe {
2574            let _ = Vec::from_raw_parts(ptr, len, len);
2575        }
2576    }
2577}
2578
2579#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2580#[derive(Serialize, Deserialize)]
2581pub struct VidyaBatchConfig {
2582    pub short_period_range: (usize, usize, usize),
2583    pub long_period_range: (usize, usize, usize),
2584    pub alpha_range: (f64, f64, f64),
2585}
2586
2587#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2588#[derive(Serialize, Deserialize)]
2589pub struct VidyaBatchJsOutput {
2590    pub values: Vec<f64>,
2591    pub combos: Vec<VidyaParams>,
2592    pub rows: usize,
2593    pub cols: usize,
2594}
2595
2596#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2597#[wasm_bindgen(js_name = vidya_batch)]
2598pub fn vidya_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2599    let config: VidyaBatchConfig = serde_wasm_bindgen::from_value(config)
2600        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2601
2602    let sweep = VidyaBatchRange {
2603        short_period: config.short_period_range,
2604        long_period: config.long_period_range,
2605        alpha: config.alpha_range,
2606    };
2607
2608    let mut output = vec![0.0; data.len() * 232];
2609    let kernel = match detect_best_kernel() {
2610        Kernel::Avx512 => Kernel::Avx2,
2611        other => other,
2612    };
2613    let combos = vidya_batch_inner_into(data, &sweep, kernel, false, &mut output)
2614        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2615
2616    let rows = combos.len();
2617    let cols = data.len();
2618    output.truncate(rows * cols);
2619
2620    let result = VidyaBatchJsOutput {
2621        values: output,
2622        combos,
2623        rows,
2624        cols,
2625    };
2626
2627    serde_wasm_bindgen::to_value(&result)
2628        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2629}
2630
2631#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2632#[wasm_bindgen]
2633pub fn vidya_batch_into(
2634    in_ptr: *const f64,
2635    out_ptr: *mut f64,
2636    len: usize,
2637    short_period_start: usize,
2638    short_period_end: usize,
2639    short_period_step: usize,
2640    long_period_start: usize,
2641    long_period_end: usize,
2642    long_period_step: usize,
2643    alpha_start: f64,
2644    alpha_end: f64,
2645    alpha_step: f64,
2646) -> Result<usize, JsValue> {
2647    if in_ptr.is_null() || out_ptr.is_null() {
2648        return Err(JsValue::from_str("Null pointer provided"));
2649    }
2650
2651    unsafe {
2652        let data = std::slice::from_raw_parts(in_ptr, len);
2653        let sweep = VidyaBatchRange {
2654            short_period: (short_period_start, short_period_end, short_period_step),
2655            long_period: (long_period_start, long_period_end, long_period_step),
2656            alpha: (alpha_start, alpha_end, alpha_step),
2657        };
2658
2659        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2660        let total_size = combos
2661            .len()
2662            .checked_mul(len)
2663            .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
2664        let out = std::slice::from_raw_parts_mut(out_ptr, total_size);
2665
2666        let kernel = match detect_best_kernel() {
2667            Kernel::Avx512 => Kernel::Avx2,
2668            other => other,
2669        };
2670        vidya_batch_inner_into(data, &sweep, kernel, false, out)
2671            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2672
2673        Ok(combos.len())
2674    }
2675}