1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
10use serde::{Deserialize, Serialize};
11#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
12use wasm_bindgen::prelude::*;
13
14use crate::utilities::data_loader::{source_type, Candles};
15use crate::utilities::enums::Kernel;
16use crate::utilities::helpers::{
17 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
18 make_uninit_matrix,
19};
20#[cfg(feature = "python")]
21use crate::utilities::kernel_validation::validate_kernel;
22#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
23use core::arch::x86_64::*;
24#[cfg(not(target_arch = "wasm32"))]
25use rayon::prelude::*;
26use std::convert::AsRef;
27use std::error::Error;
28use thiserror::Error;
29
30#[derive(Debug, Clone)]
31pub enum VarData<'a> {
32 Candles {
33 candles: &'a Candles,
34 source: &'a str,
35 },
36 Slice(&'a [f64]),
37}
38
39impl<'a> AsRef<[f64]> for VarInput<'a> {
40 #[inline(always)]
41 fn as_ref(&self) -> &[f64] {
42 match &self.data {
43 VarData::Slice(slice) => slice,
44 VarData::Candles { candles, source } => source_type(candles, source),
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
50pub struct VarOutput {
51 pub values: Vec<f64>,
52}
53
54#[derive(Debug, Clone)]
55#[cfg_attr(
56 all(target_arch = "wasm32", feature = "wasm"),
57 derive(Serialize, Deserialize)
58)]
59pub struct VarParams {
60 pub period: Option<usize>,
61 pub nbdev: Option<f64>,
62}
63
64impl Default for VarParams {
65 fn default() -> Self {
66 Self {
67 period: Some(14),
68 nbdev: Some(1.0),
69 }
70 }
71}
72
73#[derive(Debug, Clone)]
74pub struct VarInput<'a> {
75 pub data: VarData<'a>,
76 pub params: VarParams,
77}
78
79impl<'a> VarInput<'a> {
80 #[inline]
81 pub fn from_candles(c: &'a Candles, s: &'a str, p: VarParams) -> Self {
82 Self {
83 data: VarData::Candles {
84 candles: c,
85 source: s,
86 },
87 params: p,
88 }
89 }
90 #[inline]
91 pub fn from_slice(sl: &'a [f64], p: VarParams) -> Self {
92 Self {
93 data: VarData::Slice(sl),
94 params: p,
95 }
96 }
97 #[inline]
98 pub fn with_default_candles(c: &'a Candles) -> Self {
99 Self::from_candles(c, "close", VarParams::default())
100 }
101 #[inline]
102 pub fn get_period(&self) -> usize {
103 self.params.period.unwrap_or(14)
104 }
105 #[inline]
106 pub fn get_nbdev(&self) -> f64 {
107 self.params.nbdev.unwrap_or(1.0)
108 }
109}
110
111#[derive(Copy, Clone, Debug)]
112pub struct VarBuilder {
113 period: Option<usize>,
114 nbdev: Option<f64>,
115 kernel: Kernel,
116}
117
118impl Default for VarBuilder {
119 fn default() -> Self {
120 Self {
121 period: None,
122 nbdev: None,
123 kernel: Kernel::Auto,
124 }
125 }
126}
127
128impl VarBuilder {
129 #[inline(always)]
130 pub fn new() -> Self {
131 Self::default()
132 }
133 #[inline(always)]
134 pub fn period(mut self, n: usize) -> Self {
135 self.period = Some(n);
136 self
137 }
138 #[inline(always)]
139 pub fn nbdev(mut self, x: f64) -> Self {
140 self.nbdev = Some(x);
141 self
142 }
143 #[inline(always)]
144 pub fn kernel(mut self, k: Kernel) -> Self {
145 self.kernel = k;
146 self
147 }
148 #[inline(always)]
149 pub fn apply(self, c: &Candles) -> Result<VarOutput, VarError> {
150 let p = VarParams {
151 period: self.period,
152 nbdev: self.nbdev,
153 };
154 let i = VarInput::from_candles(c, "close", p);
155 var_with_kernel(&i, self.kernel)
156 }
157 #[inline(always)]
158 pub fn apply_slice(self, d: &[f64]) -> Result<VarOutput, VarError> {
159 let p = VarParams {
160 period: self.period,
161 nbdev: self.nbdev,
162 };
163 let i = VarInput::from_slice(d, p);
164 var_with_kernel(&i, self.kernel)
165 }
166 #[inline(always)]
167 pub fn into_stream(self) -> Result<VarStream, VarError> {
168 let p = VarParams {
169 period: self.period,
170 nbdev: self.nbdev,
171 };
172 VarStream::try_new(p)
173 }
174}
175
176#[derive(Debug, Error)]
177pub enum VarError {
178 #[error("var: input data is empty (All values are NaN).")]
179 EmptyInputData,
180 #[error("var: All values are NaN.")]
181 AllValuesNaN,
182 #[error("var: Invalid period: period = {period}, data length = {data_len}")]
183 InvalidPeriod { period: usize, data_len: usize },
184 #[error("var: Not enough valid data: needed = {needed}, valid = {valid}")]
185 NotEnoughValidData { needed: usize, valid: usize },
186 #[error("var: output length mismatch: expected = {expected}, got = {got}")]
187 OutputLengthMismatch { expected: usize, got: usize },
188 #[error("var: invalid range: start = {start}, end = {end}, step = {step}")]
189 InvalidRange { start: f64, end: f64, step: f64 },
190 #[error("var: invalid kernel for batch: {0:?}")]
191 InvalidKernelForBatch(Kernel),
192 #[error("var: nbdev is NaN or infinite: {nbdev}")]
193 InvalidNbdev { nbdev: f64 },
194}
195
196#[inline]
197pub fn var(input: &VarInput) -> Result<VarOutput, VarError> {
198 var_with_kernel(input, Kernel::Auto)
199}
200
201pub fn var_with_kernel(input: &VarInput, kernel: Kernel) -> Result<VarOutput, VarError> {
202 let data: &[f64] = input.as_ref();
203 let len = data.len();
204 if len == 0 {
205 return Err(VarError::EmptyInputData);
206 }
207 let first = data
208 .iter()
209 .position(|x| !x.is_nan())
210 .ok_or(VarError::AllValuesNaN)?;
211 let period = input.get_period();
212 let nbdev = input.get_nbdev();
213
214 if period == 0 || period > len {
215 return Err(VarError::InvalidPeriod {
216 period,
217 data_len: len,
218 });
219 }
220 if (len - first) < period {
221 return Err(VarError::NotEnoughValidData {
222 needed: period,
223 valid: len - first,
224 });
225 }
226 if nbdev.is_nan() || nbdev.is_infinite() {
227 return Err(VarError::InvalidNbdev { nbdev });
228 }
229
230 let chosen = match kernel {
231 Kernel::Auto => match detect_best_kernel() {
232 Kernel::Avx512 => Kernel::Avx2,
233 other => other,
234 },
235 other => other,
236 };
237
238 let mut out = alloc_with_nan_prefix(len, first + period - 1);
239
240 unsafe {
241 match chosen {
242 Kernel::Scalar | Kernel::ScalarBatch => {
243 var_scalar(data, period, first, nbdev, &mut out)?
244 }
245 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
246 Kernel::Avx2 | Kernel::Avx2Batch => var_avx2(data, period, first, nbdev, &mut out)?,
247 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
248 Kernel::Avx512 | Kernel::Avx512Batch => {
249 var_avx512(data, period, first, nbdev, &mut out)?
250 }
251 _ => unreachable!(),
252 }
253 }
254
255 Ok(VarOutput { values: out })
256}
257
258#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
259pub fn var_into(input: &VarInput, out: &mut [f64]) -> Result<(), VarError> {
260 let data: &[f64] = input.as_ref();
261 let len = data.len();
262 if len == 0 {
263 return Err(VarError::EmptyInputData);
264 }
265 if out.len() != len {
266 return Err(VarError::OutputLengthMismatch {
267 expected: len,
268 got: out.len(),
269 });
270 }
271 let first = data
272 .iter()
273 .position(|x| !x.is_nan())
274 .ok_or(VarError::AllValuesNaN)?;
275 let period = input.get_period();
276 let nbdev = input.get_nbdev();
277
278 if period == 0 || period > len {
279 return Err(VarError::InvalidPeriod {
280 period,
281 data_len: len,
282 });
283 }
284 if (len - first) < period {
285 return Err(VarError::NotEnoughValidData {
286 needed: period,
287 valid: len - first,
288 });
289 }
290 if nbdev.is_nan() || nbdev.is_infinite() {
291 return Err(VarError::InvalidNbdev { nbdev });
292 }
293
294 let warmup_end = (first + period - 1).min(len);
295 for v in &mut out[..warmup_end] {
296 *v = f64::from_bits(0x7ff8_0000_0000_0000);
297 }
298
299 let chosen = match detect_best_kernel() {
300 Kernel::Avx512 => Kernel::Avx2,
301 other => other,
302 };
303
304 unsafe {
305 match chosen {
306 Kernel::Scalar | Kernel::ScalarBatch => var_scalar(data, period, first, nbdev, out)?,
307 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
308 Kernel::Avx2 | Kernel::Avx2Batch => var_avx2(data, period, first, nbdev, out)?,
309 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
310 Kernel::Avx512 | Kernel::Avx512Batch => var_avx512(data, period, first, nbdev, out)?,
311 _ => unreachable!(),
312 }
313 }
314
315 Ok(())
316}
317
318#[inline(always)]
319pub fn var_scalar(
320 data: &[f64],
321 period: usize,
322 first: usize,
323 nbdev: f64,
324 out: &mut [f64],
325) -> Result<(), VarError> {
326 let len = data.len();
327 let nbdev2 = nbdev * nbdev;
328 let inv_p = 1.0 / (period as f64);
329
330 let mut sum = 0.0f64;
331 let mut sum_sq = 0.0f64;
332 let init = &data[first..first + period];
333 let mut it = init.chunks_exact(4);
334 for c in &mut it {
335 let x0 = c[0];
336 let x1 = c[1];
337 let x2 = c[2];
338 let x3 = c[3];
339 sum += x0 + x1 + x2 + x3;
340 sum_sq += x0 * x0 + x1 * x1 + x2 * x2 + x3 * x3;
341 }
342 for &x in it.remainder() {
343 sum += x;
344 sum_sq += x * x;
345 }
346
347 let idx0 = first + period - 1;
348 let mean0 = sum * inv_p;
349 out[idx0] = (sum_sq * inv_p - mean0 * mean0) * nbdev2;
350
351 unsafe {
352 let mut out_ptr = out.as_mut_ptr().add(idx0 + 1);
353 let mut in_new = data.as_ptr().add(first + period);
354 let mut in_old = data.as_ptr().add(first);
355 let end = data.as_ptr().add(len);
356
357 while in_new < end {
358 let new = *in_new;
359 let old = *in_old;
360
361 sum += new - old;
362 sum_sq += new * new - old * old;
363
364 let mean = sum * inv_p;
365 *out_ptr = (sum_sq * inv_p - mean * mean) * nbdev2;
366
367 in_new = in_new.add(1);
368 in_old = in_old.add(1);
369 out_ptr = out_ptr.add(1);
370 }
371 }
372
373 Ok(())
374}
375
376#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
377#[inline(always)]
378pub fn var_avx2(
379 data: &[f64],
380 period: usize,
381 first: usize,
382 nbdev: f64,
383 out: &mut [f64],
384) -> Result<(), VarError> {
385 let len = data.len();
386 let nbdev2 = nbdev * nbdev;
387 let inv_p = 1.0 / (period as f64);
388
389 let mut sum = 0.0f64;
390 let mut sum_sq = 0.0f64;
391 unsafe {
392 let mut idx = first;
393 let end = first + period;
394
395 while idx + 4 <= end {
396 let v = _mm256_loadu_pd(data.as_ptr().add(idx));
397 let v2 = _mm256_mul_pd(v, v);
398
399 let mut lanes: [f64; 4] = core::mem::zeroed();
400 _mm256_storeu_pd(lanes.as_mut_ptr(), v);
401 let mut t = lanes[0] + lanes[1];
402 t = t + lanes[2];
403 t = t + lanes[3];
404 sum += t;
405
406 _mm256_storeu_pd(lanes.as_mut_ptr(), v2);
407 let mut t2 = lanes[0] + lanes[1];
408 t2 = t2 + lanes[2];
409 t2 = t2 + lanes[3];
410 sum_sq += t2;
411
412 idx += 4;
413 }
414
415 while idx < end {
416 let x = *data.get_unchecked(idx);
417 sum += x;
418 sum_sq += x * x;
419 idx += 1;
420 }
421 }
422
423 let idx0 = first + period - 1;
424 let mean0 = sum * inv_p;
425 out[idx0] = (sum_sq * inv_p - mean0 * mean0) * nbdev2;
426
427 unsafe {
428 let mut out_ptr = out.as_mut_ptr().add(idx0 + 1);
429 let mut in_new = data.as_ptr().add(first + period);
430 let mut in_old = data.as_ptr().add(first);
431 let end = data.as_ptr().add(len);
432
433 while in_new < end {
434 let new = *in_new;
435 let old = *in_old;
436
437 sum += new - old;
438 sum_sq += new * new - old * old;
439
440 let mean = sum * inv_p;
441 *out_ptr = (sum_sq * inv_p - mean * mean) * nbdev2;
442
443 in_new = in_new.add(1);
444 in_old = in_old.add(1);
445 out_ptr = out_ptr.add(1);
446 }
447 }
448
449 Ok(())
450}
451
452#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
453#[inline(always)]
454pub fn var_avx512(
455 data: &[f64],
456 period: usize,
457 first: usize,
458 nbdev: f64,
459 out: &mut [f64],
460) -> Result<(), VarError> {
461 use core::arch::x86_64::*;
462
463 let len = data.len();
464 let nbdev2 = nbdev * nbdev;
465 let inv_p = 1.0 / (period as f64);
466
467 let mut sum = 0.0f64;
468 let mut sum_sq = 0.0f64;
469
470 unsafe {
471 let mut idx = first;
472 let end = first + period;
473 while idx + 8 <= end {
474 let v = _mm512_loadu_pd(data.as_ptr().add(idx));
475 let v2 = _mm512_mul_pd(v, v);
476
477 let mut lanes: [f64; 8] = core::mem::zeroed();
478 _mm512_storeu_pd(lanes.as_mut_ptr(), v);
479
480 let mut t0 = lanes[0] + lanes[1];
481 t0 = t0 + lanes[2];
482 t0 = t0 + lanes[3];
483 sum += t0;
484
485 let mut t1 = lanes[4] + lanes[5];
486 t1 = t1 + lanes[6];
487 t1 = t1 + lanes[7];
488 sum += t1;
489
490 _mm512_storeu_pd(lanes.as_mut_ptr(), v2);
491 let mut u0 = lanes[0] + lanes[1];
492 u0 = u0 + lanes[2];
493 u0 = u0 + lanes[3];
494 sum_sq += u0;
495 let mut u1 = lanes[4] + lanes[5];
496 u1 = u1 + lanes[6];
497 u1 = u1 + lanes[7];
498 sum_sq += u1;
499
500 idx += 8;
501 }
502
503 while idx + 4 <= end {
504 let v4 = _mm256_loadu_pd(data.as_ptr().add(idx));
505 let v4sq = _mm256_mul_pd(v4, v4);
506 let mut lanes4: [f64; 4] = core::mem::zeroed();
507 _mm256_storeu_pd(lanes4.as_mut_ptr(), v4);
508 let mut t = lanes4[0] + lanes4[1];
509 t = t + lanes4[2];
510 t = t + lanes4[3];
511 sum += t;
512 _mm256_storeu_pd(lanes4.as_mut_ptr(), v4sq);
513 let mut u = lanes4[0] + lanes4[1];
514 u = u + lanes4[2];
515 u = u + lanes4[3];
516 sum_sq += u;
517 idx += 4;
518 }
519
520 while idx < end {
521 let x = *data.get_unchecked(idx);
522 sum += x;
523 sum_sq += x * x;
524 idx += 1;
525 }
526
527 let idx0 = first + period - 1;
528 let mean0 = sum * inv_p;
529 out[idx0] = (sum_sq * inv_p - mean0 * mean0) * nbdev2;
530
531 let mut out_ptr = out.as_mut_ptr().add(idx0 + 1);
532 let mut in_new = data.as_ptr().add(first + period);
533 let mut in_old = data.as_ptr().add(first);
534 let end_ptr = data.as_ptr().add(len);
535
536 while in_new < end_ptr {
537 let new = *in_new;
538 let old = *in_old;
539
540 sum += new - old;
541 sum_sq += new * new - old * old;
542
543 let mean = sum * inv_p;
544 *out_ptr = (sum_sq * inv_p - mean * mean) * nbdev2;
545
546 in_new = in_new.add(1);
547 in_old = in_old.add(1);
548 out_ptr = out_ptr.add(1);
549 }
550 }
551
552 Ok(())
553}
554
555#[inline]
556pub fn var_into_slice(dst: &mut [f64], input: &VarInput, kern: Kernel) -> Result<(), VarError> {
557 let data: &[f64] = input.as_ref();
558 let len = data.len();
559 if len == 0 {
560 return Err(VarError::EmptyInputData);
561 }
562 if dst.len() != len {
563 return Err(VarError::OutputLengthMismatch {
564 expected: len,
565 got: dst.len(),
566 });
567 }
568 let first = data
569 .iter()
570 .position(|x| !x.is_nan())
571 .ok_or(VarError::AllValuesNaN)?;
572 let period = input.get_period();
573 let nbdev = input.get_nbdev();
574
575 if period == 0 || period > len {
576 return Err(VarError::InvalidPeriod {
577 period,
578 data_len: len,
579 });
580 }
581 if (len - first) < period {
582 return Err(VarError::NotEnoughValidData {
583 needed: period,
584 valid: len - first,
585 });
586 }
587 if nbdev.is_nan() || nbdev.is_infinite() {
588 return Err(VarError::InvalidNbdev { nbdev });
589 }
590 let chosen = match kern {
591 Kernel::Auto => Kernel::Scalar,
592 other => other,
593 };
594
595 unsafe {
596 match chosen {
597 Kernel::Scalar | Kernel::ScalarBatch => {
598 var_scalar(data, period, first, nbdev, dst)?;
599 }
600 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
601 Kernel::Avx2 | Kernel::Avx2Batch => {
602 var_avx2(data, period, first, nbdev, dst)?;
603 }
604 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
605 Kernel::Avx512 | Kernel::Avx512Batch => {
606 var_avx512(data, period, first, nbdev, dst)?;
607 }
608 _ => unreachable!(),
609 }
610 }
611
612 let warmup_end = first + period - 1;
613 for v in &mut dst[..warmup_end] {
614 *v = f64::NAN;
615 }
616
617 Ok(())
618}
619
620#[inline(always)]
621pub unsafe fn var_row_scalar(
622 data: &[f64],
623 first: usize,
624 period: usize,
625 _stride: usize,
626 nbdev: f64,
627 out: &mut [f64],
628) {
629 let len = data.len();
630 let inv_p = 1.0 / (period as f64);
631 let nbdev2 = nbdev * nbdev;
632
633 let mut sum = 0.0f64;
634 let mut sum_sq = 0.0f64;
635
636 let mut p = data.as_ptr().add(first);
637 let end = p.add(period);
638 while p.add(4) <= end {
639 let x0 = *p;
640 let x1 = *p.add(1);
641 let x2 = *p.add(2);
642 let x3 = *p.add(3);
643 sum += x0 + x1 + x2 + x3;
644 sum_sq += x0 * x0 + x1 * x1 + x2 * x2 + x3 * x3;
645 p = p.add(4);
646 }
647 while p < end {
648 let x = *p;
649 sum += x;
650 sum_sq += x * x;
651 p = p.add(1);
652 }
653
654 let idx0 = first + period - 1;
655 let mean0 = sum * inv_p;
656 *out.get_unchecked_mut(idx0) = (sum_sq * inv_p - mean0 * mean0) * nbdev2;
657
658 let mut i = idx0 + 1;
659 while i + 1 < len {
660 {
661 let new0 = *data.get_unchecked(i);
662 let old0 = *data.get_unchecked(i - period);
663 sum += new0 - old0;
664 sum_sq += new0 * new0 - old0 * old0;
665 let mean = sum * inv_p;
666 *out.get_unchecked_mut(i) = (sum_sq * inv_p - mean * mean) * nbdev2;
667 }
668
669 {
670 let new1 = *data.get_unchecked(i + 1);
671 let old1 = *data.get_unchecked(i + 1 - period);
672 sum += new1 - old1;
673 sum_sq += new1 * new1 - old1 * old1;
674 let mean = sum * inv_p;
675 *out.get_unchecked_mut(i + 1) = (sum_sq * inv_p - mean * mean) * nbdev2;
676 }
677
678 i += 2;
679 }
680
681 if i < len {
682 let newx = *data.get_unchecked(i);
683 let oldx = *data.get_unchecked(i - period);
684 sum += newx - oldx;
685 sum_sq += newx * newx - oldx * oldx;
686 let mean = sum * inv_p;
687 *out.get_unchecked_mut(i) = (sum_sq * inv_p - mean * mean) * nbdev2;
688 }
689}
690
691#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
692#[inline(always)]
693pub unsafe fn var_row_avx2(
694 data: &[f64],
695 first: usize,
696 period: usize,
697 stride: usize,
698 nbdev: f64,
699 out: &mut [f64],
700) {
701 var_row_scalar(data, first, period, stride, nbdev, out)
702}
703
704#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
705#[inline(always)]
706pub unsafe fn var_row_avx512(
707 data: &[f64],
708 first: usize,
709 period: usize,
710 stride: usize,
711 nbdev: f64,
712 out: &mut [f64],
713) {
714 if period <= 32 {
715 var_row_avx512_short(data, first, period, stride, nbdev, out);
716 } else {
717 var_row_avx512_long(data, first, period, stride, nbdev, out);
718 }
719}
720
721#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
722#[inline(always)]
723pub unsafe fn var_row_avx512_short(
724 data: &[f64],
725 first: usize,
726 period: usize,
727 stride: usize,
728 nbdev: f64,
729 out: &mut [f64],
730) {
731 var_row_scalar(data, first, period, stride, nbdev, out)
732}
733
734#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
735#[inline(always)]
736pub unsafe fn var_row_avx512_long(
737 data: &[f64],
738 first: usize,
739 period: usize,
740 stride: usize,
741 nbdev: f64,
742 out: &mut [f64],
743) {
744 var_row_scalar(data, first, period, stride, nbdev, out)
745}
746
747#[derive(Clone, Debug)]
748pub struct VarBatchRange {
749 pub period: (usize, usize, usize),
750 pub nbdev: (f64, f64, f64),
751}
752
753impl Default for VarBatchRange {
754 fn default() -> Self {
755 Self {
756 period: (14, 263, 1),
757 nbdev: (1.0, 1.0, 0.0),
758 }
759 }
760}
761
762#[derive(Clone, Debug, Default)]
763pub struct VarBatchBuilder {
764 range: VarBatchRange,
765 kernel: Kernel,
766}
767
768impl VarBatchBuilder {
769 pub fn new() -> Self {
770 Self::default()
771 }
772 pub fn kernel(mut self, k: Kernel) -> Self {
773 self.kernel = k;
774 self
775 }
776 #[inline]
777 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
778 self.range.period = (start, end, step);
779 self
780 }
781 #[inline]
782 pub fn period_static(mut self, p: usize) -> Self {
783 self.range.period = (p, p, 0);
784 self
785 }
786 #[inline]
787 pub fn nbdev_range(mut self, start: f64, end: f64, step: f64) -> Self {
788 self.range.nbdev = (start, end, step);
789 self
790 }
791 #[inline]
792 pub fn nbdev_static(mut self, x: f64) -> Self {
793 self.range.nbdev = (x, x, 0.0);
794 self
795 }
796 pub fn apply_slice(self, data: &[f64]) -> Result<VarBatchOutput, VarError> {
797 var_batch_with_kernel(data, &self.range, self.kernel)
798 }
799 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<VarBatchOutput, VarError> {
800 VarBatchBuilder::new().kernel(k).apply_slice(data)
801 }
802 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<VarBatchOutput, VarError> {
803 let slice = source_type(c, src);
804 self.apply_slice(slice)
805 }
806 pub fn with_default_candles(c: &Candles) -> Result<VarBatchOutput, VarError> {
807 VarBatchBuilder::new()
808 .kernel(Kernel::Auto)
809 .apply_candles(c, "close")
810 }
811}
812
813#[derive(Clone, Debug)]
814pub struct VarBatchOutput {
815 pub values: Vec<f64>,
816 pub combos: Vec<VarParams>,
817 pub rows: usize,
818 pub cols: usize,
819}
820impl VarBatchOutput {
821 pub fn row_for_params(&self, p: &VarParams) -> Option<usize> {
822 self.combos.iter().position(|c| {
823 c.period.unwrap_or(14) == p.period.unwrap_or(14)
824 && (c.nbdev.unwrap_or(1.0) - p.nbdev.unwrap_or(1.0)).abs() < 1e-12
825 })
826 }
827 pub fn values_for(&self, p: &VarParams) -> Option<&[f64]> {
828 self.row_for_params(p).map(|row| {
829 let start = row * self.cols;
830 &self.values[start..start + self.cols]
831 })
832 }
833}
834
835#[inline(always)]
836fn expand_grid(r: &VarBatchRange) -> Vec<VarParams> {
837 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
838 if step == 0 || start == end {
839 return vec![start];
840 }
841 let mut v = Vec::new();
842 if start < end {
843 let mut x = start;
844 while x <= end {
845 v.push(x);
846 match x.checked_add(step) {
847 Some(next) if next > x => x = next,
848 _ => break,
849 }
850 }
851 } else {
852 let mut x = start;
853 while x >= end {
854 v.push(x);
855 match x.checked_sub(step) {
856 Some(next) if next < x => x = next,
857 _ => break,
858 }
859 }
860 }
861 v
862 }
863 fn axis_f64((start, end, step): (f64, f64, f64)) -> Vec<f64> {
864 let eps = 1e-12;
865 if step.abs() < eps || (start - end).abs() < eps {
866 return vec![start];
867 }
868 let mut v = Vec::new();
869 let mut x = start;
870 if step > 0.0 {
871 if start > end + eps {
872 return v;
873 }
874 while x <= end + eps {
875 v.push(x);
876 x += step;
877 }
878 } else {
879 if start < end - eps {
880 return v;
881 }
882 while x >= end - eps {
883 v.push(x);
884 x += step;
885 }
886 }
887 v
888 }
889 let periods = axis_usize(r.period);
890 let nbdevs = axis_f64(r.nbdev);
891 let capacity = periods.len().checked_mul(nbdevs.len()).unwrap_or(0);
892 let mut out = Vec::with_capacity(capacity);
893 for &p in &periods {
894 for &n in &nbdevs {
895 out.push(VarParams {
896 period: Some(p),
897 nbdev: Some(n),
898 });
899 }
900 }
901 out
902}
903
904#[inline(always)]
905pub fn var_expand_grid(r: &VarBatchRange) -> Vec<VarParams> {
906 expand_grid(r)
907}
908
909#[inline(always)]
910pub fn var_batch_slice(
911 data: &[f64],
912 sweep: &VarBatchRange,
913 kern: Kernel,
914) -> Result<VarBatchOutput, VarError> {
915 var_batch_inner(data, sweep, kern, false)
916}
917#[inline(always)]
918pub fn var_batch_par_slice(
919 data: &[f64],
920 sweep: &VarBatchRange,
921 kern: Kernel,
922) -> Result<VarBatchOutput, VarError> {
923 var_batch_inner(data, sweep, kern, true)
924}
925
926fn round_up8(x: usize) -> usize {
927 (x + 7) & !7
928}
929
930const ENABLE_VAR_BATCH_PREFIX_SUMS: bool = false;
931
932#[inline(always)]
933fn make_prefix_sums(data: &[f64], first: usize) -> (Vec<f64>, Vec<f64>) {
934 let len = data.len();
935 let mut ps = vec![0.0f64; len + 1];
936 let mut psq = vec![0.0f64; len + 1];
937 let mut s = 0.0f64;
938 let mut s2 = 0.0f64;
939 for j in first..len {
940 let x = data[j];
941 s += x;
942 s2 += x * x;
943 ps[j + 1] = s;
944 psq[j + 1] = s2;
945 }
946 (ps, psq)
947}
948
949#[inline(always)]
950fn var_batch_inner(
951 data: &[f64],
952 sweep: &VarBatchRange,
953 kern: Kernel,
954 parallel: bool,
955) -> Result<VarBatchOutput, VarError> {
956 if data.is_empty() {
957 return Err(VarError::EmptyInputData);
958 }
959 let combos = expand_grid(sweep);
960 if combos.is_empty() {
961 return Err(VarError::InvalidRange {
962 start: sweep.period.0 as f64,
963 end: sweep.period.1 as f64,
964 step: sweep.period.2 as f64,
965 });
966 }
967 let first = data
968 .iter()
969 .position(|x| !x.is_nan())
970 .ok_or(VarError::AllValuesNaN)?;
971 let max_period = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
972 if data.len() - first < max_period {
973 return Err(VarError::NotEnoughValidData {
974 needed: max_period,
975 valid: data.len() - first,
976 });
977 }
978 let stride = round_up8(max_period);
979 let rows = combos.len();
980 let cols = data.len();
981 let mut buf_mu = make_uninit_matrix(rows, cols);
982
983 let warmup_periods: Vec<usize> = combos
984 .iter()
985 .map(|c| first + c.period.unwrap() - 1)
986 .collect();
987 init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
988
989 let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
990 let out: &mut [f64] = unsafe {
991 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
992 };
993
994 if ENABLE_VAR_BATCH_PREFIX_SUMS && matches!(kern, Kernel::Scalar) {
995 let (ps, psq) = make_prefix_sums(data, first);
996 let compute_row = |row: usize, out_row: &mut [f64]| {
997 let period = combos[row].period.unwrap();
998 let inv_p = 1.0 / (period as f64);
999 let nbdev2 = combos[row].nbdev.unwrap() * combos[row].nbdev.unwrap();
1000 let start = first + period - 1;
1001 for i in start..cols {
1002 let s = ps[i + 1] - ps[i + 1 - period];
1003 let s2 = psq[i + 1] - psq[i + 1 - period];
1004 let m = s * inv_p;
1005 out_row[i] = (s2 * inv_p - m * m) * nbdev2;
1006 }
1007 };
1008
1009 if parallel {
1010 #[cfg(not(target_arch = "wasm32"))]
1011 {
1012 out.par_chunks_mut(cols)
1013 .enumerate()
1014 .for_each(|(row, slice)| compute_row(row, slice));
1015 }
1016
1017 #[cfg(target_arch = "wasm32")]
1018 {
1019 for (row, slice) in out.chunks_mut(cols).enumerate() {
1020 compute_row(row, slice);
1021 }
1022 }
1023 } else {
1024 for (row, slice) in out.chunks_mut(cols).enumerate() {
1025 compute_row(row, slice);
1026 }
1027 }
1028 } else {
1029 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1030 let period = combos[row].period.unwrap();
1031 let nbdev = combos[row].nbdev.unwrap();
1032 match kern {
1033 Kernel::Scalar => var_row_scalar(data, first, period, stride, nbdev, out_row),
1034 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1035 Kernel::Avx2 => var_row_avx2(data, first, period, stride, nbdev, out_row),
1036 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1037 Kernel::Avx512 => var_row_avx512(data, first, period, stride, nbdev, out_row),
1038 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1039 Kernel::Avx2 | Kernel::Avx512 => {
1040 var_row_scalar(data, first, period, stride, nbdev, out_row)
1041 }
1042 _ => unreachable!(),
1043 }
1044 };
1045 if parallel {
1046 #[cfg(not(target_arch = "wasm32"))]
1047 {
1048 out.par_chunks_mut(cols)
1049 .enumerate()
1050 .for_each(|(row, slice)| do_row(row, slice));
1051 }
1052
1053 #[cfg(target_arch = "wasm32")]
1054 {
1055 for (row, slice) in out.chunks_mut(cols).enumerate() {
1056 do_row(row, slice);
1057 }
1058 }
1059 } else {
1060 for (row, slice) in out.chunks_mut(cols).enumerate() {
1061 do_row(row, slice);
1062 }
1063 }
1064 }
1065
1066 let values = unsafe {
1067 Vec::from_raw_parts(
1068 buf_guard.as_mut_ptr() as *mut f64,
1069 buf_guard.len(),
1070 buf_guard.capacity(),
1071 )
1072 };
1073 std::mem::forget(buf_guard);
1074
1075 Ok(VarBatchOutput {
1076 values,
1077 combos,
1078 rows,
1079 cols,
1080 })
1081}
1082
1083pub fn var_batch_with_kernel(
1084 data: &[f64],
1085 sweep: &VarBatchRange,
1086 k: Kernel,
1087) -> Result<VarBatchOutput, VarError> {
1088 let kernel = match k {
1089 Kernel::Auto => detect_best_batch_kernel(),
1090 other if other.is_batch() => other,
1091 _ => return Err(VarError::InvalidKernelForBatch(k)),
1092 };
1093 let simd = match kernel {
1094 Kernel::Avx512Batch => Kernel::Avx512,
1095 Kernel::Avx2Batch => Kernel::Avx2,
1096 Kernel::ScalarBatch => Kernel::Scalar,
1097 _ => unreachable!(),
1098 };
1099 var_batch_par_slice(data, sweep, simd)
1100}
1101
1102#[derive(Debug, Clone)]
1103pub struct VarStream {
1104 period: usize,
1105 nbdev: f64,
1106 buffer: Vec<f64>,
1107 sum: f64,
1108 sum_sq: f64,
1109 head: usize,
1110 filled: bool,
1111}
1112impl VarStream {
1113 pub fn try_new(params: VarParams) -> Result<Self, VarError> {
1114 let period = params.period.unwrap_or(14);
1115 if period == 0 {
1116 return Err(VarError::InvalidPeriod {
1117 period,
1118 data_len: 0,
1119 });
1120 }
1121 let nbdev = params.nbdev.unwrap_or(1.0);
1122 if nbdev.is_nan() || nbdev.is_infinite() {
1123 return Err(VarError::InvalidNbdev { nbdev });
1124 }
1125 Ok(Self {
1126 period,
1127 nbdev,
1128 buffer: vec![f64::NAN; period],
1129 sum: 0.0,
1130 sum_sq: 0.0,
1131 head: 0,
1132 filled: false,
1133 })
1134 }
1135 #[inline(always)]
1136 pub fn update(&mut self, value: f64) -> Option<f64> {
1137 let old = self.buffer[self.head];
1138 self.buffer[self.head] = value;
1139 self.head = (self.head + 1) % self.period;
1140 if !self.filled {
1141 self.sum += value;
1142 self.sum_sq += value * value;
1143 if self.head == 0 {
1144 self.filled = true;
1145 } else {
1146 return None;
1147 }
1148 } else {
1149 self.sum += value - old;
1150 self.sum_sq += value * value - old * old;
1151 }
1152 let inv_p = 1.0 / self.period as f64;
1153 let mean = self.sum * inv_p;
1154 let mean_sq = self.sum_sq * inv_p;
1155 Some((mean_sq - mean * mean) * self.nbdev * self.nbdev)
1156 }
1157}
1158
1159#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1160#[wasm_bindgen]
1161pub fn var_js(data: &[f64], period: usize, nbdev: f64) -> Result<Vec<f64>, JsValue> {
1162 let params = VarParams {
1163 period: Some(period),
1164 nbdev: Some(nbdev),
1165 };
1166 let input = VarInput::from_slice(data, params);
1167
1168 let mut output = vec![0.0; data.len()];
1169
1170 var_into_slice(&mut output, &input, detect_best_kernel())
1171 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1172
1173 Ok(output)
1174}
1175
1176#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1177#[wasm_bindgen]
1178pub fn var_alloc(len: usize) -> *mut f64 {
1179 let mut vec = Vec::<f64>::with_capacity(len);
1180 let ptr = vec.as_mut_ptr();
1181 std::mem::forget(vec);
1182 ptr
1183}
1184
1185#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1186#[wasm_bindgen]
1187pub fn var_free(ptr: *mut f64, len: usize) {
1188 unsafe {
1189 let _ = Vec::from_raw_parts(ptr, len, len);
1190 }
1191}
1192
1193#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1194#[wasm_bindgen]
1195pub fn var_into(
1196 in_ptr: *const f64,
1197 out_ptr: *mut f64,
1198 len: usize,
1199 period: usize,
1200 nbdev: f64,
1201) -> Result<(), JsValue> {
1202 if in_ptr.is_null() || out_ptr.is_null() {
1203 return Err(JsValue::from_str("Null pointer provided"));
1204 }
1205
1206 unsafe {
1207 let data = std::slice::from_raw_parts(in_ptr, len);
1208
1209 if period == 0 || period > len {
1210 return Err(JsValue::from_str("Invalid period"));
1211 }
1212
1213 let params = VarParams {
1214 period: Some(period),
1215 nbdev: Some(nbdev),
1216 };
1217 let input = VarInput::from_slice(data, params);
1218
1219 if in_ptr == out_ptr as *const f64 {
1220 let mut temp = vec![0.0; len];
1221 var_into_slice(&mut temp, &input, detect_best_kernel())
1222 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1223 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1224 out.copy_from_slice(&temp);
1225 } else {
1226 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1227 var_into_slice(out, &input, detect_best_kernel())
1228 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1229 }
1230
1231 Ok(())
1232 }
1233}
1234
1235#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1236#[derive(Serialize, Deserialize)]
1237pub struct VarBatchConfig {
1238 pub period_range: (usize, usize, usize),
1239 pub nbdev_range: (f64, f64, f64),
1240}
1241
1242#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1243#[derive(Serialize, Deserialize)]
1244pub struct VarBatchJsOutput {
1245 pub values: Vec<f64>,
1246 pub combos: Vec<VarParams>,
1247 pub rows: usize,
1248 pub cols: usize,
1249}
1250
1251#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1252#[wasm_bindgen(js_name = var_batch)]
1253pub fn var_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1254 let config: VarBatchConfig = serde_wasm_bindgen::from_value(config)
1255 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1256
1257 let sweep = VarBatchRange {
1258 period: config.period_range,
1259 nbdev: config.nbdev_range,
1260 };
1261
1262 let output = var_batch_inner(data, &sweep, detect_best_kernel(), false)
1263 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1264
1265 let js_output = VarBatchJsOutput {
1266 values: output.values,
1267 combos: output.combos,
1268 rows: output.rows,
1269 cols: output.cols,
1270 };
1271
1272 serde_wasm_bindgen::to_value(&js_output).map_err(|e| JsValue::from_str(&e.to_string()))
1273}
1274
1275#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1276#[wasm_bindgen]
1277pub fn var_batch_into(
1278 in_ptr: *const f64,
1279 out_ptr: *mut f64,
1280 len: usize,
1281 period_start: usize,
1282 period_end: usize,
1283 period_step: usize,
1284 nbdev_start: f64,
1285 nbdev_end: f64,
1286 nbdev_step: f64,
1287) -> Result<usize, JsValue> {
1288 if in_ptr.is_null() || out_ptr.is_null() {
1289 return Err(JsValue::from_str("Null pointer provided"));
1290 }
1291
1292 unsafe {
1293 let data = std::slice::from_raw_parts(in_ptr, len);
1294
1295 let sweep = VarBatchRange {
1296 period: (period_start, period_end, period_step),
1297 nbdev: (nbdev_start, nbdev_end, nbdev_step),
1298 };
1299
1300 let combos = expand_grid(&sweep);
1301 if combos.is_empty() {
1302 return Err(JsValue::from_str("No valid parameter combinations"));
1303 }
1304 let rows = combos.len();
1305 let cols = len;
1306
1307 let total = rows
1308 .checked_mul(cols)
1309 .ok_or_else(|| JsValue::from_str("rows * cols overflow"))?;
1310 let out = std::slice::from_raw_parts_mut(out_ptr, total);
1311
1312 let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1313 for (r, prm) in combos.iter().enumerate() {
1314 let warm = (first + prm.period.unwrap() - 1).min(cols);
1315 let row = &mut out[r * cols..r * cols + warm];
1316 for v in row {
1317 *v = f64::NAN;
1318 }
1319 }
1320
1321 let _ = var_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
1322 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1323
1324 Ok(rows)
1325 }
1326}
1327
1328#[cfg(test)]
1329mod tests {
1330 use super::*;
1331 use crate::skip_if_unsupported;
1332 use crate::utilities::data_loader::read_candles_from_csv;
1333 use paste::paste;
1334
1335 #[test]
1336 fn test_var_into_matches_api() -> Result<(), Box<dyn Error>> {
1337 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1338 let candles = read_candles_from_csv(file_path)?;
1339 let input = VarInput::from_candles(&candles, "close", VarParams::default());
1340
1341 let VarOutput { values: expected } = var(&input)?;
1342
1343 let mut out = vec![0.0f64; expected.len()];
1344
1345 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1346 {
1347 var_into(&input, &mut out)?;
1348 }
1349 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1350 {
1351 var_into_slice(&mut out, &input, detect_best_kernel())?;
1352 }
1353
1354 assert_eq!(out.len(), expected.len());
1355
1356 for (i, (a, b)) in out.iter().zip(expected.iter()).enumerate() {
1357 let eq_or_both_nan = (*a == *b) || (a.is_nan() && b.is_nan());
1358 assert!(
1359 eq_or_both_nan,
1360 "Mismatch at index {}: got {:?}, expected {:?}",
1361 i, a, b
1362 );
1363 }
1364
1365 Ok(())
1366 }
1367
1368 fn check_var_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1369 skip_if_unsupported!(kernel, test_name);
1370 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1371 let candles = read_candles_from_csv(file_path)?;
1372 let default_params = VarParams {
1373 period: None,
1374 nbdev: None,
1375 };
1376 let input = VarInput::from_candles(&candles, "close", default_params);
1377 let output = var_with_kernel(&input, kernel)?;
1378 assert_eq!(output.values.len(), candles.close.len());
1379 Ok(())
1380 }
1381
1382 fn check_var_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1383 skip_if_unsupported!(kernel, test_name);
1384 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1385 let candles = read_candles_from_csv(file_path)?;
1386 let input = VarInput::from_candles(&candles, "close", VarParams::default());
1387 let var_result = var_with_kernel(&input, kernel)?;
1388 assert_eq!(var_result.values.len(), candles.close.len());
1389 let expected_last_five = [
1390 350987.4081501961,
1391 348493.9183540344,
1392 302611.06121110916,
1393 106092.2499871254,
1394 121941.35202789307,
1395 ];
1396 let start_index = var_result.values.len() - 5;
1397 let result_last_five = &var_result.values[start_index..];
1398 for (i, &value) in result_last_five.iter().enumerate() {
1399 let expected_value = expected_last_five[i];
1400 assert!(
1401 (value - expected_value).abs() < 1e-1,
1402 "[{}] VAR mismatch at idx {}: got {}, expected {}",
1403 test_name,
1404 i,
1405 value,
1406 expected_value
1407 );
1408 }
1409 Ok(())
1410 }
1411
1412 fn check_var_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1413 skip_if_unsupported!(kernel, test_name);
1414 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1415 let candles = read_candles_from_csv(file_path)?;
1416 let input = VarInput::with_default_candles(&candles);
1417 match input.data {
1418 VarData::Candles { source, .. } => assert_eq!(source, "close"),
1419 _ => panic!("Expected VarData::Candles"),
1420 }
1421 let output = var_with_kernel(&input, kernel)?;
1422 assert_eq!(output.values.len(), candles.close.len());
1423 Ok(())
1424 }
1425
1426 fn check_var_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1427 skip_if_unsupported!(kernel, test_name);
1428 let input_data = [10.0, 20.0, 30.0];
1429 let params = VarParams {
1430 period: Some(0),
1431 nbdev: None,
1432 };
1433 let input = VarInput::from_slice(&input_data, params);
1434 let res = var_with_kernel(&input, kernel);
1435 assert!(
1436 res.is_err(),
1437 "[{}] VAR should fail with zero period",
1438 test_name
1439 );
1440 Ok(())
1441 }
1442
1443 fn check_var_period_exceeds_length(
1444 test_name: &str,
1445 kernel: Kernel,
1446 ) -> Result<(), Box<dyn Error>> {
1447 skip_if_unsupported!(kernel, test_name);
1448 let data_small = [10.0, 20.0, 30.0];
1449 let params = VarParams {
1450 period: Some(10),
1451 nbdev: None,
1452 };
1453 let input = VarInput::from_slice(&data_small, params);
1454 let res = var_with_kernel(&input, kernel);
1455 assert!(
1456 res.is_err(),
1457 "[{}] VAR should fail with period exceeding length",
1458 test_name
1459 );
1460 Ok(())
1461 }
1462
1463 fn check_var_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1464 skip_if_unsupported!(kernel, test_name);
1465 let single_point = [42.0];
1466 let params = VarParams {
1467 period: Some(14),
1468 nbdev: None,
1469 };
1470 let input = VarInput::from_slice(&single_point, params);
1471 let res = var_with_kernel(&input, kernel);
1472 assert!(
1473 res.is_err(),
1474 "[{}] VAR should fail with insufficient data",
1475 test_name
1476 );
1477 Ok(())
1478 }
1479
1480 fn check_var_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1481 skip_if_unsupported!(kernel, test_name);
1482 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1483 let candles = read_candles_from_csv(file_path)?;
1484 let first_params = VarParams {
1485 period: Some(14),
1486 nbdev: Some(1.0),
1487 };
1488 let first_input = VarInput::from_candles(&candles, "close", first_params);
1489 let first_result = var_with_kernel(&first_input, kernel)?;
1490 let second_params = VarParams {
1491 period: Some(14),
1492 nbdev: Some(1.0),
1493 };
1494 let second_input = VarInput::from_slice(&first_result.values, second_params);
1495 let second_result = var_with_kernel(&second_input, kernel)?;
1496 assert_eq!(second_result.values.len(), first_result.values.len());
1497 Ok(())
1498 }
1499
1500 fn check_var_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1501 skip_if_unsupported!(kernel, test_name);
1502 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1503 let candles = read_candles_from_csv(file_path)?;
1504 let input = VarInput::from_candles(
1505 &candles,
1506 "close",
1507 VarParams {
1508 period: Some(14),
1509 nbdev: None,
1510 },
1511 );
1512 let res = var_with_kernel(&input, kernel)?;
1513 assert_eq!(res.values.len(), candles.close.len());
1514 if res.values.len() > 30 {
1515 for (i, &val) in res.values[30..].iter().enumerate() {
1516 assert!(
1517 !val.is_nan(),
1518 "[{}] Found unexpected NaN at out-index {}",
1519 test_name,
1520 30 + i
1521 );
1522 }
1523 }
1524 Ok(())
1525 }
1526
1527 fn check_var_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1528 skip_if_unsupported!(kernel, test_name);
1529 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1530 let candles = read_candles_from_csv(file_path)?;
1531 let period = 14;
1532 let nbdev = 1.0;
1533 let input = VarInput::from_candles(
1534 &candles,
1535 "close",
1536 VarParams {
1537 period: Some(period),
1538 nbdev: Some(nbdev),
1539 },
1540 );
1541 let batch_output = var_with_kernel(&input, kernel)?.values;
1542 let mut stream = VarStream::try_new(VarParams {
1543 period: Some(period),
1544 nbdev: Some(nbdev),
1545 })?;
1546 let mut stream_values = Vec::with_capacity(candles.close.len());
1547 for &price in &candles.close {
1548 match stream.update(price) {
1549 Some(var_val) => stream_values.push(var_val),
1550 None => stream_values.push(f64::NAN),
1551 }
1552 }
1553 assert_eq!(batch_output.len(), stream_values.len());
1554 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1555 if b.is_nan() && s.is_nan() {
1556 continue;
1557 }
1558 let diff = (b - s).abs();
1559 assert!(
1560 diff < 1e-6,
1561 "[{}] VAR streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1562 test_name,
1563 i,
1564 b,
1565 s,
1566 diff
1567 );
1568 }
1569 Ok(())
1570 }
1571
1572 #[cfg(debug_assertions)]
1573 fn check_var_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1574 skip_if_unsupported!(kernel, test_name);
1575
1576 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1577 let candles = read_candles_from_csv(file_path)?;
1578
1579 let test_params = vec![
1580 VarParams::default(),
1581 VarParams {
1582 period: Some(2),
1583 nbdev: Some(1.0),
1584 },
1585 VarParams {
1586 period: Some(5),
1587 nbdev: Some(1.0),
1588 },
1589 VarParams {
1590 period: Some(10),
1591 nbdev: Some(1.0),
1592 },
1593 VarParams {
1594 period: Some(20),
1595 nbdev: Some(1.0),
1596 },
1597 VarParams {
1598 period: Some(30),
1599 nbdev: Some(1.0),
1600 },
1601 VarParams {
1602 period: Some(50),
1603 nbdev: Some(1.0),
1604 },
1605 VarParams {
1606 period: Some(100),
1607 nbdev: Some(1.0),
1608 },
1609 VarParams {
1610 period: Some(200),
1611 nbdev: Some(1.0),
1612 },
1613 VarParams {
1614 period: Some(14),
1615 nbdev: Some(0.5),
1616 },
1617 VarParams {
1618 period: Some(14),
1619 nbdev: Some(2.0),
1620 },
1621 VarParams {
1622 period: Some(14),
1623 nbdev: Some(3.0),
1624 },
1625 VarParams {
1626 period: Some(7),
1627 nbdev: Some(1.5),
1628 },
1629 VarParams {
1630 period: Some(21),
1631 nbdev: Some(2.5),
1632 },
1633 VarParams {
1634 period: Some(50),
1635 nbdev: Some(0.75),
1636 },
1637 ];
1638
1639 for (param_idx, params) in test_params.iter().enumerate() {
1640 let input = VarInput::from_candles(&candles, "close", params.clone());
1641 let output = var_with_kernel(&input, kernel)?;
1642
1643 for (i, &val) in output.values.iter().enumerate() {
1644 if val.is_nan() {
1645 continue;
1646 }
1647
1648 let bits = val.to_bits();
1649
1650 if bits == 0x11111111_11111111 {
1651 panic!(
1652 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1653 with params: period={}, nbdev={} (param set {})",
1654 test_name,
1655 val,
1656 bits,
1657 i,
1658 params.period.unwrap_or(14),
1659 params.nbdev.unwrap_or(1.0),
1660 param_idx
1661 );
1662 }
1663
1664 if bits == 0x22222222_22222222 {
1665 panic!(
1666 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1667 with params: period={}, nbdev={} (param set {})",
1668 test_name,
1669 val,
1670 bits,
1671 i,
1672 params.period.unwrap_or(14),
1673 params.nbdev.unwrap_or(1.0),
1674 param_idx
1675 );
1676 }
1677
1678 if bits == 0x33333333_33333333 {
1679 panic!(
1680 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1681 with params: period={}, nbdev={} (param set {})",
1682 test_name,
1683 val,
1684 bits,
1685 i,
1686 params.period.unwrap_or(14),
1687 params.nbdev.unwrap_or(1.0),
1688 param_idx
1689 );
1690 }
1691 }
1692 }
1693
1694 Ok(())
1695 }
1696
1697 #[cfg(not(debug_assertions))]
1698 fn check_var_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1699 Ok(())
1700 }
1701
1702 #[cfg(feature = "proptest")]
1703 #[allow(clippy::float_cmp)]
1704 fn check_var_property(
1705 test_name: &str,
1706 kernel: Kernel,
1707 ) -> Result<(), Box<dyn std::error::Error>> {
1708 use proptest::prelude::*;
1709 skip_if_unsupported!(kernel, test_name);
1710
1711 let strat = (2usize..=64).prop_flat_map(|period| {
1712 (
1713 prop::collection::vec(
1714 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1715 period.max(10)..400,
1716 ),
1717 Just(period),
1718 0.1f64..3.0f64,
1719 -100.0f64..100.0f64,
1720 -1e5f64..1e5f64,
1721 prop::bool::ANY,
1722 )
1723 });
1724
1725 proptest::test_runner::TestRunner::default()
1726 .run(
1727 &strat,
1728 |(mut data, period, nbdev, trend, intercept, use_special_pattern)| {
1729 if use_special_pattern {
1730 for (i, val) in data.iter_mut().enumerate() {
1731 *val = intercept + trend * (i as f64);
1732 }
1733
1734 for val in data.iter_mut() {
1735 *val += (val.abs() * 0.01).min(10.0)
1736 * (if val.is_sign_positive() { 1.0 } else { -1.0 });
1737 }
1738 }
1739
1740 let params = VarParams {
1741 period: Some(period),
1742 nbdev: Some(nbdev),
1743 };
1744 let input = VarInput::from_slice(&data, params);
1745
1746 let VarOutput { values: out } = var_with_kernel(&input, kernel).unwrap();
1747 let VarOutput { values: ref_out } =
1748 var_with_kernel(&input, Kernel::Scalar).unwrap();
1749
1750 for i in (period - 1)..data.len() {
1751 let y = out[i];
1752 if !y.is_nan() {
1753 prop_assert!(
1754 y >= -1e-6,
1755 "[{}] Variance should be non-negative at idx {}: got {}",
1756 test_name,
1757 i,
1758 y
1759 );
1760 }
1761 }
1762
1763 if period == 2 && !use_special_pattern {
1764 for i in (period - 1)..data.len().min(10) {
1765 if !out[i].is_nan() {
1766 let window = &data[i + 1 - period..=i];
1767 let mean = (window[0] + window[1]) / 2.0;
1768 let expected =
1769 ((window[0] - mean).powi(2) + (window[1] - mean).powi(2)) / 2.0
1770 * nbdev
1771 * nbdev;
1772
1773 let tolerance = (expected.abs() + 1.0) * 1e-8;
1774 prop_assert!(
1775 (out[i] - expected).abs() <= tolerance,
1776 "[{}] Period=2 variance mismatch at idx {}: got {} expected {}",
1777 test_name,
1778 i,
1779 out[i],
1780 expected
1781 );
1782 }
1783 }
1784 }
1785
1786 let is_constant = data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
1787 if is_constant && data.len() >= period {
1788 for i in (period - 1)..data.len() {
1789 if !out[i].is_nan() {
1790 prop_assert!(
1791 out[i].abs() <= 1e-6,
1792 "[{}] Constant data should have near-zero variance at idx {}: got {}",
1793 test_name, i, out[i]
1794 );
1795 }
1796 }
1797 }
1798
1799 for i in (period - 1)..data.len() {
1800 let y = out[i];
1801 let r = ref_out[i];
1802
1803 if !y.is_finite() || !r.is_finite() {
1804 prop_assert!(
1805 y.to_bits() == r.to_bits(),
1806 "[{}] finite/NaN mismatch at idx {}: {} vs {}",
1807 test_name,
1808 i,
1809 y,
1810 r
1811 );
1812 continue;
1813 }
1814
1815 let y_bits = y.to_bits();
1816 let r_bits = r.to_bits();
1817 let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1818
1819 prop_assert!(
1820 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1821 "[{}] Kernel mismatch at idx {}: {} vs {} (ULP={})",
1822 test_name,
1823 i,
1824 y,
1825 r,
1826 ulp_diff
1827 );
1828 }
1829
1830 if !use_special_pattern && data.len() >= period && period <= 10 {
1831 let idx = period * 2;
1832 if idx < data.len() && !out[idx].is_nan() {
1833 let window = &data[idx + 1 - period..=idx];
1834 let mean: f64 = window.iter().sum::<f64>() / (period as f64);
1835 let mean_sq: f64 =
1836 window.iter().map(|x| x * x).sum::<f64>() / (period as f64);
1837 let expected_var = (mean_sq - mean * mean) * nbdev * nbdev;
1838
1839 let tolerance = (expected_var.abs() + 1.0) * 1e-8;
1840 prop_assert!(
1841 (out[idx] - expected_var).abs() <= tolerance,
1842 "[{}] Mathematical formula mismatch at idx {}: got {} expected {} (diff: {})",
1843 test_name, idx, out[idx], expected_var, (out[idx] - expected_var).abs()
1844 );
1845 }
1846 }
1847
1848 Ok(())
1849 },
1850 )
1851 .unwrap();
1852
1853 Ok(())
1854 }
1855
1856 macro_rules! generate_all_var_tests {
1857 ($($test_fn:ident),*) => {
1858 paste! {
1859 $(
1860 #[test]
1861 fn [<$test_fn _scalar_f64>]() {
1862 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1863 }
1864 )*
1865 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1866 $(
1867 #[test]
1868 fn [<$test_fn _avx2_f64>]() {
1869 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1870 }
1871 #[test]
1872 fn [<$test_fn _avx512_f64>]() {
1873 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1874 }
1875 )*
1876 }
1877 }
1878 }
1879 generate_all_var_tests!(
1880 check_var_partial_params,
1881 check_var_accuracy,
1882 check_var_default_candles,
1883 check_var_zero_period,
1884 check_var_period_exceeds_length,
1885 check_var_very_small_dataset,
1886 check_var_reinput,
1887 check_var_nan_handling,
1888 check_var_streaming,
1889 check_var_no_poison
1890 );
1891
1892 #[cfg(feature = "proptest")]
1893 generate_all_var_tests!(check_var_property);
1894
1895 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1896 skip_if_unsupported!(kernel, test);
1897 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1898 let c = read_candles_from_csv(file)?;
1899 let output = VarBatchBuilder::new()
1900 .kernel(kernel)
1901 .apply_candles(&c, "close")?;
1902 let def = VarParams::default();
1903 let row = output.values_for(&def).expect("default row missing");
1904 assert_eq!(row.len(), c.close.len());
1905 let expected = [
1906 350987.4081501961,
1907 348493.9183540344,
1908 302611.06121110916,
1909 106092.2499871254,
1910 121941.35202789307,
1911 ];
1912 let start = row.len() - 5;
1913 for (i, &v) in row[start..].iter().enumerate() {
1914 assert!(
1915 (v - expected[i]).abs() < 1e-1,
1916 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1917 );
1918 }
1919 Ok(())
1920 }
1921
1922 #[cfg(debug_assertions)]
1923 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1924 skip_if_unsupported!(kernel, test);
1925
1926 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1927 let c = read_candles_from_csv(file)?;
1928
1929 let test_configs = vec![
1930 ((2, 10, 2), (1.0, 1.0, 0.0)),
1931 ((10, 30, 5), (1.0, 2.0, 0.5)),
1932 ((30, 100, 10), (0.5, 1.5, 0.5)),
1933 ((2, 5, 1), (1.0, 3.0, 1.0)),
1934 ((14, 14, 0), (1.0, 1.0, 0.0)),
1935 ((5, 25, 5), (2.0, 2.0, 0.0)),
1936 ((50, 100, 25), (1.0, 2.0, 0.25)),
1937 ((14, 28, 7), (0.5, 2.5, 0.5)),
1938 ];
1939
1940 for (cfg_idx, &(period_range, nbdev_range)) in test_configs.iter().enumerate() {
1941 let output = VarBatchBuilder::new()
1942 .kernel(kernel)
1943 .period_range(period_range.0, period_range.1, period_range.2)
1944 .nbdev_range(nbdev_range.0, nbdev_range.1, nbdev_range.2)
1945 .apply_candles(&c, "close")?;
1946
1947 for (idx, &val) in output.values.iter().enumerate() {
1948 if val.is_nan() {
1949 continue;
1950 }
1951
1952 let bits = val.to_bits();
1953 let row = idx / output.cols;
1954 let col = idx % output.cols;
1955 let combo = &output.combos[row];
1956
1957 if bits == 0x11111111_11111111 {
1958 panic!(
1959 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1960 at row {} col {} (flat index {}) with params: period={}, nbdev={}",
1961 test,
1962 cfg_idx,
1963 val,
1964 bits,
1965 row,
1966 col,
1967 idx,
1968 combo.period.unwrap_or(14),
1969 combo.nbdev.unwrap_or(1.0)
1970 );
1971 }
1972
1973 if bits == 0x22222222_22222222 {
1974 panic!(
1975 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1976 at row {} col {} (flat index {}) with params: period={}, nbdev={}",
1977 test,
1978 cfg_idx,
1979 val,
1980 bits,
1981 row,
1982 col,
1983 idx,
1984 combo.period.unwrap_or(14),
1985 combo.nbdev.unwrap_or(1.0)
1986 );
1987 }
1988
1989 if bits == 0x33333333_33333333 {
1990 panic!(
1991 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1992 at row {} col {} (flat index {}) with params: period={}, nbdev={}",
1993 test,
1994 cfg_idx,
1995 val,
1996 bits,
1997 row,
1998 col,
1999 idx,
2000 combo.period.unwrap_or(14),
2001 combo.nbdev.unwrap_or(1.0)
2002 );
2003 }
2004 }
2005 }
2006
2007 Ok(())
2008 }
2009
2010 #[cfg(not(debug_assertions))]
2011 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2012 Ok(())
2013 }
2014
2015 macro_rules! gen_batch_tests {
2016 ($fn_name:ident) => {
2017 paste! {
2018 #[test] fn [<$fn_name _scalar>]() {
2019 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2020 }
2021 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2022 #[test] fn [<$fn_name _avx2>]() {
2023 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2024 }
2025 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2026 #[test] fn [<$fn_name _avx512>]() {
2027 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2028 }
2029 #[test] fn [<$fn_name _auto_detect>]() {
2030 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2031 }
2032 }
2033 };
2034 }
2035 gen_batch_tests!(check_batch_default_row);
2036 gen_batch_tests!(check_batch_no_poison);
2037
2038 #[test]
2039 fn test_batch_non_aligned_periods() {
2040 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
2041
2042 let sweep = VarBatchRange {
2043 period: (7, 7, 0),
2044 nbdev: (1.0, 1.0, 0.0),
2045 };
2046
2047 let result = var_batch_slice(&data, &sweep, Kernel::Scalar);
2048 assert!(result.is_ok(), "Should handle period=7 with 10 data points");
2049
2050 let sweep_multi = VarBatchRange {
2051 period: (5, 7, 1),
2052 nbdev: (1.0, 1.0, 0.0),
2053 };
2054
2055 let result_multi = var_batch_slice(&data, &sweep_multi, Kernel::Scalar);
2056 assert!(
2057 result_multi.is_ok(),
2058 "Should handle periods 5,6,7 with 10 data points"
2059 );
2060 let output = result_multi.unwrap();
2061 assert_eq!(output.rows, 3, "Should have 3 rows for periods 5,6,7");
2062
2063 let data_short = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
2064 let result_short = var_batch_slice(&data_short, &sweep, Kernel::Scalar);
2065 assert!(
2066 result_short.is_err(),
2067 "Should reject when data length (6) < period (7)"
2068 );
2069
2070 let data_15 = vec![1.0; 15];
2071 let sweep_15 = VarBatchRange {
2072 period: (15, 15, 0),
2073 nbdev: (1.0, 1.0, 0.0),
2074 };
2075
2076 let result_15 = var_batch_slice(&data_15, &sweep_15, Kernel::Scalar);
2077 assert!(
2078 result_15.is_ok(),
2079 "Should handle period=15 with exactly 15 data points"
2080 );
2081 }
2082}
2083
2084#[cfg(feature = "python")]
2085#[pyfunction(name = "var")]
2086#[pyo3(signature = (data, period=14, nbdev=1.0, kernel=None))]
2087pub fn var_py<'py>(
2088 py: Python<'py>,
2089 data: PyReadonlyArray1<'py, f64>,
2090 period: usize,
2091 nbdev: f64,
2092 kernel: Option<&str>,
2093) -> PyResult<Bound<'py, PyArray1<f64>>> {
2094 let slice_in = data.as_slice()?;
2095 let kern = validate_kernel(kernel, false)?;
2096
2097 let params = VarParams {
2098 period: Some(period),
2099 nbdev: Some(nbdev),
2100 };
2101 let input = VarInput::from_slice(slice_in, params);
2102
2103 let result_vec: Vec<f64> = py
2104 .allow_threads(|| var_with_kernel(&input, kern).map(|o| o.values))
2105 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2106
2107 Ok(result_vec.into_pyarray(py))
2108}
2109
2110#[cfg(feature = "python")]
2111#[pyclass(name = "VarStream")]
2112pub struct VarStreamPy {
2113 stream: VarStream,
2114}
2115
2116#[cfg(feature = "python")]
2117#[pymethods]
2118impl VarStreamPy {
2119 #[new]
2120 fn new(period: usize, nbdev: f64) -> PyResult<Self> {
2121 let params = VarParams {
2122 period: Some(period),
2123 nbdev: Some(nbdev),
2124 };
2125 let stream =
2126 VarStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2127 Ok(VarStreamPy { stream })
2128 }
2129
2130 fn update(&mut self, value: f64) -> Option<f64> {
2131 self.stream.update(value)
2132 }
2133}
2134
2135#[cfg(feature = "python")]
2136#[pyfunction(name = "var_batch")]
2137#[pyo3(signature = (data, period_range, nbdev_range=(1.0, 1.0, 0.0), kernel=None))]
2138pub fn var_batch_py<'py>(
2139 py: Python<'py>,
2140 data: PyReadonlyArray1<'py, f64>,
2141 period_range: (usize, usize, usize),
2142 nbdev_range: (f64, f64, f64),
2143 kernel: Option<&str>,
2144) -> PyResult<Bound<'py, PyDict>> {
2145 let slice_in = data.as_slice()?;
2146 let sweep = VarBatchRange {
2147 period: period_range,
2148 nbdev: nbdev_range,
2149 };
2150
2151 let kern = validate_kernel(kernel, true)?;
2152 let kernel = match kern {
2153 Kernel::Auto => detect_best_batch_kernel(),
2154 k => k,
2155 };
2156
2157 let simd = match kernel {
2158 Kernel::Avx512Batch => Kernel::Avx512,
2159 Kernel::Avx2Batch => Kernel::Avx2,
2160 Kernel::ScalarBatch => Kernel::Scalar,
2161 _ => unreachable!(),
2162 };
2163
2164 let out = py
2165 .allow_threads(|| var_batch_par_slice(slice_in, &sweep, simd))
2166 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2167
2168 let rows = out.rows;
2169 let cols = out.cols;
2170
2171 let dict = PyDict::new(py);
2172
2173 let total = rows
2174 .checked_mul(cols)
2175 .ok_or_else(|| PyValueError::new_err("rows * cols overflow"))?;
2176 let values_2d = unsafe { numpy::PyArray2::<f64>::new(py, [rows, cols], false) };
2177 let raw_ptr = values_2d.data() as *mut f64;
2178 let output_slice = unsafe { std::slice::from_raw_parts_mut(raw_ptr, total) };
2179 output_slice.copy_from_slice(&out.values);
2180
2181 dict.set_item("values", values_2d)?;
2182 dict.set_item(
2183 "periods",
2184 out.combos
2185 .iter()
2186 .map(|p| p.period.unwrap() as u64)
2187 .collect::<Vec<_>>()
2188 .into_pyarray(py),
2189 )?;
2190 dict.set_item(
2191 "nbdevs",
2192 out.combos
2193 .iter()
2194 .map(|p| p.nbdev.unwrap())
2195 .collect::<Vec<_>>()
2196 .into_pyarray(py),
2197 )?;
2198 Ok(dict)
2199}
2200
2201#[cfg(all(feature = "python", feature = "cuda"))]
2202use crate::cuda::cuda_available;
2203#[cfg(all(feature = "python", feature = "cuda"))]
2204use crate::cuda::moving_averages::DeviceArrayF32;
2205#[cfg(all(feature = "python", feature = "cuda"))]
2206use crate::cuda::var_wrapper::CudaVar;
2207#[cfg(all(feature = "python", feature = "cuda"))]
2208use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2209#[cfg(all(feature = "python", feature = "cuda"))]
2210use cust::context::Context;
2211#[cfg(all(feature = "python", feature = "cuda"))]
2212use cust::memory::DeviceBuffer;
2213#[cfg(all(feature = "python", feature = "cuda"))]
2214use std::sync::Arc;
2215
2216#[cfg(all(feature = "python", feature = "cuda"))]
2217#[pyclass(module = "ta_indicators.cuda", name = "VarDeviceArrayF32", unsendable)]
2218pub struct VarDeviceArrayF32Py {
2219 pub(crate) inner: DeviceArrayF32,
2220 pub(crate) _ctx: Arc<Context>,
2221 pub(crate) device_id: u32,
2222}
2223
2224#[cfg(all(feature = "python", feature = "cuda"))]
2225#[pymethods]
2226impl VarDeviceArrayF32Py {
2227 #[getter]
2228 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2229 let d = PyDict::new(py);
2230 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
2231 d.set_item("typestr", "<f4")?;
2232 d.set_item(
2233 "strides",
2234 (
2235 self.inner.cols * std::mem::size_of::<f32>(),
2236 std::mem::size_of::<f32>(),
2237 ),
2238 )?;
2239 d.set_item("data", (self.inner.device_ptr() as usize, false))?;
2240
2241 d.set_item("version", 3)?;
2242 Ok(d)
2243 }
2244
2245 fn __dlpack_device__(&self) -> (i32, i32) {
2246 (2, self.device_id as i32)
2247 }
2248
2249 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2250 fn __dlpack__<'py>(
2251 mut slf: pyo3::PyRefMut<'py, Self>,
2252 py: Python<'py>,
2253 stream: Option<pyo3::PyObject>,
2254 max_version: Option<pyo3::PyObject>,
2255 dl_device: Option<pyo3::PyObject>,
2256 copy: Option<pyo3::PyObject>,
2257 ) -> PyResult<PyObject> {
2258 use cust::memory::DeviceBuffer;
2259
2260 let (expected_type, expected_dev) = slf.__dlpack_device__();
2261 if let Some(dev_obj) = dl_device.as_ref() {
2262 if let Ok((dev_type, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2263 if dev_type != expected_type || dev_id != expected_dev {
2264 return Err(PyValueError::new_err("dl_device mismatch for VAR buffer"));
2265 }
2266 }
2267 }
2268
2269 let wants_copy = copy
2270 .as_ref()
2271 .and_then(|c| c.extract::<bool>(py).ok())
2272 .unwrap_or(false);
2273 if wants_copy {
2274 return Err(PyValueError::new_err(
2275 "copy=True not supported for VAR DLPack export",
2276 ));
2277 }
2278
2279 let _ = stream;
2280
2281 let dummy =
2282 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2283 let rows = slf.inner.rows;
2284 let cols = slf.inner.cols;
2285 let inner = std::mem::replace(
2286 &mut slf.inner,
2287 DeviceArrayF32 {
2288 buf: dummy,
2289 rows: 0,
2290 cols: 0,
2291 },
2292 );
2293
2294 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2295
2296 export_f32_cuda_dlpack_2d(py, inner.buf, rows, cols, expected_dev, max_version_bound)
2297 }
2298}
2299
2300#[cfg(all(feature = "python", feature = "cuda"))]
2301#[pyfunction(name = "var_cuda_batch_dev")]
2302#[pyo3(signature = (data_f32, period_range, nbdev_range=(1.0,1.0,0.0), device_id=0))]
2303pub fn var_cuda_batch_dev_py(
2304 py: Python<'_>,
2305 data_f32: numpy::PyReadonlyArray1<'_, f32>,
2306 period_range: (usize, usize, usize),
2307 nbdev_range: (f32, f32, f32),
2308 device_id: usize,
2309) -> PyResult<VarDeviceArrayF32Py> {
2310 if !cuda_available() {
2311 return Err(PyValueError::new_err("CUDA not available"));
2312 }
2313 let slice_in = data_f32.as_slice()?;
2314 let sweep = VarBatchRange {
2315 period: period_range,
2316 nbdev: (
2317 nbdev_range.0 as f64,
2318 nbdev_range.1 as f64,
2319 nbdev_range.2 as f64,
2320 ),
2321 };
2322 let (inner, ctx, dev_id) = py.allow_threads(|| {
2323 let cuda = CudaVar::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2324 let ctx = cuda.context_arc();
2325 let dev_id = cuda.device_id();
2326 cuda.var_batch_dev(slice_in, &sweep)
2327 .map(|pair| (pair.0, ctx, dev_id))
2328 .map_err(|e| PyValueError::new_err(e.to_string()))
2329 })?;
2330 Ok(VarDeviceArrayF32Py {
2331 inner,
2332 _ctx: ctx,
2333 device_id: dev_id,
2334 })
2335}
2336
2337#[cfg(all(feature = "python", feature = "cuda"))]
2338#[pyfunction(name = "var_cuda_many_series_one_param_dev")]
2339#[pyo3(signature = (data_tm_f32, cols, rows, period, nbdev=1.0, device_id=0))]
2340pub fn var_cuda_many_series_one_param_dev_py(
2341 py: Python<'_>,
2342 data_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2343 cols: usize,
2344 rows: usize,
2345 period: usize,
2346 nbdev: f64,
2347 device_id: usize,
2348) -> PyResult<VarDeviceArrayF32Py> {
2349 if !cuda_available() {
2350 return Err(PyValueError::new_err("CUDA not available"));
2351 }
2352 let slice_tm = data_tm_f32.as_slice()?;
2353 let params = VarParams {
2354 period: Some(period),
2355 nbdev: Some(nbdev),
2356 };
2357 let (inner, ctx, dev_id) = py.allow_threads(|| {
2358 let cuda = CudaVar::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2359 let ctx = cuda.context_arc();
2360 let dev_id = cuda.device_id();
2361 cuda.var_many_series_one_param_time_major_dev(slice_tm, cols, rows, ¶ms)
2362 .map(|h| (h, ctx, dev_id))
2363 .map_err(|e| PyValueError::new_err(e.to_string()))
2364 })?;
2365 Ok(VarDeviceArrayF32Py {
2366 inner,
2367 _ctx: ctx,
2368 device_id: dev_id,
2369 })
2370}
2371
2372#[inline(always)]
2373fn var_batch_inner_into(
2374 data: &[f64],
2375 sweep: &VarBatchRange,
2376 kern: Kernel,
2377 parallel: bool,
2378 out: &mut [f64],
2379) -> Result<Vec<VarParams>, VarError> {
2380 if data.is_empty() {
2381 return Err(VarError::EmptyInputData);
2382 }
2383 let combos = expand_grid(sweep);
2384 if combos.is_empty() {
2385 return Err(VarError::InvalidRange {
2386 start: sweep.period.0 as f64,
2387 end: sweep.period.1 as f64,
2388 step: sweep.period.2 as f64,
2389 });
2390 }
2391 let first = data
2392 .iter()
2393 .position(|x| !x.is_nan())
2394 .ok_or(VarError::AllValuesNaN)?;
2395 let max_period = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
2396 if data.len() - first < max_period {
2397 return Err(VarError::NotEnoughValidData {
2398 needed: max_period,
2399 valid: data.len() - first,
2400 });
2401 }
2402 let stride = round_up8(max_period);
2403 let cols = data.len();
2404
2405 if ENABLE_VAR_BATCH_PREFIX_SUMS && matches!(kern, Kernel::Scalar) {
2406 let (ps, psq) = make_prefix_sums(data, first);
2407 let compute_row = |row: usize, out_row: &mut [f64]| {
2408 let period = combos[row].period.unwrap();
2409 let inv_p = 1.0 / (period as f64);
2410 let nbdev2 = combos[row].nbdev.unwrap() * combos[row].nbdev.unwrap();
2411 let start = first + period - 1;
2412 for i in start..cols {
2413 let s = ps[i + 1] - ps[i + 1 - period];
2414 let s2 = psq[i + 1] - psq[i + 1 - period];
2415 let m = s * inv_p;
2416 out_row[i] = (s2 * inv_p - m * m) * nbdev2;
2417 }
2418 };
2419 if parallel {
2420 #[cfg(not(target_arch = "wasm32"))]
2421 {
2422 out.par_chunks_mut(cols)
2423 .enumerate()
2424 .for_each(|(row, slice)| compute_row(row, slice));
2425 }
2426 #[cfg(target_arch = "wasm32")]
2427 {
2428 for (row, slice) in out.chunks_mut(cols).enumerate() {
2429 compute_row(row, slice);
2430 }
2431 }
2432 } else {
2433 for (row, slice) in out.chunks_mut(cols).enumerate() {
2434 compute_row(row, slice);
2435 }
2436 }
2437 } else {
2438 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
2439 let period = combos[row].period.unwrap();
2440 let nbdev = combos[row].nbdev.unwrap();
2441 match kern {
2442 Kernel::Scalar => var_row_scalar(data, first, period, stride, nbdev, out_row),
2443 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2444 Kernel::Avx2 => var_row_avx2(data, first, period, stride, nbdev, out_row),
2445 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2446 Kernel::Avx512 => var_row_avx512(data, first, period, stride, nbdev, out_row),
2447 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
2448 Kernel::Avx2 | Kernel::Avx512 => {
2449 var_row_scalar(data, first, period, stride, nbdev, out_row)
2450 }
2451 _ => unreachable!(),
2452 }
2453 };
2454 if parallel {
2455 #[cfg(not(target_arch = "wasm32"))]
2456 {
2457 out.par_chunks_mut(cols)
2458 .enumerate()
2459 .for_each(|(row, slice)| do_row(row, slice));
2460 }
2461 #[cfg(target_arch = "wasm32")]
2462 {
2463 for (row, slice) in out.chunks_mut(cols).enumerate() {
2464 do_row(row, slice);
2465 }
2466 }
2467 } else {
2468 for (row, slice) in out.chunks_mut(cols).enumerate() {
2469 do_row(row, slice);
2470 }
2471 }
2472 }
2473
2474 Ok(combos)
2475}