Skip to main content

vector_ta/indicators/moving_averages/
mama.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5    make_uninit_matrix,
6};
7use crate::utilities::math_functions::atan_fast;
8use aligned_vec::{AVec, CACHELINE_ALIGN};
9#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
10use core::arch::x86_64::*;
11#[cfg(not(target_arch = "wasm32"))]
12use rayon::prelude::*;
13use std::convert::AsRef;
14use std::error::Error;
15use std::f64::consts::PI;
16use std::mem::MaybeUninit;
17use thiserror::Error;
18#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
19use wasm_bindgen::prelude::*;
20
21#[derive(Debug, Clone)]
22pub enum MamaData<'a> {
23    Candles {
24        candles: &'a Candles,
25        source: &'a str,
26    },
27    Slice(&'a [f64]),
28}
29
30#[derive(Debug, Clone)]
31pub struct MamaOutput {
32    pub mama_values: Vec<f64>,
33    pub fama_values: Vec<f64>,
34}
35
36#[derive(Debug, Clone)]
37#[cfg_attr(
38    all(target_arch = "wasm32", feature = "wasm"),
39    derive(serde::Serialize, serde::Deserialize)
40)]
41pub struct MamaParams {
42    pub fast_limit: Option<f64>,
43    pub slow_limit: Option<f64>,
44}
45
46impl Default for MamaParams {
47    fn default() -> Self {
48        Self {
49            fast_limit: Some(0.5),
50            slow_limit: Some(0.05),
51        }
52    }
53}
54
55#[derive(Debug, Clone)]
56pub struct MamaInput<'a> {
57    pub data: MamaData<'a>,
58    pub params: MamaParams,
59}
60
61impl<'a> AsRef<[f64]> for MamaInput<'a> {
62    #[inline(always)]
63    fn as_ref(&self) -> &[f64] {
64        match &self.data {
65            MamaData::Slice(slice) => slice,
66            MamaData::Candles { candles, source } => source_type(candles, source),
67        }
68    }
69}
70
71impl<'a> MamaInput<'a> {
72    #[inline]
73    pub fn from_candles(c: &'a Candles, s: &'a str, p: MamaParams) -> Self {
74        Self {
75            data: MamaData::Candles {
76                candles: c,
77                source: s,
78            },
79            params: p,
80        }
81    }
82    #[inline]
83    pub fn from_slice(sl: &'a [f64], p: MamaParams) -> Self {
84        Self {
85            data: MamaData::Slice(sl),
86            params: p,
87        }
88    }
89    #[inline]
90    pub fn with_default_candles(c: &'a Candles) -> Self {
91        Self::from_candles(c, "close", MamaParams::default())
92    }
93    #[inline]
94    pub fn get_fast_limit(&self) -> f64 {
95        self.params.fast_limit.unwrap_or(0.5)
96    }
97    #[inline]
98    pub fn get_slow_limit(&self) -> f64 {
99        self.params.slow_limit.unwrap_or(0.05)
100    }
101}
102
103#[derive(Copy, Clone, Debug)]
104pub struct MamaBuilder {
105    fast_limit: Option<f64>,
106    slow_limit: Option<f64>,
107    kernel: Kernel,
108}
109
110impl Default for MamaBuilder {
111    fn default() -> Self {
112        Self {
113            fast_limit: None,
114            slow_limit: None,
115            kernel: Kernel::Auto,
116        }
117    }
118}
119
120impl MamaBuilder {
121    #[inline(always)]
122    pub fn new() -> Self {
123        Self::default()
124    }
125    #[inline(always)]
126    pub fn fast_limit(mut self, n: f64) -> Self {
127        self.fast_limit = Some(n);
128        self
129    }
130    #[inline(always)]
131    pub fn slow_limit(mut self, x: f64) -> Self {
132        self.slow_limit = Some(x);
133        self
134    }
135    #[inline(always)]
136    pub fn kernel(mut self, k: Kernel) -> Self {
137        self.kernel = k;
138        self
139    }
140    #[inline(always)]
141    pub fn apply(self, c: &Candles) -> Result<MamaOutput, MamaError> {
142        let p = MamaParams {
143            fast_limit: self.fast_limit,
144            slow_limit: self.slow_limit,
145        };
146        let i = MamaInput::from_candles(c, "close", p);
147        mama_with_kernel(&i, self.kernel)
148    }
149    #[inline(always)]
150    pub fn apply_slice(self, d: &[f64]) -> Result<MamaOutput, MamaError> {
151        let p = MamaParams {
152            fast_limit: self.fast_limit,
153            slow_limit: self.slow_limit,
154        };
155        let i = MamaInput::from_slice(d, p);
156        mama_with_kernel(&i, self.kernel)
157    }
158    #[inline(always)]
159    pub fn into_stream(self) -> Result<MamaStream, MamaError> {
160        let p = MamaParams {
161            fast_limit: self.fast_limit,
162            slow_limit: self.slow_limit,
163        };
164        MamaStream::try_new(p)
165    }
166}
167
168#[derive(Debug, Error)]
169pub enum MamaError {
170    #[error("mama: empty input data")]
171    EmptyInputData,
172    #[error("mama: all values are NaN")]
173    AllValuesNaN,
174    #[error("mama: not enough valid data: needed {needed}, valid {valid}")]
175    NotEnoughValidData { needed: usize, valid: usize },
176    #[error("mama: Not enough data: needed at least {needed}, found {found}")]
177    NotEnoughData { needed: usize, found: usize },
178    #[error("mama: output length mismatch: expected {expected}, got {got}")]
179    OutputLengthMismatch { expected: usize, got: usize },
180    #[error("mama: invalid range expansion start={start} end={end} step={step}")]
181    InvalidRange { start: f64, end: f64, step: f64 },
182    #[error("mama: invalid kernel for batch path: {0:?}")]
183    InvalidKernelForBatch(Kernel),
184    #[error("mama: Invalid fast limit: {fast_limit}")]
185    InvalidFastLimit { fast_limit: f64 },
186    #[error("mama: Invalid slow limit: {slow_limit}")]
187    InvalidSlowLimit { slow_limit: f64 },
188}
189
190#[inline]
191pub fn mama(input: &MamaInput) -> Result<MamaOutput, MamaError> {
192    mama_with_kernel(input, Kernel::Auto)
193}
194
195#[inline(always)]
196fn mama_prepare<'a>(
197    input: &'a MamaInput,
198    kernel: Kernel,
199) -> Result<(&'a [f64], f64, f64, Kernel), MamaError> {
200    let data = input.as_ref();
201    let len = data.len();
202    if len == 0 {
203        return Err(MamaError::EmptyInputData);
204    }
205    if len < 10 {
206        return Err(MamaError::NotEnoughData {
207            needed: 10,
208            found: len,
209        });
210    }
211
212    let fast_limit = input.get_fast_limit();
213    let slow_limit = input.get_slow_limit();
214    if fast_limit <= 0.0 || fast_limit.is_nan() || fast_limit.is_infinite() {
215        return Err(MamaError::InvalidFastLimit { fast_limit });
216    }
217    if slow_limit <= 0.0 || slow_limit.is_nan() || slow_limit.is_infinite() {
218        return Err(MamaError::InvalidSlowLimit { slow_limit });
219    }
220
221    let chosen = match kernel {
222        Kernel::Auto => Kernel::Scalar,
223        k => k,
224    };
225
226    Ok((data, fast_limit, slow_limit, chosen))
227}
228
229pub fn mama_with_kernel(input: &MamaInput, kernel: Kernel) -> Result<MamaOutput, MamaError> {
230    let (data, fast_limit, slow_limit, chosen) = mama_prepare(input, kernel)?;
231    let len = data.len();
232    const WARM: usize = 10;
233
234    let mut mama_values = alloc_with_nan_prefix(len, WARM);
235    let mut fama_values = alloc_with_nan_prefix(len, WARM);
236
237    unsafe {
238        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
239        {
240            if matches!(chosen, Kernel::Scalar | Kernel::ScalarBatch) {
241                mama_simd128_inplace(
242                    data,
243                    fast_limit,
244                    slow_limit,
245                    &mut mama_values,
246                    &mut fama_values,
247                );
248
249                for v in &mut mama_values[..WARM] {
250                    *v = f64::NAN;
251                }
252                for v in &mut fama_values[..WARM] {
253                    *v = f64::NAN;
254                }
255                return Ok(MamaOutput {
256                    mama_values,
257                    fama_values,
258                });
259            }
260        }
261
262        match chosen {
263            Kernel::Scalar | Kernel::ScalarBatch => {
264                mama_scalar_inplace(
265                    data,
266                    fast_limit,
267                    slow_limit,
268                    &mut mama_values,
269                    &mut fama_values,
270                );
271            }
272
273            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
274            Kernel::Avx2 | Kernel::Avx2Batch => {
275                mama_scalar_inplace(
276                    data,
277                    fast_limit,
278                    slow_limit,
279                    &mut mama_values,
280                    &mut fama_values,
281                );
282            }
283            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
284            Kernel::Avx2 | Kernel::Avx2Batch => {
285                mama_scalar_inplace(
286                    data,
287                    fast_limit,
288                    slow_limit,
289                    &mut mama_values,
290                    &mut fama_values,
291                );
292            }
293
294            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
295            Kernel::Avx512 | Kernel::Avx512Batch => {
296                mama_scalar_inplace(
297                    data,
298                    fast_limit,
299                    slow_limit,
300                    &mut mama_values,
301                    &mut fama_values,
302                );
303            }
304            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
305            Kernel::Avx512 | Kernel::Avx512Batch => {
306                mama_scalar_inplace(
307                    data,
308                    fast_limit,
309                    slow_limit,
310                    &mut mama_values,
311                    &mut fama_values,
312                );
313            }
314
315            _ => unreachable!("unsupported kernel variant"),
316        }
317    }
318
319    for v in &mut mama_values[..WARM] {
320        *v = f64::NAN;
321    }
322    for v in &mut fama_values[..WARM] {
323        *v = f64::NAN;
324    }
325
326    Ok(MamaOutput {
327        mama_values,
328        fama_values,
329    })
330}
331
332pub fn mama_compute_into(
333    input: &MamaInput,
334    kernel: Kernel,
335    out_mama: &mut [f64],
336    out_fama: &mut [f64],
337) -> Result<(), MamaError> {
338    let (data, fast_limit, slow_limit, chosen) = mama_prepare(input, kernel)?;
339
340    if out_mama.len() != data.len() || out_fama.len() != data.len() {
341        return Err(MamaError::OutputLengthMismatch {
342            expected: data.len(),
343            got: out_mama.len().min(out_fama.len()),
344        });
345    }
346
347    unsafe {
348        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
349        {
350            if matches!(chosen, Kernel::Scalar | Kernel::ScalarBatch) {
351                mama_simd128_inplace(data, fast_limit, slow_limit, out_mama, out_fama);
352                return Ok(());
353            }
354        }
355
356        match chosen {
357            Kernel::Scalar | Kernel::ScalarBatch => {
358                mama_scalar_inplace(data, fast_limit, slow_limit, out_mama, out_fama);
359            }
360
361            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
362            Kernel::Avx2 | Kernel::Avx2Batch => {
363                mama_avx2_inplace(data, fast_limit, slow_limit, out_mama, out_fama);
364            }
365
366            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
367            Kernel::Avx512 | Kernel::Avx512Batch => {
368                mama_avx512_inplace(data, fast_limit, slow_limit, out_mama, out_fama);
369            }
370
371            _ => unreachable!("unsupported kernel variant"),
372        }
373    }
374
375    Ok(())
376}
377
378#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
379#[inline]
380pub fn mama_into(
381    input: &MamaInput,
382    out_mama: &mut [f64],
383    out_fama: &mut [f64],
384) -> Result<(), MamaError> {
385    let data = input.as_ref();
386    if out_mama.len() != data.len() || out_fama.len() != data.len() {
387        return Err(MamaError::OutputLengthMismatch {
388            expected: data.len(),
389            got: out_mama.len().min(out_fama.len()),
390        });
391    }
392
393    mama_compute_into(input, Kernel::Auto, out_mama, out_fama)?;
394
395    const WARM: usize = 10;
396    let warm = WARM.min(data.len());
397    for v in &mut out_mama[..warm] {
398        *v = f64::NAN;
399    }
400    for v in &mut out_fama[..warm] {
401        *v = f64::NAN;
402    }
403    Ok(())
404}
405
406#[inline]
407pub fn mama_into_slice(
408    dst_mama: &mut [f64],
409    dst_fama: &mut [f64],
410    input: &MamaInput,
411    kern: Kernel,
412) -> Result<(), MamaError> {
413    let (data, _fast, _slow, _chosen) = mama_prepare(input, kern)?;
414    if dst_mama.len() != data.len() || dst_fama.len() != data.len() {
415        return Err(MamaError::OutputLengthMismatch {
416            expected: data.len(),
417            got: dst_mama.len().min(dst_fama.len()),
418        });
419    }
420    mama_compute_into(input, kern, dst_mama, dst_fama)?;
421
422    const WARM: usize = 10;
423    let warm = WARM.min(data.len());
424    for v in &mut dst_mama[..warm] {
425        *v = f64::NAN;
426    }
427    for v in &mut dst_fama[..warm] {
428        *v = f64::NAN;
429    }
430    Ok(())
431}
432
433#[inline(always)]
434pub fn mama_scalar(
435    data: &[f64],
436    fast_limit: f64,
437    slow_limit: f64,
438    out_mama: &mut [f64],
439    out_fama: &mut [f64],
440) -> Result<(), MamaError> {
441    mama_scalar_inplace(data, fast_limit, slow_limit, out_mama, out_fama);
442    Ok(())
443}
444
445#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
446#[inline]
447pub unsafe fn mama_avx2(
448    data: &[f64],
449    fast_limit: f64,
450    slow_limit: f64,
451    out_mama: &mut [f64],
452    out_fama: &mut [f64],
453) -> Result<(), MamaError> {
454    mama_avx2_inplace(data, fast_limit, slow_limit, out_mama, out_fama);
455    Ok(())
456}
457
458#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
459#[target_feature(enable = "avx512f,avx512dq,fma")]
460#[inline]
461unsafe fn hilbert4_avx512(x0: f64, x2: f64, x4: f64, x6: f64) -> f64 {
462    let v_x = _mm512_set_pd(0.0, 0.0, 0.0, 0.0, x6, x4, x2, x0);
463
464    const H3: f64 = -0.096_2;
465    const H2: f64 = -0.576_9;
466    const H1: f64 = 0.576_9;
467    const H0: f64 = 0.096_2;
468    let v_h = _mm512_set_pd(0.0, 0.0, 0.0, 0.0, H3, H2, H1, H0);
469
470    let v_mul = _mm512_mul_pd(v_x, v_h);
471    _mm512_reduce_add_pd(v_mul)
472}
473
474#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
475#[target_feature(enable = "avx512f,avx512dq,fma")]
476#[inline]
477pub unsafe fn mama_avx512_inplace(
478    data: &[f64],
479    fast_limit: f64,
480    slow_limit: f64,
481    out_mama: &mut [f64],
482    out_fama: &mut [f64],
483) {
484    debug_assert_eq!(data.len(), out_mama.len());
485    debug_assert_eq!(data.len(), out_fama.len());
486
487    const LEN: usize = 8;
488    const MASK: usize = LEN - 1;
489
490    #[repr(align(64))]
491    struct A([f64; LEN]);
492    let first = data[0];
493    let mut smooth = A([first; LEN]).0;
494    let mut detrender = A([first; LEN]).0;
495    let mut i1_buf = A([first; LEN]).0;
496    let mut q1_buf = A([first; LEN]).0;
497
498    const DEG_PER_RAD: f64 = 180.0 / std::f64::consts::PI;
499
500    let (mut idx, mut prev_mesa, mut prev_phase) = (0usize, 0.0, 0.0);
501    let (mut prev_mama, mut prev_fama) = (first, first);
502    let (mut prev_i2, mut prev_q2) = (0.0, 0.0);
503    let (mut prev_re, mut prev_im) = (0.0, 0.0);
504
505    #[inline(always)]
506    fn lag(buf: &[f64; LEN], p: usize, k: usize) -> f64 {
507        unsafe { *buf.get_unchecked((p.wrapping_sub(k)) & MASK) }
508    }
509
510    for (i, &price) in data.iter().enumerate() {
511        let s1 = if i >= 1 { data[i - 1] } else { price };
512        let s2 = if i >= 2 { data[i - 2] } else { price };
513        let s3 = if i >= 3 { data[i - 3] } else { price };
514        let smooth_val =
515            0.1 * (4.0_f64.mul_add(price, 3.0_f64.mul_add(s1, 2.0_f64.mul_add(s2, s3))));
516        smooth[idx] = smooth_val;
517
518        let amp = 0.075_f64.mul_add(prev_mesa, 0.54);
519        let dt_val = amp
520            * hilbert4_avx512(
521                smooth[idx],
522                lag(&smooth, idx, 2),
523                lag(&smooth, idx, 4),
524                lag(&smooth, idx, 6),
525            );
526        detrender[idx] = dt_val;
527
528        let i1 = lag(&detrender, idx, 3);
529        i1_buf[idx] = i1;
530
531        let q1 = amp
532            * hilbert4_avx512(
533                detrender[idx],
534                lag(&detrender, idx, 2),
535                lag(&detrender, idx, 4),
536                lag(&detrender, idx, 6),
537            );
538        q1_buf[idx] = q1;
539
540        let j_i = amp
541            * hilbert4_avx512(
542                i1_buf[idx],
543                lag(&i1_buf, idx, 2),
544                lag(&i1_buf, idx, 4),
545                lag(&i1_buf, idx, 6),
546            );
547        let j_q = amp
548            * hilbert4_avx512(
549                q1_buf[idx],
550                lag(&q1_buf, idx, 2),
551                lag(&q1_buf, idx, 4),
552                lag(&q1_buf, idx, 6),
553            );
554
555        let i2 = i1 - j_q;
556        let q2 = q1 + j_i;
557        let old_i2 = prev_i2;
558        let old_q2 = prev_q2;
559        let i2s = 0.2_f64.mul_add(i2, 0.8 * old_i2);
560        let q2s = 0.2_f64.mul_add(q2, 0.8 * old_q2);
561        prev_i2 = i2s;
562        prev_q2 = q2s;
563
564        let re = 0.2_f64.mul_add(i2s * old_i2 + q2s * old_q2, 0.8 * prev_re);
565        let im = 0.2_f64.mul_add(i2s * old_q2 - q2s * old_i2, 0.8 * prev_im);
566        prev_re = re;
567        prev_im = im;
568
569        let mut mesa = if re != 0.0 && im != 0.0 {
570            2.0 * std::f64::consts::PI / atan_fast(im / re)
571        } else {
572            prev_mesa
573        };
574        mesa = mesa
575            .min(1.5 * prev_mesa)
576            .max(0.67 * prev_mesa)
577            .max(6.0)
578            .min(50.0);
579        mesa = 0.2_f64.mul_add(mesa, 0.8 * prev_mesa);
580        prev_mesa = mesa;
581
582        let phase = if i1 != 0.0 {
583            atan_fast(q1 / i1) * DEG_PER_RAD
584        } else {
585            prev_phase
586        };
587        let mut dp = prev_phase - phase;
588        if dp < 1.0 {
589            dp = 1.0;
590        }
591        prev_phase = phase;
592
593        let mut alpha = fast_limit / dp;
594        alpha = alpha.clamp(slow_limit, fast_limit);
595
596        let cur_mama = alpha.mul_add(price, (1.0 - alpha) * prev_mama);
597        let cur_fama = (0.5 * alpha).mul_add(cur_mama, (1.0 - 0.5 * alpha) * prev_fama);
598        prev_mama = cur_mama;
599        prev_fama = cur_fama;
600
601        out_mama[i] = cur_mama;
602        out_fama[i] = cur_fama;
603
604        idx = (idx + 1) & MASK;
605    }
606}
607
608#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
609#[target_feature(enable = "avx2,fma")]
610#[inline]
611unsafe fn hilbert4_avx2(x0: f64, x2: f64, x4: f64, x6: f64) -> f64 {
612    let v_x = _mm256_set_pd(x6, x4, x2, x0);
613
614    const H3: f64 = -0.096_2;
615    const H2: f64 = -0.576_9;
616    const H1: f64 = 0.576_9;
617    const H0: f64 = 0.096_2;
618    let v_h = _mm256_set_pd(H3, H2, H1, H0);
619
620    let v_mul = _mm256_mul_pd(v_x, v_h);
621    let v_sum = _mm256_hadd_pd(v_mul, v_mul);
622
623    let v_fold = _mm256_permute2f128_pd(v_sum, v_sum, 0x1);
624    let v_res = _mm256_add_pd(v_sum, v_fold);
625    _mm256_cvtsd_f64(v_res)
626}
627
628#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
629#[inline]
630pub unsafe fn mama_avx2_inplace(
631    data: &[f64],
632    fast_limit: f64,
633    slow_limit: f64,
634    out_mama: &mut [f64],
635    out_fama: &mut [f64],
636) {
637    debug_assert_eq!(data.len(), out_mama.len());
638    debug_assert_eq!(data.len(), out_fama.len());
639
640    const RING_LEN: usize = 8;
641    const MASK: usize = RING_LEN - 1;
642
643    const W0: f64 = 4.0;
644    const W1: f64 = 3.0;
645    const W2: f64 = 2.0;
646    const W3: f64 = 1.0;
647
648    const H0: f64 = 0.096_2;
649    const H1: f64 = 0.576_9;
650    const H2: f64 = -0.576_9;
651    const H3: f64 = -0.096_2;
652
653    const DEG_PER_RAD: f64 = 180.0 / std::f64::consts::PI;
654
655    let first = data[0];
656    let mut smooth = [first; RING_LEN];
657    let mut detrender = [first; RING_LEN];
658    let mut i1_buf = [first; RING_LEN];
659    let mut q1_buf = [first; RING_LEN];
660
661    let mut idx = 0usize;
662    let mut prev_mesa = 0.0;
663    let mut prev_phase = 0.0;
664    let mut prev_mama = first;
665    let mut prev_fama = first;
666    let mut prev_i2 = 0.0;
667    let mut prev_q2 = 0.0;
668    let mut prev_re = 0.0;
669    let mut prev_im = 0.0;
670
671    #[inline(always)]
672    fn lag(buf: &[f64; RING_LEN], p: usize, k: usize) -> f64 {
673        buf[(p.wrapping_sub(k)) & MASK]
674    }
675
676    for (i, &price) in data.iter().enumerate() {
677        let s1 = if i >= 1 { data[i - 1] } else { price };
678        let s2 = if i >= 2 { data[i - 2] } else { price };
679        let s3 = if i >= 3 { data[i - 3] } else { price };
680
681        let smooth_val = W0.mul_add(price, W1.mul_add(s1, W2.mul_add(s2, s3))) * 0.1;
682        smooth[idx] = smooth_val;
683
684        let amp = 0.075_f64.mul_add(prev_mesa, 0.54);
685
686        let dt_val = amp
687            * hilbert4_avx2(
688                smooth[idx],
689                lag(&smooth, idx, 2),
690                lag(&smooth, idx, 4),
691                lag(&smooth, idx, 6),
692            );
693        detrender[idx] = dt_val;
694
695        let i1 = lag(&detrender, idx, 3);
696        i1_buf[idx] = i1;
697
698        let q1 = amp
699            * hilbert4_avx2(
700                detrender[idx],
701                lag(&detrender, idx, 2),
702                lag(&detrender, idx, 4),
703                lag(&detrender, idx, 6),
704            );
705        q1_buf[idx] = q1;
706
707        let j_i = amp
708            * hilbert4_avx2(
709                i1_buf[idx],
710                lag(&i1_buf, idx, 2),
711                lag(&i1_buf, idx, 4),
712                lag(&i1_buf, idx, 6),
713            );
714        let j_q = amp
715            * hilbert4_avx2(
716                q1_buf[idx],
717                lag(&q1_buf, idx, 2),
718                lag(&q1_buf, idx, 4),
719                lag(&q1_buf, idx, 6),
720            );
721
722        let i2 = i1 - j_q;
723        let q2 = q1 + j_i;
724        let old_i2 = prev_i2;
725        let old_q2 = prev_q2;
726        let i2s = 0.2_f64.mul_add(i2, 0.8 * old_i2);
727        let q2s = 0.2_f64.mul_add(q2, 0.8 * old_q2);
728        prev_i2 = i2s;
729        prev_q2 = q2s;
730
731        let re = 0.2_f64.mul_add(i2s * old_i2 + q2s * old_q2, 0.8 * prev_re);
732        let im = 0.2_f64.mul_add(i2s * old_q2 - q2s * old_i2, 0.8 * prev_im);
733        prev_re = re;
734        prev_im = im;
735
736        let mut mesa = if re != 0.0 && im != 0.0 {
737            2.0 * std::f64::consts::PI / atan_fast(im / re)
738        } else {
739            prev_mesa
740        };
741
742        mesa = mesa
743            .min(1.5 * prev_mesa)
744            .max(0.67 * prev_mesa)
745            .max(6.0)
746            .min(50.0);
747        mesa = 0.2_f64.mul_add(mesa, 0.8 * prev_mesa);
748        prev_mesa = mesa;
749
750        let phase = if i1 != 0.0 {
751            atan_fast(q1 / i1) * DEG_PER_RAD
752        } else {
753            prev_phase
754        };
755        let mut dp = prev_phase - phase;
756        if dp < 1.0 {
757            dp = 1.0;
758        }
759        prev_phase = phase;
760
761        let mut alpha = fast_limit / dp;
762        alpha = alpha.clamp(slow_limit, fast_limit);
763
764        let cur_mama = alpha.mul_add(price, (1.0 - alpha) * prev_mama);
765        let cur_fama = (0.5 * alpha).mul_add(cur_mama, (1.0 - 0.5 * alpha) * prev_fama);
766        prev_mama = cur_mama;
767        prev_fama = cur_fama;
768
769        out_mama[i] = cur_mama;
770        out_fama[i] = cur_fama;
771
772        idx = (idx + 1) & MASK;
773    }
774}
775#[inline(always)]
776fn hilbert(x0: f64, x2: f64, x4: f64, x6: f64) -> f64 {
777    0.0962 * x0 + 0.5769 * x2 - 0.5769 * x4 - 0.0962 * x6
778}
779
780#[inline]
781pub fn mama_scalar_inplace(
782    data: &[f64],
783    fast_limit: f64,
784    slow_limit: f64,
785    out_mama: &mut [f64],
786    out_fama: &mut [f64],
787) {
788    debug_assert_eq!(data.len(), out_mama.len());
789    debug_assert_eq!(data.len(), out_fama.len());
790    let len = data.len();
791
792    const RING: usize = 8;
793    const MASK: usize = RING - 1;
794
795    const H0: f64 = 0.096_2;
796    const H1: f64 = 0.576_9;
797    const H2: f64 = -0.576_9;
798    const H3: f64 = -0.096_2;
799    const DEG_PER_RAD: f64 = 180.0 / std::f64::consts::PI;
800
801    #[inline(always)]
802    fn hilbert4(x0: f64, x2: f64, x4: f64, x6: f64) -> f64 {
803        H0.mul_add(x0, H1.mul_add(x2, H2.mul_add(x4, H3 * x6)))
804    }
805
806    #[inline(always)]
807    fn lag<const N: usize>(buf: &[f64; N], pos: usize, k: usize) -> f64 {
808        buf[(pos.wrapping_sub(k)) & (N - 1)]
809    }
810
811    let first = data[0];
812
813    let mut smooth = [first; RING];
814    let mut detrender = [first; RING];
815    let mut i1_buf = [first; RING];
816    let mut q1_buf = [first; RING];
817
818    let mut idx = 0usize;
819    let mut prev_mesa = 0.0;
820    let mut prev_phase = 0.0;
821    let mut prev_mama = first;
822    let mut prev_fama = first;
823    let mut prev_i2 = 0.0;
824    let mut prev_q2 = 0.0;
825    let mut prev_re = 0.0;
826    let mut prev_im = 0.0;
827
828    for (i, &price) in data.iter().enumerate() {
829        let s1 = if i >= 1 { data[i - 1] } else { price };
830        let s2 = if i >= 2 { data[i - 2] } else { price };
831        let s3 = if i >= 3 { data[i - 3] } else { price };
832        let smooth_val =
833            0.1 * (4.0_f64.mul_add(price, 3.0_f64.mul_add(s1, 2.0_f64.mul_add(s2, s3))));
834        smooth[idx] = smooth_val;
835
836        let amp = 0.075_f64.mul_add(prev_mesa, 0.54);
837
838        let dt = amp
839            * hilbert4(
840                smooth[idx],
841                lag(&smooth, idx, 2),
842                lag(&smooth, idx, 4),
843                lag(&smooth, idx, 6),
844            );
845        detrender[idx] = dt;
846
847        let i1 = lag(&detrender, idx, 3);
848        i1_buf[idx] = i1;
849        let q1 = amp
850            * hilbert4(
851                detrender[idx],
852                lag(&detrender, idx, 2),
853                lag(&detrender, idx, 4),
854                lag(&detrender, idx, 6),
855            );
856        q1_buf[idx] = q1;
857
858        let j_i = amp
859            * hilbert4(
860                i1_buf[idx],
861                lag(&i1_buf, idx, 2),
862                lag(&i1_buf, idx, 4),
863                lag(&i1_buf, idx, 6),
864            );
865        let j_q = amp
866            * hilbert4(
867                q1_buf[idx],
868                lag(&q1_buf, idx, 2),
869                lag(&q1_buf, idx, 4),
870                lag(&q1_buf, idx, 6),
871            );
872
873        let i2 = i1 - j_q;
874        let q2 = q1 + j_i;
875        let i2s = 0.2_f64.mul_add(i2, 0.8 * prev_i2);
876        let q2s = 0.2_f64.mul_add(q2, 0.8 * prev_q2);
877        let re = 0.2_f64.mul_add(i2s * prev_i2 + q2s * prev_q2, 0.8 * prev_re);
878        let im = 0.2_f64.mul_add(i2s * prev_q2 - q2s * prev_i2, 0.8 * prev_im);
879        prev_i2 = i2s;
880        prev_q2 = q2s;
881        prev_re = re;
882        prev_im = im;
883
884        let mut mesa = if re != 0.0 && im != 0.0 {
885            2.0 * std::f64::consts::PI / atan_fast(im / re)
886        } else {
887            prev_mesa
888        };
889        if mesa > 1.5 * prev_mesa {
890            mesa = 1.5 * prev_mesa;
891        }
892        if mesa < 0.67 * prev_mesa {
893            mesa = 0.67 * prev_mesa;
894        }
895        if mesa < 6.0 {
896            mesa = 6.0;
897        }
898        if mesa > 50.0 {
899            mesa = 50.0;
900        }
901        mesa = 0.2_f64.mul_add(mesa, 0.8 * prev_mesa);
902        prev_mesa = mesa;
903
904        let phase = if i1 != 0.0 {
905            atan_fast(q1 / i1) * DEG_PER_RAD
906        } else {
907            prev_phase
908        };
909        let mut dphi = prev_phase - phase;
910        if dphi < 1.0 {
911            dphi = 1.0;
912        }
913        prev_phase = phase;
914
915        let mut alpha = fast_limit / dphi;
916        if alpha < slow_limit {
917            alpha = slow_limit;
918        }
919        if alpha > fast_limit {
920            alpha = fast_limit;
921        }
922
923        let mama = alpha.mul_add(price, (1.0 - alpha) * prev_mama);
924        let fama = (0.5 * alpha).mul_add(mama, (1.0 - 0.5 * alpha) * prev_fama);
925        prev_mama = mama;
926        prev_fama = fama;
927
928        out_mama[i] = mama;
929        out_fama[i] = fama;
930
931        idx = (idx + 1) & MASK;
932    }
933}
934
935#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
936#[inline]
937unsafe fn mama_simd128_inplace(
938    data: &[f64],
939    fast_limit: f64,
940    slow_limit: f64,
941    out_mama: &mut [f64],
942    out_fama: &mut [f64],
943) {
944    use core::arch::wasm32::*;
945
946    debug_assert_eq!(data.len(), out_mama.len());
947    debug_assert_eq!(data.len(), out_fama.len());
948
949    let len = data.len();
950
951    let mut smooth_buf = [data[0]; 7];
952    let mut detrender_buf = [data[0]; 7];
953    let mut i1_buf = [data[0]; 7];
954    let mut q1_buf = [data[0]; 7];
955
956    let mut prev_mesa_period = 0.0;
957    let mut prev_mama = data[0];
958    let mut prev_fama = data[0];
959    let mut prev_i2_sm = 0.0;
960    let mut prev_q2_sm = 0.0;
961    let mut prev_re = 0.0;
962    let mut prev_im = 0.0;
963    let mut prev_phase = 0.0;
964
965    let hilbert_weights = f64x2(0.0962, 0.5769);
966    let neg_hilbert_weights = f64x2(-0.5769, -0.0962);
967
968    let smooth_weights = f64x2(4.0, 3.0);
969    let smooth_weights2 = f64x2(2.0, 1.0);
970    let smooth_div = f64x2_splat(0.1);
971
972    #[inline(always)]
973    fn hilbert_simd128(
974        x0: f64,
975        x2: f64,
976        x4: f64,
977        x6: f64,
978        weights: v128,
979        neg_weights: v128,
980    ) -> f64 {
981        let v1 = f64x2(x0, x2);
982        let v2 = f64x2(x4, x6);
983
984        let prod1 = f64x2_mul(v1, weights);
985        let prod2 = f64x2_mul(v2, neg_weights);
986        let sum = f64x2_add(prod1, prod2);
987
988        f64x2_extract_lane::<0>(sum) + f64x2_extract_lane::<1>(sum)
989    }
990
991    for i in 0..len {
992        let price = data[i];
993
994        let s1 = if i >= 1 { data[i - 1] } else { price };
995        let s2 = if i >= 2 { data[i - 2] } else { price };
996        let s3 = if i >= 3 { data[i - 3] } else { price };
997
998        let v1 = f64x2(price, s1);
999        let v2 = f64x2(s2, s3);
1000        let prod1 = f64x2_mul(v1, smooth_weights);
1001        let prod2 = f64x2_mul(v2, smooth_weights2);
1002        let sum = f64x2_add(prod1, prod2);
1003        let smooth_val = (f64x2_extract_lane::<0>(sum) + f64x2_extract_lane::<1>(sum)) * 0.1;
1004
1005        let idx = i % 7;
1006        smooth_buf[idx] = smooth_val;
1007
1008        let x0 = smooth_buf[idx];
1009        let x2 = smooth_buf[(idx + 5) % 7];
1010        let x4 = smooth_buf[(idx + 3) % 7];
1011        let x6 = smooth_buf[(idx + 1) % 7];
1012
1013        let mesa_mult = 0.075 * prev_mesa_period + 0.54;
1014        let dt_val =
1015            hilbert_simd128(x0, x2, x4, x6, hilbert_weights, neg_hilbert_weights) * mesa_mult;
1016        detrender_buf[idx] = dt_val;
1017
1018        let i1_val = if i >= 3 {
1019            detrender_buf[(idx + 4) % 7]
1020        } else {
1021            dt_val
1022        };
1023        i1_buf[idx] = i1_val;
1024
1025        let d0 = detrender_buf[idx];
1026        let d2 = detrender_buf[(idx + 5) % 7];
1027        let d4 = detrender_buf[(idx + 3) % 7];
1028        let d6 = detrender_buf[(idx + 1) % 7];
1029        let q1_val =
1030            hilbert_simd128(d0, d2, d4, d6, hilbert_weights, neg_hilbert_weights) * mesa_mult;
1031        q1_buf[idx] = q1_val;
1032
1033        let j_i = {
1034            let i0 = i1_buf[idx];
1035            let i2 = i1_buf[(idx + 5) % 7];
1036            let i4 = i1_buf[(idx + 3) % 7];
1037            let i6 = i1_buf[(idx + 1) % 7];
1038            hilbert_simd128(i0, i2, i4, i6, hilbert_weights, neg_hilbert_weights) * mesa_mult
1039        };
1040        let j_q = {
1041            let q0 = q1_buf[idx];
1042            let q2 = q1_buf[(idx + 5) % 7];
1043            let q4 = q1_buf[(idx + 3) % 7];
1044            let q6 = q1_buf[(idx + 1) % 7];
1045            hilbert_simd128(q0, q2, q4, q6, hilbert_weights, neg_hilbert_weights) * mesa_mult
1046        };
1047
1048        let i2 = i1_val - j_q;
1049        let q2 = q1_val + j_i;
1050        let i2_sm = 0.2 * i2 + 0.8 * prev_i2_sm;
1051        let q2_sm = 0.2 * q2 + 0.8 * prev_q2_sm;
1052        let re = 0.2 * (i2_sm * prev_i2_sm + q2_sm * prev_q2_sm) + 0.8 * prev_re;
1053        let im = 0.2 * (i2_sm * prev_q2_sm - q2_sm * prev_i2_sm) + 0.8 * prev_im;
1054        prev_i2_sm = i2_sm;
1055        prev_q2_sm = q2_sm;
1056        prev_re = re;
1057        prev_im = im;
1058
1059        let mut mesa_period = if re != 0.0 && im != 0.0 {
1060            2.0 * std::f64::consts::PI / atan_fast(im / re)
1061        } else {
1062            prev_mesa_period
1063        };
1064
1065        if mesa_period > 1.5 * prev_mesa_period {
1066            mesa_period = 1.5 * prev_mesa_period;
1067        }
1068        if mesa_period < 0.67 * prev_mesa_period {
1069            mesa_period = 0.67 * prev_mesa_period;
1070        }
1071        if mesa_period < 6.0 {
1072            mesa_period = 6.0;
1073        }
1074        if mesa_period > 50.0 {
1075            mesa_period = 50.0;
1076        }
1077
1078        let phase = if i1_val != 0.0 {
1079            atan_fast(q1_val / i1_val) * 180.0 / std::f64::consts::PI
1080        } else {
1081            prev_phase
1082        };
1083
1084        let mut dp = prev_phase - phase;
1085        if dp < 1.0 {
1086            dp = 1.0;
1087        }
1088        prev_phase = phase;
1089
1090        let mut alpha = fast_limit / dp;
1091        alpha = alpha.clamp(slow_limit, fast_limit);
1092
1093        prev_mesa_period = mesa_period;
1094
1095        let mama_val = alpha * price + (1.0 - alpha) * prev_mama;
1096        let fama_val = 0.5 * alpha * mama_val + (1.0 - 0.5 * alpha) * prev_fama;
1097
1098        out_mama[i] = mama_val;
1099        out_fama[i] = fama_val;
1100
1101        prev_mama = mama_val;
1102        prev_fama = fama_val;
1103    }
1104}
1105
1106#[derive(Debug, Clone)]
1107pub struct MamaStream {
1108    fast_limit: f64,
1109    slow_limit: f64,
1110
1111    smooth: [f64; 8],
1112    detrender: [f64; 8],
1113    i1_buf: [f64; 8],
1114    q1_buf: [f64; 8],
1115    idx: usize,
1116
1117    prev_mesa: f64,
1118    prev_phase: f64,
1119    prev_mama: f64,
1120    prev_fama: f64,
1121    prev_i2: f64,
1122    prev_q2: f64,
1123    prev_re: f64,
1124    prev_im: f64,
1125
1126    last1: f64,
1127    last2: f64,
1128    last3: f64,
1129
1130    seeded: bool,
1131    seen: usize,
1132}
1133
1134impl MamaStream {
1135    #[inline]
1136    pub fn try_new(params: MamaParams) -> Result<Self, MamaError> {
1137        let fast_limit = params.fast_limit.unwrap_or(0.5);
1138        let slow_limit = params.slow_limit.unwrap_or(0.05);
1139        if fast_limit <= 0.0 || !fast_limit.is_finite() {
1140            return Err(MamaError::InvalidFastLimit { fast_limit });
1141        }
1142        if slow_limit <= 0.0 || !slow_limit.is_finite() {
1143            return Err(MamaError::InvalidSlowLimit { slow_limit });
1144        }
1145
1146        Ok(Self {
1147            fast_limit,
1148            slow_limit,
1149            smooth: [f64::NAN; 8],
1150            detrender: [f64::NAN; 8],
1151            i1_buf: [f64::NAN; 8],
1152            q1_buf: [f64::NAN; 8],
1153            idx: 0,
1154
1155            prev_mesa: 0.0,
1156            prev_phase: 0.0,
1157            prev_mama: f64::NAN,
1158            prev_fama: f64::NAN,
1159            prev_i2: 0.0,
1160            prev_q2: 0.0,
1161            prev_re: 0.0,
1162            prev_im: 0.0,
1163
1164            last1: f64::NAN,
1165            last2: f64::NAN,
1166            last3: f64::NAN,
1167
1168            seeded: false,
1169            seen: 0,
1170        })
1171    }
1172
1173    #[inline]
1174    pub fn update(&mut self, price: f64) -> Option<(f64, f64)> {
1175        const RING: usize = 8;
1176        const MASK: usize = RING - 1;
1177        const H0: f64 = 0.096_2;
1178        const H1: f64 = 0.576_9;
1179        const H2: f64 = -0.576_9;
1180        const H3: f64 = -0.096_2;
1181        const DEG_PER_RAD: f64 = 180.0 / std::f64::consts::PI;
1182
1183        #[inline(always)]
1184        fn hilbert4(x0: f64, x2: f64, x4: f64, x6: f64) -> f64 {
1185            H0.mul_add(x0, H1.mul_add(x2, H2.mul_add(x4, H3 * x6)))
1186        }
1187        #[inline(always)]
1188        fn lag<const N: usize>(buf: &[f64; N], pos: usize, k: usize) -> f64 {
1189            buf[(pos.wrapping_sub(k)) & (N - 1)]
1190        }
1191
1192        if !self.seeded {
1193            self.smooth = [price; RING];
1194            self.detrender = [price; RING];
1195            self.i1_buf = [price; RING];
1196            self.q1_buf = [price; RING];
1197            self.idx = 0;
1198
1199            self.prev_mesa = 0.0;
1200            self.prev_phase = 0.0;
1201            self.prev_mama = price;
1202            self.prev_fama = price;
1203            self.prev_i2 = 0.0;
1204            self.prev_q2 = 0.0;
1205            self.prev_re = 0.0;
1206            self.prev_im = 0.0;
1207
1208            self.last1 = price;
1209            self.last2 = price;
1210            self.last3 = price;
1211
1212            self.seeded = true;
1213
1214            let _ = self.process_one(price, hilbert4, lag::<RING>, DEG_PER_RAD);
1215
1216            return None;
1217        }
1218
1219        let (mama, fama) = self.process_one(price, hilbert4, lag::<RING>, DEG_PER_RAD);
1220
1221        self.seen += 1;
1222        if self.seen < 10 {
1223            return None;
1224        }
1225        Some((mama, fama))
1226    }
1227
1228    #[inline(always)]
1229    fn process_one(
1230        &mut self,
1231        price: f64,
1232        hilbert4: impl Fn(f64, f64, f64, f64) -> f64,
1233        lag: impl Fn(&[f64; 8], usize, usize) -> f64,
1234        deg_per_rad: f64,
1235    ) -> (f64, f64) {
1236        const MASK: usize = 7;
1237        let i = self.idx;
1238
1239        let s1 = if self.seen >= 1 { self.last1 } else { price };
1240        let s2 = if self.seen >= 2 { self.last2 } else { price };
1241        let s3 = if self.seen >= 3 { self.last3 } else { price };
1242        let smooth_val =
1243            0.1 * (4.0_f64.mul_add(price, 3.0_f64.mul_add(s1, 2.0_f64.mul_add(s2, s3))));
1244        self.smooth[i] = smooth_val;
1245
1246        let amp = 0.075_f64.mul_add(self.prev_mesa, 0.54);
1247
1248        let dt = amp
1249            * hilbert4(
1250                self.smooth[i],
1251                lag(&self.smooth, i, 2),
1252                lag(&self.smooth, i, 4),
1253                lag(&self.smooth, i, 6),
1254            );
1255        self.detrender[i] = dt;
1256
1257        let i1 = lag(&self.detrender, i, 3);
1258        self.i1_buf[i] = i1;
1259
1260        let q1 = amp
1261            * hilbert4(
1262                self.detrender[i],
1263                lag(&self.detrender, i, 2),
1264                lag(&self.detrender, i, 4),
1265                lag(&self.detrender, i, 6),
1266            );
1267        self.q1_buf[i] = q1;
1268
1269        let j_i = amp
1270            * hilbert4(
1271                self.i1_buf[i],
1272                lag(&self.i1_buf, i, 2),
1273                lag(&self.i1_buf, i, 4),
1274                lag(&self.i1_buf, i, 6),
1275            );
1276        let j_q = amp
1277            * hilbert4(
1278                self.q1_buf[i],
1279                lag(&self.q1_buf, i, 2),
1280                lag(&self.q1_buf, i, 4),
1281                lag(&self.q1_buf, i, 6),
1282            );
1283
1284        let i2 = i1 - j_q;
1285        let q2 = q1 + j_i;
1286
1287        let old_i2 = self.prev_i2;
1288        let old_q2 = self.prev_q2;
1289
1290        let i2s = 0.2_f64.mul_add(i2, 0.8 * old_i2);
1291        let q2s = 0.2_f64.mul_add(q2, 0.8 * old_q2);
1292        self.prev_i2 = i2s;
1293        self.prev_q2 = q2s;
1294
1295        let re = 0.2_f64.mul_add(i2s * old_i2 + q2s * old_q2, 0.8 * self.prev_re);
1296        let im = 0.2_f64.mul_add(i2s * old_q2 - q2s * old_i2, 0.8 * self.prev_im);
1297        self.prev_re = re;
1298        self.prev_im = im;
1299
1300        let mut mesa = if re != 0.0 && im != 0.0 {
1301            2.0 * std::f64::consts::PI / atan_fast(im / re)
1302        } else {
1303            self.prev_mesa
1304        };
1305
1306        mesa = mesa
1307            .min(1.5 * self.prev_mesa)
1308            .max(0.67 * self.prev_mesa)
1309            .max(6.0)
1310            .min(50.0);
1311        mesa = 0.2_f64.mul_add(mesa, 0.8 * self.prev_mesa);
1312        self.prev_mesa = mesa;
1313
1314        let phase = if i1 != 0.0 {
1315            atan_fast(q1 / i1) * deg_per_rad
1316        } else {
1317            self.prev_phase
1318        };
1319        let mut dphi = self.prev_phase - phase;
1320        if dphi < 1.0 {
1321            dphi = 1.0;
1322        }
1323        self.prev_phase = phase;
1324
1325        let mut alpha = self.fast_limit / dphi;
1326        if alpha < self.slow_limit {
1327            alpha = self.slow_limit;
1328        }
1329        if alpha > self.fast_limit {
1330            alpha = self.fast_limit;
1331        }
1332
1333        let one_minus_alpha = 1.0 - alpha;
1334        let mama = alpha.mul_add(price, one_minus_alpha * self.prev_mama);
1335
1336        let half_alpha = 0.5 * alpha;
1337        let fama = half_alpha.mul_add(mama, (1.0 - half_alpha) * self.prev_fama);
1338
1339        self.prev_mama = mama;
1340        self.prev_fama = fama;
1341
1342        self.idx = (self.idx + 1) & MASK;
1343        self.last3 = self.last2;
1344        self.last2 = self.last1;
1345        self.last1 = price;
1346
1347        (mama, fama)
1348    }
1349}
1350
1351#[derive(Clone, Debug)]
1352pub struct MamaBatchRange {
1353    pub fast_limit: (f64, f64, f64),
1354    pub slow_limit: (f64, f64, f64),
1355}
1356
1357impl Default for MamaBatchRange {
1358    fn default() -> Self {
1359        Self {
1360            fast_limit: (0.5, 0.749, 0.001),
1361            slow_limit: (0.05, 0.05, 0.0),
1362        }
1363    }
1364}
1365
1366#[derive(Clone, Debug, Default)]
1367pub struct MamaBatchBuilder {
1368    range: MamaBatchRange,
1369    kernel: Kernel,
1370}
1371
1372impl MamaBatchBuilder {
1373    pub fn new() -> Self {
1374        Self::default()
1375    }
1376    pub fn kernel(mut self, k: Kernel) -> Self {
1377        self.kernel = k;
1378        self
1379    }
1380    #[inline]
1381    pub fn fast_limit_range(mut self, start: f64, end: f64, step: f64) -> Self {
1382        self.range.fast_limit = (start, end, step);
1383        self
1384    }
1385    #[inline]
1386    pub fn fast_limit_static(mut self, x: f64) -> Self {
1387        self.range.fast_limit = (x, x, 0.0);
1388        self
1389    }
1390    #[inline]
1391    pub fn slow_limit_range(mut self, start: f64, end: f64, step: f64) -> Self {
1392        self.range.slow_limit = (start, end, step);
1393        self
1394    }
1395    #[inline]
1396    pub fn slow_limit_static(mut self, x: f64) -> Self {
1397        self.range.slow_limit = (x, x, 0.0);
1398        self
1399    }
1400    pub fn apply_slice(self, data: &[f64]) -> Result<MamaBatchOutput, MamaError> {
1401        mama_batch_with_kernel(data, &self.range, self.kernel)
1402    }
1403    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<MamaBatchOutput, MamaError> {
1404        MamaBatchBuilder::new().kernel(k).apply_slice(data)
1405    }
1406    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<MamaBatchOutput, MamaError> {
1407        let slice = source_type(c, src);
1408        self.apply_slice(slice)
1409    }
1410    pub fn with_default_candles(c: &Candles) -> Result<MamaBatchOutput, MamaError> {
1411        MamaBatchBuilder::new()
1412            .kernel(Kernel::Auto)
1413            .apply_candles(c, "close")
1414    }
1415}
1416
1417#[derive(Clone, Debug)]
1418pub struct MamaBatchOutput {
1419    pub mama_values: Vec<f64>,
1420    pub fama_values: Vec<f64>,
1421    pub combos: Vec<MamaParams>,
1422    pub rows: usize,
1423    pub cols: usize,
1424}
1425
1426impl MamaBatchOutput {
1427    pub fn row_for_params(&self, p: &MamaParams) -> Option<usize> {
1428        self.combos.iter().position(|c| {
1429            (c.fast_limit.unwrap_or(0.5) - p.fast_limit.unwrap_or(0.5)).abs() < 1e-12
1430                && (c.slow_limit.unwrap_or(0.05) - p.slow_limit.unwrap_or(0.05)).abs() < 1e-12
1431        })
1432    }
1433    pub fn mama_for(&self, p: &MamaParams) -> Option<&[f64]> {
1434        self.row_for_params(p).map(|row| {
1435            let start = row * self.cols;
1436            &self.mama_values[start..start + self.cols]
1437        })
1438    }
1439    pub fn fama_for(&self, p: &MamaParams) -> Option<&[f64]> {
1440        self.row_for_params(p).map(|row| {
1441            let start = row * self.cols;
1442            &self.fama_values[start..start + self.cols]
1443        })
1444    }
1445}
1446
1447#[inline(always)]
1448pub fn expand_grid(r: &MamaBatchRange) -> Result<Vec<MamaParams>, MamaError> {
1449    fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, MamaError> {
1450        if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
1451            return Ok(vec![start]);
1452        }
1453
1454        let mut step_signed = step;
1455        if end < start && step_signed > 0.0 {
1456            step_signed = -step_signed;
1457        } else if end > start && step_signed < 0.0 {
1458            step_signed = -step_signed;
1459        }
1460
1461        let mut v = Vec::new();
1462        let eps = 1e-12_f64;
1463        let mut x = start;
1464        if step_signed > 0.0 {
1465            while x <= end + eps {
1466                v.push(x);
1467                x += step_signed;
1468            }
1469        } else {
1470            while x >= end - eps {
1471                v.push(x);
1472                x += step_signed;
1473            }
1474        }
1475
1476        if v.is_empty() {
1477            return Err(MamaError::InvalidRange { start, end, step });
1478        }
1479        Ok(v)
1480    }
1481
1482    let fast_limits = axis_f64(r.fast_limit)?;
1483    let slow_limits = axis_f64(r.slow_limit)?;
1484
1485    let cap = fast_limits
1486        .len()
1487        .checked_mul(slow_limits.len())
1488        .ok_or(MamaError::InvalidRange {
1489            start: r.fast_limit.0,
1490            end: r.fast_limit.1,
1491            step: r.fast_limit.2,
1492        })?;
1493
1494    let mut out = Vec::with_capacity(cap);
1495    for &f in &fast_limits {
1496        for &s in &slow_limits {
1497            out.push(MamaParams {
1498                fast_limit: Some(f),
1499                slow_limit: Some(s),
1500            });
1501        }
1502    }
1503    Ok(out)
1504}
1505
1506pub fn mama_batch_with_kernel(
1507    data: &[f64],
1508    sweep: &MamaBatchRange,
1509    k: Kernel,
1510) -> Result<MamaBatchOutput, MamaError> {
1511    let kernel = match k {
1512        Kernel::Auto => Kernel::ScalarBatch,
1513        other if other.is_batch() => other,
1514        other => return Err(MamaError::InvalidKernelForBatch(other)),
1515    };
1516
1517    let simd = Kernel::Scalar;
1518    mama_batch_par_slice(data, sweep, simd)
1519}
1520
1521#[inline(always)]
1522pub fn mama_batch_slice(
1523    data: &[f64],
1524    sweep: &MamaBatchRange,
1525    kern: Kernel,
1526) -> Result<MamaBatchOutput, MamaError> {
1527    mama_batch_inner(data, sweep, kern, false)
1528}
1529
1530#[inline(always)]
1531pub fn mama_batch_par_slice(
1532    data: &[f64],
1533    sweep: &MamaBatchRange,
1534    kern: Kernel,
1535) -> Result<MamaBatchOutput, MamaError> {
1536    mama_batch_inner(data, sweep, kern, true)
1537}
1538
1539fn mama_batch_inner(
1540    data: &[f64],
1541    sweep: &MamaBatchRange,
1542    kern: Kernel,
1543    parallel: bool,
1544) -> Result<MamaBatchOutput, MamaError> {
1545    let combos = expand_grid(sweep)?;
1546    if combos.is_empty() {
1547        return Err(MamaError::InvalidRange {
1548            start: sweep.fast_limit.0,
1549            end: sweep.fast_limit.1,
1550            step: sweep.fast_limit.2,
1551        });
1552    }
1553    if data.len() < 10 {
1554        return Err(MamaError::NotEnoughData {
1555            needed: 10,
1556            found: data.len(),
1557        });
1558    }
1559
1560    for combo in &combos {
1561        let fast_limit = combo.fast_limit.unwrap_or(0.5);
1562        let slow_limit = combo.slow_limit.unwrap_or(0.05);
1563
1564        if fast_limit <= 0.0 || fast_limit.is_nan() || fast_limit.is_infinite() {
1565            return Err(MamaError::InvalidFastLimit { fast_limit });
1566        }
1567        if slow_limit <= 0.0 || slow_limit.is_nan() || slow_limit.is_infinite() {
1568            return Err(MamaError::InvalidSlowLimit { slow_limit });
1569        }
1570    }
1571
1572    let rows = combos.len();
1573    let cols = data.len();
1574
1575    let mut raw_mama = make_uninit_matrix(rows, cols);
1576    let mut raw_fama = make_uninit_matrix(rows, cols);
1577
1578    let warm_prefixes = vec![10; rows];
1579    unsafe {
1580        init_matrix_prefixes(&mut raw_mama, cols, &warm_prefixes);
1581        init_matrix_prefixes(&mut raw_fama, cols, &warm_prefixes);
1582    }
1583
1584    let delta_phase: Vec<f64> = {
1585        const RING: usize = 8;
1586        const MASK: usize = RING - 1;
1587        const H0: f64 = 0.096_2;
1588        const H1: f64 = 0.576_9;
1589        const H2: f64 = -0.576_9;
1590        const H3: f64 = -0.096_2;
1591        const DEG_PER_RAD: f64 = 180.0 / std::f64::consts::PI;
1592
1593        #[inline(always)]
1594        fn hilbert4(x0: f64, x2: f64, x4: f64, x6: f64) -> f64 {
1595            H0.mul_add(x0, H1.mul_add(x2, H2.mul_add(x4, H3 * x6)))
1596        }
1597        #[inline(always)]
1598        fn lag<const N: usize>(buf: &[f64; N], pos: usize, k: usize) -> f64 {
1599            buf[(pos.wrapping_sub(k)) & (N - 1)]
1600        }
1601
1602        let mut out = vec![1.0; cols];
1603        if cols == 0 {
1604            out
1605        } else {
1606            let first = data[0];
1607            let mut smooth = [first; RING];
1608            let mut detrender = [first; RING];
1609            let mut i1_buf = [first; RING];
1610            let mut q1_buf = [first; RING];
1611
1612            let mut idx = 0usize;
1613            let mut prev_mesa = 0.0;
1614            let mut prev_phase = 0.0;
1615            let mut prev_i2 = 0.0;
1616            let mut prev_q2 = 0.0;
1617            let mut prev_re = 0.0;
1618            let mut prev_im = 0.0;
1619
1620            for (i, &price) in data.iter().enumerate() {
1621                let s1 = if i >= 1 { data[i - 1] } else { price };
1622                let s2 = if i >= 2 { data[i - 2] } else { price };
1623                let s3 = if i >= 3 { data[i - 3] } else { price };
1624                let smooth_val =
1625                    0.1 * (4.0_f64.mul_add(price, 3.0_f64.mul_add(s1, 2.0_f64.mul_add(s2, s3))));
1626                smooth[idx] = smooth_val;
1627
1628                let amp = 0.075_f64.mul_add(prev_mesa, 0.54);
1629                let dt = amp
1630                    * hilbert4(
1631                        smooth[idx],
1632                        lag(&smooth, idx, 2),
1633                        lag(&smooth, idx, 4),
1634                        lag(&smooth, idx, 6),
1635                    );
1636                detrender[idx] = dt;
1637
1638                let i1 = lag(&detrender, idx, 3);
1639                i1_buf[idx] = i1;
1640                let q1 = amp
1641                    * hilbert4(
1642                        detrender[idx],
1643                        lag(&detrender, idx, 2),
1644                        lag(&detrender, idx, 4),
1645                        lag(&detrender, idx, 6),
1646                    );
1647                q1_buf[idx] = q1;
1648
1649                let j_i = amp
1650                    * hilbert4(
1651                        i1_buf[idx],
1652                        lag(&i1_buf, idx, 2),
1653                        lag(&i1_buf, idx, 4),
1654                        lag(&i1_buf, idx, 6),
1655                    );
1656                let j_q = amp
1657                    * hilbert4(
1658                        q1_buf[idx],
1659                        lag(&q1_buf, idx, 2),
1660                        lag(&q1_buf, idx, 4),
1661                        lag(&q1_buf, idx, 6),
1662                    );
1663
1664                let i2 = i1 - j_q;
1665                let q2 = q1 + j_i;
1666                let old_i2 = prev_i2;
1667                let old_q2 = prev_q2;
1668                let i2s = 0.2_f64.mul_add(i2, 0.8 * old_i2);
1669                let q2s = 0.2_f64.mul_add(q2, 0.8 * old_q2);
1670                prev_i2 = i2s;
1671                prev_q2 = q2s;
1672                let re = 0.2_f64.mul_add(i2s * old_i2 + q2s * old_q2, 0.8 * prev_re);
1673                let im = 0.2_f64.mul_add(i2s * old_q2 - q2s * old_i2, 0.8 * prev_im);
1674                prev_re = re;
1675                prev_im = im;
1676
1677                let mut mesa = if re != 0.0 && im != 0.0 {
1678                    2.0 * std::f64::consts::PI / atan_fast(im / re)
1679                } else {
1680                    prev_mesa
1681                };
1682                if mesa > 1.5 * prev_mesa {
1683                    mesa = 1.5 * prev_mesa;
1684                }
1685                if mesa < 0.67 * prev_mesa {
1686                    mesa = 0.67 * prev_mesa;
1687                }
1688                if mesa < 6.0 {
1689                    mesa = 6.0;
1690                }
1691                if mesa > 50.0 {
1692                    mesa = 50.0;
1693                }
1694                mesa = 0.2_f64.mul_add(mesa, 0.8 * prev_mesa);
1695                prev_mesa = mesa;
1696
1697                let phase = if i1 != 0.0 {
1698                    atan_fast(q1 / i1) * DEG_PER_RAD
1699                } else {
1700                    prev_phase
1701                };
1702                let mut dphi = prev_phase - phase;
1703                if dphi < 1.0 {
1704                    dphi = 1.0;
1705                }
1706                prev_phase = phase;
1707                out[i] = dphi;
1708
1709                idx = (idx + 1) & MASK;
1710            }
1711            out
1712        }
1713    };
1714
1715    let do_row = |row: usize, dst_m: &mut [MaybeUninit<f64>], dst_f: &mut [MaybeUninit<f64>]| unsafe {
1716        let prm = &combos[row];
1717        let fast = prm.fast_limit.unwrap_or(0.5);
1718        let slow = prm.slow_limit.unwrap_or(0.05);
1719
1720        let out_m = core::slice::from_raw_parts_mut(dst_m.as_mut_ptr() as *mut f64, dst_m.len());
1721        let out_f = core::slice::from_raw_parts_mut(dst_f.as_mut_ptr() as *mut f64, dst_f.len());
1722
1723        let mut prev_mama = data[0];
1724        let mut prev_fama = data[0];
1725        for i in 0..cols {
1726            let price = data[i];
1727            let mut alpha = fast / delta_phase[i];
1728            if alpha < slow {
1729                alpha = slow;
1730            }
1731            if alpha > fast {
1732                alpha = fast;
1733            }
1734
1735            let mama = alpha.mul_add(price, (1.0 - alpha) * prev_mama);
1736            let fama = (0.5 * alpha).mul_add(mama, (1.0 - 0.5 * alpha) * prev_fama);
1737            prev_mama = mama;
1738            prev_fama = fama;
1739            out_m[i] = mama;
1740            out_f[i] = fama;
1741        }
1742
1743        for j in 0..10.min(out_m.len()) {
1744            out_m[j] = f64::NAN;
1745            out_f[j] = f64::NAN;
1746        }
1747    };
1748
1749    if parallel {
1750        #[cfg(not(target_arch = "wasm32"))]
1751        {
1752            raw_mama
1753                .par_chunks_mut(cols)
1754                .zip(raw_fama.par_chunks_mut(cols))
1755                .enumerate()
1756                .for_each(|(row, (m_row, f_row))| do_row(row, m_row, f_row));
1757        }
1758
1759        #[cfg(target_arch = "wasm32")]
1760        {
1761            for (row, (m_row, f_row)) in raw_mama
1762                .chunks_mut(cols)
1763                .zip(raw_fama.chunks_mut(cols))
1764                .enumerate()
1765            {
1766                do_row(row, m_row, f_row);
1767            }
1768        }
1769    } else {
1770        for (row, (m_row, f_row)) in raw_mama
1771            .chunks_mut(cols)
1772            .zip(raw_fama.chunks_mut(cols))
1773            .enumerate()
1774        {
1775            do_row(row, m_row, f_row);
1776        }
1777    }
1778
1779    let mut guard_m = core::mem::ManuallyDrop::new(raw_mama);
1780    let mut guard_f = core::mem::ManuallyDrop::new(raw_fama);
1781
1782    let mama_values = unsafe {
1783        Vec::from_raw_parts(
1784            guard_m.as_mut_ptr() as *mut f64,
1785            guard_m.len(),
1786            guard_m.capacity(),
1787        )
1788    };
1789    let fama_values = unsafe {
1790        Vec::from_raw_parts(
1791            guard_f.as_mut_ptr() as *mut f64,
1792            guard_f.len(),
1793            guard_f.capacity(),
1794        )
1795    };
1796
1797    Ok(MamaBatchOutput {
1798        mama_values,
1799        fama_values,
1800        combos,
1801        rows,
1802        cols,
1803    })
1804}
1805
1806pub fn mama_batch_inner_into(
1807    data: &[f64],
1808    sweep: &MamaBatchRange,
1809    kern: Kernel,
1810    parallel: bool,
1811    out_mama: &mut [f64],
1812    out_fama: &mut [f64],
1813) -> Result<Vec<MamaParams>, MamaError> {
1814    let combos = expand_grid(sweep)?;
1815    if combos.is_empty() {
1816        return Err(MamaError::InvalidRange {
1817            start: sweep.fast_limit.0,
1818            end: sweep.fast_limit.1,
1819            step: sweep.fast_limit.2,
1820        });
1821    }
1822    if data.len() < 10 {
1823        return Err(MamaError::NotEnoughData {
1824            needed: 10,
1825            found: data.len(),
1826        });
1827    }
1828
1829    for combo in &combos {
1830        let fast_limit = combo.fast_limit.unwrap_or(0.5);
1831        let slow_limit = combo.slow_limit.unwrap_or(0.05);
1832
1833        if fast_limit <= 0.0 || fast_limit.is_nan() || fast_limit.is_infinite() {
1834            return Err(MamaError::InvalidFastLimit { fast_limit });
1835        }
1836        if slow_limit <= 0.0 || slow_limit.is_nan() || slow_limit.is_infinite() {
1837            return Err(MamaError::InvalidSlowLimit { slow_limit });
1838        }
1839    }
1840
1841    let rows = combos.len();
1842    let cols = data.len();
1843
1844    let expected = rows.checked_mul(cols).ok_or(MamaError::InvalidRange {
1845        start: sweep.fast_limit.0,
1846        end: sweep.fast_limit.1,
1847        step: sweep.fast_limit.2,
1848    })?;
1849    if out_mama.len() != expected || out_fama.len() != expected {
1850        return Err(MamaError::OutputLengthMismatch {
1851            expected,
1852            got: out_mama.len().min(out_fama.len()),
1853        });
1854    }
1855
1856    let out_mama_uninit = unsafe {
1857        std::slice::from_raw_parts_mut(
1858            out_mama.as_mut_ptr() as *mut MaybeUninit<f64>,
1859            out_mama.len(),
1860        )
1861    };
1862    let out_fama_uninit = unsafe {
1863        std::slice::from_raw_parts_mut(
1864            out_fama.as_mut_ptr() as *mut MaybeUninit<f64>,
1865            out_fama.len(),
1866        )
1867    };
1868
1869    let warm_prefixes = vec![10; rows];
1870    unsafe {
1871        init_matrix_prefixes(out_mama_uninit, cols, &warm_prefixes);
1872        init_matrix_prefixes(out_fama_uninit, cols, &warm_prefixes);
1873    }
1874
1875    let do_row = |row: usize, dst_m: &mut [MaybeUninit<f64>], dst_f: &mut [MaybeUninit<f64>]| unsafe {
1876        let prm = &combos[row];
1877        let fast = prm.fast_limit.unwrap_or(0.5);
1878        let slow = prm.slow_limit.unwrap_or(0.05);
1879
1880        let out_m = core::slice::from_raw_parts_mut(dst_m.as_mut_ptr() as *mut f64, dst_m.len());
1881        let out_f = core::slice::from_raw_parts_mut(dst_f.as_mut_ptr() as *mut f64, dst_f.len());
1882
1883        match kern {
1884            Kernel::Scalar => mama_row_scalar(data, fast, slow, out_m, out_f),
1885            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1886            Kernel::Avx2 => mama_row_avx2(data, fast, slow, out_m, out_f),
1887            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1888            Kernel::Avx512 => mama_row_avx512(data, fast, slow, out_m, out_f),
1889            _ => unreachable!(),
1890        }
1891
1892        for j in 0..10.min(out_m.len()) {
1893            out_m[j] = f64::NAN;
1894            out_f[j] = f64::NAN;
1895        }
1896    };
1897
1898    if parallel {
1899        #[cfg(not(target_arch = "wasm32"))]
1900        {
1901            out_mama_uninit
1902                .par_chunks_mut(cols)
1903                .zip(out_fama_uninit.par_chunks_mut(cols))
1904                .enumerate()
1905                .for_each(|(row, (m_row, f_row))| do_row(row, m_row, f_row));
1906        }
1907
1908        #[cfg(target_arch = "wasm32")]
1909        {
1910            for (row, (m_row, f_row)) in out_mama_uninit
1911                .chunks_mut(cols)
1912                .zip(out_fama_uninit.chunks_mut(cols))
1913                .enumerate()
1914            {
1915                do_row(row, m_row, f_row);
1916            }
1917        }
1918    } else {
1919        for (row, (m_row, f_row)) in out_mama_uninit
1920            .chunks_mut(cols)
1921            .zip(out_fama_uninit.chunks_mut(cols))
1922            .enumerate()
1923        {
1924            do_row(row, m_row, f_row);
1925        }
1926    }
1927
1928    Ok(combos)
1929}
1930
1931#[inline(always)]
1932pub unsafe fn mama_row_scalar(
1933    data: &[f64],
1934    fast_limit: f64,
1935    slow_limit: f64,
1936    out_mama: &mut [f64],
1937    out_fama: &mut [f64],
1938) {
1939    mama_scalar_inplace(data, fast_limit, slow_limit, out_mama, out_fama);
1940}
1941
1942#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1943#[inline(always)]
1944pub unsafe fn mama_row_avx2(
1945    data: &[f64],
1946    fast_limit: f64,
1947    slow_limit: f64,
1948    out_mama: &mut [f64],
1949    out_fama: &mut [f64],
1950) {
1951    mama_avx2_inplace(data, fast_limit, slow_limit, out_mama, out_fama);
1952}
1953
1954#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1955#[inline(always)]
1956pub unsafe fn mama_row_avx512(
1957    data: &[f64],
1958    fast_limit: f64,
1959    slow_limit: f64,
1960    out_mama: &mut [f64],
1961    out_fama: &mut [f64],
1962) {
1963    mama_avx512_inplace(data, fast_limit, slow_limit, out_mama, out_fama);
1964}
1965
1966#[cfg(test)]
1967mod tests {
1968    use super::*;
1969    use crate::skip_if_unsupported;
1970    use crate::utilities::data_loader::read_candles_from_csv;
1971    use paste::paste;
1972    use proptest::prelude::*;
1973
1974    fn check_mama_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1975        skip_if_unsupported!(kernel, test_name);
1976        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1977        let candles = read_candles_from_csv(file_path)?;
1978        let default_params = MamaParams {
1979            fast_limit: None,
1980            slow_limit: None,
1981        };
1982        let input = MamaInput::from_candles(&candles, "close", default_params);
1983        let output = mama_with_kernel(&input, kernel)?;
1984        assert_eq!(output.mama_values.len(), candles.close.len());
1985        assert_eq!(output.fama_values.len(), candles.close.len());
1986        Ok(())
1987    }
1988
1989    fn check_mama_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1990        skip_if_unsupported!(kernel, test_name);
1991        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1992        let candles = read_candles_from_csv(file_path)?;
1993        let input = MamaInput::from_candles(&candles, "close", MamaParams::default());
1994        let result = mama_with_kernel(&input, kernel)?;
1995        assert_eq!(result.mama_values.len(), candles.close.len());
1996        assert_eq!(result.fama_values.len(), candles.close.len());
1997        Ok(())
1998    }
1999
2000    fn check_mama_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2001        skip_if_unsupported!(kernel, test_name);
2002        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2003        let candles = read_candles_from_csv(file_path)?;
2004        let input = MamaInput::with_default_candles(&candles);
2005        match input.data {
2006            MamaData::Candles { source, .. } => assert_eq!(source, "close"),
2007            _ => panic!("Expected MamaData::Candles"),
2008        }
2009        let output = mama_with_kernel(&input, kernel)?;
2010        assert_eq!(output.mama_values.len(), candles.close.len());
2011        assert_eq!(output.fama_values.len(), candles.close.len());
2012        Ok(())
2013    }
2014
2015    fn check_mama_with_insufficient_data(
2016        test_name: &str,
2017        kernel: Kernel,
2018    ) -> Result<(), Box<dyn Error>> {
2019        skip_if_unsupported!(kernel, test_name);
2020        let input_data = [100.0; 9];
2021        let params = MamaParams::default();
2022        let input = MamaInput::from_slice(&input_data, params);
2023        let res = mama_with_kernel(&input, kernel);
2024        assert!(res.is_err());
2025        Ok(())
2026    }
2027
2028    fn check_mama_very_small_dataset(
2029        test_name: &str,
2030        kernel: Kernel,
2031    ) -> Result<(), Box<dyn Error>> {
2032        skip_if_unsupported!(kernel, test_name);
2033        let input_data = [42.0; 10];
2034        let params = MamaParams::default();
2035        let input = MamaInput::from_slice(&input_data, params);
2036        let result = mama_with_kernel(&input, kernel)?;
2037        assert_eq!(result.mama_values.len(), input_data.len());
2038        assert_eq!(result.fama_values.len(), input_data.len());
2039        Ok(())
2040    }
2041
2042    fn check_mama_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2043        skip_if_unsupported!(kernel, test_name);
2044        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2045        let candles = read_candles_from_csv(file_path)?;
2046        let first_params = MamaParams::default();
2047        let first_input = MamaInput::from_candles(&candles, "close", first_params);
2048        let first_result = mama_with_kernel(&first_input, kernel)?;
2049        let second_params = MamaParams {
2050            fast_limit: Some(0.7),
2051            slow_limit: Some(0.1),
2052        };
2053        let second_input = MamaInput::from_slice(&first_result.mama_values, second_params);
2054        let second_result = mama_with_kernel(&second_input, kernel)?;
2055        assert_eq!(
2056            second_result.mama_values.len(),
2057            first_result.mama_values.len()
2058        );
2059        assert_eq!(
2060            second_result.fama_values.len(),
2061            first_result.mama_values.len()
2062        );
2063        Ok(())
2064    }
2065
2066    fn check_mama_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2067        skip_if_unsupported!(kernel, test_name);
2068        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2069        let candles = read_candles_from_csv(file_path)?;
2070        let params = MamaParams::default();
2071        let input = MamaInput::from_candles(&candles, "close", params);
2072        let result = mama_with_kernel(&input, kernel)?;
2073        for (i, &val) in result.mama_values.iter().enumerate() {
2074            if i > 20 {
2075                assert!(val.is_finite());
2076            }
2077        }
2078        for (i, &val) in result.fama_values.iter().enumerate() {
2079            if i > 20 {
2080                assert!(val.is_finite());
2081            }
2082        }
2083        Ok(())
2084    }
2085
2086    macro_rules! generate_all_mama_tests {
2087        ($($test_fn:ident),*) => {
2088            paste! {
2089                $(
2090                    #[test]
2091                    fn [<$test_fn _scalar_f64>]() {
2092                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2093                    }
2094                )*
2095                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2096                $(
2097                    #[test]
2098                    fn [<$test_fn _avx2_f64>]() {
2099                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2100                    }
2101                    #[test]
2102                    fn [<$test_fn _avx512_f64>]() {
2103                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2104                    }
2105                )*
2106            }
2107        }
2108    }
2109
2110    #[cfg(debug_assertions)]
2111    fn check_mama_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2112        skip_if_unsupported!(kernel, test_name);
2113
2114        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2115        let candles = read_candles_from_csv(file_path)?;
2116
2117        let test_cases = vec![
2118            MamaParams::default(),
2119            MamaParams {
2120                fast_limit: Some(0.3),
2121                slow_limit: Some(0.03),
2122            },
2123            MamaParams {
2124                fast_limit: Some(0.4),
2125                slow_limit: Some(0.04),
2126            },
2127            MamaParams {
2128                fast_limit: Some(0.5),
2129                slow_limit: Some(0.05),
2130            },
2131            MamaParams {
2132                fast_limit: Some(0.6),
2133                slow_limit: Some(0.06),
2134            },
2135            MamaParams {
2136                fast_limit: Some(0.7),
2137                slow_limit: Some(0.07),
2138            },
2139            MamaParams {
2140                fast_limit: Some(0.8),
2141                slow_limit: Some(0.01),
2142            },
2143            MamaParams {
2144                fast_limit: Some(0.2),
2145                slow_limit: Some(0.1),
2146            },
2147            MamaParams {
2148                fast_limit: Some(0.9),
2149                slow_limit: Some(0.02),
2150            },
2151        ];
2152
2153        for params in test_cases {
2154            let input = MamaInput::from_candles(&candles, "close", params.clone());
2155            let output = mama_with_kernel(&input, kernel)?;
2156
2157            for (i, &val) in output.mama_values.iter().enumerate() {
2158                if val.is_nan() {
2159                    continue;
2160                }
2161
2162                let bits = val.to_bits();
2163
2164                if bits == 0x11111111_11111111 {
2165                    panic!(
2166                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in mama_values with params fast_limit={:?}, slow_limit={:?}",
2167                        test_name, val, bits, i, params.fast_limit, params.slow_limit
2168                    );
2169                }
2170
2171                if bits == 0x22222222_22222222 {
2172                    panic!(
2173                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in mama_values with params fast_limit={:?}, slow_limit={:?}",
2174                        test_name, val, bits, i, params.fast_limit, params.slow_limit
2175                    );
2176                }
2177
2178                if bits == 0x33333333_33333333 {
2179                    panic!(
2180                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in mama_values with params fast_limit={:?}, slow_limit={:?}",
2181                        test_name, val, bits, i, params.fast_limit, params.slow_limit
2182                    );
2183                }
2184            }
2185
2186            for (i, &val) in output.fama_values.iter().enumerate() {
2187                if val.is_nan() {
2188                    continue;
2189                }
2190
2191                let bits = val.to_bits();
2192
2193                if bits == 0x11111111_11111111 {
2194                    panic!(
2195                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in fama_values with params fast_limit={:?}, slow_limit={:?}",
2196                        test_name, val, bits, i, params.fast_limit, params.slow_limit
2197                    );
2198                }
2199
2200                if bits == 0x22222222_22222222 {
2201                    panic!(
2202                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in fama_values with params fast_limit={:?}, slow_limit={:?}",
2203                        test_name, val, bits, i, params.fast_limit, params.slow_limit
2204                    );
2205                }
2206
2207                if bits == 0x33333333_33333333 {
2208                    panic!(
2209                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in fama_values with params fast_limit={:?}, slow_limit={:?}",
2210                        test_name, val, bits, i, params.fast_limit, params.slow_limit
2211                    );
2212                }
2213            }
2214        }
2215
2216        Ok(())
2217    }
2218
2219    #[cfg(not(debug_assertions))]
2220    fn check_mama_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2221        Ok(())
2222    }
2223
2224    fn check_mama_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2225        skip_if_unsupported!(kernel, test_name);
2226
2227        let strat = (10usize..=200).prop_flat_map(|len| {
2228            (
2229                prop::collection::vec(
2230                    (-1e5f64..1e5f64).prop_filter("finite", |x| x.is_finite()),
2231                    len,
2232                ),
2233                (0.01f64..0.99f64).prop_filter("valid fast_limit", |x| x.is_finite() && *x > 0.0),
2234                (0.001f64..0.5f64).prop_filter("valid slow_limit", |x| x.is_finite() && *x > 0.0),
2235            )
2236        });
2237
2238        proptest::test_runner::TestRunner::default()
2239            .run(&strat, |(data, fast_limit, slow_limit)| {
2240                let slow = slow_limit.min(fast_limit * 0.9);
2241
2242                let params = MamaParams {
2243                    fast_limit: Some(fast_limit),
2244                    slow_limit: Some(slow),
2245                };
2246                let input = MamaInput::from_slice(&data, params);
2247
2248                let result = mama_with_kernel(&input, kernel).unwrap();
2249                let mama_out = &result.mama_values;
2250                let fama_out = &result.fama_values;
2251
2252                let ref_result = mama_with_kernel(&input, Kernel::Scalar).unwrap();
2253                let ref_mama = &ref_result.mama_values;
2254                let ref_fama = &ref_result.fama_values;
2255
2256                prop_assert_eq!(mama_out.len(), data.len(), "MAMA output length mismatch");
2257                prop_assert_eq!(fama_out.len(), data.len(), "FAMA output length mismatch");
2258
2259                const WARMUP: usize = 10;
2260                for i in 0..data.len() {
2261                    if i < WARMUP {
2262                        prop_assert!(
2263                            mama_out[i].is_nan(),
2264                            "MAMA should have NaN warmup at index {}, got {}",
2265                            i,
2266                            mama_out[i]
2267                        );
2268                        prop_assert!(
2269                            fama_out[i].is_nan(),
2270                            "FAMA should have NaN warmup at index {}, got {}",
2271                            i,
2272                            fama_out[i]
2273                        );
2274                    } else {
2275                        prop_assert!(
2276                            mama_out[i].is_finite(),
2277                            "MAMA should output finite values at index {}, got {}",
2278                            i,
2279                            mama_out[i]
2280                        );
2281                        prop_assert!(
2282                            fama_out[i].is_finite(),
2283                            "FAMA should output finite values at index {}, got {}",
2284                            i,
2285                            fama_out[i]
2286                        );
2287                    }
2288                }
2289
2290                let data_min = data.iter().cloned().fold(f64::INFINITY, f64::min);
2291                let data_max = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
2292                let data_range = data_max - data_min;
2293
2294                let tolerance = data_range * 0.2 + 10.0;
2295
2296                for i in WARMUP..data.len() {
2297                    prop_assert!(
2298                        mama_out[i] >= data_min - tolerance && mama_out[i] <= data_max + tolerance,
2299                        "MAMA at index {} ({}) outside bounds [{}, {}]",
2300                        i,
2301                        mama_out[i],
2302                        data_min - tolerance,
2303                        data_max + tolerance
2304                    );
2305                    prop_assert!(
2306                        fama_out[i] >= data_min - tolerance && fama_out[i] <= data_max + tolerance,
2307                        "FAMA at index {} ({}) outside bounds [{}, {}]",
2308                        i,
2309                        fama_out[i],
2310                        data_min - tolerance,
2311                        data_max + tolerance
2312                    );
2313                }
2314
2315                if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-9) {
2316                    let constant_val = data[0];
2317
2318                    for i in 10..data.len() {
2319                        prop_assert!(
2320                            (mama_out[i] - constant_val).abs() < 1e-6,
2321                            "MAMA should converge to constant value {} at index {}, got {}",
2322                            constant_val,
2323                            i,
2324                            mama_out[i]
2325                        );
2326                        prop_assert!(
2327                            (fama_out[i] - constant_val).abs() < 1e-6,
2328                            "FAMA should converge to constant value {} at index {}, got {}",
2329                            constant_val,
2330                            i,
2331                            fama_out[i]
2332                        );
2333                    }
2334                }
2335
2336                if data.len() > 30 {
2337                    let mama_variance = variance(&mama_out[10..]);
2338                    let fama_variance = variance(&fama_out[10..]);
2339
2340                    prop_assert!(
2341                        mama_variance >= 0.0 && mama_variance.is_finite(),
2342                        "MAMA variance should be finite and non-negative: {}",
2343                        mama_variance
2344                    );
2345                    prop_assert!(
2346                        fama_variance >= 0.0 && fama_variance.is_finite(),
2347                        "FAMA variance should be finite and non-negative: {}",
2348                        fama_variance
2349                    );
2350
2351                    let data_variance = variance(&data);
2352                    if data_variance > 1e-6 {
2353                        prop_assert!(
2354                            mama_variance < data_variance * 100.0,
2355                            "MAMA variance ({}) too large relative to data variance ({})",
2356                            mama_variance,
2357                            data_variance
2358                        );
2359                        prop_assert!(
2360                            fama_variance < data_variance * 100.0,
2361                            "FAMA variance ({}) too large relative to data variance ({})",
2362                            fama_variance,
2363                            data_variance
2364                        );
2365                    }
2366                }
2367
2368                for i in WARMUP..data.len() {
2369                    prop_assert!(
2370                        mama_out[i].is_finite(),
2371                        "MAMA kernel {:?} produced non-finite value at idx {}: {}",
2372                        kernel,
2373                        i,
2374                        mama_out[i]
2375                    );
2376                    prop_assert!(
2377                        fama_out[i].is_finite(),
2378                        "FAMA kernel {:?} produced non-finite value at idx {}: {}",
2379                        kernel,
2380                        i,
2381                        fama_out[i]
2382                    );
2383                }
2384
2385                if data.len() > 50 && fast_limit > slow * 2.0 && variance(&data) > 1e-6 {
2386                    let alt_params = MamaParams {
2387                        fast_limit: Some(fast_limit * 0.5),
2388                        slow_limit: Some(slow),
2389                    };
2390                    let alt_input = MamaInput::from_slice(&data, alt_params);
2391                    if let Ok(alt_result) = mama_with_kernel(&alt_input, kernel) {
2392                        let mama_var = variance(&mama_out[20..]);
2393                        let alt_var = variance(&alt_result.mama_values[20..]);
2394
2395                        if mama_var > 1e-6 && alt_var > 1e-6 {
2396                            prop_assert!(
2397                                (mama_var - alt_var).abs() > 1e-12,
2398                                "MAMA should be sensitive to fast_limit parameter"
2399                            );
2400                        }
2401                    }
2402                }
2403
2404                if (fast_limit - slow).abs() < 0.01 && data.len() > 20 {
2405                    for i in 10..data.len() {
2406                        prop_assert!(
2407                            mama_out[i].is_finite() && fama_out[i].is_finite(),
2408                            "MAMA/FAMA should remain finite even with close limits at idx {}",
2409                            i
2410                        );
2411
2412                        prop_assert!(
2413                            mama_out[i].abs() < data_max.abs() * 100.0 + 1000.0,
2414                            "MAMA should not diverge with close limits"
2415                        );
2416                        prop_assert!(
2417                            fama_out[i].abs() < data_max.abs() * 100.0 + 1000.0,
2418                            "FAMA should not diverge with close limits"
2419                        );
2420                    }
2421                }
2422
2423                let is_monotonic_inc = data.windows(2).all(|w| w[1] >= w[0] - 1e-9);
2424                let is_monotonic_dec = data.windows(2).all(|w| w[1] <= w[0] + 1e-9);
2425
2426                if (is_monotonic_inc || is_monotonic_dec) && data.len() > 20 {
2427                    for i in 11..data.len() {
2428                        if is_monotonic_inc {
2429                            prop_assert!(
2430                                mama_out[i] >= mama_out[i - 10] - tolerance * 0.1,
2431                                "MAMA should follow increasing trend at idx {}",
2432                                i
2433                            );
2434                        }
2435                        if is_monotonic_dec {
2436                            prop_assert!(
2437                                mama_out[i] <= mama_out[i - 10] + tolerance * 0.1,
2438                                "MAMA should follow decreasing trend at idx {}",
2439                                i
2440                            );
2441                        }
2442                    }
2443                }
2444
2445                Ok(())
2446            })
2447            .unwrap();
2448
2449        Ok(())
2450    }
2451
2452    fn variance(data: &[f64]) -> f64 {
2453        if data.is_empty() {
2454            return 0.0;
2455        }
2456        let mean = data.iter().sum::<f64>() / data.len() as f64;
2457        data.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / data.len() as f64
2458    }
2459
2460    generate_all_mama_tests!(
2461        check_mama_partial_params,
2462        check_mama_accuracy,
2463        check_mama_default_candles,
2464        check_mama_with_insufficient_data,
2465        check_mama_very_small_dataset,
2466        check_mama_reinput,
2467        check_mama_nan_handling,
2468        check_mama_no_poison,
2469        check_mama_property
2470    );
2471
2472    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2473        skip_if_unsupported!(kernel, test);
2474        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2475        let c = read_candles_from_csv(file)?;
2476        let output = MamaBatchBuilder::new()
2477            .kernel(kernel)
2478            .apply_candles(&c, "close")?;
2479        let def = MamaParams::default();
2480        let mama_row = output.mama_for(&def).expect("default row missing");
2481        assert_eq!(mama_row.len(), c.close.len());
2482        Ok(())
2483    }
2484
2485    macro_rules! gen_batch_tests {
2486        ($fn_name:ident) => {
2487            paste! {
2488                #[test] fn [<$fn_name _scalar>]()      {
2489                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2490                }
2491                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2492                #[test] fn [<$fn_name _avx2>]()        {
2493                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2494                }
2495                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2496                #[test] fn [<$fn_name _avx512>]()      {
2497                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2498                }
2499                #[test] fn [<$fn_name _auto_detect>]() {
2500                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2501                }
2502            }
2503        };
2504    }
2505
2506    #[cfg(debug_assertions)]
2507    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2508        skip_if_unsupported!(kernel, test);
2509
2510        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2511        let c = read_candles_from_csv(file)?;
2512
2513        let test_configs = vec![
2514            ((0.2, 0.4, 0.1), (0.02, 0.04, 0.01)),
2515            ((0.3, 0.7, 0.2), (0.03, 0.07, 0.02)),
2516            ((0.4, 0.9, 0.1), (0.01, 0.09, 0.02)),
2517            ((0.5, 0.8, 0.15), (0.01, 0.03, 0.01)),
2518            ((0.2, 0.6, 0.05), (0.02, 0.08, 0.01)),
2519        ];
2520
2521        for (fast_range, slow_range) in test_configs {
2522            let output = MamaBatchBuilder::new()
2523                .kernel(kernel)
2524                .fast_limit_range(fast_range.0, fast_range.1, fast_range.2)
2525                .slow_limit_range(slow_range.0, slow_range.1, slow_range.2)
2526                .apply_candles(&c, "close")?;
2527
2528            for (idx, &val) in output.mama_values.iter().enumerate() {
2529                if val.is_nan() {
2530                    continue;
2531                }
2532
2533                let bits = val.to_bits();
2534                let row = idx / output.cols;
2535                let col = idx % output.cols;
2536                let params = &output.combos[row];
2537
2538                if bits == 0x11111111_11111111 {
2539                    panic!(
2540                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} in mama_values (params: fast_limit={:?}, slow_limit={:?})",
2541                        test, val, bits, row, col, params.fast_limit, params.slow_limit
2542                    );
2543                }
2544
2545                if bits == 0x22222222_22222222 {
2546                    panic!(
2547                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} in mama_values (params: fast_limit={:?}, slow_limit={:?})",
2548                        test, val, bits, row, col, params.fast_limit, params.slow_limit
2549                    );
2550                }
2551
2552                if bits == 0x33333333_33333333 {
2553                    panic!(
2554                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} in mama_values (params: fast_limit={:?}, slow_limit={:?})",
2555                        test, val, bits, row, col, params.fast_limit, params.slow_limit
2556                    );
2557                }
2558            }
2559
2560            for (idx, &val) in output.fama_values.iter().enumerate() {
2561                if val.is_nan() {
2562                    continue;
2563                }
2564
2565                let bits = val.to_bits();
2566                let row = idx / output.cols;
2567                let col = idx % output.cols;
2568                let params = &output.combos[row];
2569
2570                if bits == 0x11111111_11111111 {
2571                    panic!(
2572                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} in fama_values (params: fast_limit={:?}, slow_limit={:?})",
2573                        test, val, bits, row, col, params.fast_limit, params.slow_limit
2574                    );
2575                }
2576
2577                if bits == 0x22222222_22222222 {
2578                    panic!(
2579                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} in fama_values (params: fast_limit={:?}, slow_limit={:?})",
2580                        test, val, bits, row, col, params.fast_limit, params.slow_limit
2581                    );
2582                }
2583
2584                if bits == 0x33333333_33333333 {
2585                    panic!(
2586                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} in fama_values (params: fast_limit={:?}, slow_limit={:?})",
2587                        test, val, bits, row, col, params.fast_limit, params.slow_limit
2588                    );
2589                }
2590            }
2591        }
2592
2593        Ok(())
2594    }
2595
2596    #[cfg(not(debug_assertions))]
2597    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2598        Ok(())
2599    }
2600
2601    #[test]
2602    fn test_mama_into_matches_api() -> Result<(), Box<dyn Error>> {
2603        let n = 256usize;
2604        let data: Vec<f64> = (0..n)
2605            .map(|i| {
2606                let t = i as f64;
2607                (t * 0.013).sin() * 10.0 + (t * 0.01)
2608            })
2609            .collect();
2610
2611        let input = MamaInput::from_slice(&data, MamaParams::default());
2612
2613        let baseline = mama(&input)?;
2614
2615        let mut out_mama = vec![0.0; n];
2616        let mut out_fama = vec![0.0; n];
2617        #[allow(unused_variables)]
2618        {
2619            #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2620            {
2621                super::mama_into(&input, &mut out_mama, &mut out_fama)?;
2622            }
2623            #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2624            {
2625                super::mama_into_slice(&mut out_mama, &mut out_fama, &input, Kernel::Auto)?;
2626            }
2627        }
2628
2629        fn eq_or_both_nan(a: f64, b: f64) -> bool {
2630            (a.is_nan() && b.is_nan()) || (a == b)
2631        }
2632
2633        assert_eq!(baseline.mama_values.len(), out_mama.len());
2634        assert_eq!(baseline.fama_values.len(), out_fama.len());
2635        for i in 0..n {
2636            assert!(
2637                eq_or_both_nan(baseline.mama_values[i], out_mama[i]),
2638                "mama mismatch at {}: left={} right={}",
2639                i,
2640                baseline.mama_values[i],
2641                out_mama[i]
2642            );
2643            assert!(
2644                eq_or_both_nan(baseline.fama_values[i], out_fama[i]),
2645                "fama mismatch at {}: left={} right={}",
2646                i,
2647                baseline.fama_values[i],
2648                out_fama[i]
2649            );
2650        }
2651        Ok(())
2652    }
2653
2654    gen_batch_tests!(check_batch_default_row);
2655    gen_batch_tests!(check_batch_no_poison);
2656}
2657
2658#[cfg(feature = "python")]
2659mod python_bindings {
2660    use super::*;
2661    #[cfg(feature = "cuda")]
2662    use crate::cuda::cuda_available;
2663    #[cfg(feature = "cuda")]
2664    use crate::cuda::moving_averages::{CudaMama, DeviceMamaPair};
2665    use crate::utilities::kernel_validation::validate_kernel;
2666    #[cfg(feature = "cuda")]
2667    use cust::context::Context;
2668    #[cfg(feature = "cuda")]
2669    use cust::memory::DeviceBuffer;
2670    #[cfg(feature = "cuda")]
2671    use numpy::PyReadonlyArray2;
2672    use numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray1};
2673    use pyo3::exceptions::PyValueError;
2674    use pyo3::prelude::*;
2675    use pyo3::types::PyDictMethods;
2676    #[cfg(feature = "cuda")]
2677    use std::os::raw::c_void;
2678    #[cfg(feature = "cuda")]
2679    use std::sync::Arc;
2680
2681    use pyo3::types::PyDict;
2682    use pyo3::{pyclass, pymethods};
2683    use std::collections::HashMap;
2684
2685    #[pyfunction]
2686    #[pyo3(name = "mama")]
2687    #[pyo3(signature = (data, fast_limit, slow_limit, kernel=None))]
2688    pub fn mama_py<'py>(
2689        py: Python<'py>,
2690        data: PyReadonlyArray1<'py, f64>,
2691        fast_limit: f64,
2692        slow_limit: f64,
2693        kernel: Option<&str>,
2694    ) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
2695        let slice_in = data.as_slice()?;
2696        let params = MamaParams {
2697            fast_limit: Some(fast_limit),
2698            slow_limit: Some(slow_limit),
2699        };
2700        let input = MamaInput::from_slice(slice_in, params);
2701        let kern = validate_kernel(kernel, false)?;
2702
2703        let len = slice_in.len();
2704
2705        let out_m = unsafe { PyArray1::<f64>::new(py, [len], false) };
2706        let out_f = unsafe { PyArray1::<f64>::new(py, [len], false) };
2707        let sm = unsafe { out_m.as_slice_mut()? };
2708        let sf = unsafe { out_f.as_slice_mut()? };
2709
2710        py.allow_threads(|| mama_into_slice(sm, sf, &input, kern))
2711            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2712
2713        Ok((out_m, out_f))
2714    }
2715
2716    #[pyfunction]
2717    #[pyo3(name = "mama_batch")]
2718    #[pyo3(signature = (data, fast_limit_range, slow_limit_range, kernel=None))]
2719    pub fn mama_batch_py<'py>(
2720        py: Python<'py>,
2721        data: PyReadonlyArray1<'py, f64>,
2722        fast_limit_range: (f64, f64, f64),
2723        slow_limit_range: (f64, f64, f64),
2724        kernel: Option<&str>,
2725    ) -> PyResult<Bound<'py, PyDict>> {
2726        let slice_in = data.as_slice()?;
2727        let sweep = MamaBatchRange {
2728            fast_limit: fast_limit_range,
2729            slow_limit: slow_limit_range,
2730        };
2731
2732        let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2733        let rows = combos.len();
2734        let cols = slice_in.len();
2735        let total = rows
2736            .checked_mul(cols)
2737            .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2738
2739        let mama_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2740        let fama_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2741        let mama_slice = unsafe { mama_arr.as_slice_mut()? };
2742        let fama_slice = unsafe { fama_arr.as_slice_mut()? };
2743
2744        let kern = validate_kernel(kernel, true)?;
2745
2746        let combos = py
2747            .allow_threads(|| -> Result<Vec<MamaParams>, MamaError> {
2748                let simd = match kern {
2749                    Kernel::Auto | Kernel::ScalarBatch => Kernel::Scalar,
2750                    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2751                    Kernel::Avx512Batch => Kernel::Avx512,
2752                    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2753                    Kernel::Avx2Batch => Kernel::Avx2,
2754
2755                    _ => Kernel::Scalar,
2756                };
2757
2758                mama_batch_inner_into(slice_in, &sweep, simd, true, mama_slice, fama_slice)
2759            })
2760            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2761
2762        let dict = PyDict::new(py);
2763        dict.set_item("mama", mama_arr.reshape((rows, cols))?)?;
2764        dict.set_item("fama", fama_arr.reshape((rows, cols))?)?;
2765        dict.set_item(
2766            "fast_limits",
2767            combos
2768                .iter()
2769                .map(|p| p.fast_limit.unwrap_or(0.5))
2770                .collect::<Vec<_>>()
2771                .into_pyarray(py),
2772        )?;
2773        dict.set_item(
2774            "slow_limits",
2775            combos
2776                .iter()
2777                .map(|p| p.slow_limit.unwrap_or(0.05))
2778                .collect::<Vec<_>>()
2779                .into_pyarray(py),
2780        )?;
2781
2782        Ok(dict)
2783    }
2784
2785    #[cfg(feature = "cuda")]
2786    #[pyfunction(name = "mama_cuda_batch_dev")]
2787    #[pyo3(signature = (data_f32, fast_limit_range, slow_limit_range, device_id=0))]
2788    pub fn mama_cuda_batch_dev_py(
2789        py: Python<'_>,
2790        data_f32: PyReadonlyArray1<'_, f32>,
2791        fast_limit_range: (f64, f64, f64),
2792        slow_limit_range: (f64, f64, f64),
2793        device_id: usize,
2794    ) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
2795        if !cuda_available() {
2796            return Err(PyValueError::new_err("CUDA not available"));
2797        }
2798
2799        let slice_in = data_f32.as_slice()?;
2800        let sweep = MamaBatchRange {
2801            fast_limit: fast_limit_range,
2802            slow_limit: slow_limit_range,
2803        };
2804
2805        let (pair, ctx, dev_id) = py.allow_threads(|| {
2806            let cuda =
2807                CudaMama::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2808            let ctx = cuda.context_arc();
2809            let dev_id = cuda.device_id();
2810            let pair = cuda
2811                .mama_batch_dev(slice_in, &sweep)
2812                .map_err(|e| PyValueError::new_err(e.to_string()))?;
2813            Ok::<_, PyErr>((pair, ctx, dev_id))
2814        })?;
2815
2816        let DeviceMamaPair { mama, fama } = pair;
2817        Ok((
2818            DeviceArrayF32Py {
2819                buf: Some(mama.buf),
2820                rows: mama.rows,
2821                cols: mama.cols,
2822                _ctx: ctx.clone(),
2823                device_id: dev_id,
2824            },
2825            DeviceArrayF32Py {
2826                buf: Some(fama.buf),
2827                rows: fama.rows,
2828                cols: fama.cols,
2829                _ctx: ctx,
2830                device_id: dev_id,
2831            },
2832        ))
2833    }
2834
2835    #[cfg(feature = "cuda")]
2836    #[pyfunction(name = "mama_cuda_many_series_one_param_dev")]
2837    #[pyo3(signature = (data_tm_f32, fast_limit, slow_limit, device_id=0))]
2838    pub fn mama_cuda_many_series_one_param_dev_py(
2839        py: Python<'_>,
2840        data_tm_f32: PyReadonlyArray2<'_, f32>,
2841        fast_limit: f64,
2842        slow_limit: f64,
2843        device_id: usize,
2844    ) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
2845        use numpy::PyUntypedArrayMethods;
2846
2847        if !cuda_available() {
2848            return Err(PyValueError::new_err("CUDA not available"));
2849        }
2850
2851        let shape = data_tm_f32.shape();
2852        if shape.len() != 2 {
2853            return Err(PyValueError::new_err("expected 2D array"));
2854        }
2855        let rows = shape[0];
2856        let cols = shape[1];
2857        let flat = data_tm_f32.as_slice()?;
2858
2859        let fast = fast_limit as f32;
2860        let slow = slow_limit as f32;
2861
2862        let (pair, ctx, dev_id) = py.allow_threads(|| {
2863            let cuda =
2864                CudaMama::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2865            let ctx = cuda.context_arc();
2866            let dev_id = cuda.device_id();
2867            let pair = cuda
2868                .mama_many_series_one_param_time_major_dev(flat, cols, rows, fast, slow)
2869                .map_err(|e| PyValueError::new_err(e.to_string()))?;
2870            Ok::<_, PyErr>((pair, ctx, dev_id))
2871        })?;
2872
2873        let DeviceMamaPair { mama, fama } = pair;
2874        Ok((
2875            DeviceArrayF32Py {
2876                buf: Some(mama.buf),
2877                rows: mama.rows,
2878                cols: mama.cols,
2879                _ctx: ctx.clone(),
2880                device_id: dev_id,
2881            },
2882            DeviceArrayF32Py {
2883                buf: Some(fama.buf),
2884                rows: fama.rows,
2885                cols: fama.cols,
2886                _ctx: ctx,
2887                device_id: dev_id,
2888            },
2889        ))
2890    }
2891
2892    #[pyclass]
2893    #[pyo3(name = "MamaStream")]
2894    pub struct MamaStreamPy {
2895        inner: MamaStream,
2896    }
2897
2898    #[pymethods]
2899    impl MamaStreamPy {
2900        #[new]
2901        pub fn new(fast_limit: f64, slow_limit: f64) -> PyResult<Self> {
2902            let params = MamaParams {
2903                fast_limit: Some(fast_limit),
2904                slow_limit: Some(slow_limit),
2905            };
2906            let stream =
2907                MamaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2908            Ok(Self { inner: stream })
2909        }
2910
2911        pub fn update(&mut self, value: f64) -> Option<(f64, f64)> {
2912            self.inner.update(value)
2913        }
2914    }
2915}
2916
2917#[cfg(feature = "python")]
2918pub use python_bindings::{mama_batch_py, mama_py, MamaStreamPy};
2919#[cfg(all(feature = "python", feature = "cuda"))]
2920pub use python_bindings::{mama_cuda_batch_dev_py, mama_cuda_many_series_one_param_dev_py};
2921
2922#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2923use serde::{Deserialize, Serialize};
2924
2925#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2926#[derive(Serialize, Deserialize)]
2927pub struct MamaResult {
2928    pub values: Vec<f64>,
2929    pub rows: usize,
2930    pub cols: usize,
2931}
2932
2933#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2934#[wasm_bindgen(js_name = "mama")]
2935pub fn mama_js(data: &[f64], fast_limit: f64, slow_limit: f64) -> Result<JsValue, JsValue> {
2936    let params = MamaParams {
2937        fast_limit: Some(fast_limit),
2938        slow_limit: Some(slow_limit),
2939    };
2940    let input = MamaInput::from_slice(data, params);
2941
2942    let mut mama = vec![0.0; data.len()];
2943    let mut fama = vec![0.0; data.len()];
2944    mama_into_slice(&mut mama, &mut fama, &input, detect_best_kernel())
2945        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2946
2947    let mut values = mama;
2948    values.extend_from_slice(&fama);
2949
2950    let out = MamaResult {
2951        values,
2952        rows: 2,
2953        cols: data.len(),
2954    };
2955    serde_wasm_bindgen::to_value(&out)
2956        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2957}
2958
2959#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2960#[wasm_bindgen(js_name = "mama_into")]
2961pub fn mama_into(
2962    in_ptr: *const f64,
2963    out_m_ptr: *mut f64,
2964    out_f_ptr: *mut f64,
2965    len: usize,
2966    fast_limit: f64,
2967    slow_limit: f64,
2968) -> Result<(), JsValue> {
2969    if in_ptr.is_null() || out_m_ptr.is_null() || out_f_ptr.is_null() {
2970        return Err(JsValue::from_str("null pointer passed to mama_into"));
2971    }
2972    unsafe {
2973        let data = core::slice::from_raw_parts(in_ptr, len);
2974        let out_m = core::slice::from_raw_parts_mut(out_m_ptr, len);
2975        let out_f = core::slice::from_raw_parts_mut(out_f_ptr, len);
2976        let params = MamaParams {
2977            fast_limit: Some(fast_limit),
2978            slow_limit: Some(slow_limit),
2979        };
2980        let input = MamaInput::from_slice(data, params);
2981        mama_into_slice(out_m, out_f, &input, detect_best_kernel())
2982            .map_err(|e| JsValue::from_str(&e.to_string()))
2983    }
2984}
2985
2986#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2987#[derive(Serialize, Deserialize)]
2988pub struct MamaBatchJsOutput {
2989    pub mama: Vec<f64>,
2990    pub fama: Vec<f64>,
2991    pub combos: Vec<MamaParams>,
2992    pub rows: usize,
2993    pub cols: usize,
2994}
2995
2996#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2997#[wasm_bindgen(js_name = "mama_batch")]
2998pub fn mama_batch_js(
2999    data: &[f64],
3000    fast_start: f64,
3001    fast_end: f64,
3002    fast_step: f64,
3003    slow_start: f64,
3004    slow_end: f64,
3005    slow_step: f64,
3006) -> Result<JsValue, JsValue> {
3007    let sweep = MamaBatchRange {
3008        fast_limit: (fast_start, fast_end, fast_step),
3009        slow_limit: (slow_start, slow_end, slow_step),
3010    };
3011    let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
3012    let rows = combos.len();
3013    let cols = data.len();
3014    let total = rows
3015        .checked_mul(cols)
3016        .ok_or(JsValue::from_str("rows*cols overflow"))?;
3017
3018    let mut mama_values = vec![0.0; total];
3019    let mut fama_values = vec![0.0; total];
3020
3021    let kern = detect_best_kernel();
3022    mama_batch_inner_into(
3023        data,
3024        &sweep,
3025        kern,
3026        false,
3027        &mut mama_values,
3028        &mut fama_values,
3029    )
3030    .map_err(|e| JsValue::from_str(&e.to_string()))?;
3031
3032    let out = MamaBatchJsOutput {
3033        mama: mama_values,
3034        fama: fama_values,
3035        combos,
3036        rows,
3037        cols,
3038    };
3039    serde_wasm_bindgen::to_value(&out)
3040        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
3041}
3042
3043#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3044#[wasm_bindgen]
3045pub fn mama_batch_metadata_js(
3046    fast_limit_start: f64,
3047    fast_limit_end: f64,
3048    fast_limit_step: f64,
3049    slow_limit_start: f64,
3050    slow_limit_end: f64,
3051    slow_limit_step: f64,
3052) -> Vec<f64> {
3053    let range = MamaBatchRange {
3054        fast_limit: (fast_limit_start, fast_limit_end, fast_limit_step),
3055        slow_limit: (slow_limit_start, slow_limit_end, slow_limit_step),
3056    };
3057
3058    let combos = expand_grid(&range).unwrap_or_else(|_| Vec::new());
3059    let mut metadata = Vec::with_capacity(combos.len() * 2);
3060
3061    for combo in combos {
3062        metadata.push(combo.fast_limit.unwrap_or(0.5));
3063        metadata.push(combo.slow_limit.unwrap_or(0.05));
3064    }
3065
3066    metadata
3067}
3068
3069#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3070#[wasm_bindgen]
3071pub fn mama_batch_rows_cols_js(
3072    fast_limit_start: f64,
3073    fast_limit_end: f64,
3074    fast_limit_step: f64,
3075    slow_limit_start: f64,
3076    slow_limit_end: f64,
3077    slow_limit_step: f64,
3078    data_len: usize,
3079) -> Vec<usize> {
3080    let range = MamaBatchRange {
3081        fast_limit: (fast_limit_start, fast_limit_end, fast_limit_step),
3082        slow_limit: (slow_limit_start, slow_limit_end, slow_limit_step),
3083    };
3084
3085    let combos = expand_grid(&range).unwrap_or_else(|_| Vec::new());
3086    vec![combos.len(), data_len]
3087}
3088
3089#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3090#[wasm_bindgen]
3091pub fn mama_alloc(len: usize) -> *mut f64 {
3092    let mut vec = Vec::<f64>::with_capacity(len);
3093    let ptr = vec.as_mut_ptr();
3094    std::mem::forget(vec);
3095    ptr
3096}
3097
3098#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3099#[wasm_bindgen]
3100pub fn mama_free(ptr: *mut f64, len: usize) {
3101    if !ptr.is_null() {
3102        unsafe {
3103            let _ = Vec::from_raw_parts(ptr, len, len);
3104        }
3105    }
3106}
3107
3108#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3109#[wasm_bindgen]
3110pub fn mama_batch_into(
3111    in_ptr: *const f64,
3112    out_mama_ptr: *mut f64,
3113    out_fama_ptr: *mut f64,
3114    len: usize,
3115    fast_limit_start: f64,
3116    fast_limit_end: f64,
3117    fast_limit_step: f64,
3118    slow_limit_start: f64,
3119    slow_limit_end: f64,
3120    slow_limit_step: f64,
3121) -> Result<usize, JsValue> {
3122    if in_ptr.is_null() || out_mama_ptr.is_null() || out_fama_ptr.is_null() {
3123        return Err(JsValue::from_str("null pointer passed to mama_batch_into"));
3124    }
3125
3126    unsafe {
3127        let data = std::slice::from_raw_parts(in_ptr, len);
3128
3129        let range = MamaBatchRange {
3130            fast_limit: (fast_limit_start, fast_limit_end, fast_limit_step),
3131            slow_limit: (slow_limit_start, slow_limit_end, slow_limit_step),
3132        };
3133
3134        let batch_output = mama_batch_with_kernel(data, &range, Kernel::Auto)
3135            .map_err(|e| JsValue::from_str(&e.to_string()))?;
3136
3137        let rows = batch_output.combos.len();
3138        let cols = len;
3139        let total_elements = rows * cols;
3140
3141        let out_mama = std::slice::from_raw_parts_mut(out_mama_ptr, total_elements);
3142        out_mama.copy_from_slice(&batch_output.mama_values);
3143
3144        let out_fama = std::slice::from_raw_parts_mut(out_fama_ptr, total_elements);
3145        out_fama.copy_from_slice(&batch_output.fama_values);
3146
3147        Ok(rows)
3148    }
3149}
3150
3151#[cfg(all(feature = "python", feature = "cuda"))]
3152#[pyo3::pyclass(module = "ta_indicators.cuda", unsendable)]
3153pub struct DeviceArrayF32Py {
3154    pub(crate) buf: Option<cust::memory::DeviceBuffer<f32>>,
3155    pub(crate) rows: usize,
3156    pub(crate) cols: usize,
3157    pub(crate) _ctx: std::sync::Arc<cust::context::Context>,
3158    pub(crate) device_id: u32,
3159}
3160
3161#[cfg(all(feature = "python", feature = "cuda"))]
3162#[pyo3::pymethods]
3163impl DeviceArrayF32Py {
3164    #[getter]
3165    fn __cuda_array_interface__<'py>(
3166        &self,
3167        py: pyo3::Python<'py>,
3168    ) -> pyo3::PyResult<pyo3::prelude::Bound<'py, pyo3::types::PyDict>> {
3169        let d = pyo3::types::PyDict::new(py);
3170        pyo3::types::PyDictMethods::set_item(&d, "shape", (self.rows, self.cols))?;
3171        pyo3::types::PyDictMethods::set_item(&d, "typestr", "<f4")?;
3172        pyo3::types::PyDictMethods::set_item(
3173            &d,
3174            "strides",
3175            (
3176                self.cols * std::mem::size_of::<f32>(),
3177                std::mem::size_of::<f32>(),
3178            ),
3179        )?;
3180        let ptr = self
3181            .buf
3182            .as_ref()
3183            .ok_or_else(|| {
3184                pyo3::exceptions::PyValueError::new_err("buffer already exported via __dlpack__")
3185            })?
3186            .as_device_ptr()
3187            .as_raw() as usize;
3188        pyo3::types::PyDictMethods::set_item(&d, "data", (ptr, false))?;
3189        pyo3::types::PyDictMethods::set_item(&d, "version", 3)?;
3190        Ok(d)
3191    }
3192
3193    fn __dlpack_device__(&self) -> (i32, i32) {
3194        (2, self.device_id as i32)
3195    }
3196
3197    #[cfg(feature = "mama_legacy_dlpack")]
3198    #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
3199    fn __dlpack_legacy__<'py>(
3200        &mut self,
3201        py: pyo3::Python<'py>,
3202        stream: Option<&pyo3::types::PyAny>,
3203        max_version: Option<&pyo3::types::PyAny>,
3204        dl_device: Option<&pyo3::types::PyAny>,
3205        copy: Option<&pyo3::types::PyAny>,
3206    ) -> pyo3::PyResult<pyo3::PyObject> {
3207        use std::os::raw::c_char;
3208
3209        let buf = self.buf.take().ok_or_else(|| {
3210            pyo3::exceptions::PyValueError::new_err("__dlpack__ may only be called once")
3211        })?;
3212
3213        #[repr(C)]
3214        struct DLDevice {
3215            device_type: i32,
3216            device_id: i32,
3217        }
3218        #[repr(C)]
3219        struct DLDataType {
3220            code: u8,
3221            bits: u8,
3222            lanes: u16,
3223        }
3224        #[repr(C)]
3225        struct DLTensor {
3226            data: *mut std::ffi::c_void,
3227            device: DLDevice,
3228            ndim: i32,
3229            dtype: DLDataType,
3230            shape: *mut i64,
3231            strides: *mut i64,
3232            byte_offset: u64,
3233        }
3234        #[repr(C)]
3235        struct DLManagedTensor {
3236            dl_tensor: DLTensor,
3237            manager_ctx: *mut std::ffi::c_void,
3238            deleter: Option<extern "C" fn(*mut DLManagedTensor)>,
3239        }
3240        #[repr(C)]
3241        struct DLVersion {
3242            major: i32,
3243            minor: i32,
3244        }
3245        #[repr(C)]
3246        struct DLManagedTensorVersioned {
3247            dl_managed_tensor: DLManagedTensor,
3248            version: DLVersion,
3249        }
3250
3251        struct HolderLegacy {
3252            managed: DLManagedTensor,
3253            shape: [i64; 2],
3254            strides: [i64; 2],
3255            buf: cust::memory::DeviceBuffer<f32>,
3256            retained: cust::sys::CUcontext,
3257            device_id: i32,
3258        }
3259        struct HolderV1 {
3260            managed: DLManagedTensorVersioned,
3261            shape: [i64; 2],
3262            strides: [i64; 2],
3263            buf: cust::memory::DeviceBuffer<f32>,
3264            retained: cust::sys::CUcontext,
3265            device_id: i32,
3266        }
3267
3268        unsafe extern "C" fn deleter_legacy(p: *mut DLManagedTensor) {
3269            if p.is_null() {
3270                return;
3271            }
3272            let holder = (*p).manager_ctx as *mut HolderLegacy;
3273            if !holder.is_null() {
3274                let ctx = (*holder).retained;
3275                if !ctx.is_null() {
3276                    let _ = cust::sys::cuCtxPushCurrent(ctx);
3277                    let dev = (*holder).device_id;
3278                    drop(Box::from_raw(holder));
3279                    let mut _out: cust::sys::CUcontext = std::ptr::null_mut();
3280                    let _ = cust::sys::cuCtxPopCurrent(&mut _out);
3281                    let _ = cust::sys::cuDevicePrimaryCtxRelease(dev);
3282                }
3283            }
3284            drop(Box::from_raw(p));
3285        }
3286        unsafe extern "C" fn deleter_v1(p: *mut DLManagedTensorVersioned) {
3287            if p.is_null() {
3288                return;
3289            }
3290            let holder = (*p).dl_managed_tensor.manager_ctx as *mut HolderV1;
3291            if !holder.is_null() {
3292                let ctx = (*holder).retained;
3293                if !ctx.is_null() {
3294                    let _ = cust::sys::cuCtxPushCurrent(ctx);
3295                    let dev = (*holder).device_id;
3296                    drop(Box::from_raw(holder));
3297                    let mut _out: cust::sys::CUcontext = std::ptr::null_mut();
3298                    let _ = cust::sys::cuCtxPopCurrent(&mut _out);
3299                    let _ = cust::sys::cuDevicePrimaryCtxRelease(dev);
3300                }
3301            }
3302            drop(Box::from_raw(p));
3303        }
3304
3305        unsafe extern "C" fn cap_destructor_legacy(capsule: *mut pyo3::ffi::PyObject) {
3306            let name = b"dltensor\0";
3307            let ptr = pyo3::ffi::PyCapsule_GetPointer(capsule, name.as_ptr() as *const c_char)
3308                as *mut DLManagedTensor;
3309            if !ptr.is_null() {
3310                if let Some(del) = (*ptr).deleter {
3311                    del(ptr);
3312                }
3313                let used = b"used_dltensor\0";
3314                pyo3::ffi::PyCapsule_SetName(capsule, used.as_ptr() as *const _);
3315            }
3316        }
3317        unsafe extern "C" fn cap_destructor_v1(capsule: *mut pyo3::ffi::PyObject) {
3318            let name = b"dltensor_versioned\0";
3319            let ptr = pyo3::ffi::PyCapsule_GetPointer(capsule, name.as_ptr() as *const c_char)
3320                as *mut DLManagedTensorVersioned;
3321            if !ptr.is_null() {
3322                let mt = &mut (*ptr).dl_managed_tensor;
3323                if let Some(del) = mt.deleter {
3324                    del(mt);
3325                }
3326                let used = b"used_dltensor_versioned\0";
3327                pyo3::ffi::PyCapsule_SetName(capsule, used.as_ptr() as *const _);
3328            }
3329        }
3330
3331        let alloc_dev = self.device_id as i32;
3332        let mut retained: cust::sys::CUcontext = std::ptr::null_mut();
3333        unsafe {
3334            let _ = cust::sys::cuDevicePrimaryCtxRetain(&mut retained, alloc_dev);
3335        }
3336
3337        let rows = self.rows as i64;
3338        let cols = self.cols as i64;
3339        let data_ptr: *mut std::ffi::c_void = if self.rows == 0 || self.cols == 0 {
3340            std::ptr::null_mut()
3341        } else {
3342            buf.as_device_ptr().as_raw() as *mut std::ffi::c_void
3343        };
3344
3345        let want_v1 = if let Some(v) = max_version {
3346            v.getattr("__iter")
3347                .ok()
3348                .and_then(|_| v.extract::<(i32, i32)>().ok())
3349                .map(|(maj, _)| maj >= 1)
3350                .unwrap_or(false)
3351        } else {
3352            false
3353        };
3354
3355        if want_v1 {
3356            let mut holder = Box::new(HolderV1 {
3357                managed: DLManagedTensorVersioned {
3358                    dl_managed_tensor: DLManagedTensor {
3359                        dl_tensor: DLTensor {
3360                            data: data_ptr,
3361                            device: DLDevice {
3362                                device_type: 2,
3363                                device_id: alloc_dev,
3364                            },
3365                            ndim: 2,
3366                            dtype: DLDataType {
3367                                code: 2,
3368                                bits: 32,
3369                                lanes: 1,
3370                            },
3371                            shape: std::ptr::null_mut(),
3372                            strides: std::ptr::null_mut(),
3373                            byte_offset: 0,
3374                        },
3375                        manager_ctx: std::ptr::null_mut(),
3376                        deleter: Some(|mt| {
3377                            if !mt.is_null() {
3378                                let outer = (mt as *mut u8)
3379                                    .offset(-(std::mem::size_of::<DLVersion>() as isize))
3380                                    as *mut DLManagedTensorVersioned;
3381                                deleter_v1(outer);
3382                            }
3383                        }),
3384                    },
3385                    version: DLVersion { major: 1, minor: 0 },
3386                },
3387                shape: [rows, cols],
3388                strides: [cols, 1],
3389                buf,
3390                retained,
3391                device_id: alloc_dev,
3392            });
3393            holder.managed.dl_managed_tensor.dl_tensor.shape = holder.shape.as_mut_ptr();
3394            holder.managed.dl_managed_tensor.dl_tensor.strides = holder.strides.as_mut_ptr();
3395            holder.managed.dl_managed_tensor.manager_ctx =
3396                &mut *holder as *mut HolderV1 as *mut std::ffi::c_void;
3397            let mt_ptr: *mut DLManagedTensorVersioned = &mut holder.managed;
3398            let _leak = Box::into_raw(holder);
3399            let name = b"dltensor_versioned\0";
3400            let cap = unsafe {
3401                pyo3::ffi::PyCapsule_New(
3402                    mt_ptr as *mut std::ffi::c_void,
3403                    name.as_ptr() as *const c_char,
3404                    Some(cap_destructor_v1),
3405                )
3406            };
3407            if cap.is_null() {
3408                return Err(pyo3::exceptions::PyValueError::new_err(
3409                    "failed to create DLPack capsule",
3410                ));
3411            }
3412            Ok(unsafe { pyo3::PyObject::from_owned_ptr(py, cap) })
3413        } else {
3414            let mut holder = Box::new(HolderLegacy {
3415                managed: DLManagedTensor {
3416                    dl_tensor: DLTensor {
3417                        data: data_ptr,
3418                        device: DLDevice {
3419                            device_type: 2,
3420                            device_id: alloc_dev,
3421                        },
3422                        ndim: 2,
3423                        dtype: DLDataType {
3424                            code: 2,
3425                            bits: 32,
3426                            lanes: 1,
3427                        },
3428                        shape: std::ptr::null_mut(),
3429                        strides: std::ptr::null_mut(),
3430                        byte_offset: 0,
3431                    },
3432                    manager_ctx: std::ptr::null_mut(),
3433                    deleter: Some(deleter_legacy),
3434                },
3435                shape: [rows, cols],
3436                strides: [cols, 1],
3437                buf,
3438                retained,
3439                device_id: alloc_dev,
3440            });
3441            holder.managed.dl_tensor.shape = holder.shape.as_mut_ptr();
3442            holder.managed.dl_tensor.strides = holder.strides.as_mut_ptr();
3443            holder.managed.manager_ctx = &mut *holder as *mut HolderLegacy as *mut std::ffi::c_void;
3444            let mt_ptr: *mut DLManagedTensor = &mut holder.managed;
3445            let _leak = Box::into_raw(holder);
3446            let name = b"dltensor\0";
3447            let cap = unsafe {
3448                pyo3::ffi::PyCapsule_New(
3449                    mt_ptr as *mut std::ffi::c_void,
3450                    name.as_ptr() as *const c_char,
3451                    Some(cap_destructor_legacy),
3452                )
3453            };
3454            if cap.is_null() {
3455                return Err(pyo3::exceptions::PyValueError::new_err(
3456                    "failed to create DLPack capsule",
3457                ));
3458            }
3459            Ok(unsafe { pyo3::PyObject::from_owned_ptr(py, cap) })
3460        }
3461    }
3462
3463    #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
3464    fn __dlpack__<'py>(
3465        &mut self,
3466        py: pyo3::Python<'py>,
3467        stream: Option<pyo3::PyObject>,
3468        max_version: Option<pyo3::PyObject>,
3469        dl_device: Option<pyo3::PyObject>,
3470        copy: Option<pyo3::PyObject>,
3471    ) -> pyo3::PyResult<pyo3::PyObject> {
3472        use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
3473
3474        let (kdl, alloc_dev) = self.__dlpack_device__();
3475        if let Some(dev_obj) = dl_device.as_ref() {
3476            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
3477                if dev_ty != kdl || dev_id != alloc_dev {
3478                    let wants_copy = copy
3479                        .as_ref()
3480                        .and_then(|c| c.extract::<bool>(py).ok())
3481                        .unwrap_or(false);
3482                    if wants_copy {
3483                        return Err(pyo3::exceptions::PyValueError::new_err(
3484                            "device copy not implemented for __dlpack__",
3485                        ));
3486                    } else {
3487                        return Err(pyo3::exceptions::PyValueError::new_err(
3488                            "dl_device mismatch for __dlpack__",
3489                        ));
3490                    }
3491                }
3492            }
3493        }
3494        let _ = stream;
3495
3496        let buf = self.buf.take().ok_or_else(|| {
3497            pyo3::exceptions::PyValueError::new_err("__dlpack__ may only be called once")
3498        })?;
3499
3500        let rows = self.rows;
3501        let cols = self.cols;
3502
3503        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
3504
3505        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
3506    }
3507}