Skip to main content

vector_ta/indicators/
wto.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::{cuda_available, CudaWto, CudaWtoBatchResult, DeviceArrayF32Triplet};
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyValueError;
7#[cfg(feature = "python")]
8use pyo3::prelude::*;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
19    make_uninit_matrix,
20};
21#[cfg(feature = "python")]
22use crate::utilities::kernel_validation::validate_kernel;
23use aligned_vec::{AVec, CACHELINE_ALIGN};
24
25#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
26use core::arch::x86_64::*;
27
28#[cfg(not(target_arch = "wasm32"))]
29use rayon::prelude::*;
30
31use std::collections::BTreeMap;
32use std::convert::AsRef;
33use std::error::Error;
34use std::mem::MaybeUninit;
35use thiserror::Error;
36
37use crate::indicators::moving_averages::ema::{
38    ema_into_slice, ema_with_kernel, EmaInput, EmaParams,
39};
40use crate::indicators::moving_averages::sma::{
41    sma_into_slice, sma_with_kernel, SmaInput, SmaParams,
42};
43
44#[derive(Debug, Clone)]
45pub enum WtoData<'a> {
46    Candles {
47        candles: &'a Candles,
48        source: &'a str,
49    },
50    Slice(&'a [f64]),
51}
52
53#[derive(Debug, Clone)]
54pub struct WtoOutput {
55    pub wavetrend1: Vec<f64>,
56    pub wavetrend2: Vec<f64>,
57    pub histogram: Vec<f64>,
58}
59
60#[derive(Debug, Clone)]
61#[cfg_attr(
62    all(target_arch = "wasm32", feature = "wasm"),
63    derive(Serialize, Deserialize)
64)]
65pub struct WtoParams {
66    pub channel_length: Option<usize>,
67    pub average_length: Option<usize>,
68}
69
70impl Default for WtoParams {
71    fn default() -> Self {
72        Self {
73            channel_length: Some(10),
74            average_length: Some(21),
75        }
76    }
77}
78
79#[derive(Debug, Clone)]
80pub struct WtoInput<'a> {
81    pub data: WtoData<'a>,
82    pub params: WtoParams,
83}
84
85impl<'a> WtoInput<'a> {
86    #[inline]
87    pub fn from_candles(c: &'a Candles, source: &'a str, p: WtoParams) -> Self {
88        Self {
89            data: WtoData::Candles { candles: c, source },
90            params: p,
91        }
92    }
93
94    #[inline]
95    pub fn from_slice(sl: &'a [f64], p: WtoParams) -> Self {
96        Self {
97            data: WtoData::Slice(sl),
98            params: p,
99        }
100    }
101
102    #[inline]
103    pub fn with_default_candles(c: &'a Candles) -> Self {
104        Self::from_candles(c, "close", WtoParams::default())
105    }
106
107    #[inline]
108    pub fn get_channel_length(&self) -> usize {
109        self.params.channel_length.unwrap_or(10)
110    }
111
112    #[inline]
113    pub fn get_average_length(&self) -> usize {
114        self.params.average_length.unwrap_or(21)
115    }
116}
117
118impl<'a> AsRef<[f64]> for WtoInput<'a> {
119    #[inline(always)]
120    fn as_ref(&self) -> &[f64] {
121        match &self.data {
122            WtoData::Slice(slice) => slice,
123            WtoData::Candles { candles, source } => source_type(candles, source),
124        }
125    }
126}
127
128#[derive(Copy, Clone, Debug)]
129pub struct WtoBuilder {
130    channel_length: Option<usize>,
131    average_length: Option<usize>,
132    kernel: Kernel,
133}
134
135impl Default for WtoBuilder {
136    fn default() -> Self {
137        Self {
138            channel_length: None,
139            average_length: None,
140            kernel: Kernel::Auto,
141        }
142    }
143}
144
145impl WtoBuilder {
146    #[inline(always)]
147    pub fn new() -> Self {
148        Self::default()
149    }
150
151    #[inline(always)]
152    pub fn channel_length(mut self, n: usize) -> Self {
153        self.channel_length = Some(n);
154        self
155    }
156
157    #[inline(always)]
158    pub fn average_length(mut self, n: usize) -> Self {
159        self.average_length = Some(n);
160        self
161    }
162
163    #[inline(always)]
164    pub fn kernel(mut self, k: Kernel) -> Self {
165        self.kernel = k;
166        self
167    }
168
169    #[inline(always)]
170    pub fn apply(self, c: &Candles) -> Result<WtoOutput, WtoError> {
171        let p = WtoParams {
172            channel_length: self.channel_length,
173            average_length: self.average_length,
174        };
175        let i = WtoInput::from_candles(c, "close", p);
176        wto_with_kernel(&i, self.kernel)
177    }
178
179    #[inline(always)]
180    pub fn apply_slice(self, d: &[f64]) -> Result<WtoOutput, WtoError> {
181        let p = WtoParams {
182            channel_length: self.channel_length,
183            average_length: self.average_length,
184        };
185        let i = WtoInput::from_slice(d, p);
186        wto_with_kernel(&i, self.kernel)
187    }
188
189    #[inline(always)]
190    pub fn into_stream(self) -> Result<WtoStream, WtoError> {
191        let p = WtoParams {
192            channel_length: self.channel_length,
193            average_length: self.average_length,
194        };
195        WtoStream::try_new(p)
196    }
197}
198
199#[derive(Debug, Error)]
200pub enum WtoError {
201    #[error("wto: Input data slice is empty.")]
202    EmptyInputData,
203    #[error("wto: All values are NaN.")]
204    AllValuesNaN,
205    #[error("wto: Invalid input: {0}")]
206    InvalidInput(String),
207    #[error("wto: Invalid period: period = {period}, data length = {data_len}")]
208    InvalidPeriod { period: usize, data_len: usize },
209    #[error("wto: Not enough valid data: needed = {needed}, valid = {valid}")]
210    NotEnoughValidData { needed: usize, valid: usize },
211    #[error("wto: Output length mismatch: expected {expected}, got {got}")]
212    OutputLengthMismatch { expected: usize, got: usize },
213    #[error("wto: Invalid range: start={start}, end={end}, step={step}")]
214    InvalidRange {
215        start: String,
216        end: String,
217        step: String,
218    },
219    #[error("wto: Invalid kernel for batch: {0:?}")]
220    InvalidKernelForBatch(crate::utilities::enums::Kernel),
221    #[error("wto: Computation error: {0}")]
222    ComputationError(String),
223}
224
225#[inline]
226pub fn wto(input: &WtoInput) -> Result<WtoOutput, WtoError> {
227    wto_with_kernel(input, Kernel::Auto)
228}
229
230#[inline]
231pub fn wto_with_kernel(input: &WtoInput, kernel: Kernel) -> Result<WtoOutput, WtoError> {
232    let (data, channel_length, average_length, first, chosen) = wto_prepare(input, kernel)?;
233    let len = data.len();
234
235    let ci_start = first + channel_length.saturating_sub(1);
236    let warm_wt1 = ci_start;
237    let warm_wt2_hist = ci_start.saturating_add(3);
238    let mut wavetrend1 = alloc_with_nan_prefix(len, warm_wt1);
239    let mut wavetrend2 = alloc_with_nan_prefix(len, warm_wt2_hist);
240    let mut histogram = alloc_with_nan_prefix(len, warm_wt2_hist);
241
242    wto_compute_into(
243        data,
244        channel_length,
245        average_length,
246        first,
247        chosen,
248        &mut wavetrend1,
249        &mut wavetrend2,
250        &mut histogram,
251    )?;
252
253    Ok(WtoOutput {
254        wavetrend1,
255        wavetrend2,
256        histogram,
257    })
258}
259
260#[inline]
261pub fn wto_into_slices(
262    wt1: &mut [f64],
263    wt2: &mut [f64],
264    hist: &mut [f64],
265    input: &WtoInput,
266    kernel: Kernel,
267) -> Result<(), WtoError> {
268    let (data, channel_length, average_length, first, chosen) = wto_prepare(input, kernel)?;
269
270    if wt1.len() != data.len() || wt2.len() != data.len() || hist.len() != data.len() {
271        let expected = data.len();
272        let got = wt1.len().max(wt2.len()).max(hist.len());
273        return Err(WtoError::OutputLengthMismatch { expected, got });
274    }
275
276    let ci_start = first + channel_length.saturating_sub(1);
277    let warm_wt1 = ci_start.min(wt1.len());
278    let warm_wt2_hist = ci_start.saturating_add(3);
279    let warm_wt2_hist = warm_wt2_hist.min(wt2.len()).min(hist.len());
280    let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
281    for v in &mut wt1[..warm_wt1] {
282        *v = qnan;
283    }
284    for v in &mut wt2[..warm_wt2_hist] {
285        *v = qnan;
286    }
287    for v in &mut hist[..warm_wt2_hist] {
288        *v = qnan;
289    }
290
291    wto_compute_into(
292        data,
293        channel_length,
294        average_length,
295        first,
296        chosen,
297        wt1,
298        wt2,
299        hist,
300    )?;
301
302    Ok(())
303}
304
305#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
306pub fn wto_into(
307    input: &WtoInput,
308    wt1_out: &mut [f64],
309    wt2_out: &mut [f64],
310    hist_out: &mut [f64],
311) -> Result<(), WtoError> {
312    let (data, channel_length, average_length, first, chosen) = wto_prepare(input, Kernel::Auto)?;
313
314    if wt1_out.len() != data.len() || wt2_out.len() != data.len() || hist_out.len() != data.len() {
315        let expected = data.len();
316        let got = wt1_out.len().max(wt2_out.len()).max(hist_out.len());
317        return Err(WtoError::OutputLengthMismatch { expected, got });
318    }
319
320    let ci_start = first + channel_length.saturating_sub(1);
321    let warm_wt1 = ci_start;
322    let warm_wt2_hist = ci_start.saturating_add(3);
323    let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
324    let w = warm_wt1.min(wt1_out.len());
325    for v in &mut wt1_out[..w] {
326        *v = qnan;
327    }
328    let w = warm_wt2_hist.min(wt2_out.len());
329    for v in &mut wt2_out[..w] {
330        *v = qnan;
331    }
332    let w = warm_wt2_hist.min(hist_out.len());
333    for v in &mut hist_out[..w] {
334        *v = qnan;
335    }
336
337    wto_compute_into(
338        data,
339        channel_length,
340        average_length,
341        first,
342        chosen,
343        wt1_out,
344        wt2_out,
345        hist_out,
346    )
347}
348
349#[inline(always)]
350fn wto_prepare<'a>(
351    input: &'a WtoInput,
352    kernel: Kernel,
353) -> Result<(&'a [f64], usize, usize, usize, Kernel), WtoError> {
354    let data: &[f64] = input.as_ref();
355    let len = data.len();
356
357    if len == 0 {
358        return Err(WtoError::EmptyInputData);
359    }
360
361    let first = data
362        .iter()
363        .position(|x| !x.is_nan())
364        .ok_or(WtoError::AllValuesNaN)?;
365    let channel_length = input.get_channel_length();
366    let average_length = input.get_average_length();
367
368    if channel_length == 0 || channel_length > len {
369        return Err(WtoError::InvalidPeriod {
370            period: channel_length,
371            data_len: len,
372        });
373    }
374    if average_length == 0 || average_length > len {
375        return Err(WtoError::InvalidPeriod {
376            period: average_length,
377            data_len: len,
378        });
379    }
380
381    let valid = len - first;
382    let needed = channel_length
383        .saturating_add(3)
384        .max(average_length.saturating_add(3));
385    if valid < needed {
386        return Err(WtoError::NotEnoughValidData { needed, valid });
387    }
388
389    let chosen = match kernel {
390        Kernel::Auto => Kernel::Scalar,
391        k => k,
392    };
393
394    Ok((data, channel_length, average_length, first, chosen))
395}
396
397#[inline(always)]
398fn wto_compute_into(
399    data: &[f64],
400    channel_length: usize,
401    average_length: usize,
402    first: usize,
403    kernel: Kernel,
404    wt1: &mut [f64],
405    wt2: &mut [f64],
406    hist: &mut [f64],
407) -> Result<(), WtoError> {
408    unsafe {
409        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
410        {
411            if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
412                return wto_simd128(data, channel_length, average_length, first, wt1, wt2, hist);
413            }
414        }
415
416        match kernel {
417            Kernel::Scalar | Kernel::ScalarBatch => wto_scalar(
418                data,
419                channel_length,
420                average_length,
421                first,
422                kernel,
423                wt1,
424                wt2,
425                hist,
426            ),
427            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
428            Kernel::Avx2 | Kernel::Avx2Batch => {
429                wto_avx2(data, channel_length, average_length, first, wt1, wt2, hist)
430            }
431            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
432            Kernel::Avx512 | Kernel::Avx512Batch => {
433                wto_avx512(data, channel_length, average_length, first, wt1, wt2, hist)
434            }
435            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
436            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => wto_scalar(
437                data,
438                channel_length,
439                average_length,
440                first,
441                kernel,
442                wt1,
443                wt2,
444                hist,
445            ),
446            _ => unreachable!(),
447        }
448    }
449}
450
451#[inline(always)]
452fn ema_pinescript_into(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
453    let len = data.len();
454    if first_val >= len {
455        return;
456    }
457
458    let alpha = 2.0 / (period as f64 + 1.0);
459    let beta = 1.0 - alpha;
460
461    let mut ema = data[first_val];
462    out[first_val] = ema;
463
464    for i in (first_val + 1)..len {
465        if data[i].is_finite() {
466            ema = alpha * data[i] + beta * ema;
467            out[i] = ema;
468        } else {
469            out[i] = ema;
470        }
471    }
472}
473
474#[inline]
475pub fn wto_scalar(
476    data: &[f64],
477    channel_length: usize,
478    average_length: usize,
479    first_val: usize,
480    _kernel: Kernel,
481    wt1: &mut [f64],
482    wt2: &mut [f64],
483    hist: &mut [f64],
484) -> Result<(), WtoError> {
485    #[inline(always)]
486    fn fast_abs(x: f64) -> f64 {
487        f64::from_bits(x.to_bits() & 0x7FFF_FFFF_FFFF_FFFF)
488    }
489
490    let len = data.len();
491    if len == 0 || first_val >= len {
492        return Ok(());
493    }
494
495    let alpha_e = 2.0 / (channel_length as f64 + 1.0);
496    let beta_e = 1.0 - alpha_e;
497    let alpha_t = 2.0 / (average_length as f64 + 1.0);
498    let beta_t = 1.0 - alpha_t;
499
500    let ci_start = first_val + channel_length.saturating_sub(1);
501    if ci_start >= len {
502        return Ok(());
503    }
504
505    let mut esa = data[first_val];
506
507    let mut i = first_val + 1;
508    while i < ci_start {
509        let x = data[i];
510        if x.is_finite() {
511            esa = beta_e.mul_add(esa, alpha_e * x);
512        }
513        i += 1;
514    }
515
516    let mut d = 0.0_f64;
517    let mut tci = 0.0_f64;
518
519    let mut ring = [0.0_f64; 4];
520    let mut rsum = 0.0_f64;
521    let mut rpos = 0usize;
522    let mut rlen = 0usize;
523
524    let k015 = 0.015_f64;
525    let inv4 = 0.25_f64;
526
527    {
528        let x = data[ci_start];
529        if x.is_finite() {
530            esa = beta_e.mul_add(esa, alpha_e * x);
531        }
532        let abs_diff = if x.is_finite() {
533            fast_abs(x - esa)
534        } else {
535            f64::NAN
536        };
537        d = abs_diff;
538
539        let denom = k015 * d;
540        let ci = if denom != 0.0 && denom.is_finite() {
541            if x.is_finite() {
542                (x - esa) / denom
543            } else {
544                f64::NAN
545            }
546        } else {
547            0.0
548        };
549
550        tci = ci;
551        wt1[ci_start] = tci;
552
553        ring[0] = tci;
554        rsum = tci;
555        rlen = 1;
556        rpos = 1;
557    }
558
559    i = ci_start + 1;
560    while i < len {
561        let x = data[i];
562        let x_fin = x.is_finite();
563
564        if x_fin {
565            esa = beta_e.mul_add(esa, alpha_e * x);
566            let ad = fast_abs(x - esa);
567            d = beta_e.mul_add(d, alpha_e * ad);
568        }
569
570        let mut ci = 0.0_f64;
571        if x_fin {
572            let denom = k015 * d;
573            if denom != 0.0 && denom.is_finite() {
574                ci = (x - esa) / denom;
575            }
576        } else {
577            ci = f64::NAN;
578        }
579
580        if ci.is_finite() {
581            tci = beta_t.mul_add(tci, alpha_t * ci);
582        }
583
584        wt1[i] = tci;
585
586        if rlen < 4 {
587            ring[rlen] = tci;
588            rsum += tci;
589            rlen += 1;
590        } else {
591            rsum += tci - ring[rpos];
592            ring[rpos] = tci;
593            rpos = (rpos + 1) & 3;
594        }
595
596        if rlen == 4 {
597            let sig = inv4 * rsum;
598            wt2[i] = sig;
599            hist[i] = tci - sig;
600        }
601
602        i += 1;
603    }
604
605    Ok(())
606}
607
608#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
609#[inline]
610unsafe fn wto_simd128(
611    data: &[f64],
612    channel_length: usize,
613    average_length: usize,
614    first_val: usize,
615    wt1: &mut [f64],
616    wt2: &mut [f64],
617    hist: &mut [f64],
618) -> Result<(), WtoError> {
619    wto_scalar(
620        data,
621        channel_length,
622        average_length,
623        first_val,
624        Kernel::Scalar,
625        wt1,
626        wt2,
627        hist,
628    )
629}
630
631#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
632#[inline]
633unsafe fn wto_avx2(
634    data: &[f64],
635    channel_length: usize,
636    average_length: usize,
637    first_val: usize,
638    wt1: &mut [f64],
639    wt2: &mut [f64],
640    hist: &mut [f64],
641) -> Result<(), WtoError> {
642    #[inline(always)]
643    fn fast_abs(x: f64) -> f64 {
644        f64::from_bits(x.to_bits() & 0x7FFF_FFFF_FFFF_FFFF)
645    }
646
647    let len = data.len();
648    if len == 0 {
649        return Ok(());
650    }
651
652    let alpha_e = 2.0 / (channel_length as f64 + 1.0);
653    let beta_e = 1.0 - alpha_e;
654    let alpha_t = 2.0 / (average_length as f64 + 1.0);
655    let beta_t = 1.0 - alpha_t;
656
657    let ci_start = first_val + channel_length.saturating_sub(1);
658
659    let mut ring = [0.0_f64; 4];
660    let mut rsum = 0.0_f64;
661    let mut rpos = 0usize;
662    let mut rlen = 0usize;
663
664    let x_ptr = data.as_ptr();
665    let wt1_ptr = wt1.as_mut_ptr();
666    let wt2_ptr = wt2.as_mut_ptr();
667    let hist_ptr = hist.as_mut_ptr();
668
669    let mut i = 0usize;
670    while i < first_val {
671        *wt1_ptr.add(i) = f64::NAN;
672        *wt2_ptr.add(i) = f64::NAN;
673        *hist_ptr.add(i) = f64::NAN;
674        i += 1;
675    }
676
677    let mut esa = *x_ptr.add(first_val);
678
679    let mut d = f64::NAN;
680    let mut d_inited = false;
681    let mut tci = f64::NAN;
682    let mut tci_inited = false;
683
684    i = first_val;
685    while i < len {
686        let x = *x_ptr.add(i);
687        let x_fin = x.is_finite();
688
689        if i != first_val && x_fin {
690            esa = beta_e.mul_add(esa, alpha_e * x);
691        }
692
693        if i >= ci_start {
694            let abs_diff = if x_fin { fast_abs(x - esa) } else { f64::NAN };
695            if !d_inited {
696                if i == ci_start {
697                    d = abs_diff;
698                    d_inited = true;
699                }
700            } else if abs_diff.is_finite() {
701                d = beta_e.mul_add(d, alpha_e * abs_diff);
702            }
703
704            let denom = 0.015_f64 * d;
705            let mut ci = 0.0_f64;
706            if denom != 0.0 && denom.is_finite() {
707                ci = if x_fin { (x - esa) / denom } else { f64::NAN };
708            }
709
710            if !tci_inited {
711                if i == ci_start {
712                    tci = ci;
713                    tci_inited = true;
714                }
715            } else if ci.is_finite() {
716                tci = beta_t.mul_add(tci, alpha_t * ci);
717            }
718
719            if tci_inited {
720                *wt1_ptr.add(i) = tci;
721
722                if rlen < 4 {
723                    ring[rlen] = tci;
724                    rsum += tci;
725                    rlen += 1;
726                } else {
727                    rsum += tci - ring[rpos];
728                    ring[rpos] = tci;
729                    rpos = (rpos + 1) & 3;
730                }
731
732                if rlen == 4 {
733                    let sig = 0.25_f64 * rsum;
734                    *wt2_ptr.add(i) = sig;
735                    *hist_ptr.add(i) = tci - sig;
736                } else {
737                    *wt2_ptr.add(i) = f64::NAN;
738                    *hist_ptr.add(i) = f64::NAN;
739                }
740            } else {
741                *wt1_ptr.add(i) = f64::NAN;
742                *wt2_ptr.add(i) = f64::NAN;
743                *hist_ptr.add(i) = f64::NAN;
744            }
745        } else {
746            *wt1_ptr.add(i) = f64::NAN;
747            *wt2_ptr.add(i) = f64::NAN;
748            *hist_ptr.add(i) = f64::NAN;
749        }
750
751        i += 1;
752    }
753
754    Ok(())
755}
756
757#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
758#[inline]
759unsafe fn wto_avx512(
760    data: &[f64],
761    channel_length: usize,
762    average_length: usize,
763    first_val: usize,
764    wt1: &mut [f64],
765    wt2: &mut [f64],
766    hist: &mut [f64],
767) -> Result<(), WtoError> {
768    #[inline(always)]
769    fn fast_abs(x: f64) -> f64 {
770        f64::from_bits(x.to_bits() & 0x7FFF_FFFF_FFFF_FFFF)
771    }
772
773    let len = data.len();
774    if len == 0 || first_val >= len {
775        return Ok(());
776    }
777
778    let alpha_e = 2.0 / (channel_length as f64 + 1.0);
779    let beta_e = 1.0 - alpha_e;
780    let alpha_t = 2.0 / (average_length as f64 + 1.0);
781    let beta_t = 1.0 - alpha_t;
782
783    let ci_start = first_val + channel_length.saturating_sub(1);
784    if ci_start >= len {
785        return Ok(());
786    }
787
788    let x_ptr = data.as_ptr();
789    let wt1_ptr = wt1.as_mut_ptr();
790    let wt2_ptr = wt2.as_mut_ptr();
791    let hs_ptr = hist.as_mut_ptr();
792
793    let mut esa = *x_ptr.add(first_val);
794    let mut i = first_val + 1;
795    while i < ci_start {
796        let x = *x_ptr.add(i);
797        if x.is_finite() {
798            esa = beta_e.mul_add(esa, alpha_e * x);
799        }
800        i += 1;
801    }
802
803    let mut d = 0.0_f64;
804    let mut tci = 0.0_f64;
805
806    let mut ring = [0.0_f64; 4];
807    let mut rsum = 0.0_f64;
808    let mut rpos = 0usize;
809    let mut rlen = 0usize;
810
811    let k015 = 0.015_f64;
812    let inv4 = 0.25_f64;
813
814    {
815        let x = *x_ptr.add(ci_start);
816        if x.is_finite() {
817            esa = beta_e.mul_add(esa, alpha_e * x);
818        }
819        let abs_diff = if x.is_finite() {
820            fast_abs(x - esa)
821        } else {
822            f64::NAN
823        };
824        d = abs_diff;
825
826        let denom = k015 * d;
827        let ci = if denom != 0.0 && denom.is_finite() {
828            if x.is_finite() {
829                (x - esa) / denom
830            } else {
831                f64::NAN
832            }
833        } else {
834            0.0
835        };
836        tci = ci;
837        *wt1_ptr.add(ci_start) = tci;
838
839        ring[0] = tci;
840        rsum = tci;
841        rlen = 1;
842        rpos = 1;
843    }
844
845    i = ci_start + 1;
846    while i < len {
847        let x = *x_ptr.add(i);
848        let x_fin = x.is_finite();
849
850        if x_fin {
851            esa = beta_e.mul_add(esa, alpha_e * x);
852            let ad = fast_abs(x - esa);
853            d = beta_e.mul_add(d, alpha_e * ad);
854        }
855
856        let mut ci = 0.0_f64;
857        if x_fin {
858            let denom = k015 * d;
859            if denom != 0.0 && denom.is_finite() {
860                ci = (x - esa) / denom;
861            }
862        } else {
863            ci = f64::NAN;
864        }
865
866        if ci.is_finite() {
867            tci = beta_t.mul_add(tci, alpha_t * ci);
868        }
869
870        *wt1_ptr.add(i) = tci;
871
872        if rlen < 4 {
873            ring[rlen] = tci;
874            rsum += tci;
875            rlen += 1;
876        } else {
877            rsum += tci - ring[rpos];
878            ring[rpos] = tci;
879            rpos = (rpos + 1) & 3;
880        }
881
882        if rlen == 4 {
883            let sig = inv4 * rsum;
884            *wt2_ptr.add(i) = sig;
885            *hs_ptr.add(i) = tci - sig;
886        }
887
888        i += 1;
889    }
890
891    Ok(())
892}
893
894#[cfg(feature = "python")]
895#[pyfunction(name = "wto")]
896#[pyo3(signature = (close, channel_length, average_length, kernel=None))]
897pub fn wto_py<'py>(
898    py: Python<'py>,
899    close: numpy::PyReadonlyArray1<'py, f64>,
900    channel_length: usize,
901    average_length: usize,
902    kernel: Option<&str>,
903) -> PyResult<(
904    Bound<'py, numpy::PyArray1<f64>>,
905    Bound<'py, numpy::PyArray1<f64>>,
906    Bound<'py, numpy::PyArray1<f64>>,
907)> {
908    use numpy::{IntoPyArray, PyArrayMethods};
909
910    let slice = close.as_slice()?;
911    let kern = validate_kernel(kernel, false)?;
912    let p = WtoParams {
913        channel_length: Some(channel_length),
914        average_length: Some(average_length),
915    };
916    let inp = WtoInput::from_slice(slice, p);
917    let out = py
918        .allow_threads(|| wto_with_kernel(&inp, kern))
919        .map_err(|e| PyValueError::new_err(e.to_string()))?;
920    Ok((
921        out.wavetrend1.into_pyarray(py),
922        out.wavetrend2.into_pyarray(py),
923        out.histogram.into_pyarray(py),
924    ))
925}
926
927#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
928#[wasm_bindgen(js_name = "wto_js")]
929pub fn wto_js(
930    close: &[f64],
931    channel_length: usize,
932    average_length: usize,
933) -> Result<js_sys::Object, JsValue> {
934    let params = WtoParams {
935        channel_length: Some(channel_length),
936        average_length: Some(average_length),
937    };
938    let input = WtoInput::from_slice(close, params);
939
940    let output = wto(&input).map_err(|e| JsValue::from_str(&e.to_string()))?;
941
942    let result = js_sys::Object::new();
943
944    let wt1_array = js_sys::Float64Array::new_with_length(output.wavetrend1.len() as u32);
945    wt1_array.copy_from(&output.wavetrend1);
946    js_sys::Reflect::set(&result, &JsValue::from_str("wavetrend1"), &wt1_array)?;
947
948    let wt2_array = js_sys::Float64Array::new_with_length(output.wavetrend2.len() as u32);
949    wt2_array.copy_from(&output.wavetrend2);
950    js_sys::Reflect::set(&result, &JsValue::from_str("wavetrend2"), &wt2_array)?;
951
952    let hist_array = js_sys::Float64Array::new_with_length(output.histogram.len() as u32);
953    hist_array.copy_from(&output.histogram);
954    js_sys::Reflect::set(&result, &JsValue::from_str("histogram"), &hist_array)?;
955
956    Ok(result)
957}
958
959#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
960#[wasm_bindgen]
961pub fn wto_alloc(len: usize) -> *mut f64 {
962    let mut v = Vec::<f64>::with_capacity(len);
963    let p = v.as_mut_ptr();
964    core::mem::forget(v);
965    p
966}
967
968#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
969#[wasm_bindgen]
970pub fn wto_free(ptr: *mut f64, len: usize) {
971    unsafe {
972        let _ = Vec::from_raw_parts(ptr, len, len);
973    }
974}
975
976#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
977#[wasm_bindgen]
978pub fn wto_into(
979    in_ptr: *const f64,
980    wt1_ptr: *mut f64,
981    wt2_ptr: *mut f64,
982    hist_ptr: *mut f64,
983    len: usize,
984    channel_length: usize,
985    average_length: usize,
986) -> Result<(), JsValue> {
987    if in_ptr.is_null() || wt1_ptr.is_null() || wt2_ptr.is_null() || hist_ptr.is_null() {
988        return Err(JsValue::from_str("null pointer passed to wto_into"));
989    }
990    unsafe {
991        let data = core::slice::from_raw_parts(in_ptr, len);
992        let wt1 = core::slice::from_raw_parts_mut(wt1_ptr, len);
993        let wt2 = core::slice::from_raw_parts_mut(wt2_ptr, len);
994        let hist = core::slice::from_raw_parts_mut(hist_ptr, len);
995
996        let p = WtoParams {
997            channel_length: Some(channel_length),
998            average_length: Some(average_length),
999        };
1000        let inp = WtoInput::from_slice(data, p);
1001
1002        wto_into_slices(wt1, wt2, hist, &inp, Kernel::Auto)
1003            .map_err(|e| JsValue::from_str(&e.to_string()))
1004    }
1005}
1006
1007#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1008#[derive(Serialize, Deserialize)]
1009pub struct WtoResult {
1010    pub values: Vec<f64>,
1011    pub rows: usize,
1012    pub cols: usize,
1013}
1014
1015#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1016#[wasm_bindgen(js_name = "wto_unified")]
1017pub fn wto_unified_js(
1018    close: &[f64],
1019    channel_length: usize,
1020    average_length: usize,
1021) -> Result<JsValue, JsValue> {
1022    let params = WtoParams {
1023        channel_length: Some(channel_length),
1024        average_length: Some(average_length),
1025    };
1026    let input = WtoInput::from_slice(close, params);
1027    let out = wto(&input).map_err(|e| JsValue::from_str(&e.to_string()))?;
1028
1029    let cols = close.len();
1030    let cap = 3usize
1031        .checked_mul(cols)
1032        .ok_or_else(|| JsValue::from_str("overflow in wto_unified_js allocation"))?;
1033    let mut values = Vec::with_capacity(cap);
1034    values.extend_from_slice(&out.wavetrend1);
1035    values.extend_from_slice(&out.wavetrend2);
1036    values.extend_from_slice(&out.histogram);
1037
1038    let res = WtoResult {
1039        values,
1040        rows: 3,
1041        cols,
1042    };
1043    serde_wasm_bindgen::to_value(&res)
1044        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1045}
1046
1047#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1048#[derive(Serialize, Deserialize)]
1049pub struct WtoBatchConfig {
1050    pub channel: (usize, usize, usize),
1051    pub average: (usize, usize, usize),
1052}
1053
1054#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1055#[derive(Serialize, Deserialize)]
1056pub struct WtoBatchJsOutput {
1057    pub wavetrend1: Vec<f64>,
1058    pub wavetrend2: Vec<f64>,
1059    pub histogram: Vec<f64>,
1060    pub combos: Vec<WtoParams>,
1061    pub rows: usize,
1062    pub cols: usize,
1063}
1064
1065#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1066#[wasm_bindgen]
1067pub fn wto_batch_into(
1068    in_ptr: *const f64,
1069    out_ptr: *mut f64,
1070    len: usize,
1071    ch_start: usize,
1072    ch_end: usize,
1073    ch_step: usize,
1074    av_start: usize,
1075    av_end: usize,
1076    av_step: usize,
1077) -> Result<usize, JsValue> {
1078    if in_ptr.is_null() || out_ptr.is_null() {
1079        return Err(JsValue::from_str("null pointer passed to wto_batch_into"));
1080    }
1081    if len == 0 {
1082        return Err(JsValue::from_str(&WtoError::EmptyInputData.to_string()));
1083    }
1084    if in_ptr == out_ptr {
1085        return Err(JsValue::from_str(
1086            "wto_batch_into: in_ptr and out_ptr must not alias",
1087        ));
1088    }
1089    unsafe {
1090        let data = core::slice::from_raw_parts(in_ptr, len);
1091        let first = data
1092            .iter()
1093            .position(|x| !x.is_nan())
1094            .ok_or_else(|| JsValue::from_str(&WtoError::AllValuesNaN.to_string()))?;
1095        let sweep = WtoBatchRange {
1096            channel: (ch_start, ch_end, ch_step),
1097            average: (av_start, av_end, av_step),
1098        };
1099        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1100        let rows = combos.len();
1101        let cols = len;
1102        let total = rows
1103            .checked_mul(cols)
1104            .ok_or_else(|| JsValue::from_str("rows * cols overflow in wto_batch_into"))?;
1105
1106        let out_mu = core::slice::from_raw_parts_mut(out_ptr as *mut MaybeUninit<f64>, total);
1107        let mut warms: Vec<usize> = Vec::with_capacity(rows);
1108        for p in combos.iter() {
1109            let channel_length = p.channel_length.unwrap_or(10);
1110            let average_length = p.average_length.unwrap_or(21);
1111
1112            if channel_length == 0 || channel_length > cols {
1113                return Err(JsValue::from_str(
1114                    &WtoError::InvalidPeriod {
1115                        period: channel_length,
1116                        data_len: cols,
1117                    }
1118                    .to_string(),
1119                ));
1120            }
1121            if average_length == 0 || average_length > cols {
1122                return Err(JsValue::from_str(
1123                    &WtoError::InvalidPeriod {
1124                        period: average_length,
1125                        data_len: cols,
1126                    }
1127                    .to_string(),
1128                ));
1129            }
1130
1131            let valid = cols.saturating_sub(first);
1132            if valid < channel_length {
1133                return Err(JsValue::from_str(
1134                    &WtoError::NotEnoughValidData {
1135                        needed: channel_length,
1136                        valid,
1137                    }
1138                    .to_string(),
1139                ));
1140            }
1141
1142            let ci_start = first + channel_length - 1;
1143            warms.push(ci_start);
1144        }
1145        init_matrix_prefixes(out_mu, cols, &warms);
1146
1147        let out = core::slice::from_raw_parts_mut(out_ptr, total);
1148        wto_fill_wt1_grouped(data, &combos, first, detect_best_kernel(), false, out)
1149            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1150
1151        Ok(rows)
1152    }
1153}
1154
1155#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1156#[wasm_bindgen(js_name = "wto_batch")]
1157pub fn wto_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1158    let cfg: WtoBatchConfig = serde_wasm_bindgen::from_value(config)
1159        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1160    let sweep = WtoBatchRange {
1161        channel: cfg.channel,
1162        average: cfg.average,
1163    };
1164    let out = wto_batch_all_outputs_with_kernel(data, &sweep, Kernel::ScalarBatch)
1165        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1166    let js = WtoBatchJsOutput {
1167        wavetrend1: out.wt1,
1168        wavetrend2: out.wt2,
1169        histogram: out.hist,
1170        combos: out.combos,
1171        rows: out.rows,
1172        cols: out.cols,
1173    };
1174    serde_wasm_bindgen::to_value(&js)
1175        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1176}
1177
1178#[derive(Debug, Clone)]
1179pub struct WtoBatchRange {
1180    pub channel: (usize, usize, usize),
1181    pub average: (usize, usize, usize),
1182}
1183
1184impl Default for WtoBatchRange {
1185    fn default() -> Self {
1186        Self {
1187            channel: (10, 10, 0),
1188            average: (21, 270, 1),
1189        }
1190    }
1191}
1192
1193#[derive(Debug, Clone)]
1194pub struct WtoBatchOutput {
1195    pub values: Vec<f64>,
1196    pub combos: Vec<WtoParams>,
1197    pub rows: usize,
1198    pub cols: usize,
1199}
1200
1201impl WtoBatchOutput {
1202    pub fn values_for(&self, params: &WtoParams) -> Option<&[f64]> {
1203        self.row_for_params(params)
1204            .map(|row| &self.values[row * self.cols..(row + 1) * self.cols])
1205    }
1206
1207    pub fn row_for_params(&self, params: &WtoParams) -> Option<usize> {
1208        self.combos.iter().position(|p| {
1209            p.channel_length.unwrap_or(10) == params.channel_length.unwrap_or(10)
1210                && p.average_length.unwrap_or(21) == params.average_length.unwrap_or(21)
1211        })
1212    }
1213}
1214
1215#[derive(Debug, Clone)]
1216pub struct WtoBatchBuilder {
1217    channel_range: (usize, usize, usize),
1218    average_range: (usize, usize, usize),
1219    kernel: Kernel,
1220}
1221
1222impl Default for WtoBatchBuilder {
1223    fn default() -> Self {
1224        Self {
1225            channel_range: (10, 10, 0),
1226            average_range: (21, 270, 1),
1227            kernel: Kernel::Auto,
1228        }
1229    }
1230}
1231
1232impl WtoBatchBuilder {
1233    pub fn new() -> Self {
1234        Self::default()
1235    }
1236
1237    pub fn channel_range(mut self, start: usize, end: usize, step: usize) -> Self {
1238        self.channel_range = (start, end, step);
1239        self
1240    }
1241
1242    pub fn average_range(mut self, start: usize, end: usize, step: usize) -> Self {
1243        self.average_range = (start, end, step);
1244        self
1245    }
1246
1247    pub fn kernel(mut self, k: Kernel) -> Self {
1248        self.kernel = k;
1249        self
1250    }
1251
1252    pub fn apply_candles(
1253        self,
1254        candles: &Candles,
1255        source: &str,
1256    ) -> Result<WtoBatchOutput, WtoError> {
1257        wto_batch_candles(
1258            candles,
1259            source,
1260            self.channel_range,
1261            self.average_range,
1262            self.kernel,
1263        )
1264    }
1265
1266    pub fn apply_slice(self, data: &[f64]) -> Result<WtoBatchOutput, WtoError> {
1267        let sweep = WtoBatchRange {
1268            channel: self.channel_range,
1269            average: self.average_range,
1270        };
1271        wto_batch_with_kernel(data, &sweep, self.kernel)
1272    }
1273
1274    pub fn channel_static(mut self, p: usize) -> Self {
1275        self.channel_range = (p, p, 0);
1276        self
1277    }
1278
1279    pub fn average_static(mut self, p: usize) -> Self {
1280        self.average_range = (p, p, 0);
1281        self
1282    }
1283
1284    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<WtoBatchOutput, WtoError> {
1285        WtoBatchBuilder::new().kernel(k).apply_slice(data)
1286    }
1287
1288    pub fn with_default_candles(c: &Candles) -> Result<WtoBatchOutput, WtoError> {
1289        WtoBatchBuilder::new()
1290            .kernel(Kernel::Auto)
1291            .apply_candles(c, "close")
1292    }
1293}
1294
1295#[derive(Clone, Copy, Debug)]
1296struct WtoBatchMember {
1297    row: usize,
1298    average_length: usize,
1299}
1300
1301#[derive(Clone, Copy)]
1302struct ThreadSafePtr(*mut f64);
1303
1304unsafe impl Send for ThreadSafePtr {}
1305unsafe impl Sync for ThreadSafePtr {}
1306
1307#[inline]
1308fn group_rows_by_channel(
1309    combos: &[WtoParams],
1310    cols: usize,
1311) -> Result<Vec<(usize, Vec<WtoBatchMember>)>, WtoError> {
1312    let mut groups: BTreeMap<usize, Vec<WtoBatchMember>> = BTreeMap::new();
1313
1314    for (row, params) in combos.iter().enumerate() {
1315        let channel_length = params.channel_length.unwrap_or(10);
1316        let average_length = params.average_length.unwrap_or(21);
1317
1318        if channel_length == 0 || channel_length > cols {
1319            return Err(WtoError::InvalidPeriod {
1320                period: channel_length,
1321                data_len: cols,
1322            });
1323        }
1324        if average_length == 0 || average_length > cols {
1325            return Err(WtoError::InvalidPeriod {
1326                period: average_length,
1327                data_len: cols,
1328            });
1329        }
1330
1331        groups
1332            .entry(channel_length)
1333            .or_default()
1334            .push(WtoBatchMember {
1335                row,
1336                average_length,
1337            });
1338    }
1339
1340    Ok(groups.into_iter().collect())
1341}
1342
1343#[inline]
1344fn apply_ci_to_members(
1345    ci: &[f64],
1346    members: &[WtoBatchMember],
1347    start_ci: usize,
1348    cols: usize,
1349    out_ptr: ThreadSafePtr,
1350    parallel: bool,
1351) -> Result<(), WtoError> {
1352    let _ = parallel;
1353    for member in members {
1354        let offset = member.row.checked_mul(cols).ok_or_else(|| {
1355            WtoError::InvalidInput("row * cols overflow in apply_ci_to_members".into())
1356        })?;
1357        let dst = unsafe { core::slice::from_raw_parts_mut(out_ptr.0.add(offset), cols) };
1358        ema_pinescript_into(ci, member.average_length, start_ci, dst);
1359    }
1360    Ok(())
1361}
1362
1363fn wto_fill_wt1_grouped(
1364    data: &[f64],
1365    combos: &[WtoParams],
1366    first: usize,
1367    kernel: Kernel,
1368    parallel: bool,
1369    out: &mut [f64],
1370) -> Result<(), WtoError> {
1371    let cols = data.len();
1372    let groups = group_rows_by_channel(combos, cols)?;
1373    if groups.is_empty() {
1374        return Ok(());
1375    }
1376
1377    let kernel = match kernel {
1378        Kernel::Auto => Kernel::Scalar,
1379        other => other.to_non_batch(),
1380    };
1381
1382    match kernel {
1383        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1384        Kernel::Avx512 => unsafe {
1385            wto_batch_fill_wt1_grouped_avx512(data, &groups, first, out, parallel)
1386        },
1387        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1388        Kernel::Avx2 => unsafe {
1389            wto_batch_fill_wt1_grouped_avx2(data, &groups, first, out, parallel)
1390        },
1391        Kernel::Scalar => wto_batch_fill_wt1_grouped_scalar(data, &groups, first, out, parallel),
1392        #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1393        Kernel::Avx2 | Kernel::Avx512 => {
1394            wto_batch_fill_wt1_grouped_scalar(data, &groups, first, out, parallel)
1395        }
1396        #[allow(unreachable_patterns)]
1397        _ => wto_batch_fill_wt1_grouped_scalar(data, &groups, first, out, parallel),
1398    }
1399}
1400
1401fn wto_batch_fill_wt1_grouped_scalar(
1402    data: &[f64],
1403    groups: &[(usize, Vec<WtoBatchMember>)],
1404    first: usize,
1405    out: &mut [f64],
1406    parallel: bool,
1407) -> Result<(), WtoError> {
1408    let cols = data.len();
1409    let out_ptr = ThreadSafePtr(out.as_mut_ptr());
1410
1411    for (channel_length, members) in groups.iter() {
1412        let channel_length = *channel_length;
1413        let start_ci = first + channel_length - 1;
1414        if start_ci >= cols {
1415            return Err(WtoError::NotEnoughValidData {
1416                needed: start_ci + 1,
1417                valid: cols.saturating_sub(first),
1418            });
1419        }
1420
1421        let mut scratch = make_uninit_matrix(3, cols);
1422        let warms = [start_ci, start_ci, start_ci];
1423        init_matrix_prefixes(&mut scratch, cols, &warms);
1424
1425        let mut guard = core::mem::ManuallyDrop::new(scratch);
1426        let flat: &mut [f64] =
1427            unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
1428        let (esa, rest) = flat.split_at_mut(cols);
1429        let (d, ci) = rest.split_at_mut(cols);
1430
1431        ema_pinescript_into(data, channel_length, first, esa);
1432
1433        for i in 0..cols {
1434            ci[i] = (data[i] - esa[i]).abs();
1435        }
1436
1437        ema_pinescript_into(ci, channel_length, start_ci, d);
1438
1439        for i in start_ci..cols {
1440            let denom = 0.015 * d[i];
1441            ci[i] = if denom != 0.0 && denom.is_finite() {
1442                (data[i] - esa[i]) / denom
1443            } else {
1444                0.0
1445            };
1446        }
1447
1448        let ci_slice: &[f64] = ci;
1449        apply_ci_to_members(
1450            ci_slice,
1451            members.as_slice(),
1452            start_ci,
1453            cols,
1454            out_ptr,
1455            parallel,
1456        )?;
1457
1458        unsafe {
1459            Vec::from_raw_parts(
1460                guard.as_mut_ptr() as *mut f64,
1461                guard.len(),
1462                guard.capacity(),
1463            );
1464        }
1465        core::mem::forget(guard);
1466    }
1467
1468    Ok(())
1469}
1470
1471#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1472#[target_feature(enable = "avx2")]
1473unsafe fn wto_batch_fill_wt1_grouped_avx2(
1474    data: &[f64],
1475    groups: &[(usize, Vec<WtoBatchMember>)],
1476    first: usize,
1477    out: &mut [f64],
1478    parallel: bool,
1479) -> Result<(), WtoError> {
1480    use core::arch::x86_64::*;
1481
1482    let cols = data.len();
1483    let out_ptr = ThreadSafePtr(out.as_mut_ptr());
1484
1485    for (channel_length, members) in groups.iter() {
1486        let channel_length = *channel_length;
1487        let start_ci = first + channel_length - 1;
1488        if start_ci >= cols {
1489            return Err(WtoError::NotEnoughValidData {
1490                needed: start_ci + 1,
1491                valid: cols.saturating_sub(first),
1492            });
1493        }
1494
1495        let mut scratch = make_uninit_matrix(3, cols);
1496        let warms = [start_ci, start_ci, start_ci];
1497        init_matrix_prefixes(&mut scratch, cols, &warms);
1498
1499        let mut guard = core::mem::ManuallyDrop::new(scratch);
1500        let flat: &mut [f64] =
1501            core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len());
1502        let (esa, rest) = flat.split_at_mut(cols);
1503        let (d, ci) = rest.split_at_mut(cols);
1504
1505        ema_pinescript_into(data, channel_length, first, esa);
1506
1507        let signmask = _mm256_set1_pd(-0.0_f64);
1508        let mut i = 0usize;
1509        while i + 4 <= cols {
1510            let x = _mm256_loadu_pd(data.as_ptr().add(i));
1511            let e = _mm256_loadu_pd(esa.as_ptr().add(i));
1512            let diff = _mm256_sub_pd(x, e);
1513            let absd = _mm256_andnot_pd(signmask, diff);
1514            _mm256_storeu_pd(ci.as_mut_ptr().add(i), absd);
1515            i += 4;
1516        }
1517        while i < cols {
1518            ci[i] = (data[i] - esa[i]).abs();
1519            i += 1;
1520        }
1521
1522        ema_pinescript_into(ci, channel_length, start_ci, d);
1523
1524        let k015 = _mm256_set1_pd(0.015_f64);
1525        let zero = _mm256_set1_pd(0.0_f64);
1526        let infv = _mm256_set1_pd(f64::INFINITY);
1527        let mut j = start_ci;
1528        while j + 4 <= cols {
1529            let x = _mm256_loadu_pd(data.as_ptr().add(j));
1530            let e = _mm256_loadu_pd(esa.as_ptr().add(j));
1531            let num = _mm256_sub_pd(x, e);
1532
1533            let dv = _mm256_loadu_pd(d.as_ptr().add(j));
1534            let den = _mm256_mul_pd(k015, dv);
1535
1536            let neq0 = _mm256_cmp_pd(den, zero, _CMP_NEQ_OQ);
1537            let abs_den = _mm256_andnot_pd(signmask, den);
1538            let not_inf = _mm256_cmp_pd(abs_den, infv, _CMP_NEQ_OQ);
1539            let ord = _mm256_cmp_pd(den, den, _CMP_ORD_Q);
1540            let valid = _mm256_and_pd(_mm256_and_pd(neq0, not_inf), ord);
1541
1542            let q = _mm256_div_pd(num, den);
1543            let outv = _mm256_blendv_pd(zero, q, valid);
1544            _mm256_storeu_pd(ci.as_mut_ptr().add(j), outv);
1545            j += 4;
1546        }
1547        while j < cols {
1548            let denom = 0.015 * d[j];
1549            ci[j] = if denom != 0.0 && denom.is_finite() {
1550                (data[j] - esa[j]) / denom
1551            } else {
1552                0.0
1553            };
1554            j += 1;
1555        }
1556
1557        let ci_slice: &[f64] = ci;
1558        apply_ci_to_members(
1559            ci_slice,
1560            members.as_slice(),
1561            start_ci,
1562            cols,
1563            out_ptr,
1564            parallel,
1565        )?;
1566
1567        unsafe {
1568            Vec::from_raw_parts(
1569                guard.as_mut_ptr() as *mut f64,
1570                guard.len(),
1571                guard.capacity(),
1572            );
1573        }
1574        core::mem::forget(guard);
1575    }
1576
1577    Ok(())
1578}
1579
1580#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1581#[target_feature(enable = "avx512f")]
1582unsafe fn wto_batch_fill_wt1_grouped_avx512(
1583    data: &[f64],
1584    groups: &[(usize, Vec<WtoBatchMember>)],
1585    first: usize,
1586    out: &mut [f64],
1587    parallel: bool,
1588) -> Result<(), WtoError> {
1589    use core::arch::x86_64::*;
1590
1591    let cols = data.len();
1592    let out_ptr = ThreadSafePtr(out.as_mut_ptr());
1593
1594    for (channel_length, members) in groups.iter() {
1595        let channel_length = *channel_length;
1596        let start_ci = first + channel_length - 1;
1597        if start_ci >= cols {
1598            return Err(WtoError::NotEnoughValidData {
1599                needed: start_ci + 1,
1600                valid: cols.saturating_sub(first),
1601            });
1602        }
1603
1604        let mut scratch = make_uninit_matrix(3, cols);
1605        let warms = [start_ci, start_ci, start_ci];
1606        init_matrix_prefixes(&mut scratch, cols, &warms);
1607
1608        let mut guard = core::mem::ManuallyDrop::new(scratch);
1609        let flat: &mut [f64] =
1610            core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len());
1611        let (esa, rest) = flat.split_at_mut(cols);
1612        let (d, ci) = rest.split_at_mut(cols);
1613
1614        ema_pinescript_into(data, channel_length, first, esa);
1615
1616        let signmask = _mm512_set1_pd(-0.0_f64);
1617        let mut i = 0usize;
1618        while i + 8 <= cols {
1619            let x = _mm512_loadu_pd(data.as_ptr().add(i));
1620            let e = _mm512_loadu_pd(esa.as_ptr().add(i));
1621            let diff = _mm512_sub_pd(x, e);
1622            let absd = _mm512_andnot_pd(signmask, diff);
1623            _mm512_storeu_pd(ci.as_mut_ptr().add(i), absd);
1624            i += 8;
1625        }
1626        while i < cols {
1627            ci[i] = (data[i] - esa[i]).abs();
1628            i += 1;
1629        }
1630
1631        ema_pinescript_into(ci, channel_length, start_ci, d);
1632
1633        let k015 = _mm512_set1_pd(0.015_f64);
1634        let zero = _mm512_set1_pd(0.0_f64);
1635        let infv = _mm512_set1_pd(f64::INFINITY);
1636        let mut j = start_ci;
1637        while j + 8 <= cols {
1638            let x = _mm512_loadu_pd(data.as_ptr().add(j));
1639            let e = _mm512_loadu_pd(esa.as_ptr().add(j));
1640            let num = _mm512_sub_pd(x, e);
1641
1642            let dv = _mm512_loadu_pd(d.as_ptr().add(j));
1643            let den = _mm512_mul_pd(k015, dv);
1644
1645            let neq0 = _mm512_cmp_pd_mask(den, zero, _CMP_NEQ_OQ);
1646            let abs_den = _mm512_andnot_pd(signmask, den);
1647            let not_inf = _mm512_cmp_pd_mask(abs_den, infv, _CMP_NEQ_OQ);
1648            let ord = _mm512_cmp_pd_mask(den, den, _CMP_ORD_Q);
1649            let valid = neq0 & not_inf & ord;
1650
1651            let q = _mm512_div_pd(num, den);
1652            let outv = _mm512_mask_blend_pd(valid, zero, q);
1653            _mm512_storeu_pd(ci.as_mut_ptr().add(j), outv);
1654            j += 8;
1655        }
1656        while j < cols {
1657            let denom = 0.015 * d[j];
1658            ci[j] = if denom != 0.0 && denom.is_finite() {
1659                (data[j] - esa[j]) / denom
1660            } else {
1661                0.0
1662            };
1663            j += 1;
1664        }
1665
1666        let ci_slice: &[f64] = ci;
1667        apply_ci_to_members(
1668            ci_slice,
1669            members.as_slice(),
1670            start_ci,
1671            cols,
1672            out_ptr,
1673            parallel,
1674        )?;
1675
1676        unsafe {
1677            Vec::from_raw_parts(
1678                guard.as_mut_ptr() as *mut f64,
1679                guard.len(),
1680                guard.capacity(),
1681            );
1682        }
1683        core::mem::forget(guard);
1684    }
1685
1686    Ok(())
1687}
1688
1689fn wto_fill_wt1_row(
1690    data: &[f64],
1691    p: WtoParams,
1692    first: usize,
1693    kern: Kernel,
1694    dst: &mut [f64],
1695) -> Result<(), WtoError> {
1696    let cols = data.len();
1697    let channel_length = p.channel_length.unwrap_or(10);
1698    let average_length = p.average_length.unwrap_or(21);
1699
1700    let mut mu = make_uninit_matrix(2, cols);
1701    let warms = [first + channel_length - 1, first + channel_length - 1];
1702    init_matrix_prefixes(&mut mu, cols, &warms);
1703
1704    let mut guard = core::mem::ManuallyDrop::new(mu);
1705    let flat: &mut [f64] =
1706        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
1707    let (d, ci) = flat.split_at_mut(cols);
1708
1709    ema_pinescript_into(data, channel_length, first, dst);
1710
1711    for i in 0..cols {
1712        ci[i] = (data[i] - dst[i]).abs();
1713    }
1714
1715    let d_first = first + channel_length - 1;
1716    ema_pinescript_into(ci, channel_length, d_first, d);
1717
1718    let start = first + channel_length - 1;
1719    for i in start..cols {
1720        let denom = 0.015 * d[i];
1721        ci[i] = if denom.is_finite() && denom != 0.0 {
1722            (data[i] - dst[i]) / denom
1723        } else {
1724            0.0
1725        };
1726    }
1727
1728    let ci_first = start;
1729    ema_pinescript_into(ci, average_length, ci_first, dst);
1730
1731    Ok(())
1732}
1733
1734#[inline(always)]
1735fn expand_grid(r: &WtoBatchRange) -> Result<Vec<WtoParams>, WtoError> {
1736    fn axis_u((s, e, st): (usize, usize, usize)) -> Result<Vec<usize>, WtoError> {
1737        if st == 0 || s == e {
1738            return Ok(vec![s]);
1739        }
1740        let mut out = Vec::new();
1741        if s < e {
1742            let mut v = s;
1743            loop {
1744                if v > e {
1745                    break;
1746                }
1747                out.push(v);
1748                let next = v.checked_add(st).ok_or_else(|| WtoError::InvalidRange {
1749                    start: s.to_string(),
1750                    end: e.to_string(),
1751                    step: st.to_string(),
1752                })?;
1753                if next == v {
1754                    break;
1755                }
1756                v = next;
1757            }
1758        } else {
1759            let mut v = s;
1760            loop {
1761                if v < e {
1762                    break;
1763                }
1764                out.push(v);
1765                if v - e < st {
1766                    break;
1767                }
1768                v -= st;
1769            }
1770        }
1771        if out.is_empty() {
1772            return Err(WtoError::InvalidRange {
1773                start: s.to_string(),
1774                end: e.to_string(),
1775                step: st.to_string(),
1776            });
1777        }
1778        Ok(out)
1779    }
1780    let ch = axis_u(r.channel)?;
1781    let av = axis_u(r.average)?;
1782    let mut out = Vec::with_capacity(ch.len() * av.len());
1783    for &c in &ch {
1784        for &a in &av {
1785            out.push(WtoParams {
1786                channel_length: Some(c),
1787                average_length: Some(a),
1788            });
1789        }
1790    }
1791    if out.is_empty() {
1792        return Err(WtoError::InvalidRange {
1793            start: r.channel.0.to_string(),
1794            end: r.channel.1.to_string(),
1795            step: r.channel.2.to_string(),
1796        });
1797    }
1798    Ok(out)
1799}
1800
1801pub fn wto_batch_with_kernel(
1802    data: &[f64],
1803    sweep: &WtoBatchRange,
1804    k: Kernel,
1805) -> Result<WtoBatchOutput, WtoError> {
1806    let kern = match k {
1807        Kernel::Auto => detect_best_batch_kernel(),
1808        x if x.is_batch() => x,
1809        other => {
1810            return Err(WtoError::InvalidKernelForBatch(other));
1811        }
1812    };
1813    let simd = match kern {
1814        Kernel::Avx512Batch => Kernel::Avx512,
1815        Kernel::Avx2Batch => Kernel::Avx2,
1816        Kernel::ScalarBatch => Kernel::Scalar,
1817        _ => unreachable!(),
1818    };
1819    wto_batch_inner(data, sweep, simd, true)
1820}
1821
1822#[inline(always)]
1823fn wto_batch_inner(
1824    data: &[f64],
1825    sweep: &WtoBatchRange,
1826    kern: Kernel,
1827    parallel: bool,
1828) -> Result<WtoBatchOutput, WtoError> {
1829    if data.is_empty() {
1830        return Err(WtoError::EmptyInputData);
1831    }
1832    let combos = expand_grid(sweep)?;
1833
1834    let cols = data.len();
1835    let rows = combos.len();
1836    if rows.checked_mul(cols).is_none() {
1837        return Err(WtoError::InvalidInput(
1838            "rows * cols overflow in wto_batch_inner".into(),
1839        ));
1840    }
1841
1842    let first = data
1843        .iter()
1844        .position(|x| !x.is_nan())
1845        .ok_or(WtoError::AllValuesNaN)?;
1846
1847    let mut mu = make_uninit_matrix(rows, cols);
1848    {
1849        let mut warms: Vec<usize> = Vec::with_capacity(rows);
1850        for p in combos.iter() {
1851            let channel_length = p.channel_length.unwrap_or(10);
1852            let average_length = p.average_length.unwrap_or(21);
1853
1854            if channel_length == 0 || channel_length > cols {
1855                return Err(WtoError::InvalidPeriod {
1856                    period: channel_length,
1857                    data_len: cols,
1858                });
1859            }
1860            if average_length == 0 || average_length > cols {
1861                return Err(WtoError::InvalidPeriod {
1862                    period: average_length,
1863                    data_len: cols,
1864                });
1865            }
1866
1867            let valid = cols.saturating_sub(first);
1868            if valid < channel_length {
1869                return Err(WtoError::NotEnoughValidData {
1870                    needed: channel_length,
1871                    valid,
1872                });
1873            }
1874
1875            let ci_start = first + channel_length - 1;
1876            warms.push(ci_start);
1877        }
1878        init_matrix_prefixes(&mut mu, cols, &warms);
1879    }
1880    let mut guard = core::mem::ManuallyDrop::new(mu);
1881    let out: &mut [f64] =
1882        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
1883
1884    wto_fill_wt1_grouped(data, &combos, first, kern, parallel, out)?;
1885
1886    let values = unsafe {
1887        Vec::from_raw_parts(
1888            guard.as_mut_ptr() as *mut f64,
1889            guard.len(),
1890            guard.capacity(),
1891        )
1892    };
1893    core::mem::forget(guard);
1894    Ok(WtoBatchOutput {
1895        values,
1896        combos,
1897        rows,
1898        cols,
1899    })
1900}
1901
1902pub fn wto_batch_slice(
1903    data: &[f64],
1904    channel_range: (usize, usize, usize),
1905    average_range: (usize, usize, usize),
1906    kernel: Kernel,
1907) -> Result<WtoBatchOutput, WtoError> {
1908    let sweep = WtoBatchRange {
1909        channel: channel_range,
1910        average: average_range,
1911    };
1912    wto_batch_with_kernel(data, &sweep, kernel)
1913}
1914
1915pub fn wto_batch_candles(
1916    candles: &Candles,
1917    source: &str,
1918    channel_range: (usize, usize, usize),
1919    average_range: (usize, usize, usize),
1920    kernel: Kernel,
1921) -> Result<WtoBatchOutput, WtoError> {
1922    let data = source_type(candles, source);
1923    wto_batch_slice(data, channel_range, average_range, kernel)
1924}
1925
1926#[derive(Debug, Clone)]
1927pub struct WtoBatchAllOutput {
1928    pub wt1: Vec<f64>,
1929    pub wt2: Vec<f64>,
1930    pub hist: Vec<f64>,
1931    pub combos: Vec<WtoParams>,
1932    pub rows: usize,
1933    pub cols: usize,
1934}
1935
1936pub fn wto_batch_all_outputs_with_kernel(
1937    data: &[f64],
1938    sweep: &WtoBatchRange,
1939    k: Kernel,
1940) -> Result<WtoBatchAllOutput, WtoError> {
1941    if data.is_empty() {
1942        return Err(WtoError::EmptyInputData);
1943    }
1944
1945    let combos = expand_grid(sweep)?;
1946
1947    let cols = data.len();
1948    let rows = combos.len();
1949    if rows.checked_mul(cols).is_none() {
1950        return Err(WtoError::InvalidInput(
1951            "rows * cols overflow in wto_batch_all_outputs_with_kernel".into(),
1952        ));
1953    }
1954
1955    let mut wt1_mu = make_uninit_matrix(rows, cols);
1956    let mut wt2_mu = make_uninit_matrix(rows, cols);
1957    let mut hist_mu = make_uninit_matrix(rows, cols);
1958
1959    let first = data
1960        .iter()
1961        .position(|x| !x.is_nan())
1962        .ok_or(WtoError::AllValuesNaN)?;
1963    let mut warm_wt1: Vec<usize> = Vec::with_capacity(rows);
1964    let mut warm_wt2_hist: Vec<usize> = Vec::with_capacity(rows);
1965    for p in combos.iter() {
1966        let channel_length = p.channel_length.unwrap_or(10);
1967        let average_length = p.average_length.unwrap_or(21);
1968
1969        if channel_length == 0 || channel_length > cols {
1970            return Err(WtoError::InvalidPeriod {
1971                period: channel_length,
1972                data_len: cols,
1973            });
1974        }
1975        if average_length == 0 || average_length > cols {
1976            return Err(WtoError::InvalidPeriod {
1977                period: average_length,
1978                data_len: cols,
1979            });
1980        }
1981
1982        let valid = cols.saturating_sub(first);
1983        let needed = channel_length
1984            .saturating_add(3)
1985            .max(average_length.saturating_add(3));
1986        if valid < needed {
1987            return Err(WtoError::NotEnoughValidData { needed, valid });
1988        }
1989
1990        let ci_start = first + channel_length - 1;
1991        warm_wt1.push(ci_start);
1992        warm_wt2_hist.push(ci_start + 3);
1993    }
1994
1995    init_matrix_prefixes(&mut wt1_mu, cols, &warm_wt1);
1996    init_matrix_prefixes(&mut wt2_mu, cols, &warm_wt2_hist);
1997    init_matrix_prefixes(&mut hist_mu, cols, &warm_wt2_hist);
1998
1999    let mut wt1_guard = core::mem::ManuallyDrop::new(wt1_mu);
2000    let mut wt2_guard = core::mem::ManuallyDrop::new(wt2_mu);
2001    let mut hist_guard = core::mem::ManuallyDrop::new(hist_mu);
2002
2003    let wt1_out: &mut [f64] = unsafe {
2004        core::slice::from_raw_parts_mut(wt1_guard.as_mut_ptr() as *mut f64, wt1_guard.len())
2005    };
2006    let wt2_out: &mut [f64] = unsafe {
2007        core::slice::from_raw_parts_mut(wt2_guard.as_mut_ptr() as *mut f64, wt2_guard.len())
2008    };
2009    let hist_out: &mut [f64] = unsafe {
2010        core::slice::from_raw_parts_mut(hist_guard.as_mut_ptr() as *mut f64, hist_guard.len())
2011    };
2012
2013    let kern = match k {
2014        Kernel::Auto => Kernel::Scalar,
2015        x => x,
2016    };
2017
2018    wto_fill_wt1_grouped(data, &combos, first, kern, true, wt1_out)?;
2019
2020    for row in 0..rows {
2021        let row_start = row.checked_mul(cols).ok_or_else(|| {
2022            WtoError::InvalidInput(
2023                "row * cols overflow in wto_batch_all_outputs_with_kernel".into(),
2024            )
2025        })?;
2026        let row_end = row_start + cols;
2027
2028        let wt1_row = &wt1_out[row_start..row_end];
2029        let wt2_row = &mut wt2_out[row_start..row_end];
2030        let hist_row = &mut hist_out[row_start..row_end];
2031
2032        let sma_input = SmaInput::from_slice(wt1_row, SmaParams { period: Some(4) });
2033        sma_into_slice(wt2_row, &sma_input, kern)
2034            .map_err(|e| WtoError::ComputationError(format!("WT2 SMA error: {}", e)))?;
2035
2036        for i in 0..cols {
2037            if !wt1_row[i].is_nan() && !wt2_row[i].is_nan() {
2038                hist_row[i] = wt1_row[i] - wt2_row[i];
2039            }
2040        }
2041    }
2042
2043    let wt1 = unsafe {
2044        Vec::from_raw_parts(
2045            wt1_guard.as_mut_ptr() as *mut f64,
2046            wt1_guard.len(),
2047            wt1_guard.capacity(),
2048        )
2049    };
2050    let wt2 = unsafe {
2051        Vec::from_raw_parts(
2052            wt2_guard.as_mut_ptr() as *mut f64,
2053            wt2_guard.len(),
2054            wt2_guard.capacity(),
2055        )
2056    };
2057    let hist = unsafe {
2058        Vec::from_raw_parts(
2059            hist_guard.as_mut_ptr() as *mut f64,
2060            hist_guard.len(),
2061            hist_guard.capacity(),
2062        )
2063    };
2064
2065    core::mem::forget(wt1_guard);
2066    core::mem::forget(wt2_guard);
2067    core::mem::forget(hist_guard);
2068
2069    Ok(WtoBatchAllOutput {
2070        wt1,
2071        wt2,
2072        hist,
2073        combos,
2074        rows,
2075        cols,
2076    })
2077}
2078
2079#[derive(Debug, Clone)]
2080pub struct WtoStream {
2081    channel_length: usize,
2082    average_length: usize,
2083
2084    esa_alpha: f64,
2085    esa_beta: f64,
2086    tci_alpha: f64,
2087    tci_beta: f64,
2088    k015: f64,
2089    inv4: f64,
2090
2091    samples: usize,
2092    ci_ready: bool,
2093
2094    esa: f64,
2095    d: f64,
2096    tci: f64,
2097
2098    ring: [f64; 4],
2099    rsum: f64,
2100    rpos: usize,
2101    rlen: usize,
2102}
2103
2104const STRICT_WT2_NANS: bool = true;
2105
2106impl WtoStream {
2107    pub fn try_new(params: WtoParams) -> Result<Self, WtoError> {
2108        let channel_length = params.channel_length.unwrap_or(10);
2109        let average_length = params.average_length.unwrap_or(21);
2110
2111        if channel_length == 0 {
2112            return Err(WtoError::InvalidPeriod {
2113                period: channel_length,
2114                data_len: 0,
2115            });
2116        }
2117        if average_length == 0 {
2118            return Err(WtoError::InvalidPeriod {
2119                period: average_length,
2120                data_len: 0,
2121            });
2122        }
2123
2124        let esa_alpha = 2.0 / (channel_length as f64 + 1.0);
2125        let tci_alpha = 2.0 / (average_length as f64 + 1.0);
2126
2127        Ok(Self {
2128            channel_length,
2129            average_length,
2130            esa_alpha,
2131            esa_beta: 1.0 - esa_alpha,
2132            tci_alpha,
2133            tci_beta: 1.0 - tci_alpha,
2134            k015: 0.015_f64,
2135            inv4: 0.25_f64,
2136
2137            samples: 0,
2138            ci_ready: channel_length == 1,
2139
2140            esa: 0.0,
2141            d: 0.0,
2142            tci: 0.0,
2143
2144            ring: [0.0; 4],
2145            rsum: 0.0,
2146            rpos: 0,
2147            rlen: 0,
2148        })
2149    }
2150
2151    #[inline(always)]
2152    fn fast_abs(x: f64) -> f64 {
2153        f64::from_bits(x.to_bits() & 0x7FFF_FFFF_FFFF_FFFF)
2154    }
2155
2156    pub fn update(&mut self, value: f64) -> Option<(f64, f64, f64)> {
2157        if !value.is_finite() {
2158            return None;
2159        }
2160
2161        if self.samples == 0 {
2162            self.esa = value;
2163            self.samples = 1;
2164
2165            if self.ci_ready {
2166                let ci = 0.0;
2167                self.tci = ci;
2168                self.push_wt2(ci);
2169
2170                let (wt2, hist) = self.emit_wt2(ci);
2171                return Some((ci, wt2, hist));
2172            }
2173            return None;
2174        }
2175
2176        self.esa = self.esa_beta.mul_add(self.esa, self.esa_alpha * value);
2177
2178        if !self.ci_ready {
2179            self.samples += 1;
2180
2181            if self.samples == self.channel_length {
2182                self.d = Self::fast_abs(value - self.esa);
2183
2184                let denom = self.k015 * self.d;
2185                let ci = if denom != 0.0 && denom.is_finite() {
2186                    (value - self.esa) / denom
2187                } else {
2188                    0.0
2189                };
2190
2191                self.tci = ci;
2192                self.ci_ready = true;
2193
2194                self.push_wt2(ci);
2195                let (wt2, hist) = self.emit_wt2(ci);
2196                return Some((ci, wt2, hist));
2197            } else {
2198                return None;
2199            }
2200        }
2201
2202        let ad = Self::fast_abs(value - self.esa);
2203        self.d = self.esa_beta.mul_add(self.d, self.esa_alpha * ad);
2204
2205        let mut ci = 0.0;
2206        let denom = self.k015 * self.d;
2207        if denom != 0.0 && denom.is_finite() {
2208            ci = (value - self.esa) * (1.0 / denom);
2209        }
2210
2211        self.tci = self.tci_beta.mul_add(self.tci, self.tci_alpha * ci);
2212
2213        let wt1 = self.tci;
2214
2215        self.push_wt2(wt1);
2216        let (wt2, hist) = self.emit_wt2(wt1);
2217        Some((wt1, wt2, hist))
2218    }
2219
2220    #[inline(always)]
2221    fn push_wt2(&mut self, val: f64) {
2222        if self.rlen < 4 {
2223            self.ring[self.rlen] = val;
2224            self.rsum += val;
2225            self.rlen += 1;
2226        } else {
2227            self.rsum += val - self.ring[self.rpos];
2228            self.ring[self.rpos] = val;
2229            self.rpos = (self.rpos + 1) & 3;
2230        }
2231    }
2232
2233    #[inline(always)]
2234    fn emit_wt2(&self, wt1: f64) -> (f64, f64) {
2235        if self.rlen == 4 {
2236            let sig = self.inv4 * self.rsum;
2237            (sig, wt1 - sig)
2238        } else if STRICT_WT2_NANS {
2239            (f64::NAN, f64::NAN)
2240        } else {
2241            let sig = self.rsum / (self.rlen as f64);
2242            (sig, wt1 - sig)
2243        }
2244    }
2245
2246    pub fn last(&self) -> Option<(f64, f64, f64)> {
2247        if !self.ci_ready {
2248            return None;
2249        }
2250        let wt1 = self.tci;
2251        let (wt2, hist) = self.emit_wt2(wt1);
2252        Some((wt1, wt2, hist))
2253    }
2254
2255    pub fn reset(&mut self) {
2256        self.samples = 0;
2257        self.ci_ready = self.channel_length == 1;
2258
2259        self.esa = 0.0;
2260        self.d = 0.0;
2261        self.tci = 0.0;
2262
2263        self.ring = [0.0; 4];
2264        self.rsum = 0.0;
2265        self.rpos = 0;
2266        self.rlen = 0;
2267    }
2268}
2269
2270#[cfg(feature = "python")]
2271#[pyclass(name = "WtoStream")]
2272pub struct WtoStreamPy {
2273    inner: WtoStream,
2274}
2275
2276#[cfg(feature = "python")]
2277#[pymethods]
2278impl WtoStreamPy {
2279    #[new]
2280    fn new(channel_length: usize, average_length: usize) -> PyResult<Self> {
2281        let p = WtoParams {
2282            channel_length: Some(channel_length),
2283            average_length: Some(average_length),
2284        };
2285        Ok(Self {
2286            inner: WtoStream::try_new(p).map_err(|e| PyValueError::new_err(e.to_string()))?,
2287        })
2288    }
2289
2290    pub fn update(&mut self, value: f64) -> Option<(f64, f64, f64)> {
2291        self.inner.update(value)
2292    }
2293
2294    pub fn last(&self) -> Option<(f64, f64, f64)> {
2295        self.inner.last()
2296    }
2297
2298    pub fn reset(&mut self) {
2299        self.inner.reset()
2300    }
2301}
2302
2303#[cfg(feature = "python")]
2304#[pyfunction(name = "wto_batch")]
2305#[pyo3(signature = (close, channel_range, average_range, kernel=None))]
2306pub fn wto_batch_py<'py>(
2307    py: Python<'py>,
2308    close: numpy::PyReadonlyArray1<'py, f64>,
2309    channel_range: (usize, usize, usize),
2310    average_range: (usize, usize, usize),
2311    kernel: Option<&str>,
2312) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2313    use numpy::{IntoPyArray, PyArrayMethods};
2314    let slice = close.as_slice()?;
2315    let kern = validate_kernel(kernel, false)?;
2316
2317    let sweep = WtoBatchRange {
2318        channel: channel_range,
2319        average: average_range,
2320    };
2321    let out = py
2322        .allow_threads(|| wto_batch_all_outputs_with_kernel(slice, &sweep, kern))
2323        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2324
2325    let dict = pyo3::types::PyDict::new(py);
2326
2327    let wt1_arr = unsafe { numpy::PyArray1::<f64>::new(py, [out.rows * out.cols], false) };
2328    unsafe { wt1_arr.as_slice_mut()? }.copy_from_slice(&out.wt1);
2329    dict.set_item("wt1", wt1_arr.reshape((out.rows, out.cols))?)?;
2330
2331    let wt2_arr = unsafe { numpy::PyArray1::<f64>::new(py, [out.rows * out.cols], false) };
2332    unsafe { wt2_arr.as_slice_mut()? }.copy_from_slice(&out.wt2);
2333    dict.set_item("wt2", wt2_arr.reshape((out.rows, out.cols))?)?;
2334
2335    let hist_arr = unsafe { numpy::PyArray1::<f64>::new(py, [out.rows * out.cols], false) };
2336    unsafe { hist_arr.as_slice_mut()? }.copy_from_slice(&out.hist);
2337    dict.set_item("hist", hist_arr.reshape((out.rows, out.cols))?)?;
2338
2339    dict.set_item(
2340        "channel_lengths",
2341        out.combos
2342            .iter()
2343            .map(|p| p.channel_length.unwrap())
2344            .collect::<Vec<_>>()
2345            .into_pyarray(py),
2346    )?;
2347    dict.set_item(
2348        "average_lengths",
2349        out.combos
2350            .iter()
2351            .map(|p| p.average_length.unwrap())
2352            .collect::<Vec<_>>()
2353            .into_pyarray(py),
2354    )?;
2355    Ok(dict)
2356}
2357
2358#[cfg(all(feature = "python", feature = "cuda"))]
2359#[pyfunction(name = "wto_cuda_batch_dev")]
2360#[pyo3(signature = (close_f32, channel_range, average_range, device_id=0))]
2361pub fn wto_cuda_batch_dev_py<'py>(
2362    py: Python<'py>,
2363    close_f32: numpy::PyReadonlyArray1<'py, f32>,
2364    channel_range: (usize, usize, usize),
2365    average_range: (usize, usize, usize),
2366    device_id: usize,
2367) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2368    use numpy::IntoPyArray;
2369
2370    if !cuda_available() {
2371        return Err(PyValueError::new_err("CUDA not available"));
2372    }
2373
2374    let slice = close_f32.as_slice()?;
2375    let sweep = WtoBatchRange {
2376        channel: channel_range,
2377        average: average_range,
2378    };
2379
2380    let (CudaWtoBatchResult { outputs, combos }, dev_id) = py.allow_threads(|| {
2381        let cuda = CudaWto::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2382        let res = cuda
2383            .wto_batch_dev(slice, &sweep)
2384            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2385        Ok::<_, pyo3::PyErr>((res, cuda.device_id()))
2386    })?;
2387    let DeviceArrayF32Triplet { wt1, wt2, hist } = outputs;
2388
2389    let dict = pyo3::types::PyDict::new(py);
2390    let wt1_py = make_device_array_py(dev_id as usize, wt1)?;
2391    let wt2_py = make_device_array_py(dev_id as usize, wt2)?;
2392    let hist_py = make_device_array_py(dev_id as usize, hist)?;
2393    dict.set_item("wt1", Py::new(py, wt1_py)?)?;
2394    dict.set_item("wt2", Py::new(py, wt2_py)?)?;
2395    dict.set_item("hist", Py::new(py, hist_py)?)?;
2396
2397    let channel_vec: Vec<usize> = combos.iter().map(|p| p.channel_length.unwrap()).collect();
2398    let average_vec: Vec<usize> = combos.iter().map(|p| p.average_length.unwrap()).collect();
2399
2400    dict.set_item("channel_lengths", channel_vec.into_pyarray(py))?;
2401    dict.set_item("average_lengths", average_vec.into_pyarray(py))?;
2402    dict.set_item("rows", combos.len())?;
2403    dict.set_item("cols", slice.len())?;
2404
2405    Ok(dict)
2406}
2407
2408#[cfg(all(feature = "python", feature = "cuda"))]
2409#[pyfunction(name = "wto_cuda_many_series_one_param_dev")]
2410#[pyo3(signature = (data_tm_f32, channel_length, average_length, device_id=0))]
2411pub fn wto_cuda_many_series_one_param_dev_py<'py>(
2412    py: Python<'py>,
2413    data_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
2414    channel_length: usize,
2415    average_length: usize,
2416    device_id: usize,
2417) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2418    use numpy::PyUntypedArrayMethods;
2419
2420    if !cuda_available() {
2421        return Err(PyValueError::new_err("CUDA not available"));
2422    }
2423
2424    let shape = data_tm_f32.shape();
2425    if shape.len() != 2 {
2426        return Err(PyValueError::new_err("expected 2D array"));
2427    }
2428    let rows = shape[0];
2429    let cols = shape[1];
2430    let flat = data_tm_f32.as_slice()?;
2431
2432    let params = WtoParams {
2433        channel_length: Some(channel_length),
2434        average_length: Some(average_length),
2435    };
2436
2437    let (DeviceArrayF32Triplet { wt1, wt2, hist }, dev_id) = py.allow_threads(|| {
2438        let cuda = CudaWto::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2439        let res = cuda
2440            .wto_many_series_one_param_time_major_dev(flat, cols, rows, &params)
2441            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2442        Ok::<_, pyo3::PyErr>((res, cuda.device_id()))
2443    })?;
2444
2445    let dict = pyo3::types::PyDict::new(py);
2446    let wt1_py = make_device_array_py(dev_id as usize, wt1)?;
2447    let wt2_py = make_device_array_py(dev_id as usize, wt2)?;
2448    let hist_py = make_device_array_py(dev_id as usize, hist)?;
2449    dict.set_item("wt1", Py::new(py, wt1_py)?)?;
2450    dict.set_item("wt2", Py::new(py, wt2_py)?)?;
2451    dict.set_item("hist", Py::new(py, hist_py)?)?;
2452    dict.set_item("rows", rows)?;
2453    dict.set_item("cols", cols)?;
2454    dict.set_item("channel_length", channel_length)?;
2455    dict.set_item("average_length", average_length)?;
2456
2457    Ok(dict)
2458}
2459
2460#[cfg(feature = "python")]
2461pub fn register_wto_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
2462    m.add_function(wrap_pyfunction!(wto_py, m)?)?;
2463    m.add_function(wrap_pyfunction!(wto_batch_py, m)?)?;
2464    m.add_class::<WtoStreamPy>()?;
2465    #[cfg(feature = "cuda")]
2466    {
2467        m.add_function(wrap_pyfunction!(wto_cuda_batch_dev_py, m)?)?;
2468        m.add_function(wrap_pyfunction!(wto_cuda_many_series_one_param_dev_py, m)?)?;
2469    }
2470    Ok(())
2471}
2472
2473#[cfg(test)]
2474mod tests {
2475    use super::*;
2476    use crate::utilities::data_loader::read_candles_from_csv;
2477
2478    macro_rules! skip_if_unsupported {
2479        ($kernel:expr, $test_name:expr) => {
2480            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
2481            if matches!(
2482                $kernel,
2483                Kernel::Avx2 | Kernel::Avx512 | Kernel::Avx2Batch | Kernel::Avx512Batch
2484            ) {
2485                eprintln!("[{}] Skipping due to missing AVX support", $test_name);
2486                return Ok(());
2487            }
2488        };
2489    }
2490
2491    fn check_wto_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2492        skip_if_unsupported!(kernel, test_name);
2493        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2494        let candles = read_candles_from_csv(file_path)?;
2495
2496        let input = WtoInput::from_candles(&candles, "close", WtoParams::default());
2497        let result = wto_with_kernel(&input, kernel)?;
2498
2499        let expected_wt1 = [
2500            -34.81423091,
2501            -33.92872278,
2502            -35.29125217,
2503            -34.93917015,
2504            -41.42578524,
2505        ];
2506
2507        let expected_wt2 = [
2508            -37.72141493,
2509            -35.54009606,
2510            -34.81718669,
2511            -34.74334400,
2512            -36.39623258,
2513        ];
2514
2515        let expected_hist = [
2516            2.90718403,
2517            1.61137328,
2518            -0.47406548,
2519            -0.19582615,
2520            -5.02955265,
2521        ];
2522
2523        let start = result.wavetrend1.len().saturating_sub(5);
2524
2525        for (i, &val) in result.wavetrend1[start..].iter().enumerate() {
2526            let diff = (val - expected_wt1[i]).abs();
2527
2528            let rel_tolerance = expected_wt1[i].abs() * 0.1;
2529            let abs_tolerance = 1e-6;
2530            assert!(
2531                diff < rel_tolerance.max(abs_tolerance),
2532                "WaveTrend1 mismatch at idx {}: got {}, expected {}, diff {}",
2533                i,
2534                val,
2535                expected_wt1[i],
2536                diff
2537            );
2538        }
2539
2540        for (i, &val) in result.wavetrend2[start..].iter().enumerate() {
2541            let diff = (val - expected_wt2[i]).abs();
2542            let rel_tolerance = expected_wt2[i].abs() * 0.1;
2543            let abs_tolerance = 1e-6;
2544            assert!(
2545                diff < rel_tolerance.max(abs_tolerance),
2546                "WaveTrend2 mismatch at idx {}: got {}, expected {}, diff {}",
2547                i,
2548                val,
2549                expected_wt2[i],
2550                diff
2551            );
2552        }
2553
2554        for (i, &val) in result.histogram[start..].iter().enumerate() {
2555            let diff = (val - expected_hist[i]).abs();
2556
2557            let abs_tolerance = 2.0;
2558            assert!(
2559                diff < abs_tolerance,
2560                "Histogram mismatch at idx {}: got {}, expected {}, diff {}",
2561                i,
2562                val,
2563                expected_hist[i],
2564                diff
2565            );
2566        }
2567
2568        Ok(())
2569    }
2570
2571    fn check_wto_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2572        skip_if_unsupported!(kernel, test_name);
2573
2574        let data = vec![1.0; 100];
2575        let params = WtoParams {
2576            channel_length: Some(12),
2577            average_length: None,
2578        };
2579        let input = WtoInput::from_slice(&data, params);
2580        let result = wto_with_kernel(&input, kernel)?;
2581
2582        assert_eq!(result.wavetrend1.len(), data.len());
2583        assert_eq!(result.wavetrend2.len(), data.len());
2584        assert_eq!(result.histogram.len(), data.len());
2585        Ok(())
2586    }
2587
2588    fn check_wto_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2589        skip_if_unsupported!(kernel, test_name);
2590
2591        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2592        let candles = read_candles_from_csv(file_path)?;
2593
2594        let input = WtoInput::with_default_candles(&candles);
2595        match input.data {
2596            WtoData::Candles { source, .. } => assert_eq!(source, "close"),
2597            _ => panic!("Expected WtoData::Candles"),
2598        }
2599        let output = wto_with_kernel(&input, kernel)?;
2600        assert_eq!(output.wavetrend1.len(), candles.close.len());
2601
2602        Ok(())
2603    }
2604
2605    fn check_wto_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2606        skip_if_unsupported!(kernel, test_name);
2607
2608        let data = [10.0, 20.0, 30.0];
2609        let params = WtoParams {
2610            channel_length: Some(0),
2611            average_length: None,
2612        };
2613        let input = WtoInput::from_slice(&data, params);
2614        let res = wto_with_kernel(&input, kernel);
2615        assert!(
2616            res.is_err(),
2617            "[{}] WTO should fail with zero period",
2618            test_name
2619        );
2620        Ok(())
2621    }
2622
2623    fn check_wto_period_exceeds_length(
2624        test_name: &str,
2625        kernel: Kernel,
2626    ) -> Result<(), Box<dyn Error>> {
2627        skip_if_unsupported!(kernel, test_name);
2628
2629        let data = [10.0, 20.0, 30.0];
2630        let params = WtoParams {
2631            channel_length: Some(10),
2632            average_length: None,
2633        };
2634        let input = WtoInput::from_slice(&data, params);
2635        let res = wto_with_kernel(&input, kernel);
2636        assert!(
2637            res.is_err(),
2638            "[{}] WTO should fail with period exceeding length",
2639            test_name
2640        );
2641        Ok(())
2642    }
2643
2644    fn check_wto_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2645        skip_if_unsupported!(kernel, test_name);
2646
2647        let single_point = [42.0];
2648        let params = WtoParams::default();
2649        let input = WtoInput::from_slice(&single_point, params);
2650        let res = wto_with_kernel(&input, kernel);
2651        assert!(
2652            res.is_err(),
2653            "[{}] WTO should fail with insufficient data",
2654            test_name
2655        );
2656        Ok(())
2657    }
2658
2659    fn check_wto_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2660        skip_if_unsupported!(kernel, test_name);
2661        let empty: [f64; 0] = [];
2662        let input = WtoInput::from_slice(&empty, WtoParams::default());
2663        let res = wto_with_kernel(&input, kernel);
2664        assert!(
2665            matches!(res, Err(WtoError::EmptyInputData)),
2666            "[{}] WTO should fail with empty input",
2667            test_name
2668        );
2669        Ok(())
2670    }
2671
2672    fn check_wto_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2673        skip_if_unsupported!(kernel, test_name);
2674        let data = vec![f64::NAN; 50];
2675        let params = WtoParams::default();
2676        let input = WtoInput::from_slice(&data, params);
2677        let res = wto_with_kernel(&input, kernel);
2678        assert!(
2679            matches!(res, Err(WtoError::AllValuesNaN)),
2680            "[{}] WTO should fail with all NaN values",
2681            test_name
2682        );
2683        Ok(())
2684    }
2685
2686    fn check_wto_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2687        skip_if_unsupported!(kernel, test_name);
2688
2689        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2690        let candles = read_candles_from_csv(file_path)?;
2691
2692        let first_params = WtoParams::default();
2693        let first_input = WtoInput::from_candles(&candles, "close", first_params);
2694        let first_result = wto_with_kernel(&first_input, kernel)?;
2695
2696        let second_params = WtoParams::default();
2697        let second_input = WtoInput::from_slice(&first_result.wavetrend1, second_params);
2698        let second_result = wto_with_kernel(&second_input, kernel)?;
2699
2700        assert_eq!(
2701            second_result.wavetrend1.len(),
2702            first_result.wavetrend1.len()
2703        );
2704        Ok(())
2705    }
2706
2707    fn check_wto_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2708        skip_if_unsupported!(kernel, test_name);
2709
2710        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2711        let candles = read_candles_from_csv(file_path)?;
2712
2713        let input = WtoInput::from_candles(&candles, "close", WtoParams::default());
2714        let res = wto_with_kernel(&input, kernel)?;
2715
2716        assert_eq!(res.wavetrend1.len(), candles.close.len());
2717        if res.wavetrend1.len() > 50 {
2718            for (i, &val) in res.wavetrend1[50..].iter().enumerate() {
2719                assert!(
2720                    !val.is_nan(),
2721                    "[{}] Found unexpected NaN at index {}",
2722                    test_name,
2723                    50 + i
2724                );
2725            }
2726        }
2727        Ok(())
2728    }
2729
2730    fn check_wto_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2731        skip_if_unsupported!(kernel, test_name);
2732
2733        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2734        let candles = read_candles_from_csv(file_path)?;
2735
2736        let params = WtoParams::default();
2737        let input = WtoInput::from_candles(&candles, "close", params.clone());
2738        let batch_result = wto_with_kernel(&input, kernel)?;
2739
2740        let mut stream = WtoStream::try_new(params)?;
2741
2742        let mut stream_wt1 = Vec::new();
2743        let mut stream_wt2 = Vec::new();
2744        let mut stream_hist = Vec::new();
2745
2746        for i in 0..candles.close.len() {
2747            if let Some((wt1, wt2, hist)) = stream.update(candles.close[i]) {
2748                stream_wt1.push(wt1);
2749                stream_wt2.push(wt2);
2750                stream_hist.push(hist);
2751            }
2752        }
2753
2754        assert!(!stream_wt1.is_empty());
2755        Ok(())
2756    }
2757
2758    #[cfg(debug_assertions)]
2759    fn check_wto_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2760        skip_if_unsupported!(kernel, test_name);
2761
2762        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2763        let candles = read_candles_from_csv(file_path)?;
2764
2765        let test_params = vec![
2766            WtoParams {
2767                channel_length: Some(5),
2768                average_length: Some(10),
2769            },
2770            WtoParams {
2771                channel_length: Some(20),
2772                average_length: Some(40),
2773            },
2774            WtoParams {
2775                channel_length: Some(3),
2776                average_length: Some(7),
2777            },
2778        ];
2779
2780        for params in test_params {
2781            let input = WtoInput::from_candles(&candles, "close", params.clone());
2782            let output = wto_with_kernel(&input, kernel)?;
2783
2784            for (i, &val) in output.wavetrend1.iter().enumerate() {
2785                if val.is_nan() {
2786                    continue;
2787                }
2788
2789                let bits = val.to_bits();
2790                if bits == 0x11111111_11111111
2791                    || bits == 0x22222222_22222222
2792                    || bits == 0x33333333_33333333
2793                {
2794                    panic!(
2795                        "[{}] Found poison value {} (0x{:016X}) at index {} with params: {:?}",
2796                        test_name, val, bits, i, params
2797                    );
2798                }
2799            }
2800        }
2801
2802        Ok(())
2803    }
2804
2805    #[cfg(not(debug_assertions))]
2806    fn check_wto_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2807        Ok(())
2808    }
2809
2810    #[cfg(debug_assertions)]
2811    fn check_wto_no_poison_all(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2812        skip_if_unsupported!(kernel, test_name);
2813
2814        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2815        let c = read_candles_from_csv(file)?;
2816        let input = WtoInput::with_default_candles(&c);
2817        let out = wto_with_kernel(&input, kernel)?;
2818
2819        for series in [&out.wavetrend1, &out.wavetrend2, &out.histogram] {
2820            for &v in series {
2821                if v.is_nan() {
2822                    continue;
2823                }
2824                let b = v.to_bits();
2825                assert!(
2826                    b != 0x11111111_11111111
2827                        && b != 0x22222222_22222222
2828                        && b != 0x33333333_33333333,
2829                    "[{}] poison value 0x{:016X}",
2830                    test_name,
2831                    b
2832                );
2833            }
2834        }
2835        Ok(())
2836    }
2837
2838    fn check_batch_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2839        skip_if_unsupported!(kernel, test_name);
2840
2841        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2842        let candles = read_candles_from_csv(file)?;
2843
2844        let output = WtoBatchBuilder::new()
2845            .channel_range(5, 15, 5)
2846            .average_range(10, 30, 10)
2847            .kernel(kernel)
2848            .apply_candles(&candles, "close")?;
2849
2850        for &v in &output.values {
2851            if v.is_nan() {
2852                continue;
2853            }
2854            let b = v.to_bits();
2855            assert!(
2856                b != 0x11111111_11111111 && b != 0x22222222_22222222 && b != 0x33333333_33333333,
2857                "[{}] batch poison value 0x{:016X}",
2858                test_name,
2859                b
2860            );
2861        }
2862
2863        let sweep = WtoBatchRange {
2864            channel: (5, 15, 5),
2865            average: (10, 30, 10),
2866        };
2867        let data = source_type(&candles, "close");
2868        let full_out = wto_batch_all_outputs_with_kernel(data, &sweep, kernel)?;
2869
2870        for series in [&full_out.wt1, &full_out.wt2, &full_out.hist] {
2871            for &v in series {
2872                if v.is_nan() {
2873                    continue;
2874                }
2875                let b = v.to_bits();
2876                assert!(
2877                    b != 0x11111111_11111111
2878                        && b != 0x22222222_22222222
2879                        && b != 0x33333333_33333333,
2880                    "[{}] full batch poison value 0x{:016X}",
2881                    test_name,
2882                    b
2883                );
2884            }
2885        }
2886
2887        Ok(())
2888    }
2889
2890    macro_rules! generate_all_wto_tests {
2891        ($($test_fn:ident),*) => {
2892            paste::paste! {
2893                $(
2894                    #[test]
2895                    fn [<$test_fn _scalar>]() {
2896                        let _ = $test_fn(stringify!([<$test_fn _scalar>]), Kernel::Scalar);
2897                    }
2898                )*
2899                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2900                $(
2901                    #[test]
2902                    fn [<$test_fn _avx2>]() {
2903                        let _ = $test_fn(stringify!([<$test_fn _avx2>]), Kernel::Avx2);
2904                    }
2905                    #[test]
2906                    fn [<$test_fn _avx512>]() {
2907                        let _ = $test_fn(stringify!([<$test_fn _avx512>]), Kernel::Avx512);
2908                    }
2909                )*
2910                #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2911                $(
2912                    #[test]
2913                    fn [<$test_fn _simd128>]() {
2914                        let _ = $test_fn(stringify!([<$test_fn _simd128>]), Kernel::Scalar);
2915                    }
2916                )*
2917            }
2918        }
2919    }
2920
2921    generate_all_wto_tests!(
2922        check_wto_accuracy,
2923        check_wto_partial_params,
2924        check_wto_default_candles,
2925        check_wto_zero_period,
2926        check_wto_period_exceeds_length,
2927        check_wto_very_small_dataset,
2928        check_wto_empty_input,
2929        check_wto_all_nan,
2930        check_wto_reinput,
2931        check_wto_nan_handling,
2932        check_wto_streaming,
2933        check_wto_no_poison,
2934        check_batch_poison
2935    );
2936
2937    #[cfg(debug_assertions)]
2938    #[test]
2939    fn test_wto_no_poison_all_scalar() {
2940        check_wto_no_poison_all("test_wto_no_poison_all_scalar", Kernel::Scalar).unwrap();
2941    }
2942
2943    #[cfg(all(debug_assertions, feature = "nightly-avx", target_arch = "x86_64"))]
2944    #[test]
2945    fn test_wto_no_poison_all_avx2() {
2946        check_wto_no_poison_all("test_wto_no_poison_all_avx2", Kernel::Avx2).unwrap();
2947    }
2948
2949    #[cfg(all(debug_assertions, feature = "nightly-avx", target_arch = "x86_64"))]
2950    #[test]
2951    fn test_wto_no_poison_all_avx512() {
2952        check_wto_no_poison_all("test_wto_no_poison_all_avx512", Kernel::Avx512).unwrap();
2953    }
2954
2955    #[cfg(all(debug_assertions, target_arch = "wasm32", target_feature = "simd128"))]
2956    #[test]
2957    fn test_wto_no_poison_all_simd128() {
2958        check_wto_no_poison_all("test_wto_no_poison_all_simd128", Kernel::Simd128).unwrap();
2959    }
2960
2961    fn check_batch_default_row(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2962        skip_if_unsupported!(kernel, test_name);
2963
2964        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2965        let candles = read_candles_from_csv(file_path)?;
2966
2967        let output = WtoBatchBuilder::new()
2968            .kernel(kernel)
2969            .apply_candles(&candles, "close")?;
2970
2971        let def = WtoParams::default();
2972        let row = output.values_for(&def).expect("default row missing");
2973
2974        assert_eq!(row.len(), candles.close.len());
2975        Ok(())
2976    }
2977
2978    fn check_batch_sweep(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2979        skip_if_unsupported!(kernel, test_name);
2980
2981        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2982        let candles = read_candles_from_csv(file_path)?;
2983
2984        let output = WtoBatchBuilder::new()
2985            .kernel(kernel)
2986            .channel_range(5, 15, 5)
2987            .average_range(10, 30, 10)
2988            .apply_candles(&candles, "close")?;
2989
2990        let expected_combos = 3 * 3;
2991        assert_eq!(output.combos.len(), expected_combos);
2992        assert_eq!(output.rows, expected_combos);
2993        assert_eq!(output.cols, candles.close.len());
2994
2995        Ok(())
2996    }
2997
2998    macro_rules! gen_batch_tests {
2999        ($fn_name:ident) => {
3000            paste::paste! {
3001                #[test]
3002                fn [<$fn_name _scalar>]() {
3003                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3004                }
3005                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3006                #[test]
3007                fn [<$fn_name _avx2>]() {
3008                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3009                }
3010                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3011                #[test]
3012                fn [<$fn_name _avx512>]() {
3013                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3014                }
3015                #[test]
3016                fn [<$fn_name _auto_detect>]() {
3017                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3018                }
3019            }
3020        };
3021    }
3022
3023    gen_batch_tests!(check_batch_default_row);
3024    gen_batch_tests!(check_batch_sweep);
3025
3026    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
3027    #[test]
3028    fn test_wto_into_matches_api() {
3029        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3030        let candles = read_candles_from_csv(file_path).expect("failed to load candles");
3031
3032        let input = WtoInput::with_default_candles(&candles);
3033
3034        let base = wto(&input).expect("wto baseline failed");
3035
3036        let n = candles.close.len();
3037        let mut wt1 = vec![0.0; n];
3038        let mut wt2 = vec![0.0; n];
3039        let mut hist = vec![0.0; n];
3040
3041        wto_into(&input, &mut wt1, &mut wt2, &mut hist).expect("wto_into failed");
3042
3043        fn eq_or_both_nan(a: f64, b: f64) -> bool {
3044            (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
3045        }
3046
3047        assert_eq!(wt1.len(), base.wavetrend1.len());
3048        assert_eq!(wt2.len(), base.wavetrend2.len());
3049        assert_eq!(hist.len(), base.histogram.len());
3050
3051        for i in 0..n {
3052            assert!(
3053                eq_or_both_nan(wt1[i], base.wavetrend1[i]),
3054                "wt1 mismatch at {}: into={}, api={}",
3055                i,
3056                wt1[i],
3057                base.wavetrend1[i]
3058            );
3059            assert!(
3060                eq_or_both_nan(wt2[i], base.wavetrend2[i]),
3061                "wt2 mismatch at {}: into={}, api={}",
3062                i,
3063                wt2[i],
3064                base.wavetrend2[i]
3065            );
3066            assert!(
3067                eq_or_both_nan(hist[i], base.histogram[i]),
3068                "hist mismatch at {}: into={}, api={}",
3069                i,
3070                hist[i],
3071                base.histogram[i]
3072            );
3073        }
3074    }
3075}