Skip to main content

vector_ta/indicators/moving_averages/
dma.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::CudaDma;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::cuda::moving_averages::DeviceArrayF32;
7#[cfg(all(feature = "python", feature = "cuda"))]
8use cust::context::Context;
9#[cfg(all(feature = "python", feature = "cuda"))]
10use numpy::PyUntypedArrayMethods;
11#[cfg(feature = "python")]
12use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2};
13#[cfg(feature = "python")]
14use pyo3::exceptions::PyValueError;
15#[cfg(feature = "python")]
16use pyo3::prelude::*;
17#[cfg(feature = "python")]
18use pyo3::types::PyDict;
19#[cfg(all(feature = "python", feature = "cuda"))]
20use std::sync::Arc;
21
22#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
23use serde::{Deserialize, Serialize};
24#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
25use wasm_bindgen::prelude::*;
26
27use crate::utilities::data_loader::{source_type, Candles};
28use crate::utilities::enums::Kernel;
29use crate::utilities::helpers::{
30    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
31    make_uninit_matrix,
32};
33#[cfg(feature = "python")]
34use crate::utilities::kernel_validation::validate_kernel;
35use aligned_vec::{AVec, CACHELINE_ALIGN};
36
37#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
38use core::arch::x86_64::*;
39
40#[cfg(not(target_arch = "wasm32"))]
41use rayon::prelude::*;
42
43use std::convert::AsRef;
44use std::error::Error;
45use std::mem::MaybeUninit;
46use thiserror::Error;
47
48impl<'a> AsRef<[f64]> for DmaInput<'a> {
49    #[inline(always)]
50    fn as_ref(&self) -> &[f64] {
51        match &self.data {
52            DmaData::Slice(slice) => slice,
53            DmaData::Candles { candles, source } => source_type(candles, source),
54        }
55    }
56}
57
58#[derive(Debug, Clone)]
59pub enum DmaData<'a> {
60    Candles {
61        candles: &'a Candles,
62        source: &'a str,
63    },
64    Slice(&'a [f64]),
65}
66
67#[derive(Debug, Clone)]
68pub struct DmaOutput {
69    pub values: Vec<f64>,
70}
71
72#[derive(Debug, Clone)]
73#[cfg_attr(
74    all(target_arch = "wasm32", feature = "wasm"),
75    derive(Serialize, Deserialize)
76)]
77pub struct DmaParams {
78    pub hull_length: Option<usize>,
79    pub ema_length: Option<usize>,
80    pub ema_gain_limit: Option<usize>,
81    pub hull_ma_type: Option<String>,
82}
83
84impl Default for DmaParams {
85    fn default() -> Self {
86        Self {
87            hull_length: Some(7),
88            ema_length: Some(20),
89            ema_gain_limit: Some(50),
90            hull_ma_type: Some("WMA".to_string()),
91        }
92    }
93}
94
95#[derive(Debug, Clone)]
96pub struct DmaInput<'a> {
97    pub data: DmaData<'a>,
98    pub params: DmaParams,
99}
100
101impl<'a> DmaInput<'a> {
102    #[inline]
103    pub fn from_candles(c: &'a Candles, s: &'a str, p: DmaParams) -> Self {
104        Self {
105            data: DmaData::Candles {
106                candles: c,
107                source: s,
108            },
109            params: p,
110        }
111    }
112
113    #[inline]
114    pub fn from_slice(sl: &'a [f64], p: DmaParams) -> Self {
115        Self {
116            data: DmaData::Slice(sl),
117            params: p,
118        }
119    }
120
121    #[inline]
122    pub fn with_default_candles(c: &'a Candles) -> Self {
123        Self::from_candles(c, "close", DmaParams::default())
124    }
125
126    #[inline]
127    pub fn get_hull_length(&self) -> usize {
128        self.params.hull_length.unwrap_or(7)
129    }
130
131    #[inline]
132    pub fn get_ema_length(&self) -> usize {
133        self.params.ema_length.unwrap_or(20)
134    }
135
136    #[inline]
137    pub fn get_ema_gain_limit(&self) -> usize {
138        self.params.ema_gain_limit.unwrap_or(50)
139    }
140
141    #[inline]
142    pub fn get_hull_ma_type(&self) -> String {
143        self.params
144            .hull_ma_type
145            .clone()
146            .unwrap_or_else(|| "WMA".to_string())
147    }
148
149    #[inline]
150    pub fn hull_ma_type_str(&self) -> &str {
151        self.params.hull_ma_type.as_deref().unwrap_or("WMA")
152    }
153}
154
155#[derive(Clone, Debug)]
156pub struct DmaBuilder {
157    hull_length: Option<usize>,
158    ema_length: Option<usize>,
159    ema_gain_limit: Option<usize>,
160    hull_ma_type: Option<String>,
161    kernel: Kernel,
162}
163
164impl Default for DmaBuilder {
165    fn default() -> Self {
166        Self {
167            hull_length: None,
168            ema_length: None,
169            ema_gain_limit: None,
170            hull_ma_type: None,
171            kernel: Kernel::Auto,
172        }
173    }
174}
175
176impl DmaBuilder {
177    #[inline(always)]
178    pub fn new() -> Self {
179        Self::default()
180    }
181
182    #[inline(always)]
183    pub fn hull_length(mut self, val: usize) -> Self {
184        self.hull_length = Some(val);
185        self
186    }
187
188    #[inline(always)]
189    pub fn ema_length(mut self, val: usize) -> Self {
190        self.ema_length = Some(val);
191        self
192    }
193
194    #[inline(always)]
195    pub fn ema_gain_limit(mut self, val: usize) -> Self {
196        self.ema_gain_limit = Some(val);
197        self
198    }
199
200    #[inline(always)]
201    pub fn hull_ma_type(mut self, val: String) -> Self {
202        self.hull_ma_type = Some(val);
203        self
204    }
205
206    #[inline(always)]
207    pub fn kernel(mut self, k: Kernel) -> Self {
208        self.kernel = k;
209        self
210    }
211
212    #[inline(always)]
213    pub fn apply(self, c: &Candles) -> Result<DmaOutput, DmaError> {
214        let p = DmaParams {
215            hull_length: self.hull_length,
216            ema_length: self.ema_length,
217            ema_gain_limit: self.ema_gain_limit,
218            hull_ma_type: self.hull_ma_type,
219        };
220        let i = DmaInput::from_candles(c, "close", p);
221        dma_with_kernel(&i, self.kernel)
222    }
223
224    #[inline(always)]
225    pub fn apply_slice(self, d: &[f64]) -> Result<DmaOutput, DmaError> {
226        let p = DmaParams {
227            hull_length: self.hull_length,
228            ema_length: self.ema_length,
229            ema_gain_limit: self.ema_gain_limit,
230            hull_ma_type: self.hull_ma_type,
231        };
232        let i = DmaInput::from_slice(d, p);
233        dma_with_kernel(&i, self.kernel)
234    }
235
236    #[inline(always)]
237    pub fn into_stream(self) -> Result<DmaStream, DmaError> {
238        let p = DmaParams {
239            hull_length: self.hull_length,
240            ema_length: self.ema_length,
241            ema_gain_limit: self.ema_gain_limit,
242            hull_ma_type: self.hull_ma_type,
243        };
244        DmaStream::try_new(p)
245    }
246}
247
248#[derive(Debug, Error)]
249pub enum DmaError {
250    #[error("dma: Input data slice is empty.")]
251    EmptyInputData,
252
253    #[error("dma: All values are NaN.")]
254    AllValuesNaN,
255
256    #[error("dma: Invalid period: period = {period}, data length = {data_len}")]
257    InvalidPeriod { period: usize, data_len: usize },
258
259    #[error("dma: Not enough valid data: needed = {needed}, valid = {valid}")]
260    NotEnoughValidData { needed: usize, valid: usize },
261
262    #[error("dma: Invalid Hull MA type: {value}. Must be 'WMA' or 'EMA'.")]
263    InvalidHullMAType { value: String },
264
265    #[error("dma: Output slice length mismatch: expected = {expected}, got = {got}")]
266    OutputLengthMismatch { expected: usize, got: usize },
267
268    #[error("dma: Invalid range expansion: start = {start}, end = {end}, step = {step}")]
269    InvalidRange {
270        start: usize,
271        end: usize,
272        step: usize,
273    },
274
275    #[error("dma: Invalid kernel for batch path: {0:?}")]
276    InvalidKernelForBatch(Kernel),
277}
278
279#[inline(always)]
280pub fn dma(input: &DmaInput) -> Result<DmaOutput, DmaError> {
281    dma_with_kernel(input, Kernel::Auto)
282}
283
284#[inline(always)]
285pub fn dma_with_kernel(input: &DmaInput, kernel: Kernel) -> Result<DmaOutput, DmaError> {
286    let (data, hull_len, ema_len, ema_gain_limit, hull_ma_type, first, chosen) =
287        dma_prepare(input, kernel)?;
288
289    let sqrt_len = (hull_len as f64).sqrt().round() as usize;
290    let warmup_end = first + hull_len.max(ema_len) + sqrt_len - 1;
291
292    let mut out = alloc_with_nan_prefix(data.len(), warmup_end);
293    dma_compute_into(
294        data,
295        hull_len,
296        ema_len,
297        ema_gain_limit,
298        &hull_ma_type,
299        first,
300        chosen,
301        &mut out,
302    );
303    Ok(DmaOutput { values: out })
304}
305
306#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
307#[inline(always)]
308pub fn dma_into(input: &DmaInput, out: &mut [f64]) -> Result<(), DmaError> {
309    let (data, hull_len, ema_len, ema_gain_limit, hull_ma_type, first, chosen) =
310        dma_prepare(input, Kernel::Auto)?;
311
312    if out.len() != data.len() {
313        return Err(DmaError::OutputLengthMismatch {
314            expected: data.len(),
315            got: out.len(),
316        });
317    }
318
319    let sqrt_len = (hull_len as f64).sqrt().round() as usize;
320    let warmup_end = first + hull_len.max(ema_len) + sqrt_len - 1;
321    let end = warmup_end.min(out.len());
322    let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
323    for v in &mut out[..end] {
324        *v = qnan;
325    }
326
327    dma_compute_into(
328        data,
329        hull_len,
330        ema_len,
331        ema_gain_limit,
332        &hull_ma_type,
333        first,
334        chosen,
335        out,
336    );
337    Ok(())
338}
339
340#[inline(always)]
341pub fn dma_into_slice(dst: &mut [f64], input: &DmaInput, kern: Kernel) -> Result<(), DmaError> {
342    let (data, hull_len, ema_len, ema_gain_limit, hull_ma_type, first, chosen) =
343        dma_prepare(input, kern)?;
344
345    if dst.len() != data.len() {
346        return Err(DmaError::OutputLengthMismatch {
347            expected: data.len(),
348            got: dst.len(),
349        });
350    }
351
352    dma_compute_into(
353        data,
354        hull_len,
355        ema_len,
356        ema_gain_limit,
357        &hull_ma_type,
358        first,
359        chosen,
360        dst,
361    );
362
363    let sqrt_len = (hull_len as f64).sqrt().round() as usize;
364    let warmup_end = first + hull_len.max(ema_len) + sqrt_len - 1;
365    let end = warmup_end.min(dst.len());
366    for v in &mut dst[..end] {
367        *v = f64::NAN;
368    }
369    Ok(())
370}
371
372#[inline(always)]
373fn dma_prepare<'a>(
374    input: &'a DmaInput,
375    kernel: Kernel,
376) -> Result<(&'a [f64], usize, usize, usize, &'a str, usize, Kernel), DmaError> {
377    let data: &[f64] = input.as_ref();
378    let len = data.len();
379    if len == 0 {
380        return Err(DmaError::EmptyInputData);
381    }
382
383    let first = data
384        .iter()
385        .position(|x| !x.is_nan())
386        .ok_or(DmaError::AllValuesNaN)?;
387    let hull_length = input.get_hull_length();
388    let ema_length = input.get_ema_length();
389    let ema_gain_limit = input.get_ema_gain_limit();
390    let hull_ma_type = input.hull_ma_type_str();
391
392    if hull_length == 0 || hull_length > len {
393        return Err(DmaError::InvalidPeriod {
394            period: hull_length,
395            data_len: len,
396        });
397    }
398    if ema_length == 0 || ema_length > len {
399        return Err(DmaError::InvalidPeriod {
400            period: ema_length,
401            data_len: len,
402        });
403    }
404
405    let sqrt_len = (hull_length as f64).sqrt().round() as usize;
406    let needed = hull_length.max(ema_length) + sqrt_len;
407    if len - first < needed {
408        return Err(DmaError::NotEnoughValidData {
409            needed,
410            valid: len - first,
411        });
412    }
413    if hull_ma_type != "WMA" && hull_ma_type != "EMA" {
414        return Err(DmaError::InvalidHullMAType {
415            value: hull_ma_type.to_string(),
416        });
417    }
418    let chosen = match kernel {
419        Kernel::Auto => Kernel::Scalar,
420        k => k,
421    };
422    Ok((
423        data,
424        hull_length,
425        ema_length,
426        ema_gain_limit,
427        hull_ma_type,
428        first,
429        chosen,
430    ))
431}
432
433#[inline(always)]
434fn dma_compute_into(
435    data: &[f64],
436    hull_length: usize,
437    ema_length: usize,
438    ema_gain_limit: usize,
439    hull_ma_type: &str,
440    first: usize,
441    kernel: Kernel,
442    out: &mut [f64],
443) {
444    unsafe {
445        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
446        {
447            if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
448                dma_simd128(
449                    data,
450                    hull_length,
451                    ema_length,
452                    ema_gain_limit,
453                    hull_ma_type,
454                    first,
455                    out,
456                );
457                return;
458            }
459        }
460
461        match kernel {
462            Kernel::Scalar | Kernel::ScalarBatch => dma_scalar(
463                data,
464                hull_length,
465                ema_length,
466                ema_gain_limit,
467                hull_ma_type,
468                first,
469                out,
470            ),
471            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
472            Kernel::Avx2 | Kernel::Avx2Batch => dma_avx2(
473                data,
474                hull_length,
475                ema_length,
476                ema_gain_limit,
477                hull_ma_type,
478                first,
479                out,
480            ),
481            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
482            Kernel::Avx512 | Kernel::Avx512Batch => dma_avx512(
483                data,
484                hull_length,
485                ema_length,
486                ema_gain_limit,
487                hull_ma_type,
488                first,
489                out,
490            ),
491            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
492            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => dma_scalar(
493                data,
494                hull_length,
495                ema_length,
496                ema_gain_limit,
497                hull_ma_type,
498                first,
499                out,
500            ),
501            _ => unreachable!(),
502        }
503    }
504}
505
506#[inline]
507pub fn dma_scalar(
508    data: &[f64],
509    hull_length: usize,
510    ema_length: usize,
511    ema_gain_limit: usize,
512    hull_ma_type: &str,
513    first: usize,
514    out: &mut [f64],
515) {
516    let n = data.len();
517    if n == 0 {
518        return;
519    }
520
521    let alpha_e = 2.0 / (ema_length as f64 + 1.0);
522    let one_minus_alpha_e = 1.0 - alpha_e;
523    let i0_e = first + ema_length.saturating_sub(1);
524    let mut e0_prev = 0.0;
525    let mut e0_init_done = false;
526    let mut ec_prev = 0.0;
527    let mut ec_init_done = false;
528
529    let half = hull_length / 2;
530    let sqrt_len = (hull_length as f64).sqrt().round() as usize;
531
532    let mut hull_val = f64::NAN;
533
534    let wsum = |p: usize| -> f64 { (p * (p + 1)) as f64 / 2.0 };
535    let i0_half = first + half.saturating_sub(1);
536    let i0_full = first + hull_length.saturating_sub(1);
537
538    let mut a_half = 0.0;
539    let mut s_half = 0.0;
540    let mut half_ready = false;
541
542    let mut a_full = 0.0;
543    let mut s_full = 0.0;
544    let mut full_ready = false;
545
546    let mut diff_ring: Vec<f64> = Vec::with_capacity(sqrt_len.max(1));
547    let mut diff_pos: usize = 0;
548    let mut diff_filled = 0usize;
549
550    let mut a_diff = 0.0;
551    let mut s_diff = 0.0;
552    let mut diff_wma_init_done = false;
553
554    let alpha_sqrt = if sqrt_len > 0 {
555        2.0 / (sqrt_len as f64 + 1.0)
556    } else {
557        0.0
558    };
559    let mut diff_ema = 0.0;
560    let mut diff_ema_init_done = false;
561    let mut diff_sum_seed = 0.0;
562
563    let mut e_half_prev = 0.0;
564    let mut e_half_init_done = false;
565    let mut e_full_prev = 0.0;
566    let mut e_full_init_done = false;
567    let alpha_half = if half > 0 {
568        2.0 / (half as f64 + 1.0)
569    } else {
570        0.0
571    };
572    let alpha_full = if hull_length > 0 {
573        2.0 / (hull_length as f64 + 1.0)
574    } else {
575        0.0
576    };
577
578    let is_wma = hull_ma_type == "WMA";
579
580    for i in first..n {
581        let x = data[i];
582
583        if !e0_init_done {
584            if i >= i0_e {
585                let start = i + 1 - ema_length;
586                let mut sum = 0.0;
587                for k in start..=i {
588                    sum += data[k];
589                }
590                e0_prev = sum / ema_length as f64;
591                e0_init_done = true;
592            }
593        } else {
594            e0_prev = x.mul_add(alpha_e, one_minus_alpha_e * e0_prev);
595        }
596
597        let mut diff_now = f64::NAN;
598
599        if is_wma {
600            if half > 0 {
601                if !half_ready {
602                    if i >= i0_half {
603                        let start = i + 1 - half;
604                        let mut sum = 0.0;
605                        let mut wsum_local = 0.0;
606                        for (j, idx) in (start..=i).enumerate() {
607                            let w = (j + 1) as f64;
608                            let v = data[idx];
609                            sum += v;
610                            wsum_local += w * v;
611                        }
612                        a_half = sum;
613                        s_half = wsum_local;
614                        half_ready = true;
615                    }
616                } else {
617                    let a_prev = a_half;
618                    a_half = a_prev + x - data[i - half];
619                    s_half = s_half + (half as f64) * x - a_prev;
620                }
621            }
622
623            if hull_length > 0 {
624                if !full_ready {
625                    if i >= i0_full {
626                        let start = i + 1 - hull_length;
627                        let mut sum = 0.0;
628                        let mut wsum_local = 0.0;
629                        for (j, idx) in (start..=i).enumerate() {
630                            let w = (j + 1) as f64;
631                            let v = data[idx];
632                            sum += v;
633                            wsum_local += w * v;
634                        }
635                        a_full = sum;
636                        s_full = wsum_local;
637                        full_ready = true;
638                    }
639                } else {
640                    let a_prev = a_full;
641                    a_full = a_prev + x - data[i - hull_length];
642                    s_full = s_full + (hull_length as f64) * x - a_prev;
643                }
644            }
645
646            if half_ready && full_ready {
647                let w_half = s_half / wsum(half).max(1.0);
648                let w_full = s_full / wsum(hull_length).max(1.0);
649                diff_now = 2.0 * w_half - w_full;
650            }
651        } else {
652            if half > 0 {
653                if !e_half_init_done {
654                    if i >= i0_half {
655                        let start = i + 1 - half;
656                        let mut sum = 0.0;
657                        for k in start..=i {
658                            sum += data[k];
659                        }
660                        e_half_prev = sum / half as f64;
661                        e_half_init_done = true;
662                    }
663                } else {
664                    e_half_prev = x.mul_add(alpha_half, (1.0 - alpha_half) * e_half_prev);
665                }
666            }
667
668            if hull_length > 0 {
669                if !e_full_init_done {
670                    if i >= i0_full {
671                        let start = i + 1 - hull_length;
672                        let mut sum = 0.0;
673                        for k in start..=i {
674                            sum += data[k];
675                        }
676                        e_full_prev = sum / hull_length as f64;
677                        e_full_init_done = true;
678                    }
679                } else {
680                    e_full_prev = x.mul_add(alpha_full, (1.0 - alpha_full) * e_full_prev);
681                }
682            }
683
684            if e_half_init_done && e_full_init_done {
685                diff_now = 2.0 * e_half_prev - e_full_prev;
686            }
687        }
688
689        if diff_now.is_finite() && sqrt_len > 0 {
690            if diff_filled < sqrt_len {
691                diff_ring.push(diff_now);
692                diff_sum_seed += diff_now;
693                diff_filled += 1;
694
695                if diff_filled == sqrt_len {
696                    if is_wma {
697                        a_diff = 0.0;
698                        s_diff = 0.0;
699                        for (j, &v) in diff_ring.iter().enumerate() {
700                            let w = (j + 1) as f64;
701                            a_diff += v;
702                            s_diff += w * v;
703                        }
704                        diff_wma_init_done = true;
705                        hull_val = s_diff / wsum(sqrt_len).max(1.0);
706                    } else {
707                        diff_ema = diff_sum_seed / sqrt_len as f64;
708                        diff_ema_init_done = true;
709                        hull_val = diff_ema;
710                    }
711                }
712            } else {
713                let old = diff_ring[diff_pos];
714                diff_ring[diff_pos] = diff_now;
715                diff_pos = (diff_pos + 1) % sqrt_len;
716
717                if is_wma {
718                    let a_prev = a_diff;
719                    a_diff = a_prev + diff_now - old;
720                    s_diff = s_diff + (sqrt_len as f64) * diff_now - a_prev;
721                    hull_val = s_diff / wsum(sqrt_len).max(1.0);
722                } else {
723                    diff_ema = diff_now.mul_add(alpha_sqrt, (1.0 - alpha_sqrt) * diff_ema);
724                    hull_val = diff_ema;
725                }
726            }
727        }
728
729        let mut ec_now = f64::NAN;
730        if e0_init_done {
731            if !ec_init_done {
732                ec_prev = e0_prev;
733                ec_init_done = true;
734                ec_now = ec_prev;
735            } else {
736                let dx = x - ec_prev;
737                let t = alpha_e * dx;
738                let base = e0_prev.mul_add(alpha_e, one_minus_alpha_e * ec_prev);
739                let r = x - base;
740
741                let g_sel = if t == 0.0 {
742                    0.0
743                } else {
744                    let limit_i = ema_gain_limit as i64;
745                    let target = (r / t) * 10.0;
746                    let mut i0 = target.floor() as i64;
747                    if i0 < 0 {
748                        i0 = 0;
749                    } else if i0 > limit_i {
750                        i0 = limit_i;
751                    }
752                    let i1 = if i0 < limit_i { i0 + 1 } else { i0 };
753                    let g0 = (i0 as f64) * 0.1;
754                    let g1 = (i1 as f64) * 0.1;
755                    let e0 = (r - t * g0).abs();
756                    let e1 = (r - t * g1).abs();
757                    if e0 <= e1 {
758                        g0
759                    } else {
760                        g1
761                    }
762                };
763
764                ec_now = (e0_prev + g_sel * dx).mul_add(alpha_e, one_minus_alpha_e * ec_prev);
765                ec_prev = ec_now;
766            }
767        }
768
769        if hull_val.is_finite() && ec_now.is_finite() {
770            out[i] = 0.5 * (hull_val + ec_now);
771        }
772    }
773}
774
775#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
776#[inline]
777unsafe fn dma_simd128(
778    data: &[f64],
779    hull_length: usize,
780    ema_length: usize,
781    ema_gain_limit: usize,
782    hull_ma_type: &str,
783    first_val: usize,
784    out: &mut [f64],
785) {
786    use core::arch::wasm32::*;
787    dma_scalar(
788        data,
789        hull_length,
790        ema_length,
791        ema_gain_limit,
792        hull_ma_type,
793        first_val,
794        out,
795    );
796}
797
798#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
799#[inline(always)]
800unsafe fn hsum256d(v: __m256d) -> f64 {
801    let mut buf = [0.0f64; 4];
802    _mm256_storeu_pd(buf.as_mut_ptr(), v);
803    buf[0] + buf[1] + buf[2] + buf[3]
804}
805
806#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
807#[inline(always)]
808unsafe fn hsum512d(v: __m512d) -> f64 {
809    let mut buf = [0.0f64; 8];
810    _mm512_storeu_pd(buf.as_mut_ptr(), v);
811    buf.iter().sum()
812}
813
814#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
815#[inline(always)]
816unsafe fn vabs256d(x: __m256d) -> __m256d {
817    let sign = _mm256_set1_pd(-0.0);
818    _mm256_andnot_pd(sign, x)
819}
820
821#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
822#[inline(always)]
823unsafe fn vabs512d(x: __m512d) -> __m512d {
824    let sign = _mm512_set1_epi64(i64::MIN as i64);
825    let xi = _mm512_castpd_si512(x);
826    let cleared = _mm512_andnot_si512(sign, xi);
827    _mm512_castsi512_pd(cleared)
828}
829
830#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
831#[inline(always)]
832unsafe fn sum_unweighted_avx2(ptr: *const f64, len: usize) -> f64 {
833    let mut i = 0usize;
834    let mut acc = _mm256_setzero_pd();
835    while i + 4 <= len {
836        let v = _mm256_loadu_pd(ptr.add(i));
837        acc = _mm256_add_pd(acc, v);
838        i += 4;
839    }
840    let mut s = hsum256d(acc);
841    while i < len {
842        s += *ptr.add(i);
843        i += 1;
844    }
845    s
846}
847
848#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
849#[inline(always)]
850unsafe fn sum_unweighted_avx512(ptr: *const f64, len: usize) -> f64 {
851    let mut i = 0usize;
852    let mut acc = _mm512_setzero_pd();
853    while i + 8 <= len {
854        let v = _mm512_loadu_pd(ptr.add(i));
855        acc = _mm512_add_pd(acc, v);
856        i += 8;
857    }
858    let mut s = hsum512d(acc);
859    while i < len {
860        s += *ptr.add(i);
861        i += 1;
862    }
863    s
864}
865
866#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
867#[inline(always)]
868unsafe fn seed_wma_window_avx2(ptr: *const f64, len: usize) -> (f64, f64) {
869    let mut i = 0usize;
870    let mut acc_v = _mm256_setzero_pd();
871    let mut acc_wv = _mm256_setzero_pd();
872    let inc = _mm256_set_pd(3.0, 2.0, 1.0, 0.0);
873    let mut wbase = 1.0f64;
874    while i + 4 <= len {
875        let v = _mm256_loadu_pd(ptr.add(i));
876        let w = _mm256_add_pd(_mm256_set1_pd(wbase), inc);
877        acc_v = _mm256_add_pd(acc_v, v);
878        acc_wv = _mm256_add_pd(acc_wv, _mm256_mul_pd(w, v));
879        wbase += 4.0;
880        i += 4;
881    }
882    let mut s = hsum256d(acc_v);
883    let mut sw = hsum256d(acc_wv);
884    while i < len {
885        let val = *ptr.add(i);
886        s += val;
887        sw += (i as f64 + 1.0) * val;
888        i += 1;
889    }
890    (s, sw)
891}
892
893#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
894#[inline(always)]
895unsafe fn seed_wma_window_avx512(ptr: *const f64, len: usize) -> (f64, f64) {
896    let mut i = 0usize;
897    let mut acc_v = _mm512_setzero_pd();
898    let mut acc_wv = _mm512_setzero_pd();
899    let inc = _mm512_set_pd(7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0);
900    let mut wbase = 1.0f64;
901    while i + 8 <= len {
902        let v = _mm512_loadu_pd(ptr.add(i));
903        let w = _mm512_add_pd(_mm512_set1_pd(wbase), inc);
904        acc_v = _mm512_add_pd(acc_v, v);
905        acc_wv = _mm512_add_pd(acc_wv, _mm512_mul_pd(w, v));
906        wbase += 8.0;
907        i += 8;
908    }
909    let mut s = hsum512d(acc_v);
910    let mut sw = hsum512d(acc_wv);
911    while i < len {
912        let val = *ptr.add(i);
913        s += val;
914        sw += (i as f64 + 1.0) * val;
915        i += 1;
916    }
917    (s, sw)
918}
919
920#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
921#[inline(always)]
922unsafe fn best_gain_search_avx2(
923    x: f64,
924    e0_prev: f64,
925    ec_prev: f64,
926    alpha_e: f64,
927    ema_gain_limit: usize,
928) -> f64 {
929    let width = 4usize;
930    let dx = _mm256_set1_pd(x - ec_prev);
931    let x_v = _mm256_set1_pd(x);
932    let e0_v = _mm256_set1_pd(e0_prev);
933    let ec_prev_v = _mm256_set1_pd(ec_prev);
934    let a_v = _mm256_set1_pd(alpha_e);
935    let om_a_v = _mm256_set1_pd(1.0 - alpha_e);
936    let inf_v = _mm256_set1_pd(f64::INFINITY);
937    let limit_f = ema_gain_limit as f64;
938    let limit_v = _mm256_set1_pd(limit_f);
939    let scale = _mm256_set1_pd(0.1);
940
941    let mut best_err = _mm256_set1_pd(f64::INFINITY);
942    let mut best_g = _mm256_set1_pd(0.0);
943
944    let mut idx = 0usize;
945    while idx <= ema_gain_limit {
946        let base = _mm256_set1_pd(idx as f64);
947        let inc = _mm256_set_pd(3.0, 2.0, 1.0, 0.0);
948        let idx_v = _mm256_add_pd(base, inc);
949
950        let gt_mask = _mm256_cmp_pd(idx_v, limit_v, _CMP_GT_OQ);
951
952        let g = _mm256_mul_pd(idx_v, scale);
953        let e0_plus = _mm256_fmadd_pd(g, dx, e0_v);
954        let pred = _mm256_fmadd_pd(a_v, e0_plus, _mm256_mul_pd(om_a_v, ec_prev_v));
955        let err = vabs256d(_mm256_sub_pd(x_v, pred));
956
957        let err_masked = _mm256_blendv_pd(err, inf_v, gt_mask);
958
959        let lt = _mm256_cmp_pd(err_masked, best_err, _CMP_LT_OQ);
960        best_err = _mm256_blendv_pd(best_err, err_masked, lt);
961        best_g = _mm256_blendv_pd(best_g, g, lt);
962
963        idx += width;
964    }
965
966    let mut e = [0.0f64; 4];
967    let mut g = [0.0f64; 4];
968    _mm256_storeu_pd(e.as_mut_ptr(), best_err);
969    _mm256_storeu_pd(g.as_mut_ptr(), best_g);
970
971    let mut best_e = f64::INFINITY;
972    let mut best_gg = 0.0;
973    for k in 0..4 {
974        let ek = e[k];
975        let gk = g[k];
976        if ek < best_e || (ek == best_e && gk < best_gg) {
977            best_e = ek;
978            best_gg = gk;
979        }
980    }
981    best_gg
982}
983
984#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
985#[inline(always)]
986unsafe fn best_gain_search_avx512(
987    x: f64,
988    e0_prev: f64,
989    ec_prev: f64,
990    alpha_e: f64,
991    ema_gain_limit: usize,
992) -> f64 {
993    let width = 8usize;
994    let dx = _mm512_set1_pd(x - ec_prev);
995    let x_v = _mm512_set1_pd(x);
996    let e0_v = _mm512_set1_pd(e0_prev);
997    let ec_prev_v = _mm512_set1_pd(ec_prev);
998    let a_v = _mm512_set1_pd(alpha_e);
999    let om_a_v = _mm512_set1_pd(1.0 - alpha_e);
1000    let inf_v = _mm512_set1_pd(f64::INFINITY);
1001    let limit_f = ema_gain_limit as f64;
1002    let limit_v = _mm512_set1_pd(limit_f);
1003    let scale = _mm512_set1_pd(0.1);
1004
1005    let mut best_err = _mm512_set1_pd(f64::INFINITY);
1006    let mut best_g = _mm512_set1_pd(0.0);
1007
1008    let mut idx = 0usize;
1009    while idx <= ema_gain_limit {
1010        let base = _mm512_set1_pd(idx as f64);
1011        let inc = _mm512_set_pd(7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0);
1012        let idx_v = _mm512_add_pd(base, inc);
1013
1014        let k_invalid = _mm512_cmp_pd_mask(idx_v, limit_v, _CMP_GT_OQ);
1015
1016        let g = _mm512_mul_pd(idx_v, scale);
1017        let e0_plus = _mm512_fmadd_pd(g, dx, e0_v);
1018        let pred = _mm512_fmadd_pd(a_v, e0_plus, _mm512_mul_pd(om_a_v, ec_prev_v));
1019        let err = vabs512d(_mm512_sub_pd(x_v, pred));
1020
1021        let err_masked = _mm512_mask_mov_pd(err, k_invalid, inf_v);
1022
1023        let k_lt = _mm512_cmp_pd_mask(err_masked, best_err, _CMP_LT_OQ);
1024        best_err = _mm512_mask_mov_pd(best_err, k_lt, err_masked);
1025        best_g = _mm512_mask_mov_pd(best_g, k_lt, g);
1026
1027        idx += width;
1028    }
1029
1030    let mut e = [0.0f64; 8];
1031    let mut g = [0.0f64; 8];
1032    _mm512_storeu_pd(e.as_mut_ptr(), best_err);
1033    _mm512_storeu_pd(g.as_mut_ptr(), best_g);
1034
1035    let mut best_e = f64::INFINITY;
1036    let mut best_gg = 0.0;
1037    for k in 0..8 {
1038        let ek = e[k];
1039        let gk = g[k];
1040        if ek < best_e || (ek == best_e && gk < best_gg) {
1041            best_e = ek;
1042            best_gg = gk;
1043        }
1044    }
1045    best_gg
1046}
1047
1048#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1049#[target_feature(enable = "avx2,fma")]
1050unsafe fn dma_avx2(
1051    data: &[f64],
1052    hull_length: usize,
1053    ema_length: usize,
1054    ema_gain_limit: usize,
1055    hull_ma_type: &str,
1056    first: usize,
1057    out: &mut [f64],
1058) {
1059    let n = data.len();
1060    if n == 0 {
1061        return;
1062    }
1063
1064    let alpha_e = 2.0 / (ema_length as f64 + 1.0);
1065    let i0_e = first + ema_length.saturating_sub(1);
1066    let mut e0_prev = 0.0;
1067    let mut e0_init_done = false;
1068    let mut ec_prev = 0.0;
1069    let mut ec_init_done = false;
1070
1071    let half = hull_length / 2;
1072    let sqrt_len = (hull_length as f64).sqrt().round() as usize;
1073
1074    let mut hull_val = f64::NAN;
1075
1076    let wsum = |p: usize| -> f64 { (p * (p + 1)) as f64 / 2.0 };
1077    let i0_half = first + half.saturating_sub(1);
1078    let i0_full = first + hull_length.saturating_sub(1);
1079
1080    let mut a_half = 0.0;
1081    let mut s_half = 0.0;
1082    let mut half_ready = false;
1083
1084    let mut a_full = 0.0;
1085    let mut s_full = 0.0;
1086    let mut full_ready = false;
1087
1088    let mut diff_ring: Vec<f64> = Vec::with_capacity(sqrt_len.max(1));
1089    let mut diff_pos: usize = 0;
1090    let mut diff_filled = 0usize;
1091
1092    let mut a_diff = 0.0;
1093    let mut s_diff = 0.0;
1094    let mut diff_wma_init_done = false;
1095
1096    let alpha_sqrt = if sqrt_len > 0 {
1097        2.0 / (sqrt_len as f64 + 1.0)
1098    } else {
1099        0.0
1100    };
1101    let mut diff_ema = 0.0;
1102    let mut diff_ema_init_done = false;
1103    let mut diff_sum_seed = 0.0;
1104
1105    let mut e_half_prev = 0.0;
1106    let mut e_half_init_done = false;
1107    let mut e_full_prev = 0.0;
1108    let mut e_full_init_done = false;
1109    let alpha_half = if half > 0 {
1110        2.0 / (half as f64 + 1.0)
1111    } else {
1112        0.0
1113    };
1114    let alpha_full = if hull_length > 0 {
1115        2.0 / (hull_length as f64 + 1.0)
1116    } else {
1117        0.0
1118    };
1119
1120    let is_wma = hull_ma_type == "WMA";
1121
1122    for i in first..n {
1123        let x = data[i];
1124
1125        if !e0_init_done {
1126            if i >= i0_e {
1127                let start = i + 1 - ema_length;
1128                let sum = sum_unweighted_avx2(data.as_ptr().add(start), ema_length);
1129                e0_prev = sum / ema_length as f64;
1130                e0_init_done = true;
1131            }
1132        } else {
1133            e0_prev = x.mul_add(alpha_e, (1.0 - alpha_e) * e0_prev);
1134        }
1135
1136        let mut diff_now = f64::NAN;
1137
1138        if is_wma {
1139            if half > 0 {
1140                if !half_ready {
1141                    if i >= i0_half {
1142                        let start = i + 1 - half;
1143                        let (sum, wsum_local) =
1144                            seed_wma_window_avx2(data.as_ptr().add(start), half);
1145                        a_half = sum;
1146                        s_half = wsum_local;
1147                        half_ready = true;
1148                    }
1149                } else {
1150                    let a_prev = a_half;
1151                    a_half = a_prev + x - data[i - half];
1152                    s_half = s_half + (half as f64) * x - a_prev;
1153                }
1154            }
1155
1156            if hull_length > 0 {
1157                if !full_ready {
1158                    if i >= i0_full {
1159                        let start = i + 1 - hull_length;
1160                        let (sum, wsum_local) =
1161                            seed_wma_window_avx2(data.as_ptr().add(start), hull_length);
1162                        a_full = sum;
1163                        s_full = wsum_local;
1164                        full_ready = true;
1165                    }
1166                } else {
1167                    let a_prev = a_full;
1168                    a_full = a_prev + x - data[i - hull_length];
1169                    s_full = s_full + (hull_length as f64) * x - a_prev;
1170                }
1171            }
1172
1173            if half_ready && full_ready {
1174                let w_half = s_half / wsum(half).max(1.0);
1175                let w_full = s_full / wsum(hull_length).max(1.0);
1176                diff_now = 2.0 * w_half - w_full;
1177            }
1178        } else {
1179            if half > 0 {
1180                if !e_half_init_done {
1181                    if i >= i0_half {
1182                        let start = i + 1 - half;
1183                        let sum = sum_unweighted_avx2(data.as_ptr().add(start), half);
1184                        e_half_prev = sum / half as f64;
1185                        e_half_init_done = true;
1186                    }
1187                } else {
1188                    e_half_prev = x.mul_add(alpha_half, (1.0 - alpha_half) * e_half_prev);
1189                }
1190            }
1191
1192            if hull_length > 0 {
1193                if !e_full_init_done {
1194                    if i >= i0_full {
1195                        let start = i + 1 - hull_length;
1196                        let sum = sum_unweighted_avx2(data.as_ptr().add(start), hull_length);
1197                        e_full_prev = sum / hull_length as f64;
1198                        e_full_init_done = true;
1199                    }
1200                } else {
1201                    e_full_prev = x.mul_add(alpha_full, (1.0 - alpha_full) * e_full_prev);
1202                }
1203            }
1204
1205            if e_half_init_done && e_full_init_done {
1206                diff_now = 2.0 * e_half_prev - e_full_prev;
1207            }
1208        }
1209
1210        if diff_now.is_finite() && sqrt_len > 0 {
1211            if diff_filled < sqrt_len {
1212                diff_ring.push(diff_now);
1213                diff_sum_seed += diff_now;
1214                diff_filled += 1;
1215
1216                if diff_filled == sqrt_len {
1217                    if is_wma {
1218                        let (a0, s0) = seed_wma_window_avx2(diff_ring.as_ptr(), sqrt_len);
1219                        a_diff = a0;
1220                        s_diff = s0;
1221                        diff_wma_init_done = true;
1222                        let wsum_d = (sqrt_len * (sqrt_len + 1)) as f64 / 2.0;
1223                        hull_val = s_diff / wsum_d.max(1.0);
1224                    } else {
1225                        diff_ema = diff_sum_seed / sqrt_len as f64;
1226                        diff_ema_init_done = true;
1227                        hull_val = diff_ema;
1228                    }
1229                }
1230            } else {
1231                let old = diff_ring[diff_pos];
1232                diff_ring[diff_pos] = diff_now;
1233                diff_pos = (diff_pos + 1) % sqrt_len;
1234
1235                if is_wma {
1236                    let a_prev = a_diff;
1237                    a_diff = a_prev + diff_now - old;
1238                    s_diff = s_diff + (sqrt_len as f64) * diff_now - a_prev;
1239                    let wsum_d = (sqrt_len * (sqrt_len + 1)) as f64 / 2.0;
1240                    hull_val = s_diff / wsum_d.max(1.0);
1241                } else {
1242                    diff_ema = diff_now.mul_add(alpha_sqrt, (1.0 - alpha_sqrt) * diff_ema);
1243                    hull_val = diff_ema;
1244                }
1245            }
1246        }
1247
1248        let mut ec_now = f64::NAN;
1249        if e0_init_done {
1250            if !ec_init_done {
1251                ec_prev = e0_prev;
1252                ec_init_done = true;
1253                ec_now = ec_prev;
1254            } else {
1255                let dx = x - ec_prev;
1256                let t = alpha_e * dx;
1257                let base = e0_prev.mul_add(alpha_e, (1.0 - alpha_e) * ec_prev);
1258                let r = x - base;
1259
1260                let g_sel = if t == 0.0 {
1261                    0.0
1262                } else {
1263                    let limit_i = ema_gain_limit as i64;
1264                    let target = (r / t) * 10.0;
1265                    let mut i0 = target.floor() as i64;
1266                    if i0 < 0 {
1267                        i0 = 0;
1268                    } else if i0 > limit_i {
1269                        i0 = limit_i;
1270                    }
1271                    let i1 = if i0 < limit_i { i0 + 1 } else { i0 };
1272                    let g0 = (i0 as f64) * 0.1;
1273                    let g1 = (i1 as f64) * 0.1;
1274                    let e0 = (r - t * g0).abs();
1275                    let e1 = (r - t * g1).abs();
1276                    if e0 <= e1 {
1277                        g0
1278                    } else {
1279                        g1
1280                    }
1281                };
1282
1283                ec_now = (e0_prev + g_sel * dx).mul_add(alpha_e, (1.0 - alpha_e) * ec_prev);
1284                ec_prev = ec_now;
1285            }
1286        }
1287
1288        if hull_val.is_finite() && ec_now.is_finite() {
1289            out[i] = 0.5 * (hull_val + ec_now);
1290        }
1291    }
1292}
1293
1294#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1295#[target_feature(enable = "avx512f,fma")]
1296unsafe fn dma_avx512(
1297    data: &[f64],
1298    hull_length: usize,
1299    ema_length: usize,
1300    ema_gain_limit: usize,
1301    hull_ma_type: &str,
1302    first: usize,
1303    out: &mut [f64],
1304) {
1305    let n = data.len();
1306    if n == 0 {
1307        return;
1308    }
1309
1310    let alpha_e = 2.0 / (ema_length as f64 + 1.0);
1311    let i0_e = first + ema_length.saturating_sub(1);
1312    let mut e0_prev = 0.0;
1313    let mut e0_init_done = false;
1314    let mut ec_prev = 0.0;
1315    let mut ec_init_done = false;
1316
1317    let half = hull_length / 2;
1318    let sqrt_len = (hull_length as f64).sqrt().round() as usize;
1319
1320    let mut hull_val = f64::NAN;
1321
1322    let wsum = |p: usize| -> f64 { (p * (p + 1)) as f64 / 2.0 };
1323    let i0_half = first + half.saturating_sub(1);
1324    let i0_full = first + hull_length.saturating_sub(1);
1325
1326    let mut a_half = 0.0;
1327    let mut s_half = 0.0;
1328    let mut half_ready = false;
1329
1330    let mut a_full = 0.0;
1331    let mut s_full = 0.0;
1332    let mut full_ready = false;
1333
1334    let mut diff_ring: Vec<f64> = Vec::with_capacity(sqrt_len.max(1));
1335    let mut diff_pos: usize = 0;
1336    let mut diff_filled = 0usize;
1337
1338    let mut a_diff = 0.0;
1339    let mut s_diff = 0.0;
1340    let mut diff_wma_init_done = false;
1341
1342    let alpha_sqrt = if sqrt_len > 0 {
1343        2.0 / (sqrt_len as f64 + 1.0)
1344    } else {
1345        0.0
1346    };
1347    let mut diff_ema = 0.0;
1348    let mut diff_ema_init_done = false;
1349    let mut diff_sum_seed = 0.0;
1350
1351    let mut e_half_prev = 0.0;
1352    let mut e_half_init_done = false;
1353    let mut e_full_prev = 0.0;
1354    let mut e_full_init_done = false;
1355    let alpha_half = if half > 0 {
1356        2.0 / (half as f64 + 1.0)
1357    } else {
1358        0.0
1359    };
1360    let alpha_full = if hull_length > 0 {
1361        2.0 / (hull_length as f64 + 1.0)
1362    } else {
1363        0.0
1364    };
1365
1366    let is_wma = hull_ma_type == "WMA";
1367
1368    for i in first..n {
1369        let x = data[i];
1370
1371        if !e0_init_done {
1372            if i >= i0_e {
1373                let start = i + 1 - ema_length;
1374                let sum = sum_unweighted_avx512(data.as_ptr().add(start), ema_length);
1375                e0_prev = sum / ema_length as f64;
1376                e0_init_done = true;
1377            }
1378        } else {
1379            e0_prev = x.mul_add(alpha_e, (1.0 - alpha_e) * e0_prev);
1380        }
1381
1382        let mut diff_now = f64::NAN;
1383
1384        if is_wma {
1385            if half > 0 {
1386                if !half_ready {
1387                    if i >= i0_half {
1388                        let start = i + 1 - half;
1389                        let (sum, wsum_local) =
1390                            seed_wma_window_avx512(data.as_ptr().add(start), half);
1391                        a_half = sum;
1392                        s_half = wsum_local;
1393                        half_ready = true;
1394                    }
1395                } else {
1396                    let a_prev = a_half;
1397                    a_half = a_prev + x - data[i - half];
1398                    s_half = s_half + (half as f64) * x - a_prev;
1399                }
1400            }
1401
1402            if hull_length > 0 {
1403                if !full_ready {
1404                    if i >= i0_full {
1405                        let start = i + 1 - hull_length;
1406                        let (sum, wsum_local) =
1407                            seed_wma_window_avx512(data.as_ptr().add(start), hull_length);
1408                        a_full = sum;
1409                        s_full = wsum_local;
1410                        full_ready = true;
1411                    }
1412                } else {
1413                    let a_prev = a_full;
1414                    a_full = a_prev + x - data[i - hull_length];
1415                    s_full = s_full + (hull_length as f64) * x - a_prev;
1416                }
1417            }
1418
1419            if half_ready && full_ready {
1420                let w_half = s_half / wsum(half).max(1.0);
1421                let w_full = s_full / wsum(hull_length).max(1.0);
1422                diff_now = 2.0 * w_half - w_full;
1423            }
1424        } else {
1425            if half > 0 {
1426                if !e_half_init_done {
1427                    if i >= i0_half {
1428                        let start = i + 1 - half;
1429                        let sum = sum_unweighted_avx512(data.as_ptr().add(start), half);
1430                        e_half_prev = sum / half as f64;
1431                        e_half_init_done = true;
1432                    }
1433                } else {
1434                    e_half_prev = x.mul_add(alpha_half, (1.0 - alpha_half) * e_half_prev);
1435                }
1436            }
1437
1438            if hull_length > 0 {
1439                if !e_full_init_done {
1440                    if i >= i0_full {
1441                        let start = i + 1 - hull_length;
1442                        let sum = sum_unweighted_avx512(data.as_ptr().add(start), hull_length);
1443                        e_full_prev = sum / hull_length as f64;
1444                        e_full_init_done = true;
1445                    }
1446                } else {
1447                    e_full_prev = x.mul_add(alpha_full, (1.0 - alpha_full) * e_full_prev);
1448                }
1449            }
1450
1451            if e_half_init_done && e_full_init_done {
1452                diff_now = 2.0 * e_half_prev - e_full_prev;
1453            }
1454        }
1455
1456        if diff_now.is_finite() && sqrt_len > 0 {
1457            if diff_filled < sqrt_len {
1458                diff_ring.push(diff_now);
1459                diff_sum_seed += diff_now;
1460                diff_filled += 1;
1461
1462                if diff_filled == sqrt_len {
1463                    if is_wma {
1464                        let (a0, s0) = seed_wma_window_avx512(diff_ring.as_ptr(), sqrt_len);
1465                        a_diff = a0;
1466                        s_diff = s0;
1467                        diff_wma_init_done = true;
1468                        let wsum_d = (sqrt_len * (sqrt_len + 1)) as f64 / 2.0;
1469                        hull_val = s_diff / wsum_d.max(1.0);
1470                    } else {
1471                        diff_ema = diff_sum_seed / sqrt_len as f64;
1472                        diff_ema_init_done = true;
1473                        hull_val = diff_ema;
1474                    }
1475                }
1476            } else {
1477                let old = diff_ring[diff_pos];
1478                diff_ring[diff_pos] = diff_now;
1479                diff_pos = (diff_pos + 1) % sqrt_len;
1480
1481                if is_wma {
1482                    let a_prev = a_diff;
1483                    a_diff = a_prev + diff_now - old;
1484                    s_diff = s_diff + (sqrt_len as f64) * diff_now - a_prev;
1485                    let wsum_d = (sqrt_len * (sqrt_len + 1)) as f64 / 2.0;
1486                    hull_val = s_diff / wsum_d.max(1.0);
1487                } else {
1488                    diff_ema = diff_now.mul_add(alpha_sqrt, (1.0 - alpha_sqrt) * diff_ema);
1489                    hull_val = diff_ema;
1490                }
1491            }
1492        }
1493
1494        let mut ec_now = f64::NAN;
1495        if e0_init_done {
1496            if !ec_init_done {
1497                ec_prev = e0_prev;
1498                ec_init_done = true;
1499                ec_now = ec_prev;
1500            } else {
1501                let dx = x - ec_prev;
1502                let t = alpha_e * dx;
1503                let base = e0_prev.mul_add(alpha_e, (1.0 - alpha_e) * ec_prev);
1504                let r = x - base;
1505
1506                let g_sel = if t == 0.0 {
1507                    0.0
1508                } else {
1509                    let limit_i = ema_gain_limit as i64;
1510                    let target = (r / t) * 10.0;
1511                    let mut i0 = target.floor() as i64;
1512                    if i0 < 0 {
1513                        i0 = 0;
1514                    } else if i0 > limit_i {
1515                        i0 = limit_i;
1516                    }
1517                    let i1 = if i0 < limit_i { i0 + 1 } else { i0 };
1518                    let g0 = (i0 as f64) * 0.1;
1519                    let g1 = (i1 as f64) * 0.1;
1520                    let e0 = (r - t * g0).abs();
1521                    let e1 = (r - t * g1).abs();
1522                    if e0 <= e1 {
1523                        g0
1524                    } else {
1525                        g1
1526                    }
1527                };
1528
1529                ec_now = (e0_prev + g_sel * dx).mul_add(alpha_e, (1.0 - alpha_e) * ec_prev);
1530                ec_prev = ec_now;
1531            }
1532        }
1533
1534        if hull_val.is_finite() && ec_now.is_finite() {
1535            out[i] = 0.5 * (hull_val + ec_now);
1536        }
1537    }
1538}
1539#[derive(Debug, Clone)]
1540pub struct DmaStream {
1541    ema_length: usize,
1542    ema_gain_limit: usize,
1543    hull_length: usize,
1544    half: usize,
1545    sqrt_len: usize,
1546    is_wma: bool,
1547
1548    cap: usize,
1549    ring: Vec<f64>,
1550    head: usize,
1551    filled: usize,
1552
1553    i: usize,
1554    seen_first: bool,
1555
1556    alpha_e: f64,
1557    sum_e0: f64,
1558    e0_prev: f64,
1559    e0_ready: bool,
1560
1561    ec_prev: f64,
1562    ec_ready: bool,
1563
1564    sum_half: f64,
1565    sum_full: f64,
1566    s_half: f64,
1567    s_full: f64,
1568    half_ready: bool,
1569    full_ready: bool,
1570
1571    alpha_half: f64,
1572    alpha_full: f64,
1573    e_half_prev: f64,
1574    e_full_prev: f64,
1575    e_half_ready: bool,
1576    e_full_ready: bool,
1577
1578    a_diff: f64,
1579    s_diff: f64,
1580    diff_wma_ready: bool,
1581
1582    alpha_sqrt: f64,
1583    diff_ema: f64,
1584    diff_ema_ready: bool,
1585
1586    diff_ring: Vec<f64>,
1587    diff_head: usize,
1588    diff_filled: usize,
1589}
1590
1591impl DmaStream {
1592    pub fn try_new(params: DmaParams) -> Result<Self, DmaError> {
1593        let hull_length = params.hull_length.unwrap_or(7);
1594        let ema_length = params.ema_length.unwrap_or(20);
1595        let ema_gain_limit = params.ema_gain_limit.unwrap_or(50);
1596        let hull_ma_type = params.hull_ma_type.unwrap_or_else(|| "WMA".to_string());
1597        if hull_length == 0 || ema_length == 0 {
1598            return Err(DmaError::InvalidPeriod {
1599                period: hull_length.max(ema_length),
1600                data_len: 0,
1601            });
1602        }
1603        if hull_ma_type != "WMA" && hull_ma_type != "EMA" {
1604            return Err(DmaError::InvalidHullMAType {
1605                value: hull_ma_type,
1606            });
1607        }
1608
1609        let half = hull_length / 2;
1610        let sqrt_len = (hull_length as f64).sqrt().round() as usize;
1611        let cap = hull_length.max(ema_length).max(1);
1612
1613        Ok(Self {
1614            ema_length,
1615            ema_gain_limit,
1616            hull_length,
1617            half,
1618            sqrt_len,
1619            is_wma: hull_ma_type == "WMA",
1620
1621            cap,
1622            ring: vec![f64::NAN; cap],
1623            head: 0,
1624            filled: 0,
1625            i: 0,
1626            seen_first: false,
1627
1628            alpha_e: 2.0 / (ema_length as f64 + 1.0),
1629            sum_e0: 0.0,
1630            e0_prev: 0.0,
1631            e0_ready: false,
1632
1633            ec_prev: 0.0,
1634            ec_ready: false,
1635
1636            sum_half: 0.0,
1637            sum_full: 0.0,
1638            s_half: 0.0,
1639            s_full: 0.0,
1640            half_ready: half == 0,
1641            full_ready: hull_length == 0,
1642
1643            alpha_half: if half > 0 {
1644                2.0 / (half as f64 + 1.0)
1645            } else {
1646                0.0
1647            },
1648            alpha_full: 2.0 / (hull_length as f64 + 1.0),
1649            e_half_prev: 0.0,
1650            e_full_prev: 0.0,
1651            e_half_ready: half == 0,
1652            e_full_ready: hull_length == 0,
1653
1654            a_diff: 0.0,
1655            s_diff: 0.0,
1656            diff_wma_ready: sqrt_len == 0,
1657            alpha_sqrt: if sqrt_len > 0 {
1658                2.0 / (sqrt_len as f64 + 1.0)
1659            } else {
1660                0.0
1661            },
1662            diff_ema: 0.0,
1663            diff_ema_ready: sqrt_len == 0,
1664            diff_ring: vec![f64::NAN; sqrt_len.max(1)],
1665            diff_head: 0,
1666            diff_filled: 0,
1667        })
1668    }
1669
1670    #[inline]
1671    pub fn update(&mut self, x: f64) -> Option<f64> {
1672        if !self.seen_first {
1673            self.i += 1;
1674            if x.is_nan() {
1675                return None;
1676            }
1677            self.seen_first = true;
1678        }
1679
1680        let old_head = self.head;
1681        self.ring[old_head] = x;
1682        self.head = (old_head + 1) % self.cap;
1683        if self.filled < self.cap {
1684            self.filled += 1;
1685        }
1686
1687        #[inline(always)]
1688        fn kback(ring: &[f64], head: usize, cap: usize, k: usize) -> f64 {
1689            let idx = (head + cap - k % cap) % cap;
1690            ring[idx]
1691        }
1692
1693        if self.filled < self.ema_length {
1694            if x.is_finite() {
1695                self.sum_e0 += x;
1696            }
1697        } else {
1698            let out_e = kback(&self.ring, self.head, self.cap, self.ema_length);
1699            if x.is_finite() {
1700                self.sum_e0 += x;
1701            }
1702            if out_e.is_finite() {
1703                self.sum_e0 -= out_e;
1704            }
1705            if !self.e0_ready {
1706                self.e0_prev = self.sum_e0 / self.ema_length as f64;
1707                self.e0_ready = true;
1708            } else {
1709                self.e0_prev = self.alpha_e * x + (1.0 - self.alpha_e) * self.e0_prev;
1710            }
1711        }
1712
1713        let mut diff_now = f64::NAN;
1714
1715        if self.is_wma {
1716            if self.half > 0 {
1717                if self.filled < self.half {
1718                    if x.is_finite() {
1719                        self.sum_half += x;
1720                    }
1721                } else {
1722                    let out_h = kback(&self.ring, self.head, self.cap, self.half);
1723                    if x.is_finite() {
1724                        self.sum_half += x;
1725                    }
1726                    if out_h.is_finite() {
1727                        self.sum_half -= out_h;
1728                    }
1729                    if !self.half_ready {
1730                        self.s_half = 0.0;
1731                        for j in 0..self.half {
1732                            let v = kback(&self.ring, self.head, self.cap, self.half - j);
1733                            self.s_half += (j as f64 + 1.0) * v;
1734                        }
1735                        self.half_ready = true;
1736                    } else {
1737                        let a_prev = self.sum_half + out_h - x;
1738                        self.s_half = self.s_half + (self.half as f64) * x - a_prev;
1739                    }
1740                }
1741            } else {
1742                self.half_ready = true;
1743            }
1744
1745            if self.filled < self.hull_length {
1746                if x.is_finite() {
1747                    self.sum_full += x;
1748                }
1749            } else {
1750                let out_f = kback(&self.ring, self.head, self.cap, self.hull_length);
1751                if x.is_finite() {
1752                    self.sum_full += x;
1753                }
1754                if out_f.is_finite() {
1755                    self.sum_full -= out_f;
1756                }
1757                if !self.full_ready {
1758                    self.s_full = 0.0;
1759                    for j in 0..self.hull_length {
1760                        let v = kback(&self.ring, self.head, self.cap, self.hull_length - j);
1761                        self.s_full += (j as f64 + 1.0) * v;
1762                    }
1763                    self.full_ready = true;
1764                } else {
1765                    let a_prev = self.sum_full + out_f - x;
1766                    self.s_full = self.s_full + (self.hull_length as f64) * x - a_prev;
1767                }
1768            }
1769
1770            if self.half_ready && self.full_ready && self.sqrt_len > 0 {
1771                let wsum = |p: usize| (p * (p + 1)) as f64 / 2.0;
1772                let w_half = self.s_half / wsum(self.half).max(1.0);
1773                let w_full = self.s_full / wsum(self.hull_length).max(1.0);
1774                diff_now = 2.0 * w_half - w_full;
1775            }
1776        } else {
1777            if self.half > 0 {
1778                if self.filled < self.half {
1779                } else if !self.e_half_ready {
1780                    let mut s = 0.0;
1781                    for j in 0..self.half {
1782                        s += kback(&self.ring, self.head, self.cap, self.half - j);
1783                    }
1784                    self.e_half_prev = s / self.half as f64;
1785                    self.e_half_ready = true;
1786                } else {
1787                    self.e_half_prev =
1788                        self.alpha_half * x + (1.0 - self.alpha_half) * self.e_half_prev;
1789                }
1790            } else {
1791                self.e_half_ready = true;
1792            }
1793
1794            if self.filled < self.hull_length {
1795            } else if !self.e_full_ready {
1796                let mut s = 0.0;
1797                for j in 0..self.hull_length {
1798                    s += kback(&self.ring, self.head, self.cap, self.hull_length - j);
1799                }
1800                self.e_full_prev = s / self.hull_length as f64;
1801                self.e_full_ready = true;
1802            } else {
1803                self.e_full_prev = self.alpha_full * x + (1.0 - self.alpha_full) * self.e_full_prev;
1804            }
1805
1806            if self.e_half_ready && self.e_full_ready && self.sqrt_len > 0 {
1807                diff_now = 2.0 * self.e_half_prev - self.e_full_prev;
1808            }
1809        }
1810
1811        let mut hull_val = f64::NAN;
1812        if self.sqrt_len == 0 {
1813            if diff_now.is_finite() {
1814                hull_val = diff_now;
1815            }
1816        } else if diff_now.is_finite() {
1817            let old = self.diff_ring[self.diff_head];
1818            self.diff_ring[self.diff_head] = diff_now;
1819            self.diff_head = (self.diff_head + 1) % self.sqrt_len;
1820            if self.diff_filled < self.sqrt_len {
1821                self.diff_filled += 1;
1822            }
1823
1824            if self.is_wma {
1825                if !self.diff_wma_ready && self.diff_filled == self.sqrt_len {
1826                    self.a_diff = 0.0;
1827                    self.s_diff = 0.0;
1828                    for j in 0..self.sqrt_len {
1829                        let v = self.diff_ring[(self.diff_head + j) % self.sqrt_len];
1830                        self.a_diff += v;
1831                        self.s_diff += (j as f64 + 1.0) * v;
1832                    }
1833                    self.diff_wma_ready = true;
1834                    let wsum = (self.sqrt_len * (self.sqrt_len + 1)) as f64 / 2.0;
1835                    hull_val = self.s_diff / wsum.max(1.0);
1836                } else if self.diff_wma_ready {
1837                    let wsum = (self.sqrt_len * (self.sqrt_len + 1)) as f64 / 2.0;
1838                    let a_prev = self.a_diff;
1839                    self.s_diff = self.s_diff + (self.sqrt_len as f64) * diff_now - a_prev;
1840                    self.a_diff = a_prev + diff_now - old;
1841                    hull_val = self.s_diff / wsum.max(1.0);
1842                }
1843            } else {
1844                if !self.diff_ema_ready && self.diff_filled == self.sqrt_len {
1845                    let mut s = 0.0;
1846                    for j in 0..self.sqrt_len {
1847                        s += self.diff_ring[j];
1848                    }
1849                    self.diff_ema = s / self.sqrt_len as f64;
1850                    self.diff_ema_ready = true;
1851                    hull_val = self.diff_ema;
1852                } else if self.diff_ema_ready {
1853                    self.diff_ema =
1854                        self.alpha_sqrt * diff_now + (1.0 - self.alpha_sqrt) * self.diff_ema;
1855                    hull_val = self.diff_ema;
1856                }
1857            }
1858        }
1859
1860        let mut ec_now = f64::NAN;
1861        if self.e0_ready {
1862            if !self.ec_ready {
1863                self.ec_prev = self.e0_prev;
1864                self.ec_ready = true;
1865                ec_now = self.ec_prev;
1866            } else {
1867                let one_minus_alpha_e = 1.0 - self.alpha_e;
1868                let dx = x - self.ec_prev;
1869                let t = self.alpha_e * dx;
1870                let base = self
1871                    .alpha_e
1872                    .mul_add(self.e0_prev, one_minus_alpha_e * self.ec_prev);
1873                let r = x - base;
1874
1875                let g_sel = if t == 0.0 {
1876                    0.0
1877                } else {
1878                    let target = (r / t) * 10.0;
1879                    let limit_i = self.ema_gain_limit as i64;
1880                    let mut i0 = target.floor() as i64;
1881                    if i0 < 0 {
1882                        i0 = 0;
1883                    } else if i0 > limit_i {
1884                        i0 = limit_i;
1885                    }
1886                    let i1 = if i0 < limit_i { i0 + 1 } else { i0 };
1887                    let g0 = (i0 as f64) * 0.1;
1888                    let g1 = (i1 as f64) * 0.1;
1889                    let e0 = (r - t * g0).abs();
1890                    let e1 = (r - t * g1).abs();
1891                    if e0 <= e1 {
1892                        g0
1893                    } else {
1894                        g1
1895                    }
1896                };
1897
1898                let ec = self.alpha_e.mul_add(
1899                    self.e0_prev + g_sel * dx,
1900                    (1.0 - self.alpha_e) * self.ec_prev,
1901                );
1902                self.ec_prev = ec;
1903                ec_now = ec;
1904            }
1905        }
1906
1907        self.i += 1;
1908
1909        if hull_val.is_finite() && ec_now.is_finite() {
1910            Some(0.5 * (hull_val + ec_now))
1911        } else {
1912            None
1913        }
1914    }
1915}
1916
1917#[derive(Clone, Debug)]
1918pub struct DmaBatchRange {
1919    pub hull_length: (usize, usize, usize),
1920    pub ema_length: (usize, usize, usize),
1921    pub ema_gain_limit: (usize, usize, usize),
1922    pub hull_ma_type: String,
1923}
1924
1925impl Default for DmaBatchRange {
1926    fn default() -> Self {
1927        Self {
1928            hull_length: (7, 7, 0),
1929            ema_length: (20, 269, 1),
1930            ema_gain_limit: (50, 50, 0),
1931            hull_ma_type: "WMA".to_string(),
1932        }
1933    }
1934}
1935
1936#[derive(Clone, Debug, Default)]
1937pub struct DmaBatchBuilder {
1938    range: DmaBatchRange,
1939    kernel: Kernel,
1940}
1941
1942impl DmaBatchBuilder {
1943    pub fn new() -> Self {
1944        Self::default()
1945    }
1946
1947    pub fn kernel(mut self, k: Kernel) -> Self {
1948        self.kernel = k;
1949        self
1950    }
1951
1952    #[inline]
1953    pub fn hull_length_range(mut self, start: usize, end: usize, step: usize) -> Self {
1954        self.range.hull_length = (start, end, step);
1955        self
1956    }
1957
1958    #[inline]
1959    pub fn hull_length_static(mut self, val: usize) -> Self {
1960        self.range.hull_length = (val, val, 0);
1961        self
1962    }
1963
1964    #[inline]
1965    pub fn ema_length_range(mut self, start: usize, end: usize, step: usize) -> Self {
1966        self.range.ema_length = (start, end, step);
1967        self
1968    }
1969
1970    #[inline]
1971    pub fn ema_length_static(mut self, val: usize) -> Self {
1972        self.range.ema_length = (val, val, 0);
1973        self
1974    }
1975
1976    #[inline]
1977    pub fn ema_gain_limit_range(mut self, start: usize, end: usize, step: usize) -> Self {
1978        self.range.ema_gain_limit = (start, end, step);
1979        self
1980    }
1981
1982    #[inline]
1983    pub fn ema_gain_limit_static(mut self, val: usize) -> Self {
1984        self.range.ema_gain_limit = (val, val, 0);
1985        self
1986    }
1987
1988    #[inline]
1989    pub fn hull_ma_type(mut self, val: String) -> Self {
1990        self.range.hull_ma_type = val;
1991        self
1992    }
1993
1994    pub fn apply_slice(self, data: &[f64]) -> Result<DmaBatchOutput, DmaError> {
1995        dma_batch_with_kernel(data, &self.range, self.kernel)
1996    }
1997
1998    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<DmaBatchOutput, DmaError> {
1999        DmaBatchBuilder::new().kernel(k).apply_slice(data)
2000    }
2001
2002    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<DmaBatchOutput, DmaError> {
2003        let slice = source_type(c, src);
2004        self.apply_slice(slice)
2005    }
2006
2007    pub fn with_default_candles(c: &Candles) -> Result<DmaBatchOutput, DmaError> {
2008        DmaBatchBuilder::new()
2009            .kernel(Kernel::Auto)
2010            .apply_candles(c, "close")
2011    }
2012}
2013
2014#[derive(Clone, Debug)]
2015pub struct DmaBatchOutput {
2016    pub values: Vec<f64>,
2017    pub combos: Vec<DmaParams>,
2018    pub rows: usize,
2019    pub cols: usize,
2020}
2021
2022impl DmaBatchOutput {
2023    pub fn row_for_params(&self, p: &DmaParams) -> Option<usize> {
2024        self.combos.iter().position(|c| {
2025            c.hull_length.unwrap_or(7) == p.hull_length.unwrap_or(7)
2026                && c.ema_length.unwrap_or(20) == p.ema_length.unwrap_or(20)
2027                && c.ema_gain_limit.unwrap_or(50) == p.ema_gain_limit.unwrap_or(50)
2028                && c.hull_ma_type.as_ref().unwrap_or(&"WMA".to_string())
2029                    == p.hull_ma_type.as_ref().unwrap_or(&"WMA".to_string())
2030        })
2031    }
2032
2033    pub fn values_for(&self, p: &DmaParams) -> Option<&[f64]> {
2034        self.row_for_params(p).map(|row| {
2035            let start = row * self.cols;
2036            &self.values[start..start + self.cols]
2037        })
2038    }
2039}
2040
2041#[inline(always)]
2042fn expand_grid_dma(r: &DmaBatchRange) -> Vec<DmaParams> {
2043    fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
2044        if step == 0 || start == end {
2045            return vec![start];
2046        }
2047        if start < end {
2048            return (start..=end).step_by(step).collect();
2049        }
2050
2051        let mut v: Vec<usize> = (end..=start).step_by(step).collect();
2052        v.reverse();
2053        v
2054    }
2055
2056    let hull_lengths = axis_usize(r.hull_length);
2057    let ema_lengths = axis_usize(r.ema_length);
2058    let ema_gain_limits = axis_usize(r.ema_gain_limit);
2059
2060    let mut combos = Vec::new();
2061    for &h in &hull_lengths {
2062        for &e in &ema_lengths {
2063            for &g in &ema_gain_limits {
2064                combos.push(DmaParams {
2065                    hull_length: Some(h),
2066                    ema_length: Some(e),
2067                    ema_gain_limit: Some(g),
2068                    hull_ma_type: Some(r.hull_ma_type.clone()),
2069                });
2070            }
2071        }
2072    }
2073    combos
2074}
2075
2076#[inline(always)]
2077pub fn dma_batch_slice(
2078    data: &[f64],
2079    sweep: &DmaBatchRange,
2080    kern: Kernel,
2081) -> Result<DmaBatchOutput, DmaError> {
2082    dma_batch_inner(data, sweep, kern, false)
2083}
2084
2085#[inline(always)]
2086pub fn dma_batch_par_slice(
2087    data: &[f64],
2088    sweep: &DmaBatchRange,
2089    kern: Kernel,
2090) -> Result<DmaBatchOutput, DmaError> {
2091    dma_batch_inner(data, sweep, kern, true)
2092}
2093
2094#[inline(always)]
2095fn dma_batch_inner(
2096    data: &[f64],
2097    sweep: &DmaBatchRange,
2098    kern: Kernel,
2099    parallel: bool,
2100) -> Result<DmaBatchOutput, DmaError> {
2101    let combos = expand_grid_dma(sweep);
2102    let cols = data.len();
2103    let rows = combos.len();
2104    if cols == 0 {
2105        return Err(DmaError::EmptyInputData);
2106    }
2107    if rows == 0 {
2108        return Err(DmaError::InvalidRange {
2109            start: sweep.hull_length.0,
2110            end: sweep.hull_length.1,
2111            step: sweep.hull_length.2,
2112        });
2113    }
2114
2115    let _cap = rows.checked_mul(cols).ok_or(DmaError::InvalidRange {
2116        start: rows,
2117        end: cols,
2118        step: 0,
2119    })?;
2120
2121    let mut buf_mu = make_uninit_matrix(rows, cols);
2122
2123    let first = data
2124        .iter()
2125        .position(|x| !x.is_nan())
2126        .ok_or(DmaError::AllValuesNaN)?;
2127    let warm: Vec<usize> = combos
2128        .iter()
2129        .map(|c| {
2130            let h = c.hull_length.unwrap();
2131            let e = c.ema_length.unwrap();
2132            let sqrt_len = (h as f64).sqrt().round() as usize;
2133            first + h.max(e) + sqrt_len - 1
2134        })
2135        .collect();
2136    init_matrix_prefixes(&mut buf_mu, cols, &warm);
2137
2138    let mut guard = core::mem::ManuallyDrop::new(buf_mu);
2139    let out: &mut [f64] =
2140        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
2141
2142    dma_batch_inner_into(data, sweep, kern, parallel, out)?;
2143
2144    let values = unsafe {
2145        Vec::from_raw_parts(
2146            guard.as_mut_ptr() as *mut f64,
2147            guard.len(),
2148            guard.capacity(),
2149        )
2150    };
2151
2152    Ok(DmaBatchOutput {
2153        values,
2154        combos,
2155        rows,
2156        cols,
2157    })
2158}
2159
2160pub fn dma_batch_with_kernel(
2161    data: &[f64],
2162    sweep: &DmaBatchRange,
2163    k: Kernel,
2164) -> Result<DmaBatchOutput, DmaError> {
2165    let kernel = match k {
2166        Kernel::Auto => detect_best_batch_kernel(),
2167        other if other.is_batch() => other,
2168        other => return Err(DmaError::InvalidKernelForBatch(other)),
2169    };
2170
2171    let simd = match kernel {
2172        Kernel::Avx512Batch => Kernel::Avx512,
2173        Kernel::Avx2Batch => Kernel::Avx2,
2174        Kernel::ScalarBatch => Kernel::Scalar,
2175        _ => unreachable!(),
2176    };
2177    dma_batch_par_slice(data, sweep, simd)
2178}
2179
2180#[inline(always)]
2181fn dma_batch_inner_into(
2182    data: &[f64],
2183    sweep: &DmaBatchRange,
2184    k: Kernel,
2185    parallel: bool,
2186    out: &mut [f64],
2187) -> Result<Vec<DmaParams>, DmaError> {
2188    let combos = expand_grid_dma(sweep);
2189    if combos.is_empty() {
2190        return Err(DmaError::InvalidRange {
2191            start: sweep.hull_length.0,
2192            end: sweep.hull_length.1,
2193            step: sweep.hull_length.2,
2194        });
2195    }
2196
2197    let first = data
2198        .iter()
2199        .position(|x| !x.is_nan())
2200        .ok_or(DmaError::AllValuesNaN)?;
2201    let cols = data.len();
2202
2203    let actual = match k {
2204        Kernel::Auto => detect_best_batch_kernel(),
2205        other => other,
2206    };
2207    let simd = match actual {
2208        Kernel::Avx512Batch => Kernel::Avx512,
2209        Kernel::Avx2Batch => Kernel::Avx2,
2210        Kernel::ScalarBatch => Kernel::Scalar,
2211        _ => actual,
2212    };
2213
2214    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| {
2215        let dst = unsafe {
2216            core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len())
2217        };
2218        let prm = &combos[row];
2219        let hull_len = prm.hull_length.unwrap_or(7);
2220        let ema_len = prm.ema_length.unwrap_or(20);
2221
2222        let sqrt_len = (hull_len as f64).sqrt().round() as usize;
2223        let warmup_end = first + hull_len.max(ema_len) + sqrt_len - 1;
2224        let warmup_end = warmup_end.min(dst.len());
2225
2226        for i in 0..warmup_end {
2227            dst[i] = f64::NAN;
2228        }
2229
2230        dma_compute_into(
2231            data,
2232            hull_len,
2233            ema_len,
2234            prm.ema_gain_limit.unwrap_or(50),
2235            prm.hull_ma_type.as_ref().unwrap_or(&"WMA".to_string()),
2236            first,
2237            simd,
2238            dst,
2239        );
2240    };
2241
2242    let dst_mu = unsafe {
2243        std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
2244    };
2245
2246    if parallel {
2247        #[cfg(not(target_arch = "wasm32"))]
2248        dst_mu
2249            .par_chunks_mut(cols)
2250            .enumerate()
2251            .for_each(|(r, row)| do_row(r, row));
2252        #[cfg(target_arch = "wasm32")]
2253        for (r, row) in dst_mu.chunks_mut(cols).enumerate() {
2254            do_row(r, row);
2255        }
2256    } else {
2257        for (r, row) in dst_mu.chunks_mut(cols).enumerate() {
2258            do_row(r, row);
2259        }
2260    }
2261
2262    Ok(combos)
2263}
2264
2265#[cfg(feature = "python")]
2266#[pyfunction(name = "dma")]
2267#[pyo3(signature = (data, hull_length=7, ema_length=20, ema_gain_limit=50, hull_ma_type="WMA", kernel=None))]
2268pub fn dma_py<'py>(
2269    py: Python<'py>,
2270    data: PyReadonlyArray1<'py, f64>,
2271    hull_length: usize,
2272    ema_length: usize,
2273    ema_gain_limit: usize,
2274    hull_ma_type: &str,
2275    kernel: Option<&str>,
2276) -> PyResult<Bound<'py, PyArray1<f64>>> {
2277    let slice_in = data.as_slice()?;
2278    let kern = validate_kernel(kernel, false)?;
2279    let params = DmaParams {
2280        hull_length: Some(hull_length),
2281        ema_length: Some(ema_length),
2282        ema_gain_limit: Some(ema_gain_limit),
2283        hull_ma_type: Some(hull_ma_type.to_string()),
2284    };
2285    let input = DmaInput::from_slice(slice_in, params);
2286
2287    let result_vec: Vec<f64> = py
2288        .allow_threads(|| dma_with_kernel(&input, kern).map(|o| o.values))
2289        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2290
2291    Ok(result_vec.into_pyarray(py))
2292}
2293
2294#[cfg(feature = "python")]
2295#[pyclass(name = "DmaStream")]
2296pub struct DmaStreamPy {
2297    stream: DmaStream,
2298}
2299
2300#[cfg(feature = "python")]
2301#[pymethods]
2302impl DmaStreamPy {
2303    #[new]
2304    fn new(
2305        hull_length: usize,
2306        ema_length: usize,
2307        ema_gain_limit: usize,
2308        hull_ma_type: &str,
2309    ) -> PyResult<Self> {
2310        let params = DmaParams {
2311            hull_length: Some(hull_length),
2312            ema_length: Some(ema_length),
2313            ema_gain_limit: Some(ema_gain_limit),
2314            hull_ma_type: Some(hull_ma_type.to_string()),
2315        };
2316        let stream =
2317            DmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2318        Ok(DmaStreamPy { stream })
2319    }
2320
2321    fn update(&mut self, value: f64) -> Option<f64> {
2322        self.stream.update(value)
2323    }
2324}
2325
2326#[cfg(feature = "python")]
2327#[pyfunction(name = "dma_batch")]
2328#[pyo3(signature = (data, hull_length_range, ema_length_range, ema_gain_limit_range, hull_ma_type="WMA", kernel=None))]
2329pub fn dma_batch_py<'py>(
2330    py: Python<'py>,
2331    data: PyReadonlyArray1<'py, f64>,
2332    hull_length_range: (usize, usize, usize),
2333    ema_length_range: (usize, usize, usize),
2334    ema_gain_limit_range: (usize, usize, usize),
2335    hull_ma_type: &str,
2336    kernel: Option<&str>,
2337) -> PyResult<Bound<'py, PyDict>> {
2338    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2339    let slice_in = data.as_slice()?;
2340
2341    let sweep = DmaBatchRange {
2342        hull_length: hull_length_range,
2343        ema_length: ema_length_range,
2344        ema_gain_limit: ema_gain_limit_range,
2345        hull_ma_type: hull_ma_type.to_string(),
2346    };
2347
2348    let combos = expand_grid_dma(&sweep);
2349    let rows = combos.len();
2350    let cols = slice_in.len();
2351    let total = rows
2352        .checked_mul(cols)
2353        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2354
2355    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2356    let slice_out = unsafe { out_arr.as_slice_mut()? };
2357
2358    let kern = validate_kernel(kernel, true)?;
2359    let combos = py
2360        .allow_threads(|| {
2361            let kernel = match kern {
2362                Kernel::Auto => detect_best_batch_kernel(),
2363                k => k,
2364            };
2365            let simd = match kernel {
2366                Kernel::Avx512Batch => Kernel::Avx512,
2367                Kernel::Avx2Batch => Kernel::Avx2,
2368                Kernel::ScalarBatch => Kernel::Scalar,
2369                other => return Err(DmaError::InvalidKernelForBatch(other)),
2370            };
2371            dma_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2372        })
2373        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2374
2375    let dict = PyDict::new(py);
2376    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2377    dict.set_item(
2378        "hull_lengths",
2379        combos
2380            .iter()
2381            .map(|p| p.hull_length.unwrap() as u64)
2382            .collect::<Vec<_>>()
2383            .into_pyarray(py),
2384    )?;
2385    dict.set_item(
2386        "ema_lengths",
2387        combos
2388            .iter()
2389            .map(|p| p.ema_length.unwrap() as u64)
2390            .collect::<Vec<_>>()
2391            .into_pyarray(py),
2392    )?;
2393    dict.set_item(
2394        "ema_gain_limits",
2395        combos
2396            .iter()
2397            .map(|p| p.ema_gain_limit.unwrap() as u64)
2398            .collect::<Vec<_>>()
2399            .into_pyarray(py),
2400    )?;
2401    dict.set_item("hull_ma_type", hull_ma_type)?;
2402
2403    dict.set_item(
2404        "hull_ma_types",
2405        combos
2406            .iter()
2407            .map(|p| p.hull_ma_type.as_deref().unwrap_or("WMA"))
2408            .collect::<Vec<_>>(),
2409    )?;
2410
2411    Ok(dict)
2412}
2413
2414#[cfg(all(feature = "python", feature = "cuda"))]
2415#[pyfunction(name = "dma_cuda_batch_dev")]
2416#[pyo3(signature = (data_f32, hull_length_range, ema_length_range, ema_gain_limit_range, hull_ma_type="WMA", device_id=0))]
2417pub fn dma_cuda_batch_dev_py(
2418    py: Python<'_>,
2419    data_f32: numpy::PyReadonlyArray1<'_, f32>,
2420    hull_length_range: (usize, usize, usize),
2421    ema_length_range: (usize, usize, usize),
2422    ema_gain_limit_range: (usize, usize, usize),
2423    hull_ma_type: &str,
2424    device_id: usize,
2425) -> PyResult<DmaDeviceArrayF32Py> {
2426    if !cuda_available() {
2427        return Err(PyValueError::new_err("CUDA not available"));
2428    }
2429
2430    let sweep = DmaBatchRange {
2431        hull_length: hull_length_range,
2432        ema_length: ema_length_range,
2433        ema_gain_limit: ema_gain_limit_range,
2434        hull_ma_type: hull_ma_type.to_string(),
2435    };
2436
2437    let slice_in = data_f32.as_slice()?;
2438    let (inner, ctx, dev_id) = py.allow_threads(|| {
2439        let cuda = CudaDma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2440        let ctx = cuda.context();
2441        let dev_id = cuda.device_id();
2442        let arr = cuda
2443            .dma_batch_dev(slice_in, &sweep)
2444            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2445        Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
2446    })?;
2447
2448    Ok(DmaDeviceArrayF32Py {
2449        inner: Some(inner),
2450        _ctx: ctx,
2451        device_id: dev_id,
2452    })
2453}
2454
2455#[cfg(all(feature = "python", feature = "cuda"))]
2456#[pyfunction(name = "dma_cuda_many_series_one_param_dev")]
2457#[pyo3(signature = (data_tm_f32, hull_length, ema_length, ema_gain_limit, hull_ma_type="WMA", device_id=0))]
2458pub fn dma_cuda_many_series_one_param_dev_py(
2459    py: Python<'_>,
2460    data_tm_f32: PyReadonlyArray2<'_, f32>,
2461    hull_length: usize,
2462    ema_length: usize,
2463    ema_gain_limit: usize,
2464    hull_ma_type: &str,
2465    device_id: usize,
2466) -> PyResult<DmaDeviceArrayF32Py> {
2467    if !cuda_available() {
2468        return Err(PyValueError::new_err("CUDA not available"));
2469    }
2470
2471    let flat_in: &[f32] = data_tm_f32.as_slice()?;
2472    let rows = data_tm_f32.shape()[0];
2473    let cols = data_tm_f32.shape()[1];
2474    let params = DmaParams {
2475        hull_length: Some(hull_length),
2476        ema_length: Some(ema_length),
2477        ema_gain_limit: Some(ema_gain_limit),
2478        hull_ma_type: Some(hull_ma_type.to_string()),
2479    };
2480
2481    let (inner, ctx, dev_id) = py.allow_threads(|| {
2482        let cuda = CudaDma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2483        let ctx = cuda.context();
2484        let dev_id = cuda.device_id();
2485        let arr = cuda
2486            .dma_many_series_one_param_time_major_dev(flat_in, cols, rows, &params)
2487            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2488        Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
2489    })?;
2490
2491    Ok(DmaDeviceArrayF32Py {
2492        inner: Some(inner),
2493        _ctx: ctx,
2494        device_id: dev_id,
2495    })
2496}
2497
2498#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2499#[wasm_bindgen]
2500pub fn dma_js(
2501    data: &[f64],
2502    hull_length: usize,
2503    ema_length: usize,
2504    ema_gain_limit: usize,
2505    hull_ma_type: &str,
2506) -> Result<Vec<f64>, JsValue> {
2507    let params = DmaParams {
2508        hull_length: Some(hull_length),
2509        ema_length: Some(ema_length),
2510        ema_gain_limit: Some(ema_gain_limit),
2511        hull_ma_type: Some(hull_ma_type.to_string()),
2512    };
2513    let input = DmaInput::from_slice(data, params);
2514
2515    let mut output = vec![0.0; data.len()];
2516    dma_into_slice(&mut output, &input, Kernel::Auto)
2517        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2518
2519    Ok(output)
2520}
2521
2522#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2523#[wasm_bindgen]
2524pub fn dma_alloc(len: usize) -> *mut f64 {
2525    let mut vec = Vec::<f64>::with_capacity(len);
2526    let ptr = vec.as_mut_ptr();
2527    std::mem::forget(vec);
2528    ptr
2529}
2530
2531#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2532#[wasm_bindgen]
2533pub fn dma_free(ptr: *mut f64, len: usize) {
2534    unsafe {
2535        let _ = Vec::from_raw_parts(ptr, len, len);
2536    }
2537}
2538
2539#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2540#[wasm_bindgen]
2541pub fn dma_into(
2542    in_ptr: *const f64,
2543    out_ptr: *mut f64,
2544    len: usize,
2545    hull_length: usize,
2546    ema_length: usize,
2547    ema_gain_limit: usize,
2548    hull_ma_type: &str,
2549) -> Result<(), JsValue> {
2550    if in_ptr.is_null() || out_ptr.is_null() {
2551        return Err(JsValue::from_str("null pointer passed to dma_into"));
2552    }
2553
2554    unsafe {
2555        let data = std::slice::from_raw_parts(in_ptr, len);
2556
2557        let params = DmaParams {
2558            hull_length: Some(hull_length),
2559            ema_length: Some(ema_length),
2560            ema_gain_limit: Some(ema_gain_limit),
2561            hull_ma_type: Some(hull_ma_type.to_string()),
2562        };
2563        let input = DmaInput::from_slice(data, params);
2564
2565        if in_ptr == out_ptr {
2566            let mut temp = vec![0.0; len];
2567            dma_into_slice(&mut temp, &input, Kernel::Auto)
2568                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2569            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2570            out.copy_from_slice(&temp);
2571        } else {
2572            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2573            dma_into_slice(out, &input, Kernel::Auto)
2574                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2575        }
2576
2577        Ok(())
2578    }
2579}
2580
2581#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2582#[derive(Serialize, Deserialize)]
2583pub struct DmaBatchConfig {
2584    pub hull_length_range: (usize, usize, usize),
2585    pub ema_length_range: (usize, usize, usize),
2586    pub ema_gain_limit_range: (usize, usize, usize),
2587    pub hull_ma_type: String,
2588}
2589
2590#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2591#[derive(Serialize, Deserialize)]
2592pub struct DmaBatchJsOutput {
2593    pub values: Vec<f64>,
2594    pub combos: Vec<DmaParams>,
2595    pub rows: usize,
2596    pub cols: usize,
2597}
2598
2599#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2600#[wasm_bindgen(js_name = dma_batch)]
2601pub fn dma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2602    let cfg: DmaBatchConfig = serde_wasm_bindgen::from_value(config)
2603        .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
2604
2605    let sweep = DmaBatchRange {
2606        hull_length: cfg.hull_length_range,
2607        ema_length: cfg.ema_length_range,
2608        ema_gain_limit: cfg.ema_gain_limit_range,
2609        hull_ma_type: cfg.hull_ma_type,
2610    };
2611
2612    let combos = expand_grid_dma(&sweep);
2613    let rows = combos.len();
2614    let cols = data.len();
2615    if rows == 0 {
2616        return Err(JsValue::from_str("no parameter combinations"));
2617    }
2618
2619    let mut buf_mu = make_uninit_matrix(rows, cols);
2620    let first = data
2621        .iter()
2622        .position(|x| !x.is_nan())
2623        .ok_or_else(|| JsValue::from_str("All NaN"))?;
2624    let warm: Vec<usize> = combos
2625        .iter()
2626        .map(|c| {
2627            let h = c.hull_length.unwrap();
2628            let e = c.ema_length.unwrap();
2629            let sqrt_len = (h as f64).sqrt().round() as usize;
2630            first + h.max(e) + sqrt_len - 1
2631        })
2632        .collect();
2633    init_matrix_prefixes(&mut buf_mu, cols, &warm);
2634
2635    let mut guard = core::mem::ManuallyDrop::new(buf_mu);
2636    let out: &mut [f64] =
2637        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
2638
2639    dma_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
2640        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2641
2642    let values = unsafe {
2643        Vec::from_raw_parts(
2644            guard.as_mut_ptr() as *mut f64,
2645            guard.len(),
2646            guard.capacity(),
2647        )
2648    };
2649    let js = DmaBatchJsOutput {
2650        values,
2651        combos,
2652        rows,
2653        cols,
2654    };
2655    serde_wasm_bindgen::to_value(&js)
2656        .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
2657}
2658
2659#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2660#[wasm_bindgen]
2661pub fn dma_batch_into(
2662    in_ptr: *const f64,
2663    out_ptr: *mut f64,
2664    len: usize,
2665    hull_start: usize,
2666    hull_end: usize,
2667    hull_step: usize,
2668    ema_start: usize,
2669    ema_end: usize,
2670    ema_step: usize,
2671    gain_start: usize,
2672    gain_end: usize,
2673    gain_step: usize,
2674    hull_ma_type: &str,
2675) -> Result<usize, JsValue> {
2676    if in_ptr.is_null() || out_ptr.is_null() {
2677        return Err(JsValue::from_str("null pointer passed to dma_batch_into"));
2678    }
2679    unsafe {
2680        let data = std::slice::from_raw_parts(in_ptr, len);
2681        let sweep = DmaBatchRange {
2682            hull_length: (hull_start, hull_end, hull_step),
2683            ema_length: (ema_start, ema_end, ema_step),
2684            ema_gain_limit: (gain_start, gain_end, gain_step),
2685            hull_ma_type: hull_ma_type.to_string(),
2686        };
2687        let combos = expand_grid_dma(&sweep);
2688        let rows = combos.len();
2689        let cols = len;
2690
2691        let out_mu = std::slice::from_raw_parts_mut(out_ptr as *mut MaybeUninit<f64>, rows * cols);
2692        let first = data
2693            .iter()
2694            .position(|x| !x.is_nan())
2695            .ok_or_else(|| JsValue::from_str("All NaN"))?;
2696        let warm: Vec<usize> = combos
2697            .iter()
2698            .map(|c| {
2699                let h = c.hull_length.unwrap();
2700                let e = c.ema_length.unwrap();
2701                let sqrt_len = (h as f64).sqrt().round() as usize;
2702                first + h.max(e) + sqrt_len - 1
2703            })
2704            .collect();
2705        init_matrix_prefixes(out_mu, cols, &warm);
2706
2707        let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2708        dma_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
2709            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2710
2711        Ok(rows)
2712    }
2713}
2714
2715#[cfg(test)]
2716mod tests {
2717    use super::*;
2718    use crate::skip_if_unsupported;
2719    use crate::utilities::data_loader::read_candles_from_csv;
2720    use std::error::Error;
2721
2722    fn check_dma_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2723        skip_if_unsupported!(kernel, test_name);
2724        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2725        let candles = read_candles_from_csv(file_path)?;
2726
2727        let input = DmaInput::from_candles(&candles, "close", DmaParams::default());
2728        let result = dma_with_kernel(&input, kernel)?;
2729
2730        let expected_last_five = [
2731            59404.62489256,
2732            59326.48766951,
2733            59195.35128538,
2734            59153.22811529,
2735            58933.88503421,
2736        ];
2737
2738        let start = result.values.len().saturating_sub(5);
2739        for (i, &val) in result.values[start..].iter().enumerate() {
2740            let diff = (val - expected_last_five[i]).abs();
2741            assert!(
2742                diff < 0.001,
2743                "[{}] DMA {:?} mismatch at idx {}: got {}, expected {}, diff {}",
2744                test_name,
2745                kernel,
2746                i,
2747                val,
2748                expected_last_five[i],
2749                diff
2750            );
2751        }
2752        Ok(())
2753    }
2754
2755    fn check_dma_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2756        skip_if_unsupported!(kernel, test_name);
2757        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2758        let candles = read_candles_from_csv(file_path)?;
2759
2760        let default_params = DmaParams {
2761            hull_length: None,
2762            ema_length: None,
2763            ema_gain_limit: None,
2764            hull_ma_type: None,
2765        };
2766        let input = DmaInput::from_candles(&candles, "close", default_params);
2767        let output = dma_with_kernel(&input, kernel)?;
2768        assert_eq!(output.values.len(), candles.close.len());
2769
2770        Ok(())
2771    }
2772
2773    fn check_dma_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2774        skip_if_unsupported!(kernel, test_name);
2775        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2776        let candles = read_candles_from_csv(file_path)?;
2777
2778        let input = DmaInput::with_default_candles(&candles);
2779        match input.data {
2780            DmaData::Candles { source, .. } => assert_eq!(source, "close"),
2781            _ => panic!("Expected DmaData::Candles"),
2782        }
2783        let output = dma_with_kernel(&input, kernel)?;
2784        assert_eq!(output.values.len(), candles.close.len());
2785
2786        Ok(())
2787    }
2788
2789    fn check_dma_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2790        skip_if_unsupported!(kernel, test_name);
2791        let input_data = [10.0, 20.0, 30.0];
2792        let params = DmaParams {
2793            hull_length: Some(0),
2794            ema_length: None,
2795            ema_gain_limit: None,
2796            hull_ma_type: None,
2797        };
2798        let input = DmaInput::from_slice(&input_data, params);
2799        let res = dma_with_kernel(&input, kernel);
2800        assert!(
2801            res.is_err(),
2802            "[{}] DMA should fail with zero period",
2803            test_name
2804        );
2805        Ok(())
2806    }
2807
2808    fn check_dma_period_exceeds_length(
2809        test_name: &str,
2810        kernel: Kernel,
2811    ) -> Result<(), Box<dyn Error>> {
2812        skip_if_unsupported!(kernel, test_name);
2813        let data_small = [10.0, 20.0, 30.0];
2814        let params = DmaParams {
2815            hull_length: Some(10),
2816            ema_length: None,
2817            ema_gain_limit: None,
2818            hull_ma_type: None,
2819        };
2820        let input = DmaInput::from_slice(&data_small, params);
2821        let res = dma_with_kernel(&input, kernel);
2822        assert!(
2823            res.is_err(),
2824            "[{}] DMA should fail with period exceeding length",
2825            test_name
2826        );
2827        Ok(())
2828    }
2829
2830    fn check_dma_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2831        skip_if_unsupported!(kernel, test_name);
2832        let single_point = [42.0];
2833        let params = DmaParams::default();
2834        let input = DmaInput::from_slice(&single_point, params);
2835        let res = dma_with_kernel(&input, kernel);
2836        assert!(
2837            res.is_err(),
2838            "[{}] DMA should fail with insufficient data",
2839            test_name
2840        );
2841        Ok(())
2842    }
2843
2844    fn check_dma_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2845        skip_if_unsupported!(kernel, test_name);
2846        let empty: [f64; 0] = [];
2847        let params = DmaParams::default();
2848        let input = DmaInput::from_slice(&empty, params);
2849        let res = dma_with_kernel(&input, kernel);
2850        assert!(
2851            res.is_err(),
2852            "[{}] DMA should fail with empty input",
2853            test_name
2854        );
2855        Ok(())
2856    }
2857
2858    fn check_dma_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2859        skip_if_unsupported!(kernel, test_name);
2860        let nan_data = [f64::NAN, f64::NAN, f64::NAN];
2861        let params = DmaParams::default();
2862        let input = DmaInput::from_slice(&nan_data, params);
2863        let res = dma_with_kernel(&input, kernel);
2864        assert!(
2865            res.is_err(),
2866            "[{}] DMA should fail with all NaN values",
2867            test_name
2868        );
2869        Ok(())
2870    }
2871
2872    fn check_dma_invalid_hull_type(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2873        skip_if_unsupported!(kernel, test_name);
2874        let data = [10.0; 50];
2875        let params = DmaParams {
2876            hull_length: Some(7),
2877            ema_length: Some(20),
2878            ema_gain_limit: Some(50),
2879            hull_ma_type: Some("INVALID".to_string()),
2880        };
2881        let input = DmaInput::from_slice(&data, params);
2882        let res = dma_with_kernel(&input, kernel);
2883        assert!(
2884            res.is_err(),
2885            "[{}] DMA should fail with invalid hull_ma_type",
2886            test_name
2887        );
2888        Ok(())
2889    }
2890
2891    macro_rules! generate_all_dma_tests {
2892        ($($test_fn:ident),*) => {
2893            paste::paste! {
2894                $(
2895                    #[test] fn [<$test_fn _scalar>]() -> Result<(), Box<dyn Error>> { $test_fn(stringify!([<$test_fn _scalar>]), Kernel::Scalar) }
2896                    #[test] fn [<$test_fn _auto>  ]() -> Result<(), Box<dyn Error>> { $test_fn(stringify!([<$test_fn _auto>]),   Kernel::Auto) }
2897                )*
2898                #[cfg(all(feature="nightly-avx", target_arch="x86_64"))]
2899                $(
2900                    #[test] fn [<$test_fn _avx2>  ]() -> Result<(), Box<dyn Error>> { $test_fn(stringify!([<$test_fn _avx2>]),   Kernel::Avx2) }
2901                    #[test] fn [<$test_fn _avx512>]() -> Result<(), Box<dyn Error>> { $test_fn(stringify!([<$test_fn _avx512>]), Kernel::Avx512) }
2902                )*
2903            }
2904        }
2905    }
2906
2907    generate_all_dma_tests!(
2908        check_dma_accuracy,
2909        check_dma_partial_params,
2910        check_dma_default_candles,
2911        check_dma_zero_period,
2912        check_dma_period_exceeds_length,
2913        check_dma_very_small_dataset,
2914        check_dma_empty_input,
2915        check_dma_all_nan,
2916        check_dma_invalid_hull_type
2917    );
2918
2919    macro_rules! generate_dma_batch_tests {
2920        ($($fn_name:ident),*) => {
2921            paste::paste! {
2922                $(
2923                    #[test]
2924                    fn [<$fn_name _scalar_batch>]() -> Result<(), Box<dyn Error>> {
2925                        $fn_name(stringify!([<$fn_name _scalar_batch>]), Kernel::ScalarBatch)
2926                    }
2927                )*
2928                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2929                $(
2930                    #[test]
2931                    fn [<$fn_name _avx2_batch>]() -> Result<(), Box<dyn Error>> {
2932                        $fn_name(stringify!([<$fn_name _avx2_batch>]), Kernel::Avx2Batch)
2933                    }
2934                    #[test]
2935                    fn [<$fn_name _avx512_batch>]() -> Result<(), Box<dyn Error>> {
2936                        $fn_name(stringify!([<$fn_name _avx512_batch>]), Kernel::Avx512Batch)
2937                    }
2938                )*
2939            }
2940        };
2941    }
2942
2943    generate_dma_batch_tests!(check_dma_batch_basic);
2944
2945    macro_rules! gen_batch_tests {
2946        ($fn_name:ident) => {
2947            paste::paste! {
2948                #[test] fn [<$fn_name _scalar>]()      { let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch); }
2949                #[cfg(all(feature="nightly-avx", target_arch="x86_64"))]
2950                #[test] fn [<$fn_name _avx2>]()        { let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch); }
2951                #[cfg(all(feature="nightly-avx", target_arch="x86_64"))]
2952                #[test] fn [<$fn_name _avx512>]()      { let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch); }
2953                #[test] fn [<$fn_name _auto_detect>]() { let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto); }
2954            }
2955        };
2956    }
2957
2958    gen_batch_tests!(check_batch_sweep);
2959
2960    fn check_dma_reinput(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2961        skip_if_unsupported!(kernel, test);
2962        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2963        let c = read_candles_from_csv(file)?;
2964
2965        let first = DmaInput::from_candles(&c, "close", DmaParams::default());
2966        let out1 = dma_with_kernel(&first, kernel)?.values;
2967
2968        let second = DmaInput::from_slice(&out1, DmaParams::default());
2969        let out2 = dma_with_kernel(&second, kernel)?.values;
2970
2971        assert_eq!(out2.len(), out1.len());
2972        Ok(())
2973    }
2974
2975    fn check_dma_nan_handling(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2976        skip_if_unsupported!(kernel, test);
2977        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2978        let c = read_candles_from_csv(file)?;
2979
2980        let p = DmaParams::default();
2981        let input = DmaInput::from_candles(&c, "close", p.clone());
2982        let out = dma_with_kernel(&input, kernel)?.values;
2983
2984        let first = c.close.iter().position(|x| !x.is_nan()).unwrap_or(0);
2985        let sqrt_len = (p.hull_length.unwrap_or(7) as f64).sqrt().round() as usize;
2986        let warm =
2987            first + p.hull_length.unwrap_or(7).max(p.ema_length.unwrap_or(20)) + sqrt_len - 1;
2988        for (i, &v) in out.iter().enumerate().skip(warm.min(out.len())) {
2989            assert!(!v.is_nan(), "[{test}] unexpected NaN at {i}");
2990        }
2991        Ok(())
2992    }
2993
2994    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2995        skip_if_unsupported!(kernel, test);
2996        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2997        let c = read_candles_from_csv(file)?;
2998
2999        let out = DmaBatchBuilder::new()
3000            .kernel(kernel)
3001            .apply_candles(&c, "close")?;
3002        let def = DmaParams::default();
3003        let row = out.values_for(&def).expect("default row missing");
3004        assert_eq!(row.len(), c.close.len());
3005        Ok(())
3006    }
3007
3008    fn check_batch_sweep(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3009        skip_if_unsupported!(kernel, test);
3010        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3011        let c = read_candles_from_csv(file)?;
3012        let out = DmaBatchBuilder::new()
3013            .kernel(kernel)
3014            .hull_length_range(7, 18, 1)
3015            .ema_length_range(10, 15, 1)
3016            .ema_gain_limit_range(10, 20, 5)
3017            .apply_candles(&c, "close")?;
3018        let expected = 12 * 6 * 3;
3019        assert_eq!(out.combos.len(), expected);
3020        assert_eq!(out.rows, expected);
3021        assert_eq!(out.cols, c.close.len());
3022        Ok(())
3023    }
3024
3025    fn check_dma_streaming(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3026        skip_if_unsupported!(kernel, test);
3027        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3028        let c = read_candles_from_csv(file)?;
3029        let p = DmaParams::default();
3030
3031        let batch =
3032            dma_with_kernel(&DmaInput::from_candles(&c, "close", p.clone()), kernel)?.values;
3033
3034        let mut s = DmaStream::try_new(p)?;
3035        let mut stream = Vec::with_capacity(c.close.len());
3036        for &x in &c.close {
3037            stream.push(s.update(x).unwrap_or(f64::NAN));
3038        }
3039
3040        assert_eq!(batch.len(), stream.len());
3041        for (i, (&b, &t)) in batch.iter().zip(&stream).enumerate() {
3042            if b.is_nan() && t.is_nan() {
3043                continue;
3044            }
3045            assert!(
3046                (b - t).abs() < 1e-9,
3047                "[{test}] idx {i} diff {}",
3048                (b - t).abs()
3049            );
3050        }
3051        Ok(())
3052    }
3053
3054    macro_rules! gen_added_dma_tests {
3055        ($($f:ident),*) => {
3056            paste::paste! {
3057                $(
3058                    #[test] fn [<$f _scalar>]() -> Result<(), Box<dyn Error>> {
3059                        $f(stringify!([<$f _scalar>]), Kernel::Scalar)
3060                    }
3061                    #[cfg(all(feature="nightly-avx", target_arch="x86_64"))]
3062                    #[test] fn [<$f _avx2>]() -> Result<(), Box<dyn Error>> {
3063                        $f(stringify!([<$f _avx2>]), Kernel::Avx2)
3064                    }
3065                    #[cfg(all(feature="nightly-avx", target_arch="x86_64"))]
3066                    #[test] fn [<$f _avx512>]() -> Result<(), Box<dyn Error>> {
3067                        $f(stringify!([<$f _avx512>]), Kernel::Avx512)
3068                    }
3069                )*
3070            }
3071        }
3072    }
3073
3074    gen_added_dma_tests!(check_dma_reinput, check_dma_nan_handling);
3075
3076    macro_rules! gen_batch_sweep_tests {
3077        ($($f:ident),*) => {
3078            paste::paste! {
3079                $(
3080                    #[test] fn [<$f _scalar_batch>]() -> Result<(), Box<dyn Error>> {
3081                        $f(stringify!([<$f _scalar_batch>]), Kernel::ScalarBatch)
3082                    }
3083                    #[cfg(all(feature="nightly-avx", target_arch="x86_64"))]
3084                    #[test] fn [<$f _avx2_batch>]() -> Result<(), Box<dyn Error>> {
3085                        $f(stringify!([<$f _avx2_batch>]), Kernel::Avx2Batch)
3086                    }
3087                    #[cfg(all(feature="nightly-avx", target_arch="x86_64"))]
3088                    #[test] fn [<$f _avx512_batch>]() -> Result<(), Box<dyn Error>> {
3089                        $f(stringify!([<$f _avx512_batch>]), Kernel::Avx512Batch)
3090                    }
3091                )*
3092            }
3093        }
3094    }
3095
3096    gen_batch_sweep_tests!(check_batch_default_row);
3097
3098    #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
3099    #[test]
3100    fn test_dma_simd128_correctness() {
3101        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
3102        let p = DmaParams::default();
3103        let input = DmaInput::from_slice(&data, p);
3104        let scalar = dma_with_kernel(&input, Kernel::Scalar).unwrap();
3105        let simd = dma_with_kernel(&input, Kernel::Scalar).unwrap();
3106        assert_eq!(scalar.values.len(), simd.values.len());
3107        for (a, b) in scalar.values.iter().zip(simd.values.iter()) {
3108            assert!((a - b).abs() < 1e-10);
3109        }
3110    }
3111
3112    #[cfg(debug_assertions)]
3113    #[test]
3114    fn test_dma_no_poison_values() -> Result<(), Box<dyn Error>> {
3115        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3116        let candles = read_candles_from_csv(file_path)?;
3117
3118        let input = DmaInput::from_candles(&candles, "close", DmaParams::default());
3119        let output = dma(&input)?;
3120
3121        for &v in &output.values {
3122            if v.is_nan() {
3123                continue;
3124            }
3125            let b = v.to_bits();
3126
3127            assert_ne!(
3128                b, 0x11111111_11111111,
3129                "Found poison value 0x11111111_11111111"
3130            );
3131            assert_ne!(
3132                b, 0x22222222_22222222,
3133                "Found poison value 0x22222222_22222222"
3134            );
3135            assert_ne!(
3136                b, 0x33333333_33333333,
3137                "Found poison value 0x33333333_33333333"
3138            );
3139            assert_ne!(
3140                b, 0xDEADBEEF_DEADBEEF,
3141                "Found poison value 0xDEADBEEF_DEADBEEF"
3142            );
3143            assert_ne!(
3144                b, 0xFEEEFEEE_FEEEFEEE,
3145                "Found poison value 0xFEEEFEEE_FEEEFEEE"
3146            );
3147        }
3148        Ok(())
3149    }
3150
3151    fn check_dma_batch_basic(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3152        skip_if_unsupported!(kernel, test_name);
3153        let data = vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0];
3154
3155        let sweep = DmaBatchRange {
3156            hull_length: (3, 5, 1),
3157            ema_length: (5, 5, 0),
3158            ema_gain_limit: (10, 10, 0),
3159            hull_ma_type: "WMA".to_string(),
3160        };
3161        let output = dma_batch_with_kernel(&data, &sweep, kernel)?;
3162
3163        assert_eq!(
3164            output.rows, 3,
3165            "[{}] Expected 3 rows for hull_length range 3-5",
3166            test_name
3167        );
3168        assert_eq!(output.cols, data.len());
3169        assert_eq!(output.values.len(), output.rows * output.cols);
3170        assert_eq!(output.combos.len(), output.rows);
3171
3172        Ok(())
3173    }
3174
3175    #[test]
3176    fn test_dma_stream_incremental() -> Result<(), Box<dyn Error>> {
3177        let params = DmaParams {
3178            hull_length: Some(3),
3179            ema_length: Some(3),
3180            ema_gain_limit: Some(10),
3181            hull_ma_type: Some("WMA".to_string()),
3182        };
3183
3184        let mut stream = DmaStream::try_new(params)?;
3185        let data = vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0];
3186
3187        let mut results = Vec::new();
3188        for &val in &data {
3189            if let Some(result) = stream.update(val) {
3190                results.push(result);
3191            }
3192        }
3193
3194        assert!(
3195            !results.is_empty(),
3196            "Stream should produce results after warmup"
3197        );
3198
3199        Ok(())
3200    }
3201
3202    #[cfg(debug_assertions)]
3203    #[test]
3204    fn test_dma_batch_no_poison_values() -> Result<(), Box<dyn std::error::Error>> {
3205        use crate::utilities::data_loader::read_candles_from_csv;
3206        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3207        let c = read_candles_from_csv(file)?;
3208        let out = DmaBatchBuilder::new()
3209            .hull_length_range(3, 8, 1)
3210            .ema_length_range(5, 10, 1)
3211            .ema_gain_limit_static(10)
3212            .apply_slice(&c.close)?;
3213        for &v in &out.values {
3214            if v.is_nan() {
3215                continue;
3216            }
3217            let b = v.to_bits();
3218            assert_ne!(b, 0x11111111_11111111);
3219            assert_ne!(b, 0x22222222_22222222);
3220            assert_ne!(b, 0x33333333_33333333);
3221        }
3222        Ok(())
3223    }
3224
3225    #[test]
3226    fn test_dma_into_matches_api() -> Result<(), Box<dyn Error>> {
3227        let mut data = Vec::with_capacity(160);
3228        data.extend_from_slice(&[f64::NAN, f64::NAN, f64::NAN]);
3229        for i in 0..157 {
3230            let x = (i as f64 * 0.15).sin() * 5.0 + (i as f64) * 0.01;
3231            data.push(x);
3232        }
3233
3234        let input = DmaInput::from_slice(&data, DmaParams::default());
3235
3236        let baseline = dma(&input)?;
3237
3238        let mut out = vec![0.0; data.len()];
3239        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
3240        {
3241            dma_into(&input, &mut out)?;
3242        }
3243        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3244        {
3245            dma_into_slice(&mut out, &input, Kernel::Auto)?;
3246        }
3247
3248        assert_eq!(baseline.values.len(), out.len());
3249
3250        for (a, b) in baseline.values.iter().copied().zip(out.iter().copied()) {
3251            let both_nan = a.is_nan() && b.is_nan();
3252            assert!(both_nan || a == b, "mismatch: got {b:?}, expected {a:?}");
3253        }
3254        Ok(())
3255    }
3256}
3257
3258#[cfg(all(feature = "python", feature = "cuda"))]
3259#[pyclass(module = "ta_indicators.cuda", name = "DmaDeviceArrayF32", unsendable)]
3260pub struct DmaDeviceArrayF32Py {
3261    pub(crate) inner: Option<DeviceArrayF32>,
3262    pub(crate) _ctx: Arc<Context>,
3263    pub(crate) device_id: u32,
3264}
3265
3266#[cfg(all(feature = "python", feature = "cuda"))]
3267#[pymethods]
3268impl DmaDeviceArrayF32Py {
3269    #[getter]
3270    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
3271        let inner = self
3272            .inner
3273            .as_ref()
3274            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
3275        let d = PyDict::new(py);
3276        d.set_item("shape", (inner.rows, inner.cols))?;
3277        d.set_item("typestr", "<f4")?;
3278
3279        d.set_item(
3280            "strides",
3281            (
3282                inner.cols * std::mem::size_of::<f32>(),
3283                std::mem::size_of::<f32>(),
3284            ),
3285        )?;
3286        d.set_item("data", (inner.device_ptr() as usize, false))?;
3287
3288        d.set_item("version", 3)?;
3289        Ok(d)
3290    }
3291    fn __dlpack_device__(&self) -> (i32, i32) {
3292        (2, self.device_id as i32)
3293    }
3294
3295    #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
3296    fn __dlpack__<'py>(
3297        &mut self,
3298        py: Python<'py>,
3299        stream: Option<pyo3::PyObject>,
3300        max_version: Option<pyo3::PyObject>,
3301        dl_device: Option<pyo3::PyObject>,
3302        copy: Option<pyo3::PyObject>,
3303    ) -> PyResult<pyo3::PyObject> {
3304        use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
3305
3306        let (kdl, alloc_dev) = self.__dlpack_device__();
3307        if let Some(dev_obj) = dl_device.as_ref() {
3308            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
3309                if dev_ty != kdl || dev_id != alloc_dev {
3310                    let wants_copy = copy
3311                        .as_ref()
3312                        .and_then(|c| c.extract::<bool>(py).ok())
3313                        .unwrap_or(false);
3314                    if wants_copy {
3315                        return Err(PyValueError::new_err(
3316                            "device copy not implemented for __dlpack__",
3317                        ));
3318                    } else {
3319                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
3320                    }
3321                }
3322            }
3323        }
3324
3325        let _ = stream;
3326
3327        let inner = self
3328            .inner
3329            .take()
3330            .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
3331
3332        let rows = inner.rows;
3333        let cols = inner.cols;
3334        let buf = inner.buf;
3335
3336        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
3337
3338        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
3339    }
3340}