Skip to main content

vector_ta/indicators/
tsi.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
10use serde::{Deserialize, Serialize};
11#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
12use wasm_bindgen::prelude::*;
13
14#[cfg(all(feature = "python", feature = "cuda"))]
15use crate::cuda::moving_averages::DeviceArrayF32;
16use crate::indicators::ema::{EmaError, EmaParams, EmaStream};
17use crate::utilities::data_loader::{source_type, Candles};
18#[cfg(all(feature = "python", feature = "cuda"))]
19use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
20use crate::utilities::enums::Kernel;
21use crate::utilities::helpers::{alloc_with_nan_prefix, init_matrix_prefixes, make_uninit_matrix};
22#[cfg(feature = "python")]
23use crate::utilities::kernel_validation::validate_kernel;
24#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
25use core::arch::x86_64::*;
26#[cfg(all(feature = "python", feature = "cuda"))]
27use cust::context::Context;
28#[cfg(all(feature = "python", feature = "cuda"))]
29use cust::memory::DeviceBuffer;
30#[cfg(not(target_arch = "wasm32"))]
31use rayon::prelude::*;
32use std::convert::AsRef;
33#[cfg(all(feature = "python", feature = "cuda"))]
34use std::sync::Arc;
35use thiserror::Error;
36
37impl<'a> AsRef<[f64]> for TsiInput<'a> {
38    #[inline(always)]
39    fn as_ref(&self) -> &[f64] {
40        match &self.data {
41            TsiData::Slice(slice) => slice,
42            TsiData::Candles { candles, source } => source_type(candles, source),
43        }
44    }
45}
46
47#[derive(Debug, Clone)]
48pub enum TsiData<'a> {
49    Candles {
50        candles: &'a Candles,
51        source: &'a str,
52    },
53    Slice(&'a [f64]),
54}
55
56#[derive(Debug, Clone)]
57pub struct TsiOutput {
58    pub values: Vec<f64>,
59}
60
61#[derive(Debug, Clone)]
62#[cfg_attr(
63    all(target_arch = "wasm32", feature = "wasm"),
64    derive(Serialize, Deserialize)
65)]
66pub struct TsiParams {
67    pub long_period: Option<usize>,
68    pub short_period: Option<usize>,
69}
70
71impl Default for TsiParams {
72    fn default() -> Self {
73        Self {
74            long_period: Some(25),
75            short_period: Some(13),
76        }
77    }
78}
79
80#[derive(Debug, Clone)]
81pub struct TsiInput<'a> {
82    pub data: TsiData<'a>,
83    pub params: TsiParams,
84}
85
86impl<'a> TsiInput<'a> {
87    #[inline]
88    pub fn from_candles(c: &'a Candles, s: &'a str, p: TsiParams) -> Self {
89        Self {
90            data: TsiData::Candles {
91                candles: c,
92                source: s,
93            },
94            params: p,
95        }
96    }
97    #[inline]
98    pub fn from_slice(sl: &'a [f64], p: TsiParams) -> Self {
99        Self {
100            data: TsiData::Slice(sl),
101            params: p,
102        }
103    }
104    #[inline]
105    pub fn with_default_candles(c: &'a Candles) -> Self {
106        Self::from_candles(c, "close", TsiParams::default())
107    }
108    #[inline]
109    pub fn get_long_period(&self) -> usize {
110        self.params.long_period.unwrap_or(25)
111    }
112    #[inline]
113    pub fn get_short_period(&self) -> usize {
114        self.params.short_period.unwrap_or(13)
115    }
116}
117
118#[derive(Copy, Clone, Debug)]
119pub struct TsiBuilder {
120    long_period: Option<usize>,
121    short_period: Option<usize>,
122    kernel: Kernel,
123}
124
125impl Default for TsiBuilder {
126    fn default() -> Self {
127        Self {
128            long_period: None,
129            short_period: None,
130            kernel: Kernel::Auto,
131        }
132    }
133}
134
135impl TsiBuilder {
136    #[inline(always)]
137    pub fn new() -> Self {
138        Self::default()
139    }
140    #[inline(always)]
141    pub fn long_period(mut self, n: usize) -> Self {
142        self.long_period = Some(n);
143        self
144    }
145    #[inline(always)]
146    pub fn short_period(mut self, n: usize) -> Self {
147        self.short_period = Some(n);
148        self
149    }
150    #[inline(always)]
151    pub fn kernel(mut self, k: Kernel) -> Self {
152        self.kernel = k;
153        self
154    }
155
156    #[inline(always)]
157    pub fn apply(self, c: &Candles) -> Result<TsiOutput, TsiError> {
158        let p = TsiParams {
159            long_period: self.long_period,
160            short_period: self.short_period,
161        };
162        let i = TsiInput::from_candles(c, "close", p);
163        tsi_with_kernel(&i, self.kernel)
164    }
165
166    #[inline(always)]
167    pub fn apply_slice(self, d: &[f64]) -> Result<TsiOutput, TsiError> {
168        let p = TsiParams {
169            long_period: self.long_period,
170            short_period: self.short_period,
171        };
172        let i = TsiInput::from_slice(d, p);
173        tsi_with_kernel(&i, self.kernel)
174    }
175
176    #[inline(always)]
177    pub fn into_stream(self) -> Result<TsiStream, TsiError> {
178        let p = TsiParams {
179            long_period: self.long_period,
180            short_period: self.short_period,
181        };
182        TsiStream::try_new(p)
183    }
184}
185
186#[derive(Debug, Error)]
187pub enum TsiError {
188    #[error("tsi: Input data slice is empty.")]
189    EmptyInputData,
190    #[error("tsi: All values are NaN.")]
191    AllValuesNaN,
192    #[error("tsi: Invalid period: long = {long_period}, short = {short_period}, data length = {data_len}")]
193    InvalidPeriod {
194        long_period: usize,
195        short_period: usize,
196        data_len: usize,
197    },
198    #[error("tsi: Not enough valid data: needed = {needed}, valid = {valid}")]
199    NotEnoughValidData { needed: usize, valid: usize },
200    #[error("tsi: Non-batch kernel passed to batch path: {0:?}")]
201    InvalidKernelForBatch(Kernel),
202    #[error("tsi: Output length mismatch: expected = {expected}, got = {got}")]
203    OutputLengthMismatch { expected: usize, got: usize },
204    #[error("tsi: Invalid range expansion: start = {start}, end = {end}, step = {step}")]
205    InvalidRange {
206        start: usize,
207        end: usize,
208        step: usize,
209    },
210    #[error("tsi: size overflow computing rows*cols")]
211    SizeOverflow,
212    #[error("tsi: EMA sub-error: {0}")]
213    EmaSubError(#[from] EmaError),
214}
215
216#[inline(always)]
217fn tsi_prepare<'a>(input: &'a TsiInput) -> Result<(&'a [f64], usize, usize, usize), TsiError> {
218    let data: &[f64] = input.as_ref();
219    let len = data.len();
220    if len == 0 {
221        return Err(TsiError::EmptyInputData);
222    }
223    let first = data
224        .iter()
225        .position(|x| !x.is_nan())
226        .ok_or(TsiError::AllValuesNaN)?;
227    let long = input.get_long_period();
228    let short = input.get_short_period();
229
230    if long == 0 || short == 0 || long > len || short > len {
231        return Err(TsiError::InvalidPeriod {
232            long_period: long,
233            short_period: short,
234            data_len: len,
235        });
236    }
237    let needed = 1 + long + short;
238    if len - first < needed {
239        return Err(TsiError::NotEnoughValidData {
240            needed,
241            valid: len - first,
242        });
243    }
244
245    Ok((data, long, short, first))
246}
247
248#[inline(always)]
249fn tsi_compute_into_streaming(
250    data: &[f64],
251    long: usize,
252    short: usize,
253    first: usize,
254    out: &mut [f64],
255) -> Result<(), TsiError> {
256    let mut ema_long_num = EmaStream::try_new(EmaParams { period: Some(long) })?;
257    let mut ema_short_num = EmaStream::try_new(EmaParams {
258        period: Some(short),
259    })?;
260    let mut ema_long_den = EmaStream::try_new(EmaParams { period: Some(long) })?;
261    let mut ema_short_den = EmaStream::try_new(EmaParams {
262        period: Some(short),
263    })?;
264
265    let warmup_end = first + long + short;
266    if out.len() != data.len() {
267        return Err(TsiError::OutputLengthMismatch {
268            expected: data.len(),
269            got: out.len(),
270        });
271    }
272
273    if first + 1 >= data.len() {
274        return Ok(());
275    }
276    let mut prev = data[first];
277
278    for i in (first + 1)..data.len() {
279        let cur = data[i];
280        if !cur.is_finite() {
281            out[i] = f64::NAN;
282            continue;
283        }
284
285        let m = cur - prev;
286        prev = cur;
287
288        let n1 = match ema_long_num.update(m) {
289            Some(v) => v,
290            None => {
291                out[i] = f64::NAN;
292                continue;
293            }
294        };
295        let n2 = match ema_short_num.update(n1) {
296            Some(v) => v,
297            None => {
298                out[i] = f64::NAN;
299                continue;
300            }
301        };
302
303        let d1 = match ema_long_den.update(m.abs()) {
304            Some(v) => v,
305            None => {
306                out[i] = f64::NAN;
307                continue;
308            }
309        };
310        let d2 = match ema_short_den.update(d1) {
311            Some(v) => v,
312            None => {
313                out[i] = f64::NAN;
314                continue;
315            }
316        };
317
318        if i >= warmup_end {
319            out[i] = if d2 == 0.0 {
320                f64::NAN
321            } else {
322                (100.0 * (n2 / d2)).clamp(-100.0, 100.0)
323            };
324        }
325    }
326
327    Ok(())
328}
329
330#[inline(always)]
331fn tsi_compute_into_inline(
332    data: &[f64],
333    long: usize,
334    short: usize,
335    first: usize,
336    out: &mut [f64],
337) -> Result<(), TsiError> {
338    let n = data.len();
339    if n == 0 || first >= n {
340        return Ok(());
341    }
342
343    let long_alpha = 2.0 / (long as f64 + 1.0);
344    let short_alpha = 2.0 / (short as f64 + 1.0);
345    let long_1minus = 1.0 - long_alpha;
346    let short_1minus = 1.0 - short_alpha;
347
348    let warmup_end = first + long + short;
349
350    let mut prev = data[first];
351
352    let mut i = first + 1;
353    while i < n {
354        let cur = data[i];
355        if cur.is_finite() {
356            break;
357        } else {
358            if i >= warmup_end {
359                out[i] = f64::NAN;
360            }
361            i += 1;
362        }
363    }
364    if i >= n {
365        return Ok(());
366    }
367
368    let mut momentum = data[i] - prev;
369    prev = data[i];
370
371    let mut ema_long_num = momentum;
372    let mut ema_short_num = momentum;
373    let mut ema_long_den = momentum.abs();
374    let mut ema_short_den = ema_long_den;
375
376    let mut idx = i + 1;
377    let end_warm = warmup_end.min(n);
378    while idx < end_warm {
379        let cur = data[idx];
380        if cur.is_finite() {
381            momentum = cur - prev;
382            prev = cur;
383
384            let am = momentum.abs();
385
386            ema_long_num = long_alpha * momentum + long_1minus * ema_long_num;
387            ema_short_num = short_alpha * ema_long_num + short_1minus * ema_short_num;
388
389            ema_long_den = long_alpha * am + long_1minus * ema_long_den;
390            ema_short_den = short_alpha * ema_long_den + short_1minus * ema_short_den;
391        }
392
393        idx += 1;
394    }
395
396    while idx < n {
397        let cur = data[idx];
398        if cur.is_finite() {
399            momentum = cur - prev;
400            prev = cur;
401
402            let am = momentum.abs();
403
404            ema_long_num = long_alpha * momentum + long_1minus * ema_long_num;
405            ema_short_num = short_alpha * ema_long_num + short_1minus * ema_short_num;
406
407            ema_long_den = long_alpha * am + long_1minus * ema_long_den;
408            ema_short_den = short_alpha * ema_long_den + short_1minus * ema_short_den;
409
410            let den = ema_short_den;
411            let val = if den == 0.0 {
412                f64::NAN
413            } else {
414                (100.0 * (ema_short_num / den)).clamp(-100.0, 100.0)
415            };
416            out[idx] = val;
417        } else {
418            out[idx] = f64::NAN;
419        }
420        idx += 1;
421    }
422
423    Ok(())
424}
425
426#[inline]
427pub fn tsi(input: &TsiInput) -> Result<TsiOutput, TsiError> {
428    tsi_with_kernel(input, Kernel::Auto)
429}
430
431pub fn tsi_with_kernel(input: &TsiInput, kernel: Kernel) -> Result<TsiOutput, TsiError> {
432    let (data, long, short, first) = tsi_prepare(input)?;
433    let warmup_end = first + long + short;
434    let mut out = alloc_with_nan_prefix(data.len(), warmup_end);
435
436    let resolved_kernel = match kernel {
437        Kernel::Auto => Kernel::Scalar,
438        k => k,
439    };
440
441    if resolved_kernel == Kernel::Scalar && long == 25 && short == 13 {
442        unsafe {
443            tsi_scalar_classic(data, long, short, first, &mut out)?;
444        }
445    } else {
446        tsi_compute_into_inline(data, long, short, first, &mut out)?;
447    }
448    Ok(TsiOutput { values: out })
449}
450
451#[inline]
452pub fn tsi_into_slice(dst: &mut [f64], input: &TsiInput, kern: Kernel) -> Result<(), TsiError> {
453    let (data, long, short, first) = tsi_prepare(input)?;
454    let warmup_end = first + long + short;
455
456    let end = warmup_end.min(dst.len());
457    for v in &mut dst[..end] {
458        *v = f64::NAN;
459    }
460
461    let resolved_kernel = match kern {
462        Kernel::Auto => Kernel::Scalar,
463        k => k,
464    };
465
466    if resolved_kernel == Kernel::Scalar && long == 25 && short == 13 {
467        unsafe {
468            tsi_scalar_classic(data, long, short, first, dst)?;
469        }
470    } else {
471        tsi_compute_into_inline(data, long, short, first, dst)?;
472    }
473    Ok(())
474}
475
476#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
477#[inline]
478pub fn tsi_into(input: &TsiInput, out: &mut [f64]) -> Result<(), TsiError> {
479    let data_len = input.as_ref().len();
480    if out.len() != data_len {
481        return Err(TsiError::OutputLengthMismatch {
482            expected: data_len,
483            got: out.len(),
484        });
485    }
486    tsi_into_slice(out, input, Kernel::Auto)
487}
488
489#[inline]
490pub unsafe fn tsi_scalar(
491    data: &[f64],
492    long: usize,
493    short: usize,
494    first: usize,
495) -> Result<TsiOutput, TsiError> {
496    let warmup_end = first + long + short;
497    let mut out = alloc_with_nan_prefix(data.len(), warmup_end);
498
499    if long == 25 && short == 13 {
500        tsi_scalar_classic(data, long, short, first, &mut out)?;
501    } else {
502        tsi_compute_into_inline(data, long, short, first, &mut out)?;
503    }
504    Ok(TsiOutput { values: out })
505}
506
507#[inline]
508pub unsafe fn tsi_scalar_classic(
509    data: &[f64],
510    long: usize,
511    short: usize,
512    first: usize,
513    out: &mut [f64],
514) -> Result<(), TsiError> {
515    let n = data.len();
516    let warmup_end = first + long + short;
517
518    if first + 1 >= n {
519        return Ok(());
520    }
521
522    let long_alpha = 2.0 / (long as f64 + 1.0);
523    let short_alpha = 2.0 / (short as f64 + 1.0);
524    let long_1minus = 1.0 - long_alpha;
525    let short_1minus = 1.0 - short_alpha;
526
527    let mut prev = data[first];
528
529    if first + 1 >= n || !data[first + 1].is_finite() {
530        return Ok(());
531    }
532
533    let first_momentum = data[first + 1] - prev;
534    prev = data[first + 1];
535
536    let mut ema_long_num = first_momentum;
537    let mut ema_short_num = first_momentum;
538    let mut ema_long_den = first_momentum.abs();
539    let mut ema_short_den = first_momentum.abs();
540
541    for i in (first + 2)..n {
542        let cur = data[i];
543        if !cur.is_finite() {
544            out[i] = f64::NAN;
545            continue;
546        }
547
548        let momentum = cur - prev;
549        prev = cur;
550
551        ema_long_num = long_alpha * momentum + long_1minus * ema_long_num;
552
553        ema_short_num = short_alpha * ema_long_num + short_1minus * ema_short_num;
554
555        ema_long_den = long_alpha * momentum.abs() + long_1minus * ema_long_den;
556
557        ema_short_den = short_alpha * ema_long_den + short_1minus * ema_short_den;
558
559        if i >= warmup_end {
560            out[i] = if ema_short_den == 0.0 {
561                f64::NAN
562            } else {
563                (100.0 * (ema_short_num / ema_short_den)).clamp(-100.0, 100.0)
564            };
565        }
566    }
567
568    Ok(())
569}
570
571#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
572#[inline]
573pub unsafe fn tsi_avx2(
574    data: &[f64],
575    long: usize,
576    short: usize,
577    first: usize,
578) -> Result<TsiOutput, TsiError> {
579    tsi_scalar(data, long, short, first)
580}
581#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
582#[inline]
583pub unsafe fn tsi_avx512(
584    data: &[f64],
585    long: usize,
586    short: usize,
587    first: usize,
588) -> Result<TsiOutput, TsiError> {
589    if long <= 32 && short <= 32 {
590        tsi_avx512_short(data, long, short, first)
591    } else {
592        tsi_avx512_long(data, long, short, first)
593    }
594}
595#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
596#[inline]
597pub unsafe fn tsi_avx512_short(
598    data: &[f64],
599    long: usize,
600    short: usize,
601    first: usize,
602) -> Result<TsiOutput, TsiError> {
603    tsi_scalar(data, long, short, first)
604}
605#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
606#[inline]
607pub unsafe fn tsi_avx512_long(
608    data: &[f64],
609    long: usize,
610    short: usize,
611    first: usize,
612) -> Result<TsiOutput, TsiError> {
613    tsi_scalar(data, long, short, first)
614}
615
616#[derive(Debug, Clone)]
617pub struct TsiStream {
618    long: usize,
619    short: usize,
620
621    alpha_l: f64,
622    alpha_s: f64,
623
624    ema_long_num: f64,
625    ema_short_num: f64,
626    ema_long_den: f64,
627    ema_short_den: f64,
628
629    prev_price: f64,
630    have_prev: bool,
631
632    seeded: bool,
633
634    warmup_ctr: usize,
635    warmup_needed: usize,
636}
637impl TsiStream {
638    #[inline]
639    pub fn try_new(params: TsiParams) -> Result<Self, TsiError> {
640        let long = params.long_period.unwrap_or(25);
641        let short = params.short_period.unwrap_or(13);
642
643        if long == 0 || short == 0 {
644            return Err(TsiError::InvalidPeriod {
645                long_period: long,
646                short_period: short,
647                data_len: 0,
648            });
649        }
650
651        let alpha_l = 2.0 / (long as f64 + 1.0);
652        let alpha_s = 2.0 / (short as f64 + 1.0);
653
654        Ok(Self {
655            long,
656            short,
657            alpha_l,
658            alpha_s,
659            ema_long_num: 0.0,
660            ema_short_num: 0.0,
661            ema_long_den: 0.0,
662            ema_short_den: 0.0,
663            prev_price: f64::NAN,
664            have_prev: false,
665            seeded: false,
666            warmup_ctr: 0,
667            warmup_needed: long + short,
668        })
669    }
670
671    #[inline(always)]
672    pub fn update(&mut self, value: f64) -> Option<f64> {
673        if !value.is_finite() {
674            return None;
675        }
676
677        if !self.have_prev {
678            self.prev_price = value;
679            self.have_prev = true;
680            return None;
681        }
682
683        let m = value - self.prev_price;
684        self.prev_price = value;
685        let am = m.abs();
686
687        if !self.seeded {
688            self.ema_long_num = m;
689            self.ema_short_num = m;
690            self.ema_long_den = am;
691            self.ema_short_den = am;
692            self.seeded = true;
693            self.warmup_ctr = 1;
694            return None;
695        }
696
697        self.ema_long_num += self.alpha_l * (m - self.ema_long_num);
698        self.ema_short_num += self.alpha_s * (self.ema_long_num - self.ema_short_num);
699
700        self.ema_long_den += self.alpha_l * (am - self.ema_long_den);
701        self.ema_short_den += self.alpha_s * (self.ema_long_den - self.ema_short_den);
702
703        self.warmup_ctr += 1;
704
705        if self.warmup_ctr < self.warmup_needed {
706            return None;
707        }
708
709        let den = self.ema_short_den;
710        if den == 0.0 {
711            return Some(f64::NAN);
712        }
713
714        let tsi = 100.0 * (self.ema_short_num / den);
715        Some(tsi.clamp(-100.0, 100.0))
716    }
717}
718
719#[derive(Clone, Debug)]
720pub struct TsiBatchRange {
721    pub long_period: (usize, usize, usize),
722    pub short_period: (usize, usize, usize),
723}
724impl Default for TsiBatchRange {
725    fn default() -> Self {
726        Self {
727            long_period: (25, 274, 1),
728            short_period: (13, 13, 0),
729        }
730    }
731}
732
733#[derive(Clone, Debug, Default)]
734pub struct TsiBatchBuilder {
735    range: TsiBatchRange,
736    kernel: Kernel,
737}
738impl TsiBatchBuilder {
739    pub fn new() -> Self {
740        Self::default()
741    }
742    pub fn kernel(mut self, k: Kernel) -> Self {
743        self.kernel = k;
744        self
745    }
746    pub fn long_range(mut self, start: usize, end: usize, step: usize) -> Self {
747        self.range.long_period = (start, end, step);
748        self
749    }
750    pub fn long_static(mut self, n: usize) -> Self {
751        self.range.long_period = (n, n, 1);
752        self
753    }
754    pub fn short_range(mut self, start: usize, end: usize, step: usize) -> Self {
755        self.range.short_period = (start, end, step);
756        self
757    }
758    pub fn short_static(mut self, n: usize) -> Self {
759        self.range.short_period = (n, n, 1);
760        self
761    }
762    pub fn apply_slice(self, data: &[f64]) -> Result<TsiBatchOutput, TsiError> {
763        tsi_batch_with_kernel(data, &self.range, self.kernel)
764    }
765    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<TsiBatchOutput, TsiError> {
766        TsiBatchBuilder::new().kernel(k).apply_slice(data)
767    }
768    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<TsiBatchOutput, TsiError> {
769        let slice = source_type(c, src);
770        self.apply_slice(slice)
771    }
772    pub fn with_default_candles(c: &Candles) -> Result<TsiBatchOutput, TsiError> {
773        TsiBatchBuilder::new()
774            .kernel(Kernel::Auto)
775            .apply_candles(c, "close")
776    }
777}
778
779pub fn tsi_batch_with_kernel(
780    data: &[f64],
781    sweep: &TsiBatchRange,
782    k: Kernel,
783) -> Result<TsiBatchOutput, TsiError> {
784    let kernel = match k {
785        Kernel::Auto => Kernel::ScalarBatch,
786        other if other.is_batch() => other,
787        _ => {
788            return Err(TsiError::InvalidKernelForBatch(k));
789        }
790    };
791
792    let simd = match kernel {
793        Kernel::Avx512Batch => Kernel::Avx512,
794        Kernel::Avx2Batch => Kernel::Avx2,
795        Kernel::ScalarBatch => Kernel::Scalar,
796        _ => unreachable!(),
797    };
798    tsi_batch_par_slice(data, sweep, simd)
799}
800
801#[derive(Clone, Debug)]
802pub struct TsiBatchOutput {
803    pub values: Vec<f64>,
804    pub combos: Vec<TsiParams>,
805    pub rows: usize,
806    pub cols: usize,
807}
808impl TsiBatchOutput {
809    pub fn row_for_params(&self, p: &TsiParams) -> Option<usize> {
810        self.combos.iter().position(|c| {
811            c.long_period.unwrap_or(25) == p.long_period.unwrap_or(25)
812                && c.short_period.unwrap_or(13) == p.short_period.unwrap_or(13)
813        })
814    }
815    pub fn values_for(&self, p: &TsiParams) -> Option<&[f64]> {
816        self.row_for_params(p).map(|row| {
817            let start = row * self.cols;
818            &self.values[start..start + self.cols]
819        })
820    }
821}
822
823#[inline(always)]
824fn expand_grid(r: &TsiBatchRange) -> Result<Vec<TsiParams>, TsiError> {
825    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, TsiError> {
826        if step == 0 || start == end {
827            return Ok(vec![start]);
828        }
829        if start < end {
830            let vals: Vec<usize> = (start..=end).step_by(step).collect();
831            if vals.is_empty() {
832                return Err(TsiError::InvalidRange { start, end, step });
833            }
834            return Ok(vals);
835        }
836
837        let mut v = start;
838        let mut out = Vec::new();
839        loop {
840            out.push(v);
841            let guard = end.saturating_add(step);
842            if v <= guard {
843                break;
844            }
845            v = v.saturating_sub(step);
846            if v == 0 && step == 0 {
847                break;
848            }
849        }
850        if out.is_empty() {
851            return Err(TsiError::InvalidRange { start, end, step });
852        }
853        if *out.last().unwrap() != end {
854            out.push(end);
855        }
856        Ok(out)
857    }
858
859    let longs = axis_usize(r.long_period)?;
860    let shorts = axis_usize(r.short_period)?;
861    if longs.is_empty() || shorts.is_empty() {
862        return Err(TsiError::InvalidRange {
863            start: r.long_period.0,
864            end: r.long_period.1,
865            step: r.long_period.2,
866        });
867    }
868
869    let total = longs
870        .len()
871        .checked_mul(shorts.len())
872        .ok_or(TsiError::SizeOverflow)?;
873
874    let mut out = Vec::with_capacity(total);
875    for &l in &longs {
876        for &s in &shorts {
877            out.push(TsiParams {
878                long_period: Some(l),
879                short_period: Some(s),
880            });
881        }
882    }
883    Ok(out)
884}
885
886#[inline(always)]
887pub fn tsi_batch_slice(
888    data: &[f64],
889    sweep: &TsiBatchRange,
890    kern: Kernel,
891) -> Result<TsiBatchOutput, TsiError> {
892    tsi_batch_inner(data, sweep, kern, false)
893}
894
895#[inline(always)]
896pub fn tsi_batch_par_slice(
897    data: &[f64],
898    sweep: &TsiBatchRange,
899    kern: Kernel,
900) -> Result<TsiBatchOutput, TsiError> {
901    tsi_batch_inner(data, sweep, kern, true)
902}
903
904fn tsi_batch_inner(
905    data: &[f64],
906    sweep: &TsiBatchRange,
907    kern: Kernel,
908    parallel: bool,
909) -> Result<TsiBatchOutput, TsiError> {
910    let combos = expand_grid(sweep)?;
911
912    let first = data
913        .iter()
914        .position(|x| !x.is_nan())
915        .ok_or(TsiError::AllValuesNaN)?;
916    let max_long = combos.iter().map(|c| c.long_period.unwrap()).max().unwrap();
917    let max_short = combos
918        .iter()
919        .map(|c| c.short_period.unwrap())
920        .max()
921        .unwrap();
922    let max_needed = 1 + max_long + max_short;
923    if data.len() - first < max_needed {
924        return Err(TsiError::NotEnoughValidData {
925            needed: max_needed,
926            valid: data.len() - first,
927        });
928    }
929
930    let rows = combos.len();
931    let cols = data.len();
932    if rows.checked_mul(cols).is_none() {
933        return Err(TsiError::SizeOverflow);
934    }
935
936    let mut buf_mu = make_uninit_matrix(rows, cols);
937
938    let warmup_periods: Vec<usize> = combos
939        .iter()
940        .map(|c| first + 1 + c.long_period.unwrap() + c.short_period.unwrap() - 1)
941        .collect();
942    init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
943
944    let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
945    let values: &mut [f64] = unsafe {
946        core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
947    };
948
949    let do_row = |row: usize, out_row: &mut [f64]| {
950        let p = &combos[row];
951        let l = p.long_period.unwrap();
952        let s = p.short_period.unwrap();
953        match kern {
954            Kernel::Scalar => unsafe { tsi_row_scalar_into(data, l, s, first, out_row) },
955            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
956            Kernel::Avx2 => unsafe { tsi_row_avx2_into(data, l, s, first, out_row) },
957            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
958            Kernel::Avx512 => unsafe { tsi_row_avx512_into(data, l, s, first, out_row) },
959            _ => unreachable!(),
960        }
961        .unwrap();
962    };
963
964    if parallel {
965        #[cfg(not(target_arch = "wasm32"))]
966        {
967            values
968                .par_chunks_mut(cols)
969                .enumerate()
970                .for_each(|(row, slice)| do_row(row, slice));
971        }
972
973        #[cfg(target_arch = "wasm32")]
974        {
975            for (row, slice) in values.chunks_mut(cols).enumerate() {
976                do_row(row, slice);
977            }
978        }
979    } else {
980        for (row, slice) in values.chunks_mut(cols).enumerate() {
981            do_row(row, slice);
982        }
983    }
984
985    let values = unsafe {
986        Vec::from_raw_parts(
987            buf_guard.as_mut_ptr() as *mut f64,
988            buf_guard.len(),
989            buf_guard.capacity(),
990        )
991    };
992
993    Ok(TsiBatchOutput {
994        values,
995        combos,
996        rows,
997        cols,
998    })
999}
1000
1001#[inline(always)]
1002fn tsi_batch_inner_into(
1003    data: &[f64],
1004    sweep: &TsiBatchRange,
1005    kern: Kernel,
1006    parallel: bool,
1007    out: &mut [f64],
1008) -> Result<Vec<TsiParams>, TsiError> {
1009    let combos = expand_grid(sweep)?;
1010
1011    let first = data
1012        .iter()
1013        .position(|x| !x.is_nan())
1014        .ok_or(TsiError::AllValuesNaN)?;
1015    let max_long = combos.iter().map(|c| c.long_period.unwrap()).max().unwrap();
1016    let max_short = combos
1017        .iter()
1018        .map(|c| c.short_period.unwrap())
1019        .max()
1020        .unwrap();
1021    let max_needed = 1 + max_long + max_short;
1022    if data.len() - first < max_needed {
1023        return Err(TsiError::NotEnoughValidData {
1024            needed: max_needed,
1025            valid: data.len() - first,
1026        });
1027    }
1028
1029    let rows = combos.len();
1030    let cols = data.len();
1031    let expected = rows.checked_mul(cols).ok_or(TsiError::SizeOverflow)?;
1032    if out.len() != expected {
1033        return Err(TsiError::OutputLengthMismatch {
1034            expected,
1035            got: out.len(),
1036        });
1037    }
1038
1039    let do_row = |row: usize, out_row: &mut [f64]| {
1040        let p = &combos[row];
1041        let l = p.long_period.unwrap();
1042        let s = p.short_period.unwrap();
1043        match kern {
1044            Kernel::Scalar => unsafe { tsi_row_scalar_into(data, l, s, first, out_row) },
1045            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1046            Kernel::Avx2 => unsafe { tsi_row_avx2_into(data, l, s, first, out_row) },
1047            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1048            Kernel::Avx512 => unsafe { tsi_row_avx512_into(data, l, s, first, out_row) },
1049            _ => unreachable!(),
1050        }
1051        .unwrap();
1052    };
1053
1054    if parallel {
1055        #[cfg(not(target_arch = "wasm32"))]
1056        {
1057            out.par_chunks_mut(cols)
1058                .enumerate()
1059                .for_each(|(row, slice)| do_row(row, slice));
1060        }
1061        #[cfg(target_arch = "wasm32")]
1062        {
1063            for (row, slice) in out.chunks_mut(cols).enumerate() {
1064                do_row(row, slice);
1065            }
1066        }
1067    } else {
1068        for (row, slice) in out.chunks_mut(cols).enumerate() {
1069            do_row(row, slice);
1070        }
1071    }
1072
1073    Ok(combos)
1074}
1075
1076#[inline(always)]
1077pub unsafe fn tsi_row_scalar_into(
1078    data: &[f64],
1079    long: usize,
1080    short: usize,
1081    first: usize,
1082    out_row: &mut [f64],
1083) -> Result<(), TsiError> {
1084    tsi_compute_into_inline(data, long, short, first, out_row)
1085}
1086
1087#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1088#[inline(always)]
1089pub unsafe fn tsi_row_avx2_into(
1090    data: &[f64],
1091    long: usize,
1092    short: usize,
1093    first: usize,
1094    out_row: &mut [f64],
1095) -> Result<(), TsiError> {
1096    tsi_compute_into_streaming(data, long, short, first, out_row)
1097}
1098
1099#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1100#[inline(always)]
1101pub unsafe fn tsi_row_avx512_into(
1102    data: &[f64],
1103    long: usize,
1104    short: usize,
1105    first: usize,
1106    out_row: &mut [f64],
1107) -> Result<(), TsiError> {
1108    tsi_compute_into_streaming(data, long, short, first, out_row)
1109}
1110
1111#[cfg(test)]
1112mod tests {
1113    use super::*;
1114    use crate::skip_if_unsupported;
1115    use crate::utilities::data_loader::read_candles_from_csv;
1116    use paste::paste;
1117
1118    fn check_tsi_partial_params(
1119        test_name: &str,
1120        kernel: Kernel,
1121    ) -> Result<(), Box<dyn std::error::Error>> {
1122        skip_if_unsupported!(kernel, test_name);
1123        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1124        let candles = read_candles_from_csv(file_path)?;
1125
1126        let default_params = TsiParams {
1127            long_period: None,
1128            short_period: None,
1129        };
1130        let input = TsiInput::from_candles(&candles, "close", default_params);
1131        let output = tsi_with_kernel(&input, kernel)?;
1132        assert_eq!(output.values.len(), candles.close.len());
1133        Ok(())
1134    }
1135
1136    fn check_tsi_accuracy(
1137        test_name: &str,
1138        kernel: Kernel,
1139    ) -> Result<(), Box<dyn std::error::Error>> {
1140        skip_if_unsupported!(kernel, test_name);
1141        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1142        let candles = read_candles_from_csv(file_path)?;
1143
1144        let params = TsiParams {
1145            long_period: Some(25),
1146            short_period: Some(13),
1147        };
1148        let input = TsiInput::from_candles(&candles, "close", params);
1149        let tsi_result = tsi_with_kernel(&input, kernel)?;
1150
1151        let expected_last_five = [
1152            -17.757654061849838,
1153            -17.367527062626184,
1154            -17.305577681249513,
1155            -16.937565646991143,
1156            -17.61825617316731,
1157        ];
1158        let start = tsi_result.values.len().saturating_sub(5);
1159        for (i, &val) in tsi_result.values[start..].iter().enumerate() {
1160            let diff = (val - expected_last_five[i]).abs();
1161            assert!(
1162                diff < 1e-7,
1163                "[{}] TSI {:?} mismatch at idx {}: got {}, expected {}",
1164                test_name,
1165                kernel,
1166                i,
1167                val,
1168                expected_last_five[i]
1169            );
1170        }
1171        Ok(())
1172    }
1173
1174    fn check_tsi_zero_period(
1175        test_name: &str,
1176        kernel: Kernel,
1177    ) -> Result<(), Box<dyn std::error::Error>> {
1178        skip_if_unsupported!(kernel, test_name);
1179        let input_data = [10.0, 20.0, 30.0];
1180        let params = TsiParams {
1181            long_period: Some(0),
1182            short_period: Some(13),
1183        };
1184        let input = TsiInput::from_slice(&input_data, params);
1185        let res = tsi_with_kernel(&input, kernel);
1186        assert!(
1187            res.is_err(),
1188            "[{}] TSI should fail with zero period",
1189            test_name
1190        );
1191        Ok(())
1192    }
1193
1194    fn check_tsi_period_exceeds_length(
1195        test_name: &str,
1196        kernel: Kernel,
1197    ) -> Result<(), Box<dyn std::error::Error>> {
1198        skip_if_unsupported!(kernel, test_name);
1199        let data_small = [10.0, 20.0, 30.0];
1200        let params = TsiParams {
1201            long_period: Some(25),
1202            short_period: Some(13),
1203        };
1204        let input = TsiInput::from_slice(&data_small, params);
1205        let res = tsi_with_kernel(&input, kernel);
1206        assert!(
1207            res.is_err(),
1208            "[{}] TSI should fail with period exceeding length",
1209            test_name
1210        );
1211        Ok(())
1212    }
1213
1214    fn check_tsi_very_small_dataset(
1215        test_name: &str,
1216        kernel: Kernel,
1217    ) -> Result<(), Box<dyn std::error::Error>> {
1218        skip_if_unsupported!(kernel, test_name);
1219        let single_point = [42.0];
1220        let params = TsiParams {
1221            long_period: Some(25),
1222            short_period: Some(13),
1223        };
1224        let input = TsiInput::from_slice(&single_point, params);
1225        let res = tsi_with_kernel(&input, kernel);
1226        assert!(
1227            res.is_err(),
1228            "[{}] TSI should fail with insufficient data",
1229            test_name
1230        );
1231        Ok(())
1232    }
1233
1234    fn check_tsi_default_candles(
1235        test_name: &str,
1236        kernel: Kernel,
1237    ) -> Result<(), Box<dyn std::error::Error>> {
1238        skip_if_unsupported!(kernel, test_name);
1239        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1240        let candles = read_candles_from_csv(file_path)?;
1241
1242        let input = TsiInput::with_default_candles(&candles);
1243        match input.data {
1244            TsiData::Candles { source, .. } => assert_eq!(source, "close"),
1245            _ => panic!("Expected TsiData::Candles"),
1246        }
1247        let output = tsi_with_kernel(&input, kernel)?;
1248        assert_eq!(output.values.len(), candles.close.len());
1249        Ok(())
1250    }
1251
1252    fn check_tsi_reinput(
1253        test_name: &str,
1254        kernel: Kernel,
1255    ) -> Result<(), Box<dyn std::error::Error>> {
1256        skip_if_unsupported!(kernel, test_name);
1257        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1258        let candles = read_candles_from_csv(file_path)?;
1259
1260        let first_params = TsiParams {
1261            long_period: Some(25),
1262            short_period: Some(13),
1263        };
1264        let first_input = TsiInput::from_candles(&candles, "close", first_params);
1265        let first_result = tsi_with_kernel(&first_input, kernel)?;
1266
1267        let second_params = TsiParams {
1268            long_period: Some(25),
1269            short_period: Some(13),
1270        };
1271        let second_input = TsiInput::from_slice(&first_result.values, second_params);
1272        let second_result = tsi_with_kernel(&second_input, kernel)?;
1273
1274        assert_eq!(second_result.values.len(), first_result.values.len());
1275        Ok(())
1276    }
1277
1278    fn check_tsi_nan_handling(
1279        test_name: &str,
1280        kernel: Kernel,
1281    ) -> Result<(), Box<dyn std::error::Error>> {
1282        skip_if_unsupported!(kernel, test_name);
1283        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1284        let candles = read_candles_from_csv(file_path)?;
1285
1286        let input = TsiInput::from_candles(
1287            &candles,
1288            "close",
1289            TsiParams {
1290                long_period: Some(25),
1291                short_period: Some(13),
1292            },
1293        );
1294        let res = tsi_with_kernel(&input, kernel)?;
1295        assert_eq!(res.values.len(), candles.close.len());
1296        if res.values.len() > 240 {
1297            for (i, &val) in res.values[240..].iter().enumerate() {
1298                assert!(
1299                    !val.is_nan(),
1300                    "[{}] Found unexpected NaN at out-index {}",
1301                    test_name,
1302                    240 + i
1303                );
1304            }
1305        }
1306        Ok(())
1307    }
1308
1309    #[test]
1310    fn test_tsi_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1311        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1312        let candles = read_candles_from_csv(file_path)?;
1313        let input = TsiInput::from_candles(&candles, "close", TsiParams::default());
1314
1315        let baseline = tsi(&input)?.values;
1316
1317        let mut out = vec![0.0; candles.close.len()];
1318        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1319        {
1320            tsi_into(&input, &mut out)?;
1321        }
1322        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1323        {
1324            tsi_into_slice(&mut out, &input, Kernel::Auto)?;
1325        }
1326
1327        assert_eq!(out.len(), baseline.len());
1328        let eq_or_both_nan = |a: f64, b: f64| (a.is_nan() && b.is_nan()) || (a == b);
1329        for i in 0..out.len() {
1330            assert!(
1331                eq_or_both_nan(out[i], baseline[i]),
1332                "mismatch at {}: got {}, expected {}",
1333                i,
1334                out[i],
1335                baseline[i]
1336            );
1337        }
1338        Ok(())
1339    }
1340
1341    #[cfg(debug_assertions)]
1342    fn check_tsi_no_poison(
1343        test_name: &str,
1344        kernel: Kernel,
1345    ) -> Result<(), Box<dyn std::error::Error>> {
1346        skip_if_unsupported!(kernel, test_name);
1347
1348        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1349        let candles = read_candles_from_csv(file_path)?;
1350
1351        let test_params = vec![
1352            TsiParams::default(),
1353            TsiParams {
1354                long_period: Some(2),
1355                short_period: Some(2),
1356            },
1357            TsiParams {
1358                long_period: Some(5),
1359                short_period: Some(3),
1360            },
1361            TsiParams {
1362                long_period: Some(10),
1363                short_period: Some(5),
1364            },
1365            TsiParams {
1366                long_period: Some(15),
1367                short_period: Some(7),
1368            },
1369            TsiParams {
1370                long_period: Some(25),
1371                short_period: Some(13),
1372            },
1373            TsiParams {
1374                long_period: Some(30),
1375                short_period: Some(15),
1376            },
1377            TsiParams {
1378                long_period: Some(40),
1379                short_period: Some(20),
1380            },
1381            TsiParams {
1382                long_period: Some(50),
1383                short_period: Some(25),
1384            },
1385            TsiParams {
1386                long_period: Some(100),
1387                short_period: Some(50),
1388            },
1389            TsiParams {
1390                long_period: Some(200),
1391                short_period: Some(100),
1392            },
1393            TsiParams {
1394                long_period: Some(50),
1395                short_period: Some(2),
1396            },
1397            TsiParams {
1398                long_period: Some(100),
1399                short_period: Some(5),
1400            },
1401            TsiParams {
1402                long_period: Some(10),
1403                short_period: Some(9),
1404            },
1405        ];
1406
1407        for (param_idx, params) in test_params.iter().enumerate() {
1408            let input = TsiInput::from_candles(&candles, "close", params.clone());
1409            let output = tsi_with_kernel(&input, kernel)?;
1410
1411            for (i, &val) in output.values.iter().enumerate() {
1412                if val.is_nan() {
1413                    continue;
1414                }
1415
1416                let bits = val.to_bits();
1417
1418                if bits == 0x11111111_11111111 {
1419                    panic!(
1420                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1421						 with params: {:?} (param set {})",
1422                        test_name, val, bits, i, params, param_idx
1423                    );
1424                }
1425
1426                if bits == 0x22222222_22222222 {
1427                    panic!(
1428                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1429						 with params: {:?} (param set {})",
1430                        test_name, val, bits, i, params, param_idx
1431                    );
1432                }
1433
1434                if bits == 0x33333333_33333333 {
1435                    panic!(
1436                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1437						 with params: {:?} (param set {})",
1438                        test_name, val, bits, i, params, param_idx
1439                    );
1440                }
1441            }
1442        }
1443
1444        Ok(())
1445    }
1446
1447    #[cfg(not(debug_assertions))]
1448    fn check_tsi_no_poison(
1449        _test_name: &str,
1450        _kernel: Kernel,
1451    ) -> Result<(), Box<dyn std::error::Error>> {
1452        Ok(())
1453    }
1454
1455    #[cfg(feature = "proptest")]
1456    #[allow(clippy::float_cmp)]
1457    fn check_tsi_property(
1458        test_name: &str,
1459        kernel: Kernel,
1460    ) -> Result<(), Box<dyn std::error::Error>> {
1461        use proptest::prelude::*;
1462        skip_if_unsupported!(kernel, test_name);
1463
1464        let strat = (2usize..=50).prop_flat_map(|long_period| {
1465            (
1466                prop::collection::vec(
1467                    (1.0f64..10000.0f64).prop_filter("finite", |x| x.is_finite()),
1468                    (long_period + 30)..400,
1469                ),
1470                Just(long_period),
1471                2usize..=long_period.min(25),
1472            )
1473        });
1474
1475        proptest::test_runner::TestRunner::default()
1476            .run(&strat, |(data, long_period, short_period)| {
1477                let params = TsiParams {
1478                    long_period: Some(long_period),
1479                    short_period: Some(short_period),
1480                };
1481                let input = TsiInput::from_slice(&data, params.clone());
1482
1483                let TsiOutput { values: out } = tsi_with_kernel(&input, kernel).unwrap();
1484                let TsiOutput { values: ref_out } =
1485                    tsi_with_kernel(&input, Kernel::Scalar).unwrap();
1486
1487                for (i, &val) in out.iter().enumerate() {
1488                    if !val.is_nan() {
1489                        prop_assert!(
1490                            val >= -100.0 - 1e-9 && val <= 100.0 + 1e-9,
1491                            "Property 1: TSI value out of range at idx {}: {} ∉ [-100, 100]",
1492                            i,
1493                            val
1494                        );
1495                    }
1496                }
1497
1498                let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1499
1500                prop_assert!(
1501                    out[0].is_nan(),
1502                    "Property 2: First value should always be NaN, got {}",
1503                    out[0]
1504                );
1505
1506                let has_variation = data.windows(2).any(|w| (w[0] - w[1]).abs() > 1e-10);
1507
1508                if has_variation {
1509                    let first_non_nan = out.iter().position(|x| !x.is_nan());
1510
1511                    if let Some(idx) = first_non_nan {
1512                        prop_assert!(
1513                            idx >= 1,
1514                            "Property 2: First non-NaN too early at idx {}",
1515                            idx
1516                        );
1517
1518                        let sufficient_data =
1519                            (first_valid + long_period + short_period).min(out.len());
1520                        if sufficient_data < out.len() {
1521                            let mut has_movement = false;
1522                            for i in 1..data.len() {
1523                                if (data[i] - data[i - 1]).abs() > 1e-10 {
1524                                    has_movement = true;
1525                                    break;
1526                                }
1527                            }
1528
1529                            if has_movement {
1530                                let last_quarter_start = out.len() - (out.len() / 4).max(1);
1531
1532                                let nan_count = out[last_quarter_start..]
1533                                    .iter()
1534                                    .filter(|v| v.is_nan())
1535                                    .count();
1536                                let total_count = out.len() - last_quarter_start;
1537                                prop_assert!(
1538									nan_count < total_count,
1539									"Property 2: All values in last quarter are NaN (from idx {}), which suggests an issue",
1540									last_quarter_start
1541								);
1542                            }
1543                        }
1544                    }
1545                }
1546
1547                for (i, (&val, &ref_val)) in out.iter().zip(ref_out.iter()).enumerate() {
1548                    if val.is_nan() && ref_val.is_nan() {
1549                        continue;
1550                    }
1551                    if !val.is_finite() || !ref_val.is_finite() {
1552                        prop_assert!(
1553                            val.to_bits() == ref_val.to_bits(),
1554                            "Property 3: Kernel finite/NaN mismatch at idx {}: {} vs {}",
1555                            i,
1556                            val,
1557                            ref_val
1558                        );
1559                        continue;
1560                    }
1561                    prop_assert!(
1562                        (val - ref_val).abs() <= 1e-9,
1563                        "Property 3: Kernel mismatch at idx {}: {} vs {} (diff: {})",
1564                        i,
1565                        val,
1566                        ref_val,
1567                        (val - ref_val).abs()
1568                    );
1569                }
1570
1571                let constant_data = vec![100.0; 50];
1572                let const_params = TsiParams {
1573                    long_period: Some(10),
1574                    short_period: Some(5),
1575                };
1576                let const_input = TsiInput::from_slice(&constant_data, const_params);
1577                if let Ok(TsiOutput { values: const_out }) = tsi_with_kernel(&const_input, kernel) {
1578                    for (i, &val) in const_out.iter().enumerate() {
1579                        prop_assert!(
1580                            val.is_nan(),
1581                            "Property 4: TSI should be NaN for constant prices at idx {}, got {}",
1582                            i,
1583                            val
1584                        );
1585                    }
1586                }
1587
1588                let uptrend: Vec<f64> = (0..100).map(|i| 100.0 + i as f64 * 20.0).collect();
1589                let uptrend_input = TsiInput::from_slice(&uptrend, params.clone());
1590
1591                let downtrend: Vec<f64> = (0..100).map(|i| 2100.0 - i as f64 * 20.0).collect();
1592                let downtrend_input = TsiInput::from_slice(&downtrend, params.clone());
1593
1594                if let (Ok(TsiOutput { values: up_out }), Ok(TsiOutput { values: down_out })) = (
1595                    tsi_with_kernel(&uptrend_input, kernel),
1596                    tsi_with_kernel(&downtrend_input, kernel),
1597                ) {
1598                    let test_warmup = 1 + long_period + short_period - 1;
1599                    if test_warmup + 10 < up_out.len() {
1600                        let up_vals: Vec<f64> = up_out[up_out.len() - 10..]
1601                            .iter()
1602                            .filter(|&&x| !x.is_nan())
1603                            .copied()
1604                            .collect();
1605                        let down_vals: Vec<f64> = down_out[down_out.len() - 10..]
1606                            .iter()
1607                            .filter(|&&x| !x.is_nan())
1608                            .copied()
1609                            .collect();
1610
1611                        if !up_vals.is_empty() && !down_vals.is_empty() {
1612                            let up_avg = up_vals.iter().sum::<f64>() / up_vals.len() as f64;
1613                            let down_avg = down_vals.iter().sum::<f64>() / down_vals.len() as f64;
1614
1615                            let tolerance = if long_period > 20 { 10.0 } else { 0.0 };
1616
1617                            prop_assert!(
1618								up_avg > tolerance,
1619								"Property 5: Uptrend TSI should be positive, got avg: {} (long={}, short={})",
1620								up_avg, long_period, short_period
1621							);
1622                            prop_assert!(
1623								down_avg < -tolerance,
1624								"Property 5: Downtrend TSI should be negative, got avg: {} (long={}, short={})",
1625								down_avg, long_period, short_period
1626							);
1627                        }
1628                    }
1629                }
1630
1631                let extreme_up: Vec<f64> = (0..100).map(|i| 100.0 + i as f64 * 50.0).collect();
1632                let extreme_params = TsiParams {
1633                    long_period: Some(5),
1634                    short_period: Some(3),
1635                };
1636                let extreme_input = TsiInput::from_slice(&extreme_up, extreme_params.clone());
1637                if let Ok(TsiOutput {
1638                    values: extreme_out,
1639                }) = tsi_with_kernel(&extreme_input, kernel)
1640                {
1641                    let last_valid = extreme_out.iter().rposition(|x| !x.is_nan());
1642                    if let Some(idx) = last_valid {
1643                        if idx >= 20 {
1644                            let last_vals = &extreme_out[idx - 5..=idx];
1645                            for &v in last_vals {
1646                                if !v.is_nan() {
1647                                    prop_assert!(
1648										v > 80.0,
1649										"Property 6: Very strong uptrend should have TSI > 80, got: {}",
1650										v
1651									);
1652                                }
1653                            }
1654                        }
1655                    }
1656                }
1657
1658                let extreme_down: Vec<f64> = (0..100).map(|i| 5100.0 - i as f64 * 50.0).collect();
1659                let extreme_down_input = TsiInput::from_slice(&extreme_down, extreme_params);
1660                if let Ok(TsiOutput {
1661                    values: extreme_down_out,
1662                }) = tsi_with_kernel(&extreme_down_input, kernel)
1663                {
1664                    let last_valid = extreme_down_out.iter().rposition(|x| !x.is_nan());
1665                    if let Some(idx) = last_valid {
1666                        if idx >= 20 {
1667                            let last_vals = &extreme_down_out[idx - 5..=idx];
1668                            for &v in last_vals {
1669                                if !v.is_nan() {
1670                                    prop_assert!(
1671										v < -80.0,
1672										"Property 6: Very strong downtrend should have TSI < -80, got: {}",
1673										v
1674									);
1675                                }
1676                            }
1677                        }
1678                    }
1679                }
1680
1681                if data.len() >= 20 {
1682                    let increasing_count = data.windows(2).filter(|w| w[1] > w[0]).count();
1683                    let decreasing_count = data.windows(2).filter(|w| w[1] < w[0]).count();
1684
1685                    let valid_tsi: Vec<f64> = out
1686                        .iter()
1687                        .rev()
1688                        .take(10)
1689                        .filter(|&&x| !x.is_nan())
1690                        .copied()
1691                        .collect();
1692
1693                    if !valid_tsi.is_empty() {
1694                        let avg_tsi = valid_tsi.iter().sum::<f64>() / valid_tsi.len() as f64;
1695
1696                        if increasing_count > decreasing_count * 2 {
1697                            prop_assert!(
1698								avg_tsi > -20.0,
1699								"Property 7: Mostly increasing prices should have TSI > -20, got: {}",
1700								avg_tsi
1701							);
1702                        } else if decreasing_count > increasing_count * 2 {
1703                            prop_assert!(
1704								avg_tsi < 20.0,
1705								"Property 7: Mostly decreasing prices should have TSI < 20, got: {}",
1706								avg_tsi
1707							);
1708                        }
1709                    }
1710                }
1711
1712                #[cfg(debug_assertions)]
1713                {
1714                    for (i, &val) in out.iter().enumerate() {
1715                        if !val.is_nan() {
1716                            let bits = val.to_bits();
1717                            prop_assert!(
1718                                bits != 0x11111111_11111111
1719                                    && bits != 0x22222222_22222222
1720                                    && bits != 0x33333333_33333333,
1721                                "Property 8: Found poison value at idx {}: {} (0x{:016X})",
1722                                i,
1723                                val,
1724                                bits
1725                            );
1726                        }
1727                    }
1728                }
1729
1730                if data.len() >= 50 && has_variation {
1731                    let start_idx = (1 + long_period + short_period).max(20);
1732                    if start_idx + 20 < out.len() {
1733                        let mut momentum_changes = 0;
1734                        let mut tsi_follows = 0;
1735
1736                        for i in start_idx..out.len() - 10 {
1737                            if !out[i].is_nan() && !out[i + 5].is_nan() && !out[i + 10].is_nan() {
1738                                let price_change1 = data[i + 5] - data[i];
1739                                let price_change2 = data[i + 10] - data[i + 5];
1740
1741                                let tsi_change1 = out[i + 5] - out[i];
1742                                let tsi_change2 = out[i + 10] - out[i + 5];
1743
1744                                if price_change1 * price_change2 < 0.0 {
1745                                    momentum_changes += 1;
1746
1747                                    if tsi_change1 * tsi_change2 < 0.0
1748                                        || (price_change2 > 0.0 && tsi_change2 > tsi_change1)
1749                                        || (price_change2 < 0.0 && tsi_change2 < tsi_change1)
1750                                    {
1751                                        tsi_follows += 1;
1752                                    }
1753                                }
1754                            }
1755                        }
1756
1757                        if momentum_changes > 0 {
1758                            let follow_rate = tsi_follows as f64 / momentum_changes as f64;
1759                            prop_assert!(
1760								follow_rate >= 0.3,
1761								"Property 9: TSI should respond to momentum changes, follow rate: {:.2}",
1762								follow_rate
1763							);
1764                        }
1765                    }
1766                }
1767
1768                Ok(())
1769            })
1770            .unwrap();
1771
1772        Ok(())
1773    }
1774
1775    macro_rules! generate_all_tsi_tests {
1776        ($($test_fn:ident),*) => {
1777            paste! {
1778                $(
1779                    #[test]
1780                    fn [<$test_fn _scalar_f64>]() {
1781                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1782                    }
1783                )*
1784                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1785                $(
1786                    #[test]
1787                    fn [<$test_fn _avx2_f64>]() {
1788                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1789                    }
1790                    #[test]
1791                    fn [<$test_fn _avx512_f64>]() {
1792                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1793                    }
1794                )*
1795            }
1796        }
1797    }
1798    generate_all_tsi_tests!(
1799        check_tsi_partial_params,
1800        check_tsi_accuracy,
1801        check_tsi_zero_period,
1802        check_tsi_period_exceeds_length,
1803        check_tsi_very_small_dataset,
1804        check_tsi_default_candles,
1805        check_tsi_reinput,
1806        check_tsi_nan_handling,
1807        check_tsi_no_poison
1808    );
1809
1810    #[cfg(feature = "proptest")]
1811    generate_all_tsi_tests!(check_tsi_property);
1812
1813    fn check_batch_default_row(
1814        test: &str,
1815        kernel: Kernel,
1816    ) -> Result<(), Box<dyn std::error::Error>> {
1817        skip_if_unsupported!(kernel, test);
1818        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1819        let c = read_candles_from_csv(file)?;
1820        let output = TsiBatchBuilder::new()
1821            .kernel(kernel)
1822            .apply_candles(&c, "close")?;
1823        let def = TsiParams::default();
1824        let row = output.values_for(&def).expect("default row missing");
1825        assert_eq!(row.len(), c.close.len());
1826        let expected = [
1827            -17.757654061849838,
1828            -17.367527062626184,
1829            -17.305577681249513,
1830            -16.937565646991143,
1831            -17.61825617316731,
1832        ];
1833        let start = row.len() - 5;
1834        for (i, &v) in row[start..].iter().enumerate() {
1835            assert!(
1836                (v - expected[i]).abs() < 1e-7,
1837                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1838            );
1839        }
1840        Ok(())
1841    }
1842
1843    #[cfg(debug_assertions)]
1844    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1845        skip_if_unsupported!(kernel, test);
1846
1847        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1848        let c = read_candles_from_csv(file)?;
1849
1850        let test_configs = vec![
1851            (5, 10, 1, 2, 5, 1),
1852            (10, 30, 5, 5, 15, 5),
1853            (25, 50, 5, 10, 25, 5),
1854            (50, 100, 10, 25, 50, 5),
1855            (25, 25, 0, 2, 20, 2),
1856            (10, 50, 10, 13, 13, 0),
1857            (100, 200, 50, 50, 100, 25),
1858            (2, 5, 1, 2, 5, 1),
1859            (30, 30, 0, 5, 25, 5),
1860        ];
1861
1862        for (cfg_idx, &(l_start, l_end, l_step, s_start, s_end, s_step)) in
1863            test_configs.iter().enumerate()
1864        {
1865            let output = TsiBatchBuilder::new()
1866                .kernel(kernel)
1867                .long_range(l_start, l_end, l_step)
1868                .short_range(s_start, s_end, s_step)
1869                .apply_candles(&c, "close")?;
1870
1871            for (idx, &val) in output.values.iter().enumerate() {
1872                if val.is_nan() {
1873                    continue;
1874                }
1875
1876                let bits = val.to_bits();
1877                let row = idx / output.cols;
1878                let col = idx % output.cols;
1879                let combo = &output.combos[row];
1880
1881                if bits == 0x11111111_11111111 {
1882                    panic!(
1883                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1884						 at row {} col {} (flat index {}) with params: {:?}",
1885                        test, cfg_idx, val, bits, row, col, idx, combo
1886                    );
1887                }
1888
1889                if bits == 0x22222222_22222222 {
1890                    panic!(
1891                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1892						 at row {} col {} (flat index {}) with params: {:?}",
1893                        test, cfg_idx, val, bits, row, col, idx, combo
1894                    );
1895                }
1896
1897                if bits == 0x33333333_33333333 {
1898                    panic!(
1899                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1900						 at row {} col {} (flat index {}) with params: {:?}",
1901                        test, cfg_idx, val, bits, row, col, idx, combo
1902                    );
1903                }
1904            }
1905        }
1906
1907        Ok(())
1908    }
1909
1910    #[cfg(not(debug_assertions))]
1911    fn check_batch_no_poison(
1912        _test: &str,
1913        _kernel: Kernel,
1914    ) -> Result<(), Box<dyn std::error::Error>> {
1915        Ok(())
1916    }
1917
1918    macro_rules! gen_batch_tests {
1919        ($fn_name:ident) => {
1920            paste! {
1921                #[test] fn [<$fn_name _scalar>]()      {
1922                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1923                }
1924                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1925                #[test] fn [<$fn_name _avx2>]()        {
1926                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1927                }
1928                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1929                #[test] fn [<$fn_name _avx512>]()      {
1930                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1931                }
1932                #[test] fn [<$fn_name _auto_detect>]() {
1933                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1934                }
1935            }
1936        };
1937    }
1938    gen_batch_tests!(check_batch_default_row);
1939    gen_batch_tests!(check_batch_no_poison);
1940}
1941
1942#[cfg(feature = "python")]
1943#[pyfunction(name = "tsi")]
1944#[pyo3(signature = (data, long_period=25, short_period=13, kernel=None))]
1945pub fn tsi_py<'py>(
1946    py: Python<'py>,
1947    data: PyReadonlyArray1<'py, f64>,
1948    long_period: usize,
1949    short_period: usize,
1950    kernel: Option<&str>,
1951) -> PyResult<Bound<'py, PyArray1<f64>>> {
1952    use numpy::{IntoPyArray, PyArrayMethods};
1953
1954    let slice_in = data.as_slice()?;
1955    let kern = validate_kernel(kernel, false)?;
1956
1957    let params = TsiParams {
1958        long_period: Some(long_period),
1959        short_period: Some(short_period),
1960    };
1961    let input = TsiInput::from_slice(slice_in, params);
1962
1963    let result_vec: Vec<f64> = py
1964        .allow_threads(|| tsi_with_kernel(&input, kern).map(|o| o.values))
1965        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1966
1967    Ok(result_vec.into_pyarray(py))
1968}
1969
1970#[cfg(feature = "python")]
1971#[pyfunction(name = "tsi_batch")]
1972#[pyo3(signature = (data, long_period_range, short_period_range, kernel=None))]
1973pub fn tsi_batch_py<'py>(
1974    py: Python<'py>,
1975    data: PyReadonlyArray1<'py, f64>,
1976    long_period_range: (usize, usize, usize),
1977    short_period_range: (usize, usize, usize),
1978    kernel: Option<&str>,
1979) -> PyResult<Bound<'py, PyDict>> {
1980    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1981    use pyo3::types::PyDict;
1982
1983    let slice_in = data.as_slice()?;
1984
1985    let sweep = TsiBatchRange {
1986        long_period: long_period_range,
1987        short_period: short_period_range,
1988    };
1989
1990    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1991    let rows = combos.len();
1992    let cols = slice_in.len();
1993
1994    let total = rows
1995        .checked_mul(cols)
1996        .ok_or_else(|| PyValueError::new_err("tsi_batch: size overflow"))?;
1997
1998    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1999    let slice_out = unsafe { out_arr.as_slice_mut()? };
2000
2001    let kern = validate_kernel(kernel, true)?;
2002
2003    let combos = py
2004        .allow_threads(|| {
2005            let kernel = match kern {
2006                Kernel::Auto => Kernel::ScalarBatch,
2007                k => k,
2008            };
2009
2010            let simd = match kernel {
2011                Kernel::Avx512Batch => Kernel::Avx512,
2012                Kernel::Avx2Batch => Kernel::Avx2,
2013                Kernel::ScalarBatch => Kernel::Scalar,
2014                _ => kernel,
2015            };
2016            tsi_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2017        })
2018        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2019
2020    let dict = PyDict::new(py);
2021    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2022
2023    dict.set_item(
2024        "long_periods",
2025        combos
2026            .iter()
2027            .map(|p| p.long_period.unwrap() as u64)
2028            .collect::<Vec<_>>()
2029            .into_pyarray(py),
2030    )?;
2031
2032    dict.set_item(
2033        "short_periods",
2034        combos
2035            .iter()
2036            .map(|p| p.short_period.unwrap() as u64)
2037            .collect::<Vec<_>>()
2038            .into_pyarray(py),
2039    )?;
2040
2041    Ok(dict)
2042}
2043
2044#[cfg(all(feature = "python", feature = "cuda"))]
2045use crate::cuda::cuda_available;
2046#[cfg(all(feature = "python", feature = "cuda"))]
2047use crate::cuda::oscillators::tsi_wrapper::CudaTsi;
2048
2049#[cfg(all(feature = "python", feature = "cuda"))]
2050#[pyclass(module = "ta_indicators.cuda", name = "TsiDeviceArrayF32", unsendable)]
2051pub struct TsiDeviceArrayF32Py {
2052    pub(crate) inner: DeviceArrayF32,
2053    ctx_guard: Arc<Context>,
2054    device_id: u32,
2055}
2056
2057#[cfg(all(feature = "python", feature = "cuda"))]
2058#[pymethods]
2059impl TsiDeviceArrayF32Py {
2060    #[getter]
2061    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2062        let inner = &self.inner;
2063        let d = PyDict::new(py);
2064        let itemsize = std::mem::size_of::<f32>();
2065        d.set_item("shape", (inner.rows, inner.cols))?;
2066        d.set_item("typestr", "<f4")?;
2067        d.set_item("strides", (inner.cols * itemsize, itemsize))?;
2068        d.set_item("data", (inner.device_ptr() as usize, false))?;
2069
2070        d.set_item("version", 3)?;
2071        Ok(d)
2072    }
2073
2074    fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
2075        Ok((2, self.device_id as i32))
2076    }
2077
2078    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2079    fn __dlpack__<'py>(
2080        &mut self,
2081        py: Python<'py>,
2082        stream: Option<pyo3::PyObject>,
2083        max_version: Option<pyo3::PyObject>,
2084        dl_device: Option<pyo3::PyObject>,
2085        copy: Option<pyo3::PyObject>,
2086    ) -> PyResult<PyObject> {
2087        use cust::memory::DeviceBuffer;
2088
2089        let (kdl, alloc_dev) = self.__dlpack_device__()?;
2090        if let Some(dev_obj) = dl_device.as_ref() {
2091            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2092                if dev_ty != kdl || dev_id != alloc_dev {
2093                    let wants_copy = copy
2094                        .as_ref()
2095                        .and_then(|c| c.extract::<bool>(py).ok())
2096                        .unwrap_or(false);
2097                    if wants_copy {
2098                        return Err(PyValueError::new_err(
2099                            "device copy not implemented for __dlpack__",
2100                        ));
2101                    } else {
2102                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2103                    }
2104                }
2105            }
2106        }
2107        let _ = stream;
2108
2109        let dummy =
2110            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2111        let inner = std::mem::replace(
2112            &mut self.inner,
2113            DeviceArrayF32 {
2114                buf: dummy,
2115                rows: 0,
2116                cols: 0,
2117            },
2118        );
2119
2120        let rows = inner.rows;
2121        let cols = inner.cols;
2122        let buf = inner.buf;
2123
2124        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2125
2126        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2127    }
2128}
2129
2130#[cfg(all(feature = "python", feature = "cuda"))]
2131impl TsiDeviceArrayF32Py {
2132    pub fn new_from_rust(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
2133        Self {
2134            inner,
2135            ctx_guard,
2136            device_id,
2137        }
2138    }
2139}
2140
2141#[cfg(all(feature = "python", feature = "cuda"))]
2142#[pyfunction(name = "tsi_cuda_batch_dev")]
2143#[pyo3(signature = (data_f32, long_period_range, short_period_range, device_id=0))]
2144pub fn tsi_cuda_batch_dev_py<'py>(
2145    py: Python<'py>,
2146    data_f32: numpy::PyReadonlyArray1<'py, f32>,
2147    long_period_range: (usize, usize, usize),
2148    short_period_range: (usize, usize, usize),
2149    device_id: usize,
2150) -> PyResult<(TsiDeviceArrayF32Py, Bound<'py, PyDict>)> {
2151    if !cuda_available() {
2152        return Err(PyValueError::new_err("CUDA not available"));
2153    }
2154    let slice_in = data_f32.as_slice()?;
2155    let sweep = TsiBatchRange {
2156        long_period: long_period_range,
2157        short_period: short_period_range,
2158    };
2159    let (inner, combos, ctx, dev_id) = py.allow_threads(|| {
2160        let mut cuda = CudaTsi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2161        let ctx = cuda.context_arc();
2162        let dev_id = cuda.device_id();
2163        let (arr, combos) = cuda
2164            .tsi_batch_dev(slice_in, &sweep)
2165            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2166        Ok::<_, PyErr>((arr, combos, ctx, dev_id))
2167    })?;
2168
2169    use numpy::{IntoPyArray, PyArrayMethods};
2170    let dict = PyDict::new(py);
2171    dict.set_item(
2172        "long_periods",
2173        combos
2174            .iter()
2175            .map(|p| p.long_period.unwrap() as u64)
2176            .collect::<Vec<_>>()
2177            .into_pyarray(py),
2178    )?;
2179    dict.set_item(
2180        "short_periods",
2181        combos
2182            .iter()
2183            .map(|p| p.short_period.unwrap() as u64)
2184            .collect::<Vec<_>>()
2185            .into_pyarray(py),
2186    )?;
2187    Ok((TsiDeviceArrayF32Py::new_from_rust(inner, ctx, dev_id), dict))
2188}
2189
2190#[cfg(all(feature = "python", feature = "cuda"))]
2191#[pyfunction(name = "tsi_cuda_many_series_one_param_dev")]
2192#[pyo3(signature = (data_tm_f32, cols, rows, long_period, short_period, device_id=0))]
2193pub fn tsi_cuda_many_series_one_param_dev_py<'py>(
2194    py: Python<'py>,
2195    data_tm_f32: numpy::PyReadonlyArray1<'py, f32>,
2196    cols: usize,
2197    rows: usize,
2198    long_period: usize,
2199    short_period: usize,
2200    device_id: usize,
2201) -> PyResult<TsiDeviceArrayF32Py> {
2202    if !cuda_available() {
2203        return Err(PyValueError::new_err("CUDA not available"));
2204    }
2205    let slice_in = data_tm_f32.as_slice()?;
2206    let (inner, ctx, dev_id) = py.allow_threads(|| {
2207        let mut cuda = CudaTsi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2208        let ctx = cuda.context_arc();
2209        let dev_id = cuda.device_id();
2210        let arr = cuda
2211            .tsi_many_series_one_param_time_major_dev(
2212                slice_in,
2213                cols,
2214                rows,
2215                long_period,
2216                short_period,
2217            )
2218            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2219        Ok::<_, PyErr>((arr, ctx, dev_id))
2220    })?;
2221    Ok(TsiDeviceArrayF32Py::new_from_rust(inner, ctx, dev_id))
2222}
2223
2224#[cfg(feature = "python")]
2225#[pyclass(name = "TsiStream")]
2226pub struct TsiStreamPy {
2227    inner: TsiStream,
2228}
2229
2230#[cfg(feature = "python")]
2231#[pymethods]
2232impl TsiStreamPy {
2233    #[new]
2234    pub fn new(long_period: usize, short_period: usize) -> PyResult<Self> {
2235        let params = TsiParams {
2236            long_period: Some(long_period),
2237            short_period: Some(short_period),
2238        };
2239        let inner = TsiStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2240        Ok(TsiStreamPy { inner })
2241    }
2242
2243    pub fn update(&mut self, value: f64) -> Option<f64> {
2244        self.inner.update(value)
2245    }
2246}
2247
2248#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2249#[wasm_bindgen]
2250pub fn tsi_js(data: &[f64], long_period: usize, short_period: usize) -> Result<Vec<f64>, JsValue> {
2251    let params = TsiParams {
2252        long_period: Some(long_period),
2253        short_period: Some(short_period),
2254    };
2255    let input = TsiInput::from_slice(data, params);
2256
2257    let mut output = vec![0.0; data.len()];
2258
2259    tsi_into_slice(&mut output, &input, Kernel::Auto)
2260        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2261
2262    Ok(output)
2263}
2264
2265#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2266#[wasm_bindgen]
2267pub fn tsi_into(
2268    in_ptr: *const f64,
2269    out_ptr: *mut f64,
2270    len: usize,
2271    long_period: usize,
2272    short_period: usize,
2273) -> Result<(), JsValue> {
2274    if in_ptr.is_null() || out_ptr.is_null() {
2275        return Err(JsValue::from_str("Null pointer provided"));
2276    }
2277
2278    unsafe {
2279        let data = std::slice::from_raw_parts(in_ptr, len);
2280        let params = TsiParams {
2281            long_period: Some(long_period),
2282            short_period: Some(short_period),
2283        };
2284        let input = TsiInput::from_slice(data, params);
2285
2286        if in_ptr == out_ptr as *const f64 {
2287            let mut temp = vec![0.0; len];
2288            tsi_into_slice(&mut temp, &input, Kernel::Auto)
2289                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2290            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2291            out.copy_from_slice(&temp);
2292        } else {
2293            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2294            tsi_into_slice(out, &input, Kernel::Auto)
2295                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2296        }
2297        Ok(())
2298    }
2299}
2300
2301#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2302#[wasm_bindgen]
2303pub fn tsi_alloc(len: usize) -> *mut f64 {
2304    let mut vec = Vec::<f64>::with_capacity(len);
2305    let ptr = vec.as_mut_ptr();
2306    std::mem::forget(vec);
2307    ptr
2308}
2309
2310#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2311#[wasm_bindgen]
2312pub fn tsi_free(ptr: *mut f64, len: usize) {
2313    if !ptr.is_null() {
2314        unsafe {
2315            let _ = Vec::from_raw_parts(ptr, len, len);
2316        }
2317    }
2318}
2319
2320#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2321#[derive(Serialize, Deserialize)]
2322pub struct TsiBatchConfig {
2323    pub long_period_range: (usize, usize, usize),
2324    pub short_period_range: (usize, usize, usize),
2325}
2326
2327#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2328#[derive(Serialize, Deserialize)]
2329pub struct TsiBatchJsOutput {
2330    pub values: Vec<f64>,
2331    pub combos: Vec<TsiParams>,
2332    pub rows: usize,
2333    pub cols: usize,
2334}
2335
2336#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2337#[wasm_bindgen(js_name = tsi_batch)]
2338pub fn tsi_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2339    let config: TsiBatchConfig = serde_wasm_bindgen::from_value(config)
2340        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2341
2342    let sweep = TsiBatchRange {
2343        long_period: config.long_period_range,
2344        short_period: config.short_period_range,
2345    };
2346
2347    let result = tsi_batch_slice(data, &sweep, Kernel::Scalar)
2348        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2349
2350    let js_output = TsiBatchJsOutput {
2351        values: result.values,
2352        combos: result.combos,
2353        rows: result.rows,
2354        cols: result.cols,
2355    };
2356
2357    serde_wasm_bindgen::to_value(&js_output)
2358        .map_err(|e| JsValue::from_str(&format!("Failed to serialize output: {}", e)))
2359}
2360
2361#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2362#[wasm_bindgen]
2363pub fn tsi_batch_into(
2364    in_ptr: *const f64,
2365    out_ptr: *mut f64,
2366    len: usize,
2367    long_period_start: usize,
2368    long_period_end: usize,
2369    long_period_step: usize,
2370    short_period_start: usize,
2371    short_period_end: usize,
2372    short_period_step: usize,
2373) -> Result<usize, JsValue> {
2374    if in_ptr.is_null() || out_ptr.is_null() {
2375        return Err(JsValue::from_str("null pointer passed to tsi_batch_into"));
2376    }
2377
2378    unsafe {
2379        let data = std::slice::from_raw_parts(in_ptr, len);
2380
2381        let sweep = TsiBatchRange {
2382            long_period: (long_period_start, long_period_end, long_period_step),
2383            short_period: (short_period_start, short_period_end, short_period_step),
2384        };
2385
2386        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2387        let rows = combos.len();
2388        let cols = len;
2389        let total_size = rows
2390            .checked_mul(cols)
2391            .ok_or_else(|| JsValue::from_str("tsi_batch_into: size overflow"))?;
2392
2393        let out_slice = std::slice::from_raw_parts_mut(out_ptr, total_size);
2394
2395        match tsi_batch_inner_into(data, &sweep, Kernel::Scalar, false, out_slice) {
2396            Ok(_) => Ok(rows),
2397            Err(e) => Err(JsValue::from_str(&e.to_string())),
2398        }
2399    }
2400}