1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::maaq_wrapper::DeviceArrayF32Maaq;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::cuda::moving_averages::CudaMaaq;
7use crate::utilities::data_loader::{source_type, Candles};
8use crate::utilities::enums::Kernel;
9use crate::utilities::helpers::{
10 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
11 make_uninit_matrix,
12};
13use aligned_vec::{AVec, CACHELINE_ALIGN};
14#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
15use core::arch::x86_64::*;
16#[cfg(not(target_arch = "wasm32"))]
17use rayon::prelude::*;
18use std::convert::AsRef;
19use std::error::Error;
20use std::mem::MaybeUninit;
21use thiserror::Error;
22
23impl<'a> AsRef<[f64]> for MaaqInput<'a> {
24 #[inline(always)]
25 fn as_ref(&self) -> &[f64] {
26 match &self.data {
27 MaaqData::Slice(slice) => slice,
28 MaaqData::Candles { candles, source } => source_type(candles, source),
29 }
30 }
31}
32
33#[derive(Debug, Clone)]
34pub enum MaaqData<'a> {
35 Candles {
36 candles: &'a Candles,
37 source: &'a str,
38 },
39 Slice(&'a [f64]),
40}
41
42#[derive(Debug, Clone)]
43pub struct MaaqOutput {
44 pub values: Vec<f64>,
45}
46
47#[derive(Debug, Clone)]
48#[cfg_attr(
49 all(target_arch = "wasm32", feature = "wasm"),
50 derive(Serialize, Deserialize)
51)]
52pub struct MaaqParams {
53 pub period: Option<usize>,
54 pub fast_period: Option<usize>,
55 pub slow_period: Option<usize>,
56}
57
58impl Default for MaaqParams {
59 fn default() -> Self {
60 Self {
61 period: Some(11),
62 fast_period: Some(2),
63 slow_period: Some(30),
64 }
65 }
66}
67
68#[derive(Debug, Clone)]
69pub struct MaaqInput<'a> {
70 pub data: MaaqData<'a>,
71 pub params: MaaqParams,
72}
73
74impl<'a> MaaqInput<'a> {
75 #[inline]
76 pub fn from_candles(c: &'a Candles, s: &'a str, p: MaaqParams) -> Self {
77 Self {
78 data: MaaqData::Candles {
79 candles: c,
80 source: s,
81 },
82 params: p,
83 }
84 }
85 #[inline]
86 pub fn from_slice(sl: &'a [f64], p: MaaqParams) -> Self {
87 Self {
88 data: MaaqData::Slice(sl),
89 params: p,
90 }
91 }
92 #[inline]
93 pub fn with_default_candles(c: &'a Candles) -> Self {
94 Self::from_candles(c, "close", MaaqParams::default())
95 }
96 #[inline]
97 pub fn get_period(&self) -> usize {
98 self.params.period.unwrap_or(11)
99 }
100 #[inline]
101 pub fn get_fast_period(&self) -> usize {
102 self.params.fast_period.unwrap_or(2)
103 }
104 #[inline]
105 pub fn get_slow_period(&self) -> usize {
106 self.params.slow_period.unwrap_or(30)
107 }
108}
109
110#[derive(Copy, Clone, Debug)]
111pub struct MaaqBuilder {
112 period: Option<usize>,
113 fast_period: Option<usize>,
114 slow_period: Option<usize>,
115 kernel: Kernel,
116}
117
118impl Default for MaaqBuilder {
119 fn default() -> Self {
120 Self {
121 period: None,
122 fast_period: None,
123 slow_period: None,
124 kernel: Kernel::Auto,
125 }
126 }
127}
128
129impl MaaqBuilder {
130 #[inline(always)]
131 pub fn new() -> Self {
132 Self::default()
133 }
134 #[inline(always)]
135 pub fn period(mut self, n: usize) -> Self {
136 self.period = Some(n);
137 self
138 }
139 #[inline(always)]
140 pub fn fast_period(mut self, n: usize) -> Self {
141 self.fast_period = Some(n);
142 self
143 }
144 #[inline(always)]
145 pub fn slow_period(mut self, n: usize) -> Self {
146 self.slow_period = Some(n);
147 self
148 }
149 #[inline(always)]
150 pub fn kernel(mut self, k: Kernel) -> Self {
151 self.kernel = k;
152 self
153 }
154 #[inline(always)]
155 pub fn apply(self, c: &Candles) -> Result<MaaqOutput, MaaqError> {
156 let p = MaaqParams {
157 period: self.period,
158 fast_period: self.fast_period,
159 slow_period: self.slow_period,
160 };
161 let i = MaaqInput::from_candles(c, "close", p);
162 maaq_with_kernel(&i, self.kernel)
163 }
164 #[inline(always)]
165 pub fn apply_slice(self, d: &[f64]) -> Result<MaaqOutput, MaaqError> {
166 let p = MaaqParams {
167 period: self.period,
168 fast_period: self.fast_period,
169 slow_period: self.slow_period,
170 };
171 let i = MaaqInput::from_slice(d, p);
172 maaq_with_kernel(&i, self.kernel)
173 }
174 #[inline(always)]
175 pub fn into_stream(self) -> Result<MaaqStream, MaaqError> {
176 let p = MaaqParams {
177 period: self.period,
178 fast_period: self.fast_period,
179 slow_period: self.slow_period,
180 };
181 MaaqStream::try_new(p)
182 }
183}
184
185#[derive(Debug, Error)]
186pub enum MaaqError {
187 #[error("maaq: Input data slice is empty.")]
188 EmptyInputData,
189 #[error("maaq: All values are NaN.")]
190 AllValuesNaN,
191 #[error("maaq: Invalid period: period = {period}, data length = {data_len}")]
192 InvalidPeriod { period: usize, data_len: usize },
193 #[error("maaq: Not enough valid data: needed = {needed}, valid = {valid}")]
194 NotEnoughValidData { needed: usize, valid: usize },
195 #[error("maaq: Output length mismatch: expected = {expected}, got = {got}")]
196 OutputLengthMismatch { expected: usize, got: usize },
197 #[error("maaq: Invalid range (start={start}, end={end}, step={step})")]
198 InvalidRange {
199 start: usize,
200 end: usize,
201 step: usize,
202 },
203 #[error("maaq: Non-batch kernel passed to batch path: {0:?}")]
204 InvalidKernelForBatch(Kernel),
205 #[error("maaq: periods cannot be zero: period = {period}, fast = {fast_p}, slow = {slow_p}")]
206 ZeroPeriods {
207 period: usize,
208 fast_p: usize,
209 slow_p: usize,
210 },
211}
212
213#[inline]
214pub fn maaq(input: &MaaqInput) -> Result<MaaqOutput, MaaqError> {
215 maaq_with_kernel(input, Kernel::Auto)
216}
217
218#[inline(always)]
219fn maaq_compute_into(
220 data: &[f64],
221 period: usize,
222 fast_p: usize,
223 slow_p: usize,
224 first: usize,
225 kernel: Kernel,
226 out: &mut [f64],
227) -> Result<(), MaaqError> {
228 if out.len() != data.len() {
229 return Err(MaaqError::OutputLengthMismatch {
230 expected: data.len(),
231 got: out.len(),
232 });
233 }
234 unsafe {
235 if first > 0 {
236 maaq_scalar(data, period, fast_p, slow_p, first, out)?;
237 } else {
238 match kernel {
239 Kernel::Scalar | Kernel::ScalarBatch => {
240 maaq_scalar(data, period, fast_p, slow_p, first, out)?;
241 }
242 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
243 Kernel::Avx2 | Kernel::Avx2Batch => {
244 maaq_avx2(data, period, fast_p, slow_p, first, out)?;
245 }
246 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
247 Kernel::Avx512 | Kernel::Avx512Batch => {
248 maaq_avx512(data, period, fast_p, slow_p, first, out)?;
249 }
250 _ => unreachable!(),
251 }
252 }
253 }
254 Ok(())
255}
256
257#[inline(always)]
258fn maaq_prepare<'a>(
259 input: &'a MaaqInput,
260 kernel: Kernel,
261) -> Result<(&'a [f64], usize, usize, usize, usize, Kernel), MaaqError> {
262 let data: &[f64] = input.as_ref();
263 let len = data.len();
264 if len == 0 {
265 return Err(MaaqError::EmptyInputData);
266 }
267
268 let first = data
269 .iter()
270 .position(|x| !x.is_nan())
271 .ok_or(MaaqError::AllValuesNaN)?;
272
273 let period = input.get_period();
274 let fast_p = input.get_fast_period();
275 let slow_p = input.get_slow_period();
276
277 if period == 0 || fast_p == 0 || slow_p == 0 {
278 return Err(MaaqError::ZeroPeriods {
279 period,
280 fast_p,
281 slow_p,
282 });
283 }
284 if period >= len {
285 return Err(MaaqError::InvalidPeriod {
286 period,
287 data_len: len,
288 });
289 }
290 if len - first < period {
291 return Err(MaaqError::NotEnoughValidData {
292 needed: period,
293 valid: len - first,
294 });
295 }
296
297 let chosen = match kernel {
298 Kernel::Auto => Kernel::Scalar,
299 k => k,
300 };
301
302 Ok((data, period, fast_p, slow_p, first, chosen))
303}
304
305pub fn maaq_with_kernel(input: &MaaqInput, kernel: Kernel) -> Result<MaaqOutput, MaaqError> {
306 let (data, period, fast_p, slow_p, first, chosen) = maaq_prepare(input, kernel)?;
307
308 let warm = first + period - 1;
309 let mut out = alloc_with_nan_prefix(data.len(), warm);
310
311 if out.len() != data.len() {
312 return Err(MaaqError::OutputLengthMismatch {
313 expected: data.len(),
314 got: out.len(),
315 });
316 }
317 maaq_compute_into(data, period, fast_p, slow_p, first, chosen, &mut out)?;
318
319 let warmup_end = first + period - 1;
320 for v in &mut out[..warmup_end] {
321 *v = f64::NAN;
322 }
323
324 Ok(MaaqOutput { values: out })
325}
326
327#[inline]
328pub fn maaq_scalar(
329 data: &[f64],
330 period: usize,
331 fast_p: usize,
332 slow_p: usize,
333 first: usize,
334 out: &mut [f64],
335) -> Result<(), MaaqError> {
336 let len = data.len();
337 let fast_sc = 2.0 / (fast_p as f64 + 1.0);
338 let slow_sc = 2.0 / (slow_p as f64 + 1.0);
339
340 let mut diffs = vec![0.0f64; period];
341 let mut vol_sum = 0.0;
342
343 for j in 1..period {
344 let d = (data[first + j] - data[first + j - 1]).abs();
345 diffs[j] = d;
346 vol_sum += d;
347 }
348
349 let warm_end = (first + period).min(len);
350 if warm_end > first {
351 out[first..warm_end].copy_from_slice(&data[first..warm_end]);
352 }
353
354 let i0 = first + period;
355 if i0 >= len {
356 return Ok(());
357 }
358
359 let new_diff = (data[i0] - data[i0 - 1]).abs();
360 diffs[0] = new_diff;
361 vol_sum += new_diff;
362
363 let mut prev_val = data[i0 - 1];
364 let er0 = if vol_sum > f64::EPSILON {
365 (data[i0] - data[first]).abs() / vol_sum
366 } else {
367 0.0
368 };
369 let mut sc = fast_sc.mul_add(er0, slow_sc);
370 sc *= sc;
371
372 prev_val = sc.mul_add(data[i0] - prev_val, prev_val);
373 out[i0] = prev_val;
374
375 let mut head = if period > 1 { 1usize } else { 0usize };
376
377 for i in (i0 + 1)..len {
378 vol_sum -= diffs[head];
379 let nd = (data[i] - data[i - 1]).abs();
380 diffs[head] = nd;
381 vol_sum += nd;
382 head += 1;
383 if head == period {
384 head = 0;
385 }
386
387 let er = if vol_sum > f64::EPSILON {
388 (data[i] - data[i - period]).abs() / vol_sum
389 } else {
390 0.0
391 };
392 let mut sc = fast_sc.mul_add(er, slow_sc);
393 sc *= sc;
394
395 prev_val = sc.mul_add(data[i] - prev_val, prev_val);
396 out[i] = prev_val;
397 }
398 Ok(())
399}
400
401#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
402#[inline]
403pub fn maaq_avx2(
404 data: &[f64],
405 period: usize,
406 fast_p: usize,
407 slow_p: usize,
408 first: usize,
409 out: &mut [f64],
410) -> Result<(), MaaqError> {
411 use core::arch::x86_64::*;
412
413 let len = data.len();
414 debug_assert_eq!(len, out.len());
415 if len == 0 {
416 return Ok(());
417 }
418
419 let fast_sc = 2.0 / (fast_p as f64 + 1.0);
420 let slow_sc = 2.0 / (slow_p as f64 + 1.0);
421
422 #[inline(always)]
423 unsafe fn vabs_pd(x: __m256d) -> __m256d {
424 let sign = _mm256_set1_pd(-0.0f64);
425 _mm256_andnot_pd(sign, x)
426 }
427
428 #[inline(always)]
429 fn fast_abs(x: f64) -> f64 {
430 f64::from_bits(x.to_bits() & 0x7FFF_FFFF_FFFF_FFFF)
431 }
432
433 let mut diffs: Vec<f64> = Vec::with_capacity(period);
434 unsafe {
435 diffs.set_len(period);
436 }
437 let mut vol_sum = 0.0f64;
438
439 unsafe {
440 let dp = data.as_ptr();
441
442 let base = first + 1;
443 let n = period.saturating_sub(1);
444
445 let mut accv = _mm256_setzero_pd();
446 let mut j = 0usize;
447 while j + 4 <= n {
448 let a = _mm256_loadu_pd(dp.add(base + j));
449 let b = _mm256_loadu_pd(dp.add(base + j - 1));
450 let d = vabs_pd(_mm256_sub_pd(a, b));
451 _mm256_storeu_pd(diffs.as_mut_ptr().add(1 + j), d);
452 accv = _mm256_add_pd(accv, d);
453 j += 4;
454 }
455
456 let mut tmp = [0.0f64; 4];
457 _mm256_storeu_pd(tmp.as_mut_ptr(), accv);
458 vol_sum = tmp[0] + tmp[1] + tmp[2] + tmp[3];
459
460 while j < n {
461 let k = base + j;
462 let d = fast_abs(*dp.add(k) - *dp.add(k - 1));
463 *diffs.get_unchecked_mut(1 + j) = d;
464 vol_sum += d;
465 j += 1;
466 }
467
468 let warm_end = (first + period).min(len);
469 if warm_end > first {
470 core::ptr::copy_nonoverlapping(
471 dp.add(first),
472 out.as_mut_ptr().add(first),
473 warm_end - first,
474 );
475 }
476
477 let i0 = first + period;
478 if i0 >= len {
479 return Ok(());
480 }
481
482 let new_diff = fast_abs(*dp.add(i0) - *dp.add(i0 - 1));
483 *diffs.get_unchecked_mut(0) = new_diff;
484 vol_sum += new_diff;
485
486 let mut prev_val = *dp.add(i0 - 1);
487 let er0 = if vol_sum > f64::EPSILON {
488 fast_abs(*dp.add(i0) - *dp.add(first)) / vol_sum
489 } else {
490 0.0
491 };
492 let mut sc = fast_sc.mul_add(er0, slow_sc);
493 sc *= sc;
494 prev_val = sc.mul_add(*dp.add(i0) - prev_val, prev_val);
495 *out.get_unchecked_mut(i0) = prev_val;
496
497 let mut head = if period > 1 { 1usize } else { 0usize };
498
499 let mut i = i0 + 1;
500 let op = out.as_mut_ptr();
501 while i < len {
502 vol_sum -= *diffs.get_unchecked(head);
503
504 let nd = fast_abs(*dp.add(i) - *dp.add(i - 1));
505 *diffs.get_unchecked_mut(head) = nd;
506 vol_sum += nd;
507
508 head += 1;
509 if head == period {
510 head = 0;
511 }
512
513 let er = if vol_sum > f64::EPSILON {
514 fast_abs(*dp.add(i) - *dp.add(i - period)) / vol_sum
515 } else {
516 0.0
517 };
518
519 let mut sc = fast_sc.mul_add(er, slow_sc);
520 sc *= sc;
521 prev_val = sc.mul_add(*dp.add(i) - prev_val, prev_val);
522
523 *op.add(i) = prev_val;
524 i += 1;
525 }
526 }
527
528 Ok(())
529}
530
531#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
532#[inline]
533pub fn maaq_avx512(
534 data: &[f64],
535 period: usize,
536 fast_p: usize,
537 slow_p: usize,
538 first: usize,
539 out: &mut [f64],
540) -> Result<(), MaaqError> {
541 maaq_avx2(data, period, fast_p, slow_p, first, out)
542}
543
544#[derive(Debug, Clone)]
545pub struct MaaqStream {
546 period: usize,
547 fast_period: usize,
548 slow_period: usize,
549 buffer: Vec<f64>,
550 diff: Vec<f64>,
551 head: usize,
552 filled: bool,
553 last: f64,
554 count: usize,
555
556 vol_sum: f64,
557 fast_sc: f64,
558 slow_sc: f64,
559}
560
561impl MaaqStream {
562 pub fn try_new(params: MaaqParams) -> Result<Self, MaaqError> {
563 let period = params.period.unwrap_or(11);
564 let fast_p = params.fast_period.unwrap_or(2);
565 let slow_p = params.slow_period.unwrap_or(30);
566
567 if period == 0 || fast_p == 0 || slow_p == 0 {
568 return Err(MaaqError::ZeroPeriods {
569 period,
570 fast_p,
571 slow_p,
572 });
573 }
574
575 let fast_sc = 2.0 / (fast_p as f64 + 1.0);
576 let slow_sc = 2.0 / (slow_p as f64 + 1.0);
577
578 Ok(Self {
579 period,
580 fast_period: fast_p,
581 slow_period: slow_p,
582 buffer: vec![0.0; period],
583 diff: vec![0.0; period],
584 head: 0,
585 filled: false,
586 last: f64::NAN,
587 count: 0,
588 vol_sum: 0.0,
589 fast_sc,
590 slow_sc,
591 })
592 }
593
594 #[inline(always)]
595 pub fn update(&mut self, value: f64) -> Option<f64> {
596 if !self.filled {
597 let prev = if self.count > 0 {
598 let idx_prev = (self.head + self.period - 1) % self.period;
599 self.buffer[idx_prev]
600 } else {
601 value
602 };
603 let d = (value - prev).abs();
604
605 self.buffer[self.head] = value;
606 self.diff[self.head] = d;
607 self.vol_sum += d;
608
609 self.head += 1;
610 if self.head == self.period {
611 self.head = 0;
612 }
613
614 self.count += 1;
615 self.last = value;
616
617 if self.count == self.period {
618 self.filled = true;
619 }
620 return Some(value);
621 }
622
623 let idx_prev = (self.head + self.period - 1) % self.period;
624 let prev_input = self.buffer[idx_prev];
625
626 let old_diff = self.diff[self.head];
627 self.vol_sum -= old_diff;
628
629 let new_diff = (value - prev_input).abs();
630 self.diff[self.head] = new_diff;
631 self.vol_sum += new_diff;
632
633 let old_value = self.buffer[self.head];
634
635 self.buffer[self.head] = value;
636 self.head += 1;
637 if self.head == self.period {
638 self.head = 0;
639 }
640
641 let er = if self.vol_sum > f64::EPSILON {
642 (value - old_value).abs() / self.vol_sum
643 } else {
644 0.0
645 };
646
647 let mut sc = self.fast_sc.mul_add(er, self.slow_sc);
648 sc *= sc;
649
650 let out = sc.mul_add(value - self.last, self.last);
651 self.last = out;
652 Some(out)
653 }
654}
655
656#[derive(Clone, Debug)]
657pub struct MaaqBatchRange {
658 pub period: (usize, usize, usize),
659 pub fast_period: (usize, usize, usize),
660 pub slow_period: (usize, usize, usize),
661}
662
663impl Default for MaaqBatchRange {
664 fn default() -> Self {
665 Self {
666 period: (11, 260, 1),
667 fast_period: (2, 2, 0),
668 slow_period: (30, 30, 0),
669 }
670 }
671}
672
673#[derive(Clone, Debug, Default)]
674pub struct MaaqBatchBuilder {
675 range: MaaqBatchRange,
676 kernel: Kernel,
677}
678
679impl MaaqBatchBuilder {
680 pub fn new() -> Self {
681 Self::default()
682 }
683 pub fn kernel(mut self, k: Kernel) -> Self {
684 self.kernel = k;
685 self
686 }
687 #[inline]
688 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
689 self.range.period = (start, end, step);
690 self
691 }
692 #[inline]
693 pub fn period_static(mut self, p: usize) -> Self {
694 self.range.period = (p, p, 0);
695 self
696 }
697 #[inline]
698 pub fn fast_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
699 self.range.fast_period = (start, end, step);
700 self
701 }
702 #[inline]
703 pub fn fast_period_static(mut self, x: usize) -> Self {
704 self.range.fast_period = (x, x, 0);
705 self
706 }
707 #[inline]
708 pub fn slow_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
709 self.range.slow_period = (start, end, step);
710 self
711 }
712 #[inline]
713 pub fn slow_period_static(mut self, s: usize) -> Self {
714 self.range.slow_period = (s, s, 0);
715 self
716 }
717 pub fn apply_slice(self, data: &[f64]) -> Result<MaaqBatchOutput, MaaqError> {
718 maaq_batch_with_kernel(data, &self.range, self.kernel)
719 }
720 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<MaaqBatchOutput, MaaqError> {
721 MaaqBatchBuilder::new().kernel(k).apply_slice(data)
722 }
723 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<MaaqBatchOutput, MaaqError> {
724 let slice = source_type(c, src);
725 self.apply_slice(slice)
726 }
727 pub fn with_default_candles(c: &Candles) -> Result<MaaqBatchOutput, MaaqError> {
728 MaaqBatchBuilder::new()
729 .kernel(Kernel::Auto)
730 .apply_candles(c, "close")
731 }
732}
733
734pub fn maaq_batch_with_kernel(
735 data: &[f64],
736 sweep: &MaaqBatchRange,
737 k: Kernel,
738) -> Result<MaaqBatchOutput, MaaqError> {
739 let kernel = match k {
740 Kernel::Auto => Kernel::ScalarBatch,
741 other if other.is_batch() => other,
742 _ => return Err(MaaqError::InvalidKernelForBatch(k)),
743 };
744
745 let simd = match kernel {
746 Kernel::Avx512Batch => Kernel::Avx512,
747 Kernel::Avx2Batch => Kernel::Avx2,
748 Kernel::ScalarBatch => Kernel::Scalar,
749 _ => unreachable!(),
750 };
751 maaq_batch_par_slice(data, sweep, simd)
752}
753
754#[derive(Clone, Debug)]
755pub struct MaaqBatchOutput {
756 pub values: Vec<f64>,
757 pub combos: Vec<MaaqParams>,
758 pub rows: usize,
759 pub cols: usize,
760}
761
762impl MaaqBatchOutput {
763 pub fn row_for_params(&self, p: &MaaqParams) -> Option<usize> {
764 self.combos.iter().position(|c| {
765 c.period.unwrap_or(11) == p.period.unwrap_or(11)
766 && c.fast_period.unwrap_or(2) == p.fast_period.unwrap_or(2)
767 && c.slow_period.unwrap_or(30) == p.slow_period.unwrap_or(30)
768 })
769 }
770 pub fn values_for(&self, p: &MaaqParams) -> Option<&[f64]> {
771 self.row_for_params(p).map(|row| {
772 let start = row * self.cols;
773 &self.values[start..start + self.cols]
774 })
775 }
776}
777
778#[inline(always)]
779pub fn expand_grid(r: &MaaqBatchRange) -> Vec<MaaqParams> {
780 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
781 if start == end || step == 0 {
782 return vec![start];
783 }
784 let mut v = Vec::new();
785 if start < end {
786 let mut x = start;
787 while x <= end {
788 v.push(x);
789 match x.checked_add(step) {
790 Some(nx) if nx > x => x = nx,
791 _ => break,
792 }
793 }
794 } else {
795 let mut x = start;
796 while x >= end {
797 v.push(x);
798 match x.checked_sub(step) {
799 Some(nx) if nx < x => x = nx,
800 _ => break,
801 }
802 if x == 0 {
803 break;
804 }
805 }
806 }
807 v
808 }
809 let periods = axis_usize(r.period);
810 let fasts = axis_usize(r.fast_period);
811 let slows = axis_usize(r.slow_period);
812 let mut out = Vec::with_capacity(periods.len() * fasts.len() * slows.len());
813 for &p in &periods {
814 for &f in &fasts {
815 for &s in &slows {
816 out.push(MaaqParams {
817 period: Some(p),
818 fast_period: Some(f),
819 slow_period: Some(s),
820 });
821 }
822 }
823 }
824 out
825}
826
827#[inline(always)]
828pub fn maaq_batch_slice(
829 data: &[f64],
830 sweep: &MaaqBatchRange,
831 kern: Kernel,
832) -> Result<MaaqBatchOutput, MaaqError> {
833 maaq_batch_inner(data, sweep, kern, false)
834}
835
836#[inline(always)]
837pub fn maaq_batch_par_slice(
838 data: &[f64],
839 sweep: &MaaqBatchRange,
840 kern: Kernel,
841) -> Result<MaaqBatchOutput, MaaqError> {
842 maaq_batch_inner(data, sweep, kern, true)
843}
844
845#[inline(always)]
846fn maaq_batch_inner(
847 data: &[f64],
848 sweep: &MaaqBatchRange,
849 kern: Kernel,
850 parallel: bool,
851) -> Result<MaaqBatchOutput, MaaqError> {
852 let combos = expand_grid(sweep);
853 if combos.is_empty() {
854 return Err(MaaqError::InvalidRange {
855 start: 0,
856 end: 0,
857 step: 0,
858 });
859 }
860 let first = data
861 .iter()
862 .position(|x| !x.is_nan())
863 .ok_or(MaaqError::AllValuesNaN)?;
864 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
865 if data.len() - first < max_p {
866 return Err(MaaqError::NotEnoughValidData {
867 needed: max_p,
868 valid: data.len() - first,
869 });
870 }
871 let rows = combos.len();
872 let cols = data.len();
873
874 if rows.checked_mul(cols).is_none() {
875 return Err(MaaqError::InvalidRange {
876 start: rows,
877 end: cols,
878 step: 0,
879 });
880 }
881
882 let warm: Vec<usize> = combos
883 .iter()
884 .map(|c| first + c.period.unwrap() - 1)
885 .collect();
886
887 let mut raw = make_uninit_matrix(rows, cols);
888 unsafe { init_matrix_prefixes(&mut raw, cols, &warm) };
889
890 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
891 let period = combos[row].period.unwrap();
892 let fast_p = combos[row].fast_period.unwrap();
893 let slow_p = combos[row].slow_period.unwrap();
894
895 let out_row =
896 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
897
898 match kern {
899 Kernel::Scalar => maaq_row_scalar(data, first, period, fast_p, slow_p, out_row),
900 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
901 Kernel::Avx2 => maaq_row_avx2(data, first, period, fast_p, slow_p, out_row),
902 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
903 Kernel::Avx512 => maaq_row_avx512(data, first, period, fast_p, slow_p, out_row),
904 _ => unreachable!(),
905 }
906 };
907
908 if parallel {
909 #[cfg(not(target_arch = "wasm32"))]
910 {
911 raw.par_chunks_mut(cols)
912 .enumerate()
913 .for_each(|(row, slice)| do_row(row, slice));
914 }
915
916 #[cfg(target_arch = "wasm32")]
917 {
918 for (row, slice) in raw.chunks_mut(cols).enumerate() {
919 do_row(row, slice);
920 }
921 }
922 } else {
923 for (row, slice) in raw.chunks_mut(cols).enumerate() {
924 do_row(row, slice);
925 }
926 }
927
928 let mut guard = core::mem::ManuallyDrop::new(raw);
929 let values: Vec<f64> = unsafe {
930 Vec::from_raw_parts(
931 guard.as_mut_ptr() as *mut f64,
932 guard.len(),
933 guard.capacity(),
934 )
935 };
936
937 Ok(MaaqBatchOutput {
938 values,
939 combos,
940 rows,
941 cols,
942 })
943}
944
945pub fn maaq_batch_inner_into(
946 data: &[f64],
947 sweep: &MaaqBatchRange,
948 kern: Kernel,
949 parallel: bool,
950 out: &mut [f64],
951) -> Result<Vec<MaaqParams>, MaaqError> {
952 let combos = expand_grid(sweep);
953 if combos.is_empty() {
954 return Err(MaaqError::InvalidRange {
955 start: 0,
956 end: 0,
957 step: 0,
958 });
959 }
960 let first = data
961 .iter()
962 .position(|x| !x.is_nan())
963 .ok_or(MaaqError::AllValuesNaN)?;
964 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
965 if data.len() - first < max_p {
966 return Err(MaaqError::NotEnoughValidData {
967 needed: max_p,
968 valid: data.len() - first,
969 });
970 }
971 let rows = combos.len();
972 let cols = data.len();
973
974 let expected = rows.checked_mul(cols).ok_or(MaaqError::InvalidRange {
975 start: rows,
976 end: cols,
977 step: 0,
978 })?;
979 if out.len() != expected {
980 return Err(MaaqError::OutputLengthMismatch {
981 expected,
982 got: out.len(),
983 });
984 }
985
986 let out_uninit = unsafe {
987 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
988 };
989
990 let warm: Vec<usize> = combos
991 .iter()
992 .map(|c| first + c.period.unwrap() - 1)
993 .collect();
994
995 unsafe { init_matrix_prefixes(out_uninit, cols, &warm) };
996
997 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
998 let period = combos[row].period.unwrap();
999 let fast_p = combos[row].fast_period.unwrap();
1000 let slow_p = combos[row].slow_period.unwrap();
1001
1002 let out_row =
1003 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1004
1005 match kern {
1006 Kernel::Scalar | Kernel::ScalarBatch => {
1007 maaq_row_scalar(data, first, period, fast_p, slow_p, out_row)
1008 }
1009 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1010 Kernel::Avx2 | Kernel::Avx2Batch => {
1011 maaq_row_avx2(data, first, period, fast_p, slow_p, out_row)
1012 }
1013 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1014 Kernel::Avx512 | Kernel::Avx512Batch => {
1015 maaq_row_avx512(data, first, period, fast_p, slow_p, out_row)
1016 }
1017 _ => unreachable!(),
1018 }
1019 };
1020
1021 if parallel {
1022 #[cfg(not(target_arch = "wasm32"))]
1023 {
1024 out_uninit
1025 .par_chunks_mut(cols)
1026 .enumerate()
1027 .for_each(|(row, slice)| do_row(row, slice));
1028 }
1029
1030 #[cfg(target_arch = "wasm32")]
1031 {
1032 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1033 do_row(row, slice);
1034 }
1035 }
1036 } else {
1037 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1038 do_row(row, slice);
1039 }
1040 }
1041
1042 Ok(combos)
1043}
1044
1045#[inline(always)]
1046unsafe fn maaq_row_scalar(
1047 data: &[f64],
1048 first: usize,
1049 period: usize,
1050 fast_p: usize,
1051 slow_p: usize,
1052 out: &mut [f64],
1053) {
1054 maaq_scalar(data, period, fast_p, slow_p, first, out);
1055}
1056
1057#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1058#[inline(always)]
1059unsafe fn maaq_row_avx2(
1060 data: &[f64],
1061 first: usize,
1062 period: usize,
1063 fast_p: usize,
1064 slow_p: usize,
1065 out: &mut [f64],
1066) {
1067 maaq_avx2(data, period, fast_p, slow_p, first, out);
1068}
1069
1070#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1071#[inline(always)]
1072pub unsafe fn maaq_row_avx512(
1073 data: &[f64],
1074 first: usize,
1075 period: usize,
1076 fast_p: usize,
1077 slow_p: usize,
1078 out: &mut [f64],
1079) {
1080 maaq_avx2(data, period, fast_p, slow_p, first, out);
1081}
1082
1083#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1084#[inline(always)]
1085unsafe fn maaq_row_avx512_short(
1086 data: &[f64],
1087 first: usize,
1088 period: usize,
1089 fast_p: usize,
1090 slow_p: usize,
1091 out: &mut [f64],
1092) {
1093 maaq_row_scalar(data, first, period, fast_p, slow_p, out)
1094}
1095
1096#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1097#[inline(always)]
1098unsafe fn maaq_row_avx512_long(
1099 data: &[f64],
1100 first: usize,
1101 period: usize,
1102 fast_p: usize,
1103 slow_p: usize,
1104 out: &mut [f64],
1105) {
1106 maaq_row_scalar(data, first, period, fast_p, slow_p, out)
1107}
1108
1109#[cfg(test)]
1110mod tests {
1111 use super::*;
1112 use crate::skip_if_unsupported;
1113 use crate::utilities::data_loader::read_candles_from_csv;
1114
1115 fn check_maaq_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1116 skip_if_unsupported!(kernel, test_name);
1117 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1118 let candles = read_candles_from_csv(file_path)?;
1119 let default_params = MaaqParams {
1120 period: None,
1121 fast_period: None,
1122 slow_period: None,
1123 };
1124 let input = MaaqInput::from_candles(&candles, "close", default_params);
1125 let output = maaq_with_kernel(&input, kernel)?;
1126 assert_eq!(output.values.len(), candles.close.len());
1127 Ok(())
1128 }
1129
1130 fn check_maaq_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1131 skip_if_unsupported!(kernel, test_name);
1132 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1133 let candles = read_candles_from_csv(file_path)?;
1134 let input = MaaqInput::from_candles(&candles, "close", MaaqParams::default());
1135 let result = maaq_with_kernel(&input, kernel)?;
1136 let expected_last_five = [
1137 59747.657115949725,
1138 59740.803138018055,
1139 59724.24153333905,
1140 59720.60576365108,
1141 59673.9954445178,
1142 ];
1143 let start = result.values.len().saturating_sub(5);
1144 for (i, &val) in result.values[start..].iter().enumerate() {
1145 let diff = (val - expected_last_five[i]).abs();
1146 assert!(
1147 diff < 1e-2,
1148 "[{}] MAAQ {:?} mismatch at idx {}: got {}, expected {}",
1149 test_name,
1150 kernel,
1151 i,
1152 val,
1153 expected_last_five[i]
1154 );
1155 }
1156 Ok(())
1157 }
1158
1159 fn check_maaq_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1160 skip_if_unsupported!(kernel, test_name);
1161 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1162 let candles = read_candles_from_csv(file_path)?;
1163 let input = MaaqInput::with_default_candles(&candles);
1164 match input.data {
1165 MaaqData::Candles { source, .. } => assert_eq!(source, "close"),
1166 _ => panic!("Expected MaaqData::Candles"),
1167 }
1168 let output = maaq_with_kernel(&input, kernel)?;
1169 assert_eq!(output.values.len(), candles.close.len());
1170 Ok(())
1171 }
1172
1173 fn check_maaq_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1174 skip_if_unsupported!(kernel, test_name);
1175 let input_data = [10.0, 20.0, 30.0];
1176 let params = MaaqParams {
1177 period: Some(0),
1178 fast_period: Some(0),
1179 slow_period: Some(0),
1180 };
1181 let input = MaaqInput::from_slice(&input_data, params);
1182 let res = maaq_with_kernel(&input, kernel);
1183 assert!(
1184 res.is_err(),
1185 "[{}] MAAQ should fail with zero periods",
1186 test_name
1187 );
1188 Ok(())
1189 }
1190
1191 fn check_maaq_period_exceeds_length(
1192 test_name: &str,
1193 kernel: Kernel,
1194 ) -> Result<(), Box<dyn Error>> {
1195 skip_if_unsupported!(kernel, test_name);
1196 let data_small = [10.0, 20.0, 30.0];
1197 let params = MaaqParams {
1198 period: Some(10),
1199 fast_period: Some(2),
1200 slow_period: Some(10),
1201 };
1202 let input = MaaqInput::from_slice(&data_small, params);
1203 let res = maaq_with_kernel(&input, kernel);
1204 assert!(
1205 res.is_err(),
1206 "[{}] MAAQ should fail with period exceeding length",
1207 test_name
1208 );
1209 Ok(())
1210 }
1211
1212 fn check_maaq_very_small_dataset(
1213 test_name: &str,
1214 kernel: Kernel,
1215 ) -> Result<(), Box<dyn Error>> {
1216 skip_if_unsupported!(kernel, test_name);
1217 let single_point = [42.0];
1218 let params = MaaqParams {
1219 period: Some(9),
1220 fast_period: Some(2),
1221 slow_period: Some(10),
1222 };
1223 let input = MaaqInput::from_slice(&single_point, params);
1224 let res = maaq_with_kernel(&input, kernel);
1225 assert!(
1226 res.is_err(),
1227 "[{}] MAAQ should fail with insufficient data",
1228 test_name
1229 );
1230 Ok(())
1231 }
1232
1233 fn check_maaq_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1234 skip_if_unsupported!(kernel, test_name);
1235 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1236 let candles = read_candles_from_csv(file_path)?;
1237 let first_params = MaaqParams {
1238 period: Some(11),
1239 fast_period: Some(2),
1240 slow_period: Some(30),
1241 };
1242 let first_input = MaaqInput::from_candles(&candles, "close", first_params);
1243 let first_result = maaq_with_kernel(&first_input, kernel)?;
1244 let second_params = MaaqParams {
1245 period: Some(5),
1246 fast_period: Some(2),
1247 slow_period: Some(10),
1248 };
1249 let second_input = MaaqInput::from_slice(&first_result.values, second_params);
1250 let second_result = maaq_with_kernel(&second_input, kernel)?;
1251 assert_eq!(second_result.values.len(), first_result.values.len());
1252 Ok(())
1253 }
1254
1255 fn check_maaq_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1256 skip_if_unsupported!(kernel, test_name);
1257 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1258 let candles = read_candles_from_csv(file_path)?;
1259 let input = MaaqInput::from_candles(
1260 &candles,
1261 "close",
1262 MaaqParams {
1263 period: Some(11),
1264 fast_period: Some(2),
1265 slow_period: Some(30),
1266 },
1267 );
1268 let res = maaq_with_kernel(&input, kernel)?;
1269 assert_eq!(res.values.len(), candles.close.len());
1270 if res.values.len() > 240 {
1271 for (i, &val) in res.values[240..].iter().enumerate() {
1272 assert!(
1273 !val.is_nan(),
1274 "[{}] Found unexpected NaN at out-index {}",
1275 test_name,
1276 240 + i
1277 );
1278 }
1279 }
1280 Ok(())
1281 }
1282
1283 fn check_maaq_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1284 skip_if_unsupported!(kernel, test_name);
1285 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1286 let candles = read_candles_from_csv(file_path)?;
1287 let period = 11;
1288 let fast_p = 2;
1289 let slow_p = 30;
1290 let input = MaaqInput::from_candles(
1291 &candles,
1292 "close",
1293 MaaqParams {
1294 period: Some(period),
1295 fast_period: Some(fast_p),
1296 slow_period: Some(slow_p),
1297 },
1298 );
1299 let batch_output = maaq_with_kernel(&input, kernel)?.values;
1300 let mut stream = MaaqStream::try_new(MaaqParams {
1301 period: Some(period),
1302 fast_period: Some(fast_p),
1303 slow_period: Some(slow_p),
1304 })?;
1305 let mut stream_values = Vec::with_capacity(candles.close.len());
1306 for &price in &candles.close {
1307 match stream.update(price) {
1308 Some(maaq_val) => stream_values.push(maaq_val),
1309 None => stream_values.push(f64::NAN),
1310 }
1311 }
1312 assert_eq!(batch_output.len(), stream_values.len());
1313
1314 for i in period..batch_output.len() {
1315 let b = batch_output[i];
1316 let s = stream_values[i];
1317 if b.is_nan() && s.is_nan() {
1318 continue;
1319 }
1320 let diff = (b - s).abs();
1321 assert!(
1322 diff < 1e-9,
1323 "[{}] MAAQ streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1324 test_name,
1325 i,
1326 b,
1327 s,
1328 diff
1329 );
1330 }
1331 Ok(())
1332 }
1333
1334 macro_rules! generate_all_maaq_tests {
1335 ($($test_fn:ident),*) => {
1336 paste::paste! {
1337 $(
1338 #[test]
1339 fn [<$test_fn _scalar_f64>]() {
1340 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1341 }
1342 )*
1343 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1344 $(
1345 #[test]
1346 fn [<$test_fn _avx2_f64>]() {
1347 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1348 }
1349 #[test]
1350 fn [<$test_fn _avx512_f64>]() {
1351 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1352 }
1353 )*
1354 }
1355 }
1356 }
1357
1358 #[cfg(debug_assertions)]
1359 fn check_maaq_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1360 skip_if_unsupported!(kernel, test_name);
1361
1362 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1363 let candles = read_candles_from_csv(file_path)?;
1364
1365 let test_cases = vec![
1366 MaaqParams::default(),
1367 MaaqParams {
1368 period: Some(5),
1369 fast_period: Some(2),
1370 slow_period: Some(10),
1371 },
1372 MaaqParams {
1373 period: Some(8),
1374 fast_period: Some(3),
1375 slow_period: Some(20),
1376 },
1377 MaaqParams {
1378 period: Some(11),
1379 fast_period: Some(2),
1380 slow_period: Some(30),
1381 },
1382 MaaqParams {
1383 period: Some(15),
1384 fast_period: Some(4),
1385 slow_period: Some(40),
1386 },
1387 MaaqParams {
1388 period: Some(20),
1389 fast_period: Some(5),
1390 slow_period: Some(50),
1391 },
1392 MaaqParams {
1393 period: Some(30),
1394 fast_period: Some(6),
1395 slow_period: Some(60),
1396 },
1397 MaaqParams {
1398 period: Some(10),
1399 fast_period: Some(8),
1400 slow_period: Some(30),
1401 },
1402 MaaqParams {
1403 period: Some(25),
1404 fast_period: Some(1),
1405 slow_period: Some(100),
1406 },
1407 ];
1408
1409 for params in test_cases {
1410 let input = MaaqInput::from_candles(&candles, "close", params.clone());
1411 let output = maaq_with_kernel(&input, kernel)?;
1412
1413 for (i, &val) in output.values.iter().enumerate() {
1414 if val.is_nan() {
1415 continue;
1416 }
1417
1418 let bits = val.to_bits();
1419
1420 if bits == 0x11111111_11111111 {
1421 panic!(
1422 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with params period={:?}, fast_period={:?}, slow_period={:?}",
1423 test_name, val, bits, i, params.period, params.fast_period, params.slow_period
1424 );
1425 }
1426
1427 if bits == 0x22222222_22222222 {
1428 panic!(
1429 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with params period={:?}, fast_period={:?}, slow_period={:?}",
1430 test_name, val, bits, i, params.period, params.fast_period, params.slow_period
1431 );
1432 }
1433
1434 if bits == 0x33333333_33333333 {
1435 panic!(
1436 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with params period={:?}, fast_period={:?}, slow_period={:?}",
1437 test_name, val, bits, i, params.period, params.fast_period, params.slow_period
1438 );
1439 }
1440 }
1441 }
1442
1443 Ok(())
1444 }
1445
1446 #[cfg(not(debug_assertions))]
1447 fn check_maaq_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1448 Ok(())
1449 }
1450
1451 generate_all_maaq_tests!(
1452 check_maaq_partial_params,
1453 check_maaq_accuracy,
1454 check_maaq_default_candles,
1455 check_maaq_zero_period,
1456 check_maaq_period_exceeds_length,
1457 check_maaq_very_small_dataset,
1458 check_maaq_reinput,
1459 check_maaq_nan_handling,
1460 check_maaq_streaming,
1461 check_maaq_no_poison
1462 );
1463
1464 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1465 skip_if_unsupported!(kernel, test);
1466 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1467 let c = read_candles_from_csv(file)?;
1468 let output = MaaqBatchBuilder::new()
1469 .kernel(kernel)
1470 .apply_candles(&c, "close")?;
1471 let def = MaaqParams::default();
1472 let row = output.values_for(&def).expect("default row missing");
1473 assert_eq!(row.len(), c.close.len());
1474 Ok(())
1475 }
1476
1477 macro_rules! gen_batch_tests {
1478 ($fn_name:ident) => {
1479 paste::paste! {
1480 #[test] fn [<$fn_name _scalar>]() {
1481 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1482 }
1483 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1484 #[test] fn [<$fn_name _avx2>]() {
1485 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1486 }
1487 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1488 #[test] fn [<$fn_name _avx512>]() {
1489 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1490 }
1491 #[test] fn [<$fn_name _auto_detect>]() {
1492 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1493 }
1494 }
1495 };
1496 }
1497
1498 #[cfg(debug_assertions)]
1499 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1500 skip_if_unsupported!(kernel, test);
1501
1502 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1503 let c = read_candles_from_csv(file)?;
1504
1505 let test_configs = vec![
1506 ((5, 10, 2), (2, 4, 1), (10, 30, 5)),
1507 ((10, 20, 5), (2, 6, 2), (20, 50, 10)),
1508 ((20, 30, 5), (4, 8, 2), (40, 80, 20)),
1509 ((10, 15, 5), (5, 10, 5), (30, 60, 30)),
1510 ((8, 12, 1), (2, 5, 1), (15, 25, 5)),
1511 ];
1512
1513 for (period_range, fast_range, slow_range) in test_configs {
1514 let output = MaaqBatchBuilder::new()
1515 .kernel(kernel)
1516 .period_range(period_range.0, period_range.1, period_range.2)
1517 .fast_period_range(fast_range.0, fast_range.1, fast_range.2)
1518 .slow_period_range(slow_range.0, slow_range.1, slow_range.2)
1519 .apply_candles(&c, "close")?;
1520
1521 for (idx, &val) in output.values.iter().enumerate() {
1522 if val.is_nan() {
1523 continue;
1524 }
1525
1526 let bits = val.to_bits();
1527 let row = idx / output.cols;
1528 let col = idx % output.cols;
1529 let params = &output.combos[row];
1530
1531 if bits == 0x11111111_11111111 {
1532 panic!(
1533 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (params: period={:?}, fast_period={:?}, slow_period={:?})",
1534 test, val, bits, row, col, params.period, params.fast_period, params.slow_period
1535 );
1536 }
1537
1538 if bits == 0x22222222_22222222 {
1539 panic!(
1540 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (params: period={:?}, fast_period={:?}, slow_period={:?})",
1541 test, val, bits, row, col, params.period, params.fast_period, params.slow_period
1542 );
1543 }
1544
1545 if bits == 0x33333333_33333333 {
1546 panic!(
1547 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (params: period={:?}, fast_period={:?}, slow_period={:?})",
1548 test, val, bits, row, col, params.period, params.fast_period, params.slow_period
1549 );
1550 }
1551 }
1552 }
1553
1554 Ok(())
1555 }
1556
1557 #[cfg(not(debug_assertions))]
1558 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1559 Ok(())
1560 }
1561
1562 gen_batch_tests!(check_batch_default_row);
1563 gen_batch_tests!(check_batch_no_poison);
1564
1565 #[cfg(feature = "proptest")]
1566 #[allow(clippy::float_cmp)]
1567 fn check_maaq_property(
1568 test_name: &str,
1569 kernel: Kernel,
1570 ) -> Result<(), Box<dyn std::error::Error>> {
1571 use proptest::prelude::*;
1572 skip_if_unsupported!(kernel, test_name);
1573
1574 let main_strat = (
1575 proptest::collection::vec(
1576 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1577 20..200,
1578 ),
1579 2usize..30,
1580 1usize..10,
1581 10usize..50,
1582 )
1583 .prop_filter("valid params", |(data, period, fast_p, slow_p)| {
1584 *period <= data.len() && *fast_p < *slow_p
1585 });
1586
1587 proptest::test_runner::TestRunner::default().run(
1588 &main_strat,
1589 |(data, period, fast_p, slow_p)| {
1590 let params = MaaqParams {
1591 period: Some(period),
1592 fast_period: Some(fast_p),
1593 slow_period: Some(slow_p),
1594 };
1595 let input = MaaqInput::from_slice(&data, params.clone());
1596
1597 let result = maaq_with_kernel(&input, kernel)?;
1598 let reference = maaq_with_kernel(&input, Kernel::Scalar)?;
1599
1600 prop_assert_eq!(
1601 result.values.len(),
1602 data.len(),
1603 "Output length {} doesn't match input length {}",
1604 result.values.len(),
1605 data.len()
1606 );
1607
1608 let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1609 let warmup_end = first_valid + period - 1;
1610 for i in 0..warmup_end.min(data.len()) {
1611 prop_assert!(
1612 result.values[i].is_nan(),
1613 "Warmup value at {} should be NaN, got {}",
1614 i,
1615 result.values[i]
1616 );
1617 }
1618
1619 for i in 0..result.values.len() {
1620 let y = result.values[i];
1621 let r = reference.values[i];
1622
1623 if !y.is_finite() || !r.is_finite() {
1624 prop_assert_eq!(
1625 y.to_bits(),
1626 r.to_bits(),
1627 "NaN/Inf mismatch at {}: {} vs {}",
1628 i,
1629 y,
1630 r
1631 );
1632 continue;
1633 }
1634
1635 let ulp_diff = y.to_bits().abs_diff(r.to_bits());
1636 prop_assert!(
1637 (y - r).abs() <= 1e-9 || ulp_diff <= 5,
1638 "SIMD mismatch at {}: {} vs {} (ULP={})",
1639 i,
1640 y,
1641 r,
1642 ulp_diff
1643 );
1644 }
1645
1646 let data_min = data.iter().copied().fold(f64::INFINITY, f64::min);
1647 let data_max = data.iter().copied().fold(f64::NEG_INFINITY, f64::max);
1648 let range = (data_max - data_min).abs();
1649 let tolerance = range * 0.02;
1650
1651 for (i, &val) in result.values.iter().enumerate() {
1652 if val.is_finite() && i >= period {
1653 prop_assert!(
1654 val >= data_min - tolerance && val <= data_max + tolerance,
1655 "Value {} at index {} outside bounds [{}, {}]",
1656 val,
1657 i,
1658 data_min - tolerance,
1659 data_max + tolerance
1660 );
1661 }
1662 }
1663
1664 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) && data.len() > period {
1665 let constant_val = data[0];
1666 for (i, &val) in result.values[period..].iter().enumerate() {
1667 prop_assert!(
1668 (val - constant_val).abs() < 1e-8,
1669 "Constant data should produce constant output, got {} at index {}",
1670 val,
1671 i + period
1672 );
1673 }
1674 }
1675
1676 Ok(())
1677 },
1678 )?;
1679
1680 let maaq_strat = (
1681 proptest::collection::vec(
1682 (-100f64..100f64).prop_filter("finite", |x| x.is_finite()),
1683 50..100,
1684 ),
1685 5usize..15,
1686 1usize..5,
1687 20usize..40,
1688 )
1689 .prop_filter("valid maaq params", |(_data, period, fast_p, slow_p)| {
1690 *fast_p < *slow_p
1691 });
1692
1693 proptest::test_runner::TestRunner::default().run(
1694 &maaq_strat,
1695 |(data, period, fast_p, slow_p)| {
1696 let params = MaaqParams {
1697 period: Some(period),
1698 fast_period: Some(fast_p),
1699 slow_period: Some(slow_p),
1700 };
1701 let input = MaaqInput::from_slice(&data, params);
1702 let result = maaq_with_kernel(&input, kernel)?;
1703
1704 let fast_sc = 2.0 / (fast_p as f64 + 1.0);
1705 let slow_sc = 2.0 / (slow_p as f64 + 1.0);
1706
1707 for i in (period + 1)..data.len() {
1708 let signal = (data[i] - data[i - period]).abs();
1709 let noise: f64 = (1..=period)
1710 .map(|j| (data[i - j + 1] - data[i - j]).abs())
1711 .sum();
1712
1713 if noise > f64::EPSILON {
1714 let er = signal / noise;
1715 prop_assert!(
1716 er >= 0.0 && er <= 1.0 + 1e-10,
1717 "Efficiency ratio {} out of bounds at index {}",
1718 er,
1719 i
1720 );
1721
1722 let sc = (er * fast_sc + slow_sc).powi(2);
1723
1724 let min_sc = slow_sc.powi(2);
1725 let max_sc = (fast_sc + slow_sc).powi(2);
1726 prop_assert!(
1727 sc >= min_sc - 1e-10 && sc <= max_sc + 1e-10,
1728 "Smoothing constant {} out of bounds [{}..{}] at index {}",
1729 sc,
1730 min_sc,
1731 max_sc,
1732 i
1733 );
1734 }
1735 }
1736
1737 if data.len() >= period * 3 {
1738 let trending_indices: Vec<usize> = (period..data.len())
1739 .filter(|&i| {
1740 let signal = (data[i] - data[i.saturating_sub(period)]).abs();
1741 signal > 10.0
1742 })
1743 .collect();
1744
1745 for &i in trending_indices.iter().take(5) {
1746 let tracking_error = (result.values[i] - data[i]).abs();
1747 let price_range = data[i.saturating_sub(period)..=i]
1748 .iter()
1749 .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), &v| {
1750 (min.min(v), max.max(v))
1751 });
1752 let local_range = (price_range.1 - price_range.0).abs();
1753
1754 prop_assert!(
1755 tracking_error <= local_range * 0.2 + 1.0,
1756 "Poor tracking in trend at {}: error {} > 20% of range {}",
1757 i,
1758 tracking_error,
1759 local_range
1760 );
1761 }
1762 }
1763
1764 Ok(())
1765 },
1766 )?;
1767
1768 let step_strat = (
1769 10usize..30,
1770 2usize..5,
1771 20usize..40,
1772 -100f64..100f64,
1773 -100f64..100f64,
1774 )
1775 .prop_filter("different levels", |(_p, _f, _s, init, final_level)| {
1776 (init - final_level).abs() > 1.0
1777 });
1778
1779 proptest::test_runner::TestRunner::default().run(
1780 &step_strat,
1781 |(period, fast_p, slow_p, initial, final_level)| {
1782 let mut data = vec![initial; 50];
1783 data.extend(vec![final_level; 50]);
1784
1785 let params = MaaqParams {
1786 period: Some(period),
1787 fast_period: Some(fast_p),
1788 slow_period: Some(slow_p),
1789 };
1790 let input = MaaqInput::from_slice(&data, params);
1791 let result = maaq_with_kernel(&input, kernel)?;
1792
1793 let last_values = &result.values[90..];
1794 let convergence_target = final_level;
1795
1796 for &val in last_values {
1797 let distance_to_target = (val - convergence_target).abs();
1798 let initial_distance = (initial - final_level).abs();
1799
1800 prop_assert!(
1801 distance_to_target < initial_distance * 0.3,
1802 "Failed to converge: value {} too far from target {}",
1803 val,
1804 convergence_target
1805 );
1806 }
1807
1808 Ok(())
1809 },
1810 )?;
1811
1812 let small_strat = (
1813 proptest::collection::vec(
1814 (-100f64..100f64).prop_filter("finite", |x| x.is_finite()),
1815 1..5,
1816 ),
1817 1usize..3,
1818 1usize..3,
1819 3usize..6,
1820 )
1821 .prop_filter("valid small params", |(data, period, _fast_p, _slow_p)| {
1822 *period <= data.len()
1823 });
1824
1825 proptest::test_runner::TestRunner::default().run(
1826 &small_strat,
1827 |(data, period, fast_p, slow_p)| {
1828 let params = MaaqParams {
1829 period: Some(period),
1830 fast_period: Some(fast_p),
1831 slow_period: Some(slow_p),
1832 };
1833 let input = MaaqInput::from_slice(&data, params);
1834
1835 let result = maaq_with_kernel(&input, kernel)?;
1836
1837 prop_assert_eq!(result.values.len(), data.len());
1838
1839 for i in 0..period.min(data.len()) {
1840 if data[i].is_finite() {
1841 prop_assert!(
1842 (result.values[i] - data[i]).abs() < 1e-10,
1843 "Small data warmup mismatch at {}",
1844 i
1845 );
1846 }
1847 }
1848
1849 Ok(())
1850 },
1851 )?;
1852
1853 Ok(())
1854 }
1855
1856 #[cfg(feature = "proptest")]
1857 generate_all_maaq_tests!(check_maaq_property);
1858
1859 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1860 #[test]
1861 fn test_maaq_into_matches_api() -> Result<(), Box<dyn Error>> {
1862 let mut data: Vec<f64> = vec![f64::NAN, f64::NAN, f64::NAN];
1863 for i in 0..256u32 {
1864 let x = (i as f64).sin() * 0.5 + (i as f64) * 0.1 + ((i % 7) as f64) * 0.01;
1865 data.push(x);
1866 }
1867
1868 let input = MaaqInput::from_slice(&data, MaaqParams::default());
1869
1870 let baseline = maaq(&input)?.values;
1871
1872 let mut out = vec![0.0; data.len()];
1873 super::maaq_into(&input, &mut out)?;
1874
1875 assert_eq!(baseline.len(), out.len());
1876
1877 for (idx, (a, b)) in baseline.iter().zip(out.iter()).enumerate() {
1878 let equal = (a.is_nan() && b.is_nan()) || ((a - b).abs() <= 1e-12);
1879 assert!(equal, "Mismatch at {}: {} vs {}", idx, a, b);
1880 }
1881
1882 Ok(())
1883 }
1884}
1885
1886#[cfg(feature = "python")]
1887use crate::utilities::kernel_validation::validate_kernel;
1888#[cfg(feature = "python")]
1889use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2};
1890#[cfg(feature = "python")]
1891use pyo3::exceptions::PyValueError;
1892#[cfg(feature = "python")]
1893use pyo3::prelude::*;
1894
1895#[cfg(all(feature = "python", feature = "cuda"))]
1896pub struct PrimaryCtxGuard {
1897 dev: i32,
1898 ctx: cust::sys::CUcontext,
1899}
1900
1901#[cfg(all(feature = "python", feature = "cuda"))]
1902impl PrimaryCtxGuard {
1903 fn new(device_id: u32) -> Result<Self, cust::error::CudaError> {
1904 unsafe {
1905 let mut ctx: cust::sys::CUcontext = core::ptr::null_mut();
1906 let dev = device_id as i32;
1907 let rc = cust::sys::cuDevicePrimaryCtxRetain(&mut ctx as *mut _, dev);
1908 if rc != cust::sys::CUresult::CUDA_SUCCESS {
1909 return Err(cust::error::CudaError::UnknownError);
1910 }
1911 Ok(PrimaryCtxGuard { dev, ctx })
1912 }
1913 }
1914 #[inline]
1915 unsafe fn push_current(&self) {
1916 let _ = cust::sys::cuCtxSetCurrent(self.ctx);
1917 }
1918}
1919
1920#[cfg(all(feature = "python", feature = "cuda"))]
1921impl Drop for PrimaryCtxGuard {
1922 fn drop(&mut self) {
1923 unsafe {
1924 let _ = cust::sys::cuDevicePrimaryCtxRelease_v2(self.dev);
1925 }
1926 }
1927}
1928
1929#[cfg(all(feature = "python", feature = "cuda"))]
1930#[pyclass(module = "ta_indicators.cuda", name = "DeviceArrayF32Maaq", unsendable)]
1931pub struct DeviceArrayF32MaaqPy {
1932 pub(crate) inner: Option<DeviceArrayF32Maaq>,
1933 device_id: u32,
1934 pc_guard: Option<PrimaryCtxGuard>,
1935}
1936
1937#[cfg(all(feature = "python", feature = "cuda"))]
1938#[pymethods]
1939impl DeviceArrayF32MaaqPy {
1940 #[getter]
1941 fn __cuda_array_interface__<'py>(
1942 &self,
1943 py: Python<'py>,
1944 ) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1945 let inner = self
1946 .inner
1947 .as_ref()
1948 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1949 let d = pyo3::types::PyDict::new(py);
1950 d.set_item("shape", (inner.rows, inner.cols))?;
1951 d.set_item("typestr", "<f4")?;
1952 d.set_item(
1953 "strides",
1954 (
1955 inner.cols * std::mem::size_of::<f32>(),
1956 std::mem::size_of::<f32>(),
1957 ),
1958 )?;
1959 let ptr_val: usize = if inner.rows == 0 || inner.cols == 0 {
1960 0
1961 } else {
1962 inner.device_ptr() as usize
1963 };
1964 d.set_item("data", (ptr_val, false))?;
1965
1966 d.set_item("version", 3)?;
1967 Ok(d)
1968 }
1969
1970 fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
1971 Ok((2, self.device_id as i32))
1972 }
1973
1974 #[pyo3(signature=(_stream=None, max_version=None, _dl_device=None, _copy=None))]
1975 fn __dlpack__<'py>(
1976 &mut self,
1977 py: Python<'py>,
1978 _stream: Option<pyo3::PyObject>,
1979 max_version: Option<pyo3::PyObject>,
1980 _dl_device: Option<pyo3::PyObject>,
1981 _copy: Option<pyo3::PyObject>,
1982 ) -> PyResult<PyObject> {
1983 use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1984
1985 let (kdl, alloc_dev) = self.__dlpack_device__()?;
1986 if let Some(dev_obj) = _dl_device.as_ref() {
1987 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1988 if dev_ty != kdl || dev_id != alloc_dev {
1989 let wants_copy = _copy
1990 .as_ref()
1991 .and_then(|c| c.extract::<bool>(py).ok())
1992 .unwrap_or(false);
1993 if wants_copy {
1994 return Err(PyValueError::new_err(
1995 "device copy not implemented for __dlpack__",
1996 ));
1997 } else {
1998 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1999 }
2000 }
2001 }
2002 }
2003 let _ = _stream;
2004
2005 let inner = self
2006 .inner
2007 .take()
2008 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2009 let rows = inner.rows;
2010 let cols = inner.cols;
2011 let buf = inner.buf;
2012
2013 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2014
2015 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2016 }
2017}
2018
2019#[cfg(all(feature = "python", feature = "cuda"))]
2020impl Drop for DeviceArrayF32MaaqPy {
2021 fn drop(&mut self) {
2022 if let Some(ref pc) = self.pc_guard {
2023 unsafe {
2024 pc.push_current();
2025 }
2026 }
2027 }
2028}
2029
2030#[cfg(feature = "python")]
2031#[pyfunction(name = "maaq")]
2032#[pyo3(signature = (data, period, fast_period, slow_period, kernel=None))]
2033pub fn maaq_py<'py>(
2034 py: Python<'py>,
2035 data: PyReadonlyArray1<'py, f64>,
2036 period: usize,
2037 fast_period: usize,
2038 slow_period: usize,
2039 kernel: Option<&str>,
2040) -> PyResult<Bound<'py, PyArray1<f64>>> {
2041 use numpy::{IntoPyArray, PyArrayMethods};
2042
2043 let kern = validate_kernel(kernel, false)?;
2044 let params = MaaqParams {
2045 period: Some(period),
2046 fast_period: Some(fast_period),
2047 slow_period: Some(slow_period),
2048 };
2049
2050 let result_vec: Vec<f64> = if let Ok(slice_in) = data.as_slice() {
2051 let input = MaaqInput::from_slice(slice_in, params);
2052 py.allow_threads(|| maaq_with_kernel(&input, kern).map(|o| o.values))
2053 .map_err(|e| PyValueError::new_err(e.to_string()))?
2054 } else {
2055 let owned = data.as_array().to_owned();
2056 let slice_in = owned.as_slice().expect("owned array should be contiguous");
2057 let input = MaaqInput::from_slice(slice_in, params);
2058 py.allow_threads(|| maaq_with_kernel(&input, kern).map(|o| o.values))
2059 .map_err(|e| PyValueError::new_err(e.to_string()))?
2060 };
2061
2062 Ok(result_vec.into_pyarray(py))
2063}
2064
2065#[cfg(feature = "python")]
2066#[pyfunction(name = "maaq_batch")]
2067#[pyo3(signature = (data, period_range, fast_period_range, slow_period_range, kernel=None))]
2068pub fn maaq_batch_py<'py>(
2069 py: Python<'py>,
2070 data: PyReadonlyArray1<'py, f64>,
2071 period_range: (usize, usize, usize),
2072 fast_period_range: (usize, usize, usize),
2073 slow_period_range: (usize, usize, usize),
2074 kernel: Option<&str>,
2075) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2076 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2077 use pyo3::types::PyDict;
2078
2079 let slice_in = data.as_slice()?;
2080 let kern = validate_kernel(kernel, true)?;
2081
2082 let sweep = MaaqBatchRange {
2083 period: period_range,
2084 fast_period: fast_period_range,
2085 slow_period: slow_period_range,
2086 };
2087
2088 let combos = expand_grid(&sweep);
2089 let rows = combos.len();
2090 let cols = slice_in.len();
2091
2092 let total = rows
2093 .checked_mul(cols)
2094 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2095 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2096 let slice_out = unsafe { out_arr.as_slice_mut()? };
2097
2098 let combos = py
2099 .allow_threads(|| {
2100 let kernel = match kern {
2101 Kernel::Auto => detect_best_batch_kernel(),
2102 k => k,
2103 };
2104 let simd = match kernel {
2105 Kernel::Avx512Batch => Kernel::Avx512,
2106 Kernel::Avx2Batch => Kernel::Avx2,
2107 Kernel::ScalarBatch => Kernel::Scalar,
2108 _ => kernel,
2109 };
2110 maaq_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2111 })
2112 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2113
2114 let dict = PyDict::new(py);
2115 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2116 dict.set_item(
2117 "periods",
2118 combos
2119 .iter()
2120 .map(|p| p.period.unwrap() as u64)
2121 .collect::<Vec<_>>()
2122 .into_pyarray(py),
2123 )?;
2124 dict.set_item(
2125 "fast_periods",
2126 combos
2127 .iter()
2128 .map(|p| p.fast_period.unwrap() as u64)
2129 .collect::<Vec<_>>()
2130 .into_pyarray(py),
2131 )?;
2132 dict.set_item(
2133 "slow_periods",
2134 combos
2135 .iter()
2136 .map(|p| p.slow_period.unwrap() as u64)
2137 .collect::<Vec<_>>()
2138 .into_pyarray(py),
2139 )?;
2140
2141 Ok(dict)
2142}
2143
2144#[cfg(all(feature = "python", feature = "cuda"))]
2145#[pyfunction(name = "maaq_cuda_batch_dev")]
2146#[pyo3(signature = (data, period_range, fast_period_range, slow_period_range, device_id=0))]
2147pub fn maaq_cuda_batch_dev_py(
2148 py: Python<'_>,
2149 data: numpy::PyReadonlyArray1<'_, f64>,
2150 period_range: (usize, usize, usize),
2151 fast_period_range: (usize, usize, usize),
2152 slow_period_range: (usize, usize, usize),
2153 device_id: usize,
2154) -> PyResult<DeviceArrayF32MaaqPy> {
2155 use numpy::PyArrayMethods;
2156
2157 if !cuda_available() {
2158 return Err(PyValueError::new_err("CUDA not available"));
2159 }
2160
2161 let slice_in = data.as_slice()?;
2162 let sweep = MaaqBatchRange {
2163 period: period_range,
2164 fast_period: fast_period_range,
2165 slow_period: slow_period_range,
2166 };
2167 let data_f32: Vec<f32> = slice_in.iter().map(|&v| v as f32).collect();
2168
2169 let inner = py.allow_threads(|| {
2170 let cuda = CudaMaaq::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2171 cuda.maaq_batch_dev_ex(&data_f32, &sweep)
2172 .map_err(|e| PyValueError::new_err(e.to_string()))
2173 })?;
2174
2175 let pc =
2176 PrimaryCtxGuard::new(device_id as u32).map_err(|e| PyValueError::new_err(e.to_string()))?;
2177 Ok(DeviceArrayF32MaaqPy {
2178 inner: Some(inner),
2179 device_id: device_id as u32,
2180 pc_guard: Some(pc),
2181 })
2182}
2183
2184#[cfg(all(feature = "python", feature = "cuda"))]
2185#[pyfunction(name = "maaq_cuda_many_series_one_param_dev")]
2186#[pyo3(signature = (data_tm_f32, period, fast_period, slow_period, device_id=0))]
2187pub fn maaq_cuda_many_series_one_param_dev_py(
2188 py: Python<'_>,
2189 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
2190 period: usize,
2191 fast_period: usize,
2192 slow_period: usize,
2193 device_id: usize,
2194) -> PyResult<DeviceArrayF32MaaqPy> {
2195 use numpy::PyUntypedArrayMethods;
2196
2197 if !cuda_available() {
2198 return Err(PyValueError::new_err("CUDA not available"));
2199 }
2200
2201 let flat_in = data_tm_f32.as_slice()?;
2202 let rows = data_tm_f32.shape()[0];
2203 let cols = data_tm_f32.shape()[1];
2204 let params = MaaqParams {
2205 period: Some(period),
2206 fast_period: Some(fast_period),
2207 slow_period: Some(slow_period),
2208 };
2209
2210 let inner = py.allow_threads(|| {
2211 let cuda = CudaMaaq::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2212 cuda.maaq_multi_series_one_param_time_major_dev_ex(flat_in, cols, rows, ¶ms)
2213 .map_err(|e| PyValueError::new_err(e.to_string()))
2214 })?;
2215
2216 let pc =
2217 PrimaryCtxGuard::new(device_id as u32).map_err(|e| PyValueError::new_err(e.to_string()))?;
2218 Ok(DeviceArrayF32MaaqPy {
2219 inner: Some(inner),
2220 device_id: device_id as u32,
2221 pc_guard: Some(pc),
2222 })
2223}
2224
2225#[cfg(feature = "python")]
2226#[pyclass(name = "MaaqStream")]
2227pub struct MaaqStreamPy {
2228 stream: MaaqStream,
2229}
2230
2231#[cfg(feature = "python")]
2232#[pymethods]
2233impl MaaqStreamPy {
2234 #[new]
2235 pub fn new(period: usize, fast_period: usize, slow_period: usize) -> PyResult<Self> {
2236 let params = MaaqParams {
2237 period: Some(period),
2238 fast_period: Some(fast_period),
2239 slow_period: Some(slow_period),
2240 };
2241 let stream =
2242 MaaqStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2243 Ok(Self { stream })
2244 }
2245
2246 pub fn update(&mut self, value: f64) -> Option<f64> {
2247 self.stream.update(value)
2248 }
2249}
2250
2251#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2252use wasm_bindgen::prelude::*;
2253
2254#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2255use serde::{Deserialize, Serialize};
2256
2257#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2258#[derive(Serialize, Deserialize)]
2259pub struct MaaqBatchConfig {
2260 pub period_range: (usize, usize, usize),
2261 pub fast_period_range: (usize, usize, usize),
2262 pub slow_period_range: (usize, usize, usize),
2263}
2264
2265#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2266#[derive(Serialize, Deserialize)]
2267pub struct MaaqBatchJsOutput {
2268 pub values: Vec<f64>,
2269 pub combos: Vec<MaaqParams>,
2270 pub rows: usize,
2271 pub cols: usize,
2272}
2273
2274#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2275#[wasm_bindgen]
2276pub fn maaq_js(
2277 data: &[f64],
2278 period: usize,
2279 fast_period: usize,
2280 slow_period: usize,
2281) -> Result<Vec<f64>, JsValue> {
2282 let params = MaaqParams {
2283 period: Some(period),
2284 fast_period: Some(fast_period),
2285 slow_period: Some(slow_period),
2286 };
2287 let input = MaaqInput::from_slice(data, params);
2288
2289 let mut output = vec![0.0; data.len()];
2290 maaq_into_slice(&mut output, &input, Kernel::Auto)
2291 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2292
2293 Ok(output)
2294}
2295
2296#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2297#[wasm_bindgen]
2298pub fn maaq_batch_js(data: &[f64], config: JsValue) -> Result<Vec<f64>, JsValue> {
2299 let config: MaaqBatchConfig = serde_wasm_bindgen::from_value(config)
2300 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2301
2302 let range = MaaqBatchRange {
2303 period: config.period_range,
2304 fast_period: config.fast_period_range,
2305 slow_period: config.slow_period_range,
2306 };
2307
2308 match maaq_batch_with_kernel(data, &range, Kernel::Auto) {
2309 Ok(output) => Ok(output.values),
2310 Err(e) => Err(JsValue::from_str(&e.to_string())),
2311 }
2312}
2313
2314#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2315#[wasm_bindgen(js_name = maaq_batch)]
2316pub fn maaq_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2317 let config: MaaqBatchConfig = serde_wasm_bindgen::from_value(config)
2318 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2319
2320 let range = MaaqBatchRange {
2321 period: config.period_range,
2322 fast_period: config.fast_period_range,
2323 slow_period: config.slow_period_range,
2324 };
2325
2326 let output = maaq_batch_with_kernel(data, &range, Kernel::Auto)
2327 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2328
2329 let js_output = MaaqBatchJsOutput {
2330 values: output.values,
2331 combos: output.combos,
2332 rows: output.rows,
2333 cols: output.cols,
2334 };
2335
2336 serde_wasm_bindgen::to_value(&js_output)
2337 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2338}
2339
2340#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2341#[wasm_bindgen]
2342pub fn maaq_batch_metadata_js(
2343 period_start: usize,
2344 period_end: usize,
2345 period_step: usize,
2346 fast_period_start: usize,
2347 fast_period_end: usize,
2348 fast_period_step: usize,
2349 slow_period_start: usize,
2350 slow_period_end: usize,
2351 slow_period_step: usize,
2352) -> Vec<f64> {
2353 let range = MaaqBatchRange {
2354 period: (period_start, period_end, period_step),
2355 fast_period: (fast_period_start, fast_period_end, fast_period_step),
2356 slow_period: (slow_period_start, slow_period_end, slow_period_step),
2357 };
2358
2359 let combos = expand_grid(&range);
2360 let mut metadata = Vec::with_capacity(combos.len() * 3);
2361
2362 for params in combos {
2363 metadata.push(params.period.unwrap_or(11) as f64);
2364 metadata.push(params.fast_period.unwrap_or(2) as f64);
2365 metadata.push(params.slow_period.unwrap_or(30) as f64);
2366 }
2367
2368 metadata
2369}
2370
2371#[inline]
2372pub fn maaq_into_slice(dst: &mut [f64], input: &MaaqInput, kern: Kernel) -> Result<(), MaaqError> {
2373 let (data, period, fast_p, slow_p, first, chosen) = maaq_prepare(input, kern)?;
2374
2375 if dst.len() != data.len() {
2376 return Err(MaaqError::OutputLengthMismatch {
2377 expected: data.len(),
2378 got: dst.len(),
2379 });
2380 }
2381
2382 maaq_compute_into(data, period, fast_p, slow_p, first, chosen, dst)?;
2383
2384 let warmup_end = first + period - 1;
2385 for v in &mut dst[..warmup_end] {
2386 *v = f64::NAN;
2387 }
2388
2389 Ok(())
2390}
2391
2392#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2393#[inline]
2394pub fn maaq_into(input: &MaaqInput, out: &mut [f64]) -> Result<(), MaaqError> {
2395 maaq_into_slice(out, input, Kernel::Auto)
2396}
2397
2398#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2399#[wasm_bindgen]
2400pub fn maaq_alloc(len: usize) -> *mut f64 {
2401 let mut vec = Vec::<f64>::with_capacity(len);
2402 let ptr = vec.as_mut_ptr();
2403 std::mem::forget(vec);
2404 ptr
2405}
2406
2407#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2408#[wasm_bindgen]
2409pub fn maaq_free(ptr: *mut f64, len: usize) {
2410 if !ptr.is_null() {
2411 unsafe {
2412 let _ = Vec::from_raw_parts(ptr, len, len);
2413 }
2414 }
2415}
2416
2417#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2418#[wasm_bindgen]
2419pub fn maaq_into(
2420 in_ptr: *const f64,
2421 out_ptr: *mut f64,
2422 len: usize,
2423 period: usize,
2424 fast_period: usize,
2425 slow_period: usize,
2426) -> Result<(), JsValue> {
2427 if in_ptr.is_null() || out_ptr.is_null() {
2428 return Err(JsValue::from_str("null pointer passed to maaq_into"));
2429 }
2430
2431 unsafe {
2432 let data = std::slice::from_raw_parts(in_ptr, len);
2433
2434 if period == 0 || period > len {
2435 return Err(JsValue::from_str("Invalid period"));
2436 }
2437 if fast_period == 0 {
2438 return Err(JsValue::from_str("Invalid fast_period"));
2439 }
2440 if slow_period == 0 {
2441 return Err(JsValue::from_str("Invalid slow_period"));
2442 }
2443
2444 let params = MaaqParams {
2445 period: Some(period),
2446 fast_period: Some(fast_period),
2447 slow_period: Some(slow_period),
2448 };
2449 let input = MaaqInput::from_slice(data, params);
2450
2451 if in_ptr == out_ptr {
2452 let mut temp = vec![0.0; len];
2453 maaq_into_slice(&mut temp, &input, Kernel::Auto)
2454 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2455
2456 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2457 out.copy_from_slice(&temp);
2458 } else {
2459 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2460 maaq_into_slice(out, &input, Kernel::Auto)
2461 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2462 }
2463
2464 Ok(())
2465 }
2466}
2467
2468#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2469#[wasm_bindgen]
2470pub fn maaq_batch_into(
2471 in_ptr: *const f64,
2472 out_ptr: *mut f64,
2473 len: usize,
2474 config: JsValue,
2475) -> Result<(), JsValue> {
2476 if in_ptr.is_null() || out_ptr.is_null() {
2477 return Err(JsValue::from_str("null pointer passed to maaq_batch_into"));
2478 }
2479
2480 let config: MaaqBatchConfig = serde_wasm_bindgen::from_value(config)
2481 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2482
2483 unsafe {
2484 let data = std::slice::from_raw_parts(in_ptr, len);
2485
2486 let range = MaaqBatchRange {
2487 period: config.period_range,
2488 fast_period: config.fast_period_range,
2489 slow_period: config.slow_period_range,
2490 };
2491
2492 let combos = expand_grid(&range);
2493 let total_size = combos.len() * len;
2494
2495 if in_ptr == out_ptr {
2496 let mut temp = vec![0.0; total_size];
2497 maaq_batch_inner_into(data, &range, Kernel::Auto, false, &mut temp)
2498 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2499
2500 let out = std::slice::from_raw_parts_mut(out_ptr, total_size);
2501 out.copy_from_slice(&temp);
2502 } else {
2503 let out = std::slice::from_raw_parts_mut(out_ptr, total_size);
2504 maaq_batch_inner_into(data, &range, Kernel::Auto, false, out)
2505 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2506 }
2507
2508 Ok(())
2509 }
2510}