1use crate::indicators::moving_averages::sma::{
2 sma, SmaData, SmaError, SmaInput, SmaOutput, SmaParams,
3};
4use crate::indicators::roc::{roc, RocData, RocError, RocInput, RocOutput, RocParams};
5use crate::utilities::data_loader::{source_type, Candles};
6use crate::utilities::enums::Kernel;
7use crate::utilities::helpers::{
8 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
9 make_uninit_matrix,
10};
11
12#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
13use core::arch::x86_64::*;
14#[cfg(not(target_arch = "wasm32"))]
15use rayon::prelude::*;
16use std::convert::AsRef;
17use std::error::Error;
18use thiserror::Error;
19
20#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
21use js_sys;
22#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
23use serde::{Deserialize, Serialize};
24#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
25use wasm_bindgen::prelude::*;
26
27#[cfg(all(feature = "python", feature = "cuda"))]
28use crate::cuda::oscillators::CudaKst;
29#[cfg(all(feature = "python", feature = "cuda"))]
30use crate::indicators::moving_averages::alma::DeviceArrayF32Py;
31#[cfg(feature = "python")]
32use numpy::PyReadonlyArray1;
33#[cfg(feature = "python")]
34use pyo3::exceptions::PyValueError;
35#[cfg(feature = "python")]
36use pyo3::prelude::*;
37
38#[derive(Debug, Clone)]
39pub enum KstData<'a> {
40 Candles {
41 candles: &'a Candles,
42 source: &'a str,
43 },
44 Slice(&'a [f64]),
45}
46
47#[derive(Debug, Clone)]
48pub struct KstOutput {
49 pub line: Vec<f64>,
50 pub signal: Vec<f64>,
51}
52
53#[derive(Debug, Clone, Copy)]
54#[cfg_attr(
55 all(target_arch = "wasm32", feature = "wasm"),
56 derive(serde::Serialize, serde::Deserialize)
57)]
58pub struct KstParams {
59 pub sma_period1: Option<usize>,
60 pub sma_period2: Option<usize>,
61 pub sma_period3: Option<usize>,
62 pub sma_period4: Option<usize>,
63 pub roc_period1: Option<usize>,
64 pub roc_period2: Option<usize>,
65 pub roc_period3: Option<usize>,
66 pub roc_period4: Option<usize>,
67 pub signal_period: Option<usize>,
68}
69
70impl Default for KstParams {
71 fn default() -> Self {
72 Self {
73 sma_period1: Some(10),
74 sma_period2: Some(10),
75 sma_period3: Some(10),
76 sma_period4: Some(15),
77 roc_period1: Some(10),
78 roc_period2: Some(15),
79 roc_period3: Some(20),
80 roc_period4: Some(30),
81 signal_period: Some(9),
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
87pub struct KstInput<'a> {
88 pub data: KstData<'a>,
89 pub params: KstParams,
90}
91
92impl<'a> AsRef<[f64]> for KstInput<'a> {
93 #[inline(always)]
94 fn as_ref(&self) -> &[f64] {
95 match &self.data {
96 KstData::Slice(slice) => slice,
97 KstData::Candles { candles, source } => source_type(candles, source),
98 }
99 }
100}
101
102impl<'a> KstInput<'a> {
103 #[inline]
104 pub fn from_candles(c: &'a Candles, s: &'a str, p: KstParams) -> Self {
105 Self {
106 data: KstData::Candles {
107 candles: c,
108 source: s,
109 },
110 params: p,
111 }
112 }
113 #[inline]
114 pub fn from_slice(sl: &'a [f64], p: KstParams) -> Self {
115 Self {
116 data: KstData::Slice(sl),
117 params: p,
118 }
119 }
120 #[inline]
121 pub fn with_default_candles(c: &'a Candles) -> Self {
122 Self::from_candles(c, "close", KstParams::default())
123 }
124 #[inline]
125 pub fn get_sma_period1(&self) -> usize {
126 self.params.sma_period1.unwrap_or(10)
127 }
128 #[inline]
129 pub fn get_sma_period2(&self) -> usize {
130 self.params.sma_period2.unwrap_or(10)
131 }
132 #[inline]
133 pub fn get_sma_period3(&self) -> usize {
134 self.params.sma_period3.unwrap_or(10)
135 }
136 #[inline]
137 pub fn get_sma_period4(&self) -> usize {
138 self.params.sma_period4.unwrap_or(15)
139 }
140 #[inline]
141 pub fn get_roc_period1(&self) -> usize {
142 self.params.roc_period1.unwrap_or(10)
143 }
144 #[inline]
145 pub fn get_roc_period2(&self) -> usize {
146 self.params.roc_period2.unwrap_or(15)
147 }
148 #[inline]
149 pub fn get_roc_period3(&self) -> usize {
150 self.params.roc_period3.unwrap_or(20)
151 }
152 #[inline]
153 pub fn get_roc_period4(&self) -> usize {
154 self.params.roc_period4.unwrap_or(30)
155 }
156 #[inline]
157 pub fn get_signal_period(&self) -> usize {
158 self.params.signal_period.unwrap_or(9)
159 }
160}
161
162#[derive(Copy, Clone, Debug)]
163pub struct KstBuilder {
164 sma_period1: Option<usize>,
165 sma_period2: Option<usize>,
166 sma_period3: Option<usize>,
167 sma_period4: Option<usize>,
168 roc_period1: Option<usize>,
169 roc_period2: Option<usize>,
170 roc_period3: Option<usize>,
171 roc_period4: Option<usize>,
172 signal_period: Option<usize>,
173 kernel: Kernel,
174}
175
176impl Default for KstBuilder {
177 fn default() -> Self {
178 Self {
179 sma_period1: None,
180 sma_period2: None,
181 sma_period3: None,
182 sma_period4: None,
183 roc_period1: None,
184 roc_period2: None,
185 roc_period3: None,
186 roc_period4: None,
187 signal_period: None,
188 kernel: Kernel::Auto,
189 }
190 }
191}
192
193impl KstBuilder {
194 #[inline(always)]
195 pub fn new() -> Self {
196 Self::default()
197 }
198 #[inline(always)]
199 pub fn sma_period1(mut self, n: usize) -> Self {
200 self.sma_period1 = Some(n);
201 self
202 }
203 #[inline(always)]
204 pub fn sma_period2(mut self, n: usize) -> Self {
205 self.sma_period2 = Some(n);
206 self
207 }
208 #[inline(always)]
209 pub fn sma_period3(mut self, n: usize) -> Self {
210 self.sma_period3 = Some(n);
211 self
212 }
213 #[inline(always)]
214 pub fn sma_period4(mut self, n: usize) -> Self {
215 self.sma_period4 = Some(n);
216 self
217 }
218 #[inline(always)]
219 pub fn roc_period1(mut self, n: usize) -> Self {
220 self.roc_period1 = Some(n);
221 self
222 }
223 #[inline(always)]
224 pub fn roc_period2(mut self, n: usize) -> Self {
225 self.roc_period2 = Some(n);
226 self
227 }
228 #[inline(always)]
229 pub fn roc_period3(mut self, n: usize) -> Self {
230 self.roc_period3 = Some(n);
231 self
232 }
233 #[inline(always)]
234 pub fn roc_period4(mut self, n: usize) -> Self {
235 self.roc_period4 = Some(n);
236 self
237 }
238 #[inline(always)]
239 pub fn signal_period(mut self, n: usize) -> Self {
240 self.signal_period = Some(n);
241 self
242 }
243 #[inline(always)]
244 pub fn kernel(mut self, k: Kernel) -> Self {
245 self.kernel = k;
246 self
247 }
248 #[inline(always)]
249 pub fn apply(self, c: &Candles) -> Result<KstOutput, KstError> {
250 let p = KstParams {
251 sma_period1: self.sma_period1,
252 sma_period2: self.sma_period2,
253 sma_period3: self.sma_period3,
254 sma_period4: self.sma_period4,
255 roc_period1: self.roc_period1,
256 roc_period2: self.roc_period2,
257 roc_period3: self.roc_period3,
258 roc_period4: self.roc_period4,
259 signal_period: self.signal_period,
260 };
261 let i = KstInput::from_candles(c, "close", p);
262 kst_with_kernel(&i, self.kernel)
263 }
264 #[inline(always)]
265 pub fn apply_slice(self, d: &[f64]) -> Result<KstOutput, KstError> {
266 let p = KstParams {
267 sma_period1: self.sma_period1,
268 sma_period2: self.sma_period2,
269 sma_period3: self.sma_period3,
270 sma_period4: self.sma_period4,
271 roc_period1: self.roc_period1,
272 roc_period2: self.roc_period2,
273 roc_period3: self.roc_period3,
274 roc_period4: self.roc_period4,
275 signal_period: self.signal_period,
276 };
277 let i = KstInput::from_slice(d, p);
278 kst_with_kernel(&i, self.kernel)
279 }
280 #[inline(always)]
281 pub fn into_stream(self) -> Result<KstStream, KstError> {
282 let p = KstParams {
283 sma_period1: self.sma_period1,
284 sma_period2: self.sma_period2,
285 sma_period3: self.sma_period3,
286 sma_period4: self.sma_period4,
287 roc_period1: self.roc_period1,
288 roc_period2: self.roc_period2,
289 roc_period3: self.roc_period3,
290 roc_period4: self.roc_period4,
291 signal_period: self.signal_period,
292 };
293 KstStream::try_new(p)
294 }
295}
296
297#[derive(Debug, Error)]
298pub enum KstError {
299 #[error("kst: {0}")]
300 Roc(#[from] RocError),
301 #[error("kst: {0}")]
302 Sma(#[from] SmaError),
303 #[error("kst: Input data slice is empty.")]
304 EmptyInputData,
305 #[error("kst: All values are NaN.")]
306 AllValuesNaN,
307 #[error("kst: Invalid period: period = {period}, data length = {data_len}")]
308 InvalidPeriod { period: usize, data_len: usize },
309 #[error("kst: Not enough valid data: needed = {needed}, valid = {valid}")]
310 NotEnoughValidData { needed: usize, valid: usize },
311 #[error("kst: Output length mismatch: expected = {expected}, got = {got}")]
312 OutputLengthMismatch { expected: usize, got: usize },
313 #[error("kst: Invalid range: start = {start}, end = {end}, step = {step}")]
314 InvalidRange {
315 start: usize,
316 end: usize,
317 step: usize,
318 },
319 #[error("kst: Invalid kernel for batch path: {0:?}")]
320 InvalidKernelForBatch(Kernel),
321 #[error("kst: size arithmetic overflow")]
322 SizeOverflow,
323}
324
325#[inline]
326pub fn kst(input: &KstInput) -> Result<KstOutput, KstError> {
327 kst_with_kernel(input, Kernel::Auto)
328}
329
330#[inline(always)]
331fn kst_prepare<'a>(
332 input: &'a KstInput,
333 kernel: Kernel,
334) -> Result<
335 (
336 &'a [f64],
337 (usize, usize, usize, usize),
338 (usize, usize, usize, usize),
339 usize,
340 usize,
341 usize,
342 usize,
343 Kernel,
344 ),
345 KstError,
346> {
347 let data = input.as_ref();
348 let len = data.len();
349 if len == 0 {
350 return Err(KstError::EmptyInputData);
351 }
352 let first = data
353 .iter()
354 .position(|x| !x.is_nan())
355 .ok_or(KstError::AllValuesNaN)?;
356
357 let s1 = input.get_sma_period1();
358 let s2 = input.get_sma_period2();
359 let s3 = input.get_sma_period3();
360 let s4 = input.get_sma_period4();
361 let r1 = input.get_roc_period1();
362 let r2 = input.get_roc_period2();
363 let r3 = input.get_roc_period3();
364 let r4 = input.get_roc_period4();
365 let sig = input.get_signal_period();
366
367 for &p in [s1, s2, s3, s4, r1, r2, r3, r4, sig].iter() {
368 if p == 0 || p > len {
369 return Err(KstError::InvalidPeriod {
370 period: p,
371 data_len: len,
372 });
373 }
374 }
375
376 let warm1 = r1
377 .checked_add(s1)
378 .and_then(|x| x.checked_sub(1))
379 .ok_or(KstError::SizeOverflow)?;
380 let warm2 = r2
381 .checked_add(s2)
382 .and_then(|x| x.checked_sub(1))
383 .ok_or(KstError::SizeOverflow)?;
384 let warm3 = r3
385 .checked_add(s3)
386 .and_then(|x| x.checked_sub(1))
387 .ok_or(KstError::SizeOverflow)?;
388 let warm4 = r4
389 .checked_add(s4)
390 .and_then(|x| x.checked_sub(1))
391 .ok_or(KstError::SizeOverflow)?;
392 let warm_line = warm1.max(warm2).max(warm3).max(warm4);
393 if len - first < warm_line {
394 return Err(KstError::NotEnoughValidData {
395 needed: warm_line,
396 valid: len - first,
397 });
398 }
399 let warm_sig = warm_line
400 .checked_add(sig)
401 .and_then(|x| x.checked_sub(1))
402 .ok_or(KstError::SizeOverflow)?;
403 if warm_sig > len {
404 return Err(KstError::NotEnoughValidData {
405 needed: warm_sig,
406 valid: len,
407 });
408 }
409
410 let chosen = match kernel {
411 Kernel::Auto => Kernel::Scalar,
412 k => k,
413 };
414 Ok((
415 data,
416 (s1, s2, s3, s4),
417 (r1, r2, r3, r4),
418 sig,
419 first,
420 warm_line,
421 warm_sig,
422 chosen,
423 ))
424}
425
426#[inline(always)]
427fn kst_compute_into(
428 data: &[f64],
429 s: (usize, usize, usize, usize),
430 r: (usize, usize, usize, usize),
431 sig: usize,
432 first: usize,
433 warm_line: usize,
434 warm_sig: usize,
435 out_line: &mut [f64],
436 out_sig: &mut [f64],
437) {
438 let len = data.len();
439 let (s1, s2, s3, s4) = s;
440 let (r1, r2, r3, r4) = r;
441
442 const STACK: usize = 256;
443 let mut sb1 = [0.0f64; STACK];
444 let mut sb2 = [0.0f64; STACK];
445 let mut sb3 = [0.0f64; STACK];
446 let mut sb4 = [0.0f64; STACK];
447 let mut sbs = [0.0f64; STACK];
448
449 let mut v1_heap;
450 let mut v2_heap;
451 let mut v3_heap;
452 let mut v4_heap;
453 let mut vs_heap;
454
455 let (b1, b2, b3, b4, sbuf): (&mut [f64], &mut [f64], &mut [f64], &mut [f64], &mut [f64]) = {
456 v1_heap = if s1 > STACK {
457 vec![0.0; s1]
458 } else {
459 Vec::new()
460 };
461 v2_heap = if s2 > STACK {
462 vec![0.0; s2]
463 } else {
464 Vec::new()
465 };
466 v3_heap = if s3 > STACK {
467 vec![0.0; s3]
468 } else {
469 Vec::new()
470 };
471 v4_heap = if s4 > STACK {
472 vec![0.0; s4]
473 } else {
474 Vec::new()
475 };
476 vs_heap = if sig > STACK {
477 vec![0.0; sig]
478 } else {
479 Vec::new()
480 };
481
482 let b1 = if s1 <= STACK {
483 &mut sb1[..s1]
484 } else {
485 v1_heap.as_mut_slice()
486 };
487 let b2 = if s2 <= STACK {
488 &mut sb2[..s2]
489 } else {
490 v2_heap.as_mut_slice()
491 };
492 let b3 = if s3 <= STACK {
493 &mut sb3[..s3]
494 } else {
495 v3_heap.as_mut_slice()
496 };
497 let b4 = if s4 <= STACK {
498 &mut sb4[..s4]
499 } else {
500 v4_heap.as_mut_slice()
501 };
502 let sbuf = if sig <= STACK {
503 &mut sbs[..sig]
504 } else {
505 vs_heap.as_mut_slice()
506 };
507 (b1, b2, b3, b4, sbuf)
508 };
509
510 let mut i1 = 0usize;
511 let mut i2 = 0usize;
512 let mut i3 = 0usize;
513 let mut i4 = 0usize;
514 let mut sum1 = 0.0f64;
515 let mut sum2 = 0.0f64;
516 let mut sum3 = 0.0f64;
517 let mut sum4 = 0.0f64;
518
519 let inv1 = 1.0 / (s1 as f64);
520 let inv2 = 1.0 / (s2 as f64);
521 let inv3 = 1.0 / (s3 as f64);
522 let inv4 = 1.0 / (s4 as f64);
523 let w2 = inv2 + inv2;
524 let w3 = inv3 + inv3 + inv3;
525 let w4 = (4.0f64) * inv4;
526
527 let start1 = first + r1;
528 let start2 = first + r2;
529 let start3 = first + r3;
530 let start4 = first + r4;
531
532 let start_line = first + warm_line;
533 let warm_sig_abs = first + warm_sig;
534
535 let mut sidx = 0usize;
536 let mut ssum = 0.0f64;
537 let mut sbuilt = 0usize;
538
539 #[inline(always)]
540 fn safe_roc(curr: f64, prev: f64) -> f64 {
541 if prev != 0.0 && curr.is_finite() && prev.is_finite() {
542 ((curr / prev) - 1.0) * 100.0
543 } else {
544 0.0
545 }
546 }
547
548 unsafe {
549 let out_line_ptr = out_line.as_mut_ptr();
550 let out_sig_ptr = out_sig.as_mut_ptr();
551 let data_ptr = data.as_ptr();
552
553 let b1_ptr = b1.as_mut_ptr();
554 let b2_ptr = b2.as_mut_ptr();
555 let b3_ptr = b3.as_mut_ptr();
556 let b4_ptr = b4.as_mut_ptr();
557 let sb_ptr = sbuf.as_mut_ptr();
558
559 #[inline(always)]
560 unsafe fn ring_update(buf: *mut f64, idx: &mut usize, cap: usize, sum: &mut f64, v: f64) {
561 let old = *buf.add(*idx);
562 *sum = (*sum) + (v - old);
563 *buf.add(*idx) = v;
564 *idx += 1;
565 if *idx == cap {
566 *idx = 0;
567 }
568 }
569
570 for i in first..len {
571 let x = *data_ptr.add(i);
572
573 if i >= start1 {
574 let p = *data_ptr.add(i - r1);
575 let v = safe_roc(x, p);
576 ring_update(b1_ptr, &mut i1, s1, &mut sum1, v);
577 }
578 if i >= start2 {
579 let p = *data_ptr.add(i - r2);
580 let v = safe_roc(x, p);
581 ring_update(b2_ptr, &mut i2, s2, &mut sum2, v);
582 }
583 if i >= start3 {
584 let p = *data_ptr.add(i - r3);
585 let v = safe_roc(x, p);
586 ring_update(b3_ptr, &mut i3, s3, &mut sum3, v);
587 }
588 if i >= start4 {
589 let p = *data_ptr.add(i - r4);
590 let v = safe_roc(x, p);
591 ring_update(b4_ptr, &mut i4, s4, &mut sum4, v);
592 }
593
594 if i < start_line {
595 continue;
596 }
597
598 let kst = sum1.mul_add(inv1, sum2.mul_add(w2, sum3.mul_add(w3, sum4 * w4)));
599 *out_line_ptr.add(i) = kst;
600
601 if sbuilt < sig {
602 let old = *sb_ptr.add(sidx);
603 ssum += kst - old;
604 *sb_ptr.add(sidx) = kst;
605 sidx += 1;
606 if sidx == sig {
607 sidx = 0;
608 }
609 sbuilt += 1;
610
611 if i >= warm_sig_abs {
612 *out_sig_ptr.add(i) = ssum / (sig as f64);
613 }
614 } else {
615 let old = *sb_ptr.add(sidx);
616 ssum += kst - old;
617 *sb_ptr.add(sidx) = kst;
618 sidx += 1;
619 if sidx == sig {
620 sidx = 0;
621 }
622 *out_sig_ptr.add(i) = ssum / (sig as f64);
623 }
624 }
625 }
626}
627
628pub fn kst_with_kernel(input: &KstInput, kernel: Kernel) -> Result<KstOutput, KstError> {
629 let (data, s, r, sig, first, warm_line, warm_sig, chosen) = kst_prepare(input, kernel)?;
630 let len = data.len();
631
632 let actual_warm_line = first.checked_add(warm_line).ok_or(KstError::SizeOverflow)?;
633 let actual_warm_sig = first.checked_add(warm_sig).ok_or(KstError::SizeOverflow)?;
634 let mut line = alloc_with_nan_prefix(len, actual_warm_line);
635 let mut signal = alloc_with_nan_prefix(len, actual_warm_sig);
636
637 unsafe {
638 match chosen {
639 Kernel::Scalar | Kernel::ScalarBatch => {
640 kst_compute_into(
641 data,
642 s,
643 r,
644 sig,
645 first,
646 warm_line,
647 warm_sig,
648 &mut line,
649 &mut signal,
650 );
651 }
652 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
653 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
654 kst_compute_into(
655 data,
656 s,
657 r,
658 sig,
659 first,
660 warm_line,
661 warm_sig,
662 &mut line,
663 &mut signal,
664 );
665 }
666 _ => unreachable!(),
667 }
668 }
669 Ok(KstOutput { line, signal })
670}
671
672#[inline]
673pub fn kst_into_slice(
674 out_line: &mut [f64],
675 out_signal: &mut [f64],
676 input: &KstInput,
677 kernel: Kernel,
678) -> Result<(), KstError> {
679 let (data, s, r, sig, first, warm_line, warm_sig, _chosen) = kst_prepare(input, kernel)?;
680 let expected = data.len();
681 if out_line.len() != expected || out_signal.len() != expected {
682 return Err(KstError::OutputLengthMismatch {
683 expected,
684 got: out_line.len().max(out_signal.len()),
685 });
686 }
687
688 kst_compute_into(
689 data, s, r, sig, first, warm_line, warm_sig, out_line, out_signal,
690 );
691
692 let prefix_line = (first + warm_line).min(out_line.len());
693 let prefix_sig = (first + warm_sig).min(out_signal.len());
694 for v in &mut out_line[..prefix_line] {
695 *v = f64::NAN;
696 }
697 for v in &mut out_signal[..prefix_sig] {
698 *v = f64::NAN;
699 }
700 Ok(())
701}
702
703#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
704#[inline]
705pub fn kst_into(
706 input: &KstInput,
707 out_line: &mut [f64],
708 out_signal: &mut [f64],
709) -> Result<(), KstError> {
710 kst_into_slice(out_line, out_signal, input, Kernel::Auto)
711}
712
713#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
714#[inline]
715pub(crate) unsafe fn kst_avx2(
716 _input: &KstInput,
717 _first: usize,
718 _len: usize,
719) -> Result<KstOutput, KstError> {
720 unreachable!("AVX2 stub should not be called directly")
721}
722
723#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
724#[inline]
725pub(crate) unsafe fn kst_avx512(
726 _input: &KstInput,
727 _first: usize,
728 len: usize,
729) -> Result<KstOutput, KstError> {
730 if len <= 32 {
731 kst_avx512_short(_input, _first, len)
732 } else {
733 kst_avx512_long(_input, _first, len)
734 }
735}
736
737#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
738#[inline]
739pub(crate) unsafe fn kst_avx512_short(
740 _input: &KstInput,
741 _first: usize,
742 _len: usize,
743) -> Result<KstOutput, KstError> {
744 unreachable!("AVX512 short stub should not be called directly")
745}
746
747#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
748#[inline]
749pub(crate) unsafe fn kst_avx512_long(
750 _input: &KstInput,
751 _first: usize,
752 _len: usize,
753) -> Result<KstOutput, KstError> {
754 unreachable!("AVX512 long stub should not be called directly")
755}
756
757#[inline]
758pub fn kst_batch_with_kernel(
759 data: &[f64],
760 sweep: &KstBatchRange,
761 k: Kernel,
762) -> Result<KstBatchOutput, KstError> {
763 let kernel = match k {
764 Kernel::Auto => Kernel::ScalarBatch,
765 other if other.is_batch() => other,
766 other => return Err(KstError::InvalidKernelForBatch(other)),
767 };
768 let simd = match kernel {
769 Kernel::Avx512Batch => Kernel::Avx512,
770 Kernel::Avx2Batch => Kernel::Avx2,
771 Kernel::ScalarBatch => Kernel::Scalar,
772 _ => unreachable!(),
773 };
774 kst_batch_par_slice(data, sweep, simd)
775}
776
777#[derive(Clone, Debug)]
778#[cfg_attr(
779 all(target_arch = "wasm32", feature = "wasm"),
780 derive(serde::Serialize, serde::Deserialize)
781)]
782pub struct KstBatchRange {
783 pub sma_period1: (usize, usize, usize),
784 pub sma_period2: (usize, usize, usize),
785 pub sma_period3: (usize, usize, usize),
786 pub sma_period4: (usize, usize, usize),
787 pub roc_period1: (usize, usize, usize),
788 pub roc_period2: (usize, usize, usize),
789 pub roc_period3: (usize, usize, usize),
790 pub roc_period4: (usize, usize, usize),
791 pub signal_period: (usize, usize, usize),
792}
793
794impl Default for KstBatchRange {
795 fn default() -> Self {
796 Self {
797 sma_period1: (10, 10, 0),
798 sma_period2: (10, 10, 0),
799 sma_period3: (10, 10, 0),
800 sma_period4: (15, 15, 0),
801 roc_period1: (10, 10, 0),
802 roc_period2: (15, 15, 0),
803 roc_period3: (20, 20, 0),
804 roc_period4: (30, 30, 0),
805 signal_period: (9, 258, 1),
806 }
807 }
808}
809
810#[derive(Clone, Debug, Default)]
811pub struct KstBatchBuilder {
812 range: KstBatchRange,
813 kernel: Kernel,
814}
815
816impl KstBatchBuilder {
817 pub fn new() -> Self {
818 Self::default()
819 }
820 pub fn kernel(mut self, k: Kernel) -> Self {
821 self.kernel = k;
822 self
823 }
824 pub fn sma_period1_range(mut self, start: usize, end: usize, step: usize) -> Self {
825 self.range.sma_period1 = (start, end, step);
826 self
827 }
828 pub fn sma_period2_range(mut self, start: usize, end: usize, step: usize) -> Self {
829 self.range.sma_period2 = (start, end, step);
830 self
831 }
832 pub fn sma_period3_range(mut self, start: usize, end: usize, step: usize) -> Self {
833 self.range.sma_period3 = (start, end, step);
834 self
835 }
836 pub fn sma_period4_range(mut self, start: usize, end: usize, step: usize) -> Self {
837 self.range.sma_period4 = (start, end, step);
838 self
839 }
840 pub fn roc_period1_range(mut self, start: usize, end: usize, step: usize) -> Self {
841 self.range.roc_period1 = (start, end, step);
842 self
843 }
844 pub fn roc_period2_range(mut self, start: usize, end: usize, step: usize) -> Self {
845 self.range.roc_period2 = (start, end, step);
846 self
847 }
848 pub fn roc_period3_range(mut self, start: usize, end: usize, step: usize) -> Self {
849 self.range.roc_period3 = (start, end, step);
850 self
851 }
852 pub fn roc_period4_range(mut self, start: usize, end: usize, step: usize) -> Self {
853 self.range.roc_period4 = (start, end, step);
854 self
855 }
856 pub fn signal_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
857 self.range.signal_period = (start, end, step);
858 self
859 }
860 pub fn apply_slice(self, data: &[f64]) -> Result<KstBatchOutput, KstError> {
861 kst_batch_with_kernel(data, &self.range, self.kernel)
862 }
863 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<KstBatchOutput, KstError> {
864 KstBatchBuilder::new().kernel(k).apply_slice(data)
865 }
866 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<KstBatchOutput, KstError> {
867 self.apply_slice(source_type(c, src))
868 }
869 pub fn with_default_candles(c: &Candles) -> Result<KstBatchOutput, KstError> {
870 KstBatchBuilder::new()
871 .kernel(Kernel::Auto)
872 .apply_candles(c, "close")
873 }
874}
875
876#[derive(Clone, Debug)]
877pub struct KstBatchOutput {
878 pub lines: Vec<f64>,
879 pub signals: Vec<f64>,
880 pub combos: Vec<KstParams>,
881 pub rows: usize,
882 pub cols: usize,
883}
884impl KstBatchOutput {
885 pub fn row_for_params(&self, p: &KstParams) -> Option<usize> {
886 self.combos.iter().position(|c| {
887 c.sma_period1.unwrap_or(10) == p.sma_period1.unwrap_or(10)
888 && c.sma_period2.unwrap_or(10) == p.sma_period2.unwrap_or(10)
889 && c.sma_period3.unwrap_or(10) == p.sma_period3.unwrap_or(10)
890 && c.sma_period4.unwrap_or(15) == p.sma_period4.unwrap_or(15)
891 && c.roc_period1.unwrap_or(10) == p.roc_period1.unwrap_or(10)
892 && c.roc_period2.unwrap_or(15) == p.roc_period2.unwrap_or(15)
893 && c.roc_period3.unwrap_or(20) == p.roc_period3.unwrap_or(20)
894 && c.roc_period4.unwrap_or(30) == p.roc_period4.unwrap_or(30)
895 && c.signal_period.unwrap_or(9) == p.signal_period.unwrap_or(9)
896 })
897 }
898 pub fn lines_for(&self, p: &KstParams) -> Option<&[f64]> {
899 self.row_for_params(p).map(|row| {
900 let start = row * self.cols;
901 &self.lines[start..start + self.cols]
902 })
903 }
904 pub fn signals_for(&self, p: &KstParams) -> Option<&[f64]> {
905 self.row_for_params(p).map(|row| {
906 let start = row * self.cols;
907 &self.signals[start..start + self.cols]
908 })
909 }
910}
911
912#[inline(always)]
913fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
914 if step == 0 {
915 return vec![start];
916 }
917 if start == end {
918 return vec![start];
919 }
920 let mut out = Vec::new();
921 if start < end {
922 let mut v = start;
923 while v <= end {
924 out.push(v);
925 let next = match v.checked_add(step) {
926 Some(n) if n > v => n,
927 _ => break,
928 };
929 v = next;
930 }
931 } else {
932 let mut v = start;
933 while v >= end {
934 out.push(v);
935 if v - end < step {
936 break;
937 }
938 v -= step;
939 }
940 }
941 out
942}
943
944#[inline(always)]
945fn expand_grid(r: &KstBatchRange) -> Result<Vec<KstParams>, KstError> {
946 let s1 = axis_usize(r.sma_period1);
947 let s2 = axis_usize(r.sma_period2);
948 let s3 = axis_usize(r.sma_period3);
949 let s4 = axis_usize(r.sma_period4);
950 let r1 = axis_usize(r.roc_period1);
951 let r2 = axis_usize(r.roc_period2);
952 let r3 = axis_usize(r.roc_period3);
953 let r4 = axis_usize(r.roc_period4);
954 let sig = axis_usize(r.signal_period);
955
956 if s1.is_empty() {
957 return Err(KstError::InvalidRange {
958 start: r.sma_period1.0,
959 end: r.sma_period1.1,
960 step: r.sma_period1.2,
961 });
962 }
963 if s2.is_empty() {
964 return Err(KstError::InvalidRange {
965 start: r.sma_period2.0,
966 end: r.sma_period2.1,
967 step: r.sma_period2.2,
968 });
969 }
970 if s3.is_empty() {
971 return Err(KstError::InvalidRange {
972 start: r.sma_period3.0,
973 end: r.sma_period3.1,
974 step: r.sma_period3.2,
975 });
976 }
977 if s4.is_empty() {
978 return Err(KstError::InvalidRange {
979 start: r.sma_period4.0,
980 end: r.sma_period4.1,
981 step: r.sma_period4.2,
982 });
983 }
984 if r1.is_empty() {
985 return Err(KstError::InvalidRange {
986 start: r.roc_period1.0,
987 end: r.roc_period1.1,
988 step: r.roc_period1.2,
989 });
990 }
991 if r2.is_empty() {
992 return Err(KstError::InvalidRange {
993 start: r.roc_period2.0,
994 end: r.roc_period2.1,
995 step: r.roc_period2.2,
996 });
997 }
998 if r3.is_empty() {
999 return Err(KstError::InvalidRange {
1000 start: r.roc_period3.0,
1001 end: r.roc_period3.1,
1002 step: r.roc_period3.2,
1003 });
1004 }
1005 if r4.is_empty() {
1006 return Err(KstError::InvalidRange {
1007 start: r.roc_period4.0,
1008 end: r.roc_period4.1,
1009 step: r.roc_period4.2,
1010 });
1011 }
1012 if sig.is_empty() {
1013 return Err(KstError::InvalidRange {
1014 start: r.signal_period.0,
1015 end: r.signal_period.1,
1016 step: r.signal_period.2,
1017 });
1018 }
1019
1020 let total = s1
1021 .len()
1022 .checked_mul(s2.len())
1023 .and_then(|x| x.checked_mul(s3.len()))
1024 .and_then(|x| x.checked_mul(s4.len()))
1025 .and_then(|x| x.checked_mul(r1.len()))
1026 .and_then(|x| x.checked_mul(r2.len()))
1027 .and_then(|x| x.checked_mul(r3.len()))
1028 .and_then(|x| x.checked_mul(r4.len()))
1029 .and_then(|x| x.checked_mul(sig.len()))
1030 .ok_or(KstError::SizeOverflow)?;
1031
1032 let mut out = Vec::with_capacity(total);
1033 for &s1v in &s1 {
1034 for &s2v in &s2 {
1035 for &s3v in &s3 {
1036 for &s4v in &s4 {
1037 for &r1v in &r1 {
1038 for &r2v in &r2 {
1039 for &r3v in &r3 {
1040 for &r4v in &r4 {
1041 for &sigv in &sig {
1042 out.push(KstParams {
1043 sma_period1: Some(s1v),
1044 sma_period2: Some(s2v),
1045 sma_period3: Some(s3v),
1046 sma_period4: Some(s4v),
1047 roc_period1: Some(r1v),
1048 roc_period2: Some(r2v),
1049 roc_period3: Some(r3v),
1050 roc_period4: Some(r4v),
1051 signal_period: Some(sigv),
1052 });
1053 }
1054 }
1055 }
1056 }
1057 }
1058 }
1059 }
1060 }
1061 }
1062 Ok(out)
1063}
1064
1065#[inline(always)]
1066pub fn kst_batch_slice(
1067 data: &[f64],
1068 sweep: &KstBatchRange,
1069 kern: Kernel,
1070) -> Result<KstBatchOutput, KstError> {
1071 kst_batch_inner(data, sweep, kern, false)
1072}
1073#[inline(always)]
1074pub fn kst_batch_par_slice(
1075 data: &[f64],
1076 sweep: &KstBatchRange,
1077 kern: Kernel,
1078) -> Result<KstBatchOutput, KstError> {
1079 kst_batch_inner(data, sweep, kern, true)
1080}
1081
1082#[inline(always)]
1083fn kst_batch_inner(
1084 data: &[f64],
1085 sweep: &KstBatchRange,
1086 kern: Kernel,
1087 parallel: bool,
1088) -> Result<KstBatchOutput, KstError> {
1089 let combos = expand_grid(sweep)?;
1090 let cols = data.len();
1091 if cols == 0 {
1092 return Err(KstError::EmptyInputData);
1093 }
1094
1095 let first = data
1096 .iter()
1097 .position(|x| !x.is_nan())
1098 .ok_or(KstError::AllValuesNaN)?;
1099
1100 let mut warm_line = Vec::with_capacity(combos.len());
1101 let mut warm_sig = Vec::with_capacity(combos.len());
1102 for c in &combos {
1103 let s1 = c.sma_period1.unwrap();
1104 let s2 = c.sma_period2.unwrap();
1105 let s3 = c.sma_period3.unwrap();
1106 let s4 = c.sma_period4.unwrap();
1107 let r1 = c.roc_period1.unwrap();
1108 let r2 = c.roc_period2.unwrap();
1109 let r3 = c.roc_period3.unwrap();
1110 let r4 = c.roc_period4.unwrap();
1111 let sig = c.signal_period.unwrap();
1112 let wl = (r1 + s1 - 1)
1113 .max(r2 + s2 - 1)
1114 .max(r3 + s3 - 1)
1115 .max(r4 + s4 - 1);
1116 warm_line.push(wl);
1117 warm_sig.push(wl + sig - 1);
1118 }
1119
1120 let rows = combos.len();
1121 let mut line_mu = make_uninit_matrix(rows, cols);
1122 let mut signal_mu = make_uninit_matrix(rows, cols);
1123 let warm_line_abs: Vec<usize> = warm_line.iter().map(|&wl| first + wl).collect();
1124 let warm_sig_abs: Vec<usize> = warm_sig.iter().map(|&ws| first + ws).collect();
1125 init_matrix_prefixes(&mut line_mu, cols, &warm_line_abs);
1126 init_matrix_prefixes(&mut signal_mu, cols, &warm_sig_abs);
1127
1128 let mut line_guard = core::mem::ManuallyDrop::new(line_mu);
1129 let mut signal_guard = core::mem::ManuallyDrop::new(signal_mu);
1130 let line_out: &mut [f64] = unsafe {
1131 core::slice::from_raw_parts_mut(line_guard.as_mut_ptr() as *mut f64, line_guard.len())
1132 };
1133 let signal_out: &mut [f64] = unsafe {
1134 core::slice::from_raw_parts_mut(signal_guard.as_mut_ptr() as *mut f64, signal_guard.len())
1135 };
1136
1137 let actual = match kern {
1138 Kernel::Auto => detect_best_batch_kernel(),
1139 k => k,
1140 };
1141 let simd = match actual {
1142 Kernel::ScalarBatch => Kernel::Scalar,
1143 Kernel::Avx2Batch => Kernel::Scalar,
1144 Kernel::Avx512Batch => Kernel::Scalar,
1145 _ => Kernel::Scalar,
1146 };
1147
1148 use std::collections::HashMap;
1149 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
1150 struct R4(usize, usize, usize, usize);
1151
1152 let mut groups: HashMap<R4, Vec<usize>> = HashMap::new();
1153 for (idx, prm) in combos.iter().enumerate() {
1154 groups
1155 .entry(R4(
1156 prm.roc_period1.unwrap(),
1157 prm.roc_period2.unwrap(),
1158 prm.roc_period3.unwrap(),
1159 prm.roc_period4.unwrap(),
1160 ))
1161 .or_default()
1162 .push(idx);
1163 }
1164
1165 fn compute_roc_streams(
1166 data: &[f64],
1167 first: usize,
1168 r: (usize, usize, usize, usize),
1169 ) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
1170 let len = data.len();
1171 let (r1, r2, r3, r4) = r;
1172 let mut v1 = vec![0.0f64; len];
1173 let mut v2 = vec![0.0f64; len];
1174 let mut v3 = vec![0.0f64; len];
1175 let mut v4 = vec![0.0f64; len];
1176 for i in first..len {
1177 let x = data[i];
1178 if i >= first + r1 {
1179 let p = data[i - r1];
1180 if x.is_finite() && p.is_finite() && p != 0.0 {
1181 v1[i] = ((x / p) - 1.0) * 100.0;
1182 }
1183 }
1184 if i >= first + r2 {
1185 let p = data[i - r2];
1186 if x.is_finite() && p.is_finite() && p != 0.0 {
1187 v2[i] = ((x / p) - 1.0) * 100.0;
1188 }
1189 }
1190 if i >= first + r3 {
1191 let p = data[i - r3];
1192 if x.is_finite() && p.is_finite() && p != 0.0 {
1193 v3[i] = ((x / p) - 1.0) * 100.0;
1194 }
1195 }
1196 if i >= first + r4 {
1197 let p = data[i - r4];
1198 if x.is_finite() && p.is_finite() && p != 0.0 {
1199 v4[i] = ((x / p) - 1.0) * 100.0;
1200 }
1201 }
1202 }
1203 (v1, v2, v3, v4)
1204 }
1205
1206 struct Streams {
1207 v1: Vec<f64>,
1208 v2: Vec<f64>,
1209 v3: Vec<f64>,
1210 v4: Vec<f64>,
1211 }
1212
1213 let mut streams_map: HashMap<R4, Streams> = HashMap::with_capacity(groups.len());
1214 for (key @ R4(r1, r2, r3, r4), _rows) in groups.iter() {
1215 let (v1, v2, v3, v4) = compute_roc_streams(data, first, (*r1, *r2, *r3, *r4));
1216 streams_map.insert(*key, Streams { v1, v2, v3, v4 });
1217 }
1218
1219 let do_row = |row: usize, ldst: &mut [f64], sdst: &mut [f64]| {
1220 let prm = &combos[row];
1221 let s = (
1222 prm.sma_period1.unwrap(),
1223 prm.sma_period2.unwrap(),
1224 prm.sma_period3.unwrap(),
1225 prm.sma_period4.unwrap(),
1226 );
1227 let r = (
1228 prm.roc_period1.unwrap(),
1229 prm.roc_period2.unwrap(),
1230 prm.roc_period3.unwrap(),
1231 prm.roc_period4.unwrap(),
1232 );
1233 let sig = prm.signal_period.unwrap();
1234 let wl = (r.0 + s.0 - 1)
1235 .max(r.1 + s.1 - 1)
1236 .max(r.2 + s.2 - 1)
1237 .max(r.3 + s.3 - 1);
1238 let ws = wl + sig - 1;
1239
1240 let key = R4(r.0, r.1, r.2, r.3);
1241 let st = streams_map.get(&key).unwrap();
1242
1243 let len = ldst.len();
1244 let (s1, s2, s3, s4) = s;
1245 let inv1 = 1.0 / (s1 as f64);
1246 let inv2 = 1.0 / (s2 as f64);
1247 let inv3 = 1.0 / (s3 as f64);
1248 let inv4 = 1.0 / (s4 as f64);
1249 let w2 = inv2 + inv2;
1250 let w3 = inv3 + inv3 + inv3;
1251 let w4 = 4.0f64 * inv4;
1252
1253 let mut b1 = vec![0.0f64; s1];
1254 let mut b2 = vec![0.0f64; s2];
1255 let mut b3 = vec![0.0f64; s3];
1256 let mut b4 = vec![0.0f64; s4];
1257 let mut i1 = 0usize;
1258 let mut i2 = 0usize;
1259 let mut i3 = 0usize;
1260 let mut i4 = 0usize;
1261 let mut sum1 = 0.0f64;
1262 let mut sum2 = 0.0f64;
1263 let mut sum3 = 0.0f64;
1264 let mut sum4 = 0.0f64;
1265
1266 let start_line = first + wl;
1267 let warm_sig_abs = first + ws;
1268
1269 let mut sbuf = vec![0.0f64; sig];
1270 let mut sidx = 0usize;
1271 let mut ssum = 0.0f64;
1272 let mut sbuilt = 0usize;
1273
1274 unsafe {
1275 let b1p = b1.as_mut_ptr();
1276 let b2p = b2.as_mut_ptr();
1277 let b3p = b3.as_mut_ptr();
1278 let b4p = b4.as_mut_ptr();
1279 let sbp = sbuf.as_mut_ptr();
1280 let lptr = ldst.as_mut_ptr();
1281 let sptr = sdst.as_mut_ptr();
1282 let v1p = st.v1.as_ptr();
1283 let v2p = st.v2.as_ptr();
1284 let v3p = st.v3.as_ptr();
1285 let v4p = st.v4.as_ptr();
1286
1287 #[inline(always)]
1288 unsafe fn ring_update(
1289 buf: *mut f64,
1290 idx: &mut usize,
1291 cap: usize,
1292 sum: &mut f64,
1293 v: f64,
1294 ) {
1295 let old = *buf.add(*idx);
1296 *sum = (*sum) + (v - old);
1297 *buf.add(*idx) = v;
1298 *idx += 1;
1299 if *idx == cap {
1300 *idx = 0;
1301 }
1302 }
1303
1304 for i in first..len {
1305 let v1 = *v1p.add(i);
1306 let v2 = *v2p.add(i);
1307 let v3 = *v3p.add(i);
1308 let v4 = *v4p.add(i);
1309
1310 ring_update(b1p, &mut i1, s1, &mut sum1, v1);
1311 ring_update(b2p, &mut i2, s2, &mut sum2, v2);
1312 ring_update(b3p, &mut i3, s3, &mut sum3, v3);
1313 ring_update(b4p, &mut i4, s4, &mut sum4, v4);
1314
1315 if i < start_line {
1316 continue;
1317 }
1318
1319 let kst = sum1.mul_add(inv1, sum2.mul_add(w2, sum3.mul_add(w3, sum4 * w4)));
1320 *lptr.add(i) = kst;
1321
1322 if sbuilt < sig {
1323 let old = *sbp.add(sidx);
1324 ssum += kst - old;
1325 *sbp.add(sidx) = kst;
1326 sidx += 1;
1327 if sidx == sig {
1328 sidx = 0;
1329 }
1330 sbuilt += 1;
1331 if i >= warm_sig_abs {
1332 *sptr.add(i) = ssum / (sig as f64);
1333 }
1334 } else {
1335 let old = *sbp.add(sidx);
1336 ssum += kst - old;
1337 *sbp.add(sidx) = kst;
1338 sidx += 1;
1339 if sidx == sig {
1340 sidx = 0;
1341 }
1342 *sptr.add(i) = ssum / (sig as f64);
1343 }
1344 }
1345 }
1346 };
1347
1348 if parallel {
1349 #[cfg(not(target_arch = "wasm32"))]
1350 line_out
1351 .par_chunks_mut(cols)
1352 .zip(signal_out.par_chunks_mut(cols))
1353 .enumerate()
1354 .for_each(|(row, (l, s))| do_row(row, l, s));
1355 #[cfg(target_arch = "wasm32")]
1356 for (row, (l, s)) in line_out
1357 .chunks_mut(cols)
1358 .zip(signal_out.chunks_mut(cols))
1359 .enumerate()
1360 {
1361 do_row(row, l, s);
1362 }
1363 } else {
1364 for (row, (l, s)) in line_out
1365 .chunks_mut(cols)
1366 .zip(signal_out.chunks_mut(cols))
1367 .enumerate()
1368 {
1369 do_row(row, l, s);
1370 }
1371 }
1372
1373 let lines = unsafe {
1374 Vec::from_raw_parts(
1375 line_guard.as_mut_ptr() as *mut f64,
1376 line_guard.len(),
1377 line_guard.capacity(),
1378 )
1379 };
1380 let signals = unsafe {
1381 Vec::from_raw_parts(
1382 signal_guard.as_mut_ptr() as *mut f64,
1383 signal_guard.len(),
1384 signal_guard.capacity(),
1385 )
1386 };
1387
1388 Ok(KstBatchOutput {
1389 lines,
1390 signals,
1391 combos,
1392 rows,
1393 cols,
1394 })
1395}
1396
1397#[inline(always)]
1398fn kst_batch_inner_into(
1399 data: &[f64],
1400 sweep: &KstBatchRange,
1401 kern: Kernel,
1402 parallel: bool,
1403 lines_out: &mut [f64],
1404 signals_out: &mut [f64],
1405) -> Result<Vec<KstParams>, KstError> {
1406 let combos = expand_grid(sweep)?;
1407 let cols = data.len();
1408 let rows = combos.len();
1409 let total = rows.checked_mul(cols).ok_or(KstError::SizeOverflow)?;
1410 if lines_out.len() != total || signals_out.len() != total {
1411 return Err(KstError::OutputLengthMismatch {
1412 expected: total,
1413 got: lines_out.len().max(signals_out.len()),
1414 });
1415 }
1416
1417 let first = data
1418 .iter()
1419 .position(|x| !x.is_nan())
1420 .ok_or(KstError::AllValuesNaN)?;
1421
1422 let mut warm_line = vec![0usize; rows];
1423 let mut warm_sig = vec![0usize; rows];
1424 for (row, c) in combos.iter().enumerate() {
1425 let wl = (c.roc_period1.unwrap() + c.sma_period1.unwrap() - 1)
1426 .max(c.roc_period2.unwrap() + c.sma_period2.unwrap() - 1)
1427 .max(c.roc_period3.unwrap() + c.sma_period3.unwrap() - 1)
1428 .max(c.roc_period4.unwrap() + c.sma_period4.unwrap() - 1);
1429 warm_line[row] = wl;
1430 warm_sig[row] = wl + c.signal_period.unwrap() - 1;
1431
1432 let abs_wl = first + warm_line[row];
1433 let abs_ws = first + warm_sig[row];
1434 for v in &mut lines_out[row * cols..row * cols + abs_wl.min(cols)] {
1435 *v = f64::NAN;
1436 }
1437 for v in &mut signals_out[row * cols..row * cols + abs_ws.min(cols)] {
1438 *v = f64::NAN;
1439 }
1440 }
1441
1442 let actual = match kern {
1443 Kernel::Auto => detect_best_batch_kernel(),
1444 k => k,
1445 };
1446 let simd = match actual {
1447 Kernel::ScalarBatch => Kernel::Scalar,
1448 Kernel::Avx2Batch => Kernel::Scalar,
1449 Kernel::Avx512Batch => Kernel::Scalar,
1450 _ => Kernel::Scalar,
1451 };
1452
1453 use std::collections::HashMap;
1454 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
1455 struct R4(usize, usize, usize, usize);
1456
1457 let mut groups: HashMap<R4, Vec<usize>> = HashMap::new();
1458 for (idx, prm) in combos.iter().enumerate() {
1459 groups
1460 .entry(R4(
1461 prm.roc_period1.unwrap(),
1462 prm.roc_period2.unwrap(),
1463 prm.roc_period3.unwrap(),
1464 prm.roc_period4.unwrap(),
1465 ))
1466 .or_default()
1467 .push(idx);
1468 }
1469
1470 fn compute_roc_streams(
1471 data: &[f64],
1472 first: usize,
1473 r: (usize, usize, usize, usize),
1474 ) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
1475 let len = data.len();
1476 let (r1, r2, r3, r4) = r;
1477 let mut v1 = vec![0.0f64; len];
1478 let mut v2 = vec![0.0f64; len];
1479 let mut v3 = vec![0.0f64; len];
1480 let mut v4 = vec![0.0f64; len];
1481 for i in first..len {
1482 let x = data[i];
1483 if i >= first + r1 {
1484 let p = data[i - r1];
1485 if x.is_finite() && p.is_finite() && p != 0.0 {
1486 v1[i] = ((x / p) - 1.0) * 100.0;
1487 }
1488 }
1489 if i >= first + r2 {
1490 let p = data[i - r2];
1491 if x.is_finite() && p.is_finite() && p != 0.0 {
1492 v2[i] = ((x / p) - 1.0) * 100.0;
1493 }
1494 }
1495 if i >= first + r3 {
1496 let p = data[i - r3];
1497 if x.is_finite() && p.is_finite() && p != 0.0 {
1498 v3[i] = ((x / p) - 1.0) * 100.0;
1499 }
1500 }
1501 if i >= first + r4 {
1502 let p = data[i - r4];
1503 if x.is_finite() && p.is_finite() && p != 0.0 {
1504 v4[i] = ((x / p) - 1.0) * 100.0;
1505 }
1506 }
1507 }
1508 (v1, v2, v3, v4)
1509 }
1510
1511 struct Streams {
1512 v1: Vec<f64>,
1513 v2: Vec<f64>,
1514 v3: Vec<f64>,
1515 v4: Vec<f64>,
1516 }
1517
1518 let mut streams_map: HashMap<R4, Streams> = HashMap::with_capacity(groups.len());
1519 for (key @ R4(r1, r2, r3, r4), _rows) in groups.iter() {
1520 let (v1, v2, v3, v4) = compute_roc_streams(data, first, (*r1, *r2, *r3, *r4));
1521 streams_map.insert(*key, Streams { v1, v2, v3, v4 });
1522 }
1523
1524 let do_row = |row: usize, ldst: &mut [f64], sdst: &mut [f64]| {
1525 let prm = &combos[row];
1526 let s = (
1527 prm.sma_period1.unwrap(),
1528 prm.sma_period2.unwrap(),
1529 prm.sma_period3.unwrap(),
1530 prm.sma_period4.unwrap(),
1531 );
1532 let r = (
1533 prm.roc_period1.unwrap(),
1534 prm.roc_period2.unwrap(),
1535 prm.roc_period3.unwrap(),
1536 prm.roc_period4.unwrap(),
1537 );
1538 let sig = prm.signal_period.unwrap();
1539 let wl = (r.0 + s.0 - 1)
1540 .max(r.1 + s.1 - 1)
1541 .max(r.2 + s.2 - 1)
1542 .max(r.3 + s.3 - 1);
1543 let ws = wl + sig - 1;
1544
1545 let key = R4(r.0, r.1, r.2, r.3);
1546 let st = streams_map.get(&key).unwrap();
1547
1548 let len = ldst.len();
1549 let (s1, s2, s3, s4) = s;
1550 let inv1 = 1.0 / (s1 as f64);
1551 let inv2 = 1.0 / (s2 as f64);
1552 let inv3 = 1.0 / (s3 as f64);
1553 let inv4 = 1.0 / (s4 as f64);
1554 let w2 = inv2 + inv2;
1555 let w3 = inv3 + inv3 + inv3;
1556 let w4 = 4.0f64 * inv4;
1557
1558 let mut b1 = vec![0.0f64; s1];
1559 let mut b2 = vec![0.0f64; s2];
1560 let mut b3 = vec![0.0f64; s3];
1561 let mut b4 = vec![0.0f64; s4];
1562 let mut i1 = 0usize;
1563 let mut i2 = 0usize;
1564 let mut i3 = 0usize;
1565 let mut i4 = 0usize;
1566 let mut sum1 = 0.0f64;
1567 let mut sum2 = 0.0f64;
1568 let mut sum3 = 0.0f64;
1569 let mut sum4 = 0.0f64;
1570
1571 let start_line = first + wl;
1572 let warm_sig_abs = first + ws;
1573
1574 let mut sbuf = vec![0.0f64; sig];
1575 let mut sidx = 0usize;
1576 let mut ssum = 0.0f64;
1577 let mut sbuilt = 0usize;
1578
1579 unsafe {
1580 let b1p = b1.as_mut_ptr();
1581 let b2p = b2.as_mut_ptr();
1582 let b3p = b3.as_mut_ptr();
1583 let b4p = b4.as_mut_ptr();
1584 let sbp = sbuf.as_mut_ptr();
1585 let lptr = ldst.as_mut_ptr();
1586 let sptr = sdst.as_mut_ptr();
1587 let v1p = st.v1.as_ptr();
1588 let v2p = st.v2.as_ptr();
1589 let v3p = st.v3.as_ptr();
1590 let v4p = st.v4.as_ptr();
1591
1592 #[inline(always)]
1593 unsafe fn ring_update(
1594 buf: *mut f64,
1595 idx: &mut usize,
1596 cap: usize,
1597 sum: &mut f64,
1598 v: f64,
1599 ) {
1600 let old = *buf.add(*idx);
1601 *sum = (*sum) + (v - old);
1602 *buf.add(*idx) = v;
1603 *idx += 1;
1604 if *idx == cap {
1605 *idx = 0;
1606 }
1607 }
1608
1609 for i in first..len {
1610 let v1 = *v1p.add(i);
1611 let v2 = *v2p.add(i);
1612 let v3 = *v3p.add(i);
1613 let v4 = *v4p.add(i);
1614
1615 ring_update(b1p, &mut i1, s1, &mut sum1, v1);
1616 ring_update(b2p, &mut i2, s2, &mut sum2, v2);
1617 ring_update(b3p, &mut i3, s3, &mut sum3, v3);
1618 ring_update(b4p, &mut i4, s4, &mut sum4, v4);
1619
1620 if i < start_line {
1621 continue;
1622 }
1623
1624 let kst = sum1.mul_add(inv1, sum2.mul_add(w2, sum3.mul_add(w3, sum4 * w4)));
1625 *lptr.add(i) = kst;
1626
1627 if sbuilt < sig {
1628 let old = *sbp.add(sidx);
1629 ssum += kst - old;
1630 *sbp.add(sidx) = kst;
1631 sidx += 1;
1632 if sidx == sig {
1633 sidx = 0;
1634 }
1635 sbuilt += 1;
1636 if i >= warm_sig_abs {
1637 *sptr.add(i) = ssum / (sig as f64);
1638 }
1639 } else {
1640 let old = *sbp.add(sidx);
1641 ssum += kst - old;
1642 *sbp.add(sidx) = kst;
1643 sidx += 1;
1644 if sidx == sig {
1645 sidx = 0;
1646 }
1647 *sptr.add(i) = ssum / (sig as f64);
1648 }
1649 }
1650 }
1651 };
1652
1653 if parallel {
1654 #[cfg(not(target_arch = "wasm32"))]
1655 lines_out
1656 .par_chunks_mut(cols)
1657 .zip(signals_out.par_chunks_mut(cols))
1658 .enumerate()
1659 .for_each(|(r, (l, s))| do_row(r, l, s));
1660 #[cfg(target_arch = "wasm32")]
1661 for (r, (l, s)) in lines_out
1662 .chunks_mut(cols)
1663 .zip(signals_out.chunks_mut(cols))
1664 .enumerate()
1665 {
1666 do_row(r, l, s);
1667 }
1668 } else {
1669 for (r, (l, s)) in lines_out
1670 .chunks_mut(cols)
1671 .zip(signals_out.chunks_mut(cols))
1672 .enumerate()
1673 {
1674 do_row(r, l, s);
1675 }
1676 }
1677
1678 Ok(combos)
1679}
1680
1681#[derive(Debug, Clone)]
1682pub struct KstStream {
1683 s: (usize, usize, usize, usize),
1684 r: (usize, usize, usize, usize),
1685 sig: usize,
1686
1687 b1: Vec<f64>,
1688 b2: Vec<f64>,
1689 b3: Vec<f64>,
1690 b4: Vec<f64>,
1691 i1: usize,
1692 i2: usize,
1693 i3: usize,
1694 i4: usize,
1695 sum1: f64,
1696 sum2: f64,
1697 sum3: f64,
1698 sum4: f64,
1699
1700 inv1: f64,
1701 w2: f64,
1702 w3: f64,
1703 w4: f64,
1704
1705 sig_buf: Vec<f64>,
1706 sig_idx: usize,
1707 sig_sum: f64,
1708
1709 price_ring: Vec<f64>,
1710 recip_ring: Vec<f64>,
1711 head: usize,
1712
1713 t: usize,
1714 warm_line: usize,
1715 warm_sig: usize,
1716
1717 last_line: f64,
1718}
1719impl KstStream {
1720 #[inline]
1721 pub fn try_new(params: KstParams) -> Result<Self, KstError> {
1722 let s1 = params.sma_period1.unwrap_or(10);
1723 let s2 = params.sma_period2.unwrap_or(10);
1724 let s3 = params.sma_period3.unwrap_or(10);
1725 let s4 = params.sma_period4.unwrap_or(15);
1726
1727 let r1 = params.roc_period1.unwrap_or(10);
1728 let r2 = params.roc_period2.unwrap_or(15);
1729 let r3 = params.roc_period3.unwrap_or(20);
1730 let r4 = params.roc_period4.unwrap_or(30);
1731
1732 let sig = params.signal_period.unwrap_or(9);
1733
1734 for &p in [s1, s2, s3, s4, r1, r2, r3, r4, sig].iter() {
1735 if p == 0 {
1736 return Err(KstError::InvalidPeriod {
1737 period: p,
1738 data_len: 0,
1739 });
1740 }
1741 }
1742
1743 let warm_line = (r1 + s1 - 1)
1744 .max(r2 + s2 - 1)
1745 .max(r3 + s3 - 1)
1746 .max(r4 + s4 - 1);
1747 let warm_sig = warm_line + sig - 1;
1748
1749 let max_roc = r1.max(r2).max(r3).max(r4);
1750 let price_cap = max_roc + 1;
1751
1752 Ok(Self {
1753 s: (s1, s2, s3, s4),
1754 r: (r1, r2, r3, r4),
1755 sig,
1756
1757 b1: vec![0.0; s1],
1758 b2: vec![0.0; s2],
1759 b3: vec![0.0; s3],
1760 b4: vec![0.0; s4],
1761 i1: 0,
1762 i2: 0,
1763 i3: 0,
1764 i4: 0,
1765 sum1: 0.0,
1766 sum2: 0.0,
1767 sum3: 0.0,
1768 sum4: 0.0,
1769
1770 inv1: 1.0 / (s1 as f64),
1771 w2: (2.0f64) / (s2 as f64),
1772 w3: (3.0f64) / (s3 as f64),
1773 w4: (4.0f64) / (s4 as f64),
1774
1775 sig_buf: vec![0.0; sig],
1776 sig_idx: 0,
1777 sig_sum: 0.0,
1778
1779 price_ring: vec![f64::NAN; price_cap],
1780 recip_ring: vec![f64::NAN; price_cap],
1781 head: 0,
1782
1783 t: 0,
1784 warm_line,
1785 warm_sig,
1786
1787 last_line: f64::NAN,
1788 })
1789 }
1790
1791 #[inline(always)]
1792 pub fn update(&mut self, price: f64) -> Option<(f64, f64)> {
1793 self.price_ring[self.head] = price;
1794 self.recip_ring[self.head] = if price.is_finite() && price != 0.0 {
1795 1.0 / price
1796 } else {
1797 f64::NAN
1798 };
1799
1800 Self::wrap_inc(&mut self.head, self.price_ring.len());
1801
1802 let cap = self.price_ring.len();
1803 let (s1, s2, s3, s4) = self.s;
1804 let (r1, r2, r3, r4) = self.r;
1805
1806 let mut v1 = 0.0;
1807 if self.t >= r1 {
1808 let idx = Self::back_from_next(self.head, cap, r1 + 1);
1809 let pinv = self.recip_ring[idx];
1810 if price.is_finite() && pinv.is_finite() {
1811 v1 = (price * pinv - 1.0) * 100.0;
1812 }
1813 }
1814 let mut v2 = 0.0;
1815 if self.t >= r2 {
1816 let idx = Self::back_from_next(self.head, cap, r2 + 1);
1817 let pinv = self.recip_ring[idx];
1818 if price.is_finite() && pinv.is_finite() {
1819 v2 = (price * pinv - 1.0) * 100.0;
1820 }
1821 }
1822 let mut v3 = 0.0;
1823 if self.t >= r3 {
1824 let idx = Self::back_from_next(self.head, cap, r3 + 1);
1825 let pinv = self.recip_ring[idx];
1826 if price.is_finite() && pinv.is_finite() {
1827 v3 = (price * pinv - 1.0) * 100.0;
1828 }
1829 }
1830 let mut v4 = 0.0;
1831 if self.t >= r4 {
1832 let idx = Self::back_from_next(self.head, cap, r4 + 1);
1833 let pinv = self.recip_ring[idx];
1834 if price.is_finite() && pinv.is_finite() {
1835 v4 = (price * pinv - 1.0) * 100.0;
1836 }
1837 }
1838
1839 if self.t >= r1 {
1840 self.sum1 -= self.b1[self.i1];
1841 self.b1[self.i1] = v1;
1842 self.sum1 += v1;
1843 Self::wrap_inc(&mut self.i1, s1);
1844 }
1845 if self.t >= r2 {
1846 self.sum2 -= self.b2[self.i2];
1847 self.b2[self.i2] = v2;
1848 self.sum2 += v2;
1849 Self::wrap_inc(&mut self.i2, s2);
1850 }
1851 if self.t >= r3 {
1852 self.sum3 -= self.b3[self.i3];
1853 self.b3[self.i3] = v3;
1854 self.sum3 += v3;
1855 Self::wrap_inc(&mut self.i3, s3);
1856 }
1857 if self.t >= r4 {
1858 self.sum4 -= self.b4[self.i4];
1859 self.b4[self.i4] = v4;
1860 self.sum4 += v4;
1861 Self::wrap_inc(&mut self.i4, s4);
1862 }
1863
1864 self.t += 1;
1865
1866 if self.t <= self.warm_line {
1867 return None;
1868 }
1869
1870 let line = self.sum1.mul_add(
1871 self.inv1,
1872 self.sum2
1873 .mul_add(self.w2, self.sum3.mul_add(self.w3, self.sum4 * self.w4)),
1874 );
1875
1876 self.last_line = line;
1877
1878 let old = self.sig_buf[self.sig_idx];
1879 self.sig_sum += line - old;
1880 self.sig_buf[self.sig_idx] = line;
1881 Self::wrap_inc(&mut self.sig_idx, self.sig);
1882
1883 let signal = if self.t >= self.warm_sig {
1884 self.sig_sum / (self.sig as f64)
1885 } else {
1886 f64::NAN
1887 };
1888
1889 Some((line, signal))
1890 }
1891
1892 #[inline(always)]
1893 fn wrap_inc(idx: &mut usize, cap: usize) {
1894 *idx += 1;
1895 if *idx == cap {
1896 *idx = 0;
1897 }
1898 }
1899
1900 #[inline(always)]
1901 fn back_from_next(next: usize, cap: usize, k: usize) -> usize {
1902 debug_assert!(k <= cap);
1903 let mut idx = next;
1904 if idx < k {
1905 idx += cap;
1906 }
1907 idx - k
1908 }
1909}
1910
1911#[cfg(test)]
1912mod tests {
1913 use super::*;
1914 use crate::skip_if_unsupported;
1915 use crate::utilities::data_loader::read_candles_from_csv;
1916 #[cfg(feature = "proptest")]
1917 use proptest::prelude::*;
1918
1919 fn check_kst_default_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1920 skip_if_unsupported!(kernel, test_name);
1921 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1922 let candles = read_candles_from_csv(file_path)?;
1923 let input = KstInput::with_default_candles(&candles);
1924 let result = kst_with_kernel(&input, kernel)?;
1925 assert_eq!(result.line.len(), candles.close.len());
1926 assert_eq!(result.signal.len(), candles.close.len());
1927 Ok(())
1928 }
1929
1930 fn check_kst_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1931 skip_if_unsupported!(kernel, test_name);
1932 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1933 let candles = read_candles_from_csv(file_path)?;
1934 let input = KstInput::with_default_candles(&candles);
1935 let result = kst_with_kernel(&input, kernel)?;
1936 let expected_last_five_line = [
1937 -47.38570195278667,
1938 -44.42926180347176,
1939 -42.185693049429034,
1940 -40.10697793942024,
1941 -40.17466795905724,
1942 ];
1943 let expected_last_five_signal = [
1944 -52.66743277411538,
1945 -51.559775662725556,
1946 -50.113844191238954,
1947 -48.58923772989874,
1948 -47.01112630514571,
1949 ];
1950 let l = result.line.len();
1951 let s = result.signal.len();
1952 for (i, &v) in result.line[l - 5..].iter().enumerate() {
1953 assert!(
1954 (v - expected_last_five_line[i]).abs() < 1e-1,
1955 "KST line mismatch {}: {} vs {}",
1956 i,
1957 v,
1958 expected_last_five_line[i]
1959 );
1960 }
1961 for (i, &v) in result.signal[s - 5..].iter().enumerate() {
1962 assert!(
1963 (v - expected_last_five_signal[i]).abs() < 1e-1,
1964 "KST signal mismatch {}: {} vs {}",
1965 i,
1966 v,
1967 expected_last_five_signal[i]
1968 );
1969 }
1970 Ok(())
1971 }
1972
1973 fn check_kst_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1974 skip_if_unsupported!(kernel, test_name);
1975 let nan_data = [f64::NAN, f64::NAN, f64::NAN];
1976 let input = KstInput::from_slice(&nan_data, KstParams::default());
1977 let result = kst_with_kernel(&input, kernel);
1978 assert!(result.is_err(), "[{}] Should error with all NaN", test_name);
1979 Ok(())
1980 }
1981
1982 macro_rules! generate_all_kst_tests {
1983 ($($test_fn:ident),*) => {
1984 paste::paste! {
1985 $(
1986 #[test]
1987 fn [<$test_fn _scalar_f64>]() { let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar); }
1988 )*
1989 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1990 $(
1991 #[test]
1992 fn [<$test_fn _avx2_f64>]() { let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2); }
1993 #[test]
1994 fn [<$test_fn _avx512_f64>]() { let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512); }
1995 )*
1996 }
1997 }
1998 }
1999
2000 #[cfg(debug_assertions)]
2001 fn check_kst_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2002 skip_if_unsupported!(kernel, test_name);
2003
2004 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2005 let candles = read_candles_from_csv(file_path)?;
2006
2007 let test_params = vec![
2008 KstParams::default(),
2009 KstParams {
2010 sma_period1: Some(2),
2011 sma_period2: Some(2),
2012 sma_period3: Some(2),
2013 sma_period4: Some(2),
2014 roc_period1: Some(2),
2015 roc_period2: Some(2),
2016 roc_period3: Some(2),
2017 roc_period4: Some(2),
2018 signal_period: Some(2),
2019 },
2020 KstParams {
2021 sma_period1: Some(5),
2022 sma_period2: Some(5),
2023 sma_period3: Some(5),
2024 sma_period4: Some(7),
2025 roc_period1: Some(5),
2026 roc_period2: Some(7),
2027 roc_period3: Some(10),
2028 roc_period4: Some(15),
2029 signal_period: Some(5),
2030 },
2031 KstParams {
2032 sma_period1: Some(7),
2033 sma_period2: Some(10),
2034 sma_period3: Some(12),
2035 sma_period4: Some(15),
2036 roc_period1: Some(8),
2037 roc_period2: Some(12),
2038 roc_period3: Some(16),
2039 roc_period4: Some(20),
2040 signal_period: Some(7),
2041 },
2042 KstParams {
2043 sma_period1: Some(10),
2044 sma_period2: Some(10),
2045 sma_period3: Some(10),
2046 sma_period4: Some(15),
2047 roc_period1: Some(10),
2048 roc_period2: Some(15),
2049 roc_period3: Some(20),
2050 roc_period4: Some(30),
2051 signal_period: Some(9),
2052 },
2053 KstParams {
2054 sma_period1: Some(20),
2055 sma_period2: Some(25),
2056 sma_period3: Some(30),
2057 sma_period4: Some(35),
2058 roc_period1: Some(25),
2059 roc_period2: Some(35),
2060 roc_period3: Some(45),
2061 roc_period4: Some(60),
2062 signal_period: Some(15),
2063 },
2064 KstParams {
2065 sma_period1: Some(30),
2066 sma_period2: Some(40),
2067 sma_period3: Some(50),
2068 sma_period4: Some(60),
2069 roc_period1: Some(40),
2070 roc_period2: Some(60),
2071 roc_period3: Some(80),
2072 roc_period4: Some(100),
2073 signal_period: Some(21),
2074 },
2075 KstParams {
2076 sma_period1: Some(5),
2077 sma_period2: Some(10),
2078 sma_period3: Some(20),
2079 sma_period4: Some(50),
2080 roc_period1: Some(7),
2081 roc_period2: Some(14),
2082 roc_period3: Some(28),
2083 roc_period4: Some(56),
2084 signal_period: Some(12),
2085 },
2086 KstParams {
2087 sma_period1: Some(10),
2088 sma_period2: Some(10),
2089 sma_period3: Some(10),
2090 sma_period4: Some(15),
2091 roc_period1: Some(10),
2092 roc_period2: Some(15),
2093 roc_period3: Some(20),
2094 roc_period4: Some(30),
2095 signal_period: Some(2),
2096 },
2097 KstParams {
2098 sma_period1: Some(1),
2099 sma_period2: Some(1),
2100 sma_period3: Some(1),
2101 sma_period4: Some(1),
2102 roc_period1: Some(1),
2103 roc_period2: Some(1),
2104 roc_period3: Some(1),
2105 roc_period4: Some(1),
2106 signal_period: Some(1),
2107 },
2108 KstParams {
2109 sma_period1: Some(100),
2110 sma_period2: Some(120),
2111 sma_period3: Some(140),
2112 sma_period4: Some(160),
2113 roc_period1: Some(100),
2114 roc_period2: Some(150),
2115 roc_period3: Some(200),
2116 roc_period4: Some(250),
2117 signal_period: Some(50),
2118 },
2119 KstParams {
2120 sma_period1: Some(10),
2121 sma_period2: Some(15),
2122 sma_period3: Some(20),
2123 sma_period4: Some(30),
2124 roc_period1: Some(10),
2125 roc_period2: Some(15),
2126 roc_period3: Some(20),
2127 roc_period4: Some(30),
2128 signal_period: Some(10),
2129 },
2130 KstParams {
2131 sma_period1: Some(3),
2132 sma_period2: Some(6),
2133 sma_period3: Some(12),
2134 sma_period4: Some(24),
2135 roc_period1: Some(5),
2136 roc_period2: Some(10),
2137 roc_period3: Some(20),
2138 roc_period4: Some(40),
2139 signal_period: Some(8),
2140 },
2141 ];
2142
2143 for (param_idx, params) in test_params.iter().enumerate() {
2144 let input = KstInput::from_candles(&candles, "close", params.clone());
2145 let output = kst_with_kernel(&input, kernel)?;
2146
2147 for (i, &val) in output.line.iter().enumerate() {
2148 if val.is_nan() {
2149 continue;
2150 }
2151
2152 let bits = val.to_bits();
2153
2154 if bits == 0x11111111_11111111 {
2155 panic!(
2156 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2157 in KST line with params: sma_periods=({},{},{},{}), roc_periods=({},{},{},{}), \
2158 signal_period={} (param set {})",
2159 test_name,
2160 val,
2161 bits,
2162 i,
2163 params.sma_period1.unwrap_or(10),
2164 params.sma_period2.unwrap_or(10),
2165 params.sma_period3.unwrap_or(10),
2166 params.sma_period4.unwrap_or(15),
2167 params.roc_period1.unwrap_or(10),
2168 params.roc_period2.unwrap_or(15),
2169 params.roc_period3.unwrap_or(20),
2170 params.roc_period4.unwrap_or(30),
2171 params.signal_period.unwrap_or(9),
2172 param_idx
2173 );
2174 }
2175
2176 if bits == 0x22222222_22222222 {
2177 panic!(
2178 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2179 in KST line with params: sma_periods=({},{},{},{}), roc_periods=({},{},{},{}), \
2180 signal_period={} (param set {})",
2181 test_name,
2182 val,
2183 bits,
2184 i,
2185 params.sma_period1.unwrap_or(10),
2186 params.sma_period2.unwrap_or(10),
2187 params.sma_period3.unwrap_or(10),
2188 params.sma_period4.unwrap_or(15),
2189 params.roc_period1.unwrap_or(10),
2190 params.roc_period2.unwrap_or(15),
2191 params.roc_period3.unwrap_or(20),
2192 params.roc_period4.unwrap_or(30),
2193 params.signal_period.unwrap_or(9),
2194 param_idx
2195 );
2196 }
2197
2198 if bits == 0x33333333_33333333 {
2199 panic!(
2200 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2201 in KST line with params: sma_periods=({},{},{},{}), roc_periods=({},{},{},{}), \
2202 signal_period={} (param set {})",
2203 test_name,
2204 val,
2205 bits,
2206 i,
2207 params.sma_period1.unwrap_or(10),
2208 params.sma_period2.unwrap_or(10),
2209 params.sma_period3.unwrap_or(10),
2210 params.sma_period4.unwrap_or(15),
2211 params.roc_period1.unwrap_or(10),
2212 params.roc_period2.unwrap_or(15),
2213 params.roc_period3.unwrap_or(20),
2214 params.roc_period4.unwrap_or(30),
2215 params.signal_period.unwrap_or(9),
2216 param_idx
2217 );
2218 }
2219 }
2220
2221 for (i, &val) in output.signal.iter().enumerate() {
2222 if val.is_nan() {
2223 continue;
2224 }
2225
2226 let bits = val.to_bits();
2227
2228 if bits == 0x11111111_11111111 {
2229 panic!(
2230 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2231 in KST signal with params: sma_periods=({},{},{},{}), roc_periods=({},{},{},{}), \
2232 signal_period={} (param set {})",
2233 test_name,
2234 val,
2235 bits,
2236 i,
2237 params.sma_period1.unwrap_or(10),
2238 params.sma_period2.unwrap_or(10),
2239 params.sma_period3.unwrap_or(10),
2240 params.sma_period4.unwrap_or(15),
2241 params.roc_period1.unwrap_or(10),
2242 params.roc_period2.unwrap_or(15),
2243 params.roc_period3.unwrap_or(20),
2244 params.roc_period4.unwrap_or(30),
2245 params.signal_period.unwrap_or(9),
2246 param_idx
2247 );
2248 }
2249
2250 if bits == 0x22222222_22222222 {
2251 panic!(
2252 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2253 in KST signal with params: sma_periods=({},{},{},{}), roc_periods=({},{},{},{}), \
2254 signal_period={} (param set {})",
2255 test_name,
2256 val,
2257 bits,
2258 i,
2259 params.sma_period1.unwrap_or(10),
2260 params.sma_period2.unwrap_or(10),
2261 params.sma_period3.unwrap_or(10),
2262 params.sma_period4.unwrap_or(15),
2263 params.roc_period1.unwrap_or(10),
2264 params.roc_period2.unwrap_or(15),
2265 params.roc_period3.unwrap_or(20),
2266 params.roc_period4.unwrap_or(30),
2267 params.signal_period.unwrap_or(9),
2268 param_idx
2269 );
2270 }
2271
2272 if bits == 0x33333333_33333333 {
2273 panic!(
2274 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2275 in KST signal with params: sma_periods=({},{},{},{}), roc_periods=({},{},{},{}), \
2276 signal_period={} (param set {})",
2277 test_name,
2278 val,
2279 bits,
2280 i,
2281 params.sma_period1.unwrap_or(10),
2282 params.sma_period2.unwrap_or(10),
2283 params.sma_period3.unwrap_or(10),
2284 params.sma_period4.unwrap_or(15),
2285 params.roc_period1.unwrap_or(10),
2286 params.roc_period2.unwrap_or(15),
2287 params.roc_period3.unwrap_or(20),
2288 params.roc_period4.unwrap_or(30),
2289 params.signal_period.unwrap_or(9),
2290 param_idx
2291 );
2292 }
2293 }
2294 }
2295
2296 Ok(())
2297 }
2298
2299 #[cfg(not(debug_assertions))]
2300 fn check_kst_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2301 Ok(())
2302 }
2303
2304 #[cfg(feature = "proptest")]
2305 #[allow(clippy::float_cmp)]
2306 fn check_kst_property(
2307 test_name: &str,
2308 kernel: Kernel,
2309 ) -> Result<(), Box<dyn std::error::Error>> {
2310 use proptest::prelude::*;
2311 skip_if_unsupported!(kernel, test_name);
2312
2313 let strat = (
2314 (3usize..=20),
2315 (3usize..=20),
2316 (3usize..=20),
2317 (5usize..=25),
2318 (5usize..=15),
2319 (10usize..=20),
2320 (15usize..=25),
2321 (20usize..=35),
2322 (3usize..=15),
2323 (0usize..=3),
2324 )
2325 .prop_flat_map(|(s1, s2, s3, s4, r1, r2, r3, r4, sig, scenario)| {
2326 let warmup1 = r1 + s1 - 1;
2327 let warmup2 = r2 + s2 - 1;
2328 let warmup3 = r3 + s3 - 1;
2329 let warmup4 = r4 + s4 - 1;
2330 let warmup = warmup1.max(warmup2).max(warmup3).max(warmup4);
2331 let min_data_len = warmup + sig + 20;
2332
2333 let data_strategy = match scenario {
2334 0 => prop::collection::vec(
2335 (10.0f64..10000.0f64).prop_filter("finite", |x| x.is_finite()),
2336 min_data_len..400,
2337 )
2338 .boxed(),
2339 1 => prop::collection::vec(
2340 (0.01f64..5.0f64).prop_filter("finite", |x| x.is_finite()),
2341 min_data_len..400,
2342 )
2343 .boxed(),
2344 2 => prop::collection::vec((10.0f64..1000.0f64), min_data_len..400)
2345 .prop_map(|mut v| {
2346 for i in 0..v.len() / 4 {
2347 let plateau_start = i * 4;
2348 let plateau_end = (plateau_start + 3).min(v.len() - 1);
2349 let plateau_value = v[plateau_start];
2350 for j in plateau_start..=plateau_end {
2351 v[j] = plateau_value;
2352 }
2353 }
2354 v
2355 })
2356 .boxed(),
2357 _ => prop::collection::vec((10.0f64..1000.0f64), min_data_len..400)
2358 .prop_map(|mut v| {
2359 for i in (5..v.len()).step_by(20) {
2360 v[i] = v[i - 1] * (1.5 + (i % 3) as f64 * 0.5);
2361 }
2362 v
2363 })
2364 .boxed(),
2365 };
2366
2367 (
2368 data_strategy,
2369 Just(s1),
2370 Just(s2),
2371 Just(s3),
2372 Just(s4),
2373 Just(r1),
2374 Just(r2),
2375 Just(r3),
2376 Just(r4),
2377 Just(sig),
2378 )
2379 });
2380
2381 proptest::test_runner::TestRunner::default()
2382 .run(&strat, |(data, s1, s2, s3, s4, r1, r2, r3, r4, sig)| {
2383 let params = KstParams {
2384 sma_period1: Some(s1),
2385 sma_period2: Some(s2),
2386 sma_period3: Some(s3),
2387 sma_period4: Some(s4),
2388 roc_period1: Some(r1),
2389 roc_period2: Some(r2),
2390 roc_period3: Some(r3),
2391 roc_period4: Some(r4),
2392 signal_period: Some(sig),
2393 };
2394 let input = KstInput::from_slice(&data, params);
2395
2396 let warmup1 = r1 + s1 - 1;
2397 let warmup2 = r2 + s2 - 1;
2398 let warmup3 = r3 + s3 - 1;
2399 let warmup4 = r4 + s4 - 1;
2400 let warmup = warmup1.max(warmup2).max(warmup3).max(warmup4);
2401 let signal_warmup = warmup + sig - 1;
2402
2403 let KstOutput { line, signal } = kst_with_kernel(&input, kernel).unwrap();
2404 let KstOutput {
2405 line: ref_line,
2406 signal: ref_signal,
2407 } = kst_with_kernel(&input, Kernel::Scalar).unwrap();
2408
2409 for i in 0..warmup.min(data.len()) {
2410 prop_assert!(
2411 line[i].is_nan(),
2412 "KST line should be NaN during warmup at index {i}"
2413 );
2414 }
2415 for i in 0..signal_warmup.min(data.len()) {
2416 prop_assert!(
2417 signal[i].is_nan(),
2418 "Signal should be NaN during warmup at index {i}"
2419 );
2420 }
2421
2422 for i in warmup..data.len() {
2423 let y = line[i];
2424 let r = ref_line[i];
2425
2426 if !y.is_finite() || !r.is_finite() {
2427 prop_assert!(
2428 y.to_bits() == r.to_bits(),
2429 "NaN/Inf mismatch at idx {i}: {y} vs {r}"
2430 );
2431 continue;
2432 }
2433
2434 let ulp_diff = y.to_bits().abs_diff(r.to_bits());
2435 prop_assert!(
2436 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
2437 "KST line kernel mismatch idx {i}: {y} vs {r} (ULP={ulp_diff})"
2438 );
2439 }
2440
2441 for i in signal_warmup..data.len() {
2442 let y = signal[i];
2443 let r = ref_signal[i];
2444
2445 if !y.is_finite() || !r.is_finite() {
2446 prop_assert!(
2447 y.to_bits() == r.to_bits(),
2448 "Signal NaN/Inf mismatch at idx {i}: {y} vs {r}"
2449 );
2450 continue;
2451 }
2452
2453 let ulp_diff = y.to_bits().abs_diff(r.to_bits());
2454 prop_assert!(
2455 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
2456 "Signal kernel mismatch idx {i}: {y} vs {r} (ULP={ulp_diff})"
2457 );
2458 }
2459
2460 if data.windows(2).all(|w| (w[0] - w[1]).abs() <= f64::EPSILON) {
2461 for i in warmup..data.len() {
2462 prop_assert!(
2463 line[i].abs() <= 1e-9,
2464 "KST should be ~0 for constant data at idx {i}: {}",
2465 line[i]
2466 );
2467 }
2468 }
2469
2470 if data.windows(2).all(|w| w[1] > w[0] + f64::EPSILON) {
2471 let check_start = warmup + 10;
2472 for i in check_start..data.len() {
2473 if line[i].is_finite() {
2474 prop_assert!(
2475 line[i] > -1e-6,
2476 "KST should be positive for increasing data at idx {i}: {}",
2477 line[i]
2478 );
2479 }
2480 }
2481 }
2482
2483 if data.windows(2).all(|w| w[0] > w[1] + f64::EPSILON) {
2484 let check_start = warmup + 10;
2485 for i in check_start..data.len() {
2486 if line[i].is_finite() {
2487 prop_assert!(
2488 line[i] < 1e-6,
2489 "KST should be negative for decreasing data at idx {i}: {}",
2490 line[i]
2491 );
2492 }
2493 }
2494 }
2495
2496 if signal_warmup < data.len() {
2497 for i in signal_warmup..data.len() {
2498 let y = signal[i];
2499 if !y.is_finite() {
2500 continue;
2501 }
2502
2503 let start = i + 1 - sig;
2504 let window = &line[start..=i];
2505 let lo = window.iter().cloned().fold(f64::INFINITY, f64::min);
2506 let hi = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
2507
2508 prop_assert!(
2509 y >= lo - 1e-9 && y <= hi + 1e-9,
2510 "Signal out of window bounds at idx {i}: {y} ∉ [{lo}, {hi}]"
2511 );
2512 }
2513 }
2514
2515 let min_price = data.iter().cloned().fold(f64::INFINITY, f64::min);
2516 let max_price = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
2517 let max_roc = if min_price > 0.0 {
2518 ((max_price / min_price) - 1.0) * 100.0
2519 } else {
2520 10000.0
2521 };
2522
2523 let kst_bound = max_roc * 10.0;
2524
2525 for i in warmup..data.len() {
2526 if line[i].is_finite() {
2527 prop_assert!(
2528 line[i] >= -kst_bound && line[i] <= kst_bound,
2529 "KST line out of reasonable bounds at idx {i}: {} (bound: ±{})",
2530 line[i],
2531 kst_bound
2532 );
2533 }
2534 }
2535 for i in signal_warmup..data.len() {
2536 if signal[i].is_finite() {
2537 prop_assert!(
2538 signal[i] >= -kst_bound && signal[i] <= kst_bound,
2539 "Signal out of reasonable bounds at idx {i}: {} (bound: ±{})",
2540 signal[i],
2541 kst_bound
2542 );
2543 }
2544 }
2545
2546 for i in warmup..data.len() {
2547 prop_assert!(
2548 line[i].is_nan() || line[i].is_finite(),
2549 "KST line has infinite value at idx {i}: {}",
2550 line[i]
2551 );
2552 }
2553 for i in signal_warmup..data.len() {
2554 prop_assert!(
2555 signal[i].is_nan() || signal[i].is_finite(),
2556 "Signal has infinite value at idx {i}: {}",
2557 signal[i]
2558 );
2559 }
2560
2561 if signal_warmup + sig + 5 < data.len() {
2562 for i in (signal_warmup + sig)..data.len() {
2563 let line_window = &line[i.saturating_sub(sig - 1)..=i.min(data.len() - 1)];
2564 let valid_values: Vec<f64> = line_window
2565 .iter()
2566 .filter(|x| x.is_finite())
2567 .cloned()
2568 .collect();
2569
2570 if !valid_values.is_empty() && signal[i].is_finite() {
2571 let line_avg =
2572 valid_values.iter().sum::<f64>() / valid_values.len() as f64;
2573
2574 let tolerance = if line_avg.abs() > 100.0 {
2575 0.005
2576 } else if line_avg.abs() > 10.0 {
2577 0.007
2578 } else {
2579 0.01
2580 };
2581
2582 prop_assert!(
2583 (signal[i] - line_avg).abs() <= 1e-6 ||
2584 (signal[i] - line_avg).abs() / line_avg.abs().max(1.0) <= tolerance,
2585 "Signal deviates from KST trend at idx {i}: signal={}, line_avg={}, tolerance={}%",
2586 signal[i], line_avg, tolerance * 100.0
2587 );
2588 }
2589 }
2590 }
2591
2592 Ok(())
2593 })
2594 .unwrap();
2595
2596 Ok(())
2597 }
2598
2599 #[cfg(feature = "proptest")]
2600 generate_all_kst_tests!(check_kst_property);
2601
2602 generate_all_kst_tests!(
2603 check_kst_default_params,
2604 check_kst_accuracy,
2605 check_kst_nan_handling,
2606 check_kst_no_poison
2607 );
2608
2609 #[test]
2610 fn test_kst_into_matches_api() {
2611 let n = 512usize;
2612 let mut data = Vec::with_capacity(n);
2613 for i in 0..n {
2614 let x = 100.0 + (i as f64) * 0.1 + (i as f64).sin();
2615 data.push(x);
2616 }
2617
2618 let input = KstInput::from_slice(&data, KstParams::default());
2619
2620 let base = kst(&input).expect("kst baseline");
2621
2622 let mut out_line = vec![0.0; n];
2623 let mut out_signal = vec![0.0; n];
2624
2625 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2626 {
2627 kst_into(&input, &mut out_line, &mut out_signal).expect("kst_into");
2628 }
2629 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2630 {
2631 kst_into_slice(&mut out_line, &mut out_signal, &input, Kernel::Auto)
2632 .expect("kst_into_slice");
2633 }
2634
2635 assert_eq!(base.line.len(), n);
2636 assert_eq!(base.signal.len(), n);
2637 assert_eq!(out_line.len(), n);
2638 assert_eq!(out_signal.len(), n);
2639
2640 fn eq_or_both_nan(a: f64, b: f64) -> bool {
2641 (a.is_nan() && b.is_nan()) || (a == b)
2642 }
2643
2644 for i in 0..n {
2645 assert!(
2646 eq_or_both_nan(base.line[i], out_line[i]),
2647 "line mismatch at {i}: {} vs {}",
2648 base.line[i],
2649 out_line[i]
2650 );
2651 assert!(
2652 eq_or_both_nan(base.signal[i], out_signal[i]),
2653 "signal mismatch at {i}: {} vs {}",
2654 base.signal[i],
2655 out_signal[i]
2656 );
2657 }
2658 }
2659
2660 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2661 skip_if_unsupported!(kernel, test);
2662 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2663 let c = read_candles_from_csv(file)?;
2664 let output = KstBatchBuilder::new()
2665 .kernel(kernel)
2666 .apply_candles(&c, "close")?;
2667 let def = KstParams::default();
2668 let row = output.lines_for(&def).expect("default row missing");
2669 assert_eq!(row.len(), c.close.len());
2670 Ok(())
2671 }
2672
2673 macro_rules! gen_batch_tests {
2674 ($fn_name:ident) => {
2675 paste::paste! {
2676 #[test] fn [<$fn_name _scalar>]() { let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch); }
2677 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2678 #[test] fn [<$fn_name _avx2>]() { let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch); }
2679 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2680 #[test] fn [<$fn_name _avx512>]() { let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch); }
2681 #[test] fn [<$fn_name _auto_detect>]() { let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto); }
2682 }
2683 };
2684 }
2685 #[cfg(debug_assertions)]
2686 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2687 skip_if_unsupported!(kernel, test);
2688
2689 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2690 let c = read_candles_from_csv(file)?;
2691
2692 let test_configs = vec![
2693 (
2694 3, 3, 0, 3, 3, 0, 3, 3, 0, 3, 3, 0, 3, 3, 0, 4, 4, 0, 5, 5, 0, 6, 6, 0, 3, 3, 0,
2695 ),
2696 (
2697 5, 10, 5, 5, 10, 5, 5, 10, 5, 8, 13, 5, 5, 10, 5, 8, 13, 5, 10, 15, 5, 15, 20, 5,
2698 5, 7, 2,
2699 ),
2700 (
2701 10, 10, 0, 10, 10, 0, 10, 10, 0, 15, 15, 0, 10, 10, 0, 15, 15, 0, 20, 20, 0, 30,
2702 30, 0, 9, 9, 0,
2703 ),
2704 (
2705 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 3, 4, 1, 4, 5, 1, 5, 6, 1, 2, 3, 1,
2706 ),
2707 ];
2708
2709 for (
2710 cfg_idx,
2711 &(
2712 s1_start,
2713 s1_end,
2714 s1_step,
2715 s2_start,
2716 s2_end,
2717 s2_step,
2718 s3_start,
2719 s3_end,
2720 s3_step,
2721 s4_start,
2722 s4_end,
2723 s4_step,
2724 r1_start,
2725 r1_end,
2726 r1_step,
2727 r2_start,
2728 r2_end,
2729 r2_step,
2730 r3_start,
2731 r3_end,
2732 r3_step,
2733 r4_start,
2734 r4_end,
2735 r4_step,
2736 sig_start,
2737 sig_end,
2738 sig_step,
2739 ),
2740 ) in test_configs.iter().enumerate()
2741 {
2742 let output = KstBatchBuilder::new()
2743 .kernel(kernel)
2744 .sma_period1_range(s1_start, s1_end, s1_step)
2745 .sma_period2_range(s2_start, s2_end, s2_step)
2746 .sma_period3_range(s3_start, s3_end, s3_step)
2747 .sma_period4_range(s4_start, s4_end, s4_step)
2748 .roc_period1_range(r1_start, r1_end, r1_step)
2749 .roc_period2_range(r2_start, r2_end, r2_step)
2750 .roc_period3_range(r3_start, r3_end, r3_step)
2751 .roc_period4_range(r4_start, r4_end, r4_step)
2752 .signal_period_range(sig_start, sig_end, sig_step)
2753 .apply_candles(&c, "close")?;
2754
2755 for (idx, &val) in output.lines.iter().enumerate() {
2756 if val.is_nan() {
2757 continue;
2758 }
2759
2760 let bits = val.to_bits();
2761 let row = idx / output.cols;
2762 let col = idx % output.cols;
2763 let combo = &output.combos[row];
2764
2765 if bits == 0x11111111_11111111 {
2766 panic!(
2767 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2768 at row {} col {} (flat index {}) in KST lines with params: \
2769 sma_periods=({},{},{},{}), roc_periods=({},{},{},{}), signal_period={}",
2770 test,
2771 cfg_idx,
2772 val,
2773 bits,
2774 row,
2775 col,
2776 idx,
2777 combo.sma_period1.unwrap_or(10),
2778 combo.sma_period2.unwrap_or(10),
2779 combo.sma_period3.unwrap_or(10),
2780 combo.sma_period4.unwrap_or(15),
2781 combo.roc_period1.unwrap_or(10),
2782 combo.roc_period2.unwrap_or(15),
2783 combo.roc_period3.unwrap_or(20),
2784 combo.roc_period4.unwrap_or(30),
2785 combo.signal_period.unwrap_or(9)
2786 );
2787 }
2788
2789 if bits == 0x22222222_22222222 {
2790 panic!(
2791 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2792 at row {} col {} (flat index {}) in KST lines with params: \
2793 sma_periods=({},{},{},{}), roc_periods=({},{},{},{}), signal_period={}",
2794 test,
2795 cfg_idx,
2796 val,
2797 bits,
2798 row,
2799 col,
2800 idx,
2801 combo.sma_period1.unwrap_or(10),
2802 combo.sma_period2.unwrap_or(10),
2803 combo.sma_period3.unwrap_or(10),
2804 combo.sma_period4.unwrap_or(15),
2805 combo.roc_period1.unwrap_or(10),
2806 combo.roc_period2.unwrap_or(15),
2807 combo.roc_period3.unwrap_or(20),
2808 combo.roc_period4.unwrap_or(30),
2809 combo.signal_period.unwrap_or(9)
2810 );
2811 }
2812
2813 if bits == 0x33333333_33333333 {
2814 panic!(
2815 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2816 at row {} col {} (flat index {}) in KST lines with params: \
2817 sma_periods=({},{},{},{}), roc_periods=({},{},{},{}), signal_period={}",
2818 test,
2819 cfg_idx,
2820 val,
2821 bits,
2822 row,
2823 col,
2824 idx,
2825 combo.sma_period1.unwrap_or(10),
2826 combo.sma_period2.unwrap_or(10),
2827 combo.sma_period3.unwrap_or(10),
2828 combo.sma_period4.unwrap_or(15),
2829 combo.roc_period1.unwrap_or(10),
2830 combo.roc_period2.unwrap_or(15),
2831 combo.roc_period3.unwrap_or(20),
2832 combo.roc_period4.unwrap_or(30),
2833 combo.signal_period.unwrap_or(9)
2834 );
2835 }
2836 }
2837
2838 for (idx, &val) in output.signals.iter().enumerate() {
2839 if val.is_nan() {
2840 continue;
2841 }
2842
2843 let bits = val.to_bits();
2844 let row = idx / output.cols;
2845 let col = idx % output.cols;
2846 let combo = &output.combos[row];
2847
2848 if bits == 0x11111111_11111111 {
2849 panic!(
2850 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2851 at row {} col {} (flat index {}) in KST signals with params: \
2852 sma_periods=({},{},{},{}), roc_periods=({},{},{},{}), signal_period={}",
2853 test,
2854 cfg_idx,
2855 val,
2856 bits,
2857 row,
2858 col,
2859 idx,
2860 combo.sma_period1.unwrap_or(10),
2861 combo.sma_period2.unwrap_or(10),
2862 combo.sma_period3.unwrap_or(10),
2863 combo.sma_period4.unwrap_or(15),
2864 combo.roc_period1.unwrap_or(10),
2865 combo.roc_period2.unwrap_or(15),
2866 combo.roc_period3.unwrap_or(20),
2867 combo.roc_period4.unwrap_or(30),
2868 combo.signal_period.unwrap_or(9)
2869 );
2870 }
2871
2872 if bits == 0x22222222_22222222 {
2873 panic!(
2874 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2875 at row {} col {} (flat index {}) in KST signals with params: \
2876 sma_periods=({},{},{},{}), roc_periods=({},{},{},{}), signal_period={}",
2877 test,
2878 cfg_idx,
2879 val,
2880 bits,
2881 row,
2882 col,
2883 idx,
2884 combo.sma_period1.unwrap_or(10),
2885 combo.sma_period2.unwrap_or(10),
2886 combo.sma_period3.unwrap_or(10),
2887 combo.sma_period4.unwrap_or(15),
2888 combo.roc_period1.unwrap_or(10),
2889 combo.roc_period2.unwrap_or(15),
2890 combo.roc_period3.unwrap_or(20),
2891 combo.roc_period4.unwrap_or(30),
2892 combo.signal_period.unwrap_or(9)
2893 );
2894 }
2895
2896 if bits == 0x33333333_33333333 {
2897 panic!(
2898 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2899 at row {} col {} (flat index {}) in KST signals with params: \
2900 sma_periods=({},{},{},{}), roc_periods=({},{},{},{}), signal_period={}",
2901 test,
2902 cfg_idx,
2903 val,
2904 bits,
2905 row,
2906 col,
2907 idx,
2908 combo.sma_period1.unwrap_or(10),
2909 combo.sma_period2.unwrap_or(10),
2910 combo.sma_period3.unwrap_or(10),
2911 combo.sma_period4.unwrap_or(15),
2912 combo.roc_period1.unwrap_or(10),
2913 combo.roc_period2.unwrap_or(15),
2914 combo.roc_period3.unwrap_or(20),
2915 combo.roc_period4.unwrap_or(30),
2916 combo.signal_period.unwrap_or(9)
2917 );
2918 }
2919 }
2920 }
2921
2922 Ok(())
2923 }
2924
2925 #[cfg(not(debug_assertions))]
2926 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2927 Ok(())
2928 }
2929
2930 gen_batch_tests!(check_batch_default_row);
2931 gen_batch_tests!(check_batch_no_poison);
2932
2933 #[test]
2934 fn check_empty_input_error() {
2935 let empty_data: Vec<f64> = vec![];
2936 let params = KstParams::default();
2937 let input = KstInput::from_slice(&empty_data, params);
2938
2939 match kst(&input) {
2940 Err(KstError::EmptyInputData) => {}
2941 Err(e) => panic!("Expected EmptyInputData, got: {:?}", e),
2942 Ok(_) => panic!("Empty input should have failed"),
2943 }
2944
2945 let nan_data = vec![f64::NAN; 10];
2946 let input2 = KstInput::from_slice(&nan_data, params);
2947
2948 match kst(&input2) {
2949 Err(KstError::AllValuesNaN) => {}
2950 Err(e) => panic!("Expected AllValuesNaN, got: {:?}", e),
2951 Ok(_) => panic!("All NaN should have failed"),
2952 }
2953 }
2954}
2955
2956#[cfg(feature = "python")]
2957use crate::utilities::kernel_validation::validate_kernel;
2958#[cfg(feature = "python")]
2959use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2960#[cfg(feature = "python")]
2961use pyo3::types::{PyDict, PyList};
2962
2963#[cfg(feature = "python")]
2964#[pyfunction(name = "kst")]
2965#[pyo3(signature=(data,
2966 sma_period1=None, sma_period2=None, sma_period3=None, sma_period4=None,
2967 roc_period1=None, roc_period2=None, roc_period3=None, roc_period4=None,
2968 signal_period=None, kernel=None))]
2969pub fn kst_py<'py>(
2970 py: Python<'py>,
2971 data: numpy::PyReadonlyArray1<'py, f64>,
2972 sma_period1: Option<usize>,
2973 sma_period2: Option<usize>,
2974 sma_period3: Option<usize>,
2975 sma_period4: Option<usize>,
2976 roc_period1: Option<usize>,
2977 roc_period2: Option<usize>,
2978 roc_period3: Option<usize>,
2979 roc_period4: Option<usize>,
2980 signal_period: Option<usize>,
2981 kernel: Option<&str>,
2982) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
2983 let slice = data.as_slice()?;
2984 let prm = KstParams {
2985 sma_period1: Some(sma_period1.unwrap_or(10)),
2986 sma_period2: Some(sma_period2.unwrap_or(10)),
2987 sma_period3: Some(sma_period3.unwrap_or(10)),
2988 sma_period4: Some(sma_period4.unwrap_or(15)),
2989 roc_period1: Some(roc_period1.unwrap_or(10)),
2990 roc_period2: Some(roc_period2.unwrap_or(15)),
2991 roc_period3: Some(roc_period3.unwrap_or(20)),
2992 roc_period4: Some(roc_period4.unwrap_or(30)),
2993 signal_period: Some(signal_period.unwrap_or(9)),
2994 };
2995 let input = KstInput::from_slice(slice, prm);
2996 let kern = validate_kernel(kernel, false)?;
2997 let (line, signal) = py
2998 .allow_threads(|| kst_with_kernel(&input, kern).map(|o| (o.line, o.signal)))
2999 .map_err(|e| PyValueError::new_err(e.to_string()))?;
3000 Ok((line.into_pyarray(py), signal.into_pyarray(py)))
3001}
3002
3003#[cfg(feature = "python")]
3004#[pyclass(name = "KstStream")]
3005pub struct KstStreamPy {
3006 stream: KstStream,
3007}
3008
3009#[cfg(feature = "python")]
3010#[pymethods]
3011impl KstStreamPy {
3012 #[new]
3013 fn new(
3014 sma_period1: Option<usize>,
3015 sma_period2: Option<usize>,
3016 sma_period3: Option<usize>,
3017 sma_period4: Option<usize>,
3018 roc_period1: Option<usize>,
3019 roc_period2: Option<usize>,
3020 roc_period3: Option<usize>,
3021 roc_period4: Option<usize>,
3022 signal_period: Option<usize>,
3023 ) -> PyResult<Self> {
3024 let params = KstParams {
3025 sma_period1,
3026 sma_period2,
3027 sma_period3,
3028 sma_period4,
3029 roc_period1,
3030 roc_period2,
3031 roc_period3,
3032 roc_period4,
3033 signal_period,
3034 };
3035 let stream =
3036 KstStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
3037 Ok(KstStreamPy { stream })
3038 }
3039
3040 fn update(&mut self, value: f64) -> Option<(f64, f64)> {
3041 self.stream.update(value)
3042 }
3043}
3044
3045#[cfg(feature = "python")]
3046#[pyfunction(name = "kst_batch")]
3047#[pyo3(signature=(data,
3048 sma1_range, sma2_range, sma3_range, sma4_range,
3049 roc1_range, roc2_range, roc3_range, roc4_range,
3050 sig_range, kernel=None))]
3051pub fn kst_batch_py<'py>(
3052 py: Python<'py>,
3053 data: numpy::PyReadonlyArray1<'py, f64>,
3054 sma1_range: (usize, usize, usize),
3055 sma2_range: (usize, usize, usize),
3056 sma3_range: (usize, usize, usize),
3057 sma4_range: (usize, usize, usize),
3058 roc1_range: (usize, usize, usize),
3059 roc2_range: (usize, usize, usize),
3060 roc3_range: (usize, usize, usize),
3061 roc4_range: (usize, usize, usize),
3062 sig_range: (usize, usize, usize),
3063 kernel: Option<&str>,
3064) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
3065 let slice = data.as_slice()?;
3066 let sweep = KstBatchRange {
3067 sma_period1: sma1_range,
3068 sma_period2: sma2_range,
3069 sma_period3: sma3_range,
3070 sma_period4: sma4_range,
3071 roc_period1: roc1_range,
3072 roc_period2: roc2_range,
3073 roc_period3: roc3_range,
3074 roc_period4: roc4_range,
3075 signal_period: sig_range,
3076 };
3077 let kern = validate_kernel(kernel, true)?;
3078 let combos;
3079 let rows;
3080 let cols = slice.len();
3081 let (line_arr, sig_arr) = {
3082 let tmp_combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
3083 rows = tmp_combos.len();
3084 combos = tmp_combos;
3085 let total = rows
3086 .checked_mul(cols)
3087 .ok_or_else(|| PyValueError::new_err("kst: size overflow in batch output"))?;
3088 let out_line = unsafe { PyArray1::<f64>::new(py, [total], false) };
3089 let out_sig = unsafe { PyArray1::<f64>::new(py, [total], false) };
3090 let lo = unsafe { out_line.as_slice_mut()? };
3091 let so = unsafe { out_sig.as_slice_mut()? };
3092 py.allow_threads(|| {
3093 let k = match kern {
3094 Kernel::Auto => detect_best_batch_kernel(),
3095 x => x,
3096 };
3097 let simd = match k {
3098 Kernel::ScalarBatch => Kernel::Scalar,
3099 Kernel::Avx2Batch => Kernel::Scalar,
3100 Kernel::Avx512Batch => Kernel::Scalar,
3101 _ => Kernel::Scalar,
3102 };
3103 kst_batch_inner_into(slice, &sweep, simd, true, lo, so)
3104 })
3105 .map_err(|e| PyValueError::new_err(e.to_string()))?;
3106 (out_line, out_sig)
3107 };
3108
3109 let d = pyo3::types::PyDict::new(py);
3110 d.set_item("line", line_arr.reshape((rows, cols))?)?;
3111 d.set_item("signal", sig_arr.reshape((rows, cols))?)?;
3112
3113 d.set_item(
3114 "sma1",
3115 combos
3116 .iter()
3117 .map(|c| c.sma_period1.unwrap() as u64)
3118 .collect::<Vec<_>>()
3119 .into_pyarray(py),
3120 )?;
3121 d.set_item(
3122 "sma2",
3123 combos
3124 .iter()
3125 .map(|c| c.sma_period2.unwrap() as u64)
3126 .collect::<Vec<_>>()
3127 .into_pyarray(py),
3128 )?;
3129 d.set_item(
3130 "sma3",
3131 combos
3132 .iter()
3133 .map(|c| c.sma_period3.unwrap() as u64)
3134 .collect::<Vec<_>>()
3135 .into_pyarray(py),
3136 )?;
3137 d.set_item(
3138 "sma4",
3139 combos
3140 .iter()
3141 .map(|c| c.sma_period4.unwrap() as u64)
3142 .collect::<Vec<_>>()
3143 .into_pyarray(py),
3144 )?;
3145 d.set_item(
3146 "roc1",
3147 combos
3148 .iter()
3149 .map(|c| c.roc_period1.unwrap() as u64)
3150 .collect::<Vec<_>>()
3151 .into_pyarray(py),
3152 )?;
3153 d.set_item(
3154 "roc2",
3155 combos
3156 .iter()
3157 .map(|c| c.roc_period2.unwrap() as u64)
3158 .collect::<Vec<_>>()
3159 .into_pyarray(py),
3160 )?;
3161 d.set_item(
3162 "roc3",
3163 combos
3164 .iter()
3165 .map(|c| c.roc_period3.unwrap() as u64)
3166 .collect::<Vec<_>>()
3167 .into_pyarray(py),
3168 )?;
3169 d.set_item(
3170 "roc4",
3171 combos
3172 .iter()
3173 .map(|c| c.roc_period4.unwrap() as u64)
3174 .collect::<Vec<_>>()
3175 .into_pyarray(py),
3176 )?;
3177 d.set_item(
3178 "sig",
3179 combos
3180 .iter()
3181 .map(|c| c.signal_period.unwrap() as u64)
3182 .collect::<Vec<_>>()
3183 .into_pyarray(py),
3184 )?;
3185 Ok(d)
3186}
3187
3188#[cfg(feature = "python")]
3189pub fn register_kst_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
3190 m.add_function(wrap_pyfunction!(kst_py, m)?)?;
3191 m.add_function(wrap_pyfunction!(kst_batch_py, m)?)?;
3192 #[cfg(feature = "cuda")]
3193 {
3194 m.add_function(wrap_pyfunction!(kst_cuda_batch_dev_py, m)?)?;
3195 m.add_function(wrap_pyfunction!(kst_cuda_many_series_one_param_dev_py, m)?)?;
3196 }
3197 Ok(())
3198}
3199
3200#[cfg(all(feature = "python", feature = "cuda"))]
3201#[pyfunction(name = "kst_cuda_batch_dev")]
3202#[pyo3(signature = (
3203 data_f32,
3204 s1_range, s2_range, s3_range, s4_range,
3205 r1_range, r2_range, r3_range, r4_range,
3206 sig_range,
3207 device_id=0
3208))]
3209pub fn kst_cuda_batch_dev_py(
3210 py: Python<'_>,
3211 data_f32: PyReadonlyArray1<'_, f32>,
3212 s1_range: (usize, usize, usize),
3213 s2_range: (usize, usize, usize),
3214 s3_range: (usize, usize, usize),
3215 s4_range: (usize, usize, usize),
3216 r1_range: (usize, usize, usize),
3217 r2_range: (usize, usize, usize),
3218 r3_range: (usize, usize, usize),
3219 r4_range: (usize, usize, usize),
3220 sig_range: (usize, usize, usize),
3221 device_id: usize,
3222) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
3223 use crate::cuda::cuda_available;
3224 if !cuda_available() {
3225 return Err(PyValueError::new_err("CUDA not available"));
3226 }
3227 let prices = data_f32.as_slice()?;
3228 let sweep = KstBatchRange {
3229 sma_period1: s1_range,
3230 sma_period2: s2_range,
3231 sma_period3: s3_range,
3232 sma_period4: s4_range,
3233 roc_period1: r1_range,
3234 roc_period2: r2_range,
3235 roc_period3: r3_range,
3236 roc_period4: r4_range,
3237 signal_period: sig_range,
3238 };
3239 let (pair, ctx, dev) = py.allow_threads(|| {
3240 let cuda = CudaKst::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
3241 let ctx = cuda.context_arc();
3242 let dev = cuda.device_id();
3243 cuda.kst_batch_dev(prices, &sweep)
3244 .map_err(|e| PyValueError::new_err(e.to_string()))
3245 .map(|(pair, _combos)| (pair, ctx, dev))
3246 })?;
3247 Ok((
3248 DeviceArrayF32Py {
3249 inner: pair.line,
3250 _ctx: Some(ctx.clone()),
3251 device_id: Some(dev),
3252 },
3253 DeviceArrayF32Py {
3254 inner: pair.signal,
3255 _ctx: Some(ctx),
3256 device_id: Some(dev),
3257 },
3258 ))
3259}
3260
3261#[cfg(all(feature = "python", feature = "cuda"))]
3262#[pyfunction(name = "kst_cuda_many_series_one_param_dev")]
3263#[pyo3(signature = (
3264 data_tm_f32,
3265 cols, rows,
3266 s1, s2, s3, s4,
3267 r1, r2, r3, r4,
3268 sig,
3269 device_id=0
3270))]
3271pub fn kst_cuda_many_series_one_param_dev_py(
3272 py: Python<'_>,
3273 data_tm_f32: PyReadonlyArray1<'_, f32>,
3274 cols: usize,
3275 rows: usize,
3276 s1: usize,
3277 s2: usize,
3278 s3: usize,
3279 s4: usize,
3280 r1: usize,
3281 r2: usize,
3282 r3: usize,
3283 r4: usize,
3284 sig: usize,
3285 device_id: usize,
3286) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
3287 use crate::cuda::cuda_available;
3288 if !cuda_available() {
3289 return Err(PyValueError::new_err("CUDA not available"));
3290 }
3291 let prices_tm = data_tm_f32.as_slice()?;
3292 let params = KstParams {
3293 sma_period1: Some(s1),
3294 sma_period2: Some(s2),
3295 sma_period3: Some(s3),
3296 sma_period4: Some(s4),
3297 roc_period1: Some(r1),
3298 roc_period2: Some(r2),
3299 roc_period3: Some(r3),
3300 roc_period4: Some(r4),
3301 signal_period: Some(sig),
3302 };
3303 let (pair, ctx, dev) = py.allow_threads(|| {
3304 let cuda = CudaKst::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
3305 let ctx = cuda.context_arc();
3306 let dev = cuda.device_id();
3307 cuda.kst_many_series_one_param_time_major_dev(prices_tm, cols, rows, ¶ms)
3308 .map_err(|e| PyValueError::new_err(e.to_string()))
3309 .map(|pair| (pair, ctx, dev))
3310 })?;
3311 Ok((
3312 DeviceArrayF32Py {
3313 inner: pair.line,
3314 _ctx: Some(ctx.clone()),
3315 device_id: Some(dev),
3316 },
3317 DeviceArrayF32Py {
3318 inner: pair.signal,
3319 _ctx: Some(ctx),
3320 device_id: Some(dev),
3321 },
3322 ))
3323}
3324
3325#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3326#[derive(Serialize, Deserialize)]
3327pub struct KstJsResult {
3328 pub values: Vec<f64>,
3329 pub rows: usize,
3330 pub cols: usize,
3331}
3332
3333#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3334#[wasm_bindgen(js_name = "kst")]
3335pub fn kst_js(
3336 data: &[f64],
3337 sma1: usize,
3338 sma2: usize,
3339 sma3: usize,
3340 sma4: usize,
3341 roc1: usize,
3342 roc2: usize,
3343 roc3: usize,
3344 roc4: usize,
3345 sig: usize,
3346) -> Result<JsValue, JsValue> {
3347 let prm = KstParams {
3348 sma_period1: Some(sma1),
3349 sma_period2: Some(sma2),
3350 sma_period3: Some(sma3),
3351 sma_period4: Some(sma4),
3352 roc_period1: Some(roc1),
3353 roc_period2: Some(roc2),
3354 roc_period3: Some(roc3),
3355 roc_period4: Some(roc4),
3356 signal_period: Some(sig),
3357 };
3358 let input = KstInput::from_slice(data, prm);
3359
3360 let mut line = vec![0.0; data.len()];
3361 let mut signal = vec![0.0; data.len()];
3362 kst_into_slice(&mut line, &mut signal, &input, detect_best_kernel())
3363 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3364
3365 let mut values = line;
3366 values.extend_from_slice(&signal);
3367 let result = KstJsResult {
3368 values,
3369 rows: 2,
3370 cols: data.len(),
3371 };
3372 serde_wasm_bindgen::to_value(&result)
3373 .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
3374}
3375
3376#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3377#[wasm_bindgen]
3378pub fn kst_into(
3379 in_ptr: *const f64,
3380 out_line_ptr: *mut f64,
3381 out_signal_ptr: *mut f64,
3382 len: usize,
3383 sma1: usize,
3384 sma2: usize,
3385 sma3: usize,
3386 sma4: usize,
3387 roc1: usize,
3388 roc2: usize,
3389 roc3: usize,
3390 roc4: usize,
3391 sig: usize,
3392) -> Result<(), JsValue> {
3393 if in_ptr.is_null() || out_line_ptr.is_null() || out_signal_ptr.is_null() {
3394 return Err(JsValue::from_str("null pointer"));
3395 }
3396 unsafe {
3397 let in_beg = in_ptr as usize;
3398 let in_end = in_beg + len * 8;
3399 let lo_beg = out_line_ptr as usize;
3400 let lo_end = lo_beg + len * 8;
3401 let so_beg = out_signal_ptr as usize;
3402 let so_end = so_beg + len * 8;
3403 let overlap = |a0: usize, a1: usize, b0: usize, b1: usize| a0 < b1 && b0 < a1;
3404
3405 let data_slice = std::slice::from_raw_parts(in_ptr, len);
3406 let shadow;
3407 let data =
3408 if overlap(in_beg, in_end, lo_beg, lo_end) || overlap(in_beg, in_end, so_beg, so_end) {
3409 shadow = data_slice.to_vec();
3410 &shadow[..]
3411 } else {
3412 data_slice
3413 };
3414
3415 let prm = KstParams {
3416 sma_period1: Some(sma1),
3417 sma_period2: Some(sma2),
3418 sma_period3: Some(sma3),
3419 sma_period4: Some(sma4),
3420 roc_period1: Some(roc1),
3421 roc_period2: Some(roc2),
3422 roc_period3: Some(roc3),
3423 roc_period4: Some(roc4),
3424 signal_period: Some(sig),
3425 };
3426 let input = KstInput::from_slice(data, prm);
3427
3428 let ldst = std::slice::from_raw_parts_mut(out_line_ptr, len);
3429 let sdst = std::slice::from_raw_parts_mut(out_signal_ptr, len);
3430
3431 kst_into_slice(ldst, sdst, &input, detect_best_kernel())
3432 .map_err(|e| JsValue::from_str(&e.to_string()))
3433 }
3434}
3435
3436#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3437#[wasm_bindgen]
3438pub fn kst_alloc(len: usize) -> *mut f64 {
3439 let mut vec = Vec::<f64>::with_capacity(len);
3440 let ptr = vec.as_mut_ptr();
3441 std::mem::forget(vec);
3442 ptr
3443}
3444
3445#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3446#[wasm_bindgen]
3447pub fn kst_free(ptr: *mut f64, len: usize) {
3448 unsafe {
3449 let _ = Vec::from_raw_parts(ptr, len, len);
3450 }
3451}
3452
3453#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3454#[derive(Serialize, Deserialize)]
3455pub struct KstBatchConfig {
3456 pub sma_period1_range: (usize, usize, usize),
3457 pub sma_period2_range: (usize, usize, usize),
3458 pub sma_period3_range: (usize, usize, usize),
3459 pub sma_period4_range: (usize, usize, usize),
3460 pub roc_period1_range: (usize, usize, usize),
3461 pub roc_period2_range: (usize, usize, usize),
3462 pub roc_period3_range: (usize, usize, usize),
3463 pub roc_period4_range: (usize, usize, usize),
3464 pub signal_period_range: (usize, usize, usize),
3465}
3466
3467#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3468#[derive(Serialize, Deserialize)]
3469pub struct KstBatchJsOutput {
3470 pub values: Vec<f64>,
3471 pub combos: Vec<KstParams>,
3472 pub rows: usize,
3473 pub cols: usize,
3474}
3475
3476#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3477#[wasm_bindgen(js_name = "kst_batch")]
3478pub fn kst_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
3479 let sweep: KstBatchRange = serde_wasm_bindgen::from_value(config)
3480 .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
3481 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
3482 let rows = combos.len();
3483 let cols = data.len();
3484
3485 let total = rows
3486 .checked_mul(cols)
3487 .ok_or_else(|| JsValue::from_str("kst: size overflow in kst_batch_unified_js"))?;
3488 let mut lines = vec![0.0; total];
3489 let mut sigs = vec![0.0; total];
3490 kst_batch_inner_into(
3491 data,
3492 &sweep,
3493 detect_best_kernel(),
3494 false,
3495 &mut lines,
3496 &mut sigs,
3497 )
3498 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3499
3500 let mut values = lines;
3501 values.extend_from_slice(&sigs);
3502
3503 let out = KstBatchJsOutput {
3504 values,
3505 combos,
3506 rows: rows * 2,
3507 cols,
3508 };
3509 serde_wasm_bindgen::to_value(&out)
3510 .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
3511}
3512
3513#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3514#[wasm_bindgen]
3515pub fn kst_batch_into(
3516 in_ptr: *const f64,
3517 line_out_ptr: *mut f64,
3518 signal_out_ptr: *mut f64,
3519 len: usize,
3520 sma_period1_start: usize,
3521 sma_period1_end: usize,
3522 sma_period1_step: usize,
3523 sma_period2_start: usize,
3524 sma_period2_end: usize,
3525 sma_period2_step: usize,
3526 sma_period3_start: usize,
3527 sma_period3_end: usize,
3528 sma_period3_step: usize,
3529 sma_period4_start: usize,
3530 sma_period4_end: usize,
3531 sma_period4_step: usize,
3532 roc_period1_start: usize,
3533 roc_period1_end: usize,
3534 roc_period1_step: usize,
3535 roc_period2_start: usize,
3536 roc_period2_end: usize,
3537 roc_period2_step: usize,
3538 roc_period3_start: usize,
3539 roc_period3_end: usize,
3540 roc_period3_step: usize,
3541 roc_period4_start: usize,
3542 roc_period4_end: usize,
3543 roc_period4_step: usize,
3544 signal_period_start: usize,
3545 signal_period_end: usize,
3546 signal_period_step: usize,
3547) -> Result<usize, JsValue> {
3548 if in_ptr.is_null() || line_out_ptr.is_null() || signal_out_ptr.is_null() {
3549 return Err(JsValue::from_str("null pointer passed to kst_batch_into"));
3550 }
3551
3552 unsafe {
3553 let data = std::slice::from_raw_parts(in_ptr, len);
3554
3555 let sweep = KstBatchRange {
3556 sma_period1: (sma_period1_start, sma_period1_end, sma_period1_step),
3557 sma_period2: (sma_period2_start, sma_period2_end, sma_period2_step),
3558 sma_period3: (sma_period3_start, sma_period3_end, sma_period3_step),
3559 sma_period4: (sma_period4_start, sma_period4_end, sma_period4_step),
3560 roc_period1: (roc_period1_start, roc_period1_end, roc_period1_step),
3561 roc_period2: (roc_period2_start, roc_period2_end, roc_period2_step),
3562 roc_period3: (roc_period3_start, roc_period3_end, roc_period3_step),
3563 roc_period4: (roc_period4_start, roc_period4_end, roc_period4_step),
3564 signal_period: (signal_period_start, signal_period_end, signal_period_step),
3565 };
3566
3567 let count_range = |r: &(usize, usize, usize)| {
3568 if r.2 == 0 {
3569 0
3570 } else {
3571 ((r.1.saturating_sub(r.0)) / r.2) + 1
3572 }
3573 };
3574
3575 let rows = count_range(&sweep.sma_period1)
3576 .max(1)
3577 .checked_mul(count_range(&sweep.sma_period2).max(1))
3578 .and_then(|x| x.checked_mul(count_range(&sweep.sma_period3).max(1)))
3579 .and_then(|x| x.checked_mul(count_range(&sweep.sma_period4).max(1)))
3580 .and_then(|x| x.checked_mul(count_range(&sweep.roc_period1).max(1)))
3581 .and_then(|x| x.checked_mul(count_range(&sweep.roc_period2).max(1)))
3582 .and_then(|x| x.checked_mul(count_range(&sweep.roc_period3).max(1)))
3583 .and_then(|x| x.checked_mul(count_range(&sweep.roc_period4).max(1)))
3584 .and_then(|x| x.checked_mul(count_range(&sweep.signal_period).max(1)))
3585 .ok_or_else(|| JsValue::from_str("kst: size overflow in kst_batch_into"))?;
3586 let cols = len;
3587
3588 let total = rows
3589 .checked_mul(cols)
3590 .ok_or_else(|| JsValue::from_str("kst: size overflow in kst_batch_into buffers"))?;
3591
3592 let line_out = std::slice::from_raw_parts_mut(line_out_ptr, total);
3593 let signal_out = std::slice::from_raw_parts_mut(signal_out_ptr, total);
3594
3595 kst_batch_inner_into(data, &sweep, Kernel::Auto, false, line_out, signal_out)
3596 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3597
3598 Ok(rows)
3599 }
3600}