1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::moving_averages::DeviceArrayF32;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::CudaVidya;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
7#[cfg(feature = "python")]
8use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
9#[cfg(feature = "python")]
10use pyo3::exceptions::PyValueError;
11#[cfg(feature = "python")]
12use pyo3::prelude::*;
13#[cfg(feature = "python")]
14use pyo3::types::PyDict;
15#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
16use serde::{Deserialize, Serialize};
17#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
18use wasm_bindgen::prelude::*;
19
20use crate::utilities::data_loader::{source_type, Candles};
21use crate::utilities::enums::Kernel;
22use crate::utilities::helpers::{
23 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
24 make_uninit_matrix,
25};
26#[cfg(feature = "python")]
27use crate::utilities::kernel_validation::validate_kernel;
28use aligned_vec::{AVec, CACHELINE_ALIGN};
29#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
30use core::arch::x86_64::*;
31#[cfg(all(feature = "python", feature = "cuda"))]
32use cust::context::Context;
33#[cfg(all(feature = "python", feature = "cuda"))]
34use cust::memory::DeviceBuffer;
35use paste::paste;
36#[cfg(not(target_arch = "wasm32"))]
37use rayon::prelude::*;
38use std::convert::AsRef;
39use thiserror::Error;
40
41impl<'a> AsRef<[f64]> for VidyaInput<'a> {
42 #[inline(always)]
43 fn as_ref(&self) -> &[f64] {
44 match &self.data {
45 VidyaData::Slice(slice) => slice,
46 VidyaData::Candles { candles, source } => source_type(candles, source),
47 }
48 }
49}
50
51#[cfg(test)]
52mod tests_into_parity {
53 use super::*;
54 use crate::utilities::data_loader::read_candles_from_csv;
55
56 #[test]
57 fn test_vidya_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
58 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
59 let candles = read_candles_from_csv(file_path)?;
60 let input = VidyaInput::with_default_candles(&candles);
61
62 let baseline = vidya(&input)?.values;
63
64 let mut out = vec![0.0; candles.close.len()];
65 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
66 {
67 vidya_into(&input, &mut out)?;
68 }
69 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
70 {
71 vidya_into_slice(&mut out, &input, Kernel::Auto)?;
72 }
73
74 assert_eq!(baseline.len(), out.len());
75
76 fn eq_or_both_nan(a: f64, b: f64) -> bool {
77 (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
78 }
79
80 for (i, (&a, &b)) in baseline.iter().zip(out.iter()).enumerate() {
81 assert!(
82 eq_or_both_nan(a, b),
83 "Mismatch at index {}: baseline={} into={}",
84 i,
85 a,
86 b
87 );
88 }
89
90 Ok(())
91 }
92}
93
94#[derive(Debug, Clone)]
95pub enum VidyaData<'a> {
96 Candles {
97 candles: &'a Candles,
98 source: &'a str,
99 },
100 Slice(&'a [f64]),
101}
102
103#[derive(Debug, Clone)]
104pub struct VidyaOutput {
105 pub values: Vec<f64>,
106}
107
108#[derive(Debug, Clone)]
109#[cfg_attr(
110 all(target_arch = "wasm32", feature = "wasm"),
111 derive(Serialize, Deserialize)
112)]
113pub struct VidyaParams {
114 pub short_period: Option<usize>,
115 pub long_period: Option<usize>,
116 pub alpha: Option<f64>,
117}
118
119impl Default for VidyaParams {
120 fn default() -> Self {
121 Self {
122 short_period: Some(2),
123 long_period: Some(5),
124 alpha: Some(0.2),
125 }
126 }
127}
128
129#[derive(Debug, Clone)]
130pub struct VidyaInput<'a> {
131 pub data: VidyaData<'a>,
132 pub params: VidyaParams,
133}
134
135impl<'a> VidyaInput<'a> {
136 #[inline]
137 pub fn from_candles(c: &'a Candles, s: &'a str, p: VidyaParams) -> Self {
138 Self {
139 data: VidyaData::Candles {
140 candles: c,
141 source: s,
142 },
143 params: p,
144 }
145 }
146 #[inline]
147 pub fn from_slice(sl: &'a [f64], p: VidyaParams) -> Self {
148 Self {
149 data: VidyaData::Slice(sl),
150 params: p,
151 }
152 }
153 #[inline]
154 pub fn with_default_candles(c: &'a Candles) -> Self {
155 Self::from_candles(c, "close", VidyaParams::default())
156 }
157 #[inline]
158 pub fn get_short_period(&self) -> usize {
159 self.params.short_period.unwrap_or(2)
160 }
161 #[inline]
162 pub fn get_long_period(&self) -> usize {
163 self.params.long_period.unwrap_or(5)
164 }
165 #[inline]
166 pub fn get_alpha(&self) -> f64 {
167 self.params.alpha.unwrap_or(0.2)
168 }
169}
170
171#[derive(Copy, Clone, Debug)]
172pub struct VidyaBuilder {
173 short_period: Option<usize>,
174 long_period: Option<usize>,
175 alpha: Option<f64>,
176 kernel: Kernel,
177}
178
179impl Default for VidyaBuilder {
180 fn default() -> Self {
181 Self {
182 short_period: None,
183 long_period: None,
184 alpha: None,
185 kernel: Kernel::Auto,
186 }
187 }
188}
189
190impl VidyaBuilder {
191 #[inline(always)]
192 pub fn new() -> Self {
193 Self::default()
194 }
195 #[inline(always)]
196 pub fn short_period(mut self, n: usize) -> Self {
197 self.short_period = Some(n);
198 self
199 }
200 #[inline(always)]
201 pub fn long_period(mut self, n: usize) -> Self {
202 self.long_period = Some(n);
203 self
204 }
205 #[inline(always)]
206 pub fn alpha(mut self, a: f64) -> Self {
207 self.alpha = Some(a);
208 self
209 }
210 #[inline(always)]
211 pub fn kernel(mut self, k: Kernel) -> Self {
212 self.kernel = k;
213 self
214 }
215 #[inline(always)]
216 pub fn apply(self, c: &Candles) -> Result<VidyaOutput, VidyaError> {
217 let p = VidyaParams {
218 short_period: self.short_period,
219 long_period: self.long_period,
220 alpha: self.alpha,
221 };
222 let i = VidyaInput::from_candles(c, "close", p);
223 vidya_with_kernel(&i, self.kernel)
224 }
225 #[inline(always)]
226 pub fn apply_slice(self, d: &[f64]) -> Result<VidyaOutput, VidyaError> {
227 let p = VidyaParams {
228 short_period: self.short_period,
229 long_period: self.long_period,
230 alpha: self.alpha,
231 };
232 let i = VidyaInput::from_slice(d, p);
233 vidya_with_kernel(&i, self.kernel)
234 }
235 #[inline(always)]
236 pub fn into_stream(self) -> Result<VidyaStream, VidyaError> {
237 let p = VidyaParams {
238 short_period: self.short_period,
239 long_period: self.long_period,
240 alpha: self.alpha,
241 };
242 VidyaStream::try_new(p)
243 }
244}
245
246#[derive(Debug, Error)]
247pub enum VidyaError {
248 #[error("vidya: Input data slice is empty.")]
249 EmptyInputData,
250 #[error("vidya: All values are NaN.")]
251 AllValuesNaN,
252 #[error("vidya: Invalid period: period = {period}, data length = {data_len}")]
253 InvalidPeriod { period: usize, data_len: usize },
254 #[error("vidya: Not enough valid data: needed = {needed}, valid = {valid}")]
255 NotEnoughValidData { needed: usize, valid: usize },
256 #[error("vidya: Invalid alpha: {alpha}")]
257 InvalidAlpha { alpha: f64 },
258 #[error("vidya: Output length mismatch: expected {expected}, got {got}")]
259 OutputLengthMismatch { expected: usize, got: usize },
260 #[error("vidya: Invalid range: start={start}, end={end}, step={step}")]
261 InvalidRange {
262 start: String,
263 end: String,
264 step: String,
265 },
266 #[error("vidya: Invalid kernel for batch: {0:?}")]
267 InvalidKernelForBatch(Kernel),
268}
269
270#[inline]
271pub fn vidya(input: &VidyaInput) -> Result<VidyaOutput, VidyaError> {
272 vidya_with_kernel(input, Kernel::Auto)
273}
274
275pub fn vidya_with_kernel(input: &VidyaInput, kernel: Kernel) -> Result<VidyaOutput, VidyaError> {
276 let data: &[f64] = match &input.data {
277 VidyaData::Candles { candles, source } => source_type(candles, source),
278 VidyaData::Slice(sl) => sl,
279 };
280
281 if data.is_empty() {
282 return Err(VidyaError::EmptyInputData);
283 }
284
285 let short_period = input.get_short_period();
286 let long_period = input.get_long_period();
287 let alpha = input.get_alpha();
288
289 if short_period < 2 {
290 return Err(VidyaError::InvalidPeriod {
291 period: short_period,
292 data_len: data.len(),
293 });
294 }
295 if long_period < short_period || long_period < 2 || long_period > data.len() {
296 return Err(VidyaError::InvalidPeriod {
297 period: long_period,
298 data_len: data.len(),
299 });
300 }
301 if !(0.0..=1.0).contains(&alpha) || alpha.is_nan() || alpha.is_infinite() {
302 return Err(VidyaError::InvalidAlpha { alpha });
303 }
304
305 let first = data
306 .iter()
307 .position(|&x| !x.is_nan())
308 .ok_or(VidyaError::AllValuesNaN)?;
309 if (data.len() - first) < long_period {
310 return Err(VidyaError::NotEnoughValidData {
311 needed: long_period,
312 valid: data.len() - first,
313 });
314 }
315
316 let chosen = match kernel {
317 Kernel::Auto => match detect_best_kernel() {
318 Kernel::Avx512 => Kernel::Avx2,
319 other => other,
320 },
321 other => other,
322 };
323
324 let warmup_period = first + long_period - 2;
325 let mut out = alloc_with_nan_prefix(data.len(), warmup_period);
326 unsafe {
327 match chosen {
328 Kernel::Scalar | Kernel::ScalarBatch => {
329 vidya_scalar(data, short_period, long_period, alpha, first, &mut out)
330 }
331 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
332 Kernel::Avx2 | Kernel::Avx2Batch => {
333 vidya_avx2(data, short_period, long_period, alpha, first, &mut out)
334 }
335 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
336 Kernel::Avx512 | Kernel::Avx512Batch => {
337 vidya_avx512(data, short_period, long_period, alpha, first, &mut out)
338 }
339 _ => vidya_scalar(data, short_period, long_period, alpha, first, &mut out),
340 }
341 }
342
343 Ok(VidyaOutput { values: out })
344}
345
346#[inline]
347pub fn vidya_into_slice(
348 dst: &mut [f64],
349 input: &VidyaInput,
350 kern: Kernel,
351) -> Result<(), VidyaError> {
352 let data: &[f64] = match &input.data {
353 VidyaData::Candles { candles, source } => source_type(candles, source),
354 VidyaData::Slice(sl) => sl,
355 };
356
357 if data.is_empty() {
358 return Err(VidyaError::EmptyInputData);
359 }
360
361 if dst.len() != data.len() {
362 return Err(VidyaError::OutputLengthMismatch {
363 expected: data.len(),
364 got: dst.len(),
365 });
366 }
367
368 let short_period = input.get_short_period();
369 let long_period = input.get_long_period();
370 let alpha = input.get_alpha();
371
372 if short_period < 2 {
373 return Err(VidyaError::InvalidPeriod {
374 period: short_period,
375 data_len: data.len(),
376 });
377 }
378 if long_period < short_period || long_period < 2 || long_period > data.len() {
379 return Err(VidyaError::InvalidPeriod {
380 period: long_period,
381 data_len: data.len(),
382 });
383 }
384 if !(0.0..=1.0).contains(&alpha) || alpha.is_nan() || alpha.is_infinite() {
385 return Err(VidyaError::InvalidAlpha { alpha });
386 }
387
388 let first = data
389 .iter()
390 .position(|&x| !x.is_nan())
391 .ok_or(VidyaError::AllValuesNaN)?;
392 if (data.len() - first) < long_period {
393 return Err(VidyaError::NotEnoughValidData {
394 needed: long_period,
395 valid: data.len() - first,
396 });
397 }
398
399 let chosen = match kern {
400 Kernel::Auto => match detect_best_kernel() {
401 Kernel::Avx512 => Kernel::Avx2,
402 other => other,
403 },
404 other => other,
405 };
406
407 let warmup_period = first + long_period - 2;
408
409 unsafe {
410 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
411 {
412 if matches!(chosen, Kernel::Scalar | Kernel::ScalarBatch) {
413 vidya_simd128(data, short_period, long_period, alpha, first, dst);
414
415 for v in &mut dst[..warmup_period] {
416 *v = f64::NAN;
417 }
418 return Ok(());
419 }
420 }
421
422 match chosen {
423 Kernel::Scalar | Kernel::ScalarBatch => {
424 vidya_scalar(data, short_period, long_period, alpha, first, dst)
425 }
426 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
427 Kernel::Avx2 | Kernel::Avx2Batch => {
428 vidya_avx2(data, short_period, long_period, alpha, first, dst)
429 }
430 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
431 Kernel::Avx512 | Kernel::Avx512Batch => {
432 vidya_avx512(data, short_period, long_period, alpha, first, dst)
433 }
434 _ => vidya_scalar(data, short_period, long_period, alpha, first, dst),
435 }
436 }
437
438 for v in &mut dst[..warmup_period] {
439 *v = f64::NAN;
440 }
441
442 Ok(())
443}
444
445#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
446#[inline]
447pub fn vidya_into(input: &VidyaInput, out: &mut [f64]) -> Result<(), VidyaError> {
448 vidya_into_slice(out, input, Kernel::Auto)
449}
450
451#[inline]
452pub unsafe fn vidya_scalar(
453 data: &[f64],
454 short_period: usize,
455 long_period: usize,
456 alpha: f64,
457 first: usize,
458 out: &mut [f64],
459) {
460 let len = data.len();
461
462 let mut long_sum = 0.0_f64;
463 let mut long_sum2 = 0.0_f64;
464 let mut short_sum = 0.0_f64;
465 let mut short_sum2 = 0.0_f64;
466
467 let sp_f = short_period as f64;
468 let lp_f = long_period as f64;
469 let short_inv = 1.0 / sp_f;
470 let long_inv = 1.0 / lp_f;
471
472 let warm_end = first + long_period;
473 let short_head = warm_end - short_period;
474
475 for i in first..short_head {
476 let x = data[i];
477 long_sum += x;
478
479 long_sum2 = x.mul_add(x, long_sum2);
480 }
481
482 for i in short_head..warm_end {
483 let x = data[i];
484 long_sum += x;
485 long_sum2 = x.mul_add(x, long_sum2);
486 short_sum += x;
487 short_sum2 = x.mul_add(x, short_sum2);
488 }
489
490 let idx_m2 = warm_end - 2;
491 let idx_m1 = warm_end - 1;
492
493 let mut val = data[idx_m2];
494 out[idx_m2] = val;
495
496 if idx_m1 < len {
497 let short_mean = short_sum * short_inv;
498 let long_mean = long_sum * long_inv;
499 let short_var = short_sum2 * short_inv - (short_mean * short_mean);
500 let long_var = long_sum2 * long_inv - (long_mean * long_mean);
501 let short_std = short_var.sqrt();
502 let long_std = long_var.sqrt();
503
504 let mut k = short_std / long_std;
505 if k.is_nan() {
506 k = 0.0;
507 }
508 k *= alpha;
509
510 let x = data[idx_m1];
511
512 val = (x - val).mul_add(k, val);
513 out[idx_m1] = val;
514 }
515
516 for t in warm_end..len {
517 let x_new = data[t];
518 let x_new2 = x_new * x_new;
519
520 long_sum += x_new;
521 long_sum2 += x_new2;
522 short_sum += x_new;
523 short_sum2 += x_new2;
524
525 let x_long_out = data[t - long_period];
526 let x_short_out = data[t - short_period];
527 long_sum -= x_long_out;
528
529 long_sum2 = (-x_long_out).mul_add(x_long_out, long_sum2);
530 short_sum -= x_short_out;
531 short_sum2 = (-x_short_out).mul_add(x_short_out, short_sum2);
532
533 let short_mean = short_sum * short_inv;
534 let long_mean = long_sum * long_inv;
535 let short_var = short_sum2 * short_inv - (short_mean * short_mean);
536 let long_var = long_sum2 * long_inv - (long_mean * long_mean);
537 let short_std = short_var.sqrt();
538 let long_std = long_var.sqrt();
539
540 let mut k = short_std / long_std;
541 if k.is_nan() {
542 k = 0.0;
543 }
544 k *= alpha;
545
546 val = (x_new - val).mul_add(k, val);
547 out[t] = val;
548 }
549}
550
551#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
552#[inline]
553unsafe fn vidya_simd128(
554 data: &[f64],
555 short_period: usize,
556 long_period: usize,
557 alpha: f64,
558 first: usize,
559 out: &mut [f64],
560) {
561 use core::arch::wasm32::*;
562
563 let len = data.len();
564 let mut long_sum = 0.0;
565 let mut long_sum2 = 0.0;
566 let mut short_sum = 0.0;
567 let mut short_sum2 = 0.0;
568
569 for i in first..(first + long_period) {
570 long_sum += data[i];
571 long_sum2 += data[i] * data[i];
572 if i >= (first + long_period - short_period) {
573 short_sum += data[i];
574 short_sum2 += data[i] * data[i];
575 }
576 }
577
578 let mut val = data[first + long_period - 2];
579 out[first + long_period - 2] = val;
580
581 if first + long_period - 1 < data.len() {
582 let sp = short_period as f64;
583 let lp = long_period as f64;
584 let short_div = 1.0 / sp;
585 let long_div = 1.0 / lp;
586 let short_stddev =
587 (short_sum2 * short_div - (short_sum * short_div) * (short_sum * short_div)).sqrt();
588 let long_stddev =
589 (long_sum2 * long_div - (long_sum * long_div) * (long_sum * long_div)).sqrt();
590 let mut k = short_stddev / long_stddev;
591 if k.is_nan() {
592 k = 0.0;
593 }
594 k *= alpha;
595 val = (data[first + long_period - 1] - val) * k + val;
596 out[first + long_period - 1] = val;
597 }
598
599 let alpha_v = f64x2_splat(alpha);
600 let sp_v = f64x2_splat(short_period as f64);
601 let lp_v = f64x2_splat(long_period as f64);
602 let short_div_v = f64x2_splat(1.0 / short_period as f64);
603 let long_div_v = f64x2_splat(1.0 / long_period as f64);
604
605 for i in (first + long_period)..len {
606 let new_val = data[i];
607 long_sum += new_val;
608 long_sum2 += new_val * new_val;
609 short_sum += new_val;
610 short_sum2 += new_val * new_val;
611
612 let remove_long_idx = i - long_period;
613 let remove_short_idx = i - short_period;
614 let remove_long = data[remove_long_idx];
615 let remove_short = data[remove_short_idx];
616
617 long_sum -= remove_long;
618 long_sum2 -= remove_long * remove_long;
619 short_sum -= remove_short;
620 short_sum2 -= remove_short * remove_short;
621
622 let short_mean = short_sum / short_period as f64;
623 let long_mean = long_sum / long_period as f64;
624
625 let short_variance = short_sum2 / short_period as f64 - short_mean * short_mean;
626 let long_variance = long_sum2 / long_period as f64 - long_mean * long_mean;
627
628 let short_stddev = short_variance.sqrt();
629 let long_stddev = long_variance.sqrt();
630
631 let mut k = short_stddev / long_stddev;
632 if k.is_nan() {
633 k = 0.0;
634 }
635 k *= alpha;
636 val = (new_val - val) * k + val;
637 out[i] = val;
638 }
639}
640
641#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
642#[inline]
643pub unsafe fn vidya_avx2(
644 data: &[f64],
645 short_period: usize,
646 long_period: usize,
647 alpha: f64,
648 first: usize,
649 out: &mut [f64],
650) {
651 vidya_avx2_experimental(data, short_period, long_period, alpha, first, out);
652}
653
654#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
655#[inline]
656unsafe fn vidya_avx2_experimental(
657 data: &[f64],
658 short_period: usize,
659 long_period: usize,
660 alpha: f64,
661 first: usize,
662 out: &mut [f64],
663) {
664 let len = data.len();
665 let ptr = data.as_ptr();
666 let out_ptr = out.as_mut_ptr();
667
668 let mut long_sum = 0.0_f64;
669 let mut long_sum2 = 0.0_f64;
670 let mut short_sum = 0.0_f64;
671 let mut short_sum2 = 0.0_f64;
672
673 let sp_f = short_period as f64;
674 let lp_f = long_period as f64;
675 let short_inv = 1.0 / sp_f;
676 let long_inv = 1.0 / lp_f;
677
678 let warm_end = first + long_period;
679 let short_head = warm_end - short_period;
680
681 let mut i = first;
682 while i < short_head {
683 let x = *ptr.add(i);
684 long_sum += x;
685 long_sum2 = x.mul_add(x, long_sum2);
686 i += 1;
687 }
688
689 while i < warm_end {
690 let x = *ptr.add(i);
691 long_sum += x;
692 long_sum2 = x.mul_add(x, long_sum2);
693 short_sum += x;
694 short_sum2 = x.mul_add(x, short_sum2);
695 i += 1;
696 }
697
698 let idx_m2 = warm_end - 2;
699 let idx_m1 = warm_end - 1;
700
701 let mut val = *ptr.add(idx_m2);
702 *out_ptr.add(idx_m2) = val;
703
704 if idx_m1 < len {
705 let short_mean = short_sum * short_inv;
706 let long_mean = long_sum * long_inv;
707 let short_var = short_sum2 * short_inv - (short_mean * short_mean);
708 let long_var = long_sum2 * long_inv - (long_mean * long_mean);
709 let short_std = short_var.sqrt();
710 let long_std = long_var.sqrt();
711
712 let mut k = short_std / long_std;
713 if k.is_nan() {
714 k = 0.0;
715 }
716 k *= alpha;
717
718 let x = *ptr.add(idx_m1);
719
720 val = (x - val).mul_add(k, val);
721 *out_ptr.add(idx_m1) = val;
722 }
723
724 let mut t = warm_end;
725 while t < len {
726 let x_new = *ptr.add(t);
727 let x_new2 = x_new * x_new;
728
729 long_sum += x_new;
730 long_sum2 = x_new.mul_add(x_new, long_sum2);
731 short_sum += x_new;
732 short_sum2 = x_new.mul_add(x_new, short_sum2);
733
734 let x_long_out = *ptr.add(t - long_period);
735 let x_short_out = *ptr.add(t - short_period);
736 long_sum -= x_long_out;
737 long_sum2 = (-x_long_out).mul_add(x_long_out, long_sum2);
738 short_sum -= x_short_out;
739 short_sum2 = (-x_short_out).mul_add(x_short_out, short_sum2);
740
741 let short_mean = short_sum * short_inv;
742 let long_mean = long_sum * long_inv;
743 let short_var = short_sum2 * short_inv - (short_mean * short_mean);
744 let long_var = long_sum2 * long_inv - (long_mean * long_mean);
745 let short_std = short_var.sqrt();
746 let long_std = long_var.sqrt();
747
748 let mut k = short_std / long_std;
749 if k.is_nan() {
750 k = 0.0;
751 }
752 k *= alpha;
753
754 val = (x_new - val).mul_add(k, val);
755 *out_ptr.add(t) = val;
756
757 t += 1;
758 }
759}
760
761#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
762#[inline]
763pub unsafe fn vidya_avx512(
764 data: &[f64],
765 short_period: usize,
766 long_period: usize,
767 alpha: f64,
768 first: usize,
769 out: &mut [f64],
770) {
771 if long_period <= 32 {
772 vidya_avx512_short(data, short_period, long_period, alpha, first, out)
773 } else {
774 vidya_avx512_long(data, short_period, long_period, alpha, first, out)
775 }
776}
777
778#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
779#[inline]
780pub unsafe fn vidya_avx512_short(
781 data: &[f64],
782 short_period: usize,
783 long_period: usize,
784 alpha: f64,
785 first: usize,
786 out: &mut [f64],
787) {
788 vidya_scalar(data, short_period, long_period, alpha, first, out);
789}
790
791#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
792#[inline]
793pub unsafe fn vidya_avx512_long(
794 data: &[f64],
795 short_period: usize,
796 long_period: usize,
797 alpha: f64,
798 first: usize,
799 out: &mut [f64],
800) {
801 vidya_scalar(data, short_period, long_period, alpha, first, out);
802}
803
804#[derive(Debug, Clone)]
805pub struct VidyaStream {
806 short_period: usize,
807 long_period: usize,
808 alpha: f64,
809 long_buf: Vec<f64>,
810 short_buf: Vec<f64>,
811 long_sum: f64,
812 long_sum2: f64,
813 short_sum: f64,
814 short_sum2: f64,
815 head: usize,
816 idx: usize,
817 val: f64,
818 filled: bool,
819}
820
821impl VidyaStream {
822 pub fn try_new(params: VidyaParams) -> Result<Self, VidyaError> {
823 let short_period = params.short_period.unwrap_or(2);
824 let long_period = params.long_period.unwrap_or(5);
825 let alpha = params.alpha.unwrap_or(0.2);
826
827 if short_period < 2 || long_period < short_period || long_period < 2 {
828 return Err(VidyaError::InvalidPeriod {
829 period: long_period,
830 data_len: 0,
831 });
832 }
833 if !(0.0..=1.0).contains(&alpha) || alpha.is_nan() || alpha.is_infinite() {
834 return Err(VidyaError::InvalidAlpha { alpha });
835 }
836 Ok(Self {
837 short_period,
838 long_period,
839 alpha,
840 long_buf: alloc_with_nan_prefix(long_period, long_period),
841 short_buf: alloc_with_nan_prefix(short_period, short_period),
842 long_sum: 0.0,
843 long_sum2: 0.0,
844 short_sum: 0.0,
845 short_sum2: 0.0,
846 head: 0,
847 idx: 0,
848 val: f64::NAN,
849 filled: false,
850 })
851 }
852
853 #[inline(always)]
854 pub fn update(&mut self, x: f64) -> Option<f64> {
855 let long_tail = self.long_buf[self.head];
856 let short_idx = self.idx % self.short_period;
857 let short_tail = self.short_buf[short_idx];
858
859 let phase2_start = self.long_period - self.short_period;
860
861 self.long_sum += x;
862 self.long_sum2 = x.mul_add(x, self.long_sum2);
863
864 if self.idx >= phase2_start {
865 self.short_sum += x;
866 self.short_sum2 = x.mul_add(x, self.short_sum2);
867 }
868
869 if self.idx >= self.long_period {
870 self.long_sum -= long_tail;
871 self.long_sum2 = (-long_tail).mul_add(long_tail, self.long_sum2);
872
873 self.short_sum -= short_tail;
874 self.short_sum2 = (-short_tail).mul_add(short_tail, self.short_sum2);
875 }
876
877 self.long_buf[self.head] = x;
878 self.short_buf[short_idx] = x;
879
880 let mut h = self.head + 1;
881 if h == self.long_period {
882 h = 0;
883 }
884 self.head = h;
885
886 self.idx += 1;
887
888 if self.idx < self.long_period - 1 {
889 self.val = x;
890 return None;
891 }
892 if self.idx == self.long_period - 1 {
893 self.val = x;
894 return Some(self.val);
895 }
896
897 let short_inv = 1.0 / (self.short_period as f64);
898 let long_inv = 1.0 / (self.long_period as f64);
899
900 let short_mean = self.short_sum * short_inv;
901 let long_mean = self.long_sum * long_inv;
902
903 let short_var = self.short_sum2 * short_inv - (short_mean * short_mean);
904 let long_var = self.long_sum2 * long_inv - (long_mean * long_mean);
905
906 let mut k = 0.0;
907 if long_var > 0.0 && short_var > 0.0 {
908 k = (short_var / long_var).sqrt() * self.alpha;
909 }
910
911 self.val = (x - self.val).mul_add(k, self.val);
912 Some(self.val)
913 }
914}
915
916#[derive(Clone, Debug)]
917pub struct VidyaBatchRange {
918 pub short_period: (usize, usize, usize),
919 pub long_period: (usize, usize, usize),
920 pub alpha: (f64, f64, f64),
921}
922
923impl Default for VidyaBatchRange {
924 fn default() -> Self {
925 Self {
926 short_period: (2, 2, 0),
927 long_period: (5, 254, 1),
928 alpha: (0.2, 0.2, 0.0),
929 }
930 }
931}
932
933#[derive(Clone, Debug, Default)]
934pub struct VidyaBatchBuilder {
935 range: VidyaBatchRange,
936 kernel: Kernel,
937}
938
939impl VidyaBatchBuilder {
940 pub fn new() -> Self {
941 Self::default()
942 }
943 pub fn kernel(mut self, k: Kernel) -> Self {
944 self.kernel = k;
945 self
946 }
947 #[inline]
948 pub fn short_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
949 self.range.short_period = (start, end, step);
950 self
951 }
952 #[inline]
953 pub fn short_period_static(mut self, n: usize) -> Self {
954 self.range.short_period = (n, n, 0);
955 self
956 }
957 #[inline]
958 pub fn long_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
959 self.range.long_period = (start, end, step);
960 self
961 }
962 #[inline]
963 pub fn long_period_static(mut self, n: usize) -> Self {
964 self.range.long_period = (n, n, 0);
965 self
966 }
967 #[inline]
968 pub fn alpha_range(mut self, start: f64, end: f64, step: f64) -> Self {
969 self.range.alpha = (start, end, step);
970 self
971 }
972 #[inline]
973 pub fn alpha_static(mut self, a: f64) -> Self {
974 self.range.alpha = (a, a, 0.0);
975 self
976 }
977 pub fn apply_slice(self, data: &[f64]) -> Result<VidyaBatchOutput, VidyaError> {
978 vidya_batch_with_kernel(data, &self.range, self.kernel)
979 }
980 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<VidyaBatchOutput, VidyaError> {
981 VidyaBatchBuilder::new().kernel(k).apply_slice(data)
982 }
983 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<VidyaBatchOutput, VidyaError> {
984 let slice = source_type(c, src);
985 self.apply_slice(slice)
986 }
987 pub fn with_default_candles(c: &Candles) -> Result<VidyaBatchOutput, VidyaError> {
988 VidyaBatchBuilder::new()
989 .kernel(Kernel::Auto)
990 .apply_candles(c, "close")
991 }
992}
993
994pub fn vidya_batch_with_kernel(
995 data: &[f64],
996 sweep: &VidyaBatchRange,
997 k: Kernel,
998) -> Result<VidyaBatchOutput, VidyaError> {
999 if data.is_empty() {
1000 return Err(VidyaError::EmptyInputData);
1001 }
1002 let kernel = match k {
1003 Kernel::Auto => match detect_best_batch_kernel() {
1004 Kernel::Avx512Batch => Kernel::Avx2Batch,
1005 other => other,
1006 },
1007 other if other.is_batch() => other,
1008 other => {
1009 return Err(VidyaError::InvalidKernelForBatch(other));
1010 }
1011 };
1012
1013 let simd = match kernel {
1014 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1015 Kernel::Avx512Batch => Kernel::Avx512,
1016 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1017 Kernel::Avx2Batch => Kernel::Avx2,
1018 Kernel::ScalarBatch => Kernel::Scalar,
1019 _ => Kernel::Scalar,
1020 };
1021 vidya_batch_par_slice(data, sweep, simd)
1022}
1023
1024#[derive(Clone, Debug)]
1025pub struct VidyaBatchOutput {
1026 pub values: Vec<f64>,
1027 pub combos: Vec<VidyaParams>,
1028 pub rows: usize,
1029 pub cols: usize,
1030}
1031impl VidyaBatchOutput {
1032 pub fn row_for_params(&self, p: &VidyaParams) -> Option<usize> {
1033 self.combos.iter().position(|c| {
1034 c.short_period.unwrap_or(2) == p.short_period.unwrap_or(2)
1035 && c.long_period.unwrap_or(5) == p.long_period.unwrap_or(5)
1036 && (c.alpha.unwrap_or(0.2) - p.alpha.unwrap_or(0.2)).abs() < 1e-12
1037 })
1038 }
1039 pub fn values_for(&self, p: &VidyaParams) -> Option<&[f64]> {
1040 self.row_for_params(p).map(|row| {
1041 let start = row * self.cols;
1042 &self.values[start..start + self.cols]
1043 })
1044 }
1045}
1046
1047#[inline(always)]
1048fn expand_grid(r: &VidyaBatchRange) -> Result<Vec<VidyaParams>, VidyaError> {
1049 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, VidyaError> {
1050 if step == 0 || start == end {
1051 return Ok(vec![start]);
1052 }
1053 if start < end {
1054 let mut v = Vec::new();
1055 let mut x = start;
1056 let st = step.max(1);
1057 while x <= end {
1058 v.push(x);
1059 x = x.saturating_add(st);
1060 }
1061 if v.is_empty() {
1062 return Err(VidyaError::InvalidRange {
1063 start: start.to_string(),
1064 end: end.to_string(),
1065 step: step.to_string(),
1066 });
1067 }
1068 return Ok(v);
1069 }
1070 let mut v = Vec::new();
1071 let mut x = start as isize;
1072 let end_i = end as isize;
1073 let st = (step as isize).max(1);
1074 while x >= end_i {
1075 v.push(x as usize);
1076 x -= st;
1077 }
1078 if v.is_empty() {
1079 return Err(VidyaError::InvalidRange {
1080 start: start.to_string(),
1081 end: end.to_string(),
1082 step: step.to_string(),
1083 });
1084 }
1085 Ok(v)
1086 }
1087 fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, VidyaError> {
1088 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
1089 return Ok(vec![start]);
1090 }
1091 if start < end {
1092 let mut v = Vec::new();
1093 let mut x = start;
1094 let st = step.abs();
1095 while x <= end + 1e-12 {
1096 v.push(x);
1097 x += st;
1098 }
1099 if v.is_empty() {
1100 return Err(VidyaError::InvalidRange {
1101 start: start.to_string(),
1102 end: end.to_string(),
1103 step: step.to_string(),
1104 });
1105 }
1106 return Ok(v);
1107 }
1108 let mut v = Vec::new();
1109 let mut x = start;
1110 let st = step.abs();
1111 while x + 1e-12 >= end {
1112 v.push(x);
1113 x -= st;
1114 }
1115 if v.is_empty() {
1116 return Err(VidyaError::InvalidRange {
1117 start: start.to_string(),
1118 end: end.to_string(),
1119 step: step.to_string(),
1120 });
1121 }
1122 Ok(v)
1123 }
1124
1125 let short_periods = axis_usize(r.short_period)?;
1126 let long_periods = axis_usize(r.long_period)?;
1127 let alphas = axis_f64(r.alpha)?;
1128
1129 let cap = short_periods
1130 .len()
1131 .checked_mul(long_periods.len())
1132 .and_then(|x| x.checked_mul(alphas.len()))
1133 .ok_or_else(|| VidyaError::InvalidRange {
1134 start: "cap".into(),
1135 end: "overflow".into(),
1136 step: "mul".into(),
1137 })?;
1138
1139 let mut out = Vec::with_capacity(cap);
1140 for &sp in &short_periods {
1141 for &lp in &long_periods {
1142 for &a in &alphas {
1143 out.push(VidyaParams {
1144 short_period: Some(sp),
1145 long_period: Some(lp),
1146 alpha: Some(a),
1147 });
1148 }
1149 }
1150 }
1151 Ok(out)
1152}
1153
1154#[inline(always)]
1155pub fn vidya_batch_slice(
1156 data: &[f64],
1157 sweep: &VidyaBatchRange,
1158 kern: Kernel,
1159) -> Result<VidyaBatchOutput, VidyaError> {
1160 vidya_batch_inner(data, sweep, kern, false)
1161}
1162
1163#[inline(always)]
1164pub fn vidya_batch_par_slice(
1165 data: &[f64],
1166 sweep: &VidyaBatchRange,
1167 kern: Kernel,
1168) -> Result<VidyaBatchOutput, VidyaError> {
1169 vidya_batch_inner(data, sweep, kern, true)
1170}
1171
1172#[inline(always)]
1173fn vidya_batch_inner(
1174 data: &[f64],
1175 sweep: &VidyaBatchRange,
1176 kern: Kernel,
1177 parallel: bool,
1178) -> Result<VidyaBatchOutput, VidyaError> {
1179 let combos = expand_grid(sweep)?;
1180 if data.is_empty() {
1181 return Err(VidyaError::EmptyInputData);
1182 }
1183 let first = data
1184 .iter()
1185 .position(|x| !x.is_nan())
1186 .ok_or(VidyaError::AllValuesNaN)?;
1187 let max_long = combos.iter().map(|c| c.long_period.unwrap()).max().unwrap();
1188 if data.len() - first < max_long {
1189 return Err(VidyaError::NotEnoughValidData {
1190 needed: max_long,
1191 valid: data.len() - first,
1192 });
1193 }
1194
1195 let rows = combos.len();
1196 let cols = data.len();
1197
1198 let _ = rows
1199 .checked_mul(cols)
1200 .ok_or_else(|| VidyaError::InvalidRange {
1201 start: rows.to_string(),
1202 end: cols.to_string(),
1203 step: "rows*cols".into(),
1204 })?;
1205
1206 let warmup_periods: Vec<usize> = combos
1207 .iter()
1208 .map(|c| first + c.long_period.unwrap() - 2)
1209 .collect();
1210
1211 let mut buf_mu = make_uninit_matrix(rows, cols);
1212 init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
1213
1214 let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
1215 let out: &mut [f64] = unsafe {
1216 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
1217 };
1218 let mut values = out;
1219
1220 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1221 let p = &combos[row];
1222 let sp = p.short_period.unwrap();
1223 let lp = p.long_period.unwrap();
1224 let a = p.alpha.unwrap();
1225 match kern {
1226 Kernel::Scalar | Kernel::ScalarBatch => {
1227 vidya_row_scalar(data, first, sp, lp, a, out_row)
1228 }
1229 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1230 Kernel::Avx2 | Kernel::Avx2Batch => vidya_row_avx2(data, first, sp, lp, a, out_row),
1231 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1232 Kernel::Avx512 | Kernel::Avx512Batch => {
1233 vidya_row_avx512(data, first, sp, lp, a, out_row)
1234 }
1235 _ => vidya_row_scalar(data, first, sp, lp, a, out_row),
1236 }
1237 };
1238 if parallel {
1239 #[cfg(not(target_arch = "wasm32"))]
1240 {
1241 values
1242 .par_chunks_mut(cols)
1243 .enumerate()
1244 .for_each(|(row, slice)| do_row(row, slice));
1245 }
1246
1247 #[cfg(target_arch = "wasm32")]
1248 {
1249 for (row, slice) in values.chunks_mut(cols).enumerate() {
1250 do_row(row, slice);
1251 }
1252 }
1253 } else {
1254 for (row, slice) in values.chunks_mut(cols).enumerate() {
1255 do_row(row, slice);
1256 }
1257 }
1258
1259 let values = unsafe {
1260 Vec::from_raw_parts(
1261 buf_guard.as_mut_ptr() as *mut f64,
1262 buf_guard.len(),
1263 buf_guard.capacity(),
1264 )
1265 };
1266
1267 Ok(VidyaBatchOutput {
1268 values,
1269 combos,
1270 rows,
1271 cols,
1272 })
1273}
1274
1275#[inline(always)]
1276unsafe fn vidya_row_scalar(
1277 data: &[f64],
1278 first: usize,
1279 short_period: usize,
1280 long_period: usize,
1281 alpha: f64,
1282 out: &mut [f64],
1283) {
1284 vidya_scalar(data, short_period, long_period, alpha, first, out);
1285}
1286
1287#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1288#[inline(always)]
1289unsafe fn vidya_row_avx2(
1290 data: &[f64],
1291 first: usize,
1292 short_period: usize,
1293 long_period: usize,
1294 alpha: f64,
1295 out: &mut [f64],
1296) {
1297 vidya_avx2_experimental(data, short_period, long_period, alpha, first, out);
1298}
1299
1300#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1301#[inline(always)]
1302unsafe fn vidya_row_avx512(
1303 data: &[f64],
1304 first: usize,
1305 short_period: usize,
1306 long_period: usize,
1307 alpha: f64,
1308 out: &mut [f64],
1309) {
1310 if long_period <= 32 {
1311 vidya_row_avx512_short(data, first, short_period, long_period, alpha, out)
1312 } else {
1313 vidya_row_avx512_long(data, first, short_period, long_period, alpha, out)
1314 }
1315}
1316
1317#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1318#[inline(always)]
1319unsafe fn vidya_row_avx512_short(
1320 data: &[f64],
1321 first: usize,
1322 short_period: usize,
1323 long_period: usize,
1324 alpha: f64,
1325 out: &mut [f64],
1326) {
1327 vidya_scalar(data, short_period, long_period, alpha, first, out);
1328}
1329
1330#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1331#[inline(always)]
1332unsafe fn vidya_row_avx512_long(
1333 data: &[f64],
1334 first: usize,
1335 short_period: usize,
1336 long_period: usize,
1337 alpha: f64,
1338 out: &mut [f64],
1339) {
1340 vidya_scalar(data, short_period, long_period, alpha, first, out);
1341}
1342
1343#[cfg(test)]
1344mod tests {
1345 use super::*;
1346 use crate::skip_if_unsupported;
1347 use crate::utilities::data_loader::read_candles_from_csv;
1348
1349 fn check_vidya_partial_params(
1350 test_name: &str,
1351 kernel: Kernel,
1352 ) -> Result<(), Box<dyn std::error::Error>> {
1353 skip_if_unsupported!(kernel, test_name);
1354 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1355 let candles = read_candles_from_csv(file_path)?;
1356 let default_params = VidyaParams {
1357 short_period: None,
1358 long_period: Some(10),
1359 alpha: None,
1360 };
1361 let input_default = VidyaInput::from_candles(&candles, "close", default_params);
1362 let output_default = vidya_with_kernel(&input_default, kernel)?;
1363 assert_eq!(output_default.values.len(), candles.close.len());
1364 Ok(())
1365 }
1366
1367 fn check_vidya_accuracy(
1368 test_name: &str,
1369 kernel: Kernel,
1370 ) -> Result<(), Box<dyn std::error::Error>> {
1371 skip_if_unsupported!(kernel, test_name);
1372 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1373 let candles = read_candles_from_csv(file_path)?;
1374 let close_prices = &candles.close;
1375
1376 let params = VidyaParams {
1377 short_period: Some(2),
1378 long_period: Some(5),
1379 alpha: Some(0.2),
1380 };
1381 let input = VidyaInput::from_candles(&candles, "close", params);
1382 let vidya_result = vidya_with_kernel(&input, kernel)?;
1383 assert_eq!(vidya_result.values.len(), close_prices.len());
1384
1385 if vidya_result.values.len() >= 5 {
1386 let expected_last_five = [
1387 59553.42785306692,
1388 59503.60445032524,
1389 59451.72283651444,
1390 59413.222561244685,
1391 59239.716526894175,
1392 ];
1393 let start_index = vidya_result.values.len() - 5;
1394 let result_last_five = &vidya_result.values[start_index..];
1395 for (i, &value) in result_last_five.iter().enumerate() {
1396 let expected_value = expected_last_five[i];
1397 assert!(
1398 (value - expected_value).abs() < 1e-1,
1399 "VIDYA mismatch at index {}: expected {}, got {}",
1400 i,
1401 expected_value,
1402 value
1403 );
1404 }
1405 }
1406 Ok(())
1407 }
1408
1409 fn check_vidya_default_candles(
1410 test_name: &str,
1411 kernel: Kernel,
1412 ) -> Result<(), Box<dyn std::error::Error>> {
1413 skip_if_unsupported!(kernel, test_name);
1414 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1415 let candles = read_candles_from_csv(file_path)?;
1416 let input = VidyaInput::with_default_candles(&candles);
1417 match input.data {
1418 VidyaData::Candles { source, .. } => assert_eq!(source, "close"),
1419 _ => panic!("Expected VidyaData::Candles"),
1420 }
1421 let output = vidya_with_kernel(&input, kernel)?;
1422 assert_eq!(output.values.len(), candles.close.len());
1423 Ok(())
1424 }
1425
1426 fn check_vidya_invalid_params(
1427 test_name: &str,
1428 kernel: Kernel,
1429 ) -> Result<(), Box<dyn std::error::Error>> {
1430 skip_if_unsupported!(kernel, test_name);
1431 let data = [10.0, 20.0, 30.0];
1432 let params = VidyaParams {
1433 short_period: Some(0),
1434 long_period: Some(5),
1435 alpha: Some(0.2),
1436 };
1437 let input = VidyaInput::from_slice(&data, params);
1438 let result = vidya_with_kernel(&input, kernel);
1439 assert!(result.is_err(), "Expected error for invalid short period");
1440 Ok(())
1441 }
1442
1443 fn check_vidya_exceeding_data_length(
1444 test_name: &str,
1445 kernel: Kernel,
1446 ) -> Result<(), Box<dyn std::error::Error>> {
1447 skip_if_unsupported!(kernel, test_name);
1448 let data = [10.0, 20.0, 30.0];
1449 let params = VidyaParams {
1450 short_period: Some(2),
1451 long_period: Some(5),
1452 alpha: Some(0.2),
1453 };
1454 let input = VidyaInput::from_slice(&data, params);
1455 let result = vidya_with_kernel(&input, kernel);
1456 assert!(result.is_err(), "Expected error for period > data.len()");
1457 Ok(())
1458 }
1459
1460 fn check_vidya_very_small_data_set(
1461 test_name: &str,
1462 kernel: Kernel,
1463 ) -> Result<(), Box<dyn std::error::Error>> {
1464 skip_if_unsupported!(kernel, test_name);
1465 let data = [42.0, 43.0];
1466 let params = VidyaParams {
1467 short_period: Some(2),
1468 long_period: Some(5),
1469 alpha: Some(0.2),
1470 };
1471 let input = VidyaInput::from_slice(&data, params);
1472 let result = vidya_with_kernel(&input, kernel);
1473 assert!(
1474 result.is_err(),
1475 "Expected error for data smaller than long period"
1476 );
1477 Ok(())
1478 }
1479
1480 fn check_vidya_reinput(
1481 test_name: &str,
1482 kernel: Kernel,
1483 ) -> Result<(), Box<dyn std::error::Error>> {
1484 skip_if_unsupported!(kernel, test_name);
1485 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1486 let candles = read_candles_from_csv(file_path)?;
1487
1488 let first_params = VidyaParams {
1489 short_period: Some(2),
1490 long_period: Some(5),
1491 alpha: Some(0.2),
1492 };
1493 let first_input = VidyaInput::from_candles(&candles, "close", first_params);
1494 let first_result = vidya_with_kernel(&first_input, kernel)?;
1495
1496 let second_params = VidyaParams {
1497 short_period: Some(2),
1498 long_period: Some(5),
1499 alpha: Some(0.2),
1500 };
1501 let second_input = VidyaInput::from_slice(&first_result.values, second_params);
1502 let second_result = vidya_with_kernel(&second_input, kernel)?;
1503 assert_eq!(second_result.values.len(), first_result.values.len());
1504 Ok(())
1505 }
1506
1507 fn check_vidya_nan_handling(
1508 test_name: &str,
1509 kernel: Kernel,
1510 ) -> Result<(), Box<dyn std::error::Error>> {
1511 skip_if_unsupported!(kernel, test_name);
1512 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1513 let candles = read_candles_from_csv(file_path)?;
1514 let params = VidyaParams {
1515 short_period: Some(2),
1516 long_period: Some(5),
1517 alpha: Some(0.2),
1518 };
1519 let input = VidyaInput::from_candles(&candles, "close", params);
1520 let vidya_result = vidya_with_kernel(&input, kernel)?;
1521 if vidya_result.values.len() > 10 {
1522 for i in 10..vidya_result.values.len() {
1523 assert!(!vidya_result.values[i].is_nan());
1524 }
1525 }
1526 Ok(())
1527 }
1528
1529 fn check_vidya_streaming(
1530 test_name: &str,
1531 kernel: Kernel,
1532 ) -> Result<(), Box<dyn std::error::Error>> {
1533 skip_if_unsupported!(kernel, test_name);
1534 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1535 let candles = read_candles_from_csv(file_path)?;
1536
1537 let short_period = 2;
1538 let long_period = 5;
1539 let alpha = 0.2;
1540
1541 let params = VidyaParams {
1542 short_period: Some(short_period),
1543 long_period: Some(long_period),
1544 alpha: Some(alpha),
1545 };
1546 let input = VidyaInput::from_candles(&candles, "close", params.clone());
1547 let batch_output = vidya_with_kernel(&input, kernel)?.values;
1548
1549 let mut stream = VidyaStream::try_new(params.clone())?;
1550 let mut stream_values = Vec::with_capacity(candles.close.len());
1551 for &price in &candles.close {
1552 match stream.update(price) {
1553 Some(val) => stream_values.push(val),
1554 None => stream_values.push(f64::NAN),
1555 }
1556 }
1557 assert_eq!(batch_output.len(), stream_values.len());
1558 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1559 if b.is_nan() && s.is_nan() {
1560 continue;
1561 }
1562 let diff = (b - s).abs();
1563 assert!(
1564 diff < 1e-3,
1565 "[{}] VIDYA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1566 test_name,
1567 i,
1568 b,
1569 s,
1570 diff
1571 );
1572 }
1573 Ok(())
1574 }
1575
1576 #[cfg(debug_assertions)]
1577 fn check_vidya_no_poison(
1578 test_name: &str,
1579 kernel: Kernel,
1580 ) -> Result<(), Box<dyn std::error::Error>> {
1581 skip_if_unsupported!(kernel, test_name);
1582
1583 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1584 let candles = read_candles_from_csv(file_path)?;
1585
1586 let test_params = vec![
1587 VidyaParams::default(),
1588 VidyaParams {
1589 short_period: Some(1),
1590 long_period: Some(2),
1591 alpha: Some(0.1),
1592 },
1593 VidyaParams {
1594 short_period: Some(2),
1595 long_period: Some(3),
1596 alpha: Some(0.2),
1597 },
1598 VidyaParams {
1599 short_period: Some(2),
1600 long_period: Some(5),
1601 alpha: Some(0.5),
1602 },
1603 VidyaParams {
1604 short_period: Some(3),
1605 long_period: Some(7),
1606 alpha: Some(0.3),
1607 },
1608 VidyaParams {
1609 short_period: Some(4),
1610 long_period: Some(10),
1611 alpha: Some(0.2),
1612 },
1613 VidyaParams {
1614 short_period: Some(5),
1615 long_period: Some(20),
1616 alpha: Some(0.4),
1617 },
1618 VidyaParams {
1619 short_period: Some(10),
1620 long_period: Some(30),
1621 alpha: Some(0.2),
1622 },
1623 VidyaParams {
1624 short_period: Some(15),
1625 long_period: Some(50),
1626 alpha: Some(0.3),
1627 },
1628 VidyaParams {
1629 short_period: Some(20),
1630 long_period: Some(100),
1631 alpha: Some(0.2),
1632 },
1633 VidyaParams {
1634 short_period: Some(50),
1635 long_period: Some(200),
1636 alpha: Some(0.1),
1637 },
1638 VidyaParams {
1639 short_period: Some(2),
1640 long_period: Some(10),
1641 alpha: Some(0.8),
1642 },
1643 VidyaParams {
1644 short_period: Some(3),
1645 long_period: Some(15),
1646 alpha: Some(1.0),
1647 },
1648 VidyaParams {
1649 short_period: Some(1),
1650 long_period: Some(100),
1651 alpha: Some(0.01),
1652 },
1653 VidyaParams {
1654 short_period: Some(99),
1655 long_period: Some(100),
1656 alpha: Some(0.99),
1657 },
1658 ];
1659
1660 for (param_idx, params) in test_params.iter().enumerate() {
1661 let input = VidyaInput::from_candles(&candles, "close", params.clone());
1662 let output = vidya_with_kernel(&input, kernel)?;
1663
1664 for (i, &val) in output.values.iter().enumerate() {
1665 if val.is_nan() {
1666 continue;
1667 }
1668
1669 let bits = val.to_bits();
1670
1671 if bits == 0x11111111_11111111 {
1672 panic!(
1673 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1674 with params: {:?} (param set {})",
1675 test_name, val, bits, i, params, param_idx
1676 );
1677 }
1678
1679 if bits == 0x22222222_22222222 {
1680 panic!(
1681 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1682 with params: {:?} (param set {})",
1683 test_name, val, bits, i, params, param_idx
1684 );
1685 }
1686
1687 if bits == 0x33333333_33333333 {
1688 panic!(
1689 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1690 with params: {:?} (param set {})",
1691 test_name, val, bits, i, params, param_idx
1692 );
1693 }
1694 }
1695 }
1696
1697 Ok(())
1698 }
1699
1700 #[cfg(not(debug_assertions))]
1701 fn check_vidya_no_poison(
1702 _test_name: &str,
1703 _kernel: Kernel,
1704 ) -> Result<(), Box<dyn std::error::Error>> {
1705 Ok(())
1706 }
1707
1708 #[cfg(feature = "proptest")]
1709 #[allow(clippy::float_cmp)]
1710 fn check_vidya_property(
1711 test_name: &str,
1712 kernel: Kernel,
1713 ) -> Result<(), Box<dyn std::error::Error>> {
1714 use proptest::prelude::*;
1715 skip_if_unsupported!(kernel, test_name);
1716
1717 let strat = (2usize..=20).prop_flat_map(|short_period| {
1718 let long_min = (short_period + 1).max(2);
1719 let long_max = 100.min(long_min + 50);
1720
1721 (long_min..=long_max).prop_flat_map(move |long_period| {
1722 let data_len = long_period.max(10)..400;
1723
1724 (
1725 prop::collection::vec(
1726 (-0.05f64..0.05f64).prop_filter("finite", |x| x.is_finite()),
1727 data_len,
1728 )
1729 .prop_map(|returns| {
1730 let mut prices = Vec::with_capacity(returns.len());
1731 let mut price = 100.0;
1732
1733 for (i, ret) in returns.iter().enumerate() {
1734 let volatility_factor = if (i / 20) % 2 == 0 { 0.5 } else { 2.0 };
1735 price *= 1.0 + (ret * volatility_factor);
1736 prices.push(price);
1737 }
1738 prices
1739 }),
1740 Just(short_period),
1741 Just(long_period),
1742 0.01f64..1.0f64,
1743 )
1744 })
1745 });
1746
1747 proptest::test_runner::TestRunner::default()
1748 .run(&strat, |(data, short_period, long_period, alpha)| {
1749 let params = VidyaParams {
1750 short_period: Some(short_period),
1751 long_period: Some(long_period),
1752 alpha: Some(alpha),
1753 };
1754 let input = VidyaInput::from_slice(&data, params.clone());
1755
1756
1757 let VidyaOutput { values: out } = vidya_with_kernel(&input, kernel).unwrap();
1758 let VidyaOutput { values: ref_out } = vidya_with_kernel(&input, Kernel::Scalar).unwrap();
1759
1760
1761 for i in 0..data.len() {
1762 let y = out[i];
1763 let r = ref_out[i];
1764
1765 if !y.is_finite() || !r.is_finite() {
1766 prop_assert!(y.to_bits() == r.to_bits(),
1767 "[{}] finite/NaN mismatch at idx {}: {} vs {}", test_name, i, y, r);
1768 continue;
1769 }
1770
1771 let ulp_diff = y.to_bits().abs_diff(r.to_bits());
1772 prop_assert!(
1773
1774
1775 (y - r).abs() <= 5e-8 || ulp_diff <= 4,
1776 "[{}] kernel mismatch at idx {}: {} vs {} (ULP={})",
1777 test_name, i, y, r, ulp_diff
1778 );
1779 }
1780
1781
1782 let first = data.iter().position(|&x| !x.is_nan()).unwrap_or(0);
1783 let first_valid_idx = if first + long_period >= 2 {
1784 first + long_period - 2
1785 } else {
1786 0
1787 };
1788
1789
1790 for i in 0..first_valid_idx.min(data.len()) {
1791 prop_assert!(out[i].is_nan(),
1792 "[{}] Expected NaN during warmup at idx {}, got {}", test_name, i, out[i]);
1793 }
1794
1795
1796 if first_valid_idx < data.len() {
1797 prop_assert!(!out[first_valid_idx].is_nan(),
1798 "[{}] Expected valid value at first_valid_idx {}, got NaN", test_name, first_valid_idx);
1799 }
1800
1801
1802 let warmup_end = first + long_period - 2;
1803
1804
1805 if data.len() > warmup_end + 1 {
1806 let data_min = data.iter().cloned().fold(f64::INFINITY, f64::min);
1807 let data_max = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1808
1809 let range = data_max - data_min;
1810
1811
1812 let alpha_factor = 1.0 + alpha * 2.0;
1813 let margin = if range < 1.0 {
1814
1815 let avg_magnitude = (data_max.abs() + data_min.abs()) / 2.0;
1816 avg_magnitude * 0.3 * alpha_factor
1817 } else {
1818
1819
1820 range * 0.5 * alpha_factor
1821 };
1822
1823 for i in (warmup_end + 1)..data.len() {
1824 let y = out[i];
1825 if y.is_finite() {
1826 prop_assert!(
1827 y >= data_min - margin && y <= data_max + margin,
1828 "[{}] Output {} at idx {} outside reasonable range [{}, {}] (alpha={:.3})",
1829 test_name, y, i, data_min - margin, data_max + margin, alpha
1830 );
1831 }
1832 }
1833 }
1834
1835
1836 if alpha < 0.05 && data.len() > warmup_end + 10 {
1837
1838 let vidya_section = &out[(warmup_end + 1)..];
1839 if vidya_section.len() > 2 {
1840
1841 for window in vidya_section.windows(2) {
1842 let change_ratio = (window[1] - window[0]).abs() / window[0].abs().max(1e-10);
1843 prop_assert!(
1844 change_ratio < 0.1,
1845 "[{}] With alpha={}, VIDYA should be stable but found large change ratio {}",
1846 test_name, alpha, change_ratio
1847 );
1848 }
1849 }
1850 }
1851
1852
1853 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) && data.len() > warmup_end + 1 {
1854
1855 let constant_value = data[0];
1856 for i in (warmup_end + 1)..data.len() {
1857 prop_assert!(
1858 (out[i] - constant_value).abs() <= 1e-9,
1859 "[{}] Constant input should produce constant output, got {} expected {}",
1860 test_name, out[i], constant_value
1861 );
1862 }
1863 }
1864
1865
1866
1867
1868
1869 if alpha >= 0.05 && data.len() > warmup_end + 20 {
1870
1871 let mut same_direction_count = 0;
1872 let mut total_movements = 0;
1873 let mut frozen_periods = 0;
1874
1875 for i in (warmup_end + 1)..data.len() {
1876 let price_change = data[i] - data[i - 1];
1877 let vidya_change = out[i] - out[i - 1];
1878
1879
1880 if price_change.abs() > 1e-6 && vidya_change.abs() <= 1e-10 {
1881 frozen_periods += 1;
1882 }
1883
1884 else if price_change.abs() > 1e-6 && vidya_change.abs() > 1e-10 {
1885 total_movements += 1;
1886 if price_change.signum() == vidya_change.signum() {
1887 same_direction_count += 1;
1888 }
1889 }
1890 }
1891
1892
1893
1894
1895 if total_movements > 10 && frozen_periods < (data.len() - warmup_end) / 2 {
1896 let direction_ratio = same_direction_count as f64 / total_movements as f64;
1897 prop_assert!(
1898 direction_ratio >= 0.40,
1899 "[{}] VIDYA should generally follow price direction when moving, but only followed {:.1}% of the time (frozen for {} periods)",
1900 test_name, direction_ratio * 100.0, frozen_periods
1901 );
1902 }
1903 }
1904
1905
1906 for (i, &val) in out.iter().enumerate() {
1907 if val.is_finite() {
1908 let bits = val.to_bits();
1909 prop_assert!(
1910 bits != 0x11111111_11111111 &&
1911 bits != 0x22222222_22222222 &&
1912 bits != 0x33333333_33333333,
1913 "[{}] Found poison value {} (0x{:016X}) at index {}",
1914 test_name, val, bits, i
1915 );
1916 }
1917 }
1918
1919 Ok(())
1920 })
1921 .unwrap();
1922
1923 Ok(())
1924 }
1925
1926 macro_rules! generate_all_vidya_tests {
1927 ($($test_fn:ident),*) => {
1928 paste! {
1929 $(
1930 #[test]
1931 fn [<$test_fn _scalar_f64>]() {
1932 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1933 }
1934 )*
1935 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1936 $(
1937 #[test]
1938 fn [<$test_fn _avx2_f64>]() {
1939 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1940 }
1941 #[test]
1942 fn [<$test_fn _avx512_f64>]() {
1943 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1944 }
1945 )*
1946 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1947 $(
1948 #[test]
1949 fn [<$test_fn _simd128_f64>]() {
1950 let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
1951 }
1952 )*
1953 }
1954 }
1955 }
1956
1957 generate_all_vidya_tests!(
1958 check_vidya_partial_params,
1959 check_vidya_accuracy,
1960 check_vidya_default_candles,
1961 check_vidya_invalid_params,
1962 check_vidya_exceeding_data_length,
1963 check_vidya_very_small_data_set,
1964 check_vidya_reinput,
1965 check_vidya_nan_handling,
1966 check_vidya_streaming,
1967 check_vidya_no_poison
1968 );
1969
1970 #[cfg(feature = "proptest")]
1971 generate_all_vidya_tests!(check_vidya_property);
1972
1973 fn check_batch_default_row(
1974 test: &str,
1975 kernel: Kernel,
1976 ) -> Result<(), Box<dyn std::error::Error>> {
1977 skip_if_unsupported!(kernel, test);
1978
1979 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1980 let c = read_candles_from_csv(file)?;
1981
1982 let output = VidyaBatchBuilder::new()
1983 .kernel(kernel)
1984 .apply_candles(&c, "close")?;
1985
1986 let def = VidyaParams::default();
1987 let row = output.values_for(&def).expect("default row missing");
1988
1989 assert_eq!(row.len(), c.close.len());
1990
1991 let expected = [
1992 59553.42785306692,
1993 59503.60445032524,
1994 59451.72283651444,
1995 59413.222561244685,
1996 59239.716526894175,
1997 ];
1998 let start = row.len() - 5;
1999 for (i, &v) in row[start..].iter().enumerate() {
2000 assert!(
2001 (v - expected[i]).abs() < 1e-1,
2002 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2003 );
2004 }
2005 Ok(())
2006 }
2007
2008 #[cfg(debug_assertions)]
2009 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
2010 skip_if_unsupported!(kernel, test);
2011
2012 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2013 let c = read_candles_from_csv(file)?;
2014
2015 let test_configs = vec![
2016 (2, 5, 1, 5, 10, 1, 0.2, 0.2, 0.0),
2017 (2, 10, 2, 10, 30, 5, 0.1, 0.5, 0.1),
2018 (5, 20, 5, 30, 60, 10, 0.2, 0.2, 0.0),
2019 (10, 30, 10, 50, 100, 25, 0.3, 0.3, 0.0),
2020 (2, 2, 0, 5, 50, 5, 0.1, 0.9, 0.2),
2021 (5, 15, 5, 20, 20, 0, 0.2, 0.8, 0.3),
2022 (1, 3, 1, 4, 8, 2, 0.5, 0.5, 0.0),
2023 (20, 50, 15, 100, 200, 50, 0.1, 0.3, 0.1),
2024 (2, 2, 0, 3, 3, 0, 0.1, 1.0, 0.1),
2025 ];
2026
2027 for (cfg_idx, &(s_start, s_end, s_step, l_start, l_end, l_step, a_start, a_end, a_step)) in
2028 test_configs.iter().enumerate()
2029 {
2030 let output = VidyaBatchBuilder::new()
2031 .kernel(kernel)
2032 .short_period_range(s_start, s_end, s_step)
2033 .long_period_range(l_start, l_end, l_step)
2034 .alpha_range(a_start, a_end, a_step)
2035 .apply_candles(&c, "close")?;
2036
2037 for (idx, &val) in output.values.iter().enumerate() {
2038 if val.is_nan() {
2039 continue;
2040 }
2041
2042 let bits = val.to_bits();
2043 let row = idx / output.cols;
2044 let col = idx % output.cols;
2045 let combo = &output.combos[row];
2046
2047 if bits == 0x11111111_11111111 {
2048 panic!(
2049 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2050 at row {} col {} (flat index {}) with params: {:?}",
2051 test, cfg_idx, val, bits, row, col, idx, combo
2052 );
2053 }
2054
2055 if bits == 0x22222222_22222222 {
2056 panic!(
2057 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2058 at row {} col {} (flat index {}) with params: {:?}",
2059 test, cfg_idx, val, bits, row, col, idx, combo
2060 );
2061 }
2062
2063 if bits == 0x33333333_33333333 {
2064 panic!(
2065 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2066 at row {} col {} (flat index {}) with params: {:?}",
2067 test, cfg_idx, val, bits, row, col, idx, combo
2068 );
2069 }
2070 }
2071 }
2072
2073 Ok(())
2074 }
2075
2076 #[cfg(not(debug_assertions))]
2077 fn check_batch_no_poison(
2078 _test: &str,
2079 _kernel: Kernel,
2080 ) -> Result<(), Box<dyn std::error::Error>> {
2081 Ok(())
2082 }
2083
2084 macro_rules! gen_batch_tests {
2085 ($fn_name:ident) => {
2086 paste! {
2087 #[test] fn [<$fn_name _scalar>]() {
2088 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2089 }
2090 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2091 #[test] fn [<$fn_name _avx2>]() {
2092 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2093 }
2094 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2095 #[test] fn [<$fn_name _avx512>]() {
2096 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2097 }
2098 #[test] fn [<$fn_name _auto_detect>]() {
2099 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2100 }
2101 }
2102 };
2103 }
2104 gen_batch_tests!(check_batch_default_row);
2105 gen_batch_tests!(check_batch_no_poison);
2106}
2107
2108#[inline(always)]
2109pub fn vidya_batch_inner_into(
2110 data: &[f64],
2111 sweep: &VidyaBatchRange,
2112 kern: Kernel,
2113 parallel: bool,
2114 out: &mut [f64],
2115) -> Result<Vec<VidyaParams>, VidyaError> {
2116 let combos = expand_grid(sweep)?;
2117 if data.is_empty() {
2118 return Err(VidyaError::EmptyInputData);
2119 }
2120
2121 let first = data
2122 .iter()
2123 .position(|x| !x.is_nan())
2124 .ok_or(VidyaError::AllValuesNaN)?;
2125 let max_long = combos.iter().map(|c| c.long_period.unwrap()).max().unwrap();
2126 if data.len() - first < max_long {
2127 return Err(VidyaError::NotEnoughValidData {
2128 needed: max_long,
2129 valid: data.len() - first,
2130 });
2131 }
2132
2133 let rows = combos.len();
2134 let cols = data.len();
2135
2136 let expected = rows
2137 .checked_mul(cols)
2138 .ok_or_else(|| VidyaError::InvalidRange {
2139 start: rows.to_string(),
2140 end: cols.to_string(),
2141 step: "rows*cols".into(),
2142 })?;
2143 if out.len() < expected {
2144 return Err(VidyaError::OutputLengthMismatch {
2145 expected,
2146 got: out.len(),
2147 });
2148 }
2149
2150 let warmup_periods: Vec<usize> = combos
2151 .iter()
2152 .map(|c| first + c.long_period.unwrap() - 2)
2153 .collect();
2154
2155 for (row, &warmup) in warmup_periods.iter().enumerate() {
2156 let row_start = row * cols;
2157 out[row_start..row_start + warmup].fill(f64::NAN);
2158 }
2159
2160 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
2161 let p = &combos[row];
2162 let sp = p.short_period.unwrap();
2163 let lp = p.long_period.unwrap();
2164 let a = p.alpha.unwrap();
2165 match kern {
2166 Kernel::Scalar | Kernel::ScalarBatch => {
2167 vidya_row_scalar(data, first, sp, lp, a, out_row)
2168 }
2169 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2170 Kernel::Avx2 | Kernel::Avx2Batch => vidya_row_avx2(data, first, sp, lp, a, out_row),
2171 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2172 Kernel::Avx512 | Kernel::Avx512Batch => {
2173 vidya_row_avx512(data, first, sp, lp, a, out_row)
2174 }
2175 _ => vidya_row_scalar(data, first, sp, lp, a, out_row),
2176 }
2177 };
2178
2179 if parallel {
2180 #[cfg(not(target_arch = "wasm32"))]
2181 {
2182 out.par_chunks_mut(cols)
2183 .enumerate()
2184 .for_each(|(row, slice)| do_row(row, slice));
2185 }
2186 #[cfg(target_arch = "wasm32")]
2187 {
2188 for (row, slice) in out.chunks_mut(cols).enumerate() {
2189 do_row(row, slice);
2190 }
2191 }
2192 } else {
2193 for (row, slice) in out.chunks_mut(cols).enumerate() {
2194 do_row(row, slice);
2195 }
2196 }
2197
2198 Ok(combos)
2199}
2200
2201#[cfg(feature = "python")]
2202#[pyfunction(name = "vidya")]
2203#[pyo3(signature = (data, short_period, long_period, alpha, kernel=None))]
2204pub fn vidya_py<'py>(
2205 py: Python<'py>,
2206 data: PyReadonlyArray1<'py, f64>,
2207 short_period: usize,
2208 long_period: usize,
2209 alpha: f64,
2210 kernel: Option<&str>,
2211) -> PyResult<Bound<'py, PyArray1<f64>>> {
2212 let slice_in = data.as_slice()?;
2213 let kern = validate_kernel(kernel, false)?;
2214
2215 let params = VidyaParams {
2216 short_period: Some(short_period),
2217 long_period: Some(long_period),
2218 alpha: Some(alpha),
2219 };
2220 let input = VidyaInput::from_slice(slice_in, params);
2221
2222 let result_vec: Vec<f64> = py
2223 .allow_threads(|| vidya_with_kernel(&input, kern).map(|o| o.values))
2224 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2225
2226 Ok(result_vec.into_pyarray(py))
2227}
2228
2229#[cfg(feature = "python")]
2230#[pyclass(name = "VidyaStream")]
2231pub struct VidyaStreamPy {
2232 stream: VidyaStream,
2233}
2234
2235#[cfg(feature = "python")]
2236#[pymethods]
2237impl VidyaStreamPy {
2238 #[new]
2239 fn new(short_period: usize, long_period: usize, alpha: f64) -> PyResult<Self> {
2240 let params = VidyaParams {
2241 short_period: Some(short_period),
2242 long_period: Some(long_period),
2243 alpha: Some(alpha),
2244 };
2245 let stream =
2246 VidyaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2247 Ok(VidyaStreamPy { stream })
2248 }
2249
2250 fn update(&mut self, value: f64) -> Option<f64> {
2251 self.stream.update(value)
2252 }
2253}
2254
2255#[cfg(feature = "python")]
2256#[pyfunction(name = "vidya_batch")]
2257#[pyo3(signature = (data, short_period_range, long_period_range, alpha_range, kernel=None))]
2258pub fn vidya_batch_py<'py>(
2259 py: Python<'py>,
2260 data: PyReadonlyArray1<'py, f64>,
2261 short_period_range: (usize, usize, usize),
2262 long_period_range: (usize, usize, usize),
2263 alpha_range: (f64, f64, f64),
2264 kernel: Option<&str>,
2265) -> PyResult<Bound<'py, PyDict>> {
2266 let slice_in = data.as_slice()?;
2267
2268 let sweep = VidyaBatchRange {
2269 short_period: short_period_range,
2270 long_period: long_period_range,
2271 alpha: alpha_range,
2272 };
2273
2274 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2275 let rows = combos.len();
2276 let cols = slice_in.len();
2277
2278 let total = rows
2279 .checked_mul(cols)
2280 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2281 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2282 let slice_out = unsafe { out_arr.as_slice_mut()? };
2283
2284 let kern = validate_kernel(kernel, true)?;
2285
2286 let combos = py
2287 .allow_threads(|| {
2288 let kernel = match kern {
2289 Kernel::Auto => match detect_best_batch_kernel() {
2290 Kernel::Avx512Batch => Kernel::Avx2Batch,
2291 other => other,
2292 },
2293 k => k,
2294 };
2295 let simd = match kernel {
2296 Kernel::Avx512Batch => Kernel::Avx512,
2297 Kernel::Avx2Batch => Kernel::Avx2,
2298 Kernel::ScalarBatch => Kernel::Scalar,
2299 _ => kernel,
2300 };
2301 vidya_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2302 })
2303 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2304
2305 let dict = PyDict::new(py);
2306 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2307 dict.set_item(
2308 "short_periods",
2309 combos
2310 .iter()
2311 .map(|p| p.short_period.unwrap() as u64)
2312 .collect::<Vec<_>>()
2313 .into_pyarray(py),
2314 )?;
2315 dict.set_item(
2316 "long_periods",
2317 combos
2318 .iter()
2319 .map(|p| p.long_period.unwrap() as u64)
2320 .collect::<Vec<_>>()
2321 .into_pyarray(py),
2322 )?;
2323 dict.set_item(
2324 "alphas",
2325 combos
2326 .iter()
2327 .map(|p| p.alpha.unwrap())
2328 .collect::<Vec<_>>()
2329 .into_pyarray(py),
2330 )?;
2331
2332 Ok(dict)
2333}
2334
2335#[cfg(all(feature = "python", feature = "cuda"))]
2336#[pyclass(
2337 module = "ta_indicators.cuda",
2338 name = "VidyaDeviceArrayF32",
2339 unsendable
2340)]
2341pub struct VidyaDeviceArrayF32Py {
2342 pub(crate) inner: DeviceArrayF32,
2343 pub(crate) _ctx: std::sync::Arc<Context>,
2344 pub(crate) device_id: u32,
2345}
2346
2347#[cfg(all(feature = "python", feature = "cuda"))]
2348#[pymethods]
2349impl VidyaDeviceArrayF32Py {
2350 #[getter]
2351 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2352 let d = PyDict::new(py);
2353 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
2354 d.set_item("typestr", "<f4")?;
2355 d.set_item(
2356 "strides",
2357 (
2358 self.inner.cols * std::mem::size_of::<f32>(),
2359 std::mem::size_of::<f32>(),
2360 ),
2361 )?;
2362 d.set_item("data", (self.inner.device_ptr() as usize, false))?;
2363
2364 d.set_item("version", 3)?;
2365 Ok(d)
2366 }
2367
2368 fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
2369 Ok((2, self.device_id as i32))
2370 }
2371
2372 #[pyo3(signature=(_stream=None, max_version=None, _dl_device=None, _copy=None))]
2373 fn __dlpack__<'py>(
2374 &mut self,
2375 py: Python<'py>,
2376 _stream: Option<pyo3::PyObject>,
2377 max_version: Option<pyo3::PyObject>,
2378 _dl_device: Option<pyo3::PyObject>,
2379 _copy: Option<pyo3::PyObject>,
2380 ) -> PyResult<PyObject> {
2381 let (kdl, alloc_dev) = self.__dlpack_device__()?;
2382 if let Some(dev_obj) = _dl_device.as_ref() {
2383 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2384 if dev_ty != kdl || dev_id != alloc_dev {
2385 let wants_copy = _copy
2386 .as_ref()
2387 .and_then(|c| c.extract::<bool>(py).ok())
2388 .unwrap_or(false);
2389 if wants_copy {
2390 return Err(PyValueError::new_err(
2391 "device copy not implemented for __dlpack__",
2392 ));
2393 } else {
2394 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2395 }
2396 }
2397 }
2398 }
2399 let _ = _stream;
2400
2401 let dummy =
2402 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2403 let inner = std::mem::replace(
2404 &mut self.inner,
2405 DeviceArrayF32 {
2406 buf: dummy,
2407 rows: 0,
2408 cols: 0,
2409 },
2410 );
2411
2412 let rows = inner.rows;
2413 let cols = inner.cols;
2414 let buf = inner.buf;
2415
2416 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2417
2418 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2419 }
2420}
2421
2422#[cfg(all(feature = "python", feature = "cuda"))]
2423#[pyfunction(name = "vidya_cuda_batch_dev")]
2424#[pyo3(signature = (data, short_period_range, long_period_range, alpha_range, device_id=0))]
2425pub fn vidya_cuda_batch_dev_py(
2426 py: Python<'_>,
2427 data: PyReadonlyArray1<'_, f32>,
2428 short_period_range: (usize, usize, usize),
2429 long_period_range: (usize, usize, usize),
2430 alpha_range: (f64, f64, f64),
2431 device_id: usize,
2432) -> PyResult<VidyaDeviceArrayF32Py> {
2433 use crate::cuda::cuda_available;
2434 if !cuda_available() {
2435 return Err(PyValueError::new_err("CUDA not available"));
2436 }
2437 let slice = data.as_slice()?;
2438 let sweep = VidyaBatchRange {
2439 short_period: short_period_range,
2440 long_period: long_period_range,
2441 alpha: alpha_range,
2442 };
2443 let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
2444 let cuda = CudaVidya::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2445 let arr = cuda
2446 .vidya_batch_dev(slice, &sweep)
2447 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2448 Ok::<_, PyErr>((arr, cuda.context_arc_clone(), cuda.device_id()))
2449 })?;
2450 Ok(VidyaDeviceArrayF32Py {
2451 inner,
2452 _ctx: ctx_arc,
2453 device_id: dev_id,
2454 })
2455}
2456
2457#[cfg(all(feature = "python", feature = "cuda"))]
2458#[pyfunction(name = "vidya_cuda_many_series_one_param_dev")]
2459#[pyo3(signature = (data_tm, cols, rows, short_period, long_period, alpha, device_id=0))]
2460pub fn vidya_cuda_many_series_one_param_dev_py(
2461 py: Python<'_>,
2462 data_tm: PyReadonlyArray1<'_, f32>,
2463 cols: usize,
2464 rows: usize,
2465 short_period: usize,
2466 long_period: usize,
2467 alpha: f64,
2468 device_id: usize,
2469) -> PyResult<VidyaDeviceArrayF32Py> {
2470 use crate::cuda::cuda_available;
2471 if !cuda_available() {
2472 return Err(PyValueError::new_err("CUDA not available"));
2473 }
2474 let slice = data_tm.as_slice()?;
2475 let expected = cols
2476 .checked_mul(rows)
2477 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2478 if slice.len() != expected {
2479 return Err(PyValueError::new_err("time-major input length mismatch"));
2480 }
2481 let params = VidyaParams {
2482 short_period: Some(short_period),
2483 long_period: Some(long_period),
2484 alpha: Some(alpha),
2485 };
2486 let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
2487 let cuda = CudaVidya::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2488 let arr = cuda
2489 .vidya_many_series_one_param_time_major_dev(slice, cols, rows, ¶ms)
2490 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2491 Ok::<_, PyErr>((arr, cuda.context_arc_clone(), cuda.device_id()))
2492 })?;
2493 Ok(VidyaDeviceArrayF32Py {
2494 inner,
2495 _ctx: ctx_arc,
2496 device_id: dev_id,
2497 })
2498}
2499
2500#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2501#[wasm_bindgen]
2502pub fn vidya_js(
2503 data: &[f64],
2504 short_period: usize,
2505 long_period: usize,
2506 alpha: f64,
2507) -> Result<Vec<f64>, JsValue> {
2508 let params = VidyaParams {
2509 short_period: Some(short_period),
2510 long_period: Some(long_period),
2511 alpha: Some(alpha),
2512 };
2513 let input = VidyaInput::from_slice(data, params);
2514
2515 let mut output = vec![0.0; data.len()];
2516 vidya_into_slice(&mut output, &input, Kernel::Auto)
2517 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2518
2519 Ok(output)
2520}
2521
2522#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2523#[wasm_bindgen]
2524pub fn vidya_into(
2525 in_ptr: *const f64,
2526 out_ptr: *mut f64,
2527 len: usize,
2528 short_period: usize,
2529 long_period: usize,
2530 alpha: f64,
2531) -> Result<(), JsValue> {
2532 if in_ptr.is_null() || out_ptr.is_null() {
2533 return Err(JsValue::from_str("Null pointer provided"));
2534 }
2535
2536 unsafe {
2537 let data = std::slice::from_raw_parts(in_ptr, len);
2538 let params = VidyaParams {
2539 short_period: Some(short_period),
2540 long_period: Some(long_period),
2541 alpha: Some(alpha),
2542 };
2543 let input = VidyaInput::from_slice(data, params);
2544
2545 if in_ptr == out_ptr {
2546 let mut temp = vec![0.0; len];
2547 vidya_into_slice(&mut temp, &input, Kernel::Auto)
2548 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2549 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2550 out.copy_from_slice(&temp);
2551 } else {
2552 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2553 vidya_into_slice(out, &input, Kernel::Auto)
2554 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2555 }
2556 Ok(())
2557 }
2558}
2559
2560#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2561#[wasm_bindgen]
2562pub fn vidya_alloc(len: usize) -> *mut f64 {
2563 let mut vec = Vec::<f64>::with_capacity(len);
2564 let ptr = vec.as_mut_ptr();
2565 std::mem::forget(vec);
2566 ptr
2567}
2568
2569#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2570#[wasm_bindgen]
2571pub fn vidya_free(ptr: *mut f64, len: usize) {
2572 if !ptr.is_null() {
2573 unsafe {
2574 let _ = Vec::from_raw_parts(ptr, len, len);
2575 }
2576 }
2577}
2578
2579#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2580#[derive(Serialize, Deserialize)]
2581pub struct VidyaBatchConfig {
2582 pub short_period_range: (usize, usize, usize),
2583 pub long_period_range: (usize, usize, usize),
2584 pub alpha_range: (f64, f64, f64),
2585}
2586
2587#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2588#[derive(Serialize, Deserialize)]
2589pub struct VidyaBatchJsOutput {
2590 pub values: Vec<f64>,
2591 pub combos: Vec<VidyaParams>,
2592 pub rows: usize,
2593 pub cols: usize,
2594}
2595
2596#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2597#[wasm_bindgen(js_name = vidya_batch)]
2598pub fn vidya_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2599 let config: VidyaBatchConfig = serde_wasm_bindgen::from_value(config)
2600 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2601
2602 let sweep = VidyaBatchRange {
2603 short_period: config.short_period_range,
2604 long_period: config.long_period_range,
2605 alpha: config.alpha_range,
2606 };
2607
2608 let mut output = vec![0.0; data.len() * 232];
2609 let kernel = match detect_best_kernel() {
2610 Kernel::Avx512 => Kernel::Avx2,
2611 other => other,
2612 };
2613 let combos = vidya_batch_inner_into(data, &sweep, kernel, false, &mut output)
2614 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2615
2616 let rows = combos.len();
2617 let cols = data.len();
2618 output.truncate(rows * cols);
2619
2620 let result = VidyaBatchJsOutput {
2621 values: output,
2622 combos,
2623 rows,
2624 cols,
2625 };
2626
2627 serde_wasm_bindgen::to_value(&result)
2628 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2629}
2630
2631#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2632#[wasm_bindgen]
2633pub fn vidya_batch_into(
2634 in_ptr: *const f64,
2635 out_ptr: *mut f64,
2636 len: usize,
2637 short_period_start: usize,
2638 short_period_end: usize,
2639 short_period_step: usize,
2640 long_period_start: usize,
2641 long_period_end: usize,
2642 long_period_step: usize,
2643 alpha_start: f64,
2644 alpha_end: f64,
2645 alpha_step: f64,
2646) -> Result<usize, JsValue> {
2647 if in_ptr.is_null() || out_ptr.is_null() {
2648 return Err(JsValue::from_str("Null pointer provided"));
2649 }
2650
2651 unsafe {
2652 let data = std::slice::from_raw_parts(in_ptr, len);
2653 let sweep = VidyaBatchRange {
2654 short_period: (short_period_start, short_period_end, short_period_step),
2655 long_period: (long_period_start, long_period_end, long_period_step),
2656 alpha: (alpha_start, alpha_end, alpha_step),
2657 };
2658
2659 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2660 let total_size = combos
2661 .len()
2662 .checked_mul(len)
2663 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
2664 let out = std::slice::from_raw_parts_mut(out_ptr, total_size);
2665
2666 let kernel = match detect_best_kernel() {
2667 Kernel::Avx512 => Kernel::Avx2,
2668 other => other,
2669 };
2670 vidya_batch_inner_into(data, &sweep, kernel, false, out)
2671 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2672
2673 Ok(combos.len())
2674 }
2675}