1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::CudaStddev;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::indicators::moving_averages::alma::{make_device_array_py, DeviceArrayF32Py};
5#[cfg(feature = "python")]
6use numpy::{IntoPyArray, PyArray1};
7#[cfg(feature = "python")]
8use pyo3::exceptions::PyValueError;
9#[cfg(feature = "python")]
10use pyo3::prelude::*;
11#[cfg(feature = "python")]
12use pyo3::types::PyDict;
13
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use serde::{Deserialize, Serialize};
16#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
17use wasm_bindgen::prelude::*;
18
19use crate::utilities::data_loader::{source_type, Candles};
20use crate::utilities::enums::Kernel;
21use crate::utilities::helpers::{
22 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
23 make_uninit_matrix,
24};
25#[cfg(feature = "python")]
26use crate::utilities::kernel_validation::validate_kernel;
27
28#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
29use core::arch::x86_64::*;
30#[cfg(not(target_arch = "wasm32"))]
31use rayon::prelude::*;
32use std::convert::AsRef;
33use std::error::Error;
34use std::mem::MaybeUninit;
35use thiserror::Error;
36
37impl<'a> AsRef<[f64]> for StdDevInput<'a> {
38 #[inline(always)]
39 fn as_ref(&self) -> &[f64] {
40 match &self.data {
41 StdDevData::Slice(slice) => slice,
42 StdDevData::Candles { candles, source } => source_type(candles, source),
43 }
44 }
45}
46
47#[derive(Debug, Clone)]
48pub enum StdDevData<'a> {
49 Candles {
50 candles: &'a Candles,
51 source: &'a str,
52 },
53 Slice(&'a [f64]),
54}
55
56#[derive(Debug, Clone)]
57pub struct StdDevOutput {
58 pub values: Vec<f64>,
59}
60
61#[derive(Debug, Clone)]
62#[cfg_attr(
63 all(target_arch = "wasm32", feature = "wasm"),
64 derive(Serialize, Deserialize)
65)]
66pub struct StdDevParams {
67 pub period: Option<usize>,
68 pub nbdev: Option<f64>,
69}
70
71impl Default for StdDevParams {
72 fn default() -> Self {
73 Self {
74 period: Some(5),
75 nbdev: Some(1.0),
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
81pub struct StdDevInput<'a> {
82 pub data: StdDevData<'a>,
83 pub params: StdDevParams,
84}
85
86impl<'a> StdDevInput<'a> {
87 #[inline]
88 pub fn from_candles(c: &'a Candles, s: &'a str, p: StdDevParams) -> Self {
89 Self {
90 data: StdDevData::Candles {
91 candles: c,
92 source: s,
93 },
94 params: p,
95 }
96 }
97 #[inline]
98 pub fn from_slice(sl: &'a [f64], p: StdDevParams) -> Self {
99 Self {
100 data: StdDevData::Slice(sl),
101 params: p,
102 }
103 }
104 #[inline]
105 pub fn with_default_candles(c: &'a Candles) -> Self {
106 Self::from_candles(c, "close", StdDevParams::default())
107 }
108 #[inline]
109 pub fn get_period(&self) -> usize {
110 self.params.period.unwrap_or(5)
111 }
112 #[inline]
113 pub fn get_nbdev(&self) -> f64 {
114 self.params.nbdev.unwrap_or(1.0)
115 }
116}
117
118#[derive(Copy, Clone, Debug)]
119pub struct StdDevBuilder {
120 period: Option<usize>,
121 nbdev: Option<f64>,
122 kernel: Kernel,
123}
124
125impl Default for StdDevBuilder {
126 fn default() -> Self {
127 Self {
128 period: None,
129 nbdev: None,
130 kernel: Kernel::Auto,
131 }
132 }
133}
134
135impl StdDevBuilder {
136 #[inline(always)]
137 pub fn new() -> Self {
138 Self::default()
139 }
140 #[inline(always)]
141 pub fn period(mut self, n: usize) -> Self {
142 self.period = Some(n);
143 self
144 }
145 #[inline(always)]
146 pub fn nbdev(mut self, x: f64) -> Self {
147 self.nbdev = Some(x);
148 self
149 }
150 #[inline(always)]
151 pub fn kernel(mut self, k: Kernel) -> Self {
152 self.kernel = k;
153 self
154 }
155 #[inline(always)]
156 pub fn apply(self, c: &Candles) -> Result<StdDevOutput, StdDevError> {
157 let p = StdDevParams {
158 period: self.period,
159 nbdev: self.nbdev,
160 };
161 let i = StdDevInput::from_candles(c, "close", p);
162 stddev_with_kernel(&i, self.kernel)
163 }
164 #[inline(always)]
165 pub fn apply_slice(self, d: &[f64]) -> Result<StdDevOutput, StdDevError> {
166 let p = StdDevParams {
167 period: self.period,
168 nbdev: self.nbdev,
169 };
170 let i = StdDevInput::from_slice(d, p);
171 stddev_with_kernel(&i, self.kernel)
172 }
173 #[inline(always)]
174 pub fn into_stream(self) -> Result<StdDevStream, StdDevError> {
175 let p = StdDevParams {
176 period: self.period,
177 nbdev: self.nbdev,
178 };
179 StdDevStream::try_new(p)
180 }
181}
182
183#[derive(Debug, Error)]
184pub enum StdDevError {
185 #[error("stddev: Input data slice is empty.")]
186 EmptyInputData,
187 #[error("stddev: All values are NaN.")]
188 AllValuesNaN,
189 #[error("stddev: Invalid period: period = {period}, data length = {data_len}")]
190 InvalidPeriod { period: usize, data_len: usize },
191 #[error("stddev: Not enough valid data: needed = {needed}, valid = {valid}")]
192 NotEnoughValidData { needed: usize, valid: usize },
193 #[error("stddev: Invalid nbdev: {nbdev}. Must be non-negative and finite.")]
194 InvalidNbdev { nbdev: f64 },
195 #[error("stddev: Output length mismatch: expected = {expected}, got = {got}")]
196 OutputLengthMismatch { expected: usize, got: usize },
197 #[error("stddev: Invalid range: start={start}, end={end}, step={step}")]
198 InvalidRange {
199 start: String,
200 end: String,
201 step: String,
202 },
203 #[error("stddev: Invalid kernel for batch: {0:?}")]
204 InvalidKernelForBatch(crate::utilities::enums::Kernel),
205
206 #[error("stddev: Output length mismatch: dst = {dst_len}, expected = {expected_len}")]
207 MismatchedOutputLen { dst_len: usize, expected_len: usize },
208 #[error("stddev: Invalid kernel type: {msg}")]
209 InvalidKernel { msg: String },
210 #[error("stddev: Invalid input: {msg}")]
211 InvalidInput { msg: String },
212}
213
214#[inline]
215pub fn stddev(input: &StdDevInput) -> Result<StdDevOutput, StdDevError> {
216 stddev_with_kernel(input, Kernel::Auto)
217}
218
219pub fn stddev_with_kernel(
220 input: &StdDevInput,
221 kernel: Kernel,
222) -> Result<StdDevOutput, StdDevError> {
223 let data: &[f64] = input.as_ref();
224 let len = data.len();
225 if len == 0 {
226 return Err(StdDevError::EmptyInputData);
227 }
228
229 let first = data
230 .iter()
231 .position(|x| !x.is_nan())
232 .ok_or(StdDevError::AllValuesNaN)?;
233 let period = input.get_period();
234 let nbdev = input.get_nbdev();
235
236 if !nbdev.is_finite() || nbdev < 0.0 {
237 return Err(StdDevError::InvalidNbdev { nbdev });
238 }
239 if period == 0 || period > len {
240 return Err(StdDevError::InvalidPeriod {
241 period,
242 data_len: len,
243 });
244 }
245 if (len - first) < period {
246 return Err(StdDevError::NotEnoughValidData {
247 needed: period,
248 valid: len - first,
249 });
250 }
251
252 let warmup = first + period - 1;
253 let mut out = alloc_with_nan_prefix(len, warmup);
254
255 let chosen = match kernel {
256 Kernel::Auto => detect_best_kernel(),
257 k => k,
258 };
259
260 unsafe {
261 match chosen {
262 Kernel::Scalar | Kernel::ScalarBatch => {
263 stddev_scalar(data, period, first, nbdev, &mut out)
264 }
265 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
266 Kernel::Avx2 | Kernel::Avx2Batch => stddev_avx2(data, period, first, nbdev, &mut out),
267 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
268 Kernel::Avx512 | Kernel::Avx512Batch => {
269 stddev_avx512(data, period, first, nbdev, &mut out)
270 }
271
272 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
273 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
274 stddev_scalar(data, period, first, nbdev, &mut out)
275 }
276 _ => unreachable!(),
277 }
278 }
279
280 Ok(StdDevOutput { values: out })
281}
282
283#[inline]
284pub fn stddev_into_slice(
285 dst: &mut [f64],
286 input: &StdDevInput,
287 kern: Kernel,
288) -> Result<(), StdDevError> {
289 let data: &[f64] = input.as_ref();
290 let len = data.len();
291 if len == 0 {
292 return Err(StdDevError::EmptyInputData);
293 }
294
295 let first = data
296 .iter()
297 .position(|x| !x.is_nan())
298 .ok_or(StdDevError::AllValuesNaN)?;
299 let period = input.get_period();
300 let nbdev = input.get_nbdev();
301
302 if !nbdev.is_finite() || nbdev < 0.0 {
303 return Err(StdDevError::InvalidNbdev { nbdev });
304 }
305 if period == 0 || period > len {
306 return Err(StdDevError::InvalidPeriod {
307 period,
308 data_len: len,
309 });
310 }
311 if (len - first) < period {
312 return Err(StdDevError::NotEnoughValidData {
313 needed: period,
314 valid: len - first,
315 });
316 }
317 if dst.len() != len {
318 return Err(StdDevError::OutputLengthMismatch {
319 expected: len,
320 got: dst.len(),
321 });
322 }
323
324 let warmup = first + period - 1;
325 for v in &mut dst[..warmup] {
326 *v = f64::NAN;
327 }
328
329 let chosen = match kern {
330 Kernel::Auto => detect_best_kernel(),
331 k => k,
332 };
333 unsafe {
334 match chosen {
335 Kernel::Scalar | Kernel::ScalarBatch => stddev_scalar(data, period, first, nbdev, dst),
336 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
337 Kernel::Avx2 | Kernel::Avx2Batch => stddev_avx2(data, period, first, nbdev, dst),
338 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
339 Kernel::Avx512 | Kernel::Avx512Batch => stddev_avx512(data, period, first, nbdev, dst),
340 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
341 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
342 stddev_scalar(data, period, first, nbdev, dst)
343 }
344 _ => unreachable!(),
345 }
346 }
347
348 Ok(())
349}
350
351#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
352#[inline]
353pub fn stddev_into(input: &StdDevInput, out: &mut [f64]) -> Result<(), StdDevError> {
354 stddev_into_slice(out, input, Kernel::Auto)
355}
356
357#[inline]
358pub fn stddev_scalar(data: &[f64], period: usize, first: usize, nbdev: f64, out: &mut [f64]) {
359 let den = period as f64;
360 let inv_den = 1.0 / den;
361
362 let len = data.len();
363
364 let mut sum = 0.0;
365 let mut sum_sqr = 0.0;
366
367 unsafe {
368 let mut ptr = data.as_ptr().add(first);
369 let end = ptr.add(period);
370 while ptr < end {
371 let val = *ptr;
372 sum += val;
373 sum_sqr += val * val;
374 ptr = ptr.add(1);
375 }
376 }
377
378 let idx0 = first + period - 1;
379 let mean0 = sum * inv_den;
380 let var0 = (sum_sqr * inv_den) - (mean0 * mean0);
381 out[idx0] = if var0 <= 0.0 {
382 0.0
383 } else {
384 var0.sqrt() * nbdev
385 };
386
387 unsafe {
388 let mut out_ptr = out.as_mut_ptr().add(idx0 + 1);
389 let mut in_new = data.as_ptr().add(first + period);
390 let mut in_old = data.as_ptr().add(first);
391 let end = data.as_ptr().add(len);
392
393 while in_new < end {
394 let old = *in_old;
395 let new = *in_new;
396 sum += new - old;
397 sum_sqr += new * new - old * old;
398
399 let mean = sum * inv_den;
400 let var = (sum_sqr * inv_den) - (mean * mean);
401 *out_ptr = if var <= 0.0 { 0.0 } else { var.sqrt() * nbdev };
402
403 in_new = in_new.add(1);
404 in_old = in_old.add(1);
405 out_ptr = out_ptr.add(1);
406 }
407 }
408}
409
410#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
411#[inline]
412pub fn stddev_avx2(data: &[f64], period: usize, first: usize, nbdev: f64, out: &mut [f64]) {
413 stddev_scalar(data, period, first, nbdev, out);
414}
415
416#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
417#[inline]
418pub fn stddev_avx512(data: &[f64], period: usize, first: usize, nbdev: f64, out: &mut [f64]) {
419 if period <= 32 {
420 unsafe { stddev_avx512_short(data, period, first, nbdev, out) }
421 } else {
422 unsafe { stddev_avx512_long(data, period, first, nbdev, out) }
423 }
424}
425
426#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
427#[inline]
428unsafe fn stddev_avx512_short(
429 data: &[f64],
430 period: usize,
431 first: usize,
432 nbdev: f64,
433 out: &mut [f64],
434) {
435 stddev_scalar(data, period, first, nbdev, out);
436}
437
438#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
439#[inline]
440unsafe fn stddev_avx512_long(
441 data: &[f64],
442 period: usize,
443 first: usize,
444 nbdev: f64,
445 out: &mut [f64],
446) {
447 stddev_scalar(data, period, first, nbdev, out);
448}
449
450#[derive(Debug, Clone)]
451pub struct StdDevStream {
452 period: usize,
453 nbdev: f64,
454 inv_den: f64,
455 buffer: Vec<f64>,
456 head: usize,
457 filled: bool,
458 sum: f64,
459 sum_sqr: f64,
460 nan_count: usize,
461}
462
463impl StdDevStream {
464 pub fn try_new(params: StdDevParams) -> Result<Self, StdDevError> {
465 let period = params.period.unwrap_or(5);
466 if period == 0 {
467 return Err(StdDevError::InvalidPeriod {
468 period,
469 data_len: 0,
470 });
471 }
472 let nbdev = params.nbdev.unwrap_or(1.0);
473 if !nbdev.is_finite() || nbdev < 0.0 {
474 return Err(StdDevError::InvalidNbdev { nbdev });
475 }
476 Ok(Self {
477 period,
478 nbdev,
479 inv_den: 1.0 / period as f64,
480 buffer: vec![f64::NAN; period],
481 head: 0,
482 filled: false,
483 sum: 0.0,
484 sum_sqr: 0.0,
485 nan_count: 0,
486 })
487 }
488
489 #[inline(always)]
490 pub fn update(&mut self, value: f64) -> Option<f64> {
491 if !self.filled {
492 if value.is_nan() {
493 self.nan_count += 1;
494 } else {
495 self.sum += value;
496 self.sum_sqr += value * value;
497 }
498 self.buffer[self.head] = value;
499
500 let next = self.head + 1;
501 if next == self.period {
502 self.head = 0;
503 self.filled = true;
504
505 if self.nan_count > 0 {
506 return Some(f64::NAN);
507 }
508 let mean = self.sum * self.inv_den;
509 let var = (self.sum_sqr * self.inv_den) - (mean * mean);
510 return Some(if var <= 0.0 {
511 0.0
512 } else {
513 var.sqrt() * self.nbdev
514 });
515 } else {
516 self.head = next;
517 return None;
518 }
519 }
520
521 let old = self.buffer[self.head];
522 let new_is_nan = value.is_nan();
523 let old_is_nan = old.is_nan();
524
525 match (old_is_nan, new_is_nan) {
526 (false, false) => {
527 self.sum += value - old;
528 self.sum_sqr += (value * value) - (old * old);
529 }
530 (false, true) => {
531 self.sum -= old;
532 self.sum_sqr -= old * old;
533 self.nan_count += 1;
534 }
535 (true, false) => {
536 if self.nan_count > 0 {
537 self.nan_count -= 1;
538 }
539 self.sum += value;
540 self.sum_sqr += value * value;
541 }
542 (true, true) => {
543 if self.nan_count > 0 {
544 self.nan_count -= 1;
545 }
546
547 self.nan_count += 1;
548 }
549 }
550
551 self.buffer[self.head] = value;
552
553 self.head += 1;
554 if self.head == self.period {
555 self.head = 0;
556 }
557
558 if self.nan_count > 0 {
559 return Some(f64::NAN);
560 }
561
562 let mean = self.sum * self.inv_den;
563 let var = (self.sum_sqr * self.inv_den) - (mean * mean);
564 Some(if var <= 0.0 {
565 0.0
566 } else {
567 var.sqrt() * self.nbdev
568 })
569 }
570}
571
572#[derive(Clone, Debug)]
573pub struct StdDevBatchRange {
574 pub period: (usize, usize, usize),
575 pub nbdev: (f64, f64, f64),
576}
577
578impl Default for StdDevBatchRange {
579 fn default() -> Self {
580 Self {
581 period: (5, 254, 1),
582 nbdev: (1.0, 1.0, 0.0),
583 }
584 }
585}
586
587#[derive(Clone, Debug, Default)]
588pub struct StdDevBatchBuilder {
589 range: StdDevBatchRange,
590 kernel: Kernel,
591}
592
593impl StdDevBatchBuilder {
594 pub fn new() -> Self {
595 Self::default()
596 }
597 pub fn kernel(mut self, k: Kernel) -> Self {
598 self.kernel = k;
599 self
600 }
601 #[inline]
602 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
603 self.range.period = (start, end, step);
604 self
605 }
606 #[inline]
607 pub fn period_static(mut self, p: usize) -> Self {
608 self.range.period = (p, p, 0);
609 self
610 }
611 #[inline]
612 pub fn nbdev_range(mut self, start: f64, end: f64, step: f64) -> Self {
613 self.range.nbdev = (start, end, step);
614 self
615 }
616 #[inline]
617 pub fn nbdev_static(mut self, x: f64) -> Self {
618 self.range.nbdev = (x, x, 0.0);
619 self
620 }
621 pub fn apply_slice(self, data: &[f64]) -> Result<StdDevBatchOutput, StdDevError> {
622 stddev_batch_with_kernel(data, &self.range, self.kernel)
623 }
624 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<StdDevBatchOutput, StdDevError> {
625 StdDevBatchBuilder::new().kernel(k).apply_slice(data)
626 }
627 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<StdDevBatchOutput, StdDevError> {
628 let slice = source_type(c, src);
629 self.apply_slice(slice)
630 }
631 pub fn with_default_candles(c: &Candles) -> Result<StdDevBatchOutput, StdDevError> {
632 StdDevBatchBuilder::new()
633 .kernel(Kernel::Auto)
634 .apply_candles(c, "close")
635 }
636}
637
638pub fn stddev_batch_with_kernel(
639 data: &[f64],
640 sweep: &StdDevBatchRange,
641 k: Kernel,
642) -> Result<StdDevBatchOutput, StdDevError> {
643 let kernel = match k {
644 Kernel::Auto => detect_best_batch_kernel(),
645 other if other.is_batch() => other,
646 other => return Err(StdDevError::InvalidKernelForBatch(other)),
647 };
648
649 let simd = match kernel {
650 Kernel::Avx512Batch => Kernel::Avx512,
651 Kernel::Avx2Batch => Kernel::Avx2,
652 Kernel::ScalarBatch => Kernel::Scalar,
653 _ => unreachable!(),
654 };
655 stddev_batch_par_slice(data, sweep, simd)
656}
657
658#[derive(Clone, Debug)]
659pub struct StdDevBatchOutput {
660 pub values: Vec<f64>,
661 pub combos: Vec<StdDevParams>,
662 pub rows: usize,
663 pub cols: usize,
664}
665impl StdDevBatchOutput {
666 pub fn row_for_params(&self, p: &StdDevParams) -> Option<usize> {
667 self.combos.iter().position(|c| {
668 c.period.unwrap_or(5) == p.period.unwrap_or(5)
669 && (c.nbdev.unwrap_or(1.0) - p.nbdev.unwrap_or(1.0)).abs() < 1e-12
670 })
671 }
672 pub fn values_for(&self, p: &StdDevParams) -> Option<&[f64]> {
673 self.row_for_params(p).and_then(|row| {
674 let start = row.checked_mul(self.cols)?;
675 let end = start.checked_add(self.cols)?;
676 self.values.get(start..end)
677 })
678 }
679}
680
681#[inline(always)]
682fn expand_grid_checked(r: &StdDevBatchRange) -> Result<Vec<StdDevParams>, StdDevError> {
683 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, StdDevError> {
684 if step == 0 || start == end {
685 return Ok(vec![start]);
686 }
687 let mut v = Vec::new();
688 if start < end {
689 let mut cur = start;
690 while cur <= end {
691 v.push(cur);
692 let next = cur.saturating_add(step);
693 if next == cur {
694 break;
695 }
696 cur = next;
697 }
698 } else {
699 let mut cur = start;
700 while cur >= end {
701 v.push(cur);
702 let next = cur.saturating_sub(step);
703 if next == cur {
704 break;
705 }
706 cur = next;
707 if cur == 0 && end > 0 {
708 break;
709 }
710 }
711 }
712 if v.is_empty() {
713 return Err(StdDevError::InvalidRange {
714 start: start.to_string(),
715 end: end.to_string(),
716 step: step.to_string(),
717 });
718 }
719 Ok(v)
720 }
721 fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, StdDevError> {
722 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
723 return Ok(vec![start]);
724 }
725 let mut out = Vec::new();
726 if start < end {
727 let st = if step > 0.0 { step } else { -step };
728 let mut x = start;
729 while x <= end + 1e-12 {
730 out.push(x);
731 x += st;
732 }
733 } else {
734 let st = if step > 0.0 { -step } else { step };
735 if st.abs() < 1e-12 {
736 return Ok(vec![start]);
737 }
738 let mut x = start;
739 while x >= end - 1e-12 {
740 out.push(x);
741 x += st;
742 }
743 }
744 if out.is_empty() {
745 return Err(StdDevError::InvalidRange {
746 start: start.to_string(),
747 end: end.to_string(),
748 step: step.to_string(),
749 });
750 }
751 Ok(out)
752 }
753
754 let periods = axis_usize(r.period)?;
755 if periods.iter().any(|&p| p == 0) {
756 return Err(StdDevError::InvalidPeriod {
757 period: 0,
758 data_len: 0,
759 });
760 }
761
762 let (nb_start, nb_end, nb_step) = r.nbdev;
763 if !nb_start.is_finite() || nb_start < 0.0 {
764 return Err(StdDevError::InvalidNbdev { nbdev: nb_start });
765 }
766 if !nb_end.is_finite() || nb_end < 0.0 {
767 return Err(StdDevError::InvalidNbdev { nbdev: nb_end });
768 }
769 if !nb_step.is_finite() {
770 return Err(StdDevError::InvalidRange {
771 start: nb_start.to_string(),
772 end: nb_end.to_string(),
773 step: nb_step.to_string(),
774 });
775 }
776 let nbdevs = axis_f64(r.nbdev)?;
777 let cap = periods
778 .len()
779 .checked_mul(nbdevs.len())
780 .ok_or_else(|| StdDevError::InvalidInput {
781 msg: "stddev: parameter grid size overflow".to_string(),
782 })?;
783
784 let mut out = Vec::with_capacity(cap);
785 for &p in &periods {
786 for &n in &nbdevs {
787 out.push(StdDevParams {
788 period: Some(p),
789 nbdev: Some(n),
790 });
791 }
792 }
793 if out.is_empty() {
794 return Err(StdDevError::InvalidRange {
795 start: r.period.0.to_string(),
796 end: r.period.1.to_string(),
797 step: r.period.2.to_string(),
798 });
799 }
800 Ok(out)
801}
802
803#[inline(always)]
804pub fn stddev_batch_slice(
805 data: &[f64],
806 sweep: &StdDevBatchRange,
807 kern: Kernel,
808) -> Result<StdDevBatchOutput, StdDevError> {
809 stddev_batch_inner(data, sweep, kern, false)
810}
811
812#[inline(always)]
813pub fn stddev_batch_par_slice(
814 data: &[f64],
815 sweep: &StdDevBatchRange,
816 kern: Kernel,
817) -> Result<StdDevBatchOutput, StdDevError> {
818 stddev_batch_inner(data, sweep, kern, true)
819}
820
821#[inline(always)]
822fn stddev_batch_inner(
823 data: &[f64],
824 sweep: &StdDevBatchRange,
825 kern: Kernel,
826 parallel: bool,
827) -> Result<StdDevBatchOutput, StdDevError> {
828 let combos = expand_grid_checked(sweep)?;
829 let len = data.len();
830 if len == 0 {
831 return Err(StdDevError::EmptyInputData);
832 }
833
834 let first = data
835 .iter()
836 .position(|x| !x.is_nan())
837 .ok_or(StdDevError::AllValuesNaN)?;
838 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
839 if len - first < max_p {
840 return Err(StdDevError::NotEnoughValidData {
841 needed: max_p,
842 valid: len - first,
843 });
844 }
845 let rows = combos.len();
846 let cols = len;
847
848 let _ = rows
849 .checked_mul(cols)
850 .ok_or_else(|| StdDevError::InvalidInput {
851 msg: "stddev: rows*cols overflow in batch".to_string(),
852 })?;
853
854 let warm: Vec<usize> = combos
855 .iter()
856 .map(|c| first + c.period.unwrap() - 1)
857 .collect();
858
859 let mut buf_mu = make_uninit_matrix(rows, cols);
860 init_matrix_prefixes(&mut buf_mu, cols, &warm);
861
862 let mut values = unsafe {
863 Vec::from_raw_parts(
864 buf_mu.as_mut_ptr() as *mut f64,
865 buf_mu.len(),
866 buf_mu.capacity(),
867 )
868 };
869 std::mem::forget(buf_mu);
870
871 #[derive(Clone)]
872 struct StdPrefixes {
873 ps: Vec<f64>,
874 ps2: Vec<f64>,
875 pnan: Vec<i32>,
876 }
877 #[inline]
878 fn build_std_prefixes(data: &[f64]) -> StdPrefixes {
879 let n = data.len();
880 let mut ps = vec![0.0f64; n + 1];
881 let mut ps2 = vec![0.0f64; n + 1];
882 let mut pnan = vec![0i32; n + 1];
883 for i in 0..n {
884 let v = data[i];
885 if v.is_nan() {
886 ps[i + 1] = ps[i];
887 ps2[i + 1] = ps2[i];
888 pnan[i + 1] = pnan[i] + 1;
889 } else {
890 ps[i + 1] = ps[i] + v;
891 ps2[i + 1] = ps2[i] + v * v;
892 pnan[i + 1] = pnan[i];
893 }
894 }
895 StdPrefixes { ps, ps2, pnan }
896 }
897
898 #[inline]
899 fn stddev_from_prefix_scalar(
900 warmup_end: usize,
901 period: usize,
902 nbdev: f64,
903 pre: &StdPrefixes,
904 out_row: &mut [f64],
905 ) {
906 let n = out_row.len();
907 if n <= warmup_end {
908 return;
909 }
910
911 let inv_den = 1.0 / (period as f64);
912 let inv_den2 = inv_den * inv_den;
913
914 let no_nans = pre.pnan[n] == 0;
915 if no_nans {
916 for i in warmup_end..n {
917 let sum = pre.ps[i + 1] - pre.ps[i + 1 - period];
918 let sum2 = pre.ps2[i + 1] - pre.ps2[i + 1 - period];
919
920 let var = sum2.mul_add(inv_den, -(sum * sum) * inv_den2);
921 out_row[i] = if var <= 0.0 { 0.0 } else { var.sqrt() * nbdev };
922 }
923 return;
924 }
925
926 for i in warmup_end..n {
927 if pre.pnan[i + 1] - pre.pnan[i + 1 - period] > 0 {
928 out_row[i] = f64::NAN;
929 continue;
930 }
931 let sum = pre.ps[i + 1] - pre.ps[i + 1 - period];
932 let sum2 = pre.ps2[i + 1] - pre.ps2[i + 1 - period];
933 let var = sum2.mul_add(inv_den, -(sum * sum) * inv_den2);
934 out_row[i] = if var <= 0.0 { 0.0 } else { var.sqrt() * nbdev };
935 }
936 }
937
938 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
939 #[inline]
940 unsafe fn stddev_from_prefix_avx2(
941 warmup_end: usize,
942 period: usize,
943 nbdev: f64,
944 pre: &StdPrefixes,
945 out_row: &mut [f64],
946 ) {
947 use core::arch::x86_64::*;
948 let n = out_row.len();
949 if n <= warmup_end {
950 return;
951 }
952 let no_nans = pre.pnan[n] == 0;
953 if !no_nans {
954 stddev_from_prefix_scalar(warmup_end, period, nbdev, pre, out_row);
955 return;
956 }
957
958 let inv_den = 1.0 / (period as f64);
959 let inv_den2 = inv_den * inv_den;
960 let v_inv_den = _mm256_set1_pd(inv_den);
961 let v_inv_den2 = _mm256_set1_pd(inv_den2);
962 let v_nbdev = _mm256_set1_pd(nbdev);
963 let v_zero = _mm256_set1_pd(0.0);
964
965 let mut i = warmup_end;
966 while i + 4 <= n {
967 let s_hi = _mm256_loadu_pd(pre.ps.as_ptr().add(i + 1));
968 let s_lo = _mm256_loadu_pd(pre.ps.as_ptr().add(i + 1 - period));
969 let sum = _mm256_sub_pd(s_hi, s_lo);
970
971 let q_hi = _mm256_loadu_pd(pre.ps2.as_ptr().add(i + 1));
972 let q_lo = _mm256_loadu_pd(pre.ps2.as_ptr().add(i + 1 - period));
973 let sum2 = _mm256_sub_pd(q_hi, q_lo);
974
975 let sum_sq = _mm256_mul_pd(sum, sum);
976 let term = _mm256_mul_pd(sum_sq, v_inv_den2);
977 let var = _mm256_sub_pd(_mm256_mul_pd(sum2, v_inv_den), term);
978 let var_pos = _mm256_max_pd(var, v_zero);
979 let stdv = _mm256_sqrt_pd(var_pos);
980 let outv = _mm256_mul_pd(stdv, v_nbdev);
981 _mm256_storeu_pd(out_row.as_mut_ptr().add(i), outv);
982 i += 4;
983 }
984
985 if i < n {
986 stddev_from_prefix_scalar(i, period, nbdev, pre, out_row);
987 }
988 }
989
990 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
991 #[inline]
992 unsafe fn stddev_from_prefix_avx512(
993 warmup_end: usize,
994 period: usize,
995 nbdev: f64,
996 pre: &StdPrefixes,
997 out_row: &mut [f64],
998 ) {
999 use core::arch::x86_64::*;
1000 let n = out_row.len();
1001 if n <= warmup_end {
1002 return;
1003 }
1004 let no_nans = pre.pnan[n] == 0;
1005 if !no_nans {
1006 stddev_from_prefix_scalar(warmup_end, period, nbdev, pre, out_row);
1007 return;
1008 }
1009
1010 let inv_den = 1.0 / (period as f64);
1011 let inv_den2 = inv_den * inv_den;
1012 let v_inv_den = _mm512_set1_pd(inv_den);
1013 let v_inv_den2 = _mm512_set1_pd(inv_den2);
1014 let v_nbdev = _mm512_set1_pd(nbdev);
1015 let v_zero = _mm512_set1_pd(0.0);
1016
1017 let mut i = warmup_end;
1018 while i + 8 <= n {
1019 let s_hi = _mm512_loadu_pd(pre.ps.as_ptr().add(i + 1));
1020 let s_lo = _mm512_loadu_pd(pre.ps.as_ptr().add(i + 1 - period));
1021 let sum = _mm512_sub_pd(s_hi, s_lo);
1022
1023 let q_hi = _mm512_loadu_pd(pre.ps2.as_ptr().add(i + 1));
1024 let q_lo = _mm512_loadu_pd(pre.ps2.as_ptr().add(i + 1 - period));
1025 let sum2 = _mm512_sub_pd(q_hi, q_lo);
1026
1027 let sum_sq = _mm512_mul_pd(sum, sum);
1028 let term = _mm512_mul_pd(sum_sq, v_inv_den2);
1029 let var = _mm512_sub_pd(_mm512_mul_pd(sum2, v_inv_den), term);
1030 let var_pos = _mm512_max_pd(var, v_zero);
1031 let stdv = _mm512_sqrt_pd(var_pos);
1032 let outv = _mm512_mul_pd(stdv, v_nbdev);
1033 _mm512_storeu_pd(out_row.as_mut_ptr().add(i), outv);
1034 i += 8;
1035 }
1036 if i < n {
1037 stddev_from_prefix_scalar(i, period, nbdev, pre, out_row);
1038 }
1039 }
1040
1041 let prefixes = build_std_prefixes(data);
1042
1043 let do_row = |row: usize, out_row: &mut [f64]| {
1044 let period = combos[row].period.unwrap();
1045 let nbdev = combos[row].nbdev.unwrap();
1046 let warmup_end = first + period - 1;
1047 match kern {
1048 Kernel::Scalar => {
1049 stddev_from_prefix_scalar(warmup_end, period, nbdev, &prefixes, out_row)
1050 }
1051 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1052 Kernel::Avx2 => unsafe {
1053 stddev_from_prefix_avx2(warmup_end, period, nbdev, &prefixes, out_row)
1054 },
1055 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1056 Kernel::Avx512 => unsafe {
1057 stddev_from_prefix_avx512(warmup_end, period, nbdev, &prefixes, out_row)
1058 },
1059 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1060 Kernel::Avx2 | Kernel::Avx512 => {
1061 stddev_from_prefix_scalar(warmup_end, period, nbdev, &prefixes, out_row)
1062 }
1063 _ => unreachable!(),
1064 }
1065 };
1066
1067 if parallel {
1068 #[cfg(not(target_arch = "wasm32"))]
1069 {
1070 values
1071 .par_chunks_mut(cols)
1072 .enumerate()
1073 .for_each(|(row, slice)| do_row(row, slice));
1074 }
1075
1076 #[cfg(target_arch = "wasm32")]
1077 {
1078 for (row, slice) in values.chunks_mut(cols).enumerate() {
1079 do_row(row, slice);
1080 }
1081 }
1082 } else {
1083 for (row, slice) in values.chunks_mut(cols).enumerate() {
1084 do_row(row, slice);
1085 }
1086 }
1087
1088 Ok(StdDevBatchOutput {
1089 values,
1090 combos,
1091 rows,
1092 cols,
1093 })
1094}
1095
1096#[inline(always)]
1097unsafe fn stddev_row_scalar(
1098 data: &[f64],
1099 first: usize,
1100 period: usize,
1101 nbdev: f64,
1102 out: &mut [f64],
1103) {
1104 stddev_scalar(data, period, first, nbdev, out)
1105}
1106
1107#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1108#[inline(always)]
1109unsafe fn stddev_row_avx2(data: &[f64], first: usize, period: usize, nbdev: f64, out: &mut [f64]) {
1110 stddev_scalar(data, period, first, nbdev, out)
1111}
1112
1113#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1114#[inline(always)]
1115unsafe fn stddev_row_avx512(
1116 data: &[f64],
1117 first: usize,
1118 period: usize,
1119 nbdev: f64,
1120 out: &mut [f64],
1121) {
1122 if period <= 32 {
1123 stddev_row_avx512_short(data, first, period, nbdev, out)
1124 } else {
1125 stddev_row_avx512_long(data, first, period, nbdev, out)
1126 }
1127}
1128
1129#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1130#[inline(always)]
1131unsafe fn stddev_row_avx512_short(
1132 data: &[f64],
1133 first: usize,
1134 period: usize,
1135 nbdev: f64,
1136 out: &mut [f64],
1137) {
1138 stddev_scalar(data, period, first, nbdev, out)
1139}
1140
1141#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1142#[inline(always)]
1143unsafe fn stddev_row_avx512_long(
1144 data: &[f64],
1145 first: usize,
1146 period: usize,
1147 nbdev: f64,
1148 out: &mut [f64],
1149) {
1150 stddev_scalar(data, period, first, nbdev, out)
1151}
1152
1153#[inline(always)]
1154pub fn stddev_batch_inner_into(
1155 data: &[f64],
1156 sweep: &StdDevBatchRange,
1157 kern: Kernel,
1158 parallel: bool,
1159 out: &mut [f64],
1160) -> Result<Vec<StdDevParams>, StdDevError> {
1161 let combos = expand_grid_checked(sweep)?;
1162
1163 let len = data.len();
1164 if len == 0 {
1165 return Err(StdDevError::EmptyInputData);
1166 }
1167
1168 let first = data
1169 .iter()
1170 .position(|x| !x.is_nan())
1171 .ok_or(StdDevError::AllValuesNaN)?;
1172 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1173 if len - first < max_p {
1174 return Err(StdDevError::NotEnoughValidData {
1175 needed: max_p,
1176 valid: len - first,
1177 });
1178 }
1179
1180 let rows = combos.len();
1181 let cols = len;
1182 let expected = rows
1183 .checked_mul(cols)
1184 .ok_or_else(|| StdDevError::InvalidInput {
1185 msg: "stddev: rows*cols overflow in batch_into".to_string(),
1186 })?;
1187 if out.len() != expected {
1188 return Err(StdDevError::OutputLengthMismatch {
1189 expected,
1190 got: out.len(),
1191 });
1192 }
1193
1194 let warm: Vec<usize> = combos
1195 .iter()
1196 .map(|c| first + c.period.unwrap() - 1)
1197 .collect();
1198 let out_mu: &mut [MaybeUninit<f64>] = unsafe {
1199 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1200 };
1201 init_matrix_prefixes(out_mu, cols, &warm);
1202
1203 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1204 let period = combos[row].period.unwrap();
1205 let nbdev = combos[row].nbdev.unwrap();
1206 match kern {
1207 Kernel::Scalar => stddev_row_scalar(data, first, period, nbdev, out_row),
1208 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1209 Kernel::Avx2 => stddev_row_avx2(data, first, period, nbdev, out_row),
1210 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1211 Kernel::Avx512 => stddev_row_avx512(data, first, period, nbdev, out_row),
1212 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1213 Kernel::Avx2 | Kernel::Avx512 => stddev_row_scalar(data, first, period, nbdev, out_row),
1214 _ => unreachable!(),
1215 }
1216 };
1217
1218 if parallel {
1219 #[cfg(not(target_arch = "wasm32"))]
1220 {
1221 out.par_chunks_mut(cols)
1222 .enumerate()
1223 .for_each(|(row, slice)| do_row(row, slice));
1224 }
1225
1226 #[cfg(target_arch = "wasm32")]
1227 {
1228 for (row, slice) in out.chunks_mut(cols).enumerate() {
1229 do_row(row, slice);
1230 }
1231 }
1232 } else {
1233 for (row, slice) in out.chunks_mut(cols).enumerate() {
1234 do_row(row, slice);
1235 }
1236 }
1237
1238 Ok(combos)
1239}
1240
1241#[inline(always)]
1242pub fn expand_grid_stddev(r: &StdDevBatchRange) -> Vec<StdDevParams> {
1243 expand_grid_checked(r).unwrap_or_else(|_| Vec::new())
1244}
1245
1246#[cfg(feature = "python")]
1247#[pyfunction(name = "stddev")]
1248#[pyo3(signature = (data, period, nbdev, kernel=None))]
1249pub fn stddev_py<'py>(
1250 py: Python<'py>,
1251 data: numpy::PyReadonlyArray1<'py, f64>,
1252 period: usize,
1253 nbdev: f64,
1254 kernel: Option<&str>,
1255) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1256 use numpy::{IntoPyArray, PyArrayMethods};
1257
1258 let slice_in = data.as_slice()?;
1259 let kern = validate_kernel(kernel, false)?;
1260
1261 let params = StdDevParams {
1262 period: Some(period),
1263 nbdev: Some(nbdev),
1264 };
1265 let input = StdDevInput::from_slice(slice_in, params);
1266
1267 let result_vec: Vec<f64> = py
1268 .allow_threads(|| stddev_with_kernel(&input, kern).map(|o| o.values))
1269 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1270
1271 Ok(result_vec.into_pyarray(py))
1272}
1273
1274#[cfg(feature = "python")]
1275#[pyclass(name = "StdDevStream")]
1276pub struct StdDevStreamPy {
1277 stream: StdDevStream,
1278}
1279
1280#[cfg(feature = "python")]
1281#[pymethods]
1282impl StdDevStreamPy {
1283 #[new]
1284 fn new(period: usize, nbdev: f64) -> PyResult<Self> {
1285 let params = StdDevParams {
1286 period: Some(period),
1287 nbdev: Some(nbdev),
1288 };
1289 let stream =
1290 StdDevStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1291 Ok(StdDevStreamPy { stream })
1292 }
1293
1294 fn update(&mut self, value: f64) -> Option<f64> {
1295 self.stream.update(value)
1296 }
1297}
1298
1299#[cfg(feature = "python")]
1300#[pyfunction(name = "stddev_batch")]
1301#[pyo3(signature = (data, period_range, nbdev_range, kernel=None))]
1302pub fn stddev_batch_py<'py>(
1303 py: Python<'py>,
1304 data: numpy::PyReadonlyArray1<'py, f64>,
1305 period_range: (usize, usize, usize),
1306 nbdev_range: (f64, f64, f64),
1307 kernel: Option<&str>,
1308) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1309 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1310 use pyo3::types::PyDict;
1311
1312 let slice_in = data.as_slice()?;
1313 let kern = validate_kernel(kernel, true)?;
1314
1315 let sweep = StdDevBatchRange {
1316 period: period_range,
1317 nbdev: nbdev_range,
1318 };
1319
1320 let combos = expand_grid_checked(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1321 let rows = combos.len();
1322 let cols = slice_in.len();
1323 let total = rows
1324 .checked_mul(cols)
1325 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1326
1327 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1328 let slice_out = unsafe { out_arr.as_slice_mut()? };
1329
1330 let combos = py
1331 .allow_threads(|| {
1332 let kernel = match kern {
1333 Kernel::Auto => detect_best_batch_kernel(),
1334 k => k,
1335 };
1336 let simd = match kernel {
1337 Kernel::Avx512Batch => Kernel::Avx512,
1338 Kernel::Avx2Batch => Kernel::Avx2,
1339 Kernel::ScalarBatch => Kernel::Scalar,
1340 _ => unreachable!(),
1341 };
1342 stddev_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1343 })
1344 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1345
1346 let dict = PyDict::new(py);
1347 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1348 dict.set_item(
1349 "periods",
1350 combos
1351 .iter()
1352 .map(|p| p.period.unwrap() as u64)
1353 .collect::<Vec<_>>()
1354 .into_pyarray(py),
1355 )?;
1356 dict.set_item(
1357 "nbdevs",
1358 combos
1359 .iter()
1360 .map(|p| p.nbdev.unwrap())
1361 .collect::<Vec<_>>()
1362 .into_pyarray(py),
1363 )?;
1364
1365 Ok(dict)
1366}
1367
1368#[cfg(all(feature = "python", feature = "cuda"))]
1369#[pyfunction(name = "stddev_cuda_batch_dev")]
1370#[pyo3(signature = (data_f32, period_range, nbdev_range=(1.0, 1.0, 0.0), device_id=0))]
1371pub fn stddev_cuda_batch_dev_py<'py>(
1372 py: Python<'py>,
1373 data_f32: numpy::PyReadonlyArray1<'py, f32>,
1374 period_range: (usize, usize, usize),
1375 nbdev_range: (f64, f64, f64),
1376 device_id: usize,
1377) -> PyResult<DeviceArrayF32Py> {
1378 use crate::cuda::cuda_available;
1379 if !cuda_available() {
1380 return Err(PyValueError::new_err("CUDA not available"));
1381 }
1382
1383 let slice_in = data_f32.as_slice()?;
1384 let sweep = StdDevBatchRange {
1385 period: period_range,
1386 nbdev: nbdev_range,
1387 };
1388
1389 let inner = py.allow_threads(|| {
1390 let cuda = CudaStddev::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1391 cuda.stddev_batch_dev(slice_in, &sweep)
1392 .map(|(dev, _)| dev)
1393 .map_err(|e| PyValueError::new_err(e.to_string()))
1394 })?;
1395
1396 let handle = make_device_array_py(device_id, inner)?;
1397 Ok(handle)
1398}
1399
1400#[cfg(all(feature = "python", feature = "cuda"))]
1401#[pyfunction(name = "stddev_cuda_many_series_one_param_dev")]
1402#[pyo3(signature = (data_tm_f32, cols, rows, period, nbdev, device_id=0))]
1403pub fn stddev_cuda_many_series_one_param_dev_py<'py>(
1404 py: Python<'py>,
1405 data_tm_f32: numpy::PyReadonlyArray1<'py, f32>,
1406 cols: usize,
1407 rows: usize,
1408 period: usize,
1409 nbdev: f64,
1410 device_id: usize,
1411) -> PyResult<DeviceArrayF32Py> {
1412 use crate::cuda::cuda_available;
1413 if !cuda_available() {
1414 return Err(PyValueError::new_err("CUDA not available"));
1415 }
1416
1417 let slice_in = data_tm_f32.as_slice()?;
1418 let inner = py.allow_threads(|| {
1419 let cuda = CudaStddev::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1420 cuda.stddev_many_series_one_param_time_major_dev(slice_in, cols, rows, period, nbdev as f32)
1421 .map_err(|e| PyValueError::new_err(e.to_string()))
1422 })?;
1423 let handle = make_device_array_py(device_id, inner)?;
1424 Ok(handle)
1425}
1426
1427#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1428#[wasm_bindgen]
1429pub fn stddev_js(data: &[f64], period: usize, nbdev: f64) -> Result<Vec<f64>, JsValue> {
1430 let params = StdDevParams {
1431 period: Some(period),
1432 nbdev: Some(nbdev),
1433 };
1434 let input = StdDevInput::from_slice(data, params);
1435
1436 let mut output = vec![0.0; data.len()];
1437 stddev_into_slice(&mut output, &input, detect_best_kernel())
1438 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1439
1440 Ok(output)
1441}
1442
1443#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1444#[wasm_bindgen]
1445pub fn stddev_into(
1446 in_ptr: *const f64,
1447 out_ptr: *mut f64,
1448 len: usize,
1449 period: usize,
1450 nbdev: f64,
1451) -> Result<(), JsValue> {
1452 if in_ptr.is_null() || out_ptr.is_null() {
1453 return Err(JsValue::from_str("Null pointer provided"));
1454 }
1455
1456 unsafe {
1457 let data = std::slice::from_raw_parts(in_ptr, len);
1458 let params = StdDevParams {
1459 period: Some(period),
1460 nbdev: Some(nbdev),
1461 };
1462 let input = StdDevInput::from_slice(data, params);
1463
1464 if in_ptr == out_ptr {
1465 let mut temp = vec![0.0; len];
1466 stddev_into_slice(&mut temp, &input, detect_best_kernel())
1467 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1468 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1469 out.copy_from_slice(&temp);
1470 } else {
1471 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1472 stddev_into_slice(out, &input, detect_best_kernel())
1473 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1474 }
1475 Ok(())
1476 }
1477}
1478
1479#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1480#[wasm_bindgen]
1481pub fn stddev_alloc(len: usize) -> *mut f64 {
1482 let mut vec = Vec::<f64>::with_capacity(len);
1483 let ptr = vec.as_mut_ptr();
1484 std::mem::forget(vec);
1485 ptr
1486}
1487
1488#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1489#[wasm_bindgen]
1490pub fn stddev_free(ptr: *mut f64, len: usize) {
1491 if !ptr.is_null() {
1492 unsafe {
1493 let _ = Vec::from_raw_parts(ptr, len, len);
1494 }
1495 }
1496}
1497
1498#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1499#[derive(Serialize, Deserialize)]
1500pub struct StdDevBatchConfig {
1501 pub period_range: (usize, usize, usize),
1502 pub nbdev_range: (f64, f64, f64),
1503}
1504
1505#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1506#[derive(Serialize, Deserialize)]
1507pub struct StdDevBatchJsOutput {
1508 pub values: Vec<f64>,
1509 pub combos: Vec<StdDevParams>,
1510
1511 pub periods: Vec<usize>,
1512 pub nbdevs: Vec<f64>,
1513 pub rows: usize,
1514 pub cols: usize,
1515}
1516
1517#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1518#[wasm_bindgen(js_name = stddev_batch)]
1519pub fn stddev_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1520 let config: StdDevBatchConfig = serde_wasm_bindgen::from_value(config)
1521 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1522
1523 let sweep = StdDevBatchRange {
1524 period: config.period_range,
1525 nbdev: config.nbdev_range,
1526 };
1527
1528 let output = stddev_batch_inner(data, &sweep, detect_best_kernel(), false)
1529 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1530
1531 let js_output = StdDevBatchJsOutput {
1532 values: output.values,
1533 periods: output.combos.iter().map(|p| p.period.unwrap()).collect(),
1534 nbdevs: output.combos.iter().map(|p| p.nbdev.unwrap()).collect(),
1535 combos: output.combos,
1536 rows: output.rows,
1537 cols: output.cols,
1538 };
1539
1540 serde_wasm_bindgen::to_value(&js_output)
1541 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1542}
1543
1544#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1545#[wasm_bindgen]
1546pub fn stddev_batch_into(
1547 in_ptr: *const f64,
1548 out_ptr: *mut f64,
1549 len: usize,
1550 period_start: usize,
1551 period_end: usize,
1552 period_step: usize,
1553 nbdev_start: f64,
1554 nbdev_end: f64,
1555 nbdev_step: f64,
1556) -> Result<usize, JsValue> {
1557 if in_ptr.is_null() || out_ptr.is_null() {
1558 return Err(JsValue::from_str(
1559 "null pointer passed to stddev_batch_into",
1560 ));
1561 }
1562 unsafe {
1563 let data = std::slice::from_raw_parts(in_ptr, len);
1564 let sweep = StdDevBatchRange {
1565 period: (period_start, period_end, period_step),
1566 nbdev: (nbdev_start, nbdev_end, nbdev_step),
1567 };
1568
1569 let combos = expand_grid_checked(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1570 let rows = combos.len();
1571 let cols = len;
1572 let total = rows
1573 .checked_mul(cols)
1574 .ok_or_else(|| JsValue::from_str("rows*cols overflow in stddev_batch_into"))?;
1575 let out = std::slice::from_raw_parts_mut(out_ptr, total);
1576
1577 stddev_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
1578 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1579
1580 Ok(rows)
1581 }
1582}
1583
1584#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1585#[wasm_bindgen]
1586pub fn stddev_batch_into_cfg(
1587 in_ptr: *const f64,
1588 out_ptr: *mut f64,
1589 len: usize,
1590 config: JsValue,
1591) -> Result<JsValue, JsValue> {
1592 if in_ptr.is_null() || out_ptr.is_null() {
1593 return Err(JsValue::from_str("Null pointer provided"));
1594 }
1595
1596 let config: StdDevBatchConfig =
1597 serde_wasm_bindgen::from_value(config).map_err(|e| JsValue::from_str(&e.to_string()))?;
1598
1599 let sweep = StdDevBatchRange {
1600 period: config.period_range,
1601 nbdev: config.nbdev_range,
1602 };
1603
1604 let combos = expand_grid_checked(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1605 if combos.is_empty() {
1606 return Err(JsValue::from_str("No parameter combinations generated"));
1607 }
1608
1609 let rows = combos.len();
1610 let cols = len;
1611 let total = rows
1612 .checked_mul(cols)
1613 .ok_or_else(|| JsValue::from_str("rows*cols overflow in stddev_batch_into_cfg"))?;
1614
1615 unsafe {
1616 let data = std::slice::from_raw_parts(in_ptr, len);
1617 let out = std::slice::from_raw_parts_mut(out_ptr, total);
1618
1619 let params = stddev_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
1620 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1621
1622 let result = StdDevBatchJsOutput {
1623 values: vec![],
1624 periods: params.iter().map(|p| p.period.unwrap()).collect(),
1625 nbdevs: params.iter().map(|p| p.nbdev.unwrap()).collect(),
1626 combos: params,
1627 rows,
1628 cols,
1629 };
1630
1631 serde_wasm_bindgen::to_value(&result).map_err(|e| JsValue::from_str(&e.to_string()))
1632 }
1633}
1634
1635#[cfg(test)]
1636mod tests {
1637 use super::*;
1638 use crate::skip_if_unsupported;
1639 use crate::utilities::data_loader::read_candles_from_csv;
1640
1641 fn check_stddev_empty_input(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1642 skip_if_unsupported!(kernel, test);
1643 let empty: [f64; 0] = [];
1644 let input = StdDevInput::from_slice(&empty, StdDevParams::default());
1645 let res = stddev_with_kernel(&input, kernel);
1646 assert!(matches!(res, Err(StdDevError::EmptyInputData)));
1647 Ok(())
1648 }
1649
1650 fn check_stddev_invalid_batch_kernel(
1651 test: &str,
1652 _kernel: Kernel,
1653 ) -> Result<(), Box<dyn Error>> {
1654 let data = [1.0, 2.0, 3.0, 4.0, 5.0];
1655 let sweep = StdDevBatchRange::default();
1656
1657 let res = stddev_batch_with_kernel(&data, &sweep, Kernel::Scalar);
1658 assert!(matches!(res, Err(StdDevError::InvalidKernelForBatch(_))));
1659
1660 let res2 = stddev_batch_with_kernel(&data, &sweep, Kernel::Avx2);
1661 assert!(matches!(res2, Err(StdDevError::InvalidKernelForBatch(_))));
1662 Ok(())
1663 }
1664
1665 fn check_stddev_mismatched_output_len(
1666 test: &str,
1667 kernel: Kernel,
1668 ) -> Result<(), Box<dyn Error>> {
1669 skip_if_unsupported!(kernel, test);
1670 let data = [1.0, 2.0, 3.0, 4.0, 5.0];
1671 let params = StdDevParams::default();
1672 let input = StdDevInput::from_slice(&data, params);
1673
1674 let mut wrong_size_output = vec![0.0; 10];
1675 let res = stddev_into_slice(&mut wrong_size_output, &input, kernel);
1676 assert!(matches!(res, Err(StdDevError::OutputLengthMismatch { .. })));
1677
1678 let mut small_output = vec![0.0; 3];
1679 let res2 = stddev_into_slice(&mut small_output, &input, kernel);
1680 assert!(matches!(
1681 res2,
1682 Err(StdDevError::OutputLengthMismatch { .. })
1683 ));
1684 Ok(())
1685 }
1686
1687 fn check_stddev_negative_nbdev(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1688 skip_if_unsupported!(kernel, test);
1689 let data = [1.0, 2.0, 3.0, 4.0, 5.0];
1690 let params = StdDevParams {
1691 period: Some(3),
1692 nbdev: Some(-1.0),
1693 };
1694 let input = StdDevInput::from_slice(&data, params.clone());
1695 let res = stddev_with_kernel(&input, kernel);
1696 assert!(matches!(res, Err(StdDevError::InvalidNbdev { .. })));
1697
1698 let stream_res = StdDevStream::try_new(params);
1699 assert!(matches!(stream_res, Err(StdDevError::InvalidNbdev { .. })));
1700
1701 let inf_params = StdDevParams {
1702 period: Some(3),
1703 nbdev: Some(f64::INFINITY),
1704 };
1705 let inf_input = StdDevInput::from_slice(&data, inf_params);
1706 let inf_res = stddev_with_kernel(&inf_input, kernel);
1707 assert!(matches!(inf_res, Err(StdDevError::InvalidNbdev { .. })));
1708
1709 let nan_params = StdDevParams {
1710 period: Some(3),
1711 nbdev: Some(f64::NAN),
1712 };
1713 let nan_input = StdDevInput::from_slice(&data, nan_params);
1714 let nan_res = stddev_with_kernel(&nan_input, kernel);
1715 assert!(matches!(nan_res, Err(StdDevError::InvalidNbdev { .. })));
1716 Ok(())
1717 }
1718
1719 fn check_stddev_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1720 skip_if_unsupported!(kernel, test_name);
1721 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1722 let candles = read_candles_from_csv(file_path)?;
1723
1724 let default_params = StdDevParams {
1725 period: None,
1726 nbdev: None,
1727 };
1728 let input = StdDevInput::from_candles(&candles, "close", default_params);
1729 let output = stddev_with_kernel(&input, kernel)?;
1730 assert_eq!(output.values.len(), candles.close.len());
1731 Ok(())
1732 }
1733
1734 fn check_stddev_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1735 skip_if_unsupported!(kernel, test_name);
1736 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1737 let candles = read_candles_from_csv(file_path)?;
1738
1739 let input = StdDevInput::from_candles(&candles, "close", StdDevParams::default());
1740 let result = stddev_with_kernel(&input, kernel)?;
1741 let expected_last_five = [
1742 180.12506767314034,
1743 77.7395652441455,
1744 127.16225857341935,
1745 89.40156600773197,
1746 218.50034325919697,
1747 ];
1748 let start = result.values.len().saturating_sub(5);
1749 for (i, &val) in result.values[start..].iter().enumerate() {
1750 let diff = (val - expected_last_five[i]).abs();
1751 assert!(
1752 diff < 1e-1,
1753 "[{}] STDDEV {:?} mismatch at idx {}: got {}, expected {}",
1754 test_name,
1755 kernel,
1756 i,
1757 val,
1758 expected_last_five[i]
1759 );
1760 }
1761 Ok(())
1762 }
1763
1764 fn check_stddev_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1765 skip_if_unsupported!(kernel, test_name);
1766 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1767 let candles = read_candles_from_csv(file_path)?;
1768
1769 let input = StdDevInput::with_default_candles(&candles);
1770 match input.data {
1771 StdDevData::Candles { source, .. } => assert_eq!(source, "close"),
1772 _ => panic!("Expected StdDevData::Candles"),
1773 }
1774 let output = stddev_with_kernel(&input, kernel)?;
1775 assert_eq!(output.values.len(), candles.close.len());
1776 Ok(())
1777 }
1778
1779 fn check_stddev_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1780 skip_if_unsupported!(kernel, test_name);
1781 let input_data = [10.0, 20.0, 30.0];
1782 let params = StdDevParams {
1783 period: Some(0),
1784 nbdev: None,
1785 };
1786 let input = StdDevInput::from_slice(&input_data, params);
1787 let res = stddev_with_kernel(&input, kernel);
1788 assert!(
1789 res.is_err(),
1790 "[{}] STDDEV should fail with zero period",
1791 test_name
1792 );
1793 Ok(())
1794 }
1795
1796 fn check_stddev_period_exceeds_length(
1797 test_name: &str,
1798 kernel: Kernel,
1799 ) -> Result<(), Box<dyn Error>> {
1800 skip_if_unsupported!(kernel, test_name);
1801 let data_small = [10.0, 20.0, 30.0];
1802 let params = StdDevParams {
1803 period: Some(10),
1804 nbdev: None,
1805 };
1806 let input = StdDevInput::from_slice(&data_small, params);
1807 let res = stddev_with_kernel(&input, kernel);
1808 assert!(
1809 res.is_err(),
1810 "[{}] STDDEV should fail with period exceeding length",
1811 test_name
1812 );
1813 Ok(())
1814 }
1815
1816 fn check_stddev_very_small_dataset(
1817 test_name: &str,
1818 kernel: Kernel,
1819 ) -> Result<(), Box<dyn Error>> {
1820 skip_if_unsupported!(kernel, test_name);
1821 let single_point = [42.0];
1822 let params = StdDevParams {
1823 period: Some(5),
1824 nbdev: None,
1825 };
1826 let input = StdDevInput::from_slice(&single_point, params);
1827 let res = stddev_with_kernel(&input, kernel);
1828 assert!(
1829 res.is_err(),
1830 "[{}] STDDEV should fail with insufficient data",
1831 test_name
1832 );
1833 Ok(())
1834 }
1835
1836 fn check_stddev_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1837 skip_if_unsupported!(kernel, test_name);
1838 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1839 let candles = read_candles_from_csv(file_path)?;
1840
1841 let first_params = StdDevParams {
1842 period: Some(10),
1843 nbdev: Some(1.0),
1844 };
1845 let first_input = StdDevInput::from_candles(&candles, "close", first_params);
1846 let first_result = stddev_with_kernel(&first_input, kernel)?;
1847
1848 let second_params = StdDevParams {
1849 period: Some(10),
1850 nbdev: Some(1.0),
1851 };
1852 let second_input = StdDevInput::from_slice(&first_result.values, second_params);
1853 let second_result = stddev_with_kernel(&second_input, kernel)?;
1854
1855 assert_eq!(second_result.values.len(), first_result.values.len());
1856 for i in 19..second_result.values.len() {
1857 assert!(
1858 !second_result.values[i].is_nan(),
1859 "STDDEV slice reinput: Expected no NaN after index 19, but found NaN at index {}",
1860 i
1861 );
1862 }
1863 Ok(())
1864 }
1865
1866 #[test]
1867 fn test_stddev_into_matches_api() -> Result<(), Box<dyn Error>> {
1868 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1869 let candles = read_candles_from_csv(file_path)?;
1870
1871 let input = StdDevInput::from_candles(&candles, "close", StdDevParams::default());
1872
1873 let baseline = stddev(&input)?.values;
1874
1875 let mut out = vec![0.0; baseline.len()];
1876 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1877 {
1878 stddev_into(&input, &mut out)?;
1879 }
1880 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1881 {
1882 stddev_into_slice(&mut out, &input, Kernel::Auto)?;
1883 }
1884
1885 assert_eq!(baseline.len(), out.len());
1886
1887 let eq_or_both_nan = |a: f64, b: f64| -> bool {
1888 (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
1889 };
1890
1891 for i in 0..out.len() {
1892 assert!(
1893 eq_or_both_nan(baseline[i], out[i]),
1894 "Mismatch at index {}: baseline={}, into={}",
1895 i,
1896 baseline[i],
1897 out[i]
1898 );
1899 }
1900 Ok(())
1901 }
1902
1903 fn check_stddev_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1904 skip_if_unsupported!(kernel, test_name);
1905 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1906 let candles = read_candles_from_csv(file_path)?;
1907
1908 let input = StdDevInput::from_candles(
1909 &candles,
1910 "close",
1911 StdDevParams {
1912 period: Some(5),
1913 nbdev: None,
1914 },
1915 );
1916 let res = stddev_with_kernel(&input, kernel)?;
1917 assert_eq!(res.values.len(), candles.close.len());
1918 if res.values.len() > 20 {
1919 for (i, &val) in res.values[20..].iter().enumerate() {
1920 assert!(
1921 !val.is_nan(),
1922 "[{}] Found unexpected NaN at out-index {}",
1923 test_name,
1924 20 + i
1925 );
1926 }
1927 }
1928 Ok(())
1929 }
1930
1931 fn check_stddev_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1932 skip_if_unsupported!(kernel, test_name);
1933
1934 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1935 let candles = read_candles_from_csv(file_path)?;
1936
1937 let period = 5;
1938 let nbdev = 1.0;
1939
1940 let input = StdDevInput::from_candles(
1941 &candles,
1942 "close",
1943 StdDevParams {
1944 period: Some(period),
1945 nbdev: Some(nbdev),
1946 },
1947 );
1948 let batch_output = stddev_with_kernel(&input, kernel)?.values;
1949
1950 let mut stream = StdDevStream::try_new(StdDevParams {
1951 period: Some(period),
1952 nbdev: Some(nbdev),
1953 })?;
1954
1955 let mut stream_values = Vec::with_capacity(candles.close.len());
1956 for &price in &candles.close {
1957 match stream.update(price) {
1958 Some(val) => stream_values.push(val),
1959 None => stream_values.push(f64::NAN),
1960 }
1961 }
1962
1963 assert_eq!(batch_output.len(), stream_values.len());
1964 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1965 if b.is_nan() && s.is_nan() {
1966 continue;
1967 }
1968 let diff = (b - s).abs();
1969 assert!(
1970 diff < 1e-9,
1971 "[{}] STDDEV streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1972 test_name,
1973 i,
1974 b,
1975 s,
1976 diff
1977 );
1978 }
1979 Ok(())
1980 }
1981
1982 #[cfg(debug_assertions)]
1983 fn check_stddev_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1984 skip_if_unsupported!(kernel, test_name);
1985
1986 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1987 let candles = read_candles_from_csv(file_path)?;
1988
1989 let test_params = vec![
1990 StdDevParams::default(),
1991 StdDevParams {
1992 period: Some(2),
1993 nbdev: Some(1.0),
1994 },
1995 StdDevParams {
1996 period: Some(3),
1997 nbdev: Some(0.5),
1998 },
1999 StdDevParams {
2000 period: Some(5),
2001 nbdev: Some(1.0),
2002 },
2003 StdDevParams {
2004 period: Some(5),
2005 nbdev: Some(2.0),
2006 },
2007 StdDevParams {
2008 period: Some(7),
2009 nbdev: Some(1.5),
2010 },
2011 StdDevParams {
2012 period: Some(10),
2013 nbdev: Some(1.0),
2014 },
2015 StdDevParams {
2016 period: Some(10),
2017 nbdev: Some(3.0),
2018 },
2019 StdDevParams {
2020 period: Some(20),
2021 nbdev: Some(1.0),
2022 },
2023 StdDevParams {
2024 period: Some(20),
2025 nbdev: Some(2.5),
2026 },
2027 StdDevParams {
2028 period: Some(30),
2029 nbdev: Some(1.0),
2030 },
2031 StdDevParams {
2032 period: Some(50),
2033 nbdev: Some(2.0),
2034 },
2035 StdDevParams {
2036 period: Some(100),
2037 nbdev: Some(1.0),
2038 },
2039 StdDevParams {
2040 period: Some(100),
2041 nbdev: Some(3.0),
2042 },
2043 ];
2044
2045 for (param_idx, params) in test_params.iter().enumerate() {
2046 let input = StdDevInput::from_candles(&candles, "close", params.clone());
2047 let output = stddev_with_kernel(&input, kernel)?;
2048
2049 for (i, &val) in output.values.iter().enumerate() {
2050 if val.is_nan() {
2051 continue;
2052 }
2053
2054 let bits = val.to_bits();
2055
2056 if bits == 0x11111111_11111111 {
2057 panic!(
2058 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2059 with params: {:?} (param set {})",
2060 test_name, val, bits, i, params, param_idx
2061 );
2062 }
2063
2064 if bits == 0x22222222_22222222 {
2065 panic!(
2066 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2067 with params: {:?} (param set {})",
2068 test_name, val, bits, i, params, param_idx
2069 );
2070 }
2071
2072 if bits == 0x33333333_33333333 {
2073 panic!(
2074 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2075 with params: {:?} (param set {})",
2076 test_name, val, bits, i, params, param_idx
2077 );
2078 }
2079 }
2080 }
2081
2082 Ok(())
2083 }
2084
2085 #[cfg(not(debug_assertions))]
2086 fn check_stddev_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2087 Ok(())
2088 }
2089
2090 #[cfg(feature = "proptest")]
2091 #[allow(clippy::float_cmp)]
2092 fn check_stddev_property(
2093 test_name: &str,
2094 kernel: Kernel,
2095 ) -> Result<(), Box<dyn std::error::Error>> {
2096 use proptest::prelude::*;
2097 skip_if_unsupported!(kernel, test_name);
2098
2099 let strat = (2usize..=30, 0.5f64..=3.0f64).prop_flat_map(|(period, nbdev)| {
2100 (
2101 prop::collection::vec(
2102 (0.01f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2103 period..400,
2104 ),
2105 Just(period),
2106 Just(nbdev),
2107 )
2108 });
2109
2110 proptest::test_runner::TestRunner::default().run(&strat, |(data, period, nbdev)| {
2111 let params = StdDevParams {
2112 period: Some(period),
2113 nbdev: Some(nbdev),
2114 };
2115 let input = StdDevInput::from_slice(&data, params);
2116
2117 let StdDevOutput { values: out } = stddev_with_kernel(&input, kernel).unwrap();
2118 let StdDevOutput { values: ref_out } =
2119 stddev_with_kernel(&input, Kernel::Scalar).unwrap();
2120
2121 let warmup_period = period - 1;
2122
2123 for i in 0..warmup_period {
2124 prop_assert!(
2125 out[i].is_nan(),
2126 "Expected NaN during warmup at index {}, got {}",
2127 i,
2128 out[i]
2129 );
2130 }
2131
2132 for i in warmup_period..data.len() {
2133 let y = out[i];
2134 let r = ref_out[i];
2135
2136 prop_assert!(
2137 y.is_nan() || y >= 0.0,
2138 "StdDev at index {} is negative: {}",
2139 i,
2140 y
2141 );
2142
2143 let y_bits = y.to_bits();
2144 let r_bits = r.to_bits();
2145 let ulp_diff = if y_bits > r_bits {
2146 y_bits - r_bits
2147 } else {
2148 r_bits - y_bits
2149 };
2150 prop_assert!(
2151 ulp_diff <= 3 || (y.is_nan() && r.is_nan()),
2152 "Kernel mismatch at index {}: {} vs {} (ULP diff: {})",
2153 i,
2154 y,
2155 r,
2156 ulp_diff
2157 );
2158
2159 let window = &data[i + 1 - period..=i];
2160 let is_constant = window.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12);
2161 if is_constant {
2162 prop_assert!(
2163 y.abs() < 1e-9,
2164 "StdDev should be ~0 for constant data at index {}, got {}",
2165 i,
2166 y
2167 );
2168 }
2169
2170 if (nbdev - 1.0).abs() < 1e-9 {
2171 let window_min = window.iter().cloned().fold(f64::INFINITY, f64::min);
2172 let window_max = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
2173 let range = window_max - window_min;
2174
2175 let max_stddev =
2176 (range / 2.0) * ((period as f64) / ((period - 1) as f64)).sqrt();
2177
2178 prop_assert!(
2179 y <= max_stddev + 1e-9,
2180 "StdDev at index {} ({}) exceeds theoretical maximum ({})",
2181 i,
2182 y,
2183 max_stddev
2184 );
2185 }
2186
2187 if (nbdev - 1.0).abs() > 1e-9 {
2188 let params_unit = StdDevParams {
2189 period: Some(period),
2190 nbdev: Some(1.0),
2191 };
2192 let input_unit = StdDevInput::from_slice(&data, params_unit);
2193 let StdDevOutput { values: out_unit } =
2194 stddev_with_kernel(&input_unit, kernel).unwrap();
2195 let y_unit = out_unit[i];
2196
2197 let expected = y_unit * nbdev;
2198 let diff = (y - expected).abs();
2199 prop_assert!(
2200 diff < 1e-9 || (y.is_nan() && y_unit.is_nan()),
2201 "Scaling mismatch at index {}: {} != {} * {} = {}",
2202 i,
2203 y,
2204 y_unit,
2205 nbdev,
2206 expected
2207 );
2208 }
2209 }
2210
2211 if period == 2 && data.len() >= 2 {
2212 let identical_data = vec![42.0; 10];
2213 let params2 = StdDevParams {
2214 period: Some(2),
2215 nbdev: Some(nbdev),
2216 };
2217 let input2 = StdDevInput::from_slice(&identical_data, params2);
2218 let StdDevOutput { values: out2 } = stddev_with_kernel(&input2, kernel).unwrap();
2219
2220 for i in 1..out2.len() {
2221 prop_assert!(
2222 out2[i].abs() < 1e-9,
2223 "StdDev for identical pairs should be 0, got {} at index {}",
2224 out2[i],
2225 i
2226 );
2227 }
2228 }
2229
2230 if data.len() >= period * 2 {
2231 let monotonic_data: Vec<f64> = (0..100).map(|i| 100.0 + i as f64 * 10.0).collect();
2232 let mono_params = StdDevParams {
2233 period: Some(period),
2234 nbdev: Some(1.0),
2235 };
2236 let mono_input = StdDevInput::from_slice(&monotonic_data, mono_params);
2237 let StdDevOutput { values: mono_out } =
2238 stddev_with_kernel(&mono_input, kernel).unwrap();
2239
2240 let step_size = 10.0;
2241 let expected_stddev = step_size * ((period * period - 1) as f64 / 12.0).sqrt();
2242
2243 for i in (period - 1)..mono_out.len().min(period * 3) {
2244 let deviation = (mono_out[i] - expected_stddev).abs();
2245 prop_assert!(
2246 deviation < 1.0,
2247 "Monotonic pattern stddev mismatch at index {}: got {}, expected ~{}",
2248 i,
2249 mono_out[i],
2250 expected_stddev
2251 );
2252 }
2253 }
2254
2255 if data.len() >= period * 2 && period >= 4 {
2256 let alternating_data: Vec<f64> = (0..100)
2257 .map(|i| if i % 2 == 0 { 1000.0 } else { 100.0 })
2258 .collect();
2259 let alt_params = StdDevParams {
2260 period: Some(period),
2261 nbdev: Some(1.0),
2262 };
2263 let alt_input = StdDevInput::from_slice(&alternating_data, alt_params);
2264 let StdDevOutput { values: alt_out } =
2265 stddev_with_kernel(&alt_input, kernel).unwrap();
2266
2267 let alt_range = 900.0;
2268 let expected_alt_stddev = alt_range / 2.0;
2269
2270 for i in (period - 1)..alt_out.len().min(period * 3) {
2271 prop_assert!(
2272 alt_out[i] > alt_range * 0.4,
2273 "Alternating pattern should produce high stddev at index {}: got {}",
2274 i,
2275 alt_out[i]
2276 );
2277
2278 let max_possible =
2279 (alt_range / 2.0) * ((period as f64) / ((period - 1) as f64)).sqrt();
2280 prop_assert!(
2281 alt_out[i] <= max_possible + 1e-9,
2282 "Alternating pattern stddev exceeds maximum at index {}: got {}, max {}",
2283 i,
2284 alt_out[i],
2285 max_possible
2286 );
2287 }
2288 }
2289
2290 Ok(())
2291 })?;
2292
2293 Ok(())
2294 }
2295
2296 macro_rules! generate_all_stddev_tests {
2297 ($($test_fn:ident),*) => {
2298 paste::paste! {
2299 $(
2300 #[test]
2301 fn [<$test_fn _scalar_f64>]() {
2302 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2303 }
2304 )*
2305 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2306 $(
2307 #[test]
2308 fn [<$test_fn _avx2_f64>]() {
2309 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2310 }
2311 #[test]
2312 fn [<$test_fn _avx512_f64>]() {
2313 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2314 }
2315 )*
2316 }
2317 }
2318 }
2319 generate_all_stddev_tests!(
2320 check_stddev_empty_input,
2321 check_stddev_invalid_batch_kernel,
2322 check_stddev_mismatched_output_len,
2323 check_stddev_negative_nbdev,
2324 check_stddev_partial_params,
2325 check_stddev_accuracy,
2326 check_stddev_default_candles,
2327 check_stddev_zero_period,
2328 check_stddev_period_exceeds_length,
2329 check_stddev_very_small_dataset,
2330 check_stddev_reinput,
2331 check_stddev_nan_handling,
2332 check_stddev_streaming,
2333 check_stddev_no_poison
2334 );
2335
2336 #[cfg(feature = "proptest")]
2337 generate_all_stddev_tests!(check_stddev_property);
2338
2339 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2340 skip_if_unsupported!(kernel, test);
2341
2342 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2343 let c = read_candles_from_csv(file)?;
2344
2345 let output = StdDevBatchBuilder::new()
2346 .kernel(kernel)
2347 .apply_candles(&c, "close")?;
2348
2349 let def = StdDevParams::default();
2350 let row = output.values_for(&def).expect("default row missing");
2351
2352 assert_eq!(row.len(), c.close.len());
2353
2354 let expected = [
2355 180.12506767314034,
2356 77.7395652441455,
2357 127.16225857341935,
2358 89.40156600773197,
2359 218.50034325919697,
2360 ];
2361 let start = row.len() - 5;
2362 for (i, &v) in row[start..].iter().enumerate() {
2363 assert!(
2364 (v - expected[i]).abs() < 1e-1,
2365 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2366 );
2367 }
2368 Ok(())
2369 }
2370
2371 #[cfg(debug_assertions)]
2372 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2373 skip_if_unsupported!(kernel, test);
2374
2375 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2376 let c = read_candles_from_csv(file)?;
2377
2378 let test_configs = vec![
2379 (2, 10, 2, 1.0, 1.0, 0.0),
2380 (5, 25, 5, 0.5, 2.5, 0.5),
2381 (30, 60, 15, 1.0, 1.0, 0.0),
2382 (2, 5, 1, 1.0, 3.0, 1.0),
2383 (10, 10, 0, 0.5, 3.0, 0.5),
2384 (20, 50, 10, 2.0, 2.0, 0.0),
2385 (100, 100, 0, 1.0, 3.0, 1.0),
2386 ];
2387
2388 for (cfg_idx, &(p_start, p_end, p_step, n_start, n_end, n_step)) in
2389 test_configs.iter().enumerate()
2390 {
2391 let output = StdDevBatchBuilder::new()
2392 .kernel(kernel)
2393 .period_range(p_start, p_end, p_step)
2394 .nbdev_range(n_start, n_end, n_step)
2395 .apply_candles(&c, "close")?;
2396
2397 for (idx, &val) in output.values.iter().enumerate() {
2398 if val.is_nan() {
2399 continue;
2400 }
2401
2402 let bits = val.to_bits();
2403 let row = idx / output.cols;
2404 let col = idx % output.cols;
2405 let combo = &output.combos[row];
2406
2407 if bits == 0x11111111_11111111 {
2408 panic!(
2409 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2410 at row {} col {} (flat index {}) with params: {:?}",
2411 test, cfg_idx, val, bits, row, col, idx, combo
2412 );
2413 }
2414
2415 if bits == 0x22222222_22222222 {
2416 panic!(
2417 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2418 at row {} col {} (flat index {}) with params: {:?}",
2419 test, cfg_idx, val, bits, row, col, idx, combo
2420 );
2421 }
2422
2423 if bits == 0x33333333_33333333 {
2424 panic!(
2425 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2426 at row {} col {} (flat index {}) with params: {:?}",
2427 test, cfg_idx, val, bits, row, col, idx, combo
2428 );
2429 }
2430 }
2431 }
2432
2433 Ok(())
2434 }
2435
2436 #[cfg(not(debug_assertions))]
2437 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2438 Ok(())
2439 }
2440
2441 macro_rules! gen_batch_tests {
2442 ($fn_name:ident) => {
2443 paste::paste! {
2444 #[test] fn [<$fn_name _scalar>]() {
2445 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2446 }
2447 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2448 #[test] fn [<$fn_name _avx2>]() {
2449 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2450 }
2451 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2452 #[test] fn [<$fn_name _avx512>]() {
2453 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2454 }
2455 #[test] fn [<$fn_name _auto_detect>]() {
2456 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2457 }
2458 }
2459 };
2460 }
2461 gen_batch_tests!(check_batch_default_row);
2462 gen_batch_tests!(check_batch_no_poison);
2463}