vector_ta/indicators/moving_averages/
trima.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::trima_wrapper::DeviceArrayF32Trima;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::cuda::moving_averages::CudaTrima;
7use crate::indicators::sma::{sma, SmaData, SmaInput, SmaParams};
8use crate::utilities::data_loader::{source_type, Candles};
9#[cfg(all(feature = "python", feature = "cuda"))]
10use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
11use crate::utilities::enums::Kernel;
12use crate::utilities::helpers::{
13    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
14    make_uninit_matrix,
15};
16#[cfg(feature = "python")]
17use crate::utilities::kernel_validation::validate_kernel;
18#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
19use core::arch::x86_64::*;
20use paste::paste;
21#[cfg(not(target_arch = "wasm32"))]
22use rayon::prelude::*;
23use std::convert::AsRef;
24use std::mem::MaybeUninit;
25use thiserror::Error;
26
27impl<'a> AsRef<[f64]> for TrimaInput<'a> {
28    #[inline(always)]
29    fn as_ref(&self) -> &[f64] {
30        match &self.data {
31            TrimaData::Slice(slice) => slice,
32            TrimaData::Candles { candles, source } => source_type(candles, source),
33        }
34    }
35}
36
37#[derive(Debug, Clone)]
38pub enum TrimaData<'a> {
39    Candles {
40        candles: &'a Candles,
41        source: &'a str,
42    },
43    Slice(&'a [f64]),
44}
45
46#[derive(Debug, Clone)]
47pub struct TrimaOutput {
48    pub values: Vec<f64>,
49}
50
51#[derive(Debug, Clone)]
52#[cfg_attr(
53    all(target_arch = "wasm32", feature = "wasm"),
54    derive(Serialize, Deserialize)
55)]
56pub struct TrimaParams {
57    pub period: Option<usize>,
58}
59
60impl Default for TrimaParams {
61    fn default() -> Self {
62        Self { period: Some(30) }
63    }
64}
65
66#[derive(Debug, Clone)]
67pub struct TrimaInput<'a> {
68    pub data: TrimaData<'a>,
69    pub params: TrimaParams,
70}
71
72impl<'a> TrimaInput<'a> {
73    #[inline]
74    pub fn from_candles(c: &'a Candles, s: &'a str, p: TrimaParams) -> Self {
75        Self {
76            data: TrimaData::Candles {
77                candles: c,
78                source: s,
79            },
80            params: p,
81        }
82    }
83    #[inline]
84    pub fn from_slice(sl: &'a [f64], p: TrimaParams) -> Self {
85        Self {
86            data: TrimaData::Slice(sl),
87            params: p,
88        }
89    }
90    #[inline]
91    pub fn with_default_candles(c: &'a Candles) -> Self {
92        Self::from_candles(c, "close", TrimaParams::default())
93    }
94    #[inline]
95    pub fn get_period(&self) -> usize {
96        self.params.period.unwrap_or(30)
97    }
98}
99
100#[derive(Copy, Clone, Debug)]
101pub struct TrimaBuilder {
102    period: Option<usize>,
103    kernel: Kernel,
104}
105
106impl Default for TrimaBuilder {
107    fn default() -> Self {
108        Self {
109            period: None,
110            kernel: Kernel::Auto,
111        }
112    }
113}
114
115impl TrimaBuilder {
116    #[inline(always)]
117    pub fn new() -> Self {
118        Self::default()
119    }
120    #[inline(always)]
121    pub fn period(mut self, n: usize) -> Self {
122        self.period = Some(n);
123        self
124    }
125    #[inline(always)]
126    pub fn kernel(mut self, k: Kernel) -> Self {
127        self.kernel = k;
128        self
129    }
130    #[inline(always)]
131    pub fn apply(self, c: &Candles) -> Result<TrimaOutput, TrimaError> {
132        let p = TrimaParams {
133            period: self.period,
134        };
135        let i = TrimaInput::from_candles(c, "close", p);
136        trima_with_kernel(&i, self.kernel)
137    }
138    #[inline(always)]
139    pub fn apply_slice(self, d: &[f64]) -> Result<TrimaOutput, TrimaError> {
140        let p = TrimaParams {
141            period: self.period,
142        };
143        let i = TrimaInput::from_slice(d, p);
144        trima_with_kernel(&i, self.kernel)
145    }
146    #[inline(always)]
147    pub fn into_stream(self) -> Result<TrimaStream, TrimaError> {
148        let p = TrimaParams {
149            period: self.period,
150        };
151        TrimaStream::try_new(p)
152    }
153}
154
155#[derive(Debug, Error)]
156pub enum TrimaError {
157    #[error("trima: No data provided (input data slice is empty).")]
158    EmptyInputData,
159    #[error("trima: All values are NaN.")]
160    AllValuesNaN,
161
162    #[error("trima: Invalid period: period = {period}, data length = {data_len}")]
163    InvalidPeriod { period: usize, data_len: usize },
164
165    #[error("trima: Not enough valid data: needed = {needed}, valid = {valid}")]
166    NotEnoughValidData { needed: usize, valid: usize },
167
168    #[error("trima: Period too small: {period}")]
169    PeriodTooSmall { period: usize },
170
171    #[error("trima: No data provided.")]
172    NoData,
173
174    #[error("trima: Output length mismatch: expected = {expected}, got = {got}")]
175    OutputLengthMismatch { expected: usize, got: usize },
176
177    #[error("trima: Invalid range: start = {start}, end = {end}, step = {step}")]
178    InvalidRange {
179        start: usize,
180        end: usize,
181        step: usize,
182    },
183
184    #[error("trima: Invalid kernel for batch path: {0:?}")]
185    InvalidKernelForBatch(Kernel),
186}
187
188#[inline]
189pub fn trima(input: &TrimaInput) -> Result<TrimaOutput, TrimaError> {
190    trima_with_kernel(input, Kernel::Auto)
191}
192
193#[inline(always)]
194fn trima_prepare<'a>(
195    input: &'a TrimaInput,
196    kernel: Kernel,
197) -> Result<(&'a [f64], usize, usize, usize, usize, Kernel), TrimaError> {
198    let data: &[f64] = input.as_ref();
199    let len = data.len();
200    if len == 0 {
201        return Err(TrimaError::EmptyInputData);
202    }
203    let first = data
204        .iter()
205        .position(|x| !x.is_nan())
206        .ok_or(TrimaError::AllValuesNaN)?;
207    let period = input.get_period();
208
209    if period == 0 || period > len {
210        return Err(TrimaError::InvalidPeriod {
211            period,
212            data_len: len,
213        });
214    }
215    if period <= 3 {
216        return Err(TrimaError::PeriodTooSmall { period });
217    }
218    if (len - first) < period {
219        return Err(TrimaError::NotEnoughValidData {
220            needed: period,
221            valid: len - first,
222        });
223    }
224
225    let m1 = (period + 1) / 2;
226    let m2 = period - m1 + 1;
227
228    let chosen = match kernel {
229        Kernel::Auto => Kernel::Scalar,
230        k => k,
231    };
232
233    Ok((data, period, m1, m2, first, chosen))
234}
235
236#[inline(always)]
237fn trima_compute_into(
238    data: &[f64],
239    period: usize,
240    m1: usize,
241    m2: usize,
242    first: usize,
243    kernel: Kernel,
244    out: &mut [f64],
245) {
246    unsafe {
247        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
248        {
249            if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
250                trima_simd128(data, m1, m2, first, out);
251                return;
252            }
253        }
254
255        match kernel {
256            Kernel::Scalar | Kernel::ScalarBatch => {
257                trima_scalar_optimized(data, period, m1, m2, first, out)
258            }
259            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
260            Kernel::Avx2 | Kernel::Avx2Batch => {
261                trima_avx2(data, period, first, out);
262            }
263            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
264            Kernel::Avx512 | Kernel::Avx512Batch => {
265                trima_avx512(data, period, first, out);
266            }
267            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
268            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
269                trima_scalar_optimized(data, period, m1, m2, first, out)
270            }
271            Kernel::Auto => trima_scalar_optimized(data, period, m1, m2, first, out),
272        }
273    }
274}
275
276#[inline(always)]
277unsafe fn trima_scalar_optimized(
278    data: &[f64],
279    period: usize,
280    m1: usize,
281    m2: usize,
282    first: usize,
283    out: &mut [f64],
284) {
285    debug_assert_eq!(data.len(), out.len());
286    let n = data.len();
287    if n == 0 {
288        return;
289    }
290    let warm = first + period - 1;
291    if warm >= n {
292        return;
293    }
294
295    let inv_m1 = 1.0 / (m1 as f64);
296    let inv_m2 = 1.0 / (m2 as f64);
297
298    let base = data.as_ptr().add(first);
299    let mut sum1 = 0.0;
300    {
301        let mut j = 0usize;
302        let end_unroll = m1 & !3usize;
303        while j < end_unroll {
304            sum1 += *base.add(j) + *base.add(j + 1) + *base.add(j + 2) + *base.add(j + 3);
305            j += 4;
306        }
307        while j < m1 {
308            sum1 += *base.add(j);
309            j += 1;
310        }
311    }
312
313    let mut ring: Vec<f64> = Vec::with_capacity(m2);
314    let mut sum2 = 0.0;
315
316    let mut t = first + m1 - 1;
317
318    let mut p_new = data.as_ptr().add(first + m1);
319    let mut p_old = data.as_ptr().add(first);
320
321    {
322        let s1 = sum1 * inv_m1;
323        ring.push(s1);
324        sum2 += s1;
325    }
326
327    while ring.len() < m2 {
328        t += 1;
329
330        sum1 += *p_new - *p_old;
331        p_new = p_new.add(1);
332        p_old = p_old.add(1);
333
334        let s1 = sum1 * inv_m1;
335        ring.push(s1);
336        sum2 += s1;
337    }
338
339    *out.get_unchecked_mut(warm) = sum2 * inv_m2;
340
341    let mut head = 0usize;
342    t += 1;
343    while t < n {
344        sum1 += *p_new - *p_old;
345        p_new = p_new.add(1);
346        p_old = p_old.add(1);
347
348        let new_s1 = sum1 * inv_m1;
349        let old_s1 = *ring.get_unchecked(head);
350        sum2 += new_s1 - old_s1;
351        *ring.get_unchecked_mut(head) = new_s1;
352
353        head += 1;
354        if head == m2 {
355            head = 0;
356        }
357
358        *out.get_unchecked_mut(t) = sum2 * inv_m2;
359
360        t += 1;
361    }
362}
363
364#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
365#[inline]
366unsafe fn trima_simd128(data: &[f64], m1: usize, m2: usize, first: usize, out: &mut [f64]) {
367    use core::arch::wasm32::*;
368
369    const STEP: usize = 2;
370    let n = data.len();
371
372    let mut sma1 = vec![f64::NAN; n];
373
374    if first + m1 <= n {
375        let chunks = m1 / STEP;
376        let tail = m1 % STEP;
377
378        let mut acc = f64x2_splat(0.0);
379        for i in 0..chunks {
380            let idx = first + i * STEP;
381            let d = v128_load(data.as_ptr().add(idx) as *const v128);
382            acc = f64x2_add(acc, d);
383        }
384
385        let mut sum = f64x2_extract_lane::<0>(acc) + f64x2_extract_lane::<1>(acc);
386        if tail != 0 {
387            sum += data[first + chunks * STEP];
388        }
389
390        sma1[first + m1 - 1] = sum / m1 as f64;
391
392        for i in (first + m1)..n {
393            sum += data[i] - data[i - m1];
394            sma1[i] = sum / m1 as f64;
395        }
396    }
397
398    if first + m1 + m2 - 1 <= n {
399        let sma1_first = first + m1 - 1;
400
401        let chunks2 = m2 / STEP;
402        let tail2 = m2 % STEP;
403
404        let mut acc2 = f64x2_splat(0.0);
405        for i in 0..chunks2 {
406            let idx = sma1_first + i * STEP;
407            let d = v128_load(sma1.as_ptr().add(idx) as *const v128);
408            acc2 = f64x2_add(acc2, d);
409        }
410
411        let mut sum2 = f64x2_extract_lane::<0>(acc2) + f64x2_extract_lane::<1>(acc2);
412        if tail2 != 0 {
413            sum2 += sma1[sma1_first + chunks2 * STEP];
414        }
415
416        out[sma1_first + m2 - 1] = sum2 / m2 as f64;
417
418        for i in (sma1_first + m2)..n {
419            sum2 += sma1[i] - sma1[i - m2];
420            out[i] = sum2 / m2 as f64;
421        }
422    }
423}
424
425pub fn trima_with_kernel(input: &TrimaInput, kernel: Kernel) -> Result<TrimaOutput, TrimaError> {
426    let (data, period, m1, m2, first, chosen) = trima_prepare(input, kernel)?;
427    let len = data.len();
428    let warm = first + period - 1;
429    let mut out = alloc_with_nan_prefix(len, warm);
430    trima_compute_into(data, period, m1, m2, first, chosen, &mut out);
431    Ok(TrimaOutput { values: out })
432}
433
434#[inline]
435pub fn trima_into_slice(
436    output: &mut [f64],
437    input: &TrimaInput,
438    kernel: Kernel,
439) -> Result<(), TrimaError> {
440    let (data, period, m1, m2, first, chosen) = trima_prepare(input, kernel)?;
441
442    if output.len() != data.len() {
443        return Err(TrimaError::OutputLengthMismatch {
444            expected: data.len(),
445            got: output.len(),
446        });
447    }
448
449    trima_compute_into(data, period, m1, m2, first, chosen, output);
450
451    let warmup = first + period - 1;
452    for i in 0..warmup.min(output.len()) {
453        output[i] = f64::NAN;
454    }
455
456    Ok(())
457}
458
459#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
460#[inline(always)]
461pub fn trima_into(input: &TrimaInput, out: &mut [f64]) -> Result<(), TrimaError> {
462    let (data, period, m1, m2, first, chosen) = trima_prepare(input, Kernel::Auto)?;
463
464    if out.len() != data.len() {
465        return Err(TrimaError::OutputLengthMismatch {
466            expected: data.len(),
467            got: out.len(),
468        });
469    }
470
471    let warm = first + period - 1;
472    let end = warm.min(out.len());
473    let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
474    for v in &mut out[..end] {
475        *v = qnan;
476    }
477
478    trima_compute_into(data, period, m1, m2, first, chosen, out);
479    Ok(())
480}
481
482#[inline]
483
484pub fn trima_scalar_classic(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
485    let n = data.len();
486    let m1 = (period + 1) / 2;
487    let m2 = period - m1 + 1;
488
489    let mut sma1 = vec![f64::NAN; n];
490
491    if first + m1 <= n {
492        let mut sum1 = 0.0;
493        for j in 0..m1 {
494            sum1 += data[first + j];
495        }
496        sma1[first + m1 - 1] = sum1 / m1 as f64;
497
498        for i in (first + m1)..n {
499            sum1 += data[i] - data[i - m1];
500            sma1[i] = sum1 / m1 as f64;
501        }
502    }
503
504    let warmup_end = first + period - 1;
505    if warmup_end < n {
506        let first_valid_sma1 = first + m1 - 1;
507        let first_valid_sma2 = first_valid_sma1 + m2 - 1;
508
509        if first_valid_sma2 < n {
510            let mut sum2 = 0.0;
511            for j in 0..m2 {
512                sum2 += sma1[first_valid_sma1 + j];
513            }
514
515            if warmup_end < n {
516                out[warmup_end] = sum2 / m2 as f64;
517            }
518
519            for i in (warmup_end + 1)..n {
520                let old_idx = i - m2;
521                if old_idx >= first_valid_sma1 {
522                    sum2 += sma1[i] - sma1[old_idx];
523                    out[i] = sum2 / m2 as f64;
524                }
525            }
526        }
527    }
528}
529
530pub fn trima_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
531    let n = data.len();
532    let m1 = (period + 1) / 2;
533    let m2 = period - m1 + 1;
534
535    let sma1_in = SmaInput {
536        data: SmaData::Slice(data),
537        params: SmaParams { period: Some(m1) },
538    };
539
540    let pass1 = sma(&sma1_in).unwrap();
541
542    let sma2_in = SmaInput {
543        data: SmaData::Slice(&pass1.values),
544        params: SmaParams { period: Some(m2) },
545    };
546    let pass2 = sma(&sma2_in).unwrap();
547
548    let warmup_end = first + period - 1;
549    for i in warmup_end..n {
550        out[i] = pass2.values[i];
551    }
552}
553
554#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
555#[inline]
556pub fn trima_avx512(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
557    if period <= 32 {
558        unsafe { trima_avx512_short(data, period, first, out) }
559    } else {
560        unsafe { trima_avx512_long(data, period, first, out) }
561    }
562}
563#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
564#[inline]
565pub unsafe fn trima_avx2(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
566    debug_assert_eq!(data.len(), out.len());
567    let n = data.len();
568    if n == 0 {
569        return;
570    }
571
572    let m1 = (period + 1) / 2;
573    let m2 = period - m1 + 1;
574    let warm = first + period - 1;
575    if warm >= n {
576        return;
577    }
578
579    let inv_m1 = 1.0 / (m1 as f64);
580    let inv_m2 = 1.0 / (m2 as f64);
581
582    let mut sum1 = sum_u_avx2(data.as_ptr().add(first), m1);
583
584    let mut ring: Vec<f64> = Vec::with_capacity(m2);
585    let mut sum2 = 0.0;
586
587    let mut t = first + m1 - 1;
588    let mut p_new = data.as_ptr().add(first + m1);
589    let mut p_old = data.as_ptr().add(first);
590
591    {
592        let s1 = sum1 * inv_m1;
593        ring.push(s1);
594        sum2 += s1;
595    }
596
597    while ring.len() < m2 {
598        t += 1;
599        sum1 += *p_new - *p_old;
600        p_new = p_new.add(1);
601        p_old = p_old.add(1);
602
603        let s1 = sum1 * inv_m1;
604        ring.push(s1);
605        sum2 += s1;
606    }
607
608    *out.get_unchecked_mut(warm) = sum2 * inv_m2;
609
610    let mut head = 0usize;
611    t += 1;
612    while t < n {
613        sum1 += *p_new - *p_old;
614        p_new = p_new.add(1);
615        p_old = p_old.add(1);
616
617        let new_s1 = sum1 * inv_m1;
618        let old_s1 = *ring.get_unchecked(head);
619        sum2 += new_s1 - old_s1;
620        *ring.get_unchecked_mut(head) = new_s1;
621
622        head += 1;
623        if head == m2 {
624            head = 0;
625        }
626
627        *out.get_unchecked_mut(t) = sum2 * inv_m2;
628        t += 1;
629    }
630}
631#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
632#[inline]
633pub unsafe fn trima_avx512_short(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
634    debug_assert_eq!(data.len(), out.len());
635    let n = data.len();
636    if n == 0 {
637        return;
638    }
639
640    let m1 = (period + 1) / 2;
641    let m2 = period - m1 + 1;
642    let warm = first + period - 1;
643    if warm >= n {
644        return;
645    }
646
647    let inv_m1 = 1.0 / (m1 as f64);
648    let inv_m2 = 1.0 / (m2 as f64);
649
650    let mut sum1 = sum_u_avx2(data.as_ptr().add(first), m1);
651
652    let mut ring: Vec<f64> = Vec::with_capacity(m2);
653    let mut sum2 = 0.0;
654
655    let mut t = first + m1 - 1;
656    let mut p_new = data.as_ptr().add(first + m1);
657    let mut p_old = data.as_ptr().add(first);
658
659    let s1 = sum1 * inv_m1;
660    ring.push(s1);
661    sum2 += s1;
662
663    while ring.len() < m2 {
664        t += 1;
665        sum1 += *p_new - *p_old;
666        p_new = p_new.add(1);
667        p_old = p_old.add(1);
668
669        let s1 = sum1 * inv_m1;
670        ring.push(s1);
671        sum2 += s1;
672    }
673
674    *out.get_unchecked_mut(warm) = sum2 * inv_m2;
675
676    let mut head = 0usize;
677    t += 1;
678    while t < n {
679        sum1 += *p_new - *p_old;
680        p_new = p_new.add(1);
681        p_old = p_old.add(1);
682
683        let new_s1 = sum1 * inv_m1;
684        let old_s1 = *ring.get_unchecked(head);
685        sum2 += new_s1 - old_s1;
686        *ring.get_unchecked_mut(head) = new_s1;
687
688        head += 1;
689        if head == m2 {
690            head = 0;
691        }
692
693        *out.get_unchecked_mut(t) = sum2 * inv_m2;
694        t += 1;
695    }
696}
697#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
698#[inline]
699pub unsafe fn trima_avx512_long(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
700    trima_avx512_short(data, period, first, out)
701}
702
703#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
704#[inline(always)]
705unsafe fn hsum256d(v: __m256d) -> f64 {
706    let hi = _mm256_extractf128_pd(v, 1);
707    let lo = _mm256_castpd256_pd128(v);
708    let sum128 = _mm_add_pd(lo, hi);
709    let shuffled = _mm_unpackhi_pd(sum128, sum128);
710    _mm_cvtsd_f64(_mm_add_sd(sum128, shuffled))
711}
712
713#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
714#[inline(always)]
715unsafe fn sum_u_avx2(ptr: *const f64, len: usize) -> f64 {
716    let mut acc = _mm256_setzero_pd();
717    let mut p = ptr;
718    let chunks = len / 4;
719    for _ in 0..chunks {
720        acc = _mm256_add_pd(acc, _mm256_loadu_pd(p));
721        p = p.add(4);
722    }
723    let mut s = hsum256d(acc);
724    for i in 0..(len & 3) {
725        s += *p.add(i);
726    }
727    s
728}
729
730#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
731#[inline(always)]
732unsafe fn hsum512d(v: __m512d) -> f64 {
733    let v4 = _mm256_add_pd(_mm512_castpd512_pd256(v), _mm512_extractf64x4_pd(v, 1));
734    let v2 = _mm_add_pd(_mm256_castpd256_pd128(v4), _mm256_extractf128_pd(v4, 1));
735    let hi = _mm_unpackhi_pd(v2, v2);
736    _mm_cvtsd_f64(_mm_add_sd(v2, hi))
737}
738
739#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
740#[inline(always)]
741unsafe fn sum_u_avx512(ptr: *const f64, len: usize) -> f64 {
742    let mut acc = _mm512_setzero_pd();
743    let mut p = ptr;
744    let chunks = len / 8;
745    for _ in 0..chunks {
746        acc = _mm512_add_pd(acc, _mm512_loadu_pd(p));
747        p = p.add(8);
748    }
749    let mut s = hsum512d(acc);
750    for i in 0..(len & 7) {
751        s += *p.add(i);
752    }
753    s
754}
755
756#[inline(always)]
757pub fn trima_row_scalar(
758    data: &[f64],
759    first: usize,
760    period: usize,
761    _stride: usize,
762    _w_ptr: *const f64,
763    _inv_n: f64,
764    out: &mut [f64],
765) {
766    let m1 = (period + 1) / 2;
767    let m2 = period - m1 + 1;
768    unsafe { trima_scalar_optimized(data, period, m1, m2, first, out) }
769}
770
771#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
772#[inline(always)]
773pub unsafe fn trima_row_avx2(
774    data: &[f64],
775    first: usize,
776    period: usize,
777    _stride: usize,
778    _w_ptr: *const f64,
779    _inv_n: f64,
780    out: &mut [f64],
781) {
782    trima_avx2(data, period, first, out)
783}
784
785#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
786#[inline(always)]
787pub unsafe fn trima_row_avx512(
788    data: &[f64],
789    first: usize,
790    period: usize,
791    _stride: usize,
792    _w_ptr: *const f64,
793    _inv_n: f64,
794    out: &mut [f64],
795) {
796    trima_avx512(data, period, first, out)
797}
798#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
799#[inline(always)]
800pub unsafe fn trima_row_avx512_short(
801    data: &[f64],
802    first: usize,
803    period: usize,
804    _stride: usize,
805    _w_ptr: *const f64,
806    _inv_n: f64,
807    out: &mut [f64],
808) {
809    trima_avx512_short(data, period, first, out)
810}
811#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
812#[inline(always)]
813pub unsafe fn trima_row_avx512_long(
814    data: &[f64],
815    first: usize,
816    period: usize,
817    _stride: usize,
818    _w_ptr: *const f64,
819    _inv_n: f64,
820    out: &mut [f64],
821) {
822    trima_avx512_long(data, period, first, out)
823}
824
825#[derive(Debug, Clone)]
826pub struct TrimaStream {
827    period: usize,
828    m1: usize,
829    m2: usize,
830
831    inv_m1: f64,
832    inv_m2: f64,
833
834    buf1: Box<[f64]>,
835    sum1: f64,
836    head1: usize,
837    filled1: bool,
838
839    buf2: Box<[f64]>,
840    sum2: f64,
841    head2: usize,
842    filled2: bool,
843}
844
845impl TrimaStream {
846    #[inline]
847    pub fn try_new(params: TrimaParams) -> Result<Self, TrimaError> {
848        let period = params.period.unwrap_or(30);
849        if period <= 3 {
850            return Err(TrimaError::PeriodTooSmall { period });
851        }
852
853        let m1 = (period + 1) / 2;
854        let m2 = period - m1 + 1;
855
856        Ok(Self {
857            period,
858            m1,
859            m2,
860            inv_m1: 1.0 / (m1 as f64),
861            inv_m2: 1.0 / (m2 as f64),
862
863            buf1: vec![f64::NAN; m1].into_boxed_slice(),
864            sum1: 0.0,
865            head1: 0,
866            filled1: false,
867
868            buf2: vec![f64::NAN; m2].into_boxed_slice(),
869            sum2: 0.0,
870            head2: 0,
871            filled2: false,
872        })
873    }
874
875    #[inline(always)]
876    pub fn update(&mut self, x: f64) -> Option<f64> {
877        let old1 = self.buf1[self.head1];
878        self.buf1[self.head1] = x;
879        self.head1 += 1;
880        if self.head1 == self.m1 {
881            self.head1 = 0;
882            self.filled1 = true;
883        }
884
885        if !old1.is_nan() {
886            self.sum1 -= old1;
887        }
888        if !x.is_nan() {
889            self.sum1 += x;
890        }
891
892        if !self.filled1 {
893            return None;
894        }
895
896        let s1 = self.sum1 * self.inv_m1;
897
898        let old2 = self.buf2[self.head2];
899        self.buf2[self.head2] = s1;
900        self.head2 += 1;
901        if self.head2 == self.m2 {
902            self.head2 = 0;
903            self.filled2 = true;
904        }
905
906        if !old2.is_nan() {
907            self.sum2 -= old2;
908        }
909
910        self.sum2 += s1;
911
912        if self.filled2 {
913            Some(self.sum2 * self.inv_m2)
914        } else {
915            None
916        }
917    }
918}
919
920#[derive(Clone, Debug)]
921pub struct TrimaBatchRange {
922    pub period: (usize, usize, usize),
923}
924
925impl Default for TrimaBatchRange {
926    fn default() -> Self {
927        Self {
928            period: (14, 263, 1),
929        }
930    }
931}
932
933#[derive(Clone, Debug, Default)]
934pub struct TrimaBatchBuilder {
935    range: TrimaBatchRange,
936    kernel: Kernel,
937}
938
939impl TrimaBatchBuilder {
940    pub fn new() -> Self {
941        Self::default()
942    }
943    pub fn kernel(mut self, k: Kernel) -> Self {
944        self.kernel = k;
945        self
946    }
947
948    #[inline]
949    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
950        self.range.period = (start, end, step);
951        self
952    }
953    #[inline]
954    pub fn period_static(mut self, p: usize) -> Self {
955        self.range.period = (p, p, 0);
956        self
957    }
958
959    pub fn apply_slice(self, data: &[f64]) -> Result<TrimaBatchOutput, TrimaError> {
960        trima_batch_with_kernel(data, &self.range, self.kernel)
961    }
962
963    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<TrimaBatchOutput, TrimaError> {
964        TrimaBatchBuilder::new().kernel(k).apply_slice(data)
965    }
966
967    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<TrimaBatchOutput, TrimaError> {
968        let slice = source_type(c, src);
969        self.apply_slice(slice)
970    }
971
972    pub fn with_default_candles(c: &Candles) -> Result<TrimaBatchOutput, TrimaError> {
973        TrimaBatchBuilder::new()
974            .kernel(Kernel::Auto)
975            .apply_candles(c, "close")
976    }
977}
978
979pub fn trima_batch_with_kernel(
980    data: &[f64],
981    sweep: &TrimaBatchRange,
982    k: Kernel,
983) -> Result<TrimaBatchOutput, TrimaError> {
984    let kernel = match k {
985        Kernel::Auto => detect_best_batch_kernel(),
986        other if other.is_batch() => other,
987        other => return Err(TrimaError::InvalidKernelForBatch(other)),
988    };
989
990    let simd = match kernel {
991        Kernel::Avx512Batch => Kernel::Avx512,
992        Kernel::Avx2Batch => Kernel::Avx2,
993        Kernel::ScalarBatch => Kernel::Scalar,
994        Kernel::Avx512 | Kernel::Avx2 | Kernel::Scalar => kernel,
995        _ => Kernel::Scalar,
996    };
997
998    trima_batch_par_slice(data, sweep, simd)
999}
1000
1001#[derive(Clone, Debug)]
1002pub struct TrimaBatchOutput {
1003    pub values: Vec<f64>,
1004    pub combos: Vec<TrimaParams>,
1005    pub rows: usize,
1006    pub cols: usize,
1007}
1008impl TrimaBatchOutput {
1009    pub fn row_for_params(&self, p: &TrimaParams) -> Option<usize> {
1010        self.combos
1011            .iter()
1012            .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
1013    }
1014
1015    pub fn values_for(&self, p: &TrimaParams) -> Option<&[f64]> {
1016        self.row_for_params(p).map(|row| {
1017            let start = row * self.cols;
1018            &self.values[start..start + self.cols]
1019        })
1020    }
1021}
1022
1023#[inline(always)]
1024fn expand_grid(r: &TrimaBatchRange) -> Vec<TrimaParams> {
1025    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, TrimaError> {
1026        if step == 0 || start == end {
1027            return Ok(vec![start]);
1028        }
1029        let (lo, hi) = if start <= end {
1030            (start, end)
1031        } else {
1032            (end, start)
1033        };
1034        let mut v = Vec::new();
1035        let mut cur = lo;
1036        while cur <= hi {
1037            v.push(cur);
1038            cur = cur
1039                .checked_add(step)
1040                .ok_or(TrimaError::InvalidRange { start, end, step })?;
1041            if cur == *v.last().unwrap() {
1042                break;
1043            }
1044        }
1045        if v.is_empty() {
1046            return Err(TrimaError::InvalidRange { start, end, step });
1047        }
1048        Ok(v)
1049    }
1050
1051    let periods = match axis_usize(r.period) {
1052        Ok(v) => v,
1053        Err(_) => return Vec::new(),
1054    };
1055    let mut out = Vec::with_capacity(periods.len());
1056    for &p in &periods {
1057        out.push(TrimaParams { period: Some(p) });
1058    }
1059    out
1060}
1061
1062#[inline(always)]
1063pub fn trima_batch_slice(
1064    data: &[f64],
1065    sweep: &TrimaBatchRange,
1066    kern: Kernel,
1067) -> Result<TrimaBatchOutput, TrimaError> {
1068    trima_batch_inner(data, sweep, kern, false)
1069}
1070
1071#[inline(always)]
1072pub fn trima_batch_par_slice(
1073    data: &[f64],
1074    sweep: &TrimaBatchRange,
1075    kern: Kernel,
1076) -> Result<TrimaBatchOutput, TrimaError> {
1077    trima_batch_inner(data, sweep, kern, true)
1078}
1079
1080#[inline(always)]
1081fn trima_batch_inner(
1082    data: &[f64],
1083    sweep: &TrimaBatchRange,
1084    kern: Kernel,
1085    parallel: bool,
1086) -> Result<TrimaBatchOutput, TrimaError> {
1087    let combos = expand_grid(sweep);
1088    if combos.is_empty() {
1089        return Err(TrimaError::InvalidRange {
1090            start: sweep.period.0,
1091            end: sweep.period.1,
1092            step: sweep.period.2,
1093        });
1094    }
1095
1096    let first = data
1097        .iter()
1098        .position(|x| !x.is_nan())
1099        .ok_or(TrimaError::AllValuesNaN)?;
1100    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1101    if data.len() - first < max_p {
1102        return Err(TrimaError::NotEnoughValidData {
1103            needed: max_p,
1104            valid: data.len() - first,
1105        });
1106    }
1107
1108    let rows = combos.len();
1109    let cols = data.len();
1110    let warm: Vec<usize> = combos
1111        .iter()
1112        .map(|c| first + c.period.unwrap() - 1)
1113        .collect();
1114
1115    let _total = rows.checked_mul(cols).ok_or(TrimaError::InvalidRange {
1116        start: sweep.period.0,
1117        end: sweep.period.1,
1118        step: sweep.period.2,
1119    })?;
1120    let mut raw = make_uninit_matrix(rows, cols);
1121    unsafe { init_matrix_prefixes(&mut raw, cols, &warm) };
1122
1123    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1124        let period = combos[row].period.unwrap();
1125
1126        let out_row =
1127            core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1128
1129        match kern {
1130            Kernel::Scalar => {
1131                trima_row_scalar(data, first, period, 0, core::ptr::null(), 1.0, out_row)
1132            }
1133            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1134            Kernel::Avx2 => trima_row_avx2(data, first, period, 0, core::ptr::null(), 1.0, out_row),
1135            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1136            Kernel::Avx512 => {
1137                trima_row_avx512(data, first, period, 0, core::ptr::null(), 1.0, out_row)
1138            }
1139            _ => trima_row_scalar(data, first, period, 0, core::ptr::null(), 1.0, out_row),
1140        }
1141    };
1142
1143    if parallel {
1144        #[cfg(not(target_arch = "wasm32"))]
1145        {
1146            raw.par_chunks_mut(cols)
1147                .enumerate()
1148                .for_each(|(row, slice)| do_row(row, slice));
1149        }
1150
1151        #[cfg(target_arch = "wasm32")]
1152        {
1153            for (row, slice) in raw.chunks_mut(cols).enumerate() {
1154                do_row(row, slice);
1155            }
1156        }
1157    } else {
1158        for (row, slice) in raw.chunks_mut(cols).enumerate() {
1159            do_row(row, slice);
1160        }
1161    }
1162
1163    let mut buf_guard = core::mem::ManuallyDrop::new(raw);
1164    let values = unsafe {
1165        Vec::from_raw_parts(
1166            buf_guard.as_mut_ptr() as *mut f64,
1167            buf_guard.len(),
1168            buf_guard.capacity(),
1169        )
1170    };
1171
1172    Ok(TrimaBatchOutput {
1173        values,
1174        combos,
1175        rows,
1176        cols,
1177    })
1178}
1179
1180#[inline(always)]
1181pub fn trima_batch_inner_into(
1182    data: &[f64],
1183    sweep: &TrimaBatchRange,
1184    kern: Kernel,
1185    parallel: bool,
1186    out: &mut [f64],
1187) -> Result<Vec<TrimaParams>, TrimaError> {
1188    let combos = expand_grid(sweep);
1189    if combos.is_empty() {
1190        return Err(TrimaError::InvalidRange {
1191            start: sweep.period.0,
1192            end: sweep.period.1,
1193            step: sweep.period.2,
1194        });
1195    }
1196
1197    let first = data
1198        .iter()
1199        .position(|x| !x.is_nan())
1200        .ok_or(TrimaError::AllValuesNaN)?;
1201    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1202    if data.len() - first < max_p {
1203        return Err(TrimaError::NotEnoughValidData {
1204            needed: max_p,
1205            valid: data.len() - first,
1206        });
1207    }
1208
1209    let rows = combos.len();
1210    let cols = data.len();
1211    let expected = rows.checked_mul(cols).ok_or(TrimaError::InvalidRange {
1212        start: sweep.period.0,
1213        end: sweep.period.1,
1214        step: sweep.period.2,
1215    })?;
1216    if out.len() != expected {
1217        return Err(TrimaError::OutputLengthMismatch {
1218            expected,
1219            got: out.len(),
1220        });
1221    }
1222
1223    let warm: Vec<usize> = combos
1224        .iter()
1225        .map(|c| first + c.period.unwrap() - 1)
1226        .collect();
1227    let out_mu = unsafe {
1228        core::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1229    };
1230    init_matrix_prefixes(out_mu, cols, &warm);
1231
1232    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1233        let period = combos[row].period.unwrap();
1234        let out_row =
1235            core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1236        match kern {
1237            Kernel::Scalar => {
1238                trima_row_scalar(data, first, period, 0, core::ptr::null(), 1.0, out_row)
1239            }
1240            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1241            Kernel::Avx2 => trima_row_avx2(data, first, period, 0, core::ptr::null(), 1.0, out_row),
1242            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1243            Kernel::Avx512 => {
1244                trima_row_avx512(data, first, period, 0, core::ptr::null(), 1.0, out_row)
1245            }
1246            _ => trima_row_scalar(data, first, period, 0, core::ptr::null(), 1.0, out_row),
1247        }
1248    };
1249
1250    if parallel {
1251        #[cfg(not(target_arch = "wasm32"))]
1252        {
1253            out_mu
1254                .par_chunks_mut(cols)
1255                .enumerate()
1256                .for_each(|(row, slice)| do_row(row, slice));
1257        }
1258        #[cfg(target_arch = "wasm32")]
1259        {
1260            for (row, slice) in out_mu.chunks_mut(cols).enumerate() {
1261                do_row(row, slice);
1262            }
1263        }
1264    } else {
1265        for (row, slice) in out_mu.chunks_mut(cols).enumerate() {
1266            do_row(row, slice);
1267        }
1268    }
1269
1270    Ok(combos)
1271}
1272
1273#[cfg(feature = "python")]
1274use pyo3::exceptions::PyValueError;
1275#[cfg(feature = "python")]
1276use pyo3::prelude::*;
1277
1278#[cfg(feature = "python")]
1279#[pyfunction(name = "trima")]
1280#[pyo3(signature = (data, period, kernel=None))]
1281
1282pub fn trima_py<'py>(
1283    py: Python<'py>,
1284    data: numpy::PyReadonlyArray1<'py, f64>,
1285    period: usize,
1286    kernel: Option<&str>,
1287) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1288    use numpy::{IntoPyArray, PyArrayMethods};
1289
1290    let slice_in = data.as_slice()?;
1291    let kern = validate_kernel(kernel, false)?;
1292
1293    let params = TrimaParams {
1294        period: Some(period),
1295    };
1296    let trima_in = TrimaInput::from_slice(slice_in, params);
1297
1298    let result_vec: Vec<f64> = py
1299        .allow_threads(|| trima_with_kernel(&trima_in, kern).map(|o| o.values))
1300        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1301
1302    Ok(result_vec.into_pyarray(py))
1303}
1304
1305#[cfg(feature = "python")]
1306#[pyclass(name = "TrimaStream")]
1307pub struct TrimaStreamPy {
1308    stream: TrimaStream,
1309}
1310
1311#[cfg(feature = "python")]
1312#[pymethods]
1313impl TrimaStreamPy {
1314    #[new]
1315    fn new(period: Option<usize>) -> PyResult<Self> {
1316        let params = TrimaParams { period };
1317        let stream =
1318            TrimaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1319        Ok(TrimaStreamPy { stream })
1320    }
1321
1322    fn update(&mut self, value: f64) -> Option<f64> {
1323        self.stream.update(value)
1324    }
1325}
1326
1327#[cfg(feature = "python")]
1328#[pyfunction(name = "trima_batch")]
1329#[pyo3(signature = (data, period_range, kernel=None))]
1330
1331pub fn trima_batch_py<'py>(
1332    py: Python<'py>,
1333    data: numpy::PyReadonlyArray1<'py, f64>,
1334    period_range: (usize, usize, usize),
1335    kernel: Option<&str>,
1336) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1337    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1338    use pyo3::types::PyDict;
1339
1340    let slice_in = data.as_slice()?;
1341    let sweep = TrimaBatchRange {
1342        period: period_range,
1343    };
1344
1345    let combos = expand_grid(&sweep);
1346    let rows = combos.len();
1347    let cols = slice_in.len();
1348
1349    let total = rows
1350        .checked_mul(cols)
1351        .ok_or_else(|| PyValueError::new_err("size overflow: rows*cols exceeds usize"))?;
1352
1353    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1354    let slice_out = unsafe { out_arr.as_slice_mut()? };
1355
1356    let kern = validate_kernel(kernel, true)?;
1357    let kern = match kern {
1358        Kernel::Auto => detect_best_batch_kernel(),
1359        k => k,
1360    };
1361    let simd = match kern {
1362        Kernel::Avx512Batch => Kernel::Avx512,
1363        Kernel::Avx2Batch => Kernel::Avx2,
1364        Kernel::ScalarBatch => Kernel::Scalar,
1365        Kernel::Avx512 | Kernel::Avx2 | Kernel::Scalar => kern,
1366        _ => Kernel::Scalar,
1367    };
1368
1369    let combos = py
1370        .allow_threads(|| trima_batch_inner_into(slice_in, &sweep, simd, true, slice_out))
1371        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1372
1373    let dict = PyDict::new(py);
1374    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1375    dict.set_item(
1376        "periods",
1377        combos
1378            .iter()
1379            .map(|p| p.period.unwrap_or(30) as u64)
1380            .collect::<Vec<_>>()
1381            .into_pyarray(py),
1382    )?;
1383    Ok(dict)
1384}
1385
1386#[cfg(all(feature = "python", feature = "cuda"))]
1387#[pyfunction(name = "trima_cuda_batch_dev")]
1388#[pyo3(signature = (data, period_range, device_id=0))]
1389pub fn trima_cuda_batch_dev_py(
1390    py: Python<'_>,
1391    data: numpy::PyReadonlyArray1<'_, f64>,
1392    period_range: (usize, usize, usize),
1393    device_id: usize,
1394) -> PyResult<DeviceArrayF32TrimaPy> {
1395    use numpy::PyArrayMethods;
1396
1397    if !cuda_available() {
1398        return Err(PyValueError::new_err("CUDA not available"));
1399    }
1400
1401    let slice_in = data.as_slice()?;
1402    let sweep = TrimaBatchRange {
1403        period: period_range,
1404    };
1405
1406    let data_f32: Vec<f32> = slice_in.iter().map(|&v| v as f32).collect();
1407
1408    let inner = py.allow_threads(|| {
1409        let cuda = CudaTrima::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1410        cuda.trima_batch_dev(&data_f32, &sweep)
1411            .map_err(|e| PyValueError::new_err(e.to_string()))
1412    })?;
1413
1414    Ok(DeviceArrayF32TrimaPy { inner: Some(inner) })
1415}
1416
1417#[cfg(all(feature = "python", feature = "cuda"))]
1418#[pyfunction(name = "trima_cuda_many_series_one_param_dev")]
1419#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1420pub fn trima_cuda_many_series_one_param_dev_py(
1421    py: Python<'_>,
1422    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1423    period: usize,
1424    device_id: usize,
1425) -> PyResult<DeviceArrayF32TrimaPy> {
1426    use numpy::PyUntypedArrayMethods;
1427
1428    if !cuda_available() {
1429        return Err(PyValueError::new_err("CUDA not available"));
1430    }
1431
1432    let flat_in = data_tm_f32.as_slice()?;
1433    let rows = data_tm_f32.shape()[0];
1434    let cols = data_tm_f32.shape()[1];
1435    let params = TrimaParams {
1436        period: Some(period),
1437    };
1438
1439    let inner = py.allow_threads(|| {
1440        let cuda = CudaTrima::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1441        cuda.trima_multi_series_one_param_time_major_dev(flat_in, cols, rows, &params)
1442            .map_err(|e| PyValueError::new_err(e.to_string()))
1443    })?;
1444
1445    Ok(DeviceArrayF32TrimaPy { inner: Some(inner) })
1446}
1447
1448#[cfg(all(feature = "python", feature = "cuda"))]
1449#[pyclass(
1450    module = "ta_indicators.cuda",
1451    name = "DeviceArrayF32Trima",
1452    unsendable
1453)]
1454pub struct DeviceArrayF32TrimaPy {
1455    pub(crate) inner: Option<DeviceArrayF32Trima>,
1456}
1457
1458#[cfg(all(feature = "python", feature = "cuda"))]
1459#[pymethods]
1460impl DeviceArrayF32TrimaPy {
1461    #[getter]
1462    fn __cuda_array_interface__<'py>(
1463        &self,
1464        py: Python<'py>,
1465    ) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1466        use pyo3::types::PyDict;
1467        let inner = self
1468            .inner
1469            .as_ref()
1470            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1471        let d = PyDict::new(py);
1472        d.set_item("shape", (inner.rows, inner.cols))?;
1473        d.set_item("typestr", "<f4")?;
1474        d.set_item(
1475            "strides",
1476            (
1477                inner.cols * std::mem::size_of::<f32>(),
1478                std::mem::size_of::<f32>(),
1479            ),
1480        )?;
1481
1482        let ptr_val: usize = if inner.rows == 0 || inner.cols == 0 {
1483            0
1484        } else {
1485            inner.device_ptr() as usize
1486        };
1487        d.set_item("data", (ptr_val, false))?;
1488
1489        d.set_item("version", 3)?;
1490        Ok(d)
1491    }
1492
1493    fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
1494        let inner = self
1495            .inner
1496            .as_ref()
1497            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1498        Ok((2, inner.device_id as i32))
1499    }
1500
1501    #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
1502    fn __dlpack__<'py>(
1503        &mut self,
1504        py: Python<'py>,
1505        stream: Option<PyObject>,
1506        max_version: Option<PyObject>,
1507        dl_device: Option<PyObject>,
1508        copy: Option<PyObject>,
1509    ) -> PyResult<PyObject> {
1510        let (kdl, alloc_dev) = self.__dlpack_device__()?;
1511
1512        if let Some(dev_obj) = dl_device.as_ref() {
1513            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1514                if dev_ty != kdl || dev_id != alloc_dev {
1515                    let wants_copy = copy
1516                        .as_ref()
1517                        .and_then(|c| c.extract::<bool>(py).ok())
1518                        .unwrap_or(false);
1519                    if wants_copy {
1520                        return Err(PyValueError::new_err(
1521                            "device copy not implemented for __dlpack__",
1522                        ));
1523                    } else {
1524                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1525                    }
1526                }
1527            }
1528        }
1529
1530        let _ = stream;
1531
1532        let inner = self
1533            .inner
1534            .take()
1535            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1536
1537        let DeviceArrayF32Trima {
1538            buf,
1539            rows,
1540            cols,
1541            ctx: _ctx,
1542            device_id,
1543        } = inner;
1544
1545        if device_id as i32 != alloc_dev {
1546            return Err(PyValueError::new_err("device id mismatch for __dlpack__"));
1547        }
1548
1549        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1550
1551        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1552    }
1553}
1554
1555#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1556use serde::{Deserialize, Serialize};
1557#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1558use serde_wasm_bindgen;
1559#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1560use wasm_bindgen::prelude::*;
1561
1562#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1563#[wasm_bindgen]
1564
1565pub fn trima_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1566    let params = TrimaParams {
1567        period: Some(period),
1568    };
1569    let input = TrimaInput::from_slice(data, params);
1570
1571    let mut output = vec![0.0; data.len()];
1572
1573    trima_into_slice(&mut output, &input, Kernel::Auto)
1574        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1575
1576    Ok(output)
1577}
1578
1579#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1580#[derive(Serialize, Deserialize)]
1581pub struct TrimaBatchConfig {
1582    pub period_range: (usize, usize, usize),
1583}
1584
1585#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1586#[derive(Serialize, Deserialize)]
1587pub struct TrimaBatchJsOutput {
1588    pub values: Vec<f64>,
1589    pub combos: Vec<TrimaParams>,
1590    pub rows: usize,
1591    pub cols: usize,
1592}
1593
1594#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1595#[wasm_bindgen(js_name = trima_batch)]
1596pub fn trima_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1597    let config: TrimaBatchConfig = serde_wasm_bindgen::from_value(config)
1598        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1599    let sweep = TrimaBatchRange {
1600        period: config.period_range,
1601    };
1602
1603    let output = trima_batch_inner(data, &sweep, detect_best_kernel(), false)
1604        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1605
1606    let js_output = TrimaBatchJsOutput {
1607        values: output.values,
1608        combos: output.combos,
1609        rows: output.rows,
1610        cols: output.cols,
1611    };
1612    serde_wasm_bindgen::to_value(&js_output)
1613        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1614}
1615
1616#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1617#[wasm_bindgen]
1618
1619pub fn trima_batch_js(
1620    data: &[f64],
1621    period_start: usize,
1622    period_end: usize,
1623    period_step: usize,
1624) -> Result<Vec<f64>, JsValue> {
1625    let sweep = TrimaBatchRange {
1626        period: (period_start, period_end, period_step),
1627    };
1628
1629    trima_batch_inner(data, &sweep, Kernel::Auto, false)
1630        .map(|output| output.values)
1631        .map_err(|e| JsValue::from_str(&e.to_string()))
1632}
1633
1634#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1635#[wasm_bindgen]
1636
1637pub fn trima_batch_metadata_js(
1638    period_start: usize,
1639    period_end: usize,
1640    period_step: usize,
1641) -> Result<Vec<f64>, JsValue> {
1642    let sweep = TrimaBatchRange {
1643        period: (period_start, period_end, period_step),
1644    };
1645
1646    let combos = expand_grid(&sweep);
1647    let metadata: Vec<f64> = combos
1648        .iter()
1649        .map(|combo| combo.period.unwrap_or(30) as f64)
1650        .collect();
1651
1652    Ok(metadata)
1653}
1654
1655#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1656#[wasm_bindgen]
1657pub fn trima_alloc(len: usize) -> *mut f64 {
1658    let mut vec = Vec::<f64>::with_capacity(len);
1659    let ptr = vec.as_mut_ptr();
1660    std::mem::forget(vec);
1661    ptr
1662}
1663
1664#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1665#[wasm_bindgen]
1666pub fn trima_free(ptr: *mut f64, len: usize) {
1667    unsafe {
1668        let _ = Vec::from_raw_parts(ptr, len, len);
1669    }
1670}
1671
1672#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1673#[wasm_bindgen]
1674pub fn trima_into(
1675    in_ptr: *const f64,
1676    out_ptr: *mut f64,
1677    len: usize,
1678    period: usize,
1679) -> Result<(), JsValue> {
1680    if in_ptr.is_null() || out_ptr.is_null() {
1681        return Err(JsValue::from_str("null pointer passed to trima_into"));
1682    }
1683    unsafe {
1684        let data = std::slice::from_raw_parts(in_ptr, len);
1685        if period == 0 || period > len {
1686            return Err(JsValue::from_str("Invalid period"));
1687        }
1688        let params = TrimaParams {
1689            period: Some(period),
1690        };
1691        if in_ptr == out_ptr {
1692            let mut temp = vec![0.0; len];
1693            let input = TrimaInput::from_slice(data, params);
1694            trima_into_slice(&mut temp, &input, Kernel::Auto)
1695                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1696            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1697            out.copy_from_slice(&temp);
1698        } else {
1699            let input = TrimaInput::from_slice(data, params);
1700            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1701            trima_into_slice(out, &input, Kernel::Auto)
1702                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1703        }
1704        Ok(())
1705    }
1706}
1707
1708#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1709#[wasm_bindgen]
1710pub fn trima_batch_into(
1711    in_ptr: *const f64,
1712    out_ptr: *mut f64,
1713    len: usize,
1714    period_start: usize,
1715    period_end: usize,
1716    period_step: usize,
1717) -> Result<usize, JsValue> {
1718    if in_ptr.is_null() || out_ptr.is_null() {
1719        return Err(JsValue::from_str("null pointer passed to trima_batch_into"));
1720    }
1721
1722    unsafe {
1723        let data = std::slice::from_raw_parts(in_ptr, len);
1724
1725        let sweep = TrimaBatchRange {
1726            period: (period_start, period_end, period_step),
1727        };
1728
1729        let combos = expand_grid(&sweep);
1730        let rows = combos.len();
1731        let cols = len;
1732
1733        let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
1734
1735        trima_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
1736            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1737
1738        Ok(rows)
1739    }
1740}
1741
1742#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1743#[wasm_bindgen]
1744#[deprecated(
1745    since = "1.0.0",
1746    note = "For streaming patterns, use the fast/unsafe API with persistent buffers"
1747)]
1748pub struct TrimaContext {
1749    period: usize,
1750    m1: usize,
1751    m2: usize,
1752    first: usize,
1753    kernel: Kernel,
1754}
1755
1756#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1757#[wasm_bindgen]
1758#[allow(deprecated)]
1759impl TrimaContext {
1760    #[wasm_bindgen(constructor)]
1761    #[deprecated(
1762        since = "1.0.0",
1763        note = "For streaming patterns, use the fast/unsafe API with persistent buffers"
1764    )]
1765    pub fn new(period: usize) -> Result<TrimaContext, JsValue> {
1766        if period == 0 {
1767            return Err(JsValue::from_str("Invalid period: 0"));
1768        }
1769        if period <= 3 {
1770            return Err(JsValue::from_str(&format!("Period too small: {}", period)));
1771        }
1772
1773        let m1 = (period + 1) / 2;
1774        let m2 = period - m1 + 1;
1775
1776        Ok(TrimaContext {
1777            period,
1778            m1,
1779            m2,
1780            first: 0,
1781            kernel: Kernel::Auto,
1782        })
1783    }
1784
1785    pub fn update_into(
1786        &self,
1787        in_ptr: *const f64,
1788        out_ptr: *mut f64,
1789        len: usize,
1790    ) -> Result<(), JsValue> {
1791        if len < self.period {
1792            return Err(JsValue::from_str("Data length less than period"));
1793        }
1794
1795        unsafe {
1796            let data = std::slice::from_raw_parts(in_ptr, len);
1797            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1798
1799            let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1800
1801            if in_ptr == out_ptr {
1802                let mut temp = vec![0.0; len];
1803                trima_compute_into(
1804                    data,
1805                    self.period,
1806                    self.m1,
1807                    self.m2,
1808                    first,
1809                    self.kernel,
1810                    &mut temp,
1811                );
1812
1813                out.copy_from_slice(&temp);
1814            } else {
1815                trima_compute_into(data, self.period, self.m1, self.m2, first, self.kernel, out);
1816            }
1817
1818            let warmup = first + self.period - 1;
1819            for i in 0..warmup {
1820                out[i] = f64::NAN;
1821            }
1822        }
1823
1824        Ok(())
1825    }
1826
1827    pub fn get_warmup_period(&self) -> usize {
1828        self.period - 1
1829    }
1830}
1831
1832#[cfg(test)]
1833mod tests {
1834    use super::*;
1835    use crate::skip_if_unsupported;
1836    use crate::utilities::data_loader::read_candles_from_csv;
1837    use crate::utilities::enums::Kernel;
1838
1839    fn check_trima_partial_params(
1840        test_name: &str,
1841        kernel: Kernel,
1842    ) -> Result<(), Box<dyn std::error::Error>> {
1843        skip_if_unsupported!(kernel, test_name);
1844        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1845        let candles = read_candles_from_csv(file_path)?;
1846
1847        let default_params = TrimaParams { period: None };
1848        let input = TrimaInput::from_candles(&candles, "close", default_params);
1849        let output = trima_with_kernel(&input, kernel)?;
1850        assert_eq!(output.values.len(), candles.close.len());
1851
1852        let params_period_10 = TrimaParams { period: Some(10) };
1853        let input2 = TrimaInput::from_candles(&candles, "hl2", params_period_10);
1854        let output2 = trima_with_kernel(&input2, kernel)?;
1855        assert_eq!(output2.values.len(), candles.close.len());
1856
1857        let params_custom = TrimaParams { period: Some(14) };
1858        let input3 = TrimaInput::from_candles(&candles, "hlc3", params_custom);
1859        let output3 = trima_with_kernel(&input3, kernel)?;
1860        assert_eq!(output3.values.len(), candles.close.len());
1861
1862        Ok(())
1863    }
1864
1865    fn check_trima_accuracy(
1866        test_name: &str,
1867        kernel: Kernel,
1868    ) -> Result<(), Box<dyn std::error::Error>> {
1869        skip_if_unsupported!(kernel, test_name);
1870        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1871        let candles = read_candles_from_csv(file_path)?;
1872        let close_prices = &candles.close;
1873        let params = TrimaParams { period: Some(30) };
1874        let input = TrimaInput::from_candles(&candles, "close", params);
1875        let trima_result = trima_with_kernel(&input, kernel)?;
1876
1877        assert_eq!(
1878            trima_result.values.len(),
1879            close_prices.len(),
1880            "TRIMA output length should match input data length"
1881        );
1882        let expected_last_five_trima = [
1883            59957.916666666664,
1884            59846.770833333336,
1885            59750.620833333334,
1886            59665.2125,
1887            59581.612499999996,
1888        ];
1889        assert!(
1890            trima_result.values.len() >= 5,
1891            "Not enough TRIMA values for the test"
1892        );
1893        let start_index = trima_result.values.len() - 5;
1894        let result_last_five_trima = &trima_result.values[start_index..];
1895        for (i, &value) in result_last_five_trima.iter().enumerate() {
1896            let expected_value = expected_last_five_trima[i];
1897            assert!(
1898                (value - expected_value).abs() < 1e-6,
1899                "[{}] TRIMA value mismatch at index {}: expected {}, got {}",
1900                test_name,
1901                i,
1902                expected_value,
1903                value
1904            );
1905        }
1906        let period = input.params.period.unwrap_or(14);
1907        for i in 0..(period - 1) {
1908            assert!(
1909                trima_result.values[i].is_nan(),
1910                "[{}] Expected NaN at early index {} for TRIMA, got {}",
1911                test_name,
1912                i,
1913                trima_result.values[i]
1914            );
1915        }
1916        Ok(())
1917    }
1918
1919    fn check_trima_default_candles(
1920        test_name: &str,
1921        kernel: Kernel,
1922    ) -> Result<(), Box<dyn std::error::Error>> {
1923        skip_if_unsupported!(kernel, test_name);
1924        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1925        let candles = read_candles_from_csv(file_path)?;
1926        let input = TrimaInput::with_default_candles(&candles);
1927        match input.data {
1928            TrimaData::Candles { source, .. } => assert_eq!(source, "close"),
1929            _ => panic!("Expected TrimaData::Candles"),
1930        }
1931        let output = trima_with_kernel(&input, kernel)?;
1932        assert_eq!(output.values.len(), candles.close.len());
1933        Ok(())
1934    }
1935
1936    fn check_trima_zero_period(
1937        test_name: &str,
1938        kernel: Kernel,
1939    ) -> Result<(), Box<dyn std::error::Error>> {
1940        skip_if_unsupported!(kernel, test_name);
1941        let input_data = [10.0, 20.0, 30.0];
1942        let params = TrimaParams { period: Some(0) };
1943        let input = TrimaInput::from_slice(&input_data, params);
1944        let res = trima_with_kernel(&input, kernel);
1945        assert!(
1946            res.is_err(),
1947            "[{}] TRIMA should fail with zero period",
1948            test_name
1949        );
1950        Ok(())
1951    }
1952
1953    fn check_trima_period_too_small(
1954        test_name: &str,
1955        kernel: Kernel,
1956    ) -> Result<(), Box<dyn std::error::Error>> {
1957        skip_if_unsupported!(kernel, test_name);
1958        let input_data = [10.0, 20.0, 30.0, 40.0];
1959        let params = TrimaParams { period: Some(3) };
1960        let input = TrimaInput::from_slice(&input_data, params);
1961        let res = trima_with_kernel(&input, kernel);
1962        assert!(
1963            res.is_err(),
1964            "[{}] TRIMA should fail with period <= 3",
1965            test_name
1966        );
1967        Ok(())
1968    }
1969
1970    fn check_trima_period_exceeds_length(
1971        test_name: &str,
1972        kernel: Kernel,
1973    ) -> Result<(), Box<dyn std::error::Error>> {
1974        skip_if_unsupported!(kernel, test_name);
1975        let data_small = [10.0, 20.0, 30.0];
1976        let params = TrimaParams { period: Some(10) };
1977        let input = TrimaInput::from_slice(&data_small, params);
1978        let res = trima_with_kernel(&input, kernel);
1979        assert!(
1980            res.is_err(),
1981            "[{}] TRIMA should fail with period exceeding length",
1982            test_name
1983        );
1984        Ok(())
1985    }
1986
1987    fn check_trima_very_small_dataset(
1988        test_name: &str,
1989        kernel: Kernel,
1990    ) -> Result<(), Box<dyn std::error::Error>> {
1991        skip_if_unsupported!(kernel, test_name);
1992        let single_point = [42.0];
1993        let params = TrimaParams { period: Some(14) };
1994        let input = TrimaInput::from_slice(&single_point, params);
1995        let res = trima_with_kernel(&input, kernel);
1996        assert!(
1997            res.is_err(),
1998            "[{}] TRIMA should fail with insufficient data",
1999            test_name
2000        );
2001        Ok(())
2002    }
2003
2004    fn check_trima_reinput(
2005        test_name: &str,
2006        kernel: Kernel,
2007    ) -> Result<(), Box<dyn std::error::Error>> {
2008        skip_if_unsupported!(kernel, test_name);
2009        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2010        let candles = read_candles_from_csv(file_path)?;
2011
2012        let first_params = TrimaParams { period: Some(14) };
2013        let first_input = TrimaInput::from_candles(&candles, "close", first_params);
2014        let first_result = trima_with_kernel(&first_input, kernel)?;
2015
2016        let second_params = TrimaParams { period: Some(10) };
2017        let second_input = TrimaInput::from_slice(&first_result.values, second_params);
2018        let second_result = trima_with_kernel(&second_input, kernel)?;
2019
2020        assert_eq!(second_result.values.len(), first_result.values.len());
2021        for val in &second_result.values[240..] {
2022            assert!(val.is_finite());
2023        }
2024        Ok(())
2025    }
2026
2027    fn check_trima_nan_handling(
2028        test_name: &str,
2029        kernel: Kernel,
2030    ) -> Result<(), Box<dyn std::error::Error>> {
2031        skip_if_unsupported!(kernel, test_name);
2032        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2033        let candles = read_candles_from_csv(file_path)?;
2034
2035        let input = TrimaInput::from_candles(&candles, "close", TrimaParams { period: Some(14) });
2036        let res = trima_with_kernel(&input, kernel)?;
2037        assert_eq!(res.values.len(), candles.close.len());
2038        if res.values.len() > 240 {
2039            for (i, &val) in res.values[240..].iter().enumerate() {
2040                assert!(
2041                    !val.is_nan(),
2042                    "[{}] Found unexpected NaN at out-index {}",
2043                    test_name,
2044                    240 + i
2045                );
2046            }
2047        }
2048        Ok(())
2049    }
2050
2051    fn check_trima_streaming(
2052        test_name: &str,
2053        kernel: Kernel,
2054    ) -> Result<(), Box<dyn std::error::Error>> {
2055        skip_if_unsupported!(kernel, test_name);
2056
2057        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2058        let candles = read_candles_from_csv(file_path)?;
2059
2060        let period = 14;
2061
2062        let input = TrimaInput::from_candles(
2063            &candles,
2064            "close",
2065            TrimaParams {
2066                period: Some(period),
2067            },
2068        );
2069        let batch_output = trima_with_kernel(&input, kernel)?.values;
2070
2071        let mut stream = TrimaStream::try_new(TrimaParams {
2072            period: Some(period),
2073        })?;
2074
2075        let mut stream_values = Vec::with_capacity(candles.close.len());
2076        for &price in &candles.close {
2077            match stream.update(price) {
2078                Some(trima_val) => stream_values.push(trima_val),
2079                None => stream_values.push(f64::NAN),
2080            }
2081        }
2082
2083        assert_eq!(batch_output.len(), stream_values.len());
2084        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
2085            if b.is_nan() && s.is_nan() {
2086                continue;
2087            }
2088            let diff = (b - s).abs();
2089            assert!(
2090                diff < 1e-8,
2091                "[{}] TRIMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
2092                test_name,
2093                i,
2094                b,
2095                s,
2096                diff
2097            );
2098        }
2099        Ok(())
2100    }
2101
2102    macro_rules! generate_all_trima_tests {
2103        ($($test_fn:ident),*) => {
2104            paste! {
2105                $(
2106                    #[test]
2107                    fn [<$test_fn _scalar_f64>]() {
2108                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2109                    }
2110                )*
2111                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2112                $(
2113                    #[test]
2114                    fn [<$test_fn _avx2_f64>]() {
2115                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2116                    }
2117                    #[test]
2118                    fn [<$test_fn _avx512_f64>]() {
2119                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2120                    }
2121                )*
2122            }
2123        }
2124    }
2125
2126    #[cfg(debug_assertions)]
2127    fn check_trima_no_poison(
2128        test_name: &str,
2129        kernel: Kernel,
2130    ) -> Result<(), Box<dyn std::error::Error>> {
2131        skip_if_unsupported!(kernel, test_name);
2132
2133        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2134        let candles = read_candles_from_csv(file_path)?;
2135
2136        let test_periods = vec![4, 10, 14, 30, 50, 100];
2137        let test_sources = vec!["close", "open", "high", "low", "hl2", "hlc3", "ohlc4"];
2138
2139        for period in test_periods {
2140            for source in &test_sources {
2141                let params = TrimaParams {
2142                    period: Some(period),
2143                };
2144                let input = TrimaInput::from_candles(&candles, source, params);
2145                let output = trima_with_kernel(&input, kernel)?;
2146
2147                for (i, &val) in output.values.iter().enumerate() {
2148                    if val.is_nan() {
2149                        continue;
2150                    }
2151
2152                    let bits = val.to_bits();
2153
2154                    if bits == 0x11111111_11111111 {
2155                        panic!(
2156                            "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} (period={}, source={})",
2157                            test_name, val, bits, i, period, source
2158                        );
2159                    }
2160
2161                    if bits == 0x22222222_22222222 {
2162                        panic!(
2163                            "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} (period={}, source={})",
2164                            test_name, val, bits, i, period, source
2165                        );
2166                    }
2167
2168                    if bits == 0x33333333_33333333 {
2169                        panic!(
2170                            "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} (period={}, source={})",
2171                            test_name, val, bits, i, period, source
2172                        );
2173                    }
2174                }
2175            }
2176        }
2177
2178        Ok(())
2179    }
2180
2181    #[cfg(not(debug_assertions))]
2182    fn check_trima_no_poison(
2183        _test_name: &str,
2184        _kernel: Kernel,
2185    ) -> Result<(), Box<dyn std::error::Error>> {
2186        Ok(())
2187    }
2188
2189    #[cfg(feature = "proptest")]
2190    #[allow(clippy::float_cmp)]
2191    fn check_trima_property(
2192        test_name: &str,
2193        kernel: Kernel,
2194    ) -> Result<(), Box<dyn std::error::Error>> {
2195        use crate::indicators::sma::{sma, SmaData, SmaInput, SmaParams};
2196        use proptest::prelude::*;
2197        skip_if_unsupported!(kernel, test_name);
2198
2199        let strat = (4usize..=100).prop_flat_map(|period| {
2200            (
2201                prop::collection::vec(
2202                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2203                    period..400,
2204                ),
2205                Just(period),
2206            )
2207        });
2208
2209        proptest::test_runner::TestRunner::default().run(&strat, |(data, period)| {
2210            let params = TrimaParams {
2211                period: Some(period),
2212            };
2213            let input = TrimaInput::from_slice(&data, params);
2214
2215            let result = trima_with_kernel(&input, kernel)?;
2216            let scalar_result = trima_with_kernel(&input, Kernel::Scalar)?;
2217
2218            let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
2219            let warmup_end = first + period - 1;
2220
2221            for i in 0..warmup_end.min(data.len()) {
2222                prop_assert!(
2223                    result.values[i].is_nan(),
2224                    "Expected NaN during warmup at index {}, got {}",
2225                    i,
2226                    result.values[i]
2227                );
2228            }
2229
2230            for i in warmup_end..data.len() {
2231                prop_assert!(
2232                    result.values[i].is_finite() || data[i].is_nan(),
2233                    "Expected finite value after warmup at index {}, got {}",
2234                    i,
2235                    result.values[i]
2236                );
2237            }
2238
2239            if data[first..]
2240                .windows(2)
2241                .all(|w| (w[0] - w[1]).abs() < 1e-10)
2242                && data.len() > first
2243            {
2244                let constant_val = data[first];
2245                for i in warmup_end..data.len() {
2246                    prop_assert!(
2247							(result.values[i] - constant_val).abs() < 1e-9,
2248							"Constant input should produce constant output at index {}: expected {}, got {}",
2249							i,
2250							constant_val,
2251							result.values[i]
2252						);
2253                }
2254            }
2255
2256            for i in 0..data.len() {
2257                let val = result.values[i];
2258                let ref_val = scalar_result.values[i];
2259
2260                if val.is_nan() && ref_val.is_nan() {
2261                    continue;
2262                }
2263
2264                if !val.is_finite() || !ref_val.is_finite() {
2265                    prop_assert_eq!(
2266                        val.to_bits(),
2267                        ref_val.to_bits(),
2268                        "NaN/Inf mismatch at index {}: {} vs {}",
2269                        i,
2270                        val,
2271                        ref_val
2272                    );
2273                } else {
2274                    let ulp_diff = val.to_bits().abs_diff(ref_val.to_bits());
2275                    prop_assert!(
2276                        (val - ref_val).abs() < 1e-9 || ulp_diff <= 4,
2277                        "Cross-kernel mismatch at index {}: {} vs {} (ULP diff: {})",
2278                        i,
2279                        val,
2280                        ref_val,
2281                        ulp_diff
2282                    );
2283                }
2284            }
2285
2286            for (i, &val) in result.values.iter().enumerate() {
2287                prop_assert!(
2288                    val.is_nan() || val.is_finite(),
2289                    "Value should be finite or NaN at index {}, got {}",
2290                    i,
2291                    val
2292                );
2293            }
2294
2295            for i in warmup_end..data.len() {
2296                if i >= period - 1 {
2297                    let start = if i >= period - 1 { i + 1 - period } else { 0 };
2298                    let window = &data[start..=i];
2299                    let min_val = window
2300                        .iter()
2301                        .filter(|x| x.is_finite())
2302                        .fold(f64::INFINITY, |a, &b| a.min(b));
2303                    let max_val = window
2304                        .iter()
2305                        .filter(|x| x.is_finite())
2306                        .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
2307
2308                    if min_val.is_finite() && max_val.is_finite() {
2309                        let val = result.values[i];
2310
2311                        prop_assert!(
2312                            val >= min_val - 1e-6 && val <= max_val + 1e-6,
2313                            "TRIMA value {} at index {} outside window bounds [{}, {}]",
2314                            val,
2315                            i,
2316                            min_val,
2317                            max_val
2318                        );
2319                    }
2320                }
2321            }
2322
2323            if period == 4 {
2324                let m1 = 2;
2325                let m2 = 3;
2326
2327                let sma1_input = SmaInput {
2328                    data: SmaData::Slice(&data),
2329                    params: SmaParams { period: Some(m1) },
2330                };
2331                let pass1 = sma(&sma1_input)?;
2332
2333                let sma2_input = SmaInput {
2334                    data: SmaData::Slice(&pass1.values),
2335                    params: SmaParams { period: Some(m2) },
2336                };
2337                let expected = sma(&sma2_input)?;
2338
2339                for i in warmup_end..data.len().min(warmup_end + 5) {
2340                    prop_assert!(
2341                        (result.values[i] - expected.values[i]).abs() < 1e-9,
2342                        "Period=4: TRIMA mismatch at index {}: got {}, expected {}",
2343                        i,
2344                        result.values[i],
2345                        expected.values[i]
2346                    );
2347                }
2348            }
2349
2350            {
2351                let m1 = (period + 1) / 2;
2352                let m2 = period - m1 + 1;
2353
2354                let sma1_input = SmaInput {
2355                    data: SmaData::Slice(&data),
2356                    params: SmaParams { period: Some(m1) },
2357                };
2358                let pass1 = sma(&sma1_input)?;
2359
2360                let sma2_input = SmaInput {
2361                    data: SmaData::Slice(&pass1.values),
2362                    params: SmaParams { period: Some(m2) },
2363                };
2364                let expected = sma(&sma2_input)?;
2365
2366                let check_points = vec![
2367                    warmup_end,
2368                    warmup_end + period / 2,
2369                    warmup_end + period,
2370                    data.len() - 1,
2371                ];
2372
2373                for &idx in &check_points {
2374                    if idx < data.len() {
2375                        let trima_val = result.values[idx];
2376                        let expected_val = expected.values[idx];
2377
2378                        if trima_val.is_finite() && expected_val.is_finite() {
2379                            prop_assert!(
2380                                (trima_val - expected_val).abs() < 1e-9,
2381                                "Two-pass SMA formula mismatch at index {}: TRIMA={}, Expected={}",
2382                                idx,
2383                                trima_val,
2384                                expected_val
2385                            );
2386                        }
2387                    }
2388                }
2389            }
2390
2391            if data.len() >= warmup_end + 20 {
2392                let sma_input = SmaInput {
2393                    data: SmaData::Slice(&data),
2394                    params: SmaParams {
2395                        period: Some(period),
2396                    },
2397                };
2398                let single_sma = sma(&sma_input)?;
2399
2400                let trima_roughness: f64 = result.values[warmup_end..warmup_end + 20]
2401                    .windows(2)
2402                    .map(|w| (w[1] - w[0]).abs())
2403                    .sum();
2404
2405                let sma_roughness: f64 = single_sma.values[warmup_end..warmup_end + 20]
2406                    .windows(2)
2407                    .map(|w| (w[1] - w[0]).abs())
2408                    .sum();
2409
2410                if sma_roughness > 1e-10 {
2411                    prop_assert!(
2412							trima_roughness <= sma_roughness * 1.1,
2413							"TRIMA should be smoother than single SMA: TRIMA roughness={}, SMA roughness={}",
2414							trima_roughness,
2415							sma_roughness
2416						);
2417                }
2418            }
2419
2420            if data.len() == period {
2421                prop_assert!(
2422                    result.values[period - 1].is_finite(),
2423                    "With data.len()==period, last value should be finite, got {}",
2424                    result.values[period - 1]
2425                );
2426
2427                for i in 0..period - 1 {
2428                    prop_assert!(
2429                        result.values[i].is_nan(),
2430                        "With data.len()==period, value at {} should be NaN, got {}",
2431                        i,
2432                        result.values[i]
2433                    );
2434                }
2435            }
2436
2437            let is_monotonic_increasing = data[first..].windows(2).all(|w| w[1] >= w[0] - 1e-10);
2438            let is_monotonic_decreasing = data[first..].windows(2).all(|w| w[1] <= w[0] + 1e-10);
2439
2440            if is_monotonic_increasing || is_monotonic_decreasing {
2441                let valid_trima = &result.values[warmup_end..];
2442                if valid_trima.len() >= 2 {
2443                    if is_monotonic_increasing {
2444                        for w in valid_trima.windows(2) {
2445                            prop_assert!(
2446                                w[1] >= w[0] - 1e-9,
2447                                "TRIMA should preserve increasing trend: {} < {}",
2448                                w[1],
2449                                w[0]
2450                            );
2451                        }
2452                    } else {
2453                        for w in valid_trima.windows(2) {
2454                            prop_assert!(
2455                                w[1] <= w[0] + 1e-9,
2456                                "TRIMA should preserve decreasing trend: {} > {}",
2457                                w[1],
2458                                w[0]
2459                            );
2460                        }
2461                    }
2462                }
2463            }
2464
2465            #[cfg(debug_assertions)]
2466            {
2467                for (i, &val) in result.values.iter().enumerate() {
2468                    if !val.is_nan() {
2469                        let bits = val.to_bits();
2470                        prop_assert!(
2471                            bits != 0x11111111_11111111
2472                                && bits != 0x22222222_22222222
2473                                && bits != 0x33333333_33333333,
2474                            "Found poison value at index {}: {} (0x{:016X})",
2475                            i,
2476                            val,
2477                            bits
2478                        );
2479                    }
2480                }
2481            }
2482
2483            Ok(())
2484        })?;
2485
2486        Ok(())
2487    }
2488
2489    #[cfg(feature = "proptest")]
2490    generate_all_trima_tests!(check_trima_property);
2491
2492    generate_all_trima_tests!(
2493        check_trima_partial_params,
2494        check_trima_accuracy,
2495        check_trima_default_candles,
2496        check_trima_zero_period,
2497        check_trima_period_exceeds_length,
2498        check_trima_period_too_small,
2499        check_trima_very_small_dataset,
2500        check_trima_reinput,
2501        check_trima_nan_handling,
2502        check_trima_streaming,
2503        check_trima_no_poison
2504    );
2505
2506    fn check_batch_default_row(
2507        test: &str,
2508        kernel: Kernel,
2509    ) -> Result<(), Box<dyn std::error::Error>> {
2510        skip_if_unsupported!(kernel, test);
2511
2512        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2513        let c = read_candles_from_csv(file)?;
2514
2515        let output = TrimaBatchBuilder::new()
2516            .kernel(kernel)
2517            .apply_candles(&c, "close")?;
2518
2519        let def = TrimaParams::default();
2520        let row = output.values_for(&def).expect("default row missing");
2521
2522        assert_eq!(row.len(), c.close.len());
2523
2524        let expected = [
2525            59957.916666666664,
2526            59846.770833333336,
2527            59750.620833333334,
2528            59665.2125,
2529            59581.612499999996,
2530        ];
2531        let start = row.len() - 5;
2532        for (i, &v) in row[start..].iter().enumerate() {
2533            assert!(
2534                (v - expected[i]).abs() < 1e-6,
2535                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2536            );
2537        }
2538        Ok(())
2539    }
2540
2541    macro_rules! gen_batch_tests {
2542        ($fn_name:ident) => {
2543            paste! {
2544                #[test] fn [<$fn_name _scalar>]()      {
2545                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2546                }
2547                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2548                #[test] fn [<$fn_name _avx2>]()        {
2549                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2550                }
2551                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2552                #[test] fn [<$fn_name _avx512>]()      {
2553                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2554                }
2555                #[test] fn [<$fn_name _auto_detect>]() {
2556                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2557                }
2558            }
2559        };
2560    }
2561
2562    #[cfg(debug_assertions)]
2563    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
2564        skip_if_unsupported!(kernel, test);
2565
2566        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2567        let c = read_candles_from_csv(file)?;
2568
2569        let period_ranges = vec![(4, 20, 4), (20, 50, 10), (50, 100, 25), (5, 15, 1)];
2570
2571        let test_sources = vec!["close", "open", "high", "low", "hl2", "hlc3", "ohlc4"];
2572
2573        for (start, end, step) in period_ranges {
2574            for source in &test_sources {
2575                let output = TrimaBatchBuilder::new()
2576                    .kernel(kernel)
2577                    .period_range(start, end, step)
2578                    .apply_candles(&c, source)?;
2579
2580                for (idx, &val) in output.values.iter().enumerate() {
2581                    if val.is_nan() {
2582                        continue;
2583                    }
2584
2585                    let bits = val.to_bits();
2586                    let row = idx / output.cols;
2587                    let col = idx % output.cols;
2588
2589                    if bits == 0x11111111_11111111 {
2590                        panic!(
2591                            "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with period_range({},{},{}) source={}",
2592                            test, val, bits, row, col, idx, start, end, step, source
2593                        );
2594                    }
2595
2596                    if bits == 0x22222222_22222222 {
2597                        panic!(
2598                            "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) with period_range({},{},{}) source={}",
2599                            test, val, bits, row, col, idx, start, end, step, source
2600                        );
2601                    }
2602
2603                    if bits == 0x33333333_33333333 {
2604                        panic!(
2605                            "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with period_range({},{},{}) source={}",
2606                            test, val, bits, row, col, idx, start, end, step, source
2607                        );
2608                    }
2609                }
2610            }
2611        }
2612
2613        Ok(())
2614    }
2615
2616    #[cfg(not(debug_assertions))]
2617    fn check_batch_no_poison(
2618        _test: &str,
2619        _kernel: Kernel,
2620    ) -> Result<(), Box<dyn std::error::Error>> {
2621        Ok(())
2622    }
2623
2624    gen_batch_tests!(check_batch_default_row);
2625    gen_batch_tests!(check_batch_no_poison);
2626
2627    #[test]
2628    fn test_trima_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
2629        let mut data = Vec::with_capacity(256);
2630        data.extend_from_slice(&[f64::NAN, f64::NAN, f64::NAN, f64::NAN]);
2631        for i in 0..252 {
2632            let x = (i as f64 * 0.131).sin() * 7.0 + (i as f64) * 0.02;
2633            data.push(x);
2634        }
2635
2636        let input = TrimaInput::from_slice(&data, TrimaParams::default());
2637
2638        let baseline = trima(&input)?;
2639
2640        let mut out = vec![0.0; data.len()];
2641        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2642        {
2643            trima_into(&input, &mut out)?;
2644        }
2645        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2646        {
2647            trima_into_slice(&mut out, &input, Kernel::Auto)?;
2648        }
2649
2650        assert_eq!(baseline.values.len(), out.len());
2651
2652        for (a, b) in baseline.values.iter().copied().zip(out.iter().copied()) {
2653            let both_nan = a.is_nan() && b.is_nan();
2654            assert!(both_nan || a == b, "mismatch: got {b:?}, expected {a:?}");
2655        }
2656        Ok(())
2657    }
2658}