1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7use crate::utilities::math_functions::atan_fast;
8use aligned_vec::{AVec, CACHELINE_ALIGN};
9#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
10use core::arch::x86_64::*;
11#[cfg(not(target_arch = "wasm32"))]
12use rayon::prelude::*;
13use std::convert::AsRef;
14use std::error::Error;
15use std::f64::consts::PI;
16use std::mem::MaybeUninit;
17use thiserror::Error;
18#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
19use wasm_bindgen::prelude::*;
20
21#[derive(Debug, Clone)]
22pub enum MamaData<'a> {
23 Candles {
24 candles: &'a Candles,
25 source: &'a str,
26 },
27 Slice(&'a [f64]),
28}
29
30#[derive(Debug, Clone)]
31pub struct MamaOutput {
32 pub mama_values: Vec<f64>,
33 pub fama_values: Vec<f64>,
34}
35
36#[derive(Debug, Clone)]
37#[cfg_attr(
38 all(target_arch = "wasm32", feature = "wasm"),
39 derive(serde::Serialize, serde::Deserialize)
40)]
41pub struct MamaParams {
42 pub fast_limit: Option<f64>,
43 pub slow_limit: Option<f64>,
44}
45
46impl Default for MamaParams {
47 fn default() -> Self {
48 Self {
49 fast_limit: Some(0.5),
50 slow_limit: Some(0.05),
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
56pub struct MamaInput<'a> {
57 pub data: MamaData<'a>,
58 pub params: MamaParams,
59}
60
61impl<'a> AsRef<[f64]> for MamaInput<'a> {
62 #[inline(always)]
63 fn as_ref(&self) -> &[f64] {
64 match &self.data {
65 MamaData::Slice(slice) => slice,
66 MamaData::Candles { candles, source } => source_type(candles, source),
67 }
68 }
69}
70
71impl<'a> MamaInput<'a> {
72 #[inline]
73 pub fn from_candles(c: &'a Candles, s: &'a str, p: MamaParams) -> Self {
74 Self {
75 data: MamaData::Candles {
76 candles: c,
77 source: s,
78 },
79 params: p,
80 }
81 }
82 #[inline]
83 pub fn from_slice(sl: &'a [f64], p: MamaParams) -> Self {
84 Self {
85 data: MamaData::Slice(sl),
86 params: p,
87 }
88 }
89 #[inline]
90 pub fn with_default_candles(c: &'a Candles) -> Self {
91 Self::from_candles(c, "close", MamaParams::default())
92 }
93 #[inline]
94 pub fn get_fast_limit(&self) -> f64 {
95 self.params.fast_limit.unwrap_or(0.5)
96 }
97 #[inline]
98 pub fn get_slow_limit(&self) -> f64 {
99 self.params.slow_limit.unwrap_or(0.05)
100 }
101}
102
103#[derive(Copy, Clone, Debug)]
104pub struct MamaBuilder {
105 fast_limit: Option<f64>,
106 slow_limit: Option<f64>,
107 kernel: Kernel,
108}
109
110impl Default for MamaBuilder {
111 fn default() -> Self {
112 Self {
113 fast_limit: None,
114 slow_limit: None,
115 kernel: Kernel::Auto,
116 }
117 }
118}
119
120impl MamaBuilder {
121 #[inline(always)]
122 pub fn new() -> Self {
123 Self::default()
124 }
125 #[inline(always)]
126 pub fn fast_limit(mut self, n: f64) -> Self {
127 self.fast_limit = Some(n);
128 self
129 }
130 #[inline(always)]
131 pub fn slow_limit(mut self, x: f64) -> Self {
132 self.slow_limit = Some(x);
133 self
134 }
135 #[inline(always)]
136 pub fn kernel(mut self, k: Kernel) -> Self {
137 self.kernel = k;
138 self
139 }
140 #[inline(always)]
141 pub fn apply(self, c: &Candles) -> Result<MamaOutput, MamaError> {
142 let p = MamaParams {
143 fast_limit: self.fast_limit,
144 slow_limit: self.slow_limit,
145 };
146 let i = MamaInput::from_candles(c, "close", p);
147 mama_with_kernel(&i, self.kernel)
148 }
149 #[inline(always)]
150 pub fn apply_slice(self, d: &[f64]) -> Result<MamaOutput, MamaError> {
151 let p = MamaParams {
152 fast_limit: self.fast_limit,
153 slow_limit: self.slow_limit,
154 };
155 let i = MamaInput::from_slice(d, p);
156 mama_with_kernel(&i, self.kernel)
157 }
158 #[inline(always)]
159 pub fn into_stream(self) -> Result<MamaStream, MamaError> {
160 let p = MamaParams {
161 fast_limit: self.fast_limit,
162 slow_limit: self.slow_limit,
163 };
164 MamaStream::try_new(p)
165 }
166}
167
168#[derive(Debug, Error)]
169pub enum MamaError {
170 #[error("mama: empty input data")]
171 EmptyInputData,
172 #[error("mama: all values are NaN")]
173 AllValuesNaN,
174 #[error("mama: not enough valid data: needed {needed}, valid {valid}")]
175 NotEnoughValidData { needed: usize, valid: usize },
176 #[error("mama: Not enough data: needed at least {needed}, found {found}")]
177 NotEnoughData { needed: usize, found: usize },
178 #[error("mama: output length mismatch: expected {expected}, got {got}")]
179 OutputLengthMismatch { expected: usize, got: usize },
180 #[error("mama: invalid range expansion start={start} end={end} step={step}")]
181 InvalidRange { start: f64, end: f64, step: f64 },
182 #[error("mama: invalid kernel for batch path: {0:?}")]
183 InvalidKernelForBatch(Kernel),
184 #[error("mama: Invalid fast limit: {fast_limit}")]
185 InvalidFastLimit { fast_limit: f64 },
186 #[error("mama: Invalid slow limit: {slow_limit}")]
187 InvalidSlowLimit { slow_limit: f64 },
188}
189
190#[inline]
191pub fn mama(input: &MamaInput) -> Result<MamaOutput, MamaError> {
192 mama_with_kernel(input, Kernel::Auto)
193}
194
195#[inline(always)]
196fn mama_prepare<'a>(
197 input: &'a MamaInput,
198 kernel: Kernel,
199) -> Result<(&'a [f64], f64, f64, Kernel), MamaError> {
200 let data = input.as_ref();
201 let len = data.len();
202 if len == 0 {
203 return Err(MamaError::EmptyInputData);
204 }
205 if len < 10 {
206 return Err(MamaError::NotEnoughData {
207 needed: 10,
208 found: len,
209 });
210 }
211
212 let fast_limit = input.get_fast_limit();
213 let slow_limit = input.get_slow_limit();
214 if fast_limit <= 0.0 || fast_limit.is_nan() || fast_limit.is_infinite() {
215 return Err(MamaError::InvalidFastLimit { fast_limit });
216 }
217 if slow_limit <= 0.0 || slow_limit.is_nan() || slow_limit.is_infinite() {
218 return Err(MamaError::InvalidSlowLimit { slow_limit });
219 }
220
221 let chosen = match kernel {
222 Kernel::Auto => Kernel::Scalar,
223 k => k,
224 };
225
226 Ok((data, fast_limit, slow_limit, chosen))
227}
228
229pub fn mama_with_kernel(input: &MamaInput, kernel: Kernel) -> Result<MamaOutput, MamaError> {
230 let (data, fast_limit, slow_limit, chosen) = mama_prepare(input, kernel)?;
231 let len = data.len();
232 const WARM: usize = 10;
233
234 let mut mama_values = alloc_with_nan_prefix(len, WARM);
235 let mut fama_values = alloc_with_nan_prefix(len, WARM);
236
237 unsafe {
238 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
239 {
240 if matches!(chosen, Kernel::Scalar | Kernel::ScalarBatch) {
241 mama_simd128_inplace(
242 data,
243 fast_limit,
244 slow_limit,
245 &mut mama_values,
246 &mut fama_values,
247 );
248
249 for v in &mut mama_values[..WARM] {
250 *v = f64::NAN;
251 }
252 for v in &mut fama_values[..WARM] {
253 *v = f64::NAN;
254 }
255 return Ok(MamaOutput {
256 mama_values,
257 fama_values,
258 });
259 }
260 }
261
262 match chosen {
263 Kernel::Scalar | Kernel::ScalarBatch => {
264 mama_scalar_inplace(
265 data,
266 fast_limit,
267 slow_limit,
268 &mut mama_values,
269 &mut fama_values,
270 );
271 }
272
273 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
274 Kernel::Avx2 | Kernel::Avx2Batch => {
275 mama_scalar_inplace(
276 data,
277 fast_limit,
278 slow_limit,
279 &mut mama_values,
280 &mut fama_values,
281 );
282 }
283 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
284 Kernel::Avx2 | Kernel::Avx2Batch => {
285 mama_scalar_inplace(
286 data,
287 fast_limit,
288 slow_limit,
289 &mut mama_values,
290 &mut fama_values,
291 );
292 }
293
294 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
295 Kernel::Avx512 | Kernel::Avx512Batch => {
296 mama_scalar_inplace(
297 data,
298 fast_limit,
299 slow_limit,
300 &mut mama_values,
301 &mut fama_values,
302 );
303 }
304 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
305 Kernel::Avx512 | Kernel::Avx512Batch => {
306 mama_scalar_inplace(
307 data,
308 fast_limit,
309 slow_limit,
310 &mut mama_values,
311 &mut fama_values,
312 );
313 }
314
315 _ => unreachable!("unsupported kernel variant"),
316 }
317 }
318
319 for v in &mut mama_values[..WARM] {
320 *v = f64::NAN;
321 }
322 for v in &mut fama_values[..WARM] {
323 *v = f64::NAN;
324 }
325
326 Ok(MamaOutput {
327 mama_values,
328 fama_values,
329 })
330}
331
332pub fn mama_compute_into(
333 input: &MamaInput,
334 kernel: Kernel,
335 out_mama: &mut [f64],
336 out_fama: &mut [f64],
337) -> Result<(), MamaError> {
338 let (data, fast_limit, slow_limit, chosen) = mama_prepare(input, kernel)?;
339
340 if out_mama.len() != data.len() || out_fama.len() != data.len() {
341 return Err(MamaError::OutputLengthMismatch {
342 expected: data.len(),
343 got: out_mama.len().min(out_fama.len()),
344 });
345 }
346
347 unsafe {
348 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
349 {
350 if matches!(chosen, Kernel::Scalar | Kernel::ScalarBatch) {
351 mama_simd128_inplace(data, fast_limit, slow_limit, out_mama, out_fama);
352 return Ok(());
353 }
354 }
355
356 match chosen {
357 Kernel::Scalar | Kernel::ScalarBatch => {
358 mama_scalar_inplace(data, fast_limit, slow_limit, out_mama, out_fama);
359 }
360
361 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
362 Kernel::Avx2 | Kernel::Avx2Batch => {
363 mama_avx2_inplace(data, fast_limit, slow_limit, out_mama, out_fama);
364 }
365
366 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
367 Kernel::Avx512 | Kernel::Avx512Batch => {
368 mama_avx512_inplace(data, fast_limit, slow_limit, out_mama, out_fama);
369 }
370
371 _ => unreachable!("unsupported kernel variant"),
372 }
373 }
374
375 Ok(())
376}
377
378#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
379#[inline]
380pub fn mama_into(
381 input: &MamaInput,
382 out_mama: &mut [f64],
383 out_fama: &mut [f64],
384) -> Result<(), MamaError> {
385 let data = input.as_ref();
386 if out_mama.len() != data.len() || out_fama.len() != data.len() {
387 return Err(MamaError::OutputLengthMismatch {
388 expected: data.len(),
389 got: out_mama.len().min(out_fama.len()),
390 });
391 }
392
393 mama_compute_into(input, Kernel::Auto, out_mama, out_fama)?;
394
395 const WARM: usize = 10;
396 let warm = WARM.min(data.len());
397 for v in &mut out_mama[..warm] {
398 *v = f64::NAN;
399 }
400 for v in &mut out_fama[..warm] {
401 *v = f64::NAN;
402 }
403 Ok(())
404}
405
406#[inline]
407pub fn mama_into_slice(
408 dst_mama: &mut [f64],
409 dst_fama: &mut [f64],
410 input: &MamaInput,
411 kern: Kernel,
412) -> Result<(), MamaError> {
413 let (data, _fast, _slow, _chosen) = mama_prepare(input, kern)?;
414 if dst_mama.len() != data.len() || dst_fama.len() != data.len() {
415 return Err(MamaError::OutputLengthMismatch {
416 expected: data.len(),
417 got: dst_mama.len().min(dst_fama.len()),
418 });
419 }
420 mama_compute_into(input, kern, dst_mama, dst_fama)?;
421
422 const WARM: usize = 10;
423 let warm = WARM.min(data.len());
424 for v in &mut dst_mama[..warm] {
425 *v = f64::NAN;
426 }
427 for v in &mut dst_fama[..warm] {
428 *v = f64::NAN;
429 }
430 Ok(())
431}
432
433#[inline(always)]
434pub fn mama_scalar(
435 data: &[f64],
436 fast_limit: f64,
437 slow_limit: f64,
438 out_mama: &mut [f64],
439 out_fama: &mut [f64],
440) -> Result<(), MamaError> {
441 mama_scalar_inplace(data, fast_limit, slow_limit, out_mama, out_fama);
442 Ok(())
443}
444
445#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
446#[inline]
447pub unsafe fn mama_avx2(
448 data: &[f64],
449 fast_limit: f64,
450 slow_limit: f64,
451 out_mama: &mut [f64],
452 out_fama: &mut [f64],
453) -> Result<(), MamaError> {
454 mama_avx2_inplace(data, fast_limit, slow_limit, out_mama, out_fama);
455 Ok(())
456}
457
458#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
459#[target_feature(enable = "avx512f,avx512dq,fma")]
460#[inline]
461unsafe fn hilbert4_avx512(x0: f64, x2: f64, x4: f64, x6: f64) -> f64 {
462 let v_x = _mm512_set_pd(0.0, 0.0, 0.0, 0.0, x6, x4, x2, x0);
463
464 const H3: f64 = -0.096_2;
465 const H2: f64 = -0.576_9;
466 const H1: f64 = 0.576_9;
467 const H0: f64 = 0.096_2;
468 let v_h = _mm512_set_pd(0.0, 0.0, 0.0, 0.0, H3, H2, H1, H0);
469
470 let v_mul = _mm512_mul_pd(v_x, v_h);
471 _mm512_reduce_add_pd(v_mul)
472}
473
474#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
475#[target_feature(enable = "avx512f,avx512dq,fma")]
476#[inline]
477pub unsafe fn mama_avx512_inplace(
478 data: &[f64],
479 fast_limit: f64,
480 slow_limit: f64,
481 out_mama: &mut [f64],
482 out_fama: &mut [f64],
483) {
484 debug_assert_eq!(data.len(), out_mama.len());
485 debug_assert_eq!(data.len(), out_fama.len());
486
487 const LEN: usize = 8;
488 const MASK: usize = LEN - 1;
489
490 #[repr(align(64))]
491 struct A([f64; LEN]);
492 let first = data[0];
493 let mut smooth = A([first; LEN]).0;
494 let mut detrender = A([first; LEN]).0;
495 let mut i1_buf = A([first; LEN]).0;
496 let mut q1_buf = A([first; LEN]).0;
497
498 const DEG_PER_RAD: f64 = 180.0 / std::f64::consts::PI;
499
500 let (mut idx, mut prev_mesa, mut prev_phase) = (0usize, 0.0, 0.0);
501 let (mut prev_mama, mut prev_fama) = (first, first);
502 let (mut prev_i2, mut prev_q2) = (0.0, 0.0);
503 let (mut prev_re, mut prev_im) = (0.0, 0.0);
504
505 #[inline(always)]
506 fn lag(buf: &[f64; LEN], p: usize, k: usize) -> f64 {
507 unsafe { *buf.get_unchecked((p.wrapping_sub(k)) & MASK) }
508 }
509
510 for (i, &price) in data.iter().enumerate() {
511 let s1 = if i >= 1 { data[i - 1] } else { price };
512 let s2 = if i >= 2 { data[i - 2] } else { price };
513 let s3 = if i >= 3 { data[i - 3] } else { price };
514 let smooth_val =
515 0.1 * (4.0_f64.mul_add(price, 3.0_f64.mul_add(s1, 2.0_f64.mul_add(s2, s3))));
516 smooth[idx] = smooth_val;
517
518 let amp = 0.075_f64.mul_add(prev_mesa, 0.54);
519 let dt_val = amp
520 * hilbert4_avx512(
521 smooth[idx],
522 lag(&smooth, idx, 2),
523 lag(&smooth, idx, 4),
524 lag(&smooth, idx, 6),
525 );
526 detrender[idx] = dt_val;
527
528 let i1 = lag(&detrender, idx, 3);
529 i1_buf[idx] = i1;
530
531 let q1 = amp
532 * hilbert4_avx512(
533 detrender[idx],
534 lag(&detrender, idx, 2),
535 lag(&detrender, idx, 4),
536 lag(&detrender, idx, 6),
537 );
538 q1_buf[idx] = q1;
539
540 let j_i = amp
541 * hilbert4_avx512(
542 i1_buf[idx],
543 lag(&i1_buf, idx, 2),
544 lag(&i1_buf, idx, 4),
545 lag(&i1_buf, idx, 6),
546 );
547 let j_q = amp
548 * hilbert4_avx512(
549 q1_buf[idx],
550 lag(&q1_buf, idx, 2),
551 lag(&q1_buf, idx, 4),
552 lag(&q1_buf, idx, 6),
553 );
554
555 let i2 = i1 - j_q;
556 let q2 = q1 + j_i;
557 let old_i2 = prev_i2;
558 let old_q2 = prev_q2;
559 let i2s = 0.2_f64.mul_add(i2, 0.8 * old_i2);
560 let q2s = 0.2_f64.mul_add(q2, 0.8 * old_q2);
561 prev_i2 = i2s;
562 prev_q2 = q2s;
563
564 let re = 0.2_f64.mul_add(i2s * old_i2 + q2s * old_q2, 0.8 * prev_re);
565 let im = 0.2_f64.mul_add(i2s * old_q2 - q2s * old_i2, 0.8 * prev_im);
566 prev_re = re;
567 prev_im = im;
568
569 let mut mesa = if re != 0.0 && im != 0.0 {
570 2.0 * std::f64::consts::PI / atan_fast(im / re)
571 } else {
572 prev_mesa
573 };
574 mesa = mesa
575 .min(1.5 * prev_mesa)
576 .max(0.67 * prev_mesa)
577 .max(6.0)
578 .min(50.0);
579 mesa = 0.2_f64.mul_add(mesa, 0.8 * prev_mesa);
580 prev_mesa = mesa;
581
582 let phase = if i1 != 0.0 {
583 atan_fast(q1 / i1) * DEG_PER_RAD
584 } else {
585 prev_phase
586 };
587 let mut dp = prev_phase - phase;
588 if dp < 1.0 {
589 dp = 1.0;
590 }
591 prev_phase = phase;
592
593 let mut alpha = fast_limit / dp;
594 alpha = alpha.clamp(slow_limit, fast_limit);
595
596 let cur_mama = alpha.mul_add(price, (1.0 - alpha) * prev_mama);
597 let cur_fama = (0.5 * alpha).mul_add(cur_mama, (1.0 - 0.5 * alpha) * prev_fama);
598 prev_mama = cur_mama;
599 prev_fama = cur_fama;
600
601 out_mama[i] = cur_mama;
602 out_fama[i] = cur_fama;
603
604 idx = (idx + 1) & MASK;
605 }
606}
607
608#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
609#[target_feature(enable = "avx2,fma")]
610#[inline]
611unsafe fn hilbert4_avx2(x0: f64, x2: f64, x4: f64, x6: f64) -> f64 {
612 let v_x = _mm256_set_pd(x6, x4, x2, x0);
613
614 const H3: f64 = -0.096_2;
615 const H2: f64 = -0.576_9;
616 const H1: f64 = 0.576_9;
617 const H0: f64 = 0.096_2;
618 let v_h = _mm256_set_pd(H3, H2, H1, H0);
619
620 let v_mul = _mm256_mul_pd(v_x, v_h);
621 let v_sum = _mm256_hadd_pd(v_mul, v_mul);
622
623 let v_fold = _mm256_permute2f128_pd(v_sum, v_sum, 0x1);
624 let v_res = _mm256_add_pd(v_sum, v_fold);
625 _mm256_cvtsd_f64(v_res)
626}
627
628#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
629#[inline]
630pub unsafe fn mama_avx2_inplace(
631 data: &[f64],
632 fast_limit: f64,
633 slow_limit: f64,
634 out_mama: &mut [f64],
635 out_fama: &mut [f64],
636) {
637 debug_assert_eq!(data.len(), out_mama.len());
638 debug_assert_eq!(data.len(), out_fama.len());
639
640 const RING_LEN: usize = 8;
641 const MASK: usize = RING_LEN - 1;
642
643 const W0: f64 = 4.0;
644 const W1: f64 = 3.0;
645 const W2: f64 = 2.0;
646 const W3: f64 = 1.0;
647
648 const H0: f64 = 0.096_2;
649 const H1: f64 = 0.576_9;
650 const H2: f64 = -0.576_9;
651 const H3: f64 = -0.096_2;
652
653 const DEG_PER_RAD: f64 = 180.0 / std::f64::consts::PI;
654
655 let first = data[0];
656 let mut smooth = [first; RING_LEN];
657 let mut detrender = [first; RING_LEN];
658 let mut i1_buf = [first; RING_LEN];
659 let mut q1_buf = [first; RING_LEN];
660
661 let mut idx = 0usize;
662 let mut prev_mesa = 0.0;
663 let mut prev_phase = 0.0;
664 let mut prev_mama = first;
665 let mut prev_fama = first;
666 let mut prev_i2 = 0.0;
667 let mut prev_q2 = 0.0;
668 let mut prev_re = 0.0;
669 let mut prev_im = 0.0;
670
671 #[inline(always)]
672 fn lag(buf: &[f64; RING_LEN], p: usize, k: usize) -> f64 {
673 buf[(p.wrapping_sub(k)) & MASK]
674 }
675
676 for (i, &price) in data.iter().enumerate() {
677 let s1 = if i >= 1 { data[i - 1] } else { price };
678 let s2 = if i >= 2 { data[i - 2] } else { price };
679 let s3 = if i >= 3 { data[i - 3] } else { price };
680
681 let smooth_val = W0.mul_add(price, W1.mul_add(s1, W2.mul_add(s2, s3))) * 0.1;
682 smooth[idx] = smooth_val;
683
684 let amp = 0.075_f64.mul_add(prev_mesa, 0.54);
685
686 let dt_val = amp
687 * hilbert4_avx2(
688 smooth[idx],
689 lag(&smooth, idx, 2),
690 lag(&smooth, idx, 4),
691 lag(&smooth, idx, 6),
692 );
693 detrender[idx] = dt_val;
694
695 let i1 = lag(&detrender, idx, 3);
696 i1_buf[idx] = i1;
697
698 let q1 = amp
699 * hilbert4_avx2(
700 detrender[idx],
701 lag(&detrender, idx, 2),
702 lag(&detrender, idx, 4),
703 lag(&detrender, idx, 6),
704 );
705 q1_buf[idx] = q1;
706
707 let j_i = amp
708 * hilbert4_avx2(
709 i1_buf[idx],
710 lag(&i1_buf, idx, 2),
711 lag(&i1_buf, idx, 4),
712 lag(&i1_buf, idx, 6),
713 );
714 let j_q = amp
715 * hilbert4_avx2(
716 q1_buf[idx],
717 lag(&q1_buf, idx, 2),
718 lag(&q1_buf, idx, 4),
719 lag(&q1_buf, idx, 6),
720 );
721
722 let i2 = i1 - j_q;
723 let q2 = q1 + j_i;
724 let old_i2 = prev_i2;
725 let old_q2 = prev_q2;
726 let i2s = 0.2_f64.mul_add(i2, 0.8 * old_i2);
727 let q2s = 0.2_f64.mul_add(q2, 0.8 * old_q2);
728 prev_i2 = i2s;
729 prev_q2 = q2s;
730
731 let re = 0.2_f64.mul_add(i2s * old_i2 + q2s * old_q2, 0.8 * prev_re);
732 let im = 0.2_f64.mul_add(i2s * old_q2 - q2s * old_i2, 0.8 * prev_im);
733 prev_re = re;
734 prev_im = im;
735
736 let mut mesa = if re != 0.0 && im != 0.0 {
737 2.0 * std::f64::consts::PI / atan_fast(im / re)
738 } else {
739 prev_mesa
740 };
741
742 mesa = mesa
743 .min(1.5 * prev_mesa)
744 .max(0.67 * prev_mesa)
745 .max(6.0)
746 .min(50.0);
747 mesa = 0.2_f64.mul_add(mesa, 0.8 * prev_mesa);
748 prev_mesa = mesa;
749
750 let phase = if i1 != 0.0 {
751 atan_fast(q1 / i1) * DEG_PER_RAD
752 } else {
753 prev_phase
754 };
755 let mut dp = prev_phase - phase;
756 if dp < 1.0 {
757 dp = 1.0;
758 }
759 prev_phase = phase;
760
761 let mut alpha = fast_limit / dp;
762 alpha = alpha.clamp(slow_limit, fast_limit);
763
764 let cur_mama = alpha.mul_add(price, (1.0 - alpha) * prev_mama);
765 let cur_fama = (0.5 * alpha).mul_add(cur_mama, (1.0 - 0.5 * alpha) * prev_fama);
766 prev_mama = cur_mama;
767 prev_fama = cur_fama;
768
769 out_mama[i] = cur_mama;
770 out_fama[i] = cur_fama;
771
772 idx = (idx + 1) & MASK;
773 }
774}
775#[inline(always)]
776fn hilbert(x0: f64, x2: f64, x4: f64, x6: f64) -> f64 {
777 0.0962 * x0 + 0.5769 * x2 - 0.5769 * x4 - 0.0962 * x6
778}
779
780#[inline]
781pub fn mama_scalar_inplace(
782 data: &[f64],
783 fast_limit: f64,
784 slow_limit: f64,
785 out_mama: &mut [f64],
786 out_fama: &mut [f64],
787) {
788 debug_assert_eq!(data.len(), out_mama.len());
789 debug_assert_eq!(data.len(), out_fama.len());
790 let len = data.len();
791
792 const RING: usize = 8;
793 const MASK: usize = RING - 1;
794
795 const H0: f64 = 0.096_2;
796 const H1: f64 = 0.576_9;
797 const H2: f64 = -0.576_9;
798 const H3: f64 = -0.096_2;
799 const DEG_PER_RAD: f64 = 180.0 / std::f64::consts::PI;
800
801 #[inline(always)]
802 fn hilbert4(x0: f64, x2: f64, x4: f64, x6: f64) -> f64 {
803 H0.mul_add(x0, H1.mul_add(x2, H2.mul_add(x4, H3 * x6)))
804 }
805
806 #[inline(always)]
807 fn lag<const N: usize>(buf: &[f64; N], pos: usize, k: usize) -> f64 {
808 buf[(pos.wrapping_sub(k)) & (N - 1)]
809 }
810
811 let first = data[0];
812
813 let mut smooth = [first; RING];
814 let mut detrender = [first; RING];
815 let mut i1_buf = [first; RING];
816 let mut q1_buf = [first; RING];
817
818 let mut idx = 0usize;
819 let mut prev_mesa = 0.0;
820 let mut prev_phase = 0.0;
821 let mut prev_mama = first;
822 let mut prev_fama = first;
823 let mut prev_i2 = 0.0;
824 let mut prev_q2 = 0.0;
825 let mut prev_re = 0.0;
826 let mut prev_im = 0.0;
827
828 for (i, &price) in data.iter().enumerate() {
829 let s1 = if i >= 1 { data[i - 1] } else { price };
830 let s2 = if i >= 2 { data[i - 2] } else { price };
831 let s3 = if i >= 3 { data[i - 3] } else { price };
832 let smooth_val =
833 0.1 * (4.0_f64.mul_add(price, 3.0_f64.mul_add(s1, 2.0_f64.mul_add(s2, s3))));
834 smooth[idx] = smooth_val;
835
836 let amp = 0.075_f64.mul_add(prev_mesa, 0.54);
837
838 let dt = amp
839 * hilbert4(
840 smooth[idx],
841 lag(&smooth, idx, 2),
842 lag(&smooth, idx, 4),
843 lag(&smooth, idx, 6),
844 );
845 detrender[idx] = dt;
846
847 let i1 = lag(&detrender, idx, 3);
848 i1_buf[idx] = i1;
849 let q1 = amp
850 * hilbert4(
851 detrender[idx],
852 lag(&detrender, idx, 2),
853 lag(&detrender, idx, 4),
854 lag(&detrender, idx, 6),
855 );
856 q1_buf[idx] = q1;
857
858 let j_i = amp
859 * hilbert4(
860 i1_buf[idx],
861 lag(&i1_buf, idx, 2),
862 lag(&i1_buf, idx, 4),
863 lag(&i1_buf, idx, 6),
864 );
865 let j_q = amp
866 * hilbert4(
867 q1_buf[idx],
868 lag(&q1_buf, idx, 2),
869 lag(&q1_buf, idx, 4),
870 lag(&q1_buf, idx, 6),
871 );
872
873 let i2 = i1 - j_q;
874 let q2 = q1 + j_i;
875 let i2s = 0.2_f64.mul_add(i2, 0.8 * prev_i2);
876 let q2s = 0.2_f64.mul_add(q2, 0.8 * prev_q2);
877 let re = 0.2_f64.mul_add(i2s * prev_i2 + q2s * prev_q2, 0.8 * prev_re);
878 let im = 0.2_f64.mul_add(i2s * prev_q2 - q2s * prev_i2, 0.8 * prev_im);
879 prev_i2 = i2s;
880 prev_q2 = q2s;
881 prev_re = re;
882 prev_im = im;
883
884 let mut mesa = if re != 0.0 && im != 0.0 {
885 2.0 * std::f64::consts::PI / atan_fast(im / re)
886 } else {
887 prev_mesa
888 };
889 if mesa > 1.5 * prev_mesa {
890 mesa = 1.5 * prev_mesa;
891 }
892 if mesa < 0.67 * prev_mesa {
893 mesa = 0.67 * prev_mesa;
894 }
895 if mesa < 6.0 {
896 mesa = 6.0;
897 }
898 if mesa > 50.0 {
899 mesa = 50.0;
900 }
901 mesa = 0.2_f64.mul_add(mesa, 0.8 * prev_mesa);
902 prev_mesa = mesa;
903
904 let phase = if i1 != 0.0 {
905 atan_fast(q1 / i1) * DEG_PER_RAD
906 } else {
907 prev_phase
908 };
909 let mut dphi = prev_phase - phase;
910 if dphi < 1.0 {
911 dphi = 1.0;
912 }
913 prev_phase = phase;
914
915 let mut alpha = fast_limit / dphi;
916 if alpha < slow_limit {
917 alpha = slow_limit;
918 }
919 if alpha > fast_limit {
920 alpha = fast_limit;
921 }
922
923 let mama = alpha.mul_add(price, (1.0 - alpha) * prev_mama);
924 let fama = (0.5 * alpha).mul_add(mama, (1.0 - 0.5 * alpha) * prev_fama);
925 prev_mama = mama;
926 prev_fama = fama;
927
928 out_mama[i] = mama;
929 out_fama[i] = fama;
930
931 idx = (idx + 1) & MASK;
932 }
933}
934
935#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
936#[inline]
937unsafe fn mama_simd128_inplace(
938 data: &[f64],
939 fast_limit: f64,
940 slow_limit: f64,
941 out_mama: &mut [f64],
942 out_fama: &mut [f64],
943) {
944 use core::arch::wasm32::*;
945
946 debug_assert_eq!(data.len(), out_mama.len());
947 debug_assert_eq!(data.len(), out_fama.len());
948
949 let len = data.len();
950
951 let mut smooth_buf = [data[0]; 7];
952 let mut detrender_buf = [data[0]; 7];
953 let mut i1_buf = [data[0]; 7];
954 let mut q1_buf = [data[0]; 7];
955
956 let mut prev_mesa_period = 0.0;
957 let mut prev_mama = data[0];
958 let mut prev_fama = data[0];
959 let mut prev_i2_sm = 0.0;
960 let mut prev_q2_sm = 0.0;
961 let mut prev_re = 0.0;
962 let mut prev_im = 0.0;
963 let mut prev_phase = 0.0;
964
965 let hilbert_weights = f64x2(0.0962, 0.5769);
966 let neg_hilbert_weights = f64x2(-0.5769, -0.0962);
967
968 let smooth_weights = f64x2(4.0, 3.0);
969 let smooth_weights2 = f64x2(2.0, 1.0);
970 let smooth_div = f64x2_splat(0.1);
971
972 #[inline(always)]
973 fn hilbert_simd128(
974 x0: f64,
975 x2: f64,
976 x4: f64,
977 x6: f64,
978 weights: v128,
979 neg_weights: v128,
980 ) -> f64 {
981 let v1 = f64x2(x0, x2);
982 let v2 = f64x2(x4, x6);
983
984 let prod1 = f64x2_mul(v1, weights);
985 let prod2 = f64x2_mul(v2, neg_weights);
986 let sum = f64x2_add(prod1, prod2);
987
988 f64x2_extract_lane::<0>(sum) + f64x2_extract_lane::<1>(sum)
989 }
990
991 for i in 0..len {
992 let price = data[i];
993
994 let s1 = if i >= 1 { data[i - 1] } else { price };
995 let s2 = if i >= 2 { data[i - 2] } else { price };
996 let s3 = if i >= 3 { data[i - 3] } else { price };
997
998 let v1 = f64x2(price, s1);
999 let v2 = f64x2(s2, s3);
1000 let prod1 = f64x2_mul(v1, smooth_weights);
1001 let prod2 = f64x2_mul(v2, smooth_weights2);
1002 let sum = f64x2_add(prod1, prod2);
1003 let smooth_val = (f64x2_extract_lane::<0>(sum) + f64x2_extract_lane::<1>(sum)) * 0.1;
1004
1005 let idx = i % 7;
1006 smooth_buf[idx] = smooth_val;
1007
1008 let x0 = smooth_buf[idx];
1009 let x2 = smooth_buf[(idx + 5) % 7];
1010 let x4 = smooth_buf[(idx + 3) % 7];
1011 let x6 = smooth_buf[(idx + 1) % 7];
1012
1013 let mesa_mult = 0.075 * prev_mesa_period + 0.54;
1014 let dt_val =
1015 hilbert_simd128(x0, x2, x4, x6, hilbert_weights, neg_hilbert_weights) * mesa_mult;
1016 detrender_buf[idx] = dt_val;
1017
1018 let i1_val = if i >= 3 {
1019 detrender_buf[(idx + 4) % 7]
1020 } else {
1021 dt_val
1022 };
1023 i1_buf[idx] = i1_val;
1024
1025 let d0 = detrender_buf[idx];
1026 let d2 = detrender_buf[(idx + 5) % 7];
1027 let d4 = detrender_buf[(idx + 3) % 7];
1028 let d6 = detrender_buf[(idx + 1) % 7];
1029 let q1_val =
1030 hilbert_simd128(d0, d2, d4, d6, hilbert_weights, neg_hilbert_weights) * mesa_mult;
1031 q1_buf[idx] = q1_val;
1032
1033 let j_i = {
1034 let i0 = i1_buf[idx];
1035 let i2 = i1_buf[(idx + 5) % 7];
1036 let i4 = i1_buf[(idx + 3) % 7];
1037 let i6 = i1_buf[(idx + 1) % 7];
1038 hilbert_simd128(i0, i2, i4, i6, hilbert_weights, neg_hilbert_weights) * mesa_mult
1039 };
1040 let j_q = {
1041 let q0 = q1_buf[idx];
1042 let q2 = q1_buf[(idx + 5) % 7];
1043 let q4 = q1_buf[(idx + 3) % 7];
1044 let q6 = q1_buf[(idx + 1) % 7];
1045 hilbert_simd128(q0, q2, q4, q6, hilbert_weights, neg_hilbert_weights) * mesa_mult
1046 };
1047
1048 let i2 = i1_val - j_q;
1049 let q2 = q1_val + j_i;
1050 let i2_sm = 0.2 * i2 + 0.8 * prev_i2_sm;
1051 let q2_sm = 0.2 * q2 + 0.8 * prev_q2_sm;
1052 let re = 0.2 * (i2_sm * prev_i2_sm + q2_sm * prev_q2_sm) + 0.8 * prev_re;
1053 let im = 0.2 * (i2_sm * prev_q2_sm - q2_sm * prev_i2_sm) + 0.8 * prev_im;
1054 prev_i2_sm = i2_sm;
1055 prev_q2_sm = q2_sm;
1056 prev_re = re;
1057 prev_im = im;
1058
1059 let mut mesa_period = if re != 0.0 && im != 0.0 {
1060 2.0 * std::f64::consts::PI / atan_fast(im / re)
1061 } else {
1062 prev_mesa_period
1063 };
1064
1065 if mesa_period > 1.5 * prev_mesa_period {
1066 mesa_period = 1.5 * prev_mesa_period;
1067 }
1068 if mesa_period < 0.67 * prev_mesa_period {
1069 mesa_period = 0.67 * prev_mesa_period;
1070 }
1071 if mesa_period < 6.0 {
1072 mesa_period = 6.0;
1073 }
1074 if mesa_period > 50.0 {
1075 mesa_period = 50.0;
1076 }
1077
1078 let phase = if i1_val != 0.0 {
1079 atan_fast(q1_val / i1_val) * 180.0 / std::f64::consts::PI
1080 } else {
1081 prev_phase
1082 };
1083
1084 let mut dp = prev_phase - phase;
1085 if dp < 1.0 {
1086 dp = 1.0;
1087 }
1088 prev_phase = phase;
1089
1090 let mut alpha = fast_limit / dp;
1091 alpha = alpha.clamp(slow_limit, fast_limit);
1092
1093 prev_mesa_period = mesa_period;
1094
1095 let mama_val = alpha * price + (1.0 - alpha) * prev_mama;
1096 let fama_val = 0.5 * alpha * mama_val + (1.0 - 0.5 * alpha) * prev_fama;
1097
1098 out_mama[i] = mama_val;
1099 out_fama[i] = fama_val;
1100
1101 prev_mama = mama_val;
1102 prev_fama = fama_val;
1103 }
1104}
1105
1106#[derive(Debug, Clone)]
1107pub struct MamaStream {
1108 fast_limit: f64,
1109 slow_limit: f64,
1110
1111 smooth: [f64; 8],
1112 detrender: [f64; 8],
1113 i1_buf: [f64; 8],
1114 q1_buf: [f64; 8],
1115 idx: usize,
1116
1117 prev_mesa: f64,
1118 prev_phase: f64,
1119 prev_mama: f64,
1120 prev_fama: f64,
1121 prev_i2: f64,
1122 prev_q2: f64,
1123 prev_re: f64,
1124 prev_im: f64,
1125
1126 last1: f64,
1127 last2: f64,
1128 last3: f64,
1129
1130 seeded: bool,
1131 seen: usize,
1132}
1133
1134impl MamaStream {
1135 #[inline]
1136 pub fn try_new(params: MamaParams) -> Result<Self, MamaError> {
1137 let fast_limit = params.fast_limit.unwrap_or(0.5);
1138 let slow_limit = params.slow_limit.unwrap_or(0.05);
1139 if fast_limit <= 0.0 || !fast_limit.is_finite() {
1140 return Err(MamaError::InvalidFastLimit { fast_limit });
1141 }
1142 if slow_limit <= 0.0 || !slow_limit.is_finite() {
1143 return Err(MamaError::InvalidSlowLimit { slow_limit });
1144 }
1145
1146 Ok(Self {
1147 fast_limit,
1148 slow_limit,
1149 smooth: [f64::NAN; 8],
1150 detrender: [f64::NAN; 8],
1151 i1_buf: [f64::NAN; 8],
1152 q1_buf: [f64::NAN; 8],
1153 idx: 0,
1154
1155 prev_mesa: 0.0,
1156 prev_phase: 0.0,
1157 prev_mama: f64::NAN,
1158 prev_fama: f64::NAN,
1159 prev_i2: 0.0,
1160 prev_q2: 0.0,
1161 prev_re: 0.0,
1162 prev_im: 0.0,
1163
1164 last1: f64::NAN,
1165 last2: f64::NAN,
1166 last3: f64::NAN,
1167
1168 seeded: false,
1169 seen: 0,
1170 })
1171 }
1172
1173 #[inline]
1174 pub fn update(&mut self, price: f64) -> Option<(f64, f64)> {
1175 const RING: usize = 8;
1176 const MASK: usize = RING - 1;
1177 const H0: f64 = 0.096_2;
1178 const H1: f64 = 0.576_9;
1179 const H2: f64 = -0.576_9;
1180 const H3: f64 = -0.096_2;
1181 const DEG_PER_RAD: f64 = 180.0 / std::f64::consts::PI;
1182
1183 #[inline(always)]
1184 fn hilbert4(x0: f64, x2: f64, x4: f64, x6: f64) -> f64 {
1185 H0.mul_add(x0, H1.mul_add(x2, H2.mul_add(x4, H3 * x6)))
1186 }
1187 #[inline(always)]
1188 fn lag<const N: usize>(buf: &[f64; N], pos: usize, k: usize) -> f64 {
1189 buf[(pos.wrapping_sub(k)) & (N - 1)]
1190 }
1191
1192 if !self.seeded {
1193 self.smooth = [price; RING];
1194 self.detrender = [price; RING];
1195 self.i1_buf = [price; RING];
1196 self.q1_buf = [price; RING];
1197 self.idx = 0;
1198
1199 self.prev_mesa = 0.0;
1200 self.prev_phase = 0.0;
1201 self.prev_mama = price;
1202 self.prev_fama = price;
1203 self.prev_i2 = 0.0;
1204 self.prev_q2 = 0.0;
1205 self.prev_re = 0.0;
1206 self.prev_im = 0.0;
1207
1208 self.last1 = price;
1209 self.last2 = price;
1210 self.last3 = price;
1211
1212 self.seeded = true;
1213
1214 let _ = self.process_one(price, hilbert4, lag::<RING>, DEG_PER_RAD);
1215
1216 return None;
1217 }
1218
1219 let (mama, fama) = self.process_one(price, hilbert4, lag::<RING>, DEG_PER_RAD);
1220
1221 self.seen += 1;
1222 if self.seen < 10 {
1223 return None;
1224 }
1225 Some((mama, fama))
1226 }
1227
1228 #[inline(always)]
1229 fn process_one(
1230 &mut self,
1231 price: f64,
1232 hilbert4: impl Fn(f64, f64, f64, f64) -> f64,
1233 lag: impl Fn(&[f64; 8], usize, usize) -> f64,
1234 deg_per_rad: f64,
1235 ) -> (f64, f64) {
1236 const MASK: usize = 7;
1237 let i = self.idx;
1238
1239 let s1 = if self.seen >= 1 { self.last1 } else { price };
1240 let s2 = if self.seen >= 2 { self.last2 } else { price };
1241 let s3 = if self.seen >= 3 { self.last3 } else { price };
1242 let smooth_val =
1243 0.1 * (4.0_f64.mul_add(price, 3.0_f64.mul_add(s1, 2.0_f64.mul_add(s2, s3))));
1244 self.smooth[i] = smooth_val;
1245
1246 let amp = 0.075_f64.mul_add(self.prev_mesa, 0.54);
1247
1248 let dt = amp
1249 * hilbert4(
1250 self.smooth[i],
1251 lag(&self.smooth, i, 2),
1252 lag(&self.smooth, i, 4),
1253 lag(&self.smooth, i, 6),
1254 );
1255 self.detrender[i] = dt;
1256
1257 let i1 = lag(&self.detrender, i, 3);
1258 self.i1_buf[i] = i1;
1259
1260 let q1 = amp
1261 * hilbert4(
1262 self.detrender[i],
1263 lag(&self.detrender, i, 2),
1264 lag(&self.detrender, i, 4),
1265 lag(&self.detrender, i, 6),
1266 );
1267 self.q1_buf[i] = q1;
1268
1269 let j_i = amp
1270 * hilbert4(
1271 self.i1_buf[i],
1272 lag(&self.i1_buf, i, 2),
1273 lag(&self.i1_buf, i, 4),
1274 lag(&self.i1_buf, i, 6),
1275 );
1276 let j_q = amp
1277 * hilbert4(
1278 self.q1_buf[i],
1279 lag(&self.q1_buf, i, 2),
1280 lag(&self.q1_buf, i, 4),
1281 lag(&self.q1_buf, i, 6),
1282 );
1283
1284 let i2 = i1 - j_q;
1285 let q2 = q1 + j_i;
1286
1287 let old_i2 = self.prev_i2;
1288 let old_q2 = self.prev_q2;
1289
1290 let i2s = 0.2_f64.mul_add(i2, 0.8 * old_i2);
1291 let q2s = 0.2_f64.mul_add(q2, 0.8 * old_q2);
1292 self.prev_i2 = i2s;
1293 self.prev_q2 = q2s;
1294
1295 let re = 0.2_f64.mul_add(i2s * old_i2 + q2s * old_q2, 0.8 * self.prev_re);
1296 let im = 0.2_f64.mul_add(i2s * old_q2 - q2s * old_i2, 0.8 * self.prev_im);
1297 self.prev_re = re;
1298 self.prev_im = im;
1299
1300 let mut mesa = if re != 0.0 && im != 0.0 {
1301 2.0 * std::f64::consts::PI / atan_fast(im / re)
1302 } else {
1303 self.prev_mesa
1304 };
1305
1306 mesa = mesa
1307 .min(1.5 * self.prev_mesa)
1308 .max(0.67 * self.prev_mesa)
1309 .max(6.0)
1310 .min(50.0);
1311 mesa = 0.2_f64.mul_add(mesa, 0.8 * self.prev_mesa);
1312 self.prev_mesa = mesa;
1313
1314 let phase = if i1 != 0.0 {
1315 atan_fast(q1 / i1) * deg_per_rad
1316 } else {
1317 self.prev_phase
1318 };
1319 let mut dphi = self.prev_phase - phase;
1320 if dphi < 1.0 {
1321 dphi = 1.0;
1322 }
1323 self.prev_phase = phase;
1324
1325 let mut alpha = self.fast_limit / dphi;
1326 if alpha < self.slow_limit {
1327 alpha = self.slow_limit;
1328 }
1329 if alpha > self.fast_limit {
1330 alpha = self.fast_limit;
1331 }
1332
1333 let one_minus_alpha = 1.0 - alpha;
1334 let mama = alpha.mul_add(price, one_minus_alpha * self.prev_mama);
1335
1336 let half_alpha = 0.5 * alpha;
1337 let fama = half_alpha.mul_add(mama, (1.0 - half_alpha) * self.prev_fama);
1338
1339 self.prev_mama = mama;
1340 self.prev_fama = fama;
1341
1342 self.idx = (self.idx + 1) & MASK;
1343 self.last3 = self.last2;
1344 self.last2 = self.last1;
1345 self.last1 = price;
1346
1347 (mama, fama)
1348 }
1349}
1350
1351#[derive(Clone, Debug)]
1352pub struct MamaBatchRange {
1353 pub fast_limit: (f64, f64, f64),
1354 pub slow_limit: (f64, f64, f64),
1355}
1356
1357impl Default for MamaBatchRange {
1358 fn default() -> Self {
1359 Self {
1360 fast_limit: (0.5, 0.749, 0.001),
1361 slow_limit: (0.05, 0.05, 0.0),
1362 }
1363 }
1364}
1365
1366#[derive(Clone, Debug, Default)]
1367pub struct MamaBatchBuilder {
1368 range: MamaBatchRange,
1369 kernel: Kernel,
1370}
1371
1372impl MamaBatchBuilder {
1373 pub fn new() -> Self {
1374 Self::default()
1375 }
1376 pub fn kernel(mut self, k: Kernel) -> Self {
1377 self.kernel = k;
1378 self
1379 }
1380 #[inline]
1381 pub fn fast_limit_range(mut self, start: f64, end: f64, step: f64) -> Self {
1382 self.range.fast_limit = (start, end, step);
1383 self
1384 }
1385 #[inline]
1386 pub fn fast_limit_static(mut self, x: f64) -> Self {
1387 self.range.fast_limit = (x, x, 0.0);
1388 self
1389 }
1390 #[inline]
1391 pub fn slow_limit_range(mut self, start: f64, end: f64, step: f64) -> Self {
1392 self.range.slow_limit = (start, end, step);
1393 self
1394 }
1395 #[inline]
1396 pub fn slow_limit_static(mut self, x: f64) -> Self {
1397 self.range.slow_limit = (x, x, 0.0);
1398 self
1399 }
1400 pub fn apply_slice(self, data: &[f64]) -> Result<MamaBatchOutput, MamaError> {
1401 mama_batch_with_kernel(data, &self.range, self.kernel)
1402 }
1403 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<MamaBatchOutput, MamaError> {
1404 MamaBatchBuilder::new().kernel(k).apply_slice(data)
1405 }
1406 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<MamaBatchOutput, MamaError> {
1407 let slice = source_type(c, src);
1408 self.apply_slice(slice)
1409 }
1410 pub fn with_default_candles(c: &Candles) -> Result<MamaBatchOutput, MamaError> {
1411 MamaBatchBuilder::new()
1412 .kernel(Kernel::Auto)
1413 .apply_candles(c, "close")
1414 }
1415}
1416
1417#[derive(Clone, Debug)]
1418pub struct MamaBatchOutput {
1419 pub mama_values: Vec<f64>,
1420 pub fama_values: Vec<f64>,
1421 pub combos: Vec<MamaParams>,
1422 pub rows: usize,
1423 pub cols: usize,
1424}
1425
1426impl MamaBatchOutput {
1427 pub fn row_for_params(&self, p: &MamaParams) -> Option<usize> {
1428 self.combos.iter().position(|c| {
1429 (c.fast_limit.unwrap_or(0.5) - p.fast_limit.unwrap_or(0.5)).abs() < 1e-12
1430 && (c.slow_limit.unwrap_or(0.05) - p.slow_limit.unwrap_or(0.05)).abs() < 1e-12
1431 })
1432 }
1433 pub fn mama_for(&self, p: &MamaParams) -> Option<&[f64]> {
1434 self.row_for_params(p).map(|row| {
1435 let start = row * self.cols;
1436 &self.mama_values[start..start + self.cols]
1437 })
1438 }
1439 pub fn fama_for(&self, p: &MamaParams) -> Option<&[f64]> {
1440 self.row_for_params(p).map(|row| {
1441 let start = row * self.cols;
1442 &self.fama_values[start..start + self.cols]
1443 })
1444 }
1445}
1446
1447#[inline(always)]
1448pub fn expand_grid(r: &MamaBatchRange) -> Result<Vec<MamaParams>, MamaError> {
1449 fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, MamaError> {
1450 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
1451 return Ok(vec![start]);
1452 }
1453
1454 let mut step_signed = step;
1455 if end < start && step_signed > 0.0 {
1456 step_signed = -step_signed;
1457 } else if end > start && step_signed < 0.0 {
1458 step_signed = -step_signed;
1459 }
1460
1461 let mut v = Vec::new();
1462 let eps = 1e-12_f64;
1463 let mut x = start;
1464 if step_signed > 0.0 {
1465 while x <= end + eps {
1466 v.push(x);
1467 x += step_signed;
1468 }
1469 } else {
1470 while x >= end - eps {
1471 v.push(x);
1472 x += step_signed;
1473 }
1474 }
1475
1476 if v.is_empty() {
1477 return Err(MamaError::InvalidRange { start, end, step });
1478 }
1479 Ok(v)
1480 }
1481
1482 let fast_limits = axis_f64(r.fast_limit)?;
1483 let slow_limits = axis_f64(r.slow_limit)?;
1484
1485 let cap = fast_limits
1486 .len()
1487 .checked_mul(slow_limits.len())
1488 .ok_or(MamaError::InvalidRange {
1489 start: r.fast_limit.0,
1490 end: r.fast_limit.1,
1491 step: r.fast_limit.2,
1492 })?;
1493
1494 let mut out = Vec::with_capacity(cap);
1495 for &f in &fast_limits {
1496 for &s in &slow_limits {
1497 out.push(MamaParams {
1498 fast_limit: Some(f),
1499 slow_limit: Some(s),
1500 });
1501 }
1502 }
1503 Ok(out)
1504}
1505
1506pub fn mama_batch_with_kernel(
1507 data: &[f64],
1508 sweep: &MamaBatchRange,
1509 k: Kernel,
1510) -> Result<MamaBatchOutput, MamaError> {
1511 let kernel = match k {
1512 Kernel::Auto => Kernel::ScalarBatch,
1513 other if other.is_batch() => other,
1514 other => return Err(MamaError::InvalidKernelForBatch(other)),
1515 };
1516
1517 let simd = Kernel::Scalar;
1518 mama_batch_par_slice(data, sweep, simd)
1519}
1520
1521#[inline(always)]
1522pub fn mama_batch_slice(
1523 data: &[f64],
1524 sweep: &MamaBatchRange,
1525 kern: Kernel,
1526) -> Result<MamaBatchOutput, MamaError> {
1527 mama_batch_inner(data, sweep, kern, false)
1528}
1529
1530#[inline(always)]
1531pub fn mama_batch_par_slice(
1532 data: &[f64],
1533 sweep: &MamaBatchRange,
1534 kern: Kernel,
1535) -> Result<MamaBatchOutput, MamaError> {
1536 mama_batch_inner(data, sweep, kern, true)
1537}
1538
1539fn mama_batch_inner(
1540 data: &[f64],
1541 sweep: &MamaBatchRange,
1542 kern: Kernel,
1543 parallel: bool,
1544) -> Result<MamaBatchOutput, MamaError> {
1545 let combos = expand_grid(sweep)?;
1546 if combos.is_empty() {
1547 return Err(MamaError::InvalidRange {
1548 start: sweep.fast_limit.0,
1549 end: sweep.fast_limit.1,
1550 step: sweep.fast_limit.2,
1551 });
1552 }
1553 if data.len() < 10 {
1554 return Err(MamaError::NotEnoughData {
1555 needed: 10,
1556 found: data.len(),
1557 });
1558 }
1559
1560 for combo in &combos {
1561 let fast_limit = combo.fast_limit.unwrap_or(0.5);
1562 let slow_limit = combo.slow_limit.unwrap_or(0.05);
1563
1564 if fast_limit <= 0.0 || fast_limit.is_nan() || fast_limit.is_infinite() {
1565 return Err(MamaError::InvalidFastLimit { fast_limit });
1566 }
1567 if slow_limit <= 0.0 || slow_limit.is_nan() || slow_limit.is_infinite() {
1568 return Err(MamaError::InvalidSlowLimit { slow_limit });
1569 }
1570 }
1571
1572 let rows = combos.len();
1573 let cols = data.len();
1574
1575 let mut raw_mama = make_uninit_matrix(rows, cols);
1576 let mut raw_fama = make_uninit_matrix(rows, cols);
1577
1578 let warm_prefixes = vec![10; rows];
1579 unsafe {
1580 init_matrix_prefixes(&mut raw_mama, cols, &warm_prefixes);
1581 init_matrix_prefixes(&mut raw_fama, cols, &warm_prefixes);
1582 }
1583
1584 let delta_phase: Vec<f64> = {
1585 const RING: usize = 8;
1586 const MASK: usize = RING - 1;
1587 const H0: f64 = 0.096_2;
1588 const H1: f64 = 0.576_9;
1589 const H2: f64 = -0.576_9;
1590 const H3: f64 = -0.096_2;
1591 const DEG_PER_RAD: f64 = 180.0 / std::f64::consts::PI;
1592
1593 #[inline(always)]
1594 fn hilbert4(x0: f64, x2: f64, x4: f64, x6: f64) -> f64 {
1595 H0.mul_add(x0, H1.mul_add(x2, H2.mul_add(x4, H3 * x6)))
1596 }
1597 #[inline(always)]
1598 fn lag<const N: usize>(buf: &[f64; N], pos: usize, k: usize) -> f64 {
1599 buf[(pos.wrapping_sub(k)) & (N - 1)]
1600 }
1601
1602 let mut out = vec![1.0; cols];
1603 if cols == 0 {
1604 out
1605 } else {
1606 let first = data[0];
1607 let mut smooth = [first; RING];
1608 let mut detrender = [first; RING];
1609 let mut i1_buf = [first; RING];
1610 let mut q1_buf = [first; RING];
1611
1612 let mut idx = 0usize;
1613 let mut prev_mesa = 0.0;
1614 let mut prev_phase = 0.0;
1615 let mut prev_i2 = 0.0;
1616 let mut prev_q2 = 0.0;
1617 let mut prev_re = 0.0;
1618 let mut prev_im = 0.0;
1619
1620 for (i, &price) in data.iter().enumerate() {
1621 let s1 = if i >= 1 { data[i - 1] } else { price };
1622 let s2 = if i >= 2 { data[i - 2] } else { price };
1623 let s3 = if i >= 3 { data[i - 3] } else { price };
1624 let smooth_val =
1625 0.1 * (4.0_f64.mul_add(price, 3.0_f64.mul_add(s1, 2.0_f64.mul_add(s2, s3))));
1626 smooth[idx] = smooth_val;
1627
1628 let amp = 0.075_f64.mul_add(prev_mesa, 0.54);
1629 let dt = amp
1630 * hilbert4(
1631 smooth[idx],
1632 lag(&smooth, idx, 2),
1633 lag(&smooth, idx, 4),
1634 lag(&smooth, idx, 6),
1635 );
1636 detrender[idx] = dt;
1637
1638 let i1 = lag(&detrender, idx, 3);
1639 i1_buf[idx] = i1;
1640 let q1 = amp
1641 * hilbert4(
1642 detrender[idx],
1643 lag(&detrender, idx, 2),
1644 lag(&detrender, idx, 4),
1645 lag(&detrender, idx, 6),
1646 );
1647 q1_buf[idx] = q1;
1648
1649 let j_i = amp
1650 * hilbert4(
1651 i1_buf[idx],
1652 lag(&i1_buf, idx, 2),
1653 lag(&i1_buf, idx, 4),
1654 lag(&i1_buf, idx, 6),
1655 );
1656 let j_q = amp
1657 * hilbert4(
1658 q1_buf[idx],
1659 lag(&q1_buf, idx, 2),
1660 lag(&q1_buf, idx, 4),
1661 lag(&q1_buf, idx, 6),
1662 );
1663
1664 let i2 = i1 - j_q;
1665 let q2 = q1 + j_i;
1666 let old_i2 = prev_i2;
1667 let old_q2 = prev_q2;
1668 let i2s = 0.2_f64.mul_add(i2, 0.8 * old_i2);
1669 let q2s = 0.2_f64.mul_add(q2, 0.8 * old_q2);
1670 prev_i2 = i2s;
1671 prev_q2 = q2s;
1672 let re = 0.2_f64.mul_add(i2s * old_i2 + q2s * old_q2, 0.8 * prev_re);
1673 let im = 0.2_f64.mul_add(i2s * old_q2 - q2s * old_i2, 0.8 * prev_im);
1674 prev_re = re;
1675 prev_im = im;
1676
1677 let mut mesa = if re != 0.0 && im != 0.0 {
1678 2.0 * std::f64::consts::PI / atan_fast(im / re)
1679 } else {
1680 prev_mesa
1681 };
1682 if mesa > 1.5 * prev_mesa {
1683 mesa = 1.5 * prev_mesa;
1684 }
1685 if mesa < 0.67 * prev_mesa {
1686 mesa = 0.67 * prev_mesa;
1687 }
1688 if mesa < 6.0 {
1689 mesa = 6.0;
1690 }
1691 if mesa > 50.0 {
1692 mesa = 50.0;
1693 }
1694 mesa = 0.2_f64.mul_add(mesa, 0.8 * prev_mesa);
1695 prev_mesa = mesa;
1696
1697 let phase = if i1 != 0.0 {
1698 atan_fast(q1 / i1) * DEG_PER_RAD
1699 } else {
1700 prev_phase
1701 };
1702 let mut dphi = prev_phase - phase;
1703 if dphi < 1.0 {
1704 dphi = 1.0;
1705 }
1706 prev_phase = phase;
1707 out[i] = dphi;
1708
1709 idx = (idx + 1) & MASK;
1710 }
1711 out
1712 }
1713 };
1714
1715 let do_row = |row: usize, dst_m: &mut [MaybeUninit<f64>], dst_f: &mut [MaybeUninit<f64>]| unsafe {
1716 let prm = &combos[row];
1717 let fast = prm.fast_limit.unwrap_or(0.5);
1718 let slow = prm.slow_limit.unwrap_or(0.05);
1719
1720 let out_m = core::slice::from_raw_parts_mut(dst_m.as_mut_ptr() as *mut f64, dst_m.len());
1721 let out_f = core::slice::from_raw_parts_mut(dst_f.as_mut_ptr() as *mut f64, dst_f.len());
1722
1723 let mut prev_mama = data[0];
1724 let mut prev_fama = data[0];
1725 for i in 0..cols {
1726 let price = data[i];
1727 let mut alpha = fast / delta_phase[i];
1728 if alpha < slow {
1729 alpha = slow;
1730 }
1731 if alpha > fast {
1732 alpha = fast;
1733 }
1734
1735 let mama = alpha.mul_add(price, (1.0 - alpha) * prev_mama);
1736 let fama = (0.5 * alpha).mul_add(mama, (1.0 - 0.5 * alpha) * prev_fama);
1737 prev_mama = mama;
1738 prev_fama = fama;
1739 out_m[i] = mama;
1740 out_f[i] = fama;
1741 }
1742
1743 for j in 0..10.min(out_m.len()) {
1744 out_m[j] = f64::NAN;
1745 out_f[j] = f64::NAN;
1746 }
1747 };
1748
1749 if parallel {
1750 #[cfg(not(target_arch = "wasm32"))]
1751 {
1752 raw_mama
1753 .par_chunks_mut(cols)
1754 .zip(raw_fama.par_chunks_mut(cols))
1755 .enumerate()
1756 .for_each(|(row, (m_row, f_row))| do_row(row, m_row, f_row));
1757 }
1758
1759 #[cfg(target_arch = "wasm32")]
1760 {
1761 for (row, (m_row, f_row)) in raw_mama
1762 .chunks_mut(cols)
1763 .zip(raw_fama.chunks_mut(cols))
1764 .enumerate()
1765 {
1766 do_row(row, m_row, f_row);
1767 }
1768 }
1769 } else {
1770 for (row, (m_row, f_row)) in raw_mama
1771 .chunks_mut(cols)
1772 .zip(raw_fama.chunks_mut(cols))
1773 .enumerate()
1774 {
1775 do_row(row, m_row, f_row);
1776 }
1777 }
1778
1779 let mut guard_m = core::mem::ManuallyDrop::new(raw_mama);
1780 let mut guard_f = core::mem::ManuallyDrop::new(raw_fama);
1781
1782 let mama_values = unsafe {
1783 Vec::from_raw_parts(
1784 guard_m.as_mut_ptr() as *mut f64,
1785 guard_m.len(),
1786 guard_m.capacity(),
1787 )
1788 };
1789 let fama_values = unsafe {
1790 Vec::from_raw_parts(
1791 guard_f.as_mut_ptr() as *mut f64,
1792 guard_f.len(),
1793 guard_f.capacity(),
1794 )
1795 };
1796
1797 Ok(MamaBatchOutput {
1798 mama_values,
1799 fama_values,
1800 combos,
1801 rows,
1802 cols,
1803 })
1804}
1805
1806pub fn mama_batch_inner_into(
1807 data: &[f64],
1808 sweep: &MamaBatchRange,
1809 kern: Kernel,
1810 parallel: bool,
1811 out_mama: &mut [f64],
1812 out_fama: &mut [f64],
1813) -> Result<Vec<MamaParams>, MamaError> {
1814 let combos = expand_grid(sweep)?;
1815 if combos.is_empty() {
1816 return Err(MamaError::InvalidRange {
1817 start: sweep.fast_limit.0,
1818 end: sweep.fast_limit.1,
1819 step: sweep.fast_limit.2,
1820 });
1821 }
1822 if data.len() < 10 {
1823 return Err(MamaError::NotEnoughData {
1824 needed: 10,
1825 found: data.len(),
1826 });
1827 }
1828
1829 for combo in &combos {
1830 let fast_limit = combo.fast_limit.unwrap_or(0.5);
1831 let slow_limit = combo.slow_limit.unwrap_or(0.05);
1832
1833 if fast_limit <= 0.0 || fast_limit.is_nan() || fast_limit.is_infinite() {
1834 return Err(MamaError::InvalidFastLimit { fast_limit });
1835 }
1836 if slow_limit <= 0.0 || slow_limit.is_nan() || slow_limit.is_infinite() {
1837 return Err(MamaError::InvalidSlowLimit { slow_limit });
1838 }
1839 }
1840
1841 let rows = combos.len();
1842 let cols = data.len();
1843
1844 let expected = rows.checked_mul(cols).ok_or(MamaError::InvalidRange {
1845 start: sweep.fast_limit.0,
1846 end: sweep.fast_limit.1,
1847 step: sweep.fast_limit.2,
1848 })?;
1849 if out_mama.len() != expected || out_fama.len() != expected {
1850 return Err(MamaError::OutputLengthMismatch {
1851 expected,
1852 got: out_mama.len().min(out_fama.len()),
1853 });
1854 }
1855
1856 let out_mama_uninit = unsafe {
1857 std::slice::from_raw_parts_mut(
1858 out_mama.as_mut_ptr() as *mut MaybeUninit<f64>,
1859 out_mama.len(),
1860 )
1861 };
1862 let out_fama_uninit = unsafe {
1863 std::slice::from_raw_parts_mut(
1864 out_fama.as_mut_ptr() as *mut MaybeUninit<f64>,
1865 out_fama.len(),
1866 )
1867 };
1868
1869 let warm_prefixes = vec![10; rows];
1870 unsafe {
1871 init_matrix_prefixes(out_mama_uninit, cols, &warm_prefixes);
1872 init_matrix_prefixes(out_fama_uninit, cols, &warm_prefixes);
1873 }
1874
1875 let do_row = |row: usize, dst_m: &mut [MaybeUninit<f64>], dst_f: &mut [MaybeUninit<f64>]| unsafe {
1876 let prm = &combos[row];
1877 let fast = prm.fast_limit.unwrap_or(0.5);
1878 let slow = prm.slow_limit.unwrap_or(0.05);
1879
1880 let out_m = core::slice::from_raw_parts_mut(dst_m.as_mut_ptr() as *mut f64, dst_m.len());
1881 let out_f = core::slice::from_raw_parts_mut(dst_f.as_mut_ptr() as *mut f64, dst_f.len());
1882
1883 match kern {
1884 Kernel::Scalar => mama_row_scalar(data, fast, slow, out_m, out_f),
1885 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1886 Kernel::Avx2 => mama_row_avx2(data, fast, slow, out_m, out_f),
1887 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1888 Kernel::Avx512 => mama_row_avx512(data, fast, slow, out_m, out_f),
1889 _ => unreachable!(),
1890 }
1891
1892 for j in 0..10.min(out_m.len()) {
1893 out_m[j] = f64::NAN;
1894 out_f[j] = f64::NAN;
1895 }
1896 };
1897
1898 if parallel {
1899 #[cfg(not(target_arch = "wasm32"))]
1900 {
1901 out_mama_uninit
1902 .par_chunks_mut(cols)
1903 .zip(out_fama_uninit.par_chunks_mut(cols))
1904 .enumerate()
1905 .for_each(|(row, (m_row, f_row))| do_row(row, m_row, f_row));
1906 }
1907
1908 #[cfg(target_arch = "wasm32")]
1909 {
1910 for (row, (m_row, f_row)) in out_mama_uninit
1911 .chunks_mut(cols)
1912 .zip(out_fama_uninit.chunks_mut(cols))
1913 .enumerate()
1914 {
1915 do_row(row, m_row, f_row);
1916 }
1917 }
1918 } else {
1919 for (row, (m_row, f_row)) in out_mama_uninit
1920 .chunks_mut(cols)
1921 .zip(out_fama_uninit.chunks_mut(cols))
1922 .enumerate()
1923 {
1924 do_row(row, m_row, f_row);
1925 }
1926 }
1927
1928 Ok(combos)
1929}
1930
1931#[inline(always)]
1932pub unsafe fn mama_row_scalar(
1933 data: &[f64],
1934 fast_limit: f64,
1935 slow_limit: f64,
1936 out_mama: &mut [f64],
1937 out_fama: &mut [f64],
1938) {
1939 mama_scalar_inplace(data, fast_limit, slow_limit, out_mama, out_fama);
1940}
1941
1942#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1943#[inline(always)]
1944pub unsafe fn mama_row_avx2(
1945 data: &[f64],
1946 fast_limit: f64,
1947 slow_limit: f64,
1948 out_mama: &mut [f64],
1949 out_fama: &mut [f64],
1950) {
1951 mama_avx2_inplace(data, fast_limit, slow_limit, out_mama, out_fama);
1952}
1953
1954#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1955#[inline(always)]
1956pub unsafe fn mama_row_avx512(
1957 data: &[f64],
1958 fast_limit: f64,
1959 slow_limit: f64,
1960 out_mama: &mut [f64],
1961 out_fama: &mut [f64],
1962) {
1963 mama_avx512_inplace(data, fast_limit, slow_limit, out_mama, out_fama);
1964}
1965
1966#[cfg(test)]
1967mod tests {
1968 use super::*;
1969 use crate::skip_if_unsupported;
1970 use crate::utilities::data_loader::read_candles_from_csv;
1971 use paste::paste;
1972 use proptest::prelude::*;
1973
1974 fn check_mama_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1975 skip_if_unsupported!(kernel, test_name);
1976 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1977 let candles = read_candles_from_csv(file_path)?;
1978 let default_params = MamaParams {
1979 fast_limit: None,
1980 slow_limit: None,
1981 };
1982 let input = MamaInput::from_candles(&candles, "close", default_params);
1983 let output = mama_with_kernel(&input, kernel)?;
1984 assert_eq!(output.mama_values.len(), candles.close.len());
1985 assert_eq!(output.fama_values.len(), candles.close.len());
1986 Ok(())
1987 }
1988
1989 fn check_mama_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1990 skip_if_unsupported!(kernel, test_name);
1991 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1992 let candles = read_candles_from_csv(file_path)?;
1993 let input = MamaInput::from_candles(&candles, "close", MamaParams::default());
1994 let result = mama_with_kernel(&input, kernel)?;
1995 assert_eq!(result.mama_values.len(), candles.close.len());
1996 assert_eq!(result.fama_values.len(), candles.close.len());
1997 Ok(())
1998 }
1999
2000 fn check_mama_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2001 skip_if_unsupported!(kernel, test_name);
2002 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2003 let candles = read_candles_from_csv(file_path)?;
2004 let input = MamaInput::with_default_candles(&candles);
2005 match input.data {
2006 MamaData::Candles { source, .. } => assert_eq!(source, "close"),
2007 _ => panic!("Expected MamaData::Candles"),
2008 }
2009 let output = mama_with_kernel(&input, kernel)?;
2010 assert_eq!(output.mama_values.len(), candles.close.len());
2011 assert_eq!(output.fama_values.len(), candles.close.len());
2012 Ok(())
2013 }
2014
2015 fn check_mama_with_insufficient_data(
2016 test_name: &str,
2017 kernel: Kernel,
2018 ) -> Result<(), Box<dyn Error>> {
2019 skip_if_unsupported!(kernel, test_name);
2020 let input_data = [100.0; 9];
2021 let params = MamaParams::default();
2022 let input = MamaInput::from_slice(&input_data, params);
2023 let res = mama_with_kernel(&input, kernel);
2024 assert!(res.is_err());
2025 Ok(())
2026 }
2027
2028 fn check_mama_very_small_dataset(
2029 test_name: &str,
2030 kernel: Kernel,
2031 ) -> Result<(), Box<dyn Error>> {
2032 skip_if_unsupported!(kernel, test_name);
2033 let input_data = [42.0; 10];
2034 let params = MamaParams::default();
2035 let input = MamaInput::from_slice(&input_data, params);
2036 let result = mama_with_kernel(&input, kernel)?;
2037 assert_eq!(result.mama_values.len(), input_data.len());
2038 assert_eq!(result.fama_values.len(), input_data.len());
2039 Ok(())
2040 }
2041
2042 fn check_mama_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2043 skip_if_unsupported!(kernel, test_name);
2044 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2045 let candles = read_candles_from_csv(file_path)?;
2046 let first_params = MamaParams::default();
2047 let first_input = MamaInput::from_candles(&candles, "close", first_params);
2048 let first_result = mama_with_kernel(&first_input, kernel)?;
2049 let second_params = MamaParams {
2050 fast_limit: Some(0.7),
2051 slow_limit: Some(0.1),
2052 };
2053 let second_input = MamaInput::from_slice(&first_result.mama_values, second_params);
2054 let second_result = mama_with_kernel(&second_input, kernel)?;
2055 assert_eq!(
2056 second_result.mama_values.len(),
2057 first_result.mama_values.len()
2058 );
2059 assert_eq!(
2060 second_result.fama_values.len(),
2061 first_result.mama_values.len()
2062 );
2063 Ok(())
2064 }
2065
2066 fn check_mama_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2067 skip_if_unsupported!(kernel, test_name);
2068 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2069 let candles = read_candles_from_csv(file_path)?;
2070 let params = MamaParams::default();
2071 let input = MamaInput::from_candles(&candles, "close", params);
2072 let result = mama_with_kernel(&input, kernel)?;
2073 for (i, &val) in result.mama_values.iter().enumerate() {
2074 if i > 20 {
2075 assert!(val.is_finite());
2076 }
2077 }
2078 for (i, &val) in result.fama_values.iter().enumerate() {
2079 if i > 20 {
2080 assert!(val.is_finite());
2081 }
2082 }
2083 Ok(())
2084 }
2085
2086 macro_rules! generate_all_mama_tests {
2087 ($($test_fn:ident),*) => {
2088 paste! {
2089 $(
2090 #[test]
2091 fn [<$test_fn _scalar_f64>]() {
2092 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2093 }
2094 )*
2095 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2096 $(
2097 #[test]
2098 fn [<$test_fn _avx2_f64>]() {
2099 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2100 }
2101 #[test]
2102 fn [<$test_fn _avx512_f64>]() {
2103 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2104 }
2105 )*
2106 }
2107 }
2108 }
2109
2110 #[cfg(debug_assertions)]
2111 fn check_mama_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2112 skip_if_unsupported!(kernel, test_name);
2113
2114 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2115 let candles = read_candles_from_csv(file_path)?;
2116
2117 let test_cases = vec![
2118 MamaParams::default(),
2119 MamaParams {
2120 fast_limit: Some(0.3),
2121 slow_limit: Some(0.03),
2122 },
2123 MamaParams {
2124 fast_limit: Some(0.4),
2125 slow_limit: Some(0.04),
2126 },
2127 MamaParams {
2128 fast_limit: Some(0.5),
2129 slow_limit: Some(0.05),
2130 },
2131 MamaParams {
2132 fast_limit: Some(0.6),
2133 slow_limit: Some(0.06),
2134 },
2135 MamaParams {
2136 fast_limit: Some(0.7),
2137 slow_limit: Some(0.07),
2138 },
2139 MamaParams {
2140 fast_limit: Some(0.8),
2141 slow_limit: Some(0.01),
2142 },
2143 MamaParams {
2144 fast_limit: Some(0.2),
2145 slow_limit: Some(0.1),
2146 },
2147 MamaParams {
2148 fast_limit: Some(0.9),
2149 slow_limit: Some(0.02),
2150 },
2151 ];
2152
2153 for params in test_cases {
2154 let input = MamaInput::from_candles(&candles, "close", params.clone());
2155 let output = mama_with_kernel(&input, kernel)?;
2156
2157 for (i, &val) in output.mama_values.iter().enumerate() {
2158 if val.is_nan() {
2159 continue;
2160 }
2161
2162 let bits = val.to_bits();
2163
2164 if bits == 0x11111111_11111111 {
2165 panic!(
2166 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in mama_values with params fast_limit={:?}, slow_limit={:?}",
2167 test_name, val, bits, i, params.fast_limit, params.slow_limit
2168 );
2169 }
2170
2171 if bits == 0x22222222_22222222 {
2172 panic!(
2173 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in mama_values with params fast_limit={:?}, slow_limit={:?}",
2174 test_name, val, bits, i, params.fast_limit, params.slow_limit
2175 );
2176 }
2177
2178 if bits == 0x33333333_33333333 {
2179 panic!(
2180 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in mama_values with params fast_limit={:?}, slow_limit={:?}",
2181 test_name, val, bits, i, params.fast_limit, params.slow_limit
2182 );
2183 }
2184 }
2185
2186 for (i, &val) in output.fama_values.iter().enumerate() {
2187 if val.is_nan() {
2188 continue;
2189 }
2190
2191 let bits = val.to_bits();
2192
2193 if bits == 0x11111111_11111111 {
2194 panic!(
2195 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in fama_values with params fast_limit={:?}, slow_limit={:?}",
2196 test_name, val, bits, i, params.fast_limit, params.slow_limit
2197 );
2198 }
2199
2200 if bits == 0x22222222_22222222 {
2201 panic!(
2202 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in fama_values with params fast_limit={:?}, slow_limit={:?}",
2203 test_name, val, bits, i, params.fast_limit, params.slow_limit
2204 );
2205 }
2206
2207 if bits == 0x33333333_33333333 {
2208 panic!(
2209 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in fama_values with params fast_limit={:?}, slow_limit={:?}",
2210 test_name, val, bits, i, params.fast_limit, params.slow_limit
2211 );
2212 }
2213 }
2214 }
2215
2216 Ok(())
2217 }
2218
2219 #[cfg(not(debug_assertions))]
2220 fn check_mama_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2221 Ok(())
2222 }
2223
2224 fn check_mama_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2225 skip_if_unsupported!(kernel, test_name);
2226
2227 let strat = (10usize..=200).prop_flat_map(|len| {
2228 (
2229 prop::collection::vec(
2230 (-1e5f64..1e5f64).prop_filter("finite", |x| x.is_finite()),
2231 len,
2232 ),
2233 (0.01f64..0.99f64).prop_filter("valid fast_limit", |x| x.is_finite() && *x > 0.0),
2234 (0.001f64..0.5f64).prop_filter("valid slow_limit", |x| x.is_finite() && *x > 0.0),
2235 )
2236 });
2237
2238 proptest::test_runner::TestRunner::default()
2239 .run(&strat, |(data, fast_limit, slow_limit)| {
2240 let slow = slow_limit.min(fast_limit * 0.9);
2241
2242 let params = MamaParams {
2243 fast_limit: Some(fast_limit),
2244 slow_limit: Some(slow),
2245 };
2246 let input = MamaInput::from_slice(&data, params);
2247
2248 let result = mama_with_kernel(&input, kernel).unwrap();
2249 let mama_out = &result.mama_values;
2250 let fama_out = &result.fama_values;
2251
2252 let ref_result = mama_with_kernel(&input, Kernel::Scalar).unwrap();
2253 let ref_mama = &ref_result.mama_values;
2254 let ref_fama = &ref_result.fama_values;
2255
2256 prop_assert_eq!(mama_out.len(), data.len(), "MAMA output length mismatch");
2257 prop_assert_eq!(fama_out.len(), data.len(), "FAMA output length mismatch");
2258
2259 const WARMUP: usize = 10;
2260 for i in 0..data.len() {
2261 if i < WARMUP {
2262 prop_assert!(
2263 mama_out[i].is_nan(),
2264 "MAMA should have NaN warmup at index {}, got {}",
2265 i,
2266 mama_out[i]
2267 );
2268 prop_assert!(
2269 fama_out[i].is_nan(),
2270 "FAMA should have NaN warmup at index {}, got {}",
2271 i,
2272 fama_out[i]
2273 );
2274 } else {
2275 prop_assert!(
2276 mama_out[i].is_finite(),
2277 "MAMA should output finite values at index {}, got {}",
2278 i,
2279 mama_out[i]
2280 );
2281 prop_assert!(
2282 fama_out[i].is_finite(),
2283 "FAMA should output finite values at index {}, got {}",
2284 i,
2285 fama_out[i]
2286 );
2287 }
2288 }
2289
2290 let data_min = data.iter().cloned().fold(f64::INFINITY, f64::min);
2291 let data_max = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
2292 let data_range = data_max - data_min;
2293
2294 let tolerance = data_range * 0.2 + 10.0;
2295
2296 for i in WARMUP..data.len() {
2297 prop_assert!(
2298 mama_out[i] >= data_min - tolerance && mama_out[i] <= data_max + tolerance,
2299 "MAMA at index {} ({}) outside bounds [{}, {}]",
2300 i,
2301 mama_out[i],
2302 data_min - tolerance,
2303 data_max + tolerance
2304 );
2305 prop_assert!(
2306 fama_out[i] >= data_min - tolerance && fama_out[i] <= data_max + tolerance,
2307 "FAMA at index {} ({}) outside bounds [{}, {}]",
2308 i,
2309 fama_out[i],
2310 data_min - tolerance,
2311 data_max + tolerance
2312 );
2313 }
2314
2315 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-9) {
2316 let constant_val = data[0];
2317
2318 for i in 10..data.len() {
2319 prop_assert!(
2320 (mama_out[i] - constant_val).abs() < 1e-6,
2321 "MAMA should converge to constant value {} at index {}, got {}",
2322 constant_val,
2323 i,
2324 mama_out[i]
2325 );
2326 prop_assert!(
2327 (fama_out[i] - constant_val).abs() < 1e-6,
2328 "FAMA should converge to constant value {} at index {}, got {}",
2329 constant_val,
2330 i,
2331 fama_out[i]
2332 );
2333 }
2334 }
2335
2336 if data.len() > 30 {
2337 let mama_variance = variance(&mama_out[10..]);
2338 let fama_variance = variance(&fama_out[10..]);
2339
2340 prop_assert!(
2341 mama_variance >= 0.0 && mama_variance.is_finite(),
2342 "MAMA variance should be finite and non-negative: {}",
2343 mama_variance
2344 );
2345 prop_assert!(
2346 fama_variance >= 0.0 && fama_variance.is_finite(),
2347 "FAMA variance should be finite and non-negative: {}",
2348 fama_variance
2349 );
2350
2351 let data_variance = variance(&data);
2352 if data_variance > 1e-6 {
2353 prop_assert!(
2354 mama_variance < data_variance * 100.0,
2355 "MAMA variance ({}) too large relative to data variance ({})",
2356 mama_variance,
2357 data_variance
2358 );
2359 prop_assert!(
2360 fama_variance < data_variance * 100.0,
2361 "FAMA variance ({}) too large relative to data variance ({})",
2362 fama_variance,
2363 data_variance
2364 );
2365 }
2366 }
2367
2368 for i in WARMUP..data.len() {
2369 prop_assert!(
2370 mama_out[i].is_finite(),
2371 "MAMA kernel {:?} produced non-finite value at idx {}: {}",
2372 kernel,
2373 i,
2374 mama_out[i]
2375 );
2376 prop_assert!(
2377 fama_out[i].is_finite(),
2378 "FAMA kernel {:?} produced non-finite value at idx {}: {}",
2379 kernel,
2380 i,
2381 fama_out[i]
2382 );
2383 }
2384
2385 if data.len() > 50 && fast_limit > slow * 2.0 && variance(&data) > 1e-6 {
2386 let alt_params = MamaParams {
2387 fast_limit: Some(fast_limit * 0.5),
2388 slow_limit: Some(slow),
2389 };
2390 let alt_input = MamaInput::from_slice(&data, alt_params);
2391 if let Ok(alt_result) = mama_with_kernel(&alt_input, kernel) {
2392 let mama_var = variance(&mama_out[20..]);
2393 let alt_var = variance(&alt_result.mama_values[20..]);
2394
2395 if mama_var > 1e-6 && alt_var > 1e-6 {
2396 prop_assert!(
2397 (mama_var - alt_var).abs() > 1e-12,
2398 "MAMA should be sensitive to fast_limit parameter"
2399 );
2400 }
2401 }
2402 }
2403
2404 if (fast_limit - slow).abs() < 0.01 && data.len() > 20 {
2405 for i in 10..data.len() {
2406 prop_assert!(
2407 mama_out[i].is_finite() && fama_out[i].is_finite(),
2408 "MAMA/FAMA should remain finite even with close limits at idx {}",
2409 i
2410 );
2411
2412 prop_assert!(
2413 mama_out[i].abs() < data_max.abs() * 100.0 + 1000.0,
2414 "MAMA should not diverge with close limits"
2415 );
2416 prop_assert!(
2417 fama_out[i].abs() < data_max.abs() * 100.0 + 1000.0,
2418 "FAMA should not diverge with close limits"
2419 );
2420 }
2421 }
2422
2423 let is_monotonic_inc = data.windows(2).all(|w| w[1] >= w[0] - 1e-9);
2424 let is_monotonic_dec = data.windows(2).all(|w| w[1] <= w[0] + 1e-9);
2425
2426 if (is_monotonic_inc || is_monotonic_dec) && data.len() > 20 {
2427 for i in 11..data.len() {
2428 if is_monotonic_inc {
2429 prop_assert!(
2430 mama_out[i] >= mama_out[i - 10] - tolerance * 0.1,
2431 "MAMA should follow increasing trend at idx {}",
2432 i
2433 );
2434 }
2435 if is_monotonic_dec {
2436 prop_assert!(
2437 mama_out[i] <= mama_out[i - 10] + tolerance * 0.1,
2438 "MAMA should follow decreasing trend at idx {}",
2439 i
2440 );
2441 }
2442 }
2443 }
2444
2445 Ok(())
2446 })
2447 .unwrap();
2448
2449 Ok(())
2450 }
2451
2452 fn variance(data: &[f64]) -> f64 {
2453 if data.is_empty() {
2454 return 0.0;
2455 }
2456 let mean = data.iter().sum::<f64>() / data.len() as f64;
2457 data.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / data.len() as f64
2458 }
2459
2460 generate_all_mama_tests!(
2461 check_mama_partial_params,
2462 check_mama_accuracy,
2463 check_mama_default_candles,
2464 check_mama_with_insufficient_data,
2465 check_mama_very_small_dataset,
2466 check_mama_reinput,
2467 check_mama_nan_handling,
2468 check_mama_no_poison,
2469 check_mama_property
2470 );
2471
2472 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2473 skip_if_unsupported!(kernel, test);
2474 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2475 let c = read_candles_from_csv(file)?;
2476 let output = MamaBatchBuilder::new()
2477 .kernel(kernel)
2478 .apply_candles(&c, "close")?;
2479 let def = MamaParams::default();
2480 let mama_row = output.mama_for(&def).expect("default row missing");
2481 assert_eq!(mama_row.len(), c.close.len());
2482 Ok(())
2483 }
2484
2485 macro_rules! gen_batch_tests {
2486 ($fn_name:ident) => {
2487 paste! {
2488 #[test] fn [<$fn_name _scalar>]() {
2489 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2490 }
2491 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2492 #[test] fn [<$fn_name _avx2>]() {
2493 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2494 }
2495 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2496 #[test] fn [<$fn_name _avx512>]() {
2497 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2498 }
2499 #[test] fn [<$fn_name _auto_detect>]() {
2500 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2501 }
2502 }
2503 };
2504 }
2505
2506 #[cfg(debug_assertions)]
2507 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2508 skip_if_unsupported!(kernel, test);
2509
2510 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2511 let c = read_candles_from_csv(file)?;
2512
2513 let test_configs = vec![
2514 ((0.2, 0.4, 0.1), (0.02, 0.04, 0.01)),
2515 ((0.3, 0.7, 0.2), (0.03, 0.07, 0.02)),
2516 ((0.4, 0.9, 0.1), (0.01, 0.09, 0.02)),
2517 ((0.5, 0.8, 0.15), (0.01, 0.03, 0.01)),
2518 ((0.2, 0.6, 0.05), (0.02, 0.08, 0.01)),
2519 ];
2520
2521 for (fast_range, slow_range) in test_configs {
2522 let output = MamaBatchBuilder::new()
2523 .kernel(kernel)
2524 .fast_limit_range(fast_range.0, fast_range.1, fast_range.2)
2525 .slow_limit_range(slow_range.0, slow_range.1, slow_range.2)
2526 .apply_candles(&c, "close")?;
2527
2528 for (idx, &val) in output.mama_values.iter().enumerate() {
2529 if val.is_nan() {
2530 continue;
2531 }
2532
2533 let bits = val.to_bits();
2534 let row = idx / output.cols;
2535 let col = idx % output.cols;
2536 let params = &output.combos[row];
2537
2538 if bits == 0x11111111_11111111 {
2539 panic!(
2540 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} in mama_values (params: fast_limit={:?}, slow_limit={:?})",
2541 test, val, bits, row, col, params.fast_limit, params.slow_limit
2542 );
2543 }
2544
2545 if bits == 0x22222222_22222222 {
2546 panic!(
2547 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} in mama_values (params: fast_limit={:?}, slow_limit={:?})",
2548 test, val, bits, row, col, params.fast_limit, params.slow_limit
2549 );
2550 }
2551
2552 if bits == 0x33333333_33333333 {
2553 panic!(
2554 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} in mama_values (params: fast_limit={:?}, slow_limit={:?})",
2555 test, val, bits, row, col, params.fast_limit, params.slow_limit
2556 );
2557 }
2558 }
2559
2560 for (idx, &val) in output.fama_values.iter().enumerate() {
2561 if val.is_nan() {
2562 continue;
2563 }
2564
2565 let bits = val.to_bits();
2566 let row = idx / output.cols;
2567 let col = idx % output.cols;
2568 let params = &output.combos[row];
2569
2570 if bits == 0x11111111_11111111 {
2571 panic!(
2572 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} in fama_values (params: fast_limit={:?}, slow_limit={:?})",
2573 test, val, bits, row, col, params.fast_limit, params.slow_limit
2574 );
2575 }
2576
2577 if bits == 0x22222222_22222222 {
2578 panic!(
2579 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} in fama_values (params: fast_limit={:?}, slow_limit={:?})",
2580 test, val, bits, row, col, params.fast_limit, params.slow_limit
2581 );
2582 }
2583
2584 if bits == 0x33333333_33333333 {
2585 panic!(
2586 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} in fama_values (params: fast_limit={:?}, slow_limit={:?})",
2587 test, val, bits, row, col, params.fast_limit, params.slow_limit
2588 );
2589 }
2590 }
2591 }
2592
2593 Ok(())
2594 }
2595
2596 #[cfg(not(debug_assertions))]
2597 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2598 Ok(())
2599 }
2600
2601 #[test]
2602 fn test_mama_into_matches_api() -> Result<(), Box<dyn Error>> {
2603 let n = 256usize;
2604 let data: Vec<f64> = (0..n)
2605 .map(|i| {
2606 let t = i as f64;
2607 (t * 0.013).sin() * 10.0 + (t * 0.01)
2608 })
2609 .collect();
2610
2611 let input = MamaInput::from_slice(&data, MamaParams::default());
2612
2613 let baseline = mama(&input)?;
2614
2615 let mut out_mama = vec![0.0; n];
2616 let mut out_fama = vec![0.0; n];
2617 #[allow(unused_variables)]
2618 {
2619 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2620 {
2621 super::mama_into(&input, &mut out_mama, &mut out_fama)?;
2622 }
2623 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2624 {
2625 super::mama_into_slice(&mut out_mama, &mut out_fama, &input, Kernel::Auto)?;
2626 }
2627 }
2628
2629 fn eq_or_both_nan(a: f64, b: f64) -> bool {
2630 (a.is_nan() && b.is_nan()) || (a == b)
2631 }
2632
2633 assert_eq!(baseline.mama_values.len(), out_mama.len());
2634 assert_eq!(baseline.fama_values.len(), out_fama.len());
2635 for i in 0..n {
2636 assert!(
2637 eq_or_both_nan(baseline.mama_values[i], out_mama[i]),
2638 "mama mismatch at {}: left={} right={}",
2639 i,
2640 baseline.mama_values[i],
2641 out_mama[i]
2642 );
2643 assert!(
2644 eq_or_both_nan(baseline.fama_values[i], out_fama[i]),
2645 "fama mismatch at {}: left={} right={}",
2646 i,
2647 baseline.fama_values[i],
2648 out_fama[i]
2649 );
2650 }
2651 Ok(())
2652 }
2653
2654 gen_batch_tests!(check_batch_default_row);
2655 gen_batch_tests!(check_batch_no_poison);
2656}
2657
2658#[cfg(feature = "python")]
2659mod python_bindings {
2660 use super::*;
2661 #[cfg(feature = "cuda")]
2662 use crate::cuda::cuda_available;
2663 #[cfg(feature = "cuda")]
2664 use crate::cuda::moving_averages::{CudaMama, DeviceMamaPair};
2665 use crate::utilities::kernel_validation::validate_kernel;
2666 #[cfg(feature = "cuda")]
2667 use cust::context::Context;
2668 #[cfg(feature = "cuda")]
2669 use cust::memory::DeviceBuffer;
2670 #[cfg(feature = "cuda")]
2671 use numpy::PyReadonlyArray2;
2672 use numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray1};
2673 use pyo3::exceptions::PyValueError;
2674 use pyo3::prelude::*;
2675 use pyo3::types::PyDictMethods;
2676 #[cfg(feature = "cuda")]
2677 use std::os::raw::c_void;
2678 #[cfg(feature = "cuda")]
2679 use std::sync::Arc;
2680
2681 use pyo3::types::PyDict;
2682 use pyo3::{pyclass, pymethods};
2683 use std::collections::HashMap;
2684
2685 #[pyfunction]
2686 #[pyo3(name = "mama")]
2687 #[pyo3(signature = (data, fast_limit, slow_limit, kernel=None))]
2688 pub fn mama_py<'py>(
2689 py: Python<'py>,
2690 data: PyReadonlyArray1<'py, f64>,
2691 fast_limit: f64,
2692 slow_limit: f64,
2693 kernel: Option<&str>,
2694 ) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
2695 let slice_in = data.as_slice()?;
2696 let params = MamaParams {
2697 fast_limit: Some(fast_limit),
2698 slow_limit: Some(slow_limit),
2699 };
2700 let input = MamaInput::from_slice(slice_in, params);
2701 let kern = validate_kernel(kernel, false)?;
2702
2703 let len = slice_in.len();
2704
2705 let out_m = unsafe { PyArray1::<f64>::new(py, [len], false) };
2706 let out_f = unsafe { PyArray1::<f64>::new(py, [len], false) };
2707 let sm = unsafe { out_m.as_slice_mut()? };
2708 let sf = unsafe { out_f.as_slice_mut()? };
2709
2710 py.allow_threads(|| mama_into_slice(sm, sf, &input, kern))
2711 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2712
2713 Ok((out_m, out_f))
2714 }
2715
2716 #[pyfunction]
2717 #[pyo3(name = "mama_batch")]
2718 #[pyo3(signature = (data, fast_limit_range, slow_limit_range, kernel=None))]
2719 pub fn mama_batch_py<'py>(
2720 py: Python<'py>,
2721 data: PyReadonlyArray1<'py, f64>,
2722 fast_limit_range: (f64, f64, f64),
2723 slow_limit_range: (f64, f64, f64),
2724 kernel: Option<&str>,
2725 ) -> PyResult<Bound<'py, PyDict>> {
2726 let slice_in = data.as_slice()?;
2727 let sweep = MamaBatchRange {
2728 fast_limit: fast_limit_range,
2729 slow_limit: slow_limit_range,
2730 };
2731
2732 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2733 let rows = combos.len();
2734 let cols = slice_in.len();
2735 let total = rows
2736 .checked_mul(cols)
2737 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2738
2739 let mama_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2740 let fama_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2741 let mama_slice = unsafe { mama_arr.as_slice_mut()? };
2742 let fama_slice = unsafe { fama_arr.as_slice_mut()? };
2743
2744 let kern = validate_kernel(kernel, true)?;
2745
2746 let combos = py
2747 .allow_threads(|| -> Result<Vec<MamaParams>, MamaError> {
2748 let simd = match kern {
2749 Kernel::Auto | Kernel::ScalarBatch => Kernel::Scalar,
2750 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2751 Kernel::Avx512Batch => Kernel::Avx512,
2752 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2753 Kernel::Avx2Batch => Kernel::Avx2,
2754
2755 _ => Kernel::Scalar,
2756 };
2757
2758 mama_batch_inner_into(slice_in, &sweep, simd, true, mama_slice, fama_slice)
2759 })
2760 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2761
2762 let dict = PyDict::new(py);
2763 dict.set_item("mama", mama_arr.reshape((rows, cols))?)?;
2764 dict.set_item("fama", fama_arr.reshape((rows, cols))?)?;
2765 dict.set_item(
2766 "fast_limits",
2767 combos
2768 .iter()
2769 .map(|p| p.fast_limit.unwrap_or(0.5))
2770 .collect::<Vec<_>>()
2771 .into_pyarray(py),
2772 )?;
2773 dict.set_item(
2774 "slow_limits",
2775 combos
2776 .iter()
2777 .map(|p| p.slow_limit.unwrap_or(0.05))
2778 .collect::<Vec<_>>()
2779 .into_pyarray(py),
2780 )?;
2781
2782 Ok(dict)
2783 }
2784
2785 #[cfg(feature = "cuda")]
2786 #[pyfunction(name = "mama_cuda_batch_dev")]
2787 #[pyo3(signature = (data_f32, fast_limit_range, slow_limit_range, device_id=0))]
2788 pub fn mama_cuda_batch_dev_py(
2789 py: Python<'_>,
2790 data_f32: PyReadonlyArray1<'_, f32>,
2791 fast_limit_range: (f64, f64, f64),
2792 slow_limit_range: (f64, f64, f64),
2793 device_id: usize,
2794 ) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
2795 if !cuda_available() {
2796 return Err(PyValueError::new_err("CUDA not available"));
2797 }
2798
2799 let slice_in = data_f32.as_slice()?;
2800 let sweep = MamaBatchRange {
2801 fast_limit: fast_limit_range,
2802 slow_limit: slow_limit_range,
2803 };
2804
2805 let (pair, ctx, dev_id) = py.allow_threads(|| {
2806 let cuda =
2807 CudaMama::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2808 let ctx = cuda.context_arc();
2809 let dev_id = cuda.device_id();
2810 let pair = cuda
2811 .mama_batch_dev(slice_in, &sweep)
2812 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2813 Ok::<_, PyErr>((pair, ctx, dev_id))
2814 })?;
2815
2816 let DeviceMamaPair { mama, fama } = pair;
2817 Ok((
2818 DeviceArrayF32Py {
2819 buf: Some(mama.buf),
2820 rows: mama.rows,
2821 cols: mama.cols,
2822 _ctx: ctx.clone(),
2823 device_id: dev_id,
2824 },
2825 DeviceArrayF32Py {
2826 buf: Some(fama.buf),
2827 rows: fama.rows,
2828 cols: fama.cols,
2829 _ctx: ctx,
2830 device_id: dev_id,
2831 },
2832 ))
2833 }
2834
2835 #[cfg(feature = "cuda")]
2836 #[pyfunction(name = "mama_cuda_many_series_one_param_dev")]
2837 #[pyo3(signature = (data_tm_f32, fast_limit, slow_limit, device_id=0))]
2838 pub fn mama_cuda_many_series_one_param_dev_py(
2839 py: Python<'_>,
2840 data_tm_f32: PyReadonlyArray2<'_, f32>,
2841 fast_limit: f64,
2842 slow_limit: f64,
2843 device_id: usize,
2844 ) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
2845 use numpy::PyUntypedArrayMethods;
2846
2847 if !cuda_available() {
2848 return Err(PyValueError::new_err("CUDA not available"));
2849 }
2850
2851 let shape = data_tm_f32.shape();
2852 if shape.len() != 2 {
2853 return Err(PyValueError::new_err("expected 2D array"));
2854 }
2855 let rows = shape[0];
2856 let cols = shape[1];
2857 let flat = data_tm_f32.as_slice()?;
2858
2859 let fast = fast_limit as f32;
2860 let slow = slow_limit as f32;
2861
2862 let (pair, ctx, dev_id) = py.allow_threads(|| {
2863 let cuda =
2864 CudaMama::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2865 let ctx = cuda.context_arc();
2866 let dev_id = cuda.device_id();
2867 let pair = cuda
2868 .mama_many_series_one_param_time_major_dev(flat, cols, rows, fast, slow)
2869 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2870 Ok::<_, PyErr>((pair, ctx, dev_id))
2871 })?;
2872
2873 let DeviceMamaPair { mama, fama } = pair;
2874 Ok((
2875 DeviceArrayF32Py {
2876 buf: Some(mama.buf),
2877 rows: mama.rows,
2878 cols: mama.cols,
2879 _ctx: ctx.clone(),
2880 device_id: dev_id,
2881 },
2882 DeviceArrayF32Py {
2883 buf: Some(fama.buf),
2884 rows: fama.rows,
2885 cols: fama.cols,
2886 _ctx: ctx,
2887 device_id: dev_id,
2888 },
2889 ))
2890 }
2891
2892 #[pyclass]
2893 #[pyo3(name = "MamaStream")]
2894 pub struct MamaStreamPy {
2895 inner: MamaStream,
2896 }
2897
2898 #[pymethods]
2899 impl MamaStreamPy {
2900 #[new]
2901 pub fn new(fast_limit: f64, slow_limit: f64) -> PyResult<Self> {
2902 let params = MamaParams {
2903 fast_limit: Some(fast_limit),
2904 slow_limit: Some(slow_limit),
2905 };
2906 let stream =
2907 MamaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2908 Ok(Self { inner: stream })
2909 }
2910
2911 pub fn update(&mut self, value: f64) -> Option<(f64, f64)> {
2912 self.inner.update(value)
2913 }
2914 }
2915}
2916
2917#[cfg(feature = "python")]
2918pub use python_bindings::{mama_batch_py, mama_py, MamaStreamPy};
2919#[cfg(all(feature = "python", feature = "cuda"))]
2920pub use python_bindings::{mama_cuda_batch_dev_py, mama_cuda_many_series_one_param_dev_py};
2921
2922#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2923use serde::{Deserialize, Serialize};
2924
2925#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2926#[derive(Serialize, Deserialize)]
2927pub struct MamaResult {
2928 pub values: Vec<f64>,
2929 pub rows: usize,
2930 pub cols: usize,
2931}
2932
2933#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2934#[wasm_bindgen(js_name = "mama")]
2935pub fn mama_js(data: &[f64], fast_limit: f64, slow_limit: f64) -> Result<JsValue, JsValue> {
2936 let params = MamaParams {
2937 fast_limit: Some(fast_limit),
2938 slow_limit: Some(slow_limit),
2939 };
2940 let input = MamaInput::from_slice(data, params);
2941
2942 let mut mama = vec![0.0; data.len()];
2943 let mut fama = vec![0.0; data.len()];
2944 mama_into_slice(&mut mama, &mut fama, &input, detect_best_kernel())
2945 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2946
2947 let mut values = mama;
2948 values.extend_from_slice(&fama);
2949
2950 let out = MamaResult {
2951 values,
2952 rows: 2,
2953 cols: data.len(),
2954 };
2955 serde_wasm_bindgen::to_value(&out)
2956 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2957}
2958
2959#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2960#[wasm_bindgen(js_name = "mama_into")]
2961pub fn mama_into(
2962 in_ptr: *const f64,
2963 out_m_ptr: *mut f64,
2964 out_f_ptr: *mut f64,
2965 len: usize,
2966 fast_limit: f64,
2967 slow_limit: f64,
2968) -> Result<(), JsValue> {
2969 if in_ptr.is_null() || out_m_ptr.is_null() || out_f_ptr.is_null() {
2970 return Err(JsValue::from_str("null pointer passed to mama_into"));
2971 }
2972 unsafe {
2973 let data = core::slice::from_raw_parts(in_ptr, len);
2974 let out_m = core::slice::from_raw_parts_mut(out_m_ptr, len);
2975 let out_f = core::slice::from_raw_parts_mut(out_f_ptr, len);
2976 let params = MamaParams {
2977 fast_limit: Some(fast_limit),
2978 slow_limit: Some(slow_limit),
2979 };
2980 let input = MamaInput::from_slice(data, params);
2981 mama_into_slice(out_m, out_f, &input, detect_best_kernel())
2982 .map_err(|e| JsValue::from_str(&e.to_string()))
2983 }
2984}
2985
2986#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2987#[derive(Serialize, Deserialize)]
2988pub struct MamaBatchJsOutput {
2989 pub mama: Vec<f64>,
2990 pub fama: Vec<f64>,
2991 pub combos: Vec<MamaParams>,
2992 pub rows: usize,
2993 pub cols: usize,
2994}
2995
2996#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2997#[wasm_bindgen(js_name = "mama_batch")]
2998pub fn mama_batch_js(
2999 data: &[f64],
3000 fast_start: f64,
3001 fast_end: f64,
3002 fast_step: f64,
3003 slow_start: f64,
3004 slow_end: f64,
3005 slow_step: f64,
3006) -> Result<JsValue, JsValue> {
3007 let sweep = MamaBatchRange {
3008 fast_limit: (fast_start, fast_end, fast_step),
3009 slow_limit: (slow_start, slow_end, slow_step),
3010 };
3011 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
3012 let rows = combos.len();
3013 let cols = data.len();
3014 let total = rows
3015 .checked_mul(cols)
3016 .ok_or(JsValue::from_str("rows*cols overflow"))?;
3017
3018 let mut mama_values = vec![0.0; total];
3019 let mut fama_values = vec![0.0; total];
3020
3021 let kern = detect_best_kernel();
3022 mama_batch_inner_into(
3023 data,
3024 &sweep,
3025 kern,
3026 false,
3027 &mut mama_values,
3028 &mut fama_values,
3029 )
3030 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3031
3032 let out = MamaBatchJsOutput {
3033 mama: mama_values,
3034 fama: fama_values,
3035 combos,
3036 rows,
3037 cols,
3038 };
3039 serde_wasm_bindgen::to_value(&out)
3040 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
3041}
3042
3043#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3044#[wasm_bindgen]
3045pub fn mama_batch_metadata_js(
3046 fast_limit_start: f64,
3047 fast_limit_end: f64,
3048 fast_limit_step: f64,
3049 slow_limit_start: f64,
3050 slow_limit_end: f64,
3051 slow_limit_step: f64,
3052) -> Vec<f64> {
3053 let range = MamaBatchRange {
3054 fast_limit: (fast_limit_start, fast_limit_end, fast_limit_step),
3055 slow_limit: (slow_limit_start, slow_limit_end, slow_limit_step),
3056 };
3057
3058 let combos = expand_grid(&range).unwrap_or_else(|_| Vec::new());
3059 let mut metadata = Vec::with_capacity(combos.len() * 2);
3060
3061 for combo in combos {
3062 metadata.push(combo.fast_limit.unwrap_or(0.5));
3063 metadata.push(combo.slow_limit.unwrap_or(0.05));
3064 }
3065
3066 metadata
3067}
3068
3069#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3070#[wasm_bindgen]
3071pub fn mama_batch_rows_cols_js(
3072 fast_limit_start: f64,
3073 fast_limit_end: f64,
3074 fast_limit_step: f64,
3075 slow_limit_start: f64,
3076 slow_limit_end: f64,
3077 slow_limit_step: f64,
3078 data_len: usize,
3079) -> Vec<usize> {
3080 let range = MamaBatchRange {
3081 fast_limit: (fast_limit_start, fast_limit_end, fast_limit_step),
3082 slow_limit: (slow_limit_start, slow_limit_end, slow_limit_step),
3083 };
3084
3085 let combos = expand_grid(&range).unwrap_or_else(|_| Vec::new());
3086 vec![combos.len(), data_len]
3087}
3088
3089#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3090#[wasm_bindgen]
3091pub fn mama_alloc(len: usize) -> *mut f64 {
3092 let mut vec = Vec::<f64>::with_capacity(len);
3093 let ptr = vec.as_mut_ptr();
3094 std::mem::forget(vec);
3095 ptr
3096}
3097
3098#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3099#[wasm_bindgen]
3100pub fn mama_free(ptr: *mut f64, len: usize) {
3101 if !ptr.is_null() {
3102 unsafe {
3103 let _ = Vec::from_raw_parts(ptr, len, len);
3104 }
3105 }
3106}
3107
3108#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3109#[wasm_bindgen]
3110pub fn mama_batch_into(
3111 in_ptr: *const f64,
3112 out_mama_ptr: *mut f64,
3113 out_fama_ptr: *mut f64,
3114 len: usize,
3115 fast_limit_start: f64,
3116 fast_limit_end: f64,
3117 fast_limit_step: f64,
3118 slow_limit_start: f64,
3119 slow_limit_end: f64,
3120 slow_limit_step: f64,
3121) -> Result<usize, JsValue> {
3122 if in_ptr.is_null() || out_mama_ptr.is_null() || out_fama_ptr.is_null() {
3123 return Err(JsValue::from_str("null pointer passed to mama_batch_into"));
3124 }
3125
3126 unsafe {
3127 let data = std::slice::from_raw_parts(in_ptr, len);
3128
3129 let range = MamaBatchRange {
3130 fast_limit: (fast_limit_start, fast_limit_end, fast_limit_step),
3131 slow_limit: (slow_limit_start, slow_limit_end, slow_limit_step),
3132 };
3133
3134 let batch_output = mama_batch_with_kernel(data, &range, Kernel::Auto)
3135 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3136
3137 let rows = batch_output.combos.len();
3138 let cols = len;
3139 let total_elements = rows * cols;
3140
3141 let out_mama = std::slice::from_raw_parts_mut(out_mama_ptr, total_elements);
3142 out_mama.copy_from_slice(&batch_output.mama_values);
3143
3144 let out_fama = std::slice::from_raw_parts_mut(out_fama_ptr, total_elements);
3145 out_fama.copy_from_slice(&batch_output.fama_values);
3146
3147 Ok(rows)
3148 }
3149}
3150
3151#[cfg(all(feature = "python", feature = "cuda"))]
3152#[pyo3::pyclass(module = "ta_indicators.cuda", unsendable)]
3153pub struct DeviceArrayF32Py {
3154 pub(crate) buf: Option<cust::memory::DeviceBuffer<f32>>,
3155 pub(crate) rows: usize,
3156 pub(crate) cols: usize,
3157 pub(crate) _ctx: std::sync::Arc<cust::context::Context>,
3158 pub(crate) device_id: u32,
3159}
3160
3161#[cfg(all(feature = "python", feature = "cuda"))]
3162#[pyo3::pymethods]
3163impl DeviceArrayF32Py {
3164 #[getter]
3165 fn __cuda_array_interface__<'py>(
3166 &self,
3167 py: pyo3::Python<'py>,
3168 ) -> pyo3::PyResult<pyo3::prelude::Bound<'py, pyo3::types::PyDict>> {
3169 let d = pyo3::types::PyDict::new(py);
3170 pyo3::types::PyDictMethods::set_item(&d, "shape", (self.rows, self.cols))?;
3171 pyo3::types::PyDictMethods::set_item(&d, "typestr", "<f4")?;
3172 pyo3::types::PyDictMethods::set_item(
3173 &d,
3174 "strides",
3175 (
3176 self.cols * std::mem::size_of::<f32>(),
3177 std::mem::size_of::<f32>(),
3178 ),
3179 )?;
3180 let ptr = self
3181 .buf
3182 .as_ref()
3183 .ok_or_else(|| {
3184 pyo3::exceptions::PyValueError::new_err("buffer already exported via __dlpack__")
3185 })?
3186 .as_device_ptr()
3187 .as_raw() as usize;
3188 pyo3::types::PyDictMethods::set_item(&d, "data", (ptr, false))?;
3189 pyo3::types::PyDictMethods::set_item(&d, "version", 3)?;
3190 Ok(d)
3191 }
3192
3193 fn __dlpack_device__(&self) -> (i32, i32) {
3194 (2, self.device_id as i32)
3195 }
3196
3197 #[cfg(feature = "mama_legacy_dlpack")]
3198 #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
3199 fn __dlpack_legacy__<'py>(
3200 &mut self,
3201 py: pyo3::Python<'py>,
3202 stream: Option<&pyo3::types::PyAny>,
3203 max_version: Option<&pyo3::types::PyAny>,
3204 dl_device: Option<&pyo3::types::PyAny>,
3205 copy: Option<&pyo3::types::PyAny>,
3206 ) -> pyo3::PyResult<pyo3::PyObject> {
3207 use std::os::raw::c_char;
3208
3209 let buf = self.buf.take().ok_or_else(|| {
3210 pyo3::exceptions::PyValueError::new_err("__dlpack__ may only be called once")
3211 })?;
3212
3213 #[repr(C)]
3214 struct DLDevice {
3215 device_type: i32,
3216 device_id: i32,
3217 }
3218 #[repr(C)]
3219 struct DLDataType {
3220 code: u8,
3221 bits: u8,
3222 lanes: u16,
3223 }
3224 #[repr(C)]
3225 struct DLTensor {
3226 data: *mut std::ffi::c_void,
3227 device: DLDevice,
3228 ndim: i32,
3229 dtype: DLDataType,
3230 shape: *mut i64,
3231 strides: *mut i64,
3232 byte_offset: u64,
3233 }
3234 #[repr(C)]
3235 struct DLManagedTensor {
3236 dl_tensor: DLTensor,
3237 manager_ctx: *mut std::ffi::c_void,
3238 deleter: Option<extern "C" fn(*mut DLManagedTensor)>,
3239 }
3240 #[repr(C)]
3241 struct DLVersion {
3242 major: i32,
3243 minor: i32,
3244 }
3245 #[repr(C)]
3246 struct DLManagedTensorVersioned {
3247 dl_managed_tensor: DLManagedTensor,
3248 version: DLVersion,
3249 }
3250
3251 struct HolderLegacy {
3252 managed: DLManagedTensor,
3253 shape: [i64; 2],
3254 strides: [i64; 2],
3255 buf: cust::memory::DeviceBuffer<f32>,
3256 retained: cust::sys::CUcontext,
3257 device_id: i32,
3258 }
3259 struct HolderV1 {
3260 managed: DLManagedTensorVersioned,
3261 shape: [i64; 2],
3262 strides: [i64; 2],
3263 buf: cust::memory::DeviceBuffer<f32>,
3264 retained: cust::sys::CUcontext,
3265 device_id: i32,
3266 }
3267
3268 unsafe extern "C" fn deleter_legacy(p: *mut DLManagedTensor) {
3269 if p.is_null() {
3270 return;
3271 }
3272 let holder = (*p).manager_ctx as *mut HolderLegacy;
3273 if !holder.is_null() {
3274 let ctx = (*holder).retained;
3275 if !ctx.is_null() {
3276 let _ = cust::sys::cuCtxPushCurrent(ctx);
3277 let dev = (*holder).device_id;
3278 drop(Box::from_raw(holder));
3279 let mut _out: cust::sys::CUcontext = std::ptr::null_mut();
3280 let _ = cust::sys::cuCtxPopCurrent(&mut _out);
3281 let _ = cust::sys::cuDevicePrimaryCtxRelease(dev);
3282 }
3283 }
3284 drop(Box::from_raw(p));
3285 }
3286 unsafe extern "C" fn deleter_v1(p: *mut DLManagedTensorVersioned) {
3287 if p.is_null() {
3288 return;
3289 }
3290 let holder = (*p).dl_managed_tensor.manager_ctx as *mut HolderV1;
3291 if !holder.is_null() {
3292 let ctx = (*holder).retained;
3293 if !ctx.is_null() {
3294 let _ = cust::sys::cuCtxPushCurrent(ctx);
3295 let dev = (*holder).device_id;
3296 drop(Box::from_raw(holder));
3297 let mut _out: cust::sys::CUcontext = std::ptr::null_mut();
3298 let _ = cust::sys::cuCtxPopCurrent(&mut _out);
3299 let _ = cust::sys::cuDevicePrimaryCtxRelease(dev);
3300 }
3301 }
3302 drop(Box::from_raw(p));
3303 }
3304
3305 unsafe extern "C" fn cap_destructor_legacy(capsule: *mut pyo3::ffi::PyObject) {
3306 let name = b"dltensor\0";
3307 let ptr = pyo3::ffi::PyCapsule_GetPointer(capsule, name.as_ptr() as *const c_char)
3308 as *mut DLManagedTensor;
3309 if !ptr.is_null() {
3310 if let Some(del) = (*ptr).deleter {
3311 del(ptr);
3312 }
3313 let used = b"used_dltensor\0";
3314 pyo3::ffi::PyCapsule_SetName(capsule, used.as_ptr() as *const _);
3315 }
3316 }
3317 unsafe extern "C" fn cap_destructor_v1(capsule: *mut pyo3::ffi::PyObject) {
3318 let name = b"dltensor_versioned\0";
3319 let ptr = pyo3::ffi::PyCapsule_GetPointer(capsule, name.as_ptr() as *const c_char)
3320 as *mut DLManagedTensorVersioned;
3321 if !ptr.is_null() {
3322 let mt = &mut (*ptr).dl_managed_tensor;
3323 if let Some(del) = mt.deleter {
3324 del(mt);
3325 }
3326 let used = b"used_dltensor_versioned\0";
3327 pyo3::ffi::PyCapsule_SetName(capsule, used.as_ptr() as *const _);
3328 }
3329 }
3330
3331 let alloc_dev = self.device_id as i32;
3332 let mut retained: cust::sys::CUcontext = std::ptr::null_mut();
3333 unsafe {
3334 let _ = cust::sys::cuDevicePrimaryCtxRetain(&mut retained, alloc_dev);
3335 }
3336
3337 let rows = self.rows as i64;
3338 let cols = self.cols as i64;
3339 let data_ptr: *mut std::ffi::c_void = if self.rows == 0 || self.cols == 0 {
3340 std::ptr::null_mut()
3341 } else {
3342 buf.as_device_ptr().as_raw() as *mut std::ffi::c_void
3343 };
3344
3345 let want_v1 = if let Some(v) = max_version {
3346 v.getattr("__iter")
3347 .ok()
3348 .and_then(|_| v.extract::<(i32, i32)>().ok())
3349 .map(|(maj, _)| maj >= 1)
3350 .unwrap_or(false)
3351 } else {
3352 false
3353 };
3354
3355 if want_v1 {
3356 let mut holder = Box::new(HolderV1 {
3357 managed: DLManagedTensorVersioned {
3358 dl_managed_tensor: DLManagedTensor {
3359 dl_tensor: DLTensor {
3360 data: data_ptr,
3361 device: DLDevice {
3362 device_type: 2,
3363 device_id: alloc_dev,
3364 },
3365 ndim: 2,
3366 dtype: DLDataType {
3367 code: 2,
3368 bits: 32,
3369 lanes: 1,
3370 },
3371 shape: std::ptr::null_mut(),
3372 strides: std::ptr::null_mut(),
3373 byte_offset: 0,
3374 },
3375 manager_ctx: std::ptr::null_mut(),
3376 deleter: Some(|mt| {
3377 if !mt.is_null() {
3378 let outer = (mt as *mut u8)
3379 .offset(-(std::mem::size_of::<DLVersion>() as isize))
3380 as *mut DLManagedTensorVersioned;
3381 deleter_v1(outer);
3382 }
3383 }),
3384 },
3385 version: DLVersion { major: 1, minor: 0 },
3386 },
3387 shape: [rows, cols],
3388 strides: [cols, 1],
3389 buf,
3390 retained,
3391 device_id: alloc_dev,
3392 });
3393 holder.managed.dl_managed_tensor.dl_tensor.shape = holder.shape.as_mut_ptr();
3394 holder.managed.dl_managed_tensor.dl_tensor.strides = holder.strides.as_mut_ptr();
3395 holder.managed.dl_managed_tensor.manager_ctx =
3396 &mut *holder as *mut HolderV1 as *mut std::ffi::c_void;
3397 let mt_ptr: *mut DLManagedTensorVersioned = &mut holder.managed;
3398 let _leak = Box::into_raw(holder);
3399 let name = b"dltensor_versioned\0";
3400 let cap = unsafe {
3401 pyo3::ffi::PyCapsule_New(
3402 mt_ptr as *mut std::ffi::c_void,
3403 name.as_ptr() as *const c_char,
3404 Some(cap_destructor_v1),
3405 )
3406 };
3407 if cap.is_null() {
3408 return Err(pyo3::exceptions::PyValueError::new_err(
3409 "failed to create DLPack capsule",
3410 ));
3411 }
3412 Ok(unsafe { pyo3::PyObject::from_owned_ptr(py, cap) })
3413 } else {
3414 let mut holder = Box::new(HolderLegacy {
3415 managed: DLManagedTensor {
3416 dl_tensor: DLTensor {
3417 data: data_ptr,
3418 device: DLDevice {
3419 device_type: 2,
3420 device_id: alloc_dev,
3421 },
3422 ndim: 2,
3423 dtype: DLDataType {
3424 code: 2,
3425 bits: 32,
3426 lanes: 1,
3427 },
3428 shape: std::ptr::null_mut(),
3429 strides: std::ptr::null_mut(),
3430 byte_offset: 0,
3431 },
3432 manager_ctx: std::ptr::null_mut(),
3433 deleter: Some(deleter_legacy),
3434 },
3435 shape: [rows, cols],
3436 strides: [cols, 1],
3437 buf,
3438 retained,
3439 device_id: alloc_dev,
3440 });
3441 holder.managed.dl_tensor.shape = holder.shape.as_mut_ptr();
3442 holder.managed.dl_tensor.strides = holder.strides.as_mut_ptr();
3443 holder.managed.manager_ctx = &mut *holder as *mut HolderLegacy as *mut std::ffi::c_void;
3444 let mt_ptr: *mut DLManagedTensor = &mut holder.managed;
3445 let _leak = Box::into_raw(holder);
3446 let name = b"dltensor\0";
3447 let cap = unsafe {
3448 pyo3::ffi::PyCapsule_New(
3449 mt_ptr as *mut std::ffi::c_void,
3450 name.as_ptr() as *const c_char,
3451 Some(cap_destructor_legacy),
3452 )
3453 };
3454 if cap.is_null() {
3455 return Err(pyo3::exceptions::PyValueError::new_err(
3456 "failed to create DLPack capsule",
3457 ));
3458 }
3459 Ok(unsafe { pyo3::PyObject::from_owned_ptr(py, cap) })
3460 }
3461 }
3462
3463 #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
3464 fn __dlpack__<'py>(
3465 &mut self,
3466 py: pyo3::Python<'py>,
3467 stream: Option<pyo3::PyObject>,
3468 max_version: Option<pyo3::PyObject>,
3469 dl_device: Option<pyo3::PyObject>,
3470 copy: Option<pyo3::PyObject>,
3471 ) -> pyo3::PyResult<pyo3::PyObject> {
3472 use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
3473
3474 let (kdl, alloc_dev) = self.__dlpack_device__();
3475 if let Some(dev_obj) = dl_device.as_ref() {
3476 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
3477 if dev_ty != kdl || dev_id != alloc_dev {
3478 let wants_copy = copy
3479 .as_ref()
3480 .and_then(|c| c.extract::<bool>(py).ok())
3481 .unwrap_or(false);
3482 if wants_copy {
3483 return Err(pyo3::exceptions::PyValueError::new_err(
3484 "device copy not implemented for __dlpack__",
3485 ));
3486 } else {
3487 return Err(pyo3::exceptions::PyValueError::new_err(
3488 "dl_device mismatch for __dlpack__",
3489 ));
3490 }
3491 }
3492 }
3493 }
3494 let _ = stream;
3495
3496 let buf = self.buf.take().ok_or_else(|| {
3497 pyo3::exceptions::PyValueError::new_err("__dlpack__ may only be called once")
3498 })?;
3499
3500 let rows = self.rows;
3501 let cols = self.cols;
3502
3503 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
3504
3505 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
3506 }
3507}