Skip to main content

vector_ta/indicators/moving_averages/
hma.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::moving_averages::CudaHma;
3use crate::utilities::data_loader::{source_type, Candles};
4use crate::utilities::enums::Kernel;
5use crate::utilities::helpers::{
6    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
7    make_uninit_matrix,
8};
9use aligned_vec::{AVec, CACHELINE_ALIGN};
10#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
11use core::arch::x86_64::*;
12#[cfg(all(feature = "python", feature = "cuda"))]
13use cust::memory::DeviceBuffer;
14#[cfg(not(target_arch = "wasm32"))]
15use rayon::prelude::*;
16use std::convert::AsRef;
17use std::error::Error;
18use std::mem::MaybeUninit;
19use thiserror::Error;
20impl<'a> AsRef<[f64]> for HmaInput<'a> {
21    #[inline(always)]
22    fn as_ref(&self) -> &[f64] {
23        match &self.data {
24            HmaData::Slice(slice) => slice,
25            HmaData::Candles { candles, source } => source_type(candles, source),
26        }
27    }
28}
29
30#[derive(Debug, Clone)]
31pub enum HmaData<'a> {
32    Candles {
33        candles: &'a Candles,
34        source: &'a str,
35    },
36    Slice(&'a [f64]),
37}
38
39#[derive(Debug, Clone)]
40pub struct HmaOutput {
41    pub values: Vec<f64>,
42}
43
44#[derive(Debug, Clone)]
45#[cfg_attr(
46    all(target_arch = "wasm32", feature = "wasm"),
47    derive(Serialize, Deserialize)
48)]
49pub struct HmaParams {
50    pub period: Option<usize>,
51}
52
53impl Default for HmaParams {
54    fn default() -> Self {
55        Self { period: Some(5) }
56    }
57}
58
59#[derive(Debug, Clone)]
60pub struct HmaInput<'a> {
61    pub data: HmaData<'a>,
62    pub params: HmaParams,
63}
64
65impl<'a> HmaInput<'a> {
66    #[inline]
67    pub fn from_candles(c: &'a Candles, s: &'a str, p: HmaParams) -> Self {
68        Self {
69            data: HmaData::Candles {
70                candles: c,
71                source: s,
72            },
73            params: p,
74        }
75    }
76    #[inline]
77    pub fn from_slice(sl: &'a [f64], p: HmaParams) -> Self {
78        Self {
79            data: HmaData::Slice(sl),
80            params: p,
81        }
82    }
83    #[inline]
84    pub fn with_default_candles(c: &'a Candles) -> Self {
85        Self::from_candles(c, "close", HmaParams::default())
86    }
87    #[inline]
88    pub fn get_period(&self) -> usize {
89        self.params.period.unwrap_or(5)
90    }
91}
92
93#[derive(Copy, Clone, Debug)]
94pub struct HmaBuilder {
95    period: Option<usize>,
96    kernel: Kernel,
97}
98
99impl Default for HmaBuilder {
100    fn default() -> Self {
101        Self {
102            period: None,
103            kernel: Kernel::Auto,
104        }
105    }
106}
107
108impl HmaBuilder {
109    #[inline(always)]
110    pub fn new() -> Self {
111        Self::default()
112    }
113    #[inline(always)]
114    pub fn period(mut self, n: usize) -> Self {
115        self.period = Some(n);
116        self
117    }
118    #[inline(always)]
119    pub fn kernel(mut self, k: Kernel) -> Self {
120        self.kernel = k;
121        self
122    }
123    #[inline(always)]
124    pub fn apply(self, c: &Candles) -> Result<HmaOutput, HmaError> {
125        let p = HmaParams {
126            period: self.period,
127        };
128        let i = HmaInput::from_candles(c, "close", p);
129        hma_with_kernel(&i, self.kernel)
130    }
131    #[inline(always)]
132    pub fn apply_slice(self, d: &[f64]) -> Result<HmaOutput, HmaError> {
133        let p = HmaParams {
134            period: self.period,
135        };
136        let i = HmaInput::from_slice(d, p);
137        hma_with_kernel(&i, self.kernel)
138    }
139
140    #[inline(always)]
141    pub fn into_stream(self) -> Result<HmaStream, HmaError> {
142        let p = HmaParams {
143            period: self.period,
144        };
145        HmaStream::try_new(p)
146    }
147}
148
149#[derive(Debug, Error)]
150pub enum HmaError {
151    #[error("hma: No data provided.")]
152    NoData,
153
154    #[error("hma: All values are NaN.")]
155    AllValuesNaN,
156
157    #[error("hma: Invalid period: period = {period}, data length = {data_len}")]
158    InvalidPeriod { period: usize, data_len: usize },
159
160    #[error("hma: Output length mismatch: expected = {expected}, got = {got}")]
161    OutputLengthMismatch { expected: usize, got: usize },
162
163    #[error("hma: Invalid range: start = {start}, end = {end}, step = {step}")]
164    InvalidRange {
165        start: usize,
166        end: usize,
167        step: usize,
168    },
169
170    #[error("hma: Invalid kernel for batch API: {0:?}")]
171    InvalidKernelForBatch(Kernel),
172
173    #[error("hma: arithmetic overflow when computing {what}")]
174    ArithmeticOverflow { what: &'static str },
175
176    #[error("hma: Cannot calculate half of period: period = {period}")]
177    ZeroHalf { period: usize },
178
179    #[error("hma: Cannot calculate sqrt of period: period = {period}")]
180    ZeroSqrtPeriod { period: usize },
181
182    #[error("hma: Not enough valid data: needed = {needed}, valid = {valid}")]
183    NotEnoughValidData { needed: usize, valid: usize },
184}
185
186#[inline]
187pub fn hma(input: &HmaInput) -> Result<HmaOutput, HmaError> {
188    hma_with_kernel(input, Kernel::Auto)
189}
190
191#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
192pub fn hma_into(input: &HmaInput, out: &mut [f64]) -> Result<(), HmaError> {
193    hma_into_internal(input, out)
194}
195
196#[inline]
197fn hma_into_internal(input: &HmaInput, out: &mut [f64]) -> Result<(), HmaError> {
198    hma_with_kernel_into(input, Kernel::Auto, out)
199}
200
201pub fn hma_with_kernel(input: &HmaInput, kernel: Kernel) -> Result<HmaOutput, HmaError> {
202    let data: &[f64] = input.as_ref();
203    let len = data.len();
204    if len == 0 {
205        return Err(HmaError::NoData);
206    }
207    let first = data
208        .iter()
209        .position(|x| !x.is_nan())
210        .ok_or(HmaError::AllValuesNaN)?;
211    let period = input.get_period();
212    if period == 0 || period > len {
213        return Err(HmaError::InvalidPeriod {
214            period,
215            data_len: len,
216        });
217    }
218    if len - first < period {
219        return Err(HmaError::NotEnoughValidData {
220            needed: period,
221            valid: len - first,
222        });
223    }
224    let half = period / 2;
225    if half == 0 {
226        return Err(HmaError::ZeroHalf { period });
227    }
228    let sqrt_len = (period as f64).sqrt().floor() as usize;
229    if sqrt_len == 0 {
230        return Err(HmaError::ZeroSqrtPeriod { period });
231    }
232    if len - first < period + sqrt_len - 1 {
233        return Err(HmaError::NotEnoughValidData {
234            needed: period + sqrt_len - 1,
235            valid: len - first,
236        });
237    }
238    let chosen = match kernel {
239        Kernel::Auto => detect_best_kernel(),
240        other => other,
241    };
242    let first_out = first + period + sqrt_len - 2;
243    let mut out = alloc_with_nan_prefix(len, first_out);
244    unsafe {
245        match chosen {
246            Kernel::Scalar | Kernel::ScalarBatch => hma_scalar(data, period, first, &mut out),
247            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
248            Kernel::Avx2 | Kernel::Avx2Batch => hma_avx2(data, period, first, &mut out),
249            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
250            Kernel::Avx512 | Kernel::Avx512Batch => hma_avx512(data, period, first, &mut out),
251            _ => unreachable!(),
252        }
253    }
254    Ok(HmaOutput { values: out })
255}
256
257fn hma_with_kernel_into(input: &HmaInput, kernel: Kernel, out: &mut [f64]) -> Result<(), HmaError> {
258    let data: &[f64] = input.as_ref();
259    let len = data.len();
260    if len == 0 {
261        return Err(HmaError::NoData);
262    }
263    if out.len() != len {
264        return Err(HmaError::OutputLengthMismatch {
265            expected: len,
266            got: out.len(),
267        });
268    }
269
270    let first = data
271        .iter()
272        .position(|x| !x.is_nan())
273        .ok_or(HmaError::AllValuesNaN)?;
274    let period = input.get_period();
275    if period == 0 || period > len {
276        return Err(HmaError::InvalidPeriod {
277            period,
278            data_len: len,
279        });
280    }
281    if len - first < period {
282        return Err(HmaError::NotEnoughValidData {
283            needed: period,
284            valid: len - first,
285        });
286    }
287
288    let half = period / 2;
289    if half == 0 {
290        return Err(HmaError::ZeroHalf { period });
291    }
292    let sqrt_len = (period as f64).sqrt().floor() as usize;
293    if sqrt_len == 0 {
294        return Err(HmaError::ZeroSqrtPeriod { period });
295    }
296    if len - first < period + sqrt_len - 1 {
297        return Err(HmaError::NotEnoughValidData {
298            needed: period + sqrt_len - 1,
299            valid: len - first,
300        });
301    }
302
303    let chosen = match kernel {
304        Kernel::Auto => detect_best_kernel(),
305        other => other,
306    };
307    unsafe {
308        match chosen {
309            Kernel::Scalar | Kernel::ScalarBatch => hma_scalar(data, period, first, out),
310            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
311            Kernel::Avx2 | Kernel::Avx2Batch => hma_avx2(data, period, first, out),
312            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
313            Kernel::Avx512 | Kernel::Avx512Batch => hma_avx512(data, period, first, out),
314            _ => unreachable!(),
315        }
316    }
317
318    let first_out = first + period + sqrt_len - 2;
319    for v in &mut out[..first_out] {
320        *v = f64::NAN;
321    }
322    Ok(())
323}
324
325#[inline]
326pub fn hma_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
327    let len = data.len();
328    if period < 2 || first >= len || period > len - first {
329        return;
330    }
331    let half = period / 2;
332    if half == 0 {
333        return;
334    }
335    let sq = (period as f64).sqrt().floor() as usize;
336    if sq == 0 {
337        return;
338    }
339
340    let first_out = first + period + sq - 2;
341    if first_out >= len {
342        return;
343    }
344
345    let ws_half = (half * (half + 1) / 2) as f64;
346    let ws_full = (period * (period + 1) / 2) as f64;
347    let ws_sqrt = (sq * (sq + 1) / 2) as f64;
348    let half_f = half as f64;
349    let period_f = period as f64;
350    let sq_f = sq as f64;
351
352    let (mut s_half, mut ws_half_acc) = (0.0, 0.0);
353    let (mut s_full, mut ws_full_acc) = (0.0, 0.0);
354    let (mut wma_half, mut wma_full) = (f64::NAN, f64::NAN);
355
356    let mut x_buf = vec![0.0f64; sq];
357    let mut x_sum = 0.0;
358    let mut x_wsum = 0.0;
359    let mut x_head = 0usize;
360
361    let start = first;
362
363    for j in 0..half {
364        let v = data[start + j];
365        let jf = j as f64 + 1.0;
366        s_full += v;
367        ws_full_acc = jf.mul_add(v, ws_full_acc);
368        s_half += v;
369        ws_half_acc = jf.mul_add(v, ws_half_acc);
370    }
371    wma_half = ws_half_acc / ws_half;
372
373    if period > half + 1 {
374        for j in half..(period - 1) {
375            let idx = start + j;
376            let v = data[idx];
377
378            let jf = j as f64 + 1.0;
379            s_full += v;
380            ws_full_acc = jf.mul_add(v, ws_full_acc);
381
382            let old_h = data[idx - half];
383            let prev = s_half;
384            s_half = prev + v - old_h;
385            ws_half_acc = half_f.mul_add(v, ws_half_acc - prev);
386            wma_half = ws_half_acc / ws_half;
387        }
388    }
389
390    {
391        let j = period - 1;
392        let idx = start + j;
393        let v = data[idx];
394
395        let jf = j as f64 + 1.0;
396        s_full += v;
397        ws_full_acc = jf.mul_add(v, ws_full_acc);
398        wma_full = ws_full_acc / ws_full;
399
400        let old_h = data[idx - half];
401        let prev = s_half;
402        s_half = prev + v - old_h;
403        ws_half_acc = half_f.mul_add(v, ws_half_acc - prev);
404        wma_half = ws_half_acc / ws_half;
405
406        let x = 2.0 * wma_half - wma_full;
407        x_buf[0] = x;
408        x_sum += x;
409        x_wsum = 1.0f64.mul_add(x, x_wsum);
410
411        if sq == 1 {
412            out[first_out] = x_wsum / ws_sqrt;
413        }
414    }
415
416    if sq > 1 {
417        for j in period..(period + sq - 1) {
418            let idx = start + j;
419            let v = data[idx];
420
421            let old_f = data[idx - period];
422            let prev_f = s_full;
423            s_full = prev_f + v - old_f;
424            ws_full_acc = period_f.mul_add(v, ws_full_acc - prev_f);
425            wma_full = ws_full_acc / ws_full;
426
427            let old_h = data[idx - half];
428            let prev_h = s_half;
429            s_half = prev_h + v - old_h;
430            ws_half_acc = half_f.mul_add(v, ws_half_acc - prev_h);
431            wma_half = ws_half_acc / ws_half;
432
433            let x = 2.0 * wma_half - wma_full;
434            let pos = j + 1 - period;
435            x_buf[pos] = x;
436            x_sum += x;
437            x_wsum = (pos as f64 + 1.0).mul_add(x, x_wsum);
438
439            if pos + 1 == sq {
440                out[first_out] = x_wsum / ws_sqrt;
441            }
442        }
443    }
444
445    let mut j = period + sq - 1;
446    while j < len - start {
447        let idx = start + j;
448        let v = data[idx];
449
450        let old_f = data[idx - period];
451        let prev_f = s_full;
452        s_full = prev_f + v - old_f;
453        ws_full_acc = period_f.mul_add(v, ws_full_acc - prev_f);
454        wma_full = ws_full_acc / ws_full;
455
456        let old_h = data[idx - half];
457        let prev_h = s_half;
458        s_half = prev_h + v - old_h;
459        ws_half_acc = half_f.mul_add(v, ws_half_acc - prev_h);
460        wma_half = ws_half_acc / ws_half;
461
462        let x = 2.0 * wma_half - wma_full;
463        let old_x = x_buf[x_head];
464        x_buf[x_head] = x;
465        x_head += 1;
466        if x_head == sq {
467            x_head = 0;
468        }
469
470        let prev_sum = x_sum;
471        x_sum = prev_sum + x - old_x;
472        x_wsum = sq_f.mul_add(x, x_wsum - prev_sum);
473
474        out[idx] = x_wsum / ws_sqrt;
475        j += 1;
476    }
477}
478
479#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
480#[inline]
481pub fn hma_avx2(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
482    hma_scalar(data, period, first, out)
483}
484
485#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
486#[target_feature(enable = "avx512f,fma")]
487#[inline]
488pub unsafe fn hma_avx512(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
489    use aligned_vec::AVec;
490    use core::arch::x86_64::*;
491
492    let len = data.len();
493    if period < 2 || first >= len || period > len - first {
494        return;
495    }
496    let half = period / 2;
497    if half == 0 {
498        return;
499    }
500    let sq = (period as f64).sqrt().floor() as usize;
501    debug_assert!(
502        sq > 0 && sq <= 65_535,
503        "HMA: √period must fit in 16-bit to keep Σw < 2^53"
504    );
505    if sq == 0 {
506        return;
507    }
508    let first_out = first + period + sq - 2;
509    if first_out >= len {
510        return;
511    }
512
513    let ws_half = (half * (half + 1) / 2) as f64;
514    let ws_full = (period * (period + 1) / 2) as f64;
515    let ws_sqrt = (sq * (sq + 1) / 2) as f64;
516    let sq_f = sq as f64;
517
518    let (mut s_half, mut ws_half_acc) = (0.0, 0.0);
519    let (mut s_full, mut ws_full_acc) = (0.0, 0.0);
520    let (mut wma_half, mut wma_full) = (f64::NAN, f64::NAN);
521
522    let sq_aligned = (sq + 7) & !7;
523    let mut x_buf: AVec<f64> = AVec::with_capacity(64, sq_aligned);
524    x_buf.resize(sq_aligned, 0.0);
525
526    let mut x_sum = 0.0;
527    let mut x_wsum = 0.0;
528    let mut x_head = 0usize;
529
530    const W_RAMP_ARR: [f64; 8] = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
531
532    let w_ramp: __m512d = _mm512_loadu_pd(W_RAMP_ARR.as_ptr());
533
534    #[inline(always)]
535    unsafe fn horiz_sum(z: __m512d) -> f64 {
536        let hi = _mm512_extractf64x4_pd(z, 1);
537        let lo = _mm512_castpd512_pd256(z);
538
539        let sum256 = _mm256_add_pd(hi, lo);
540
541        let sum128 = _mm256_hadd_pd(sum256, sum256);
542
543        let hi128 = _mm256_extractf128_pd(sum128, 1);
544        let lo128 = _mm256_castpd256_pd128(sum128);
545        let final_sum = _mm_add_pd(hi128, lo128);
546
547        _mm_cvtsd_f64(final_sum)
548    }
549
550    for j in 0..(period + sq - 1) {
551        let idx = first + j;
552        let val = *data.get_unchecked(idx);
553
554        if j < period {
555            s_full += val;
556            ws_full_acc += (j as f64 + 1.0) * val;
557        } else {
558            let old = *data.get_unchecked(idx - period);
559            let prev = s_full;
560            s_full = prev + val - old;
561            ws_full_acc = ws_full_acc - prev + (period as f64) * val;
562        }
563
564        if j < half {
565            s_half += val;
566            ws_half_acc += (j as f64 + 1.0) * val;
567        } else {
568            let old = *data.get_unchecked(idx - half);
569            let prev = s_half;
570            s_half = prev + val - old;
571            ws_half_acc = ws_half_acc - prev + (half as f64) * val;
572        }
573
574        if j + 1 >= half {
575            wma_half = ws_half_acc / ws_half;
576        }
577        if j + 1 >= period {
578            wma_full = ws_full_acc / ws_full;
579        }
580
581        if j + 1 >= period {
582            let x_val = 2.0 * wma_half - wma_full;
583            let pos = (j + 1 - period) as usize;
584
585            if pos < sq {
586                *x_buf.get_unchecked_mut(pos) = x_val;
587                x_sum += x_val;
588
589                if pos + 1 == sq {
590                    let mut acc = _mm512_setzero_pd();
591                    let mut i = 0usize;
592                    let mut off = 0.0;
593                    while i + 8 <= sq {
594                        let x = _mm512_loadu_pd(x_buf.as_ptr().add(i));
595
596                        let w = _mm512_add_pd(w_ramp, _mm512_set1_pd(off + 1.0));
597                        acc = _mm512_fmadd_pd(x, w, acc);
598                        i += 8;
599                        off += 8.0;
600                    }
601                    x_wsum = horiz_sum(acc);
602                    for k in i..sq {
603                        x_wsum += x_buf[k] * (k as f64 + 1.0);
604                    }
605                    *out.get_unchecked_mut(first_out) = x_wsum / ws_sqrt;
606                }
607            }
608        }
609    }
610
611    for j in (period + sq - 1)..(len - first) {
612        let idx = first + j;
613        let val = *data.get_unchecked(idx);
614
615        let old_f = *data.get_unchecked(idx - period);
616        let old_h = *data.get_unchecked(idx - half);
617
618        let sum_vec = _mm_set_pd(s_full, s_half);
619        let old_vec = _mm_set_pd(old_f, old_h);
620        let ws_vec = _mm_set_pd(ws_full_acc, ws_half_acc);
621        let weights = _mm_set_pd(period as f64, half as f64);
622        let v_val = _mm_set1_pd(val);
623
624        let new_sum_vec = _mm_add_pd(_mm_sub_pd(sum_vec, old_vec), v_val);
625
626        let diff = _mm_sub_pd(ws_vec, sum_vec);
627        let new_ws_vec = _mm_fmadd_pd(v_val, weights, diff);
628
629        s_full = _mm_cvtsd_f64(_mm_unpackhi_pd(new_sum_vec, new_sum_vec));
630        s_half = _mm_cvtsd_f64(new_sum_vec);
631        ws_full_acc = _mm_cvtsd_f64(_mm_unpackhi_pd(new_ws_vec, new_ws_vec));
632        ws_half_acc = _mm_cvtsd_f64(new_ws_vec);
633
634        wma_full = ws_full_acc / ws_full;
635        wma_half = ws_half_acc / ws_half;
636        let x_val = 2.0 * wma_half - wma_full;
637
638        let old_x = *x_buf.get_unchecked(x_head);
639        *x_buf.get_unchecked_mut(x_head) = x_val;
640        x_head = (x_head + 1) % sq;
641
642        let prev_sum = x_sum;
643        x_sum = prev_sum + x_val - old_x;
644        x_wsum = sq_f.mul_add(x_val, x_wsum - prev_sum);
645
646        *out.get_unchecked_mut(idx) = x_wsum / ws_sqrt;
647
648        let pf = core::cmp::min(idx + 128, len - 1);
649        _mm_prefetch(data.as_ptr().add(pf) as *const i8, _MM_HINT_T1);
650    }
651}
652
653#[derive(Debug, Clone)]
654struct LinWma {
655    period: usize,
656    inv_norm: f64,
657    buffer: Vec<f64>,
658    head: usize,
659    filled: bool,
660    count: usize,
661
662    sum: f64,
663    wsum: f64,
664    nan_count: usize,
665    dirty: bool,
666}
667
668impl LinWma {
669    #[inline(always)]
670    fn new(period: usize) -> Self {
671        let norm = (period as f64) * ((period as f64) + 1.0) * 0.5;
672        Self {
673            period,
674            inv_norm: 1.0 / norm,
675            buffer: vec![f64::NAN; period],
676            head: 0,
677            filled: false,
678            count: 0,
679            sum: 0.0,
680            wsum: 0.0,
681            nan_count: 0,
682            dirty: false,
683        }
684    }
685
686    #[inline(always)]
687    fn rebuild(&mut self) {
688        self.sum = 0.0;
689        self.wsum = 0.0;
690        self.nan_count = 0;
691
692        let mut idx = self.head;
693        for i in 0..self.period {
694            let v = self.buffer[idx];
695            if v.is_nan() {
696                self.nan_count += 1;
697            } else {
698                self.sum += v;
699                self.wsum = (i as f64 + 1.0).mul_add(v, self.wsum);
700            }
701            idx = if idx + 1 == self.period { 0 } else { idx + 1 };
702        }
703        self.dirty = self.nan_count != 0;
704        debug_assert!(self.nan_count == 0, "rebuild expected clean window");
705    }
706
707    #[inline(always)]
708    fn update(&mut self, value: f64) -> Option<f64> {
709        let n = self.period as f64;
710
711        let old = self.buffer[self.head];
712        self.buffer[self.head] = value;
713        self.head = if self.head + 1 == self.period {
714            0
715        } else {
716            self.head + 1
717        };
718
719        if !self.filled {
720            self.count += 1;
721
722            if value.is_nan() {
723                self.nan_count += 1;
724                self.dirty = true;
725            } else {
726                self.sum += value;
727                self.wsum = (self.count as f64).mul_add(value, self.wsum);
728            }
729
730            if self.count == self.period {
731                self.filled = true;
732
733                return Some(if self.nan_count > 0 {
734                    f64::NAN
735                } else {
736                    self.wsum * self.inv_norm
737                });
738            }
739            return None;
740        }
741
742        if old.is_nan() {
743            self.nan_count = self.nan_count.saturating_sub(1);
744        }
745        if value.is_nan() {
746            self.nan_count += 1;
747        }
748
749        if self.nan_count > 0 {
750            self.dirty = true;
751            return Some(f64::NAN);
752        }
753
754        if self.dirty {
755            self.rebuild();
756            self.dirty = false;
757            debug_assert_eq!(self.nan_count, 0);
758            return Some(self.wsum * self.inv_norm);
759        }
760
761        let prev_sum = self.sum;
762        self.sum = prev_sum + value - old;
763        self.wsum = n.mul_add(value, self.wsum - prev_sum);
764
765        Some(self.wsum * self.inv_norm)
766    }
767}
768
769#[derive(Debug, Clone)]
770pub struct HmaStream {
771    wma_half: LinWma,
772    wma_full: LinWma,
773    wma_sqrt: LinWma,
774}
775
776impl HmaStream {
777    pub fn try_new(params: HmaParams) -> Result<Self, HmaError> {
778        let period = params.period.unwrap_or(5);
779        if period < 2 {
780            return Err(HmaError::InvalidPeriod {
781                period,
782                data_len: 0,
783            });
784        }
785        let half = period / 2;
786        if half == 0 {
787            return Err(HmaError::ZeroHalf { period });
788        }
789        let sqrt_len = (period as f64).sqrt().floor() as usize;
790        if sqrt_len == 0 {
791            return Err(HmaError::ZeroSqrtPeriod { period });
792        }
793
794        Ok(Self {
795            wma_half: LinWma::new(half),
796            wma_full: LinWma::new(period),
797            wma_sqrt: LinWma::new(sqrt_len),
798        })
799    }
800
801    #[inline(always)]
802    pub fn update(&mut self, value: f64) -> Option<f64> {
803        let full = self.wma_full.update(value);
804        let half = self.wma_half.update(value);
805
806        if let (Some(f), Some(h)) = (full, half) {
807            let x = 2.0f64.mul_add(h, -f);
808            self.wma_sqrt.update(x)
809        } else {
810            None
811        }
812    }
813}
814
815#[derive(Clone, Debug)]
816pub struct HmaBatchRange {
817    pub period: (usize, usize, usize),
818}
819
820impl Default for HmaBatchRange {
821    fn default() -> Self {
822        Self {
823            period: (5, 254, 1),
824        }
825    }
826}
827
828#[derive(Clone, Debug, Default)]
829pub struct HmaBatchBuilder {
830    range: HmaBatchRange,
831    kernel: Kernel,
832}
833
834impl HmaBatchBuilder {
835    pub fn new() -> Self {
836        Self::default()
837    }
838    pub fn kernel(mut self, k: Kernel) -> Self {
839        self.kernel = k;
840        self
841    }
842    #[inline]
843    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
844        self.range.period = (start, end, step);
845        self
846    }
847    #[inline]
848    pub fn period_static(mut self, p: usize) -> Self {
849        self.range.period = (p, p, 0);
850        self
851    }
852    pub fn apply_slice(self, data: &[f64]) -> Result<HmaBatchOutput, HmaError> {
853        hma_batch_with_kernel(data, &self.range, self.kernel)
854    }
855    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<HmaBatchOutput, HmaError> {
856        HmaBatchBuilder::new().kernel(k).apply_slice(data)
857    }
858    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<HmaBatchOutput, HmaError> {
859        let slice = source_type(c, src);
860        self.apply_slice(slice)
861    }
862    pub fn with_default_candles(c: &Candles) -> Result<HmaBatchOutput, HmaError> {
863        HmaBatchBuilder::new()
864            .kernel(Kernel::Auto)
865            .apply_candles(c, "close")
866    }
867}
868
869pub fn hma_batch_with_kernel(
870    data: &[f64],
871    sweep: &HmaBatchRange,
872    k: Kernel,
873) -> Result<HmaBatchOutput, HmaError> {
874    let kernel = match k {
875        Kernel::Auto => detect_best_batch_kernel(),
876        other if other.is_batch() => other,
877        other => return Err(HmaError::InvalidKernelForBatch(other)),
878    };
879    let simd = match kernel {
880        Kernel::Avx512Batch => Kernel::Avx512,
881        Kernel::Avx2Batch => Kernel::Avx2,
882        Kernel::ScalarBatch => Kernel::Scalar,
883        _ => unreachable!(),
884    };
885    hma_batch_par_slice(data, sweep, simd)
886}
887
888#[derive(Clone, Debug)]
889pub struct HmaBatchOutput {
890    pub values: Vec<f64>,
891    pub combos: Vec<HmaParams>,
892    pub rows: usize,
893    pub cols: usize,
894}
895
896impl HmaBatchOutput {
897    pub fn row_for_params(&self, p: &HmaParams) -> Option<usize> {
898        self.combos
899            .iter()
900            .position(|c| c.period.unwrap_or(5) == p.period.unwrap_or(5))
901    }
902    pub fn values_for(&self, p: &HmaParams) -> Option<&[f64]> {
903        self.row_for_params(p).map(|row| {
904            let start = row * self.cols;
905            &self.values[start..start + self.cols]
906        })
907    }
908}
909
910#[inline(always)]
911fn expand_grid(r: &HmaBatchRange) -> Vec<HmaParams> {
912    fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
913        if step == 0 || start == end {
914            return vec![start];
915        }
916
917        let (lo, hi) = if start <= end {
918            (start, end)
919        } else {
920            (end, start)
921        };
922        let mut v = Vec::new();
923        let mut x = lo;
924        while x <= hi {
925            v.push(x);
926            match x.checked_add(step) {
927                Some(nx) => x = nx,
928                None => break,
929            }
930        }
931        v
932    }
933    let periods = axis_usize(r.period);
934    let mut out = Vec::with_capacity(periods.len());
935    for &p in &periods {
936        out.push(HmaParams { period: Some(p) });
937    }
938    out
939}
940
941#[inline(always)]
942pub fn hma_batch_slice(
943    data: &[f64],
944    sweep: &HmaBatchRange,
945    kern: Kernel,
946) -> Result<HmaBatchOutput, HmaError> {
947    hma_batch_inner(data, sweep, kern, false)
948}
949
950#[inline(always)]
951pub fn hma_batch_par_slice(
952    data: &[f64],
953    sweep: &HmaBatchRange,
954    kern: Kernel,
955) -> Result<HmaBatchOutput, HmaError> {
956    hma_batch_inner(data, sweep, kern, true)
957}
958
959#[inline(always)]
960fn hma_batch_inner(
961    data: &[f64],
962    sweep: &HmaBatchRange,
963    kern: Kernel,
964    parallel: bool,
965) -> Result<HmaBatchOutput, HmaError> {
966    let combos = expand_grid(sweep);
967    if combos.is_empty() {
968        let (s, e, t) = sweep.period;
969        return Err(HmaError::InvalidRange {
970            start: s,
971            end: e,
972            step: t,
973        });
974    }
975    let first = data
976        .iter()
977        .position(|x| !x.is_nan())
978        .ok_or(HmaError::AllValuesNaN)?;
979    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
980    if data.len() - first < max_p {
981        return Err(HmaError::NotEnoughValidData {
982            needed: max_p,
983            valid: data.len() - first,
984        });
985    }
986    let rows = combos.len();
987    let cols = data.len();
988
989    let warm: Vec<usize> = combos
990        .iter()
991        .map(|c| {
992            let p = c.period.unwrap();
993            let s = (p as f64).sqrt().floor() as usize;
994            first + p + s - 2
995        })
996        .collect();
997
998    let _ = rows
999        .checked_mul(cols)
1000        .ok_or(HmaError::ArithmeticOverflow { what: "rows*cols" })?;
1001    let mut raw = make_uninit_matrix(rows, cols);
1002    unsafe { init_matrix_prefixes(&mut raw, cols, &warm) };
1003
1004    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1005        let period = combos[row].period.unwrap();
1006
1007        let out_row =
1008            core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1009
1010        match kern {
1011            Kernel::Scalar | Kernel::ScalarBatch => hma_row_scalar(data, first, period, out_row),
1012            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1013            Kernel::Avx2 | Kernel::Avx2Batch => hma_row_avx2(data, first, period, out_row),
1014            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1015            Kernel::Avx512 | Kernel::Avx512Batch => hma_row_avx512(data, first, period, out_row),
1016            _ => unreachable!(),
1017        }
1018    };
1019
1020    if parallel {
1021        #[cfg(not(target_arch = "wasm32"))]
1022        {
1023            raw.par_chunks_mut(cols)
1024                .enumerate()
1025                .for_each(|(row, slice)| do_row(row, slice));
1026        }
1027
1028        #[cfg(target_arch = "wasm32")]
1029        {
1030            for (row, slice) in raw.chunks_mut(cols).enumerate() {
1031                do_row(row, slice);
1032            }
1033        }
1034    } else {
1035        for (row, slice) in raw.chunks_mut(cols).enumerate() {
1036            do_row(row, slice);
1037        }
1038    }
1039
1040    let rows = combos.len();
1041    let cols = data.len();
1042    let _ = rows
1043        .checked_mul(cols)
1044        .ok_or(HmaError::ArithmeticOverflow { what: "rows*cols" })?;
1045
1046    let mut guard = core::mem::ManuallyDrop::new(raw);
1047    let values: Vec<f64> = unsafe {
1048        Vec::from_raw_parts(
1049            guard.as_mut_ptr() as *mut f64,
1050            guard.len(),
1051            guard.capacity(),
1052        )
1053    };
1054
1055    Ok(HmaBatchOutput {
1056        values,
1057        combos,
1058        rows,
1059        cols,
1060    })
1061}
1062
1063#[inline(always)]
1064fn hma_batch_inner_into(
1065    data: &[f64],
1066    sweep: &HmaBatchRange,
1067    kern: Kernel,
1068    parallel: bool,
1069    out: &mut [f64],
1070) -> Result<(Vec<HmaParams>, usize, usize), HmaError> {
1071    let combos = expand_grid(sweep);
1072    if combos.is_empty() {
1073        let (s, e, t) = sweep.period;
1074        return Err(HmaError::InvalidRange {
1075            start: s,
1076            end: e,
1077            step: t,
1078        });
1079    }
1080    let first = data
1081        .iter()
1082        .position(|x| !x.is_nan())
1083        .ok_or(HmaError::AllValuesNaN)?;
1084    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1085    if data.len() - first < max_p {
1086        return Err(HmaError::NotEnoughValidData {
1087            needed: max_p,
1088            valid: data.len() - first,
1089        });
1090    }
1091    let rows = combos.len();
1092    let cols = data.len();
1093
1094    let expected = rows
1095        .checked_mul(cols)
1096        .ok_or(HmaError::ArithmeticOverflow { what: "rows*cols" })?;
1097    if out.len() != expected {
1098        return Err(HmaError::OutputLengthMismatch {
1099            expected,
1100            got: out.len(),
1101        });
1102    }
1103
1104    let warm: Vec<usize> = combos
1105        .iter()
1106        .map(|c| {
1107            let p = c.period.unwrap();
1108            let s = (p as f64).sqrt().floor() as usize;
1109            first + p + s - 2
1110        })
1111        .collect();
1112
1113    let out_uninit = unsafe {
1114        std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1115    };
1116    unsafe { init_matrix_prefixes(out_uninit, cols, &warm) };
1117
1118    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1119        let period = combos[row].period.unwrap();
1120
1121        match kern {
1122            Kernel::Scalar | Kernel::ScalarBatch => hma_row_scalar(data, first, period, out_row),
1123            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1124            Kernel::Avx2 | Kernel::Avx2Batch => hma_row_avx2(data, first, period, out_row),
1125            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1126            Kernel::Avx512 | Kernel::Avx512Batch => hma_row_avx512(data, first, period, out_row),
1127            _ => unreachable!(),
1128        }
1129    };
1130
1131    if parallel {
1132        #[cfg(not(target_arch = "wasm32"))]
1133        {
1134            out.par_chunks_mut(cols)
1135                .enumerate()
1136                .for_each(|(row, slice)| do_row(row, slice));
1137        }
1138        #[cfg(target_arch = "wasm32")]
1139        {
1140            for (row, slice) in out.chunks_mut(cols).enumerate() {
1141                do_row(row, slice);
1142            }
1143        }
1144    } else {
1145        for (row, slice) in out.chunks_mut(cols).enumerate() {
1146            do_row(row, slice);
1147        }
1148    }
1149
1150    Ok((combos, rows, cols))
1151}
1152
1153#[inline(always)]
1154pub unsafe fn hma_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1155    hma_scalar(data, period, first, out)
1156}
1157
1158#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1159#[inline(always)]
1160pub unsafe fn hma_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1161    hma_avx2(data, period, first, out);
1162}
1163
1164#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1165#[inline(always)]
1166pub unsafe fn hma_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1167    hma_avx512(data, period, first, out);
1168}
1169
1170#[inline(always)]
1171fn expand_grid_hma(r: &HmaBatchRange) -> Vec<HmaParams> {
1172    expand_grid(r)
1173}
1174
1175#[cfg(feature = "python")]
1176use crate::utilities::kernel_validation::validate_kernel;
1177#[cfg(feature = "python")]
1178use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
1179#[cfg(feature = "python")]
1180use pyo3::exceptions::PyValueError;
1181#[cfg(feature = "python")]
1182use pyo3::prelude::*;
1183#[cfg(feature = "python")]
1184use pyo3::types::PyDict;
1185
1186#[cfg(feature = "python")]
1187#[pyfunction(name = "hma")]
1188#[pyo3(signature = (data, period, kernel=None))]
1189pub fn hma_py<'py>(
1190    py: Python<'py>,
1191    data: PyReadonlyArray1<'py, f64>,
1192    period: usize,
1193    kernel: Option<&str>,
1194) -> PyResult<Bound<'py, PyArray1<f64>>> {
1195    use numpy::{IntoPyArray, PyArrayMethods};
1196
1197    let slice_in = data.as_slice()?;
1198    let kern = validate_kernel(kernel, false)?;
1199
1200    let params = HmaParams {
1201        period: Some(period),
1202    };
1203    let hma_in = HmaInput::from_slice(slice_in, params);
1204
1205    let result_vec: Vec<f64> = py
1206        .allow_threads(|| hma_with_kernel(&hma_in, kern).map(|o| o.values))
1207        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1208
1209    Ok(result_vec.into_pyarray(py))
1210}
1211
1212#[cfg(feature = "python")]
1213#[pyfunction(name = "hma_batch")]
1214#[pyo3(signature = (data, period_range, kernel=None))]
1215pub fn hma_batch_py<'py>(
1216    py: Python<'py>,
1217    data: PyReadonlyArray1<'py, f64>,
1218    period_range: (usize, usize, usize),
1219    kernel: Option<&str>,
1220) -> PyResult<Bound<'py, PyDict>> {
1221    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1222    use pyo3::types::PyDict;
1223
1224    let slice_in = data.as_slice()?;
1225    let kern = validate_kernel(kernel, true)?;
1226    let sweep = HmaBatchRange {
1227        period: period_range,
1228    };
1229
1230    let combos = expand_grid(&sweep);
1231    let rows = combos.len();
1232    let cols = slice_in.len();
1233
1234    let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1235    let slice_out = unsafe { out_arr.as_slice_mut()? };
1236
1237    let combos = py
1238        .allow_threads(|| {
1239            let kernel = match kern {
1240                Kernel::Auto => detect_best_batch_kernel(),
1241                k => k,
1242            };
1243            let simd = match kernel {
1244                Kernel::Avx512Batch => Kernel::Avx512,
1245                Kernel::Avx2Batch => Kernel::Avx2,
1246                Kernel::ScalarBatch => Kernel::Scalar,
1247                _ => unreachable!(),
1248            };
1249            hma_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1250                .map(|(combos, _, _)| combos)
1251        })
1252        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1253
1254    let dict = PyDict::new(py);
1255    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1256    dict.set_item(
1257        "periods",
1258        combos
1259            .iter()
1260            .map(|p| p.period.unwrap() as u64)
1261            .collect::<Vec<_>>()
1262            .into_pyarray(py),
1263    )?;
1264
1265    Ok(dict)
1266}
1267
1268#[cfg(all(feature = "python", feature = "cuda"))]
1269#[pyfunction(name = "hma_cuda_batch_dev")]
1270#[pyo3(signature = (data_f32, period_range, device_id=0))]
1271pub fn hma_cuda_batch_dev_py<'py>(
1272    py: Python<'py>,
1273    data_f32: numpy::PyReadonlyArray1<'py, f32>,
1274    period_range: (usize, usize, usize),
1275    device_id: usize,
1276) -> PyResult<(DeviceArrayF32HmaPy, Bound<'py, PyDict>)> {
1277    use crate::cuda::cuda_available;
1278    use numpy::IntoPyArray;
1279    use pyo3::types::PyDict;
1280
1281    if !cuda_available() {
1282        return Err(PyValueError::new_err("CUDA not available"));
1283    }
1284
1285    let slice_in = data_f32.as_slice()?;
1286    let sweep = HmaBatchRange {
1287        period: period_range,
1288    };
1289
1290    let (inner, combos, stream_u64, ctx, dev_id) = py.allow_threads(|| {
1291        let cuda = CudaHma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1292        let ctx = cuda.ctx();
1293        let dev_id = cuda.device_id();
1294        let res = cuda
1295            .hma_batch_dev(slice_in, &sweep)
1296            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1297        Ok::<_, PyErr>((res.0, res.1, cuda.stream_handle_u64(), ctx, dev_id))
1298    })?;
1299
1300    let dict = PyDict::new(py);
1301    let periods: Vec<u64> = combos.iter().map(|c| c.period.unwrap() as u64).collect();
1302    dict.set_item("periods", periods.into_pyarray(py))?;
1303
1304    dict.set_item("cai_version", 3u64)?;
1305    dict.set_item("cai_typestr", "<f4")?;
1306    dict.set_item("cai_shape", (inner.rows as u64, inner.cols as u64))?;
1307    dict.set_item("cai_strides_bytes", ((inner.cols as u64) * 4u64, 4u64))?;
1308    dict.set_item("stream", stream_u64)?;
1309
1310    Ok((
1311        DeviceArrayF32HmaPy::new(inner, ctx, dev_id, stream_u64),
1312        dict,
1313    ))
1314}
1315
1316#[cfg(all(feature = "python", feature = "cuda"))]
1317#[pyfunction(name = "hma_cuda_many_series_one_param_dev")]
1318#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1319pub fn hma_cuda_many_series_one_param_dev_py(
1320    py: Python<'_>,
1321    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1322    period: usize,
1323    device_id: usize,
1324) -> PyResult<DeviceArrayF32HmaPy> {
1325    use crate::cuda::cuda_available;
1326    use numpy::PyUntypedArrayMethods;
1327
1328    if !cuda_available() {
1329        return Err(PyValueError::new_err("CUDA not available"));
1330    }
1331
1332    let flat_in = data_tm_f32.as_slice()?;
1333    let rows = data_tm_f32.shape()[0];
1334    let cols = data_tm_f32.shape()[1];
1335    let params = HmaParams {
1336        period: Some(period),
1337    };
1338
1339    let (inner, ctx, dev_id, stream_u64) = py.allow_threads(|| {
1340        let cuda = CudaHma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1341        let ctx = cuda.ctx();
1342        let dev_id = cuda.device_id();
1343        let arr = cuda
1344            .hma_multi_series_one_param_time_major_dev(flat_in, cols, rows, &params)
1345            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1346        Ok::<_, PyErr>((arr, ctx, dev_id, cuda.stream_handle_u64()))
1347    })?;
1348
1349    Ok(DeviceArrayF32HmaPy::new(inner, ctx, dev_id, stream_u64))
1350}
1351
1352#[cfg(feature = "python")]
1353#[pyclass(name = "HmaStream")]
1354pub struct HmaStreamPy {
1355    inner: HmaStream,
1356}
1357
1358#[cfg(feature = "python")]
1359#[pymethods]
1360impl HmaStreamPy {
1361    #[new]
1362    fn new(period: usize) -> PyResult<Self> {
1363        let params = HmaParams {
1364            period: Some(period),
1365        };
1366        let stream =
1367            HmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1368        Ok(HmaStreamPy { inner: stream })
1369    }
1370
1371    fn update(&mut self, value: f64) -> Option<f64> {
1372        self.inner.update(value)
1373    }
1374}
1375
1376#[cfg(all(feature = "python", feature = "cuda"))]
1377#[pyclass(module = "ta_indicators.cuda", name = "DeviceArrayF32Hma", unsendable)]
1378pub struct DeviceArrayF32HmaPy {
1379    pub(crate) inner: crate::cuda::moving_averages::DeviceArrayF32,
1380    _ctx_guard: std::sync::Arc<cust::context::Context>,
1381    _device_id: u32,
1382    _stream: u64,
1383}
1384
1385#[cfg(all(feature = "python", feature = "cuda"))]
1386#[pymethods]
1387impl DeviceArrayF32HmaPy {
1388    #[new]
1389    fn py_new() -> PyResult<Self> {
1390        Err(pyo3::exceptions::PyTypeError::new_err(
1391            "use factory methods from CUDA functions",
1392        ))
1393    }
1394
1395    #[getter]
1396    fn __cuda_array_interface__<'py>(
1397        &self,
1398        py: Python<'py>,
1399    ) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1400        let d = pyo3::types::PyDict::new(py);
1401        let itemsize = std::mem::size_of::<f32>();
1402        d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1403        d.set_item("typestr", "<f4")?;
1404
1405        d.set_item("strides", (self.inner.cols * itemsize, itemsize))?;
1406        let size = self.inner.rows.saturating_mul(self.inner.cols);
1407        let ptr_val: usize = if size == 0 {
1408            0
1409        } else {
1410            self.inner.buf.as_device_ptr().as_raw() as usize
1411        };
1412        d.set_item("data", (ptr_val, false))?;
1413
1414        d.set_item("stream", self._stream)?;
1415        d.set_item("version", 3)?;
1416        Ok(d)
1417    }
1418
1419    fn __dlpack_device__(&self) -> (i32, i32) {
1420        (2, self._device_id as i32)
1421    }
1422
1423    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1424    fn __dlpack__<'py>(
1425        &mut self,
1426        py: Python<'py>,
1427        stream: Option<PyObject>,
1428        max_version: Option<PyObject>,
1429        dl_device: Option<PyObject>,
1430        copy: Option<PyObject>,
1431    ) -> PyResult<pyo3::PyObject> {
1432        use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1433        use pyo3::ffi as pyffi;
1434        use std::ffi::{c_void, CString};
1435
1436        let (kdl, alloc_dev) = self.__dlpack_device__();
1437        if let Some(dev_obj) = dl_device.as_ref() {
1438            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1439                if dev_ty != kdl || dev_id != alloc_dev {
1440                    let wants_copy = copy
1441                        .as_ref()
1442                        .and_then(|c| c.extract::<bool>(py).ok())
1443                        .unwrap_or(false);
1444                    if wants_copy {
1445                        return Err(PyValueError::new_err(
1446                            "device copy not implemented for __dlpack__",
1447                        ));
1448                    } else {
1449                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1450                    }
1451                }
1452            }
1453        }
1454
1455        unsafe {
1456            let st = self._stream as cust::sys::CUstream;
1457            let _ = cust::sys::cuStreamSynchronize(st);
1458        }
1459        let _ = stream;
1460
1461        let dummy =
1462            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1463        let inner = std::mem::replace(
1464            &mut self.inner,
1465            crate::cuda::moving_averages::DeviceArrayF32 {
1466                buf: dummy,
1467                rows: 0,
1468                cols: 0,
1469            },
1470        );
1471        let rows = inner.rows;
1472        let cols = inner.cols;
1473        let buf = inner.buf;
1474
1475        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1476
1477        return export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound);
1478
1479        if false {};
1480    }
1481}
1482
1483#[cfg(all(feature = "python", feature = "cuda"))]
1484impl DeviceArrayF32HmaPy {
1485    pub fn new(
1486        inner: crate::cuda::moving_averages::DeviceArrayF32,
1487        ctx_guard: std::sync::Arc<cust::context::Context>,
1488        device_id: u32,
1489        stream_u64: u64,
1490    ) -> Self {
1491        Self {
1492            inner,
1493            _ctx_guard: ctx_guard,
1494            _device_id: device_id,
1495            _stream: stream_u64,
1496        }
1497    }
1498}
1499
1500#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1501use serde::{Deserialize, Serialize};
1502#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1503use wasm_bindgen::prelude::*;
1504
1505#[inline]
1506pub fn hma_into_slice(dst: &mut [f64], input: &HmaInput, kern: Kernel) -> Result<(), HmaError> {
1507    let data: &[f64] = input.as_ref();
1508
1509    if dst.len() != data.len() {
1510        return Err(HmaError::OutputLengthMismatch {
1511            expected: data.len(),
1512            got: dst.len(),
1513        });
1514    }
1515
1516    hma_with_kernel_into(input, kern, dst)
1517}
1518
1519#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1520#[wasm_bindgen]
1521pub fn hma_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1522    let params = HmaParams {
1523        period: Some(period),
1524    };
1525    let input = HmaInput::from_slice(data, params);
1526
1527    let first = data
1528        .iter()
1529        .position(|x| !x.is_nan())
1530        .ok_or_else(|| JsValue::from_str("All NaN"))?;
1531    let sqrt_len = (period as f64).sqrt().floor() as usize;
1532    if period == 0 || sqrt_len == 0 || data.len() - first < period + sqrt_len - 1 {
1533        return Err(JsValue::from_str("Invalid or insufficient data"));
1534    }
1535    let first_out = first + period + sqrt_len - 2;
1536
1537    let mut output = alloc_with_nan_prefix(data.len(), first_out);
1538    hma_into_slice(&mut output, &input, Kernel::Auto)
1539        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1540    Ok(output)
1541}
1542
1543#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1544#[derive(Serialize, Deserialize)]
1545pub struct HmaBatchConfig {
1546    pub period_range: (usize, usize, usize),
1547}
1548
1549#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1550#[derive(Serialize, Deserialize)]
1551pub struct HmaBatchJsOutput {
1552    pub values: Vec<f64>,
1553    pub combos: Vec<HmaParams>,
1554    pub rows: usize,
1555    pub cols: usize,
1556}
1557
1558#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1559#[wasm_bindgen(js_name = hma_batch)]
1560pub fn hma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1561    let config: HmaBatchConfig = serde_wasm_bindgen::from_value(config)
1562        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1563
1564    let sweep = HmaBatchRange {
1565        period: config.period_range,
1566    };
1567
1568    let kernel = if cfg!(target_arch = "wasm32") {
1569        Kernel::ScalarBatch
1570    } else {
1571        Kernel::Auto
1572    };
1573
1574    let output = hma_batch_inner(data, &sweep, kernel, false)
1575        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1576
1577    let js_output = HmaBatchJsOutput {
1578        values: output.values,
1579        combos: output.combos,
1580        rows: output.rows,
1581        cols: output.cols,
1582    };
1583
1584    serde_wasm_bindgen::to_value(&js_output)
1585        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1586}
1587
1588#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1589#[wasm_bindgen]
1590pub fn hma_batch_js(
1591    data: &[f64],
1592    period_start: usize,
1593    period_end: usize,
1594    period_step: usize,
1595) -> Result<Vec<f64>, JsValue> {
1596    let sweep = HmaBatchRange {
1597        period: (period_start, period_end, period_step),
1598    };
1599
1600    let kernel = if cfg!(target_arch = "wasm32") {
1601        Kernel::ScalarBatch
1602    } else {
1603        Kernel::Auto
1604    };
1605
1606    let output = hma_batch_inner(data, &sweep, kernel, false)
1607        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1608
1609    Ok(output.values)
1610}
1611
1612#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1613#[wasm_bindgen]
1614pub fn hma_batch_metadata_js(
1615    period_start: usize,
1616    period_end: usize,
1617    period_step: usize,
1618) -> Vec<f64> {
1619    let periods: Vec<usize> = if period_step == 0 || period_start == period_end {
1620        vec![period_start]
1621    } else {
1622        (period_start..=period_end).step_by(period_step).collect()
1623    };
1624
1625    periods.iter().map(|&p| p as f64).collect()
1626}
1627
1628#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1629#[wasm_bindgen]
1630pub fn hma_alloc(len: usize) -> *mut f64 {
1631    let mut vec = Vec::<f64>::with_capacity(len);
1632    let ptr = vec.as_mut_ptr();
1633    std::mem::forget(vec);
1634    ptr
1635}
1636
1637#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1638#[wasm_bindgen]
1639pub fn hma_free(ptr: *mut f64, len: usize) {
1640    if !ptr.is_null() {
1641        unsafe {
1642            let _ = Vec::from_raw_parts(ptr, len, len);
1643        }
1644    }
1645}
1646
1647#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1648#[wasm_bindgen]
1649pub fn hma_into(
1650    in_ptr: *const f64,
1651    out_ptr: *mut f64,
1652    len: usize,
1653    period: usize,
1654) -> Result<(), JsValue> {
1655    if in_ptr.is_null() || out_ptr.is_null() {
1656        return Err(JsValue::from_str("null pointer passed to hma_into"));
1657    }
1658
1659    unsafe {
1660        let data = std::slice::from_raw_parts(in_ptr, len);
1661
1662        if period == 0 || period > len {
1663            return Err(JsValue::from_str("Invalid period"));
1664        }
1665
1666        let params = HmaParams {
1667            period: Some(period),
1668        };
1669        let input = HmaInput::from_slice(data, params);
1670
1671        if in_ptr == out_ptr {
1672            let mut temp = vec![0.0; len];
1673            hma_into_slice(&mut temp, &input, Kernel::Auto)
1674                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1675
1676            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1677            out.copy_from_slice(&temp);
1678        } else {
1679            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1680            hma_into_slice(out, &input, Kernel::Auto)
1681                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1682        }
1683
1684        Ok(())
1685    }
1686}
1687
1688#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1689#[wasm_bindgen]
1690pub fn hma_batch_into(
1691    in_ptr: *const f64,
1692    out_ptr: *mut f64,
1693    len: usize,
1694    period_start: usize,
1695    period_end: usize,
1696    period_step: usize,
1697) -> Result<usize, JsValue> {
1698    if in_ptr.is_null() || out_ptr.is_null() {
1699        return Err(JsValue::from_str("null pointer passed to hma_batch_into"));
1700    }
1701
1702    unsafe {
1703        let data = std::slice::from_raw_parts(in_ptr, len);
1704
1705        let sweep = HmaBatchRange {
1706            period: (period_start, period_end, period_step),
1707        };
1708
1709        let combos = expand_grid(&sweep);
1710        let rows = combos.len();
1711        let cols = len;
1712
1713        let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
1714
1715        let kernel = if cfg!(target_arch = "wasm32") {
1716            Kernel::ScalarBatch
1717        } else {
1718            Kernel::Auto
1719        };
1720
1721        hma_batch_inner_into(data, &sweep, kernel, false, out)
1722            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1723
1724        Ok(rows)
1725    }
1726}
1727
1728#[cfg(feature = "python")]
1729pub fn register_hma_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
1730    m.add_function(wrap_pyfunction!(hma_py, m)?)?;
1731    m.add_function(wrap_pyfunction!(hma_batch_py, m)?)?;
1732    m.add_class::<HmaStreamPy>()?;
1733    #[cfg(feature = "cuda")]
1734    {
1735        m.add_class::<DeviceArrayF32HmaPy>()?;
1736        m.add_function(wrap_pyfunction!(hma_cuda_batch_dev_py, m)?)?;
1737        m.add_function(wrap_pyfunction!(hma_cuda_many_series_one_param_dev_py, m)?)?;
1738    }
1739    Ok(())
1740}
1741
1742#[cfg(test)]
1743mod tests {
1744    use super::*;
1745    use crate::skip_if_unsupported;
1746    use crate::utilities::data_loader::read_candles_from_csv;
1747    use proptest::prelude::*;
1748
1749    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1750    #[test]
1751    fn test_hma_into_matches_api() -> Result<(), Box<dyn Error>> {
1752        let data: Vec<f64> = (0..512)
1753            .map(|i| ((i as f64).sin() * 123.456789) + (i as f64) * 0.25)
1754            .collect();
1755
1756        let input = HmaInput::from_slice(&data, HmaParams::default());
1757
1758        let baseline = hma(&input)?.values;
1759
1760        let mut out = vec![0.0; data.len()];
1761        hma_into(&input, &mut out)?;
1762
1763        assert_eq!(baseline.len(), out.len());
1764        for (a, b) in baseline.iter().zip(out.iter()) {
1765            let equal = (a.is_nan() && b.is_nan()) || (a == b);
1766            assert!(equal, "Mismatch: a={:?}, b={:?}", a, b);
1767        }
1768
1769        Ok(())
1770    }
1771
1772    fn check_hma_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1773        skip_if_unsupported!(kernel, test_name);
1774        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1775        let candles = read_candles_from_csv(file_path)?;
1776        let default_params = HmaParams { period: None };
1777        let input_default = HmaInput::from_candles(&candles, "close", default_params);
1778        let output_default = hma_with_kernel(&input_default, kernel)?;
1779        assert_eq!(output_default.values.len(), candles.close.len());
1780        Ok(())
1781    }
1782
1783    fn check_hma_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1784        skip_if_unsupported!(kernel, test_name);
1785        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1786        let candles = read_candles_from_csv(file_path)?;
1787        let input = HmaInput::with_default_candles(&candles);
1788        let result = hma_with_kernel(&input, kernel)?;
1789        let expected_last_five = [
1790            59334.13333336847,
1791            59201.4666667018,
1792            59047.77777781293,
1793            59048.71111114628,
1794            58803.44444447962,
1795        ];
1796        assert!(result.values.len() >= 5);
1797        assert_eq!(result.values.len(), candles.close.len());
1798        let start = result.values.len() - 5;
1799        let last_five = &result.values[start..];
1800        for (i, &val) in last_five.iter().enumerate() {
1801            let exp = expected_last_five[i];
1802            assert!(
1803                (val - exp).abs() < 1e-3,
1804                "[{}] idx {}: got {}, expected {}",
1805                test_name,
1806                i,
1807                val,
1808                exp
1809            );
1810        }
1811        Ok(())
1812    }
1813
1814    fn check_hma_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1815        skip_if_unsupported!(kernel, test_name);
1816        let input_data = [10.0, 20.0, 30.0];
1817        let params = HmaParams { period: Some(0) };
1818        let input = HmaInput::from_slice(&input_data, params);
1819        let result = hma_with_kernel(&input, kernel);
1820        assert!(
1821            result.is_err(),
1822            "[{}] HMA should fail with zero period",
1823            test_name
1824        );
1825        Ok(())
1826    }
1827
1828    fn check_hma_period_exceeds_length(
1829        test_name: &str,
1830        kernel: Kernel,
1831    ) -> Result<(), Box<dyn Error>> {
1832        skip_if_unsupported!(kernel, test_name);
1833        let input_data = [10.0, 20.0, 30.0];
1834        let params = HmaParams { period: Some(10) };
1835        let input = HmaInput::from_slice(&input_data, params);
1836        let result = hma_with_kernel(&input, kernel);
1837        assert!(
1838            result.is_err(),
1839            "[{}] HMA should fail with period exceeding length",
1840            test_name
1841        );
1842        Ok(())
1843    }
1844
1845    fn check_hma_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1846        skip_if_unsupported!(kernel, test_name);
1847        let input_data = [42.0];
1848        let params = HmaParams { period: Some(5) };
1849        let input = HmaInput::from_slice(&input_data, params);
1850        let result = hma_with_kernel(&input, kernel);
1851        assert!(
1852            result.is_err(),
1853            "[{}] HMA should fail with insufficient data",
1854            test_name
1855        );
1856        Ok(())
1857    }
1858
1859    fn check_hma_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1860        skip_if_unsupported!(kernel, test_name);
1861        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1862        let candles = read_candles_from_csv(file_path)?;
1863        let first_params = HmaParams { period: Some(5) };
1864        let first_input = HmaInput::from_candles(&candles, "close", first_params);
1865        let first_result = hma_with_kernel(&first_input, kernel)?;
1866        let second_params = HmaParams { period: Some(3) };
1867        let second_input = HmaInput::from_slice(&first_result.values, second_params);
1868        let second_result = hma_with_kernel(&second_input, kernel)?;
1869        assert_eq!(second_result.values.len(), first_result.values.len());
1870        if second_result.values.len() > 240 {
1871            for i in 240..second_result.values.len() {
1872                assert!(!second_result.values[i].is_nan());
1873            }
1874        }
1875        Ok(())
1876    }
1877
1878    fn check_hma_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1879        skip_if_unsupported!(kernel, test_name);
1880        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1881        let candles = read_candles_from_csv(file_path)?;
1882        let params = HmaParams::default();
1883        let period = params.period.unwrap_or(5) * 2;
1884        let input = HmaInput::from_candles(&candles, "close", params);
1885        let result = hma_with_kernel(&input, kernel)?;
1886        assert_eq!(result.values.len(), candles.close.len());
1887        if result.values.len() > period {
1888            for i in period..result.values.len() {
1889                assert!(!result.values[i].is_nan());
1890            }
1891        }
1892        Ok(())
1893    }
1894
1895    fn check_hma_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1896        skip_if_unsupported!(kernel, test_name);
1897        let empty: [f64; 0] = [];
1898        let input = HmaInput::from_slice(&empty, HmaParams::default());
1899        let res = hma_with_kernel(&input, kernel);
1900        assert!(
1901            matches!(res, Err(HmaError::NoData)),
1902            "[{}] expected NoData",
1903            test_name
1904        );
1905        Ok(())
1906    }
1907
1908    fn check_hma_not_enough_valid(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1909        skip_if_unsupported!(kernel, test_name);
1910        let data = [f64::NAN, f64::NAN, 1.0, 2.0];
1911        let params = HmaParams { period: Some(3) };
1912        let input = HmaInput::from_slice(&data, params);
1913        let res = hma_with_kernel(&input, kernel);
1914        assert!(
1915            matches!(res, Err(HmaError::NotEnoughValidData { .. })),
1916            "[{}] expected NotEnoughValidData",
1917            test_name
1918        );
1919        Ok(())
1920    }
1921
1922    fn check_hma_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1923        skip_if_unsupported!(kernel, test_name);
1924        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1925        let candles = read_candles_from_csv(file_path)?;
1926        let period = 5;
1927        let input = HmaInput::from_candles(
1928            &candles,
1929            "close",
1930            HmaParams {
1931                period: Some(period),
1932            },
1933        );
1934        let batch_output = hma_with_kernel(&input, kernel)?.values;
1935
1936        let mut stream = HmaStream::try_new(HmaParams {
1937            period: Some(period),
1938        })?;
1939        let mut stream_vals = Vec::with_capacity(candles.close.len());
1940        for &p in &candles.close {
1941            match stream.update(p) {
1942                Some(v) => stream_vals.push(v),
1943                None => stream_vals.push(f64::NAN),
1944            }
1945        }
1946
1947        assert_eq!(batch_output.len(), stream_vals.len());
1948        for (i, (&b, &s)) in batch_output.iter().zip(stream_vals.iter()).enumerate() {
1949            if b.is_nan() && s.is_nan() {
1950                continue;
1951            }
1952            let diff = (b - s).abs();
1953            assert!(
1954                diff < 1e-4,
1955                "[{}] HMA streaming mismatch at idx {}: batch={}, stream={}, diff={}",
1956                test_name,
1957                i,
1958                b,
1959                s,
1960                diff
1961            );
1962        }
1963        Ok(())
1964    }
1965
1966    fn check_hma_property(
1967        test_name: &str,
1968        kernel: Kernel,
1969    ) -> Result<(), Box<dyn std::error::Error>> {
1970        skip_if_unsupported!(kernel, test_name);
1971
1972        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1973        let candles = read_candles_from_csv(file_path)?;
1974        let close_data = &candles.close;
1975
1976        let strat = (
1977            2usize..=100,
1978            0usize..close_data.len().saturating_sub(500),
1979            200usize..=500,
1980        );
1981
1982        proptest::test_runner::TestRunner::default()
1983            .run(&strat, |(period, start_idx, slice_len)| {
1984                let end_idx = (start_idx + slice_len).min(close_data.len());
1985                if end_idx <= start_idx || end_idx - start_idx < period + 10 {
1986                    return Ok(());
1987                }
1988
1989                let data_slice = &close_data[start_idx..end_idx];
1990                let params = HmaParams {
1991                    period: Some(period),
1992                };
1993                let input = HmaInput::from_slice(data_slice, params);
1994
1995                let result = hma_with_kernel(&input, kernel);
1996
1997                let scalar_result = hma_with_kernel(&input, Kernel::Scalar);
1998
1999                match (result, scalar_result) {
2000                    (Ok(HmaOutput { values: out }), Ok(HmaOutput { values: ref_out })) => {
2001                        prop_assert_eq!(out.len(), data_slice.len());
2002                        prop_assert_eq!(ref_out.len(), data_slice.len());
2003
2004                        let sqrt_period = (period as f64).sqrt().floor() as usize;
2005                        let expected_warmup = period + sqrt_period - 1;
2006
2007                        let first_valid = out.iter().position(|x| !x.is_nan());
2008                        if let Some(first_idx) = first_valid {
2009                            prop_assert!(
2010                                first_idx >= expected_warmup - 1,
2011                                "First valid at {} but expected warmup is {}",
2012                                first_idx,
2013                                expected_warmup
2014                            );
2015
2016                            for i in 0..first_idx {
2017                                prop_assert!(
2018                                    out[i].is_nan(),
2019                                    "Expected NaN at index {} during warmup, got {}",
2020                                    i,
2021                                    out[i]
2022                                );
2023                            }
2024                        }
2025
2026                        for i in 0..out.len() {
2027                            let y = out[i];
2028                            let r = ref_out[i];
2029
2030                            if y.is_nan() {
2031                                prop_assert!(
2032                                    r.is_nan(),
2033                                    "Kernel mismatch at {}: {} vs {}",
2034                                    i,
2035                                    y,
2036                                    r
2037                                );
2038                                continue;
2039                            }
2040
2041                            prop_assert!(y.is_finite(), "Non-finite value at index {}: {}", i, y);
2042
2043                            let y_bits = y.to_bits();
2044                            let r_bits = r.to_bits();
2045                            let ulp_diff = y_bits.abs_diff(r_bits);
2046
2047                            let ulp_tolerance = if matches!(kernel, Kernel::Avx512) {
2048                                20000
2049                            } else {
2050                                8
2051                            };
2052                            prop_assert!(
2053                                (y - r).abs() <= 1e-8 || ulp_diff <= ulp_tolerance,
2054                                "Kernel mismatch at {}: {} vs {} (ULP={})",
2055                                i,
2056                                y,
2057                                r,
2058                                ulp_diff
2059                            );
2060                        }
2061
2062                        for i in expected_warmup..out.len() {
2063                            let y = out[i];
2064                            if y.is_nan() {
2065                                continue;
2066                            }
2067
2068                            prop_assert!(y.is_finite(), "HMA output at {} is not finite: {}", i, y);
2069
2070                            if i >= period * 2 {
2071                                let window_start = i.saturating_sub(period);
2072                                let window = &data_slice[window_start..=i];
2073                                let is_constant =
2074                                    window.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
2075                                if is_constant {
2076                                    let constant_val = window[0];
2077                                    prop_assert!(
2078										(y - constant_val).abs() <= 1e-6,
2079										"HMA should converge to {} for constant data, got {} at index {}",
2080										constant_val,
2081										y,
2082										i
2083									);
2084                                }
2085                            }
2086                        }
2087
2088                        if period == 2 {
2089                            let min_valid_idx = expected_warmup;
2090                            if out.len() > min_valid_idx {
2091                                prop_assert!(
2092                                    out[min_valid_idx].is_finite(),
2093                                    "HMA with period=2 should produce valid output at index {}",
2094                                    min_valid_idx
2095                                );
2096                            }
2097                        }
2098
2099                        Ok(())
2100                    }
2101                    (Err(e1), Err(_e2)) => {
2102                        prop_assert!(
2103                            format!("{:?}", e1).contains("NotEnoughValidData")
2104                                || format!("{:?}", e1).contains("InvalidPeriod"),
2105                            "Unexpected error type: {:?}",
2106                            e1
2107                        );
2108                        Ok(())
2109                    }
2110                    (Ok(_), Err(e)) | (Err(e), Ok(_)) => {
2111                        prop_assert!(
2112                            false,
2113                            "Kernel consistency failure: one succeeded, one failed with {:?}",
2114                            e
2115                        );
2116                        Ok(())
2117                    }
2118                }
2119            })
2120            .unwrap();
2121
2122        let edge_cases = vec![
2123            (vec![1.0, 2.0, 3.0, 4.0, 5.0], 2),
2124            (vec![42.0; 100], 10),
2125            ((0..100).map(|i| i as f64).collect::<Vec<_>>(), 15),
2126            ((0..100).map(|i| 100.0 - i as f64).collect::<Vec<_>>(), 20),
2127        ];
2128
2129        for (case_idx, (data, period)) in edge_cases.into_iter().enumerate() {
2130            let params = HmaParams {
2131                period: Some(period),
2132            };
2133            let input = HmaInput::from_slice(&data, params);
2134
2135            match hma_with_kernel(&input, kernel) {
2136                Ok(out) => {
2137                    let has_valid = out.values.iter().any(|&x| x.is_finite() && !x.is_nan());
2138                    assert!(
2139                        has_valid || data.len() < period + 2,
2140                        "[{}] Edge case {} produced no valid output",
2141                        test_name,
2142                        case_idx
2143                    );
2144                }
2145                Err(_) => {}
2146            }
2147        }
2148
2149        Ok(())
2150    }
2151
2152    #[cfg(debug_assertions)]
2153    fn check_hma_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2154        skip_if_unsupported!(kernel, test_name);
2155
2156        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2157        let candles = read_candles_from_csv(file_path)?;
2158
2159        let test_params = vec![
2160            HmaParams::default(),
2161            HmaParams { period: Some(2) },
2162            HmaParams { period: Some(3) },
2163            HmaParams { period: Some(4) },
2164            HmaParams { period: Some(5) },
2165            HmaParams { period: Some(7) },
2166            HmaParams { period: Some(10) },
2167            HmaParams { period: Some(14) },
2168            HmaParams { period: Some(20) },
2169            HmaParams { period: Some(30) },
2170            HmaParams { period: Some(50) },
2171            HmaParams { period: Some(100) },
2172            HmaParams { period: Some(200) },
2173            HmaParams { period: Some(1) },
2174            HmaParams { period: Some(250) },
2175        ];
2176
2177        for (param_idx, params) in test_params.iter().enumerate() {
2178            let input = HmaInput::from_candles(&candles, "close", params.clone());
2179            let output = hma_with_kernel(&input, kernel)?;
2180
2181            for (i, &val) in output.values.iter().enumerate() {
2182                if val.is_nan() {
2183                    continue;
2184                }
2185
2186                let bits = val.to_bits();
2187
2188                if bits == 0x11111111_11111111 {
2189                    panic!(
2190                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2191                        with params: period={}",
2192                        test_name,
2193                        val,
2194                        bits,
2195                        i,
2196                        params.period.unwrap_or(5)
2197                    );
2198                }
2199
2200                if bits == 0x22222222_22222222 {
2201                    panic!(
2202                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2203                        with params: period={}",
2204                        test_name,
2205                        val,
2206                        bits,
2207                        i,
2208                        params.period.unwrap_or(5)
2209                    );
2210                }
2211
2212                if bits == 0x33333333_33333333 {
2213                    panic!(
2214                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2215                        with params: period={}",
2216                        test_name,
2217                        val,
2218                        bits,
2219                        i,
2220                        params.period.unwrap_or(5)
2221                    );
2222                }
2223            }
2224        }
2225
2226        Ok(())
2227    }
2228
2229    #[cfg(not(debug_assertions))]
2230    fn check_hma_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2231        Ok(())
2232    }
2233
2234    macro_rules! generate_all_hma_tests {
2235        ($($test_fn:ident),*) => {
2236            paste::paste! {
2237                $(
2238                    #[test]
2239                    fn [<$test_fn _scalar_f64>]() {
2240                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2241                    }
2242                )*
2243                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2244                $(
2245                    #[test]
2246                    fn [<$test_fn _avx2_f64>]() {
2247                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2248                    }
2249                    #[test]
2250                    fn [<$test_fn _avx512_f64>]() {
2251                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2252                    }
2253                )*
2254            }
2255        }
2256    }
2257
2258    generate_all_hma_tests!(
2259        check_hma_partial_params,
2260        check_hma_accuracy,
2261        check_hma_zero_period,
2262        check_hma_period_exceeds_length,
2263        check_hma_very_small_dataset,
2264        check_hma_reinput,
2265        check_hma_nan_handling,
2266        check_hma_empty_input,
2267        check_hma_not_enough_valid,
2268        check_hma_streaming,
2269        check_hma_property,
2270        check_hma_no_poison
2271    );
2272
2273    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2274        skip_if_unsupported!(kernel, test);
2275        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2276        let c = read_candles_from_csv(file)?;
2277        let output = HmaBatchBuilder::new()
2278            .kernel(kernel)
2279            .apply_candles(&c, "close")?;
2280        let def = HmaParams::default();
2281        let row = output.values_for(&def).expect("default row missing");
2282        assert_eq!(row.len(), c.close.len());
2283        let expected = [
2284            59334.13333336847,
2285            59201.4666667018,
2286            59047.77777781293,
2287            59048.71111114628,
2288            58803.44444447962,
2289        ];
2290        let start = row.len() - 5;
2291        for (i, &v) in row[start..].iter().enumerate() {
2292            assert!(
2293                (v - expected[i]).abs() < 1e-3,
2294                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2295            );
2296        }
2297        Ok(())
2298    }
2299
2300    #[cfg(debug_assertions)]
2301    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2302        skip_if_unsupported!(kernel, test);
2303
2304        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2305        let c = read_candles_from_csv(file)?;
2306
2307        let test_configs = vec![
2308            (2, 5, 1),
2309            (5, 25, 5),
2310            (10, 50, 10),
2311            (2, 4, 1),
2312            (50, 150, 25),
2313            (10, 30, 2),
2314            (10, 30, 10),
2315            (100, 300, 50),
2316        ];
2317
2318        for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
2319            let output = HmaBatchBuilder::new()
2320                .kernel(kernel)
2321                .period_range(p_start, p_end, p_step)
2322                .apply_candles(&c, "close")?;
2323
2324            for (idx, &val) in output.values.iter().enumerate() {
2325                if val.is_nan() {
2326                    continue;
2327                }
2328
2329                let bits = val.to_bits();
2330                let row = idx / output.cols;
2331                let col = idx % output.cols;
2332                let combo = &output.combos[row];
2333
2334                if bits == 0x11111111_11111111 {
2335                    panic!(
2336                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2337                        at row {} col {} (flat index {}) with params: period={}",
2338                        test,
2339                        cfg_idx,
2340                        val,
2341                        bits,
2342                        row,
2343                        col,
2344                        idx,
2345                        combo.period.unwrap_or(5)
2346                    );
2347                }
2348
2349                if bits == 0x22222222_22222222 {
2350                    panic!(
2351                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2352                        at row {} col {} (flat index {}) with params: period={}",
2353                        test,
2354                        cfg_idx,
2355                        val,
2356                        bits,
2357                        row,
2358                        col,
2359                        idx,
2360                        combo.period.unwrap_or(5)
2361                    );
2362                }
2363
2364                if bits == 0x33333333_33333333 {
2365                    panic!(
2366                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2367                        at row {} col {} (flat index {}) with params: period={}",
2368                        test,
2369                        cfg_idx,
2370                        val,
2371                        bits,
2372                        row,
2373                        col,
2374                        idx,
2375                        combo.period.unwrap_or(5)
2376                    );
2377                }
2378            }
2379        }
2380
2381        Ok(())
2382    }
2383
2384    #[cfg(not(debug_assertions))]
2385    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2386        Ok(())
2387    }
2388
2389    macro_rules! gen_batch_tests {
2390        ($fn_name:ident) => {
2391            paste::paste! {
2392                #[test] fn [<$fn_name _scalar>]()      {
2393                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2394                }
2395                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2396                #[test] fn [<$fn_name _avx2>]()        {
2397                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2398                }
2399                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2400                #[test] fn [<$fn_name _avx512>]()      {
2401                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2402                }
2403                #[test] fn [<$fn_name _auto_detect>]() {
2404                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]),
2405                                     Kernel::Auto);
2406                }
2407            }
2408        };
2409    }
2410    gen_batch_tests!(check_batch_default_row);
2411    gen_batch_tests!(check_batch_no_poison);
2412}