1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
8use core::arch::x86_64::*;
9#[cfg(not(target_arch = "wasm32"))]
10use rayon::prelude::*;
11use std::convert::AsRef;
12use std::error::Error;
13use std::mem::MaybeUninit;
14use thiserror::Error;
15
16#[cfg(all(feature = "python", feature = "cuda"))]
17use crate::cuda::cuda_available;
18#[cfg(all(feature = "python", feature = "cuda"))]
19use crate::cuda::moving_averages::CudaTema;
20#[cfg(all(feature = "python", feature = "cuda"))]
21use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
22#[cfg(feature = "python")]
23use crate::utilities::kernel_validation::validate_kernel;
24#[cfg(feature = "python")]
25use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
26#[cfg(all(feature = "python", feature = "cuda"))]
27use numpy::{PyReadonlyArray1, PyReadonlyArray2};
28#[cfg(feature = "python")]
29use pyo3::exceptions::PyValueError;
30#[cfg(feature = "python")]
31use pyo3::prelude::*;
32#[cfg(feature = "python")]
33use pyo3::types::{PyDict, PyList};
34#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
35use serde::{Deserialize, Serialize};
36#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
37use wasm_bindgen::prelude::*;
38
39#[derive(Debug, Clone)]
40pub enum TemaData<'a> {
41 Candles {
42 candles: &'a Candles,
43 source: &'a str,
44 },
45 Slice(&'a [f64]),
46}
47
48#[derive(Debug, Clone)]
49pub struct TemaInput<'a> {
50 pub data: TemaData<'a>,
51 pub params: TemaParams,
52}
53
54impl<'a> TemaInput<'a> {
55 #[inline]
56 pub fn from_candles(candles: &'a Candles, source: &'a str, params: TemaParams) -> Self {
57 Self {
58 data: TemaData::Candles { candles, source },
59 params,
60 }
61 }
62 #[inline]
63 pub fn from_slice(slice: &'a [f64], params: TemaParams) -> Self {
64 Self {
65 data: TemaData::Slice(slice),
66 params,
67 }
68 }
69 #[inline]
70 pub fn with_default_candles(candles: &'a Candles) -> Self {
71 Self::from_candles(candles, "close", TemaParams::default())
72 }
73 #[inline]
74 pub fn get_period(&self) -> usize {
75 self.params.period.unwrap_or(9)
76 }
77}
78
79impl<'a> AsRef<[f64]> for TemaInput<'a> {
80 #[inline(always)]
81 fn as_ref(&self) -> &[f64] {
82 match &self.data {
83 TemaData::Slice(slice) => slice,
84 TemaData::Candles { candles, source } => source_type(candles, source),
85 }
86 }
87}
88
89#[derive(Debug, Clone, Copy)]
90#[cfg_attr(
91 all(target_arch = "wasm32", feature = "wasm"),
92 derive(Serialize, Deserialize)
93)]
94pub struct TemaParams {
95 pub period: Option<usize>,
96}
97
98impl Default for TemaParams {
99 fn default() -> Self {
100 Self { period: Some(9) }
101 }
102}
103
104#[derive(Debug, Clone)]
105pub struct TemaOutput {
106 pub values: Vec<f64>,
107}
108
109#[derive(Debug, Error)]
110pub enum TemaError {
111 #[error("tema: Input data slice is empty.")]
112 EmptyInputData,
113 #[error("tema: All values are NaN.")]
114 AllValuesNaN,
115 #[error("tema: Invalid period: period = {period}, data length = {data_len}")]
116 InvalidPeriod { period: usize, data_len: usize },
117 #[error("tema: Not enough valid data: needed = {needed}, valid = {valid}")]
118 NotEnoughValidData { needed: usize, valid: usize },
119 #[error("tema: Output length mismatch: expected {expected}, got {got}")]
120 OutputLengthMismatch { expected: usize, got: usize },
121 #[error("tema: Invalid range: start={start}, end={end}, step={step}")]
122 InvalidRange {
123 start: usize,
124 end: usize,
125 step: usize,
126 },
127 #[error("tema: Invalid kernel for batch path: {0:?}")]
128 InvalidKernelForBatch(Kernel),
129 #[error("tema: Invalid input: {0}")]
130 InvalidInput(String),
131}
132
133#[derive(Copy, Clone, Debug)]
134pub struct TemaBuilder {
135 period: Option<usize>,
136 kernel: Kernel,
137}
138
139impl Default for TemaBuilder {
140 fn default() -> Self {
141 Self {
142 period: None,
143 kernel: Kernel::Auto,
144 }
145 }
146}
147
148impl TemaBuilder {
149 #[inline(always)]
150 pub fn new() -> Self {
151 Self::default()
152 }
153 #[inline(always)]
154 pub fn period(mut self, n: usize) -> Self {
155 self.period = Some(n);
156 self
157 }
158 #[inline(always)]
159 pub fn kernel(mut self, k: Kernel) -> Self {
160 self.kernel = k;
161 self
162 }
163 #[inline(always)]
164 pub fn apply(self, c: &Candles) -> Result<TemaOutput, TemaError> {
165 let p = TemaParams {
166 period: self.period,
167 };
168 let i = TemaInput::from_candles(c, "close", p);
169 tema_with_kernel(&i, self.kernel)
170 }
171 #[inline(always)]
172 pub fn apply_slice(self, d: &[f64]) -> Result<TemaOutput, TemaError> {
173 let p = TemaParams {
174 period: self.period,
175 };
176 let i = TemaInput::from_slice(d, p);
177 tema_with_kernel(&i, self.kernel)
178 }
179 #[inline(always)]
180 pub fn into_stream(self) -> Result<TemaStream, TemaError> {
181 let p = TemaParams {
182 period: self.period,
183 };
184 TemaStream::try_new(p)
185 }
186}
187
188#[inline]
189pub fn tema(input: &TemaInput) -> Result<TemaOutput, TemaError> {
190 tema_with_kernel(input, Kernel::Auto)
191}
192
193pub fn tema_with_kernel(input: &TemaInput, kernel: Kernel) -> Result<TemaOutput, TemaError> {
194 let (data, period, first, len, chosen) = tema_prepare(input, kernel)?;
195 let lookback = (period - 1) * 3;
196 let warm = first + lookback;
197
198 let mut out = alloc_with_nan_prefix(len, warm);
199 tema_compute_into(data, period, first, chosen, &mut out);
200 Ok(TemaOutput { values: out })
201}
202
203#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
204#[inline]
205pub fn tema_into(input: &TemaInput, out: &mut [f64]) -> Result<(), TemaError> {
206 let (data, period, first, len, chosen) = tema_prepare(input, Kernel::Auto)?;
207
208 if out.len() != len {
209 return Err(TemaError::OutputLengthMismatch {
210 expected: len,
211 got: out.len(),
212 });
213 }
214
215 let warm = first + (period - 1) * 3;
216 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
217 let pref = warm.min(out.len());
218 for v in &mut out[..pref] {
219 *v = qnan;
220 }
221
222 tema_compute_into(data, period, first, chosen, out);
223
224 Ok(())
225}
226
227#[inline]
228pub fn tema_scalar(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
229 debug_assert_eq!(data.len(), out.len());
230
231 let n = data.len();
232 if n == 0 {
233 return;
234 }
235
236 let per = 2.0 / (period as f64 + 1.0);
237 let per1 = 1.0 - per;
238
239 let p1 = period - 1;
240 let start2 = first_val + p1;
241 let start3 = first_val + (p1 << 1);
242 let start_out = first_val + p1 * 3;
243
244 if period == 1 {
245 out[first_val..n].copy_from_slice(&data[first_val..n]);
246 return;
247 }
248
249 let mut ema1 = data[first_val];
250 let mut ema2 = 0.0f64;
251 let mut ema3 = 0.0f64;
252
253 let end0 = start2.min(n);
254 for i in first_val..end0 {
255 let price = data[i];
256 ema1 = ema1 * per1 + price * per;
257 }
258
259 if start2 < n {
260 let price = data[start2];
261 ema1 = ema1 * per1 + price * per;
262 ema2 = ema1;
263 ema2 = ema2 * per1 + ema1 * per;
264
265 let end1 = start3.min(n);
266 for i in (start2 + 1)..end1 {
267 let price = data[i];
268 ema1 = ema1 * per1 + price * per;
269 ema2 = ema2 * per1 + ema1 * per;
270 }
271
272 if start3 < n {
273 let price = data[start3];
274 ema1 = ema1 * per1 + price * per;
275 ema2 = ema2 * per1 + ema1 * per;
276 ema3 = ema2;
277 ema3 = ema3 * per1 + ema2 * per;
278
279 let end2 = start_out.min(n);
280 for i in (start3 + 1)..end2 {
281 let price = data[i];
282 ema1 = ema1 * per1 + price * per;
283 ema2 = ema2 * per1 + ema1 * per;
284 ema3 = ema3 * per1 + ema2 * per;
285 }
286
287 for i in start_out..n {
288 let price = data[i];
289 ema1 = ema1 * per1 + price * per;
290 ema2 = ema2 * per1 + ema1 * per;
291 ema3 = ema3 * per1 + ema2 * per;
292
293 out[i] = 3.0 * ema1 - 3.0 * ema2 + ema3;
294 }
295 }
296 }
297}
298
299#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
300#[inline]
301pub fn tema_avx512(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
302 unsafe { tema_avx512_long(data, period, first_valid, out) }
303}
304
305#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
306#[inline]
307pub fn tema_avx2(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
308 tema_scalar(data, period, first_valid, out)
309}
310
311#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
312#[inline]
313pub unsafe fn tema_avx512_short(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
314 tema_scalar(data, period, first_valid, out)
315}
316
317#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
318#[inline]
319pub unsafe fn tema_avx512_long(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
320 tema_scalar(data, period, first_valid, out)
321}
322
323#[derive(Debug, Clone)]
324pub struct TemaStream {
325 period: usize,
326
327 alpha: f64,
328 one_minus_alpha: f64,
329
330 ema1: f64,
331 ema2: f64,
332 ema3: f64,
333
334 count: usize,
335
336 start2: usize,
337 start3: usize,
338 lookback: usize,
339
340 ready: bool,
341}
342
343impl TemaStream {
344 #[inline]
345 pub fn try_new(params: TemaParams) -> Result<Self, TemaError> {
346 let period = params.period.unwrap_or(9);
347 if period == 0 {
348 return Err(TemaError::InvalidPeriod {
349 period,
350 data_len: 0,
351 });
352 }
353
354 let alpha = 2.0 / (period as f64 + 1.0);
355 let one_minus_alpha = 1.0 - alpha;
356
357 let start2 = period;
358
359 let start3 = period + period - 1;
360
361 let lookback = (period - 1) * 3;
362
363 Ok(Self {
364 period,
365 alpha,
366 one_minus_alpha,
367 ema1: f64::NAN,
368 ema2: 0.0,
369 ema3: 0.0,
370 count: 0,
371 start2,
372 start3,
373 lookback,
374 ready: false,
375 })
376 }
377
378 #[inline(always)]
379 pub fn update(&mut self, x: f64) -> Option<f64> {
380 if self.ready {
381 self.ema1 = self.ema1 * self.one_minus_alpha + x * self.alpha;
382 self.ema2 = self.ema2 * self.one_minus_alpha + self.ema1 * self.alpha;
383 self.ema3 = self.ema3 * self.one_minus_alpha + self.ema2 * self.alpha;
384 return Some(3.0 * self.ema1 - 3.0 * self.ema2 + self.ema3);
385 }
386
387 if self.count == 0 {
388 if x.is_nan() {
389 return None;
390 }
391 self.ema1 = x;
392 self.count = 1;
393
394 if self.period == 1 {
395 self.ema2 = self.ema1;
396 self.ema3 = self.ema2;
397 self.ready = true;
398 return Some(self.ema1);
399 }
400 return None;
401 }
402
403 self.ema1 = self.ema1 * self.one_minus_alpha + x * self.alpha;
404 self.count += 1;
405
406 if self.count == self.start2 {
407 self.ema2 = self.ema1;
408 } else if self.count > self.start2 {
409 self.ema2 = self.ema2 * self.one_minus_alpha + self.ema1 * self.alpha;
410 }
411
412 if self.count == self.start3 {
413 self.ema3 = self.ema2;
414 } else if self.count > self.start3 {
415 self.ema3 = self.ema3 * self.one_minus_alpha + self.ema2 * self.alpha;
416 }
417
418 if self.count > self.lookback {
419 self.ready = true;
420 Some(3.0 * self.ema1 - 3.0 * self.ema2 + self.ema3)
421 } else {
422 None
423 }
424 }
425
426 #[inline(always)]
427 pub fn update_unchecked(&mut self, x: f64) -> f64 {
428 debug_assert!(self.ready, "call update() until Some(..) first");
429 self.ema1 = self.ema1 * self.one_minus_alpha + x * self.alpha;
430 self.ema2 = self.ema2 * self.one_minus_alpha + self.ema1 * self.alpha;
431 self.ema3 = self.ema3 * self.one_minus_alpha + self.ema2 * self.alpha;
432 3.0 * self.ema1 - 3.0 * self.ema2 + self.ema3
433 }
434}
435
436#[derive(Clone, Debug)]
437pub struct TemaBatchRange {
438 pub period: (usize, usize, usize),
439}
440
441impl Default for TemaBatchRange {
442 fn default() -> Self {
443 Self {
444 period: (9, 258, 1),
445 }
446 }
447}
448
449#[derive(Clone, Debug, Default)]
450pub struct TemaBatchBuilder {
451 range: TemaBatchRange,
452 kernel: Kernel,
453}
454
455impl TemaBatchBuilder {
456 pub fn new() -> Self {
457 Self::default()
458 }
459 pub fn kernel(mut self, k: Kernel) -> Self {
460 self.kernel = k;
461 self
462 }
463 #[inline]
464 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
465 self.range.period = (start, end, step);
466 self
467 }
468 #[inline]
469 pub fn period_static(mut self, p: usize) -> Self {
470 self.range.period = (p, p, 0);
471 self
472 }
473 pub fn apply_slice(self, data: &[f64]) -> Result<TemaBatchOutput, TemaError> {
474 tema_batch_with_kernel(data, &self.range, self.kernel)
475 }
476 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<TemaBatchOutput, TemaError> {
477 TemaBatchBuilder::new().kernel(k).apply_slice(data)
478 }
479 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<TemaBatchOutput, TemaError> {
480 let slice = source_type(c, src);
481 self.apply_slice(slice)
482 }
483 pub fn with_default_candles(c: &Candles) -> Result<TemaBatchOutput, TemaError> {
484 TemaBatchBuilder::new()
485 .kernel(Kernel::Auto)
486 .apply_candles(c, "close")
487 }
488}
489
490pub fn tema_batch_with_kernel(
491 data: &[f64],
492 sweep: &TemaBatchRange,
493 k: Kernel,
494) -> Result<TemaBatchOutput, TemaError> {
495 let kernel = match k {
496 Kernel::Auto => detect_best_batch_kernel(),
497 other if other.is_batch() => other,
498 other => return Err(TemaError::InvalidKernelForBatch(other)),
499 };
500 let simd = match kernel {
501 Kernel::Avx512Batch => Kernel::Avx512,
502 Kernel::Avx2Batch => Kernel::Avx2,
503 Kernel::ScalarBatch => Kernel::Scalar,
504 _ => unreachable!(),
505 };
506 tema_batch_par_slice(data, sweep, simd)
507}
508
509#[derive(Clone, Debug)]
510pub struct TemaBatchOutput {
511 pub values: Vec<f64>,
512 pub combos: Vec<TemaParams>,
513 pub rows: usize,
514 pub cols: usize,
515}
516impl TemaBatchOutput {
517 pub fn row_for_params(&self, p: &TemaParams) -> Option<usize> {
518 self.combos
519 .iter()
520 .position(|c| c.period.unwrap_or(9) == p.period.unwrap_or(9))
521 }
522 pub fn values_for(&self, p: &TemaParams) -> Option<&[f64]> {
523 self.row_for_params(p).map(|row| {
524 let start = row * self.cols;
525 &self.values[start..start + self.cols]
526 })
527 }
528}
529
530#[inline(always)]
531fn expand_grid(r: &TemaBatchRange) -> Result<Vec<TemaParams>, TemaError> {
532 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, TemaError> {
533 if step == 0 || start == end {
534 return Ok(vec![start]);
535 }
536
537 if start > end {
538 let mut v = Vec::new();
539 let mut cur = start;
540 while cur >= end {
541 v.push(cur);
542
543 if let Some(next) = cur.checked_sub(step) {
544 if next == cur {
545 break;
546 }
547 cur = next;
548 } else {
549 break;
550 }
551 if cur == usize::MAX {
552 break;
553 }
554 if cur < end {
555 break;
556 }
557 }
558 if v.is_empty() {
559 return Err(TemaError::InvalidRange { start, end, step });
560 }
561 return Ok(v);
562 }
563 let it = (start..=end).step_by(step);
564 let v: Vec<usize> = it.collect();
565 if v.is_empty() {
566 return Err(TemaError::InvalidRange { start, end, step });
567 }
568 Ok(v)
569 }
570 let periods = axis_usize(r.period)?;
571 let mut out = Vec::with_capacity(periods.len());
572 for &p in &periods {
573 out.push(TemaParams { period: Some(p) });
574 }
575 Ok(out)
576}
577
578#[inline(always)]
579pub fn tema_batch_slice(
580 data: &[f64],
581 sweep: &TemaBatchRange,
582 kern: Kernel,
583) -> Result<TemaBatchOutput, TemaError> {
584 tema_batch_inner(data, sweep, kern, false)
585}
586
587#[inline(always)]
588pub fn tema_batch_par_slice(
589 data: &[f64],
590 sweep: &TemaBatchRange,
591 kern: Kernel,
592) -> Result<TemaBatchOutput, TemaError> {
593 tema_batch_inner(data, sweep, kern, true)
594}
595
596#[inline(always)]
597fn tema_batch_inner(
598 data: &[f64],
599 sweep: &TemaBatchRange,
600 kern: Kernel,
601 parallel: bool,
602) -> Result<TemaBatchOutput, TemaError> {
603 let combos = expand_grid(sweep)?;
604 if combos.is_empty() {
605 return Err(TemaError::InvalidRange {
606 start: sweep.period.0,
607 end: sweep.period.1,
608 step: sweep.period.2,
609 });
610 }
611
612 let first = data
613 .iter()
614 .position(|x| !x.is_nan())
615 .ok_or(TemaError::AllValuesNaN)?;
616 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
617 if data.len() - first < max_p {
618 return Err(TemaError::NotEnoughValidData {
619 needed: max_p,
620 valid: data.len() - first,
621 });
622 }
623
624 let rows = combos.len();
625 let cols = data.len();
626
627 let _total = rows
628 .checked_mul(cols)
629 .ok_or_else(|| TemaError::InvalidInput("rows * cols overflow".into()))?;
630
631 let actual_kern = match kern {
632 Kernel::Auto => detect_best_batch_kernel(),
633 k => k,
634 };
635
636 let warm: Vec<usize> = combos
637 .iter()
638 .map(|c| (first + (c.period.unwrap() - 1) * 3).min(cols))
639 .collect();
640
641 let mut raw = make_uninit_matrix(rows, cols);
642 unsafe { init_matrix_prefixes(&mut raw, cols, &warm) };
643
644 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
645 let period = combos[row].period.unwrap();
646
647 let out_row =
648 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
649
650 match actual_kern {
651 Kernel::Scalar | Kernel::ScalarBatch => tema_row_scalar(data, first, period, out_row),
652 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
653 Kernel::Avx2 | Kernel::Avx2Batch => tema_row_avx2(data, first, period, out_row),
654 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
655 Kernel::Avx512 | Kernel::Avx512Batch => tema_row_avx512(data, first, period, out_row),
656 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
657 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
658 tema_row_scalar(data, first, period, out_row)
659 }
660 Kernel::Auto => unreachable!("Auto kernel should have been resolved"),
661 }
662 };
663
664 if parallel {
665 #[cfg(not(target_arch = "wasm32"))]
666 {
667 raw.par_chunks_mut(cols)
668 .enumerate()
669 .for_each(|(row, slice)| do_row(row, slice));
670 }
671
672 #[cfg(target_arch = "wasm32")]
673 {
674 for (row, slice) in raw.chunks_mut(cols).enumerate() {
675 do_row(row, slice);
676 }
677 }
678 } else {
679 for (row, slice) in raw.chunks_mut(cols).enumerate() {
680 do_row(row, slice);
681 }
682 }
683
684 let mut guard = core::mem::ManuallyDrop::new(raw);
685 let values: Vec<f64> = unsafe {
686 Vec::from_raw_parts(
687 guard.as_mut_ptr() as *mut f64,
688 guard.len(),
689 guard.capacity(),
690 )
691 };
692
693 Ok(TemaBatchOutput {
694 values,
695 combos,
696 rows,
697 cols,
698 })
699}
700
701#[inline(always)]
702unsafe fn tema_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
703 tema_scalar(data, period, first, out)
704}
705
706#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
707#[inline]
708unsafe fn tema_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
709 tema_scalar(data, period, first, out)
710}
711
712#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
713#[inline]
714unsafe fn tema_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
715 if period <= 32 {
716 tema_row_avx512_short(data, first, period, out)
717 } else {
718 tema_row_avx512_long(data, first, period, out)
719 }
720}
721
722#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
723#[inline]
724unsafe fn tema_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
725 tema_scalar(data, period, first, out)
726}
727
728#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
729#[inline]
730unsafe fn tema_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
731 tema_scalar(data, period, first, out)
732}
733
734#[cfg(test)]
735mod tests {
736 use super::*;
737 use crate::skip_if_unsupported;
738 use crate::utilities::data_loader::read_candles_from_csv;
739
740 #[test]
741 fn test_tema_into_matches_api() -> Result<(), Box<dyn Error>> {
742 let len = 256usize;
743 let mut data = vec![f64::NAN; len];
744 for i in 2..len {
745 let x = i as f64;
746 data[i] = (x * 0.31337).sin() * 10.0 + x * 0.01;
747 }
748
749 let input = TemaInput::from_slice(&data, TemaParams::default());
750
751 let baseline = tema(&input)?.values;
752
753 let mut out = vec![0.0; len];
754 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
755 {
756 tema_into(&input, &mut out)?;
757 }
758 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
759 {
760 tema_into_slice(&mut out, &input, Kernel::Auto)?;
761 }
762
763 assert_eq!(baseline.len(), out.len());
764
765 fn eq_or_both_nan(a: f64, b: f64) -> bool {
766 (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
767 }
768
769 for (i, (&a, &b)) in baseline.iter().zip(out.iter()).enumerate() {
770 assert!(
771 eq_or_both_nan(a, b),
772 "parity mismatch at idx {}: api={} into={} diff={}",
773 i,
774 a,
775 b,
776 (a - b).abs()
777 );
778 }
779
780 Ok(())
781 }
782
783 fn check_tema_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
784 skip_if_unsupported!(kernel, test_name);
785 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
786 let candles = read_candles_from_csv(file_path)?;
787 let default_params = TemaParams { period: None };
788 let input = TemaInput::from_candles(&candles, "close", default_params);
789 let output = tema_with_kernel(&input, kernel)?;
790 assert_eq!(output.values.len(), candles.close.len());
791 Ok(())
792 }
793 fn check_tema_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
794 skip_if_unsupported!(kernel, test_name);
795 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
796 let candles = read_candles_from_csv(file_path)?;
797 let input = TemaInput::from_candles(&candles, "close", TemaParams::default());
798 let result = tema_with_kernel(&input, kernel)?;
799 let expected_last_five = [
800 59281.895570662884,
801 59257.25021607971,
802 59172.23342859784,
803 59175.218345941066,
804 58934.24395798363,
805 ];
806 let start = result.values.len().saturating_sub(5);
807 for (i, &val) in result.values[start..].iter().enumerate() {
808 let diff = (val - expected_last_five[i]).abs();
809 assert!(
810 diff < 1e-8,
811 "[{}] TEMA {:?} mismatch at idx {}: got {}, expected {}",
812 test_name,
813 kernel,
814 i,
815 val,
816 expected_last_five[i]
817 );
818 }
819 Ok(())
820 }
821 fn check_tema_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
822 skip_if_unsupported!(kernel, test_name);
823 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
824 let candles = read_candles_from_csv(file_path)?;
825 let input = TemaInput::with_default_candles(&candles);
826 match input.data {
827 TemaData::Candles { source, .. } => assert_eq!(source, "close"),
828 _ => panic!("Expected TemaData::Candles"),
829 }
830 let output = tema_with_kernel(&input, kernel)?;
831 assert_eq!(output.values.len(), candles.close.len());
832 Ok(())
833 }
834 fn check_tema_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
835 skip_if_unsupported!(kernel, test_name);
836 let input_data = [10.0, 20.0, 30.0];
837 let params = TemaParams { period: Some(0) };
838 let input = TemaInput::from_slice(&input_data, params);
839 let res = tema_with_kernel(&input, kernel);
840 assert!(
841 res.is_err(),
842 "[{}] TEMA should fail with zero period",
843 test_name
844 );
845 Ok(())
846 }
847 fn check_tema_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
848 skip_if_unsupported!(kernel, test_name);
849 let input_data: [f64; 0] = [];
850 let params = TemaParams { period: Some(9) };
851 let input = TemaInput::from_slice(&input_data, params);
852 let res = tema_with_kernel(&input, kernel);
853 assert!(
854 res.is_err(),
855 "[{}] TEMA should fail with empty input",
856 test_name
857 );
858 if let Err(e) = res {
859 assert!(
860 matches!(e, TemaError::EmptyInputData),
861 "[{}] Expected EmptyInputData error",
862 test_name
863 );
864 }
865 Ok(())
866 }
867 fn check_tema_period_exceeds_length(
868 test_name: &str,
869 kernel: Kernel,
870 ) -> Result<(), Box<dyn Error>> {
871 skip_if_unsupported!(kernel, test_name);
872 let data_small = [10.0, 20.0, 30.0];
873 let params = TemaParams { period: Some(10) };
874 let input = TemaInput::from_slice(&data_small, params);
875 let res = tema_with_kernel(&input, kernel);
876 assert!(
877 res.is_err(),
878 "[{}] TEMA should fail with period exceeding length",
879 test_name
880 );
881 Ok(())
882 }
883 fn check_tema_very_small_dataset(
884 test_name: &str,
885 kernel: Kernel,
886 ) -> Result<(), Box<dyn Error>> {
887 skip_if_unsupported!(kernel, test_name);
888 let single_point = [42.0];
889 let params = TemaParams { period: Some(9) };
890 let input = TemaInput::from_slice(&single_point, params);
891 let res = tema_with_kernel(&input, kernel);
892 assert!(
893 res.is_err(),
894 "[{}] TEMA should fail with insufficient data",
895 test_name
896 );
897 Ok(())
898 }
899 fn check_tema_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
900 skip_if_unsupported!(kernel, test_name);
901 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
902 let candles = read_candles_from_csv(file_path)?;
903 let first_params = TemaParams { period: Some(9) };
904 let first_input = TemaInput::from_candles(&candles, "close", first_params);
905 let first_result = tema_with_kernel(&first_input, kernel)?;
906 let second_params = TemaParams { period: Some(9) };
907 let second_input = TemaInput::from_slice(&first_result.values, second_params);
908 let second_result = tema_with_kernel(&second_input, kernel)?;
909 assert_eq!(second_result.values.len(), first_result.values.len());
910 Ok(())
911 }
912 fn check_tema_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
913 skip_if_unsupported!(kernel, test_name);
914 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
915 let candles = read_candles_from_csv(file_path)?;
916 let input = TemaInput::from_candles(&candles, "close", TemaParams { period: Some(9) });
917 let res = tema_with_kernel(&input, kernel)?;
918 assert_eq!(res.values.len(), candles.close.len());
919 if res.values.len() > 50 {
920 for (i, &val) in res.values[50..].iter().enumerate() {
921 assert!(
922 !val.is_nan(),
923 "[{}] Found unexpected NaN at out-index {}",
924 test_name,
925 50 + i
926 );
927 }
928 }
929 Ok(())
930 }
931 fn check_tema_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
932 skip_if_unsupported!(kernel, test_name);
933 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
934 let candles = read_candles_from_csv(file_path)?;
935 let period = 9;
936 let input = TemaInput::from_candles(
937 &candles,
938 "close",
939 TemaParams {
940 period: Some(period),
941 },
942 );
943 let batch_output = tema_with_kernel(&input, kernel)?.values;
944 let mut stream = TemaStream::try_new(TemaParams {
945 period: Some(period),
946 })?;
947 let mut stream_values = Vec::with_capacity(candles.close.len());
948 for &price in &candles.close {
949 match stream.update(price) {
950 Some(val) => stream_values.push(val),
951 None => stream_values.push(f64::NAN),
952 }
953 }
954 assert_eq!(batch_output.len(), stream_values.len());
955 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
956 if b.is_nan() && s.is_nan() {
957 continue;
958 }
959 let diff = (b - s).abs();
960 assert!(
961 diff < 1e-9,
962 "[{}] TEMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
963 test_name,
964 i,
965 b,
966 s,
967 diff
968 );
969 }
970 Ok(())
971 }
972
973 macro_rules! generate_all_tema_tests {
974 ($($test_fn:ident),*) => {
975 paste::paste! {
976 $(
977 #[test]
978 fn [<$test_fn _scalar_f64>]() {
979 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
980 }
981 )*
982 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
983 $(
984 #[test]
985 fn [<$test_fn _avx2_f64>]() {
986 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
987 }
988 #[test]
989 fn [<$test_fn _avx512_f64>]() {
990 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
991 }
992 )*
993 }
994 }
995 }
996
997 #[cfg(debug_assertions)]
998 fn check_tema_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
999 skip_if_unsupported!(kernel, test_name);
1000
1001 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1002 let candles = read_candles_from_csv(file_path)?;
1003
1004 let test_periods = vec![5, 9, 14, 20, 50, 100, 200];
1005
1006 for &period in &test_periods {
1007 let params = TemaParams {
1008 period: Some(period),
1009 };
1010 let input = TemaInput::from_candles(&candles, "close", params);
1011
1012 if candles.close.len() < period {
1013 continue;
1014 }
1015
1016 let output = match tema_with_kernel(&input, kernel) {
1017 Ok(o) => o,
1018 Err(_) => continue,
1019 };
1020
1021 for (i, &val) in output.values.iter().enumerate() {
1022 if val.is_nan() {
1023 continue;
1024 }
1025
1026 let bits = val.to_bits();
1027
1028 if bits == 0x11111111_11111111 {
1029 panic!(
1030 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with period {}",
1031 test_name, val, bits, i, period
1032 );
1033 }
1034
1035 if bits == 0x22222222_22222222 {
1036 panic!(
1037 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with period {}",
1038 test_name, val, bits, i, period
1039 );
1040 }
1041
1042 if bits == 0x33333333_33333333 {
1043 panic!(
1044 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with period {}",
1045 test_name, val, bits, i, period
1046 );
1047 }
1048 }
1049 }
1050
1051 Ok(())
1052 }
1053
1054 #[cfg(not(debug_assertions))]
1055 fn check_tema_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1056 Ok(())
1057 }
1058
1059 #[cfg(feature = "proptest")]
1060 fn check_tema_property(
1061 test_name: &str,
1062 kernel: Kernel,
1063 ) -> Result<(), Box<dyn std::error::Error>> {
1064 use proptest::prelude::*;
1065 skip_if_unsupported!(kernel, test_name);
1066
1067 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1068 let candles = read_candles_from_csv(file_path)?;
1069 let close_data = &candles.close;
1070
1071 let strat = (
1072 2usize..=50,
1073 0usize..close_data.len().saturating_sub(200),
1074 100usize..=200,
1075 );
1076
1077 proptest::test_runner::TestRunner::default()
1078 .run(&strat, |(period, start_idx, slice_len)| {
1079 let end_idx = (start_idx + slice_len).min(close_data.len());
1080 if end_idx <= start_idx || end_idx - start_idx < period * 3 + 10 {
1081 return Ok(());
1082 }
1083
1084 let data_slice = &close_data[start_idx..end_idx];
1085 let params = TemaParams {
1086 period: Some(period),
1087 };
1088 let input = TemaInput::from_slice(data_slice, params);
1089
1090 let result = tema_with_kernel(&input, kernel);
1091
1092 let scalar_result = tema_with_kernel(&input, Kernel::Scalar);
1093
1094 match (result, scalar_result) {
1095 (Ok(TemaOutput { values: out }), Ok(TemaOutput { values: ref_out })) => {
1096 prop_assert_eq!(out.len(), data_slice.len());
1097 prop_assert_eq!(ref_out.len(), data_slice.len());
1098
1099 let first = data_slice.iter().position(|x| !x.is_nan()).unwrap_or(0);
1100 let lookback = (period - 1) * 3;
1101 let expected_warmup = first + lookback;
1102
1103 for i in 0..expected_warmup.min(out.len()) {
1104 prop_assert!(
1105 out[i].is_nan(),
1106 "Expected NaN at index {} during warmup, got {}",
1107 i,
1108 out[i]
1109 );
1110 }
1111
1112 let multiplier = 2.0 / (period as f64 + 1.0);
1113 prop_assert!(
1114 multiplier > 0.0 && multiplier <= 1.0,
1115 "EMA multiplier should be in (0, 1]: {}",
1116 multiplier
1117 );
1118
1119 for i in expected_warmup..out.len() {
1120 let y = out[i];
1121 let r = ref_out[i];
1122
1123 prop_assert!(!y.is_nan(), "Unexpected NaN at index {}", i);
1124 prop_assert!(y.is_finite(), "Non-finite value at index {}: {}", i, y);
1125
1126 let y_bits = y.to_bits();
1127 let r_bits = r.to_bits();
1128
1129 if !y.is_finite() || !r.is_finite() {
1130 prop_assert_eq!(
1131 y_bits,
1132 r_bits,
1133 "NaN/Inf mismatch at {}: {} vs {}",
1134 i,
1135 y,
1136 r
1137 );
1138 continue;
1139 }
1140
1141 let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1142 prop_assert!(
1143 (y - r).abs() <= 1e-9 || ulp_diff <= 5,
1144 "Kernel mismatch at {}: {} vs {} (ULP={})",
1145 i,
1146 y,
1147 r,
1148 ulp_diff
1149 );
1150 }
1151
1152 let const_value = 100.0;
1153 let const_data = vec![const_value; period * 4];
1154 let const_input = TemaInput::from_slice(
1155 &const_data,
1156 TemaParams {
1157 period: Some(period),
1158 },
1159 );
1160 if let Ok(TemaOutput { values: const_out }) =
1161 tema_with_kernel(&const_input, kernel)
1162 {
1163 let const_warmup = lookback;
1164 for (i, &val) in const_out.iter().enumerate() {
1165 if i >= const_warmup && !val.is_nan() {
1166 prop_assert!(
1167 (val - const_value).abs() < 1e-9,
1168 "TEMA of constant data should equal the constant at {}: got {}",
1169 i, val
1170 );
1171 }
1172 }
1173 }
1174
1175 if period <= 20 {
1176 let mut stream = TemaStream::try_new(TemaParams {
1177 period: Some(period),
1178 })
1179 .unwrap();
1180 let mut stream_values = Vec::with_capacity(data_slice.len());
1181
1182 for &price in data_slice {
1183 match stream.update(price) {
1184 Some(val) => stream_values.push(val),
1185 None => stream_values.push(f64::NAN),
1186 }
1187 }
1188
1189 for (i, (&batch_val, &stream_val)) in
1190 out.iter().zip(stream_values.iter()).enumerate()
1191 {
1192 if batch_val.is_nan() && stream_val.is_nan() {
1193 continue;
1194 }
1195 if !batch_val.is_nan() && !stream_val.is_nan() {
1196 prop_assert!(
1197 (batch_val - stream_val).abs() < 1e-9,
1198 "Streaming mismatch at {}: batch={} vs stream={}",
1199 i,
1200 batch_val,
1201 stream_val
1202 );
1203 }
1204 }
1205 }
1206
1207 if data_slice.len() > period * 2 {
1208 let trend_start = expected_warmup;
1209 let trend_end = (trend_start + period).min(data_slice.len());
1210
1211 if trend_end > trend_start + 3 {
1212 let trend_data = &data_slice[trend_start..trend_end];
1213 let is_uptrend =
1214 trend_data.windows(2).filter(|w| w[1] > w[0]).count()
1215 > trend_data.windows(2).filter(|w| w[1] < w[0]).count();
1216
1217 if is_uptrend {
1218 let last_price = data_slice[trend_end - 1];
1219 let tema_value = out[trend_end - 1];
1220
1221 let price_range = trend_data
1222 .iter()
1223 .cloned()
1224 .fold(f64::NEG_INFINITY, f64::max)
1225 - trend_data.iter().cloned().fold(f64::INFINITY, f64::min);
1226 prop_assert!(
1227 (tema_value - last_price).abs() < price_range * 1.5,
1228 "TEMA diverged too much from price: TEMA={}, price={}, range={}",
1229 tema_value, last_price, price_range
1230 );
1231 }
1232 }
1233 }
1234 }
1235 (Err(e1), Err(e2)) => {
1236 prop_assert_eq!(
1237 std::mem::discriminant(&e1),
1238 std::mem::discriminant(&e2),
1239 "Different error types: {:?} vs {:?}",
1240 e1,
1241 e2
1242 );
1243 }
1244 _ => {
1245 prop_assert!(
1246 false,
1247 "Kernel consistency failure: one succeeded, one failed"
1248 );
1249 }
1250 }
1251
1252 Ok(())
1253 })
1254 .unwrap();
1255
1256 Ok(())
1257 }
1258
1259 generate_all_tema_tests!(
1260 check_tema_partial_params,
1261 check_tema_accuracy,
1262 check_tema_default_candles,
1263 check_tema_zero_period,
1264 check_tema_empty_input,
1265 check_tema_period_exceeds_length,
1266 check_tema_very_small_dataset,
1267 check_tema_reinput,
1268 check_tema_nan_handling,
1269 check_tema_streaming,
1270 check_tema_no_poison
1271 );
1272
1273 #[cfg(feature = "proptest")]
1274 generate_all_tema_tests!(check_tema_property);
1275
1276 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1277 skip_if_unsupported!(kernel, test);
1278 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1279 let c = read_candles_from_csv(file)?;
1280 let output = TemaBatchBuilder::new()
1281 .kernel(kernel)
1282 .apply_candles(&c, "close")?;
1283 let def = TemaParams::default();
1284 let row = output.values_for(&def).expect("default row missing");
1285 assert_eq!(row.len(), c.close.len());
1286 let expected = [
1287 59281.895570662884,
1288 59257.25021607971,
1289 59172.23342859784,
1290 59175.218345941066,
1291 58934.24395798363,
1292 ];
1293 let start = row.len() - 5;
1294 for (i, &v) in row[start..].iter().enumerate() {
1295 assert!(
1296 (v - expected[i]).abs() < 1e-8,
1297 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1298 );
1299 }
1300 Ok(())
1301 }
1302
1303 macro_rules! gen_batch_tests {
1304 ($fn_name:ident) => {
1305 paste::paste! {
1306 #[test] fn [<$fn_name _scalar>]() {
1307 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1308 }
1309 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1310 #[test] fn [<$fn_name _avx2>]() {
1311 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1312 }
1313 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1314 #[test] fn [<$fn_name _avx512>]() {
1315 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1316 }
1317 #[test] fn [<$fn_name _auto_detect>]() {
1318 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1319 }
1320 }
1321 };
1322 }
1323
1324 #[cfg(debug_assertions)]
1325 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1326 skip_if_unsupported!(kernel, test);
1327
1328 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1329 let c = read_candles_from_csv(file)?;
1330
1331 let test_configs = vec![
1332 (5, 15, 2),
1333 (10, 50, 5),
1334 (20, 100, 10),
1335 (50, 200, 25),
1336 (3, 3, 1),
1337 (150, 150, 1),
1338 ];
1339
1340 for (start, end, step) in test_configs {
1341 let output = TemaBatchBuilder::new()
1342 .kernel(kernel)
1343 .period_range(start, end, step)
1344 .apply_candles(&c, "close")?;
1345
1346 for (idx, &val) in output.values.iter().enumerate() {
1347 if val.is_nan() {
1348 continue;
1349 }
1350
1351 let bits = val.to_bits();
1352 let row = idx / output.cols;
1353 let col = idx % output.cols;
1354 let period = output
1355 .combos
1356 .get(row)
1357 .map(|p| p.period.unwrap_or(0))
1358 .unwrap_or(0);
1359
1360 if bits == 0x11111111_11111111 {
1361 panic!(
1362 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (period {}, flat index {})",
1363 test, val, bits, row, col, period, idx
1364 );
1365 }
1366
1367 if bits == 0x22222222_22222222 {
1368 panic!(
1369 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (period {}, flat index {})",
1370 test, val, bits, row, col, period, idx
1371 );
1372 }
1373
1374 if bits == 0x33333333_33333333 {
1375 panic!(
1376 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (period {}, flat index {})",
1377 test, val, bits, row, col, period, idx
1378 );
1379 }
1380 }
1381 }
1382
1383 Ok(())
1384 }
1385
1386 #[cfg(not(debug_assertions))]
1387 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1388 Ok(())
1389 }
1390
1391 gen_batch_tests!(check_batch_default_row);
1392 gen_batch_tests!(check_batch_no_poison);
1393}
1394
1395#[inline]
1396fn tema_prepare<'a>(
1397 input: &'a TemaInput,
1398 kernel: Kernel,
1399) -> Result<(&'a [f64], usize, usize, usize, Kernel), TemaError> {
1400 let data: &[f64] = match &input.data {
1401 TemaData::Candles { candles, source } => source_type(candles, source),
1402 TemaData::Slice(sl) => sl,
1403 };
1404
1405 if data.is_empty() {
1406 return Err(TemaError::EmptyInputData);
1407 }
1408
1409 let first = data
1410 .iter()
1411 .position(|x| !x.is_nan())
1412 .ok_or(TemaError::AllValuesNaN)?;
1413 let len = data.len();
1414 let period = input.get_period();
1415
1416 if period == 0 || period > len {
1417 return Err(TemaError::InvalidPeriod {
1418 period,
1419 data_len: len,
1420 });
1421 }
1422 if (len - first) < period {
1423 return Err(TemaError::NotEnoughValidData {
1424 needed: period,
1425 valid: len - first,
1426 });
1427 }
1428
1429 let chosen = match kernel {
1430 Kernel::Auto => Kernel::Scalar,
1431 other => other,
1432 };
1433
1434 Ok((data, period, first, len, chosen))
1435}
1436
1437#[inline]
1438fn tema_compute_into(data: &[f64], period: usize, first: usize, chosen: Kernel, out: &mut [f64]) {
1439 unsafe {
1440 match chosen {
1441 Kernel::Scalar | Kernel::ScalarBatch => tema_scalar(data, period, first, out),
1442 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1443 Kernel::Avx2 | Kernel::Avx2Batch => tema_avx2(data, period, first, out),
1444 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1445 Kernel::Avx512 | Kernel::Avx512Batch => tema_avx512(data, period, first, out),
1446 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1447 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
1448 tema_scalar(data, period, first, out)
1449 }
1450 Kernel::Auto => unreachable!(),
1451 }
1452 }
1453}
1454
1455#[inline(always)]
1456fn tema_batch_inner_into(
1457 data: &[f64],
1458 sweep: &TemaBatchRange,
1459 kern: Kernel,
1460 parallel: bool,
1461 out: &mut [f64],
1462) -> Result<Vec<TemaParams>, TemaError> {
1463 let combos = expand_grid(sweep)?;
1464 if combos.is_empty() {
1465 return Err(TemaError::InvalidRange {
1466 start: sweep.period.0,
1467 end: sweep.period.1,
1468 step: sweep.period.2,
1469 });
1470 }
1471
1472 if data.is_empty() {
1473 return Err(TemaError::EmptyInputData);
1474 }
1475
1476 let first = data
1477 .iter()
1478 .position(|x| !x.is_nan())
1479 .ok_or(TemaError::AllValuesNaN)?;
1480 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1481 if data.len() - first < max_p {
1482 return Err(TemaError::NotEnoughValidData {
1483 needed: max_p,
1484 valid: data.len() - first,
1485 });
1486 }
1487
1488 let rows = combos.len();
1489 let cols = data.len();
1490 let expected = rows
1491 .checked_mul(cols)
1492 .ok_or_else(|| TemaError::InvalidInput("rows * cols overflow".into()))?;
1493 if out.len() != expected {
1494 return Err(TemaError::OutputLengthMismatch {
1495 expected,
1496 got: out.len(),
1497 });
1498 }
1499
1500 let actual_kern = match kern {
1501 Kernel::Auto => detect_best_batch_kernel(),
1502 k => k,
1503 };
1504
1505 let warm: Vec<usize> = combos
1506 .iter()
1507 .map(|c| (first + (c.period.unwrap() - 1) * 3).min(cols))
1508 .collect();
1509
1510 let out_uninit = unsafe {
1511 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1512 };
1513
1514 unsafe { init_matrix_prefixes(out_uninit, cols, &warm) };
1515
1516 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1517 let period = combos[row].period.unwrap();
1518
1519 let out_row =
1520 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1521
1522 match actual_kern {
1523 Kernel::Scalar | Kernel::ScalarBatch => tema_row_scalar(data, first, period, out_row),
1524 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1525 Kernel::Avx2 | Kernel::Avx2Batch => tema_row_avx2(data, first, period, out_row),
1526 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1527 Kernel::Avx512 | Kernel::Avx512Batch => tema_row_avx512(data, first, period, out_row),
1528 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1529 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
1530 tema_row_scalar(data, first, period, out_row)
1531 }
1532 Kernel::Auto => unreachable!("Auto kernel should have been resolved"),
1533 }
1534 };
1535
1536 if parallel {
1537 #[cfg(not(target_arch = "wasm32"))]
1538 {
1539 out_uninit
1540 .par_chunks_mut(cols)
1541 .enumerate()
1542 .for_each(|(row, slice)| do_row(row, slice));
1543 }
1544 #[cfg(target_arch = "wasm32")]
1545 {
1546 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1547 do_row(row, slice);
1548 }
1549 }
1550 } else {
1551 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1552 do_row(row, slice);
1553 }
1554 }
1555
1556 Ok(combos)
1557}
1558
1559#[cfg(feature = "python")]
1560#[pyfunction(name = "tema")]
1561#[pyo3(signature = (data, period, kernel=None))]
1562
1563pub fn tema_py<'py>(
1564 py: Python<'py>,
1565 data: numpy::PyReadonlyArray1<'py, f64>,
1566 period: usize,
1567 kernel: Option<&str>,
1568) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1569 use numpy::{IntoPyArray, PyArrayMethods};
1570
1571 let kern = validate_kernel(kernel, false)?;
1572
1573 let params = TemaParams {
1574 period: Some(period),
1575 };
1576 let result_vec: Vec<f64> = if let Ok(slice_in) = data.as_slice() {
1577 let tema_in = TemaInput::from_slice(slice_in, params);
1578 py.allow_threads(|| tema_with_kernel(&tema_in, kern).map(|o| o.values))
1579 .map_err(|e| PyValueError::new_err(e.to_string()))?
1580 } else {
1581 let owned = data.as_array().to_owned();
1582 let slice_in = owned.as_slice().expect("owned array should be contiguous");
1583 let tema_in = TemaInput::from_slice(slice_in, params);
1584 py.allow_threads(|| tema_with_kernel(&tema_in, kern).map(|o| o.values))
1585 .map_err(|e| PyValueError::new_err(e.to_string()))?
1586 };
1587
1588 Ok(result_vec.into_pyarray(py))
1589}
1590
1591#[cfg(feature = "python")]
1592#[pyclass(name = "TemaStream")]
1593pub struct TemaStreamPy {
1594 stream: TemaStream,
1595}
1596
1597#[cfg(feature = "python")]
1598#[pymethods]
1599impl TemaStreamPy {
1600 #[new]
1601 fn new(period: usize) -> PyResult<Self> {
1602 let params = TemaParams {
1603 period: Some(period),
1604 };
1605 let stream =
1606 TemaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1607 Ok(TemaStreamPy { stream })
1608 }
1609
1610 fn update(&mut self, value: f64) -> Option<f64> {
1611 self.stream.update(value)
1612 }
1613}
1614
1615#[cfg(feature = "python")]
1616#[pyfunction(name = "tema_batch")]
1617#[pyo3(signature = (data, period_range, kernel=None))]
1618
1619pub fn tema_batch_py<'py>(
1620 py: Python<'py>,
1621 data: numpy::PyReadonlyArray1<'py, f64>,
1622 period_range: (usize, usize, usize),
1623 kernel: Option<&str>,
1624) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1625 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1626 use pyo3::types::PyDict;
1627
1628 let slice_in = data.as_slice()?;
1629 let kern = validate_kernel(kernel, true)?;
1630
1631 let sweep = TemaBatchRange {
1632 period: period_range,
1633 };
1634
1635 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1636 let rows = combos.len();
1637 let cols = slice_in.len();
1638
1639 let total = rows
1640 .checked_mul(cols)
1641 .ok_or_else(|| PyValueError::new_err("rows * cols overflow"))?;
1642 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1643 let slice_out = unsafe { out_arr.as_slice_mut()? };
1644
1645 let combos = py
1646 .allow_threads(|| {
1647 let kernel = match kern {
1648 Kernel::Auto => detect_best_batch_kernel(),
1649 k => k,
1650 };
1651
1652 let simd = match kernel {
1653 Kernel::Avx512Batch => Kernel::Avx512,
1654 Kernel::Avx2Batch => Kernel::Avx2,
1655 Kernel::ScalarBatch => Kernel::Scalar,
1656 _ => kernel,
1657 };
1658
1659 tema_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1660 })
1661 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1662
1663 let dict = PyDict::new(py);
1664 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1665 dict.set_item(
1666 "periods",
1667 combos
1668 .iter()
1669 .map(|p| p.period.unwrap() as u64)
1670 .collect::<Vec<_>>()
1671 .into_pyarray(py),
1672 )?;
1673
1674 Ok(dict)
1675}
1676
1677#[cfg(all(feature = "python", feature = "cuda"))]
1678#[pyfunction(name = "tema_cuda_batch_dev")]
1679#[pyo3(signature = (data_f32, period_range, device_id=0))]
1680pub fn tema_cuda_batch_dev_py(
1681 py: Python<'_>,
1682 data_f32: PyReadonlyArray1<'_, f32>,
1683 period_range: (usize, usize, usize),
1684 device_id: usize,
1685) -> PyResult<DeviceArrayF32Py> {
1686 if !cuda_available() {
1687 return Err(PyValueError::new_err("CUDA not available"));
1688 }
1689
1690 let slice_in = data_f32.as_slice()?;
1691 let sweep = TemaBatchRange {
1692 period: period_range,
1693 };
1694
1695 let inner = py.allow_threads(|| -> Result<_, PyErr> {
1696 let cuda = CudaTema::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1697 let inner = cuda
1698 .tema_batch_dev(slice_in, &sweep)
1699 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1700
1701 Ok(inner)
1702 })?;
1703 make_device_array_py(device_id, inner)
1704}
1705
1706#[cfg(all(feature = "python", feature = "cuda"))]
1707#[pyfunction(name = "tema_cuda_many_series_one_param_dev")]
1708#[pyo3(signature = (prices_tm_f32, period, device_id=0))]
1709pub fn tema_cuda_many_series_one_param_dev_py(
1710 py: Python<'_>,
1711 prices_tm_f32: PyReadonlyArray2<'_, f32>,
1712 period: usize,
1713 device_id: usize,
1714) -> PyResult<DeviceArrayF32Py> {
1715 if !cuda_available() {
1716 return Err(PyValueError::new_err("CUDA not available"));
1717 }
1718
1719 use numpy::PyUntypedArrayMethods;
1720
1721 let rows = prices_tm_f32.shape()[0];
1722 let cols = prices_tm_f32.shape()[1];
1723
1724 let prices_flat = prices_tm_f32.as_slice()?;
1725
1726 let inner = py.allow_threads(|| -> Result<_, PyErr> {
1727 let cuda = CudaTema::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1728 let inner = cuda
1729 .tema_many_series_one_param_time_major_dev(prices_flat, cols, rows, period)
1730 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1731
1732 Ok(inner)
1733 })?;
1734 make_device_array_py(device_id, inner)
1735}
1736
1737#[inline]
1738pub fn tema_into_slice(dst: &mut [f64], input: &TemaInput, kern: Kernel) -> Result<(), TemaError> {
1739 let (data, period, first, len, chosen) = tema_prepare(input, kern)?;
1740 if dst.len() != len {
1741 return Err(TemaError::OutputLengthMismatch {
1742 expected: len,
1743 got: dst.len(),
1744 });
1745 }
1746 tema_compute_into(data, period, first, chosen, dst);
1747 let warm = first + (period - 1) * 3;
1748 let pref = warm.min(len);
1749 for v in &mut dst[..pref] {
1750 *v = f64::NAN;
1751 }
1752 Ok(())
1753}
1754
1755#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1756#[wasm_bindgen]
1757pub fn tema_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1758 if data.is_empty() {
1759 return Err(JsValue::from_str("Input data slice is empty"));
1760 }
1761 if period == 0 || period > data.len() {
1762 return Err(JsValue::from_str(&format!(
1763 "Invalid period: {} (data length: {})",
1764 period,
1765 data.len()
1766 )));
1767 }
1768
1769 if data.iter().all(|&x| x.is_nan()) {
1770 return Err(JsValue::from_str("All values are NaN"));
1771 }
1772
1773 let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1774 let valid_count = data.len() - first_valid;
1775 if valid_count < period {
1776 return Err(JsValue::from_str(&format!(
1777 "Not enough valid data: need {} but only {} valid values after NaN values",
1778 period, valid_count
1779 )));
1780 }
1781
1782 if period > 1 {
1783 let lookback = (period - 1) * 3;
1784 if first_valid + lookback >= data.len() {
1785 return Ok(vec![f64::NAN; data.len()]);
1786 }
1787 }
1788
1789 let params = TemaParams {
1790 period: Some(period),
1791 };
1792 let input = TemaInput::from_slice(data, params);
1793
1794 let mut output = vec![0.0; data.len()];
1795
1796 tema_into_slice(&mut output, &input, Kernel::Auto)
1797 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1798
1799 Ok(output)
1800}
1801
1802#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1803#[derive(Serialize, Deserialize)]
1804pub struct TemaBatchConfig {
1805 pub period_range: (usize, usize, usize),
1806}
1807
1808#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1809#[derive(Serialize, Deserialize)]
1810pub struct TemaBatchJsOutput {
1811 pub values: Vec<f64>,
1812 pub combos: Vec<TemaParams>,
1813 pub rows: usize,
1814 pub cols: usize,
1815}
1816
1817#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1818#[wasm_bindgen(js_name = tema_batch)]
1819pub fn tema_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1820 let config: TemaBatchConfig = serde_wasm_bindgen::from_value(config)
1821 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1822
1823 let sweep = TemaBatchRange {
1824 period: config.period_range,
1825 };
1826
1827 let output = tema_batch_inner(data, &sweep, Kernel::Auto, false)
1828 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1829
1830 let js_output = TemaBatchJsOutput {
1831 values: output.values,
1832 combos: output.combos,
1833 rows: output.rows,
1834 cols: output.cols,
1835 };
1836
1837 serde_wasm_bindgen::to_value(&js_output)
1838 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1839}
1840
1841#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1842#[wasm_bindgen]
1843#[deprecated(since = "1.0.0", note = "Use tema_batch instead")]
1844pub fn tema_batch_js(
1845 data: &[f64],
1846 period_start: usize,
1847 period_end: usize,
1848 period_step: usize,
1849) -> Result<Vec<f64>, JsValue> {
1850 let sweep = TemaBatchRange {
1851 period: (period_start, period_end, period_step),
1852 };
1853
1854 tema_batch_inner(data, &sweep, Kernel::Scalar, false)
1855 .map(|output| output.values)
1856 .map_err(|e| JsValue::from_str(&e.to_string()))
1857}
1858
1859#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1860#[wasm_bindgen]
1861pub fn tema_alloc(len: usize) -> *mut f64 {
1862 let mut vec = Vec::<f64>::with_capacity(len);
1863 let ptr = vec.as_mut_ptr();
1864 std::mem::forget(vec);
1865 ptr
1866}
1867
1868#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1869#[wasm_bindgen]
1870pub fn tema_free(ptr: *mut f64, len: usize) {
1871 unsafe {
1872 let _ = Vec::from_raw_parts(ptr, len, len);
1873 }
1874}
1875
1876#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1877#[wasm_bindgen]
1878pub fn tema_into(
1879 in_ptr: *const f64,
1880 out_ptr: *mut f64,
1881 len: usize,
1882 period: usize,
1883) -> Result<(), JsValue> {
1884 if in_ptr.is_null() || out_ptr.is_null() {
1885 return Err(JsValue::from_str("null pointer passed to tema_into"));
1886 }
1887
1888 if len == 0 {
1889 return Err(JsValue::from_str("Input data slice is empty"));
1890 }
1891 if period == 0 || period > len {
1892 return Err(JsValue::from_str(&format!(
1893 "Invalid period: {} (data length: {})",
1894 period, len
1895 )));
1896 }
1897
1898 let data = unsafe { std::slice::from_raw_parts(in_ptr, len) };
1899
1900 if period > 1 && !data.iter().all(|&x| x.is_nan()) {
1901 let lookback = (period - 1) * 3;
1902 if lookback >= len {
1903 let out = unsafe { std::slice::from_raw_parts_mut(out_ptr, len) };
1904 out.fill(f64::NAN);
1905 return Ok(());
1906 }
1907 }
1908
1909 let params = TemaParams {
1910 period: Some(period),
1911 };
1912 let input = TemaInput::from_slice(data, params);
1913
1914 if in_ptr == out_ptr {
1915 let mut temp = vec![0.0; len];
1916 tema_into_slice(&mut temp, &input, Kernel::Auto)
1917 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1918 let out = unsafe { std::slice::from_raw_parts_mut(out_ptr, len) };
1919 out.copy_from_slice(&temp);
1920 } else {
1921 let out = unsafe { std::slice::from_raw_parts_mut(out_ptr, len) };
1922 tema_into_slice(out, &input, Kernel::Auto)
1923 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1924 }
1925
1926 Ok(())
1927}
1928
1929#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1930#[wasm_bindgen]
1931pub fn tema_batch_into(
1932 in_ptr: *const f64,
1933 out_ptr: *mut f64,
1934 len: usize,
1935 period_start: usize,
1936 period_end: usize,
1937 period_step: usize,
1938) -> Result<usize, JsValue> {
1939 if in_ptr.is_null() || out_ptr.is_null() {
1940 return Err(JsValue::from_str("null pointer passed to tema_batch_into"));
1941 }
1942
1943 let data = unsafe { std::slice::from_raw_parts(in_ptr, len) };
1944 let sweep = TemaBatchRange {
1945 period: (period_start, period_end, period_step),
1946 };
1947
1948 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1949 let rows = combos.len();
1950 let cols = data.len();
1951 let total_size = rows * cols;
1952
1953 let out_slice = unsafe { std::slice::from_raw_parts_mut(out_ptr, total_size) };
1954
1955 tema_batch_inner_into(data, &sweep, Kernel::Auto, false, out_slice)
1956 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1957
1958 Ok(rows)
1959}