Skip to main content

vector_ta/indicators/moving_averages/
trendflex.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::moving_averages::alma_wrapper::DeviceArrayF32;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::CudaTrendflex;
5use crate::utilities::data_loader::{source_type, Candles};
6#[cfg(all(feature = "python", feature = "cuda"))]
7use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
8use crate::utilities::enums::Kernel;
9use crate::utilities::helpers::{
10    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
11    make_uninit_matrix,
12};
13#[cfg(feature = "python")]
14use crate::utilities::kernel_validation::validate_kernel;
15use aligned_vec::{AVec, ConstAlign, CACHELINE_ALIGN};
16#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
17use core::arch::x86_64::*;
18#[cfg(all(feature = "python", feature = "cuda"))]
19use cust::context::Context;
20#[cfg(all(feature = "python", feature = "cuda"))]
21use cust::memory::DeviceBuffer;
22#[cfg(not(target_arch = "wasm32"))]
23use rayon::prelude::*;
24use std::convert::AsRef;
25use std::error::Error;
26use std::mem::MaybeUninit;
27#[cfg(all(feature = "python", feature = "cuda"))]
28use std::sync::Arc;
29use thiserror::Error;
30
31impl<'a> AsRef<[f64]> for TrendFlexInput<'a> {
32    #[inline(always)]
33    fn as_ref(&self) -> &[f64] {
34        match &self.data {
35            TrendFlexData::Slice(slice) => slice,
36            TrendFlexData::Candles { candles, source } => source_type(candles, source),
37        }
38    }
39}
40
41#[derive(Debug, Clone)]
42pub enum TrendFlexData<'a> {
43    Candles {
44        candles: &'a Candles,
45        source: &'a str,
46    },
47    Slice(&'a [f64]),
48}
49
50#[derive(Debug, Clone)]
51pub struct TrendFlexOutput {
52    pub values: Vec<f64>,
53}
54
55#[derive(Debug, Clone)]
56#[cfg_attr(
57    all(target_arch = "wasm32", feature = "wasm"),
58    derive(Serialize, Deserialize)
59)]
60pub struct TrendFlexParams {
61    pub period: Option<usize>,
62}
63
64impl Default for TrendFlexParams {
65    fn default() -> Self {
66        Self { period: Some(20) }
67    }
68}
69
70#[derive(Debug, Clone)]
71pub struct TrendFlexInput<'a> {
72    pub data: TrendFlexData<'a>,
73    pub params: TrendFlexParams,
74}
75
76#[cfg(all(feature = "python", feature = "cuda"))]
77#[pyo3::prelude::pyclass(
78    module = "ta_indicators.cuda",
79    name = "TrendFlexDeviceArrayF32",
80    unsendable
81)]
82pub struct TrendFlexDeviceArrayF32Py {
83    pub(crate) inner: DeviceArrayF32,
84    pub(crate) _ctx: Arc<Context>,
85    pub(crate) device_id: u32,
86}
87
88#[cfg(all(feature = "python", feature = "cuda"))]
89#[pyo3::prelude::pymethods]
90impl TrendFlexDeviceArrayF32Py {
91    #[getter]
92    fn __cuda_array_interface__<'py>(
93        &self,
94        py: pyo3::prelude::Python<'py>,
95    ) -> pyo3::PyResult<pyo3::prelude::Bound<'py, pyo3::types::PyDict>> {
96        let d = pyo3::types::PyDict::new(py);
97        d.set_item("shape", (self.inner.rows, self.inner.cols))?;
98        d.set_item("typestr", "<f4")?;
99        d.set_item(
100            "strides",
101            (
102                self.inner.cols * std::mem::size_of::<f32>(),
103                std::mem::size_of::<f32>(),
104            ),
105        )?;
106        d.set_item("data", (self.inner.device_ptr() as usize, false))?;
107
108        d.set_item("version", 3)?;
109        Ok(d)
110    }
111
112    fn __dlpack_device__(&self) -> (i32, i32) {
113        (2, self.device_id as i32)
114    }
115
116    #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
117    fn __dlpack__<'py>(
118        &mut self,
119        py: pyo3::prelude::Python<'py>,
120        stream: Option<pyo3::PyObject>,
121        max_version: Option<pyo3::PyObject>,
122        dl_device: Option<pyo3::PyObject>,
123        copy: Option<pyo3::PyObject>,
124    ) -> pyo3::PyResult<pyo3::prelude::PyObject> {
125        let (kdl, alloc_dev) = self.__dlpack_device__();
126        if let Some(dev_obj) = dl_device.as_ref() {
127            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
128                if dev_ty != kdl || dev_id != alloc_dev {
129                    let wants_copy = copy
130                        .as_ref()
131                        .and_then(|c| c.extract::<bool>(py).ok())
132                        .unwrap_or(false);
133                    if wants_copy {
134                        return Err(pyo3::exceptions::PyValueError::new_err(
135                            "device copy not implemented for __dlpack__",
136                        ));
137                    } else {
138                        return Err(pyo3::exceptions::PyValueError::new_err(
139                            "dl_device mismatch for __dlpack__",
140                        ));
141                    }
142                }
143            }
144        }
145        let _ = stream;
146
147        let dummy = DeviceBuffer::from_slice(&[])
148            .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
149        let inner = std::mem::replace(
150            &mut self.inner,
151            DeviceArrayF32 {
152                buf: dummy,
153                rows: 0,
154                cols: 0,
155            },
156        );
157
158        let rows = inner.rows;
159        let cols = inner.cols;
160        let buf = inner.buf;
161
162        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
163
164        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
165    }
166}
167
168impl<'a> TrendFlexInput<'a> {
169    #[inline]
170    pub fn from_candles(c: &'a Candles, s: &'a str, p: TrendFlexParams) -> Self {
171        Self {
172            data: TrendFlexData::Candles {
173                candles: c,
174                source: s,
175            },
176            params: p,
177        }
178    }
179    #[inline]
180    pub fn from_slice(sl: &'a [f64], p: TrendFlexParams) -> Self {
181        Self {
182            data: TrendFlexData::Slice(sl),
183            params: p,
184        }
185    }
186    #[inline]
187    pub fn with_default_candles(c: &'a Candles) -> Self {
188        Self::from_candles(c, "close", TrendFlexParams::default())
189    }
190    #[inline]
191    pub fn get_period(&self) -> usize {
192        self.params.period.unwrap_or(20)
193    }
194}
195
196#[derive(Copy, Clone, Debug)]
197pub struct TrendFlexBuilder {
198    period: Option<usize>,
199    kernel: Kernel,
200}
201
202impl Default for TrendFlexBuilder {
203    fn default() -> Self {
204        Self {
205            period: None,
206            kernel: Kernel::Auto,
207        }
208    }
209}
210
211impl TrendFlexBuilder {
212    #[inline(always)]
213    pub fn new() -> Self {
214        Self::default()
215    }
216    #[inline(always)]
217    pub fn period(mut self, n: usize) -> Self {
218        self.period = Some(n);
219        self
220    }
221    #[inline(always)]
222    pub fn kernel(mut self, k: Kernel) -> Self {
223        self.kernel = k;
224        self
225    }
226    #[inline(always)]
227    pub fn apply(self, c: &Candles) -> Result<TrendFlexOutput, TrendFlexError> {
228        let p = TrendFlexParams {
229            period: self.period,
230        };
231        let i = TrendFlexInput::from_candles(c, "close", p);
232        trendflex_with_kernel(&i, self.kernel)
233    }
234    #[inline(always)]
235    pub fn apply_slice(self, d: &[f64]) -> Result<TrendFlexOutput, TrendFlexError> {
236        let p = TrendFlexParams {
237            period: self.period,
238        };
239        let i = TrendFlexInput::from_slice(d, p);
240        trendflex_with_kernel(&i, self.kernel)
241    }
242    #[inline(always)]
243    pub fn into_stream(self) -> Result<TrendFlexStream, TrendFlexError> {
244        let p = TrendFlexParams {
245            period: self.period,
246        };
247        TrendFlexStream::try_new(p)
248    }
249}
250
251#[derive(Debug, Error)]
252pub enum TrendFlexError {
253    #[error("trendflex: No data provided.")]
254    NoDataProvided,
255    #[error("trendflex: All values are NaN.")]
256    AllValuesNaN,
257    #[error("trendflex: period = 0")]
258    ZeroTrendFlexPeriod { period: usize },
259    #[error("trendflex: period > data len: period = {period}, data_len = {data_len}")]
260    TrendFlexPeriodExceedsData { period: usize, data_len: usize },
261    #[error(
262        "trendflex: smoother period > data len: ss_period = {ss_period}, data_len = {data_len}"
263    )]
264    SmootherPeriodExceedsData { ss_period: usize, data_len: usize },
265    #[error("trendflex: output length mismatch: expected {expected}, got {got}")]
266    OutputLengthMismatch { expected: usize, got: usize },
267    #[error("trendflex: not enough valid data: needed {needed}, valid {valid}")]
268    NotEnoughValidData { needed: usize, valid: usize },
269    #[error("trendflex: invalid range: start={start}, end={end}, step={step}")]
270    InvalidRange {
271        start: usize,
272        end: usize,
273        step: usize,
274    },
275    #[error("trendflex: invalid kernel for batch: {0:?}")]
276    InvalidKernelForBatch(Kernel),
277    #[error("trendflex: dimensions overflow: rows={rows}, cols={cols}")]
278    DimensionsOverflow { rows: usize, cols: usize },
279}
280
281#[inline]
282pub fn trendflex(input: &TrendFlexInput) -> Result<TrendFlexOutput, TrendFlexError> {
283    trendflex_with_kernel(input, Kernel::Auto)
284}
285
286pub fn trendflex_with_kernel(
287    input: &TrendFlexInput,
288    kernel: Kernel,
289) -> Result<TrendFlexOutput, TrendFlexError> {
290    let data: &[f64] = input.as_ref();
291    let len = data.len();
292    if len == 0 {
293        return Err(TrendFlexError::NoDataProvided);
294    }
295
296    let period = input.get_period();
297    if period == 0 {
298        return Err(TrendFlexError::ZeroTrendFlexPeriod { period });
299    }
300    if period >= len {
301        return Err(TrendFlexError::TrendFlexPeriodExceedsData {
302            period,
303            data_len: len,
304        });
305    }
306
307    let first = data
308        .iter()
309        .position(|x| !x.is_nan())
310        .ok_or(TrendFlexError::AllValuesNaN)?;
311    let ss_period = ((period as f64) / 2.0).round() as usize;
312
313    let valid = len - first;
314    if valid < period {
315        return Err(TrendFlexError::NotEnoughValidData {
316            needed: period,
317            valid,
318        });
319    }
320    if ss_period > len {
321        return Err(TrendFlexError::SmootherPeriodExceedsData {
322            ss_period,
323            data_len: len,
324        });
325    }
326
327    let warm = first + period;
328    let mut out = alloc_with_nan_prefix(len, warm);
329
330    let chosen = match kernel {
331        Kernel::Auto => detect_best_kernel(),
332        k => k,
333    };
334
335    unsafe {
336        match chosen {
337            Kernel::Scalar | Kernel::ScalarBatch => {
338                trendflex_scalar_into(data, period, ss_period, first, &mut out)?
339            }
340            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
341            Kernel::Avx2 | Kernel::Avx2Batch => {
342                trendflex_avx2_into(data, period, ss_period, first, &mut out)?
343            }
344            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
345            Kernel::Avx512 | Kernel::Avx512Batch => {
346                trendflex_avx512_into(data, period, ss_period, first, &mut out)?
347            }
348            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
349            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
350                trendflex_scalar_into(data, period, ss_period, first, &mut out)?
351            }
352            Kernel::Auto => unreachable!(),
353        }
354    }
355
356    Ok(TrendFlexOutput { values: out })
357}
358
359pub fn trendflex_into_slice(
360    dst: &mut [f64],
361    input: &TrendFlexInput,
362    kernel: Kernel,
363) -> Result<(), TrendFlexError> {
364    let data: &[f64] = input.as_ref();
365    let len = data.len();
366    if dst.len() != len {
367        return Err(TrendFlexError::OutputLengthMismatch {
368            expected: len,
369            got: dst.len(),
370        });
371    }
372    if len == 0 {
373        return Err(TrendFlexError::NoDataProvided);
374    }
375    let period = input.get_period();
376    if period == 0 {
377        return Err(TrendFlexError::ZeroTrendFlexPeriod { period });
378    }
379    if period >= len {
380        return Err(TrendFlexError::TrendFlexPeriodExceedsData {
381            period,
382            data_len: len,
383        });
384    }
385    let first = data
386        .iter()
387        .position(|x| !x.is_nan())
388        .ok_or(TrendFlexError::AllValuesNaN)?;
389    let ss_period = ((period as f64) / 2.0).round() as usize;
390    let valid = len - first;
391    if valid < period {
392        return Err(TrendFlexError::NotEnoughValidData {
393            needed: period,
394            valid,
395        });
396    }
397    if ss_period > data.len() {
398        return Err(TrendFlexError::SmootherPeriodExceedsData {
399            ss_period,
400            data_len: data.len(),
401        });
402    }
403
404    let chosen = match kernel {
405        Kernel::Auto => detect_best_kernel(),
406        k => k,
407    };
408
409    unsafe {
410        match chosen {
411            Kernel::Scalar | Kernel::ScalarBatch => {
412                trendflex_scalar_into(data, period, ss_period, first, dst)?
413            }
414            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
415            Kernel::Avx2 | Kernel::Avx2Batch => {
416                trendflex_avx2_into(data, period, ss_period, first, dst)?
417            }
418            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
419            Kernel::Avx512 | Kernel::Avx512Batch => {
420                trendflex_avx512_into(data, period, ss_period, first, dst)?
421            }
422            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
423            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
424                trendflex_scalar_into(data, period, ss_period, first, dst)?
425            }
426            Kernel::Auto => unreachable!(),
427        }
428    }
429
430    let warmup_end = first + period;
431    for v in &mut dst[..warmup_end] {
432        *v = f64::NAN;
433    }
434    Ok(())
435}
436
437#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
438#[inline]
439pub fn trendflex_into(input: &TrendFlexInput, out: &mut [f64]) -> Result<(), TrendFlexError> {
440    trendflex_into_slice(out, input, Kernel::Auto)
441}
442
443#[inline]
444unsafe fn trendflex_scalar_into(
445    data: &[f64],
446    period: usize,
447    ss_period: usize,
448    first_valid: usize,
449    out: &mut [f64],
450) -> Result<(), TrendFlexError> {
451    use std::f64::consts::PI;
452
453    let len = data.len();
454    let warm = first_valid + period;
455
456    for i in 0..warm.min(out.len()) {
457        out[i] = f64::NAN;
458    }
459
460    if first_valid >= len {
461        return Ok(());
462    }
463
464    let a = (-1.414_f64 * PI / ss_period as f64).exp();
465    let a_sq = a * a;
466    let b = 2.0 * a * (1.414_f64 * PI / ss_period as f64).cos();
467
468    let c = (1.0 + a_sq - b) * 0.5;
469
470    let m = len - first_valid;
471    if m < period {
472        return Ok(());
473    }
474    if m < ss_period {
475        return Err(TrendFlexError::SmootherPeriodExceedsData {
476            ss_period,
477            data_len: m,
478        });
479    }
480
481    let x = &data[first_valid..];
482
483    let mut prev2 = x[0];
484    let mut prev1 = if m > 1 { x[1] } else { x[0] };
485
486    let mut ring = vec![0.0f64; period];
487    let mut head = 0usize;
488    let mut sum = 0.0f64;
489
490    ring[head] = prev2;
491    sum += prev2;
492    head = (head + 1) % period;
493    if m > 1 {
494        ring[head] = prev1;
495        sum += prev1;
496        head = (head + 1) % period;
497    }
498
499    let tp_f = period as f64;
500    let inv_tp = 1.0 / tp_f;
501    let mut ms_prev = 0.0f64;
502
503    let mut i = 2usize;
504    while i < m && i < period {
505        let cur = (-a_sq).mul_add(prev2, b.mul_add(prev1, c * (x[i] + x[i - 1])));
506        prev2 = prev1;
507        prev1 = cur;
508
509        sum += cur;
510        ring[head] = cur;
511        head = (head + 1) % period;
512        i += 1;
513    }
514
515    while i < m {
516        let cur = (-a_sq).mul_add(prev2, b.mul_add(prev1, c * (x[i] + x[i - 1])));
517        prev2 = prev1;
518        prev1 = cur;
519
520        let my_sum = (tp_f * cur - sum) * inv_tp;
521
522        let ms_current = 0.04f64.mul_add(my_sum * my_sum, 0.96f64 * ms_prev);
523        ms_prev = ms_current;
524
525        let out_val = if ms_current != 0.0 {
526            my_sum / ms_current.sqrt()
527        } else {
528            0.0
529        };
530        out[first_valid + i] = out_val;
531
532        let old = ring[head];
533        sum += cur - old;
534        ring[head] = cur;
535        head = (head + 1) % period;
536
537        i += 1;
538    }
539
540    Ok(())
541}
542
543#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
544#[inline]
545#[target_feature(enable = "avx2,fma")]
546unsafe fn trendflex_avx2_into(
547    data: &[f64],
548    period: usize,
549    ss_period: usize,
550    first_valid: usize,
551    out: &mut [f64],
552) -> Result<(), TrendFlexError> {
553    use std::f64::consts::PI;
554
555    let len = data.len();
556    let warm = first_valid + period;
557    for i in 0..warm.min(out.len()) {
558        *out.get_unchecked_mut(i) = f64::NAN;
559    }
560
561    if first_valid >= len {
562        return Ok(());
563    }
564
565    let a = (-1.414_f64 * PI / ss_period as f64).exp();
566    let a_sq = a * a;
567    let b = 2.0 * a * (1.414_f64 * PI / ss_period as f64).cos();
568    let c = (1.0 + a_sq - b) * 0.5;
569
570    #[inline(always)]
571    unsafe fn run_series_avx2(
572        x: &[f64],
573        period: usize,
574        a_sq: f64,
575        b: f64,
576        c: f64,
577        out: &mut [f64],
578        out_off: usize,
579    ) {
580        let n = x.len();
581        if n == 0 {
582            return;
583        }
584        let mut prev2 = x[0];
585        let mut prev1 = if n > 1 { x[1] } else { x[0] };
586
587        let mut ring = vec![0.0f64; period];
588        let mut sum = 0.0f64;
589        let mut head = 0usize;
590
591        ring[head] = prev2;
592        sum += prev2;
593        head = (head + 1) % period;
594        if n > 1 {
595            ring[head] = prev1;
596            sum += prev1;
597            head = (head + 1) % period;
598        }
599
600        let tp_f = period as f64;
601        let inv_tp = 1.0 / tp_f;
602        let mut ms_prev = 0.0f64;
603
604        let mut i = 2usize;
605        while i < n && i < period {
606            let cur = c * (x[i] + x[i - 1]) + b * prev1 - a_sq * prev2;
607            prev2 = prev1;
608            prev1 = cur;
609            sum += cur;
610            ring[head] = cur;
611            head = (head + 1) % period;
612            i += 1;
613        }
614
615        while i < n {
616            _mm_prefetch(x.as_ptr().add(i + 16).cast(), _MM_HINT_T0);
617            let cur = c * (x[i] + x[i - 1]) + b * prev1 - a_sq * prev2;
618            prev2 = prev1;
619            prev1 = cur;
620
621            let my_sum = (tp_f * cur - sum) * inv_tp;
622
623            let v = _mm_set_sd(my_sum);
624            let sq = _mm_mul_sd(v, v);
625            let s04 = _mm_mul_sd(_mm_set_sd(0.04), sq);
626            let s96 = _mm_mul_sd(_mm_set_sd(0.96), _mm_set_sd(ms_prev));
627            let ms_cur = _mm_add_sd(s04, s96);
628            let ms_current = _mm_cvtsd_f64(ms_cur);
629            ms_prev = ms_current;
630
631            let out_val = if ms_current != 0.0 {
632                let denom = _mm_sqrt_sd(_mm_setzero_pd(), _mm_set_sd(ms_current));
633                let denom_s = _mm_cvtsd_f64(denom);
634                my_sum / denom_s
635            } else {
636                0.0
637            };
638
639            _mm_stream_sd(
640                out.get_unchecked_mut(out_off + i) as *mut f64,
641                _mm_set_sd(out_val),
642            );
643
644            let old = ring[head];
645            sum += cur - old;
646            ring[head] = cur;
647            head = (head + 1) % period;
648
649            i += 1;
650        }
651    }
652
653    if first_valid == 0 {
654        run_series_avx2(data, period, a_sq, b, c, out, 0);
655        return Ok(());
656    }
657
658    let m = len - first_valid;
659    if m < period {
660        return Ok(());
661    }
662    if m < ss_period {
663        return Err(TrendFlexError::SmootherPeriodExceedsData {
664            ss_period,
665            data_len: m,
666        });
667    }
668    let tail = &data[first_valid..];
669    run_series_avx2(tail, period, a_sq, b, c, out, first_valid);
670    Ok(())
671}
672
673#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
674#[inline]
675#[target_feature(enable = "avx512f,avx512dq,fma")]
676unsafe fn trendflex_avx512_into(
677    data: &[f64],
678    period: usize,
679    ss_period: usize,
680    first_valid: usize,
681    out: &mut [f64],
682) -> Result<(), TrendFlexError> {
683    use std::f64::consts::PI;
684
685    let len = data.len();
686    let warm = first_valid + period;
687    for i in 0..warm.min(out.len()) {
688        *out.get_unchecked_mut(i) = f64::NAN;
689    }
690
691    if first_valid >= len {
692        return Ok(());
693    }
694
695    let a = (-1.414_f64 * PI / ss_period as f64).exp();
696    let a_sq = a * a;
697    let b = 2.0 * a * (1.414_f64 * PI / ss_period as f64).cos();
698    let c = (1.0 + a_sq - b) * 0.5;
699
700    #[inline(always)]
701    unsafe fn run_series_avx512(
702        x: &[f64],
703        period: usize,
704        a_sq: f64,
705        b: f64,
706        c: f64,
707        out: &mut [f64],
708        out_off: usize,
709    ) {
710        let n = x.len();
711        if n == 0 {
712            return;
713        }
714        let mut prev2 = *x.get_unchecked(0);
715        let mut prev1 = if n > 1 {
716            *x.get_unchecked(1)
717        } else {
718            *x.get_unchecked(0)
719        };
720        let mut ring = vec![0.0f64; period];
721        let mut sum = 0.0f64;
722        let mut head = 0usize;
723
724        *ring.get_unchecked_mut(head) = prev2;
725        sum += prev2;
726        head += 1;
727        if head == period {
728            head = 0;
729        }
730        if n > 1 {
731            *ring.get_unchecked_mut(head) = prev1;
732            sum += prev1;
733            head += 1;
734            if head == period {
735                head = 0;
736            }
737        }
738
739        let tp_f = period as f64;
740        let inv_tp = 1.0 / tp_f;
741        let mut ms_prev = 0.0f64;
742
743        let mut i = 2usize;
744        while i < n && i < period {
745            let cur =
746                c * (*x.get_unchecked(i) + *x.get_unchecked(i - 1)) + b * prev1 - a_sq * prev2;
747            prev2 = prev1;
748            prev1 = cur;
749            sum += cur;
750            *ring.get_unchecked_mut(head) = cur;
751            head += 1;
752            if head == period {
753                head = 0;
754            }
755            i += 1;
756        }
757
758        let use_stream = n >= 131072;
759        let use_unroll = n >= 262144;
760
761        if use_unroll {
762            while i + 1 < n {
763                _mm_prefetch(x.as_ptr().add(i + 32).cast(), _MM_HINT_T0);
764
765                let cur0 =
766                    c * (*x.get_unchecked(i) + *x.get_unchecked(i - 1)) + b * prev1 - a_sq * prev2;
767                prev2 = prev1;
768                prev1 = cur0;
769
770                let my_sum0 = (tp_f * cur0 - sum) * inv_tp;
771
772                let v0 = _mm_set_sd(my_sum0);
773                let sq0 = _mm_mul_sd(v0, v0);
774                let ms0 = _mm_fmadd_sd(
775                    _mm_set_sd(0.04),
776                    sq0,
777                    _mm_mul_sd(_mm_set_sd(0.96), _mm_set_sd(ms_prev)),
778                );
779                let ms0_s = _mm_cvtsd_f64(ms0);
780                ms_prev = ms0_s;
781                let out0 = if ms0_s != 0.0 {
782                    let den0 = _mm_sqrt_sd(_mm_setzero_pd(), _mm_set_sd(ms0_s));
783                    my_sum0 / _mm_cvtsd_f64(den0)
784                } else {
785                    0.0
786                };
787                if use_stream {
788                    _mm_stream_sd(
789                        out.get_unchecked_mut(out_off + i) as *mut f64,
790                        _mm_set_sd(out0),
791                    );
792                } else {
793                    *out.get_unchecked_mut(out_off + i) = out0;
794                }
795
796                let old0 = *ring.get_unchecked(head);
797                sum += cur0 - old0;
798                *ring.get_unchecked_mut(head) = cur0;
799                head += 1;
800                if head == period {
801                    head = 0;
802                }
803
804                let cur1 =
805                    c * (*x.get_unchecked(i + 1) + *x.get_unchecked(i)) + b * prev1 - a_sq * prev2;
806                prev2 = prev1;
807                prev1 = cur1;
808
809                let my_sum1 = (tp_f * cur1 - sum) * inv_tp;
810                let v1 = _mm_set_sd(my_sum1);
811                let sq1 = _mm_mul_sd(v1, v1);
812                let ms1 = _mm_fmadd_sd(
813                    _mm_set_sd(0.04),
814                    sq1,
815                    _mm_mul_sd(_mm_set_sd(0.96), _mm_set_sd(ms_prev)),
816                );
817                let ms1_s = _mm_cvtsd_f64(ms1);
818                ms_prev = ms1_s;
819                let out1 = if ms1_s != 0.0 {
820                    let den1 = _mm_sqrt_sd(_mm_setzero_pd(), _mm_set_sd(ms1_s));
821                    my_sum1 / _mm_cvtsd_f64(den1)
822                } else {
823                    0.0
824                };
825                if use_stream {
826                    _mm_stream_sd(
827                        out.get_unchecked_mut(out_off + i + 1) as *mut f64,
828                        _mm_set_sd(out1),
829                    );
830                } else {
831                    *out.get_unchecked_mut(out_off + i + 1) = out1;
832                }
833
834                let old1 = *ring.get_unchecked(head);
835                sum += cur1 - old1;
836                *ring.get_unchecked_mut(head) = cur1;
837                head += 1;
838                if head == period {
839                    head = 0;
840                }
841
842                i += 2;
843            }
844        }
845
846        while i < n {
847            _mm_prefetch(x.as_ptr().add(i + 32).cast(), _MM_HINT_T0);
848            let cur =
849                c * (*x.get_unchecked(i) + *x.get_unchecked(i - 1)) + b * prev1 - a_sq * prev2;
850            prev2 = prev1;
851            prev1 = cur;
852
853            let my_sum = (tp_f * cur - sum) * inv_tp;
854            let v = _mm_set_sd(my_sum);
855            let sq = _mm_mul_sd(v, v);
856            let ms = _mm_fmadd_sd(
857                _mm_set_sd(0.04),
858                sq,
859                _mm_mul_sd(_mm_set_sd(0.96), _mm_set_sd(ms_prev)),
860            );
861            let ms_s = _mm_cvtsd_f64(ms);
862            ms_prev = ms_s;
863            let out_val = if ms_s != 0.0 {
864                let den = _mm_sqrt_sd(_mm_setzero_pd(), _mm_set_sd(ms_s));
865                my_sum / _mm_cvtsd_f64(den)
866            } else {
867                0.0
868            };
869            if use_stream {
870                _mm_stream_sd(
871                    out.get_unchecked_mut(out_off + i) as *mut f64,
872                    _mm_set_sd(out_val),
873                );
874            } else {
875                *out.get_unchecked_mut(out_off + i) = out_val;
876            }
877
878            let old = *ring.get_unchecked(head);
879            sum += cur - old;
880            *ring.get_unchecked_mut(head) = cur;
881            head += 1;
882            if head == period {
883                head = 0;
884            }
885
886            i += 1;
887        }
888    }
889
890    if first_valid == 0 {
891        run_series_avx512(data, period, a_sq, b, c, out, 0);
892        return Ok(());
893    }
894
895    let m = len - first_valid;
896    if m < period {
897        return Ok(());
898    }
899    if m < ss_period {
900        return Err(TrendFlexError::SmootherPeriodExceedsData {
901            ss_period,
902            data_len: m,
903        });
904    }
905    let tail = &data[first_valid..];
906    run_series_avx512(tail, period, a_sq, b, c, out, first_valid);
907    Ok(())
908}
909
910#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
911#[inline]
912unsafe fn trendflex_avx512_short_into(
913    data: &[f64],
914    period: usize,
915    ss_period: usize,
916    first_valid: usize,
917    out: &mut [f64],
918) -> Result<(), TrendFlexError> {
919    trendflex_scalar_into(data, period, ss_period, first_valid, out)
920}
921
922#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
923#[inline]
924unsafe fn trendflex_avx512_long_into(
925    data: &[f64],
926    period: usize,
927    ss_period: usize,
928    first_valid: usize,
929    out: &mut [f64],
930) -> Result<(), TrendFlexError> {
931    trendflex_scalar_into(data, period, ss_period, first_valid, out)
932}
933
934#[derive(Debug, Clone)]
935pub struct TrendFlexStream {
936    period: usize,
937    ss_period: usize,
938
939    a: f64,
940    a_sq: f64,
941    b: f64,
942    c: f64,
943
944    buf: Vec<f64>,
945    sum: f64,
946    head: usize,
947
948    prev1_ssf: f64,
949    prev2_ssf: f64,
950    last_raw: f64,
951
952    n_ssf: usize,
953
954    ms_prev: f64,
955
956    inv_p: f64,
957}
958
959impl TrendFlexStream {
960    pub fn try_new(params: TrendFlexParams) -> Result<Self, TrendFlexError> {
961        let period = params.period.unwrap_or(20);
962        if period == 0 {
963            return Err(TrendFlexError::ZeroTrendFlexPeriod { period });
964        }
965
966        let ss_period = ((period as f64) / 2.0).round() as usize;
967        if ss_period == 0 {
968            return Err(TrendFlexError::SmootherPeriodExceedsData {
969                ss_period,
970                data_len: 0,
971            });
972        }
973
974        use std::f64::consts::PI;
975        let a = (-1.414_f64 * PI / (ss_period as f64)).exp();
976        let a_sq = a * a;
977        let b = 2.0 * a * (1.414_f64 * PI / (ss_period as f64)).cos();
978        let c = (1.0 + a_sq - b) * 0.5;
979
980        Ok(Self {
981            period,
982            ss_period,
983            a,
984            a_sq,
985            b,
986            c,
987            buf: vec![0.0; period],
988            sum: 0.0,
989            head: 0,
990            prev1_ssf: 0.0,
991            prev2_ssf: 0.0,
992            last_raw: 0.0,
993            n_ssf: 0,
994            ms_prev: 0.0,
995            inv_p: 1.0 / (period as f64),
996        })
997    }
998
999    #[inline(always)]
1000    pub fn update(&mut self, x: f64) -> Option<f64> {
1001        if self.n_ssf == 0 {
1002            self.prev2_ssf = x;
1003            self.last_raw = x;
1004
1005            self.buf[self.head] = x;
1006            self.sum += x;
1007            self.head = if self.period > 1 { 1 } else { 0 };
1008            self.n_ssf = 1;
1009            return None;
1010        }
1011
1012        if self.n_ssf == 1 {
1013            self.prev1_ssf = x;
1014            self.last_raw = x;
1015
1016            if self.period > 1 {
1017                self.buf[self.head] = x;
1018                self.sum += x;
1019                self.head = (self.head + 1) % self.period;
1020            } else {
1021                self.buf[0] = x;
1022                self.sum = x;
1023            }
1024            self.n_ssf = 2;
1025            return None;
1026        }
1027
1028        let cur = (-self.a_sq).mul_add(
1029            self.prev2_ssf,
1030            self.b.mul_add(self.prev1_ssf, self.c * (x + self.last_raw)),
1031        );
1032
1033        let tp_cur_minus_sum = (self.period as f64).mul_add(cur, -self.sum);
1034        let my_sum = self.inv_p * tp_cur_minus_sum;
1035
1036        let will_emit = self.n_ssf + 1 > self.period;
1037
1038        let out_val = if will_emit {
1039            let sq = my_sum * my_sum;
1040            let ms_current = 0.04f64.mul_add(sq, 0.96f64 * self.ms_prev);
1041            self.ms_prev = ms_current;
1042            if ms_current > 0.0 {
1043                my_sum / ms_current.sqrt()
1044            } else {
1045                0.0
1046            }
1047        } else {
1048            0.0
1049        };
1050
1051        let old = self.buf[self.head];
1052        self.sum += cur - old;
1053        self.buf[self.head] = cur;
1054        self.head = (self.head + 1) % self.period;
1055
1056        self.prev2_ssf = self.prev1_ssf;
1057        self.prev1_ssf = cur;
1058        self.last_raw = x;
1059        self.n_ssf += 1;
1060
1061        if will_emit {
1062            Some(out_val)
1063        } else {
1064            None
1065        }
1066    }
1067}
1068
1069#[inline(always)]
1070pub fn trendflex_batch_inner_into(
1071    data: &[f64],
1072    sweep: &TrendFlexBatchRange,
1073    kern: Kernel,
1074    parallel: bool,
1075    out: &mut [f64],
1076) -> Result<Vec<TrendFlexParams>, TrendFlexError> {
1077    let combos = expand_grid(sweep)?;
1078    if combos.is_empty() {
1079        return Err(TrendFlexError::InvalidRange {
1080            start: sweep.period.0,
1081            end: sweep.period.1,
1082            step: sweep.period.2,
1083        });
1084    }
1085
1086    let first = data
1087        .iter()
1088        .position(|x| !x.is_nan())
1089        .ok_or(TrendFlexError::AllValuesNaN)?;
1090    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1091    if data.len() - first < max_p {
1092        return Err(TrendFlexError::TrendFlexPeriodExceedsData {
1093            period: max_p,
1094            data_len: data.len() - first,
1095        });
1096    }
1097
1098    let rows = combos.len();
1099    let cols = data.len();
1100    let expected = rows
1101        .checked_mul(cols)
1102        .ok_or(TrendFlexError::DimensionsOverflow { rows, cols })?;
1103    if out.len() != expected {
1104        return Err(TrendFlexError::OutputLengthMismatch {
1105            expected,
1106            got: out.len(),
1107        });
1108    }
1109
1110    let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
1111
1112    for (row, &warmup) in warm.iter().enumerate() {
1113        let start = row * cols;
1114        let end = start + warmup;
1115        out[start..end].fill(f64::NAN);
1116    }
1117
1118    let actual_kern = match kern {
1119        Kernel::Auto => detect_best_batch_kernel(),
1120        k => k,
1121    };
1122
1123    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1124        let period = combos[row].period.unwrap();
1125
1126        match actual_kern {
1127            Kernel::Scalar | Kernel::ScalarBatch => {
1128                trendflex_row_scalar(data, first, period, out_row)
1129            }
1130            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1131            Kernel::Avx2 | Kernel::Avx2Batch => trendflex_row_avx2(data, first, period, out_row),
1132            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1133            Kernel::Avx512 | Kernel::Avx512Batch => {
1134                trendflex_row_avx512(data, first, period, out_row)
1135            }
1136            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1137            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
1138                trendflex_row_scalar(data, first, period, out_row)
1139            }
1140            Kernel::Auto => unreachable!("Auto kernel should have been resolved"),
1141        }
1142    };
1143
1144    if parallel {
1145        #[cfg(not(target_arch = "wasm32"))]
1146        {
1147            use rayon::prelude::*;
1148            out.par_chunks_mut(cols)
1149                .enumerate()
1150                .for_each(|(row, slice)| do_row(row, slice));
1151        }
1152
1153        #[cfg(target_arch = "wasm32")]
1154        {
1155            for (row, slice) in out.chunks_mut(cols).enumerate() {
1156                do_row(row, slice);
1157            }
1158        }
1159    } else {
1160        for (row, slice) in out.chunks_mut(cols).enumerate() {
1161            do_row(row, slice);
1162        }
1163    }
1164
1165    Ok(combos)
1166}
1167
1168#[derive(Clone, Debug)]
1169pub struct TrendFlexBatchRange {
1170    pub period: (usize, usize, usize),
1171}
1172
1173impl Default for TrendFlexBatchRange {
1174    fn default() -> Self {
1175        Self {
1176            period: (20, 269, 1),
1177        }
1178    }
1179}
1180
1181#[derive(Clone, Debug, Default)]
1182pub struct TrendFlexBatchBuilder {
1183    range: TrendFlexBatchRange,
1184    kernel: Kernel,
1185}
1186
1187impl TrendFlexBatchBuilder {
1188    pub fn new() -> Self {
1189        Self::default()
1190    }
1191    pub fn kernel(mut self, k: Kernel) -> Self {
1192        self.kernel = k;
1193        self
1194    }
1195    #[inline]
1196    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1197        self.range.period = (start, end, step);
1198        self
1199    }
1200    #[inline]
1201    pub fn period_static(mut self, p: usize) -> Self {
1202        self.range.period = (p, p, 0);
1203        self
1204    }
1205
1206    pub fn apply_slice(self, data: &[f64]) -> Result<TrendFlexBatchOutput, TrendFlexError> {
1207        trendflex_batch_with_kernel(data, &self.range, self.kernel)
1208    }
1209    pub fn with_default_slice(
1210        data: &[f64],
1211        k: Kernel,
1212    ) -> Result<TrendFlexBatchOutput, TrendFlexError> {
1213        TrendFlexBatchBuilder::new().kernel(k).apply_slice(data)
1214    }
1215    pub fn apply_candles(
1216        self,
1217        c: &Candles,
1218        src: &str,
1219    ) -> Result<TrendFlexBatchOutput, TrendFlexError> {
1220        let slice = source_type(c, src);
1221        self.apply_slice(slice)
1222    }
1223    pub fn with_default_candles(c: &Candles) -> Result<TrendFlexBatchOutput, TrendFlexError> {
1224        TrendFlexBatchBuilder::new()
1225            .kernel(Kernel::Auto)
1226            .apply_candles(c, "close")
1227    }
1228}
1229
1230pub fn trendflex_batch_with_kernel(
1231    data: &[f64],
1232    sweep: &TrendFlexBatchRange,
1233    k: Kernel,
1234) -> Result<TrendFlexBatchOutput, TrendFlexError> {
1235    let kernel = match k {
1236        Kernel::Auto => detect_best_batch_kernel(),
1237        other if other.is_batch() => other,
1238        _ => return Err(TrendFlexError::InvalidKernelForBatch(k)),
1239    };
1240
1241    let simd = match kernel {
1242        Kernel::Avx512Batch => Kernel::Avx512,
1243        Kernel::Avx2Batch => Kernel::Avx2,
1244        Kernel::ScalarBatch => Kernel::Scalar,
1245        _ => unreachable!(),
1246    };
1247    trendflex_batch_par_slice(data, sweep, simd)
1248}
1249
1250#[derive(Clone, Debug)]
1251pub struct TrendFlexBatchOutput {
1252    pub values: Vec<f64>,
1253    pub combos: Vec<TrendFlexParams>,
1254    pub rows: usize,
1255    pub cols: usize,
1256}
1257
1258impl TrendFlexBatchOutput {
1259    pub fn row_for_params(&self, p: &TrendFlexParams) -> Option<usize> {
1260        self.combos
1261            .iter()
1262            .position(|c| c.period.unwrap_or(20) == p.period.unwrap_or(20))
1263    }
1264    pub fn values_for(&self, p: &TrendFlexParams) -> Option<&[f64]> {
1265        self.row_for_params(p).map(|row| {
1266            let start = row * self.cols;
1267            &self.values[start..start + self.cols]
1268        })
1269    }
1270}
1271
1272#[inline(always)]
1273fn expand_grid(r: &TrendFlexBatchRange) -> Result<Vec<TrendFlexParams>, TrendFlexError> {
1274    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, TrendFlexError> {
1275        if step == 0 || start == end {
1276            return Ok(vec![start]);
1277        }
1278        if start < end {
1279            let v: Vec<usize> = (start..=end).step_by(step).collect();
1280            if v.is_empty() {
1281                return Err(TrendFlexError::InvalidRange { start, end, step });
1282            }
1283            return Ok(v);
1284        }
1285
1286        let mut v = Vec::new();
1287        let mut cur = start;
1288        while cur >= end {
1289            v.push(cur);
1290            if let Some(next) = cur.checked_sub(step) {
1291                cur = next;
1292            } else {
1293                break;
1294            }
1295            if cur == usize::MAX {
1296                break;
1297            }
1298        }
1299        if v.is_empty() {
1300            return Err(TrendFlexError::InvalidRange { start, end, step });
1301        }
1302        Ok(v)
1303    }
1304
1305    let periods = axis_usize(r.period)?;
1306    let mut out = Vec::with_capacity(periods.len());
1307    for &p in &periods {
1308        out.push(TrendFlexParams { period: Some(p) });
1309    }
1310    Ok(out)
1311}
1312
1313#[inline(always)]
1314pub fn expand_grid_trendflex(r: &TrendFlexBatchRange) -> Vec<TrendFlexParams> {
1315    expand_grid(r).unwrap_or_default()
1316}
1317
1318#[inline(always)]
1319pub fn expand_grid_trendflex_checked(
1320    r: &TrendFlexBatchRange,
1321) -> Result<Vec<TrendFlexParams>, TrendFlexError> {
1322    expand_grid(r)
1323}
1324
1325#[inline(always)]
1326pub fn trendflex_batch_slice(
1327    data: &[f64],
1328    sweep: &TrendFlexBatchRange,
1329    kern: Kernel,
1330) -> Result<TrendFlexBatchOutput, TrendFlexError> {
1331    trendflex_batch_inner(data, sweep, kern, false)
1332}
1333#[inline(always)]
1334pub fn trendflex_batch_par_slice(
1335    data: &[f64],
1336    sweep: &TrendFlexBatchRange,
1337    kern: Kernel,
1338) -> Result<TrendFlexBatchOutput, TrendFlexError> {
1339    trendflex_batch_inner(data, sweep, kern, true)
1340}
1341
1342#[inline(always)]
1343fn trendflex_batch_inner(
1344    data: &[f64],
1345    sweep: &TrendFlexBatchRange,
1346    kern: Kernel,
1347    parallel: bool,
1348) -> Result<TrendFlexBatchOutput, TrendFlexError> {
1349    let combos = expand_grid(sweep)?;
1350    if combos.is_empty() {
1351        return Err(TrendFlexError::InvalidRange {
1352            start: sweep.period.0,
1353            end: sweep.period.1,
1354            step: sweep.period.2,
1355        });
1356    }
1357    let first = data
1358        .iter()
1359        .position(|x| !x.is_nan())
1360        .ok_or(TrendFlexError::AllValuesNaN)?;
1361    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1362    if data.len() - first < max_p {
1363        return Err(TrendFlexError::TrendFlexPeriodExceedsData {
1364            period: max_p,
1365            data_len: data.len() - first,
1366        });
1367    }
1368    let rows = combos.len();
1369    let cols = data.len();
1370
1371    rows.checked_mul(cols)
1372        .ok_or(TrendFlexError::DimensionsOverflow { rows, cols })?;
1373
1374    let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
1375    let mut raw = make_uninit_matrix(rows, cols);
1376
1377    unsafe {
1378        init_matrix_prefixes(&mut raw, cols, &warm);
1379    }
1380
1381    let actual_kern = match kern {
1382        Kernel::Auto => detect_best_batch_kernel(),
1383        k => k,
1384    };
1385
1386    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1387        let period = combos[row].period.unwrap();
1388
1389        let out_row =
1390            core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1391
1392        match actual_kern {
1393            Kernel::Scalar | Kernel::ScalarBatch => {
1394                trendflex_row_scalar(data, first, period, out_row)
1395            }
1396            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1397            Kernel::Avx2 | Kernel::Avx2Batch => trendflex_row_avx2(data, first, period, out_row),
1398            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1399            Kernel::Avx512 | Kernel::Avx512Batch => {
1400                trendflex_row_avx512(data, first, period, out_row)
1401            }
1402            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1403            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
1404                trendflex_row_scalar(data, first, period, out_row)
1405            }
1406            Kernel::Auto => unreachable!("Auto kernel should have been resolved"),
1407        }
1408    };
1409
1410    if parallel {
1411        #[cfg(not(target_arch = "wasm32"))]
1412        {
1413            raw.par_chunks_mut(cols)
1414                .enumerate()
1415                .for_each(|(row, slice)| do_row(row, slice));
1416        }
1417
1418        #[cfg(target_arch = "wasm32")]
1419        {
1420            for (row, slice) in raw.chunks_mut(cols).enumerate() {
1421                do_row(row, slice);
1422            }
1423        }
1424    } else {
1425        for (row, slice) in raw.chunks_mut(cols).enumerate() {
1426            do_row(row, slice);
1427        }
1428    }
1429
1430    use core::mem::ManuallyDrop;
1431    let mut guard = ManuallyDrop::new(raw);
1432    let values: Vec<f64> = unsafe {
1433        Vec::from_raw_parts(
1434            guard.as_mut_ptr() as *mut f64,
1435            guard.len(),
1436            guard.capacity(),
1437        )
1438    };
1439
1440    Ok(TrendFlexBatchOutput {
1441        values,
1442        combos,
1443        rows,
1444        cols,
1445    })
1446}
1447
1448#[inline(always)]
1449unsafe fn trendflex_row_scalar(data: &[f64], first: usize, period: usize, out_row: &mut [f64]) {
1450    let ss_period = ((period as f64) / 2.0).round() as usize;
1451    let _ = trendflex_scalar_into(data, period, ss_period, first, out_row);
1452}
1453#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1454#[inline(always)]
1455unsafe fn trendflex_row_avx2(data: &[f64], first: usize, period: usize, out_row: &mut [f64]) {
1456    let ss_period = ((period as f64) / 2.0).round() as usize;
1457    let _ = trendflex_avx2_into(data, period, ss_period, first, out_row);
1458}
1459#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1460#[inline(always)]
1461unsafe fn trendflex_row_avx512(data: &[f64], first: usize, period: usize, out_row: &mut [f64]) {
1462    let ss_period = ((period as f64) / 2.0).round() as usize;
1463    let _ = trendflex_avx512_into(data, period, ss_period, first, out_row);
1464}
1465
1466#[cfg(test)]
1467mod tests {
1468    use super::*;
1469    use crate::skip_if_unsupported;
1470    use crate::utilities::data_loader::read_candles_from_csv;
1471
1472    fn check_trendflex_partial_params(
1473        test_name: &str,
1474        kernel: Kernel,
1475    ) -> Result<(), Box<dyn Error>> {
1476        skip_if_unsupported!(kernel, test_name);
1477        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1478        let candles = read_candles_from_csv(file_path)?;
1479
1480        let default_params = TrendFlexParams { period: None };
1481        let input = TrendFlexInput::from_candles(&candles, "close", default_params);
1482        let output = trendflex_with_kernel(&input, kernel)?;
1483        assert_eq!(output.values.len(), candles.close.len());
1484
1485        Ok(())
1486    }
1487
1488    fn check_trendflex_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1489        skip_if_unsupported!(kernel, test_name);
1490        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1491        let candles = read_candles_from_csv(file_path)?;
1492
1493        let params = TrendFlexParams { period: Some(20) };
1494        let input = TrendFlexInput::from_candles(&candles, "close", params);
1495        let result = trendflex_with_kernel(&input, kernel)?;
1496        let expected_last_five = [
1497            -0.19724678008015128,
1498            -0.1238001236481444,
1499            -0.10515389737087717,
1500            -0.1149541079904878,
1501            -0.16006869484450567,
1502        ];
1503        let start = result.values.len().saturating_sub(5);
1504        for (i, &val) in result.values[start..].iter().enumerate() {
1505            let diff = (val - expected_last_five[i]).abs();
1506            assert!(
1507                diff < 1e-8,
1508                "[{}] TrendFlex {:?} mismatch at idx {}: got {}, expected {}",
1509                test_name,
1510                kernel,
1511                i,
1512                val,
1513                expected_last_five[i]
1514            );
1515        }
1516        Ok(())
1517    }
1518
1519    fn check_trendflex_default_candles(
1520        test_name: &str,
1521        kernel: Kernel,
1522    ) -> Result<(), Box<dyn Error>> {
1523        skip_if_unsupported!(kernel, test_name);
1524        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1525        let candles = read_candles_from_csv(file_path)?;
1526
1527        let input = TrendFlexInput::with_default_candles(&candles);
1528        match input.data {
1529            TrendFlexData::Candles { source, .. } => assert_eq!(source, "close"),
1530            _ => panic!("Expected TrendFlexData::Candles"),
1531        }
1532        let output = trendflex_with_kernel(&input, kernel)?;
1533        assert_eq!(output.values.len(), candles.close.len());
1534
1535        Ok(())
1536    }
1537
1538    fn check_trendflex_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1539        skip_if_unsupported!(kernel, test_name);
1540        let input_data = [10.0, 20.0, 30.0];
1541        let params = TrendFlexParams { period: Some(0) };
1542        let input = TrendFlexInput::from_slice(&input_data, params);
1543        let res = trendflex_with_kernel(&input, kernel);
1544        assert!(
1545            res.is_err(),
1546            "[{}] TrendFlex should fail with zero period",
1547            test_name
1548        );
1549        Ok(())
1550    }
1551
1552    fn check_trendflex_period_exceeds_length(
1553        test_name: &str,
1554        kernel: Kernel,
1555    ) -> Result<(), Box<dyn Error>> {
1556        skip_if_unsupported!(kernel, test_name);
1557        let data_small = [10.0, 20.0, 30.0];
1558        let params = TrendFlexParams { period: Some(10) };
1559        let input = TrendFlexInput::from_slice(&data_small, params);
1560        let res = trendflex_with_kernel(&input, kernel);
1561        assert!(
1562            res.is_err(),
1563            "[{}] TrendFlex should fail with period exceeding length",
1564            test_name
1565        );
1566        Ok(())
1567    }
1568
1569    fn check_trendflex_very_small_dataset(
1570        test_name: &str,
1571        kernel: Kernel,
1572    ) -> Result<(), Box<dyn Error>> {
1573        skip_if_unsupported!(kernel, test_name);
1574        let single_point = [42.0];
1575        let params = TrendFlexParams { period: Some(9) };
1576        let input = TrendFlexInput::from_slice(&single_point, params);
1577        let res = trendflex_with_kernel(&input, kernel);
1578        assert!(
1579            res.is_err(),
1580            "[{}] TrendFlex should fail with insufficient data",
1581            test_name
1582        );
1583        Ok(())
1584    }
1585
1586    fn check_trendflex_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1587        skip_if_unsupported!(kernel, test_name);
1588        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1589        let candles = read_candles_from_csv(file_path)?;
1590
1591        let first_params = TrendFlexParams { period: Some(20) };
1592        let first_input = TrendFlexInput::from_candles(&candles, "close", first_params);
1593        let first_result = trendflex_with_kernel(&first_input, kernel)?;
1594
1595        let second_params = TrendFlexParams { period: Some(10) };
1596        let second_input = TrendFlexInput::from_slice(&first_result.values, second_params);
1597        let second_result = trendflex_with_kernel(&second_input, kernel)?;
1598
1599        assert_eq!(second_result.values.len(), first_result.values.len());
1600        if second_result.values.len() > 240 {
1601            for (i, &val) in second_result.values[240..].iter().enumerate() {
1602                assert!(
1603                    !val.is_nan(),
1604                    "[{}] Found unexpected NaN at out-index {}",
1605                    test_name,
1606                    240 + i
1607                );
1608            }
1609        }
1610        Ok(())
1611    }
1612
1613    fn check_trendflex_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1614        skip_if_unsupported!(kernel, test_name);
1615        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1616        let candles = read_candles_from_csv(file_path)?;
1617
1618        let input =
1619            TrendFlexInput::from_candles(&candles, "close", TrendFlexParams { period: Some(20) });
1620        let res = trendflex_with_kernel(&input, kernel)?;
1621        assert_eq!(res.values.len(), candles.close.len());
1622        if res.values.len() > 240 {
1623            for (i, &val) in res.values[240..].iter().enumerate() {
1624                assert!(
1625                    !val.is_nan(),
1626                    "[{}] Found unexpected NaN at out-index {}",
1627                    test_name,
1628                    240 + i
1629                );
1630            }
1631        }
1632        Ok(())
1633    }
1634
1635    fn check_trendflex_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1636        skip_if_unsupported!(kernel, test_name);
1637
1638        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1639        let candles = read_candles_from_csv(file_path)?;
1640
1641        let period = 20;
1642
1643        let input = TrendFlexInput::from_candles(
1644            &candles,
1645            "close",
1646            TrendFlexParams {
1647                period: Some(period),
1648            },
1649        );
1650        let batch_output = trendflex_with_kernel(&input, kernel)?.values;
1651
1652        let mut stream = TrendFlexStream::try_new(TrendFlexParams {
1653            period: Some(period),
1654        })?;
1655
1656        let mut stream_values = Vec::with_capacity(candles.close.len());
1657        for &price in &candles.close {
1658            match stream.update(price) {
1659                Some(tf_val) => stream_values.push(tf_val),
1660                None => stream_values.push(f64::NAN),
1661            }
1662        }
1663
1664        assert_eq!(batch_output.len(), stream_values.len());
1665        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1666            if b.is_nan() && s.is_nan() {
1667                continue;
1668            }
1669            let diff = (b - s).abs();
1670            assert!(
1671                diff < 1e-9,
1672                "[{}] TrendFlex streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1673                test_name,
1674                i,
1675                b,
1676                s,
1677                diff
1678            );
1679        }
1680        Ok(())
1681    }
1682
1683    macro_rules! generate_all_trendflex_tests {
1684        ($($test_fn:ident),*) => {
1685            paste::paste! {
1686                $(
1687                    #[test]
1688                    fn [<$test_fn _scalar_f64>]() {
1689                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1690                    }
1691                )*
1692                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1693                $(
1694                    #[test]
1695                    fn [<$test_fn _avx2_f64>]() {
1696                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1697                    }
1698                    #[test]
1699                    fn [<$test_fn _avx512_f64>]() {
1700                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1701                    }
1702                )*
1703            }
1704        }
1705    }
1706
1707    #[cfg(debug_assertions)]
1708    fn check_trendflex_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1709        skip_if_unsupported!(kernel, test_name);
1710
1711        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1712        let candles = read_candles_from_csv(file_path)?;
1713
1714        let test_periods = vec![5, 10, 20, 30, 50, 80, 100, 150];
1715
1716        for &period in &test_periods {
1717            let params = TrendFlexParams {
1718                period: Some(period),
1719            };
1720            let input = TrendFlexInput::from_candles(&candles, "close", params);
1721
1722            if candles.close.len() < period {
1723                continue;
1724            }
1725
1726            let output = match trendflex_with_kernel(&input, kernel) {
1727                Ok(o) => o,
1728                Err(_) => continue,
1729            };
1730
1731            for (i, &val) in output.values.iter().enumerate() {
1732                if val.is_nan() {
1733                    continue;
1734                }
1735
1736                let bits = val.to_bits();
1737
1738                if bits == 0x11111111_11111111 {
1739                    panic!(
1740						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with period {}",
1741						test_name, val, bits, i, period
1742					);
1743                }
1744
1745                if bits == 0x22222222_22222222 {
1746                    panic!(
1747						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with period {}",
1748						test_name, val, bits, i, period
1749					);
1750                }
1751
1752                if bits == 0x33333333_33333333 {
1753                    panic!(
1754						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with period {}",
1755						test_name, val, bits, i, period
1756					);
1757                }
1758            }
1759        }
1760
1761        Ok(())
1762    }
1763
1764    #[cfg(not(debug_assertions))]
1765    fn check_trendflex_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1766        Ok(())
1767    }
1768
1769    #[cfg(feature = "proptest")]
1770    #[allow(clippy::float_cmp)]
1771    fn check_trendflex_property(
1772        test_name: &str,
1773        kernel: Kernel,
1774    ) -> Result<(), Box<dyn std::error::Error>> {
1775        use proptest::prelude::*;
1776        skip_if_unsupported!(kernel, test_name);
1777
1778        let strat = (1usize..=64).prop_flat_map(|period| {
1779            (
1780                prop::collection::vec(
1781                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1782                    period..400,
1783                ),
1784                Just(period),
1785            )
1786        });
1787
1788        proptest::test_runner::TestRunner::default()
1789            .run(&strat, |(data, period)| {
1790                let input = TrendFlexInput::from_slice(
1791                    &data,
1792                    TrendFlexParams {
1793                        period: Some(period),
1794                    },
1795                );
1796                let output = trendflex_with_kernel(&input, kernel)?;
1797
1798                prop_assert_eq!(output.values.len(), data.len(), "Output length mismatch");
1799
1800                let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1801                let warmup = first + period;
1802
1803                for i in 0..warmup.min(data.len()) {
1804                    prop_assert!(
1805                        output.values[i].is_nan(),
1806                        "Expected NaN in warmup period at index {}, got {}",
1807                        i,
1808                        output.values[i]
1809                    );
1810                }
1811
1812                for i in warmup..output.values.len() {
1813                    prop_assert!(
1814                        output.values[i].is_finite(),
1815                        "Output at index {} is not finite: {}",
1816                        i,
1817                        output.values[i]
1818                    );
1819                }
1820
1821                if data.len() > warmup + 10 {
1822                    let scale_factor = 10.0;
1823                    let scaled_data: Vec<f64> = data.iter().map(|&x| x * scale_factor).collect();
1824                    let scaled_input = TrendFlexInput::from_slice(
1825                        &scaled_data,
1826                        TrendFlexParams {
1827                            period: Some(period),
1828                        },
1829                    );
1830                    let scaled_output = trendflex_with_kernel(&scaled_input, kernel)?;
1831
1832                    let mut similarity_count = 0;
1833                    let mut total_compared = 0;
1834                    for i in warmup..output.values.len() {
1835                        if output.values[i].is_finite() && scaled_output.values[i].is_finite() {
1836                            let diff = (output.values[i] - scaled_output.values[i]).abs();
1837
1838                            if diff < 0.5 {
1839                                similarity_count += 1;
1840                            }
1841                            total_compared += 1;
1842                        }
1843                    }
1844
1845                    if total_compared > 0 {
1846                        let similarity_ratio = similarity_count as f64 / total_compared as f64;
1847                        prop_assert!(
1848							similarity_ratio > 0.9,
1849							"Scale invariance failed: only {:.1}% of values are similar after scaling",
1850							similarity_ratio * 100.0
1851						);
1852                    }
1853                }
1854
1855                if data.len() > warmup + 20 {
1856                    let mut is_increasing = true;
1857                    let mut is_decreasing = true;
1858                    for i in (warmup + 1)..data.len().min(warmup + 50) {
1859                        if data[i] <= data[i - 1] {
1860                            is_increasing = false;
1861                        }
1862                        if data[i] >= data[i - 1] {
1863                            is_decreasing = false;
1864                        }
1865                    }
1866
1867                    if is_increasing {
1868                        let positive_count =
1869                            output.values[warmup..].iter().filter(|&&v| v > 0.0).count();
1870                        let total = output.values.len() - warmup;
1871                        let positive_ratio = positive_count as f64 / total as f64;
1872                        prop_assert!(
1873							positive_ratio > 0.7,
1874							"Increasing trend should produce mostly positive values, got {:.1}% positive",
1875							positive_ratio * 100.0
1876						);
1877                    } else if is_decreasing {
1878                        let negative_count =
1879                            output.values[warmup..].iter().filter(|&&v| v < 0.0).count();
1880                        let total = output.values.len() - warmup;
1881                        let negative_ratio = negative_count as f64 / total as f64;
1882                        prop_assert!(
1883							negative_ratio > 0.7,
1884							"Decreasing trend should produce mostly negative values, got {:.1}% negative",
1885							negative_ratio * 100.0
1886						);
1887                    }
1888                }
1889
1890                let all_same = data[first..]
1891                    .windows(2)
1892                    .all(|w| (w[0] - w[1]).abs() < 1e-10);
1893                if all_same && data.len() > warmup + 10 {
1894                    let last_values = &output.values[(data.len() - 5)..];
1895                    for val in last_values {
1896                        prop_assert!(
1897                            val.abs() < 0.1,
1898                            "Constant input should produce values near 0, got {}",
1899                            val
1900                        );
1901                    }
1902                }
1903
1904                if period == 1 {
1905                    for i in (first + 1)..output.values.len() {
1906                        prop_assert!(
1907                            output.values[i].is_finite(),
1908                            "Period=1 should still produce finite values at index {}",
1909                            i
1910                        );
1911                    }
1912                }
1913
1914                if data.len() > 5 && period >= data.len().saturating_sub(5) && data.len() > period {
1915                    let last_idx = data.len() - 1;
1916                    if last_idx >= warmup {
1917                        prop_assert!(
1918                            output.values[last_idx].is_finite(),
1919                            "Large period should still produce finite values at the end"
1920                        );
1921                    }
1922                }
1923
1924                if cfg!(all(feature = "nightly-avx", target_arch = "x86_64")) {
1925                    let scalar_output = trendflex_with_kernel(&input, Kernel::Scalar)?;
1926
1927                    for i in 0..output.values.len() {
1928                        if output.values[i].is_finite() && scalar_output.values[i].is_finite() {
1929                            prop_assert!(
1930                                (output.values[i] - scalar_output.values[i]).abs() < 1e-9,
1931                                "Kernel consistency failed at index {}: {} vs {}",
1932                                i,
1933                                output.values[i],
1934                                scalar_output.values[i]
1935                            );
1936                        } else {
1937                            prop_assert_eq!(
1938                                output.values[i].is_nan(),
1939                                scalar_output.values[i].is_nan(),
1940                                "NaN mismatch between kernels at index {}",
1941                                i
1942                            );
1943                        }
1944                    }
1945                }
1946
1947                Ok(())
1948            })
1949            .map_err(|e| e.into())
1950    }
1951
1952    #[cfg(feature = "proptest")]
1953    generate_all_trendflex_tests!(check_trendflex_property);
1954
1955    #[test]
1956    fn test_trendflex_into_slice_validation() {
1957        let data = vec![1.0, 2.0, 3.0];
1958        let params = TrendFlexParams { period: Some(10) };
1959        let input = TrendFlexInput::from_slice(&data, params);
1960        let mut out = vec![0.0; data.len()];
1961
1962        let result = trendflex_into_slice(&mut out, &input, Kernel::Scalar);
1963        assert!(result.is_err());
1964        match result {
1965            Err(TrendFlexError::TrendFlexPeriodExceedsData { period, data_len }) => {
1966                assert_eq!(period, 10);
1967                assert_eq!(data_len, 3);
1968            }
1969            _ => panic!("Expected TrendFlexPeriodExceedsData error"),
1970        }
1971
1972        let empty_data: Vec<f64> = vec![];
1973        let params = TrendFlexParams { period: Some(5) };
1974        let input = TrendFlexInput::from_slice(&empty_data, params);
1975        let mut out = vec![];
1976
1977        let result = trendflex_into_slice(&mut out, &input, Kernel::Scalar);
1978        assert!(result.is_err());
1979        match result {
1980            Err(TrendFlexError::NoDataProvided) => {}
1981            _ => panic!("Expected NoDataProvided error"),
1982        }
1983
1984        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1985        let params = TrendFlexParams { period: Some(0) };
1986        let input = TrendFlexInput::from_slice(&data, params);
1987        let mut out = vec![0.0; data.len()];
1988
1989        let result = trendflex_into_slice(&mut out, &input, Kernel::Scalar);
1990        assert!(result.is_err());
1991        match result {
1992            Err(TrendFlexError::ZeroTrendFlexPeriod { period }) => {
1993                assert_eq!(period, 0);
1994            }
1995            _ => panic!("Expected ZeroTrendFlexPeriod error"),
1996        }
1997
1998        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1999        let params = TrendFlexParams { period: Some(3) };
2000        let input = TrendFlexInput::from_slice(&data, params);
2001        let mut out = vec![0.0; data.len()];
2002
2003        let result = trendflex_into_slice(&mut out, &input, Kernel::Scalar);
2004        assert!(result.is_ok());
2005    }
2006
2007    #[test]
2008    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2009    fn test_trendflex_into_matches_api() -> Result<(), Box<dyn Error>> {
2010        let n = 512usize;
2011        let mut data = Vec::with_capacity(n);
2012        for i in 0..n {
2013            let t = i as f64;
2014            data.push(0.01 * t + (t * 0.05).sin());
2015        }
2016
2017        let input = TrendFlexInput::from_slice(&data, TrendFlexParams::default());
2018        let baseline = trendflex(&input)?.values;
2019
2020        let mut out = vec![0.0f64; n];
2021        trendflex_into(&input, &mut out)?;
2022
2023        assert_eq!(baseline.len(), out.len());
2024        for i in 0..n {
2025            let a = baseline[i];
2026            let b = out[i];
2027            let equal = if a.is_nan() && b.is_nan() {
2028                true
2029            } else {
2030                (a - b).abs() <= 1e-12
2031            };
2032            assert!(equal, "divergence at {}: {} vs {}", i, a, b);
2033        }
2034        Ok(())
2035    }
2036
2037    #[test]
2038    fn test_trendflex_batch_kernel_policy() {
2039        let data = vec![1.0; 50];
2040        let sweep = TrendFlexBatchRange { period: (5, 10, 1) };
2041
2042        let result_scalar = trendflex_batch_with_kernel(&data, &sweep, Kernel::Scalar);
2043        assert!(matches!(
2044            result_scalar,
2045            Err(TrendFlexError::InvalidKernelForBatch(Kernel::Scalar))
2046        ));
2047
2048        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2049        {
2050            let result_avx2 = trendflex_batch_with_kernel(&data, &sweep, Kernel::Avx2);
2051            assert!(matches!(
2052                result_avx2,
2053                Err(TrendFlexError::InvalidKernelForBatch(Kernel::Avx2))
2054            ));
2055
2056            let result_avx512 = trendflex_batch_with_kernel(&data, &sweep, Kernel::Avx512);
2057            assert!(matches!(
2058                result_avx512,
2059                Err(TrendFlexError::InvalidKernelForBatch(Kernel::Avx512))
2060            ));
2061        }
2062
2063        let result_scalar_batch = trendflex_batch_with_kernel(&data, &sweep, Kernel::ScalarBatch);
2064        assert!(result_scalar_batch.is_ok());
2065    }
2066
2067    generate_all_trendflex_tests!(
2068        check_trendflex_partial_params,
2069        check_trendflex_accuracy,
2070        check_trendflex_default_candles,
2071        check_trendflex_zero_period,
2072        check_trendflex_period_exceeds_length,
2073        check_trendflex_very_small_dataset,
2074        check_trendflex_reinput,
2075        check_trendflex_nan_handling,
2076        check_trendflex_streaming,
2077        check_trendflex_no_poison
2078    );
2079
2080    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2081        skip_if_unsupported!(kernel, test);
2082
2083        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2084        let c = read_candles_from_csv(file)?;
2085
2086        let output = TrendFlexBatchBuilder::new()
2087            .kernel(kernel)
2088            .apply_candles(&c, "close")?;
2089
2090        let def = TrendFlexParams::default();
2091        let row = output.values_for(&def).expect("default row missing");
2092
2093        assert_eq!(row.len(), c.close.len());
2094
2095        let expected = [
2096            -0.19724678008015128,
2097            -0.1238001236481444,
2098            -0.10515389737087717,
2099            -0.1149541079904878,
2100            -0.16006869484450567,
2101        ];
2102        let start = row.len() - 5;
2103        for (i, &v) in row[start..].iter().enumerate() {
2104            assert!(
2105                (v - expected[i]).abs() < 1e-8,
2106                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2107            );
2108        }
2109        Ok(())
2110    }
2111
2112    macro_rules! gen_batch_tests {
2113        ($fn_name:ident) => {
2114            paste::paste! {
2115                #[test] fn [<$fn_name _scalar>]()      {
2116                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2117                }
2118                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2119                #[test] fn [<$fn_name _avx2>]()        {
2120                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2121                }
2122                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2123                #[test] fn [<$fn_name _avx512>]()      {
2124                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2125                }
2126                #[test] fn [<$fn_name _auto_detect>]() {
2127                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2128                }
2129            }
2130        };
2131    }
2132
2133    #[cfg(debug_assertions)]
2134    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2135        skip_if_unsupported!(kernel, test);
2136
2137        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2138        let c = read_candles_from_csv(file)?;
2139
2140        let test_configs = vec![
2141            (5, 20, 3),
2142            (10, 50, 5),
2143            (20, 100, 10),
2144            (30, 120, 15),
2145            (7, 7, 1),
2146            (80, 80, 1),
2147            (15, 45, 5),
2148        ];
2149
2150        for (start, end, step) in test_configs {
2151            let output = TrendFlexBatchBuilder::new()
2152                .kernel(kernel)
2153                .period_range(start, end, step)
2154                .apply_candles(&c, "close")?;
2155
2156            for (idx, &val) in output.values.iter().enumerate() {
2157                if val.is_nan() {
2158                    continue;
2159                }
2160
2161                let bits = val.to_bits();
2162                let row = idx / output.cols;
2163                let col = idx % output.cols;
2164                let period = output
2165                    .combos
2166                    .get(row)
2167                    .map(|p| p.period.unwrap_or(0))
2168                    .unwrap_or(0);
2169
2170                if bits == 0x11111111_11111111 {
2171                    panic!(
2172                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (period {}, flat index {})",
2173                        test, val, bits, row, col, period, idx
2174                    );
2175                }
2176
2177                if bits == 0x22222222_22222222 {
2178                    panic!(
2179                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (period {}, flat index {})",
2180                        test, val, bits, row, col, period, idx
2181                    );
2182                }
2183
2184                if bits == 0x33333333_33333333 {
2185                    panic!(
2186                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (period {}, flat index {})",
2187                        test, val, bits, row, col, period, idx
2188                    );
2189                }
2190            }
2191        }
2192
2193        Ok(())
2194    }
2195
2196    #[cfg(not(debug_assertions))]
2197    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2198        Ok(())
2199    }
2200
2201    gen_batch_tests!(check_batch_default_row);
2202    gen_batch_tests!(check_batch_no_poison);
2203}
2204
2205#[cfg(feature = "python")]
2206use pyo3::exceptions::PyValueError;
2207#[cfg(feature = "python")]
2208use pyo3::prelude::*;
2209
2210#[cfg(feature = "python")]
2211#[pyfunction(name = "trendflex")]
2212#[pyo3(signature = (data, period=None, kernel=None))]
2213
2214pub fn trendflex_py<'py>(
2215    py: Python<'py>,
2216    data: numpy::PyReadonlyArray1<'py, f64>,
2217    period: Option<usize>,
2218    kernel: Option<&str>,
2219) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
2220    use numpy::{IntoPyArray, PyArrayMethods};
2221
2222    let slice_in = data.as_slice()?;
2223    let kern = validate_kernel(kernel, false)?;
2224
2225    let params = TrendFlexParams { period };
2226    let trendflex_in = TrendFlexInput::from_slice(slice_in, params);
2227
2228    let result_vec: Vec<f64> = py
2229        .allow_threads(|| trendflex_with_kernel(&trendflex_in, kern).map(|o| o.values))
2230        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2231
2232    Ok(result_vec.into_pyarray(py))
2233}
2234
2235#[cfg(feature = "python")]
2236#[pyclass(name = "TrendFlexStream")]
2237pub struct TrendFlexStreamPy {
2238    stream: TrendFlexStream,
2239}
2240
2241#[cfg(feature = "python")]
2242#[pymethods]
2243impl TrendFlexStreamPy {
2244    #[new]
2245    fn new(period: Option<usize>) -> PyResult<Self> {
2246        let params = TrendFlexParams { period };
2247        let stream =
2248            TrendFlexStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2249        Ok(TrendFlexStreamPy { stream })
2250    }
2251
2252    fn update(&mut self, value: f64) -> Option<f64> {
2253        self.stream.update(value)
2254    }
2255}
2256
2257#[cfg(feature = "python")]
2258#[pyfunction(name = "trendflex_batch")]
2259#[pyo3(signature = (data, period_range, kernel=None))]
2260
2261pub fn trendflex_batch_py<'py>(
2262    py: Python<'py>,
2263    data: numpy::PyReadonlyArray1<'py, f64>,
2264    period_range: (usize, usize, usize),
2265    kernel: Option<&str>,
2266) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2267    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2268    use pyo3::types::PyDict;
2269
2270    let slice_in = data.as_slice()?;
2271    let kern = validate_kernel(kernel, true)?;
2272
2273    let sweep = TrendFlexBatchRange {
2274        period: period_range,
2275    };
2276
2277    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2278    let rows = combos.len();
2279    let cols = slice_in.len();
2280    rows.checked_mul(cols)
2281        .ok_or_else(|| PyValueError::new_err("dimensions overflow"))?;
2282
2283    let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
2284    let slice_out = unsafe { out_arr.as_slice_mut()? };
2285
2286    let combos = py
2287        .allow_threads(|| -> Result<Vec<TrendFlexParams>, TrendFlexError> {
2288            let kernel = match kern {
2289                Kernel::Auto => detect_best_batch_kernel(),
2290                k => k,
2291            };
2292            let simd = match kernel {
2293                Kernel::Avx512Batch => Kernel::Avx512,
2294                Kernel::Avx2Batch => Kernel::Avx2,
2295                Kernel::ScalarBatch => Kernel::Scalar,
2296                _ => unreachable!(),
2297            };
2298
2299            trendflex_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2300        })
2301        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2302
2303    let dict = PyDict::new(py);
2304    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2305
2306    dict.set_item(
2307        "periods",
2308        combos
2309            .iter()
2310            .map(|p| p.period.unwrap_or(20) as u64)
2311            .collect::<Vec<_>>()
2312            .into_pyarray(py),
2313    )?;
2314
2315    Ok(dict)
2316}
2317
2318#[cfg(all(feature = "python", feature = "cuda"))]
2319#[pyfunction(name = "trendflex_cuda_batch_dev")]
2320#[pyo3(signature = (data_f32, period_range, device_id=0))]
2321pub fn trendflex_cuda_batch_dev_py<'py>(
2322    py: Python<'py>,
2323    data_f32: numpy::PyReadonlyArray1<'py, f32>,
2324    period_range: (usize, usize, usize),
2325    device_id: usize,
2326) -> PyResult<(TrendFlexDeviceArrayF32Py, Bound<'py, pyo3::types::PyDict>)> {
2327    use crate::cuda::cuda_available;
2328    use numpy::IntoPyArray;
2329    use pyo3::types::PyDict;
2330
2331    if !cuda_available() {
2332        return Err(PyValueError::new_err("CUDA not available"));
2333    }
2334
2335    let slice_in = data_f32.as_slice()?;
2336    let sweep = TrendFlexBatchRange {
2337        period: period_range,
2338    };
2339
2340    let (inner, combos, ctx_arc, dev_id) = py.allow_threads(|| {
2341        let cuda =
2342            CudaTrendflex::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2343        let (dev, combos) = cuda
2344            .trendflex_batch_dev(slice_in, &sweep)
2345            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2346
2347        cuda.synchronize()
2348            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2349        Ok::<_, PyErr>((dev, combos, cuda.context_arc_clone(), cuda.device_id()))
2350    })?;
2351
2352    let dict = PyDict::new(py);
2353    let periods: Vec<u64> = combos.iter().map(|c| c.period.unwrap() as u64).collect();
2354    dict.set_item("periods", periods.into_pyarray(py))?;
2355
2356    Ok((
2357        TrendFlexDeviceArrayF32Py {
2358            inner,
2359            _ctx: ctx_arc,
2360            device_id: dev_id,
2361        },
2362        dict,
2363    ))
2364}
2365
2366#[cfg(all(feature = "python", feature = "cuda"))]
2367#[pyfunction(name = "trendflex_cuda_many_series_one_param_dev")]
2368#[pyo3(signature = (data_tm_f32, period, device_id=0))]
2369pub fn trendflex_cuda_many_series_one_param_dev_py(
2370    py: Python<'_>,
2371    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
2372    period: usize,
2373    device_id: usize,
2374) -> PyResult<TrendFlexDeviceArrayF32Py> {
2375    use crate::cuda::cuda_available;
2376    use numpy::PyUntypedArrayMethods;
2377
2378    if !cuda_available() {
2379        return Err(PyValueError::new_err("CUDA not available"));
2380    }
2381
2382    let flat_in = data_tm_f32.as_slice()?;
2383    let rows = data_tm_f32.shape()[0];
2384    let cols = data_tm_f32.shape()[1];
2385    let params = TrendFlexParams {
2386        period: Some(period),
2387    };
2388
2389    let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
2390        let cuda =
2391            CudaTrendflex::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2392        let dev = cuda
2393            .trendflex_multi_series_one_param_time_major_dev(flat_in, cols, rows, &params)
2394            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2395        cuda.synchronize()
2396            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2397        Ok::<_, PyErr>((dev, cuda.context_arc_clone(), cuda.device_id()))
2398    })?;
2399
2400    Ok(TrendFlexDeviceArrayF32Py {
2401        inner,
2402        _ctx: ctx_arc,
2403        device_id: dev_id,
2404    })
2405}
2406
2407#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2408use serde::{Deserialize, Serialize};
2409#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2410use wasm_bindgen::prelude::*;
2411
2412#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2413#[derive(Serialize, Deserialize)]
2414pub struct TrendFlexBatchConfig {
2415    pub period_range: (usize, usize, usize),
2416}
2417
2418#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2419#[derive(Serialize, Deserialize)]
2420pub struct TrendFlexBatchJsOutput {
2421    pub values: Vec<f64>,
2422    pub combos: Vec<TrendFlexParams>,
2423    pub rows: usize,
2424    pub cols: usize,
2425}
2426
2427#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2428#[wasm_bindgen]
2429
2430pub fn trendflex_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2431    let params = TrendFlexParams {
2432        period: Some(period),
2433    };
2434    let input = TrendFlexInput::from_slice(data, params);
2435
2436    trendflex_with_kernel(&input, Kernel::Auto)
2437        .map(|o| o.values)
2438        .map_err(|e| JsValue::from_str(&e.to_string()))
2439}
2440
2441#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2442#[wasm_bindgen]
2443
2444pub fn trendflex_batch_js(
2445    data: &[f64],
2446    period_start: usize,
2447    period_end: usize,
2448    period_step: usize,
2449) -> Result<Vec<f64>, JsValue> {
2450    let sweep = TrendFlexBatchRange {
2451        period: (period_start, period_end, period_step),
2452    };
2453
2454    trendflex_batch_inner(data, &sweep, Kernel::Auto, false)
2455        .map(|output| output.values)
2456        .map_err(|e| JsValue::from_str(&e.to_string()))
2457}
2458
2459#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2460#[wasm_bindgen]
2461
2462pub fn trendflex_batch_metadata_js(
2463    period_start: usize,
2464    period_end: usize,
2465    period_step: usize,
2466) -> Result<Vec<f64>, JsValue> {
2467    let sweep = TrendFlexBatchRange {
2468        period: (period_start, period_end, period_step),
2469    };
2470
2471    let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2472    let metadata: Vec<f64> = combos
2473        .iter()
2474        .map(|combo| combo.period.unwrap_or(20) as f64)
2475        .collect();
2476
2477    Ok(metadata)
2478}
2479
2480#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2481#[wasm_bindgen(js_name = trendflex_batch)]
2482pub fn trendflex_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2483    let config: TrendFlexBatchConfig = serde_wasm_bindgen::from_value(config)
2484        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2485
2486    let sweep = TrendFlexBatchRange {
2487        period: config.period_range,
2488    };
2489
2490    let output = trendflex_batch_inner(data, &sweep, Kernel::Auto, false)
2491        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2492
2493    let js_output = TrendFlexBatchJsOutput {
2494        values: output.values,
2495        combos: output.combos,
2496        rows: output.rows,
2497        cols: output.cols,
2498    };
2499
2500    serde_wasm_bindgen::to_value(&js_output)
2501        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2502}
2503
2504#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2505#[wasm_bindgen]
2506pub fn trendflex_alloc(len: usize) -> *mut f64 {
2507    let mut vec = Vec::<f64>::with_capacity(len);
2508    let ptr = vec.as_mut_ptr();
2509    std::mem::forget(vec);
2510    ptr
2511}
2512
2513#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2514#[wasm_bindgen]
2515pub fn trendflex_free(ptr: *mut f64, len: usize) {
2516    unsafe {
2517        let _ = Vec::from_raw_parts(ptr, len, len);
2518    }
2519}
2520
2521#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2522#[wasm_bindgen]
2523pub fn trendflex_into(
2524    in_ptr: *const f64,
2525    out_ptr: *mut f64,
2526    len: usize,
2527    period: usize,
2528) -> Result<(), JsValue> {
2529    if in_ptr.is_null() || out_ptr.is_null() {
2530        return Err(JsValue::from_str("null pointer passed to trendflex_into"));
2531    }
2532    unsafe {
2533        let data = std::slice::from_raw_parts(in_ptr, len);
2534        if period == 0 || period >= len {
2535            return Err(JsValue::from_str("Invalid period"));
2536        }
2537        let input = TrendFlexInput::from_slice(
2538            data,
2539            TrendFlexParams {
2540                period: Some(period),
2541            },
2542        );
2543        if in_ptr == out_ptr {
2544            let mut tmp = vec![0.0; len];
2545            trendflex_into_slice(&mut tmp, &input, detect_best_kernel())
2546                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2547            std::slice::from_raw_parts_mut(out_ptr, len).copy_from_slice(&tmp);
2548        } else {
2549            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2550            trendflex_into_slice(out, &input, detect_best_kernel())
2551                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2552        }
2553        Ok(())
2554    }
2555}
2556
2557#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2558#[wasm_bindgen]
2559pub fn trendflex_batch_into(
2560    in_ptr: *const f64,
2561    out_ptr: *mut f64,
2562    len: usize,
2563    period_start: usize,
2564    period_end: usize,
2565    period_step: usize,
2566) -> Result<usize, JsValue> {
2567    if in_ptr.is_null() || out_ptr.is_null() {
2568        return Err(JsValue::from_str(
2569            "null pointer passed to trendflex_batch_into",
2570        ));
2571    }
2572
2573    unsafe {
2574        let data = std::slice::from_raw_parts(in_ptr, len);
2575
2576        let sweep = TrendFlexBatchRange {
2577            period: (period_start, period_end, period_step),
2578        };
2579
2580        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2581        let n_combos = combos.len();
2582        let total_size = n_combos
2583            .checked_mul(len)
2584            .ok_or_else(|| JsValue::from_str("dimensions overflow"))?;
2585
2586        let out_slice = std::slice::from_raw_parts_mut(out_ptr, total_size);
2587
2588        trendflex_batch_inner_into(data, &sweep, Kernel::Auto, false, out_slice)
2589            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2590
2591        Ok(n_combos)
2592    }
2593}