1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{alloc_with_nan_prefix, init_matrix_prefixes, make_uninit_matrix};
4use aligned_vec::{AVec, CACHELINE_ALIGN};
5#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
6use core::arch::x86_64::*;
7#[cfg(not(target_arch = "wasm32"))]
8use rayon::prelude::*;
9use std::convert::AsRef;
10use std::error::Error;
11use thiserror::Error;
12
13#[cfg(all(feature = "python", feature = "cuda"))]
14use crate::cuda::moving_averages::CudaTrix;
15#[cfg(all(feature = "python", feature = "cuda"))]
16use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
17#[cfg(feature = "python")]
18use crate::utilities::kernel_validation::validate_kernel;
19#[cfg(feature = "python")]
20use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
21#[cfg(feature = "python")]
22use pyo3::exceptions::PyValueError;
23#[cfg(feature = "python")]
24use pyo3::prelude::*;
25#[cfg(feature = "python")]
26use pyo3::types::PyDict;
27
28#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
29use serde::{Deserialize, Serialize};
30#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
31use wasm_bindgen::prelude::*;
32
33impl<'a> AsRef<[f64]> for TrixInput<'a> {
34 #[inline(always)]
35 fn as_ref(&self) -> &[f64] {
36 match &self.data {
37 TrixData::Slice(slice) => slice,
38 TrixData::Candles { candles, source } => source_type(candles, source),
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
44pub enum TrixData<'a> {
45 Candles {
46 candles: &'a Candles,
47 source: &'a str,
48 },
49 Slice(&'a [f64]),
50}
51
52#[derive(Debug, Clone)]
53pub struct TrixOutput {
54 pub values: Vec<f64>,
55}
56
57#[derive(Debug, Clone)]
58pub struct TrixParams {
59 pub period: Option<usize>,
60}
61
62impl Default for TrixParams {
63 fn default() -> Self {
64 Self { period: Some(18) }
65 }
66}
67
68#[derive(Debug, Clone)]
69pub struct TrixInput<'a> {
70 pub data: TrixData<'a>,
71 pub params: TrixParams,
72}
73
74impl<'a> TrixInput<'a> {
75 #[inline]
76 pub fn from_candles(c: &'a Candles, s: &'a str, p: TrixParams) -> Self {
77 Self {
78 data: TrixData::Candles {
79 candles: c,
80 source: s,
81 },
82 params: p,
83 }
84 }
85 #[inline]
86 pub fn from_slice(sl: &'a [f64], p: TrixParams) -> Self {
87 Self {
88 data: TrixData::Slice(sl),
89 params: p,
90 }
91 }
92 #[inline]
93 pub fn with_default_candles(c: &'a Candles) -> Self {
94 Self::from_candles(c, "close", TrixParams::default())
95 }
96 #[inline]
97 pub fn get_period(&self) -> usize {
98 self.params.period.unwrap_or(18)
99 }
100}
101
102#[derive(Copy, Clone, Debug)]
103pub struct TrixBuilder {
104 period: Option<usize>,
105 kernel: Kernel,
106}
107
108impl Default for TrixBuilder {
109 fn default() -> Self {
110 Self {
111 period: None,
112 kernel: Kernel::Auto,
113 }
114 }
115}
116
117impl TrixBuilder {
118 #[inline(always)]
119 pub fn new() -> Self {
120 Self::default()
121 }
122 #[inline(always)]
123 pub fn period(mut self, n: usize) -> Self {
124 self.period = Some(n);
125 self
126 }
127 #[inline(always)]
128 pub fn kernel(mut self, k: Kernel) -> Self {
129 self.kernel = k;
130 self
131 }
132 #[inline(always)]
133 pub fn apply(self, c: &Candles) -> Result<TrixOutput, TrixError> {
134 let p = TrixParams {
135 period: self.period,
136 };
137 let i = TrixInput::from_candles(c, "close", p);
138 trix_with_kernel(&i, self.kernel)
139 }
140 #[inline(always)]
141 pub fn apply_slice(self, d: &[f64]) -> Result<TrixOutput, TrixError> {
142 let p = TrixParams {
143 period: self.period,
144 };
145 let i = TrixInput::from_slice(d, p);
146 trix_with_kernel(&i, self.kernel)
147 }
148 #[inline(always)]
149 pub fn into_stream(self) -> Result<TrixStream, TrixError> {
150 let p = TrixParams {
151 period: self.period,
152 };
153 TrixStream::try_new(p)
154 }
155}
156
157#[derive(Debug, Error)]
158pub enum TrixError {
159 #[error("trix: Empty data provided.")]
160 EmptyInputData,
161 #[error("trix: Invalid period: period = {period}, data length = {data_len}")]
162 InvalidPeriod { period: usize, data_len: usize },
163 #[error("trix: Not enough valid data: needed = {needed}, valid = {valid}")]
164 NotEnoughValidData { needed: usize, valid: usize },
165 #[error("trix: All values are NaN.")]
166 AllValuesNaN,
167 #[error("trix: output length mismatch: expected = {expected}, got = {got}")]
168 OutputLengthMismatch { expected: usize, got: usize },
169 #[error("trix: invalid range: start={start}, end={end}, step={step}")]
170 InvalidRange {
171 start: usize,
172 end: usize,
173 step: usize,
174 },
175 #[error("trix: invalid kernel for batch: {0:?}")]
176 InvalidKernelForBatch(Kernel),
177 #[error("trix: invalid input: {0}")]
178 InvalidInput(String),
179}
180
181#[inline]
182pub fn trix(input: &TrixInput) -> Result<TrixOutput, TrixError> {
183 trix_with_kernel(input, Kernel::Auto)
184}
185
186#[inline(always)]
187fn trix_needed_len(period: usize) -> Result<usize, TrixError> {
188 let base = period
189 .checked_sub(1)
190 .and_then(|v| v.checked_mul(3))
191 .and_then(|v| v.checked_add(2))
192 .ok_or_else(|| {
193 TrixError::InvalidInput("period overflow when computing TRIX warmup length".into())
194 })?;
195 Ok(base)
196}
197
198#[inline(always)]
199fn trix_warmup_end(first: usize, period: usize) -> Result<usize, TrixError> {
200 let delta = period
201 .checked_sub(1)
202 .and_then(|v| v.checked_mul(3))
203 .and_then(|v| v.checked_add(1))
204 .ok_or_else(|| {
205 TrixError::InvalidInput("period overflow when computing TRIX warmup index".into())
206 })?;
207 first.checked_add(delta).ok_or_else(|| {
208 TrixError::InvalidInput("index overflow when computing TRIX warmup index".into())
209 })
210}
211
212#[inline(always)]
213fn trix_prepare<'a>(
214 input: &'a TrixInput,
215 k: Kernel,
216) -> Result<(&'a [f64], usize, usize, Kernel, f64, usize), TrixError> {
217 let data: &[f64] = input.as_ref();
218 let len = data.len();
219 if len == 0 {
220 return Err(TrixError::EmptyInputData);
221 }
222 let period = input.get_period();
223 if period == 0 || period > len {
224 return Err(TrixError::InvalidPeriod {
225 period,
226 data_len: len,
227 });
228 }
229 let first = data
230 .iter()
231 .position(|x| !x.is_nan())
232 .ok_or(TrixError::AllValuesNaN)?;
233 let needed = trix_needed_len(period)?;
234 let valid_len = len.saturating_sub(first);
235 if valid_len < needed {
236 return Err(TrixError::NotEnoughValidData {
237 needed,
238 valid: valid_len,
239 });
240 }
241 let chosen = match k {
242 Kernel::Auto => Kernel::Scalar,
243 other => other,
244 };
245 let alpha = 2.0 / (period as f64 + 1.0);
246 let warmup_end = trix_warmup_end(first, period)?;
247 Ok((data, period, first, chosen, alpha, warmup_end))
248}
249
250#[inline(always)]
251fn trix_compute_into_scalar(
252 data: &[f64],
253 period: usize,
254 first: usize,
255 alpha: f64,
256 out: &mut [f64],
257) {
258 let len = data.len();
259 let warmup_end = first + 3 * (period - 1) + 1;
260 if warmup_end >= len {
261 return;
262 }
263
264 let inv_n = 1.0 / period as f64;
265 const SCALE: f64 = 10000.0;
266
267 let mut sum1 = 0.0;
268 let end1 = first + period;
269 let mut i = first;
270 while i < end1 {
271 sum1 += data[i].ln();
272 i += 1;
273 }
274 let mut ema1 = sum1 * inv_n;
275
276 let mut sum_ema1 = ema1;
277 let end2 = first + 2 * period - 1;
278 i = end1;
279
280 while i + 3 < end2 {
281 let mut lv = data[i].ln();
282 ema1 = (lv - ema1).mul_add(alpha, ema1);
283 sum_ema1 += ema1;
284
285 lv = data[i + 1].ln();
286 ema1 = (lv - ema1).mul_add(alpha, ema1);
287 sum_ema1 += ema1;
288
289 lv = data[i + 2].ln();
290 ema1 = (lv - ema1).mul_add(alpha, ema1);
291 sum_ema1 += ema1;
292
293 lv = data[i + 3].ln();
294 ema1 = (lv - ema1).mul_add(alpha, ema1);
295 sum_ema1 += ema1;
296 i += 4;
297 }
298 while i < end2 {
299 let lv = data[i].ln();
300 ema1 = (lv - ema1).mul_add(alpha, ema1);
301 sum_ema1 += ema1;
302 i += 1;
303 }
304
305 let mut ema2 = sum_ema1 * inv_n;
306
307 let mut sum_ema2 = ema2;
308 let end3 = first + 3 * period - 2;
309 i = end2;
310
311 while i + 3 < end3 {
312 let mut lv = data[i].ln();
313 ema1 = (lv - ema1).mul_add(alpha, ema1);
314 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
315 sum_ema2 += ema2;
316
317 lv = data[i + 1].ln();
318 ema1 = (lv - ema1).mul_add(alpha, ema1);
319 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
320 sum_ema2 += ema2;
321
322 lv = data[i + 2].ln();
323 ema1 = (lv - ema1).mul_add(alpha, ema1);
324 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
325 sum_ema2 += ema2;
326
327 lv = data[i + 3].ln();
328 ema1 = (lv - ema1).mul_add(alpha, ema1);
329 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
330 sum_ema2 += ema2;
331 i += 4;
332 }
333 while i < end3 {
334 let lv = data[i].ln();
335 ema1 = (lv - ema1).mul_add(alpha, ema1);
336 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
337 sum_ema2 += ema2;
338 i += 1;
339 }
340
341 let mut ema3_prev = sum_ema2 * inv_n;
342
343 let mut src = warmup_end;
344 let mut lv = data[src].ln();
345 ema1 = (lv - ema1).mul_add(alpha, ema1);
346 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
347 let mut ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
348 out[src] = (ema3 - ema3_prev) * SCALE;
349 ema3_prev = ema3;
350 src += 1;
351
352 while src + 3 < len {
353 lv = data[src].ln();
354 ema1 = (lv - ema1).mul_add(alpha, ema1);
355 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
356 ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
357 out[src] = (ema3 - ema3_prev) * SCALE;
358 ema3_prev = ema3;
359
360 let lv1 = data[src + 1].ln();
361 ema1 = (lv1 - ema1).mul_add(alpha, ema1);
362 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
363 ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
364 out[src + 1] = (ema3 - ema3_prev) * SCALE;
365 ema3_prev = ema3;
366
367 let lv2 = data[src + 2].ln();
368 ema1 = (lv2 - ema1).mul_add(alpha, ema1);
369 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
370 ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
371 out[src + 2] = (ema3 - ema3_prev) * SCALE;
372 ema3_prev = ema3;
373
374 let lv3 = data[src + 3].ln();
375 ema1 = (lv3 - ema1).mul_add(alpha, ema1);
376 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
377 ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
378 out[src + 3] = (ema3 - ema3_prev) * SCALE;
379 ema3_prev = ema3;
380
381 src += 4;
382 }
383
384 while src < len {
385 lv = data[src].ln();
386 ema1 = (lv - ema1).mul_add(alpha, ema1);
387 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
388 ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
389 out[src] = (ema3 - ema3_prev) * SCALE;
390 ema3_prev = ema3;
391 src += 1;
392 }
393}
394
395pub fn trix_with_kernel(input: &TrixInput, kernel: Kernel) -> Result<TrixOutput, TrixError> {
396 let (data, period, first, chosen, alpha, warmup_end) = trix_prepare(input, kernel)?;
397 let mut out = alloc_with_nan_prefix(data.len(), warmup_end);
398 unsafe {
399 match chosen {
400 Kernel::Scalar | Kernel::ScalarBatch => {
401 trix_compute_into_scalar(data, period, first, alpha, &mut out);
402 }
403 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
404 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
405 trix_compute_into_scalar(data, period, first, alpha, &mut out);
406 }
407 #[allow(unreachable_patterns)]
408 _ => trix_compute_into_scalar(data, period, first, alpha, &mut out),
409 }
410 }
411 Ok(TrixOutput { values: out })
412}
413
414#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
415#[inline]
416pub fn trix_into(input: &TrixInput, out: &mut [f64]) -> Result<(), TrixError> {
417 let (data, period, first, chosen, alpha, warmup_end) = trix_prepare(input, Kernel::Auto)?;
418
419 if out.len() != data.len() {
420 return Err(TrixError::OutputLengthMismatch {
421 expected: data.len(),
422 got: out.len(),
423 });
424 }
425
426 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
427 let warm = warmup_end.min(out.len());
428 for v in &mut out[..warm] {
429 *v = qnan;
430 }
431
432 unsafe {
433 match chosen {
434 Kernel::Scalar | Kernel::ScalarBatch => {
435 trix_compute_into_scalar(data, period, first, alpha, out)
436 }
437 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
438 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
439 trix_compute_into_scalar(data, period, first, alpha, out)
440 }
441 #[allow(unreachable_patterns)]
442 _ => trix_compute_into_scalar(data, period, first, alpha, out),
443 }
444 }
445 Ok(())
446}
447
448#[derive(Debug, Clone)]
449pub struct TrixStream {
450 period: usize,
451 alpha: f64,
452 inv_n: f64,
453 state: StreamState,
454}
455
456#[derive(Debug, Clone)]
457enum StreamState {
458 Seed1 {
459 need: usize,
460 sum1: f64,
461 },
462
463 Seed2 {
464 remain: usize,
465 ema1: f64,
466 sum_ema1: f64,
467 },
468
469 Seed3 {
470 remain: usize,
471 ema1: f64,
472 ema2: f64,
473 sum_ema2: f64,
474 },
475
476 Running {
477 ema1: f64,
478 ema2: f64,
479 ema3_prev: f64,
480 },
481}
482
483impl TrixStream {
484 #[inline(always)]
485 pub fn try_new(params: TrixParams) -> Result<Self, TrixError> {
486 let period = params.period.unwrap_or(18);
487 if period == 0 {
488 return Err(TrixError::InvalidPeriod {
489 period,
490 data_len: 0,
491 });
492 }
493 let alpha = 2.0 / (period as f64 + 1.0);
494 Ok(Self {
495 period,
496 alpha,
497 inv_n: 1.0 / period as f64,
498 state: StreamState::Seed1 {
499 need: period,
500 sum1: 0.0,
501 },
502 })
503 }
504
505 #[inline(always)]
506 fn reset(&mut self) {
507 self.state = StreamState::Seed1 {
508 need: self.period,
509 sum1: 0.0,
510 };
511 }
512
513 #[inline(always)]
514 pub fn update(&mut self, value: f64) -> Option<f64> {
515 if !value.is_finite() || value <= 0.0 {
516 self.reset();
517 return None;
518 }
519
520 let lv = value.ln();
521 let a = self.alpha;
522
523 match &mut self.state {
524 StreamState::Seed1 { need, sum1 } => {
525 *sum1 += lv;
526 *need -= 1;
527 if *need == 0 {
528 let ema1 = *sum1 * self.inv_n;
529 self.state = StreamState::Seed2 {
530 remain: self.period - 1,
531 ema1,
532 sum_ema1: ema1,
533 };
534 }
535 None
536 }
537
538 StreamState::Seed2 {
539 remain,
540 ema1,
541 sum_ema1,
542 } => {
543 *ema1 = (lv - *ema1).mul_add(a, *ema1);
544 *sum_ema1 += *ema1;
545 *remain -= 1;
546 if *remain == 0 {
547 let ema2 = *sum_ema1 * self.inv_n;
548 let e1 = *ema1;
549 self.state = StreamState::Seed3 {
550 remain: self.period - 1,
551 ema1: e1,
552 ema2,
553 sum_ema2: ema2,
554 };
555 }
556 None
557 }
558
559 StreamState::Seed3 {
560 remain,
561 ema1,
562 ema2,
563 sum_ema2,
564 } => {
565 *ema1 = (lv - *ema1).mul_add(a, *ema1);
566 *ema2 = (*ema1 - *ema2).mul_add(a, *ema2);
567 *sum_ema2 += *ema2;
568 *remain -= 1;
569 if *remain == 0 {
570 let ema3_prev = *sum_ema2 * self.inv_n;
571 let e1 = *ema1;
572 let e2 = *ema2;
573 self.state = StreamState::Running {
574 ema1: e1,
575 ema2: e2,
576 ema3_prev,
577 };
578 }
579 None
580 }
581
582 StreamState::Running {
583 ema1,
584 ema2,
585 ema3_prev,
586 } => {
587 *ema1 = (lv - *ema1).mul_add(a, *ema1);
588 *ema2 = (*ema1 - *ema2).mul_add(a, *ema2);
589 let ema3 = (*ema2 - *ema3_prev).mul_add(a, *ema3_prev);
590 let out = (ema3 - *ema3_prev) * 10000.0;
591 *ema3_prev = ema3;
592 Some(out)
593 }
594 }
595 }
596}
597
598#[derive(Clone, Debug)]
599pub struct TrixBatchRange {
600 pub period: (usize, usize, usize),
601}
602
603impl Default for TrixBatchRange {
604 fn default() -> Self {
605 Self {
606 period: (18, 267, 1),
607 }
608 }
609}
610
611#[derive(Clone, Debug, Default)]
612pub struct TrixBatchBuilder {
613 range: TrixBatchRange,
614 kernel: Kernel,
615}
616
617impl TrixBatchBuilder {
618 pub fn new() -> Self {
619 Self::default()
620 }
621 pub fn kernel(mut self, k: Kernel) -> Self {
622 self.kernel = k;
623 self
624 }
625 #[inline]
626 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
627 self.range.period = (start, end, step);
628 self
629 }
630 #[inline]
631 pub fn period_static(mut self, p: usize) -> Self {
632 self.range.period = (p, p, 0);
633 self
634 }
635 pub fn apply_slice(self, data: &[f64]) -> Result<TrixBatchOutput, TrixError> {
636 trix_batch_with_kernel(data, &self.range, self.kernel)
637 }
638 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<TrixBatchOutput, TrixError> {
639 TrixBatchBuilder::new().kernel(k).apply_slice(data)
640 }
641 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<TrixBatchOutput, TrixError> {
642 let slice = source_type(c, src);
643 self.apply_slice(slice)
644 }
645 pub fn with_default_candles(c: &Candles) -> Result<TrixBatchOutput, TrixError> {
646 TrixBatchBuilder::new()
647 .kernel(Kernel::Auto)
648 .apply_candles(c, "close")
649 }
650}
651
652pub fn trix_batch_with_kernel(
653 data: &[f64],
654 sweep: &TrixBatchRange,
655 k: Kernel,
656) -> Result<TrixBatchOutput, TrixError> {
657 let kernel = match k {
658 Kernel::Auto => Kernel::ScalarBatch,
659 other if other.is_batch() => other,
660 _ => {
661 return Err(TrixError::InvalidKernelForBatch(k));
662 }
663 };
664 let simd = match kernel {
665 Kernel::Avx512Batch => Kernel::Avx512,
666 Kernel::Avx2Batch => Kernel::Avx2,
667 Kernel::ScalarBatch => Kernel::Scalar,
668 _ => unreachable!(),
669 };
670 trix_batch_par_slice(data, sweep, simd)
671}
672
673#[derive(Clone, Debug)]
674pub struct TrixBatchOutput {
675 pub values: Vec<f64>,
676 pub combos: Vec<TrixParams>,
677 pub rows: usize,
678 pub cols: usize,
679}
680impl TrixBatchOutput {
681 pub fn row_for_params(&self, p: &TrixParams) -> Option<usize> {
682 self.combos
683 .iter()
684 .position(|c| c.period.unwrap_or(18) == p.period.unwrap_or(18))
685 }
686 pub fn values_for(&self, p: &TrixParams) -> Option<&[f64]> {
687 self.row_for_params(p).map(|row| {
688 let start = row * self.cols;
689 &self.values[start..start + self.cols]
690 })
691 }
692}
693
694#[inline(always)]
695fn expand_grid(r: &TrixBatchRange) -> Result<Vec<TrixParams>, TrixError> {
696 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, TrixError> {
697 if step == 0 || start == end {
698 return Ok(vec![start]);
699 }
700 let mut vals = Vec::new();
701 if start < end {
702 let mut v = start;
703 while v <= end {
704 vals.push(v);
705 let next = match v.checked_add(step) {
706 Some(n) => n,
707 None => break,
708 };
709 if next == v {
710 break;
711 }
712 v = next;
713 }
714 } else {
715 let mut v = start;
716 while v >= end {
717 vals.push(v);
718 let next = v.saturating_sub(step);
719 if next == v {
720 break;
721 }
722 v = next;
723 }
724 }
725 if vals.is_empty() {
726 return Err(TrixError::InvalidRange { start, end, step });
727 }
728 Ok(vals)
729 }
730 let periods = axis_usize(r.period)?;
731 let mut out = Vec::with_capacity(periods.len());
732 for &p in &periods {
733 out.push(TrixParams { period: Some(p) });
734 }
735 Ok(out)
736}
737
738#[inline(always)]
739pub fn trix_batch_slice(
740 data: &[f64],
741 sweep: &TrixBatchRange,
742 kern: Kernel,
743) -> Result<TrixBatchOutput, TrixError> {
744 trix_batch_inner(data, sweep, kern, false)
745}
746
747#[inline(always)]
748pub fn trix_batch_par_slice(
749 data: &[f64],
750 sweep: &TrixBatchRange,
751 kern: Kernel,
752) -> Result<TrixBatchOutput, TrixError> {
753 trix_batch_inner(data, sweep, kern, true)
754}
755
756#[inline(always)]
757fn trix_batch_inner(
758 data: &[f64],
759 sweep: &TrixBatchRange,
760 kern: Kernel,
761 parallel: bool,
762) -> Result<TrixBatchOutput, TrixError> {
763 let combos = expand_grid(sweep)?;
764 let first = data
765 .iter()
766 .position(|x| !x.is_nan())
767 .ok_or(TrixError::AllValuesNaN)?;
768 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
769 let needed = trix_needed_len(max_p)?;
770 if data.len() - first < needed {
771 return Err(TrixError::NotEnoughValidData {
772 needed,
773 valid: data.len() - first,
774 });
775 }
776 let rows = combos.len();
777 let cols = data.len();
778 let _total = rows
779 .checked_mul(cols)
780 .ok_or_else(|| TrixError::InvalidInput("rows*cols overflow".into()))?;
781
782 let mut buf_mu = make_uninit_matrix(rows, cols);
783
784 let warm: Vec<usize> = combos
785 .iter()
786 .map(|c| trix_warmup_end(first, c.period.unwrap()))
787 .collect::<Result<_, _>>()?;
788
789 init_matrix_prefixes(&mut buf_mu, cols, &warm);
790
791 let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
792 let values: &mut [f64] = unsafe {
793 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
794 };
795
796 let mut logs: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, cols);
797 unsafe { logs.set_len(cols) };
798 for i in 0..first {
799 logs[i] = 0.0;
800 }
801 for i in first..cols {
802 logs[i] = data[i].ln();
803 }
804
805 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
806 let period = combos[row].period.unwrap();
807 trix_row_scalar_with_logs(&logs, first, period, out_row)
808 };
809 if parallel {
810 #[cfg(not(target_arch = "wasm32"))]
811 {
812 values
813 .par_chunks_mut(cols)
814 .enumerate()
815 .for_each(|(row, slice)| do_row(row, slice));
816 }
817
818 #[cfg(target_arch = "wasm32")]
819 {
820 for (row, slice) in values.chunks_mut(cols).enumerate() {
821 do_row(row, slice);
822 }
823 }
824 } else {
825 for (row, slice) in values.chunks_mut(cols).enumerate() {
826 do_row(row, slice);
827 }
828 }
829
830 let values = unsafe {
831 let ptr = buf_guard.as_mut_ptr() as *mut f64;
832 let len = buf_guard.len();
833 core::mem::forget(buf_guard);
834 Vec::from_raw_parts(ptr, len, len)
835 };
836
837 Ok(TrixBatchOutput {
838 values,
839 combos,
840 rows,
841 cols,
842 })
843}
844
845#[inline(always)]
846unsafe fn trix_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
847 let len = data.len();
848 let alpha = 2.0 / (period as f64 + 1.0);
849 let inv_n = 1.0 / period as f64;
850 const SCALE: f64 = 10000.0;
851
852 let warmup_end = first + 3 * (period - 1) + 1;
853 for v in &mut out[..warmup_end.min(len)] {
854 *v = f64::NAN;
855 }
856 if warmup_end >= len {
857 return;
858 }
859
860 let p = data.as_ptr();
861
862 let mut sum1 = 0.0;
863 let end1 = first + period;
864 let mut i = first;
865 while i < end1 {
866 sum1 += (*p.add(i)).ln();
867 i += 1;
868 }
869 let mut ema1 = sum1 * inv_n;
870
871 let mut sum_ema1 = ema1;
872 let end2 = first + 2 * period - 1;
873 i = end1;
874 while i < end2 {
875 let lv = (*p.add(i)).ln();
876 ema1 = (lv - ema1).mul_add(alpha, ema1);
877 sum_ema1 += ema1;
878 i += 1;
879 }
880
881 let mut ema2 = sum_ema1 * inv_n;
882
883 let mut sum_ema2 = ema2;
884 let end3 = first + 3 * period - 2;
885 i = end2;
886 while i < end3 {
887 let lv = (*p.add(i)).ln();
888 ema1 = (lv - ema1).mul_add(alpha, ema1);
889 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
890 sum_ema2 += ema2;
891 i += 1;
892 }
893
894 let mut ema3_prev = sum_ema2 * inv_n;
895
896 let mut src = warmup_end;
897 let mut lv = (*p.add(src)).ln();
898 ema1 = (lv - ema1).mul_add(alpha, ema1);
899 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
900 let mut ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
901 *out.get_unchecked_mut(src) = (ema3 - ema3_prev) * SCALE;
902 ema3_prev = ema3;
903 src += 1;
904
905 while src + 1 < len {
906 lv = (*p.add(src)).ln();
907 ema1 = (lv - ema1).mul_add(alpha, ema1);
908 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
909 ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
910 *out.get_unchecked_mut(src) = (ema3 - ema3_prev) * SCALE;
911 ema3_prev = ema3;
912 src += 1;
913
914 lv = (*p.add(src)).ln();
915 ema1 = (lv - ema1).mul_add(alpha, ema1);
916 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
917 ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
918 *out.get_unchecked_mut(src) = (ema3 - ema3_prev) * SCALE;
919 ema3_prev = ema3;
920 src += 1;
921 }
922
923 if src < len {
924 lv = (*p.add(src)).ln();
925 ema1 = (lv - ema1).mul_add(alpha, ema1);
926 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
927 ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
928 *out.get_unchecked_mut(src) = (ema3 - ema3_prev) * SCALE;
929 }
930}
931
932#[inline(always)]
933unsafe fn trix_row_scalar_with_logs(logs: &[f64], first: usize, period: usize, out: &mut [f64]) {
934 let len = logs.len();
935 let alpha = 2.0 / (period as f64 + 1.0);
936 let inv_n = 1.0 / period as f64;
937 const SCALE: f64 = 10000.0;
938
939 let warmup_end = first + 3 * (period - 1) + 1;
940 if warmup_end >= len {
941 return;
942 }
943
944 let p = logs.as_ptr();
945
946 let mut sum1 = 0.0;
947 let end1 = first + period;
948 let mut i = first;
949 while i < end1 {
950 sum1 += *p.add(i);
951 i += 1;
952 }
953 let mut ema1 = sum1 * inv_n;
954
955 let mut sum_ema1 = ema1;
956 let end2 = first + 2 * period - 1;
957 i = end1;
958
959 while i + 3 < end2 {
960 let mut lv = *p.add(i);
961 ema1 = (lv - ema1).mul_add(alpha, ema1);
962 sum_ema1 += ema1;
963
964 lv = *p.add(i + 1);
965 ema1 = (lv - ema1).mul_add(alpha, ema1);
966 sum_ema1 += ema1;
967
968 lv = *p.add(i + 2);
969 ema1 = (lv - ema1).mul_add(alpha, ema1);
970 sum_ema1 += ema1;
971
972 lv = *p.add(i + 3);
973 ema1 = (lv - ema1).mul_add(alpha, ema1);
974 sum_ema1 += ema1;
975 i += 4;
976 }
977 while i < end2 {
978 let lv = *p.add(i);
979 ema1 = (lv - ema1).mul_add(alpha, ema1);
980 sum_ema1 += ema1;
981 i += 1;
982 }
983
984 let mut ema2 = sum_ema1 * inv_n;
985
986 let mut sum_ema2 = ema2;
987 let end3 = first + 3 * period - 2;
988 i = end2;
989
990 while i + 3 < end3 {
991 let mut lv = *p.add(i);
992 ema1 = (lv - ema1).mul_add(alpha, ema1);
993 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
994 sum_ema2 += ema2;
995
996 lv = *p.add(i + 1);
997 ema1 = (lv - ema1).mul_add(alpha, ema1);
998 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
999 sum_ema2 += ema2;
1000
1001 lv = *p.add(i + 2);
1002 ema1 = (lv - ema1).mul_add(alpha, ema1);
1003 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
1004 sum_ema2 += ema2;
1005
1006 lv = *p.add(i + 3);
1007 ema1 = (lv - ema1).mul_add(alpha, ema1);
1008 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
1009 sum_ema2 += ema2;
1010 i += 4;
1011 }
1012 while i < end3 {
1013 let lv = *p.add(i);
1014 ema1 = (lv - ema1).mul_add(alpha, ema1);
1015 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
1016 sum_ema2 += ema2;
1017 i += 1;
1018 }
1019
1020 let mut ema3_prev = sum_ema2 * inv_n;
1021
1022 let mut src = warmup_end;
1023 let mut lv = *p.add(src);
1024 ema1 = (lv - ema1).mul_add(alpha, ema1);
1025 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
1026 let mut ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
1027 *out.get_unchecked_mut(src) = (ema3 - ema3_prev) * SCALE;
1028 ema3_prev = ema3;
1029 src += 1;
1030
1031 while src + 3 < len {
1032 lv = *p.add(src);
1033 ema1 = (lv - ema1).mul_add(alpha, ema1);
1034 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
1035 ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
1036 *out.get_unchecked_mut(src) = (ema3 - ema3_prev) * SCALE;
1037 ema3_prev = ema3;
1038
1039 let lv1 = *p.add(src + 1);
1040 ema1 = (lv1 - ema1).mul_add(alpha, ema1);
1041 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
1042 ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
1043 *out.get_unchecked_mut(src + 1) = (ema3 - ema3_prev) * SCALE;
1044 ema3_prev = ema3;
1045
1046 let lv2 = *p.add(src + 2);
1047 ema1 = (lv2 - ema1).mul_add(alpha, ema1);
1048 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
1049 ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
1050 *out.get_unchecked_mut(src + 2) = (ema3 - ema3_prev) * SCALE;
1051 ema3_prev = ema3;
1052
1053 let lv3 = *p.add(src + 3);
1054 ema1 = (lv3 - ema1).mul_add(alpha, ema1);
1055 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
1056 ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
1057 *out.get_unchecked_mut(src + 3) = (ema3 - ema3_prev) * SCALE;
1058 ema3_prev = ema3;
1059
1060 src += 4;
1061 }
1062 while src < len {
1063 lv = *p.add(src);
1064 ema1 = (lv - ema1).mul_add(alpha, ema1);
1065 ema2 = (ema1 - ema2).mul_add(alpha, ema2);
1066 ema3 = (ema2 - ema3_prev).mul_add(alpha, ema3_prev);
1067 *out.get_unchecked_mut(src) = (ema3 - ema3_prev) * SCALE;
1068 ema3_prev = ema3;
1069 src += 1;
1070 }
1071}
1072
1073#[cfg(feature = "python")]
1074#[pyfunction(name = "trix")]
1075#[pyo3(signature = (data, period, kernel=None))]
1076pub fn trix_py<'py>(
1077 py: Python<'py>,
1078 data: PyReadonlyArray1<'py, f64>,
1079 period: usize,
1080 kernel: Option<&str>,
1081) -> PyResult<Bound<'py, PyArray1<f64>>> {
1082 use numpy::{IntoPyArray, PyArrayMethods};
1083
1084 let slice_in = data.as_slice()?;
1085 let kern = validate_kernel(kernel, false)?;
1086
1087 let params = TrixParams {
1088 period: Some(period),
1089 };
1090 let trix_in = TrixInput::from_slice(slice_in, params);
1091
1092 let result_vec: Vec<f64> = py
1093 .allow_threads(|| trix_with_kernel(&trix_in, kern).map(|o| o.values))
1094 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1095
1096 Ok(result_vec.into_pyarray(py))
1097}
1098
1099#[cfg(feature = "python")]
1100#[pyclass(name = "TrixStream")]
1101pub struct TrixStreamPy {
1102 stream: TrixStream,
1103}
1104
1105#[cfg(feature = "python")]
1106#[pymethods]
1107impl TrixStreamPy {
1108 #[new]
1109 fn new(period: usize) -> PyResult<Self> {
1110 let params = TrixParams {
1111 period: Some(period),
1112 };
1113 let stream =
1114 TrixStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1115 Ok(TrixStreamPy { stream })
1116 }
1117
1118 fn update(&mut self, value: f64) -> Option<f64> {
1119 self.stream.update(value)
1120 }
1121}
1122
1123#[cfg(all(feature = "python", feature = "cuda"))]
1124#[pyfunction(name = "trix_cuda_batch_dev")]
1125#[pyo3(signature = (data_f32, period_range, device_id=0))]
1126pub fn trix_cuda_batch_dev_py<'py>(
1127 py: Python<'py>,
1128 data_f32: numpy::PyReadonlyArray1<'py, f32>,
1129 period_range: (usize, usize, usize),
1130 device_id: usize,
1131) -> PyResult<(DeviceArrayF32Py, Bound<'py, PyDict>)> {
1132 use crate::cuda::cuda_available;
1133 use numpy::IntoPyArray;
1134
1135 if !cuda_available() {
1136 return Err(PyValueError::new_err("CUDA not available"));
1137 }
1138
1139 let slice_in = data_f32.as_slice()?;
1140 let sweep = TrixBatchRange {
1141 period: period_range,
1142 };
1143
1144 let inner = py.allow_threads(|| -> PyResult<_> {
1145 let cuda = CudaTrix::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1146 let arr = cuda
1147 .trix_batch_dev(slice_in, &sweep)
1148 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1149 cuda.synchronize()
1150 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1151 Ok(arr)
1152 })?;
1153
1154 let dict = PyDict::new(py);
1155 let (start, end, step) = period_range;
1156 let mut periods: Vec<u64> = Vec::new();
1157 if step == 0 {
1158 periods.push(start as u64);
1159 } else {
1160 let mut p = start;
1161 while p <= end {
1162 periods.push(p as u64);
1163 p = p.saturating_add(step);
1164 }
1165 }
1166 dict.set_item("periods", periods.into_pyarray(py))?;
1167
1168 let handle = make_device_array_py(device_id, inner)?;
1169 Ok((handle, dict))
1170}
1171
1172#[cfg(all(feature = "python", feature = "cuda"))]
1173#[pyfunction(name = "trix_cuda_many_series_one_param_dev")]
1174#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1175pub fn trix_cuda_many_series_one_param_dev_py(
1176 py: Python<'_>,
1177 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1178 period: usize,
1179 device_id: usize,
1180) -> PyResult<DeviceArrayF32Py> {
1181 use crate::cuda::cuda_available;
1182 use numpy::PyUntypedArrayMethods;
1183
1184 if !cuda_available() {
1185 return Err(PyValueError::new_err("CUDA not available"));
1186 }
1187
1188 let flat_in = data_tm_f32.as_slice()?;
1189 let rows = data_tm_f32.shape()[0];
1190 let cols = data_tm_f32.shape()[1];
1191
1192 let inner = py.allow_threads(|| -> PyResult<_> {
1193 let cuda = CudaTrix::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1194 let arr = cuda
1195 .trix_many_series_one_param_time_major_dev(flat_in, cols, rows, period)
1196 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1197 cuda.synchronize()
1198 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1199 Ok(arr)
1200 })?;
1201
1202 let handle = make_device_array_py(device_id, inner)?;
1203 Ok(handle)
1204}
1205
1206#[cfg(feature = "python")]
1207#[pyfunction(name = "trix_batch")]
1208#[pyo3(signature = (data, period_range, kernel=None))]
1209pub fn trix_batch_py<'py>(
1210 py: Python<'py>,
1211 data: PyReadonlyArray1<'py, f64>,
1212 period_range: (usize, usize, usize),
1213 kernel: Option<&str>,
1214) -> PyResult<Bound<'py, PyDict>> {
1215 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1216 use pyo3::types::PyDict;
1217
1218 let slice_in = data.as_slice()?;
1219 let kern = validate_kernel(kernel, true)?;
1220
1221 let sweep = TrixBatchRange {
1222 period: period_range,
1223 };
1224 let combos_probe = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1225 let rows = combos_probe.len();
1226 let cols = slice_in.len();
1227
1228 let total = rows
1229 .checked_mul(cols)
1230 .ok_or_else(|| PyValueError::new_err("trix_batch_py: rows*cols overflow"))?;
1231 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1232 let slice_out = unsafe { out_arr.as_slice_mut()? };
1233
1234 let first = slice_in
1235 .iter()
1236 .position(|x| !x.is_nan())
1237 .ok_or_else(|| PyValueError::new_err("AllValuesNaN"))?;
1238 for (r, prm) in combos_probe.iter().enumerate() {
1239 let warm = first + 3 * (prm.period.unwrap() - 1) + 1;
1240 let start = r * cols;
1241 let end = start + warm.min(cols);
1242 for v in &mut slice_out[start..end] {
1243 *v = f64::NAN;
1244 }
1245 }
1246
1247 let combos = py
1248 .allow_threads(|| {
1249 let kernel = match kern {
1250 Kernel::Auto => Kernel::ScalarBatch,
1251 k => k,
1252 };
1253 let simd = match kernel {
1254 Kernel::Avx512Batch => Kernel::Avx512,
1255 Kernel::Avx2Batch => Kernel::Avx2,
1256 Kernel::ScalarBatch => Kernel::Scalar,
1257 _ => unreachable!(),
1258 };
1259 trix_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1260 })
1261 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1262
1263 let dict = PyDict::new(py);
1264 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1265 dict.set_item(
1266 "periods",
1267 combos
1268 .iter()
1269 .map(|p| p.period.unwrap() as u64)
1270 .collect::<Vec<_>>()
1271 .into_pyarray(py),
1272 )?;
1273 Ok(dict.into())
1274}
1275
1276#[inline(always)]
1277fn trix_batch_inner_into(
1278 data: &[f64],
1279 sweep: &TrixBatchRange,
1280 kern: Kernel,
1281 parallel: bool,
1282 out: &mut [f64],
1283) -> Result<Vec<TrixParams>, TrixError> {
1284 let combos = expand_grid(sweep)?;
1285 let first = data
1286 .iter()
1287 .position(|x| !x.is_nan())
1288 .ok_or(TrixError::AllValuesNaN)?;
1289 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1290 let needed = trix_needed_len(max_p)?;
1291 if data.len() - first < needed {
1292 return Err(TrixError::NotEnoughValidData {
1293 needed,
1294 valid: data.len() - first,
1295 });
1296 }
1297 let rows = combos.len();
1298 let cols = data.len();
1299
1300 let mut logs: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, cols);
1301 unsafe { logs.set_len(cols) };
1302 for i in 0..first {
1303 logs[i] = 0.0;
1304 }
1305 for i in first..cols {
1306 logs[i] = data[i].ln();
1307 }
1308
1309 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1310 let period = combos[row].period.unwrap();
1311 trix_row_scalar_with_logs(&logs, first, period, out_row)
1312 };
1313
1314 if parallel {
1315 #[cfg(not(target_arch = "wasm32"))]
1316 {
1317 out.par_chunks_mut(cols)
1318 .enumerate()
1319 .for_each(|(row, slice)| do_row(row, slice));
1320 }
1321
1322 #[cfg(target_arch = "wasm32")]
1323 {
1324 for (row, slice) in out.chunks_mut(cols).enumerate() {
1325 do_row(row, slice);
1326 }
1327 }
1328 } else {
1329 for (row, slice) in out.chunks_mut(cols).enumerate() {
1330 do_row(row, slice);
1331 }
1332 }
1333 Ok(combos)
1334}
1335
1336#[inline(always)]
1337pub fn trix_into_slice(dst: &mut [f64], input: &TrixInput, kern: Kernel) -> Result<(), TrixError> {
1338 let (data, period, first, chosen, alpha, warmup_end) = trix_prepare(input, kern)?;
1339 if dst.len() != data.len() {
1340 return Err(TrixError::OutputLengthMismatch {
1341 expected: data.len(),
1342 got: dst.len(),
1343 });
1344 }
1345
1346 let warmup_len = warmup_end.min(dst.len());
1347 for v in &mut dst[..warmup_len] {
1348 *v = f64::NAN;
1349 }
1350 unsafe {
1351 match chosen {
1352 Kernel::Scalar | Kernel::ScalarBatch => {
1353 trix_compute_into_scalar(data, period, first, alpha, dst)
1354 }
1355 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1356 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
1357 trix_compute_into_scalar(data, period, first, alpha, dst)
1358 }
1359 #[allow(unreachable_patterns)]
1360 _ => trix_compute_into_scalar(data, period, first, alpha, dst),
1361 }
1362 }
1363 Ok(())
1364}
1365
1366#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1367#[wasm_bindgen]
1368pub fn trix_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1369 let params = TrixParams {
1370 period: Some(period),
1371 };
1372 let input = TrixInput::from_slice(data, params);
1373 let mut output = vec![f64::NAN; data.len()];
1374 trix_into_slice(&mut output, &input, Kernel::Scalar)
1375 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1376 Ok(output)
1377}
1378
1379#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1380#[wasm_bindgen]
1381pub fn trix_alloc(len: usize) -> *mut f64 {
1382 let mut vec = Vec::<f64>::with_capacity(len);
1383 let ptr = vec.as_mut_ptr();
1384 std::mem::forget(vec);
1385 ptr
1386}
1387
1388#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1389#[wasm_bindgen]
1390pub fn trix_free(ptr: *mut f64, len: usize) {
1391 if !ptr.is_null() {
1392 unsafe {
1393 let _ = Vec::from_raw_parts(ptr, len, len);
1394 }
1395 }
1396}
1397
1398#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1399#[wasm_bindgen]
1400pub fn trix_into(
1401 in_ptr: *const f64,
1402 out_ptr: *mut f64,
1403 len: usize,
1404 period: usize,
1405) -> Result<(), JsValue> {
1406 if in_ptr.is_null() || out_ptr.is_null() {
1407 return Err(JsValue::from_str("Null pointer provided"));
1408 }
1409 unsafe {
1410 let data = std::slice::from_raw_parts(in_ptr, len);
1411 let params = TrixParams {
1412 period: Some(period),
1413 };
1414 let input = TrixInput::from_slice(data, params);
1415 if in_ptr == out_ptr {
1416 let mut tmp = vec![f64::NAN; len];
1417 trix_into_slice(&mut tmp, &input, Kernel::Scalar)
1418 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1419 std::slice::from_raw_parts_mut(out_ptr, len).copy_from_slice(&tmp);
1420 } else {
1421 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1422 trix_into_slice(out, &input, Kernel::Scalar)
1423 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1424 }
1425 }
1426 Ok(())
1427}
1428
1429#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1430#[derive(Serialize, Deserialize)]
1431pub struct TrixBatchConfig {
1432 pub period_range: (usize, usize, usize),
1433}
1434
1435#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1436#[derive(Serialize, Deserialize)]
1437pub struct TrixBatchJsOutput {
1438 pub values: Vec<f64>,
1439 pub periods: Vec<usize>,
1440 pub rows: usize,
1441 pub cols: usize,
1442}
1443
1444#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1445#[wasm_bindgen(js_name = trix_batch)]
1446pub fn trix_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1447 let config: TrixBatchConfig = serde_wasm_bindgen::from_value(config)
1448 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1449
1450 let sweep = TrixBatchRange {
1451 period: config.period_range,
1452 };
1453
1454 let output = trix_batch_inner(data, &sweep, Kernel::Scalar, false)
1455 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1456
1457 let periods: Vec<usize> = output.combos.iter().map(|p| p.period.unwrap()).collect();
1458
1459 let js_output = TrixBatchJsOutput {
1460 values: output.values,
1461 periods,
1462 rows: output.rows,
1463 cols: output.cols,
1464 };
1465
1466 serde_wasm_bindgen::to_value(&js_output)
1467 .map_err(|e| JsValue::from_str(&format!("Failed to serialize output: {}", e)))
1468}
1469
1470#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1471#[wasm_bindgen]
1472pub fn trix_batch_into(
1473 in_ptr: *const f64,
1474 out_ptr: *mut f64,
1475 len: usize,
1476 period_start: usize,
1477 period_end: usize,
1478 period_step: usize,
1479) -> Result<usize, JsValue> {
1480 if in_ptr.is_null() || out_ptr.is_null() {
1481 return Err(JsValue::from_str("Null pointer provided"));
1482 }
1483
1484 unsafe {
1485 let data = std::slice::from_raw_parts(in_ptr, len);
1486
1487 let sweep = TrixBatchRange {
1488 period: (period_start, period_end, period_step),
1489 };
1490
1491 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1492 let num_combos = combos.len();
1493 let total_size = num_combos
1494 .checked_mul(len)
1495 .ok_or_else(|| JsValue::from_str("trix_batch_into: rows*cols overflow"))?;
1496
1497 let out = std::slice::from_raw_parts_mut(out_ptr, total_size);
1498
1499 trix_batch_inner_into(data, &sweep, Kernel::Scalar, false, out)
1500 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1501
1502 Ok(num_combos)
1503 }
1504}
1505
1506#[cfg(feature = "python")]
1507pub fn register_trix_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
1508 m.add_function(wrap_pyfunction!(trix_py, m)?)?;
1509 m.add_function(wrap_pyfunction!(trix_batch_py, m)?)?;
1510 m.add_class::<TrixStreamPy>()?;
1511 Ok(())
1512}
1513
1514#[cfg(test)]
1515mod tests_into {
1516 use super::*;
1517
1518 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1519 (a.is_nan() && b.is_nan()) || (a == b)
1520 }
1521
1522 #[test]
1523 fn test_trix_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1524 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1525 let candles = crate::utilities::data_loader::read_candles_from_csv(file_path)?;
1526
1527 let params = TrixParams::default();
1528 let input = TrixInput::from_candles(&candles, "close", params);
1529
1530 let baseline = trix(&input)?.values;
1531
1532 let mut out = vec![0.0; baseline.len()];
1533
1534 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1535 {
1536 trix_into(&input, &mut out)?;
1537 }
1538
1539 assert_eq!(baseline.len(), out.len());
1540 for i in 0..out.len() {
1541 let a = baseline[i];
1542 let b = out[i];
1543 if a.is_nan() || b.is_nan() {
1544 assert!(
1545 eq_or_both_nan(a, b),
1546 "NaN mismatch at index {}: {:?} vs {:?}",
1547 i,
1548 a,
1549 b
1550 );
1551 } else {
1552 let diff = (a - b).abs();
1553 assert!(
1554 diff <= 1e-12,
1555 "Value mismatch at {}: a={} b={} diff={}",
1556 i,
1557 a,
1558 b,
1559 diff
1560 );
1561 }
1562 }
1563 Ok(())
1564 }
1565}
1566
1567#[cfg(test)]
1568mod tests {
1569 use super::*;
1570 use crate::skip_if_unsupported;
1571 use crate::utilities::data_loader::read_candles_from_csv;
1572
1573 fn check_trix_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1574 skip_if_unsupported!(kernel, test_name);
1575 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1576 let candles = read_candles_from_csv(file_path)?;
1577 let default_params = TrixParams { period: None };
1578 let input_default = TrixInput::from_candles(&candles, "close", default_params);
1579 let output_default = trix_with_kernel(&input_default, kernel)?;
1580 assert_eq!(output_default.values.len(), candles.close.len());
1581 let params_period_14 = TrixParams { period: Some(14) };
1582 let input_period_14 = TrixInput::from_candles(&candles, "hl2", params_period_14);
1583 let output_period_14 = trix_with_kernel(&input_period_14, kernel)?;
1584 assert_eq!(output_period_14.values.len(), candles.close.len());
1585 let params_custom = TrixParams { period: Some(20) };
1586 let input_custom = TrixInput::from_candles(&candles, "hlc3", params_custom);
1587 let output_custom = trix_with_kernel(&input_custom, kernel)?;
1588 assert_eq!(output_custom.values.len(), candles.close.len());
1589 Ok(())
1590 }
1591
1592 fn check_trix_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1593 skip_if_unsupported!(kernel, test_name);
1594 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1595 let candles = read_candles_from_csv(file_path)?;
1596 let close_prices = candles.select_candle_field("close")?;
1597 let params = TrixParams { period: Some(18) };
1598 let input = TrixInput::from_candles(&candles, "close", params);
1599 let trix_result = trix_with_kernel(&input, kernel)?;
1600 assert_eq!(
1601 trix_result.values.len(),
1602 close_prices.len(),
1603 "TRIX length mismatch"
1604 );
1605 let expected_last_five = [
1606 -16.03736447,
1607 -15.92084231,
1608 -15.76171478,
1609 -15.53571033,
1610 -15.34967155,
1611 ];
1612 assert!(trix_result.values.len() >= 5, "TRIX length too short");
1613 let start_index = trix_result.values.len() - 5;
1614 let result_last_five = &trix_result.values[start_index..];
1615 for (i, &value) in result_last_five.iter().enumerate() {
1616 let expected_value = expected_last_five[i];
1617
1618 let tolerance = 0.3;
1619 assert!(
1620 (value - expected_value).abs() < tolerance,
1621 "TRIX mismatch at index {}: expected {}, got {}, diff={}",
1622 i,
1623 expected_value,
1624 value,
1625 (value - expected_value).abs()
1626 );
1627 }
1628 Ok(())
1629 }
1630
1631 fn check_trix_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1632 skip_if_unsupported!(kernel, test_name);
1633 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1634 let candles = read_candles_from_csv(file_path)?;
1635 let input = TrixInput::with_default_candles(&candles);
1636 match input.data {
1637 TrixData::Candles { source, .. } => assert_eq!(source, "close"),
1638 _ => panic!("Expected TrixData::Candles"),
1639 }
1640 Ok(())
1641 }
1642
1643 fn check_trix_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1644 skip_if_unsupported!(kernel, test_name);
1645 let input_data = [10.0, 20.0, 30.0];
1646 let params = TrixParams { period: Some(0) };
1647 let input = TrixInput::from_slice(&input_data, params);
1648 let res = trix_with_kernel(&input, kernel);
1649 assert!(
1650 res.is_err(),
1651 "[{}] TRIX should fail with zero period",
1652 test_name
1653 );
1654 Ok(())
1655 }
1656
1657 fn check_trix_period_exceeds_length(
1658 test_name: &str,
1659 kernel: Kernel,
1660 ) -> Result<(), Box<dyn Error>> {
1661 skip_if_unsupported!(kernel, test_name);
1662 let data_small = [10.0, 20.0, 30.0];
1663 let params = TrixParams { period: Some(10) };
1664 let input = TrixInput::from_slice(&data_small, params);
1665 let res = trix_with_kernel(&input, kernel);
1666 assert!(
1667 res.is_err(),
1668 "[{}] TRIX should fail with period exceeding length",
1669 test_name
1670 );
1671 Ok(())
1672 }
1673
1674 fn check_trix_very_small_dataset(
1675 test_name: &str,
1676 kernel: Kernel,
1677 ) -> Result<(), Box<dyn Error>> {
1678 skip_if_unsupported!(kernel, test_name);
1679 let single_point = [42.0];
1680 let params = TrixParams { period: Some(18) };
1681 let input = TrixInput::from_slice(&single_point, params);
1682 let res = trix_with_kernel(&input, kernel);
1683 assert!(
1684 res.is_err(),
1685 "[{}] TRIX should fail with insufficient data",
1686 test_name
1687 );
1688 Ok(())
1689 }
1690
1691 fn check_trix_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1692 skip_if_unsupported!(kernel, test_name);
1693 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1694 let candles = read_candles_from_csv(file_path)?;
1695 let params = TrixParams { period: Some(10) };
1696 let input = TrixInput::from_candles(&candles, "close", params);
1697 let first_result = trix_with_kernel(&input, kernel)?;
1698 let second_input =
1699 TrixInput::from_slice(&first_result.values, TrixParams { period: Some(10) });
1700 let second_result = trix_with_kernel(&second_input, kernel)?;
1701 assert_eq!(first_result.values.len(), second_result.values.len());
1702 Ok(())
1703 }
1704
1705 #[cfg(debug_assertions)]
1706 fn check_trix_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1707 skip_if_unsupported!(kernel, test_name);
1708
1709 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1710 let candles = read_candles_from_csv(file_path)?;
1711
1712 let test_params = vec![
1713 TrixParams::default(),
1714 TrixParams { period: Some(2) },
1715 TrixParams { period: Some(5) },
1716 TrixParams { period: Some(10) },
1717 TrixParams { period: Some(14) },
1718 TrixParams { period: Some(20) },
1719 TrixParams { period: Some(30) },
1720 TrixParams { period: Some(50) },
1721 TrixParams { period: Some(100) },
1722 TrixParams { period: Some(200) },
1723 ];
1724
1725 for (param_idx, params) in test_params.iter().enumerate() {
1726 let input = TrixInput::from_candles(&candles, "close", params.clone());
1727 let output = trix_with_kernel(&input, kernel)?;
1728
1729 for (i, &val) in output.values.iter().enumerate() {
1730 if val.is_nan() {
1731 continue;
1732 }
1733
1734 let bits = val.to_bits();
1735
1736 if bits == 0x11111111_11111111 {
1737 panic!(
1738 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1739 with params: period={} (param set {})",
1740 test_name,
1741 val,
1742 bits,
1743 i,
1744 params.period.unwrap_or(18),
1745 param_idx
1746 );
1747 }
1748
1749 if bits == 0x22222222_22222222 {
1750 panic!(
1751 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1752 with params: period={} (param set {})",
1753 test_name,
1754 val,
1755 bits,
1756 i,
1757 params.period.unwrap_or(18),
1758 param_idx
1759 );
1760 }
1761
1762 if bits == 0x33333333_33333333 {
1763 panic!(
1764 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1765 with params: period={} (param set {})",
1766 test_name,
1767 val,
1768 bits,
1769 i,
1770 params.period.unwrap_or(18),
1771 param_idx
1772 );
1773 }
1774 }
1775 }
1776
1777 Ok(())
1778 }
1779
1780 #[cfg(not(debug_assertions))]
1781 fn check_trix_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1782 Ok(())
1783 }
1784
1785 macro_rules! generate_all_trix_tests {
1786 ($($test_fn:ident),*) => {
1787 paste::paste! {
1788 $(
1789 #[test]
1790 fn [<$test_fn _scalar_f64>]() {
1791 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1792 }
1793 )*
1794 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1795 $(
1796 #[test]
1797 fn [<$test_fn _avx2_f64>]() {
1798 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1799 }
1800 #[test]
1801 fn [<$test_fn _avx512_f64>]() {
1802 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1803 }
1804 )*
1805 }
1806 }
1807 }
1808 generate_all_trix_tests!(
1809 check_trix_partial_params,
1810 check_trix_accuracy,
1811 check_trix_default_candles,
1812 check_trix_zero_period,
1813 check_trix_period_exceeds_length,
1814 check_trix_very_small_dataset,
1815 check_trix_reinput,
1816 check_trix_no_poison
1817 );
1818
1819 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1820 skip_if_unsupported!(kernel, test);
1821 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1822 let c = read_candles_from_csv(file)?;
1823 let output = TrixBatchBuilder::new()
1824 .kernel(kernel)
1825 .apply_candles(&c, "close")?;
1826 let def = TrixParams::default();
1827 let row = output.values_for(&def).expect("default row missing");
1828 assert_eq!(row.len(), c.close.len());
1829 let expected = [
1830 -16.03736447,
1831 -15.92084231,
1832 -15.76171478,
1833 -15.53571033,
1834 -15.34967155,
1835 ];
1836 let start = row.len() - 5;
1837 for (i, &v) in row[start..].iter().enumerate() {
1838 let tolerance = 0.3;
1839 assert!(
1840 (v - expected[i]).abs() < tolerance,
1841 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}, diff={}",
1842 (v - expected[i]).abs()
1843 );
1844 }
1845 Ok(())
1846 }
1847
1848 #[cfg(debug_assertions)]
1849 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1850 skip_if_unsupported!(kernel, test);
1851
1852 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1853 let c = read_candles_from_csv(file)?;
1854
1855 let test_configs = vec![
1856 (2, 10, 2),
1857 (10, 30, 5),
1858 (30, 100, 10),
1859 (2, 5, 1),
1860 (18, 18, 0),
1861 (5, 25, 5),
1862 (50, 100, 25),
1863 (14, 28, 7),
1864 ];
1865
1866 for (cfg_idx, &(period_start, period_end, period_step)) in test_configs.iter().enumerate() {
1867 let output = TrixBatchBuilder::new()
1868 .kernel(kernel)
1869 .period_range(period_start, period_end, period_step)
1870 .apply_candles(&c, "close")?;
1871
1872 for (idx, &val) in output.values.iter().enumerate() {
1873 if val.is_nan() {
1874 continue;
1875 }
1876
1877 let bits = val.to_bits();
1878 let row = idx / output.cols;
1879 let col = idx % output.cols;
1880 let combo = &output.combos[row];
1881
1882 if bits == 0x11111111_11111111 {
1883 panic!(
1884 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1885 at row {} col {} (flat index {}) with params: period={}",
1886 test,
1887 cfg_idx,
1888 val,
1889 bits,
1890 row,
1891 col,
1892 idx,
1893 combo.period.unwrap_or(18)
1894 );
1895 }
1896
1897 if bits == 0x22222222_22222222 {
1898 panic!(
1899 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1900 at row {} col {} (flat index {}) with params: period={}",
1901 test,
1902 cfg_idx,
1903 val,
1904 bits,
1905 row,
1906 col,
1907 idx,
1908 combo.period.unwrap_or(18)
1909 );
1910 }
1911
1912 if bits == 0x33333333_33333333 {
1913 panic!(
1914 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1915 at row {} col {} (flat index {}) with params: period={}",
1916 test,
1917 cfg_idx,
1918 val,
1919 bits,
1920 row,
1921 col,
1922 idx,
1923 combo.period.unwrap_or(18)
1924 );
1925 }
1926 }
1927 }
1928
1929 Ok(())
1930 }
1931
1932 #[cfg(not(debug_assertions))]
1933 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1934 Ok(())
1935 }
1936
1937 macro_rules! gen_batch_tests {
1938 ($fn_name:ident) => {
1939 paste::paste! {
1940 #[test] fn [<$fn_name _scalar>]() {
1941 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1942 }
1943 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1944 #[test] fn [<$fn_name _avx2>]() {
1945 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1946 }
1947 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1948 #[test] fn [<$fn_name _avx512>]() {
1949 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1950 }
1951 #[test] fn [<$fn_name _auto_detect>]() {
1952 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1953 }
1954 }
1955 };
1956 }
1957 gen_batch_tests!(check_batch_default_row);
1958 gen_batch_tests!(check_batch_no_poison);
1959
1960 #[cfg(feature = "proptest")]
1961 #[allow(clippy::float_cmp)]
1962 fn check_trix_property(
1963 test_name: &str,
1964 kernel: Kernel,
1965 ) -> Result<(), Box<dyn std::error::Error>> {
1966 use proptest::prelude::*;
1967 skip_if_unsupported!(kernel, test_name);
1968
1969 let strat = (2usize..=20).prop_flat_map(|period| {
1970 let min_data_needed = 3 * (period - 1) + 1 + 10;
1971 (
1972 prop::collection::vec(
1973 (0.001f64..1e6f64)
1974 .prop_filter("positive finite", |x| x.is_finite() && *x > 0.0),
1975 min_data_needed..400,
1976 ),
1977 Just(period),
1978 )
1979 });
1980
1981 proptest::test_runner::TestRunner::default()
1982 .run(&strat, |(data, period)| {
1983 let params = TrixParams {
1984 period: Some(period),
1985 };
1986 let input = TrixInput::from_slice(&data, params);
1987
1988 let TrixOutput { values: out } = trix_with_kernel(&input, kernel).unwrap();
1989
1990 let TrixOutput { values: ref_out } =
1991 trix_with_kernel(&input, Kernel::Scalar).unwrap();
1992
1993 let warmup_period = 3 * (period - 1) + 1;
1994 for i in 0..warmup_period.min(data.len()) {
1995 prop_assert!(
1996 out[i].is_nan(),
1997 "Expected NaN during warmup at index {}, got {}",
1998 i,
1999 out[i]
2000 );
2001 }
2002
2003 if data.len() > warmup_period {
2004 prop_assert!(
2005 !out[warmup_period].is_nan(),
2006 "Expected valid value at index {} (after warmup), got NaN",
2007 warmup_period
2008 );
2009 }
2010
2011 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10)
2012 && data.len() > warmup_period
2013 {
2014 for i in warmup_period..data.len() {
2015 prop_assert!(
2016 out[i].abs() < 1e-6,
2017 "TRIX should be near zero for constant data at index {}: got {}",
2018 i,
2019 out[i]
2020 );
2021 }
2022 }
2023
2024 let increasing_count = data.windows(2).filter(|w| w[1] > w[0]).count();
2025 let is_mostly_increasing = increasing_count as f64 > data.len() as f64 * 0.8;
2026 if is_mostly_increasing && data.len() > warmup_period + 10 {
2027 let last_values: Vec<f64> = out[(data.len() - 5)..data.len()]
2028 .iter()
2029 .filter(|&&v| !v.is_nan())
2030 .copied()
2031 .collect();
2032 if !last_values.is_empty() {
2033 let avg = last_values.iter().sum::<f64>() / last_values.len() as f64;
2034 prop_assert!(
2035 avg > -10.0,
2036 "TRIX average should be positive for mostly increasing data: got {}",
2037 avg
2038 );
2039 }
2040 }
2041
2042 let decreasing_count = data.windows(2).filter(|w| w[1] < w[0]).count();
2043 let is_mostly_decreasing = decreasing_count as f64 > data.len() as f64 * 0.8;
2044 if is_mostly_decreasing && data.len() > warmup_period + 10 {
2045 let last_values: Vec<f64> = out[(data.len() - 5)..data.len()]
2046 .iter()
2047 .filter(|&&v| !v.is_nan())
2048 .copied()
2049 .collect();
2050 if !last_values.is_empty() {
2051 let avg = last_values.iter().sum::<f64>() / last_values.len() as f64;
2052 prop_assert!(
2053 avg < 10.0,
2054 "TRIX average should be negative for mostly decreasing data: got {}",
2055 avg
2056 );
2057 }
2058 }
2059
2060 for i in warmup_period..data.len() {
2061 if !out[i].is_nan() {
2062 prop_assert!(
2063 out[i].abs() < 100000.0,
2064 "TRIX value too large at index {}: {}",
2065 i,
2066 out[i]
2067 );
2068 }
2069 }
2070
2071 for (i, &val) in out.iter().enumerate() {
2072 prop_assert!(
2073 val.is_nan() || val.is_finite(),
2074 "TRIX should not produce infinite values at index {}: got {}",
2075 i,
2076 val
2077 );
2078 }
2079
2080 if data.len() > warmup_period + 20 {
2081 let log_returns: Vec<f64> = data
2082 .windows(2)
2083 .skip(warmup_period)
2084 .map(|w| (w[1] / w[0]).ln() * 10000.0)
2085 .collect();
2086
2087 let trix_values: Vec<f64> = out
2088 .iter()
2089 .skip(warmup_period + 1)
2090 .filter(|&&v| !v.is_nan())
2091 .copied()
2092 .collect();
2093
2094 if log_returns.len() > 10 && trix_values.len() > 10 {
2095 let log_std = calculate_std(&log_returns);
2096 let trix_std = calculate_std(&trix_values);
2097
2098 prop_assert!(
2099 trix_std <= log_std * 1.2 || trix_std < 1.0,
2100 "TRIX should be smoother than log returns: TRIX std={}, log return std={}",
2101 trix_std,
2102 log_std
2103 );
2104 }
2105 }
2106
2107 for i in warmup_period..data.len() {
2108 let y = out[i];
2109 let r = ref_out[i];
2110
2111 if !y.is_finite() || !r.is_finite() {
2112 prop_assert!(
2113 y.to_bits() == r.to_bits(),
2114 "finite/NaN mismatch at index {}: {} vs {}",
2115 i,
2116 y,
2117 r
2118 );
2119 continue;
2120 }
2121
2122 let y_bits = y.to_bits();
2123 let r_bits = r.to_bits();
2124 let ulp_diff: u64 = y_bits.abs_diff(r_bits);
2125
2126 prop_assert!(
2127 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
2128 "Kernel mismatch at index {}: {} vs {} (ULP={})",
2129 i,
2130 y,
2131 r,
2132 ulp_diff
2133 );
2134 }
2135
2136 let TrixOutput { values: out2 } = trix_with_kernel(&input, kernel).unwrap();
2137 prop_assert_eq!(
2138 out.len(),
2139 out2.len(),
2140 "Output length mismatch on second run"
2141 );
2142 for i in 0..out.len() {
2143 prop_assert!(
2144 out[i].to_bits() == out2[i].to_bits(),
2145 "Determinism failed at index {}: {} vs {}",
2146 i,
2147 out[i],
2148 out2[i]
2149 );
2150 }
2151
2152 Ok(())
2153 })
2154 .unwrap();
2155
2156 let edge_data = vec![0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0, 10000.0];
2157 let params = TrixParams { period: Some(2) };
2158 let input = TrixInput::from_slice(&edge_data, params);
2159 let result = trix_with_kernel(&input, kernel);
2160 assert!(
2161 result.is_ok(),
2162 "TRIX should handle very small positive values"
2163 );
2164
2165 Ok(())
2166 }
2167
2168 fn calculate_std(values: &[f64]) -> f64 {
2169 let mean = values.iter().sum::<f64>() / values.len() as f64;
2170 let variance =
2171 values.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
2172 variance.sqrt()
2173 }
2174
2175 #[cfg(feature = "proptest")]
2176 generate_all_trix_tests!(check_trix_property);
2177}