1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::{cuda_available, CudaWto, CudaWtoBatchResult, DeviceArrayF32Triplet};
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyValueError;
7#[cfg(feature = "python")]
8use pyo3::prelude::*;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
19 make_uninit_matrix,
20};
21#[cfg(feature = "python")]
22use crate::utilities::kernel_validation::validate_kernel;
23use aligned_vec::{AVec, CACHELINE_ALIGN};
24
25#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
26use core::arch::x86_64::*;
27
28#[cfg(not(target_arch = "wasm32"))]
29use rayon::prelude::*;
30
31use std::collections::BTreeMap;
32use std::convert::AsRef;
33use std::error::Error;
34use std::mem::MaybeUninit;
35use thiserror::Error;
36
37use crate::indicators::moving_averages::ema::{
38 ema_into_slice, ema_with_kernel, EmaInput, EmaParams,
39};
40use crate::indicators::moving_averages::sma::{
41 sma_into_slice, sma_with_kernel, SmaInput, SmaParams,
42};
43
44#[derive(Debug, Clone)]
45pub enum WtoData<'a> {
46 Candles {
47 candles: &'a Candles,
48 source: &'a str,
49 },
50 Slice(&'a [f64]),
51}
52
53#[derive(Debug, Clone)]
54pub struct WtoOutput {
55 pub wavetrend1: Vec<f64>,
56 pub wavetrend2: Vec<f64>,
57 pub histogram: Vec<f64>,
58}
59
60#[derive(Debug, Clone)]
61#[cfg_attr(
62 all(target_arch = "wasm32", feature = "wasm"),
63 derive(Serialize, Deserialize)
64)]
65pub struct WtoParams {
66 pub channel_length: Option<usize>,
67 pub average_length: Option<usize>,
68}
69
70impl Default for WtoParams {
71 fn default() -> Self {
72 Self {
73 channel_length: Some(10),
74 average_length: Some(21),
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
80pub struct WtoInput<'a> {
81 pub data: WtoData<'a>,
82 pub params: WtoParams,
83}
84
85impl<'a> WtoInput<'a> {
86 #[inline]
87 pub fn from_candles(c: &'a Candles, source: &'a str, p: WtoParams) -> Self {
88 Self {
89 data: WtoData::Candles { candles: c, source },
90 params: p,
91 }
92 }
93
94 #[inline]
95 pub fn from_slice(sl: &'a [f64], p: WtoParams) -> Self {
96 Self {
97 data: WtoData::Slice(sl),
98 params: p,
99 }
100 }
101
102 #[inline]
103 pub fn with_default_candles(c: &'a Candles) -> Self {
104 Self::from_candles(c, "close", WtoParams::default())
105 }
106
107 #[inline]
108 pub fn get_channel_length(&self) -> usize {
109 self.params.channel_length.unwrap_or(10)
110 }
111
112 #[inline]
113 pub fn get_average_length(&self) -> usize {
114 self.params.average_length.unwrap_or(21)
115 }
116}
117
118impl<'a> AsRef<[f64]> for WtoInput<'a> {
119 #[inline(always)]
120 fn as_ref(&self) -> &[f64] {
121 match &self.data {
122 WtoData::Slice(slice) => slice,
123 WtoData::Candles { candles, source } => source_type(candles, source),
124 }
125 }
126}
127
128#[derive(Copy, Clone, Debug)]
129pub struct WtoBuilder {
130 channel_length: Option<usize>,
131 average_length: Option<usize>,
132 kernel: Kernel,
133}
134
135impl Default for WtoBuilder {
136 fn default() -> Self {
137 Self {
138 channel_length: None,
139 average_length: None,
140 kernel: Kernel::Auto,
141 }
142 }
143}
144
145impl WtoBuilder {
146 #[inline(always)]
147 pub fn new() -> Self {
148 Self::default()
149 }
150
151 #[inline(always)]
152 pub fn channel_length(mut self, n: usize) -> Self {
153 self.channel_length = Some(n);
154 self
155 }
156
157 #[inline(always)]
158 pub fn average_length(mut self, n: usize) -> Self {
159 self.average_length = Some(n);
160 self
161 }
162
163 #[inline(always)]
164 pub fn kernel(mut self, k: Kernel) -> Self {
165 self.kernel = k;
166 self
167 }
168
169 #[inline(always)]
170 pub fn apply(self, c: &Candles) -> Result<WtoOutput, WtoError> {
171 let p = WtoParams {
172 channel_length: self.channel_length,
173 average_length: self.average_length,
174 };
175 let i = WtoInput::from_candles(c, "close", p);
176 wto_with_kernel(&i, self.kernel)
177 }
178
179 #[inline(always)]
180 pub fn apply_slice(self, d: &[f64]) -> Result<WtoOutput, WtoError> {
181 let p = WtoParams {
182 channel_length: self.channel_length,
183 average_length: self.average_length,
184 };
185 let i = WtoInput::from_slice(d, p);
186 wto_with_kernel(&i, self.kernel)
187 }
188
189 #[inline(always)]
190 pub fn into_stream(self) -> Result<WtoStream, WtoError> {
191 let p = WtoParams {
192 channel_length: self.channel_length,
193 average_length: self.average_length,
194 };
195 WtoStream::try_new(p)
196 }
197}
198
199#[derive(Debug, Error)]
200pub enum WtoError {
201 #[error("wto: Input data slice is empty.")]
202 EmptyInputData,
203 #[error("wto: All values are NaN.")]
204 AllValuesNaN,
205 #[error("wto: Invalid input: {0}")]
206 InvalidInput(String),
207 #[error("wto: Invalid period: period = {period}, data length = {data_len}")]
208 InvalidPeriod { period: usize, data_len: usize },
209 #[error("wto: Not enough valid data: needed = {needed}, valid = {valid}")]
210 NotEnoughValidData { needed: usize, valid: usize },
211 #[error("wto: Output length mismatch: expected {expected}, got {got}")]
212 OutputLengthMismatch { expected: usize, got: usize },
213 #[error("wto: Invalid range: start={start}, end={end}, step={step}")]
214 InvalidRange {
215 start: String,
216 end: String,
217 step: String,
218 },
219 #[error("wto: Invalid kernel for batch: {0:?}")]
220 InvalidKernelForBatch(crate::utilities::enums::Kernel),
221 #[error("wto: Computation error: {0}")]
222 ComputationError(String),
223}
224
225#[inline]
226pub fn wto(input: &WtoInput) -> Result<WtoOutput, WtoError> {
227 wto_with_kernel(input, Kernel::Auto)
228}
229
230#[inline]
231pub fn wto_with_kernel(input: &WtoInput, kernel: Kernel) -> Result<WtoOutput, WtoError> {
232 let (data, channel_length, average_length, first, chosen) = wto_prepare(input, kernel)?;
233 let len = data.len();
234
235 let ci_start = first + channel_length.saturating_sub(1);
236 let warm_wt1 = ci_start;
237 let warm_wt2_hist = ci_start.saturating_add(3);
238 let mut wavetrend1 = alloc_with_nan_prefix(len, warm_wt1);
239 let mut wavetrend2 = alloc_with_nan_prefix(len, warm_wt2_hist);
240 let mut histogram = alloc_with_nan_prefix(len, warm_wt2_hist);
241
242 wto_compute_into(
243 data,
244 channel_length,
245 average_length,
246 first,
247 chosen,
248 &mut wavetrend1,
249 &mut wavetrend2,
250 &mut histogram,
251 )?;
252
253 Ok(WtoOutput {
254 wavetrend1,
255 wavetrend2,
256 histogram,
257 })
258}
259
260#[inline]
261pub fn wto_into_slices(
262 wt1: &mut [f64],
263 wt2: &mut [f64],
264 hist: &mut [f64],
265 input: &WtoInput,
266 kernel: Kernel,
267) -> Result<(), WtoError> {
268 let (data, channel_length, average_length, first, chosen) = wto_prepare(input, kernel)?;
269
270 if wt1.len() != data.len() || wt2.len() != data.len() || hist.len() != data.len() {
271 let expected = data.len();
272 let got = wt1.len().max(wt2.len()).max(hist.len());
273 return Err(WtoError::OutputLengthMismatch { expected, got });
274 }
275
276 let ci_start = first + channel_length.saturating_sub(1);
277 let warm_wt1 = ci_start.min(wt1.len());
278 let warm_wt2_hist = ci_start.saturating_add(3);
279 let warm_wt2_hist = warm_wt2_hist.min(wt2.len()).min(hist.len());
280 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
281 for v in &mut wt1[..warm_wt1] {
282 *v = qnan;
283 }
284 for v in &mut wt2[..warm_wt2_hist] {
285 *v = qnan;
286 }
287 for v in &mut hist[..warm_wt2_hist] {
288 *v = qnan;
289 }
290
291 wto_compute_into(
292 data,
293 channel_length,
294 average_length,
295 first,
296 chosen,
297 wt1,
298 wt2,
299 hist,
300 )?;
301
302 Ok(())
303}
304
305#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
306pub fn wto_into(
307 input: &WtoInput,
308 wt1_out: &mut [f64],
309 wt2_out: &mut [f64],
310 hist_out: &mut [f64],
311) -> Result<(), WtoError> {
312 let (data, channel_length, average_length, first, chosen) = wto_prepare(input, Kernel::Auto)?;
313
314 if wt1_out.len() != data.len() || wt2_out.len() != data.len() || hist_out.len() != data.len() {
315 let expected = data.len();
316 let got = wt1_out.len().max(wt2_out.len()).max(hist_out.len());
317 return Err(WtoError::OutputLengthMismatch { expected, got });
318 }
319
320 let ci_start = first + channel_length.saturating_sub(1);
321 let warm_wt1 = ci_start;
322 let warm_wt2_hist = ci_start.saturating_add(3);
323 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
324 let w = warm_wt1.min(wt1_out.len());
325 for v in &mut wt1_out[..w] {
326 *v = qnan;
327 }
328 let w = warm_wt2_hist.min(wt2_out.len());
329 for v in &mut wt2_out[..w] {
330 *v = qnan;
331 }
332 let w = warm_wt2_hist.min(hist_out.len());
333 for v in &mut hist_out[..w] {
334 *v = qnan;
335 }
336
337 wto_compute_into(
338 data,
339 channel_length,
340 average_length,
341 first,
342 chosen,
343 wt1_out,
344 wt2_out,
345 hist_out,
346 )
347}
348
349#[inline(always)]
350fn wto_prepare<'a>(
351 input: &'a WtoInput,
352 kernel: Kernel,
353) -> Result<(&'a [f64], usize, usize, usize, Kernel), WtoError> {
354 let data: &[f64] = input.as_ref();
355 let len = data.len();
356
357 if len == 0 {
358 return Err(WtoError::EmptyInputData);
359 }
360
361 let first = data
362 .iter()
363 .position(|x| !x.is_nan())
364 .ok_or(WtoError::AllValuesNaN)?;
365 let channel_length = input.get_channel_length();
366 let average_length = input.get_average_length();
367
368 if channel_length == 0 || channel_length > len {
369 return Err(WtoError::InvalidPeriod {
370 period: channel_length,
371 data_len: len,
372 });
373 }
374 if average_length == 0 || average_length > len {
375 return Err(WtoError::InvalidPeriod {
376 period: average_length,
377 data_len: len,
378 });
379 }
380
381 let valid = len - first;
382 let needed = channel_length
383 .saturating_add(3)
384 .max(average_length.saturating_add(3));
385 if valid < needed {
386 return Err(WtoError::NotEnoughValidData { needed, valid });
387 }
388
389 let chosen = match kernel {
390 Kernel::Auto => Kernel::Scalar,
391 k => k,
392 };
393
394 Ok((data, channel_length, average_length, first, chosen))
395}
396
397#[inline(always)]
398fn wto_compute_into(
399 data: &[f64],
400 channel_length: usize,
401 average_length: usize,
402 first: usize,
403 kernel: Kernel,
404 wt1: &mut [f64],
405 wt2: &mut [f64],
406 hist: &mut [f64],
407) -> Result<(), WtoError> {
408 unsafe {
409 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
410 {
411 if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
412 return wto_simd128(data, channel_length, average_length, first, wt1, wt2, hist);
413 }
414 }
415
416 match kernel {
417 Kernel::Scalar | Kernel::ScalarBatch => wto_scalar(
418 data,
419 channel_length,
420 average_length,
421 first,
422 kernel,
423 wt1,
424 wt2,
425 hist,
426 ),
427 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
428 Kernel::Avx2 | Kernel::Avx2Batch => {
429 wto_avx2(data, channel_length, average_length, first, wt1, wt2, hist)
430 }
431 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
432 Kernel::Avx512 | Kernel::Avx512Batch => {
433 wto_avx512(data, channel_length, average_length, first, wt1, wt2, hist)
434 }
435 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
436 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => wto_scalar(
437 data,
438 channel_length,
439 average_length,
440 first,
441 kernel,
442 wt1,
443 wt2,
444 hist,
445 ),
446 _ => unreachable!(),
447 }
448 }
449}
450
451#[inline(always)]
452fn ema_pinescript_into(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
453 let len = data.len();
454 if first_val >= len {
455 return;
456 }
457
458 let alpha = 2.0 / (period as f64 + 1.0);
459 let beta = 1.0 - alpha;
460
461 let mut ema = data[first_val];
462 out[first_val] = ema;
463
464 for i in (first_val + 1)..len {
465 if data[i].is_finite() {
466 ema = alpha * data[i] + beta * ema;
467 out[i] = ema;
468 } else {
469 out[i] = ema;
470 }
471 }
472}
473
474#[inline]
475pub fn wto_scalar(
476 data: &[f64],
477 channel_length: usize,
478 average_length: usize,
479 first_val: usize,
480 _kernel: Kernel,
481 wt1: &mut [f64],
482 wt2: &mut [f64],
483 hist: &mut [f64],
484) -> Result<(), WtoError> {
485 #[inline(always)]
486 fn fast_abs(x: f64) -> f64 {
487 f64::from_bits(x.to_bits() & 0x7FFF_FFFF_FFFF_FFFF)
488 }
489
490 let len = data.len();
491 if len == 0 || first_val >= len {
492 return Ok(());
493 }
494
495 let alpha_e = 2.0 / (channel_length as f64 + 1.0);
496 let beta_e = 1.0 - alpha_e;
497 let alpha_t = 2.0 / (average_length as f64 + 1.0);
498 let beta_t = 1.0 - alpha_t;
499
500 let ci_start = first_val + channel_length.saturating_sub(1);
501 if ci_start >= len {
502 return Ok(());
503 }
504
505 let mut esa = data[first_val];
506
507 let mut i = first_val + 1;
508 while i < ci_start {
509 let x = data[i];
510 if x.is_finite() {
511 esa = beta_e.mul_add(esa, alpha_e * x);
512 }
513 i += 1;
514 }
515
516 let mut d = 0.0_f64;
517 let mut tci = 0.0_f64;
518
519 let mut ring = [0.0_f64; 4];
520 let mut rsum = 0.0_f64;
521 let mut rpos = 0usize;
522 let mut rlen = 0usize;
523
524 let k015 = 0.015_f64;
525 let inv4 = 0.25_f64;
526
527 {
528 let x = data[ci_start];
529 if x.is_finite() {
530 esa = beta_e.mul_add(esa, alpha_e * x);
531 }
532 let abs_diff = if x.is_finite() {
533 fast_abs(x - esa)
534 } else {
535 f64::NAN
536 };
537 d = abs_diff;
538
539 let denom = k015 * d;
540 let ci = if denom != 0.0 && denom.is_finite() {
541 if x.is_finite() {
542 (x - esa) / denom
543 } else {
544 f64::NAN
545 }
546 } else {
547 0.0
548 };
549
550 tci = ci;
551 wt1[ci_start] = tci;
552
553 ring[0] = tci;
554 rsum = tci;
555 rlen = 1;
556 rpos = 1;
557 }
558
559 i = ci_start + 1;
560 while i < len {
561 let x = data[i];
562 let x_fin = x.is_finite();
563
564 if x_fin {
565 esa = beta_e.mul_add(esa, alpha_e * x);
566 let ad = fast_abs(x - esa);
567 d = beta_e.mul_add(d, alpha_e * ad);
568 }
569
570 let mut ci = 0.0_f64;
571 if x_fin {
572 let denom = k015 * d;
573 if denom != 0.0 && denom.is_finite() {
574 ci = (x - esa) / denom;
575 }
576 } else {
577 ci = f64::NAN;
578 }
579
580 if ci.is_finite() {
581 tci = beta_t.mul_add(tci, alpha_t * ci);
582 }
583
584 wt1[i] = tci;
585
586 if rlen < 4 {
587 ring[rlen] = tci;
588 rsum += tci;
589 rlen += 1;
590 } else {
591 rsum += tci - ring[rpos];
592 ring[rpos] = tci;
593 rpos = (rpos + 1) & 3;
594 }
595
596 if rlen == 4 {
597 let sig = inv4 * rsum;
598 wt2[i] = sig;
599 hist[i] = tci - sig;
600 }
601
602 i += 1;
603 }
604
605 Ok(())
606}
607
608#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
609#[inline]
610unsafe fn wto_simd128(
611 data: &[f64],
612 channel_length: usize,
613 average_length: usize,
614 first_val: usize,
615 wt1: &mut [f64],
616 wt2: &mut [f64],
617 hist: &mut [f64],
618) -> Result<(), WtoError> {
619 wto_scalar(
620 data,
621 channel_length,
622 average_length,
623 first_val,
624 Kernel::Scalar,
625 wt1,
626 wt2,
627 hist,
628 )
629}
630
631#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
632#[inline]
633unsafe fn wto_avx2(
634 data: &[f64],
635 channel_length: usize,
636 average_length: usize,
637 first_val: usize,
638 wt1: &mut [f64],
639 wt2: &mut [f64],
640 hist: &mut [f64],
641) -> Result<(), WtoError> {
642 #[inline(always)]
643 fn fast_abs(x: f64) -> f64 {
644 f64::from_bits(x.to_bits() & 0x7FFF_FFFF_FFFF_FFFF)
645 }
646
647 let len = data.len();
648 if len == 0 {
649 return Ok(());
650 }
651
652 let alpha_e = 2.0 / (channel_length as f64 + 1.0);
653 let beta_e = 1.0 - alpha_e;
654 let alpha_t = 2.0 / (average_length as f64 + 1.0);
655 let beta_t = 1.0 - alpha_t;
656
657 let ci_start = first_val + channel_length.saturating_sub(1);
658
659 let mut ring = [0.0_f64; 4];
660 let mut rsum = 0.0_f64;
661 let mut rpos = 0usize;
662 let mut rlen = 0usize;
663
664 let x_ptr = data.as_ptr();
665 let wt1_ptr = wt1.as_mut_ptr();
666 let wt2_ptr = wt2.as_mut_ptr();
667 let hist_ptr = hist.as_mut_ptr();
668
669 let mut i = 0usize;
670 while i < first_val {
671 *wt1_ptr.add(i) = f64::NAN;
672 *wt2_ptr.add(i) = f64::NAN;
673 *hist_ptr.add(i) = f64::NAN;
674 i += 1;
675 }
676
677 let mut esa = *x_ptr.add(first_val);
678
679 let mut d = f64::NAN;
680 let mut d_inited = false;
681 let mut tci = f64::NAN;
682 let mut tci_inited = false;
683
684 i = first_val;
685 while i < len {
686 let x = *x_ptr.add(i);
687 let x_fin = x.is_finite();
688
689 if i != first_val && x_fin {
690 esa = beta_e.mul_add(esa, alpha_e * x);
691 }
692
693 if i >= ci_start {
694 let abs_diff = if x_fin { fast_abs(x - esa) } else { f64::NAN };
695 if !d_inited {
696 if i == ci_start {
697 d = abs_diff;
698 d_inited = true;
699 }
700 } else if abs_diff.is_finite() {
701 d = beta_e.mul_add(d, alpha_e * abs_diff);
702 }
703
704 let denom = 0.015_f64 * d;
705 let mut ci = 0.0_f64;
706 if denom != 0.0 && denom.is_finite() {
707 ci = if x_fin { (x - esa) / denom } else { f64::NAN };
708 }
709
710 if !tci_inited {
711 if i == ci_start {
712 tci = ci;
713 tci_inited = true;
714 }
715 } else if ci.is_finite() {
716 tci = beta_t.mul_add(tci, alpha_t * ci);
717 }
718
719 if tci_inited {
720 *wt1_ptr.add(i) = tci;
721
722 if rlen < 4 {
723 ring[rlen] = tci;
724 rsum += tci;
725 rlen += 1;
726 } else {
727 rsum += tci - ring[rpos];
728 ring[rpos] = tci;
729 rpos = (rpos + 1) & 3;
730 }
731
732 if rlen == 4 {
733 let sig = 0.25_f64 * rsum;
734 *wt2_ptr.add(i) = sig;
735 *hist_ptr.add(i) = tci - sig;
736 } else {
737 *wt2_ptr.add(i) = f64::NAN;
738 *hist_ptr.add(i) = f64::NAN;
739 }
740 } else {
741 *wt1_ptr.add(i) = f64::NAN;
742 *wt2_ptr.add(i) = f64::NAN;
743 *hist_ptr.add(i) = f64::NAN;
744 }
745 } else {
746 *wt1_ptr.add(i) = f64::NAN;
747 *wt2_ptr.add(i) = f64::NAN;
748 *hist_ptr.add(i) = f64::NAN;
749 }
750
751 i += 1;
752 }
753
754 Ok(())
755}
756
757#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
758#[inline]
759unsafe fn wto_avx512(
760 data: &[f64],
761 channel_length: usize,
762 average_length: usize,
763 first_val: usize,
764 wt1: &mut [f64],
765 wt2: &mut [f64],
766 hist: &mut [f64],
767) -> Result<(), WtoError> {
768 #[inline(always)]
769 fn fast_abs(x: f64) -> f64 {
770 f64::from_bits(x.to_bits() & 0x7FFF_FFFF_FFFF_FFFF)
771 }
772
773 let len = data.len();
774 if len == 0 || first_val >= len {
775 return Ok(());
776 }
777
778 let alpha_e = 2.0 / (channel_length as f64 + 1.0);
779 let beta_e = 1.0 - alpha_e;
780 let alpha_t = 2.0 / (average_length as f64 + 1.0);
781 let beta_t = 1.0 - alpha_t;
782
783 let ci_start = first_val + channel_length.saturating_sub(1);
784 if ci_start >= len {
785 return Ok(());
786 }
787
788 let x_ptr = data.as_ptr();
789 let wt1_ptr = wt1.as_mut_ptr();
790 let wt2_ptr = wt2.as_mut_ptr();
791 let hs_ptr = hist.as_mut_ptr();
792
793 let mut esa = *x_ptr.add(first_val);
794 let mut i = first_val + 1;
795 while i < ci_start {
796 let x = *x_ptr.add(i);
797 if x.is_finite() {
798 esa = beta_e.mul_add(esa, alpha_e * x);
799 }
800 i += 1;
801 }
802
803 let mut d = 0.0_f64;
804 let mut tci = 0.0_f64;
805
806 let mut ring = [0.0_f64; 4];
807 let mut rsum = 0.0_f64;
808 let mut rpos = 0usize;
809 let mut rlen = 0usize;
810
811 let k015 = 0.015_f64;
812 let inv4 = 0.25_f64;
813
814 {
815 let x = *x_ptr.add(ci_start);
816 if x.is_finite() {
817 esa = beta_e.mul_add(esa, alpha_e * x);
818 }
819 let abs_diff = if x.is_finite() {
820 fast_abs(x - esa)
821 } else {
822 f64::NAN
823 };
824 d = abs_diff;
825
826 let denom = k015 * d;
827 let ci = if denom != 0.0 && denom.is_finite() {
828 if x.is_finite() {
829 (x - esa) / denom
830 } else {
831 f64::NAN
832 }
833 } else {
834 0.0
835 };
836 tci = ci;
837 *wt1_ptr.add(ci_start) = tci;
838
839 ring[0] = tci;
840 rsum = tci;
841 rlen = 1;
842 rpos = 1;
843 }
844
845 i = ci_start + 1;
846 while i < len {
847 let x = *x_ptr.add(i);
848 let x_fin = x.is_finite();
849
850 if x_fin {
851 esa = beta_e.mul_add(esa, alpha_e * x);
852 let ad = fast_abs(x - esa);
853 d = beta_e.mul_add(d, alpha_e * ad);
854 }
855
856 let mut ci = 0.0_f64;
857 if x_fin {
858 let denom = k015 * d;
859 if denom != 0.0 && denom.is_finite() {
860 ci = (x - esa) / denom;
861 }
862 } else {
863 ci = f64::NAN;
864 }
865
866 if ci.is_finite() {
867 tci = beta_t.mul_add(tci, alpha_t * ci);
868 }
869
870 *wt1_ptr.add(i) = tci;
871
872 if rlen < 4 {
873 ring[rlen] = tci;
874 rsum += tci;
875 rlen += 1;
876 } else {
877 rsum += tci - ring[rpos];
878 ring[rpos] = tci;
879 rpos = (rpos + 1) & 3;
880 }
881
882 if rlen == 4 {
883 let sig = inv4 * rsum;
884 *wt2_ptr.add(i) = sig;
885 *hs_ptr.add(i) = tci - sig;
886 }
887
888 i += 1;
889 }
890
891 Ok(())
892}
893
894#[cfg(feature = "python")]
895#[pyfunction(name = "wto")]
896#[pyo3(signature = (close, channel_length, average_length, kernel=None))]
897pub fn wto_py<'py>(
898 py: Python<'py>,
899 close: numpy::PyReadonlyArray1<'py, f64>,
900 channel_length: usize,
901 average_length: usize,
902 kernel: Option<&str>,
903) -> PyResult<(
904 Bound<'py, numpy::PyArray1<f64>>,
905 Bound<'py, numpy::PyArray1<f64>>,
906 Bound<'py, numpy::PyArray1<f64>>,
907)> {
908 use numpy::{IntoPyArray, PyArrayMethods};
909
910 let slice = close.as_slice()?;
911 let kern = validate_kernel(kernel, false)?;
912 let p = WtoParams {
913 channel_length: Some(channel_length),
914 average_length: Some(average_length),
915 };
916 let inp = WtoInput::from_slice(slice, p);
917 let out = py
918 .allow_threads(|| wto_with_kernel(&inp, kern))
919 .map_err(|e| PyValueError::new_err(e.to_string()))?;
920 Ok((
921 out.wavetrend1.into_pyarray(py),
922 out.wavetrend2.into_pyarray(py),
923 out.histogram.into_pyarray(py),
924 ))
925}
926
927#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
928#[wasm_bindgen(js_name = "wto_js")]
929pub fn wto_js(
930 close: &[f64],
931 channel_length: usize,
932 average_length: usize,
933) -> Result<js_sys::Object, JsValue> {
934 let params = WtoParams {
935 channel_length: Some(channel_length),
936 average_length: Some(average_length),
937 };
938 let input = WtoInput::from_slice(close, params);
939
940 let output = wto(&input).map_err(|e| JsValue::from_str(&e.to_string()))?;
941
942 let result = js_sys::Object::new();
943
944 let wt1_array = js_sys::Float64Array::new_with_length(output.wavetrend1.len() as u32);
945 wt1_array.copy_from(&output.wavetrend1);
946 js_sys::Reflect::set(&result, &JsValue::from_str("wavetrend1"), &wt1_array)?;
947
948 let wt2_array = js_sys::Float64Array::new_with_length(output.wavetrend2.len() as u32);
949 wt2_array.copy_from(&output.wavetrend2);
950 js_sys::Reflect::set(&result, &JsValue::from_str("wavetrend2"), &wt2_array)?;
951
952 let hist_array = js_sys::Float64Array::new_with_length(output.histogram.len() as u32);
953 hist_array.copy_from(&output.histogram);
954 js_sys::Reflect::set(&result, &JsValue::from_str("histogram"), &hist_array)?;
955
956 Ok(result)
957}
958
959#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
960#[wasm_bindgen]
961pub fn wto_alloc(len: usize) -> *mut f64 {
962 let mut v = Vec::<f64>::with_capacity(len);
963 let p = v.as_mut_ptr();
964 core::mem::forget(v);
965 p
966}
967
968#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
969#[wasm_bindgen]
970pub fn wto_free(ptr: *mut f64, len: usize) {
971 unsafe {
972 let _ = Vec::from_raw_parts(ptr, len, len);
973 }
974}
975
976#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
977#[wasm_bindgen]
978pub fn wto_into(
979 in_ptr: *const f64,
980 wt1_ptr: *mut f64,
981 wt2_ptr: *mut f64,
982 hist_ptr: *mut f64,
983 len: usize,
984 channel_length: usize,
985 average_length: usize,
986) -> Result<(), JsValue> {
987 if in_ptr.is_null() || wt1_ptr.is_null() || wt2_ptr.is_null() || hist_ptr.is_null() {
988 return Err(JsValue::from_str("null pointer passed to wto_into"));
989 }
990 unsafe {
991 let data = core::slice::from_raw_parts(in_ptr, len);
992 let wt1 = core::slice::from_raw_parts_mut(wt1_ptr, len);
993 let wt2 = core::slice::from_raw_parts_mut(wt2_ptr, len);
994 let hist = core::slice::from_raw_parts_mut(hist_ptr, len);
995
996 let p = WtoParams {
997 channel_length: Some(channel_length),
998 average_length: Some(average_length),
999 };
1000 let inp = WtoInput::from_slice(data, p);
1001
1002 wto_into_slices(wt1, wt2, hist, &inp, Kernel::Auto)
1003 .map_err(|e| JsValue::from_str(&e.to_string()))
1004 }
1005}
1006
1007#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1008#[derive(Serialize, Deserialize)]
1009pub struct WtoResult {
1010 pub values: Vec<f64>,
1011 pub rows: usize,
1012 pub cols: usize,
1013}
1014
1015#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1016#[wasm_bindgen(js_name = "wto_unified")]
1017pub fn wto_unified_js(
1018 close: &[f64],
1019 channel_length: usize,
1020 average_length: usize,
1021) -> Result<JsValue, JsValue> {
1022 let params = WtoParams {
1023 channel_length: Some(channel_length),
1024 average_length: Some(average_length),
1025 };
1026 let input = WtoInput::from_slice(close, params);
1027 let out = wto(&input).map_err(|e| JsValue::from_str(&e.to_string()))?;
1028
1029 let cols = close.len();
1030 let cap = 3usize
1031 .checked_mul(cols)
1032 .ok_or_else(|| JsValue::from_str("overflow in wto_unified_js allocation"))?;
1033 let mut values = Vec::with_capacity(cap);
1034 values.extend_from_slice(&out.wavetrend1);
1035 values.extend_from_slice(&out.wavetrend2);
1036 values.extend_from_slice(&out.histogram);
1037
1038 let res = WtoResult {
1039 values,
1040 rows: 3,
1041 cols,
1042 };
1043 serde_wasm_bindgen::to_value(&res)
1044 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1045}
1046
1047#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1048#[derive(Serialize, Deserialize)]
1049pub struct WtoBatchConfig {
1050 pub channel: (usize, usize, usize),
1051 pub average: (usize, usize, usize),
1052}
1053
1054#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1055#[derive(Serialize, Deserialize)]
1056pub struct WtoBatchJsOutput {
1057 pub wavetrend1: Vec<f64>,
1058 pub wavetrend2: Vec<f64>,
1059 pub histogram: Vec<f64>,
1060 pub combos: Vec<WtoParams>,
1061 pub rows: usize,
1062 pub cols: usize,
1063}
1064
1065#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1066#[wasm_bindgen]
1067pub fn wto_batch_into(
1068 in_ptr: *const f64,
1069 out_ptr: *mut f64,
1070 len: usize,
1071 ch_start: usize,
1072 ch_end: usize,
1073 ch_step: usize,
1074 av_start: usize,
1075 av_end: usize,
1076 av_step: usize,
1077) -> Result<usize, JsValue> {
1078 if in_ptr.is_null() || out_ptr.is_null() {
1079 return Err(JsValue::from_str("null pointer passed to wto_batch_into"));
1080 }
1081 if len == 0 {
1082 return Err(JsValue::from_str(&WtoError::EmptyInputData.to_string()));
1083 }
1084 if in_ptr == out_ptr {
1085 return Err(JsValue::from_str(
1086 "wto_batch_into: in_ptr and out_ptr must not alias",
1087 ));
1088 }
1089 unsafe {
1090 let data = core::slice::from_raw_parts(in_ptr, len);
1091 let first = data
1092 .iter()
1093 .position(|x| !x.is_nan())
1094 .ok_or_else(|| JsValue::from_str(&WtoError::AllValuesNaN.to_string()))?;
1095 let sweep = WtoBatchRange {
1096 channel: (ch_start, ch_end, ch_step),
1097 average: (av_start, av_end, av_step),
1098 };
1099 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1100 let rows = combos.len();
1101 let cols = len;
1102 let total = rows
1103 .checked_mul(cols)
1104 .ok_or_else(|| JsValue::from_str("rows * cols overflow in wto_batch_into"))?;
1105
1106 let out_mu = core::slice::from_raw_parts_mut(out_ptr as *mut MaybeUninit<f64>, total);
1107 let mut warms: Vec<usize> = Vec::with_capacity(rows);
1108 for p in combos.iter() {
1109 let channel_length = p.channel_length.unwrap_or(10);
1110 let average_length = p.average_length.unwrap_or(21);
1111
1112 if channel_length == 0 || channel_length > cols {
1113 return Err(JsValue::from_str(
1114 &WtoError::InvalidPeriod {
1115 period: channel_length,
1116 data_len: cols,
1117 }
1118 .to_string(),
1119 ));
1120 }
1121 if average_length == 0 || average_length > cols {
1122 return Err(JsValue::from_str(
1123 &WtoError::InvalidPeriod {
1124 period: average_length,
1125 data_len: cols,
1126 }
1127 .to_string(),
1128 ));
1129 }
1130
1131 let valid = cols.saturating_sub(first);
1132 if valid < channel_length {
1133 return Err(JsValue::from_str(
1134 &WtoError::NotEnoughValidData {
1135 needed: channel_length,
1136 valid,
1137 }
1138 .to_string(),
1139 ));
1140 }
1141
1142 let ci_start = first + channel_length - 1;
1143 warms.push(ci_start);
1144 }
1145 init_matrix_prefixes(out_mu, cols, &warms);
1146
1147 let out = core::slice::from_raw_parts_mut(out_ptr, total);
1148 wto_fill_wt1_grouped(data, &combos, first, detect_best_kernel(), false, out)
1149 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1150
1151 Ok(rows)
1152 }
1153}
1154
1155#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1156#[wasm_bindgen(js_name = "wto_batch")]
1157pub fn wto_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1158 let cfg: WtoBatchConfig = serde_wasm_bindgen::from_value(config)
1159 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1160 let sweep = WtoBatchRange {
1161 channel: cfg.channel,
1162 average: cfg.average,
1163 };
1164 let out = wto_batch_all_outputs_with_kernel(data, &sweep, Kernel::ScalarBatch)
1165 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1166 let js = WtoBatchJsOutput {
1167 wavetrend1: out.wt1,
1168 wavetrend2: out.wt2,
1169 histogram: out.hist,
1170 combos: out.combos,
1171 rows: out.rows,
1172 cols: out.cols,
1173 };
1174 serde_wasm_bindgen::to_value(&js)
1175 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1176}
1177
1178#[derive(Debug, Clone)]
1179pub struct WtoBatchRange {
1180 pub channel: (usize, usize, usize),
1181 pub average: (usize, usize, usize),
1182}
1183
1184impl Default for WtoBatchRange {
1185 fn default() -> Self {
1186 Self {
1187 channel: (10, 10, 0),
1188 average: (21, 270, 1),
1189 }
1190 }
1191}
1192
1193#[derive(Debug, Clone)]
1194pub struct WtoBatchOutput {
1195 pub values: Vec<f64>,
1196 pub combos: Vec<WtoParams>,
1197 pub rows: usize,
1198 pub cols: usize,
1199}
1200
1201impl WtoBatchOutput {
1202 pub fn values_for(&self, params: &WtoParams) -> Option<&[f64]> {
1203 self.row_for_params(params)
1204 .map(|row| &self.values[row * self.cols..(row + 1) * self.cols])
1205 }
1206
1207 pub fn row_for_params(&self, params: &WtoParams) -> Option<usize> {
1208 self.combos.iter().position(|p| {
1209 p.channel_length.unwrap_or(10) == params.channel_length.unwrap_or(10)
1210 && p.average_length.unwrap_or(21) == params.average_length.unwrap_or(21)
1211 })
1212 }
1213}
1214
1215#[derive(Debug, Clone)]
1216pub struct WtoBatchBuilder {
1217 channel_range: (usize, usize, usize),
1218 average_range: (usize, usize, usize),
1219 kernel: Kernel,
1220}
1221
1222impl Default for WtoBatchBuilder {
1223 fn default() -> Self {
1224 Self {
1225 channel_range: (10, 10, 0),
1226 average_range: (21, 270, 1),
1227 kernel: Kernel::Auto,
1228 }
1229 }
1230}
1231
1232impl WtoBatchBuilder {
1233 pub fn new() -> Self {
1234 Self::default()
1235 }
1236
1237 pub fn channel_range(mut self, start: usize, end: usize, step: usize) -> Self {
1238 self.channel_range = (start, end, step);
1239 self
1240 }
1241
1242 pub fn average_range(mut self, start: usize, end: usize, step: usize) -> Self {
1243 self.average_range = (start, end, step);
1244 self
1245 }
1246
1247 pub fn kernel(mut self, k: Kernel) -> Self {
1248 self.kernel = k;
1249 self
1250 }
1251
1252 pub fn apply_candles(
1253 self,
1254 candles: &Candles,
1255 source: &str,
1256 ) -> Result<WtoBatchOutput, WtoError> {
1257 wto_batch_candles(
1258 candles,
1259 source,
1260 self.channel_range,
1261 self.average_range,
1262 self.kernel,
1263 )
1264 }
1265
1266 pub fn apply_slice(self, data: &[f64]) -> Result<WtoBatchOutput, WtoError> {
1267 let sweep = WtoBatchRange {
1268 channel: self.channel_range,
1269 average: self.average_range,
1270 };
1271 wto_batch_with_kernel(data, &sweep, self.kernel)
1272 }
1273
1274 pub fn channel_static(mut self, p: usize) -> Self {
1275 self.channel_range = (p, p, 0);
1276 self
1277 }
1278
1279 pub fn average_static(mut self, p: usize) -> Self {
1280 self.average_range = (p, p, 0);
1281 self
1282 }
1283
1284 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<WtoBatchOutput, WtoError> {
1285 WtoBatchBuilder::new().kernel(k).apply_slice(data)
1286 }
1287
1288 pub fn with_default_candles(c: &Candles) -> Result<WtoBatchOutput, WtoError> {
1289 WtoBatchBuilder::new()
1290 .kernel(Kernel::Auto)
1291 .apply_candles(c, "close")
1292 }
1293}
1294
1295#[derive(Clone, Copy, Debug)]
1296struct WtoBatchMember {
1297 row: usize,
1298 average_length: usize,
1299}
1300
1301#[derive(Clone, Copy)]
1302struct ThreadSafePtr(*mut f64);
1303
1304unsafe impl Send for ThreadSafePtr {}
1305unsafe impl Sync for ThreadSafePtr {}
1306
1307#[inline]
1308fn group_rows_by_channel(
1309 combos: &[WtoParams],
1310 cols: usize,
1311) -> Result<Vec<(usize, Vec<WtoBatchMember>)>, WtoError> {
1312 let mut groups: BTreeMap<usize, Vec<WtoBatchMember>> = BTreeMap::new();
1313
1314 for (row, params) in combos.iter().enumerate() {
1315 let channel_length = params.channel_length.unwrap_or(10);
1316 let average_length = params.average_length.unwrap_or(21);
1317
1318 if channel_length == 0 || channel_length > cols {
1319 return Err(WtoError::InvalidPeriod {
1320 period: channel_length,
1321 data_len: cols,
1322 });
1323 }
1324 if average_length == 0 || average_length > cols {
1325 return Err(WtoError::InvalidPeriod {
1326 period: average_length,
1327 data_len: cols,
1328 });
1329 }
1330
1331 groups
1332 .entry(channel_length)
1333 .or_default()
1334 .push(WtoBatchMember {
1335 row,
1336 average_length,
1337 });
1338 }
1339
1340 Ok(groups.into_iter().collect())
1341}
1342
1343#[inline]
1344fn apply_ci_to_members(
1345 ci: &[f64],
1346 members: &[WtoBatchMember],
1347 start_ci: usize,
1348 cols: usize,
1349 out_ptr: ThreadSafePtr,
1350 parallel: bool,
1351) -> Result<(), WtoError> {
1352 let _ = parallel;
1353 for member in members {
1354 let offset = member.row.checked_mul(cols).ok_or_else(|| {
1355 WtoError::InvalidInput("row * cols overflow in apply_ci_to_members".into())
1356 })?;
1357 let dst = unsafe { core::slice::from_raw_parts_mut(out_ptr.0.add(offset), cols) };
1358 ema_pinescript_into(ci, member.average_length, start_ci, dst);
1359 }
1360 Ok(())
1361}
1362
1363fn wto_fill_wt1_grouped(
1364 data: &[f64],
1365 combos: &[WtoParams],
1366 first: usize,
1367 kernel: Kernel,
1368 parallel: bool,
1369 out: &mut [f64],
1370) -> Result<(), WtoError> {
1371 let cols = data.len();
1372 let groups = group_rows_by_channel(combos, cols)?;
1373 if groups.is_empty() {
1374 return Ok(());
1375 }
1376
1377 let kernel = match kernel {
1378 Kernel::Auto => Kernel::Scalar,
1379 other => other.to_non_batch(),
1380 };
1381
1382 match kernel {
1383 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1384 Kernel::Avx512 => unsafe {
1385 wto_batch_fill_wt1_grouped_avx512(data, &groups, first, out, parallel)
1386 },
1387 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1388 Kernel::Avx2 => unsafe {
1389 wto_batch_fill_wt1_grouped_avx2(data, &groups, first, out, parallel)
1390 },
1391 Kernel::Scalar => wto_batch_fill_wt1_grouped_scalar(data, &groups, first, out, parallel),
1392 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1393 Kernel::Avx2 | Kernel::Avx512 => {
1394 wto_batch_fill_wt1_grouped_scalar(data, &groups, first, out, parallel)
1395 }
1396 #[allow(unreachable_patterns)]
1397 _ => wto_batch_fill_wt1_grouped_scalar(data, &groups, first, out, parallel),
1398 }
1399}
1400
1401fn wto_batch_fill_wt1_grouped_scalar(
1402 data: &[f64],
1403 groups: &[(usize, Vec<WtoBatchMember>)],
1404 first: usize,
1405 out: &mut [f64],
1406 parallel: bool,
1407) -> Result<(), WtoError> {
1408 let cols = data.len();
1409 let out_ptr = ThreadSafePtr(out.as_mut_ptr());
1410
1411 for (channel_length, members) in groups.iter() {
1412 let channel_length = *channel_length;
1413 let start_ci = first + channel_length - 1;
1414 if start_ci >= cols {
1415 return Err(WtoError::NotEnoughValidData {
1416 needed: start_ci + 1,
1417 valid: cols.saturating_sub(first),
1418 });
1419 }
1420
1421 let mut scratch = make_uninit_matrix(3, cols);
1422 let warms = [start_ci, start_ci, start_ci];
1423 init_matrix_prefixes(&mut scratch, cols, &warms);
1424
1425 let mut guard = core::mem::ManuallyDrop::new(scratch);
1426 let flat: &mut [f64] =
1427 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
1428 let (esa, rest) = flat.split_at_mut(cols);
1429 let (d, ci) = rest.split_at_mut(cols);
1430
1431 ema_pinescript_into(data, channel_length, first, esa);
1432
1433 for i in 0..cols {
1434 ci[i] = (data[i] - esa[i]).abs();
1435 }
1436
1437 ema_pinescript_into(ci, channel_length, start_ci, d);
1438
1439 for i in start_ci..cols {
1440 let denom = 0.015 * d[i];
1441 ci[i] = if denom != 0.0 && denom.is_finite() {
1442 (data[i] - esa[i]) / denom
1443 } else {
1444 0.0
1445 };
1446 }
1447
1448 let ci_slice: &[f64] = ci;
1449 apply_ci_to_members(
1450 ci_slice,
1451 members.as_slice(),
1452 start_ci,
1453 cols,
1454 out_ptr,
1455 parallel,
1456 )?;
1457
1458 unsafe {
1459 Vec::from_raw_parts(
1460 guard.as_mut_ptr() as *mut f64,
1461 guard.len(),
1462 guard.capacity(),
1463 );
1464 }
1465 core::mem::forget(guard);
1466 }
1467
1468 Ok(())
1469}
1470
1471#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1472#[target_feature(enable = "avx2")]
1473unsafe fn wto_batch_fill_wt1_grouped_avx2(
1474 data: &[f64],
1475 groups: &[(usize, Vec<WtoBatchMember>)],
1476 first: usize,
1477 out: &mut [f64],
1478 parallel: bool,
1479) -> Result<(), WtoError> {
1480 use core::arch::x86_64::*;
1481
1482 let cols = data.len();
1483 let out_ptr = ThreadSafePtr(out.as_mut_ptr());
1484
1485 for (channel_length, members) in groups.iter() {
1486 let channel_length = *channel_length;
1487 let start_ci = first + channel_length - 1;
1488 if start_ci >= cols {
1489 return Err(WtoError::NotEnoughValidData {
1490 needed: start_ci + 1,
1491 valid: cols.saturating_sub(first),
1492 });
1493 }
1494
1495 let mut scratch = make_uninit_matrix(3, cols);
1496 let warms = [start_ci, start_ci, start_ci];
1497 init_matrix_prefixes(&mut scratch, cols, &warms);
1498
1499 let mut guard = core::mem::ManuallyDrop::new(scratch);
1500 let flat: &mut [f64] =
1501 core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len());
1502 let (esa, rest) = flat.split_at_mut(cols);
1503 let (d, ci) = rest.split_at_mut(cols);
1504
1505 ema_pinescript_into(data, channel_length, first, esa);
1506
1507 let signmask = _mm256_set1_pd(-0.0_f64);
1508 let mut i = 0usize;
1509 while i + 4 <= cols {
1510 let x = _mm256_loadu_pd(data.as_ptr().add(i));
1511 let e = _mm256_loadu_pd(esa.as_ptr().add(i));
1512 let diff = _mm256_sub_pd(x, e);
1513 let absd = _mm256_andnot_pd(signmask, diff);
1514 _mm256_storeu_pd(ci.as_mut_ptr().add(i), absd);
1515 i += 4;
1516 }
1517 while i < cols {
1518 ci[i] = (data[i] - esa[i]).abs();
1519 i += 1;
1520 }
1521
1522 ema_pinescript_into(ci, channel_length, start_ci, d);
1523
1524 let k015 = _mm256_set1_pd(0.015_f64);
1525 let zero = _mm256_set1_pd(0.0_f64);
1526 let infv = _mm256_set1_pd(f64::INFINITY);
1527 let mut j = start_ci;
1528 while j + 4 <= cols {
1529 let x = _mm256_loadu_pd(data.as_ptr().add(j));
1530 let e = _mm256_loadu_pd(esa.as_ptr().add(j));
1531 let num = _mm256_sub_pd(x, e);
1532
1533 let dv = _mm256_loadu_pd(d.as_ptr().add(j));
1534 let den = _mm256_mul_pd(k015, dv);
1535
1536 let neq0 = _mm256_cmp_pd(den, zero, _CMP_NEQ_OQ);
1537 let abs_den = _mm256_andnot_pd(signmask, den);
1538 let not_inf = _mm256_cmp_pd(abs_den, infv, _CMP_NEQ_OQ);
1539 let ord = _mm256_cmp_pd(den, den, _CMP_ORD_Q);
1540 let valid = _mm256_and_pd(_mm256_and_pd(neq0, not_inf), ord);
1541
1542 let q = _mm256_div_pd(num, den);
1543 let outv = _mm256_blendv_pd(zero, q, valid);
1544 _mm256_storeu_pd(ci.as_mut_ptr().add(j), outv);
1545 j += 4;
1546 }
1547 while j < cols {
1548 let denom = 0.015 * d[j];
1549 ci[j] = if denom != 0.0 && denom.is_finite() {
1550 (data[j] - esa[j]) / denom
1551 } else {
1552 0.0
1553 };
1554 j += 1;
1555 }
1556
1557 let ci_slice: &[f64] = ci;
1558 apply_ci_to_members(
1559 ci_slice,
1560 members.as_slice(),
1561 start_ci,
1562 cols,
1563 out_ptr,
1564 parallel,
1565 )?;
1566
1567 unsafe {
1568 Vec::from_raw_parts(
1569 guard.as_mut_ptr() as *mut f64,
1570 guard.len(),
1571 guard.capacity(),
1572 );
1573 }
1574 core::mem::forget(guard);
1575 }
1576
1577 Ok(())
1578}
1579
1580#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1581#[target_feature(enable = "avx512f")]
1582unsafe fn wto_batch_fill_wt1_grouped_avx512(
1583 data: &[f64],
1584 groups: &[(usize, Vec<WtoBatchMember>)],
1585 first: usize,
1586 out: &mut [f64],
1587 parallel: bool,
1588) -> Result<(), WtoError> {
1589 use core::arch::x86_64::*;
1590
1591 let cols = data.len();
1592 let out_ptr = ThreadSafePtr(out.as_mut_ptr());
1593
1594 for (channel_length, members) in groups.iter() {
1595 let channel_length = *channel_length;
1596 let start_ci = first + channel_length - 1;
1597 if start_ci >= cols {
1598 return Err(WtoError::NotEnoughValidData {
1599 needed: start_ci + 1,
1600 valid: cols.saturating_sub(first),
1601 });
1602 }
1603
1604 let mut scratch = make_uninit_matrix(3, cols);
1605 let warms = [start_ci, start_ci, start_ci];
1606 init_matrix_prefixes(&mut scratch, cols, &warms);
1607
1608 let mut guard = core::mem::ManuallyDrop::new(scratch);
1609 let flat: &mut [f64] =
1610 core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len());
1611 let (esa, rest) = flat.split_at_mut(cols);
1612 let (d, ci) = rest.split_at_mut(cols);
1613
1614 ema_pinescript_into(data, channel_length, first, esa);
1615
1616 let signmask = _mm512_set1_pd(-0.0_f64);
1617 let mut i = 0usize;
1618 while i + 8 <= cols {
1619 let x = _mm512_loadu_pd(data.as_ptr().add(i));
1620 let e = _mm512_loadu_pd(esa.as_ptr().add(i));
1621 let diff = _mm512_sub_pd(x, e);
1622 let absd = _mm512_andnot_pd(signmask, diff);
1623 _mm512_storeu_pd(ci.as_mut_ptr().add(i), absd);
1624 i += 8;
1625 }
1626 while i < cols {
1627 ci[i] = (data[i] - esa[i]).abs();
1628 i += 1;
1629 }
1630
1631 ema_pinescript_into(ci, channel_length, start_ci, d);
1632
1633 let k015 = _mm512_set1_pd(0.015_f64);
1634 let zero = _mm512_set1_pd(0.0_f64);
1635 let infv = _mm512_set1_pd(f64::INFINITY);
1636 let mut j = start_ci;
1637 while j + 8 <= cols {
1638 let x = _mm512_loadu_pd(data.as_ptr().add(j));
1639 let e = _mm512_loadu_pd(esa.as_ptr().add(j));
1640 let num = _mm512_sub_pd(x, e);
1641
1642 let dv = _mm512_loadu_pd(d.as_ptr().add(j));
1643 let den = _mm512_mul_pd(k015, dv);
1644
1645 let neq0 = _mm512_cmp_pd_mask(den, zero, _CMP_NEQ_OQ);
1646 let abs_den = _mm512_andnot_pd(signmask, den);
1647 let not_inf = _mm512_cmp_pd_mask(abs_den, infv, _CMP_NEQ_OQ);
1648 let ord = _mm512_cmp_pd_mask(den, den, _CMP_ORD_Q);
1649 let valid = neq0 & not_inf & ord;
1650
1651 let q = _mm512_div_pd(num, den);
1652 let outv = _mm512_mask_blend_pd(valid, zero, q);
1653 _mm512_storeu_pd(ci.as_mut_ptr().add(j), outv);
1654 j += 8;
1655 }
1656 while j < cols {
1657 let denom = 0.015 * d[j];
1658 ci[j] = if denom != 0.0 && denom.is_finite() {
1659 (data[j] - esa[j]) / denom
1660 } else {
1661 0.0
1662 };
1663 j += 1;
1664 }
1665
1666 let ci_slice: &[f64] = ci;
1667 apply_ci_to_members(
1668 ci_slice,
1669 members.as_slice(),
1670 start_ci,
1671 cols,
1672 out_ptr,
1673 parallel,
1674 )?;
1675
1676 unsafe {
1677 Vec::from_raw_parts(
1678 guard.as_mut_ptr() as *mut f64,
1679 guard.len(),
1680 guard.capacity(),
1681 );
1682 }
1683 core::mem::forget(guard);
1684 }
1685
1686 Ok(())
1687}
1688
1689fn wto_fill_wt1_row(
1690 data: &[f64],
1691 p: WtoParams,
1692 first: usize,
1693 kern: Kernel,
1694 dst: &mut [f64],
1695) -> Result<(), WtoError> {
1696 let cols = data.len();
1697 let channel_length = p.channel_length.unwrap_or(10);
1698 let average_length = p.average_length.unwrap_or(21);
1699
1700 let mut mu = make_uninit_matrix(2, cols);
1701 let warms = [first + channel_length - 1, first + channel_length - 1];
1702 init_matrix_prefixes(&mut mu, cols, &warms);
1703
1704 let mut guard = core::mem::ManuallyDrop::new(mu);
1705 let flat: &mut [f64] =
1706 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
1707 let (d, ci) = flat.split_at_mut(cols);
1708
1709 ema_pinescript_into(data, channel_length, first, dst);
1710
1711 for i in 0..cols {
1712 ci[i] = (data[i] - dst[i]).abs();
1713 }
1714
1715 let d_first = first + channel_length - 1;
1716 ema_pinescript_into(ci, channel_length, d_first, d);
1717
1718 let start = first + channel_length - 1;
1719 for i in start..cols {
1720 let denom = 0.015 * d[i];
1721 ci[i] = if denom.is_finite() && denom != 0.0 {
1722 (data[i] - dst[i]) / denom
1723 } else {
1724 0.0
1725 };
1726 }
1727
1728 let ci_first = start;
1729 ema_pinescript_into(ci, average_length, ci_first, dst);
1730
1731 Ok(())
1732}
1733
1734#[inline(always)]
1735fn expand_grid(r: &WtoBatchRange) -> Result<Vec<WtoParams>, WtoError> {
1736 fn axis_u((s, e, st): (usize, usize, usize)) -> Result<Vec<usize>, WtoError> {
1737 if st == 0 || s == e {
1738 return Ok(vec![s]);
1739 }
1740 let mut out = Vec::new();
1741 if s < e {
1742 let mut v = s;
1743 loop {
1744 if v > e {
1745 break;
1746 }
1747 out.push(v);
1748 let next = v.checked_add(st).ok_or_else(|| WtoError::InvalidRange {
1749 start: s.to_string(),
1750 end: e.to_string(),
1751 step: st.to_string(),
1752 })?;
1753 if next == v {
1754 break;
1755 }
1756 v = next;
1757 }
1758 } else {
1759 let mut v = s;
1760 loop {
1761 if v < e {
1762 break;
1763 }
1764 out.push(v);
1765 if v - e < st {
1766 break;
1767 }
1768 v -= st;
1769 }
1770 }
1771 if out.is_empty() {
1772 return Err(WtoError::InvalidRange {
1773 start: s.to_string(),
1774 end: e.to_string(),
1775 step: st.to_string(),
1776 });
1777 }
1778 Ok(out)
1779 }
1780 let ch = axis_u(r.channel)?;
1781 let av = axis_u(r.average)?;
1782 let mut out = Vec::with_capacity(ch.len() * av.len());
1783 for &c in &ch {
1784 for &a in &av {
1785 out.push(WtoParams {
1786 channel_length: Some(c),
1787 average_length: Some(a),
1788 });
1789 }
1790 }
1791 if out.is_empty() {
1792 return Err(WtoError::InvalidRange {
1793 start: r.channel.0.to_string(),
1794 end: r.channel.1.to_string(),
1795 step: r.channel.2.to_string(),
1796 });
1797 }
1798 Ok(out)
1799}
1800
1801pub fn wto_batch_with_kernel(
1802 data: &[f64],
1803 sweep: &WtoBatchRange,
1804 k: Kernel,
1805) -> Result<WtoBatchOutput, WtoError> {
1806 let kern = match k {
1807 Kernel::Auto => detect_best_batch_kernel(),
1808 x if x.is_batch() => x,
1809 other => {
1810 return Err(WtoError::InvalidKernelForBatch(other));
1811 }
1812 };
1813 let simd = match kern {
1814 Kernel::Avx512Batch => Kernel::Avx512,
1815 Kernel::Avx2Batch => Kernel::Avx2,
1816 Kernel::ScalarBatch => Kernel::Scalar,
1817 _ => unreachable!(),
1818 };
1819 wto_batch_inner(data, sweep, simd, true)
1820}
1821
1822#[inline(always)]
1823fn wto_batch_inner(
1824 data: &[f64],
1825 sweep: &WtoBatchRange,
1826 kern: Kernel,
1827 parallel: bool,
1828) -> Result<WtoBatchOutput, WtoError> {
1829 if data.is_empty() {
1830 return Err(WtoError::EmptyInputData);
1831 }
1832 let combos = expand_grid(sweep)?;
1833
1834 let cols = data.len();
1835 let rows = combos.len();
1836 if rows.checked_mul(cols).is_none() {
1837 return Err(WtoError::InvalidInput(
1838 "rows * cols overflow in wto_batch_inner".into(),
1839 ));
1840 }
1841
1842 let first = data
1843 .iter()
1844 .position(|x| !x.is_nan())
1845 .ok_or(WtoError::AllValuesNaN)?;
1846
1847 let mut mu = make_uninit_matrix(rows, cols);
1848 {
1849 let mut warms: Vec<usize> = Vec::with_capacity(rows);
1850 for p in combos.iter() {
1851 let channel_length = p.channel_length.unwrap_or(10);
1852 let average_length = p.average_length.unwrap_or(21);
1853
1854 if channel_length == 0 || channel_length > cols {
1855 return Err(WtoError::InvalidPeriod {
1856 period: channel_length,
1857 data_len: cols,
1858 });
1859 }
1860 if average_length == 0 || average_length > cols {
1861 return Err(WtoError::InvalidPeriod {
1862 period: average_length,
1863 data_len: cols,
1864 });
1865 }
1866
1867 let valid = cols.saturating_sub(first);
1868 if valid < channel_length {
1869 return Err(WtoError::NotEnoughValidData {
1870 needed: channel_length,
1871 valid,
1872 });
1873 }
1874
1875 let ci_start = first + channel_length - 1;
1876 warms.push(ci_start);
1877 }
1878 init_matrix_prefixes(&mut mu, cols, &warms);
1879 }
1880 let mut guard = core::mem::ManuallyDrop::new(mu);
1881 let out: &mut [f64] =
1882 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
1883
1884 wto_fill_wt1_grouped(data, &combos, first, kern, parallel, out)?;
1885
1886 let values = unsafe {
1887 Vec::from_raw_parts(
1888 guard.as_mut_ptr() as *mut f64,
1889 guard.len(),
1890 guard.capacity(),
1891 )
1892 };
1893 core::mem::forget(guard);
1894 Ok(WtoBatchOutput {
1895 values,
1896 combos,
1897 rows,
1898 cols,
1899 })
1900}
1901
1902pub fn wto_batch_slice(
1903 data: &[f64],
1904 channel_range: (usize, usize, usize),
1905 average_range: (usize, usize, usize),
1906 kernel: Kernel,
1907) -> Result<WtoBatchOutput, WtoError> {
1908 let sweep = WtoBatchRange {
1909 channel: channel_range,
1910 average: average_range,
1911 };
1912 wto_batch_with_kernel(data, &sweep, kernel)
1913}
1914
1915pub fn wto_batch_candles(
1916 candles: &Candles,
1917 source: &str,
1918 channel_range: (usize, usize, usize),
1919 average_range: (usize, usize, usize),
1920 kernel: Kernel,
1921) -> Result<WtoBatchOutput, WtoError> {
1922 let data = source_type(candles, source);
1923 wto_batch_slice(data, channel_range, average_range, kernel)
1924}
1925
1926#[derive(Debug, Clone)]
1927pub struct WtoBatchAllOutput {
1928 pub wt1: Vec<f64>,
1929 pub wt2: Vec<f64>,
1930 pub hist: Vec<f64>,
1931 pub combos: Vec<WtoParams>,
1932 pub rows: usize,
1933 pub cols: usize,
1934}
1935
1936pub fn wto_batch_all_outputs_with_kernel(
1937 data: &[f64],
1938 sweep: &WtoBatchRange,
1939 k: Kernel,
1940) -> Result<WtoBatchAllOutput, WtoError> {
1941 if data.is_empty() {
1942 return Err(WtoError::EmptyInputData);
1943 }
1944
1945 let combos = expand_grid(sweep)?;
1946
1947 let cols = data.len();
1948 let rows = combos.len();
1949 if rows.checked_mul(cols).is_none() {
1950 return Err(WtoError::InvalidInput(
1951 "rows * cols overflow in wto_batch_all_outputs_with_kernel".into(),
1952 ));
1953 }
1954
1955 let mut wt1_mu = make_uninit_matrix(rows, cols);
1956 let mut wt2_mu = make_uninit_matrix(rows, cols);
1957 let mut hist_mu = make_uninit_matrix(rows, cols);
1958
1959 let first = data
1960 .iter()
1961 .position(|x| !x.is_nan())
1962 .ok_or(WtoError::AllValuesNaN)?;
1963 let mut warm_wt1: Vec<usize> = Vec::with_capacity(rows);
1964 let mut warm_wt2_hist: Vec<usize> = Vec::with_capacity(rows);
1965 for p in combos.iter() {
1966 let channel_length = p.channel_length.unwrap_or(10);
1967 let average_length = p.average_length.unwrap_or(21);
1968
1969 if channel_length == 0 || channel_length > cols {
1970 return Err(WtoError::InvalidPeriod {
1971 period: channel_length,
1972 data_len: cols,
1973 });
1974 }
1975 if average_length == 0 || average_length > cols {
1976 return Err(WtoError::InvalidPeriod {
1977 period: average_length,
1978 data_len: cols,
1979 });
1980 }
1981
1982 let valid = cols.saturating_sub(first);
1983 let needed = channel_length
1984 .saturating_add(3)
1985 .max(average_length.saturating_add(3));
1986 if valid < needed {
1987 return Err(WtoError::NotEnoughValidData { needed, valid });
1988 }
1989
1990 let ci_start = first + channel_length - 1;
1991 warm_wt1.push(ci_start);
1992 warm_wt2_hist.push(ci_start + 3);
1993 }
1994
1995 init_matrix_prefixes(&mut wt1_mu, cols, &warm_wt1);
1996 init_matrix_prefixes(&mut wt2_mu, cols, &warm_wt2_hist);
1997 init_matrix_prefixes(&mut hist_mu, cols, &warm_wt2_hist);
1998
1999 let mut wt1_guard = core::mem::ManuallyDrop::new(wt1_mu);
2000 let mut wt2_guard = core::mem::ManuallyDrop::new(wt2_mu);
2001 let mut hist_guard = core::mem::ManuallyDrop::new(hist_mu);
2002
2003 let wt1_out: &mut [f64] = unsafe {
2004 core::slice::from_raw_parts_mut(wt1_guard.as_mut_ptr() as *mut f64, wt1_guard.len())
2005 };
2006 let wt2_out: &mut [f64] = unsafe {
2007 core::slice::from_raw_parts_mut(wt2_guard.as_mut_ptr() as *mut f64, wt2_guard.len())
2008 };
2009 let hist_out: &mut [f64] = unsafe {
2010 core::slice::from_raw_parts_mut(hist_guard.as_mut_ptr() as *mut f64, hist_guard.len())
2011 };
2012
2013 let kern = match k {
2014 Kernel::Auto => Kernel::Scalar,
2015 x => x,
2016 };
2017
2018 wto_fill_wt1_grouped(data, &combos, first, kern, true, wt1_out)?;
2019
2020 for row in 0..rows {
2021 let row_start = row.checked_mul(cols).ok_or_else(|| {
2022 WtoError::InvalidInput(
2023 "row * cols overflow in wto_batch_all_outputs_with_kernel".into(),
2024 )
2025 })?;
2026 let row_end = row_start + cols;
2027
2028 let wt1_row = &wt1_out[row_start..row_end];
2029 let wt2_row = &mut wt2_out[row_start..row_end];
2030 let hist_row = &mut hist_out[row_start..row_end];
2031
2032 let sma_input = SmaInput::from_slice(wt1_row, SmaParams { period: Some(4) });
2033 sma_into_slice(wt2_row, &sma_input, kern)
2034 .map_err(|e| WtoError::ComputationError(format!("WT2 SMA error: {}", e)))?;
2035
2036 for i in 0..cols {
2037 if !wt1_row[i].is_nan() && !wt2_row[i].is_nan() {
2038 hist_row[i] = wt1_row[i] - wt2_row[i];
2039 }
2040 }
2041 }
2042
2043 let wt1 = unsafe {
2044 Vec::from_raw_parts(
2045 wt1_guard.as_mut_ptr() as *mut f64,
2046 wt1_guard.len(),
2047 wt1_guard.capacity(),
2048 )
2049 };
2050 let wt2 = unsafe {
2051 Vec::from_raw_parts(
2052 wt2_guard.as_mut_ptr() as *mut f64,
2053 wt2_guard.len(),
2054 wt2_guard.capacity(),
2055 )
2056 };
2057 let hist = unsafe {
2058 Vec::from_raw_parts(
2059 hist_guard.as_mut_ptr() as *mut f64,
2060 hist_guard.len(),
2061 hist_guard.capacity(),
2062 )
2063 };
2064
2065 core::mem::forget(wt1_guard);
2066 core::mem::forget(wt2_guard);
2067 core::mem::forget(hist_guard);
2068
2069 Ok(WtoBatchAllOutput {
2070 wt1,
2071 wt2,
2072 hist,
2073 combos,
2074 rows,
2075 cols,
2076 })
2077}
2078
2079#[derive(Debug, Clone)]
2080pub struct WtoStream {
2081 channel_length: usize,
2082 average_length: usize,
2083
2084 esa_alpha: f64,
2085 esa_beta: f64,
2086 tci_alpha: f64,
2087 tci_beta: f64,
2088 k015: f64,
2089 inv4: f64,
2090
2091 samples: usize,
2092 ci_ready: bool,
2093
2094 esa: f64,
2095 d: f64,
2096 tci: f64,
2097
2098 ring: [f64; 4],
2099 rsum: f64,
2100 rpos: usize,
2101 rlen: usize,
2102}
2103
2104const STRICT_WT2_NANS: bool = true;
2105
2106impl WtoStream {
2107 pub fn try_new(params: WtoParams) -> Result<Self, WtoError> {
2108 let channel_length = params.channel_length.unwrap_or(10);
2109 let average_length = params.average_length.unwrap_or(21);
2110
2111 if channel_length == 0 {
2112 return Err(WtoError::InvalidPeriod {
2113 period: channel_length,
2114 data_len: 0,
2115 });
2116 }
2117 if average_length == 0 {
2118 return Err(WtoError::InvalidPeriod {
2119 period: average_length,
2120 data_len: 0,
2121 });
2122 }
2123
2124 let esa_alpha = 2.0 / (channel_length as f64 + 1.0);
2125 let tci_alpha = 2.0 / (average_length as f64 + 1.0);
2126
2127 Ok(Self {
2128 channel_length,
2129 average_length,
2130 esa_alpha,
2131 esa_beta: 1.0 - esa_alpha,
2132 tci_alpha,
2133 tci_beta: 1.0 - tci_alpha,
2134 k015: 0.015_f64,
2135 inv4: 0.25_f64,
2136
2137 samples: 0,
2138 ci_ready: channel_length == 1,
2139
2140 esa: 0.0,
2141 d: 0.0,
2142 tci: 0.0,
2143
2144 ring: [0.0; 4],
2145 rsum: 0.0,
2146 rpos: 0,
2147 rlen: 0,
2148 })
2149 }
2150
2151 #[inline(always)]
2152 fn fast_abs(x: f64) -> f64 {
2153 f64::from_bits(x.to_bits() & 0x7FFF_FFFF_FFFF_FFFF)
2154 }
2155
2156 pub fn update(&mut self, value: f64) -> Option<(f64, f64, f64)> {
2157 if !value.is_finite() {
2158 return None;
2159 }
2160
2161 if self.samples == 0 {
2162 self.esa = value;
2163 self.samples = 1;
2164
2165 if self.ci_ready {
2166 let ci = 0.0;
2167 self.tci = ci;
2168 self.push_wt2(ci);
2169
2170 let (wt2, hist) = self.emit_wt2(ci);
2171 return Some((ci, wt2, hist));
2172 }
2173 return None;
2174 }
2175
2176 self.esa = self.esa_beta.mul_add(self.esa, self.esa_alpha * value);
2177
2178 if !self.ci_ready {
2179 self.samples += 1;
2180
2181 if self.samples == self.channel_length {
2182 self.d = Self::fast_abs(value - self.esa);
2183
2184 let denom = self.k015 * self.d;
2185 let ci = if denom != 0.0 && denom.is_finite() {
2186 (value - self.esa) / denom
2187 } else {
2188 0.0
2189 };
2190
2191 self.tci = ci;
2192 self.ci_ready = true;
2193
2194 self.push_wt2(ci);
2195 let (wt2, hist) = self.emit_wt2(ci);
2196 return Some((ci, wt2, hist));
2197 } else {
2198 return None;
2199 }
2200 }
2201
2202 let ad = Self::fast_abs(value - self.esa);
2203 self.d = self.esa_beta.mul_add(self.d, self.esa_alpha * ad);
2204
2205 let mut ci = 0.0;
2206 let denom = self.k015 * self.d;
2207 if denom != 0.0 && denom.is_finite() {
2208 ci = (value - self.esa) * (1.0 / denom);
2209 }
2210
2211 self.tci = self.tci_beta.mul_add(self.tci, self.tci_alpha * ci);
2212
2213 let wt1 = self.tci;
2214
2215 self.push_wt2(wt1);
2216 let (wt2, hist) = self.emit_wt2(wt1);
2217 Some((wt1, wt2, hist))
2218 }
2219
2220 #[inline(always)]
2221 fn push_wt2(&mut self, val: f64) {
2222 if self.rlen < 4 {
2223 self.ring[self.rlen] = val;
2224 self.rsum += val;
2225 self.rlen += 1;
2226 } else {
2227 self.rsum += val - self.ring[self.rpos];
2228 self.ring[self.rpos] = val;
2229 self.rpos = (self.rpos + 1) & 3;
2230 }
2231 }
2232
2233 #[inline(always)]
2234 fn emit_wt2(&self, wt1: f64) -> (f64, f64) {
2235 if self.rlen == 4 {
2236 let sig = self.inv4 * self.rsum;
2237 (sig, wt1 - sig)
2238 } else if STRICT_WT2_NANS {
2239 (f64::NAN, f64::NAN)
2240 } else {
2241 let sig = self.rsum / (self.rlen as f64);
2242 (sig, wt1 - sig)
2243 }
2244 }
2245
2246 pub fn last(&self) -> Option<(f64, f64, f64)> {
2247 if !self.ci_ready {
2248 return None;
2249 }
2250 let wt1 = self.tci;
2251 let (wt2, hist) = self.emit_wt2(wt1);
2252 Some((wt1, wt2, hist))
2253 }
2254
2255 pub fn reset(&mut self) {
2256 self.samples = 0;
2257 self.ci_ready = self.channel_length == 1;
2258
2259 self.esa = 0.0;
2260 self.d = 0.0;
2261 self.tci = 0.0;
2262
2263 self.ring = [0.0; 4];
2264 self.rsum = 0.0;
2265 self.rpos = 0;
2266 self.rlen = 0;
2267 }
2268}
2269
2270#[cfg(feature = "python")]
2271#[pyclass(name = "WtoStream")]
2272pub struct WtoStreamPy {
2273 inner: WtoStream,
2274}
2275
2276#[cfg(feature = "python")]
2277#[pymethods]
2278impl WtoStreamPy {
2279 #[new]
2280 fn new(channel_length: usize, average_length: usize) -> PyResult<Self> {
2281 let p = WtoParams {
2282 channel_length: Some(channel_length),
2283 average_length: Some(average_length),
2284 };
2285 Ok(Self {
2286 inner: WtoStream::try_new(p).map_err(|e| PyValueError::new_err(e.to_string()))?,
2287 })
2288 }
2289
2290 pub fn update(&mut self, value: f64) -> Option<(f64, f64, f64)> {
2291 self.inner.update(value)
2292 }
2293
2294 pub fn last(&self) -> Option<(f64, f64, f64)> {
2295 self.inner.last()
2296 }
2297
2298 pub fn reset(&mut self) {
2299 self.inner.reset()
2300 }
2301}
2302
2303#[cfg(feature = "python")]
2304#[pyfunction(name = "wto_batch")]
2305#[pyo3(signature = (close, channel_range, average_range, kernel=None))]
2306pub fn wto_batch_py<'py>(
2307 py: Python<'py>,
2308 close: numpy::PyReadonlyArray1<'py, f64>,
2309 channel_range: (usize, usize, usize),
2310 average_range: (usize, usize, usize),
2311 kernel: Option<&str>,
2312) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2313 use numpy::{IntoPyArray, PyArrayMethods};
2314 let slice = close.as_slice()?;
2315 let kern = validate_kernel(kernel, false)?;
2316
2317 let sweep = WtoBatchRange {
2318 channel: channel_range,
2319 average: average_range,
2320 };
2321 let out = py
2322 .allow_threads(|| wto_batch_all_outputs_with_kernel(slice, &sweep, kern))
2323 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2324
2325 let dict = pyo3::types::PyDict::new(py);
2326
2327 let wt1_arr = unsafe { numpy::PyArray1::<f64>::new(py, [out.rows * out.cols], false) };
2328 unsafe { wt1_arr.as_slice_mut()? }.copy_from_slice(&out.wt1);
2329 dict.set_item("wt1", wt1_arr.reshape((out.rows, out.cols))?)?;
2330
2331 let wt2_arr = unsafe { numpy::PyArray1::<f64>::new(py, [out.rows * out.cols], false) };
2332 unsafe { wt2_arr.as_slice_mut()? }.copy_from_slice(&out.wt2);
2333 dict.set_item("wt2", wt2_arr.reshape((out.rows, out.cols))?)?;
2334
2335 let hist_arr = unsafe { numpy::PyArray1::<f64>::new(py, [out.rows * out.cols], false) };
2336 unsafe { hist_arr.as_slice_mut()? }.copy_from_slice(&out.hist);
2337 dict.set_item("hist", hist_arr.reshape((out.rows, out.cols))?)?;
2338
2339 dict.set_item(
2340 "channel_lengths",
2341 out.combos
2342 .iter()
2343 .map(|p| p.channel_length.unwrap())
2344 .collect::<Vec<_>>()
2345 .into_pyarray(py),
2346 )?;
2347 dict.set_item(
2348 "average_lengths",
2349 out.combos
2350 .iter()
2351 .map(|p| p.average_length.unwrap())
2352 .collect::<Vec<_>>()
2353 .into_pyarray(py),
2354 )?;
2355 Ok(dict)
2356}
2357
2358#[cfg(all(feature = "python", feature = "cuda"))]
2359#[pyfunction(name = "wto_cuda_batch_dev")]
2360#[pyo3(signature = (close_f32, channel_range, average_range, device_id=0))]
2361pub fn wto_cuda_batch_dev_py<'py>(
2362 py: Python<'py>,
2363 close_f32: numpy::PyReadonlyArray1<'py, f32>,
2364 channel_range: (usize, usize, usize),
2365 average_range: (usize, usize, usize),
2366 device_id: usize,
2367) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2368 use numpy::IntoPyArray;
2369
2370 if !cuda_available() {
2371 return Err(PyValueError::new_err("CUDA not available"));
2372 }
2373
2374 let slice = close_f32.as_slice()?;
2375 let sweep = WtoBatchRange {
2376 channel: channel_range,
2377 average: average_range,
2378 };
2379
2380 let (CudaWtoBatchResult { outputs, combos }, dev_id) = py.allow_threads(|| {
2381 let cuda = CudaWto::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2382 let res = cuda
2383 .wto_batch_dev(slice, &sweep)
2384 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2385 Ok::<_, pyo3::PyErr>((res, cuda.device_id()))
2386 })?;
2387 let DeviceArrayF32Triplet { wt1, wt2, hist } = outputs;
2388
2389 let dict = pyo3::types::PyDict::new(py);
2390 let wt1_py = make_device_array_py(dev_id as usize, wt1)?;
2391 let wt2_py = make_device_array_py(dev_id as usize, wt2)?;
2392 let hist_py = make_device_array_py(dev_id as usize, hist)?;
2393 dict.set_item("wt1", Py::new(py, wt1_py)?)?;
2394 dict.set_item("wt2", Py::new(py, wt2_py)?)?;
2395 dict.set_item("hist", Py::new(py, hist_py)?)?;
2396
2397 let channel_vec: Vec<usize> = combos.iter().map(|p| p.channel_length.unwrap()).collect();
2398 let average_vec: Vec<usize> = combos.iter().map(|p| p.average_length.unwrap()).collect();
2399
2400 dict.set_item("channel_lengths", channel_vec.into_pyarray(py))?;
2401 dict.set_item("average_lengths", average_vec.into_pyarray(py))?;
2402 dict.set_item("rows", combos.len())?;
2403 dict.set_item("cols", slice.len())?;
2404
2405 Ok(dict)
2406}
2407
2408#[cfg(all(feature = "python", feature = "cuda"))]
2409#[pyfunction(name = "wto_cuda_many_series_one_param_dev")]
2410#[pyo3(signature = (data_tm_f32, channel_length, average_length, device_id=0))]
2411pub fn wto_cuda_many_series_one_param_dev_py<'py>(
2412 py: Python<'py>,
2413 data_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
2414 channel_length: usize,
2415 average_length: usize,
2416 device_id: usize,
2417) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2418 use numpy::PyUntypedArrayMethods;
2419
2420 if !cuda_available() {
2421 return Err(PyValueError::new_err("CUDA not available"));
2422 }
2423
2424 let shape = data_tm_f32.shape();
2425 if shape.len() != 2 {
2426 return Err(PyValueError::new_err("expected 2D array"));
2427 }
2428 let rows = shape[0];
2429 let cols = shape[1];
2430 let flat = data_tm_f32.as_slice()?;
2431
2432 let params = WtoParams {
2433 channel_length: Some(channel_length),
2434 average_length: Some(average_length),
2435 };
2436
2437 let (DeviceArrayF32Triplet { wt1, wt2, hist }, dev_id) = py.allow_threads(|| {
2438 let cuda = CudaWto::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2439 let res = cuda
2440 .wto_many_series_one_param_time_major_dev(flat, cols, rows, ¶ms)
2441 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2442 Ok::<_, pyo3::PyErr>((res, cuda.device_id()))
2443 })?;
2444
2445 let dict = pyo3::types::PyDict::new(py);
2446 let wt1_py = make_device_array_py(dev_id as usize, wt1)?;
2447 let wt2_py = make_device_array_py(dev_id as usize, wt2)?;
2448 let hist_py = make_device_array_py(dev_id as usize, hist)?;
2449 dict.set_item("wt1", Py::new(py, wt1_py)?)?;
2450 dict.set_item("wt2", Py::new(py, wt2_py)?)?;
2451 dict.set_item("hist", Py::new(py, hist_py)?)?;
2452 dict.set_item("rows", rows)?;
2453 dict.set_item("cols", cols)?;
2454 dict.set_item("channel_length", channel_length)?;
2455 dict.set_item("average_length", average_length)?;
2456
2457 Ok(dict)
2458}
2459
2460#[cfg(feature = "python")]
2461pub fn register_wto_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
2462 m.add_function(wrap_pyfunction!(wto_py, m)?)?;
2463 m.add_function(wrap_pyfunction!(wto_batch_py, m)?)?;
2464 m.add_class::<WtoStreamPy>()?;
2465 #[cfg(feature = "cuda")]
2466 {
2467 m.add_function(wrap_pyfunction!(wto_cuda_batch_dev_py, m)?)?;
2468 m.add_function(wrap_pyfunction!(wto_cuda_many_series_one_param_dev_py, m)?)?;
2469 }
2470 Ok(())
2471}
2472
2473#[cfg(test)]
2474mod tests {
2475 use super::*;
2476 use crate::utilities::data_loader::read_candles_from_csv;
2477
2478 macro_rules! skip_if_unsupported {
2479 ($kernel:expr, $test_name:expr) => {
2480 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
2481 if matches!(
2482 $kernel,
2483 Kernel::Avx2 | Kernel::Avx512 | Kernel::Avx2Batch | Kernel::Avx512Batch
2484 ) {
2485 eprintln!("[{}] Skipping due to missing AVX support", $test_name);
2486 return Ok(());
2487 }
2488 };
2489 }
2490
2491 fn check_wto_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2492 skip_if_unsupported!(kernel, test_name);
2493 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2494 let candles = read_candles_from_csv(file_path)?;
2495
2496 let input = WtoInput::from_candles(&candles, "close", WtoParams::default());
2497 let result = wto_with_kernel(&input, kernel)?;
2498
2499 let expected_wt1 = [
2500 -34.81423091,
2501 -33.92872278,
2502 -35.29125217,
2503 -34.93917015,
2504 -41.42578524,
2505 ];
2506
2507 let expected_wt2 = [
2508 -37.72141493,
2509 -35.54009606,
2510 -34.81718669,
2511 -34.74334400,
2512 -36.39623258,
2513 ];
2514
2515 let expected_hist = [
2516 2.90718403,
2517 1.61137328,
2518 -0.47406548,
2519 -0.19582615,
2520 -5.02955265,
2521 ];
2522
2523 let start = result.wavetrend1.len().saturating_sub(5);
2524
2525 for (i, &val) in result.wavetrend1[start..].iter().enumerate() {
2526 let diff = (val - expected_wt1[i]).abs();
2527
2528 let rel_tolerance = expected_wt1[i].abs() * 0.1;
2529 let abs_tolerance = 1e-6;
2530 assert!(
2531 diff < rel_tolerance.max(abs_tolerance),
2532 "WaveTrend1 mismatch at idx {}: got {}, expected {}, diff {}",
2533 i,
2534 val,
2535 expected_wt1[i],
2536 diff
2537 );
2538 }
2539
2540 for (i, &val) in result.wavetrend2[start..].iter().enumerate() {
2541 let diff = (val - expected_wt2[i]).abs();
2542 let rel_tolerance = expected_wt2[i].abs() * 0.1;
2543 let abs_tolerance = 1e-6;
2544 assert!(
2545 diff < rel_tolerance.max(abs_tolerance),
2546 "WaveTrend2 mismatch at idx {}: got {}, expected {}, diff {}",
2547 i,
2548 val,
2549 expected_wt2[i],
2550 diff
2551 );
2552 }
2553
2554 for (i, &val) in result.histogram[start..].iter().enumerate() {
2555 let diff = (val - expected_hist[i]).abs();
2556
2557 let abs_tolerance = 2.0;
2558 assert!(
2559 diff < abs_tolerance,
2560 "Histogram mismatch at idx {}: got {}, expected {}, diff {}",
2561 i,
2562 val,
2563 expected_hist[i],
2564 diff
2565 );
2566 }
2567
2568 Ok(())
2569 }
2570
2571 fn check_wto_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2572 skip_if_unsupported!(kernel, test_name);
2573
2574 let data = vec![1.0; 100];
2575 let params = WtoParams {
2576 channel_length: Some(12),
2577 average_length: None,
2578 };
2579 let input = WtoInput::from_slice(&data, params);
2580 let result = wto_with_kernel(&input, kernel)?;
2581
2582 assert_eq!(result.wavetrend1.len(), data.len());
2583 assert_eq!(result.wavetrend2.len(), data.len());
2584 assert_eq!(result.histogram.len(), data.len());
2585 Ok(())
2586 }
2587
2588 fn check_wto_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2589 skip_if_unsupported!(kernel, test_name);
2590
2591 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2592 let candles = read_candles_from_csv(file_path)?;
2593
2594 let input = WtoInput::with_default_candles(&candles);
2595 match input.data {
2596 WtoData::Candles { source, .. } => assert_eq!(source, "close"),
2597 _ => panic!("Expected WtoData::Candles"),
2598 }
2599 let output = wto_with_kernel(&input, kernel)?;
2600 assert_eq!(output.wavetrend1.len(), candles.close.len());
2601
2602 Ok(())
2603 }
2604
2605 fn check_wto_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2606 skip_if_unsupported!(kernel, test_name);
2607
2608 let data = [10.0, 20.0, 30.0];
2609 let params = WtoParams {
2610 channel_length: Some(0),
2611 average_length: None,
2612 };
2613 let input = WtoInput::from_slice(&data, params);
2614 let res = wto_with_kernel(&input, kernel);
2615 assert!(
2616 res.is_err(),
2617 "[{}] WTO should fail with zero period",
2618 test_name
2619 );
2620 Ok(())
2621 }
2622
2623 fn check_wto_period_exceeds_length(
2624 test_name: &str,
2625 kernel: Kernel,
2626 ) -> Result<(), Box<dyn Error>> {
2627 skip_if_unsupported!(kernel, test_name);
2628
2629 let data = [10.0, 20.0, 30.0];
2630 let params = WtoParams {
2631 channel_length: Some(10),
2632 average_length: None,
2633 };
2634 let input = WtoInput::from_slice(&data, params);
2635 let res = wto_with_kernel(&input, kernel);
2636 assert!(
2637 res.is_err(),
2638 "[{}] WTO should fail with period exceeding length",
2639 test_name
2640 );
2641 Ok(())
2642 }
2643
2644 fn check_wto_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2645 skip_if_unsupported!(kernel, test_name);
2646
2647 let single_point = [42.0];
2648 let params = WtoParams::default();
2649 let input = WtoInput::from_slice(&single_point, params);
2650 let res = wto_with_kernel(&input, kernel);
2651 assert!(
2652 res.is_err(),
2653 "[{}] WTO should fail with insufficient data",
2654 test_name
2655 );
2656 Ok(())
2657 }
2658
2659 fn check_wto_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2660 skip_if_unsupported!(kernel, test_name);
2661 let empty: [f64; 0] = [];
2662 let input = WtoInput::from_slice(&empty, WtoParams::default());
2663 let res = wto_with_kernel(&input, kernel);
2664 assert!(
2665 matches!(res, Err(WtoError::EmptyInputData)),
2666 "[{}] WTO should fail with empty input",
2667 test_name
2668 );
2669 Ok(())
2670 }
2671
2672 fn check_wto_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2673 skip_if_unsupported!(kernel, test_name);
2674 let data = vec![f64::NAN; 50];
2675 let params = WtoParams::default();
2676 let input = WtoInput::from_slice(&data, params);
2677 let res = wto_with_kernel(&input, kernel);
2678 assert!(
2679 matches!(res, Err(WtoError::AllValuesNaN)),
2680 "[{}] WTO should fail with all NaN values",
2681 test_name
2682 );
2683 Ok(())
2684 }
2685
2686 fn check_wto_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2687 skip_if_unsupported!(kernel, test_name);
2688
2689 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2690 let candles = read_candles_from_csv(file_path)?;
2691
2692 let first_params = WtoParams::default();
2693 let first_input = WtoInput::from_candles(&candles, "close", first_params);
2694 let first_result = wto_with_kernel(&first_input, kernel)?;
2695
2696 let second_params = WtoParams::default();
2697 let second_input = WtoInput::from_slice(&first_result.wavetrend1, second_params);
2698 let second_result = wto_with_kernel(&second_input, kernel)?;
2699
2700 assert_eq!(
2701 second_result.wavetrend1.len(),
2702 first_result.wavetrend1.len()
2703 );
2704 Ok(())
2705 }
2706
2707 fn check_wto_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2708 skip_if_unsupported!(kernel, test_name);
2709
2710 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2711 let candles = read_candles_from_csv(file_path)?;
2712
2713 let input = WtoInput::from_candles(&candles, "close", WtoParams::default());
2714 let res = wto_with_kernel(&input, kernel)?;
2715
2716 assert_eq!(res.wavetrend1.len(), candles.close.len());
2717 if res.wavetrend1.len() > 50 {
2718 for (i, &val) in res.wavetrend1[50..].iter().enumerate() {
2719 assert!(
2720 !val.is_nan(),
2721 "[{}] Found unexpected NaN at index {}",
2722 test_name,
2723 50 + i
2724 );
2725 }
2726 }
2727 Ok(())
2728 }
2729
2730 fn check_wto_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2731 skip_if_unsupported!(kernel, test_name);
2732
2733 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2734 let candles = read_candles_from_csv(file_path)?;
2735
2736 let params = WtoParams::default();
2737 let input = WtoInput::from_candles(&candles, "close", params.clone());
2738 let batch_result = wto_with_kernel(&input, kernel)?;
2739
2740 let mut stream = WtoStream::try_new(params)?;
2741
2742 let mut stream_wt1 = Vec::new();
2743 let mut stream_wt2 = Vec::new();
2744 let mut stream_hist = Vec::new();
2745
2746 for i in 0..candles.close.len() {
2747 if let Some((wt1, wt2, hist)) = stream.update(candles.close[i]) {
2748 stream_wt1.push(wt1);
2749 stream_wt2.push(wt2);
2750 stream_hist.push(hist);
2751 }
2752 }
2753
2754 assert!(!stream_wt1.is_empty());
2755 Ok(())
2756 }
2757
2758 #[cfg(debug_assertions)]
2759 fn check_wto_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2760 skip_if_unsupported!(kernel, test_name);
2761
2762 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2763 let candles = read_candles_from_csv(file_path)?;
2764
2765 let test_params = vec![
2766 WtoParams {
2767 channel_length: Some(5),
2768 average_length: Some(10),
2769 },
2770 WtoParams {
2771 channel_length: Some(20),
2772 average_length: Some(40),
2773 },
2774 WtoParams {
2775 channel_length: Some(3),
2776 average_length: Some(7),
2777 },
2778 ];
2779
2780 for params in test_params {
2781 let input = WtoInput::from_candles(&candles, "close", params.clone());
2782 let output = wto_with_kernel(&input, kernel)?;
2783
2784 for (i, &val) in output.wavetrend1.iter().enumerate() {
2785 if val.is_nan() {
2786 continue;
2787 }
2788
2789 let bits = val.to_bits();
2790 if bits == 0x11111111_11111111
2791 || bits == 0x22222222_22222222
2792 || bits == 0x33333333_33333333
2793 {
2794 panic!(
2795 "[{}] Found poison value {} (0x{:016X}) at index {} with params: {:?}",
2796 test_name, val, bits, i, params
2797 );
2798 }
2799 }
2800 }
2801
2802 Ok(())
2803 }
2804
2805 #[cfg(not(debug_assertions))]
2806 fn check_wto_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2807 Ok(())
2808 }
2809
2810 #[cfg(debug_assertions)]
2811 fn check_wto_no_poison_all(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2812 skip_if_unsupported!(kernel, test_name);
2813
2814 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2815 let c = read_candles_from_csv(file)?;
2816 let input = WtoInput::with_default_candles(&c);
2817 let out = wto_with_kernel(&input, kernel)?;
2818
2819 for series in [&out.wavetrend1, &out.wavetrend2, &out.histogram] {
2820 for &v in series {
2821 if v.is_nan() {
2822 continue;
2823 }
2824 let b = v.to_bits();
2825 assert!(
2826 b != 0x11111111_11111111
2827 && b != 0x22222222_22222222
2828 && b != 0x33333333_33333333,
2829 "[{}] poison value 0x{:016X}",
2830 test_name,
2831 b
2832 );
2833 }
2834 }
2835 Ok(())
2836 }
2837
2838 fn check_batch_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2839 skip_if_unsupported!(kernel, test_name);
2840
2841 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2842 let candles = read_candles_from_csv(file)?;
2843
2844 let output = WtoBatchBuilder::new()
2845 .channel_range(5, 15, 5)
2846 .average_range(10, 30, 10)
2847 .kernel(kernel)
2848 .apply_candles(&candles, "close")?;
2849
2850 for &v in &output.values {
2851 if v.is_nan() {
2852 continue;
2853 }
2854 let b = v.to_bits();
2855 assert!(
2856 b != 0x11111111_11111111 && b != 0x22222222_22222222 && b != 0x33333333_33333333,
2857 "[{}] batch poison value 0x{:016X}",
2858 test_name,
2859 b
2860 );
2861 }
2862
2863 let sweep = WtoBatchRange {
2864 channel: (5, 15, 5),
2865 average: (10, 30, 10),
2866 };
2867 let data = source_type(&candles, "close");
2868 let full_out = wto_batch_all_outputs_with_kernel(data, &sweep, kernel)?;
2869
2870 for series in [&full_out.wt1, &full_out.wt2, &full_out.hist] {
2871 for &v in series {
2872 if v.is_nan() {
2873 continue;
2874 }
2875 let b = v.to_bits();
2876 assert!(
2877 b != 0x11111111_11111111
2878 && b != 0x22222222_22222222
2879 && b != 0x33333333_33333333,
2880 "[{}] full batch poison value 0x{:016X}",
2881 test_name,
2882 b
2883 );
2884 }
2885 }
2886
2887 Ok(())
2888 }
2889
2890 macro_rules! generate_all_wto_tests {
2891 ($($test_fn:ident),*) => {
2892 paste::paste! {
2893 $(
2894 #[test]
2895 fn [<$test_fn _scalar>]() {
2896 let _ = $test_fn(stringify!([<$test_fn _scalar>]), Kernel::Scalar);
2897 }
2898 )*
2899 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2900 $(
2901 #[test]
2902 fn [<$test_fn _avx2>]() {
2903 let _ = $test_fn(stringify!([<$test_fn _avx2>]), Kernel::Avx2);
2904 }
2905 #[test]
2906 fn [<$test_fn _avx512>]() {
2907 let _ = $test_fn(stringify!([<$test_fn _avx512>]), Kernel::Avx512);
2908 }
2909 )*
2910 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2911 $(
2912 #[test]
2913 fn [<$test_fn _simd128>]() {
2914 let _ = $test_fn(stringify!([<$test_fn _simd128>]), Kernel::Scalar);
2915 }
2916 )*
2917 }
2918 }
2919 }
2920
2921 generate_all_wto_tests!(
2922 check_wto_accuracy,
2923 check_wto_partial_params,
2924 check_wto_default_candles,
2925 check_wto_zero_period,
2926 check_wto_period_exceeds_length,
2927 check_wto_very_small_dataset,
2928 check_wto_empty_input,
2929 check_wto_all_nan,
2930 check_wto_reinput,
2931 check_wto_nan_handling,
2932 check_wto_streaming,
2933 check_wto_no_poison,
2934 check_batch_poison
2935 );
2936
2937 #[cfg(debug_assertions)]
2938 #[test]
2939 fn test_wto_no_poison_all_scalar() {
2940 check_wto_no_poison_all("test_wto_no_poison_all_scalar", Kernel::Scalar).unwrap();
2941 }
2942
2943 #[cfg(all(debug_assertions, feature = "nightly-avx", target_arch = "x86_64"))]
2944 #[test]
2945 fn test_wto_no_poison_all_avx2() {
2946 check_wto_no_poison_all("test_wto_no_poison_all_avx2", Kernel::Avx2).unwrap();
2947 }
2948
2949 #[cfg(all(debug_assertions, feature = "nightly-avx", target_arch = "x86_64"))]
2950 #[test]
2951 fn test_wto_no_poison_all_avx512() {
2952 check_wto_no_poison_all("test_wto_no_poison_all_avx512", Kernel::Avx512).unwrap();
2953 }
2954
2955 #[cfg(all(debug_assertions, target_arch = "wasm32", target_feature = "simd128"))]
2956 #[test]
2957 fn test_wto_no_poison_all_simd128() {
2958 check_wto_no_poison_all("test_wto_no_poison_all_simd128", Kernel::Simd128).unwrap();
2959 }
2960
2961 fn check_batch_default_row(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2962 skip_if_unsupported!(kernel, test_name);
2963
2964 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2965 let candles = read_candles_from_csv(file_path)?;
2966
2967 let output = WtoBatchBuilder::new()
2968 .kernel(kernel)
2969 .apply_candles(&candles, "close")?;
2970
2971 let def = WtoParams::default();
2972 let row = output.values_for(&def).expect("default row missing");
2973
2974 assert_eq!(row.len(), candles.close.len());
2975 Ok(())
2976 }
2977
2978 fn check_batch_sweep(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2979 skip_if_unsupported!(kernel, test_name);
2980
2981 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2982 let candles = read_candles_from_csv(file_path)?;
2983
2984 let output = WtoBatchBuilder::new()
2985 .kernel(kernel)
2986 .channel_range(5, 15, 5)
2987 .average_range(10, 30, 10)
2988 .apply_candles(&candles, "close")?;
2989
2990 let expected_combos = 3 * 3;
2991 assert_eq!(output.combos.len(), expected_combos);
2992 assert_eq!(output.rows, expected_combos);
2993 assert_eq!(output.cols, candles.close.len());
2994
2995 Ok(())
2996 }
2997
2998 macro_rules! gen_batch_tests {
2999 ($fn_name:ident) => {
3000 paste::paste! {
3001 #[test]
3002 fn [<$fn_name _scalar>]() {
3003 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3004 }
3005 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3006 #[test]
3007 fn [<$fn_name _avx2>]() {
3008 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3009 }
3010 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3011 #[test]
3012 fn [<$fn_name _avx512>]() {
3013 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3014 }
3015 #[test]
3016 fn [<$fn_name _auto_detect>]() {
3017 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3018 }
3019 }
3020 };
3021 }
3022
3023 gen_batch_tests!(check_batch_default_row);
3024 gen_batch_tests!(check_batch_sweep);
3025
3026 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
3027 #[test]
3028 fn test_wto_into_matches_api() {
3029 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3030 let candles = read_candles_from_csv(file_path).expect("failed to load candles");
3031
3032 let input = WtoInput::with_default_candles(&candles);
3033
3034 let base = wto(&input).expect("wto baseline failed");
3035
3036 let n = candles.close.len();
3037 let mut wt1 = vec![0.0; n];
3038 let mut wt2 = vec![0.0; n];
3039 let mut hist = vec![0.0; n];
3040
3041 wto_into(&input, &mut wt1, &mut wt2, &mut hist).expect("wto_into failed");
3042
3043 fn eq_or_both_nan(a: f64, b: f64) -> bool {
3044 (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
3045 }
3046
3047 assert_eq!(wt1.len(), base.wavetrend1.len());
3048 assert_eq!(wt2.len(), base.wavetrend2.len());
3049 assert_eq!(hist.len(), base.histogram.len());
3050
3051 for i in 0..n {
3052 assert!(
3053 eq_or_both_nan(wt1[i], base.wavetrend1[i]),
3054 "wt1 mismatch at {}: into={}, api={}",
3055 i,
3056 wt1[i],
3057 base.wavetrend1[i]
3058 );
3059 assert!(
3060 eq_or_both_nan(wt2[i], base.wavetrend2[i]),
3061 "wt2 mismatch at {}: into={}, api={}",
3062 i,
3063 wt2[i],
3064 base.wavetrend2[i]
3065 );
3066 assert!(
3067 eq_or_both_nan(hist[i], base.histogram[i]),
3068 "hist mismatch at {}: into={}, api={}",
3069 i,
3070 hist[i],
3071 base.histogram[i]
3072 );
3073 }
3074 }
3075}