1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
19 make_uninit_matrix,
20};
21#[cfg(feature = "python")]
22use crate::utilities::kernel_validation::validate_kernel;
23#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
24use core::arch::x86_64::*;
25#[cfg(not(target_arch = "wasm32"))]
26use rayon::prelude::*;
27use std::convert::AsRef;
28use std::error::Error;
29use std::mem::MaybeUninit;
30use thiserror::Error;
31
32#[cfg(all(feature = "python", feature = "cuda"))]
33use crate::{
34 cuda::moving_averages::CudaTsf,
35 indicators::moving_averages::alma::{make_device_array_py, DeviceArrayF32Py},
36};
37
38impl<'a> AsRef<[f64]> for TsfInput<'a> {
39 #[inline(always)]
40 fn as_ref(&self) -> &[f64] {
41 match &self.data {
42 TsfData::Slice(slice) => slice,
43 TsfData::Candles { candles, source } => source_type(candles, source),
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
49pub enum TsfData<'a> {
50 Candles {
51 candles: &'a Candles,
52 source: &'a str,
53 },
54 Slice(&'a [f64]),
55}
56
57#[derive(Debug, Clone)]
58pub struct TsfOutput {
59 pub values: Vec<f64>,
60}
61
62#[derive(Debug, Clone)]
63#[cfg_attr(
64 all(target_arch = "wasm32", feature = "wasm"),
65 derive(Serialize, Deserialize)
66)]
67pub struct TsfParams {
68 pub period: Option<usize>,
69}
70
71impl Default for TsfParams {
72 fn default() -> Self {
73 Self { period: Some(14) }
74 }
75}
76
77#[derive(Debug, Clone)]
78pub struct TsfInput<'a> {
79 pub data: TsfData<'a>,
80 pub params: TsfParams,
81}
82
83impl<'a> TsfInput<'a> {
84 #[inline]
85 pub fn from_candles(c: &'a Candles, s: &'a str, p: TsfParams) -> Self {
86 Self {
87 data: TsfData::Candles {
88 candles: c,
89 source: s,
90 },
91 params: p,
92 }
93 }
94 #[inline]
95 pub fn from_slice(sl: &'a [f64], p: TsfParams) -> Self {
96 Self {
97 data: TsfData::Slice(sl),
98 params: p,
99 }
100 }
101 #[inline]
102 pub fn with_default_candles(c: &'a Candles) -> Self {
103 Self::from_candles(c, "close", TsfParams::default())
104 }
105 #[inline]
106 pub fn get_period(&self) -> usize {
107 self.params.period.unwrap_or(14)
108 }
109}
110
111#[derive(Copy, Clone, Debug)]
112pub struct TsfBuilder {
113 period: Option<usize>,
114 kernel: Kernel,
115}
116
117impl Default for TsfBuilder {
118 fn default() -> Self {
119 Self {
120 period: None,
121 kernel: Kernel::Auto,
122 }
123 }
124}
125
126impl TsfBuilder {
127 #[inline(always)]
128 pub fn new() -> Self {
129 Self::default()
130 }
131 #[inline(always)]
132 pub fn period(mut self, n: usize) -> Self {
133 self.period = Some(n);
134 self
135 }
136 #[inline(always)]
137 pub fn kernel(mut self, k: Kernel) -> Self {
138 self.kernel = k;
139 self
140 }
141 #[inline(always)]
142 pub fn apply(self, c: &Candles) -> Result<TsfOutput, TsfError> {
143 let p = TsfParams {
144 period: self.period,
145 };
146 let i = TsfInput::from_candles(c, "close", p);
147 tsf_with_kernel(&i, self.kernel)
148 }
149 #[inline(always)]
150 pub fn apply_slice(self, d: &[f64]) -> Result<TsfOutput, TsfError> {
151 let p = TsfParams {
152 period: self.period,
153 };
154 let i = TsfInput::from_slice(d, p);
155 tsf_with_kernel(&i, self.kernel)
156 }
157 #[inline(always)]
158 pub fn into_stream(self) -> Result<TsfStream, TsfError> {
159 let p = TsfParams {
160 period: self.period,
161 };
162 TsfStream::try_new(p)
163 }
164}
165
166#[derive(Debug, Error)]
167pub enum TsfError {
168 #[error("tsf: Input data slice is empty.")]
169 EmptyInputData,
170 #[error("tsf: All values are NaN.")]
171 AllValuesNaN,
172 #[error("tsf: Invalid period: period = {period}, data length = {data_len}")]
173 InvalidPeriod { period: usize, data_len: usize },
174 #[error("tsf: Not enough valid data: needed = {needed}, valid = {valid}")]
175 NotEnoughValidData { needed: usize, valid: usize },
176 #[error("tsf: Output length mismatch: expected = {expected}, got = {got}")]
177 OutputLengthMismatch { expected: usize, got: usize },
178 #[error("tsf: Period must be at least 2 for linear regression, got {period}")]
179 PeriodTooSmall { period: usize },
180 #[error("tsf: Invalid batch range: start = {start}, end = {end}, step = {step}")]
181 InvalidRange {
182 start: usize,
183 end: usize,
184 step: usize,
185 },
186 #[error("tsf: Invalid kernel for batch path: {0:?}")]
187 InvalidKernelForBatch(Kernel),
188}
189
190#[inline]
191pub fn tsf(input: &TsfInput) -> Result<TsfOutput, TsfError> {
192 tsf_with_kernel(input, Kernel::Auto)
193}
194
195#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
196pub fn tsf_into(input: &TsfInput, out: &mut [f64]) -> Result<(), TsfError> {
197 let data: &[f64] = match &input.data {
198 TsfData::Candles { candles, source } => source_type(candles, source),
199 TsfData::Slice(sl) => sl,
200 };
201
202 let len = data.len();
203 if len == 0 {
204 return Err(TsfError::EmptyInputData);
205 }
206
207 let first = data
208 .iter()
209 .position(|x| !x.is_nan())
210 .ok_or(TsfError::AllValuesNaN)?;
211 let period = input.get_period();
212
213 if period < 2 {
214 return Err(TsfError::PeriodTooSmall { period });
215 }
216 if period > len {
217 return Err(TsfError::InvalidPeriod {
218 period,
219 data_len: len,
220 });
221 }
222 if (len - first) < period {
223 return Err(TsfError::NotEnoughValidData {
224 needed: period,
225 valid: len - first,
226 });
227 }
228 if out.len() != len {
229 return Err(TsfError::OutputLengthMismatch {
230 expected: len,
231 got: out.len(),
232 });
233 }
234
235 let warmup_end = first + period - 1;
236 let nanq = f64::from_bits(0x7ff8_0000_0000_0000);
237 for v in &mut out[..warmup_end] {
238 *v = nanq;
239 }
240
241 let chosen = match Kernel::Auto {
242 Kernel::Auto => Kernel::Scalar,
243 other => other,
244 };
245
246 unsafe {
247 match chosen {
248 Kernel::Scalar | Kernel::ScalarBatch => tsf_scalar(data, period, first, out),
249 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
250 Kernel::Avx2 | Kernel::Avx2Batch => tsf_avx2(data, period, first, out),
251 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
252 Kernel::Avx512 | Kernel::Avx512Batch => tsf_avx512(data, period, first, out),
253 _ => unreachable!(),
254 }
255 }
256
257 Ok(())
258}
259
260pub fn tsf_with_kernel(input: &TsfInput, kernel: Kernel) -> Result<TsfOutput, TsfError> {
261 let data: &[f64] = match &input.data {
262 TsfData::Candles { candles, source } => source_type(candles, source),
263 TsfData::Slice(sl) => sl,
264 };
265
266 let len = data.len();
267 if len == 0 {
268 return Err(TsfError::EmptyInputData);
269 }
270 let first = data
271 .iter()
272 .position(|x| !x.is_nan())
273 .ok_or(TsfError::AllValuesNaN)?;
274 let period = input.get_period();
275
276 if period < 2 {
277 return Err(TsfError::PeriodTooSmall { period });
278 }
279 if period > len {
280 return Err(TsfError::InvalidPeriod {
281 period,
282 data_len: len,
283 });
284 }
285 if (len - first) < period {
286 return Err(TsfError::NotEnoughValidData {
287 needed: period,
288 valid: len - first,
289 });
290 }
291
292 let chosen = match kernel {
293 Kernel::Auto => Kernel::Scalar,
294 other => other,
295 };
296 let mut out = alloc_with_nan_prefix(len, first + period - 1);
297
298 unsafe {
299 match chosen {
300 Kernel::Scalar | Kernel::ScalarBatch => tsf_scalar(data, period, first, &mut out),
301 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
302 Kernel::Avx2 | Kernel::Avx2Batch => tsf_avx2(data, period, first, &mut out),
303 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
304 Kernel::Avx512 | Kernel::Avx512Batch => tsf_avx512(data, period, first, &mut out),
305 _ => unreachable!(),
306 }
307 }
308
309 Ok(TsfOutput { values: out })
310}
311
312#[inline]
313pub fn tsf_into_slice(dst: &mut [f64], input: &TsfInput, kern: Kernel) -> Result<(), TsfError> {
314 let data: &[f64] = match &input.data {
315 TsfData::Candles { candles, source } => source_type(candles, source),
316 TsfData::Slice(sl) => sl,
317 };
318
319 let len = data.len();
320 if len == 0 {
321 return Err(TsfError::EmptyInputData);
322 }
323 let first = data
324 .iter()
325 .position(|x| !x.is_nan())
326 .ok_or(TsfError::AllValuesNaN)?;
327 let period = input.get_period();
328
329 if period < 2 {
330 return Err(TsfError::PeriodTooSmall { period });
331 }
332 if period > len {
333 return Err(TsfError::InvalidPeriod {
334 period,
335 data_len: len,
336 });
337 }
338 if (len - first) < period {
339 return Err(TsfError::NotEnoughValidData {
340 needed: period,
341 valid: len - first,
342 });
343 }
344 if dst.len() != data.len() {
345 return Err(TsfError::OutputLengthMismatch {
346 expected: data.len(),
347 got: dst.len(),
348 });
349 }
350
351 let chosen = match kern {
352 Kernel::Auto => Kernel::Scalar,
353 k => k,
354 };
355
356 match chosen {
357 Kernel::Scalar | Kernel::ScalarBatch => tsf_scalar(data, period, first, dst),
358 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
359 Kernel::Avx2 | Kernel::Avx2Batch => unsafe { tsf_avx2(data, period, first, dst) },
360 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
361 Kernel::Avx512 | Kernel::Avx512Batch => unsafe { tsf_avx512(data, period, first, dst) },
362 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
363 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
364 tsf_scalar(data, period, first, dst)
365 }
366 _ => unreachable!(),
367 }
368
369 let warmup_end = first + period - 1;
370 let nanq = f64::from_bits(0x7ff8_0000_0000_0000);
371 for v in &mut dst[..warmup_end] {
372 *v = nanq;
373 }
374
375 Ok(())
376}
377
378#[inline(always)]
379pub fn tsf_scalar(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
380 let p = period;
381 let n = data.len();
382 if n == 0 {
383 return;
384 }
385
386 let pf = p as f64;
387
388 let mut sum_x = 0.0f64;
389 let mut sum_x2 = 0.0f64;
390 for x in 0..p {
391 let xf = x as f64;
392 sum_x += xf;
393 sum_x2 += xf * xf;
394 }
395 let divisor = pf * sum_x2 - sum_x * sum_x;
396
397 let inv_div = 1.0 / divisor;
398 let inv_pf = 1.0 / pf;
399 let pf_over_div = pf * inv_div;
400 let sumx_over_div = sum_x * inv_div;
401 let p_minus_mean_x = pf - sum_x * inv_pf;
402
403 let mut base = first_val;
404 let mut i = base + p - 1;
405 if i >= n {
406 return;
407 }
408
409 let mut s0 = 0.0f64;
410 let mut s1 = 0.0f64;
411 let mut nan_count = 0usize;
412 unsafe {
413 let mut ptr = data.as_ptr().add(base);
414 for j in 0..p {
415 let v = *ptr;
416 if v.is_nan() {
417 nan_count += 1;
418 } else {
419 s0 += v;
420 s1 += (j as f64) * v;
421 }
422 ptr = ptr.add(1);
423 }
424 }
425
426 if nan_count == 0 {
427 let m = s1 * pf_over_div - s0 * sumx_over_div;
428 unsafe {
429 *out.get_unchecked_mut(i) = s0 * inv_pf + m * p_minus_mean_x;
430 }
431 } else {
432 s0 = f64::NAN;
433 s1 = f64::NAN;
434 unsafe {
435 *out.get_unchecked_mut(i) = f64::NAN;
436 }
437 }
438
439 while i + 1 < n {
440 let y_old = unsafe { *data.get_unchecked(base) };
441 let y_new = unsafe { *data.get_unchecked(base + p) };
442 base += 1;
443 i += 1;
444
445 let prev_nan = nan_count;
446 if y_old.is_nan() {
447 nan_count = nan_count.saturating_sub(1);
448 }
449 if y_new.is_nan() {
450 nan_count = nan_count.saturating_add(1);
451 }
452
453 if nan_count == 0 {
454 if prev_nan == 0 {
455 let new_s0 = s0 + (y_new - y_old);
456 let new_s1 = pf * y_new + s1 - new_s0;
457 s0 = new_s0;
458 s1 = new_s1;
459 } else {
460 let mut r0 = 0.0f64;
461 let mut r1 = 0.0f64;
462 for j in 0..p {
463 let v = unsafe { *data.get_unchecked(base + j) };
464 r0 += v;
465 r1 += (j as f64) * v;
466 }
467 s0 = r0;
468 s1 = r1;
469 }
470
471 let m = s1 * pf_over_div - s0 * sumx_over_div;
472 unsafe {
473 *out.get_unchecked_mut(i) = s0 * inv_pf + m * p_minus_mean_x;
474 }
475 } else {
476 s0 = f64::NAN;
477 s1 = f64::NAN;
478 unsafe {
479 *out.get_unchecked_mut(i) = f64::NAN;
480 }
481 }
482 }
483}
484
485#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
486#[inline]
487pub fn tsf_avx512(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
488 unsafe {
489 if period <= 32 {
490 tsf_avx512_short(data, period, first_valid, out);
491 } else {
492 tsf_avx512_long(data, period, first_valid, out);
493 }
494 }
495}
496
497#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
498#[target_feature(enable = "avx2")]
499#[inline]
500pub fn tsf_avx2(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
501 unsafe {
502 let n = data.len();
503 if n == 0 {
504 return;
505 }
506
507 let p = period;
508 let pf = p as f64;
509
510 let mut sum_x = 0.0f64;
511 let mut sum_x2 = 0.0f64;
512 for x in 0..p {
513 let xf = x as f64;
514 sum_x += xf;
515 sum_x2 += xf * xf;
516 }
517 let divisor = pf * sum_x2 - sum_x * sum_x;
518
519 let pfv = _mm256_set1_pd(pf);
520 let sumxv = _mm256_set1_pd(sum_x);
521 let divv = _mm256_set1_pd(divisor);
522
523 let mut base = first_valid;
524 let mut i = base + p - 1;
525 if i >= n {
526 return;
527 }
528
529 let mut s0 = 0.0f64;
530 let mut s1 = 0.0f64;
531 let mut ok = true;
532 for j in 0..p {
533 let v = *data.get_unchecked(base + j);
534 if v.is_nan() {
535 s0 = f64::NAN;
536 s1 = f64::NAN;
537 ok = false;
538 break;
539 }
540 s0 += v;
541 s1 += (j as f64) * v;
542 }
543
544 if ok {
545 let m = (pf * s1 - sum_x * s0) / divisor;
546 let b = (s0 - m * sum_x) / pf;
547 *out.get_unchecked_mut(i) = b + m * pf;
548 } else {
549 *out.get_unchecked_mut(i) = f64::NAN;
550 }
551
552 while i + 4 < n {
553 let mut s0_buf = [0.0f64; 4];
554 let mut s1_buf = [0.0f64; 4];
555 let mut s0_k = s0;
556 let mut s1_k = s1;
557
558 let y_old0 = *data.get_unchecked(base);
559 let y_new0 = *data.get_unchecked(base + p);
560 if s0_k.is_finite() && s1_k.is_finite() && y_old0.is_finite() && y_new0.is_finite() {
561 let ns0 = s0_k + (y_new0 - y_old0);
562 let ns1 = pf * y_new0 + s1_k - ns0;
563 s0_k = ns0;
564 s1_k = ns1;
565 } else {
566 let mut r0 = 0.0f64;
567 let mut r1 = 0.0f64;
568 let mut clean = true;
569 let b1 = base + 1;
570 for j in 0..p {
571 let v = *data.get_unchecked(b1 + j);
572 if v.is_nan() {
573 r0 = f64::NAN;
574 r1 = f64::NAN;
575 clean = false;
576 break;
577 }
578 r0 += v;
579 r1 += (j as f64) * v;
580 }
581 s0_k = r0;
582 s1_k = r1;
583 let _ = clean;
584 }
585 s0_buf[0] = s0_k;
586 s1_buf[0] = s1_k;
587
588 let y_old1 = *data.get_unchecked(base + 1);
589 let y_new1 = *data.get_unchecked(base + p + 1);
590 if s0_k.is_finite() && s1_k.is_finite() && y_old1.is_finite() && y_new1.is_finite() {
591 let ns0 = s0_k + (y_new1 - y_old1);
592 let ns1 = pf * y_new1 + s1_k - ns0;
593 s0_k = ns0;
594 s1_k = ns1;
595 } else {
596 let mut r0 = 0.0f64;
597 let mut r1 = 0.0f64;
598 let mut clean = true;
599 let b2 = base + 2;
600 for j in 0..p {
601 let v = *data.get_unchecked(b2 + j);
602 if v.is_nan() {
603 r0 = f64::NAN;
604 r1 = f64::NAN;
605 clean = false;
606 break;
607 }
608 r0 += v;
609 r1 += (j as f64) * v;
610 }
611 s0_k = r0;
612 s1_k = r1;
613 let _ = clean;
614 }
615 s0_buf[1] = s0_k;
616 s1_buf[1] = s1_k;
617
618 let y_old2 = *data.get_unchecked(base + 2);
619 let y_new2 = *data.get_unchecked(base + p + 2);
620 if s0_k.is_finite() && s1_k.is_finite() && y_old2.is_finite() && y_new2.is_finite() {
621 let ns0 = s0_k + (y_new2 - y_old2);
622 let ns1 = pf * y_new2 + s1_k - ns0;
623 s0_k = ns0;
624 s1_k = ns1;
625 } else {
626 let mut r0 = 0.0f64;
627 let mut r1 = 0.0f64;
628 let mut clean = true;
629 let b3 = base + 3;
630 for j in 0..p {
631 let v = *data.get_unchecked(b3 + j);
632 if v.is_nan() {
633 r0 = f64::NAN;
634 r1 = f64::NAN;
635 clean = false;
636 break;
637 }
638 r0 += v;
639 r1 += (j as f64) * v;
640 }
641 s0_k = r0;
642 s1_k = r1;
643 let _ = clean;
644 }
645 s0_buf[2] = s0_k;
646 s1_buf[2] = s1_k;
647
648 let y_old3 = *data.get_unchecked(base + 3);
649 let y_new3 = *data.get_unchecked(base + p + 3);
650 if s0_k.is_finite() && s1_k.is_finite() && y_old3.is_finite() && y_new3.is_finite() {
651 let ns0 = s0_k + (y_new3 - y_old3);
652 let ns1 = pf * y_new3 + s1_k - ns0;
653 s0_k = ns0;
654 s1_k = ns1;
655 } else {
656 let mut r0 = 0.0f64;
657 let mut r1 = 0.0f64;
658 let mut clean = true;
659 let b4 = base + 4;
660 for j in 0..p {
661 let v = *data.get_unchecked(b4 + j);
662 if v.is_nan() {
663 r0 = f64::NAN;
664 r1 = f64::NAN;
665 clean = false;
666 break;
667 }
668 r0 += v;
669 r1 += (j as f64) * v;
670 }
671 s0_k = r0;
672 s1_k = r1;
673 let _ = clean;
674 }
675 s0_buf[3] = s0_k;
676 s1_buf[3] = s1_k;
677
678 let s0v = _mm256_loadu_pd(s0_buf.as_ptr());
679 let s1v = _mm256_loadu_pd(s1_buf.as_ptr());
680 let num = _mm256_sub_pd(_mm256_mul_pd(pfv, s1v), _mm256_mul_pd(sumxv, s0v));
681 let mv = _mm256_div_pd(num, divv);
682 let bv = _mm256_div_pd(_mm256_sub_pd(s0v, _mm256_mul_pd(mv, sumxv)), pfv);
683 let outv = _mm256_add_pd(bv, _mm256_mul_pd(mv, pfv));
684 _mm256_storeu_pd(out.as_mut_ptr().add(i + 1), outv);
685
686 s0 = s0_k;
687 s1 = s1_k;
688 base += 4;
689 i += 4;
690 }
691
692 while i + 1 < n {
693 let y_old = *data.get_unchecked(base);
694 let y_new = *data.get_unchecked(base + p);
695 base += 1;
696 i += 1;
697
698 if s0.is_finite() && s1.is_finite() && y_old.is_finite() && y_new.is_finite() {
699 let ns0 = s0 + (y_new - y_old);
700 let ns1 = pf * y_new + s1 - ns0;
701 s0 = ns0;
702 s1 = ns1;
703 let m = (pf * s1 - sum_x * s0) / divisor;
704 let b = (s0 - m * sum_x) / pf;
705 *out.get_unchecked_mut(i) = b + m * pf;
706 } else {
707 let mut r0 = 0.0f64;
708 let mut r1 = 0.0f64;
709 let mut clean = true;
710 for j in 0..p {
711 let v = *data.get_unchecked(base + j - 1);
712 if v.is_nan() {
713 r0 = f64::NAN;
714 r1 = f64::NAN;
715 clean = false;
716 break;
717 }
718 r0 += v;
719 r1 += (j as f64) * v;
720 }
721 s0 = r0;
722 s1 = r1;
723 if clean {
724 let m = (pf * s1 - sum_x * s0) / divisor;
725 let b = (s0 - m * sum_x) / pf;
726 *out.get_unchecked_mut(i) = b + m * pf;
727 } else {
728 *out.get_unchecked_mut(i) = f64::NAN;
729 }
730 }
731 }
732 }
733}
734
735#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
736#[inline]
737pub unsafe fn tsf_avx512_short(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
738 tsf_scalar(data, period, first_valid, out)
739}
740
741#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
742#[inline]
743pub unsafe fn tsf_avx512_long(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
744 tsf_scalar(data, period, first_valid, out)
745}
746
747#[derive(Debug, Clone)]
748pub struct TsfStream {
749 period: usize,
750 buffer: Vec<f64>,
751 head: usize,
752 filled: bool,
753
754 pf: f64,
755 sum_x: f64,
756 sum_x_sqr: f64,
757 divisor: f64,
758 inv_pf: f64,
759 inv_divisor: f64,
760 pf_over_div: f64,
761 sumx_over_div: f64,
762 p_minus_mean_x: f64,
763
764 s0: f64,
765 s1: f64,
766
767 nan_count: usize,
768}
769
770impl TsfStream {
771 pub fn try_new(params: TsfParams) -> Result<Self, TsfError> {
772 let period = params.period.unwrap_or(14);
773 if period < 2 {
774 return Err(TsfError::PeriodTooSmall { period });
775 }
776
777 let pf = period as f64;
778 let mut sum_x = 0.0f64;
779 let mut sum_x_sqr = 0.0f64;
780 for x in 0..period {
781 let xf = x as f64;
782 sum_x += xf;
783 sum_x_sqr += xf * xf;
784 }
785 let divisor = pf * sum_x_sqr - (sum_x * sum_x);
786 let inv_pf = 1.0 / pf;
787 let inv_divisor = 1.0 / divisor;
788 let pf_over_div = pf * inv_divisor;
789 let sumx_over_div = sum_x * inv_divisor;
790 let p_minus_mean_x = pf - sum_x * inv_pf;
791
792 Ok(Self {
793 period,
794 buffer: vec![f64::NAN; period],
795 head: 0,
796 filled: false,
797 pf,
798 sum_x,
799 sum_x_sqr,
800 divisor,
801 inv_pf,
802 inv_divisor,
803 pf_over_div,
804 sumx_over_div,
805 p_minus_mean_x,
806 s0: 0.0,
807 s1: 0.0,
808 nan_count: 0,
809 })
810 }
811
812 #[inline(always)]
813 pub fn update(&mut self, value: f64) -> Option<f64> {
814 if !self.filled {
815 self.buffer[self.head] = value;
816 self.advance_head();
817 if self.head == 0 {
818 self.filled = true;
819
820 let (s0, s1, nan_count) = self.recompute_from_ring_checked();
821 self.s0 = s0;
822 self.s1 = s1;
823 self.nan_count = nan_count;
824
825 if self.nan_count > 0 || !self.s0.is_finite() || !self.s1.is_finite() {
826 return Some(f64::NAN);
827 }
828
829 let m = self.s1 * self.pf_over_div - self.s0 * self.sumx_over_div;
830 return Some(self.s0 * self.inv_pf + m * self.p_minus_mean_x);
831 }
832 return None;
833 }
834
835 let y_old = self.buffer[self.head];
836 let y_new = value;
837 self.buffer[self.head] = y_new;
838
839 let prev_nan_count = self.nan_count;
840 if y_old.is_nan() {
841 self.nan_count = self.nan_count.saturating_sub(1);
842 }
843 if y_new.is_nan() {
844 self.nan_count = self.nan_count.saturating_add(1);
845 }
846
847 let out = if self.nan_count == 0 {
848 if prev_nan_count == 0
849 && self.s0.is_finite()
850 && self.s1.is_finite()
851 && y_old.is_finite()
852 && y_new.is_finite()
853 {
854 let new_s0 = self.s0 + (y_new - y_old);
855
856 let new_s1 = self.pf * y_new + self.s1 - new_s0;
857 self.s0 = new_s0;
858 self.s1 = new_s1;
859 } else {
860 let (s0, s1) = self.recompute_from_ring_clean();
861 self.s0 = s0;
862 self.s1 = s1;
863 }
864
865 let m = self.s1 * self.pf_over_div - self.s0 * self.sumx_over_div;
866 self.s0 * self.inv_pf + m * self.p_minus_mean_x
867 } else {
868 self.s0 = f64::NAN;
869 self.s1 = f64::NAN;
870 f64::NAN
871 };
872
873 self.advance_head();
874 Some(out)
875 }
876
877 #[inline(always)]
878 fn advance_head(&mut self) {
879 self.head += 1;
880 if self.head == self.period {
881 self.head = 0;
882 }
883 }
884
885 #[inline(always)]
886 fn recompute_from_ring_clean(&self) -> (f64, f64) {
887 let mut s0 = 0.0;
888 let mut s1 = 0.0;
889 let mut idx = self.head;
890 for j in 0..self.period {
891 let v = self.buffer[idx];
892
893 s0 += v;
894 s1 += (j as f64) * v;
895 idx += 1;
896 if idx == self.period {
897 idx = 0;
898 }
899 }
900 (s0, s1)
901 }
902
903 #[inline(always)]
904 fn recompute_from_ring_checked(&self) -> (f64, f64, usize) {
905 let mut s0 = 0.0;
906 let mut s1 = 0.0;
907 let mut cnt = 0usize;
908 let mut idx = self.head;
909 for j in 0..self.period {
910 let v = self.buffer[idx];
911 if v.is_nan() {
912 cnt += 1;
913 } else {
914 s0 += v;
915 s1 += (j as f64) * v;
916 }
917 idx += 1;
918 if idx == self.period {
919 idx = 0;
920 }
921 }
922 if cnt > 0 {
923 (f64::NAN, f64::NAN, cnt)
924 } else {
925 (s0, s1, 0)
926 }
927 }
928}
929
930#[derive(Clone, Debug)]
931pub struct TsfBatchRange {
932 pub period: (usize, usize, usize),
933}
934
935impl Default for TsfBatchRange {
936 fn default() -> Self {
937 Self {
938 period: (14, 263, 1),
939 }
940 }
941}
942
943#[derive(Clone, Debug, Default)]
944pub struct TsfBatchBuilder {
945 range: TsfBatchRange,
946 kernel: Kernel,
947}
948
949impl TsfBatchBuilder {
950 pub fn new() -> Self {
951 Self::default()
952 }
953 pub fn kernel(mut self, k: Kernel) -> Self {
954 self.kernel = k;
955 self
956 }
957 #[inline]
958 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
959 self.range.period = (start, end, step);
960 self
961 }
962 #[inline]
963 pub fn period_static(mut self, p: usize) -> Self {
964 self.range.period = (p, p, 0);
965 self
966 }
967 pub fn apply_slice(self, data: &[f64]) -> Result<TsfBatchOutput, TsfError> {
968 tsf_batch_with_kernel(data, &self.range, self.kernel)
969 }
970 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<TsfBatchOutput, TsfError> {
971 TsfBatchBuilder::new().kernel(k).apply_slice(data)
972 }
973 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<TsfBatchOutput, TsfError> {
974 let slice = source_type(c, src);
975 self.apply_slice(slice)
976 }
977 pub fn with_default_candles(c: &Candles) -> Result<TsfBatchOutput, TsfError> {
978 TsfBatchBuilder::new()
979 .kernel(Kernel::Auto)
980 .apply_candles(c, "close")
981 }
982}
983
984pub fn tsf_batch_with_kernel(
985 data: &[f64],
986 sweep: &TsfBatchRange,
987 k: Kernel,
988) -> Result<TsfBatchOutput, TsfError> {
989 let kernel = match k {
990 Kernel::Auto => Kernel::ScalarBatch,
991 other if other.is_batch() => other,
992 other => return Err(TsfError::InvalidKernelForBatch(other)),
993 };
994
995 let simd = match kernel {
996 Kernel::Avx512Batch => Kernel::Avx512,
997 Kernel::Avx2Batch => Kernel::Avx2,
998 Kernel::ScalarBatch => Kernel::Scalar,
999 _ => unreachable!(),
1000 };
1001 tsf_batch_par_slice(data, sweep, simd)
1002}
1003
1004#[derive(Clone, Debug)]
1005pub struct TsfBatchOutput {
1006 pub values: Vec<f64>,
1007 pub combos: Vec<TsfParams>,
1008 pub rows: usize,
1009 pub cols: usize,
1010}
1011impl TsfBatchOutput {
1012 pub fn row_for_params(&self, p: &TsfParams) -> Option<usize> {
1013 self.combos
1014 .iter()
1015 .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
1016 }
1017
1018 pub fn values_for(&self, p: &TsfParams) -> Option<&[f64]> {
1019 self.row_for_params(p).map(|row| {
1020 let start = row * self.cols;
1021 &self.values[start..start + self.cols]
1022 })
1023 }
1024}
1025
1026#[inline(always)]
1027fn expand_grid(r: &TsfBatchRange) -> Vec<TsfParams> {
1028 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
1029 if step == 0 || start == end {
1030 return vec![start];
1031 }
1032 let mut out = Vec::new();
1033 if start <= end {
1034 let mut v = start;
1035 while v <= end {
1036 out.push(v);
1037 match v.checked_add(step) {
1038 Some(next) => v = next,
1039 None => break,
1040 }
1041 }
1042 } else {
1043 let mut v = start;
1044 loop {
1045 if v < end {
1046 break;
1047 }
1048 out.push(v);
1049 if v == end {
1050 break;
1051 }
1052 match v.checked_sub(step) {
1053 Some(next) => v = next,
1054 None => break,
1055 }
1056 }
1057 }
1058 out
1059 }
1060
1061 let periods = axis_usize(r.period);
1062
1063 let mut out = Vec::with_capacity(periods.len());
1064 for &p in &periods {
1065 out.push(TsfParams { period: Some(p) });
1066 }
1067 out
1068}
1069
1070#[inline(always)]
1071pub fn tsf_batch_slice(
1072 data: &[f64],
1073 sweep: &TsfBatchRange,
1074 kern: Kernel,
1075) -> Result<TsfBatchOutput, TsfError> {
1076 tsf_batch_inner(data, sweep, kern, false)
1077}
1078
1079#[inline(always)]
1080pub fn tsf_batch_par_slice(
1081 data: &[f64],
1082 sweep: &TsfBatchRange,
1083 kern: Kernel,
1084) -> Result<TsfBatchOutput, TsfError> {
1085 tsf_batch_inner(data, sweep, kern, true)
1086}
1087
1088#[inline(always)]
1089fn tsf_batch_inner(
1090 data: &[f64],
1091 sweep: &TsfBatchRange,
1092 kern: Kernel,
1093 parallel: bool,
1094) -> Result<TsfBatchOutput, TsfError> {
1095 if data.is_empty() {
1096 return Err(TsfError::EmptyInputData);
1097 }
1098
1099 let kern = match kern {
1100 Kernel::Auto | Kernel::Scalar | Kernel::ScalarBatch => Kernel::Scalar,
1101 Kernel::Avx2 | Kernel::Avx2Batch => Kernel::Avx2,
1102 Kernel::Avx512 | Kernel::Avx512Batch => Kernel::Avx512,
1103 other => return Err(TsfError::InvalidKernelForBatch(other)),
1104 };
1105
1106 let combos = expand_grid(sweep);
1107 if combos.is_empty() {
1108 let (start, end, step) = sweep.period;
1109 return Err(TsfError::InvalidRange { start, end, step });
1110 }
1111
1112 let first = data
1113 .iter()
1114 .position(|x| !x.is_nan())
1115 .ok_or(TsfError::AllValuesNaN)?;
1116
1117 let mut max_p = 0usize;
1118 for prm in &combos {
1119 let p = prm.period.unwrap();
1120 if p < 2 {
1121 return Err(TsfError::PeriodTooSmall { period: p });
1122 }
1123 if p > max_p {
1124 max_p = p;
1125 }
1126 }
1127 if max_p > data.len() {
1128 return Err(TsfError::InvalidPeriod {
1129 period: max_p,
1130 data_len: data.len(),
1131 });
1132 }
1133 if combos.len().checked_mul(max_p).is_none() {
1134 let (start, end, step) = sweep.period;
1135 return Err(TsfError::InvalidRange { start, end, step });
1136 }
1137 if data.len() - first < max_p {
1138 return Err(TsfError::NotEnoughValidData {
1139 needed: max_p,
1140 valid: data.len() - first,
1141 });
1142 }
1143
1144 let rows = combos.len();
1145 let cols = data.len();
1146 if rows.checked_mul(cols).is_none() {
1147 let (start, end, step) = sweep.period;
1148 return Err(TsfError::InvalidRange { start, end, step });
1149 }
1150 let mut sum_xs = vec![0.0; rows];
1151 let mut divisors = vec![0.0; rows];
1152
1153 for (row, prm) in combos.iter().enumerate() {
1154 let period = prm.period.unwrap();
1155 let sum_x = (0..period).map(|x| x as f64).sum::<f64>();
1156 let sum_x2 = (0..period).map(|x| (x as f64) * (x as f64)).sum::<f64>();
1157 let divisor = (period as f64 * sum_x2) - (sum_x * sum_x);
1158 sum_xs[row] = sum_x;
1159 divisors[row] = divisor;
1160 }
1161
1162 let mut buf_mu = make_uninit_matrix(rows, cols);
1163
1164 let warm: Vec<usize> = combos
1165 .iter()
1166 .map(|c| first + c.period.unwrap() - 1)
1167 .collect();
1168 init_matrix_prefixes(&mut buf_mu, cols, &warm);
1169
1170 let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
1171 let out: &mut [f64] = unsafe {
1172 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
1173 };
1174
1175 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1176 let period = combos[row].period.unwrap();
1177 let sum_x = sum_xs[row];
1178 let divisor = divisors[row];
1179
1180 match kern {
1181 Kernel::Scalar => tsf_row_scalar(data, first, period, sum_x, divisor, out_row),
1182 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1183 Kernel::Avx2 => tsf_row_avx2(data, first, period, sum_x, divisor, out_row),
1184 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1185 Kernel::Avx512 => tsf_row_avx512(data, first, period, sum_x, divisor, out_row),
1186 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1187 Kernel::Avx2 | Kernel::Avx512 => {
1188 tsf_row_scalar(data, first, period, sum_x, divisor, out_row)
1189 }
1190 _ => unreachable!(),
1191 }
1192 };
1193
1194 if parallel {
1195 #[cfg(not(target_arch = "wasm32"))]
1196 {
1197 out.par_chunks_mut(cols)
1198 .enumerate()
1199 .for_each(|(row, slice)| do_row(row, slice));
1200 }
1201
1202 #[cfg(target_arch = "wasm32")]
1203 {
1204 for (row, slice) in out.chunks_mut(cols).enumerate() {
1205 do_row(row, slice);
1206 }
1207 }
1208 } else {
1209 for (row, slice) in out.chunks_mut(cols).enumerate() {
1210 do_row(row, slice);
1211 }
1212 }
1213
1214 let values = unsafe {
1215 Vec::from_raw_parts(
1216 buf_guard.as_mut_ptr() as *mut f64,
1217 buf_guard.len(),
1218 buf_guard.capacity(),
1219 )
1220 };
1221
1222 Ok(TsfBatchOutput {
1223 values,
1224 combos,
1225 rows,
1226 cols,
1227 })
1228}
1229
1230#[inline(always)]
1231fn tsf_batch_inner_into(
1232 data: &[f64],
1233 sweep: &TsfBatchRange,
1234 kern: Kernel,
1235 parallel: bool,
1236 output: &mut [f64],
1237) -> Result<Vec<TsfParams>, TsfError> {
1238 if data.is_empty() {
1239 return Err(TsfError::EmptyInputData);
1240 }
1241
1242 let kern = match kern {
1243 Kernel::Auto | Kernel::Scalar | Kernel::ScalarBatch => Kernel::Scalar,
1244 Kernel::Avx2 | Kernel::Avx2Batch => Kernel::Avx2,
1245 Kernel::Avx512 | Kernel::Avx512Batch => Kernel::Avx512,
1246 other => return Err(TsfError::InvalidKernelForBatch(other)),
1247 };
1248
1249 let combos = expand_grid(sweep);
1250 if combos.is_empty() {
1251 let (start, end, step) = sweep.period;
1252 return Err(TsfError::InvalidRange { start, end, step });
1253 }
1254
1255 let first = data
1256 .iter()
1257 .position(|x| !x.is_nan())
1258 .ok_or(TsfError::AllValuesNaN)?;
1259 let mut max_p = 0usize;
1260 for prm in &combos {
1261 let p = prm.period.unwrap();
1262 if p < 2 {
1263 return Err(TsfError::PeriodTooSmall { period: p });
1264 }
1265 if p > max_p {
1266 max_p = p;
1267 }
1268 }
1269 if max_p > data.len() {
1270 return Err(TsfError::InvalidPeriod {
1271 period: max_p,
1272 data_len: data.len(),
1273 });
1274 }
1275 if combos.len().checked_mul(max_p).is_none() {
1276 let (start, end, step) = sweep.period;
1277 return Err(TsfError::InvalidRange { start, end, step });
1278 }
1279 if data.len() - first < max_p {
1280 return Err(TsfError::NotEnoughValidData {
1281 needed: max_p,
1282 valid: data.len() - first,
1283 });
1284 }
1285
1286 let rows = combos.len();
1287 let cols = data.len();
1288 let total = rows.checked_mul(cols).ok_or_else(|| {
1289 let (start, end, step) = sweep.period;
1290 TsfError::InvalidRange { start, end, step }
1291 })?;
1292 if output.len() != total {
1293 return Err(TsfError::OutputLengthMismatch {
1294 expected: total,
1295 got: output.len(),
1296 });
1297 }
1298
1299 let mut sum_xs = vec![0.0; rows];
1300 let mut divisors = vec![0.0; rows];
1301 for (row, prm) in combos.iter().enumerate() {
1302 let p = prm.period.unwrap();
1303 let sum_x = (0..p).map(|x| x as f64).sum::<f64>();
1304 let sum_x2 = (0..p).map(|x| (x as f64) * (x as f64)).sum::<f64>();
1305 sum_xs[row] = sum_x;
1306 divisors[row] = (p as f64 * sum_x2) - (sum_x * sum_x);
1307 }
1308
1309 let warm: Vec<usize> = combos
1310 .iter()
1311 .map(|c| first + c.period.unwrap() - 1)
1312 .collect();
1313 let out_mu: &mut [MaybeUninit<f64>] = unsafe {
1314 core::slice::from_raw_parts_mut(output.as_mut_ptr() as *mut MaybeUninit<f64>, output.len())
1315 };
1316 init_matrix_prefixes(out_mu, cols, &warm);
1317
1318 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1319 let p = combos[row].period.unwrap();
1320 let sum_x = sum_xs[row];
1321 let div = divisors[row];
1322 match kern {
1323 Kernel::Scalar => tsf_row_scalar(data, first, p, sum_x, div, out_row),
1324 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1325 Kernel::Avx2 => tsf_row_avx2(data, first, p, sum_x, div, out_row),
1326 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1327 Kernel::Avx512 => tsf_row_avx512(data, first, p, sum_x, div, out_row),
1328 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1329 Kernel::Avx2 | Kernel::Avx512 => tsf_row_scalar(data, first, p, sum_x, div, out_row),
1330 _ => unreachable!(),
1331 }
1332 };
1333
1334 if parallel {
1335 #[cfg(not(target_arch = "wasm32"))]
1336 {
1337 output
1338 .par_chunks_mut(cols)
1339 .enumerate()
1340 .for_each(|(r, sl)| do_row(r, sl));
1341 }
1342 #[cfg(target_arch = "wasm32")]
1343 {
1344 for (r, sl) in output.chunks_mut(cols).enumerate() {
1345 do_row(r, sl);
1346 }
1347 }
1348 } else {
1349 for (r, sl) in output.chunks_mut(cols).enumerate() {
1350 do_row(r, sl);
1351 }
1352 }
1353
1354 Ok(combos)
1355}
1356
1357#[inline(always)]
1358unsafe fn tsf_row_scalar(
1359 data: &[f64],
1360 first: usize,
1361 period: usize,
1362 sum_x: f64,
1363 divisor: f64,
1364 out: &mut [f64],
1365) {
1366 let n = data.len();
1367 if n == 0 {
1368 return;
1369 }
1370
1371 let p = period;
1372 let pf = p as f64;
1373
1374 let mut base = first;
1375 let mut i = base + p - 1;
1376 if i >= n {
1377 return;
1378 }
1379
1380 let mut s0 = 0.0f64;
1381 let mut s1 = 0.0f64;
1382 let mut ok = true;
1383 for j in 0..p {
1384 let v = data[base + j];
1385 if v.is_nan() {
1386 s0 = f64::NAN;
1387 s1 = f64::NAN;
1388 ok = false;
1389 break;
1390 }
1391 s0 += v;
1392 s1 += (j as f64) * v;
1393 }
1394
1395 if ok {
1396 let m = (pf * s1 - sum_x * s0) / divisor;
1397 let b = (s0 - m * sum_x) / pf;
1398 out[i] = b + m * pf;
1399 } else {
1400 out[i] = f64::NAN;
1401 }
1402
1403 while i + 1 < n {
1404 let y_old = data[base];
1405 let y_new = data[base + p];
1406 base += 1;
1407 i += 1;
1408
1409 if s0.is_finite() && s1.is_finite() && y_old.is_finite() && y_new.is_finite() {
1410 let new_s0 = s0 + (y_new - y_old);
1411 let new_s1 = pf * y_new + s1 - new_s0;
1412 s0 = new_s0;
1413 s1 = new_s1;
1414
1415 let m = (pf * s1 - sum_x * s0) / divisor;
1416 let b = (s0 - m * sum_x) / pf;
1417 out[i] = b + m * pf;
1418 } else {
1419 let mut r0 = 0.0f64;
1420 let mut r1 = 0.0f64;
1421 let mut clean = true;
1422 for j in 0..p {
1423 let v = data[base + j];
1424 if v.is_nan() {
1425 r0 = f64::NAN;
1426 r1 = f64::NAN;
1427 clean = false;
1428 break;
1429 }
1430 r0 += v;
1431 r1 += (j as f64) * v;
1432 }
1433 s0 = r0;
1434 s1 = r1;
1435
1436 if clean {
1437 let m = (pf * s1 - sum_x * s0) / divisor;
1438 let b = (s0 - m * sum_x) / pf;
1439 out[i] = b + m * pf;
1440 } else {
1441 out[i] = f64::NAN;
1442 }
1443 }
1444 }
1445}
1446
1447#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1448#[inline(always)]
1449unsafe fn tsf_row_avx2(
1450 data: &[f64],
1451 first: usize,
1452 period: usize,
1453 sum_x: f64,
1454 divisor: f64,
1455 out: &mut [f64],
1456) {
1457 let n = data.len();
1458 if n == 0 {
1459 return;
1460 }
1461 let p = period;
1462 let pf = p as f64;
1463
1464 let pfv = _mm256_set1_pd(pf);
1465 let sumxv = _mm256_set1_pd(sum_x);
1466 let divv = _mm256_set1_pd(divisor);
1467
1468 let mut base = first;
1469 let mut i = base + p - 1;
1470 if i >= n {
1471 return;
1472 }
1473
1474 let mut s0 = 0.0f64;
1475 let mut s1 = 0.0f64;
1476 let mut ok = true;
1477 for j in 0..p {
1478 let v = *data.get_unchecked(base + j);
1479 if v.is_nan() {
1480 s0 = f64::NAN;
1481 s1 = f64::NAN;
1482 ok = false;
1483 break;
1484 }
1485 s0 += v;
1486 s1 += (j as f64) * v;
1487 }
1488 if ok {
1489 let m = (pf * s1 - sum_x * s0) / divisor;
1490 let b = (s0 - m * sum_x) / pf;
1491 *out.get_unchecked_mut(i) = b + m * pf;
1492 } else {
1493 *out.get_unchecked_mut(i) = f64::NAN;
1494 }
1495
1496 while i + 4 < n {
1497 let mut s0_buf = [0.0f64; 4];
1498 let mut s1_buf = [0.0f64; 4];
1499 let mut s0_k = s0;
1500 let mut s1_k = s1;
1501
1502 let y_old0 = *data.get_unchecked(base);
1503 let y_new0 = *data.get_unchecked(base + p);
1504 if s0_k.is_finite() && s1_k.is_finite() && y_old0.is_finite() && y_new0.is_finite() {
1505 let ns0 = s0_k + (y_new0 - y_old0);
1506 let ns1 = pf * y_new0 + s1_k - ns0;
1507 s0_k = ns0;
1508 s1_k = ns1;
1509 } else {
1510 let mut r0 = 0.0f64;
1511 let mut r1 = 0.0f64;
1512 let mut clean = true;
1513 let b1 = base + 1;
1514 for j in 0..p {
1515 let v = *data.get_unchecked(b1 + j);
1516 if v.is_nan() {
1517 r0 = f64::NAN;
1518 r1 = f64::NAN;
1519 clean = false;
1520 break;
1521 }
1522 r0 += v;
1523 r1 += (j as f64) * v;
1524 }
1525 s0_k = r0;
1526 s1_k = r1;
1527 let _ = clean;
1528 }
1529 s0_buf[0] = s0_k;
1530 s1_buf[0] = s1_k;
1531
1532 let y_old1 = *data.get_unchecked(base + 1);
1533 let y_new1 = *data.get_unchecked(base + p + 1);
1534 if s0_k.is_finite() && s1_k.is_finite() && y_old1.is_finite() && y_new1.is_finite() {
1535 let ns0 = s0_k + (y_new1 - y_old1);
1536 let ns1 = pf * y_new1 + s1_k - ns0;
1537 s0_k = ns0;
1538 s1_k = ns1;
1539 } else {
1540 let mut r0 = 0.0f64;
1541 let mut r1 = 0.0f64;
1542 let mut clean = true;
1543 let b2 = base + 2;
1544 for j in 0..p {
1545 let v = *data.get_unchecked(b2 + j);
1546 if v.is_nan() {
1547 r0 = f64::NAN;
1548 r1 = f64::NAN;
1549 clean = false;
1550 break;
1551 }
1552 r0 += v;
1553 r1 += (j as f64) * v;
1554 }
1555 s0_k = r0;
1556 s1_k = r1;
1557 let _ = clean;
1558 }
1559 s0_buf[1] = s0_k;
1560 s1_buf[1] = s1_k;
1561
1562 let y_old2 = *data.get_unchecked(base + 2);
1563 let y_new2 = *data.get_unchecked(base + p + 2);
1564 if s0_k.is_finite() && s1_k.is_finite() && y_old2.is_finite() && y_new2.is_finite() {
1565 let ns0 = s0_k + (y_new2 - y_old2);
1566 let ns1 = pf * y_new2 + s1_k - ns0;
1567 s0_k = ns0;
1568 s1_k = ns1;
1569 } else {
1570 let mut r0 = 0.0f64;
1571 let mut r1 = 0.0f64;
1572 let mut clean = true;
1573 let b3 = base + 3;
1574 for j in 0..p {
1575 let v = *data.get_unchecked(b3 + j);
1576 if v.is_nan() {
1577 r0 = f64::NAN;
1578 r1 = f64::NAN;
1579 clean = false;
1580 break;
1581 }
1582 r0 += v;
1583 r1 += (j as f64) * v;
1584 }
1585 s0_k = r0;
1586 s1_k = r1;
1587 let _ = clean;
1588 }
1589 s0_buf[2] = s0_k;
1590 s1_buf[2] = s1_k;
1591
1592 let y_old3 = *data.get_unchecked(base + 3);
1593 let y_new3 = *data.get_unchecked(base + p + 3);
1594 if s0_k.is_finite() && s1_k.is_finite() && y_old3.is_finite() && y_new3.is_finite() {
1595 let ns0 = s0_k + (y_new3 - y_old3);
1596 let ns1 = pf * y_new3 + s1_k - ns0;
1597 s0_k = ns0;
1598 s1_k = ns1;
1599 } else {
1600 let mut r0 = 0.0f64;
1601 let mut r1 = 0.0f64;
1602 let mut clean = true;
1603 let b4 = base + 4;
1604 for j in 0..p {
1605 let v = *data.get_unchecked(b4 + j);
1606 if v.is_nan() {
1607 r0 = f64::NAN;
1608 r1 = f64::NAN;
1609 clean = false;
1610 break;
1611 }
1612 r0 += v;
1613 r1 += (j as f64) * v;
1614 }
1615 s0_k = r0;
1616 s1_k = r1;
1617 let _ = clean;
1618 }
1619 s0_buf[3] = s0_k;
1620 s1_buf[3] = s1_k;
1621
1622 let s0v = _mm256_loadu_pd(s0_buf.as_ptr());
1623 let s1v = _mm256_loadu_pd(s1_buf.as_ptr());
1624 let num = _mm256_sub_pd(_mm256_mul_pd(pfv, s1v), _mm256_mul_pd(sumxv, s0v));
1625 let mv = _mm256_div_pd(num, divv);
1626 let bv = _mm256_div_pd(_mm256_sub_pd(s0v, _mm256_mul_pd(mv, sumxv)), pfv);
1627 let outv = _mm256_add_pd(bv, _mm256_mul_pd(mv, pfv));
1628 _mm256_storeu_pd(out.as_mut_ptr().add(i + 1), outv);
1629
1630 s0 = s0_k;
1631 s1 = s1_k;
1632 base += 4;
1633 i += 4;
1634 }
1635
1636 while i + 1 < n {
1637 let y_old = *data.get_unchecked(base);
1638 let y_new = *data.get_unchecked(base + p);
1639 base += 1;
1640 i += 1;
1641
1642 if s0.is_finite() && s1.is_finite() && y_old.is_finite() && y_new.is_finite() {
1643 let ns0 = s0 + (y_new - y_old);
1644 let ns1 = pf * y_new + s1 - ns0;
1645 s0 = ns0;
1646 s1 = ns1;
1647 let m = (pf * s1 - sum_x * s0) / divisor;
1648 let b = (s0 - m * sum_x) / pf;
1649 *out.get_unchecked_mut(i) = b + m * pf;
1650 } else {
1651 let mut r0 = 0.0f64;
1652 let mut r1 = 0.0f64;
1653 let mut clean = true;
1654 for j in 0..p {
1655 let v = *data.get_unchecked(base + j - 1);
1656 if v.is_nan() {
1657 r0 = f64::NAN;
1658 r1 = f64::NAN;
1659 clean = false;
1660 break;
1661 }
1662 r0 += v;
1663 r1 += (j as f64) * v;
1664 }
1665 s0 = r0;
1666 s1 = r1;
1667 if clean {
1668 let m = (pf * s1 - sum_x * s0) / divisor;
1669 let b = (s0 - m * sum_x) / pf;
1670 *out.get_unchecked_mut(i) = b + m * pf;
1671 } else {
1672 *out.get_unchecked_mut(i) = f64::NAN;
1673 }
1674 }
1675 }
1676}
1677
1678#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1679#[inline(always)]
1680unsafe fn tsf_row_avx512(
1681 data: &[f64],
1682 first: usize,
1683 period: usize,
1684 sum_x: f64,
1685 divisor: f64,
1686 out: &mut [f64],
1687) {
1688 if period <= 32 {
1689 tsf_row_avx512_short(data, first, period, sum_x, divisor, out);
1690 } else {
1691 tsf_row_avx512_long(data, first, period, sum_x, divisor, out);
1692 }
1693}
1694
1695#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1696#[inline(always)]
1697unsafe fn tsf_row_avx512_short(
1698 data: &[f64],
1699 first: usize,
1700 period: usize,
1701 sum_x: f64,
1702 divisor: f64,
1703 out: &mut [f64],
1704) {
1705 tsf_row_scalar(data, first, period, sum_x, divisor, out)
1706}
1707
1708#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1709#[inline(always)]
1710unsafe fn tsf_row_avx512_long(
1711 data: &[f64],
1712 first: usize,
1713 period: usize,
1714 sum_x: f64,
1715 divisor: f64,
1716 out: &mut [f64],
1717) {
1718 tsf_row_scalar(data, first, period, sum_x, divisor, out)
1719}
1720
1721#[cfg(feature = "python")]
1722#[pyfunction(name = "tsf")]
1723#[pyo3(signature = (data, period, kernel=None))]
1724pub fn tsf_py<'py>(
1725 py: Python<'py>,
1726 data: PyReadonlyArray1<'py, f64>,
1727 period: usize,
1728 kernel: Option<&str>,
1729) -> PyResult<Bound<'py, PyArray1<f64>>> {
1730 use numpy::{IntoPyArray, PyArrayMethods};
1731
1732 let slice_in = data.as_slice()?;
1733 let kern = validate_kernel(kernel, false)?;
1734
1735 let params = TsfParams {
1736 period: Some(period),
1737 };
1738 let tsf_in = TsfInput::from_slice(slice_in, params);
1739
1740 let result_vec: Vec<f64> = py
1741 .allow_threads(|| tsf_with_kernel(&tsf_in, kern).map(|o| o.values))
1742 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1743
1744 Ok(result_vec.into_pyarray(py))
1745}
1746
1747#[cfg(feature = "python")]
1748#[pyclass(name = "TsfStream")]
1749pub struct TsfStreamPy {
1750 stream: TsfStream,
1751}
1752
1753#[cfg(feature = "python")]
1754#[pymethods]
1755impl TsfStreamPy {
1756 #[new]
1757 fn new(period: usize) -> PyResult<Self> {
1758 let params = TsfParams {
1759 period: Some(period),
1760 };
1761 let stream =
1762 TsfStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1763 Ok(TsfStreamPy { stream })
1764 }
1765
1766 fn update(&mut self, value: f64) -> Option<f64> {
1767 self.stream.update(value)
1768 }
1769}
1770
1771#[cfg(feature = "python")]
1772#[pyfunction(name = "tsf_batch")]
1773#[pyo3(signature = (data, period_range, kernel=None))]
1774pub fn tsf_batch_py<'py>(
1775 py: Python<'py>,
1776 data: PyReadonlyArray1<'py, f64>,
1777 period_range: (usize, usize, usize),
1778 kernel: Option<&str>,
1779) -> PyResult<Bound<'py, PyDict>> {
1780 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1781 use pyo3::types::PyDict;
1782
1783 let slice_in = data.as_slice()?;
1784
1785 let sweep = TsfBatchRange {
1786 period: period_range,
1787 };
1788
1789 let combos = expand_grid(&sweep);
1790 let rows = combos.len();
1791 let cols = slice_in.len();
1792
1793 let total = rows
1794 .checked_mul(cols)
1795 .ok_or_else(|| PyValueError::new_err("tsf_batch_py: rows*cols overflow"))?;
1796 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1797 let slice_out = unsafe { out_arr.as_slice_mut()? };
1798
1799 let kern = validate_kernel(kernel, true)?;
1800
1801 let combos = py
1802 .allow_threads(|| {
1803 let kernel = match kern {
1804 Kernel::Auto => detect_best_batch_kernel(),
1805 k => k,
1806 };
1807 let simd = match kernel {
1808 Kernel::Avx512Batch => Kernel::Avx512,
1809 Kernel::Avx2Batch => Kernel::Avx2,
1810 Kernel::ScalarBatch => Kernel::Scalar,
1811 _ => unreachable!(),
1812 };
1813 tsf_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1814 })
1815 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1816
1817 let dict = PyDict::new(py);
1818 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1819 dict.set_item(
1820 "periods",
1821 combos
1822 .iter()
1823 .map(|p| p.period.unwrap() as u64)
1824 .collect::<Vec<_>>()
1825 .into_pyarray(py),
1826 )?;
1827
1828 Ok(dict)
1829}
1830
1831#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1832#[wasm_bindgen]
1833pub fn tsf_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1834 let params = TsfParams {
1835 period: Some(period),
1836 };
1837 let input = TsfInput::from_slice(data, params);
1838
1839 let mut output = vec![0.0; data.len()];
1840
1841 tsf_into_slice(&mut output, &input, detect_best_kernel())
1842 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1843
1844 Ok(output)
1845}
1846
1847#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1848#[wasm_bindgen]
1849pub fn tsf_alloc(len: usize) -> *mut f64 {
1850 let mut vec = Vec::<f64>::with_capacity(len);
1851 let ptr = vec.as_mut_ptr();
1852 std::mem::forget(vec);
1853 ptr
1854}
1855
1856#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1857#[wasm_bindgen]
1858pub fn tsf_free(ptr: *mut f64, len: usize) {
1859 if !ptr.is_null() {
1860 unsafe {
1861 let _ = Vec::from_raw_parts(ptr, len, len);
1862 }
1863 }
1864}
1865
1866#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1867#[wasm_bindgen]
1868pub fn tsf_into(
1869 in_ptr: *const f64,
1870 out_ptr: *mut f64,
1871 len: usize,
1872 period: usize,
1873) -> Result<(), JsValue> {
1874 if in_ptr.is_null() || out_ptr.is_null() {
1875 return Err(JsValue::from_str("Null pointer provided"));
1876 }
1877
1878 unsafe {
1879 let data = std::slice::from_raw_parts(in_ptr, len);
1880 let params = TsfParams {
1881 period: Some(period),
1882 };
1883 let input = TsfInput::from_slice(data, params);
1884
1885 let kernel = detect_best_kernel();
1886 if in_ptr == out_ptr {
1887 let mut temp = vec![0.0; len];
1888 tsf_into_slice(&mut temp, &input, kernel)
1889 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1890 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1891 out.copy_from_slice(&temp);
1892 } else {
1893 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1894 tsf_into_slice(out, &input, kernel).map_err(|e| JsValue::from_str(&e.to_string()))?;
1895 }
1896 Ok(())
1897 }
1898}
1899
1900#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1901#[derive(Serialize, Deserialize)]
1902pub struct TsfBatchConfig {
1903 pub period_range: (usize, usize, usize),
1904}
1905
1906#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1907#[derive(Serialize, Deserialize)]
1908pub struct TsfBatchJsOutput {
1909 pub values: Vec<f64>,
1910 pub combos: Vec<TsfParams>,
1911 pub rows: usize,
1912 pub cols: usize,
1913}
1914
1915#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1916#[wasm_bindgen(js_name = tsf_batch)]
1917pub fn tsf_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1918 let config: TsfBatchConfig = serde_wasm_bindgen::from_value(config)
1919 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1920
1921 let sweep = TsfBatchRange {
1922 period: config.period_range,
1923 };
1924
1925 let kernel = detect_best_batch_kernel();
1926 let simd = match kernel {
1927 Kernel::Avx512Batch => Kernel::Avx512,
1928 Kernel::Avx2Batch => Kernel::Avx2,
1929 Kernel::ScalarBatch => Kernel::Scalar,
1930 _ => Kernel::Scalar,
1931 };
1932 let output = tsf_batch_inner(data, &sweep, simd, false)
1933 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1934
1935 let js_output = TsfBatchJsOutput {
1936 values: output.values,
1937 combos: output.combos,
1938 rows: output.rows,
1939 cols: output.cols,
1940 };
1941
1942 serde_wasm_bindgen::to_value(&js_output)
1943 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1944}
1945
1946#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1947#[wasm_bindgen]
1948pub fn tsf_batch_into(
1949 in_ptr: *const f64,
1950 out_ptr: *mut f64,
1951 len: usize,
1952 period_start: usize,
1953 period_end: usize,
1954 period_step: usize,
1955) -> Result<usize, JsValue> {
1956 if in_ptr.is_null() || out_ptr.is_null() {
1957 return Err(JsValue::from_str("null pointer passed to tsf_batch_into"));
1958 }
1959
1960 unsafe {
1961 let data = std::slice::from_raw_parts(in_ptr, len);
1962
1963 let sweep = TsfBatchRange {
1964 period: (period_start, period_end, period_step),
1965 };
1966
1967 let combos = expand_grid(&sweep);
1968 let rows = combos.len();
1969 let cols = len;
1970
1971 if rows == 0 {
1972 return Err(JsValue::from_str("No valid parameter combinations"));
1973 }
1974
1975 let total = rows
1976 .checked_mul(cols)
1977 .ok_or_else(|| JsValue::from_str("tsf_batch_into: rows*cols overflow"))?;
1978
1979 let out = std::slice::from_raw_parts_mut(out_ptr, total);
1980
1981 let kernel = detect_best_batch_kernel();
1982 let simd = match kernel {
1983 Kernel::Avx512Batch => Kernel::Avx512,
1984 Kernel::Avx2Batch => Kernel::Avx2,
1985 Kernel::ScalarBatch => Kernel::Scalar,
1986 _ => unreachable!(),
1987 };
1988
1989 tsf_batch_inner_into(data, &sweep, simd, true, out)
1990 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1991
1992 Ok(rows)
1993 }
1994}
1995
1996#[cfg(test)]
1997mod tests {
1998 use super::*;
1999 use crate::skip_if_unsupported;
2000 use crate::utilities::data_loader::read_candles_from_csv;
2001
2002 fn check_tsf_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2003 skip_if_unsupported!(kernel, test_name);
2004 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2005 let candles = read_candles_from_csv(file_path)?;
2006 let default_params = TsfParams { period: None };
2007 let input = TsfInput::from_candles(&candles, "close", default_params);
2008 let output = tsf_with_kernel(&input, kernel)?;
2009 assert_eq!(output.values.len(), candles.close.len());
2010 Ok(())
2011 }
2012
2013 fn check_tsf_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2014 skip_if_unsupported!(kernel, test_name);
2015 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2016 let candles = read_candles_from_csv(file_path)?;
2017 let input = TsfInput::from_candles(&candles, "close", TsfParams::default());
2018 let result = tsf_with_kernel(&input, kernel)?;
2019 let expected_last_five = [
2020 58846.945054945056,
2021 58818.83516483516,
2022 58854.57142857143,
2023 59083.846153846156,
2024 58962.25274725275,
2025 ];
2026 let start = result.values.len().saturating_sub(5);
2027 for (i, &val) in result.values[start..].iter().enumerate() {
2028 let diff = (val - expected_last_five[i]).abs();
2029 assert!(
2030 diff < 1e-1,
2031 "[{}] TSF {:?} mismatch at idx {}: got {}, expected {}",
2032 test_name,
2033 kernel,
2034 i,
2035 val,
2036 expected_last_five[i]
2037 );
2038 }
2039 Ok(())
2040 }
2041
2042 fn check_tsf_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2043 skip_if_unsupported!(kernel, test_name);
2044 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2045 let candles = read_candles_from_csv(file_path)?;
2046 let input = TsfInput::with_default_candles(&candles);
2047 match input.data {
2048 TsfData::Candles { source, .. } => assert_eq!(source, "close"),
2049 _ => panic!("Expected TsfData::Candles"),
2050 }
2051 let output = tsf_with_kernel(&input, kernel)?;
2052 assert_eq!(output.values.len(), candles.close.len());
2053 Ok(())
2054 }
2055
2056 fn check_tsf_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2057 skip_if_unsupported!(kernel, test_name);
2058 let input_data = [10.0, 20.0, 30.0];
2059 let params = TsfParams { period: Some(0) };
2060 let input = TsfInput::from_slice(&input_data, params);
2061 let res = tsf_with_kernel(&input, kernel);
2062 assert!(
2063 res.is_err(),
2064 "[{}] TSF should fail with zero period",
2065 test_name
2066 );
2067 if let Err(e) = res {
2068 assert!(matches!(e, TsfError::PeriodTooSmall { period: 0 }));
2069 }
2070 Ok(())
2071 }
2072
2073 fn check_tsf_period_one(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2074 skip_if_unsupported!(kernel, test_name);
2075 let input_data = [10.0, 20.0, 30.0, 40.0, 50.0];
2076 let params = TsfParams { period: Some(1) };
2077 let input = TsfInput::from_slice(&input_data, params);
2078 let res = tsf_with_kernel(&input, kernel);
2079 assert!(
2080 res.is_err(),
2081 "[{}] TSF should fail with period=1",
2082 test_name
2083 );
2084 if let Err(e) = res {
2085 assert!(matches!(e, TsfError::PeriodTooSmall { period: 1 }));
2086 }
2087 Ok(())
2088 }
2089
2090 fn check_tsf_mismatched_output_len(
2091 test_name: &str,
2092 kernel: Kernel,
2093 ) -> Result<(), Box<dyn Error>> {
2094 skip_if_unsupported!(kernel, test_name);
2095 let input_data = [10.0, 20.0, 30.0, 40.0, 50.0];
2096 let params = TsfParams { period: Some(3) };
2097 let input = TsfInput::from_slice(&input_data, params);
2098
2099 let mut dst = vec![0.0; 10];
2100 let res = tsf_into_slice(&mut dst, &input, kernel);
2101
2102 assert!(
2103 res.is_err(),
2104 "[{}] TSF should fail with mismatched output length",
2105 test_name
2106 );
2107 if let Err(e) = res {
2108 assert!(matches!(
2109 e,
2110 TsfError::OutputLengthMismatch {
2111 expected: 5,
2112 got: 10
2113 }
2114 ));
2115 }
2116 Ok(())
2117 }
2118
2119 fn check_tsf_period_exceeds_length(
2120 test_name: &str,
2121 kernel: Kernel,
2122 ) -> Result<(), Box<dyn Error>> {
2123 skip_if_unsupported!(kernel, test_name);
2124 let data_small = [10.0, 20.0, 30.0];
2125 let params = TsfParams { period: Some(10) };
2126 let input = TsfInput::from_slice(&data_small, params);
2127 let res = tsf_with_kernel(&input, kernel);
2128 assert!(
2129 res.is_err(),
2130 "[{}] TSF should fail with period exceeding length",
2131 test_name
2132 );
2133 Ok(())
2134 }
2135
2136 fn check_tsf_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2137 skip_if_unsupported!(kernel, test_name);
2138 let single_point = [42.0];
2139 let params = TsfParams { period: Some(9) };
2140 let input = TsfInput::from_slice(&single_point, params);
2141 let res = tsf_with_kernel(&input, kernel);
2142 assert!(
2143 res.is_err(),
2144 "[{}] TSF should fail with insufficient data",
2145 test_name
2146 );
2147 Ok(())
2148 }
2149
2150 fn check_tsf_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2151 skip_if_unsupported!(kernel, test_name);
2152 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2153 let candles = read_candles_from_csv(file_path)?;
2154 let first_params = TsfParams { period: Some(14) };
2155 let first_input = TsfInput::from_candles(&candles, "close", first_params);
2156 let first_result = tsf_with_kernel(&first_input, kernel)?;
2157 let second_params = TsfParams { period: Some(14) };
2158 let second_input = TsfInput::from_slice(&first_result.values, second_params);
2159 let second_result = tsf_with_kernel(&second_input, kernel)?;
2160 assert_eq!(second_result.values.len(), first_result.values.len());
2161 Ok(())
2162 }
2163
2164 fn check_tsf_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2165 skip_if_unsupported!(kernel, test_name);
2166 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2167 let candles = read_candles_from_csv(file_path)?;
2168 let input = TsfInput::from_candles(&candles, "close", TsfParams { period: Some(14) });
2169 let res = tsf_with_kernel(&input, kernel)?;
2170 assert_eq!(res.values.len(), candles.close.len());
2171 if res.values.len() > 240 {
2172 for (i, &val) in res.values[240..].iter().enumerate() {
2173 assert!(
2174 !val.is_nan(),
2175 "[{}] Found unexpected NaN at out-index {}",
2176 test_name,
2177 240 + i
2178 );
2179 }
2180 }
2181 Ok(())
2182 }
2183
2184 fn check_tsf_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2185 skip_if_unsupported!(kernel, test_name);
2186
2187 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2188 let candles = read_candles_from_csv(file_path)?;
2189
2190 let period = 14;
2191 let input = TsfInput::from_candles(
2192 &candles,
2193 "close",
2194 TsfParams {
2195 period: Some(period),
2196 },
2197 );
2198 let batch_output = tsf_with_kernel(&input, kernel)?.values;
2199
2200 let mut stream = TsfStream::try_new(TsfParams {
2201 period: Some(period),
2202 })?;
2203
2204 let mut stream_values = Vec::with_capacity(candles.close.len());
2205 for &price in &candles.close {
2206 match stream.update(price) {
2207 Some(tsf_val) => stream_values.push(tsf_val),
2208 None => stream_values.push(f64::NAN),
2209 }
2210 }
2211
2212 assert_eq!(batch_output.len(), stream_values.len());
2213 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
2214 if b.is_nan() && s.is_nan() {
2215 continue;
2216 }
2217 let diff = (b - s).abs();
2218 assert!(
2219 diff < 1e-9,
2220 "[{}] TSF streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
2221 test_name,
2222 i,
2223 b,
2224 s,
2225 diff
2226 );
2227 }
2228 Ok(())
2229 }
2230
2231 #[cfg(debug_assertions)]
2232 fn check_tsf_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2233 skip_if_unsupported!(kernel, test_name);
2234
2235 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2236 let candles = read_candles_from_csv(file_path)?;
2237
2238 let test_params = vec![
2239 TsfParams::default(),
2240 TsfParams { period: Some(1) },
2241 TsfParams { period: Some(2) },
2242 TsfParams { period: Some(5) },
2243 TsfParams { period: Some(7) },
2244 TsfParams { period: Some(10) },
2245 TsfParams { period: Some(20) },
2246 TsfParams { period: Some(30) },
2247 TsfParams { period: Some(50) },
2248 TsfParams { period: Some(100) },
2249 TsfParams { period: Some(200) },
2250 ];
2251
2252 for (param_idx, params) in test_params.iter().enumerate() {
2253 let input = TsfInput::from_candles(&candles, "close", params.clone());
2254 let output = tsf_with_kernel(&input, kernel)?;
2255
2256 for (i, &val) in output.values.iter().enumerate() {
2257 if val.is_nan() {
2258 continue;
2259 }
2260
2261 let bits = val.to_bits();
2262
2263 if bits == 0x11111111_11111111 {
2264 panic!(
2265 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2266 with params: period={} (param set {})",
2267 test_name,
2268 val,
2269 bits,
2270 i,
2271 params.period.unwrap_or(14),
2272 param_idx
2273 );
2274 }
2275
2276 if bits == 0x22222222_22222222 {
2277 panic!(
2278 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2279 with params: period={} (param set {})",
2280 test_name,
2281 val,
2282 bits,
2283 i,
2284 params.period.unwrap_or(14),
2285 param_idx
2286 );
2287 }
2288
2289 if bits == 0x33333333_33333333 {
2290 panic!(
2291 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2292 with params: period={} (param set {})",
2293 test_name,
2294 val,
2295 bits,
2296 i,
2297 params.period.unwrap_or(14),
2298 param_idx
2299 );
2300 }
2301 }
2302 }
2303
2304 Ok(())
2305 }
2306
2307 #[cfg(not(debug_assertions))]
2308 fn check_tsf_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2309 Ok(())
2310 }
2311
2312 #[cfg(feature = "proptest")]
2313 #[allow(clippy::float_cmp)]
2314 fn check_tsf_property(
2315 test_name: &str,
2316 kernel: Kernel,
2317 ) -> Result<(), Box<dyn std::error::Error>> {
2318 use proptest::prelude::*;
2319 skip_if_unsupported!(kernel, test_name);
2320
2321 let strat = (2usize..=30).prop_flat_map(|period| {
2322 (
2323 prop::collection::vec(
2324 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2325 period..400,
2326 ),
2327 Just(period),
2328 )
2329 });
2330
2331 proptest::test_runner::TestRunner::default()
2332 .run(&strat, |(data, period)| {
2333 let params = TsfParams {
2334 period: Some(period),
2335 };
2336 let input = TsfInput::from_slice(&data, params);
2337
2338 let TsfOutput { values: out } = tsf_with_kernel(&input, kernel).unwrap();
2339 let TsfOutput { values: ref_out } =
2340 tsf_with_kernel(&input, Kernel::Scalar).unwrap();
2341
2342 let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
2343 let warmup_end = first + period - 1;
2344
2345 for i in 0..warmup_end {
2346 prop_assert!(
2347 out[i].is_nan(),
2348 "[{}] Expected NaN during warmup at index {}, got {}",
2349 test_name,
2350 i,
2351 out[i]
2352 );
2353 }
2354
2355 for i in warmup_end..data.len() {
2356 let y = out[i];
2357 let r = ref_out[i];
2358
2359 if !y.is_finite() || !r.is_finite() {
2360 prop_assert!(
2361 y.to_bits() == r.to_bits(),
2362 "[{}] finite/NaN mismatch at idx {}: {} vs {}",
2363 test_name,
2364 i,
2365 y,
2366 r
2367 );
2368 continue;
2369 }
2370
2371 let y_bits = y.to_bits();
2372 let r_bits = r.to_bits();
2373 let ulp_diff: u64 = y_bits.abs_diff(r_bits);
2374
2375 prop_assert!(
2376 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
2377 "[{}] Kernel mismatch at idx {}: {} vs {} (ULP={})",
2378 test_name,
2379 i,
2380 y,
2381 r,
2382 ulp_diff
2383 );
2384 }
2385
2386 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) && !data.is_empty() {
2387 for i in warmup_end..data.len() {
2388 if out[i].is_finite() {
2389 prop_assert!(
2390 (out[i] - data[0]).abs() <= 1e-6,
2391 "[{}] Constant data: TSF at {} = {}, expected {}",
2392 test_name,
2393 i,
2394 out[i],
2395 data[0]
2396 );
2397 }
2398 }
2399 }
2400
2401 let data_min = data[first..].iter().fold(f64::INFINITY, |a, &b| a.min(b));
2402 let data_max = data[first..]
2403 .iter()
2404 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
2405 let data_range = data_max - data_min;
2406
2407 let bound_factor = 2.0;
2408 let lower_bound = data_min - bound_factor * data_range;
2409 let upper_bound = data_max + bound_factor * data_range;
2410
2411 for i in warmup_end..data.len() {
2412 if out[i].is_finite() && data_range > 1e-10 {
2413 prop_assert!(
2414 out[i] >= lower_bound && out[i] <= upper_bound,
2415 "[{}] TSF at {} = {} is outside bounds [{}, {}]",
2416 test_name,
2417 i,
2418 out[i],
2419 lower_bound,
2420 upper_bound
2421 );
2422 }
2423 }
2424
2425 let is_monotonic_inc = data.windows(2).all(|w| w[1] >= w[0] - 1e-10);
2426 let is_monotonic_dec = data.windows(2).all(|w| w[1] <= w[0] + 1e-10);
2427
2428 if is_monotonic_inc || is_monotonic_dec {
2429 let test_points =
2430 vec![warmup_end, (warmup_end + data.len()) / 2, data.len() - 1];
2431
2432 for &i in test_points.iter() {
2433 if i < data.len() && out[i].is_finite() {
2434 let window_end = i;
2435 let window_start = i.saturating_sub(period - 1);
2436
2437 if is_monotonic_inc {
2438 prop_assert!(
2439 out[i] >= data[window_end] - 1e-6,
2440 "[{}] Monotonic increasing: TSF at {} = {} < last value {}",
2441 test_name,
2442 i,
2443 out[i],
2444 data[window_end]
2445 );
2446 } else if is_monotonic_dec {
2447 prop_assert!(
2448 out[i] <= data[window_end] + 1e-6,
2449 "[{}] Monotonic decreasing: TSF at {} = {} > last value {}",
2450 test_name,
2451 i,
2452 out[i],
2453 data[window_end]
2454 );
2455 }
2456 }
2457 }
2458 }
2459
2460 if period == 2 && warmup_end < data.len() {
2461 let i = warmup_end;
2462 if out[i].is_finite() {
2463 let x0 = data[i - 1];
2464 let x1 = data[i];
2465
2466 let expected = 2.0 * x1 - x0;
2467 prop_assert!(
2468 (out[i] - expected).abs() <= 1e-6,
2469 "[{}] Period=2: TSF at {} = {}, expected {}",
2470 test_name,
2471 i,
2472 out[i],
2473 expected
2474 );
2475 }
2476 }
2477
2478 for i in warmup_end..data.len() {
2479 if data[i.saturating_sub(period - 1)..=i]
2480 .iter()
2481 .all(|x| x.is_finite())
2482 {
2483 prop_assert!(
2484 out[i].is_finite(),
2485 "[{}] TSF produced NaN/Inf at {} despite finite input window",
2486 test_name,
2487 i
2488 );
2489 }
2490 }
2491
2492 Ok(())
2493 })
2494 .unwrap();
2495
2496 Ok(())
2497 }
2498
2499 macro_rules! generate_all_tsf_tests {
2500 ($($test_fn:ident),*) => {
2501 paste::paste! {
2502 $(
2503 #[test]
2504 fn [<$test_fn _scalar_f64>]() {
2505 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2506 }
2507 )*
2508 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2509 $(
2510 #[test]
2511 fn [<$test_fn _avx2_f64>]() {
2512 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2513 }
2514 #[test]
2515 fn [<$test_fn _avx512_f64>]() {
2516 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2517 }
2518 )*
2519 }
2520 }
2521 }
2522
2523 generate_all_tsf_tests!(
2524 check_tsf_partial_params,
2525 check_tsf_accuracy,
2526 check_tsf_default_candles,
2527 check_tsf_zero_period,
2528 check_tsf_period_one,
2529 check_tsf_mismatched_output_len,
2530 check_tsf_period_exceeds_length,
2531 check_tsf_very_small_dataset,
2532 check_tsf_reinput,
2533 check_tsf_nan_handling,
2534 check_tsf_streaming,
2535 check_tsf_no_poison
2536 );
2537
2538 #[cfg(feature = "proptest")]
2539 generate_all_tsf_tests!(check_tsf_property);
2540
2541 #[test]
2542 fn test_tsf_into_matches_api() -> Result<(), Box<dyn Error>> {
2543 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2544 let candles = read_candles_from_csv(file_path)?;
2545 let input = TsfInput::from_candles(&candles, "close", TsfParams::default());
2546
2547 let baseline = tsf(&input)?.values;
2548
2549 let mut out = vec![0.0; candles.close.len()];
2550 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2551 {
2552 tsf_into(&input, &mut out)?;
2553 }
2554 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2555 {
2556 return Ok(());
2557 }
2558
2559 assert_eq!(baseline.len(), out.len());
2560
2561 fn eq_or_both_nan(a: f64, b: f64) -> bool {
2562 (a.is_nan() && b.is_nan()) || (a == b)
2563 }
2564 for i in 0..out.len() {
2565 assert!(
2566 eq_or_both_nan(baseline[i], out[i]),
2567 "Mismatch at index {i}: baseline={} out={}",
2568 baseline[i],
2569 out[i]
2570 );
2571 }
2572
2573 Ok(())
2574 }
2575
2576 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2577 skip_if_unsupported!(kernel, test);
2578
2579 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2580 let c = read_candles_from_csv(file)?;
2581
2582 let output = TsfBatchBuilder::new()
2583 .kernel(kernel)
2584 .apply_candles(&c, "close")?;
2585
2586 let def = TsfParams::default();
2587 let row = output.values_for(&def).expect("default row missing");
2588
2589 assert_eq!(row.len(), c.close.len());
2590
2591 let expected = [
2592 58846.945054945056,
2593 58818.83516483516,
2594 58854.57142857143,
2595 59083.846153846156,
2596 58962.25274725275,
2597 ];
2598 let start = row.len() - 5;
2599 for (i, &v) in row[start..].iter().enumerate() {
2600 assert!(
2601 (v - expected[i]).abs() < 1e-1,
2602 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2603 );
2604 }
2605 Ok(())
2606 }
2607
2608 #[cfg(debug_assertions)]
2609 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2610 skip_if_unsupported!(kernel, test);
2611
2612 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2613 let c = read_candles_from_csv(file)?;
2614
2615 let test_configs = vec![
2616 (2, 10, 2),
2617 (5, 25, 5),
2618 (20, 50, 10),
2619 (2, 5, 1),
2620 (14, 14, 0),
2621 (30, 60, 15),
2622 (50, 100, 25),
2623 (100, 200, 50),
2624 ];
2625
2626 for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
2627 let output = TsfBatchBuilder::new()
2628 .kernel(kernel)
2629 .period_range(p_start, p_end, p_step)
2630 .apply_candles(&c, "close")?;
2631
2632 for (idx, &val) in output.values.iter().enumerate() {
2633 if val.is_nan() {
2634 continue;
2635 }
2636
2637 let bits = val.to_bits();
2638 let row = idx / output.cols;
2639 let col = idx % output.cols;
2640 let combo = &output.combos[row];
2641
2642 if bits == 0x11111111_11111111 {
2643 panic!(
2644 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2645 at row {} col {} (flat index {}) with params: period={}",
2646 test,
2647 cfg_idx,
2648 val,
2649 bits,
2650 row,
2651 col,
2652 idx,
2653 combo.period.unwrap_or(14)
2654 );
2655 }
2656
2657 if bits == 0x22222222_22222222 {
2658 panic!(
2659 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2660 at row {} col {} (flat index {}) with params: period={}",
2661 test,
2662 cfg_idx,
2663 val,
2664 bits,
2665 row,
2666 col,
2667 idx,
2668 combo.period.unwrap_or(14)
2669 );
2670 }
2671
2672 if bits == 0x33333333_33333333 {
2673 panic!(
2674 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2675 at row {} col {} (flat index {}) with params: period={}",
2676 test,
2677 cfg_idx,
2678 val,
2679 bits,
2680 row,
2681 col,
2682 idx,
2683 combo.period.unwrap_or(14)
2684 );
2685 }
2686 }
2687 }
2688
2689 Ok(())
2690 }
2691
2692 #[cfg(not(debug_assertions))]
2693 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2694 Ok(())
2695 }
2696
2697 macro_rules! gen_batch_tests {
2698 ($fn_name:ident) => {
2699 paste::paste! {
2700 #[test] fn [<$fn_name _scalar>]() {
2701 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2702 }
2703 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2704 #[test] fn [<$fn_name _avx2>]() {
2705 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2706 }
2707 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2708 #[test] fn [<$fn_name _avx512>]() {
2709 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2710 }
2711 #[test] fn [<$fn_name _auto_detect>]() {
2712 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2713 }
2714 }
2715 };
2716 }
2717 gen_batch_tests!(check_batch_default_row);
2718 gen_batch_tests!(check_batch_no_poison);
2719}