1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::moving_averages::DeviceArrayF32;
3#[cfg(feature = "cuda")]
4use crate::cuda::{CudaZscore, CudaZscoreError};
5use crate::indicators::deviation::{
6 deviation, DevError, DevInput, DevParams, DeviationData, DeviationOutput,
7};
8use crate::indicators::moving_averages::ma::{ma, MaData};
9use crate::utilities::data_loader::{source_type, Candles};
10#[cfg(all(feature = "python", feature = "cuda"))]
11use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
12use crate::utilities::enums::Kernel;
13use crate::utilities::helpers::{
14 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
15 make_uninit_matrix,
16};
17#[cfg(feature = "python")]
18use crate::utilities::kernel_validation::validate_kernel;
19use aligned_vec::{AVec, CACHELINE_ALIGN};
20#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
21use core::arch::x86_64::*;
22#[cfg(all(feature = "python", feature = "cuda"))]
23use cust::context::Context;
24#[cfg(all(feature = "python", feature = "cuda"))]
25use cust::memory::DeviceBuffer;
26#[cfg(feature = "python")]
27use numpy::{IntoPyArray, PyArray1};
28#[cfg(feature = "python")]
29use pyo3::exceptions::PyValueError;
30#[cfg(feature = "python")]
31use pyo3::prelude::*;
32#[cfg(feature = "python")]
33use pyo3::types::{PyDict, PyList};
34#[cfg(not(target_arch = "wasm32"))]
35use rayon::prelude::*;
36#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
37use serde::{Deserialize, Serialize};
38use std::collections::HashMap;
39use std::error::Error;
40#[cfg(all(feature = "python", feature = "cuda"))]
41use std::sync::Arc;
42use thiserror::Error;
43#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
44use wasm_bindgen::prelude::*;
45
46impl<'a> AsRef<[f64]> for ZscoreInput<'a> {
47 #[inline(always)]
48 fn as_ref(&self) -> &[f64] {
49 match &self.data {
50 ZscoreData::Slice(slice) => slice,
51 ZscoreData::Candles { candles, source } => source_type(candles, source),
52 }
53 }
54}
55
56#[derive(Debug, Clone)]
57pub enum ZscoreData<'a> {
58 Candles {
59 candles: &'a Candles,
60 source: &'a str,
61 },
62 Slice(&'a [f64]),
63}
64
65#[derive(Debug, Clone)]
66pub struct ZscoreOutput {
67 pub values: Vec<f64>,
68}
69
70#[derive(Debug, Clone)]
71#[cfg_attr(
72 all(target_arch = "wasm32", feature = "wasm"),
73 derive(Serialize, Deserialize)
74)]
75pub struct ZscoreParams {
76 pub period: Option<usize>,
77 pub ma_type: Option<String>,
78 pub nbdev: Option<f64>,
79 pub devtype: Option<usize>,
80}
81
82impl Default for ZscoreParams {
83 fn default() -> Self {
84 Self {
85 period: Some(14),
86 ma_type: Some("sma".to_string()),
87 nbdev: Some(1.0),
88 devtype: Some(0),
89 }
90 }
91}
92
93#[derive(Debug, Clone)]
94pub struct ZscoreInput<'a> {
95 pub data: ZscoreData<'a>,
96 pub params: ZscoreParams,
97}
98
99impl<'a> ZscoreInput<'a> {
100 #[inline]
101 pub fn from_candles(candles: &'a Candles, source: &'a str, params: ZscoreParams) -> Self {
102 Self {
103 data: ZscoreData::Candles { candles, source },
104 params,
105 }
106 }
107 #[inline]
108 pub fn from_slice(slice: &'a [f64], params: ZscoreParams) -> Self {
109 Self {
110 data: ZscoreData::Slice(slice),
111 params,
112 }
113 }
114 #[inline]
115 pub fn with_default_candles(candles: &'a Candles) -> Self {
116 Self::from_candles(candles, "close", ZscoreParams::default())
117 }
118 #[inline]
119 pub fn get_period(&self) -> usize {
120 self.params.period.unwrap_or(14)
121 }
122 #[inline]
123 pub fn get_ma_type(&self) -> String {
124 self.params
125 .ma_type
126 .clone()
127 .unwrap_or_else(|| "sma".to_string())
128 }
129 #[inline]
130 pub fn get_nbdev(&self) -> f64 {
131 self.params.nbdev.unwrap_or(1.0)
132 }
133 #[inline]
134 pub fn get_devtype(&self) -> usize {
135 self.params.devtype.unwrap_or(0)
136 }
137}
138
139#[derive(Clone, Debug)]
140pub struct ZscoreBuilder {
141 period: Option<usize>,
142 ma_type: Option<String>,
143 nbdev: Option<f64>,
144 devtype: Option<usize>,
145 kernel: Kernel,
146}
147
148impl Default for ZscoreBuilder {
149 fn default() -> Self {
150 Self {
151 period: None,
152 ma_type: None,
153 nbdev: None,
154 devtype: None,
155 kernel: Kernel::Auto,
156 }
157 }
158}
159
160impl ZscoreBuilder {
161 #[inline(always)]
162 pub fn new() -> Self {
163 Self::default()
164 }
165 #[inline(always)]
166 pub fn period(mut self, n: usize) -> Self {
167 self.period = Some(n);
168 self
169 }
170 #[inline(always)]
171 pub fn ma_type<T: Into<String>>(mut self, t: T) -> Self {
172 self.ma_type = Some(t.into());
173 self
174 }
175 #[inline(always)]
176 pub fn nbdev(mut self, x: f64) -> Self {
177 self.nbdev = Some(x);
178 self
179 }
180 #[inline(always)]
181 pub fn devtype(mut self, dt: usize) -> Self {
182 self.devtype = Some(dt);
183 self
184 }
185 #[inline(always)]
186 pub fn kernel(mut self, k: Kernel) -> Self {
187 self.kernel = k;
188 self
189 }
190 #[inline(always)]
191 pub fn apply(self, candles: &Candles) -> Result<ZscoreOutput, ZscoreError> {
192 let p = ZscoreParams {
193 period: self.period,
194 ma_type: self.ma_type,
195 nbdev: self.nbdev,
196 devtype: self.devtype,
197 };
198 let i = ZscoreInput::from_candles(candles, "close", p);
199 zscore_with_kernel(&i, self.kernel)
200 }
201 #[inline(always)]
202 pub fn apply_slice(self, data: &[f64]) -> Result<ZscoreOutput, ZscoreError> {
203 let p = ZscoreParams {
204 period: self.period,
205 ma_type: self.ma_type,
206 nbdev: self.nbdev,
207 devtype: self.devtype,
208 };
209 let i = ZscoreInput::from_slice(data, p);
210 zscore_with_kernel(&i, self.kernel)
211 }
212 #[inline(always)]
213 pub fn into_stream(self) -> Result<ZscoreStream, ZscoreError> {
214 let p = ZscoreParams {
215 period: self.period,
216 ma_type: self.ma_type,
217 nbdev: self.nbdev,
218 devtype: self.devtype,
219 };
220 ZscoreStream::try_new(p)
221 }
222}
223
224#[derive(Debug, Error)]
225pub enum ZscoreError {
226 #[error("zscore: Input data slice is empty.")]
227 EmptyInputData,
228 #[error("zscore: All values are NaN.")]
229 AllValuesNaN,
230 #[error("zscore: Invalid period: period = {period}, data length = {data_len}")]
231 InvalidPeriod { period: usize, data_len: usize },
232 #[error("zscore: Not enough valid data: needed = {needed}, valid = {valid}")]
233 NotEnoughValidData { needed: usize, valid: usize },
234 #[error("zscore: output length mismatch: expected = {expected}, got = {got}")]
235 OutputLengthMismatch { expected: usize, got: usize },
236 #[error("zscore: invalid range: start={start}, end={end}, step={step}")]
237 InvalidRange { start: f64, end: f64, step: f64 },
238 #[error("zscore: invalid kernel for batch: {0:?}")]
239 InvalidKernelForBatch(Kernel),
240 #[error("zscore: DevError {0}")]
241 DevError(#[from] DevError),
242 #[error("zscore: MaError {0}")]
243 MaError(String),
244}
245
246#[inline]
247pub fn zscore(input: &ZscoreInput) -> Result<ZscoreOutput, ZscoreError> {
248 zscore_with_kernel(input, Kernel::Auto)
249}
250
251pub fn zscore_with_kernel(
252 input: &ZscoreInput,
253 kernel: Kernel,
254) -> Result<ZscoreOutput, ZscoreError> {
255 let data: &[f64] = input.as_ref();
256 if data.is_empty() {
257 return Err(ZscoreError::EmptyInputData);
258 }
259
260 let first = data
261 .iter()
262 .position(|x| !x.is_nan())
263 .ok_or(ZscoreError::AllValuesNaN)?;
264 let len = data.len();
265 let period = input.get_period();
266 if period == 0 || period > len {
267 return Err(ZscoreError::InvalidPeriod {
268 period,
269 data_len: len,
270 });
271 }
272 if (len - first) < period {
273 return Err(ZscoreError::NotEnoughValidData {
274 needed: period,
275 valid: len - first,
276 });
277 }
278
279 let ma_type = input.get_ma_type();
280 let nbdev = input.get_nbdev();
281 let devtype = input.get_devtype();
282
283 let chosen = match kernel {
284 Kernel::Auto => match detect_best_kernel() {
285 Kernel::Avx512 | Kernel::Avx512Batch => Kernel::Avx2,
286 other => other,
287 },
288 other => other,
289 };
290
291 unsafe {
292 match chosen {
293 Kernel::Scalar | Kernel::ScalarBatch => {
294 zscore_scalar(data, period, first, &ma_type, nbdev, devtype)
295 }
296 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
297 Kernel::Avx2 | Kernel::Avx2Batch => {
298 zscore_avx2(data, period, first, &ma_type, nbdev, devtype)
299 }
300 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
301 Kernel::Avx512 | Kernel::Avx512Batch => {
302 zscore_avx512(data, period, first, &ma_type, nbdev, devtype)
303 }
304 _ => unreachable!(),
305 }
306 }
307}
308
309#[inline]
310pub unsafe fn zscore_scalar(
311 data: &[f64],
312 period: usize,
313 first: usize,
314 ma_type: &str,
315 nbdev: f64,
316 devtype: usize,
317) -> Result<ZscoreOutput, ZscoreError> {
318 if devtype == 0 {
319 if ma_type == "sma" {
320 return zscore_scalar_classic_sma(data, period, first, nbdev);
321 } else if ma_type == "ema" {
322 return zscore_scalar_classic_ema(data, period, first, nbdev);
323 }
324 }
325
326 let means = ma(ma_type, MaData::Slice(data), period)
327 .map_err(|e| ZscoreError::MaError(e.to_string()))?;
328 let dev_input = DevInput {
329 data: DeviationData::Slice(data),
330 params: DevParams {
331 period: Some(period),
332 devtype: Some(devtype),
333 },
334 };
335 let mut sigmas = deviation(&dev_input)?.values;
336 for v in &mut sigmas {
337 *v *= nbdev;
338 }
339
340 let warmup_end = first + period - 1;
341 let n = data.len();
342 let mut out = alloc_with_nan_prefix(n, warmup_end);
343 let mut i = warmup_end;
344 while i < n {
345 let mean = *means.get_unchecked(i);
346 let sigma = *sigmas.get_unchecked(i);
347 let value = *data.get_unchecked(i);
348 *out.get_unchecked_mut(i) = if sigma == 0.0 || sigma.is_nan() {
349 f64::NAN
350 } else {
351 (value - mean) / sigma
352 };
353 i += 1;
354 }
355 Ok(ZscoreOutput { values: out })
356}
357
358#[inline]
359pub unsafe fn zscore_scalar_classic_sma(
360 data: &[f64],
361 period: usize,
362 first: usize,
363 nbdev: f64,
364) -> Result<ZscoreOutput, ZscoreError> {
365 let warmup_end = first + period - 1;
366 let mut out = alloc_with_nan_prefix(data.len(), warmup_end);
367
368 let inv = 1.0 / (period as f64);
369 let mut sum = 0.0f64;
370 let mut sum_sqr = 0.0f64;
371 {
372 let mut j = first;
373 while j <= warmup_end {
374 let v = *data.get_unchecked(j);
375 sum += v;
376
377 sum_sqr = v.mul_add(v, sum_sqr);
378 j += 1;
379 }
380 }
381 let mut mean = sum * inv;
382
383 let mut variance = (-mean).mul_add(mean, sum_sqr * inv);
384 if variance < 0.0 {
385 variance = 0.0;
386 }
387 let mut stddev = if variance == 0.0 {
388 0.0
389 } else {
390 variance.sqrt() * nbdev
391 };
392
393 let xw = *data.get_unchecked(warmup_end);
394 *out.get_unchecked_mut(warmup_end) = if stddev == 0.0 || stddev.is_nan() {
395 f64::NAN
396 } else {
397 (xw - mean) / stddev
398 };
399
400 let n = data.len();
401 let mut i = warmup_end + 1;
402 while i < n {
403 let old_val = *data.get_unchecked(i - period);
404 let new_val = *data.get_unchecked(i);
405 let dd = new_val - old_val;
406 sum += dd;
407
408 sum_sqr = (new_val + old_val).mul_add(dd, sum_sqr);
409 mean = sum * inv;
410
411 variance = (-mean).mul_add(mean, sum_sqr * inv);
412 if variance < 0.0 {
413 variance = 0.0;
414 }
415 stddev = if variance == 0.0 {
416 0.0
417 } else {
418 variance.sqrt() * nbdev
419 };
420
421 *out.get_unchecked_mut(i) = if stddev == 0.0 || stddev.is_nan() {
422 f64::NAN
423 } else {
424 (new_val - mean) / stddev
425 };
426 i += 1;
427 }
428
429 Ok(ZscoreOutput { values: out })
430}
431
432#[inline]
433pub unsafe fn zscore_scalar_classic_ema(
434 data: &[f64],
435 period: usize,
436 first: usize,
437 nbdev: f64,
438) -> Result<ZscoreOutput, ZscoreError> {
439 let n = data.len();
440 let warmup_end = first + period - 1;
441 let mut out = alloc_with_nan_prefix(n, warmup_end);
442
443 if n <= warmup_end {
444 return Ok(ZscoreOutput { values: out });
445 }
446
447 let den = period as f64;
448 let inv = 1.0 / den;
449 let alpha = 2.0 / (den + 1.0);
450 let one_minus_alpha = 1.0 - alpha;
451
452 let mut sum = 0.0f64;
453 let mut sum2 = 0.0f64;
454 {
455 let mut j = first;
456 while j <= warmup_end {
457 let v = *data.get_unchecked(j);
458 sum += v;
459 sum2 = v.mul_add(v, sum2);
460 j += 1;
461 }
462 }
463 let mut ema = sum * inv;
464
465 let mut ex = sum * inv;
466 let mut ex2 = sum2 * inv;
467 let mut mse = (-2.0 * ema).mul_add(ex, ema.mul_add(ema, ex2));
468 if mse < 0.0 {
469 mse = 0.0;
470 }
471 let mut sd = mse.sqrt() * nbdev;
472
473 let xw = *data.get_unchecked(warmup_end);
474 *out.get_unchecked_mut(warmup_end) = if sd == 0.0 || sd.is_nan() {
475 f64::NAN
476 } else {
477 (xw - ema) / sd
478 };
479
480 let mut i = warmup_end + 1;
481 while i < n {
482 let new = *data.get_unchecked(i);
483 let old = *data.get_unchecked(i - period);
484
485 let dd = new - old;
486 sum += dd;
487 sum2 = (new + old).mul_add(dd, sum2);
488 ex = sum * inv;
489 ex2 = sum2 * inv;
490
491 ema = ema.mul_add(one_minus_alpha, alpha * new);
492
493 mse = (-2.0 * ema).mul_add(ex, ema.mul_add(ema, ex2));
494 if mse < 0.0 {
495 mse = 0.0;
496 }
497 sd = mse.sqrt() * nbdev;
498
499 *out.get_unchecked_mut(i) = if sd == 0.0 || sd.is_nan() {
500 f64::NAN
501 } else {
502 (new - ema) / sd
503 };
504 i += 1;
505 }
506
507 Ok(ZscoreOutput { values: out })
508}
509
510#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
511#[inline]
512pub unsafe fn zscore_avx2(
513 data: &[f64],
514 period: usize,
515 first: usize,
516 ma_type: &str,
517 nbdev: f64,
518 devtype: usize,
519) -> Result<ZscoreOutput, ZscoreError> {
520 zscore_scalar(data, period, first, ma_type, nbdev, devtype)
521}
522
523#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
524#[inline]
525pub unsafe fn zscore_avx512(
526 data: &[f64],
527 period: usize,
528 first: usize,
529 ma_type: &str,
530 nbdev: f64,
531 devtype: usize,
532) -> Result<ZscoreOutput, ZscoreError> {
533 if period <= 32 {
534 zscore_avx512_short(data, period, first, ma_type, nbdev, devtype)
535 } else {
536 zscore_avx512_long(data, period, first, ma_type, nbdev, devtype)
537 }
538}
539
540#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
541#[inline]
542pub unsafe fn zscore_avx512_short(
543 data: &[f64],
544 period: usize,
545 first: usize,
546 ma_type: &str,
547 nbdev: f64,
548 devtype: usize,
549) -> Result<ZscoreOutput, ZscoreError> {
550 zscore_scalar(data, period, first, ma_type, nbdev, devtype)
551}
552
553#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
554#[inline]
555pub unsafe fn zscore_avx512_long(
556 data: &[f64],
557 period: usize,
558 first: usize,
559 ma_type: &str,
560 nbdev: f64,
561 devtype: usize,
562) -> Result<ZscoreOutput, ZscoreError> {
563 zscore_scalar(data, period, first, ma_type, nbdev, devtype)
564}
565
566#[derive(Debug, Clone)]
567pub struct ZscoreStream {
568 period: usize,
569 ma_type: String,
570 nbdev: f64,
571 devtype: usize,
572
573 buffer: Vec<f64>,
574 head: usize,
575 filled: bool,
576
577 sum: f64,
578 sum2: f64,
579
580 wsum: f64,
581
582 ema: f64,
583 ema_inited: bool,
584
585 nan_count: usize,
586
587 inv_period: f64,
588 wma_denom: f64,
589 inv_wma_denom: f64,
590 inv_nbdev: f64,
591
592 kind: MaKind,
593}
594
595#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
596enum MaKind {
597 Sma,
598 Ema,
599 Wma,
600 Other,
601}
602
603impl ZscoreStream {
604 pub fn try_new(params: ZscoreParams) -> Result<Self, ZscoreError> {
605 let period = params.period.unwrap_or(14);
606 if period == 0 {
607 return Err(ZscoreError::InvalidPeriod {
608 period,
609 data_len: 0,
610 });
611 }
612 let ma_type = params.ma_type.unwrap_or_else(|| "sma".to_string());
613 let nbdev = params.nbdev.unwrap_or(1.0);
614 let devtype = params.devtype.unwrap_or(0);
615
616 let kind = match ma_type.to_ascii_lowercase().as_str() {
617 "sma" => MaKind::Sma,
618 "ema" => MaKind::Ema,
619 "wma" => MaKind::Wma,
620 _ => MaKind::Other,
621 };
622
623 let n = period as f64;
624 let wden = n * (n + 1.0) * 0.5;
625 let inv_nbdev = if nbdev != 0.0 {
626 1.0 / nbdev
627 } else {
628 f64::INFINITY
629 };
630
631 Ok(Self {
632 period,
633 ma_type,
634 nbdev,
635 devtype,
636 buffer: vec![f64::NAN; period],
637 head: 0,
638 filled: false,
639
640 sum: 0.0,
641 sum2: 0.0,
642 wsum: 0.0,
643
644 ema: 0.0,
645 ema_inited: false,
646
647 nan_count: period,
648
649 inv_period: 1.0 / n,
650 wma_denom: wden,
651 inv_wma_denom: 1.0 / wden,
652 inv_nbdev,
653 kind,
654 })
655 }
656
657 #[inline(always)]
658 pub fn update(&mut self, value: f64) -> Option<f64> {
659 let old = self.buffer[self.head];
660 self.buffer[self.head] = value;
661 self.head = (self.head + 1) % self.period;
662 let just_filled = !self.filled && self.head == 0;
663 if just_filled {
664 self.filled = true;
665 }
666
667 if old.is_nan() {
668 self.nan_count = self.nan_count.saturating_sub(1);
669 }
670 if value.is_nan() {
671 self.nan_count += 1;
672 }
673
674 let sum_prev = self.sum;
675
676 let old_c = if old.is_nan() { 0.0 } else { old };
677 let new_c = if value.is_nan() { 0.0 } else { value };
678
679 self.sum = self.sum + new_c - old_c;
680 self.sum2 = self.sum2 + new_c * new_c - old_c * old_c;
681
682 self.wsum = self.wsum - sum_prev + (self.period as f64) * new_c;
683
684 if !self.filled {
685 return None;
686 }
687
688 if self.devtype == 0 && self.nbdev != 0.0 && self.nan_count == 0 {
689 let mean = match self.kind {
690 MaKind::Sma => self.sum * self.inv_period,
691 MaKind::Ema => {
692 if !self.ema_inited {
693 self.ema = self.sum * self.inv_period;
694 self.ema_inited = true;
695 self.ema
696 } else {
697 let alpha = 2.0 / ((self.period as f64) + 1.0);
698 self.ema = self.ema.mul_add(1.0 - alpha, alpha * new_c);
699 self.ema
700 }
701 }
702 MaKind::Wma => self.wsum * self.inv_wma_denom,
703 MaKind::Other => {
704 return Some(self.compute_zscore_slow());
705 }
706 };
707
708 let ex = self.sum * self.inv_period;
709 let ex2 = self.sum2 * self.inv_period;
710
711 let var = if self.kind == MaKind::Sma {
712 let v = ex2 - mean * mean;
713 if v < 0.0 {
714 0.0
715 } else {
716 v
717 }
718 } else {
719 let v = ex2 - 2.0 * mean * ex + mean * mean;
720 if v < 0.0 {
721 0.0
722 } else {
723 v
724 }
725 };
726
727 let sd = var.sqrt();
728 if sd == 0.0 || !sd.is_finite() {
729 return Some(f64::NAN);
730 }
731
732 let last_idx = if self.head == 0 {
733 self.period - 1
734 } else {
735 self.head - 1
736 };
737 let last = self.buffer[last_idx];
738 let z = (last - mean) / (sd * self.nbdev);
739 return Some(z);
740 }
741
742 Some(self.compute_zscore_slow())
743 }
744
745 #[inline(always)]
746 fn compute_zscore_slow(&self) -> f64 {
747 let mut ordered = vec![0.0; self.period];
748 let mut idx = self.head;
749 for i in 0..self.period {
750 ordered[i] = self.buffer[idx];
751 idx = (idx + 1) % self.period;
752 }
753
754 let means = match ma(&self.ma_type, MaData::Slice(&ordered), self.period) {
755 Ok(m) => m,
756 Err(_) => return f64::NAN,
757 };
758
759 let dev_input = DevInput {
760 data: DeviationData::Slice(&ordered),
761 params: DevParams {
762 period: Some(self.period),
763 devtype: Some(self.devtype),
764 },
765 };
766 let mut sigmas = match deviation(&dev_input) {
767 Ok(d) => d.values,
768 Err(_) => return f64::NAN,
769 };
770 for s in &mut sigmas {
771 *s *= self.nbdev;
772 }
773
774 let mean = means[self.period - 1];
775 let sigma = sigmas[self.period - 1];
776 let value = ordered[self.period - 1];
777 if sigma == 0.0 || sigma.is_nan() {
778 f64::NAN
779 } else {
780 (value - mean) / sigma
781 }
782 }
783}
784
785#[derive(Clone, Debug)]
786pub struct ZscoreBatchRange {
787 pub period: (usize, usize, usize),
788 pub ma_type: (String, String, String),
789 pub nbdev: (f64, f64, f64),
790 pub devtype: (usize, usize, usize),
791}
792
793impl Default for ZscoreBatchRange {
794 fn default() -> Self {
795 Self {
796 period: (14, 263, 1),
797 ma_type: ("sma".to_string(), "sma".to_string(), "".to_string()),
798 nbdev: (1.0, 1.0, 0.0),
799 devtype: (0, 0, 0),
800 }
801 }
802}
803
804#[derive(Clone, Debug, Default)]
805pub struct ZscoreBatchBuilder {
806 range: ZscoreBatchRange,
807 kernel: Kernel,
808}
809
810impl ZscoreBatchBuilder {
811 pub fn new() -> Self {
812 Self::default()
813 }
814 pub fn kernel(mut self, k: Kernel) -> Self {
815 self.kernel = k;
816 self
817 }
818 #[inline]
819 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
820 self.range.period = (start, end, step);
821 self
822 }
823 #[inline]
824 pub fn period_static(mut self, p: usize) -> Self {
825 self.range.period = (p, p, 0);
826 self
827 }
828 #[inline]
829 pub fn ma_type_static<T: Into<String>>(mut self, s: T) -> Self {
830 let val = s.into();
831 self.range.ma_type = (val.clone(), val.clone(), "".to_string());
832 self
833 }
834 #[inline]
835 pub fn nbdev_range(mut self, start: f64, end: f64, step: f64) -> Self {
836 self.range.nbdev = (start, end, step);
837 self
838 }
839 #[inline]
840 pub fn nbdev_static(mut self, x: f64) -> Self {
841 self.range.nbdev = (x, x, 0.0);
842 self
843 }
844 #[inline]
845 pub fn devtype_range(mut self, start: usize, end: usize, step: usize) -> Self {
846 self.range.devtype = (start, end, step);
847 self
848 }
849 #[inline]
850 pub fn devtype_static(mut self, x: usize) -> Self {
851 self.range.devtype = (x, x, 0);
852 self
853 }
854 pub fn apply_slice(self, data: &[f64]) -> Result<ZscoreBatchOutput, ZscoreError> {
855 zscore_batch_with_kernel(data, &self.range, self.kernel)
856 }
857 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<ZscoreBatchOutput, ZscoreError> {
858 ZscoreBatchBuilder::new().kernel(k).apply_slice(data)
859 }
860 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<ZscoreBatchOutput, ZscoreError> {
861 let slice = source_type(c, src);
862 self.apply_slice(slice)
863 }
864 pub fn with_default_candles(c: &Candles) -> Result<ZscoreBatchOutput, ZscoreError> {
865 ZscoreBatchBuilder::new()
866 .kernel(Kernel::Auto)
867 .apply_candles(c, "close")
868 }
869}
870
871pub fn zscore_batch_with_kernel(
872 data: &[f64],
873 sweep: &ZscoreBatchRange,
874 k: Kernel,
875) -> Result<ZscoreBatchOutput, ZscoreError> {
876 let kernel = match k {
877 Kernel::Auto => detect_best_batch_kernel(),
878 other if other.is_batch() => other,
879 other => return Err(ZscoreError::InvalidKernelForBatch(other)),
880 };
881
882 let simd = match kernel {
883 Kernel::Avx512Batch => Kernel::Avx512,
884 Kernel::Avx2Batch => Kernel::Avx2,
885 Kernel::ScalarBatch => Kernel::Scalar,
886 _ => unreachable!(),
887 };
888 zscore_batch_par_slice(data, sweep, simd)
889}
890
891#[derive(Clone, Debug)]
892pub struct ZscoreBatchOutput {
893 pub values: Vec<f64>,
894 pub combos: Vec<ZscoreParams>,
895 pub rows: usize,
896 pub cols: usize,
897}
898
899impl ZscoreBatchOutput {
900 pub fn row_for_params(&self, p: &ZscoreParams) -> Option<usize> {
901 self.combos.iter().position(|c| {
902 c.period.unwrap_or(14) == p.period.unwrap_or(14)
903 && c.ma_type.as_ref().unwrap_or(&"sma".to_string())
904 == p.ma_type.as_ref().unwrap_or(&"sma".to_string())
905 && (c.nbdev.unwrap_or(1.0) - p.nbdev.unwrap_or(1.0)).abs() < 1e-12
906 && c.devtype.unwrap_or(0) == p.devtype.unwrap_or(0)
907 })
908 }
909 pub fn values_for(&self, p: &ZscoreParams) -> Option<&[f64]> {
910 self.row_for_params(p).map(|row| {
911 let start = row * self.cols;
912 &self.values[start..start + self.cols]
913 })
914 }
915}
916
917#[inline(always)]
918fn expand_grid(r: &ZscoreBatchRange) -> Result<Vec<ZscoreParams>, ZscoreError> {
919 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, ZscoreError> {
920 if step == 0 || start == end {
921 return Ok(vec![start]);
922 }
923 let mut vals = Vec::new();
924 if start < end {
925 let mut v = start;
926 while v <= end {
927 vals.push(v);
928 match v.checked_add(step) {
929 Some(next) => {
930 if next == v {
931 break;
932 }
933 v = next;
934 }
935 None => break,
936 }
937 }
938 } else {
939 let mut v = start;
940 loop {
941 vals.push(v);
942 if v == end {
943 break;
944 }
945 let next = v.saturating_sub(step);
946 if next == v {
947 break;
948 }
949 v = next;
950 if v < end {
951 break;
952 }
953 }
954 }
955 if vals.is_empty() {
956 return Err(ZscoreError::InvalidRange {
957 start: start as f64,
958 end: end as f64,
959 step: step as f64,
960 });
961 }
962 Ok(vals)
963 }
964 fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, ZscoreError> {
965 if !step.is_finite() {
966 return Err(ZscoreError::InvalidRange { start, end, step });
967 }
968 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
969 return Ok(vec![start]);
970 }
971 let mut vals = Vec::new();
972 let tol = 1e-12;
973 if start <= end {
974 let step_pos = step.abs();
975 let mut x = start;
976 while x <= end + tol {
977 vals.push(x);
978 x += step_pos;
979 if !x.is_finite() {
980 break;
981 }
982 }
983 } else {
984 let step_neg = -step.abs();
985 let mut x = start;
986 while x >= end - tol {
987 vals.push(x);
988 x += step_neg;
989 if !x.is_finite() {
990 break;
991 }
992 }
993 }
994 if vals.is_empty() {
995 return Err(ZscoreError::InvalidRange { start, end, step });
996 }
997 Ok(vals)
998 }
999 fn axis_string((start, end, _step): (String, String, String)) -> Vec<String> {
1000 if start == end {
1001 vec![start]
1002 } else {
1003 vec![start]
1004 }
1005 }
1006
1007 let periods = axis_usize(r.period)?;
1008 let ma_types = axis_string(r.ma_type.clone());
1009 let nbdevs = axis_f64(r.nbdev)?;
1010 let devtypes = axis_usize(r.devtype)?;
1011
1012 let total = periods
1013 .len()
1014 .checked_mul(ma_types.len())
1015 .and_then(|v| v.checked_mul(nbdevs.len()))
1016 .and_then(|v| v.checked_mul(devtypes.len()))
1017 .ok_or(ZscoreError::InvalidRange {
1018 start: periods.len() as f64,
1019 end: devtypes.len() as f64,
1020 step: nbdevs.len() as f64,
1021 })?;
1022
1023 let mut out = Vec::with_capacity(total);
1024 for &p in &periods {
1025 for mt in &ma_types {
1026 for &n in &nbdevs {
1027 for &dt in &devtypes {
1028 out.push(ZscoreParams {
1029 period: Some(p),
1030 ma_type: Some(mt.clone()),
1031 nbdev: Some(n),
1032 devtype: Some(dt),
1033 });
1034 }
1035 }
1036 }
1037 }
1038 Ok(out)
1039}
1040
1041#[inline(always)]
1042pub fn zscore_batch_slice(
1043 data: &[f64],
1044 sweep: &ZscoreBatchRange,
1045 kern: Kernel,
1046) -> Result<ZscoreBatchOutput, ZscoreError> {
1047 zscore_batch_inner(data, sweep, kern, false)
1048}
1049
1050#[inline(always)]
1051pub fn zscore_batch_par_slice(
1052 data: &[f64],
1053 sweep: &ZscoreBatchRange,
1054 kern: Kernel,
1055) -> Result<ZscoreBatchOutput, ZscoreError> {
1056 zscore_batch_inner(data, sweep, kern, true)
1057}
1058
1059#[inline(always)]
1060fn zscore_batch_inner(
1061 data: &[f64],
1062 sweep: &ZscoreBatchRange,
1063 kern: Kernel,
1064 parallel: bool,
1065) -> Result<ZscoreBatchOutput, ZscoreError> {
1066 if data.is_empty() {
1067 return Err(ZscoreError::EmptyInputData);
1068 }
1069
1070 let combos = expand_grid(sweep)?;
1071
1072 let first = data
1073 .iter()
1074 .position(|x| !x.is_nan())
1075 .ok_or(ZscoreError::AllValuesNaN)?;
1076 let cols = data.len();
1077 let mut max_p = 0usize;
1078 for prm in &combos {
1079 let period = prm.period.unwrap();
1080 if period == 0 || period > cols {
1081 return Err(ZscoreError::InvalidPeriod {
1082 period,
1083 data_len: cols,
1084 });
1085 }
1086 if period > max_p {
1087 max_p = period;
1088 }
1089 }
1090 if cols - first < max_p {
1091 return Err(ZscoreError::NotEnoughValidData {
1092 needed: max_p,
1093 valid: cols - first,
1094 });
1095 }
1096
1097 let rows = combos.len();
1098
1099 rows.checked_mul(cols).ok_or(ZscoreError::InvalidRange {
1100 start: rows as f64,
1101 end: cols as f64,
1102 step: 0.0,
1103 })?;
1104
1105 let mut buf_mu = make_uninit_matrix(rows, cols);
1106
1107 let warmup_periods: Vec<usize> = combos
1108 .iter()
1109 .map(|c| {
1110 let p = c.period.unwrap();
1111 first
1112 .checked_add(p)
1113 .and_then(|v| v.checked_sub(1))
1114 .ok_or(ZscoreError::InvalidRange {
1115 start: first as f64,
1116 end: p as f64,
1117 step: 1.0,
1118 })
1119 })
1120 .collect::<Result<_, _>>()?;
1121
1122 init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
1123
1124 let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
1125 let out: &mut [f64] = unsafe {
1126 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
1127 };
1128
1129 let mut groups: HashMap<(usize, MaKind), Vec<(usize, f64)>> = HashMap::new();
1130 let mut fallback_rows: Vec<usize> = Vec::new();
1131 for (row_idx, prm) in combos.iter().enumerate() {
1132 let period = prm.period.unwrap();
1133 let ma_type = prm.ma_type.as_ref().unwrap();
1134 let devtype = prm.devtype.unwrap();
1135 match zscore_fast_batch_kind(ma_type, devtype) {
1136 Some(kind) => {
1137 groups
1138 .entry((period, kind))
1139 .or_default()
1140 .push((row_idx, prm.nbdev.unwrap()));
1141 }
1142 None => fallback_rows.push(row_idx),
1143 }
1144 }
1145
1146 let prefixes = if groups.keys().any(|(_, kind)| matches!(kind, MaKind::Sma)) {
1147 Some(build_sma_std_prefixes(data))
1148 } else {
1149 None
1150 };
1151
1152 let writer = RowWriter {
1153 ptr: out.as_mut_ptr(),
1154 cols,
1155 };
1156
1157 for ((period, kind), rows_for_period) in groups.into_iter() {
1158 let warmup_end = first + period - 1;
1159 let mut base = vec![f64::NAN; cols];
1160
1161 match kind {
1162 MaKind::Sma => match kern {
1163 Kernel::Scalar => {
1164 let pre = prefixes.as_ref().expect("prefixes missing for scalar path");
1165 zscore_sma_std_from_prefix_scalar(data, period, warmup_end, pre, &mut base);
1166 }
1167 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1168 Kernel::Avx2 => unsafe {
1169 let pre = prefixes.as_ref().expect("prefixes missing for AVX2 path");
1170 zscore_sma_std_from_prefix_avx2(data, period, warmup_end, pre, &mut base);
1171 },
1172 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1173 Kernel::Avx512 => unsafe {
1174 let pre = prefixes.as_ref().expect("prefixes missing for AVX512 path");
1175 zscore_sma_std_from_prefix_avx512(data, period, warmup_end, pre, &mut base);
1176 },
1177 _ => unreachable!(),
1178 },
1179 MaKind::Ema => unsafe {
1180 zscore_row_scalar_classic_ema(data, first, period, 1.0, &mut base);
1181 },
1182 _ => unreachable!(),
1183 }
1184
1185 let base_ref = &base;
1186
1187 let write_scalar = |row_idx: usize, nbdev: f64| unsafe {
1188 writer.with_row(row_idx, |dst| {
1189 scale_copy_row_scalar(base_ref, warmup_end, nbdev, dst);
1190 });
1191 };
1192
1193 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1194 let write_avx2 = |row_idx: usize, nbdev: f64| unsafe {
1195 writer.with_row(row_idx, |dst| {
1196 scale_copy_row_avx2(base_ref, warmup_end, nbdev, dst);
1197 });
1198 };
1199
1200 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1201 let write_avx512 = |row_idx: usize, nbdev: f64| unsafe {
1202 writer.with_row(row_idx, |dst| {
1203 scale_copy_row_avx512(base_ref, warmup_end, nbdev, dst);
1204 });
1205 };
1206
1207 let dispatch_write = |row_idx: usize, nbdev: f64| match kern {
1208 Kernel::Scalar => write_scalar(row_idx, nbdev),
1209 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1210 Kernel::Avx2 => write_avx2(row_idx, nbdev),
1211 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1212 Kernel::Avx512 => write_avx512(row_idx, nbdev),
1213 _ => unreachable!(),
1214 };
1215
1216 if parallel {
1217 #[cfg(not(target_arch = "wasm32"))]
1218 {
1219 use rayon::prelude::*;
1220 rows_for_period
1221 .par_iter()
1222 .for_each(|(row_idx, nb)| dispatch_write(*row_idx, *nb));
1223 }
1224 #[cfg(target_arch = "wasm32")]
1225 {
1226 for (row_idx, nb) in rows_for_period.iter() {
1227 dispatch_write(*row_idx, *nb);
1228 }
1229 }
1230 } else {
1231 for (row_idx, nb) in rows_for_period.iter() {
1232 dispatch_write(*row_idx, *nb);
1233 }
1234 }
1235 }
1236
1237 if !fallback_rows.is_empty() {
1238 let do_row = |row: usize| unsafe {
1239 let prm = &combos[row];
1240 let period = prm.period.unwrap();
1241 let ma_type = prm.ma_type.as_ref().unwrap();
1242 let nbdev = prm.nbdev.unwrap();
1243 let devtype = prm.devtype.unwrap();
1244 writer.with_row(row, |dst| match kern {
1245 Kernel::Scalar => {
1246 zscore_row_scalar(data, first, period, ma_type, nbdev, devtype, dst)
1247 }
1248 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1249 Kernel::Avx2 => zscore_row_avx2(data, first, period, ma_type, nbdev, devtype, dst),
1250 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1251 Kernel::Avx512 => {
1252 zscore_row_avx512(data, first, period, ma_type, nbdev, devtype, dst)
1253 }
1254 _ => unreachable!(),
1255 });
1256 };
1257
1258 if parallel {
1259 #[cfg(not(target_arch = "wasm32"))]
1260 {
1261 use rayon::prelude::*;
1262 fallback_rows.par_iter().for_each(|&row| do_row(row));
1263 }
1264 #[cfg(target_arch = "wasm32")]
1265 {
1266 for &row in &fallback_rows {
1267 do_row(row);
1268 }
1269 }
1270 } else {
1271 for &row in &fallback_rows {
1272 do_row(row);
1273 }
1274 }
1275 }
1276
1277 let values = unsafe {
1278 Vec::from_raw_parts(
1279 buf_guard.as_mut_ptr() as *mut f64,
1280 buf_guard.len(),
1281 buf_guard.capacity(),
1282 )
1283 };
1284
1285 Ok(ZscoreBatchOutput {
1286 values,
1287 combos,
1288 rows,
1289 cols,
1290 })
1291}
1292
1293#[inline(always)]
1294fn zscore_fast_batch_kind(ma_type: &str, devtype: usize) -> Option<MaKind> {
1295 if devtype != 0 {
1296 None
1297 } else if ma_type.eq_ignore_ascii_case("sma") {
1298 Some(MaKind::Sma)
1299 } else if ma_type.eq_ignore_ascii_case("ema") {
1300 Some(MaKind::Ema)
1301 } else {
1302 None
1303 }
1304}
1305
1306#[inline(always)]
1307unsafe fn zscore_row_scalar(
1308 data: &[f64],
1309 first: usize,
1310 period: usize,
1311 ma_type: &str,
1312 nbdev: f64,
1313 devtype: usize,
1314 out: &mut [f64],
1315) {
1316 if devtype == 0 {
1317 if ma_type == "sma" {
1318 zscore_row_scalar_classic_sma(data, first, period, nbdev, out);
1319 return;
1320 } else if ma_type == "ema" {
1321 zscore_row_scalar_classic_ema(data, first, period, nbdev, out);
1322 return;
1323 }
1324 }
1325
1326 let means = match ma(ma_type, MaData::Slice(data), period) {
1327 Ok(m) => m,
1328 Err(_) => {
1329 out.fill(f64::NAN);
1330 return;
1331 }
1332 };
1333 let dev_input = DevInput {
1334 data: DeviationData::Slice(data),
1335 params: DevParams {
1336 period: Some(period),
1337 devtype: Some(devtype),
1338 },
1339 };
1340 let mut sigmas = match deviation(&dev_input) {
1341 Ok(d) => d.values,
1342 Err(_) => {
1343 out.fill(f64::NAN);
1344 return;
1345 }
1346 };
1347 for v in &mut sigmas {
1348 *v *= nbdev;
1349 }
1350 let warmup_end = first + period - 1;
1351 for i in warmup_end..data.len() {
1352 let mean = means[i];
1353 let sigma = sigmas[i];
1354 let value = data[i];
1355 out[i] = if sigma == 0.0 || sigma.is_nan() {
1356 f64::NAN
1357 } else {
1358 (value - mean) / sigma
1359 };
1360 }
1361}
1362
1363#[inline(always)]
1364unsafe fn zscore_row_scalar_classic_sma(
1365 data: &[f64],
1366 first: usize,
1367 period: usize,
1368 nbdev: f64,
1369 out: &mut [f64],
1370) {
1371 let warmup_end = first + period - 1;
1372
1373 let mut sum = 0.0;
1374 let mut sum_sqr = 0.0;
1375 for j in first..=warmup_end {
1376 let val = data[j];
1377 sum += val;
1378 sum_sqr += val * val;
1379 }
1380
1381 let mut mean = sum / period as f64;
1382 let mut variance = (sum_sqr / period as f64) - (mean * mean);
1383 let mut stddev = if variance <= 0.0 {
1384 0.0
1385 } else {
1386 variance.sqrt() * nbdev
1387 };
1388
1389 out[warmup_end] = if stddev == 0.0 || stddev.is_nan() {
1390 f64::NAN
1391 } else {
1392 (data[warmup_end] - mean) / stddev
1393 };
1394
1395 for i in warmup_end + 1..data.len() {
1396 let old_val = data[i - period];
1397 let new_val = data[i];
1398
1399 sum += new_val - old_val;
1400 sum_sqr += new_val * new_val - old_val * old_val;
1401
1402 mean = sum / period as f64;
1403 variance = (sum_sqr / period as f64) - (mean * mean);
1404 stddev = if variance <= 0.0 {
1405 0.0
1406 } else {
1407 variance.sqrt() * nbdev
1408 };
1409
1410 out[i] = if stddev == 0.0 || stddev.is_nan() {
1411 f64::NAN
1412 } else {
1413 (new_val - mean) / stddev
1414 };
1415 }
1416}
1417
1418#[derive(Clone, Debug)]
1419struct SmaStdPrefixes {
1420 ps: Vec<f64>,
1421 ps2: Vec<f64>,
1422 pnan: Vec<i32>,
1423}
1424
1425#[inline]
1426fn build_sma_std_prefixes(data: &[f64]) -> SmaStdPrefixes {
1427 let n = data.len();
1428 let mut ps = vec![0.0f64; n + 1];
1429 let mut ps2 = vec![0.0f64; n + 1];
1430 let mut pnan = vec![0i32; n + 1];
1431
1432 for i in 0..n {
1433 let v = data[i];
1434 if v.is_nan() {
1435 ps[i + 1] = ps[i];
1436 ps2[i + 1] = ps2[i];
1437 pnan[i + 1] = pnan[i] + 1;
1438 } else {
1439 ps[i + 1] = ps[i] + v;
1440 ps2[i + 1] = ps2[i] + v * v;
1441 pnan[i + 1] = pnan[i];
1442 }
1443 }
1444
1445 SmaStdPrefixes { ps, ps2, pnan }
1446}
1447
1448#[inline]
1449fn zscore_sma_std_from_prefix_scalar(
1450 data: &[f64],
1451 period: usize,
1452 warmup_end: usize,
1453 pre: &SmaStdPrefixes,
1454 base_out: &mut [f64],
1455) {
1456 let n = data.len();
1457 debug_assert_eq!(base_out.len(), n);
1458
1459 for v in &mut base_out[..warmup_end] {
1460 *v = f64::NAN;
1461 }
1462 if n <= warmup_end {
1463 return;
1464 }
1465
1466 let denom = period as f64;
1467 for i in warmup_end..n {
1468 let nan_count = pre.pnan[i + 1] - pre.pnan[i + 1 - period];
1469 if nan_count > 0 {
1470 base_out[i] = f64::NAN;
1471 continue;
1472 }
1473
1474 let sum = pre.ps[i + 1] - pre.ps[i + 1 - period];
1475 let sum2 = pre.ps2[i + 1] - pre.ps2[i + 1 - period];
1476 let mean = sum / denom;
1477 let variance = sum2 / denom - mean * mean;
1478 let stdv = if variance <= 0.0 {
1479 0.0
1480 } else {
1481 variance.sqrt()
1482 };
1483 base_out[i] = if stdv == 0.0 || stdv.is_nan() {
1484 f64::NAN
1485 } else {
1486 (data[i] - mean) / stdv
1487 };
1488 }
1489}
1490
1491#[inline]
1492fn scale_copy_row_scalar(src_base: &[f64], warmup_end: usize, nbdev: f64, dst: &mut [f64]) {
1493 debug_assert_eq!(src_base.len(), dst.len());
1494
1495 if warmup_end > 0 {
1496 dst[..warmup_end].copy_from_slice(&src_base[..warmup_end]);
1497 }
1498
1499 if dst.len() <= warmup_end {
1500 return;
1501 }
1502
1503 if nbdev == 0.0 {
1504 for v in &mut dst[warmup_end..] {
1505 *v = f64::NAN;
1506 }
1507 return;
1508 }
1509
1510 for (d, s) in dst[warmup_end..].iter_mut().zip(&src_base[warmup_end..]) {
1511 *d = *s / nbdev;
1512 }
1513}
1514
1515#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1516#[target_feature(enable = "avx2")]
1517#[inline]
1518unsafe fn zscore_sma_std_from_prefix_avx2(
1519 data: &[f64],
1520 period: usize,
1521 warmup_end: usize,
1522 pre: &SmaStdPrefixes,
1523 base_out: &mut [f64],
1524) {
1525 let n = data.len();
1526 debug_assert_eq!(base_out.len(), n);
1527
1528 for v in &mut base_out[..warmup_end] {
1529 *v = f64::NAN;
1530 }
1531 if n <= warmup_end {
1532 return;
1533 }
1534
1535 let den = _mm256_set1_pd(period as f64);
1536 let zero = _mm256_set1_pd(0.0);
1537 let nanv = _mm256_set1_pd(f64::NAN);
1538
1539 let mut i = warmup_end;
1540 while i + 4 <= n {
1541 let s_hi = _mm256_loadu_pd(pre.ps.as_ptr().add(i + 1));
1542 let s_lo = _mm256_loadu_pd(pre.ps.as_ptr().add(i + 1 - period));
1543 let sum = _mm256_sub_pd(s_hi, s_lo);
1544
1545 let q_hi = _mm256_loadu_pd(pre.ps2.as_ptr().add(i + 1));
1546 let q_lo = _mm256_loadu_pd(pre.ps2.as_ptr().add(i + 1 - period));
1547 let sum2 = _mm256_sub_pd(q_hi, q_lo);
1548
1549 let mean = _mm256_div_pd(sum, den);
1550 let var = _mm256_sub_pd(_mm256_div_pd(sum2, den), _mm256_mul_pd(mean, mean));
1551 let var_nz = _mm256_max_pd(var, zero);
1552 let stdv = _mm256_sqrt_pd(var_nz);
1553
1554 let x = _mm256_loadu_pd(data.as_ptr().add(i));
1555 let z = _mm256_div_pd(_mm256_sub_pd(x, mean), stdv);
1556
1557 let m_std0 = _mm256_cmp_pd(stdv, zero, _CMP_EQ_OQ);
1558 let m_stdnan = _mm256_cmp_pd(stdv, stdv, _CMP_UNORD_Q);
1559
1560 let cur = _mm_loadu_si128(pre.pnan.as_ptr().add(i + 1) as *const _);
1561 let prev = _mm_loadu_si128(pre.pnan.as_ptr().add(i + 1 - period) as *const _);
1562 let diff = _mm_sub_epi32(cur, prev);
1563 let diff_pd = _mm256_cvtepi32_pd(diff);
1564 let m_hasnan = _mm256_cmp_pd(diff_pd, zero, _CMP_GT_OQ);
1565
1566 let mask = _mm256_or_pd(_mm256_or_pd(m_std0, m_stdnan), m_hasnan);
1567 let res = _mm256_blendv_pd(z, nanv, mask);
1568 _mm256_storeu_pd(base_out.as_mut_ptr().add(i), res);
1569
1570 i += 4;
1571 }
1572
1573 let den_s = period as f64;
1574 while i < n {
1575 let count = pre.pnan[i + 1] - pre.pnan[i + 1 - period];
1576 if count > 0 {
1577 base_out[i] = f64::NAN;
1578 } else {
1579 let sum = pre.ps[i + 1] - pre.ps[i + 1 - period];
1580 let sum2 = pre.ps2[i + 1] - pre.ps2[i + 1 - period];
1581 let mean = sum / den_s;
1582 let variance = sum2 / den_s - mean * mean;
1583 let sd = if variance <= 0.0 {
1584 0.0
1585 } else {
1586 variance.sqrt()
1587 };
1588 base_out[i] = if sd == 0.0 || sd.is_nan() {
1589 f64::NAN
1590 } else {
1591 (data[i] - mean) / sd
1592 };
1593 }
1594 i += 1;
1595 }
1596}
1597
1598#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1599#[target_feature(enable = "avx512f")]
1600#[inline]
1601unsafe fn zscore_sma_std_from_prefix_avx512(
1602 data: &[f64],
1603 period: usize,
1604 warmup_end: usize,
1605 pre: &SmaStdPrefixes,
1606 base_out: &mut [f64],
1607) {
1608 let n = data.len();
1609 debug_assert_eq!(base_out.len(), n);
1610
1611 for v in &mut base_out[..warmup_end] {
1612 *v = f64::NAN;
1613 }
1614 if n <= warmup_end {
1615 return;
1616 }
1617
1618 let den = _mm512_set1_pd(period as f64);
1619 let zero = _mm512_set1_pd(0.0);
1620 let nanv = _mm512_set1_pd(f64::NAN);
1621
1622 let mut i = warmup_end;
1623 while i + 8 <= n {
1624 let s_hi = _mm512_loadu_pd(pre.ps.as_ptr().add(i + 1));
1625 let s_lo = _mm512_loadu_pd(pre.ps.as_ptr().add(i + 1 - period));
1626 let sum = _mm512_sub_pd(s_hi, s_lo);
1627
1628 let q_hi = _mm512_loadu_pd(pre.ps2.as_ptr().add(i + 1));
1629 let q_lo = _mm512_loadu_pd(pre.ps2.as_ptr().add(i + 1 - period));
1630 let sum2 = _mm512_sub_pd(q_hi, q_lo);
1631
1632 let mean = _mm512_div_pd(sum, den);
1633 let var = _mm512_sub_pd(_mm512_div_pd(sum2, den), _mm512_mul_pd(mean, mean));
1634 let var_nz = _mm512_max_pd(var, zero);
1635 let stdv = _mm512_sqrt_pd(var_nz);
1636
1637 let x = _mm512_loadu_pd(data.as_ptr().add(i));
1638 let z = _mm512_div_pd(_mm512_sub_pd(x, mean), stdv);
1639
1640 let k_std0 = _mm512_cmp_pd_mask(stdv, zero, _CMP_EQ_OQ);
1641 let k_stdnan = _mm512_cmp_pd_mask(stdv, stdv, _CMP_UNORD_Q);
1642
1643 let cur_i = _mm256_loadu_si256(pre.pnan.as_ptr().add(i + 1) as *const _);
1644 let prev_i = _mm256_loadu_si256(pre.pnan.as_ptr().add(i + 1 - period) as *const _);
1645 let diff_i = _mm256_sub_epi32(cur_i, prev_i);
1646 let diff_pd = _mm512_cvtepi32_pd(diff_i);
1647 let k_hasnan = _mm512_cmp_pd_mask(diff_pd, zero, _CMP_GT_OQ);
1648
1649 let k_bad = k_std0 | k_stdnan | k_hasnan;
1650 let res = _mm512_mask_mov_pd(z, k_bad, nanv);
1651 _mm512_storeu_pd(base_out.as_mut_ptr().add(i), res);
1652
1653 i += 8;
1654 }
1655
1656 let den_s = period as f64;
1657 while i < n {
1658 let count = pre.pnan[i + 1] - pre.pnan[i + 1 - period];
1659 if count > 0 {
1660 base_out[i] = f64::NAN;
1661 } else {
1662 let sum = pre.ps[i + 1] - pre.ps[i + 1 - period];
1663 let sum2 = pre.ps2[i + 1] - pre.ps2[i + 1 - period];
1664 let mean = sum / den_s;
1665 let variance = sum2 / den_s - mean * mean;
1666 let sd = if variance <= 0.0 {
1667 0.0
1668 } else {
1669 variance.sqrt()
1670 };
1671 base_out[i] = if sd == 0.0 || sd.is_nan() {
1672 f64::NAN
1673 } else {
1674 (data[i] - mean) / sd
1675 };
1676 }
1677 i += 1;
1678 }
1679}
1680
1681#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1682#[target_feature(enable = "avx2")]
1683#[inline]
1684unsafe fn scale_copy_row_avx2(src_base: &[f64], warmup_end: usize, nbdev: f64, dst: &mut [f64]) {
1685 debug_assert_eq!(src_base.len(), dst.len());
1686
1687 if warmup_end > 0 {
1688 dst[..warmup_end].copy_from_slice(&src_base[..warmup_end]);
1689 }
1690
1691 if dst.len() <= warmup_end {
1692 return;
1693 }
1694
1695 if nbdev == 0.0 {
1696 for v in &mut dst[warmup_end..] {
1697 *v = f64::NAN;
1698 }
1699 return;
1700 }
1701
1702 let inv = _mm256_set1_pd(1.0 / nbdev);
1703 let mut i = warmup_end;
1704 while i + 4 <= dst.len() {
1705 let v = _mm256_loadu_pd(src_base.as_ptr().add(i));
1706 let y = _mm256_mul_pd(v, inv);
1707 _mm256_storeu_pd(dst.as_mut_ptr().add(i), y);
1708 i += 4;
1709 }
1710 while i < dst.len() {
1711 dst[i] = src_base[i] / nbdev;
1712 i += 1;
1713 }
1714}
1715
1716#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1717#[target_feature(enable = "avx512f")]
1718#[inline]
1719unsafe fn scale_copy_row_avx512(src_base: &[f64], warmup_end: usize, nbdev: f64, dst: &mut [f64]) {
1720 debug_assert_eq!(src_base.len(), dst.len());
1721
1722 if warmup_end > 0 {
1723 dst[..warmup_end].copy_from_slice(&src_base[..warmup_end]);
1724 }
1725
1726 if dst.len() <= warmup_end {
1727 return;
1728 }
1729
1730 if nbdev == 0.0 {
1731 for v in &mut dst[warmup_end..] {
1732 *v = f64::NAN;
1733 }
1734 return;
1735 }
1736
1737 let inv = _mm512_set1_pd(1.0 / nbdev);
1738 let mut i = warmup_end;
1739 while i + 8 <= dst.len() {
1740 let v = _mm512_loadu_pd(src_base.as_ptr().add(i));
1741 let y = _mm512_mul_pd(v, inv);
1742 _mm512_storeu_pd(dst.as_mut_ptr().add(i), y);
1743 i += 8;
1744 }
1745 while i < dst.len() {
1746 dst[i] = src_base[i] / nbdev;
1747 i += 1;
1748 }
1749}
1750
1751#[derive(Clone, Copy)]
1752struct RowWriter {
1753 ptr: *mut f64,
1754 cols: usize,
1755}
1756
1757unsafe impl Send for RowWriter {}
1758unsafe impl Sync for RowWriter {}
1759
1760impl RowWriter {
1761 #[inline(always)]
1762 unsafe fn with_row<F>(&self, row: usize, mut f: F)
1763 where
1764 F: FnOnce(&mut [f64]),
1765 {
1766 let slice = std::slice::from_raw_parts_mut(self.ptr.add(row * self.cols), self.cols);
1767 f(slice);
1768 }
1769}
1770
1771#[inline(always)]
1772unsafe fn zscore_row_scalar_classic_ema(
1773 data: &[f64],
1774 first: usize,
1775 period: usize,
1776 nbdev: f64,
1777 out: &mut [f64],
1778) {
1779 let n = data.len();
1780 let warmup_end = first + period - 1;
1781
1782 if n <= warmup_end {
1783 return;
1784 }
1785
1786 let den = period as f64;
1787 let alpha = 2.0 / (den + 1.0);
1788 let one_minus_alpha = 1.0 - alpha;
1789
1790 let mut sum = 0.0;
1791 let mut sum2 = 0.0;
1792 {
1793 let mut j = first;
1794 while j <= warmup_end {
1795 let v = data[j];
1796 sum += v;
1797 sum2 += v * v;
1798 j += 1;
1799 }
1800 }
1801 let mut ema = sum / den;
1802
1803 let mut mse = (sum2 / den) - 2.0 * ema * (sum / den) + ema * ema;
1804 if mse < 0.0 {
1805 mse = 0.0;
1806 }
1807 let mut sd = mse.sqrt() * nbdev;
1808
1809 out[warmup_end] = if sd == 0.0 || sd.is_nan() {
1810 f64::NAN
1811 } else {
1812 (data[warmup_end] - ema) / sd
1813 };
1814
1815 let mut i = warmup_end + 1;
1816 while i < n {
1817 let new = data[i];
1818 let old = data[i - period];
1819
1820 sum += new - old;
1821 sum2 += new * new - old * old;
1822
1823 ema = alpha * new + one_minus_alpha * ema;
1824
1825 mse = (sum2 / den) - 2.0 * ema * (sum / den) + ema * ema;
1826 if mse < 0.0 {
1827 mse = 0.0;
1828 }
1829 sd = mse.sqrt() * nbdev;
1830
1831 out[i] = if sd == 0.0 || sd.is_nan() {
1832 f64::NAN
1833 } else {
1834 (new - ema) / sd
1835 };
1836 i += 1;
1837 }
1838}
1839
1840#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1841#[inline(always)]
1842unsafe fn zscore_row_avx2(
1843 data: &[f64],
1844 first: usize,
1845 period: usize,
1846 ma_type: &str,
1847 nbdev: f64,
1848 devtype: usize,
1849 out: &mut [f64],
1850) {
1851 zscore_row_scalar(data, first, period, ma_type, nbdev, devtype, out)
1852}
1853
1854#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1855#[inline(always)]
1856unsafe fn zscore_row_avx512(
1857 data: &[f64],
1858 first: usize,
1859 period: usize,
1860 ma_type: &str,
1861 nbdev: f64,
1862 devtype: usize,
1863 out: &mut [f64],
1864) {
1865 if period <= 32 {
1866 zscore_row_avx512_short(data, first, period, ma_type, nbdev, devtype, out)
1867 } else {
1868 zscore_row_avx512_long(data, first, period, ma_type, nbdev, devtype, out)
1869 }
1870}
1871
1872#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1873#[inline(always)]
1874unsafe fn zscore_row_avx512_short(
1875 data: &[f64],
1876 first: usize,
1877 period: usize,
1878 ma_type: &str,
1879 nbdev: f64,
1880 devtype: usize,
1881 out: &mut [f64],
1882) {
1883 zscore_row_scalar(data, first, period, ma_type, nbdev, devtype, out)
1884}
1885
1886#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1887#[inline(always)]
1888unsafe fn zscore_row_avx512_long(
1889 data: &[f64],
1890 first: usize,
1891 period: usize,
1892 ma_type: &str,
1893 nbdev: f64,
1894 devtype: usize,
1895 out: &mut [f64],
1896) {
1897 zscore_row_scalar(data, first, period, ma_type, nbdev, devtype, out)
1898}
1899
1900#[inline(always)]
1901pub fn zscore_batch_inner_into(
1902 data: &[f64],
1903 sweep: &ZscoreBatchRange,
1904 kern: Kernel,
1905 parallel: bool,
1906 out: &mut [f64],
1907) -> Result<Vec<ZscoreParams>, ZscoreError> {
1908 if data.is_empty() {
1909 return Err(ZscoreError::EmptyInputData);
1910 }
1911
1912 let combos = expand_grid(sweep)?;
1913
1914 let first = data
1915 .iter()
1916 .position(|x| !x.is_nan())
1917 .ok_or(ZscoreError::AllValuesNaN)?;
1918 let cols = data.len();
1919 let mut max_p = 0usize;
1920 for prm in &combos {
1921 let period = prm.period.unwrap();
1922 if period == 0 || period > cols {
1923 return Err(ZscoreError::InvalidPeriod {
1924 period,
1925 data_len: cols,
1926 });
1927 }
1928 if period > max_p {
1929 max_p = period;
1930 }
1931 }
1932 if cols - first < max_p {
1933 return Err(ZscoreError::NotEnoughValidData {
1934 needed: max_p,
1935 valid: cols - first,
1936 });
1937 }
1938
1939 let rows = combos.len();
1940
1941 let expected = rows.checked_mul(cols).ok_or(ZscoreError::InvalidRange {
1942 start: rows as f64,
1943 end: cols as f64,
1944 step: 0.0,
1945 })?;
1946 if out.len() != expected {
1947 return Err(ZscoreError::OutputLengthMismatch {
1948 expected,
1949 got: out.len(),
1950 });
1951 }
1952
1953 let warm: Vec<usize> = combos
1954 .iter()
1955 .map(|c| first + c.period.unwrap() - 1)
1956 .collect();
1957 {
1958 let out_uninit = unsafe {
1959 std::slice::from_raw_parts_mut(
1960 out.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
1961 out.len(),
1962 )
1963 };
1964 init_matrix_prefixes(out_uninit, cols, &warm);
1965 }
1966
1967 let mut groups: HashMap<(usize, MaKind), Vec<(usize, f64)>> = HashMap::new();
1968 let mut fallback_rows: Vec<usize> = Vec::new();
1969 for (row_idx, prm) in combos.iter().enumerate() {
1970 let period = prm.period.unwrap();
1971 let ma_type = prm.ma_type.as_ref().unwrap();
1972 let devtype = prm.devtype.unwrap();
1973 match zscore_fast_batch_kind(ma_type, devtype) {
1974 Some(kind) => {
1975 groups
1976 .entry((period, kind))
1977 .or_default()
1978 .push((row_idx, prm.nbdev.unwrap()));
1979 }
1980 None => fallback_rows.push(row_idx),
1981 }
1982 }
1983
1984 let prefixes = if groups.keys().any(|(_, kind)| matches!(kind, MaKind::Sma)) {
1985 Some(build_sma_std_prefixes(data))
1986 } else {
1987 None
1988 };
1989
1990 let writer = RowWriter {
1991 ptr: out.as_mut_ptr(),
1992 cols,
1993 };
1994
1995 for ((period, kind), rows_for_period) in groups.into_iter() {
1996 let warmup_end = first + period - 1;
1997 let mut base = vec![f64::NAN; cols];
1998
1999 match kind {
2000 MaKind::Sma => match kern {
2001 Kernel::Scalar => {
2002 let pre = prefixes.as_ref().expect("prefixes missing for scalar path");
2003 zscore_sma_std_from_prefix_scalar(data, period, warmup_end, pre, &mut base);
2004 }
2005 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2006 Kernel::Avx2 => unsafe {
2007 let pre = prefixes.as_ref().expect("prefixes missing for AVX2 path");
2008 zscore_sma_std_from_prefix_avx2(data, period, warmup_end, pre, &mut base);
2009 },
2010 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2011 Kernel::Avx512 => unsafe {
2012 let pre = prefixes.as_ref().expect("prefixes missing for AVX512 path");
2013 zscore_sma_std_from_prefix_avx512(data, period, warmup_end, pre, &mut base);
2014 },
2015 _ => unreachable!(),
2016 },
2017 MaKind::Ema => unsafe {
2018 zscore_row_scalar_classic_ema(data, first, period, 1.0, &mut base);
2019 },
2020 _ => unreachable!(),
2021 }
2022
2023 let base_ref = &base;
2024
2025 let write_scalar = |row_idx: usize, nbdev: f64| unsafe {
2026 writer.with_row(row_idx, |dst| {
2027 scale_copy_row_scalar(base_ref, warmup_end, nbdev, dst);
2028 });
2029 };
2030
2031 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2032 let write_avx2 = |row_idx: usize, nbdev: f64| unsafe {
2033 writer.with_row(row_idx, |dst| {
2034 scale_copy_row_avx2(base_ref, warmup_end, nbdev, dst);
2035 });
2036 };
2037
2038 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2039 let write_avx512 = |row_idx: usize, nbdev: f64| unsafe {
2040 writer.with_row(row_idx, |dst| {
2041 scale_copy_row_avx512(base_ref, warmup_end, nbdev, dst);
2042 });
2043 };
2044
2045 let dispatch_write = |row_idx: usize, nbdev: f64| match kern {
2046 Kernel::Scalar => write_scalar(row_idx, nbdev),
2047 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2048 Kernel::Avx2 => write_avx2(row_idx, nbdev),
2049 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2050 Kernel::Avx512 => write_avx512(row_idx, nbdev),
2051 _ => unreachable!(),
2052 };
2053
2054 if parallel {
2055 #[cfg(not(target_arch = "wasm32"))]
2056 {
2057 use rayon::prelude::*;
2058 rows_for_period
2059 .par_iter()
2060 .for_each(|(row_idx, nb)| dispatch_write(*row_idx, *nb));
2061 }
2062 #[cfg(target_arch = "wasm32")]
2063 {
2064 for (row_idx, nb) in rows_for_period.iter() {
2065 dispatch_write(*row_idx, *nb);
2066 }
2067 }
2068 } else {
2069 for (row_idx, nb) in rows_for_period.iter() {
2070 dispatch_write(*row_idx, *nb);
2071 }
2072 }
2073 }
2074
2075 if !fallback_rows.is_empty() {
2076 let do_row = |row: usize| unsafe {
2077 let prm = &combos[row];
2078 let period = prm.period.unwrap();
2079 let ma_type = prm.ma_type.as_ref().unwrap();
2080 let nbdev = prm.nbdev.unwrap();
2081 let devtype = prm.devtype.unwrap();
2082 writer.with_row(row, |dst| match kern {
2083 Kernel::Scalar => {
2084 zscore_row_scalar(data, first, period, ma_type, nbdev, devtype, dst)
2085 }
2086 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2087 Kernel::Avx2 => zscore_row_avx2(data, first, period, ma_type, nbdev, devtype, dst),
2088 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2089 Kernel::Avx512 => {
2090 zscore_row_avx512(data, first, period, ma_type, nbdev, devtype, dst)
2091 }
2092 _ => unreachable!(),
2093 });
2094 };
2095
2096 if parallel {
2097 #[cfg(not(target_arch = "wasm32"))]
2098 {
2099 use rayon::prelude::*;
2100 fallback_rows.par_iter().for_each(|&row| do_row(row));
2101 }
2102 #[cfg(target_arch = "wasm32")]
2103 {
2104 for &row in &fallback_rows {
2105 do_row(row);
2106 }
2107 }
2108 } else {
2109 for &row in &fallback_rows {
2110 do_row(row);
2111 }
2112 }
2113 }
2114
2115 Ok(combos)
2116}
2117
2118#[cfg(feature = "python")]
2119#[pyfunction(name = "zscore")]
2120#[pyo3(signature = (data, period=14, ma_type="sma", nbdev=1.0, devtype=0, kernel=None))]
2121pub fn zscore_py<'py>(
2122 py: Python<'py>,
2123 data: numpy::PyReadonlyArray1<'py, f64>,
2124 period: usize,
2125 ma_type: &str,
2126 nbdev: f64,
2127 devtype: usize,
2128 kernel: Option<&str>,
2129) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
2130 use numpy::{IntoPyArray, PyArrayMethods};
2131
2132 let slice_in = data.as_slice()?;
2133 let kern = validate_kernel(kernel, false)?;
2134 let params = ZscoreParams {
2135 period: Some(period),
2136 ma_type: Some(ma_type.to_string()),
2137 nbdev: Some(nbdev),
2138 devtype: Some(devtype),
2139 };
2140 let input = ZscoreInput::from_slice(slice_in, params);
2141
2142 let result_vec: Vec<f64> = py
2143 .allow_threads(|| zscore_with_kernel(&input, kern).map(|o| o.values))
2144 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2145
2146 Ok(result_vec.into_pyarray(py))
2147}
2148
2149#[cfg(all(feature = "python", feature = "cuda"))]
2150#[pyclass(module = "ta_indicators.cuda", unsendable)]
2151pub struct ZscoreDeviceArrayF32Py {
2152 pub(crate) inner: DeviceArrayF32,
2153 _ctx_guard: Arc<Context>,
2154 _device_id: u32,
2155}
2156
2157#[cfg(all(feature = "python", feature = "cuda"))]
2158#[pymethods]
2159impl ZscoreDeviceArrayF32Py {
2160 #[getter]
2161 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2162 let d = PyDict::new(py);
2163 let itemsize = std::mem::size_of::<f32>();
2164 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
2165 d.set_item("typestr", "<f4")?;
2166 d.set_item("strides", (self.inner.cols * itemsize, itemsize))?;
2167 let ptr_val = self.inner.buf.as_device_ptr().as_raw() as usize;
2168 d.set_item("data", (ptr_val, false))?;
2169
2170 d.set_item("version", 3)?;
2171 Ok(d)
2172 }
2173
2174 fn __dlpack_device__(&self) -> (i32, i32) {
2175 (2, self._device_id as i32)
2176 }
2177
2178 #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
2179 fn __dlpack__<'py>(
2180 &mut self,
2181 py: Python<'py>,
2182 stream: Option<pyo3::PyObject>,
2183 max_version: Option<pyo3::PyObject>,
2184 dl_device: Option<pyo3::PyObject>,
2185 copy: Option<pyo3::PyObject>,
2186 ) -> PyResult<PyObject> {
2187 let (dev_type, alloc_dev) = self.__dlpack_device__();
2188 if let Some(dev_obj) = dl_device.as_ref() {
2189 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2190 if dev_ty != dev_type || dev_id != alloc_dev {
2191 return Err(PyValueError::new_err(
2192 "zscore: dl_device mismatch; cross-device copy not implemented",
2193 ));
2194 }
2195 }
2196 }
2197 let _ = stream;
2198 let _ = copy;
2199
2200 let dummy =
2201 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2202 let inner = std::mem::replace(
2203 &mut self.inner,
2204 DeviceArrayF32 {
2205 buf: dummy,
2206 rows: 0,
2207 cols: 0,
2208 },
2209 );
2210 let rows = inner.rows;
2211 let cols = inner.cols;
2212 let buf = inner.buf;
2213
2214 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2215
2216 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2217 }
2218}
2219
2220#[cfg(all(feature = "python", feature = "cuda"))]
2221impl ZscoreDeviceArrayF32Py {
2222 pub fn new_from_rust(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
2223 Self {
2224 inner,
2225 _ctx_guard: ctx_guard,
2226 _device_id: device_id,
2227 }
2228 }
2229}
2230
2231#[cfg(all(feature = "python", feature = "cuda"))]
2232#[pyfunction(name = "zscore_cuda_batch_dev")]
2233#[pyo3(signature = (data_f32, period_range, nbdev_range=(1.0, 1.0, 0.0), device_id=0))]
2234pub fn zscore_cuda_batch_dev_py<'py>(
2235 py: Python<'py>,
2236 data_f32: numpy::PyReadonlyArray1<'py, f32>,
2237 period_range: (usize, usize, usize),
2238 nbdev_range: (f64, f64, f64),
2239 device_id: usize,
2240) -> PyResult<(ZscoreDeviceArrayF32Py, Bound<'py, PyDict>)> {
2241 use crate::cuda::cuda_available;
2242
2243 if !cuda_available() {
2244 return Err(PyValueError::new_err("CUDA not available"));
2245 }
2246
2247 let slice_in = data_f32.as_slice()?;
2248 let sweep = ZscoreBatchRange {
2249 period: period_range,
2250 ma_type: ("sma".to_string(), "sma".to_string(), "".to_string()),
2251 nbdev: nbdev_range,
2252 devtype: (0, 0, 0),
2253 };
2254
2255 let (inner, ctx, dev_id, combos) = py.allow_threads(|| {
2256 let cuda = CudaZscore::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2257 let ctx = cuda.context_arc();
2258 let dev_id = cuda.device_id();
2259 let (arr, combos) = cuda
2260 .zscore_batch_dev(slice_in, &sweep)
2261 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2262 Ok::<_, pyo3::PyErr>((arr, ctx, dev_id, combos))
2263 })?;
2264
2265 let dict = PyDict::new(py);
2266 let periods: Vec<u64> = combos.iter().map(|(p, _)| *p as u64).collect();
2267 let nbdevs: Vec<f64> = combos.iter().map(|(_, nb)| *nb as f64).collect();
2268 let devtypes: Vec<u64> = combos.iter().map(|_| 0u64).collect();
2269 let ma_types = PyList::new(py, vec!["sma"; combos.len()])?;
2270
2271 dict.set_item("periods", periods.into_pyarray(py))?;
2272 dict.set_item("nbdevs", nbdevs.into_pyarray(py))?;
2273 dict.set_item("ma_types", ma_types)?;
2274 dict.set_item("devtypes", devtypes.into_pyarray(py))?;
2275
2276 Ok((
2277 ZscoreDeviceArrayF32Py::new_from_rust(inner, ctx, dev_id),
2278 dict,
2279 ))
2280}
2281
2282#[cfg(all(feature = "python", feature = "cuda"))]
2283#[pyfunction(name = "zscore_cuda_many_series_one_param_dev")]
2284#[pyo3(signature = (data_tm_f32, cols, rows, period, nbdev=1.0, device_id=0))]
2285pub fn zscore_cuda_many_series_one_param_dev_py<'py>(
2286 py: Python<'py>,
2287 data_tm_f32: numpy::PyReadonlyArray1<'py, f32>,
2288 cols: usize,
2289 rows: usize,
2290 period: usize,
2291 nbdev: f64,
2292 device_id: usize,
2293) -> PyResult<ZscoreDeviceArrayF32Py> {
2294 use crate::cuda::cuda_available;
2295 if !cuda_available() {
2296 return Err(PyValueError::new_err("CUDA not available"));
2297 }
2298
2299 if nbdev < 0.0 || !nbdev.is_finite() {
2300 return Err(PyValueError::new_err(
2301 "nbdev must be non-negative and finite",
2302 ));
2303 }
2304
2305 let slice_in = data_tm_f32.as_slice()?;
2306 let (inner, ctx, dev_id) = py.allow_threads(|| {
2307 let cuda = CudaZscore::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2308 let ctx = cuda.context_arc();
2309 let dev_id = cuda.device_id();
2310 let arr = cuda
2311 .zscore_many_series_one_param_time_major_dev(slice_in, cols, rows, period, nbdev as f32)
2312 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2313 Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
2314 })?;
2315
2316 Ok(ZscoreDeviceArrayF32Py::new_from_rust(inner, ctx, dev_id))
2317}
2318
2319#[cfg(feature = "python")]
2320#[pyclass(name = "ZscoreStream")]
2321pub struct ZscoreStreamPy {
2322 stream: ZscoreStream,
2323}
2324
2325#[cfg(feature = "python")]
2326#[pymethods]
2327impl ZscoreStreamPy {
2328 #[new]
2329 fn new(period: usize, ma_type: &str, nbdev: f64, devtype: usize) -> PyResult<Self> {
2330 let params = ZscoreParams {
2331 period: Some(period),
2332 ma_type: Some(ma_type.to_string()),
2333 nbdev: Some(nbdev),
2334 devtype: Some(devtype),
2335 };
2336 let stream =
2337 ZscoreStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2338 Ok(ZscoreStreamPy { stream })
2339 }
2340
2341 fn update(&mut self, value: f64) -> Option<f64> {
2342 self.stream.update(value)
2343 }
2344}
2345
2346#[cfg(feature = "python")]
2347#[pyfunction(name = "zscore_batch")]
2348#[pyo3(signature = (data, period_range, ma_type="sma", nbdev_range=(1.0, 1.0, 0.0), devtype_range=(0, 0, 0), kernel=None))]
2349pub fn zscore_batch_py<'py>(
2350 py: Python<'py>,
2351 data: numpy::PyReadonlyArray1<'py, f64>,
2352 period_range: (usize, usize, usize),
2353 ma_type: &str,
2354 nbdev_range: (f64, f64, f64),
2355 devtype_range: (usize, usize, usize),
2356 kernel: Option<&str>,
2357) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2358 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2359 use pyo3::types::PyDict;
2360
2361 let slice_in = data.as_slice()?;
2362
2363 let sweep = ZscoreBatchRange {
2364 period: period_range,
2365 ma_type: (ma_type.to_string(), ma_type.to_string(), "".to_string()),
2366 nbdev: nbdev_range,
2367 devtype: devtype_range,
2368 };
2369
2370 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2371 let rows = combos.len();
2372 let cols = slice_in.len();
2373
2374 let total = rows
2375 .checked_mul(cols)
2376 .ok_or_else(|| PyValueError::new_err("zscore_batch: rows*cols overflow"))?;
2377 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2378 let slice_out = unsafe { out_arr.as_slice_mut()? };
2379
2380 let kern = validate_kernel(kernel, true)?;
2381
2382 let combos = py
2383 .allow_threads(|| {
2384 let kernel = match kern {
2385 Kernel::Auto => detect_best_batch_kernel(),
2386 k => k,
2387 };
2388 let simd = match kernel {
2389 Kernel::Avx512Batch => Kernel::Avx512,
2390 Kernel::Avx2Batch => Kernel::Avx2,
2391 Kernel::ScalarBatch => Kernel::Scalar,
2392 _ => unreachable!(),
2393 };
2394 zscore_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2395 })
2396 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2397
2398 let dict = PyDict::new(py);
2399 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2400 dict.set_item(
2401 "periods",
2402 combos
2403 .iter()
2404 .map(|p| p.period.unwrap() as u64)
2405 .collect::<Vec<_>>()
2406 .into_pyarray(py),
2407 )?;
2408 dict.set_item(
2409 "ma_types",
2410 PyList::new(
2411 py,
2412 combos.iter().map(|p| p.ma_type.as_ref().unwrap().clone()),
2413 )?,
2414 )?;
2415 dict.set_item(
2416 "nbdevs",
2417 combos
2418 .iter()
2419 .map(|p| p.nbdev.unwrap())
2420 .collect::<Vec<_>>()
2421 .into_pyarray(py),
2422 )?;
2423 dict.set_item(
2424 "devtypes",
2425 combos
2426 .iter()
2427 .map(|p| p.devtype.unwrap() as u64)
2428 .collect::<Vec<_>>()
2429 .into_pyarray(py),
2430 )?;
2431
2432 Ok(dict)
2433}
2434
2435pub fn zscore_into_slice(
2436 dst: &mut [f64],
2437 input: &ZscoreInput,
2438 kern: Kernel,
2439) -> Result<(), ZscoreError> {
2440 let data: &[f64] = input.as_ref();
2441 if data.is_empty() {
2442 return Err(ZscoreError::EmptyInputData);
2443 }
2444 if dst.len() != data.len() {
2445 return Err(ZscoreError::OutputLengthMismatch {
2446 expected: data.len(),
2447 got: dst.len(),
2448 });
2449 }
2450
2451 let first = data
2452 .iter()
2453 .position(|x| !x.is_nan())
2454 .ok_or(ZscoreError::AllValuesNaN)?;
2455 let len = data.len();
2456 let period = input.get_period();
2457 if period == 0 || period > len {
2458 return Err(ZscoreError::InvalidPeriod {
2459 period,
2460 data_len: len,
2461 });
2462 }
2463 if (len - first) < period {
2464 return Err(ZscoreError::NotEnoughValidData {
2465 needed: period,
2466 valid: len - first,
2467 });
2468 }
2469
2470 let ma_type = input.get_ma_type();
2471 let nbdev = input.get_nbdev();
2472 let devtype = input.get_devtype();
2473
2474 let chosen = match kern {
2475 Kernel::Auto => match detect_best_kernel() {
2476 Kernel::Avx512 | Kernel::Avx512Batch => Kernel::Avx2,
2477 other => other,
2478 },
2479 other => other,
2480 };
2481
2482 unsafe {
2483 match chosen {
2484 Kernel::Scalar | Kernel::ScalarBatch => {
2485 zscore_compute_into_scalar(data, period, first, &ma_type, nbdev, devtype, dst)
2486 }
2487 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2488 Kernel::Avx2 | Kernel::Avx2Batch => {
2489 zscore_compute_into_avx2(data, period, first, &ma_type, nbdev, devtype, dst)
2490 }
2491 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2492 Kernel::Avx512 | Kernel::Avx512Batch => {
2493 zscore_compute_into_avx512(data, period, first, &ma_type, nbdev, devtype, dst)
2494 }
2495 _ => {
2496 return Err(ZscoreError::InvalidPeriod {
2497 period: 0,
2498 data_len: 0,
2499 })
2500 }
2501 }
2502 }?;
2503
2504 Ok(())
2505}
2506
2507#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2508pub fn zscore_into(input: &ZscoreInput, out: &mut [f64]) -> Result<(), ZscoreError> {
2509 zscore_into_slice(out, input, Kernel::Auto)
2510}
2511
2512#[inline]
2513unsafe fn zscore_compute_into_scalar(
2514 data: &[f64],
2515 period: usize,
2516 first: usize,
2517 ma_type: &str,
2518 nbdev: f64,
2519 devtype: usize,
2520 out: &mut [f64],
2521) -> Result<(), ZscoreError> {
2522 let warmup_end = first + period - 1;
2523 for v in &mut out[..warmup_end] {
2524 *v = f64::from_bits(0x7ff8_0000_0000_0000);
2525 }
2526
2527 if data.len() <= warmup_end {
2528 return Ok(());
2529 }
2530
2531 if devtype == 0 {
2532 if ma_type == "sma" {
2533 let inv = 1.0 / (period as f64);
2534 let mut sum = 0.0f64;
2535 let mut sum_sqr = 0.0f64;
2536 {
2537 let mut j = first;
2538 while j <= warmup_end {
2539 let v = *data.get_unchecked(j);
2540 sum += v;
2541 sum_sqr = v.mul_add(v, sum_sqr);
2542 j += 1;
2543 }
2544 }
2545 let mut mean = sum * inv;
2546 let mut variance = (-mean).mul_add(mean, sum_sqr * inv);
2547 if variance < 0.0 {
2548 variance = 0.0;
2549 }
2550 let mut sd = if variance == 0.0 {
2551 0.0
2552 } else {
2553 variance.sqrt() * nbdev
2554 };
2555
2556 let xw = *data.get_unchecked(warmup_end);
2557 *out.get_unchecked_mut(warmup_end) = if sd == 0.0 || sd.is_nan() {
2558 f64::NAN
2559 } else {
2560 (xw - mean) / sd
2561 };
2562
2563 let n = data.len();
2564 let mut i = warmup_end + 1;
2565 while i < n {
2566 let old_val = *data.get_unchecked(i - period);
2567 let new_val = *data.get_unchecked(i);
2568 let dd = new_val - old_val;
2569 sum += dd;
2570 sum_sqr = (new_val + old_val).mul_add(dd, sum_sqr);
2571 mean = sum * inv;
2572
2573 variance = (-mean).mul_add(mean, sum_sqr * inv);
2574 if variance < 0.0 {
2575 variance = 0.0;
2576 }
2577 sd = if variance == 0.0 {
2578 0.0
2579 } else {
2580 variance.sqrt() * nbdev
2581 };
2582
2583 *out.get_unchecked_mut(i) = if sd == 0.0 || sd.is_nan() {
2584 f64::NAN
2585 } else {
2586 (new_val - mean) / sd
2587 };
2588 i += 1;
2589 }
2590
2591 return Ok(());
2592 }
2593
2594 if ma_type == "ema" {
2595 let den = period as f64;
2596 let inv = 1.0 / den;
2597 let alpha = 2.0 / (den + 1.0);
2598 let one_minus_alpha = 1.0 - alpha;
2599
2600 let mut sum = 0.0f64;
2601 let mut sum2 = 0.0f64;
2602 {
2603 let mut j = first;
2604 while j <= warmup_end {
2605 let v = *data.get_unchecked(j);
2606 sum += v;
2607 sum2 = v.mul_add(v, sum2);
2608 j += 1;
2609 }
2610 }
2611 let mut ema = sum * inv;
2612
2613 let mut ex = sum * inv;
2614 let mut ex2 = sum2 * inv;
2615 let mut mse = (-2.0 * ema).mul_add(ex, ema.mul_add(ema, ex2));
2616 if mse < 0.0 {
2617 mse = 0.0;
2618 }
2619 let mut sd = mse.sqrt() * nbdev;
2620
2621 let xw = *data.get_unchecked(warmup_end);
2622 *out.get_unchecked_mut(warmup_end) = if sd == 0.0 || sd.is_nan() {
2623 f64::NAN
2624 } else {
2625 (xw - ema) / sd
2626 };
2627
2628 let n = data.len();
2629 let mut i = warmup_end + 1;
2630 while i < n {
2631 let new = *data.get_unchecked(i);
2632 let old = *data.get_unchecked(i - period);
2633
2634 let dd = new - old;
2635 sum += dd;
2636 sum2 = (new + old).mul_add(dd, sum2);
2637 ex = sum * inv;
2638 ex2 = sum2 * inv;
2639
2640 ema = ema.mul_add(one_minus_alpha, alpha * new);
2641
2642 mse = (-2.0 * ema).mul_add(ex, ema.mul_add(ema, ex2));
2643 if mse < 0.0 {
2644 mse = 0.0;
2645 }
2646 sd = mse.sqrt() * nbdev;
2647
2648 *out.get_unchecked_mut(i) = if sd == 0.0 || sd.is_nan() {
2649 f64::NAN
2650 } else {
2651 (new - ema) / sd
2652 };
2653 i += 1;
2654 }
2655
2656 return Ok(());
2657 }
2658 }
2659
2660 let means = ma(ma_type, MaData::Slice(data), period)
2661 .map_err(|e| ZscoreError::MaError(e.to_string()))?;
2662 let dev_input = DevInput {
2663 data: DeviationData::Slice(data),
2664 params: DevParams {
2665 period: Some(period),
2666 devtype: Some(devtype),
2667 },
2668 };
2669 let mut sigmas = deviation(&dev_input)?.values;
2670 for v in &mut sigmas {
2671 *v *= nbdev;
2672 }
2673
2674 for i in warmup_end..data.len() {
2675 let mean = means[i];
2676 let sigma = sigmas[i];
2677 let value = data[i];
2678 out[i] = if sigma == 0.0 || sigma.is_nan() {
2679 f64::NAN
2680 } else {
2681 (value - mean) / sigma
2682 };
2683 }
2684 Ok(())
2685}
2686
2687#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2688#[inline]
2689unsafe fn zscore_compute_into_avx2(
2690 data: &[f64],
2691 period: usize,
2692 first: usize,
2693 ma_type: &str,
2694 nbdev: f64,
2695 devtype: usize,
2696 out: &mut [f64],
2697) -> Result<(), ZscoreError> {
2698 zscore_compute_into_scalar(data, period, first, ma_type, nbdev, devtype, out)
2699}
2700
2701#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2702#[inline]
2703unsafe fn zscore_compute_into_avx512(
2704 data: &[f64],
2705 period: usize,
2706 first: usize,
2707 ma_type: &str,
2708 nbdev: f64,
2709 devtype: usize,
2710 out: &mut [f64],
2711) -> Result<(), ZscoreError> {
2712 zscore_compute_into_scalar(data, period, first, ma_type, nbdev, devtype, out)
2713}
2714
2715#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2716#[wasm_bindgen]
2717pub fn zscore_js(
2718 data: &[f64],
2719 period: usize,
2720 ma_type: &str,
2721 nbdev: f64,
2722 devtype: usize,
2723) -> Result<Vec<f64>, JsValue> {
2724 let params = ZscoreParams {
2725 period: Some(period),
2726 ma_type: Some(ma_type.to_string()),
2727 nbdev: Some(nbdev),
2728 devtype: Some(devtype),
2729 };
2730 let input = ZscoreInput::from_slice(data, params);
2731
2732 let mut output = vec![0.0; data.len()];
2733
2734 zscore_into_slice(&mut output, &input, Kernel::Auto)
2735 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2736
2737 Ok(output)
2738}
2739
2740#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2741#[wasm_bindgen]
2742pub fn zscore_alloc(len: usize) -> *mut f64 {
2743 let mut vec = Vec::<f64>::with_capacity(len);
2744 let ptr = vec.as_mut_ptr();
2745 std::mem::forget(vec);
2746 ptr
2747}
2748
2749#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2750#[wasm_bindgen]
2751pub fn zscore_free(ptr: *mut f64, len: usize) {
2752 if !ptr.is_null() {
2753 unsafe {
2754 let _ = Vec::from_raw_parts(ptr, len, len);
2755 }
2756 }
2757}
2758
2759#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2760#[wasm_bindgen]
2761pub fn zscore_into(
2762 in_ptr: *const f64,
2763 out_ptr: *mut f64,
2764 len: usize,
2765 period: usize,
2766 ma_type: &str,
2767 nbdev: f64,
2768 devtype: usize,
2769) -> Result<(), JsValue> {
2770 if in_ptr.is_null() || out_ptr.is_null() {
2771 return Err(JsValue::from_str("Null pointer provided"));
2772 }
2773
2774 unsafe {
2775 let data = std::slice::from_raw_parts(in_ptr, len);
2776 let params = ZscoreParams {
2777 period: Some(period),
2778 ma_type: Some(ma_type.to_string()),
2779 nbdev: Some(nbdev),
2780 devtype: Some(devtype),
2781 };
2782 let input = ZscoreInput::from_slice(data, params);
2783
2784 if in_ptr == out_ptr {
2785 let mut temp = vec![0.0; len];
2786 zscore_into_slice(&mut temp, &input, Kernel::Auto)
2787 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2788 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2789 out.copy_from_slice(&temp);
2790 } else {
2791 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2792 zscore_into_slice(out, &input, Kernel::Auto)
2793 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2794 }
2795 }
2796
2797 Ok(())
2798}
2799
2800#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2801#[derive(Serialize, Deserialize)]
2802pub struct ZscoreBatchConfig {
2803 pub period_range: (usize, usize, usize),
2804 pub ma_type: String,
2805 pub nbdev_range: (f64, f64, f64),
2806 pub devtype_range: (usize, usize, usize),
2807}
2808
2809#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2810#[derive(Serialize, Deserialize)]
2811pub struct ZscoreBatchJsOutput {
2812 pub values: Vec<f64>,
2813 pub combos: Vec<ZscoreParams>,
2814 pub rows: usize,
2815 pub cols: usize,
2816}
2817
2818#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2819#[wasm_bindgen(js_name = zscore_batch)]
2820pub fn zscore_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2821 let config: ZscoreBatchConfig = serde_wasm_bindgen::from_value(config)
2822 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2823
2824 let sweep = ZscoreBatchRange {
2825 period: config.period_range,
2826 ma_type: (
2827 config.ma_type.clone(),
2828 config.ma_type.clone(),
2829 "".to_string(),
2830 ),
2831 nbdev: config.nbdev_range,
2832 devtype: config.devtype_range,
2833 };
2834
2835 let output = zscore_batch_inner(data, &sweep, Kernel::Auto, false)
2836 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2837
2838 let js_output = ZscoreBatchJsOutput {
2839 values: output.values,
2840 combos: output.combos,
2841 rows: output.rows,
2842 cols: output.cols,
2843 };
2844
2845 serde_wasm_bindgen::to_value(&js_output).map_err(|e| JsValue::from_str(&e.to_string()))
2846}
2847
2848#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2849#[wasm_bindgen]
2850pub fn zscore_batch_into(
2851 in_ptr: *const f64,
2852 out_ptr: *mut f64,
2853 len: usize,
2854 period_start: usize,
2855 period_end: usize,
2856 period_step: usize,
2857 ma_type: &str,
2858 nbdev_start: f64,
2859 nbdev_end: f64,
2860 nbdev_step: f64,
2861 devtype_start: usize,
2862 devtype_end: usize,
2863 devtype_step: usize,
2864) -> Result<usize, JsValue> {
2865 if in_ptr.is_null() || out_ptr.is_null() {
2866 return Err(JsValue::from_str("Null pointer provided"));
2867 }
2868
2869 let sweep = ZscoreBatchRange {
2870 period: (period_start, period_end, period_step),
2871 ma_type: (ma_type.to_string(), ma_type.to_string(), "".to_string()),
2872 nbdev: (nbdev_start, nbdev_end, nbdev_step),
2873 devtype: (devtype_start, devtype_end, devtype_step),
2874 };
2875
2876 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2877 let n_combos = combos.len();
2878
2879 unsafe {
2880 let data = std::slice::from_raw_parts(in_ptr, len);
2881 let total = n_combos
2882 .checked_mul(len)
2883 .ok_or_else(|| JsValue::from_str("zscore_batch_into: rows*cols overflow"))?;
2884 let out = std::slice::from_raw_parts_mut(out_ptr, total);
2885
2886 let simd = detect_best_kernel();
2887 zscore_batch_inner_into(data, &sweep, simd, false, out)
2888 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2889 }
2890
2891 Ok(n_combos)
2892}
2893
2894#[cfg(test)]
2895mod tests {
2896 use super::*;
2897 use crate::skip_if_unsupported;
2898 use crate::utilities::data_loader::read_candles_from_csv;
2899 fn check_zscore_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2900 skip_if_unsupported!(kernel, test_name);
2901 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2902 let candles = read_candles_from_csv(file_path)?;
2903 let default_params = ZscoreParams {
2904 period: None,
2905 ma_type: None,
2906 nbdev: None,
2907 devtype: None,
2908 };
2909 let input = ZscoreInput::from_candles(&candles, "close", default_params);
2910 let output = zscore_with_kernel(&input, kernel)?;
2911 assert_eq!(output.values.len(), candles.close.len());
2912 Ok(())
2913 }
2914 fn check_zscore_with_zero_period(
2915 test_name: &str,
2916 kernel: Kernel,
2917 ) -> Result<(), Box<dyn Error>> {
2918 skip_if_unsupported!(kernel, test_name);
2919 let input_data = [10.0, 20.0, 30.0];
2920 let params = ZscoreParams {
2921 period: Some(0),
2922 ma_type: None,
2923 nbdev: None,
2924 devtype: None,
2925 };
2926 let input = ZscoreInput::from_slice(&input_data, params);
2927 let res = zscore_with_kernel(&input, kernel);
2928 assert!(
2929 res.is_err(),
2930 "[{}] Zscore should fail with zero period",
2931 test_name
2932 );
2933 Ok(())
2934 }
2935 fn check_zscore_period_exceeds_length(
2936 test_name: &str,
2937 kernel: Kernel,
2938 ) -> Result<(), Box<dyn Error>> {
2939 skip_if_unsupported!(kernel, test_name);
2940 let data_small = [10.0, 20.0, 30.0];
2941 let params = ZscoreParams {
2942 period: Some(10),
2943 ma_type: None,
2944 nbdev: None,
2945 devtype: None,
2946 };
2947 let input = ZscoreInput::from_slice(&data_small, params);
2948 let res = zscore_with_kernel(&input, kernel);
2949 assert!(
2950 res.is_err(),
2951 "[{}] Zscore should fail with period exceeding length",
2952 test_name
2953 );
2954 Ok(())
2955 }
2956 fn check_zscore_very_small_dataset(
2957 test_name: &str,
2958 kernel: Kernel,
2959 ) -> Result<(), Box<dyn Error>> {
2960 skip_if_unsupported!(kernel, test_name);
2961 let single_point = [42.0];
2962 let params = ZscoreParams {
2963 period: Some(14),
2964 ma_type: None,
2965 nbdev: None,
2966 devtype: None,
2967 };
2968 let input = ZscoreInput::from_slice(&single_point, params);
2969 let res = zscore_with_kernel(&input, kernel);
2970 assert!(
2971 res.is_err(),
2972 "[{}] Zscore should fail with insufficient data",
2973 test_name
2974 );
2975 Ok(())
2976 }
2977 fn check_zscore_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2978 skip_if_unsupported!(kernel, test_name);
2979 let input_data = [f64::NAN, f64::NAN, f64::NAN];
2980 let params = ZscoreParams::default();
2981 let input = ZscoreInput::from_slice(&input_data, params);
2982 let res = zscore_with_kernel(&input, kernel);
2983 assert!(
2984 res.is_err(),
2985 "[{}] Zscore should fail when all values are NaN",
2986 test_name
2987 );
2988 Ok(())
2989 }
2990 fn check_zscore_input_with_default_candles(
2991 test_name: &str,
2992 kernel: Kernel,
2993 ) -> Result<(), Box<dyn Error>> {
2994 skip_if_unsupported!(kernel, test_name);
2995 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2996 let candles = read_candles_from_csv(file_path)?;
2997 let input = ZscoreInput::with_default_candles(&candles);
2998 match input.data {
2999 ZscoreData::Candles { source, .. } => assert_eq!(source, "close"),
3000 _ => panic!("Expected ZscoreData::Candles"),
3001 }
3002 Ok(())
3003 }
3004 fn check_zscore_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3005 skip_if_unsupported!(kernel, test_name);
3006 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3007 let candles = read_candles_from_csv(file_path)?;
3008
3009 let input = ZscoreInput::from_candles(&candles, "close", ZscoreParams::default());
3010 let result = zscore_with_kernel(&input, kernel)?;
3011
3012 let expected_last_five = [
3013 -0.3040683926967643,
3014 -0.41042159719064014,
3015 -0.5411993612192193,
3016 -0.1673226261513698,
3017 -1.431635486349618,
3018 ];
3019 let start = result.values.len().saturating_sub(5);
3020
3021 for (i, &val) in result.values[start..].iter().enumerate() {
3022 let diff = (val - expected_last_five[i]).abs();
3023 assert!(
3024 diff < 1e-8,
3025 "[{}] Zscore {:?} mismatch at idx {}: got {}, expected {}",
3026 test_name,
3027 kernel,
3028 i,
3029 val,
3030 expected_last_five[i]
3031 );
3032 }
3033 Ok(())
3034 }
3035 macro_rules! generate_all_zscore_tests {
3036 ($($test_fn:ident),*) => {
3037 paste::paste! {
3038 $(
3039 #[test]
3040 fn [<$test_fn _scalar_f64>]() {
3041 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
3042 }
3043 )*
3044 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3045 $(
3046 #[test]
3047 fn [<$test_fn _avx2_f64>]() {
3048 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
3049 }
3050 #[test]
3051 fn [<$test_fn _avx512_f64>]() {
3052 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
3053 }
3054 )*
3055 }
3056 }
3057 }
3058 #[cfg(debug_assertions)]
3059 fn check_zscore_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3060 skip_if_unsupported!(kernel, test_name);
3061
3062 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3063 let candles = read_candles_from_csv(file_path)?;
3064
3065 let test_params = vec![
3066 ZscoreParams::default(),
3067 ZscoreParams {
3068 period: Some(2),
3069 ma_type: Some("sma".to_string()),
3070 nbdev: Some(1.0),
3071 devtype: Some(0),
3072 },
3073 ZscoreParams {
3074 period: Some(5),
3075 ma_type: Some("ema".to_string()),
3076 nbdev: Some(1.0),
3077 devtype: Some(0),
3078 },
3079 ZscoreParams {
3080 period: Some(10),
3081 ma_type: Some("wma".to_string()),
3082 nbdev: Some(2.0),
3083 devtype: Some(0),
3084 },
3085 ZscoreParams {
3086 period: Some(20),
3087 ma_type: Some("sma".to_string()),
3088 nbdev: Some(1.5),
3089 devtype: Some(1),
3090 },
3091 ZscoreParams {
3092 period: Some(30),
3093 ma_type: Some("ema".to_string()),
3094 nbdev: Some(2.5),
3095 devtype: Some(2),
3096 },
3097 ZscoreParams {
3098 period: Some(50),
3099 ma_type: Some("wma".to_string()),
3100 nbdev: Some(3.0),
3101 devtype: Some(0),
3102 },
3103 ZscoreParams {
3104 period: Some(100),
3105 ma_type: Some("sma".to_string()),
3106 nbdev: Some(1.0),
3107 devtype: Some(1),
3108 },
3109 ZscoreParams {
3110 period: Some(14),
3111 ma_type: Some("ema".to_string()),
3112 nbdev: Some(0.5),
3113 devtype: Some(0),
3114 },
3115 ZscoreParams {
3116 period: Some(14),
3117 ma_type: Some("sma".to_string()),
3118 nbdev: Some(0.1),
3119 devtype: Some(2),
3120 },
3121 ZscoreParams {
3122 period: Some(25),
3123 ma_type: Some("wma".to_string()),
3124 nbdev: Some(4.0),
3125 devtype: Some(1),
3126 },
3127 ZscoreParams {
3128 period: Some(7),
3129 ma_type: Some("ema".to_string()),
3130 nbdev: Some(1.618),
3131 devtype: Some(0),
3132 },
3133 ZscoreParams {
3134 period: Some(21),
3135 ma_type: Some("sma".to_string()),
3136 nbdev: Some(2.718),
3137 devtype: Some(1),
3138 },
3139 ZscoreParams {
3140 period: Some(42),
3141 ma_type: Some("wma".to_string()),
3142 nbdev: Some(3.14159),
3143 devtype: Some(2),
3144 },
3145 ];
3146
3147 for (param_idx, params) in test_params.iter().enumerate() {
3148 let input = ZscoreInput::from_candles(&candles, "close", params.clone());
3149 let output = zscore_with_kernel(&input, kernel)?;
3150
3151 for (i, &val) in output.values.iter().enumerate() {
3152 if val.is_nan() {
3153 continue;
3154 }
3155
3156 let bits = val.to_bits();
3157
3158 if bits == 0x11111111_11111111 {
3159 panic!(
3160 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
3161 with params: period={}, ma_type={}, nbdev={}, devtype={} (param set {})",
3162 test_name,
3163 val,
3164 bits,
3165 i,
3166 params.period.unwrap_or(14),
3167 params.ma_type.as_deref().unwrap_or("sma"),
3168 params.nbdev.unwrap_or(1.0),
3169 params.devtype.unwrap_or(0),
3170 param_idx
3171 );
3172 }
3173
3174 if bits == 0x22222222_22222222 {
3175 panic!(
3176 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
3177 with params: period={}, ma_type={}, nbdev={}, devtype={} (param set {})",
3178 test_name,
3179 val,
3180 bits,
3181 i,
3182 params.period.unwrap_or(14),
3183 params.ma_type.as_deref().unwrap_or("sma"),
3184 params.nbdev.unwrap_or(1.0),
3185 params.devtype.unwrap_or(0),
3186 param_idx
3187 );
3188 }
3189
3190 if bits == 0x33333333_33333333 {
3191 panic!(
3192 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
3193 with params: period={}, ma_type={}, nbdev={}, devtype={} (param set {})",
3194 test_name,
3195 val,
3196 bits,
3197 i,
3198 params.period.unwrap_or(14),
3199 params.ma_type.as_deref().unwrap_or("sma"),
3200 params.nbdev.unwrap_or(1.0),
3201 params.devtype.unwrap_or(0),
3202 param_idx
3203 );
3204 }
3205 }
3206 }
3207
3208 Ok(())
3209 }
3210
3211 #[cfg(not(debug_assertions))]
3212 fn check_zscore_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
3213 Ok(())
3214 }
3215
3216 #[cfg(feature = "proptest")]
3217 #[allow(clippy::float_cmp)]
3218 fn check_zscore_property(
3219 test_name: &str,
3220 kernel: Kernel,
3221 ) -> Result<(), Box<dyn std::error::Error>> {
3222 use proptest::prelude::*;
3223 skip_if_unsupported!(kernel, test_name);
3224
3225 let strat = (2usize..=64).prop_flat_map(|period| {
3226 (
3227 prop::collection::vec(
3228 (-1e5f64..1e5f64).prop_filter("finite", |x| x.is_finite()),
3229 period + 10..400,
3230 ),
3231 Just(period),
3232 prop::sample::select(vec!["sma", "ema", "wma"]),
3233 0.5f64..3.0f64,
3234 0usize..=2,
3235 )
3236 });
3237
3238 proptest::test_runner::TestRunner::default().run(
3239 &strat,
3240 |(data, period, ma_type, nbdev, devtype)| {
3241 let params = ZscoreParams {
3242 period: Some(period),
3243 ma_type: Some(ma_type.to_string()),
3244 nbdev: Some(nbdev),
3245 devtype: Some(devtype),
3246 };
3247 let input = ZscoreInput::from_slice(&data, params.clone());
3248
3249 let ZscoreOutput { values: out } = zscore_with_kernel(&input, kernel)?;
3250
3251 let ZscoreOutput { values: ref_out } = zscore_with_kernel(&input, Kernel::Scalar)?;
3252
3253 prop_assert_eq!(out.len(), data.len(), "Output length mismatch");
3254
3255 for i in 0..(period - 1) {
3256 prop_assert!(
3257 out[i].is_nan(),
3258 "Expected NaN during warmup at index {}, got {}",
3259 i,
3260 out[i]
3261 );
3262 }
3263
3264 for i in (period - 1)..data.len() {
3265 let y = out[i];
3266 let r = ref_out[i];
3267
3268 if !y.is_finite() || !r.is_finite() {
3269 prop_assert_eq!(
3270 y.to_bits(),
3271 r.to_bits(),
3272 "NaN/infinite mismatch at index {}: {} vs {}",
3273 i,
3274 y,
3275 r
3276 );
3277 } else {
3278 let y_bits = y.to_bits();
3279 let r_bits = r.to_bits();
3280 let ulp_diff = y_bits.abs_diff(r_bits);
3281
3282 prop_assert!(
3283 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
3284 "Kernel mismatch at index {}: {} vs {} (ULP={})",
3285 i,
3286 y,
3287 r,
3288 ulp_diff
3289 );
3290 }
3291 }
3292
3293 if data.windows(2).all(|w| (w[0] - w[1]).abs() < f64::EPSILON) {
3294 for i in (period - 1)..data.len() {
3295 prop_assert!(
3296 out[i].is_nan() || devtype != 0,
3297 "Expected NaN for constant data with stddev at index {}, got {}",
3298 i,
3299 out[i]
3300 );
3301 }
3302 }
3303
3304 if period == 2 && devtype == 0 && ma_type == "sma" {
3305 for i in 1..data.len() {
3306 if out[i].is_finite() {
3307 let mean = (data[i - 1] + data[i]) / 2.0;
3308 let diff1 = (data[i - 1] - mean).powi(2);
3309 let diff2 = (data[i] - mean).powi(2);
3310 let variance = (diff1 + diff2) / 2.0;
3311 let stddev = variance.sqrt();
3312
3313 if stddev > f64::EPSILON {
3314 let expected = (data[i] - mean) / (stddev * nbdev);
3315 prop_assert!(
3316 (out[i] - expected).abs() <= 1e-6,
3317 "Zscore calculation mismatch at index {}: {} vs expected {}",
3318 i,
3319 out[i],
3320 expected
3321 );
3322 }
3323 }
3324 }
3325 }
3326
3327 Ok(())
3328 },
3329 )?;
3330
3331 Ok(())
3332 }
3333
3334 generate_all_zscore_tests!(
3335 check_zscore_partial_params,
3336 check_zscore_with_zero_period,
3337 check_zscore_period_exceeds_length,
3338 check_zscore_very_small_dataset,
3339 check_zscore_all_nan,
3340 check_zscore_input_with_default_candles,
3341 check_zscore_accuracy,
3342 check_zscore_no_poison
3343 );
3344
3345 #[cfg(feature = "proptest")]
3346 generate_all_zscore_tests!(check_zscore_property);
3347 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3348 skip_if_unsupported!(kernel, test);
3349 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3350 let c = read_candles_from_csv(file)?;
3351 let output = ZscoreBatchBuilder::new()
3352 .kernel(kernel)
3353 .apply_candles(&c, "close")?;
3354 let def = ZscoreParams::default();
3355 let row = output.values_for(&def).expect("default row missing");
3356 assert_eq!(row.len(), c.close.len());
3357 Ok(())
3358 }
3359 macro_rules! gen_batch_tests {
3360 ($fn_name:ident) => {
3361 paste::paste! {
3362 #[test] fn [<$fn_name _scalar>]() {
3363 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3364 }
3365 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3366 #[test] fn [<$fn_name _avx2>]() {
3367 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3368 }
3369 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3370 #[test] fn [<$fn_name _avx512>]() {
3371 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3372 }
3373 #[test] fn [<$fn_name _auto_detect>]() {
3374 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3375 }
3376 }
3377 };
3378 }
3379 #[cfg(debug_assertions)]
3380 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3381 skip_if_unsupported!(kernel, test);
3382
3383 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3384 let c = read_candles_from_csv(file)?;
3385
3386 let test_configs = vec![
3387 (2, 10, 2, 0.5, 2.0, 0.5, 0),
3388 (5, 25, 5, 1.0, 3.0, 1.0, 1),
3389 (10, 50, 10, 1.5, 3.5, 1.0, 2),
3390 (2, 5, 1, 0.1, 1.0, 0.3, 0),
3391 (14, 14, 0, 1.0, 4.0, 0.5, 0),
3392 (20, 40, 10, 2.0, 2.0, 0.0, 1),
3393 ];
3394
3395 let ma_types = vec!["sma", "ema", "wma"];
3396
3397 for (
3398 cfg_idx,
3399 &(period_start, period_end, period_step, nbdev_start, nbdev_end, nbdev_step, devtype),
3400 ) in test_configs.iter().enumerate()
3401 {
3402 for ma_type in &ma_types {
3403 let mut builder = ZscoreBatchBuilder::new().kernel(kernel);
3404
3405 if period_step > 0 {
3406 builder = builder.period_range(period_start, period_end, period_step);
3407 } else {
3408 builder = builder.period_static(period_start);
3409 }
3410
3411 if nbdev_step > 0.0 {
3412 builder = builder.nbdev_range(nbdev_start, nbdev_end, nbdev_step);
3413 } else {
3414 builder = builder.nbdev_static(nbdev_start);
3415 }
3416
3417 builder = builder.ma_type_static(ma_type.to_string());
3418
3419 builder = builder.devtype_static(devtype);
3420
3421 let output = builder.apply_candles(&c, "close")?;
3422
3423 for (idx, &val) in output.values.iter().enumerate() {
3424 if val.is_nan() {
3425 continue;
3426 }
3427
3428 let bits = val.to_bits();
3429 let row = idx / output.cols;
3430 let col = idx % output.cols;
3431 let combo = &output.combos[row];
3432
3433 if bits == 0x11111111_11111111 {
3434 panic!(
3435 "[{}] Config {} (MA: {}): Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
3436 at row {} col {} (flat index {}) with params: period={}, ma_type={}, nbdev={}, devtype={}",
3437 test, cfg_idx, ma_type, val, bits, row, col, idx,
3438 combo.period.unwrap_or(14),
3439 combo.ma_type.as_deref().unwrap_or("sma"),
3440 combo.nbdev.unwrap_or(1.0),
3441 combo.devtype.unwrap_or(0)
3442 );
3443 }
3444
3445 if bits == 0x22222222_22222222 {
3446 panic!(
3447 "[{}] Config {} (MA: {}): Found init_matrix_prefixes poison value {} (0x{:016X}) \
3448 at row {} col {} (flat index {}) with params: period={}, ma_type={}, nbdev={}, devtype={}",
3449 test, cfg_idx, ma_type, val, bits, row, col, idx,
3450 combo.period.unwrap_or(14),
3451 combo.ma_type.as_deref().unwrap_or("sma"),
3452 combo.nbdev.unwrap_or(1.0),
3453 combo.devtype.unwrap_or(0)
3454 );
3455 }
3456
3457 if bits == 0x33333333_33333333 {
3458 panic!(
3459 "[{}] Config {} (MA: {}): Found make_uninit_matrix poison value {} (0x{:016X}) \
3460 at row {} col {} (flat index {}) with params: period={}, ma_type={}, nbdev={}, devtype={}",
3461 test, cfg_idx, ma_type, val, bits, row, col, idx,
3462 combo.period.unwrap_or(14),
3463 combo.ma_type.as_deref().unwrap_or("sma"),
3464 combo.nbdev.unwrap_or(1.0),
3465 combo.devtype.unwrap_or(0)
3466 );
3467 }
3468 }
3469 }
3470 }
3471
3472 let devtype_test = ZscoreBatchBuilder::new()
3473 .kernel(kernel)
3474 .period_range(10, 30, 10)
3475 .nbdev_static(2.0)
3476 .ma_type_static("ema")
3477 .devtype_range(0, 2, 1)
3478 .apply_candles(&c, "close")?;
3479
3480 for (idx, &val) in devtype_test.values.iter().enumerate() {
3481 if val.is_nan() {
3482 continue;
3483 }
3484
3485 let bits = val.to_bits();
3486 let row = idx / devtype_test.cols;
3487 let col = idx % devtype_test.cols;
3488 let combo = &devtype_test.combos[row];
3489
3490 if bits == 0x11111111_11111111
3491 || bits == 0x22222222_22222222
3492 || bits == 0x33333333_33333333
3493 {
3494 let poison_type = if bits == 0x11111111_11111111 {
3495 "alloc_with_nan_prefix"
3496 } else if bits == 0x22222222_22222222 {
3497 "init_matrix_prefixes"
3498 } else {
3499 "make_uninit_matrix"
3500 };
3501
3502 panic!(
3503 "[{}] Devtype test: Found {} poison value {} (0x{:016X}) \
3504 at row {} col {} (flat index {}) with params: period={}, ma_type={}, nbdev={}, devtype={}",
3505 test,
3506 poison_type,
3507 val,
3508 bits,
3509 row,
3510 col,
3511 idx,
3512 combo.period.unwrap_or(14),
3513 combo.ma_type.as_deref().unwrap_or("sma"),
3514 combo.nbdev.unwrap_or(1.0),
3515 combo.devtype.unwrap_or(0)
3516 );
3517 }
3518 }
3519
3520 Ok(())
3521 }
3522
3523 #[cfg(not(debug_assertions))]
3524 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
3525 Ok(())
3526 }
3527
3528 gen_batch_tests!(check_batch_default_row);
3529 gen_batch_tests!(check_batch_no_poison);
3530
3531 #[test]
3532 fn test_zscore_into_matches_api() -> Result<(), Box<dyn Error>> {
3533 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3534 let candles = read_candles_from_csv(file_path)?;
3535
3536 let input = ZscoreInput::from_candles(&candles, "close", ZscoreParams::default());
3537
3538 let baseline = zscore(&input)?.values;
3539
3540 let mut out = vec![0.0; baseline.len()];
3541 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
3542 {
3543 zscore_into(&input, &mut out)?;
3544 }
3545 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3546 {
3547 zscore_into_slice(&mut out, &input, Kernel::Auto)?;
3548 }
3549
3550 assert_eq!(baseline.len(), out.len());
3551
3552 let eq_or_both_nan = |a: f64, b: f64| -> bool {
3553 (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-9)
3554 };
3555 for i in 0..out.len() {
3556 assert!(
3557 eq_or_both_nan(baseline[i], out[i]),
3558 "Mismatch at index {}: baseline={}, into={}",
3559 i,
3560 baseline[i],
3561 out[i]
3562 );
3563 }
3564 Ok(())
3565 }
3566}