Skip to main content

vector_ta/indicators/
trix.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{alloc_with_nan_prefix, init_matrix_prefixes, make_uninit_matrix};
4use aligned_vec::{AVec, CACHELINE_ALIGN};
5#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
6use core::arch::x86_64::*;
7#[cfg(not(target_arch = "wasm32"))]
8use rayon::prelude::*;
9use std::convert::AsRef;
10use std::error::Error;
11use thiserror::Error;
12
13#[cfg(all(feature = "python", feature = "cuda"))]
14use crate::cuda::moving_averages::CudaTrix;
15#[cfg(all(feature = "python", feature = "cuda"))]
16use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
17#[cfg(feature = "python")]
18use crate::utilities::kernel_validation::validate_kernel;
19#[cfg(feature = "python")]
20use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
21#[cfg(feature = "python")]
22use pyo3::exceptions::PyValueError;
23#[cfg(feature = "python")]
24use pyo3::prelude::*;
25#[cfg(feature = "python")]
26use pyo3::types::PyDict;
27
28#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
29use serde::{Deserialize, Serialize};
30#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
31use wasm_bindgen::prelude::*;
32
33impl<'a> AsRef<[f64]> for TrixInput<'a> {
34    #[inline(always)]
35    fn as_ref(&self) -> &[f64] {
36        match &self.data {
37            TrixData::Slice(slice) => slice,
38            TrixData::Candles { candles, source } => source_type(candles, source),
39        }
40    }
41}
42
43#[derive(Debug, Clone)]
44pub enum TrixData<'a> {
45    Candles {
46        candles: &'a Candles,
47        source: &'a str,
48    },
49    Slice(&'a [f64]),
50}
51
52#[derive(Debug, Clone)]
53pub struct TrixOutput {
54    pub values: Vec<f64>,
55}
56
57#[derive(Debug, Clone)]
58pub struct TrixParams {
59    pub period: Option<usize>,
60}
61
62impl Default for TrixParams {
63    fn default() -> Self {
64        Self { period: Some(18) }
65    }
66}
67
68#[derive(Debug, Clone)]
69pub struct TrixInput<'a> {
70    pub data: TrixData<'a>,
71    pub params: TrixParams,
72}
73
74impl<'a> TrixInput<'a> {
75    #[inline]
76    pub fn from_candles(c: &'a Candles, s: &'a str, p: TrixParams) -> Self {
77        Self {
78            data: TrixData::Candles {
79                candles: c,
80                source: s,
81            },
82            params: p,
83        }
84    }
85    #[inline]
86    pub fn from_slice(sl: &'a [f64], p: TrixParams) -> Self {
87        Self {
88            data: TrixData::Slice(sl),
89            params: p,
90        }
91    }
92    #[inline]
93    pub fn with_default_candles(c: &'a Candles) -> Self {
94        Self::from_candles(c, "close", TrixParams::default())
95    }
96    #[inline]
97    pub fn get_period(&self) -> usize {
98        self.params.period.unwrap_or(18)
99    }
100}
101
102#[derive(Copy, Clone, Debug)]
103pub struct TrixBuilder {
104    period: Option<usize>,
105    kernel: Kernel,
106}
107
108impl Default for TrixBuilder {
109    fn default() -> Self {
110        Self {
111            period: None,
112            kernel: Kernel::Auto,
113        }
114    }
115}
116
117impl TrixBuilder {
118    #[inline(always)]
119    pub fn new() -> Self {
120        Self::default()
121    }
122    #[inline(always)]
123    pub fn period(mut self, n: usize) -> Self {
124        self.period = Some(n);
125        self
126    }
127    #[inline(always)]
128    pub fn kernel(mut self, k: Kernel) -> Self {
129        self.kernel = k;
130        self
131    }
132    #[inline(always)]
133    pub fn apply(self, c: &Candles) -> Result<TrixOutput, TrixError> {
134        let p = TrixParams {
135            period: self.period,
136        };
137        let i = TrixInput::from_candles(c, "close", p);
138        trix_with_kernel(&i, self.kernel)
139    }
140    #[inline(always)]
141    pub fn apply_slice(self, d: &[f64]) -> Result<TrixOutput, TrixError> {
142        let p = TrixParams {
143            period: self.period,
144        };
145        let i = TrixInput::from_slice(d, p);
146        trix_with_kernel(&i, self.kernel)
147    }
148    #[inline(always)]
149    pub fn into_stream(self) -> Result<TrixStream, TrixError> {
150        let p = TrixParams {
151            period: self.period,
152        };
153        TrixStream::try_new(p)
154    }
155}
156
157#[derive(Debug, Error)]
158pub enum TrixError {
159    #[error("trix: Empty data provided.")]
160    EmptyInputData,
161    #[error("trix: Invalid period: period = {period}, data length = {data_len}")]
162    InvalidPeriod { period: usize, data_len: usize },
163    #[error("trix: Not enough valid data: needed = {needed}, valid = {valid}")]
164    NotEnoughValidData { needed: usize, valid: usize },
165    #[error("trix: All values are NaN.")]
166    AllValuesNaN,
167    #[error("trix: output length mismatch: expected = {expected}, got = {got}")]
168    OutputLengthMismatch { expected: usize, got: usize },
169    #[error("trix: invalid range: start={start}, end={end}, step={step}")]
170    InvalidRange {
171        start: usize,
172        end: usize,
173        step: usize,
174    },
175    #[error("trix: invalid kernel for batch: {0:?}")]
176    InvalidKernelForBatch(Kernel),
177    #[error("trix: invalid input: {0}")]
178    InvalidInput(String),
179}
180
181#[inline]
182pub fn trix(input: &TrixInput) -> Result<TrixOutput, TrixError> {
183    trix_with_kernel(input, Kernel::Auto)
184}
185
186#[inline(always)]
187fn trix_needed_len(period: usize) -> Result<usize, TrixError> {
188    let base = period
189        .checked_sub(1)
190        .and_then(|v| v.checked_mul(3))
191        .and_then(|v| v.checked_add(2))
192        .ok_or_else(|| {
193            TrixError::InvalidInput("period overflow when computing TRIX warmup length".into())
194        })?;
195    Ok(base)
196}
197
198#[inline(always)]
199fn trix_warmup_end(first: usize, period: usize) -> Result<usize, TrixError> {
200    let delta = period
201        .checked_sub(1)
202        .and_then(|v| v.checked_mul(3))
203        .and_then(|v| v.checked_add(1))
204        .ok_or_else(|| {
205            TrixError::InvalidInput("period overflow when computing TRIX warmup index".into())
206        })?;
207    first.checked_add(delta).ok_or_else(|| {
208        TrixError::InvalidInput("index overflow when computing TRIX warmup index".into())
209    })
210}
211
212#[inline(always)]
213fn trix_prepare<'a>(
214    input: &'a TrixInput,
215    k: Kernel,
216) -> Result<(&'a [f64], usize, usize, Kernel, f64, usize), TrixError> {
217    let data: &[f64] = input.as_ref();
218    let len = data.len();
219    if len == 0 {
220        return Err(TrixError::EmptyInputData);
221    }
222    let period = input.get_period();
223    if period == 0 || period > len {
224        return Err(TrixError::InvalidPeriod {
225            period,
226            data_len: len,
227        });
228    }
229    let first = data
230        .iter()
231        .position(|x| !x.is_nan())
232        .ok_or(TrixError::AllValuesNaN)?;
233    let needed = trix_needed_len(period)?;
234    let valid_len = len.saturating_sub(first);
235    if valid_len < needed {
236        return Err(TrixError::NotEnoughValidData {
237            needed,
238            valid: valid_len,
239        });
240    }
241    let chosen = match k {
242        Kernel::Auto => Kernel::Scalar,
243        other => other,
244    };
245    let alpha = 2.0 / (period as f64 + 1.0);
246    let warmup_end = trix_warmup_end(first, period)?;
247    Ok((data, period, first, chosen, alpha, warmup_end))
248}
249
250#[inline(always)]
251fn trix_compute_into_scalar(
252    data: &[f64],
253    period: usize,
254    first: usize,
255    alpha: f64,
256    out: &mut [f64],
257) {
258    let len = data.len();
259    let warmup_end = first + 3 * (period - 1) + 1;
260    if warmup_end >= len {
261        return;
262    }
263
264    let inv_n = 1.0 / period as f64;
265    const SCALE: f64 = 10000.0;
266
267    let mut sum1 = 0.0;
268    let end1 = first + period;
269    let mut i = first;
270    while i < end1 {
271        sum1 += data[i].ln();
272        i += 1;
273    }
274    let mut ema1 = sum1 * inv_n;
275
276    let mut sum_ema1 = ema1;
277    let end2 = first + 2 * period - 1;
278    i = end1;
279
280    while i + 3 < end2 {
281        let mut lv = data[i].ln();
282        ema1 = (lv - ema1).mul_add(alpha, ema1);
283        sum_ema1 += ema1;
284
285        lv = data[i + 1].ln();
286        ema1 = (lv - ema1).mul_add(alpha, ema1);
287        sum_ema1 += ema1;
288
289        lv = data[i + 2].ln();
290        ema1 = (lv - ema1).mul_add(alpha, ema1);
291        sum_ema1 += ema1;
292
293        lv = data[i + 3].ln();
294        ema1 = (lv - ema1).mul_add(alpha, ema1);
295        sum_ema1 += ema1;
296        i += 4;
297    }
298    while i < end2 {
299        let lv = data[i].ln();
300        ema1 = (lv - ema1).mul_add(alpha, ema1);
301        sum_ema1 += ema1;
302        i += 1;
303    }
304
305    let mut ema2 = sum_ema1 * inv_n;
306
307    let mut sum_ema2 = ema2;
308    let end3 = first + 3 * period - 2;
309    i = end2;
310
311    while i + 3 < end3 {
312        let mut lv = data[i].ln();
313        ema1 = (lv - ema1).mul_add(alpha, ema1);
314        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
315        sum_ema2 += ema2;
316
317        lv = data[i + 1].ln();
318        ema1 = (lv - ema1).mul_add(alpha, ema1);
319        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
320        sum_ema2 += ema2;
321
322        lv = data[i + 2].ln();
323        ema1 = (lv - ema1).mul_add(alpha, ema1);
324        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
325        sum_ema2 += ema2;
326
327        lv = data[i + 3].ln();
328        ema1 = (lv - ema1).mul_add(alpha, ema1);
329        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
330        sum_ema2 += ema2;
331        i += 4;
332    }
333    while i < end3 {
334        let lv = data[i].ln();
335        ema1 = (lv - ema1).mul_add(alpha, ema1);
336        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
337        sum_ema2 += ema2;
338        i += 1;
339    }
340
341    let mut ema3_prev = sum_ema2 * inv_n;
342
343    let mut src = warmup_end;
344    let mut lv = data[src].ln();
345    ema1 = (lv - ema1).mul_add(alpha, ema1);
346    ema2 = (ema1 - ema2).mul_add(alpha, ema2);
347    let mut ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
348    out[src] = (ema3 - ema3_prev) * SCALE;
349    ema3_prev = ema3;
350    src += 1;
351
352    while src + 3 < len {
353        lv = data[src].ln();
354        ema1 = (lv - ema1).mul_add(alpha, ema1);
355        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
356        ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
357        out[src] = (ema3 - ema3_prev) * SCALE;
358        ema3_prev = ema3;
359
360        let lv1 = data[src + 1].ln();
361        ema1 = (lv1 - ema1).mul_add(alpha, ema1);
362        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
363        ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
364        out[src + 1] = (ema3 - ema3_prev) * SCALE;
365        ema3_prev = ema3;
366
367        let lv2 = data[src + 2].ln();
368        ema1 = (lv2 - ema1).mul_add(alpha, ema1);
369        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
370        ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
371        out[src + 2] = (ema3 - ema3_prev) * SCALE;
372        ema3_prev = ema3;
373
374        let lv3 = data[src + 3].ln();
375        ema1 = (lv3 - ema1).mul_add(alpha, ema1);
376        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
377        ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
378        out[src + 3] = (ema3 - ema3_prev) * SCALE;
379        ema3_prev = ema3;
380
381        src += 4;
382    }
383
384    while src < len {
385        lv = data[src].ln();
386        ema1 = (lv - ema1).mul_add(alpha, ema1);
387        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
388        ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
389        out[src] = (ema3 - ema3_prev) * SCALE;
390        ema3_prev = ema3;
391        src += 1;
392    }
393}
394
395pub fn trix_with_kernel(input: &TrixInput, kernel: Kernel) -> Result<TrixOutput, TrixError> {
396    let (data, period, first, chosen, alpha, warmup_end) = trix_prepare(input, kernel)?;
397    let mut out = alloc_with_nan_prefix(data.len(), warmup_end);
398    unsafe {
399        match chosen {
400            Kernel::Scalar | Kernel::ScalarBatch => {
401                trix_compute_into_scalar(data, period, first, alpha, &mut out);
402            }
403            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
404            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
405                trix_compute_into_scalar(data, period, first, alpha, &mut out);
406            }
407            #[allow(unreachable_patterns)]
408            _ => trix_compute_into_scalar(data, period, first, alpha, &mut out),
409        }
410    }
411    Ok(TrixOutput { values: out })
412}
413
414#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
415#[inline]
416pub fn trix_into(input: &TrixInput, out: &mut [f64]) -> Result<(), TrixError> {
417    let (data, period, first, chosen, alpha, warmup_end) = trix_prepare(input, Kernel::Auto)?;
418
419    if out.len() != data.len() {
420        return Err(TrixError::OutputLengthMismatch {
421            expected: data.len(),
422            got: out.len(),
423        });
424    }
425
426    let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
427    let warm = warmup_end.min(out.len());
428    for v in &mut out[..warm] {
429        *v = qnan;
430    }
431
432    unsafe {
433        match chosen {
434            Kernel::Scalar | Kernel::ScalarBatch => {
435                trix_compute_into_scalar(data, period, first, alpha, out)
436            }
437            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
438            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
439                trix_compute_into_scalar(data, period, first, alpha, out)
440            }
441            #[allow(unreachable_patterns)]
442            _ => trix_compute_into_scalar(data, period, first, alpha, out),
443        }
444    }
445    Ok(())
446}
447
448#[derive(Debug, Clone)]
449pub struct TrixStream {
450    period: usize,
451    alpha: f64,
452    inv_n: f64,
453    state: StreamState,
454}
455
456#[derive(Debug, Clone)]
457enum StreamState {
458    Seed1 {
459        need: usize,
460        sum1: f64,
461    },
462
463    Seed2 {
464        remain: usize,
465        ema1: f64,
466        sum_ema1: f64,
467    },
468
469    Seed3 {
470        remain: usize,
471        ema1: f64,
472        ema2: f64,
473        sum_ema2: f64,
474    },
475
476    Running {
477        ema1: f64,
478        ema2: f64,
479        ema3_prev: f64,
480    },
481}
482
483impl TrixStream {
484    #[inline(always)]
485    pub fn try_new(params: TrixParams) -> Result<Self, TrixError> {
486        let period = params.period.unwrap_or(18);
487        if period == 0 {
488            return Err(TrixError::InvalidPeriod {
489                period,
490                data_len: 0,
491            });
492        }
493        let alpha = 2.0 / (period as f64 + 1.0);
494        Ok(Self {
495            period,
496            alpha,
497            inv_n: 1.0 / period as f64,
498            state: StreamState::Seed1 {
499                need: period,
500                sum1: 0.0,
501            },
502        })
503    }
504
505    #[inline(always)]
506    fn reset(&mut self) {
507        self.state = StreamState::Seed1 {
508            need: self.period,
509            sum1: 0.0,
510        };
511    }
512
513    #[inline(always)]
514    pub fn update(&mut self, value: f64) -> Option<f64> {
515        if !value.is_finite() || value <= 0.0 {
516            self.reset();
517            return None;
518        }
519
520        let lv = value.ln();
521        let a = self.alpha;
522
523        match &mut self.state {
524            StreamState::Seed1 { need, sum1 } => {
525                *sum1 += lv;
526                *need -= 1;
527                if *need == 0 {
528                    let ema1 = *sum1 * self.inv_n;
529                    self.state = StreamState::Seed2 {
530                        remain: self.period - 1,
531                        ema1,
532                        sum_ema1: ema1,
533                    };
534                }
535                None
536            }
537
538            StreamState::Seed2 {
539                remain,
540                ema1,
541                sum_ema1,
542            } => {
543                *ema1 = (lv - *ema1).mul_add(a, *ema1);
544                *sum_ema1 += *ema1;
545                *remain -= 1;
546                if *remain == 0 {
547                    let ema2 = *sum_ema1 * self.inv_n;
548                    let e1 = *ema1;
549                    self.state = StreamState::Seed3 {
550                        remain: self.period - 1,
551                        ema1: e1,
552                        ema2,
553                        sum_ema2: ema2,
554                    };
555                }
556                None
557            }
558
559            StreamState::Seed3 {
560                remain,
561                ema1,
562                ema2,
563                sum_ema2,
564            } => {
565                *ema1 = (lv - *ema1).mul_add(a, *ema1);
566                *ema2 = (*ema1 - *ema2).mul_add(a, *ema2);
567                *sum_ema2 += *ema2;
568                *remain -= 1;
569                if *remain == 0 {
570                    let ema3_prev = *sum_ema2 * self.inv_n;
571                    let e1 = *ema1;
572                    let e2 = *ema2;
573                    self.state = StreamState::Running {
574                        ema1: e1,
575                        ema2: e2,
576                        ema3_prev,
577                    };
578                }
579                None
580            }
581
582            StreamState::Running {
583                ema1,
584                ema2,
585                ema3_prev,
586            } => {
587                *ema1 = (lv - *ema1).mul_add(a, *ema1);
588                *ema2 = (*ema1 - *ema2).mul_add(a, *ema2);
589                let ema3 = (*ema2 - *ema3_prev).mul_add(a, *ema3_prev);
590                let out = (ema3 - *ema3_prev) * 10000.0;
591                *ema3_prev = ema3;
592                Some(out)
593            }
594        }
595    }
596}
597
598#[derive(Clone, Debug)]
599pub struct TrixBatchRange {
600    pub period: (usize, usize, usize),
601}
602
603impl Default for TrixBatchRange {
604    fn default() -> Self {
605        Self {
606            period: (18, 267, 1),
607        }
608    }
609}
610
611#[derive(Clone, Debug, Default)]
612pub struct TrixBatchBuilder {
613    range: TrixBatchRange,
614    kernel: Kernel,
615}
616
617impl TrixBatchBuilder {
618    pub fn new() -> Self {
619        Self::default()
620    }
621    pub fn kernel(mut self, k: Kernel) -> Self {
622        self.kernel = k;
623        self
624    }
625    #[inline]
626    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
627        self.range.period = (start, end, step);
628        self
629    }
630    #[inline]
631    pub fn period_static(mut self, p: usize) -> Self {
632        self.range.period = (p, p, 0);
633        self
634    }
635    pub fn apply_slice(self, data: &[f64]) -> Result<TrixBatchOutput, TrixError> {
636        trix_batch_with_kernel(data, &self.range, self.kernel)
637    }
638    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<TrixBatchOutput, TrixError> {
639        TrixBatchBuilder::new().kernel(k).apply_slice(data)
640    }
641    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<TrixBatchOutput, TrixError> {
642        let slice = source_type(c, src);
643        self.apply_slice(slice)
644    }
645    pub fn with_default_candles(c: &Candles) -> Result<TrixBatchOutput, TrixError> {
646        TrixBatchBuilder::new()
647            .kernel(Kernel::Auto)
648            .apply_candles(c, "close")
649    }
650}
651
652pub fn trix_batch_with_kernel(
653    data: &[f64],
654    sweep: &TrixBatchRange,
655    k: Kernel,
656) -> Result<TrixBatchOutput, TrixError> {
657    let kernel = match k {
658        Kernel::Auto => Kernel::ScalarBatch,
659        other if other.is_batch() => other,
660        _ => {
661            return Err(TrixError::InvalidKernelForBatch(k));
662        }
663    };
664    let simd = match kernel {
665        Kernel::Avx512Batch => Kernel::Avx512,
666        Kernel::Avx2Batch => Kernel::Avx2,
667        Kernel::ScalarBatch => Kernel::Scalar,
668        _ => unreachable!(),
669    };
670    trix_batch_par_slice(data, sweep, simd)
671}
672
673#[derive(Clone, Debug)]
674pub struct TrixBatchOutput {
675    pub values: Vec<f64>,
676    pub combos: Vec<TrixParams>,
677    pub rows: usize,
678    pub cols: usize,
679}
680impl TrixBatchOutput {
681    pub fn row_for_params(&self, p: &TrixParams) -> Option<usize> {
682        self.combos
683            .iter()
684            .position(|c| c.period.unwrap_or(18) == p.period.unwrap_or(18))
685    }
686    pub fn values_for(&self, p: &TrixParams) -> Option<&[f64]> {
687        self.row_for_params(p).map(|row| {
688            let start = row * self.cols;
689            &self.values[start..start + self.cols]
690        })
691    }
692}
693
694#[inline(always)]
695fn expand_grid(r: &TrixBatchRange) -> Result<Vec<TrixParams>, TrixError> {
696    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, TrixError> {
697        if step == 0 || start == end {
698            return Ok(vec![start]);
699        }
700        let mut vals = Vec::new();
701        if start < end {
702            let mut v = start;
703            while v <= end {
704                vals.push(v);
705                let next = match v.checked_add(step) {
706                    Some(n) => n,
707                    None => break,
708                };
709                if next == v {
710                    break;
711                }
712                v = next;
713            }
714        } else {
715            let mut v = start;
716            while v >= end {
717                vals.push(v);
718                let next = v.saturating_sub(step);
719                if next == v {
720                    break;
721                }
722                v = next;
723            }
724        }
725        if vals.is_empty() {
726            return Err(TrixError::InvalidRange { start, end, step });
727        }
728        Ok(vals)
729    }
730    let periods = axis_usize(r.period)?;
731    let mut out = Vec::with_capacity(periods.len());
732    for &p in &periods {
733        out.push(TrixParams { period: Some(p) });
734    }
735    Ok(out)
736}
737
738#[inline(always)]
739pub fn trix_batch_slice(
740    data: &[f64],
741    sweep: &TrixBatchRange,
742    kern: Kernel,
743) -> Result<TrixBatchOutput, TrixError> {
744    trix_batch_inner(data, sweep, kern, false)
745}
746
747#[inline(always)]
748pub fn trix_batch_par_slice(
749    data: &[f64],
750    sweep: &TrixBatchRange,
751    kern: Kernel,
752) -> Result<TrixBatchOutput, TrixError> {
753    trix_batch_inner(data, sweep, kern, true)
754}
755
756#[inline(always)]
757fn trix_batch_inner(
758    data: &[f64],
759    sweep: &TrixBatchRange,
760    kern: Kernel,
761    parallel: bool,
762) -> Result<TrixBatchOutput, TrixError> {
763    let combos = expand_grid(sweep)?;
764    let first = data
765        .iter()
766        .position(|x| !x.is_nan())
767        .ok_or(TrixError::AllValuesNaN)?;
768    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
769    let needed = trix_needed_len(max_p)?;
770    if data.len() - first < needed {
771        return Err(TrixError::NotEnoughValidData {
772            needed,
773            valid: data.len() - first,
774        });
775    }
776    let rows = combos.len();
777    let cols = data.len();
778    let _total = rows
779        .checked_mul(cols)
780        .ok_or_else(|| TrixError::InvalidInput("rows*cols overflow".into()))?;
781
782    let mut buf_mu = make_uninit_matrix(rows, cols);
783
784    let warm: Vec<usize> = combos
785        .iter()
786        .map(|c| trix_warmup_end(first, c.period.unwrap()))
787        .collect::<Result<_, _>>()?;
788
789    init_matrix_prefixes(&mut buf_mu, cols, &warm);
790
791    let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
792    let values: &mut [f64] = unsafe {
793        core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
794    };
795
796    let mut logs: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, cols);
797    unsafe { logs.set_len(cols) };
798    for i in 0..first {
799        logs[i] = 0.0;
800    }
801    for i in first..cols {
802        logs[i] = data[i].ln();
803    }
804
805    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
806        let period = combos[row].period.unwrap();
807        trix_row_scalar_with_logs(&logs, first, period, out_row)
808    };
809    if parallel {
810        #[cfg(not(target_arch = "wasm32"))]
811        {
812            values
813                .par_chunks_mut(cols)
814                .enumerate()
815                .for_each(|(row, slice)| do_row(row, slice));
816        }
817
818        #[cfg(target_arch = "wasm32")]
819        {
820            for (row, slice) in values.chunks_mut(cols).enumerate() {
821                do_row(row, slice);
822            }
823        }
824    } else {
825        for (row, slice) in values.chunks_mut(cols).enumerate() {
826            do_row(row, slice);
827        }
828    }
829
830    let values = unsafe {
831        let ptr = buf_guard.as_mut_ptr() as *mut f64;
832        let len = buf_guard.len();
833        core::mem::forget(buf_guard);
834        Vec::from_raw_parts(ptr, len, len)
835    };
836
837    Ok(TrixBatchOutput {
838        values,
839        combos,
840        rows,
841        cols,
842    })
843}
844
845#[inline(always)]
846unsafe fn trix_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
847    let len = data.len();
848    let alpha = 2.0 / (period as f64 + 1.0);
849    let inv_n = 1.0 / period as f64;
850    const SCALE: f64 = 10000.0;
851
852    let warmup_end = first + 3 * (period - 1) + 1;
853    for v in &mut out[..warmup_end.min(len)] {
854        *v = f64::NAN;
855    }
856    if warmup_end >= len {
857        return;
858    }
859
860    let p = data.as_ptr();
861
862    let mut sum1 = 0.0;
863    let end1 = first + period;
864    let mut i = first;
865    while i < end1 {
866        sum1 += (*p.add(i)).ln();
867        i += 1;
868    }
869    let mut ema1 = sum1 * inv_n;
870
871    let mut sum_ema1 = ema1;
872    let end2 = first + 2 * period - 1;
873    i = end1;
874    while i < end2 {
875        let lv = (*p.add(i)).ln();
876        ema1 = (lv - ema1).mul_add(alpha, ema1);
877        sum_ema1 += ema1;
878        i += 1;
879    }
880
881    let mut ema2 = sum_ema1 * inv_n;
882
883    let mut sum_ema2 = ema2;
884    let end3 = first + 3 * period - 2;
885    i = end2;
886    while i < end3 {
887        let lv = (*p.add(i)).ln();
888        ema1 = (lv - ema1).mul_add(alpha, ema1);
889        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
890        sum_ema2 += ema2;
891        i += 1;
892    }
893
894    let mut ema3_prev = sum_ema2 * inv_n;
895
896    let mut src = warmup_end;
897    let mut lv = (*p.add(src)).ln();
898    ema1 = (lv - ema1).mul_add(alpha, ema1);
899    ema2 = (ema1 - ema2).mul_add(alpha, ema2);
900    let mut ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
901    *out.get_unchecked_mut(src) = (ema3 - ema3_prev) * SCALE;
902    ema3_prev = ema3;
903    src += 1;
904
905    while src + 1 < len {
906        lv = (*p.add(src)).ln();
907        ema1 = (lv - ema1).mul_add(alpha, ema1);
908        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
909        ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
910        *out.get_unchecked_mut(src) = (ema3 - ema3_prev) * SCALE;
911        ema3_prev = ema3;
912        src += 1;
913
914        lv = (*p.add(src)).ln();
915        ema1 = (lv - ema1).mul_add(alpha, ema1);
916        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
917        ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
918        *out.get_unchecked_mut(src) = (ema3 - ema3_prev) * SCALE;
919        ema3_prev = ema3;
920        src += 1;
921    }
922
923    if src < len {
924        lv = (*p.add(src)).ln();
925        ema1 = (lv - ema1).mul_add(alpha, ema1);
926        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
927        ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
928        *out.get_unchecked_mut(src) = (ema3 - ema3_prev) * SCALE;
929    }
930}
931
932#[inline(always)]
933unsafe fn trix_row_scalar_with_logs(logs: &[f64], first: usize, period: usize, out: &mut [f64]) {
934    let len = logs.len();
935    let alpha = 2.0 / (period as f64 + 1.0);
936    let inv_n = 1.0 / period as f64;
937    const SCALE: f64 = 10000.0;
938
939    let warmup_end = first + 3 * (period - 1) + 1;
940    if warmup_end >= len {
941        return;
942    }
943
944    let p = logs.as_ptr();
945
946    let mut sum1 = 0.0;
947    let end1 = first + period;
948    let mut i = first;
949    while i < end1 {
950        sum1 += *p.add(i);
951        i += 1;
952    }
953    let mut ema1 = sum1 * inv_n;
954
955    let mut sum_ema1 = ema1;
956    let end2 = first + 2 * period - 1;
957    i = end1;
958
959    while i + 3 < end2 {
960        let mut lv = *p.add(i);
961        ema1 = (lv - ema1).mul_add(alpha, ema1);
962        sum_ema1 += ema1;
963
964        lv = *p.add(i + 1);
965        ema1 = (lv - ema1).mul_add(alpha, ema1);
966        sum_ema1 += ema1;
967
968        lv = *p.add(i + 2);
969        ema1 = (lv - ema1).mul_add(alpha, ema1);
970        sum_ema1 += ema1;
971
972        lv = *p.add(i + 3);
973        ema1 = (lv - ema1).mul_add(alpha, ema1);
974        sum_ema1 += ema1;
975        i += 4;
976    }
977    while i < end2 {
978        let lv = *p.add(i);
979        ema1 = (lv - ema1).mul_add(alpha, ema1);
980        sum_ema1 += ema1;
981        i += 1;
982    }
983
984    let mut ema2 = sum_ema1 * inv_n;
985
986    let mut sum_ema2 = ema2;
987    let end3 = first + 3 * period - 2;
988    i = end2;
989
990    while i + 3 < end3 {
991        let mut lv = *p.add(i);
992        ema1 = (lv - ema1).mul_add(alpha, ema1);
993        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
994        sum_ema2 += ema2;
995
996        lv = *p.add(i + 1);
997        ema1 = (lv - ema1).mul_add(alpha, ema1);
998        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
999        sum_ema2 += ema2;
1000
1001        lv = *p.add(i + 2);
1002        ema1 = (lv - ema1).mul_add(alpha, ema1);
1003        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
1004        sum_ema2 += ema2;
1005
1006        lv = *p.add(i + 3);
1007        ema1 = (lv - ema1).mul_add(alpha, ema1);
1008        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
1009        sum_ema2 += ema2;
1010        i += 4;
1011    }
1012    while i < end3 {
1013        let lv = *p.add(i);
1014        ema1 = (lv - ema1).mul_add(alpha, ema1);
1015        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
1016        sum_ema2 += ema2;
1017        i += 1;
1018    }
1019
1020    let mut ema3_prev = sum_ema2 * inv_n;
1021
1022    let mut src = warmup_end;
1023    let mut lv = *p.add(src);
1024    ema1 = (lv - ema1).mul_add(alpha, ema1);
1025    ema2 = (ema1 - ema2).mul_add(alpha, ema2);
1026    let mut ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
1027    *out.get_unchecked_mut(src) = (ema3 - ema3_prev) * SCALE;
1028    ema3_prev = ema3;
1029    src += 1;
1030
1031    while src + 3 < len {
1032        lv = *p.add(src);
1033        ema1 = (lv - ema1).mul_add(alpha, ema1);
1034        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
1035        ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
1036        *out.get_unchecked_mut(src) = (ema3 - ema3_prev) * SCALE;
1037        ema3_prev = ema3;
1038
1039        let lv1 = *p.add(src + 1);
1040        ema1 = (lv1 - ema1).mul_add(alpha, ema1);
1041        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
1042        ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
1043        *out.get_unchecked_mut(src + 1) = (ema3 - ema3_prev) * SCALE;
1044        ema3_prev = ema3;
1045
1046        let lv2 = *p.add(src + 2);
1047        ema1 = (lv2 - ema1).mul_add(alpha, ema1);
1048        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
1049        ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
1050        *out.get_unchecked_mut(src + 2) = (ema3 - ema3_prev) * SCALE;
1051        ema3_prev = ema3;
1052
1053        let lv3 = *p.add(src + 3);
1054        ema1 = (lv3 - ema1).mul_add(alpha, ema1);
1055        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
1056        ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
1057        *out.get_unchecked_mut(src + 3) = (ema3 - ema3_prev) * SCALE;
1058        ema3_prev = ema3;
1059
1060        src += 4;
1061    }
1062    while src < len {
1063        lv = *p.add(src);
1064        ema1 = (lv - ema1).mul_add(alpha, ema1);
1065        ema2 = (ema1 - ema2).mul_add(alpha, ema2);
1066        ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
1067        *out.get_unchecked_mut(src) = (ema3 - ema3_prev) * SCALE;
1068        ema3_prev = ema3;
1069        src += 1;
1070    }
1071}
1072
1073#[cfg(feature = "python")]
1074#[pyfunction(name = "trix")]
1075#[pyo3(signature = (data, period, kernel=None))]
1076pub fn trix_py<'py>(
1077    py: Python<'py>,
1078    data: PyReadonlyArray1<'py, f64>,
1079    period: usize,
1080    kernel: Option<&str>,
1081) -> PyResult<Bound<'py, PyArray1<f64>>> {
1082    use numpy::{IntoPyArray, PyArrayMethods};
1083
1084    let slice_in = data.as_slice()?;
1085    let kern = validate_kernel(kernel, false)?;
1086
1087    let params = TrixParams {
1088        period: Some(period),
1089    };
1090    let trix_in = TrixInput::from_slice(slice_in, params);
1091
1092    let result_vec: Vec<f64> = py
1093        .allow_threads(|| trix_with_kernel(&trix_in, kern).map(|o| o.values))
1094        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1095
1096    Ok(result_vec.into_pyarray(py))
1097}
1098
1099#[cfg(feature = "python")]
1100#[pyclass(name = "TrixStream")]
1101pub struct TrixStreamPy {
1102    stream: TrixStream,
1103}
1104
1105#[cfg(feature = "python")]
1106#[pymethods]
1107impl TrixStreamPy {
1108    #[new]
1109    fn new(period: usize) -> PyResult<Self> {
1110        let params = TrixParams {
1111            period: Some(period),
1112        };
1113        let stream =
1114            TrixStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1115        Ok(TrixStreamPy { stream })
1116    }
1117
1118    fn update(&mut self, value: f64) -> Option<f64> {
1119        self.stream.update(value)
1120    }
1121}
1122
1123#[cfg(all(feature = "python", feature = "cuda"))]
1124#[pyfunction(name = "trix_cuda_batch_dev")]
1125#[pyo3(signature = (data_f32, period_range, device_id=0))]
1126pub fn trix_cuda_batch_dev_py<'py>(
1127    py: Python<'py>,
1128    data_f32: numpy::PyReadonlyArray1<'py, f32>,
1129    period_range: (usize, usize, usize),
1130    device_id: usize,
1131) -> PyResult<(DeviceArrayF32Py, Bound<'py, PyDict>)> {
1132    use crate::cuda::cuda_available;
1133    use numpy::IntoPyArray;
1134
1135    if !cuda_available() {
1136        return Err(PyValueError::new_err("CUDA not available"));
1137    }
1138
1139    let slice_in = data_f32.as_slice()?;
1140    let sweep = TrixBatchRange {
1141        period: period_range,
1142    };
1143
1144    let inner = py.allow_threads(|| -> PyResult<_> {
1145        let cuda = CudaTrix::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1146        let arr = cuda
1147            .trix_batch_dev(slice_in, &sweep)
1148            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1149        cuda.synchronize()
1150            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1151        Ok(arr)
1152    })?;
1153
1154    let dict = PyDict::new(py);
1155    let (start, end, step) = period_range;
1156    let mut periods: Vec<u64> = Vec::new();
1157    if step == 0 {
1158        periods.push(start as u64);
1159    } else {
1160        let mut p = start;
1161        while p <= end {
1162            periods.push(p as u64);
1163            p = p.saturating_add(step);
1164        }
1165    }
1166    dict.set_item("periods", periods.into_pyarray(py))?;
1167
1168    let handle = make_device_array_py(device_id, inner)?;
1169    Ok((handle, dict))
1170}
1171
1172#[cfg(all(feature = "python", feature = "cuda"))]
1173#[pyfunction(name = "trix_cuda_many_series_one_param_dev")]
1174#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1175pub fn trix_cuda_many_series_one_param_dev_py(
1176    py: Python<'_>,
1177    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1178    period: usize,
1179    device_id: usize,
1180) -> PyResult<DeviceArrayF32Py> {
1181    use crate::cuda::cuda_available;
1182    use numpy::PyUntypedArrayMethods;
1183
1184    if !cuda_available() {
1185        return Err(PyValueError::new_err("CUDA not available"));
1186    }
1187
1188    let flat_in = data_tm_f32.as_slice()?;
1189    let rows = data_tm_f32.shape()[0];
1190    let cols = data_tm_f32.shape()[1];
1191
1192    let inner = py.allow_threads(|| -> PyResult<_> {
1193        let cuda = CudaTrix::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1194        let arr = cuda
1195            .trix_many_series_one_param_time_major_dev(flat_in, cols, rows, period)
1196            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1197        cuda.synchronize()
1198            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1199        Ok(arr)
1200    })?;
1201
1202    let handle = make_device_array_py(device_id, inner)?;
1203    Ok(handle)
1204}
1205
1206#[cfg(feature = "python")]
1207#[pyfunction(name = "trix_batch")]
1208#[pyo3(signature = (data, period_range, kernel=None))]
1209pub fn trix_batch_py<'py>(
1210    py: Python<'py>,
1211    data: PyReadonlyArray1<'py, f64>,
1212    period_range: (usize, usize, usize),
1213    kernel: Option<&str>,
1214) -> PyResult<Bound<'py, PyDict>> {
1215    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1216    use pyo3::types::PyDict;
1217
1218    let slice_in = data.as_slice()?;
1219    let kern = validate_kernel(kernel, true)?;
1220
1221    let sweep = TrixBatchRange {
1222        period: period_range,
1223    };
1224    let combos_probe = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1225    let rows = combos_probe.len();
1226    let cols = slice_in.len();
1227
1228    let total = rows
1229        .checked_mul(cols)
1230        .ok_or_else(|| PyValueError::new_err("trix_batch_py: rows*cols overflow"))?;
1231    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1232    let slice_out = unsafe { out_arr.as_slice_mut()? };
1233
1234    let first = slice_in
1235        .iter()
1236        .position(|x| !x.is_nan())
1237        .ok_or_else(|| PyValueError::new_err("AllValuesNaN"))?;
1238    for (r, prm) in combos_probe.iter().enumerate() {
1239        let warm = first + 3 * (prm.period.unwrap() - 1) + 1;
1240        let start = r * cols;
1241        let end = start + warm.min(cols);
1242        for v in &mut slice_out[start..end] {
1243            *v = f64::NAN;
1244        }
1245    }
1246
1247    let combos = py
1248        .allow_threads(|| {
1249            let kernel = match kern {
1250                Kernel::Auto => Kernel::ScalarBatch,
1251                k => k,
1252            };
1253            let simd = match kernel {
1254                Kernel::Avx512Batch => Kernel::Avx512,
1255                Kernel::Avx2Batch => Kernel::Avx2,
1256                Kernel::ScalarBatch => Kernel::Scalar,
1257                _ => unreachable!(),
1258            };
1259            trix_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1260        })
1261        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1262
1263    let dict = PyDict::new(py);
1264    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1265    dict.set_item(
1266        "periods",
1267        combos
1268            .iter()
1269            .map(|p| p.period.unwrap() as u64)
1270            .collect::<Vec<_>>()
1271            .into_pyarray(py),
1272    )?;
1273    Ok(dict.into())
1274}
1275
1276#[inline(always)]
1277fn trix_batch_inner_into(
1278    data: &[f64],
1279    sweep: &TrixBatchRange,
1280    kern: Kernel,
1281    parallel: bool,
1282    out: &mut [f64],
1283) -> Result<Vec<TrixParams>, TrixError> {
1284    let combos = expand_grid(sweep)?;
1285    let first = data
1286        .iter()
1287        .position(|x| !x.is_nan())
1288        .ok_or(TrixError::AllValuesNaN)?;
1289    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1290    let needed = trix_needed_len(max_p)?;
1291    if data.len() - first < needed {
1292        return Err(TrixError::NotEnoughValidData {
1293            needed,
1294            valid: data.len() - first,
1295        });
1296    }
1297    let rows = combos.len();
1298    let cols = data.len();
1299
1300    let mut logs: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, cols);
1301    unsafe { logs.set_len(cols) };
1302    for i in 0..first {
1303        logs[i] = 0.0;
1304    }
1305    for i in first..cols {
1306        logs[i] = data[i].ln();
1307    }
1308
1309    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1310        let period = combos[row].period.unwrap();
1311        trix_row_scalar_with_logs(&logs, first, period, out_row)
1312    };
1313
1314    if parallel {
1315        #[cfg(not(target_arch = "wasm32"))]
1316        {
1317            out.par_chunks_mut(cols)
1318                .enumerate()
1319                .for_each(|(row, slice)| do_row(row, slice));
1320        }
1321
1322        #[cfg(target_arch = "wasm32")]
1323        {
1324            for (row, slice) in out.chunks_mut(cols).enumerate() {
1325                do_row(row, slice);
1326            }
1327        }
1328    } else {
1329        for (row, slice) in out.chunks_mut(cols).enumerate() {
1330            do_row(row, slice);
1331        }
1332    }
1333    Ok(combos)
1334}
1335
1336#[inline(always)]
1337pub fn trix_into_slice(dst: &mut [f64], input: &TrixInput, kern: Kernel) -> Result<(), TrixError> {
1338    let (data, period, first, chosen, alpha, warmup_end) = trix_prepare(input, kern)?;
1339    if dst.len() != data.len() {
1340        return Err(TrixError::OutputLengthMismatch {
1341            expected: data.len(),
1342            got: dst.len(),
1343        });
1344    }
1345
1346    let warmup_len = warmup_end.min(dst.len());
1347    for v in &mut dst[..warmup_len] {
1348        *v = f64::NAN;
1349    }
1350    unsafe {
1351        match chosen {
1352            Kernel::Scalar | Kernel::ScalarBatch => {
1353                trix_compute_into_scalar(data, period, first, alpha, dst)
1354            }
1355            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1356            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
1357                trix_compute_into_scalar(data, period, first, alpha, dst)
1358            }
1359            #[allow(unreachable_patterns)]
1360            _ => trix_compute_into_scalar(data, period, first, alpha, dst),
1361        }
1362    }
1363    Ok(())
1364}
1365
1366#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1367#[wasm_bindgen]
1368pub fn trix_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1369    let params = TrixParams {
1370        period: Some(period),
1371    };
1372    let input = TrixInput::from_slice(data, params);
1373    let mut output = vec![f64::NAN; data.len()];
1374    trix_into_slice(&mut output, &input, Kernel::Scalar)
1375        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1376    Ok(output)
1377}
1378
1379#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1380#[wasm_bindgen]
1381pub fn trix_alloc(len: usize) -> *mut f64 {
1382    let mut vec = Vec::<f64>::with_capacity(len);
1383    let ptr = vec.as_mut_ptr();
1384    std::mem::forget(vec);
1385    ptr
1386}
1387
1388#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1389#[wasm_bindgen]
1390pub fn trix_free(ptr: *mut f64, len: usize) {
1391    if !ptr.is_null() {
1392        unsafe {
1393            let _ = Vec::from_raw_parts(ptr, len, len);
1394        }
1395    }
1396}
1397
1398#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1399#[wasm_bindgen]
1400pub fn trix_into(
1401    in_ptr: *const f64,
1402    out_ptr: *mut f64,
1403    len: usize,
1404    period: usize,
1405) -> Result<(), JsValue> {
1406    if in_ptr.is_null() || out_ptr.is_null() {
1407        return Err(JsValue::from_str("Null pointer provided"));
1408    }
1409    unsafe {
1410        let data = std::slice::from_raw_parts(in_ptr, len);
1411        let params = TrixParams {
1412            period: Some(period),
1413        };
1414        let input = TrixInput::from_slice(data, params);
1415        if in_ptr == out_ptr {
1416            let mut tmp = vec![f64::NAN; len];
1417            trix_into_slice(&mut tmp, &input, Kernel::Scalar)
1418                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1419            std::slice::from_raw_parts_mut(out_ptr, len).copy_from_slice(&tmp);
1420        } else {
1421            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1422            trix_into_slice(out, &input, Kernel::Scalar)
1423                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1424        }
1425    }
1426    Ok(())
1427}
1428
1429#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1430#[derive(Serialize, Deserialize)]
1431pub struct TrixBatchConfig {
1432    pub period_range: (usize, usize, usize),
1433}
1434
1435#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1436#[derive(Serialize, Deserialize)]
1437pub struct TrixBatchJsOutput {
1438    pub values: Vec<f64>,
1439    pub periods: Vec<usize>,
1440    pub rows: usize,
1441    pub cols: usize,
1442}
1443
1444#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1445#[wasm_bindgen(js_name = trix_batch)]
1446pub fn trix_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1447    let config: TrixBatchConfig = serde_wasm_bindgen::from_value(config)
1448        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1449
1450    let sweep = TrixBatchRange {
1451        period: config.period_range,
1452    };
1453
1454    let output = trix_batch_inner(data, &sweep, Kernel::Scalar, false)
1455        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1456
1457    let periods: Vec<usize> = output.combos.iter().map(|p| p.period.unwrap()).collect();
1458
1459    let js_output = TrixBatchJsOutput {
1460        values: output.values,
1461        periods,
1462        rows: output.rows,
1463        cols: output.cols,
1464    };
1465
1466    serde_wasm_bindgen::to_value(&js_output)
1467        .map_err(|e| JsValue::from_str(&format!("Failed to serialize output: {}", e)))
1468}
1469
1470#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1471#[wasm_bindgen]
1472pub fn trix_batch_into(
1473    in_ptr: *const f64,
1474    out_ptr: *mut f64,
1475    len: usize,
1476    period_start: usize,
1477    period_end: usize,
1478    period_step: usize,
1479) -> Result<usize, JsValue> {
1480    if in_ptr.is_null() || out_ptr.is_null() {
1481        return Err(JsValue::from_str("Null pointer provided"));
1482    }
1483
1484    unsafe {
1485        let data = std::slice::from_raw_parts(in_ptr, len);
1486
1487        let sweep = TrixBatchRange {
1488            period: (period_start, period_end, period_step),
1489        };
1490
1491        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1492        let num_combos = combos.len();
1493        let total_size = num_combos
1494            .checked_mul(len)
1495            .ok_or_else(|| JsValue::from_str("trix_batch_into: rows*cols overflow"))?;
1496
1497        let out = std::slice::from_raw_parts_mut(out_ptr, total_size);
1498
1499        trix_batch_inner_into(data, &sweep, Kernel::Scalar, false, out)
1500            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1501
1502        Ok(num_combos)
1503    }
1504}
1505
1506#[cfg(feature = "python")]
1507pub fn register_trix_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
1508    m.add_function(wrap_pyfunction!(trix_py, m)?)?;
1509    m.add_function(wrap_pyfunction!(trix_batch_py, m)?)?;
1510    m.add_class::<TrixStreamPy>()?;
1511    Ok(())
1512}
1513
1514#[cfg(test)]
1515mod tests_into {
1516    use super::*;
1517
1518    fn eq_or_both_nan(a: f64, b: f64) -> bool {
1519        (a.is_nan() && b.is_nan()) || (a == b)
1520    }
1521
1522    #[test]
1523    fn test_trix_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1524        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1525        let candles = crate::utilities::data_loader::read_candles_from_csv(file_path)?;
1526
1527        let params = TrixParams::default();
1528        let input = TrixInput::from_candles(&candles, "close", params);
1529
1530        let baseline = trix(&input)?.values;
1531
1532        let mut out = vec![0.0; baseline.len()];
1533
1534        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1535        {
1536            trix_into(&input, &mut out)?;
1537        }
1538
1539        assert_eq!(baseline.len(), out.len());
1540        for i in 0..out.len() {
1541            let a = baseline[i];
1542            let b = out[i];
1543            if a.is_nan() || b.is_nan() {
1544                assert!(
1545                    eq_or_both_nan(a, b),
1546                    "NaN mismatch at index {}: {:?} vs {:?}",
1547                    i,
1548                    a,
1549                    b
1550                );
1551            } else {
1552                let diff = (a - b).abs();
1553                assert!(
1554                    diff <= 1e-12,
1555                    "Value mismatch at {}: a={} b={} diff={}",
1556                    i,
1557                    a,
1558                    b,
1559                    diff
1560                );
1561            }
1562        }
1563        Ok(())
1564    }
1565}
1566
1567#[cfg(test)]
1568mod tests {
1569    use super::*;
1570    use crate::skip_if_unsupported;
1571    use crate::utilities::data_loader::read_candles_from_csv;
1572
1573    fn check_trix_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1574        skip_if_unsupported!(kernel, test_name);
1575        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1576        let candles = read_candles_from_csv(file_path)?;
1577        let default_params = TrixParams { period: None };
1578        let input_default = TrixInput::from_candles(&candles, "close", default_params);
1579        let output_default = trix_with_kernel(&input_default, kernel)?;
1580        assert_eq!(output_default.values.len(), candles.close.len());
1581        let params_period_14 = TrixParams { period: Some(14) };
1582        let input_period_14 = TrixInput::from_candles(&candles, "hl2", params_period_14);
1583        let output_period_14 = trix_with_kernel(&input_period_14, kernel)?;
1584        assert_eq!(output_period_14.values.len(), candles.close.len());
1585        let params_custom = TrixParams { period: Some(20) };
1586        let input_custom = TrixInput::from_candles(&candles, "hlc3", params_custom);
1587        let output_custom = trix_with_kernel(&input_custom, kernel)?;
1588        assert_eq!(output_custom.values.len(), candles.close.len());
1589        Ok(())
1590    }
1591
1592    fn check_trix_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1593        skip_if_unsupported!(kernel, test_name);
1594        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1595        let candles = read_candles_from_csv(file_path)?;
1596        let close_prices = candles.select_candle_field("close")?;
1597        let params = TrixParams { period: Some(18) };
1598        let input = TrixInput::from_candles(&candles, "close", params);
1599        let trix_result = trix_with_kernel(&input, kernel)?;
1600        assert_eq!(
1601            trix_result.values.len(),
1602            close_prices.len(),
1603            "TRIX length mismatch"
1604        );
1605        let expected_last_five = [
1606            -16.03736447,
1607            -15.92084231,
1608            -15.76171478,
1609            -15.53571033,
1610            -15.34967155,
1611        ];
1612        assert!(trix_result.values.len() >= 5, "TRIX length too short");
1613        let start_index = trix_result.values.len() - 5;
1614        let result_last_five = &trix_result.values[start_index..];
1615        for (i, &value) in result_last_five.iter().enumerate() {
1616            let expected_value = expected_last_five[i];
1617
1618            let tolerance = 0.3;
1619            assert!(
1620                (value - expected_value).abs() < tolerance,
1621                "TRIX mismatch at index {}: expected {}, got {}, diff={}",
1622                i,
1623                expected_value,
1624                value,
1625                (value - expected_value).abs()
1626            );
1627        }
1628        Ok(())
1629    }
1630
1631    fn check_trix_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1632        skip_if_unsupported!(kernel, test_name);
1633        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1634        let candles = read_candles_from_csv(file_path)?;
1635        let input = TrixInput::with_default_candles(&candles);
1636        match input.data {
1637            TrixData::Candles { source, .. } => assert_eq!(source, "close"),
1638            _ => panic!("Expected TrixData::Candles"),
1639        }
1640        Ok(())
1641    }
1642
1643    fn check_trix_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1644        skip_if_unsupported!(kernel, test_name);
1645        let input_data = [10.0, 20.0, 30.0];
1646        let params = TrixParams { period: Some(0) };
1647        let input = TrixInput::from_slice(&input_data, params);
1648        let res = trix_with_kernel(&input, kernel);
1649        assert!(
1650            res.is_err(),
1651            "[{}] TRIX should fail with zero period",
1652            test_name
1653        );
1654        Ok(())
1655    }
1656
1657    fn check_trix_period_exceeds_length(
1658        test_name: &str,
1659        kernel: Kernel,
1660    ) -> Result<(), Box<dyn Error>> {
1661        skip_if_unsupported!(kernel, test_name);
1662        let data_small = [10.0, 20.0, 30.0];
1663        let params = TrixParams { period: Some(10) };
1664        let input = TrixInput::from_slice(&data_small, params);
1665        let res = trix_with_kernel(&input, kernel);
1666        assert!(
1667            res.is_err(),
1668            "[{}] TRIX should fail with period exceeding length",
1669            test_name
1670        );
1671        Ok(())
1672    }
1673
1674    fn check_trix_very_small_dataset(
1675        test_name: &str,
1676        kernel: Kernel,
1677    ) -> Result<(), Box<dyn Error>> {
1678        skip_if_unsupported!(kernel, test_name);
1679        let single_point = [42.0];
1680        let params = TrixParams { period: Some(18) };
1681        let input = TrixInput::from_slice(&single_point, params);
1682        let res = trix_with_kernel(&input, kernel);
1683        assert!(
1684            res.is_err(),
1685            "[{}] TRIX should fail with insufficient data",
1686            test_name
1687        );
1688        Ok(())
1689    }
1690
1691    fn check_trix_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1692        skip_if_unsupported!(kernel, test_name);
1693        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1694        let candles = read_candles_from_csv(file_path)?;
1695        let params = TrixParams { period: Some(10) };
1696        let input = TrixInput::from_candles(&candles, "close", params);
1697        let first_result = trix_with_kernel(&input, kernel)?;
1698        let second_input =
1699            TrixInput::from_slice(&first_result.values, TrixParams { period: Some(10) });
1700        let second_result = trix_with_kernel(&second_input, kernel)?;
1701        assert_eq!(first_result.values.len(), second_result.values.len());
1702        Ok(())
1703    }
1704
1705    #[cfg(debug_assertions)]
1706    fn check_trix_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1707        skip_if_unsupported!(kernel, test_name);
1708
1709        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1710        let candles = read_candles_from_csv(file_path)?;
1711
1712        let test_params = vec![
1713            TrixParams::default(),
1714            TrixParams { period: Some(2) },
1715            TrixParams { period: Some(5) },
1716            TrixParams { period: Some(10) },
1717            TrixParams { period: Some(14) },
1718            TrixParams { period: Some(20) },
1719            TrixParams { period: Some(30) },
1720            TrixParams { period: Some(50) },
1721            TrixParams { period: Some(100) },
1722            TrixParams { period: Some(200) },
1723        ];
1724
1725        for (param_idx, params) in test_params.iter().enumerate() {
1726            let input = TrixInput::from_candles(&candles, "close", params.clone());
1727            let output = trix_with_kernel(&input, kernel)?;
1728
1729            for (i, &val) in output.values.iter().enumerate() {
1730                if val.is_nan() {
1731                    continue;
1732                }
1733
1734                let bits = val.to_bits();
1735
1736                if bits == 0x11111111_11111111 {
1737                    panic!(
1738                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1739						 with params: period={} (param set {})",
1740                        test_name,
1741                        val,
1742                        bits,
1743                        i,
1744                        params.period.unwrap_or(18),
1745                        param_idx
1746                    );
1747                }
1748
1749                if bits == 0x22222222_22222222 {
1750                    panic!(
1751                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1752						 with params: period={} (param set {})",
1753                        test_name,
1754                        val,
1755                        bits,
1756                        i,
1757                        params.period.unwrap_or(18),
1758                        param_idx
1759                    );
1760                }
1761
1762                if bits == 0x33333333_33333333 {
1763                    panic!(
1764                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1765						 with params: period={} (param set {})",
1766                        test_name,
1767                        val,
1768                        bits,
1769                        i,
1770                        params.period.unwrap_or(18),
1771                        param_idx
1772                    );
1773                }
1774            }
1775        }
1776
1777        Ok(())
1778    }
1779
1780    #[cfg(not(debug_assertions))]
1781    fn check_trix_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1782        Ok(())
1783    }
1784
1785    macro_rules! generate_all_trix_tests {
1786        ($($test_fn:ident),*) => {
1787            paste::paste! {
1788                $(
1789                    #[test]
1790                    fn [<$test_fn _scalar_f64>]() {
1791                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1792                    }
1793                )*
1794                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1795                $(
1796                    #[test]
1797                    fn [<$test_fn _avx2_f64>]() {
1798                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1799                    }
1800                    #[test]
1801                    fn [<$test_fn _avx512_f64>]() {
1802                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1803                    }
1804                )*
1805            }
1806        }
1807    }
1808    generate_all_trix_tests!(
1809        check_trix_partial_params,
1810        check_trix_accuracy,
1811        check_trix_default_candles,
1812        check_trix_zero_period,
1813        check_trix_period_exceeds_length,
1814        check_trix_very_small_dataset,
1815        check_trix_reinput,
1816        check_trix_no_poison
1817    );
1818
1819    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1820        skip_if_unsupported!(kernel, test);
1821        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1822        let c = read_candles_from_csv(file)?;
1823        let output = TrixBatchBuilder::new()
1824            .kernel(kernel)
1825            .apply_candles(&c, "close")?;
1826        let def = TrixParams::default();
1827        let row = output.values_for(&def).expect("default row missing");
1828        assert_eq!(row.len(), c.close.len());
1829        let expected = [
1830            -16.03736447,
1831            -15.92084231,
1832            -15.76171478,
1833            -15.53571033,
1834            -15.34967155,
1835        ];
1836        let start = row.len() - 5;
1837        for (i, &v) in row[start..].iter().enumerate() {
1838            let tolerance = 0.3;
1839            assert!(
1840                (v - expected[i]).abs() < tolerance,
1841                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}, diff={}",
1842                (v - expected[i]).abs()
1843            );
1844        }
1845        Ok(())
1846    }
1847
1848    #[cfg(debug_assertions)]
1849    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1850        skip_if_unsupported!(kernel, test);
1851
1852        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1853        let c = read_candles_from_csv(file)?;
1854
1855        let test_configs = vec![
1856            (2, 10, 2),
1857            (10, 30, 5),
1858            (30, 100, 10),
1859            (2, 5, 1),
1860            (18, 18, 0),
1861            (5, 25, 5),
1862            (50, 100, 25),
1863            (14, 28, 7),
1864        ];
1865
1866        for (cfg_idx, &(period_start, period_end, period_step)) in test_configs.iter().enumerate() {
1867            let output = TrixBatchBuilder::new()
1868                .kernel(kernel)
1869                .period_range(period_start, period_end, period_step)
1870                .apply_candles(&c, "close")?;
1871
1872            for (idx, &val) in output.values.iter().enumerate() {
1873                if val.is_nan() {
1874                    continue;
1875                }
1876
1877                let bits = val.to_bits();
1878                let row = idx / output.cols;
1879                let col = idx % output.cols;
1880                let combo = &output.combos[row];
1881
1882                if bits == 0x11111111_11111111 {
1883                    panic!(
1884                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1885						 at row {} col {} (flat index {}) with params: period={}",
1886                        test,
1887                        cfg_idx,
1888                        val,
1889                        bits,
1890                        row,
1891                        col,
1892                        idx,
1893                        combo.period.unwrap_or(18)
1894                    );
1895                }
1896
1897                if bits == 0x22222222_22222222 {
1898                    panic!(
1899                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1900						 at row {} col {} (flat index {}) with params: period={}",
1901                        test,
1902                        cfg_idx,
1903                        val,
1904                        bits,
1905                        row,
1906                        col,
1907                        idx,
1908                        combo.period.unwrap_or(18)
1909                    );
1910                }
1911
1912                if bits == 0x33333333_33333333 {
1913                    panic!(
1914                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1915						 at row {} col {} (flat index {}) with params: period={}",
1916                        test,
1917                        cfg_idx,
1918                        val,
1919                        bits,
1920                        row,
1921                        col,
1922                        idx,
1923                        combo.period.unwrap_or(18)
1924                    );
1925                }
1926            }
1927        }
1928
1929        Ok(())
1930    }
1931
1932    #[cfg(not(debug_assertions))]
1933    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1934        Ok(())
1935    }
1936
1937    macro_rules! gen_batch_tests {
1938        ($fn_name:ident) => {
1939            paste::paste! {
1940                #[test] fn [<$fn_name _scalar>]()      {
1941                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1942                }
1943                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1944                #[test] fn [<$fn_name _avx2>]()        {
1945                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1946                }
1947                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1948                #[test] fn [<$fn_name _avx512>]()      {
1949                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1950                }
1951                #[test] fn [<$fn_name _auto_detect>]() {
1952                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1953                }
1954            }
1955        };
1956    }
1957    gen_batch_tests!(check_batch_default_row);
1958    gen_batch_tests!(check_batch_no_poison);
1959
1960    #[cfg(feature = "proptest")]
1961    #[allow(clippy::float_cmp)]
1962    fn check_trix_property(
1963        test_name: &str,
1964        kernel: Kernel,
1965    ) -> Result<(), Box<dyn std::error::Error>> {
1966        use proptest::prelude::*;
1967        skip_if_unsupported!(kernel, test_name);
1968
1969        let strat = (2usize..=20).prop_flat_map(|period| {
1970            let min_data_needed = 3 * (period - 1) + 1 + 10;
1971            (
1972                prop::collection::vec(
1973                    (0.001f64..1e6f64)
1974                        .prop_filter("positive finite", |x| x.is_finite() && *x > 0.0),
1975                    min_data_needed..400,
1976                ),
1977                Just(period),
1978            )
1979        });
1980
1981        proptest::test_runner::TestRunner::default()
1982            .run(&strat, |(data, period)| {
1983                let params = TrixParams {
1984                    period: Some(period),
1985                };
1986                let input = TrixInput::from_slice(&data, params);
1987
1988                let TrixOutput { values: out } = trix_with_kernel(&input, kernel).unwrap();
1989
1990                let TrixOutput { values: ref_out } =
1991                    trix_with_kernel(&input, Kernel::Scalar).unwrap();
1992
1993                let warmup_period = 3 * (period - 1) + 1;
1994                for i in 0..warmup_period.min(data.len()) {
1995                    prop_assert!(
1996                        out[i].is_nan(),
1997                        "Expected NaN during warmup at index {}, got {}",
1998                        i,
1999                        out[i]
2000                    );
2001                }
2002
2003                if data.len() > warmup_period {
2004                    prop_assert!(
2005                        !out[warmup_period].is_nan(),
2006                        "Expected valid value at index {} (after warmup), got NaN",
2007                        warmup_period
2008                    );
2009                }
2010
2011                if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10)
2012                    && data.len() > warmup_period
2013                {
2014                    for i in warmup_period..data.len() {
2015                        prop_assert!(
2016                            out[i].abs() < 1e-6,
2017                            "TRIX should be near zero for constant data at index {}: got {}",
2018                            i,
2019                            out[i]
2020                        );
2021                    }
2022                }
2023
2024                let increasing_count = data.windows(2).filter(|w| w[1] > w[0]).count();
2025                let is_mostly_increasing = increasing_count as f64 > data.len() as f64 * 0.8;
2026                if is_mostly_increasing && data.len() > warmup_period + 10 {
2027                    let last_values: Vec<f64> = out[(data.len() - 5)..data.len()]
2028                        .iter()
2029                        .filter(|&&v| !v.is_nan())
2030                        .copied()
2031                        .collect();
2032                    if !last_values.is_empty() {
2033                        let avg = last_values.iter().sum::<f64>() / last_values.len() as f64;
2034                        prop_assert!(
2035                            avg > -10.0,
2036                            "TRIX average should be positive for mostly increasing data: got {}",
2037                            avg
2038                        );
2039                    }
2040                }
2041
2042                let decreasing_count = data.windows(2).filter(|w| w[1] < w[0]).count();
2043                let is_mostly_decreasing = decreasing_count as f64 > data.len() as f64 * 0.8;
2044                if is_mostly_decreasing && data.len() > warmup_period + 10 {
2045                    let last_values: Vec<f64> = out[(data.len() - 5)..data.len()]
2046                        .iter()
2047                        .filter(|&&v| !v.is_nan())
2048                        .copied()
2049                        .collect();
2050                    if !last_values.is_empty() {
2051                        let avg = last_values.iter().sum::<f64>() / last_values.len() as f64;
2052                        prop_assert!(
2053                            avg < 10.0,
2054                            "TRIX average should be negative for mostly decreasing data: got {}",
2055                            avg
2056                        );
2057                    }
2058                }
2059
2060                for i in warmup_period..data.len() {
2061                    if !out[i].is_nan() {
2062                        prop_assert!(
2063                            out[i].abs() < 100000.0,
2064                            "TRIX value too large at index {}: {}",
2065                            i,
2066                            out[i]
2067                        );
2068                    }
2069                }
2070
2071                for (i, &val) in out.iter().enumerate() {
2072                    prop_assert!(
2073                        val.is_nan() || val.is_finite(),
2074                        "TRIX should not produce infinite values at index {}: got {}",
2075                        i,
2076                        val
2077                    );
2078                }
2079
2080                if data.len() > warmup_period + 20 {
2081                    let log_returns: Vec<f64> = data
2082                        .windows(2)
2083                        .skip(warmup_period)
2084                        .map(|w| (w[1] / w[0]).ln() * 10000.0)
2085                        .collect();
2086
2087                    let trix_values: Vec<f64> = out
2088                        .iter()
2089                        .skip(warmup_period + 1)
2090                        .filter(|&&v| !v.is_nan())
2091                        .copied()
2092                        .collect();
2093
2094                    if log_returns.len() > 10 && trix_values.len() > 10 {
2095                        let log_std = calculate_std(&log_returns);
2096                        let trix_std = calculate_std(&trix_values);
2097
2098                        prop_assert!(
2099							trix_std <= log_std * 1.2 || trix_std < 1.0,
2100							"TRIX should be smoother than log returns: TRIX std={}, log return std={}",
2101							trix_std,
2102							log_std
2103						);
2104                    }
2105                }
2106
2107                for i in warmup_period..data.len() {
2108                    let y = out[i];
2109                    let r = ref_out[i];
2110
2111                    if !y.is_finite() || !r.is_finite() {
2112                        prop_assert!(
2113                            y.to_bits() == r.to_bits(),
2114                            "finite/NaN mismatch at index {}: {} vs {}",
2115                            i,
2116                            y,
2117                            r
2118                        );
2119                        continue;
2120                    }
2121
2122                    let y_bits = y.to_bits();
2123                    let r_bits = r.to_bits();
2124                    let ulp_diff: u64 = y_bits.abs_diff(r_bits);
2125
2126                    prop_assert!(
2127                        (y - r).abs() <= 1e-9 || ulp_diff <= 4,
2128                        "Kernel mismatch at index {}: {} vs {} (ULP={})",
2129                        i,
2130                        y,
2131                        r,
2132                        ulp_diff
2133                    );
2134                }
2135
2136                let TrixOutput { values: out2 } = trix_with_kernel(&input, kernel).unwrap();
2137                prop_assert_eq!(
2138                    out.len(),
2139                    out2.len(),
2140                    "Output length mismatch on second run"
2141                );
2142                for i in 0..out.len() {
2143                    prop_assert!(
2144                        out[i].to_bits() == out2[i].to_bits(),
2145                        "Determinism failed at index {}: {} vs {}",
2146                        i,
2147                        out[i],
2148                        out2[i]
2149                    );
2150                }
2151
2152                Ok(())
2153            })
2154            .unwrap();
2155
2156        let edge_data = vec![0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0, 10000.0];
2157        let params = TrixParams { period: Some(2) };
2158        let input = TrixInput::from_slice(&edge_data, params);
2159        let result = trix_with_kernel(&input, kernel);
2160        assert!(
2161            result.is_ok(),
2162            "TRIX should handle very small positive values"
2163        );
2164
2165        Ok(())
2166    }
2167
2168    fn calculate_std(values: &[f64]) -> f64 {
2169        let mean = values.iter().sum::<f64>() / values.len() as f64;
2170        let variance =
2171            values.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
2172        variance.sqrt()
2173    }
2174
2175    #[cfg(feature = "proptest")]
2176    generate_all_trix_tests!(check_trix_property);
2177}