Skip to main content

vector_ta/indicators/
alphatrend.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::{PyDict, PyList};
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, CandleFieldFlags, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
19    make_uninit_matrix,
20};
21#[cfg(feature = "python")]
22use crate::utilities::kernel_validation::validate_kernel;
23use aligned_vec::{AVec, CACHELINE_ALIGN};
24
25#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
26use core::arch::x86_64::*;
27
28#[cfg(not(target_arch = "wasm32"))]
29use rayon::prelude::*;
30
31use std::convert::AsRef;
32use std::error::Error;
33use std::mem::MaybeUninit;
34use thiserror::Error;
35
36#[cfg(all(feature = "python", feature = "cuda"))]
37use crate::cuda::alphatrend_wrapper::CudaAlphaTrend;
38use crate::indicators::mfi::{mfi_with_kernel, MfiInput, MfiParams};
39use crate::indicators::rsi::{rsi_with_kernel, RsiInput, RsiParams};
40#[cfg(all(feature = "python", feature = "cuda"))]
41use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
42#[cfg(all(feature = "python", feature = "cuda"))]
43use cust::context::Context;
44#[cfg(all(feature = "python", feature = "cuda"))]
45use cust::memory::DeviceBuffer;
46#[cfg(all(feature = "python", feature = "cuda"))]
47use std::sync::Arc;
48
49impl<'a> AsRef<[f64]> for AlphaTrendInput<'a> {
50    #[inline(always)]
51    fn as_ref(&self) -> &[f64] {
52        match &self.data {
53            AlphaTrendData::Slices { close, .. } => close,
54            AlphaTrendData::Candles { candles, .. } => &candles.close,
55        }
56    }
57}
58
59#[derive(Debug, Clone)]
60pub enum AlphaTrendData<'a> {
61    Candles {
62        candles: &'a Candles,
63    },
64    Slices {
65        open: &'a [f64],
66        high: &'a [f64],
67        low: &'a [f64],
68        close: &'a [f64],
69        volume: &'a [f64],
70    },
71}
72
73#[derive(Debug, Clone)]
74pub struct AlphaTrendOutput {
75    pub k1: Vec<f64>,
76    pub k2: Vec<f64>,
77}
78
79#[derive(Debug, Clone)]
80#[cfg_attr(
81    all(target_arch = "wasm32", feature = "wasm"),
82    derive(Serialize, Deserialize)
83)]
84pub struct AlphaTrendParams {
85    pub coeff: Option<f64>,
86    pub period: Option<usize>,
87    pub no_volume: Option<bool>,
88}
89
90impl Default for AlphaTrendParams {
91    fn default() -> Self {
92        Self {
93            coeff: Some(1.0),
94            period: Some(14),
95            no_volume: Some(false),
96        }
97    }
98}
99
100#[derive(Debug, Clone)]
101pub struct AlphaTrendInput<'a> {
102    pub data: AlphaTrendData<'a>,
103    pub params: AlphaTrendParams,
104}
105
106impl<'a> AlphaTrendInput<'a> {
107    #[inline]
108    pub fn from_candles(c: &'a Candles, p: AlphaTrendParams) -> Self {
109        Self {
110            data: AlphaTrendData::Candles { candles: c },
111            params: p,
112        }
113    }
114
115    #[inline]
116    pub fn from_slices(
117        open: &'a [f64],
118        high: &'a [f64],
119        low: &'a [f64],
120        close: &'a [f64],
121        volume: &'a [f64],
122        p: AlphaTrendParams,
123    ) -> Self {
124        Self {
125            data: AlphaTrendData::Slices {
126                open,
127                high,
128                low,
129                close,
130                volume,
131            },
132            params: p,
133        }
134    }
135
136    #[inline]
137    pub fn with_default_candles(c: &'a Candles) -> Self {
138        Self::from_candles(c, AlphaTrendParams::default())
139    }
140
141    #[inline]
142    pub fn get_coeff(&self) -> f64 {
143        self.params.coeff.unwrap_or(1.0)
144    }
145
146    #[inline]
147    pub fn get_period(&self) -> usize {
148        self.params.period.unwrap_or(14)
149    }
150
151    #[inline]
152    pub fn get_no_volume(&self) -> bool {
153        self.params.no_volume.unwrap_or(false)
154    }
155}
156
157#[derive(Copy, Clone, Debug)]
158pub struct AlphaTrendBuilder {
159    coeff: Option<f64>,
160    period: Option<usize>,
161    no_volume: Option<bool>,
162    kernel: Kernel,
163}
164
165impl Default for AlphaTrendBuilder {
166    fn default() -> Self {
167        Self {
168            coeff: None,
169            period: None,
170            no_volume: None,
171            kernel: Kernel::Auto,
172        }
173    }
174}
175
176impl AlphaTrendBuilder {
177    #[inline(always)]
178    pub fn new() -> Self {
179        Self::default()
180    }
181
182    #[inline(always)]
183    pub fn coeff(mut self, val: f64) -> Self {
184        self.coeff = Some(val);
185        self
186    }
187
188    #[inline(always)]
189    pub fn period(mut self, val: usize) -> Self {
190        self.period = Some(val);
191        self
192    }
193
194    #[inline(always)]
195    pub fn no_volume(mut self, val: bool) -> Self {
196        self.no_volume = Some(val);
197        self
198    }
199
200    #[inline(always)]
201    pub fn kernel(mut self, k: Kernel) -> Self {
202        self.kernel = k;
203        self
204    }
205
206    #[inline(always)]
207    pub fn apply(self, c: &Candles) -> Result<AlphaTrendOutput, AlphaTrendError> {
208        let p = AlphaTrendParams {
209            coeff: self.coeff,
210            period: self.period,
211            no_volume: self.no_volume,
212        };
213        let i = AlphaTrendInput::from_candles(c, p);
214        alphatrend_with_kernel(&i, self.kernel)
215    }
216
217    #[inline(always)]
218    pub fn apply_slice(
219        self,
220        open: &[f64],
221        high: &[f64],
222        low: &[f64],
223        close: &[f64],
224        volume: &[f64],
225    ) -> Result<AlphaTrendOutput, AlphaTrendError> {
226        let p = AlphaTrendParams {
227            coeff: self.coeff,
228            period: self.period,
229            no_volume: self.no_volume,
230        };
231        let i = AlphaTrendInput::from_slices(open, high, low, close, volume, p);
232        alphatrend_with_kernel(&i, self.kernel)
233    }
234
235    #[inline(always)]
236    pub fn into_stream(self) -> Result<AlphaTrendStream, AlphaTrendError> {
237        let p = AlphaTrendParams {
238            coeff: self.coeff,
239            period: self.period,
240            no_volume: self.no_volume,
241        };
242        AlphaTrendStream::try_new(p)
243    }
244}
245
246#[derive(Debug, Error)]
247pub enum AlphaTrendError {
248    #[error("alphatrend: Input data slice is empty.")]
249    EmptyInputData,
250
251    #[error("alphatrend: All values are NaN.")]
252    AllValuesNaN,
253
254    #[error("alphatrend: Invalid period: period = {period}, data length = {data_len}")]
255    InvalidPeriod { period: usize, data_len: usize },
256
257    #[error("alphatrend: Not enough valid data: needed = {needed}, valid = {valid}")]
258    NotEnoughValidData { needed: usize, valid: usize },
259
260    #[error("alphatrend: Inconsistent data lengths")]
261    InconsistentDataLengths,
262
263    #[error("alphatrend: Output length mismatch: expected = {expected}, got = {got}")]
264    OutputLengthMismatch { expected: usize, got: usize },
265
266    #[error("alphatrend: Invalid coefficient: {coeff}")]
267    InvalidCoeff { coeff: f64 },
268
269    #[error("alphatrend: RSI calculation failed: {msg}")]
270    RsiError { msg: String },
271
272    #[error("alphatrend: MFI calculation failed: {msg}")]
273    MfiError { msg: String },
274
275    #[error("alphatrend: Invalid range (usize): start={start} end={end} step={step}")]
276    InvalidRange {
277        start: usize,
278        end: usize,
279        step: usize,
280    },
281
282    #[error("alphatrend: Invalid range (f64): start={start} end={end} step={step}")]
283    InvalidRangeF64 { start: f64, end: f64, step: f64 },
284
285    #[error("alphatrend: Invalid kernel for batch path: {0:?}")]
286    InvalidKernelForBatch(Kernel),
287
288    #[error("alphatrend: invalid input: {0}")]
289    InvalidInput(String),
290}
291
292#[inline]
293pub fn alphatrend(input: &AlphaTrendInput) -> Result<AlphaTrendOutput, AlphaTrendError> {
294    alphatrend_with_kernel(input, Kernel::Auto)
295}
296
297pub fn alphatrend_with_kernel(
298    input: &AlphaTrendInput,
299    kernel: Kernel,
300) -> Result<AlphaTrendOutput, AlphaTrendError> {
301    let (open, high, low, close, volume, coeff, period, no_volume, first, chosen) =
302        alphatrend_prepare(input, kernel)?;
303
304    let len = close.len();
305    let warm = first + period - 1;
306
307    let mut k1 = alloc_with_nan_prefix(len, warm);
308    let mut k2 = alloc_with_nan_prefix(len, warm + 2);
309
310    alphatrend_compute_into(
311        open, high, low, close, volume, coeff, period, no_volume, first, chosen, &mut k1, &mut k2,
312    )?;
313
314    Ok(AlphaTrendOutput { k1, k2 })
315}
316
317#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
318#[inline]
319pub fn alphatrend_into(
320    input: &AlphaTrendInput,
321    out_k1: &mut [f64],
322    out_k2: &mut [f64],
323) -> Result<(), AlphaTrendError> {
324    alphatrend_into_slices(out_k1, out_k2, input, Kernel::Auto)
325}
326
327#[inline]
328pub fn alphatrend_into_slices(
329    dst_k1: &mut [f64],
330    dst_k2: &mut [f64],
331    input: &AlphaTrendInput,
332    kern: Kernel,
333) -> Result<(), AlphaTrendError> {
334    let (open, high, low, close, volume, coeff, period, no_volume, first, chosen) =
335        alphatrend_prepare(input, kern)?;
336
337    if dst_k1.len() != close.len() {
338        return Err(AlphaTrendError::OutputLengthMismatch {
339            expected: close.len(),
340            got: dst_k1.len(),
341        });
342    }
343    if dst_k2.len() != close.len() {
344        return Err(AlphaTrendError::OutputLengthMismatch {
345            expected: close.len(),
346            got: dst_k2.len(),
347        });
348    }
349
350    let warm = first + period - 1;
351    let k1_warm_end = warm.min(dst_k1.len());
352    let k2_warm_end = (warm + 2).min(dst_k2.len());
353    for v in &mut dst_k1[..k1_warm_end] {
354        *v = f64::NAN;
355    }
356    for v in &mut dst_k2[..k2_warm_end] {
357        *v = f64::NAN;
358    }
359
360    alphatrend_compute_into(
361        open, high, low, close, volume, coeff, period, no_volume, first, chosen, dst_k1, dst_k2,
362    )?;
363
364    Ok(())
365}
366
367#[inline(always)]
368fn alphatrend_prepare<'a>(
369    input: &'a AlphaTrendInput,
370    kernel: Kernel,
371) -> Result<
372    (
373        &'a [f64],
374        &'a [f64],
375        &'a [f64],
376        &'a [f64],
377        &'a [f64],
378        f64,
379        usize,
380        bool,
381        usize,
382        Kernel,
383    ),
384    AlphaTrendError,
385> {
386    let (open, high, low, close, volume) = match &input.data {
387        AlphaTrendData::Candles { candles } => (
388            &candles.open[..],
389            &candles.high[..],
390            &candles.low[..],
391            &candles.close[..],
392            &candles.volume[..],
393        ),
394        AlphaTrendData::Slices {
395            open,
396            high,
397            low,
398            close,
399            volume,
400        } => (*open, *high, *low, *close, *volume),
401    };
402
403    let len = close.len();
404
405    if len == 0 {
406        return Err(AlphaTrendError::EmptyInputData);
407    }
408
409    if open.len() != len || high.len() != len || low.len() != len || volume.len() != len {
410        return Err(AlphaTrendError::InconsistentDataLengths);
411    }
412
413    let first = close
414        .iter()
415        .position(|x| !x.is_nan())
416        .ok_or(AlphaTrendError::AllValuesNaN)?;
417
418    let coeff = input.get_coeff();
419    let period = input.get_period();
420    let no_volume = input.get_no_volume();
421
422    if period == 0 || period > len {
423        return Err(AlphaTrendError::InvalidPeriod {
424            period,
425            data_len: len,
426        });
427    }
428
429    if len - first < period {
430        return Err(AlphaTrendError::NotEnoughValidData {
431            needed: period,
432            valid: len - first,
433        });
434    }
435
436    if coeff <= 0.0 || !coeff.is_finite() {
437        return Err(AlphaTrendError::InvalidCoeff { coeff });
438    }
439
440    let chosen = match kernel {
441        Kernel::Auto => Kernel::Scalar,
442        k => k,
443    };
444
445    Ok((
446        open, high, low, close, volume, coeff, period, no_volume, first, chosen,
447    ))
448}
449
450#[inline(always)]
451fn alphatrend_compute_into(
452    open: &[f64],
453    high: &[f64],
454    low: &[f64],
455    close: &[f64],
456    volume: &[f64],
457    coeff: f64,
458    period: usize,
459    no_volume: bool,
460    first: usize,
461    kernel: Kernel,
462    out_k1: &mut [f64],
463    out_k2: &mut [f64],
464) -> Result<(), AlphaTrendError> {
465    unsafe {
466        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
467        {
468            if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
469                return alphatrend_simd128(
470                    open, high, low, close, volume, coeff, period, no_volume, first, out_k1, out_k2,
471                );
472            }
473        }
474
475        match kernel {
476            Kernel::Scalar | Kernel::ScalarBatch => alphatrend_scalar(
477                open, high, low, close, volume, coeff, period, no_volume, first, out_k1, out_k2,
478            ),
479            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
480            Kernel::Avx2 | Kernel::Avx2Batch => alphatrend_avx2(
481                open, high, low, close, volume, coeff, period, no_volume, first, out_k1, out_k2,
482            ),
483            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
484            Kernel::Avx512 | Kernel::Avx512Batch => alphatrend_avx512(
485                open, high, low, close, volume, coeff, period, no_volume, first, out_k1, out_k2,
486            ),
487            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
488            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
489                alphatrend_scalar(
490                    open, high, low, close, volume, coeff, period, no_volume, first, out_k1, out_k2,
491                )
492            }
493            _ => unreachable!(),
494        }
495    }
496}
497
498#[inline]
499pub fn alphatrend_scalar(
500    _open: &[f64],
501    high: &[f64],
502    low: &[f64],
503    close: &[f64],
504    volume: &[f64],
505    coeff: f64,
506    period: usize,
507    no_volume: bool,
508    first_val: usize,
509    out_k1: &mut [f64],
510    out_k2: &mut [f64],
511) -> Result<(), AlphaTrendError> {
512    let len = close.len();
513    let warmup = first_val + period - 1;
514
515    let mut tr_mu = make_uninit_matrix(1, len);
516    let tr: &mut [f64] =
517        unsafe { core::slice::from_raw_parts_mut(tr_mu.as_mut_ptr() as *mut f64, len) };
518
519    if first_val < len {
520        tr[first_val] = high[first_val] - low[first_val];
521    }
522    for i in (first_val + 1)..len {
523        let hl = high[i] - low[i];
524        let hc = (high[i] - close[i - 1]).abs();
525        let lc = (low[i] - close[i - 1]).abs();
526        tr[i] = hl.max(hc).max(lc);
527    }
528
529    let momentum_values: Vec<f64> = if no_volume {
530        let rsi_params = RsiParams {
531            period: Some(period),
532        };
533        let rsi_input = RsiInput::from_slice(close, rsi_params);
534        rsi_with_kernel(&rsi_input, Kernel::Scalar)
535            .map_err(|e| AlphaTrendError::RsiError { msg: e.to_string() })?
536            .values
537    } else {
538        let mut hlc3_mu = make_uninit_matrix(1, len);
539        let hlc3: &mut [f64] =
540            unsafe { core::slice::from_raw_parts_mut(hlc3_mu.as_mut_ptr() as *mut f64, len) };
541        for i in 0..len {
542            hlc3[i] = (high[i] + low[i] + close[i]) / 3.0;
543        }
544        let mfi_params = MfiParams {
545            period: Some(period),
546        };
547        let mfi_input = MfiInput::from_slices(hlc3, volume, mfi_params);
548        mfi_with_kernel(&mfi_input, Kernel::Scalar)
549            .map_err(|e| AlphaTrendError::MfiError { msg: e.to_string() })?
550            .values
551    };
552
553    if warmup < len {
554        let mut sum = 0.0f64;
555        for j in first_val..=warmup {
556            sum += tr[j];
557        }
558
559        let mut prev_alpha = f64::NAN;
560        let mut prev1 = f64::NAN;
561        let mut prev2 = f64::NAN;
562
563        for i in warmup..len {
564            let a = sum / period as f64;
565
566            let up_t = low[i] - a * coeff;
567            let down_t = high[i] + a * coeff;
568            let m_check = momentum_values[i] >= 50.0;
569
570            let cur = if i == warmup {
571                if m_check {
572                    up_t
573                } else {
574                    down_t
575                }
576            } else if m_check {
577                if up_t < prev_alpha {
578                    prev_alpha
579                } else {
580                    up_t
581                }
582            } else {
583                if down_t > prev_alpha {
584                    prev_alpha
585                } else {
586                    down_t
587                }
588            };
589
590            out_k1[i] = cur;
591            if i >= warmup + 2 {
592                out_k2[i] = prev2;
593            }
594
595            prev2 = prev1;
596            prev1 = cur;
597            prev_alpha = cur;
598
599            if i + 1 < len {
600                sum += tr[i + 1] - tr[i + 1 - period];
601            }
602        }
603    }
604
605    for v in &mut out_k1[..warmup.min(len)] {
606        *v = f64::NAN;
607    }
608    for v in &mut out_k2[..(warmup + 2).min(len)] {
609        *v = f64::NAN;
610    }
611
612    Ok(())
613}
614
615#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
616#[target_feature(enable = "avx2,fma")]
617unsafe fn alphatrend_avx2(
618    _open: &[f64],
619    high: &[f64],
620    low: &[f64],
621    close: &[f64],
622    volume: &[f64],
623    coeff: f64,
624    period: usize,
625    no_volume: bool,
626    first_val: usize,
627    out_k1: &mut [f64],
628    out_k2: &mut [f64],
629) -> Result<(), AlphaTrendError> {
630    use core::arch::x86_64::*;
631
632    #[inline(always)]
633    unsafe fn mm256_abs_pd(x: __m256d) -> __m256d {
634        let sign = _mm256_set1_pd(-0.0);
635        _mm256_andnot_pd(sign, x)
636    }
637
638    let len = close.len();
639    let warmup = first_val + period - 1;
640    let p_f = period as f64;
641
642    let mut tr_mu = make_uninit_matrix(1, len);
643    let tr: &mut [f64] = core::slice::from_raw_parts_mut(tr_mu.as_mut_ptr() as *mut f64, len);
644
645    if first_val < len {
646        *tr.get_unchecked_mut(first_val) =
647            *high.get_unchecked(first_val) - *low.get_unchecked(first_val);
648    }
649
650    let mut i = first_val + 1;
651    while i + 4 <= len {
652        let hv = _mm256_loadu_pd(high.as_ptr().add(i));
653        let lv = _mm256_loadu_pd(low.as_ptr().add(i));
654        let pc = _mm256_loadu_pd(close.as_ptr().add(i - 1));
655
656        let hl = _mm256_sub_pd(hv, lv);
657        let hc = mm256_abs_pd(_mm256_sub_pd(hv, pc));
658        let lc = mm256_abs_pd(_mm256_sub_pd(lv, pc));
659
660        let m1 = _mm256_max_pd(hl, hc);
661        let m = _mm256_max_pd(m1, lc);
662        _mm256_storeu_pd(tr.as_mut_ptr().add(i), m);
663        i += 4;
664    }
665    while i < len {
666        let hi = *high.get_unchecked(i);
667        let lo = *low.get_unchecked(i);
668        let pc = *close.get_unchecked(i - 1);
669        let hl = hi - lo;
670        let hc = (hi - pc).abs();
671        let lc = (lo - pc).abs();
672        let m = if hl >= hc { hl } else { hc };
673        *tr.get_unchecked_mut(i) = if m >= lc { m } else { lc };
674        i += 1;
675    }
676
677    let momentum_values: Vec<f64> = if no_volume {
678        let rsi_params = RsiParams {
679            period: Some(period),
680        };
681        let rsi_input = RsiInput::from_slice(close, rsi_params);
682        rsi_with_kernel(&rsi_input, Kernel::Avx2)
683            .map_err(|e| AlphaTrendError::RsiError { msg: e.to_string() })?
684            .values
685    } else {
686        let mut hlc3_mu = make_uninit_matrix(1, len);
687        let hlc3: &mut [f64] =
688            core::slice::from_raw_parts_mut(hlc3_mu.as_mut_ptr() as *mut f64, len);
689
690        let inv3 = _mm256_set1_pd(1.0 / 3.0);
691        let mut j = 0usize;
692        while j + 4 <= len {
693            let hv = _mm256_loadu_pd(high.as_ptr().add(j));
694            let lv = _mm256_loadu_pd(low.as_ptr().add(j));
695            let cv = _mm256_loadu_pd(close.as_ptr().add(j));
696            let s = _mm256_add_pd(_mm256_add_pd(hv, lv), cv);
697            let h3 = _mm256_mul_pd(s, inv3);
698            _mm256_storeu_pd(hlc3.as_mut_ptr().add(j), h3);
699            j += 4;
700        }
701        while j < len {
702            *hlc3.get_unchecked_mut(j) =
703                (*high.get_unchecked(j) + *low.get_unchecked(j) + *close.get_unchecked(j))
704                    * (1.0 / 3.0);
705            j += 1;
706        }
707
708        let mfi_params = MfiParams {
709            period: Some(period),
710        };
711        let mfi_input = MfiInput::from_slices(hlc3, volume, mfi_params);
712        mfi_with_kernel(&mfi_input, Kernel::Avx2)
713            .map_err(|e| AlphaTrendError::MfiError { msg: e.to_string() })?
714            .values
715    };
716
717    let mut sum = 0.0f64;
718    {
719        let mut j = first_val;
720        while j <= warmup {
721            sum += *tr.get_unchecked(j);
722            j += 1;
723        }
724    }
725
726    #[inline(always)]
727    fn fast_max(a: f64, b: f64) -> f64 {
728        if a >= b {
729            a
730        } else {
731            b
732        }
733    }
734    #[inline(always)]
735    fn fast_min(a: f64, b: f64) -> f64 {
736        if a <= b {
737            a
738        } else {
739            b
740        }
741    }
742
743    let mut prev2 = f64::NAN;
744    let mut prev1 = f64::NAN;
745    let mut prev_alpha = f64::NAN;
746
747    let mut k = warmup;
748    while k < len {
749        let a = sum / p_f;
750        let hi = *high.get_unchecked(k);
751        let lo = *low.get_unchecked(k);
752        let up = (-coeff).mul_add(a, lo);
753        let dn = coeff.mul_add(a, hi);
754        let m_ge_50 = *momentum_values.get_unchecked(k) >= 50.0;
755
756        let alpha = if k == warmup {
757            if m_ge_50 {
758                up
759            } else {
760                dn
761            }
762        } else if m_ge_50 {
763            fast_max(up, prev_alpha)
764        } else {
765            fast_min(dn, prev_alpha)
766        };
767
768        *out_k1.get_unchecked_mut(k) = alpha;
769        if k >= warmup + 2 {
770            *out_k2.get_unchecked_mut(k) = prev2;
771        }
772
773        prev2 = prev1;
774        prev1 = alpha;
775        prev_alpha = alpha;
776
777        let nxt = k + 1;
778        if nxt < len {
779            sum += *tr.get_unchecked(nxt) - *tr.get_unchecked(nxt - period);
780        }
781        k += 1;
782    }
783
784    for v in &mut out_k1[..warmup.min(len)] {
785        *v = f64::NAN;
786    }
787    for v in &mut out_k2[..(warmup + 2).min(len)] {
788        *v = f64::NAN;
789    }
790
791    Ok(())
792}
793
794#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
795#[target_feature(enable = "avx512f,fma")]
796unsafe fn alphatrend_avx512(
797    _open: &[f64],
798    high: &[f64],
799    low: &[f64],
800    close: &[f64],
801    volume: &[f64],
802    coeff: f64,
803    period: usize,
804    no_volume: bool,
805    first_val: usize,
806    out_k1: &mut [f64],
807    out_k2: &mut [f64],
808) -> Result<(), AlphaTrendError> {
809    use core::arch::x86_64::*;
810
811    #[inline(always)]
812    unsafe fn mm512_abs_pd(x: __m512d) -> __m512d {
813        let sign = _mm512_set1_pd(-0.0);
814        _mm512_andnot_pd(sign, x)
815    }
816
817    let len = close.len();
818    let warmup = first_val + period - 1;
819    let p_f = period as f64;
820
821    let mut tr_mu = make_uninit_matrix(1, len);
822    let tr: &mut [f64] = core::slice::from_raw_parts_mut(tr_mu.as_mut_ptr() as *mut f64, len);
823
824    if first_val < len {
825        *tr.get_unchecked_mut(first_val) =
826            *high.get_unchecked(first_val) - *low.get_unchecked(first_val);
827    }
828
829    let mut i = first_val + 1;
830    while i + 8 <= len {
831        let hv = _mm512_loadu_pd(high.as_ptr().add(i));
832        let lv = _mm512_loadu_pd(low.as_ptr().add(i));
833        let pc = _mm512_loadu_pd(close.as_ptr().add(i - 1));
834
835        let hl = _mm512_sub_pd(hv, lv);
836        let hc = mm512_abs_pd(_mm512_sub_pd(hv, pc));
837        let lc = mm512_abs_pd(_mm512_sub_pd(lv, pc));
838
839        let m1 = _mm512_max_pd(hl, hc);
840        let m = _mm512_max_pd(m1, lc);
841        _mm512_storeu_pd(tr.as_mut_ptr().add(i), m);
842        i += 8;
843    }
844    while i < len {
845        let hi = *high.get_unchecked(i);
846        let lo = *low.get_unchecked(i);
847        let pc = *close.get_unchecked(i - 1);
848        let hl = hi - lo;
849        let hc = (hi - pc).abs();
850        let lc = (lo - pc).abs();
851        let m = if hl >= hc { hl } else { hc };
852        *tr.get_unchecked_mut(i) = if m >= lc { m } else { lc };
853        i += 1;
854    }
855
856    let momentum_values: Vec<f64> = if no_volume {
857        let rsi_params = RsiParams {
858            period: Some(period),
859        };
860        let rsi_input = RsiInput::from_slice(close, rsi_params);
861        rsi_with_kernel(&rsi_input, Kernel::Avx512)
862            .map_err(|e| AlphaTrendError::RsiError { msg: e.to_string() })?
863            .values
864    } else {
865        let mut hlc3_mu = make_uninit_matrix(1, len);
866        let hlc3: &mut [f64] =
867            core::slice::from_raw_parts_mut(hlc3_mu.as_mut_ptr() as *mut f64, len);
868
869        let inv3 = _mm512_set1_pd(1.0 / 3.0);
870        let mut j = 0usize;
871        while j + 8 <= len {
872            let hv = _mm512_loadu_pd(high.as_ptr().add(j));
873            let lv = _mm512_loadu_pd(low.as_ptr().add(j));
874            let cv = _mm512_loadu_pd(close.as_ptr().add(j));
875            let s = _mm512_add_pd(_mm512_add_pd(hv, lv), cv);
876            let h3 = _mm512_mul_pd(s, inv3);
877            _mm512_storeu_pd(hlc3.as_mut_ptr().add(j), h3);
878            j += 8;
879        }
880        while j < len {
881            *hlc3.get_unchecked_mut(j) =
882                (*high.get_unchecked(j) + *low.get_unchecked(j) + *close.get_unchecked(j))
883                    * (1.0 / 3.0);
884            j += 1;
885        }
886
887        let mfi_params = MfiParams {
888            period: Some(period),
889        };
890        let mfi_input = MfiInput::from_slices(hlc3, volume, mfi_params);
891        mfi_with_kernel(&mfi_input, Kernel::Avx512)
892            .map_err(|e| AlphaTrendError::MfiError { msg: e.to_string() })?
893            .values
894    };
895
896    let mut sum = 0.0f64;
897    {
898        let mut j = first_val;
899        while j <= warmup {
900            sum += *tr.get_unchecked(j);
901            j += 1;
902        }
903    }
904
905    #[inline(always)]
906    fn fast_max(a: f64, b: f64) -> f64 {
907        if a >= b {
908            a
909        } else {
910            b
911        }
912    }
913    #[inline(always)]
914    fn fast_min(a: f64, b: f64) -> f64 {
915        if a <= b {
916            a
917        } else {
918            b
919        }
920    }
921
922    let mut prev2 = f64::NAN;
923    let mut prev1 = f64::NAN;
924    let mut prev_alpha = f64::NAN;
925
926    let mut k = warmup;
927    while k < len {
928        let a = sum / p_f;
929        let hi = *high.get_unchecked(k);
930        let lo = *low.get_unchecked(k);
931        let up = (-coeff).mul_add(a, lo);
932        let dn = coeff.mul_add(a, hi);
933        let m_ge_50 = *momentum_values.get_unchecked(k) >= 50.0;
934
935        let alpha = if k == warmup {
936            if m_ge_50 {
937                up
938            } else {
939                dn
940            }
941        } else if m_ge_50 {
942            fast_max(up, prev_alpha)
943        } else {
944            fast_min(dn, prev_alpha)
945        };
946
947        *out_k1.get_unchecked_mut(k) = alpha;
948        if k >= warmup + 2 {
949            *out_k2.get_unchecked_mut(k) = prev2;
950        }
951
952        prev2 = prev1;
953        prev1 = alpha;
954        prev_alpha = alpha;
955
956        let nxt = k + 1;
957        if nxt < len {
958            sum += *tr.get_unchecked(nxt) - *tr.get_unchecked(nxt - period);
959        }
960        k += 1;
961    }
962
963    for v in &mut out_k1[..warmup.min(len)] {
964        *v = f64::NAN;
965    }
966    for v in &mut out_k2[..(warmup + 2).min(len)] {
967        *v = f64::NAN;
968    }
969
970    Ok(())
971}
972
973#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
974#[inline]
975unsafe fn alphatrend_simd128(
976    open: &[f64],
977    high: &[f64],
978    low: &[f64],
979    close: &[f64],
980    volume: &[f64],
981    coeff: f64,
982    period: usize,
983    no_volume: bool,
984    first_val: usize,
985    out_k1: &mut [f64],
986    out_k2: &mut [f64],
987) -> Result<(), AlphaTrendError> {
988    use core::arch::wasm32::*;
989
990    alphatrend_scalar(
991        open, high, low, close, volume, coeff, period, no_volume, first_val, out_k1, out_k2,
992    )
993}
994
995#[derive(Debug, Clone)]
996pub struct AlphaTrendStream {
997    coeff: f64,
998    period: usize,
999    inv_period: f64,
1000    no_volume: bool,
1001
1002    tr_ring: Vec<f64>,
1003    tr_sum: f64,
1004    tr_idx: usize,
1005    tr_filled: usize,
1006
1007    rsi_seeded: bool,
1008    rsi_init_gains: f64,
1009    rsi_init_losses: f64,
1010    rsi_count: usize,
1011    rsi_avg_gain: f64,
1012    rsi_avg_loss: f64,
1013
1014    mfi_pos_ring: Vec<f64>,
1015    mfi_neg_ring: Vec<f64>,
1016    mfi_pos_sum: f64,
1017    mfi_neg_sum: f64,
1018    mfi_idx: usize,
1019    mfi_filled: usize,
1020    prev_tp: f64,
1021
1022    prev_close: f64,
1023    have_prev: bool,
1024
1025    prev_alpha: f64,
1026    prev1: f64,
1027    prev2: f64,
1028    alpha_count: usize,
1029}
1030
1031impl AlphaTrendStream {
1032    pub fn try_new(params: AlphaTrendParams) -> Result<Self, AlphaTrendError> {
1033        let coeff = params.coeff.unwrap_or(1.0);
1034        let period = params.period.unwrap_or(14);
1035        let no_volume = params.no_volume.unwrap_or(false);
1036
1037        if period == 0 {
1038            return Err(AlphaTrendError::InvalidPeriod {
1039                period,
1040                data_len: 0,
1041            });
1042        }
1043        if coeff <= 0.0 || !coeff.is_finite() {
1044            return Err(AlphaTrendError::InvalidCoeff { coeff });
1045        }
1046
1047        Ok(Self {
1048            coeff,
1049            period,
1050            inv_period: 1.0 / (period as f64),
1051            no_volume,
1052
1053            tr_ring: vec![0.0; period],
1054            tr_sum: 0.0,
1055            tr_idx: 0,
1056            tr_filled: 0,
1057
1058            rsi_seeded: false,
1059            rsi_init_gains: 0.0,
1060            rsi_init_losses: 0.0,
1061            rsi_count: 0,
1062            rsi_avg_gain: 0.0,
1063            rsi_avg_loss: 0.0,
1064
1065            mfi_pos_ring: vec![0.0; period],
1066            mfi_neg_ring: vec![0.0; period],
1067            mfi_pos_sum: 0.0,
1068            mfi_neg_sum: 0.0,
1069            mfi_idx: 0,
1070            mfi_filled: 0,
1071            prev_tp: f64::NAN,
1072
1073            prev_close: f64::NAN,
1074            have_prev: false,
1075
1076            prev_alpha: f64::NAN,
1077            prev1: f64::NAN,
1078            prev2: f64::NAN,
1079            alpha_count: 0,
1080        })
1081    }
1082
1083    #[inline]
1084    pub fn update(&mut self, high: f64, low: f64, close: f64, volume: f64) -> Option<(f64, f64)> {
1085        if !(high.is_finite() && low.is_finite() && close.is_finite() && volume.is_finite()) {
1086            return None;
1087        }
1088        if high < low {
1089            return None;
1090        }
1091
1092        let tr = if self.have_prev {
1093            let hl = high - low;
1094            let hc = (high - self.prev_close).abs();
1095            let lc = (low - self.prev_close).abs();
1096            if hl >= hc {
1097                if hl >= lc {
1098                    hl
1099                } else {
1100                    lc
1101                }
1102            } else {
1103                if hc >= lc {
1104                    hc
1105                } else {
1106                    lc
1107                }
1108            }
1109        } else {
1110            high - low
1111        };
1112
1113        if self.tr_filled < self.period {
1114            self.tr_ring[self.tr_idx] = tr;
1115            self.tr_sum += tr;
1116            self.tr_filled += 1;
1117            self.tr_idx = (self.tr_idx + 1) % self.period;
1118        } else {
1119            let old = self.tr_ring[self.tr_idx];
1120            self.tr_ring[self.tr_idx] = tr;
1121            self.tr_sum += tr - old;
1122            self.tr_idx = (self.tr_idx + 1) % self.period;
1123        }
1124        let atr_ready = self.tr_filled == self.period;
1125        let atr = if atr_ready {
1126            self.tr_sum * self.inv_period
1127        } else {
1128            f64::NAN
1129        };
1130
1131        let mut m_ge_50 = false;
1132
1133        if self.no_volume {
1134            let (gain, loss) = if self.have_prev {
1135                let d = close - self.prev_close;
1136                if d >= 0.0 {
1137                    (d, 0.0)
1138                } else {
1139                    (0.0, -d)
1140                }
1141            } else {
1142                (0.0, 0.0)
1143            };
1144
1145            if !self.rsi_seeded {
1146                self.rsi_init_gains += gain;
1147                self.rsi_init_losses += loss;
1148                self.rsi_count += 1;
1149                if self.rsi_count >= self.period {
1150                    self.rsi_avg_gain = self.rsi_init_gains * self.inv_period;
1151                    self.rsi_avg_loss = self.rsi_init_losses * self.inv_period;
1152                    self.rsi_seeded = true;
1153                }
1154            } else {
1155                let n1 = (self.period as f64) - 1.0;
1156                self.rsi_avg_gain = (self.rsi_avg_gain * n1 + gain) * self.inv_period;
1157                self.rsi_avg_loss = (self.rsi_avg_loss * n1 + loss) * self.inv_period;
1158            }
1159
1160            if self.rsi_seeded {
1161                if self.rsi_avg_loss == 0.0 {
1162                    m_ge_50 = self.rsi_avg_gain >= 0.0;
1163                } else if self.rsi_avg_gain == 0.0 {
1164                    m_ge_50 = false;
1165                } else {
1166                    m_ge_50 = self.rsi_avg_gain >= self.rsi_avg_loss;
1167                }
1168            } else {
1169                m_ge_50 = false;
1170            }
1171        } else {
1172            let tp = (high + low + close) / 3.0;
1173            if self.have_prev {
1174                let mf = (tp * volume).max(0.0);
1175                let (pos, neg) = if tp > self.prev_tp {
1176                    (mf, 0.0)
1177                } else if tp < self.prev_tp {
1178                    (0.0, mf)
1179                } else {
1180                    (0.0, 0.0)
1181                };
1182
1183                if self.mfi_filled < self.period {
1184                    self.mfi_pos_sum += pos;
1185                    self.mfi_neg_sum += neg;
1186                    self.mfi_pos_ring[self.mfi_idx] = pos;
1187                    self.mfi_neg_ring[self.mfi_idx] = neg;
1188                    self.mfi_idx = (self.mfi_idx + 1) % self.period;
1189                    self.mfi_filled += 1;
1190                } else {
1191                    let op = self.mfi_pos_ring[self.mfi_idx];
1192                    let on = self.mfi_neg_ring[self.mfi_idx];
1193                    self.mfi_pos_ring[self.mfi_idx] = pos;
1194                    self.mfi_neg_ring[self.mfi_idx] = neg;
1195                    self.mfi_pos_sum += pos - op;
1196                    self.mfi_neg_sum += neg - on;
1197                    self.mfi_idx = (self.mfi_idx + 1) % self.period;
1198                }
1199            }
1200
1201            if self.mfi_filled == self.period {
1202                if self.mfi_neg_sum == 0.0 {
1203                    m_ge_50 = self.mfi_pos_sum >= 0.0;
1204                } else if self.mfi_pos_sum == 0.0 {
1205                    m_ge_50 = false;
1206                } else {
1207                    m_ge_50 = self.mfi_pos_sum >= self.mfi_neg_sum;
1208                }
1209            } else {
1210                m_ge_50 = false;
1211            }
1212            self.prev_tp = tp;
1213        }
1214
1215        let mut emitted = false;
1216        let mut cur = f64::NAN;
1217        let mut k2_out = f64::NAN;
1218
1219        if atr_ready {
1220            let up = (-self.coeff).mul_add(atr, low);
1221            let dn = self.coeff.mul_add(atr, high);
1222
1223            cur = if self.alpha_count == 0 {
1224                if m_ge_50 {
1225                    up
1226                } else {
1227                    dn
1228                }
1229            } else if m_ge_50 {
1230                if up < self.prev_alpha {
1231                    self.prev_alpha
1232                } else {
1233                    up
1234                }
1235            } else {
1236                if dn > self.prev_alpha {
1237                    self.prev_alpha
1238                } else {
1239                    dn
1240                }
1241            };
1242
1243            let k2_emit = self.prev2;
1244            self.prev2 = self.prev1;
1245            self.prev1 = cur;
1246            self.prev_alpha = cur;
1247            self.alpha_count += 1;
1248            emitted = true;
1249            k2_out = k2_emit;
1250        }
1251
1252        self.prev_close = close;
1253        self.have_prev = true;
1254
1255        if emitted && self.alpha_count >= 3 {
1256            Some((cur, k2_out))
1257        } else {
1258            None
1259        }
1260    }
1261
1262    #[inline(always)]
1263    pub fn get_warmup_period(&self) -> usize {
1264        self.period - 1
1265    }
1266}
1267
1268#[cfg(feature = "python")]
1269#[pyfunction(name = "alphatrend")]
1270#[pyo3(signature = (open, high, low, close, volume, coeff=1.0, period=14, no_volume=false, kernel=None))]
1271pub fn alphatrend_py<'py>(
1272    py: Python<'py>,
1273    open: PyReadonlyArray1<'py, f64>,
1274    high: PyReadonlyArray1<'py, f64>,
1275    low: PyReadonlyArray1<'py, f64>,
1276    close: PyReadonlyArray1<'py, f64>,
1277    volume: PyReadonlyArray1<'py, f64>,
1278    coeff: f64,
1279    period: usize,
1280    no_volume: bool,
1281    kernel: Option<&str>,
1282) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
1283    let open_slice = open.as_slice()?;
1284    let high_slice = high.as_slice()?;
1285    let low_slice = low.as_slice()?;
1286    let close_slice = close.as_slice()?;
1287    let volume_slice = volume.as_slice()?;
1288
1289    let kern = validate_kernel(kernel, false)?;
1290    let params = AlphaTrendParams {
1291        coeff: Some(coeff),
1292        period: Some(period),
1293        no_volume: Some(no_volume),
1294    };
1295    let input = AlphaTrendInput::from_slices(
1296        open_slice,
1297        high_slice,
1298        low_slice,
1299        close_slice,
1300        volume_slice,
1301        params,
1302    );
1303
1304    let result = py
1305        .allow_threads(|| alphatrend_with_kernel(&input, kern))
1306        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1307
1308    Ok((result.k1.into_pyarray(py), result.k2.into_pyarray(py)))
1309}
1310
1311#[cfg(feature = "python")]
1312#[pyclass(name = "AlphaTrendStream")]
1313pub struct AlphaTrendStreamPy {
1314    stream: AlphaTrendStream,
1315}
1316
1317#[cfg(feature = "python")]
1318#[pymethods]
1319impl AlphaTrendStreamPy {
1320    #[new]
1321    fn new(coeff: f64, period: usize, no_volume: bool) -> PyResult<Self> {
1322        let params = AlphaTrendParams {
1323            coeff: Some(coeff),
1324            period: Some(period),
1325            no_volume: Some(no_volume),
1326        };
1327        let stream =
1328            AlphaTrendStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1329        Ok(AlphaTrendStreamPy { stream })
1330    }
1331
1332    fn update(&mut self, high: f64, low: f64, close: f64, volume: f64) -> Option<(f64, f64)> {
1333        self.stream.update(high, low, close, volume)
1334    }
1335}
1336
1337#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1338#[derive(Serialize, Deserialize)]
1339pub struct AlphaTrendJsOutput {
1340    pub values: Vec<f64>,
1341    pub rows: usize,
1342    pub cols: usize,
1343}
1344
1345#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1346#[wasm_bindgen]
1347pub fn alphatrend_js(
1348    open: &[f64],
1349    high: &[f64],
1350    low: &[f64],
1351    close: &[f64],
1352    volume: &[f64],
1353    coeff: f64,
1354    period: usize,
1355    no_volume: bool,
1356) -> Result<JsValue, JsValue> {
1357    let params = AlphaTrendParams {
1358        coeff: Some(coeff),
1359        period: Some(period),
1360        no_volume: Some(no_volume),
1361    };
1362    let input = AlphaTrendInput::from_slices(open, high, low, close, volume, params);
1363
1364    let mut k1 = vec![0.0; close.len()];
1365    let mut k2 = vec![0.0; close.len()];
1366
1367    alphatrend_into_slices(&mut k1, &mut k2, &input, Kernel::Auto)
1368        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1369
1370    let mut values = Vec::with_capacity(k1.len() * 2);
1371    values.extend_from_slice(&k1);
1372    values.extend_from_slice(&k2);
1373
1374    let out = AlphaTrendJsOutput {
1375        values,
1376        rows: 2,
1377        cols: close.len(),
1378    };
1379    serde_wasm_bindgen::to_value(&out)
1380        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1381}
1382
1383#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1384#[wasm_bindgen]
1385pub fn alphatrend_alloc_flat(n: usize) -> *mut f64 {
1386    let mut v = Vec::<f64>::with_capacity(2 * n);
1387    let p = v.as_mut_ptr();
1388    core::mem::forget(v);
1389    p
1390}
1391
1392#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1393#[wasm_bindgen]
1394pub fn alphatrend_free_flat(ptr: *mut f64, n: usize) {
1395    unsafe {
1396        let _ = Vec::from_raw_parts(ptr, 2 * n, 2 * n);
1397    }
1398}
1399
1400#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1401#[wasm_bindgen]
1402pub fn alphatrend_into_flat(
1403    open_ptr: *const f64,
1404    high_ptr: *const f64,
1405    low_ptr: *const f64,
1406    close_ptr: *const f64,
1407    volume_ptr: *const f64,
1408    out_flat_ptr: *mut f64,
1409    len: usize,
1410    coeff: f64,
1411    period: usize,
1412    no_volume: bool,
1413) -> Result<(), JsValue> {
1414    if [open_ptr, high_ptr, low_ptr, close_ptr, volume_ptr]
1415        .iter()
1416        .any(|&p| p.is_null())
1417        || out_flat_ptr.is_null()
1418    {
1419        return Err(JsValue::from_str("null pointer"));
1420    }
1421    unsafe {
1422        let (open, high, low, close, volume) = (
1423            core::slice::from_raw_parts(open_ptr, len),
1424            core::slice::from_raw_parts(high_ptr, len),
1425            core::slice::from_raw_parts(low_ptr, len),
1426            core::slice::from_raw_parts(close_ptr, len),
1427            core::slice::from_raw_parts(volume_ptr, len),
1428        );
1429        let (k1, k2) = (
1430            core::slice::from_raw_parts_mut(out_flat_ptr, len),
1431            core::slice::from_raw_parts_mut(out_flat_ptr.add(len), len),
1432        );
1433        let params = AlphaTrendParams {
1434            coeff: Some(coeff),
1435            period: Some(period),
1436            no_volume: Some(no_volume),
1437        };
1438        let input = AlphaTrendInput::from_slices(open, high, low, close, volume, params);
1439        alphatrend_into_slices(k1, k2, &input, Kernel::Auto)
1440            .map_err(|e| JsValue::from_str(&e.to_string()))
1441    }
1442}
1443
1444#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1445#[wasm_bindgen]
1446#[deprecated(note = "Use alphatrend_alloc_flat/alphatrend_into_flat")]
1447pub fn alphatrend_alloc(_len: usize) -> *mut f64 {
1448    core::ptr::null_mut()
1449}
1450
1451#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1452#[wasm_bindgen]
1453#[deprecated(note = "Use alphatrend_free_flat")]
1454pub fn alphatrend_free(_ptr: *mut f64, _len: usize) {}
1455
1456#[derive(Clone, Debug)]
1457pub struct AlphaTrendBatchRange {
1458    pub coeff: (f64, f64, f64),
1459    pub period: (usize, usize, usize),
1460    pub no_volume: bool,
1461}
1462
1463impl Default for AlphaTrendBatchRange {
1464    fn default() -> Self {
1465        Self {
1466            coeff: (1.0, 1.0, 0.0),
1467            period: (14, 263, 1),
1468            no_volume: false,
1469        }
1470    }
1471}
1472
1473#[derive(Clone, Debug, Default)]
1474pub struct AlphaTrendBatchBuilder {
1475    range: AlphaTrendBatchRange,
1476    kernel: Kernel,
1477}
1478
1479impl AlphaTrendBatchBuilder {
1480    pub fn new() -> Self {
1481        Self::default()
1482    }
1483
1484    pub fn kernel(mut self, k: Kernel) -> Self {
1485        self.kernel = k;
1486        self
1487    }
1488
1489    #[inline]
1490    pub fn coeff_range(mut self, start: f64, end: f64, step: f64) -> Self {
1491        self.range.coeff = (start, end, step);
1492        self
1493    }
1494
1495    #[inline]
1496    pub fn coeff_static(mut self, val: f64) -> Self {
1497        self.range.coeff = (val, val, 0.0);
1498        self
1499    }
1500
1501    #[inline]
1502    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1503        self.range.period = (start, end, step);
1504        self
1505    }
1506
1507    #[inline]
1508    pub fn period_static(mut self, val: usize) -> Self {
1509        self.range.period = (val, val, 0);
1510        self
1511    }
1512
1513    #[inline]
1514    pub fn no_volume(mut self, val: bool) -> Self {
1515        self.range.no_volume = val;
1516        self
1517    }
1518
1519    pub fn apply_candles(self, c: &Candles) -> Result<AlphaTrendBatchOutput, AlphaTrendError> {
1520        alphatrend_batch_with_kernel(c, &self.range, self.kernel)
1521    }
1522
1523    pub fn apply_slices(
1524        self,
1525        open: &[f64],
1526        high: &[f64],
1527        low: &[f64],
1528        close: &[f64],
1529        volume: &[f64],
1530    ) -> Result<AlphaTrendBatchOutput, AlphaTrendError> {
1531        let len = close.len();
1532        if open.len() != len || high.len() != len || low.len() != len || volume.len() != len {
1533            return Err(AlphaTrendError::InconsistentDataLengths);
1534        }
1535
1536        let candles = Candles {
1537            timestamp: vec![0; len],
1538            open: open.to_vec(),
1539            high: high.to_vec(),
1540            low: low.to_vec(),
1541            close: close.to_vec(),
1542            volume: volume.to_vec(),
1543            fields: CandleFieldFlags {
1544                open: true,
1545                high: true,
1546                low: true,
1547                close: true,
1548                volume: true,
1549            },
1550            hl2: vec![],
1551            hlc3: vec![],
1552            ohlc4: vec![],
1553            hlcc4: vec![],
1554        };
1555
1556        alphatrend_batch_with_kernel(&candles, &self.range, self.kernel)
1557    }
1558
1559    pub fn with_default_candles(c: &Candles) -> Result<AlphaTrendBatchOutput, AlphaTrendError> {
1560        AlphaTrendBatchBuilder::new()
1561            .kernel(Kernel::Auto)
1562            .apply_candles(c)
1563    }
1564
1565    pub fn with_default_slices(
1566        open: &[f64],
1567        high: &[f64],
1568        low: &[f64],
1569        close: &[f64],
1570        volume: &[f64],
1571        k: Kernel,
1572    ) -> Result<AlphaTrendBatchOutput, AlphaTrendError> {
1573        AlphaTrendBatchBuilder::new()
1574            .kernel(k)
1575            .apply_slices(open, high, low, close, volume)
1576    }
1577}
1578
1579#[derive(Clone, Debug)]
1580pub struct AlphaTrendBatchOutput {
1581    pub values_k1: Vec<f64>,
1582    pub values_k2: Vec<f64>,
1583    pub combos: Vec<AlphaTrendParams>,
1584    pub rows: usize,
1585    pub cols: usize,
1586}
1587
1588impl AlphaTrendBatchOutput {
1589    pub fn row_for_params(&self, p: &AlphaTrendParams) -> Option<usize> {
1590        self.combos.iter().position(|c| {
1591            (c.coeff.unwrap_or(1.0) - p.coeff.unwrap_or(1.0)).abs() < 1e-12
1592                && c.period.unwrap_or(14) == p.period.unwrap_or(14)
1593                && c.no_volume.unwrap_or(false) == p.no_volume.unwrap_or(false)
1594        })
1595    }
1596
1597    pub fn values_for(&self, p: &AlphaTrendParams) -> Option<(&[f64], &[f64])> {
1598        self.row_for_params(p).map(|row| {
1599            let start = row * self.cols;
1600            let end = start + self.cols;
1601            (&self.values_k1[start..end], &self.values_k2[start..end])
1602        })
1603    }
1604}
1605
1606#[inline(always)]
1607fn expand_grid_alphatrend(
1608    r: &AlphaTrendBatchRange,
1609) -> Result<Vec<AlphaTrendParams>, AlphaTrendError> {
1610    fn axis_usize(
1611        (start, end, step): (usize, usize, usize),
1612    ) -> Result<Vec<usize>, AlphaTrendError> {
1613        if step == 0 || start == end {
1614            return Ok(vec![start]);
1615        }
1616        let mut v = Vec::new();
1617        if start < end {
1618            let mut cur = start;
1619            while cur <= end {
1620                v.push(cur);
1621                cur = cur.saturating_add(step);
1622                if cur == *v.last().unwrap() {
1623                    break;
1624                }
1625            }
1626        } else {
1627            let mut cur = start;
1628            while cur >= end {
1629                v.push(cur);
1630                let next = cur.saturating_sub(step);
1631                if next == cur {
1632                    break;
1633                }
1634                cur = next;
1635                if cur == 0 && end > 0 {
1636                    break;
1637                }
1638            }
1639        }
1640        if v.is_empty() {
1641            return Err(AlphaTrendError::InvalidRange { start, end, step });
1642        }
1643        Ok(v)
1644    }
1645    fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, AlphaTrendError> {
1646        if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
1647            return Ok(vec![start]);
1648        }
1649        let mut out = Vec::new();
1650        if start < end {
1651            let st = if step > 0.0 { step } else { -step };
1652            let mut x = start;
1653            while x <= end + 1e-12 {
1654                out.push(x);
1655                x += st;
1656            }
1657        } else {
1658            let st = if step > 0.0 { -step } else { step };
1659            if st.abs() < 1e-12 {
1660                return Ok(vec![start]);
1661            }
1662            let mut x = start;
1663            while x >= end - 1e-12 {
1664                out.push(x);
1665                x += st;
1666            }
1667        }
1668        if out.is_empty() {
1669            return Err(AlphaTrendError::InvalidRangeF64 { start, end, step });
1670        }
1671        Ok(out)
1672    }
1673
1674    let coeffs = axis_f64(r.coeff)?;
1675    let periods = axis_usize(r.period)?;
1676
1677    let mut out = Vec::with_capacity(coeffs.len().saturating_mul(periods.len()));
1678    for &c in &coeffs {
1679        for &p in &periods {
1680            out.push(AlphaTrendParams {
1681                coeff: Some(c),
1682                period: Some(p),
1683                no_volume: Some(r.no_volume),
1684            });
1685        }
1686    }
1687    Ok(out)
1688}
1689
1690pub fn alphatrend_batch_with_kernel(
1691    candles: &Candles,
1692    sweep: &AlphaTrendBatchRange,
1693    k: Kernel,
1694) -> Result<AlphaTrendBatchOutput, AlphaTrendError> {
1695    let kernel = match k {
1696        Kernel::Auto => detect_best_batch_kernel(),
1697        other if other.is_batch() => other,
1698        other => return Err(AlphaTrendError::InvalidKernelForBatch(other)),
1699    };
1700
1701    let simd = match kernel {
1702        Kernel::Avx512Batch => Kernel::Avx512,
1703        Kernel::Avx2Batch => Kernel::Avx2,
1704        Kernel::ScalarBatch => Kernel::Scalar,
1705        _ => unreachable!(),
1706    };
1707
1708    alphatrend_batch_inner(candles, sweep, simd, true)
1709}
1710
1711#[inline(always)]
1712pub fn alphatrend_batch_slice(
1713    open: &[f64],
1714    high: &[f64],
1715    low: &[f64],
1716    close: &[f64],
1717    volume: &[f64],
1718    sweep: &AlphaTrendBatchRange,
1719    kern: Kernel,
1720) -> Result<AlphaTrendBatchOutput, AlphaTrendError> {
1721    alphatrend_batch_inner_from_slices(open, high, low, close, volume, sweep, kern, false)
1722}
1723
1724#[inline(always)]
1725pub fn alphatrend_batch_par_slice(
1726    open: &[f64],
1727    high: &[f64],
1728    low: &[f64],
1729    close: &[f64],
1730    volume: &[f64],
1731    sweep: &AlphaTrendBatchRange,
1732    kern: Kernel,
1733) -> Result<AlphaTrendBatchOutput, AlphaTrendError> {
1734    alphatrend_batch_inner_from_slices(open, high, low, close, volume, sweep, kern, true)
1735}
1736
1737#[inline(always)]
1738fn alphatrend_batch_inner_from_slices(
1739    open: &[f64],
1740    high: &[f64],
1741    low: &[f64],
1742    close: &[f64],
1743    volume: &[f64],
1744    sweep: &AlphaTrendBatchRange,
1745    kern: Kernel,
1746    parallel: bool,
1747) -> Result<AlphaTrendBatchOutput, AlphaTrendError> {
1748    let combos = expand_grid_alphatrend(sweep)?;
1749    let cols = close.len();
1750    let rows = combos.len();
1751    let _elems = rows
1752        .checked_mul(cols)
1753        .ok_or_else(|| AlphaTrendError::InvalidInput("rows*cols overflow".into()))?;
1754    if cols == 0 {
1755        return Err(AlphaTrendError::EmptyInputData);
1756    }
1757
1758    let first = close
1759        .iter()
1760        .position(|x| !x.is_nan())
1761        .ok_or(AlphaTrendError::AllValuesNaN)?;
1762    let warm_k1: Vec<usize> = combos
1763        .iter()
1764        .map(|p| first + p.period.unwrap_or(14) - 1)
1765        .collect();
1766    let warm_k2: Vec<usize> = warm_k1.iter().map(|&w| w.saturating_add(2)).collect();
1767
1768    let mut k1_mu = make_uninit_matrix(rows, cols);
1769    let mut k2_mu = make_uninit_matrix(rows, cols);
1770    init_matrix_prefixes(&mut k1_mu, cols, &warm_k1);
1771    init_matrix_prefixes(&mut k2_mu, cols, &warm_k2);
1772
1773    let mut k1_guard = core::mem::ManuallyDrop::new(k1_mu);
1774    let mut k2_guard = core::mem::ManuallyDrop::new(k2_mu);
1775    let out_k1: &mut [f64] = unsafe {
1776        core::slice::from_raw_parts_mut(k1_guard.as_mut_ptr() as *mut f64, k1_guard.len())
1777    };
1778    let out_k2: &mut [f64] = unsafe {
1779        core::slice::from_raw_parts_mut(k2_guard.as_mut_ptr() as *mut f64, k2_guard.len())
1780    };
1781
1782    let actual = match kern {
1783        Kernel::Auto => detect_best_batch_kernel(),
1784        k => k,
1785    };
1786    alphatrend_batch_inner_into_slices(
1787        open, high, low, close, volume, sweep, actual, parallel, out_k1, out_k2,
1788    )?;
1789
1790    let values_k1 = unsafe {
1791        Vec::from_raw_parts(
1792            k1_guard.as_mut_ptr() as *mut f64,
1793            k1_guard.len(),
1794            k1_guard.capacity(),
1795        )
1796    };
1797    let values_k2 = unsafe {
1798        Vec::from_raw_parts(
1799            k2_guard.as_mut_ptr() as *mut f64,
1800            k2_guard.len(),
1801            k2_guard.capacity(),
1802        )
1803    };
1804    core::mem::forget(k1_guard);
1805    core::mem::forget(k2_guard);
1806
1807    Ok(AlphaTrendBatchOutput {
1808        values_k1,
1809        values_k2,
1810        combos,
1811        rows,
1812        cols,
1813    })
1814}
1815
1816#[inline(always)]
1817fn alphatrend_batch_inner(
1818    candles: &Candles,
1819    sweep: &AlphaTrendBatchRange,
1820    kern: Kernel,
1821    parallel: bool,
1822) -> Result<AlphaTrendBatchOutput, AlphaTrendError> {
1823    let combos = expand_grid_alphatrend(sweep)?;
1824    let cols = candles.close.len();
1825    let rows = combos.len();
1826    let _elems = rows
1827        .checked_mul(cols)
1828        .ok_or_else(|| AlphaTrendError::InvalidInput("rows*cols overflow".into()))?;
1829    if cols == 0 {
1830        return Err(AlphaTrendError::EmptyInputData);
1831    }
1832
1833    let mut k1_mu = make_uninit_matrix(rows, cols);
1834    let mut k2_mu = make_uninit_matrix(rows, cols);
1835
1836    let first = candles
1837        .close
1838        .iter()
1839        .position(|x| !x.is_nan())
1840        .ok_or(AlphaTrendError::AllValuesNaN)?;
1841    let warm_k1: Vec<usize> = combos
1842        .iter()
1843        .map(|p| first + p.period.unwrap_or(14) - 1)
1844        .collect();
1845    let warm_k2: Vec<usize> = warm_k1.iter().map(|&w| w.saturating_add(2)).collect();
1846
1847    init_matrix_prefixes(&mut k1_mu, cols, &warm_k1);
1848    init_matrix_prefixes(&mut k2_mu, cols, &warm_k2);
1849
1850    let mut k1_guard = core::mem::ManuallyDrop::new(k1_mu);
1851    let mut k2_guard = core::mem::ManuallyDrop::new(k2_mu);
1852    let out_k1: &mut [f64] = unsafe {
1853        core::slice::from_raw_parts_mut(k1_guard.as_mut_ptr() as *mut f64, k1_guard.len())
1854    };
1855    let out_k2: &mut [f64] = unsafe {
1856        core::slice::from_raw_parts_mut(k2_guard.as_mut_ptr() as *mut f64, k2_guard.len())
1857    };
1858
1859    let actual = match kern {
1860        Kernel::Auto => detect_best_batch_kernel(),
1861        k => k,
1862    };
1863
1864    let do_row =
1865        |row: usize, k1_row: &mut [f64], k2_row: &mut [f64]| -> Result<(), AlphaTrendError> {
1866            let p = &combos[row];
1867            let input = AlphaTrendInput::from_candles(candles, p.clone());
1868            alphatrend_into_slices(k1_row, k2_row, &input, actual)
1869        };
1870
1871    #[cfg(not(target_arch = "wasm32"))]
1872    if parallel {
1873        use rayon::prelude::*;
1874
1875        out_k1
1876            .par_chunks_mut(cols)
1877            .zip(out_k2.par_chunks_mut(cols))
1878            .enumerate()
1879            .try_for_each(|(row, (k1r, k2r))| do_row(row, k1r, k2r))?;
1880    } else {
1881        for (row, (k1r, k2r)) in out_k1
1882            .chunks_mut(cols)
1883            .zip(out_k2.chunks_mut(cols))
1884            .enumerate()
1885        {
1886            do_row(row, k1r, k2r)?;
1887        }
1888    }
1889
1890    #[cfg(target_arch = "wasm32")]
1891    for (row, (k1r, k2r)) in out_k1
1892        .chunks_mut(cols)
1893        .zip(out_k2.chunks_mut(cols))
1894        .enumerate()
1895    {
1896        do_row(row, k1r, k2r)?;
1897    }
1898
1899    let values_k1 = unsafe {
1900        Vec::from_raw_parts(
1901            k1_guard.as_mut_ptr() as *mut f64,
1902            k1_guard.len(),
1903            k1_guard.capacity(),
1904        )
1905    };
1906    let values_k2 = unsafe {
1907        Vec::from_raw_parts(
1908            k2_guard.as_mut_ptr() as *mut f64,
1909            k2_guard.len(),
1910            k2_guard.capacity(),
1911        )
1912    };
1913
1914    Ok(AlphaTrendBatchOutput {
1915        values_k1,
1916        values_k2,
1917        combos,
1918        rows,
1919        cols,
1920    })
1921}
1922
1923#[inline(always)]
1924pub fn alphatrend_batch_inner_into_slices(
1925    open: &[f64],
1926    high: &[f64],
1927    low: &[f64],
1928    close: &[f64],
1929    volume: &[f64],
1930    sweep: &AlphaTrendBatchRange,
1931    kern: Kernel,
1932    parallel: bool,
1933    k1_slice: &mut [f64],
1934    k2_slice: &mut [f64],
1935) -> Result<(), AlphaTrendError> {
1936    let combos = expand_grid_alphatrend(sweep)?;
1937    let cols = close.len();
1938    let rows = combos.len();
1939
1940    if open.len() != cols || high.len() != cols || low.len() != cols || volume.len() != cols {
1941        return Err(AlphaTrendError::InconsistentDataLengths);
1942    }
1943
1944    if cols == 0 {
1945        return Err(AlphaTrendError::EmptyInputData);
1946    }
1947
1948    let total = rows
1949        .checked_mul(cols)
1950        .ok_or_else(|| AlphaTrendError::InvalidInput("rows*cols overflow".into()))?;
1951    if k1_slice.len() != total {
1952        return Err(AlphaTrendError::OutputLengthMismatch {
1953            expected: total,
1954            got: k1_slice.len(),
1955        });
1956    }
1957    if k2_slice.len() != total {
1958        return Err(AlphaTrendError::OutputLengthMismatch {
1959            expected: total,
1960            got: k2_slice.len(),
1961        });
1962    }
1963
1964    let actual = match kern {
1965        Kernel::Auto => detect_best_batch_kernel(),
1966        k => k,
1967    };
1968    let simd_kernel = match actual {
1969        Kernel::Avx512Batch => Kernel::Avx512,
1970        Kernel::Avx2Batch => Kernel::Avx2,
1971        Kernel::ScalarBatch => Kernel::Scalar,
1972        _ => detect_best_kernel(),
1973    };
1974
1975    let first = close
1976        .iter()
1977        .position(|x| !x.is_nan())
1978        .ok_or(AlphaTrendError::AllValuesNaN)?;
1979
1980    let warm_k1: Vec<usize> = combos
1981        .iter()
1982        .map(|p| first + p.period.unwrap_or(14) - 1)
1983        .collect();
1984    let warm_k2: Vec<usize> = warm_k1.iter().map(|&w| w.saturating_add(2)).collect();
1985
1986    let k1_uninit: &mut [MaybeUninit<f64>] = unsafe {
1987        core::slice::from_raw_parts_mut(
1988            k1_slice.as_mut_ptr() as *mut MaybeUninit<f64>,
1989            k1_slice.len(),
1990        )
1991    };
1992    let k2_uninit: &mut [MaybeUninit<f64>] = unsafe {
1993        core::slice::from_raw_parts_mut(
1994            k2_slice.as_mut_ptr() as *mut MaybeUninit<f64>,
1995            k2_slice.len(),
1996        )
1997    };
1998    init_matrix_prefixes(k1_uninit, cols, &warm_k1);
1999    init_matrix_prefixes(k2_uninit, cols, &warm_k2);
2000
2001    let mut tr_mu = make_uninit_matrix(1, cols);
2002    let tr: &mut [f64] =
2003        unsafe { core::slice::from_raw_parts_mut(tr_mu.as_mut_ptr() as *mut f64, cols) };
2004    if first < cols {
2005        tr[first] = high[first] - low[first];
2006    }
2007    for i in (first + 1)..cols {
2008        let hl = high[i] - low[i];
2009        let hc = (high[i] - close[i - 1]).abs();
2010        let lc = (low[i] - close[i - 1]).abs();
2011        tr[i] = hl.max(hc).max(lc);
2012    }
2013
2014    let use_rsi = sweep.no_volume;
2015    let hlc3_opt: Option<Vec<f64>> = if use_rsi {
2016        None
2017    } else {
2018        let mut hlc3_mu = make_uninit_matrix(1, cols);
2019        let hlc3: &mut [f64] =
2020            unsafe { core::slice::from_raw_parts_mut(hlc3_mu.as_mut_ptr() as *mut f64, cols) };
2021        for i in 0..cols {
2022            hlc3[i] = (high[i] + low[i] + close[i]) / 3.0;
2023        }
2024
2025        let v = unsafe {
2026            Vec::from_raw_parts(
2027                hlc3_mu.as_mut_ptr() as *mut f64,
2028                hlc3_mu.len(),
2029                hlc3_mu.capacity(),
2030            )
2031        };
2032        core::mem::forget(hlc3_mu);
2033        Some(v)
2034    };
2035
2036    use std::collections::HashMap;
2037    let mut unique_periods: Vec<usize> = combos.iter().map(|p| p.period.unwrap_or(14)).collect();
2038    unique_periods.sort_unstable();
2039    unique_periods.dedup();
2040
2041    let mut momentum_map: HashMap<usize, Vec<f64>> = HashMap::with_capacity(unique_periods.len());
2042    for &p in &unique_periods {
2043        if p == 0 || p > cols {
2044            return Err(AlphaTrendError::InvalidPeriod {
2045                period: p,
2046                data_len: cols,
2047            });
2048        }
2049        if use_rsi {
2050            let rsi_params = RsiParams { period: Some(p) };
2051            let rsi_input = RsiInput::from_slice(close, rsi_params);
2052            let mv = rsi_with_kernel(&rsi_input, simd_kernel)
2053                .map_err(|e| AlphaTrendError::RsiError { msg: e.to_string() })?
2054                .values;
2055            momentum_map.insert(p, mv);
2056        } else {
2057            let hlc3 = hlc3_opt.as_ref().expect("hlc3 precomputed");
2058            let mfi_params = MfiParams { period: Some(p) };
2059            let mfi_input = MfiInput::from_slices(hlc3, volume, mfi_params);
2060            let mv = mfi_with_kernel(&mfi_input, simd_kernel)
2061                .map_err(|e| AlphaTrendError::MfiError { msg: e.to_string() })?
2062                .values;
2063            momentum_map.insert(p, mv);
2064        }
2065    }
2066
2067    let do_row =
2068        |row: usize, k1_row: &mut [f64], k2_row: &mut [f64]| -> Result<(), AlphaTrendError> {
2069            let params = &combos[row];
2070            let coeff = params.coeff.unwrap_or(1.0);
2071            if !coeff.is_finite() || coeff <= 0.0 {
2072                return Err(AlphaTrendError::InvalidCoeff { coeff });
2073            }
2074            let period = params.period.unwrap_or(14);
2075            if period == 0 || period > cols {
2076                return Err(AlphaTrendError::InvalidPeriod {
2077                    period,
2078                    data_len: cols,
2079                });
2080            }
2081            let warmup = first + period - 1;
2082            if warmup >= cols {
2083                return Ok(());
2084            }
2085
2086            let mom = momentum_map.get(&period).expect("momentum precomputed");
2087
2088            let mut sum = 0.0f64;
2089            for j in first..=warmup {
2090                sum += tr[j];
2091            }
2092
2093            let mut prev_alpha = f64::NAN;
2094            let mut prev1 = f64::NAN;
2095            let mut prev2 = f64::NAN;
2096
2097            for i in warmup..cols {
2098                let a = sum / period as f64;
2099                let up = low[i] - a * coeff;
2100                let dn = high[i] + a * coeff;
2101                let m_ge_50 = mom[i] >= 50.0;
2102
2103                let cur = if i == warmup {
2104                    if m_ge_50 {
2105                        up
2106                    } else {
2107                        dn
2108                    }
2109                } else if m_ge_50 {
2110                    if up < prev_alpha {
2111                        prev_alpha
2112                    } else {
2113                        up
2114                    }
2115                } else {
2116                    if dn > prev_alpha {
2117                        prev_alpha
2118                    } else {
2119                        dn
2120                    }
2121                };
2122
2123                k1_row[i] = cur;
2124                if i >= warmup + 2 {
2125                    k2_row[i] = prev2;
2126                }
2127
2128                prev2 = prev1;
2129                prev1 = cur;
2130                prev_alpha = cur;
2131
2132                if i + 1 < cols {
2133                    sum += tr[i + 1] - tr[i + 1 - period];
2134                }
2135            }
2136            Ok(())
2137        };
2138
2139    #[cfg(not(target_arch = "wasm32"))]
2140    if parallel {
2141        use rayon::prelude::*;
2142        k1_slice
2143            .par_chunks_mut(cols)
2144            .zip(k2_slice.par_chunks_mut(cols))
2145            .enumerate()
2146            .try_for_each(|(row, (k1r, k2r))| do_row(row, k1r, k2r))?;
2147    } else {
2148        for (row, (k1r, k2r)) in k1_slice
2149            .chunks_mut(cols)
2150            .zip(k2_slice.chunks_mut(cols))
2151            .enumerate()
2152        {
2153            do_row(row, k1r, k2r)?;
2154        }
2155    }
2156
2157    #[cfg(target_arch = "wasm32")]
2158    for (row, (k1r, k2r)) in k1_slice
2159        .chunks_mut(cols)
2160        .zip(k2_slice.chunks_mut(cols))
2161        .enumerate()
2162    {
2163        do_row(row, k1r, k2r)?;
2164    }
2165
2166    Ok(())
2167}
2168
2169#[cfg(feature = "python")]
2170#[pyfunction(name = "alphatrend_batch")]
2171#[pyo3(signature = (open, high, low, close, volume, coeff_range, period_range, no_volume=false, kernel=None))]
2172pub fn alphatrend_batch_py<'py>(
2173    py: Python<'py>,
2174    open: PyReadonlyArray1<'py, f64>,
2175    high: PyReadonlyArray1<'py, f64>,
2176    low: PyReadonlyArray1<'py, f64>,
2177    close: PyReadonlyArray1<'py, f64>,
2178    volume: PyReadonlyArray1<'py, f64>,
2179    coeff_range: (f64, f64, f64),
2180    period_range: (usize, usize, usize),
2181    no_volume: bool,
2182    kernel: Option<&str>,
2183) -> PyResult<Bound<'py, PyDict>> {
2184    use numpy::PyArray1;
2185
2186    let (o, h, l, c, v) = (
2187        open.as_slice()?,
2188        high.as_slice()?,
2189        low.as_slice()?,
2190        close.as_slice()?,
2191        volume.as_slice()?,
2192    );
2193    let len = c.len();
2194    if o.len() != len || h.len() != len || l.len() != len || v.len() != len {
2195        return Err(PyValueError::new_err("Inconsistent data lengths"));
2196    }
2197
2198    let sweep = AlphaTrendBatchRange {
2199        coeff: coeff_range,
2200        period: period_range,
2201        no_volume,
2202    };
2203    let kern = validate_kernel(kernel, true)?;
2204
2205    let rows = {
2206        fn axis_usize((s, e, st): (usize, usize, usize)) -> usize {
2207            if st == 0 || s == e {
2208                1
2209            } else {
2210                (e - s) / st + 1
2211            }
2212        }
2213        fn axis_f64((s, e, st): (f64, f64, f64)) -> usize {
2214            if st.abs() < 1e-12 || (s - e).abs() < 1e-12 {
2215                1
2216            } else {
2217                ((e - s) / st).floor() as usize + 1
2218            }
2219        }
2220        axis_f64(coeff_range) * axis_usize(period_range)
2221    };
2222
2223    let out_k1 = unsafe { PyArray1::<f64>::new(py, [rows * len], false) };
2224    let out_k2 = unsafe { PyArray1::<f64>::new(py, [rows * len], false) };
2225    let k1_slice = unsafe { out_k1.as_slice_mut()? };
2226    let k2_slice = unsafe { out_k2.as_slice_mut()? };
2227
2228    py.allow_threads(|| {
2229        alphatrend_batch_inner_into_slices(o, h, l, c, v, &sweep, kern, true, k1_slice, k2_slice)
2230    })
2231    .map_err(|e| PyValueError::new_err(e.to_string()))?;
2232
2233    let dict = PyDict::new(py);
2234    dict.set_item("k1", out_k1.reshape([rows, len])?)?;
2235    dict.set_item("k2", out_k2.reshape([rows, len])?)?;
2236    dict.set_item("rows", rows)?;
2237    dict.set_item("cols", len)?;
2238
2239    let combos =
2240        expand_grid_alphatrend(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2241    let combo_list = PyList::new(
2242        py,
2243        combos.iter().map(|c| {
2244            let d = PyDict::new(py);
2245            d.set_item("coeff", c.coeff.unwrap_or(1.0)).unwrap();
2246            d.set_item("period", c.period.unwrap_or(14)).unwrap();
2247            d.set_item("no_volume", c.no_volume.unwrap_or(false))
2248                .unwrap();
2249            d
2250        }),
2251    )?;
2252    dict.set_item("combos", combo_list)?;
2253
2254    Ok(dict.into())
2255}
2256
2257#[cfg(all(feature = "python", feature = "cuda"))]
2258#[pyfunction(name = "alphatrend_cuda_batch_dev")]
2259#[pyo3(signature = (high_f32, low_f32, close_f32, volume_f32, coeff_range, period_range, no_volume=false, device_id=0))]
2260pub fn alphatrend_cuda_batch_dev_py<'py>(
2261    py: Python<'py>,
2262    high_f32: PyReadonlyArray1<'py, f32>,
2263    low_f32: PyReadonlyArray1<'py, f32>,
2264    close_f32: PyReadonlyArray1<'py, f32>,
2265    volume_f32: PyReadonlyArray1<'py, f32>,
2266    coeff_range: (f64, f64, f64),
2267    period_range: (usize, usize, usize),
2268    no_volume: bool,
2269    device_id: usize,
2270) -> PyResult<Bound<'py, PyDict>> {
2271    use crate::cuda::cuda_available;
2272    use numpy::IntoPyArray;
2273    if !cuda_available() {
2274        return Err(PyValueError::new_err("CUDA not available"));
2275    }
2276    let (h, l, c, v) = (
2277        high_f32.as_slice()?,
2278        low_f32.as_slice()?,
2279        close_f32.as_slice()?,
2280        volume_f32.as_slice()?,
2281    );
2282    if h.len() != l.len() || h.len() != c.len() || h.len() != v.len() {
2283        return Err(PyValueError::new_err("Inconsistent data lengths"));
2284    }
2285    let sweep = AlphaTrendBatchRange {
2286        coeff: coeff_range,
2287        period: period_range,
2288        no_volume,
2289    };
2290    let (batch, coeffs_vec, periods_vec, ctx_guard, dev_id) = py.allow_threads(|| {
2291        let cuda =
2292            CudaAlphaTrend::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2293        let out = cuda
2294            .alphatrend_batch_dev(h, l, c, v, &sweep)
2295            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2296        let coeffs: Vec<f64> = out.combos.iter().map(|p| p.coeff.unwrap_or(1.0)).collect();
2297        let periods: Vec<u64> = out
2298            .combos
2299            .iter()
2300            .map(|p| p.period.unwrap_or(14) as u64)
2301            .collect();
2302        Ok::<_, PyErr>((out, coeffs, periods, cuda.context_arc(), cuda.device_id()))
2303    })?;
2304
2305    let rows = batch.k1.rows;
2306    let cols = batch.k1.cols;
2307    let dict = PyDict::new(py);
2308
2309    let k1_py = AtDeviceArrayF32Py {
2310        buf: Some(batch.k1.buf),
2311        rows,
2312        cols,
2313        _ctx: ctx_guard.clone(),
2314        device_id: dev_id,
2315    };
2316    let k2_py = AtDeviceArrayF32Py {
2317        buf: Some(batch.k2.buf),
2318        rows,
2319        cols,
2320        _ctx: ctx_guard,
2321        device_id: dev_id,
2322    };
2323    dict.set_item("k1", Py::new(py, k1_py)?)?;
2324    dict.set_item("k2", Py::new(py, k2_py)?)?;
2325    dict.set_item("coeffs", coeffs_vec.into_pyarray(py))?;
2326    dict.set_item("periods", periods_vec.into_pyarray(py))?;
2327    dict.set_item("rows", rows)?;
2328    dict.set_item("cols", cols)?;
2329    Ok(dict)
2330}
2331
2332#[cfg(all(feature = "python", feature = "cuda"))]
2333#[pyfunction(name = "alphatrend_cuda_many_series_one_param_dev")]
2334#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, volume_tm_f32, cols, rows, coeff=1.0, period=14, no_volume=false, device_id=0))]
2335pub fn alphatrend_cuda_many_series_one_param_dev_py<'py>(
2336    py: Python<'py>,
2337    high_tm_f32: PyReadonlyArray1<'py, f32>,
2338    low_tm_f32: PyReadonlyArray1<'py, f32>,
2339    close_tm_f32: PyReadonlyArray1<'py, f32>,
2340    volume_tm_f32: PyReadonlyArray1<'py, f32>,
2341    cols: usize,
2342    rows: usize,
2343    coeff: f64,
2344    period: usize,
2345    no_volume: bool,
2346    device_id: usize,
2347) -> PyResult<(AtDeviceArrayF32Py, AtDeviceArrayF32Py)> {
2348    use crate::cuda::cuda_available;
2349    if !cuda_available() {
2350        return Err(PyValueError::new_err("CUDA not available"));
2351    }
2352    let (h, l, c, v) = (
2353        high_tm_f32.as_slice()?,
2354        low_tm_f32.as_slice()?,
2355        close_tm_f32.as_slice()?,
2356        volume_tm_f32.as_slice()?,
2357    );
2358    if h.len() != cols * rows
2359        || l.len() != cols * rows
2360        || c.len() != cols * rows
2361        || v.len() != cols * rows
2362    {
2363        return Err(PyValueError::new_err("Inconsistent time-major shapes"));
2364    }
2365    let (k1, k2, ctx_guard, dev_id) = py.allow_threads(|| {
2366        let cuda =
2367            CudaAlphaTrend::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2368        let out = cuda
2369            .alphatrend_many_series_one_param_time_major_dev(
2370                h, l, c, v, cols, rows, coeff, period, no_volume,
2371            )
2372            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2373        Ok::<_, PyErr>((out.0, out.1, cuda.context_arc(), cuda.device_id()))
2374    })?;
2375    Ok((
2376        AtDeviceArrayF32Py {
2377            buf: Some(k1.buf),
2378            rows: k1.rows,
2379            cols: k1.cols,
2380            _ctx: ctx_guard.clone(),
2381            device_id: dev_id,
2382        },
2383        AtDeviceArrayF32Py {
2384            buf: Some(k2.buf),
2385            rows: k2.rows,
2386            cols: k2.cols,
2387            _ctx: ctx_guard,
2388            device_id: dev_id,
2389        },
2390    ))
2391}
2392
2393#[cfg(all(feature = "python", feature = "cuda"))]
2394#[pyclass(module = "ta_indicators.cuda", unsendable)]
2395pub struct AtDeviceArrayF32Py {
2396    pub(crate) buf: Option<DeviceBuffer<f32>>,
2397    pub(crate) rows: usize,
2398    pub(crate) cols: usize,
2399    pub(crate) _ctx: Arc<Context>,
2400    pub(crate) device_id: u32,
2401}
2402
2403#[cfg(all(feature = "python", feature = "cuda"))]
2404#[pymethods]
2405impl AtDeviceArrayF32Py {
2406    #[getter]
2407    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2408        let d = PyDict::new(py);
2409        d.set_item("shape", (self.rows, self.cols))?;
2410        d.set_item("typestr", "<f4")?;
2411        d.set_item(
2412            "strides",
2413            (
2414                self.cols * std::mem::size_of::<f32>(),
2415                std::mem::size_of::<f32>(),
2416            ),
2417        )?;
2418        let ptr = self
2419            .buf
2420            .as_ref()
2421            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
2422            .as_device_ptr()
2423            .as_raw() as usize;
2424        d.set_item("data", (ptr, false))?;
2425
2426        d.set_item("version", 3)?;
2427        Ok(d)
2428    }
2429
2430    fn __dlpack_device__(&self) -> (i32, i32) {
2431        (2, self.device_id as i32)
2432    }
2433
2434    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2435    fn __dlpack__<'py>(
2436        &mut self,
2437        py: Python<'py>,
2438        stream: Option<PyObject>,
2439        max_version: Option<PyObject>,
2440        dl_device: Option<PyObject>,
2441        copy: Option<PyObject>,
2442    ) -> PyResult<PyObject> {
2443        let (kdl, alloc_dev) = self.__dlpack_device__();
2444        if let Some(dev_obj) = dl_device.as_ref() {
2445            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2446                if dev_ty != kdl || dev_id != alloc_dev {
2447                    let wants_copy = copy
2448                        .as_ref()
2449                        .and_then(|c| c.extract::<bool>(py).ok())
2450                        .unwrap_or(false);
2451                    if wants_copy {
2452                        return Err(PyValueError::new_err(
2453                            "device copy not implemented for __dlpack__",
2454                        ));
2455                    } else {
2456                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2457                    }
2458                }
2459            }
2460        }
2461        let _ = stream;
2462
2463        let buf = self
2464            .buf
2465            .take()
2466            .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
2467
2468        let rows = self.rows;
2469        let cols = self.cols;
2470        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2471
2472        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2473    }
2474}
2475
2476#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2477#[derive(Serialize, Deserialize)]
2478pub struct AlphaTrendBatchJsOutput {
2479    pub values: Vec<f64>,
2480    pub combos: Vec<AlphaTrendParams>,
2481    pub rows: usize,
2482    pub cols: usize,
2483}
2484
2485#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2486#[wasm_bindgen(js_name = alphatrend_batch)]
2487pub fn alphatrend_batch_js(
2488    open: &[f64],
2489    high: &[f64],
2490    low: &[f64],
2491    close: &[f64],
2492    volume: &[f64],
2493    coeff_start: f64,
2494    coeff_end: f64,
2495    coeff_step: f64,
2496    period_start: usize,
2497    period_end: usize,
2498    period_step: usize,
2499    no_volume: bool,
2500) -> Result<JsValue, JsValue> {
2501    let sweep = AlphaTrendBatchRange {
2502        coeff: (coeff_start, coeff_end, coeff_step),
2503        period: (period_start, period_end, period_step),
2504        no_volume,
2505    };
2506    let combos = expand_grid_alphatrend(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2507    let rows = combos.len();
2508    let cols = close.len();
2509
2510    let total = rows
2511        .checked_mul(cols)
2512        .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
2513    let mut k1 = vec![f64::NAN; total];
2514    let mut k2 = vec![f64::NAN; total];
2515
2516    alphatrend_batch_inner_into_slices(
2517        open,
2518        high,
2519        low,
2520        close,
2521        volume,
2522        &sweep,
2523        detect_best_batch_kernel(),
2524        true,
2525        &mut k1,
2526        &mut k2,
2527    )
2528    .map_err(|e| JsValue::from_str(&e.to_string()))?;
2529
2530    let total_values = rows
2531        .checked_mul(2)
2532        .and_then(|r2| r2.checked_mul(cols))
2533        .ok_or_else(|| JsValue::from_str("rows*2*cols overflow"))?;
2534    let mut values = Vec::with_capacity(total_values);
2535    for r in 0..rows {
2536        let base = r * cols;
2537        values.extend_from_slice(&k1[base..base + cols]);
2538        values.extend_from_slice(&k2[base..base + cols]);
2539    }
2540
2541    let js = AlphaTrendBatchJsOutput {
2542        values,
2543        combos,
2544        rows: rows * 2,
2545        cols,
2546    };
2547    serde_wasm_bindgen::to_value(&js)
2548        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2549}
2550
2551#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2552#[wasm_bindgen]
2553pub fn alphatrend_batch_into_flat(
2554    open_ptr: *const f64,
2555    high_ptr: *const f64,
2556    low_ptr: *const f64,
2557    close_ptr: *const f64,
2558    volume_ptr: *const f64,
2559    out_ptr: *mut f64,
2560    len: usize,
2561    coeff_start: f64,
2562    coeff_end: f64,
2563    coeff_step: f64,
2564    period_start: usize,
2565    period_end: usize,
2566    period_step: usize,
2567    no_volume: bool,
2568) -> Result<usize, JsValue> {
2569    if [open_ptr, high_ptr, low_ptr, close_ptr, volume_ptr]
2570        .iter()
2571        .any(|&p| p.is_null())
2572        || out_ptr.is_null()
2573    {
2574        return Err(JsValue::from_str("null pointer"));
2575    }
2576    unsafe {
2577        let (open, high, low, close, volume) = (
2578            core::slice::from_raw_parts(open_ptr, len),
2579            core::slice::from_raw_parts(high_ptr, len),
2580            core::slice::from_raw_parts(low_ptr, len),
2581            core::slice::from_raw_parts(close_ptr, len),
2582            core::slice::from_raw_parts(volume_ptr, len),
2583        );
2584        let sweep = AlphaTrendBatchRange {
2585            coeff: (coeff_start, coeff_end, coeff_step),
2586            period: (period_start, period_end, period_step),
2587            no_volume,
2588        };
2589        let combos =
2590            expand_grid_alphatrend(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2591        let rows = combos.len();
2592        let cols = len;
2593        let total = rows
2594            .checked_mul(cols)
2595            .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
2596
2597        let k1 = core::slice::from_raw_parts_mut(out_ptr, total);
2598        let k2 = core::slice::from_raw_parts_mut(out_ptr.add(total), total);
2599
2600        alphatrend_batch_inner_into_slices(
2601            open,
2602            high,
2603            low,
2604            close,
2605            volume,
2606            &sweep,
2607            detect_best_batch_kernel(),
2608            false,
2609            k1,
2610            k2,
2611        )
2612        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2613
2614        Ok(rows)
2615    }
2616}
2617
2618#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2619#[derive(Serialize, Deserialize)]
2620pub struct AlphaTrendBatchConfig {
2621    pub coeff_range: (f64, f64, f64),
2622    pub period_range: (usize, usize, usize),
2623    pub no_volume: bool,
2624}
2625
2626#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2627#[wasm_bindgen(js_name = alphatrend_batch_unified)]
2628pub fn alphatrend_batch_unified_js(
2629    open: &[f64],
2630    high: &[f64],
2631    low: &[f64],
2632    close: &[f64],
2633    volume: &[f64],
2634    config: JsValue,
2635) -> Result<JsValue, JsValue> {
2636    let config: AlphaTrendBatchConfig = serde_wasm_bindgen::from_value(config)
2637        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2638
2639    let sweep = AlphaTrendBatchRange {
2640        coeff: config.coeff_range,
2641        period: config.period_range,
2642        no_volume: config.no_volume,
2643    };
2644
2645    let output =
2646        alphatrend_batch_slice(open, high, low, close, volume, &sweep, detect_best_kernel())
2647            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2648
2649    let rows2 = output.rows * 2;
2650    let cols = output.cols;
2651    let total_values = rows2
2652        .checked_mul(cols)
2653        .ok_or_else(|| JsValue::from_str("rows2*cols overflow"))?;
2654    let mut values = Vec::with_capacity(total_values);
2655    for r in 0..output.rows {
2656        let base = r * cols;
2657        values.extend_from_slice(&output.values_k1[base..base + cols]);
2658        values.extend_from_slice(&output.values_k2[base..base + cols]);
2659    }
2660
2661    let js_output = AlphaTrendBatchJsOutput {
2662        values,
2663        combos: output.combos,
2664        rows: rows2,
2665        cols,
2666    };
2667
2668    serde_wasm_bindgen::to_value(&js_output)
2669        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2670}
2671
2672#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2673#[wasm_bindgen]
2674pub fn alphatrend_into(
2675    in_ptr: *const f64,
2676    out_k1_ptr: *mut f64,
2677    out_k2_ptr: *mut f64,
2678    len: usize,
2679    open_ptr: *const f64,
2680    high_ptr: *const f64,
2681    low_ptr: *const f64,
2682    volume_ptr: *const f64,
2683    coeff: f64,
2684    period: usize,
2685    no_volume: bool,
2686) -> Result<(), JsValue> {
2687    if in_ptr.is_null()
2688        || out_k1_ptr.is_null()
2689        || out_k2_ptr.is_null()
2690        || open_ptr.is_null()
2691        || high_ptr.is_null()
2692        || low_ptr.is_null()
2693        || volume_ptr.is_null()
2694    {
2695        return Err(JsValue::from_str("Null pointer passed to alphatrend_into"));
2696    }
2697
2698    unsafe {
2699        let open = std::slice::from_raw_parts(open_ptr, len);
2700        let high = std::slice::from_raw_parts(high_ptr, len);
2701        let low = std::slice::from_raw_parts(low_ptr, len);
2702        let close = std::slice::from_raw_parts(in_ptr, len);
2703        let volume = std::slice::from_raw_parts(volume_ptr, len);
2704
2705        let params = AlphaTrendParams {
2706            coeff: Some(coeff),
2707            period: Some(period),
2708            no_volume: Some(no_volume),
2709        };
2710        let input = AlphaTrendInput::from_slices(open, high, low, close, volume, params);
2711
2712        let out_k1 = std::slice::from_raw_parts_mut(out_k1_ptr, len);
2713        let out_k2 = std::slice::from_raw_parts_mut(out_k2_ptr, len);
2714
2715        alphatrend_into_slices(out_k1, out_k2, &input, Kernel::Auto)
2716            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2717
2718        Ok(())
2719    }
2720}
2721
2722#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2723#[wasm_bindgen]
2724#[deprecated(
2725    since = "1.0.0",
2726    note = "For weight reuse patterns, use the fast/unsafe API with persistent buffers"
2727)]
2728pub struct AlphaTrendContext {
2729    coeff: f64,
2730    period: usize,
2731    no_volume: bool,
2732    kernel: Kernel,
2733}
2734
2735#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2736#[wasm_bindgen]
2737#[allow(deprecated)]
2738impl AlphaTrendContext {
2739    #[wasm_bindgen(constructor)]
2740    #[deprecated(
2741        since = "1.0.0",
2742        note = "For weight reuse patterns, use the fast/unsafe API with persistent buffers"
2743    )]
2744    pub fn new(coeff: f64, period: usize, no_volume: bool) -> Result<AlphaTrendContext, JsValue> {
2745        if period == 0 {
2746            return Err(JsValue::from_str("Invalid period: 0"));
2747        }
2748        if coeff <= 0.0 || !coeff.is_finite() {
2749            return Err(JsValue::from_str(&format!(
2750                "Invalid coefficient: {}",
2751                coeff
2752            )));
2753        }
2754
2755        Ok(AlphaTrendContext {
2756            coeff,
2757            period,
2758            no_volume,
2759            kernel: Kernel::Auto,
2760        })
2761    }
2762
2763    pub fn update_into(
2764        &self,
2765        open_ptr: *const f64,
2766        high_ptr: *const f64,
2767        low_ptr: *const f64,
2768        close_ptr: *const f64,
2769        volume_ptr: *const f64,
2770        out_k1_ptr: *mut f64,
2771        out_k2_ptr: *mut f64,
2772        len: usize,
2773    ) -> Result<(), JsValue> {
2774        if len < self.period {
2775            return Err(JsValue::from_str("Data length less than period"));
2776        }
2777
2778        unsafe {
2779            let open = std::slice::from_raw_parts(open_ptr, len);
2780            let high = std::slice::from_raw_parts(high_ptr, len);
2781            let low = std::slice::from_raw_parts(low_ptr, len);
2782            let close = std::slice::from_raw_parts(close_ptr, len);
2783            let volume = std::slice::from_raw_parts(volume_ptr, len);
2784            let out_k1 = std::slice::from_raw_parts_mut(out_k1_ptr, len);
2785            let out_k2 = std::slice::from_raw_parts_mut(out_k2_ptr, len);
2786
2787            let params = AlphaTrendParams {
2788                coeff: Some(self.coeff),
2789                period: Some(self.period),
2790                no_volume: Some(self.no_volume),
2791            };
2792            let input = AlphaTrendInput::from_slices(open, high, low, close, volume, params);
2793
2794            if close_ptr == out_k1_ptr || close_ptr == out_k2_ptr {
2795                let mut temp_k1 = vec![0.0; len];
2796                let mut temp_k2 = vec![0.0; len];
2797
2798                alphatrend_into_slices(&mut temp_k1, &mut temp_k2, &input, self.kernel)
2799                    .map_err(|e| JsValue::from_str(&e.to_string()))?;
2800
2801                out_k1.copy_from_slice(&temp_k1);
2802                out_k2.copy_from_slice(&temp_k2);
2803            } else {
2804                alphatrend_into_slices(out_k1, out_k2, &input, self.kernel)
2805                    .map_err(|e| JsValue::from_str(&e.to_string()))?;
2806            }
2807        }
2808
2809        Ok(())
2810    }
2811
2812    pub fn get_warmup_period(&self) -> usize {
2813        self.period - 1
2814    }
2815}
2816
2817#[cfg(test)]
2818mod tests {
2819    use super::*;
2820    use crate::utilities::data_loader::read_candles_from_csv;
2821    use std::error::Error;
2822
2823    fn check_alphatrend_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2824        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2825        let candles = read_candles_from_csv(file_path)?;
2826
2827        let input = AlphaTrendInput::from_candles(&candles, AlphaTrendParams::default());
2828        let result = alphatrend_with_kernel(&input, kernel)?;
2829
2830        let expected_k1 = [
2831            60243.00,
2832            60243.00,
2833            60138.92857143,
2834            60088.42857143,
2835            59937.21428571,
2836        ];
2837
2838        let expected_k2 = [
2839            60542.42857143,
2840            60454.14285714,
2841            60243.00,
2842            60243.00,
2843            60138.92857143,
2844        ];
2845
2846        let start = result.k1.len().saturating_sub(5);
2847
2848        for (i, &val) in result.k1[start..].iter().enumerate() {
2849            let diff = (val - expected_k1[i]).abs();
2850            assert!(
2851                diff < 1e-6,
2852                "[{}] AlphaTrend K1 {:?} mismatch at idx {}: got {}, expected {} (diff: {})",
2853                test_name,
2854                kernel,
2855                i,
2856                val,
2857                expected_k1[i],
2858                diff
2859            );
2860        }
2861
2862        for (i, &val) in result.k2[start..].iter().enumerate() {
2863            let diff = (val - expected_k2[i]).abs();
2864            assert!(
2865                diff < 1e-6,
2866                "[{}] AlphaTrend K2 {:?} mismatch at idx {}: got {}, expected {} (diff: {})",
2867                test_name,
2868                kernel,
2869                i,
2870                val,
2871                expected_k2[i],
2872                diff
2873            );
2874        }
2875
2876        Ok(())
2877    }
2878
2879    fn check_alphatrend_partial_params(
2880        test_name: &str,
2881        kernel: Kernel,
2882    ) -> Result<(), Box<dyn Error>> {
2883        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2884        let candles = read_candles_from_csv(file_path)?;
2885
2886        let default_params = AlphaTrendParams {
2887            coeff: None,
2888            period: None,
2889            no_volume: None,
2890        };
2891        let input = AlphaTrendInput::from_candles(&candles, default_params);
2892        let output = alphatrend_with_kernel(&input, kernel)?;
2893        assert_eq!(output.k1.len(), candles.close.len());
2894        assert_eq!(output.k2.len(), candles.close.len());
2895
2896        Ok(())
2897    }
2898
2899    fn check_alphatrend_default_candles(
2900        test_name: &str,
2901        kernel: Kernel,
2902    ) -> Result<(), Box<dyn Error>> {
2903        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2904        let candles = read_candles_from_csv(file_path)?;
2905
2906        let input = AlphaTrendInput::with_default_candles(&candles);
2907        let output = alphatrend_with_kernel(&input, kernel)?;
2908        assert_eq!(output.k1.len(), candles.close.len());
2909        assert_eq!(output.k2.len(), candles.close.len());
2910
2911        Ok(())
2912    }
2913
2914    fn check_alphatrend_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2915        let open = vec![10.0, 20.0, 30.0];
2916        let high = vec![12.0, 22.0, 32.0];
2917        let low = vec![8.0, 18.0, 28.0];
2918        let close = vec![11.0, 21.0, 31.0];
2919        let volume = vec![100.0, 200.0, 300.0];
2920
2921        let params = AlphaTrendParams {
2922            coeff: Some(1.0),
2923            period: Some(0),
2924            no_volume: Some(false),
2925        };
2926        let input = AlphaTrendInput::from_slices(&open, &high, &low, &close, &volume, params);
2927        let res = alphatrend_with_kernel(&input, kernel);
2928        assert!(
2929            res.is_err(),
2930            "[{}] AlphaTrend should fail with zero period",
2931            test_name
2932        );
2933        Ok(())
2934    }
2935
2936    fn check_alphatrend_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2937        let empty: [f64; 0] = [];
2938        let params = AlphaTrendParams::default();
2939        let input = AlphaTrendInput::from_slices(&empty, &empty, &empty, &empty, &empty, params);
2940        let res = alphatrend_with_kernel(&input, kernel);
2941        assert!(
2942            res.is_err(),
2943            "[{}] AlphaTrend should fail with empty input",
2944            test_name
2945        );
2946        Ok(())
2947    }
2948
2949    fn check_alphatrend_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2950        let nan_data = [f64::NAN, f64::NAN, f64::NAN];
2951        let params = AlphaTrendParams::default();
2952        let input = AlphaTrendInput::from_slices(
2953            &nan_data, &nan_data, &nan_data, &nan_data, &nan_data, params,
2954        );
2955        let res = alphatrend_with_kernel(&input, kernel);
2956        assert!(
2957            res.is_err(),
2958            "[{}] AlphaTrend should fail with all NaN values",
2959            test_name
2960        );
2961        Ok(())
2962    }
2963
2964    fn check_alphatrend_period_exceeds_length(
2965        test_name: &str,
2966        kernel: Kernel,
2967    ) -> Result<(), Box<dyn Error>> {
2968        let data_small = [10.0, 20.0, 30.0];
2969        let params = AlphaTrendParams {
2970            coeff: Some(1.0),
2971            period: Some(10),
2972            no_volume: Some(false),
2973        };
2974        let input = AlphaTrendInput::from_slices(
2975            &data_small,
2976            &data_small,
2977            &data_small,
2978            &data_small,
2979            &data_small,
2980            params,
2981        );
2982        let res = alphatrend_with_kernel(&input, kernel);
2983        assert!(
2984            res.is_err(),
2985            "[{}] AlphaTrend should fail with period exceeding length",
2986            test_name
2987        );
2988        Ok(())
2989    }
2990
2991    fn check_alphatrend_very_small_dataset(
2992        test_name: &str,
2993        kernel: Kernel,
2994    ) -> Result<(), Box<dyn Error>> {
2995        let single_point = [42.0];
2996        let params = AlphaTrendParams {
2997            coeff: Some(1.0),
2998            period: Some(14),
2999            no_volume: Some(false),
3000        };
3001        let input = AlphaTrendInput::from_slices(
3002            &single_point,
3003            &single_point,
3004            &single_point,
3005            &single_point,
3006            &single_point,
3007            params,
3008        );
3009        let res = alphatrend_with_kernel(&input, kernel);
3010        assert!(
3011            res.is_err(),
3012            "[{}] AlphaTrend should fail with insufficient data",
3013            test_name
3014        );
3015        Ok(())
3016    }
3017
3018    fn check_alphatrend_invalid_coeff(
3019        test_name: &str,
3020        kernel: Kernel,
3021    ) -> Result<(), Box<dyn Error>> {
3022        let data = vec![1.0; 20];
3023        let params = AlphaTrendParams {
3024            coeff: Some(-1.0),
3025            period: Some(14),
3026            no_volume: Some(false),
3027        };
3028        let input = AlphaTrendInput::from_slices(&data, &data, &data, &data, &data, params);
3029        let res = alphatrend_with_kernel(&input, kernel);
3030        assert!(
3031            matches!(res, Err(AlphaTrendError::InvalidCoeff { .. })),
3032            "[{}] AlphaTrend should fail with invalid coefficient",
3033            test_name
3034        );
3035        Ok(())
3036    }
3037
3038    fn check_alphatrend_inconsistent_lengths(
3039        test_name: &str,
3040        kernel: Kernel,
3041    ) -> Result<(), Box<dyn Error>> {
3042        let open = vec![10.0, 20.0, 30.0];
3043        let high = vec![12.0, 22.0];
3044        let low = vec![8.0, 18.0, 28.0];
3045        let close = vec![11.0, 21.0, 31.0];
3046        let volume = vec![100.0, 200.0, 300.0];
3047
3048        let params = AlphaTrendParams::default();
3049        let input = AlphaTrendInput::from_slices(&open, &high, &low, &close, &volume, params);
3050        let res = alphatrend_with_kernel(&input, kernel);
3051        assert!(
3052            matches!(res, Err(AlphaTrendError::InconsistentDataLengths)),
3053            "[{}] AlphaTrend should fail with inconsistent data lengths",
3054            test_name
3055        );
3056        Ok(())
3057    }
3058
3059    fn check_alphatrend_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3060        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3061        let candles = read_candles_from_csv(file_path)?;
3062
3063        let first_params = AlphaTrendParams {
3064            coeff: Some(1.0),
3065            period: Some(14),
3066            no_volume: Some(false),
3067        };
3068        let first_input = AlphaTrendInput::from_candles(&candles, first_params);
3069        let first_result = alphatrend_with_kernel(&first_input, kernel)?;
3070
3071        let second_params = AlphaTrendParams {
3072            coeff: Some(1.0),
3073            period: Some(14),
3074            no_volume: Some(true),
3075        };
3076
3077        let k1 = &first_result.k1;
3078        let synthetic_high: Vec<f64> = k1
3079            .iter()
3080            .map(|&v| if v.is_nan() { v } else { v + 10.0 })
3081            .collect();
3082        let synthetic_low: Vec<f64> = k1
3083            .iter()
3084            .map(|&v| if v.is_nan() { v } else { v - 10.0 })
3085            .collect();
3086        let synthetic_volume = vec![1000.0; k1.len()];
3087
3088        let second_input = AlphaTrendInput::from_slices(
3089            k1,
3090            &synthetic_high,
3091            &synthetic_low,
3092            k1,
3093            &synthetic_volume,
3094            second_params,
3095        );
3096        let second_result = alphatrend_with_kernel(&second_input, kernel)?;
3097
3098        assert_eq!(second_result.k1.len(), first_result.k1.len());
3099        assert_eq!(second_result.k2.len(), first_result.k2.len());
3100
3101        Ok(())
3102    }
3103
3104    fn check_alphatrend_nan_handling(
3105        test_name: &str,
3106        kernel: Kernel,
3107    ) -> Result<(), Box<dyn Error>> {
3108        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3109        let candles = read_candles_from_csv(file_path)?;
3110
3111        let input = AlphaTrendInput::from_candles(
3112            &candles,
3113            AlphaTrendParams {
3114                coeff: Some(1.0),
3115                period: Some(14),
3116                no_volume: Some(false),
3117            },
3118        );
3119        let res = alphatrend_with_kernel(&input, kernel)?;
3120        assert_eq!(res.k1.len(), candles.close.len());
3121        assert_eq!(res.k2.len(), candles.close.len());
3122
3123        if res.k1.len() > 240 {
3124            for (i, &val) in res.k1[240..].iter().enumerate() {
3125                assert!(
3126                    !val.is_nan(),
3127                    "[{}] Found unexpected NaN in K1 at out-index {}",
3128                    test_name,
3129                    240 + i
3130                );
3131            }
3132        }
3133        Ok(())
3134    }
3135
3136    fn check_alphatrend_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3137        let params = AlphaTrendParams {
3138            coeff: Some(1.0),
3139            period: Some(14),
3140            no_volume: Some(false),
3141        };
3142
3143        let mut stream = AlphaTrendStream::try_new(params)?;
3144        let warmup = stream.get_warmup_period();
3145
3146        for i in 0..30 {
3147            let high = 100.0 + i as f64 + 2.0;
3148            let low = 100.0 + i as f64 - 2.0;
3149            let close = 100.0 + i as f64;
3150            let volume = 1000.0 + i as f64 * 10.0;
3151
3152            let result = stream.update(high, low, close, volume);
3153            if i + 1 >= warmup + 3 {
3154                let some = result.expect("streaming should emit after warmup+2");
3155                assert!(
3156                    some.0.is_finite() && some.1.is_finite(),
3157                    "[{}] Non-finite streaming outputs at i={}",
3158                    test_name,
3159                    i
3160                );
3161            } else {
3162                assert!(
3163                    result.is_none(),
3164                    "[{}] Should not emit before warmup+2 at i={}",
3165                    test_name,
3166                    i
3167                );
3168            }
3169        }
3170        Ok(())
3171    }
3172
3173    #[cfg(debug_assertions)]
3174    fn check_alphatrend_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3175        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3176        let candles = read_candles_from_csv(file_path)?;
3177
3178        let test_params = vec![
3179            AlphaTrendParams::default(),
3180            AlphaTrendParams {
3181                coeff: Some(0.5),
3182                period: Some(7),
3183                no_volume: Some(false),
3184            },
3185            AlphaTrendParams {
3186                coeff: Some(2.0),
3187                period: Some(21),
3188                no_volume: Some(true),
3189            },
3190            AlphaTrendParams {
3191                coeff: Some(1.5),
3192                period: Some(10),
3193                no_volume: Some(false),
3194            },
3195        ];
3196
3197        for (param_idx, params) in test_params.iter().enumerate() {
3198            let input = AlphaTrendInput::from_candles(&candles, params.clone());
3199            let output = alphatrend_with_kernel(&input, kernel)?;
3200
3201            for (i, &val) in output.k1.iter().chain(output.k2.iter()).enumerate() {
3202                if val.is_nan() {
3203                    continue;
3204                }
3205
3206                let bits = val.to_bits();
3207
3208                if bits == 0x11111111_11111111 {
3209                    panic!(
3210                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
3211                        with params: coeff={}, period={}, no_volume={}",
3212                        test_name,
3213                        val,
3214                        bits,
3215                        i,
3216                        params.coeff.unwrap_or(1.0),
3217                        params.period.unwrap_or(14),
3218                        params.no_volume.unwrap_or(false)
3219                    );
3220                }
3221
3222                if bits == 0x22222222_22222222 {
3223                    panic!(
3224                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
3225                        with params: coeff={}, period={}, no_volume={}",
3226                        test_name,
3227                        val,
3228                        bits,
3229                        i,
3230                        params.coeff.unwrap_or(1.0),
3231                        params.period.unwrap_or(14),
3232                        params.no_volume.unwrap_or(false)
3233                    );
3234                }
3235
3236                if bits == 0x33333333_33333333 {
3237                    panic!(
3238                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
3239                        with params: coeff={}, period={}, no_volume={}",
3240                        test_name,
3241                        val,
3242                        bits,
3243                        i,
3244                        params.coeff.unwrap_or(1.0),
3245                        params.period.unwrap_or(14),
3246                        params.no_volume.unwrap_or(false)
3247                    );
3248                }
3249            }
3250        }
3251
3252        Ok(())
3253    }
3254
3255    #[cfg(not(debug_assertions))]
3256    fn check_alphatrend_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
3257        Ok(())
3258    }
3259
3260    #[cfg(feature = "proptest")]
3261    #[allow(clippy::float_cmp)]
3262    fn check_alphatrend_property(
3263        test_name: &str,
3264        kernel: Kernel,
3265    ) -> Result<(), Box<dyn std::error::Error>> {
3266        use proptest::prelude::*;
3267
3268        let strat = (1usize..=50).prop_flat_map(|period| {
3269            (
3270                prop::collection::vec(
3271                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
3272                    period..400,
3273                ),
3274                Just(period),
3275                0.1f64..5.0f64,
3276                any::<bool>(),
3277            )
3278        });
3279
3280        proptest::test_runner::TestRunner::default()
3281            .run(&strat, |(close_data, period, coeff, no_volume)| {
3282                let high: Vec<f64> = close_data.iter().map(|&c| c + 5.0).collect();
3283                let low: Vec<f64> = close_data.iter().map(|&c| c - 5.0).collect();
3284                let open = close_data.clone();
3285                let volume = vec![1000.0; close_data.len()];
3286
3287                let params = AlphaTrendParams {
3288                    coeff: Some(coeff),
3289                    period: Some(period),
3290                    no_volume: Some(no_volume),
3291                };
3292                let input =
3293                    AlphaTrendInput::from_slices(&open, &high, &low, &close_data, &volume, params);
3294
3295                let result = alphatrend_with_kernel(&input, kernel).unwrap();
3296                let ref_result = alphatrend_with_kernel(&input, Kernel::Scalar).unwrap();
3297
3298                for i in 0..close_data.len() {
3299                    let y = result.k1[i];
3300                    let r = ref_result.k1[i];
3301
3302                    if !y.is_finite() || !r.is_finite() {
3303                        prop_assert!(
3304                            y.to_bits() == r.to_bits(),
3305                            "K1 finite/NaN mismatch idx {i}: {y} vs {r}"
3306                        );
3307                        continue;
3308                    }
3309
3310                    let ulp_diff: u64 = y.to_bits().abs_diff(r.to_bits());
3311                    prop_assert!(
3312                        (y - r).abs() <= 1e-9 || ulp_diff <= 4,
3313                        "K1 mismatch idx {i}: {y} vs {r} (ULP={ulp_diff})"
3314                    );
3315                }
3316
3317                for i in 0..close_data.len() {
3318                    let y = result.k2[i];
3319                    let r = ref_result.k2[i];
3320
3321                    if !y.is_finite() || !r.is_finite() {
3322                        prop_assert!(
3323                            y.to_bits() == r.to_bits(),
3324                            "K2 finite/NaN mismatch idx {i}: {y} vs {r}"
3325                        );
3326                        continue;
3327                    }
3328
3329                    let ulp_diff: u64 = y.to_bits().abs_diff(r.to_bits());
3330                    prop_assert!(
3331                        (y - r).abs() <= 1e-9 || ulp_diff <= 4,
3332                        "K2 mismatch idx {i}: {y} vs {r} (ULP={ulp_diff})"
3333                    );
3334                }
3335
3336                Ok(())
3337            })
3338            .unwrap();
3339
3340        Ok(())
3341    }
3342
3343    macro_rules! generate_all_alphatrend_tests {
3344        ($($test_fn:ident),*) => {
3345            paste::paste! {
3346                $(
3347                    #[test]
3348                    fn [<$test_fn _scalar_f64>]() {
3349                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
3350                    }
3351                )*
3352                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3353                $(
3354                    #[test]
3355                    fn [<$test_fn _avx2_f64>]() {
3356                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
3357                    }
3358                    #[test]
3359                    fn [<$test_fn _avx512_f64>]() {
3360                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
3361                    }
3362                )*
3363                #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
3364                $(
3365                    #[test]
3366                    fn [<$test_fn _simd128_f64>]() {
3367                        let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
3368                    }
3369                )*
3370            }
3371        }
3372    }
3373
3374    generate_all_alphatrend_tests!(
3375        check_alphatrend_accuracy,
3376        check_alphatrend_partial_params,
3377        check_alphatrend_default_candles,
3378        check_alphatrend_zero_period,
3379        check_alphatrend_empty_input,
3380        check_alphatrend_all_nan,
3381        check_alphatrend_period_exceeds_length,
3382        check_alphatrend_very_small_dataset,
3383        check_alphatrend_invalid_coeff,
3384        check_alphatrend_inconsistent_lengths,
3385        check_alphatrend_reinput,
3386        check_alphatrend_nan_handling,
3387        check_alphatrend_streaming,
3388        check_alphatrend_no_poison
3389    );
3390
3391    #[cfg(feature = "proptest")]
3392    generate_all_alphatrend_tests!(check_alphatrend_property);
3393
3394    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
3395    #[test]
3396    fn test_alphatrend_into_matches_api() -> Result<(), Box<dyn Error>> {
3397        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3398        let candles = read_candles_from_csv(file_path)?;
3399
3400        let input = AlphaTrendInput::from_candles(&candles, AlphaTrendParams::default());
3401
3402        let baseline = alphatrend(&input)?;
3403
3404        let mut out_k1 = vec![0.0; candles.close.len()];
3405        let mut out_k2 = vec![0.0; candles.close.len()];
3406        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
3407        {
3408            alphatrend_into(&input, &mut out_k1, &mut out_k2)?;
3409        }
3410
3411        fn eq_or_both_nan(a: f64, b: f64) -> bool {
3412            if a.is_nan() && b.is_nan() {
3413                true
3414            } else {
3415                a == b || (a - b).abs() <= 1e-12
3416            }
3417        }
3418
3419        assert_eq!(baseline.k1.len(), out_k1.len());
3420        assert_eq!(baseline.k2.len(), out_k2.len());
3421
3422        for i in 0..out_k1.len() {
3423            assert!(
3424                eq_or_both_nan(baseline.k1[i], out_k1[i]),
3425                "k1 mismatch at idx {i}: api={} into={}",
3426                baseline.k1[i],
3427                out_k1[i]
3428            );
3429            assert!(
3430                eq_or_both_nan(baseline.k2[i], out_k2[i]),
3431                "k2 mismatch at idx {i}: api={} into={}",
3432                baseline.k2[i],
3433                out_k2[i]
3434            );
3435        }
3436
3437        Ok(())
3438    }
3439
3440    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3441        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3442        let c = read_candles_from_csv(file)?;
3443
3444        let sweep = AlphaTrendBatchRange::default();
3445        let output = alphatrend_batch_with_kernel(&c, &sweep, kernel)?;
3446
3447        let def = AlphaTrendParams::default();
3448        let row = output.row_for_params(&def).expect("default row missing");
3449
3450        assert_eq!(output.cols, c.close.len());
3451
3452        let k1_start = row * output.cols;
3453        let k2_start = row * output.cols;
3454
3455        let expected_k1 = [
3456            60243.00,
3457            60243.00,
3458            60138.92857143,
3459            60088.42857143,
3460            59937.21428571,
3461        ];
3462
3463        let start = output.cols - 5;
3464        for (i, &expected) in expected_k1.iter().enumerate() {
3465            let actual = output.values_k1[k1_start + start + i];
3466            assert!(
3467                (actual - expected).abs() < 1e-6,
3468                "[{test}] default-row K1 mismatch at idx {i}: {actual} vs {expected}"
3469            );
3470        }
3471        Ok(())
3472    }
3473
3474    fn check_batch_sweep(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3475        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3476        let c = read_candles_from_csv(file)?;
3477
3478        let sweep = AlphaTrendBatchRange {
3479            coeff: (1.0, 2.0, 0.5),
3480            period: (10, 20, 5),
3481            no_volume: false,
3482        };
3483
3484        let output = alphatrend_batch_with_kernel(&c, &sweep, kernel)?;
3485
3486        let coeff_count = ((2.0 - 1.0) / 0.5) as usize + 1;
3487        let period_count = ((20 - 10) / 5) as usize + 1;
3488        let expected_combos = coeff_count * period_count;
3489
3490        assert_eq!(output.combos.len(), expected_combos);
3491        assert_eq!(output.rows, expected_combos);
3492        assert_eq!(output.cols, c.close.len());
3493
3494        Ok(())
3495    }
3496
3497    #[cfg(debug_assertions)]
3498    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3499        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3500        let c = read_candles_from_csv(file)?;
3501
3502        let test_configs = vec![
3503            (1.0, 1.0, 0.0, 10, 15, 5, false),
3504            (0.5, 2.0, 0.5, 14, 14, 0, true),
3505            (1.5, 1.5, 0.0, 7, 21, 7, false),
3506        ];
3507
3508        for (cfg_idx, &(c_start, c_end, c_step, p_start, p_end, p_step, no_vol)) in
3509            test_configs.iter().enumerate()
3510        {
3511            let sweep = AlphaTrendBatchRange {
3512                coeff: (c_start, c_end, c_step),
3513                period: (p_start, p_end, p_step),
3514                no_volume: no_vol,
3515            };
3516
3517            let output = alphatrend_batch_with_kernel(&c, &sweep, kernel)?;
3518
3519            for (idx, &val) in output
3520                .values_k1
3521                .iter()
3522                .chain(output.values_k2.iter())
3523                .enumerate()
3524            {
3525                if val.is_nan() {
3526                    continue;
3527                }
3528
3529                let bits = val.to_bits();
3530                let row = idx / output.cols;
3531                let col = idx % output.cols;
3532
3533                if bits == 0x11111111_11111111
3534                    || bits == 0x22222222_22222222
3535                    || bits == 0x33333333_33333333
3536                {
3537                    let combo = if row < output.combos.len() {
3538                        &output.combos[row]
3539                    } else {
3540                        &output.combos[row - output.combos.len()]
3541                    };
3542
3543                    panic!(
3544                        "[{}] Config {}: Found poison value {} (0x{:016X}) \
3545                        at row {} col {} (flat index {}) with params: coeff={}, period={}, no_volume={}",
3546                        test,
3547                        cfg_idx,
3548                        val,
3549                        bits,
3550                        row,
3551                        col,
3552                        idx,
3553                        combo.coeff.unwrap_or(1.0),
3554                        combo.period.unwrap_or(14),
3555                        combo.no_volume.unwrap_or(false)
3556                    );
3557                }
3558            }
3559        }
3560
3561        Ok(())
3562    }
3563
3564    #[cfg(not(debug_assertions))]
3565    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
3566        Ok(())
3567    }
3568
3569    macro_rules! gen_batch_tests {
3570        ($fn_name:ident) => {
3571            paste::paste! {
3572                #[test] fn [<$fn_name _scalar>]() {
3573                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3574                }
3575                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3576                #[test] fn [<$fn_name _avx2>]() {
3577                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3578                }
3579                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3580                #[test] fn [<$fn_name _avx512>]() {
3581                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3582                }
3583                #[test] fn [<$fn_name _auto_detect>]() {
3584                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3585                }
3586            }
3587        };
3588    }
3589
3590    gen_batch_tests!(check_batch_default_row);
3591    gen_batch_tests!(check_batch_sweep);
3592    gen_batch_tests!(check_batch_no_poison);
3593}