Skip to main content

vector_ta/indicators/moving_averages/
tema.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5    make_uninit_matrix,
6};
7#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
8use core::arch::x86_64::*;
9#[cfg(not(target_arch = "wasm32"))]
10use rayon::prelude::*;
11use std::convert::AsRef;
12use std::error::Error;
13use std::mem::MaybeUninit;
14use thiserror::Error;
15
16#[cfg(all(feature = "python", feature = "cuda"))]
17use crate::cuda::cuda_available;
18#[cfg(all(feature = "python", feature = "cuda"))]
19use crate::cuda::moving_averages::CudaTema;
20#[cfg(all(feature = "python", feature = "cuda"))]
21use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
22#[cfg(feature = "python")]
23use crate::utilities::kernel_validation::validate_kernel;
24#[cfg(feature = "python")]
25use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
26#[cfg(all(feature = "python", feature = "cuda"))]
27use numpy::{PyReadonlyArray1, PyReadonlyArray2};
28#[cfg(feature = "python")]
29use pyo3::exceptions::PyValueError;
30#[cfg(feature = "python")]
31use pyo3::prelude::*;
32#[cfg(feature = "python")]
33use pyo3::types::{PyDict, PyList};
34#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
35use serde::{Deserialize, Serialize};
36#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
37use wasm_bindgen::prelude::*;
38
39#[derive(Debug, Clone)]
40pub enum TemaData<'a> {
41    Candles {
42        candles: &'a Candles,
43        source: &'a str,
44    },
45    Slice(&'a [f64]),
46}
47
48#[derive(Debug, Clone)]
49pub struct TemaInput<'a> {
50    pub data: TemaData<'a>,
51    pub params: TemaParams,
52}
53
54impl<'a> TemaInput<'a> {
55    #[inline]
56    pub fn from_candles(candles: &'a Candles, source: &'a str, params: TemaParams) -> Self {
57        Self {
58            data: TemaData::Candles { candles, source },
59            params,
60        }
61    }
62    #[inline]
63    pub fn from_slice(slice: &'a [f64], params: TemaParams) -> Self {
64        Self {
65            data: TemaData::Slice(slice),
66            params,
67        }
68    }
69    #[inline]
70    pub fn with_default_candles(candles: &'a Candles) -> Self {
71        Self::from_candles(candles, "close", TemaParams::default())
72    }
73    #[inline]
74    pub fn get_period(&self) -> usize {
75        self.params.period.unwrap_or(9)
76    }
77}
78
79impl<'a> AsRef<[f64]> for TemaInput<'a> {
80    #[inline(always)]
81    fn as_ref(&self) -> &[f64] {
82        match &self.data {
83            TemaData::Slice(slice) => slice,
84            TemaData::Candles { candles, source } => source_type(candles, source),
85        }
86    }
87}
88
89#[derive(Debug, Clone, Copy)]
90#[cfg_attr(
91    all(target_arch = "wasm32", feature = "wasm"),
92    derive(Serialize, Deserialize)
93)]
94pub struct TemaParams {
95    pub period: Option<usize>,
96}
97
98impl Default for TemaParams {
99    fn default() -> Self {
100        Self { period: Some(9) }
101    }
102}
103
104#[derive(Debug, Clone)]
105pub struct TemaOutput {
106    pub values: Vec<f64>,
107}
108
109#[derive(Debug, Error)]
110pub enum TemaError {
111    #[error("tema: Input data slice is empty.")]
112    EmptyInputData,
113    #[error("tema: All values are NaN.")]
114    AllValuesNaN,
115    #[error("tema: Invalid period: period = {period}, data length = {data_len}")]
116    InvalidPeriod { period: usize, data_len: usize },
117    #[error("tema: Not enough valid data: needed = {needed}, valid = {valid}")]
118    NotEnoughValidData { needed: usize, valid: usize },
119    #[error("tema: Output length mismatch: expected {expected}, got {got}")]
120    OutputLengthMismatch { expected: usize, got: usize },
121    #[error("tema: Invalid range: start={start}, end={end}, step={step}")]
122    InvalidRange {
123        start: usize,
124        end: usize,
125        step: usize,
126    },
127    #[error("tema: Invalid kernel for batch path: {0:?}")]
128    InvalidKernelForBatch(Kernel),
129    #[error("tema: Invalid input: {0}")]
130    InvalidInput(String),
131}
132
133#[derive(Copy, Clone, Debug)]
134pub struct TemaBuilder {
135    period: Option<usize>,
136    kernel: Kernel,
137}
138
139impl Default for TemaBuilder {
140    fn default() -> Self {
141        Self {
142            period: None,
143            kernel: Kernel::Auto,
144        }
145    }
146}
147
148impl TemaBuilder {
149    #[inline(always)]
150    pub fn new() -> Self {
151        Self::default()
152    }
153    #[inline(always)]
154    pub fn period(mut self, n: usize) -> Self {
155        self.period = Some(n);
156        self
157    }
158    #[inline(always)]
159    pub fn kernel(mut self, k: Kernel) -> Self {
160        self.kernel = k;
161        self
162    }
163    #[inline(always)]
164    pub fn apply(self, c: &Candles) -> Result<TemaOutput, TemaError> {
165        let p = TemaParams {
166            period: self.period,
167        };
168        let i = TemaInput::from_candles(c, "close", p);
169        tema_with_kernel(&i, self.kernel)
170    }
171    #[inline(always)]
172    pub fn apply_slice(self, d: &[f64]) -> Result<TemaOutput, TemaError> {
173        let p = TemaParams {
174            period: self.period,
175        };
176        let i = TemaInput::from_slice(d, p);
177        tema_with_kernel(&i, self.kernel)
178    }
179    #[inline(always)]
180    pub fn into_stream(self) -> Result<TemaStream, TemaError> {
181        let p = TemaParams {
182            period: self.period,
183        };
184        TemaStream::try_new(p)
185    }
186}
187
188#[inline]
189pub fn tema(input: &TemaInput) -> Result<TemaOutput, TemaError> {
190    tema_with_kernel(input, Kernel::Auto)
191}
192
193pub fn tema_with_kernel(input: &TemaInput, kernel: Kernel) -> Result<TemaOutput, TemaError> {
194    let (data, period, first, len, chosen) = tema_prepare(input, kernel)?;
195    let lookback = (period - 1) * 3;
196    let warm = first + lookback;
197
198    let mut out = alloc_with_nan_prefix(len, warm);
199    tema_compute_into(data, period, first, chosen, &mut out);
200    Ok(TemaOutput { values: out })
201}
202
203#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
204#[inline]
205pub fn tema_into(input: &TemaInput, out: &mut [f64]) -> Result<(), TemaError> {
206    let (data, period, first, len, chosen) = tema_prepare(input, Kernel::Auto)?;
207
208    if out.len() != len {
209        return Err(TemaError::OutputLengthMismatch {
210            expected: len,
211            got: out.len(),
212        });
213    }
214
215    let warm = first + (period - 1) * 3;
216    let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
217    let pref = warm.min(out.len());
218    for v in &mut out[..pref] {
219        *v = qnan;
220    }
221
222    tema_compute_into(data, period, first, chosen, out);
223
224    Ok(())
225}
226
227#[inline]
228pub fn tema_scalar(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
229    debug_assert_eq!(data.len(), out.len());
230
231    let n = data.len();
232    if n == 0 {
233        return;
234    }
235
236    let per = 2.0 / (period as f64 + 1.0);
237    let per1 = 1.0 - per;
238
239    let p1 = period - 1;
240    let start2 = first_val + p1;
241    let start3 = first_val + (p1 << 1);
242    let start_out = first_val + p1 * 3;
243
244    if period == 1 {
245        out[first_val..n].copy_from_slice(&data[first_val..n]);
246        return;
247    }
248
249    let mut ema1 = data[first_val];
250    let mut ema2 = 0.0f64;
251    let mut ema3 = 0.0f64;
252
253    let end0 = start2.min(n);
254    for i in first_val..end0 {
255        let price = data[i];
256        ema1 = ema1 * per1 + price * per;
257    }
258
259    if start2 < n {
260        let price = data[start2];
261        ema1 = ema1 * per1 + price * per;
262        ema2 = ema1;
263        ema2 = ema2 * per1 + ema1 * per;
264
265        let end1 = start3.min(n);
266        for i in (start2 + 1)..end1 {
267            let price = data[i];
268            ema1 = ema1 * per1 + price * per;
269            ema2 = ema2 * per1 + ema1 * per;
270        }
271
272        if start3 < n {
273            let price = data[start3];
274            ema1 = ema1 * per1 + price * per;
275            ema2 = ema2 * per1 + ema1 * per;
276            ema3 = ema2;
277            ema3 = ema3 * per1 + ema2 * per;
278
279            let end2 = start_out.min(n);
280            for i in (start3 + 1)..end2 {
281                let price = data[i];
282                ema1 = ema1 * per1 + price * per;
283                ema2 = ema2 * per1 + ema1 * per;
284                ema3 = ema3 * per1 + ema2 * per;
285            }
286
287            for i in start_out..n {
288                let price = data[i];
289                ema1 = ema1 * per1 + price * per;
290                ema2 = ema2 * per1 + ema1 * per;
291                ema3 = ema3 * per1 + ema2 * per;
292
293                out[i] = 3.0 * ema1 - 3.0 * ema2 + ema3;
294            }
295        }
296    }
297}
298
299#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
300#[inline]
301pub fn tema_avx512(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
302    unsafe { tema_avx512_long(data, period, first_valid, out) }
303}
304
305#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
306#[inline]
307pub fn tema_avx2(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
308    tema_scalar(data, period, first_valid, out)
309}
310
311#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
312#[inline]
313pub unsafe fn tema_avx512_short(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
314    tema_scalar(data, period, first_valid, out)
315}
316
317#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
318#[inline]
319pub unsafe fn tema_avx512_long(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
320    tema_scalar(data, period, first_valid, out)
321}
322
323#[derive(Debug, Clone)]
324pub struct TemaStream {
325    period: usize,
326
327    alpha: f64,
328    one_minus_alpha: f64,
329
330    ema1: f64,
331    ema2: f64,
332    ema3: f64,
333
334    count: usize,
335
336    start2: usize,
337    start3: usize,
338    lookback: usize,
339
340    ready: bool,
341}
342
343impl TemaStream {
344    #[inline]
345    pub fn try_new(params: TemaParams) -> Result<Self, TemaError> {
346        let period = params.period.unwrap_or(9);
347        if period == 0 {
348            return Err(TemaError::InvalidPeriod {
349                period,
350                data_len: 0,
351            });
352        }
353
354        let alpha = 2.0 / (period as f64 + 1.0);
355        let one_minus_alpha = 1.0 - alpha;
356
357        let start2 = period;
358
359        let start3 = period + period - 1;
360
361        let lookback = (period - 1) * 3;
362
363        Ok(Self {
364            period,
365            alpha,
366            one_minus_alpha,
367            ema1: f64::NAN,
368            ema2: 0.0,
369            ema3: 0.0,
370            count: 0,
371            start2,
372            start3,
373            lookback,
374            ready: false,
375        })
376    }
377
378    #[inline(always)]
379    pub fn update(&mut self, x: f64) -> Option<f64> {
380        if self.ready {
381            self.ema1 = self.ema1 * self.one_minus_alpha + x * self.alpha;
382            self.ema2 = self.ema2 * self.one_minus_alpha + self.ema1 * self.alpha;
383            self.ema3 = self.ema3 * self.one_minus_alpha + self.ema2 * self.alpha;
384            return Some(3.0 * self.ema1 - 3.0 * self.ema2 + self.ema3);
385        }
386
387        if self.count == 0 {
388            if x.is_nan() {
389                return None;
390            }
391            self.ema1 = x;
392            self.count = 1;
393
394            if self.period == 1 {
395                self.ema2 = self.ema1;
396                self.ema3 = self.ema2;
397                self.ready = true;
398                return Some(self.ema1);
399            }
400            return None;
401        }
402
403        self.ema1 = self.ema1 * self.one_minus_alpha + x * self.alpha;
404        self.count += 1;
405
406        if self.count == self.start2 {
407            self.ema2 = self.ema1;
408        } else if self.count > self.start2 {
409            self.ema2 = self.ema2 * self.one_minus_alpha + self.ema1 * self.alpha;
410        }
411
412        if self.count == self.start3 {
413            self.ema3 = self.ema2;
414        } else if self.count > self.start3 {
415            self.ema3 = self.ema3 * self.one_minus_alpha + self.ema2 * self.alpha;
416        }
417
418        if self.count > self.lookback {
419            self.ready = true;
420            Some(3.0 * self.ema1 - 3.0 * self.ema2 + self.ema3)
421        } else {
422            None
423        }
424    }
425
426    #[inline(always)]
427    pub fn update_unchecked(&mut self, x: f64) -> f64 {
428        debug_assert!(self.ready, "call update() until Some(..) first");
429        self.ema1 = self.ema1 * self.one_minus_alpha + x * self.alpha;
430        self.ema2 = self.ema2 * self.one_minus_alpha + self.ema1 * self.alpha;
431        self.ema3 = self.ema3 * self.one_minus_alpha + self.ema2 * self.alpha;
432        3.0 * self.ema1 - 3.0 * self.ema2 + self.ema3
433    }
434}
435
436#[derive(Clone, Debug)]
437pub struct TemaBatchRange {
438    pub period: (usize, usize, usize),
439}
440
441impl Default for TemaBatchRange {
442    fn default() -> Self {
443        Self {
444            period: (9, 258, 1),
445        }
446    }
447}
448
449#[derive(Clone, Debug, Default)]
450pub struct TemaBatchBuilder {
451    range: TemaBatchRange,
452    kernel: Kernel,
453}
454
455impl TemaBatchBuilder {
456    pub fn new() -> Self {
457        Self::default()
458    }
459    pub fn kernel(mut self, k: Kernel) -> Self {
460        self.kernel = k;
461        self
462    }
463    #[inline]
464    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
465        self.range.period = (start, end, step);
466        self
467    }
468    #[inline]
469    pub fn period_static(mut self, p: usize) -> Self {
470        self.range.period = (p, p, 0);
471        self
472    }
473    pub fn apply_slice(self, data: &[f64]) -> Result<TemaBatchOutput, TemaError> {
474        tema_batch_with_kernel(data, &self.range, self.kernel)
475    }
476    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<TemaBatchOutput, TemaError> {
477        TemaBatchBuilder::new().kernel(k).apply_slice(data)
478    }
479    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<TemaBatchOutput, TemaError> {
480        let slice = source_type(c, src);
481        self.apply_slice(slice)
482    }
483    pub fn with_default_candles(c: &Candles) -> Result<TemaBatchOutput, TemaError> {
484        TemaBatchBuilder::new()
485            .kernel(Kernel::Auto)
486            .apply_candles(c, "close")
487    }
488}
489
490pub fn tema_batch_with_kernel(
491    data: &[f64],
492    sweep: &TemaBatchRange,
493    k: Kernel,
494) -> Result<TemaBatchOutput, TemaError> {
495    let kernel = match k {
496        Kernel::Auto => detect_best_batch_kernel(),
497        other if other.is_batch() => other,
498        other => return Err(TemaError::InvalidKernelForBatch(other)),
499    };
500    let simd = match kernel {
501        Kernel::Avx512Batch => Kernel::Avx512,
502        Kernel::Avx2Batch => Kernel::Avx2,
503        Kernel::ScalarBatch => Kernel::Scalar,
504        _ => unreachable!(),
505    };
506    tema_batch_par_slice(data, sweep, simd)
507}
508
509#[derive(Clone, Debug)]
510pub struct TemaBatchOutput {
511    pub values: Vec<f64>,
512    pub combos: Vec<TemaParams>,
513    pub rows: usize,
514    pub cols: usize,
515}
516impl TemaBatchOutput {
517    pub fn row_for_params(&self, p: &TemaParams) -> Option<usize> {
518        self.combos
519            .iter()
520            .position(|c| c.period.unwrap_or(9) == p.period.unwrap_or(9))
521    }
522    pub fn values_for(&self, p: &TemaParams) -> Option<&[f64]> {
523        self.row_for_params(p).map(|row| {
524            let start = row * self.cols;
525            &self.values[start..start + self.cols]
526        })
527    }
528}
529
530#[inline(always)]
531fn expand_grid(r: &TemaBatchRange) -> Result<Vec<TemaParams>, TemaError> {
532    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, TemaError> {
533        if step == 0 || start == end {
534            return Ok(vec![start]);
535        }
536
537        if start > end {
538            let mut v = Vec::new();
539            let mut cur = start;
540            while cur >= end {
541                v.push(cur);
542
543                if let Some(next) = cur.checked_sub(step) {
544                    if next == cur {
545                        break;
546                    }
547                    cur = next;
548                } else {
549                    break;
550                }
551                if cur == usize::MAX {
552                    break;
553                }
554                if cur < end {
555                    break;
556                }
557            }
558            if v.is_empty() {
559                return Err(TemaError::InvalidRange { start, end, step });
560            }
561            return Ok(v);
562        }
563        let it = (start..=end).step_by(step);
564        let v: Vec<usize> = it.collect();
565        if v.is_empty() {
566            return Err(TemaError::InvalidRange { start, end, step });
567        }
568        Ok(v)
569    }
570    let periods = axis_usize(r.period)?;
571    let mut out = Vec::with_capacity(periods.len());
572    for &p in &periods {
573        out.push(TemaParams { period: Some(p) });
574    }
575    Ok(out)
576}
577
578#[inline(always)]
579pub fn tema_batch_slice(
580    data: &[f64],
581    sweep: &TemaBatchRange,
582    kern: Kernel,
583) -> Result<TemaBatchOutput, TemaError> {
584    tema_batch_inner(data, sweep, kern, false)
585}
586
587#[inline(always)]
588pub fn tema_batch_par_slice(
589    data: &[f64],
590    sweep: &TemaBatchRange,
591    kern: Kernel,
592) -> Result<TemaBatchOutput, TemaError> {
593    tema_batch_inner(data, sweep, kern, true)
594}
595
596#[inline(always)]
597fn tema_batch_inner(
598    data: &[f64],
599    sweep: &TemaBatchRange,
600    kern: Kernel,
601    parallel: bool,
602) -> Result<TemaBatchOutput, TemaError> {
603    let combos = expand_grid(sweep)?;
604    if combos.is_empty() {
605        return Err(TemaError::InvalidRange {
606            start: sweep.period.0,
607            end: sweep.period.1,
608            step: sweep.period.2,
609        });
610    }
611
612    let first = data
613        .iter()
614        .position(|x| !x.is_nan())
615        .ok_or(TemaError::AllValuesNaN)?;
616    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
617    if data.len() - first < max_p {
618        return Err(TemaError::NotEnoughValidData {
619            needed: max_p,
620            valid: data.len() - first,
621        });
622    }
623
624    let rows = combos.len();
625    let cols = data.len();
626
627    let _total = rows
628        .checked_mul(cols)
629        .ok_or_else(|| TemaError::InvalidInput("rows * cols overflow".into()))?;
630
631    let actual_kern = match kern {
632        Kernel::Auto => detect_best_batch_kernel(),
633        k => k,
634    };
635
636    let warm: Vec<usize> = combos
637        .iter()
638        .map(|c| (first + (c.period.unwrap() - 1) * 3).min(cols))
639        .collect();
640
641    let mut raw = make_uninit_matrix(rows, cols);
642    unsafe { init_matrix_prefixes(&mut raw, cols, &warm) };
643
644    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
645        let period = combos[row].period.unwrap();
646
647        let out_row =
648            core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
649
650        match actual_kern {
651            Kernel::Scalar | Kernel::ScalarBatch => tema_row_scalar(data, first, period, out_row),
652            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
653            Kernel::Avx2 | Kernel::Avx2Batch => tema_row_avx2(data, first, period, out_row),
654            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
655            Kernel::Avx512 | Kernel::Avx512Batch => tema_row_avx512(data, first, period, out_row),
656            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
657            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
658                tema_row_scalar(data, first, period, out_row)
659            }
660            Kernel::Auto => unreachable!("Auto kernel should have been resolved"),
661        }
662    };
663
664    if parallel {
665        #[cfg(not(target_arch = "wasm32"))]
666        {
667            raw.par_chunks_mut(cols)
668                .enumerate()
669                .for_each(|(row, slice)| do_row(row, slice));
670        }
671
672        #[cfg(target_arch = "wasm32")]
673        {
674            for (row, slice) in raw.chunks_mut(cols).enumerate() {
675                do_row(row, slice);
676            }
677        }
678    } else {
679        for (row, slice) in raw.chunks_mut(cols).enumerate() {
680            do_row(row, slice);
681        }
682    }
683
684    let mut guard = core::mem::ManuallyDrop::new(raw);
685    let values: Vec<f64> = unsafe {
686        Vec::from_raw_parts(
687            guard.as_mut_ptr() as *mut f64,
688            guard.len(),
689            guard.capacity(),
690        )
691    };
692
693    Ok(TemaBatchOutput {
694        values,
695        combos,
696        rows,
697        cols,
698    })
699}
700
701#[inline(always)]
702unsafe fn tema_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
703    tema_scalar(data, period, first, out)
704}
705
706#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
707#[inline]
708unsafe fn tema_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
709    tema_scalar(data, period, first, out)
710}
711
712#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
713#[inline]
714unsafe fn tema_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
715    if period <= 32 {
716        tema_row_avx512_short(data, first, period, out)
717    } else {
718        tema_row_avx512_long(data, first, period, out)
719    }
720}
721
722#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
723#[inline]
724unsafe fn tema_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
725    tema_scalar(data, period, first, out)
726}
727
728#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
729#[inline]
730unsafe fn tema_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
731    tema_scalar(data, period, first, out)
732}
733
734#[cfg(test)]
735mod tests {
736    use super::*;
737    use crate::skip_if_unsupported;
738    use crate::utilities::data_loader::read_candles_from_csv;
739
740    #[test]
741    fn test_tema_into_matches_api() -> Result<(), Box<dyn Error>> {
742        let len = 256usize;
743        let mut data = vec![f64::NAN; len];
744        for i in 2..len {
745            let x = i as f64;
746            data[i] = (x * 0.31337).sin() * 10.0 + x * 0.01;
747        }
748
749        let input = TemaInput::from_slice(&data, TemaParams::default());
750
751        let baseline = tema(&input)?.values;
752
753        let mut out = vec![0.0; len];
754        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
755        {
756            tema_into(&input, &mut out)?;
757        }
758        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
759        {
760            tema_into_slice(&mut out, &input, Kernel::Auto)?;
761        }
762
763        assert_eq!(baseline.len(), out.len());
764
765        fn eq_or_both_nan(a: f64, b: f64) -> bool {
766            (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
767        }
768
769        for (i, (&a, &b)) in baseline.iter().zip(out.iter()).enumerate() {
770            assert!(
771                eq_or_both_nan(a, b),
772                "parity mismatch at idx {}: api={} into={} diff={}",
773                i,
774                a,
775                b,
776                (a - b).abs()
777            );
778        }
779
780        Ok(())
781    }
782
783    fn check_tema_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
784        skip_if_unsupported!(kernel, test_name);
785        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
786        let candles = read_candles_from_csv(file_path)?;
787        let default_params = TemaParams { period: None };
788        let input = TemaInput::from_candles(&candles, "close", default_params);
789        let output = tema_with_kernel(&input, kernel)?;
790        assert_eq!(output.values.len(), candles.close.len());
791        Ok(())
792    }
793    fn check_tema_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
794        skip_if_unsupported!(kernel, test_name);
795        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
796        let candles = read_candles_from_csv(file_path)?;
797        let input = TemaInput::from_candles(&candles, "close", TemaParams::default());
798        let result = tema_with_kernel(&input, kernel)?;
799        let expected_last_five = [
800            59281.895570662884,
801            59257.25021607971,
802            59172.23342859784,
803            59175.218345941066,
804            58934.24395798363,
805        ];
806        let start = result.values.len().saturating_sub(5);
807        for (i, &val) in result.values[start..].iter().enumerate() {
808            let diff = (val - expected_last_five[i]).abs();
809            assert!(
810                diff < 1e-8,
811                "[{}] TEMA {:?} mismatch at idx {}: got {}, expected {}",
812                test_name,
813                kernel,
814                i,
815                val,
816                expected_last_five[i]
817            );
818        }
819        Ok(())
820    }
821    fn check_tema_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
822        skip_if_unsupported!(kernel, test_name);
823        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
824        let candles = read_candles_from_csv(file_path)?;
825        let input = TemaInput::with_default_candles(&candles);
826        match input.data {
827            TemaData::Candles { source, .. } => assert_eq!(source, "close"),
828            _ => panic!("Expected TemaData::Candles"),
829        }
830        let output = tema_with_kernel(&input, kernel)?;
831        assert_eq!(output.values.len(), candles.close.len());
832        Ok(())
833    }
834    fn check_tema_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
835        skip_if_unsupported!(kernel, test_name);
836        let input_data = [10.0, 20.0, 30.0];
837        let params = TemaParams { period: Some(0) };
838        let input = TemaInput::from_slice(&input_data, params);
839        let res = tema_with_kernel(&input, kernel);
840        assert!(
841            res.is_err(),
842            "[{}] TEMA should fail with zero period",
843            test_name
844        );
845        Ok(())
846    }
847    fn check_tema_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
848        skip_if_unsupported!(kernel, test_name);
849        let input_data: [f64; 0] = [];
850        let params = TemaParams { period: Some(9) };
851        let input = TemaInput::from_slice(&input_data, params);
852        let res = tema_with_kernel(&input, kernel);
853        assert!(
854            res.is_err(),
855            "[{}] TEMA should fail with empty input",
856            test_name
857        );
858        if let Err(e) = res {
859            assert!(
860                matches!(e, TemaError::EmptyInputData),
861                "[{}] Expected EmptyInputData error",
862                test_name
863            );
864        }
865        Ok(())
866    }
867    fn check_tema_period_exceeds_length(
868        test_name: &str,
869        kernel: Kernel,
870    ) -> Result<(), Box<dyn Error>> {
871        skip_if_unsupported!(kernel, test_name);
872        let data_small = [10.0, 20.0, 30.0];
873        let params = TemaParams { period: Some(10) };
874        let input = TemaInput::from_slice(&data_small, params);
875        let res = tema_with_kernel(&input, kernel);
876        assert!(
877            res.is_err(),
878            "[{}] TEMA should fail with period exceeding length",
879            test_name
880        );
881        Ok(())
882    }
883    fn check_tema_very_small_dataset(
884        test_name: &str,
885        kernel: Kernel,
886    ) -> Result<(), Box<dyn Error>> {
887        skip_if_unsupported!(kernel, test_name);
888        let single_point = [42.0];
889        let params = TemaParams { period: Some(9) };
890        let input = TemaInput::from_slice(&single_point, params);
891        let res = tema_with_kernel(&input, kernel);
892        assert!(
893            res.is_err(),
894            "[{}] TEMA should fail with insufficient data",
895            test_name
896        );
897        Ok(())
898    }
899    fn check_tema_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
900        skip_if_unsupported!(kernel, test_name);
901        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
902        let candles = read_candles_from_csv(file_path)?;
903        let first_params = TemaParams { period: Some(9) };
904        let first_input = TemaInput::from_candles(&candles, "close", first_params);
905        let first_result = tema_with_kernel(&first_input, kernel)?;
906        let second_params = TemaParams { period: Some(9) };
907        let second_input = TemaInput::from_slice(&first_result.values, second_params);
908        let second_result = tema_with_kernel(&second_input, kernel)?;
909        assert_eq!(second_result.values.len(), first_result.values.len());
910        Ok(())
911    }
912    fn check_tema_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
913        skip_if_unsupported!(kernel, test_name);
914        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
915        let candles = read_candles_from_csv(file_path)?;
916        let input = TemaInput::from_candles(&candles, "close", TemaParams { period: Some(9) });
917        let res = tema_with_kernel(&input, kernel)?;
918        assert_eq!(res.values.len(), candles.close.len());
919        if res.values.len() > 50 {
920            for (i, &val) in res.values[50..].iter().enumerate() {
921                assert!(
922                    !val.is_nan(),
923                    "[{}] Found unexpected NaN at out-index {}",
924                    test_name,
925                    50 + i
926                );
927            }
928        }
929        Ok(())
930    }
931    fn check_tema_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
932        skip_if_unsupported!(kernel, test_name);
933        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
934        let candles = read_candles_from_csv(file_path)?;
935        let period = 9;
936        let input = TemaInput::from_candles(
937            &candles,
938            "close",
939            TemaParams {
940                period: Some(period),
941            },
942        );
943        let batch_output = tema_with_kernel(&input, kernel)?.values;
944        let mut stream = TemaStream::try_new(TemaParams {
945            period: Some(period),
946        })?;
947        let mut stream_values = Vec::with_capacity(candles.close.len());
948        for &price in &candles.close {
949            match stream.update(price) {
950                Some(val) => stream_values.push(val),
951                None => stream_values.push(f64::NAN),
952            }
953        }
954        assert_eq!(batch_output.len(), stream_values.len());
955        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
956            if b.is_nan() && s.is_nan() {
957                continue;
958            }
959            let diff = (b - s).abs();
960            assert!(
961                diff < 1e-9,
962                "[{}] TEMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
963                test_name,
964                i,
965                b,
966                s,
967                diff
968            );
969        }
970        Ok(())
971    }
972
973    macro_rules! generate_all_tema_tests {
974        ($($test_fn:ident),*) => {
975            paste::paste! {
976                $(
977                    #[test]
978                    fn [<$test_fn _scalar_f64>]() {
979                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
980                    }
981                )*
982                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
983                $(
984                    #[test]
985                    fn [<$test_fn _avx2_f64>]() {
986                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
987                    }
988                    #[test]
989                    fn [<$test_fn _avx512_f64>]() {
990                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
991                    }
992                )*
993            }
994        }
995    }
996
997    #[cfg(debug_assertions)]
998    fn check_tema_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
999        skip_if_unsupported!(kernel, test_name);
1000
1001        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1002        let candles = read_candles_from_csv(file_path)?;
1003
1004        let test_periods = vec![5, 9, 14, 20, 50, 100, 200];
1005
1006        for &period in &test_periods {
1007            let params = TemaParams {
1008                period: Some(period),
1009            };
1010            let input = TemaInput::from_candles(&candles, "close", params);
1011
1012            if candles.close.len() < period {
1013                continue;
1014            }
1015
1016            let output = match tema_with_kernel(&input, kernel) {
1017                Ok(o) => o,
1018                Err(_) => continue,
1019            };
1020
1021            for (i, &val) in output.values.iter().enumerate() {
1022                if val.is_nan() {
1023                    continue;
1024                }
1025
1026                let bits = val.to_bits();
1027
1028                if bits == 0x11111111_11111111 {
1029                    panic!(
1030						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with period {}",
1031						test_name, val, bits, i, period
1032					);
1033                }
1034
1035                if bits == 0x22222222_22222222 {
1036                    panic!(
1037						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with period {}",
1038						test_name, val, bits, i, period
1039					);
1040                }
1041
1042                if bits == 0x33333333_33333333 {
1043                    panic!(
1044						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with period {}",
1045						test_name, val, bits, i, period
1046					);
1047                }
1048            }
1049        }
1050
1051        Ok(())
1052    }
1053
1054    #[cfg(not(debug_assertions))]
1055    fn check_tema_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1056        Ok(())
1057    }
1058
1059    #[cfg(feature = "proptest")]
1060    fn check_tema_property(
1061        test_name: &str,
1062        kernel: Kernel,
1063    ) -> Result<(), Box<dyn std::error::Error>> {
1064        use proptest::prelude::*;
1065        skip_if_unsupported!(kernel, test_name);
1066
1067        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1068        let candles = read_candles_from_csv(file_path)?;
1069        let close_data = &candles.close;
1070
1071        let strat = (
1072            2usize..=50,
1073            0usize..close_data.len().saturating_sub(200),
1074            100usize..=200,
1075        );
1076
1077        proptest::test_runner::TestRunner::default()
1078            .run(&strat, |(period, start_idx, slice_len)| {
1079                let end_idx = (start_idx + slice_len).min(close_data.len());
1080                if end_idx <= start_idx || end_idx - start_idx < period * 3 + 10 {
1081                    return Ok(());
1082                }
1083
1084                let data_slice = &close_data[start_idx..end_idx];
1085                let params = TemaParams {
1086                    period: Some(period),
1087                };
1088                let input = TemaInput::from_slice(data_slice, params);
1089
1090                let result = tema_with_kernel(&input, kernel);
1091
1092                let scalar_result = tema_with_kernel(&input, Kernel::Scalar);
1093
1094                match (result, scalar_result) {
1095                    (Ok(TemaOutput { values: out }), Ok(TemaOutput { values: ref_out })) => {
1096                        prop_assert_eq!(out.len(), data_slice.len());
1097                        prop_assert_eq!(ref_out.len(), data_slice.len());
1098
1099                        let first = data_slice.iter().position(|x| !x.is_nan()).unwrap_or(0);
1100                        let lookback = (period - 1) * 3;
1101                        let expected_warmup = first + lookback;
1102
1103                        for i in 0..expected_warmup.min(out.len()) {
1104                            prop_assert!(
1105                                out[i].is_nan(),
1106                                "Expected NaN at index {} during warmup, got {}",
1107                                i,
1108                                out[i]
1109                            );
1110                        }
1111
1112                        let multiplier = 2.0 / (period as f64 + 1.0);
1113                        prop_assert!(
1114                            multiplier > 0.0 && multiplier <= 1.0,
1115                            "EMA multiplier should be in (0, 1]: {}",
1116                            multiplier
1117                        );
1118
1119                        for i in expected_warmup..out.len() {
1120                            let y = out[i];
1121                            let r = ref_out[i];
1122
1123                            prop_assert!(!y.is_nan(), "Unexpected NaN at index {}", i);
1124                            prop_assert!(y.is_finite(), "Non-finite value at index {}: {}", i, y);
1125
1126                            let y_bits = y.to_bits();
1127                            let r_bits = r.to_bits();
1128
1129                            if !y.is_finite() || !r.is_finite() {
1130                                prop_assert_eq!(
1131                                    y_bits,
1132                                    r_bits,
1133                                    "NaN/Inf mismatch at {}: {} vs {}",
1134                                    i,
1135                                    y,
1136                                    r
1137                                );
1138                                continue;
1139                            }
1140
1141                            let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1142                            prop_assert!(
1143                                (y - r).abs() <= 1e-9 || ulp_diff <= 5,
1144                                "Kernel mismatch at {}: {} vs {} (ULP={})",
1145                                i,
1146                                y,
1147                                r,
1148                                ulp_diff
1149                            );
1150                        }
1151
1152                        let const_value = 100.0;
1153                        let const_data = vec![const_value; period * 4];
1154                        let const_input = TemaInput::from_slice(
1155                            &const_data,
1156                            TemaParams {
1157                                period: Some(period),
1158                            },
1159                        );
1160                        if let Ok(TemaOutput { values: const_out }) =
1161                            tema_with_kernel(&const_input, kernel)
1162                        {
1163                            let const_warmup = lookback;
1164                            for (i, &val) in const_out.iter().enumerate() {
1165                                if i >= const_warmup && !val.is_nan() {
1166                                    prop_assert!(
1167										(val - const_value).abs() < 1e-9,
1168										"TEMA of constant data should equal the constant at {}: got {}",
1169										i, val
1170									);
1171                                }
1172                            }
1173                        }
1174
1175                        if period <= 20 {
1176                            let mut stream = TemaStream::try_new(TemaParams {
1177                                period: Some(period),
1178                            })
1179                            .unwrap();
1180                            let mut stream_values = Vec::with_capacity(data_slice.len());
1181
1182                            for &price in data_slice {
1183                                match stream.update(price) {
1184                                    Some(val) => stream_values.push(val),
1185                                    None => stream_values.push(f64::NAN),
1186                                }
1187                            }
1188
1189                            for (i, (&batch_val, &stream_val)) in
1190                                out.iter().zip(stream_values.iter()).enumerate()
1191                            {
1192                                if batch_val.is_nan() && stream_val.is_nan() {
1193                                    continue;
1194                                }
1195                                if !batch_val.is_nan() && !stream_val.is_nan() {
1196                                    prop_assert!(
1197                                        (batch_val - stream_val).abs() < 1e-9,
1198                                        "Streaming mismatch at {}: batch={} vs stream={}",
1199                                        i,
1200                                        batch_val,
1201                                        stream_val
1202                                    );
1203                                }
1204                            }
1205                        }
1206
1207                        if data_slice.len() > period * 2 {
1208                            let trend_start = expected_warmup;
1209                            let trend_end = (trend_start + period).min(data_slice.len());
1210
1211                            if trend_end > trend_start + 3 {
1212                                let trend_data = &data_slice[trend_start..trend_end];
1213                                let is_uptrend =
1214                                    trend_data.windows(2).filter(|w| w[1] > w[0]).count()
1215                                        > trend_data.windows(2).filter(|w| w[1] < w[0]).count();
1216
1217                                if is_uptrend {
1218                                    let last_price = data_slice[trend_end - 1];
1219                                    let tema_value = out[trend_end - 1];
1220
1221                                    let price_range = trend_data
1222                                        .iter()
1223                                        .cloned()
1224                                        .fold(f64::NEG_INFINITY, f64::max)
1225                                        - trend_data.iter().cloned().fold(f64::INFINITY, f64::min);
1226                                    prop_assert!(
1227										(tema_value - last_price).abs() < price_range * 1.5,
1228										"TEMA diverged too much from price: TEMA={}, price={}, range={}",
1229										tema_value, last_price, price_range
1230									);
1231                                }
1232                            }
1233                        }
1234                    }
1235                    (Err(e1), Err(e2)) => {
1236                        prop_assert_eq!(
1237                            std::mem::discriminant(&e1),
1238                            std::mem::discriminant(&e2),
1239                            "Different error types: {:?} vs {:?}",
1240                            e1,
1241                            e2
1242                        );
1243                    }
1244                    _ => {
1245                        prop_assert!(
1246                            false,
1247                            "Kernel consistency failure: one succeeded, one failed"
1248                        );
1249                    }
1250                }
1251
1252                Ok(())
1253            })
1254            .unwrap();
1255
1256        Ok(())
1257    }
1258
1259    generate_all_tema_tests!(
1260        check_tema_partial_params,
1261        check_tema_accuracy,
1262        check_tema_default_candles,
1263        check_tema_zero_period,
1264        check_tema_empty_input,
1265        check_tema_period_exceeds_length,
1266        check_tema_very_small_dataset,
1267        check_tema_reinput,
1268        check_tema_nan_handling,
1269        check_tema_streaming,
1270        check_tema_no_poison
1271    );
1272
1273    #[cfg(feature = "proptest")]
1274    generate_all_tema_tests!(check_tema_property);
1275
1276    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1277        skip_if_unsupported!(kernel, test);
1278        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1279        let c = read_candles_from_csv(file)?;
1280        let output = TemaBatchBuilder::new()
1281            .kernel(kernel)
1282            .apply_candles(&c, "close")?;
1283        let def = TemaParams::default();
1284        let row = output.values_for(&def).expect("default row missing");
1285        assert_eq!(row.len(), c.close.len());
1286        let expected = [
1287            59281.895570662884,
1288            59257.25021607971,
1289            59172.23342859784,
1290            59175.218345941066,
1291            58934.24395798363,
1292        ];
1293        let start = row.len() - 5;
1294        for (i, &v) in row[start..].iter().enumerate() {
1295            assert!(
1296                (v - expected[i]).abs() < 1e-8,
1297                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1298            );
1299        }
1300        Ok(())
1301    }
1302
1303    macro_rules! gen_batch_tests {
1304        ($fn_name:ident) => {
1305            paste::paste! {
1306                #[test] fn [<$fn_name _scalar>]()      {
1307                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1308                }
1309                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1310                #[test] fn [<$fn_name _avx2>]()        {
1311                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1312                }
1313                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1314                #[test] fn [<$fn_name _avx512>]()      {
1315                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1316                }
1317                #[test] fn [<$fn_name _auto_detect>]() {
1318                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1319                }
1320            }
1321        };
1322    }
1323
1324    #[cfg(debug_assertions)]
1325    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1326        skip_if_unsupported!(kernel, test);
1327
1328        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1329        let c = read_candles_from_csv(file)?;
1330
1331        let test_configs = vec![
1332            (5, 15, 2),
1333            (10, 50, 5),
1334            (20, 100, 10),
1335            (50, 200, 25),
1336            (3, 3, 1),
1337            (150, 150, 1),
1338        ];
1339
1340        for (start, end, step) in test_configs {
1341            let output = TemaBatchBuilder::new()
1342                .kernel(kernel)
1343                .period_range(start, end, step)
1344                .apply_candles(&c, "close")?;
1345
1346            for (idx, &val) in output.values.iter().enumerate() {
1347                if val.is_nan() {
1348                    continue;
1349                }
1350
1351                let bits = val.to_bits();
1352                let row = idx / output.cols;
1353                let col = idx % output.cols;
1354                let period = output
1355                    .combos
1356                    .get(row)
1357                    .map(|p| p.period.unwrap_or(0))
1358                    .unwrap_or(0);
1359
1360                if bits == 0x11111111_11111111 {
1361                    panic!(
1362                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (period {}, flat index {})",
1363                        test, val, bits, row, col, period, idx
1364                    );
1365                }
1366
1367                if bits == 0x22222222_22222222 {
1368                    panic!(
1369                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (period {}, flat index {})",
1370                        test, val, bits, row, col, period, idx
1371                    );
1372                }
1373
1374                if bits == 0x33333333_33333333 {
1375                    panic!(
1376                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (period {}, flat index {})",
1377                        test, val, bits, row, col, period, idx
1378                    );
1379                }
1380            }
1381        }
1382
1383        Ok(())
1384    }
1385
1386    #[cfg(not(debug_assertions))]
1387    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1388        Ok(())
1389    }
1390
1391    gen_batch_tests!(check_batch_default_row);
1392    gen_batch_tests!(check_batch_no_poison);
1393}
1394
1395#[inline]
1396fn tema_prepare<'a>(
1397    input: &'a TemaInput,
1398    kernel: Kernel,
1399) -> Result<(&'a [f64], usize, usize, usize, Kernel), TemaError> {
1400    let data: &[f64] = match &input.data {
1401        TemaData::Candles { candles, source } => source_type(candles, source),
1402        TemaData::Slice(sl) => sl,
1403    };
1404
1405    if data.is_empty() {
1406        return Err(TemaError::EmptyInputData);
1407    }
1408
1409    let first = data
1410        .iter()
1411        .position(|x| !x.is_nan())
1412        .ok_or(TemaError::AllValuesNaN)?;
1413    let len = data.len();
1414    let period = input.get_period();
1415
1416    if period == 0 || period > len {
1417        return Err(TemaError::InvalidPeriod {
1418            period,
1419            data_len: len,
1420        });
1421    }
1422    if (len - first) < period {
1423        return Err(TemaError::NotEnoughValidData {
1424            needed: period,
1425            valid: len - first,
1426        });
1427    }
1428
1429    let chosen = match kernel {
1430        Kernel::Auto => Kernel::Scalar,
1431        other => other,
1432    };
1433
1434    Ok((data, period, first, len, chosen))
1435}
1436
1437#[inline]
1438fn tema_compute_into(data: &[f64], period: usize, first: usize, chosen: Kernel, out: &mut [f64]) {
1439    unsafe {
1440        match chosen {
1441            Kernel::Scalar | Kernel::ScalarBatch => tema_scalar(data, period, first, out),
1442            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1443            Kernel::Avx2 | Kernel::Avx2Batch => tema_avx2(data, period, first, out),
1444            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1445            Kernel::Avx512 | Kernel::Avx512Batch => tema_avx512(data, period, first, out),
1446            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1447            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
1448                tema_scalar(data, period, first, out)
1449            }
1450            Kernel::Auto => unreachable!(),
1451        }
1452    }
1453}
1454
1455#[inline(always)]
1456fn tema_batch_inner_into(
1457    data: &[f64],
1458    sweep: &TemaBatchRange,
1459    kern: Kernel,
1460    parallel: bool,
1461    out: &mut [f64],
1462) -> Result<Vec<TemaParams>, TemaError> {
1463    let combos = expand_grid(sweep)?;
1464    if combos.is_empty() {
1465        return Err(TemaError::InvalidRange {
1466            start: sweep.period.0,
1467            end: sweep.period.1,
1468            step: sweep.period.2,
1469        });
1470    }
1471
1472    if data.is_empty() {
1473        return Err(TemaError::EmptyInputData);
1474    }
1475
1476    let first = data
1477        .iter()
1478        .position(|x| !x.is_nan())
1479        .ok_or(TemaError::AllValuesNaN)?;
1480    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1481    if data.len() - first < max_p {
1482        return Err(TemaError::NotEnoughValidData {
1483            needed: max_p,
1484            valid: data.len() - first,
1485        });
1486    }
1487
1488    let rows = combos.len();
1489    let cols = data.len();
1490    let expected = rows
1491        .checked_mul(cols)
1492        .ok_or_else(|| TemaError::InvalidInput("rows * cols overflow".into()))?;
1493    if out.len() != expected {
1494        return Err(TemaError::OutputLengthMismatch {
1495            expected,
1496            got: out.len(),
1497        });
1498    }
1499
1500    let actual_kern = match kern {
1501        Kernel::Auto => detect_best_batch_kernel(),
1502        k => k,
1503    };
1504
1505    let warm: Vec<usize> = combos
1506        .iter()
1507        .map(|c| (first + (c.period.unwrap() - 1) * 3).min(cols))
1508        .collect();
1509
1510    let out_uninit = unsafe {
1511        std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1512    };
1513
1514    unsafe { init_matrix_prefixes(out_uninit, cols, &warm) };
1515
1516    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1517        let period = combos[row].period.unwrap();
1518
1519        let out_row =
1520            core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1521
1522        match actual_kern {
1523            Kernel::Scalar | Kernel::ScalarBatch => tema_row_scalar(data, first, period, out_row),
1524            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1525            Kernel::Avx2 | Kernel::Avx2Batch => tema_row_avx2(data, first, period, out_row),
1526            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1527            Kernel::Avx512 | Kernel::Avx512Batch => tema_row_avx512(data, first, period, out_row),
1528            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1529            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
1530                tema_row_scalar(data, first, period, out_row)
1531            }
1532            Kernel::Auto => unreachable!("Auto kernel should have been resolved"),
1533        }
1534    };
1535
1536    if parallel {
1537        #[cfg(not(target_arch = "wasm32"))]
1538        {
1539            out_uninit
1540                .par_chunks_mut(cols)
1541                .enumerate()
1542                .for_each(|(row, slice)| do_row(row, slice));
1543        }
1544        #[cfg(target_arch = "wasm32")]
1545        {
1546            for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1547                do_row(row, slice);
1548            }
1549        }
1550    } else {
1551        for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1552            do_row(row, slice);
1553        }
1554    }
1555
1556    Ok(combos)
1557}
1558
1559#[cfg(feature = "python")]
1560#[pyfunction(name = "tema")]
1561#[pyo3(signature = (data, period, kernel=None))]
1562
1563pub fn tema_py<'py>(
1564    py: Python<'py>,
1565    data: numpy::PyReadonlyArray1<'py, f64>,
1566    period: usize,
1567    kernel: Option<&str>,
1568) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1569    use numpy::{IntoPyArray, PyArrayMethods};
1570
1571    let kern = validate_kernel(kernel, false)?;
1572
1573    let params = TemaParams {
1574        period: Some(period),
1575    };
1576    let result_vec: Vec<f64> = if let Ok(slice_in) = data.as_slice() {
1577        let tema_in = TemaInput::from_slice(slice_in, params);
1578        py.allow_threads(|| tema_with_kernel(&tema_in, kern).map(|o| o.values))
1579            .map_err(|e| PyValueError::new_err(e.to_string()))?
1580    } else {
1581        let owned = data.as_array().to_owned();
1582        let slice_in = owned.as_slice().expect("owned array should be contiguous");
1583        let tema_in = TemaInput::from_slice(slice_in, params);
1584        py.allow_threads(|| tema_with_kernel(&tema_in, kern).map(|o| o.values))
1585            .map_err(|e| PyValueError::new_err(e.to_string()))?
1586    };
1587
1588    Ok(result_vec.into_pyarray(py))
1589}
1590
1591#[cfg(feature = "python")]
1592#[pyclass(name = "TemaStream")]
1593pub struct TemaStreamPy {
1594    stream: TemaStream,
1595}
1596
1597#[cfg(feature = "python")]
1598#[pymethods]
1599impl TemaStreamPy {
1600    #[new]
1601    fn new(period: usize) -> PyResult<Self> {
1602        let params = TemaParams {
1603            period: Some(period),
1604        };
1605        let stream =
1606            TemaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1607        Ok(TemaStreamPy { stream })
1608    }
1609
1610    fn update(&mut self, value: f64) -> Option<f64> {
1611        self.stream.update(value)
1612    }
1613}
1614
1615#[cfg(feature = "python")]
1616#[pyfunction(name = "tema_batch")]
1617#[pyo3(signature = (data, period_range, kernel=None))]
1618
1619pub fn tema_batch_py<'py>(
1620    py: Python<'py>,
1621    data: numpy::PyReadonlyArray1<'py, f64>,
1622    period_range: (usize, usize, usize),
1623    kernel: Option<&str>,
1624) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1625    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1626    use pyo3::types::PyDict;
1627
1628    let slice_in = data.as_slice()?;
1629    let kern = validate_kernel(kernel, true)?;
1630
1631    let sweep = TemaBatchRange {
1632        period: period_range,
1633    };
1634
1635    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1636    let rows = combos.len();
1637    let cols = slice_in.len();
1638
1639    let total = rows
1640        .checked_mul(cols)
1641        .ok_or_else(|| PyValueError::new_err("rows * cols overflow"))?;
1642    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1643    let slice_out = unsafe { out_arr.as_slice_mut()? };
1644
1645    let combos = py
1646        .allow_threads(|| {
1647            let kernel = match kern {
1648                Kernel::Auto => detect_best_batch_kernel(),
1649                k => k,
1650            };
1651
1652            let simd = match kernel {
1653                Kernel::Avx512Batch => Kernel::Avx512,
1654                Kernel::Avx2Batch => Kernel::Avx2,
1655                Kernel::ScalarBatch => Kernel::Scalar,
1656                _ => kernel,
1657            };
1658
1659            tema_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1660        })
1661        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1662
1663    let dict = PyDict::new(py);
1664    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1665    dict.set_item(
1666        "periods",
1667        combos
1668            .iter()
1669            .map(|p| p.period.unwrap() as u64)
1670            .collect::<Vec<_>>()
1671            .into_pyarray(py),
1672    )?;
1673
1674    Ok(dict)
1675}
1676
1677#[cfg(all(feature = "python", feature = "cuda"))]
1678#[pyfunction(name = "tema_cuda_batch_dev")]
1679#[pyo3(signature = (data_f32, period_range, device_id=0))]
1680pub fn tema_cuda_batch_dev_py(
1681    py: Python<'_>,
1682    data_f32: PyReadonlyArray1<'_, f32>,
1683    period_range: (usize, usize, usize),
1684    device_id: usize,
1685) -> PyResult<DeviceArrayF32Py> {
1686    if !cuda_available() {
1687        return Err(PyValueError::new_err("CUDA not available"));
1688    }
1689
1690    let slice_in = data_f32.as_slice()?;
1691    let sweep = TemaBatchRange {
1692        period: period_range,
1693    };
1694
1695    let inner = py.allow_threads(|| -> Result<_, PyErr> {
1696        let cuda = CudaTema::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1697        let inner = cuda
1698            .tema_batch_dev(slice_in, &sweep)
1699            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1700
1701        Ok(inner)
1702    })?;
1703    make_device_array_py(device_id, inner)
1704}
1705
1706#[cfg(all(feature = "python", feature = "cuda"))]
1707#[pyfunction(name = "tema_cuda_many_series_one_param_dev")]
1708#[pyo3(signature = (prices_tm_f32, period, device_id=0))]
1709pub fn tema_cuda_many_series_one_param_dev_py(
1710    py: Python<'_>,
1711    prices_tm_f32: PyReadonlyArray2<'_, f32>,
1712    period: usize,
1713    device_id: usize,
1714) -> PyResult<DeviceArrayF32Py> {
1715    if !cuda_available() {
1716        return Err(PyValueError::new_err("CUDA not available"));
1717    }
1718
1719    use numpy::PyUntypedArrayMethods;
1720
1721    let rows = prices_tm_f32.shape()[0];
1722    let cols = prices_tm_f32.shape()[1];
1723
1724    let prices_flat = prices_tm_f32.as_slice()?;
1725
1726    let inner = py.allow_threads(|| -> Result<_, PyErr> {
1727        let cuda = CudaTema::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1728        let inner = cuda
1729            .tema_many_series_one_param_time_major_dev(prices_flat, cols, rows, period)
1730            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1731
1732        Ok(inner)
1733    })?;
1734    make_device_array_py(device_id, inner)
1735}
1736
1737#[inline]
1738pub fn tema_into_slice(dst: &mut [f64], input: &TemaInput, kern: Kernel) -> Result<(), TemaError> {
1739    let (data, period, first, len, chosen) = tema_prepare(input, kern)?;
1740    if dst.len() != len {
1741        return Err(TemaError::OutputLengthMismatch {
1742            expected: len,
1743            got: dst.len(),
1744        });
1745    }
1746    tema_compute_into(data, period, first, chosen, dst);
1747    let warm = first + (period - 1) * 3;
1748    let pref = warm.min(len);
1749    for v in &mut dst[..pref] {
1750        *v = f64::NAN;
1751    }
1752    Ok(())
1753}
1754
1755#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1756#[wasm_bindgen]
1757pub fn tema_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1758    if data.is_empty() {
1759        return Err(JsValue::from_str("Input data slice is empty"));
1760    }
1761    if period == 0 || period > data.len() {
1762        return Err(JsValue::from_str(&format!(
1763            "Invalid period: {} (data length: {})",
1764            period,
1765            data.len()
1766        )));
1767    }
1768
1769    if data.iter().all(|&x| x.is_nan()) {
1770        return Err(JsValue::from_str("All values are NaN"));
1771    }
1772
1773    let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1774    let valid_count = data.len() - first_valid;
1775    if valid_count < period {
1776        return Err(JsValue::from_str(&format!(
1777            "Not enough valid data: need {} but only {} valid values after NaN values",
1778            period, valid_count
1779        )));
1780    }
1781
1782    if period > 1 {
1783        let lookback = (period - 1) * 3;
1784        if first_valid + lookback >= data.len() {
1785            return Ok(vec![f64::NAN; data.len()]);
1786        }
1787    }
1788
1789    let params = TemaParams {
1790        period: Some(period),
1791    };
1792    let input = TemaInput::from_slice(data, params);
1793
1794    let mut output = vec![0.0; data.len()];
1795
1796    tema_into_slice(&mut output, &input, Kernel::Auto)
1797        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1798
1799    Ok(output)
1800}
1801
1802#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1803#[derive(Serialize, Deserialize)]
1804pub struct TemaBatchConfig {
1805    pub period_range: (usize, usize, usize),
1806}
1807
1808#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1809#[derive(Serialize, Deserialize)]
1810pub struct TemaBatchJsOutput {
1811    pub values: Vec<f64>,
1812    pub combos: Vec<TemaParams>,
1813    pub rows: usize,
1814    pub cols: usize,
1815}
1816
1817#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1818#[wasm_bindgen(js_name = tema_batch)]
1819pub fn tema_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1820    let config: TemaBatchConfig = serde_wasm_bindgen::from_value(config)
1821        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1822
1823    let sweep = TemaBatchRange {
1824        period: config.period_range,
1825    };
1826
1827    let output = tema_batch_inner(data, &sweep, Kernel::Auto, false)
1828        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1829
1830    let js_output = TemaBatchJsOutput {
1831        values: output.values,
1832        combos: output.combos,
1833        rows: output.rows,
1834        cols: output.cols,
1835    };
1836
1837    serde_wasm_bindgen::to_value(&js_output)
1838        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1839}
1840
1841#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1842#[wasm_bindgen]
1843#[deprecated(since = "1.0.0", note = "Use tema_batch instead")]
1844pub fn tema_batch_js(
1845    data: &[f64],
1846    period_start: usize,
1847    period_end: usize,
1848    period_step: usize,
1849) -> Result<Vec<f64>, JsValue> {
1850    let sweep = TemaBatchRange {
1851        period: (period_start, period_end, period_step),
1852    };
1853
1854    tema_batch_inner(data, &sweep, Kernel::Scalar, false)
1855        .map(|output| output.values)
1856        .map_err(|e| JsValue::from_str(&e.to_string()))
1857}
1858
1859#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1860#[wasm_bindgen]
1861pub fn tema_alloc(len: usize) -> *mut f64 {
1862    let mut vec = Vec::<f64>::with_capacity(len);
1863    let ptr = vec.as_mut_ptr();
1864    std::mem::forget(vec);
1865    ptr
1866}
1867
1868#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1869#[wasm_bindgen]
1870pub fn tema_free(ptr: *mut f64, len: usize) {
1871    unsafe {
1872        let _ = Vec::from_raw_parts(ptr, len, len);
1873    }
1874}
1875
1876#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1877#[wasm_bindgen]
1878pub fn tema_into(
1879    in_ptr: *const f64,
1880    out_ptr: *mut f64,
1881    len: usize,
1882    period: usize,
1883) -> Result<(), JsValue> {
1884    if in_ptr.is_null() || out_ptr.is_null() {
1885        return Err(JsValue::from_str("null pointer passed to tema_into"));
1886    }
1887
1888    if len == 0 {
1889        return Err(JsValue::from_str("Input data slice is empty"));
1890    }
1891    if period == 0 || period > len {
1892        return Err(JsValue::from_str(&format!(
1893            "Invalid period: {} (data length: {})",
1894            period, len
1895        )));
1896    }
1897
1898    let data = unsafe { std::slice::from_raw_parts(in_ptr, len) };
1899
1900    if period > 1 && !data.iter().all(|&x| x.is_nan()) {
1901        let lookback = (period - 1) * 3;
1902        if lookback >= len {
1903            let out = unsafe { std::slice::from_raw_parts_mut(out_ptr, len) };
1904            out.fill(f64::NAN);
1905            return Ok(());
1906        }
1907    }
1908
1909    let params = TemaParams {
1910        period: Some(period),
1911    };
1912    let input = TemaInput::from_slice(data, params);
1913
1914    if in_ptr == out_ptr {
1915        let mut temp = vec![0.0; len];
1916        tema_into_slice(&mut temp, &input, Kernel::Auto)
1917            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1918        let out = unsafe { std::slice::from_raw_parts_mut(out_ptr, len) };
1919        out.copy_from_slice(&temp);
1920    } else {
1921        let out = unsafe { std::slice::from_raw_parts_mut(out_ptr, len) };
1922        tema_into_slice(out, &input, Kernel::Auto)
1923            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1924    }
1925
1926    Ok(())
1927}
1928
1929#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1930#[wasm_bindgen]
1931pub fn tema_batch_into(
1932    in_ptr: *const f64,
1933    out_ptr: *mut f64,
1934    len: usize,
1935    period_start: usize,
1936    period_end: usize,
1937    period_step: usize,
1938) -> Result<usize, JsValue> {
1939    if in_ptr.is_null() || out_ptr.is_null() {
1940        return Err(JsValue::from_str("null pointer passed to tema_batch_into"));
1941    }
1942
1943    let data = unsafe { std::slice::from_raw_parts(in_ptr, len) };
1944    let sweep = TemaBatchRange {
1945        period: (period_start, period_end, period_step),
1946    };
1947
1948    let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1949    let rows = combos.len();
1950    let cols = data.len();
1951    let total_size = rows * cols;
1952
1953    let out_slice = unsafe { std::slice::from_raw_parts_mut(out_ptr, total_size) };
1954
1955    tema_batch_inner_into(data, &sweep, Kernel::Auto, false, out_slice)
1956        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1957
1958    Ok(rows)
1959}