1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, init_matrix_prefixes, make_uninit_matrix,
5};
6#[cfg(feature = "python")]
7use crate::utilities::kernel_validation::validate_kernel;
8use aligned_vec::{AVec, CACHELINE_ALIGN};
9#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
10use core::arch::x86_64::*;
11#[cfg(feature = "python")]
12use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
13#[cfg(feature = "python")]
14use pyo3::exceptions::PyValueError;
15#[cfg(feature = "python")]
16use pyo3::prelude::*;
17#[cfg(feature = "python")]
18use pyo3::types::PyDict;
19#[cfg(not(target_arch = "wasm32"))]
20use rayon::prelude::*;
21#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
22use serde::{Deserialize, Serialize};
23use std::convert::AsRef;
24use std::error::Error;
25use std::mem::MaybeUninit;
26use thiserror::Error;
27#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
28use wasm_bindgen::prelude::*;
29
30#[cfg(all(feature = "python", feature = "cuda"))]
31use crate::cuda::CudaMeanAd;
32#[cfg(all(feature = "python", feature = "cuda"))]
33use crate::indicators::moving_averages::alma::{make_device_array_py, DeviceArrayF32Py};
34
35impl<'a> AsRef<[f64]> for MeanAdInput<'a> {
36 #[inline(always)]
37 fn as_ref(&self) -> &[f64] {
38 match &self.data {
39 MeanAdData::Slice(slice) => slice,
40 MeanAdData::Candles { candles, source } => source_type(candles, source),
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
46pub enum MeanAdData<'a> {
47 Candles {
48 candles: &'a Candles,
49 source: &'a str,
50 },
51 Slice(&'a [f64]),
52}
53
54#[derive(Debug, Clone)]
55pub struct MeanAdOutput {
56 pub values: Vec<f64>,
57}
58
59#[derive(Debug, Clone)]
60pub struct MeanAdParams {
61 pub period: Option<usize>,
62}
63
64impl Default for MeanAdParams {
65 fn default() -> Self {
66 Self { period: Some(5) }
67 }
68}
69
70#[derive(Debug, Clone)]
71pub struct MeanAdInput<'a> {
72 pub data: MeanAdData<'a>,
73 pub params: MeanAdParams,
74}
75
76impl<'a> MeanAdInput<'a> {
77 #[inline]
78 pub fn from_candles(c: &'a Candles, s: &'a str, p: MeanAdParams) -> Self {
79 Self {
80 data: MeanAdData::Candles {
81 candles: c,
82 source: s,
83 },
84 params: p,
85 }
86 }
87 #[inline]
88 pub fn from_slice(sl: &'a [f64], p: MeanAdParams) -> Self {
89 Self {
90 data: MeanAdData::Slice(sl),
91 params: p,
92 }
93 }
94 #[inline]
95 pub fn with_default_candles(c: &'a Candles) -> Self {
96 Self::from_candles(c, "close", MeanAdParams::default())
97 }
98 #[inline]
99 pub fn get_period(&self) -> usize {
100 self.params.period.unwrap_or(5)
101 }
102}
103
104#[derive(Copy, Clone, Debug)]
105pub struct MeanAdBuilder {
106 period: Option<usize>,
107 kernel: Kernel,
108}
109
110impl Default for MeanAdBuilder {
111 fn default() -> Self {
112 Self {
113 period: None,
114 kernel: Kernel::Auto,
115 }
116 }
117}
118
119impl MeanAdBuilder {
120 #[inline(always)]
121 pub fn new() -> Self {
122 Self::default()
123 }
124 #[inline(always)]
125 pub fn period(mut self, n: usize) -> Self {
126 self.period = Some(n);
127 self
128 }
129 #[inline(always)]
130 pub fn kernel(mut self, k: Kernel) -> Self {
131 self.kernel = k;
132 self
133 }
134 #[inline(always)]
135 pub fn apply(self, c: &Candles) -> Result<MeanAdOutput, MeanAdError> {
136 let p = MeanAdParams {
137 period: self.period,
138 };
139 let i = MeanAdInput::from_candles(c, "close", p);
140 mean_ad_with_kernel(&i, self.kernel)
141 }
142 #[inline(always)]
143 pub fn apply_slice(self, d: &[f64]) -> Result<MeanAdOutput, MeanAdError> {
144 let p = MeanAdParams {
145 period: self.period,
146 };
147 let i = MeanAdInput::from_slice(d, p);
148 mean_ad_with_kernel(&i, self.kernel)
149 }
150 #[inline(always)]
151 pub fn into_stream(self) -> Result<MeanAdStream, MeanAdError> {
152 let p = MeanAdParams {
153 period: self.period,
154 };
155 MeanAdStream::try_new(p)
156 }
157}
158
159#[derive(Debug, Error)]
160pub enum MeanAdError {
161 #[error("mean_ad: Empty data provided.")]
162 EmptyInputData,
163 #[error("mean_ad: Invalid period: period = {period}, data length = {data_len}")]
164 InvalidPeriod { period: usize, data_len: usize },
165 #[error("mean_ad: Not enough valid data: needed = {needed}, valid = {valid}")]
166 NotEnoughValidData { needed: usize, valid: usize },
167 #[error("mean_ad: All values are NaN.")]
168 AllValuesNaN,
169 #[error("mean_ad: Output length mismatch: expected {expected}, got {got}")]
170 OutputLengthMismatch { expected: usize, got: usize },
171 #[error("mean_ad: Invalid range: start={start}, end={end}, step={step}")]
172 InvalidRange {
173 start: String,
174 end: String,
175 step: String,
176 },
177 #[error("mean_ad: Invalid kernel for batch: {0:?}")]
178 InvalidKernelForBatch(crate::utilities::enums::Kernel),
179}
180
181#[inline]
182pub fn mean_ad(input: &MeanAdInput) -> Result<MeanAdOutput, MeanAdError> {
183 mean_ad_with_kernel(input, Kernel::Auto)
184}
185
186pub fn mean_ad_with_kernel(
187 input: &MeanAdInput,
188 kernel: Kernel,
189) -> Result<MeanAdOutput, MeanAdError> {
190 let data: &[f64] = match &input.data {
191 MeanAdData::Candles { candles, source } => source_type(candles, source),
192 MeanAdData::Slice(sl) => sl,
193 };
194
195 if data.is_empty() {
196 return Err(MeanAdError::EmptyInputData);
197 }
198
199 let period = input.get_period();
200 let first = data
201 .iter()
202 .position(|x| !x.is_nan())
203 .ok_or(MeanAdError::AllValuesNaN)?;
204
205 if period == 0 || period > data.len() {
206 return Err(MeanAdError::InvalidPeriod {
207 period,
208 data_len: data.len(),
209 });
210 }
211 if (data.len() - first) < period {
212 return Err(MeanAdError::NotEnoughValidData {
213 needed: period,
214 valid: data.len() - first,
215 });
216 }
217
218 let chosen = match kernel {
219 Kernel::Auto => Kernel::Scalar,
220 other => other,
221 };
222
223 match chosen {
224 Kernel::Scalar | Kernel::ScalarBatch => mean_ad_scalar(data, period, first),
225 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
226 Kernel::Avx2 | Kernel::Avx2Batch => mean_ad_avx2(data, period, first),
227 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
228 Kernel::Avx512 | Kernel::Avx512Batch => mean_ad_avx512(data, period, first),
229 _ => unreachable!(),
230 }
231}
232
233#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
234#[inline]
235pub fn mean_ad_into(input: &MeanAdInput, out: &mut [f64]) -> Result<(), MeanAdError> {
236 mean_ad_into_slice(out, input, Kernel::Auto)
237}
238
239#[inline]
240pub fn mean_ad_scalar(
241 data: &[f64],
242 period: usize,
243 first: usize,
244) -> Result<MeanAdOutput, MeanAdError> {
245 if period == 0 {
246 return Err(MeanAdError::InvalidPeriod {
247 period: 0,
248 data_len: data.len(),
249 });
250 }
251 if first > data.len() {
252 return Err(MeanAdError::NotEnoughValidData {
253 needed: first,
254 valid: data.len(),
255 });
256 }
257
258 let n = data.len();
259
260 if first + period > n {
261 let out = alloc_with_nan_prefix(n, n);
262 return Ok(MeanAdOutput { values: out });
263 }
264
265 let inv_p = 1.0f64 / (period as f64);
266
267 let warmup_end = first + (period << 1) - 2;
268 let mut out = alloc_with_nan_prefix(n, warmup_end.min(n));
269
270 let mut sum = 0.0f64;
271 for i in first..(first + period) {
272 sum += data[i];
273 }
274 let mut sma = sum * inv_p;
275
276 let mut residual_buffer = vec![0.0f64; period];
277 let mut buffer_index = 0usize;
278 let mut residual_sum = 0.0f64;
279
280 let start_t = first + period - 1;
281 let fill_t_end = (start_t + period - 1).min(n.saturating_sub(1));
282 for t in start_t..=fill_t_end {
283 let residual = (data[t] - sma).abs();
284 residual_buffer[buffer_index] = residual;
285 residual_sum += residual;
286 buffer_index += 1;
287 if buffer_index == period {
288 buffer_index = 0;
289 }
290 if t + 1 < n {
291 sum += data[t + 1] - data[t + 1 - period];
292 sma = sum * inv_p;
293 }
294 }
295
296 if warmup_end < n {
297 out[warmup_end] = residual_sum * inv_p;
298 }
299
300 let mut t = start_t + period;
301 while t < n {
302 let residual = (data[t] - sma).abs();
303
304 let old = residual_buffer[buffer_index];
305 residual_sum += residual - old;
306 residual_buffer[buffer_index] = residual;
307 buffer_index += 1;
308 if buffer_index == period {
309 buffer_index = 0;
310 }
311
312 out[t] = residual_sum * inv_p;
313
314 if t + 1 < n {
315 sum += data[t + 1] - data[t + 1 - period];
316 sma = sum * inv_p;
317 }
318 t += 1;
319 }
320
321 Ok(MeanAdOutput { values: out })
322}
323
324#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
325#[inline]
326pub fn mean_ad_avx2(
327 data: &[f64],
328 period: usize,
329 first: usize,
330) -> Result<MeanAdOutput, MeanAdError> {
331 mean_ad_scalar(data, period, first)
332}
333
334#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
335#[inline]
336pub fn mean_ad_avx512(
337 data: &[f64],
338 period: usize,
339 first: usize,
340) -> Result<MeanAdOutput, MeanAdError> {
341 mean_ad_scalar(data, period, first)
342}
343
344#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
345#[inline]
346pub fn mean_ad_avx512_short(
347 data: &[f64],
348 period: usize,
349 first: usize,
350) -> Result<MeanAdOutput, MeanAdError> {
351 mean_ad_scalar(data, period, first)
352}
353
354#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
355#[inline]
356pub fn mean_ad_avx512_long(
357 data: &[f64],
358 period: usize,
359 first: usize,
360) -> Result<MeanAdOutput, MeanAdError> {
361 mean_ad_scalar(data, period, first)
362}
363
364#[inline(always)]
365pub fn mean_ad_batch_with_kernel(
366 data: &[f64],
367 sweep: &MeanAdBatchRange,
368 k: Kernel,
369) -> Result<MeanAdBatchOutput, MeanAdError> {
370 let kernel = match k {
371 Kernel::Auto => Kernel::ScalarBatch,
372 other if other.is_batch() => other,
373 other => return Err(MeanAdError::InvalidKernelForBatch(other)),
374 };
375 let simd = match kernel {
376 Kernel::Avx512Batch => Kernel::Avx512,
377 Kernel::Avx2Batch => Kernel::Avx2,
378 Kernel::ScalarBatch => Kernel::Scalar,
379 _ => unreachable!(),
380 };
381 mean_ad_batch_par_slice(data, sweep, simd)
382}
383
384#[derive(Clone, Debug)]
385pub struct MeanAdBatchRange {
386 pub period: (usize, usize, usize),
387}
388
389impl Default for MeanAdBatchRange {
390 fn default() -> Self {
391 Self {
392 period: (5, 254, 1),
393 }
394 }
395}
396
397#[derive(Clone, Debug, Default)]
398pub struct MeanAdBatchBuilder {
399 range: MeanAdBatchRange,
400 kernel: Kernel,
401}
402
403impl MeanAdBatchBuilder {
404 pub fn new() -> Self {
405 Self::default()
406 }
407 pub fn kernel(mut self, k: Kernel) -> Self {
408 self.kernel = k;
409 self
410 }
411 #[inline]
412 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
413 self.range.period = (start, end, step);
414 self
415 }
416 #[inline]
417 pub fn period_static(mut self, p: usize) -> Self {
418 self.range.period = (p, p, 0);
419 self
420 }
421 pub fn apply_slice(self, data: &[f64]) -> Result<MeanAdBatchOutput, MeanAdError> {
422 mean_ad_batch_with_kernel(data, &self.range, self.kernel)
423 }
424 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<MeanAdBatchOutput, MeanAdError> {
425 MeanAdBatchBuilder::new().kernel(k).apply_slice(data)
426 }
427 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<MeanAdBatchOutput, MeanAdError> {
428 let slice = source_type(c, src);
429 self.apply_slice(slice)
430 }
431 pub fn with_default_candles(c: &Candles) -> Result<MeanAdBatchOutput, MeanAdError> {
432 MeanAdBatchBuilder::new()
433 .kernel(Kernel::Auto)
434 .apply_candles(c, "close")
435 }
436}
437
438#[derive(Clone, Debug)]
439pub struct MeanAdBatchOutput {
440 pub values: Vec<f64>,
441 pub combos: Vec<MeanAdParams>,
442 pub rows: usize,
443 pub cols: usize,
444}
445impl MeanAdBatchOutput {
446 pub fn row_for_params(&self, p: &MeanAdParams) -> Option<usize> {
447 self.combos
448 .iter()
449 .position(|c| c.period.unwrap_or(5) == p.period.unwrap_or(5))
450 }
451 pub fn values_for(&self, p: &MeanAdParams) -> Option<&[f64]> {
452 self.row_for_params(p).and_then(|row| {
453 let start = row.checked_mul(self.cols)?;
454 self.values.get(start..start + self.cols)
455 })
456 }
457}
458
459#[inline(always)]
460fn expand_grid(r: &MeanAdBatchRange) -> Result<Vec<MeanAdParams>, MeanAdError> {
461 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, MeanAdError> {
462 if step == 0 || start == end {
463 return Ok(vec![start]);
464 }
465 if start < end {
466 let st = step.max(1);
467 let v: Vec<usize> = (start..=end).step_by(st).collect();
468 if v.is_empty() {
469 return Err(MeanAdError::InvalidRange {
470 start: start.to_string(),
471 end: end.to_string(),
472 step: step.to_string(),
473 });
474 }
475 return Ok(v);
476 }
477
478 let mut v = Vec::new();
479 let mut x = start as isize;
480 let end_i = end as isize;
481 let st = (step as isize).max(1);
482 while x >= end_i {
483 v.push(x as usize);
484 x -= st;
485 }
486 if v.is_empty() {
487 return Err(MeanAdError::InvalidRange {
488 start: start.to_string(),
489 end: end.to_string(),
490 step: step.to_string(),
491 });
492 }
493 Ok(v)
494 }
495
496 let periods = axis_usize(r.period)?;
497 if periods.is_empty() {
498 return Err(MeanAdError::InvalidRange {
499 start: r.period.0.to_string(),
500 end: r.period.1.to_string(),
501 step: r.period.2.to_string(),
502 });
503 }
504
505 let mut out = Vec::with_capacity(periods.len());
506 for p in periods {
507 out.push(MeanAdParams { period: Some(p) });
508 }
509 Ok(out)
510}
511
512#[inline(always)]
513pub fn mean_ad_batch_slice(
514 data: &[f64],
515 sweep: &MeanAdBatchRange,
516 kern: Kernel,
517) -> Result<MeanAdBatchOutput, MeanAdError> {
518 mean_ad_batch_inner(data, sweep, kern, false)
519}
520
521#[inline(always)]
522pub fn mean_ad_batch_par_slice(
523 data: &[f64],
524 sweep: &MeanAdBatchRange,
525 kern: Kernel,
526) -> Result<MeanAdBatchOutput, MeanAdError> {
527 mean_ad_batch_inner(data, sweep, kern, true)
528}
529
530#[inline(always)]
531fn mean_ad_batch_inner_into(
532 data: &[f64],
533 sweep: &MeanAdBatchRange,
534 kern: Kernel,
535 parallel: bool,
536 out: &mut [f64],
537) -> Result<Vec<MeanAdParams>, MeanAdError> {
538 if data.is_empty() {
539 return Err(MeanAdError::EmptyInputData);
540 }
541 let combos = expand_grid(sweep)?;
542 if combos.iter().any(|c| c.period.unwrap_or(0) == 0) {
543 return Err(MeanAdError::InvalidPeriod {
544 period: 0,
545 data_len: data.len(),
546 });
547 }
548 let expected =
549 combos
550 .len()
551 .checked_mul(data.len())
552 .ok_or_else(|| MeanAdError::OutputLengthMismatch {
553 expected: usize::MAX,
554 got: out.len(),
555 })?;
556 if out.len() != expected {
557 return Err(MeanAdError::OutputLengthMismatch {
558 expected,
559 got: out.len(),
560 });
561 }
562 let first = data
563 .iter()
564 .position(|x| !x.is_nan())
565 .ok_or(MeanAdError::AllValuesNaN)?;
566 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
567 if max_p > data.len() {
568 return Err(MeanAdError::InvalidPeriod {
569 period: max_p,
570 data_len: data.len(),
571 });
572 }
573 if data.len() - first < max_p {
574 return Err(MeanAdError::NotEnoughValidData {
575 needed: max_p,
576 valid: data.len() - first,
577 });
578 }
579 let rows = combos.len();
580 let cols = data.len();
581
582 let chosen = match kern {
583 Kernel::Auto => Kernel::Scalar,
584 other => other,
585 };
586
587 let do_row = |row: usize, out_row: &mut [f64]| {
588 let period = combos[row].period.unwrap();
589 let warmup_end = first + 2 * period - 2;
590 for i in 0..warmup_end.min(out_row.len()) {
591 out_row[i] = f64::NAN;
592 }
593 match chosen {
594 Kernel::Scalar => mean_ad_row_scalar(data, first, period, out_row),
595 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
596 Kernel::Avx2 => mean_ad_row_avx2(data, first, period, out_row),
597 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
598 Kernel::Avx512 => mean_ad_row_avx512(data, first, period, out_row),
599 _ => mean_ad_row_scalar(data, first, period, out_row),
600 }
601 };
602
603 if parallel {
604 #[cfg(not(target_arch = "wasm32"))]
605 {
606 out.par_chunks_mut(cols)
607 .enumerate()
608 .for_each(|(row, slice)| do_row(row, slice));
609 }
610
611 #[cfg(target_arch = "wasm32")]
612 {
613 for (row, slice) in out.chunks_mut(cols).enumerate() {
614 do_row(row, slice);
615 }
616 }
617 } else {
618 for (row, slice) in out.chunks_mut(cols).enumerate() {
619 do_row(row, slice);
620 }
621 }
622
623 Ok(combos)
624}
625
626#[inline(always)]
627fn mean_ad_batch_inner(
628 data: &[f64],
629 sweep: &MeanAdBatchRange,
630 kern: Kernel,
631 parallel: bool,
632) -> Result<MeanAdBatchOutput, MeanAdError> {
633 if data.is_empty() {
634 return Err(MeanAdError::EmptyInputData);
635 }
636 let combos = expand_grid(sweep)?;
637 if combos.iter().any(|c| c.period.unwrap_or(0) == 0) {
638 return Err(MeanAdError::InvalidPeriod {
639 period: 0,
640 data_len: data.len(),
641 });
642 }
643 let first = data
644 .iter()
645 .position(|x| !x.is_nan())
646 .ok_or(MeanAdError::AllValuesNaN)?;
647 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
648 if max_p > data.len() {
649 return Err(MeanAdError::InvalidPeriod {
650 period: max_p,
651 data_len: data.len(),
652 });
653 }
654 if data.len() - first < max_p {
655 return Err(MeanAdError::NotEnoughValidData {
656 needed: max_p,
657 valid: data.len() - first,
658 });
659 }
660 let rows = combos.len();
661 let cols = data.len();
662
663 let warmup_periods: Vec<usize> = combos
664 .iter()
665 .map(|c| first + 2 * c.period.unwrap() - 2)
666 .collect();
667
668 let mut buf_mu = make_uninit_matrix(rows, cols);
669 init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
670
671 let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
672 let values_ptr = buf_guard.as_mut_ptr() as *mut f64;
673 let values_slice: &mut [f64] =
674 unsafe { core::slice::from_raw_parts_mut(values_ptr, buf_guard.len()) };
675
676 let chosen = match kern {
677 Kernel::Auto => Kernel::Scalar,
678 other => other,
679 };
680
681 let do_row = |row: usize, out_row: &mut [f64]| {
682 let period = combos[row].period.unwrap();
683
684 let warmup_end = first + 2 * period - 2;
685 for i in 0..warmup_end.min(out_row.len()) {
686 out_row[i] = f64::NAN;
687 }
688 match chosen {
689 Kernel::Scalar => mean_ad_row_scalar(data, first, period, out_row),
690 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
691 Kernel::Avx2 => mean_ad_row_avx2(data, first, period, out_row),
692 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
693 Kernel::Avx512 => mean_ad_row_avx512(data, first, period, out_row),
694 _ => mean_ad_row_scalar(data, first, period, out_row),
695 }
696 };
697 if parallel {
698 #[cfg(not(target_arch = "wasm32"))]
699 {
700 values_slice
701 .par_chunks_mut(cols)
702 .enumerate()
703 .for_each(|(row, slice)| do_row(row, slice));
704 }
705
706 #[cfg(target_arch = "wasm32")]
707 {
708 for (row, slice) in values_slice.chunks_mut(cols).enumerate() {
709 do_row(row, slice);
710 }
711 }
712 } else {
713 for (row, slice) in values_slice.chunks_mut(cols).enumerate() {
714 do_row(row, slice);
715 }
716 }
717
718 let values = unsafe {
719 Vec::from_raw_parts(
720 buf_guard.as_mut_ptr() as *mut f64,
721 buf_guard.len(),
722 buf_guard.capacity(),
723 )
724 };
725
726 Ok(MeanAdBatchOutput {
727 values,
728 combos,
729 rows,
730 cols,
731 })
732}
733
734#[inline(always)]
735pub fn mean_ad_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
736 if first + period > data.len() {
737 return;
738 }
739
740 let n = data.len();
741 let inv_p = 1.0f64 / (period as f64);
742
743 let mut sum = 0.0f64;
744 for i in first..(first + period) {
745 sum += data[i];
746 }
747 let mut sma = sum * inv_p;
748
749 let mut residual_buffer = vec![0.0f64; period];
750 let mut buffer_index = 0usize;
751 let mut residual_sum = 0.0f64;
752
753 let start_t = first + period - 1;
754 let fill_t_end = (start_t + period - 1).min(n.saturating_sub(1));
755 for t in start_t..=fill_t_end {
756 let residual = (data[t] - sma).abs();
757 residual_buffer[buffer_index] = residual;
758 residual_sum += residual;
759 buffer_index += 1;
760 if buffer_index == period {
761 buffer_index = 0;
762 }
763 if t + 1 < n {
764 sum += data[t + 1] - data[t + 1 - period];
765 sma = sum * inv_p;
766 }
767 }
768
769 let first_output = first + (period << 1) - 2;
770 if first_output < n {
771 out[first_output] = residual_sum * inv_p;
772 }
773
774 let mut t = start_t + period;
775 while t < n {
776 let residual = (data[t] - sma).abs();
777
778 let old = residual_buffer[buffer_index];
779 residual_sum += residual - old;
780 residual_buffer[buffer_index] = residual;
781 buffer_index += 1;
782 if buffer_index == period {
783 buffer_index = 0;
784 }
785
786 out[t] = residual_sum * inv_p;
787
788 if t + 1 < n {
789 sum += data[t + 1] - data[t + 1 - period];
790 sma = sum * inv_p;
791 }
792 t += 1;
793 }
794}
795
796#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
797#[inline(always)]
798pub fn mean_ad_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
799 mean_ad_row_scalar(data, first, period, out);
800}
801
802#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
803#[inline(always)]
804pub fn mean_ad_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
805 if period <= 32 {
806 mean_ad_row_avx512_short(data, first, period, out);
807 } else {
808 mean_ad_row_avx512_long(data, first, period, out);
809 }
810}
811
812#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
813#[inline(always)]
814pub fn mean_ad_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
815 mean_ad_row_scalar(data, first, period, out);
816}
817
818#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
819#[inline(always)]
820pub fn mean_ad_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
821 mean_ad_row_scalar(data, first, period, out);
822}
823
824#[derive(Debug, Clone)]
825pub struct MeanAdStream {
826 period: usize,
827 buffer: Vec<f64>,
828 head: usize,
829 filled: bool,
830 mean_buffer: Vec<f64>,
831 mean_head: usize,
832 mean_filled: bool,
833 mean: f64,
834 mad: f64,
835}
836
837impl MeanAdStream {
838 pub fn try_new(params: MeanAdParams) -> Result<Self, MeanAdError> {
839 let period = params.period.unwrap_or(5);
840 if period == 0 {
841 return Err(MeanAdError::InvalidPeriod {
842 period,
843 data_len: 0,
844 });
845 }
846 Ok(Self {
847 period,
848 buffer: vec![f64::NAN; period],
849 head: 0,
850 filled: false,
851 mean_buffer: vec![f64::NAN; period],
852 mean_head: 0,
853 mean_filled: false,
854 mean: 0.0,
855 mad: 0.0,
856 })
857 }
858 #[inline(always)]
859 pub fn update(&mut self, value: f64) -> Option<f64> {
860 let p = self.period;
861 let inv_p = 1.0f64 / (p as f64);
862
863 let price_idx = self.head;
864 let old_x = self.buffer[price_idx];
865
866 let resid_idx = self.mean_head;
867 let old_r = self.mean_buffer[resid_idx];
868
869 self.buffer[price_idx] = value;
870 let next_price_idx = price_idx + 1;
871 let wrapped_prices = next_price_idx == p;
872 self.head = if wrapped_prices { 0 } else { next_price_idx };
873
874 let just_filled_prices = !self.filled && wrapped_prices;
875 if just_filled_prices {
876 self.filled = true;
877 }
878
879 if !self.filled {
880 self.mean += value;
881 return None;
882 }
883
884 let sum_t = if just_filled_prices {
885 self.mean + value
886 } else {
887 self.mean + value - old_x
888 };
889 let mean_t = sum_t * inv_p;
890
891 self.mean = sum_t;
892
893 let resid_t = (value - mean_t).abs();
894
895 self.mean_buffer[resid_idx] = resid_t;
896 let next_resid_idx = resid_idx + 1;
897 let wrapped_resids = next_resid_idx == p;
898 self.mean_head = if wrapped_resids { 0 } else { next_resid_idx };
899
900 if !self.mean_filled {
901 self.mad += resid_t;
902 if wrapped_resids {
903 self.mean_filled = true;
904 return Some(self.mad * inv_p);
905 }
906 return None;
907 }
908
909 self.mad = self.mad + resid_t - old_r;
910 Some(self.mad * inv_p)
911 }
912}
913
914#[cfg(test)]
915mod tests {
916 use super::*;
917 use crate::skip_if_unsupported;
918 use crate::utilities::data_loader::read_candles_from_csv;
919 use paste::paste;
920
921 fn check_mean_ad_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
922 skip_if_unsupported!(kernel, test_name);
923 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
924 let candles = read_candles_from_csv(file_path)?;
925 let default_params = MeanAdParams { period: None };
926 let input = MeanAdInput::from_candles(&candles, "close", default_params);
927 let output = mean_ad_with_kernel(&input, kernel)?;
928 assert_eq!(output.values.len(), candles.close.len());
929 Ok(())
930 }
931
932 fn check_mean_ad_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
933 skip_if_unsupported!(kernel, test_name);
934 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
935 let candles = read_candles_from_csv(file_path)?;
936 let input = MeanAdInput::from_candles(&candles, "hl2", MeanAdParams { period: Some(5) });
937 let result = mean_ad_with_kernel(&input, kernel)?;
938 assert_eq!(result.values.len(), candles.close.len());
939 let expected_last_five = [
940 199.71999999999971,
941 104.14000000000087,
942 133.4,
943 100.54000000000087,
944 117.98000000000029,
945 ];
946 let start = result.values.len().saturating_sub(5);
947 for (i, &val) in result.values[start..].iter().enumerate() {
948 let diff = (val - expected_last_five[i]).abs();
949 assert!(
950 diff < 1e-1,
951 "[{}] MeanAd {:?} mismatch at idx {}: got {}, expected {}",
952 test_name,
953 kernel,
954 i,
955 val,
956 expected_last_five[i]
957 );
958 }
959 Ok(())
960 }
961
962 fn check_mean_ad_default_candles(
963 test_name: &str,
964 kernel: Kernel,
965 ) -> Result<(), Box<dyn Error>> {
966 skip_if_unsupported!(kernel, test_name);
967 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
968 let candles = read_candles_from_csv(file_path)?;
969 let input = MeanAdInput::with_default_candles(&candles);
970 match input.data {
971 MeanAdData::Candles { source, .. } => assert_eq!(source, "close"),
972 _ => panic!("Expected MeanAdData::Candles"),
973 }
974 let output = mean_ad_with_kernel(&input, kernel)?;
975 assert_eq!(output.values.len(), candles.close.len());
976 Ok(())
977 }
978
979 fn check_mean_ad_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
980 skip_if_unsupported!(kernel, test_name);
981 let input_data = [10.0, 20.0, 30.0];
982 let params = MeanAdParams { period: Some(0) };
983 let input = MeanAdInput::from_slice(&input_data, params);
984 let res = mean_ad_with_kernel(&input, kernel);
985 assert!(
986 res.is_err(),
987 "[{}] MeanAd should fail with zero period",
988 test_name
989 );
990 Ok(())
991 }
992
993 fn check_mean_ad_period_exceeds_length(
994 test_name: &str,
995 kernel: Kernel,
996 ) -> Result<(), Box<dyn Error>> {
997 skip_if_unsupported!(kernel, test_name);
998 let data_small = [10.0, 20.0, 30.0];
999 let params = MeanAdParams { period: Some(10) };
1000 let input = MeanAdInput::from_slice(&data_small, params);
1001 let res = mean_ad_with_kernel(&input, kernel);
1002 assert!(
1003 res.is_err(),
1004 "[{}] MeanAd should fail with period exceeding length",
1005 test_name
1006 );
1007 Ok(())
1008 }
1009
1010 fn check_mean_ad_very_small_dataset(
1011 test_name: &str,
1012 kernel: Kernel,
1013 ) -> Result<(), Box<dyn Error>> {
1014 skip_if_unsupported!(kernel, test_name);
1015 let single_point = [42.0];
1016 let params = MeanAdParams { period: Some(5) };
1017 let input = MeanAdInput::from_slice(&single_point, params);
1018 let res = mean_ad_with_kernel(&input, kernel);
1019 assert!(
1020 res.is_err(),
1021 "[{}] MeanAd should fail with insufficient data",
1022 test_name
1023 );
1024 Ok(())
1025 }
1026
1027 fn check_mean_ad_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1028 skip_if_unsupported!(kernel, test_name);
1029 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1030 let candles = read_candles_from_csv(file_path)?;
1031 let first_params = MeanAdParams { period: Some(5) };
1032 let first_input = MeanAdInput::from_candles(&candles, "close", first_params);
1033 let first_result = mean_ad_with_kernel(&first_input, kernel)?;
1034 let params2 = MeanAdParams { period: Some(3) };
1035 let second_input = MeanAdInput::from_slice(&first_result.values, params2);
1036 let second_result = mean_ad_with_kernel(&second_input, kernel)?;
1037 assert_eq!(second_result.values.len(), first_result.values.len());
1038 Ok(())
1039 }
1040
1041 fn check_mean_ad_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1042 skip_if_unsupported!(kernel, test_name);
1043 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1044 let candles = read_candles_from_csv(file_path)?;
1045 let input = MeanAdInput::from_candles(&candles, "close", MeanAdParams { period: Some(5) });
1046 let res = mean_ad_with_kernel(&input, kernel)?;
1047 assert_eq!(res.values.len(), candles.close.len());
1048 if res.values.len() > 240 {
1049 for (i, &val) in res.values[240..].iter().enumerate() {
1050 assert!(
1051 !val.is_nan(),
1052 "[{}] Found unexpected NaN at out-index {}",
1053 test_name,
1054 240 + i
1055 );
1056 }
1057 }
1058 Ok(())
1059 }
1060
1061 fn check_mean_ad_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1062 skip_if_unsupported!(kernel, test_name);
1063 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1064 let candles = read_candles_from_csv(file_path)?;
1065 let period = 5;
1066 let input = MeanAdInput::from_candles(
1067 &candles,
1068 "close",
1069 MeanAdParams {
1070 period: Some(period),
1071 },
1072 );
1073 let batch_output = mean_ad_with_kernel(&input, kernel)?.values;
1074 let mut stream = MeanAdStream::try_new(MeanAdParams {
1075 period: Some(period),
1076 })?;
1077
1078 let mut stream_uninit: Vec<MaybeUninit<f64>> = Vec::with_capacity(candles.close.len());
1079 unsafe {
1080 stream_uninit.set_len(candles.close.len());
1081 }
1082
1083 for (i, &price) in candles.close.iter().enumerate() {
1084 let val = match stream.update(price) {
1085 Some(mean_ad_val) => mean_ad_val,
1086 None => f64::NAN,
1087 };
1088 stream_uninit[i] = MaybeUninit::new(val);
1089 }
1090
1091 let stream_values = unsafe {
1092 let ptr = stream_uninit.as_mut_ptr() as *mut f64;
1093 let len = stream_uninit.len();
1094 let cap = stream_uninit.capacity();
1095 std::mem::forget(stream_uninit);
1096 Vec::from_raw_parts(ptr, len, cap)
1097 };
1098 assert_eq!(batch_output.len(), stream_values.len());
1099 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1100 if b.is_nan() && s.is_nan() {
1101 continue;
1102 }
1103 let diff = (b - s).abs();
1104 assert!(
1105 diff < 1e-9,
1106 "[{}] MeanAd streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1107 test_name,
1108 i,
1109 b,
1110 s,
1111 diff
1112 );
1113 }
1114 Ok(())
1115 }
1116
1117 #[cfg(debug_assertions)]
1118 fn check_mean_ad_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1119 skip_if_unsupported!(kernel, test_name);
1120
1121 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1122 let candles = read_candles_from_csv(file_path)?;
1123
1124 let test_params = vec![
1125 MeanAdParams::default(),
1126 MeanAdParams { period: Some(2) },
1127 MeanAdParams { period: Some(3) },
1128 MeanAdParams { period: Some(5) },
1129 MeanAdParams { period: Some(7) },
1130 MeanAdParams { period: Some(10) },
1131 MeanAdParams { period: Some(14) },
1132 MeanAdParams { period: Some(20) },
1133 MeanAdParams { period: Some(30) },
1134 MeanAdParams { period: Some(50) },
1135 MeanAdParams { period: Some(100) },
1136 MeanAdParams { period: Some(200) },
1137 ];
1138
1139 for (param_idx, params) in test_params.iter().enumerate() {
1140 let input = MeanAdInput::from_candles(&candles, "close", params.clone());
1141 let output = mean_ad_with_kernel(&input, kernel)?;
1142
1143 for (i, &val) in output.values.iter().enumerate() {
1144 if val.is_nan() {
1145 continue;
1146 }
1147
1148 let bits = val.to_bits();
1149
1150 if bits == 0x11111111_11111111 {
1151 panic!(
1152 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1153 with params: period={} (param set {})",
1154 test_name,
1155 val,
1156 bits,
1157 i,
1158 params.period.unwrap_or(5),
1159 param_idx
1160 );
1161 }
1162
1163 if bits == 0x22222222_22222222 {
1164 panic!(
1165 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1166 with params: period={} (param set {})",
1167 test_name,
1168 val,
1169 bits,
1170 i,
1171 params.period.unwrap_or(5),
1172 param_idx
1173 );
1174 }
1175
1176 if bits == 0x33333333_33333333 {
1177 panic!(
1178 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1179 with params: period={} (param set {})",
1180 test_name,
1181 val,
1182 bits,
1183 i,
1184 params.period.unwrap_or(5),
1185 param_idx
1186 );
1187 }
1188 }
1189 }
1190
1191 Ok(())
1192 }
1193
1194 #[cfg(not(debug_assertions))]
1195 fn check_mean_ad_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1196 Ok(())
1197 }
1198
1199 macro_rules! generate_all_mean_ad_tests {
1200 ($($test_fn:ident),*) => {
1201 paste! {
1202 $(
1203 #[test]
1204 fn [<$test_fn _scalar_f64>]() {
1205 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1206 }
1207 )*
1208 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1209 $(
1210 #[test]
1211 fn [<$test_fn _avx2_f64>]() {
1212 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1213 }
1214 #[test]
1215 fn [<$test_fn _avx512_f64>]() {
1216 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1217 }
1218 )*
1219 }
1220 }
1221 }
1222 #[cfg(feature = "proptest")]
1223 #[allow(clippy::float_cmp)]
1224 fn check_mean_ad_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1225 use proptest::prelude::*;
1226 skip_if_unsupported!(kernel, test_name);
1227
1228 let strat = (2usize..=64)
1229 .prop_flat_map(|period| {
1230 (
1231 prop::collection::vec(
1232 (10.0f64..1000.0f64).prop_filter("finite", |x| x.is_finite()),
1233 period..400,
1234 ),
1235 Just(period),
1236 prop::bool::weighted(0.1),
1237 )
1238 })
1239 .prop_map(|(mut data, period, make_constant)| {
1240 if make_constant && data.len() > 0 {
1241 let constant_val = data[0];
1242 data.iter_mut().for_each(|v| *v = constant_val);
1243 }
1244 (data, period)
1245 });
1246
1247 proptest::test_runner::TestRunner::default().run(&strat, |(data, period)| {
1248 let params = MeanAdParams {
1249 period: Some(period),
1250 };
1251 let input = MeanAdInput::from_slice(&data, params);
1252
1253 let MeanAdOutput { values: out } = mean_ad_with_kernel(&input, kernel)?;
1254 let MeanAdOutput { values: ref_out } = mean_ad_with_kernel(&input, Kernel::Scalar)?;
1255
1256 prop_assert_eq!(
1257 out.len(),
1258 data.len(),
1259 "[{}] Output length mismatch",
1260 test_name
1261 );
1262
1263 let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1264 let warmup_period = first_valid + 2 * period - 2;
1265
1266 for i in 0..warmup_period.min(out.len()) {
1267 prop_assert!(
1268 out[i].is_nan(),
1269 "[{}] Expected NaN during warmup at index {}, got {}",
1270 test_name,
1271 i,
1272 out[i]
1273 );
1274 }
1275
1276 for i in warmup_period..out.len() {
1277 if !out[i].is_nan() {
1278 prop_assert!(
1279 out[i] >= -1e-10,
1280 "[{}] MAD should be non-negative at index {}: got {}",
1281 test_name,
1282 i,
1283 out[i]
1284 );
1285 }
1286 }
1287
1288 for i in 0..out.len() {
1289 let y = out[i];
1290 let r = ref_out[i];
1291
1292 if y.is_nan() || r.is_nan() {
1293 prop_assert!(
1294 y.is_nan() && r.is_nan(),
1295 "[{}] NaN mismatch at index {}: kernel={}, scalar={}",
1296 test_name,
1297 i,
1298 y,
1299 r
1300 );
1301 continue;
1302 }
1303
1304 let y_bits = y.to_bits();
1305 let r_bits = r.to_bits();
1306 let ulp_diff = y_bits.abs_diff(r_bits);
1307
1308 prop_assert!(
1309 (y - r).abs() <= 1e-9 || ulp_diff <= 5,
1310 "[{}] Kernel mismatch at index {}: kernel={}, scalar={}, diff={}, ULP={}",
1311 test_name,
1312 i,
1313 y,
1314 r,
1315 (y - r).abs(),
1316 ulp_diff
1317 );
1318 }
1319
1320 let is_constant = data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
1321 if is_constant && out.len() > warmup_period {
1322 for i in warmup_period..out.len() {
1323 if !out[i].is_nan() {
1324 prop_assert!(
1325 out[i].abs() <= 1e-9,
1326 "[{}] MAD should be ~0 for constant data at index {}: got {}",
1327 test_name,
1328 i,
1329 out[i]
1330 );
1331 }
1332 }
1333 }
1334
1335 let is_linear_monotonic = if data.len() >= 3 {
1336 let diffs: Vec<f64> = data.windows(2).map(|w| w[1] - w[0]).collect();
1337 let first_diff = diffs[0];
1338 diffs.iter().all(|&d| (d - first_diff).abs() < 1e-9)
1339 } else {
1340 false
1341 };
1342
1343 if is_linear_monotonic && out.len() > warmup_period + period {
1344 for i in (warmup_period + 1)..out.len() {
1345 if !out[i].is_nan() && !out[i - 1].is_nan() && out[i - 1] > 1e-10 {
1346 let change_ratio = (out[i] - out[i - 1]).abs() / out[i - 1];
1347
1348 prop_assert!(
1349 change_ratio <= 0.1,
1350 "[{}] MAD changes too much for linear data at index {}: {} -> {} ({:.2}% change)",
1351 test_name,
1352 i,
1353 out[i-1],
1354 out[i],
1355 change_ratio * 100.0
1356 );
1357 }
1358 }
1359 }
1360
1361 for i in warmup_period..out.len() {
1362 if out[i].is_nan() || i < period {
1363 continue;
1364 }
1365
1366 let window_start = i + 1 - period;
1367 let window = &data[window_start..=i];
1368
1369 let window_min = window.iter().cloned().fold(f64::INFINITY, f64::min);
1370 let window_max = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1371 let window_range = window_max - window_min;
1372
1373 prop_assert!(
1374 out[i] <= window_range / 2.0 + 1e-9,
1375 "[{}] MAD exceeds half window range at index {}: MAD={}, window_range/2={}",
1376 test_name,
1377 i,
1378 out[i],
1379 window_range / 2.0
1380 );
1381 }
1382
1383 if period == data.len() && out.len() > warmup_period {
1384 let non_nan_count = out.iter().filter(|&&v| !v.is_nan()).count();
1385 prop_assert!(
1386 non_nan_count <= 1,
1387 "[{}] With period={}, expected at most 1 non-NaN value, got {}",
1388 test_name,
1389 period,
1390 non_nan_count
1391 );
1392 }
1393
1394 if period == 2 && out.len() > warmup_period {
1395 for i in warmup_period..out.len() {
1396 if !out[i].is_nan() && i >= 1 {
1397 let expected = (data[i] - data[i - 1]).abs() / 2.0;
1398 prop_assert!(
1399 (out[i] - expected).abs() <= 1e-9,
1400 "[{}] For period=2 at index {}, MAD should be {}, got {}",
1401 test_name,
1402 i,
1403 expected,
1404 out[i]
1405 );
1406 }
1407 }
1408 }
1409
1410 Ok(())
1411 })?;
1412
1413 Ok(())
1414 }
1415
1416 generate_all_mean_ad_tests!(
1417 check_mean_ad_partial_params,
1418 check_mean_ad_accuracy,
1419 check_mean_ad_default_candles,
1420 check_mean_ad_zero_period,
1421 check_mean_ad_period_exceeds_length,
1422 check_mean_ad_very_small_dataset,
1423 check_mean_ad_reinput,
1424 check_mean_ad_nan_handling,
1425 check_mean_ad_streaming,
1426 check_mean_ad_no_poison
1427 );
1428
1429 #[cfg(feature = "proptest")]
1430 generate_all_mean_ad_tests!(check_mean_ad_property);
1431
1432 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1433 skip_if_unsupported!(kernel, test);
1434
1435 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1436 let c = read_candles_from_csv(file)?;
1437 let output = MeanAdBatchBuilder::new()
1438 .kernel(kernel)
1439 .apply_candles(&c, "hl2")?;
1440
1441 let def = MeanAdParams::default();
1442 let row = output.values_for(&def).expect("default row missing");
1443
1444 assert_eq!(row.len(), c.close.len());
1445
1446 let expected = [
1447 199.71999999999971,
1448 104.14000000000087,
1449 133.4,
1450 100.54000000000087,
1451 117.98000000000029,
1452 ];
1453 let start = row.len() - 5;
1454 for (i, &v) in row[start..].iter().enumerate() {
1455 assert!(
1456 (v - expected[i]).abs() < 1e-1,
1457 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1458 );
1459 }
1460 Ok(())
1461 }
1462
1463 #[cfg(debug_assertions)]
1464 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1465 skip_if_unsupported!(kernel, test);
1466
1467 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1468 let c = read_candles_from_csv(file)?;
1469
1470 let test_configs = vec![
1471 (2, 10, 2),
1472 (5, 25, 5),
1473 (30, 60, 15),
1474 (2, 5, 1),
1475 (10, 50, 10),
1476 (3, 15, 3),
1477 (20, 30, 2),
1478 (7, 21, 7),
1479 (100, 200, 50),
1480 ];
1481
1482 for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1483 let output = MeanAdBatchBuilder::new()
1484 .kernel(kernel)
1485 .period_range(p_start, p_end, p_step)
1486 .apply_candles(&c, "close")?;
1487
1488 for (idx, &val) in output.values.iter().enumerate() {
1489 if val.is_nan() {
1490 continue;
1491 }
1492
1493 let bits = val.to_bits();
1494 let row = idx / output.cols;
1495 let col = idx % output.cols;
1496 let combo = &output.combos[row];
1497
1498 if bits == 0x11111111_11111111 {
1499 panic!(
1500 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1501 at row {} col {} (flat index {}) with params: period={}",
1502 test,
1503 cfg_idx,
1504 val,
1505 bits,
1506 row,
1507 col,
1508 idx,
1509 combo.period.unwrap_or(5)
1510 );
1511 }
1512
1513 if bits == 0x22222222_22222222 {
1514 panic!(
1515 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1516 at row {} col {} (flat index {}) with params: period={}",
1517 test,
1518 cfg_idx,
1519 val,
1520 bits,
1521 row,
1522 col,
1523 idx,
1524 combo.period.unwrap_or(5)
1525 );
1526 }
1527
1528 if bits == 0x33333333_33333333 {
1529 panic!(
1530 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1531 at row {} col {} (flat index {}) with params: period={}",
1532 test,
1533 cfg_idx,
1534 val,
1535 bits,
1536 row,
1537 col,
1538 idx,
1539 combo.period.unwrap_or(5)
1540 );
1541 }
1542 }
1543 }
1544
1545 Ok(())
1546 }
1547
1548 #[cfg(not(debug_assertions))]
1549 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1550 Ok(())
1551 }
1552
1553 #[test]
1554 fn test_mean_ad_into_matches_api() -> Result<(), Box<dyn Error>> {
1555 let n = 256usize;
1556 let mut data = Vec::with_capacity(n);
1557 data.push(f64::NAN);
1558 data.push(f64::NAN);
1559 data.push(f64::NAN);
1560 for i in 3..n {
1561 let t = i as f64;
1562 let v = (t * 0.01).mul_add((t * 0.1).sin() * 100.0, (t * 0.05).cos() * 0.5);
1563 data.push(v);
1564 }
1565
1566 let input = MeanAdInput::from_slice(&data, MeanAdParams::default());
1567 let baseline = mean_ad(&input)?.values;
1568
1569 let mut into_out = vec![0.0; baseline.len()];
1570 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1571 {
1572 mean_ad_into(&input, &mut into_out)?;
1573 }
1574 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1575 {
1576 mean_ad_into_slice(&mut into_out, &input, Kernel::Auto)?;
1577 }
1578
1579 assert_eq!(baseline.len(), into_out.len());
1580 #[inline]
1581 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1582 (a.is_nan() && b.is_nan()) || (a == b) || (a - b).abs() <= 1e-12
1583 }
1584 for (i, (a, b)) in baseline.iter().zip(into_out.iter()).enumerate() {
1585 assert!(
1586 eq_or_both_nan(*a, *b),
1587 "divergence at idx {}: api={}, into={}",
1588 i,
1589 a,
1590 b
1591 );
1592 }
1593 Ok(())
1594 }
1595
1596 macro_rules! gen_batch_tests {
1597 ($fn_name:ident) => {
1598 paste! {
1599 #[test] fn [<$fn_name _scalar>]() {
1600 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1601 }
1602 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1603 #[test] fn [<$fn_name _avx2>]() {
1604 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1605 }
1606 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1607 #[test] fn [<$fn_name _avx512>]() {
1608 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1609 }
1610 #[test] fn [<$fn_name _auto_detect>]() {
1611 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1612 }
1613 }
1614 };
1615 }
1616 gen_batch_tests!(check_batch_default_row);
1617 gen_batch_tests!(check_batch_no_poison);
1618}
1619
1620#[cfg(feature = "python")]
1621#[pyfunction(name = "mean_ad")]
1622#[pyo3(signature = (data, period, kernel=None))]
1623pub fn mean_ad_py<'py>(
1624 py: Python<'py>,
1625 data: PyReadonlyArray1<'py, f64>,
1626 period: usize,
1627 kernel: Option<&str>,
1628) -> PyResult<Bound<'py, PyArray1<f64>>> {
1629 let slice_in = data.as_slice()?;
1630 let kern = validate_kernel(kernel, false)?;
1631
1632 let params = MeanAdParams {
1633 period: Some(period),
1634 };
1635 let input = MeanAdInput::from_slice(slice_in, params);
1636
1637 let result_vec: Vec<f64> = py
1638 .allow_threads(|| mean_ad_with_kernel(&input, kern).map(|o| o.values))
1639 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1640
1641 Ok(result_vec.into_pyarray(py))
1642}
1643
1644#[cfg(feature = "python")]
1645#[pyfunction(name = "mean_ad_batch")]
1646#[pyo3(signature = (data, period_range, kernel=None))]
1647pub fn mean_ad_batch_py<'py>(
1648 py: Python<'py>,
1649 data: PyReadonlyArray1<'py, f64>,
1650 period_range: (usize, usize, usize),
1651 kernel: Option<&str>,
1652) -> PyResult<Bound<'py, PyDict>> {
1653 let slice_in = data.as_slice()?;
1654 let kern = validate_kernel(kernel, true)?;
1655
1656 let sweep = MeanAdBatchRange {
1657 period: period_range,
1658 };
1659
1660 let combos_for_shape = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1661 let rows = combos_for_shape.len();
1662 let cols = slice_in.len();
1663
1664 let total = rows
1665 .checked_mul(cols)
1666 .ok_or_else(|| PyValueError::new_err("mean_ad_batch_py: size overflow"))?;
1667 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1668 let slice_out = unsafe { out_arr.as_slice_mut()? };
1669
1670 let combos = py
1671 .allow_threads(|| {
1672 let kernel = match kern {
1673 Kernel::Auto => detect_best_batch_kernel(),
1674 k => k,
1675 };
1676 let simd = match kernel {
1677 Kernel::Avx512Batch => Kernel::Avx512,
1678 Kernel::Avx2Batch => Kernel::Avx2,
1679 Kernel::ScalarBatch => Kernel::Scalar,
1680 _ => unreachable!(),
1681 };
1682 mean_ad_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1683 })
1684 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1685
1686 let dict = PyDict::new(py);
1687 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1688 dict.set_item(
1689 "periods",
1690 combos
1691 .iter()
1692 .map(|p| p.period.unwrap() as u64)
1693 .collect::<Vec<_>>()
1694 .into_pyarray(py),
1695 )?;
1696
1697 Ok(dict)
1698}
1699
1700#[cfg(all(feature = "python", feature = "cuda"))]
1701#[pyfunction(name = "mean_ad_cuda_batch_dev")]
1702#[pyo3(signature = (data_f32, period_range, device_id=0))]
1703pub fn mean_ad_cuda_batch_dev_py<'py>(
1704 py: Python<'py>,
1705 data_f32: numpy::PyReadonlyArray1<'py, f32>,
1706 period_range: (usize, usize, usize),
1707 device_id: usize,
1708) -> PyResult<DeviceArrayF32Py> {
1709 if !crate::cuda::cuda_available() {
1710 return Err(PyValueError::new_err("CUDA not available"));
1711 }
1712 let slice_in = data_f32.as_slice()?;
1713 let sweep = MeanAdBatchRange {
1714 period: period_range,
1715 };
1716 let inner = py.allow_threads(|| {
1717 let cuda = CudaMeanAd::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1718 cuda.mean_ad_batch_dev(slice_in, &sweep)
1719 .map_err(|e| PyValueError::new_err(e.to_string()))
1720 })?;
1721 make_device_array_py(device_id, inner)
1722}
1723
1724#[cfg(all(feature = "python", feature = "cuda"))]
1725#[pyfunction(name = "mean_ad_cuda_many_series_one_param_dev")]
1726#[pyo3(signature = (data_tm_f32, cols, rows, period, device_id=0))]
1727pub fn mean_ad_cuda_many_series_one_param_dev_py<'py>(
1728 py: Python<'py>,
1729 data_tm_f32: numpy::PyReadonlyArray1<'py, f32>,
1730 cols: usize,
1731 rows: usize,
1732 period: usize,
1733 device_id: usize,
1734) -> PyResult<DeviceArrayF32Py> {
1735 if !crate::cuda::cuda_available() {
1736 return Err(PyValueError::new_err("CUDA not available"));
1737 }
1738 let slice_in = data_tm_f32.as_slice()?;
1739 let params = MeanAdParams {
1740 period: Some(period),
1741 };
1742 let inner = py.allow_threads(|| {
1743 let cuda = CudaMeanAd::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1744 cuda.mean_ad_many_series_one_param_time_major_dev(slice_in, cols, rows, ¶ms)
1745 .map_err(|e| PyValueError::new_err(e.to_string()))
1746 })?;
1747 make_device_array_py(device_id, inner)
1748}
1749
1750#[cfg(feature = "python")]
1751#[pyclass(name = "MeanAdStream")]
1752pub struct MeanAdStreamPy {
1753 stream: MeanAdStream,
1754}
1755
1756#[cfg(feature = "python")]
1757#[pymethods]
1758impl MeanAdStreamPy {
1759 #[new]
1760 fn new(period: usize) -> PyResult<Self> {
1761 let params = MeanAdParams {
1762 period: Some(period),
1763 };
1764 let stream =
1765 MeanAdStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1766 Ok(MeanAdStreamPy { stream })
1767 }
1768
1769 fn update(&mut self, value: f64) -> Option<f64> {
1770 self.stream.update(value)
1771 }
1772}
1773
1774pub fn mean_ad_into_slice(
1775 dst: &mut [f64],
1776 input: &MeanAdInput,
1777 kern: Kernel,
1778) -> Result<(), MeanAdError> {
1779 let data = input.as_ref();
1780 let period = match input.params.period {
1781 Some(p) if p > 0 => p,
1782 _ => {
1783 return Err(MeanAdError::InvalidPeriod {
1784 period: input.params.period.unwrap_or(0),
1785 data_len: data.len(),
1786 })
1787 }
1788 };
1789
1790 if data.is_empty() {
1791 return Err(MeanAdError::EmptyInputData);
1792 }
1793 if period > data.len() {
1794 return Err(MeanAdError::InvalidPeriod {
1795 period,
1796 data_len: data.len(),
1797 });
1798 }
1799 if dst.len() != data.len() {
1800 return Err(MeanAdError::OutputLengthMismatch {
1801 expected: data.len(),
1802 got: dst.len(),
1803 });
1804 }
1805
1806 let first = data
1807 .iter()
1808 .position(|x| !x.is_nan())
1809 .ok_or(MeanAdError::AllValuesNaN)?;
1810
1811 if (data.len() - first) < period {
1812 return Err(MeanAdError::NotEnoughValidData {
1813 needed: period,
1814 valid: data.len() - first,
1815 });
1816 }
1817
1818 let chosen = match kern {
1819 Kernel::Auto => Kernel::Scalar,
1820 other => other,
1821 };
1822
1823 let warmup_end = first + (period << 1) - 2;
1824 let warmup_end = warmup_end.min(dst.len());
1825 if warmup_end > 0 {
1826 dst[..warmup_end].fill(f64::NAN);
1827 }
1828
1829 match chosen {
1830 Kernel::Scalar | Kernel::ScalarBatch => mean_ad_row_scalar(data, first, period, dst),
1831 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1832 Kernel::Avx2 | Kernel::Avx2Batch => mean_ad_row_avx2(data, first, period, dst),
1833 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1834 Kernel::Avx512 | Kernel::Avx512Batch => mean_ad_row_avx512(data, first, period, dst),
1835 _ => unreachable!(),
1836 }
1837
1838 Ok(())
1839}
1840
1841#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1842#[wasm_bindgen]
1843pub fn mean_ad_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1844 let params = MeanAdParams {
1845 period: Some(period),
1846 };
1847 let input = MeanAdInput::from_slice(data, params);
1848
1849 let mut output = vec![0.0; data.len()];
1850 mean_ad_into_slice(&mut output, &input, Kernel::Auto)
1851 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1852
1853 Ok(output)
1854}
1855
1856#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1857#[wasm_bindgen]
1858pub fn mean_ad_into(
1859 in_ptr: *const f64,
1860 out_ptr: *mut f64,
1861 len: usize,
1862 period: usize,
1863) -> Result<(), JsValue> {
1864 if in_ptr.is_null() || out_ptr.is_null() {
1865 return Err(JsValue::from_str("null pointer passed to mean_ad_into"));
1866 }
1867
1868 unsafe {
1869 let data = std::slice::from_raw_parts(in_ptr, len);
1870 let params = MeanAdParams {
1871 period: Some(period),
1872 };
1873 let input = MeanAdInput::from_slice(data, params);
1874
1875 if in_ptr == out_ptr {
1876 let mut temp = vec![0.0; len];
1877 mean_ad_into_slice(&mut temp, &input, Kernel::Auto)
1878 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1879 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1880 out.copy_from_slice(&temp);
1881 } else {
1882 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1883 mean_ad_into_slice(out, &input, Kernel::Auto)
1884 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1885 }
1886 Ok(())
1887 }
1888}
1889
1890#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1891#[wasm_bindgen]
1892pub fn mean_ad_alloc(len: usize) -> *mut f64 {
1893 let mut vec = Vec::<f64>::with_capacity(len);
1894 let ptr = vec.as_mut_ptr();
1895 std::mem::forget(vec);
1896 ptr
1897}
1898
1899#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1900#[wasm_bindgen]
1901pub fn mean_ad_free(ptr: *mut f64, len: usize) {
1902 if !ptr.is_null() {
1903 unsafe {
1904 let _ = Vec::from_raw_parts(ptr, len, len);
1905 }
1906 }
1907}
1908
1909#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1910#[derive(Serialize, Deserialize)]
1911pub struct MeanAdBatchConfig {
1912 pub period_range: (usize, usize, usize),
1913}
1914
1915#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1916#[derive(Serialize, Deserialize)]
1917pub struct MeanAdBatchJsOutput {
1918 pub values: Vec<f64>,
1919 pub periods: Vec<usize>,
1920 pub rows: usize,
1921 pub cols: usize,
1922}
1923
1924#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1925#[wasm_bindgen(js_name = mean_ad_batch)]
1926pub fn mean_ad_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1927 let config: MeanAdBatchConfig = serde_wasm_bindgen::from_value(config)
1928 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1929
1930 let sweep = MeanAdBatchRange {
1931 period: config.period_range,
1932 };
1933
1934 let output = mean_ad_batch_inner(data, &sweep, Kernel::Auto, false)
1935 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1936
1937 let js_output = MeanAdBatchJsOutput {
1938 values: output.values,
1939 periods: output.combos.iter().map(|p| p.period.unwrap()).collect(),
1940 rows: output.rows,
1941 cols: output.cols,
1942 };
1943
1944 serde_wasm_bindgen::to_value(&js_output)
1945 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1946}
1947
1948#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1949#[wasm_bindgen]
1950pub fn mean_ad_batch_into(
1951 in_ptr: *const f64,
1952 out_ptr: *mut f64,
1953 len: usize,
1954 period_start: usize,
1955 period_end: usize,
1956 period_step: usize,
1957) -> Result<usize, JsValue> {
1958 if in_ptr.is_null() || out_ptr.is_null() {
1959 return Err(JsValue::from_str(
1960 "null pointer passed to mean_ad_batch_into",
1961 ));
1962 }
1963
1964 unsafe {
1965 let data = std::slice::from_raw_parts(in_ptr, len);
1966
1967 let sweep = MeanAdBatchRange {
1968 period: (period_start, period_end, period_step),
1969 };
1970
1971 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1972 let rows = combos.len();
1973 let cols = len;
1974
1975 if out_ptr as usize % 8 != 0 {
1976 return Err(JsValue::from_str("Output pointer must be 8-byte aligned"));
1977 }
1978
1979 let total = rows
1980 .checked_mul(cols)
1981 .ok_or_else(|| JsValue::from_str("mean_ad_batch_into: size overflow"))?;
1982 let out = std::slice::from_raw_parts_mut(out_ptr, total);
1983
1984 let kernel = detect_best_batch_kernel();
1985 let simd = match kernel {
1986 Kernel::Avx512Batch => Kernel::Avx512,
1987 Kernel::Avx2Batch => Kernel::Avx2,
1988 Kernel::ScalarBatch => Kernel::Scalar,
1989 _ => unreachable!(),
1990 };
1991
1992 mean_ad_batch_inner_into(data, &sweep, simd, false, out)
1993 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1994
1995 Ok(rows)
1996 }
1997}