1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::DeviceArrayF32;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::cuda::wavetrend::CudaWavetrend;
7#[cfg(all(feature = "python", feature = "cuda"))]
8use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
9#[cfg(feature = "python")]
10use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
11#[cfg(feature = "python")]
12use pyo3::exceptions::PyValueError;
13#[cfg(feature = "python")]
14use pyo3::prelude::*;
15#[cfg(feature = "python")]
16use pyo3::types::PyDict;
17
18#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
19use serde::{Deserialize, Serialize};
20#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
21use wasm_bindgen::prelude::*;
22
23use crate::indicators::moving_averages::ema::{ema, EmaError, EmaInput, EmaParams};
24use crate::indicators::moving_averages::sma::{sma, SmaError, SmaInput, SmaParams};
25use crate::utilities::data_loader::{source_type, Candles};
26use crate::utilities::enums::Kernel;
27use crate::utilities::helpers::{
28 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
29 make_uninit_matrix,
30};
31#[cfg(feature = "python")]
32use crate::utilities::kernel_validation::validate_kernel;
33use aligned_vec::{AVec, CACHELINE_ALIGN};
34#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
35use core::arch::x86_64::*;
36#[cfg(not(target_arch = "wasm32"))]
37use rayon::prelude::*;
38use std::convert::AsRef;
39use thiserror::Error;
40
41impl<'a> AsRef<[f64]> for WavetrendInput<'a> {
42 #[inline(always)]
43 fn as_ref(&self) -> &[f64] {
44 match &self.data {
45 WavetrendData::Slice(slice) => slice,
46 WavetrendData::Candles { candles, source } => source_type(candles, source),
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
52pub enum WavetrendData<'a> {
53 Candles {
54 candles: &'a Candles,
55 source: &'a str,
56 },
57 Slice(&'a [f64]),
58}
59
60#[derive(Debug, Clone)]
61pub struct WavetrendOutput {
62 pub wt1: Vec<f64>,
63 pub wt2: Vec<f64>,
64 pub wt_diff: Vec<f64>,
65}
66
67#[derive(Debug, Clone)]
68pub struct WavetrendParams {
69 pub channel_length: Option<usize>,
70 pub average_length: Option<usize>,
71 pub ma_length: Option<usize>,
72 pub factor: Option<f64>,
73}
74
75impl Default for WavetrendParams {
76 fn default() -> Self {
77 Self {
78 channel_length: Some(9),
79 average_length: Some(12),
80 ma_length: Some(3),
81 factor: Some(0.015),
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
87pub struct WavetrendInput<'a> {
88 pub data: WavetrendData<'a>,
89 pub params: WavetrendParams,
90}
91
92impl<'a> WavetrendInput<'a> {
93 #[inline]
94 pub fn from_candles(c: &'a Candles, s: &'a str, p: WavetrendParams) -> Self {
95 Self {
96 data: WavetrendData::Candles {
97 candles: c,
98 source: s,
99 },
100 params: p,
101 }
102 }
103 #[inline]
104 pub fn from_slice(sl: &'a [f64], p: WavetrendParams) -> Self {
105 Self {
106 data: WavetrendData::Slice(sl),
107 params: p,
108 }
109 }
110 #[inline]
111 pub fn with_default_candles(c: &'a Candles) -> Self {
112 Self::from_candles(c, "hlc3", WavetrendParams::default())
113 }
114 #[inline]
115 pub fn get_channel_length(&self) -> usize {
116 self.params.channel_length.unwrap_or(9)
117 }
118 #[inline]
119 pub fn get_average_length(&self) -> usize {
120 self.params.average_length.unwrap_or(12)
121 }
122 #[inline]
123 pub fn get_ma_length(&self) -> usize {
124 self.params.ma_length.unwrap_or(3)
125 }
126 #[inline]
127 pub fn get_factor(&self) -> f64 {
128 self.params.factor.unwrap_or(0.015)
129 }
130}
131
132#[derive(Copy, Clone, Debug)]
133pub struct WavetrendBuilder {
134 channel_length: Option<usize>,
135 average_length: Option<usize>,
136 ma_length: Option<usize>,
137 factor: Option<f64>,
138 kernel: Kernel,
139}
140
141impl Default for WavetrendBuilder {
142 fn default() -> Self {
143 Self {
144 channel_length: None,
145 average_length: None,
146 ma_length: None,
147 factor: None,
148 kernel: Kernel::Auto,
149 }
150 }
151}
152
153impl WavetrendBuilder {
154 #[inline(always)]
155 pub fn new() -> Self {
156 Self::default()
157 }
158 #[inline(always)]
159 pub fn channel_length(mut self, n: usize) -> Self {
160 self.channel_length = Some(n);
161 self
162 }
163 #[inline(always)]
164 pub fn average_length(mut self, n: usize) -> Self {
165 self.average_length = Some(n);
166 self
167 }
168 #[inline(always)]
169 pub fn ma_length(mut self, n: usize) -> Self {
170 self.ma_length = Some(n);
171 self
172 }
173 #[inline(always)]
174 pub fn factor(mut self, f: f64) -> Self {
175 self.factor = Some(f);
176 self
177 }
178 #[inline(always)]
179 pub fn kernel(mut self, k: Kernel) -> Self {
180 self.kernel = k;
181 self
182 }
183 #[inline(always)]
184 pub fn apply(self, c: &Candles) -> Result<WavetrendOutput, WavetrendError> {
185 let p = WavetrendParams {
186 channel_length: self.channel_length,
187 average_length: self.average_length,
188 ma_length: self.ma_length,
189 factor: self.factor,
190 };
191 let i = WavetrendInput::from_candles(c, "hlc3", p);
192 wavetrend_with_kernel(&i, self.kernel)
193 }
194 #[inline(always)]
195 pub fn apply_slice(self, d: &[f64]) -> Result<WavetrendOutput, WavetrendError> {
196 let p = WavetrendParams {
197 channel_length: self.channel_length,
198 average_length: self.average_length,
199 ma_length: self.ma_length,
200 factor: self.factor,
201 };
202 let i = WavetrendInput::from_slice(d, p);
203 wavetrend_with_kernel(&i, self.kernel)
204 }
205 #[inline(always)]
206 pub fn into_stream(self) -> Result<WavetrendStream, WavetrendError> {
207 let p = WavetrendParams {
208 channel_length: self.channel_length,
209 average_length: self.average_length,
210 ma_length: self.ma_length,
211 factor: self.factor,
212 };
213 WavetrendStream::try_new(p)
214 }
215}
216
217#[derive(Debug, Error)]
218pub enum WavetrendError {
219 #[error("wavetrend: Empty data provided.")]
220 EmptyInputData,
221 #[error("wavetrend: Empty data provided.")]
222 EmptyData,
223 #[error("wavetrend: All values are NaN.")]
224 AllValuesNaN,
225 #[error("wavetrend: Invalid channel_length = {channel_length}, data length = {data_len}")]
226 InvalidChannelLen {
227 channel_length: usize,
228 data_len: usize,
229 },
230 #[error("wavetrend: Invalid average_length = {average_length}, data length = {data_len}")]
231 InvalidAverageLen {
232 average_length: usize,
233 data_len: usize,
234 },
235 #[error("wavetrend: Invalid ma_length = {ma_length}, data length = {data_len}")]
236 InvalidMaLen { ma_length: usize, data_len: usize },
237 #[error("wavetrend: Not enough valid data: needed = {needed}, valid = {valid}")]
238 NotEnoughValidData { needed: usize, valid: usize },
239 #[error("wavetrend: Output length mismatch: expected = {expected}, got = {got}")]
240 OutputLengthMismatch { expected: usize, got: usize },
241 #[error("wavetrend: Output slice length mismatch: expected = {expected}, got = {got}")]
242 OutputSliceLengthMismatch { expected: usize, got: usize },
243 #[error("wavetrend: Invalid range: start={start}, end={end}, step={step}")]
244 InvalidRange {
245 start: String,
246 end: String,
247 step: String,
248 },
249 #[error("wavetrend: Invalid kernel for batch: {0:?}")]
250 InvalidKernelForBatch(crate::utilities::enums::Kernel),
251 #[error("wavetrend: EMA error {0}")]
252 EmaError(#[from] EmaError),
253 #[error("wavetrend: SMA error {0}")]
254 SmaError(#[from] SmaError),
255}
256
257#[inline]
258pub fn wavetrend(input: &WavetrendInput) -> Result<WavetrendOutput, WavetrendError> {
259 wavetrend_with_kernel(input, Kernel::Auto)
260}
261
262pub fn wavetrend_with_kernel(
263 input: &WavetrendInput,
264 kernel: Kernel,
265) -> Result<WavetrendOutput, WavetrendError> {
266 let data: &[f64] = input.as_ref();
267 if data.is_empty() {
268 return Err(WavetrendError::EmptyInputData);
269 }
270 let channel_len = input.get_channel_length();
271 let average_len = input.get_average_length();
272 let ma_len = input.get_ma_length();
273 let factor = input.get_factor();
274
275 let first = data
276 .iter()
277 .position(|x| !x.is_nan())
278 .ok_or(WavetrendError::AllValuesNaN)?;
279 let needed = *[channel_len, average_len, ma_len].iter().max().unwrap();
280 let valid = data.len() - first;
281
282 if channel_len == 0 || channel_len > data.len() {
283 return Err(WavetrendError::InvalidChannelLen {
284 channel_length: channel_len,
285 data_len: data.len(),
286 });
287 }
288 if average_len == 0 || average_len > data.len() {
289 return Err(WavetrendError::InvalidAverageLen {
290 average_length: average_len,
291 data_len: data.len(),
292 });
293 }
294 if ma_len == 0 || ma_len > data.len() {
295 return Err(WavetrendError::InvalidMaLen {
296 ma_length: ma_len,
297 data_len: data.len(),
298 });
299 }
300 if valid < needed {
301 return Err(WavetrendError::NotEnoughValidData { needed, valid });
302 }
303
304 let chosen = match kernel {
305 Kernel::Auto => detect_best_kernel(),
306
307 Kernel::Avx2 | Kernel::Avx512 => Kernel::Scalar,
308 other => other,
309 };
310
311 unsafe {
312 match chosen {
313 Kernel::Scalar | Kernel::ScalarBatch => {
314 wavetrend_scalar(data, channel_len, average_len, ma_len, factor, first)
315 }
316 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
317 Kernel::Avx2 | Kernel::Avx2Batch => {
318 wavetrend_avx2(data, channel_len, average_len, ma_len, factor, first)
319 }
320 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
321 Kernel::Avx512 | Kernel::Avx512Batch => {
322 wavetrend_avx512(data, channel_len, average_len, ma_len, factor, first)
323 }
324 _ => unreachable!(),
325 }
326 }
327}
328
329fn wavetrend_kernel_dispatch(
330 data: &[f64],
331 channel_len: usize,
332 average_len: usize,
333 ma_len: usize,
334 factor: f64,
335 first: usize,
336 kernel: Kernel,
337) -> Result<WavetrendOutput, WavetrendError> {
338 let warmup_period = first + channel_len - 1 + average_len - 1 + ma_len - 1;
339
340 let mut wt1_final = alloc_with_nan_prefix(data.len(), warmup_period);
341 let mut wt2_final = alloc_with_nan_prefix(data.len(), warmup_period);
342 let mut diff_final = alloc_with_nan_prefix(data.len(), warmup_period);
343
344 wavetrend_compute_into(
345 data,
346 channel_len,
347 average_len,
348 ma_len,
349 factor,
350 first,
351 warmup_period,
352 &mut wt1_final,
353 &mut wt2_final,
354 &mut diff_final,
355 kernel,
356 )?;
357
358 Ok(WavetrendOutput {
359 wt1: wt1_final,
360 wt2: wt2_final,
361 wt_diff: diff_final,
362 })
363}
364
365pub fn wavetrend_scalar(
366 data: &[f64],
367 channel_len: usize,
368 average_len: usize,
369 ma_len: usize,
370 factor: f64,
371 first: usize,
372) -> Result<WavetrendOutput, WavetrendError> {
373 wavetrend_kernel_dispatch(
374 data,
375 channel_len,
376 average_len,
377 ma_len,
378 factor,
379 first,
380 Kernel::Scalar,
381 )
382}
383
384use std::collections::VecDeque;
385
386#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
387#[inline]
388pub unsafe fn wavetrend_avx2(
389 data: &[f64],
390 channel_len: usize,
391 average_len: usize,
392 ma_len: usize,
393 factor: f64,
394 first: usize,
395) -> Result<WavetrendOutput, WavetrendError> {
396 let warmup_period = first + channel_len - 1 + average_len - 1 + ma_len - 1;
397
398 let mut wt1_out = alloc_with_nan_prefix(data.len(), warmup_period);
399 let mut wt2_out = alloc_with_nan_prefix(data.len(), warmup_period);
400 let mut diff_out = alloc_with_nan_prefix(data.len(), warmup_period);
401
402 wavetrend_fused_avx2_into(
403 data,
404 channel_len,
405 average_len,
406 ma_len,
407 factor,
408 first,
409 warmup_period,
410 &mut wt1_out,
411 &mut wt2_out,
412 &mut diff_out,
413 );
414
415 Ok(WavetrendOutput {
416 wt1: wt1_out,
417 wt2: wt2_out,
418 wt_diff: diff_out,
419 })
420}
421
422#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
423#[target_feature(enable = "avx2")]
424#[target_feature(enable = "fma")]
425unsafe fn wavetrend_fused_avx2_into(
426 data: &[f64],
427 channel_len: usize,
428 average_len: usize,
429 ma_len: usize,
430 factor: f64,
431 first: usize,
432 warmup_period: usize,
433 dst_wt1: &mut [f64],
434 dst_wt2: &mut [f64],
435 dst_wt_diff: &mut [f64],
436) {
437 let n = data.len();
438 if n == 0 {
439 return;
440 }
441
442 let alpha_ch = 2.0 / (channel_len as f64 + 1.0);
443 let beta_ch = 1.0 - alpha_ch;
444 let alpha_avg = 2.0 / (average_len as f64 + 1.0);
445 let beta_avg = 1.0 - alpha_avg;
446
447 let mut esa_state: f64 = f64::NAN;
448 let mut de_state: f64 = f64::NAN;
449 let mut wt1_state: f64 = f64::NAN;
450 let mut esa_seeded = false;
451 let mut de_seeded = false;
452 let mut wt1_seeded = false;
453
454 let mut ring_vals = vec![f64::NAN; ma_len];
455 let mut ring_mask = vec![0u8; ma_len];
456 let mut head = 0usize;
457 let mut sma_sum = 0.0f64;
458 let mut sma_count = 0usize;
459 let inv_ma = 1.0 / (ma_len as f64);
460
461 for idx in first..n {
462 let x = data[idx];
463 let mut wt1_i = f64::NAN;
464 let mut wt2_i = f64::NAN;
465
466 if x.is_finite() {
467 if !esa_seeded {
468 esa_state = x;
469 esa_seeded = true;
470 } else {
471 esa_state = x.mul_add(alpha_ch, beta_ch * esa_state);
472 }
473
474 let abs_diff = (x - esa_state).abs();
475 if !de_seeded {
476 de_state = abs_diff;
477 de_seeded = true;
478 } else {
479 de_state = abs_diff.mul_add(alpha_ch, beta_ch * de_state);
480 }
481
482 let den = factor * de_state;
483 if den != 0.0 && den.is_finite() && esa_state.is_finite() {
484 let ci = (x - esa_state) / den;
485 if ci.is_finite() {
486 if !wt1_seeded {
487 wt1_state = ci;
488 wt1_seeded = true;
489 } else {
490 wt1_state = ci.mul_add(alpha_avg, beta_avg * wt1_state);
491 }
492 wt1_i = wt1_state;
493 }
494 }
495 }
496
497 if ring_mask[head] != 0 {
498 sma_sum -= ring_vals[head];
499 sma_count -= 1;
500 }
501 if wt1_i.is_finite() {
502 ring_vals[head] = wt1_i;
503 ring_mask[head] = 1;
504 sma_sum += wt1_i;
505 sma_count += 1;
506 } else {
507 ring_vals[head] = f64::NAN;
508 ring_mask[head] = 0;
509 }
510 head += 1;
511 if head == ma_len {
512 head = 0;
513 }
514 if sma_count == ma_len {
515 wt2_i = sma_sum * inv_ma;
516 }
517
518 if idx >= warmup_period {
519 dst_wt1[idx] = wt1_i;
520 dst_wt2[idx] = wt2_i;
521 dst_wt_diff[idx] = if wt1_i.is_finite() && wt2_i.is_finite() {
522 wt2_i - wt1_i
523 } else {
524 f64::NAN
525 };
526 }
527 }
528}
529
530#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
531#[inline]
532pub unsafe fn wavetrend_avx512(
533 data: &[f64],
534 channel_len: usize,
535 average_len: usize,
536 ma_len: usize,
537 factor: f64,
538 first: usize,
539) -> Result<WavetrendOutput, WavetrendError> {
540 if channel_len <= 32 {
541 wavetrend_avx512_short(data, channel_len, average_len, ma_len, factor, first)
542 } else {
543 wavetrend_avx512_long(data, channel_len, average_len, ma_len, factor, first)
544 }
545}
546
547#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
548#[inline]
549pub unsafe fn wavetrend_avx512_short(
550 data: &[f64],
551 channel_len: usize,
552 average_len: usize,
553 ma_len: usize,
554 factor: f64,
555 first: usize,
556) -> Result<WavetrendOutput, WavetrendError> {
557 wavetrend_kernel_dispatch(
558 data,
559 channel_len,
560 average_len,
561 ma_len,
562 factor,
563 first,
564 Kernel::Avx512,
565 )
566}
567
568#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
569#[inline]
570pub unsafe fn wavetrend_avx512_long(
571 data: &[f64],
572 channel_len: usize,
573 average_len: usize,
574 ma_len: usize,
575 factor: f64,
576 first: usize,
577) -> Result<WavetrendOutput, WavetrendError> {
578 wavetrend_kernel_dispatch(
579 data,
580 channel_len,
581 average_len,
582 ma_len,
583 factor,
584 first,
585 Kernel::Avx512,
586 )
587}
588
589#[inline(always)]
590fn wavetrend_prepare<'a>(
591 input: &'a WavetrendInput,
592) -> Result<(&'a [f64], usize, usize, usize, f64, usize, usize), WavetrendError> {
593 let data: &[f64] = input.as_ref();
594 if data.is_empty() {
595 return Err(WavetrendError::EmptyInputData);
596 }
597
598 let first = data
599 .iter()
600 .position(|x| !x.is_nan())
601 .ok_or(WavetrendError::AllValuesNaN)?;
602 let channel_len = input.get_channel_length();
603 let average_len = input.get_average_length();
604 let ma_len = input.get_ma_length();
605 let factor = input.get_factor();
606
607 if channel_len == 0 || channel_len > data.len() {
608 return Err(WavetrendError::InvalidChannelLen {
609 channel_length: channel_len,
610 data_len: data.len(),
611 });
612 }
613 if average_len == 0 || average_len > data.len() {
614 return Err(WavetrendError::InvalidAverageLen {
615 average_length: average_len,
616 data_len: data.len(),
617 });
618 }
619 if ma_len == 0 || ma_len > data.len() {
620 return Err(WavetrendError::InvalidMaLen {
621 ma_length: ma_len,
622 data_len: data.len(),
623 });
624 }
625
626 let max_period = channel_len.max(average_len).max(ma_len);
627 if data.len() - first < max_period {
628 return Err(WavetrendError::NotEnoughValidData {
629 needed: max_period,
630 valid: data.len() - first,
631 });
632 }
633
634 let warmup_period = first + channel_len - 1 + average_len - 1 + ma_len - 1;
635
636 Ok((
637 data,
638 channel_len,
639 average_len,
640 ma_len,
641 factor,
642 first,
643 warmup_period,
644 ))
645}
646
647#[inline(always)]
648fn wavetrend_compute_into(
649 data: &[f64],
650 channel_len: usize,
651 average_len: usize,
652 ma_len: usize,
653 factor: f64,
654 first: usize,
655 warmup_period: usize,
656 dst_wt1: &mut [f64],
657 dst_wt2: &mut [f64],
658 dst_wt_diff: &mut [f64],
659 kernel: Kernel,
660) -> Result<(), WavetrendError> {
661 if matches!(kernel.to_non_batch(), Kernel::Scalar) {
662 let n = data.len();
663 if n == 0 {
664 return Ok(());
665 }
666
667 let alpha_ch = 2.0 / (channel_len as f64 + 1.0);
668 let beta_ch = 1.0 - alpha_ch;
669 let alpha_avg = 2.0 / (average_len as f64 + 1.0);
670 let beta_avg = 1.0 - alpha_avg;
671
672 let mut esa_state: f64 = f64::NAN;
673 let mut de_state: f64 = f64::NAN;
674 let mut wt1_state: f64 = f64::NAN;
675 let mut esa_seeded = false;
676 let mut de_seeded = false;
677 let mut wt1_seeded = false;
678
679 let mut ring_vals = vec![f64::NAN; ma_len];
680 let mut ring_mask = vec![0u8; ma_len];
681 let mut head = 0usize;
682 let mut sma_sum = 0.0f64;
683 let mut sma_count = 0usize;
684
685 for idx in first..n {
686 let x = data[idx];
687
688 let mut wt1_i = f64::NAN;
689 let mut wt2_i = f64::NAN;
690
691 if x.is_finite() {
692 if !esa_seeded {
693 esa_state = x;
694 esa_seeded = true;
695 } else {
696 esa_state = alpha_ch * x + beta_ch * esa_state;
697 }
698
699 let abs_diff = (x - esa_state).abs();
700 if !de_seeded {
701 de_state = abs_diff;
702 de_seeded = true;
703 } else {
704 de_state = alpha_ch * abs_diff + beta_ch * de_state;
705 }
706
707 let den = factor * de_state;
708 if den != 0.0 && den.is_finite() && esa_state.is_finite() {
709 let ci = (x - esa_state) / den;
710 if ci.is_finite() {
711 if !wt1_seeded {
712 wt1_state = ci;
713 wt1_seeded = true;
714 } else {
715 wt1_state = alpha_avg * ci + beta_avg * wt1_state;
716 }
717 wt1_i = wt1_state;
718 }
719 }
720 }
721
722 if ma_len > 0 {
723 if ring_mask[head] != 0 {
724 sma_sum -= ring_vals[head];
725 sma_count -= 1;
726 }
727
728 if wt1_i.is_finite() {
729 ring_vals[head] = wt1_i;
730 ring_mask[head] = 1;
731 sma_sum += wt1_i;
732 sma_count += 1;
733 } else {
734 ring_vals[head] = f64::NAN;
735 ring_mask[head] = 0;
736 }
737 head += 1;
738 if head == ma_len {
739 head = 0;
740 }
741
742 if sma_count == ma_len {
743 wt2_i = sma_sum / (ma_len as f64);
744 }
745 }
746
747 if idx >= warmup_period {
748 dst_wt1[idx] = wt1_i;
749 dst_wt2[idx] = wt2_i;
750 dst_wt_diff[idx] = if wt1_i.is_finite() && wt2_i.is_finite() {
751 wt2_i - wt1_i
752 } else {
753 f64::NAN
754 };
755 }
756 }
757
758 return Ok(());
759 }
760
761 let data_valid = &data[first..];
762 let simd_kernel = kernel.to_non_batch();
763
764 if data_valid.len() <= STACK_LIMIT {
765 let mut esa_buf = [0.0f64; STACK_LIMIT];
766 let mut de_buf = [0.0f64; STACK_LIMIT];
767 let mut ci_buf = [0.0f64; STACK_LIMIT];
768 let mut wt1_buf = [0.0f64; STACK_LIMIT];
769 let mut wt2_buf = [0.0f64; STACK_LIMIT];
770
771 let esa = &mut esa_buf[..data_valid.len()];
772 let de = &mut de_buf[..data_valid.len()];
773 let ci = &mut ci_buf[..data_valid.len()];
774 let wt1 = &mut wt1_buf[..data_valid.len()];
775 let wt2 = &mut wt2_buf[..data_valid.len()];
776
777 wavetrend_core_computation(
778 data_valid,
779 channel_len,
780 average_len,
781 ma_len,
782 factor,
783 esa,
784 de,
785 ci,
786 wt1,
787 wt2,
788 simd_kernel,
789 )?;
790
791 for i in 0..data_valid.len() {
792 let out_idx = i + first;
793 if out_idx >= warmup_period {
794 dst_wt1[out_idx] = wt1[i];
795 dst_wt2[out_idx] = wt2[i];
796 if !wt1[i].is_nan() && !wt2[i].is_nan() {
797 dst_wt_diff[out_idx] = wt2[i] - wt1[i];
798 } else {
799 dst_wt_diff[out_idx] = f64::NAN;
800 }
801 }
802 }
803 } else {
804 let mut esa = vec![0.0; data_valid.len()];
805 let mut de = vec![0.0; data_valid.len()];
806 let mut ci = vec![0.0; data_valid.len()];
807 let mut wt1 = vec![0.0; data_valid.len()];
808 let mut wt2 = vec![0.0; data_valid.len()];
809
810 wavetrend_core_computation(
811 data_valid,
812 channel_len,
813 average_len,
814 ma_len,
815 factor,
816 &mut esa,
817 &mut de,
818 &mut ci,
819 &mut wt1,
820 &mut wt2,
821 simd_kernel,
822 )?;
823
824 for i in 0..data_valid.len() {
825 let out_idx = i + first;
826 if out_idx >= warmup_period {
827 dst_wt1[out_idx] = wt1[i];
828 dst_wt2[out_idx] = wt2[i];
829 if !wt1[i].is_nan() && !wt2[i].is_nan() {
830 dst_wt_diff[out_idx] = wt2[i] - wt1[i];
831 } else {
832 dst_wt_diff[out_idx] = f64::NAN;
833 }
834 }
835 }
836 }
837
838 Ok(())
839}
840
841const STACK_LIMIT: usize = 512;
842
843#[inline(always)]
844fn wavetrend_core_computation(
845 data: &[f64],
846 channel_len: usize,
847 average_len: usize,
848 ma_len: usize,
849 factor: f64,
850 esa: &mut [f64],
851 de: &mut [f64],
852 ci: &mut [f64],
853 wt1: &mut [f64],
854 wt2: &mut [f64],
855 kernel: Kernel,
856) -> Result<(), WavetrendError> {
857 ema_compute_into(data, channel_len, esa);
858
859 if data.len() <= STACK_LIMIT {
860 let mut abs_diff_buf = [0.0f64; STACK_LIMIT];
861 let abs_diff = &mut abs_diff_buf[..data.len()];
862 compute_abs_diff(abs_diff, data, esa, kernel);
863 ema_compute_into(abs_diff, channel_len, de);
864 } else {
865 let mut abs_diff = vec![0.0; data.len()];
866 compute_abs_diff(&mut abs_diff, data, esa, kernel);
867 ema_compute_into(&abs_diff, channel_len, de);
868 }
869
870 compute_ci(ci, data, esa, de, factor, kernel);
871
872 ema_compute_into(ci, average_len, wt1);
873
874 sma_compute_into(wt1, ma_len, wt2);
875
876 Ok(())
877}
878
879#[inline(always)]
880fn compute_abs_diff(out: &mut [f64], data: &[f64], esa: &[f64], kernel: Kernel) {
881 let simd = kernel.to_non_batch();
882 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
883 {
884 match simd {
885 Kernel::Avx512 => unsafe {
886 absdiff_vec_avx512(out, data, esa);
887 return;
888 },
889 Kernel::Avx2 => unsafe {
890 absdiff_vec_avx2(out, data, esa);
891 return;
892 },
893 _ => {}
894 }
895 }
896
897 for i in 0..out.len() {
898 out[i] = (data[i] - esa[i]).abs();
899 }
900}
901
902#[inline(always)]
903fn compute_ci(out: &mut [f64], data: &[f64], esa: &[f64], de: &[f64], factor: f64, kernel: Kernel) {
904 let simd = kernel.to_non_batch();
905 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
906 {
907 match simd {
908 Kernel::Avx512 => unsafe {
909 ci_vec_avx512(out, data, esa, de, factor);
910 return;
911 },
912 Kernel::Avx2 => unsafe {
913 ci_vec_avx2(out, data, esa, de, factor);
914 return;
915 },
916 _ => {}
917 }
918 }
919
920 for i in 0..out.len() {
921 let den = factor * de[i];
922 if den != 0.0 && !data[i].is_nan() && !esa[i].is_nan() && !de[i].is_nan() {
923 out[i] = (data[i] - esa[i]) / den;
924 } else {
925 out[i] = f64::NAN;
926 }
927 }
928}
929
930#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
931#[target_feature(enable = "avx2")]
932unsafe fn absdiff_vec_avx2(dst: &mut [f64], a: &[f64], b: &[f64]) {
933 let n = dst.len();
934 let pa = a.as_ptr();
935 let pb = b.as_ptr();
936 let pd = dst.as_mut_ptr();
937 let sign = _mm256_set1_pd(-0.0f64);
938 let mut i = 0usize;
939 while i + 4 <= n {
940 let va = _mm256_loadu_pd(pa.add(i));
941 let vb = _mm256_loadu_pd(pb.add(i));
942 let vd = _mm256_sub_pd(va, vb);
943 let vabs = _mm256_andnot_pd(sign, vd);
944 _mm256_storeu_pd(pd.add(i), vabs);
945 i += 4;
946 }
947 while i < n {
948 *pd.add(i) = (*pa.add(i) - *pb.add(i)).abs();
949 i += 1;
950 }
951}
952
953#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
954#[target_feature(enable = "avx2")]
955unsafe fn ci_vec_avx2(dst: &mut [f64], data: &[f64], esa: &[f64], de: &[f64], factor: f64) {
956 let n = dst.len();
957 let px = data.as_ptr();
958 let pe = esa.as_ptr();
959 let pd = de.as_ptr();
960 let pr = dst.as_mut_ptr();
961
962 let vf = _mm256_set1_pd(factor);
963 let vzero = _mm256_set1_pd(0.0);
964 let vnan = _mm256_set1_pd(f64::NAN);
965
966 let mut i = 0usize;
967 while i + 4 <= n {
968 let vx = _mm256_loadu_pd(px.add(i));
969 let ve = _mm256_loadu_pd(pe.add(i));
970 let vd = _mm256_loadu_pd(pd.add(i));
971
972 let vnum = _mm256_sub_pd(vx, ve);
973 let vden = _mm256_mul_pd(vf, vd);
974 let vci = _mm256_div_pd(vnum, vden);
975
976 let ord_x = _mm256_cmp_pd(vx, vx, _CMP_ORD_Q);
977 let ord_e = _mm256_cmp_pd(ve, ve, _CMP_ORD_Q);
978 let ord_d = _mm256_cmp_pd(vd, vd, _CMP_ORD_Q);
979 let ord_all = _mm256_and_pd(ord_x, _mm256_and_pd(ord_e, ord_d));
980 let den_zero = _mm256_cmp_pd(vden, vzero, _CMP_EQ_OQ);
981 let valid = _mm256_andnot_pd(den_zero, ord_all);
982
983 let vres = _mm256_blendv_pd(vnan, vci, valid);
984 _mm256_storeu_pd(pr.add(i), vres);
985 i += 4;
986 }
987 while i < n {
988 let x = *px.add(i);
989 let e = *pe.add(i);
990 let d = *pd.add(i);
991 let den = factor * d;
992 *pr.add(i) = if den != 0.0 && x.is_finite() && e.is_finite() && d.is_finite() {
993 (x - e) / den
994 } else {
995 f64::NAN
996 };
997 i += 1;
998 }
999}
1000
1001#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1002#[target_feature(enable = "avx512f")]
1003unsafe fn absdiff_vec_avx512(dst: &mut [f64], a: &[f64], b: &[f64]) {
1004 let n = dst.len();
1005 let pa = a.as_ptr();
1006 let pb = b.as_ptr();
1007 let pd = dst.as_mut_ptr();
1008 let sign = _mm512_set1_epi64(0x8000_0000_0000_0000u64 as i64);
1009 let sign_pd = _mm512_castsi512_pd(sign);
1010
1011 let mut i = 0usize;
1012 while i + 8 <= n {
1013 let va = _mm512_loadu_pd(pa.add(i));
1014 let vb = _mm512_loadu_pd(pb.add(i));
1015 let vd = _mm512_sub_pd(va, vb);
1016 let vabs = _mm512_andnot_pd(sign_pd, vd);
1017 _mm512_storeu_pd(pd.add(i), vabs);
1018 i += 8;
1019 }
1020 while i < n {
1021 *pd.add(i) = (*pa.add(i) - *pb.add(i)).abs();
1022 i += 1;
1023 }
1024}
1025
1026#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1027#[target_feature(enable = "avx512f")]
1028unsafe fn ci_vec_avx512(dst: &mut [f64], data: &[f64], esa: &[f64], de: &[f64], factor: f64) {
1029 let n = dst.len();
1030 let px = data.as_ptr();
1031 let pe = esa.as_ptr();
1032 let pdv = de.as_ptr();
1033 let pr = dst.as_mut_ptr();
1034
1035 let vf = _mm512_set1_pd(factor);
1036 let vzero = _mm512_set1_pd(0.0);
1037 let vnan = _mm512_set1_pd(f64::NAN);
1038
1039 let mut i = 0usize;
1040 while i + 8 <= n {
1041 let vx = _mm512_loadu_pd(px.add(i));
1042 let ve = _mm512_loadu_pd(pe.add(i));
1043 let vd = _mm512_loadu_pd(pdv.add(i));
1044
1045 let vnum = _mm512_sub_pd(vx, ve);
1046 let vden = _mm512_mul_pd(vf, vd);
1047 let vci = _mm512_div_pd(vnum, vden);
1048
1049 let ord_x = _mm512_cmp_pd_mask(vx, vx, _CMP_ORD_Q);
1050 let ord_e = _mm512_cmp_pd_mask(ve, ve, _CMP_ORD_Q);
1051 let ord_d = _mm512_cmp_pd_mask(vd, vd, _CMP_ORD_Q);
1052 let ord_all = ord_x & ord_e & ord_d;
1053 let den_zero = _mm512_cmp_pd_mask(vden, vzero, _CMP_EQ_OQ);
1054 let valid = ord_all & (!den_zero);
1055
1056 let vres = _mm512_mask_mov_pd(vnan, valid, vci);
1057 _mm512_storeu_pd(pr.add(i), vres);
1058 i += 8;
1059 }
1060 while i < n {
1061 let x = *px.add(i);
1062 let e = *pe.add(i);
1063 let d = *pdv.add(i);
1064 let den = factor * d;
1065 *pr.add(i) = if den != 0.0 && x.is_finite() && e.is_finite() && d.is_finite() {
1066 (x - e) / den
1067 } else {
1068 f64::NAN
1069 };
1070 i += 1;
1071 }
1072}
1073
1074#[inline(always)]
1075fn ema_compute_into(data: &[f64], period: usize, out: &mut [f64]) {
1076 if period == 0 || data.is_empty() {
1077 return;
1078 }
1079
1080 let alpha = 2.0 / (period as f64 + 1.0);
1081 let beta = 1.0 - alpha;
1082
1083 let mut ema_val = f64::NAN;
1084 for i in 0..data.len() {
1085 if !data[i].is_nan() {
1086 if ema_val.is_nan() {
1087 ema_val = data[i];
1088 } else {
1089 ema_val = alpha * data[i] + beta * ema_val;
1090 }
1091 out[i] = ema_val;
1092 } else {
1093 out[i] = f64::NAN;
1094 }
1095 }
1096}
1097
1098#[inline(always)]
1099fn sma_compute_into(data: &[f64], period: usize, out: &mut [f64]) {
1100 if period == 0 || data.is_empty() {
1101 return;
1102 }
1103
1104 let mut sum = 0.0;
1105 let mut count = 0;
1106
1107 for i in 0..out.len() {
1108 out[i] = f64::NAN;
1109 }
1110
1111 for i in 0..data.len() {
1112 if !data[i].is_nan() {
1113 sum += data[i];
1114 count += 1;
1115
1116 if i >= period {
1117 if !data[i - period].is_nan() {
1118 sum -= data[i - period];
1119 count -= 1;
1120 }
1121 }
1122
1123 if count >= period {
1124 out[i] = sum / period as f64;
1125 }
1126 }
1127 }
1128}
1129
1130#[inline]
1131pub fn wavetrend_into_slice(
1132 dst_wt1: &mut [f64],
1133 dst_wt2: &mut [f64],
1134 dst_wt_diff: &mut [f64],
1135 input: &WavetrendInput,
1136 kern: Kernel,
1137) -> Result<(), WavetrendError> {
1138 let (data, channel_len, average_len, ma_len, factor, first, warmup_period) =
1139 wavetrend_prepare(input)?;
1140
1141 if dst_wt1.len() != data.len() {
1142 return Err(WavetrendError::OutputLengthMismatch {
1143 expected: data.len(),
1144 got: dst_wt1.len(),
1145 });
1146 }
1147 if dst_wt2.len() != data.len() {
1148 return Err(WavetrendError::OutputLengthMismatch {
1149 expected: data.len(),
1150 got: dst_wt2.len(),
1151 });
1152 }
1153 if dst_wt_diff.len() != data.len() {
1154 return Err(WavetrendError::OutputLengthMismatch {
1155 expected: data.len(),
1156 got: dst_wt_diff.len(),
1157 });
1158 }
1159
1160 for i in 0..warmup_period.min(data.len()) {
1161 dst_wt1[i] = f64::NAN;
1162 dst_wt2[i] = f64::NAN;
1163 dst_wt_diff[i] = f64::NAN;
1164 }
1165
1166 let chosen = match kern {
1167 Kernel::Auto => detect_best_kernel(),
1168 Kernel::ScalarBatch => Kernel::Scalar,
1169 Kernel::Avx2Batch => Kernel::Avx2,
1170 Kernel::Avx512Batch => Kernel::Avx512,
1171 other => other,
1172 };
1173
1174 wavetrend_compute_into(
1175 data,
1176 channel_len,
1177 average_len,
1178 ma_len,
1179 factor,
1180 first,
1181 warmup_period,
1182 dst_wt1,
1183 dst_wt2,
1184 dst_wt_diff,
1185 chosen,
1186 )?;
1187
1188 Ok(())
1189}
1190
1191#[derive(Clone, Debug)]
1192pub struct WavetrendStream {
1193 pub channel_length: usize,
1194 pub average_length: usize,
1195 pub ma_length: usize,
1196 pub factor: f64,
1197
1198 esa_buf: VecDeque<f64>,
1199 last_esa: Option<f64>,
1200 alpha_ch: f64,
1201
1202 beta_ch: f64,
1203
1204 de_buf: VecDeque<f64>,
1205 last_de: Option<f64>,
1206
1207 ci_buf: VecDeque<f64>,
1208 last_wt1: Option<f64>,
1209 alpha_avg: f64,
1210
1211 beta_avg: f64,
1212
1213 wt1_buf: VecDeque<f64>,
1214 running_sum: f64,
1215
1216 sma_count: usize,
1217
1218 inv_ma: f64,
1219
1220 pub history: Vec<f64>,
1221}
1222
1223impl WavetrendStream {
1224 pub fn try_new(p: WavetrendParams) -> Result<Self, WavetrendError> {
1225 let channel_length = p.channel_length.unwrap_or(9);
1226 let average_length = p.average_length.unwrap_or(12);
1227 let ma_length = p.ma_length.unwrap_or(3);
1228 let factor = p.factor.unwrap_or(0.015);
1229
1230 if channel_length == 0 {
1231 return Err(WavetrendError::InvalidChannelLen {
1232 channel_length,
1233 data_len: 0,
1234 });
1235 }
1236 if average_length == 0 {
1237 return Err(WavetrendError::InvalidAverageLen {
1238 average_length,
1239 data_len: 0,
1240 });
1241 }
1242 if ma_length == 0 {
1243 return Err(WavetrendError::InvalidMaLen {
1244 ma_length,
1245 data_len: 0,
1246 });
1247 }
1248
1249 let alpha_ch = 2.0 / (channel_length as f64 + 1.0);
1250 let alpha_avg = 2.0 / (average_length as f64 + 1.0);
1251
1252 Ok(Self {
1253 channel_length,
1254 average_length,
1255 ma_length,
1256 factor,
1257
1258 esa_buf: VecDeque::with_capacity(channel_length),
1259 last_esa: None,
1260 alpha_ch,
1261 beta_ch: 1.0 - alpha_ch,
1262
1263 de_buf: VecDeque::with_capacity(channel_length),
1264 last_de: None,
1265
1266 ci_buf: VecDeque::with_capacity(average_length),
1267 last_wt1: None,
1268 alpha_avg,
1269 beta_avg: 1.0 - alpha_avg,
1270
1271 wt1_buf: VecDeque::with_capacity(ma_length),
1272 running_sum: 0.0,
1273 sma_count: 0,
1274 inv_ma: 1.0 / (ma_length as f64),
1275
1276 history: Vec::new(),
1277 })
1278 }
1279
1280 #[inline(always)]
1281 pub fn update(&mut self, price: f64) -> Option<(f64, f64, f64)> {
1282 self.history.push(price);
1283
1284 let mut wt1_val = f64::NAN;
1285
1286 if price.is_finite() {
1287 if let Some(prev) = self.last_esa {
1288 let new_esa = ema_step(prev, price, self.alpha_ch, self.beta_ch);
1289 self.last_esa = Some(new_esa);
1290 } else {
1291 self.last_esa = Some(price);
1292 }
1293
1294 if let Some(esa_now) = self.last_esa {
1295 let abs_diff = fast_abs_f64(price - esa_now);
1296 if let Some(prev_de) = self.last_de {
1297 let new_de = ema_step(prev_de, abs_diff, self.alpha_ch, self.beta_ch);
1298 self.last_de = Some(new_de);
1299 } else {
1300 self.last_de = Some(abs_diff);
1301 }
1302 }
1303
1304 if let (Some(esa_now), Some(de_now)) = (self.last_esa, self.last_de) {
1305 let den = self.factor * de_now;
1306 if den != 0.0 && den.is_finite() && esa_now.is_finite() {
1307 let ci = (price - esa_now) / den;
1308 if ci.is_finite() {
1309 if let Some(prev_wt1) = self.last_wt1 {
1310 let new_wt1 = ema_step(prev_wt1, ci, self.alpha_avg, self.beta_avg);
1311 self.last_wt1 = Some(new_wt1);
1312 } else {
1313 self.last_wt1 = Some(ci);
1314 }
1315 if let Some(v) = self.last_wt1 {
1316 wt1_val = v;
1317 }
1318 }
1319 }
1320 }
1321 }
1322
1323 if self.wt1_buf.len() == self.ma_length {
1324 if let Some(leaving) = self.wt1_buf.pop_front() {
1325 if leaving.is_finite() {
1326 self.running_sum -= leaving;
1327 if self.sma_count > 0 {
1328 self.sma_count -= 1;
1329 }
1330 }
1331 }
1332 }
1333
1334 self.wt1_buf.push_back(wt1_val);
1335 if wt1_val.is_finite() {
1336 self.running_sum += wt1_val;
1337 self.sma_count += 1;
1338 }
1339
1340 if self.wt1_buf.len() == self.ma_length && self.sma_count == self.ma_length {
1341 let wt1 = wt1_val;
1342 let wt2 = self.running_sum * self.inv_ma;
1343 let diff = wt2 - wt1;
1344 Some((wt1, wt2, diff))
1345 } else {
1346 None
1347 }
1348 }
1349}
1350
1351#[inline(always)]
1352fn ema_step(prev: f64, x: f64, alpha: f64, beta: f64) -> f64 {
1353 x.mul_add(alpha, beta * prev)
1354}
1355
1356#[inline(always)]
1357fn fast_abs_f64(x: f64) -> f64 {
1358 f64::from_bits(x.to_bits() & 0x7FFF_FFFF_FFFF_FFFF)
1359}
1360
1361#[cfg(all(feature = "python", feature = "cuda"))]
1362use cust::context::Context;
1363#[cfg(all(feature = "python", feature = "cuda"))]
1364use std::sync::Arc;
1365
1366#[cfg(all(feature = "python", feature = "cuda"))]
1367#[pyclass(
1368 module = "ta_indicators.cuda",
1369 name = "WavetrendDeviceArrayF32",
1370 unsendable
1371)]
1372pub struct WavetrendDeviceArrayF32Py {
1373 pub(crate) inner: DeviceArrayF32,
1374 pub(crate) _ctx: Arc<Context>,
1375 pub(crate) device_id: u32,
1376}
1377
1378#[cfg(all(feature = "python", feature = "cuda"))]
1379#[pymethods]
1380impl WavetrendDeviceArrayF32Py {
1381 #[getter]
1382 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1383 let d = PyDict::new(py);
1384 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1385 d.set_item("typestr", "<f4")?;
1386 d.set_item(
1387 "strides",
1388 (
1389 self.inner.cols * std::mem::size_of::<f32>(),
1390 std::mem::size_of::<f32>(),
1391 ),
1392 )?;
1393 d.set_item("data", (self.inner.device_ptr() as usize, false))?;
1394
1395 d.set_item("version", 3)?;
1396 Ok(d)
1397 }
1398
1399 fn __dlpack_device__(&self) -> (i32, i32) {
1400 (2, self.device_id as i32)
1401 }
1402
1403 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1404 fn __dlpack__<'py>(
1405 &mut self,
1406 py: Python<'py>,
1407 stream: Option<pyo3::PyObject>,
1408 max_version: Option<pyo3::PyObject>,
1409 dl_device: Option<pyo3::PyObject>,
1410 copy: Option<pyo3::PyObject>,
1411 ) -> PyResult<PyObject> {
1412 let (kdl, alloc_dev) = self.__dlpack_device__();
1413 if let Some(dev_obj) = dl_device.as_ref() {
1414 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1415 if dev_ty != kdl || dev_id != alloc_dev {
1416 let wants_copy = copy
1417 .as_ref()
1418 .and_then(|c| c.extract::<bool>(py).ok())
1419 .unwrap_or(false);
1420 if wants_copy {
1421 return Err(PyValueError::new_err(
1422 "device copy not implemented for __dlpack__",
1423 ));
1424 } else {
1425 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1426 }
1427 }
1428 }
1429 }
1430 let _ = stream;
1431
1432 let dummy = cust::memory::DeviceBuffer::from_slice(&[])
1433 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1434 let inner = std::mem::replace(
1435 &mut self.inner,
1436 DeviceArrayF32 {
1437 buf: dummy,
1438 rows: 0,
1439 cols: 0,
1440 },
1441 );
1442
1443 let rows = inner.rows;
1444 let cols = inner.cols;
1445 let buf = inner.buf;
1446
1447 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1448
1449 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1450 }
1451}
1452
1453#[derive(Clone, Debug)]
1454pub struct WavetrendBatchRange {
1455 pub channel_length: (usize, usize, usize),
1456 pub average_length: (usize, usize, usize),
1457 pub ma_length: (usize, usize, usize),
1458 pub factor: (f64, f64, f64),
1459}
1460
1461impl Default for WavetrendBatchRange {
1462 fn default() -> Self {
1463 Self {
1464 channel_length: (9, 9, 0),
1465 average_length: (12, 261, 1),
1466 ma_length: (3, 3, 0),
1467 factor: (0.015, 0.015, 0.0),
1468 }
1469 }
1470}
1471
1472#[derive(Clone, Debug, Default)]
1473pub struct WavetrendBatchBuilder {
1474 range: WavetrendBatchRange,
1475 kernel: Kernel,
1476}
1477
1478impl WavetrendBatchBuilder {
1479 pub fn new() -> Self {
1480 Self::default()
1481 }
1482 pub fn kernel(mut self, k: Kernel) -> Self {
1483 self.kernel = k;
1484 self
1485 }
1486 pub fn channel_range(mut self, start: usize, end: usize, step: usize) -> Self {
1487 self.range.channel_length = (start, end, step);
1488 self
1489 }
1490 pub fn channel_static(mut self, x: usize) -> Self {
1491 self.range.channel_length = (x, x, 0);
1492 self
1493 }
1494 pub fn avg_range(mut self, start: usize, end: usize, step: usize) -> Self {
1495 self.range.average_length = (start, end, step);
1496 self
1497 }
1498 pub fn avg_static(mut self, x: usize) -> Self {
1499 self.range.average_length = (x, x, 0);
1500 self
1501 }
1502 pub fn ma_range(mut self, start: usize, end: usize, step: usize) -> Self {
1503 self.range.ma_length = (start, end, step);
1504 self
1505 }
1506 pub fn ma_static(mut self, x: usize) -> Self {
1507 self.range.ma_length = (x, x, 0);
1508 self
1509 }
1510 pub fn factor_range(mut self, start: f64, end: f64, step: f64) -> Self {
1511 self.range.factor = (start, end, step);
1512 self
1513 }
1514 pub fn factor_static(mut self, x: f64) -> Self {
1515 self.range.factor = (x, x, 0.0);
1516 self
1517 }
1518 pub fn apply_slice(self, data: &[f64]) -> Result<WavetrendBatchOutput, WavetrendError> {
1519 wavetrend_batch_with_kernel(data, &self.range, self.kernel)
1520 }
1521 pub fn with_default_slice(
1522 data: &[f64],
1523 k: Kernel,
1524 ) -> Result<WavetrendBatchOutput, WavetrendError> {
1525 WavetrendBatchBuilder::new().kernel(k).apply_slice(data)
1526 }
1527 pub fn apply_candles(
1528 self,
1529 c: &Candles,
1530 src: &str,
1531 ) -> Result<WavetrendBatchOutput, WavetrendError> {
1532 let slice = source_type(c, src);
1533 self.apply_slice(slice)
1534 }
1535 pub fn with_default_candles(c: &Candles) -> Result<WavetrendBatchOutput, WavetrendError> {
1536 WavetrendBatchBuilder::new()
1537 .kernel(Kernel::Auto)
1538 .apply_candles(c, "hlc3")
1539 }
1540}
1541
1542pub fn wavetrend_batch_with_kernel(
1543 data: &[f64],
1544 sweep: &WavetrendBatchRange,
1545 k: Kernel,
1546) -> Result<WavetrendBatchOutput, WavetrendError> {
1547 let kernel = match k {
1548 Kernel::Auto => Kernel::ScalarBatch,
1549
1550 Kernel::Avx2Batch | Kernel::Avx512Batch => Kernel::ScalarBatch,
1551 other if other.is_batch() => other,
1552 _ => {
1553 return Err(WavetrendError::InvalidKernelForBatch(k));
1554 }
1555 };
1556 let simd = match kernel {
1557 Kernel::Avx512Batch => Kernel::Avx512,
1558 Kernel::Avx2Batch => Kernel::Avx2,
1559 Kernel::ScalarBatch => Kernel::Scalar,
1560 _ => unreachable!(),
1561 };
1562 wavetrend_batch_par_slice(data, sweep, simd)
1563}
1564
1565#[derive(Clone, Debug)]
1566pub struct WavetrendBatchOutput {
1567 pub wt1: Vec<f64>,
1568 pub wt2: Vec<f64>,
1569 pub wt_diff: Vec<f64>,
1570 pub combos: Vec<WavetrendParams>,
1571 pub rows: usize,
1572 pub cols: usize,
1573}
1574impl WavetrendBatchOutput {
1575 pub fn row_for_params(&self, p: &WavetrendParams) -> Option<usize> {
1576 self.combos.iter().position(|c| {
1577 c.channel_length.unwrap_or(9) == p.channel_length.unwrap_or(9)
1578 && c.average_length.unwrap_or(12) == p.average_length.unwrap_or(12)
1579 && c.ma_length.unwrap_or(3) == p.ma_length.unwrap_or(3)
1580 && (c.factor.unwrap_or(0.015) - p.factor.unwrap_or(0.015)).abs() < 1e-12
1581 })
1582 }
1583 pub fn values_for(&self, p: &WavetrendParams) -> Option<(&[f64], &[f64], &[f64])> {
1584 self.row_for_params(p).map(|row| {
1585 let start = row * self.cols;
1586 (
1587 &self.wt1[start..start + self.cols],
1588 &self.wt2[start..start + self.cols],
1589 &self.wt_diff[start..start + self.cols],
1590 )
1591 })
1592 }
1593}
1594
1595#[inline(always)]
1596fn expand_grid(r: &WavetrendBatchRange) -> Result<Vec<WavetrendParams>, WavetrendError> {
1597 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, WavetrendError> {
1598 if step == 0 || start == end {
1599 return Ok(vec![start]);
1600 }
1601 if start < end {
1602 let st = step.max(1);
1603 return Ok((start..=end).step_by(st).collect());
1604 }
1605
1606 let st = step.max(1) as isize;
1607 let mut v = Vec::new();
1608 let mut x = start as isize;
1609 let end_i = end as isize;
1610 while x >= end_i {
1611 v.push(x as usize);
1612 x -= st;
1613 }
1614 if v.is_empty() {
1615 return Err(WavetrendError::InvalidRange {
1616 start: start.to_string(),
1617 end: end.to_string(),
1618 step: step.to_string(),
1619 });
1620 }
1621 Ok(v)
1622 }
1623 fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, WavetrendError> {
1624 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
1625 return Ok(vec![start]);
1626 }
1627 if start < end {
1628 let mut v = Vec::new();
1629 let mut x = start;
1630 let st = step.abs();
1631 while x <= end + 1e-12 {
1632 v.push(x);
1633 x += st;
1634 }
1635 if v.is_empty() {
1636 return Err(WavetrendError::InvalidRange {
1637 start: start.to_string(),
1638 end: end.to_string(),
1639 step: step.to_string(),
1640 });
1641 }
1642 return Ok(v);
1643 }
1644 let mut v = Vec::new();
1645 let mut x = start;
1646 let st = step.abs();
1647 while x + 1e-12 >= end {
1648 v.push(x);
1649 x -= st;
1650 }
1651 if v.is_empty() {
1652 return Err(WavetrendError::InvalidRange {
1653 start: start.to_string(),
1654 end: end.to_string(),
1655 step: step.to_string(),
1656 });
1657 }
1658 Ok(v)
1659 }
1660
1661 let chs = axis_usize(r.channel_length)?;
1662 let avgs = axis_usize(r.average_length)?;
1663 let mas = axis_usize(r.ma_length)?;
1664 let factors = axis_f64(r.factor)?;
1665
1666 let cap = chs
1667 .len()
1668 .checked_mul(avgs.len())
1669 .and_then(|x| x.checked_mul(mas.len()))
1670 .and_then(|x| x.checked_mul(factors.len()))
1671 .ok_or_else(|| WavetrendError::InvalidRange {
1672 start: "cap".into(),
1673 end: "overflow".into(),
1674 step: "mul".into(),
1675 })?;
1676
1677 let mut out = Vec::with_capacity(cap);
1678 for &c in &chs {
1679 for &a in &avgs {
1680 for &m in &mas {
1681 for &f in &factors {
1682 out.push(WavetrendParams {
1683 channel_length: Some(c),
1684 average_length: Some(a),
1685 ma_length: Some(m),
1686 factor: Some(f),
1687 });
1688 }
1689 }
1690 }
1691 }
1692 Ok(out)
1693}
1694
1695#[inline(always)]
1696pub fn wavetrend_batch_slice(
1697 data: &[f64],
1698 sweep: &WavetrendBatchRange,
1699 kern: Kernel,
1700) -> Result<WavetrendBatchOutput, WavetrendError> {
1701 wavetrend_batch_inner(data, sweep, kern, false)
1702}
1703#[inline(always)]
1704pub fn wavetrend_batch_par_slice(
1705 data: &[f64],
1706 sweep: &WavetrendBatchRange,
1707 kern: Kernel,
1708) -> Result<WavetrendBatchOutput, WavetrendError> {
1709 wavetrend_batch_inner(data, sweep, kern, true)
1710}
1711
1712#[inline(always)]
1713fn wavetrend_batch_inner(
1714 data: &[f64],
1715 sweep: &WavetrendBatchRange,
1716 kern: Kernel,
1717 parallel: bool,
1718) -> Result<WavetrendBatchOutput, WavetrendError> {
1719 let combos = expand_grid(sweep)?;
1720 if combos.is_empty() {
1721 return Err(WavetrendError::InvalidRange {
1722 start: "range".into(),
1723 end: "range".into(),
1724 step: "empty".into(),
1725 });
1726 }
1727 let first = data
1728 .iter()
1729 .position(|x| !x.is_nan())
1730 .ok_or(WavetrendError::AllValuesNaN)?;
1731
1732 let mut max_p = 0usize;
1733 let mut warmup_periods = Vec::with_capacity(combos.len());
1734 for c in combos.iter() {
1735 let channel_length = c.channel_length.unwrap();
1736 if channel_length == 0 {
1737 return Err(WavetrendError::InvalidChannelLen {
1738 channel_length,
1739 data_len: data.len(),
1740 });
1741 }
1742 let average_length = c.average_length.unwrap();
1743 if average_length == 0 {
1744 return Err(WavetrendError::InvalidAverageLen {
1745 average_length,
1746 data_len: data.len(),
1747 });
1748 }
1749 let ma_length = c.ma_length.unwrap();
1750 if ma_length == 0 {
1751 return Err(WavetrendError::InvalidMaLen {
1752 ma_length,
1753 data_len: data.len(),
1754 });
1755 }
1756
1757 max_p = max_p.max(channel_length).max(average_length).max(ma_length);
1758 warmup_periods.push(first + channel_length - 1 + average_length - 1 + ma_length - 1);
1759 }
1760 if data.len() - first < max_p {
1761 return Err(WavetrendError::NotEnoughValidData {
1762 needed: max_p,
1763 valid: data.len() - first,
1764 });
1765 }
1766 let rows = combos.len();
1767 let cols = data.len();
1768
1769 let _ = rows
1770 .checked_mul(cols)
1771 .ok_or_else(|| WavetrendError::InvalidRange {
1772 start: rows.to_string(),
1773 end: cols.to_string(),
1774 step: "rows*cols".into(),
1775 })?;
1776
1777 let mut wt1_mu = make_uninit_matrix(rows, cols);
1778 let mut wt2_mu = make_uninit_matrix(rows, cols);
1779 let mut wt_diff_mu = make_uninit_matrix(rows, cols);
1780
1781 init_matrix_prefixes(&mut wt1_mu, cols, &warmup_periods);
1782 init_matrix_prefixes(&mut wt2_mu, cols, &warmup_periods);
1783 init_matrix_prefixes(&mut wt_diff_mu, cols, &warmup_periods);
1784
1785 let mut wt1_guard = core::mem::ManuallyDrop::new(wt1_mu);
1786 let mut wt2_guard = core::mem::ManuallyDrop::new(wt2_mu);
1787 let mut wt_diff_guard = core::mem::ManuallyDrop::new(wt_diff_mu);
1788
1789 let wt1: &mut [f64] = unsafe {
1790 core::slice::from_raw_parts_mut(wt1_guard.as_mut_ptr() as *mut f64, wt1_guard.len())
1791 };
1792 let wt2: &mut [f64] = unsafe {
1793 core::slice::from_raw_parts_mut(wt2_guard.as_mut_ptr() as *mut f64, wt2_guard.len())
1794 };
1795 let wt_diff: &mut [f64] = unsafe {
1796 core::slice::from_raw_parts_mut(wt_diff_guard.as_mut_ptr() as *mut f64, wt_diff_guard.len())
1797 };
1798
1799 let do_row = |row: usize, w1: &mut [f64], w2: &mut [f64], wd: &mut [f64]| unsafe {
1800 let p = &combos[row];
1801 let row_kernel = match kern {
1802 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1803 Kernel::Avx512 => wavetrend_row_avx512(
1804 data,
1805 first,
1806 p.channel_length.unwrap(),
1807 p.average_length.unwrap(),
1808 p.ma_length.unwrap(),
1809 p.factor.unwrap_or(0.015),
1810 w1,
1811 w2,
1812 wd,
1813 ),
1814 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1815 Kernel::Avx2 => wavetrend_row_avx2(
1816 data,
1817 first,
1818 p.channel_length.unwrap(),
1819 p.average_length.unwrap(),
1820 p.ma_length.unwrap(),
1821 p.factor.unwrap_or(0.015),
1822 w1,
1823 w2,
1824 wd,
1825 ),
1826 _ => wavetrend_row_scalar(
1827 data,
1828 first,
1829 p.channel_length.unwrap(),
1830 p.average_length.unwrap(),
1831 p.ma_length.unwrap(),
1832 p.factor.unwrap_or(0.015),
1833 w1,
1834 w2,
1835 wd,
1836 ),
1837 };
1838 if let Err(e) = row_kernel {
1839 panic!("wavetrend row error: {:?}", e);
1840 }
1841 };
1842
1843 if parallel {
1844 #[cfg(not(target_arch = "wasm32"))]
1845 {
1846 wt1.par_chunks_mut(cols)
1847 .zip(wt2.par_chunks_mut(cols))
1848 .zip(wt_diff.par_chunks_mut(cols))
1849 .enumerate()
1850 .for_each(|(row, ((w1, w2), wd))| do_row(row, w1, w2, wd));
1851 }
1852
1853 #[cfg(target_arch = "wasm32")]
1854 {
1855 for (row, (((w1, w2), wd))) in wt1
1856 .chunks_mut(cols)
1857 .zip(wt2.chunks_mut(cols))
1858 .zip(wt_diff.chunks_mut(cols))
1859 .enumerate()
1860 {
1861 do_row(row, w1, w2, wd);
1862 }
1863 }
1864 } else {
1865 for (row, (((w1, w2), wd))) in wt1
1866 .chunks_mut(cols)
1867 .zip(wt2.chunks_mut(cols))
1868 .zip(wt_diff.chunks_mut(cols))
1869 .enumerate()
1870 {
1871 do_row(row, w1, w2, wd);
1872 }
1873 }
1874
1875 let wt1_vec = unsafe {
1876 Vec::from_raw_parts(
1877 wt1_guard.as_mut_ptr() as *mut f64,
1878 wt1_guard.len(),
1879 wt1_guard.capacity(),
1880 )
1881 };
1882 let wt2_vec = unsafe {
1883 Vec::from_raw_parts(
1884 wt2_guard.as_mut_ptr() as *mut f64,
1885 wt2_guard.len(),
1886 wt2_guard.capacity(),
1887 )
1888 };
1889 let wt_diff_vec = unsafe {
1890 Vec::from_raw_parts(
1891 wt_diff_guard.as_mut_ptr() as *mut f64,
1892 wt_diff_guard.len(),
1893 wt_diff_guard.capacity(),
1894 )
1895 };
1896
1897 Ok(WavetrendBatchOutput {
1898 wt1: wt1_vec,
1899 wt2: wt2_vec,
1900 wt_diff: wt_diff_vec,
1901 combos,
1902 rows,
1903 cols,
1904 })
1905}
1906
1907#[inline(always)]
1908fn wavetrend_batch_inner_into(
1909 data: &[f64],
1910 sweep: &WavetrendBatchRange,
1911 kern: Kernel,
1912 parallel: bool,
1913 out_wt1: &mut [f64],
1914 out_wt2: &mut [f64],
1915 out_wt_diff: &mut [f64],
1916) -> Result<Vec<WavetrendParams>, WavetrendError> {
1917 let combos = expand_grid(sweep)?;
1918 if combos.is_empty() {
1919 return Err(WavetrendError::InvalidRange {
1920 start: "range".into(),
1921 end: "range".into(),
1922 step: "empty".into(),
1923 });
1924 }
1925 let first = data
1926 .iter()
1927 .position(|x| !x.is_nan())
1928 .ok_or(WavetrendError::AllValuesNaN)?;
1929
1930 let mut max_p = 0usize;
1931 for c in combos.iter() {
1932 let channel_length = c.channel_length.unwrap();
1933 if channel_length == 0 {
1934 return Err(WavetrendError::InvalidChannelLen {
1935 channel_length,
1936 data_len: data.len(),
1937 });
1938 }
1939 let average_length = c.average_length.unwrap();
1940 if average_length == 0 {
1941 return Err(WavetrendError::InvalidAverageLen {
1942 average_length,
1943 data_len: data.len(),
1944 });
1945 }
1946 let ma_length = c.ma_length.unwrap();
1947 if ma_length == 0 {
1948 return Err(WavetrendError::InvalidMaLen {
1949 ma_length,
1950 data_len: data.len(),
1951 });
1952 }
1953
1954 max_p = max_p.max(channel_length).max(average_length).max(ma_length);
1955 }
1956 if data.len() - first < max_p {
1957 return Err(WavetrendError::NotEnoughValidData {
1958 needed: max_p,
1959 valid: data.len() - first,
1960 });
1961 }
1962 let rows = combos.len();
1963 let cols = data.len();
1964
1965 let total = rows
1966 .checked_mul(cols)
1967 .ok_or_else(|| WavetrendError::InvalidRange {
1968 start: rows.to_string(),
1969 end: cols.to_string(),
1970 step: "rows*cols".into(),
1971 })?;
1972 if out_wt1.len() != total {
1973 return Err(WavetrendError::OutputSliceLengthMismatch {
1974 expected: total,
1975 got: out_wt1.len(),
1976 });
1977 }
1978 if out_wt2.len() != total {
1979 return Err(WavetrendError::OutputSliceLengthMismatch {
1980 expected: total,
1981 got: out_wt2.len(),
1982 });
1983 }
1984 if out_wt_diff.len() != total {
1985 return Err(WavetrendError::OutputSliceLengthMismatch {
1986 expected: total,
1987 got: out_wt_diff.len(),
1988 });
1989 }
1990
1991 for (row, combo) in combos.iter().enumerate() {
1992 let warmup = first + combo.channel_length.unwrap() - 1 + combo.average_length.unwrap() - 1
1993 + combo.ma_length.unwrap()
1994 - 1;
1995 let row_start = row * cols;
1996 for i in 0..warmup.min(cols) {
1997 out_wt1[row_start + i] = f64::NAN;
1998 out_wt2[row_start + i] = f64::NAN;
1999 out_wt_diff[row_start + i] = f64::NAN;
2000 }
2001 }
2002
2003 let do_row = |row: usize, w1: &mut [f64], w2: &mut [f64], wd: &mut [f64]| unsafe {
2004 let p = &combos[row];
2005 let r = wavetrend_row_scalar(
2006 data,
2007 first,
2008 p.channel_length.unwrap(),
2009 p.average_length.unwrap(),
2010 p.ma_length.unwrap(),
2011 p.factor.unwrap_or(0.015),
2012 w1,
2013 w2,
2014 wd,
2015 );
2016 if let Err(e) = r {
2017 panic!("wavetrend row error: {:?}", e);
2018 }
2019 };
2020
2021 if parallel {
2022 #[cfg(not(target_arch = "wasm32"))]
2023 {
2024 out_wt1
2025 .par_chunks_mut(cols)
2026 .zip(out_wt2.par_chunks_mut(cols))
2027 .zip(out_wt_diff.par_chunks_mut(cols))
2028 .enumerate()
2029 .for_each(|(row, ((w1, w2), wd))| do_row(row, w1, w2, wd));
2030 }
2031
2032 #[cfg(target_arch = "wasm32")]
2033 {
2034 for (row, (((w1, w2), wd))) in out_wt1
2035 .chunks_mut(cols)
2036 .zip(out_wt2.chunks_mut(cols))
2037 .zip(out_wt_diff.chunks_mut(cols))
2038 .enumerate()
2039 {
2040 do_row(row, w1, w2, wd);
2041 }
2042 }
2043 } else {
2044 for (row, (((w1, w2), wd))) in out_wt1
2045 .chunks_mut(cols)
2046 .zip(out_wt2.chunks_mut(cols))
2047 .zip(out_wt_diff.chunks_mut(cols))
2048 .enumerate()
2049 {
2050 do_row(row, w1, w2, wd);
2051 }
2052 }
2053 Ok(combos)
2054}
2055
2056#[inline(always)]
2057unsafe fn wavetrend_row_scalar(
2058 data: &[f64],
2059 first: usize,
2060 channel_len: usize,
2061 average_len: usize,
2062 ma_len: usize,
2063 factor: f64,
2064 wt1: &mut [f64],
2065 wt2: &mut [f64],
2066 wd: &mut [f64],
2067) -> Result<(), WavetrendError> {
2068 wavetrend_row_with_kernel(
2069 data,
2070 first,
2071 channel_len,
2072 average_len,
2073 ma_len,
2074 factor,
2075 wt1,
2076 wt2,
2077 wd,
2078 Kernel::Scalar,
2079 )
2080}
2081
2082#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2083#[inline(always)]
2084unsafe fn wavetrend_row_avx2(
2085 data: &[f64],
2086 first: usize,
2087 channel_len: usize,
2088 average_len: usize,
2089 ma_len: usize,
2090 factor: f64,
2091 wt1: &mut [f64],
2092 wt2: &mut [f64],
2093 wd: &mut [f64],
2094) -> Result<(), WavetrendError> {
2095 wavetrend_row_with_kernel(
2096 data,
2097 first,
2098 channel_len,
2099 average_len,
2100 ma_len,
2101 factor,
2102 wt1,
2103 wt2,
2104 wd,
2105 Kernel::Avx2,
2106 )
2107}
2108#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2109#[inline(always)]
2110unsafe fn wavetrend_row_avx512(
2111 data: &[f64],
2112 first: usize,
2113 channel_len: usize,
2114 average_len: usize,
2115 ma_len: usize,
2116 factor: f64,
2117 wt1: &mut [f64],
2118 wt2: &mut [f64],
2119 wd: &mut [f64],
2120) -> Result<(), WavetrendError> {
2121 wavetrend_row_with_kernel(
2122 data,
2123 first,
2124 channel_len,
2125 average_len,
2126 ma_len,
2127 factor,
2128 wt1,
2129 wt2,
2130 wd,
2131 Kernel::Avx512,
2132 )
2133}
2134#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2135#[inline(always)]
2136unsafe fn wavetrend_row_avx512_short(
2137 data: &[f64],
2138 first: usize,
2139 channel_len: usize,
2140 average_len: usize,
2141 ma_len: usize,
2142 factor: f64,
2143 wt1: &mut [f64],
2144 wt2: &mut [f64],
2145 wd: &mut [f64],
2146) -> Result<(), WavetrendError> {
2147 wavetrend_row_with_kernel(
2148 data,
2149 first,
2150 channel_len,
2151 average_len,
2152 ma_len,
2153 factor,
2154 wt1,
2155 wt2,
2156 wd,
2157 Kernel::Avx512,
2158 )
2159}
2160#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2161#[inline(always)]
2162unsafe fn wavetrend_row_avx512_long(
2163 data: &[f64],
2164 first: usize,
2165 channel_len: usize,
2166 average_len: usize,
2167 ma_len: usize,
2168 factor: f64,
2169 wt1: &mut [f64],
2170 wt2: &mut [f64],
2171 wd: &mut [f64],
2172) -> Result<(), WavetrendError> {
2173 wavetrend_row_with_kernel(
2174 data,
2175 first,
2176 channel_len,
2177 average_len,
2178 ma_len,
2179 factor,
2180 wt1,
2181 wt2,
2182 wd,
2183 Kernel::Avx512,
2184 )
2185}
2186
2187#[inline(always)]
2188unsafe fn wavetrend_row_with_kernel(
2189 data: &[f64],
2190 first: usize,
2191 channel_len: usize,
2192 average_len: usize,
2193 ma_len: usize,
2194 factor: f64,
2195 wt1: &mut [f64],
2196 wt2: &mut [f64],
2197 wd: &mut [f64],
2198 kernel: Kernel,
2199) -> Result<(), WavetrendError> {
2200 debug_assert_eq!(wt1.len(), data.len());
2201 debug_assert_eq!(wt2.len(), data.len());
2202 debug_assert_eq!(wd.len(), data.len());
2203
2204 let warmup = first + channel_len - 1 + average_len - 1 + ma_len - 1;
2205
2206 wavetrend_compute_into(
2207 data,
2208 channel_len,
2209 average_len,
2210 ma_len,
2211 factor,
2212 first,
2213 warmup,
2214 wt1,
2215 wt2,
2216 wd,
2217 kernel,
2218 )
2219}
2220#[cfg(test)]
2221mod tests {
2222 use super::*;
2223 use crate::skip_if_unsupported;
2224 use crate::utilities::data_loader::read_candles_from_csv;
2225 use crate::utilities::enums::Kernel;
2226
2227 fn check_wavetrend_partial_params(
2228 test_name: &str,
2229 kernel: Kernel,
2230 ) -> Result<(), Box<dyn std::error::Error>> {
2231 skip_if_unsupported!(kernel, test_name);
2232 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2233 let candles = read_candles_from_csv(file_path)?;
2234 let default_params = WavetrendParams {
2235 channel_length: None,
2236 average_length: None,
2237 ma_length: None,
2238 factor: None,
2239 };
2240 let input = WavetrendInput::from_candles(&candles, "hlc3", default_params);
2241 let output = wavetrend_with_kernel(&input, kernel)?;
2242 assert_eq!(output.wt1.len(), candles.close.len());
2243 Ok(())
2244 }
2245
2246 fn check_wavetrend_accuracy(
2247 test_name: &str,
2248 kernel: Kernel,
2249 ) -> Result<(), Box<dyn std::error::Error>> {
2250 skip_if_unsupported!(kernel, test_name);
2251 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2252 let candles = read_candles_from_csv(file_path)?;
2253 let input = WavetrendInput::from_candles(&candles, "hlc3", WavetrendParams::default());
2254 let result = wavetrend_with_kernel(&input, kernel)?;
2255 let len = result.wt1.len();
2256 let expected_wt1 = [
2257 -29.02058232514538,
2258 -28.207769813591664,
2259 -31.991808642927193,
2260 -31.9218051759519,
2261 -44.956245952893866,
2262 ];
2263 let expected_wt2 = [
2264 -30.651043230696555,
2265 -28.686329669808583,
2266 -29.740053593887932,
2267 -30.707127877490105,
2268 -36.2899532572575,
2269 ];
2270 for (i, &val) in result.wt1[len - 5..].iter().enumerate() {
2271 let diff = (val - expected_wt1[i]).abs();
2272 assert!(
2273 diff < 1e-6,
2274 "[{}] Wavetrend {:?} WT1 mismatch at idx {}: got {}, expected {}",
2275 test_name,
2276 kernel,
2277 i,
2278 val,
2279 expected_wt1[i]
2280 );
2281 }
2282 for (i, &val) in result.wt2[len - 5..].iter().enumerate() {
2283 let diff = (val - expected_wt2[i]).abs();
2284 assert!(
2285 diff < 1e-6,
2286 "[{}] Wavetrend {:?} WT2 mismatch at idx {}: got {}, expected {}",
2287 test_name,
2288 kernel,
2289 i,
2290 val,
2291 expected_wt2[i]
2292 );
2293 }
2294 let last_five_diff = &result.wt_diff[len - 5..];
2295 for i in 0..5 {
2296 let expected = expected_wt2[i] - expected_wt1[i];
2297 let diff = (last_five_diff[i] - expected).abs();
2298 assert!(
2299 diff < 1e-6,
2300 "[{}] Wavetrend {:?} WT_DIFF mismatch at idx {}: got {}, expected {}",
2301 test_name,
2302 kernel,
2303 i,
2304 last_five_diff[i],
2305 expected
2306 );
2307 }
2308 Ok(())
2309 }
2310
2311 fn check_wavetrend_default_candles(
2312 test_name: &str,
2313 kernel: Kernel,
2314 ) -> Result<(), Box<dyn std::error::Error>> {
2315 skip_if_unsupported!(kernel, test_name);
2316 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2317 let candles = read_candles_from_csv(file_path)?;
2318 let input = WavetrendInput::with_default_candles(&candles);
2319 match input.data {
2320 WavetrendData::Candles { source, .. } => assert_eq!(source, "hlc3"),
2321 _ => panic!("Expected WavetrendData::Candles"),
2322 }
2323 let output = wavetrend_with_kernel(&input, kernel)?;
2324 assert_eq!(output.wt1.len(), candles.close.len());
2325 Ok(())
2326 }
2327
2328 fn check_wavetrend_zero_channel(
2329 test_name: &str,
2330 kernel: Kernel,
2331 ) -> Result<(), Box<dyn std::error::Error>> {
2332 skip_if_unsupported!(kernel, test_name);
2333 let input_data = [10.0, 20.0, 30.0];
2334 let params = WavetrendParams {
2335 channel_length: Some(0),
2336 average_length: Some(12),
2337 ma_length: Some(3),
2338 factor: Some(0.015),
2339 };
2340 let input = WavetrendInput::from_slice(&input_data, params);
2341 let res = wavetrend_with_kernel(&input, kernel);
2342 assert!(
2343 res.is_err(),
2344 "[{}] Wavetrend should fail with zero channel_length",
2345 test_name
2346 );
2347 Ok(())
2348 }
2349
2350 fn check_wavetrend_channel_exceeds_length(
2351 test_name: &str,
2352 kernel: Kernel,
2353 ) -> Result<(), Box<dyn std::error::Error>> {
2354 skip_if_unsupported!(kernel, test_name);
2355 let data_small = [10.0, 20.0, 30.0];
2356 let params = WavetrendParams {
2357 channel_length: Some(10),
2358 average_length: Some(12),
2359 ma_length: Some(3),
2360 factor: Some(0.015),
2361 };
2362 let input = WavetrendInput::from_slice(&data_small, params);
2363 let res = wavetrend_with_kernel(&input, kernel);
2364 assert!(
2365 res.is_err(),
2366 "[{}] Wavetrend should fail with channel_length exceeding length",
2367 test_name
2368 );
2369 Ok(())
2370 }
2371
2372 fn check_wavetrend_very_small_dataset(
2373 test_name: &str,
2374 kernel: Kernel,
2375 ) -> Result<(), Box<dyn std::error::Error>> {
2376 skip_if_unsupported!(kernel, test_name);
2377 let single_point = [42.0];
2378 let params = WavetrendParams::default();
2379 let input = WavetrendInput::from_slice(&single_point, params);
2380 let res = wavetrend_with_kernel(&input, kernel);
2381 assert!(
2382 res.is_err(),
2383 "[{}] Wavetrend should fail with insufficient data",
2384 test_name
2385 );
2386 Ok(())
2387 }
2388
2389 fn check_wavetrend_nan_handling(
2390 test_name: &str,
2391 kernel: Kernel,
2392 ) -> Result<(), Box<dyn std::error::Error>> {
2393 skip_if_unsupported!(kernel, test_name);
2394 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2395 let candles = read_candles_from_csv(file_path)?;
2396 let input = WavetrendInput::from_candles(
2397 &candles,
2398 "hlc3",
2399 WavetrendParams {
2400 channel_length: Some(9),
2401 average_length: Some(12),
2402 ma_length: Some(3),
2403 factor: Some(0.015),
2404 },
2405 );
2406 let res = wavetrend_with_kernel(&input, kernel)?;
2407 assert_eq!(res.wt1.len(), candles.close.len());
2408 if res.wt1.len() > 240 {
2409 for (i, &val) in res.wt1[240..].iter().enumerate() {
2410 assert!(
2411 !val.is_nan(),
2412 "[{}] Found unexpected NaN at out-index {}",
2413 test_name,
2414 240 + i
2415 );
2416 }
2417 }
2418 Ok(())
2419 }
2420
2421 #[cfg(debug_assertions)]
2422 fn check_wavetrend_no_poison(
2423 test_name: &str,
2424 kernel: Kernel,
2425 ) -> Result<(), Box<dyn std::error::Error>> {
2426 skip_if_unsupported!(kernel, test_name);
2427
2428 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2429 let candles = read_candles_from_csv(file_path)?;
2430
2431 let test_params = vec![
2432 WavetrendParams::default(),
2433 WavetrendParams {
2434 channel_length: Some(1),
2435 average_length: Some(1),
2436 ma_length: Some(1),
2437 factor: Some(0.001),
2438 },
2439 WavetrendParams {
2440 channel_length: Some(2),
2441 average_length: Some(2),
2442 ma_length: Some(2),
2443 factor: Some(0.005),
2444 },
2445 WavetrendParams {
2446 channel_length: Some(5),
2447 average_length: Some(7),
2448 ma_length: Some(3),
2449 factor: Some(0.01),
2450 },
2451 WavetrendParams {
2452 channel_length: Some(10),
2453 average_length: Some(15),
2454 ma_length: Some(5),
2455 factor: Some(0.02),
2456 },
2457 WavetrendParams {
2458 channel_length: Some(20),
2459 average_length: Some(25),
2460 ma_length: Some(7),
2461 factor: Some(0.025),
2462 },
2463 WavetrendParams {
2464 channel_length: Some(30),
2465 average_length: Some(40),
2466 ma_length: Some(10),
2467 factor: Some(0.03),
2468 },
2469 WavetrendParams {
2470 channel_length: Some(50),
2471 average_length: Some(60),
2472 ma_length: Some(15),
2473 factor: Some(0.04),
2474 },
2475 WavetrendParams {
2476 channel_length: Some(100),
2477 average_length: Some(120),
2478 ma_length: Some(20),
2479 factor: Some(0.05),
2480 },
2481 WavetrendParams {
2482 channel_length: Some(7),
2483 average_length: Some(11),
2484 ma_length: Some(3),
2485 factor: Some(0.013),
2486 },
2487 WavetrendParams {
2488 channel_length: Some(13),
2489 average_length: Some(17),
2490 ma_length: Some(5),
2491 factor: Some(0.017),
2492 },
2493 WavetrendParams {
2494 channel_length: Some(9),
2495 average_length: Some(3),
2496 ma_length: Some(12),
2497 factor: Some(0.015),
2498 },
2499 WavetrendParams {
2500 channel_length: Some(15),
2501 average_length: Some(15),
2502 ma_length: Some(15),
2503 factor: Some(0.015),
2504 },
2505 WavetrendParams {
2506 channel_length: Some(9),
2507 average_length: Some(12),
2508 ma_length: Some(3),
2509 factor: Some(0.0001),
2510 },
2511 WavetrendParams {
2512 channel_length: Some(9),
2513 average_length: Some(12),
2514 ma_length: Some(3),
2515 factor: Some(1.0),
2516 },
2517 WavetrendParams {
2518 channel_length: Some(3),
2519 average_length: Some(5),
2520 ma_length: Some(1),
2521 factor: Some(0.008),
2522 },
2523 WavetrendParams {
2524 channel_length: Some(8),
2525 average_length: Some(13),
2526 ma_length: Some(2),
2527 factor: Some(0.021),
2528 },
2529 WavetrendParams {
2530 channel_length: Some(21),
2531 average_length: Some(34),
2532 ma_length: Some(8),
2533 factor: Some(0.034),
2534 },
2535 ];
2536
2537 for (param_idx, params) in test_params.iter().enumerate() {
2538 let input = WavetrendInput::from_candles(&candles, "hlc3", params.clone());
2539 let output = wavetrend_with_kernel(&input, kernel)?;
2540
2541 for (i, &val) in output.wt1.iter().enumerate() {
2542 if val.is_nan() {
2543 continue;
2544 }
2545
2546 let bits = val.to_bits();
2547
2548 if bits == 0x11111111_11111111 {
2549 panic!(
2550 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2551 in wt1 output with params: channel_length={}, average_length={}, ma_length={}, factor={} (param set {})",
2552 test_name, val, bits, i,
2553 params.channel_length.unwrap_or(9),
2554 params.average_length.unwrap_or(12),
2555 params.ma_length.unwrap_or(3),
2556 params.factor.unwrap_or(0.015),
2557 param_idx
2558 );
2559 }
2560
2561 if bits == 0x22222222_22222222 {
2562 panic!(
2563 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2564 in wt1 output with params: channel_length={}, average_length={}, ma_length={}, factor={} (param set {})",
2565 test_name, val, bits, i,
2566 params.channel_length.unwrap_or(9),
2567 params.average_length.unwrap_or(12),
2568 params.ma_length.unwrap_or(3),
2569 params.factor.unwrap_or(0.015),
2570 param_idx
2571 );
2572 }
2573
2574 if bits == 0x33333333_33333333 {
2575 panic!(
2576 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2577 in wt1 output with params: channel_length={}, average_length={}, ma_length={}, factor={} (param set {})",
2578 test_name, val, bits, i,
2579 params.channel_length.unwrap_or(9),
2580 params.average_length.unwrap_or(12),
2581 params.ma_length.unwrap_or(3),
2582 params.factor.unwrap_or(0.015),
2583 param_idx
2584 );
2585 }
2586 }
2587
2588 for (i, &val) in output.wt2.iter().enumerate() {
2589 if val.is_nan() {
2590 continue;
2591 }
2592
2593 let bits = val.to_bits();
2594
2595 if bits == 0x11111111_11111111 {
2596 panic!(
2597 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2598 in wt2 output with params: channel_length={}, average_length={}, ma_length={}, factor={} (param set {})",
2599 test_name, val, bits, i,
2600 params.channel_length.unwrap_or(9),
2601 params.average_length.unwrap_or(12),
2602 params.ma_length.unwrap_or(3),
2603 params.factor.unwrap_or(0.015),
2604 param_idx
2605 );
2606 }
2607
2608 if bits == 0x22222222_22222222 {
2609 panic!(
2610 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2611 in wt2 output with params: channel_length={}, average_length={}, ma_length={}, factor={} (param set {})",
2612 test_name, val, bits, i,
2613 params.channel_length.unwrap_or(9),
2614 params.average_length.unwrap_or(12),
2615 params.ma_length.unwrap_or(3),
2616 params.factor.unwrap_or(0.015),
2617 param_idx
2618 );
2619 }
2620
2621 if bits == 0x33333333_33333333 {
2622 panic!(
2623 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2624 in wt2 output with params: channel_length={}, average_length={}, ma_length={}, factor={} (param set {})",
2625 test_name, val, bits, i,
2626 params.channel_length.unwrap_or(9),
2627 params.average_length.unwrap_or(12),
2628 params.ma_length.unwrap_or(3),
2629 params.factor.unwrap_or(0.015),
2630 param_idx
2631 );
2632 }
2633 }
2634
2635 for (i, &val) in output.wt_diff.iter().enumerate() {
2636 if val.is_nan() {
2637 continue;
2638 }
2639
2640 let bits = val.to_bits();
2641
2642 if bits == 0x11111111_11111111 {
2643 panic!(
2644 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2645 in wt_diff output with params: channel_length={}, average_length={}, ma_length={}, factor={} (param set {})",
2646 test_name, val, bits, i,
2647 params.channel_length.unwrap_or(9),
2648 params.average_length.unwrap_or(12),
2649 params.ma_length.unwrap_or(3),
2650 params.factor.unwrap_or(0.015),
2651 param_idx
2652 );
2653 }
2654
2655 if bits == 0x22222222_22222222 {
2656 panic!(
2657 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2658 in wt_diff output with params: channel_length={}, average_length={}, ma_length={}, factor={} (param set {})",
2659 test_name, val, bits, i,
2660 params.channel_length.unwrap_or(9),
2661 params.average_length.unwrap_or(12),
2662 params.ma_length.unwrap_or(3),
2663 params.factor.unwrap_or(0.015),
2664 param_idx
2665 );
2666 }
2667
2668 if bits == 0x33333333_33333333 {
2669 panic!(
2670 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2671 in wt_diff output with params: channel_length={}, average_length={}, ma_length={}, factor={} (param set {})",
2672 test_name, val, bits, i,
2673 params.channel_length.unwrap_or(9),
2674 params.average_length.unwrap_or(12),
2675 params.ma_length.unwrap_or(3),
2676 params.factor.unwrap_or(0.015),
2677 param_idx
2678 );
2679 }
2680 }
2681 }
2682
2683 Ok(())
2684 }
2685
2686 #[cfg(not(debug_assertions))]
2687 fn check_wavetrend_no_poison(
2688 _test_name: &str,
2689 _kernel: Kernel,
2690 ) -> Result<(), Box<dyn std::error::Error>> {
2691 Ok(())
2692 }
2693
2694 fn check_wavetrend_streaming(
2695 test_name: &str,
2696 kernel: Kernel,
2697 ) -> Result<(), Box<dyn std::error::Error>> {
2698 skip_if_unsupported!(kernel, test_name);
2699
2700 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2701 let candles = read_candles_from_csv(file_path)?;
2702
2703 let channel_length = 9;
2704 let average_length = 12;
2705 let ma_length = 3;
2706 let factor = 0.015;
2707
2708 let input = WavetrendInput::from_candles(
2709 &candles,
2710 "hlc3",
2711 WavetrendParams {
2712 channel_length: Some(channel_length),
2713 average_length: Some(average_length),
2714 ma_length: Some(ma_length),
2715 factor: Some(factor),
2716 },
2717 );
2718 let full_output = wavetrend_with_kernel(&input, kernel)?;
2719
2720 let mut stream = WavetrendStream::try_new(WavetrendParams {
2721 channel_length: Some(channel_length),
2722 average_length: Some(average_length),
2723 ma_length: Some(ma_length),
2724 factor: Some(factor),
2725 })?;
2726
2727 let mut wt1_stream = Vec::with_capacity(candles.hlc3.len());
2728 let mut wt2_stream = Vec::with_capacity(candles.hlc3.len());
2729 let mut diff_stream = Vec::with_capacity(candles.hlc3.len());
2730 for &price in &candles.hlc3 {
2731 match stream.update(price) {
2732 Some((wt1, wt2, diff)) => {
2733 wt1_stream.push(wt1);
2734 wt2_stream.push(wt2);
2735 diff_stream.push(diff);
2736 }
2737 None => {
2738 wt1_stream.push(f64::NAN);
2739 wt2_stream.push(f64::NAN);
2740 diff_stream.push(f64::NAN);
2741 }
2742 }
2743 }
2744
2745 let mut first_non_nan = None;
2746 for (i, &b) in full_output.wt1.iter().enumerate() {
2747 if !b.is_nan() {
2748 first_non_nan = Some(i);
2749 break;
2750 }
2751 }
2752 let start = first_non_nan.unwrap_or(0);
2753 assert_eq!(full_output.wt1.len(), wt1_stream.len());
2754 for (i, (&b, &s)) in full_output
2755 .wt1
2756 .iter()
2757 .zip(wt1_stream.iter())
2758 .enumerate()
2759 .skip(start)
2760 {
2761 if b.is_nan() || s.is_nan() {
2762 continue;
2763 }
2764 let diff = (b - s).abs();
2765 assert!(
2766 diff < 1e-9,
2767 "[{}] Wavetrend streaming wt1 f64 mismatch at idx {}: full={}, stream={}, diff={}",
2768 test_name,
2769 i,
2770 b,
2771 s,
2772 diff
2773 );
2774 }
2775 for (i, (&b, &s)) in full_output.wt2.iter().zip(wt2_stream.iter()).enumerate() {
2776 if b.is_nan() || s.is_nan() {
2777 continue;
2778 }
2779 let diff = (b - s).abs();
2780 assert!(
2781 diff < 1e-9,
2782 "[{}] Wavetrend streaming wt2 f64 mismatch at idx {}: full={}, stream={}, diff={}",
2783 test_name,
2784 i,
2785 b,
2786 s,
2787 diff
2788 );
2789 }
2790 for (i, (&b, &s)) in full_output
2791 .wt_diff
2792 .iter()
2793 .zip(diff_stream.iter())
2794 .enumerate()
2795 {
2796 if b.is_nan() || s.is_nan() {
2797 continue;
2798 }
2799 let diff = (b - s).abs();
2800 assert!(
2801 diff < 1e-9,
2802 "[{}] Wavetrend streaming wt_diff f64 mismatch at idx {}: full={}, stream={}, diff={}",
2803 test_name,
2804 i,
2805 b,
2806 s,
2807 diff
2808 );
2809 }
2810 Ok(())
2811 }
2812
2813 #[cfg(feature = "proptest")]
2814 fn check_wavetrend_property(
2815 test_name: &str,
2816 kernel: Kernel,
2817 ) -> Result<(), Box<dyn std::error::Error>> {
2818 use proptest::prelude::*;
2819 skip_if_unsupported!(kernel, test_name);
2820
2821 let strat = (2usize..=30, 2usize..=30, 1usize..=10, 0.001f64..1.0f64).prop_flat_map(
2822 |(channel_len, average_len, ma_len, factor)| {
2823 let min_len = channel_len + average_len + ma_len + 20;
2824 (min_len..400).prop_flat_map(move |data_len| {
2825 (
2826 prop::collection::vec(
2827 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2828 data_len,
2829 ),
2830 Just(channel_len),
2831 Just(average_len),
2832 Just(ma_len),
2833 Just(factor),
2834 )
2835 })
2836 },
2837 );
2838
2839 proptest::test_runner::TestRunner::default()
2840 .run(
2841 &strat,
2842 |(data, channel_len, average_len, ma_len, factor)| {
2843 let params = WavetrendParams {
2844 channel_length: Some(channel_len),
2845 average_length: Some(average_len),
2846 ma_length: Some(ma_len),
2847 factor: Some(factor),
2848 };
2849 let input = WavetrendInput::from_slice(&data, params);
2850
2851 let output = wavetrend_with_kernel(&input, kernel).unwrap();
2852 let ref_output = wavetrend_with_kernel(&input, Kernel::Scalar).unwrap();
2853
2854 let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
2855 let expected_warmup =
2856 first_valid + channel_len - 1 + average_len - 1 + ma_len - 1;
2857
2858 for i in expected_warmup.min(data.len())..data.len() {
2859 if output.wt1[i].is_finite() && output.wt2[i].is_finite() {
2860 let expected_diff = output.wt2[i] - output.wt1[i];
2861 let actual_diff = output.wt_diff[i];
2862 prop_assert!(
2863 (actual_diff - expected_diff).abs() <= 1e-9,
2864 "WT_DIFF mismatch at idx {}: expected {}, got {}",
2865 i,
2866 expected_diff,
2867 actual_diff
2868 );
2869 }
2870 }
2871
2872 let valid_start = expected_warmup.min(data.len());
2873 let valid_wt1: Vec<f64> = output.wt1[valid_start..]
2874 .iter()
2875 .filter(|&&x| x.is_finite())
2876 .copied()
2877 .collect();
2878 let valid_wt2: Vec<f64> = output.wt2[valid_start..]
2879 .iter()
2880 .filter(|&&x| x.is_finite())
2881 .copied()
2882 .collect();
2883
2884 if valid_wt1.len() > 10 && valid_wt2.len() > 10 && ma_len > 1 {
2885 let mut wt1_changes = 0.0;
2886 let mut wt2_changes = 0.0;
2887 for i in 1..valid_wt1.len().min(valid_wt2.len()) {
2888 wt1_changes += (valid_wt1[i] - valid_wt1[i - 1]).abs();
2889 wt2_changes += (valid_wt2[i] - valid_wt2[i - 1]).abs();
2890 }
2891
2892 if wt1_changes > 1e-6 {
2893 prop_assert!(
2894 wt2_changes <= wt1_changes * 1.1,
2895 "WT2 should be smoother: wt1_changes={}, wt2_changes={}",
2896 wt1_changes,
2897 wt2_changes
2898 );
2899 }
2900 }
2901
2902 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-9)
2903 && data.len() > valid_start + 10
2904 {
2905 let last_10_wt1: Vec<f64> = output.wt1[output.wt1.len() - 10..]
2906 .iter()
2907 .filter(|&&x| x.is_finite())
2908 .copied()
2909 .collect();
2910 if last_10_wt1.len() >= 5 {
2911 let avg_wt1: f64 =
2912 last_10_wt1.iter().sum::<f64>() / last_10_wt1.len() as f64;
2913 prop_assert!(
2914 avg_wt1.abs() <= 1.0,
2915 "Constant price should give near-zero oscillator: avg_wt1={}",
2916 avg_wt1
2917 );
2918 }
2919 }
2920
2921 if factor < 0.5 && valid_start < data.len() {
2922 let params_double = WavetrendParams {
2923 channel_length: Some(channel_len),
2924 average_length: Some(average_len),
2925 ma_length: Some(ma_len),
2926 factor: Some(factor * 2.0),
2927 };
2928 let input_double = WavetrendInput::from_slice(&data, params_double);
2929 let output_double = wavetrend_with_kernel(&input_double, kernel).unwrap();
2930
2931 let check_end = data.len().min(valid_start + 20);
2932 let mut checked_count = 0;
2933 for i in valid_start..check_end {
2934 if output.wt1[i].is_finite()
2935 && output_double.wt1[i].is_finite()
2936 && output.wt1[i].abs() > 0.1
2937 {
2938 let ratio = output_double.wt1[i] / output.wt1[i];
2939
2940 prop_assert!(
2941 (ratio - 0.5).abs() <= 0.35,
2942 "Factor doubling should roughly halve WT1 at idx {}: original={}, doubled={}, ratio={}",
2943 i, output.wt1[i], output_double.wt1[i], ratio
2944 );
2945 checked_count += 1;
2946 if checked_count >= 5 {
2947 break;
2948 }
2949 }
2950 }
2951 }
2952
2953 if ma_len == 1 {
2954 for i in valid_start..data.len() {
2955 if output.wt1[i].is_finite() && output.wt2[i].is_finite() {
2956 prop_assert!(
2957 (output.wt1[i] - output.wt2[i]).abs() <= 1e-9,
2958 "When ma_len=1, WT2 should equal WT1 at idx {}: wt1={}, wt2={}",
2959 i,
2960 output.wt1[i],
2961 output.wt2[i]
2962 );
2963 }
2964 }
2965 }
2966
2967 for i in 0..data.len() {
2968 let wt1 = output.wt1[i];
2969 let wt1_ref = ref_output.wt1[i];
2970 let wt2 = output.wt2[i];
2971 let wt2_ref = ref_output.wt2[i];
2972 let diff = output.wt_diff[i];
2973 let diff_ref = ref_output.wt_diff[i];
2974
2975 if wt1.is_nan() || wt1_ref.is_nan() {
2976 prop_assert!(
2977 wt1.is_nan() && wt1_ref.is_nan(),
2978 "NaN mismatch for WT1 at idx {}: kernel={:?}, ref={:?}",
2979 i,
2980 wt1,
2981 wt1_ref
2982 );
2983 } else {
2984 let wt1_bits = wt1.to_bits();
2985 let wt1_ref_bits = wt1_ref.to_bits();
2986 let ulp_diff = wt1_bits.abs_diff(wt1_ref_bits);
2987 prop_assert!(
2988 (wt1 - wt1_ref).abs() <= 1e-9 || ulp_diff <= 4,
2989 "WT1 mismatch at idx {}: kernel={}, ref={} (ULP={})",
2990 i,
2991 wt1,
2992 wt1_ref,
2993 ulp_diff
2994 );
2995 }
2996
2997 if wt2.is_nan() || wt2_ref.is_nan() {
2998 prop_assert!(
2999 wt2.is_nan() && wt2_ref.is_nan(),
3000 "NaN mismatch for WT2 at idx {}: kernel={:?}, ref={:?}",
3001 i,
3002 wt2,
3003 wt2_ref
3004 );
3005 } else {
3006 let wt2_bits = wt2.to_bits();
3007 let wt2_ref_bits = wt2_ref.to_bits();
3008 let ulp_diff = wt2_bits.abs_diff(wt2_ref_bits);
3009 prop_assert!(
3010 (wt2 - wt2_ref).abs() <= 1e-9 || ulp_diff <= 4,
3011 "WT2 mismatch at idx {}: kernel={}, ref={} (ULP={})",
3012 i,
3013 wt2,
3014 wt2_ref,
3015 ulp_diff
3016 );
3017 }
3018
3019 if diff.is_nan() || diff_ref.is_nan() {
3020 prop_assert!(
3021 diff.is_nan() && diff_ref.is_nan(),
3022 "NaN mismatch for WT_DIFF at idx {}: kernel={:?}, ref={:?}",
3023 i,
3024 diff,
3025 diff_ref
3026 );
3027 } else {
3028 let diff_bits = diff.to_bits();
3029 let diff_ref_bits = diff_ref.to_bits();
3030 let ulp_diff = diff_bits.abs_diff(diff_ref_bits);
3031 prop_assert!(
3032 (diff - diff_ref).abs() <= 1e-9 || ulp_diff <= 4,
3033 "WT_DIFF mismatch at idx {}: kernel={}, ref={} (ULP={})",
3034 i,
3035 diff,
3036 diff_ref,
3037 ulp_diff
3038 );
3039 }
3040 }
3041
3042 for i in 0..expected_warmup.min(data.len()) {
3043 prop_assert!(
3044 output.wt1[i].is_nan(),
3045 "WT1 should be NaN during warmup at idx {}: got {}",
3046 i,
3047 output.wt1[i]
3048 );
3049 prop_assert!(
3050 output.wt2[i].is_nan(),
3051 "WT2 should be NaN during warmup at idx {}: got {}",
3052 i,
3053 output.wt2[i]
3054 );
3055 prop_assert!(
3056 output.wt_diff[i].is_nan(),
3057 "WT_DIFF should be NaN during warmup at idx {}: got {}",
3058 i,
3059 output.wt_diff[i]
3060 );
3061 }
3062
3063 Ok(())
3064 },
3065 )
3066 .unwrap();
3067
3068 Ok(())
3069 }
3070
3071 macro_rules! generate_all_wavetrend_tests {
3072 ($($test_fn:ident),*) => {
3073 paste::paste! {
3074 $(
3075 #[test]
3076 fn [<$test_fn _scalar_f64>]() {
3077 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
3078 }
3079 )*
3080 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3081 $(
3082 #[test]
3083 fn [<$test_fn _avx2_f64>]() {
3084 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
3085 }
3086 #[test]
3087 fn [<$test_fn _avx512_f64>]() {
3088 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
3089 }
3090 )*
3091 }
3092 }
3093 }
3094
3095 generate_all_wavetrend_tests!(
3096 check_wavetrend_partial_params,
3097 check_wavetrend_accuracy,
3098 check_wavetrend_default_candles,
3099 check_wavetrend_zero_channel,
3100 check_wavetrend_channel_exceeds_length,
3101 check_wavetrend_very_small_dataset,
3102 check_wavetrend_nan_handling,
3103 check_wavetrend_streaming,
3104 check_wavetrend_no_poison
3105 );
3106
3107 #[cfg(feature = "proptest")]
3108 generate_all_wavetrend_tests!(check_wavetrend_property);
3109
3110 fn check_batch_default_row(
3111 test: &str,
3112 kernel: Kernel,
3113 ) -> Result<(), Box<dyn std::error::Error>> {
3114 skip_if_unsupported!(kernel, test);
3115 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3116 let c = read_candles_from_csv(file)?;
3117
3118 let output = WavetrendBatchBuilder::new()
3119 .kernel(kernel)
3120 .apply_candles(&c, "hlc3")?;
3121
3122 let def = WavetrendParams::default();
3123 let (wt1_row, wt2_row, diff_row) = output.values_for(&def).expect("default row missing");
3124
3125 assert_eq!(wt1_row.len(), c.close.len());
3126 assert_eq!(wt2_row.len(), c.close.len());
3127 assert_eq!(diff_row.len(), c.close.len());
3128
3129 let expected_wt1 = [
3130 -29.02058232514538,
3131 -28.207769813591664,
3132 -31.991808642927193,
3133 -31.9218051759519,
3134 -44.956245952893866,
3135 ];
3136 let expected_wt2 = [
3137 -30.651043230696555,
3138 -28.686329669808583,
3139 -29.740053593887932,
3140 -30.707127877490105,
3141 -36.2899532572575,
3142 ];
3143
3144 let start = wt1_row.len().saturating_sub(5);
3145 for (i, &v) in wt1_row[start..].iter().enumerate() {
3146 assert!(
3147 (v - expected_wt1[i]).abs() < 1e-8,
3148 "[{test}] default-row WT1 mismatch at idx {i}: {v} vs {expected}",
3149 test = test,
3150 i = i,
3151 v = v,
3152 expected = expected_wt1[i]
3153 );
3154 }
3155 for (i, &v) in wt2_row[start..].iter().enumerate() {
3156 assert!(
3157 (v - expected_wt2[i]).abs() < 1e-6,
3158 "[{test}] default-row WT2 mismatch at idx {i}: {v} vs {expected}",
3159 test = test,
3160 i = i,
3161 v = v,
3162 expected = expected_wt2[i]
3163 );
3164 }
3165 Ok(())
3166 }
3167
3168 #[cfg(debug_assertions)]
3169 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
3170 skip_if_unsupported!(kernel, test);
3171
3172 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3173 let c = read_candles_from_csv(file)?;
3174
3175 let test_configs = vec![
3176 (2, 10, 2, 3, 12, 3, 1, 5, 1, 0.005, 0.015, 0.005),
3177 (5, 25, 5, 10, 30, 5, 2, 8, 2, 0.01, 0.03, 0.01),
3178 (20, 60, 10, 25, 75, 10, 5, 15, 5, 0.02, 0.05, 0.015),
3179 (2, 5, 1, 2, 5, 1, 1, 3, 1, 0.001, 0.005, 0.001),
3180 (10, 30, 10, 15, 45, 15, 3, 9, 3, 0.015, 0.045, 0.015),
3181 (50, 100, 25, 60, 120, 30, 10, 20, 5, 0.03, 0.06, 0.03),
3182 (9, 9, 0, 12, 12, 0, 3, 3, 0, 0.015, 0.015, 0.0),
3183 (1, 3, 1, 1, 3, 1, 1, 2, 1, 0.001, 0.003, 0.001),
3184 ];
3185
3186 for (cfg_idx, config) in test_configs.iter().enumerate() {
3187 let output = WavetrendBatchBuilder::new()
3188 .kernel(kernel)
3189 .channel_range(config.0, config.1, config.2)
3190 .avg_range(config.3, config.4, config.5)
3191 .ma_range(config.6, config.7, config.8)
3192 .factor_range(config.9, config.10, config.11)
3193 .apply_candles(&c, "hlc3")?;
3194
3195 for (idx, &val) in output.wt1.iter().enumerate() {
3196 if val.is_nan() {
3197 continue;
3198 }
3199
3200 let bits = val.to_bits();
3201 let row = idx / output.cols;
3202 let col = idx % output.cols;
3203 let combo = &output.combos[row];
3204
3205 if bits == 0x11111111_11111111 {
3206 panic!(
3207 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
3208 at row {} col {} (flat index {}) in wt1 output with params: channel_length={}, average_length={}, ma_length={}, factor={}",
3209 test, cfg_idx, val, bits, row, col, idx,
3210 combo.channel_length.unwrap_or(9),
3211 combo.average_length.unwrap_or(12),
3212 combo.ma_length.unwrap_or(3),
3213 combo.factor.unwrap_or(0.015)
3214 );
3215 }
3216
3217 if bits == 0x22222222_22222222 {
3218 panic!(
3219 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
3220 at row {} col {} (flat index {}) in wt1 output with params: channel_length={}, average_length={}, ma_length={}, factor={}",
3221 test, cfg_idx, val, bits, row, col, idx,
3222 combo.channel_length.unwrap_or(9),
3223 combo.average_length.unwrap_or(12),
3224 combo.ma_length.unwrap_or(3),
3225 combo.factor.unwrap_or(0.015)
3226 );
3227 }
3228
3229 if bits == 0x33333333_33333333 {
3230 panic!(
3231 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
3232 at row {} col {} (flat index {}) in wt1 output with params: channel_length={}, average_length={}, ma_length={}, factor={}",
3233 test, cfg_idx, val, bits, row, col, idx,
3234 combo.channel_length.unwrap_or(9),
3235 combo.average_length.unwrap_or(12),
3236 combo.ma_length.unwrap_or(3),
3237 combo.factor.unwrap_or(0.015)
3238 );
3239 }
3240 }
3241
3242 for (idx, &val) in output.wt2.iter().enumerate() {
3243 if val.is_nan() {
3244 continue;
3245 }
3246
3247 let bits = val.to_bits();
3248 let row = idx / output.cols;
3249 let col = idx % output.cols;
3250 let combo = &output.combos[row];
3251
3252 if bits == 0x11111111_11111111 {
3253 panic!(
3254 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
3255 at row {} col {} (flat index {}) in wt2 output with params: channel_length={}, average_length={}, ma_length={}, factor={}",
3256 test, cfg_idx, val, bits, row, col, idx,
3257 combo.channel_length.unwrap_or(9),
3258 combo.average_length.unwrap_or(12),
3259 combo.ma_length.unwrap_or(3),
3260 combo.factor.unwrap_or(0.015)
3261 );
3262 }
3263
3264 if bits == 0x22222222_22222222 {
3265 panic!(
3266 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
3267 at row {} col {} (flat index {}) in wt2 output with params: channel_length={}, average_length={}, ma_length={}, factor={}",
3268 test, cfg_idx, val, bits, row, col, idx,
3269 combo.channel_length.unwrap_or(9),
3270 combo.average_length.unwrap_or(12),
3271 combo.ma_length.unwrap_or(3),
3272 combo.factor.unwrap_or(0.015)
3273 );
3274 }
3275
3276 if bits == 0x33333333_33333333 {
3277 panic!(
3278 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
3279 at row {} col {} (flat index {}) in wt2 output with params: channel_length={}, average_length={}, ma_length={}, factor={}",
3280 test, cfg_idx, val, bits, row, col, idx,
3281 combo.channel_length.unwrap_or(9),
3282 combo.average_length.unwrap_or(12),
3283 combo.ma_length.unwrap_or(3),
3284 combo.factor.unwrap_or(0.015)
3285 );
3286 }
3287 }
3288
3289 for (idx, &val) in output.wt_diff.iter().enumerate() {
3290 if val.is_nan() {
3291 continue;
3292 }
3293
3294 let bits = val.to_bits();
3295 let row = idx / output.cols;
3296 let col = idx % output.cols;
3297 let combo = &output.combos[row];
3298
3299 if bits == 0x11111111_11111111 {
3300 panic!(
3301 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
3302 at row {} col {} (flat index {}) in wt_diff output with params: channel_length={}, average_length={}, ma_length={}, factor={}",
3303 test, cfg_idx, val, bits, row, col, idx,
3304 combo.channel_length.unwrap_or(9),
3305 combo.average_length.unwrap_or(12),
3306 combo.ma_length.unwrap_or(3),
3307 combo.factor.unwrap_or(0.015)
3308 );
3309 }
3310
3311 if bits == 0x22222222_22222222 {
3312 panic!(
3313 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
3314 at row {} col {} (flat index {}) in wt_diff output with params: channel_length={}, average_length={}, ma_length={}, factor={}",
3315 test, cfg_idx, val, bits, row, col, idx,
3316 combo.channel_length.unwrap_or(9),
3317 combo.average_length.unwrap_or(12),
3318 combo.ma_length.unwrap_or(3),
3319 combo.factor.unwrap_or(0.015)
3320 );
3321 }
3322
3323 if bits == 0x33333333_33333333 {
3324 panic!(
3325 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
3326 at row {} col {} (flat index {}) in wt_diff output with params: channel_length={}, average_length={}, ma_length={}, factor={}",
3327 test, cfg_idx, val, bits, row, col, idx,
3328 combo.channel_length.unwrap_or(9),
3329 combo.average_length.unwrap_or(12),
3330 combo.ma_length.unwrap_or(3),
3331 combo.factor.unwrap_or(0.015)
3332 );
3333 }
3334 }
3335 }
3336
3337 Ok(())
3338 }
3339
3340 #[cfg(not(debug_assertions))]
3341 fn check_batch_no_poison(
3342 _test: &str,
3343 _kernel: Kernel,
3344 ) -> Result<(), Box<dyn std::error::Error>> {
3345 Ok(())
3346 }
3347
3348 macro_rules! gen_batch_tests {
3349 ($fn_name:ident) => {
3350 paste::paste! {
3351 #[test] fn [<$fn_name _scalar>]() {
3352 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3353 }
3354 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3355 #[test] fn [<$fn_name _avx2>]() {
3356 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3357 }
3358 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3359 #[test] fn [<$fn_name _avx512>]() {
3360 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3361 }
3362 #[test] fn [<$fn_name _auto_detect>]() {
3363 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3364 }
3365 }
3366 };
3367 }
3368 gen_batch_tests!(check_batch_default_row);
3369 gen_batch_tests!(check_batch_no_poison);
3370}
3371
3372#[cfg(feature = "python")]
3373#[pyfunction(name = "wavetrend")]
3374#[pyo3(signature = (data, channel_length, average_length, ma_length, factor, kernel=None))]
3375pub fn wavetrend_py<'py>(
3376 py: Python<'py>,
3377 data: numpy::PyReadonlyArray1<'py, f64>,
3378 channel_length: usize,
3379 average_length: usize,
3380 ma_length: usize,
3381 factor: f64,
3382 kernel: Option<&str>,
3383) -> PyResult<(
3384 Bound<'py, PyArray1<f64>>,
3385 Bound<'py, PyArray1<f64>>,
3386 Bound<'py, PyArray1<f64>>,
3387)> {
3388 use numpy::{IntoPyArray, PyArrayMethods};
3389
3390 let slice_in = data.as_slice()?;
3391 let kern = validate_kernel(kernel, false)?;
3392
3393 let params = WavetrendParams {
3394 channel_length: Some(channel_length),
3395 average_length: Some(average_length),
3396 ma_length: Some(ma_length),
3397 factor: Some(factor),
3398 };
3399 let input = WavetrendInput::from_slice(slice_in, params);
3400
3401 let (wt1_vec, wt2_vec, wt_diff_vec) = py
3402 .allow_threads(|| wavetrend_with_kernel(&input, kern).map(|o| (o.wt1, o.wt2, o.wt_diff)))
3403 .map_err(|e| PyValueError::new_err(e.to_string()))?;
3404
3405 Ok((
3406 wt1_vec.into_pyarray(py),
3407 wt2_vec.into_pyarray(py),
3408 wt_diff_vec.into_pyarray(py),
3409 ))
3410}
3411
3412#[cfg(feature = "python")]
3413#[pyclass(name = "WavetrendStream")]
3414pub struct WavetrendStreamPy {
3415 stream: WavetrendStream,
3416}
3417
3418#[cfg(feature = "python")]
3419#[pymethods]
3420impl WavetrendStreamPy {
3421 #[new]
3422 fn new(
3423 channel_length: usize,
3424 average_length: usize,
3425 ma_length: usize,
3426 factor: f64,
3427 ) -> PyResult<Self> {
3428 let params = WavetrendParams {
3429 channel_length: Some(channel_length),
3430 average_length: Some(average_length),
3431 ma_length: Some(ma_length),
3432 factor: Some(factor),
3433 };
3434 let stream =
3435 WavetrendStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
3436 Ok(WavetrendStreamPy { stream })
3437 }
3438
3439 fn update(&mut self, value: f64) -> Option<(f64, f64, f64)> {
3440 self.stream.update(value)
3441 }
3442}
3443
3444#[cfg(feature = "python")]
3445#[pyfunction(name = "wavetrend_batch")]
3446#[pyo3(signature = (data, channel_length_range, average_length_range, ma_length_range, factor_range, kernel=None))]
3447pub fn wavetrend_batch_py<'py>(
3448 py: Python<'py>,
3449 data: numpy::PyReadonlyArray1<'py, f64>,
3450 channel_length_range: (usize, usize, usize),
3451 average_length_range: (usize, usize, usize),
3452 ma_length_range: (usize, usize, usize),
3453 factor_range: (f64, f64, f64),
3454 kernel: Option<&str>,
3455) -> PyResult<Bound<'py, PyDict>> {
3456 use numpy::{IntoPyArray, PyArrayMethods};
3457
3458 let slice_in = data.as_slice()?;
3459 let kern = validate_kernel(kernel, true)?;
3460
3461 let sweep = WavetrendBatchRange {
3462 channel_length: channel_length_range,
3463 average_length: average_length_range,
3464 ma_length: ma_length_range,
3465 factor: factor_range,
3466 };
3467
3468 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
3469 let rows = combos.len();
3470 let cols = slice_in.len();
3471
3472 let total = rows
3473 .checked_mul(cols)
3474 .ok_or_else(|| PyValueError::new_err("rows*cols overflow for wavetrend_batch"))?;
3475 let wt1_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
3476 let wt2_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
3477 let wt_diff_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
3478
3479 let slice_wt1 = unsafe { wt1_arr.as_slice_mut()? };
3480 let slice_wt2 = unsafe { wt2_arr.as_slice_mut()? };
3481 let slice_wt_diff = unsafe { wt_diff_arr.as_slice_mut()? };
3482
3483 let combos = py
3484 .allow_threads(|| {
3485 let kernel = match kern {
3486 Kernel::Auto => detect_best_batch_kernel(),
3487 k => k,
3488 };
3489 let simd = match kernel {
3490 Kernel::Avx512Batch => Kernel::Avx512,
3491 Kernel::Avx2Batch => Kernel::Avx2,
3492 Kernel::ScalarBatch => Kernel::Scalar,
3493 _ => unreachable!(),
3494 };
3495 wavetrend_batch_inner_into(
3496 slice_in,
3497 &sweep,
3498 simd,
3499 true,
3500 slice_wt1,
3501 slice_wt2,
3502 slice_wt_diff,
3503 )
3504 })
3505 .map_err(|e| PyValueError::new_err(e.to_string()))?;
3506
3507 let dict = PyDict::new(py);
3508 dict.set_item("wt1", wt1_arr.reshape((rows, cols))?)?;
3509 dict.set_item("wt2", wt2_arr.reshape((rows, cols))?)?;
3510 dict.set_item("wt_diff", wt_diff_arr.reshape((rows, cols))?)?;
3511 dict.set_item(
3512 "channel_lengths",
3513 combos
3514 .iter()
3515 .map(|p| p.channel_length.unwrap() as u64)
3516 .collect::<Vec<_>>()
3517 .into_pyarray(py),
3518 )?;
3519 dict.set_item(
3520 "average_lengths",
3521 combos
3522 .iter()
3523 .map(|p| p.average_length.unwrap() as u64)
3524 .collect::<Vec<_>>()
3525 .into_pyarray(py),
3526 )?;
3527 dict.set_item(
3528 "ma_lengths",
3529 combos
3530 .iter()
3531 .map(|p| p.ma_length.unwrap() as u64)
3532 .collect::<Vec<_>>()
3533 .into_pyarray(py),
3534 )?;
3535 dict.set_item(
3536 "factors",
3537 combos
3538 .iter()
3539 .map(|p| p.factor.unwrap())
3540 .collect::<Vec<_>>()
3541 .into_pyarray(py),
3542 )?;
3543
3544 Ok(dict)
3545}
3546
3547#[cfg(all(feature = "python", feature = "cuda"))]
3548#[pyfunction(name = "wavetrend_cuda_batch_dev")]
3549#[pyo3(signature = (data_f32, channel_length_range, average_length_range, ma_length_range, factor_range, device_id=0))]
3550pub fn wavetrend_cuda_batch_dev_py<'py>(
3551 py: Python<'py>,
3552 data_f32: numpy::PyReadonlyArray1<'py, f32>,
3553 channel_length_range: (usize, usize, usize),
3554 average_length_range: (usize, usize, usize),
3555 ma_length_range: (usize, usize, usize),
3556 factor_range: (f64, f64, f64),
3557 device_id: usize,
3558) -> PyResult<Bound<'py, PyDict>> {
3559 use numpy::IntoPyArray;
3560
3561 if !cuda_available() {
3562 return Err(PyValueError::new_err("CUDA not available"));
3563 }
3564
3565 let slice_in = data_f32.as_slice()?;
3566 let sweep = WavetrendBatchRange {
3567 channel_length: channel_length_range,
3568 average_length: average_length_range,
3569 ma_length: ma_length_range,
3570 factor: factor_range,
3571 };
3572
3573 let (batch, ctx, dev_id) = py.allow_threads(|| {
3574 let cuda =
3575 CudaWavetrend::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
3576 let ctx = cuda.context_arc();
3577 let dev_id = cuda.device_id();
3578 cuda.wavetrend_batch_dev(slice_in, &sweep)
3579 .map(|b| (b, ctx, dev_id))
3580 .map_err(|e| PyValueError::new_err(e.to_string()))
3581 })?;
3582
3583 let dict = PyDict::new(py);
3584 dict.set_item(
3585 "wt1",
3586 Py::new(
3587 py,
3588 WavetrendDeviceArrayF32Py {
3589 inner: batch.wt1,
3590 _ctx: ctx.clone(),
3591 device_id: dev_id,
3592 },
3593 )?,
3594 )?;
3595 dict.set_item(
3596 "wt2",
3597 Py::new(
3598 py,
3599 WavetrendDeviceArrayF32Py {
3600 inner: batch.wt2,
3601 _ctx: ctx.clone(),
3602 device_id: dev_id,
3603 },
3604 )?,
3605 )?;
3606 dict.set_item(
3607 "wt_diff",
3608 Py::new(
3609 py,
3610 WavetrendDeviceArrayF32Py {
3611 inner: batch.wt_diff,
3612 _ctx: ctx,
3613 device_id: dev_id,
3614 },
3615 )?,
3616 )?;
3617
3618 let (c0, c1, cstep) = channel_length_range;
3619 let (a0, a1, astep) = average_length_range;
3620 let (m0, m1, mstep) = ma_length_range;
3621 let (f0, f1, fstep) = factor_range;
3622 let channel_axis: Vec<usize> = if cstep == 0 {
3623 vec![c0]
3624 } else {
3625 (c0..=c1).step_by(cstep).collect()
3626 };
3627 let average_axis: Vec<usize> = if astep == 0 {
3628 vec![a0]
3629 } else {
3630 (a0..=a1).step_by(astep).collect()
3631 };
3632 let ma_axis: Vec<usize> = if mstep == 0 {
3633 vec![m0]
3634 } else {
3635 (m0..=m1).step_by(mstep).collect()
3636 };
3637 let mut factor_axis: Vec<f64> = Vec::new();
3638 if fstep.abs() < f64::EPSILON || (f0 - f1).abs() < f64::EPSILON {
3639 factor_axis.push(f0);
3640 } else {
3641 let mut v = f0;
3642 while v <= f1 + fstep.abs() * 1e-12 {
3643 factor_axis.push(v);
3644 v += fstep;
3645 }
3646 }
3647
3648 dict.set_item("channel_lengths", channel_axis.into_pyarray(py))?;
3649 dict.set_item("average_lengths", average_axis.into_pyarray(py))?;
3650 dict.set_item("ma_lengths", ma_axis.into_pyarray(py))?;
3651 dict.set_item("factors", factor_axis.into_pyarray(py))?;
3652
3653 Ok(dict)
3654}
3655
3656#[cfg(all(feature = "python", feature = "cuda"))]
3657#[pyfunction(name = "wavetrend_cuda_many_series_one_param_dev")]
3658#[pyo3(signature = (data_tm_f32, channel_length, average_length, ma_length, factor, device_id=0))]
3659pub fn wavetrend_cuda_many_series_one_param_dev_py<'py>(
3660 py: Python<'py>,
3661 data_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
3662 channel_length: usize,
3663 average_length: usize,
3664 ma_length: usize,
3665 factor: f64,
3666 device_id: usize,
3667) -> PyResult<Bound<'py, PyDict>> {
3668 use numpy::PyUntypedArrayMethods;
3669
3670 if !cuda_available() {
3671 return Err(PyValueError::new_err("CUDA not available"));
3672 }
3673
3674 let shape = data_tm_f32.shape();
3675 if shape.len() != 2 {
3676 return Err(PyValueError::new_err("expected 2D array (rows x cols)"));
3677 }
3678 let rows = shape[0];
3679 let cols = shape[1];
3680 let flat = data_tm_f32.as_slice()?;
3681
3682 let params = WavetrendParams {
3683 channel_length: Some(channel_length),
3684 average_length: Some(average_length),
3685 ma_length: Some(ma_length),
3686 factor: Some(factor),
3687 };
3688
3689 let (wt1, wt2, wt_diff, ctx, dev_id) = py.allow_threads(|| {
3690 let cuda =
3691 CudaWavetrend::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
3692 let ctx = cuda.context_arc();
3693 let dev_id = cuda.device_id();
3694 cuda.wavetrend_many_series_one_param_time_major_dev(flat, cols, rows, ¶ms)
3695 .map(|(a, b, c)| (a, b, c, ctx, dev_id))
3696 .map_err(|e| PyValueError::new_err(e.to_string()))
3697 })?;
3698
3699 let dict = PyDict::new(py);
3700 dict.set_item(
3701 "wt1",
3702 Py::new(
3703 py,
3704 WavetrendDeviceArrayF32Py {
3705 inner: wt1,
3706 _ctx: ctx.clone(),
3707 device_id: dev_id,
3708 },
3709 )?,
3710 )?;
3711 dict.set_item(
3712 "wt2",
3713 Py::new(
3714 py,
3715 WavetrendDeviceArrayF32Py {
3716 inner: wt2,
3717 _ctx: ctx.clone(),
3718 device_id: dev_id,
3719 },
3720 )?,
3721 )?;
3722 dict.set_item(
3723 "wt_diff",
3724 Py::new(
3725 py,
3726 WavetrendDeviceArrayF32Py {
3727 inner: wt_diff,
3728 _ctx: ctx,
3729 device_id: dev_id,
3730 },
3731 )?,
3732 )?;
3733 dict.set_item("rows", rows)?;
3734 dict.set_item("cols", cols)?;
3735 dict.set_item("channel_length", channel_length)?;
3736 dict.set_item("average_length", average_length)?;
3737 dict.set_item("ma_length", ma_length)?;
3738 dict.set_item("factor", factor)?;
3739
3740 Ok(dict)
3741}
3742
3743#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3744#[wasm_bindgen]
3745pub fn wavetrend_js(
3746 data: &[f64],
3747 channel_length: usize,
3748 average_length: usize,
3749 ma_length: usize,
3750 factor: f64,
3751) -> Result<Vec<f64>, JsValue> {
3752 let params = WavetrendParams {
3753 channel_length: Some(channel_length),
3754 average_length: Some(average_length),
3755 ma_length: Some(ma_length),
3756 factor: Some(factor),
3757 };
3758 let input = WavetrendInput::from_slice(data, params);
3759
3760 let mut output = vec![0.0; data.len() * 3];
3761 let (wt1_part, rest) = output.split_at_mut(data.len());
3762 let (wt2_part, wt_diff_part) = rest.split_at_mut(data.len());
3763
3764 wavetrend_into_slice(wt1_part, wt2_part, wt_diff_part, &input, Kernel::Auto)
3765 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3766
3767 Ok(output)
3768}
3769
3770#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3771#[wasm_bindgen]
3772pub fn wavetrend_into(
3773 in_ptr: *const f64,
3774 wt1_ptr: *mut f64,
3775 wt2_ptr: *mut f64,
3776 wt_diff_ptr: *mut f64,
3777 len: usize,
3778 channel_length: usize,
3779 average_length: usize,
3780 ma_length: usize,
3781 factor: f64,
3782) -> Result<(), JsValue> {
3783 if in_ptr.is_null() || wt1_ptr.is_null() || wt2_ptr.is_null() || wt_diff_ptr.is_null() {
3784 return Err(JsValue::from_str("Null pointer provided"));
3785 }
3786
3787 unsafe {
3788 let data = std::slice::from_raw_parts(in_ptr, len);
3789 let params = WavetrendParams {
3790 channel_length: Some(channel_length),
3791 average_length: Some(average_length),
3792 ma_length: Some(ma_length),
3793 factor: Some(factor),
3794 };
3795 let input = WavetrendInput::from_slice(data, params);
3796
3797 let needs_temp = in_ptr as *const u8 == wt1_ptr as *const u8
3798 || in_ptr as *const u8 == wt2_ptr as *const u8
3799 || in_ptr as *const u8 == wt_diff_ptr as *const u8;
3800
3801 if needs_temp {
3802 let mut temp = vec![0.0; len * 3];
3803 let (temp_wt1, rest) = temp.split_at_mut(len);
3804 let (temp_wt2, temp_wt_diff) = rest.split_at_mut(len);
3805
3806 wavetrend_into_slice(temp_wt1, temp_wt2, temp_wt_diff, &input, Kernel::Auto)
3807 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3808
3809 let wt1_out = std::slice::from_raw_parts_mut(wt1_ptr, len);
3810 let wt2_out = std::slice::from_raw_parts_mut(wt2_ptr, len);
3811 let wt_diff_out = std::slice::from_raw_parts_mut(wt_diff_ptr, len);
3812
3813 wt1_out.copy_from_slice(temp_wt1);
3814 wt2_out.copy_from_slice(temp_wt2);
3815 wt_diff_out.copy_from_slice(temp_wt_diff);
3816 } else {
3817 let wt1_out = std::slice::from_raw_parts_mut(wt1_ptr, len);
3818 let wt2_out = std::slice::from_raw_parts_mut(wt2_ptr, len);
3819 let wt_diff_out = std::slice::from_raw_parts_mut(wt_diff_ptr, len);
3820
3821 wavetrend_into_slice(wt1_out, wt2_out, wt_diff_out, &input, Kernel::Auto)
3822 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3823 }
3824
3825 Ok(())
3826 }
3827}
3828
3829#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3830#[wasm_bindgen]
3831pub fn wavetrend_alloc(len: usize) -> *mut f64 {
3832 let mut vec = Vec::<f64>::with_capacity(len);
3833 let ptr = vec.as_mut_ptr();
3834 std::mem::forget(vec);
3835 ptr
3836}
3837
3838#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3839#[wasm_bindgen]
3840pub fn wavetrend_free(ptr: *mut f64, len: usize) {
3841 if !ptr.is_null() {
3842 unsafe {
3843 let _ = Vec::from_raw_parts(ptr, len, len);
3844 }
3845 }
3846}
3847
3848#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3849#[derive(Serialize, Deserialize)]
3850pub struct WavetrendBatchConfig {
3851 pub channel_length_range: (usize, usize, usize),
3852 pub average_length_range: (usize, usize, usize),
3853 pub ma_length_range: (usize, usize, usize),
3854 pub factor_range: (f64, f64, f64),
3855}
3856
3857#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3858#[derive(Serialize, Deserialize)]
3859pub struct WavetrendBatchJsOutput {
3860 pub wt1_values: Vec<f64>,
3861 pub wt2_values: Vec<f64>,
3862 pub wt_diff_values: Vec<f64>,
3863 pub channel_lengths: Vec<usize>,
3864 pub average_lengths: Vec<usize>,
3865 pub ma_lengths: Vec<usize>,
3866 pub factors: Vec<f64>,
3867 pub rows: usize,
3868 pub cols: usize,
3869}
3870
3871#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3872#[wasm_bindgen(js_name = wavetrend_batch)]
3873pub fn wavetrend_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
3874 let config: WavetrendBatchConfig =
3875 serde_wasm_bindgen::from_value(config).map_err(|e| JsValue::from_str(&e.to_string()))?;
3876
3877 let sweep = WavetrendBatchRange {
3878 channel_length: (
3879 config.channel_length_range.0,
3880 config.channel_length_range.1,
3881 config.channel_length_range.2,
3882 ),
3883 average_length: (
3884 config.average_length_range.0,
3885 config.average_length_range.1,
3886 config.average_length_range.2,
3887 ),
3888 ma_length: (
3889 config.ma_length_range.0,
3890 config.ma_length_range.1,
3891 config.ma_length_range.2,
3892 ),
3893 factor: (
3894 config.factor_range.0,
3895 config.factor_range.1,
3896 config.factor_range.2,
3897 ),
3898 };
3899
3900 let batch_output = wavetrend_batch_with_kernel(data, &sweep, Kernel::Auto)
3901 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3902
3903 let js_output = WavetrendBatchJsOutput {
3904 wt1_values: batch_output.wt1,
3905 wt2_values: batch_output.wt2,
3906 wt_diff_values: batch_output.wt_diff,
3907 channel_lengths: batch_output
3908 .combos
3909 .iter()
3910 .map(|p| p.channel_length.unwrap())
3911 .collect(),
3912 average_lengths: batch_output
3913 .combos
3914 .iter()
3915 .map(|p| p.average_length.unwrap())
3916 .collect(),
3917 ma_lengths: batch_output
3918 .combos
3919 .iter()
3920 .map(|p| p.ma_length.unwrap())
3921 .collect(),
3922 factors: batch_output
3923 .combos
3924 .iter()
3925 .map(|p| p.factor.unwrap())
3926 .collect(),
3927 rows: batch_output.combos.len(),
3928 cols: data.len(),
3929 };
3930
3931 serde_wasm_bindgen::to_value(&js_output)
3932 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
3933}