1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::moving_averages::alma_wrapper::DeviceArrayF32;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::CudaTrendflex;
5use crate::utilities::data_loader::{source_type, Candles};
6#[cfg(all(feature = "python", feature = "cuda"))]
7use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
8use crate::utilities::enums::Kernel;
9use crate::utilities::helpers::{
10 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
11 make_uninit_matrix,
12};
13#[cfg(feature = "python")]
14use crate::utilities::kernel_validation::validate_kernel;
15use aligned_vec::{AVec, ConstAlign, CACHELINE_ALIGN};
16#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
17use core::arch::x86_64::*;
18#[cfg(all(feature = "python", feature = "cuda"))]
19use cust::context::Context;
20#[cfg(all(feature = "python", feature = "cuda"))]
21use cust::memory::DeviceBuffer;
22#[cfg(not(target_arch = "wasm32"))]
23use rayon::prelude::*;
24use std::convert::AsRef;
25use std::error::Error;
26use std::mem::MaybeUninit;
27#[cfg(all(feature = "python", feature = "cuda"))]
28use std::sync::Arc;
29use thiserror::Error;
30
31impl<'a> AsRef<[f64]> for TrendFlexInput<'a> {
32 #[inline(always)]
33 fn as_ref(&self) -> &[f64] {
34 match &self.data {
35 TrendFlexData::Slice(slice) => slice,
36 TrendFlexData::Candles { candles, source } => source_type(candles, source),
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
42pub enum TrendFlexData<'a> {
43 Candles {
44 candles: &'a Candles,
45 source: &'a str,
46 },
47 Slice(&'a [f64]),
48}
49
50#[derive(Debug, Clone)]
51pub struct TrendFlexOutput {
52 pub values: Vec<f64>,
53}
54
55#[derive(Debug, Clone)]
56#[cfg_attr(
57 all(target_arch = "wasm32", feature = "wasm"),
58 derive(Serialize, Deserialize)
59)]
60pub struct TrendFlexParams {
61 pub period: Option<usize>,
62}
63
64impl Default for TrendFlexParams {
65 fn default() -> Self {
66 Self { period: Some(20) }
67 }
68}
69
70#[derive(Debug, Clone)]
71pub struct TrendFlexInput<'a> {
72 pub data: TrendFlexData<'a>,
73 pub params: TrendFlexParams,
74}
75
76#[cfg(all(feature = "python", feature = "cuda"))]
77#[pyo3::prelude::pyclass(
78 module = "ta_indicators.cuda",
79 name = "TrendFlexDeviceArrayF32",
80 unsendable
81)]
82pub struct TrendFlexDeviceArrayF32Py {
83 pub(crate) inner: DeviceArrayF32,
84 pub(crate) _ctx: Arc<Context>,
85 pub(crate) device_id: u32,
86}
87
88#[cfg(all(feature = "python", feature = "cuda"))]
89#[pyo3::prelude::pymethods]
90impl TrendFlexDeviceArrayF32Py {
91 #[getter]
92 fn __cuda_array_interface__<'py>(
93 &self,
94 py: pyo3::prelude::Python<'py>,
95 ) -> pyo3::PyResult<pyo3::prelude::Bound<'py, pyo3::types::PyDict>> {
96 let d = pyo3::types::PyDict::new(py);
97 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
98 d.set_item("typestr", "<f4")?;
99 d.set_item(
100 "strides",
101 (
102 self.inner.cols * std::mem::size_of::<f32>(),
103 std::mem::size_of::<f32>(),
104 ),
105 )?;
106 d.set_item("data", (self.inner.device_ptr() as usize, false))?;
107
108 d.set_item("version", 3)?;
109 Ok(d)
110 }
111
112 fn __dlpack_device__(&self) -> (i32, i32) {
113 (2, self.device_id as i32)
114 }
115
116 #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
117 fn __dlpack__<'py>(
118 &mut self,
119 py: pyo3::prelude::Python<'py>,
120 stream: Option<pyo3::PyObject>,
121 max_version: Option<pyo3::PyObject>,
122 dl_device: Option<pyo3::PyObject>,
123 copy: Option<pyo3::PyObject>,
124 ) -> pyo3::PyResult<pyo3::prelude::PyObject> {
125 let (kdl, alloc_dev) = self.__dlpack_device__();
126 if let Some(dev_obj) = dl_device.as_ref() {
127 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
128 if dev_ty != kdl || dev_id != alloc_dev {
129 let wants_copy = copy
130 .as_ref()
131 .and_then(|c| c.extract::<bool>(py).ok())
132 .unwrap_or(false);
133 if wants_copy {
134 return Err(pyo3::exceptions::PyValueError::new_err(
135 "device copy not implemented for __dlpack__",
136 ));
137 } else {
138 return Err(pyo3::exceptions::PyValueError::new_err(
139 "dl_device mismatch for __dlpack__",
140 ));
141 }
142 }
143 }
144 }
145 let _ = stream;
146
147 let dummy = DeviceBuffer::from_slice(&[])
148 .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
149 let inner = std::mem::replace(
150 &mut self.inner,
151 DeviceArrayF32 {
152 buf: dummy,
153 rows: 0,
154 cols: 0,
155 },
156 );
157
158 let rows = inner.rows;
159 let cols = inner.cols;
160 let buf = inner.buf;
161
162 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
163
164 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
165 }
166}
167
168impl<'a> TrendFlexInput<'a> {
169 #[inline]
170 pub fn from_candles(c: &'a Candles, s: &'a str, p: TrendFlexParams) -> Self {
171 Self {
172 data: TrendFlexData::Candles {
173 candles: c,
174 source: s,
175 },
176 params: p,
177 }
178 }
179 #[inline]
180 pub fn from_slice(sl: &'a [f64], p: TrendFlexParams) -> Self {
181 Self {
182 data: TrendFlexData::Slice(sl),
183 params: p,
184 }
185 }
186 #[inline]
187 pub fn with_default_candles(c: &'a Candles) -> Self {
188 Self::from_candles(c, "close", TrendFlexParams::default())
189 }
190 #[inline]
191 pub fn get_period(&self) -> usize {
192 self.params.period.unwrap_or(20)
193 }
194}
195
196#[derive(Copy, Clone, Debug)]
197pub struct TrendFlexBuilder {
198 period: Option<usize>,
199 kernel: Kernel,
200}
201
202impl Default for TrendFlexBuilder {
203 fn default() -> Self {
204 Self {
205 period: None,
206 kernel: Kernel::Auto,
207 }
208 }
209}
210
211impl TrendFlexBuilder {
212 #[inline(always)]
213 pub fn new() -> Self {
214 Self::default()
215 }
216 #[inline(always)]
217 pub fn period(mut self, n: usize) -> Self {
218 self.period = Some(n);
219 self
220 }
221 #[inline(always)]
222 pub fn kernel(mut self, k: Kernel) -> Self {
223 self.kernel = k;
224 self
225 }
226 #[inline(always)]
227 pub fn apply(self, c: &Candles) -> Result<TrendFlexOutput, TrendFlexError> {
228 let p = TrendFlexParams {
229 period: self.period,
230 };
231 let i = TrendFlexInput::from_candles(c, "close", p);
232 trendflex_with_kernel(&i, self.kernel)
233 }
234 #[inline(always)]
235 pub fn apply_slice(self, d: &[f64]) -> Result<TrendFlexOutput, TrendFlexError> {
236 let p = TrendFlexParams {
237 period: self.period,
238 };
239 let i = TrendFlexInput::from_slice(d, p);
240 trendflex_with_kernel(&i, self.kernel)
241 }
242 #[inline(always)]
243 pub fn into_stream(self) -> Result<TrendFlexStream, TrendFlexError> {
244 let p = TrendFlexParams {
245 period: self.period,
246 };
247 TrendFlexStream::try_new(p)
248 }
249}
250
251#[derive(Debug, Error)]
252pub enum TrendFlexError {
253 #[error("trendflex: No data provided.")]
254 NoDataProvided,
255 #[error("trendflex: All values are NaN.")]
256 AllValuesNaN,
257 #[error("trendflex: period = 0")]
258 ZeroTrendFlexPeriod { period: usize },
259 #[error("trendflex: period > data len: period = {period}, data_len = {data_len}")]
260 TrendFlexPeriodExceedsData { period: usize, data_len: usize },
261 #[error(
262 "trendflex: smoother period > data len: ss_period = {ss_period}, data_len = {data_len}"
263 )]
264 SmootherPeriodExceedsData { ss_period: usize, data_len: usize },
265 #[error("trendflex: output length mismatch: expected {expected}, got {got}")]
266 OutputLengthMismatch { expected: usize, got: usize },
267 #[error("trendflex: not enough valid data: needed {needed}, valid {valid}")]
268 NotEnoughValidData { needed: usize, valid: usize },
269 #[error("trendflex: invalid range: start={start}, end={end}, step={step}")]
270 InvalidRange {
271 start: usize,
272 end: usize,
273 step: usize,
274 },
275 #[error("trendflex: invalid kernel for batch: {0:?}")]
276 InvalidKernelForBatch(Kernel),
277 #[error("trendflex: dimensions overflow: rows={rows}, cols={cols}")]
278 DimensionsOverflow { rows: usize, cols: usize },
279}
280
281#[inline]
282pub fn trendflex(input: &TrendFlexInput) -> Result<TrendFlexOutput, TrendFlexError> {
283 trendflex_with_kernel(input, Kernel::Auto)
284}
285
286pub fn trendflex_with_kernel(
287 input: &TrendFlexInput,
288 kernel: Kernel,
289) -> Result<TrendFlexOutput, TrendFlexError> {
290 let data: &[f64] = input.as_ref();
291 let len = data.len();
292 if len == 0 {
293 return Err(TrendFlexError::NoDataProvided);
294 }
295
296 let period = input.get_period();
297 if period == 0 {
298 return Err(TrendFlexError::ZeroTrendFlexPeriod { period });
299 }
300 if period >= len {
301 return Err(TrendFlexError::TrendFlexPeriodExceedsData {
302 period,
303 data_len: len,
304 });
305 }
306
307 let first = data
308 .iter()
309 .position(|x| !x.is_nan())
310 .ok_or(TrendFlexError::AllValuesNaN)?;
311 let ss_period = ((period as f64) / 2.0).round() as usize;
312
313 let valid = len - first;
314 if valid < period {
315 return Err(TrendFlexError::NotEnoughValidData {
316 needed: period,
317 valid,
318 });
319 }
320 if ss_period > len {
321 return Err(TrendFlexError::SmootherPeriodExceedsData {
322 ss_period,
323 data_len: len,
324 });
325 }
326
327 let warm = first + period;
328 let mut out = alloc_with_nan_prefix(len, warm);
329
330 let chosen = match kernel {
331 Kernel::Auto => detect_best_kernel(),
332 k => k,
333 };
334
335 unsafe {
336 match chosen {
337 Kernel::Scalar | Kernel::ScalarBatch => {
338 trendflex_scalar_into(data, period, ss_period, first, &mut out)?
339 }
340 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
341 Kernel::Avx2 | Kernel::Avx2Batch => {
342 trendflex_avx2_into(data, period, ss_period, first, &mut out)?
343 }
344 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
345 Kernel::Avx512 | Kernel::Avx512Batch => {
346 trendflex_avx512_into(data, period, ss_period, first, &mut out)?
347 }
348 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
349 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
350 trendflex_scalar_into(data, period, ss_period, first, &mut out)?
351 }
352 Kernel::Auto => unreachable!(),
353 }
354 }
355
356 Ok(TrendFlexOutput { values: out })
357}
358
359pub fn trendflex_into_slice(
360 dst: &mut [f64],
361 input: &TrendFlexInput,
362 kernel: Kernel,
363) -> Result<(), TrendFlexError> {
364 let data: &[f64] = input.as_ref();
365 let len = data.len();
366 if dst.len() != len {
367 return Err(TrendFlexError::OutputLengthMismatch {
368 expected: len,
369 got: dst.len(),
370 });
371 }
372 if len == 0 {
373 return Err(TrendFlexError::NoDataProvided);
374 }
375 let period = input.get_period();
376 if period == 0 {
377 return Err(TrendFlexError::ZeroTrendFlexPeriod { period });
378 }
379 if period >= len {
380 return Err(TrendFlexError::TrendFlexPeriodExceedsData {
381 period,
382 data_len: len,
383 });
384 }
385 let first = data
386 .iter()
387 .position(|x| !x.is_nan())
388 .ok_or(TrendFlexError::AllValuesNaN)?;
389 let ss_period = ((period as f64) / 2.0).round() as usize;
390 let valid = len - first;
391 if valid < period {
392 return Err(TrendFlexError::NotEnoughValidData {
393 needed: period,
394 valid,
395 });
396 }
397 if ss_period > data.len() {
398 return Err(TrendFlexError::SmootherPeriodExceedsData {
399 ss_period,
400 data_len: data.len(),
401 });
402 }
403
404 let chosen = match kernel {
405 Kernel::Auto => detect_best_kernel(),
406 k => k,
407 };
408
409 unsafe {
410 match chosen {
411 Kernel::Scalar | Kernel::ScalarBatch => {
412 trendflex_scalar_into(data, period, ss_period, first, dst)?
413 }
414 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
415 Kernel::Avx2 | Kernel::Avx2Batch => {
416 trendflex_avx2_into(data, period, ss_period, first, dst)?
417 }
418 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
419 Kernel::Avx512 | Kernel::Avx512Batch => {
420 trendflex_avx512_into(data, period, ss_period, first, dst)?
421 }
422 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
423 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
424 trendflex_scalar_into(data, period, ss_period, first, dst)?
425 }
426 Kernel::Auto => unreachable!(),
427 }
428 }
429
430 let warmup_end = first + period;
431 for v in &mut dst[..warmup_end] {
432 *v = f64::NAN;
433 }
434 Ok(())
435}
436
437#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
438#[inline]
439pub fn trendflex_into(input: &TrendFlexInput, out: &mut [f64]) -> Result<(), TrendFlexError> {
440 trendflex_into_slice(out, input, Kernel::Auto)
441}
442
443#[inline]
444unsafe fn trendflex_scalar_into(
445 data: &[f64],
446 period: usize,
447 ss_period: usize,
448 first_valid: usize,
449 out: &mut [f64],
450) -> Result<(), TrendFlexError> {
451 use std::f64::consts::PI;
452
453 let len = data.len();
454 let warm = first_valid + period;
455
456 for i in 0..warm.min(out.len()) {
457 out[i] = f64::NAN;
458 }
459
460 if first_valid >= len {
461 return Ok(());
462 }
463
464 let a = (-1.414_f64 * PI / ss_period as f64).exp();
465 let a_sq = a * a;
466 let b = 2.0 * a * (1.414_f64 * PI / ss_period as f64).cos();
467
468 let c = (1.0 + a_sq - b) * 0.5;
469
470 let m = len - first_valid;
471 if m < period {
472 return Ok(());
473 }
474 if m < ss_period {
475 return Err(TrendFlexError::SmootherPeriodExceedsData {
476 ss_period,
477 data_len: m,
478 });
479 }
480
481 let x = &data[first_valid..];
482
483 let mut prev2 = x[0];
484 let mut prev1 = if m > 1 { x[1] } else { x[0] };
485
486 let mut ring = vec![0.0f64; period];
487 let mut head = 0usize;
488 let mut sum = 0.0f64;
489
490 ring[head] = prev2;
491 sum += prev2;
492 head = (head + 1) % period;
493 if m > 1 {
494 ring[head] = prev1;
495 sum += prev1;
496 head = (head + 1) % period;
497 }
498
499 let tp_f = period as f64;
500 let inv_tp = 1.0 / tp_f;
501 let mut ms_prev = 0.0f64;
502
503 let mut i = 2usize;
504 while i < m && i < period {
505 let cur = (-a_sq).mul_add(prev2, b.mul_add(prev1, c * (x[i] + x[i - 1])));
506 prev2 = prev1;
507 prev1 = cur;
508
509 sum += cur;
510 ring[head] = cur;
511 head = (head + 1) % period;
512 i += 1;
513 }
514
515 while i < m {
516 let cur = (-a_sq).mul_add(prev2, b.mul_add(prev1, c * (x[i] + x[i - 1])));
517 prev2 = prev1;
518 prev1 = cur;
519
520 let my_sum = (tp_f * cur - sum) * inv_tp;
521
522 let ms_current = 0.04f64.mul_add(my_sum * my_sum, 0.96f64 * ms_prev);
523 ms_prev = ms_current;
524
525 let out_val = if ms_current != 0.0 {
526 my_sum / ms_current.sqrt()
527 } else {
528 0.0
529 };
530 out[first_valid + i] = out_val;
531
532 let old = ring[head];
533 sum += cur - old;
534 ring[head] = cur;
535 head = (head + 1) % period;
536
537 i += 1;
538 }
539
540 Ok(())
541}
542
543#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
544#[inline]
545#[target_feature(enable = "avx2,fma")]
546unsafe fn trendflex_avx2_into(
547 data: &[f64],
548 period: usize,
549 ss_period: usize,
550 first_valid: usize,
551 out: &mut [f64],
552) -> Result<(), TrendFlexError> {
553 use std::f64::consts::PI;
554
555 let len = data.len();
556 let warm = first_valid + period;
557 for i in 0..warm.min(out.len()) {
558 *out.get_unchecked_mut(i) = f64::NAN;
559 }
560
561 if first_valid >= len {
562 return Ok(());
563 }
564
565 let a = (-1.414_f64 * PI / ss_period as f64).exp();
566 let a_sq = a * a;
567 let b = 2.0 * a * (1.414_f64 * PI / ss_period as f64).cos();
568 let c = (1.0 + a_sq - b) * 0.5;
569
570 #[inline(always)]
571 unsafe fn run_series_avx2(
572 x: &[f64],
573 period: usize,
574 a_sq: f64,
575 b: f64,
576 c: f64,
577 out: &mut [f64],
578 out_off: usize,
579 ) {
580 let n = x.len();
581 if n == 0 {
582 return;
583 }
584 let mut prev2 = x[0];
585 let mut prev1 = if n > 1 { x[1] } else { x[0] };
586
587 let mut ring = vec![0.0f64; period];
588 let mut sum = 0.0f64;
589 let mut head = 0usize;
590
591 ring[head] = prev2;
592 sum += prev2;
593 head = (head + 1) % period;
594 if n > 1 {
595 ring[head] = prev1;
596 sum += prev1;
597 head = (head + 1) % period;
598 }
599
600 let tp_f = period as f64;
601 let inv_tp = 1.0 / tp_f;
602 let mut ms_prev = 0.0f64;
603
604 let mut i = 2usize;
605 while i < n && i < period {
606 let cur = c * (x[i] + x[i - 1]) + b * prev1 - a_sq * prev2;
607 prev2 = prev1;
608 prev1 = cur;
609 sum += cur;
610 ring[head] = cur;
611 head = (head + 1) % period;
612 i += 1;
613 }
614
615 while i < n {
616 _mm_prefetch(x.as_ptr().add(i + 16).cast(), _MM_HINT_T0);
617 let cur = c * (x[i] + x[i - 1]) + b * prev1 - a_sq * prev2;
618 prev2 = prev1;
619 prev1 = cur;
620
621 let my_sum = (tp_f * cur - sum) * inv_tp;
622
623 let v = _mm_set_sd(my_sum);
624 let sq = _mm_mul_sd(v, v);
625 let s04 = _mm_mul_sd(_mm_set_sd(0.04), sq);
626 let s96 = _mm_mul_sd(_mm_set_sd(0.96), _mm_set_sd(ms_prev));
627 let ms_cur = _mm_add_sd(s04, s96);
628 let ms_current = _mm_cvtsd_f64(ms_cur);
629 ms_prev = ms_current;
630
631 let out_val = if ms_current != 0.0 {
632 let denom = _mm_sqrt_sd(_mm_setzero_pd(), _mm_set_sd(ms_current));
633 let denom_s = _mm_cvtsd_f64(denom);
634 my_sum / denom_s
635 } else {
636 0.0
637 };
638
639 _mm_stream_sd(
640 out.get_unchecked_mut(out_off + i) as *mut f64,
641 _mm_set_sd(out_val),
642 );
643
644 let old = ring[head];
645 sum += cur - old;
646 ring[head] = cur;
647 head = (head + 1) % period;
648
649 i += 1;
650 }
651 }
652
653 if first_valid == 0 {
654 run_series_avx2(data, period, a_sq, b, c, out, 0);
655 return Ok(());
656 }
657
658 let m = len - first_valid;
659 if m < period {
660 return Ok(());
661 }
662 if m < ss_period {
663 return Err(TrendFlexError::SmootherPeriodExceedsData {
664 ss_period,
665 data_len: m,
666 });
667 }
668 let tail = &data[first_valid..];
669 run_series_avx2(tail, period, a_sq, b, c, out, first_valid);
670 Ok(())
671}
672
673#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
674#[inline]
675#[target_feature(enable = "avx512f,avx512dq,fma")]
676unsafe fn trendflex_avx512_into(
677 data: &[f64],
678 period: usize,
679 ss_period: usize,
680 first_valid: usize,
681 out: &mut [f64],
682) -> Result<(), TrendFlexError> {
683 use std::f64::consts::PI;
684
685 let len = data.len();
686 let warm = first_valid + period;
687 for i in 0..warm.min(out.len()) {
688 *out.get_unchecked_mut(i) = f64::NAN;
689 }
690
691 if first_valid >= len {
692 return Ok(());
693 }
694
695 let a = (-1.414_f64 * PI / ss_period as f64).exp();
696 let a_sq = a * a;
697 let b = 2.0 * a * (1.414_f64 * PI / ss_period as f64).cos();
698 let c = (1.0 + a_sq - b) * 0.5;
699
700 #[inline(always)]
701 unsafe fn run_series_avx512(
702 x: &[f64],
703 period: usize,
704 a_sq: f64,
705 b: f64,
706 c: f64,
707 out: &mut [f64],
708 out_off: usize,
709 ) {
710 let n = x.len();
711 if n == 0 {
712 return;
713 }
714 let mut prev2 = *x.get_unchecked(0);
715 let mut prev1 = if n > 1 {
716 *x.get_unchecked(1)
717 } else {
718 *x.get_unchecked(0)
719 };
720 let mut ring = vec![0.0f64; period];
721 let mut sum = 0.0f64;
722 let mut head = 0usize;
723
724 *ring.get_unchecked_mut(head) = prev2;
725 sum += prev2;
726 head += 1;
727 if head == period {
728 head = 0;
729 }
730 if n > 1 {
731 *ring.get_unchecked_mut(head) = prev1;
732 sum += prev1;
733 head += 1;
734 if head == period {
735 head = 0;
736 }
737 }
738
739 let tp_f = period as f64;
740 let inv_tp = 1.0 / tp_f;
741 let mut ms_prev = 0.0f64;
742
743 let mut i = 2usize;
744 while i < n && i < period {
745 let cur =
746 c * (*x.get_unchecked(i) + *x.get_unchecked(i - 1)) + b * prev1 - a_sq * prev2;
747 prev2 = prev1;
748 prev1 = cur;
749 sum += cur;
750 *ring.get_unchecked_mut(head) = cur;
751 head += 1;
752 if head == period {
753 head = 0;
754 }
755 i += 1;
756 }
757
758 let use_stream = n >= 131072;
759 let use_unroll = n >= 262144;
760
761 if use_unroll {
762 while i + 1 < n {
763 _mm_prefetch(x.as_ptr().add(i + 32).cast(), _MM_HINT_T0);
764
765 let cur0 =
766 c * (*x.get_unchecked(i) + *x.get_unchecked(i - 1)) + b * prev1 - a_sq * prev2;
767 prev2 = prev1;
768 prev1 = cur0;
769
770 let my_sum0 = (tp_f * cur0 - sum) * inv_tp;
771
772 let v0 = _mm_set_sd(my_sum0);
773 let sq0 = _mm_mul_sd(v0, v0);
774 let ms0 = _mm_fmadd_sd(
775 _mm_set_sd(0.04),
776 sq0,
777 _mm_mul_sd(_mm_set_sd(0.96), _mm_set_sd(ms_prev)),
778 );
779 let ms0_s = _mm_cvtsd_f64(ms0);
780 ms_prev = ms0_s;
781 let out0 = if ms0_s != 0.0 {
782 let den0 = _mm_sqrt_sd(_mm_setzero_pd(), _mm_set_sd(ms0_s));
783 my_sum0 / _mm_cvtsd_f64(den0)
784 } else {
785 0.0
786 };
787 if use_stream {
788 _mm_stream_sd(
789 out.get_unchecked_mut(out_off + i) as *mut f64,
790 _mm_set_sd(out0),
791 );
792 } else {
793 *out.get_unchecked_mut(out_off + i) = out0;
794 }
795
796 let old0 = *ring.get_unchecked(head);
797 sum += cur0 - old0;
798 *ring.get_unchecked_mut(head) = cur0;
799 head += 1;
800 if head == period {
801 head = 0;
802 }
803
804 let cur1 =
805 c * (*x.get_unchecked(i + 1) + *x.get_unchecked(i)) + b * prev1 - a_sq * prev2;
806 prev2 = prev1;
807 prev1 = cur1;
808
809 let my_sum1 = (tp_f * cur1 - sum) * inv_tp;
810 let v1 = _mm_set_sd(my_sum1);
811 let sq1 = _mm_mul_sd(v1, v1);
812 let ms1 = _mm_fmadd_sd(
813 _mm_set_sd(0.04),
814 sq1,
815 _mm_mul_sd(_mm_set_sd(0.96), _mm_set_sd(ms_prev)),
816 );
817 let ms1_s = _mm_cvtsd_f64(ms1);
818 ms_prev = ms1_s;
819 let out1 = if ms1_s != 0.0 {
820 let den1 = _mm_sqrt_sd(_mm_setzero_pd(), _mm_set_sd(ms1_s));
821 my_sum1 / _mm_cvtsd_f64(den1)
822 } else {
823 0.0
824 };
825 if use_stream {
826 _mm_stream_sd(
827 out.get_unchecked_mut(out_off + i + 1) as *mut f64,
828 _mm_set_sd(out1),
829 );
830 } else {
831 *out.get_unchecked_mut(out_off + i + 1) = out1;
832 }
833
834 let old1 = *ring.get_unchecked(head);
835 sum += cur1 - old1;
836 *ring.get_unchecked_mut(head) = cur1;
837 head += 1;
838 if head == period {
839 head = 0;
840 }
841
842 i += 2;
843 }
844 }
845
846 while i < n {
847 _mm_prefetch(x.as_ptr().add(i + 32).cast(), _MM_HINT_T0);
848 let cur =
849 c * (*x.get_unchecked(i) + *x.get_unchecked(i - 1)) + b * prev1 - a_sq * prev2;
850 prev2 = prev1;
851 prev1 = cur;
852
853 let my_sum = (tp_f * cur - sum) * inv_tp;
854 let v = _mm_set_sd(my_sum);
855 let sq = _mm_mul_sd(v, v);
856 let ms = _mm_fmadd_sd(
857 _mm_set_sd(0.04),
858 sq,
859 _mm_mul_sd(_mm_set_sd(0.96), _mm_set_sd(ms_prev)),
860 );
861 let ms_s = _mm_cvtsd_f64(ms);
862 ms_prev = ms_s;
863 let out_val = if ms_s != 0.0 {
864 let den = _mm_sqrt_sd(_mm_setzero_pd(), _mm_set_sd(ms_s));
865 my_sum / _mm_cvtsd_f64(den)
866 } else {
867 0.0
868 };
869 if use_stream {
870 _mm_stream_sd(
871 out.get_unchecked_mut(out_off + i) as *mut f64,
872 _mm_set_sd(out_val),
873 );
874 } else {
875 *out.get_unchecked_mut(out_off + i) = out_val;
876 }
877
878 let old = *ring.get_unchecked(head);
879 sum += cur - old;
880 *ring.get_unchecked_mut(head) = cur;
881 head += 1;
882 if head == period {
883 head = 0;
884 }
885
886 i += 1;
887 }
888 }
889
890 if first_valid == 0 {
891 run_series_avx512(data, period, a_sq, b, c, out, 0);
892 return Ok(());
893 }
894
895 let m = len - first_valid;
896 if m < period {
897 return Ok(());
898 }
899 if m < ss_period {
900 return Err(TrendFlexError::SmootherPeriodExceedsData {
901 ss_period,
902 data_len: m,
903 });
904 }
905 let tail = &data[first_valid..];
906 run_series_avx512(tail, period, a_sq, b, c, out, first_valid);
907 Ok(())
908}
909
910#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
911#[inline]
912unsafe fn trendflex_avx512_short_into(
913 data: &[f64],
914 period: usize,
915 ss_period: usize,
916 first_valid: usize,
917 out: &mut [f64],
918) -> Result<(), TrendFlexError> {
919 trendflex_scalar_into(data, period, ss_period, first_valid, out)
920}
921
922#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
923#[inline]
924unsafe fn trendflex_avx512_long_into(
925 data: &[f64],
926 period: usize,
927 ss_period: usize,
928 first_valid: usize,
929 out: &mut [f64],
930) -> Result<(), TrendFlexError> {
931 trendflex_scalar_into(data, period, ss_period, first_valid, out)
932}
933
934#[derive(Debug, Clone)]
935pub struct TrendFlexStream {
936 period: usize,
937 ss_period: usize,
938
939 a: f64,
940 a_sq: f64,
941 b: f64,
942 c: f64,
943
944 buf: Vec<f64>,
945 sum: f64,
946 head: usize,
947
948 prev1_ssf: f64,
949 prev2_ssf: f64,
950 last_raw: f64,
951
952 n_ssf: usize,
953
954 ms_prev: f64,
955
956 inv_p: f64,
957}
958
959impl TrendFlexStream {
960 pub fn try_new(params: TrendFlexParams) -> Result<Self, TrendFlexError> {
961 let period = params.period.unwrap_or(20);
962 if period == 0 {
963 return Err(TrendFlexError::ZeroTrendFlexPeriod { period });
964 }
965
966 let ss_period = ((period as f64) / 2.0).round() as usize;
967 if ss_period == 0 {
968 return Err(TrendFlexError::SmootherPeriodExceedsData {
969 ss_period,
970 data_len: 0,
971 });
972 }
973
974 use std::f64::consts::PI;
975 let a = (-1.414_f64 * PI / (ss_period as f64)).exp();
976 let a_sq = a * a;
977 let b = 2.0 * a * (1.414_f64 * PI / (ss_period as f64)).cos();
978 let c = (1.0 + a_sq - b) * 0.5;
979
980 Ok(Self {
981 period,
982 ss_period,
983 a,
984 a_sq,
985 b,
986 c,
987 buf: vec![0.0; period],
988 sum: 0.0,
989 head: 0,
990 prev1_ssf: 0.0,
991 prev2_ssf: 0.0,
992 last_raw: 0.0,
993 n_ssf: 0,
994 ms_prev: 0.0,
995 inv_p: 1.0 / (period as f64),
996 })
997 }
998
999 #[inline(always)]
1000 pub fn update(&mut self, x: f64) -> Option<f64> {
1001 if self.n_ssf == 0 {
1002 self.prev2_ssf = x;
1003 self.last_raw = x;
1004
1005 self.buf[self.head] = x;
1006 self.sum += x;
1007 self.head = if self.period > 1 { 1 } else { 0 };
1008 self.n_ssf = 1;
1009 return None;
1010 }
1011
1012 if self.n_ssf == 1 {
1013 self.prev1_ssf = x;
1014 self.last_raw = x;
1015
1016 if self.period > 1 {
1017 self.buf[self.head] = x;
1018 self.sum += x;
1019 self.head = (self.head + 1) % self.period;
1020 } else {
1021 self.buf[0] = x;
1022 self.sum = x;
1023 }
1024 self.n_ssf = 2;
1025 return None;
1026 }
1027
1028 let cur = (-self.a_sq).mul_add(
1029 self.prev2_ssf,
1030 self.b.mul_add(self.prev1_ssf, self.c * (x + self.last_raw)),
1031 );
1032
1033 let tp_cur_minus_sum = (self.period as f64).mul_add(cur, -self.sum);
1034 let my_sum = self.inv_p * tp_cur_minus_sum;
1035
1036 let will_emit = self.n_ssf + 1 > self.period;
1037
1038 let out_val = if will_emit {
1039 let sq = my_sum * my_sum;
1040 let ms_current = 0.04f64.mul_add(sq, 0.96f64 * self.ms_prev);
1041 self.ms_prev = ms_current;
1042 if ms_current > 0.0 {
1043 my_sum / ms_current.sqrt()
1044 } else {
1045 0.0
1046 }
1047 } else {
1048 0.0
1049 };
1050
1051 let old = self.buf[self.head];
1052 self.sum += cur - old;
1053 self.buf[self.head] = cur;
1054 self.head = (self.head + 1) % self.period;
1055
1056 self.prev2_ssf = self.prev1_ssf;
1057 self.prev1_ssf = cur;
1058 self.last_raw = x;
1059 self.n_ssf += 1;
1060
1061 if will_emit {
1062 Some(out_val)
1063 } else {
1064 None
1065 }
1066 }
1067}
1068
1069#[inline(always)]
1070pub fn trendflex_batch_inner_into(
1071 data: &[f64],
1072 sweep: &TrendFlexBatchRange,
1073 kern: Kernel,
1074 parallel: bool,
1075 out: &mut [f64],
1076) -> Result<Vec<TrendFlexParams>, TrendFlexError> {
1077 let combos = expand_grid(sweep)?;
1078 if combos.is_empty() {
1079 return Err(TrendFlexError::InvalidRange {
1080 start: sweep.period.0,
1081 end: sweep.period.1,
1082 step: sweep.period.2,
1083 });
1084 }
1085
1086 let first = data
1087 .iter()
1088 .position(|x| !x.is_nan())
1089 .ok_or(TrendFlexError::AllValuesNaN)?;
1090 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1091 if data.len() - first < max_p {
1092 return Err(TrendFlexError::TrendFlexPeriodExceedsData {
1093 period: max_p,
1094 data_len: data.len() - first,
1095 });
1096 }
1097
1098 let rows = combos.len();
1099 let cols = data.len();
1100 let expected = rows
1101 .checked_mul(cols)
1102 .ok_or(TrendFlexError::DimensionsOverflow { rows, cols })?;
1103 if out.len() != expected {
1104 return Err(TrendFlexError::OutputLengthMismatch {
1105 expected,
1106 got: out.len(),
1107 });
1108 }
1109
1110 let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
1111
1112 for (row, &warmup) in warm.iter().enumerate() {
1113 let start = row * cols;
1114 let end = start + warmup;
1115 out[start..end].fill(f64::NAN);
1116 }
1117
1118 let actual_kern = match kern {
1119 Kernel::Auto => detect_best_batch_kernel(),
1120 k => k,
1121 };
1122
1123 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1124 let period = combos[row].period.unwrap();
1125
1126 match actual_kern {
1127 Kernel::Scalar | Kernel::ScalarBatch => {
1128 trendflex_row_scalar(data, first, period, out_row)
1129 }
1130 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1131 Kernel::Avx2 | Kernel::Avx2Batch => trendflex_row_avx2(data, first, period, out_row),
1132 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1133 Kernel::Avx512 | Kernel::Avx512Batch => {
1134 trendflex_row_avx512(data, first, period, out_row)
1135 }
1136 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1137 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
1138 trendflex_row_scalar(data, first, period, out_row)
1139 }
1140 Kernel::Auto => unreachable!("Auto kernel should have been resolved"),
1141 }
1142 };
1143
1144 if parallel {
1145 #[cfg(not(target_arch = "wasm32"))]
1146 {
1147 use rayon::prelude::*;
1148 out.par_chunks_mut(cols)
1149 .enumerate()
1150 .for_each(|(row, slice)| do_row(row, slice));
1151 }
1152
1153 #[cfg(target_arch = "wasm32")]
1154 {
1155 for (row, slice) in out.chunks_mut(cols).enumerate() {
1156 do_row(row, slice);
1157 }
1158 }
1159 } else {
1160 for (row, slice) in out.chunks_mut(cols).enumerate() {
1161 do_row(row, slice);
1162 }
1163 }
1164
1165 Ok(combos)
1166}
1167
1168#[derive(Clone, Debug)]
1169pub struct TrendFlexBatchRange {
1170 pub period: (usize, usize, usize),
1171}
1172
1173impl Default for TrendFlexBatchRange {
1174 fn default() -> Self {
1175 Self {
1176 period: (20, 269, 1),
1177 }
1178 }
1179}
1180
1181#[derive(Clone, Debug, Default)]
1182pub struct TrendFlexBatchBuilder {
1183 range: TrendFlexBatchRange,
1184 kernel: Kernel,
1185}
1186
1187impl TrendFlexBatchBuilder {
1188 pub fn new() -> Self {
1189 Self::default()
1190 }
1191 pub fn kernel(mut self, k: Kernel) -> Self {
1192 self.kernel = k;
1193 self
1194 }
1195 #[inline]
1196 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1197 self.range.period = (start, end, step);
1198 self
1199 }
1200 #[inline]
1201 pub fn period_static(mut self, p: usize) -> Self {
1202 self.range.period = (p, p, 0);
1203 self
1204 }
1205
1206 pub fn apply_slice(self, data: &[f64]) -> Result<TrendFlexBatchOutput, TrendFlexError> {
1207 trendflex_batch_with_kernel(data, &self.range, self.kernel)
1208 }
1209 pub fn with_default_slice(
1210 data: &[f64],
1211 k: Kernel,
1212 ) -> Result<TrendFlexBatchOutput, TrendFlexError> {
1213 TrendFlexBatchBuilder::new().kernel(k).apply_slice(data)
1214 }
1215 pub fn apply_candles(
1216 self,
1217 c: &Candles,
1218 src: &str,
1219 ) -> Result<TrendFlexBatchOutput, TrendFlexError> {
1220 let slice = source_type(c, src);
1221 self.apply_slice(slice)
1222 }
1223 pub fn with_default_candles(c: &Candles) -> Result<TrendFlexBatchOutput, TrendFlexError> {
1224 TrendFlexBatchBuilder::new()
1225 .kernel(Kernel::Auto)
1226 .apply_candles(c, "close")
1227 }
1228}
1229
1230pub fn trendflex_batch_with_kernel(
1231 data: &[f64],
1232 sweep: &TrendFlexBatchRange,
1233 k: Kernel,
1234) -> Result<TrendFlexBatchOutput, TrendFlexError> {
1235 let kernel = match k {
1236 Kernel::Auto => detect_best_batch_kernel(),
1237 other if other.is_batch() => other,
1238 _ => return Err(TrendFlexError::InvalidKernelForBatch(k)),
1239 };
1240
1241 let simd = match kernel {
1242 Kernel::Avx512Batch => Kernel::Avx512,
1243 Kernel::Avx2Batch => Kernel::Avx2,
1244 Kernel::ScalarBatch => Kernel::Scalar,
1245 _ => unreachable!(),
1246 };
1247 trendflex_batch_par_slice(data, sweep, simd)
1248}
1249
1250#[derive(Clone, Debug)]
1251pub struct TrendFlexBatchOutput {
1252 pub values: Vec<f64>,
1253 pub combos: Vec<TrendFlexParams>,
1254 pub rows: usize,
1255 pub cols: usize,
1256}
1257
1258impl TrendFlexBatchOutput {
1259 pub fn row_for_params(&self, p: &TrendFlexParams) -> Option<usize> {
1260 self.combos
1261 .iter()
1262 .position(|c| c.period.unwrap_or(20) == p.period.unwrap_or(20))
1263 }
1264 pub fn values_for(&self, p: &TrendFlexParams) -> Option<&[f64]> {
1265 self.row_for_params(p).map(|row| {
1266 let start = row * self.cols;
1267 &self.values[start..start + self.cols]
1268 })
1269 }
1270}
1271
1272#[inline(always)]
1273fn expand_grid(r: &TrendFlexBatchRange) -> Result<Vec<TrendFlexParams>, TrendFlexError> {
1274 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, TrendFlexError> {
1275 if step == 0 || start == end {
1276 return Ok(vec![start]);
1277 }
1278 if start < end {
1279 let v: Vec<usize> = (start..=end).step_by(step).collect();
1280 if v.is_empty() {
1281 return Err(TrendFlexError::InvalidRange { start, end, step });
1282 }
1283 return Ok(v);
1284 }
1285
1286 let mut v = Vec::new();
1287 let mut cur = start;
1288 while cur >= end {
1289 v.push(cur);
1290 if let Some(next) = cur.checked_sub(step) {
1291 cur = next;
1292 } else {
1293 break;
1294 }
1295 if cur == usize::MAX {
1296 break;
1297 }
1298 }
1299 if v.is_empty() {
1300 return Err(TrendFlexError::InvalidRange { start, end, step });
1301 }
1302 Ok(v)
1303 }
1304
1305 let periods = axis_usize(r.period)?;
1306 let mut out = Vec::with_capacity(periods.len());
1307 for &p in &periods {
1308 out.push(TrendFlexParams { period: Some(p) });
1309 }
1310 Ok(out)
1311}
1312
1313#[inline(always)]
1314pub fn expand_grid_trendflex(r: &TrendFlexBatchRange) -> Vec<TrendFlexParams> {
1315 expand_grid(r).unwrap_or_default()
1316}
1317
1318#[inline(always)]
1319pub fn expand_grid_trendflex_checked(
1320 r: &TrendFlexBatchRange,
1321) -> Result<Vec<TrendFlexParams>, TrendFlexError> {
1322 expand_grid(r)
1323}
1324
1325#[inline(always)]
1326pub fn trendflex_batch_slice(
1327 data: &[f64],
1328 sweep: &TrendFlexBatchRange,
1329 kern: Kernel,
1330) -> Result<TrendFlexBatchOutput, TrendFlexError> {
1331 trendflex_batch_inner(data, sweep, kern, false)
1332}
1333#[inline(always)]
1334pub fn trendflex_batch_par_slice(
1335 data: &[f64],
1336 sweep: &TrendFlexBatchRange,
1337 kern: Kernel,
1338) -> Result<TrendFlexBatchOutput, TrendFlexError> {
1339 trendflex_batch_inner(data, sweep, kern, true)
1340}
1341
1342#[inline(always)]
1343fn trendflex_batch_inner(
1344 data: &[f64],
1345 sweep: &TrendFlexBatchRange,
1346 kern: Kernel,
1347 parallel: bool,
1348) -> Result<TrendFlexBatchOutput, TrendFlexError> {
1349 let combos = expand_grid(sweep)?;
1350 if combos.is_empty() {
1351 return Err(TrendFlexError::InvalidRange {
1352 start: sweep.period.0,
1353 end: sweep.period.1,
1354 step: sweep.period.2,
1355 });
1356 }
1357 let first = data
1358 .iter()
1359 .position(|x| !x.is_nan())
1360 .ok_or(TrendFlexError::AllValuesNaN)?;
1361 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1362 if data.len() - first < max_p {
1363 return Err(TrendFlexError::TrendFlexPeriodExceedsData {
1364 period: max_p,
1365 data_len: data.len() - first,
1366 });
1367 }
1368 let rows = combos.len();
1369 let cols = data.len();
1370
1371 rows.checked_mul(cols)
1372 .ok_or(TrendFlexError::DimensionsOverflow { rows, cols })?;
1373
1374 let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
1375 let mut raw = make_uninit_matrix(rows, cols);
1376
1377 unsafe {
1378 init_matrix_prefixes(&mut raw, cols, &warm);
1379 }
1380
1381 let actual_kern = match kern {
1382 Kernel::Auto => detect_best_batch_kernel(),
1383 k => k,
1384 };
1385
1386 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1387 let period = combos[row].period.unwrap();
1388
1389 let out_row =
1390 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1391
1392 match actual_kern {
1393 Kernel::Scalar | Kernel::ScalarBatch => {
1394 trendflex_row_scalar(data, first, period, out_row)
1395 }
1396 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1397 Kernel::Avx2 | Kernel::Avx2Batch => trendflex_row_avx2(data, first, period, out_row),
1398 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1399 Kernel::Avx512 | Kernel::Avx512Batch => {
1400 trendflex_row_avx512(data, first, period, out_row)
1401 }
1402 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1403 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
1404 trendflex_row_scalar(data, first, period, out_row)
1405 }
1406 Kernel::Auto => unreachable!("Auto kernel should have been resolved"),
1407 }
1408 };
1409
1410 if parallel {
1411 #[cfg(not(target_arch = "wasm32"))]
1412 {
1413 raw.par_chunks_mut(cols)
1414 .enumerate()
1415 .for_each(|(row, slice)| do_row(row, slice));
1416 }
1417
1418 #[cfg(target_arch = "wasm32")]
1419 {
1420 for (row, slice) in raw.chunks_mut(cols).enumerate() {
1421 do_row(row, slice);
1422 }
1423 }
1424 } else {
1425 for (row, slice) in raw.chunks_mut(cols).enumerate() {
1426 do_row(row, slice);
1427 }
1428 }
1429
1430 use core::mem::ManuallyDrop;
1431 let mut guard = ManuallyDrop::new(raw);
1432 let values: Vec<f64> = unsafe {
1433 Vec::from_raw_parts(
1434 guard.as_mut_ptr() as *mut f64,
1435 guard.len(),
1436 guard.capacity(),
1437 )
1438 };
1439
1440 Ok(TrendFlexBatchOutput {
1441 values,
1442 combos,
1443 rows,
1444 cols,
1445 })
1446}
1447
1448#[inline(always)]
1449unsafe fn trendflex_row_scalar(data: &[f64], first: usize, period: usize, out_row: &mut [f64]) {
1450 let ss_period = ((period as f64) / 2.0).round() as usize;
1451 let _ = trendflex_scalar_into(data, period, ss_period, first, out_row);
1452}
1453#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1454#[inline(always)]
1455unsafe fn trendflex_row_avx2(data: &[f64], first: usize, period: usize, out_row: &mut [f64]) {
1456 let ss_period = ((period as f64) / 2.0).round() as usize;
1457 let _ = trendflex_avx2_into(data, period, ss_period, first, out_row);
1458}
1459#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1460#[inline(always)]
1461unsafe fn trendflex_row_avx512(data: &[f64], first: usize, period: usize, out_row: &mut [f64]) {
1462 let ss_period = ((period as f64) / 2.0).round() as usize;
1463 let _ = trendflex_avx512_into(data, period, ss_period, first, out_row);
1464}
1465
1466#[cfg(test)]
1467mod tests {
1468 use super::*;
1469 use crate::skip_if_unsupported;
1470 use crate::utilities::data_loader::read_candles_from_csv;
1471
1472 fn check_trendflex_partial_params(
1473 test_name: &str,
1474 kernel: Kernel,
1475 ) -> Result<(), Box<dyn Error>> {
1476 skip_if_unsupported!(kernel, test_name);
1477 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1478 let candles = read_candles_from_csv(file_path)?;
1479
1480 let default_params = TrendFlexParams { period: None };
1481 let input = TrendFlexInput::from_candles(&candles, "close", default_params);
1482 let output = trendflex_with_kernel(&input, kernel)?;
1483 assert_eq!(output.values.len(), candles.close.len());
1484
1485 Ok(())
1486 }
1487
1488 fn check_trendflex_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1489 skip_if_unsupported!(kernel, test_name);
1490 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1491 let candles = read_candles_from_csv(file_path)?;
1492
1493 let params = TrendFlexParams { period: Some(20) };
1494 let input = TrendFlexInput::from_candles(&candles, "close", params);
1495 let result = trendflex_with_kernel(&input, kernel)?;
1496 let expected_last_five = [
1497 -0.19724678008015128,
1498 -0.1238001236481444,
1499 -0.10515389737087717,
1500 -0.1149541079904878,
1501 -0.16006869484450567,
1502 ];
1503 let start = result.values.len().saturating_sub(5);
1504 for (i, &val) in result.values[start..].iter().enumerate() {
1505 let diff = (val - expected_last_five[i]).abs();
1506 assert!(
1507 diff < 1e-8,
1508 "[{}] TrendFlex {:?} mismatch at idx {}: got {}, expected {}",
1509 test_name,
1510 kernel,
1511 i,
1512 val,
1513 expected_last_five[i]
1514 );
1515 }
1516 Ok(())
1517 }
1518
1519 fn check_trendflex_default_candles(
1520 test_name: &str,
1521 kernel: Kernel,
1522 ) -> Result<(), Box<dyn Error>> {
1523 skip_if_unsupported!(kernel, test_name);
1524 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1525 let candles = read_candles_from_csv(file_path)?;
1526
1527 let input = TrendFlexInput::with_default_candles(&candles);
1528 match input.data {
1529 TrendFlexData::Candles { source, .. } => assert_eq!(source, "close"),
1530 _ => panic!("Expected TrendFlexData::Candles"),
1531 }
1532 let output = trendflex_with_kernel(&input, kernel)?;
1533 assert_eq!(output.values.len(), candles.close.len());
1534
1535 Ok(())
1536 }
1537
1538 fn check_trendflex_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1539 skip_if_unsupported!(kernel, test_name);
1540 let input_data = [10.0, 20.0, 30.0];
1541 let params = TrendFlexParams { period: Some(0) };
1542 let input = TrendFlexInput::from_slice(&input_data, params);
1543 let res = trendflex_with_kernel(&input, kernel);
1544 assert!(
1545 res.is_err(),
1546 "[{}] TrendFlex should fail with zero period",
1547 test_name
1548 );
1549 Ok(())
1550 }
1551
1552 fn check_trendflex_period_exceeds_length(
1553 test_name: &str,
1554 kernel: Kernel,
1555 ) -> Result<(), Box<dyn Error>> {
1556 skip_if_unsupported!(kernel, test_name);
1557 let data_small = [10.0, 20.0, 30.0];
1558 let params = TrendFlexParams { period: Some(10) };
1559 let input = TrendFlexInput::from_slice(&data_small, params);
1560 let res = trendflex_with_kernel(&input, kernel);
1561 assert!(
1562 res.is_err(),
1563 "[{}] TrendFlex should fail with period exceeding length",
1564 test_name
1565 );
1566 Ok(())
1567 }
1568
1569 fn check_trendflex_very_small_dataset(
1570 test_name: &str,
1571 kernel: Kernel,
1572 ) -> Result<(), Box<dyn Error>> {
1573 skip_if_unsupported!(kernel, test_name);
1574 let single_point = [42.0];
1575 let params = TrendFlexParams { period: Some(9) };
1576 let input = TrendFlexInput::from_slice(&single_point, params);
1577 let res = trendflex_with_kernel(&input, kernel);
1578 assert!(
1579 res.is_err(),
1580 "[{}] TrendFlex should fail with insufficient data",
1581 test_name
1582 );
1583 Ok(())
1584 }
1585
1586 fn check_trendflex_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1587 skip_if_unsupported!(kernel, test_name);
1588 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1589 let candles = read_candles_from_csv(file_path)?;
1590
1591 let first_params = TrendFlexParams { period: Some(20) };
1592 let first_input = TrendFlexInput::from_candles(&candles, "close", first_params);
1593 let first_result = trendflex_with_kernel(&first_input, kernel)?;
1594
1595 let second_params = TrendFlexParams { period: Some(10) };
1596 let second_input = TrendFlexInput::from_slice(&first_result.values, second_params);
1597 let second_result = trendflex_with_kernel(&second_input, kernel)?;
1598
1599 assert_eq!(second_result.values.len(), first_result.values.len());
1600 if second_result.values.len() > 240 {
1601 for (i, &val) in second_result.values[240..].iter().enumerate() {
1602 assert!(
1603 !val.is_nan(),
1604 "[{}] Found unexpected NaN at out-index {}",
1605 test_name,
1606 240 + i
1607 );
1608 }
1609 }
1610 Ok(())
1611 }
1612
1613 fn check_trendflex_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1614 skip_if_unsupported!(kernel, test_name);
1615 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1616 let candles = read_candles_from_csv(file_path)?;
1617
1618 let input =
1619 TrendFlexInput::from_candles(&candles, "close", TrendFlexParams { period: Some(20) });
1620 let res = trendflex_with_kernel(&input, kernel)?;
1621 assert_eq!(res.values.len(), candles.close.len());
1622 if res.values.len() > 240 {
1623 for (i, &val) in res.values[240..].iter().enumerate() {
1624 assert!(
1625 !val.is_nan(),
1626 "[{}] Found unexpected NaN at out-index {}",
1627 test_name,
1628 240 + i
1629 );
1630 }
1631 }
1632 Ok(())
1633 }
1634
1635 fn check_trendflex_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1636 skip_if_unsupported!(kernel, test_name);
1637
1638 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1639 let candles = read_candles_from_csv(file_path)?;
1640
1641 let period = 20;
1642
1643 let input = TrendFlexInput::from_candles(
1644 &candles,
1645 "close",
1646 TrendFlexParams {
1647 period: Some(period),
1648 },
1649 );
1650 let batch_output = trendflex_with_kernel(&input, kernel)?.values;
1651
1652 let mut stream = TrendFlexStream::try_new(TrendFlexParams {
1653 period: Some(period),
1654 })?;
1655
1656 let mut stream_values = Vec::with_capacity(candles.close.len());
1657 for &price in &candles.close {
1658 match stream.update(price) {
1659 Some(tf_val) => stream_values.push(tf_val),
1660 None => stream_values.push(f64::NAN),
1661 }
1662 }
1663
1664 assert_eq!(batch_output.len(), stream_values.len());
1665 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1666 if b.is_nan() && s.is_nan() {
1667 continue;
1668 }
1669 let diff = (b - s).abs();
1670 assert!(
1671 diff < 1e-9,
1672 "[{}] TrendFlex streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1673 test_name,
1674 i,
1675 b,
1676 s,
1677 diff
1678 );
1679 }
1680 Ok(())
1681 }
1682
1683 macro_rules! generate_all_trendflex_tests {
1684 ($($test_fn:ident),*) => {
1685 paste::paste! {
1686 $(
1687 #[test]
1688 fn [<$test_fn _scalar_f64>]() {
1689 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1690 }
1691 )*
1692 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1693 $(
1694 #[test]
1695 fn [<$test_fn _avx2_f64>]() {
1696 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1697 }
1698 #[test]
1699 fn [<$test_fn _avx512_f64>]() {
1700 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1701 }
1702 )*
1703 }
1704 }
1705 }
1706
1707 #[cfg(debug_assertions)]
1708 fn check_trendflex_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1709 skip_if_unsupported!(kernel, test_name);
1710
1711 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1712 let candles = read_candles_from_csv(file_path)?;
1713
1714 let test_periods = vec![5, 10, 20, 30, 50, 80, 100, 150];
1715
1716 for &period in &test_periods {
1717 let params = TrendFlexParams {
1718 period: Some(period),
1719 };
1720 let input = TrendFlexInput::from_candles(&candles, "close", params);
1721
1722 if candles.close.len() < period {
1723 continue;
1724 }
1725
1726 let output = match trendflex_with_kernel(&input, kernel) {
1727 Ok(o) => o,
1728 Err(_) => continue,
1729 };
1730
1731 for (i, &val) in output.values.iter().enumerate() {
1732 if val.is_nan() {
1733 continue;
1734 }
1735
1736 let bits = val.to_bits();
1737
1738 if bits == 0x11111111_11111111 {
1739 panic!(
1740 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with period {}",
1741 test_name, val, bits, i, period
1742 );
1743 }
1744
1745 if bits == 0x22222222_22222222 {
1746 panic!(
1747 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with period {}",
1748 test_name, val, bits, i, period
1749 );
1750 }
1751
1752 if bits == 0x33333333_33333333 {
1753 panic!(
1754 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with period {}",
1755 test_name, val, bits, i, period
1756 );
1757 }
1758 }
1759 }
1760
1761 Ok(())
1762 }
1763
1764 #[cfg(not(debug_assertions))]
1765 fn check_trendflex_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1766 Ok(())
1767 }
1768
1769 #[cfg(feature = "proptest")]
1770 #[allow(clippy::float_cmp)]
1771 fn check_trendflex_property(
1772 test_name: &str,
1773 kernel: Kernel,
1774 ) -> Result<(), Box<dyn std::error::Error>> {
1775 use proptest::prelude::*;
1776 skip_if_unsupported!(kernel, test_name);
1777
1778 let strat = (1usize..=64).prop_flat_map(|period| {
1779 (
1780 prop::collection::vec(
1781 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1782 period..400,
1783 ),
1784 Just(period),
1785 )
1786 });
1787
1788 proptest::test_runner::TestRunner::default()
1789 .run(&strat, |(data, period)| {
1790 let input = TrendFlexInput::from_slice(
1791 &data,
1792 TrendFlexParams {
1793 period: Some(period),
1794 },
1795 );
1796 let output = trendflex_with_kernel(&input, kernel)?;
1797
1798 prop_assert_eq!(output.values.len(), data.len(), "Output length mismatch");
1799
1800 let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1801 let warmup = first + period;
1802
1803 for i in 0..warmup.min(data.len()) {
1804 prop_assert!(
1805 output.values[i].is_nan(),
1806 "Expected NaN in warmup period at index {}, got {}",
1807 i,
1808 output.values[i]
1809 );
1810 }
1811
1812 for i in warmup..output.values.len() {
1813 prop_assert!(
1814 output.values[i].is_finite(),
1815 "Output at index {} is not finite: {}",
1816 i,
1817 output.values[i]
1818 );
1819 }
1820
1821 if data.len() > warmup + 10 {
1822 let scale_factor = 10.0;
1823 let scaled_data: Vec<f64> = data.iter().map(|&x| x * scale_factor).collect();
1824 let scaled_input = TrendFlexInput::from_slice(
1825 &scaled_data,
1826 TrendFlexParams {
1827 period: Some(period),
1828 },
1829 );
1830 let scaled_output = trendflex_with_kernel(&scaled_input, kernel)?;
1831
1832 let mut similarity_count = 0;
1833 let mut total_compared = 0;
1834 for i in warmup..output.values.len() {
1835 if output.values[i].is_finite() && scaled_output.values[i].is_finite() {
1836 let diff = (output.values[i] - scaled_output.values[i]).abs();
1837
1838 if diff < 0.5 {
1839 similarity_count += 1;
1840 }
1841 total_compared += 1;
1842 }
1843 }
1844
1845 if total_compared > 0 {
1846 let similarity_ratio = similarity_count as f64 / total_compared as f64;
1847 prop_assert!(
1848 similarity_ratio > 0.9,
1849 "Scale invariance failed: only {:.1}% of values are similar after scaling",
1850 similarity_ratio * 100.0
1851 );
1852 }
1853 }
1854
1855 if data.len() > warmup + 20 {
1856 let mut is_increasing = true;
1857 let mut is_decreasing = true;
1858 for i in (warmup + 1)..data.len().min(warmup + 50) {
1859 if data[i] <= data[i - 1] {
1860 is_increasing = false;
1861 }
1862 if data[i] >= data[i - 1] {
1863 is_decreasing = false;
1864 }
1865 }
1866
1867 if is_increasing {
1868 let positive_count =
1869 output.values[warmup..].iter().filter(|&&v| v > 0.0).count();
1870 let total = output.values.len() - warmup;
1871 let positive_ratio = positive_count as f64 / total as f64;
1872 prop_assert!(
1873 positive_ratio > 0.7,
1874 "Increasing trend should produce mostly positive values, got {:.1}% positive",
1875 positive_ratio * 100.0
1876 );
1877 } else if is_decreasing {
1878 let negative_count =
1879 output.values[warmup..].iter().filter(|&&v| v < 0.0).count();
1880 let total = output.values.len() - warmup;
1881 let negative_ratio = negative_count as f64 / total as f64;
1882 prop_assert!(
1883 negative_ratio > 0.7,
1884 "Decreasing trend should produce mostly negative values, got {:.1}% negative",
1885 negative_ratio * 100.0
1886 );
1887 }
1888 }
1889
1890 let all_same = data[first..]
1891 .windows(2)
1892 .all(|w| (w[0] - w[1]).abs() < 1e-10);
1893 if all_same && data.len() > warmup + 10 {
1894 let last_values = &output.values[(data.len() - 5)..];
1895 for val in last_values {
1896 prop_assert!(
1897 val.abs() < 0.1,
1898 "Constant input should produce values near 0, got {}",
1899 val
1900 );
1901 }
1902 }
1903
1904 if period == 1 {
1905 for i in (first + 1)..output.values.len() {
1906 prop_assert!(
1907 output.values[i].is_finite(),
1908 "Period=1 should still produce finite values at index {}",
1909 i
1910 );
1911 }
1912 }
1913
1914 if data.len() > 5 && period >= data.len().saturating_sub(5) && data.len() > period {
1915 let last_idx = data.len() - 1;
1916 if last_idx >= warmup {
1917 prop_assert!(
1918 output.values[last_idx].is_finite(),
1919 "Large period should still produce finite values at the end"
1920 );
1921 }
1922 }
1923
1924 if cfg!(all(feature = "nightly-avx", target_arch = "x86_64")) {
1925 let scalar_output = trendflex_with_kernel(&input, Kernel::Scalar)?;
1926
1927 for i in 0..output.values.len() {
1928 if output.values[i].is_finite() && scalar_output.values[i].is_finite() {
1929 prop_assert!(
1930 (output.values[i] - scalar_output.values[i]).abs() < 1e-9,
1931 "Kernel consistency failed at index {}: {} vs {}",
1932 i,
1933 output.values[i],
1934 scalar_output.values[i]
1935 );
1936 } else {
1937 prop_assert_eq!(
1938 output.values[i].is_nan(),
1939 scalar_output.values[i].is_nan(),
1940 "NaN mismatch between kernels at index {}",
1941 i
1942 );
1943 }
1944 }
1945 }
1946
1947 Ok(())
1948 })
1949 .map_err(|e| e.into())
1950 }
1951
1952 #[cfg(feature = "proptest")]
1953 generate_all_trendflex_tests!(check_trendflex_property);
1954
1955 #[test]
1956 fn test_trendflex_into_slice_validation() {
1957 let data = vec![1.0, 2.0, 3.0];
1958 let params = TrendFlexParams { period: Some(10) };
1959 let input = TrendFlexInput::from_slice(&data, params);
1960 let mut out = vec![0.0; data.len()];
1961
1962 let result = trendflex_into_slice(&mut out, &input, Kernel::Scalar);
1963 assert!(result.is_err());
1964 match result {
1965 Err(TrendFlexError::TrendFlexPeriodExceedsData { period, data_len }) => {
1966 assert_eq!(period, 10);
1967 assert_eq!(data_len, 3);
1968 }
1969 _ => panic!("Expected TrendFlexPeriodExceedsData error"),
1970 }
1971
1972 let empty_data: Vec<f64> = vec![];
1973 let params = TrendFlexParams { period: Some(5) };
1974 let input = TrendFlexInput::from_slice(&empty_data, params);
1975 let mut out = vec![];
1976
1977 let result = trendflex_into_slice(&mut out, &input, Kernel::Scalar);
1978 assert!(result.is_err());
1979 match result {
1980 Err(TrendFlexError::NoDataProvided) => {}
1981 _ => panic!("Expected NoDataProvided error"),
1982 }
1983
1984 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1985 let params = TrendFlexParams { period: Some(0) };
1986 let input = TrendFlexInput::from_slice(&data, params);
1987 let mut out = vec![0.0; data.len()];
1988
1989 let result = trendflex_into_slice(&mut out, &input, Kernel::Scalar);
1990 assert!(result.is_err());
1991 match result {
1992 Err(TrendFlexError::ZeroTrendFlexPeriod { period }) => {
1993 assert_eq!(period, 0);
1994 }
1995 _ => panic!("Expected ZeroTrendFlexPeriod error"),
1996 }
1997
1998 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1999 let params = TrendFlexParams { period: Some(3) };
2000 let input = TrendFlexInput::from_slice(&data, params);
2001 let mut out = vec![0.0; data.len()];
2002
2003 let result = trendflex_into_slice(&mut out, &input, Kernel::Scalar);
2004 assert!(result.is_ok());
2005 }
2006
2007 #[test]
2008 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2009 fn test_trendflex_into_matches_api() -> Result<(), Box<dyn Error>> {
2010 let n = 512usize;
2011 let mut data = Vec::with_capacity(n);
2012 for i in 0..n {
2013 let t = i as f64;
2014 data.push(0.01 * t + (t * 0.05).sin());
2015 }
2016
2017 let input = TrendFlexInput::from_slice(&data, TrendFlexParams::default());
2018 let baseline = trendflex(&input)?.values;
2019
2020 let mut out = vec![0.0f64; n];
2021 trendflex_into(&input, &mut out)?;
2022
2023 assert_eq!(baseline.len(), out.len());
2024 for i in 0..n {
2025 let a = baseline[i];
2026 let b = out[i];
2027 let equal = if a.is_nan() && b.is_nan() {
2028 true
2029 } else {
2030 (a - b).abs() <= 1e-12
2031 };
2032 assert!(equal, "divergence at {}: {} vs {}", i, a, b);
2033 }
2034 Ok(())
2035 }
2036
2037 #[test]
2038 fn test_trendflex_batch_kernel_policy() {
2039 let data = vec![1.0; 50];
2040 let sweep = TrendFlexBatchRange { period: (5, 10, 1) };
2041
2042 let result_scalar = trendflex_batch_with_kernel(&data, &sweep, Kernel::Scalar);
2043 assert!(matches!(
2044 result_scalar,
2045 Err(TrendFlexError::InvalidKernelForBatch(Kernel::Scalar))
2046 ));
2047
2048 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2049 {
2050 let result_avx2 = trendflex_batch_with_kernel(&data, &sweep, Kernel::Avx2);
2051 assert!(matches!(
2052 result_avx2,
2053 Err(TrendFlexError::InvalidKernelForBatch(Kernel::Avx2))
2054 ));
2055
2056 let result_avx512 = trendflex_batch_with_kernel(&data, &sweep, Kernel::Avx512);
2057 assert!(matches!(
2058 result_avx512,
2059 Err(TrendFlexError::InvalidKernelForBatch(Kernel::Avx512))
2060 ));
2061 }
2062
2063 let result_scalar_batch = trendflex_batch_with_kernel(&data, &sweep, Kernel::ScalarBatch);
2064 assert!(result_scalar_batch.is_ok());
2065 }
2066
2067 generate_all_trendflex_tests!(
2068 check_trendflex_partial_params,
2069 check_trendflex_accuracy,
2070 check_trendflex_default_candles,
2071 check_trendflex_zero_period,
2072 check_trendflex_period_exceeds_length,
2073 check_trendflex_very_small_dataset,
2074 check_trendflex_reinput,
2075 check_trendflex_nan_handling,
2076 check_trendflex_streaming,
2077 check_trendflex_no_poison
2078 );
2079
2080 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2081 skip_if_unsupported!(kernel, test);
2082
2083 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2084 let c = read_candles_from_csv(file)?;
2085
2086 let output = TrendFlexBatchBuilder::new()
2087 .kernel(kernel)
2088 .apply_candles(&c, "close")?;
2089
2090 let def = TrendFlexParams::default();
2091 let row = output.values_for(&def).expect("default row missing");
2092
2093 assert_eq!(row.len(), c.close.len());
2094
2095 let expected = [
2096 -0.19724678008015128,
2097 -0.1238001236481444,
2098 -0.10515389737087717,
2099 -0.1149541079904878,
2100 -0.16006869484450567,
2101 ];
2102 let start = row.len() - 5;
2103 for (i, &v) in row[start..].iter().enumerate() {
2104 assert!(
2105 (v - expected[i]).abs() < 1e-8,
2106 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2107 );
2108 }
2109 Ok(())
2110 }
2111
2112 macro_rules! gen_batch_tests {
2113 ($fn_name:ident) => {
2114 paste::paste! {
2115 #[test] fn [<$fn_name _scalar>]() {
2116 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2117 }
2118 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2119 #[test] fn [<$fn_name _avx2>]() {
2120 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2121 }
2122 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2123 #[test] fn [<$fn_name _avx512>]() {
2124 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2125 }
2126 #[test] fn [<$fn_name _auto_detect>]() {
2127 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2128 }
2129 }
2130 };
2131 }
2132
2133 #[cfg(debug_assertions)]
2134 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2135 skip_if_unsupported!(kernel, test);
2136
2137 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2138 let c = read_candles_from_csv(file)?;
2139
2140 let test_configs = vec![
2141 (5, 20, 3),
2142 (10, 50, 5),
2143 (20, 100, 10),
2144 (30, 120, 15),
2145 (7, 7, 1),
2146 (80, 80, 1),
2147 (15, 45, 5),
2148 ];
2149
2150 for (start, end, step) in test_configs {
2151 let output = TrendFlexBatchBuilder::new()
2152 .kernel(kernel)
2153 .period_range(start, end, step)
2154 .apply_candles(&c, "close")?;
2155
2156 for (idx, &val) in output.values.iter().enumerate() {
2157 if val.is_nan() {
2158 continue;
2159 }
2160
2161 let bits = val.to_bits();
2162 let row = idx / output.cols;
2163 let col = idx % output.cols;
2164 let period = output
2165 .combos
2166 .get(row)
2167 .map(|p| p.period.unwrap_or(0))
2168 .unwrap_or(0);
2169
2170 if bits == 0x11111111_11111111 {
2171 panic!(
2172 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (period {}, flat index {})",
2173 test, val, bits, row, col, period, idx
2174 );
2175 }
2176
2177 if bits == 0x22222222_22222222 {
2178 panic!(
2179 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (period {}, flat index {})",
2180 test, val, bits, row, col, period, idx
2181 );
2182 }
2183
2184 if bits == 0x33333333_33333333 {
2185 panic!(
2186 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (period {}, flat index {})",
2187 test, val, bits, row, col, period, idx
2188 );
2189 }
2190 }
2191 }
2192
2193 Ok(())
2194 }
2195
2196 #[cfg(not(debug_assertions))]
2197 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2198 Ok(())
2199 }
2200
2201 gen_batch_tests!(check_batch_default_row);
2202 gen_batch_tests!(check_batch_no_poison);
2203}
2204
2205#[cfg(feature = "python")]
2206use pyo3::exceptions::PyValueError;
2207#[cfg(feature = "python")]
2208use pyo3::prelude::*;
2209
2210#[cfg(feature = "python")]
2211#[pyfunction(name = "trendflex")]
2212#[pyo3(signature = (data, period=None, kernel=None))]
2213
2214pub fn trendflex_py<'py>(
2215 py: Python<'py>,
2216 data: numpy::PyReadonlyArray1<'py, f64>,
2217 period: Option<usize>,
2218 kernel: Option<&str>,
2219) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
2220 use numpy::{IntoPyArray, PyArrayMethods};
2221
2222 let slice_in = data.as_slice()?;
2223 let kern = validate_kernel(kernel, false)?;
2224
2225 let params = TrendFlexParams { period };
2226 let trendflex_in = TrendFlexInput::from_slice(slice_in, params);
2227
2228 let result_vec: Vec<f64> = py
2229 .allow_threads(|| trendflex_with_kernel(&trendflex_in, kern).map(|o| o.values))
2230 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2231
2232 Ok(result_vec.into_pyarray(py))
2233}
2234
2235#[cfg(feature = "python")]
2236#[pyclass(name = "TrendFlexStream")]
2237pub struct TrendFlexStreamPy {
2238 stream: TrendFlexStream,
2239}
2240
2241#[cfg(feature = "python")]
2242#[pymethods]
2243impl TrendFlexStreamPy {
2244 #[new]
2245 fn new(period: Option<usize>) -> PyResult<Self> {
2246 let params = TrendFlexParams { period };
2247 let stream =
2248 TrendFlexStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2249 Ok(TrendFlexStreamPy { stream })
2250 }
2251
2252 fn update(&mut self, value: f64) -> Option<f64> {
2253 self.stream.update(value)
2254 }
2255}
2256
2257#[cfg(feature = "python")]
2258#[pyfunction(name = "trendflex_batch")]
2259#[pyo3(signature = (data, period_range, kernel=None))]
2260
2261pub fn trendflex_batch_py<'py>(
2262 py: Python<'py>,
2263 data: numpy::PyReadonlyArray1<'py, f64>,
2264 period_range: (usize, usize, usize),
2265 kernel: Option<&str>,
2266) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2267 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2268 use pyo3::types::PyDict;
2269
2270 let slice_in = data.as_slice()?;
2271 let kern = validate_kernel(kernel, true)?;
2272
2273 let sweep = TrendFlexBatchRange {
2274 period: period_range,
2275 };
2276
2277 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2278 let rows = combos.len();
2279 let cols = slice_in.len();
2280 rows.checked_mul(cols)
2281 .ok_or_else(|| PyValueError::new_err("dimensions overflow"))?;
2282
2283 let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
2284 let slice_out = unsafe { out_arr.as_slice_mut()? };
2285
2286 let combos = py
2287 .allow_threads(|| -> Result<Vec<TrendFlexParams>, TrendFlexError> {
2288 let kernel = match kern {
2289 Kernel::Auto => detect_best_batch_kernel(),
2290 k => k,
2291 };
2292 let simd = match kernel {
2293 Kernel::Avx512Batch => Kernel::Avx512,
2294 Kernel::Avx2Batch => Kernel::Avx2,
2295 Kernel::ScalarBatch => Kernel::Scalar,
2296 _ => unreachable!(),
2297 };
2298
2299 trendflex_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2300 })
2301 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2302
2303 let dict = PyDict::new(py);
2304 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2305
2306 dict.set_item(
2307 "periods",
2308 combos
2309 .iter()
2310 .map(|p| p.period.unwrap_or(20) as u64)
2311 .collect::<Vec<_>>()
2312 .into_pyarray(py),
2313 )?;
2314
2315 Ok(dict)
2316}
2317
2318#[cfg(all(feature = "python", feature = "cuda"))]
2319#[pyfunction(name = "trendflex_cuda_batch_dev")]
2320#[pyo3(signature = (data_f32, period_range, device_id=0))]
2321pub fn trendflex_cuda_batch_dev_py<'py>(
2322 py: Python<'py>,
2323 data_f32: numpy::PyReadonlyArray1<'py, f32>,
2324 period_range: (usize, usize, usize),
2325 device_id: usize,
2326) -> PyResult<(TrendFlexDeviceArrayF32Py, Bound<'py, pyo3::types::PyDict>)> {
2327 use crate::cuda::cuda_available;
2328 use numpy::IntoPyArray;
2329 use pyo3::types::PyDict;
2330
2331 if !cuda_available() {
2332 return Err(PyValueError::new_err("CUDA not available"));
2333 }
2334
2335 let slice_in = data_f32.as_slice()?;
2336 let sweep = TrendFlexBatchRange {
2337 period: period_range,
2338 };
2339
2340 let (inner, combos, ctx_arc, dev_id) = py.allow_threads(|| {
2341 let cuda =
2342 CudaTrendflex::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2343 let (dev, combos) = cuda
2344 .trendflex_batch_dev(slice_in, &sweep)
2345 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2346
2347 cuda.synchronize()
2348 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2349 Ok::<_, PyErr>((dev, combos, cuda.context_arc_clone(), cuda.device_id()))
2350 })?;
2351
2352 let dict = PyDict::new(py);
2353 let periods: Vec<u64> = combos.iter().map(|c| c.period.unwrap() as u64).collect();
2354 dict.set_item("periods", periods.into_pyarray(py))?;
2355
2356 Ok((
2357 TrendFlexDeviceArrayF32Py {
2358 inner,
2359 _ctx: ctx_arc,
2360 device_id: dev_id,
2361 },
2362 dict,
2363 ))
2364}
2365
2366#[cfg(all(feature = "python", feature = "cuda"))]
2367#[pyfunction(name = "trendflex_cuda_many_series_one_param_dev")]
2368#[pyo3(signature = (data_tm_f32, period, device_id=0))]
2369pub fn trendflex_cuda_many_series_one_param_dev_py(
2370 py: Python<'_>,
2371 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
2372 period: usize,
2373 device_id: usize,
2374) -> PyResult<TrendFlexDeviceArrayF32Py> {
2375 use crate::cuda::cuda_available;
2376 use numpy::PyUntypedArrayMethods;
2377
2378 if !cuda_available() {
2379 return Err(PyValueError::new_err("CUDA not available"));
2380 }
2381
2382 let flat_in = data_tm_f32.as_slice()?;
2383 let rows = data_tm_f32.shape()[0];
2384 let cols = data_tm_f32.shape()[1];
2385 let params = TrendFlexParams {
2386 period: Some(period),
2387 };
2388
2389 let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
2390 let cuda =
2391 CudaTrendflex::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2392 let dev = cuda
2393 .trendflex_multi_series_one_param_time_major_dev(flat_in, cols, rows, ¶ms)
2394 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2395 cuda.synchronize()
2396 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2397 Ok::<_, PyErr>((dev, cuda.context_arc_clone(), cuda.device_id()))
2398 })?;
2399
2400 Ok(TrendFlexDeviceArrayF32Py {
2401 inner,
2402 _ctx: ctx_arc,
2403 device_id: dev_id,
2404 })
2405}
2406
2407#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2408use serde::{Deserialize, Serialize};
2409#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2410use wasm_bindgen::prelude::*;
2411
2412#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2413#[derive(Serialize, Deserialize)]
2414pub struct TrendFlexBatchConfig {
2415 pub period_range: (usize, usize, usize),
2416}
2417
2418#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2419#[derive(Serialize, Deserialize)]
2420pub struct TrendFlexBatchJsOutput {
2421 pub values: Vec<f64>,
2422 pub combos: Vec<TrendFlexParams>,
2423 pub rows: usize,
2424 pub cols: usize,
2425}
2426
2427#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2428#[wasm_bindgen]
2429
2430pub fn trendflex_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2431 let params = TrendFlexParams {
2432 period: Some(period),
2433 };
2434 let input = TrendFlexInput::from_slice(data, params);
2435
2436 trendflex_with_kernel(&input, Kernel::Auto)
2437 .map(|o| o.values)
2438 .map_err(|e| JsValue::from_str(&e.to_string()))
2439}
2440
2441#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2442#[wasm_bindgen]
2443
2444pub fn trendflex_batch_js(
2445 data: &[f64],
2446 period_start: usize,
2447 period_end: usize,
2448 period_step: usize,
2449) -> Result<Vec<f64>, JsValue> {
2450 let sweep = TrendFlexBatchRange {
2451 period: (period_start, period_end, period_step),
2452 };
2453
2454 trendflex_batch_inner(data, &sweep, Kernel::Auto, false)
2455 .map(|output| output.values)
2456 .map_err(|e| JsValue::from_str(&e.to_string()))
2457}
2458
2459#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2460#[wasm_bindgen]
2461
2462pub fn trendflex_batch_metadata_js(
2463 period_start: usize,
2464 period_end: usize,
2465 period_step: usize,
2466) -> Result<Vec<f64>, JsValue> {
2467 let sweep = TrendFlexBatchRange {
2468 period: (period_start, period_end, period_step),
2469 };
2470
2471 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2472 let metadata: Vec<f64> = combos
2473 .iter()
2474 .map(|combo| combo.period.unwrap_or(20) as f64)
2475 .collect();
2476
2477 Ok(metadata)
2478}
2479
2480#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2481#[wasm_bindgen(js_name = trendflex_batch)]
2482pub fn trendflex_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2483 let config: TrendFlexBatchConfig = serde_wasm_bindgen::from_value(config)
2484 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2485
2486 let sweep = TrendFlexBatchRange {
2487 period: config.period_range,
2488 };
2489
2490 let output = trendflex_batch_inner(data, &sweep, Kernel::Auto, false)
2491 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2492
2493 let js_output = TrendFlexBatchJsOutput {
2494 values: output.values,
2495 combos: output.combos,
2496 rows: output.rows,
2497 cols: output.cols,
2498 };
2499
2500 serde_wasm_bindgen::to_value(&js_output)
2501 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2502}
2503
2504#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2505#[wasm_bindgen]
2506pub fn trendflex_alloc(len: usize) -> *mut f64 {
2507 let mut vec = Vec::<f64>::with_capacity(len);
2508 let ptr = vec.as_mut_ptr();
2509 std::mem::forget(vec);
2510 ptr
2511}
2512
2513#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2514#[wasm_bindgen]
2515pub fn trendflex_free(ptr: *mut f64, len: usize) {
2516 unsafe {
2517 let _ = Vec::from_raw_parts(ptr, len, len);
2518 }
2519}
2520
2521#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2522#[wasm_bindgen]
2523pub fn trendflex_into(
2524 in_ptr: *const f64,
2525 out_ptr: *mut f64,
2526 len: usize,
2527 period: usize,
2528) -> Result<(), JsValue> {
2529 if in_ptr.is_null() || out_ptr.is_null() {
2530 return Err(JsValue::from_str("null pointer passed to trendflex_into"));
2531 }
2532 unsafe {
2533 let data = std::slice::from_raw_parts(in_ptr, len);
2534 if period == 0 || period >= len {
2535 return Err(JsValue::from_str("Invalid period"));
2536 }
2537 let input = TrendFlexInput::from_slice(
2538 data,
2539 TrendFlexParams {
2540 period: Some(period),
2541 },
2542 );
2543 if in_ptr == out_ptr {
2544 let mut tmp = vec![0.0; len];
2545 trendflex_into_slice(&mut tmp, &input, detect_best_kernel())
2546 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2547 std::slice::from_raw_parts_mut(out_ptr, len).copy_from_slice(&tmp);
2548 } else {
2549 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2550 trendflex_into_slice(out, &input, detect_best_kernel())
2551 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2552 }
2553 Ok(())
2554 }
2555}
2556
2557#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2558#[wasm_bindgen]
2559pub fn trendflex_batch_into(
2560 in_ptr: *const f64,
2561 out_ptr: *mut f64,
2562 len: usize,
2563 period_start: usize,
2564 period_end: usize,
2565 period_step: usize,
2566) -> Result<usize, JsValue> {
2567 if in_ptr.is_null() || out_ptr.is_null() {
2568 return Err(JsValue::from_str(
2569 "null pointer passed to trendflex_batch_into",
2570 ));
2571 }
2572
2573 unsafe {
2574 let data = std::slice::from_raw_parts(in_ptr, len);
2575
2576 let sweep = TrendFlexBatchRange {
2577 period: (period_start, period_end, period_step),
2578 };
2579
2580 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2581 let n_combos = combos.len();
2582 let total_size = n_combos
2583 .checked_mul(len)
2584 .ok_or_else(|| JsValue::from_str("dimensions overflow"))?;
2585
2586 let out_slice = std::slice::from_raw_parts_mut(out_ptr, total_size);
2587
2588 trendflex_batch_inner_into(data, &sweep, Kernel::Auto, false, out_slice)
2589 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2590
2591 Ok(n_combos)
2592 }
2593}