1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{alloc_with_nan_prefix, init_matrix_prefixes, make_uninit_matrix};
4#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
5use core::arch::x86_64::*;
6#[cfg(not(target_arch = "wasm32"))]
7use rayon::prelude::*;
8use std::convert::AsRef;
9use thiserror::Error;
10
11impl<'a> AsRef<[f64]> for MediumAdInput<'a> {
12 #[inline(always)]
13 fn as_ref(&self) -> &[f64] {
14 match &self.data {
15 MediumAdData::Slice(slice) => slice,
16 MediumAdData::Candles { candles, source } => source_type(candles, source),
17 }
18 }
19}
20
21#[derive(Debug, Clone)]
22pub enum MediumAdData<'a> {
23 Candles {
24 candles: &'a Candles,
25 source: &'a str,
26 },
27 Slice(&'a [f64]),
28}
29
30#[derive(Debug, Clone)]
31pub struct MediumAdOutput {
32 pub values: Vec<f64>,
33}
34
35#[derive(Debug, Clone)]
36#[cfg_attr(
37 all(target_arch = "wasm32", feature = "wasm"),
38 derive(Serialize, Deserialize)
39)]
40pub struct MediumAdParams {
41 pub period: Option<usize>,
42}
43
44impl Default for MediumAdParams {
45 fn default() -> Self {
46 Self { period: Some(5) }
47 }
48}
49
50#[derive(Debug, Clone)]
51pub struct MediumAdInput<'a> {
52 pub data: MediumAdData<'a>,
53 pub params: MediumAdParams,
54}
55
56impl<'a> MediumAdInput<'a> {
57 #[inline]
58 pub fn from_candles(c: &'a Candles, s: &'a str, p: MediumAdParams) -> Self {
59 Self {
60 data: MediumAdData::Candles {
61 candles: c,
62 source: s,
63 },
64 params: p,
65 }
66 }
67 #[inline]
68 pub fn from_slice(sl: &'a [f64], p: MediumAdParams) -> Self {
69 Self {
70 data: MediumAdData::Slice(sl),
71 params: p,
72 }
73 }
74 #[inline]
75 pub fn with_default_candles(c: &'a Candles) -> Self {
76 Self::from_candles(c, "close", MediumAdParams::default())
77 }
78 #[inline]
79 pub fn get_period(&self) -> usize {
80 self.params.period.unwrap_or(5)
81 }
82}
83
84#[derive(Copy, Clone, Debug)]
85pub struct MediumAdBuilder {
86 period: Option<usize>,
87 kernel: Kernel,
88}
89
90impl Default for MediumAdBuilder {
91 fn default() -> Self {
92 Self {
93 period: None,
94 kernel: Kernel::Auto,
95 }
96 }
97}
98
99impl MediumAdBuilder {
100 #[inline(always)]
101 pub fn new() -> Self {
102 Self::default()
103 }
104 #[inline(always)]
105 pub fn period(mut self, n: usize) -> Self {
106 self.period = Some(n);
107 self
108 }
109 #[inline(always)]
110 pub fn kernel(mut self, k: Kernel) -> Self {
111 self.kernel = k;
112 self
113 }
114
115 #[inline(always)]
116 pub fn apply(self, c: &Candles) -> Result<MediumAdOutput, MediumAdError> {
117 let p = MediumAdParams {
118 period: self.period,
119 };
120 let i = MediumAdInput::from_candles(c, "close", p);
121 medium_ad_with_kernel(&i, self.kernel)
122 }
123
124 #[inline(always)]
125 pub fn apply_slice(self, d: &[f64]) -> Result<MediumAdOutput, MediumAdError> {
126 let p = MediumAdParams {
127 period: self.period,
128 };
129 let i = MediumAdInput::from_slice(d, p);
130 medium_ad_with_kernel(&i, self.kernel)
131 }
132
133 #[inline(always)]
134 pub fn into_stream(self) -> Result<MediumAdStream, MediumAdError> {
135 let p = MediumAdParams {
136 period: self.period,
137 };
138 MediumAdStream::try_new(p)
139 }
140}
141
142#[derive(Debug, Error)]
143pub enum MediumAdError {
144 #[error("medium_ad: Empty input data slice.")]
145 EmptyInputData,
146
147 #[error("medium_ad: All values are NaN.")]
148 AllValuesNaN,
149
150 #[error("medium_ad: Invalid period: period = {period}, data length = {data_len}")]
151 InvalidPeriod { period: usize, data_len: usize },
152
153 #[error("medium_ad: Not enough valid data: needed = {needed}, valid = {valid}")]
154 NotEnoughValidData { needed: usize, valid: usize },
155
156 #[error("medium_ad: Output length mismatch: expected {expected}, got {got}")]
157 OutputLengthMismatch { expected: usize, got: usize },
158
159 #[error("medium_ad: Invalid range: start={start}, end={end}, step={step}")]
160 InvalidRange {
161 start: String,
162 end: String,
163 step: String,
164 },
165
166 #[error("medium_ad: Invalid kernel for batch: {0:?}")]
167 InvalidKernelForBatch(Kernel),
168}
169
170#[inline]
171pub fn medium_ad(input: &MediumAdInput) -> Result<MediumAdOutput, MediumAdError> {
172 medium_ad_with_kernel(input, Kernel::Auto)
173}
174
175pub fn medium_ad_with_kernel(
176 input: &MediumAdInput,
177 kernel: Kernel,
178) -> Result<MediumAdOutput, MediumAdError> {
179 let data: &[f64] = match &input.data {
180 MediumAdData::Candles { candles, source } => source_type(candles, source),
181 MediumAdData::Slice(sl) => sl,
182 };
183
184 if data.is_empty() {
185 return Err(MediumAdError::EmptyInputData);
186 }
187
188 let first = data
189 .iter()
190 .position(|x| !x.is_nan())
191 .ok_or(MediumAdError::AllValuesNaN)?;
192 let len = data.len();
193 let period = input.get_period();
194
195 if period == 0 || period > len {
196 return Err(MediumAdError::InvalidPeriod {
197 period,
198 data_len: len,
199 });
200 }
201 if (len - first) < period {
202 return Err(MediumAdError::NotEnoughValidData {
203 needed: period,
204 valid: len - first,
205 });
206 }
207
208 let mut out = alloc_with_nan_prefix(len, first + period - 1);
209
210 let chosen = match kernel {
211 Kernel::Auto => Kernel::Scalar,
212 other => other,
213 };
214
215 unsafe {
216 match chosen {
217 Kernel::Scalar | Kernel::ScalarBatch => medium_ad_scalar(data, period, first, &mut out),
218 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
219 Kernel::Avx2 | Kernel::Avx2Batch => medium_ad_avx2(data, period, first, &mut out),
220
221 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
222 Kernel::Avx512 | Kernel::Avx512Batch => medium_ad_scalar(data, period, first, &mut out),
223 _ => unreachable!(),
224 }
225 }
226
227 Ok(MediumAdOutput { values: out })
228}
229
230#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
231#[inline]
232pub fn medium_ad_into(input: &MediumAdInput, out: &mut [f64]) -> Result<(), MediumAdError> {
233 let data: &[f64] = match &input.data {
234 MediumAdData::Candles { candles, source } => source_type(candles, source),
235 MediumAdData::Slice(sl) => sl,
236 };
237
238 let len = data.len();
239 if len == 0 {
240 return Err(MediumAdError::EmptyInputData);
241 }
242
243 let first = data
244 .iter()
245 .position(|x| !x.is_nan())
246 .ok_or(MediumAdError::AllValuesNaN)?;
247
248 let period = input.get_period();
249 if period == 0 || period > len {
250 return Err(MediumAdError::InvalidPeriod {
251 period,
252 data_len: len,
253 });
254 }
255 if (len - first) < period {
256 return Err(MediumAdError::NotEnoughValidData {
257 needed: period,
258 valid: len - first,
259 });
260 }
261
262 if out.len() != len {
263 return Err(MediumAdError::OutputLengthMismatch {
264 expected: len,
265 got: out.len(),
266 });
267 }
268
269 let warm = first + period - 1;
270 let warm_cap = warm.min(len);
271 for v in &mut out[..warm_cap] {
272 *v = f64::from_bits(0x7ff8_0000_0000_0000);
273 }
274
275 let chosen = Kernel::Scalar;
276
277 unsafe {
278 match chosen {
279 Kernel::Scalar | Kernel::ScalarBatch => medium_ad_scalar(data, period, first, out),
280 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
281 Kernel::Avx2 | Kernel::Avx2Batch => medium_ad_avx2(data, period, first, out),
282
283 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
284 Kernel::Avx512 | Kernel::Avx512Batch => medium_ad_scalar(data, period, first, out),
285 _ => unreachable!(),
286 }
287 }
288
289 Ok(())
290}
291
292#[inline]
293pub fn medium_ad_scalar(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
294 use core::cmp::Ordering;
295
296 #[inline(always)]
297 fn fast_abs_f64(x: f64) -> f64 {
298 f64::from_bits(x.to_bits() & 0x7FFF_FFFF_FFFF_FFFF)
299 }
300
301 #[inline(always)]
302 fn median_from(buf: &mut [f64], mid: usize) -> f64 {
303 buf.select_nth_unstable_by(mid, |a, b| {
304 if *a < *b {
305 Ordering::Less
306 } else if *a > *b {
307 Ordering::Greater
308 } else {
309 Ordering::Equal
310 }
311 });
312 if (buf.len() & 1) == 1 {
313 unsafe { *buf.get_unchecked(mid) }
314 } else {
315 let mut lo_max = f64::NEG_INFINITY;
316 let left = unsafe { core::slice::from_raw_parts(buf.as_ptr(), mid) };
317 for &v in left.iter() {
318 if v > lo_max {
319 lo_max = v;
320 }
321 }
322 0.5 * (lo_max + unsafe { *buf.get_unchecked(mid) })
323 }
324 }
325
326 let len = data.len();
327 if period == 1 {
328 let start = first_valid;
329 for i in start..len {
330 let v = unsafe { *data.get_unchecked(i) };
331 unsafe { *out.get_unchecked_mut(i) = if v.is_nan() { f64::NAN } else { 0.0 } };
332 }
333 return;
334 }
335
336 let mut buf: Vec<f64> = Vec::with_capacity(period);
337 unsafe { buf.set_len(period) };
338 let mid = period >> 1;
339 let warm = first_valid + period - 1;
340
341 for i in warm..len {
342 let start = i + 1 - period;
343
344 let mut has_nan = false;
345 unsafe {
346 let dp = data.as_ptr().add(start);
347 let bp = buf.as_mut_ptr();
348 let mut k = 0usize;
349
350 while k + 4 <= period {
351 let a = *dp.add(k);
352 let b = *dp.add(k + 1);
353 let c = *dp.add(k + 2);
354 let d = *dp.add(k + 3);
355 *bp.add(k) = a;
356 *bp.add(k + 1) = b;
357 *bp.add(k + 2) = c;
358 *bp.add(k + 3) = d;
359 has_nan |= (a != a) | (b != b) | (c != c) | (d != d);
360 k += 4;
361 }
362 while k < period {
363 let v = *dp.add(k);
364 *bp.add(k) = v;
365 has_nan |= v != v;
366 k += 1;
367 }
368 }
369 if has_nan {
370 unsafe { *out.get_unchecked_mut(i) = f64::NAN };
371 continue;
372 }
373
374 let med = median_from(&mut buf, mid);
375
376 unsafe {
377 let bp = buf.as_mut_ptr();
378 let mut k = 0usize;
379 while k + 4 <= period {
380 let a = *bp.add(k) - med;
381 let b = *bp.add(k + 1) - med;
382 let c = *bp.add(k + 2) - med;
383 let d = *bp.add(k + 3) - med;
384 *bp.add(k) = fast_abs_f64(a);
385 *bp.add(k + 1) = fast_abs_f64(b);
386 *bp.add(k + 2) = fast_abs_f64(c);
387 *bp.add(k + 3) = fast_abs_f64(d);
388 k += 4;
389 }
390 while k < period {
391 let t = *bp.add(k) - med;
392 *bp.add(k) = fast_abs_f64(t);
393 k += 1;
394 }
395 }
396
397 let mad = median_from(&mut buf, mid);
398 unsafe { *out.get_unchecked_mut(i) = mad };
399 }
400}
401
402#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
403#[inline]
404pub fn medium_ad_avx512(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
405 use core::cmp::Ordering;
406 unsafe {
407 let len = data.len();
408
409 let mut buf: Vec<f64> = Vec::with_capacity(period);
410 unsafe { buf.set_len(period) };
411 let mid = period >> 1;
412 let sign_mask = _mm512_set1_pd(-0.0);
413
414 #[inline(always)]
415 fn median_from(buf: &mut [f64], mid: usize) -> f64 {
416 buf.select_nth_unstable_by(mid, |a, b| {
417 if *a < *b {
418 Ordering::Less
419 } else if *a > *b {
420 Ordering::Greater
421 } else {
422 Ordering::Equal
423 }
424 });
425 if (buf.len() & 1) == 1 {
426 unsafe { *buf.get_unchecked(mid) }
427 } else {
428 let mut lo_max = f64::NEG_INFINITY;
429 for &v in (&buf[..mid]).iter() {
430 if v > lo_max {
431 lo_max = v;
432 }
433 }
434 0.5 * (lo_max + unsafe { *buf.get_unchecked(mid) })
435 }
436 }
437
438 let warm = first_valid + period - 1;
439 for i in warm..len {
440 let start = i + 1 - period;
441
442 let mut has_nan = false;
443 let mut k = 0usize;
444 while k + 8 <= period {
445 let v = _mm512_loadu_pd(data.as_ptr().add(start + k));
446 _mm512_storeu_pd(buf.as_mut_ptr().add(k), v);
447 let m = _mm512_cmp_pd_mask(v, v, 0x03);
448 if m != 0 {
449 has_nan = true;
450 }
451 k += 8;
452 }
453
454 while k + 4 <= period {
455 let v = _mm256_loadu_pd(data.as_ptr().add(start + k));
456 _mm256_storeu_pd(buf.as_mut_ptr().add(k), v);
457 let nan_mask = _mm256_cmp_pd(v, v, 0x03);
458 if _mm256_movemask_pd(nan_mask) != 0 {
459 has_nan = true;
460 }
461 k += 4;
462 }
463 while k < period {
464 let val = *data.get_unchecked(start + k);
465 *buf.get_unchecked_mut(k) = val;
466 has_nan |= val != val;
467 k += 1;
468 }
469 if has_nan {
470 *out.get_unchecked_mut(i) = f64::NAN;
471 continue;
472 }
473
474 let med = median_from(&mut buf, mid);
475
476 let mv = _mm512_set1_pd(med);
477 let mut k = 0usize;
478 while k + 8 <= period {
479 let x = _mm512_loadu_pd(buf.as_ptr().add(k));
480 let d = _mm512_sub_pd(x, mv);
481 let ad = _mm512_andnot_pd(sign_mask, d);
482 _mm512_storeu_pd(buf.as_mut_ptr().add(k), ad);
483 k += 8;
484 }
485 while k + 4 <= period {
486 let x = _mm256_loadu_pd(buf.as_ptr().add(k));
487 let mv4 = _mm256_set1_pd(med);
488 let sign4 = _mm256_set1_pd(-0.0);
489 let d = _mm256_sub_pd(x, mv4);
490 let ad = _mm256_andnot_pd(sign4, d);
491 _mm256_storeu_pd(buf.as_mut_ptr().add(k), ad);
492 k += 4;
493 }
494 while k < period {
495 let t = *buf.get_unchecked(k) - med;
496 *buf.get_unchecked_mut(k) = f64::from_bits(t.to_bits() & 0x7FFF_FFFF_FFFF_FFFF);
497 k += 1;
498 }
499
500 *out.get_unchecked_mut(i) = median_from(&mut buf, mid);
501 }
502 }
503}
504
505#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
506#[inline]
507pub fn medium_ad_avx2(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
508 use core::cmp::Ordering;
509 unsafe {
510 let len = data.len();
511
512 let mut buf: Vec<f64> = Vec::with_capacity(period);
513 unsafe { buf.set_len(period) };
514 let mid = period >> 1;
515 let sign_mask = _mm256_set1_pd(-0.0);
516
517 #[inline(always)]
518 fn median_from(buf: &mut [f64], mid: usize) -> f64 {
519 buf.select_nth_unstable_by(mid, |a, b| {
520 if *a < *b {
521 Ordering::Less
522 } else if *a > *b {
523 Ordering::Greater
524 } else {
525 Ordering::Equal
526 }
527 });
528 if (buf.len() & 1) == 1 {
529 unsafe { *buf.get_unchecked(mid) }
530 } else {
531 let mut lo_max = f64::NEG_INFINITY;
532 for &v in (&buf[..mid]).iter() {
533 if v > lo_max {
534 lo_max = v;
535 }
536 }
537 0.5 * (lo_max + unsafe { *buf.get_unchecked(mid) })
538 }
539 }
540
541 let warm = first_valid + period - 1;
542 for i in warm..len {
543 let start = i + 1 - period;
544
545 let mut has_nan = false;
546 let mut k = 0usize;
547 while k + 4 <= period {
548 let v = _mm256_loadu_pd(data.as_ptr().add(start + k));
549 _mm256_storeu_pd(buf.as_mut_ptr().add(k), v);
550
551 let nan_mask = _mm256_cmp_pd(v, v, 0x03);
552 if _mm256_movemask_pd(nan_mask) != 0 {
553 has_nan = true;
554 }
555 k += 4;
556 }
557 while k < period {
558 let val = *data.get_unchecked(start + k);
559 *buf.get_unchecked_mut(k) = val;
560 has_nan |= val != val;
561 k += 1;
562 }
563 if has_nan {
564 *out.get_unchecked_mut(i) = f64::NAN;
565 continue;
566 }
567
568 let med = median_from(&mut buf, mid);
569
570 let mv = _mm256_set1_pd(med);
571 let mut k = 0usize;
572 while k + 4 <= period {
573 let x = _mm256_loadu_pd(buf.as_ptr().add(k));
574 let d = _mm256_sub_pd(x, mv);
575 let ad = _mm256_andnot_pd(sign_mask, d);
576 _mm256_storeu_pd(buf.as_mut_ptr().add(k), ad);
577 k += 4;
578 }
579 while k < period {
580 let t = *buf.get_unchecked(k) - med;
581 *buf.get_unchecked_mut(k) = f64::from_bits(t.to_bits() & 0x7FFF_FFFF_FFFF_FFFF);
582 k += 1;
583 }
584
585 *out.get_unchecked_mut(i) = median_from(&mut buf, mid);
586 }
587 }
588}
589
590#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
591#[inline]
592pub fn medium_ad_avx512_short(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
593 unsafe { medium_ad_scalar(data, period, first_valid, out) }
594}
595
596#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
597#[inline]
598pub fn medium_ad_avx512_long(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
599 unsafe { medium_ad_scalar(data, period, first_valid, out) }
600}
601
602#[inline(always)]
603pub fn medium_ad_batch_with_kernel(
604 data: &[f64],
605 sweep: &MediumAdBatchRange,
606 k: Kernel,
607) -> Result<MediumAdBatchOutput, MediumAdError> {
608 let kernel = match k {
609 Kernel::Auto => Kernel::ScalarBatch,
610 other if other.is_batch() => other,
611 other => return Err(MediumAdError::InvalidKernelForBatch(other)),
612 };
613
614 let simd = match kernel {
615 Kernel::Avx512Batch => Kernel::Avx512,
616 Kernel::Avx2Batch => Kernel::Avx2,
617 Kernel::ScalarBatch => Kernel::Scalar,
618 _ => unreachable!(),
619 };
620 medium_ad_batch_par_slice(data, sweep, simd)
621}
622
623#[derive(Clone, Debug)]
624pub struct MediumAdBatchRange {
625 pub period: (usize, usize, usize),
626}
627
628impl Default for MediumAdBatchRange {
629 fn default() -> Self {
630 Self {
631 period: (5, 254, 1),
632 }
633 }
634}
635
636#[derive(Clone, Debug, Default)]
637pub struct MediumAdBatchBuilder {
638 range: MediumAdBatchRange,
639 kernel: Kernel,
640}
641
642impl MediumAdBatchBuilder {
643 pub fn new() -> Self {
644 Self::default()
645 }
646 pub fn kernel(mut self, k: Kernel) -> Self {
647 self.kernel = k;
648 self
649 }
650
651 #[inline]
652 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
653 self.range.period = (start, end, step);
654 self
655 }
656 #[inline]
657 pub fn period_static(mut self, p: usize) -> Self {
658 self.range.period = (p, p, 0);
659 self
660 }
661
662 pub fn apply_slice(self, data: &[f64]) -> Result<MediumAdBatchOutput, MediumAdError> {
663 medium_ad_batch_with_kernel(data, &self.range, self.kernel)
664 }
665
666 pub fn with_default_slice(
667 data: &[f64],
668 k: Kernel,
669 ) -> Result<MediumAdBatchOutput, MediumAdError> {
670 MediumAdBatchBuilder::new().kernel(k).apply_slice(data)
671 }
672
673 pub fn apply_candles(
674 self,
675 c: &Candles,
676 src: &str,
677 ) -> Result<MediumAdBatchOutput, MediumAdError> {
678 let slice = source_type(c, src);
679 self.apply_slice(slice)
680 }
681
682 pub fn with_default_candles(c: &Candles) -> Result<MediumAdBatchOutput, MediumAdError> {
683 MediumAdBatchBuilder::new()
684 .kernel(Kernel::Auto)
685 .apply_candles(c, "close")
686 }
687}
688
689#[derive(Clone, Debug)]
690pub struct MediumAdBatchOutput {
691 pub values: Vec<f64>,
692 pub combos: Vec<MediumAdParams>,
693 pub rows: usize,
694 pub cols: usize,
695}
696impl MediumAdBatchOutput {
697 pub fn row_for_params(&self, p: &MediumAdParams) -> Option<usize> {
698 self.combos
699 .iter()
700 .position(|c| c.period.unwrap_or(5) == p.period.unwrap_or(5))
701 }
702
703 pub fn values_for(&self, p: &MediumAdParams) -> Option<&[f64]> {
704 self.row_for_params(p).and_then(|row| {
705 let start = row.checked_mul(self.cols)?;
706 let end = start.checked_add(self.cols)?;
707 self.values.get(start..end)
708 })
709 }
710}
711
712#[inline(always)]
713fn expand_grid(r: &MediumAdBatchRange) -> Result<Vec<MediumAdParams>, MediumAdError> {
714 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, MediumAdError> {
715 if step == 0 || start == end {
716 return Ok(vec![start]);
717 }
718 if start < end {
719 return Ok((start..=end).step_by(step.max(1)).collect());
720 }
721 let mut v = Vec::new();
722 let mut x = start as isize;
723 let end_i = end as isize;
724 let st = (step as isize).max(1);
725 while x >= end_i {
726 v.push(x as usize);
727 x = x.saturating_sub(st);
728 if x < 0 {
729 break;
730 }
731 }
732 if v.is_empty() {
733 return Err(MediumAdError::InvalidRange {
734 start: start.to_string(),
735 end: end.to_string(),
736 step: step.to_string(),
737 });
738 }
739 Ok(v)
740 }
741
742 let periods = axis_usize(r.period)?;
743 if periods.is_empty() {
744 return Err(MediumAdError::InvalidRange {
745 start: r.period.0.to_string(),
746 end: r.period.1.to_string(),
747 step: r.period.2.to_string(),
748 });
749 }
750
751 let mut out = Vec::with_capacity(periods.len());
752 for &p in &periods {
753 out.push(MediumAdParams { period: Some(p) });
754 }
755 Ok(out)
756}
757
758#[inline(always)]
759pub fn medium_ad_batch_slice(
760 data: &[f64],
761 sweep: &MediumAdBatchRange,
762 kern: Kernel,
763) -> Result<MediumAdBatchOutput, MediumAdError> {
764 medium_ad_batch_inner(data, sweep, kern, false)
765}
766
767#[inline(always)]
768pub fn medium_ad_batch_par_slice(
769 data: &[f64],
770 sweep: &MediumAdBatchRange,
771 kern: Kernel,
772) -> Result<MediumAdBatchOutput, MediumAdError> {
773 medium_ad_batch_inner(data, sweep, kern, true)
774}
775
776#[inline(always)]
777fn medium_ad_batch_inner(
778 data: &[f64],
779 sweep: &MediumAdBatchRange,
780 kern: Kernel,
781 parallel: bool,
782) -> Result<MediumAdBatchOutput, MediumAdError> {
783 let combos = expand_grid(sweep)?;
784
785 let cols = data.len();
786 if cols == 0 {
787 return Err(MediumAdError::AllValuesNaN);
788 }
789
790 let first = data
791 .iter()
792 .position(|x| !x.is_nan())
793 .ok_or(MediumAdError::AllValuesNaN)?;
794 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
795 if cols - first < max_p {
796 return Err(MediumAdError::NotEnoughValidData {
797 needed: max_p,
798 valid: cols - first,
799 });
800 }
801
802 let rows = combos.len();
803
804 let _total_elems = rows.checked_mul(cols).ok_or(MediumAdError::InvalidRange {
805 start: sweep.period.0.to_string(),
806 end: sweep.period.1.to_string(),
807 step: sweep.period.2.to_string(),
808 })?;
809 let mut buf_mu = make_uninit_matrix(rows, cols);
810 let warm: Vec<usize> = combos
811 .iter()
812 .map(|c| first + c.period.unwrap() - 1)
813 .collect();
814 init_matrix_prefixes(&mut buf_mu, cols, &warm);
815
816 let out_mu = buf_mu.as_mut_slice();
817
818 let do_row = |row: usize, dst_mu: &mut [core::mem::MaybeUninit<f64>]| {
819 let dst = unsafe {
820 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len())
821 };
822 let period = combos[row].period.unwrap();
823
824 unsafe {
825 match kern {
826 Kernel::Scalar | Kernel::Auto => medium_ad_row_scalar(data, first, period, dst),
827 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
828 Kernel::Avx2 => medium_ad_row_avx2(data, first, period, dst),
829 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
830 Kernel::Avx512 => medium_ad_row_avx512(data, first, period, dst),
831 _ => unreachable!(),
832 }
833 }
834 };
835
836 if parallel {
837 #[cfg(not(target_arch = "wasm32"))]
838 {
839 use rayon::prelude::*;
840 out_mu
841 .par_chunks_mut(cols)
842 .enumerate()
843 .for_each(|(row, slice_mu)| do_row(row, slice_mu));
844 }
845 #[cfg(target_arch = "wasm32")]
846 {
847 for (row, slice_mu) in out_mu.chunks_mut(cols).enumerate() {
848 do_row(row, slice_mu);
849 }
850 }
851 } else {
852 for (row, slice_mu) in out_mu.chunks_mut(cols).enumerate() {
853 do_row(row, slice_mu);
854 }
855 }
856
857 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
858 let values = unsafe {
859 Vec::from_raw_parts(
860 guard.as_mut_ptr() as *mut f64,
861 guard.len(),
862 guard.capacity(),
863 )
864 };
865
866 Ok(MediumAdBatchOutput {
867 values,
868 combos,
869 rows,
870 cols,
871 })
872}
873
874#[inline(always)]
875fn medium_ad_batch_inner_into(
876 data: &[f64],
877 sweep: &MediumAdBatchRange,
878 kern: Kernel,
879 parallel: bool,
880 out: &mut [f64],
881) -> Result<Vec<MediumAdParams>, MediumAdError> {
882 let combos = expand_grid(sweep)?;
883
884 let first = data
885 .iter()
886 .position(|x| !x.is_nan())
887 .ok_or(MediumAdError::AllValuesNaN)?;
888 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
889 if data.len() - first < max_p {
890 return Err(MediumAdError::NotEnoughValidData {
891 needed: max_p,
892 valid: data.len() - first,
893 });
894 }
895
896 let cols = data.len();
897
898 let _total_elems = combos
899 .len()
900 .checked_mul(cols)
901 .ok_or(MediumAdError::InvalidRange {
902 start: sweep.period.0.to_string(),
903 end: sweep.period.1.to_string(),
904 step: sweep.period.2.to_string(),
905 })?;
906
907 for (row, combo) in combos.iter().enumerate() {
908 let warmup = first + combo.period.unwrap() - 1;
909 let row_start = match row.checked_mul(cols) {
910 Some(v) => v,
911 None => {
912 return Err(MediumAdError::InvalidRange {
913 start: sweep.period.0.to_string(),
914 end: sweep.period.1.to_string(),
915 step: sweep.period.2.to_string(),
916 })
917 }
918 };
919 for i in 0..warmup.min(cols) {
920 out[row_start + i] = f64::NAN;
921 }
922 }
923
924 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
925 let period = combos[row].period.unwrap();
926 match kern {
927 Kernel::Scalar => medium_ad_row_scalar(data, first, period, out_row),
928 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
929 Kernel::Avx2 => medium_ad_row_avx2(data, first, period, out_row),
930 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
931 Kernel::Avx512 => medium_ad_row_avx512(data, first, period, out_row),
932 _ => unreachable!(),
933 }
934 };
935
936 if parallel {
937 #[cfg(not(target_arch = "wasm32"))]
938 {
939 out.par_chunks_mut(cols)
940 .enumerate()
941 .for_each(|(row, slice)| do_row(row, slice));
942 }
943
944 #[cfg(target_arch = "wasm32")]
945 {
946 for (row, slice) in out.chunks_mut(cols).enumerate() {
947 do_row(row, slice);
948 }
949 }
950 } else {
951 for (row, slice) in out.chunks_mut(cols).enumerate() {
952 do_row(row, slice);
953 }
954 }
955
956 Ok(combos)
957}
958
959#[inline(always)]
960unsafe fn medium_ad_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
961 use core::cmp::Ordering;
962
963 #[inline(always)]
964 fn fast_abs_f64(x: f64) -> f64 {
965 f64::from_bits(x.to_bits() & 0x7FFF_FFFF_FFFF_FFFF)
966 }
967 #[inline(always)]
968 fn median_from(buf: &mut [f64], mid: usize) -> f64 {
969 buf.select_nth_unstable_by(mid, |a, b| {
970 if *a < *b {
971 Ordering::Less
972 } else if *a > *b {
973 Ordering::Greater
974 } else {
975 Ordering::Equal
976 }
977 });
978 if (buf.len() & 1) == 1 {
979 unsafe { *buf.get_unchecked(mid) }
980 } else {
981 let mut lo_max = f64::NEG_INFINITY;
982 let left = unsafe { core::slice::from_raw_parts(buf.as_ptr(), mid) };
983 for &v in left.iter() {
984 if v > lo_max {
985 lo_max = v;
986 }
987 }
988 0.5 * (lo_max + unsafe { *buf.get_unchecked(mid) })
989 }
990 }
991
992 if period == 1 {
993 let warm = first;
994 for i in warm..data.len() {
995 let v = *data.get_unchecked(i);
996 *out.get_unchecked_mut(i) = if v.is_nan() { f64::NAN } else { 0.0 };
997 }
998 return;
999 }
1000
1001 let mut buf: Vec<f64> = Vec::with_capacity(period);
1002 unsafe { buf.set_len(period) };
1003 let mid = period >> 1;
1004 let warm = first + period - 1;
1005
1006 for i in warm..data.len() {
1007 let start = i + 1 - period;
1008
1009 let mut has_nan = false;
1010 let dp = data.as_ptr().add(start);
1011 let bp = buf.as_mut_ptr();
1012 let mut k = 0usize;
1013 while k + 4 <= period {
1014 let a = *dp.add(k);
1015 let b = *dp.add(k + 1);
1016 let c = *dp.add(k + 2);
1017 let d = *dp.add(k + 3);
1018 *bp.add(k) = a;
1019 *bp.add(k + 1) = b;
1020 *bp.add(k + 2) = c;
1021 *bp.add(k + 3) = d;
1022 has_nan |= (a != a) | (b != b) | (c != c) | (d != d);
1023 k += 4;
1024 }
1025 while k < period {
1026 let v = *dp.add(k);
1027 *bp.add(k) = v;
1028 has_nan |= v != v;
1029 k += 1;
1030 }
1031 if has_nan {
1032 *out.get_unchecked_mut(i) = f64::NAN;
1033 continue;
1034 }
1035
1036 let med = median_from(&mut buf, mid);
1037
1038 let bp = buf.as_mut_ptr();
1039 let mut k = 0usize;
1040 while k + 4 <= period {
1041 let a = *bp.add(k) - med;
1042 let b = *bp.add(k + 1) - med;
1043 let c = *bp.add(k + 2) - med;
1044 let d = *bp.add(k + 3) - med;
1045 *bp.add(k) = fast_abs_f64(a);
1046 *bp.add(k + 1) = fast_abs_f64(b);
1047 *bp.add(k + 2) = fast_abs_f64(c);
1048 *bp.add(k + 3) = fast_abs_f64(d);
1049 k += 4;
1050 }
1051 while k < period {
1052 let t = *bp.add(k) - med;
1053 *bp.add(k) = fast_abs_f64(t);
1054 k += 1;
1055 }
1056
1057 *out.get_unchecked_mut(i) = median_from(&mut buf, mid);
1058 }
1059}
1060
1061#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1062#[inline(always)]
1063unsafe fn medium_ad_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1064 medium_ad_avx2(data, period, first, out)
1065}
1066
1067#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1068#[inline(always)]
1069unsafe fn medium_ad_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1070 if period <= 32 {
1071 medium_ad_row_avx512_short(data, first, period, out)
1072 } else {
1073 medium_ad_row_avx512_long(data, first, period, out)
1074 }
1075}
1076
1077#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1078#[inline(always)]
1079unsafe fn medium_ad_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1080 medium_ad_avx512(data, period, first, out)
1081}
1082
1083#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1084#[inline(always)]
1085unsafe fn medium_ad_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1086 medium_ad_avx512(data, period, first, out)
1087}
1088
1089#[derive(Debug, Clone)]
1090pub struct MediumAdStream {
1091 period: usize,
1092
1093 ring: Vec<Option<Entry>>,
1094 head: usize,
1095 filled: bool,
1096
1097 os: OrderStatTree,
1098 next_id: u64,
1099}
1100
1101#[derive(Copy, Clone, Debug)]
1102struct Entry {
1103 val: f64,
1104 id: u64,
1105}
1106
1107impl MediumAdStream {
1108 pub fn try_new(params: MediumAdParams) -> Result<Self, MediumAdError> {
1109 let period = params.period.unwrap_or(5);
1110 if period == 0 {
1111 return Err(MediumAdError::InvalidPeriod {
1112 period,
1113 data_len: 0,
1114 });
1115 }
1116 Ok(Self {
1117 period,
1118 ring: vec![None; period],
1119 head: 0,
1120 filled: false,
1121 os: OrderStatTree::new(),
1122 next_id: 1,
1123 })
1124 }
1125
1126 #[inline(always)]
1127 pub fn update(&mut self, value: f64) -> Option<f64> {
1128 if let Some(old) = self.ring[self.head] {
1129 self.os.remove(old);
1130 }
1131
1132 let _inserted = if value.is_nan() {
1133 self.ring[self.head] = None;
1134 false
1135 } else {
1136 let e = Entry {
1137 val: value,
1138 id: self.next_id,
1139 };
1140 self.next_id = self.next_id.wrapping_add(1);
1141 self.os.insert(e);
1142 self.ring[self.head] = Some(e);
1143 true
1144 };
1145
1146 self.head = (self.head + 1) % self.period;
1147 if !self.filled && self.head == 0 {
1148 self.filled = true;
1149 }
1150
1151 if !self.filled || self.os.len() != self.period {
1152 return None;
1153 }
1154
1155 if self.period == 1 {
1156 return Some(0.0);
1157 }
1158
1159 let n = self.period;
1160 let left_sz = n >> 1;
1161 let median = if (n & 1) == 1 {
1162 self.os.kth(left_sz).val
1163 } else {
1164 let lo = self.os.kth(left_sz - 1).val;
1165 let hi = self.os.kth(left_sz).val;
1166 0.5 * (lo + hi)
1167 };
1168
1169 Some(self.mad_from_tree(median))
1170 }
1171
1172 #[inline(always)]
1173 fn ldist(&self, i: usize, median: f64, left_sz: usize) -> f64 {
1174 let idx = left_sz - 1 - i;
1175 let x = self.os.kth(idx).val;
1176
1177 median - x
1178 }
1179
1180 #[inline(always)]
1181 fn rdist(&self, j: usize, median: f64, left_sz: usize) -> f64 {
1182 let idx = left_sz + j;
1183 let x = self.os.kth(idx).val;
1184 x - median
1185 }
1186
1187 #[inline(always)]
1188 fn kth_absdev_union(&self, k: usize, median: f64, left_sz: usize) -> f64 {
1189 let right_sz = self.period - left_sz;
1190
1191 let mut lo = if k > right_sz { k - right_sz } else { 0 };
1192 let mut hi = k.min(left_sz);
1193
1194 while lo <= hi {
1195 let i = (lo + hi) >> 1;
1196 let j = k - i;
1197
1198 let l_im1 = if i == 0 {
1199 f64::NEG_INFINITY
1200 } else {
1201 self.ldist(i - 1, median, left_sz)
1202 };
1203 let l_i = if i == left_sz {
1204 f64::INFINITY
1205 } else {
1206 self.ldist(i, median, left_sz)
1207 };
1208
1209 let r_jm1 = if j == 0 {
1210 f64::NEG_INFINITY
1211 } else {
1212 self.rdist(j - 1, median, left_sz)
1213 };
1214 let r_j = if j == right_sz {
1215 f64::INFINITY
1216 } else {
1217 self.rdist(j, median, left_sz)
1218 };
1219
1220 if l_im1 <= r_j && r_jm1 <= l_i {
1221 return if l_im1 > r_jm1 { l_im1 } else { r_jm1 };
1222 } else if l_im1 > r_j {
1223 hi = i - 1;
1224 } else {
1225 lo = i + 1;
1226 }
1227 }
1228 debug_assert!(false, "kth_absdev_union: unreachable");
1229 0.0
1230 }
1231
1232 #[inline(always)]
1233 fn mad_from_tree(&self, median: f64) -> f64 {
1234 let n = self.period;
1235 let mid = n >> 1;
1236 let mut buf = Vec::with_capacity(n);
1237 for i in 0..n {
1238 let x = self.os.kth(i).val;
1239 buf.push((x - median).abs());
1240 }
1241 use core::cmp::Ordering;
1242 buf.select_nth_unstable_by(mid, |a, b| {
1243 if *a < *b {
1244 Ordering::Less
1245 } else if *a > *b {
1246 Ordering::Greater
1247 } else {
1248 Ordering::Equal
1249 }
1250 });
1251 if (n & 1) == 1 {
1252 buf[mid]
1253 } else {
1254 let mut lo_max = f64::NEG_INFINITY;
1255 for &v in &buf[..mid] {
1256 if v > lo_max {
1257 lo_max = v;
1258 }
1259 }
1260 0.5 * (lo_max + buf[mid])
1261 }
1262 }
1263}
1264
1265#[derive(Default, Debug, Clone)]
1266struct OrderStatTree {
1267 root: Link,
1268}
1269
1270type Link = Option<Box<Node>>;
1271
1272#[derive(Debug, Clone)]
1273struct Node {
1274 key: Entry,
1275 prio: u64,
1276 size: usize,
1277 left: Link,
1278 right: Link,
1279}
1280
1281impl OrderStatTree {
1282 #[inline(always)]
1283 fn new() -> Self {
1284 Self { root: None }
1285 }
1286
1287 #[inline(always)]
1288 fn len(&self) -> usize {
1289 size_of(&self.root)
1290 }
1291
1292 #[inline(always)]
1293 fn insert(&mut self, key: Entry) {
1294 let prio = treap_priority(key);
1295 self.root = insert_rec(self.root.take(), key, prio);
1296 }
1297
1298 #[inline(always)]
1299 fn remove(&mut self, key: Entry) {
1300 self.root = remove_rec(self.root.take(), key);
1301 }
1302
1303 #[inline(always)]
1304 fn kth(&self, k: usize) -> Entry {
1305 kth_rec(&self.root, k)
1306 }
1307}
1308
1309#[inline(always)]
1310fn size_of(n: &Link) -> usize {
1311 n.as_ref().map_or(0, |b| b.size)
1312}
1313
1314#[inline(always)]
1315fn upd(node: &mut Box<Node>) {
1316 node.size = 1 + size_of(&node.left) + size_of(&node.right);
1317}
1318
1319#[inline(always)]
1320fn less(a: Entry, b: Entry) -> bool {
1321 if a.val < b.val {
1322 true
1323 } else if a.val > b.val {
1324 false
1325 } else {
1326 a.id < b.id
1327 }
1328}
1329
1330#[inline(always)]
1331fn rotate_left(mut x: Box<Node>) -> Box<Node> {
1332 let mut y = x.right.take().expect("rotate_left requires right child");
1333 x.right = y.left.take();
1334 upd(&mut x);
1335 y.left = Some(x);
1336 upd(&mut y);
1337 y
1338}
1339
1340#[inline(always)]
1341fn rotate_right(mut y: Box<Node>) -> Box<Node> {
1342 let mut x = y.left.take().expect("rotate_right requires left child");
1343 y.left = x.right.take();
1344 upd(&mut y);
1345 x.right = Some(y);
1346 upd(&mut x);
1347 x
1348}
1349
1350fn insert_rec(node: Link, key: Entry, prio: u64) -> Link {
1351 match node {
1352 None => Some(Box::new(Node {
1353 key,
1354 prio,
1355 size: 1,
1356 left: None,
1357 right: None,
1358 })),
1359 Some(mut n) => {
1360 if less(key, n.key) {
1361 n.left = insert_rec(n.left.take(), key, prio);
1362 if n.left.as_ref().unwrap().prio > n.prio {
1363 n = rotate_right(n);
1364 }
1365 } else {
1366 n.right = insert_rec(n.right.take(), key, prio);
1367 if n.right.as_ref().unwrap().prio > n.prio {
1368 n = rotate_left(n);
1369 }
1370 }
1371 upd(&mut n);
1372 Some(n)
1373 }
1374 }
1375}
1376
1377fn remove_rec(node: Link, key: Entry) -> Link {
1378 match node {
1379 None => None,
1380 Some(mut n) => {
1381 if n.key.id == key.id {
1382 return match (n.left.take(), n.right.take()) {
1383 (None, r) => r,
1384 (l, None) => l,
1385 (Some(lc), Some(rc)) => {
1386 let (mut n2, left_is_higher) = if lc.prio > rc.prio {
1387 let mut n2 = Box::new(Node {
1388 key: n.key,
1389 prio: n.prio,
1390 size: 0,
1391 left: Some(lc),
1392 right: Some(rc),
1393 });
1394 n2 = rotate_right(n2);
1395 (n2, true)
1396 } else {
1397 let mut n2 = Box::new(Node {
1398 key: n.key,
1399 prio: n.prio,
1400 size: 0,
1401 left: Some(lc),
1402 right: Some(rc),
1403 });
1404 n2 = rotate_left(n2);
1405 (n2, false)
1406 };
1407 if left_is_higher {
1408 n2.right = remove_rec(n2.right.take(), key);
1409 } else {
1410 n2.left = remove_rec(n2.left.take(), key);
1411 }
1412 upd(&mut n2);
1413 Some(n2)
1414 }
1415 };
1416 }
1417 if less(key, n.key) {
1418 n.left = remove_rec(n.left.take(), key);
1419 } else {
1420 n.right = remove_rec(n.right.take(), key);
1421 }
1422 upd(&mut n);
1423 Some(n)
1424 }
1425 }
1426}
1427
1428fn kth_rec(node: &Link, mut k: usize) -> Entry {
1429 let n = node.as_ref().expect("kth_rec on empty tree");
1430 let ls = size_of(&n.left);
1431 if k < ls {
1432 kth_rec(&n.left, k)
1433 } else if k == ls {
1434 n.key
1435 } else {
1436 k -= ls + 1;
1437 kth_rec(&n.right, k)
1438 }
1439}
1440
1441#[inline(always)]
1442fn treap_priority(e: Entry) -> u64 {
1443 let mut z = e.id ^ e.val.to_bits();
1444 z = z.wrapping_add(0x9E37_79B9_7F4A_7C15);
1445 z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
1446 z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
1447 z ^ (z >> 31)
1448}
1449
1450#[cfg(test)]
1451mod tests {
1452 use super::*;
1453 use crate::skip_if_unsupported;
1454 use crate::utilities::data_loader::read_candles_from_csv;
1455 #[cfg(feature = "proptest")]
1456 use proptest::prelude::*;
1457
1458 fn check_medium_ad_partial_params(
1459 test_name: &str,
1460 kernel: Kernel,
1461 ) -> Result<(), Box<dyn std::error::Error>> {
1462 skip_if_unsupported!(kernel, test_name);
1463 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1464 let candles = read_candles_from_csv(file_path)?;
1465
1466 let default_params = MediumAdParams { period: None };
1467 let input = MediumAdInput::from_candles(&candles, "close", default_params);
1468 let output = medium_ad_with_kernel(&input, kernel)?;
1469 assert_eq!(output.values.len(), candles.close.len());
1470 Ok(())
1471 }
1472
1473 fn check_medium_ad_accuracy(
1474 test_name: &str,
1475 kernel: Kernel,
1476 ) -> Result<(), Box<dyn std::error::Error>> {
1477 skip_if_unsupported!(kernel, test_name);
1478 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1479 let candles = read_candles_from_csv(file_path)?;
1480
1481 let params = MediumAdParams { period: Some(5) };
1482 let input = MediumAdInput::from_candles(&candles, "hl2", params);
1483 let result = medium_ad_with_kernel(&input, kernel)?;
1484 let expected_last_five = [220.0, 78.5, 126.5, 48.0, 28.5];
1485 let start = result.values.len().saturating_sub(5);
1486 for (i, &val) in result.values[start..].iter().enumerate() {
1487 let diff = (val - expected_last_five[i]).abs();
1488 assert!(
1489 diff < 1e-1,
1490 "[{}] MEDIUM_AD {:?} mismatch at idx {}: got {}, expected {}",
1491 test_name,
1492 kernel,
1493 i,
1494 val,
1495 expected_last_five[i]
1496 );
1497 }
1498 Ok(())
1499 }
1500
1501 fn check_medium_ad_default_candles(
1502 test_name: &str,
1503 kernel: Kernel,
1504 ) -> Result<(), Box<dyn std::error::Error>> {
1505 skip_if_unsupported!(kernel, test_name);
1506 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1507 let candles = read_candles_from_csv(file_path)?;
1508 let input = MediumAdInput::with_default_candles(&candles);
1509 match input.data {
1510 MediumAdData::Candles { source, .. } => assert_eq!(source, "close"),
1511 _ => panic!("Expected MediumAdData::Candles"),
1512 }
1513 let output = medium_ad_with_kernel(&input, kernel)?;
1514 assert_eq!(output.values.len(), candles.close.len());
1515 Ok(())
1516 }
1517
1518 fn check_medium_ad_zero_period(
1519 test_name: &str,
1520 kernel: Kernel,
1521 ) -> Result<(), Box<dyn std::error::Error>> {
1522 skip_if_unsupported!(kernel, test_name);
1523 let input_data = [10.0, 20.0, 30.0];
1524 let params = MediumAdParams { period: Some(0) };
1525 let input = MediumAdInput::from_slice(&input_data, params);
1526 let res = medium_ad_with_kernel(&input, kernel);
1527 assert!(
1528 res.is_err(),
1529 "[{}] MEDIUM_AD should fail with zero period",
1530 test_name
1531 );
1532 Ok(())
1533 }
1534
1535 fn check_medium_ad_period_exceeds_length(
1536 test_name: &str,
1537 kernel: Kernel,
1538 ) -> Result<(), Box<dyn std::error::Error>> {
1539 skip_if_unsupported!(kernel, test_name);
1540 let data_small = [10.0, 20.0, 30.0];
1541 let params = MediumAdParams { period: Some(10) };
1542 let input = MediumAdInput::from_slice(&data_small, params);
1543 let res = medium_ad_with_kernel(&input, kernel);
1544 assert!(
1545 res.is_err(),
1546 "[{}] MEDIUM_AD should fail with period exceeding length",
1547 test_name
1548 );
1549 Ok(())
1550 }
1551
1552 fn check_medium_ad_very_small_dataset(
1553 test_name: &str,
1554 kernel: Kernel,
1555 ) -> Result<(), Box<dyn std::error::Error>> {
1556 skip_if_unsupported!(kernel, test_name);
1557 let single_point = [42.0];
1558 let params = MediumAdParams { period: Some(5) };
1559 let input = MediumAdInput::from_slice(&single_point, params);
1560 let res = medium_ad_with_kernel(&input, kernel);
1561 assert!(
1562 res.is_err(),
1563 "[{}] MEDIUM_AD should fail with insufficient data",
1564 test_name
1565 );
1566 Ok(())
1567 }
1568
1569 fn check_medium_ad_reinput(
1570 test_name: &str,
1571 kernel: Kernel,
1572 ) -> Result<(), Box<dyn std::error::Error>> {
1573 skip_if_unsupported!(kernel, test_name);
1574 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1575 let candles = read_candles_from_csv(file_path)?;
1576
1577 let first_params = MediumAdParams { period: Some(5) };
1578 let first_input = MediumAdInput::from_candles(&candles, "close", first_params);
1579 let first_result = medium_ad_with_kernel(&first_input, kernel)?;
1580
1581 let second_params = MediumAdParams { period: Some(5) };
1582 let second_input = MediumAdInput::from_slice(&first_result.values, second_params);
1583 let second_result = medium_ad_with_kernel(&second_input, kernel)?;
1584
1585 assert_eq!(second_result.values.len(), first_result.values.len());
1586 Ok(())
1587 }
1588
1589 fn check_medium_ad_nan_handling(
1590 test_name: &str,
1591 kernel: Kernel,
1592 ) -> Result<(), Box<dyn std::error::Error>> {
1593 skip_if_unsupported!(kernel, test_name);
1594 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1595 let candles = read_candles_from_csv(file_path)?;
1596 let input =
1597 MediumAdInput::from_candles(&candles, "close", MediumAdParams { period: Some(5) });
1598 let res = medium_ad_with_kernel(&input, kernel)?;
1599 assert_eq!(res.values.len(), candles.close.len());
1600 if res.values.len() > 60 {
1601 for (i, &val) in res.values[60..].iter().enumerate() {
1602 assert!(
1603 !val.is_nan(),
1604 "[{}] Found unexpected NaN at out-index {}",
1605 test_name,
1606 60 + i
1607 );
1608 }
1609 }
1610 Ok(())
1611 }
1612
1613 fn check_medium_ad_streaming(
1614 test_name: &str,
1615 kernel: Kernel,
1616 ) -> Result<(), Box<dyn std::error::Error>> {
1617 skip_if_unsupported!(kernel, test_name);
1618
1619 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1620 let candles = read_candles_from_csv(file_path)?;
1621
1622 let period = 5;
1623 let input = MediumAdInput::from_candles(
1624 &candles,
1625 "close",
1626 MediumAdParams {
1627 period: Some(period),
1628 },
1629 );
1630 let batch_output = medium_ad_with_kernel(&input, kernel)?.values;
1631
1632 let mut stream = MediumAdStream::try_new(MediumAdParams {
1633 period: Some(period),
1634 })?;
1635
1636 let mut stream_values = Vec::with_capacity(candles.close.len());
1637 for &price in &candles.close {
1638 match stream.update(price) {
1639 Some(mad_val) => stream_values.push(mad_val),
1640 None => stream_values.push(f64::NAN),
1641 }
1642 }
1643
1644 assert_eq!(batch_output.len(), stream_values.len());
1645 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1646 if b.is_nan() && s.is_nan() {
1647 continue;
1648 }
1649 let diff = (b - s).abs();
1650 assert!(
1651 diff < 1e-9,
1652 "[{}] MEDIUM_AD streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1653 test_name,
1654 i,
1655 b,
1656 s,
1657 diff
1658 );
1659 }
1660 Ok(())
1661 }
1662
1663 #[cfg(feature = "proptest")]
1664 #[allow(clippy::float_cmp)]
1665 fn check_medium_ad_property(
1666 test_name: &str,
1667 kernel: Kernel,
1668 ) -> Result<(), Box<dyn std::error::Error>> {
1669 skip_if_unsupported!(kernel, test_name);
1670
1671 let strat = (1usize..=64).prop_flat_map(|period| {
1672 (
1673 prop::collection::vec(
1674 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1675 period..400,
1676 ),
1677 Just(period),
1678 )
1679 });
1680
1681 proptest::test_runner::TestRunner::default()
1682 .run(&strat, |(data, period)| {
1683 let params = MediumAdParams {
1684 period: Some(period),
1685 };
1686 let input = MediumAdInput::from_slice(&data, params);
1687
1688 let MediumAdOutput { values: out } = medium_ad_with_kernel(&input, kernel).unwrap();
1689
1690 let MediumAdOutput { values: ref_out } =
1691 medium_ad_with_kernel(&input, Kernel::Scalar).unwrap();
1692
1693 for i in 0..data.len() {
1694 let y = out[i];
1695 let r = ref_out[i];
1696
1697 if y.is_nan() {
1698 prop_assert!(r.is_nan(), "Kernel consistency: NaN mismatch at idx {}", i);
1699 } else if r.is_nan() {
1700 prop_assert!(y.is_nan(), "Kernel consistency: NaN mismatch at idx {}", i);
1701 } else {
1702 let ulp_diff = y.to_bits().abs_diff(r.to_bits());
1703 prop_assert!(
1704 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1705 "Kernel mismatch at idx {}: {} vs {} (ULP={})",
1706 i,
1707 y,
1708 r,
1709 ulp_diff
1710 );
1711 }
1712 }
1713
1714 for i in 0..(period - 1) {
1715 prop_assert!(
1716 out[i].is_nan(),
1717 "Expected NaN during warmup at idx {}, got {}",
1718 i,
1719 out[i]
1720 );
1721 }
1722
1723 for i in (period - 1)..data.len() {
1724 let mad = out[i];
1725 prop_assert!(
1726 mad.is_finite() && mad >= 0.0,
1727 "MAD at idx {} is not finite or negative: {}",
1728 i,
1729 mad
1730 );
1731 }
1732
1733 if data.windows(2).all(|w| (w[0] - w[1]).abs() < f64::EPSILON)
1734 && data.len() >= period
1735 {
1736 for i in (period - 1)..data.len() {
1737 prop_assert!(
1738 out[i].abs() < 1e-9,
1739 "Constant data should have MAD=0.0, got {} at idx {}",
1740 out[i],
1741 i
1742 );
1743 }
1744 }
1745
1746 for i in (period - 1)..data.len() {
1747 let window = &data[i + 1 - period..=i];
1748 let min_val = window.iter().cloned().fold(f64::INFINITY, f64::min);
1749 let max_val = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1750 let range = max_val - min_val;
1751 let mad = out[i];
1752
1753 prop_assert!(
1754 mad <= range * 0.5 + 1e-9,
1755 "MAD {} exceeds theoretical maximum (50% of range {}) at idx {}",
1756 mad,
1757 range * 0.5,
1758 i
1759 );
1760 }
1761
1762 if period == 1 {
1763 for i in 0..data.len() {
1764 if !out[i].is_nan() {
1765 prop_assert!(
1766 out[i].abs() < f64::EPSILON,
1767 "Period=1 should have MAD=0.0, got {} at idx {}",
1768 out[i],
1769 i
1770 );
1771 }
1772 }
1773 }
1774
1775 let neg_data: Vec<f64> = data.iter().map(|&x| -x).collect();
1776 let neg_input = MediumAdInput::from_slice(
1777 &neg_data,
1778 MediumAdParams {
1779 period: Some(period),
1780 },
1781 );
1782 let MediumAdOutput { values: neg_out } =
1783 medium_ad_with_kernel(&neg_input, kernel).unwrap();
1784
1785 for i in (period - 1)..data.len() {
1786 let mad = out[i];
1787 let neg_mad = neg_out[i];
1788 prop_assert!(
1789 (mad - neg_mad).abs() < 1e-9,
1790 "Symmetry failed at idx {}: {} vs {}",
1791 i,
1792 mad,
1793 neg_mad
1794 );
1795 }
1796
1797 let scale_factors = [2.0, -3.0, 0.5];
1798 for &scale in &scale_factors {
1799 let scaled_data: Vec<f64> = data.iter().map(|&x| x * scale).collect();
1800 let scaled_input = MediumAdInput::from_slice(
1801 &scaled_data,
1802 MediumAdParams {
1803 period: Some(period),
1804 },
1805 );
1806 let MediumAdOutput { values: scaled_out } =
1807 medium_ad_with_kernel(&scaled_input, kernel).unwrap();
1808
1809 for i in (period - 1)..data.len() {
1810 let original_mad = out[i];
1811 let scaled_mad = scaled_out[i];
1812 let expected_scaled_mad = original_mad * scale.abs();
1813
1814 prop_assert!(
1815 (scaled_mad - expected_scaled_mad).abs() < 1e-9,
1816 "Scale invariance failed at idx {} with scale {}: {} vs expected {}",
1817 i,
1818 scale,
1819 scaled_mad,
1820 expected_scaled_mad
1821 );
1822 }
1823 }
1824
1825 if period >= 5 && data.len() >= period + 10 {
1826 let mut outlier_data = data.clone();
1827
1828 for test_idx in (period + 4..data.len().min(period + 20)).step_by(5) {
1829 let window = &data[test_idx + 1 - period..=test_idx];
1830 let win_min = window.iter().cloned().fold(f64::INFINITY, f64::min);
1831 let win_max = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1832 let win_range = win_max - win_min;
1833
1834 let outlier_idx = test_idx - period / 2;
1835 let original_value = outlier_data[outlier_idx];
1836 outlier_data[outlier_idx] = win_max + win_range * 10.0;
1837
1838 let outlier_input = MediumAdInput::from_slice(
1839 &outlier_data,
1840 MediumAdParams {
1841 period: Some(period),
1842 },
1843 );
1844 let MediumAdOutput {
1845 values: outlier_out,
1846 } = medium_ad_with_kernel(&outlier_input, kernel).unwrap();
1847
1848 let original_mad = out[test_idx];
1849 let outlier_mad = outlier_out[test_idx];
1850
1851 let outlier_effect = win_range * 10.0;
1852 prop_assert!(
1853 outlier_mad <= original_mad * 10.0 + outlier_effect * 0.1,
1854 "MAD not robust enough to outliers at idx {}: original {}, with outlier {}",
1855 test_idx, original_mad, outlier_mad
1856 );
1857
1858 if original_mad > win_range * 0.05 {
1859 let mad_ratio = outlier_mad / original_mad;
1860 prop_assert!(
1861 mad_ratio <= 25.0,
1862 "MAD ratio too high with outlier at idx {}: ratio {}",
1863 test_idx,
1864 mad_ratio
1865 );
1866 }
1867
1868 outlier_data[outlier_idx] = original_value;
1869 }
1870 }
1871
1872 if period >= 3 && period <= 20 {
1873 let sequential: Vec<f64> = (1..=period).map(|i| i as f64).collect();
1874 let seq_input = MediumAdInput::from_slice(
1875 &sequential,
1876 MediumAdParams {
1877 period: Some(period),
1878 },
1879 );
1880 let MediumAdOutput { values: seq_out } =
1881 medium_ad_with_kernel(&seq_input, kernel).unwrap();
1882
1883 let median = if period % 2 == 1 {
1884 (period / 2 + 1) as f64
1885 } else {
1886 (period / 2) as f64 + 0.5
1887 };
1888
1889 if period - 1 < sequential.len() {
1890 let calculated_mad = seq_out[period - 1];
1891 let seq_range = (period - 1) as f64;
1892
1893 prop_assert!(
1894 calculated_mad > 0.0 && calculated_mad <= seq_range * 0.5,
1895 "MAD for sequential data with period {} out of bounds: got {}, range is {}",
1896 period, calculated_mad, seq_range
1897 );
1898
1899 if period == 3 {
1900 prop_assert!(
1901 (calculated_mad - 1.0).abs() < 1e-9,
1902 "MAD for [1,2,3] should be 1.0, got {}",
1903 calculated_mad
1904 );
1905 } else if period == 5 {
1906 prop_assert!(
1907 (calculated_mad - 1.0).abs() < 1e-9,
1908 "MAD for [1,2,3,4,5] should be 1.0, got {}",
1909 calculated_mad
1910 );
1911 }
1912 }
1913 }
1914
1915 if period >= 4 && period % 2 == 0 {
1916 let mut extreme_data = vec![0.0; period];
1917 for i in 0..period / 2 {
1918 extreme_data[i] = 100.0;
1919 }
1920
1921 let extreme_input = MediumAdInput::from_slice(
1922 &extreme_data,
1923 MediumAdParams {
1924 period: Some(period),
1925 },
1926 );
1927 let MediumAdOutput {
1928 values: extreme_out,
1929 } = medium_ad_with_kernel(&extreme_input, kernel).unwrap();
1930
1931 let expected_extreme_mad = 50.0;
1932
1933 if period - 1 < extreme_data.len() {
1934 let calculated_extreme_mad = extreme_out[period - 1];
1935 prop_assert!(
1936 (calculated_extreme_mad - expected_extreme_mad).abs() < 1e-9,
1937 "MAD mismatch for extreme data pattern with period {}: got {}, expected {}",
1938 period, calculated_extreme_mad, expected_extreme_mad
1939 );
1940 }
1941 }
1942
1943 Ok(())
1944 })
1945 .unwrap();
1946
1947 Ok(())
1948 }
1949
1950 macro_rules! generate_all_medium_ad_tests {
1951 ($($test_fn:ident),*) => {
1952 paste::paste! {
1953 $(
1954 #[test]
1955 fn [<$test_fn _scalar_f64>]() {
1956 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1957 }
1958 )*
1959 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1960 $(
1961 #[test]
1962 fn [<$test_fn _avx2_f64>]() {
1963 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1964 }
1965 #[test]
1966 fn [<$test_fn _avx512_f64>]() {
1967 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1968 }
1969 )*
1970 }
1971 }
1972 }
1973
1974 generate_all_medium_ad_tests!(
1975 check_medium_ad_partial_params,
1976 check_medium_ad_accuracy,
1977 check_medium_ad_default_candles,
1978 check_medium_ad_zero_period,
1979 check_medium_ad_period_exceeds_length,
1980 check_medium_ad_very_small_dataset,
1981 check_medium_ad_reinput,
1982 check_medium_ad_nan_handling,
1983 check_medium_ad_streaming
1984 );
1985
1986 #[cfg(feature = "proptest")]
1987 generate_all_medium_ad_tests!(check_medium_ad_property);
1988
1989 fn check_batch_default_row(
1990 test: &str,
1991 kernel: Kernel,
1992 ) -> Result<(), Box<dyn std::error::Error>> {
1993 skip_if_unsupported!(kernel, test);
1994
1995 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1996 let c = read_candles_from_csv(file)?;
1997
1998 let output = MediumAdBatchBuilder::new()
1999 .kernel(kernel)
2000 .apply_candles(&c, "close")?;
2001
2002 let def = MediumAdParams::default();
2003 let row = output.values_for(&def).expect("default row missing");
2004
2005 assert_eq!(row.len(), c.close.len());
2006 Ok(())
2007 }
2008
2009 macro_rules! gen_batch_tests {
2010 ($fn_name:ident) => {
2011 paste::paste! {
2012 #[test] fn [<$fn_name _scalar>]() {
2013 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2014 }
2015 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2016 #[test] fn [<$fn_name _avx2>]() {
2017 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2018 }
2019 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2020 #[test] fn [<$fn_name _avx512>]() {
2021 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2022 }
2023 #[test] fn [<$fn_name _auto_detect>]() {
2024 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2025 }
2026 }
2027 };
2028 }
2029 gen_batch_tests!(check_batch_default_row);
2030}
2031
2032#[cfg(feature = "python")]
2033use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
2034#[cfg(feature = "python")]
2035use pyo3::exceptions::{PyBufferError, PyValueError};
2036#[cfg(feature = "python")]
2037use pyo3::prelude::*;
2038#[cfg(feature = "python")]
2039use pyo3::types::PyDict;
2040
2041#[cfg(feature = "python")]
2042use crate::utilities::kernel_validation::validate_kernel;
2043
2044#[cfg(feature = "python")]
2045#[pyfunction(name = "medium_ad")]
2046#[pyo3(signature = (data, period, kernel=None))]
2047pub fn medium_ad_py<'py>(
2048 py: Python<'py>,
2049 data: PyReadonlyArray1<'py, f64>,
2050 period: usize,
2051 kernel: Option<&str>,
2052) -> PyResult<Bound<'py, PyArray1<f64>>> {
2053 let slice_in = data.as_slice()?;
2054 let kern = validate_kernel(kernel, false)?;
2055
2056 let params = MediumAdParams {
2057 period: Some(period),
2058 };
2059 let input = MediumAdInput::from_slice(slice_in, params);
2060
2061 let result_vec: Vec<f64> = py
2062 .allow_threads(|| medium_ad_with_kernel(&input, kern).map(|o| o.values))
2063 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2064
2065 Ok(result_vec.into_pyarray(py))
2066}
2067
2068#[cfg(feature = "python")]
2069#[pyclass(name = "MediumAdStream")]
2070pub struct MediumAdStreamPy {
2071 stream: MediumAdStream,
2072}
2073
2074#[cfg(feature = "python")]
2075#[pymethods]
2076impl MediumAdStreamPy {
2077 #[new]
2078 fn new(period: usize) -> PyResult<Self> {
2079 let params = MediumAdParams {
2080 period: Some(period),
2081 };
2082 let stream =
2083 MediumAdStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2084 Ok(MediumAdStreamPy { stream })
2085 }
2086
2087 fn update(&mut self, value: f64) -> Option<f64> {
2088 self.stream.update(value)
2089 }
2090}
2091
2092#[cfg(feature = "python")]
2093#[pyfunction(name = "medium_ad_batch")]
2094#[pyo3(signature = (data, period_range, kernel=None))]
2095pub fn medium_ad_batch_py<'py>(
2096 py: Python<'py>,
2097 data: PyReadonlyArray1<'py, f64>,
2098 period_range: (usize, usize, usize),
2099 kernel: Option<&str>,
2100) -> PyResult<Bound<'py, PyDict>> {
2101 let slice_in = data.as_slice()?;
2102 let kern = validate_kernel(kernel, true)?;
2103
2104 let sweep = MediumAdBatchRange {
2105 period: period_range,
2106 };
2107
2108 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2109 let rows = combos.len();
2110 let cols = slice_in.len();
2111
2112 let total = rows
2113 .checked_mul(cols)
2114 .ok_or_else(|| PyValueError::new_err("medium_ad: batch output size overflow"))?;
2115 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2116 let slice_out = unsafe { out_arr.as_slice_mut()? };
2117
2118 let combos = py
2119 .allow_threads(|| {
2120 let kernel = match kern {
2121 Kernel::Auto => Kernel::ScalarBatch,
2122 k => k,
2123 };
2124 let simd = match kernel {
2125 Kernel::Avx512Batch => Kernel::Avx512,
2126 Kernel::Avx2Batch => Kernel::Avx2,
2127 Kernel::ScalarBatch => Kernel::Scalar,
2128 _ => unreachable!(),
2129 };
2130 medium_ad_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2131 })
2132 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2133
2134 let dict = PyDict::new(py);
2135 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2136 dict.set_item(
2137 "periods",
2138 combos
2139 .iter()
2140 .map(|p| p.period.unwrap() as u64)
2141 .collect::<Vec<_>>()
2142 .into_pyarray(py),
2143 )?;
2144
2145 Ok(dict)
2146}
2147
2148#[cfg(all(feature = "python", feature = "cuda"))]
2149use crate::cuda::medium_ad_wrapper::CudaMediumAd;
2150#[cfg(all(feature = "python", feature = "cuda"))]
2151use crate::cuda::moving_averages::DeviceArrayF32;
2152#[cfg(all(feature = "python", feature = "cuda"))]
2153use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2154#[cfg(all(feature = "python", feature = "cuda"))]
2155use cust::context::Context;
2156#[cfg(all(feature = "python", feature = "cuda"))]
2157use cust::memory::DeviceBuffer;
2158#[cfg(all(feature = "python", feature = "cuda"))]
2159use std::sync::Arc;
2160
2161#[cfg(all(feature = "python", feature = "cuda"))]
2162#[pyclass(module = "ta_indicators.cuda", unsendable)]
2163pub struct MediumAdDeviceArrayF32Py {
2164 pub(crate) inner: DeviceArrayF32,
2165 ctx: Arc<Context>,
2166 device_id: u32,
2167}
2168
2169#[cfg(all(feature = "python", feature = "cuda"))]
2170#[pymethods]
2171impl MediumAdDeviceArrayF32Py {
2172 #[getter]
2173 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2174 let inner = &self.inner;
2175 let d = PyDict::new(py);
2176
2177 d.set_item("shape", (inner.rows, inner.cols))?;
2178
2179 d.set_item("typestr", "<f4")?;
2180
2181 d.set_item(
2182 "strides",
2183 (
2184 inner.cols * std::mem::size_of::<f32>(),
2185 std::mem::size_of::<f32>(),
2186 ),
2187 )?;
2188 let ptr_val = if inner.rows == 0 || inner.cols == 0 {
2189 0usize
2190 } else {
2191 inner.device_ptr() as usize
2192 };
2193 d.set_item("data", (ptr_val, false))?;
2194
2195 d.set_item("version", 3)?;
2196 Ok(d)
2197 }
2198
2199 fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
2200 let mut device_ordinal: i32 = 0;
2201 unsafe {
2202 let attr = cust::sys::CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL;
2203 let mut value = std::mem::MaybeUninit::<i32>::uninit();
2204 let err = cust::sys::cuPointerGetAttribute(
2205 value.as_mut_ptr() as *mut std::ffi::c_void,
2206 attr,
2207 self.inner.buf.as_device_ptr().as_raw(),
2208 );
2209 if err == cust::sys::CUresult::CUDA_SUCCESS {
2210 device_ordinal = value.assume_init();
2211 } else {
2212 let _ = cust::sys::cuCtxGetDevice(&mut device_ordinal);
2213 }
2214 }
2215 Ok((2, device_ordinal))
2216 }
2217
2218 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2219 fn __dlpack__<'py>(
2220 &mut self,
2221 py: Python<'py>,
2222 stream: Option<pyo3::PyObject>,
2223 max_version: Option<pyo3::PyObject>,
2224 dl_device: Option<pyo3::PyObject>,
2225 copy: Option<pyo3::PyObject>,
2226 ) -> PyResult<PyObject> {
2227 if let Some(obj) = &stream {
2228 if let Ok(i) = obj.extract::<i64>(py) {
2229 if i == 0 {
2230 return Err(PyValueError::new_err(
2231 "__dlpack__: stream 0 is disallowed for CUDA",
2232 ));
2233 }
2234 }
2235 }
2236
2237 let (kdl, alloc_dev) = self.__dlpack_device__()?;
2238 if let Some(dev_obj) = dl_device.as_ref() {
2239 if let Ok((dev_type, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2240 if dev_type != kdl || dev_id != alloc_dev {
2241 let wants_copy = copy
2242 .as_ref()
2243 .and_then(|c| c.extract::<bool>(py).ok())
2244 .unwrap_or(false);
2245 if wants_copy {
2246 return Err(PyBufferError::new_err(
2247 "__dlpack__: copy semantics are not supported for medium_ad",
2248 ));
2249 } else {
2250 return Err(PyBufferError::new_err(
2251 "__dlpack__: requested device does not match allocation device",
2252 ));
2253 }
2254 }
2255 }
2256 }
2257
2258 if copy
2259 .as_ref()
2260 .and_then(|c| c.extract::<bool>(py).ok())
2261 .unwrap_or(false)
2262 {
2263 return Err(PyBufferError::new_err(
2264 "__dlpack__: copy semantics are not supported for medium_ad",
2265 ));
2266 }
2267
2268 let dummy =
2269 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2270 let inner = std::mem::replace(
2271 &mut self.inner,
2272 DeviceArrayF32 {
2273 buf: dummy,
2274 rows: 0,
2275 cols: 0,
2276 },
2277 );
2278
2279 let rows = inner.rows;
2280 let cols = inner.cols;
2281 let buf = inner.buf;
2282
2283 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2284 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2285 }
2286}
2287
2288#[cfg(all(feature = "python", feature = "cuda"))]
2289#[pyfunction(name = "medium_ad_cuda_batch_dev")]
2290#[pyo3(signature = (data_f32, period_range, device_id=0))]
2291pub fn medium_ad_cuda_batch_dev_py(
2292 py: Python<'_>,
2293 data_f32: numpy::PyReadonlyArray1<'_, f32>,
2294 period_range: (usize, usize, usize),
2295 device_id: usize,
2296) -> PyResult<MediumAdDeviceArrayF32Py> {
2297 use crate::cuda::cuda_available;
2298 if !cuda_available() {
2299 return Err(PyValueError::new_err("CUDA not available"));
2300 }
2301 let slice = data_f32.as_slice()?;
2302 let sweep = MediumAdBatchRange {
2303 period: period_range,
2304 };
2305 let (inner, ctx, dev_id) =
2306 py.allow_threads(|| -> PyResult<(DeviceArrayF32, Arc<Context>, u32)> {
2307 let cuda =
2308 CudaMediumAd::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2309 let ctx = cuda.context_arc();
2310 let dev_id = cuda.device_id();
2311 let arr = cuda
2312 .medium_ad_batch_dev(slice, &sweep)
2313 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2314 Ok((arr, ctx, dev_id))
2315 })?;
2316 Ok(MediumAdDeviceArrayF32Py {
2317 inner,
2318 ctx,
2319 device_id: dev_id,
2320 })
2321}
2322
2323#[cfg(all(feature = "python", feature = "cuda"))]
2324#[pyfunction(name = "medium_ad_cuda_many_series_one_param_dev")]
2325#[pyo3(signature = (data_tm_f32, cols, rows, period, device_id=0))]
2326pub fn medium_ad_cuda_many_series_one_param_dev_py(
2327 py: Python<'_>,
2328 data_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2329 cols: usize,
2330 rows: usize,
2331 period: usize,
2332 device_id: usize,
2333) -> PyResult<MediumAdDeviceArrayF32Py> {
2334 use crate::cuda::cuda_available;
2335 if !cuda_available() {
2336 return Err(PyValueError::new_err("CUDA not available"));
2337 }
2338 let slice = data_tm_f32.as_slice()?;
2339 let (inner, ctx, dev_id) =
2340 py.allow_threads(|| -> PyResult<(DeviceArrayF32, Arc<Context>, u32)> {
2341 let cuda =
2342 CudaMediumAd::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2343 let ctx = cuda.context_arc();
2344 let dev_id = cuda.device_id();
2345 let arr = cuda
2346 .medium_ad_many_series_one_param_time_major_dev(slice, cols, rows, period)
2347 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2348 Ok((arr, ctx, dev_id))
2349 })?;
2350 Ok(MediumAdDeviceArrayF32Py {
2351 inner,
2352 ctx,
2353 device_id: dev_id,
2354 })
2355}
2356
2357#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2358use serde::{Deserialize, Serialize};
2359#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2360use wasm_bindgen::prelude::*;
2361
2362pub fn medium_ad_into_slice(
2363 dst: &mut [f64],
2364 input: &MediumAdInput,
2365 kern: Kernel,
2366) -> Result<(), MediumAdError> {
2367 let data = match &input.data {
2368 MediumAdData::Candles { candles, source } => source_type(candles, source),
2369 MediumAdData::Slice(s) => s,
2370 };
2371 let period = input.params.period.unwrap_or(5);
2372
2373 if period == 0 || period > data.len() {
2374 return Err(MediumAdError::InvalidPeriod {
2375 period,
2376 data_len: data.len(),
2377 });
2378 }
2379
2380 if dst.len() != data.len() {
2381 return Err(MediumAdError::OutputLengthMismatch {
2382 expected: data.len(),
2383 got: dst.len(),
2384 });
2385 }
2386
2387 let first = data.iter().position(|&x| !x.is_nan()).unwrap_or(0);
2388 let chosen = if kern == Kernel::Auto {
2389 Kernel::Scalar
2390 } else {
2391 kern
2392 };
2393
2394 match chosen {
2395 Kernel::Scalar => medium_ad_scalar(data, period, first, dst),
2396 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2397 Kernel::Avx2 => medium_ad_avx2(data, period, first, dst),
2398
2399 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2400 Kernel::Avx512 => medium_ad_scalar(data, period, first, dst),
2401 _ => unreachable!(),
2402 }
2403
2404 let warmup_end = first + period - 1;
2405 for v in &mut dst[..warmup_end] {
2406 *v = f64::NAN;
2407 }
2408
2409 Ok(())
2410}
2411
2412#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2413#[wasm_bindgen]
2414pub fn medium_ad_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2415 let params = MediumAdParams {
2416 period: Some(period),
2417 };
2418 let input = MediumAdInput::from_slice(data, params);
2419
2420 let mut output = vec![0.0; data.len()];
2421
2422 medium_ad_into_slice(&mut output, &input, Kernel::Auto)
2423 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2424
2425 Ok(output)
2426}
2427
2428#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2429#[wasm_bindgen]
2430pub fn medium_ad_alloc(len: usize) -> *mut f64 {
2431 let mut vec = Vec::<f64>::with_capacity(len);
2432 let ptr = vec.as_mut_ptr();
2433 std::mem::forget(vec);
2434 ptr
2435}
2436
2437#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2438#[wasm_bindgen]
2439pub fn medium_ad_free(ptr: *mut f64, len: usize) {
2440 if !ptr.is_null() {
2441 unsafe {
2442 let _ = Vec::from_raw_parts(ptr, len, len);
2443 }
2444 }
2445}
2446
2447#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2448#[wasm_bindgen]
2449pub fn medium_ad_into(
2450 in_ptr: *const f64,
2451 out_ptr: *mut f64,
2452 len: usize,
2453 period: usize,
2454) -> Result<(), JsValue> {
2455 if in_ptr.is_null() || out_ptr.is_null() {
2456 return Err(JsValue::from_str("Null pointer provided"));
2457 }
2458
2459 unsafe {
2460 let data = std::slice::from_raw_parts(in_ptr, len);
2461
2462 if period == 0 || period > len {
2463 return Err(JsValue::from_str("Invalid period"));
2464 }
2465
2466 let params = MediumAdParams {
2467 period: Some(period),
2468 };
2469 let input = MediumAdInput::from_slice(data, params);
2470
2471 if in_ptr == out_ptr {
2472 let mut temp = vec![0.0; len];
2473 medium_ad_into_slice(&mut temp, &input, Kernel::Auto)
2474 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2475 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2476 out.copy_from_slice(&temp);
2477 } else {
2478 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2479 medium_ad_into_slice(out, &input, Kernel::Auto)
2480 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2481 }
2482
2483 Ok(())
2484 }
2485}
2486
2487#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2488#[derive(Serialize, Deserialize)]
2489pub struct MediumAdBatchConfig {
2490 pub period_range: (usize, usize, usize),
2491}
2492
2493#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2494#[derive(Serialize, Deserialize)]
2495pub struct MediumAdBatchJsOutput {
2496 pub values: Vec<f64>,
2497 pub combos: Vec<MediumAdParams>,
2498 pub rows: usize,
2499 pub cols: usize,
2500}
2501
2502#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2503#[wasm_bindgen(js_name = medium_ad_batch)]
2504pub fn medium_ad_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2505 let config: MediumAdBatchConfig = serde_wasm_bindgen::from_value(config)
2506 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2507
2508 let sweep = MediumAdBatchRange {
2509 period: config.period_range,
2510 };
2511
2512 let output = medium_ad_batch_inner(data, &sweep, Kernel::Auto, false)
2513 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2514
2515 let js_output = MediumAdBatchJsOutput {
2516 values: output.values,
2517 combos: output.combos,
2518 rows: output.rows,
2519 cols: output.cols,
2520 };
2521
2522 serde_wasm_bindgen::to_value(&js_output)
2523 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2524}
2525
2526#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2527#[wasm_bindgen]
2528pub fn medium_ad_batch_into(
2529 in_ptr: *const f64,
2530 out_ptr: *mut f64,
2531 len: usize,
2532 period_start: usize,
2533 period_end: usize,
2534 period_step: usize,
2535) -> Result<usize, JsValue> {
2536 if in_ptr.is_null() || out_ptr.is_null() {
2537 return Err(JsValue::from_str(
2538 "null pointer passed to medium_ad_batch_into",
2539 ));
2540 }
2541
2542 unsafe {
2543 let data = std::slice::from_raw_parts(in_ptr, len);
2544
2545 let sweep = MediumAdBatchRange {
2546 period: (period_start, period_end, period_step),
2547 };
2548
2549 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2550 let rows = combos.len();
2551 let cols = len;
2552
2553 let total = match rows.checked_mul(cols) {
2554 Some(v) => v,
2555 None => {
2556 return Err(JsValue::from_str(
2557 "medium_ad_batch_into: rows*cols overflow in output size",
2558 ))
2559 }
2560 };
2561
2562 let out = std::slice::from_raw_parts_mut(out_ptr, total);
2563
2564 medium_ad_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
2565 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2566
2567 Ok(rows)
2568 }
2569}
2570
2571#[cfg(all(test, not(all(target_arch = "wasm32", feature = "wasm"))))]
2572mod tests_into_parity {
2573 use super::*;
2574
2575 #[test]
2576 fn test_medium_ad_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
2577 let mut data = Vec::with_capacity(256);
2578 data.extend_from_slice(&[f64::NAN, f64::NAN, f64::NAN]);
2579 for i in 0..253usize {
2580 let x = (i as f64 * 0.019).sin() * 3.0 + 42.0 + ((i % 11) as f64) * 0.07;
2581 data.push(x);
2582 }
2583
2584 let input = MediumAdInput::from_slice(&data, MediumAdParams::default());
2585
2586 let baseline = medium_ad(&input)?.values;
2587
2588 let mut out = vec![0.0; data.len()];
2589 medium_ad_into(&input, &mut out)?;
2590
2591 assert_eq!(baseline.len(), out.len());
2592
2593 #[inline]
2594 fn eq_or_nan(a: f64, b: f64) -> bool {
2595 (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
2596 }
2597
2598 for (i, (a, b)) in baseline.iter().zip(out.iter()).enumerate() {
2599 assert!(
2600 eq_or_nan(*a, *b),
2601 "medium_ad_into mismatch at idx {}: baseline={} into={}",
2602 i,
2603 a,
2604 b
2605 );
2606 }
2607
2608 Ok(())
2609 }
2610}