Skip to main content

vector_ta/indicators/
medium_ad.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{alloc_with_nan_prefix, init_matrix_prefixes, make_uninit_matrix};
4#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
5use core::arch::x86_64::*;
6#[cfg(not(target_arch = "wasm32"))]
7use rayon::prelude::*;
8use std::convert::AsRef;
9use thiserror::Error;
10
11impl<'a> AsRef<[f64]> for MediumAdInput<'a> {
12    #[inline(always)]
13    fn as_ref(&self) -> &[f64] {
14        match &self.data {
15            MediumAdData::Slice(slice) => slice,
16            MediumAdData::Candles { candles, source } => source_type(candles, source),
17        }
18    }
19}
20
21#[derive(Debug, Clone)]
22pub enum MediumAdData<'a> {
23    Candles {
24        candles: &'a Candles,
25        source: &'a str,
26    },
27    Slice(&'a [f64]),
28}
29
30#[derive(Debug, Clone)]
31pub struct MediumAdOutput {
32    pub values: Vec<f64>,
33}
34
35#[derive(Debug, Clone)]
36#[cfg_attr(
37    all(target_arch = "wasm32", feature = "wasm"),
38    derive(Serialize, Deserialize)
39)]
40pub struct MediumAdParams {
41    pub period: Option<usize>,
42}
43
44impl Default for MediumAdParams {
45    fn default() -> Self {
46        Self { period: Some(5) }
47    }
48}
49
50#[derive(Debug, Clone)]
51pub struct MediumAdInput<'a> {
52    pub data: MediumAdData<'a>,
53    pub params: MediumAdParams,
54}
55
56impl<'a> MediumAdInput<'a> {
57    #[inline]
58    pub fn from_candles(c: &'a Candles, s: &'a str, p: MediumAdParams) -> Self {
59        Self {
60            data: MediumAdData::Candles {
61                candles: c,
62                source: s,
63            },
64            params: p,
65        }
66    }
67    #[inline]
68    pub fn from_slice(sl: &'a [f64], p: MediumAdParams) -> Self {
69        Self {
70            data: MediumAdData::Slice(sl),
71            params: p,
72        }
73    }
74    #[inline]
75    pub fn with_default_candles(c: &'a Candles) -> Self {
76        Self::from_candles(c, "close", MediumAdParams::default())
77    }
78    #[inline]
79    pub fn get_period(&self) -> usize {
80        self.params.period.unwrap_or(5)
81    }
82}
83
84#[derive(Copy, Clone, Debug)]
85pub struct MediumAdBuilder {
86    period: Option<usize>,
87    kernel: Kernel,
88}
89
90impl Default for MediumAdBuilder {
91    fn default() -> Self {
92        Self {
93            period: None,
94            kernel: Kernel::Auto,
95        }
96    }
97}
98
99impl MediumAdBuilder {
100    #[inline(always)]
101    pub fn new() -> Self {
102        Self::default()
103    }
104    #[inline(always)]
105    pub fn period(mut self, n: usize) -> Self {
106        self.period = Some(n);
107        self
108    }
109    #[inline(always)]
110    pub fn kernel(mut self, k: Kernel) -> Self {
111        self.kernel = k;
112        self
113    }
114
115    #[inline(always)]
116    pub fn apply(self, c: &Candles) -> Result<MediumAdOutput, MediumAdError> {
117        let p = MediumAdParams {
118            period: self.period,
119        };
120        let i = MediumAdInput::from_candles(c, "close", p);
121        medium_ad_with_kernel(&i, self.kernel)
122    }
123
124    #[inline(always)]
125    pub fn apply_slice(self, d: &[f64]) -> Result<MediumAdOutput, MediumAdError> {
126        let p = MediumAdParams {
127            period: self.period,
128        };
129        let i = MediumAdInput::from_slice(d, p);
130        medium_ad_with_kernel(&i, self.kernel)
131    }
132
133    #[inline(always)]
134    pub fn into_stream(self) -> Result<MediumAdStream, MediumAdError> {
135        let p = MediumAdParams {
136            period: self.period,
137        };
138        MediumAdStream::try_new(p)
139    }
140}
141
142#[derive(Debug, Error)]
143pub enum MediumAdError {
144    #[error("medium_ad: Empty input data slice.")]
145    EmptyInputData,
146
147    #[error("medium_ad: All values are NaN.")]
148    AllValuesNaN,
149
150    #[error("medium_ad: Invalid period: period = {period}, data length = {data_len}")]
151    InvalidPeriod { period: usize, data_len: usize },
152
153    #[error("medium_ad: Not enough valid data: needed = {needed}, valid = {valid}")]
154    NotEnoughValidData { needed: usize, valid: usize },
155
156    #[error("medium_ad: Output length mismatch: expected {expected}, got {got}")]
157    OutputLengthMismatch { expected: usize, got: usize },
158
159    #[error("medium_ad: Invalid range: start={start}, end={end}, step={step}")]
160    InvalidRange {
161        start: String,
162        end: String,
163        step: String,
164    },
165
166    #[error("medium_ad: Invalid kernel for batch: {0:?}")]
167    InvalidKernelForBatch(Kernel),
168}
169
170#[inline]
171pub fn medium_ad(input: &MediumAdInput) -> Result<MediumAdOutput, MediumAdError> {
172    medium_ad_with_kernel(input, Kernel::Auto)
173}
174
175pub fn medium_ad_with_kernel(
176    input: &MediumAdInput,
177    kernel: Kernel,
178) -> Result<MediumAdOutput, MediumAdError> {
179    let data: &[f64] = match &input.data {
180        MediumAdData::Candles { candles, source } => source_type(candles, source),
181        MediumAdData::Slice(sl) => sl,
182    };
183
184    if data.is_empty() {
185        return Err(MediumAdError::EmptyInputData);
186    }
187
188    let first = data
189        .iter()
190        .position(|x| !x.is_nan())
191        .ok_or(MediumAdError::AllValuesNaN)?;
192    let len = data.len();
193    let period = input.get_period();
194
195    if period == 0 || period > len {
196        return Err(MediumAdError::InvalidPeriod {
197            period,
198            data_len: len,
199        });
200    }
201    if (len - first) < period {
202        return Err(MediumAdError::NotEnoughValidData {
203            needed: period,
204            valid: len - first,
205        });
206    }
207
208    let mut out = alloc_with_nan_prefix(len, first + period - 1);
209
210    let chosen = match kernel {
211        Kernel::Auto => Kernel::Scalar,
212        other => other,
213    };
214
215    unsafe {
216        match chosen {
217            Kernel::Scalar | Kernel::ScalarBatch => medium_ad_scalar(data, period, first, &mut out),
218            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
219            Kernel::Avx2 | Kernel::Avx2Batch => medium_ad_avx2(data, period, first, &mut out),
220
221            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
222            Kernel::Avx512 | Kernel::Avx512Batch => medium_ad_scalar(data, period, first, &mut out),
223            _ => unreachable!(),
224        }
225    }
226
227    Ok(MediumAdOutput { values: out })
228}
229
230#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
231#[inline]
232pub fn medium_ad_into(input: &MediumAdInput, out: &mut [f64]) -> Result<(), MediumAdError> {
233    let data: &[f64] = match &input.data {
234        MediumAdData::Candles { candles, source } => source_type(candles, source),
235        MediumAdData::Slice(sl) => sl,
236    };
237
238    let len = data.len();
239    if len == 0 {
240        return Err(MediumAdError::EmptyInputData);
241    }
242
243    let first = data
244        .iter()
245        .position(|x| !x.is_nan())
246        .ok_or(MediumAdError::AllValuesNaN)?;
247
248    let period = input.get_period();
249    if period == 0 || period > len {
250        return Err(MediumAdError::InvalidPeriod {
251            period,
252            data_len: len,
253        });
254    }
255    if (len - first) < period {
256        return Err(MediumAdError::NotEnoughValidData {
257            needed: period,
258            valid: len - first,
259        });
260    }
261
262    if out.len() != len {
263        return Err(MediumAdError::OutputLengthMismatch {
264            expected: len,
265            got: out.len(),
266        });
267    }
268
269    let warm = first + period - 1;
270    let warm_cap = warm.min(len);
271    for v in &mut out[..warm_cap] {
272        *v = f64::from_bits(0x7ff8_0000_0000_0000);
273    }
274
275    let chosen = Kernel::Scalar;
276
277    unsafe {
278        match chosen {
279            Kernel::Scalar | Kernel::ScalarBatch => medium_ad_scalar(data, period, first, out),
280            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
281            Kernel::Avx2 | Kernel::Avx2Batch => medium_ad_avx2(data, period, first, out),
282
283            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
284            Kernel::Avx512 | Kernel::Avx512Batch => medium_ad_scalar(data, period, first, out),
285            _ => unreachable!(),
286        }
287    }
288
289    Ok(())
290}
291
292#[inline]
293pub fn medium_ad_scalar(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
294    use core::cmp::Ordering;
295
296    #[inline(always)]
297    fn fast_abs_f64(x: f64) -> f64 {
298        f64::from_bits(x.to_bits() & 0x7FFF_FFFF_FFFF_FFFF)
299    }
300
301    #[inline(always)]
302    fn median_from(buf: &mut [f64], mid: usize) -> f64 {
303        buf.select_nth_unstable_by(mid, |a, b| {
304            if *a < *b {
305                Ordering::Less
306            } else if *a > *b {
307                Ordering::Greater
308            } else {
309                Ordering::Equal
310            }
311        });
312        if (buf.len() & 1) == 1 {
313            unsafe { *buf.get_unchecked(mid) }
314        } else {
315            let mut lo_max = f64::NEG_INFINITY;
316            let left = unsafe { core::slice::from_raw_parts(buf.as_ptr(), mid) };
317            for &v in left.iter() {
318                if v > lo_max {
319                    lo_max = v;
320                }
321            }
322            0.5 * (lo_max + unsafe { *buf.get_unchecked(mid) })
323        }
324    }
325
326    let len = data.len();
327    if period == 1 {
328        let start = first_valid;
329        for i in start..len {
330            let v = unsafe { *data.get_unchecked(i) };
331            unsafe { *out.get_unchecked_mut(i) = if v.is_nan() { f64::NAN } else { 0.0 } };
332        }
333        return;
334    }
335
336    let mut buf: Vec<f64> = Vec::with_capacity(period);
337    unsafe { buf.set_len(period) };
338    let mid = period >> 1;
339    let warm = first_valid + period - 1;
340
341    for i in warm..len {
342        let start = i + 1 - period;
343
344        let mut has_nan = false;
345        unsafe {
346            let dp = data.as_ptr().add(start);
347            let bp = buf.as_mut_ptr();
348            let mut k = 0usize;
349
350            while k + 4 <= period {
351                let a = *dp.add(k);
352                let b = *dp.add(k + 1);
353                let c = *dp.add(k + 2);
354                let d = *dp.add(k + 3);
355                *bp.add(k) = a;
356                *bp.add(k + 1) = b;
357                *bp.add(k + 2) = c;
358                *bp.add(k + 3) = d;
359                has_nan |= (a != a) | (b != b) | (c != c) | (d != d);
360                k += 4;
361            }
362            while k < period {
363                let v = *dp.add(k);
364                *bp.add(k) = v;
365                has_nan |= v != v;
366                k += 1;
367            }
368        }
369        if has_nan {
370            unsafe { *out.get_unchecked_mut(i) = f64::NAN };
371            continue;
372        }
373
374        let med = median_from(&mut buf, mid);
375
376        unsafe {
377            let bp = buf.as_mut_ptr();
378            let mut k = 0usize;
379            while k + 4 <= period {
380                let a = *bp.add(k) - med;
381                let b = *bp.add(k + 1) - med;
382                let c = *bp.add(k + 2) - med;
383                let d = *bp.add(k + 3) - med;
384                *bp.add(k) = fast_abs_f64(a);
385                *bp.add(k + 1) = fast_abs_f64(b);
386                *bp.add(k + 2) = fast_abs_f64(c);
387                *bp.add(k + 3) = fast_abs_f64(d);
388                k += 4;
389            }
390            while k < period {
391                let t = *bp.add(k) - med;
392                *bp.add(k) = fast_abs_f64(t);
393                k += 1;
394            }
395        }
396
397        let mad = median_from(&mut buf, mid);
398        unsafe { *out.get_unchecked_mut(i) = mad };
399    }
400}
401
402#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
403#[inline]
404pub fn medium_ad_avx512(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
405    use core::cmp::Ordering;
406    unsafe {
407        let len = data.len();
408
409        let mut buf: Vec<f64> = Vec::with_capacity(period);
410        unsafe { buf.set_len(period) };
411        let mid = period >> 1;
412        let sign_mask = _mm512_set1_pd(-0.0);
413
414        #[inline(always)]
415        fn median_from(buf: &mut [f64], mid: usize) -> f64 {
416            buf.select_nth_unstable_by(mid, |a, b| {
417                if *a < *b {
418                    Ordering::Less
419                } else if *a > *b {
420                    Ordering::Greater
421                } else {
422                    Ordering::Equal
423                }
424            });
425            if (buf.len() & 1) == 1 {
426                unsafe { *buf.get_unchecked(mid) }
427            } else {
428                let mut lo_max = f64::NEG_INFINITY;
429                for &v in (&buf[..mid]).iter() {
430                    if v > lo_max {
431                        lo_max = v;
432                    }
433                }
434                0.5 * (lo_max + unsafe { *buf.get_unchecked(mid) })
435            }
436        }
437
438        let warm = first_valid + period - 1;
439        for i in warm..len {
440            let start = i + 1 - period;
441
442            let mut has_nan = false;
443            let mut k = 0usize;
444            while k + 8 <= period {
445                let v = _mm512_loadu_pd(data.as_ptr().add(start + k));
446                _mm512_storeu_pd(buf.as_mut_ptr().add(k), v);
447                let m = _mm512_cmp_pd_mask(v, v, 0x03);
448                if m != 0 {
449                    has_nan = true;
450                }
451                k += 8;
452            }
453
454            while k + 4 <= period {
455                let v = _mm256_loadu_pd(data.as_ptr().add(start + k));
456                _mm256_storeu_pd(buf.as_mut_ptr().add(k), v);
457                let nan_mask = _mm256_cmp_pd(v, v, 0x03);
458                if _mm256_movemask_pd(nan_mask) != 0 {
459                    has_nan = true;
460                }
461                k += 4;
462            }
463            while k < period {
464                let val = *data.get_unchecked(start + k);
465                *buf.get_unchecked_mut(k) = val;
466                has_nan |= val != val;
467                k += 1;
468            }
469            if has_nan {
470                *out.get_unchecked_mut(i) = f64::NAN;
471                continue;
472            }
473
474            let med = median_from(&mut buf, mid);
475
476            let mv = _mm512_set1_pd(med);
477            let mut k = 0usize;
478            while k + 8 <= period {
479                let x = _mm512_loadu_pd(buf.as_ptr().add(k));
480                let d = _mm512_sub_pd(x, mv);
481                let ad = _mm512_andnot_pd(sign_mask, d);
482                _mm512_storeu_pd(buf.as_mut_ptr().add(k), ad);
483                k += 8;
484            }
485            while k + 4 <= period {
486                let x = _mm256_loadu_pd(buf.as_ptr().add(k));
487                let mv4 = _mm256_set1_pd(med);
488                let sign4 = _mm256_set1_pd(-0.0);
489                let d = _mm256_sub_pd(x, mv4);
490                let ad = _mm256_andnot_pd(sign4, d);
491                _mm256_storeu_pd(buf.as_mut_ptr().add(k), ad);
492                k += 4;
493            }
494            while k < period {
495                let t = *buf.get_unchecked(k) - med;
496                *buf.get_unchecked_mut(k) = f64::from_bits(t.to_bits() & 0x7FFF_FFFF_FFFF_FFFF);
497                k += 1;
498            }
499
500            *out.get_unchecked_mut(i) = median_from(&mut buf, mid);
501        }
502    }
503}
504
505#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
506#[inline]
507pub fn medium_ad_avx2(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
508    use core::cmp::Ordering;
509    unsafe {
510        let len = data.len();
511
512        let mut buf: Vec<f64> = Vec::with_capacity(period);
513        unsafe { buf.set_len(period) };
514        let mid = period >> 1;
515        let sign_mask = _mm256_set1_pd(-0.0);
516
517        #[inline(always)]
518        fn median_from(buf: &mut [f64], mid: usize) -> f64 {
519            buf.select_nth_unstable_by(mid, |a, b| {
520                if *a < *b {
521                    Ordering::Less
522                } else if *a > *b {
523                    Ordering::Greater
524                } else {
525                    Ordering::Equal
526                }
527            });
528            if (buf.len() & 1) == 1 {
529                unsafe { *buf.get_unchecked(mid) }
530            } else {
531                let mut lo_max = f64::NEG_INFINITY;
532                for &v in (&buf[..mid]).iter() {
533                    if v > lo_max {
534                        lo_max = v;
535                    }
536                }
537                0.5 * (lo_max + unsafe { *buf.get_unchecked(mid) })
538            }
539        }
540
541        let warm = first_valid + period - 1;
542        for i in warm..len {
543            let start = i + 1 - period;
544
545            let mut has_nan = false;
546            let mut k = 0usize;
547            while k + 4 <= period {
548                let v = _mm256_loadu_pd(data.as_ptr().add(start + k));
549                _mm256_storeu_pd(buf.as_mut_ptr().add(k), v);
550
551                let nan_mask = _mm256_cmp_pd(v, v, 0x03);
552                if _mm256_movemask_pd(nan_mask) != 0 {
553                    has_nan = true;
554                }
555                k += 4;
556            }
557            while k < period {
558                let val = *data.get_unchecked(start + k);
559                *buf.get_unchecked_mut(k) = val;
560                has_nan |= val != val;
561                k += 1;
562            }
563            if has_nan {
564                *out.get_unchecked_mut(i) = f64::NAN;
565                continue;
566            }
567
568            let med = median_from(&mut buf, mid);
569
570            let mv = _mm256_set1_pd(med);
571            let mut k = 0usize;
572            while k + 4 <= period {
573                let x = _mm256_loadu_pd(buf.as_ptr().add(k));
574                let d = _mm256_sub_pd(x, mv);
575                let ad = _mm256_andnot_pd(sign_mask, d);
576                _mm256_storeu_pd(buf.as_mut_ptr().add(k), ad);
577                k += 4;
578            }
579            while k < period {
580                let t = *buf.get_unchecked(k) - med;
581                *buf.get_unchecked_mut(k) = f64::from_bits(t.to_bits() & 0x7FFF_FFFF_FFFF_FFFF);
582                k += 1;
583            }
584
585            *out.get_unchecked_mut(i) = median_from(&mut buf, mid);
586        }
587    }
588}
589
590#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
591#[inline]
592pub fn medium_ad_avx512_short(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
593    unsafe { medium_ad_scalar(data, period, first_valid, out) }
594}
595
596#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
597#[inline]
598pub fn medium_ad_avx512_long(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
599    unsafe { medium_ad_scalar(data, period, first_valid, out) }
600}
601
602#[inline(always)]
603pub fn medium_ad_batch_with_kernel(
604    data: &[f64],
605    sweep: &MediumAdBatchRange,
606    k: Kernel,
607) -> Result<MediumAdBatchOutput, MediumAdError> {
608    let kernel = match k {
609        Kernel::Auto => Kernel::ScalarBatch,
610        other if other.is_batch() => other,
611        other => return Err(MediumAdError::InvalidKernelForBatch(other)),
612    };
613
614    let simd = match kernel {
615        Kernel::Avx512Batch => Kernel::Avx512,
616        Kernel::Avx2Batch => Kernel::Avx2,
617        Kernel::ScalarBatch => Kernel::Scalar,
618        _ => unreachable!(),
619    };
620    medium_ad_batch_par_slice(data, sweep, simd)
621}
622
623#[derive(Clone, Debug)]
624pub struct MediumAdBatchRange {
625    pub period: (usize, usize, usize),
626}
627
628impl Default for MediumAdBatchRange {
629    fn default() -> Self {
630        Self {
631            period: (5, 254, 1),
632        }
633    }
634}
635
636#[derive(Clone, Debug, Default)]
637pub struct MediumAdBatchBuilder {
638    range: MediumAdBatchRange,
639    kernel: Kernel,
640}
641
642impl MediumAdBatchBuilder {
643    pub fn new() -> Self {
644        Self::default()
645    }
646    pub fn kernel(mut self, k: Kernel) -> Self {
647        self.kernel = k;
648        self
649    }
650
651    #[inline]
652    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
653        self.range.period = (start, end, step);
654        self
655    }
656    #[inline]
657    pub fn period_static(mut self, p: usize) -> Self {
658        self.range.period = (p, p, 0);
659        self
660    }
661
662    pub fn apply_slice(self, data: &[f64]) -> Result<MediumAdBatchOutput, MediumAdError> {
663        medium_ad_batch_with_kernel(data, &self.range, self.kernel)
664    }
665
666    pub fn with_default_slice(
667        data: &[f64],
668        k: Kernel,
669    ) -> Result<MediumAdBatchOutput, MediumAdError> {
670        MediumAdBatchBuilder::new().kernel(k).apply_slice(data)
671    }
672
673    pub fn apply_candles(
674        self,
675        c: &Candles,
676        src: &str,
677    ) -> Result<MediumAdBatchOutput, MediumAdError> {
678        let slice = source_type(c, src);
679        self.apply_slice(slice)
680    }
681
682    pub fn with_default_candles(c: &Candles) -> Result<MediumAdBatchOutput, MediumAdError> {
683        MediumAdBatchBuilder::new()
684            .kernel(Kernel::Auto)
685            .apply_candles(c, "close")
686    }
687}
688
689#[derive(Clone, Debug)]
690pub struct MediumAdBatchOutput {
691    pub values: Vec<f64>,
692    pub combos: Vec<MediumAdParams>,
693    pub rows: usize,
694    pub cols: usize,
695}
696impl MediumAdBatchOutput {
697    pub fn row_for_params(&self, p: &MediumAdParams) -> Option<usize> {
698        self.combos
699            .iter()
700            .position(|c| c.period.unwrap_or(5) == p.period.unwrap_or(5))
701    }
702
703    pub fn values_for(&self, p: &MediumAdParams) -> Option<&[f64]> {
704        self.row_for_params(p).and_then(|row| {
705            let start = row.checked_mul(self.cols)?;
706            let end = start.checked_add(self.cols)?;
707            self.values.get(start..end)
708        })
709    }
710}
711
712#[inline(always)]
713fn expand_grid(r: &MediumAdBatchRange) -> Result<Vec<MediumAdParams>, MediumAdError> {
714    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, MediumAdError> {
715        if step == 0 || start == end {
716            return Ok(vec![start]);
717        }
718        if start < end {
719            return Ok((start..=end).step_by(step.max(1)).collect());
720        }
721        let mut v = Vec::new();
722        let mut x = start as isize;
723        let end_i = end as isize;
724        let st = (step as isize).max(1);
725        while x >= end_i {
726            v.push(x as usize);
727            x = x.saturating_sub(st);
728            if x < 0 {
729                break;
730            }
731        }
732        if v.is_empty() {
733            return Err(MediumAdError::InvalidRange {
734                start: start.to_string(),
735                end: end.to_string(),
736                step: step.to_string(),
737            });
738        }
739        Ok(v)
740    }
741
742    let periods = axis_usize(r.period)?;
743    if periods.is_empty() {
744        return Err(MediumAdError::InvalidRange {
745            start: r.period.0.to_string(),
746            end: r.period.1.to_string(),
747            step: r.period.2.to_string(),
748        });
749    }
750
751    let mut out = Vec::with_capacity(periods.len());
752    for &p in &periods {
753        out.push(MediumAdParams { period: Some(p) });
754    }
755    Ok(out)
756}
757
758#[inline(always)]
759pub fn medium_ad_batch_slice(
760    data: &[f64],
761    sweep: &MediumAdBatchRange,
762    kern: Kernel,
763) -> Result<MediumAdBatchOutput, MediumAdError> {
764    medium_ad_batch_inner(data, sweep, kern, false)
765}
766
767#[inline(always)]
768pub fn medium_ad_batch_par_slice(
769    data: &[f64],
770    sweep: &MediumAdBatchRange,
771    kern: Kernel,
772) -> Result<MediumAdBatchOutput, MediumAdError> {
773    medium_ad_batch_inner(data, sweep, kern, true)
774}
775
776#[inline(always)]
777fn medium_ad_batch_inner(
778    data: &[f64],
779    sweep: &MediumAdBatchRange,
780    kern: Kernel,
781    parallel: bool,
782) -> Result<MediumAdBatchOutput, MediumAdError> {
783    let combos = expand_grid(sweep)?;
784
785    let cols = data.len();
786    if cols == 0 {
787        return Err(MediumAdError::AllValuesNaN);
788    }
789
790    let first = data
791        .iter()
792        .position(|x| !x.is_nan())
793        .ok_or(MediumAdError::AllValuesNaN)?;
794    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
795    if cols - first < max_p {
796        return Err(MediumAdError::NotEnoughValidData {
797            needed: max_p,
798            valid: cols - first,
799        });
800    }
801
802    let rows = combos.len();
803
804    let _total_elems = rows.checked_mul(cols).ok_or(MediumAdError::InvalidRange {
805        start: sweep.period.0.to_string(),
806        end: sweep.period.1.to_string(),
807        step: sweep.period.2.to_string(),
808    })?;
809    let mut buf_mu = make_uninit_matrix(rows, cols);
810    let warm: Vec<usize> = combos
811        .iter()
812        .map(|c| first + c.period.unwrap() - 1)
813        .collect();
814    init_matrix_prefixes(&mut buf_mu, cols, &warm);
815
816    let out_mu = buf_mu.as_mut_slice();
817
818    let do_row = |row: usize, dst_mu: &mut [core::mem::MaybeUninit<f64>]| {
819        let dst = unsafe {
820            core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len())
821        };
822        let period = combos[row].period.unwrap();
823
824        unsafe {
825            match kern {
826                Kernel::Scalar | Kernel::Auto => medium_ad_row_scalar(data, first, period, dst),
827                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
828                Kernel::Avx2 => medium_ad_row_avx2(data, first, period, dst),
829                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
830                Kernel::Avx512 => medium_ad_row_avx512(data, first, period, dst),
831                _ => unreachable!(),
832            }
833        }
834    };
835
836    if parallel {
837        #[cfg(not(target_arch = "wasm32"))]
838        {
839            use rayon::prelude::*;
840            out_mu
841                .par_chunks_mut(cols)
842                .enumerate()
843                .for_each(|(row, slice_mu)| do_row(row, slice_mu));
844        }
845        #[cfg(target_arch = "wasm32")]
846        {
847            for (row, slice_mu) in out_mu.chunks_mut(cols).enumerate() {
848                do_row(row, slice_mu);
849            }
850        }
851    } else {
852        for (row, slice_mu) in out_mu.chunks_mut(cols).enumerate() {
853            do_row(row, slice_mu);
854        }
855    }
856
857    let mut guard = core::mem::ManuallyDrop::new(buf_mu);
858    let values = unsafe {
859        Vec::from_raw_parts(
860            guard.as_mut_ptr() as *mut f64,
861            guard.len(),
862            guard.capacity(),
863        )
864    };
865
866    Ok(MediumAdBatchOutput {
867        values,
868        combos,
869        rows,
870        cols,
871    })
872}
873
874#[inline(always)]
875fn medium_ad_batch_inner_into(
876    data: &[f64],
877    sweep: &MediumAdBatchRange,
878    kern: Kernel,
879    parallel: bool,
880    out: &mut [f64],
881) -> Result<Vec<MediumAdParams>, MediumAdError> {
882    let combos = expand_grid(sweep)?;
883
884    let first = data
885        .iter()
886        .position(|x| !x.is_nan())
887        .ok_or(MediumAdError::AllValuesNaN)?;
888    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
889    if data.len() - first < max_p {
890        return Err(MediumAdError::NotEnoughValidData {
891            needed: max_p,
892            valid: data.len() - first,
893        });
894    }
895
896    let cols = data.len();
897
898    let _total_elems = combos
899        .len()
900        .checked_mul(cols)
901        .ok_or(MediumAdError::InvalidRange {
902            start: sweep.period.0.to_string(),
903            end: sweep.period.1.to_string(),
904            step: sweep.period.2.to_string(),
905        })?;
906
907    for (row, combo) in combos.iter().enumerate() {
908        let warmup = first + combo.period.unwrap() - 1;
909        let row_start = match row.checked_mul(cols) {
910            Some(v) => v,
911            None => {
912                return Err(MediumAdError::InvalidRange {
913                    start: sweep.period.0.to_string(),
914                    end: sweep.period.1.to_string(),
915                    step: sweep.period.2.to_string(),
916                })
917            }
918        };
919        for i in 0..warmup.min(cols) {
920            out[row_start + i] = f64::NAN;
921        }
922    }
923
924    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
925        let period = combos[row].period.unwrap();
926        match kern {
927            Kernel::Scalar => medium_ad_row_scalar(data, first, period, out_row),
928            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
929            Kernel::Avx2 => medium_ad_row_avx2(data, first, period, out_row),
930            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
931            Kernel::Avx512 => medium_ad_row_avx512(data, first, period, out_row),
932            _ => unreachable!(),
933        }
934    };
935
936    if parallel {
937        #[cfg(not(target_arch = "wasm32"))]
938        {
939            out.par_chunks_mut(cols)
940                .enumerate()
941                .for_each(|(row, slice)| do_row(row, slice));
942        }
943
944        #[cfg(target_arch = "wasm32")]
945        {
946            for (row, slice) in out.chunks_mut(cols).enumerate() {
947                do_row(row, slice);
948            }
949        }
950    } else {
951        for (row, slice) in out.chunks_mut(cols).enumerate() {
952            do_row(row, slice);
953        }
954    }
955
956    Ok(combos)
957}
958
959#[inline(always)]
960unsafe fn medium_ad_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
961    use core::cmp::Ordering;
962
963    #[inline(always)]
964    fn fast_abs_f64(x: f64) -> f64 {
965        f64::from_bits(x.to_bits() & 0x7FFF_FFFF_FFFF_FFFF)
966    }
967    #[inline(always)]
968    fn median_from(buf: &mut [f64], mid: usize) -> f64 {
969        buf.select_nth_unstable_by(mid, |a, b| {
970            if *a < *b {
971                Ordering::Less
972            } else if *a > *b {
973                Ordering::Greater
974            } else {
975                Ordering::Equal
976            }
977        });
978        if (buf.len() & 1) == 1 {
979            unsafe { *buf.get_unchecked(mid) }
980        } else {
981            let mut lo_max = f64::NEG_INFINITY;
982            let left = unsafe { core::slice::from_raw_parts(buf.as_ptr(), mid) };
983            for &v in left.iter() {
984                if v > lo_max {
985                    lo_max = v;
986                }
987            }
988            0.5 * (lo_max + unsafe { *buf.get_unchecked(mid) })
989        }
990    }
991
992    if period == 1 {
993        let warm = first;
994        for i in warm..data.len() {
995            let v = *data.get_unchecked(i);
996            *out.get_unchecked_mut(i) = if v.is_nan() { f64::NAN } else { 0.0 };
997        }
998        return;
999    }
1000
1001    let mut buf: Vec<f64> = Vec::with_capacity(period);
1002    unsafe { buf.set_len(period) };
1003    let mid = period >> 1;
1004    let warm = first + period - 1;
1005
1006    for i in warm..data.len() {
1007        let start = i + 1 - period;
1008
1009        let mut has_nan = false;
1010        let dp = data.as_ptr().add(start);
1011        let bp = buf.as_mut_ptr();
1012        let mut k = 0usize;
1013        while k + 4 <= period {
1014            let a = *dp.add(k);
1015            let b = *dp.add(k + 1);
1016            let c = *dp.add(k + 2);
1017            let d = *dp.add(k + 3);
1018            *bp.add(k) = a;
1019            *bp.add(k + 1) = b;
1020            *bp.add(k + 2) = c;
1021            *bp.add(k + 3) = d;
1022            has_nan |= (a != a) | (b != b) | (c != c) | (d != d);
1023            k += 4;
1024        }
1025        while k < period {
1026            let v = *dp.add(k);
1027            *bp.add(k) = v;
1028            has_nan |= v != v;
1029            k += 1;
1030        }
1031        if has_nan {
1032            *out.get_unchecked_mut(i) = f64::NAN;
1033            continue;
1034        }
1035
1036        let med = median_from(&mut buf, mid);
1037
1038        let bp = buf.as_mut_ptr();
1039        let mut k = 0usize;
1040        while k + 4 <= period {
1041            let a = *bp.add(k) - med;
1042            let b = *bp.add(k + 1) - med;
1043            let c = *bp.add(k + 2) - med;
1044            let d = *bp.add(k + 3) - med;
1045            *bp.add(k) = fast_abs_f64(a);
1046            *bp.add(k + 1) = fast_abs_f64(b);
1047            *bp.add(k + 2) = fast_abs_f64(c);
1048            *bp.add(k + 3) = fast_abs_f64(d);
1049            k += 4;
1050        }
1051        while k < period {
1052            let t = *bp.add(k) - med;
1053            *bp.add(k) = fast_abs_f64(t);
1054            k += 1;
1055        }
1056
1057        *out.get_unchecked_mut(i) = median_from(&mut buf, mid);
1058    }
1059}
1060
1061#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1062#[inline(always)]
1063unsafe fn medium_ad_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1064    medium_ad_avx2(data, period, first, out)
1065}
1066
1067#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1068#[inline(always)]
1069unsafe fn medium_ad_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1070    if period <= 32 {
1071        medium_ad_row_avx512_short(data, first, period, out)
1072    } else {
1073        medium_ad_row_avx512_long(data, first, period, out)
1074    }
1075}
1076
1077#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1078#[inline(always)]
1079unsafe fn medium_ad_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1080    medium_ad_avx512(data, period, first, out)
1081}
1082
1083#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1084#[inline(always)]
1085unsafe fn medium_ad_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1086    medium_ad_avx512(data, period, first, out)
1087}
1088
1089#[derive(Debug, Clone)]
1090pub struct MediumAdStream {
1091    period: usize,
1092
1093    ring: Vec<Option<Entry>>,
1094    head: usize,
1095    filled: bool,
1096
1097    os: OrderStatTree,
1098    next_id: u64,
1099}
1100
1101#[derive(Copy, Clone, Debug)]
1102struct Entry {
1103    val: f64,
1104    id: u64,
1105}
1106
1107impl MediumAdStream {
1108    pub fn try_new(params: MediumAdParams) -> Result<Self, MediumAdError> {
1109        let period = params.period.unwrap_or(5);
1110        if period == 0 {
1111            return Err(MediumAdError::InvalidPeriod {
1112                period,
1113                data_len: 0,
1114            });
1115        }
1116        Ok(Self {
1117            period,
1118            ring: vec![None; period],
1119            head: 0,
1120            filled: false,
1121            os: OrderStatTree::new(),
1122            next_id: 1,
1123        })
1124    }
1125
1126    #[inline(always)]
1127    pub fn update(&mut self, value: f64) -> Option<f64> {
1128        if let Some(old) = self.ring[self.head] {
1129            self.os.remove(old);
1130        }
1131
1132        let _inserted = if value.is_nan() {
1133            self.ring[self.head] = None;
1134            false
1135        } else {
1136            let e = Entry {
1137                val: value,
1138                id: self.next_id,
1139            };
1140            self.next_id = self.next_id.wrapping_add(1);
1141            self.os.insert(e);
1142            self.ring[self.head] = Some(e);
1143            true
1144        };
1145
1146        self.head = (self.head + 1) % self.period;
1147        if !self.filled && self.head == 0 {
1148            self.filled = true;
1149        }
1150
1151        if !self.filled || self.os.len() != self.period {
1152            return None;
1153        }
1154
1155        if self.period == 1 {
1156            return Some(0.0);
1157        }
1158
1159        let n = self.period;
1160        let left_sz = n >> 1;
1161        let median = if (n & 1) == 1 {
1162            self.os.kth(left_sz).val
1163        } else {
1164            let lo = self.os.kth(left_sz - 1).val;
1165            let hi = self.os.kth(left_sz).val;
1166            0.5 * (lo + hi)
1167        };
1168
1169        Some(self.mad_from_tree(median))
1170    }
1171
1172    #[inline(always)]
1173    fn ldist(&self, i: usize, median: f64, left_sz: usize) -> f64 {
1174        let idx = left_sz - 1 - i;
1175        let x = self.os.kth(idx).val;
1176
1177        median - x
1178    }
1179
1180    #[inline(always)]
1181    fn rdist(&self, j: usize, median: f64, left_sz: usize) -> f64 {
1182        let idx = left_sz + j;
1183        let x = self.os.kth(idx).val;
1184        x - median
1185    }
1186
1187    #[inline(always)]
1188    fn kth_absdev_union(&self, k: usize, median: f64, left_sz: usize) -> f64 {
1189        let right_sz = self.period - left_sz;
1190
1191        let mut lo = if k > right_sz { k - right_sz } else { 0 };
1192        let mut hi = k.min(left_sz);
1193
1194        while lo <= hi {
1195            let i = (lo + hi) >> 1;
1196            let j = k - i;
1197
1198            let l_im1 = if i == 0 {
1199                f64::NEG_INFINITY
1200            } else {
1201                self.ldist(i - 1, median, left_sz)
1202            };
1203            let l_i = if i == left_sz {
1204                f64::INFINITY
1205            } else {
1206                self.ldist(i, median, left_sz)
1207            };
1208
1209            let r_jm1 = if j == 0 {
1210                f64::NEG_INFINITY
1211            } else {
1212                self.rdist(j - 1, median, left_sz)
1213            };
1214            let r_j = if j == right_sz {
1215                f64::INFINITY
1216            } else {
1217                self.rdist(j, median, left_sz)
1218            };
1219
1220            if l_im1 <= r_j && r_jm1 <= l_i {
1221                return if l_im1 > r_jm1 { l_im1 } else { r_jm1 };
1222            } else if l_im1 > r_j {
1223                hi = i - 1;
1224            } else {
1225                lo = i + 1;
1226            }
1227        }
1228        debug_assert!(false, "kth_absdev_union: unreachable");
1229        0.0
1230    }
1231
1232    #[inline(always)]
1233    fn mad_from_tree(&self, median: f64) -> f64 {
1234        let n = self.period;
1235        let mid = n >> 1;
1236        let mut buf = Vec::with_capacity(n);
1237        for i in 0..n {
1238            let x = self.os.kth(i).val;
1239            buf.push((x - median).abs());
1240        }
1241        use core::cmp::Ordering;
1242        buf.select_nth_unstable_by(mid, |a, b| {
1243            if *a < *b {
1244                Ordering::Less
1245            } else if *a > *b {
1246                Ordering::Greater
1247            } else {
1248                Ordering::Equal
1249            }
1250        });
1251        if (n & 1) == 1 {
1252            buf[mid]
1253        } else {
1254            let mut lo_max = f64::NEG_INFINITY;
1255            for &v in &buf[..mid] {
1256                if v > lo_max {
1257                    lo_max = v;
1258                }
1259            }
1260            0.5 * (lo_max + buf[mid])
1261        }
1262    }
1263}
1264
1265#[derive(Default, Debug, Clone)]
1266struct OrderStatTree {
1267    root: Link,
1268}
1269
1270type Link = Option<Box<Node>>;
1271
1272#[derive(Debug, Clone)]
1273struct Node {
1274    key: Entry,
1275    prio: u64,
1276    size: usize,
1277    left: Link,
1278    right: Link,
1279}
1280
1281impl OrderStatTree {
1282    #[inline(always)]
1283    fn new() -> Self {
1284        Self { root: None }
1285    }
1286
1287    #[inline(always)]
1288    fn len(&self) -> usize {
1289        size_of(&self.root)
1290    }
1291
1292    #[inline(always)]
1293    fn insert(&mut self, key: Entry) {
1294        let prio = treap_priority(key);
1295        self.root = insert_rec(self.root.take(), key, prio);
1296    }
1297
1298    #[inline(always)]
1299    fn remove(&mut self, key: Entry) {
1300        self.root = remove_rec(self.root.take(), key);
1301    }
1302
1303    #[inline(always)]
1304    fn kth(&self, k: usize) -> Entry {
1305        kth_rec(&self.root, k)
1306    }
1307}
1308
1309#[inline(always)]
1310fn size_of(n: &Link) -> usize {
1311    n.as_ref().map_or(0, |b| b.size)
1312}
1313
1314#[inline(always)]
1315fn upd(node: &mut Box<Node>) {
1316    node.size = 1 + size_of(&node.left) + size_of(&node.right);
1317}
1318
1319#[inline(always)]
1320fn less(a: Entry, b: Entry) -> bool {
1321    if a.val < b.val {
1322        true
1323    } else if a.val > b.val {
1324        false
1325    } else {
1326        a.id < b.id
1327    }
1328}
1329
1330#[inline(always)]
1331fn rotate_left(mut x: Box<Node>) -> Box<Node> {
1332    let mut y = x.right.take().expect("rotate_left requires right child");
1333    x.right = y.left.take();
1334    upd(&mut x);
1335    y.left = Some(x);
1336    upd(&mut y);
1337    y
1338}
1339
1340#[inline(always)]
1341fn rotate_right(mut y: Box<Node>) -> Box<Node> {
1342    let mut x = y.left.take().expect("rotate_right requires left child");
1343    y.left = x.right.take();
1344    upd(&mut y);
1345    x.right = Some(y);
1346    upd(&mut x);
1347    x
1348}
1349
1350fn insert_rec(node: Link, key: Entry, prio: u64) -> Link {
1351    match node {
1352        None => Some(Box::new(Node {
1353            key,
1354            prio,
1355            size: 1,
1356            left: None,
1357            right: None,
1358        })),
1359        Some(mut n) => {
1360            if less(key, n.key) {
1361                n.left = insert_rec(n.left.take(), key, prio);
1362                if n.left.as_ref().unwrap().prio > n.prio {
1363                    n = rotate_right(n);
1364                }
1365            } else {
1366                n.right = insert_rec(n.right.take(), key, prio);
1367                if n.right.as_ref().unwrap().prio > n.prio {
1368                    n = rotate_left(n);
1369                }
1370            }
1371            upd(&mut n);
1372            Some(n)
1373        }
1374    }
1375}
1376
1377fn remove_rec(node: Link, key: Entry) -> Link {
1378    match node {
1379        None => None,
1380        Some(mut n) => {
1381            if n.key.id == key.id {
1382                return match (n.left.take(), n.right.take()) {
1383                    (None, r) => r,
1384                    (l, None) => l,
1385                    (Some(lc), Some(rc)) => {
1386                        let (mut n2, left_is_higher) = if lc.prio > rc.prio {
1387                            let mut n2 = Box::new(Node {
1388                                key: n.key,
1389                                prio: n.prio,
1390                                size: 0,
1391                                left: Some(lc),
1392                                right: Some(rc),
1393                            });
1394                            n2 = rotate_right(n2);
1395                            (n2, true)
1396                        } else {
1397                            let mut n2 = Box::new(Node {
1398                                key: n.key,
1399                                prio: n.prio,
1400                                size: 0,
1401                                left: Some(lc),
1402                                right: Some(rc),
1403                            });
1404                            n2 = rotate_left(n2);
1405                            (n2, false)
1406                        };
1407                        if left_is_higher {
1408                            n2.right = remove_rec(n2.right.take(), key);
1409                        } else {
1410                            n2.left = remove_rec(n2.left.take(), key);
1411                        }
1412                        upd(&mut n2);
1413                        Some(n2)
1414                    }
1415                };
1416            }
1417            if less(key, n.key) {
1418                n.left = remove_rec(n.left.take(), key);
1419            } else {
1420                n.right = remove_rec(n.right.take(), key);
1421            }
1422            upd(&mut n);
1423            Some(n)
1424        }
1425    }
1426}
1427
1428fn kth_rec(node: &Link, mut k: usize) -> Entry {
1429    let n = node.as_ref().expect("kth_rec on empty tree");
1430    let ls = size_of(&n.left);
1431    if k < ls {
1432        kth_rec(&n.left, k)
1433    } else if k == ls {
1434        n.key
1435    } else {
1436        k -= ls + 1;
1437        kth_rec(&n.right, k)
1438    }
1439}
1440
1441#[inline(always)]
1442fn treap_priority(e: Entry) -> u64 {
1443    let mut z = e.id ^ e.val.to_bits();
1444    z = z.wrapping_add(0x9E37_79B9_7F4A_7C15);
1445    z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
1446    z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
1447    z ^ (z >> 31)
1448}
1449
1450#[cfg(test)]
1451mod tests {
1452    use super::*;
1453    use crate::skip_if_unsupported;
1454    use crate::utilities::data_loader::read_candles_from_csv;
1455    #[cfg(feature = "proptest")]
1456    use proptest::prelude::*;
1457
1458    fn check_medium_ad_partial_params(
1459        test_name: &str,
1460        kernel: Kernel,
1461    ) -> Result<(), Box<dyn std::error::Error>> {
1462        skip_if_unsupported!(kernel, test_name);
1463        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1464        let candles = read_candles_from_csv(file_path)?;
1465
1466        let default_params = MediumAdParams { period: None };
1467        let input = MediumAdInput::from_candles(&candles, "close", default_params);
1468        let output = medium_ad_with_kernel(&input, kernel)?;
1469        assert_eq!(output.values.len(), candles.close.len());
1470        Ok(())
1471    }
1472
1473    fn check_medium_ad_accuracy(
1474        test_name: &str,
1475        kernel: Kernel,
1476    ) -> Result<(), Box<dyn std::error::Error>> {
1477        skip_if_unsupported!(kernel, test_name);
1478        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1479        let candles = read_candles_from_csv(file_path)?;
1480
1481        let params = MediumAdParams { period: Some(5) };
1482        let input = MediumAdInput::from_candles(&candles, "hl2", params);
1483        let result = medium_ad_with_kernel(&input, kernel)?;
1484        let expected_last_five = [220.0, 78.5, 126.5, 48.0, 28.5];
1485        let start = result.values.len().saturating_sub(5);
1486        for (i, &val) in result.values[start..].iter().enumerate() {
1487            let diff = (val - expected_last_five[i]).abs();
1488            assert!(
1489                diff < 1e-1,
1490                "[{}] MEDIUM_AD {:?} mismatch at idx {}: got {}, expected {}",
1491                test_name,
1492                kernel,
1493                i,
1494                val,
1495                expected_last_five[i]
1496            );
1497        }
1498        Ok(())
1499    }
1500
1501    fn check_medium_ad_default_candles(
1502        test_name: &str,
1503        kernel: Kernel,
1504    ) -> Result<(), Box<dyn std::error::Error>> {
1505        skip_if_unsupported!(kernel, test_name);
1506        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1507        let candles = read_candles_from_csv(file_path)?;
1508        let input = MediumAdInput::with_default_candles(&candles);
1509        match input.data {
1510            MediumAdData::Candles { source, .. } => assert_eq!(source, "close"),
1511            _ => panic!("Expected MediumAdData::Candles"),
1512        }
1513        let output = medium_ad_with_kernel(&input, kernel)?;
1514        assert_eq!(output.values.len(), candles.close.len());
1515        Ok(())
1516    }
1517
1518    fn check_medium_ad_zero_period(
1519        test_name: &str,
1520        kernel: Kernel,
1521    ) -> Result<(), Box<dyn std::error::Error>> {
1522        skip_if_unsupported!(kernel, test_name);
1523        let input_data = [10.0, 20.0, 30.0];
1524        let params = MediumAdParams { period: Some(0) };
1525        let input = MediumAdInput::from_slice(&input_data, params);
1526        let res = medium_ad_with_kernel(&input, kernel);
1527        assert!(
1528            res.is_err(),
1529            "[{}] MEDIUM_AD should fail with zero period",
1530            test_name
1531        );
1532        Ok(())
1533    }
1534
1535    fn check_medium_ad_period_exceeds_length(
1536        test_name: &str,
1537        kernel: Kernel,
1538    ) -> Result<(), Box<dyn std::error::Error>> {
1539        skip_if_unsupported!(kernel, test_name);
1540        let data_small = [10.0, 20.0, 30.0];
1541        let params = MediumAdParams { period: Some(10) };
1542        let input = MediumAdInput::from_slice(&data_small, params);
1543        let res = medium_ad_with_kernel(&input, kernel);
1544        assert!(
1545            res.is_err(),
1546            "[{}] MEDIUM_AD should fail with period exceeding length",
1547            test_name
1548        );
1549        Ok(())
1550    }
1551
1552    fn check_medium_ad_very_small_dataset(
1553        test_name: &str,
1554        kernel: Kernel,
1555    ) -> Result<(), Box<dyn std::error::Error>> {
1556        skip_if_unsupported!(kernel, test_name);
1557        let single_point = [42.0];
1558        let params = MediumAdParams { period: Some(5) };
1559        let input = MediumAdInput::from_slice(&single_point, params);
1560        let res = medium_ad_with_kernel(&input, kernel);
1561        assert!(
1562            res.is_err(),
1563            "[{}] MEDIUM_AD should fail with insufficient data",
1564            test_name
1565        );
1566        Ok(())
1567    }
1568
1569    fn check_medium_ad_reinput(
1570        test_name: &str,
1571        kernel: Kernel,
1572    ) -> Result<(), Box<dyn std::error::Error>> {
1573        skip_if_unsupported!(kernel, test_name);
1574        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1575        let candles = read_candles_from_csv(file_path)?;
1576
1577        let first_params = MediumAdParams { period: Some(5) };
1578        let first_input = MediumAdInput::from_candles(&candles, "close", first_params);
1579        let first_result = medium_ad_with_kernel(&first_input, kernel)?;
1580
1581        let second_params = MediumAdParams { period: Some(5) };
1582        let second_input = MediumAdInput::from_slice(&first_result.values, second_params);
1583        let second_result = medium_ad_with_kernel(&second_input, kernel)?;
1584
1585        assert_eq!(second_result.values.len(), first_result.values.len());
1586        Ok(())
1587    }
1588
1589    fn check_medium_ad_nan_handling(
1590        test_name: &str,
1591        kernel: Kernel,
1592    ) -> Result<(), Box<dyn std::error::Error>> {
1593        skip_if_unsupported!(kernel, test_name);
1594        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1595        let candles = read_candles_from_csv(file_path)?;
1596        let input =
1597            MediumAdInput::from_candles(&candles, "close", MediumAdParams { period: Some(5) });
1598        let res = medium_ad_with_kernel(&input, kernel)?;
1599        assert_eq!(res.values.len(), candles.close.len());
1600        if res.values.len() > 60 {
1601            for (i, &val) in res.values[60..].iter().enumerate() {
1602                assert!(
1603                    !val.is_nan(),
1604                    "[{}] Found unexpected NaN at out-index {}",
1605                    test_name,
1606                    60 + i
1607                );
1608            }
1609        }
1610        Ok(())
1611    }
1612
1613    fn check_medium_ad_streaming(
1614        test_name: &str,
1615        kernel: Kernel,
1616    ) -> Result<(), Box<dyn std::error::Error>> {
1617        skip_if_unsupported!(kernel, test_name);
1618
1619        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1620        let candles = read_candles_from_csv(file_path)?;
1621
1622        let period = 5;
1623        let input = MediumAdInput::from_candles(
1624            &candles,
1625            "close",
1626            MediumAdParams {
1627                period: Some(period),
1628            },
1629        );
1630        let batch_output = medium_ad_with_kernel(&input, kernel)?.values;
1631
1632        let mut stream = MediumAdStream::try_new(MediumAdParams {
1633            period: Some(period),
1634        })?;
1635
1636        let mut stream_values = Vec::with_capacity(candles.close.len());
1637        for &price in &candles.close {
1638            match stream.update(price) {
1639                Some(mad_val) => stream_values.push(mad_val),
1640                None => stream_values.push(f64::NAN),
1641            }
1642        }
1643
1644        assert_eq!(batch_output.len(), stream_values.len());
1645        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1646            if b.is_nan() && s.is_nan() {
1647                continue;
1648            }
1649            let diff = (b - s).abs();
1650            assert!(
1651                diff < 1e-9,
1652                "[{}] MEDIUM_AD streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1653                test_name,
1654                i,
1655                b,
1656                s,
1657                diff
1658            );
1659        }
1660        Ok(())
1661    }
1662
1663    #[cfg(feature = "proptest")]
1664    #[allow(clippy::float_cmp)]
1665    fn check_medium_ad_property(
1666        test_name: &str,
1667        kernel: Kernel,
1668    ) -> Result<(), Box<dyn std::error::Error>> {
1669        skip_if_unsupported!(kernel, test_name);
1670
1671        let strat = (1usize..=64).prop_flat_map(|period| {
1672            (
1673                prop::collection::vec(
1674                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1675                    period..400,
1676                ),
1677                Just(period),
1678            )
1679        });
1680
1681        proptest::test_runner::TestRunner::default()
1682            .run(&strat, |(data, period)| {
1683                let params = MediumAdParams {
1684                    period: Some(period),
1685                };
1686                let input = MediumAdInput::from_slice(&data, params);
1687
1688                let MediumAdOutput { values: out } = medium_ad_with_kernel(&input, kernel).unwrap();
1689
1690                let MediumAdOutput { values: ref_out } =
1691                    medium_ad_with_kernel(&input, Kernel::Scalar).unwrap();
1692
1693                for i in 0..data.len() {
1694                    let y = out[i];
1695                    let r = ref_out[i];
1696
1697                    if y.is_nan() {
1698                        prop_assert!(r.is_nan(), "Kernel consistency: NaN mismatch at idx {}", i);
1699                    } else if r.is_nan() {
1700                        prop_assert!(y.is_nan(), "Kernel consistency: NaN mismatch at idx {}", i);
1701                    } else {
1702                        let ulp_diff = y.to_bits().abs_diff(r.to_bits());
1703                        prop_assert!(
1704                            (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1705                            "Kernel mismatch at idx {}: {} vs {} (ULP={})",
1706                            i,
1707                            y,
1708                            r,
1709                            ulp_diff
1710                        );
1711                    }
1712                }
1713
1714                for i in 0..(period - 1) {
1715                    prop_assert!(
1716                        out[i].is_nan(),
1717                        "Expected NaN during warmup at idx {}, got {}",
1718                        i,
1719                        out[i]
1720                    );
1721                }
1722
1723                for i in (period - 1)..data.len() {
1724                    let mad = out[i];
1725                    prop_assert!(
1726                        mad.is_finite() && mad >= 0.0,
1727                        "MAD at idx {} is not finite or negative: {}",
1728                        i,
1729                        mad
1730                    );
1731                }
1732
1733                if data.windows(2).all(|w| (w[0] - w[1]).abs() < f64::EPSILON)
1734                    && data.len() >= period
1735                {
1736                    for i in (period - 1)..data.len() {
1737                        prop_assert!(
1738                            out[i].abs() < 1e-9,
1739                            "Constant data should have MAD=0.0, got {} at idx {}",
1740                            out[i],
1741                            i
1742                        );
1743                    }
1744                }
1745
1746                for i in (period - 1)..data.len() {
1747                    let window = &data[i + 1 - period..=i];
1748                    let min_val = window.iter().cloned().fold(f64::INFINITY, f64::min);
1749                    let max_val = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1750                    let range = max_val - min_val;
1751                    let mad = out[i];
1752
1753                    prop_assert!(
1754                        mad <= range * 0.5 + 1e-9,
1755                        "MAD {} exceeds theoretical maximum (50% of range {}) at idx {}",
1756                        mad,
1757                        range * 0.5,
1758                        i
1759                    );
1760                }
1761
1762                if period == 1 {
1763                    for i in 0..data.len() {
1764                        if !out[i].is_nan() {
1765                            prop_assert!(
1766                                out[i].abs() < f64::EPSILON,
1767                                "Period=1 should have MAD=0.0, got {} at idx {}",
1768                                out[i],
1769                                i
1770                            );
1771                        }
1772                    }
1773                }
1774
1775                let neg_data: Vec<f64> = data.iter().map(|&x| -x).collect();
1776                let neg_input = MediumAdInput::from_slice(
1777                    &neg_data,
1778                    MediumAdParams {
1779                        period: Some(period),
1780                    },
1781                );
1782                let MediumAdOutput { values: neg_out } =
1783                    medium_ad_with_kernel(&neg_input, kernel).unwrap();
1784
1785                for i in (period - 1)..data.len() {
1786                    let mad = out[i];
1787                    let neg_mad = neg_out[i];
1788                    prop_assert!(
1789                        (mad - neg_mad).abs() < 1e-9,
1790                        "Symmetry failed at idx {}: {} vs {}",
1791                        i,
1792                        mad,
1793                        neg_mad
1794                    );
1795                }
1796
1797                let scale_factors = [2.0, -3.0, 0.5];
1798                for &scale in &scale_factors {
1799                    let scaled_data: Vec<f64> = data.iter().map(|&x| x * scale).collect();
1800                    let scaled_input = MediumAdInput::from_slice(
1801                        &scaled_data,
1802                        MediumAdParams {
1803                            period: Some(period),
1804                        },
1805                    );
1806                    let MediumAdOutput { values: scaled_out } =
1807                        medium_ad_with_kernel(&scaled_input, kernel).unwrap();
1808
1809                    for i in (period - 1)..data.len() {
1810                        let original_mad = out[i];
1811                        let scaled_mad = scaled_out[i];
1812                        let expected_scaled_mad = original_mad * scale.abs();
1813
1814                        prop_assert!(
1815                            (scaled_mad - expected_scaled_mad).abs() < 1e-9,
1816                            "Scale invariance failed at idx {} with scale {}: {} vs expected {}",
1817                            i,
1818                            scale,
1819                            scaled_mad,
1820                            expected_scaled_mad
1821                        );
1822                    }
1823                }
1824
1825                if period >= 5 && data.len() >= period + 10 {
1826                    let mut outlier_data = data.clone();
1827
1828                    for test_idx in (period + 4..data.len().min(period + 20)).step_by(5) {
1829                        let window = &data[test_idx + 1 - period..=test_idx];
1830                        let win_min = window.iter().cloned().fold(f64::INFINITY, f64::min);
1831                        let win_max = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1832                        let win_range = win_max - win_min;
1833
1834                        let outlier_idx = test_idx - period / 2;
1835                        let original_value = outlier_data[outlier_idx];
1836                        outlier_data[outlier_idx] = win_max + win_range * 10.0;
1837
1838                        let outlier_input = MediumAdInput::from_slice(
1839                            &outlier_data,
1840                            MediumAdParams {
1841                                period: Some(period),
1842                            },
1843                        );
1844                        let MediumAdOutput {
1845                            values: outlier_out,
1846                        } = medium_ad_with_kernel(&outlier_input, kernel).unwrap();
1847
1848                        let original_mad = out[test_idx];
1849                        let outlier_mad = outlier_out[test_idx];
1850
1851                        let outlier_effect = win_range * 10.0;
1852                        prop_assert!(
1853							outlier_mad <= original_mad * 10.0 + outlier_effect * 0.1,
1854							"MAD not robust enough to outliers at idx {}: original {}, with outlier {}",
1855							test_idx, original_mad, outlier_mad
1856						);
1857
1858                        if original_mad > win_range * 0.05 {
1859                            let mad_ratio = outlier_mad / original_mad;
1860                            prop_assert!(
1861                                mad_ratio <= 25.0,
1862                                "MAD ratio too high with outlier at idx {}: ratio {}",
1863                                test_idx,
1864                                mad_ratio
1865                            );
1866                        }
1867
1868                        outlier_data[outlier_idx] = original_value;
1869                    }
1870                }
1871
1872                if period >= 3 && period <= 20 {
1873                    let sequential: Vec<f64> = (1..=period).map(|i| i as f64).collect();
1874                    let seq_input = MediumAdInput::from_slice(
1875                        &sequential,
1876                        MediumAdParams {
1877                            period: Some(period),
1878                        },
1879                    );
1880                    let MediumAdOutput { values: seq_out } =
1881                        medium_ad_with_kernel(&seq_input, kernel).unwrap();
1882
1883                    let median = if period % 2 == 1 {
1884                        (period / 2 + 1) as f64
1885                    } else {
1886                        (period / 2) as f64 + 0.5
1887                    };
1888
1889                    if period - 1 < sequential.len() {
1890                        let calculated_mad = seq_out[period - 1];
1891                        let seq_range = (period - 1) as f64;
1892
1893                        prop_assert!(
1894							calculated_mad > 0.0 && calculated_mad <= seq_range * 0.5,
1895							"MAD for sequential data with period {} out of bounds: got {}, range is {}",
1896							period, calculated_mad, seq_range
1897						);
1898
1899                        if period == 3 {
1900                            prop_assert!(
1901                                (calculated_mad - 1.0).abs() < 1e-9,
1902                                "MAD for [1,2,3] should be 1.0, got {}",
1903                                calculated_mad
1904                            );
1905                        } else if period == 5 {
1906                            prop_assert!(
1907                                (calculated_mad - 1.0).abs() < 1e-9,
1908                                "MAD for [1,2,3,4,5] should be 1.0, got {}",
1909                                calculated_mad
1910                            );
1911                        }
1912                    }
1913                }
1914
1915                if period >= 4 && period % 2 == 0 {
1916                    let mut extreme_data = vec![0.0; period];
1917                    for i in 0..period / 2 {
1918                        extreme_data[i] = 100.0;
1919                    }
1920
1921                    let extreme_input = MediumAdInput::from_slice(
1922                        &extreme_data,
1923                        MediumAdParams {
1924                            period: Some(period),
1925                        },
1926                    );
1927                    let MediumAdOutput {
1928                        values: extreme_out,
1929                    } = medium_ad_with_kernel(&extreme_input, kernel).unwrap();
1930
1931                    let expected_extreme_mad = 50.0;
1932
1933                    if period - 1 < extreme_data.len() {
1934                        let calculated_extreme_mad = extreme_out[period - 1];
1935                        prop_assert!(
1936							(calculated_extreme_mad - expected_extreme_mad).abs() < 1e-9,
1937							"MAD mismatch for extreme data pattern with period {}: got {}, expected {}",
1938							period, calculated_extreme_mad, expected_extreme_mad
1939						);
1940                    }
1941                }
1942
1943                Ok(())
1944            })
1945            .unwrap();
1946
1947        Ok(())
1948    }
1949
1950    macro_rules! generate_all_medium_ad_tests {
1951        ($($test_fn:ident),*) => {
1952            paste::paste! {
1953                $(
1954                    #[test]
1955                    fn [<$test_fn _scalar_f64>]() {
1956                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1957                    }
1958                )*
1959                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1960                $(
1961                    #[test]
1962                    fn [<$test_fn _avx2_f64>]() {
1963                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1964                    }
1965                    #[test]
1966                    fn [<$test_fn _avx512_f64>]() {
1967                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1968                    }
1969                )*
1970            }
1971        }
1972    }
1973
1974    generate_all_medium_ad_tests!(
1975        check_medium_ad_partial_params,
1976        check_medium_ad_accuracy,
1977        check_medium_ad_default_candles,
1978        check_medium_ad_zero_period,
1979        check_medium_ad_period_exceeds_length,
1980        check_medium_ad_very_small_dataset,
1981        check_medium_ad_reinput,
1982        check_medium_ad_nan_handling,
1983        check_medium_ad_streaming
1984    );
1985
1986    #[cfg(feature = "proptest")]
1987    generate_all_medium_ad_tests!(check_medium_ad_property);
1988
1989    fn check_batch_default_row(
1990        test: &str,
1991        kernel: Kernel,
1992    ) -> Result<(), Box<dyn std::error::Error>> {
1993        skip_if_unsupported!(kernel, test);
1994
1995        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1996        let c = read_candles_from_csv(file)?;
1997
1998        let output = MediumAdBatchBuilder::new()
1999            .kernel(kernel)
2000            .apply_candles(&c, "close")?;
2001
2002        let def = MediumAdParams::default();
2003        let row = output.values_for(&def).expect("default row missing");
2004
2005        assert_eq!(row.len(), c.close.len());
2006        Ok(())
2007    }
2008
2009    macro_rules! gen_batch_tests {
2010        ($fn_name:ident) => {
2011            paste::paste! {
2012                #[test] fn [<$fn_name _scalar>]()      {
2013                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2014                }
2015                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2016                #[test] fn [<$fn_name _avx2>]()        {
2017                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2018                }
2019                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2020                #[test] fn [<$fn_name _avx512>]()      {
2021                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2022                }
2023                #[test] fn [<$fn_name _auto_detect>]() {
2024                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2025                }
2026            }
2027        };
2028    }
2029    gen_batch_tests!(check_batch_default_row);
2030}
2031
2032#[cfg(feature = "python")]
2033use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
2034#[cfg(feature = "python")]
2035use pyo3::exceptions::{PyBufferError, PyValueError};
2036#[cfg(feature = "python")]
2037use pyo3::prelude::*;
2038#[cfg(feature = "python")]
2039use pyo3::types::PyDict;
2040
2041#[cfg(feature = "python")]
2042use crate::utilities::kernel_validation::validate_kernel;
2043
2044#[cfg(feature = "python")]
2045#[pyfunction(name = "medium_ad")]
2046#[pyo3(signature = (data, period, kernel=None))]
2047pub fn medium_ad_py<'py>(
2048    py: Python<'py>,
2049    data: PyReadonlyArray1<'py, f64>,
2050    period: usize,
2051    kernel: Option<&str>,
2052) -> PyResult<Bound<'py, PyArray1<f64>>> {
2053    let slice_in = data.as_slice()?;
2054    let kern = validate_kernel(kernel, false)?;
2055
2056    let params = MediumAdParams {
2057        period: Some(period),
2058    };
2059    let input = MediumAdInput::from_slice(slice_in, params);
2060
2061    let result_vec: Vec<f64> = py
2062        .allow_threads(|| medium_ad_with_kernel(&input, kern).map(|o| o.values))
2063        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2064
2065    Ok(result_vec.into_pyarray(py))
2066}
2067
2068#[cfg(feature = "python")]
2069#[pyclass(name = "MediumAdStream")]
2070pub struct MediumAdStreamPy {
2071    stream: MediumAdStream,
2072}
2073
2074#[cfg(feature = "python")]
2075#[pymethods]
2076impl MediumAdStreamPy {
2077    #[new]
2078    fn new(period: usize) -> PyResult<Self> {
2079        let params = MediumAdParams {
2080            period: Some(period),
2081        };
2082        let stream =
2083            MediumAdStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2084        Ok(MediumAdStreamPy { stream })
2085    }
2086
2087    fn update(&mut self, value: f64) -> Option<f64> {
2088        self.stream.update(value)
2089    }
2090}
2091
2092#[cfg(feature = "python")]
2093#[pyfunction(name = "medium_ad_batch")]
2094#[pyo3(signature = (data, period_range, kernel=None))]
2095pub fn medium_ad_batch_py<'py>(
2096    py: Python<'py>,
2097    data: PyReadonlyArray1<'py, f64>,
2098    period_range: (usize, usize, usize),
2099    kernel: Option<&str>,
2100) -> PyResult<Bound<'py, PyDict>> {
2101    let slice_in = data.as_slice()?;
2102    let kern = validate_kernel(kernel, true)?;
2103
2104    let sweep = MediumAdBatchRange {
2105        period: period_range,
2106    };
2107
2108    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2109    let rows = combos.len();
2110    let cols = slice_in.len();
2111
2112    let total = rows
2113        .checked_mul(cols)
2114        .ok_or_else(|| PyValueError::new_err("medium_ad: batch output size overflow"))?;
2115    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2116    let slice_out = unsafe { out_arr.as_slice_mut()? };
2117
2118    let combos = py
2119        .allow_threads(|| {
2120            let kernel = match kern {
2121                Kernel::Auto => Kernel::ScalarBatch,
2122                k => k,
2123            };
2124            let simd = match kernel {
2125                Kernel::Avx512Batch => Kernel::Avx512,
2126                Kernel::Avx2Batch => Kernel::Avx2,
2127                Kernel::ScalarBatch => Kernel::Scalar,
2128                _ => unreachable!(),
2129            };
2130            medium_ad_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2131        })
2132        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2133
2134    let dict = PyDict::new(py);
2135    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2136    dict.set_item(
2137        "periods",
2138        combos
2139            .iter()
2140            .map(|p| p.period.unwrap() as u64)
2141            .collect::<Vec<_>>()
2142            .into_pyarray(py),
2143    )?;
2144
2145    Ok(dict)
2146}
2147
2148#[cfg(all(feature = "python", feature = "cuda"))]
2149use crate::cuda::medium_ad_wrapper::CudaMediumAd;
2150#[cfg(all(feature = "python", feature = "cuda"))]
2151use crate::cuda::moving_averages::DeviceArrayF32;
2152#[cfg(all(feature = "python", feature = "cuda"))]
2153use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2154#[cfg(all(feature = "python", feature = "cuda"))]
2155use cust::context::Context;
2156#[cfg(all(feature = "python", feature = "cuda"))]
2157use cust::memory::DeviceBuffer;
2158#[cfg(all(feature = "python", feature = "cuda"))]
2159use std::sync::Arc;
2160
2161#[cfg(all(feature = "python", feature = "cuda"))]
2162#[pyclass(module = "ta_indicators.cuda", unsendable)]
2163pub struct MediumAdDeviceArrayF32Py {
2164    pub(crate) inner: DeviceArrayF32,
2165    ctx: Arc<Context>,
2166    device_id: u32,
2167}
2168
2169#[cfg(all(feature = "python", feature = "cuda"))]
2170#[pymethods]
2171impl MediumAdDeviceArrayF32Py {
2172    #[getter]
2173    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2174        let inner = &self.inner;
2175        let d = PyDict::new(py);
2176
2177        d.set_item("shape", (inner.rows, inner.cols))?;
2178
2179        d.set_item("typestr", "<f4")?;
2180
2181        d.set_item(
2182            "strides",
2183            (
2184                inner.cols * std::mem::size_of::<f32>(),
2185                std::mem::size_of::<f32>(),
2186            ),
2187        )?;
2188        let ptr_val = if inner.rows == 0 || inner.cols == 0 {
2189            0usize
2190        } else {
2191            inner.device_ptr() as usize
2192        };
2193        d.set_item("data", (ptr_val, false))?;
2194
2195        d.set_item("version", 3)?;
2196        Ok(d)
2197    }
2198
2199    fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
2200        let mut device_ordinal: i32 = 0;
2201        unsafe {
2202            let attr = cust::sys::CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL;
2203            let mut value = std::mem::MaybeUninit::<i32>::uninit();
2204            let err = cust::sys::cuPointerGetAttribute(
2205                value.as_mut_ptr() as *mut std::ffi::c_void,
2206                attr,
2207                self.inner.buf.as_device_ptr().as_raw(),
2208            );
2209            if err == cust::sys::CUresult::CUDA_SUCCESS {
2210                device_ordinal = value.assume_init();
2211            } else {
2212                let _ = cust::sys::cuCtxGetDevice(&mut device_ordinal);
2213            }
2214        }
2215        Ok((2, device_ordinal))
2216    }
2217
2218    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2219    fn __dlpack__<'py>(
2220        &mut self,
2221        py: Python<'py>,
2222        stream: Option<pyo3::PyObject>,
2223        max_version: Option<pyo3::PyObject>,
2224        dl_device: Option<pyo3::PyObject>,
2225        copy: Option<pyo3::PyObject>,
2226    ) -> PyResult<PyObject> {
2227        if let Some(obj) = &stream {
2228            if let Ok(i) = obj.extract::<i64>(py) {
2229                if i == 0 {
2230                    return Err(PyValueError::new_err(
2231                        "__dlpack__: stream 0 is disallowed for CUDA",
2232                    ));
2233                }
2234            }
2235        }
2236
2237        let (kdl, alloc_dev) = self.__dlpack_device__()?;
2238        if let Some(dev_obj) = dl_device.as_ref() {
2239            if let Ok((dev_type, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2240                if dev_type != kdl || dev_id != alloc_dev {
2241                    let wants_copy = copy
2242                        .as_ref()
2243                        .and_then(|c| c.extract::<bool>(py).ok())
2244                        .unwrap_or(false);
2245                    if wants_copy {
2246                        return Err(PyBufferError::new_err(
2247                            "__dlpack__: copy semantics are not supported for medium_ad",
2248                        ));
2249                    } else {
2250                        return Err(PyBufferError::new_err(
2251                            "__dlpack__: requested device does not match allocation device",
2252                        ));
2253                    }
2254                }
2255            }
2256        }
2257
2258        if copy
2259            .as_ref()
2260            .and_then(|c| c.extract::<bool>(py).ok())
2261            .unwrap_or(false)
2262        {
2263            return Err(PyBufferError::new_err(
2264                "__dlpack__: copy semantics are not supported for medium_ad",
2265            ));
2266        }
2267
2268        let dummy =
2269            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2270        let inner = std::mem::replace(
2271            &mut self.inner,
2272            DeviceArrayF32 {
2273                buf: dummy,
2274                rows: 0,
2275                cols: 0,
2276            },
2277        );
2278
2279        let rows = inner.rows;
2280        let cols = inner.cols;
2281        let buf = inner.buf;
2282
2283        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2284        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2285    }
2286}
2287
2288#[cfg(all(feature = "python", feature = "cuda"))]
2289#[pyfunction(name = "medium_ad_cuda_batch_dev")]
2290#[pyo3(signature = (data_f32, period_range, device_id=0))]
2291pub fn medium_ad_cuda_batch_dev_py(
2292    py: Python<'_>,
2293    data_f32: numpy::PyReadonlyArray1<'_, f32>,
2294    period_range: (usize, usize, usize),
2295    device_id: usize,
2296) -> PyResult<MediumAdDeviceArrayF32Py> {
2297    use crate::cuda::cuda_available;
2298    if !cuda_available() {
2299        return Err(PyValueError::new_err("CUDA not available"));
2300    }
2301    let slice = data_f32.as_slice()?;
2302    let sweep = MediumAdBatchRange {
2303        period: period_range,
2304    };
2305    let (inner, ctx, dev_id) =
2306        py.allow_threads(|| -> PyResult<(DeviceArrayF32, Arc<Context>, u32)> {
2307            let cuda =
2308                CudaMediumAd::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2309            let ctx = cuda.context_arc();
2310            let dev_id = cuda.device_id();
2311            let arr = cuda
2312                .medium_ad_batch_dev(slice, &sweep)
2313                .map_err(|e| PyValueError::new_err(e.to_string()))?;
2314            Ok((arr, ctx, dev_id))
2315        })?;
2316    Ok(MediumAdDeviceArrayF32Py {
2317        inner,
2318        ctx,
2319        device_id: dev_id,
2320    })
2321}
2322
2323#[cfg(all(feature = "python", feature = "cuda"))]
2324#[pyfunction(name = "medium_ad_cuda_many_series_one_param_dev")]
2325#[pyo3(signature = (data_tm_f32, cols, rows, period, device_id=0))]
2326pub fn medium_ad_cuda_many_series_one_param_dev_py(
2327    py: Python<'_>,
2328    data_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2329    cols: usize,
2330    rows: usize,
2331    period: usize,
2332    device_id: usize,
2333) -> PyResult<MediumAdDeviceArrayF32Py> {
2334    use crate::cuda::cuda_available;
2335    if !cuda_available() {
2336        return Err(PyValueError::new_err("CUDA not available"));
2337    }
2338    let slice = data_tm_f32.as_slice()?;
2339    let (inner, ctx, dev_id) =
2340        py.allow_threads(|| -> PyResult<(DeviceArrayF32, Arc<Context>, u32)> {
2341            let cuda =
2342                CudaMediumAd::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2343            let ctx = cuda.context_arc();
2344            let dev_id = cuda.device_id();
2345            let arr = cuda
2346                .medium_ad_many_series_one_param_time_major_dev(slice, cols, rows, period)
2347                .map_err(|e| PyValueError::new_err(e.to_string()))?;
2348            Ok((arr, ctx, dev_id))
2349        })?;
2350    Ok(MediumAdDeviceArrayF32Py {
2351        inner,
2352        ctx,
2353        device_id: dev_id,
2354    })
2355}
2356
2357#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2358use serde::{Deserialize, Serialize};
2359#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2360use wasm_bindgen::prelude::*;
2361
2362pub fn medium_ad_into_slice(
2363    dst: &mut [f64],
2364    input: &MediumAdInput,
2365    kern: Kernel,
2366) -> Result<(), MediumAdError> {
2367    let data = match &input.data {
2368        MediumAdData::Candles { candles, source } => source_type(candles, source),
2369        MediumAdData::Slice(s) => s,
2370    };
2371    let period = input.params.period.unwrap_or(5);
2372
2373    if period == 0 || period > data.len() {
2374        return Err(MediumAdError::InvalidPeriod {
2375            period,
2376            data_len: data.len(),
2377        });
2378    }
2379
2380    if dst.len() != data.len() {
2381        return Err(MediumAdError::OutputLengthMismatch {
2382            expected: data.len(),
2383            got: dst.len(),
2384        });
2385    }
2386
2387    let first = data.iter().position(|&x| !x.is_nan()).unwrap_or(0);
2388    let chosen = if kern == Kernel::Auto {
2389        Kernel::Scalar
2390    } else {
2391        kern
2392    };
2393
2394    match chosen {
2395        Kernel::Scalar => medium_ad_scalar(data, period, first, dst),
2396        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2397        Kernel::Avx2 => medium_ad_avx2(data, period, first, dst),
2398
2399        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2400        Kernel::Avx512 => medium_ad_scalar(data, period, first, dst),
2401        _ => unreachable!(),
2402    }
2403
2404    let warmup_end = first + period - 1;
2405    for v in &mut dst[..warmup_end] {
2406        *v = f64::NAN;
2407    }
2408
2409    Ok(())
2410}
2411
2412#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2413#[wasm_bindgen]
2414pub fn medium_ad_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2415    let params = MediumAdParams {
2416        period: Some(period),
2417    };
2418    let input = MediumAdInput::from_slice(data, params);
2419
2420    let mut output = vec![0.0; data.len()];
2421
2422    medium_ad_into_slice(&mut output, &input, Kernel::Auto)
2423        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2424
2425    Ok(output)
2426}
2427
2428#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2429#[wasm_bindgen]
2430pub fn medium_ad_alloc(len: usize) -> *mut f64 {
2431    let mut vec = Vec::<f64>::with_capacity(len);
2432    let ptr = vec.as_mut_ptr();
2433    std::mem::forget(vec);
2434    ptr
2435}
2436
2437#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2438#[wasm_bindgen]
2439pub fn medium_ad_free(ptr: *mut f64, len: usize) {
2440    if !ptr.is_null() {
2441        unsafe {
2442            let _ = Vec::from_raw_parts(ptr, len, len);
2443        }
2444    }
2445}
2446
2447#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2448#[wasm_bindgen]
2449pub fn medium_ad_into(
2450    in_ptr: *const f64,
2451    out_ptr: *mut f64,
2452    len: usize,
2453    period: usize,
2454) -> Result<(), JsValue> {
2455    if in_ptr.is_null() || out_ptr.is_null() {
2456        return Err(JsValue::from_str("Null pointer provided"));
2457    }
2458
2459    unsafe {
2460        let data = std::slice::from_raw_parts(in_ptr, len);
2461
2462        if period == 0 || period > len {
2463            return Err(JsValue::from_str("Invalid period"));
2464        }
2465
2466        let params = MediumAdParams {
2467            period: Some(period),
2468        };
2469        let input = MediumAdInput::from_slice(data, params);
2470
2471        if in_ptr == out_ptr {
2472            let mut temp = vec![0.0; len];
2473            medium_ad_into_slice(&mut temp, &input, Kernel::Auto)
2474                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2475            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2476            out.copy_from_slice(&temp);
2477        } else {
2478            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2479            medium_ad_into_slice(out, &input, Kernel::Auto)
2480                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2481        }
2482
2483        Ok(())
2484    }
2485}
2486
2487#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2488#[derive(Serialize, Deserialize)]
2489pub struct MediumAdBatchConfig {
2490    pub period_range: (usize, usize, usize),
2491}
2492
2493#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2494#[derive(Serialize, Deserialize)]
2495pub struct MediumAdBatchJsOutput {
2496    pub values: Vec<f64>,
2497    pub combos: Vec<MediumAdParams>,
2498    pub rows: usize,
2499    pub cols: usize,
2500}
2501
2502#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2503#[wasm_bindgen(js_name = medium_ad_batch)]
2504pub fn medium_ad_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2505    let config: MediumAdBatchConfig = serde_wasm_bindgen::from_value(config)
2506        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2507
2508    let sweep = MediumAdBatchRange {
2509        period: config.period_range,
2510    };
2511
2512    let output = medium_ad_batch_inner(data, &sweep, Kernel::Auto, false)
2513        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2514
2515    let js_output = MediumAdBatchJsOutput {
2516        values: output.values,
2517        combos: output.combos,
2518        rows: output.rows,
2519        cols: output.cols,
2520    };
2521
2522    serde_wasm_bindgen::to_value(&js_output)
2523        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2524}
2525
2526#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2527#[wasm_bindgen]
2528pub fn medium_ad_batch_into(
2529    in_ptr: *const f64,
2530    out_ptr: *mut f64,
2531    len: usize,
2532    period_start: usize,
2533    period_end: usize,
2534    period_step: usize,
2535) -> Result<usize, JsValue> {
2536    if in_ptr.is_null() || out_ptr.is_null() {
2537        return Err(JsValue::from_str(
2538            "null pointer passed to medium_ad_batch_into",
2539        ));
2540    }
2541
2542    unsafe {
2543        let data = std::slice::from_raw_parts(in_ptr, len);
2544
2545        let sweep = MediumAdBatchRange {
2546            period: (period_start, period_end, period_step),
2547        };
2548
2549        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2550        let rows = combos.len();
2551        let cols = len;
2552
2553        let total = match rows.checked_mul(cols) {
2554            Some(v) => v,
2555            None => {
2556                return Err(JsValue::from_str(
2557                    "medium_ad_batch_into: rows*cols overflow in output size",
2558                ))
2559            }
2560        };
2561
2562        let out = std::slice::from_raw_parts_mut(out_ptr, total);
2563
2564        medium_ad_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
2565            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2566
2567        Ok(rows)
2568    }
2569}
2570
2571#[cfg(all(test, not(all(target_arch = "wasm32", feature = "wasm"))))]
2572mod tests_into_parity {
2573    use super::*;
2574
2575    #[test]
2576    fn test_medium_ad_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
2577        let mut data = Vec::with_capacity(256);
2578        data.extend_from_slice(&[f64::NAN, f64::NAN, f64::NAN]);
2579        for i in 0..253usize {
2580            let x = (i as f64 * 0.019).sin() * 3.0 + 42.0 + ((i % 11) as f64) * 0.07;
2581            data.push(x);
2582        }
2583
2584        let input = MediumAdInput::from_slice(&data, MediumAdParams::default());
2585
2586        let baseline = medium_ad(&input)?.values;
2587
2588        let mut out = vec![0.0; data.len()];
2589        medium_ad_into(&input, &mut out)?;
2590
2591        assert_eq!(baseline.len(), out.len());
2592
2593        #[inline]
2594        fn eq_or_nan(a: f64, b: f64) -> bool {
2595            (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
2596        }
2597
2598        for (i, (a, b)) in baseline.iter().zip(out.iter()).enumerate() {
2599            assert!(
2600                eq_or_nan(*a, *b),
2601                "medium_ad_into mismatch at idx {}: baseline={} into={}",
2602                i,
2603                a,
2604                b
2605            );
2606        }
2607
2608        Ok(())
2609    }
2610}