1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::gaussian_wrapper::DeviceArrayF32Py;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::cuda::moving_averages::CudaGaussian;
7use crate::utilities::data_loader::{source_type, Candles};
8use crate::utilities::enums::Kernel;
9use crate::utilities::helpers::{
10 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
11 make_uninit_matrix,
12};
13#[cfg(not(target_arch = "wasm32"))]
14use rayon::prelude::*;
15#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
16use std::arch::x86_64::*;
17use std::f64::consts::PI;
18use std::mem::MaybeUninit;
19use thiserror::Error;
20
21const LANES_AVX512: usize = 8;
22const LANES_AVX2: usize = 4;
23
24impl<'a> AsRef<[f64]> for GaussianInput<'a> {
25 #[inline(always)]
26 fn as_ref(&self) -> &[f64] {
27 match &self.data {
28 GaussianData::Slice(slice) => slice,
29 GaussianData::Candles { candles, source } => source_type(candles, source),
30 }
31 }
32}
33
34#[cfg(test)]
35mod tests_into {
36 use super::*;
37
38 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
39 #[test]
40 fn test_gaussian_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
41 let mut data = Vec::with_capacity(256);
42
43 for _ in 0..8 {
44 data.push(f64::NAN);
45 }
46
47 for i in 0..(256 - 8) {
48 let x = i as f64;
49 data.push(x * 0.1 + (x * 0.07).sin());
50 }
51
52 let input = GaussianInput::from_slice(&data, GaussianParams::default());
53
54 let base = gaussian(&input)?.values;
55
56 let mut out = vec![0.0f64; data.len()];
57 gaussian_into(&input, &mut out)?;
58
59 assert_eq!(base.len(), out.len());
60
61 fn eq_or_both_nan(a: f64, b: f64) -> bool {
62 (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
63 }
64 for i in 0..base.len() {
65 assert!(
66 eq_or_both_nan(base[i], out[i]),
67 "mismatch at idx {}: base={}, out={}",
68 i,
69 base[i],
70 out[i]
71 );
72 }
73 Ok(())
74 }
75}
76
77#[derive(Debug, Clone)]
78pub enum GaussianData<'a> {
79 Candles {
80 candles: &'a Candles,
81 source: &'a str,
82 },
83 Slice(&'a [f64]),
84}
85
86#[derive(Debug, Clone)]
87pub struct GaussianOutput {
88 pub values: Vec<f64>,
89}
90
91#[derive(Debug, Clone, Copy)]
92#[cfg_attr(
93 all(target_arch = "wasm32", feature = "wasm"),
94 derive(Serialize, Deserialize)
95)]
96pub struct GaussianParams {
97 pub period: Option<usize>,
98 pub poles: Option<usize>,
99}
100
101impl Default for GaussianParams {
102 fn default() -> Self {
103 Self {
104 period: Some(14),
105 poles: Some(4),
106 }
107 }
108}
109
110#[derive(Debug, Clone)]
111pub struct GaussianInput<'a> {
112 pub data: GaussianData<'a>,
113 pub params: GaussianParams,
114}
115
116impl<'a> GaussianInput<'a> {
117 pub fn from_candles(c: &'a Candles, s: &'a str, p: GaussianParams) -> Self {
118 Self {
119 data: GaussianData::Candles {
120 candles: c,
121 source: s,
122 },
123 params: p,
124 }
125 }
126 pub fn from_slice(sl: &'a [f64], p: GaussianParams) -> Self {
127 Self {
128 data: GaussianData::Slice(sl),
129 params: p,
130 }
131 }
132 pub fn with_default_candles(c: &'a Candles) -> Self {
133 Self::from_candles(c, "close", GaussianParams::default())
134 }
135 pub fn get_period(&self) -> usize {
136 self.params.period.unwrap_or(14)
137 }
138 pub fn get_poles(&self) -> usize {
139 self.params.poles.unwrap_or(4)
140 }
141}
142
143#[derive(Copy, Clone, Debug)]
144pub struct GaussianBuilder {
145 period: Option<usize>,
146 poles: Option<usize>,
147 kernel: Kernel,
148}
149
150impl Default for GaussianBuilder {
151 fn default() -> Self {
152 Self {
153 period: None,
154 poles: None,
155 kernel: Kernel::Auto,
156 }
157 }
158}
159
160impl GaussianBuilder {
161 pub fn new() -> Self {
162 Self::default()
163 }
164 pub fn period(mut self, n: usize) -> Self {
165 self.period = Some(n);
166 self
167 }
168 pub fn poles(mut self, k: usize) -> Self {
169 self.poles = Some(k);
170 self
171 }
172 pub fn kernel(mut self, k: Kernel) -> Self {
173 self.kernel = k;
174 self
175 }
176 pub fn apply(self, c: &Candles) -> Result<GaussianOutput, GaussianError> {
177 let p = GaussianParams {
178 period: self.period,
179 poles: self.poles,
180 };
181 let i = GaussianInput::from_candles(c, "close", p);
182 gaussian_with_kernel(&i, self.kernel)
183 }
184 pub fn apply_slice(self, d: &[f64]) -> Result<GaussianOutput, GaussianError> {
185 let p = GaussianParams {
186 period: self.period,
187 poles: self.poles,
188 };
189 let i = GaussianInput::from_slice(d, p);
190 gaussian_with_kernel(&i, self.kernel)
191 }
192 pub fn into_stream(self) -> Result<GaussianStream, GaussianError> {
193 let p = GaussianParams {
194 period: self.period,
195 poles: self.poles,
196 };
197 GaussianStream::try_new(p)
198 }
199}
200
201#[derive(Debug, Error)]
202pub enum GaussianError {
203 #[error("gaussian: Input data slice is empty.")]
204 EmptyInputData,
205 #[error("gaussian: No data provided to Gaussian filter.")]
206 NoData,
207 #[error("gaussian: Invalid period: period = {period}, data length = {data_len}")]
208 InvalidPeriod { period: usize, data_len: usize },
209 #[error("gaussian: Invalid number of poles: expected 1..4, got {poles}")]
210 InvalidPoles { poles: usize },
211 #[error("gaussian: Period is longer than the data. period={period}, data_len={data_len}")]
212 PeriodLongerThanData { period: usize, data_len: usize },
213 #[error("gaussian: All values are NaN.")]
214 AllValuesNaN,
215 #[error("gaussian: Not enough valid data: needed = {needed}, valid = {valid}")]
216 NotEnoughValidData { needed: usize, valid: usize },
217 #[error("gaussian: Period too small. Period must be >= 2 for meaningful Gaussian filtering. Got period={period}")]
218 DegeneratePeriod { period: usize },
219 #[error("gaussian: Period of 1 causes degenerate filter (alpha=0). This produces constant zero output. Use period >= 2.")]
220 PeriodOneDegenerate,
221 #[error("gaussian: output length mismatch: expected {expected}, got {got}")]
222 OutputLengthMismatch { expected: usize, got: usize },
223 #[error("gaussian: invalid range: start={start}, end={end}, step={step}")]
224 InvalidRange {
225 start: usize,
226 end: usize,
227 step: usize,
228 },
229 #[error("gaussian: invalid kernel for batch API: {0:?}")]
230 InvalidKernelForBatch(Kernel),
231 #[error("gaussian: size overflow (rows={rows}, cols={cols})")]
232 SizeOverflow { rows: usize, cols: usize },
233}
234
235#[inline]
236pub fn gaussian(input: &GaussianInput) -> Result<GaussianOutput, GaussianError> {
237 gaussian_with_kernel(input, Kernel::Auto)
238}
239
240#[inline(always)]
241fn has_fma() -> bool {
242 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
243 {
244 std::is_x86_feature_detected!("fma")
245 }
246
247 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
248 {
249 false
250 }
251}
252
253#[inline(always)]
254pub fn gaussian_scalar(data: &[f64], period: usize, poles: usize, out: &mut [f64]) {
255 debug_assert_eq!(data.len(), out.len());
256
257 if has_fma() {
258 unsafe { gaussian_scalar_fma(data, period, poles, out) }
259 } else {
260 gaussian_scalar_fallback(data, period, poles, out)
261 }
262}
263
264#[cfg_attr(
265 any(target_arch = "x86", target_arch = "x86_64"),
266 target_feature(enable = "fma")
267)]
268unsafe fn gaussian_scalar_fma(data: &[f64], period: usize, poles: usize, out: &mut [f64]) {
269 use core::f64::consts::PI;
270
271 let beta = {
272 let num = 1.0 - (2.0 * PI / period as f64).cos();
273 let den = (2.0f64).powf(1.0 / poles as f64) - 1.0;
274 num / den
275 };
276 let alpha = {
277 let tmp = beta * beta + 2.0 * beta;
278 -beta + tmp.sqrt()
279 };
280
281 match poles {
282 1 => gaussian_poles1_fma(data, alpha, out),
283 2 => gaussian_poles2_fma(data, alpha, out),
284 3 => gaussian_poles3_fma(data, alpha, out),
285 4 => gaussian_poles4_fma(data, alpha, out),
286 _ => core::hint::unreachable_unchecked(),
287 }
288}
289
290fn gaussian_scalar_fallback(data: &[f64], period: usize, poles: usize, out: &mut [f64]) {
291 use core::f64::consts::PI;
292
293 let beta = {
294 let num = 1.0 - (2.0 * PI / period as f64).cos();
295 let den = (2.0f64).powf(1.0 / poles as f64) - 1.0;
296 num / den
297 };
298 let alpha = {
299 let tmp = beta * beta + 2.0 * beta;
300 -beta + tmp.sqrt()
301 };
302
303 unsafe {
304 match poles {
305 1 => gaussian_poles1_fma(data, alpha, out),
306 2 => gaussian_poles2_fma(data, alpha, out),
307 3 => gaussian_poles3_fma(data, alpha, out),
308 4 => gaussian_poles4_fma(data, alpha, out),
309 _ => core::hint::unreachable_unchecked(),
310 }
311 }
312}
313
314pub fn gaussian_with_kernel(
315 input: &GaussianInput,
316 kernel: Kernel,
317) -> Result<GaussianOutput, GaussianError> {
318 let data: &[f64] = match &input.data {
319 GaussianData::Candles { candles, source } => source_type(candles, source),
320 GaussianData::Slice(sl) => sl,
321 };
322
323 let len = data.len();
324 let period = input.get_period();
325 let poles = input.get_poles();
326
327 if len == 0 {
328 return Err(GaussianError::NoData);
329 }
330
331 if period == 1 {
332 return Err(GaussianError::PeriodOneDegenerate);
333 }
334
335 if period < 2 {
336 return Err(GaussianError::DegeneratePeriod { period });
337 }
338
339 if period > len {
340 return Err(GaussianError::PeriodLongerThanData {
341 period,
342 data_len: len,
343 });
344 }
345
346 if !(1..=4).contains(&poles) {
347 return Err(GaussianError::InvalidPoles { poles });
348 }
349
350 let first_valid = data
351 .iter()
352 .position(|x| !x.is_nan())
353 .ok_or(GaussianError::AllValuesNaN)?;
354 if len - first_valid < period {
355 return Err(GaussianError::NotEnoughValidData {
356 needed: period,
357 valid: len - first_valid,
358 });
359 }
360
361 let chosen = match kernel {
362 Kernel::Auto => Kernel::Scalar,
363 other => other,
364 };
365 let warm = first_valid + period;
366 let mut out = alloc_with_nan_prefix(len, warm);
367
368 unsafe {
369 match chosen {
370 Kernel::Scalar | Kernel::ScalarBatch => gaussian_scalar(data, period, poles, &mut out),
371
372 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
373 Kernel::Avx2 | Kernel::Avx2Batch => gaussian_avx2(data, period, poles, &mut out),
374 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
375 Kernel::Avx512 | Kernel::Avx512Batch => gaussian_avx512(data, period, poles, &mut out),
376 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
377 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
378 gaussian_scalar(data, period, poles, &mut out)
379 }
380
381 Kernel::Auto => unreachable!(),
382 }
383 }
384
385 Ok(GaussianOutput { values: out })
386}
387
388#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
389#[inline(always)]
390#[allow(unused_variables)]
391pub unsafe fn gaussian_avx2(data: &[f64], period: usize, poles: usize, out: &mut [f64]) {
392 gaussian_scalar(data, period, poles, out);
393}
394
395#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
396#[inline(always)]
397#[allow(unused_variables)]
398pub unsafe fn gaussian_avx512(data: &[f64], period: usize, poles: usize, out: &mut [f64]) {
399 gaussian_scalar(data, period, poles, out);
400}
401
402#[inline(always)]
403unsafe fn gaussian_poles1_fma(inp: &[f64], alpha: f64, out: &mut [f64]) {
404 let c0 = alpha;
405 let c1 = 1.0 - alpha;
406
407 let mut prev = 0.0;
408 for i in 0..inp.len() {
409 let x = *inp.get_unchecked(i);
410 prev = c1.mul_add(prev, c0 * x);
411 *out.get_unchecked_mut(i) = prev;
412 }
413}
414
415#[inline(always)]
416unsafe fn gaussian_poles2_fma(inp: &[f64], alpha: f64, out: &mut [f64]) {
417 let a2 = alpha * alpha;
418 let one = 1.0 - alpha;
419 let c0 = a2;
420 let c1 = 2.0 * one;
421 let c2 = -(one * one);
422
423 let mut prev1 = 0.0;
424 let mut prev0 = 0.0;
425
426 for i in 0..inp.len() {
427 let x = *inp.get_unchecked(i);
428 let y = c2.mul_add(prev0, c1.mul_add(prev1, c0 * x));
429 prev0 = prev1;
430 prev1 = y;
431 *out.get_unchecked_mut(i) = y;
432 }
433}
434
435#[inline(always)]
436unsafe fn gaussian_poles3_fma(inp: &[f64], alpha: f64, out: &mut [f64]) {
437 let a3 = alpha * alpha * alpha;
438 let one = 1.0 - alpha;
439 let one2 = one * one;
440
441 let c0 = a3;
442 let c1 = 3.0 * one;
443 let c2 = -3.0 * one2;
444 let c3 = one2 * one;
445
446 let mut p2 = 0.0;
447 let mut p1 = 0.0;
448 let mut p0 = 0.0;
449
450 for i in 0..inp.len() {
451 let x = *inp.get_unchecked(i);
452 let y = c3.mul_add(p0, c2.mul_add(p1, c1.mul_add(p2, c0 * x)));
453 p0 = p1;
454 p1 = p2;
455 p2 = y;
456 *out.get_unchecked_mut(i) = y;
457 }
458}
459
460#[inline(always)]
461unsafe fn gaussian_poles4_fma(inp: &[f64], alpha: f64, out: &mut [f64]) {
462 let a4 = alpha * alpha * alpha * alpha;
463 let one = 1.0 - alpha;
464 let one2 = one * one;
465 let one3 = one2 * one;
466
467 let c0 = a4;
468 let c1 = 4.0 * one;
469 let c2 = -6.0 * one2;
470 let c3 = 4.0 * one3;
471 let c4 = -(one3 * one);
472
473 let mut p3 = 0.0;
474 let mut p2 = 0.0;
475 let mut p1 = 0.0;
476 let mut p0 = 0.0;
477
478 for i in 0..inp.len() {
479 let x = *inp.get_unchecked(i);
480 let y = c4.mul_add(p0, c3.mul_add(p1, c2.mul_add(p2, c1.mul_add(p3, c0 * x))));
481 p0 = p1;
482 p1 = p2;
483 p2 = p3;
484 p3 = y;
485 *out.get_unchecked_mut(i) = y;
486 }
487}
488
489#[inline(always)]
490fn gaussian_poles1(data: &[f64], n: usize, alpha: f64) -> Vec<f64> {
491 let c0 = alpha;
492 let c1 = 1.0 - alpha;
493 let mut fil = vec![0.0; 1 + n];
494 for i in 0..n {
495 fil[i + 1] = c0 * data[i] + c1 * fil[i];
496 }
497 fil[1..1 + n].to_vec()
498}
499#[inline(always)]
500fn gaussian_poles2(data: &[f64], n: usize, alpha: f64) -> Vec<f64> {
501 let a2 = alpha * alpha;
502 let one_a = 1.0 - alpha;
503 let c0 = a2;
504 let c1 = 2.0 * one_a;
505 let c2 = -(one_a * one_a);
506 let mut fil = vec![0.0; 2 + n];
507 for i in 0..n {
508 fil[i + 2] = c0 * data[i] + c1 * fil[i + 1] + c2 * fil[i];
509 }
510 fil[2..2 + n].to_vec()
511}
512#[inline(always)]
513fn gaussian_poles3(data: &[f64], n: usize, alpha: f64) -> Vec<f64> {
514 let a3 = alpha * alpha * alpha;
515 let one_a = 1.0 - alpha;
516 let one_a2 = one_a * one_a;
517 let c0 = a3;
518 let c1 = 3.0 * one_a;
519 let c2 = -3.0 * one_a2;
520 let c3 = one_a2 * one_a;
521 let mut fil = vec![0.0; 3 + n];
522 for i in 0..n {
523 fil[i + 3] = c0 * data[i] + c1 * fil[i + 2] + c2 * fil[i + 1] + c3 * fil[i];
524 }
525 fil[3..3 + n].to_vec()
526}
527#[inline(always)]
528fn gaussian_poles4(data: &[f64], n: usize, alpha: f64) -> Vec<f64> {
529 let a4 = alpha * alpha * alpha * alpha;
530 let one_a = 1.0 - alpha;
531 let one_a2 = one_a * one_a;
532 let one_a3 = one_a2 * one_a;
533 let c0 = a4;
534 let c1 = 4.0 * one_a;
535 let c2 = -6.0 * one_a2;
536 let c3 = 4.0 * one_a3;
537 let c4 = -(one_a3 * one_a);
538 let mut fil = vec![0.0; 4 + n];
539 for i in 0..n {
540 fil[i + 4] =
541 c0 * data[i] + c1 * fil[i + 3] + c2 * fil[i + 2] + c3 * fil[i + 1] + c4 * fil[i];
542 }
543 fil[4..4 + n].to_vec()
544}
545
546#[derive(Debug, Clone)]
547pub struct GaussianStream {
548 period: usize,
549 poles: u8,
550 alpha: f64,
551 one_minus: f64,
552
553 c: [f64; 5],
554
555 y: [f64; 4],
556
557 idx: usize,
558 init: bool,
559}
560
561impl GaussianStream {
562 #[inline]
563 pub fn try_new(params: GaussianParams) -> Result<Self, GaussianError> {
564 let period = params.period.unwrap_or(14);
565 let poles = params.poles.unwrap_or(4);
566
567 if period == 1 {
568 return Err(GaussianError::PeriodOneDegenerate);
569 }
570 if period < 2 {
571 return Err(GaussianError::DegeneratePeriod { period });
572 }
573 if !(1..=4).contains(&poles) {
574 return Err(GaussianError::InvalidPoles { poles });
575 }
576
577 let inv_n = 1.0 / (period as f64);
578 let theta = core::f64::consts::PI * 2.0 * inv_n;
579 let one_minus_cos = {
580 let s = (0.5 * theta).sin();
581 2.0 * (s * s)
582 };
583
584 #[inline(always)]
585 fn denom_for_poles(k: usize) -> f64 {
586 match k {
587 1 => 1.0,
588 2 => core::f64::consts::SQRT_2 - 1.0,
589 3 => 0.259_921_049_894_873_2,
590 4 => 0.189_207_115_002_721_06,
591 _ => unreachable!(),
592 }
593 }
594 let beta = one_minus_cos / denom_for_poles(poles);
595
596 let alpha = if beta == 0.0 {
597 0.0
598 } else {
599 let r = (beta * beta + 2.0 * beta).sqrt();
600 (2.0 * beta) / (beta + r)
601 };
602
603 let one_minus = 1.0 - alpha;
604
605 let a = alpha;
606 let a2 = a * a;
607 let a3 = a2 * a;
608 let a4 = a2 * a2;
609 let o = one_minus;
610 let o2 = o * o;
611 let o3 = o2 * o;
612 let o4 = o2 * o2;
613
614 let mut c = [0.0; 5];
615 match poles {
616 1 => {
617 c[0] = a;
618 c[1] = o;
619 }
620 2 => {
621 c[0] = a2;
622 c[1] = 2.0 * o;
623 c[2] = -o2;
624 }
625 3 => {
626 c[0] = a3;
627 c[1] = 3.0 * o;
628 c[2] = -3.0 * o2;
629 c[3] = o3;
630 }
631 4 => {
632 c[0] = a4;
633 c[1] = 4.0 * o;
634 c[2] = -6.0 * o2;
635 c[3] = 4.0 * o3;
636 c[4] = -o4;
637 }
638 _ => unreachable!(),
639 }
640
641 Ok(Self {
642 period,
643 poles: poles as u8,
644 alpha,
645 one_minus,
646 c,
647 y: [0.0; 4],
648 idx: 0,
649 init: true,
650 })
651 }
652
653 #[inline(always)]
654 pub fn update(&mut self, x: f64) -> f64 {
655 let p = self.poles;
656 let c = self.c;
657 let mut y0 = self.y[0];
658 let mut y1 = self.y[1];
659 let mut y2 = self.y[2];
660 let mut y3 = self.y[3];
661
662 let y = match p {
663 1 => self.alpha.mul_add(x, self.one_minus * y0),
664 2 => c[2].mul_add(y1, c[1].mul_add(y0, c[0] * x)),
665 3 => c[3].mul_add(y2, c[2].mul_add(y1, c[1].mul_add(y0, c[0] * x))),
666 _ => c[4].mul_add(
667 y3,
668 c[3].mul_add(y2, c[2].mul_add(y1, c[1].mul_add(y0, c[0] * x))),
669 ),
670 };
671
672 self.y[3] = y2;
673 self.y[2] = y1;
674 self.y[1] = y0;
675 self.y[0] = y;
676
677 self.idx = self.idx.wrapping_add(1);
678 y
679 }
680
681 #[inline]
682 pub fn update_many(&mut self, xs: &[f64], out: &mut [f64]) {
683 debug_assert_eq!(xs.len(), out.len());
684 for (i, &x) in xs.iter().enumerate() {
685 out[i] = self.update(x);
686 }
687 }
688}
689
690#[derive(Clone, Debug)]
691pub struct GaussianBatchRange {
692 pub period: (usize, usize, usize),
693 pub poles: (usize, usize, usize),
694}
695
696impl Default for GaussianBatchRange {
697 fn default() -> Self {
698 Self {
699 period: (14, 263, 1),
700 poles: (4, 4, 0),
701 }
702 }
703}
704
705#[derive(Clone, Debug, Default)]
706pub struct GaussianBatchBuilder {
707 range: GaussianBatchRange,
708 kernel: Kernel,
709}
710
711impl GaussianBatchBuilder {
712 pub fn new() -> Self {
713 Self::default()
714 }
715 pub fn kernel(mut self, k: Kernel) -> Self {
716 self.kernel = k;
717 self
718 }
719 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
720 self.range.period = (start, end, step);
721 self
722 }
723 pub fn period_static(mut self, p: usize) -> Self {
724 self.range.period = (p, p, 0);
725 self
726 }
727 pub fn poles_range(mut self, start: usize, end: usize, step: usize) -> Self {
728 self.range.poles = (start, end, step);
729 self
730 }
731 pub fn poles_static(mut self, p: usize) -> Self {
732 self.range.poles = (p, p, 0);
733 self
734 }
735 pub fn apply_slice(self, data: &[f64]) -> Result<GaussianBatchOutput, GaussianError> {
736 gaussian_batch_with_kernel(data, &self.range, self.kernel)
737 }
738 pub fn apply_candles(
739 self,
740 c: &Candles,
741 src: &str,
742 ) -> Result<GaussianBatchOutput, GaussianError> {
743 let slice = source_type(c, src);
744 self.apply_slice(slice)
745 }
746 pub fn with_default_slice(
747 data: &[f64],
748 k: Kernel,
749 ) -> Result<GaussianBatchOutput, GaussianError> {
750 GaussianBatchBuilder::new().kernel(k).apply_slice(data)
751 }
752 pub fn with_default_candles(c: &Candles) -> Result<GaussianBatchOutput, GaussianError> {
753 GaussianBatchBuilder::new()
754 .kernel(Kernel::Auto)
755 .apply_candles(c, "close")
756 }
757}
758
759#[derive(Clone, Debug)]
760pub struct GaussianBatchOutput {
761 pub values: Vec<f64>,
762 pub combos: Vec<GaussianParams>,
763 pub rows: usize,
764 pub cols: usize,
765}
766
767impl GaussianBatchOutput {
768 pub fn row_for_params(&self, p: &GaussianParams) -> Option<usize> {
769 self.combos.iter().position(|c| {
770 c.period.unwrap_or(14) == p.period.unwrap_or(14)
771 && c.poles.unwrap_or(4) == p.poles.unwrap_or(4)
772 })
773 }
774 pub fn values_for(&self, p: &GaussianParams) -> Option<&[f64]> {
775 self.row_for_params(p).map(|row| {
776 let start = row * self.cols;
777 &self.values[start..start + self.cols]
778 })
779 }
780}
781
782#[inline(always)]
783fn expand_grid(r: &GaussianBatchRange) -> Result<Vec<GaussianParams>, GaussianError> {
784 #[inline(always)]
785 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, GaussianError> {
786 if step == 0 || start == end {
787 return Ok(vec![start]);
788 }
789 let (lo, hi) = if start <= end {
790 (start, end)
791 } else {
792 (end, start)
793 };
794 let v: Vec<usize> = (lo..=hi).step_by(step).collect();
795 if v.is_empty() {
796 return Err(GaussianError::InvalidRange { start, end, step });
797 }
798 Ok(v)
799 }
800 let periods = axis_usize(r.period)?;
801 let poles = axis_usize(r.poles)?;
802 let mut out = Vec::with_capacity(periods.len().saturating_mul(poles.len()));
803 for &p in &periods {
804 for &k in &poles {
805 out.push(GaussianParams {
806 period: Some(p),
807 poles: Some(k),
808 });
809 }
810 }
811 Ok(out)
812}
813
814pub fn gaussian_batch_with_kernel(
815 data: &[f64],
816 sweep: &GaussianBatchRange,
817 k: Kernel,
818) -> Result<GaussianBatchOutput, GaussianError> {
819 let kernel = match k {
820 Kernel::Auto => match detect_best_batch_kernel() {
821 Kernel::Avx512Batch => Kernel::Avx2Batch,
822 other => other,
823 },
824 other if other.is_batch() => other,
825 other => return Err(GaussianError::InvalidKernelForBatch(other)),
826 };
827 let simd = match kernel {
828 Kernel::Avx512Batch => Kernel::Avx512,
829 Kernel::Avx2Batch => Kernel::Avx2,
830 Kernel::ScalarBatch => Kernel::Scalar,
831 _ => unreachable!(),
832 };
833 gaussian_batch_par_slice(data, sweep, simd)
834}
835
836#[inline(always)]
837pub fn gaussian_batch_slice(
838 data: &[f64],
839 sweep: &GaussianBatchRange,
840 kern: Kernel,
841) -> Result<GaussianBatchOutput, GaussianError> {
842 gaussian_batch_inner(data, sweep, kern, false)
843}
844
845#[inline(always)]
846pub fn gaussian_batch_par_slice(
847 data: &[f64],
848 sweep: &GaussianBatchRange,
849 kern: Kernel,
850) -> Result<GaussianBatchOutput, GaussianError> {
851 gaussian_batch_inner(data, sweep, kern, true)
852}
853
854#[inline]
855pub fn gaussian_into_slice(
856 dst: &mut [f64],
857 input: &GaussianInput,
858 kern: Kernel,
859) -> Result<(), GaussianError> {
860 let data = input.as_ref();
861
862 if dst.len() != data.len() {
863 return Err(GaussianError::OutputLengthMismatch {
864 expected: data.len(),
865 got: dst.len(),
866 });
867 }
868
869 gaussian_with_kernel_into(input, kern, dst)
870}
871
872#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
873#[inline]
874pub fn gaussian_into(input: &GaussianInput, out: &mut [f64]) -> Result<(), GaussianError> {
875 let data = input.as_ref();
876 if out.len() != data.len() {
877 return Err(GaussianError::OutputLengthMismatch {
878 expected: data.len(),
879 got: out.len(),
880 });
881 }
882
883 gaussian_with_kernel_into(input, Kernel::Auto, out)
884}
885
886#[inline(always)]
887fn gaussian_with_kernel_into(
888 input: &GaussianInput,
889 kernel: Kernel,
890 out: &mut [f64],
891) -> Result<(), GaussianError> {
892 let (data, period, poles, first, chosen) = gaussian_prepare(input, kernel)?;
893
894 out[..first + period].fill(f64::NAN);
895
896 gaussian_compute_into(data, period, poles, first, chosen, out);
897
898 Ok(())
899}
900
901#[inline(always)]
902fn gaussian_prepare<'a>(
903 input: &'a GaussianInput,
904 kernel: Kernel,
905) -> Result<(&'a [f64], usize, usize, usize, Kernel), GaussianError> {
906 let data: &[f64] = input.as_ref();
907 let len = data.len();
908
909 if len == 0 {
910 return Err(GaussianError::NoData);
911 }
912
913 let first = data
914 .iter()
915 .position(|x| !x.is_nan())
916 .ok_or(GaussianError::AllValuesNaN)?;
917
918 let period = input.params.period.unwrap_or(14);
919 let poles = input.params.poles.unwrap_or(4);
920
921 if period == 1 {
922 return Err(GaussianError::PeriodOneDegenerate);
923 }
924
925 if period < 2 {
926 return Err(GaussianError::DegeneratePeriod { period });
927 }
928
929 if period > len {
930 return Err(GaussianError::PeriodLongerThanData {
931 period,
932 data_len: len,
933 });
934 }
935
936 if !(1..=4).contains(&poles) {
937 return Err(GaussianError::InvalidPoles { poles });
938 }
939
940 if len - first < period {
941 return Err(GaussianError::NotEnoughValidData {
942 needed: period,
943 valid: len - first,
944 });
945 }
946
947 let chosen = match kernel {
948 Kernel::Auto => Kernel::Scalar,
949 other => other,
950 };
951
952 Ok((data, period, poles, first, chosen))
953}
954
955#[inline(always)]
956fn gaussian_compute_into(
957 data: &[f64],
958 period: usize,
959 poles: usize,
960 first: usize,
961 kernel: Kernel,
962 out: &mut [f64],
963) {
964 unsafe {
965 match kernel {
966 Kernel::Scalar | Kernel::ScalarBatch => gaussian_scalar(data, period, poles, out),
967 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
968 Kernel::Avx2 | Kernel::Avx2Batch => gaussian_avx2(data, period, poles, out),
969 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
970 Kernel::Avx512 | Kernel::Avx512Batch => gaussian_avx512(data, period, poles, out),
971 _ => unreachable!(),
972 }
973 }
974}
975
976#[inline(always)]
977fn alpha_from(period: usize, poles: usize) -> f64 {
978 let beta = {
979 let numerator = 1.0 - (2.0 * PI / period as f64).cos();
980 let denominator = (2.0_f64).powf(1.0 / poles as f64) - 1.0;
981 numerator / denominator
982 };
983 let tmp = beta * beta + 2.0 * beta;
984 -beta + tmp.sqrt()
985}
986
987#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
988#[target_feature(enable = "avx512f,fma")]
989#[inline]
990unsafe fn gaussian_rows8_avx512(
991 data: &[f64],
992 params: &[GaussianParams],
993 out_rows: &mut [f64],
994 cols: usize,
995) {
996 debug_assert_eq!(params.len(), LANES_AVX512);
997 debug_assert_eq!(out_rows.len(), LANES_AVX512 * cols);
998
999 let mut alpha_arr = [0.0f64; LANES_AVX512];
1000 let mut pole_arr = [0u32; LANES_AVX512];
1001 for (lane, prm) in params.iter().enumerate() {
1002 let p = prm.period.unwrap_or(14);
1003 let k = prm.poles.unwrap_or(4);
1004 alpha_arr[lane] = alpha_from(p, k);
1005 pole_arr[lane] = k as u32;
1006 }
1007
1008 let alpha_v = _mm512_loadu_pd(alpha_arr.as_ptr());
1009 let one_minus_v = _mm512_sub_pd(_mm512_set1_pd(1.0), alpha_v);
1010
1011 let mask_for = |stage: u32| -> __mmask8 {
1012 let mut m: u8 = 0;
1013 for lane in 0..LANES_AVX512 {
1014 if pole_arr[lane] > stage {
1015 m |= 1 << lane;
1016 }
1017 }
1018 m as __mmask8
1019 };
1020 let m0 = mask_for(0);
1021 let m1 = mask_for(1);
1022 let m2 = mask_for(2);
1023 let m3 = mask_for(3);
1024
1025 let mut st0 = _mm512_setzero_pd();
1026 let mut st1 = _mm512_setzero_pd();
1027 let mut st2 = _mm512_setzero_pd();
1028 let mut st3 = _mm512_setzero_pd();
1029
1030 for (t, &x_n) in data.iter().enumerate() {
1031 let x_vec = _mm512_set1_pd(x_n);
1032
1033 let y0 = _mm512_fmadd_pd(alpha_v, x_vec, _mm512_mul_pd(one_minus_v, st0));
1034 st0 = _mm512_mask_mov_pd(st0, m0, y0);
1035
1036 let y1 = _mm512_fmadd_pd(alpha_v, st0, _mm512_mul_pd(one_minus_v, st1));
1037 st1 = _mm512_mask_mov_pd(st1, m1, y1);
1038
1039 let y2 = _mm512_fmadd_pd(alpha_v, st1, _mm512_mul_pd(one_minus_v, st2));
1040 st2 = _mm512_mask_mov_pd(st2, m2, y2);
1041
1042 let y3 = _mm512_fmadd_pd(alpha_v, st2, _mm512_mul_pd(one_minus_v, st3));
1043 st3 = _mm512_mask_mov_pd(st3, m3, y3);
1044
1045 let mut y = st0;
1046 y = _mm512_mask_mov_pd(y, m1, st1);
1047 y = _mm512_mask_mov_pd(y, m2, st2);
1048 y = _mm512_mask_mov_pd(y, m3, st3);
1049
1050 let mut tmp = [0.0f64; LANES_AVX512];
1051 _mm512_storeu_pd(tmp.as_mut_ptr(), y);
1052 for lane in 0..LANES_AVX512 {
1053 out_rows[lane * cols + t] = tmp[lane];
1054 }
1055 }
1056}
1057
1058#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1059#[target_feature(enable = "avx2,fma")]
1060#[inline]
1061unsafe fn gaussian_batch_tile_avx2(
1062 data: &[f64],
1063 combos: &[GaussianParams],
1064 out_mu: &mut [core::mem::MaybeUninit<f64>],
1065 cols: usize,
1066) {
1067 let out = core::slice::from_raw_parts_mut(out_mu.as_mut_ptr() as *mut f64, out_mu.len());
1068
1069 let mut row = 0;
1070 while row + LANES_AVX2 <= combos.len() {
1071 gaussian_rows4_avx2(
1072 data,
1073 &combos[row..row + LANES_AVX2],
1074 &mut out[row * cols..(row + LANES_AVX2) * cols],
1075 cols,
1076 );
1077 row += LANES_AVX2;
1078 }
1079 for r in row..combos.len() {
1080 gaussian_row_scalar(data, &combos[r], &mut out[r * cols..(r + 1) * cols]);
1081 }
1082}
1083
1084#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1085#[target_feature(enable = "avx2,fma")]
1086#[inline]
1087unsafe fn gaussian_rows4_avx2(
1088 data: &[f64],
1089 params: &[GaussianParams],
1090 out_rows: &mut [f64],
1091 cols: usize,
1092) {
1093 debug_assert_eq!(params.len(), LANES_AVX2);
1094 debug_assert_eq!(out_rows.len(), LANES_AVX2 * cols);
1095
1096 let mut alpha_arr = [0.0f64; LANES_AVX2];
1097 let mut pole_arr = [0u32; LANES_AVX2];
1098 for (l, prm) in params.iter().enumerate() {
1099 alpha_arr[l] = alpha_from(prm.period.unwrap_or(14), prm.poles.unwrap_or(4));
1100 pole_arr[l] = prm.poles.unwrap_or(4) as u32;
1101 }
1102 let alpha_v = _mm256_loadu_pd(alpha_arr.as_ptr());
1103 let one_minus_v = _mm256_sub_pd(_mm256_set1_pd(1.0), alpha_v);
1104
1105 let mut st0 = _mm256_setzero_pd();
1106 let mut st1 = _mm256_setzero_pd();
1107 let mut st2 = _mm256_setzero_pd();
1108 let mut st3 = _mm256_setzero_pd();
1109
1110 let mut y0a = [0.0; LANES_AVX2];
1111 let mut y1a = [0.0; LANES_AVX2];
1112 let mut y2a = [0.0; LANES_AVX2];
1113 let mut y3a = [0.0; LANES_AVX2];
1114
1115 for (t, &x_n) in data.iter().enumerate() {
1116 let x_v = _mm256_set1_pd(x_n);
1117
1118 let y0_v = _mm256_fmadd_pd(alpha_v, x_v, _mm256_mul_pd(one_minus_v, st0));
1119 st0 = y0_v;
1120
1121 let y1_v = _mm256_fmadd_pd(alpha_v, st0, _mm256_mul_pd(one_minus_v, st1));
1122 st1 = y1_v;
1123
1124 let y2_v = _mm256_fmadd_pd(alpha_v, st1, _mm256_mul_pd(one_minus_v, st2));
1125 st2 = y2_v;
1126
1127 let y3_v = _mm256_fmadd_pd(alpha_v, st2, _mm256_mul_pd(one_minus_v, st3));
1128 st3 = y3_v;
1129
1130 _mm256_storeu_pd(y0a.as_mut_ptr(), y0_v);
1131 _mm256_storeu_pd(y1a.as_mut_ptr(), y1_v);
1132 _mm256_storeu_pd(y2a.as_mut_ptr(), y2_v);
1133 _mm256_storeu_pd(y3a.as_mut_ptr(), y3_v);
1134
1135 for lane in 0..LANES_AVX2 {
1136 let final_y = match pole_arr[lane] {
1137 1 => y0a[lane],
1138 2 => y1a[lane],
1139 3 => y2a[lane],
1140 _ => y3a[lane],
1141 };
1142 out_rows[lane * cols + t] = final_y;
1143 }
1144 }
1145}
1146
1147#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1148#[target_feature(enable = "avx512f,fma")]
1149#[inline]
1150unsafe fn gaussian_batch_tile_avx512(
1151 data: &[f64],
1152 combos: &[GaussianParams],
1153 out_mu: &mut [core::mem::MaybeUninit<f64>],
1154 cols: usize,
1155) {
1156 let out = core::slice::from_raw_parts_mut(out_mu.as_mut_ptr() as *mut f64, out_mu.len());
1157
1158 let mut row = 0;
1159 while row + LANES_AVX512 <= combos.len() {
1160 gaussian_rows8_avx512(
1161 data,
1162 &combos[row..row + LANES_AVX512],
1163 &mut out[row * cols..(row + LANES_AVX512) * cols],
1164 cols,
1165 );
1166 row += LANES_AVX512;
1167 }
1168
1169 for r in row..combos.len() {
1170 gaussian_row_scalar(data, &combos[r], &mut out[r * cols..(r + 1) * cols]);
1171 }
1172}
1173
1174#[inline(always)]
1175fn gaussian_batch_inner(
1176 data: &[f64],
1177 sweep: &GaussianBatchRange,
1178 kern: Kernel,
1179 parallel: bool,
1180) -> Result<GaussianBatchOutput, GaussianError> {
1181 #[cfg(not(target_arch = "wasm32"))]
1182 use rayon::prelude::*;
1183 use std::{arch::is_x86_feature_detected, mem::MaybeUninit};
1184
1185 let combos = expand_grid(sweep)?;
1186 if combos.is_empty() || data.is_empty() {
1187 return Err(GaussianError::NoData);
1188 }
1189 let cols = data.len();
1190 let rows = combos.len();
1191
1192 let _total = rows
1193 .checked_mul(cols)
1194 .ok_or(GaussianError::SizeOverflow { rows, cols })?;
1195
1196 let first_valid = data
1197 .iter()
1198 .position(|x| !x.is_nan())
1199 .ok_or(GaussianError::AllValuesNaN)?;
1200
1201 for c in &combos {
1202 let period = c.period.unwrap_or(14);
1203 let poles = c.poles.unwrap_or(4);
1204
1205 if period == 1 {
1206 return Err(GaussianError::PeriodOneDegenerate);
1207 }
1208
1209 if period < 2 {
1210 return Err(GaussianError::DegeneratePeriod { period });
1211 }
1212
1213 if period > cols {
1214 return Err(GaussianError::PeriodLongerThanData {
1215 period,
1216 data_len: cols,
1217 });
1218 }
1219 if !(1..=4).contains(&poles) {
1220 return Err(GaussianError::InvalidPoles { poles });
1221 }
1222 if cols - first_valid < period {
1223 return Err(GaussianError::NotEnoughValidData {
1224 needed: period,
1225 valid: cols - first_valid,
1226 });
1227 }
1228 }
1229
1230 let warm: Vec<usize> = combos
1231 .iter()
1232 .map(|c| {
1233 let p = c.period.unwrap_or(14);
1234 (first_valid + p).min(cols)
1235 })
1236 .collect();
1237
1238 let mut raw = make_uninit_matrix(rows, cols);
1239 init_matrix_prefixes(&mut raw, cols, &warm);
1240
1241 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1242 let have_avx512 = cfg!(target_feature = "avx512f") && is_x86_feature_detected!("avx512f");
1243 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
1244 let have_avx512 = false;
1245
1246 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1247 let have_avx2 = cfg!(target_feature = "avx2") && is_x86_feature_detected!("avx2");
1248 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
1249 let have_avx2 = false;
1250
1251 let chosen = match kern {
1252 Kernel::Avx512 if have_avx512 => Kernel::Avx512,
1253 Kernel::Avx2 if have_avx2 => Kernel::Avx2,
1254 _ => Kernel::Scalar,
1255 };
1256
1257 type RowRunner = unsafe fn(&[f64], &GaussianParams, &mut [f64]);
1258
1259 let row_runner: RowRunner = match chosen {
1260 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1261 Kernel::Avx512 => gaussian_row_avx512,
1262 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1263 Kernel::Avx2 => gaussian_row_avx2,
1264 _ => gaussian_row_scalar,
1265 };
1266
1267 #[inline(always)]
1268 unsafe fn compute_row(
1269 row_idx: usize,
1270 dst_mu: &mut [MaybeUninit<f64>],
1271 combos: &[GaussianParams],
1272 data: &[f64],
1273 cols: usize,
1274 runner: RowRunner,
1275 ) {
1276 let out = std::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, cols);
1277 runner(data, &combos[row_idx], out);
1278 }
1279
1280 if parallel {
1281 #[cfg(not(target_arch = "wasm32"))]
1282 {
1283 match chosen {
1284 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1285 Kernel::Avx512 => {
1286 let tiles = rows / LANES_AVX512;
1287
1288 raw.par_chunks_exact_mut(cols * LANES_AVX512)
1289 .zip(combos.par_chunks_exact(LANES_AVX512))
1290 .for_each(|(dst_blk, prm_blk)| unsafe {
1291 gaussian_batch_tile_avx512(data, prm_blk, dst_blk, cols);
1292 });
1293
1294 raw[tiles * cols * LANES_AVX512..]
1295 .par_chunks_mut(cols)
1296 .enumerate()
1297 .for_each(|(i, dst)| unsafe {
1298 compute_row(
1299 tiles * LANES_AVX512 + i,
1300 dst,
1301 &combos,
1302 data,
1303 cols,
1304 row_runner,
1305 );
1306 });
1307 }
1308
1309 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1310 Kernel::Avx2 => {
1311 let tiles = rows / LANES_AVX2;
1312
1313 raw.par_chunks_exact_mut(cols * LANES_AVX2)
1314 .zip(combos.par_chunks_exact(LANES_AVX2))
1315 .for_each(|(dst_blk, prm_blk)| unsafe {
1316 gaussian_batch_tile_avx2(data, prm_blk, dst_blk, cols);
1317 });
1318
1319 raw[tiles * cols * LANES_AVX2..]
1320 .par_chunks_mut(cols)
1321 .enumerate()
1322 .for_each(|(i, dst)| unsafe {
1323 compute_row(
1324 tiles * LANES_AVX2 + i,
1325 dst,
1326 &combos,
1327 data,
1328 cols,
1329 row_runner,
1330 );
1331 });
1332 }
1333
1334 _ => {
1335 raw.par_chunks_mut(cols)
1336 .enumerate()
1337 .for_each(|(row, dst)| unsafe {
1338 compute_row(row, dst, &combos, data, cols, row_runner);
1339 });
1340 }
1341 }
1342 }
1343 #[cfg(target_arch = "wasm32")]
1344 {
1345 for (row, dst) in raw.chunks_mut(cols).enumerate() {
1346 unsafe {
1347 compute_row(row, dst, &combos, data, cols, row_runner);
1348 }
1349 }
1350 }
1351 } else {
1352 match chosen {
1353 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1354 Kernel::Avx512 => {
1355 let tiles = rows / LANES_AVX512;
1356 unsafe {
1357 gaussian_batch_tile_avx512(
1358 data,
1359 &combos[..tiles * LANES_AVX512],
1360 &mut raw[..tiles * cols * LANES_AVX512],
1361 cols,
1362 );
1363 }
1364 for (i, dst) in raw[tiles * cols * LANES_AVX512..]
1365 .chunks_mut(cols)
1366 .enumerate()
1367 {
1368 unsafe {
1369 compute_row(
1370 tiles * LANES_AVX512 + i,
1371 dst,
1372 &combos,
1373 data,
1374 cols,
1375 row_runner,
1376 );
1377 }
1378 }
1379 }
1380
1381 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1382 Kernel::Avx2 => {
1383 let tiles = rows / LANES_AVX2;
1384 unsafe {
1385 gaussian_batch_tile_avx2(
1386 data,
1387 &combos[..tiles * LANES_AVX2],
1388 &mut raw[..tiles * cols * LANES_AVX2],
1389 cols,
1390 );
1391 }
1392 for (i, dst) in raw[tiles * cols * LANES_AVX2..]
1393 .chunks_mut(cols)
1394 .enumerate()
1395 {
1396 unsafe {
1397 compute_row(tiles * LANES_AVX2 + i, dst, &combos, data, cols, row_runner);
1398 }
1399 }
1400 }
1401
1402 _ => {
1403 for (row, dst) in raw.chunks_mut(cols).enumerate() {
1404 unsafe {
1405 compute_row(row, dst, &combos, data, cols, row_runner);
1406 }
1407 }
1408 }
1409 }
1410 }
1411
1412 let mut guard = core::mem::ManuallyDrop::new(raw);
1413 let values = unsafe {
1414 Vec::from_raw_parts(
1415 guard.as_mut_ptr() as *mut f64,
1416 guard.len(),
1417 guard.capacity(),
1418 )
1419 };
1420 Ok(GaussianBatchOutput {
1421 values,
1422 combos,
1423 rows,
1424 cols,
1425 })
1426}
1427
1428#[inline(always)]
1429pub unsafe fn gaussian_row_scalar(data: &[f64], prm: &GaussianParams, out: &mut [f64]) {
1430 gaussian_scalar(data, prm.period.unwrap_or(14), prm.poles.unwrap_or(4), out);
1431}
1432
1433#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1434#[target_feature(enable = "avx2,fma")]
1435#[inline]
1436pub unsafe fn gaussian_row_avx2(data: &[f64], prm: &GaussianParams, out: &mut [f64]) {
1437 gaussian_row_scalar(data, prm, out);
1438}
1439
1440#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1441#[target_feature(enable = "avx512f,fma")]
1442#[inline]
1443pub unsafe fn gaussian_row_avx512(data: &[f64], prm: &GaussianParams, out: &mut [f64]) {
1444 gaussian_row_scalar(data, prm, out);
1445}
1446
1447#[cfg(test)]
1448mod tests {
1449 use super::*;
1450 use crate::skip_if_unsupported;
1451 use crate::utilities::data_loader::read_candles_from_csv;
1452 use proptest::prelude::*;
1453
1454 fn check_gaussian_partial_params(
1455 test_name: &str,
1456 kernel: Kernel,
1457 ) -> Result<(), Box<dyn std::error::Error>> {
1458 skip_if_unsupported!(kernel, test_name);
1459 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1460 let candles = read_candles_from_csv(file_path)?;
1461 let default_params = GaussianParams {
1462 period: None,
1463 poles: None,
1464 };
1465 let input = GaussianInput::from_candles(&candles, "close", default_params);
1466 let output = gaussian_with_kernel(&input, kernel)?;
1467 assert_eq!(output.values.len(), candles.close.len());
1468 Ok(())
1469 }
1470
1471 fn check_gaussian_accuracy(
1472 test_name: &str,
1473 kernel: Kernel,
1474 ) -> Result<(), Box<dyn std::error::Error>> {
1475 skip_if_unsupported!(kernel, test_name);
1476 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1477 let candles = read_candles_from_csv(file_path)?;
1478 let params = GaussianParams {
1479 period: Some(14),
1480 poles: Some(4),
1481 };
1482 let input = GaussianInput::from_candles(&candles, "close", params);
1483 let result = gaussian_with_kernel(&input, kernel)?;
1484 let expected_last_five = [
1485 59221.90637814869,
1486 59236.15215167245,
1487 59207.10087088464,
1488 59178.48276885589,
1489 59085.36983209433,
1490 ];
1491 let start = result.values.len().saturating_sub(5);
1492 for (i, &val) in result.values[start..].iter().enumerate() {
1493 let diff = (val - expected_last_five[i]).abs();
1494 assert!(
1495 diff < 1e-4,
1496 "[{}] Gaussian {:?} mismatch at idx {}: got {}, expected {}",
1497 test_name,
1498 kernel,
1499 i,
1500 val,
1501 expected_last_five[i]
1502 );
1503 }
1504 Ok(())
1505 }
1506
1507 fn check_gaussian_default_candles(
1508 test_name: &str,
1509 kernel: Kernel,
1510 ) -> Result<(), Box<dyn std::error::Error>> {
1511 skip_if_unsupported!(kernel, test_name);
1512 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1513 let candles = read_candles_from_csv(file_path)?;
1514 let input = GaussianInput::with_default_candles(&candles);
1515 match input.data {
1516 GaussianData::Candles { source, .. } => assert_eq!(source, "close"),
1517 _ => panic!("Expected GaussianData::Candles"),
1518 }
1519 let output = gaussian_with_kernel(&input, kernel)?;
1520 assert_eq!(output.values.len(), candles.close.len());
1521 Ok(())
1522 }
1523
1524 fn check_gaussian_zero_period(
1525 test_name: &str,
1526 kernel: Kernel,
1527 ) -> Result<(), Box<dyn std::error::Error>> {
1528 skip_if_unsupported!(kernel, test_name);
1529 let data = [1.0, 2.0, 3.0];
1530 let params = GaussianParams {
1531 period: Some(0),
1532 poles: Some(2),
1533 };
1534 let input = GaussianInput::from_slice(&data, params);
1535 let res = gaussian_with_kernel(&input, kernel);
1536 assert!(
1537 matches!(res, Err(GaussianError::DegeneratePeriod { .. })),
1538 "[{test_name}] expected DegeneratePeriod error for period=0"
1539 );
1540 Ok(())
1541 }
1542
1543 fn check_gaussian_empty_input(
1544 test_name: &str,
1545 kernel: Kernel,
1546 ) -> Result<(), Box<dyn std::error::Error>> {
1547 skip_if_unsupported!(kernel, test_name);
1548 let empty: [f64; 0] = [];
1549 let input = GaussianInput::from_slice(&empty, GaussianParams::default());
1550 let res = gaussian_with_kernel(&input, kernel);
1551 assert!(
1552 matches!(res, Err(GaussianError::NoData)),
1553 "[{test_name}] expected NoData error"
1554 );
1555 Ok(())
1556 }
1557
1558 fn check_gaussian_invalid_poles(
1559 test_name: &str,
1560 kernel: Kernel,
1561 ) -> Result<(), Box<dyn std::error::Error>> {
1562 skip_if_unsupported!(kernel, test_name);
1563 let data = [1.0, 2.0, 3.0, 4.0];
1564 let params = GaussianParams {
1565 period: Some(2),
1566 poles: Some(5),
1567 };
1568 let input = GaussianInput::from_slice(&data, params);
1569 let res = gaussian_with_kernel(&input, kernel);
1570 assert!(
1571 matches!(res, Err(GaussianError::InvalidPoles { .. })),
1572 "[{test_name}] expected InvalidPoles error"
1573 );
1574 Ok(())
1575 }
1576
1577 fn check_gaussian_all_nan(
1578 test_name: &str,
1579 kernel: Kernel,
1580 ) -> Result<(), Box<dyn std::error::Error>> {
1581 skip_if_unsupported!(kernel, test_name);
1582 let data = [f64::NAN; 5];
1583 let params = GaussianParams {
1584 period: Some(3),
1585 poles: None,
1586 };
1587 let input = GaussianInput::from_slice(&data, params);
1588 let res = gaussian_with_kernel(&input, kernel);
1589 assert!(
1590 matches!(res, Err(GaussianError::AllValuesNaN)),
1591 "[{test_name}] expected AllValuesNaN error"
1592 );
1593 Ok(())
1594 }
1595
1596 fn check_gaussian_period_exceeds_length(
1597 test_name: &str,
1598 kernel: Kernel,
1599 ) -> Result<(), Box<dyn std::error::Error>> {
1600 skip_if_unsupported!(kernel, test_name);
1601 let data_small = [10.0, 20.0, 30.0];
1602 let params = GaussianParams {
1603 period: Some(10),
1604 poles: None,
1605 };
1606 let input = GaussianInput::from_slice(&data_small, params);
1607 let res = gaussian_with_kernel(&input, kernel);
1608 assert!(
1609 res.is_err(),
1610 "[{}] Gaussian should fail with period exceeding length",
1611 test_name
1612 );
1613 Ok(())
1614 }
1615
1616 fn check_gaussian_very_small_dataset(
1617 test_name: &str,
1618 kernel: Kernel,
1619 ) -> Result<(), Box<dyn std::error::Error>> {
1620 skip_if_unsupported!(kernel, test_name);
1621 let single_point = [42.0];
1622 let params = GaussianParams {
1623 period: Some(14),
1624 poles: None,
1625 };
1626 let input = GaussianInput::from_slice(&single_point, params);
1627 let res = gaussian_with_kernel(&input, kernel);
1628 assert!(
1629 res.is_err(),
1630 "[{}] Gaussian should fail with insufficient data",
1631 test_name
1632 );
1633 Ok(())
1634 }
1635
1636 fn check_gaussian_reinput(
1637 test_name: &str,
1638 kernel: Kernel,
1639 ) -> Result<(), Box<dyn std::error::Error>> {
1640 skip_if_unsupported!(kernel, test_name);
1641 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1642 let candles = read_candles_from_csv(file_path)?;
1643 let first_params = GaussianParams {
1644 period: Some(14),
1645 poles: Some(4),
1646 };
1647 let first_input = GaussianInput::from_candles(&candles, "close", first_params);
1648 let first_result = gaussian_with_kernel(&first_input, kernel)?;
1649 assert_eq!(first_result.values.len(), candles.close.len());
1650 let second_params = GaussianParams {
1651 period: Some(7),
1652 poles: Some(2),
1653 };
1654 let second_input = GaussianInput::from_slice(&first_result.values, second_params);
1655 let second_result = gaussian_with_kernel(&second_input, kernel)?;
1656 assert_eq!(second_result.values.len(), first_result.values.len());
1657
1658 for i in 10..second_result.values.len() {
1659 assert!(
1660 !second_result.values[i].is_nan(),
1661 "NaN found at index {}",
1662 i
1663 );
1664 }
1665 Ok(())
1666 }
1667
1668 fn check_gaussian_nan_handling(
1669 test_name: &str,
1670 kernel: Kernel,
1671 ) -> Result<(), Box<dyn std::error::Error>> {
1672 skip_if_unsupported!(kernel, test_name);
1673 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1674 let candles = read_candles_from_csv(file_path)?;
1675 let input = GaussianInput::from_candles(&candles, "close", GaussianParams::default());
1676 let res = gaussian_with_kernel(&input, kernel)?;
1677 assert_eq!(res.values.len(), candles.close.len());
1678
1679 let skip = input.params.poles.unwrap_or(4);
1680 for val in res.values.iter().skip(skip) {
1681 assert!(
1682 val.is_finite(),
1683 "[{}] Gaussian output should be finite once settled.",
1684 test_name
1685 );
1686 }
1687 Ok(())
1688 }
1689
1690 fn check_gaussian_streaming(
1691 test_name: &str,
1692 kernel: Kernel,
1693 ) -> Result<(), Box<dyn std::error::Error>> {
1694 skip_if_unsupported!(kernel, test_name);
1695 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1696 let candles = read_candles_from_csv(file_path)?;
1697 let period = 14;
1698 let poles = 4;
1699 let input = GaussianInput::from_candles(
1700 &candles,
1701 "close",
1702 GaussianParams {
1703 period: Some(period),
1704 poles: Some(poles),
1705 },
1706 );
1707 let batch_output = gaussian_with_kernel(&input, kernel)?.values;
1708 let mut stream = GaussianStream::try_new(GaussianParams {
1709 period: Some(period),
1710 poles: Some(poles),
1711 })?;
1712 let mut stream_values = Vec::with_capacity(candles.close.len());
1713 for &price in &candles.close {
1714 stream_values.push(stream.update(price));
1715 }
1716 assert_eq!(batch_output.len(), stream_values.len());
1717 let skip = poles;
1718 for (i, (&b, &s)) in batch_output
1719 .iter()
1720 .zip(stream_values.iter())
1721 .enumerate()
1722 .skip(skip)
1723 {
1724 if b.is_nan() && s.is_nan() {
1725 continue;
1726 }
1727 let diff = (b - s).abs();
1728 assert!(
1729 diff < 1e-9,
1730 "[{}] Gaussian streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1731 test_name,
1732 i,
1733 b,
1734 s,
1735 diff
1736 );
1737 }
1738 Ok(())
1739 }
1740
1741 fn check_gaussian_property(
1742 test_name: &str,
1743 kernel: Kernel,
1744 ) -> Result<(), Box<dyn std::error::Error>> {
1745 skip_if_unsupported!(kernel, test_name);
1746
1747 let strat = (2usize..=50).prop_flat_map(|period| {
1748 (
1749 prop::collection::vec(
1750 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1751 period..400,
1752 ),
1753 Just(period),
1754 1usize..=4,
1755 )
1756 });
1757
1758 proptest::test_runner::TestRunner::default()
1759 .run(&strat, |(data, period, poles)| {
1760 let params = GaussianParams {
1761 period: Some(period),
1762 poles: Some(poles),
1763 };
1764 let input = GaussianInput::from_slice(&data, params);
1765
1766 let GaussianOutput { values: out } = gaussian_with_kernel(&input, kernel).unwrap();
1767 let GaussianOutput { values: ref_out } =
1768 gaussian_with_kernel(&input, Kernel::Scalar).unwrap();
1769
1770 prop_assert_eq!(out.len(), data.len());
1771
1772 let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(data.len());
1773
1774 let expected_warmup = first_valid + period;
1775
1776 for i in 0..first_valid {
1777 prop_assert!(
1778 out[i].is_nan(),
1779 "idx {}: expected NaN for NaN input, got {}",
1780 i,
1781 out[i]
1782 );
1783 }
1784
1785 for i in first_valid..data.len() {
1786 if data[i].is_finite() && !data[first_valid..=i].iter().any(|x| x.is_nan()) {
1787 prop_assert!(
1788 out[i].is_finite(),
1789 "idx {}: expected finite value, got {}",
1790 i,
1791 out[i]
1792 );
1793 }
1794 }
1795
1796 let stability_point = first_valid + period * 2;
1797 if period > 1 && stability_point + 20 < data.len() {
1798 let window_start = stability_point;
1799 let window_end = (stability_point + 50).min(data.len());
1800
1801 let input_slice = &data[window_start..window_end];
1802 let output_slice = &out[window_start..window_end];
1803
1804 if input_slice.len() > 10
1805 && input_slice.iter().all(|x| x.is_finite())
1806 && output_slice.iter().all(|x| x.is_finite())
1807 {
1808 let input_mean: f64 =
1809 input_slice.iter().sum::<f64>() / input_slice.len() as f64;
1810 let output_mean: f64 =
1811 output_slice.iter().sum::<f64>() / output_slice.len() as f64;
1812
1813 let input_var: f64 = input_slice
1814 .iter()
1815 .map(|x| (x - input_mean).powi(2))
1816 .sum::<f64>()
1817 / input_slice.len() as f64;
1818 let output_var: f64 = output_slice
1819 .iter()
1820 .map(|x| (x - output_mean).powi(2))
1821 .sum::<f64>()
1822 / output_slice.len() as f64;
1823
1824 if input_var > 1e-10 {
1825 prop_assert!(
1826 output_var <= input_var + 1e-15,
1827 "Gaussian filter MUST reduce variance: input_var={}, output_var={}, period={}, poles={}",
1828 input_var, output_var, period, poles
1829 );
1830 }
1831 }
1832 }
1833
1834 if period < 2 {
1835 prop_assert!(
1836 false,
1837 "Test should not generate period < 2, but got period={}",
1838 period
1839 );
1840 }
1841
1842 let stability_check = first_valid + period * 3;
1843 if data[first_valid..]
1844 .windows(2)
1845 .all(|w| (w[0] - w[1]).abs() < 1e-10)
1846 && stability_check < data.len()
1847 {
1848 let constant_val = data[first_valid];
1849 for i in stability_check..data.len() {
1850 if out[i].is_finite() {
1851 prop_assert!(
1852 (out[i] - constant_val).abs() <= 1e-12,
1853 "constant input MUST produce constant output: idx={}, expected={}, got={}, diff={}",
1854 i, constant_val, out[i], (out[i] - constant_val).abs()
1855 );
1856 }
1857 }
1858 }
1859
1860 if kernel != Kernel::Scalar {
1861 for i in first_valid..data.len() {
1862 if out[i].is_finite() && ref_out[i].is_finite() {
1863 let y_bits = out[i].to_bits();
1864 let r_bits = ref_out[i].to_bits();
1865 let diff_bits = if y_bits > r_bits {
1866 y_bits - r_bits
1867 } else {
1868 r_bits - y_bits
1869 };
1870
1871 prop_assert!(
1872 diff_bits <= 10 || (out[i] - ref_out[i]).abs() < 1e-14,
1873 "kernel consistency failed at idx {}: {:?}={}, Scalar={}, diff_bits={}, abs_diff={}",
1874 i, kernel, out[i], ref_out[i], diff_bits, (out[i] - ref_out[i]).abs()
1875 );
1876 }
1877 }
1878 }
1879
1880 if stability_point + 30 < data.len() {
1881 match poles {
1882 1 => {
1883 let check_start = (first_valid + period * 2).min(data.len() / 2);
1884 let check_end = data.len();
1885
1886 if check_end > check_start + 10 {
1887 let input_slice = &data[check_start..check_end];
1888 let output_slice = &out[check_start..check_end];
1889
1890 if input_slice.iter().all(|x| x.is_finite())
1891 && output_slice.iter().all(|x| x.is_finite())
1892 {
1893 let input_mean =
1894 input_slice.iter().sum::<f64>() / input_slice.len() as f64;
1895 let output_mean = output_slice.iter().sum::<f64>()
1896 / output_slice.len() as f64;
1897
1898 let input_var = input_slice
1899 .iter()
1900 .map(|x| (x - input_mean).powi(2))
1901 .sum::<f64>()
1902 / input_slice.len() as f64;
1903 let output_var = output_slice
1904 .iter()
1905 .map(|x| (x - output_mean).powi(2))
1906 .sum::<f64>()
1907 / output_slice.len() as f64;
1908
1909 if input_var > 1e-10 {
1910 prop_assert!(
1911 output_var <= input_var,
1912 "1-pole filter should reduce variance: input_var={}, output_var={}",
1913 input_var, output_var
1914 );
1915 }
1916 }
1917 }
1918 }
1919 2 | 3 | 4 => {
1920 let params_1pole = GaussianParams {
1921 period: Some(period),
1922 poles: Some(1),
1923 };
1924 let input_1pole = GaussianInput::from_slice(&data, params_1pole);
1925 if let Ok(GaussianOutput { values: out_1pole }) =
1926 gaussian_with_kernel(&input_1pole, kernel)
1927 {
1928 let window_start = stability_point;
1929 let window_end = (stability_point + 30).min(data.len() - 1);
1930
1931 if window_end > window_start + 5 {
1932 let accel_multi: f64 = (window_start + 2..window_end)
1933 .map(|i| {
1934 if out[i - 2].is_finite()
1935 && out[i - 1].is_finite()
1936 && out[i].is_finite()
1937 {
1938 ((out[i] - out[i - 1]) - (out[i - 1] - out[i - 2]))
1939 .abs()
1940 } else {
1941 0.0
1942 }
1943 })
1944 .sum::<f64>();
1945
1946 let accel_1pole: f64 = (window_start + 2..window_end)
1947 .map(|i| {
1948 if out_1pole[i - 2].is_finite()
1949 && out_1pole[i - 1].is_finite()
1950 && out_1pole[i].is_finite()
1951 {
1952 ((out_1pole[i] - out_1pole[i - 1])
1953 - (out_1pole[i - 1] - out_1pole[i - 2]))
1954 .abs()
1955 } else {
1956 0.0
1957 }
1958 })
1959 .sum::<f64>();
1960
1961 let non_zero_count =
1962 data.iter().filter(|&&x| x.abs() > 1e-10).count();
1963 if accel_1pole > 1e-10
1964 && accel_multi > 1e-10
1965 && non_zero_count > 5
1966 {
1967 prop_assert!(
1968 accel_multi <= accel_1pole * 1.1,
1969 "{}-pole filter should be smoother than 1-pole: accel_{}pole={}, accel_1pole={}",
1970 poles, poles, accel_multi, accel_1pole
1971 );
1972 }
1973 }
1974 }
1975 }
1976 _ => {}
1977 }
1978 }
1979
1980 Ok(())
1981 })
1982 .unwrap();
1983
1984 Ok(())
1985 }
1986
1987 macro_rules! generate_all_gaussian_tests {
1988 ($($test_fn:ident),*) => {
1989 paste::paste! {
1990 $(
1991 #[test]
1992 fn [<$test_fn _scalar_f64>]() {
1993 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1994 }
1995 #[test]
1996 fn [<$test_fn _avx2_f64>]() {
1997 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1998 }
1999 #[test]
2000 fn [<$test_fn _avx512_f64>]() {
2001 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2002 }
2003 )*
2004 }
2005 }
2006 }
2007
2008 fn check_gaussian_edge_cases(
2009 test_name: &str,
2010 kernel: Kernel,
2011 ) -> Result<(), Box<dyn std::error::Error>> {
2012 skip_if_unsupported!(kernel, test_name);
2013
2014 {
2015 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
2016 let params = GaussianParams {
2017 period: Some(1),
2018 poles: Some(1),
2019 };
2020 let input = GaussianInput::from_slice(&data, params);
2021 let result = gaussian_with_kernel(&input, kernel);
2022
2023 assert!(
2024 result.is_err(),
2025 "[{}] Period=1 should return error, but got Ok",
2026 test_name
2027 );
2028
2029 if let Err(e) = result {
2030 match e {
2031 GaussianError::PeriodOneDegenerate => {}
2032 _ => panic!("[{}] Wrong error type for period=1: {:?}", test_name, e),
2033 }
2034 }
2035 }
2036
2037 {
2038 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
2039 let params = GaussianParams {
2040 period: Some(data.len()),
2041 poles: Some(2),
2042 };
2043 let input = GaussianInput::from_slice(&data, params);
2044 let result = gaussian_with_kernel(&input, kernel)?;
2045 assert_eq!(result.values.len(), data.len());
2046
2047 for i in 0..data.len() {
2048 assert!(
2049 result.values[i].is_finite() || result.values[i].is_nan(),
2050 "[{}] Output should be finite or NaN at index {}, got {}",
2051 test_name,
2052 i,
2053 result.values[i]
2054 );
2055 }
2056 }
2057
2058 {
2059 let data = vec![42.0];
2060 let params = GaussianParams {
2061 period: Some(1),
2062 poles: Some(1),
2063 };
2064 let input = GaussianInput::from_slice(&data, params);
2065 let result = gaussian_with_kernel(&input, kernel);
2066
2067 assert!(
2068 result.is_err(),
2069 "[{}] Single data point with period=1 should return error",
2070 test_name
2071 );
2072 }
2073
2074 {
2075 let data = vec![0.0; 10];
2076 let params = GaussianParams {
2077 period: Some(3),
2078 poles: Some(2),
2079 };
2080 let input = GaussianInput::from_slice(&data, params);
2081 let result = gaussian_with_kernel(&input, kernel)?;
2082
2083 for i in 3..result.values.len() {
2084 assert!(
2085 result.values[i].abs() < 1e-15 || result.values[i].is_nan(),
2086 "[{}] All-zero input should produce zero output at index {}: got {}",
2087 test_name,
2088 i,
2089 result.values[i]
2090 );
2091 }
2092 }
2093
2094 {
2095 let data = vec![1.0, 2.0, 3.0];
2096 let params = GaussianParams {
2097 period: Some(5),
2098 poles: Some(2),
2099 };
2100 let input = GaussianInput::from_slice(&data, params);
2101 let result = gaussian_with_kernel(&input, kernel);
2102
2103 assert!(
2104 result.is_err(),
2105 "[{}] Period > data.len() should return error, but got Ok",
2106 test_name
2107 );
2108
2109 if let Err(e) = result {
2110 match e {
2111 GaussianError::InvalidPeriod { .. }
2112 | GaussianError::PeriodLongerThanData { .. } => {}
2113 _ => panic!(
2114 "[{}] Wrong error type for period > data.len(): {:?}",
2115 test_name, e
2116 ),
2117 }
2118 }
2119 }
2120
2121 {
2122 let data = vec![1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0];
2123 let params = GaussianParams {
2124 period: Some(2),
2125 poles: Some(4),
2126 };
2127 let input = GaussianInput::from_slice(&data, params);
2128 let result = gaussian_with_kernel(&input, kernel)?;
2129
2130 let start = 2;
2131 if result.values.len() > start + 2 {
2132 let slice = &result.values[start..];
2133 let valid_values: Vec<f64> =
2134 slice.iter().filter(|x| x.is_finite()).copied().collect();
2135
2136 if valid_values.len() > 2 {
2137 let mean = valid_values.iter().sum::<f64>() / valid_values.len() as f64;
2138 let variance = valid_values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
2139 / valid_values.len() as f64;
2140
2141 assert!(
2142 variance < 0.6,
2143 "[{}] 4-pole filter should smooth alternating input: variance={}",
2144 test_name,
2145 variance
2146 );
2147 }
2148 }
2149 }
2150
2151 Ok(())
2152 }
2153
2154 #[cfg(debug_assertions)]
2155 fn check_gaussian_no_poison(
2156 test_name: &str,
2157 kernel: Kernel,
2158 ) -> Result<(), Box<dyn std::error::Error>> {
2159 skip_if_unsupported!(kernel, test_name);
2160
2161 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2162 let candles = read_candles_from_csv(file_path)?;
2163
2164 let test_cases = vec![
2165 GaussianParams {
2166 period: Some(14),
2167 poles: Some(4),
2168 },
2169 GaussianParams {
2170 period: Some(10),
2171 poles: Some(1),
2172 },
2173 GaussianParams {
2174 period: Some(30),
2175 poles: Some(2),
2176 },
2177 GaussianParams {
2178 period: Some(20),
2179 poles: Some(3),
2180 },
2181 GaussianParams {
2182 period: Some(50),
2183 poles: Some(4),
2184 },
2185 GaussianParams {
2186 period: Some(5),
2187 poles: Some(1),
2188 },
2189 GaussianParams {
2190 period: Some(100),
2191 poles: Some(4),
2192 },
2193 GaussianParams {
2194 period: None,
2195 poles: None,
2196 },
2197 ];
2198
2199 for params in test_cases {
2200 let input = GaussianInput::from_candles(&candles, "close", params);
2201 let output = gaussian_with_kernel(&input, kernel)?;
2202
2203 for (i, &val) in output.values.iter().enumerate() {
2204 if val.is_nan() {
2205 continue;
2206 }
2207
2208 let bits = val.to_bits();
2209
2210 if bits == 0x11111111_11111111 {
2211 panic!(
2212 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2213 with params period={:?}, poles={:?}",
2214 test_name, val, bits, i, params.period, params.poles
2215 );
2216 }
2217
2218 if bits == 0x22222222_22222222 {
2219 panic!(
2220 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2221 with params period={:?}, poles={:?}",
2222 test_name, val, bits, i, params.period, params.poles
2223 );
2224 }
2225
2226 if bits == 0x33333333_33333333 {
2227 panic!(
2228 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2229 with params period={:?}, poles={:?}",
2230 test_name, val, bits, i, params.period, params.poles
2231 );
2232 }
2233 }
2234 }
2235
2236 Ok(())
2237 }
2238
2239 #[cfg(not(debug_assertions))]
2240 fn check_gaussian_no_poison(
2241 _test_name: &str,
2242 _kernel: Kernel,
2243 ) -> Result<(), Box<dyn std::error::Error>> {
2244 Ok(())
2245 }
2246
2247 generate_all_gaussian_tests!(
2248 check_gaussian_partial_params,
2249 check_gaussian_accuracy,
2250 check_gaussian_default_candles,
2251 check_gaussian_zero_period,
2252 check_gaussian_period_exceeds_length,
2253 check_gaussian_very_small_dataset,
2254 check_gaussian_empty_input,
2255 check_gaussian_invalid_poles,
2256 check_gaussian_all_nan,
2257 check_gaussian_reinput,
2258 check_gaussian_nan_handling,
2259 check_gaussian_streaming,
2260 check_gaussian_property,
2261 check_gaussian_edge_cases,
2262 check_gaussian_no_poison
2263 );
2264
2265 fn check_batch_default_row(
2266 test: &str,
2267 kernel: Kernel,
2268 ) -> Result<(), Box<dyn std::error::Error>> {
2269 skip_if_unsupported!(kernel, test);
2270 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2271 let c = read_candles_from_csv(file)?;
2272 let output = GaussianBatchBuilder::new()
2273 .kernel(kernel)
2274 .apply_candles(&c, "close")?;
2275 let def = GaussianParams::default();
2276 let row = output.values_for(&def).expect("default row missing");
2277 assert_eq!(row.len(), c.close.len());
2278 let expected = [
2279 59221.90637814869,
2280 59236.15215167245,
2281 59207.10087088464,
2282 59178.48276885589,
2283 59085.36983209433,
2284 ];
2285 let start = row.len() - 5;
2286 for (i, &v) in row[start..].iter().enumerate() {
2287 assert!(
2288 (v - expected[i]).abs() < 1e-4,
2289 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2290 );
2291 }
2292 Ok(())
2293 }
2294
2295 macro_rules! gen_batch_tests {
2296 ($fn_name:ident) => {
2297 paste::paste! {
2298 #[test] fn [<$fn_name _scalar>]() {
2299 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2300 }
2301 #[test] fn [<$fn_name _avx2>]() {
2302 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2303 }
2304 #[test] fn [<$fn_name _avx512>]() {
2305 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2306 }
2307 #[test] fn [<$fn_name _auto_detect>]() {
2308 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2309 }
2310 }
2311 };
2312 }
2313
2314 #[cfg(debug_assertions)]
2315 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
2316 skip_if_unsupported!(kernel, test);
2317
2318 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2319 let c = read_candles_from_csv(file)?;
2320
2321 let batch_configs = vec![
2322 ((10, 30, 10), (1, 4, 1)),
2323 ((5, 5, 0), (1, 1, 0)),
2324 ((100, 120, 20), (2, 4, 2)),
2325 ((7, 21, 7), (1, 3, 2)),
2326 ((15, 45, 15), (1, 4, 3)),
2327 ((3, 12, 3), (1, 2, 1)),
2328 ];
2329
2330 for ((p_start, p_end, p_step), (poles_start, poles_end, poles_step)) in batch_configs {
2331 let output = GaussianBatchBuilder::new()
2332 .kernel(kernel)
2333 .period_range(p_start, p_end, p_step)
2334 .poles_range(poles_start, poles_end, poles_step)
2335 .apply_candles(&c, "close")?;
2336
2337 for (idx, &val) in output.values.iter().enumerate() {
2338 if val.is_nan() {
2339 continue;
2340 }
2341
2342 let bits = val.to_bits();
2343 let row = idx / output.cols;
2344 let col = idx % output.cols;
2345 let combo = &output.combos[row];
2346
2347 if bits == 0x11111111_11111111 {
2348 panic!(
2349 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} \
2350 (flat index {}) with params period={:?}, poles={:?}",
2351 test, val, bits, row, col, idx, combo.period, combo.poles
2352 );
2353 }
2354
2355 if bits == 0x22222222_22222222 {
2356 panic!(
2357 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} \
2358 (flat index {}) with params period={:?}, poles={:?}",
2359 test, val, bits, row, col, idx, combo.period, combo.poles
2360 );
2361 }
2362
2363 if bits == 0x33333333_33333333 {
2364 panic!(
2365 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} \
2366 (flat index {}) with params period={:?}, poles={:?}",
2367 test, val, bits, row, col, idx, combo.period, combo.poles
2368 );
2369 }
2370 }
2371 }
2372
2373 Ok(())
2374 }
2375
2376 #[cfg(not(debug_assertions))]
2377 fn check_batch_no_poison(
2378 _test: &str,
2379 _kernel: Kernel,
2380 ) -> Result<(), Box<dyn std::error::Error>> {
2381 Ok(())
2382 }
2383
2384 gen_batch_tests!(check_batch_default_row);
2385 gen_batch_tests!(check_batch_no_poison);
2386}
2387
2388#[cfg(feature = "python")]
2389use crate::utilities::kernel_validation::validate_kernel;
2390#[cfg(feature = "python")]
2391use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
2392#[cfg(feature = "python")]
2393use pyo3::exceptions::PyValueError;
2394#[cfg(feature = "python")]
2395use pyo3::prelude::*;
2396#[cfg(feature = "python")]
2397use pyo3::types::PyDict;
2398
2399#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2400use serde::{Deserialize, Serialize};
2401#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2402use wasm_bindgen::prelude::*;
2403
2404#[inline(always)]
2405fn gaussian_batch_inner_into(
2406 data: &[f64],
2407 sweep: &GaussianBatchRange,
2408 kern: Kernel,
2409 parallel: bool,
2410 out: &mut [f64],
2411) -> Result<Vec<GaussianParams>, GaussianError> {
2412 let combos = expand_grid(sweep)?;
2413 if combos.is_empty() || data.is_empty() {
2414 return Err(GaussianError::NoData);
2415 }
2416 let cols = data.len();
2417 let rows = combos.len();
2418 let expected = rows
2419 .checked_mul(cols)
2420 .ok_or(GaussianError::SizeOverflow { rows, cols })?;
2421 if out.len() != expected {
2422 return Err(GaussianError::OutputLengthMismatch {
2423 expected,
2424 got: out.len(),
2425 });
2426 }
2427
2428 let first_valid = data
2429 .iter()
2430 .position(|x| !x.is_nan())
2431 .ok_or(GaussianError::AllValuesNaN)?;
2432
2433 for c in &combos {
2434 let period = c.period.unwrap_or(14);
2435 let poles = c.poles.unwrap_or(4);
2436
2437 if period == 1 {
2438 return Err(GaussianError::PeriodOneDegenerate);
2439 }
2440
2441 if period < 2 {
2442 return Err(GaussianError::DegeneratePeriod { period });
2443 }
2444
2445 if period > cols {
2446 return Err(GaussianError::PeriodLongerThanData {
2447 period,
2448 data_len: cols,
2449 });
2450 }
2451 if !(1..=4).contains(&poles) {
2452 return Err(GaussianError::InvalidPoles { poles });
2453 }
2454 if cols - first_valid < period {
2455 return Err(GaussianError::NotEnoughValidData {
2456 needed: period,
2457 valid: cols - first_valid,
2458 });
2459 }
2460 }
2461
2462 let warm: Vec<usize> = combos
2463 .iter()
2464 .map(|c| {
2465 let w = first_valid + c.period.unwrap_or(14);
2466 w.min(cols)
2467 })
2468 .collect();
2469
2470 let raw = unsafe {
2471 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
2472 };
2473
2474 unsafe { init_matrix_prefixes(raw, cols, &warm) };
2475
2476 let chosen = match kern {
2477 Kernel::Auto => Kernel::Scalar,
2478 other => other,
2479 };
2480
2481 type RowRunner = unsafe fn(&[f64], &GaussianParams, &mut [f64]);
2482 let row_runner: RowRunner = match chosen {
2483 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2484 Kernel::Avx512 => gaussian_row_avx512,
2485 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2486 Kernel::Avx2 => gaussian_row_avx2,
2487 _ => gaussian_row_scalar,
2488 };
2489
2490 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
2491 let dst = std::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
2492 row_runner(data, &combos[row], dst);
2493 };
2494
2495 if parallel {
2496 #[cfg(not(target_arch = "wasm32"))]
2497 {
2498 raw.par_chunks_mut(cols)
2499 .enumerate()
2500 .for_each(|(row, slice)| do_row(row, slice));
2501 }
2502 #[cfg(target_arch = "wasm32")]
2503 {
2504 for (row, slice) in raw.chunks_mut(cols).enumerate() {
2505 do_row(row, slice);
2506 }
2507 }
2508 } else {
2509 for (row, slice) in raw.chunks_mut(cols).enumerate() {
2510 do_row(row, slice);
2511 }
2512 }
2513
2514 Ok(combos)
2515}
2516
2517#[cfg(feature = "python")]
2518#[pyfunction(name = "gaussian")]
2519#[pyo3(signature = (data, period, poles, kernel=None))]
2520pub fn gaussian_py<'py>(
2521 py: Python<'py>,
2522 data: PyReadonlyArray1<'py, f64>,
2523 period: usize,
2524 poles: usize,
2525 kernel: Option<&str>,
2526) -> PyResult<Bound<'py, PyArray1<f64>>> {
2527 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2528
2529 let slice_in = data.as_slice()?;
2530 let kern = validate_kernel(kernel, false)?;
2531
2532 let params = GaussianParams {
2533 period: Some(period),
2534 poles: Some(poles),
2535 };
2536 let gaussian_in = GaussianInput::from_slice(slice_in, params);
2537
2538 let result_vec: Vec<f64> = py
2539 .allow_threads(|| gaussian_with_kernel(&gaussian_in, kern).map(|o| o.values))
2540 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2541
2542 Ok(result_vec.into_pyarray(py))
2543}
2544
2545#[cfg(feature = "python")]
2546#[pyclass(name = "GaussianStream")]
2547pub struct GaussianStreamPy {
2548 stream: GaussianStream,
2549}
2550
2551#[cfg(feature = "python")]
2552#[pymethods]
2553impl GaussianStreamPy {
2554 #[new]
2555 fn new(period: usize, poles: usize) -> PyResult<Self> {
2556 let params = GaussianParams {
2557 period: Some(period),
2558 poles: Some(poles),
2559 };
2560 let stream =
2561 GaussianStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2562 Ok(GaussianStreamPy { stream })
2563 }
2564
2565 fn update(&mut self, value: f64) -> Option<f64> {
2566 Some(self.stream.update(value))
2567 }
2568}
2569
2570#[cfg(feature = "python")]
2571#[pyfunction(name = "gaussian_batch")]
2572#[pyo3(signature = (data, period_range, poles_range, kernel=None))]
2573pub fn gaussian_batch_py<'py>(
2574 py: Python<'py>,
2575 data: PyReadonlyArray1<'py, f64>,
2576 period_range: (usize, usize, usize),
2577 poles_range: (usize, usize, usize),
2578 kernel: Option<&str>,
2579) -> PyResult<Bound<'py, PyDict>> {
2580 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2581 use pyo3::types::PyDict;
2582
2583 let slice_in = data.as_slice()?;
2584 let kern = validate_kernel(kernel, true)?;
2585
2586 let sweep = GaussianBatchRange {
2587 period: period_range,
2588 poles: poles_range,
2589 };
2590
2591 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2592 let rows = combos.len();
2593 let cols = slice_in.len();
2594
2595 let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
2596 let slice_out = unsafe { out_arr.as_slice_mut()? };
2597
2598 let combos = py
2599 .allow_threads(|| {
2600 let kernel = match kern {
2601 Kernel::Auto => match detect_best_batch_kernel() {
2602 Kernel::Avx512Batch => Kernel::Avx2Batch,
2603 other => other,
2604 },
2605 k => k,
2606 };
2607 let simd = match kernel {
2608 Kernel::Avx512Batch => Kernel::Avx512,
2609 Kernel::Avx2Batch => Kernel::Avx2,
2610 Kernel::ScalarBatch => Kernel::Scalar,
2611 _ => unreachable!(),
2612 };
2613 gaussian_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2614 })
2615 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2616
2617 let dict = PyDict::new(py);
2618 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2619 dict.set_item(
2620 "periods",
2621 combos
2622 .iter()
2623 .map(|p| p.period.unwrap() as u64)
2624 .collect::<Vec<_>>()
2625 .into_pyarray(py),
2626 )?;
2627 dict.set_item(
2628 "poles",
2629 combos
2630 .iter()
2631 .map(|p| p.poles.unwrap_or(4) as u64)
2632 .collect::<Vec<_>>()
2633 .into_pyarray(py),
2634 )?;
2635
2636 Ok(dict)
2637}
2638
2639#[cfg(all(feature = "python", feature = "cuda"))]
2640#[pyfunction(name = "gaussian_cuda_batch_dev")]
2641#[pyo3(signature = (data_f32, period_range, poles_range, device_id=0))]
2642pub fn gaussian_cuda_batch_dev_py(
2643 py: Python<'_>,
2644 data_f32: numpy::PyReadonlyArray1<'_, f32>,
2645 period_range: (usize, usize, usize),
2646 poles_range: (usize, usize, usize),
2647 device_id: usize,
2648) -> PyResult<DeviceArrayF32Py> {
2649 if !cuda_available() {
2650 return Err(PyValueError::new_err("CUDA not available"));
2651 }
2652
2653 let slice = data_f32.as_slice()?;
2654 let sweep = GaussianBatchRange {
2655 period: period_range,
2656 poles: poles_range,
2657 };
2658
2659 let cuda = CudaGaussian::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2660 let stream = cuda.stream_handle();
2661 let dev_id = cuda.device_id();
2662 let ctx_guard = cuda.context_arc();
2663 let inner = py
2664 .allow_threads(|| cuda.gaussian_batch_dev(slice, &sweep))
2665 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2666
2667 Ok(DeviceArrayF32Py::new_from_rust(
2668 inner, stream, ctx_guard, dev_id,
2669 ))
2670}
2671
2672#[cfg(all(feature = "python", feature = "cuda"))]
2673#[pyfunction(name = "gaussian_cuda_many_series_one_param_dev")]
2674#[pyo3(signature = (prices_tm_f32, period, poles, device_id=0))]
2675pub fn gaussian_cuda_many_series_one_param_dev_py(
2676 py: Python<'_>,
2677 prices_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
2678 period: usize,
2679 poles: usize,
2680 device_id: usize,
2681) -> PyResult<DeviceArrayF32Py> {
2682 use numpy::PyUntypedArrayMethods;
2683
2684 if !cuda_available() {
2685 return Err(PyValueError::new_err("CUDA not available"));
2686 }
2687
2688 let shape = prices_tm_f32.shape();
2689 let rows = shape[0];
2690 let cols = shape[1];
2691
2692 let flat = prices_tm_f32.as_slice()?;
2693 let params = GaussianParams {
2694 period: Some(period),
2695 poles: Some(poles),
2696 };
2697
2698 let cuda = CudaGaussian::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2699 let stream = cuda.stream_handle();
2700 let dev_id = cuda.device_id();
2701 let ctx_guard = cuda.context_arc();
2702 let inner = py
2703 .allow_threads(|| {
2704 cuda.gaussian_many_series_one_param_time_major_dev(flat, cols, rows, ¶ms)
2705 })
2706 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2707
2708 Ok(DeviceArrayF32Py::new_from_rust(
2709 inner, stream, ctx_guard, dev_id,
2710 ))
2711}
2712
2713#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2714#[wasm_bindgen]
2715pub fn gaussian_js(data: &[f64], period: usize, poles: usize) -> Result<Vec<f64>, JsValue> {
2716 let params = GaussianParams {
2717 period: Some(period),
2718 poles: Some(poles),
2719 };
2720 let input = GaussianInput::from_slice(data, params);
2721
2722 let mut output = vec![0.0; data.len()];
2723
2724 gaussian_into_slice(&mut output, &input, Kernel::Auto)
2725 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2726
2727 Ok(output)
2728}
2729
2730#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2731#[wasm_bindgen]
2732pub fn gaussian_into(
2733 in_ptr: *const f64,
2734 out_ptr: *mut f64,
2735 len: usize,
2736 period: usize,
2737 poles: usize,
2738) -> Result<(), JsValue> {
2739 if in_ptr.is_null() || out_ptr.is_null() {
2740 return Err(JsValue::from_str("null pointer passed to gaussian_into"));
2741 }
2742
2743 unsafe {
2744 let data = std::slice::from_raw_parts(in_ptr, len);
2745
2746 if period == 1 {
2747 return Err(JsValue::from_str("Period of 1 causes degenerate filter"));
2748 }
2749 if period < 2 {
2750 return Err(JsValue::from_str(&format!(
2751 "Period must be >= 2, got {}",
2752 period
2753 )));
2754 }
2755 if period > len {
2756 return Err(JsValue::from_str(&format!(
2757 "Period {} is longer than data length {}",
2758 period, len
2759 )));
2760 }
2761
2762 if !(1..=4).contains(&poles) {
2763 return Err(JsValue::from_str("Invalid poles (must be 1-4)"));
2764 }
2765
2766 let params = GaussianParams {
2767 period: Some(period),
2768 poles: Some(poles),
2769 };
2770 let input = GaussianInput::from_slice(data, params);
2771
2772 if in_ptr == out_ptr {
2773 let mut temp = vec![0.0; len];
2774 gaussian_into_slice(&mut temp, &input, Kernel::Auto)
2775 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2776
2777 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2778 out.copy_from_slice(&temp);
2779 } else {
2780 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2781 gaussian_into_slice(out, &input, Kernel::Auto)
2782 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2783 }
2784
2785 Ok(())
2786 }
2787}
2788
2789#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2790#[wasm_bindgen]
2791pub fn gaussian_alloc(len: usize) -> *mut f64 {
2792 let mut vec = Vec::<f64>::with_capacity(len);
2793 let ptr = vec.as_mut_ptr();
2794 std::mem::forget(vec);
2795 ptr
2796}
2797
2798#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2799#[wasm_bindgen]
2800pub fn gaussian_free(ptr: *mut f64, len: usize) {
2801 if !ptr.is_null() {
2802 unsafe {
2803 let _ = Vec::from_raw_parts(ptr, len, len);
2804 }
2805 }
2806}
2807
2808#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2809#[wasm_bindgen]
2810pub fn gaussian_batch_js(
2811 data: &[f64],
2812 period_start: usize,
2813 period_end: usize,
2814 period_step: usize,
2815 poles_start: usize,
2816 poles_end: usize,
2817 poles_step: usize,
2818) -> Result<Vec<f64>, JsValue> {
2819 let sweep = GaussianBatchRange {
2820 period: (period_start, period_end, period_step),
2821 poles: (poles_start, poles_end, poles_step),
2822 };
2823
2824 gaussian_batch_inner(data, &sweep, Kernel::Auto, false)
2825 .map(|output| output.values)
2826 .map_err(|e| JsValue::from_str(&e.to_string()))
2827}
2828
2829#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2830#[wasm_bindgen]
2831pub fn gaussian_batch_metadata_js(
2832 period_start: usize,
2833 period_end: usize,
2834 period_step: usize,
2835 poles_start: usize,
2836 poles_end: usize,
2837 poles_step: usize,
2838) -> Result<Vec<f64>, JsValue> {
2839 let sweep = GaussianBatchRange {
2840 period: (period_start, period_end, period_step),
2841 poles: (poles_start, poles_end, poles_step),
2842 };
2843
2844 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2845 let metadata: Vec<f64> = combos
2846 .iter()
2847 .flat_map(|combo| {
2848 vec![
2849 combo.period.unwrap_or(14) as f64,
2850 combo.poles.unwrap_or(4) as f64,
2851 ]
2852 })
2853 .collect();
2854
2855 Ok(metadata)
2856}
2857
2858#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2859#[derive(Serialize, Deserialize)]
2860pub struct GaussianBatchConfig {
2861 pub period_range: (usize, usize, usize),
2862 pub poles_range: (usize, usize, usize),
2863}
2864
2865#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2866#[derive(Serialize, Deserialize)]
2867pub struct GaussianBatchJsOutput {
2868 pub values: Vec<f64>,
2869 pub combos: Vec<GaussianParams>,
2870 pub rows: usize,
2871 pub cols: usize,
2872}
2873
2874#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2875#[wasm_bindgen(js_name = gaussian_batch)]
2876pub fn gaussian_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2877 let config: GaussianBatchConfig = serde_wasm_bindgen::from_value(config)
2878 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2879
2880 let sweep = GaussianBatchRange {
2881 period: config.period_range,
2882 poles: config.poles_range,
2883 };
2884
2885 let output = gaussian_batch_inner(data, &sweep, Kernel::Auto, false)
2886 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2887
2888 let js_output = GaussianBatchJsOutput {
2889 values: output.values,
2890 combos: output.combos,
2891 rows: output.rows,
2892 cols: output.cols,
2893 };
2894
2895 serde_wasm_bindgen::to_value(&js_output)
2896 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2897}
2898
2899#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2900#[wasm_bindgen]
2901pub fn gaussian_batch_into(
2902 in_ptr: *const f64,
2903 out_ptr: *mut f64,
2904 len: usize,
2905 period_start: usize,
2906 period_end: usize,
2907 period_step: usize,
2908 poles_start: usize,
2909 poles_end: usize,
2910 poles_step: usize,
2911) -> Result<usize, JsValue> {
2912 if in_ptr.is_null() || out_ptr.is_null() {
2913 return Err(JsValue::from_str(
2914 "null pointer passed to gaussian_batch_into",
2915 ));
2916 }
2917
2918 unsafe {
2919 let data = std::slice::from_raw_parts(in_ptr, len);
2920
2921 let sweep = GaussianBatchRange {
2922 period: (period_start, period_end, period_step),
2923 poles: (poles_start, poles_end, poles_step),
2924 };
2925
2926 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2927 let rows = combos.len();
2928 let cols = len;
2929
2930 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2931
2932 gaussian_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
2933 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2934
2935 Ok(rows)
2936 }
2937}