1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::moving_averages::DeviceArrayF32;
3#[cfg(feature = "cuda")]
4use crate::cuda::{CudaZscore, CudaZscoreError};
5use crate::indicators::deviation::{
6 deviation, DevError, DevInput, DevParams, DeviationData, DeviationOutput,
7};
8use crate::indicators::moving_averages::ma::{ma, MaData};
9use crate::utilities::data_loader::{source_type, Candles};
10#[cfg(all(feature = "python", feature = "cuda"))]
11use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
12use crate::utilities::enums::Kernel;
13use crate::utilities::helpers::{
14 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
15 make_uninit_matrix,
16};
17#[cfg(feature = "python")]
18use crate::utilities::kernel_validation::validate_kernel;
19use aligned_vec::{AVec, CACHELINE_ALIGN};
20#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
21use core::arch::x86_64::*;
22#[cfg(all(feature = "python", feature = "cuda"))]
23use cust::context::Context;
24#[cfg(all(feature = "python", feature = "cuda"))]
25use cust::memory::DeviceBuffer;
26#[cfg(feature = "python")]
27use numpy::{IntoPyArray, PyArray1};
28#[cfg(feature = "python")]
29use pyo3::exceptions::PyValueError;
30#[cfg(feature = "python")]
31use pyo3::prelude::*;
32#[cfg(feature = "python")]
33use pyo3::types::{PyDict, PyList};
34#[cfg(not(target_arch = "wasm32"))]
35use rayon::prelude::*;
36#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
37use serde::{Deserialize, Serialize};
38use std::collections::HashMap;
39use std::error::Error;
40#[cfg(all(feature = "python", feature = "cuda"))]
41use std::sync::Arc;
42use thiserror::Error;
43#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
44use wasm_bindgen::prelude::*;
45
46impl<'a> AsRef<[f64]> for ZscoreInput<'a> {
47 #[inline(always)]
48 fn as_ref(&self) -> &[f64] {
49 match &self.data {
50 ZscoreData::Slice(slice) => slice,
51 ZscoreData::Candles { candles, source } => source_type(candles, source),
52 }
53 }
54}
55
56#[derive(Debug, Clone)]
57pub enum ZscoreData<'a> {
58 Candles {
59 candles: &'a Candles,
60 source: &'a str,
61 },
62 Slice(&'a [f64]),
63}
64
65#[derive(Debug, Clone)]
66pub struct ZscoreOutput {
67 pub values: Vec<f64>,
68}
69
70#[derive(Debug, Clone)]
71#[cfg_attr(
72 all(target_arch = "wasm32", feature = "wasm"),
73 derive(Serialize, Deserialize)
74)]
75pub struct ZscoreParams {
76 pub period: Option<usize>,
77 pub ma_type: Option<String>,
78 pub nbdev: Option<f64>,
79 pub devtype: Option<usize>,
80}
81
82impl Default for ZscoreParams {
83 fn default() -> Self {
84 Self {
85 period: Some(14),
86 ma_type: Some("sma".to_string()),
87 nbdev: Some(1.0),
88 devtype: Some(0),
89 }
90 }
91}
92
93#[derive(Debug, Clone)]
94pub struct ZscoreInput<'a> {
95 pub data: ZscoreData<'a>,
96 pub params: ZscoreParams,
97}
98
99impl<'a> ZscoreInput<'a> {
100 #[inline]
101 pub fn from_candles(candles: &'a Candles, source: &'a str, params: ZscoreParams) -> Self {
102 Self {
103 data: ZscoreData::Candles { candles, source },
104 params,
105 }
106 }
107 #[inline]
108 pub fn from_slice(slice: &'a [f64], params: ZscoreParams) -> Self {
109 Self {
110 data: ZscoreData::Slice(slice),
111 params,
112 }
113 }
114 #[inline]
115 pub fn with_default_candles(candles: &'a Candles) -> Self {
116 Self::from_candles(candles, "close", ZscoreParams::default())
117 }
118 #[inline]
119 pub fn get_period(&self) -> usize {
120 self.params.period.unwrap_or(14)
121 }
122 #[inline]
123 pub fn get_ma_type(&self) -> String {
124 self.params
125 .ma_type
126 .clone()
127 .unwrap_or_else(|| "sma".to_string())
128 }
129 #[inline]
130 pub fn get_nbdev(&self) -> f64 {
131 self.params.nbdev.unwrap_or(1.0)
132 }
133 #[inline]
134 pub fn get_devtype(&self) -> usize {
135 self.params.devtype.unwrap_or(0)
136 }
137}
138
139#[derive(Clone, Debug)]
140pub struct ZscoreBuilder {
141 period: Option<usize>,
142 ma_type: Option<String>,
143 nbdev: Option<f64>,
144 devtype: Option<usize>,
145 kernel: Kernel,
146}
147
148impl Default for ZscoreBuilder {
149 fn default() -> Self {
150 Self {
151 period: None,
152 ma_type: None,
153 nbdev: None,
154 devtype: None,
155 kernel: Kernel::Auto,
156 }
157 }
158}
159
160impl ZscoreBuilder {
161 #[inline(always)]
162 pub fn new() -> Self {
163 Self::default()
164 }
165 #[inline(always)]
166 pub fn period(mut self, n: usize) -> Self {
167 self.period = Some(n);
168 self
169 }
170 #[inline(always)]
171 pub fn ma_type<T: Into<String>>(mut self, t: T) -> Self {
172 self.ma_type = Some(t.into());
173 self
174 }
175 #[inline(always)]
176 pub fn nbdev(mut self, x: f64) -> Self {
177 self.nbdev = Some(x);
178 self
179 }
180 #[inline(always)]
181 pub fn devtype(mut self, dt: usize) -> Self {
182 self.devtype = Some(dt);
183 self
184 }
185 #[inline(always)]
186 pub fn kernel(mut self, k: Kernel) -> Self {
187 self.kernel = k;
188 self
189 }
190 #[inline(always)]
191 pub fn apply(self, candles: &Candles) -> Result<ZscoreOutput, ZscoreError> {
192 let p = ZscoreParams {
193 period: self.period,
194 ma_type: self.ma_type,
195 nbdev: self.nbdev,
196 devtype: self.devtype,
197 };
198 let i = ZscoreInput::from_candles(candles, "close", p);
199 zscore_with_kernel(&i, self.kernel)
200 }
201 #[inline(always)]
202 pub fn apply_slice(self, data: &[f64]) -> Result<ZscoreOutput, ZscoreError> {
203 let p = ZscoreParams {
204 period: self.period,
205 ma_type: self.ma_type,
206 nbdev: self.nbdev,
207 devtype: self.devtype,
208 };
209 let i = ZscoreInput::from_slice(data, p);
210 zscore_with_kernel(&i, self.kernel)
211 }
212 #[inline(always)]
213 pub fn into_stream(self) -> Result<ZscoreStream, ZscoreError> {
214 let p = ZscoreParams {
215 period: self.period,
216 ma_type: self.ma_type,
217 nbdev: self.nbdev,
218 devtype: self.devtype,
219 };
220 ZscoreStream::try_new(p)
221 }
222}
223
224#[derive(Debug, Error)]
225pub enum ZscoreError {
226 #[error("zscore: Input data slice is empty.")]
227 EmptyInputData,
228 #[error("zscore: All values are NaN.")]
229 AllValuesNaN,
230 #[error("zscore: Invalid period: period = {period}, data length = {data_len}")]
231 InvalidPeriod { period: usize, data_len: usize },
232 #[error("zscore: Not enough valid data: needed = {needed}, valid = {valid}")]
233 NotEnoughValidData { needed: usize, valid: usize },
234 #[error("zscore: output length mismatch: expected = {expected}, got = {got}")]
235 OutputLengthMismatch { expected: usize, got: usize },
236 #[error("zscore: invalid range: start={start}, end={end}, step={step}")]
237 InvalidRange { start: f64, end: f64, step: f64 },
238 #[error("zscore: invalid kernel for batch: {0:?}")]
239 InvalidKernelForBatch(Kernel),
240 #[error("zscore: DevError {0}")]
241 DevError(#[from] DevError),
242 #[error("zscore: MaError {0}")]
243 MaError(String),
244}
245
246#[inline]
247pub fn zscore(input: &ZscoreInput) -> Result<ZscoreOutput, ZscoreError> {
248 zscore_with_kernel(input, Kernel::Auto)
249}
250
251pub fn zscore_with_kernel(
252 input: &ZscoreInput,
253 kernel: Kernel,
254) -> Result<ZscoreOutput, ZscoreError> {
255 let data: &[f64] = input.as_ref();
256 if data.is_empty() {
257 return Err(ZscoreError::EmptyInputData);
258 }
259
260 let first = data
261 .iter()
262 .position(|x| !x.is_nan())
263 .ok_or(ZscoreError::AllValuesNaN)?;
264 let len = data.len();
265 let period = input.get_period();
266 if period == 0 || period > len {
267 return Err(ZscoreError::InvalidPeriod {
268 period,
269 data_len: len,
270 });
271 }
272 if (len - first) < period {
273 return Err(ZscoreError::NotEnoughValidData {
274 needed: period,
275 valid: len - first,
276 });
277 }
278
279 let ma_type = input.get_ma_type();
280 let nbdev = input.get_nbdev();
281 let devtype = input.get_devtype();
282
283 let chosen = match kernel {
284 Kernel::Auto => match detect_best_kernel() {
285 Kernel::Avx512 | Kernel::Avx512Batch => Kernel::Avx2,
286 other => other,
287 },
288 other => other,
289 };
290
291 unsafe {
292 match chosen {
293 Kernel::Scalar | Kernel::ScalarBatch => {
294 zscore_scalar(data, period, first, &ma_type, nbdev, devtype)
295 }
296 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
297 Kernel::Avx2 | Kernel::Avx2Batch => {
298 zscore_avx2(data, period, first, &ma_type, nbdev, devtype)
299 }
300 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
301 Kernel::Avx512 | Kernel::Avx512Batch => {
302 zscore_avx512(data, period, first, &ma_type, nbdev, devtype)
303 }
304 _ => unreachable!(),
305 }
306 }
307}
308
309#[inline]
310pub unsafe fn zscore_scalar(
311 data: &[f64],
312 period: usize,
313 first: usize,
314 ma_type: &str,
315 nbdev: f64,
316 devtype: usize,
317) -> Result<ZscoreOutput, ZscoreError> {
318 if devtype == 0 {
319 if ma_type == "sma" {
320 return zscore_scalar_classic_sma(data, period, first, nbdev);
321 } else if ma_type == "ema" {
322 return zscore_scalar_classic_ema(data, period, first, nbdev);
323 }
324 }
325
326 let means = ma(ma_type, MaData::Slice(data), period)
327 .map_err(|e| ZscoreError::MaError(e.to_string()))?;
328 let dev_input = DevInput {
329 data: DeviationData::Slice(data),
330 params: DevParams {
331 period: Some(period),
332 devtype: Some(devtype),
333 },
334 };
335 let mut sigmas = deviation(&dev_input)?.values;
336 for v in &mut sigmas {
337 *v *= nbdev;
338 }
339
340 let warmup_end = first + period - 1;
341 let n = data.len();
342 let mut out = alloc_with_nan_prefix(n, warmup_end);
343 let mut i = warmup_end;
344 while i < n {
345 let mean = *means.get_unchecked(i);
346 let sigma = *sigmas.get_unchecked(i);
347 let value = *data.get_unchecked(i);
348 *out.get_unchecked_mut(i) = if sigma == 0.0 || sigma.is_nan() {
349 f64::NAN
350 } else {
351 (value - mean) / sigma
352 };
353 i += 1;
354 }
355 Ok(ZscoreOutput { values: out })
356}
357
358#[inline]
359pub unsafe fn zscore_scalar_classic_sma(
360 data: &[f64],
361 period: usize,
362 first: usize,
363 nbdev: f64,
364) -> Result<ZscoreOutput, ZscoreError> {
365 let warmup_end = first + period - 1;
366 let mut out = alloc_with_nan_prefix(data.len(), warmup_end);
367
368 let inv = 1.0 / (period as f64);
369 let mut sum = 0.0f64;
370 let mut sum_sqr = 0.0f64;
371 {
372 let mut j = first;
373 while j <= warmup_end {
374 let v = *data.get_unchecked(j);
375 sum += v;
376
377 sum_sqr = v.mul_add(v, sum_sqr);
378 j += 1;
379 }
380 }
381 let mut mean = sum * inv;
382
383 let mut variance = (-mean).mul_add(mean, sum_sqr * inv);
384 if variance < 0.0 {
385 variance = 0.0;
386 }
387 let mut stddev = if variance == 0.0 {
388 0.0
389 } else {
390 variance.sqrt() * nbdev
391 };
392
393 let xw = *data.get_unchecked(warmup_end);
394 *out.get_unchecked_mut(warmup_end) = if stddev == 0.0 || stddev.is_nan() {
395 f64::NAN
396 } else {
397 (xw - mean) / stddev
398 };
399
400 let n = data.len();
401 let mut i = warmup_end + 1;
402 while i < n {
403 let old_val = *data.get_unchecked(i - period);
404 let new_val = *data.get_unchecked(i);
405 let dd = new_val - old_val;
406 sum += dd;
407
408 sum_sqr = (new_val + old_val).mul_add(dd, sum_sqr);
409 mean = sum * inv;
410
411 variance = (-mean).mul_add(mean, sum_sqr * inv);
412 if variance < 0.0 {
413 variance = 0.0;
414 }
415 stddev = if variance == 0.0 {
416 0.0
417 } else {
418 variance.sqrt() * nbdev
419 };
420
421 *out.get_unchecked_mut(i) = if stddev == 0.0 || stddev.is_nan() {
422 f64::NAN
423 } else {
424 (new_val - mean) / stddev
425 };
426 i += 1;
427 }
428
429 Ok(ZscoreOutput { values: out })
430}
431
432#[inline]
433pub unsafe fn zscore_scalar_classic_ema(
434 data: &[f64],
435 period: usize,
436 first: usize,
437 nbdev: f64,
438) -> Result<ZscoreOutput, ZscoreError> {
439 let n = data.len();
440 let warmup_end = first + period - 1;
441 let mut out = alloc_with_nan_prefix(n, warmup_end);
442
443 if n <= warmup_end {
444 return Ok(ZscoreOutput { values: out });
445 }
446
447 let den = period as f64;
448 let inv = 1.0 / den;
449 let alpha = 2.0 / (den + 1.0);
450 let one_minus_alpha = 1.0 - alpha;
451
452 let mut sum = 0.0f64;
453 let mut sum2 = 0.0f64;
454 {
455 let mut j = first;
456 while j <= warmup_end {
457 let v = *data.get_unchecked(j);
458 sum += v;
459 sum2 = v.mul_add(v, sum2);
460 j += 1;
461 }
462 }
463 let mut ema = sum * inv;
464
465 let mut ex = sum * inv;
466 let mut ex2 = sum2 * inv;
467 let mut mse = (-2.0 * ema).mul_add(ex, ema.mul_add(ema, ex2));
468 if mse < 0.0 {
469 mse = 0.0;
470 }
471 let mut sd = mse.sqrt() * nbdev;
472
473 let xw = *data.get_unchecked(warmup_end);
474 *out.get_unchecked_mut(warmup_end) = if sd == 0.0 || sd.is_nan() {
475 f64::NAN
476 } else {
477 (xw - ema) / sd
478 };
479
480 let mut i = warmup_end + 1;
481 while i < n {
482 let new = *data.get_unchecked(i);
483 let old = *data.get_unchecked(i - period);
484
485 let dd = new - old;
486 sum += dd;
487 sum2 = (new + old).mul_add(dd, sum2);
488 ex = sum * inv;
489 ex2 = sum2 * inv;
490
491 ema = ema.mul_add(one_minus_alpha, alpha * new);
492
493 mse = (-2.0 * ema).mul_add(ex, ema.mul_add(ema, ex2));
494 if mse < 0.0 {
495 mse = 0.0;
496 }
497 sd = mse.sqrt() * nbdev;
498
499 *out.get_unchecked_mut(i) = if sd == 0.0 || sd.is_nan() {
500 f64::NAN
501 } else {
502 (new - ema) / sd
503 };
504 i += 1;
505 }
506
507 Ok(ZscoreOutput { values: out })
508}
509
510#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
511#[inline]
512pub unsafe fn zscore_avx2(
513 data: &[f64],
514 period: usize,
515 first: usize,
516 ma_type: &str,
517 nbdev: f64,
518 devtype: usize,
519) -> Result<ZscoreOutput, ZscoreError> {
520 zscore_scalar(data, period, first, ma_type, nbdev, devtype)
521}
522
523#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
524#[inline]
525pub unsafe fn zscore_avx512(
526 data: &[f64],
527 period: usize,
528 first: usize,
529 ma_type: &str,
530 nbdev: f64,
531 devtype: usize,
532) -> Result<ZscoreOutput, ZscoreError> {
533 if period <= 32 {
534 zscore_avx512_short(data, period, first, ma_type, nbdev, devtype)
535 } else {
536 zscore_avx512_long(data, period, first, ma_type, nbdev, devtype)
537 }
538}
539
540#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
541#[inline]
542pub unsafe fn zscore_avx512_short(
543 data: &[f64],
544 period: usize,
545 first: usize,
546 ma_type: &str,
547 nbdev: f64,
548 devtype: usize,
549) -> Result<ZscoreOutput, ZscoreError> {
550 zscore_scalar(data, period, first, ma_type, nbdev, devtype)
551}
552
553#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
554#[inline]
555pub unsafe fn zscore_avx512_long(
556 data: &[f64],
557 period: usize,
558 first: usize,
559 ma_type: &str,
560 nbdev: f64,
561 devtype: usize,
562) -> Result<ZscoreOutput, ZscoreError> {
563 zscore_scalar(data, period, first, ma_type, nbdev, devtype)
564}
565
566#[derive(Debug, Clone)]
567pub struct ZscoreStream {
568 period: usize,
569 ma_type: String,
570 nbdev: f64,
571 devtype: usize,
572
573 buffer: Vec<f64>,
574 head: usize,
575 filled: bool,
576
577 sum: f64,
578 sum2: f64,
579
580 wsum: f64,
581
582 ema: f64,
583 ema_inited: bool,
584
585 nan_count: usize,
586
587 inv_period: f64,
588 wma_denom: f64,
589 inv_wma_denom: f64,
590 inv_nbdev: f64,
591
592 kind: MaKind,
593}
594
595#[derive(Debug, Clone, Copy, PartialEq, Eq)]
596enum MaKind {
597 Sma,
598 Ema,
599 Wma,
600 Other,
601}
602
603impl ZscoreStream {
604 pub fn try_new(params: ZscoreParams) -> Result<Self, ZscoreError> {
605 let period = params.period.unwrap_or(14);
606 if period == 0 {
607 return Err(ZscoreError::InvalidPeriod {
608 period,
609 data_len: 0,
610 });
611 }
612 let ma_type = params.ma_type.unwrap_or_else(|| "sma".to_string());
613 let nbdev = params.nbdev.unwrap_or(1.0);
614 let devtype = params.devtype.unwrap_or(0);
615
616 let kind = match ma_type.to_ascii_lowercase().as_str() {
617 "sma" => MaKind::Sma,
618 "ema" => MaKind::Ema,
619 "wma" => MaKind::Wma,
620 _ => MaKind::Other,
621 };
622
623 let n = period as f64;
624 let wden = n * (n + 1.0) * 0.5;
625 let inv_nbdev = if nbdev != 0.0 {
626 1.0 / nbdev
627 } else {
628 f64::INFINITY
629 };
630
631 Ok(Self {
632 period,
633 ma_type,
634 nbdev,
635 devtype,
636 buffer: vec![f64::NAN; period],
637 head: 0,
638 filled: false,
639
640 sum: 0.0,
641 sum2: 0.0,
642 wsum: 0.0,
643
644 ema: 0.0,
645 ema_inited: false,
646
647 nan_count: period,
648
649 inv_period: 1.0 / n,
650 wma_denom: wden,
651 inv_wma_denom: 1.0 / wden,
652 inv_nbdev,
653 kind,
654 })
655 }
656
657 #[inline(always)]
658 pub fn update(&mut self, value: f64) -> Option<f64> {
659 let old = self.buffer[self.head];
660 self.buffer[self.head] = value;
661 self.head = (self.head + 1) % self.period;
662 let just_filled = !self.filled && self.head == 0;
663 if just_filled {
664 self.filled = true;
665 }
666
667 if old.is_nan() {
668 self.nan_count = self.nan_count.saturating_sub(1);
669 }
670 if value.is_nan() {
671 self.nan_count += 1;
672 }
673
674 let sum_prev = self.sum;
675
676 let old_c = if old.is_nan() { 0.0 } else { old };
677 let new_c = if value.is_nan() { 0.0 } else { value };
678
679 self.sum = self.sum + new_c - old_c;
680 self.sum2 = self.sum2 + new_c * new_c - old_c * old_c;
681
682 self.wsum = self.wsum - sum_prev + (self.period as f64) * new_c;
683
684 if !self.filled {
685 return None;
686 }
687
688 if self.devtype == 0 && self.nbdev != 0.0 && self.nan_count == 0 {
689 let mean = match self.kind {
690 MaKind::Sma => self.sum * self.inv_period,
691 MaKind::Ema => {
692 if !self.ema_inited {
693 self.ema = self.sum * self.inv_period;
694 self.ema_inited = true;
695 self.ema
696 } else {
697 let alpha = 2.0 / ((self.period as f64) + 1.0);
698 self.ema = self.ema.mul_add(1.0 - alpha, alpha * new_c);
699 self.ema
700 }
701 }
702 MaKind::Wma => self.wsum * self.inv_wma_denom,
703 MaKind::Other => {
704 return Some(self.compute_zscore_slow());
705 }
706 };
707
708 let ex = self.sum * self.inv_period;
709 let ex2 = self.sum2 * self.inv_period;
710
711 let var = if self.kind == MaKind::Sma {
712 let v = ex2 - mean * mean;
713 if v < 0.0 {
714 0.0
715 } else {
716 v
717 }
718 } else {
719 let v = ex2 - 2.0 * mean * ex + mean * mean;
720 if v < 0.0 {
721 0.0
722 } else {
723 v
724 }
725 };
726
727 let sd = var.sqrt();
728 if sd == 0.0 || !sd.is_finite() {
729 return Some(f64::NAN);
730 }
731
732 let last_idx = if self.head == 0 {
733 self.period - 1
734 } else {
735 self.head - 1
736 };
737 let last = self.buffer[last_idx];
738 let z = (last - mean) / (sd * self.nbdev);
739 return Some(z);
740 }
741
742 Some(self.compute_zscore_slow())
743 }
744
745 #[inline(always)]
746 fn compute_zscore_slow(&self) -> f64 {
747 let mut ordered = vec![0.0; self.period];
748 let mut idx = self.head;
749 for i in 0..self.period {
750 ordered[i] = self.buffer[idx];
751 idx = (idx + 1) % self.period;
752 }
753
754 let means = match ma(&self.ma_type, MaData::Slice(&ordered), self.period) {
755 Ok(m) => m,
756 Err(_) => return f64::NAN,
757 };
758
759 let dev_input = DevInput {
760 data: DeviationData::Slice(&ordered),
761 params: DevParams {
762 period: Some(self.period),
763 devtype: Some(self.devtype),
764 },
765 };
766 let mut sigmas = match deviation(&dev_input) {
767 Ok(d) => d.values,
768 Err(_) => return f64::NAN,
769 };
770 for s in &mut sigmas {
771 *s *= self.nbdev;
772 }
773
774 let mean = means[self.period - 1];
775 let sigma = sigmas[self.period - 1];
776 let value = ordered[self.period - 1];
777 if sigma == 0.0 || sigma.is_nan() {
778 f64::NAN
779 } else {
780 (value - mean) / sigma
781 }
782 }
783}
784
785#[derive(Clone, Debug)]
786pub struct ZscoreBatchRange {
787 pub period: (usize, usize, usize),
788 pub ma_type: (String, String, String),
789 pub nbdev: (f64, f64, f64),
790 pub devtype: (usize, usize, usize),
791}
792
793impl Default for ZscoreBatchRange {
794 fn default() -> Self {
795 Self {
796 period: (14, 263, 1),
797 ma_type: ("sma".to_string(), "sma".to_string(), "".to_string()),
798 nbdev: (1.0, 1.0, 0.0),
799 devtype: (0, 0, 0),
800 }
801 }
802}
803
804#[derive(Clone, Debug, Default)]
805pub struct ZscoreBatchBuilder {
806 range: ZscoreBatchRange,
807 kernel: Kernel,
808}
809
810impl ZscoreBatchBuilder {
811 pub fn new() -> Self {
812 Self::default()
813 }
814 pub fn kernel(mut self, k: Kernel) -> Self {
815 self.kernel = k;
816 self
817 }
818 #[inline]
819 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
820 self.range.period = (start, end, step);
821 self
822 }
823 #[inline]
824 pub fn period_static(mut self, p: usize) -> Self {
825 self.range.period = (p, p, 0);
826 self
827 }
828 #[inline]
829 pub fn ma_type_static<T: Into<String>>(mut self, s: T) -> Self {
830 let val = s.into();
831 self.range.ma_type = (val.clone(), val.clone(), "".to_string());
832 self
833 }
834 #[inline]
835 pub fn nbdev_range(mut self, start: f64, end: f64, step: f64) -> Self {
836 self.range.nbdev = (start, end, step);
837 self
838 }
839 #[inline]
840 pub fn nbdev_static(mut self, x: f64) -> Self {
841 self.range.nbdev = (x, x, 0.0);
842 self
843 }
844 #[inline]
845 pub fn devtype_range(mut self, start: usize, end: usize, step: usize) -> Self {
846 self.range.devtype = (start, end, step);
847 self
848 }
849 #[inline]
850 pub fn devtype_static(mut self, x: usize) -> Self {
851 self.range.devtype = (x, x, 0);
852 self
853 }
854 pub fn apply_slice(self, data: &[f64]) -> Result<ZscoreBatchOutput, ZscoreError> {
855 zscore_batch_with_kernel(data, &self.range, self.kernel)
856 }
857 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<ZscoreBatchOutput, ZscoreError> {
858 ZscoreBatchBuilder::new().kernel(k).apply_slice(data)
859 }
860 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<ZscoreBatchOutput, ZscoreError> {
861 let slice = source_type(c, src);
862 self.apply_slice(slice)
863 }
864 pub fn with_default_candles(c: &Candles) -> Result<ZscoreBatchOutput, ZscoreError> {
865 ZscoreBatchBuilder::new()
866 .kernel(Kernel::Auto)
867 .apply_candles(c, "close")
868 }
869}
870
871pub fn zscore_batch_with_kernel(
872 data: &[f64],
873 sweep: &ZscoreBatchRange,
874 k: Kernel,
875) -> Result<ZscoreBatchOutput, ZscoreError> {
876 let kernel = match k {
877 Kernel::Auto => detect_best_batch_kernel(),
878 other if other.is_batch() => other,
879 other => return Err(ZscoreError::InvalidKernelForBatch(other)),
880 };
881
882 let simd = match kernel {
883 Kernel::Avx512Batch => Kernel::Avx512,
884 Kernel::Avx2Batch => Kernel::Avx2,
885 Kernel::ScalarBatch => Kernel::Scalar,
886 _ => unreachable!(),
887 };
888 zscore_batch_par_slice(data, sweep, simd)
889}
890
891#[derive(Clone, Debug)]
892pub struct ZscoreBatchOutput {
893 pub values: Vec<f64>,
894 pub combos: Vec<ZscoreParams>,
895 pub rows: usize,
896 pub cols: usize,
897}
898
899impl ZscoreBatchOutput {
900 pub fn row_for_params(&self, p: &ZscoreParams) -> Option<usize> {
901 self.combos.iter().position(|c| {
902 c.period.unwrap_or(14) == p.period.unwrap_or(14)
903 && c.ma_type.as_ref().unwrap_or(&"sma".to_string())
904 == p.ma_type.as_ref().unwrap_or(&"sma".to_string())
905 && (c.nbdev.unwrap_or(1.0) - p.nbdev.unwrap_or(1.0)).abs() < 1e-12
906 && c.devtype.unwrap_or(0) == p.devtype.unwrap_or(0)
907 })
908 }
909 pub fn values_for(&self, p: &ZscoreParams) -> Option<&[f64]> {
910 self.row_for_params(p).map(|row| {
911 let start = row * self.cols;
912 &self.values[start..start + self.cols]
913 })
914 }
915}
916
917#[inline(always)]
918fn expand_grid(r: &ZscoreBatchRange) -> Result<Vec<ZscoreParams>, ZscoreError> {
919 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, ZscoreError> {
920 if step == 0 || start == end {
921 return Ok(vec![start]);
922 }
923 let mut vals = Vec::new();
924 if start < end {
925 let mut v = start;
926 while v <= end {
927 vals.push(v);
928 match v.checked_add(step) {
929 Some(next) => {
930 if next == v {
931 break;
932 }
933 v = next;
934 }
935 None => break,
936 }
937 }
938 } else {
939 let mut v = start;
940 loop {
941 vals.push(v);
942 if v == end {
943 break;
944 }
945 let next = v.saturating_sub(step);
946 if next == v {
947 break;
948 }
949 v = next;
950 if v < end {
951 break;
952 }
953 }
954 }
955 if vals.is_empty() {
956 return Err(ZscoreError::InvalidRange {
957 start: start as f64,
958 end: end as f64,
959 step: step as f64,
960 });
961 }
962 Ok(vals)
963 }
964 fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, ZscoreError> {
965 if !step.is_finite() {
966 return Err(ZscoreError::InvalidRange { start, end, step });
967 }
968 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
969 return Ok(vec![start]);
970 }
971 let mut vals = Vec::new();
972 let tol = 1e-12;
973 if start <= end {
974 let step_pos = step.abs();
975 let mut x = start;
976 while x <= end + tol {
977 vals.push(x);
978 x += step_pos;
979 if !x.is_finite() {
980 break;
981 }
982 }
983 } else {
984 let step_neg = -step.abs();
985 let mut x = start;
986 while x >= end - tol {
987 vals.push(x);
988 x += step_neg;
989 if !x.is_finite() {
990 break;
991 }
992 }
993 }
994 if vals.is_empty() {
995 return Err(ZscoreError::InvalidRange { start, end, step });
996 }
997 Ok(vals)
998 }
999 fn axis_string((start, end, _step): (String, String, String)) -> Vec<String> {
1000 if start == end {
1001 vec![start]
1002 } else {
1003 vec![start]
1004 }
1005 }
1006
1007 let periods = axis_usize(r.period)?;
1008 let ma_types = axis_string(r.ma_type.clone());
1009 let nbdevs = axis_f64(r.nbdev)?;
1010 let devtypes = axis_usize(r.devtype)?;
1011
1012 let total = periods
1013 .len()
1014 .checked_mul(ma_types.len())
1015 .and_then(|v| v.checked_mul(nbdevs.len()))
1016 .and_then(|v| v.checked_mul(devtypes.len()))
1017 .ok_or(ZscoreError::InvalidRange {
1018 start: periods.len() as f64,
1019 end: devtypes.len() as f64,
1020 step: nbdevs.len() as f64,
1021 })?;
1022
1023 let mut out = Vec::with_capacity(total);
1024 for &p in &periods {
1025 for mt in &ma_types {
1026 for &n in &nbdevs {
1027 for &dt in &devtypes {
1028 out.push(ZscoreParams {
1029 period: Some(p),
1030 ma_type: Some(mt.clone()),
1031 nbdev: Some(n),
1032 devtype: Some(dt),
1033 });
1034 }
1035 }
1036 }
1037 }
1038 Ok(out)
1039}
1040
1041#[inline(always)]
1042pub fn zscore_batch_slice(
1043 data: &[f64],
1044 sweep: &ZscoreBatchRange,
1045 kern: Kernel,
1046) -> Result<ZscoreBatchOutput, ZscoreError> {
1047 zscore_batch_inner(data, sweep, kern, false)
1048}
1049
1050#[inline(always)]
1051pub fn zscore_batch_par_slice(
1052 data: &[f64],
1053 sweep: &ZscoreBatchRange,
1054 kern: Kernel,
1055) -> Result<ZscoreBatchOutput, ZscoreError> {
1056 zscore_batch_inner(data, sweep, kern, true)
1057}
1058
1059#[inline(always)]
1060fn zscore_batch_inner(
1061 data: &[f64],
1062 sweep: &ZscoreBatchRange,
1063 kern: Kernel,
1064 parallel: bool,
1065) -> Result<ZscoreBatchOutput, ZscoreError> {
1066 if data.is_empty() {
1067 return Err(ZscoreError::EmptyInputData);
1068 }
1069
1070 let combos = expand_grid(sweep)?;
1071
1072 let first = data
1073 .iter()
1074 .position(|x| !x.is_nan())
1075 .ok_or(ZscoreError::AllValuesNaN)?;
1076 let cols = data.len();
1077 let mut max_p = 0usize;
1078 for prm in &combos {
1079 let period = prm.period.unwrap();
1080 if period == 0 || period > cols {
1081 return Err(ZscoreError::InvalidPeriod {
1082 period,
1083 data_len: cols,
1084 });
1085 }
1086 if period > max_p {
1087 max_p = period;
1088 }
1089 }
1090 if cols - first < max_p {
1091 return Err(ZscoreError::NotEnoughValidData {
1092 needed: max_p,
1093 valid: cols - first,
1094 });
1095 }
1096
1097 let rows = combos.len();
1098
1099 rows.checked_mul(cols).ok_or(ZscoreError::InvalidRange {
1100 start: rows as f64,
1101 end: cols as f64,
1102 step: 0.0,
1103 })?;
1104
1105 let mut buf_mu = make_uninit_matrix(rows, cols);
1106
1107 let warmup_periods: Vec<usize> = combos
1108 .iter()
1109 .map(|c| {
1110 let p = c.period.unwrap();
1111 first
1112 .checked_add(p)
1113 .and_then(|v| v.checked_sub(1))
1114 .ok_or(ZscoreError::InvalidRange {
1115 start: first as f64,
1116 end: p as f64,
1117 step: 1.0,
1118 })
1119 })
1120 .collect::<Result<_, _>>()?;
1121
1122 init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
1123
1124 let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
1125 let out: &mut [f64] = unsafe {
1126 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
1127 };
1128
1129 let mut groups: HashMap<usize, Vec<(usize, f64)>> = HashMap::new();
1130 let mut fallback_rows: Vec<usize> = Vec::new();
1131 for (row_idx, prm) in combos.iter().enumerate() {
1132 let period = prm.period.unwrap();
1133 let ma_type = prm.ma_type.as_ref().unwrap();
1134 let devtype = prm.devtype.unwrap();
1135 if devtype == 0 && ma_type == "sma" {
1136 groups
1137 .entry(period)
1138 .or_default()
1139 .push((row_idx, prm.nbdev.unwrap()));
1140 } else {
1141 fallback_rows.push(row_idx);
1142 }
1143 }
1144
1145 let prefixes = if groups.is_empty() {
1146 None
1147 } else {
1148 Some(build_sma_std_prefixes(data))
1149 };
1150
1151 let writer = RowWriter {
1152 ptr: out.as_mut_ptr(),
1153 cols,
1154 };
1155
1156 for (period, rows_for_period) in groups.into_iter() {
1157 let warmup_end = first + period - 1;
1158 let mut base = vec![f64::NAN; cols];
1159
1160 match kern {
1161 Kernel::Scalar => {
1162 let pre = prefixes.as_ref().expect("prefixes missing for scalar path");
1163 zscore_sma_std_from_prefix_scalar(data, period, warmup_end, pre, &mut base);
1164 }
1165 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1166 Kernel::Avx2 => unsafe {
1167 let pre = prefixes.as_ref().expect("prefixes missing for AVX2 path");
1168 zscore_sma_std_from_prefix_avx2(data, period, warmup_end, pre, &mut base);
1169 },
1170 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1171 Kernel::Avx512 => unsafe {
1172 let pre = prefixes.as_ref().expect("prefixes missing for AVX512 path");
1173 zscore_sma_std_from_prefix_avx512(data, period, warmup_end, pre, &mut base);
1174 },
1175 _ => unreachable!(),
1176 }
1177
1178 let base_ref = &base;
1179
1180 let write_scalar = |row_idx: usize, nbdev: f64| unsafe {
1181 writer.with_row(row_idx, |dst| {
1182 scale_copy_row_scalar(base_ref, warmup_end, nbdev, dst);
1183 });
1184 };
1185
1186 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1187 let write_avx2 = |row_idx: usize, nbdev: f64| unsafe {
1188 writer.with_row(row_idx, |dst| {
1189 scale_copy_row_avx2(base_ref, warmup_end, nbdev, dst);
1190 });
1191 };
1192
1193 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1194 let write_avx512 = |row_idx: usize, nbdev: f64| unsafe {
1195 writer.with_row(row_idx, |dst| {
1196 scale_copy_row_avx512(base_ref, warmup_end, nbdev, dst);
1197 });
1198 };
1199
1200 let dispatch_write = |row_idx: usize, nbdev: f64| match kern {
1201 Kernel::Scalar => write_scalar(row_idx, nbdev),
1202 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1203 Kernel::Avx2 => write_avx2(row_idx, nbdev),
1204 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1205 Kernel::Avx512 => write_avx512(row_idx, nbdev),
1206 _ => unreachable!(),
1207 };
1208
1209 if parallel {
1210 #[cfg(not(target_arch = "wasm32"))]
1211 {
1212 use rayon::prelude::*;
1213 rows_for_period
1214 .par_iter()
1215 .for_each(|(row_idx, nb)| dispatch_write(*row_idx, *nb));
1216 }
1217 #[cfg(target_arch = "wasm32")]
1218 {
1219 for (row_idx, nb) in rows_for_period.iter() {
1220 dispatch_write(*row_idx, *nb);
1221 }
1222 }
1223 } else {
1224 for (row_idx, nb) in rows_for_period.iter() {
1225 dispatch_write(*row_idx, *nb);
1226 }
1227 }
1228 }
1229
1230 if !fallback_rows.is_empty() {
1231 let do_row = |row: usize| unsafe {
1232 let prm = &combos[row];
1233 let period = prm.period.unwrap();
1234 let ma_type = prm.ma_type.as_ref().unwrap();
1235 let nbdev = prm.nbdev.unwrap();
1236 let devtype = prm.devtype.unwrap();
1237 writer.with_row(row, |dst| match kern {
1238 Kernel::Scalar => {
1239 zscore_row_scalar(data, first, period, ma_type, nbdev, devtype, dst)
1240 }
1241 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1242 Kernel::Avx2 => zscore_row_avx2(data, first, period, ma_type, nbdev, devtype, dst),
1243 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1244 Kernel::Avx512 => {
1245 zscore_row_avx512(data, first, period, ma_type, nbdev, devtype, dst)
1246 }
1247 _ => unreachable!(),
1248 });
1249 };
1250
1251 if parallel {
1252 #[cfg(not(target_arch = "wasm32"))]
1253 {
1254 use rayon::prelude::*;
1255 fallback_rows.par_iter().for_each(|&row| do_row(row));
1256 }
1257 #[cfg(target_arch = "wasm32")]
1258 {
1259 for &row in &fallback_rows {
1260 do_row(row);
1261 }
1262 }
1263 } else {
1264 for &row in &fallback_rows {
1265 do_row(row);
1266 }
1267 }
1268 }
1269
1270 let values = unsafe {
1271 Vec::from_raw_parts(
1272 buf_guard.as_mut_ptr() as *mut f64,
1273 buf_guard.len(),
1274 buf_guard.capacity(),
1275 )
1276 };
1277
1278 Ok(ZscoreBatchOutput {
1279 values,
1280 combos,
1281 rows,
1282 cols,
1283 })
1284}
1285
1286#[inline(always)]
1287unsafe fn zscore_row_scalar(
1288 data: &[f64],
1289 first: usize,
1290 period: usize,
1291 ma_type: &str,
1292 nbdev: f64,
1293 devtype: usize,
1294 out: &mut [f64],
1295) {
1296 if devtype == 0 {
1297 if ma_type == "sma" {
1298 zscore_row_scalar_classic_sma(data, first, period, nbdev, out);
1299 return;
1300 } else if ma_type == "ema" {
1301 zscore_row_scalar_classic_ema(data, first, period, nbdev, out);
1302 return;
1303 }
1304 }
1305
1306 let means = match ma(ma_type, MaData::Slice(data), period) {
1307 Ok(m) => m,
1308 Err(_) => {
1309 out.fill(f64::NAN);
1310 return;
1311 }
1312 };
1313 let dev_input = DevInput {
1314 data: DeviationData::Slice(data),
1315 params: DevParams {
1316 period: Some(period),
1317 devtype: Some(devtype),
1318 },
1319 };
1320 let mut sigmas = match deviation(&dev_input) {
1321 Ok(d) => d.values,
1322 Err(_) => {
1323 out.fill(f64::NAN);
1324 return;
1325 }
1326 };
1327 for v in &mut sigmas {
1328 *v *= nbdev;
1329 }
1330 let warmup_end = first + period - 1;
1331 for i in warmup_end..data.len() {
1332 let mean = means[i];
1333 let sigma = sigmas[i];
1334 let value = data[i];
1335 out[i] = if sigma == 0.0 || sigma.is_nan() {
1336 f64::NAN
1337 } else {
1338 (value - mean) / sigma
1339 };
1340 }
1341}
1342
1343#[inline(always)]
1344unsafe fn zscore_row_scalar_classic_sma(
1345 data: &[f64],
1346 first: usize,
1347 period: usize,
1348 nbdev: f64,
1349 out: &mut [f64],
1350) {
1351 let warmup_end = first + period - 1;
1352
1353 let mut sum = 0.0;
1354 let mut sum_sqr = 0.0;
1355 for j in first..=warmup_end {
1356 let val = data[j];
1357 sum += val;
1358 sum_sqr += val * val;
1359 }
1360
1361 let mut mean = sum / period as f64;
1362 let mut variance = (sum_sqr / period as f64) - (mean * mean);
1363 let mut stddev = if variance <= 0.0 {
1364 0.0
1365 } else {
1366 variance.sqrt() * nbdev
1367 };
1368
1369 out[warmup_end] = if stddev == 0.0 || stddev.is_nan() {
1370 f64::NAN
1371 } else {
1372 (data[warmup_end] - mean) / stddev
1373 };
1374
1375 for i in warmup_end + 1..data.len() {
1376 let old_val = data[i - period];
1377 let new_val = data[i];
1378
1379 sum += new_val - old_val;
1380 sum_sqr += new_val * new_val - old_val * old_val;
1381
1382 mean = sum / period as f64;
1383 variance = (sum_sqr / period as f64) - (mean * mean);
1384 stddev = if variance <= 0.0 {
1385 0.0
1386 } else {
1387 variance.sqrt() * nbdev
1388 };
1389
1390 out[i] = if stddev == 0.0 || stddev.is_nan() {
1391 f64::NAN
1392 } else {
1393 (new_val - mean) / stddev
1394 };
1395 }
1396}
1397
1398#[derive(Clone, Debug)]
1399struct SmaStdPrefixes {
1400 ps: Vec<f64>,
1401 ps2: Vec<f64>,
1402 pnan: Vec<i32>,
1403}
1404
1405#[inline]
1406fn build_sma_std_prefixes(data: &[f64]) -> SmaStdPrefixes {
1407 let n = data.len();
1408 let mut ps = vec![0.0f64; n + 1];
1409 let mut ps2 = vec![0.0f64; n + 1];
1410 let mut pnan = vec![0i32; n + 1];
1411
1412 for i in 0..n {
1413 let v = data[i];
1414 if v.is_nan() {
1415 ps[i + 1] = ps[i];
1416 ps2[i + 1] = ps2[i];
1417 pnan[i + 1] = pnan[i] + 1;
1418 } else {
1419 ps[i + 1] = ps[i] + v;
1420 ps2[i + 1] = ps2[i] + v * v;
1421 pnan[i + 1] = pnan[i];
1422 }
1423 }
1424
1425 SmaStdPrefixes { ps, ps2, pnan }
1426}
1427
1428#[inline]
1429fn zscore_sma_std_from_prefix_scalar(
1430 data: &[f64],
1431 period: usize,
1432 warmup_end: usize,
1433 pre: &SmaStdPrefixes,
1434 base_out: &mut [f64],
1435) {
1436 let n = data.len();
1437 debug_assert_eq!(base_out.len(), n);
1438
1439 for v in &mut base_out[..warmup_end] {
1440 *v = f64::NAN;
1441 }
1442 if n <= warmup_end {
1443 return;
1444 }
1445
1446 let denom = period as f64;
1447 for i in warmup_end..n {
1448 let nan_count = pre.pnan[i + 1] - pre.pnan[i + 1 - period];
1449 if nan_count > 0 {
1450 base_out[i] = f64::NAN;
1451 continue;
1452 }
1453
1454 let sum = pre.ps[i + 1] - pre.ps[i + 1 - period];
1455 let sum2 = pre.ps2[i + 1] - pre.ps2[i + 1 - period];
1456 let mean = sum / denom;
1457 let variance = sum2 / denom - mean * mean;
1458 let stdv = if variance <= 0.0 {
1459 0.0
1460 } else {
1461 variance.sqrt()
1462 };
1463 base_out[i] = if stdv == 0.0 || stdv.is_nan() {
1464 f64::NAN
1465 } else {
1466 (data[i] - mean) / stdv
1467 };
1468 }
1469}
1470
1471#[inline]
1472fn scale_copy_row_scalar(src_base: &[f64], warmup_end: usize, nbdev: f64, dst: &mut [f64]) {
1473 debug_assert_eq!(src_base.len(), dst.len());
1474
1475 if warmup_end > 0 {
1476 dst[..warmup_end].copy_from_slice(&src_base[..warmup_end]);
1477 }
1478
1479 if dst.len() <= warmup_end {
1480 return;
1481 }
1482
1483 if nbdev == 0.0 {
1484 for v in &mut dst[warmup_end..] {
1485 *v = f64::NAN;
1486 }
1487 return;
1488 }
1489
1490 for (d, s) in dst[warmup_end..].iter_mut().zip(&src_base[warmup_end..]) {
1491 *d = *s / nbdev;
1492 }
1493}
1494
1495#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1496#[target_feature(enable = "avx2")]
1497#[inline]
1498unsafe fn zscore_sma_std_from_prefix_avx2(
1499 data: &[f64],
1500 period: usize,
1501 warmup_end: usize,
1502 pre: &SmaStdPrefixes,
1503 base_out: &mut [f64],
1504) {
1505 let n = data.len();
1506 debug_assert_eq!(base_out.len(), n);
1507
1508 for v in &mut base_out[..warmup_end] {
1509 *v = f64::NAN;
1510 }
1511 if n <= warmup_end {
1512 return;
1513 }
1514
1515 let den = _mm256_set1_pd(period as f64);
1516 let zero = _mm256_set1_pd(0.0);
1517 let nanv = _mm256_set1_pd(f64::NAN);
1518
1519 let mut i = warmup_end;
1520 while i + 4 <= n {
1521 let s_hi = _mm256_loadu_pd(pre.ps.as_ptr().add(i + 1));
1522 let s_lo = _mm256_loadu_pd(pre.ps.as_ptr().add(i + 1 - period));
1523 let sum = _mm256_sub_pd(s_hi, s_lo);
1524
1525 let q_hi = _mm256_loadu_pd(pre.ps2.as_ptr().add(i + 1));
1526 let q_lo = _mm256_loadu_pd(pre.ps2.as_ptr().add(i + 1 - period));
1527 let sum2 = _mm256_sub_pd(q_hi, q_lo);
1528
1529 let mean = _mm256_div_pd(sum, den);
1530 let var = _mm256_sub_pd(_mm256_div_pd(sum2, den), _mm256_mul_pd(mean, mean));
1531 let var_nz = _mm256_max_pd(var, zero);
1532 let stdv = _mm256_sqrt_pd(var_nz);
1533
1534 let x = _mm256_loadu_pd(data.as_ptr().add(i));
1535 let z = _mm256_div_pd(_mm256_sub_pd(x, mean), stdv);
1536
1537 let m_std0 = _mm256_cmp_pd(stdv, zero, _CMP_EQ_OQ);
1538 let m_stdnan = _mm256_cmp_pd(stdv, stdv, _CMP_UNORD_Q);
1539
1540 let cur = _mm_loadu_si128(pre.pnan.as_ptr().add(i + 1) as *const _);
1541 let prev = _mm_loadu_si128(pre.pnan.as_ptr().add(i + 1 - period) as *const _);
1542 let diff = _mm_sub_epi32(cur, prev);
1543 let diff_pd = _mm256_cvtepi32_pd(diff);
1544 let m_hasnan = _mm256_cmp_pd(diff_pd, zero, _CMP_GT_OQ);
1545
1546 let mask = _mm256_or_pd(_mm256_or_pd(m_std0, m_stdnan), m_hasnan);
1547 let res = _mm256_blendv_pd(z, nanv, mask);
1548 _mm256_storeu_pd(base_out.as_mut_ptr().add(i), res);
1549
1550 i += 4;
1551 }
1552
1553 let den_s = period as f64;
1554 while i < n {
1555 let count = pre.pnan[i + 1] - pre.pnan[i + 1 - period];
1556 if count > 0 {
1557 base_out[i] = f64::NAN;
1558 } else {
1559 let sum = pre.ps[i + 1] - pre.ps[i + 1 - period];
1560 let sum2 = pre.ps2[i + 1] - pre.ps2[i + 1 - period];
1561 let mean = sum / den_s;
1562 let variance = sum2 / den_s - mean * mean;
1563 let sd = if variance <= 0.0 {
1564 0.0
1565 } else {
1566 variance.sqrt()
1567 };
1568 base_out[i] = if sd == 0.0 || sd.is_nan() {
1569 f64::NAN
1570 } else {
1571 (data[i] - mean) / sd
1572 };
1573 }
1574 i += 1;
1575 }
1576}
1577
1578#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1579#[target_feature(enable = "avx512f")]
1580#[inline]
1581unsafe fn zscore_sma_std_from_prefix_avx512(
1582 data: &[f64],
1583 period: usize,
1584 warmup_end: usize,
1585 pre: &SmaStdPrefixes,
1586 base_out: &mut [f64],
1587) {
1588 let n = data.len();
1589 debug_assert_eq!(base_out.len(), n);
1590
1591 for v in &mut base_out[..warmup_end] {
1592 *v = f64::NAN;
1593 }
1594 if n <= warmup_end {
1595 return;
1596 }
1597
1598 let den = _mm512_set1_pd(period as f64);
1599 let zero = _mm512_set1_pd(0.0);
1600 let nanv = _mm512_set1_pd(f64::NAN);
1601
1602 let mut i = warmup_end;
1603 while i + 8 <= n {
1604 let s_hi = _mm512_loadu_pd(pre.ps.as_ptr().add(i + 1));
1605 let s_lo = _mm512_loadu_pd(pre.ps.as_ptr().add(i + 1 - period));
1606 let sum = _mm512_sub_pd(s_hi, s_lo);
1607
1608 let q_hi = _mm512_loadu_pd(pre.ps2.as_ptr().add(i + 1));
1609 let q_lo = _mm512_loadu_pd(pre.ps2.as_ptr().add(i + 1 - period));
1610 let sum2 = _mm512_sub_pd(q_hi, q_lo);
1611
1612 let mean = _mm512_div_pd(sum, den);
1613 let var = _mm512_sub_pd(_mm512_div_pd(sum2, den), _mm512_mul_pd(mean, mean));
1614 let var_nz = _mm512_max_pd(var, zero);
1615 let stdv = _mm512_sqrt_pd(var_nz);
1616
1617 let x = _mm512_loadu_pd(data.as_ptr().add(i));
1618 let z = _mm512_div_pd(_mm512_sub_pd(x, mean), stdv);
1619
1620 let k_std0 = _mm512_cmp_pd_mask(stdv, zero, _CMP_EQ_OQ);
1621 let k_stdnan = _mm512_cmp_pd_mask(stdv, stdv, _CMP_UNORD_Q);
1622
1623 let cur_i = _mm256_loadu_si256(pre.pnan.as_ptr().add(i + 1) as *const _);
1624 let prev_i = _mm256_loadu_si256(pre.pnan.as_ptr().add(i + 1 - period) as *const _);
1625 let diff_i = _mm256_sub_epi32(cur_i, prev_i);
1626 let diff_pd = _mm512_cvtepi32_pd(diff_i);
1627 let k_hasnan = _mm512_cmp_pd_mask(diff_pd, zero, _CMP_GT_OQ);
1628
1629 let k_bad = k_std0 | k_stdnan | k_hasnan;
1630 let res = _mm512_mask_mov_pd(z, k_bad, nanv);
1631 _mm512_storeu_pd(base_out.as_mut_ptr().add(i), res);
1632
1633 i += 8;
1634 }
1635
1636 let den_s = period as f64;
1637 while i < n {
1638 let count = pre.pnan[i + 1] - pre.pnan[i + 1 - period];
1639 if count > 0 {
1640 base_out[i] = f64::NAN;
1641 } else {
1642 let sum = pre.ps[i + 1] - pre.ps[i + 1 - period];
1643 let sum2 = pre.ps2[i + 1] - pre.ps2[i + 1 - period];
1644 let mean = sum / den_s;
1645 let variance = sum2 / den_s - mean * mean;
1646 let sd = if variance <= 0.0 {
1647 0.0
1648 } else {
1649 variance.sqrt()
1650 };
1651 base_out[i] = if sd == 0.0 || sd.is_nan() {
1652 f64::NAN
1653 } else {
1654 (data[i] - mean) / sd
1655 };
1656 }
1657 i += 1;
1658 }
1659}
1660
1661#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1662#[target_feature(enable = "avx2")]
1663#[inline]
1664unsafe fn scale_copy_row_avx2(src_base: &[f64], warmup_end: usize, nbdev: f64, dst: &mut [f64]) {
1665 debug_assert_eq!(src_base.len(), dst.len());
1666
1667 if warmup_end > 0 {
1668 dst[..warmup_end].copy_from_slice(&src_base[..warmup_end]);
1669 }
1670
1671 if dst.len() <= warmup_end {
1672 return;
1673 }
1674
1675 if nbdev == 0.0 {
1676 for v in &mut dst[warmup_end..] {
1677 *v = f64::NAN;
1678 }
1679 return;
1680 }
1681
1682 let inv = _mm256_set1_pd(1.0 / nbdev);
1683 let mut i = warmup_end;
1684 while i + 4 <= dst.len() {
1685 let v = _mm256_loadu_pd(src_base.as_ptr().add(i));
1686 let y = _mm256_mul_pd(v, inv);
1687 _mm256_storeu_pd(dst.as_mut_ptr().add(i), y);
1688 i += 4;
1689 }
1690 while i < dst.len() {
1691 dst[i] = src_base[i] / nbdev;
1692 i += 1;
1693 }
1694}
1695
1696#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1697#[target_feature(enable = "avx512f")]
1698#[inline]
1699unsafe fn scale_copy_row_avx512(src_base: &[f64], warmup_end: usize, nbdev: f64, dst: &mut [f64]) {
1700 debug_assert_eq!(src_base.len(), dst.len());
1701
1702 if warmup_end > 0 {
1703 dst[..warmup_end].copy_from_slice(&src_base[..warmup_end]);
1704 }
1705
1706 if dst.len() <= warmup_end {
1707 return;
1708 }
1709
1710 if nbdev == 0.0 {
1711 for v in &mut dst[warmup_end..] {
1712 *v = f64::NAN;
1713 }
1714 return;
1715 }
1716
1717 let inv = _mm512_set1_pd(1.0 / nbdev);
1718 let mut i = warmup_end;
1719 while i + 8 <= dst.len() {
1720 let v = _mm512_loadu_pd(src_base.as_ptr().add(i));
1721 let y = _mm512_mul_pd(v, inv);
1722 _mm512_storeu_pd(dst.as_mut_ptr().add(i), y);
1723 i += 8;
1724 }
1725 while i < dst.len() {
1726 dst[i] = src_base[i] / nbdev;
1727 i += 1;
1728 }
1729}
1730
1731#[derive(Clone, Copy)]
1732struct RowWriter {
1733 ptr: *mut f64,
1734 cols: usize,
1735}
1736
1737unsafe impl Send for RowWriter {}
1738unsafe impl Sync for RowWriter {}
1739
1740impl RowWriter {
1741 #[inline(always)]
1742 unsafe fn with_row<F>(&self, row: usize, mut f: F)
1743 where
1744 F: FnOnce(&mut [f64]),
1745 {
1746 let slice = std::slice::from_raw_parts_mut(self.ptr.add(row * self.cols), self.cols);
1747 f(slice);
1748 }
1749}
1750
1751#[inline(always)]
1752unsafe fn zscore_row_scalar_classic_ema(
1753 data: &[f64],
1754 first: usize,
1755 period: usize,
1756 nbdev: f64,
1757 out: &mut [f64],
1758) {
1759 let n = data.len();
1760 let warmup_end = first + period - 1;
1761
1762 if n <= warmup_end {
1763 return;
1764 }
1765
1766 let den = period as f64;
1767 let alpha = 2.0 / (den + 1.0);
1768 let one_minus_alpha = 1.0 - alpha;
1769
1770 let mut sum = 0.0;
1771 let mut sum2 = 0.0;
1772 {
1773 let mut j = first;
1774 while j <= warmup_end {
1775 let v = data[j];
1776 sum += v;
1777 sum2 += v * v;
1778 j += 1;
1779 }
1780 }
1781 let mut ema = sum / den;
1782
1783 let mut mse = (sum2 / den) - 2.0 * ema * (sum / den) + ema * ema;
1784 if mse < 0.0 {
1785 mse = 0.0;
1786 }
1787 let mut sd = mse.sqrt() * nbdev;
1788
1789 out[warmup_end] = if sd == 0.0 || sd.is_nan() {
1790 f64::NAN
1791 } else {
1792 (data[warmup_end] - ema) / sd
1793 };
1794
1795 let mut i = warmup_end + 1;
1796 while i < n {
1797 let new = data[i];
1798 let old = data[i - period];
1799
1800 sum += new - old;
1801 sum2 += new * new - old * old;
1802
1803 ema = alpha * new + one_minus_alpha * ema;
1804
1805 mse = (sum2 / den) - 2.0 * ema * (sum / den) + ema * ema;
1806 if mse < 0.0 {
1807 mse = 0.0;
1808 }
1809 sd = mse.sqrt() * nbdev;
1810
1811 out[i] = if sd == 0.0 || sd.is_nan() {
1812 f64::NAN
1813 } else {
1814 (new - ema) / sd
1815 };
1816 i += 1;
1817 }
1818}
1819
1820#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1821#[inline(always)]
1822unsafe fn zscore_row_avx2(
1823 data: &[f64],
1824 first: usize,
1825 period: usize,
1826 ma_type: &str,
1827 nbdev: f64,
1828 devtype: usize,
1829 out: &mut [f64],
1830) {
1831 zscore_row_scalar(data, first, period, ma_type, nbdev, devtype, out)
1832}
1833
1834#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1835#[inline(always)]
1836unsafe fn zscore_row_avx512(
1837 data: &[f64],
1838 first: usize,
1839 period: usize,
1840 ma_type: &str,
1841 nbdev: f64,
1842 devtype: usize,
1843 out: &mut [f64],
1844) {
1845 if period <= 32 {
1846 zscore_row_avx512_short(data, first, period, ma_type, nbdev, devtype, out)
1847 } else {
1848 zscore_row_avx512_long(data, first, period, ma_type, nbdev, devtype, out)
1849 }
1850}
1851
1852#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1853#[inline(always)]
1854unsafe fn zscore_row_avx512_short(
1855 data: &[f64],
1856 first: usize,
1857 period: usize,
1858 ma_type: &str,
1859 nbdev: f64,
1860 devtype: usize,
1861 out: &mut [f64],
1862) {
1863 zscore_row_scalar(data, first, period, ma_type, nbdev, devtype, out)
1864}
1865
1866#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1867#[inline(always)]
1868unsafe fn zscore_row_avx512_long(
1869 data: &[f64],
1870 first: usize,
1871 period: usize,
1872 ma_type: &str,
1873 nbdev: f64,
1874 devtype: usize,
1875 out: &mut [f64],
1876) {
1877 zscore_row_scalar(data, first, period, ma_type, nbdev, devtype, out)
1878}
1879
1880#[inline(always)]
1881pub fn zscore_batch_inner_into(
1882 data: &[f64],
1883 sweep: &ZscoreBatchRange,
1884 kern: Kernel,
1885 parallel: bool,
1886 out: &mut [f64],
1887) -> Result<Vec<ZscoreParams>, ZscoreError> {
1888 if data.is_empty() {
1889 return Err(ZscoreError::EmptyInputData);
1890 }
1891
1892 let combos = expand_grid(sweep)?;
1893
1894 let first = data
1895 .iter()
1896 .position(|x| !x.is_nan())
1897 .ok_or(ZscoreError::AllValuesNaN)?;
1898 let cols = data.len();
1899 let mut max_p = 0usize;
1900 for prm in &combos {
1901 let period = prm.period.unwrap();
1902 if period == 0 || period > cols {
1903 return Err(ZscoreError::InvalidPeriod {
1904 period,
1905 data_len: cols,
1906 });
1907 }
1908 if period > max_p {
1909 max_p = period;
1910 }
1911 }
1912 if cols - first < max_p {
1913 return Err(ZscoreError::NotEnoughValidData {
1914 needed: max_p,
1915 valid: cols - first,
1916 });
1917 }
1918
1919 let rows = combos.len();
1920
1921 let expected = rows.checked_mul(cols).ok_or(ZscoreError::InvalidRange {
1922 start: rows as f64,
1923 end: cols as f64,
1924 step: 0.0,
1925 })?;
1926 if out.len() != expected {
1927 return Err(ZscoreError::OutputLengthMismatch {
1928 expected,
1929 got: out.len(),
1930 });
1931 }
1932
1933 let warm: Vec<usize> = combos
1934 .iter()
1935 .map(|c| first + c.period.unwrap() - 1)
1936 .collect();
1937 {
1938 let out_uninit = unsafe {
1939 std::slice::from_raw_parts_mut(
1940 out.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
1941 out.len(),
1942 )
1943 };
1944 init_matrix_prefixes(out_uninit, cols, &warm);
1945 }
1946
1947 let mut groups: HashMap<usize, Vec<(usize, f64)>> = HashMap::new();
1948 let mut fallback_rows: Vec<usize> = Vec::new();
1949 for (row_idx, prm) in combos.iter().enumerate() {
1950 let period = prm.period.unwrap();
1951 let ma_type = prm.ma_type.as_ref().unwrap();
1952 let devtype = prm.devtype.unwrap();
1953 if devtype == 0 && ma_type == "sma" {
1954 groups
1955 .entry(period)
1956 .or_default()
1957 .push((row_idx, prm.nbdev.unwrap()));
1958 } else {
1959 fallback_rows.push(row_idx);
1960 }
1961 }
1962
1963 let prefixes = if groups.is_empty() {
1964 None
1965 } else {
1966 Some(build_sma_std_prefixes(data))
1967 };
1968
1969 let writer = RowWriter {
1970 ptr: out.as_mut_ptr(),
1971 cols,
1972 };
1973
1974 for (period, rows_for_period) in groups.into_iter() {
1975 let warmup_end = first + period - 1;
1976 let mut base = vec![f64::NAN; cols];
1977
1978 match kern {
1979 Kernel::Scalar => {
1980 let pre = prefixes.as_ref().expect("prefixes missing for scalar path");
1981 zscore_sma_std_from_prefix_scalar(data, period, warmup_end, pre, &mut base);
1982 }
1983 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1984 Kernel::Avx2 => unsafe {
1985 let pre = prefixes.as_ref().expect("prefixes missing for AVX2 path");
1986 zscore_sma_std_from_prefix_avx2(data, period, warmup_end, pre, &mut base);
1987 },
1988 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1989 Kernel::Avx512 => unsafe {
1990 let pre = prefixes.as_ref().expect("prefixes missing for AVX512 path");
1991 zscore_sma_std_from_prefix_avx512(data, period, warmup_end, pre, &mut base);
1992 },
1993 _ => unreachable!(),
1994 }
1995
1996 let base_ref = &base;
1997
1998 let write_scalar = |row_idx: usize, nbdev: f64| unsafe {
1999 writer.with_row(row_idx, |dst| {
2000 scale_copy_row_scalar(base_ref, warmup_end, nbdev, dst);
2001 });
2002 };
2003
2004 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2005 let write_avx2 = |row_idx: usize, nbdev: f64| unsafe {
2006 writer.with_row(row_idx, |dst| {
2007 scale_copy_row_avx2(base_ref, warmup_end, nbdev, dst);
2008 });
2009 };
2010
2011 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2012 let write_avx512 = |row_idx: usize, nbdev: f64| unsafe {
2013 writer.with_row(row_idx, |dst| {
2014 scale_copy_row_avx512(base_ref, warmup_end, nbdev, dst);
2015 });
2016 };
2017
2018 let dispatch_write = |row_idx: usize, nbdev: f64| match kern {
2019 Kernel::Scalar => write_scalar(row_idx, nbdev),
2020 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2021 Kernel::Avx2 => write_avx2(row_idx, nbdev),
2022 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2023 Kernel::Avx512 => write_avx512(row_idx, nbdev),
2024 _ => unreachable!(),
2025 };
2026
2027 if parallel {
2028 #[cfg(not(target_arch = "wasm32"))]
2029 {
2030 use rayon::prelude::*;
2031 rows_for_period
2032 .par_iter()
2033 .for_each(|(row_idx, nb)| dispatch_write(*row_idx, *nb));
2034 }
2035 #[cfg(target_arch = "wasm32")]
2036 {
2037 for (row_idx, nb) in rows_for_period.iter() {
2038 dispatch_write(*row_idx, *nb);
2039 }
2040 }
2041 } else {
2042 for (row_idx, nb) in rows_for_period.iter() {
2043 dispatch_write(*row_idx, *nb);
2044 }
2045 }
2046 }
2047
2048 if !fallback_rows.is_empty() {
2049 let do_row = |row: usize| unsafe {
2050 let prm = &combos[row];
2051 let period = prm.period.unwrap();
2052 let ma_type = prm.ma_type.as_ref().unwrap();
2053 let nbdev = prm.nbdev.unwrap();
2054 let devtype = prm.devtype.unwrap();
2055 writer.with_row(row, |dst| match kern {
2056 Kernel::Scalar => {
2057 zscore_row_scalar(data, first, period, ma_type, nbdev, devtype, dst)
2058 }
2059 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2060 Kernel::Avx2 => zscore_row_avx2(data, first, period, ma_type, nbdev, devtype, dst),
2061 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2062 Kernel::Avx512 => {
2063 zscore_row_avx512(data, first, period, ma_type, nbdev, devtype, dst)
2064 }
2065 _ => unreachable!(),
2066 });
2067 };
2068
2069 if parallel {
2070 #[cfg(not(target_arch = "wasm32"))]
2071 {
2072 use rayon::prelude::*;
2073 fallback_rows.par_iter().for_each(|&row| do_row(row));
2074 }
2075 #[cfg(target_arch = "wasm32")]
2076 {
2077 for &row in &fallback_rows {
2078 do_row(row);
2079 }
2080 }
2081 } else {
2082 for &row in &fallback_rows {
2083 do_row(row);
2084 }
2085 }
2086 }
2087
2088 Ok(combos)
2089}
2090
2091#[cfg(feature = "python")]
2092#[pyfunction(name = "zscore")]
2093#[pyo3(signature = (data, period=14, ma_type="sma", nbdev=1.0, devtype=0, kernel=None))]
2094pub fn zscore_py<'py>(
2095 py: Python<'py>,
2096 data: numpy::PyReadonlyArray1<'py, f64>,
2097 period: usize,
2098 ma_type: &str,
2099 nbdev: f64,
2100 devtype: usize,
2101 kernel: Option<&str>,
2102) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
2103 use numpy::{IntoPyArray, PyArrayMethods};
2104
2105 let slice_in = data.as_slice()?;
2106 let kern = validate_kernel(kernel, false)?;
2107 let params = ZscoreParams {
2108 period: Some(period),
2109 ma_type: Some(ma_type.to_string()),
2110 nbdev: Some(nbdev),
2111 devtype: Some(devtype),
2112 };
2113 let input = ZscoreInput::from_slice(slice_in, params);
2114
2115 let result_vec: Vec<f64> = py
2116 .allow_threads(|| zscore_with_kernel(&input, kern).map(|o| o.values))
2117 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2118
2119 Ok(result_vec.into_pyarray(py))
2120}
2121
2122#[cfg(all(feature = "python", feature = "cuda"))]
2123#[pyclass(module = "ta_indicators.cuda", unsendable)]
2124pub struct ZscoreDeviceArrayF32Py {
2125 pub(crate) inner: DeviceArrayF32,
2126 _ctx_guard: Arc<Context>,
2127 _device_id: u32,
2128}
2129
2130#[cfg(all(feature = "python", feature = "cuda"))]
2131#[pymethods]
2132impl ZscoreDeviceArrayF32Py {
2133 #[getter]
2134 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2135 let d = PyDict::new(py);
2136 let itemsize = std::mem::size_of::<f32>();
2137 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
2138 d.set_item("typestr", "<f4")?;
2139 d.set_item("strides", (self.inner.cols * itemsize, itemsize))?;
2140 let ptr_val = self.inner.buf.as_device_ptr().as_raw() as usize;
2141 d.set_item("data", (ptr_val, false))?;
2142
2143 d.set_item("version", 3)?;
2144 Ok(d)
2145 }
2146
2147 fn __dlpack_device__(&self) -> (i32, i32) {
2148 (2, self._device_id as i32)
2149 }
2150
2151 #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
2152 fn __dlpack__<'py>(
2153 &mut self,
2154 py: Python<'py>,
2155 stream: Option<pyo3::PyObject>,
2156 max_version: Option<pyo3::PyObject>,
2157 dl_device: Option<pyo3::PyObject>,
2158 copy: Option<pyo3::PyObject>,
2159 ) -> PyResult<PyObject> {
2160 let (dev_type, alloc_dev) = self.__dlpack_device__();
2161 if let Some(dev_obj) = dl_device.as_ref() {
2162 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2163 if dev_ty != dev_type || dev_id != alloc_dev {
2164 return Err(PyValueError::new_err(
2165 "zscore: dl_device mismatch; cross-device copy not implemented",
2166 ));
2167 }
2168 }
2169 }
2170 let _ = stream;
2171 let _ = copy;
2172
2173 let dummy =
2174 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2175 let inner = std::mem::replace(
2176 &mut self.inner,
2177 DeviceArrayF32 {
2178 buf: dummy,
2179 rows: 0,
2180 cols: 0,
2181 },
2182 );
2183 let rows = inner.rows;
2184 let cols = inner.cols;
2185 let buf = inner.buf;
2186
2187 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2188
2189 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2190 }
2191}
2192
2193#[cfg(all(feature = "python", feature = "cuda"))]
2194impl ZscoreDeviceArrayF32Py {
2195 pub fn new_from_rust(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
2196 Self {
2197 inner,
2198 _ctx_guard: ctx_guard,
2199 _device_id: device_id,
2200 }
2201 }
2202}
2203
2204#[cfg(all(feature = "python", feature = "cuda"))]
2205#[pyfunction(name = "zscore_cuda_batch_dev")]
2206#[pyo3(signature = (data_f32, period_range, nbdev_range=(1.0, 1.0, 0.0), device_id=0))]
2207pub fn zscore_cuda_batch_dev_py<'py>(
2208 py: Python<'py>,
2209 data_f32: numpy::PyReadonlyArray1<'py, f32>,
2210 period_range: (usize, usize, usize),
2211 nbdev_range: (f64, f64, f64),
2212 device_id: usize,
2213) -> PyResult<(ZscoreDeviceArrayF32Py, Bound<'py, PyDict>)> {
2214 use crate::cuda::cuda_available;
2215
2216 if !cuda_available() {
2217 return Err(PyValueError::new_err("CUDA not available"));
2218 }
2219
2220 let slice_in = data_f32.as_slice()?;
2221 let sweep = ZscoreBatchRange {
2222 period: period_range,
2223 ma_type: ("sma".to_string(), "sma".to_string(), "".to_string()),
2224 nbdev: nbdev_range,
2225 devtype: (0, 0, 0),
2226 };
2227
2228 let (inner, ctx, dev_id, combos) = py.allow_threads(|| {
2229 let cuda = CudaZscore::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2230 let ctx = cuda.context_arc();
2231 let dev_id = cuda.device_id();
2232 let (arr, combos) = cuda
2233 .zscore_batch_dev(slice_in, &sweep)
2234 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2235 Ok::<_, pyo3::PyErr>((arr, ctx, dev_id, combos))
2236 })?;
2237
2238 let dict = PyDict::new(py);
2239 let periods: Vec<u64> = combos.iter().map(|(p, _)| *p as u64).collect();
2240 let nbdevs: Vec<f64> = combos.iter().map(|(_, nb)| *nb as f64).collect();
2241 let devtypes: Vec<u64> = combos.iter().map(|_| 0u64).collect();
2242 let ma_types = PyList::new(py, vec!["sma"; combos.len()])?;
2243
2244 dict.set_item("periods", periods.into_pyarray(py))?;
2245 dict.set_item("nbdevs", nbdevs.into_pyarray(py))?;
2246 dict.set_item("ma_types", ma_types)?;
2247 dict.set_item("devtypes", devtypes.into_pyarray(py))?;
2248
2249 Ok((
2250 ZscoreDeviceArrayF32Py::new_from_rust(inner, ctx, dev_id),
2251 dict,
2252 ))
2253}
2254
2255#[cfg(all(feature = "python", feature = "cuda"))]
2256#[pyfunction(name = "zscore_cuda_many_series_one_param_dev")]
2257#[pyo3(signature = (data_tm_f32, cols, rows, period, nbdev=1.0, device_id=0))]
2258pub fn zscore_cuda_many_series_one_param_dev_py<'py>(
2259 py: Python<'py>,
2260 data_tm_f32: numpy::PyReadonlyArray1<'py, f32>,
2261 cols: usize,
2262 rows: usize,
2263 period: usize,
2264 nbdev: f64,
2265 device_id: usize,
2266) -> PyResult<ZscoreDeviceArrayF32Py> {
2267 use crate::cuda::cuda_available;
2268 if !cuda_available() {
2269 return Err(PyValueError::new_err("CUDA not available"));
2270 }
2271
2272 if nbdev < 0.0 || !nbdev.is_finite() {
2273 return Err(PyValueError::new_err(
2274 "nbdev must be non-negative and finite",
2275 ));
2276 }
2277
2278 let slice_in = data_tm_f32.as_slice()?;
2279 let (inner, ctx, dev_id) = py.allow_threads(|| {
2280 let cuda = CudaZscore::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2281 let ctx = cuda.context_arc();
2282 let dev_id = cuda.device_id();
2283 let arr = cuda
2284 .zscore_many_series_one_param_time_major_dev(slice_in, cols, rows, period, nbdev as f32)
2285 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2286 Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
2287 })?;
2288
2289 Ok(ZscoreDeviceArrayF32Py::new_from_rust(inner, ctx, dev_id))
2290}
2291
2292#[cfg(feature = "python")]
2293#[pyclass(name = "ZscoreStream")]
2294pub struct ZscoreStreamPy {
2295 stream: ZscoreStream,
2296}
2297
2298#[cfg(feature = "python")]
2299#[pymethods]
2300impl ZscoreStreamPy {
2301 #[new]
2302 fn new(period: usize, ma_type: &str, nbdev: f64, devtype: usize) -> PyResult<Self> {
2303 let params = ZscoreParams {
2304 period: Some(period),
2305 ma_type: Some(ma_type.to_string()),
2306 nbdev: Some(nbdev),
2307 devtype: Some(devtype),
2308 };
2309 let stream =
2310 ZscoreStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2311 Ok(ZscoreStreamPy { stream })
2312 }
2313
2314 fn update(&mut self, value: f64) -> Option<f64> {
2315 self.stream.update(value)
2316 }
2317}
2318
2319#[cfg(feature = "python")]
2320#[pyfunction(name = "zscore_batch")]
2321#[pyo3(signature = (data, period_range, ma_type="sma", nbdev_range=(1.0, 1.0, 0.0), devtype_range=(0, 0, 0), kernel=None))]
2322pub fn zscore_batch_py<'py>(
2323 py: Python<'py>,
2324 data: numpy::PyReadonlyArray1<'py, f64>,
2325 period_range: (usize, usize, usize),
2326 ma_type: &str,
2327 nbdev_range: (f64, f64, f64),
2328 devtype_range: (usize, usize, usize),
2329 kernel: Option<&str>,
2330) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2331 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2332 use pyo3::types::PyDict;
2333
2334 let slice_in = data.as_slice()?;
2335
2336 let sweep = ZscoreBatchRange {
2337 period: period_range,
2338 ma_type: (ma_type.to_string(), ma_type.to_string(), "".to_string()),
2339 nbdev: nbdev_range,
2340 devtype: devtype_range,
2341 };
2342
2343 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2344 let rows = combos.len();
2345 let cols = slice_in.len();
2346
2347 let total = rows
2348 .checked_mul(cols)
2349 .ok_or_else(|| PyValueError::new_err("zscore_batch: rows*cols overflow"))?;
2350 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2351 let slice_out = unsafe { out_arr.as_slice_mut()? };
2352
2353 let kern = validate_kernel(kernel, true)?;
2354
2355 let combos = py
2356 .allow_threads(|| {
2357 let kernel = match kern {
2358 Kernel::Auto => detect_best_batch_kernel(),
2359 k => k,
2360 };
2361 let simd = match kernel {
2362 Kernel::Avx512Batch => Kernel::Avx512,
2363 Kernel::Avx2Batch => Kernel::Avx2,
2364 Kernel::ScalarBatch => Kernel::Scalar,
2365 _ => unreachable!(),
2366 };
2367 zscore_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2368 })
2369 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2370
2371 let dict = PyDict::new(py);
2372 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2373 dict.set_item(
2374 "periods",
2375 combos
2376 .iter()
2377 .map(|p| p.period.unwrap() as u64)
2378 .collect::<Vec<_>>()
2379 .into_pyarray(py),
2380 )?;
2381 dict.set_item(
2382 "ma_types",
2383 PyList::new(
2384 py,
2385 combos.iter().map(|p| p.ma_type.as_ref().unwrap().clone()),
2386 )?,
2387 )?;
2388 dict.set_item(
2389 "nbdevs",
2390 combos
2391 .iter()
2392 .map(|p| p.nbdev.unwrap())
2393 .collect::<Vec<_>>()
2394 .into_pyarray(py),
2395 )?;
2396 dict.set_item(
2397 "devtypes",
2398 combos
2399 .iter()
2400 .map(|p| p.devtype.unwrap() as u64)
2401 .collect::<Vec<_>>()
2402 .into_pyarray(py),
2403 )?;
2404
2405 Ok(dict)
2406}
2407
2408pub fn zscore_into_slice(
2409 dst: &mut [f64],
2410 input: &ZscoreInput,
2411 kern: Kernel,
2412) -> Result<(), ZscoreError> {
2413 let data: &[f64] = input.as_ref();
2414 if data.is_empty() {
2415 return Err(ZscoreError::EmptyInputData);
2416 }
2417 if dst.len() != data.len() {
2418 return Err(ZscoreError::OutputLengthMismatch {
2419 expected: data.len(),
2420 got: dst.len(),
2421 });
2422 }
2423
2424 let first = data
2425 .iter()
2426 .position(|x| !x.is_nan())
2427 .ok_or(ZscoreError::AllValuesNaN)?;
2428 let len = data.len();
2429 let period = input.get_period();
2430 if period == 0 || period > len {
2431 return Err(ZscoreError::InvalidPeriod {
2432 period,
2433 data_len: len,
2434 });
2435 }
2436 if (len - first) < period {
2437 return Err(ZscoreError::NotEnoughValidData {
2438 needed: period,
2439 valid: len - first,
2440 });
2441 }
2442
2443 let ma_type = input.get_ma_type();
2444 let nbdev = input.get_nbdev();
2445 let devtype = input.get_devtype();
2446
2447 let chosen = match kern {
2448 Kernel::Auto => match detect_best_kernel() {
2449 Kernel::Avx512 | Kernel::Avx512Batch => Kernel::Avx2,
2450 other => other,
2451 },
2452 other => other,
2453 };
2454
2455 unsafe {
2456 match chosen {
2457 Kernel::Scalar | Kernel::ScalarBatch => {
2458 zscore_compute_into_scalar(data, period, first, &ma_type, nbdev, devtype, dst)
2459 }
2460 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2461 Kernel::Avx2 | Kernel::Avx2Batch => {
2462 zscore_compute_into_avx2(data, period, first, &ma_type, nbdev, devtype, dst)
2463 }
2464 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2465 Kernel::Avx512 | Kernel::Avx512Batch => {
2466 zscore_compute_into_avx512(data, period, first, &ma_type, nbdev, devtype, dst)
2467 }
2468 _ => {
2469 return Err(ZscoreError::InvalidPeriod {
2470 period: 0,
2471 data_len: 0,
2472 })
2473 }
2474 }
2475 }?;
2476
2477 Ok(())
2478}
2479
2480#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2481pub fn zscore_into(input: &ZscoreInput, out: &mut [f64]) -> Result<(), ZscoreError> {
2482 zscore_into_slice(out, input, Kernel::Auto)
2483}
2484
2485#[inline]
2486unsafe fn zscore_compute_into_scalar(
2487 data: &[f64],
2488 period: usize,
2489 first: usize,
2490 ma_type: &str,
2491 nbdev: f64,
2492 devtype: usize,
2493 out: &mut [f64],
2494) -> Result<(), ZscoreError> {
2495 let warmup_end = first + period - 1;
2496 for v in &mut out[..warmup_end] {
2497 *v = f64::from_bits(0x7ff8_0000_0000_0000);
2498 }
2499
2500 if data.len() <= warmup_end {
2501 return Ok(());
2502 }
2503
2504 if devtype == 0 {
2505 if ma_type == "sma" {
2506 let inv = 1.0 / (period as f64);
2507 let mut sum = 0.0f64;
2508 let mut sum_sqr = 0.0f64;
2509 {
2510 let mut j = first;
2511 while j <= warmup_end {
2512 let v = *data.get_unchecked(j);
2513 sum += v;
2514 sum_sqr = v.mul_add(v, sum_sqr);
2515 j += 1;
2516 }
2517 }
2518 let mut mean = sum * inv;
2519 let mut variance = (-mean).mul_add(mean, sum_sqr * inv);
2520 if variance < 0.0 {
2521 variance = 0.0;
2522 }
2523 let mut sd = if variance == 0.0 {
2524 0.0
2525 } else {
2526 variance.sqrt() * nbdev
2527 };
2528
2529 let xw = *data.get_unchecked(warmup_end);
2530 *out.get_unchecked_mut(warmup_end) = if sd == 0.0 || sd.is_nan() {
2531 f64::NAN
2532 } else {
2533 (xw - mean) / sd
2534 };
2535
2536 let n = data.len();
2537 let mut i = warmup_end + 1;
2538 while i < n {
2539 let old_val = *data.get_unchecked(i - period);
2540 let new_val = *data.get_unchecked(i);
2541 let dd = new_val - old_val;
2542 sum += dd;
2543 sum_sqr = (new_val + old_val).mul_add(dd, sum_sqr);
2544 mean = sum * inv;
2545
2546 variance = (-mean).mul_add(mean, sum_sqr * inv);
2547 if variance < 0.0 {
2548 variance = 0.0;
2549 }
2550 sd = if variance == 0.0 {
2551 0.0
2552 } else {
2553 variance.sqrt() * nbdev
2554 };
2555
2556 *out.get_unchecked_mut(i) = if sd == 0.0 || sd.is_nan() {
2557 f64::NAN
2558 } else {
2559 (new_val - mean) / sd
2560 };
2561 i += 1;
2562 }
2563
2564 return Ok(());
2565 }
2566
2567 if ma_type == "ema" {
2568 let den = period as f64;
2569 let inv = 1.0 / den;
2570 let alpha = 2.0 / (den + 1.0);
2571 let one_minus_alpha = 1.0 - alpha;
2572
2573 let mut sum = 0.0f64;
2574 let mut sum2 = 0.0f64;
2575 {
2576 let mut j = first;
2577 while j <= warmup_end {
2578 let v = *data.get_unchecked(j);
2579 sum += v;
2580 sum2 = v.mul_add(v, sum2);
2581 j += 1;
2582 }
2583 }
2584 let mut ema = sum * inv;
2585
2586 let mut ex = sum * inv;
2587 let mut ex2 = sum2 * inv;
2588 let mut mse = (-2.0 * ema).mul_add(ex, ema.mul_add(ema, ex2));
2589 if mse < 0.0 {
2590 mse = 0.0;
2591 }
2592 let mut sd = mse.sqrt() * nbdev;
2593
2594 let xw = *data.get_unchecked(warmup_end);
2595 *out.get_unchecked_mut(warmup_end) = if sd == 0.0 || sd.is_nan() {
2596 f64::NAN
2597 } else {
2598 (xw - ema) / sd
2599 };
2600
2601 let n = data.len();
2602 let mut i = warmup_end + 1;
2603 while i < n {
2604 let new = *data.get_unchecked(i);
2605 let old = *data.get_unchecked(i - period);
2606
2607 let dd = new - old;
2608 sum += dd;
2609 sum2 = (new + old).mul_add(dd, sum2);
2610 ex = sum * inv;
2611 ex2 = sum2 * inv;
2612
2613 ema = ema.mul_add(one_minus_alpha, alpha * new);
2614
2615 mse = (-2.0 * ema).mul_add(ex, ema.mul_add(ema, ex2));
2616 if mse < 0.0 {
2617 mse = 0.0;
2618 }
2619 sd = mse.sqrt() * nbdev;
2620
2621 *out.get_unchecked_mut(i) = if sd == 0.0 || sd.is_nan() {
2622 f64::NAN
2623 } else {
2624 (new - ema) / sd
2625 };
2626 i += 1;
2627 }
2628
2629 return Ok(());
2630 }
2631 }
2632
2633 let means = ma(ma_type, MaData::Slice(data), period)
2634 .map_err(|e| ZscoreError::MaError(e.to_string()))?;
2635 let dev_input = DevInput {
2636 data: DeviationData::Slice(data),
2637 params: DevParams {
2638 period: Some(period),
2639 devtype: Some(devtype),
2640 },
2641 };
2642 let mut sigmas = deviation(&dev_input)?.values;
2643 for v in &mut sigmas {
2644 *v *= nbdev;
2645 }
2646
2647 for i in warmup_end..data.len() {
2648 let mean = means[i];
2649 let sigma = sigmas[i];
2650 let value = data[i];
2651 out[i] = if sigma == 0.0 || sigma.is_nan() {
2652 f64::NAN
2653 } else {
2654 (value - mean) / sigma
2655 };
2656 }
2657 Ok(())
2658}
2659
2660#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2661#[inline]
2662unsafe fn zscore_compute_into_avx2(
2663 data: &[f64],
2664 period: usize,
2665 first: usize,
2666 ma_type: &str,
2667 nbdev: f64,
2668 devtype: usize,
2669 out: &mut [f64],
2670) -> Result<(), ZscoreError> {
2671 zscore_compute_into_scalar(data, period, first, ma_type, nbdev, devtype, out)
2672}
2673
2674#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2675#[inline]
2676unsafe fn zscore_compute_into_avx512(
2677 data: &[f64],
2678 period: usize,
2679 first: usize,
2680 ma_type: &str,
2681 nbdev: f64,
2682 devtype: usize,
2683 out: &mut [f64],
2684) -> Result<(), ZscoreError> {
2685 zscore_compute_into_scalar(data, period, first, ma_type, nbdev, devtype, out)
2686}
2687
2688#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2689#[wasm_bindgen]
2690pub fn zscore_js(
2691 data: &[f64],
2692 period: usize,
2693 ma_type: &str,
2694 nbdev: f64,
2695 devtype: usize,
2696) -> Result<Vec<f64>, JsValue> {
2697 let params = ZscoreParams {
2698 period: Some(period),
2699 ma_type: Some(ma_type.to_string()),
2700 nbdev: Some(nbdev),
2701 devtype: Some(devtype),
2702 };
2703 let input = ZscoreInput::from_slice(data, params);
2704
2705 let mut output = vec![0.0; data.len()];
2706
2707 zscore_into_slice(&mut output, &input, Kernel::Auto)
2708 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2709
2710 Ok(output)
2711}
2712
2713#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2714#[wasm_bindgen]
2715pub fn zscore_alloc(len: usize) -> *mut f64 {
2716 let mut vec = Vec::<f64>::with_capacity(len);
2717 let ptr = vec.as_mut_ptr();
2718 std::mem::forget(vec);
2719 ptr
2720}
2721
2722#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2723#[wasm_bindgen]
2724pub fn zscore_free(ptr: *mut f64, len: usize) {
2725 if !ptr.is_null() {
2726 unsafe {
2727 let _ = Vec::from_raw_parts(ptr, len, len);
2728 }
2729 }
2730}
2731
2732#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2733#[wasm_bindgen]
2734pub fn zscore_into(
2735 in_ptr: *const f64,
2736 out_ptr: *mut f64,
2737 len: usize,
2738 period: usize,
2739 ma_type: &str,
2740 nbdev: f64,
2741 devtype: usize,
2742) -> Result<(), JsValue> {
2743 if in_ptr.is_null() || out_ptr.is_null() {
2744 return Err(JsValue::from_str("Null pointer provided"));
2745 }
2746
2747 unsafe {
2748 let data = std::slice::from_raw_parts(in_ptr, len);
2749 let params = ZscoreParams {
2750 period: Some(period),
2751 ma_type: Some(ma_type.to_string()),
2752 nbdev: Some(nbdev),
2753 devtype: Some(devtype),
2754 };
2755 let input = ZscoreInput::from_slice(data, params);
2756
2757 if in_ptr == out_ptr {
2758 let mut temp = vec![0.0; len];
2759 zscore_into_slice(&mut temp, &input, Kernel::Auto)
2760 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2761 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2762 out.copy_from_slice(&temp);
2763 } else {
2764 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2765 zscore_into_slice(out, &input, Kernel::Auto)
2766 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2767 }
2768 }
2769
2770 Ok(())
2771}
2772
2773#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2774#[derive(Serialize, Deserialize)]
2775pub struct ZscoreBatchConfig {
2776 pub period_range: (usize, usize, usize),
2777 pub ma_type: String,
2778 pub nbdev_range: (f64, f64, f64),
2779 pub devtype_range: (usize, usize, usize),
2780}
2781
2782#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2783#[derive(Serialize, Deserialize)]
2784pub struct ZscoreBatchJsOutput {
2785 pub values: Vec<f64>,
2786 pub combos: Vec<ZscoreParams>,
2787 pub rows: usize,
2788 pub cols: usize,
2789}
2790
2791#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2792#[wasm_bindgen(js_name = zscore_batch)]
2793pub fn zscore_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2794 let config: ZscoreBatchConfig = serde_wasm_bindgen::from_value(config)
2795 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2796
2797 let sweep = ZscoreBatchRange {
2798 period: config.period_range,
2799 ma_type: (
2800 config.ma_type.clone(),
2801 config.ma_type.clone(),
2802 "".to_string(),
2803 ),
2804 nbdev: config.nbdev_range,
2805 devtype: config.devtype_range,
2806 };
2807
2808 let output = zscore_batch_inner(data, &sweep, Kernel::Auto, false)
2809 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2810
2811 let js_output = ZscoreBatchJsOutput {
2812 values: output.values,
2813 combos: output.combos,
2814 rows: output.rows,
2815 cols: output.cols,
2816 };
2817
2818 serde_wasm_bindgen::to_value(&js_output).map_err(|e| JsValue::from_str(&e.to_string()))
2819}
2820
2821#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2822#[wasm_bindgen]
2823pub fn zscore_batch_into(
2824 in_ptr: *const f64,
2825 out_ptr: *mut f64,
2826 len: usize,
2827 period_start: usize,
2828 period_end: usize,
2829 period_step: usize,
2830 ma_type: &str,
2831 nbdev_start: f64,
2832 nbdev_end: f64,
2833 nbdev_step: f64,
2834 devtype_start: usize,
2835 devtype_end: usize,
2836 devtype_step: usize,
2837) -> Result<usize, JsValue> {
2838 if in_ptr.is_null() || out_ptr.is_null() {
2839 return Err(JsValue::from_str("Null pointer provided"));
2840 }
2841
2842 let sweep = ZscoreBatchRange {
2843 period: (period_start, period_end, period_step),
2844 ma_type: (ma_type.to_string(), ma_type.to_string(), "".to_string()),
2845 nbdev: (nbdev_start, nbdev_end, nbdev_step),
2846 devtype: (devtype_start, devtype_end, devtype_step),
2847 };
2848
2849 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2850 let n_combos = combos.len();
2851
2852 unsafe {
2853 let data = std::slice::from_raw_parts(in_ptr, len);
2854 let total = n_combos
2855 .checked_mul(len)
2856 .ok_or_else(|| JsValue::from_str("zscore_batch_into: rows*cols overflow"))?;
2857 let out = std::slice::from_raw_parts_mut(out_ptr, total);
2858
2859 let simd = detect_best_kernel();
2860 zscore_batch_inner_into(data, &sweep, simd, false, out)
2861 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2862 }
2863
2864 Ok(n_combos)
2865}
2866
2867#[cfg(test)]
2868mod tests {
2869 use super::*;
2870 use crate::skip_if_unsupported;
2871 use crate::utilities::data_loader::read_candles_from_csv;
2872 fn check_zscore_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2873 skip_if_unsupported!(kernel, test_name);
2874 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2875 let candles = read_candles_from_csv(file_path)?;
2876 let default_params = ZscoreParams {
2877 period: None,
2878 ma_type: None,
2879 nbdev: None,
2880 devtype: None,
2881 };
2882 let input = ZscoreInput::from_candles(&candles, "close", default_params);
2883 let output = zscore_with_kernel(&input, kernel)?;
2884 assert_eq!(output.values.len(), candles.close.len());
2885 Ok(())
2886 }
2887 fn check_zscore_with_zero_period(
2888 test_name: &str,
2889 kernel: Kernel,
2890 ) -> Result<(), Box<dyn Error>> {
2891 skip_if_unsupported!(kernel, test_name);
2892 let input_data = [10.0, 20.0, 30.0];
2893 let params = ZscoreParams {
2894 period: Some(0),
2895 ma_type: None,
2896 nbdev: None,
2897 devtype: None,
2898 };
2899 let input = ZscoreInput::from_slice(&input_data, params);
2900 let res = zscore_with_kernel(&input, kernel);
2901 assert!(
2902 res.is_err(),
2903 "[{}] Zscore should fail with zero period",
2904 test_name
2905 );
2906 Ok(())
2907 }
2908 fn check_zscore_period_exceeds_length(
2909 test_name: &str,
2910 kernel: Kernel,
2911 ) -> Result<(), Box<dyn Error>> {
2912 skip_if_unsupported!(kernel, test_name);
2913 let data_small = [10.0, 20.0, 30.0];
2914 let params = ZscoreParams {
2915 period: Some(10),
2916 ma_type: None,
2917 nbdev: None,
2918 devtype: None,
2919 };
2920 let input = ZscoreInput::from_slice(&data_small, params);
2921 let res = zscore_with_kernel(&input, kernel);
2922 assert!(
2923 res.is_err(),
2924 "[{}] Zscore should fail with period exceeding length",
2925 test_name
2926 );
2927 Ok(())
2928 }
2929 fn check_zscore_very_small_dataset(
2930 test_name: &str,
2931 kernel: Kernel,
2932 ) -> Result<(), Box<dyn Error>> {
2933 skip_if_unsupported!(kernel, test_name);
2934 let single_point = [42.0];
2935 let params = ZscoreParams {
2936 period: Some(14),
2937 ma_type: None,
2938 nbdev: None,
2939 devtype: None,
2940 };
2941 let input = ZscoreInput::from_slice(&single_point, params);
2942 let res = zscore_with_kernel(&input, kernel);
2943 assert!(
2944 res.is_err(),
2945 "[{}] Zscore should fail with insufficient data",
2946 test_name
2947 );
2948 Ok(())
2949 }
2950 fn check_zscore_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2951 skip_if_unsupported!(kernel, test_name);
2952 let input_data = [f64::NAN, f64::NAN, f64::NAN];
2953 let params = ZscoreParams::default();
2954 let input = ZscoreInput::from_slice(&input_data, params);
2955 let res = zscore_with_kernel(&input, kernel);
2956 assert!(
2957 res.is_err(),
2958 "[{}] Zscore should fail when all values are NaN",
2959 test_name
2960 );
2961 Ok(())
2962 }
2963 fn check_zscore_input_with_default_candles(
2964 test_name: &str,
2965 kernel: Kernel,
2966 ) -> Result<(), Box<dyn Error>> {
2967 skip_if_unsupported!(kernel, test_name);
2968 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2969 let candles = read_candles_from_csv(file_path)?;
2970 let input = ZscoreInput::with_default_candles(&candles);
2971 match input.data {
2972 ZscoreData::Candles { source, .. } => assert_eq!(source, "close"),
2973 _ => panic!("Expected ZscoreData::Candles"),
2974 }
2975 Ok(())
2976 }
2977 fn check_zscore_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2978 skip_if_unsupported!(kernel, test_name);
2979 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2980 let candles = read_candles_from_csv(file_path)?;
2981
2982 let input = ZscoreInput::from_candles(&candles, "close", ZscoreParams::default());
2983 let result = zscore_with_kernel(&input, kernel)?;
2984
2985 let expected_last_five = [
2986 -0.3040683926967643,
2987 -0.41042159719064014,
2988 -0.5411993612192193,
2989 -0.1673226261513698,
2990 -1.431635486349618,
2991 ];
2992 let start = result.values.len().saturating_sub(5);
2993
2994 for (i, &val) in result.values[start..].iter().enumerate() {
2995 let diff = (val - expected_last_five[i]).abs();
2996 assert!(
2997 diff < 1e-8,
2998 "[{}] Zscore {:?} mismatch at idx {}: got {}, expected {}",
2999 test_name,
3000 kernel,
3001 i,
3002 val,
3003 expected_last_five[i]
3004 );
3005 }
3006 Ok(())
3007 }
3008 macro_rules! generate_all_zscore_tests {
3009 ($($test_fn:ident),*) => {
3010 paste::paste! {
3011 $(
3012 #[test]
3013 fn [<$test_fn _scalar_f64>]() {
3014 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
3015 }
3016 )*
3017 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3018 $(
3019 #[test]
3020 fn [<$test_fn _avx2_f64>]() {
3021 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
3022 }
3023 #[test]
3024 fn [<$test_fn _avx512_f64>]() {
3025 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
3026 }
3027 )*
3028 }
3029 }
3030 }
3031 #[cfg(debug_assertions)]
3032 fn check_zscore_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3033 skip_if_unsupported!(kernel, test_name);
3034
3035 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3036 let candles = read_candles_from_csv(file_path)?;
3037
3038 let test_params = vec![
3039 ZscoreParams::default(),
3040 ZscoreParams {
3041 period: Some(2),
3042 ma_type: Some("sma".to_string()),
3043 nbdev: Some(1.0),
3044 devtype: Some(0),
3045 },
3046 ZscoreParams {
3047 period: Some(5),
3048 ma_type: Some("ema".to_string()),
3049 nbdev: Some(1.0),
3050 devtype: Some(0),
3051 },
3052 ZscoreParams {
3053 period: Some(10),
3054 ma_type: Some("wma".to_string()),
3055 nbdev: Some(2.0),
3056 devtype: Some(0),
3057 },
3058 ZscoreParams {
3059 period: Some(20),
3060 ma_type: Some("sma".to_string()),
3061 nbdev: Some(1.5),
3062 devtype: Some(1),
3063 },
3064 ZscoreParams {
3065 period: Some(30),
3066 ma_type: Some("ema".to_string()),
3067 nbdev: Some(2.5),
3068 devtype: Some(2),
3069 },
3070 ZscoreParams {
3071 period: Some(50),
3072 ma_type: Some("wma".to_string()),
3073 nbdev: Some(3.0),
3074 devtype: Some(0),
3075 },
3076 ZscoreParams {
3077 period: Some(100),
3078 ma_type: Some("sma".to_string()),
3079 nbdev: Some(1.0),
3080 devtype: Some(1),
3081 },
3082 ZscoreParams {
3083 period: Some(14),
3084 ma_type: Some("ema".to_string()),
3085 nbdev: Some(0.5),
3086 devtype: Some(0),
3087 },
3088 ZscoreParams {
3089 period: Some(14),
3090 ma_type: Some("sma".to_string()),
3091 nbdev: Some(0.1),
3092 devtype: Some(2),
3093 },
3094 ZscoreParams {
3095 period: Some(25),
3096 ma_type: Some("wma".to_string()),
3097 nbdev: Some(4.0),
3098 devtype: Some(1),
3099 },
3100 ZscoreParams {
3101 period: Some(7),
3102 ma_type: Some("ema".to_string()),
3103 nbdev: Some(1.618),
3104 devtype: Some(0),
3105 },
3106 ZscoreParams {
3107 period: Some(21),
3108 ma_type: Some("sma".to_string()),
3109 nbdev: Some(2.718),
3110 devtype: Some(1),
3111 },
3112 ZscoreParams {
3113 period: Some(42),
3114 ma_type: Some("wma".to_string()),
3115 nbdev: Some(3.14159),
3116 devtype: Some(2),
3117 },
3118 ];
3119
3120 for (param_idx, params) in test_params.iter().enumerate() {
3121 let input = ZscoreInput::from_candles(&candles, "close", params.clone());
3122 let output = zscore_with_kernel(&input, kernel)?;
3123
3124 for (i, &val) in output.values.iter().enumerate() {
3125 if val.is_nan() {
3126 continue;
3127 }
3128
3129 let bits = val.to_bits();
3130
3131 if bits == 0x11111111_11111111 {
3132 panic!(
3133 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
3134 with params: period={}, ma_type={}, nbdev={}, devtype={} (param set {})",
3135 test_name,
3136 val,
3137 bits,
3138 i,
3139 params.period.unwrap_or(14),
3140 params.ma_type.as_deref().unwrap_or("sma"),
3141 params.nbdev.unwrap_or(1.0),
3142 params.devtype.unwrap_or(0),
3143 param_idx
3144 );
3145 }
3146
3147 if bits == 0x22222222_22222222 {
3148 panic!(
3149 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
3150 with params: period={}, ma_type={}, nbdev={}, devtype={} (param set {})",
3151 test_name,
3152 val,
3153 bits,
3154 i,
3155 params.period.unwrap_or(14),
3156 params.ma_type.as_deref().unwrap_or("sma"),
3157 params.nbdev.unwrap_or(1.0),
3158 params.devtype.unwrap_or(0),
3159 param_idx
3160 );
3161 }
3162
3163 if bits == 0x33333333_33333333 {
3164 panic!(
3165 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
3166 with params: period={}, ma_type={}, nbdev={}, devtype={} (param set {})",
3167 test_name,
3168 val,
3169 bits,
3170 i,
3171 params.period.unwrap_or(14),
3172 params.ma_type.as_deref().unwrap_or("sma"),
3173 params.nbdev.unwrap_or(1.0),
3174 params.devtype.unwrap_or(0),
3175 param_idx
3176 );
3177 }
3178 }
3179 }
3180
3181 Ok(())
3182 }
3183
3184 #[cfg(not(debug_assertions))]
3185 fn check_zscore_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
3186 Ok(())
3187 }
3188
3189 #[cfg(feature = "proptest")]
3190 #[allow(clippy::float_cmp)]
3191 fn check_zscore_property(
3192 test_name: &str,
3193 kernel: Kernel,
3194 ) -> Result<(), Box<dyn std::error::Error>> {
3195 use proptest::prelude::*;
3196 skip_if_unsupported!(kernel, test_name);
3197
3198 let strat = (2usize..=64).prop_flat_map(|period| {
3199 (
3200 prop::collection::vec(
3201 (-1e5f64..1e5f64).prop_filter("finite", |x| x.is_finite()),
3202 period + 10..400,
3203 ),
3204 Just(period),
3205 prop::sample::select(vec!["sma", "ema", "wma"]),
3206 0.5f64..3.0f64,
3207 0usize..=2,
3208 )
3209 });
3210
3211 proptest::test_runner::TestRunner::default().run(
3212 &strat,
3213 |(data, period, ma_type, nbdev, devtype)| {
3214 let params = ZscoreParams {
3215 period: Some(period),
3216 ma_type: Some(ma_type.to_string()),
3217 nbdev: Some(nbdev),
3218 devtype: Some(devtype),
3219 };
3220 let input = ZscoreInput::from_slice(&data, params.clone());
3221
3222 let ZscoreOutput { values: out } = zscore_with_kernel(&input, kernel)?;
3223
3224 let ZscoreOutput { values: ref_out } = zscore_with_kernel(&input, Kernel::Scalar)?;
3225
3226 prop_assert_eq!(out.len(), data.len(), "Output length mismatch");
3227
3228 for i in 0..(period - 1) {
3229 prop_assert!(
3230 out[i].is_nan(),
3231 "Expected NaN during warmup at index {}, got {}",
3232 i,
3233 out[i]
3234 );
3235 }
3236
3237 for i in (period - 1)..data.len() {
3238 let y = out[i];
3239 let r = ref_out[i];
3240
3241 if !y.is_finite() || !r.is_finite() {
3242 prop_assert_eq!(
3243 y.to_bits(),
3244 r.to_bits(),
3245 "NaN/infinite mismatch at index {}: {} vs {}",
3246 i,
3247 y,
3248 r
3249 );
3250 } else {
3251 let y_bits = y.to_bits();
3252 let r_bits = r.to_bits();
3253 let ulp_diff = y_bits.abs_diff(r_bits);
3254
3255 prop_assert!(
3256 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
3257 "Kernel mismatch at index {}: {} vs {} (ULP={})",
3258 i,
3259 y,
3260 r,
3261 ulp_diff
3262 );
3263 }
3264 }
3265
3266 if data.windows(2).all(|w| (w[0] - w[1]).abs() < f64::EPSILON) {
3267 for i in (period - 1)..data.len() {
3268 prop_assert!(
3269 out[i].is_nan() || devtype != 0,
3270 "Expected NaN for constant data with stddev at index {}, got {}",
3271 i,
3272 out[i]
3273 );
3274 }
3275 }
3276
3277 if period == 2 && devtype == 0 && ma_type == "sma" {
3278 for i in 1..data.len() {
3279 if out[i].is_finite() {
3280 let mean = (data[i - 1] + data[i]) / 2.0;
3281 let diff1 = (data[i - 1] - mean).powi(2);
3282 let diff2 = (data[i] - mean).powi(2);
3283 let variance = (diff1 + diff2) / 2.0;
3284 let stddev = variance.sqrt();
3285
3286 if stddev > f64::EPSILON {
3287 let expected = (data[i] - mean) / (stddev * nbdev);
3288 prop_assert!(
3289 (out[i] - expected).abs() <= 1e-6,
3290 "Zscore calculation mismatch at index {}: {} vs expected {}",
3291 i,
3292 out[i],
3293 expected
3294 );
3295 }
3296 }
3297 }
3298 }
3299
3300 Ok(())
3301 },
3302 )?;
3303
3304 Ok(())
3305 }
3306
3307 generate_all_zscore_tests!(
3308 check_zscore_partial_params,
3309 check_zscore_with_zero_period,
3310 check_zscore_period_exceeds_length,
3311 check_zscore_very_small_dataset,
3312 check_zscore_all_nan,
3313 check_zscore_input_with_default_candles,
3314 check_zscore_accuracy,
3315 check_zscore_no_poison
3316 );
3317
3318 #[cfg(feature = "proptest")]
3319 generate_all_zscore_tests!(check_zscore_property);
3320 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3321 skip_if_unsupported!(kernel, test);
3322 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3323 let c = read_candles_from_csv(file)?;
3324 let output = ZscoreBatchBuilder::new()
3325 .kernel(kernel)
3326 .apply_candles(&c, "close")?;
3327 let def = ZscoreParams::default();
3328 let row = output.values_for(&def).expect("default row missing");
3329 assert_eq!(row.len(), c.close.len());
3330 Ok(())
3331 }
3332 macro_rules! gen_batch_tests {
3333 ($fn_name:ident) => {
3334 paste::paste! {
3335 #[test] fn [<$fn_name _scalar>]() {
3336 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3337 }
3338 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3339 #[test] fn [<$fn_name _avx2>]() {
3340 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3341 }
3342 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3343 #[test] fn [<$fn_name _avx512>]() {
3344 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3345 }
3346 #[test] fn [<$fn_name _auto_detect>]() {
3347 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3348 }
3349 }
3350 };
3351 }
3352 #[cfg(debug_assertions)]
3353 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3354 skip_if_unsupported!(kernel, test);
3355
3356 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3357 let c = read_candles_from_csv(file)?;
3358
3359 let test_configs = vec![
3360 (2, 10, 2, 0.5, 2.0, 0.5, 0),
3361 (5, 25, 5, 1.0, 3.0, 1.0, 1),
3362 (10, 50, 10, 1.5, 3.5, 1.0, 2),
3363 (2, 5, 1, 0.1, 1.0, 0.3, 0),
3364 (14, 14, 0, 1.0, 4.0, 0.5, 0),
3365 (20, 40, 10, 2.0, 2.0, 0.0, 1),
3366 ];
3367
3368 let ma_types = vec!["sma", "ema", "wma"];
3369
3370 for (
3371 cfg_idx,
3372 &(period_start, period_end, period_step, nbdev_start, nbdev_end, nbdev_step, devtype),
3373 ) in test_configs.iter().enumerate()
3374 {
3375 for ma_type in &ma_types {
3376 let mut builder = ZscoreBatchBuilder::new().kernel(kernel);
3377
3378 if period_step > 0 {
3379 builder = builder.period_range(period_start, period_end, period_step);
3380 } else {
3381 builder = builder.period_static(period_start);
3382 }
3383
3384 if nbdev_step > 0.0 {
3385 builder = builder.nbdev_range(nbdev_start, nbdev_end, nbdev_step);
3386 } else {
3387 builder = builder.nbdev_static(nbdev_start);
3388 }
3389
3390 builder = builder.ma_type_static(ma_type.to_string());
3391
3392 builder = builder.devtype_static(devtype);
3393
3394 let output = builder.apply_candles(&c, "close")?;
3395
3396 for (idx, &val) in output.values.iter().enumerate() {
3397 if val.is_nan() {
3398 continue;
3399 }
3400
3401 let bits = val.to_bits();
3402 let row = idx / output.cols;
3403 let col = idx % output.cols;
3404 let combo = &output.combos[row];
3405
3406 if bits == 0x11111111_11111111 {
3407 panic!(
3408 "[{}] Config {} (MA: {}): Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
3409 at row {} col {} (flat index {}) with params: period={}, ma_type={}, nbdev={}, devtype={}",
3410 test, cfg_idx, ma_type, val, bits, row, col, idx,
3411 combo.period.unwrap_or(14),
3412 combo.ma_type.as_deref().unwrap_or("sma"),
3413 combo.nbdev.unwrap_or(1.0),
3414 combo.devtype.unwrap_or(0)
3415 );
3416 }
3417
3418 if bits == 0x22222222_22222222 {
3419 panic!(
3420 "[{}] Config {} (MA: {}): Found init_matrix_prefixes poison value {} (0x{:016X}) \
3421 at row {} col {} (flat index {}) with params: period={}, ma_type={}, nbdev={}, devtype={}",
3422 test, cfg_idx, ma_type, val, bits, row, col, idx,
3423 combo.period.unwrap_or(14),
3424 combo.ma_type.as_deref().unwrap_or("sma"),
3425 combo.nbdev.unwrap_or(1.0),
3426 combo.devtype.unwrap_or(0)
3427 );
3428 }
3429
3430 if bits == 0x33333333_33333333 {
3431 panic!(
3432 "[{}] Config {} (MA: {}): Found make_uninit_matrix poison value {} (0x{:016X}) \
3433 at row {} col {} (flat index {}) with params: period={}, ma_type={}, nbdev={}, devtype={}",
3434 test, cfg_idx, ma_type, val, bits, row, col, idx,
3435 combo.period.unwrap_or(14),
3436 combo.ma_type.as_deref().unwrap_or("sma"),
3437 combo.nbdev.unwrap_or(1.0),
3438 combo.devtype.unwrap_or(0)
3439 );
3440 }
3441 }
3442 }
3443 }
3444
3445 let devtype_test = ZscoreBatchBuilder::new()
3446 .kernel(kernel)
3447 .period_range(10, 30, 10)
3448 .nbdev_static(2.0)
3449 .ma_type_static("ema")
3450 .devtype_range(0, 2, 1)
3451 .apply_candles(&c, "close")?;
3452
3453 for (idx, &val) in devtype_test.values.iter().enumerate() {
3454 if val.is_nan() {
3455 continue;
3456 }
3457
3458 let bits = val.to_bits();
3459 let row = idx / devtype_test.cols;
3460 let col = idx % devtype_test.cols;
3461 let combo = &devtype_test.combos[row];
3462
3463 if bits == 0x11111111_11111111
3464 || bits == 0x22222222_22222222
3465 || bits == 0x33333333_33333333
3466 {
3467 let poison_type = if bits == 0x11111111_11111111 {
3468 "alloc_with_nan_prefix"
3469 } else if bits == 0x22222222_22222222 {
3470 "init_matrix_prefixes"
3471 } else {
3472 "make_uninit_matrix"
3473 };
3474
3475 panic!(
3476 "[{}] Devtype test: Found {} poison value {} (0x{:016X}) \
3477 at row {} col {} (flat index {}) with params: period={}, ma_type={}, nbdev={}, devtype={}",
3478 test,
3479 poison_type,
3480 val,
3481 bits,
3482 row,
3483 col,
3484 idx,
3485 combo.period.unwrap_or(14),
3486 combo.ma_type.as_deref().unwrap_or("sma"),
3487 combo.nbdev.unwrap_or(1.0),
3488 combo.devtype.unwrap_or(0)
3489 );
3490 }
3491 }
3492
3493 Ok(())
3494 }
3495
3496 #[cfg(not(debug_assertions))]
3497 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
3498 Ok(())
3499 }
3500
3501 gen_batch_tests!(check_batch_default_row);
3502 gen_batch_tests!(check_batch_no_poison);
3503
3504 #[test]
3505 fn test_zscore_into_matches_api() -> Result<(), Box<dyn Error>> {
3506 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3507 let candles = read_candles_from_csv(file_path)?;
3508
3509 let input = ZscoreInput::from_candles(&candles, "close", ZscoreParams::default());
3510
3511 let baseline = zscore(&input)?.values;
3512
3513 let mut out = vec![0.0; baseline.len()];
3514 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
3515 {
3516 zscore_into(&input, &mut out)?;
3517 }
3518 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3519 {
3520 zscore_into_slice(&mut out, &input, Kernel::Auto)?;
3521 }
3522
3523 assert_eq!(baseline.len(), out.len());
3524
3525 let eq_or_both_nan = |a: f64, b: f64| -> bool {
3526 (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-9)
3527 };
3528 for i in 0..out.len() {
3529 assert!(
3530 eq_or_both_nan(baseline[i], out[i]),
3531 "Mismatch at index {}: baseline={}, into={}",
3532 i,
3533 baseline[i],
3534 out[i]
3535 );
3536 }
3537 Ok(())
3538 }
3539}