1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::CudaDma;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::cuda::moving_averages::DeviceArrayF32;
7#[cfg(all(feature = "python", feature = "cuda"))]
8use cust::context::Context;
9#[cfg(all(feature = "python", feature = "cuda"))]
10use numpy::PyUntypedArrayMethods;
11#[cfg(feature = "python")]
12use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2};
13#[cfg(feature = "python")]
14use pyo3::exceptions::PyValueError;
15#[cfg(feature = "python")]
16use pyo3::prelude::*;
17#[cfg(feature = "python")]
18use pyo3::types::PyDict;
19#[cfg(all(feature = "python", feature = "cuda"))]
20use std::sync::Arc;
21
22#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
23use serde::{Deserialize, Serialize};
24#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
25use wasm_bindgen::prelude::*;
26
27use crate::utilities::data_loader::{source_type, Candles};
28use crate::utilities::enums::Kernel;
29use crate::utilities::helpers::{
30 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
31 make_uninit_matrix,
32};
33#[cfg(feature = "python")]
34use crate::utilities::kernel_validation::validate_kernel;
35use aligned_vec::{AVec, CACHELINE_ALIGN};
36
37#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
38use core::arch::x86_64::*;
39
40#[cfg(not(target_arch = "wasm32"))]
41use rayon::prelude::*;
42
43use std::convert::AsRef;
44use std::error::Error;
45use std::mem::MaybeUninit;
46use thiserror::Error;
47
48impl<'a> AsRef<[f64]> for DmaInput<'a> {
49 #[inline(always)]
50 fn as_ref(&self) -> &[f64] {
51 match &self.data {
52 DmaData::Slice(slice) => slice,
53 DmaData::Candles { candles, source } => source_type(candles, source),
54 }
55 }
56}
57
58#[derive(Debug, Clone)]
59pub enum DmaData<'a> {
60 Candles {
61 candles: &'a Candles,
62 source: &'a str,
63 },
64 Slice(&'a [f64]),
65}
66
67#[derive(Debug, Clone)]
68pub struct DmaOutput {
69 pub values: Vec<f64>,
70}
71
72#[derive(Debug, Clone)]
73#[cfg_attr(
74 all(target_arch = "wasm32", feature = "wasm"),
75 derive(Serialize, Deserialize)
76)]
77pub struct DmaParams {
78 pub hull_length: Option<usize>,
79 pub ema_length: Option<usize>,
80 pub ema_gain_limit: Option<usize>,
81 pub hull_ma_type: Option<String>,
82}
83
84impl Default for DmaParams {
85 fn default() -> Self {
86 Self {
87 hull_length: Some(7),
88 ema_length: Some(20),
89 ema_gain_limit: Some(50),
90 hull_ma_type: Some("WMA".to_string()),
91 }
92 }
93}
94
95#[derive(Debug, Clone)]
96pub struct DmaInput<'a> {
97 pub data: DmaData<'a>,
98 pub params: DmaParams,
99}
100
101impl<'a> DmaInput<'a> {
102 #[inline]
103 pub fn from_candles(c: &'a Candles, s: &'a str, p: DmaParams) -> Self {
104 Self {
105 data: DmaData::Candles {
106 candles: c,
107 source: s,
108 },
109 params: p,
110 }
111 }
112
113 #[inline]
114 pub fn from_slice(sl: &'a [f64], p: DmaParams) -> Self {
115 Self {
116 data: DmaData::Slice(sl),
117 params: p,
118 }
119 }
120
121 #[inline]
122 pub fn with_default_candles(c: &'a Candles) -> Self {
123 Self::from_candles(c, "close", DmaParams::default())
124 }
125
126 #[inline]
127 pub fn get_hull_length(&self) -> usize {
128 self.params.hull_length.unwrap_or(7)
129 }
130
131 #[inline]
132 pub fn get_ema_length(&self) -> usize {
133 self.params.ema_length.unwrap_or(20)
134 }
135
136 #[inline]
137 pub fn get_ema_gain_limit(&self) -> usize {
138 self.params.ema_gain_limit.unwrap_or(50)
139 }
140
141 #[inline]
142 pub fn get_hull_ma_type(&self) -> String {
143 self.params
144 .hull_ma_type
145 .clone()
146 .unwrap_or_else(|| "WMA".to_string())
147 }
148
149 #[inline]
150 pub fn hull_ma_type_str(&self) -> &str {
151 self.params.hull_ma_type.as_deref().unwrap_or("WMA")
152 }
153}
154
155#[derive(Clone, Debug)]
156pub struct DmaBuilder {
157 hull_length: Option<usize>,
158 ema_length: Option<usize>,
159 ema_gain_limit: Option<usize>,
160 hull_ma_type: Option<String>,
161 kernel: Kernel,
162}
163
164impl Default for DmaBuilder {
165 fn default() -> Self {
166 Self {
167 hull_length: None,
168 ema_length: None,
169 ema_gain_limit: None,
170 hull_ma_type: None,
171 kernel: Kernel::Auto,
172 }
173 }
174}
175
176impl DmaBuilder {
177 #[inline(always)]
178 pub fn new() -> Self {
179 Self::default()
180 }
181
182 #[inline(always)]
183 pub fn hull_length(mut self, val: usize) -> Self {
184 self.hull_length = Some(val);
185 self
186 }
187
188 #[inline(always)]
189 pub fn ema_length(mut self, val: usize) -> Self {
190 self.ema_length = Some(val);
191 self
192 }
193
194 #[inline(always)]
195 pub fn ema_gain_limit(mut self, val: usize) -> Self {
196 self.ema_gain_limit = Some(val);
197 self
198 }
199
200 #[inline(always)]
201 pub fn hull_ma_type(mut self, val: String) -> Self {
202 self.hull_ma_type = Some(val);
203 self
204 }
205
206 #[inline(always)]
207 pub fn kernel(mut self, k: Kernel) -> Self {
208 self.kernel = k;
209 self
210 }
211
212 #[inline(always)]
213 pub fn apply(self, c: &Candles) -> Result<DmaOutput, DmaError> {
214 let p = DmaParams {
215 hull_length: self.hull_length,
216 ema_length: self.ema_length,
217 ema_gain_limit: self.ema_gain_limit,
218 hull_ma_type: self.hull_ma_type,
219 };
220 let i = DmaInput::from_candles(c, "close", p);
221 dma_with_kernel(&i, self.kernel)
222 }
223
224 #[inline(always)]
225 pub fn apply_slice(self, d: &[f64]) -> Result<DmaOutput, DmaError> {
226 let p = DmaParams {
227 hull_length: self.hull_length,
228 ema_length: self.ema_length,
229 ema_gain_limit: self.ema_gain_limit,
230 hull_ma_type: self.hull_ma_type,
231 };
232 let i = DmaInput::from_slice(d, p);
233 dma_with_kernel(&i, self.kernel)
234 }
235
236 #[inline(always)]
237 pub fn into_stream(self) -> Result<DmaStream, DmaError> {
238 let p = DmaParams {
239 hull_length: self.hull_length,
240 ema_length: self.ema_length,
241 ema_gain_limit: self.ema_gain_limit,
242 hull_ma_type: self.hull_ma_type,
243 };
244 DmaStream::try_new(p)
245 }
246}
247
248#[derive(Debug, Error)]
249pub enum DmaError {
250 #[error("dma: Input data slice is empty.")]
251 EmptyInputData,
252
253 #[error("dma: All values are NaN.")]
254 AllValuesNaN,
255
256 #[error("dma: Invalid period: period = {period}, data length = {data_len}")]
257 InvalidPeriod { period: usize, data_len: usize },
258
259 #[error("dma: Not enough valid data: needed = {needed}, valid = {valid}")]
260 NotEnoughValidData { needed: usize, valid: usize },
261
262 #[error("dma: Invalid Hull MA type: {value}. Must be 'WMA' or 'EMA'.")]
263 InvalidHullMAType { value: String },
264
265 #[error("dma: Output slice length mismatch: expected = {expected}, got = {got}")]
266 OutputLengthMismatch { expected: usize, got: usize },
267
268 #[error("dma: Invalid range expansion: start = {start}, end = {end}, step = {step}")]
269 InvalidRange {
270 start: usize,
271 end: usize,
272 step: usize,
273 },
274
275 #[error("dma: Invalid kernel for batch path: {0:?}")]
276 InvalidKernelForBatch(Kernel),
277}
278
279#[inline(always)]
280pub fn dma(input: &DmaInput) -> Result<DmaOutput, DmaError> {
281 dma_with_kernel(input, Kernel::Auto)
282}
283
284#[inline(always)]
285pub fn dma_with_kernel(input: &DmaInput, kernel: Kernel) -> Result<DmaOutput, DmaError> {
286 let (data, hull_len, ema_len, ema_gain_limit, hull_ma_type, first, chosen) =
287 dma_prepare(input, kernel)?;
288
289 let sqrt_len = (hull_len as f64).sqrt().round() as usize;
290 let warmup_end = first + hull_len.max(ema_len) + sqrt_len - 1;
291
292 let mut out = alloc_with_nan_prefix(data.len(), warmup_end);
293 dma_compute_into(
294 data,
295 hull_len,
296 ema_len,
297 ema_gain_limit,
298 &hull_ma_type,
299 first,
300 chosen,
301 &mut out,
302 );
303 Ok(DmaOutput { values: out })
304}
305
306#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
307#[inline(always)]
308pub fn dma_into(input: &DmaInput, out: &mut [f64]) -> Result<(), DmaError> {
309 let (data, hull_len, ema_len, ema_gain_limit, hull_ma_type, first, chosen) =
310 dma_prepare(input, Kernel::Auto)?;
311
312 if out.len() != data.len() {
313 return Err(DmaError::OutputLengthMismatch {
314 expected: data.len(),
315 got: out.len(),
316 });
317 }
318
319 let sqrt_len = (hull_len as f64).sqrt().round() as usize;
320 let warmup_end = first + hull_len.max(ema_len) + sqrt_len - 1;
321 let end = warmup_end.min(out.len());
322 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
323 for v in &mut out[..end] {
324 *v = qnan;
325 }
326
327 dma_compute_into(
328 data,
329 hull_len,
330 ema_len,
331 ema_gain_limit,
332 &hull_ma_type,
333 first,
334 chosen,
335 out,
336 );
337 Ok(())
338}
339
340#[inline(always)]
341pub fn dma_into_slice(dst: &mut [f64], input: &DmaInput, kern: Kernel) -> Result<(), DmaError> {
342 let (data, hull_len, ema_len, ema_gain_limit, hull_ma_type, first, chosen) =
343 dma_prepare(input, kern)?;
344
345 if dst.len() != data.len() {
346 return Err(DmaError::OutputLengthMismatch {
347 expected: data.len(),
348 got: dst.len(),
349 });
350 }
351
352 dma_compute_into(
353 data,
354 hull_len,
355 ema_len,
356 ema_gain_limit,
357 &hull_ma_type,
358 first,
359 chosen,
360 dst,
361 );
362
363 let sqrt_len = (hull_len as f64).sqrt().round() as usize;
364 let warmup_end = first + hull_len.max(ema_len) + sqrt_len - 1;
365 let end = warmup_end.min(dst.len());
366 for v in &mut dst[..end] {
367 *v = f64::NAN;
368 }
369 Ok(())
370}
371
372#[inline(always)]
373fn dma_prepare<'a>(
374 input: &'a DmaInput,
375 kernel: Kernel,
376) -> Result<(&'a [f64], usize, usize, usize, &'a str, usize, Kernel), DmaError> {
377 let data: &[f64] = input.as_ref();
378 let len = data.len();
379 if len == 0 {
380 return Err(DmaError::EmptyInputData);
381 }
382
383 let first = data
384 .iter()
385 .position(|x| !x.is_nan())
386 .ok_or(DmaError::AllValuesNaN)?;
387 let hull_length = input.get_hull_length();
388 let ema_length = input.get_ema_length();
389 let ema_gain_limit = input.get_ema_gain_limit();
390 let hull_ma_type = input.hull_ma_type_str();
391
392 if hull_length == 0 || hull_length > len {
393 return Err(DmaError::InvalidPeriod {
394 period: hull_length,
395 data_len: len,
396 });
397 }
398 if ema_length == 0 || ema_length > len {
399 return Err(DmaError::InvalidPeriod {
400 period: ema_length,
401 data_len: len,
402 });
403 }
404
405 let sqrt_len = (hull_length as f64).sqrt().round() as usize;
406 let needed = hull_length.max(ema_length) + sqrt_len;
407 if len - first < needed {
408 return Err(DmaError::NotEnoughValidData {
409 needed,
410 valid: len - first,
411 });
412 }
413 if hull_ma_type != "WMA" && hull_ma_type != "EMA" {
414 return Err(DmaError::InvalidHullMAType {
415 value: hull_ma_type.to_string(),
416 });
417 }
418 let chosen = match kernel {
419 Kernel::Auto => Kernel::Scalar,
420 k => k,
421 };
422 Ok((
423 data,
424 hull_length,
425 ema_length,
426 ema_gain_limit,
427 hull_ma_type,
428 first,
429 chosen,
430 ))
431}
432
433#[inline(always)]
434fn dma_compute_into(
435 data: &[f64],
436 hull_length: usize,
437 ema_length: usize,
438 ema_gain_limit: usize,
439 hull_ma_type: &str,
440 first: usize,
441 kernel: Kernel,
442 out: &mut [f64],
443) {
444 unsafe {
445 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
446 {
447 if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
448 dma_simd128(
449 data,
450 hull_length,
451 ema_length,
452 ema_gain_limit,
453 hull_ma_type,
454 first,
455 out,
456 );
457 return;
458 }
459 }
460
461 match kernel {
462 Kernel::Scalar | Kernel::ScalarBatch => dma_scalar(
463 data,
464 hull_length,
465 ema_length,
466 ema_gain_limit,
467 hull_ma_type,
468 first,
469 out,
470 ),
471 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
472 Kernel::Avx2 | Kernel::Avx2Batch => dma_avx2(
473 data,
474 hull_length,
475 ema_length,
476 ema_gain_limit,
477 hull_ma_type,
478 first,
479 out,
480 ),
481 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
482 Kernel::Avx512 | Kernel::Avx512Batch => dma_avx512(
483 data,
484 hull_length,
485 ema_length,
486 ema_gain_limit,
487 hull_ma_type,
488 first,
489 out,
490 ),
491 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
492 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => dma_scalar(
493 data,
494 hull_length,
495 ema_length,
496 ema_gain_limit,
497 hull_ma_type,
498 first,
499 out,
500 ),
501 _ => unreachable!(),
502 }
503 }
504}
505
506#[inline]
507pub fn dma_scalar(
508 data: &[f64],
509 hull_length: usize,
510 ema_length: usize,
511 ema_gain_limit: usize,
512 hull_ma_type: &str,
513 first: usize,
514 out: &mut [f64],
515) {
516 let n = data.len();
517 if n == 0 {
518 return;
519 }
520
521 let alpha_e = 2.0 / (ema_length as f64 + 1.0);
522 let one_minus_alpha_e = 1.0 - alpha_e;
523 let i0_e = first + ema_length.saturating_sub(1);
524 let mut e0_prev = 0.0;
525 let mut e0_init_done = false;
526 let mut ec_prev = 0.0;
527 let mut ec_init_done = false;
528
529 let half = hull_length / 2;
530 let sqrt_len = (hull_length as f64).sqrt().round() as usize;
531
532 let mut hull_val = f64::NAN;
533
534 let wsum = |p: usize| -> f64 { (p * (p + 1)) as f64 / 2.0 };
535 let i0_half = first + half.saturating_sub(1);
536 let i0_full = first + hull_length.saturating_sub(1);
537
538 let mut a_half = 0.0;
539 let mut s_half = 0.0;
540 let mut half_ready = false;
541
542 let mut a_full = 0.0;
543 let mut s_full = 0.0;
544 let mut full_ready = false;
545
546 let mut diff_ring: Vec<f64> = Vec::with_capacity(sqrt_len.max(1));
547 let mut diff_pos: usize = 0;
548 let mut diff_filled = 0usize;
549
550 let mut a_diff = 0.0;
551 let mut s_diff = 0.0;
552 let mut diff_wma_init_done = false;
553
554 let alpha_sqrt = if sqrt_len > 0 {
555 2.0 / (sqrt_len as f64 + 1.0)
556 } else {
557 0.0
558 };
559 let mut diff_ema = 0.0;
560 let mut diff_ema_init_done = false;
561 let mut diff_sum_seed = 0.0;
562
563 let mut e_half_prev = 0.0;
564 let mut e_half_init_done = false;
565 let mut e_full_prev = 0.0;
566 let mut e_full_init_done = false;
567 let alpha_half = if half > 0 {
568 2.0 / (half as f64 + 1.0)
569 } else {
570 0.0
571 };
572 let alpha_full = if hull_length > 0 {
573 2.0 / (hull_length as f64 + 1.0)
574 } else {
575 0.0
576 };
577
578 let is_wma = hull_ma_type == "WMA";
579
580 for i in first..n {
581 let x = data[i];
582
583 if !e0_init_done {
584 if i >= i0_e {
585 let start = i + 1 - ema_length;
586 let mut sum = 0.0;
587 for k in start..=i {
588 sum += data[k];
589 }
590 e0_prev = sum / ema_length as f64;
591 e0_init_done = true;
592 }
593 } else {
594 e0_prev = x.mul_add(alpha_e, one_minus_alpha_e * e0_prev);
595 }
596
597 let mut diff_now = f64::NAN;
598
599 if is_wma {
600 if half > 0 {
601 if !half_ready {
602 if i >= i0_half {
603 let start = i + 1 - half;
604 let mut sum = 0.0;
605 let mut wsum_local = 0.0;
606 for (j, idx) in (start..=i).enumerate() {
607 let w = (j + 1) as f64;
608 let v = data[idx];
609 sum += v;
610 wsum_local += w * v;
611 }
612 a_half = sum;
613 s_half = wsum_local;
614 half_ready = true;
615 }
616 } else {
617 let a_prev = a_half;
618 a_half = a_prev + x - data[i - half];
619 s_half = s_half + (half as f64) * x - a_prev;
620 }
621 }
622
623 if hull_length > 0 {
624 if !full_ready {
625 if i >= i0_full {
626 let start = i + 1 - hull_length;
627 let mut sum = 0.0;
628 let mut wsum_local = 0.0;
629 for (j, idx) in (start..=i).enumerate() {
630 let w = (j + 1) as f64;
631 let v = data[idx];
632 sum += v;
633 wsum_local += w * v;
634 }
635 a_full = sum;
636 s_full = wsum_local;
637 full_ready = true;
638 }
639 } else {
640 let a_prev = a_full;
641 a_full = a_prev + x - data[i - hull_length];
642 s_full = s_full + (hull_length as f64) * x - a_prev;
643 }
644 }
645
646 if half_ready && full_ready {
647 let w_half = s_half / wsum(half).max(1.0);
648 let w_full = s_full / wsum(hull_length).max(1.0);
649 diff_now = 2.0 * w_half - w_full;
650 }
651 } else {
652 if half > 0 {
653 if !e_half_init_done {
654 if i >= i0_half {
655 let start = i + 1 - half;
656 let mut sum = 0.0;
657 for k in start..=i {
658 sum += data[k];
659 }
660 e_half_prev = sum / half as f64;
661 e_half_init_done = true;
662 }
663 } else {
664 e_half_prev = x.mul_add(alpha_half, (1.0 - alpha_half) * e_half_prev);
665 }
666 }
667
668 if hull_length > 0 {
669 if !e_full_init_done {
670 if i >= i0_full {
671 let start = i + 1 - hull_length;
672 let mut sum = 0.0;
673 for k in start..=i {
674 sum += data[k];
675 }
676 e_full_prev = sum / hull_length as f64;
677 e_full_init_done = true;
678 }
679 } else {
680 e_full_prev = x.mul_add(alpha_full, (1.0 - alpha_full) * e_full_prev);
681 }
682 }
683
684 if e_half_init_done && e_full_init_done {
685 diff_now = 2.0 * e_half_prev - e_full_prev;
686 }
687 }
688
689 if diff_now.is_finite() && sqrt_len > 0 {
690 if diff_filled < sqrt_len {
691 diff_ring.push(diff_now);
692 diff_sum_seed += diff_now;
693 diff_filled += 1;
694
695 if diff_filled == sqrt_len {
696 if is_wma {
697 a_diff = 0.0;
698 s_diff = 0.0;
699 for (j, &v) in diff_ring.iter().enumerate() {
700 let w = (j + 1) as f64;
701 a_diff += v;
702 s_diff += w * v;
703 }
704 diff_wma_init_done = true;
705 hull_val = s_diff / wsum(sqrt_len).max(1.0);
706 } else {
707 diff_ema = diff_sum_seed / sqrt_len as f64;
708 diff_ema_init_done = true;
709 hull_val = diff_ema;
710 }
711 }
712 } else {
713 let old = diff_ring[diff_pos];
714 diff_ring[diff_pos] = diff_now;
715 diff_pos = (diff_pos + 1) % sqrt_len;
716
717 if is_wma {
718 let a_prev = a_diff;
719 a_diff = a_prev + diff_now - old;
720 s_diff = s_diff + (sqrt_len as f64) * diff_now - a_prev;
721 hull_val = s_diff / wsum(sqrt_len).max(1.0);
722 } else {
723 diff_ema = diff_now.mul_add(alpha_sqrt, (1.0 - alpha_sqrt) * diff_ema);
724 hull_val = diff_ema;
725 }
726 }
727 }
728
729 let mut ec_now = f64::NAN;
730 if e0_init_done {
731 if !ec_init_done {
732 ec_prev = e0_prev;
733 ec_init_done = true;
734 ec_now = ec_prev;
735 } else {
736 let dx = x - ec_prev;
737 let t = alpha_e * dx;
738 let base = e0_prev.mul_add(alpha_e, one_minus_alpha_e * ec_prev);
739 let r = x - base;
740
741 let g_sel = if t == 0.0 {
742 0.0
743 } else {
744 let limit_i = ema_gain_limit as i64;
745 let target = (r / t) * 10.0;
746 let mut i0 = target.floor() as i64;
747 if i0 < 0 {
748 i0 = 0;
749 } else if i0 > limit_i {
750 i0 = limit_i;
751 }
752 let i1 = if i0 < limit_i { i0 + 1 } else { i0 };
753 let g0 = (i0 as f64) * 0.1;
754 let g1 = (i1 as f64) * 0.1;
755 let e0 = (r - t * g0).abs();
756 let e1 = (r - t * g1).abs();
757 if e0 <= e1 {
758 g0
759 } else {
760 g1
761 }
762 };
763
764 ec_now = (e0_prev + g_sel * dx).mul_add(alpha_e, one_minus_alpha_e * ec_prev);
765 ec_prev = ec_now;
766 }
767 }
768
769 if hull_val.is_finite() && ec_now.is_finite() {
770 out[i] = 0.5 * (hull_val + ec_now);
771 }
772 }
773}
774
775#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
776#[inline]
777unsafe fn dma_simd128(
778 data: &[f64],
779 hull_length: usize,
780 ema_length: usize,
781 ema_gain_limit: usize,
782 hull_ma_type: &str,
783 first_val: usize,
784 out: &mut [f64],
785) {
786 use core::arch::wasm32::*;
787 dma_scalar(
788 data,
789 hull_length,
790 ema_length,
791 ema_gain_limit,
792 hull_ma_type,
793 first_val,
794 out,
795 );
796}
797
798#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
799#[inline(always)]
800unsafe fn hsum256d(v: __m256d) -> f64 {
801 let mut buf = [0.0f64; 4];
802 _mm256_storeu_pd(buf.as_mut_ptr(), v);
803 buf[0] + buf[1] + buf[2] + buf[3]
804}
805
806#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
807#[inline(always)]
808unsafe fn hsum512d(v: __m512d) -> f64 {
809 let mut buf = [0.0f64; 8];
810 _mm512_storeu_pd(buf.as_mut_ptr(), v);
811 buf.iter().sum()
812}
813
814#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
815#[inline(always)]
816unsafe fn vabs256d(x: __m256d) -> __m256d {
817 let sign = _mm256_set1_pd(-0.0);
818 _mm256_andnot_pd(sign, x)
819}
820
821#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
822#[inline(always)]
823unsafe fn vabs512d(x: __m512d) -> __m512d {
824 let sign = _mm512_set1_epi64(i64::MIN as i64);
825 let xi = _mm512_castpd_si512(x);
826 let cleared = _mm512_andnot_si512(sign, xi);
827 _mm512_castsi512_pd(cleared)
828}
829
830#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
831#[inline(always)]
832unsafe fn sum_unweighted_avx2(ptr: *const f64, len: usize) -> f64 {
833 let mut i = 0usize;
834 let mut acc = _mm256_setzero_pd();
835 while i + 4 <= len {
836 let v = _mm256_loadu_pd(ptr.add(i));
837 acc = _mm256_add_pd(acc, v);
838 i += 4;
839 }
840 let mut s = hsum256d(acc);
841 while i < len {
842 s += *ptr.add(i);
843 i += 1;
844 }
845 s
846}
847
848#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
849#[inline(always)]
850unsafe fn sum_unweighted_avx512(ptr: *const f64, len: usize) -> f64 {
851 let mut i = 0usize;
852 let mut acc = _mm512_setzero_pd();
853 while i + 8 <= len {
854 let v = _mm512_loadu_pd(ptr.add(i));
855 acc = _mm512_add_pd(acc, v);
856 i += 8;
857 }
858 let mut s = hsum512d(acc);
859 while i < len {
860 s += *ptr.add(i);
861 i += 1;
862 }
863 s
864}
865
866#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
867#[inline(always)]
868unsafe fn seed_wma_window_avx2(ptr: *const f64, len: usize) -> (f64, f64) {
869 let mut i = 0usize;
870 let mut acc_v = _mm256_setzero_pd();
871 let mut acc_wv = _mm256_setzero_pd();
872 let inc = _mm256_set_pd(3.0, 2.0, 1.0, 0.0);
873 let mut wbase = 1.0f64;
874 while i + 4 <= len {
875 let v = _mm256_loadu_pd(ptr.add(i));
876 let w = _mm256_add_pd(_mm256_set1_pd(wbase), inc);
877 acc_v = _mm256_add_pd(acc_v, v);
878 acc_wv = _mm256_add_pd(acc_wv, _mm256_mul_pd(w, v));
879 wbase += 4.0;
880 i += 4;
881 }
882 let mut s = hsum256d(acc_v);
883 let mut sw = hsum256d(acc_wv);
884 while i < len {
885 let val = *ptr.add(i);
886 s += val;
887 sw += (i as f64 + 1.0) * val;
888 i += 1;
889 }
890 (s, sw)
891}
892
893#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
894#[inline(always)]
895unsafe fn seed_wma_window_avx512(ptr: *const f64, len: usize) -> (f64, f64) {
896 let mut i = 0usize;
897 let mut acc_v = _mm512_setzero_pd();
898 let mut acc_wv = _mm512_setzero_pd();
899 let inc = _mm512_set_pd(7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0);
900 let mut wbase = 1.0f64;
901 while i + 8 <= len {
902 let v = _mm512_loadu_pd(ptr.add(i));
903 let w = _mm512_add_pd(_mm512_set1_pd(wbase), inc);
904 acc_v = _mm512_add_pd(acc_v, v);
905 acc_wv = _mm512_add_pd(acc_wv, _mm512_mul_pd(w, v));
906 wbase += 8.0;
907 i += 8;
908 }
909 let mut s = hsum512d(acc_v);
910 let mut sw = hsum512d(acc_wv);
911 while i < len {
912 let val = *ptr.add(i);
913 s += val;
914 sw += (i as f64 + 1.0) * val;
915 i += 1;
916 }
917 (s, sw)
918}
919
920#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
921#[inline(always)]
922unsafe fn best_gain_search_avx2(
923 x: f64,
924 e0_prev: f64,
925 ec_prev: f64,
926 alpha_e: f64,
927 ema_gain_limit: usize,
928) -> f64 {
929 let width = 4usize;
930 let dx = _mm256_set1_pd(x - ec_prev);
931 let x_v = _mm256_set1_pd(x);
932 let e0_v = _mm256_set1_pd(e0_prev);
933 let ec_prev_v = _mm256_set1_pd(ec_prev);
934 let a_v = _mm256_set1_pd(alpha_e);
935 let om_a_v = _mm256_set1_pd(1.0 - alpha_e);
936 let inf_v = _mm256_set1_pd(f64::INFINITY);
937 let limit_f = ema_gain_limit as f64;
938 let limit_v = _mm256_set1_pd(limit_f);
939 let scale = _mm256_set1_pd(0.1);
940
941 let mut best_err = _mm256_set1_pd(f64::INFINITY);
942 let mut best_g = _mm256_set1_pd(0.0);
943
944 let mut idx = 0usize;
945 while idx <= ema_gain_limit {
946 let base = _mm256_set1_pd(idx as f64);
947 let inc = _mm256_set_pd(3.0, 2.0, 1.0, 0.0);
948 let idx_v = _mm256_add_pd(base, inc);
949
950 let gt_mask = _mm256_cmp_pd(idx_v, limit_v, _CMP_GT_OQ);
951
952 let g = _mm256_mul_pd(idx_v, scale);
953 let e0_plus = _mm256_fmadd_pd(g, dx, e0_v);
954 let pred = _mm256_fmadd_pd(a_v, e0_plus, _mm256_mul_pd(om_a_v, ec_prev_v));
955 let err = vabs256d(_mm256_sub_pd(x_v, pred));
956
957 let err_masked = _mm256_blendv_pd(err, inf_v, gt_mask);
958
959 let lt = _mm256_cmp_pd(err_masked, best_err, _CMP_LT_OQ);
960 best_err = _mm256_blendv_pd(best_err, err_masked, lt);
961 best_g = _mm256_blendv_pd(best_g, g, lt);
962
963 idx += width;
964 }
965
966 let mut e = [0.0f64; 4];
967 let mut g = [0.0f64; 4];
968 _mm256_storeu_pd(e.as_mut_ptr(), best_err);
969 _mm256_storeu_pd(g.as_mut_ptr(), best_g);
970
971 let mut best_e = f64::INFINITY;
972 let mut best_gg = 0.0;
973 for k in 0..4 {
974 let ek = e[k];
975 let gk = g[k];
976 if ek < best_e || (ek == best_e && gk < best_gg) {
977 best_e = ek;
978 best_gg = gk;
979 }
980 }
981 best_gg
982}
983
984#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
985#[inline(always)]
986unsafe fn best_gain_search_avx512(
987 x: f64,
988 e0_prev: f64,
989 ec_prev: f64,
990 alpha_e: f64,
991 ema_gain_limit: usize,
992) -> f64 {
993 let width = 8usize;
994 let dx = _mm512_set1_pd(x - ec_prev);
995 let x_v = _mm512_set1_pd(x);
996 let e0_v = _mm512_set1_pd(e0_prev);
997 let ec_prev_v = _mm512_set1_pd(ec_prev);
998 let a_v = _mm512_set1_pd(alpha_e);
999 let om_a_v = _mm512_set1_pd(1.0 - alpha_e);
1000 let inf_v = _mm512_set1_pd(f64::INFINITY);
1001 let limit_f = ema_gain_limit as f64;
1002 let limit_v = _mm512_set1_pd(limit_f);
1003 let scale = _mm512_set1_pd(0.1);
1004
1005 let mut best_err = _mm512_set1_pd(f64::INFINITY);
1006 let mut best_g = _mm512_set1_pd(0.0);
1007
1008 let mut idx = 0usize;
1009 while idx <= ema_gain_limit {
1010 let base = _mm512_set1_pd(idx as f64);
1011 let inc = _mm512_set_pd(7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0);
1012 let idx_v = _mm512_add_pd(base, inc);
1013
1014 let k_invalid = _mm512_cmp_pd_mask(idx_v, limit_v, _CMP_GT_OQ);
1015
1016 let g = _mm512_mul_pd(idx_v, scale);
1017 let e0_plus = _mm512_fmadd_pd(g, dx, e0_v);
1018 let pred = _mm512_fmadd_pd(a_v, e0_plus, _mm512_mul_pd(om_a_v, ec_prev_v));
1019 let err = vabs512d(_mm512_sub_pd(x_v, pred));
1020
1021 let err_masked = _mm512_mask_mov_pd(err, k_invalid, inf_v);
1022
1023 let k_lt = _mm512_cmp_pd_mask(err_masked, best_err, _CMP_LT_OQ);
1024 best_err = _mm512_mask_mov_pd(best_err, k_lt, err_masked);
1025 best_g = _mm512_mask_mov_pd(best_g, k_lt, g);
1026
1027 idx += width;
1028 }
1029
1030 let mut e = [0.0f64; 8];
1031 let mut g = [0.0f64; 8];
1032 _mm512_storeu_pd(e.as_mut_ptr(), best_err);
1033 _mm512_storeu_pd(g.as_mut_ptr(), best_g);
1034
1035 let mut best_e = f64::INFINITY;
1036 let mut best_gg = 0.0;
1037 for k in 0..8 {
1038 let ek = e[k];
1039 let gk = g[k];
1040 if ek < best_e || (ek == best_e && gk < best_gg) {
1041 best_e = ek;
1042 best_gg = gk;
1043 }
1044 }
1045 best_gg
1046}
1047
1048#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1049#[target_feature(enable = "avx2,fma")]
1050unsafe fn dma_avx2(
1051 data: &[f64],
1052 hull_length: usize,
1053 ema_length: usize,
1054 ema_gain_limit: usize,
1055 hull_ma_type: &str,
1056 first: usize,
1057 out: &mut [f64],
1058) {
1059 let n = data.len();
1060 if n == 0 {
1061 return;
1062 }
1063
1064 let alpha_e = 2.0 / (ema_length as f64 + 1.0);
1065 let i0_e = first + ema_length.saturating_sub(1);
1066 let mut e0_prev = 0.0;
1067 let mut e0_init_done = false;
1068 let mut ec_prev = 0.0;
1069 let mut ec_init_done = false;
1070
1071 let half = hull_length / 2;
1072 let sqrt_len = (hull_length as f64).sqrt().round() as usize;
1073
1074 let mut hull_val = f64::NAN;
1075
1076 let wsum = |p: usize| -> f64 { (p * (p + 1)) as f64 / 2.0 };
1077 let i0_half = first + half.saturating_sub(1);
1078 let i0_full = first + hull_length.saturating_sub(1);
1079
1080 let mut a_half = 0.0;
1081 let mut s_half = 0.0;
1082 let mut half_ready = false;
1083
1084 let mut a_full = 0.0;
1085 let mut s_full = 0.0;
1086 let mut full_ready = false;
1087
1088 let mut diff_ring: Vec<f64> = Vec::with_capacity(sqrt_len.max(1));
1089 let mut diff_pos: usize = 0;
1090 let mut diff_filled = 0usize;
1091
1092 let mut a_diff = 0.0;
1093 let mut s_diff = 0.0;
1094 let mut diff_wma_init_done = false;
1095
1096 let alpha_sqrt = if sqrt_len > 0 {
1097 2.0 / (sqrt_len as f64 + 1.0)
1098 } else {
1099 0.0
1100 };
1101 let mut diff_ema = 0.0;
1102 let mut diff_ema_init_done = false;
1103 let mut diff_sum_seed = 0.0;
1104
1105 let mut e_half_prev = 0.0;
1106 let mut e_half_init_done = false;
1107 let mut e_full_prev = 0.0;
1108 let mut e_full_init_done = false;
1109 let alpha_half = if half > 0 {
1110 2.0 / (half as f64 + 1.0)
1111 } else {
1112 0.0
1113 };
1114 let alpha_full = if hull_length > 0 {
1115 2.0 / (hull_length as f64 + 1.0)
1116 } else {
1117 0.0
1118 };
1119
1120 let is_wma = hull_ma_type == "WMA";
1121
1122 for i in first..n {
1123 let x = data[i];
1124
1125 if !e0_init_done {
1126 if i >= i0_e {
1127 let start = i + 1 - ema_length;
1128 let sum = sum_unweighted_avx2(data.as_ptr().add(start), ema_length);
1129 e0_prev = sum / ema_length as f64;
1130 e0_init_done = true;
1131 }
1132 } else {
1133 e0_prev = x.mul_add(alpha_e, (1.0 - alpha_e) * e0_prev);
1134 }
1135
1136 let mut diff_now = f64::NAN;
1137
1138 if is_wma {
1139 if half > 0 {
1140 if !half_ready {
1141 if i >= i0_half {
1142 let start = i + 1 - half;
1143 let (sum, wsum_local) =
1144 seed_wma_window_avx2(data.as_ptr().add(start), half);
1145 a_half = sum;
1146 s_half = wsum_local;
1147 half_ready = true;
1148 }
1149 } else {
1150 let a_prev = a_half;
1151 a_half = a_prev + x - data[i - half];
1152 s_half = s_half + (half as f64) * x - a_prev;
1153 }
1154 }
1155
1156 if hull_length > 0 {
1157 if !full_ready {
1158 if i >= i0_full {
1159 let start = i + 1 - hull_length;
1160 let (sum, wsum_local) =
1161 seed_wma_window_avx2(data.as_ptr().add(start), hull_length);
1162 a_full = sum;
1163 s_full = wsum_local;
1164 full_ready = true;
1165 }
1166 } else {
1167 let a_prev = a_full;
1168 a_full = a_prev + x - data[i - hull_length];
1169 s_full = s_full + (hull_length as f64) * x - a_prev;
1170 }
1171 }
1172
1173 if half_ready && full_ready {
1174 let w_half = s_half / wsum(half).max(1.0);
1175 let w_full = s_full / wsum(hull_length).max(1.0);
1176 diff_now = 2.0 * w_half - w_full;
1177 }
1178 } else {
1179 if half > 0 {
1180 if !e_half_init_done {
1181 if i >= i0_half {
1182 let start = i + 1 - half;
1183 let sum = sum_unweighted_avx2(data.as_ptr().add(start), half);
1184 e_half_prev = sum / half as f64;
1185 e_half_init_done = true;
1186 }
1187 } else {
1188 e_half_prev = x.mul_add(alpha_half, (1.0 - alpha_half) * e_half_prev);
1189 }
1190 }
1191
1192 if hull_length > 0 {
1193 if !e_full_init_done {
1194 if i >= i0_full {
1195 let start = i + 1 - hull_length;
1196 let sum = sum_unweighted_avx2(data.as_ptr().add(start), hull_length);
1197 e_full_prev = sum / hull_length as f64;
1198 e_full_init_done = true;
1199 }
1200 } else {
1201 e_full_prev = x.mul_add(alpha_full, (1.0 - alpha_full) * e_full_prev);
1202 }
1203 }
1204
1205 if e_half_init_done && e_full_init_done {
1206 diff_now = 2.0 * e_half_prev - e_full_prev;
1207 }
1208 }
1209
1210 if diff_now.is_finite() && sqrt_len > 0 {
1211 if diff_filled < sqrt_len {
1212 diff_ring.push(diff_now);
1213 diff_sum_seed += diff_now;
1214 diff_filled += 1;
1215
1216 if diff_filled == sqrt_len {
1217 if is_wma {
1218 let (a0, s0) = seed_wma_window_avx2(diff_ring.as_ptr(), sqrt_len);
1219 a_diff = a0;
1220 s_diff = s0;
1221 diff_wma_init_done = true;
1222 let wsum_d = (sqrt_len * (sqrt_len + 1)) as f64 / 2.0;
1223 hull_val = s_diff / wsum_d.max(1.0);
1224 } else {
1225 diff_ema = diff_sum_seed / sqrt_len as f64;
1226 diff_ema_init_done = true;
1227 hull_val = diff_ema;
1228 }
1229 }
1230 } else {
1231 let old = diff_ring[diff_pos];
1232 diff_ring[diff_pos] = diff_now;
1233 diff_pos = (diff_pos + 1) % sqrt_len;
1234
1235 if is_wma {
1236 let a_prev = a_diff;
1237 a_diff = a_prev + diff_now - old;
1238 s_diff = s_diff + (sqrt_len as f64) * diff_now - a_prev;
1239 let wsum_d = (sqrt_len * (sqrt_len + 1)) as f64 / 2.0;
1240 hull_val = s_diff / wsum_d.max(1.0);
1241 } else {
1242 diff_ema = diff_now.mul_add(alpha_sqrt, (1.0 - alpha_sqrt) * diff_ema);
1243 hull_val = diff_ema;
1244 }
1245 }
1246 }
1247
1248 let mut ec_now = f64::NAN;
1249 if e0_init_done {
1250 if !ec_init_done {
1251 ec_prev = e0_prev;
1252 ec_init_done = true;
1253 ec_now = ec_prev;
1254 } else {
1255 let dx = x - ec_prev;
1256 let t = alpha_e * dx;
1257 let base = e0_prev.mul_add(alpha_e, (1.0 - alpha_e) * ec_prev);
1258 let r = x - base;
1259
1260 let g_sel = if t == 0.0 {
1261 0.0
1262 } else {
1263 let limit_i = ema_gain_limit as i64;
1264 let target = (r / t) * 10.0;
1265 let mut i0 = target.floor() as i64;
1266 if i0 < 0 {
1267 i0 = 0;
1268 } else if i0 > limit_i {
1269 i0 = limit_i;
1270 }
1271 let i1 = if i0 < limit_i { i0 + 1 } else { i0 };
1272 let g0 = (i0 as f64) * 0.1;
1273 let g1 = (i1 as f64) * 0.1;
1274 let e0 = (r - t * g0).abs();
1275 let e1 = (r - t * g1).abs();
1276 if e0 <= e1 {
1277 g0
1278 } else {
1279 g1
1280 }
1281 };
1282
1283 ec_now = (e0_prev + g_sel * dx).mul_add(alpha_e, (1.0 - alpha_e) * ec_prev);
1284 ec_prev = ec_now;
1285 }
1286 }
1287
1288 if hull_val.is_finite() && ec_now.is_finite() {
1289 out[i] = 0.5 * (hull_val + ec_now);
1290 }
1291 }
1292}
1293
1294#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1295#[target_feature(enable = "avx512f,fma")]
1296unsafe fn dma_avx512(
1297 data: &[f64],
1298 hull_length: usize,
1299 ema_length: usize,
1300 ema_gain_limit: usize,
1301 hull_ma_type: &str,
1302 first: usize,
1303 out: &mut [f64],
1304) {
1305 let n = data.len();
1306 if n == 0 {
1307 return;
1308 }
1309
1310 let alpha_e = 2.0 / (ema_length as f64 + 1.0);
1311 let i0_e = first + ema_length.saturating_sub(1);
1312 let mut e0_prev = 0.0;
1313 let mut e0_init_done = false;
1314 let mut ec_prev = 0.0;
1315 let mut ec_init_done = false;
1316
1317 let half = hull_length / 2;
1318 let sqrt_len = (hull_length as f64).sqrt().round() as usize;
1319
1320 let mut hull_val = f64::NAN;
1321
1322 let wsum = |p: usize| -> f64 { (p * (p + 1)) as f64 / 2.0 };
1323 let i0_half = first + half.saturating_sub(1);
1324 let i0_full = first + hull_length.saturating_sub(1);
1325
1326 let mut a_half = 0.0;
1327 let mut s_half = 0.0;
1328 let mut half_ready = false;
1329
1330 let mut a_full = 0.0;
1331 let mut s_full = 0.0;
1332 let mut full_ready = false;
1333
1334 let mut diff_ring: Vec<f64> = Vec::with_capacity(sqrt_len.max(1));
1335 let mut diff_pos: usize = 0;
1336 let mut diff_filled = 0usize;
1337
1338 let mut a_diff = 0.0;
1339 let mut s_diff = 0.0;
1340 let mut diff_wma_init_done = false;
1341
1342 let alpha_sqrt = if sqrt_len > 0 {
1343 2.0 / (sqrt_len as f64 + 1.0)
1344 } else {
1345 0.0
1346 };
1347 let mut diff_ema = 0.0;
1348 let mut diff_ema_init_done = false;
1349 let mut diff_sum_seed = 0.0;
1350
1351 let mut e_half_prev = 0.0;
1352 let mut e_half_init_done = false;
1353 let mut e_full_prev = 0.0;
1354 let mut e_full_init_done = false;
1355 let alpha_half = if half > 0 {
1356 2.0 / (half as f64 + 1.0)
1357 } else {
1358 0.0
1359 };
1360 let alpha_full = if hull_length > 0 {
1361 2.0 / (hull_length as f64 + 1.0)
1362 } else {
1363 0.0
1364 };
1365
1366 let is_wma = hull_ma_type == "WMA";
1367
1368 for i in first..n {
1369 let x = data[i];
1370
1371 if !e0_init_done {
1372 if i >= i0_e {
1373 let start = i + 1 - ema_length;
1374 let sum = sum_unweighted_avx512(data.as_ptr().add(start), ema_length);
1375 e0_prev = sum / ema_length as f64;
1376 e0_init_done = true;
1377 }
1378 } else {
1379 e0_prev = x.mul_add(alpha_e, (1.0 - alpha_e) * e0_prev);
1380 }
1381
1382 let mut diff_now = f64::NAN;
1383
1384 if is_wma {
1385 if half > 0 {
1386 if !half_ready {
1387 if i >= i0_half {
1388 let start = i + 1 - half;
1389 let (sum, wsum_local) =
1390 seed_wma_window_avx512(data.as_ptr().add(start), half);
1391 a_half = sum;
1392 s_half = wsum_local;
1393 half_ready = true;
1394 }
1395 } else {
1396 let a_prev = a_half;
1397 a_half = a_prev + x - data[i - half];
1398 s_half = s_half + (half as f64) * x - a_prev;
1399 }
1400 }
1401
1402 if hull_length > 0 {
1403 if !full_ready {
1404 if i >= i0_full {
1405 let start = i + 1 - hull_length;
1406 let (sum, wsum_local) =
1407 seed_wma_window_avx512(data.as_ptr().add(start), hull_length);
1408 a_full = sum;
1409 s_full = wsum_local;
1410 full_ready = true;
1411 }
1412 } else {
1413 let a_prev = a_full;
1414 a_full = a_prev + x - data[i - hull_length];
1415 s_full = s_full + (hull_length as f64) * x - a_prev;
1416 }
1417 }
1418
1419 if half_ready && full_ready {
1420 let w_half = s_half / wsum(half).max(1.0);
1421 let w_full = s_full / wsum(hull_length).max(1.0);
1422 diff_now = 2.0 * w_half - w_full;
1423 }
1424 } else {
1425 if half > 0 {
1426 if !e_half_init_done {
1427 if i >= i0_half {
1428 let start = i + 1 - half;
1429 let sum = sum_unweighted_avx512(data.as_ptr().add(start), half);
1430 e_half_prev = sum / half as f64;
1431 e_half_init_done = true;
1432 }
1433 } else {
1434 e_half_prev = x.mul_add(alpha_half, (1.0 - alpha_half) * e_half_prev);
1435 }
1436 }
1437
1438 if hull_length > 0 {
1439 if !e_full_init_done {
1440 if i >= i0_full {
1441 let start = i + 1 - hull_length;
1442 let sum = sum_unweighted_avx512(data.as_ptr().add(start), hull_length);
1443 e_full_prev = sum / hull_length as f64;
1444 e_full_init_done = true;
1445 }
1446 } else {
1447 e_full_prev = x.mul_add(alpha_full, (1.0 - alpha_full) * e_full_prev);
1448 }
1449 }
1450
1451 if e_half_init_done && e_full_init_done {
1452 diff_now = 2.0 * e_half_prev - e_full_prev;
1453 }
1454 }
1455
1456 if diff_now.is_finite() && sqrt_len > 0 {
1457 if diff_filled < sqrt_len {
1458 diff_ring.push(diff_now);
1459 diff_sum_seed += diff_now;
1460 diff_filled += 1;
1461
1462 if diff_filled == sqrt_len {
1463 if is_wma {
1464 let (a0, s0) = seed_wma_window_avx512(diff_ring.as_ptr(), sqrt_len);
1465 a_diff = a0;
1466 s_diff = s0;
1467 diff_wma_init_done = true;
1468 let wsum_d = (sqrt_len * (sqrt_len + 1)) as f64 / 2.0;
1469 hull_val = s_diff / wsum_d.max(1.0);
1470 } else {
1471 diff_ema = diff_sum_seed / sqrt_len as f64;
1472 diff_ema_init_done = true;
1473 hull_val = diff_ema;
1474 }
1475 }
1476 } else {
1477 let old = diff_ring[diff_pos];
1478 diff_ring[diff_pos] = diff_now;
1479 diff_pos = (diff_pos + 1) % sqrt_len;
1480
1481 if is_wma {
1482 let a_prev = a_diff;
1483 a_diff = a_prev + diff_now - old;
1484 s_diff = s_diff + (sqrt_len as f64) * diff_now - a_prev;
1485 let wsum_d = (sqrt_len * (sqrt_len + 1)) as f64 / 2.0;
1486 hull_val = s_diff / wsum_d.max(1.0);
1487 } else {
1488 diff_ema = diff_now.mul_add(alpha_sqrt, (1.0 - alpha_sqrt) * diff_ema);
1489 hull_val = diff_ema;
1490 }
1491 }
1492 }
1493
1494 let mut ec_now = f64::NAN;
1495 if e0_init_done {
1496 if !ec_init_done {
1497 ec_prev = e0_prev;
1498 ec_init_done = true;
1499 ec_now = ec_prev;
1500 } else {
1501 let dx = x - ec_prev;
1502 let t = alpha_e * dx;
1503 let base = e0_prev.mul_add(alpha_e, (1.0 - alpha_e) * ec_prev);
1504 let r = x - base;
1505
1506 let g_sel = if t == 0.0 {
1507 0.0
1508 } else {
1509 let limit_i = ema_gain_limit as i64;
1510 let target = (r / t) * 10.0;
1511 let mut i0 = target.floor() as i64;
1512 if i0 < 0 {
1513 i0 = 0;
1514 } else if i0 > limit_i {
1515 i0 = limit_i;
1516 }
1517 let i1 = if i0 < limit_i { i0 + 1 } else { i0 };
1518 let g0 = (i0 as f64) * 0.1;
1519 let g1 = (i1 as f64) * 0.1;
1520 let e0 = (r - t * g0).abs();
1521 let e1 = (r - t * g1).abs();
1522 if e0 <= e1 {
1523 g0
1524 } else {
1525 g1
1526 }
1527 };
1528
1529 ec_now = (e0_prev + g_sel * dx).mul_add(alpha_e, (1.0 - alpha_e) * ec_prev);
1530 ec_prev = ec_now;
1531 }
1532 }
1533
1534 if hull_val.is_finite() && ec_now.is_finite() {
1535 out[i] = 0.5 * (hull_val + ec_now);
1536 }
1537 }
1538}
1539#[derive(Debug, Clone)]
1540pub struct DmaStream {
1541 ema_length: usize,
1542 ema_gain_limit: usize,
1543 hull_length: usize,
1544 half: usize,
1545 sqrt_len: usize,
1546 is_wma: bool,
1547
1548 cap: usize,
1549 ring: Vec<f64>,
1550 head: usize,
1551 filled: usize,
1552
1553 i: usize,
1554 seen_first: bool,
1555
1556 alpha_e: f64,
1557 sum_e0: f64,
1558 e0_prev: f64,
1559 e0_ready: bool,
1560
1561 ec_prev: f64,
1562 ec_ready: bool,
1563
1564 sum_half: f64,
1565 sum_full: f64,
1566 s_half: f64,
1567 s_full: f64,
1568 half_ready: bool,
1569 full_ready: bool,
1570
1571 alpha_half: f64,
1572 alpha_full: f64,
1573 e_half_prev: f64,
1574 e_full_prev: f64,
1575 e_half_ready: bool,
1576 e_full_ready: bool,
1577
1578 a_diff: f64,
1579 s_diff: f64,
1580 diff_wma_ready: bool,
1581
1582 alpha_sqrt: f64,
1583 diff_ema: f64,
1584 diff_ema_ready: bool,
1585
1586 diff_ring: Vec<f64>,
1587 diff_head: usize,
1588 diff_filled: usize,
1589}
1590
1591impl DmaStream {
1592 pub fn try_new(params: DmaParams) -> Result<Self, DmaError> {
1593 let hull_length = params.hull_length.unwrap_or(7);
1594 let ema_length = params.ema_length.unwrap_or(20);
1595 let ema_gain_limit = params.ema_gain_limit.unwrap_or(50);
1596 let hull_ma_type = params.hull_ma_type.unwrap_or_else(|| "WMA".to_string());
1597 if hull_length == 0 || ema_length == 0 {
1598 return Err(DmaError::InvalidPeriod {
1599 period: hull_length.max(ema_length),
1600 data_len: 0,
1601 });
1602 }
1603 if hull_ma_type != "WMA" && hull_ma_type != "EMA" {
1604 return Err(DmaError::InvalidHullMAType {
1605 value: hull_ma_type,
1606 });
1607 }
1608
1609 let half = hull_length / 2;
1610 let sqrt_len = (hull_length as f64).sqrt().round() as usize;
1611 let cap = hull_length.max(ema_length).max(1);
1612
1613 Ok(Self {
1614 ema_length,
1615 ema_gain_limit,
1616 hull_length,
1617 half,
1618 sqrt_len,
1619 is_wma: hull_ma_type == "WMA",
1620
1621 cap,
1622 ring: vec![f64::NAN; cap],
1623 head: 0,
1624 filled: 0,
1625 i: 0,
1626 seen_first: false,
1627
1628 alpha_e: 2.0 / (ema_length as f64 + 1.0),
1629 sum_e0: 0.0,
1630 e0_prev: 0.0,
1631 e0_ready: false,
1632
1633 ec_prev: 0.0,
1634 ec_ready: false,
1635
1636 sum_half: 0.0,
1637 sum_full: 0.0,
1638 s_half: 0.0,
1639 s_full: 0.0,
1640 half_ready: half == 0,
1641 full_ready: hull_length == 0,
1642
1643 alpha_half: if half > 0 {
1644 2.0 / (half as f64 + 1.0)
1645 } else {
1646 0.0
1647 },
1648 alpha_full: 2.0 / (hull_length as f64 + 1.0),
1649 e_half_prev: 0.0,
1650 e_full_prev: 0.0,
1651 e_half_ready: half == 0,
1652 e_full_ready: hull_length == 0,
1653
1654 a_diff: 0.0,
1655 s_diff: 0.0,
1656 diff_wma_ready: sqrt_len == 0,
1657 alpha_sqrt: if sqrt_len > 0 {
1658 2.0 / (sqrt_len as f64 + 1.0)
1659 } else {
1660 0.0
1661 },
1662 diff_ema: 0.0,
1663 diff_ema_ready: sqrt_len == 0,
1664 diff_ring: vec![f64::NAN; sqrt_len.max(1)],
1665 diff_head: 0,
1666 diff_filled: 0,
1667 })
1668 }
1669
1670 #[inline]
1671 pub fn update(&mut self, x: f64) -> Option<f64> {
1672 if !self.seen_first {
1673 self.i += 1;
1674 if x.is_nan() {
1675 return None;
1676 }
1677 self.seen_first = true;
1678 }
1679
1680 let old_head = self.head;
1681 self.ring[old_head] = x;
1682 self.head = (old_head + 1) % self.cap;
1683 if self.filled < self.cap {
1684 self.filled += 1;
1685 }
1686
1687 #[inline(always)]
1688 fn kback(ring: &[f64], head: usize, cap: usize, k: usize) -> f64 {
1689 let idx = (head + cap - k % cap) % cap;
1690 ring[idx]
1691 }
1692
1693 if self.filled < self.ema_length {
1694 if x.is_finite() {
1695 self.sum_e0 += x;
1696 }
1697 } else {
1698 let out_e = kback(&self.ring, self.head, self.cap, self.ema_length);
1699 if x.is_finite() {
1700 self.sum_e0 += x;
1701 }
1702 if out_e.is_finite() {
1703 self.sum_e0 -= out_e;
1704 }
1705 if !self.e0_ready {
1706 self.e0_prev = self.sum_e0 / self.ema_length as f64;
1707 self.e0_ready = true;
1708 } else {
1709 self.e0_prev = self.alpha_e * x + (1.0 - self.alpha_e) * self.e0_prev;
1710 }
1711 }
1712
1713 let mut diff_now = f64::NAN;
1714
1715 if self.is_wma {
1716 if self.half > 0 {
1717 if self.filled < self.half {
1718 if x.is_finite() {
1719 self.sum_half += x;
1720 }
1721 } else {
1722 let out_h = kback(&self.ring, self.head, self.cap, self.half);
1723 if x.is_finite() {
1724 self.sum_half += x;
1725 }
1726 if out_h.is_finite() {
1727 self.sum_half -= out_h;
1728 }
1729 if !self.half_ready {
1730 self.s_half = 0.0;
1731 for j in 0..self.half {
1732 let v = kback(&self.ring, self.head, self.cap, self.half - j);
1733 self.s_half += (j as f64 + 1.0) * v;
1734 }
1735 self.half_ready = true;
1736 } else {
1737 let a_prev = self.sum_half + out_h - x;
1738 self.s_half = self.s_half + (self.half as f64) * x - a_prev;
1739 }
1740 }
1741 } else {
1742 self.half_ready = true;
1743 }
1744
1745 if self.filled < self.hull_length {
1746 if x.is_finite() {
1747 self.sum_full += x;
1748 }
1749 } else {
1750 let out_f = kback(&self.ring, self.head, self.cap, self.hull_length);
1751 if x.is_finite() {
1752 self.sum_full += x;
1753 }
1754 if out_f.is_finite() {
1755 self.sum_full -= out_f;
1756 }
1757 if !self.full_ready {
1758 self.s_full = 0.0;
1759 for j in 0..self.hull_length {
1760 let v = kback(&self.ring, self.head, self.cap, self.hull_length - j);
1761 self.s_full += (j as f64 + 1.0) * v;
1762 }
1763 self.full_ready = true;
1764 } else {
1765 let a_prev = self.sum_full + out_f - x;
1766 self.s_full = self.s_full + (self.hull_length as f64) * x - a_prev;
1767 }
1768 }
1769
1770 if self.half_ready && self.full_ready && self.sqrt_len > 0 {
1771 let wsum = |p: usize| (p * (p + 1)) as f64 / 2.0;
1772 let w_half = self.s_half / wsum(self.half).max(1.0);
1773 let w_full = self.s_full / wsum(self.hull_length).max(1.0);
1774 diff_now = 2.0 * w_half - w_full;
1775 }
1776 } else {
1777 if self.half > 0 {
1778 if self.filled < self.half {
1779 } else if !self.e_half_ready {
1780 let mut s = 0.0;
1781 for j in 0..self.half {
1782 s += kback(&self.ring, self.head, self.cap, self.half - j);
1783 }
1784 self.e_half_prev = s / self.half as f64;
1785 self.e_half_ready = true;
1786 } else {
1787 self.e_half_prev =
1788 self.alpha_half * x + (1.0 - self.alpha_half) * self.e_half_prev;
1789 }
1790 } else {
1791 self.e_half_ready = true;
1792 }
1793
1794 if self.filled < self.hull_length {
1795 } else if !self.e_full_ready {
1796 let mut s = 0.0;
1797 for j in 0..self.hull_length {
1798 s += kback(&self.ring, self.head, self.cap, self.hull_length - j);
1799 }
1800 self.e_full_prev = s / self.hull_length as f64;
1801 self.e_full_ready = true;
1802 } else {
1803 self.e_full_prev = self.alpha_full * x + (1.0 - self.alpha_full) * self.e_full_prev;
1804 }
1805
1806 if self.e_half_ready && self.e_full_ready && self.sqrt_len > 0 {
1807 diff_now = 2.0 * self.e_half_prev - self.e_full_prev;
1808 }
1809 }
1810
1811 let mut hull_val = f64::NAN;
1812 if self.sqrt_len == 0 {
1813 if diff_now.is_finite() {
1814 hull_val = diff_now;
1815 }
1816 } else if diff_now.is_finite() {
1817 let old = self.diff_ring[self.diff_head];
1818 self.diff_ring[self.diff_head] = diff_now;
1819 self.diff_head = (self.diff_head + 1) % self.sqrt_len;
1820 if self.diff_filled < self.sqrt_len {
1821 self.diff_filled += 1;
1822 }
1823
1824 if self.is_wma {
1825 if !self.diff_wma_ready && self.diff_filled == self.sqrt_len {
1826 self.a_diff = 0.0;
1827 self.s_diff = 0.0;
1828 for j in 0..self.sqrt_len {
1829 let v = self.diff_ring[(self.diff_head + j) % self.sqrt_len];
1830 self.a_diff += v;
1831 self.s_diff += (j as f64 + 1.0) * v;
1832 }
1833 self.diff_wma_ready = true;
1834 let wsum = (self.sqrt_len * (self.sqrt_len + 1)) as f64 / 2.0;
1835 hull_val = self.s_diff / wsum.max(1.0);
1836 } else if self.diff_wma_ready {
1837 let wsum = (self.sqrt_len * (self.sqrt_len + 1)) as f64 / 2.0;
1838 let a_prev = self.a_diff;
1839 self.s_diff = self.s_diff + (self.sqrt_len as f64) * diff_now - a_prev;
1840 self.a_diff = a_prev + diff_now - old;
1841 hull_val = self.s_diff / wsum.max(1.0);
1842 }
1843 } else {
1844 if !self.diff_ema_ready && self.diff_filled == self.sqrt_len {
1845 let mut s = 0.0;
1846 for j in 0..self.sqrt_len {
1847 s += self.diff_ring[j];
1848 }
1849 self.diff_ema = s / self.sqrt_len as f64;
1850 self.diff_ema_ready = true;
1851 hull_val = self.diff_ema;
1852 } else if self.diff_ema_ready {
1853 self.diff_ema =
1854 self.alpha_sqrt * diff_now + (1.0 - self.alpha_sqrt) * self.diff_ema;
1855 hull_val = self.diff_ema;
1856 }
1857 }
1858 }
1859
1860 let mut ec_now = f64::NAN;
1861 if self.e0_ready {
1862 if !self.ec_ready {
1863 self.ec_prev = self.e0_prev;
1864 self.ec_ready = true;
1865 ec_now = self.ec_prev;
1866 } else {
1867 let one_minus_alpha_e = 1.0 - self.alpha_e;
1868 let dx = x - self.ec_prev;
1869 let t = self.alpha_e * dx;
1870 let base = self
1871 .alpha_e
1872 .mul_add(self.e0_prev, one_minus_alpha_e * self.ec_prev);
1873 let r = x - base;
1874
1875 let g_sel = if t == 0.0 {
1876 0.0
1877 } else {
1878 let target = (r / t) * 10.0;
1879 let limit_i = self.ema_gain_limit as i64;
1880 let mut i0 = target.floor() as i64;
1881 if i0 < 0 {
1882 i0 = 0;
1883 } else if i0 > limit_i {
1884 i0 = limit_i;
1885 }
1886 let i1 = if i0 < limit_i { i0 + 1 } else { i0 };
1887 let g0 = (i0 as f64) * 0.1;
1888 let g1 = (i1 as f64) * 0.1;
1889 let e0 = (r - t * g0).abs();
1890 let e1 = (r - t * g1).abs();
1891 if e0 <= e1 {
1892 g0
1893 } else {
1894 g1
1895 }
1896 };
1897
1898 let ec = self.alpha_e.mul_add(
1899 self.e0_prev + g_sel * dx,
1900 (1.0 - self.alpha_e) * self.ec_prev,
1901 );
1902 self.ec_prev = ec;
1903 ec_now = ec;
1904 }
1905 }
1906
1907 self.i += 1;
1908
1909 if hull_val.is_finite() && ec_now.is_finite() {
1910 Some(0.5 * (hull_val + ec_now))
1911 } else {
1912 None
1913 }
1914 }
1915}
1916
1917#[derive(Clone, Debug)]
1918pub struct DmaBatchRange {
1919 pub hull_length: (usize, usize, usize),
1920 pub ema_length: (usize, usize, usize),
1921 pub ema_gain_limit: (usize, usize, usize),
1922 pub hull_ma_type: String,
1923}
1924
1925impl Default for DmaBatchRange {
1926 fn default() -> Self {
1927 Self {
1928 hull_length: (7, 7, 0),
1929 ema_length: (20, 269, 1),
1930 ema_gain_limit: (50, 50, 0),
1931 hull_ma_type: "WMA".to_string(),
1932 }
1933 }
1934}
1935
1936#[derive(Clone, Debug, Default)]
1937pub struct DmaBatchBuilder {
1938 range: DmaBatchRange,
1939 kernel: Kernel,
1940}
1941
1942impl DmaBatchBuilder {
1943 pub fn new() -> Self {
1944 Self::default()
1945 }
1946
1947 pub fn kernel(mut self, k: Kernel) -> Self {
1948 self.kernel = k;
1949 self
1950 }
1951
1952 #[inline]
1953 pub fn hull_length_range(mut self, start: usize, end: usize, step: usize) -> Self {
1954 self.range.hull_length = (start, end, step);
1955 self
1956 }
1957
1958 #[inline]
1959 pub fn hull_length_static(mut self, val: usize) -> Self {
1960 self.range.hull_length = (val, val, 0);
1961 self
1962 }
1963
1964 #[inline]
1965 pub fn ema_length_range(mut self, start: usize, end: usize, step: usize) -> Self {
1966 self.range.ema_length = (start, end, step);
1967 self
1968 }
1969
1970 #[inline]
1971 pub fn ema_length_static(mut self, val: usize) -> Self {
1972 self.range.ema_length = (val, val, 0);
1973 self
1974 }
1975
1976 #[inline]
1977 pub fn ema_gain_limit_range(mut self, start: usize, end: usize, step: usize) -> Self {
1978 self.range.ema_gain_limit = (start, end, step);
1979 self
1980 }
1981
1982 #[inline]
1983 pub fn ema_gain_limit_static(mut self, val: usize) -> Self {
1984 self.range.ema_gain_limit = (val, val, 0);
1985 self
1986 }
1987
1988 #[inline]
1989 pub fn hull_ma_type(mut self, val: String) -> Self {
1990 self.range.hull_ma_type = val;
1991 self
1992 }
1993
1994 pub fn apply_slice(self, data: &[f64]) -> Result<DmaBatchOutput, DmaError> {
1995 dma_batch_with_kernel(data, &self.range, self.kernel)
1996 }
1997
1998 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<DmaBatchOutput, DmaError> {
1999 DmaBatchBuilder::new().kernel(k).apply_slice(data)
2000 }
2001
2002 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<DmaBatchOutput, DmaError> {
2003 let slice = source_type(c, src);
2004 self.apply_slice(slice)
2005 }
2006
2007 pub fn with_default_candles(c: &Candles) -> Result<DmaBatchOutput, DmaError> {
2008 DmaBatchBuilder::new()
2009 .kernel(Kernel::Auto)
2010 .apply_candles(c, "close")
2011 }
2012}
2013
2014#[derive(Clone, Debug)]
2015pub struct DmaBatchOutput {
2016 pub values: Vec<f64>,
2017 pub combos: Vec<DmaParams>,
2018 pub rows: usize,
2019 pub cols: usize,
2020}
2021
2022impl DmaBatchOutput {
2023 pub fn row_for_params(&self, p: &DmaParams) -> Option<usize> {
2024 self.combos.iter().position(|c| {
2025 c.hull_length.unwrap_or(7) == p.hull_length.unwrap_or(7)
2026 && c.ema_length.unwrap_or(20) == p.ema_length.unwrap_or(20)
2027 && c.ema_gain_limit.unwrap_or(50) == p.ema_gain_limit.unwrap_or(50)
2028 && c.hull_ma_type.as_ref().unwrap_or(&"WMA".to_string())
2029 == p.hull_ma_type.as_ref().unwrap_or(&"WMA".to_string())
2030 })
2031 }
2032
2033 pub fn values_for(&self, p: &DmaParams) -> Option<&[f64]> {
2034 self.row_for_params(p).map(|row| {
2035 let start = row * self.cols;
2036 &self.values[start..start + self.cols]
2037 })
2038 }
2039}
2040
2041#[inline(always)]
2042fn expand_grid_dma(r: &DmaBatchRange) -> Vec<DmaParams> {
2043 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
2044 if step == 0 || start == end {
2045 return vec![start];
2046 }
2047 if start < end {
2048 return (start..=end).step_by(step).collect();
2049 }
2050
2051 let mut v: Vec<usize> = (end..=start).step_by(step).collect();
2052 v.reverse();
2053 v
2054 }
2055
2056 let hull_lengths = axis_usize(r.hull_length);
2057 let ema_lengths = axis_usize(r.ema_length);
2058 let ema_gain_limits = axis_usize(r.ema_gain_limit);
2059
2060 let mut combos = Vec::new();
2061 for &h in &hull_lengths {
2062 for &e in &ema_lengths {
2063 for &g in &ema_gain_limits {
2064 combos.push(DmaParams {
2065 hull_length: Some(h),
2066 ema_length: Some(e),
2067 ema_gain_limit: Some(g),
2068 hull_ma_type: Some(r.hull_ma_type.clone()),
2069 });
2070 }
2071 }
2072 }
2073 combos
2074}
2075
2076#[inline(always)]
2077pub fn dma_batch_slice(
2078 data: &[f64],
2079 sweep: &DmaBatchRange,
2080 kern: Kernel,
2081) -> Result<DmaBatchOutput, DmaError> {
2082 dma_batch_inner(data, sweep, kern, false)
2083}
2084
2085#[inline(always)]
2086pub fn dma_batch_par_slice(
2087 data: &[f64],
2088 sweep: &DmaBatchRange,
2089 kern: Kernel,
2090) -> Result<DmaBatchOutput, DmaError> {
2091 dma_batch_inner(data, sweep, kern, true)
2092}
2093
2094#[inline(always)]
2095fn dma_batch_inner(
2096 data: &[f64],
2097 sweep: &DmaBatchRange,
2098 kern: Kernel,
2099 parallel: bool,
2100) -> Result<DmaBatchOutput, DmaError> {
2101 let combos = expand_grid_dma(sweep);
2102 let cols = data.len();
2103 let rows = combos.len();
2104 if cols == 0 {
2105 return Err(DmaError::EmptyInputData);
2106 }
2107 if rows == 0 {
2108 return Err(DmaError::InvalidRange {
2109 start: sweep.hull_length.0,
2110 end: sweep.hull_length.1,
2111 step: sweep.hull_length.2,
2112 });
2113 }
2114
2115 let _cap = rows.checked_mul(cols).ok_or(DmaError::InvalidRange {
2116 start: rows,
2117 end: cols,
2118 step: 0,
2119 })?;
2120
2121 let mut buf_mu = make_uninit_matrix(rows, cols);
2122
2123 let first = data
2124 .iter()
2125 .position(|x| !x.is_nan())
2126 .ok_or(DmaError::AllValuesNaN)?;
2127 let warm: Vec<usize> = combos
2128 .iter()
2129 .map(|c| {
2130 let h = c.hull_length.unwrap();
2131 let e = c.ema_length.unwrap();
2132 let sqrt_len = (h as f64).sqrt().round() as usize;
2133 first + h.max(e) + sqrt_len - 1
2134 })
2135 .collect();
2136 init_matrix_prefixes(&mut buf_mu, cols, &warm);
2137
2138 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
2139 let out: &mut [f64] =
2140 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
2141
2142 dma_batch_inner_into(data, sweep, kern, parallel, out)?;
2143
2144 let values = unsafe {
2145 Vec::from_raw_parts(
2146 guard.as_mut_ptr() as *mut f64,
2147 guard.len(),
2148 guard.capacity(),
2149 )
2150 };
2151
2152 Ok(DmaBatchOutput {
2153 values,
2154 combos,
2155 rows,
2156 cols,
2157 })
2158}
2159
2160pub fn dma_batch_with_kernel(
2161 data: &[f64],
2162 sweep: &DmaBatchRange,
2163 k: Kernel,
2164) -> Result<DmaBatchOutput, DmaError> {
2165 let kernel = match k {
2166 Kernel::Auto => detect_best_batch_kernel(),
2167 other if other.is_batch() => other,
2168 other => return Err(DmaError::InvalidKernelForBatch(other)),
2169 };
2170
2171 let simd = match kernel {
2172 Kernel::Avx512Batch => Kernel::Avx512,
2173 Kernel::Avx2Batch => Kernel::Avx2,
2174 Kernel::ScalarBatch => Kernel::Scalar,
2175 _ => unreachable!(),
2176 };
2177 dma_batch_par_slice(data, sweep, simd)
2178}
2179
2180#[inline(always)]
2181fn dma_batch_inner_into(
2182 data: &[f64],
2183 sweep: &DmaBatchRange,
2184 k: Kernel,
2185 parallel: bool,
2186 out: &mut [f64],
2187) -> Result<Vec<DmaParams>, DmaError> {
2188 let combos = expand_grid_dma(sweep);
2189 if combos.is_empty() {
2190 return Err(DmaError::InvalidRange {
2191 start: sweep.hull_length.0,
2192 end: sweep.hull_length.1,
2193 step: sweep.hull_length.2,
2194 });
2195 }
2196
2197 let first = data
2198 .iter()
2199 .position(|x| !x.is_nan())
2200 .ok_or(DmaError::AllValuesNaN)?;
2201 let cols = data.len();
2202
2203 let actual = match k {
2204 Kernel::Auto => detect_best_batch_kernel(),
2205 other => other,
2206 };
2207 let simd = match actual {
2208 Kernel::Avx512Batch => Kernel::Avx512,
2209 Kernel::Avx2Batch => Kernel::Avx2,
2210 Kernel::ScalarBatch => Kernel::Scalar,
2211 _ => actual,
2212 };
2213
2214 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| {
2215 let dst = unsafe {
2216 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len())
2217 };
2218 let prm = &combos[row];
2219 let hull_len = prm.hull_length.unwrap_or(7);
2220 let ema_len = prm.ema_length.unwrap_or(20);
2221
2222 let sqrt_len = (hull_len as f64).sqrt().round() as usize;
2223 let warmup_end = first + hull_len.max(ema_len) + sqrt_len - 1;
2224 let warmup_end = warmup_end.min(dst.len());
2225
2226 for i in 0..warmup_end {
2227 dst[i] = f64::NAN;
2228 }
2229
2230 dma_compute_into(
2231 data,
2232 hull_len,
2233 ema_len,
2234 prm.ema_gain_limit.unwrap_or(50),
2235 prm.hull_ma_type.as_ref().unwrap_or(&"WMA".to_string()),
2236 first,
2237 simd,
2238 dst,
2239 );
2240 };
2241
2242 let dst_mu = unsafe {
2243 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
2244 };
2245
2246 if parallel {
2247 #[cfg(not(target_arch = "wasm32"))]
2248 dst_mu
2249 .par_chunks_mut(cols)
2250 .enumerate()
2251 .for_each(|(r, row)| do_row(r, row));
2252 #[cfg(target_arch = "wasm32")]
2253 for (r, row) in dst_mu.chunks_mut(cols).enumerate() {
2254 do_row(r, row);
2255 }
2256 } else {
2257 for (r, row) in dst_mu.chunks_mut(cols).enumerate() {
2258 do_row(r, row);
2259 }
2260 }
2261
2262 Ok(combos)
2263}
2264
2265#[cfg(feature = "python")]
2266#[pyfunction(name = "dma")]
2267#[pyo3(signature = (data, hull_length=7, ema_length=20, ema_gain_limit=50, hull_ma_type="WMA", kernel=None))]
2268pub fn dma_py<'py>(
2269 py: Python<'py>,
2270 data: PyReadonlyArray1<'py, f64>,
2271 hull_length: usize,
2272 ema_length: usize,
2273 ema_gain_limit: usize,
2274 hull_ma_type: &str,
2275 kernel: Option<&str>,
2276) -> PyResult<Bound<'py, PyArray1<f64>>> {
2277 let slice_in = data.as_slice()?;
2278 let kern = validate_kernel(kernel, false)?;
2279 let params = DmaParams {
2280 hull_length: Some(hull_length),
2281 ema_length: Some(ema_length),
2282 ema_gain_limit: Some(ema_gain_limit),
2283 hull_ma_type: Some(hull_ma_type.to_string()),
2284 };
2285 let input = DmaInput::from_slice(slice_in, params);
2286
2287 let result_vec: Vec<f64> = py
2288 .allow_threads(|| dma_with_kernel(&input, kern).map(|o| o.values))
2289 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2290
2291 Ok(result_vec.into_pyarray(py))
2292}
2293
2294#[cfg(feature = "python")]
2295#[pyclass(name = "DmaStream")]
2296pub struct DmaStreamPy {
2297 stream: DmaStream,
2298}
2299
2300#[cfg(feature = "python")]
2301#[pymethods]
2302impl DmaStreamPy {
2303 #[new]
2304 fn new(
2305 hull_length: usize,
2306 ema_length: usize,
2307 ema_gain_limit: usize,
2308 hull_ma_type: &str,
2309 ) -> PyResult<Self> {
2310 let params = DmaParams {
2311 hull_length: Some(hull_length),
2312 ema_length: Some(ema_length),
2313 ema_gain_limit: Some(ema_gain_limit),
2314 hull_ma_type: Some(hull_ma_type.to_string()),
2315 };
2316 let stream =
2317 DmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2318 Ok(DmaStreamPy { stream })
2319 }
2320
2321 fn update(&mut self, value: f64) -> Option<f64> {
2322 self.stream.update(value)
2323 }
2324}
2325
2326#[cfg(feature = "python")]
2327#[pyfunction(name = "dma_batch")]
2328#[pyo3(signature = (data, hull_length_range, ema_length_range, ema_gain_limit_range, hull_ma_type="WMA", kernel=None))]
2329pub fn dma_batch_py<'py>(
2330 py: Python<'py>,
2331 data: PyReadonlyArray1<'py, f64>,
2332 hull_length_range: (usize, usize, usize),
2333 ema_length_range: (usize, usize, usize),
2334 ema_gain_limit_range: (usize, usize, usize),
2335 hull_ma_type: &str,
2336 kernel: Option<&str>,
2337) -> PyResult<Bound<'py, PyDict>> {
2338 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2339 let slice_in = data.as_slice()?;
2340
2341 let sweep = DmaBatchRange {
2342 hull_length: hull_length_range,
2343 ema_length: ema_length_range,
2344 ema_gain_limit: ema_gain_limit_range,
2345 hull_ma_type: hull_ma_type.to_string(),
2346 };
2347
2348 let combos = expand_grid_dma(&sweep);
2349 let rows = combos.len();
2350 let cols = slice_in.len();
2351 let total = rows
2352 .checked_mul(cols)
2353 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2354
2355 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2356 let slice_out = unsafe { out_arr.as_slice_mut()? };
2357
2358 let kern = validate_kernel(kernel, true)?;
2359 let combos = py
2360 .allow_threads(|| {
2361 let kernel = match kern {
2362 Kernel::Auto => detect_best_batch_kernel(),
2363 k => k,
2364 };
2365 let simd = match kernel {
2366 Kernel::Avx512Batch => Kernel::Avx512,
2367 Kernel::Avx2Batch => Kernel::Avx2,
2368 Kernel::ScalarBatch => Kernel::Scalar,
2369 other => return Err(DmaError::InvalidKernelForBatch(other)),
2370 };
2371 dma_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2372 })
2373 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2374
2375 let dict = PyDict::new(py);
2376 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2377 dict.set_item(
2378 "hull_lengths",
2379 combos
2380 .iter()
2381 .map(|p| p.hull_length.unwrap() as u64)
2382 .collect::<Vec<_>>()
2383 .into_pyarray(py),
2384 )?;
2385 dict.set_item(
2386 "ema_lengths",
2387 combos
2388 .iter()
2389 .map(|p| p.ema_length.unwrap() as u64)
2390 .collect::<Vec<_>>()
2391 .into_pyarray(py),
2392 )?;
2393 dict.set_item(
2394 "ema_gain_limits",
2395 combos
2396 .iter()
2397 .map(|p| p.ema_gain_limit.unwrap() as u64)
2398 .collect::<Vec<_>>()
2399 .into_pyarray(py),
2400 )?;
2401 dict.set_item("hull_ma_type", hull_ma_type)?;
2402
2403 dict.set_item(
2404 "hull_ma_types",
2405 combos
2406 .iter()
2407 .map(|p| p.hull_ma_type.as_deref().unwrap_or("WMA"))
2408 .collect::<Vec<_>>(),
2409 )?;
2410
2411 Ok(dict)
2412}
2413
2414#[cfg(all(feature = "python", feature = "cuda"))]
2415#[pyfunction(name = "dma_cuda_batch_dev")]
2416#[pyo3(signature = (data_f32, hull_length_range, ema_length_range, ema_gain_limit_range, hull_ma_type="WMA", device_id=0))]
2417pub fn dma_cuda_batch_dev_py(
2418 py: Python<'_>,
2419 data_f32: numpy::PyReadonlyArray1<'_, f32>,
2420 hull_length_range: (usize, usize, usize),
2421 ema_length_range: (usize, usize, usize),
2422 ema_gain_limit_range: (usize, usize, usize),
2423 hull_ma_type: &str,
2424 device_id: usize,
2425) -> PyResult<DmaDeviceArrayF32Py> {
2426 if !cuda_available() {
2427 return Err(PyValueError::new_err("CUDA not available"));
2428 }
2429
2430 let sweep = DmaBatchRange {
2431 hull_length: hull_length_range,
2432 ema_length: ema_length_range,
2433 ema_gain_limit: ema_gain_limit_range,
2434 hull_ma_type: hull_ma_type.to_string(),
2435 };
2436
2437 let slice_in = data_f32.as_slice()?;
2438 let (inner, ctx, dev_id) = py.allow_threads(|| {
2439 let cuda = CudaDma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2440 let ctx = cuda.context();
2441 let dev_id = cuda.device_id();
2442 let arr = cuda
2443 .dma_batch_dev(slice_in, &sweep)
2444 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2445 Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
2446 })?;
2447
2448 Ok(DmaDeviceArrayF32Py {
2449 inner: Some(inner),
2450 _ctx: ctx,
2451 device_id: dev_id,
2452 })
2453}
2454
2455#[cfg(all(feature = "python", feature = "cuda"))]
2456#[pyfunction(name = "dma_cuda_many_series_one_param_dev")]
2457#[pyo3(signature = (data_tm_f32, hull_length, ema_length, ema_gain_limit, hull_ma_type="WMA", device_id=0))]
2458pub fn dma_cuda_many_series_one_param_dev_py(
2459 py: Python<'_>,
2460 data_tm_f32: PyReadonlyArray2<'_, f32>,
2461 hull_length: usize,
2462 ema_length: usize,
2463 ema_gain_limit: usize,
2464 hull_ma_type: &str,
2465 device_id: usize,
2466) -> PyResult<DmaDeviceArrayF32Py> {
2467 if !cuda_available() {
2468 return Err(PyValueError::new_err("CUDA not available"));
2469 }
2470
2471 let flat_in: &[f32] = data_tm_f32.as_slice()?;
2472 let rows = data_tm_f32.shape()[0];
2473 let cols = data_tm_f32.shape()[1];
2474 let params = DmaParams {
2475 hull_length: Some(hull_length),
2476 ema_length: Some(ema_length),
2477 ema_gain_limit: Some(ema_gain_limit),
2478 hull_ma_type: Some(hull_ma_type.to_string()),
2479 };
2480
2481 let (inner, ctx, dev_id) = py.allow_threads(|| {
2482 let cuda = CudaDma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2483 let ctx = cuda.context();
2484 let dev_id = cuda.device_id();
2485 let arr = cuda
2486 .dma_many_series_one_param_time_major_dev(flat_in, cols, rows, ¶ms)
2487 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2488 Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
2489 })?;
2490
2491 Ok(DmaDeviceArrayF32Py {
2492 inner: Some(inner),
2493 _ctx: ctx,
2494 device_id: dev_id,
2495 })
2496}
2497
2498#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2499#[wasm_bindgen]
2500pub fn dma_js(
2501 data: &[f64],
2502 hull_length: usize,
2503 ema_length: usize,
2504 ema_gain_limit: usize,
2505 hull_ma_type: &str,
2506) -> Result<Vec<f64>, JsValue> {
2507 let params = DmaParams {
2508 hull_length: Some(hull_length),
2509 ema_length: Some(ema_length),
2510 ema_gain_limit: Some(ema_gain_limit),
2511 hull_ma_type: Some(hull_ma_type.to_string()),
2512 };
2513 let input = DmaInput::from_slice(data, params);
2514
2515 let mut output = vec![0.0; data.len()];
2516 dma_into_slice(&mut output, &input, Kernel::Auto)
2517 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2518
2519 Ok(output)
2520}
2521
2522#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2523#[wasm_bindgen]
2524pub fn dma_alloc(len: usize) -> *mut f64 {
2525 let mut vec = Vec::<f64>::with_capacity(len);
2526 let ptr = vec.as_mut_ptr();
2527 std::mem::forget(vec);
2528 ptr
2529}
2530
2531#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2532#[wasm_bindgen]
2533pub fn dma_free(ptr: *mut f64, len: usize) {
2534 unsafe {
2535 let _ = Vec::from_raw_parts(ptr, len, len);
2536 }
2537}
2538
2539#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2540#[wasm_bindgen]
2541pub fn dma_into(
2542 in_ptr: *const f64,
2543 out_ptr: *mut f64,
2544 len: usize,
2545 hull_length: usize,
2546 ema_length: usize,
2547 ema_gain_limit: usize,
2548 hull_ma_type: &str,
2549) -> Result<(), JsValue> {
2550 if in_ptr.is_null() || out_ptr.is_null() {
2551 return Err(JsValue::from_str("null pointer passed to dma_into"));
2552 }
2553
2554 unsafe {
2555 let data = std::slice::from_raw_parts(in_ptr, len);
2556
2557 let params = DmaParams {
2558 hull_length: Some(hull_length),
2559 ema_length: Some(ema_length),
2560 ema_gain_limit: Some(ema_gain_limit),
2561 hull_ma_type: Some(hull_ma_type.to_string()),
2562 };
2563 let input = DmaInput::from_slice(data, params);
2564
2565 if in_ptr == out_ptr {
2566 let mut temp = vec![0.0; len];
2567 dma_into_slice(&mut temp, &input, Kernel::Auto)
2568 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2569 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2570 out.copy_from_slice(&temp);
2571 } else {
2572 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2573 dma_into_slice(out, &input, Kernel::Auto)
2574 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2575 }
2576
2577 Ok(())
2578 }
2579}
2580
2581#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2582#[derive(Serialize, Deserialize)]
2583pub struct DmaBatchConfig {
2584 pub hull_length_range: (usize, usize, usize),
2585 pub ema_length_range: (usize, usize, usize),
2586 pub ema_gain_limit_range: (usize, usize, usize),
2587 pub hull_ma_type: String,
2588}
2589
2590#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2591#[derive(Serialize, Deserialize)]
2592pub struct DmaBatchJsOutput {
2593 pub values: Vec<f64>,
2594 pub combos: Vec<DmaParams>,
2595 pub rows: usize,
2596 pub cols: usize,
2597}
2598
2599#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2600#[wasm_bindgen(js_name = dma_batch)]
2601pub fn dma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2602 let cfg: DmaBatchConfig = serde_wasm_bindgen::from_value(config)
2603 .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
2604
2605 let sweep = DmaBatchRange {
2606 hull_length: cfg.hull_length_range,
2607 ema_length: cfg.ema_length_range,
2608 ema_gain_limit: cfg.ema_gain_limit_range,
2609 hull_ma_type: cfg.hull_ma_type,
2610 };
2611
2612 let combos = expand_grid_dma(&sweep);
2613 let rows = combos.len();
2614 let cols = data.len();
2615 if rows == 0 {
2616 return Err(JsValue::from_str("no parameter combinations"));
2617 }
2618
2619 let mut buf_mu = make_uninit_matrix(rows, cols);
2620 let first = data
2621 .iter()
2622 .position(|x| !x.is_nan())
2623 .ok_or_else(|| JsValue::from_str("All NaN"))?;
2624 let warm: Vec<usize> = combos
2625 .iter()
2626 .map(|c| {
2627 let h = c.hull_length.unwrap();
2628 let e = c.ema_length.unwrap();
2629 let sqrt_len = (h as f64).sqrt().round() as usize;
2630 first + h.max(e) + sqrt_len - 1
2631 })
2632 .collect();
2633 init_matrix_prefixes(&mut buf_mu, cols, &warm);
2634
2635 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
2636 let out: &mut [f64] =
2637 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
2638
2639 dma_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
2640 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2641
2642 let values = unsafe {
2643 Vec::from_raw_parts(
2644 guard.as_mut_ptr() as *mut f64,
2645 guard.len(),
2646 guard.capacity(),
2647 )
2648 };
2649 let js = DmaBatchJsOutput {
2650 values,
2651 combos,
2652 rows,
2653 cols,
2654 };
2655 serde_wasm_bindgen::to_value(&js)
2656 .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
2657}
2658
2659#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2660#[wasm_bindgen]
2661pub fn dma_batch_into(
2662 in_ptr: *const f64,
2663 out_ptr: *mut f64,
2664 len: usize,
2665 hull_start: usize,
2666 hull_end: usize,
2667 hull_step: usize,
2668 ema_start: usize,
2669 ema_end: usize,
2670 ema_step: usize,
2671 gain_start: usize,
2672 gain_end: usize,
2673 gain_step: usize,
2674 hull_ma_type: &str,
2675) -> Result<usize, JsValue> {
2676 if in_ptr.is_null() || out_ptr.is_null() {
2677 return Err(JsValue::from_str("null pointer passed to dma_batch_into"));
2678 }
2679 unsafe {
2680 let data = std::slice::from_raw_parts(in_ptr, len);
2681 let sweep = DmaBatchRange {
2682 hull_length: (hull_start, hull_end, hull_step),
2683 ema_length: (ema_start, ema_end, ema_step),
2684 ema_gain_limit: (gain_start, gain_end, gain_step),
2685 hull_ma_type: hull_ma_type.to_string(),
2686 };
2687 let combos = expand_grid_dma(&sweep);
2688 let rows = combos.len();
2689 let cols = len;
2690
2691 let out_mu = std::slice::from_raw_parts_mut(out_ptr as *mut MaybeUninit<f64>, rows * cols);
2692 let first = data
2693 .iter()
2694 .position(|x| !x.is_nan())
2695 .ok_or_else(|| JsValue::from_str("All NaN"))?;
2696 let warm: Vec<usize> = combos
2697 .iter()
2698 .map(|c| {
2699 let h = c.hull_length.unwrap();
2700 let e = c.ema_length.unwrap();
2701 let sqrt_len = (h as f64).sqrt().round() as usize;
2702 first + h.max(e) + sqrt_len - 1
2703 })
2704 .collect();
2705 init_matrix_prefixes(out_mu, cols, &warm);
2706
2707 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2708 dma_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
2709 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2710
2711 Ok(rows)
2712 }
2713}
2714
2715#[cfg(test)]
2716mod tests {
2717 use super::*;
2718 use crate::skip_if_unsupported;
2719 use crate::utilities::data_loader::read_candles_from_csv;
2720 use std::error::Error;
2721
2722 fn check_dma_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2723 skip_if_unsupported!(kernel, test_name);
2724 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2725 let candles = read_candles_from_csv(file_path)?;
2726
2727 let input = DmaInput::from_candles(&candles, "close", DmaParams::default());
2728 let result = dma_with_kernel(&input, kernel)?;
2729
2730 let expected_last_five = [
2731 59404.62489256,
2732 59326.48766951,
2733 59195.35128538,
2734 59153.22811529,
2735 58933.88503421,
2736 ];
2737
2738 let start = result.values.len().saturating_sub(5);
2739 for (i, &val) in result.values[start..].iter().enumerate() {
2740 let diff = (val - expected_last_five[i]).abs();
2741 assert!(
2742 diff < 0.001,
2743 "[{}] DMA {:?} mismatch at idx {}: got {}, expected {}, diff {}",
2744 test_name,
2745 kernel,
2746 i,
2747 val,
2748 expected_last_five[i],
2749 diff
2750 );
2751 }
2752 Ok(())
2753 }
2754
2755 fn check_dma_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2756 skip_if_unsupported!(kernel, test_name);
2757 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2758 let candles = read_candles_from_csv(file_path)?;
2759
2760 let default_params = DmaParams {
2761 hull_length: None,
2762 ema_length: None,
2763 ema_gain_limit: None,
2764 hull_ma_type: None,
2765 };
2766 let input = DmaInput::from_candles(&candles, "close", default_params);
2767 let output = dma_with_kernel(&input, kernel)?;
2768 assert_eq!(output.values.len(), candles.close.len());
2769
2770 Ok(())
2771 }
2772
2773 fn check_dma_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2774 skip_if_unsupported!(kernel, test_name);
2775 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2776 let candles = read_candles_from_csv(file_path)?;
2777
2778 let input = DmaInput::with_default_candles(&candles);
2779 match input.data {
2780 DmaData::Candles { source, .. } => assert_eq!(source, "close"),
2781 _ => panic!("Expected DmaData::Candles"),
2782 }
2783 let output = dma_with_kernel(&input, kernel)?;
2784 assert_eq!(output.values.len(), candles.close.len());
2785
2786 Ok(())
2787 }
2788
2789 fn check_dma_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2790 skip_if_unsupported!(kernel, test_name);
2791 let input_data = [10.0, 20.0, 30.0];
2792 let params = DmaParams {
2793 hull_length: Some(0),
2794 ema_length: None,
2795 ema_gain_limit: None,
2796 hull_ma_type: None,
2797 };
2798 let input = DmaInput::from_slice(&input_data, params);
2799 let res = dma_with_kernel(&input, kernel);
2800 assert!(
2801 res.is_err(),
2802 "[{}] DMA should fail with zero period",
2803 test_name
2804 );
2805 Ok(())
2806 }
2807
2808 fn check_dma_period_exceeds_length(
2809 test_name: &str,
2810 kernel: Kernel,
2811 ) -> Result<(), Box<dyn Error>> {
2812 skip_if_unsupported!(kernel, test_name);
2813 let data_small = [10.0, 20.0, 30.0];
2814 let params = DmaParams {
2815 hull_length: Some(10),
2816 ema_length: None,
2817 ema_gain_limit: None,
2818 hull_ma_type: None,
2819 };
2820 let input = DmaInput::from_slice(&data_small, params);
2821 let res = dma_with_kernel(&input, kernel);
2822 assert!(
2823 res.is_err(),
2824 "[{}] DMA should fail with period exceeding length",
2825 test_name
2826 );
2827 Ok(())
2828 }
2829
2830 fn check_dma_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2831 skip_if_unsupported!(kernel, test_name);
2832 let single_point = [42.0];
2833 let params = DmaParams::default();
2834 let input = DmaInput::from_slice(&single_point, params);
2835 let res = dma_with_kernel(&input, kernel);
2836 assert!(
2837 res.is_err(),
2838 "[{}] DMA should fail with insufficient data",
2839 test_name
2840 );
2841 Ok(())
2842 }
2843
2844 fn check_dma_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2845 skip_if_unsupported!(kernel, test_name);
2846 let empty: [f64; 0] = [];
2847 let params = DmaParams::default();
2848 let input = DmaInput::from_slice(&empty, params);
2849 let res = dma_with_kernel(&input, kernel);
2850 assert!(
2851 res.is_err(),
2852 "[{}] DMA should fail with empty input",
2853 test_name
2854 );
2855 Ok(())
2856 }
2857
2858 fn check_dma_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2859 skip_if_unsupported!(kernel, test_name);
2860 let nan_data = [f64::NAN, f64::NAN, f64::NAN];
2861 let params = DmaParams::default();
2862 let input = DmaInput::from_slice(&nan_data, params);
2863 let res = dma_with_kernel(&input, kernel);
2864 assert!(
2865 res.is_err(),
2866 "[{}] DMA should fail with all NaN values",
2867 test_name
2868 );
2869 Ok(())
2870 }
2871
2872 fn check_dma_invalid_hull_type(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2873 skip_if_unsupported!(kernel, test_name);
2874 let data = [10.0; 50];
2875 let params = DmaParams {
2876 hull_length: Some(7),
2877 ema_length: Some(20),
2878 ema_gain_limit: Some(50),
2879 hull_ma_type: Some("INVALID".to_string()),
2880 };
2881 let input = DmaInput::from_slice(&data, params);
2882 let res = dma_with_kernel(&input, kernel);
2883 assert!(
2884 res.is_err(),
2885 "[{}] DMA should fail with invalid hull_ma_type",
2886 test_name
2887 );
2888 Ok(())
2889 }
2890
2891 macro_rules! generate_all_dma_tests {
2892 ($($test_fn:ident),*) => {
2893 paste::paste! {
2894 $(
2895 #[test] fn [<$test_fn _scalar>]() -> Result<(), Box<dyn Error>> { $test_fn(stringify!([<$test_fn _scalar>]), Kernel::Scalar) }
2896 #[test] fn [<$test_fn _auto> ]() -> Result<(), Box<dyn Error>> { $test_fn(stringify!([<$test_fn _auto>]), Kernel::Auto) }
2897 )*
2898 #[cfg(all(feature="nightly-avx", target_arch="x86_64"))]
2899 $(
2900 #[test] fn [<$test_fn _avx2> ]() -> Result<(), Box<dyn Error>> { $test_fn(stringify!([<$test_fn _avx2>]), Kernel::Avx2) }
2901 #[test] fn [<$test_fn _avx512>]() -> Result<(), Box<dyn Error>> { $test_fn(stringify!([<$test_fn _avx512>]), Kernel::Avx512) }
2902 )*
2903 }
2904 }
2905 }
2906
2907 generate_all_dma_tests!(
2908 check_dma_accuracy,
2909 check_dma_partial_params,
2910 check_dma_default_candles,
2911 check_dma_zero_period,
2912 check_dma_period_exceeds_length,
2913 check_dma_very_small_dataset,
2914 check_dma_empty_input,
2915 check_dma_all_nan,
2916 check_dma_invalid_hull_type
2917 );
2918
2919 macro_rules! generate_dma_batch_tests {
2920 ($($fn_name:ident),*) => {
2921 paste::paste! {
2922 $(
2923 #[test]
2924 fn [<$fn_name _scalar_batch>]() -> Result<(), Box<dyn Error>> {
2925 $fn_name(stringify!([<$fn_name _scalar_batch>]), Kernel::ScalarBatch)
2926 }
2927 )*
2928 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2929 $(
2930 #[test]
2931 fn [<$fn_name _avx2_batch>]() -> Result<(), Box<dyn Error>> {
2932 $fn_name(stringify!([<$fn_name _avx2_batch>]), Kernel::Avx2Batch)
2933 }
2934 #[test]
2935 fn [<$fn_name _avx512_batch>]() -> Result<(), Box<dyn Error>> {
2936 $fn_name(stringify!([<$fn_name _avx512_batch>]), Kernel::Avx512Batch)
2937 }
2938 )*
2939 }
2940 };
2941 }
2942
2943 generate_dma_batch_tests!(check_dma_batch_basic);
2944
2945 macro_rules! gen_batch_tests {
2946 ($fn_name:ident) => {
2947 paste::paste! {
2948 #[test] fn [<$fn_name _scalar>]() { let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch); }
2949 #[cfg(all(feature="nightly-avx", target_arch="x86_64"))]
2950 #[test] fn [<$fn_name _avx2>]() { let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch); }
2951 #[cfg(all(feature="nightly-avx", target_arch="x86_64"))]
2952 #[test] fn [<$fn_name _avx512>]() { let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch); }
2953 #[test] fn [<$fn_name _auto_detect>]() { let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto); }
2954 }
2955 };
2956 }
2957
2958 gen_batch_tests!(check_batch_sweep);
2959
2960 fn check_dma_reinput(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2961 skip_if_unsupported!(kernel, test);
2962 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2963 let c = read_candles_from_csv(file)?;
2964
2965 let first = DmaInput::from_candles(&c, "close", DmaParams::default());
2966 let out1 = dma_with_kernel(&first, kernel)?.values;
2967
2968 let second = DmaInput::from_slice(&out1, DmaParams::default());
2969 let out2 = dma_with_kernel(&second, kernel)?.values;
2970
2971 assert_eq!(out2.len(), out1.len());
2972 Ok(())
2973 }
2974
2975 fn check_dma_nan_handling(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2976 skip_if_unsupported!(kernel, test);
2977 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2978 let c = read_candles_from_csv(file)?;
2979
2980 let p = DmaParams::default();
2981 let input = DmaInput::from_candles(&c, "close", p.clone());
2982 let out = dma_with_kernel(&input, kernel)?.values;
2983
2984 let first = c.close.iter().position(|x| !x.is_nan()).unwrap_or(0);
2985 let sqrt_len = (p.hull_length.unwrap_or(7) as f64).sqrt().round() as usize;
2986 let warm =
2987 first + p.hull_length.unwrap_or(7).max(p.ema_length.unwrap_or(20)) + sqrt_len - 1;
2988 for (i, &v) in out.iter().enumerate().skip(warm.min(out.len())) {
2989 assert!(!v.is_nan(), "[{test}] unexpected NaN at {i}");
2990 }
2991 Ok(())
2992 }
2993
2994 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2995 skip_if_unsupported!(kernel, test);
2996 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2997 let c = read_candles_from_csv(file)?;
2998
2999 let out = DmaBatchBuilder::new()
3000 .kernel(kernel)
3001 .apply_candles(&c, "close")?;
3002 let def = DmaParams::default();
3003 let row = out.values_for(&def).expect("default row missing");
3004 assert_eq!(row.len(), c.close.len());
3005 Ok(())
3006 }
3007
3008 fn check_batch_sweep(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3009 skip_if_unsupported!(kernel, test);
3010 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3011 let c = read_candles_from_csv(file)?;
3012 let out = DmaBatchBuilder::new()
3013 .kernel(kernel)
3014 .hull_length_range(7, 18, 1)
3015 .ema_length_range(10, 15, 1)
3016 .ema_gain_limit_range(10, 20, 5)
3017 .apply_candles(&c, "close")?;
3018 let expected = 12 * 6 * 3;
3019 assert_eq!(out.combos.len(), expected);
3020 assert_eq!(out.rows, expected);
3021 assert_eq!(out.cols, c.close.len());
3022 Ok(())
3023 }
3024
3025 fn check_dma_streaming(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3026 skip_if_unsupported!(kernel, test);
3027 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3028 let c = read_candles_from_csv(file)?;
3029 let p = DmaParams::default();
3030
3031 let batch =
3032 dma_with_kernel(&DmaInput::from_candles(&c, "close", p.clone()), kernel)?.values;
3033
3034 let mut s = DmaStream::try_new(p)?;
3035 let mut stream = Vec::with_capacity(c.close.len());
3036 for &x in &c.close {
3037 stream.push(s.update(x).unwrap_or(f64::NAN));
3038 }
3039
3040 assert_eq!(batch.len(), stream.len());
3041 for (i, (&b, &t)) in batch.iter().zip(&stream).enumerate() {
3042 if b.is_nan() && t.is_nan() {
3043 continue;
3044 }
3045 assert!(
3046 (b - t).abs() < 1e-9,
3047 "[{test}] idx {i} diff {}",
3048 (b - t).abs()
3049 );
3050 }
3051 Ok(())
3052 }
3053
3054 macro_rules! gen_added_dma_tests {
3055 ($($f:ident),*) => {
3056 paste::paste! {
3057 $(
3058 #[test] fn [<$f _scalar>]() -> Result<(), Box<dyn Error>> {
3059 $f(stringify!([<$f _scalar>]), Kernel::Scalar)
3060 }
3061 #[cfg(all(feature="nightly-avx", target_arch="x86_64"))]
3062 #[test] fn [<$f _avx2>]() -> Result<(), Box<dyn Error>> {
3063 $f(stringify!([<$f _avx2>]), Kernel::Avx2)
3064 }
3065 #[cfg(all(feature="nightly-avx", target_arch="x86_64"))]
3066 #[test] fn [<$f _avx512>]() -> Result<(), Box<dyn Error>> {
3067 $f(stringify!([<$f _avx512>]), Kernel::Avx512)
3068 }
3069 )*
3070 }
3071 }
3072 }
3073
3074 gen_added_dma_tests!(check_dma_reinput, check_dma_nan_handling);
3075
3076 macro_rules! gen_batch_sweep_tests {
3077 ($($f:ident),*) => {
3078 paste::paste! {
3079 $(
3080 #[test] fn [<$f _scalar_batch>]() -> Result<(), Box<dyn Error>> {
3081 $f(stringify!([<$f _scalar_batch>]), Kernel::ScalarBatch)
3082 }
3083 #[cfg(all(feature="nightly-avx", target_arch="x86_64"))]
3084 #[test] fn [<$f _avx2_batch>]() -> Result<(), Box<dyn Error>> {
3085 $f(stringify!([<$f _avx2_batch>]), Kernel::Avx2Batch)
3086 }
3087 #[cfg(all(feature="nightly-avx", target_arch="x86_64"))]
3088 #[test] fn [<$f _avx512_batch>]() -> Result<(), Box<dyn Error>> {
3089 $f(stringify!([<$f _avx512_batch>]), Kernel::Avx512Batch)
3090 }
3091 )*
3092 }
3093 }
3094 }
3095
3096 gen_batch_sweep_tests!(check_batch_default_row);
3097
3098 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
3099 #[test]
3100 fn test_dma_simd128_correctness() {
3101 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
3102 let p = DmaParams::default();
3103 let input = DmaInput::from_slice(&data, p);
3104 let scalar = dma_with_kernel(&input, Kernel::Scalar).unwrap();
3105 let simd = dma_with_kernel(&input, Kernel::Scalar).unwrap();
3106 assert_eq!(scalar.values.len(), simd.values.len());
3107 for (a, b) in scalar.values.iter().zip(simd.values.iter()) {
3108 assert!((a - b).abs() < 1e-10);
3109 }
3110 }
3111
3112 #[cfg(debug_assertions)]
3113 #[test]
3114 fn test_dma_no_poison_values() -> Result<(), Box<dyn Error>> {
3115 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3116 let candles = read_candles_from_csv(file_path)?;
3117
3118 let input = DmaInput::from_candles(&candles, "close", DmaParams::default());
3119 let output = dma(&input)?;
3120
3121 for &v in &output.values {
3122 if v.is_nan() {
3123 continue;
3124 }
3125 let b = v.to_bits();
3126
3127 assert_ne!(
3128 b, 0x11111111_11111111,
3129 "Found poison value 0x11111111_11111111"
3130 );
3131 assert_ne!(
3132 b, 0x22222222_22222222,
3133 "Found poison value 0x22222222_22222222"
3134 );
3135 assert_ne!(
3136 b, 0x33333333_33333333,
3137 "Found poison value 0x33333333_33333333"
3138 );
3139 assert_ne!(
3140 b, 0xDEADBEEF_DEADBEEF,
3141 "Found poison value 0xDEADBEEF_DEADBEEF"
3142 );
3143 assert_ne!(
3144 b, 0xFEEEFEEE_FEEEFEEE,
3145 "Found poison value 0xFEEEFEEE_FEEEFEEE"
3146 );
3147 }
3148 Ok(())
3149 }
3150
3151 fn check_dma_batch_basic(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3152 skip_if_unsupported!(kernel, test_name);
3153 let data = vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0];
3154
3155 let sweep = DmaBatchRange {
3156 hull_length: (3, 5, 1),
3157 ema_length: (5, 5, 0),
3158 ema_gain_limit: (10, 10, 0),
3159 hull_ma_type: "WMA".to_string(),
3160 };
3161 let output = dma_batch_with_kernel(&data, &sweep, kernel)?;
3162
3163 assert_eq!(
3164 output.rows, 3,
3165 "[{}] Expected 3 rows for hull_length range 3-5",
3166 test_name
3167 );
3168 assert_eq!(output.cols, data.len());
3169 assert_eq!(output.values.len(), output.rows * output.cols);
3170 assert_eq!(output.combos.len(), output.rows);
3171
3172 Ok(())
3173 }
3174
3175 #[test]
3176 fn test_dma_stream_incremental() -> Result<(), Box<dyn Error>> {
3177 let params = DmaParams {
3178 hull_length: Some(3),
3179 ema_length: Some(3),
3180 ema_gain_limit: Some(10),
3181 hull_ma_type: Some("WMA".to_string()),
3182 };
3183
3184 let mut stream = DmaStream::try_new(params)?;
3185 let data = vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0];
3186
3187 let mut results = Vec::new();
3188 for &val in &data {
3189 if let Some(result) = stream.update(val) {
3190 results.push(result);
3191 }
3192 }
3193
3194 assert!(
3195 !results.is_empty(),
3196 "Stream should produce results after warmup"
3197 );
3198
3199 Ok(())
3200 }
3201
3202 #[cfg(debug_assertions)]
3203 #[test]
3204 fn test_dma_batch_no_poison_values() -> Result<(), Box<dyn std::error::Error>> {
3205 use crate::utilities::data_loader::read_candles_from_csv;
3206 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3207 let c = read_candles_from_csv(file)?;
3208 let out = DmaBatchBuilder::new()
3209 .hull_length_range(3, 8, 1)
3210 .ema_length_range(5, 10, 1)
3211 .ema_gain_limit_static(10)
3212 .apply_slice(&c.close)?;
3213 for &v in &out.values {
3214 if v.is_nan() {
3215 continue;
3216 }
3217 let b = v.to_bits();
3218 assert_ne!(b, 0x11111111_11111111);
3219 assert_ne!(b, 0x22222222_22222222);
3220 assert_ne!(b, 0x33333333_33333333);
3221 }
3222 Ok(())
3223 }
3224
3225 #[test]
3226 fn test_dma_into_matches_api() -> Result<(), Box<dyn Error>> {
3227 let mut data = Vec::with_capacity(160);
3228 data.extend_from_slice(&[f64::NAN, f64::NAN, f64::NAN]);
3229 for i in 0..157 {
3230 let x = (i as f64 * 0.15).sin() * 5.0 + (i as f64) * 0.01;
3231 data.push(x);
3232 }
3233
3234 let input = DmaInput::from_slice(&data, DmaParams::default());
3235
3236 let baseline = dma(&input)?;
3237
3238 let mut out = vec![0.0; data.len()];
3239 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
3240 {
3241 dma_into(&input, &mut out)?;
3242 }
3243 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3244 {
3245 dma_into_slice(&mut out, &input, Kernel::Auto)?;
3246 }
3247
3248 assert_eq!(baseline.values.len(), out.len());
3249
3250 for (a, b) in baseline.values.iter().copied().zip(out.iter().copied()) {
3251 let both_nan = a.is_nan() && b.is_nan();
3252 assert!(both_nan || a == b, "mismatch: got {b:?}, expected {a:?}");
3253 }
3254 Ok(())
3255 }
3256}
3257
3258#[cfg(all(feature = "python", feature = "cuda"))]
3259#[pyclass(module = "ta_indicators.cuda", name = "DmaDeviceArrayF32", unsendable)]
3260pub struct DmaDeviceArrayF32Py {
3261 pub(crate) inner: Option<DeviceArrayF32>,
3262 pub(crate) _ctx: Arc<Context>,
3263 pub(crate) device_id: u32,
3264}
3265
3266#[cfg(all(feature = "python", feature = "cuda"))]
3267#[pymethods]
3268impl DmaDeviceArrayF32Py {
3269 #[getter]
3270 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
3271 let inner = self
3272 .inner
3273 .as_ref()
3274 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
3275 let d = PyDict::new(py);
3276 d.set_item("shape", (inner.rows, inner.cols))?;
3277 d.set_item("typestr", "<f4")?;
3278
3279 d.set_item(
3280 "strides",
3281 (
3282 inner.cols * std::mem::size_of::<f32>(),
3283 std::mem::size_of::<f32>(),
3284 ),
3285 )?;
3286 d.set_item("data", (inner.device_ptr() as usize, false))?;
3287
3288 d.set_item("version", 3)?;
3289 Ok(d)
3290 }
3291 fn __dlpack_device__(&self) -> (i32, i32) {
3292 (2, self.device_id as i32)
3293 }
3294
3295 #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
3296 fn __dlpack__<'py>(
3297 &mut self,
3298 py: Python<'py>,
3299 stream: Option<pyo3::PyObject>,
3300 max_version: Option<pyo3::PyObject>,
3301 dl_device: Option<pyo3::PyObject>,
3302 copy: Option<pyo3::PyObject>,
3303 ) -> PyResult<pyo3::PyObject> {
3304 use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
3305
3306 let (kdl, alloc_dev) = self.__dlpack_device__();
3307 if let Some(dev_obj) = dl_device.as_ref() {
3308 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
3309 if dev_ty != kdl || dev_id != alloc_dev {
3310 let wants_copy = copy
3311 .as_ref()
3312 .and_then(|c| c.extract::<bool>(py).ok())
3313 .unwrap_or(false);
3314 if wants_copy {
3315 return Err(PyValueError::new_err(
3316 "device copy not implemented for __dlpack__",
3317 ));
3318 } else {
3319 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
3320 }
3321 }
3322 }
3323 }
3324
3325 let _ = stream;
3326
3327 let inner = self
3328 .inner
3329 .take()
3330 .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
3331
3332 let rows = inner.rows;
3333 let cols = inner.cols;
3334 let buf = inner.buf;
3335
3336 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
3337
3338 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
3339 }
3340}