1#[cfg(feature = "python")]
2use crate::utilities::kernel_validation::validate_kernel;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use numpy::PyUntypedArrayMethods;
5#[cfg(feature = "python")]
6use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
7#[cfg(feature = "python")]
8use pyo3::exceptions::PyValueError;
9#[cfg(feature = "python")]
10use pyo3::prelude::*;
11#[cfg(feature = "python")]
12use pyo3::types::PyDict;
13
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use serde::{Deserialize, Serialize};
16#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
17use wasm_bindgen::prelude::*;
18
19use crate::utilities::data_loader::{source_type, Candles};
20#[cfg(all(feature = "python", feature = "cuda"))]
21use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
22use crate::utilities::enums::Kernel;
23use crate::utilities::helpers::{
24 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
25 make_uninit_matrix,
26};
27#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
28use core::arch::x86_64::*;
29#[cfg(all(feature = "python", feature = "cuda"))]
30use cust::context::Context;
31#[cfg(all(feature = "python", feature = "cuda"))]
32use cust::memory::DeviceBuffer;
33#[cfg(not(target_arch = "wasm32"))]
34use rayon::prelude::*;
35use std::convert::AsRef;
36use std::error::Error;
37#[cfg(all(feature = "python", feature = "cuda"))]
38use std::sync::Arc;
39use thiserror::Error;
40
41impl<'a> AsRef<[f64]> for VossInput<'a> {
42 #[inline(always)]
43 fn as_ref(&self) -> &[f64] {
44 match &self.data {
45 VossData::Slice(slice) => slice,
46 VossData::Candles { candles, source } => source_type(candles, source),
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
52pub enum VossData<'a> {
53 Candles {
54 candles: &'a Candles,
55 source: &'a str,
56 },
57 Slice(&'a [f64]),
58}
59
60#[derive(Debug, Clone)]
61pub struct VossOutput {
62 pub voss: Vec<f64>,
63 pub filt: Vec<f64>,
64}
65
66#[derive(Debug, Clone)]
67#[cfg_attr(
68 all(target_arch = "wasm32", feature = "wasm"),
69 derive(Serialize, Deserialize)
70)]
71pub struct VossParams {
72 pub period: Option<usize>,
73 pub predict: Option<usize>,
74 pub bandwidth: Option<f64>,
75}
76
77impl Default for VossParams {
78 fn default() -> Self {
79 Self {
80 period: Some(20),
81 predict: Some(3),
82 bandwidth: Some(0.25),
83 }
84 }
85}
86
87#[derive(Debug, Clone)]
88pub struct VossInput<'a> {
89 pub data: VossData<'a>,
90 pub params: VossParams,
91}
92
93impl<'a> VossInput<'a> {
94 #[inline]
95 pub fn from_candles(c: &'a Candles, s: &'a str, p: VossParams) -> Self {
96 Self {
97 data: VossData::Candles {
98 candles: c,
99 source: s,
100 },
101 params: p,
102 }
103 }
104 #[inline]
105 pub fn from_slice(sl: &'a [f64], p: VossParams) -> Self {
106 Self {
107 data: VossData::Slice(sl),
108 params: p,
109 }
110 }
111 #[inline]
112 pub fn with_default_candles(c: &'a Candles) -> Self {
113 Self::from_candles(c, "close", VossParams::default())
114 }
115 #[inline]
116 pub fn get_period(&self) -> usize {
117 self.params.period.unwrap_or(20)
118 }
119 #[inline]
120 pub fn get_predict(&self) -> usize {
121 self.params.predict.unwrap_or(3)
122 }
123 #[inline]
124 pub fn get_bandwidth(&self) -> f64 {
125 self.params.bandwidth.unwrap_or(0.25)
126 }
127}
128
129#[derive(Copy, Clone, Debug)]
130pub struct VossBuilder {
131 period: Option<usize>,
132 predict: Option<usize>,
133 bandwidth: Option<f64>,
134 kernel: Kernel,
135}
136
137impl Default for VossBuilder {
138 fn default() -> Self {
139 Self {
140 period: None,
141 predict: None,
142 bandwidth: None,
143 kernel: Kernel::Auto,
144 }
145 }
146}
147
148impl VossBuilder {
149 #[inline(always)]
150 pub fn new() -> Self {
151 Self::default()
152 }
153 #[inline(always)]
154 pub fn period(mut self, n: usize) -> Self {
155 self.period = Some(n);
156 self
157 }
158 #[inline(always)]
159 pub fn predict(mut self, n: usize) -> Self {
160 self.predict = Some(n);
161 self
162 }
163 #[inline(always)]
164 pub fn bandwidth(mut self, b: f64) -> Self {
165 self.bandwidth = Some(b);
166 self
167 }
168 #[inline(always)]
169 pub fn kernel(mut self, k: Kernel) -> Self {
170 self.kernel = k;
171 self
172 }
173
174 #[inline(always)]
175 pub fn apply(self, c: &Candles) -> Result<VossOutput, VossError> {
176 let p = VossParams {
177 period: self.period,
178 predict: self.predict,
179 bandwidth: self.bandwidth,
180 };
181 let i = VossInput::from_candles(c, "close", p);
182 voss_with_kernel(&i, self.kernel)
183 }
184 #[inline(always)]
185 pub fn apply_slice(self, d: &[f64]) -> Result<VossOutput, VossError> {
186 let p = VossParams {
187 period: self.period,
188 predict: self.predict,
189 bandwidth: self.bandwidth,
190 };
191 let i = VossInput::from_slice(d, p);
192 voss_with_kernel(&i, self.kernel)
193 }
194 #[inline(always)]
195 pub fn into_stream(self) -> Result<VossStream, VossError> {
196 let p = VossParams {
197 period: self.period,
198 predict: self.predict,
199 bandwidth: self.bandwidth,
200 };
201 VossStream::try_new(p)
202 }
203}
204
205#[derive(Debug, Error)]
206pub enum VossError {
207 #[error("voss: Input data slice is empty.")]
208 EmptyInputData,
209 #[error("voss: All values are NaN.")]
210 AllValuesNaN,
211 #[error("voss: Invalid period: period = {period}, data length = {data_len}")]
212 InvalidPeriod { period: usize, data_len: usize },
213 #[error("voss: Not enough valid data: needed = {needed}, valid = {valid}")]
214 NotEnoughValidData { needed: usize, valid: usize },
215 #[error("voss: Output length mismatch: expected = {expected}, got = {got}")]
216 OutputLengthMismatch { expected: usize, got: usize },
217 #[error("voss: Invalid range: start = {start}, end = {end}, step = {step}")]
218 InvalidRange {
219 start: String,
220 end: String,
221 step: String,
222 },
223 #[error("voss: Invalid kernel for batch: {0:?}")]
224 InvalidKernelForBatch(Kernel),
225}
226
227#[inline]
228pub fn voss(input: &VossInput) -> Result<VossOutput, VossError> {
229 voss_with_kernel(input, Kernel::Auto)
230}
231
232#[inline(always)]
233fn voss_prepare<'a>(
234 input: &'a VossInput,
235 kernel: Kernel,
236) -> Result<(&'a [f64], usize, usize, f64, usize, usize, Kernel), VossError> {
237 let data: &[f64] = match &input.data {
238 VossData::Candles { candles, source } => source_type(candles, source),
239 VossData::Slice(sl) => sl,
240 };
241
242 if data.is_empty() {
243 return Err(VossError::EmptyInputData);
244 }
245
246 let first = data
247 .iter()
248 .position(|x| !x.is_nan())
249 .ok_or(VossError::AllValuesNaN)?;
250 let len = data.len();
251 let period = input.get_period();
252 let predict = input.get_predict();
253 let bandwidth = input.get_bandwidth();
254
255 if period == 0 || period > len {
256 return Err(VossError::InvalidPeriod {
257 period,
258 data_len: len,
259 });
260 }
261
262 let order = 3 * predict;
263 let min_index = period.max(5).max(order);
264 if (len - first) < min_index {
265 return Err(VossError::NotEnoughValidData {
266 needed: min_index,
267 valid: len - first,
268 });
269 }
270
271 let chosen = match kernel {
272 Kernel::Auto => detect_best_kernel(),
273 other => other,
274 };
275
276 Ok((data, period, predict, bandwidth, first, min_index, chosen))
277}
278
279#[inline(always)]
280fn voss_compute_into(
281 data: &[f64],
282 period: usize,
283 predict: usize,
284 bandwidth: f64,
285 first: usize,
286 min_index: usize,
287 kernel: Kernel,
288 voss: &mut [f64],
289 filt: &mut [f64],
290) {
291 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
292 {
293 if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
294 unsafe {
295 voss_simd128(
296 data, period, predict, bandwidth, first, min_index, voss, filt,
297 );
298 }
299 return;
300 }
301 }
302
303 match kernel {
304 Kernel::Scalar | Kernel::ScalarBatch => {
305 voss_scalar(data, period, predict, bandwidth, first, voss, filt)
306 }
307 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
308 Kernel::Avx2 | Kernel::Avx2Batch => unsafe {
309 voss_avx2(data, period, predict, bandwidth, first, voss, filt)
310 },
311 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
312 Kernel::Avx512 | Kernel::Avx512Batch => unsafe {
313 voss_avx512(data, period, predict, bandwidth, first, voss, filt)
314 },
315 _ => unreachable!(),
316 }
317}
318
319pub fn voss_with_kernel(input: &VossInput, kernel: Kernel) -> Result<VossOutput, VossError> {
320 let (data, period, predict, bandwidth, first, min_index, chosen) = voss_prepare(input, kernel)?;
321
322 let warmup_period = first + min_index;
323 let mut voss = alloc_with_nan_prefix(data.len(), warmup_period);
324 let mut filt = alloc_with_nan_prefix(data.len(), warmup_period);
325
326 voss_compute_into(
327 data, period, predict, bandwidth, first, min_index, chosen, &mut voss, &mut filt,
328 );
329
330 Ok(VossOutput { voss, filt })
331}
332
333#[inline]
334pub fn voss_into_slice(
335 voss_dst: &mut [f64],
336 filt_dst: &mut [f64],
337 input: &VossInput,
338 kern: Kernel,
339) -> Result<(), VossError> {
340 let (data, period, predict, bandwidth, first, min_index, chosen) = voss_prepare(input, kern)?;
341
342 if voss_dst.len() != data.len() {
343 return Err(VossError::OutputLengthMismatch {
344 expected: data.len(),
345 got: voss_dst.len(),
346 });
347 }
348 if filt_dst.len() != data.len() {
349 return Err(VossError::OutputLengthMismatch {
350 expected: data.len(),
351 got: filt_dst.len(),
352 });
353 }
354
355 let warmup_end = first + min_index;
356 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
357 let v_end = if warmup_end < voss_dst.len() {
358 warmup_end
359 } else {
360 voss_dst.len()
361 };
362 for v in &mut voss_dst[..v_end] {
363 *v = qnan;
364 }
365 let f_end = if warmup_end < filt_dst.len() {
366 warmup_end
367 } else {
368 filt_dst.len()
369 };
370 for v in &mut filt_dst[..f_end] {
371 *v = qnan;
372 }
373
374 voss_compute_into(
375 data, period, predict, bandwidth, first, min_index, chosen, voss_dst, filt_dst,
376 );
377
378 Ok(())
379}
380
381#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
382#[inline]
383pub fn voss_into(
384 input: &VossInput,
385 voss_out: &mut [f64],
386 filt_out: &mut [f64],
387) -> Result<(), VossError> {
388 voss_into_slice(voss_out, filt_out, input, Kernel::Auto)
389}
390
391#[inline]
392pub fn voss_scalar(
393 data: &[f64],
394 period: usize,
395 predict: usize,
396 bandwidth: f64,
397 first: usize,
398 voss: &mut [f64],
399 filt: &mut [f64],
400) {
401 voss_row_scalar(data, first, period, predict, bandwidth, voss, filt)
402}
403
404#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
405#[inline]
406unsafe fn voss_simd128(
407 data: &[f64],
408 period: usize,
409 predict: usize,
410 bandwidth: f64,
411 first: usize,
412 min_index: usize,
413 voss: &mut [f64],
414 filt: &mut [f64],
415) {
416 use core::arch::wasm32::*;
417 use std::f64::consts::PI;
418
419 let order = 3 * predict;
420 let min_idx = period.max(5).max(order);
421
422 let f1 = (2.0 * PI / period as f64).cos();
423 let g1 = (bandwidth * 2.0 * PI / period as f64).cos();
424 let s1 = 1.0 / g1 - (1.0 / (g1 * g1) - 1.0).sqrt();
425
426 let start_idx = first + min_idx;
427 if start_idx >= 2 {
428 filt[start_idx - 2] = 0.0;
429 filt[start_idx - 1] = 0.0;
430 }
431
432 let coeff1 = 0.5 * (1.0 - s1);
433 let coeff2 = f1 * (1.0 + s1);
434 let coeff3 = -s1;
435
436 for i in start_idx..data.len() {
437 let current = data[i];
438 let prev_2 = data[i - 2];
439 let prev_filt_1 = filt[i - 1];
440 let prev_filt_2 = filt[i - 2];
441 filt[i] = coeff1 * (current - prev_2) + coeff2 * prev_filt_1 + coeff3 * prev_filt_2;
442 }
443
444 let scale = (3 + order) as f64 / 2.0;
445 let inv_order = 1.0 / order as f64;
446
447 for i in start_idx..data.len() {
448 let mut sum = 0.0;
449 let base_idx = i - order;
450
451 let pairs = order / 2;
452 for j in 0..pairs {
453 let idx1 = base_idx + j * 2;
454 let idx2 = base_idx + j * 2 + 1;
455
456 let voss_val1 = if idx1 >= first && !voss[idx1].is_nan() {
457 voss[idx1]
458 } else {
459 0.0
460 };
461 let voss_val2 = if idx2 >= first && !voss[idx2].is_nan() {
462 voss[idx2]
463 } else {
464 0.0
465 };
466
467 let v1 = v128_load64_splat(&voss_val1 as *const f64 as *const u64);
468 let v2 = v128_load64_splat(&voss_val2 as *const f64 as *const u64);
469 let vals = f64x2_replace_lane::<1>(v1, f64x2_extract_lane::<0>(v2));
470
471 let w1 = (j * 2 + 1) as f64 * inv_order;
472 let w2 = (j * 2 + 2) as f64 * inv_order;
473 let weights = f64x2(w1, w2);
474
475 let prod = f64x2_mul(vals, weights);
476 sum += f64x2_extract_lane::<0>(prod) + f64x2_extract_lane::<1>(prod);
477 }
478
479 if order % 2 != 0 {
480 let idx = base_idx + order - 1;
481 let voss_val = if idx >= first && !voss[idx].is_nan() {
482 voss[idx]
483 } else {
484 0.0
485 };
486 sum += order as f64 * inv_order * voss_val;
487 }
488
489 voss[i] = scale * filt[i] - sum;
490 }
491}
492
493#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
494#[inline]
495pub unsafe fn voss_avx2(
496 data: &[f64],
497 period: usize,
498 predict: usize,
499 bandwidth: f64,
500 first: usize,
501 voss: &mut [f64],
502 filt: &mut [f64],
503) {
504 voss_row_scalar_unchecked(data, first, period, predict, bandwidth, voss, filt)
505}
506
507#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
508#[inline]
509pub unsafe fn voss_avx512(
510 data: &[f64],
511 period: usize,
512 predict: usize,
513 bandwidth: f64,
514 first: usize,
515 voss: &mut [f64],
516 filt: &mut [f64],
517) {
518 if period <= 32 {
519 voss_avx512_short(data, period, predict, bandwidth, first, voss, filt)
520 } else {
521 voss_avx512_long(data, period, predict, bandwidth, first, voss, filt)
522 }
523}
524
525#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
526#[inline]
527pub unsafe fn voss_avx512_short(
528 data: &[f64],
529 period: usize,
530 predict: usize,
531 bandwidth: f64,
532 first: usize,
533 voss: &mut [f64],
534 filt: &mut [f64],
535) {
536 voss_row_scalar_unchecked(data, first, period, predict, bandwidth, voss, filt)
537}
538
539#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
540#[inline]
541pub unsafe fn voss_avx512_long(
542 data: &[f64],
543 period: usize,
544 predict: usize,
545 bandwidth: f64,
546 first: usize,
547 voss: &mut [f64],
548 filt: &mut [f64],
549) {
550 voss_row_scalar_unchecked(data, first, period, predict, bandwidth, voss, filt)
551}
552
553#[derive(Debug, Clone)]
554pub struct VossStream {
555 period: usize,
556 predict: usize,
557 bandwidth: f64,
558
559 f1: f64,
560 s1: f64,
561 c1: f64,
562 c2: f64,
563 c3: f64,
564 order: usize,
565 warm_left: usize,
566
567 prev_f1: f64,
568 prev_f2: f64,
569 last_x1: f64,
570 last_x2: f64,
571
572 ring: Vec<f64>,
573 rpos: usize,
574 a_sum: f64,
575 d_sum: f64,
576 ord_f: f64,
577 inv_order: f64,
578 scale: f64,
579
580 filled: bool,
581}
582
583impl VossStream {
584 pub fn try_new(params: VossParams) -> Result<Self, VossError> {
585 use std::f64::consts::PI;
586 let period = params.period.unwrap_or(20);
587 if period == 0 {
588 return Err(VossError::InvalidPeriod {
589 period,
590 data_len: 0,
591 });
592 }
593 let predict = params.predict.unwrap_or(3);
594 let bandwidth = params.bandwidth.unwrap_or(0.25);
595
596 let order = 3 * predict;
597 let min_index = period.max(5).max(order);
598
599 let w0 = 2.0 * PI / period as f64;
600 let f1 = w0.cos();
601 let g1 = (bandwidth * w0).cos();
602
603 let inv_g1 = 1.0 / g1;
604 let root = (inv_g1.mul_add(inv_g1, -1.0)).sqrt();
605 let s1 = 1.0 / (inv_g1 + root);
606
607 Ok(Self {
608 period,
609 predict,
610 bandwidth,
611 f1,
612 s1,
613 c1: 0.5 * (1.0 - s1),
614 c2: f1 * (1.0 + s1),
615 c3: -s1,
616 order,
617 warm_left: min_index,
618 prev_f1: 0.0,
619 prev_f2: 0.0,
620 last_x1: f64::NAN,
621 last_x2: f64::NAN,
622 ring: if order > 0 {
623 vec![0.0; order]
624 } else {
625 Vec::new()
626 },
627 rpos: 0,
628 a_sum: 0.0,
629 d_sum: 0.0,
630 ord_f: order as f64,
631 inv_order: if order > 0 { 1.0 / order as f64 } else { 0.0 },
632 scale: 0.5 * (3 + order) as f64,
633
634 filled: false,
635 })
636 }
637
638 #[inline(always)]
639 pub fn update(&mut self, x: f64) -> Option<(f64, f64)> {
640 let x_im2 = self.last_x2;
641 self.last_x2 = self.last_x1;
642 self.last_x1 = x;
643
644 let filt_val = if x_im2.is_nan() {
645 0.0
646 } else {
647 let diff = x - x_im2;
648 let t = self.c1 * diff + self.c3 * self.prev_f2;
649 let f = self.c2.mul_add(self.prev_f1, t);
650 self.prev_f2 = self.prev_f1;
651 self.prev_f1 = f;
652 f
653 };
654
655 let voss_val = if self.order == 0 {
656 self.scale * filt_val
657 } else {
658 let sumc = self.d_sum * self.inv_order;
659 let vi = self.scale.mul_add(filt_val, -sumc);
660
661 let v_new_nz = if vi.is_nan() { 0.0 } else { vi };
662 let v_old = self.ring[self.rpos];
663 let a_prev = self.a_sum;
664 self.a_sum = a_prev - v_old + v_new_nz;
665 self.d_sum = self.ord_f.mul_add(v_new_nz, self.d_sum - a_prev);
666
667 self.ring[self.rpos] = v_new_nz;
668 self.rpos += 1;
669 if self.rpos == self.order {
670 self.rpos = 0;
671 }
672
673 vi
674 };
675
676 if self.warm_left > 0 {
677 self.warm_left -= 1;
678 return None;
679 }
680 self.filled = true;
681 Some((voss_val, filt_val))
682 }
683}
684
685#[derive(Clone, Debug)]
686pub struct VossBatchRange {
687 pub period: (usize, usize, usize),
688 pub predict: (usize, usize, usize),
689 pub bandwidth: (f64, f64, f64),
690}
691
692impl Default for VossBatchRange {
693 fn default() -> Self {
694 Self {
695 period: (20, 269, 1),
696 predict: (3, 3, 0),
697 bandwidth: (0.25, 0.25, 0.0),
698 }
699 }
700}
701
702#[derive(Clone, Debug, Default)]
703pub struct VossBatchBuilder {
704 range: VossBatchRange,
705 kernel: Kernel,
706}
707
708impl VossBatchBuilder {
709 pub fn new() -> Self {
710 Self::default()
711 }
712 pub fn kernel(mut self, k: Kernel) -> Self {
713 self.kernel = k;
714 self
715 }
716 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
717 self.range.period = (start, end, step);
718 self
719 }
720 pub fn predict_range(mut self, start: usize, end: usize, step: usize) -> Self {
721 self.range.predict = (start, end, step);
722 self
723 }
724 pub fn bandwidth_range(mut self, start: f64, end: f64, step: f64) -> Self {
725 self.range.bandwidth = (start, end, step);
726 self
727 }
728
729 pub fn apply_slice(self, data: &[f64]) -> Result<VossBatchOutput, VossError> {
730 voss_batch_with_kernel(data, &self.range, self.kernel)
731 }
732
733 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<VossBatchOutput, VossError> {
734 let slice = source_type(c, src);
735 self.apply_slice(slice)
736 }
737}
738
739pub fn voss_batch_with_kernel(
740 data: &[f64],
741 sweep: &VossBatchRange,
742 k: Kernel,
743) -> Result<VossBatchOutput, VossError> {
744 let kernel = match k {
745 Kernel::Auto => detect_best_batch_kernel(),
746 other if other.is_batch() => other,
747 other => return Err(VossError::InvalidKernelForBatch(other)),
748 };
749 let simd = match kernel {
750 Kernel::Avx512Batch => Kernel::Avx512,
751 Kernel::Avx2Batch => Kernel::Avx2,
752 Kernel::ScalarBatch => Kernel::Scalar,
753 _ => unreachable!(),
754 };
755 voss_batch_par_slice(data, sweep, simd)
756}
757
758#[derive(Clone, Debug)]
759pub struct VossBatchOutput {
760 pub voss: Vec<f64>,
761 pub filt: Vec<f64>,
762 pub combos: Vec<VossParams>,
763 pub rows: usize,
764 pub cols: usize,
765}
766impl VossBatchOutput {
767 pub fn row_for_params(&self, p: &VossParams) -> Option<usize> {
768 self.combos.iter().position(|c| {
769 c.period.unwrap_or(20) == p.period.unwrap_or(20)
770 && c.predict.unwrap_or(3) == p.predict.unwrap_or(3)
771 && (c.bandwidth.unwrap_or(0.25) - p.bandwidth.unwrap_or(0.25)).abs() < 1e-12
772 })
773 }
774 pub fn voss_for(&self, p: &VossParams) -> Option<&[f64]> {
775 self.row_for_params(p).map(|row| {
776 let start = row * self.cols;
777 &self.voss[start..start + self.cols]
778 })
779 }
780 pub fn filt_for(&self, p: &VossParams) -> Option<&[f64]> {
781 self.row_for_params(p).map(|row| {
782 let start = row * self.cols;
783 &self.filt[start..start + self.cols]
784 })
785 }
786}
787
788#[inline(always)]
789fn expand_grid(r: &VossBatchRange) -> Result<Vec<VossParams>, VossError> {
790 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, VossError> {
791 if step == 0 || start == end {
792 return Ok(vec![start]);
793 }
794 let mut out = Vec::new();
795 if start < end {
796 let mut v = start;
797 while v <= end {
798 out.push(v);
799 match v.checked_add(step) {
800 Some(next) => {
801 if next == v {
802 break;
803 }
804 v = next;
805 }
806 None => break,
807 }
808 }
809 } else {
810 let mut v = start as isize;
811 let end_i = end as isize;
812 let st = (step as isize).max(1);
813 while v >= end_i {
814 out.push(v as usize);
815 v -= st;
816 }
817 }
818 if out.is_empty() {
819 return Err(VossError::InvalidRange {
820 start: start.to_string(),
821 end: end.to_string(),
822 step: step.to_string(),
823 });
824 }
825 Ok(out)
826 }
827 fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, VossError> {
828 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
829 return Ok(vec![start]);
830 }
831 if start < end {
832 let mut v = Vec::new();
833 let mut x = start;
834 let st = step.abs();
835 while x <= end + 1e-12 {
836 v.push(x);
837 x += st;
838 }
839 if v.is_empty() {
840 return Err(VossError::InvalidRange {
841 start: start.to_string(),
842 end: end.to_string(),
843 step: step.to_string(),
844 });
845 }
846 return Ok(v);
847 }
848 let mut v = Vec::new();
849 let mut x = start;
850 let st = step.abs();
851 while x + 1e-12 >= end {
852 v.push(x);
853 x -= st;
854 }
855 if v.is_empty() {
856 return Err(VossError::InvalidRange {
857 start: start.to_string(),
858 end: end.to_string(),
859 step: step.to_string(),
860 });
861 }
862 Ok(v)
863 }
864 let periods = axis_usize(r.period)?;
865 let predicts = axis_usize(r.predict)?;
866 let bandwidths = axis_f64(r.bandwidth)?;
867 let cap = periods
868 .len()
869 .checked_mul(predicts.len())
870 .and_then(|x| x.checked_mul(bandwidths.len()))
871 .ok_or_else(|| VossError::InvalidRange {
872 start: "cap".into(),
873 end: "overflow".into(),
874 step: "mul".into(),
875 })?;
876 let mut out = Vec::with_capacity(cap);
877 for &p in &periods {
878 for &q in &predicts {
879 for &b in &bandwidths {
880 out.push(VossParams {
881 period: Some(p),
882 predict: Some(q),
883 bandwidth: Some(b),
884 });
885 }
886 }
887 }
888 if out.is_empty() {
889 return Err(VossError::InvalidRange {
890 start: "combos".into(),
891 end: "empty".into(),
892 step: "voss".into(),
893 });
894 }
895 Ok(out)
896}
897
898#[inline(always)]
899pub fn voss_batch_slice(
900 data: &[f64],
901 sweep: &VossBatchRange,
902 kern: Kernel,
903) -> Result<VossBatchOutput, VossError> {
904 voss_batch_inner(data, sweep, kern, false)
905}
906
907#[inline(always)]
908pub fn voss_batch_par_slice(
909 data: &[f64],
910 sweep: &VossBatchRange,
911 kern: Kernel,
912) -> Result<VossBatchOutput, VossError> {
913 voss_batch_inner(data, sweep, kern, true)
914}
915
916#[inline(always)]
917pub fn voss_batch_inner_into(
918 data: &[f64],
919 sweep: &VossBatchRange,
920 kern: Kernel,
921 parallel: bool,
922 voss_out: &mut [f64],
923 filt_out: &mut [f64],
924) -> Result<Vec<VossParams>, VossError> {
925 let combos = expand_grid(sweep)?;
926
927 let first = data
928 .iter()
929 .position(|x| !x.is_nan())
930 .ok_or(VossError::AllValuesNaN)?;
931 let rows = combos.len();
932 let cols = data.len();
933
934 let expected = rows
935 .checked_mul(cols)
936 .ok_or_else(|| VossError::InvalidRange {
937 start: rows.to_string(),
938 end: cols.to_string(),
939 step: "rows*cols".into(),
940 })?;
941 if voss_out.len() != expected {
942 return Err(VossError::OutputLengthMismatch {
943 expected,
944 got: voss_out.len(),
945 });
946 }
947 if filt_out.len() != expected {
948 return Err(VossError::OutputLengthMismatch {
949 expected,
950 got: filt_out.len(),
951 });
952 }
953
954 let do_row = |row: usize, vo: &mut [f64], fo: &mut [f64]| {
955 let p = combos[row].period.unwrap_or(20);
956 let q = combos[row].predict.unwrap_or(3);
957 let order = 3 * q;
958 let min_index = p.max(5).max(order);
959 let warm = first + min_index;
960 for x in &mut vo[..warm.min(cols)] {
961 *x = f64::NAN;
962 }
963 for x in &mut fo[..warm.min(cols)] {
964 *x = f64::NAN;
965 }
966
967 match kern {
968 Kernel::Scalar | Kernel::ScalarBatch => voss_scalar(
969 data,
970 p,
971 q,
972 combos[row].bandwidth.unwrap_or(0.25),
973 first,
974 vo,
975 fo,
976 ),
977 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
978 Kernel::Avx2 | Kernel::Avx2Batch => unsafe {
979 voss_avx2(
980 data,
981 p,
982 q,
983 combos[row].bandwidth.unwrap_or(0.25),
984 first,
985 vo,
986 fo,
987 )
988 },
989 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
990 Kernel::Avx512 | Kernel::Avx512Batch => unsafe {
991 voss_avx512(
992 data,
993 p,
994 q,
995 combos[row].bandwidth.unwrap_or(0.25),
996 first,
997 vo,
998 fo,
999 )
1000 },
1001 Kernel::Auto => unreachable!(),
1002 _ => unreachable!(),
1003 }
1004 };
1005
1006 if parallel {
1007 #[cfg(not(target_arch = "wasm32"))]
1008 {
1009 use rayon::prelude::*;
1010 voss_out
1011 .par_chunks_mut(cols)
1012 .zip(filt_out.par_chunks_mut(cols))
1013 .enumerate()
1014 .for_each(|(row, (vo, fo))| do_row(row, vo, fo));
1015 }
1016 #[cfg(target_arch = "wasm32")]
1017 {
1018 for (row, (vo, fo)) in voss_out
1019 .chunks_mut(cols)
1020 .zip(filt_out.chunks_mut(cols))
1021 .enumerate()
1022 {
1023 do_row(row, vo, fo);
1024 }
1025 }
1026 } else {
1027 for (row, (vo, fo)) in voss_out
1028 .chunks_mut(cols)
1029 .zip(filt_out.chunks_mut(cols))
1030 .enumerate()
1031 {
1032 do_row(row, vo, fo);
1033 }
1034 }
1035
1036 Ok(combos)
1037}
1038
1039#[inline(always)]
1040fn voss_batch_inner(
1041 data: &[f64],
1042 sweep: &VossBatchRange,
1043 kern: Kernel,
1044 parallel: bool,
1045) -> Result<VossBatchOutput, VossError> {
1046 let combos = expand_grid(sweep)?;
1047
1048 let first = data
1049 .iter()
1050 .position(|x| !x.is_nan())
1051 .ok_or(VossError::AllValuesNaN)?;
1052 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1053 if data.len() - first < max_p {
1054 return Err(VossError::NotEnoughValidData {
1055 needed: max_p,
1056 valid: data.len() - first,
1057 });
1058 }
1059
1060 let rows = combos.len();
1061 let cols = data.len();
1062 let _total = rows
1063 .checked_mul(cols)
1064 .ok_or_else(|| VossError::InvalidRange {
1065 start: rows.to_string(),
1066 end: cols.to_string(),
1067 step: "rows*cols".into(),
1068 })?;
1069
1070 let warmup_periods: Vec<usize> = combos
1071 .iter()
1072 .map(|c| {
1073 let period = c.period.unwrap();
1074 let predict = c.predict.unwrap();
1075 let order = 3 * predict;
1076 let min_index = period.max(5).max(order);
1077 first + min_index
1078 })
1079 .collect();
1080
1081 let mut voss_mu = make_uninit_matrix(rows, cols);
1082 let mut filt_mu = make_uninit_matrix(rows, cols);
1083 init_matrix_prefixes(&mut voss_mu, cols, &warmup_periods);
1084 init_matrix_prefixes(&mut filt_mu, cols, &warmup_periods);
1085
1086 let mut voss_guard = core::mem::ManuallyDrop::new(voss_mu);
1087 let mut filt_guard = core::mem::ManuallyDrop::new(filt_mu);
1088 let voss_slice: &mut [f64] = unsafe {
1089 core::slice::from_raw_parts_mut(voss_guard.as_mut_ptr() as *mut f64, voss_guard.len())
1090 };
1091 let filt_slice: &mut [f64] = unsafe {
1092 core::slice::from_raw_parts_mut(filt_guard.as_mut_ptr() as *mut f64, filt_guard.len())
1093 };
1094
1095 let do_row = |row: usize, out_voss: &mut [f64], out_filt: &mut [f64]| {
1096 let prm = &combos[row];
1097 let period = prm.period.unwrap();
1098 let predict = prm.predict.unwrap();
1099 let bandwidth = prm.bandwidth.unwrap();
1100 match kern {
1101 Kernel::Scalar => {
1102 voss_row_scalar(data, first, period, predict, bandwidth, out_voss, out_filt)
1103 }
1104 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1105 Kernel::Avx2 => unsafe {
1106 voss_row_avx2(data, first, period, predict, bandwidth, out_voss, out_filt)
1107 },
1108 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1109 Kernel::Avx512 => unsafe {
1110 voss_row_avx512(data, first, period, predict, bandwidth, out_voss, out_filt)
1111 },
1112 _ => unreachable!(),
1113 }
1114 };
1115
1116 if parallel {
1117 #[cfg(not(target_arch = "wasm32"))]
1118 {
1119 voss_slice
1120 .par_chunks_mut(cols)
1121 .zip(filt_slice.par_chunks_mut(cols))
1122 .enumerate()
1123 .for_each(|(row, (vo, fo))| do_row(row, vo, fo));
1124 }
1125
1126 #[cfg(target_arch = "wasm32")]
1127 {
1128 for (row, (vo, fo)) in voss_slice
1129 .chunks_mut(cols)
1130 .zip(filt_slice.chunks_mut(cols))
1131 .enumerate()
1132 {
1133 do_row(row, vo, fo);
1134 }
1135 }
1136 } else {
1137 for (row, (vo, fo)) in voss_slice
1138 .chunks_mut(cols)
1139 .zip(filt_slice.chunks_mut(cols))
1140 .enumerate()
1141 {
1142 do_row(row, vo, fo);
1143 }
1144 }
1145
1146 let voss_vec = unsafe {
1147 let raw_vec = Vec::from_raw_parts(
1148 voss_guard.as_mut_ptr() as *mut f64,
1149 voss_guard.len(),
1150 voss_guard.capacity(),
1151 );
1152 core::mem::forget(voss_guard);
1153 raw_vec
1154 };
1155 let filt_vec = unsafe {
1156 let raw_vec = Vec::from_raw_parts(
1157 filt_guard.as_mut_ptr() as *mut f64,
1158 filt_guard.len(),
1159 filt_guard.capacity(),
1160 );
1161 core::mem::forget(filt_guard);
1162 raw_vec
1163 };
1164
1165 Ok(VossBatchOutput {
1166 voss: voss_vec,
1167 filt: filt_vec,
1168 combos,
1169 rows,
1170 cols,
1171 })
1172}
1173
1174#[inline(always)]
1175fn voss_row_scalar(
1176 data: &[f64],
1177 first: usize,
1178 period: usize,
1179 predict: usize,
1180 bandwidth: f64,
1181 voss: &mut [f64],
1182 filt: &mut [f64],
1183) {
1184 use std::f64::consts::PI;
1185
1186 let order = 3 * predict;
1187 let min_index = period.max(5).max(order);
1188 let start = first + min_index;
1189 if start >= data.len() {
1190 return;
1191 }
1192
1193 let w0 = 2.0 * PI / period as f64;
1194 let f1 = w0.cos();
1195 let g1 = (bandwidth * w0).cos();
1196 let s1 = 1.0 / g1 - (1.0 / (g1 * g1) - 1.0).sqrt();
1197 let c1 = 0.5 * (1.0 - s1);
1198 let c2 = f1 * (1.0 + s1);
1199 let c3 = -s1;
1200
1201 if start >= 2 {
1202 filt[start - 2] = 0.0;
1203 filt[start - 1] = 0.0;
1204 }
1205
1206 let mut prev_f1 = 0.0f64;
1207 let mut prev_f2 = 0.0f64;
1208 let scale = 0.5 * (3 + order) as f64;
1209
1210 if order == 0 {
1211 for i in start..data.len() {
1212 let diff = data[i] - data[i - 2];
1213 let t = c3.mul_add(prev_f2, c1 * diff);
1214 let f = c2.mul_add(prev_f1, t);
1215 filt[i] = f;
1216 prev_f2 = prev_f1;
1217 prev_f1 = f;
1218 voss[i] = scale * f;
1219 }
1220 return;
1221 }
1222
1223 let ord_f = order as f64;
1224 let inv_order = 1.0 / ord_f;
1225 let mut a_sum = 0.0f64;
1226 let mut d_sum = 0.0f64;
1227 let mut ring = vec![0.0f64; order];
1228 let mut rpos = 0usize;
1229
1230 for i in start..data.len() {
1231 let diff = data[i] - data[i - 2];
1232 let t = c3.mul_add(prev_f2, c1 * diff);
1233 let f = c2.mul_add(prev_f1, t);
1234 filt[i] = f;
1235 prev_f2 = prev_f1;
1236 prev_f1 = f;
1237
1238 let sumc = d_sum * inv_order;
1239 let vi = scale.mul_add(f, -sumc);
1240 voss[i] = vi;
1241
1242 let v_new_nz = if vi.is_nan() { 0.0 } else { vi };
1243 let v_old = ring[rpos];
1244
1245 let a_prev = a_sum;
1246 a_sum = a_prev - v_old + v_new_nz;
1247 d_sum = ord_f.mul_add(v_new_nz, d_sum - a_prev);
1248
1249 ring[rpos] = v_new_nz;
1250 rpos += 1;
1251 if rpos == order {
1252 rpos = 0;
1253 }
1254 }
1255}
1256
1257#[inline(always)]
1258unsafe fn voss_row_scalar_unchecked(
1259 data: &[f64],
1260 first: usize,
1261 period: usize,
1262 predict: usize,
1263 bandwidth: f64,
1264 voss: &mut [f64],
1265 filt: &mut [f64],
1266) {
1267 use std::f64::consts::PI;
1268
1269 let order = 3 * predict;
1270 let min_index = period.max(5).max(order);
1271 let start = first + min_index;
1272 let len = data.len();
1273 if start >= len {
1274 return;
1275 }
1276
1277 let w0 = 2.0 * PI / period as f64;
1278 let f1 = w0.cos();
1279 let g1 = (bandwidth * w0).cos();
1280
1281 let s1 = 1.0 / g1 - (1.0 / (g1 * g1) - 1.0).sqrt();
1282 let c1 = 0.5 * (1.0 - s1);
1283 let c2 = f1 * (1.0 + s1);
1284 let c3 = -s1;
1285
1286 if start >= 2 {
1287 *filt.get_unchecked_mut(start - 2) = 0.0;
1288 *filt.get_unchecked_mut(start - 1) = 0.0;
1289 }
1290
1291 let mut prev_f1 = 0.0f64;
1292 let mut prev_f2 = 0.0f64;
1293 let scale = 0.5 * (3 + order) as f64;
1294
1295 if order == 0 {
1296 for i in start..len {
1297 let xi = *data.get_unchecked(i);
1298 let xim2 = *data.get_unchecked(i - 2);
1299 let diff = xi - xim2;
1300 let t = c3.mul_add(prev_f2, c1 * diff);
1301 let f = c2.mul_add(prev_f1, t);
1302 *filt.get_unchecked_mut(i) = f;
1303 *voss.get_unchecked_mut(i) = scale * f;
1304 prev_f2 = prev_f1;
1305 prev_f1 = f;
1306 }
1307 return;
1308 }
1309
1310 let ord_f = order as f64;
1311 let inv_order = 1.0 / ord_f;
1312 let mut a_sum = 0.0f64;
1313 let mut d_sum = 0.0f64;
1314 let mut ring = vec![0.0f64; order];
1315 let mut rpos = 0usize;
1316
1317 for i in start..len {
1318 let xi = *data.get_unchecked(i);
1319 let xim2 = *data.get_unchecked(i - 2);
1320 let diff = xi - xim2;
1321 let t = c3.mul_add(prev_f2, c1 * diff);
1322 let f = c2.mul_add(prev_f1, t);
1323 *filt.get_unchecked_mut(i) = f;
1324 prev_f2 = prev_f1;
1325 prev_f1 = f;
1326
1327 let sumc = d_sum * inv_order;
1328 let vi = scale.mul_add(f, -sumc);
1329 *voss.get_unchecked_mut(i) = vi;
1330
1331 let v_new_nz = if vi.is_nan() { 0.0 } else { vi };
1332 let v_old = *ring.get_unchecked(rpos);
1333
1334 let a_prev = a_sum;
1335 a_sum = a_prev - v_old + v_new_nz;
1336 d_sum = ord_f.mul_add(v_new_nz, d_sum - a_prev);
1337
1338 *ring.get_unchecked_mut(rpos) = v_new_nz;
1339 rpos += 1;
1340 if rpos == order {
1341 rpos = 0;
1342 }
1343 }
1344}
1345
1346#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1347#[inline(always)]
1348unsafe fn voss_row_avx2(
1349 data: &[f64],
1350 first: usize,
1351 period: usize,
1352 predict: usize,
1353 bandwidth: f64,
1354 voss: &mut [f64],
1355 filt: &mut [f64],
1356) {
1357 voss_avx2(data, period, predict, bandwidth, first, voss, filt)
1358}
1359
1360#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1361#[inline(always)]
1362unsafe fn voss_row_avx512(
1363 data: &[f64],
1364 first: usize,
1365 period: usize,
1366 predict: usize,
1367 bandwidth: f64,
1368 voss: &mut [f64],
1369 filt: &mut [f64],
1370) {
1371 voss_avx512(data, period, predict, bandwidth, first, voss, filt)
1372}
1373
1374#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1375#[inline(always)]
1376unsafe fn voss_row_avx512_short(
1377 data: &[f64],
1378 first: usize,
1379 period: usize,
1380 predict: usize,
1381 bandwidth: f64,
1382 voss: &mut [f64],
1383 filt: &mut [f64],
1384) {
1385 voss_avx512_short(data, period, predict, bandwidth, first, voss, filt)
1386}
1387
1388#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1389#[inline(always)]
1390unsafe fn voss_row_avx512_long(
1391 data: &[f64],
1392 first: usize,
1393 period: usize,
1394 predict: usize,
1395 bandwidth: f64,
1396 voss: &mut [f64],
1397 filt: &mut [f64],
1398) {
1399 voss_avx512_long(data, period, predict, bandwidth, first, voss, filt)
1400}
1401
1402#[inline(always)]
1403pub fn expand_grid_voss(r: &VossBatchRange) -> Result<Vec<VossParams>, VossError> {
1404 expand_grid(r)
1405}
1406
1407#[cfg(feature = "python")]
1408pub fn register_voss_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
1409 m.add_function(wrap_pyfunction!(voss_py, m)?)?;
1410 m.add_function(wrap_pyfunction!(voss_batch_py, m)?)?;
1411 m.add_class::<VossStreamPy>()?;
1412 #[cfg(all(feature = "python", feature = "cuda"))]
1413 {
1414 m.add_class::<VossDeviceArrayF32Py>()?;
1415 }
1416 Ok(())
1417}
1418
1419#[cfg(all(feature = "python", feature = "cuda"))]
1420#[pyclass(module = "ta_indicators.cuda", unsendable)]
1421pub struct VossDeviceArrayF32Py {
1422 pub(crate) buf: Option<DeviceBuffer<f32>>,
1423 pub(crate) rows: usize,
1424 pub(crate) cols: usize,
1425 pub(crate) _ctx: Arc<Context>,
1426 pub(crate) device_id: u32,
1427}
1428
1429#[cfg(all(feature = "python", feature = "cuda"))]
1430#[pymethods]
1431impl VossDeviceArrayF32Py {
1432 #[getter]
1433 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1434 let d = PyDict::new(py);
1435 d.set_item("shape", (self.rows, self.cols))?;
1436 d.set_item("typestr", "<f4")?;
1437 d.set_item(
1438 "strides",
1439 (
1440 self.cols * std::mem::size_of::<f32>(),
1441 std::mem::size_of::<f32>(),
1442 ),
1443 )?;
1444 let ptr = self
1445 .buf
1446 .as_ref()
1447 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
1448 .as_device_ptr()
1449 .as_raw() as usize;
1450 d.set_item("data", (ptr, false))?;
1451
1452 d.set_item("version", 3)?;
1453 Ok(d)
1454 }
1455
1456 fn __dlpack_device__(&self) -> (i32, i32) {
1457 (2, self.device_id as i32)
1458 }
1459
1460 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1461 fn __dlpack__<'py>(
1462 &mut self,
1463 py: Python<'py>,
1464 stream: Option<pyo3::PyObject>,
1465 max_version: Option<pyo3::PyObject>,
1466 dl_device: Option<pyo3::PyObject>,
1467 copy: Option<pyo3::PyObject>,
1468 ) -> PyResult<PyObject> {
1469 let (exp_dev_ty, alloc_dev) = self.__dlpack_device__();
1470 if let Some(dev_obj) = dl_device.as_ref() {
1471 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1472 if dev_ty != exp_dev_ty || dev_id != alloc_dev {
1473 let wants_copy = copy
1474 .as_ref()
1475 .and_then(|c| c.extract::<bool>(py).ok())
1476 .unwrap_or(false);
1477 if wants_copy {
1478 return Err(PyValueError::new_err(
1479 "device copy not implemented for __dlpack__",
1480 ));
1481 } else {
1482 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1483 }
1484 }
1485 }
1486 }
1487 let _ = stream;
1488
1489 let buf = self
1490 .buf
1491 .take()
1492 .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
1493
1494 let rows = self.rows;
1495 let cols = self.cols;
1496 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1497
1498 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1499 }
1500}
1501#[cfg(all(feature = "python", feature = "cuda"))]
1502#[pyfunction(name = "voss_cuda_batch_dev")]
1503#[pyo3(signature = (data_f32, period=(20,20,0), predict=(3,3,0), bandwidth=(0.25,0.25,0.0), device_id=0))]
1504pub fn voss_cuda_batch_dev_py<'py>(
1505 py: Python<'py>,
1506 data_f32: numpy::PyReadonlyArray1<'py, f32>,
1507 period: (usize, usize, usize),
1508 predict: (usize, usize, usize),
1509 bandwidth: (f64, f64, f64),
1510 device_id: usize,
1511) -> PyResult<Bound<'py, PyDict>> {
1512 use crate::cuda::cuda_available;
1513 if !cuda_available() {
1514 return Err(PyValueError::new_err("CUDA not available"));
1515 }
1516 let slice_in = data_f32.as_slice()?;
1517 let sweep = VossBatchRange {
1518 period,
1519 predict,
1520 bandwidth,
1521 };
1522 let (voss_dev, filt_dev, combos, ctx, dev_id) = py.allow_threads(|| -> PyResult<_> {
1523 let cuda = crate::cuda::CudaVoss::new(device_id)
1524 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1525 let (voss_dev, filt_dev, combos) = cuda
1526 .voss_batch_dev(slice_in, &sweep)
1527 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1528 Ok((
1529 voss_dev,
1530 filt_dev,
1531 combos,
1532 cuda.context_arc(),
1533 cuda.device_id(),
1534 ))
1535 })?;
1536 let dict = PyDict::new(py);
1537 let crate::cuda::moving_averages::DeviceArrayF32 { buf, rows, cols } = voss_dev;
1538 dict.set_item(
1539 "voss",
1540 Py::new(
1541 py,
1542 VossDeviceArrayF32Py {
1543 buf: Some(buf),
1544 rows,
1545 cols,
1546 _ctx: ctx.clone(),
1547 device_id: dev_id,
1548 },
1549 )?,
1550 )?;
1551 let crate::cuda::moving_averages::DeviceArrayF32 { buf, rows, cols } = filt_dev;
1552 dict.set_item(
1553 "filt",
1554 Py::new(
1555 py,
1556 VossDeviceArrayF32Py {
1557 buf: Some(buf),
1558 rows,
1559 cols,
1560 _ctx: ctx,
1561 device_id: dev_id,
1562 },
1563 )?,
1564 )?;
1565 dict.set_item("rows", combos.len())?;
1566 dict.set_item("cols", slice_in.len())?;
1567 use numpy::IntoPyArray;
1568 dict.set_item(
1569 "periods",
1570 combos
1571 .iter()
1572 .map(|c| c.period.unwrap_or(20) as u64)
1573 .collect::<Vec<_>>()
1574 .into_pyarray(py),
1575 )?;
1576 dict.set_item(
1577 "predicts",
1578 combos
1579 .iter()
1580 .map(|c| c.predict.unwrap_or(3) as u64)
1581 .collect::<Vec<_>>()
1582 .into_pyarray(py),
1583 )?;
1584 dict.set_item(
1585 "bandwidths",
1586 combos
1587 .iter()
1588 .map(|c| c.bandwidth.unwrap_or(0.25))
1589 .collect::<Vec<_>>()
1590 .into_pyarray(py),
1591 )?;
1592 Ok(dict)
1593}
1594
1595#[cfg(all(feature = "python", feature = "cuda"))]
1596#[pyfunction(name = "voss_cuda_many_series_one_param_dev")]
1597#[pyo3(signature = (data_tm_f32, period=20, predict=3, bandwidth=0.25, device_id=0))]
1598pub fn voss_cuda_many_series_one_param_dev_py<'py>(
1599 py: Python<'py>,
1600 data_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
1601 period: usize,
1602 predict: usize,
1603 bandwidth: f64,
1604 device_id: usize,
1605) -> PyResult<(VossDeviceArrayF32Py, VossDeviceArrayF32Py)> {
1606 use crate::cuda::cuda_available;
1607 if !cuda_available() {
1608 return Err(PyValueError::new_err("CUDA not available"));
1609 }
1610 let rows = data_tm_f32.shape()[0];
1611 let cols = data_tm_f32.shape()[1];
1612 let flat = data_tm_f32.as_slice()?;
1613 let params = VossParams {
1614 period: Some(period),
1615 predict: Some(predict),
1616 bandwidth: Some(bandwidth),
1617 };
1618 let (voss_dev, filt_dev, ctx, dev_id) = py.allow_threads(|| -> PyResult<_> {
1619 let cuda = crate::cuda::CudaVoss::new(device_id)
1620 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1621 let (voss_dev, filt_dev) = cuda
1622 .voss_many_series_one_param_time_major_dev(flat, cols, rows, ¶ms)
1623 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1624 Ok((voss_dev, filt_dev, cuda.context_arc(), cuda.device_id()))
1625 })?;
1626 let crate::cuda::moving_averages::DeviceArrayF32 { buf, rows, cols } = voss_dev;
1627 let voss_py = VossDeviceArrayF32Py {
1628 buf: Some(buf),
1629 rows,
1630 cols,
1631 _ctx: ctx.clone(),
1632 device_id: dev_id,
1633 };
1634 let crate::cuda::moving_averages::DeviceArrayF32 { buf, rows, cols } = filt_dev;
1635 let filt_py = VossDeviceArrayF32Py {
1636 buf: Some(buf),
1637 rows,
1638 cols,
1639 _ctx: ctx,
1640 device_id: dev_id,
1641 };
1642 Ok((voss_py, filt_py))
1643}
1644
1645#[cfg(test)]
1646mod tests {
1647 use super::*;
1648 use crate::skip_if_unsupported;
1649 use crate::utilities::data_loader::read_candles_from_csv;
1650
1651 #[inline]
1652 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1653 (a.is_nan() && b.is_nan()) || (a == b)
1654 }
1655
1656 #[test]
1657 fn test_voss_into_matches_api() -> Result<(), Box<dyn Error>> {
1658 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1659 let candles = read_candles_from_csv(file_path)?;
1660 let input = VossInput::with_default_candles(&candles);
1661
1662 let base = voss(&input)?;
1663
1664 let len = candles.close.len();
1665 let mut out_voss = vec![0.0f64; len];
1666 let mut out_filt = vec![0.0f64; len];
1667 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1668 {
1669 voss_into(&input, &mut out_voss, &mut out_filt)?;
1670 }
1671 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1672 {
1673 voss_into_slice(&mut out_voss, &mut out_filt, &input, Kernel::Auto)?;
1674 }
1675
1676 assert_eq!(base.voss.len(), out_voss.len());
1677 assert_eq!(base.filt.len(), out_filt.len());
1678
1679 for i in 0..len {
1680 assert!(
1681 eq_or_both_nan(base.voss[i], out_voss[i]),
1682 "voss mismatch at {}: got {}, expected {}",
1683 i,
1684 out_voss[i],
1685 base.voss[i]
1686 );
1687 assert!(
1688 eq_or_both_nan(base.filt[i], out_filt[i]),
1689 "filt mismatch at {}: got {}, expected {}",
1690 i,
1691 out_filt[i],
1692 base.filt[i]
1693 );
1694 }
1695
1696 Ok(())
1697 }
1698
1699 fn check_voss_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1700 skip_if_unsupported!(kernel, test_name);
1701 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1702 let candles = read_candles_from_csv(file_path)?;
1703
1704 let params = VossParams {
1705 period: None,
1706 predict: Some(2),
1707 bandwidth: None,
1708 };
1709 let input = VossInput::from_candles(&candles, "close", params);
1710 let output = voss_with_kernel(&input, kernel)?;
1711 assert_eq!(output.voss.len(), candles.close.len());
1712 assert_eq!(output.filt.len(), candles.close.len());
1713 Ok(())
1714 }
1715
1716 fn check_voss_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1717 skip_if_unsupported!(kernel, test_name);
1718 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1719 let candles = read_candles_from_csv(file_path)?;
1720
1721 let params = VossParams {
1722 period: Some(20),
1723 predict: Some(3),
1724 bandwidth: Some(0.25),
1725 };
1726 let input = VossInput::from_candles(&candles, "close", params);
1727 let output = voss_with_kernel(&input, kernel)?;
1728
1729 let expected_voss_last_five = [
1730 -290.430249544605,
1731 -269.74949153549596,
1732 -241.08179139844515,
1733 -149.2113276943419,
1734 -138.60361772412466,
1735 ];
1736 let expected_filt_last_five = [
1737 -228.0283989610523,
1738 -257.79056527053103,
1739 -270.3220395771822,
1740 -257.4282859799144,
1741 -235.78021136041997,
1742 ];
1743
1744 let start = output.voss.len() - 5;
1745 for (i, &val) in output.voss[start..].iter().enumerate() {
1746 let expected = expected_voss_last_five[i];
1747 assert!(
1748 (val - expected).abs() < 1e-6,
1749 "[{}] VOSS mismatch at idx {}: got {}, expected {}",
1750 test_name,
1751 i,
1752 val,
1753 expected
1754 );
1755 }
1756 for (i, &val) in output.filt[start..].iter().enumerate() {
1757 let expected = expected_filt_last_five[i];
1758 assert!(
1759 (val - expected).abs() < 1e-6,
1760 "[{}] Filt mismatch at idx {}: got {}, expected {}",
1761 test_name,
1762 i,
1763 val,
1764 expected
1765 );
1766 }
1767 Ok(())
1768 }
1769
1770 fn check_voss_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1771 skip_if_unsupported!(kernel, test_name);
1772 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1773 let candles = read_candles_from_csv(file_path)?;
1774
1775 let input = VossInput::with_default_candles(&candles);
1776 match input.data {
1777 VossData::Candles { source, .. } => assert_eq!(source, "close"),
1778 _ => panic!("Expected VossData::Candles"),
1779 }
1780 let output = voss_with_kernel(&input, kernel)?;
1781 assert_eq!(output.voss.len(), candles.close.len());
1782 Ok(())
1783 }
1784
1785 fn check_voss_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1786 skip_if_unsupported!(kernel, test_name);
1787 let input_data = [10.0, 20.0, 30.0];
1788 let params = VossParams {
1789 period: Some(0),
1790 predict: None,
1791 bandwidth: None,
1792 };
1793 let input = VossInput::from_slice(&input_data, params);
1794 let res = voss_with_kernel(&input, kernel);
1795 assert!(
1796 res.is_err(),
1797 "[{}] VOSS should fail with zero period",
1798 test_name
1799 );
1800 Ok(())
1801 }
1802
1803 fn check_voss_period_exceeds_length(
1804 test_name: &str,
1805 kernel: Kernel,
1806 ) -> Result<(), Box<dyn Error>> {
1807 skip_if_unsupported!(kernel, test_name);
1808 let data_small = [10.0, 20.0, 30.0];
1809 let params = VossParams {
1810 period: Some(10),
1811 predict: None,
1812 bandwidth: None,
1813 };
1814 let input = VossInput::from_slice(&data_small, params);
1815 let res = voss_with_kernel(&input, kernel);
1816 assert!(
1817 res.is_err(),
1818 "[{}] VOSS should fail with period exceeding length",
1819 test_name
1820 );
1821 Ok(())
1822 }
1823
1824 fn check_voss_very_small_dataset(
1825 test_name: &str,
1826 kernel: Kernel,
1827 ) -> Result<(), Box<dyn Error>> {
1828 skip_if_unsupported!(kernel, test_name);
1829 let single_point = [42.0];
1830 let params = VossParams {
1831 period: Some(20),
1832 predict: None,
1833 bandwidth: None,
1834 };
1835 let input = VossInput::from_slice(&single_point, params);
1836 let res = voss_with_kernel(&input, kernel);
1837 assert!(
1838 res.is_err(),
1839 "[{}] VOSS should fail with insufficient data",
1840 test_name
1841 );
1842 Ok(())
1843 }
1844
1845 fn check_voss_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1846 skip_if_unsupported!(kernel, test_name);
1847 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1848 let candles = read_candles_from_csv(file_path)?;
1849
1850 let first_params = VossParams {
1851 period: Some(10),
1852 predict: Some(2),
1853 bandwidth: Some(0.2),
1854 };
1855 let first_input = VossInput::from_candles(&candles, "close", first_params);
1856 let first_result = voss_with_kernel(&first_input, kernel)?;
1857
1858 let second_params = VossParams {
1859 period: Some(10),
1860 predict: Some(2),
1861 bandwidth: Some(0.2),
1862 };
1863 let second_input = VossInput::from_slice(&first_result.voss, second_params);
1864 let second_result = voss_with_kernel(&second_input, kernel)?;
1865
1866 assert_eq!(second_result.voss.len(), first_result.voss.len());
1867 Ok(())
1868 }
1869
1870 #[cfg(debug_assertions)]
1871 fn check_voss_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1872 skip_if_unsupported!(kernel, test_name);
1873
1874 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1875 let candles = read_candles_from_csv(file_path)?;
1876
1877 let test_params = vec![
1878 VossParams::default(),
1879 VossParams {
1880 period: Some(2),
1881 predict: Some(1),
1882 bandwidth: Some(0.1),
1883 },
1884 VossParams {
1885 period: Some(5),
1886 predict: Some(2),
1887 bandwidth: Some(0.2),
1888 },
1889 VossParams {
1890 period: Some(10),
1891 predict: Some(3),
1892 bandwidth: Some(0.3),
1893 },
1894 VossParams {
1895 period: Some(15),
1896 predict: Some(4),
1897 bandwidth: Some(0.4),
1898 },
1899 VossParams {
1900 period: Some(30),
1901 predict: Some(5),
1902 bandwidth: Some(0.5),
1903 },
1904 VossParams {
1905 period: Some(50),
1906 predict: Some(8),
1907 bandwidth: Some(0.6),
1908 },
1909 VossParams {
1910 period: Some(100),
1911 predict: Some(10),
1912 bandwidth: Some(0.75),
1913 },
1914 VossParams {
1915 period: Some(20),
1916 predict: Some(1),
1917 bandwidth: Some(0.1),
1918 },
1919 VossParams {
1920 period: Some(20),
1921 predict: Some(10),
1922 bandwidth: Some(0.9),
1923 },
1924 VossParams {
1925 period: Some(20),
1926 predict: Some(3),
1927 bandwidth: Some(0.05),
1928 },
1929 VossParams {
1930 period: Some(20),
1931 predict: Some(3),
1932 bandwidth: Some(1.0),
1933 },
1934 VossParams {
1935 period: Some(3),
1936 predict: Some(1),
1937 bandwidth: Some(0.15),
1938 },
1939 VossParams {
1940 period: Some(7),
1941 predict: Some(2),
1942 bandwidth: Some(0.35),
1943 },
1944 VossParams {
1945 period: Some(13),
1946 predict: Some(3),
1947 bandwidth: Some(0.45),
1948 },
1949 ];
1950
1951 for (param_idx, params) in test_params.iter().enumerate() {
1952 let input = VossInput::from_candles(&candles, "close", params.clone());
1953 let output = voss_with_kernel(&input, kernel)?;
1954
1955 for (i, &val) in output.voss.iter().enumerate() {
1956 if val.is_nan() {
1957 continue;
1958 }
1959
1960 let bits = val.to_bits();
1961
1962 if bits == 0x11111111_11111111 {
1963 panic!(
1964 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1965 in voss output with params: period={}, predict={}, bandwidth={} (param set {})",
1966 test_name,
1967 val,
1968 bits,
1969 i,
1970 params.period.unwrap_or(20),
1971 params.predict.unwrap_or(3),
1972 params.bandwidth.unwrap_or(0.25),
1973 param_idx
1974 );
1975 }
1976
1977 if bits == 0x22222222_22222222 {
1978 panic!(
1979 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1980 in voss output with params: period={}, predict={}, bandwidth={} (param set {})",
1981 test_name,
1982 val,
1983 bits,
1984 i,
1985 params.period.unwrap_or(20),
1986 params.predict.unwrap_or(3),
1987 params.bandwidth.unwrap_or(0.25),
1988 param_idx
1989 );
1990 }
1991
1992 if bits == 0x33333333_33333333 {
1993 panic!(
1994 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1995 in voss output with params: period={}, predict={}, bandwidth={} (param set {})",
1996 test_name,
1997 val,
1998 bits,
1999 i,
2000 params.period.unwrap_or(20),
2001 params.predict.unwrap_or(3),
2002 params.bandwidth.unwrap_or(0.25),
2003 param_idx
2004 );
2005 }
2006 }
2007
2008 for (i, &val) in output.filt.iter().enumerate() {
2009 if val.is_nan() {
2010 continue;
2011 }
2012
2013 let bits = val.to_bits();
2014
2015 if bits == 0x11111111_11111111 {
2016 panic!(
2017 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2018 in filt output with params: period={}, predict={}, bandwidth={} (param set {})",
2019 test_name,
2020 val,
2021 bits,
2022 i,
2023 params.period.unwrap_or(20),
2024 params.predict.unwrap_or(3),
2025 params.bandwidth.unwrap_or(0.25),
2026 param_idx
2027 );
2028 }
2029
2030 if bits == 0x22222222_22222222 {
2031 panic!(
2032 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2033 in filt output with params: period={}, predict={}, bandwidth={} (param set {})",
2034 test_name,
2035 val,
2036 bits,
2037 i,
2038 params.period.unwrap_or(20),
2039 params.predict.unwrap_or(3),
2040 params.bandwidth.unwrap_or(0.25),
2041 param_idx
2042 );
2043 }
2044
2045 if bits == 0x33333333_33333333 {
2046 panic!(
2047 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2048 in filt output with params: period={}, predict={}, bandwidth={} (param set {})",
2049 test_name,
2050 val,
2051 bits,
2052 i,
2053 params.period.unwrap_or(20),
2054 params.predict.unwrap_or(3),
2055 params.bandwidth.unwrap_or(0.25),
2056 param_idx
2057 );
2058 }
2059 }
2060 }
2061
2062 Ok(())
2063 }
2064
2065 #[cfg(not(debug_assertions))]
2066 fn check_voss_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2067 Ok(())
2068 }
2069
2070 #[cfg(feature = "proptest")]
2071 #[allow(clippy::float_cmp)]
2072 fn check_voss_property(
2073 test_name: &str,
2074 kernel: Kernel,
2075 ) -> Result<(), Box<dyn std::error::Error>> {
2076 use proptest::prelude::*;
2077 skip_if_unsupported!(kernel, test_name);
2078
2079 let strat = (10usize..=20).prop_flat_map(|period| {
2080 (
2081 prop::collection::vec((100.0f64..10000.0f64), 100..200),
2082 Just(period),
2083 Just(1usize),
2084 Just(0.5f64),
2085 )
2086 });
2087
2088 proptest::test_runner::TestRunner::default()
2089 .run(&strat, |(mut data, period, predict, bandwidth)| {
2090 for (i, val) in data.iter_mut().enumerate() {
2091 *val += (i as f64).sin() * 100.0;
2092 }
2093
2094 let params = VossParams {
2095 period: Some(period),
2096 predict: Some(predict),
2097 bandwidth: Some(bandwidth),
2098 };
2099 let input = VossInput::from_slice(&data, params);
2100
2101 let output = voss_with_kernel(&input, kernel).unwrap();
2102 let ref_output = voss_with_kernel(&input, Kernel::Scalar).unwrap();
2103
2104 let order = 3 * predict;
2105 let min_index = period.max(5).max(order);
2106 let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
2107 let warmup_end = first + min_index;
2108
2109 prop_assert_eq!(
2110 output.voss.len(),
2111 data.len(),
2112 "[{}] VOSS length mismatch",
2113 test_name
2114 );
2115 prop_assert_eq!(
2116 output.filt.len(),
2117 data.len(),
2118 "[{}] Filt length mismatch",
2119 test_name
2120 );
2121
2122 for i in 0..warmup_end.min(5) {
2123 prop_assert!(
2124 output.voss[i].is_nan(),
2125 "[{}] Expected NaN in voss during warmup at idx {}, got {}",
2126 test_name,
2127 i,
2128 output.voss[i]
2129 );
2130 prop_assert!(
2131 output.filt[i].is_nan(),
2132 "[{}] Expected NaN in filt during warmup at idx {}, got {}",
2133 test_name,
2134 i,
2135 output.filt[i]
2136 );
2137 }
2138
2139 for i in 0..data.len() {
2140 let voss_val = output.voss[i];
2141 let voss_ref = ref_output.voss[i];
2142 let filt_val = output.filt[i];
2143 let filt_ref = ref_output.filt[i];
2144
2145 prop_assert!(
2146 voss_val.is_nan() == voss_ref.is_nan(),
2147 "[{}] VOSS NaN pattern mismatch at idx {}: {} vs {}",
2148 test_name,
2149 i,
2150 voss_val,
2151 voss_ref
2152 );
2153 prop_assert!(
2154 filt_val.is_nan() == filt_ref.is_nan(),
2155 "[{}] Filt NaN pattern mismatch at idx {}: {} vs {}",
2156 test_name,
2157 i,
2158 filt_val,
2159 filt_ref
2160 );
2161
2162 if filt_val.is_finite() && filt_ref.is_finite() {
2163 let diff = (filt_val - filt_ref).abs();
2164 let scale = filt_val.abs().max(filt_ref.abs()).max(1.0);
2165 prop_assert!(
2166 diff / scale < 1e-4,
2167 "[{}] Filt kernel value mismatch at idx {}: {} vs {} (rel diff: {})",
2168 test_name,
2169 i,
2170 filt_val,
2171 filt_ref,
2172 diff / scale
2173 );
2174 }
2175 }
2176
2177 Ok(())
2178 })
2179 .unwrap();
2180 Ok(())
2181 }
2182
2183 macro_rules! generate_all_voss_tests {
2184 ($($test_fn:ident),*) => {
2185 paste::paste! {
2186 $(
2187 #[test]
2188 fn [<$test_fn _scalar_f64>]() {
2189 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2190 }
2191 )*
2192 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2193 $(
2194 #[test]
2195 fn [<$test_fn _avx2_f64>]() {
2196 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2197 }
2198 #[test]
2199 fn [<$test_fn _avx512_f64>]() {
2200 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2201 }
2202 )*
2203 }
2204 }
2205 }
2206 generate_all_voss_tests!(
2207 check_voss_partial_params,
2208 check_voss_accuracy,
2209 check_voss_default_candles,
2210 check_voss_zero_period,
2211 check_voss_period_exceeds_length,
2212 check_voss_very_small_dataset,
2213 check_voss_reinput,
2214 check_voss_no_poison
2215 );
2216
2217 #[cfg(feature = "proptest")]
2218 generate_all_voss_tests!(check_voss_property);
2219
2220 fn check_batch_default_row(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2221 skip_if_unsupported!(kernel, test_name);
2222
2223 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2224 let c = read_candles_from_csv(file)?;
2225
2226 let output = VossBatchBuilder::new()
2227 .kernel(kernel)
2228 .apply_candles(&c, "close")?;
2229
2230 let def = VossParams::default();
2231 let row = output.voss_for(&def).expect("default row missing");
2232
2233 assert_eq!(row.len(), c.close.len());
2234
2235 for (i, &v) in row.iter().enumerate() {
2236 if i < def.period.unwrap() {
2237 assert!(
2238 v.is_nan() || v.abs() < 1e-8,
2239 "[{test_name}] expected NaN or 0 at idx {i}, got {v}"
2240 );
2241 }
2242 }
2243 Ok(())
2244 }
2245
2246 fn check_batch_param_grid(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2247 skip_if_unsupported!(kernel, test_name);
2248
2249 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2250 let c = read_candles_from_csv(file)?;
2251
2252 let builder = VossBatchBuilder::new()
2253 .kernel(kernel)
2254 .period_range(10, 14, 2)
2255 .predict_range(2, 4, 1)
2256 .bandwidth_range(0.1, 0.2, 0.1);
2257
2258 let output = builder.apply_candles(&c, "close")?;
2259 let expected_param_count = ((14 - 10) / 2 + 1) * (4 - 2 + 1) * 2;
2260 assert_eq!(
2261 output.combos.len(),
2262 expected_param_count,
2263 "[{test_name}] Unexpected grid size: got {}, expected {}",
2264 output.combos.len(),
2265 expected_param_count
2266 );
2267
2268 for p in &output.combos {
2269 let row = output.voss_for(p).unwrap();
2270 assert_eq!(row.len(), c.close.len());
2271 }
2272 Ok(())
2273 }
2274
2275 fn check_batch_nan_propagation(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2276 skip_if_unsupported!(kernel, test_name);
2277
2278 let data = [f64::NAN, f64::NAN, 1.0, 2.0, 3.0, 4.0];
2279 let range = VossBatchRange {
2280 period: (3, 3, 0),
2281 predict: (2, 2, 0),
2282 bandwidth: (0.1, 0.1, 0.0),
2283 };
2284 let output = VossBatchBuilder::new().kernel(kernel).apply_slice(&data)?;
2285 for row in 0..output.rows {
2286 let v = &output.voss[row * output.cols..][..output.cols];
2287 assert!(
2288 v.iter().any(|&x| x.is_nan()),
2289 "[{test_name}] No NaNs found in output row"
2290 );
2291 }
2292 Ok(())
2293 }
2294
2295 fn check_batch_invalid_range(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2296 skip_if_unsupported!(kernel, test_name);
2297
2298 let data = [1.0, 2.0, 3.0];
2299 let range = VossBatchRange {
2300 period: (10, 10, 0),
2301 predict: (3, 3, 0),
2302 bandwidth: (0.25, 0.25, 0.0),
2303 };
2304 let result = VossBatchBuilder::new().kernel(kernel).apply_slice(&data);
2305 assert!(
2306 result.is_err(),
2307 "[{test_name}] Expected error for invalid batch range"
2308 );
2309 Ok(())
2310 }
2311
2312 #[cfg(debug_assertions)]
2313 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2314 skip_if_unsupported!(kernel, test);
2315
2316 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2317 let c = read_candles_from_csv(file)?;
2318
2319 let test_configs = vec![
2320 (2, 10, 2, 1, 3, 1, 0.1, 0.3, 0.1),
2321 (5, 25, 5, 2, 4, 1, 0.2, 0.4, 0.1),
2322 (30, 60, 15, 5, 8, 1, 0.3, 0.6, 0.15),
2323 (2, 5, 1, 1, 2, 1, 0.1, 0.2, 0.05),
2324 (10, 30, 10, 3, 5, 1, 0.25, 0.5, 0.25),
2325 (50, 100, 25, 8, 10, 1, 0.5, 0.75, 0.25),
2326 ];
2327
2328 for (cfg_idx, config) in test_configs.iter().enumerate() {
2329 let output = VossBatchBuilder::new()
2330 .kernel(kernel)
2331 .period_range(config.0, config.1, config.2)
2332 .predict_range(config.3, config.4, config.5)
2333 .bandwidth_range(config.6, config.7, config.8)
2334 .apply_candles(&c, "close")?;
2335
2336 for (idx, &val) in output.voss.iter().enumerate() {
2337 if val.is_nan() {
2338 continue;
2339 }
2340
2341 let bits = val.to_bits();
2342 let row = idx / output.cols;
2343 let col = idx % output.cols;
2344 let combo = &output.combos[row];
2345
2346 if bits == 0x11111111_11111111 {
2347 panic!(
2348 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2349 at row {} col {} (flat index {}) in voss output with params: period={}, predict={}, bandwidth={}",
2350 test, cfg_idx, val, bits, row, col, idx,
2351 combo.period.unwrap_or(20),
2352 combo.predict.unwrap_or(3),
2353 combo.bandwidth.unwrap_or(0.25)
2354 );
2355 }
2356
2357 if bits == 0x22222222_22222222 {
2358 panic!(
2359 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2360 at row {} col {} (flat index {}) in voss output with params: period={}, predict={}, bandwidth={}",
2361 test, cfg_idx, val, bits, row, col, idx,
2362 combo.period.unwrap_or(20),
2363 combo.predict.unwrap_or(3),
2364 combo.bandwidth.unwrap_or(0.25)
2365 );
2366 }
2367
2368 if bits == 0x33333333_33333333 {
2369 panic!(
2370 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2371 at row {} col {} (flat index {}) in voss output with params: period={}, predict={}, bandwidth={}",
2372 test, cfg_idx, val, bits, row, col, idx,
2373 combo.period.unwrap_or(20),
2374 combo.predict.unwrap_or(3),
2375 combo.bandwidth.unwrap_or(0.25)
2376 );
2377 }
2378 }
2379
2380 for (idx, &val) in output.filt.iter().enumerate() {
2381 if val.is_nan() {
2382 continue;
2383 }
2384
2385 let bits = val.to_bits();
2386 let row = idx / output.cols;
2387 let col = idx % output.cols;
2388 let combo = &output.combos[row];
2389
2390 if bits == 0x11111111_11111111 {
2391 panic!(
2392 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2393 at row {} col {} (flat index {}) in filt output with params: period={}, predict={}, bandwidth={}",
2394 test, cfg_idx, val, bits, row, col, idx,
2395 combo.period.unwrap_or(20),
2396 combo.predict.unwrap_or(3),
2397 combo.bandwidth.unwrap_or(0.25)
2398 );
2399 }
2400
2401 if bits == 0x22222222_22222222 {
2402 panic!(
2403 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2404 at row {} col {} (flat index {}) in filt output with params: period={}, predict={}, bandwidth={}",
2405 test, cfg_idx, val, bits, row, col, idx,
2406 combo.period.unwrap_or(20),
2407 combo.predict.unwrap_or(3),
2408 combo.bandwidth.unwrap_or(0.25)
2409 );
2410 }
2411
2412 if bits == 0x33333333_33333333 {
2413 panic!(
2414 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2415 at row {} col {} (flat index {}) in filt output with params: period={}, predict={}, bandwidth={}",
2416 test, cfg_idx, val, bits, row, col, idx,
2417 combo.period.unwrap_or(20),
2418 combo.predict.unwrap_or(3),
2419 combo.bandwidth.unwrap_or(0.25)
2420 );
2421 }
2422 }
2423 }
2424
2425 Ok(())
2426 }
2427
2428 #[cfg(not(debug_assertions))]
2429 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2430 Ok(())
2431 }
2432
2433 macro_rules! gen_batch_tests {
2434 ($fn_name:ident) => {
2435 paste::paste! {
2436 #[test] fn [<$fn_name _scalar>]() {
2437 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2438 }
2439 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2440 #[test] fn [<$fn_name _avx2>]() {
2441 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2442 }
2443 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2444 #[test] fn [<$fn_name _avx512>]() {
2445 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2446 }
2447 #[test] fn [<$fn_name _auto_detect>]() {
2448 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2449 }
2450 }
2451 };
2452 }
2453
2454 gen_batch_tests!(check_batch_default_row);
2455 gen_batch_tests!(check_batch_param_grid);
2456 gen_batch_tests!(check_batch_nan_propagation);
2457 gen_batch_tests!(check_batch_invalid_range);
2458 gen_batch_tests!(check_batch_no_poison);
2459}
2460
2461#[cfg(feature = "python")]
2462#[pyfunction(name = "voss")]
2463#[pyo3(signature = (data, period=20, predict=3, bandwidth=0.25, kernel=None))]
2464pub fn voss_py<'py>(
2465 py: Python<'py>,
2466 data: PyReadonlyArray1<'py, f64>,
2467 period: usize,
2468 predict: usize,
2469 bandwidth: f64,
2470 kernel: Option<&str>,
2471) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
2472 let slice_in = data.as_slice()?;
2473 let kern = validate_kernel(kernel, false)?;
2474
2475 let params = VossParams {
2476 period: Some(period),
2477 predict: Some(predict),
2478 bandwidth: Some(bandwidth),
2479 };
2480 let input = VossInput::from_slice(slice_in, params);
2481
2482 let (voss_vec, filt_vec) = py
2483 .allow_threads(|| voss_with_kernel(&input, kern).map(|o| (o.voss, o.filt)))
2484 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2485
2486 Ok((voss_vec.into_pyarray(py), filt_vec.into_pyarray(py)))
2487}
2488
2489#[cfg(feature = "python")]
2490#[pyclass(name = "VossStream")]
2491pub struct VossStreamPy {
2492 stream: VossStream,
2493}
2494
2495#[cfg(feature = "python")]
2496#[pymethods]
2497impl VossStreamPy {
2498 #[new]
2499 #[pyo3(signature = (period=20, predict=3, bandwidth=0.25))]
2500 fn new(period: usize, predict: usize, bandwidth: f64) -> PyResult<Self> {
2501 let params = VossParams {
2502 period: Some(period),
2503 predict: Some(predict),
2504 bandwidth: Some(bandwidth),
2505 };
2506 let stream =
2507 VossStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2508 Ok(VossStreamPy { stream })
2509 }
2510
2511 fn update(&mut self, value: f64) -> Option<(f64, f64)> {
2512 self.stream.update(value)
2513 }
2514}
2515
2516#[cfg(feature = "python")]
2517#[pyfunction(name = "voss_batch")]
2518#[pyo3(signature = (data, period_range=(20, 100, 1), predict_range=(3, 3, 0), bandwidth_range=(0.25, 0.25, 0.0), kernel=None))]
2519pub fn voss_batch_py<'py>(
2520 py: Python<'py>,
2521 data: PyReadonlyArray1<'py, f64>,
2522 period_range: (usize, usize, usize),
2523 predict_range: (usize, usize, usize),
2524 bandwidth_range: (f64, f64, f64),
2525 kernel: Option<&str>,
2526) -> PyResult<Bound<'py, PyDict>> {
2527 use numpy::PyArray1;
2528 let slice_in = data.as_slice()?;
2529
2530 let sweep = VossBatchRange {
2531 period: period_range,
2532 predict: predict_range,
2533 bandwidth: bandwidth_range,
2534 };
2535
2536 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2537 let rows = combos.len();
2538 let cols = slice_in.len();
2539 let total = rows
2540 .checked_mul(cols)
2541 .ok_or_else(|| PyValueError::new_err("rows*cols overflow in voss_batch_py"))?;
2542
2543 let voss_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2544 let filt_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2545 let voss_slice = unsafe { voss_arr.as_slice_mut()? };
2546 let filt_slice = unsafe { filt_arr.as_slice_mut()? };
2547
2548 let kern = validate_kernel(kernel, true)?;
2549
2550 let combos = py
2551 .allow_threads(|| {
2552 let k = match kern {
2553 Kernel::Auto => detect_best_batch_kernel(),
2554 k => k,
2555 };
2556 let simd = match k {
2557 Kernel::Avx512Batch => Kernel::Avx512,
2558 Kernel::Avx2Batch => Kernel::Avx2,
2559 Kernel::ScalarBatch => Kernel::Scalar,
2560 _ => k,
2561 };
2562 voss_batch_inner_into(slice_in, &sweep, simd, true, voss_slice, filt_slice)
2563 })
2564 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2565
2566 let dict = PyDict::new(py);
2567 dict.set_item("voss", voss_arr.reshape((rows, cols))?)?;
2568 dict.set_item("filt", filt_arr.reshape((rows, cols))?)?;
2569 dict.set_item(
2570 "periods",
2571 combos
2572 .iter()
2573 .map(|p| p.period.unwrap_or(20) as u64)
2574 .collect::<Vec<_>>()
2575 .into_pyarray(py),
2576 )?;
2577 dict.set_item(
2578 "predicts",
2579 combos
2580 .iter()
2581 .map(|p| p.predict.unwrap_or(3) as u64)
2582 .collect::<Vec<_>>()
2583 .into_pyarray(py),
2584 )?;
2585 dict.set_item(
2586 "bandwidths",
2587 combos
2588 .iter()
2589 .map(|p| p.bandwidth.unwrap_or(0.25))
2590 .collect::<Vec<_>>()
2591 .into_pyarray(py),
2592 )?;
2593
2594 Ok(dict)
2595}
2596
2597#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2598#[wasm_bindgen]
2599pub fn voss_js(
2600 data: &[f64],
2601 period: usize,
2602 predict: usize,
2603 bandwidth: f64,
2604) -> Result<JsValue, JsValue> {
2605 let params = VossParams {
2606 period: Some(period),
2607 predict: Some(predict),
2608 bandwidth: Some(bandwidth),
2609 };
2610 let input = VossInput::from_slice(data, params);
2611 let out = voss_with_kernel(&input, detect_best_kernel())
2612 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2613
2614 #[derive(Serialize)]
2615 struct Out {
2616 voss: Vec<f64>,
2617 filt: Vec<f64>,
2618 }
2619
2620 serde_wasm_bindgen::to_value(&Out {
2621 voss: out.voss,
2622 filt: out.filt,
2623 })
2624 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2625}
2626
2627#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2628#[wasm_bindgen]
2629pub fn voss_into(
2630 in_ptr: *const f64,
2631 voss_ptr: *mut f64,
2632 filt_ptr: *mut f64,
2633 len: usize,
2634 period: usize,
2635 predict: usize,
2636 bandwidth: f64,
2637) -> Result<(), JsValue> {
2638 if in_ptr.is_null() || voss_ptr.is_null() || filt_ptr.is_null() {
2639 return Err(JsValue::from_str("Null pointer provided"));
2640 }
2641
2642 unsafe {
2643 let data = std::slice::from_raw_parts(in_ptr, len);
2644 let params = VossParams {
2645 period: Some(period),
2646 predict: Some(predict),
2647 bandwidth: Some(bandwidth),
2648 };
2649 let input = VossInput::from_slice(data, params);
2650
2651 let in_aliased = in_ptr == voss_ptr as *const f64 || in_ptr == filt_ptr as *const f64;
2652 let out_aliased = voss_ptr == filt_ptr;
2653
2654 if in_aliased || out_aliased {
2655 let mut temp_voss = vec![0.0; len];
2656 let mut temp_filt = vec![0.0; len];
2657 voss_into_slice(&mut temp_voss, &mut temp_filt, &input, Kernel::Auto)
2658 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2659
2660 let voss_out = std::slice::from_raw_parts_mut(voss_ptr, len);
2661 let filt_out = std::slice::from_raw_parts_mut(filt_ptr, len);
2662 voss_out.copy_from_slice(&temp_voss);
2663 filt_out.copy_from_slice(&temp_filt);
2664 } else {
2665 let voss_out = std::slice::from_raw_parts_mut(voss_ptr, len);
2666 let filt_out = std::slice::from_raw_parts_mut(filt_ptr, len);
2667 voss_into_slice(voss_out, filt_out, &input, Kernel::Auto)
2668 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2669 }
2670 Ok(())
2671 }
2672}
2673
2674#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2675#[wasm_bindgen]
2676pub fn voss_alloc(len: usize) -> *mut f64 {
2677 let mut vec = Vec::<f64>::with_capacity(len);
2678 let ptr = vec.as_mut_ptr();
2679 std::mem::forget(vec);
2680 ptr
2681}
2682
2683#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2684#[wasm_bindgen]
2685pub fn voss_free(ptr: *mut f64, len: usize) {
2686 if !ptr.is_null() {
2687 unsafe {
2688 let _ = Vec::from_raw_parts(ptr, len, len);
2689 }
2690 }
2691}
2692
2693#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2694#[derive(Serialize, Deserialize)]
2695pub struct VossBatchConfig {
2696 pub period_range: (usize, usize, usize),
2697 pub predict_range: (usize, usize, usize),
2698 pub bandwidth_range: (f64, f64, f64),
2699}
2700
2701#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2702#[derive(Serialize, Deserialize)]
2703pub struct VossBatchJsOutput {
2704 pub voss: Vec<f64>,
2705 pub filt: Vec<f64>,
2706 pub combos: Vec<VossParams>,
2707 pub rows: usize,
2708 pub cols: usize,
2709}
2710
2711#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2712#[wasm_bindgen(js_name = "voss_batch")]
2713pub fn voss_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2714 let cfg: VossBatchConfig = serde_wasm_bindgen::from_value(config)
2715 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2716
2717 let sweep = VossBatchRange {
2718 period: cfg.period_range,
2719 predict: cfg.predict_range,
2720 bandwidth: cfg.bandwidth_range,
2721 };
2722
2723 let out = voss_batch_inner(data, &sweep, detect_best_kernel(), false)
2724 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2725
2726 let js = VossBatchJsOutput {
2727 voss: out.voss,
2728 filt: out.filt,
2729 combos: out.combos,
2730 rows: out.rows,
2731 cols: out.cols,
2732 };
2733
2734 serde_wasm_bindgen::to_value(&js)
2735 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2736}
2737
2738#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2739#[wasm_bindgen]
2740pub fn voss_batch_metadata_js(
2741 period_start: usize,
2742 period_end: usize,
2743 period_step: usize,
2744 predict_start: usize,
2745 predict_end: usize,
2746 predict_step: usize,
2747 bandwidth_start: f64,
2748 bandwidth_end: f64,
2749 bandwidth_step: f64,
2750) -> Result<Vec<f64>, JsValue> {
2751 let sweep = VossBatchRange {
2752 period: (period_start, period_end, period_step),
2753 predict: (predict_start, predict_end, predict_step),
2754 bandwidth: (bandwidth_start, bandwidth_end, bandwidth_step),
2755 };
2756
2757 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2758 let cap = combos.len().checked_mul(3).unwrap_or(0);
2759 let mut metadata = Vec::with_capacity(cap);
2760
2761 for combo in combos {
2762 metadata.push(combo.period.unwrap_or(20) as f64);
2763 metadata.push(combo.predict.unwrap_or(3) as f64);
2764 metadata.push(combo.bandwidth.unwrap_or(0.25));
2765 }
2766
2767 Ok(metadata)
2768}
2769
2770#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2771#[wasm_bindgen]
2772pub fn voss_batch_into(
2773 in_ptr: *const f64,
2774 voss_ptr: *mut f64,
2775 filt_ptr: *mut f64,
2776 len: usize,
2777 period_start: usize,
2778 period_end: usize,
2779 period_step: usize,
2780 predict_start: usize,
2781 predict_end: usize,
2782 predict_step: usize,
2783 bandwidth_start: f64,
2784 bandwidth_end: f64,
2785 bandwidth_step: f64,
2786) -> Result<(), JsValue> {
2787 if in_ptr.is_null() || voss_ptr.is_null() || filt_ptr.is_null() {
2788 return Err(JsValue::from_str("Null pointer provided"));
2789 }
2790
2791 unsafe {
2792 let data = std::slice::from_raw_parts(in_ptr, len);
2793 let sweep = VossBatchRange {
2794 period: (period_start, period_end, period_step),
2795 predict: (predict_start, predict_end, predict_step),
2796 bandwidth: (bandwidth_start, bandwidth_end, bandwidth_step),
2797 };
2798
2799 let output = voss_batch_par_slice(data, &sweep, Kernel::Auto)
2800 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2801
2802 let total_size = output
2803 .rows
2804 .checked_mul(output.cols)
2805 .ok_or_else(|| JsValue::from_str("rows*cols overflow in voss_batch_into"))?;
2806 if total_size > 0 {
2807 let voss_out = std::slice::from_raw_parts_mut(voss_ptr, total_size);
2808 let filt_out = std::slice::from_raw_parts_mut(filt_ptr, total_size);
2809 voss_out.copy_from_slice(&output.voss);
2810 filt_out.copy_from_slice(&output.filt);
2811 }
2812
2813 Ok(())
2814 }
2815}