Skip to main content

vector_ta/indicators/
mean_ad.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4    alloc_with_nan_prefix, detect_best_batch_kernel, init_matrix_prefixes, make_uninit_matrix,
5};
6#[cfg(feature = "python")]
7use crate::utilities::kernel_validation::validate_kernel;
8use aligned_vec::{AVec, CACHELINE_ALIGN};
9#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
10use core::arch::x86_64::*;
11#[cfg(feature = "python")]
12use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
13#[cfg(feature = "python")]
14use pyo3::exceptions::PyValueError;
15#[cfg(feature = "python")]
16use pyo3::prelude::*;
17#[cfg(feature = "python")]
18use pyo3::types::PyDict;
19#[cfg(not(target_arch = "wasm32"))]
20use rayon::prelude::*;
21#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
22use serde::{Deserialize, Serialize};
23use std::convert::AsRef;
24use std::error::Error;
25use std::mem::MaybeUninit;
26use thiserror::Error;
27#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
28use wasm_bindgen::prelude::*;
29
30#[cfg(all(feature = "python", feature = "cuda"))]
31use crate::cuda::CudaMeanAd;
32#[cfg(all(feature = "python", feature = "cuda"))]
33use crate::indicators::moving_averages::alma::{make_device_array_py, DeviceArrayF32Py};
34
35impl<'a> AsRef<[f64]> for MeanAdInput<'a> {
36    #[inline(always)]
37    fn as_ref(&self) -> &[f64] {
38        match &self.data {
39            MeanAdData::Slice(slice) => slice,
40            MeanAdData::Candles { candles, source } => source_type(candles, source),
41        }
42    }
43}
44
45#[derive(Debug, Clone)]
46pub enum MeanAdData<'a> {
47    Candles {
48        candles: &'a Candles,
49        source: &'a str,
50    },
51    Slice(&'a [f64]),
52}
53
54#[derive(Debug, Clone)]
55pub struct MeanAdOutput {
56    pub values: Vec<f64>,
57}
58
59#[derive(Debug, Clone)]
60pub struct MeanAdParams {
61    pub period: Option<usize>,
62}
63
64impl Default for MeanAdParams {
65    fn default() -> Self {
66        Self { period: Some(5) }
67    }
68}
69
70#[derive(Debug, Clone)]
71pub struct MeanAdInput<'a> {
72    pub data: MeanAdData<'a>,
73    pub params: MeanAdParams,
74}
75
76impl<'a> MeanAdInput<'a> {
77    #[inline]
78    pub fn from_candles(c: &'a Candles, s: &'a str, p: MeanAdParams) -> Self {
79        Self {
80            data: MeanAdData::Candles {
81                candles: c,
82                source: s,
83            },
84            params: p,
85        }
86    }
87    #[inline]
88    pub fn from_slice(sl: &'a [f64], p: MeanAdParams) -> Self {
89        Self {
90            data: MeanAdData::Slice(sl),
91            params: p,
92        }
93    }
94    #[inline]
95    pub fn with_default_candles(c: &'a Candles) -> Self {
96        Self::from_candles(c, "close", MeanAdParams::default())
97    }
98    #[inline]
99    pub fn get_period(&self) -> usize {
100        self.params.period.unwrap_or(5)
101    }
102}
103
104#[derive(Copy, Clone, Debug)]
105pub struct MeanAdBuilder {
106    period: Option<usize>,
107    kernel: Kernel,
108}
109
110impl Default for MeanAdBuilder {
111    fn default() -> Self {
112        Self {
113            period: None,
114            kernel: Kernel::Auto,
115        }
116    }
117}
118
119impl MeanAdBuilder {
120    #[inline(always)]
121    pub fn new() -> Self {
122        Self::default()
123    }
124    #[inline(always)]
125    pub fn period(mut self, n: usize) -> Self {
126        self.period = Some(n);
127        self
128    }
129    #[inline(always)]
130    pub fn kernel(mut self, k: Kernel) -> Self {
131        self.kernel = k;
132        self
133    }
134    #[inline(always)]
135    pub fn apply(self, c: &Candles) -> Result<MeanAdOutput, MeanAdError> {
136        let p = MeanAdParams {
137            period: self.period,
138        };
139        let i = MeanAdInput::from_candles(c, "close", p);
140        mean_ad_with_kernel(&i, self.kernel)
141    }
142    #[inline(always)]
143    pub fn apply_slice(self, d: &[f64]) -> Result<MeanAdOutput, MeanAdError> {
144        let p = MeanAdParams {
145            period: self.period,
146        };
147        let i = MeanAdInput::from_slice(d, p);
148        mean_ad_with_kernel(&i, self.kernel)
149    }
150    #[inline(always)]
151    pub fn into_stream(self) -> Result<MeanAdStream, MeanAdError> {
152        let p = MeanAdParams {
153            period: self.period,
154        };
155        MeanAdStream::try_new(p)
156    }
157}
158
159#[derive(Debug, Error)]
160pub enum MeanAdError {
161    #[error("mean_ad: Empty data provided.")]
162    EmptyInputData,
163    #[error("mean_ad: Invalid period: period = {period}, data length = {data_len}")]
164    InvalidPeriod { period: usize, data_len: usize },
165    #[error("mean_ad: Not enough valid data: needed = {needed}, valid = {valid}")]
166    NotEnoughValidData { needed: usize, valid: usize },
167    #[error("mean_ad: All values are NaN.")]
168    AllValuesNaN,
169    #[error("mean_ad: Output length mismatch: expected {expected}, got {got}")]
170    OutputLengthMismatch { expected: usize, got: usize },
171    #[error("mean_ad: Invalid range: start={start}, end={end}, step={step}")]
172    InvalidRange {
173        start: String,
174        end: String,
175        step: String,
176    },
177    #[error("mean_ad: Invalid kernel for batch: {0:?}")]
178    InvalidKernelForBatch(crate::utilities::enums::Kernel),
179}
180
181#[inline]
182pub fn mean_ad(input: &MeanAdInput) -> Result<MeanAdOutput, MeanAdError> {
183    mean_ad_with_kernel(input, Kernel::Auto)
184}
185
186pub fn mean_ad_with_kernel(
187    input: &MeanAdInput,
188    kernel: Kernel,
189) -> Result<MeanAdOutput, MeanAdError> {
190    let data: &[f64] = match &input.data {
191        MeanAdData::Candles { candles, source } => source_type(candles, source),
192        MeanAdData::Slice(sl) => sl,
193    };
194
195    if data.is_empty() {
196        return Err(MeanAdError::EmptyInputData);
197    }
198
199    let period = input.get_period();
200    let first = data
201        .iter()
202        .position(|x| !x.is_nan())
203        .ok_or(MeanAdError::AllValuesNaN)?;
204
205    if period == 0 || period > data.len() {
206        return Err(MeanAdError::InvalidPeriod {
207            period,
208            data_len: data.len(),
209        });
210    }
211    if (data.len() - first) < period {
212        return Err(MeanAdError::NotEnoughValidData {
213            needed: period,
214            valid: data.len() - first,
215        });
216    }
217
218    let chosen = match kernel {
219        Kernel::Auto => Kernel::Scalar,
220        other => other,
221    };
222
223    match chosen {
224        Kernel::Scalar | Kernel::ScalarBatch => mean_ad_scalar(data, period, first),
225        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
226        Kernel::Avx2 | Kernel::Avx2Batch => mean_ad_avx2(data, period, first),
227        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
228        Kernel::Avx512 | Kernel::Avx512Batch => mean_ad_avx512(data, period, first),
229        _ => unreachable!(),
230    }
231}
232
233#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
234#[inline]
235pub fn mean_ad_into(input: &MeanAdInput, out: &mut [f64]) -> Result<(), MeanAdError> {
236    mean_ad_into_slice(out, input, Kernel::Auto)
237}
238
239#[inline]
240pub fn mean_ad_scalar(
241    data: &[f64],
242    period: usize,
243    first: usize,
244) -> Result<MeanAdOutput, MeanAdError> {
245    if period == 0 {
246        return Err(MeanAdError::InvalidPeriod {
247            period: 0,
248            data_len: data.len(),
249        });
250    }
251    if first > data.len() {
252        return Err(MeanAdError::NotEnoughValidData {
253            needed: first,
254            valid: data.len(),
255        });
256    }
257
258    let n = data.len();
259
260    if first + period > n {
261        let out = alloc_with_nan_prefix(n, n);
262        return Ok(MeanAdOutput { values: out });
263    }
264
265    let inv_p = 1.0f64 / (period as f64);
266
267    let warmup_end = first + (period << 1) - 2;
268    let mut out = alloc_with_nan_prefix(n, warmup_end.min(n));
269
270    let mut sum = 0.0f64;
271    for i in first..(first + period) {
272        sum += data[i];
273    }
274    let mut sma = sum * inv_p;
275
276    let mut residual_buffer = vec![0.0f64; period];
277    let mut buffer_index = 0usize;
278    let mut residual_sum = 0.0f64;
279
280    let start_t = first + period - 1;
281    let fill_t_end = (start_t + period - 1).min(n.saturating_sub(1));
282    for t in start_t..=fill_t_end {
283        let residual = (data[t] - sma).abs();
284        residual_buffer[buffer_index] = residual;
285        residual_sum += residual;
286        buffer_index += 1;
287        if buffer_index == period {
288            buffer_index = 0;
289        }
290        if t + 1 < n {
291            sum += data[t + 1] - data[t + 1 - period];
292            sma = sum * inv_p;
293        }
294    }
295
296    if warmup_end < n {
297        out[warmup_end] = residual_sum * inv_p;
298    }
299
300    let mut t = start_t + period;
301    while t < n {
302        let residual = (data[t] - sma).abs();
303
304        let old = residual_buffer[buffer_index];
305        residual_sum += residual - old;
306        residual_buffer[buffer_index] = residual;
307        buffer_index += 1;
308        if buffer_index == period {
309            buffer_index = 0;
310        }
311
312        out[t] = residual_sum * inv_p;
313
314        if t + 1 < n {
315            sum += data[t + 1] - data[t + 1 - period];
316            sma = sum * inv_p;
317        }
318        t += 1;
319    }
320
321    Ok(MeanAdOutput { values: out })
322}
323
324#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
325#[inline]
326pub fn mean_ad_avx2(
327    data: &[f64],
328    period: usize,
329    first: usize,
330) -> Result<MeanAdOutput, MeanAdError> {
331    mean_ad_scalar(data, period, first)
332}
333
334#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
335#[inline]
336pub fn mean_ad_avx512(
337    data: &[f64],
338    period: usize,
339    first: usize,
340) -> Result<MeanAdOutput, MeanAdError> {
341    mean_ad_scalar(data, period, first)
342}
343
344#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
345#[inline]
346pub fn mean_ad_avx512_short(
347    data: &[f64],
348    period: usize,
349    first: usize,
350) -> Result<MeanAdOutput, MeanAdError> {
351    mean_ad_scalar(data, period, first)
352}
353
354#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
355#[inline]
356pub fn mean_ad_avx512_long(
357    data: &[f64],
358    period: usize,
359    first: usize,
360) -> Result<MeanAdOutput, MeanAdError> {
361    mean_ad_scalar(data, period, first)
362}
363
364#[inline(always)]
365pub fn mean_ad_batch_with_kernel(
366    data: &[f64],
367    sweep: &MeanAdBatchRange,
368    k: Kernel,
369) -> Result<MeanAdBatchOutput, MeanAdError> {
370    let kernel = match k {
371        Kernel::Auto => Kernel::ScalarBatch,
372        other if other.is_batch() => other,
373        other => return Err(MeanAdError::InvalidKernelForBatch(other)),
374    };
375    let simd = match kernel {
376        Kernel::Avx512Batch => Kernel::Avx512,
377        Kernel::Avx2Batch => Kernel::Avx2,
378        Kernel::ScalarBatch => Kernel::Scalar,
379        _ => unreachable!(),
380    };
381    mean_ad_batch_par_slice(data, sweep, simd)
382}
383
384#[derive(Clone, Debug)]
385pub struct MeanAdBatchRange {
386    pub period: (usize, usize, usize),
387}
388
389impl Default for MeanAdBatchRange {
390    fn default() -> Self {
391        Self {
392            period: (5, 254, 1),
393        }
394    }
395}
396
397#[derive(Clone, Debug, Default)]
398pub struct MeanAdBatchBuilder {
399    range: MeanAdBatchRange,
400    kernel: Kernel,
401}
402
403impl MeanAdBatchBuilder {
404    pub fn new() -> Self {
405        Self::default()
406    }
407    pub fn kernel(mut self, k: Kernel) -> Self {
408        self.kernel = k;
409        self
410    }
411    #[inline]
412    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
413        self.range.period = (start, end, step);
414        self
415    }
416    #[inline]
417    pub fn period_static(mut self, p: usize) -> Self {
418        self.range.period = (p, p, 0);
419        self
420    }
421    pub fn apply_slice(self, data: &[f64]) -> Result<MeanAdBatchOutput, MeanAdError> {
422        mean_ad_batch_with_kernel(data, &self.range, self.kernel)
423    }
424    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<MeanAdBatchOutput, MeanAdError> {
425        MeanAdBatchBuilder::new().kernel(k).apply_slice(data)
426    }
427    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<MeanAdBatchOutput, MeanAdError> {
428        let slice = source_type(c, src);
429        self.apply_slice(slice)
430    }
431    pub fn with_default_candles(c: &Candles) -> Result<MeanAdBatchOutput, MeanAdError> {
432        MeanAdBatchBuilder::new()
433            .kernel(Kernel::Auto)
434            .apply_candles(c, "close")
435    }
436}
437
438#[derive(Clone, Debug)]
439pub struct MeanAdBatchOutput {
440    pub values: Vec<f64>,
441    pub combos: Vec<MeanAdParams>,
442    pub rows: usize,
443    pub cols: usize,
444}
445impl MeanAdBatchOutput {
446    pub fn row_for_params(&self, p: &MeanAdParams) -> Option<usize> {
447        self.combos
448            .iter()
449            .position(|c| c.period.unwrap_or(5) == p.period.unwrap_or(5))
450    }
451    pub fn values_for(&self, p: &MeanAdParams) -> Option<&[f64]> {
452        self.row_for_params(p).and_then(|row| {
453            let start = row.checked_mul(self.cols)?;
454            self.values.get(start..start + self.cols)
455        })
456    }
457}
458
459#[inline(always)]
460fn expand_grid(r: &MeanAdBatchRange) -> Result<Vec<MeanAdParams>, MeanAdError> {
461    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, MeanAdError> {
462        if step == 0 || start == end {
463            return Ok(vec![start]);
464        }
465        if start < end {
466            let st = step.max(1);
467            let v: Vec<usize> = (start..=end).step_by(st).collect();
468            if v.is_empty() {
469                return Err(MeanAdError::InvalidRange {
470                    start: start.to_string(),
471                    end: end.to_string(),
472                    step: step.to_string(),
473                });
474            }
475            return Ok(v);
476        }
477
478        let mut v = Vec::new();
479        let mut x = start as isize;
480        let end_i = end as isize;
481        let st = (step as isize).max(1);
482        while x >= end_i {
483            v.push(x as usize);
484            x -= st;
485        }
486        if v.is_empty() {
487            return Err(MeanAdError::InvalidRange {
488                start: start.to_string(),
489                end: end.to_string(),
490                step: step.to_string(),
491            });
492        }
493        Ok(v)
494    }
495
496    let periods = axis_usize(r.period)?;
497    if periods.is_empty() {
498        return Err(MeanAdError::InvalidRange {
499            start: r.period.0.to_string(),
500            end: r.period.1.to_string(),
501            step: r.period.2.to_string(),
502        });
503    }
504
505    let mut out = Vec::with_capacity(periods.len());
506    for p in periods {
507        out.push(MeanAdParams { period: Some(p) });
508    }
509    Ok(out)
510}
511
512#[inline(always)]
513pub fn mean_ad_batch_slice(
514    data: &[f64],
515    sweep: &MeanAdBatchRange,
516    kern: Kernel,
517) -> Result<MeanAdBatchOutput, MeanAdError> {
518    mean_ad_batch_inner(data, sweep, kern, false)
519}
520
521#[inline(always)]
522pub fn mean_ad_batch_par_slice(
523    data: &[f64],
524    sweep: &MeanAdBatchRange,
525    kern: Kernel,
526) -> Result<MeanAdBatchOutput, MeanAdError> {
527    mean_ad_batch_inner(data, sweep, kern, true)
528}
529
530#[inline(always)]
531fn mean_ad_batch_inner_into(
532    data: &[f64],
533    sweep: &MeanAdBatchRange,
534    kern: Kernel,
535    parallel: bool,
536    out: &mut [f64],
537) -> Result<Vec<MeanAdParams>, MeanAdError> {
538    if data.is_empty() {
539        return Err(MeanAdError::EmptyInputData);
540    }
541    let combos = expand_grid(sweep)?;
542    if combos.iter().any(|c| c.period.unwrap_or(0) == 0) {
543        return Err(MeanAdError::InvalidPeriod {
544            period: 0,
545            data_len: data.len(),
546        });
547    }
548    let expected =
549        combos
550            .len()
551            .checked_mul(data.len())
552            .ok_or_else(|| MeanAdError::OutputLengthMismatch {
553                expected: usize::MAX,
554                got: out.len(),
555            })?;
556    if out.len() != expected {
557        return Err(MeanAdError::OutputLengthMismatch {
558            expected,
559            got: out.len(),
560        });
561    }
562    let first = data
563        .iter()
564        .position(|x| !x.is_nan())
565        .ok_or(MeanAdError::AllValuesNaN)?;
566    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
567    if max_p > data.len() {
568        return Err(MeanAdError::InvalidPeriod {
569            period: max_p,
570            data_len: data.len(),
571        });
572    }
573    if data.len() - first < max_p {
574        return Err(MeanAdError::NotEnoughValidData {
575            needed: max_p,
576            valid: data.len() - first,
577        });
578    }
579    let rows = combos.len();
580    let cols = data.len();
581
582    let chosen = match kern {
583        Kernel::Auto => Kernel::Scalar,
584        other => other,
585    };
586
587    let do_row = |row: usize, out_row: &mut [f64]| {
588        let period = combos[row].period.unwrap();
589        let warmup_end = first + 2 * period - 2;
590        for i in 0..warmup_end.min(out_row.len()) {
591            out_row[i] = f64::NAN;
592        }
593        match chosen {
594            Kernel::Scalar => mean_ad_row_scalar(data, first, period, out_row),
595            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
596            Kernel::Avx2 => mean_ad_row_avx2(data, first, period, out_row),
597            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
598            Kernel::Avx512 => mean_ad_row_avx512(data, first, period, out_row),
599            _ => mean_ad_row_scalar(data, first, period, out_row),
600        }
601    };
602
603    if parallel {
604        #[cfg(not(target_arch = "wasm32"))]
605        {
606            out.par_chunks_mut(cols)
607                .enumerate()
608                .for_each(|(row, slice)| do_row(row, slice));
609        }
610
611        #[cfg(target_arch = "wasm32")]
612        {
613            for (row, slice) in out.chunks_mut(cols).enumerate() {
614                do_row(row, slice);
615            }
616        }
617    } else {
618        for (row, slice) in out.chunks_mut(cols).enumerate() {
619            do_row(row, slice);
620        }
621    }
622
623    Ok(combos)
624}
625
626#[inline(always)]
627fn mean_ad_batch_inner(
628    data: &[f64],
629    sweep: &MeanAdBatchRange,
630    kern: Kernel,
631    parallel: bool,
632) -> Result<MeanAdBatchOutput, MeanAdError> {
633    if data.is_empty() {
634        return Err(MeanAdError::EmptyInputData);
635    }
636    let combos = expand_grid(sweep)?;
637    if combos.iter().any(|c| c.period.unwrap_or(0) == 0) {
638        return Err(MeanAdError::InvalidPeriod {
639            period: 0,
640            data_len: data.len(),
641        });
642    }
643    let first = data
644        .iter()
645        .position(|x| !x.is_nan())
646        .ok_or(MeanAdError::AllValuesNaN)?;
647    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
648    if max_p > data.len() {
649        return Err(MeanAdError::InvalidPeriod {
650            period: max_p,
651            data_len: data.len(),
652        });
653    }
654    if data.len() - first < max_p {
655        return Err(MeanAdError::NotEnoughValidData {
656            needed: max_p,
657            valid: data.len() - first,
658        });
659    }
660    let rows = combos.len();
661    let cols = data.len();
662
663    let warmup_periods: Vec<usize> = combos
664        .iter()
665        .map(|c| first + 2 * c.period.unwrap() - 2)
666        .collect();
667
668    let mut buf_mu = make_uninit_matrix(rows, cols);
669    init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
670
671    let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
672    let values_ptr = buf_guard.as_mut_ptr() as *mut f64;
673    let values_slice: &mut [f64] =
674        unsafe { core::slice::from_raw_parts_mut(values_ptr, buf_guard.len()) };
675
676    let chosen = match kern {
677        Kernel::Auto => Kernel::Scalar,
678        other => other,
679    };
680
681    let do_row = |row: usize, out_row: &mut [f64]| {
682        let period = combos[row].period.unwrap();
683
684        let warmup_end = first + 2 * period - 2;
685        for i in 0..warmup_end.min(out_row.len()) {
686            out_row[i] = f64::NAN;
687        }
688        match chosen {
689            Kernel::Scalar => mean_ad_row_scalar(data, first, period, out_row),
690            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
691            Kernel::Avx2 => mean_ad_row_avx2(data, first, period, out_row),
692            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
693            Kernel::Avx512 => mean_ad_row_avx512(data, first, period, out_row),
694            _ => mean_ad_row_scalar(data, first, period, out_row),
695        }
696    };
697    if parallel {
698        #[cfg(not(target_arch = "wasm32"))]
699        {
700            values_slice
701                .par_chunks_mut(cols)
702                .enumerate()
703                .for_each(|(row, slice)| do_row(row, slice));
704        }
705
706        #[cfg(target_arch = "wasm32")]
707        {
708            for (row, slice) in values_slice.chunks_mut(cols).enumerate() {
709                do_row(row, slice);
710            }
711        }
712    } else {
713        for (row, slice) in values_slice.chunks_mut(cols).enumerate() {
714            do_row(row, slice);
715        }
716    }
717
718    let values = unsafe {
719        Vec::from_raw_parts(
720            buf_guard.as_mut_ptr() as *mut f64,
721            buf_guard.len(),
722            buf_guard.capacity(),
723        )
724    };
725
726    Ok(MeanAdBatchOutput {
727        values,
728        combos,
729        rows,
730        cols,
731    })
732}
733
734#[inline(always)]
735pub fn mean_ad_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
736    if first + period > data.len() {
737        return;
738    }
739
740    let n = data.len();
741    let inv_p = 1.0f64 / (period as f64);
742
743    let mut sum = 0.0f64;
744    for i in first..(first + period) {
745        sum += data[i];
746    }
747    let mut sma = sum * inv_p;
748
749    let mut residual_buffer = vec![0.0f64; period];
750    let mut buffer_index = 0usize;
751    let mut residual_sum = 0.0f64;
752
753    let start_t = first + period - 1;
754    let fill_t_end = (start_t + period - 1).min(n.saturating_sub(1));
755    for t in start_t..=fill_t_end {
756        let residual = (data[t] - sma).abs();
757        residual_buffer[buffer_index] = residual;
758        residual_sum += residual;
759        buffer_index += 1;
760        if buffer_index == period {
761            buffer_index = 0;
762        }
763        if t + 1 < n {
764            sum += data[t + 1] - data[t + 1 - period];
765            sma = sum * inv_p;
766        }
767    }
768
769    let first_output = first + (period << 1) - 2;
770    if first_output < n {
771        out[first_output] = residual_sum * inv_p;
772    }
773
774    let mut t = start_t + period;
775    while t < n {
776        let residual = (data[t] - sma).abs();
777
778        let old = residual_buffer[buffer_index];
779        residual_sum += residual - old;
780        residual_buffer[buffer_index] = residual;
781        buffer_index += 1;
782        if buffer_index == period {
783            buffer_index = 0;
784        }
785
786        out[t] = residual_sum * inv_p;
787
788        if t + 1 < n {
789            sum += data[t + 1] - data[t + 1 - period];
790            sma = sum * inv_p;
791        }
792        t += 1;
793    }
794}
795
796#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
797#[inline(always)]
798pub fn mean_ad_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
799    mean_ad_row_scalar(data, first, period, out);
800}
801
802#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
803#[inline(always)]
804pub fn mean_ad_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
805    if period <= 32 {
806        mean_ad_row_avx512_short(data, first, period, out);
807    } else {
808        mean_ad_row_avx512_long(data, first, period, out);
809    }
810}
811
812#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
813#[inline(always)]
814pub fn mean_ad_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
815    mean_ad_row_scalar(data, first, period, out);
816}
817
818#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
819#[inline(always)]
820pub fn mean_ad_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
821    mean_ad_row_scalar(data, first, period, out);
822}
823
824#[derive(Debug, Clone)]
825pub struct MeanAdStream {
826    period: usize,
827    buffer: Vec<f64>,
828    head: usize,
829    filled: bool,
830    mean_buffer: Vec<f64>,
831    mean_head: usize,
832    mean_filled: bool,
833    mean: f64,
834    mad: f64,
835}
836
837impl MeanAdStream {
838    pub fn try_new(params: MeanAdParams) -> Result<Self, MeanAdError> {
839        let period = params.period.unwrap_or(5);
840        if period == 0 {
841            return Err(MeanAdError::InvalidPeriod {
842                period,
843                data_len: 0,
844            });
845        }
846        Ok(Self {
847            period,
848            buffer: vec![f64::NAN; period],
849            head: 0,
850            filled: false,
851            mean_buffer: vec![f64::NAN; period],
852            mean_head: 0,
853            mean_filled: false,
854            mean: 0.0,
855            mad: 0.0,
856        })
857    }
858    #[inline(always)]
859    pub fn update(&mut self, value: f64) -> Option<f64> {
860        let p = self.period;
861        let inv_p = 1.0f64 / (p as f64);
862
863        let price_idx = self.head;
864        let old_x = self.buffer[price_idx];
865
866        let resid_idx = self.mean_head;
867        let old_r = self.mean_buffer[resid_idx];
868
869        self.buffer[price_idx] = value;
870        let next_price_idx = price_idx + 1;
871        let wrapped_prices = next_price_idx == p;
872        self.head = if wrapped_prices { 0 } else { next_price_idx };
873
874        let just_filled_prices = !self.filled && wrapped_prices;
875        if just_filled_prices {
876            self.filled = true;
877        }
878
879        if !self.filled {
880            self.mean += value;
881            return None;
882        }
883
884        let sum_t = if just_filled_prices {
885            self.mean + value
886        } else {
887            self.mean + value - old_x
888        };
889        let mean_t = sum_t * inv_p;
890
891        self.mean = sum_t;
892
893        let resid_t = (value - mean_t).abs();
894
895        self.mean_buffer[resid_idx] = resid_t;
896        let next_resid_idx = resid_idx + 1;
897        let wrapped_resids = next_resid_idx == p;
898        self.mean_head = if wrapped_resids { 0 } else { next_resid_idx };
899
900        if !self.mean_filled {
901            self.mad += resid_t;
902            if wrapped_resids {
903                self.mean_filled = true;
904                return Some(self.mad * inv_p);
905            }
906            return None;
907        }
908
909        self.mad = self.mad + resid_t - old_r;
910        Some(self.mad * inv_p)
911    }
912}
913
914#[cfg(test)]
915mod tests {
916    use super::*;
917    use crate::skip_if_unsupported;
918    use crate::utilities::data_loader::read_candles_from_csv;
919    use paste::paste;
920
921    fn check_mean_ad_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
922        skip_if_unsupported!(kernel, test_name);
923        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
924        let candles = read_candles_from_csv(file_path)?;
925        let default_params = MeanAdParams { period: None };
926        let input = MeanAdInput::from_candles(&candles, "close", default_params);
927        let output = mean_ad_with_kernel(&input, kernel)?;
928        assert_eq!(output.values.len(), candles.close.len());
929        Ok(())
930    }
931
932    fn check_mean_ad_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
933        skip_if_unsupported!(kernel, test_name);
934        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
935        let candles = read_candles_from_csv(file_path)?;
936        let input = MeanAdInput::from_candles(&candles, "hl2", MeanAdParams { period: Some(5) });
937        let result = mean_ad_with_kernel(&input, kernel)?;
938        assert_eq!(result.values.len(), candles.close.len());
939        let expected_last_five = [
940            199.71999999999971,
941            104.14000000000087,
942            133.4,
943            100.54000000000087,
944            117.98000000000029,
945        ];
946        let start = result.values.len().saturating_sub(5);
947        for (i, &val) in result.values[start..].iter().enumerate() {
948            let diff = (val - expected_last_five[i]).abs();
949            assert!(
950                diff < 1e-1,
951                "[{}] MeanAd {:?} mismatch at idx {}: got {}, expected {}",
952                test_name,
953                kernel,
954                i,
955                val,
956                expected_last_five[i]
957            );
958        }
959        Ok(())
960    }
961
962    fn check_mean_ad_default_candles(
963        test_name: &str,
964        kernel: Kernel,
965    ) -> Result<(), Box<dyn Error>> {
966        skip_if_unsupported!(kernel, test_name);
967        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
968        let candles = read_candles_from_csv(file_path)?;
969        let input = MeanAdInput::with_default_candles(&candles);
970        match input.data {
971            MeanAdData::Candles { source, .. } => assert_eq!(source, "close"),
972            _ => panic!("Expected MeanAdData::Candles"),
973        }
974        let output = mean_ad_with_kernel(&input, kernel)?;
975        assert_eq!(output.values.len(), candles.close.len());
976        Ok(())
977    }
978
979    fn check_mean_ad_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
980        skip_if_unsupported!(kernel, test_name);
981        let input_data = [10.0, 20.0, 30.0];
982        let params = MeanAdParams { period: Some(0) };
983        let input = MeanAdInput::from_slice(&input_data, params);
984        let res = mean_ad_with_kernel(&input, kernel);
985        assert!(
986            res.is_err(),
987            "[{}] MeanAd should fail with zero period",
988            test_name
989        );
990        Ok(())
991    }
992
993    fn check_mean_ad_period_exceeds_length(
994        test_name: &str,
995        kernel: Kernel,
996    ) -> Result<(), Box<dyn Error>> {
997        skip_if_unsupported!(kernel, test_name);
998        let data_small = [10.0, 20.0, 30.0];
999        let params = MeanAdParams { period: Some(10) };
1000        let input = MeanAdInput::from_slice(&data_small, params);
1001        let res = mean_ad_with_kernel(&input, kernel);
1002        assert!(
1003            res.is_err(),
1004            "[{}] MeanAd should fail with period exceeding length",
1005            test_name
1006        );
1007        Ok(())
1008    }
1009
1010    fn check_mean_ad_very_small_dataset(
1011        test_name: &str,
1012        kernel: Kernel,
1013    ) -> Result<(), Box<dyn Error>> {
1014        skip_if_unsupported!(kernel, test_name);
1015        let single_point = [42.0];
1016        let params = MeanAdParams { period: Some(5) };
1017        let input = MeanAdInput::from_slice(&single_point, params);
1018        let res = mean_ad_with_kernel(&input, kernel);
1019        assert!(
1020            res.is_err(),
1021            "[{}] MeanAd should fail with insufficient data",
1022            test_name
1023        );
1024        Ok(())
1025    }
1026
1027    fn check_mean_ad_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1028        skip_if_unsupported!(kernel, test_name);
1029        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1030        let candles = read_candles_from_csv(file_path)?;
1031        let first_params = MeanAdParams { period: Some(5) };
1032        let first_input = MeanAdInput::from_candles(&candles, "close", first_params);
1033        let first_result = mean_ad_with_kernel(&first_input, kernel)?;
1034        let params2 = MeanAdParams { period: Some(3) };
1035        let second_input = MeanAdInput::from_slice(&first_result.values, params2);
1036        let second_result = mean_ad_with_kernel(&second_input, kernel)?;
1037        assert_eq!(second_result.values.len(), first_result.values.len());
1038        Ok(())
1039    }
1040
1041    fn check_mean_ad_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1042        skip_if_unsupported!(kernel, test_name);
1043        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1044        let candles = read_candles_from_csv(file_path)?;
1045        let input = MeanAdInput::from_candles(&candles, "close", MeanAdParams { period: Some(5) });
1046        let res = mean_ad_with_kernel(&input, kernel)?;
1047        assert_eq!(res.values.len(), candles.close.len());
1048        if res.values.len() > 240 {
1049            for (i, &val) in res.values[240..].iter().enumerate() {
1050                assert!(
1051                    !val.is_nan(),
1052                    "[{}] Found unexpected NaN at out-index {}",
1053                    test_name,
1054                    240 + i
1055                );
1056            }
1057        }
1058        Ok(())
1059    }
1060
1061    fn check_mean_ad_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1062        skip_if_unsupported!(kernel, test_name);
1063        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1064        let candles = read_candles_from_csv(file_path)?;
1065        let period = 5;
1066        let input = MeanAdInput::from_candles(
1067            &candles,
1068            "close",
1069            MeanAdParams {
1070                period: Some(period),
1071            },
1072        );
1073        let batch_output = mean_ad_with_kernel(&input, kernel)?.values;
1074        let mut stream = MeanAdStream::try_new(MeanAdParams {
1075            period: Some(period),
1076        })?;
1077
1078        let mut stream_uninit: Vec<MaybeUninit<f64>> = Vec::with_capacity(candles.close.len());
1079        unsafe {
1080            stream_uninit.set_len(candles.close.len());
1081        }
1082
1083        for (i, &price) in candles.close.iter().enumerate() {
1084            let val = match stream.update(price) {
1085                Some(mean_ad_val) => mean_ad_val,
1086                None => f64::NAN,
1087            };
1088            stream_uninit[i] = MaybeUninit::new(val);
1089        }
1090
1091        let stream_values = unsafe {
1092            let ptr = stream_uninit.as_mut_ptr() as *mut f64;
1093            let len = stream_uninit.len();
1094            let cap = stream_uninit.capacity();
1095            std::mem::forget(stream_uninit);
1096            Vec::from_raw_parts(ptr, len, cap)
1097        };
1098        assert_eq!(batch_output.len(), stream_values.len());
1099        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1100            if b.is_nan() && s.is_nan() {
1101                continue;
1102            }
1103            let diff = (b - s).abs();
1104            assert!(
1105                diff < 1e-9,
1106                "[{}] MeanAd streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1107                test_name,
1108                i,
1109                b,
1110                s,
1111                diff
1112            );
1113        }
1114        Ok(())
1115    }
1116
1117    #[cfg(debug_assertions)]
1118    fn check_mean_ad_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1119        skip_if_unsupported!(kernel, test_name);
1120
1121        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1122        let candles = read_candles_from_csv(file_path)?;
1123
1124        let test_params = vec![
1125            MeanAdParams::default(),
1126            MeanAdParams { period: Some(2) },
1127            MeanAdParams { period: Some(3) },
1128            MeanAdParams { period: Some(5) },
1129            MeanAdParams { period: Some(7) },
1130            MeanAdParams { period: Some(10) },
1131            MeanAdParams { period: Some(14) },
1132            MeanAdParams { period: Some(20) },
1133            MeanAdParams { period: Some(30) },
1134            MeanAdParams { period: Some(50) },
1135            MeanAdParams { period: Some(100) },
1136            MeanAdParams { period: Some(200) },
1137        ];
1138
1139        for (param_idx, params) in test_params.iter().enumerate() {
1140            let input = MeanAdInput::from_candles(&candles, "close", params.clone());
1141            let output = mean_ad_with_kernel(&input, kernel)?;
1142
1143            for (i, &val) in output.values.iter().enumerate() {
1144                if val.is_nan() {
1145                    continue;
1146                }
1147
1148                let bits = val.to_bits();
1149
1150                if bits == 0x11111111_11111111 {
1151                    panic!(
1152                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1153						 with params: period={} (param set {})",
1154                        test_name,
1155                        val,
1156                        bits,
1157                        i,
1158                        params.period.unwrap_or(5),
1159                        param_idx
1160                    );
1161                }
1162
1163                if bits == 0x22222222_22222222 {
1164                    panic!(
1165                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1166						 with params: period={} (param set {})",
1167                        test_name,
1168                        val,
1169                        bits,
1170                        i,
1171                        params.period.unwrap_or(5),
1172                        param_idx
1173                    );
1174                }
1175
1176                if bits == 0x33333333_33333333 {
1177                    panic!(
1178                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1179						 with params: period={} (param set {})",
1180                        test_name,
1181                        val,
1182                        bits,
1183                        i,
1184                        params.period.unwrap_or(5),
1185                        param_idx
1186                    );
1187                }
1188            }
1189        }
1190
1191        Ok(())
1192    }
1193
1194    #[cfg(not(debug_assertions))]
1195    fn check_mean_ad_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1196        Ok(())
1197    }
1198
1199    macro_rules! generate_all_mean_ad_tests {
1200        ($($test_fn:ident),*) => {
1201            paste! {
1202                $(
1203                    #[test]
1204                    fn [<$test_fn _scalar_f64>]() {
1205                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1206                    }
1207                )*
1208                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1209                $(
1210                    #[test]
1211                    fn [<$test_fn _avx2_f64>]() {
1212                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1213                    }
1214                    #[test]
1215                    fn [<$test_fn _avx512_f64>]() {
1216                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1217                    }
1218                )*
1219            }
1220        }
1221    }
1222    #[cfg(feature = "proptest")]
1223    #[allow(clippy::float_cmp)]
1224    fn check_mean_ad_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1225        use proptest::prelude::*;
1226        skip_if_unsupported!(kernel, test_name);
1227
1228        let strat = (2usize..=64)
1229            .prop_flat_map(|period| {
1230                (
1231                    prop::collection::vec(
1232                        (10.0f64..1000.0f64).prop_filter("finite", |x| x.is_finite()),
1233                        period..400,
1234                    ),
1235                    Just(period),
1236                    prop::bool::weighted(0.1),
1237                )
1238            })
1239            .prop_map(|(mut data, period, make_constant)| {
1240                if make_constant && data.len() > 0 {
1241                    let constant_val = data[0];
1242                    data.iter_mut().for_each(|v| *v = constant_val);
1243                }
1244                (data, period)
1245            });
1246
1247        proptest::test_runner::TestRunner::default().run(&strat, |(data, period)| {
1248            let params = MeanAdParams {
1249                period: Some(period),
1250            };
1251            let input = MeanAdInput::from_slice(&data, params);
1252
1253            let MeanAdOutput { values: out } = mean_ad_with_kernel(&input, kernel)?;
1254            let MeanAdOutput { values: ref_out } = mean_ad_with_kernel(&input, Kernel::Scalar)?;
1255
1256            prop_assert_eq!(
1257                out.len(),
1258                data.len(),
1259                "[{}] Output length mismatch",
1260                test_name
1261            );
1262
1263            let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1264            let warmup_period = first_valid + 2 * period - 2;
1265
1266            for i in 0..warmup_period.min(out.len()) {
1267                prop_assert!(
1268                    out[i].is_nan(),
1269                    "[{}] Expected NaN during warmup at index {}, got {}",
1270                    test_name,
1271                    i,
1272                    out[i]
1273                );
1274            }
1275
1276            for i in warmup_period..out.len() {
1277                if !out[i].is_nan() {
1278                    prop_assert!(
1279                        out[i] >= -1e-10,
1280                        "[{}] MAD should be non-negative at index {}: got {}",
1281                        test_name,
1282                        i,
1283                        out[i]
1284                    );
1285                }
1286            }
1287
1288            for i in 0..out.len() {
1289                let y = out[i];
1290                let r = ref_out[i];
1291
1292                if y.is_nan() || r.is_nan() {
1293                    prop_assert!(
1294                        y.is_nan() && r.is_nan(),
1295                        "[{}] NaN mismatch at index {}: kernel={}, scalar={}",
1296                        test_name,
1297                        i,
1298                        y,
1299                        r
1300                    );
1301                    continue;
1302                }
1303
1304                let y_bits = y.to_bits();
1305                let r_bits = r.to_bits();
1306                let ulp_diff = y_bits.abs_diff(r_bits);
1307
1308                prop_assert!(
1309                    (y - r).abs() <= 1e-9 || ulp_diff <= 5,
1310                    "[{}] Kernel mismatch at index {}: kernel={}, scalar={}, diff={}, ULP={}",
1311                    test_name,
1312                    i,
1313                    y,
1314                    r,
1315                    (y - r).abs(),
1316                    ulp_diff
1317                );
1318            }
1319
1320            let is_constant = data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
1321            if is_constant && out.len() > warmup_period {
1322                for i in warmup_period..out.len() {
1323                    if !out[i].is_nan() {
1324                        prop_assert!(
1325                            out[i].abs() <= 1e-9,
1326                            "[{}] MAD should be ~0 for constant data at index {}: got {}",
1327                            test_name,
1328                            i,
1329                            out[i]
1330                        );
1331                    }
1332                }
1333            }
1334
1335            let is_linear_monotonic = if data.len() >= 3 {
1336                let diffs: Vec<f64> = data.windows(2).map(|w| w[1] - w[0]).collect();
1337                let first_diff = diffs[0];
1338                diffs.iter().all(|&d| (d - first_diff).abs() < 1e-9)
1339            } else {
1340                false
1341            };
1342
1343            if is_linear_monotonic && out.len() > warmup_period + period {
1344                for i in (warmup_period + 1)..out.len() {
1345                    if !out[i].is_nan() && !out[i - 1].is_nan() && out[i - 1] > 1e-10 {
1346                        let change_ratio = (out[i] - out[i - 1]).abs() / out[i - 1];
1347
1348                        prop_assert!(
1349								change_ratio <= 0.1,
1350								"[{}] MAD changes too much for linear data at index {}: {} -> {} ({:.2}% change)",
1351								test_name,
1352								i,
1353								out[i-1],
1354								out[i],
1355								change_ratio * 100.0
1356							);
1357                    }
1358                }
1359            }
1360
1361            for i in warmup_period..out.len() {
1362                if out[i].is_nan() || i < period {
1363                    continue;
1364                }
1365
1366                let window_start = i + 1 - period;
1367                let window = &data[window_start..=i];
1368
1369                let window_min = window.iter().cloned().fold(f64::INFINITY, f64::min);
1370                let window_max = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1371                let window_range = window_max - window_min;
1372
1373                prop_assert!(
1374                    out[i] <= window_range / 2.0 + 1e-9,
1375                    "[{}] MAD exceeds half window range at index {}: MAD={}, window_range/2={}",
1376                    test_name,
1377                    i,
1378                    out[i],
1379                    window_range / 2.0
1380                );
1381            }
1382
1383            if period == data.len() && out.len() > warmup_period {
1384                let non_nan_count = out.iter().filter(|&&v| !v.is_nan()).count();
1385                prop_assert!(
1386                    non_nan_count <= 1,
1387                    "[{}] With period={}, expected at most 1 non-NaN value, got {}",
1388                    test_name,
1389                    period,
1390                    non_nan_count
1391                );
1392            }
1393
1394            if period == 2 && out.len() > warmup_period {
1395                for i in warmup_period..out.len() {
1396                    if !out[i].is_nan() && i >= 1 {
1397                        let expected = (data[i] - data[i - 1]).abs() / 2.0;
1398                        prop_assert!(
1399                            (out[i] - expected).abs() <= 1e-9,
1400                            "[{}] For period=2 at index {}, MAD should be {}, got {}",
1401                            test_name,
1402                            i,
1403                            expected,
1404                            out[i]
1405                        );
1406                    }
1407                }
1408            }
1409
1410            Ok(())
1411        })?;
1412
1413        Ok(())
1414    }
1415
1416    generate_all_mean_ad_tests!(
1417        check_mean_ad_partial_params,
1418        check_mean_ad_accuracy,
1419        check_mean_ad_default_candles,
1420        check_mean_ad_zero_period,
1421        check_mean_ad_period_exceeds_length,
1422        check_mean_ad_very_small_dataset,
1423        check_mean_ad_reinput,
1424        check_mean_ad_nan_handling,
1425        check_mean_ad_streaming,
1426        check_mean_ad_no_poison
1427    );
1428
1429    #[cfg(feature = "proptest")]
1430    generate_all_mean_ad_tests!(check_mean_ad_property);
1431
1432    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1433        skip_if_unsupported!(kernel, test);
1434
1435        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1436        let c = read_candles_from_csv(file)?;
1437        let output = MeanAdBatchBuilder::new()
1438            .kernel(kernel)
1439            .apply_candles(&c, "hl2")?;
1440
1441        let def = MeanAdParams::default();
1442        let row = output.values_for(&def).expect("default row missing");
1443
1444        assert_eq!(row.len(), c.close.len());
1445
1446        let expected = [
1447            199.71999999999971,
1448            104.14000000000087,
1449            133.4,
1450            100.54000000000087,
1451            117.98000000000029,
1452        ];
1453        let start = row.len() - 5;
1454        for (i, &v) in row[start..].iter().enumerate() {
1455            assert!(
1456                (v - expected[i]).abs() < 1e-1,
1457                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1458            );
1459        }
1460        Ok(())
1461    }
1462
1463    #[cfg(debug_assertions)]
1464    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1465        skip_if_unsupported!(kernel, test);
1466
1467        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1468        let c = read_candles_from_csv(file)?;
1469
1470        let test_configs = vec![
1471            (2, 10, 2),
1472            (5, 25, 5),
1473            (30, 60, 15),
1474            (2, 5, 1),
1475            (10, 50, 10),
1476            (3, 15, 3),
1477            (20, 30, 2),
1478            (7, 21, 7),
1479            (100, 200, 50),
1480        ];
1481
1482        for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1483            let output = MeanAdBatchBuilder::new()
1484                .kernel(kernel)
1485                .period_range(p_start, p_end, p_step)
1486                .apply_candles(&c, "close")?;
1487
1488            for (idx, &val) in output.values.iter().enumerate() {
1489                if val.is_nan() {
1490                    continue;
1491                }
1492
1493                let bits = val.to_bits();
1494                let row = idx / output.cols;
1495                let col = idx % output.cols;
1496                let combo = &output.combos[row];
1497
1498                if bits == 0x11111111_11111111 {
1499                    panic!(
1500                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1501						 at row {} col {} (flat index {}) with params: period={}",
1502                        test,
1503                        cfg_idx,
1504                        val,
1505                        bits,
1506                        row,
1507                        col,
1508                        idx,
1509                        combo.period.unwrap_or(5)
1510                    );
1511                }
1512
1513                if bits == 0x22222222_22222222 {
1514                    panic!(
1515                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1516						 at row {} col {} (flat index {}) with params: period={}",
1517                        test,
1518                        cfg_idx,
1519                        val,
1520                        bits,
1521                        row,
1522                        col,
1523                        idx,
1524                        combo.period.unwrap_or(5)
1525                    );
1526                }
1527
1528                if bits == 0x33333333_33333333 {
1529                    panic!(
1530                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1531						 at row {} col {} (flat index {}) with params: period={}",
1532                        test,
1533                        cfg_idx,
1534                        val,
1535                        bits,
1536                        row,
1537                        col,
1538                        idx,
1539                        combo.period.unwrap_or(5)
1540                    );
1541                }
1542            }
1543        }
1544
1545        Ok(())
1546    }
1547
1548    #[cfg(not(debug_assertions))]
1549    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1550        Ok(())
1551    }
1552
1553    #[test]
1554    fn test_mean_ad_into_matches_api() -> Result<(), Box<dyn Error>> {
1555        let n = 256usize;
1556        let mut data = Vec::with_capacity(n);
1557        data.push(f64::NAN);
1558        data.push(f64::NAN);
1559        data.push(f64::NAN);
1560        for i in 3..n {
1561            let t = i as f64;
1562            let v = (t * 0.01).mul_add((t * 0.1).sin() * 100.0, (t * 0.05).cos() * 0.5);
1563            data.push(v);
1564        }
1565
1566        let input = MeanAdInput::from_slice(&data, MeanAdParams::default());
1567        let baseline = mean_ad(&input)?.values;
1568
1569        let mut into_out = vec![0.0; baseline.len()];
1570        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1571        {
1572            mean_ad_into(&input, &mut into_out)?;
1573        }
1574        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1575        {
1576            mean_ad_into_slice(&mut into_out, &input, Kernel::Auto)?;
1577        }
1578
1579        assert_eq!(baseline.len(), into_out.len());
1580        #[inline]
1581        fn eq_or_both_nan(a: f64, b: f64) -> bool {
1582            (a.is_nan() && b.is_nan()) || (a == b) || (a - b).abs() <= 1e-12
1583        }
1584        for (i, (a, b)) in baseline.iter().zip(into_out.iter()).enumerate() {
1585            assert!(
1586                eq_or_both_nan(*a, *b),
1587                "divergence at idx {}: api={}, into={}",
1588                i,
1589                a,
1590                b
1591            );
1592        }
1593        Ok(())
1594    }
1595
1596    macro_rules! gen_batch_tests {
1597        ($fn_name:ident) => {
1598            paste! {
1599                #[test] fn [<$fn_name _scalar>]()      {
1600                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1601                }
1602                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1603                #[test] fn [<$fn_name _avx2>]()        {
1604                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1605                }
1606                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1607                #[test] fn [<$fn_name _avx512>]()      {
1608                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1609                }
1610                #[test] fn [<$fn_name _auto_detect>]() {
1611                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1612                }
1613            }
1614        };
1615    }
1616    gen_batch_tests!(check_batch_default_row);
1617    gen_batch_tests!(check_batch_no_poison);
1618}
1619
1620#[cfg(feature = "python")]
1621#[pyfunction(name = "mean_ad")]
1622#[pyo3(signature = (data, period, kernel=None))]
1623pub fn mean_ad_py<'py>(
1624    py: Python<'py>,
1625    data: PyReadonlyArray1<'py, f64>,
1626    period: usize,
1627    kernel: Option<&str>,
1628) -> PyResult<Bound<'py, PyArray1<f64>>> {
1629    let slice_in = data.as_slice()?;
1630    let kern = validate_kernel(kernel, false)?;
1631
1632    let params = MeanAdParams {
1633        period: Some(period),
1634    };
1635    let input = MeanAdInput::from_slice(slice_in, params);
1636
1637    let result_vec: Vec<f64> = py
1638        .allow_threads(|| mean_ad_with_kernel(&input, kern).map(|o| o.values))
1639        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1640
1641    Ok(result_vec.into_pyarray(py))
1642}
1643
1644#[cfg(feature = "python")]
1645#[pyfunction(name = "mean_ad_batch")]
1646#[pyo3(signature = (data, period_range, kernel=None))]
1647pub fn mean_ad_batch_py<'py>(
1648    py: Python<'py>,
1649    data: PyReadonlyArray1<'py, f64>,
1650    period_range: (usize, usize, usize),
1651    kernel: Option<&str>,
1652) -> PyResult<Bound<'py, PyDict>> {
1653    let slice_in = data.as_slice()?;
1654    let kern = validate_kernel(kernel, true)?;
1655
1656    let sweep = MeanAdBatchRange {
1657        period: period_range,
1658    };
1659
1660    let combos_for_shape = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1661    let rows = combos_for_shape.len();
1662    let cols = slice_in.len();
1663
1664    let total = rows
1665        .checked_mul(cols)
1666        .ok_or_else(|| PyValueError::new_err("mean_ad_batch_py: size overflow"))?;
1667    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1668    let slice_out = unsafe { out_arr.as_slice_mut()? };
1669
1670    let combos = py
1671        .allow_threads(|| {
1672            let kernel = match kern {
1673                Kernel::Auto => detect_best_batch_kernel(),
1674                k => k,
1675            };
1676            let simd = match kernel {
1677                Kernel::Avx512Batch => Kernel::Avx512,
1678                Kernel::Avx2Batch => Kernel::Avx2,
1679                Kernel::ScalarBatch => Kernel::Scalar,
1680                _ => unreachable!(),
1681            };
1682            mean_ad_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1683        })
1684        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1685
1686    let dict = PyDict::new(py);
1687    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1688    dict.set_item(
1689        "periods",
1690        combos
1691            .iter()
1692            .map(|p| p.period.unwrap() as u64)
1693            .collect::<Vec<_>>()
1694            .into_pyarray(py),
1695    )?;
1696
1697    Ok(dict)
1698}
1699
1700#[cfg(all(feature = "python", feature = "cuda"))]
1701#[pyfunction(name = "mean_ad_cuda_batch_dev")]
1702#[pyo3(signature = (data_f32, period_range, device_id=0))]
1703pub fn mean_ad_cuda_batch_dev_py<'py>(
1704    py: Python<'py>,
1705    data_f32: numpy::PyReadonlyArray1<'py, f32>,
1706    period_range: (usize, usize, usize),
1707    device_id: usize,
1708) -> PyResult<DeviceArrayF32Py> {
1709    if !crate::cuda::cuda_available() {
1710        return Err(PyValueError::new_err("CUDA not available"));
1711    }
1712    let slice_in = data_f32.as_slice()?;
1713    let sweep = MeanAdBatchRange {
1714        period: period_range,
1715    };
1716    let inner = py.allow_threads(|| {
1717        let cuda = CudaMeanAd::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1718        cuda.mean_ad_batch_dev(slice_in, &sweep)
1719            .map_err(|e| PyValueError::new_err(e.to_string()))
1720    })?;
1721    make_device_array_py(device_id, inner)
1722}
1723
1724#[cfg(all(feature = "python", feature = "cuda"))]
1725#[pyfunction(name = "mean_ad_cuda_many_series_one_param_dev")]
1726#[pyo3(signature = (data_tm_f32, cols, rows, period, device_id=0))]
1727pub fn mean_ad_cuda_many_series_one_param_dev_py<'py>(
1728    py: Python<'py>,
1729    data_tm_f32: numpy::PyReadonlyArray1<'py, f32>,
1730    cols: usize,
1731    rows: usize,
1732    period: usize,
1733    device_id: usize,
1734) -> PyResult<DeviceArrayF32Py> {
1735    if !crate::cuda::cuda_available() {
1736        return Err(PyValueError::new_err("CUDA not available"));
1737    }
1738    let slice_in = data_tm_f32.as_slice()?;
1739    let params = MeanAdParams {
1740        period: Some(period),
1741    };
1742    let inner = py.allow_threads(|| {
1743        let cuda = CudaMeanAd::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1744        cuda.mean_ad_many_series_one_param_time_major_dev(slice_in, cols, rows, &params)
1745            .map_err(|e| PyValueError::new_err(e.to_string()))
1746    })?;
1747    make_device_array_py(device_id, inner)
1748}
1749
1750#[cfg(feature = "python")]
1751#[pyclass(name = "MeanAdStream")]
1752pub struct MeanAdStreamPy {
1753    stream: MeanAdStream,
1754}
1755
1756#[cfg(feature = "python")]
1757#[pymethods]
1758impl MeanAdStreamPy {
1759    #[new]
1760    fn new(period: usize) -> PyResult<Self> {
1761        let params = MeanAdParams {
1762            period: Some(period),
1763        };
1764        let stream =
1765            MeanAdStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1766        Ok(MeanAdStreamPy { stream })
1767    }
1768
1769    fn update(&mut self, value: f64) -> Option<f64> {
1770        self.stream.update(value)
1771    }
1772}
1773
1774pub fn mean_ad_into_slice(
1775    dst: &mut [f64],
1776    input: &MeanAdInput,
1777    kern: Kernel,
1778) -> Result<(), MeanAdError> {
1779    let data = input.as_ref();
1780    let period = match input.params.period {
1781        Some(p) if p > 0 => p,
1782        _ => {
1783            return Err(MeanAdError::InvalidPeriod {
1784                period: input.params.period.unwrap_or(0),
1785                data_len: data.len(),
1786            })
1787        }
1788    };
1789
1790    if data.is_empty() {
1791        return Err(MeanAdError::EmptyInputData);
1792    }
1793    if period > data.len() {
1794        return Err(MeanAdError::InvalidPeriod {
1795            period,
1796            data_len: data.len(),
1797        });
1798    }
1799    if dst.len() != data.len() {
1800        return Err(MeanAdError::OutputLengthMismatch {
1801            expected: data.len(),
1802            got: dst.len(),
1803        });
1804    }
1805
1806    let first = data
1807        .iter()
1808        .position(|x| !x.is_nan())
1809        .ok_or(MeanAdError::AllValuesNaN)?;
1810
1811    if (data.len() - first) < period {
1812        return Err(MeanAdError::NotEnoughValidData {
1813            needed: period,
1814            valid: data.len() - first,
1815        });
1816    }
1817
1818    let chosen = match kern {
1819        Kernel::Auto => Kernel::Scalar,
1820        other => other,
1821    };
1822
1823    let warmup_end = first + (period << 1) - 2;
1824    let warmup_end = warmup_end.min(dst.len());
1825    if warmup_end > 0 {
1826        dst[..warmup_end].fill(f64::NAN);
1827    }
1828
1829    match chosen {
1830        Kernel::Scalar | Kernel::ScalarBatch => mean_ad_row_scalar(data, first, period, dst),
1831        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1832        Kernel::Avx2 | Kernel::Avx2Batch => mean_ad_row_avx2(data, first, period, dst),
1833        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1834        Kernel::Avx512 | Kernel::Avx512Batch => mean_ad_row_avx512(data, first, period, dst),
1835        _ => unreachable!(),
1836    }
1837
1838    Ok(())
1839}
1840
1841#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1842#[wasm_bindgen]
1843pub fn mean_ad_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1844    let params = MeanAdParams {
1845        period: Some(period),
1846    };
1847    let input = MeanAdInput::from_slice(data, params);
1848
1849    let mut output = vec![0.0; data.len()];
1850    mean_ad_into_slice(&mut output, &input, Kernel::Auto)
1851        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1852
1853    Ok(output)
1854}
1855
1856#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1857#[wasm_bindgen]
1858pub fn mean_ad_into(
1859    in_ptr: *const f64,
1860    out_ptr: *mut f64,
1861    len: usize,
1862    period: usize,
1863) -> Result<(), JsValue> {
1864    if in_ptr.is_null() || out_ptr.is_null() {
1865        return Err(JsValue::from_str("null pointer passed to mean_ad_into"));
1866    }
1867
1868    unsafe {
1869        let data = std::slice::from_raw_parts(in_ptr, len);
1870        let params = MeanAdParams {
1871            period: Some(period),
1872        };
1873        let input = MeanAdInput::from_slice(data, params);
1874
1875        if in_ptr == out_ptr {
1876            let mut temp = vec![0.0; len];
1877            mean_ad_into_slice(&mut temp, &input, Kernel::Auto)
1878                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1879            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1880            out.copy_from_slice(&temp);
1881        } else {
1882            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1883            mean_ad_into_slice(out, &input, Kernel::Auto)
1884                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1885        }
1886        Ok(())
1887    }
1888}
1889
1890#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1891#[wasm_bindgen]
1892pub fn mean_ad_alloc(len: usize) -> *mut f64 {
1893    let mut vec = Vec::<f64>::with_capacity(len);
1894    let ptr = vec.as_mut_ptr();
1895    std::mem::forget(vec);
1896    ptr
1897}
1898
1899#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1900#[wasm_bindgen]
1901pub fn mean_ad_free(ptr: *mut f64, len: usize) {
1902    if !ptr.is_null() {
1903        unsafe {
1904            let _ = Vec::from_raw_parts(ptr, len, len);
1905        }
1906    }
1907}
1908
1909#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1910#[derive(Serialize, Deserialize)]
1911pub struct MeanAdBatchConfig {
1912    pub period_range: (usize, usize, usize),
1913}
1914
1915#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1916#[derive(Serialize, Deserialize)]
1917pub struct MeanAdBatchJsOutput {
1918    pub values: Vec<f64>,
1919    pub periods: Vec<usize>,
1920    pub rows: usize,
1921    pub cols: usize,
1922}
1923
1924#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1925#[wasm_bindgen(js_name = mean_ad_batch)]
1926pub fn mean_ad_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1927    let config: MeanAdBatchConfig = serde_wasm_bindgen::from_value(config)
1928        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1929
1930    let sweep = MeanAdBatchRange {
1931        period: config.period_range,
1932    };
1933
1934    let output = mean_ad_batch_inner(data, &sweep, Kernel::Auto, false)
1935        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1936
1937    let js_output = MeanAdBatchJsOutput {
1938        values: output.values,
1939        periods: output.combos.iter().map(|p| p.period.unwrap()).collect(),
1940        rows: output.rows,
1941        cols: output.cols,
1942    };
1943
1944    serde_wasm_bindgen::to_value(&js_output)
1945        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1946}
1947
1948#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1949#[wasm_bindgen]
1950pub fn mean_ad_batch_into(
1951    in_ptr: *const f64,
1952    out_ptr: *mut f64,
1953    len: usize,
1954    period_start: usize,
1955    period_end: usize,
1956    period_step: usize,
1957) -> Result<usize, JsValue> {
1958    if in_ptr.is_null() || out_ptr.is_null() {
1959        return Err(JsValue::from_str(
1960            "null pointer passed to mean_ad_batch_into",
1961        ));
1962    }
1963
1964    unsafe {
1965        let data = std::slice::from_raw_parts(in_ptr, len);
1966
1967        let sweep = MeanAdBatchRange {
1968            period: (period_start, period_end, period_step),
1969        };
1970
1971        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1972        let rows = combos.len();
1973        let cols = len;
1974
1975        if out_ptr as usize % 8 != 0 {
1976            return Err(JsValue::from_str("Output pointer must be 8-byte aligned"));
1977        }
1978
1979        let total = rows
1980            .checked_mul(cols)
1981            .ok_or_else(|| JsValue::from_str("mean_ad_batch_into: size overflow"))?;
1982        let out = std::slice::from_raw_parts_mut(out_ptr, total);
1983
1984        let kernel = detect_best_batch_kernel();
1985        let simd = match kernel {
1986            Kernel::Avx512Batch => Kernel::Avx512,
1987            Kernel::Avx2Batch => Kernel::Avx2,
1988            Kernel::ScalarBatch => Kernel::Scalar,
1989            _ => unreachable!(),
1990        };
1991
1992        mean_ad_batch_inner_into(data, &sweep, simd, false, out)
1993            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1994
1995        Ok(rows)
1996    }
1997}