1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
10use serde::{Deserialize, Serialize};
11#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
12use wasm_bindgen::prelude::*;
13
14#[cfg(all(feature = "python", feature = "cuda"))]
15use crate::cuda::moving_averages::DeviceArrayF32;
16use crate::indicators::ema::{EmaError, EmaParams, EmaStream};
17use crate::utilities::data_loader::{source_type, Candles};
18#[cfg(all(feature = "python", feature = "cuda"))]
19use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
20use crate::utilities::enums::Kernel;
21use crate::utilities::helpers::{alloc_with_nan_prefix, init_matrix_prefixes, make_uninit_matrix};
22#[cfg(feature = "python")]
23use crate::utilities::kernel_validation::validate_kernel;
24#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
25use core::arch::x86_64::*;
26#[cfg(all(feature = "python", feature = "cuda"))]
27use cust::context::Context;
28#[cfg(all(feature = "python", feature = "cuda"))]
29use cust::memory::DeviceBuffer;
30#[cfg(not(target_arch = "wasm32"))]
31use rayon::prelude::*;
32use std::convert::AsRef;
33#[cfg(all(feature = "python", feature = "cuda"))]
34use std::sync::Arc;
35use thiserror::Error;
36
37impl<'a> AsRef<[f64]> for TsiInput<'a> {
38 #[inline(always)]
39 fn as_ref(&self) -> &[f64] {
40 match &self.data {
41 TsiData::Slice(slice) => slice,
42 TsiData::Candles { candles, source } => source_type(candles, source),
43 }
44 }
45}
46
47#[derive(Debug, Clone)]
48pub enum TsiData<'a> {
49 Candles {
50 candles: &'a Candles,
51 source: &'a str,
52 },
53 Slice(&'a [f64]),
54}
55
56#[derive(Debug, Clone)]
57pub struct TsiOutput {
58 pub values: Vec<f64>,
59}
60
61#[derive(Debug, Clone)]
62#[cfg_attr(
63 all(target_arch = "wasm32", feature = "wasm"),
64 derive(Serialize, Deserialize)
65)]
66pub struct TsiParams {
67 pub long_period: Option<usize>,
68 pub short_period: Option<usize>,
69}
70
71impl Default for TsiParams {
72 fn default() -> Self {
73 Self {
74 long_period: Some(25),
75 short_period: Some(13),
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
81pub struct TsiInput<'a> {
82 pub data: TsiData<'a>,
83 pub params: TsiParams,
84}
85
86impl<'a> TsiInput<'a> {
87 #[inline]
88 pub fn from_candles(c: &'a Candles, s: &'a str, p: TsiParams) -> Self {
89 Self {
90 data: TsiData::Candles {
91 candles: c,
92 source: s,
93 },
94 params: p,
95 }
96 }
97 #[inline]
98 pub fn from_slice(sl: &'a [f64], p: TsiParams) -> Self {
99 Self {
100 data: TsiData::Slice(sl),
101 params: p,
102 }
103 }
104 #[inline]
105 pub fn with_default_candles(c: &'a Candles) -> Self {
106 Self::from_candles(c, "close", TsiParams::default())
107 }
108 #[inline]
109 pub fn get_long_period(&self) -> usize {
110 self.params.long_period.unwrap_or(25)
111 }
112 #[inline]
113 pub fn get_short_period(&self) -> usize {
114 self.params.short_period.unwrap_or(13)
115 }
116}
117
118#[derive(Copy, Clone, Debug)]
119pub struct TsiBuilder {
120 long_period: Option<usize>,
121 short_period: Option<usize>,
122 kernel: Kernel,
123}
124
125impl Default for TsiBuilder {
126 fn default() -> Self {
127 Self {
128 long_period: None,
129 short_period: None,
130 kernel: Kernel::Auto,
131 }
132 }
133}
134
135impl TsiBuilder {
136 #[inline(always)]
137 pub fn new() -> Self {
138 Self::default()
139 }
140 #[inline(always)]
141 pub fn long_period(mut self, n: usize) -> Self {
142 self.long_period = Some(n);
143 self
144 }
145 #[inline(always)]
146 pub fn short_period(mut self, n: usize) -> Self {
147 self.short_period = Some(n);
148 self
149 }
150 #[inline(always)]
151 pub fn kernel(mut self, k: Kernel) -> Self {
152 self.kernel = k;
153 self
154 }
155
156 #[inline(always)]
157 pub fn apply(self, c: &Candles) -> Result<TsiOutput, TsiError> {
158 let p = TsiParams {
159 long_period: self.long_period,
160 short_period: self.short_period,
161 };
162 let i = TsiInput::from_candles(c, "close", p);
163 tsi_with_kernel(&i, self.kernel)
164 }
165
166 #[inline(always)]
167 pub fn apply_slice(self, d: &[f64]) -> Result<TsiOutput, TsiError> {
168 let p = TsiParams {
169 long_period: self.long_period,
170 short_period: self.short_period,
171 };
172 let i = TsiInput::from_slice(d, p);
173 tsi_with_kernel(&i, self.kernel)
174 }
175
176 #[inline(always)]
177 pub fn into_stream(self) -> Result<TsiStream, TsiError> {
178 let p = TsiParams {
179 long_period: self.long_period,
180 short_period: self.short_period,
181 };
182 TsiStream::try_new(p)
183 }
184}
185
186#[derive(Debug, Error)]
187pub enum TsiError {
188 #[error("tsi: Input data slice is empty.")]
189 EmptyInputData,
190 #[error("tsi: All values are NaN.")]
191 AllValuesNaN,
192 #[error("tsi: Invalid period: long = {long_period}, short = {short_period}, data length = {data_len}")]
193 InvalidPeriod {
194 long_period: usize,
195 short_period: usize,
196 data_len: usize,
197 },
198 #[error("tsi: Not enough valid data: needed = {needed}, valid = {valid}")]
199 NotEnoughValidData { needed: usize, valid: usize },
200 #[error("tsi: Non-batch kernel passed to batch path: {0:?}")]
201 InvalidKernelForBatch(Kernel),
202 #[error("tsi: Output length mismatch: expected = {expected}, got = {got}")]
203 OutputLengthMismatch { expected: usize, got: usize },
204 #[error("tsi: Invalid range expansion: start = {start}, end = {end}, step = {step}")]
205 InvalidRange {
206 start: usize,
207 end: usize,
208 step: usize,
209 },
210 #[error("tsi: size overflow computing rows*cols")]
211 SizeOverflow,
212 #[error("tsi: EMA sub-error: {0}")]
213 EmaSubError(#[from] EmaError),
214}
215
216#[inline(always)]
217fn tsi_prepare<'a>(input: &'a TsiInput) -> Result<(&'a [f64], usize, usize, usize), TsiError> {
218 let data: &[f64] = input.as_ref();
219 let len = data.len();
220 if len == 0 {
221 return Err(TsiError::EmptyInputData);
222 }
223 let first = data
224 .iter()
225 .position(|x| !x.is_nan())
226 .ok_or(TsiError::AllValuesNaN)?;
227 let long = input.get_long_period();
228 let short = input.get_short_period();
229
230 if long == 0 || short == 0 || long > len || short > len {
231 return Err(TsiError::InvalidPeriod {
232 long_period: long,
233 short_period: short,
234 data_len: len,
235 });
236 }
237 let needed = 1 + long + short;
238 if len - first < needed {
239 return Err(TsiError::NotEnoughValidData {
240 needed,
241 valid: len - first,
242 });
243 }
244
245 Ok((data, long, short, first))
246}
247
248#[inline(always)]
249fn tsi_compute_into_streaming(
250 data: &[f64],
251 long: usize,
252 short: usize,
253 first: usize,
254 out: &mut [f64],
255) -> Result<(), TsiError> {
256 let mut ema_long_num = EmaStream::try_new(EmaParams { period: Some(long) })?;
257 let mut ema_short_num = EmaStream::try_new(EmaParams {
258 period: Some(short),
259 })?;
260 let mut ema_long_den = EmaStream::try_new(EmaParams { period: Some(long) })?;
261 let mut ema_short_den = EmaStream::try_new(EmaParams {
262 period: Some(short),
263 })?;
264
265 let warmup_end = first + long + short;
266 if out.len() != data.len() {
267 return Err(TsiError::OutputLengthMismatch {
268 expected: data.len(),
269 got: out.len(),
270 });
271 }
272
273 if first + 1 >= data.len() {
274 return Ok(());
275 }
276 let mut prev = data[first];
277
278 for i in (first + 1)..data.len() {
279 let cur = data[i];
280 if !cur.is_finite() {
281 out[i] = f64::NAN;
282 continue;
283 }
284
285 let m = cur - prev;
286 prev = cur;
287
288 let n1 = match ema_long_num.update(m) {
289 Some(v) => v,
290 None => {
291 out[i] = f64::NAN;
292 continue;
293 }
294 };
295 let n2 = match ema_short_num.update(n1) {
296 Some(v) => v,
297 None => {
298 out[i] = f64::NAN;
299 continue;
300 }
301 };
302
303 let d1 = match ema_long_den.update(m.abs()) {
304 Some(v) => v,
305 None => {
306 out[i] = f64::NAN;
307 continue;
308 }
309 };
310 let d2 = match ema_short_den.update(d1) {
311 Some(v) => v,
312 None => {
313 out[i] = f64::NAN;
314 continue;
315 }
316 };
317
318 if i >= warmup_end {
319 out[i] = if d2 == 0.0 {
320 f64::NAN
321 } else {
322 (100.0 * (n2 / d2)).clamp(-100.0, 100.0)
323 };
324 }
325 }
326
327 Ok(())
328}
329
330#[inline(always)]
331fn tsi_compute_into_inline(
332 data: &[f64],
333 long: usize,
334 short: usize,
335 first: usize,
336 out: &mut [f64],
337) -> Result<(), TsiError> {
338 let n = data.len();
339 if n == 0 || first >= n {
340 return Ok(());
341 }
342
343 let long_alpha = 2.0 / (long as f64 + 1.0);
344 let short_alpha = 2.0 / (short as f64 + 1.0);
345 let long_1minus = 1.0 - long_alpha;
346 let short_1minus = 1.0 - short_alpha;
347
348 let warmup_end = first + long + short;
349
350 let mut prev = data[first];
351
352 let mut i = first + 1;
353 while i < n {
354 let cur = data[i];
355 if cur.is_finite() {
356 break;
357 } else {
358 if i >= warmup_end {
359 out[i] = f64::NAN;
360 }
361 i += 1;
362 }
363 }
364 if i >= n {
365 return Ok(());
366 }
367
368 let mut momentum = data[i] - prev;
369 prev = data[i];
370
371 let mut ema_long_num = momentum;
372 let mut ema_short_num = momentum;
373 let mut ema_long_den = momentum.abs();
374 let mut ema_short_den = ema_long_den;
375
376 let mut idx = i + 1;
377 let end_warm = warmup_end.min(n);
378 while idx < end_warm {
379 let cur = data[idx];
380 if cur.is_finite() {
381 momentum = cur - prev;
382 prev = cur;
383
384 let am = momentum.abs();
385
386 ema_long_num = long_alpha * momentum + long_1minus * ema_long_num;
387 ema_short_num = short_alpha * ema_long_num + short_1minus * ema_short_num;
388
389 ema_long_den = long_alpha * am + long_1minus * ema_long_den;
390 ema_short_den = short_alpha * ema_long_den + short_1minus * ema_short_den;
391 }
392
393 idx += 1;
394 }
395
396 while idx < n {
397 let cur = data[idx];
398 if cur.is_finite() {
399 momentum = cur - prev;
400 prev = cur;
401
402 let am = momentum.abs();
403
404 ema_long_num = long_alpha * momentum + long_1minus * ema_long_num;
405 ema_short_num = short_alpha * ema_long_num + short_1minus * ema_short_num;
406
407 ema_long_den = long_alpha * am + long_1minus * ema_long_den;
408 ema_short_den = short_alpha * ema_long_den + short_1minus * ema_short_den;
409
410 let den = ema_short_den;
411 let val = if den == 0.0 {
412 f64::NAN
413 } else {
414 (100.0 * (ema_short_num / den)).clamp(-100.0, 100.0)
415 };
416 out[idx] = val;
417 } else {
418 out[idx] = f64::NAN;
419 }
420 idx += 1;
421 }
422
423 Ok(())
424}
425
426#[inline]
427pub fn tsi(input: &TsiInput) -> Result<TsiOutput, TsiError> {
428 tsi_with_kernel(input, Kernel::Auto)
429}
430
431pub fn tsi_with_kernel(input: &TsiInput, kernel: Kernel) -> Result<TsiOutput, TsiError> {
432 let (data, long, short, first) = tsi_prepare(input)?;
433 let warmup_end = first + long + short;
434 let mut out = alloc_with_nan_prefix(data.len(), warmup_end);
435
436 let resolved_kernel = match kernel {
437 Kernel::Auto => Kernel::Scalar,
438 k => k,
439 };
440
441 if resolved_kernel == Kernel::Scalar && long == 25 && short == 13 {
442 unsafe {
443 tsi_scalar_classic(data, long, short, first, &mut out)?;
444 }
445 } else {
446 tsi_compute_into_inline(data, long, short, first, &mut out)?;
447 }
448 Ok(TsiOutput { values: out })
449}
450
451#[inline]
452pub fn tsi_into_slice(dst: &mut [f64], input: &TsiInput, kern: Kernel) -> Result<(), TsiError> {
453 let (data, long, short, first) = tsi_prepare(input)?;
454 let warmup_end = first + long + short;
455
456 let end = warmup_end.min(dst.len());
457 for v in &mut dst[..end] {
458 *v = f64::NAN;
459 }
460
461 let resolved_kernel = match kern {
462 Kernel::Auto => Kernel::Scalar,
463 k => k,
464 };
465
466 if resolved_kernel == Kernel::Scalar && long == 25 && short == 13 {
467 unsafe {
468 tsi_scalar_classic(data, long, short, first, dst)?;
469 }
470 } else {
471 tsi_compute_into_inline(data, long, short, first, dst)?;
472 }
473 Ok(())
474}
475
476#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
477#[inline]
478pub fn tsi_into(input: &TsiInput, out: &mut [f64]) -> Result<(), TsiError> {
479 let data_len = input.as_ref().len();
480 if out.len() != data_len {
481 return Err(TsiError::OutputLengthMismatch {
482 expected: data_len,
483 got: out.len(),
484 });
485 }
486 tsi_into_slice(out, input, Kernel::Auto)
487}
488
489#[inline]
490pub unsafe fn tsi_scalar(
491 data: &[f64],
492 long: usize,
493 short: usize,
494 first: usize,
495) -> Result<TsiOutput, TsiError> {
496 let warmup_end = first + long + short;
497 let mut out = alloc_with_nan_prefix(data.len(), warmup_end);
498
499 if long == 25 && short == 13 {
500 tsi_scalar_classic(data, long, short, first, &mut out)?;
501 } else {
502 tsi_compute_into_inline(data, long, short, first, &mut out)?;
503 }
504 Ok(TsiOutput { values: out })
505}
506
507#[inline]
508pub unsafe fn tsi_scalar_classic(
509 data: &[f64],
510 long: usize,
511 short: usize,
512 first: usize,
513 out: &mut [f64],
514) -> Result<(), TsiError> {
515 let n = data.len();
516 let warmup_end = first + long + short;
517
518 if first + 1 >= n {
519 return Ok(());
520 }
521
522 let long_alpha = 2.0 / (long as f64 + 1.0);
523 let short_alpha = 2.0 / (short as f64 + 1.0);
524 let long_1minus = 1.0 - long_alpha;
525 let short_1minus = 1.0 - short_alpha;
526
527 let mut prev = data[first];
528
529 if first + 1 >= n || !data[first + 1].is_finite() {
530 return Ok(());
531 }
532
533 let first_momentum = data[first + 1] - prev;
534 prev = data[first + 1];
535
536 let mut ema_long_num = first_momentum;
537 let mut ema_short_num = first_momentum;
538 let mut ema_long_den = first_momentum.abs();
539 let mut ema_short_den = first_momentum.abs();
540
541 for i in (first + 2)..n {
542 let cur = data[i];
543 if !cur.is_finite() {
544 out[i] = f64::NAN;
545 continue;
546 }
547
548 let momentum = cur - prev;
549 prev = cur;
550
551 ema_long_num = long_alpha * momentum + long_1minus * ema_long_num;
552
553 ema_short_num = short_alpha * ema_long_num + short_1minus * ema_short_num;
554
555 ema_long_den = long_alpha * momentum.abs() + long_1minus * ema_long_den;
556
557 ema_short_den = short_alpha * ema_long_den + short_1minus * ema_short_den;
558
559 if i >= warmup_end {
560 out[i] = if ema_short_den == 0.0 {
561 f64::NAN
562 } else {
563 (100.0 * (ema_short_num / ema_short_den)).clamp(-100.0, 100.0)
564 };
565 }
566 }
567
568 Ok(())
569}
570
571#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
572#[inline]
573pub unsafe fn tsi_avx2(
574 data: &[f64],
575 long: usize,
576 short: usize,
577 first: usize,
578) -> Result<TsiOutput, TsiError> {
579 tsi_scalar(data, long, short, first)
580}
581#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
582#[inline]
583pub unsafe fn tsi_avx512(
584 data: &[f64],
585 long: usize,
586 short: usize,
587 first: usize,
588) -> Result<TsiOutput, TsiError> {
589 if long <= 32 && short <= 32 {
590 tsi_avx512_short(data, long, short, first)
591 } else {
592 tsi_avx512_long(data, long, short, first)
593 }
594}
595#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
596#[inline]
597pub unsafe fn tsi_avx512_short(
598 data: &[f64],
599 long: usize,
600 short: usize,
601 first: usize,
602) -> Result<TsiOutput, TsiError> {
603 tsi_scalar(data, long, short, first)
604}
605#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
606#[inline]
607pub unsafe fn tsi_avx512_long(
608 data: &[f64],
609 long: usize,
610 short: usize,
611 first: usize,
612) -> Result<TsiOutput, TsiError> {
613 tsi_scalar(data, long, short, first)
614}
615
616#[derive(Debug, Clone)]
617pub struct TsiStream {
618 long: usize,
619 short: usize,
620
621 alpha_l: f64,
622 alpha_s: f64,
623
624 ema_long_num: f64,
625 ema_short_num: f64,
626 ema_long_den: f64,
627 ema_short_den: f64,
628
629 prev_price: f64,
630 have_prev: bool,
631
632 seeded: bool,
633
634 warmup_ctr: usize,
635 warmup_needed: usize,
636}
637impl TsiStream {
638 #[inline]
639 pub fn try_new(params: TsiParams) -> Result<Self, TsiError> {
640 let long = params.long_period.unwrap_or(25);
641 let short = params.short_period.unwrap_or(13);
642
643 if long == 0 || short == 0 {
644 return Err(TsiError::InvalidPeriod {
645 long_period: long,
646 short_period: short,
647 data_len: 0,
648 });
649 }
650
651 let alpha_l = 2.0 / (long as f64 + 1.0);
652 let alpha_s = 2.0 / (short as f64 + 1.0);
653
654 Ok(Self {
655 long,
656 short,
657 alpha_l,
658 alpha_s,
659 ema_long_num: 0.0,
660 ema_short_num: 0.0,
661 ema_long_den: 0.0,
662 ema_short_den: 0.0,
663 prev_price: f64::NAN,
664 have_prev: false,
665 seeded: false,
666 warmup_ctr: 0,
667 warmup_needed: long + short,
668 })
669 }
670
671 #[inline(always)]
672 pub fn update(&mut self, value: f64) -> Option<f64> {
673 if !value.is_finite() {
674 return None;
675 }
676
677 if !self.have_prev {
678 self.prev_price = value;
679 self.have_prev = true;
680 return None;
681 }
682
683 let m = value - self.prev_price;
684 self.prev_price = value;
685 let am = m.abs();
686
687 if !self.seeded {
688 self.ema_long_num = m;
689 self.ema_short_num = m;
690 self.ema_long_den = am;
691 self.ema_short_den = am;
692 self.seeded = true;
693 self.warmup_ctr = 1;
694 return None;
695 }
696
697 self.ema_long_num += self.alpha_l * (m - self.ema_long_num);
698 self.ema_short_num += self.alpha_s * (self.ema_long_num - self.ema_short_num);
699
700 self.ema_long_den += self.alpha_l * (am - self.ema_long_den);
701 self.ema_short_den += self.alpha_s * (self.ema_long_den - self.ema_short_den);
702
703 self.warmup_ctr += 1;
704
705 if self.warmup_ctr < self.warmup_needed {
706 return None;
707 }
708
709 let den = self.ema_short_den;
710 if den == 0.0 {
711 return Some(f64::NAN);
712 }
713
714 let tsi = 100.0 * (self.ema_short_num / den);
715 Some(tsi.clamp(-100.0, 100.0))
716 }
717}
718
719#[derive(Clone, Debug)]
720pub struct TsiBatchRange {
721 pub long_period: (usize, usize, usize),
722 pub short_period: (usize, usize, usize),
723}
724impl Default for TsiBatchRange {
725 fn default() -> Self {
726 Self {
727 long_period: (25, 274, 1),
728 short_period: (13, 13, 0),
729 }
730 }
731}
732
733#[derive(Clone, Debug, Default)]
734pub struct TsiBatchBuilder {
735 range: TsiBatchRange,
736 kernel: Kernel,
737}
738impl TsiBatchBuilder {
739 pub fn new() -> Self {
740 Self::default()
741 }
742 pub fn kernel(mut self, k: Kernel) -> Self {
743 self.kernel = k;
744 self
745 }
746 pub fn long_range(mut self, start: usize, end: usize, step: usize) -> Self {
747 self.range.long_period = (start, end, step);
748 self
749 }
750 pub fn long_static(mut self, n: usize) -> Self {
751 self.range.long_period = (n, n, 1);
752 self
753 }
754 pub fn short_range(mut self, start: usize, end: usize, step: usize) -> Self {
755 self.range.short_period = (start, end, step);
756 self
757 }
758 pub fn short_static(mut self, n: usize) -> Self {
759 self.range.short_period = (n, n, 1);
760 self
761 }
762 pub fn apply_slice(self, data: &[f64]) -> Result<TsiBatchOutput, TsiError> {
763 tsi_batch_with_kernel(data, &self.range, self.kernel)
764 }
765 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<TsiBatchOutput, TsiError> {
766 TsiBatchBuilder::new().kernel(k).apply_slice(data)
767 }
768 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<TsiBatchOutput, TsiError> {
769 let slice = source_type(c, src);
770 self.apply_slice(slice)
771 }
772 pub fn with_default_candles(c: &Candles) -> Result<TsiBatchOutput, TsiError> {
773 TsiBatchBuilder::new()
774 .kernel(Kernel::Auto)
775 .apply_candles(c, "close")
776 }
777}
778
779pub fn tsi_batch_with_kernel(
780 data: &[f64],
781 sweep: &TsiBatchRange,
782 k: Kernel,
783) -> Result<TsiBatchOutput, TsiError> {
784 let kernel = match k {
785 Kernel::Auto => Kernel::ScalarBatch,
786 other if other.is_batch() => other,
787 _ => {
788 return Err(TsiError::InvalidKernelForBatch(k));
789 }
790 };
791
792 let simd = match kernel {
793 Kernel::Avx512Batch => Kernel::Avx512,
794 Kernel::Avx2Batch => Kernel::Avx2,
795 Kernel::ScalarBatch => Kernel::Scalar,
796 _ => unreachable!(),
797 };
798 tsi_batch_par_slice(data, sweep, simd)
799}
800
801#[derive(Clone, Debug)]
802pub struct TsiBatchOutput {
803 pub values: Vec<f64>,
804 pub combos: Vec<TsiParams>,
805 pub rows: usize,
806 pub cols: usize,
807}
808impl TsiBatchOutput {
809 pub fn row_for_params(&self, p: &TsiParams) -> Option<usize> {
810 self.combos.iter().position(|c| {
811 c.long_period.unwrap_or(25) == p.long_period.unwrap_or(25)
812 && c.short_period.unwrap_or(13) == p.short_period.unwrap_or(13)
813 })
814 }
815 pub fn values_for(&self, p: &TsiParams) -> Option<&[f64]> {
816 self.row_for_params(p).map(|row| {
817 let start = row * self.cols;
818 &self.values[start..start + self.cols]
819 })
820 }
821}
822
823#[inline(always)]
824fn expand_grid(r: &TsiBatchRange) -> Result<Vec<TsiParams>, TsiError> {
825 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, TsiError> {
826 if step == 0 || start == end {
827 return Ok(vec![start]);
828 }
829 if start < end {
830 let vals: Vec<usize> = (start..=end).step_by(step).collect();
831 if vals.is_empty() {
832 return Err(TsiError::InvalidRange { start, end, step });
833 }
834 return Ok(vals);
835 }
836
837 let mut v = start;
838 let mut out = Vec::new();
839 loop {
840 out.push(v);
841 let guard = end.saturating_add(step);
842 if v <= guard {
843 break;
844 }
845 v = v.saturating_sub(step);
846 if v == 0 && step == 0 {
847 break;
848 }
849 }
850 if out.is_empty() {
851 return Err(TsiError::InvalidRange { start, end, step });
852 }
853 if *out.last().unwrap() != end {
854 out.push(end);
855 }
856 Ok(out)
857 }
858
859 let longs = axis_usize(r.long_period)?;
860 let shorts = axis_usize(r.short_period)?;
861 if longs.is_empty() || shorts.is_empty() {
862 return Err(TsiError::InvalidRange {
863 start: r.long_period.0,
864 end: r.long_period.1,
865 step: r.long_period.2,
866 });
867 }
868
869 let total = longs
870 .len()
871 .checked_mul(shorts.len())
872 .ok_or(TsiError::SizeOverflow)?;
873
874 let mut out = Vec::with_capacity(total);
875 for &l in &longs {
876 for &s in &shorts {
877 out.push(TsiParams {
878 long_period: Some(l),
879 short_period: Some(s),
880 });
881 }
882 }
883 Ok(out)
884}
885
886#[inline(always)]
887pub fn tsi_batch_slice(
888 data: &[f64],
889 sweep: &TsiBatchRange,
890 kern: Kernel,
891) -> Result<TsiBatchOutput, TsiError> {
892 tsi_batch_inner(data, sweep, kern, false)
893}
894
895#[inline(always)]
896pub fn tsi_batch_par_slice(
897 data: &[f64],
898 sweep: &TsiBatchRange,
899 kern: Kernel,
900) -> Result<TsiBatchOutput, TsiError> {
901 tsi_batch_inner(data, sweep, kern, true)
902}
903
904fn tsi_batch_inner(
905 data: &[f64],
906 sweep: &TsiBatchRange,
907 kern: Kernel,
908 parallel: bool,
909) -> Result<TsiBatchOutput, TsiError> {
910 let combos = expand_grid(sweep)?;
911
912 let first = data
913 .iter()
914 .position(|x| !x.is_nan())
915 .ok_or(TsiError::AllValuesNaN)?;
916 let max_long = combos.iter().map(|c| c.long_period.unwrap()).max().unwrap();
917 let max_short = combos
918 .iter()
919 .map(|c| c.short_period.unwrap())
920 .max()
921 .unwrap();
922 let max_needed = 1 + max_long + max_short;
923 if data.len() - first < max_needed {
924 return Err(TsiError::NotEnoughValidData {
925 needed: max_needed,
926 valid: data.len() - first,
927 });
928 }
929
930 let rows = combos.len();
931 let cols = data.len();
932 if rows.checked_mul(cols).is_none() {
933 return Err(TsiError::SizeOverflow);
934 }
935
936 let mut buf_mu = make_uninit_matrix(rows, cols);
937
938 let warmup_periods: Vec<usize> = combos
939 .iter()
940 .map(|c| first + 1 + c.long_period.unwrap() + c.short_period.unwrap() - 1)
941 .collect();
942 init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
943
944 let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
945 let values: &mut [f64] = unsafe {
946 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
947 };
948
949 let do_row = |row: usize, out_row: &mut [f64]| {
950 let p = &combos[row];
951 let l = p.long_period.unwrap();
952 let s = p.short_period.unwrap();
953 match kern {
954 Kernel::Scalar => unsafe { tsi_row_scalar_into(data, l, s, first, out_row) },
955 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
956 Kernel::Avx2 => unsafe { tsi_row_avx2_into(data, l, s, first, out_row) },
957 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
958 Kernel::Avx512 => unsafe { tsi_row_avx512_into(data, l, s, first, out_row) },
959 _ => unreachable!(),
960 }
961 .unwrap();
962 };
963
964 if parallel {
965 #[cfg(not(target_arch = "wasm32"))]
966 {
967 values
968 .par_chunks_mut(cols)
969 .enumerate()
970 .for_each(|(row, slice)| do_row(row, slice));
971 }
972
973 #[cfg(target_arch = "wasm32")]
974 {
975 for (row, slice) in values.chunks_mut(cols).enumerate() {
976 do_row(row, slice);
977 }
978 }
979 } else {
980 for (row, slice) in values.chunks_mut(cols).enumerate() {
981 do_row(row, slice);
982 }
983 }
984
985 let values = unsafe {
986 Vec::from_raw_parts(
987 buf_guard.as_mut_ptr() as *mut f64,
988 buf_guard.len(),
989 buf_guard.capacity(),
990 )
991 };
992
993 Ok(TsiBatchOutput {
994 values,
995 combos,
996 rows,
997 cols,
998 })
999}
1000
1001#[inline(always)]
1002fn tsi_batch_inner_into(
1003 data: &[f64],
1004 sweep: &TsiBatchRange,
1005 kern: Kernel,
1006 parallel: bool,
1007 out: &mut [f64],
1008) -> Result<Vec<TsiParams>, TsiError> {
1009 let combos = expand_grid(sweep)?;
1010
1011 let first = data
1012 .iter()
1013 .position(|x| !x.is_nan())
1014 .ok_or(TsiError::AllValuesNaN)?;
1015 let max_long = combos.iter().map(|c| c.long_period.unwrap()).max().unwrap();
1016 let max_short = combos
1017 .iter()
1018 .map(|c| c.short_period.unwrap())
1019 .max()
1020 .unwrap();
1021 let max_needed = 1 + max_long + max_short;
1022 if data.len() - first < max_needed {
1023 return Err(TsiError::NotEnoughValidData {
1024 needed: max_needed,
1025 valid: data.len() - first,
1026 });
1027 }
1028
1029 let rows = combos.len();
1030 let cols = data.len();
1031 let expected = rows.checked_mul(cols).ok_or(TsiError::SizeOverflow)?;
1032 if out.len() != expected {
1033 return Err(TsiError::OutputLengthMismatch {
1034 expected,
1035 got: out.len(),
1036 });
1037 }
1038
1039 let do_row = |row: usize, out_row: &mut [f64]| {
1040 let p = &combos[row];
1041 let l = p.long_period.unwrap();
1042 let s = p.short_period.unwrap();
1043 match kern {
1044 Kernel::Scalar => unsafe { tsi_row_scalar_into(data, l, s, first, out_row) },
1045 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1046 Kernel::Avx2 => unsafe { tsi_row_avx2_into(data, l, s, first, out_row) },
1047 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1048 Kernel::Avx512 => unsafe { tsi_row_avx512_into(data, l, s, first, out_row) },
1049 _ => unreachable!(),
1050 }
1051 .unwrap();
1052 };
1053
1054 if parallel {
1055 #[cfg(not(target_arch = "wasm32"))]
1056 {
1057 out.par_chunks_mut(cols)
1058 .enumerate()
1059 .for_each(|(row, slice)| do_row(row, slice));
1060 }
1061 #[cfg(target_arch = "wasm32")]
1062 {
1063 for (row, slice) in out.chunks_mut(cols).enumerate() {
1064 do_row(row, slice);
1065 }
1066 }
1067 } else {
1068 for (row, slice) in out.chunks_mut(cols).enumerate() {
1069 do_row(row, slice);
1070 }
1071 }
1072
1073 Ok(combos)
1074}
1075
1076#[inline(always)]
1077pub unsafe fn tsi_row_scalar_into(
1078 data: &[f64],
1079 long: usize,
1080 short: usize,
1081 first: usize,
1082 out_row: &mut [f64],
1083) -> Result<(), TsiError> {
1084 tsi_compute_into_inline(data, long, short, first, out_row)
1085}
1086
1087#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1088#[inline(always)]
1089pub unsafe fn tsi_row_avx2_into(
1090 data: &[f64],
1091 long: usize,
1092 short: usize,
1093 first: usize,
1094 out_row: &mut [f64],
1095) -> Result<(), TsiError> {
1096 tsi_compute_into_streaming(data, long, short, first, out_row)
1097}
1098
1099#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1100#[inline(always)]
1101pub unsafe fn tsi_row_avx512_into(
1102 data: &[f64],
1103 long: usize,
1104 short: usize,
1105 first: usize,
1106 out_row: &mut [f64],
1107) -> Result<(), TsiError> {
1108 tsi_compute_into_streaming(data, long, short, first, out_row)
1109}
1110
1111#[cfg(test)]
1112mod tests {
1113 use super::*;
1114 use crate::skip_if_unsupported;
1115 use crate::utilities::data_loader::read_candles_from_csv;
1116 use paste::paste;
1117
1118 fn check_tsi_partial_params(
1119 test_name: &str,
1120 kernel: Kernel,
1121 ) -> Result<(), Box<dyn std::error::Error>> {
1122 skip_if_unsupported!(kernel, test_name);
1123 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1124 let candles = read_candles_from_csv(file_path)?;
1125
1126 let default_params = TsiParams {
1127 long_period: None,
1128 short_period: None,
1129 };
1130 let input = TsiInput::from_candles(&candles, "close", default_params);
1131 let output = tsi_with_kernel(&input, kernel)?;
1132 assert_eq!(output.values.len(), candles.close.len());
1133 Ok(())
1134 }
1135
1136 fn check_tsi_accuracy(
1137 test_name: &str,
1138 kernel: Kernel,
1139 ) -> Result<(), Box<dyn std::error::Error>> {
1140 skip_if_unsupported!(kernel, test_name);
1141 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1142 let candles = read_candles_from_csv(file_path)?;
1143
1144 let params = TsiParams {
1145 long_period: Some(25),
1146 short_period: Some(13),
1147 };
1148 let input = TsiInput::from_candles(&candles, "close", params);
1149 let tsi_result = tsi_with_kernel(&input, kernel)?;
1150
1151 let expected_last_five = [
1152 -17.757654061849838,
1153 -17.367527062626184,
1154 -17.305577681249513,
1155 -16.937565646991143,
1156 -17.61825617316731,
1157 ];
1158 let start = tsi_result.values.len().saturating_sub(5);
1159 for (i, &val) in tsi_result.values[start..].iter().enumerate() {
1160 let diff = (val - expected_last_five[i]).abs();
1161 assert!(
1162 diff < 1e-7,
1163 "[{}] TSI {:?} mismatch at idx {}: got {}, expected {}",
1164 test_name,
1165 kernel,
1166 i,
1167 val,
1168 expected_last_five[i]
1169 );
1170 }
1171 Ok(())
1172 }
1173
1174 fn check_tsi_zero_period(
1175 test_name: &str,
1176 kernel: Kernel,
1177 ) -> Result<(), Box<dyn std::error::Error>> {
1178 skip_if_unsupported!(kernel, test_name);
1179 let input_data = [10.0, 20.0, 30.0];
1180 let params = TsiParams {
1181 long_period: Some(0),
1182 short_period: Some(13),
1183 };
1184 let input = TsiInput::from_slice(&input_data, params);
1185 let res = tsi_with_kernel(&input, kernel);
1186 assert!(
1187 res.is_err(),
1188 "[{}] TSI should fail with zero period",
1189 test_name
1190 );
1191 Ok(())
1192 }
1193
1194 fn check_tsi_period_exceeds_length(
1195 test_name: &str,
1196 kernel: Kernel,
1197 ) -> Result<(), Box<dyn std::error::Error>> {
1198 skip_if_unsupported!(kernel, test_name);
1199 let data_small = [10.0, 20.0, 30.0];
1200 let params = TsiParams {
1201 long_period: Some(25),
1202 short_period: Some(13),
1203 };
1204 let input = TsiInput::from_slice(&data_small, params);
1205 let res = tsi_with_kernel(&input, kernel);
1206 assert!(
1207 res.is_err(),
1208 "[{}] TSI should fail with period exceeding length",
1209 test_name
1210 );
1211 Ok(())
1212 }
1213
1214 fn check_tsi_very_small_dataset(
1215 test_name: &str,
1216 kernel: Kernel,
1217 ) -> Result<(), Box<dyn std::error::Error>> {
1218 skip_if_unsupported!(kernel, test_name);
1219 let single_point = [42.0];
1220 let params = TsiParams {
1221 long_period: Some(25),
1222 short_period: Some(13),
1223 };
1224 let input = TsiInput::from_slice(&single_point, params);
1225 let res = tsi_with_kernel(&input, kernel);
1226 assert!(
1227 res.is_err(),
1228 "[{}] TSI should fail with insufficient data",
1229 test_name
1230 );
1231 Ok(())
1232 }
1233
1234 fn check_tsi_default_candles(
1235 test_name: &str,
1236 kernel: Kernel,
1237 ) -> Result<(), Box<dyn std::error::Error>> {
1238 skip_if_unsupported!(kernel, test_name);
1239 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1240 let candles = read_candles_from_csv(file_path)?;
1241
1242 let input = TsiInput::with_default_candles(&candles);
1243 match input.data {
1244 TsiData::Candles { source, .. } => assert_eq!(source, "close"),
1245 _ => panic!("Expected TsiData::Candles"),
1246 }
1247 let output = tsi_with_kernel(&input, kernel)?;
1248 assert_eq!(output.values.len(), candles.close.len());
1249 Ok(())
1250 }
1251
1252 fn check_tsi_reinput(
1253 test_name: &str,
1254 kernel: Kernel,
1255 ) -> Result<(), Box<dyn std::error::Error>> {
1256 skip_if_unsupported!(kernel, test_name);
1257 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1258 let candles = read_candles_from_csv(file_path)?;
1259
1260 let first_params = TsiParams {
1261 long_period: Some(25),
1262 short_period: Some(13),
1263 };
1264 let first_input = TsiInput::from_candles(&candles, "close", first_params);
1265 let first_result = tsi_with_kernel(&first_input, kernel)?;
1266
1267 let second_params = TsiParams {
1268 long_period: Some(25),
1269 short_period: Some(13),
1270 };
1271 let second_input = TsiInput::from_slice(&first_result.values, second_params);
1272 let second_result = tsi_with_kernel(&second_input, kernel)?;
1273
1274 assert_eq!(second_result.values.len(), first_result.values.len());
1275 Ok(())
1276 }
1277
1278 fn check_tsi_nan_handling(
1279 test_name: &str,
1280 kernel: Kernel,
1281 ) -> Result<(), Box<dyn std::error::Error>> {
1282 skip_if_unsupported!(kernel, test_name);
1283 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1284 let candles = read_candles_from_csv(file_path)?;
1285
1286 let input = TsiInput::from_candles(
1287 &candles,
1288 "close",
1289 TsiParams {
1290 long_period: Some(25),
1291 short_period: Some(13),
1292 },
1293 );
1294 let res = tsi_with_kernel(&input, kernel)?;
1295 assert_eq!(res.values.len(), candles.close.len());
1296 if res.values.len() > 240 {
1297 for (i, &val) in res.values[240..].iter().enumerate() {
1298 assert!(
1299 !val.is_nan(),
1300 "[{}] Found unexpected NaN at out-index {}",
1301 test_name,
1302 240 + i
1303 );
1304 }
1305 }
1306 Ok(())
1307 }
1308
1309 #[test]
1310 fn test_tsi_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1311 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1312 let candles = read_candles_from_csv(file_path)?;
1313 let input = TsiInput::from_candles(&candles, "close", TsiParams::default());
1314
1315 let baseline = tsi(&input)?.values;
1316
1317 let mut out = vec![0.0; candles.close.len()];
1318 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1319 {
1320 tsi_into(&input, &mut out)?;
1321 }
1322 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1323 {
1324 tsi_into_slice(&mut out, &input, Kernel::Auto)?;
1325 }
1326
1327 assert_eq!(out.len(), baseline.len());
1328 let eq_or_both_nan = |a: f64, b: f64| (a.is_nan() && b.is_nan()) || (a == b);
1329 for i in 0..out.len() {
1330 assert!(
1331 eq_or_both_nan(out[i], baseline[i]),
1332 "mismatch at {}: got {}, expected {}",
1333 i,
1334 out[i],
1335 baseline[i]
1336 );
1337 }
1338 Ok(())
1339 }
1340
1341 #[cfg(debug_assertions)]
1342 fn check_tsi_no_poison(
1343 test_name: &str,
1344 kernel: Kernel,
1345 ) -> Result<(), Box<dyn std::error::Error>> {
1346 skip_if_unsupported!(kernel, test_name);
1347
1348 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1349 let candles = read_candles_from_csv(file_path)?;
1350
1351 let test_params = vec![
1352 TsiParams::default(),
1353 TsiParams {
1354 long_period: Some(2),
1355 short_period: Some(2),
1356 },
1357 TsiParams {
1358 long_period: Some(5),
1359 short_period: Some(3),
1360 },
1361 TsiParams {
1362 long_period: Some(10),
1363 short_period: Some(5),
1364 },
1365 TsiParams {
1366 long_period: Some(15),
1367 short_period: Some(7),
1368 },
1369 TsiParams {
1370 long_period: Some(25),
1371 short_period: Some(13),
1372 },
1373 TsiParams {
1374 long_period: Some(30),
1375 short_period: Some(15),
1376 },
1377 TsiParams {
1378 long_period: Some(40),
1379 short_period: Some(20),
1380 },
1381 TsiParams {
1382 long_period: Some(50),
1383 short_period: Some(25),
1384 },
1385 TsiParams {
1386 long_period: Some(100),
1387 short_period: Some(50),
1388 },
1389 TsiParams {
1390 long_period: Some(200),
1391 short_period: Some(100),
1392 },
1393 TsiParams {
1394 long_period: Some(50),
1395 short_period: Some(2),
1396 },
1397 TsiParams {
1398 long_period: Some(100),
1399 short_period: Some(5),
1400 },
1401 TsiParams {
1402 long_period: Some(10),
1403 short_period: Some(9),
1404 },
1405 ];
1406
1407 for (param_idx, params) in test_params.iter().enumerate() {
1408 let input = TsiInput::from_candles(&candles, "close", params.clone());
1409 let output = tsi_with_kernel(&input, kernel)?;
1410
1411 for (i, &val) in output.values.iter().enumerate() {
1412 if val.is_nan() {
1413 continue;
1414 }
1415
1416 let bits = val.to_bits();
1417
1418 if bits == 0x11111111_11111111 {
1419 panic!(
1420 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1421 with params: {:?} (param set {})",
1422 test_name, val, bits, i, params, param_idx
1423 );
1424 }
1425
1426 if bits == 0x22222222_22222222 {
1427 panic!(
1428 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1429 with params: {:?} (param set {})",
1430 test_name, val, bits, i, params, param_idx
1431 );
1432 }
1433
1434 if bits == 0x33333333_33333333 {
1435 panic!(
1436 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1437 with params: {:?} (param set {})",
1438 test_name, val, bits, i, params, param_idx
1439 );
1440 }
1441 }
1442 }
1443
1444 Ok(())
1445 }
1446
1447 #[cfg(not(debug_assertions))]
1448 fn check_tsi_no_poison(
1449 _test_name: &str,
1450 _kernel: Kernel,
1451 ) -> Result<(), Box<dyn std::error::Error>> {
1452 Ok(())
1453 }
1454
1455 #[cfg(feature = "proptest")]
1456 #[allow(clippy::float_cmp)]
1457 fn check_tsi_property(
1458 test_name: &str,
1459 kernel: Kernel,
1460 ) -> Result<(), Box<dyn std::error::Error>> {
1461 use proptest::prelude::*;
1462 skip_if_unsupported!(kernel, test_name);
1463
1464 let strat = (2usize..=50).prop_flat_map(|long_period| {
1465 (
1466 prop::collection::vec(
1467 (1.0f64..10000.0f64).prop_filter("finite", |x| x.is_finite()),
1468 (long_period + 30)..400,
1469 ),
1470 Just(long_period),
1471 2usize..=long_period.min(25),
1472 )
1473 });
1474
1475 proptest::test_runner::TestRunner::default()
1476 .run(&strat, |(data, long_period, short_period)| {
1477 let params = TsiParams {
1478 long_period: Some(long_period),
1479 short_period: Some(short_period),
1480 };
1481 let input = TsiInput::from_slice(&data, params.clone());
1482
1483 let TsiOutput { values: out } = tsi_with_kernel(&input, kernel).unwrap();
1484 let TsiOutput { values: ref_out } =
1485 tsi_with_kernel(&input, Kernel::Scalar).unwrap();
1486
1487 for (i, &val) in out.iter().enumerate() {
1488 if !val.is_nan() {
1489 prop_assert!(
1490 val >= -100.0 - 1e-9 && val <= 100.0 + 1e-9,
1491 "Property 1: TSI value out of range at idx {}: {} ∉ [-100, 100]",
1492 i,
1493 val
1494 );
1495 }
1496 }
1497
1498 let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1499
1500 prop_assert!(
1501 out[0].is_nan(),
1502 "Property 2: First value should always be NaN, got {}",
1503 out[0]
1504 );
1505
1506 let has_variation = data.windows(2).any(|w| (w[0] - w[1]).abs() > 1e-10);
1507
1508 if has_variation {
1509 let first_non_nan = out.iter().position(|x| !x.is_nan());
1510
1511 if let Some(idx) = first_non_nan {
1512 prop_assert!(
1513 idx >= 1,
1514 "Property 2: First non-NaN too early at idx {}",
1515 idx
1516 );
1517
1518 let sufficient_data =
1519 (first_valid + long_period + short_period).min(out.len());
1520 if sufficient_data < out.len() {
1521 let mut has_movement = false;
1522 for i in 1..data.len() {
1523 if (data[i] - data[i - 1]).abs() > 1e-10 {
1524 has_movement = true;
1525 break;
1526 }
1527 }
1528
1529 if has_movement {
1530 let last_quarter_start = out.len() - (out.len() / 4).max(1);
1531
1532 let nan_count = out[last_quarter_start..]
1533 .iter()
1534 .filter(|v| v.is_nan())
1535 .count();
1536 let total_count = out.len() - last_quarter_start;
1537 prop_assert!(
1538 nan_count < total_count,
1539 "Property 2: All values in last quarter are NaN (from idx {}), which suggests an issue",
1540 last_quarter_start
1541 );
1542 }
1543 }
1544 }
1545 }
1546
1547 for (i, (&val, &ref_val)) in out.iter().zip(ref_out.iter()).enumerate() {
1548 if val.is_nan() && ref_val.is_nan() {
1549 continue;
1550 }
1551 if !val.is_finite() || !ref_val.is_finite() {
1552 prop_assert!(
1553 val.to_bits() == ref_val.to_bits(),
1554 "Property 3: Kernel finite/NaN mismatch at idx {}: {} vs {}",
1555 i,
1556 val,
1557 ref_val
1558 );
1559 continue;
1560 }
1561 prop_assert!(
1562 (val - ref_val).abs() <= 1e-9,
1563 "Property 3: Kernel mismatch at idx {}: {} vs {} (diff: {})",
1564 i,
1565 val,
1566 ref_val,
1567 (val - ref_val).abs()
1568 );
1569 }
1570
1571 let constant_data = vec![100.0; 50];
1572 let const_params = TsiParams {
1573 long_period: Some(10),
1574 short_period: Some(5),
1575 };
1576 let const_input = TsiInput::from_slice(&constant_data, const_params);
1577 if let Ok(TsiOutput { values: const_out }) = tsi_with_kernel(&const_input, kernel) {
1578 for (i, &val) in const_out.iter().enumerate() {
1579 prop_assert!(
1580 val.is_nan(),
1581 "Property 4: TSI should be NaN for constant prices at idx {}, got {}",
1582 i,
1583 val
1584 );
1585 }
1586 }
1587
1588 let uptrend: Vec<f64> = (0..100).map(|i| 100.0 + i as f64 * 20.0).collect();
1589 let uptrend_input = TsiInput::from_slice(&uptrend, params.clone());
1590
1591 let downtrend: Vec<f64> = (0..100).map(|i| 2100.0 - i as f64 * 20.0).collect();
1592 let downtrend_input = TsiInput::from_slice(&downtrend, params.clone());
1593
1594 if let (Ok(TsiOutput { values: up_out }), Ok(TsiOutput { values: down_out })) = (
1595 tsi_with_kernel(&uptrend_input, kernel),
1596 tsi_with_kernel(&downtrend_input, kernel),
1597 ) {
1598 let test_warmup = 1 + long_period + short_period - 1;
1599 if test_warmup + 10 < up_out.len() {
1600 let up_vals: Vec<f64> = up_out[up_out.len() - 10..]
1601 .iter()
1602 .filter(|&&x| !x.is_nan())
1603 .copied()
1604 .collect();
1605 let down_vals: Vec<f64> = down_out[down_out.len() - 10..]
1606 .iter()
1607 .filter(|&&x| !x.is_nan())
1608 .copied()
1609 .collect();
1610
1611 if !up_vals.is_empty() && !down_vals.is_empty() {
1612 let up_avg = up_vals.iter().sum::<f64>() / up_vals.len() as f64;
1613 let down_avg = down_vals.iter().sum::<f64>() / down_vals.len() as f64;
1614
1615 let tolerance = if long_period > 20 { 10.0 } else { 0.0 };
1616
1617 prop_assert!(
1618 up_avg > tolerance,
1619 "Property 5: Uptrend TSI should be positive, got avg: {} (long={}, short={})",
1620 up_avg, long_period, short_period
1621 );
1622 prop_assert!(
1623 down_avg < -tolerance,
1624 "Property 5: Downtrend TSI should be negative, got avg: {} (long={}, short={})",
1625 down_avg, long_period, short_period
1626 );
1627 }
1628 }
1629 }
1630
1631 let extreme_up: Vec<f64> = (0..100).map(|i| 100.0 + i as f64 * 50.0).collect();
1632 let extreme_params = TsiParams {
1633 long_period: Some(5),
1634 short_period: Some(3),
1635 };
1636 let extreme_input = TsiInput::from_slice(&extreme_up, extreme_params.clone());
1637 if let Ok(TsiOutput {
1638 values: extreme_out,
1639 }) = tsi_with_kernel(&extreme_input, kernel)
1640 {
1641 let last_valid = extreme_out.iter().rposition(|x| !x.is_nan());
1642 if let Some(idx) = last_valid {
1643 if idx >= 20 {
1644 let last_vals = &extreme_out[idx - 5..=idx];
1645 for &v in last_vals {
1646 if !v.is_nan() {
1647 prop_assert!(
1648 v > 80.0,
1649 "Property 6: Very strong uptrend should have TSI > 80, got: {}",
1650 v
1651 );
1652 }
1653 }
1654 }
1655 }
1656 }
1657
1658 let extreme_down: Vec<f64> = (0..100).map(|i| 5100.0 - i as f64 * 50.0).collect();
1659 let extreme_down_input = TsiInput::from_slice(&extreme_down, extreme_params);
1660 if let Ok(TsiOutput {
1661 values: extreme_down_out,
1662 }) = tsi_with_kernel(&extreme_down_input, kernel)
1663 {
1664 let last_valid = extreme_down_out.iter().rposition(|x| !x.is_nan());
1665 if let Some(idx) = last_valid {
1666 if idx >= 20 {
1667 let last_vals = &extreme_down_out[idx - 5..=idx];
1668 for &v in last_vals {
1669 if !v.is_nan() {
1670 prop_assert!(
1671 v < -80.0,
1672 "Property 6: Very strong downtrend should have TSI < -80, got: {}",
1673 v
1674 );
1675 }
1676 }
1677 }
1678 }
1679 }
1680
1681 if data.len() >= 20 {
1682 let increasing_count = data.windows(2).filter(|w| w[1] > w[0]).count();
1683 let decreasing_count = data.windows(2).filter(|w| w[1] < w[0]).count();
1684
1685 let valid_tsi: Vec<f64> = out
1686 .iter()
1687 .rev()
1688 .take(10)
1689 .filter(|&&x| !x.is_nan())
1690 .copied()
1691 .collect();
1692
1693 if !valid_tsi.is_empty() {
1694 let avg_tsi = valid_tsi.iter().sum::<f64>() / valid_tsi.len() as f64;
1695
1696 if increasing_count > decreasing_count * 2 {
1697 prop_assert!(
1698 avg_tsi > -20.0,
1699 "Property 7: Mostly increasing prices should have TSI > -20, got: {}",
1700 avg_tsi
1701 );
1702 } else if decreasing_count > increasing_count * 2 {
1703 prop_assert!(
1704 avg_tsi < 20.0,
1705 "Property 7: Mostly decreasing prices should have TSI < 20, got: {}",
1706 avg_tsi
1707 );
1708 }
1709 }
1710 }
1711
1712 #[cfg(debug_assertions)]
1713 {
1714 for (i, &val) in out.iter().enumerate() {
1715 if !val.is_nan() {
1716 let bits = val.to_bits();
1717 prop_assert!(
1718 bits != 0x11111111_11111111
1719 && bits != 0x22222222_22222222
1720 && bits != 0x33333333_33333333,
1721 "Property 8: Found poison value at idx {}: {} (0x{:016X})",
1722 i,
1723 val,
1724 bits
1725 );
1726 }
1727 }
1728 }
1729
1730 if data.len() >= 50 && has_variation {
1731 let start_idx = (1 + long_period + short_period).max(20);
1732 if start_idx + 20 < out.len() {
1733 let mut momentum_changes = 0;
1734 let mut tsi_follows = 0;
1735
1736 for i in start_idx..out.len() - 10 {
1737 if !out[i].is_nan() && !out[i + 5].is_nan() && !out[i + 10].is_nan() {
1738 let price_change1 = data[i + 5] - data[i];
1739 let price_change2 = data[i + 10] - data[i + 5];
1740
1741 let tsi_change1 = out[i + 5] - out[i];
1742 let tsi_change2 = out[i + 10] - out[i + 5];
1743
1744 if price_change1 * price_change2 < 0.0 {
1745 momentum_changes += 1;
1746
1747 if tsi_change1 * tsi_change2 < 0.0
1748 || (price_change2 > 0.0 && tsi_change2 > tsi_change1)
1749 || (price_change2 < 0.0 && tsi_change2 < tsi_change1)
1750 {
1751 tsi_follows += 1;
1752 }
1753 }
1754 }
1755 }
1756
1757 if momentum_changes > 0 {
1758 let follow_rate = tsi_follows as f64 / momentum_changes as f64;
1759 prop_assert!(
1760 follow_rate >= 0.3,
1761 "Property 9: TSI should respond to momentum changes, follow rate: {:.2}",
1762 follow_rate
1763 );
1764 }
1765 }
1766 }
1767
1768 Ok(())
1769 })
1770 .unwrap();
1771
1772 Ok(())
1773 }
1774
1775 macro_rules! generate_all_tsi_tests {
1776 ($($test_fn:ident),*) => {
1777 paste! {
1778 $(
1779 #[test]
1780 fn [<$test_fn _scalar_f64>]() {
1781 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1782 }
1783 )*
1784 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1785 $(
1786 #[test]
1787 fn [<$test_fn _avx2_f64>]() {
1788 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1789 }
1790 #[test]
1791 fn [<$test_fn _avx512_f64>]() {
1792 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1793 }
1794 )*
1795 }
1796 }
1797 }
1798 generate_all_tsi_tests!(
1799 check_tsi_partial_params,
1800 check_tsi_accuracy,
1801 check_tsi_zero_period,
1802 check_tsi_period_exceeds_length,
1803 check_tsi_very_small_dataset,
1804 check_tsi_default_candles,
1805 check_tsi_reinput,
1806 check_tsi_nan_handling,
1807 check_tsi_no_poison
1808 );
1809
1810 #[cfg(feature = "proptest")]
1811 generate_all_tsi_tests!(check_tsi_property);
1812
1813 fn check_batch_default_row(
1814 test: &str,
1815 kernel: Kernel,
1816 ) -> Result<(), Box<dyn std::error::Error>> {
1817 skip_if_unsupported!(kernel, test);
1818 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1819 let c = read_candles_from_csv(file)?;
1820 let output = TsiBatchBuilder::new()
1821 .kernel(kernel)
1822 .apply_candles(&c, "close")?;
1823 let def = TsiParams::default();
1824 let row = output.values_for(&def).expect("default row missing");
1825 assert_eq!(row.len(), c.close.len());
1826 let expected = [
1827 -17.757654061849838,
1828 -17.367527062626184,
1829 -17.305577681249513,
1830 -16.937565646991143,
1831 -17.61825617316731,
1832 ];
1833 let start = row.len() - 5;
1834 for (i, &v) in row[start..].iter().enumerate() {
1835 assert!(
1836 (v - expected[i]).abs() < 1e-7,
1837 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1838 );
1839 }
1840 Ok(())
1841 }
1842
1843 #[cfg(debug_assertions)]
1844 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1845 skip_if_unsupported!(kernel, test);
1846
1847 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1848 let c = read_candles_from_csv(file)?;
1849
1850 let test_configs = vec![
1851 (5, 10, 1, 2, 5, 1),
1852 (10, 30, 5, 5, 15, 5),
1853 (25, 50, 5, 10, 25, 5),
1854 (50, 100, 10, 25, 50, 5),
1855 (25, 25, 0, 2, 20, 2),
1856 (10, 50, 10, 13, 13, 0),
1857 (100, 200, 50, 50, 100, 25),
1858 (2, 5, 1, 2, 5, 1),
1859 (30, 30, 0, 5, 25, 5),
1860 ];
1861
1862 for (cfg_idx, &(l_start, l_end, l_step, s_start, s_end, s_step)) in
1863 test_configs.iter().enumerate()
1864 {
1865 let output = TsiBatchBuilder::new()
1866 .kernel(kernel)
1867 .long_range(l_start, l_end, l_step)
1868 .short_range(s_start, s_end, s_step)
1869 .apply_candles(&c, "close")?;
1870
1871 for (idx, &val) in output.values.iter().enumerate() {
1872 if val.is_nan() {
1873 continue;
1874 }
1875
1876 let bits = val.to_bits();
1877 let row = idx / output.cols;
1878 let col = idx % output.cols;
1879 let combo = &output.combos[row];
1880
1881 if bits == 0x11111111_11111111 {
1882 panic!(
1883 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1884 at row {} col {} (flat index {}) with params: {:?}",
1885 test, cfg_idx, val, bits, row, col, idx, combo
1886 );
1887 }
1888
1889 if bits == 0x22222222_22222222 {
1890 panic!(
1891 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1892 at row {} col {} (flat index {}) with params: {:?}",
1893 test, cfg_idx, val, bits, row, col, idx, combo
1894 );
1895 }
1896
1897 if bits == 0x33333333_33333333 {
1898 panic!(
1899 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1900 at row {} col {} (flat index {}) with params: {:?}",
1901 test, cfg_idx, val, bits, row, col, idx, combo
1902 );
1903 }
1904 }
1905 }
1906
1907 Ok(())
1908 }
1909
1910 #[cfg(not(debug_assertions))]
1911 fn check_batch_no_poison(
1912 _test: &str,
1913 _kernel: Kernel,
1914 ) -> Result<(), Box<dyn std::error::Error>> {
1915 Ok(())
1916 }
1917
1918 macro_rules! gen_batch_tests {
1919 ($fn_name:ident) => {
1920 paste! {
1921 #[test] fn [<$fn_name _scalar>]() {
1922 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1923 }
1924 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1925 #[test] fn [<$fn_name _avx2>]() {
1926 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1927 }
1928 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1929 #[test] fn [<$fn_name _avx512>]() {
1930 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1931 }
1932 #[test] fn [<$fn_name _auto_detect>]() {
1933 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1934 }
1935 }
1936 };
1937 }
1938 gen_batch_tests!(check_batch_default_row);
1939 gen_batch_tests!(check_batch_no_poison);
1940}
1941
1942#[cfg(feature = "python")]
1943#[pyfunction(name = "tsi")]
1944#[pyo3(signature = (data, long_period=25, short_period=13, kernel=None))]
1945pub fn tsi_py<'py>(
1946 py: Python<'py>,
1947 data: PyReadonlyArray1<'py, f64>,
1948 long_period: usize,
1949 short_period: usize,
1950 kernel: Option<&str>,
1951) -> PyResult<Bound<'py, PyArray1<f64>>> {
1952 use numpy::{IntoPyArray, PyArrayMethods};
1953
1954 let slice_in = data.as_slice()?;
1955 let kern = validate_kernel(kernel, false)?;
1956
1957 let params = TsiParams {
1958 long_period: Some(long_period),
1959 short_period: Some(short_period),
1960 };
1961 let input = TsiInput::from_slice(slice_in, params);
1962
1963 let result_vec: Vec<f64> = py
1964 .allow_threads(|| tsi_with_kernel(&input, kern).map(|o| o.values))
1965 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1966
1967 Ok(result_vec.into_pyarray(py))
1968}
1969
1970#[cfg(feature = "python")]
1971#[pyfunction(name = "tsi_batch")]
1972#[pyo3(signature = (data, long_period_range, short_period_range, kernel=None))]
1973pub fn tsi_batch_py<'py>(
1974 py: Python<'py>,
1975 data: PyReadonlyArray1<'py, f64>,
1976 long_period_range: (usize, usize, usize),
1977 short_period_range: (usize, usize, usize),
1978 kernel: Option<&str>,
1979) -> PyResult<Bound<'py, PyDict>> {
1980 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1981 use pyo3::types::PyDict;
1982
1983 let slice_in = data.as_slice()?;
1984
1985 let sweep = TsiBatchRange {
1986 long_period: long_period_range,
1987 short_period: short_period_range,
1988 };
1989
1990 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1991 let rows = combos.len();
1992 let cols = slice_in.len();
1993
1994 let total = rows
1995 .checked_mul(cols)
1996 .ok_or_else(|| PyValueError::new_err("tsi_batch: size overflow"))?;
1997
1998 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1999 let slice_out = unsafe { out_arr.as_slice_mut()? };
2000
2001 let kern = validate_kernel(kernel, true)?;
2002
2003 let combos = py
2004 .allow_threads(|| {
2005 let kernel = match kern {
2006 Kernel::Auto => Kernel::ScalarBatch,
2007 k => k,
2008 };
2009
2010 let simd = match kernel {
2011 Kernel::Avx512Batch => Kernel::Avx512,
2012 Kernel::Avx2Batch => Kernel::Avx2,
2013 Kernel::ScalarBatch => Kernel::Scalar,
2014 _ => kernel,
2015 };
2016 tsi_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2017 })
2018 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2019
2020 let dict = PyDict::new(py);
2021 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2022
2023 dict.set_item(
2024 "long_periods",
2025 combos
2026 .iter()
2027 .map(|p| p.long_period.unwrap() as u64)
2028 .collect::<Vec<_>>()
2029 .into_pyarray(py),
2030 )?;
2031
2032 dict.set_item(
2033 "short_periods",
2034 combos
2035 .iter()
2036 .map(|p| p.short_period.unwrap() as u64)
2037 .collect::<Vec<_>>()
2038 .into_pyarray(py),
2039 )?;
2040
2041 Ok(dict)
2042}
2043
2044#[cfg(all(feature = "python", feature = "cuda"))]
2045use crate::cuda::cuda_available;
2046#[cfg(all(feature = "python", feature = "cuda"))]
2047use crate::cuda::oscillators::tsi_wrapper::CudaTsi;
2048
2049#[cfg(all(feature = "python", feature = "cuda"))]
2050#[pyclass(module = "ta_indicators.cuda", name = "TsiDeviceArrayF32", unsendable)]
2051pub struct TsiDeviceArrayF32Py {
2052 pub(crate) inner: DeviceArrayF32,
2053 ctx_guard: Arc<Context>,
2054 device_id: u32,
2055}
2056
2057#[cfg(all(feature = "python", feature = "cuda"))]
2058#[pymethods]
2059impl TsiDeviceArrayF32Py {
2060 #[getter]
2061 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2062 let inner = &self.inner;
2063 let d = PyDict::new(py);
2064 let itemsize = std::mem::size_of::<f32>();
2065 d.set_item("shape", (inner.rows, inner.cols))?;
2066 d.set_item("typestr", "<f4")?;
2067 d.set_item("strides", (inner.cols * itemsize, itemsize))?;
2068 d.set_item("data", (inner.device_ptr() as usize, false))?;
2069
2070 d.set_item("version", 3)?;
2071 Ok(d)
2072 }
2073
2074 fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
2075 Ok((2, self.device_id as i32))
2076 }
2077
2078 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2079 fn __dlpack__<'py>(
2080 &mut self,
2081 py: Python<'py>,
2082 stream: Option<pyo3::PyObject>,
2083 max_version: Option<pyo3::PyObject>,
2084 dl_device: Option<pyo3::PyObject>,
2085 copy: Option<pyo3::PyObject>,
2086 ) -> PyResult<PyObject> {
2087 use cust::memory::DeviceBuffer;
2088
2089 let (kdl, alloc_dev) = self.__dlpack_device__()?;
2090 if let Some(dev_obj) = dl_device.as_ref() {
2091 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2092 if dev_ty != kdl || dev_id != alloc_dev {
2093 let wants_copy = copy
2094 .as_ref()
2095 .and_then(|c| c.extract::<bool>(py).ok())
2096 .unwrap_or(false);
2097 if wants_copy {
2098 return Err(PyValueError::new_err(
2099 "device copy not implemented for __dlpack__",
2100 ));
2101 } else {
2102 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2103 }
2104 }
2105 }
2106 }
2107 let _ = stream;
2108
2109 let dummy =
2110 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2111 let inner = std::mem::replace(
2112 &mut self.inner,
2113 DeviceArrayF32 {
2114 buf: dummy,
2115 rows: 0,
2116 cols: 0,
2117 },
2118 );
2119
2120 let rows = inner.rows;
2121 let cols = inner.cols;
2122 let buf = inner.buf;
2123
2124 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2125
2126 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2127 }
2128}
2129
2130#[cfg(all(feature = "python", feature = "cuda"))]
2131impl TsiDeviceArrayF32Py {
2132 pub fn new_from_rust(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
2133 Self {
2134 inner,
2135 ctx_guard,
2136 device_id,
2137 }
2138 }
2139}
2140
2141#[cfg(all(feature = "python", feature = "cuda"))]
2142#[pyfunction(name = "tsi_cuda_batch_dev")]
2143#[pyo3(signature = (data_f32, long_period_range, short_period_range, device_id=0))]
2144pub fn tsi_cuda_batch_dev_py<'py>(
2145 py: Python<'py>,
2146 data_f32: numpy::PyReadonlyArray1<'py, f32>,
2147 long_period_range: (usize, usize, usize),
2148 short_period_range: (usize, usize, usize),
2149 device_id: usize,
2150) -> PyResult<(TsiDeviceArrayF32Py, Bound<'py, PyDict>)> {
2151 if !cuda_available() {
2152 return Err(PyValueError::new_err("CUDA not available"));
2153 }
2154 let slice_in = data_f32.as_slice()?;
2155 let sweep = TsiBatchRange {
2156 long_period: long_period_range,
2157 short_period: short_period_range,
2158 };
2159 let (inner, combos, ctx, dev_id) = py.allow_threads(|| {
2160 let mut cuda = CudaTsi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2161 let ctx = cuda.context_arc();
2162 let dev_id = cuda.device_id();
2163 let (arr, combos) = cuda
2164 .tsi_batch_dev(slice_in, &sweep)
2165 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2166 Ok::<_, PyErr>((arr, combos, ctx, dev_id))
2167 })?;
2168
2169 use numpy::{IntoPyArray, PyArrayMethods};
2170 let dict = PyDict::new(py);
2171 dict.set_item(
2172 "long_periods",
2173 combos
2174 .iter()
2175 .map(|p| p.long_period.unwrap() as u64)
2176 .collect::<Vec<_>>()
2177 .into_pyarray(py),
2178 )?;
2179 dict.set_item(
2180 "short_periods",
2181 combos
2182 .iter()
2183 .map(|p| p.short_period.unwrap() as u64)
2184 .collect::<Vec<_>>()
2185 .into_pyarray(py),
2186 )?;
2187 Ok((TsiDeviceArrayF32Py::new_from_rust(inner, ctx, dev_id), dict))
2188}
2189
2190#[cfg(all(feature = "python", feature = "cuda"))]
2191#[pyfunction(name = "tsi_cuda_many_series_one_param_dev")]
2192#[pyo3(signature = (data_tm_f32, cols, rows, long_period, short_period, device_id=0))]
2193pub fn tsi_cuda_many_series_one_param_dev_py<'py>(
2194 py: Python<'py>,
2195 data_tm_f32: numpy::PyReadonlyArray1<'py, f32>,
2196 cols: usize,
2197 rows: usize,
2198 long_period: usize,
2199 short_period: usize,
2200 device_id: usize,
2201) -> PyResult<TsiDeviceArrayF32Py> {
2202 if !cuda_available() {
2203 return Err(PyValueError::new_err("CUDA not available"));
2204 }
2205 let slice_in = data_tm_f32.as_slice()?;
2206 let (inner, ctx, dev_id) = py.allow_threads(|| {
2207 let mut cuda = CudaTsi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2208 let ctx = cuda.context_arc();
2209 let dev_id = cuda.device_id();
2210 let arr = cuda
2211 .tsi_many_series_one_param_time_major_dev(
2212 slice_in,
2213 cols,
2214 rows,
2215 long_period,
2216 short_period,
2217 )
2218 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2219 Ok::<_, PyErr>((arr, ctx, dev_id))
2220 })?;
2221 Ok(TsiDeviceArrayF32Py::new_from_rust(inner, ctx, dev_id))
2222}
2223
2224#[cfg(feature = "python")]
2225#[pyclass(name = "TsiStream")]
2226pub struct TsiStreamPy {
2227 inner: TsiStream,
2228}
2229
2230#[cfg(feature = "python")]
2231#[pymethods]
2232impl TsiStreamPy {
2233 #[new]
2234 pub fn new(long_period: usize, short_period: usize) -> PyResult<Self> {
2235 let params = TsiParams {
2236 long_period: Some(long_period),
2237 short_period: Some(short_period),
2238 };
2239 let inner = TsiStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2240 Ok(TsiStreamPy { inner })
2241 }
2242
2243 pub fn update(&mut self, value: f64) -> Option<f64> {
2244 self.inner.update(value)
2245 }
2246}
2247
2248#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2249#[wasm_bindgen]
2250pub fn tsi_js(data: &[f64], long_period: usize, short_period: usize) -> Result<Vec<f64>, JsValue> {
2251 let params = TsiParams {
2252 long_period: Some(long_period),
2253 short_period: Some(short_period),
2254 };
2255 let input = TsiInput::from_slice(data, params);
2256
2257 let mut output = vec![0.0; data.len()];
2258
2259 tsi_into_slice(&mut output, &input, Kernel::Auto)
2260 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2261
2262 Ok(output)
2263}
2264
2265#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2266#[wasm_bindgen]
2267pub fn tsi_into(
2268 in_ptr: *const f64,
2269 out_ptr: *mut f64,
2270 len: usize,
2271 long_period: usize,
2272 short_period: usize,
2273) -> Result<(), JsValue> {
2274 if in_ptr.is_null() || out_ptr.is_null() {
2275 return Err(JsValue::from_str("Null pointer provided"));
2276 }
2277
2278 unsafe {
2279 let data = std::slice::from_raw_parts(in_ptr, len);
2280 let params = TsiParams {
2281 long_period: Some(long_period),
2282 short_period: Some(short_period),
2283 };
2284 let input = TsiInput::from_slice(data, params);
2285
2286 if in_ptr == out_ptr as *const f64 {
2287 let mut temp = vec![0.0; len];
2288 tsi_into_slice(&mut temp, &input, Kernel::Auto)
2289 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2290 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2291 out.copy_from_slice(&temp);
2292 } else {
2293 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2294 tsi_into_slice(out, &input, Kernel::Auto)
2295 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2296 }
2297 Ok(())
2298 }
2299}
2300
2301#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2302#[wasm_bindgen]
2303pub fn tsi_alloc(len: usize) -> *mut f64 {
2304 let mut vec = Vec::<f64>::with_capacity(len);
2305 let ptr = vec.as_mut_ptr();
2306 std::mem::forget(vec);
2307 ptr
2308}
2309
2310#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2311#[wasm_bindgen]
2312pub fn tsi_free(ptr: *mut f64, len: usize) {
2313 if !ptr.is_null() {
2314 unsafe {
2315 let _ = Vec::from_raw_parts(ptr, len, len);
2316 }
2317 }
2318}
2319
2320#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2321#[derive(Serialize, Deserialize)]
2322pub struct TsiBatchConfig {
2323 pub long_period_range: (usize, usize, usize),
2324 pub short_period_range: (usize, usize, usize),
2325}
2326
2327#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2328#[derive(Serialize, Deserialize)]
2329pub struct TsiBatchJsOutput {
2330 pub values: Vec<f64>,
2331 pub combos: Vec<TsiParams>,
2332 pub rows: usize,
2333 pub cols: usize,
2334}
2335
2336#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2337#[wasm_bindgen(js_name = tsi_batch)]
2338pub fn tsi_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2339 let config: TsiBatchConfig = serde_wasm_bindgen::from_value(config)
2340 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2341
2342 let sweep = TsiBatchRange {
2343 long_period: config.long_period_range,
2344 short_period: config.short_period_range,
2345 };
2346
2347 let result = tsi_batch_slice(data, &sweep, Kernel::Scalar)
2348 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2349
2350 let js_output = TsiBatchJsOutput {
2351 values: result.values,
2352 combos: result.combos,
2353 rows: result.rows,
2354 cols: result.cols,
2355 };
2356
2357 serde_wasm_bindgen::to_value(&js_output)
2358 .map_err(|e| JsValue::from_str(&format!("Failed to serialize output: {}", e)))
2359}
2360
2361#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2362#[wasm_bindgen]
2363pub fn tsi_batch_into(
2364 in_ptr: *const f64,
2365 out_ptr: *mut f64,
2366 len: usize,
2367 long_period_start: usize,
2368 long_period_end: usize,
2369 long_period_step: usize,
2370 short_period_start: usize,
2371 short_period_end: usize,
2372 short_period_step: usize,
2373) -> Result<usize, JsValue> {
2374 if in_ptr.is_null() || out_ptr.is_null() {
2375 return Err(JsValue::from_str("null pointer passed to tsi_batch_into"));
2376 }
2377
2378 unsafe {
2379 let data = std::slice::from_raw_parts(in_ptr, len);
2380
2381 let sweep = TsiBatchRange {
2382 long_period: (long_period_start, long_period_end, long_period_step),
2383 short_period: (short_period_start, short_period_end, short_period_step),
2384 };
2385
2386 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2387 let rows = combos.len();
2388 let cols = len;
2389 let total_size = rows
2390 .checked_mul(cols)
2391 .ok_or_else(|| JsValue::from_str("tsi_batch_into: size overflow"))?;
2392
2393 let out_slice = std::slice::from_raw_parts_mut(out_ptr, total_size);
2394
2395 match tsi_batch_inner_into(data, &sweep, Kernel::Scalar, false, out_slice) {
2396 Ok(_) => Ok(rows),
2397 Err(e) => Err(JsValue::from_str(&e.to_string())),
2398 }
2399 }
2400}