1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::moving_averages::CudaHma;
3use crate::utilities::data_loader::{source_type, Candles};
4use crate::utilities::enums::Kernel;
5use crate::utilities::helpers::{
6 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
7 make_uninit_matrix,
8};
9use aligned_vec::{AVec, CACHELINE_ALIGN};
10#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
11use core::arch::x86_64::*;
12#[cfg(all(feature = "python", feature = "cuda"))]
13use cust::memory::DeviceBuffer;
14#[cfg(not(target_arch = "wasm32"))]
15use rayon::prelude::*;
16use std::convert::AsRef;
17use std::error::Error;
18use std::mem::MaybeUninit;
19use thiserror::Error;
20impl<'a> AsRef<[f64]> for HmaInput<'a> {
21 #[inline(always)]
22 fn as_ref(&self) -> &[f64] {
23 match &self.data {
24 HmaData::Slice(slice) => slice,
25 HmaData::Candles { candles, source } => source_type(candles, source),
26 }
27 }
28}
29
30#[derive(Debug, Clone)]
31pub enum HmaData<'a> {
32 Candles {
33 candles: &'a Candles,
34 source: &'a str,
35 },
36 Slice(&'a [f64]),
37}
38
39#[derive(Debug, Clone)]
40pub struct HmaOutput {
41 pub values: Vec<f64>,
42}
43
44#[derive(Debug, Clone)]
45#[cfg_attr(
46 all(target_arch = "wasm32", feature = "wasm"),
47 derive(Serialize, Deserialize)
48)]
49pub struct HmaParams {
50 pub period: Option<usize>,
51}
52
53impl Default for HmaParams {
54 fn default() -> Self {
55 Self { period: Some(5) }
56 }
57}
58
59#[derive(Debug, Clone)]
60pub struct HmaInput<'a> {
61 pub data: HmaData<'a>,
62 pub params: HmaParams,
63}
64
65impl<'a> HmaInput<'a> {
66 #[inline]
67 pub fn from_candles(c: &'a Candles, s: &'a str, p: HmaParams) -> Self {
68 Self {
69 data: HmaData::Candles {
70 candles: c,
71 source: s,
72 },
73 params: p,
74 }
75 }
76 #[inline]
77 pub fn from_slice(sl: &'a [f64], p: HmaParams) -> Self {
78 Self {
79 data: HmaData::Slice(sl),
80 params: p,
81 }
82 }
83 #[inline]
84 pub fn with_default_candles(c: &'a Candles) -> Self {
85 Self::from_candles(c, "close", HmaParams::default())
86 }
87 #[inline]
88 pub fn get_period(&self) -> usize {
89 self.params.period.unwrap_or(5)
90 }
91}
92
93#[derive(Copy, Clone, Debug)]
94pub struct HmaBuilder {
95 period: Option<usize>,
96 kernel: Kernel,
97}
98
99impl Default for HmaBuilder {
100 fn default() -> Self {
101 Self {
102 period: None,
103 kernel: Kernel::Auto,
104 }
105 }
106}
107
108impl HmaBuilder {
109 #[inline(always)]
110 pub fn new() -> Self {
111 Self::default()
112 }
113 #[inline(always)]
114 pub fn period(mut self, n: usize) -> Self {
115 self.period = Some(n);
116 self
117 }
118 #[inline(always)]
119 pub fn kernel(mut self, k: Kernel) -> Self {
120 self.kernel = k;
121 self
122 }
123 #[inline(always)]
124 pub fn apply(self, c: &Candles) -> Result<HmaOutput, HmaError> {
125 let p = HmaParams {
126 period: self.period,
127 };
128 let i = HmaInput::from_candles(c, "close", p);
129 hma_with_kernel(&i, self.kernel)
130 }
131 #[inline(always)]
132 pub fn apply_slice(self, d: &[f64]) -> Result<HmaOutput, HmaError> {
133 let p = HmaParams {
134 period: self.period,
135 };
136 let i = HmaInput::from_slice(d, p);
137 hma_with_kernel(&i, self.kernel)
138 }
139
140 #[inline(always)]
141 pub fn into_stream(self) -> Result<HmaStream, HmaError> {
142 let p = HmaParams {
143 period: self.period,
144 };
145 HmaStream::try_new(p)
146 }
147}
148
149#[derive(Debug, Error)]
150pub enum HmaError {
151 #[error("hma: No data provided.")]
152 NoData,
153
154 #[error("hma: All values are NaN.")]
155 AllValuesNaN,
156
157 #[error("hma: Invalid period: period = {period}, data length = {data_len}")]
158 InvalidPeriod { period: usize, data_len: usize },
159
160 #[error("hma: Output length mismatch: expected = {expected}, got = {got}")]
161 OutputLengthMismatch { expected: usize, got: usize },
162
163 #[error("hma: Invalid range: start = {start}, end = {end}, step = {step}")]
164 InvalidRange {
165 start: usize,
166 end: usize,
167 step: usize,
168 },
169
170 #[error("hma: Invalid kernel for batch API: {0:?}")]
171 InvalidKernelForBatch(Kernel),
172
173 #[error("hma: arithmetic overflow when computing {what}")]
174 ArithmeticOverflow { what: &'static str },
175
176 #[error("hma: Cannot calculate half of period: period = {period}")]
177 ZeroHalf { period: usize },
178
179 #[error("hma: Cannot calculate sqrt of period: period = {period}")]
180 ZeroSqrtPeriod { period: usize },
181
182 #[error("hma: Not enough valid data: needed = {needed}, valid = {valid}")]
183 NotEnoughValidData { needed: usize, valid: usize },
184}
185
186#[inline]
187pub fn hma(input: &HmaInput) -> Result<HmaOutput, HmaError> {
188 hma_with_kernel(input, Kernel::Auto)
189}
190
191#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
192pub fn hma_into(input: &HmaInput, out: &mut [f64]) -> Result<(), HmaError> {
193 hma_into_internal(input, out)
194}
195
196#[inline]
197fn hma_into_internal(input: &HmaInput, out: &mut [f64]) -> Result<(), HmaError> {
198 hma_with_kernel_into(input, Kernel::Auto, out)
199}
200
201pub fn hma_with_kernel(input: &HmaInput, kernel: Kernel) -> Result<HmaOutput, HmaError> {
202 let data: &[f64] = input.as_ref();
203 let len = data.len();
204 if len == 0 {
205 return Err(HmaError::NoData);
206 }
207 let first = data
208 .iter()
209 .position(|x| !x.is_nan())
210 .ok_or(HmaError::AllValuesNaN)?;
211 let period = input.get_period();
212 if period == 0 || period > len {
213 return Err(HmaError::InvalidPeriod {
214 period,
215 data_len: len,
216 });
217 }
218 if len - first < period {
219 return Err(HmaError::NotEnoughValidData {
220 needed: period,
221 valid: len - first,
222 });
223 }
224 let half = period / 2;
225 if half == 0 {
226 return Err(HmaError::ZeroHalf { period });
227 }
228 let sqrt_len = (period as f64).sqrt().floor() as usize;
229 if sqrt_len == 0 {
230 return Err(HmaError::ZeroSqrtPeriod { period });
231 }
232 if len - first < period + sqrt_len - 1 {
233 return Err(HmaError::NotEnoughValidData {
234 needed: period + sqrt_len - 1,
235 valid: len - first,
236 });
237 }
238 let chosen = match kernel {
239 Kernel::Auto => detect_best_kernel(),
240 other => other,
241 };
242 let first_out = first + period + sqrt_len - 2;
243 let mut out = alloc_with_nan_prefix(len, first_out);
244 unsafe {
245 match chosen {
246 Kernel::Scalar | Kernel::ScalarBatch => hma_scalar(data, period, first, &mut out),
247 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
248 Kernel::Avx2 | Kernel::Avx2Batch => hma_avx2(data, period, first, &mut out),
249 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
250 Kernel::Avx512 | Kernel::Avx512Batch => hma_avx512(data, period, first, &mut out),
251 _ => unreachable!(),
252 }
253 }
254 Ok(HmaOutput { values: out })
255}
256
257fn hma_with_kernel_into(input: &HmaInput, kernel: Kernel, out: &mut [f64]) -> Result<(), HmaError> {
258 let data: &[f64] = input.as_ref();
259 let len = data.len();
260 if len == 0 {
261 return Err(HmaError::NoData);
262 }
263 if out.len() != len {
264 return Err(HmaError::OutputLengthMismatch {
265 expected: len,
266 got: out.len(),
267 });
268 }
269
270 let first = data
271 .iter()
272 .position(|x| !x.is_nan())
273 .ok_or(HmaError::AllValuesNaN)?;
274 let period = input.get_period();
275 if period == 0 || period > len {
276 return Err(HmaError::InvalidPeriod {
277 period,
278 data_len: len,
279 });
280 }
281 if len - first < period {
282 return Err(HmaError::NotEnoughValidData {
283 needed: period,
284 valid: len - first,
285 });
286 }
287
288 let half = period / 2;
289 if half == 0 {
290 return Err(HmaError::ZeroHalf { period });
291 }
292 let sqrt_len = (period as f64).sqrt().floor() as usize;
293 if sqrt_len == 0 {
294 return Err(HmaError::ZeroSqrtPeriod { period });
295 }
296 if len - first < period + sqrt_len - 1 {
297 return Err(HmaError::NotEnoughValidData {
298 needed: period + sqrt_len - 1,
299 valid: len - first,
300 });
301 }
302
303 let chosen = match kernel {
304 Kernel::Auto => detect_best_kernel(),
305 other => other,
306 };
307 unsafe {
308 match chosen {
309 Kernel::Scalar | Kernel::ScalarBatch => hma_scalar(data, period, first, out),
310 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
311 Kernel::Avx2 | Kernel::Avx2Batch => hma_avx2(data, period, first, out),
312 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
313 Kernel::Avx512 | Kernel::Avx512Batch => hma_avx512(data, period, first, out),
314 _ => unreachable!(),
315 }
316 }
317
318 let first_out = first + period + sqrt_len - 2;
319 for v in &mut out[..first_out] {
320 *v = f64::NAN;
321 }
322 Ok(())
323}
324
325#[inline]
326pub fn hma_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
327 let len = data.len();
328 if period < 2 || first >= len || period > len - first {
329 return;
330 }
331 let half = period / 2;
332 if half == 0 {
333 return;
334 }
335 let sq = (period as f64).sqrt().floor() as usize;
336 if sq == 0 {
337 return;
338 }
339
340 let first_out = first + period + sq - 2;
341 if first_out >= len {
342 return;
343 }
344
345 let ws_half = (half * (half + 1) / 2) as f64;
346 let ws_full = (period * (period + 1) / 2) as f64;
347 let ws_sqrt = (sq * (sq + 1) / 2) as f64;
348 let half_f = half as f64;
349 let period_f = period as f64;
350 let sq_f = sq as f64;
351
352 let (mut s_half, mut ws_half_acc) = (0.0, 0.0);
353 let (mut s_full, mut ws_full_acc) = (0.0, 0.0);
354 let (mut wma_half, mut wma_full) = (f64::NAN, f64::NAN);
355
356 let mut x_buf = vec![0.0f64; sq];
357 let mut x_sum = 0.0;
358 let mut x_wsum = 0.0;
359 let mut x_head = 0usize;
360
361 let start = first;
362
363 for j in 0..half {
364 let v = data[start + j];
365 let jf = j as f64 + 1.0;
366 s_full += v;
367 ws_full_acc = jf.mul_add(v, ws_full_acc);
368 s_half += v;
369 ws_half_acc = jf.mul_add(v, ws_half_acc);
370 }
371 wma_half = ws_half_acc / ws_half;
372
373 if period > half + 1 {
374 for j in half..(period - 1) {
375 let idx = start + j;
376 let v = data[idx];
377
378 let jf = j as f64 + 1.0;
379 s_full += v;
380 ws_full_acc = jf.mul_add(v, ws_full_acc);
381
382 let old_h = data[idx - half];
383 let prev = s_half;
384 s_half = prev + v - old_h;
385 ws_half_acc = half_f.mul_add(v, ws_half_acc - prev);
386 wma_half = ws_half_acc / ws_half;
387 }
388 }
389
390 {
391 let j = period - 1;
392 let idx = start + j;
393 let v = data[idx];
394
395 let jf = j as f64 + 1.0;
396 s_full += v;
397 ws_full_acc = jf.mul_add(v, ws_full_acc);
398 wma_full = ws_full_acc / ws_full;
399
400 let old_h = data[idx - half];
401 let prev = s_half;
402 s_half = prev + v - old_h;
403 ws_half_acc = half_f.mul_add(v, ws_half_acc - prev);
404 wma_half = ws_half_acc / ws_half;
405
406 let x = 2.0 * wma_half - wma_full;
407 x_buf[0] = x;
408 x_sum += x;
409 x_wsum = 1.0f64.mul_add(x, x_wsum);
410
411 if sq == 1 {
412 out[first_out] = x_wsum / ws_sqrt;
413 }
414 }
415
416 if sq > 1 {
417 for j in period..(period + sq - 1) {
418 let idx = start + j;
419 let v = data[idx];
420
421 let old_f = data[idx - period];
422 let prev_f = s_full;
423 s_full = prev_f + v - old_f;
424 ws_full_acc = period_f.mul_add(v, ws_full_acc - prev_f);
425 wma_full = ws_full_acc / ws_full;
426
427 let old_h = data[idx - half];
428 let prev_h = s_half;
429 s_half = prev_h + v - old_h;
430 ws_half_acc = half_f.mul_add(v, ws_half_acc - prev_h);
431 wma_half = ws_half_acc / ws_half;
432
433 let x = 2.0 * wma_half - wma_full;
434 let pos = j + 1 - period;
435 x_buf[pos] = x;
436 x_sum += x;
437 x_wsum = (pos as f64 + 1.0).mul_add(x, x_wsum);
438
439 if pos + 1 == sq {
440 out[first_out] = x_wsum / ws_sqrt;
441 }
442 }
443 }
444
445 let mut j = period + sq - 1;
446 while j < len - start {
447 let idx = start + j;
448 let v = data[idx];
449
450 let old_f = data[idx - period];
451 let prev_f = s_full;
452 s_full = prev_f + v - old_f;
453 ws_full_acc = period_f.mul_add(v, ws_full_acc - prev_f);
454 wma_full = ws_full_acc / ws_full;
455
456 let old_h = data[idx - half];
457 let prev_h = s_half;
458 s_half = prev_h + v - old_h;
459 ws_half_acc = half_f.mul_add(v, ws_half_acc - prev_h);
460 wma_half = ws_half_acc / ws_half;
461
462 let x = 2.0 * wma_half - wma_full;
463 let old_x = x_buf[x_head];
464 x_buf[x_head] = x;
465 x_head += 1;
466 if x_head == sq {
467 x_head = 0;
468 }
469
470 let prev_sum = x_sum;
471 x_sum = prev_sum + x - old_x;
472 x_wsum = sq_f.mul_add(x, x_wsum - prev_sum);
473
474 out[idx] = x_wsum / ws_sqrt;
475 j += 1;
476 }
477}
478
479#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
480#[inline]
481pub fn hma_avx2(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
482 hma_scalar(data, period, first, out)
483}
484
485#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
486#[target_feature(enable = "avx512f,fma")]
487#[inline]
488pub unsafe fn hma_avx512(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
489 use aligned_vec::AVec;
490 use core::arch::x86_64::*;
491
492 let len = data.len();
493 if period < 2 || first >= len || period > len - first {
494 return;
495 }
496 let half = period / 2;
497 if half == 0 {
498 return;
499 }
500 let sq = (period as f64).sqrt().floor() as usize;
501 debug_assert!(
502 sq > 0 && sq <= 65_535,
503 "HMA: √period must fit in 16-bit to keep Σw < 2^53"
504 );
505 if sq == 0 {
506 return;
507 }
508 let first_out = first + period + sq - 2;
509 if first_out >= len {
510 return;
511 }
512
513 let ws_half = (half * (half + 1) / 2) as f64;
514 let ws_full = (period * (period + 1) / 2) as f64;
515 let ws_sqrt = (sq * (sq + 1) / 2) as f64;
516 let sq_f = sq as f64;
517
518 let (mut s_half, mut ws_half_acc) = (0.0, 0.0);
519 let (mut s_full, mut ws_full_acc) = (0.0, 0.0);
520 let (mut wma_half, mut wma_full) = (f64::NAN, f64::NAN);
521
522 let sq_aligned = (sq + 7) & !7;
523 let mut x_buf: AVec<f64> = AVec::with_capacity(64, sq_aligned);
524 x_buf.resize(sq_aligned, 0.0);
525
526 let mut x_sum = 0.0;
527 let mut x_wsum = 0.0;
528 let mut x_head = 0usize;
529
530 const W_RAMP_ARR: [f64; 8] = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
531
532 let w_ramp: __m512d = _mm512_loadu_pd(W_RAMP_ARR.as_ptr());
533
534 #[inline(always)]
535 unsafe fn horiz_sum(z: __m512d) -> f64 {
536 let hi = _mm512_extractf64x4_pd(z, 1);
537 let lo = _mm512_castpd512_pd256(z);
538
539 let sum256 = _mm256_add_pd(hi, lo);
540
541 let sum128 = _mm256_hadd_pd(sum256, sum256);
542
543 let hi128 = _mm256_extractf128_pd(sum128, 1);
544 let lo128 = _mm256_castpd256_pd128(sum128);
545 let final_sum = _mm_add_pd(hi128, lo128);
546
547 _mm_cvtsd_f64(final_sum)
548 }
549
550 for j in 0..(period + sq - 1) {
551 let idx = first + j;
552 let val = *data.get_unchecked(idx);
553
554 if j < period {
555 s_full += val;
556 ws_full_acc += (j as f64 + 1.0) * val;
557 } else {
558 let old = *data.get_unchecked(idx - period);
559 let prev = s_full;
560 s_full = prev + val - old;
561 ws_full_acc = ws_full_acc - prev + (period as f64) * val;
562 }
563
564 if j < half {
565 s_half += val;
566 ws_half_acc += (j as f64 + 1.0) * val;
567 } else {
568 let old = *data.get_unchecked(idx - half);
569 let prev = s_half;
570 s_half = prev + val - old;
571 ws_half_acc = ws_half_acc - prev + (half as f64) * val;
572 }
573
574 if j + 1 >= half {
575 wma_half = ws_half_acc / ws_half;
576 }
577 if j + 1 >= period {
578 wma_full = ws_full_acc / ws_full;
579 }
580
581 if j + 1 >= period {
582 let x_val = 2.0 * wma_half - wma_full;
583 let pos = (j + 1 - period) as usize;
584
585 if pos < sq {
586 *x_buf.get_unchecked_mut(pos) = x_val;
587 x_sum += x_val;
588
589 if pos + 1 == sq {
590 let mut acc = _mm512_setzero_pd();
591 let mut i = 0usize;
592 let mut off = 0.0;
593 while i + 8 <= sq {
594 let x = _mm512_loadu_pd(x_buf.as_ptr().add(i));
595
596 let w = _mm512_add_pd(w_ramp, _mm512_set1_pd(off + 1.0));
597 acc = _mm512_fmadd_pd(x, w, acc);
598 i += 8;
599 off += 8.0;
600 }
601 x_wsum = horiz_sum(acc);
602 for k in i..sq {
603 x_wsum += x_buf[k] * (k as f64 + 1.0);
604 }
605 *out.get_unchecked_mut(first_out) = x_wsum / ws_sqrt;
606 }
607 }
608 }
609 }
610
611 for j in (period + sq - 1)..(len - first) {
612 let idx = first + j;
613 let val = *data.get_unchecked(idx);
614
615 let old_f = *data.get_unchecked(idx - period);
616 let old_h = *data.get_unchecked(idx - half);
617
618 let sum_vec = _mm_set_pd(s_full, s_half);
619 let old_vec = _mm_set_pd(old_f, old_h);
620 let ws_vec = _mm_set_pd(ws_full_acc, ws_half_acc);
621 let weights = _mm_set_pd(period as f64, half as f64);
622 let v_val = _mm_set1_pd(val);
623
624 let new_sum_vec = _mm_add_pd(_mm_sub_pd(sum_vec, old_vec), v_val);
625
626 let diff = _mm_sub_pd(ws_vec, sum_vec);
627 let new_ws_vec = _mm_fmadd_pd(v_val, weights, diff);
628
629 s_full = _mm_cvtsd_f64(_mm_unpackhi_pd(new_sum_vec, new_sum_vec));
630 s_half = _mm_cvtsd_f64(new_sum_vec);
631 ws_full_acc = _mm_cvtsd_f64(_mm_unpackhi_pd(new_ws_vec, new_ws_vec));
632 ws_half_acc = _mm_cvtsd_f64(new_ws_vec);
633
634 wma_full = ws_full_acc / ws_full;
635 wma_half = ws_half_acc / ws_half;
636 let x_val = 2.0 * wma_half - wma_full;
637
638 let old_x = *x_buf.get_unchecked(x_head);
639 *x_buf.get_unchecked_mut(x_head) = x_val;
640 x_head = (x_head + 1) % sq;
641
642 let prev_sum = x_sum;
643 x_sum = prev_sum + x_val - old_x;
644 x_wsum = sq_f.mul_add(x_val, x_wsum - prev_sum);
645
646 *out.get_unchecked_mut(idx) = x_wsum / ws_sqrt;
647
648 let pf = core::cmp::min(idx + 128, len - 1);
649 _mm_prefetch(data.as_ptr().add(pf) as *const i8, _MM_HINT_T1);
650 }
651}
652
653#[derive(Debug, Clone)]
654struct LinWma {
655 period: usize,
656 inv_norm: f64,
657 buffer: Vec<f64>,
658 head: usize,
659 filled: bool,
660 count: usize,
661
662 sum: f64,
663 wsum: f64,
664 nan_count: usize,
665 dirty: bool,
666}
667
668impl LinWma {
669 #[inline(always)]
670 fn new(period: usize) -> Self {
671 let norm = (period as f64) * ((period as f64) + 1.0) * 0.5;
672 Self {
673 period,
674 inv_norm: 1.0 / norm,
675 buffer: vec![f64::NAN; period],
676 head: 0,
677 filled: false,
678 count: 0,
679 sum: 0.0,
680 wsum: 0.0,
681 nan_count: 0,
682 dirty: false,
683 }
684 }
685
686 #[inline(always)]
687 fn rebuild(&mut self) {
688 self.sum = 0.0;
689 self.wsum = 0.0;
690 self.nan_count = 0;
691
692 let mut idx = self.head;
693 for i in 0..self.period {
694 let v = self.buffer[idx];
695 if v.is_nan() {
696 self.nan_count += 1;
697 } else {
698 self.sum += v;
699 self.wsum = (i as f64 + 1.0).mul_add(v, self.wsum);
700 }
701 idx = if idx + 1 == self.period { 0 } else { idx + 1 };
702 }
703 self.dirty = self.nan_count != 0;
704 debug_assert!(self.nan_count == 0, "rebuild expected clean window");
705 }
706
707 #[inline(always)]
708 fn update(&mut self, value: f64) -> Option<f64> {
709 let n = self.period as f64;
710
711 let old = self.buffer[self.head];
712 self.buffer[self.head] = value;
713 self.head = if self.head + 1 == self.period {
714 0
715 } else {
716 self.head + 1
717 };
718
719 if !self.filled {
720 self.count += 1;
721
722 if value.is_nan() {
723 self.nan_count += 1;
724 self.dirty = true;
725 } else {
726 self.sum += value;
727 self.wsum = (self.count as f64).mul_add(value, self.wsum);
728 }
729
730 if self.count == self.period {
731 self.filled = true;
732
733 return Some(if self.nan_count > 0 {
734 f64::NAN
735 } else {
736 self.wsum * self.inv_norm
737 });
738 }
739 return None;
740 }
741
742 if old.is_nan() {
743 self.nan_count = self.nan_count.saturating_sub(1);
744 }
745 if value.is_nan() {
746 self.nan_count += 1;
747 }
748
749 if self.nan_count > 0 {
750 self.dirty = true;
751 return Some(f64::NAN);
752 }
753
754 if self.dirty {
755 self.rebuild();
756 self.dirty = false;
757 debug_assert_eq!(self.nan_count, 0);
758 return Some(self.wsum * self.inv_norm);
759 }
760
761 let prev_sum = self.sum;
762 self.sum = prev_sum + value - old;
763 self.wsum = n.mul_add(value, self.wsum - prev_sum);
764
765 Some(self.wsum * self.inv_norm)
766 }
767}
768
769#[derive(Debug, Clone)]
770pub struct HmaStream {
771 wma_half: LinWma,
772 wma_full: LinWma,
773 wma_sqrt: LinWma,
774}
775
776impl HmaStream {
777 pub fn try_new(params: HmaParams) -> Result<Self, HmaError> {
778 let period = params.period.unwrap_or(5);
779 if period < 2 {
780 return Err(HmaError::InvalidPeriod {
781 period,
782 data_len: 0,
783 });
784 }
785 let half = period / 2;
786 if half == 0 {
787 return Err(HmaError::ZeroHalf { period });
788 }
789 let sqrt_len = (period as f64).sqrt().floor() as usize;
790 if sqrt_len == 0 {
791 return Err(HmaError::ZeroSqrtPeriod { period });
792 }
793
794 Ok(Self {
795 wma_half: LinWma::new(half),
796 wma_full: LinWma::new(period),
797 wma_sqrt: LinWma::new(sqrt_len),
798 })
799 }
800
801 #[inline(always)]
802 pub fn update(&mut self, value: f64) -> Option<f64> {
803 let full = self.wma_full.update(value);
804 let half = self.wma_half.update(value);
805
806 if let (Some(f), Some(h)) = (full, half) {
807 let x = 2.0f64.mul_add(h, -f);
808 self.wma_sqrt.update(x)
809 } else {
810 None
811 }
812 }
813}
814
815#[derive(Clone, Debug)]
816pub struct HmaBatchRange {
817 pub period: (usize, usize, usize),
818}
819
820impl Default for HmaBatchRange {
821 fn default() -> Self {
822 Self {
823 period: (5, 254, 1),
824 }
825 }
826}
827
828#[derive(Clone, Debug, Default)]
829pub struct HmaBatchBuilder {
830 range: HmaBatchRange,
831 kernel: Kernel,
832}
833
834impl HmaBatchBuilder {
835 pub fn new() -> Self {
836 Self::default()
837 }
838 pub fn kernel(mut self, k: Kernel) -> Self {
839 self.kernel = k;
840 self
841 }
842 #[inline]
843 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
844 self.range.period = (start, end, step);
845 self
846 }
847 #[inline]
848 pub fn period_static(mut self, p: usize) -> Self {
849 self.range.period = (p, p, 0);
850 self
851 }
852 pub fn apply_slice(self, data: &[f64]) -> Result<HmaBatchOutput, HmaError> {
853 hma_batch_with_kernel(data, &self.range, self.kernel)
854 }
855 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<HmaBatchOutput, HmaError> {
856 HmaBatchBuilder::new().kernel(k).apply_slice(data)
857 }
858 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<HmaBatchOutput, HmaError> {
859 let slice = source_type(c, src);
860 self.apply_slice(slice)
861 }
862 pub fn with_default_candles(c: &Candles) -> Result<HmaBatchOutput, HmaError> {
863 HmaBatchBuilder::new()
864 .kernel(Kernel::Auto)
865 .apply_candles(c, "close")
866 }
867}
868
869pub fn hma_batch_with_kernel(
870 data: &[f64],
871 sweep: &HmaBatchRange,
872 k: Kernel,
873) -> Result<HmaBatchOutput, HmaError> {
874 let kernel = match k {
875 Kernel::Auto => detect_best_batch_kernel(),
876 other if other.is_batch() => other,
877 other => return Err(HmaError::InvalidKernelForBatch(other)),
878 };
879 let simd = match kernel {
880 Kernel::Avx512Batch => Kernel::Avx512,
881 Kernel::Avx2Batch => Kernel::Avx2,
882 Kernel::ScalarBatch => Kernel::Scalar,
883 _ => unreachable!(),
884 };
885 hma_batch_par_slice(data, sweep, simd)
886}
887
888#[derive(Clone, Debug)]
889pub struct HmaBatchOutput {
890 pub values: Vec<f64>,
891 pub combos: Vec<HmaParams>,
892 pub rows: usize,
893 pub cols: usize,
894}
895
896impl HmaBatchOutput {
897 pub fn row_for_params(&self, p: &HmaParams) -> Option<usize> {
898 self.combos
899 .iter()
900 .position(|c| c.period.unwrap_or(5) == p.period.unwrap_or(5))
901 }
902 pub fn values_for(&self, p: &HmaParams) -> Option<&[f64]> {
903 self.row_for_params(p).map(|row| {
904 let start = row * self.cols;
905 &self.values[start..start + self.cols]
906 })
907 }
908}
909
910#[inline(always)]
911fn expand_grid(r: &HmaBatchRange) -> Vec<HmaParams> {
912 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
913 if step == 0 || start == end {
914 return vec![start];
915 }
916
917 let (lo, hi) = if start <= end {
918 (start, end)
919 } else {
920 (end, start)
921 };
922 let mut v = Vec::new();
923 let mut x = lo;
924 while x <= hi {
925 v.push(x);
926 match x.checked_add(step) {
927 Some(nx) => x = nx,
928 None => break,
929 }
930 }
931 v
932 }
933 let periods = axis_usize(r.period);
934 let mut out = Vec::with_capacity(periods.len());
935 for &p in &periods {
936 out.push(HmaParams { period: Some(p) });
937 }
938 out
939}
940
941#[inline(always)]
942pub fn hma_batch_slice(
943 data: &[f64],
944 sweep: &HmaBatchRange,
945 kern: Kernel,
946) -> Result<HmaBatchOutput, HmaError> {
947 hma_batch_inner(data, sweep, kern, false)
948}
949
950#[inline(always)]
951pub fn hma_batch_par_slice(
952 data: &[f64],
953 sweep: &HmaBatchRange,
954 kern: Kernel,
955) -> Result<HmaBatchOutput, HmaError> {
956 hma_batch_inner(data, sweep, kern, true)
957}
958
959#[inline(always)]
960fn hma_batch_inner(
961 data: &[f64],
962 sweep: &HmaBatchRange,
963 kern: Kernel,
964 parallel: bool,
965) -> Result<HmaBatchOutput, HmaError> {
966 let combos = expand_grid(sweep);
967 if combos.is_empty() {
968 let (s, e, t) = sweep.period;
969 return Err(HmaError::InvalidRange {
970 start: s,
971 end: e,
972 step: t,
973 });
974 }
975 let first = data
976 .iter()
977 .position(|x| !x.is_nan())
978 .ok_or(HmaError::AllValuesNaN)?;
979 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
980 if data.len() - first < max_p {
981 return Err(HmaError::NotEnoughValidData {
982 needed: max_p,
983 valid: data.len() - first,
984 });
985 }
986 let rows = combos.len();
987 let cols = data.len();
988
989 let warm: Vec<usize> = combos
990 .iter()
991 .map(|c| {
992 let p = c.period.unwrap();
993 let s = (p as f64).sqrt().floor() as usize;
994 first + p + s - 2
995 })
996 .collect();
997
998 let _ = rows
999 .checked_mul(cols)
1000 .ok_or(HmaError::ArithmeticOverflow { what: "rows*cols" })?;
1001 let mut raw = make_uninit_matrix(rows, cols);
1002 unsafe { init_matrix_prefixes(&mut raw, cols, &warm) };
1003
1004 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1005 let period = combos[row].period.unwrap();
1006
1007 let out_row =
1008 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1009
1010 match kern {
1011 Kernel::Scalar | Kernel::ScalarBatch => hma_row_scalar(data, first, period, out_row),
1012 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1013 Kernel::Avx2 | Kernel::Avx2Batch => hma_row_avx2(data, first, period, out_row),
1014 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1015 Kernel::Avx512 | Kernel::Avx512Batch => hma_row_avx512(data, first, period, out_row),
1016 _ => unreachable!(),
1017 }
1018 };
1019
1020 if parallel {
1021 #[cfg(not(target_arch = "wasm32"))]
1022 {
1023 raw.par_chunks_mut(cols)
1024 .enumerate()
1025 .for_each(|(row, slice)| do_row(row, slice));
1026 }
1027
1028 #[cfg(target_arch = "wasm32")]
1029 {
1030 for (row, slice) in raw.chunks_mut(cols).enumerate() {
1031 do_row(row, slice);
1032 }
1033 }
1034 } else {
1035 for (row, slice) in raw.chunks_mut(cols).enumerate() {
1036 do_row(row, slice);
1037 }
1038 }
1039
1040 let rows = combos.len();
1041 let cols = data.len();
1042 let _ = rows
1043 .checked_mul(cols)
1044 .ok_or(HmaError::ArithmeticOverflow { what: "rows*cols" })?;
1045
1046 let mut guard = core::mem::ManuallyDrop::new(raw);
1047 let values: Vec<f64> = unsafe {
1048 Vec::from_raw_parts(
1049 guard.as_mut_ptr() as *mut f64,
1050 guard.len(),
1051 guard.capacity(),
1052 )
1053 };
1054
1055 Ok(HmaBatchOutput {
1056 values,
1057 combos,
1058 rows,
1059 cols,
1060 })
1061}
1062
1063#[inline(always)]
1064fn hma_batch_inner_into(
1065 data: &[f64],
1066 sweep: &HmaBatchRange,
1067 kern: Kernel,
1068 parallel: bool,
1069 out: &mut [f64],
1070) -> Result<(Vec<HmaParams>, usize, usize), HmaError> {
1071 let combos = expand_grid(sweep);
1072 if combos.is_empty() {
1073 let (s, e, t) = sweep.period;
1074 return Err(HmaError::InvalidRange {
1075 start: s,
1076 end: e,
1077 step: t,
1078 });
1079 }
1080 let first = data
1081 .iter()
1082 .position(|x| !x.is_nan())
1083 .ok_or(HmaError::AllValuesNaN)?;
1084 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1085 if data.len() - first < max_p {
1086 return Err(HmaError::NotEnoughValidData {
1087 needed: max_p,
1088 valid: data.len() - first,
1089 });
1090 }
1091 let rows = combos.len();
1092 let cols = data.len();
1093
1094 let expected = rows
1095 .checked_mul(cols)
1096 .ok_or(HmaError::ArithmeticOverflow { what: "rows*cols" })?;
1097 if out.len() != expected {
1098 return Err(HmaError::OutputLengthMismatch {
1099 expected,
1100 got: out.len(),
1101 });
1102 }
1103
1104 let warm: Vec<usize> = combos
1105 .iter()
1106 .map(|c| {
1107 let p = c.period.unwrap();
1108 let s = (p as f64).sqrt().floor() as usize;
1109 first + p + s - 2
1110 })
1111 .collect();
1112
1113 let out_uninit = unsafe {
1114 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1115 };
1116 unsafe { init_matrix_prefixes(out_uninit, cols, &warm) };
1117
1118 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1119 let period = combos[row].period.unwrap();
1120
1121 match kern {
1122 Kernel::Scalar | Kernel::ScalarBatch => hma_row_scalar(data, first, period, out_row),
1123 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1124 Kernel::Avx2 | Kernel::Avx2Batch => hma_row_avx2(data, first, period, out_row),
1125 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1126 Kernel::Avx512 | Kernel::Avx512Batch => hma_row_avx512(data, first, period, out_row),
1127 _ => unreachable!(),
1128 }
1129 };
1130
1131 if parallel {
1132 #[cfg(not(target_arch = "wasm32"))]
1133 {
1134 out.par_chunks_mut(cols)
1135 .enumerate()
1136 .for_each(|(row, slice)| do_row(row, slice));
1137 }
1138 #[cfg(target_arch = "wasm32")]
1139 {
1140 for (row, slice) in out.chunks_mut(cols).enumerate() {
1141 do_row(row, slice);
1142 }
1143 }
1144 } else {
1145 for (row, slice) in out.chunks_mut(cols).enumerate() {
1146 do_row(row, slice);
1147 }
1148 }
1149
1150 Ok((combos, rows, cols))
1151}
1152
1153#[inline(always)]
1154pub unsafe fn hma_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1155 hma_scalar(data, period, first, out)
1156}
1157
1158#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1159#[inline(always)]
1160pub unsafe fn hma_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1161 hma_avx2(data, period, first, out);
1162}
1163
1164#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1165#[inline(always)]
1166pub unsafe fn hma_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1167 hma_avx512(data, period, first, out);
1168}
1169
1170#[inline(always)]
1171fn expand_grid_hma(r: &HmaBatchRange) -> Vec<HmaParams> {
1172 expand_grid(r)
1173}
1174
1175#[cfg(feature = "python")]
1176use crate::utilities::kernel_validation::validate_kernel;
1177#[cfg(feature = "python")]
1178use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
1179#[cfg(feature = "python")]
1180use pyo3::exceptions::PyValueError;
1181#[cfg(feature = "python")]
1182use pyo3::prelude::*;
1183#[cfg(feature = "python")]
1184use pyo3::types::PyDict;
1185
1186#[cfg(feature = "python")]
1187#[pyfunction(name = "hma")]
1188#[pyo3(signature = (data, period, kernel=None))]
1189pub fn hma_py<'py>(
1190 py: Python<'py>,
1191 data: PyReadonlyArray1<'py, f64>,
1192 period: usize,
1193 kernel: Option<&str>,
1194) -> PyResult<Bound<'py, PyArray1<f64>>> {
1195 use numpy::{IntoPyArray, PyArrayMethods};
1196
1197 let slice_in = data.as_slice()?;
1198 let kern = validate_kernel(kernel, false)?;
1199
1200 let params = HmaParams {
1201 period: Some(period),
1202 };
1203 let hma_in = HmaInput::from_slice(slice_in, params);
1204
1205 let result_vec: Vec<f64> = py
1206 .allow_threads(|| hma_with_kernel(&hma_in, kern).map(|o| o.values))
1207 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1208
1209 Ok(result_vec.into_pyarray(py))
1210}
1211
1212#[cfg(feature = "python")]
1213#[pyfunction(name = "hma_batch")]
1214#[pyo3(signature = (data, period_range, kernel=None))]
1215pub fn hma_batch_py<'py>(
1216 py: Python<'py>,
1217 data: PyReadonlyArray1<'py, f64>,
1218 period_range: (usize, usize, usize),
1219 kernel: Option<&str>,
1220) -> PyResult<Bound<'py, PyDict>> {
1221 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1222 use pyo3::types::PyDict;
1223
1224 let slice_in = data.as_slice()?;
1225 let kern = validate_kernel(kernel, true)?;
1226 let sweep = HmaBatchRange {
1227 period: period_range,
1228 };
1229
1230 let combos = expand_grid(&sweep);
1231 let rows = combos.len();
1232 let cols = slice_in.len();
1233
1234 let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1235 let slice_out = unsafe { out_arr.as_slice_mut()? };
1236
1237 let combos = py
1238 .allow_threads(|| {
1239 let kernel = match kern {
1240 Kernel::Auto => detect_best_batch_kernel(),
1241 k => k,
1242 };
1243 let simd = match kernel {
1244 Kernel::Avx512Batch => Kernel::Avx512,
1245 Kernel::Avx2Batch => Kernel::Avx2,
1246 Kernel::ScalarBatch => Kernel::Scalar,
1247 _ => unreachable!(),
1248 };
1249 hma_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1250 .map(|(combos, _, _)| combos)
1251 })
1252 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1253
1254 let dict = PyDict::new(py);
1255 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1256 dict.set_item(
1257 "periods",
1258 combos
1259 .iter()
1260 .map(|p| p.period.unwrap() as u64)
1261 .collect::<Vec<_>>()
1262 .into_pyarray(py),
1263 )?;
1264
1265 Ok(dict)
1266}
1267
1268#[cfg(all(feature = "python", feature = "cuda"))]
1269#[pyfunction(name = "hma_cuda_batch_dev")]
1270#[pyo3(signature = (data_f32, period_range, device_id=0))]
1271pub fn hma_cuda_batch_dev_py<'py>(
1272 py: Python<'py>,
1273 data_f32: numpy::PyReadonlyArray1<'py, f32>,
1274 period_range: (usize, usize, usize),
1275 device_id: usize,
1276) -> PyResult<(DeviceArrayF32HmaPy, Bound<'py, PyDict>)> {
1277 use crate::cuda::cuda_available;
1278 use numpy::IntoPyArray;
1279 use pyo3::types::PyDict;
1280
1281 if !cuda_available() {
1282 return Err(PyValueError::new_err("CUDA not available"));
1283 }
1284
1285 let slice_in = data_f32.as_slice()?;
1286 let sweep = HmaBatchRange {
1287 period: period_range,
1288 };
1289
1290 let (inner, combos, stream_u64, ctx, dev_id) = py.allow_threads(|| {
1291 let cuda = CudaHma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1292 let ctx = cuda.ctx();
1293 let dev_id = cuda.device_id();
1294 let res = cuda
1295 .hma_batch_dev(slice_in, &sweep)
1296 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1297 Ok::<_, PyErr>((res.0, res.1, cuda.stream_handle_u64(), ctx, dev_id))
1298 })?;
1299
1300 let dict = PyDict::new(py);
1301 let periods: Vec<u64> = combos.iter().map(|c| c.period.unwrap() as u64).collect();
1302 dict.set_item("periods", periods.into_pyarray(py))?;
1303
1304 dict.set_item("cai_version", 3u64)?;
1305 dict.set_item("cai_typestr", "<f4")?;
1306 dict.set_item("cai_shape", (inner.rows as u64, inner.cols as u64))?;
1307 dict.set_item("cai_strides_bytes", ((inner.cols as u64) * 4u64, 4u64))?;
1308 dict.set_item("stream", stream_u64)?;
1309
1310 Ok((
1311 DeviceArrayF32HmaPy::new(inner, ctx, dev_id, stream_u64),
1312 dict,
1313 ))
1314}
1315
1316#[cfg(all(feature = "python", feature = "cuda"))]
1317#[pyfunction(name = "hma_cuda_many_series_one_param_dev")]
1318#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1319pub fn hma_cuda_many_series_one_param_dev_py(
1320 py: Python<'_>,
1321 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1322 period: usize,
1323 device_id: usize,
1324) -> PyResult<DeviceArrayF32HmaPy> {
1325 use crate::cuda::cuda_available;
1326 use numpy::PyUntypedArrayMethods;
1327
1328 if !cuda_available() {
1329 return Err(PyValueError::new_err("CUDA not available"));
1330 }
1331
1332 let flat_in = data_tm_f32.as_slice()?;
1333 let rows = data_tm_f32.shape()[0];
1334 let cols = data_tm_f32.shape()[1];
1335 let params = HmaParams {
1336 period: Some(period),
1337 };
1338
1339 let (inner, ctx, dev_id, stream_u64) = py.allow_threads(|| {
1340 let cuda = CudaHma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1341 let ctx = cuda.ctx();
1342 let dev_id = cuda.device_id();
1343 let arr = cuda
1344 .hma_multi_series_one_param_time_major_dev(flat_in, cols, rows, ¶ms)
1345 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1346 Ok::<_, PyErr>((arr, ctx, dev_id, cuda.stream_handle_u64()))
1347 })?;
1348
1349 Ok(DeviceArrayF32HmaPy::new(inner, ctx, dev_id, stream_u64))
1350}
1351
1352#[cfg(feature = "python")]
1353#[pyclass(name = "HmaStream")]
1354pub struct HmaStreamPy {
1355 inner: HmaStream,
1356}
1357
1358#[cfg(feature = "python")]
1359#[pymethods]
1360impl HmaStreamPy {
1361 #[new]
1362 fn new(period: usize) -> PyResult<Self> {
1363 let params = HmaParams {
1364 period: Some(period),
1365 };
1366 let stream =
1367 HmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1368 Ok(HmaStreamPy { inner: stream })
1369 }
1370
1371 fn update(&mut self, value: f64) -> Option<f64> {
1372 self.inner.update(value)
1373 }
1374}
1375
1376#[cfg(all(feature = "python", feature = "cuda"))]
1377#[pyclass(module = "ta_indicators.cuda", name = "DeviceArrayF32Hma", unsendable)]
1378pub struct DeviceArrayF32HmaPy {
1379 pub(crate) inner: crate::cuda::moving_averages::DeviceArrayF32,
1380 _ctx_guard: std::sync::Arc<cust::context::Context>,
1381 _device_id: u32,
1382 _stream: u64,
1383}
1384
1385#[cfg(all(feature = "python", feature = "cuda"))]
1386#[pymethods]
1387impl DeviceArrayF32HmaPy {
1388 #[new]
1389 fn py_new() -> PyResult<Self> {
1390 Err(pyo3::exceptions::PyTypeError::new_err(
1391 "use factory methods from CUDA functions",
1392 ))
1393 }
1394
1395 #[getter]
1396 fn __cuda_array_interface__<'py>(
1397 &self,
1398 py: Python<'py>,
1399 ) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1400 let d = pyo3::types::PyDict::new(py);
1401 let itemsize = std::mem::size_of::<f32>();
1402 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1403 d.set_item("typestr", "<f4")?;
1404
1405 d.set_item("strides", (self.inner.cols * itemsize, itemsize))?;
1406 let size = self.inner.rows.saturating_mul(self.inner.cols);
1407 let ptr_val: usize = if size == 0 {
1408 0
1409 } else {
1410 self.inner.buf.as_device_ptr().as_raw() as usize
1411 };
1412 d.set_item("data", (ptr_val, false))?;
1413
1414 d.set_item("stream", self._stream)?;
1415 d.set_item("version", 3)?;
1416 Ok(d)
1417 }
1418
1419 fn __dlpack_device__(&self) -> (i32, i32) {
1420 (2, self._device_id as i32)
1421 }
1422
1423 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1424 fn __dlpack__<'py>(
1425 &mut self,
1426 py: Python<'py>,
1427 stream: Option<PyObject>,
1428 max_version: Option<PyObject>,
1429 dl_device: Option<PyObject>,
1430 copy: Option<PyObject>,
1431 ) -> PyResult<pyo3::PyObject> {
1432 use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1433 use pyo3::ffi as pyffi;
1434 use std::ffi::{c_void, CString};
1435
1436 let (kdl, alloc_dev) = self.__dlpack_device__();
1437 if let Some(dev_obj) = dl_device.as_ref() {
1438 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1439 if dev_ty != kdl || dev_id != alloc_dev {
1440 let wants_copy = copy
1441 .as_ref()
1442 .and_then(|c| c.extract::<bool>(py).ok())
1443 .unwrap_or(false);
1444 if wants_copy {
1445 return Err(PyValueError::new_err(
1446 "device copy not implemented for __dlpack__",
1447 ));
1448 } else {
1449 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1450 }
1451 }
1452 }
1453 }
1454
1455 unsafe {
1456 let st = self._stream as cust::sys::CUstream;
1457 let _ = cust::sys::cuStreamSynchronize(st);
1458 }
1459 let _ = stream;
1460
1461 let dummy =
1462 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1463 let inner = std::mem::replace(
1464 &mut self.inner,
1465 crate::cuda::moving_averages::DeviceArrayF32 {
1466 buf: dummy,
1467 rows: 0,
1468 cols: 0,
1469 },
1470 );
1471 let rows = inner.rows;
1472 let cols = inner.cols;
1473 let buf = inner.buf;
1474
1475 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1476
1477 return export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound);
1478
1479 if false {};
1480 }
1481}
1482
1483#[cfg(all(feature = "python", feature = "cuda"))]
1484impl DeviceArrayF32HmaPy {
1485 pub fn new(
1486 inner: crate::cuda::moving_averages::DeviceArrayF32,
1487 ctx_guard: std::sync::Arc<cust::context::Context>,
1488 device_id: u32,
1489 stream_u64: u64,
1490 ) -> Self {
1491 Self {
1492 inner,
1493 _ctx_guard: ctx_guard,
1494 _device_id: device_id,
1495 _stream: stream_u64,
1496 }
1497 }
1498}
1499
1500#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1501use serde::{Deserialize, Serialize};
1502#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1503use wasm_bindgen::prelude::*;
1504
1505#[inline]
1506pub fn hma_into_slice(dst: &mut [f64], input: &HmaInput, kern: Kernel) -> Result<(), HmaError> {
1507 let data: &[f64] = input.as_ref();
1508
1509 if dst.len() != data.len() {
1510 return Err(HmaError::OutputLengthMismatch {
1511 expected: data.len(),
1512 got: dst.len(),
1513 });
1514 }
1515
1516 hma_with_kernel_into(input, kern, dst)
1517}
1518
1519#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1520#[wasm_bindgen]
1521pub fn hma_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1522 let params = HmaParams {
1523 period: Some(period),
1524 };
1525 let input = HmaInput::from_slice(data, params);
1526
1527 let first = data
1528 .iter()
1529 .position(|x| !x.is_nan())
1530 .ok_or_else(|| JsValue::from_str("All NaN"))?;
1531 let sqrt_len = (period as f64).sqrt().floor() as usize;
1532 if period == 0 || sqrt_len == 0 || data.len() - first < period + sqrt_len - 1 {
1533 return Err(JsValue::from_str("Invalid or insufficient data"));
1534 }
1535 let first_out = first + period + sqrt_len - 2;
1536
1537 let mut output = alloc_with_nan_prefix(data.len(), first_out);
1538 hma_into_slice(&mut output, &input, Kernel::Auto)
1539 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1540 Ok(output)
1541}
1542
1543#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1544#[derive(Serialize, Deserialize)]
1545pub struct HmaBatchConfig {
1546 pub period_range: (usize, usize, usize),
1547}
1548
1549#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1550#[derive(Serialize, Deserialize)]
1551pub struct HmaBatchJsOutput {
1552 pub values: Vec<f64>,
1553 pub combos: Vec<HmaParams>,
1554 pub rows: usize,
1555 pub cols: usize,
1556}
1557
1558#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1559#[wasm_bindgen(js_name = hma_batch)]
1560pub fn hma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1561 let config: HmaBatchConfig = serde_wasm_bindgen::from_value(config)
1562 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1563
1564 let sweep = HmaBatchRange {
1565 period: config.period_range,
1566 };
1567
1568 let kernel = if cfg!(target_arch = "wasm32") {
1569 Kernel::ScalarBatch
1570 } else {
1571 Kernel::Auto
1572 };
1573
1574 let output = hma_batch_inner(data, &sweep, kernel, false)
1575 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1576
1577 let js_output = HmaBatchJsOutput {
1578 values: output.values,
1579 combos: output.combos,
1580 rows: output.rows,
1581 cols: output.cols,
1582 };
1583
1584 serde_wasm_bindgen::to_value(&js_output)
1585 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1586}
1587
1588#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1589#[wasm_bindgen]
1590pub fn hma_batch_js(
1591 data: &[f64],
1592 period_start: usize,
1593 period_end: usize,
1594 period_step: usize,
1595) -> Result<Vec<f64>, JsValue> {
1596 let sweep = HmaBatchRange {
1597 period: (period_start, period_end, period_step),
1598 };
1599
1600 let kernel = if cfg!(target_arch = "wasm32") {
1601 Kernel::ScalarBatch
1602 } else {
1603 Kernel::Auto
1604 };
1605
1606 let output = hma_batch_inner(data, &sweep, kernel, false)
1607 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1608
1609 Ok(output.values)
1610}
1611
1612#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1613#[wasm_bindgen]
1614pub fn hma_batch_metadata_js(
1615 period_start: usize,
1616 period_end: usize,
1617 period_step: usize,
1618) -> Vec<f64> {
1619 let periods: Vec<usize> = if period_step == 0 || period_start == period_end {
1620 vec![period_start]
1621 } else {
1622 (period_start..=period_end).step_by(period_step).collect()
1623 };
1624
1625 periods.iter().map(|&p| p as f64).collect()
1626}
1627
1628#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1629#[wasm_bindgen]
1630pub fn hma_alloc(len: usize) -> *mut f64 {
1631 let mut vec = Vec::<f64>::with_capacity(len);
1632 let ptr = vec.as_mut_ptr();
1633 std::mem::forget(vec);
1634 ptr
1635}
1636
1637#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1638#[wasm_bindgen]
1639pub fn hma_free(ptr: *mut f64, len: usize) {
1640 if !ptr.is_null() {
1641 unsafe {
1642 let _ = Vec::from_raw_parts(ptr, len, len);
1643 }
1644 }
1645}
1646
1647#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1648#[wasm_bindgen]
1649pub fn hma_into(
1650 in_ptr: *const f64,
1651 out_ptr: *mut f64,
1652 len: usize,
1653 period: usize,
1654) -> Result<(), JsValue> {
1655 if in_ptr.is_null() || out_ptr.is_null() {
1656 return Err(JsValue::from_str("null pointer passed to hma_into"));
1657 }
1658
1659 unsafe {
1660 let data = std::slice::from_raw_parts(in_ptr, len);
1661
1662 if period == 0 || period > len {
1663 return Err(JsValue::from_str("Invalid period"));
1664 }
1665
1666 let params = HmaParams {
1667 period: Some(period),
1668 };
1669 let input = HmaInput::from_slice(data, params);
1670
1671 if in_ptr == out_ptr {
1672 let mut temp = vec![0.0; len];
1673 hma_into_slice(&mut temp, &input, Kernel::Auto)
1674 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1675
1676 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1677 out.copy_from_slice(&temp);
1678 } else {
1679 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1680 hma_into_slice(out, &input, Kernel::Auto)
1681 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1682 }
1683
1684 Ok(())
1685 }
1686}
1687
1688#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1689#[wasm_bindgen]
1690pub fn hma_batch_into(
1691 in_ptr: *const f64,
1692 out_ptr: *mut f64,
1693 len: usize,
1694 period_start: usize,
1695 period_end: usize,
1696 period_step: usize,
1697) -> Result<usize, JsValue> {
1698 if in_ptr.is_null() || out_ptr.is_null() {
1699 return Err(JsValue::from_str("null pointer passed to hma_batch_into"));
1700 }
1701
1702 unsafe {
1703 let data = std::slice::from_raw_parts(in_ptr, len);
1704
1705 let sweep = HmaBatchRange {
1706 period: (period_start, period_end, period_step),
1707 };
1708
1709 let combos = expand_grid(&sweep);
1710 let rows = combos.len();
1711 let cols = len;
1712
1713 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
1714
1715 let kernel = if cfg!(target_arch = "wasm32") {
1716 Kernel::ScalarBatch
1717 } else {
1718 Kernel::Auto
1719 };
1720
1721 hma_batch_inner_into(data, &sweep, kernel, false, out)
1722 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1723
1724 Ok(rows)
1725 }
1726}
1727
1728#[cfg(feature = "python")]
1729pub fn register_hma_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
1730 m.add_function(wrap_pyfunction!(hma_py, m)?)?;
1731 m.add_function(wrap_pyfunction!(hma_batch_py, m)?)?;
1732 m.add_class::<HmaStreamPy>()?;
1733 #[cfg(feature = "cuda")]
1734 {
1735 m.add_class::<DeviceArrayF32HmaPy>()?;
1736 m.add_function(wrap_pyfunction!(hma_cuda_batch_dev_py, m)?)?;
1737 m.add_function(wrap_pyfunction!(hma_cuda_many_series_one_param_dev_py, m)?)?;
1738 }
1739 Ok(())
1740}
1741
1742#[cfg(test)]
1743mod tests {
1744 use super::*;
1745 use crate::skip_if_unsupported;
1746 use crate::utilities::data_loader::read_candles_from_csv;
1747 use proptest::prelude::*;
1748
1749 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1750 #[test]
1751 fn test_hma_into_matches_api() -> Result<(), Box<dyn Error>> {
1752 let data: Vec<f64> = (0..512)
1753 .map(|i| ((i as f64).sin() * 123.456789) + (i as f64) * 0.25)
1754 .collect();
1755
1756 let input = HmaInput::from_slice(&data, HmaParams::default());
1757
1758 let baseline = hma(&input)?.values;
1759
1760 let mut out = vec![0.0; data.len()];
1761 hma_into(&input, &mut out)?;
1762
1763 assert_eq!(baseline.len(), out.len());
1764 for (a, b) in baseline.iter().zip(out.iter()) {
1765 let equal = (a.is_nan() && b.is_nan()) || (a == b);
1766 assert!(equal, "Mismatch: a={:?}, b={:?}", a, b);
1767 }
1768
1769 Ok(())
1770 }
1771
1772 fn check_hma_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1773 skip_if_unsupported!(kernel, test_name);
1774 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1775 let candles = read_candles_from_csv(file_path)?;
1776 let default_params = HmaParams { period: None };
1777 let input_default = HmaInput::from_candles(&candles, "close", default_params);
1778 let output_default = hma_with_kernel(&input_default, kernel)?;
1779 assert_eq!(output_default.values.len(), candles.close.len());
1780 Ok(())
1781 }
1782
1783 fn check_hma_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1784 skip_if_unsupported!(kernel, test_name);
1785 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1786 let candles = read_candles_from_csv(file_path)?;
1787 let input = HmaInput::with_default_candles(&candles);
1788 let result = hma_with_kernel(&input, kernel)?;
1789 let expected_last_five = [
1790 59334.13333336847,
1791 59201.4666667018,
1792 59047.77777781293,
1793 59048.71111114628,
1794 58803.44444447962,
1795 ];
1796 assert!(result.values.len() >= 5);
1797 assert_eq!(result.values.len(), candles.close.len());
1798 let start = result.values.len() - 5;
1799 let last_five = &result.values[start..];
1800 for (i, &val) in last_five.iter().enumerate() {
1801 let exp = expected_last_five[i];
1802 assert!(
1803 (val - exp).abs() < 1e-3,
1804 "[{}] idx {}: got {}, expected {}",
1805 test_name,
1806 i,
1807 val,
1808 exp
1809 );
1810 }
1811 Ok(())
1812 }
1813
1814 fn check_hma_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1815 skip_if_unsupported!(kernel, test_name);
1816 let input_data = [10.0, 20.0, 30.0];
1817 let params = HmaParams { period: Some(0) };
1818 let input = HmaInput::from_slice(&input_data, params);
1819 let result = hma_with_kernel(&input, kernel);
1820 assert!(
1821 result.is_err(),
1822 "[{}] HMA should fail with zero period",
1823 test_name
1824 );
1825 Ok(())
1826 }
1827
1828 fn check_hma_period_exceeds_length(
1829 test_name: &str,
1830 kernel: Kernel,
1831 ) -> Result<(), Box<dyn Error>> {
1832 skip_if_unsupported!(kernel, test_name);
1833 let input_data = [10.0, 20.0, 30.0];
1834 let params = HmaParams { period: Some(10) };
1835 let input = HmaInput::from_slice(&input_data, params);
1836 let result = hma_with_kernel(&input, kernel);
1837 assert!(
1838 result.is_err(),
1839 "[{}] HMA should fail with period exceeding length",
1840 test_name
1841 );
1842 Ok(())
1843 }
1844
1845 fn check_hma_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1846 skip_if_unsupported!(kernel, test_name);
1847 let input_data = [42.0];
1848 let params = HmaParams { period: Some(5) };
1849 let input = HmaInput::from_slice(&input_data, params);
1850 let result = hma_with_kernel(&input, kernel);
1851 assert!(
1852 result.is_err(),
1853 "[{}] HMA should fail with insufficient data",
1854 test_name
1855 );
1856 Ok(())
1857 }
1858
1859 fn check_hma_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1860 skip_if_unsupported!(kernel, test_name);
1861 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1862 let candles = read_candles_from_csv(file_path)?;
1863 let first_params = HmaParams { period: Some(5) };
1864 let first_input = HmaInput::from_candles(&candles, "close", first_params);
1865 let first_result = hma_with_kernel(&first_input, kernel)?;
1866 let second_params = HmaParams { period: Some(3) };
1867 let second_input = HmaInput::from_slice(&first_result.values, second_params);
1868 let second_result = hma_with_kernel(&second_input, kernel)?;
1869 assert_eq!(second_result.values.len(), first_result.values.len());
1870 if second_result.values.len() > 240 {
1871 for i in 240..second_result.values.len() {
1872 assert!(!second_result.values[i].is_nan());
1873 }
1874 }
1875 Ok(())
1876 }
1877
1878 fn check_hma_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1879 skip_if_unsupported!(kernel, test_name);
1880 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1881 let candles = read_candles_from_csv(file_path)?;
1882 let params = HmaParams::default();
1883 let period = params.period.unwrap_or(5) * 2;
1884 let input = HmaInput::from_candles(&candles, "close", params);
1885 let result = hma_with_kernel(&input, kernel)?;
1886 assert_eq!(result.values.len(), candles.close.len());
1887 if result.values.len() > period {
1888 for i in period..result.values.len() {
1889 assert!(!result.values[i].is_nan());
1890 }
1891 }
1892 Ok(())
1893 }
1894
1895 fn check_hma_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1896 skip_if_unsupported!(kernel, test_name);
1897 let empty: [f64; 0] = [];
1898 let input = HmaInput::from_slice(&empty, HmaParams::default());
1899 let res = hma_with_kernel(&input, kernel);
1900 assert!(
1901 matches!(res, Err(HmaError::NoData)),
1902 "[{}] expected NoData",
1903 test_name
1904 );
1905 Ok(())
1906 }
1907
1908 fn check_hma_not_enough_valid(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1909 skip_if_unsupported!(kernel, test_name);
1910 let data = [f64::NAN, f64::NAN, 1.0, 2.0];
1911 let params = HmaParams { period: Some(3) };
1912 let input = HmaInput::from_slice(&data, params);
1913 let res = hma_with_kernel(&input, kernel);
1914 assert!(
1915 matches!(res, Err(HmaError::NotEnoughValidData { .. })),
1916 "[{}] expected NotEnoughValidData",
1917 test_name
1918 );
1919 Ok(())
1920 }
1921
1922 fn check_hma_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1923 skip_if_unsupported!(kernel, test_name);
1924 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1925 let candles = read_candles_from_csv(file_path)?;
1926 let period = 5;
1927 let input = HmaInput::from_candles(
1928 &candles,
1929 "close",
1930 HmaParams {
1931 period: Some(period),
1932 },
1933 );
1934 let batch_output = hma_with_kernel(&input, kernel)?.values;
1935
1936 let mut stream = HmaStream::try_new(HmaParams {
1937 period: Some(period),
1938 })?;
1939 let mut stream_vals = Vec::with_capacity(candles.close.len());
1940 for &p in &candles.close {
1941 match stream.update(p) {
1942 Some(v) => stream_vals.push(v),
1943 None => stream_vals.push(f64::NAN),
1944 }
1945 }
1946
1947 assert_eq!(batch_output.len(), stream_vals.len());
1948 for (i, (&b, &s)) in batch_output.iter().zip(stream_vals.iter()).enumerate() {
1949 if b.is_nan() && s.is_nan() {
1950 continue;
1951 }
1952 let diff = (b - s).abs();
1953 assert!(
1954 diff < 1e-4,
1955 "[{}] HMA streaming mismatch at idx {}: batch={}, stream={}, diff={}",
1956 test_name,
1957 i,
1958 b,
1959 s,
1960 diff
1961 );
1962 }
1963 Ok(())
1964 }
1965
1966 fn check_hma_property(
1967 test_name: &str,
1968 kernel: Kernel,
1969 ) -> Result<(), Box<dyn std::error::Error>> {
1970 skip_if_unsupported!(kernel, test_name);
1971
1972 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1973 let candles = read_candles_from_csv(file_path)?;
1974 let close_data = &candles.close;
1975
1976 let strat = (
1977 2usize..=100,
1978 0usize..close_data.len().saturating_sub(500),
1979 200usize..=500,
1980 );
1981
1982 proptest::test_runner::TestRunner::default()
1983 .run(&strat, |(period, start_idx, slice_len)| {
1984 let end_idx = (start_idx + slice_len).min(close_data.len());
1985 if end_idx <= start_idx || end_idx - start_idx < period + 10 {
1986 return Ok(());
1987 }
1988
1989 let data_slice = &close_data[start_idx..end_idx];
1990 let params = HmaParams {
1991 period: Some(period),
1992 };
1993 let input = HmaInput::from_slice(data_slice, params);
1994
1995 let result = hma_with_kernel(&input, kernel);
1996
1997 let scalar_result = hma_with_kernel(&input, Kernel::Scalar);
1998
1999 match (result, scalar_result) {
2000 (Ok(HmaOutput { values: out }), Ok(HmaOutput { values: ref_out })) => {
2001 prop_assert_eq!(out.len(), data_slice.len());
2002 prop_assert_eq!(ref_out.len(), data_slice.len());
2003
2004 let sqrt_period = (period as f64).sqrt().floor() as usize;
2005 let expected_warmup = period + sqrt_period - 1;
2006
2007 let first_valid = out.iter().position(|x| !x.is_nan());
2008 if let Some(first_idx) = first_valid {
2009 prop_assert!(
2010 first_idx >= expected_warmup - 1,
2011 "First valid at {} but expected warmup is {}",
2012 first_idx,
2013 expected_warmup
2014 );
2015
2016 for i in 0..first_idx {
2017 prop_assert!(
2018 out[i].is_nan(),
2019 "Expected NaN at index {} during warmup, got {}",
2020 i,
2021 out[i]
2022 );
2023 }
2024 }
2025
2026 for i in 0..out.len() {
2027 let y = out[i];
2028 let r = ref_out[i];
2029
2030 if y.is_nan() {
2031 prop_assert!(
2032 r.is_nan(),
2033 "Kernel mismatch at {}: {} vs {}",
2034 i,
2035 y,
2036 r
2037 );
2038 continue;
2039 }
2040
2041 prop_assert!(y.is_finite(), "Non-finite value at index {}: {}", i, y);
2042
2043 let y_bits = y.to_bits();
2044 let r_bits = r.to_bits();
2045 let ulp_diff = y_bits.abs_diff(r_bits);
2046
2047 let ulp_tolerance = if matches!(kernel, Kernel::Avx512) {
2048 20000
2049 } else {
2050 8
2051 };
2052 prop_assert!(
2053 (y - r).abs() <= 1e-8 || ulp_diff <= ulp_tolerance,
2054 "Kernel mismatch at {}: {} vs {} (ULP={})",
2055 i,
2056 y,
2057 r,
2058 ulp_diff
2059 );
2060 }
2061
2062 for i in expected_warmup..out.len() {
2063 let y = out[i];
2064 if y.is_nan() {
2065 continue;
2066 }
2067
2068 prop_assert!(y.is_finite(), "HMA output at {} is not finite: {}", i, y);
2069
2070 if i >= period * 2 {
2071 let window_start = i.saturating_sub(period);
2072 let window = &data_slice[window_start..=i];
2073 let is_constant =
2074 window.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
2075 if is_constant {
2076 let constant_val = window[0];
2077 prop_assert!(
2078 (y - constant_val).abs() <= 1e-6,
2079 "HMA should converge to {} for constant data, got {} at index {}",
2080 constant_val,
2081 y,
2082 i
2083 );
2084 }
2085 }
2086 }
2087
2088 if period == 2 {
2089 let min_valid_idx = expected_warmup;
2090 if out.len() > min_valid_idx {
2091 prop_assert!(
2092 out[min_valid_idx].is_finite(),
2093 "HMA with period=2 should produce valid output at index {}",
2094 min_valid_idx
2095 );
2096 }
2097 }
2098
2099 Ok(())
2100 }
2101 (Err(e1), Err(_e2)) => {
2102 prop_assert!(
2103 format!("{:?}", e1).contains("NotEnoughValidData")
2104 || format!("{:?}", e1).contains("InvalidPeriod"),
2105 "Unexpected error type: {:?}",
2106 e1
2107 );
2108 Ok(())
2109 }
2110 (Ok(_), Err(e)) | (Err(e), Ok(_)) => {
2111 prop_assert!(
2112 false,
2113 "Kernel consistency failure: one succeeded, one failed with {:?}",
2114 e
2115 );
2116 Ok(())
2117 }
2118 }
2119 })
2120 .unwrap();
2121
2122 let edge_cases = vec![
2123 (vec![1.0, 2.0, 3.0, 4.0, 5.0], 2),
2124 (vec![42.0; 100], 10),
2125 ((0..100).map(|i| i as f64).collect::<Vec<_>>(), 15),
2126 ((0..100).map(|i| 100.0 - i as f64).collect::<Vec<_>>(), 20),
2127 ];
2128
2129 for (case_idx, (data, period)) in edge_cases.into_iter().enumerate() {
2130 let params = HmaParams {
2131 period: Some(period),
2132 };
2133 let input = HmaInput::from_slice(&data, params);
2134
2135 match hma_with_kernel(&input, kernel) {
2136 Ok(out) => {
2137 let has_valid = out.values.iter().any(|&x| x.is_finite() && !x.is_nan());
2138 assert!(
2139 has_valid || data.len() < period + 2,
2140 "[{}] Edge case {} produced no valid output",
2141 test_name,
2142 case_idx
2143 );
2144 }
2145 Err(_) => {}
2146 }
2147 }
2148
2149 Ok(())
2150 }
2151
2152 #[cfg(debug_assertions)]
2153 fn check_hma_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2154 skip_if_unsupported!(kernel, test_name);
2155
2156 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2157 let candles = read_candles_from_csv(file_path)?;
2158
2159 let test_params = vec![
2160 HmaParams::default(),
2161 HmaParams { period: Some(2) },
2162 HmaParams { period: Some(3) },
2163 HmaParams { period: Some(4) },
2164 HmaParams { period: Some(5) },
2165 HmaParams { period: Some(7) },
2166 HmaParams { period: Some(10) },
2167 HmaParams { period: Some(14) },
2168 HmaParams { period: Some(20) },
2169 HmaParams { period: Some(30) },
2170 HmaParams { period: Some(50) },
2171 HmaParams { period: Some(100) },
2172 HmaParams { period: Some(200) },
2173 HmaParams { period: Some(1) },
2174 HmaParams { period: Some(250) },
2175 ];
2176
2177 for (param_idx, params) in test_params.iter().enumerate() {
2178 let input = HmaInput::from_candles(&candles, "close", params.clone());
2179 let output = hma_with_kernel(&input, kernel)?;
2180
2181 for (i, &val) in output.values.iter().enumerate() {
2182 if val.is_nan() {
2183 continue;
2184 }
2185
2186 let bits = val.to_bits();
2187
2188 if bits == 0x11111111_11111111 {
2189 panic!(
2190 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2191 with params: period={}",
2192 test_name,
2193 val,
2194 bits,
2195 i,
2196 params.period.unwrap_or(5)
2197 );
2198 }
2199
2200 if bits == 0x22222222_22222222 {
2201 panic!(
2202 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2203 with params: period={}",
2204 test_name,
2205 val,
2206 bits,
2207 i,
2208 params.period.unwrap_or(5)
2209 );
2210 }
2211
2212 if bits == 0x33333333_33333333 {
2213 panic!(
2214 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2215 with params: period={}",
2216 test_name,
2217 val,
2218 bits,
2219 i,
2220 params.period.unwrap_or(5)
2221 );
2222 }
2223 }
2224 }
2225
2226 Ok(())
2227 }
2228
2229 #[cfg(not(debug_assertions))]
2230 fn check_hma_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2231 Ok(())
2232 }
2233
2234 macro_rules! generate_all_hma_tests {
2235 ($($test_fn:ident),*) => {
2236 paste::paste! {
2237 $(
2238 #[test]
2239 fn [<$test_fn _scalar_f64>]() {
2240 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2241 }
2242 )*
2243 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2244 $(
2245 #[test]
2246 fn [<$test_fn _avx2_f64>]() {
2247 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2248 }
2249 #[test]
2250 fn [<$test_fn _avx512_f64>]() {
2251 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2252 }
2253 )*
2254 }
2255 }
2256 }
2257
2258 generate_all_hma_tests!(
2259 check_hma_partial_params,
2260 check_hma_accuracy,
2261 check_hma_zero_period,
2262 check_hma_period_exceeds_length,
2263 check_hma_very_small_dataset,
2264 check_hma_reinput,
2265 check_hma_nan_handling,
2266 check_hma_empty_input,
2267 check_hma_not_enough_valid,
2268 check_hma_streaming,
2269 check_hma_property,
2270 check_hma_no_poison
2271 );
2272
2273 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2274 skip_if_unsupported!(kernel, test);
2275 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2276 let c = read_candles_from_csv(file)?;
2277 let output = HmaBatchBuilder::new()
2278 .kernel(kernel)
2279 .apply_candles(&c, "close")?;
2280 let def = HmaParams::default();
2281 let row = output.values_for(&def).expect("default row missing");
2282 assert_eq!(row.len(), c.close.len());
2283 let expected = [
2284 59334.13333336847,
2285 59201.4666667018,
2286 59047.77777781293,
2287 59048.71111114628,
2288 58803.44444447962,
2289 ];
2290 let start = row.len() - 5;
2291 for (i, &v) in row[start..].iter().enumerate() {
2292 assert!(
2293 (v - expected[i]).abs() < 1e-3,
2294 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2295 );
2296 }
2297 Ok(())
2298 }
2299
2300 #[cfg(debug_assertions)]
2301 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2302 skip_if_unsupported!(kernel, test);
2303
2304 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2305 let c = read_candles_from_csv(file)?;
2306
2307 let test_configs = vec![
2308 (2, 5, 1),
2309 (5, 25, 5),
2310 (10, 50, 10),
2311 (2, 4, 1),
2312 (50, 150, 25),
2313 (10, 30, 2),
2314 (10, 30, 10),
2315 (100, 300, 50),
2316 ];
2317
2318 for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
2319 let output = HmaBatchBuilder::new()
2320 .kernel(kernel)
2321 .period_range(p_start, p_end, p_step)
2322 .apply_candles(&c, "close")?;
2323
2324 for (idx, &val) in output.values.iter().enumerate() {
2325 if val.is_nan() {
2326 continue;
2327 }
2328
2329 let bits = val.to_bits();
2330 let row = idx / output.cols;
2331 let col = idx % output.cols;
2332 let combo = &output.combos[row];
2333
2334 if bits == 0x11111111_11111111 {
2335 panic!(
2336 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2337 at row {} col {} (flat index {}) with params: period={}",
2338 test,
2339 cfg_idx,
2340 val,
2341 bits,
2342 row,
2343 col,
2344 idx,
2345 combo.period.unwrap_or(5)
2346 );
2347 }
2348
2349 if bits == 0x22222222_22222222 {
2350 panic!(
2351 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2352 at row {} col {} (flat index {}) with params: period={}",
2353 test,
2354 cfg_idx,
2355 val,
2356 bits,
2357 row,
2358 col,
2359 idx,
2360 combo.period.unwrap_or(5)
2361 );
2362 }
2363
2364 if bits == 0x33333333_33333333 {
2365 panic!(
2366 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2367 at row {} col {} (flat index {}) with params: period={}",
2368 test,
2369 cfg_idx,
2370 val,
2371 bits,
2372 row,
2373 col,
2374 idx,
2375 combo.period.unwrap_or(5)
2376 );
2377 }
2378 }
2379 }
2380
2381 Ok(())
2382 }
2383
2384 #[cfg(not(debug_assertions))]
2385 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2386 Ok(())
2387 }
2388
2389 macro_rules! gen_batch_tests {
2390 ($fn_name:ident) => {
2391 paste::paste! {
2392 #[test] fn [<$fn_name _scalar>]() {
2393 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2394 }
2395 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2396 #[test] fn [<$fn_name _avx2>]() {
2397 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2398 }
2399 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2400 #[test] fn [<$fn_name _avx512>]() {
2401 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2402 }
2403 #[test] fn [<$fn_name _auto_detect>]() {
2404 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]),
2405 Kernel::Auto);
2406 }
2407 }
2408 };
2409 }
2410 gen_batch_tests!(check_batch_default_row);
2411 gen_batch_tests!(check_batch_no_poison);
2412}