1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::trima_wrapper::DeviceArrayF32Trima;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::cuda::moving_averages::CudaTrima;
7use crate::indicators::sma::{sma, SmaData, SmaInput, SmaParams};
8use crate::utilities::data_loader::{source_type, Candles};
9#[cfg(all(feature = "python", feature = "cuda"))]
10use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
11use crate::utilities::enums::Kernel;
12use crate::utilities::helpers::{
13 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
14 make_uninit_matrix,
15};
16#[cfg(feature = "python")]
17use crate::utilities::kernel_validation::validate_kernel;
18#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
19use core::arch::x86_64::*;
20use paste::paste;
21#[cfg(not(target_arch = "wasm32"))]
22use rayon::prelude::*;
23use std::convert::AsRef;
24use std::mem::MaybeUninit;
25use thiserror::Error;
26
27impl<'a> AsRef<[f64]> for TrimaInput<'a> {
28 #[inline(always)]
29 fn as_ref(&self) -> &[f64] {
30 match &self.data {
31 TrimaData::Slice(slice) => slice,
32 TrimaData::Candles { candles, source } => source_type(candles, source),
33 }
34 }
35}
36
37#[derive(Debug, Clone)]
38pub enum TrimaData<'a> {
39 Candles {
40 candles: &'a Candles,
41 source: &'a str,
42 },
43 Slice(&'a [f64]),
44}
45
46#[derive(Debug, Clone)]
47pub struct TrimaOutput {
48 pub values: Vec<f64>,
49}
50
51#[derive(Debug, Clone)]
52#[cfg_attr(
53 all(target_arch = "wasm32", feature = "wasm"),
54 derive(Serialize, Deserialize)
55)]
56pub struct TrimaParams {
57 pub period: Option<usize>,
58}
59
60impl Default for TrimaParams {
61 fn default() -> Self {
62 Self { period: Some(30) }
63 }
64}
65
66#[derive(Debug, Clone)]
67pub struct TrimaInput<'a> {
68 pub data: TrimaData<'a>,
69 pub params: TrimaParams,
70}
71
72impl<'a> TrimaInput<'a> {
73 #[inline]
74 pub fn from_candles(c: &'a Candles, s: &'a str, p: TrimaParams) -> Self {
75 Self {
76 data: TrimaData::Candles {
77 candles: c,
78 source: s,
79 },
80 params: p,
81 }
82 }
83 #[inline]
84 pub fn from_slice(sl: &'a [f64], p: TrimaParams) -> Self {
85 Self {
86 data: TrimaData::Slice(sl),
87 params: p,
88 }
89 }
90 #[inline]
91 pub fn with_default_candles(c: &'a Candles) -> Self {
92 Self::from_candles(c, "close", TrimaParams::default())
93 }
94 #[inline]
95 pub fn get_period(&self) -> usize {
96 self.params.period.unwrap_or(30)
97 }
98}
99
100#[derive(Copy, Clone, Debug)]
101pub struct TrimaBuilder {
102 period: Option<usize>,
103 kernel: Kernel,
104}
105
106impl Default for TrimaBuilder {
107 fn default() -> Self {
108 Self {
109 period: None,
110 kernel: Kernel::Auto,
111 }
112 }
113}
114
115impl TrimaBuilder {
116 #[inline(always)]
117 pub fn new() -> Self {
118 Self::default()
119 }
120 #[inline(always)]
121 pub fn period(mut self, n: usize) -> Self {
122 self.period = Some(n);
123 self
124 }
125 #[inline(always)]
126 pub fn kernel(mut self, k: Kernel) -> Self {
127 self.kernel = k;
128 self
129 }
130 #[inline(always)]
131 pub fn apply(self, c: &Candles) -> Result<TrimaOutput, TrimaError> {
132 let p = TrimaParams {
133 period: self.period,
134 };
135 let i = TrimaInput::from_candles(c, "close", p);
136 trima_with_kernel(&i, self.kernel)
137 }
138 #[inline(always)]
139 pub fn apply_slice(self, d: &[f64]) -> Result<TrimaOutput, TrimaError> {
140 let p = TrimaParams {
141 period: self.period,
142 };
143 let i = TrimaInput::from_slice(d, p);
144 trima_with_kernel(&i, self.kernel)
145 }
146 #[inline(always)]
147 pub fn into_stream(self) -> Result<TrimaStream, TrimaError> {
148 let p = TrimaParams {
149 period: self.period,
150 };
151 TrimaStream::try_new(p)
152 }
153}
154
155#[derive(Debug, Error)]
156pub enum TrimaError {
157 #[error("trima: No data provided (input data slice is empty).")]
158 EmptyInputData,
159 #[error("trima: All values are NaN.")]
160 AllValuesNaN,
161
162 #[error("trima: Invalid period: period = {period}, data length = {data_len}")]
163 InvalidPeriod { period: usize, data_len: usize },
164
165 #[error("trima: Not enough valid data: needed = {needed}, valid = {valid}")]
166 NotEnoughValidData { needed: usize, valid: usize },
167
168 #[error("trima: Period too small: {period}")]
169 PeriodTooSmall { period: usize },
170
171 #[error("trima: No data provided.")]
172 NoData,
173
174 #[error("trima: Output length mismatch: expected = {expected}, got = {got}")]
175 OutputLengthMismatch { expected: usize, got: usize },
176
177 #[error("trima: Invalid range: start = {start}, end = {end}, step = {step}")]
178 InvalidRange {
179 start: usize,
180 end: usize,
181 step: usize,
182 },
183
184 #[error("trima: Invalid kernel for batch path: {0:?}")]
185 InvalidKernelForBatch(Kernel),
186}
187
188#[inline]
189pub fn trima(input: &TrimaInput) -> Result<TrimaOutput, TrimaError> {
190 trima_with_kernel(input, Kernel::Auto)
191}
192
193#[inline(always)]
194fn trima_prepare<'a>(
195 input: &'a TrimaInput,
196 kernel: Kernel,
197) -> Result<(&'a [f64], usize, usize, usize, usize, Kernel), TrimaError> {
198 let data: &[f64] = input.as_ref();
199 let len = data.len();
200 if len == 0 {
201 return Err(TrimaError::EmptyInputData);
202 }
203 let first = data
204 .iter()
205 .position(|x| !x.is_nan())
206 .ok_or(TrimaError::AllValuesNaN)?;
207 let period = input.get_period();
208
209 if period == 0 || period > len {
210 return Err(TrimaError::InvalidPeriod {
211 period,
212 data_len: len,
213 });
214 }
215 if period <= 3 {
216 return Err(TrimaError::PeriodTooSmall { period });
217 }
218 if (len - first) < period {
219 return Err(TrimaError::NotEnoughValidData {
220 needed: period,
221 valid: len - first,
222 });
223 }
224
225 let m1 = (period + 1) / 2;
226 let m2 = period - m1 + 1;
227
228 let chosen = match kernel {
229 Kernel::Auto => Kernel::Scalar,
230 k => k,
231 };
232
233 Ok((data, period, m1, m2, first, chosen))
234}
235
236#[inline(always)]
237fn trima_compute_into(
238 data: &[f64],
239 period: usize,
240 m1: usize,
241 m2: usize,
242 first: usize,
243 kernel: Kernel,
244 out: &mut [f64],
245) {
246 unsafe {
247 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
248 {
249 if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
250 trima_simd128(data, m1, m2, first, out);
251 return;
252 }
253 }
254
255 match kernel {
256 Kernel::Scalar | Kernel::ScalarBatch => {
257 trima_scalar_optimized(data, period, m1, m2, first, out)
258 }
259 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
260 Kernel::Avx2 | Kernel::Avx2Batch => {
261 trima_avx2(data, period, first, out);
262 }
263 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
264 Kernel::Avx512 | Kernel::Avx512Batch => {
265 trima_avx512(data, period, first, out);
266 }
267 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
268 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
269 trima_scalar_optimized(data, period, m1, m2, first, out)
270 }
271 Kernel::Auto => trima_scalar_optimized(data, period, m1, m2, first, out),
272 }
273 }
274}
275
276#[inline(always)]
277unsafe fn trima_scalar_optimized(
278 data: &[f64],
279 period: usize,
280 m1: usize,
281 m2: usize,
282 first: usize,
283 out: &mut [f64],
284) {
285 debug_assert_eq!(data.len(), out.len());
286 let n = data.len();
287 if n == 0 {
288 return;
289 }
290 let warm = first + period - 1;
291 if warm >= n {
292 return;
293 }
294
295 let inv_m1 = 1.0 / (m1 as f64);
296 let inv_m2 = 1.0 / (m2 as f64);
297
298 let base = data.as_ptr().add(first);
299 let mut sum1 = 0.0;
300 {
301 let mut j = 0usize;
302 let end_unroll = m1 & !3usize;
303 while j < end_unroll {
304 sum1 += *base.add(j) + *base.add(j + 1) + *base.add(j + 2) + *base.add(j + 3);
305 j += 4;
306 }
307 while j < m1 {
308 sum1 += *base.add(j);
309 j += 1;
310 }
311 }
312
313 let mut ring: Vec<f64> = Vec::with_capacity(m2);
314 let mut sum2 = 0.0;
315
316 let mut t = first + m1 - 1;
317
318 let mut p_new = data.as_ptr().add(first + m1);
319 let mut p_old = data.as_ptr().add(first);
320
321 {
322 let s1 = sum1 * inv_m1;
323 ring.push(s1);
324 sum2 += s1;
325 }
326
327 while ring.len() < m2 {
328 t += 1;
329
330 sum1 += *p_new - *p_old;
331 p_new = p_new.add(1);
332 p_old = p_old.add(1);
333
334 let s1 = sum1 * inv_m1;
335 ring.push(s1);
336 sum2 += s1;
337 }
338
339 *out.get_unchecked_mut(warm) = sum2 * inv_m2;
340
341 let mut head = 0usize;
342 t += 1;
343 while t < n {
344 sum1 += *p_new - *p_old;
345 p_new = p_new.add(1);
346 p_old = p_old.add(1);
347
348 let new_s1 = sum1 * inv_m1;
349 let old_s1 = *ring.get_unchecked(head);
350 sum2 += new_s1 - old_s1;
351 *ring.get_unchecked_mut(head) = new_s1;
352
353 head += 1;
354 if head == m2 {
355 head = 0;
356 }
357
358 *out.get_unchecked_mut(t) = sum2 * inv_m2;
359
360 t += 1;
361 }
362}
363
364#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
365#[inline]
366unsafe fn trima_simd128(data: &[f64], m1: usize, m2: usize, first: usize, out: &mut [f64]) {
367 use core::arch::wasm32::*;
368
369 const STEP: usize = 2;
370 let n = data.len();
371
372 let mut sma1 = vec![f64::NAN; n];
373
374 if first + m1 <= n {
375 let chunks = m1 / STEP;
376 let tail = m1 % STEP;
377
378 let mut acc = f64x2_splat(0.0);
379 for i in 0..chunks {
380 let idx = first + i * STEP;
381 let d = v128_load(data.as_ptr().add(idx) as *const v128);
382 acc = f64x2_add(acc, d);
383 }
384
385 let mut sum = f64x2_extract_lane::<0>(acc) + f64x2_extract_lane::<1>(acc);
386 if tail != 0 {
387 sum += data[first + chunks * STEP];
388 }
389
390 sma1[first + m1 - 1] = sum / m1 as f64;
391
392 for i in (first + m1)..n {
393 sum += data[i] - data[i - m1];
394 sma1[i] = sum / m1 as f64;
395 }
396 }
397
398 if first + m1 + m2 - 1 <= n {
399 let sma1_first = first + m1 - 1;
400
401 let chunks2 = m2 / STEP;
402 let tail2 = m2 % STEP;
403
404 let mut acc2 = f64x2_splat(0.0);
405 for i in 0..chunks2 {
406 let idx = sma1_first + i * STEP;
407 let d = v128_load(sma1.as_ptr().add(idx) as *const v128);
408 acc2 = f64x2_add(acc2, d);
409 }
410
411 let mut sum2 = f64x2_extract_lane::<0>(acc2) + f64x2_extract_lane::<1>(acc2);
412 if tail2 != 0 {
413 sum2 += sma1[sma1_first + chunks2 * STEP];
414 }
415
416 out[sma1_first + m2 - 1] = sum2 / m2 as f64;
417
418 for i in (sma1_first + m2)..n {
419 sum2 += sma1[i] - sma1[i - m2];
420 out[i] = sum2 / m2 as f64;
421 }
422 }
423}
424
425pub fn trima_with_kernel(input: &TrimaInput, kernel: Kernel) -> Result<TrimaOutput, TrimaError> {
426 let (data, period, m1, m2, first, chosen) = trima_prepare(input, kernel)?;
427 let len = data.len();
428 let warm = first + period - 1;
429 let mut out = alloc_with_nan_prefix(len, warm);
430 trima_compute_into(data, period, m1, m2, first, chosen, &mut out);
431 Ok(TrimaOutput { values: out })
432}
433
434#[inline]
435pub fn trima_into_slice(
436 output: &mut [f64],
437 input: &TrimaInput,
438 kernel: Kernel,
439) -> Result<(), TrimaError> {
440 let (data, period, m1, m2, first, chosen) = trima_prepare(input, kernel)?;
441
442 if output.len() != data.len() {
443 return Err(TrimaError::OutputLengthMismatch {
444 expected: data.len(),
445 got: output.len(),
446 });
447 }
448
449 trima_compute_into(data, period, m1, m2, first, chosen, output);
450
451 let warmup = first + period - 1;
452 for i in 0..warmup.min(output.len()) {
453 output[i] = f64::NAN;
454 }
455
456 Ok(())
457}
458
459#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
460#[inline(always)]
461pub fn trima_into(input: &TrimaInput, out: &mut [f64]) -> Result<(), TrimaError> {
462 let (data, period, m1, m2, first, chosen) = trima_prepare(input, Kernel::Auto)?;
463
464 if out.len() != data.len() {
465 return Err(TrimaError::OutputLengthMismatch {
466 expected: data.len(),
467 got: out.len(),
468 });
469 }
470
471 let warm = first + period - 1;
472 let end = warm.min(out.len());
473 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
474 for v in &mut out[..end] {
475 *v = qnan;
476 }
477
478 trima_compute_into(data, period, m1, m2, first, chosen, out);
479 Ok(())
480}
481
482#[inline]
483
484pub fn trima_scalar_classic(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
485 let n = data.len();
486 let m1 = (period + 1) / 2;
487 let m2 = period - m1 + 1;
488
489 let mut sma1 = vec![f64::NAN; n];
490
491 if first + m1 <= n {
492 let mut sum1 = 0.0;
493 for j in 0..m1 {
494 sum1 += data[first + j];
495 }
496 sma1[first + m1 - 1] = sum1 / m1 as f64;
497
498 for i in (first + m1)..n {
499 sum1 += data[i] - data[i - m1];
500 sma1[i] = sum1 / m1 as f64;
501 }
502 }
503
504 let warmup_end = first + period - 1;
505 if warmup_end < n {
506 let first_valid_sma1 = first + m1 - 1;
507 let first_valid_sma2 = first_valid_sma1 + m2 - 1;
508
509 if first_valid_sma2 < n {
510 let mut sum2 = 0.0;
511 for j in 0..m2 {
512 sum2 += sma1[first_valid_sma1 + j];
513 }
514
515 if warmup_end < n {
516 out[warmup_end] = sum2 / m2 as f64;
517 }
518
519 for i in (warmup_end + 1)..n {
520 let old_idx = i - m2;
521 if old_idx >= first_valid_sma1 {
522 sum2 += sma1[i] - sma1[old_idx];
523 out[i] = sum2 / m2 as f64;
524 }
525 }
526 }
527 }
528}
529
530pub fn trima_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
531 let n = data.len();
532 let m1 = (period + 1) / 2;
533 let m2 = period - m1 + 1;
534
535 let sma1_in = SmaInput {
536 data: SmaData::Slice(data),
537 params: SmaParams { period: Some(m1) },
538 };
539
540 let pass1 = sma(&sma1_in).unwrap();
541
542 let sma2_in = SmaInput {
543 data: SmaData::Slice(&pass1.values),
544 params: SmaParams { period: Some(m2) },
545 };
546 let pass2 = sma(&sma2_in).unwrap();
547
548 let warmup_end = first + period - 1;
549 for i in warmup_end..n {
550 out[i] = pass2.values[i];
551 }
552}
553
554#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
555#[inline]
556pub fn trima_avx512(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
557 if period <= 32 {
558 unsafe { trima_avx512_short(data, period, first, out) }
559 } else {
560 unsafe { trima_avx512_long(data, period, first, out) }
561 }
562}
563#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
564#[inline]
565pub unsafe fn trima_avx2(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
566 debug_assert_eq!(data.len(), out.len());
567 let n = data.len();
568 if n == 0 {
569 return;
570 }
571
572 let m1 = (period + 1) / 2;
573 let m2 = period - m1 + 1;
574 let warm = first + period - 1;
575 if warm >= n {
576 return;
577 }
578
579 let inv_m1 = 1.0 / (m1 as f64);
580 let inv_m2 = 1.0 / (m2 as f64);
581
582 let mut sum1 = sum_u_avx2(data.as_ptr().add(first), m1);
583
584 let mut ring: Vec<f64> = Vec::with_capacity(m2);
585 let mut sum2 = 0.0;
586
587 let mut t = first + m1 - 1;
588 let mut p_new = data.as_ptr().add(first + m1);
589 let mut p_old = data.as_ptr().add(first);
590
591 {
592 let s1 = sum1 * inv_m1;
593 ring.push(s1);
594 sum2 += s1;
595 }
596
597 while ring.len() < m2 {
598 t += 1;
599 sum1 += *p_new - *p_old;
600 p_new = p_new.add(1);
601 p_old = p_old.add(1);
602
603 let s1 = sum1 * inv_m1;
604 ring.push(s1);
605 sum2 += s1;
606 }
607
608 *out.get_unchecked_mut(warm) = sum2 * inv_m2;
609
610 let mut head = 0usize;
611 t += 1;
612 while t < n {
613 sum1 += *p_new - *p_old;
614 p_new = p_new.add(1);
615 p_old = p_old.add(1);
616
617 let new_s1 = sum1 * inv_m1;
618 let old_s1 = *ring.get_unchecked(head);
619 sum2 += new_s1 - old_s1;
620 *ring.get_unchecked_mut(head) = new_s1;
621
622 head += 1;
623 if head == m2 {
624 head = 0;
625 }
626
627 *out.get_unchecked_mut(t) = sum2 * inv_m2;
628 t += 1;
629 }
630}
631#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
632#[inline]
633pub unsafe fn trima_avx512_short(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
634 debug_assert_eq!(data.len(), out.len());
635 let n = data.len();
636 if n == 0 {
637 return;
638 }
639
640 let m1 = (period + 1) / 2;
641 let m2 = period - m1 + 1;
642 let warm = first + period - 1;
643 if warm >= n {
644 return;
645 }
646
647 let inv_m1 = 1.0 / (m1 as f64);
648 let inv_m2 = 1.0 / (m2 as f64);
649
650 let mut sum1 = sum_u_avx2(data.as_ptr().add(first), m1);
651
652 let mut ring: Vec<f64> = Vec::with_capacity(m2);
653 let mut sum2 = 0.0;
654
655 let mut t = first + m1 - 1;
656 let mut p_new = data.as_ptr().add(first + m1);
657 let mut p_old = data.as_ptr().add(first);
658
659 let s1 = sum1 * inv_m1;
660 ring.push(s1);
661 sum2 += s1;
662
663 while ring.len() < m2 {
664 t += 1;
665 sum1 += *p_new - *p_old;
666 p_new = p_new.add(1);
667 p_old = p_old.add(1);
668
669 let s1 = sum1 * inv_m1;
670 ring.push(s1);
671 sum2 += s1;
672 }
673
674 *out.get_unchecked_mut(warm) = sum2 * inv_m2;
675
676 let mut head = 0usize;
677 t += 1;
678 while t < n {
679 sum1 += *p_new - *p_old;
680 p_new = p_new.add(1);
681 p_old = p_old.add(1);
682
683 let new_s1 = sum1 * inv_m1;
684 let old_s1 = *ring.get_unchecked(head);
685 sum2 += new_s1 - old_s1;
686 *ring.get_unchecked_mut(head) = new_s1;
687
688 head += 1;
689 if head == m2 {
690 head = 0;
691 }
692
693 *out.get_unchecked_mut(t) = sum2 * inv_m2;
694 t += 1;
695 }
696}
697#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
698#[inline]
699pub unsafe fn trima_avx512_long(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
700 trima_avx512_short(data, period, first, out)
701}
702
703#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
704#[inline(always)]
705unsafe fn hsum256d(v: __m256d) -> f64 {
706 let hi = _mm256_extractf128_pd(v, 1);
707 let lo = _mm256_castpd256_pd128(v);
708 let sum128 = _mm_add_pd(lo, hi);
709 let shuffled = _mm_unpackhi_pd(sum128, sum128);
710 _mm_cvtsd_f64(_mm_add_sd(sum128, shuffled))
711}
712
713#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
714#[inline(always)]
715unsafe fn sum_u_avx2(ptr: *const f64, len: usize) -> f64 {
716 let mut acc = _mm256_setzero_pd();
717 let mut p = ptr;
718 let chunks = len / 4;
719 for _ in 0..chunks {
720 acc = _mm256_add_pd(acc, _mm256_loadu_pd(p));
721 p = p.add(4);
722 }
723 let mut s = hsum256d(acc);
724 for i in 0..(len & 3) {
725 s += *p.add(i);
726 }
727 s
728}
729
730#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
731#[inline(always)]
732unsafe fn hsum512d(v: __m512d) -> f64 {
733 let v4 = _mm256_add_pd(_mm512_castpd512_pd256(v), _mm512_extractf64x4_pd(v, 1));
734 let v2 = _mm_add_pd(_mm256_castpd256_pd128(v4), _mm256_extractf128_pd(v4, 1));
735 let hi = _mm_unpackhi_pd(v2, v2);
736 _mm_cvtsd_f64(_mm_add_sd(v2, hi))
737}
738
739#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
740#[inline(always)]
741unsafe fn sum_u_avx512(ptr: *const f64, len: usize) -> f64 {
742 let mut acc = _mm512_setzero_pd();
743 let mut p = ptr;
744 let chunks = len / 8;
745 for _ in 0..chunks {
746 acc = _mm512_add_pd(acc, _mm512_loadu_pd(p));
747 p = p.add(8);
748 }
749 let mut s = hsum512d(acc);
750 for i in 0..(len & 7) {
751 s += *p.add(i);
752 }
753 s
754}
755
756#[inline(always)]
757pub fn trima_row_scalar(
758 data: &[f64],
759 first: usize,
760 period: usize,
761 _stride: usize,
762 _w_ptr: *const f64,
763 _inv_n: f64,
764 out: &mut [f64],
765) {
766 let m1 = (period + 1) / 2;
767 let m2 = period - m1 + 1;
768 unsafe { trima_scalar_optimized(data, period, m1, m2, first, out) }
769}
770
771#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
772#[inline(always)]
773pub unsafe fn trima_row_avx2(
774 data: &[f64],
775 first: usize,
776 period: usize,
777 _stride: usize,
778 _w_ptr: *const f64,
779 _inv_n: f64,
780 out: &mut [f64],
781) {
782 trima_avx2(data, period, first, out)
783}
784
785#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
786#[inline(always)]
787pub unsafe fn trima_row_avx512(
788 data: &[f64],
789 first: usize,
790 period: usize,
791 _stride: usize,
792 _w_ptr: *const f64,
793 _inv_n: f64,
794 out: &mut [f64],
795) {
796 trima_avx512(data, period, first, out)
797}
798#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
799#[inline(always)]
800pub unsafe fn trima_row_avx512_short(
801 data: &[f64],
802 first: usize,
803 period: usize,
804 _stride: usize,
805 _w_ptr: *const f64,
806 _inv_n: f64,
807 out: &mut [f64],
808) {
809 trima_avx512_short(data, period, first, out)
810}
811#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
812#[inline(always)]
813pub unsafe fn trima_row_avx512_long(
814 data: &[f64],
815 first: usize,
816 period: usize,
817 _stride: usize,
818 _w_ptr: *const f64,
819 _inv_n: f64,
820 out: &mut [f64],
821) {
822 trima_avx512_long(data, period, first, out)
823}
824
825#[derive(Debug, Clone)]
826pub struct TrimaStream {
827 period: usize,
828 m1: usize,
829 m2: usize,
830
831 inv_m1: f64,
832 inv_m2: f64,
833
834 buf1: Box<[f64]>,
835 sum1: f64,
836 head1: usize,
837 filled1: bool,
838
839 buf2: Box<[f64]>,
840 sum2: f64,
841 head2: usize,
842 filled2: bool,
843}
844
845impl TrimaStream {
846 #[inline]
847 pub fn try_new(params: TrimaParams) -> Result<Self, TrimaError> {
848 let period = params.period.unwrap_or(30);
849 if period <= 3 {
850 return Err(TrimaError::PeriodTooSmall { period });
851 }
852
853 let m1 = (period + 1) / 2;
854 let m2 = period - m1 + 1;
855
856 Ok(Self {
857 period,
858 m1,
859 m2,
860 inv_m1: 1.0 / (m1 as f64),
861 inv_m2: 1.0 / (m2 as f64),
862
863 buf1: vec![f64::NAN; m1].into_boxed_slice(),
864 sum1: 0.0,
865 head1: 0,
866 filled1: false,
867
868 buf2: vec![f64::NAN; m2].into_boxed_slice(),
869 sum2: 0.0,
870 head2: 0,
871 filled2: false,
872 })
873 }
874
875 #[inline(always)]
876 pub fn update(&mut self, x: f64) -> Option<f64> {
877 let old1 = self.buf1[self.head1];
878 self.buf1[self.head1] = x;
879 self.head1 += 1;
880 if self.head1 == self.m1 {
881 self.head1 = 0;
882 self.filled1 = true;
883 }
884
885 if !old1.is_nan() {
886 self.sum1 -= old1;
887 }
888 if !x.is_nan() {
889 self.sum1 += x;
890 }
891
892 if !self.filled1 {
893 return None;
894 }
895
896 let s1 = self.sum1 * self.inv_m1;
897
898 let old2 = self.buf2[self.head2];
899 self.buf2[self.head2] = s1;
900 self.head2 += 1;
901 if self.head2 == self.m2 {
902 self.head2 = 0;
903 self.filled2 = true;
904 }
905
906 if !old2.is_nan() {
907 self.sum2 -= old2;
908 }
909
910 self.sum2 += s1;
911
912 if self.filled2 {
913 Some(self.sum2 * self.inv_m2)
914 } else {
915 None
916 }
917 }
918}
919
920#[derive(Clone, Debug)]
921pub struct TrimaBatchRange {
922 pub period: (usize, usize, usize),
923}
924
925impl Default for TrimaBatchRange {
926 fn default() -> Self {
927 Self {
928 period: (14, 263, 1),
929 }
930 }
931}
932
933#[derive(Clone, Debug, Default)]
934pub struct TrimaBatchBuilder {
935 range: TrimaBatchRange,
936 kernel: Kernel,
937}
938
939impl TrimaBatchBuilder {
940 pub fn new() -> Self {
941 Self::default()
942 }
943 pub fn kernel(mut self, k: Kernel) -> Self {
944 self.kernel = k;
945 self
946 }
947
948 #[inline]
949 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
950 self.range.period = (start, end, step);
951 self
952 }
953 #[inline]
954 pub fn period_static(mut self, p: usize) -> Self {
955 self.range.period = (p, p, 0);
956 self
957 }
958
959 pub fn apply_slice(self, data: &[f64]) -> Result<TrimaBatchOutput, TrimaError> {
960 trima_batch_with_kernel(data, &self.range, self.kernel)
961 }
962
963 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<TrimaBatchOutput, TrimaError> {
964 TrimaBatchBuilder::new().kernel(k).apply_slice(data)
965 }
966
967 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<TrimaBatchOutput, TrimaError> {
968 let slice = source_type(c, src);
969 self.apply_slice(slice)
970 }
971
972 pub fn with_default_candles(c: &Candles) -> Result<TrimaBatchOutput, TrimaError> {
973 TrimaBatchBuilder::new()
974 .kernel(Kernel::Auto)
975 .apply_candles(c, "close")
976 }
977}
978
979pub fn trima_batch_with_kernel(
980 data: &[f64],
981 sweep: &TrimaBatchRange,
982 k: Kernel,
983) -> Result<TrimaBatchOutput, TrimaError> {
984 let kernel = match k {
985 Kernel::Auto => detect_best_batch_kernel(),
986 other if other.is_batch() => other,
987 other => return Err(TrimaError::InvalidKernelForBatch(other)),
988 };
989
990 let simd = match kernel {
991 Kernel::Avx512Batch => Kernel::Avx512,
992 Kernel::Avx2Batch => Kernel::Avx2,
993 Kernel::ScalarBatch => Kernel::Scalar,
994 Kernel::Avx512 | Kernel::Avx2 | Kernel::Scalar => kernel,
995 _ => Kernel::Scalar,
996 };
997
998 trima_batch_par_slice(data, sweep, simd)
999}
1000
1001#[derive(Clone, Debug)]
1002pub struct TrimaBatchOutput {
1003 pub values: Vec<f64>,
1004 pub combos: Vec<TrimaParams>,
1005 pub rows: usize,
1006 pub cols: usize,
1007}
1008impl TrimaBatchOutput {
1009 pub fn row_for_params(&self, p: &TrimaParams) -> Option<usize> {
1010 self.combos
1011 .iter()
1012 .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
1013 }
1014
1015 pub fn values_for(&self, p: &TrimaParams) -> Option<&[f64]> {
1016 self.row_for_params(p).map(|row| {
1017 let start = row * self.cols;
1018 &self.values[start..start + self.cols]
1019 })
1020 }
1021}
1022
1023#[inline(always)]
1024fn expand_grid(r: &TrimaBatchRange) -> Vec<TrimaParams> {
1025 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, TrimaError> {
1026 if step == 0 || start == end {
1027 return Ok(vec![start]);
1028 }
1029 let (lo, hi) = if start <= end {
1030 (start, end)
1031 } else {
1032 (end, start)
1033 };
1034 let mut v = Vec::new();
1035 let mut cur = lo;
1036 while cur <= hi {
1037 v.push(cur);
1038 cur = cur
1039 .checked_add(step)
1040 .ok_or(TrimaError::InvalidRange { start, end, step })?;
1041 if cur == *v.last().unwrap() {
1042 break;
1043 }
1044 }
1045 if v.is_empty() {
1046 return Err(TrimaError::InvalidRange { start, end, step });
1047 }
1048 Ok(v)
1049 }
1050
1051 let periods = match axis_usize(r.period) {
1052 Ok(v) => v,
1053 Err(_) => return Vec::new(),
1054 };
1055 let mut out = Vec::with_capacity(periods.len());
1056 for &p in &periods {
1057 out.push(TrimaParams { period: Some(p) });
1058 }
1059 out
1060}
1061
1062#[inline(always)]
1063pub fn trima_batch_slice(
1064 data: &[f64],
1065 sweep: &TrimaBatchRange,
1066 kern: Kernel,
1067) -> Result<TrimaBatchOutput, TrimaError> {
1068 trima_batch_inner(data, sweep, kern, false)
1069}
1070
1071#[inline(always)]
1072pub fn trima_batch_par_slice(
1073 data: &[f64],
1074 sweep: &TrimaBatchRange,
1075 kern: Kernel,
1076) -> Result<TrimaBatchOutput, TrimaError> {
1077 trima_batch_inner(data, sweep, kern, true)
1078}
1079
1080#[inline(always)]
1081fn trima_batch_inner(
1082 data: &[f64],
1083 sweep: &TrimaBatchRange,
1084 kern: Kernel,
1085 parallel: bool,
1086) -> Result<TrimaBatchOutput, TrimaError> {
1087 let combos = expand_grid(sweep);
1088 if combos.is_empty() {
1089 return Err(TrimaError::InvalidRange {
1090 start: sweep.period.0,
1091 end: sweep.period.1,
1092 step: sweep.period.2,
1093 });
1094 }
1095
1096 let first = data
1097 .iter()
1098 .position(|x| !x.is_nan())
1099 .ok_or(TrimaError::AllValuesNaN)?;
1100 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1101 if data.len() - first < max_p {
1102 return Err(TrimaError::NotEnoughValidData {
1103 needed: max_p,
1104 valid: data.len() - first,
1105 });
1106 }
1107
1108 let rows = combos.len();
1109 let cols = data.len();
1110 let warm: Vec<usize> = combos
1111 .iter()
1112 .map(|c| first + c.period.unwrap() - 1)
1113 .collect();
1114
1115 let _total = rows.checked_mul(cols).ok_or(TrimaError::InvalidRange {
1116 start: sweep.period.0,
1117 end: sweep.period.1,
1118 step: sweep.period.2,
1119 })?;
1120 let mut raw = make_uninit_matrix(rows, cols);
1121 unsafe { init_matrix_prefixes(&mut raw, cols, &warm) };
1122
1123 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1124 let period = combos[row].period.unwrap();
1125
1126 let out_row =
1127 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1128
1129 match kern {
1130 Kernel::Scalar => {
1131 trima_row_scalar(data, first, period, 0, core::ptr::null(), 1.0, out_row)
1132 }
1133 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1134 Kernel::Avx2 => trima_row_avx2(data, first, period, 0, core::ptr::null(), 1.0, out_row),
1135 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1136 Kernel::Avx512 => {
1137 trima_row_avx512(data, first, period, 0, core::ptr::null(), 1.0, out_row)
1138 }
1139 _ => trima_row_scalar(data, first, period, 0, core::ptr::null(), 1.0, out_row),
1140 }
1141 };
1142
1143 if parallel {
1144 #[cfg(not(target_arch = "wasm32"))]
1145 {
1146 raw.par_chunks_mut(cols)
1147 .enumerate()
1148 .for_each(|(row, slice)| do_row(row, slice));
1149 }
1150
1151 #[cfg(target_arch = "wasm32")]
1152 {
1153 for (row, slice) in raw.chunks_mut(cols).enumerate() {
1154 do_row(row, slice);
1155 }
1156 }
1157 } else {
1158 for (row, slice) in raw.chunks_mut(cols).enumerate() {
1159 do_row(row, slice);
1160 }
1161 }
1162
1163 let mut buf_guard = core::mem::ManuallyDrop::new(raw);
1164 let values = unsafe {
1165 Vec::from_raw_parts(
1166 buf_guard.as_mut_ptr() as *mut f64,
1167 buf_guard.len(),
1168 buf_guard.capacity(),
1169 )
1170 };
1171
1172 Ok(TrimaBatchOutput {
1173 values,
1174 combos,
1175 rows,
1176 cols,
1177 })
1178}
1179
1180#[inline(always)]
1181pub fn trima_batch_inner_into(
1182 data: &[f64],
1183 sweep: &TrimaBatchRange,
1184 kern: Kernel,
1185 parallel: bool,
1186 out: &mut [f64],
1187) -> Result<Vec<TrimaParams>, TrimaError> {
1188 let combos = expand_grid(sweep);
1189 if combos.is_empty() {
1190 return Err(TrimaError::InvalidRange {
1191 start: sweep.period.0,
1192 end: sweep.period.1,
1193 step: sweep.period.2,
1194 });
1195 }
1196
1197 let first = data
1198 .iter()
1199 .position(|x| !x.is_nan())
1200 .ok_or(TrimaError::AllValuesNaN)?;
1201 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1202 if data.len() - first < max_p {
1203 return Err(TrimaError::NotEnoughValidData {
1204 needed: max_p,
1205 valid: data.len() - first,
1206 });
1207 }
1208
1209 let rows = combos.len();
1210 let cols = data.len();
1211 let expected = rows.checked_mul(cols).ok_or(TrimaError::InvalidRange {
1212 start: sweep.period.0,
1213 end: sweep.period.1,
1214 step: sweep.period.2,
1215 })?;
1216 if out.len() != expected {
1217 return Err(TrimaError::OutputLengthMismatch {
1218 expected,
1219 got: out.len(),
1220 });
1221 }
1222
1223 let warm: Vec<usize> = combos
1224 .iter()
1225 .map(|c| first + c.period.unwrap() - 1)
1226 .collect();
1227 let out_mu = unsafe {
1228 core::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1229 };
1230 init_matrix_prefixes(out_mu, cols, &warm);
1231
1232 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1233 let period = combos[row].period.unwrap();
1234 let out_row =
1235 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1236 match kern {
1237 Kernel::Scalar => {
1238 trima_row_scalar(data, first, period, 0, core::ptr::null(), 1.0, out_row)
1239 }
1240 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1241 Kernel::Avx2 => trima_row_avx2(data, first, period, 0, core::ptr::null(), 1.0, out_row),
1242 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1243 Kernel::Avx512 => {
1244 trima_row_avx512(data, first, period, 0, core::ptr::null(), 1.0, out_row)
1245 }
1246 _ => trima_row_scalar(data, first, period, 0, core::ptr::null(), 1.0, out_row),
1247 }
1248 };
1249
1250 if parallel {
1251 #[cfg(not(target_arch = "wasm32"))]
1252 {
1253 out_mu
1254 .par_chunks_mut(cols)
1255 .enumerate()
1256 .for_each(|(row, slice)| do_row(row, slice));
1257 }
1258 #[cfg(target_arch = "wasm32")]
1259 {
1260 for (row, slice) in out_mu.chunks_mut(cols).enumerate() {
1261 do_row(row, slice);
1262 }
1263 }
1264 } else {
1265 for (row, slice) in out_mu.chunks_mut(cols).enumerate() {
1266 do_row(row, slice);
1267 }
1268 }
1269
1270 Ok(combos)
1271}
1272
1273#[cfg(feature = "python")]
1274use pyo3::exceptions::PyValueError;
1275#[cfg(feature = "python")]
1276use pyo3::prelude::*;
1277
1278#[cfg(feature = "python")]
1279#[pyfunction(name = "trima")]
1280#[pyo3(signature = (data, period, kernel=None))]
1281
1282pub fn trima_py<'py>(
1283 py: Python<'py>,
1284 data: numpy::PyReadonlyArray1<'py, f64>,
1285 period: usize,
1286 kernel: Option<&str>,
1287) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1288 use numpy::{IntoPyArray, PyArrayMethods};
1289
1290 let slice_in = data.as_slice()?;
1291 let kern = validate_kernel(kernel, false)?;
1292
1293 let params = TrimaParams {
1294 period: Some(period),
1295 };
1296 let trima_in = TrimaInput::from_slice(slice_in, params);
1297
1298 let result_vec: Vec<f64> = py
1299 .allow_threads(|| trima_with_kernel(&trima_in, kern).map(|o| o.values))
1300 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1301
1302 Ok(result_vec.into_pyarray(py))
1303}
1304
1305#[cfg(feature = "python")]
1306#[pyclass(name = "TrimaStream")]
1307pub struct TrimaStreamPy {
1308 stream: TrimaStream,
1309}
1310
1311#[cfg(feature = "python")]
1312#[pymethods]
1313impl TrimaStreamPy {
1314 #[new]
1315 fn new(period: Option<usize>) -> PyResult<Self> {
1316 let params = TrimaParams { period };
1317 let stream =
1318 TrimaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1319 Ok(TrimaStreamPy { stream })
1320 }
1321
1322 fn update(&mut self, value: f64) -> Option<f64> {
1323 self.stream.update(value)
1324 }
1325}
1326
1327#[cfg(feature = "python")]
1328#[pyfunction(name = "trima_batch")]
1329#[pyo3(signature = (data, period_range, kernel=None))]
1330
1331pub fn trima_batch_py<'py>(
1332 py: Python<'py>,
1333 data: numpy::PyReadonlyArray1<'py, f64>,
1334 period_range: (usize, usize, usize),
1335 kernel: Option<&str>,
1336) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1337 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1338 use pyo3::types::PyDict;
1339
1340 let slice_in = data.as_slice()?;
1341 let sweep = TrimaBatchRange {
1342 period: period_range,
1343 };
1344
1345 let combos = expand_grid(&sweep);
1346 let rows = combos.len();
1347 let cols = slice_in.len();
1348
1349 let total = rows
1350 .checked_mul(cols)
1351 .ok_or_else(|| PyValueError::new_err("size overflow: rows*cols exceeds usize"))?;
1352
1353 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1354 let slice_out = unsafe { out_arr.as_slice_mut()? };
1355
1356 let kern = validate_kernel(kernel, true)?;
1357 let kern = match kern {
1358 Kernel::Auto => detect_best_batch_kernel(),
1359 k => k,
1360 };
1361 let simd = match kern {
1362 Kernel::Avx512Batch => Kernel::Avx512,
1363 Kernel::Avx2Batch => Kernel::Avx2,
1364 Kernel::ScalarBatch => Kernel::Scalar,
1365 Kernel::Avx512 | Kernel::Avx2 | Kernel::Scalar => kern,
1366 _ => Kernel::Scalar,
1367 };
1368
1369 let combos = py
1370 .allow_threads(|| trima_batch_inner_into(slice_in, &sweep, simd, true, slice_out))
1371 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1372
1373 let dict = PyDict::new(py);
1374 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1375 dict.set_item(
1376 "periods",
1377 combos
1378 .iter()
1379 .map(|p| p.period.unwrap_or(30) as u64)
1380 .collect::<Vec<_>>()
1381 .into_pyarray(py),
1382 )?;
1383 Ok(dict)
1384}
1385
1386#[cfg(all(feature = "python", feature = "cuda"))]
1387#[pyfunction(name = "trima_cuda_batch_dev")]
1388#[pyo3(signature = (data, period_range, device_id=0))]
1389pub fn trima_cuda_batch_dev_py(
1390 py: Python<'_>,
1391 data: numpy::PyReadonlyArray1<'_, f64>,
1392 period_range: (usize, usize, usize),
1393 device_id: usize,
1394) -> PyResult<DeviceArrayF32TrimaPy> {
1395 use numpy::PyArrayMethods;
1396
1397 if !cuda_available() {
1398 return Err(PyValueError::new_err("CUDA not available"));
1399 }
1400
1401 let slice_in = data.as_slice()?;
1402 let sweep = TrimaBatchRange {
1403 period: period_range,
1404 };
1405
1406 let data_f32: Vec<f32> = slice_in.iter().map(|&v| v as f32).collect();
1407
1408 let inner = py.allow_threads(|| {
1409 let cuda = CudaTrima::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1410 cuda.trima_batch_dev(&data_f32, &sweep)
1411 .map_err(|e| PyValueError::new_err(e.to_string()))
1412 })?;
1413
1414 Ok(DeviceArrayF32TrimaPy { inner: Some(inner) })
1415}
1416
1417#[cfg(all(feature = "python", feature = "cuda"))]
1418#[pyfunction(name = "trima_cuda_many_series_one_param_dev")]
1419#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1420pub fn trima_cuda_many_series_one_param_dev_py(
1421 py: Python<'_>,
1422 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1423 period: usize,
1424 device_id: usize,
1425) -> PyResult<DeviceArrayF32TrimaPy> {
1426 use numpy::PyUntypedArrayMethods;
1427
1428 if !cuda_available() {
1429 return Err(PyValueError::new_err("CUDA not available"));
1430 }
1431
1432 let flat_in = data_tm_f32.as_slice()?;
1433 let rows = data_tm_f32.shape()[0];
1434 let cols = data_tm_f32.shape()[1];
1435 let params = TrimaParams {
1436 period: Some(period),
1437 };
1438
1439 let inner = py.allow_threads(|| {
1440 let cuda = CudaTrima::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1441 cuda.trima_multi_series_one_param_time_major_dev(flat_in, cols, rows, ¶ms)
1442 .map_err(|e| PyValueError::new_err(e.to_string()))
1443 })?;
1444
1445 Ok(DeviceArrayF32TrimaPy { inner: Some(inner) })
1446}
1447
1448#[cfg(all(feature = "python", feature = "cuda"))]
1449#[pyclass(
1450 module = "ta_indicators.cuda",
1451 name = "DeviceArrayF32Trima",
1452 unsendable
1453)]
1454pub struct DeviceArrayF32TrimaPy {
1455 pub(crate) inner: Option<DeviceArrayF32Trima>,
1456}
1457
1458#[cfg(all(feature = "python", feature = "cuda"))]
1459#[pymethods]
1460impl DeviceArrayF32TrimaPy {
1461 #[getter]
1462 fn __cuda_array_interface__<'py>(
1463 &self,
1464 py: Python<'py>,
1465 ) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1466 use pyo3::types::PyDict;
1467 let inner = self
1468 .inner
1469 .as_ref()
1470 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1471 let d = PyDict::new(py);
1472 d.set_item("shape", (inner.rows, inner.cols))?;
1473 d.set_item("typestr", "<f4")?;
1474 d.set_item(
1475 "strides",
1476 (
1477 inner.cols * std::mem::size_of::<f32>(),
1478 std::mem::size_of::<f32>(),
1479 ),
1480 )?;
1481
1482 let ptr_val: usize = if inner.rows == 0 || inner.cols == 0 {
1483 0
1484 } else {
1485 inner.device_ptr() as usize
1486 };
1487 d.set_item("data", (ptr_val, false))?;
1488
1489 d.set_item("version", 3)?;
1490 Ok(d)
1491 }
1492
1493 fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
1494 let inner = self
1495 .inner
1496 .as_ref()
1497 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1498 Ok((2, inner.device_id as i32))
1499 }
1500
1501 #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
1502 fn __dlpack__<'py>(
1503 &mut self,
1504 py: Python<'py>,
1505 stream: Option<PyObject>,
1506 max_version: Option<PyObject>,
1507 dl_device: Option<PyObject>,
1508 copy: Option<PyObject>,
1509 ) -> PyResult<PyObject> {
1510 let (kdl, alloc_dev) = self.__dlpack_device__()?;
1511
1512 if let Some(dev_obj) = dl_device.as_ref() {
1513 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1514 if dev_ty != kdl || dev_id != alloc_dev {
1515 let wants_copy = copy
1516 .as_ref()
1517 .and_then(|c| c.extract::<bool>(py).ok())
1518 .unwrap_or(false);
1519 if wants_copy {
1520 return Err(PyValueError::new_err(
1521 "device copy not implemented for __dlpack__",
1522 ));
1523 } else {
1524 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1525 }
1526 }
1527 }
1528 }
1529
1530 let _ = stream;
1531
1532 let inner = self
1533 .inner
1534 .take()
1535 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1536
1537 let DeviceArrayF32Trima {
1538 buf,
1539 rows,
1540 cols,
1541 ctx: _ctx,
1542 device_id,
1543 } = inner;
1544
1545 if device_id as i32 != alloc_dev {
1546 return Err(PyValueError::new_err("device id mismatch for __dlpack__"));
1547 }
1548
1549 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1550
1551 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1552 }
1553}
1554
1555#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1556use serde::{Deserialize, Serialize};
1557#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1558use serde_wasm_bindgen;
1559#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1560use wasm_bindgen::prelude::*;
1561
1562#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1563#[wasm_bindgen]
1564
1565pub fn trima_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1566 let params = TrimaParams {
1567 period: Some(period),
1568 };
1569 let input = TrimaInput::from_slice(data, params);
1570
1571 let mut output = vec![0.0; data.len()];
1572
1573 trima_into_slice(&mut output, &input, Kernel::Auto)
1574 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1575
1576 Ok(output)
1577}
1578
1579#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1580#[derive(Serialize, Deserialize)]
1581pub struct TrimaBatchConfig {
1582 pub period_range: (usize, usize, usize),
1583}
1584
1585#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1586#[derive(Serialize, Deserialize)]
1587pub struct TrimaBatchJsOutput {
1588 pub values: Vec<f64>,
1589 pub combos: Vec<TrimaParams>,
1590 pub rows: usize,
1591 pub cols: usize,
1592}
1593
1594#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1595#[wasm_bindgen(js_name = trima_batch)]
1596pub fn trima_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1597 let config: TrimaBatchConfig = serde_wasm_bindgen::from_value(config)
1598 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1599 let sweep = TrimaBatchRange {
1600 period: config.period_range,
1601 };
1602
1603 let output = trima_batch_inner(data, &sweep, detect_best_kernel(), false)
1604 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1605
1606 let js_output = TrimaBatchJsOutput {
1607 values: output.values,
1608 combos: output.combos,
1609 rows: output.rows,
1610 cols: output.cols,
1611 };
1612 serde_wasm_bindgen::to_value(&js_output)
1613 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1614}
1615
1616#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1617#[wasm_bindgen]
1618
1619pub fn trima_batch_js(
1620 data: &[f64],
1621 period_start: usize,
1622 period_end: usize,
1623 period_step: usize,
1624) -> Result<Vec<f64>, JsValue> {
1625 let sweep = TrimaBatchRange {
1626 period: (period_start, period_end, period_step),
1627 };
1628
1629 trima_batch_inner(data, &sweep, Kernel::Auto, false)
1630 .map(|output| output.values)
1631 .map_err(|e| JsValue::from_str(&e.to_string()))
1632}
1633
1634#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1635#[wasm_bindgen]
1636
1637pub fn trima_batch_metadata_js(
1638 period_start: usize,
1639 period_end: usize,
1640 period_step: usize,
1641) -> Result<Vec<f64>, JsValue> {
1642 let sweep = TrimaBatchRange {
1643 period: (period_start, period_end, period_step),
1644 };
1645
1646 let combos = expand_grid(&sweep);
1647 let metadata: Vec<f64> = combos
1648 .iter()
1649 .map(|combo| combo.period.unwrap_or(30) as f64)
1650 .collect();
1651
1652 Ok(metadata)
1653}
1654
1655#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1656#[wasm_bindgen]
1657pub fn trima_alloc(len: usize) -> *mut f64 {
1658 let mut vec = Vec::<f64>::with_capacity(len);
1659 let ptr = vec.as_mut_ptr();
1660 std::mem::forget(vec);
1661 ptr
1662}
1663
1664#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1665#[wasm_bindgen]
1666pub fn trima_free(ptr: *mut f64, len: usize) {
1667 unsafe {
1668 let _ = Vec::from_raw_parts(ptr, len, len);
1669 }
1670}
1671
1672#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1673#[wasm_bindgen]
1674pub fn trima_into(
1675 in_ptr: *const f64,
1676 out_ptr: *mut f64,
1677 len: usize,
1678 period: usize,
1679) -> Result<(), JsValue> {
1680 if in_ptr.is_null() || out_ptr.is_null() {
1681 return Err(JsValue::from_str("null pointer passed to trima_into"));
1682 }
1683 unsafe {
1684 let data = std::slice::from_raw_parts(in_ptr, len);
1685 if period == 0 || period > len {
1686 return Err(JsValue::from_str("Invalid period"));
1687 }
1688 let params = TrimaParams {
1689 period: Some(period),
1690 };
1691 if in_ptr == out_ptr {
1692 let mut temp = vec![0.0; len];
1693 let input = TrimaInput::from_slice(data, params);
1694 trima_into_slice(&mut temp, &input, Kernel::Auto)
1695 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1696 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1697 out.copy_from_slice(&temp);
1698 } else {
1699 let input = TrimaInput::from_slice(data, params);
1700 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1701 trima_into_slice(out, &input, Kernel::Auto)
1702 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1703 }
1704 Ok(())
1705 }
1706}
1707
1708#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1709#[wasm_bindgen]
1710pub fn trima_batch_into(
1711 in_ptr: *const f64,
1712 out_ptr: *mut f64,
1713 len: usize,
1714 period_start: usize,
1715 period_end: usize,
1716 period_step: usize,
1717) -> Result<usize, JsValue> {
1718 if in_ptr.is_null() || out_ptr.is_null() {
1719 return Err(JsValue::from_str("null pointer passed to trima_batch_into"));
1720 }
1721
1722 unsafe {
1723 let data = std::slice::from_raw_parts(in_ptr, len);
1724
1725 let sweep = TrimaBatchRange {
1726 period: (period_start, period_end, period_step),
1727 };
1728
1729 let combos = expand_grid(&sweep);
1730 let rows = combos.len();
1731 let cols = len;
1732
1733 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
1734
1735 trima_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
1736 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1737
1738 Ok(rows)
1739 }
1740}
1741
1742#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1743#[wasm_bindgen]
1744#[deprecated(
1745 since = "1.0.0",
1746 note = "For streaming patterns, use the fast/unsafe API with persistent buffers"
1747)]
1748pub struct TrimaContext {
1749 period: usize,
1750 m1: usize,
1751 m2: usize,
1752 first: usize,
1753 kernel: Kernel,
1754}
1755
1756#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1757#[wasm_bindgen]
1758#[allow(deprecated)]
1759impl TrimaContext {
1760 #[wasm_bindgen(constructor)]
1761 #[deprecated(
1762 since = "1.0.0",
1763 note = "For streaming patterns, use the fast/unsafe API with persistent buffers"
1764 )]
1765 pub fn new(period: usize) -> Result<TrimaContext, JsValue> {
1766 if period == 0 {
1767 return Err(JsValue::from_str("Invalid period: 0"));
1768 }
1769 if period <= 3 {
1770 return Err(JsValue::from_str(&format!("Period too small: {}", period)));
1771 }
1772
1773 let m1 = (period + 1) / 2;
1774 let m2 = period - m1 + 1;
1775
1776 Ok(TrimaContext {
1777 period,
1778 m1,
1779 m2,
1780 first: 0,
1781 kernel: Kernel::Auto,
1782 })
1783 }
1784
1785 pub fn update_into(
1786 &self,
1787 in_ptr: *const f64,
1788 out_ptr: *mut f64,
1789 len: usize,
1790 ) -> Result<(), JsValue> {
1791 if len < self.period {
1792 return Err(JsValue::from_str("Data length less than period"));
1793 }
1794
1795 unsafe {
1796 let data = std::slice::from_raw_parts(in_ptr, len);
1797 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1798
1799 let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1800
1801 if in_ptr == out_ptr {
1802 let mut temp = vec![0.0; len];
1803 trima_compute_into(
1804 data,
1805 self.period,
1806 self.m1,
1807 self.m2,
1808 first,
1809 self.kernel,
1810 &mut temp,
1811 );
1812
1813 out.copy_from_slice(&temp);
1814 } else {
1815 trima_compute_into(data, self.period, self.m1, self.m2, first, self.kernel, out);
1816 }
1817
1818 let warmup = first + self.period - 1;
1819 for i in 0..warmup {
1820 out[i] = f64::NAN;
1821 }
1822 }
1823
1824 Ok(())
1825 }
1826
1827 pub fn get_warmup_period(&self) -> usize {
1828 self.period - 1
1829 }
1830}
1831
1832#[cfg(test)]
1833mod tests {
1834 use super::*;
1835 use crate::skip_if_unsupported;
1836 use crate::utilities::data_loader::read_candles_from_csv;
1837 use crate::utilities::enums::Kernel;
1838
1839 fn check_trima_partial_params(
1840 test_name: &str,
1841 kernel: Kernel,
1842 ) -> Result<(), Box<dyn std::error::Error>> {
1843 skip_if_unsupported!(kernel, test_name);
1844 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1845 let candles = read_candles_from_csv(file_path)?;
1846
1847 let default_params = TrimaParams { period: None };
1848 let input = TrimaInput::from_candles(&candles, "close", default_params);
1849 let output = trima_with_kernel(&input, kernel)?;
1850 assert_eq!(output.values.len(), candles.close.len());
1851
1852 let params_period_10 = TrimaParams { period: Some(10) };
1853 let input2 = TrimaInput::from_candles(&candles, "hl2", params_period_10);
1854 let output2 = trima_with_kernel(&input2, kernel)?;
1855 assert_eq!(output2.values.len(), candles.close.len());
1856
1857 let params_custom = TrimaParams { period: Some(14) };
1858 let input3 = TrimaInput::from_candles(&candles, "hlc3", params_custom);
1859 let output3 = trima_with_kernel(&input3, kernel)?;
1860 assert_eq!(output3.values.len(), candles.close.len());
1861
1862 Ok(())
1863 }
1864
1865 fn check_trima_accuracy(
1866 test_name: &str,
1867 kernel: Kernel,
1868 ) -> Result<(), Box<dyn std::error::Error>> {
1869 skip_if_unsupported!(kernel, test_name);
1870 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1871 let candles = read_candles_from_csv(file_path)?;
1872 let close_prices = &candles.close;
1873 let params = TrimaParams { period: Some(30) };
1874 let input = TrimaInput::from_candles(&candles, "close", params);
1875 let trima_result = trima_with_kernel(&input, kernel)?;
1876
1877 assert_eq!(
1878 trima_result.values.len(),
1879 close_prices.len(),
1880 "TRIMA output length should match input data length"
1881 );
1882 let expected_last_five_trima = [
1883 59957.916666666664,
1884 59846.770833333336,
1885 59750.620833333334,
1886 59665.2125,
1887 59581.612499999996,
1888 ];
1889 assert!(
1890 trima_result.values.len() >= 5,
1891 "Not enough TRIMA values for the test"
1892 );
1893 let start_index = trima_result.values.len() - 5;
1894 let result_last_five_trima = &trima_result.values[start_index..];
1895 for (i, &value) in result_last_five_trima.iter().enumerate() {
1896 let expected_value = expected_last_five_trima[i];
1897 assert!(
1898 (value - expected_value).abs() < 1e-6,
1899 "[{}] TRIMA value mismatch at index {}: expected {}, got {}",
1900 test_name,
1901 i,
1902 expected_value,
1903 value
1904 );
1905 }
1906 let period = input.params.period.unwrap_or(14);
1907 for i in 0..(period - 1) {
1908 assert!(
1909 trima_result.values[i].is_nan(),
1910 "[{}] Expected NaN at early index {} for TRIMA, got {}",
1911 test_name,
1912 i,
1913 trima_result.values[i]
1914 );
1915 }
1916 Ok(())
1917 }
1918
1919 fn check_trima_default_candles(
1920 test_name: &str,
1921 kernel: Kernel,
1922 ) -> Result<(), Box<dyn std::error::Error>> {
1923 skip_if_unsupported!(kernel, test_name);
1924 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1925 let candles = read_candles_from_csv(file_path)?;
1926 let input = TrimaInput::with_default_candles(&candles);
1927 match input.data {
1928 TrimaData::Candles { source, .. } => assert_eq!(source, "close"),
1929 _ => panic!("Expected TrimaData::Candles"),
1930 }
1931 let output = trima_with_kernel(&input, kernel)?;
1932 assert_eq!(output.values.len(), candles.close.len());
1933 Ok(())
1934 }
1935
1936 fn check_trima_zero_period(
1937 test_name: &str,
1938 kernel: Kernel,
1939 ) -> Result<(), Box<dyn std::error::Error>> {
1940 skip_if_unsupported!(kernel, test_name);
1941 let input_data = [10.0, 20.0, 30.0];
1942 let params = TrimaParams { period: Some(0) };
1943 let input = TrimaInput::from_slice(&input_data, params);
1944 let res = trima_with_kernel(&input, kernel);
1945 assert!(
1946 res.is_err(),
1947 "[{}] TRIMA should fail with zero period",
1948 test_name
1949 );
1950 Ok(())
1951 }
1952
1953 fn check_trima_period_too_small(
1954 test_name: &str,
1955 kernel: Kernel,
1956 ) -> Result<(), Box<dyn std::error::Error>> {
1957 skip_if_unsupported!(kernel, test_name);
1958 let input_data = [10.0, 20.0, 30.0, 40.0];
1959 let params = TrimaParams { period: Some(3) };
1960 let input = TrimaInput::from_slice(&input_data, params);
1961 let res = trima_with_kernel(&input, kernel);
1962 assert!(
1963 res.is_err(),
1964 "[{}] TRIMA should fail with period <= 3",
1965 test_name
1966 );
1967 Ok(())
1968 }
1969
1970 fn check_trima_period_exceeds_length(
1971 test_name: &str,
1972 kernel: Kernel,
1973 ) -> Result<(), Box<dyn std::error::Error>> {
1974 skip_if_unsupported!(kernel, test_name);
1975 let data_small = [10.0, 20.0, 30.0];
1976 let params = TrimaParams { period: Some(10) };
1977 let input = TrimaInput::from_slice(&data_small, params);
1978 let res = trima_with_kernel(&input, kernel);
1979 assert!(
1980 res.is_err(),
1981 "[{}] TRIMA should fail with period exceeding length",
1982 test_name
1983 );
1984 Ok(())
1985 }
1986
1987 fn check_trima_very_small_dataset(
1988 test_name: &str,
1989 kernel: Kernel,
1990 ) -> Result<(), Box<dyn std::error::Error>> {
1991 skip_if_unsupported!(kernel, test_name);
1992 let single_point = [42.0];
1993 let params = TrimaParams { period: Some(14) };
1994 let input = TrimaInput::from_slice(&single_point, params);
1995 let res = trima_with_kernel(&input, kernel);
1996 assert!(
1997 res.is_err(),
1998 "[{}] TRIMA should fail with insufficient data",
1999 test_name
2000 );
2001 Ok(())
2002 }
2003
2004 fn check_trima_reinput(
2005 test_name: &str,
2006 kernel: Kernel,
2007 ) -> Result<(), Box<dyn std::error::Error>> {
2008 skip_if_unsupported!(kernel, test_name);
2009 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2010 let candles = read_candles_from_csv(file_path)?;
2011
2012 let first_params = TrimaParams { period: Some(14) };
2013 let first_input = TrimaInput::from_candles(&candles, "close", first_params);
2014 let first_result = trima_with_kernel(&first_input, kernel)?;
2015
2016 let second_params = TrimaParams { period: Some(10) };
2017 let second_input = TrimaInput::from_slice(&first_result.values, second_params);
2018 let second_result = trima_with_kernel(&second_input, kernel)?;
2019
2020 assert_eq!(second_result.values.len(), first_result.values.len());
2021 for val in &second_result.values[240..] {
2022 assert!(val.is_finite());
2023 }
2024 Ok(())
2025 }
2026
2027 fn check_trima_nan_handling(
2028 test_name: &str,
2029 kernel: Kernel,
2030 ) -> Result<(), Box<dyn std::error::Error>> {
2031 skip_if_unsupported!(kernel, test_name);
2032 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2033 let candles = read_candles_from_csv(file_path)?;
2034
2035 let input = TrimaInput::from_candles(&candles, "close", TrimaParams { period: Some(14) });
2036 let res = trima_with_kernel(&input, kernel)?;
2037 assert_eq!(res.values.len(), candles.close.len());
2038 if res.values.len() > 240 {
2039 for (i, &val) in res.values[240..].iter().enumerate() {
2040 assert!(
2041 !val.is_nan(),
2042 "[{}] Found unexpected NaN at out-index {}",
2043 test_name,
2044 240 + i
2045 );
2046 }
2047 }
2048 Ok(())
2049 }
2050
2051 fn check_trima_streaming(
2052 test_name: &str,
2053 kernel: Kernel,
2054 ) -> Result<(), Box<dyn std::error::Error>> {
2055 skip_if_unsupported!(kernel, test_name);
2056
2057 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2058 let candles = read_candles_from_csv(file_path)?;
2059
2060 let period = 14;
2061
2062 let input = TrimaInput::from_candles(
2063 &candles,
2064 "close",
2065 TrimaParams {
2066 period: Some(period),
2067 },
2068 );
2069 let batch_output = trima_with_kernel(&input, kernel)?.values;
2070
2071 let mut stream = TrimaStream::try_new(TrimaParams {
2072 period: Some(period),
2073 })?;
2074
2075 let mut stream_values = Vec::with_capacity(candles.close.len());
2076 for &price in &candles.close {
2077 match stream.update(price) {
2078 Some(trima_val) => stream_values.push(trima_val),
2079 None => stream_values.push(f64::NAN),
2080 }
2081 }
2082
2083 assert_eq!(batch_output.len(), stream_values.len());
2084 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
2085 if b.is_nan() && s.is_nan() {
2086 continue;
2087 }
2088 let diff = (b - s).abs();
2089 assert!(
2090 diff < 1e-8,
2091 "[{}] TRIMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
2092 test_name,
2093 i,
2094 b,
2095 s,
2096 diff
2097 );
2098 }
2099 Ok(())
2100 }
2101
2102 macro_rules! generate_all_trima_tests {
2103 ($($test_fn:ident),*) => {
2104 paste! {
2105 $(
2106 #[test]
2107 fn [<$test_fn _scalar_f64>]() {
2108 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2109 }
2110 )*
2111 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2112 $(
2113 #[test]
2114 fn [<$test_fn _avx2_f64>]() {
2115 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2116 }
2117 #[test]
2118 fn [<$test_fn _avx512_f64>]() {
2119 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2120 }
2121 )*
2122 }
2123 }
2124 }
2125
2126 #[cfg(debug_assertions)]
2127 fn check_trima_no_poison(
2128 test_name: &str,
2129 kernel: Kernel,
2130 ) -> Result<(), Box<dyn std::error::Error>> {
2131 skip_if_unsupported!(kernel, test_name);
2132
2133 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2134 let candles = read_candles_from_csv(file_path)?;
2135
2136 let test_periods = vec![4, 10, 14, 30, 50, 100];
2137 let test_sources = vec!["close", "open", "high", "low", "hl2", "hlc3", "ohlc4"];
2138
2139 for period in test_periods {
2140 for source in &test_sources {
2141 let params = TrimaParams {
2142 period: Some(period),
2143 };
2144 let input = TrimaInput::from_candles(&candles, source, params);
2145 let output = trima_with_kernel(&input, kernel)?;
2146
2147 for (i, &val) in output.values.iter().enumerate() {
2148 if val.is_nan() {
2149 continue;
2150 }
2151
2152 let bits = val.to_bits();
2153
2154 if bits == 0x11111111_11111111 {
2155 panic!(
2156 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} (period={}, source={})",
2157 test_name, val, bits, i, period, source
2158 );
2159 }
2160
2161 if bits == 0x22222222_22222222 {
2162 panic!(
2163 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} (period={}, source={})",
2164 test_name, val, bits, i, period, source
2165 );
2166 }
2167
2168 if bits == 0x33333333_33333333 {
2169 panic!(
2170 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} (period={}, source={})",
2171 test_name, val, bits, i, period, source
2172 );
2173 }
2174 }
2175 }
2176 }
2177
2178 Ok(())
2179 }
2180
2181 #[cfg(not(debug_assertions))]
2182 fn check_trima_no_poison(
2183 _test_name: &str,
2184 _kernel: Kernel,
2185 ) -> Result<(), Box<dyn std::error::Error>> {
2186 Ok(())
2187 }
2188
2189 #[cfg(feature = "proptest")]
2190 #[allow(clippy::float_cmp)]
2191 fn check_trima_property(
2192 test_name: &str,
2193 kernel: Kernel,
2194 ) -> Result<(), Box<dyn std::error::Error>> {
2195 use crate::indicators::sma::{sma, SmaData, SmaInput, SmaParams};
2196 use proptest::prelude::*;
2197 skip_if_unsupported!(kernel, test_name);
2198
2199 let strat = (4usize..=100).prop_flat_map(|period| {
2200 (
2201 prop::collection::vec(
2202 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2203 period..400,
2204 ),
2205 Just(period),
2206 )
2207 });
2208
2209 proptest::test_runner::TestRunner::default().run(&strat, |(data, period)| {
2210 let params = TrimaParams {
2211 period: Some(period),
2212 };
2213 let input = TrimaInput::from_slice(&data, params);
2214
2215 let result = trima_with_kernel(&input, kernel)?;
2216 let scalar_result = trima_with_kernel(&input, Kernel::Scalar)?;
2217
2218 let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
2219 let warmup_end = first + period - 1;
2220
2221 for i in 0..warmup_end.min(data.len()) {
2222 prop_assert!(
2223 result.values[i].is_nan(),
2224 "Expected NaN during warmup at index {}, got {}",
2225 i,
2226 result.values[i]
2227 );
2228 }
2229
2230 for i in warmup_end..data.len() {
2231 prop_assert!(
2232 result.values[i].is_finite() || data[i].is_nan(),
2233 "Expected finite value after warmup at index {}, got {}",
2234 i,
2235 result.values[i]
2236 );
2237 }
2238
2239 if data[first..]
2240 .windows(2)
2241 .all(|w| (w[0] - w[1]).abs() < 1e-10)
2242 && data.len() > first
2243 {
2244 let constant_val = data[first];
2245 for i in warmup_end..data.len() {
2246 prop_assert!(
2247 (result.values[i] - constant_val).abs() < 1e-9,
2248 "Constant input should produce constant output at index {}: expected {}, got {}",
2249 i,
2250 constant_val,
2251 result.values[i]
2252 );
2253 }
2254 }
2255
2256 for i in 0..data.len() {
2257 let val = result.values[i];
2258 let ref_val = scalar_result.values[i];
2259
2260 if val.is_nan() && ref_val.is_nan() {
2261 continue;
2262 }
2263
2264 if !val.is_finite() || !ref_val.is_finite() {
2265 prop_assert_eq!(
2266 val.to_bits(),
2267 ref_val.to_bits(),
2268 "NaN/Inf mismatch at index {}: {} vs {}",
2269 i,
2270 val,
2271 ref_val
2272 );
2273 } else {
2274 let ulp_diff = val.to_bits().abs_diff(ref_val.to_bits());
2275 prop_assert!(
2276 (val - ref_val).abs() < 1e-9 || ulp_diff <= 4,
2277 "Cross-kernel mismatch at index {}: {} vs {} (ULP diff: {})",
2278 i,
2279 val,
2280 ref_val,
2281 ulp_diff
2282 );
2283 }
2284 }
2285
2286 for (i, &val) in result.values.iter().enumerate() {
2287 prop_assert!(
2288 val.is_nan() || val.is_finite(),
2289 "Value should be finite or NaN at index {}, got {}",
2290 i,
2291 val
2292 );
2293 }
2294
2295 for i in warmup_end..data.len() {
2296 if i >= period - 1 {
2297 let start = if i >= period - 1 { i + 1 - period } else { 0 };
2298 let window = &data[start..=i];
2299 let min_val = window
2300 .iter()
2301 .filter(|x| x.is_finite())
2302 .fold(f64::INFINITY, |a, &b| a.min(b));
2303 let max_val = window
2304 .iter()
2305 .filter(|x| x.is_finite())
2306 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
2307
2308 if min_val.is_finite() && max_val.is_finite() {
2309 let val = result.values[i];
2310
2311 prop_assert!(
2312 val >= min_val - 1e-6 && val <= max_val + 1e-6,
2313 "TRIMA value {} at index {} outside window bounds [{}, {}]",
2314 val,
2315 i,
2316 min_val,
2317 max_val
2318 );
2319 }
2320 }
2321 }
2322
2323 if period == 4 {
2324 let m1 = 2;
2325 let m2 = 3;
2326
2327 let sma1_input = SmaInput {
2328 data: SmaData::Slice(&data),
2329 params: SmaParams { period: Some(m1) },
2330 };
2331 let pass1 = sma(&sma1_input)?;
2332
2333 let sma2_input = SmaInput {
2334 data: SmaData::Slice(&pass1.values),
2335 params: SmaParams { period: Some(m2) },
2336 };
2337 let expected = sma(&sma2_input)?;
2338
2339 for i in warmup_end..data.len().min(warmup_end + 5) {
2340 prop_assert!(
2341 (result.values[i] - expected.values[i]).abs() < 1e-9,
2342 "Period=4: TRIMA mismatch at index {}: got {}, expected {}",
2343 i,
2344 result.values[i],
2345 expected.values[i]
2346 );
2347 }
2348 }
2349
2350 {
2351 let m1 = (period + 1) / 2;
2352 let m2 = period - m1 + 1;
2353
2354 let sma1_input = SmaInput {
2355 data: SmaData::Slice(&data),
2356 params: SmaParams { period: Some(m1) },
2357 };
2358 let pass1 = sma(&sma1_input)?;
2359
2360 let sma2_input = SmaInput {
2361 data: SmaData::Slice(&pass1.values),
2362 params: SmaParams { period: Some(m2) },
2363 };
2364 let expected = sma(&sma2_input)?;
2365
2366 let check_points = vec![
2367 warmup_end,
2368 warmup_end + period / 2,
2369 warmup_end + period,
2370 data.len() - 1,
2371 ];
2372
2373 for &idx in &check_points {
2374 if idx < data.len() {
2375 let trima_val = result.values[idx];
2376 let expected_val = expected.values[idx];
2377
2378 if trima_val.is_finite() && expected_val.is_finite() {
2379 prop_assert!(
2380 (trima_val - expected_val).abs() < 1e-9,
2381 "Two-pass SMA formula mismatch at index {}: TRIMA={}, Expected={}",
2382 idx,
2383 trima_val,
2384 expected_val
2385 );
2386 }
2387 }
2388 }
2389 }
2390
2391 if data.len() >= warmup_end + 20 {
2392 let sma_input = SmaInput {
2393 data: SmaData::Slice(&data),
2394 params: SmaParams {
2395 period: Some(period),
2396 },
2397 };
2398 let single_sma = sma(&sma_input)?;
2399
2400 let trima_roughness: f64 = result.values[warmup_end..warmup_end + 20]
2401 .windows(2)
2402 .map(|w| (w[1] - w[0]).abs())
2403 .sum();
2404
2405 let sma_roughness: f64 = single_sma.values[warmup_end..warmup_end + 20]
2406 .windows(2)
2407 .map(|w| (w[1] - w[0]).abs())
2408 .sum();
2409
2410 if sma_roughness > 1e-10 {
2411 prop_assert!(
2412 trima_roughness <= sma_roughness * 1.1,
2413 "TRIMA should be smoother than single SMA: TRIMA roughness={}, SMA roughness={}",
2414 trima_roughness,
2415 sma_roughness
2416 );
2417 }
2418 }
2419
2420 if data.len() == period {
2421 prop_assert!(
2422 result.values[period - 1].is_finite(),
2423 "With data.len()==period, last value should be finite, got {}",
2424 result.values[period - 1]
2425 );
2426
2427 for i in 0..period - 1 {
2428 prop_assert!(
2429 result.values[i].is_nan(),
2430 "With data.len()==period, value at {} should be NaN, got {}",
2431 i,
2432 result.values[i]
2433 );
2434 }
2435 }
2436
2437 let is_monotonic_increasing = data[first..].windows(2).all(|w| w[1] >= w[0] - 1e-10);
2438 let is_monotonic_decreasing = data[first..].windows(2).all(|w| w[1] <= w[0] + 1e-10);
2439
2440 if is_monotonic_increasing || is_monotonic_decreasing {
2441 let valid_trima = &result.values[warmup_end..];
2442 if valid_trima.len() >= 2 {
2443 if is_monotonic_increasing {
2444 for w in valid_trima.windows(2) {
2445 prop_assert!(
2446 w[1] >= w[0] - 1e-9,
2447 "TRIMA should preserve increasing trend: {} < {}",
2448 w[1],
2449 w[0]
2450 );
2451 }
2452 } else {
2453 for w in valid_trima.windows(2) {
2454 prop_assert!(
2455 w[1] <= w[0] + 1e-9,
2456 "TRIMA should preserve decreasing trend: {} > {}",
2457 w[1],
2458 w[0]
2459 );
2460 }
2461 }
2462 }
2463 }
2464
2465 #[cfg(debug_assertions)]
2466 {
2467 for (i, &val) in result.values.iter().enumerate() {
2468 if !val.is_nan() {
2469 let bits = val.to_bits();
2470 prop_assert!(
2471 bits != 0x11111111_11111111
2472 && bits != 0x22222222_22222222
2473 && bits != 0x33333333_33333333,
2474 "Found poison value at index {}: {} (0x{:016X})",
2475 i,
2476 val,
2477 bits
2478 );
2479 }
2480 }
2481 }
2482
2483 Ok(())
2484 })?;
2485
2486 Ok(())
2487 }
2488
2489 #[cfg(feature = "proptest")]
2490 generate_all_trima_tests!(check_trima_property);
2491
2492 generate_all_trima_tests!(
2493 check_trima_partial_params,
2494 check_trima_accuracy,
2495 check_trima_default_candles,
2496 check_trima_zero_period,
2497 check_trima_period_exceeds_length,
2498 check_trima_period_too_small,
2499 check_trima_very_small_dataset,
2500 check_trima_reinput,
2501 check_trima_nan_handling,
2502 check_trima_streaming,
2503 check_trima_no_poison
2504 );
2505
2506 fn check_batch_default_row(
2507 test: &str,
2508 kernel: Kernel,
2509 ) -> Result<(), Box<dyn std::error::Error>> {
2510 skip_if_unsupported!(kernel, test);
2511
2512 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2513 let c = read_candles_from_csv(file)?;
2514
2515 let output = TrimaBatchBuilder::new()
2516 .kernel(kernel)
2517 .apply_candles(&c, "close")?;
2518
2519 let def = TrimaParams::default();
2520 let row = output.values_for(&def).expect("default row missing");
2521
2522 assert_eq!(row.len(), c.close.len());
2523
2524 let expected = [
2525 59957.916666666664,
2526 59846.770833333336,
2527 59750.620833333334,
2528 59665.2125,
2529 59581.612499999996,
2530 ];
2531 let start = row.len() - 5;
2532 for (i, &v) in row[start..].iter().enumerate() {
2533 assert!(
2534 (v - expected[i]).abs() < 1e-6,
2535 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2536 );
2537 }
2538 Ok(())
2539 }
2540
2541 macro_rules! gen_batch_tests {
2542 ($fn_name:ident) => {
2543 paste! {
2544 #[test] fn [<$fn_name _scalar>]() {
2545 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2546 }
2547 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2548 #[test] fn [<$fn_name _avx2>]() {
2549 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2550 }
2551 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2552 #[test] fn [<$fn_name _avx512>]() {
2553 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2554 }
2555 #[test] fn [<$fn_name _auto_detect>]() {
2556 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2557 }
2558 }
2559 };
2560 }
2561
2562 #[cfg(debug_assertions)]
2563 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
2564 skip_if_unsupported!(kernel, test);
2565
2566 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2567 let c = read_candles_from_csv(file)?;
2568
2569 let period_ranges = vec![(4, 20, 4), (20, 50, 10), (50, 100, 25), (5, 15, 1)];
2570
2571 let test_sources = vec!["close", "open", "high", "low", "hl2", "hlc3", "ohlc4"];
2572
2573 for (start, end, step) in period_ranges {
2574 for source in &test_sources {
2575 let output = TrimaBatchBuilder::new()
2576 .kernel(kernel)
2577 .period_range(start, end, step)
2578 .apply_candles(&c, source)?;
2579
2580 for (idx, &val) in output.values.iter().enumerate() {
2581 if val.is_nan() {
2582 continue;
2583 }
2584
2585 let bits = val.to_bits();
2586 let row = idx / output.cols;
2587 let col = idx % output.cols;
2588
2589 if bits == 0x11111111_11111111 {
2590 panic!(
2591 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with period_range({},{},{}) source={}",
2592 test, val, bits, row, col, idx, start, end, step, source
2593 );
2594 }
2595
2596 if bits == 0x22222222_22222222 {
2597 panic!(
2598 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) with period_range({},{},{}) source={}",
2599 test, val, bits, row, col, idx, start, end, step, source
2600 );
2601 }
2602
2603 if bits == 0x33333333_33333333 {
2604 panic!(
2605 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with period_range({},{},{}) source={}",
2606 test, val, bits, row, col, idx, start, end, step, source
2607 );
2608 }
2609 }
2610 }
2611 }
2612
2613 Ok(())
2614 }
2615
2616 #[cfg(not(debug_assertions))]
2617 fn check_batch_no_poison(
2618 _test: &str,
2619 _kernel: Kernel,
2620 ) -> Result<(), Box<dyn std::error::Error>> {
2621 Ok(())
2622 }
2623
2624 gen_batch_tests!(check_batch_default_row);
2625 gen_batch_tests!(check_batch_no_poison);
2626
2627 #[test]
2628 fn test_trima_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
2629 let mut data = Vec::with_capacity(256);
2630 data.extend_from_slice(&[f64::NAN, f64::NAN, f64::NAN, f64::NAN]);
2631 for i in 0..252 {
2632 let x = (i as f64 * 0.131).sin() * 7.0 + (i as f64) * 0.02;
2633 data.push(x);
2634 }
2635
2636 let input = TrimaInput::from_slice(&data, TrimaParams::default());
2637
2638 let baseline = trima(&input)?;
2639
2640 let mut out = vec![0.0; data.len()];
2641 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2642 {
2643 trima_into(&input, &mut out)?;
2644 }
2645 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2646 {
2647 trima_into_slice(&mut out, &input, Kernel::Auto)?;
2648 }
2649
2650 assert_eq!(baseline.values.len(), out.len());
2651
2652 for (a, b) in baseline.values.iter().copied().zip(out.iter().copied()) {
2653 let both_nan = a.is_nan() && b.is_nan();
2654 assert!(both_nan || a == b, "mismatch: got {b:?}, expected {a:?}");
2655 }
2656 Ok(())
2657 }
2658}