Skip to main content

vector_ta/indicators/
wavetrend.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::DeviceArrayF32;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::cuda::wavetrend::CudaWavetrend;
7#[cfg(all(feature = "python", feature = "cuda"))]
8use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
9#[cfg(feature = "python")]
10use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
11#[cfg(feature = "python")]
12use pyo3::exceptions::PyValueError;
13#[cfg(feature = "python")]
14use pyo3::prelude::*;
15#[cfg(feature = "python")]
16use pyo3::types::PyDict;
17
18#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
19use serde::{Deserialize, Serialize};
20#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
21use wasm_bindgen::prelude::*;
22
23use crate::indicators::moving_averages::ema::{ema, EmaError, EmaInput, EmaParams};
24use crate::indicators::moving_averages::sma::{sma, SmaError, SmaInput, SmaParams};
25use crate::utilities::data_loader::{source_type, Candles};
26use crate::utilities::enums::Kernel;
27use crate::utilities::helpers::{
28    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
29    make_uninit_matrix,
30};
31#[cfg(feature = "python")]
32use crate::utilities::kernel_validation::validate_kernel;
33use aligned_vec::{AVec, CACHELINE_ALIGN};
34#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
35use core::arch::x86_64::*;
36#[cfg(not(target_arch = "wasm32"))]
37use rayon::prelude::*;
38use std::convert::AsRef;
39use thiserror::Error;
40
41impl<'a> AsRef<[f64]> for WavetrendInput<'a> {
42    #[inline(always)]
43    fn as_ref(&self) -> &[f64] {
44        match &self.data {
45            WavetrendData::Slice(slice) => slice,
46            WavetrendData::Candles { candles, source } => source_type(candles, source),
47        }
48    }
49}
50
51#[derive(Debug, Clone)]
52pub enum WavetrendData<'a> {
53    Candles {
54        candles: &'a Candles,
55        source: &'a str,
56    },
57    Slice(&'a [f64]),
58}
59
60#[derive(Debug, Clone)]
61pub struct WavetrendOutput {
62    pub wt1: Vec<f64>,
63    pub wt2: Vec<f64>,
64    pub wt_diff: Vec<f64>,
65}
66
67#[derive(Debug, Clone)]
68pub struct WavetrendParams {
69    pub channel_length: Option<usize>,
70    pub average_length: Option<usize>,
71    pub ma_length: Option<usize>,
72    pub factor: Option<f64>,
73}
74
75impl Default for WavetrendParams {
76    fn default() -> Self {
77        Self {
78            channel_length: Some(9),
79            average_length: Some(12),
80            ma_length: Some(3),
81            factor: Some(0.015),
82        }
83    }
84}
85
86#[derive(Debug, Clone)]
87pub struct WavetrendInput<'a> {
88    pub data: WavetrendData<'a>,
89    pub params: WavetrendParams,
90}
91
92impl<'a> WavetrendInput<'a> {
93    #[inline]
94    pub fn from_candles(c: &'a Candles, s: &'a str, p: WavetrendParams) -> Self {
95        Self {
96            data: WavetrendData::Candles {
97                candles: c,
98                source: s,
99            },
100            params: p,
101        }
102    }
103    #[inline]
104    pub fn from_slice(sl: &'a [f64], p: WavetrendParams) -> Self {
105        Self {
106            data: WavetrendData::Slice(sl),
107            params: p,
108        }
109    }
110    #[inline]
111    pub fn with_default_candles(c: &'a Candles) -> Self {
112        Self::from_candles(c, "hlc3", WavetrendParams::default())
113    }
114    #[inline]
115    pub fn get_channel_length(&self) -> usize {
116        self.params.channel_length.unwrap_or(9)
117    }
118    #[inline]
119    pub fn get_average_length(&self) -> usize {
120        self.params.average_length.unwrap_or(12)
121    }
122    #[inline]
123    pub fn get_ma_length(&self) -> usize {
124        self.params.ma_length.unwrap_or(3)
125    }
126    #[inline]
127    pub fn get_factor(&self) -> f64 {
128        self.params.factor.unwrap_or(0.015)
129    }
130}
131
132#[derive(Copy, Clone, Debug)]
133pub struct WavetrendBuilder {
134    channel_length: Option<usize>,
135    average_length: Option<usize>,
136    ma_length: Option<usize>,
137    factor: Option<f64>,
138    kernel: Kernel,
139}
140
141impl Default for WavetrendBuilder {
142    fn default() -> Self {
143        Self {
144            channel_length: None,
145            average_length: None,
146            ma_length: None,
147            factor: None,
148            kernel: Kernel::Auto,
149        }
150    }
151}
152
153impl WavetrendBuilder {
154    #[inline(always)]
155    pub fn new() -> Self {
156        Self::default()
157    }
158    #[inline(always)]
159    pub fn channel_length(mut self, n: usize) -> Self {
160        self.channel_length = Some(n);
161        self
162    }
163    #[inline(always)]
164    pub fn average_length(mut self, n: usize) -> Self {
165        self.average_length = Some(n);
166        self
167    }
168    #[inline(always)]
169    pub fn ma_length(mut self, n: usize) -> Self {
170        self.ma_length = Some(n);
171        self
172    }
173    #[inline(always)]
174    pub fn factor(mut self, f: f64) -> Self {
175        self.factor = Some(f);
176        self
177    }
178    #[inline(always)]
179    pub fn kernel(mut self, k: Kernel) -> Self {
180        self.kernel = k;
181        self
182    }
183    #[inline(always)]
184    pub fn apply(self, c: &Candles) -> Result<WavetrendOutput, WavetrendError> {
185        let p = WavetrendParams {
186            channel_length: self.channel_length,
187            average_length: self.average_length,
188            ma_length: self.ma_length,
189            factor: self.factor,
190        };
191        let i = WavetrendInput::from_candles(c, "hlc3", p);
192        wavetrend_with_kernel(&i, self.kernel)
193    }
194    #[inline(always)]
195    pub fn apply_slice(self, d: &[f64]) -> Result<WavetrendOutput, WavetrendError> {
196        let p = WavetrendParams {
197            channel_length: self.channel_length,
198            average_length: self.average_length,
199            ma_length: self.ma_length,
200            factor: self.factor,
201        };
202        let i = WavetrendInput::from_slice(d, p);
203        wavetrend_with_kernel(&i, self.kernel)
204    }
205    #[inline(always)]
206    pub fn into_stream(self) -> Result<WavetrendStream, WavetrendError> {
207        let p = WavetrendParams {
208            channel_length: self.channel_length,
209            average_length: self.average_length,
210            ma_length: self.ma_length,
211            factor: self.factor,
212        };
213        WavetrendStream::try_new(p)
214    }
215}
216
217#[derive(Debug, Error)]
218pub enum WavetrendError {
219    #[error("wavetrend: Empty data provided.")]
220    EmptyInputData,
221    #[error("wavetrend: Empty data provided.")]
222    EmptyData,
223    #[error("wavetrend: All values are NaN.")]
224    AllValuesNaN,
225    #[error("wavetrend: Invalid channel_length = {channel_length}, data length = {data_len}")]
226    InvalidChannelLen {
227        channel_length: usize,
228        data_len: usize,
229    },
230    #[error("wavetrend: Invalid average_length = {average_length}, data length = {data_len}")]
231    InvalidAverageLen {
232        average_length: usize,
233        data_len: usize,
234    },
235    #[error("wavetrend: Invalid ma_length = {ma_length}, data length = {data_len}")]
236    InvalidMaLen { ma_length: usize, data_len: usize },
237    #[error("wavetrend: Not enough valid data: needed = {needed}, valid = {valid}")]
238    NotEnoughValidData { needed: usize, valid: usize },
239    #[error("wavetrend: Output length mismatch: expected = {expected}, got = {got}")]
240    OutputLengthMismatch { expected: usize, got: usize },
241    #[error("wavetrend: Output slice length mismatch: expected = {expected}, got = {got}")]
242    OutputSliceLengthMismatch { expected: usize, got: usize },
243    #[error("wavetrend: Invalid range: start={start}, end={end}, step={step}")]
244    InvalidRange {
245        start: String,
246        end: String,
247        step: String,
248    },
249    #[error("wavetrend: Invalid kernel for batch: {0:?}")]
250    InvalidKernelForBatch(crate::utilities::enums::Kernel),
251    #[error("wavetrend: EMA error {0}")]
252    EmaError(#[from] EmaError),
253    #[error("wavetrend: SMA error {0}")]
254    SmaError(#[from] SmaError),
255}
256
257#[inline]
258pub fn wavetrend(input: &WavetrendInput) -> Result<WavetrendOutput, WavetrendError> {
259    wavetrend_with_kernel(input, Kernel::Auto)
260}
261
262pub fn wavetrend_with_kernel(
263    input: &WavetrendInput,
264    kernel: Kernel,
265) -> Result<WavetrendOutput, WavetrendError> {
266    let data: &[f64] = input.as_ref();
267    if data.is_empty() {
268        return Err(WavetrendError::EmptyInputData);
269    }
270    let channel_len = input.get_channel_length();
271    let average_len = input.get_average_length();
272    let ma_len = input.get_ma_length();
273    let factor = input.get_factor();
274
275    let first = data
276        .iter()
277        .position(|x| !x.is_nan())
278        .ok_or(WavetrendError::AllValuesNaN)?;
279    let needed = *[channel_len, average_len, ma_len].iter().max().unwrap();
280    let valid = data.len() - first;
281
282    if channel_len == 0 || channel_len > data.len() {
283        return Err(WavetrendError::InvalidChannelLen {
284            channel_length: channel_len,
285            data_len: data.len(),
286        });
287    }
288    if average_len == 0 || average_len > data.len() {
289        return Err(WavetrendError::InvalidAverageLen {
290            average_length: average_len,
291            data_len: data.len(),
292        });
293    }
294    if ma_len == 0 || ma_len > data.len() {
295        return Err(WavetrendError::InvalidMaLen {
296            ma_length: ma_len,
297            data_len: data.len(),
298        });
299    }
300    if valid < needed {
301        return Err(WavetrendError::NotEnoughValidData { needed, valid });
302    }
303
304    let chosen = match kernel {
305        Kernel::Auto => detect_best_kernel(),
306
307        Kernel::Avx2 | Kernel::Avx512 => Kernel::Scalar,
308        other => other,
309    };
310
311    unsafe {
312        match chosen {
313            Kernel::Scalar | Kernel::ScalarBatch => {
314                wavetrend_scalar(data, channel_len, average_len, ma_len, factor, first)
315            }
316            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
317            Kernel::Avx2 | Kernel::Avx2Batch => {
318                wavetrend_avx2(data, channel_len, average_len, ma_len, factor, first)
319            }
320            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
321            Kernel::Avx512 | Kernel::Avx512Batch => {
322                wavetrend_avx512(data, channel_len, average_len, ma_len, factor, first)
323            }
324            _ => unreachable!(),
325        }
326    }
327}
328
329fn wavetrend_kernel_dispatch(
330    data: &[f64],
331    channel_len: usize,
332    average_len: usize,
333    ma_len: usize,
334    factor: f64,
335    first: usize,
336    kernel: Kernel,
337) -> Result<WavetrendOutput, WavetrendError> {
338    let warmup_period = first + channel_len - 1 + average_len - 1 + ma_len - 1;
339
340    let mut wt1_final = alloc_with_nan_prefix(data.len(), warmup_period);
341    let mut wt2_final = alloc_with_nan_prefix(data.len(), warmup_period);
342    let mut diff_final = alloc_with_nan_prefix(data.len(), warmup_period);
343
344    wavetrend_compute_into(
345        data,
346        channel_len,
347        average_len,
348        ma_len,
349        factor,
350        first,
351        warmup_period,
352        &mut wt1_final,
353        &mut wt2_final,
354        &mut diff_final,
355        kernel,
356    )?;
357
358    Ok(WavetrendOutput {
359        wt1: wt1_final,
360        wt2: wt2_final,
361        wt_diff: diff_final,
362    })
363}
364
365pub fn wavetrend_scalar(
366    data: &[f64],
367    channel_len: usize,
368    average_len: usize,
369    ma_len: usize,
370    factor: f64,
371    first: usize,
372) -> Result<WavetrendOutput, WavetrendError> {
373    wavetrend_kernel_dispatch(
374        data,
375        channel_len,
376        average_len,
377        ma_len,
378        factor,
379        first,
380        Kernel::Scalar,
381    )
382}
383
384use std::collections::VecDeque;
385
386#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
387#[inline]
388pub unsafe fn wavetrend_avx2(
389    data: &[f64],
390    channel_len: usize,
391    average_len: usize,
392    ma_len: usize,
393    factor: f64,
394    first: usize,
395) -> Result<WavetrendOutput, WavetrendError> {
396    let warmup_period = first + channel_len - 1 + average_len - 1 + ma_len - 1;
397
398    let mut wt1_out = alloc_with_nan_prefix(data.len(), warmup_period);
399    let mut wt2_out = alloc_with_nan_prefix(data.len(), warmup_period);
400    let mut diff_out = alloc_with_nan_prefix(data.len(), warmup_period);
401
402    wavetrend_fused_avx2_into(
403        data,
404        channel_len,
405        average_len,
406        ma_len,
407        factor,
408        first,
409        warmup_period,
410        &mut wt1_out,
411        &mut wt2_out,
412        &mut diff_out,
413    );
414
415    Ok(WavetrendOutput {
416        wt1: wt1_out,
417        wt2: wt2_out,
418        wt_diff: diff_out,
419    })
420}
421
422#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
423#[target_feature(enable = "avx2")]
424#[target_feature(enable = "fma")]
425unsafe fn wavetrend_fused_avx2_into(
426    data: &[f64],
427    channel_len: usize,
428    average_len: usize,
429    ma_len: usize,
430    factor: f64,
431    first: usize,
432    warmup_period: usize,
433    dst_wt1: &mut [f64],
434    dst_wt2: &mut [f64],
435    dst_wt_diff: &mut [f64],
436) {
437    let n = data.len();
438    if n == 0 {
439        return;
440    }
441
442    let alpha_ch = 2.0 / (channel_len as f64 + 1.0);
443    let beta_ch = 1.0 - alpha_ch;
444    let alpha_avg = 2.0 / (average_len as f64 + 1.0);
445    let beta_avg = 1.0 - alpha_avg;
446
447    let mut esa_state: f64 = f64::NAN;
448    let mut de_state: f64 = f64::NAN;
449    let mut wt1_state: f64 = f64::NAN;
450    let mut esa_seeded = false;
451    let mut de_seeded = false;
452    let mut wt1_seeded = false;
453
454    let mut ring_vals = vec![f64::NAN; ma_len];
455    let mut ring_mask = vec![0u8; ma_len];
456    let mut head = 0usize;
457    let mut sma_sum = 0.0f64;
458    let mut sma_count = 0usize;
459    let inv_ma = 1.0 / (ma_len as f64);
460
461    for idx in first..n {
462        let x = data[idx];
463        let mut wt1_i = f64::NAN;
464        let mut wt2_i = f64::NAN;
465
466        if x.is_finite() {
467            if !esa_seeded {
468                esa_state = x;
469                esa_seeded = true;
470            } else {
471                esa_state = x.mul_add(alpha_ch, beta_ch * esa_state);
472            }
473
474            let abs_diff = (x - esa_state).abs();
475            if !de_seeded {
476                de_state = abs_diff;
477                de_seeded = true;
478            } else {
479                de_state = abs_diff.mul_add(alpha_ch, beta_ch * de_state);
480            }
481
482            let den = factor * de_state;
483            if den != 0.0 && den.is_finite() && esa_state.is_finite() {
484                let ci = (x - esa_state) / den;
485                if ci.is_finite() {
486                    if !wt1_seeded {
487                        wt1_state = ci;
488                        wt1_seeded = true;
489                    } else {
490                        wt1_state = ci.mul_add(alpha_avg, beta_avg * wt1_state);
491                    }
492                    wt1_i = wt1_state;
493                }
494            }
495        }
496
497        if ring_mask[head] != 0 {
498            sma_sum -= ring_vals[head];
499            sma_count -= 1;
500        }
501        if wt1_i.is_finite() {
502            ring_vals[head] = wt1_i;
503            ring_mask[head] = 1;
504            sma_sum += wt1_i;
505            sma_count += 1;
506        } else {
507            ring_vals[head] = f64::NAN;
508            ring_mask[head] = 0;
509        }
510        head += 1;
511        if head == ma_len {
512            head = 0;
513        }
514        if sma_count == ma_len {
515            wt2_i = sma_sum * inv_ma;
516        }
517
518        if idx >= warmup_period {
519            dst_wt1[idx] = wt1_i;
520            dst_wt2[idx] = wt2_i;
521            dst_wt_diff[idx] = if wt1_i.is_finite() && wt2_i.is_finite() {
522                wt2_i - wt1_i
523            } else {
524                f64::NAN
525            };
526        }
527    }
528}
529
530#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
531#[inline]
532pub unsafe fn wavetrend_avx512(
533    data: &[f64],
534    channel_len: usize,
535    average_len: usize,
536    ma_len: usize,
537    factor: f64,
538    first: usize,
539) -> Result<WavetrendOutput, WavetrendError> {
540    if channel_len <= 32 {
541        wavetrend_avx512_short(data, channel_len, average_len, ma_len, factor, first)
542    } else {
543        wavetrend_avx512_long(data, channel_len, average_len, ma_len, factor, first)
544    }
545}
546
547#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
548#[inline]
549pub unsafe fn wavetrend_avx512_short(
550    data: &[f64],
551    channel_len: usize,
552    average_len: usize,
553    ma_len: usize,
554    factor: f64,
555    first: usize,
556) -> Result<WavetrendOutput, WavetrendError> {
557    wavetrend_kernel_dispatch(
558        data,
559        channel_len,
560        average_len,
561        ma_len,
562        factor,
563        first,
564        Kernel::Avx512,
565    )
566}
567
568#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
569#[inline]
570pub unsafe fn wavetrend_avx512_long(
571    data: &[f64],
572    channel_len: usize,
573    average_len: usize,
574    ma_len: usize,
575    factor: f64,
576    first: usize,
577) -> Result<WavetrendOutput, WavetrendError> {
578    wavetrend_kernel_dispatch(
579        data,
580        channel_len,
581        average_len,
582        ma_len,
583        factor,
584        first,
585        Kernel::Avx512,
586    )
587}
588
589#[inline(always)]
590fn wavetrend_prepare<'a>(
591    input: &'a WavetrendInput,
592) -> Result<(&'a [f64], usize, usize, usize, f64, usize, usize), WavetrendError> {
593    let data: &[f64] = input.as_ref();
594    if data.is_empty() {
595        return Err(WavetrendError::EmptyInputData);
596    }
597
598    let first = data
599        .iter()
600        .position(|x| !x.is_nan())
601        .ok_or(WavetrendError::AllValuesNaN)?;
602    let channel_len = input.get_channel_length();
603    let average_len = input.get_average_length();
604    let ma_len = input.get_ma_length();
605    let factor = input.get_factor();
606
607    if channel_len == 0 || channel_len > data.len() {
608        return Err(WavetrendError::InvalidChannelLen {
609            channel_length: channel_len,
610            data_len: data.len(),
611        });
612    }
613    if average_len == 0 || average_len > data.len() {
614        return Err(WavetrendError::InvalidAverageLen {
615            average_length: average_len,
616            data_len: data.len(),
617        });
618    }
619    if ma_len == 0 || ma_len > data.len() {
620        return Err(WavetrendError::InvalidMaLen {
621            ma_length: ma_len,
622            data_len: data.len(),
623        });
624    }
625
626    let max_period = channel_len.max(average_len).max(ma_len);
627    if data.len() - first < max_period {
628        return Err(WavetrendError::NotEnoughValidData {
629            needed: max_period,
630            valid: data.len() - first,
631        });
632    }
633
634    let warmup_period = first + channel_len - 1 + average_len - 1 + ma_len - 1;
635
636    Ok((
637        data,
638        channel_len,
639        average_len,
640        ma_len,
641        factor,
642        first,
643        warmup_period,
644    ))
645}
646
647#[inline(always)]
648fn wavetrend_compute_into(
649    data: &[f64],
650    channel_len: usize,
651    average_len: usize,
652    ma_len: usize,
653    factor: f64,
654    first: usize,
655    warmup_period: usize,
656    dst_wt1: &mut [f64],
657    dst_wt2: &mut [f64],
658    dst_wt_diff: &mut [f64],
659    kernel: Kernel,
660) -> Result<(), WavetrendError> {
661    if matches!(kernel.to_non_batch(), Kernel::Scalar) {
662        let n = data.len();
663        if n == 0 {
664            return Ok(());
665        }
666
667        let alpha_ch = 2.0 / (channel_len as f64 + 1.0);
668        let beta_ch = 1.0 - alpha_ch;
669        let alpha_avg = 2.0 / (average_len as f64 + 1.0);
670        let beta_avg = 1.0 - alpha_avg;
671
672        let mut esa_state: f64 = f64::NAN;
673        let mut de_state: f64 = f64::NAN;
674        let mut wt1_state: f64 = f64::NAN;
675        let mut esa_seeded = false;
676        let mut de_seeded = false;
677        let mut wt1_seeded = false;
678
679        let mut ring_vals = vec![f64::NAN; ma_len];
680        let mut ring_mask = vec![0u8; ma_len];
681        let mut head = 0usize;
682        let mut sma_sum = 0.0f64;
683        let mut sma_count = 0usize;
684
685        for idx in first..n {
686            let x = data[idx];
687
688            let mut wt1_i = f64::NAN;
689            let mut wt2_i = f64::NAN;
690
691            if x.is_finite() {
692                if !esa_seeded {
693                    esa_state = x;
694                    esa_seeded = true;
695                } else {
696                    esa_state = alpha_ch * x + beta_ch * esa_state;
697                }
698
699                let abs_diff = (x - esa_state).abs();
700                if !de_seeded {
701                    de_state = abs_diff;
702                    de_seeded = true;
703                } else {
704                    de_state = alpha_ch * abs_diff + beta_ch * de_state;
705                }
706
707                let den = factor * de_state;
708                if den != 0.0 && den.is_finite() && esa_state.is_finite() {
709                    let ci = (x - esa_state) / den;
710                    if ci.is_finite() {
711                        if !wt1_seeded {
712                            wt1_state = ci;
713                            wt1_seeded = true;
714                        } else {
715                            wt1_state = alpha_avg * ci + beta_avg * wt1_state;
716                        }
717                        wt1_i = wt1_state;
718                    }
719                }
720            }
721
722            if ma_len > 0 {
723                if ring_mask[head] != 0 {
724                    sma_sum -= ring_vals[head];
725                    sma_count -= 1;
726                }
727
728                if wt1_i.is_finite() {
729                    ring_vals[head] = wt1_i;
730                    ring_mask[head] = 1;
731                    sma_sum += wt1_i;
732                    sma_count += 1;
733                } else {
734                    ring_vals[head] = f64::NAN;
735                    ring_mask[head] = 0;
736                }
737                head += 1;
738                if head == ma_len {
739                    head = 0;
740                }
741
742                if sma_count == ma_len {
743                    wt2_i = sma_sum / (ma_len as f64);
744                }
745            }
746
747            if idx >= warmup_period {
748                dst_wt1[idx] = wt1_i;
749                dst_wt2[idx] = wt2_i;
750                dst_wt_diff[idx] = if wt1_i.is_finite() && wt2_i.is_finite() {
751                    wt2_i - wt1_i
752                } else {
753                    f64::NAN
754                };
755            }
756        }
757
758        return Ok(());
759    }
760
761    let data_valid = &data[first..];
762    let simd_kernel = kernel.to_non_batch();
763
764    if data_valid.len() <= STACK_LIMIT {
765        let mut esa_buf = [0.0f64; STACK_LIMIT];
766        let mut de_buf = [0.0f64; STACK_LIMIT];
767        let mut ci_buf = [0.0f64; STACK_LIMIT];
768        let mut wt1_buf = [0.0f64; STACK_LIMIT];
769        let mut wt2_buf = [0.0f64; STACK_LIMIT];
770
771        let esa = &mut esa_buf[..data_valid.len()];
772        let de = &mut de_buf[..data_valid.len()];
773        let ci = &mut ci_buf[..data_valid.len()];
774        let wt1 = &mut wt1_buf[..data_valid.len()];
775        let wt2 = &mut wt2_buf[..data_valid.len()];
776
777        wavetrend_core_computation(
778            data_valid,
779            channel_len,
780            average_len,
781            ma_len,
782            factor,
783            esa,
784            de,
785            ci,
786            wt1,
787            wt2,
788            simd_kernel,
789        )?;
790
791        for i in 0..data_valid.len() {
792            let out_idx = i + first;
793            if out_idx >= warmup_period {
794                dst_wt1[out_idx] = wt1[i];
795                dst_wt2[out_idx] = wt2[i];
796                if !wt1[i].is_nan() && !wt2[i].is_nan() {
797                    dst_wt_diff[out_idx] = wt2[i] - wt1[i];
798                } else {
799                    dst_wt_diff[out_idx] = f64::NAN;
800                }
801            }
802        }
803    } else {
804        let mut esa = vec![0.0; data_valid.len()];
805        let mut de = vec![0.0; data_valid.len()];
806        let mut ci = vec![0.0; data_valid.len()];
807        let mut wt1 = vec![0.0; data_valid.len()];
808        let mut wt2 = vec![0.0; data_valid.len()];
809
810        wavetrend_core_computation(
811            data_valid,
812            channel_len,
813            average_len,
814            ma_len,
815            factor,
816            &mut esa,
817            &mut de,
818            &mut ci,
819            &mut wt1,
820            &mut wt2,
821            simd_kernel,
822        )?;
823
824        for i in 0..data_valid.len() {
825            let out_idx = i + first;
826            if out_idx >= warmup_period {
827                dst_wt1[out_idx] = wt1[i];
828                dst_wt2[out_idx] = wt2[i];
829                if !wt1[i].is_nan() && !wt2[i].is_nan() {
830                    dst_wt_diff[out_idx] = wt2[i] - wt1[i];
831                } else {
832                    dst_wt_diff[out_idx] = f64::NAN;
833                }
834            }
835        }
836    }
837
838    Ok(())
839}
840
841const STACK_LIMIT: usize = 512;
842
843#[inline(always)]
844fn wavetrend_core_computation(
845    data: &[f64],
846    channel_len: usize,
847    average_len: usize,
848    ma_len: usize,
849    factor: f64,
850    esa: &mut [f64],
851    de: &mut [f64],
852    ci: &mut [f64],
853    wt1: &mut [f64],
854    wt2: &mut [f64],
855    kernel: Kernel,
856) -> Result<(), WavetrendError> {
857    ema_compute_into(data, channel_len, esa);
858
859    if data.len() <= STACK_LIMIT {
860        let mut abs_diff_buf = [0.0f64; STACK_LIMIT];
861        let abs_diff = &mut abs_diff_buf[..data.len()];
862        compute_abs_diff(abs_diff, data, esa, kernel);
863        ema_compute_into(abs_diff, channel_len, de);
864    } else {
865        let mut abs_diff = vec![0.0; data.len()];
866        compute_abs_diff(&mut abs_diff, data, esa, kernel);
867        ema_compute_into(&abs_diff, channel_len, de);
868    }
869
870    compute_ci(ci, data, esa, de, factor, kernel);
871
872    ema_compute_into(ci, average_len, wt1);
873
874    sma_compute_into(wt1, ma_len, wt2);
875
876    Ok(())
877}
878
879#[inline(always)]
880fn compute_abs_diff(out: &mut [f64], data: &[f64], esa: &[f64], kernel: Kernel) {
881    let simd = kernel.to_non_batch();
882    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
883    {
884        match simd {
885            Kernel::Avx512 => unsafe {
886                absdiff_vec_avx512(out, data, esa);
887                return;
888            },
889            Kernel::Avx2 => unsafe {
890                absdiff_vec_avx2(out, data, esa);
891                return;
892            },
893            _ => {}
894        }
895    }
896
897    for i in 0..out.len() {
898        out[i] = (data[i] - esa[i]).abs();
899    }
900}
901
902#[inline(always)]
903fn compute_ci(out: &mut [f64], data: &[f64], esa: &[f64], de: &[f64], factor: f64, kernel: Kernel) {
904    let simd = kernel.to_non_batch();
905    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
906    {
907        match simd {
908            Kernel::Avx512 => unsafe {
909                ci_vec_avx512(out, data, esa, de, factor);
910                return;
911            },
912            Kernel::Avx2 => unsafe {
913                ci_vec_avx2(out, data, esa, de, factor);
914                return;
915            },
916            _ => {}
917        }
918    }
919
920    for i in 0..out.len() {
921        let den = factor * de[i];
922        if den != 0.0 && !data[i].is_nan() && !esa[i].is_nan() && !de[i].is_nan() {
923            out[i] = (data[i] - esa[i]) / den;
924        } else {
925            out[i] = f64::NAN;
926        }
927    }
928}
929
930#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
931#[target_feature(enable = "avx2")]
932unsafe fn absdiff_vec_avx2(dst: &mut [f64], a: &[f64], b: &[f64]) {
933    let n = dst.len();
934    let pa = a.as_ptr();
935    let pb = b.as_ptr();
936    let pd = dst.as_mut_ptr();
937    let sign = _mm256_set1_pd(-0.0f64);
938    let mut i = 0usize;
939    while i + 4 <= n {
940        let va = _mm256_loadu_pd(pa.add(i));
941        let vb = _mm256_loadu_pd(pb.add(i));
942        let vd = _mm256_sub_pd(va, vb);
943        let vabs = _mm256_andnot_pd(sign, vd);
944        _mm256_storeu_pd(pd.add(i), vabs);
945        i += 4;
946    }
947    while i < n {
948        *pd.add(i) = (*pa.add(i) - *pb.add(i)).abs();
949        i += 1;
950    }
951}
952
953#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
954#[target_feature(enable = "avx2")]
955unsafe fn ci_vec_avx2(dst: &mut [f64], data: &[f64], esa: &[f64], de: &[f64], factor: f64) {
956    let n = dst.len();
957    let px = data.as_ptr();
958    let pe = esa.as_ptr();
959    let pd = de.as_ptr();
960    let pr = dst.as_mut_ptr();
961
962    let vf = _mm256_set1_pd(factor);
963    let vzero = _mm256_set1_pd(0.0);
964    let vnan = _mm256_set1_pd(f64::NAN);
965
966    let mut i = 0usize;
967    while i + 4 <= n {
968        let vx = _mm256_loadu_pd(px.add(i));
969        let ve = _mm256_loadu_pd(pe.add(i));
970        let vd = _mm256_loadu_pd(pd.add(i));
971
972        let vnum = _mm256_sub_pd(vx, ve);
973        let vden = _mm256_mul_pd(vf, vd);
974        let vci = _mm256_div_pd(vnum, vden);
975
976        let ord_x = _mm256_cmp_pd(vx, vx, _CMP_ORD_Q);
977        let ord_e = _mm256_cmp_pd(ve, ve, _CMP_ORD_Q);
978        let ord_d = _mm256_cmp_pd(vd, vd, _CMP_ORD_Q);
979        let ord_all = _mm256_and_pd(ord_x, _mm256_and_pd(ord_e, ord_d));
980        let den_zero = _mm256_cmp_pd(vden, vzero, _CMP_EQ_OQ);
981        let valid = _mm256_andnot_pd(den_zero, ord_all);
982
983        let vres = _mm256_blendv_pd(vnan, vci, valid);
984        _mm256_storeu_pd(pr.add(i), vres);
985        i += 4;
986    }
987    while i < n {
988        let x = *px.add(i);
989        let e = *pe.add(i);
990        let d = *pd.add(i);
991        let den = factor * d;
992        *pr.add(i) = if den != 0.0 && x.is_finite() && e.is_finite() && d.is_finite() {
993            (x - e) / den
994        } else {
995            f64::NAN
996        };
997        i += 1;
998    }
999}
1000
1001#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1002#[target_feature(enable = "avx512f")]
1003unsafe fn absdiff_vec_avx512(dst: &mut [f64], a: &[f64], b: &[f64]) {
1004    let n = dst.len();
1005    let pa = a.as_ptr();
1006    let pb = b.as_ptr();
1007    let pd = dst.as_mut_ptr();
1008    let sign = _mm512_set1_epi64(0x8000_0000_0000_0000u64 as i64);
1009    let sign_pd = _mm512_castsi512_pd(sign);
1010
1011    let mut i = 0usize;
1012    while i + 8 <= n {
1013        let va = _mm512_loadu_pd(pa.add(i));
1014        let vb = _mm512_loadu_pd(pb.add(i));
1015        let vd = _mm512_sub_pd(va, vb);
1016        let vabs = _mm512_andnot_pd(sign_pd, vd);
1017        _mm512_storeu_pd(pd.add(i), vabs);
1018        i += 8;
1019    }
1020    while i < n {
1021        *pd.add(i) = (*pa.add(i) - *pb.add(i)).abs();
1022        i += 1;
1023    }
1024}
1025
1026#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1027#[target_feature(enable = "avx512f")]
1028unsafe fn ci_vec_avx512(dst: &mut [f64], data: &[f64], esa: &[f64], de: &[f64], factor: f64) {
1029    let n = dst.len();
1030    let px = data.as_ptr();
1031    let pe = esa.as_ptr();
1032    let pdv = de.as_ptr();
1033    let pr = dst.as_mut_ptr();
1034
1035    let vf = _mm512_set1_pd(factor);
1036    let vzero = _mm512_set1_pd(0.0);
1037    let vnan = _mm512_set1_pd(f64::NAN);
1038
1039    let mut i = 0usize;
1040    while i + 8 <= n {
1041        let vx = _mm512_loadu_pd(px.add(i));
1042        let ve = _mm512_loadu_pd(pe.add(i));
1043        let vd = _mm512_loadu_pd(pdv.add(i));
1044
1045        let vnum = _mm512_sub_pd(vx, ve);
1046        let vden = _mm512_mul_pd(vf, vd);
1047        let vci = _mm512_div_pd(vnum, vden);
1048
1049        let ord_x = _mm512_cmp_pd_mask(vx, vx, _CMP_ORD_Q);
1050        let ord_e = _mm512_cmp_pd_mask(ve, ve, _CMP_ORD_Q);
1051        let ord_d = _mm512_cmp_pd_mask(vd, vd, _CMP_ORD_Q);
1052        let ord_all = ord_x & ord_e & ord_d;
1053        let den_zero = _mm512_cmp_pd_mask(vden, vzero, _CMP_EQ_OQ);
1054        let valid = ord_all & (!den_zero);
1055
1056        let vres = _mm512_mask_mov_pd(vnan, valid, vci);
1057        _mm512_storeu_pd(pr.add(i), vres);
1058        i += 8;
1059    }
1060    while i < n {
1061        let x = *px.add(i);
1062        let e = *pe.add(i);
1063        let d = *pdv.add(i);
1064        let den = factor * d;
1065        *pr.add(i) = if den != 0.0 && x.is_finite() && e.is_finite() && d.is_finite() {
1066            (x - e) / den
1067        } else {
1068            f64::NAN
1069        };
1070        i += 1;
1071    }
1072}
1073
1074#[inline(always)]
1075fn ema_compute_into(data: &[f64], period: usize, out: &mut [f64]) {
1076    if period == 0 || data.is_empty() {
1077        return;
1078    }
1079
1080    let alpha = 2.0 / (period as f64 + 1.0);
1081    let beta = 1.0 - alpha;
1082
1083    let mut ema_val = f64::NAN;
1084    for i in 0..data.len() {
1085        if !data[i].is_nan() {
1086            if ema_val.is_nan() {
1087                ema_val = data[i];
1088            } else {
1089                ema_val = alpha * data[i] + beta * ema_val;
1090            }
1091            out[i] = ema_val;
1092        } else {
1093            out[i] = f64::NAN;
1094        }
1095    }
1096}
1097
1098#[inline(always)]
1099fn sma_compute_into(data: &[f64], period: usize, out: &mut [f64]) {
1100    if period == 0 || data.is_empty() {
1101        return;
1102    }
1103
1104    let mut sum = 0.0;
1105    let mut count = 0;
1106
1107    for i in 0..out.len() {
1108        out[i] = f64::NAN;
1109    }
1110
1111    for i in 0..data.len() {
1112        if !data[i].is_nan() {
1113            sum += data[i];
1114            count += 1;
1115
1116            if i >= period {
1117                if !data[i - period].is_nan() {
1118                    sum -= data[i - period];
1119                    count -= 1;
1120                }
1121            }
1122
1123            if count >= period {
1124                out[i] = sum / period as f64;
1125            }
1126        }
1127    }
1128}
1129
1130#[inline]
1131pub fn wavetrend_into_slice(
1132    dst_wt1: &mut [f64],
1133    dst_wt2: &mut [f64],
1134    dst_wt_diff: &mut [f64],
1135    input: &WavetrendInput,
1136    kern: Kernel,
1137) -> Result<(), WavetrendError> {
1138    let (data, channel_len, average_len, ma_len, factor, first, warmup_period) =
1139        wavetrend_prepare(input)?;
1140
1141    if dst_wt1.len() != data.len() {
1142        return Err(WavetrendError::OutputLengthMismatch {
1143            expected: data.len(),
1144            got: dst_wt1.len(),
1145        });
1146    }
1147    if dst_wt2.len() != data.len() {
1148        return Err(WavetrendError::OutputLengthMismatch {
1149            expected: data.len(),
1150            got: dst_wt2.len(),
1151        });
1152    }
1153    if dst_wt_diff.len() != data.len() {
1154        return Err(WavetrendError::OutputLengthMismatch {
1155            expected: data.len(),
1156            got: dst_wt_diff.len(),
1157        });
1158    }
1159
1160    for i in 0..warmup_period.min(data.len()) {
1161        dst_wt1[i] = f64::NAN;
1162        dst_wt2[i] = f64::NAN;
1163        dst_wt_diff[i] = f64::NAN;
1164    }
1165
1166    let chosen = match kern {
1167        Kernel::Auto => detect_best_kernel(),
1168        Kernel::ScalarBatch => Kernel::Scalar,
1169        Kernel::Avx2Batch => Kernel::Avx2,
1170        Kernel::Avx512Batch => Kernel::Avx512,
1171        other => other,
1172    };
1173
1174    wavetrend_compute_into(
1175        data,
1176        channel_len,
1177        average_len,
1178        ma_len,
1179        factor,
1180        first,
1181        warmup_period,
1182        dst_wt1,
1183        dst_wt2,
1184        dst_wt_diff,
1185        chosen,
1186    )?;
1187
1188    Ok(())
1189}
1190
1191#[derive(Clone, Debug)]
1192pub struct WavetrendStream {
1193    pub channel_length: usize,
1194    pub average_length: usize,
1195    pub ma_length: usize,
1196    pub factor: f64,
1197
1198    esa_buf: VecDeque<f64>,
1199    last_esa: Option<f64>,
1200    alpha_ch: f64,
1201
1202    beta_ch: f64,
1203
1204    de_buf: VecDeque<f64>,
1205    last_de: Option<f64>,
1206
1207    ci_buf: VecDeque<f64>,
1208    last_wt1: Option<f64>,
1209    alpha_avg: f64,
1210
1211    beta_avg: f64,
1212
1213    wt1_buf: VecDeque<f64>,
1214    running_sum: f64,
1215
1216    sma_count: usize,
1217
1218    inv_ma: f64,
1219
1220    pub history: Vec<f64>,
1221}
1222
1223impl WavetrendStream {
1224    pub fn try_new(p: WavetrendParams) -> Result<Self, WavetrendError> {
1225        let channel_length = p.channel_length.unwrap_or(9);
1226        let average_length = p.average_length.unwrap_or(12);
1227        let ma_length = p.ma_length.unwrap_or(3);
1228        let factor = p.factor.unwrap_or(0.015);
1229
1230        if channel_length == 0 {
1231            return Err(WavetrendError::InvalidChannelLen {
1232                channel_length,
1233                data_len: 0,
1234            });
1235        }
1236        if average_length == 0 {
1237            return Err(WavetrendError::InvalidAverageLen {
1238                average_length,
1239                data_len: 0,
1240            });
1241        }
1242        if ma_length == 0 {
1243            return Err(WavetrendError::InvalidMaLen {
1244                ma_length,
1245                data_len: 0,
1246            });
1247        }
1248
1249        let alpha_ch = 2.0 / (channel_length as f64 + 1.0);
1250        let alpha_avg = 2.0 / (average_length as f64 + 1.0);
1251
1252        Ok(Self {
1253            channel_length,
1254            average_length,
1255            ma_length,
1256            factor,
1257
1258            esa_buf: VecDeque::with_capacity(channel_length),
1259            last_esa: None,
1260            alpha_ch,
1261            beta_ch: 1.0 - alpha_ch,
1262
1263            de_buf: VecDeque::with_capacity(channel_length),
1264            last_de: None,
1265
1266            ci_buf: VecDeque::with_capacity(average_length),
1267            last_wt1: None,
1268            alpha_avg,
1269            beta_avg: 1.0 - alpha_avg,
1270
1271            wt1_buf: VecDeque::with_capacity(ma_length),
1272            running_sum: 0.0,
1273            sma_count: 0,
1274            inv_ma: 1.0 / (ma_length as f64),
1275
1276            history: Vec::new(),
1277        })
1278    }
1279
1280    #[inline(always)]
1281    pub fn update(&mut self, price: f64) -> Option<(f64, f64, f64)> {
1282        self.history.push(price);
1283
1284        let mut wt1_val = f64::NAN;
1285
1286        if price.is_finite() {
1287            if let Some(prev) = self.last_esa {
1288                let new_esa = ema_step(prev, price, self.alpha_ch, self.beta_ch);
1289                self.last_esa = Some(new_esa);
1290            } else {
1291                self.last_esa = Some(price);
1292            }
1293
1294            if let Some(esa_now) = self.last_esa {
1295                let abs_diff = fast_abs_f64(price - esa_now);
1296                if let Some(prev_de) = self.last_de {
1297                    let new_de = ema_step(prev_de, abs_diff, self.alpha_ch, self.beta_ch);
1298                    self.last_de = Some(new_de);
1299                } else {
1300                    self.last_de = Some(abs_diff);
1301                }
1302            }
1303
1304            if let (Some(esa_now), Some(de_now)) = (self.last_esa, self.last_de) {
1305                let den = self.factor * de_now;
1306                if den != 0.0 && den.is_finite() && esa_now.is_finite() {
1307                    let ci = (price - esa_now) / den;
1308                    if ci.is_finite() {
1309                        if let Some(prev_wt1) = self.last_wt1 {
1310                            let new_wt1 = ema_step(prev_wt1, ci, self.alpha_avg, self.beta_avg);
1311                            self.last_wt1 = Some(new_wt1);
1312                        } else {
1313                            self.last_wt1 = Some(ci);
1314                        }
1315                        if let Some(v) = self.last_wt1 {
1316                            wt1_val = v;
1317                        }
1318                    }
1319                }
1320            }
1321        }
1322
1323        if self.wt1_buf.len() == self.ma_length {
1324            if let Some(leaving) = self.wt1_buf.pop_front() {
1325                if leaving.is_finite() {
1326                    self.running_sum -= leaving;
1327                    if self.sma_count > 0 {
1328                        self.sma_count -= 1;
1329                    }
1330                }
1331            }
1332        }
1333
1334        self.wt1_buf.push_back(wt1_val);
1335        if wt1_val.is_finite() {
1336            self.running_sum += wt1_val;
1337            self.sma_count += 1;
1338        }
1339
1340        if self.wt1_buf.len() == self.ma_length && self.sma_count == self.ma_length {
1341            let wt1 = wt1_val;
1342            let wt2 = self.running_sum * self.inv_ma;
1343            let diff = wt2 - wt1;
1344            Some((wt1, wt2, diff))
1345        } else {
1346            None
1347        }
1348    }
1349}
1350
1351#[inline(always)]
1352fn ema_step(prev: f64, x: f64, alpha: f64, beta: f64) -> f64 {
1353    x.mul_add(alpha, beta * prev)
1354}
1355
1356#[inline(always)]
1357fn fast_abs_f64(x: f64) -> f64 {
1358    f64::from_bits(x.to_bits() & 0x7FFF_FFFF_FFFF_FFFF)
1359}
1360
1361#[cfg(all(feature = "python", feature = "cuda"))]
1362use cust::context::Context;
1363#[cfg(all(feature = "python", feature = "cuda"))]
1364use std::sync::Arc;
1365
1366#[cfg(all(feature = "python", feature = "cuda"))]
1367#[pyclass(
1368    module = "ta_indicators.cuda",
1369    name = "WavetrendDeviceArrayF32",
1370    unsendable
1371)]
1372pub struct WavetrendDeviceArrayF32Py {
1373    pub(crate) inner: DeviceArrayF32,
1374    pub(crate) _ctx: Arc<Context>,
1375    pub(crate) device_id: u32,
1376}
1377
1378#[cfg(all(feature = "python", feature = "cuda"))]
1379#[pymethods]
1380impl WavetrendDeviceArrayF32Py {
1381    #[getter]
1382    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1383        let d = PyDict::new(py);
1384        d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1385        d.set_item("typestr", "<f4")?;
1386        d.set_item(
1387            "strides",
1388            (
1389                self.inner.cols * std::mem::size_of::<f32>(),
1390                std::mem::size_of::<f32>(),
1391            ),
1392        )?;
1393        d.set_item("data", (self.inner.device_ptr() as usize, false))?;
1394
1395        d.set_item("version", 3)?;
1396        Ok(d)
1397    }
1398
1399    fn __dlpack_device__(&self) -> (i32, i32) {
1400        (2, self.device_id as i32)
1401    }
1402
1403    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1404    fn __dlpack__<'py>(
1405        &mut self,
1406        py: Python<'py>,
1407        stream: Option<pyo3::PyObject>,
1408        max_version: Option<pyo3::PyObject>,
1409        dl_device: Option<pyo3::PyObject>,
1410        copy: Option<pyo3::PyObject>,
1411    ) -> PyResult<PyObject> {
1412        let (kdl, alloc_dev) = self.__dlpack_device__();
1413        if let Some(dev_obj) = dl_device.as_ref() {
1414            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1415                if dev_ty != kdl || dev_id != alloc_dev {
1416                    let wants_copy = copy
1417                        .as_ref()
1418                        .and_then(|c| c.extract::<bool>(py).ok())
1419                        .unwrap_or(false);
1420                    if wants_copy {
1421                        return Err(PyValueError::new_err(
1422                            "device copy not implemented for __dlpack__",
1423                        ));
1424                    } else {
1425                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1426                    }
1427                }
1428            }
1429        }
1430        let _ = stream;
1431
1432        let dummy = cust::memory::DeviceBuffer::from_slice(&[])
1433            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1434        let inner = std::mem::replace(
1435            &mut self.inner,
1436            DeviceArrayF32 {
1437                buf: dummy,
1438                rows: 0,
1439                cols: 0,
1440            },
1441        );
1442
1443        let rows = inner.rows;
1444        let cols = inner.cols;
1445        let buf = inner.buf;
1446
1447        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1448
1449        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1450    }
1451}
1452
1453#[derive(Clone, Debug)]
1454pub struct WavetrendBatchRange {
1455    pub channel_length: (usize, usize, usize),
1456    pub average_length: (usize, usize, usize),
1457    pub ma_length: (usize, usize, usize),
1458    pub factor: (f64, f64, f64),
1459}
1460
1461impl Default for WavetrendBatchRange {
1462    fn default() -> Self {
1463        Self {
1464            channel_length: (9, 9, 0),
1465            average_length: (12, 261, 1),
1466            ma_length: (3, 3, 0),
1467            factor: (0.015, 0.015, 0.0),
1468        }
1469    }
1470}
1471
1472#[derive(Clone, Debug, Default)]
1473pub struct WavetrendBatchBuilder {
1474    range: WavetrendBatchRange,
1475    kernel: Kernel,
1476}
1477
1478impl WavetrendBatchBuilder {
1479    pub fn new() -> Self {
1480        Self::default()
1481    }
1482    pub fn kernel(mut self, k: Kernel) -> Self {
1483        self.kernel = k;
1484        self
1485    }
1486    pub fn channel_range(mut self, start: usize, end: usize, step: usize) -> Self {
1487        self.range.channel_length = (start, end, step);
1488        self
1489    }
1490    pub fn channel_static(mut self, x: usize) -> Self {
1491        self.range.channel_length = (x, x, 0);
1492        self
1493    }
1494    pub fn avg_range(mut self, start: usize, end: usize, step: usize) -> Self {
1495        self.range.average_length = (start, end, step);
1496        self
1497    }
1498    pub fn avg_static(mut self, x: usize) -> Self {
1499        self.range.average_length = (x, x, 0);
1500        self
1501    }
1502    pub fn ma_range(mut self, start: usize, end: usize, step: usize) -> Self {
1503        self.range.ma_length = (start, end, step);
1504        self
1505    }
1506    pub fn ma_static(mut self, x: usize) -> Self {
1507        self.range.ma_length = (x, x, 0);
1508        self
1509    }
1510    pub fn factor_range(mut self, start: f64, end: f64, step: f64) -> Self {
1511        self.range.factor = (start, end, step);
1512        self
1513    }
1514    pub fn factor_static(mut self, x: f64) -> Self {
1515        self.range.factor = (x, x, 0.0);
1516        self
1517    }
1518    pub fn apply_slice(self, data: &[f64]) -> Result<WavetrendBatchOutput, WavetrendError> {
1519        wavetrend_batch_with_kernel(data, &self.range, self.kernel)
1520    }
1521    pub fn with_default_slice(
1522        data: &[f64],
1523        k: Kernel,
1524    ) -> Result<WavetrendBatchOutput, WavetrendError> {
1525        WavetrendBatchBuilder::new().kernel(k).apply_slice(data)
1526    }
1527    pub fn apply_candles(
1528        self,
1529        c: &Candles,
1530        src: &str,
1531    ) -> Result<WavetrendBatchOutput, WavetrendError> {
1532        let slice = source_type(c, src);
1533        self.apply_slice(slice)
1534    }
1535    pub fn with_default_candles(c: &Candles) -> Result<WavetrendBatchOutput, WavetrendError> {
1536        WavetrendBatchBuilder::new()
1537            .kernel(Kernel::Auto)
1538            .apply_candles(c, "hlc3")
1539    }
1540}
1541
1542pub fn wavetrend_batch_with_kernel(
1543    data: &[f64],
1544    sweep: &WavetrendBatchRange,
1545    k: Kernel,
1546) -> Result<WavetrendBatchOutput, WavetrendError> {
1547    let kernel = match k {
1548        Kernel::Auto => Kernel::ScalarBatch,
1549
1550        Kernel::Avx2Batch | Kernel::Avx512Batch => Kernel::ScalarBatch,
1551        other if other.is_batch() => other,
1552        _ => {
1553            return Err(WavetrendError::InvalidKernelForBatch(k));
1554        }
1555    };
1556    let simd = match kernel {
1557        Kernel::Avx512Batch => Kernel::Avx512,
1558        Kernel::Avx2Batch => Kernel::Avx2,
1559        Kernel::ScalarBatch => Kernel::Scalar,
1560        _ => unreachable!(),
1561    };
1562    wavetrend_batch_par_slice(data, sweep, simd)
1563}
1564
1565#[derive(Clone, Debug)]
1566pub struct WavetrendBatchOutput {
1567    pub wt1: Vec<f64>,
1568    pub wt2: Vec<f64>,
1569    pub wt_diff: Vec<f64>,
1570    pub combos: Vec<WavetrendParams>,
1571    pub rows: usize,
1572    pub cols: usize,
1573}
1574impl WavetrendBatchOutput {
1575    pub fn row_for_params(&self, p: &WavetrendParams) -> Option<usize> {
1576        self.combos.iter().position(|c| {
1577            c.channel_length.unwrap_or(9) == p.channel_length.unwrap_or(9)
1578                && c.average_length.unwrap_or(12) == p.average_length.unwrap_or(12)
1579                && c.ma_length.unwrap_or(3) == p.ma_length.unwrap_or(3)
1580                && (c.factor.unwrap_or(0.015) - p.factor.unwrap_or(0.015)).abs() < 1e-12
1581        })
1582    }
1583    pub fn values_for(&self, p: &WavetrendParams) -> Option<(&[f64], &[f64], &[f64])> {
1584        self.row_for_params(p).map(|row| {
1585            let start = row * self.cols;
1586            (
1587                &self.wt1[start..start + self.cols],
1588                &self.wt2[start..start + self.cols],
1589                &self.wt_diff[start..start + self.cols],
1590            )
1591        })
1592    }
1593}
1594
1595#[inline(always)]
1596fn expand_grid(r: &WavetrendBatchRange) -> Result<Vec<WavetrendParams>, WavetrendError> {
1597    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, WavetrendError> {
1598        if step == 0 || start == end {
1599            return Ok(vec![start]);
1600        }
1601        if start < end {
1602            let st = step.max(1);
1603            return Ok((start..=end).step_by(st).collect());
1604        }
1605
1606        let st = step.max(1) as isize;
1607        let mut v = Vec::new();
1608        let mut x = start as isize;
1609        let end_i = end as isize;
1610        while x >= end_i {
1611            v.push(x as usize);
1612            x -= st;
1613        }
1614        if v.is_empty() {
1615            return Err(WavetrendError::InvalidRange {
1616                start: start.to_string(),
1617                end: end.to_string(),
1618                step: step.to_string(),
1619            });
1620        }
1621        Ok(v)
1622    }
1623    fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, WavetrendError> {
1624        if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
1625            return Ok(vec![start]);
1626        }
1627        if start < end {
1628            let mut v = Vec::new();
1629            let mut x = start;
1630            let st = step.abs();
1631            while x <= end + 1e-12 {
1632                v.push(x);
1633                x += st;
1634            }
1635            if v.is_empty() {
1636                return Err(WavetrendError::InvalidRange {
1637                    start: start.to_string(),
1638                    end: end.to_string(),
1639                    step: step.to_string(),
1640                });
1641            }
1642            return Ok(v);
1643        }
1644        let mut v = Vec::new();
1645        let mut x = start;
1646        let st = step.abs();
1647        while x + 1e-12 >= end {
1648            v.push(x);
1649            x -= st;
1650        }
1651        if v.is_empty() {
1652            return Err(WavetrendError::InvalidRange {
1653                start: start.to_string(),
1654                end: end.to_string(),
1655                step: step.to_string(),
1656            });
1657        }
1658        Ok(v)
1659    }
1660
1661    let chs = axis_usize(r.channel_length)?;
1662    let avgs = axis_usize(r.average_length)?;
1663    let mas = axis_usize(r.ma_length)?;
1664    let factors = axis_f64(r.factor)?;
1665
1666    let cap = chs
1667        .len()
1668        .checked_mul(avgs.len())
1669        .and_then(|x| x.checked_mul(mas.len()))
1670        .and_then(|x| x.checked_mul(factors.len()))
1671        .ok_or_else(|| WavetrendError::InvalidRange {
1672            start: "cap".into(),
1673            end: "overflow".into(),
1674            step: "mul".into(),
1675        })?;
1676
1677    let mut out = Vec::with_capacity(cap);
1678    for &c in &chs {
1679        for &a in &avgs {
1680            for &m in &mas {
1681                for &f in &factors {
1682                    out.push(WavetrendParams {
1683                        channel_length: Some(c),
1684                        average_length: Some(a),
1685                        ma_length: Some(m),
1686                        factor: Some(f),
1687                    });
1688                }
1689            }
1690        }
1691    }
1692    Ok(out)
1693}
1694
1695#[inline(always)]
1696pub fn wavetrend_batch_slice(
1697    data: &[f64],
1698    sweep: &WavetrendBatchRange,
1699    kern: Kernel,
1700) -> Result<WavetrendBatchOutput, WavetrendError> {
1701    wavetrend_batch_inner(data, sweep, kern, false)
1702}
1703#[inline(always)]
1704pub fn wavetrend_batch_par_slice(
1705    data: &[f64],
1706    sweep: &WavetrendBatchRange,
1707    kern: Kernel,
1708) -> Result<WavetrendBatchOutput, WavetrendError> {
1709    wavetrend_batch_inner(data, sweep, kern, true)
1710}
1711
1712#[inline(always)]
1713fn wavetrend_batch_inner(
1714    data: &[f64],
1715    sweep: &WavetrendBatchRange,
1716    kern: Kernel,
1717    parallel: bool,
1718) -> Result<WavetrendBatchOutput, WavetrendError> {
1719    let combos = expand_grid(sweep)?;
1720    if combos.is_empty() {
1721        return Err(WavetrendError::InvalidRange {
1722            start: "range".into(),
1723            end: "range".into(),
1724            step: "empty".into(),
1725        });
1726    }
1727    let first = data
1728        .iter()
1729        .position(|x| !x.is_nan())
1730        .ok_or(WavetrendError::AllValuesNaN)?;
1731
1732    let mut max_p = 0usize;
1733    let mut warmup_periods = Vec::with_capacity(combos.len());
1734    for c in combos.iter() {
1735        let channel_length = c.channel_length.unwrap();
1736        if channel_length == 0 {
1737            return Err(WavetrendError::InvalidChannelLen {
1738                channel_length,
1739                data_len: data.len(),
1740            });
1741        }
1742        let average_length = c.average_length.unwrap();
1743        if average_length == 0 {
1744            return Err(WavetrendError::InvalidAverageLen {
1745                average_length,
1746                data_len: data.len(),
1747            });
1748        }
1749        let ma_length = c.ma_length.unwrap();
1750        if ma_length == 0 {
1751            return Err(WavetrendError::InvalidMaLen {
1752                ma_length,
1753                data_len: data.len(),
1754            });
1755        }
1756
1757        max_p = max_p.max(channel_length).max(average_length).max(ma_length);
1758        warmup_periods.push(first + channel_length - 1 + average_length - 1 + ma_length - 1);
1759    }
1760    if data.len() - first < max_p {
1761        return Err(WavetrendError::NotEnoughValidData {
1762            needed: max_p,
1763            valid: data.len() - first,
1764        });
1765    }
1766    let rows = combos.len();
1767    let cols = data.len();
1768
1769    let _ = rows
1770        .checked_mul(cols)
1771        .ok_or_else(|| WavetrendError::InvalidRange {
1772            start: rows.to_string(),
1773            end: cols.to_string(),
1774            step: "rows*cols".into(),
1775        })?;
1776
1777    let mut wt1_mu = make_uninit_matrix(rows, cols);
1778    let mut wt2_mu = make_uninit_matrix(rows, cols);
1779    let mut wt_diff_mu = make_uninit_matrix(rows, cols);
1780
1781    init_matrix_prefixes(&mut wt1_mu, cols, &warmup_periods);
1782    init_matrix_prefixes(&mut wt2_mu, cols, &warmup_periods);
1783    init_matrix_prefixes(&mut wt_diff_mu, cols, &warmup_periods);
1784
1785    let mut wt1_guard = core::mem::ManuallyDrop::new(wt1_mu);
1786    let mut wt2_guard = core::mem::ManuallyDrop::new(wt2_mu);
1787    let mut wt_diff_guard = core::mem::ManuallyDrop::new(wt_diff_mu);
1788
1789    let wt1: &mut [f64] = unsafe {
1790        core::slice::from_raw_parts_mut(wt1_guard.as_mut_ptr() as *mut f64, wt1_guard.len())
1791    };
1792    let wt2: &mut [f64] = unsafe {
1793        core::slice::from_raw_parts_mut(wt2_guard.as_mut_ptr() as *mut f64, wt2_guard.len())
1794    };
1795    let wt_diff: &mut [f64] = unsafe {
1796        core::slice::from_raw_parts_mut(wt_diff_guard.as_mut_ptr() as *mut f64, wt_diff_guard.len())
1797    };
1798
1799    let do_row = |row: usize, w1: &mut [f64], w2: &mut [f64], wd: &mut [f64]| unsafe {
1800        let p = &combos[row];
1801        let row_kernel = match kern {
1802            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1803            Kernel::Avx512 => wavetrend_row_avx512(
1804                data,
1805                first,
1806                p.channel_length.unwrap(),
1807                p.average_length.unwrap(),
1808                p.ma_length.unwrap(),
1809                p.factor.unwrap_or(0.015),
1810                w1,
1811                w2,
1812                wd,
1813            ),
1814            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1815            Kernel::Avx2 => wavetrend_row_avx2(
1816                data,
1817                first,
1818                p.channel_length.unwrap(),
1819                p.average_length.unwrap(),
1820                p.ma_length.unwrap(),
1821                p.factor.unwrap_or(0.015),
1822                w1,
1823                w2,
1824                wd,
1825            ),
1826            _ => wavetrend_row_scalar(
1827                data,
1828                first,
1829                p.channel_length.unwrap(),
1830                p.average_length.unwrap(),
1831                p.ma_length.unwrap(),
1832                p.factor.unwrap_or(0.015),
1833                w1,
1834                w2,
1835                wd,
1836            ),
1837        };
1838        if let Err(e) = row_kernel {
1839            panic!("wavetrend row error: {:?}", e);
1840        }
1841    };
1842
1843    if parallel {
1844        #[cfg(not(target_arch = "wasm32"))]
1845        {
1846            wt1.par_chunks_mut(cols)
1847                .zip(wt2.par_chunks_mut(cols))
1848                .zip(wt_diff.par_chunks_mut(cols))
1849                .enumerate()
1850                .for_each(|(row, ((w1, w2), wd))| do_row(row, w1, w2, wd));
1851        }
1852
1853        #[cfg(target_arch = "wasm32")]
1854        {
1855            for (row, (((w1, w2), wd))) in wt1
1856                .chunks_mut(cols)
1857                .zip(wt2.chunks_mut(cols))
1858                .zip(wt_diff.chunks_mut(cols))
1859                .enumerate()
1860            {
1861                do_row(row, w1, w2, wd);
1862            }
1863        }
1864    } else {
1865        for (row, (((w1, w2), wd))) in wt1
1866            .chunks_mut(cols)
1867            .zip(wt2.chunks_mut(cols))
1868            .zip(wt_diff.chunks_mut(cols))
1869            .enumerate()
1870        {
1871            do_row(row, w1, w2, wd);
1872        }
1873    }
1874
1875    let wt1_vec = unsafe {
1876        Vec::from_raw_parts(
1877            wt1_guard.as_mut_ptr() as *mut f64,
1878            wt1_guard.len(),
1879            wt1_guard.capacity(),
1880        )
1881    };
1882    let wt2_vec = unsafe {
1883        Vec::from_raw_parts(
1884            wt2_guard.as_mut_ptr() as *mut f64,
1885            wt2_guard.len(),
1886            wt2_guard.capacity(),
1887        )
1888    };
1889    let wt_diff_vec = unsafe {
1890        Vec::from_raw_parts(
1891            wt_diff_guard.as_mut_ptr() as *mut f64,
1892            wt_diff_guard.len(),
1893            wt_diff_guard.capacity(),
1894        )
1895    };
1896
1897    Ok(WavetrendBatchOutput {
1898        wt1: wt1_vec,
1899        wt2: wt2_vec,
1900        wt_diff: wt_diff_vec,
1901        combos,
1902        rows,
1903        cols,
1904    })
1905}
1906
1907#[inline(always)]
1908fn wavetrend_batch_inner_into(
1909    data: &[f64],
1910    sweep: &WavetrendBatchRange,
1911    kern: Kernel,
1912    parallel: bool,
1913    out_wt1: &mut [f64],
1914    out_wt2: &mut [f64],
1915    out_wt_diff: &mut [f64],
1916) -> Result<Vec<WavetrendParams>, WavetrendError> {
1917    let combos = expand_grid(sweep)?;
1918    if combos.is_empty() {
1919        return Err(WavetrendError::InvalidRange {
1920            start: "range".into(),
1921            end: "range".into(),
1922            step: "empty".into(),
1923        });
1924    }
1925    let first = data
1926        .iter()
1927        .position(|x| !x.is_nan())
1928        .ok_or(WavetrendError::AllValuesNaN)?;
1929
1930    let mut max_p = 0usize;
1931    for c in combos.iter() {
1932        let channel_length = c.channel_length.unwrap();
1933        if channel_length == 0 {
1934            return Err(WavetrendError::InvalidChannelLen {
1935                channel_length,
1936                data_len: data.len(),
1937            });
1938        }
1939        let average_length = c.average_length.unwrap();
1940        if average_length == 0 {
1941            return Err(WavetrendError::InvalidAverageLen {
1942                average_length,
1943                data_len: data.len(),
1944            });
1945        }
1946        let ma_length = c.ma_length.unwrap();
1947        if ma_length == 0 {
1948            return Err(WavetrendError::InvalidMaLen {
1949                ma_length,
1950                data_len: data.len(),
1951            });
1952        }
1953
1954        max_p = max_p.max(channel_length).max(average_length).max(ma_length);
1955    }
1956    if data.len() - first < max_p {
1957        return Err(WavetrendError::NotEnoughValidData {
1958            needed: max_p,
1959            valid: data.len() - first,
1960        });
1961    }
1962    let rows = combos.len();
1963    let cols = data.len();
1964
1965    let total = rows
1966        .checked_mul(cols)
1967        .ok_or_else(|| WavetrendError::InvalidRange {
1968            start: rows.to_string(),
1969            end: cols.to_string(),
1970            step: "rows*cols".into(),
1971        })?;
1972    if out_wt1.len() != total {
1973        return Err(WavetrendError::OutputSliceLengthMismatch {
1974            expected: total,
1975            got: out_wt1.len(),
1976        });
1977    }
1978    if out_wt2.len() != total {
1979        return Err(WavetrendError::OutputSliceLengthMismatch {
1980            expected: total,
1981            got: out_wt2.len(),
1982        });
1983    }
1984    if out_wt_diff.len() != total {
1985        return Err(WavetrendError::OutputSliceLengthMismatch {
1986            expected: total,
1987            got: out_wt_diff.len(),
1988        });
1989    }
1990
1991    for (row, combo) in combos.iter().enumerate() {
1992        let warmup = first + combo.channel_length.unwrap() - 1 + combo.average_length.unwrap() - 1
1993            + combo.ma_length.unwrap()
1994            - 1;
1995        let row_start = row * cols;
1996        for i in 0..warmup.min(cols) {
1997            out_wt1[row_start + i] = f64::NAN;
1998            out_wt2[row_start + i] = f64::NAN;
1999            out_wt_diff[row_start + i] = f64::NAN;
2000        }
2001    }
2002
2003    let do_row = |row: usize, w1: &mut [f64], w2: &mut [f64], wd: &mut [f64]| unsafe {
2004        let p = &combos[row];
2005        let r = wavetrend_row_scalar(
2006            data,
2007            first,
2008            p.channel_length.unwrap(),
2009            p.average_length.unwrap(),
2010            p.ma_length.unwrap(),
2011            p.factor.unwrap_or(0.015),
2012            w1,
2013            w2,
2014            wd,
2015        );
2016        if let Err(e) = r {
2017            panic!("wavetrend row error: {:?}", e);
2018        }
2019    };
2020
2021    if parallel {
2022        #[cfg(not(target_arch = "wasm32"))]
2023        {
2024            out_wt1
2025                .par_chunks_mut(cols)
2026                .zip(out_wt2.par_chunks_mut(cols))
2027                .zip(out_wt_diff.par_chunks_mut(cols))
2028                .enumerate()
2029                .for_each(|(row, ((w1, w2), wd))| do_row(row, w1, w2, wd));
2030        }
2031
2032        #[cfg(target_arch = "wasm32")]
2033        {
2034            for (row, (((w1, w2), wd))) in out_wt1
2035                .chunks_mut(cols)
2036                .zip(out_wt2.chunks_mut(cols))
2037                .zip(out_wt_diff.chunks_mut(cols))
2038                .enumerate()
2039            {
2040                do_row(row, w1, w2, wd);
2041            }
2042        }
2043    } else {
2044        for (row, (((w1, w2), wd))) in out_wt1
2045            .chunks_mut(cols)
2046            .zip(out_wt2.chunks_mut(cols))
2047            .zip(out_wt_diff.chunks_mut(cols))
2048            .enumerate()
2049        {
2050            do_row(row, w1, w2, wd);
2051        }
2052    }
2053    Ok(combos)
2054}
2055
2056#[inline(always)]
2057unsafe fn wavetrend_row_scalar(
2058    data: &[f64],
2059    first: usize,
2060    channel_len: usize,
2061    average_len: usize,
2062    ma_len: usize,
2063    factor: f64,
2064    wt1: &mut [f64],
2065    wt2: &mut [f64],
2066    wd: &mut [f64],
2067) -> Result<(), WavetrendError> {
2068    wavetrend_row_with_kernel(
2069        data,
2070        first,
2071        channel_len,
2072        average_len,
2073        ma_len,
2074        factor,
2075        wt1,
2076        wt2,
2077        wd,
2078        Kernel::Scalar,
2079    )
2080}
2081
2082#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2083#[inline(always)]
2084unsafe fn wavetrend_row_avx2(
2085    data: &[f64],
2086    first: usize,
2087    channel_len: usize,
2088    average_len: usize,
2089    ma_len: usize,
2090    factor: f64,
2091    wt1: &mut [f64],
2092    wt2: &mut [f64],
2093    wd: &mut [f64],
2094) -> Result<(), WavetrendError> {
2095    wavetrend_row_with_kernel(
2096        data,
2097        first,
2098        channel_len,
2099        average_len,
2100        ma_len,
2101        factor,
2102        wt1,
2103        wt2,
2104        wd,
2105        Kernel::Avx2,
2106    )
2107}
2108#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2109#[inline(always)]
2110unsafe fn wavetrend_row_avx512(
2111    data: &[f64],
2112    first: usize,
2113    channel_len: usize,
2114    average_len: usize,
2115    ma_len: usize,
2116    factor: f64,
2117    wt1: &mut [f64],
2118    wt2: &mut [f64],
2119    wd: &mut [f64],
2120) -> Result<(), WavetrendError> {
2121    wavetrend_row_with_kernel(
2122        data,
2123        first,
2124        channel_len,
2125        average_len,
2126        ma_len,
2127        factor,
2128        wt1,
2129        wt2,
2130        wd,
2131        Kernel::Avx512,
2132    )
2133}
2134#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2135#[inline(always)]
2136unsafe fn wavetrend_row_avx512_short(
2137    data: &[f64],
2138    first: usize,
2139    channel_len: usize,
2140    average_len: usize,
2141    ma_len: usize,
2142    factor: f64,
2143    wt1: &mut [f64],
2144    wt2: &mut [f64],
2145    wd: &mut [f64],
2146) -> Result<(), WavetrendError> {
2147    wavetrend_row_with_kernel(
2148        data,
2149        first,
2150        channel_len,
2151        average_len,
2152        ma_len,
2153        factor,
2154        wt1,
2155        wt2,
2156        wd,
2157        Kernel::Avx512,
2158    )
2159}
2160#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2161#[inline(always)]
2162unsafe fn wavetrend_row_avx512_long(
2163    data: &[f64],
2164    first: usize,
2165    channel_len: usize,
2166    average_len: usize,
2167    ma_len: usize,
2168    factor: f64,
2169    wt1: &mut [f64],
2170    wt2: &mut [f64],
2171    wd: &mut [f64],
2172) -> Result<(), WavetrendError> {
2173    wavetrend_row_with_kernel(
2174        data,
2175        first,
2176        channel_len,
2177        average_len,
2178        ma_len,
2179        factor,
2180        wt1,
2181        wt2,
2182        wd,
2183        Kernel::Avx512,
2184    )
2185}
2186
2187#[inline(always)]
2188unsafe fn wavetrend_row_with_kernel(
2189    data: &[f64],
2190    first: usize,
2191    channel_len: usize,
2192    average_len: usize,
2193    ma_len: usize,
2194    factor: f64,
2195    wt1: &mut [f64],
2196    wt2: &mut [f64],
2197    wd: &mut [f64],
2198    kernel: Kernel,
2199) -> Result<(), WavetrendError> {
2200    debug_assert_eq!(wt1.len(), data.len());
2201    debug_assert_eq!(wt2.len(), data.len());
2202    debug_assert_eq!(wd.len(), data.len());
2203
2204    let warmup = first + channel_len - 1 + average_len - 1 + ma_len - 1;
2205
2206    wavetrend_compute_into(
2207        data,
2208        channel_len,
2209        average_len,
2210        ma_len,
2211        factor,
2212        first,
2213        warmup,
2214        wt1,
2215        wt2,
2216        wd,
2217        kernel,
2218    )
2219}
2220#[cfg(test)]
2221mod tests {
2222    use super::*;
2223    use crate::skip_if_unsupported;
2224    use crate::utilities::data_loader::read_candles_from_csv;
2225    use crate::utilities::enums::Kernel;
2226
2227    fn check_wavetrend_partial_params(
2228        test_name: &str,
2229        kernel: Kernel,
2230    ) -> Result<(), Box<dyn std::error::Error>> {
2231        skip_if_unsupported!(kernel, test_name);
2232        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2233        let candles = read_candles_from_csv(file_path)?;
2234        let default_params = WavetrendParams {
2235            channel_length: None,
2236            average_length: None,
2237            ma_length: None,
2238            factor: None,
2239        };
2240        let input = WavetrendInput::from_candles(&candles, "hlc3", default_params);
2241        let output = wavetrend_with_kernel(&input, kernel)?;
2242        assert_eq!(output.wt1.len(), candles.close.len());
2243        Ok(())
2244    }
2245
2246    fn check_wavetrend_accuracy(
2247        test_name: &str,
2248        kernel: Kernel,
2249    ) -> Result<(), Box<dyn std::error::Error>> {
2250        skip_if_unsupported!(kernel, test_name);
2251        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2252        let candles = read_candles_from_csv(file_path)?;
2253        let input = WavetrendInput::from_candles(&candles, "hlc3", WavetrendParams::default());
2254        let result = wavetrend_with_kernel(&input, kernel)?;
2255        let len = result.wt1.len();
2256        let expected_wt1 = [
2257            -29.02058232514538,
2258            -28.207769813591664,
2259            -31.991808642927193,
2260            -31.9218051759519,
2261            -44.956245952893866,
2262        ];
2263        let expected_wt2 = [
2264            -30.651043230696555,
2265            -28.686329669808583,
2266            -29.740053593887932,
2267            -30.707127877490105,
2268            -36.2899532572575,
2269        ];
2270        for (i, &val) in result.wt1[len - 5..].iter().enumerate() {
2271            let diff = (val - expected_wt1[i]).abs();
2272            assert!(
2273                diff < 1e-6,
2274                "[{}] Wavetrend {:?} WT1 mismatch at idx {}: got {}, expected {}",
2275                test_name,
2276                kernel,
2277                i,
2278                val,
2279                expected_wt1[i]
2280            );
2281        }
2282        for (i, &val) in result.wt2[len - 5..].iter().enumerate() {
2283            let diff = (val - expected_wt2[i]).abs();
2284            assert!(
2285                diff < 1e-6,
2286                "[{}] Wavetrend {:?} WT2 mismatch at idx {}: got {}, expected {}",
2287                test_name,
2288                kernel,
2289                i,
2290                val,
2291                expected_wt2[i]
2292            );
2293        }
2294        let last_five_diff = &result.wt_diff[len - 5..];
2295        for i in 0..5 {
2296            let expected = expected_wt2[i] - expected_wt1[i];
2297            let diff = (last_five_diff[i] - expected).abs();
2298            assert!(
2299                diff < 1e-6,
2300                "[{}] Wavetrend {:?} WT_DIFF mismatch at idx {}: got {}, expected {}",
2301                test_name,
2302                kernel,
2303                i,
2304                last_five_diff[i],
2305                expected
2306            );
2307        }
2308        Ok(())
2309    }
2310
2311    fn check_wavetrend_default_candles(
2312        test_name: &str,
2313        kernel: Kernel,
2314    ) -> Result<(), Box<dyn std::error::Error>> {
2315        skip_if_unsupported!(kernel, test_name);
2316        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2317        let candles = read_candles_from_csv(file_path)?;
2318        let input = WavetrendInput::with_default_candles(&candles);
2319        match input.data {
2320            WavetrendData::Candles { source, .. } => assert_eq!(source, "hlc3"),
2321            _ => panic!("Expected WavetrendData::Candles"),
2322        }
2323        let output = wavetrend_with_kernel(&input, kernel)?;
2324        assert_eq!(output.wt1.len(), candles.close.len());
2325        Ok(())
2326    }
2327
2328    fn check_wavetrend_zero_channel(
2329        test_name: &str,
2330        kernel: Kernel,
2331    ) -> Result<(), Box<dyn std::error::Error>> {
2332        skip_if_unsupported!(kernel, test_name);
2333        let input_data = [10.0, 20.0, 30.0];
2334        let params = WavetrendParams {
2335            channel_length: Some(0),
2336            average_length: Some(12),
2337            ma_length: Some(3),
2338            factor: Some(0.015),
2339        };
2340        let input = WavetrendInput::from_slice(&input_data, params);
2341        let res = wavetrend_with_kernel(&input, kernel);
2342        assert!(
2343            res.is_err(),
2344            "[{}] Wavetrend should fail with zero channel_length",
2345            test_name
2346        );
2347        Ok(())
2348    }
2349
2350    fn check_wavetrend_channel_exceeds_length(
2351        test_name: &str,
2352        kernel: Kernel,
2353    ) -> Result<(), Box<dyn std::error::Error>> {
2354        skip_if_unsupported!(kernel, test_name);
2355        let data_small = [10.0, 20.0, 30.0];
2356        let params = WavetrendParams {
2357            channel_length: Some(10),
2358            average_length: Some(12),
2359            ma_length: Some(3),
2360            factor: Some(0.015),
2361        };
2362        let input = WavetrendInput::from_slice(&data_small, params);
2363        let res = wavetrend_with_kernel(&input, kernel);
2364        assert!(
2365            res.is_err(),
2366            "[{}] Wavetrend should fail with channel_length exceeding length",
2367            test_name
2368        );
2369        Ok(())
2370    }
2371
2372    fn check_wavetrend_very_small_dataset(
2373        test_name: &str,
2374        kernel: Kernel,
2375    ) -> Result<(), Box<dyn std::error::Error>> {
2376        skip_if_unsupported!(kernel, test_name);
2377        let single_point = [42.0];
2378        let params = WavetrendParams::default();
2379        let input = WavetrendInput::from_slice(&single_point, params);
2380        let res = wavetrend_with_kernel(&input, kernel);
2381        assert!(
2382            res.is_err(),
2383            "[{}] Wavetrend should fail with insufficient data",
2384            test_name
2385        );
2386        Ok(())
2387    }
2388
2389    fn check_wavetrend_nan_handling(
2390        test_name: &str,
2391        kernel: Kernel,
2392    ) -> Result<(), Box<dyn std::error::Error>> {
2393        skip_if_unsupported!(kernel, test_name);
2394        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2395        let candles = read_candles_from_csv(file_path)?;
2396        let input = WavetrendInput::from_candles(
2397            &candles,
2398            "hlc3",
2399            WavetrendParams {
2400                channel_length: Some(9),
2401                average_length: Some(12),
2402                ma_length: Some(3),
2403                factor: Some(0.015),
2404            },
2405        );
2406        let res = wavetrend_with_kernel(&input, kernel)?;
2407        assert_eq!(res.wt1.len(), candles.close.len());
2408        if res.wt1.len() > 240 {
2409            for (i, &val) in res.wt1[240..].iter().enumerate() {
2410                assert!(
2411                    !val.is_nan(),
2412                    "[{}] Found unexpected NaN at out-index {}",
2413                    test_name,
2414                    240 + i
2415                );
2416            }
2417        }
2418        Ok(())
2419    }
2420
2421    #[cfg(debug_assertions)]
2422    fn check_wavetrend_no_poison(
2423        test_name: &str,
2424        kernel: Kernel,
2425    ) -> Result<(), Box<dyn std::error::Error>> {
2426        skip_if_unsupported!(kernel, test_name);
2427
2428        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2429        let candles = read_candles_from_csv(file_path)?;
2430
2431        let test_params = vec![
2432            WavetrendParams::default(),
2433            WavetrendParams {
2434                channel_length: Some(1),
2435                average_length: Some(1),
2436                ma_length: Some(1),
2437                factor: Some(0.001),
2438            },
2439            WavetrendParams {
2440                channel_length: Some(2),
2441                average_length: Some(2),
2442                ma_length: Some(2),
2443                factor: Some(0.005),
2444            },
2445            WavetrendParams {
2446                channel_length: Some(5),
2447                average_length: Some(7),
2448                ma_length: Some(3),
2449                factor: Some(0.01),
2450            },
2451            WavetrendParams {
2452                channel_length: Some(10),
2453                average_length: Some(15),
2454                ma_length: Some(5),
2455                factor: Some(0.02),
2456            },
2457            WavetrendParams {
2458                channel_length: Some(20),
2459                average_length: Some(25),
2460                ma_length: Some(7),
2461                factor: Some(0.025),
2462            },
2463            WavetrendParams {
2464                channel_length: Some(30),
2465                average_length: Some(40),
2466                ma_length: Some(10),
2467                factor: Some(0.03),
2468            },
2469            WavetrendParams {
2470                channel_length: Some(50),
2471                average_length: Some(60),
2472                ma_length: Some(15),
2473                factor: Some(0.04),
2474            },
2475            WavetrendParams {
2476                channel_length: Some(100),
2477                average_length: Some(120),
2478                ma_length: Some(20),
2479                factor: Some(0.05),
2480            },
2481            WavetrendParams {
2482                channel_length: Some(7),
2483                average_length: Some(11),
2484                ma_length: Some(3),
2485                factor: Some(0.013),
2486            },
2487            WavetrendParams {
2488                channel_length: Some(13),
2489                average_length: Some(17),
2490                ma_length: Some(5),
2491                factor: Some(0.017),
2492            },
2493            WavetrendParams {
2494                channel_length: Some(9),
2495                average_length: Some(3),
2496                ma_length: Some(12),
2497                factor: Some(0.015),
2498            },
2499            WavetrendParams {
2500                channel_length: Some(15),
2501                average_length: Some(15),
2502                ma_length: Some(15),
2503                factor: Some(0.015),
2504            },
2505            WavetrendParams {
2506                channel_length: Some(9),
2507                average_length: Some(12),
2508                ma_length: Some(3),
2509                factor: Some(0.0001),
2510            },
2511            WavetrendParams {
2512                channel_length: Some(9),
2513                average_length: Some(12),
2514                ma_length: Some(3),
2515                factor: Some(1.0),
2516            },
2517            WavetrendParams {
2518                channel_length: Some(3),
2519                average_length: Some(5),
2520                ma_length: Some(1),
2521                factor: Some(0.008),
2522            },
2523            WavetrendParams {
2524                channel_length: Some(8),
2525                average_length: Some(13),
2526                ma_length: Some(2),
2527                factor: Some(0.021),
2528            },
2529            WavetrendParams {
2530                channel_length: Some(21),
2531                average_length: Some(34),
2532                ma_length: Some(8),
2533                factor: Some(0.034),
2534            },
2535        ];
2536
2537        for (param_idx, params) in test_params.iter().enumerate() {
2538            let input = WavetrendInput::from_candles(&candles, "hlc3", params.clone());
2539            let output = wavetrend_with_kernel(&input, kernel)?;
2540
2541            for (i, &val) in output.wt1.iter().enumerate() {
2542                if val.is_nan() {
2543                    continue;
2544                }
2545
2546                let bits = val.to_bits();
2547
2548                if bits == 0x11111111_11111111 {
2549                    panic!(
2550						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2551						 in wt1 output with params: channel_length={}, average_length={}, ma_length={}, factor={} (param set {})",
2552						test_name, val, bits, i,
2553						params.channel_length.unwrap_or(9),
2554						params.average_length.unwrap_or(12),
2555						params.ma_length.unwrap_or(3),
2556						params.factor.unwrap_or(0.015),
2557						param_idx
2558					);
2559                }
2560
2561                if bits == 0x22222222_22222222 {
2562                    panic!(
2563						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2564						 in wt1 output with params: channel_length={}, average_length={}, ma_length={}, factor={} (param set {})",
2565						test_name, val, bits, i,
2566						params.channel_length.unwrap_or(9),
2567						params.average_length.unwrap_or(12),
2568						params.ma_length.unwrap_or(3),
2569						params.factor.unwrap_or(0.015),
2570						param_idx
2571					);
2572                }
2573
2574                if bits == 0x33333333_33333333 {
2575                    panic!(
2576						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2577						 in wt1 output with params: channel_length={}, average_length={}, ma_length={}, factor={} (param set {})",
2578						test_name, val, bits, i,
2579						params.channel_length.unwrap_or(9),
2580						params.average_length.unwrap_or(12),
2581						params.ma_length.unwrap_or(3),
2582						params.factor.unwrap_or(0.015),
2583						param_idx
2584					);
2585                }
2586            }
2587
2588            for (i, &val) in output.wt2.iter().enumerate() {
2589                if val.is_nan() {
2590                    continue;
2591                }
2592
2593                let bits = val.to_bits();
2594
2595                if bits == 0x11111111_11111111 {
2596                    panic!(
2597						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2598						 in wt2 output with params: channel_length={}, average_length={}, ma_length={}, factor={} (param set {})",
2599						test_name, val, bits, i,
2600						params.channel_length.unwrap_or(9),
2601						params.average_length.unwrap_or(12),
2602						params.ma_length.unwrap_or(3),
2603						params.factor.unwrap_or(0.015),
2604						param_idx
2605					);
2606                }
2607
2608                if bits == 0x22222222_22222222 {
2609                    panic!(
2610						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2611						 in wt2 output with params: channel_length={}, average_length={}, ma_length={}, factor={} (param set {})",
2612						test_name, val, bits, i,
2613						params.channel_length.unwrap_or(9),
2614						params.average_length.unwrap_or(12),
2615						params.ma_length.unwrap_or(3),
2616						params.factor.unwrap_or(0.015),
2617						param_idx
2618					);
2619                }
2620
2621                if bits == 0x33333333_33333333 {
2622                    panic!(
2623						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2624						 in wt2 output with params: channel_length={}, average_length={}, ma_length={}, factor={} (param set {})",
2625						test_name, val, bits, i,
2626						params.channel_length.unwrap_or(9),
2627						params.average_length.unwrap_or(12),
2628						params.ma_length.unwrap_or(3),
2629						params.factor.unwrap_or(0.015),
2630						param_idx
2631					);
2632                }
2633            }
2634
2635            for (i, &val) in output.wt_diff.iter().enumerate() {
2636                if val.is_nan() {
2637                    continue;
2638                }
2639
2640                let bits = val.to_bits();
2641
2642                if bits == 0x11111111_11111111 {
2643                    panic!(
2644						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2645						 in wt_diff output with params: channel_length={}, average_length={}, ma_length={}, factor={} (param set {})",
2646						test_name, val, bits, i,
2647						params.channel_length.unwrap_or(9),
2648						params.average_length.unwrap_or(12),
2649						params.ma_length.unwrap_or(3),
2650						params.factor.unwrap_or(0.015),
2651						param_idx
2652					);
2653                }
2654
2655                if bits == 0x22222222_22222222 {
2656                    panic!(
2657						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2658						 in wt_diff output with params: channel_length={}, average_length={}, ma_length={}, factor={} (param set {})",
2659						test_name, val, bits, i,
2660						params.channel_length.unwrap_or(9),
2661						params.average_length.unwrap_or(12),
2662						params.ma_length.unwrap_or(3),
2663						params.factor.unwrap_or(0.015),
2664						param_idx
2665					);
2666                }
2667
2668                if bits == 0x33333333_33333333 {
2669                    panic!(
2670						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2671						 in wt_diff output with params: channel_length={}, average_length={}, ma_length={}, factor={} (param set {})",
2672						test_name, val, bits, i,
2673						params.channel_length.unwrap_or(9),
2674						params.average_length.unwrap_or(12),
2675						params.ma_length.unwrap_or(3),
2676						params.factor.unwrap_or(0.015),
2677						param_idx
2678					);
2679                }
2680            }
2681        }
2682
2683        Ok(())
2684    }
2685
2686    #[cfg(not(debug_assertions))]
2687    fn check_wavetrend_no_poison(
2688        _test_name: &str,
2689        _kernel: Kernel,
2690    ) -> Result<(), Box<dyn std::error::Error>> {
2691        Ok(())
2692    }
2693
2694    fn check_wavetrend_streaming(
2695        test_name: &str,
2696        kernel: Kernel,
2697    ) -> Result<(), Box<dyn std::error::Error>> {
2698        skip_if_unsupported!(kernel, test_name);
2699
2700        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2701        let candles = read_candles_from_csv(file_path)?;
2702
2703        let channel_length = 9;
2704        let average_length = 12;
2705        let ma_length = 3;
2706        let factor = 0.015;
2707
2708        let input = WavetrendInput::from_candles(
2709            &candles,
2710            "hlc3",
2711            WavetrendParams {
2712                channel_length: Some(channel_length),
2713                average_length: Some(average_length),
2714                ma_length: Some(ma_length),
2715                factor: Some(factor),
2716            },
2717        );
2718        let full_output = wavetrend_with_kernel(&input, kernel)?;
2719
2720        let mut stream = WavetrendStream::try_new(WavetrendParams {
2721            channel_length: Some(channel_length),
2722            average_length: Some(average_length),
2723            ma_length: Some(ma_length),
2724            factor: Some(factor),
2725        })?;
2726
2727        let mut wt1_stream = Vec::with_capacity(candles.hlc3.len());
2728        let mut wt2_stream = Vec::with_capacity(candles.hlc3.len());
2729        let mut diff_stream = Vec::with_capacity(candles.hlc3.len());
2730        for &price in &candles.hlc3 {
2731            match stream.update(price) {
2732                Some((wt1, wt2, diff)) => {
2733                    wt1_stream.push(wt1);
2734                    wt2_stream.push(wt2);
2735                    diff_stream.push(diff);
2736                }
2737                None => {
2738                    wt1_stream.push(f64::NAN);
2739                    wt2_stream.push(f64::NAN);
2740                    diff_stream.push(f64::NAN);
2741                }
2742            }
2743        }
2744
2745        let mut first_non_nan = None;
2746        for (i, &b) in full_output.wt1.iter().enumerate() {
2747            if !b.is_nan() {
2748                first_non_nan = Some(i);
2749                break;
2750            }
2751        }
2752        let start = first_non_nan.unwrap_or(0);
2753        assert_eq!(full_output.wt1.len(), wt1_stream.len());
2754        for (i, (&b, &s)) in full_output
2755            .wt1
2756            .iter()
2757            .zip(wt1_stream.iter())
2758            .enumerate()
2759            .skip(start)
2760        {
2761            if b.is_nan() || s.is_nan() {
2762                continue;
2763            }
2764            let diff = (b - s).abs();
2765            assert!(
2766                diff < 1e-9,
2767                "[{}] Wavetrend streaming wt1 f64 mismatch at idx {}: full={}, stream={}, diff={}",
2768                test_name,
2769                i,
2770                b,
2771                s,
2772                diff
2773            );
2774        }
2775        for (i, (&b, &s)) in full_output.wt2.iter().zip(wt2_stream.iter()).enumerate() {
2776            if b.is_nan() || s.is_nan() {
2777                continue;
2778            }
2779            let diff = (b - s).abs();
2780            assert!(
2781                diff < 1e-9,
2782                "[{}] Wavetrend streaming wt2 f64 mismatch at idx {}: full={}, stream={}, diff={}",
2783                test_name,
2784                i,
2785                b,
2786                s,
2787                diff
2788            );
2789        }
2790        for (i, (&b, &s)) in full_output
2791            .wt_diff
2792            .iter()
2793            .zip(diff_stream.iter())
2794            .enumerate()
2795        {
2796            if b.is_nan() || s.is_nan() {
2797                continue;
2798            }
2799            let diff = (b - s).abs();
2800            assert!(
2801				diff < 1e-9,
2802				"[{}] Wavetrend streaming wt_diff f64 mismatch at idx {}: full={}, stream={}, diff={}",
2803				test_name,
2804				i,
2805				b,
2806				s,
2807				diff
2808			);
2809        }
2810        Ok(())
2811    }
2812
2813    #[cfg(feature = "proptest")]
2814    fn check_wavetrend_property(
2815        test_name: &str,
2816        kernel: Kernel,
2817    ) -> Result<(), Box<dyn std::error::Error>> {
2818        use proptest::prelude::*;
2819        skip_if_unsupported!(kernel, test_name);
2820
2821        let strat = (2usize..=30, 2usize..=30, 1usize..=10, 0.001f64..1.0f64).prop_flat_map(
2822            |(channel_len, average_len, ma_len, factor)| {
2823                let min_len = channel_len + average_len + ma_len + 20;
2824                (min_len..400).prop_flat_map(move |data_len| {
2825                    (
2826                        prop::collection::vec(
2827                            (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2828                            data_len,
2829                        ),
2830                        Just(channel_len),
2831                        Just(average_len),
2832                        Just(ma_len),
2833                        Just(factor),
2834                    )
2835                })
2836            },
2837        );
2838
2839        proptest::test_runner::TestRunner::default()
2840            .run(
2841                &strat,
2842                |(data, channel_len, average_len, ma_len, factor)| {
2843                    let params = WavetrendParams {
2844                        channel_length: Some(channel_len),
2845                        average_length: Some(average_len),
2846                        ma_length: Some(ma_len),
2847                        factor: Some(factor),
2848                    };
2849                    let input = WavetrendInput::from_slice(&data, params);
2850
2851                    let output = wavetrend_with_kernel(&input, kernel).unwrap();
2852                    let ref_output = wavetrend_with_kernel(&input, Kernel::Scalar).unwrap();
2853
2854                    let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
2855                    let expected_warmup =
2856                        first_valid + channel_len - 1 + average_len - 1 + ma_len - 1;
2857
2858                    for i in expected_warmup.min(data.len())..data.len() {
2859                        if output.wt1[i].is_finite() && output.wt2[i].is_finite() {
2860                            let expected_diff = output.wt2[i] - output.wt1[i];
2861                            let actual_diff = output.wt_diff[i];
2862                            prop_assert!(
2863                                (actual_diff - expected_diff).abs() <= 1e-9,
2864                                "WT_DIFF mismatch at idx {}: expected {}, got {}",
2865                                i,
2866                                expected_diff,
2867                                actual_diff
2868                            );
2869                        }
2870                    }
2871
2872                    let valid_start = expected_warmup.min(data.len());
2873                    let valid_wt1: Vec<f64> = output.wt1[valid_start..]
2874                        .iter()
2875                        .filter(|&&x| x.is_finite())
2876                        .copied()
2877                        .collect();
2878                    let valid_wt2: Vec<f64> = output.wt2[valid_start..]
2879                        .iter()
2880                        .filter(|&&x| x.is_finite())
2881                        .copied()
2882                        .collect();
2883
2884                    if valid_wt1.len() > 10 && valid_wt2.len() > 10 && ma_len > 1 {
2885                        let mut wt1_changes = 0.0;
2886                        let mut wt2_changes = 0.0;
2887                        for i in 1..valid_wt1.len().min(valid_wt2.len()) {
2888                            wt1_changes += (valid_wt1[i] - valid_wt1[i - 1]).abs();
2889                            wt2_changes += (valid_wt2[i] - valid_wt2[i - 1]).abs();
2890                        }
2891
2892                        if wt1_changes > 1e-6 {
2893                            prop_assert!(
2894                                wt2_changes <= wt1_changes * 1.1,
2895                                "WT2 should be smoother: wt1_changes={}, wt2_changes={}",
2896                                wt1_changes,
2897                                wt2_changes
2898                            );
2899                        }
2900                    }
2901
2902                    if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-9)
2903                        && data.len() > valid_start + 10
2904                    {
2905                        let last_10_wt1: Vec<f64> = output.wt1[output.wt1.len() - 10..]
2906                            .iter()
2907                            .filter(|&&x| x.is_finite())
2908                            .copied()
2909                            .collect();
2910                        if last_10_wt1.len() >= 5 {
2911                            let avg_wt1: f64 =
2912                                last_10_wt1.iter().sum::<f64>() / last_10_wt1.len() as f64;
2913                            prop_assert!(
2914                                avg_wt1.abs() <= 1.0,
2915                                "Constant price should give near-zero oscillator: avg_wt1={}",
2916                                avg_wt1
2917                            );
2918                        }
2919                    }
2920
2921                    if factor < 0.5 && valid_start < data.len() {
2922                        let params_double = WavetrendParams {
2923                            channel_length: Some(channel_len),
2924                            average_length: Some(average_len),
2925                            ma_length: Some(ma_len),
2926                            factor: Some(factor * 2.0),
2927                        };
2928                        let input_double = WavetrendInput::from_slice(&data, params_double);
2929                        let output_double = wavetrend_with_kernel(&input_double, kernel).unwrap();
2930
2931                        let check_end = data.len().min(valid_start + 20);
2932                        let mut checked_count = 0;
2933                        for i in valid_start..check_end {
2934                            if output.wt1[i].is_finite()
2935                                && output_double.wt1[i].is_finite()
2936                                && output.wt1[i].abs() > 0.1
2937                            {
2938                                let ratio = output_double.wt1[i] / output.wt1[i];
2939
2940                                prop_assert!(
2941								(ratio - 0.5).abs() <= 0.35,
2942								"Factor doubling should roughly halve WT1 at idx {}: original={}, doubled={}, ratio={}",
2943								i, output.wt1[i], output_double.wt1[i], ratio
2944							);
2945                                checked_count += 1;
2946                                if checked_count >= 5 {
2947                                    break;
2948                                }
2949                            }
2950                        }
2951                    }
2952
2953                    if ma_len == 1 {
2954                        for i in valid_start..data.len() {
2955                            if output.wt1[i].is_finite() && output.wt2[i].is_finite() {
2956                                prop_assert!(
2957                                    (output.wt1[i] - output.wt2[i]).abs() <= 1e-9,
2958                                    "When ma_len=1, WT2 should equal WT1 at idx {}: wt1={}, wt2={}",
2959                                    i,
2960                                    output.wt1[i],
2961                                    output.wt2[i]
2962                                );
2963                            }
2964                        }
2965                    }
2966
2967                    for i in 0..data.len() {
2968                        let wt1 = output.wt1[i];
2969                        let wt1_ref = ref_output.wt1[i];
2970                        let wt2 = output.wt2[i];
2971                        let wt2_ref = ref_output.wt2[i];
2972                        let diff = output.wt_diff[i];
2973                        let diff_ref = ref_output.wt_diff[i];
2974
2975                        if wt1.is_nan() || wt1_ref.is_nan() {
2976                            prop_assert!(
2977                                wt1.is_nan() && wt1_ref.is_nan(),
2978                                "NaN mismatch for WT1 at idx {}: kernel={:?}, ref={:?}",
2979                                i,
2980                                wt1,
2981                                wt1_ref
2982                            );
2983                        } else {
2984                            let wt1_bits = wt1.to_bits();
2985                            let wt1_ref_bits = wt1_ref.to_bits();
2986                            let ulp_diff = wt1_bits.abs_diff(wt1_ref_bits);
2987                            prop_assert!(
2988                                (wt1 - wt1_ref).abs() <= 1e-9 || ulp_diff <= 4,
2989                                "WT1 mismatch at idx {}: kernel={}, ref={} (ULP={})",
2990                                i,
2991                                wt1,
2992                                wt1_ref,
2993                                ulp_diff
2994                            );
2995                        }
2996
2997                        if wt2.is_nan() || wt2_ref.is_nan() {
2998                            prop_assert!(
2999                                wt2.is_nan() && wt2_ref.is_nan(),
3000                                "NaN mismatch for WT2 at idx {}: kernel={:?}, ref={:?}",
3001                                i,
3002                                wt2,
3003                                wt2_ref
3004                            );
3005                        } else {
3006                            let wt2_bits = wt2.to_bits();
3007                            let wt2_ref_bits = wt2_ref.to_bits();
3008                            let ulp_diff = wt2_bits.abs_diff(wt2_ref_bits);
3009                            prop_assert!(
3010                                (wt2 - wt2_ref).abs() <= 1e-9 || ulp_diff <= 4,
3011                                "WT2 mismatch at idx {}: kernel={}, ref={} (ULP={})",
3012                                i,
3013                                wt2,
3014                                wt2_ref,
3015                                ulp_diff
3016                            );
3017                        }
3018
3019                        if diff.is_nan() || diff_ref.is_nan() {
3020                            prop_assert!(
3021                                diff.is_nan() && diff_ref.is_nan(),
3022                                "NaN mismatch for WT_DIFF at idx {}: kernel={:?}, ref={:?}",
3023                                i,
3024                                diff,
3025                                diff_ref
3026                            );
3027                        } else {
3028                            let diff_bits = diff.to_bits();
3029                            let diff_ref_bits = diff_ref.to_bits();
3030                            let ulp_diff = diff_bits.abs_diff(diff_ref_bits);
3031                            prop_assert!(
3032                                (diff - diff_ref).abs() <= 1e-9 || ulp_diff <= 4,
3033                                "WT_DIFF mismatch at idx {}: kernel={}, ref={} (ULP={})",
3034                                i,
3035                                diff,
3036                                diff_ref,
3037                                ulp_diff
3038                            );
3039                        }
3040                    }
3041
3042                    for i in 0..expected_warmup.min(data.len()) {
3043                        prop_assert!(
3044                            output.wt1[i].is_nan(),
3045                            "WT1 should be NaN during warmup at idx {}: got {}",
3046                            i,
3047                            output.wt1[i]
3048                        );
3049                        prop_assert!(
3050                            output.wt2[i].is_nan(),
3051                            "WT2 should be NaN during warmup at idx {}: got {}",
3052                            i,
3053                            output.wt2[i]
3054                        );
3055                        prop_assert!(
3056                            output.wt_diff[i].is_nan(),
3057                            "WT_DIFF should be NaN during warmup at idx {}: got {}",
3058                            i,
3059                            output.wt_diff[i]
3060                        );
3061                    }
3062
3063                    Ok(())
3064                },
3065            )
3066            .unwrap();
3067
3068        Ok(())
3069    }
3070
3071    macro_rules! generate_all_wavetrend_tests {
3072        ($($test_fn:ident),*) => {
3073            paste::paste! {
3074                $(
3075                    #[test]
3076                    fn [<$test_fn _scalar_f64>]() {
3077                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
3078                    }
3079                )*
3080                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3081                $(
3082                    #[test]
3083                    fn [<$test_fn _avx2_f64>]() {
3084                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
3085                    }
3086                    #[test]
3087                    fn [<$test_fn _avx512_f64>]() {
3088                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
3089                    }
3090                )*
3091            }
3092        }
3093    }
3094
3095    generate_all_wavetrend_tests!(
3096        check_wavetrend_partial_params,
3097        check_wavetrend_accuracy,
3098        check_wavetrend_default_candles,
3099        check_wavetrend_zero_channel,
3100        check_wavetrend_channel_exceeds_length,
3101        check_wavetrend_very_small_dataset,
3102        check_wavetrend_nan_handling,
3103        check_wavetrend_streaming,
3104        check_wavetrend_no_poison
3105    );
3106
3107    #[cfg(feature = "proptest")]
3108    generate_all_wavetrend_tests!(check_wavetrend_property);
3109
3110    fn check_batch_default_row(
3111        test: &str,
3112        kernel: Kernel,
3113    ) -> Result<(), Box<dyn std::error::Error>> {
3114        skip_if_unsupported!(kernel, test);
3115        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3116        let c = read_candles_from_csv(file)?;
3117
3118        let output = WavetrendBatchBuilder::new()
3119            .kernel(kernel)
3120            .apply_candles(&c, "hlc3")?;
3121
3122        let def = WavetrendParams::default();
3123        let (wt1_row, wt2_row, diff_row) = output.values_for(&def).expect("default row missing");
3124
3125        assert_eq!(wt1_row.len(), c.close.len());
3126        assert_eq!(wt2_row.len(), c.close.len());
3127        assert_eq!(diff_row.len(), c.close.len());
3128
3129        let expected_wt1 = [
3130            -29.02058232514538,
3131            -28.207769813591664,
3132            -31.991808642927193,
3133            -31.9218051759519,
3134            -44.956245952893866,
3135        ];
3136        let expected_wt2 = [
3137            -30.651043230696555,
3138            -28.686329669808583,
3139            -29.740053593887932,
3140            -30.707127877490105,
3141            -36.2899532572575,
3142        ];
3143
3144        let start = wt1_row.len().saturating_sub(5);
3145        for (i, &v) in wt1_row[start..].iter().enumerate() {
3146            assert!(
3147                (v - expected_wt1[i]).abs() < 1e-8,
3148                "[{test}] default-row WT1 mismatch at idx {i}: {v} vs {expected}",
3149                test = test,
3150                i = i,
3151                v = v,
3152                expected = expected_wt1[i]
3153            );
3154        }
3155        for (i, &v) in wt2_row[start..].iter().enumerate() {
3156            assert!(
3157                (v - expected_wt2[i]).abs() < 1e-6,
3158                "[{test}] default-row WT2 mismatch at idx {i}: {v} vs {expected}",
3159                test = test,
3160                i = i,
3161                v = v,
3162                expected = expected_wt2[i]
3163            );
3164        }
3165        Ok(())
3166    }
3167
3168    #[cfg(debug_assertions)]
3169    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
3170        skip_if_unsupported!(kernel, test);
3171
3172        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3173        let c = read_candles_from_csv(file)?;
3174
3175        let test_configs = vec![
3176            (2, 10, 2, 3, 12, 3, 1, 5, 1, 0.005, 0.015, 0.005),
3177            (5, 25, 5, 10, 30, 5, 2, 8, 2, 0.01, 0.03, 0.01),
3178            (20, 60, 10, 25, 75, 10, 5, 15, 5, 0.02, 0.05, 0.015),
3179            (2, 5, 1, 2, 5, 1, 1, 3, 1, 0.001, 0.005, 0.001),
3180            (10, 30, 10, 15, 45, 15, 3, 9, 3, 0.015, 0.045, 0.015),
3181            (50, 100, 25, 60, 120, 30, 10, 20, 5, 0.03, 0.06, 0.03),
3182            (9, 9, 0, 12, 12, 0, 3, 3, 0, 0.015, 0.015, 0.0),
3183            (1, 3, 1, 1, 3, 1, 1, 2, 1, 0.001, 0.003, 0.001),
3184        ];
3185
3186        for (cfg_idx, config) in test_configs.iter().enumerate() {
3187            let output = WavetrendBatchBuilder::new()
3188                .kernel(kernel)
3189                .channel_range(config.0, config.1, config.2)
3190                .avg_range(config.3, config.4, config.5)
3191                .ma_range(config.6, config.7, config.8)
3192                .factor_range(config.9, config.10, config.11)
3193                .apply_candles(&c, "hlc3")?;
3194
3195            for (idx, &val) in output.wt1.iter().enumerate() {
3196                if val.is_nan() {
3197                    continue;
3198                }
3199
3200                let bits = val.to_bits();
3201                let row = idx / output.cols;
3202                let col = idx % output.cols;
3203                let combo = &output.combos[row];
3204
3205                if bits == 0x11111111_11111111 {
3206                    panic!(
3207						"[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
3208						 at row {} col {} (flat index {}) in wt1 output with params: channel_length={}, average_length={}, ma_length={}, factor={}",
3209						test, cfg_idx, val, bits, row, col, idx,
3210						combo.channel_length.unwrap_or(9),
3211						combo.average_length.unwrap_or(12),
3212						combo.ma_length.unwrap_or(3),
3213						combo.factor.unwrap_or(0.015)
3214					);
3215                }
3216
3217                if bits == 0x22222222_22222222 {
3218                    panic!(
3219						"[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
3220						 at row {} col {} (flat index {}) in wt1 output with params: channel_length={}, average_length={}, ma_length={}, factor={}",
3221						test, cfg_idx, val, bits, row, col, idx,
3222						combo.channel_length.unwrap_or(9),
3223						combo.average_length.unwrap_or(12),
3224						combo.ma_length.unwrap_or(3),
3225						combo.factor.unwrap_or(0.015)
3226					);
3227                }
3228
3229                if bits == 0x33333333_33333333 {
3230                    panic!(
3231						"[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
3232						 at row {} col {} (flat index {}) in wt1 output with params: channel_length={}, average_length={}, ma_length={}, factor={}",
3233						test, cfg_idx, val, bits, row, col, idx,
3234						combo.channel_length.unwrap_or(9),
3235						combo.average_length.unwrap_or(12),
3236						combo.ma_length.unwrap_or(3),
3237						combo.factor.unwrap_or(0.015)
3238					);
3239                }
3240            }
3241
3242            for (idx, &val) in output.wt2.iter().enumerate() {
3243                if val.is_nan() {
3244                    continue;
3245                }
3246
3247                let bits = val.to_bits();
3248                let row = idx / output.cols;
3249                let col = idx % output.cols;
3250                let combo = &output.combos[row];
3251
3252                if bits == 0x11111111_11111111 {
3253                    panic!(
3254						"[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
3255						 at row {} col {} (flat index {}) in wt2 output with params: channel_length={}, average_length={}, ma_length={}, factor={}",
3256						test, cfg_idx, val, bits, row, col, idx,
3257						combo.channel_length.unwrap_or(9),
3258						combo.average_length.unwrap_or(12),
3259						combo.ma_length.unwrap_or(3),
3260						combo.factor.unwrap_or(0.015)
3261					);
3262                }
3263
3264                if bits == 0x22222222_22222222 {
3265                    panic!(
3266						"[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
3267						 at row {} col {} (flat index {}) in wt2 output with params: channel_length={}, average_length={}, ma_length={}, factor={}",
3268						test, cfg_idx, val, bits, row, col, idx,
3269						combo.channel_length.unwrap_or(9),
3270						combo.average_length.unwrap_or(12),
3271						combo.ma_length.unwrap_or(3),
3272						combo.factor.unwrap_or(0.015)
3273					);
3274                }
3275
3276                if bits == 0x33333333_33333333 {
3277                    panic!(
3278						"[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
3279						 at row {} col {} (flat index {}) in wt2 output with params: channel_length={}, average_length={}, ma_length={}, factor={}",
3280						test, cfg_idx, val, bits, row, col, idx,
3281						combo.channel_length.unwrap_or(9),
3282						combo.average_length.unwrap_or(12),
3283						combo.ma_length.unwrap_or(3),
3284						combo.factor.unwrap_or(0.015)
3285					);
3286                }
3287            }
3288
3289            for (idx, &val) in output.wt_diff.iter().enumerate() {
3290                if val.is_nan() {
3291                    continue;
3292                }
3293
3294                let bits = val.to_bits();
3295                let row = idx / output.cols;
3296                let col = idx % output.cols;
3297                let combo = &output.combos[row];
3298
3299                if bits == 0x11111111_11111111 {
3300                    panic!(
3301						"[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
3302						 at row {} col {} (flat index {}) in wt_diff output with params: channel_length={}, average_length={}, ma_length={}, factor={}",
3303						test, cfg_idx, val, bits, row, col, idx,
3304						combo.channel_length.unwrap_or(9),
3305						combo.average_length.unwrap_or(12),
3306						combo.ma_length.unwrap_or(3),
3307						combo.factor.unwrap_or(0.015)
3308					);
3309                }
3310
3311                if bits == 0x22222222_22222222 {
3312                    panic!(
3313						"[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
3314						 at row {} col {} (flat index {}) in wt_diff output with params: channel_length={}, average_length={}, ma_length={}, factor={}",
3315						test, cfg_idx, val, bits, row, col, idx,
3316						combo.channel_length.unwrap_or(9),
3317						combo.average_length.unwrap_or(12),
3318						combo.ma_length.unwrap_or(3),
3319						combo.factor.unwrap_or(0.015)
3320					);
3321                }
3322
3323                if bits == 0x33333333_33333333 {
3324                    panic!(
3325						"[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
3326						 at row {} col {} (flat index {}) in wt_diff output with params: channel_length={}, average_length={}, ma_length={}, factor={}",
3327						test, cfg_idx, val, bits, row, col, idx,
3328						combo.channel_length.unwrap_or(9),
3329						combo.average_length.unwrap_or(12),
3330						combo.ma_length.unwrap_or(3),
3331						combo.factor.unwrap_or(0.015)
3332					);
3333                }
3334            }
3335        }
3336
3337        Ok(())
3338    }
3339
3340    #[cfg(not(debug_assertions))]
3341    fn check_batch_no_poison(
3342        _test: &str,
3343        _kernel: Kernel,
3344    ) -> Result<(), Box<dyn std::error::Error>> {
3345        Ok(())
3346    }
3347
3348    macro_rules! gen_batch_tests {
3349        ($fn_name:ident) => {
3350            paste::paste! {
3351                #[test] fn [<$fn_name _scalar>]()      {
3352                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3353                }
3354                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3355                #[test] fn [<$fn_name _avx2>]()        {
3356                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3357                }
3358                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3359                #[test] fn [<$fn_name _avx512>]()      {
3360                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3361                }
3362                #[test] fn [<$fn_name _auto_detect>]() {
3363                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3364                }
3365            }
3366        };
3367    }
3368    gen_batch_tests!(check_batch_default_row);
3369    gen_batch_tests!(check_batch_no_poison);
3370}
3371
3372#[cfg(feature = "python")]
3373#[pyfunction(name = "wavetrend")]
3374#[pyo3(signature = (data, channel_length, average_length, ma_length, factor, kernel=None))]
3375pub fn wavetrend_py<'py>(
3376    py: Python<'py>,
3377    data: numpy::PyReadonlyArray1<'py, f64>,
3378    channel_length: usize,
3379    average_length: usize,
3380    ma_length: usize,
3381    factor: f64,
3382    kernel: Option<&str>,
3383) -> PyResult<(
3384    Bound<'py, PyArray1<f64>>,
3385    Bound<'py, PyArray1<f64>>,
3386    Bound<'py, PyArray1<f64>>,
3387)> {
3388    use numpy::{IntoPyArray, PyArrayMethods};
3389
3390    let slice_in = data.as_slice()?;
3391    let kern = validate_kernel(kernel, false)?;
3392
3393    let params = WavetrendParams {
3394        channel_length: Some(channel_length),
3395        average_length: Some(average_length),
3396        ma_length: Some(ma_length),
3397        factor: Some(factor),
3398    };
3399    let input = WavetrendInput::from_slice(slice_in, params);
3400
3401    let (wt1_vec, wt2_vec, wt_diff_vec) = py
3402        .allow_threads(|| wavetrend_with_kernel(&input, kern).map(|o| (o.wt1, o.wt2, o.wt_diff)))
3403        .map_err(|e| PyValueError::new_err(e.to_string()))?;
3404
3405    Ok((
3406        wt1_vec.into_pyarray(py),
3407        wt2_vec.into_pyarray(py),
3408        wt_diff_vec.into_pyarray(py),
3409    ))
3410}
3411
3412#[cfg(feature = "python")]
3413#[pyclass(name = "WavetrendStream")]
3414pub struct WavetrendStreamPy {
3415    stream: WavetrendStream,
3416}
3417
3418#[cfg(feature = "python")]
3419#[pymethods]
3420impl WavetrendStreamPy {
3421    #[new]
3422    fn new(
3423        channel_length: usize,
3424        average_length: usize,
3425        ma_length: usize,
3426        factor: f64,
3427    ) -> PyResult<Self> {
3428        let params = WavetrendParams {
3429            channel_length: Some(channel_length),
3430            average_length: Some(average_length),
3431            ma_length: Some(ma_length),
3432            factor: Some(factor),
3433        };
3434        let stream =
3435            WavetrendStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
3436        Ok(WavetrendStreamPy { stream })
3437    }
3438
3439    fn update(&mut self, value: f64) -> Option<(f64, f64, f64)> {
3440        self.stream.update(value)
3441    }
3442}
3443
3444#[cfg(feature = "python")]
3445#[pyfunction(name = "wavetrend_batch")]
3446#[pyo3(signature = (data, channel_length_range, average_length_range, ma_length_range, factor_range, kernel=None))]
3447pub fn wavetrend_batch_py<'py>(
3448    py: Python<'py>,
3449    data: numpy::PyReadonlyArray1<'py, f64>,
3450    channel_length_range: (usize, usize, usize),
3451    average_length_range: (usize, usize, usize),
3452    ma_length_range: (usize, usize, usize),
3453    factor_range: (f64, f64, f64),
3454    kernel: Option<&str>,
3455) -> PyResult<Bound<'py, PyDict>> {
3456    use numpy::{IntoPyArray, PyArrayMethods};
3457
3458    let slice_in = data.as_slice()?;
3459    let kern = validate_kernel(kernel, true)?;
3460
3461    let sweep = WavetrendBatchRange {
3462        channel_length: channel_length_range,
3463        average_length: average_length_range,
3464        ma_length: ma_length_range,
3465        factor: factor_range,
3466    };
3467
3468    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
3469    let rows = combos.len();
3470    let cols = slice_in.len();
3471
3472    let total = rows
3473        .checked_mul(cols)
3474        .ok_or_else(|| PyValueError::new_err("rows*cols overflow for wavetrend_batch"))?;
3475    let wt1_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
3476    let wt2_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
3477    let wt_diff_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
3478
3479    let slice_wt1 = unsafe { wt1_arr.as_slice_mut()? };
3480    let slice_wt2 = unsafe { wt2_arr.as_slice_mut()? };
3481    let slice_wt_diff = unsafe { wt_diff_arr.as_slice_mut()? };
3482
3483    let combos = py
3484        .allow_threads(|| {
3485            let kernel = match kern {
3486                Kernel::Auto => detect_best_batch_kernel(),
3487                k => k,
3488            };
3489            let simd = match kernel {
3490                Kernel::Avx512Batch => Kernel::Avx512,
3491                Kernel::Avx2Batch => Kernel::Avx2,
3492                Kernel::ScalarBatch => Kernel::Scalar,
3493                _ => unreachable!(),
3494            };
3495            wavetrend_batch_inner_into(
3496                slice_in,
3497                &sweep,
3498                simd,
3499                true,
3500                slice_wt1,
3501                slice_wt2,
3502                slice_wt_diff,
3503            )
3504        })
3505        .map_err(|e| PyValueError::new_err(e.to_string()))?;
3506
3507    let dict = PyDict::new(py);
3508    dict.set_item("wt1", wt1_arr.reshape((rows, cols))?)?;
3509    dict.set_item("wt2", wt2_arr.reshape((rows, cols))?)?;
3510    dict.set_item("wt_diff", wt_diff_arr.reshape((rows, cols))?)?;
3511    dict.set_item(
3512        "channel_lengths",
3513        combos
3514            .iter()
3515            .map(|p| p.channel_length.unwrap() as u64)
3516            .collect::<Vec<_>>()
3517            .into_pyarray(py),
3518    )?;
3519    dict.set_item(
3520        "average_lengths",
3521        combos
3522            .iter()
3523            .map(|p| p.average_length.unwrap() as u64)
3524            .collect::<Vec<_>>()
3525            .into_pyarray(py),
3526    )?;
3527    dict.set_item(
3528        "ma_lengths",
3529        combos
3530            .iter()
3531            .map(|p| p.ma_length.unwrap() as u64)
3532            .collect::<Vec<_>>()
3533            .into_pyarray(py),
3534    )?;
3535    dict.set_item(
3536        "factors",
3537        combos
3538            .iter()
3539            .map(|p| p.factor.unwrap())
3540            .collect::<Vec<_>>()
3541            .into_pyarray(py),
3542    )?;
3543
3544    Ok(dict)
3545}
3546
3547#[cfg(all(feature = "python", feature = "cuda"))]
3548#[pyfunction(name = "wavetrend_cuda_batch_dev")]
3549#[pyo3(signature = (data_f32, channel_length_range, average_length_range, ma_length_range, factor_range, device_id=0))]
3550pub fn wavetrend_cuda_batch_dev_py<'py>(
3551    py: Python<'py>,
3552    data_f32: numpy::PyReadonlyArray1<'py, f32>,
3553    channel_length_range: (usize, usize, usize),
3554    average_length_range: (usize, usize, usize),
3555    ma_length_range: (usize, usize, usize),
3556    factor_range: (f64, f64, f64),
3557    device_id: usize,
3558) -> PyResult<Bound<'py, PyDict>> {
3559    use numpy::IntoPyArray;
3560
3561    if !cuda_available() {
3562        return Err(PyValueError::new_err("CUDA not available"));
3563    }
3564
3565    let slice_in = data_f32.as_slice()?;
3566    let sweep = WavetrendBatchRange {
3567        channel_length: channel_length_range,
3568        average_length: average_length_range,
3569        ma_length: ma_length_range,
3570        factor: factor_range,
3571    };
3572
3573    let (batch, ctx, dev_id) = py.allow_threads(|| {
3574        let cuda =
3575            CudaWavetrend::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
3576        let ctx = cuda.context_arc();
3577        let dev_id = cuda.device_id();
3578        cuda.wavetrend_batch_dev(slice_in, &sweep)
3579            .map(|b| (b, ctx, dev_id))
3580            .map_err(|e| PyValueError::new_err(e.to_string()))
3581    })?;
3582
3583    let dict = PyDict::new(py);
3584    dict.set_item(
3585        "wt1",
3586        Py::new(
3587            py,
3588            WavetrendDeviceArrayF32Py {
3589                inner: batch.wt1,
3590                _ctx: ctx.clone(),
3591                device_id: dev_id,
3592            },
3593        )?,
3594    )?;
3595    dict.set_item(
3596        "wt2",
3597        Py::new(
3598            py,
3599            WavetrendDeviceArrayF32Py {
3600                inner: batch.wt2,
3601                _ctx: ctx.clone(),
3602                device_id: dev_id,
3603            },
3604        )?,
3605    )?;
3606    dict.set_item(
3607        "wt_diff",
3608        Py::new(
3609            py,
3610            WavetrendDeviceArrayF32Py {
3611                inner: batch.wt_diff,
3612                _ctx: ctx,
3613                device_id: dev_id,
3614            },
3615        )?,
3616    )?;
3617
3618    let (c0, c1, cstep) = channel_length_range;
3619    let (a0, a1, astep) = average_length_range;
3620    let (m0, m1, mstep) = ma_length_range;
3621    let (f0, f1, fstep) = factor_range;
3622    let channel_axis: Vec<usize> = if cstep == 0 {
3623        vec![c0]
3624    } else {
3625        (c0..=c1).step_by(cstep).collect()
3626    };
3627    let average_axis: Vec<usize> = if astep == 0 {
3628        vec![a0]
3629    } else {
3630        (a0..=a1).step_by(astep).collect()
3631    };
3632    let ma_axis: Vec<usize> = if mstep == 0 {
3633        vec![m0]
3634    } else {
3635        (m0..=m1).step_by(mstep).collect()
3636    };
3637    let mut factor_axis: Vec<f64> = Vec::new();
3638    if fstep.abs() < f64::EPSILON || (f0 - f1).abs() < f64::EPSILON {
3639        factor_axis.push(f0);
3640    } else {
3641        let mut v = f0;
3642        while v <= f1 + fstep.abs() * 1e-12 {
3643            factor_axis.push(v);
3644            v += fstep;
3645        }
3646    }
3647
3648    dict.set_item("channel_lengths", channel_axis.into_pyarray(py))?;
3649    dict.set_item("average_lengths", average_axis.into_pyarray(py))?;
3650    dict.set_item("ma_lengths", ma_axis.into_pyarray(py))?;
3651    dict.set_item("factors", factor_axis.into_pyarray(py))?;
3652
3653    Ok(dict)
3654}
3655
3656#[cfg(all(feature = "python", feature = "cuda"))]
3657#[pyfunction(name = "wavetrend_cuda_many_series_one_param_dev")]
3658#[pyo3(signature = (data_tm_f32, channel_length, average_length, ma_length, factor, device_id=0))]
3659pub fn wavetrend_cuda_many_series_one_param_dev_py<'py>(
3660    py: Python<'py>,
3661    data_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
3662    channel_length: usize,
3663    average_length: usize,
3664    ma_length: usize,
3665    factor: f64,
3666    device_id: usize,
3667) -> PyResult<Bound<'py, PyDict>> {
3668    use numpy::PyUntypedArrayMethods;
3669
3670    if !cuda_available() {
3671        return Err(PyValueError::new_err("CUDA not available"));
3672    }
3673
3674    let shape = data_tm_f32.shape();
3675    if shape.len() != 2 {
3676        return Err(PyValueError::new_err("expected 2D array (rows x cols)"));
3677    }
3678    let rows = shape[0];
3679    let cols = shape[1];
3680    let flat = data_tm_f32.as_slice()?;
3681
3682    let params = WavetrendParams {
3683        channel_length: Some(channel_length),
3684        average_length: Some(average_length),
3685        ma_length: Some(ma_length),
3686        factor: Some(factor),
3687    };
3688
3689    let (wt1, wt2, wt_diff, ctx, dev_id) = py.allow_threads(|| {
3690        let cuda =
3691            CudaWavetrend::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
3692        let ctx = cuda.context_arc();
3693        let dev_id = cuda.device_id();
3694        cuda.wavetrend_many_series_one_param_time_major_dev(flat, cols, rows, &params)
3695            .map(|(a, b, c)| (a, b, c, ctx, dev_id))
3696            .map_err(|e| PyValueError::new_err(e.to_string()))
3697    })?;
3698
3699    let dict = PyDict::new(py);
3700    dict.set_item(
3701        "wt1",
3702        Py::new(
3703            py,
3704            WavetrendDeviceArrayF32Py {
3705                inner: wt1,
3706                _ctx: ctx.clone(),
3707                device_id: dev_id,
3708            },
3709        )?,
3710    )?;
3711    dict.set_item(
3712        "wt2",
3713        Py::new(
3714            py,
3715            WavetrendDeviceArrayF32Py {
3716                inner: wt2,
3717                _ctx: ctx.clone(),
3718                device_id: dev_id,
3719            },
3720        )?,
3721    )?;
3722    dict.set_item(
3723        "wt_diff",
3724        Py::new(
3725            py,
3726            WavetrendDeviceArrayF32Py {
3727                inner: wt_diff,
3728                _ctx: ctx,
3729                device_id: dev_id,
3730            },
3731        )?,
3732    )?;
3733    dict.set_item("rows", rows)?;
3734    dict.set_item("cols", cols)?;
3735    dict.set_item("channel_length", channel_length)?;
3736    dict.set_item("average_length", average_length)?;
3737    dict.set_item("ma_length", ma_length)?;
3738    dict.set_item("factor", factor)?;
3739
3740    Ok(dict)
3741}
3742
3743#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3744#[wasm_bindgen]
3745pub fn wavetrend_js(
3746    data: &[f64],
3747    channel_length: usize,
3748    average_length: usize,
3749    ma_length: usize,
3750    factor: f64,
3751) -> Result<Vec<f64>, JsValue> {
3752    let params = WavetrendParams {
3753        channel_length: Some(channel_length),
3754        average_length: Some(average_length),
3755        ma_length: Some(ma_length),
3756        factor: Some(factor),
3757    };
3758    let input = WavetrendInput::from_slice(data, params);
3759
3760    let mut output = vec![0.0; data.len() * 3];
3761    let (wt1_part, rest) = output.split_at_mut(data.len());
3762    let (wt2_part, wt_diff_part) = rest.split_at_mut(data.len());
3763
3764    wavetrend_into_slice(wt1_part, wt2_part, wt_diff_part, &input, Kernel::Auto)
3765        .map_err(|e| JsValue::from_str(&e.to_string()))?;
3766
3767    Ok(output)
3768}
3769
3770#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3771#[wasm_bindgen]
3772pub fn wavetrend_into(
3773    in_ptr: *const f64,
3774    wt1_ptr: *mut f64,
3775    wt2_ptr: *mut f64,
3776    wt_diff_ptr: *mut f64,
3777    len: usize,
3778    channel_length: usize,
3779    average_length: usize,
3780    ma_length: usize,
3781    factor: f64,
3782) -> Result<(), JsValue> {
3783    if in_ptr.is_null() || wt1_ptr.is_null() || wt2_ptr.is_null() || wt_diff_ptr.is_null() {
3784        return Err(JsValue::from_str("Null pointer provided"));
3785    }
3786
3787    unsafe {
3788        let data = std::slice::from_raw_parts(in_ptr, len);
3789        let params = WavetrendParams {
3790            channel_length: Some(channel_length),
3791            average_length: Some(average_length),
3792            ma_length: Some(ma_length),
3793            factor: Some(factor),
3794        };
3795        let input = WavetrendInput::from_slice(data, params);
3796
3797        let needs_temp = in_ptr as *const u8 == wt1_ptr as *const u8
3798            || in_ptr as *const u8 == wt2_ptr as *const u8
3799            || in_ptr as *const u8 == wt_diff_ptr as *const u8;
3800
3801        if needs_temp {
3802            let mut temp = vec![0.0; len * 3];
3803            let (temp_wt1, rest) = temp.split_at_mut(len);
3804            let (temp_wt2, temp_wt_diff) = rest.split_at_mut(len);
3805
3806            wavetrend_into_slice(temp_wt1, temp_wt2, temp_wt_diff, &input, Kernel::Auto)
3807                .map_err(|e| JsValue::from_str(&e.to_string()))?;
3808
3809            let wt1_out = std::slice::from_raw_parts_mut(wt1_ptr, len);
3810            let wt2_out = std::slice::from_raw_parts_mut(wt2_ptr, len);
3811            let wt_diff_out = std::slice::from_raw_parts_mut(wt_diff_ptr, len);
3812
3813            wt1_out.copy_from_slice(temp_wt1);
3814            wt2_out.copy_from_slice(temp_wt2);
3815            wt_diff_out.copy_from_slice(temp_wt_diff);
3816        } else {
3817            let wt1_out = std::slice::from_raw_parts_mut(wt1_ptr, len);
3818            let wt2_out = std::slice::from_raw_parts_mut(wt2_ptr, len);
3819            let wt_diff_out = std::slice::from_raw_parts_mut(wt_diff_ptr, len);
3820
3821            wavetrend_into_slice(wt1_out, wt2_out, wt_diff_out, &input, Kernel::Auto)
3822                .map_err(|e| JsValue::from_str(&e.to_string()))?;
3823        }
3824
3825        Ok(())
3826    }
3827}
3828
3829#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3830#[wasm_bindgen]
3831pub fn wavetrend_alloc(len: usize) -> *mut f64 {
3832    let mut vec = Vec::<f64>::with_capacity(len);
3833    let ptr = vec.as_mut_ptr();
3834    std::mem::forget(vec);
3835    ptr
3836}
3837
3838#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3839#[wasm_bindgen]
3840pub fn wavetrend_free(ptr: *mut f64, len: usize) {
3841    if !ptr.is_null() {
3842        unsafe {
3843            let _ = Vec::from_raw_parts(ptr, len, len);
3844        }
3845    }
3846}
3847
3848#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3849#[derive(Serialize, Deserialize)]
3850pub struct WavetrendBatchConfig {
3851    pub channel_length_range: (usize, usize, usize),
3852    pub average_length_range: (usize, usize, usize),
3853    pub ma_length_range: (usize, usize, usize),
3854    pub factor_range: (f64, f64, f64),
3855}
3856
3857#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3858#[derive(Serialize, Deserialize)]
3859pub struct WavetrendBatchJsOutput {
3860    pub wt1_values: Vec<f64>,
3861    pub wt2_values: Vec<f64>,
3862    pub wt_diff_values: Vec<f64>,
3863    pub channel_lengths: Vec<usize>,
3864    pub average_lengths: Vec<usize>,
3865    pub ma_lengths: Vec<usize>,
3866    pub factors: Vec<f64>,
3867    pub rows: usize,
3868    pub cols: usize,
3869}
3870
3871#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3872#[wasm_bindgen(js_name = wavetrend_batch)]
3873pub fn wavetrend_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
3874    let config: WavetrendBatchConfig =
3875        serde_wasm_bindgen::from_value(config).map_err(|e| JsValue::from_str(&e.to_string()))?;
3876
3877    let sweep = WavetrendBatchRange {
3878        channel_length: (
3879            config.channel_length_range.0,
3880            config.channel_length_range.1,
3881            config.channel_length_range.2,
3882        ),
3883        average_length: (
3884            config.average_length_range.0,
3885            config.average_length_range.1,
3886            config.average_length_range.2,
3887        ),
3888        ma_length: (
3889            config.ma_length_range.0,
3890            config.ma_length_range.1,
3891            config.ma_length_range.2,
3892        ),
3893        factor: (
3894            config.factor_range.0,
3895            config.factor_range.1,
3896            config.factor_range.2,
3897        ),
3898    };
3899
3900    let batch_output = wavetrend_batch_with_kernel(data, &sweep, Kernel::Auto)
3901        .map_err(|e| JsValue::from_str(&e.to_string()))?;
3902
3903    let js_output = WavetrendBatchJsOutput {
3904        wt1_values: batch_output.wt1,
3905        wt2_values: batch_output.wt2,
3906        wt_diff_values: batch_output.wt_diff,
3907        channel_lengths: batch_output
3908            .combos
3909            .iter()
3910            .map(|p| p.channel_length.unwrap())
3911            .collect(),
3912        average_lengths: batch_output
3913            .combos
3914            .iter()
3915            .map(|p| p.average_length.unwrap())
3916            .collect(),
3917        ma_lengths: batch_output
3918            .combos
3919            .iter()
3920            .map(|p| p.ma_length.unwrap())
3921            .collect(),
3922        factors: batch_output
3923            .combos
3924            .iter()
3925            .map(|p| p.factor.unwrap())
3926            .collect(),
3927        rows: batch_output.combos.len(),
3928        cols: data.len(),
3929    };
3930
3931    serde_wasm_bindgen::to_value(&js_output)
3932        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
3933}