1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::CudaBuffAverages;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use cust::context::Context;
7#[cfg(all(feature = "python", feature = "cuda"))]
8use cust::memory::DeviceBuffer;
9#[cfg(feature = "python")]
10use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
11#[cfg(feature = "python")]
12use pyo3::exceptions::PyValueError;
13#[cfg(feature = "python")]
14use pyo3::prelude::*;
15#[cfg(feature = "python")]
16use pyo3::types::{PyDict, PyList};
17#[cfg(all(feature = "python", feature = "cuda"))]
18use std::sync::Arc;
19
20#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
21use serde::{Deserialize, Serialize};
22#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
23use wasm_bindgen::prelude::*;
24
25use crate::utilities::data_loader::{source_type, Candles};
26use crate::utilities::enums::Kernel;
27use crate::utilities::helpers::{
28 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
29 make_uninit_matrix,
30};
31#[cfg(feature = "python")]
32use crate::utilities::kernel_validation::validate_kernel;
33
34#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
35use core::arch::x86_64::*;
36
37#[cfg(not(target_arch = "wasm32"))]
38use rayon::prelude::*;
39
40use std::convert::AsRef;
41use std::error::Error;
42use std::mem::MaybeUninit;
43use thiserror::Error;
44
45#[cfg(all(feature = "python", feature = "cuda"))]
46#[pyclass(module = "ta_indicators.cuda", unsendable)]
47pub struct BuffAveragesDeviceArrayF32Py {
48 pub(crate) buf: Option<DeviceBuffer<f32>>,
49 pub(crate) rows: usize,
50 pub(crate) cols: usize,
51 pub(crate) _ctx: Arc<Context>,
52 pub(crate) device_id: u32,
53}
54
55#[cfg(all(feature = "python", feature = "cuda"))]
56#[pymethods]
57impl BuffAveragesDeviceArrayF32Py {
58 #[getter]
59 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
60 let d = PyDict::new(py);
61 d.set_item("shape", (self.rows, self.cols))?;
62 d.set_item("typestr", "<f4")?;
63 d.set_item(
64 "strides",
65 (
66 self.cols * std::mem::size_of::<f32>(),
67 std::mem::size_of::<f32>(),
68 ),
69 )?;
70 let ptr = self
71 .buf
72 .as_ref()
73 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
74 .as_device_ptr()
75 .as_raw() as usize;
76 d.set_item("data", (ptr, false))?;
77
78 d.set_item("version", 3)?;
79 Ok(d)
80 }
81
82 fn __dlpack_device__(&self) -> (i32, i32) {
83 (2, self.device_id as i32)
84 }
85
86 #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
87 fn __dlpack__<'py>(
88 &mut self,
89 py: Python<'py>,
90 stream: Option<pyo3::PyObject>,
91 max_version: Option<pyo3::PyObject>,
92 dl_device: Option<pyo3::PyObject>,
93 copy: Option<pyo3::PyObject>,
94 ) -> PyResult<PyObject> {
95 use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
96
97 let (kdl, alloc_dev) = self.__dlpack_device__();
98 if let Some(dev_obj) = dl_device.as_ref() {
99 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
100 if dev_ty != kdl || dev_id != alloc_dev {
101 let wants_copy = copy
102 .as_ref()
103 .and_then(|c| c.extract::<bool>(py).ok())
104 .unwrap_or(false);
105 if wants_copy {
106 return Err(PyValueError::new_err(
107 "device copy not implemented for __dlpack__",
108 ));
109 } else {
110 return Err(PyValueError::new_err("device mismatch for __dlpack__"));
111 }
112 }
113 }
114 }
115 let _ = stream;
116
117 let buf = self
118 .buf
119 .take()
120 .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
121
122 let rows = self.rows;
123 let cols = self.cols;
124
125 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
126
127 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
128 }
129}
130
131impl<'a> AsRef<[f64]> for BuffAveragesInput<'a> {
132 #[inline(always)]
133 fn as_ref(&self) -> &[f64] {
134 match &self.data {
135 BuffAveragesData::Slice(slice) => slice,
136 BuffAveragesData::Candles { candles, source } => source_type(candles, source),
137 }
138 }
139}
140
141#[derive(Debug, Clone)]
142pub enum BuffAveragesData<'a> {
143 Candles {
144 candles: &'a Candles,
145 source: &'a str,
146 },
147 Slice(&'a [f64]),
148}
149
150#[derive(Debug, Clone)]
151pub struct BuffAveragesOutput {
152 pub fast_buff: Vec<f64>,
153 pub slow_buff: Vec<f64>,
154}
155
156#[derive(Debug, Clone)]
157#[cfg_attr(
158 all(target_arch = "wasm32", feature = "wasm"),
159 derive(Serialize, Deserialize)
160)]
161pub struct BuffAveragesParams {
162 pub fast_period: Option<usize>,
163 pub slow_period: Option<usize>,
164}
165
166impl Default for BuffAveragesParams {
167 fn default() -> Self {
168 Self {
169 fast_period: Some(5),
170 slow_period: Some(20),
171 }
172 }
173}
174
175#[derive(Debug, Clone)]
176pub struct BuffAveragesInput<'a> {
177 pub data: BuffAveragesData<'a>,
178 pub volume: Option<&'a [f64]>,
179 pub params: BuffAveragesParams,
180}
181
182impl<'a> BuffAveragesInput<'a> {
183 #[inline]
184 pub fn from_candles(c: &'a Candles, s: &'a str, p: BuffAveragesParams) -> Self {
185 Self {
186 data: BuffAveragesData::Candles {
187 candles: c,
188 source: s,
189 },
190 volume: Some(&c.volume),
191 params: p,
192 }
193 }
194
195 #[inline]
196 pub fn from_slices(price: &'a [f64], volume: &'a [f64], p: BuffAveragesParams) -> Self {
197 Self {
198 data: BuffAveragesData::Slice(price),
199 volume: Some(volume),
200 params: p,
201 }
202 }
203
204 #[inline]
205 pub fn from_slice(sl: &'a [f64], p: BuffAveragesParams) -> Self {
206 Self {
207 data: BuffAveragesData::Slice(sl),
208 volume: None,
209 params: p,
210 }
211 }
212
213 #[inline]
214 pub fn with_default_candles(c: &'a Candles) -> Self {
215 Self::from_candles(c, "close", BuffAveragesParams::default())
216 }
217
218 #[inline]
219 pub fn get_fast_period(&self) -> usize {
220 self.params.fast_period.unwrap_or(5)
221 }
222
223 #[inline]
224 pub fn get_slow_period(&self) -> usize {
225 self.params.slow_period.unwrap_or(20)
226 }
227}
228
229#[derive(Copy, Clone, Debug)]
230pub struct BuffAveragesBuilder {
231 fast_period: Option<usize>,
232 slow_period: Option<usize>,
233 kernel: Kernel,
234}
235
236impl Default for BuffAveragesBuilder {
237 fn default() -> Self {
238 Self {
239 fast_period: None,
240 slow_period: None,
241 kernel: Kernel::Auto,
242 }
243 }
244}
245
246impl BuffAveragesBuilder {
247 #[inline(always)]
248 pub fn new() -> Self {
249 Self::default()
250 }
251
252 #[inline(always)]
253 pub fn fast_period(mut self, val: usize) -> Self {
254 self.fast_period = Some(val);
255 self
256 }
257
258 #[inline(always)]
259 pub fn slow_period(mut self, val: usize) -> Self {
260 self.slow_period = Some(val);
261 self
262 }
263
264 #[inline(always)]
265 pub fn kernel(mut self, k: Kernel) -> Self {
266 self.kernel = k;
267 self
268 }
269
270 #[inline(always)]
271 pub fn apply(self, c: &Candles) -> Result<BuffAveragesOutput, BuffAveragesError> {
272 let p = BuffAveragesParams {
273 fast_period: self.fast_period,
274 slow_period: self.slow_period,
275 };
276 let i = BuffAveragesInput::from_candles(c, "close", p);
277 buff_averages_with_kernel(&i, self.kernel)
278 }
279
280 #[inline(always)]
281 pub fn apply_slices(
282 self,
283 price: &[f64],
284 volume: &[f64],
285 ) -> Result<BuffAveragesOutput, BuffAveragesError> {
286 let p = BuffAveragesParams {
287 fast_period: self.fast_period,
288 slow_period: self.slow_period,
289 };
290 let i = BuffAveragesInput::from_slices(price, volume, p);
291 buff_averages_with_kernel(&i, self.kernel)
292 }
293
294 #[inline(always)]
295 pub fn into_stream(self) -> Result<BuffAveragesStream, BuffAveragesError> {
296 let p = BuffAveragesParams {
297 fast_period: self.fast_period,
298 slow_period: self.slow_period,
299 };
300 BuffAveragesStream::try_new(p)
301 }
302}
303
304#[derive(Debug, Error)]
305pub enum BuffAveragesError {
306 #[error("buff_averages: Input data slice is empty.")]
307 EmptyInputData,
308
309 #[error("buff_averages: All values are NaN.")]
310 AllValuesNaN,
311
312 #[error("buff_averages: Invalid period: period = {period}, data length = {data_len}")]
313 InvalidPeriod { period: usize, data_len: usize },
314
315 #[error("buff_averages: Not enough valid data: needed = {needed}, valid = {valid}")]
316 NotEnoughValidData { needed: usize, valid: usize },
317
318 #[error("buff_averages: Price and volume arrays have different lengths: price = {price_len}, volume = {volume_len}")]
319 MismatchedDataLength { price_len: usize, volume_len: usize },
320
321 #[error("buff_averages: Volume data is required for this indicator")]
322 MissingVolumeData,
323
324 #[error("buff_averages: Output length mismatch: expected = {expected}, got = {got}")]
325 OutputLengthMismatch { expected: usize, got: usize },
326
327 #[error("buff_averages: Invalid range: start = {start}, end = {end}, step = {step}")]
328 InvalidRange {
329 start: usize,
330 end: usize,
331 step: usize,
332 },
333
334 #[error("buff_averages: Invalid kernel for batch: {0:?}")]
335 InvalidKernelForBatch(Kernel),
336
337 #[error("buff_averages: size overflow for rows = {rows}, cols = {cols}")]
338 SizeOverflow { rows: usize, cols: usize },
339}
340
341#[inline]
342pub fn buff_averages(input: &BuffAveragesInput) -> Result<BuffAveragesOutput, BuffAveragesError> {
343 buff_averages_with_kernel(input, Kernel::Auto)
344}
345
346pub fn buff_averages_with_kernel(
347 input: &BuffAveragesInput,
348 kernel: Kernel,
349) -> Result<BuffAveragesOutput, BuffAveragesError> {
350 let (price, volume, fast_period, slow_period, first, chosen) =
351 buff_averages_prepare(input, kernel)?;
352
353 let warm = first + slow_period - 1;
354
355 let mut fast_buff = alloc_with_nan_prefix(price.len(), warm);
356 let mut slow_buff = alloc_with_nan_prefix(price.len(), warm);
357
358 buff_averages_compute_into(
359 price,
360 volume,
361 fast_period,
362 slow_period,
363 first,
364 chosen,
365 &mut fast_buff,
366 &mut slow_buff,
367 );
368
369 Ok(BuffAveragesOutput {
370 fast_buff,
371 slow_buff,
372 })
373}
374
375#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
376#[inline]
377pub fn buff_averages_into(
378 input: &BuffAveragesInput,
379 fast_out: &mut [f64],
380 slow_out: &mut [f64],
381) -> Result<(), BuffAveragesError> {
382 let (price, volume, fast_period, slow_period, first, chosen) =
383 buff_averages_prepare(input, Kernel::Auto)?;
384
385 if fast_out.len() != price.len() || slow_out.len() != price.len() {
386 return Err(BuffAveragesError::OutputLengthMismatch {
387 expected: price.len(),
388 got: core::cmp::min(fast_out.len(), slow_out.len()),
389 });
390 }
391
392 let warm = first + slow_period - 1;
393 let nan = f64::from_bits(0x7ff8_0000_0000_0000);
394 let warmup_len = warm.min(price.len());
395 for v in &mut fast_out[..warmup_len] {
396 *v = nan;
397 }
398 for v in &mut slow_out[..warmup_len] {
399 *v = nan;
400 }
401
402 buff_averages_compute_into(
403 price,
404 volume,
405 fast_period,
406 slow_period,
407 first,
408 chosen,
409 fast_out,
410 slow_out,
411 );
412
413 Ok(())
414}
415
416#[inline]
417pub fn buff_averages_into_slices(
418 fast_dst: &mut [f64],
419 slow_dst: &mut [f64],
420 input: &BuffAveragesInput,
421 kern: Kernel,
422) -> Result<(), BuffAveragesError> {
423 let (price, volume, fast_p, slow_p, first, chosen) = buff_averages_prepare(input, kern)?;
424
425 if fast_dst.len() != price.len() || slow_dst.len() != price.len() {
426 return Err(BuffAveragesError::OutputLengthMismatch {
427 expected: price.len(),
428 got: core::cmp::min(fast_dst.len(), slow_dst.len()),
429 });
430 }
431
432 buff_averages_compute_into(
433 price, volume, fast_p, slow_p, first, chosen, fast_dst, slow_dst,
434 );
435
436 let warm = first + slow_p - 1;
437 for x in &mut fast_dst[..warm] {
438 *x = f64::NAN;
439 }
440 for x in &mut slow_dst[..warm] {
441 *x = f64::NAN;
442 }
443
444 Ok(())
445}
446
447#[inline(always)]
448fn buff_averages_prepare<'a>(
449 input: &'a BuffAveragesInput,
450 kernel: Kernel,
451) -> Result<(&'a [f64], &'a [f64], usize, usize, usize, Kernel), BuffAveragesError> {
452 let price: &[f64] = input.as_ref();
453 let len = price.len();
454
455 if len == 0 {
456 return Err(BuffAveragesError::EmptyInputData);
457 }
458
459 let volume = match &input.data {
460 BuffAveragesData::Candles { candles, .. } => &candles.volume,
461 BuffAveragesData::Slice(_) => input.volume.ok_or(BuffAveragesError::MissingVolumeData)?,
462 };
463
464 if price.len() != volume.len() {
465 return Err(BuffAveragesError::MismatchedDataLength {
466 price_len: price.len(),
467 volume_len: volume.len(),
468 });
469 }
470
471 let first = price
472 .iter()
473 .position(|x| !x.is_nan())
474 .ok_or(BuffAveragesError::AllValuesNaN)?;
475
476 let fast_period = input.get_fast_period();
477 let slow_period = input.get_slow_period();
478
479 if fast_period == 0 || fast_period > len {
480 return Err(BuffAveragesError::InvalidPeriod {
481 period: fast_period,
482 data_len: len,
483 });
484 }
485
486 if slow_period == 0 || slow_period > len {
487 return Err(BuffAveragesError::InvalidPeriod {
488 period: slow_period,
489 data_len: len,
490 });
491 }
492
493 if len - first < slow_period {
494 return Err(BuffAveragesError::NotEnoughValidData {
495 needed: slow_period,
496 valid: len - first,
497 });
498 }
499
500 let chosen = match kernel {
501 Kernel::Auto => Kernel::Scalar,
502 k => k,
503 };
504
505 Ok((price, volume, fast_period, slow_period, first, chosen))
506}
507
508#[inline(always)]
509fn buff_averages_compute_into(
510 price: &[f64],
511 volume: &[f64],
512 fast_period: usize,
513 slow_period: usize,
514 first: usize,
515 kernel: Kernel,
516 fast_out: &mut [f64],
517 slow_out: &mut [f64],
518) {
519 unsafe {
520 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
521 {
522 if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
523 buff_averages_simd128(
524 price,
525 volume,
526 fast_period,
527 slow_period,
528 first,
529 fast_out,
530 slow_out,
531 );
532 return;
533 }
534 }
535
536 match kernel {
537 Kernel::Scalar | Kernel::ScalarBatch => buff_averages_scalar(
538 price,
539 volume,
540 fast_period,
541 slow_period,
542 first,
543 fast_out,
544 slow_out,
545 ),
546 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
547 Kernel::Avx2 | Kernel::Avx2Batch => buff_averages_avx2(
548 price,
549 volume,
550 fast_period,
551 slow_period,
552 first,
553 fast_out,
554 slow_out,
555 ),
556 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
557 Kernel::Avx512 | Kernel::Avx512Batch => buff_averages_avx512(
558 price,
559 volume,
560 fast_period,
561 slow_period,
562 first,
563 fast_out,
564 slow_out,
565 ),
566 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
567 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
568 buff_averages_scalar(
569 price,
570 volume,
571 fast_period,
572 slow_period,
573 first,
574 fast_out,
575 slow_out,
576 )
577 }
578 _ => unreachable!(),
579 }
580 }
581}
582
583#[inline]
584pub fn buff_averages_scalar(
585 price: &[f64],
586 volume: &[f64],
587 fast_period: usize,
588 slow_period: usize,
589 first: usize,
590 fast_out: &mut [f64],
591 slow_out: &mut [f64],
592) {
593 let len = price.len();
594 if len == 0 {
595 return;
596 }
597
598 let warm = first + slow_period - 1;
599 if warm >= len {
600 return;
601 }
602
603 let mut slow_numerator = 0.0;
604 let mut slow_denominator = 0.0;
605 let slow_start = warm + 1 - slow_period;
606 for i in slow_start..=warm {
607 let p = price[i];
608 let v = volume[i];
609 if !p.is_nan() && !v.is_nan() {
610 slow_numerator += p * v;
611 slow_denominator += v;
612 }
613 }
614
615 let mut fast_numerator = 0.0;
616 let mut fast_denominator = 0.0;
617 let fast_start = warm + 1 - fast_period;
618 for i in fast_start..=warm {
619 let p = price[i];
620 let v = volume[i];
621 if !p.is_nan() && !v.is_nan() {
622 fast_numerator += p * v;
623 fast_denominator += v;
624 }
625 }
626
627 if slow_denominator != 0.0 {
628 slow_out[warm] = slow_numerator / slow_denominator;
629 } else {
630 slow_out[warm] = 0.0;
631 }
632
633 if fast_denominator != 0.0 {
634 fast_out[warm] = fast_numerator / fast_denominator;
635 } else {
636 fast_out[warm] = 0.0;
637 }
638
639 for i in (warm + 1)..len {
640 let old_slow = i - slow_period;
641 let new_p = price[i];
642 let new_v = volume[i];
643 let old_p = price[old_slow];
644 let old_v = volume[old_slow];
645
646 if !old_p.is_nan() && !old_v.is_nan() {
647 slow_numerator -= old_p * old_v;
648 slow_denominator -= old_v;
649 }
650 if !new_p.is_nan() && !new_v.is_nan() {
651 slow_numerator += new_p * new_v;
652 slow_denominator += new_v;
653 }
654
655 slow_out[i] = if slow_denominator != 0.0 {
656 slow_numerator / slow_denominator
657 } else {
658 0.0
659 };
660
661 let old_fast = i - fast_period;
662 let old_pf = price[old_fast];
663 let old_vf = volume[old_fast];
664
665 if !old_pf.is_nan() && !old_vf.is_nan() {
666 fast_numerator -= old_pf * old_vf;
667 fast_denominator -= old_vf;
668 }
669 if !new_p.is_nan() && !new_v.is_nan() {
670 fast_numerator += new_p * new_v;
671 fast_denominator += new_v;
672 }
673
674 fast_out[i] = if fast_denominator != 0.0 {
675 fast_numerator / fast_denominator
676 } else {
677 0.0
678 };
679 }
680}
681
682#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
683#[inline]
684unsafe fn buff_averages_simd128(
685 price: &[f64],
686 volume: &[f64],
687 fast_period: usize,
688 slow_period: usize,
689 first: usize,
690 fast_out: &mut [f64],
691 slow_out: &mut [f64],
692) {
693 buff_averages_scalar(
694 price,
695 volume,
696 fast_period,
697 slow_period,
698 first,
699 fast_out,
700 slow_out,
701 );
702}
703
704#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
705#[inline(always)]
706unsafe fn hsum256_pd(v: __m256d) -> f64 {
707 let hi: __m128d = _mm256_extractf128_pd::<1>(v);
708 let lo: __m128d = _mm256_castpd256_pd128(v);
709 let sum2: __m128d = _mm_add_pd(lo, hi);
710 let hi64: __m128d = _mm_unpackhi_pd(sum2, sum2);
711 let sum: __m128d = _mm_add_sd(sum2, hi64);
712 _mm_cvtsd_f64(sum)
713}
714
715#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
716#[target_feature(enable = "avx2,fma")]
717unsafe fn buff_averages_avx2(
718 price: &[f64],
719 volume: &[f64],
720 fast_period: usize,
721 slow_period: usize,
722 first: usize,
723 fast_out: &mut [f64],
724 slow_out: &mut [f64],
725) {
726 let len = price.len();
727 if len == 0 {
728 return;
729 }
730 let warm = first + slow_period - 1;
731 if warm >= len {
732 return;
733 }
734
735 let slow_start = warm + 1 - slow_period;
736 let mut i = slow_start;
737 let end = warm + 1;
738 let mut slow_num_v = _mm256_setzero_pd();
739 let mut slow_den_v = _mm256_setzero_pd();
740
741 while i + 4 <= end {
742 let p = _mm256_loadu_pd(price.as_ptr().add(i));
743 let v = _mm256_loadu_pd(volume.as_ptr().add(i));
744
745 let mp = _mm256_cmp_pd(p, p, _CMP_ORD_Q);
746 let mv = _mm256_cmp_pd(v, v, _CMP_ORD_Q);
747 let m = _mm256_and_pd(mp, mv);
748
749 let pz = _mm256_and_pd(p, m);
750 let vz = _mm256_and_pd(v, m);
751 slow_num_v = _mm256_fmadd_pd(pz, vz, slow_num_v);
752 slow_den_v = _mm256_add_pd(slow_den_v, vz);
753 i += 4;
754 }
755
756 let mut slow_numerator = hsum256_pd(slow_num_v);
757 let mut slow_denominator = hsum256_pd(slow_den_v);
758
759 while i < end {
760 let p = *price.get_unchecked(i);
761 let v = *volume.get_unchecked(i);
762 if !p.is_nan() && !v.is_nan() {
763 slow_numerator += p * v;
764 slow_denominator += v;
765 }
766 i += 1;
767 }
768
769 let fast_start = warm + 1 - fast_period;
770 let mut j = fast_start;
771 let mut fast_num_v = _mm256_setzero_pd();
772 let mut fast_den_v = _mm256_setzero_pd();
773
774 while j + 4 <= end {
775 let p = _mm256_loadu_pd(price.as_ptr().add(j));
776 let v = _mm256_loadu_pd(volume.as_ptr().add(j));
777 let mp = _mm256_cmp_pd(p, p, _CMP_ORD_Q);
778 let mv = _mm256_cmp_pd(v, v, _CMP_ORD_Q);
779 let m = _mm256_and_pd(mp, mv);
780 let pz = _mm256_and_pd(p, m);
781 let vz = _mm256_and_pd(v, m);
782 fast_num_v = _mm256_fmadd_pd(pz, vz, fast_num_v);
783 fast_den_v = _mm256_add_pd(fast_den_v, vz);
784 j += 4;
785 }
786
787 let mut fast_numerator = hsum256_pd(fast_num_v);
788 let mut fast_denominator = hsum256_pd(fast_den_v);
789
790 while j < end {
791 let p = *price.get_unchecked(j);
792 let v = *volume.get_unchecked(j);
793 if !p.is_nan() && !v.is_nan() {
794 fast_numerator += p * v;
795 fast_denominator += v;
796 }
797 j += 1;
798 }
799
800 slow_out[warm] = if slow_denominator != 0.0 {
801 slow_numerator / slow_denominator
802 } else {
803 0.0
804 };
805 fast_out[warm] = if fast_denominator != 0.0 {
806 fast_numerator / fast_denominator
807 } else {
808 0.0
809 };
810
811 for k in (warm + 1)..len {
812 let old_slow = k - slow_period;
813 let new_p = *price.get_unchecked(k);
814 let new_v = *volume.get_unchecked(k);
815 let old_p = *price.get_unchecked(old_slow);
816 let old_v = *volume.get_unchecked(old_slow);
817
818 if !old_p.is_nan() && !old_v.is_nan() {
819 slow_numerator -= old_p * old_v;
820 slow_denominator -= old_v;
821 }
822 if !new_p.is_nan() && !new_v.is_nan() {
823 slow_numerator += new_p * new_v;
824 slow_denominator += new_v;
825 }
826 slow_out[k] = if slow_denominator != 0.0 {
827 slow_numerator / slow_denominator
828 } else {
829 0.0
830 };
831
832 let old_fast = k - fast_period;
833 let old_pf = *price.get_unchecked(old_fast);
834 let old_vf = *volume.get_unchecked(old_fast);
835 if !old_pf.is_nan() && !old_vf.is_nan() {
836 fast_numerator -= old_pf * old_vf;
837 fast_denominator -= old_vf;
838 }
839 if !new_p.is_nan() && !new_v.is_nan() {
840 fast_numerator += new_p * new_v;
841 fast_denominator += new_v;
842 }
843 fast_out[k] = if fast_denominator != 0.0 {
844 fast_numerator / fast_denominator
845 } else {
846 0.0
847 };
848 }
849}
850
851#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
852#[target_feature(enable = "avx512f,fma")]
853unsafe fn buff_averages_avx512(
854 price: &[f64],
855 volume: &[f64],
856 fast_period: usize,
857 slow_period: usize,
858 first: usize,
859 fast_out: &mut [f64],
860 slow_out: &mut [f64],
861) {
862 let len = price.len();
863 if len == 0 {
864 return;
865 }
866 let warm = first + slow_period - 1;
867 if warm >= len {
868 return;
869 }
870
871 let slow_start = warm + 1 - slow_period;
872 let mut i = slow_start;
873 let end = warm + 1;
874 let mut slow_num_v = _mm512_setzero_pd();
875 let mut slow_den_v = _mm512_setzero_pd();
876
877 while i + 8 <= end {
878 let p = _mm512_loadu_pd(price.as_ptr().add(i));
879 let v = _mm512_loadu_pd(volume.as_ptr().add(i));
880
881 let mp: __mmask8 = _mm512_cmp_pd_mask(p, p, _CMP_ORD_Q);
882 let mv: __mmask8 = _mm512_cmp_pd_mask(v, v, _CMP_ORD_Q);
883 let m: __mmask8 = mp & mv;
884
885 let pz = _mm512_maskz_mov_pd(m, p);
886 let vz = _mm512_maskz_mov_pd(m, v);
887 slow_num_v = _mm512_fmadd_pd(pz, vz, slow_num_v);
888 slow_den_v = _mm512_add_pd(slow_den_v, vz);
889
890 i += 8;
891 }
892
893 let mut slow_numerator = _mm512_reduce_add_pd(slow_num_v);
894 let mut slow_denominator = _mm512_reduce_add_pd(slow_den_v);
895
896 while i < end {
897 let p = *price.get_unchecked(i);
898 let v = *volume.get_unchecked(i);
899 if !p.is_nan() && !v.is_nan() {
900 slow_numerator += p * v;
901 slow_denominator += v;
902 }
903 i += 1;
904 }
905
906 let fast_start = warm + 1 - fast_period;
907 let mut j = fast_start;
908 let mut fast_num_v = _mm512_setzero_pd();
909 let mut fast_den_v = _mm512_setzero_pd();
910
911 while j + 8 <= end {
912 let p = _mm512_loadu_pd(price.as_ptr().add(j));
913 let v = _mm512_loadu_pd(volume.as_ptr().add(j));
914 let mp: __mmask8 = _mm512_cmp_pd_mask(p, p, _CMP_ORD_Q);
915 let mv: __mmask8 = _mm512_cmp_pd_mask(v, v, _CMP_ORD_Q);
916 let m: __mmask8 = mp & mv;
917
918 let pz = _mm512_maskz_mov_pd(m, p);
919 let vz = _mm512_maskz_mov_pd(m, v);
920 fast_num_v = _mm512_fmadd_pd(pz, vz, fast_num_v);
921 fast_den_v = _mm512_add_pd(fast_den_v, vz);
922
923 j += 8;
924 }
925
926 let mut fast_numerator = _mm512_reduce_add_pd(fast_num_v);
927 let mut fast_denominator = _mm512_reduce_add_pd(fast_den_v);
928
929 while j < end {
930 let p = *price.get_unchecked(j);
931 let v = *volume.get_unchecked(j);
932 if !p.is_nan() && !v.is_nan() {
933 fast_numerator += p * v;
934 fast_denominator += v;
935 }
936 j += 1;
937 }
938
939 slow_out[warm] = if slow_denominator != 0.0 {
940 slow_numerator / slow_denominator
941 } else {
942 0.0
943 };
944 fast_out[warm] = if fast_denominator != 0.0 {
945 fast_numerator / fast_denominator
946 } else {
947 0.0
948 };
949
950 for k in (warm + 1)..len {
951 let old_slow = k - slow_period;
952 let new_p = *price.get_unchecked(k);
953 let new_v = *volume.get_unchecked(k);
954 let old_p = *price.get_unchecked(old_slow);
955 let old_v = *volume.get_unchecked(old_slow);
956
957 if !old_p.is_nan() && !old_v.is_nan() {
958 slow_numerator -= old_p * old_v;
959 slow_denominator -= old_v;
960 }
961 if !new_p.is_nan() && !new_v.is_nan() {
962 slow_numerator += new_p * new_v;
963 slow_denominator += new_v;
964 }
965 slow_out[k] = if slow_denominator != 0.0 {
966 slow_numerator / slow_denominator
967 } else {
968 0.0
969 };
970
971 let old_fast = k - fast_period;
972 let old_pf = *price.get_unchecked(old_fast);
973 let old_vf = *volume.get_unchecked(old_fast);
974 if !old_pf.is_nan() && !old_vf.is_nan() {
975 fast_numerator -= old_pf * old_vf;
976 fast_denominator -= old_vf;
977 }
978 if !new_p.is_nan() && !new_v.is_nan() {
979 fast_numerator += new_p * new_v;
980 fast_denominator += new_v;
981 }
982 fast_out[k] = if fast_denominator != 0.0 {
983 fast_numerator / fast_denominator
984 } else {
985 0.0
986 };
987 }
988}
989
990#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
991#[target_feature(enable = "avx2,fma")]
992pub unsafe fn build_masked_pv_v_avx2(
993 price: &[f64],
994 volume: &[f64],
995 pv: &mut [f64],
996 vv: &mut [f64],
997) {
998 debug_assert_eq!(price.len(), volume.len());
999 debug_assert_eq!(pv.len(), price.len());
1000 debug_assert_eq!(vv.len(), price.len());
1001
1002 let n = price.len();
1003 let mut i = 0usize;
1004
1005 while i + 4 <= n {
1006 let p = _mm256_loadu_pd(price.as_ptr().add(i));
1007 let v = _mm256_loadu_pd(volume.as_ptr().add(i));
1008
1009 let mp = _mm256_cmp_pd(p, p, _CMP_ORD_Q);
1010 let mv = _mm256_cmp_pd(v, v, _CMP_ORD_Q);
1011 let m = _mm256_and_pd(mp, mv);
1012
1013 let pvv = _mm256_mul_pd(p, v);
1014 let pv_masked = _mm256_and_pd(pvv, m);
1015 let vv_masked = _mm256_and_pd(v, m);
1016
1017 _mm256_storeu_pd(pv.as_mut_ptr().add(i), pv_masked);
1018 _mm256_storeu_pd(vv.as_mut_ptr().add(i), vv_masked);
1019
1020 i += 4;
1021 }
1022
1023 while i < n {
1024 let p = *price.get_unchecked(i);
1025 let v = *volume.get_unchecked(i);
1026 if !p.is_nan() && !v.is_nan() {
1027 *pv.get_unchecked_mut(i) = p * v;
1028 *vv.get_unchecked_mut(i) = v;
1029 } else {
1030 *pv.get_unchecked_mut(i) = 0.0;
1031 *vv.get_unchecked_mut(i) = 0.0;
1032 }
1033 i += 1;
1034 }
1035}
1036
1037#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1038#[target_feature(enable = "avx512f,fma")]
1039pub unsafe fn build_masked_pv_v_avx512(
1040 price: &[f64],
1041 volume: &[f64],
1042 pv: &mut [f64],
1043 vv: &mut [f64],
1044) {
1045 debug_assert_eq!(price.len(), volume.len());
1046 debug_assert_eq!(pv.len(), price.len());
1047 debug_assert_eq!(vv.len(), price.len());
1048
1049 let n = price.len();
1050 let mut i = 0usize;
1051
1052 while i + 8 <= n {
1053 let p = _mm512_loadu_pd(price.as_ptr().add(i));
1054 let v = _mm512_loadu_pd(volume.as_ptr().add(i));
1055
1056 let mp: __mmask8 = _mm512_cmp_pd_mask(p, p, _CMP_ORD_Q);
1057 let mv: __mmask8 = _mm512_cmp_pd_mask(v, v, _CMP_ORD_Q);
1058 let m: __mmask8 = mp & mv;
1059
1060 let pvv = _mm512_mul_pd(p, v);
1061 let pv_masked = _mm512_maskz_mov_pd(m, pvv);
1062 let vv_masked = _mm512_maskz_mov_pd(m, v);
1063
1064 _mm512_storeu_pd(pv.as_mut_ptr().add(i), pv_masked);
1065 _mm512_storeu_pd(vv.as_mut_ptr().add(i), vv_masked);
1066
1067 i += 8;
1068 }
1069
1070 while i < n {
1071 let p = *price.get_unchecked(i);
1072 let v = *volume.get_unchecked(i);
1073 if !p.is_nan() && !v.is_nan() {
1074 *pv.get_unchecked_mut(i) = p * v;
1075 *vv.get_unchecked_mut(i) = v;
1076 } else {
1077 *pv.get_unchecked_mut(i) = 0.0;
1078 *vv.get_unchecked_mut(i) = 0.0;
1079 }
1080 i += 1;
1081 }
1082}
1083
1084#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1085#[target_feature(enable = "avx2,fma")]
1086pub unsafe fn buff_averages_row_avx2_from_masked(
1087 pv: &[f64],
1088 vv: &[f64],
1089 fast_period: usize,
1090 slow_period: usize,
1091 first: usize,
1092 fast_out: &mut [f64],
1093 slow_out: &mut [f64],
1094) {
1095 let len = pv.len();
1096 if len == 0 {
1097 return;
1098 }
1099
1100 let warm = first + slow_period - 1;
1101 if warm >= len {
1102 return;
1103 }
1104
1105 let end = warm + 1;
1106
1107 let mut i = end - slow_period;
1108 let mut s_num_v = _mm256_setzero_pd();
1109 let mut s_den_v = _mm256_setzero_pd();
1110 while i + 4 <= end {
1111 let pv4 = _mm256_loadu_pd(pv.as_ptr().add(i));
1112 let v4 = _mm256_loadu_pd(vv.as_ptr().add(i));
1113 s_num_v = _mm256_add_pd(s_num_v, pv4);
1114 s_den_v = _mm256_add_pd(s_den_v, v4);
1115 i += 4;
1116 }
1117 let mut slow_numerator = hsum256_pd(s_num_v);
1118 let mut slow_denominator = hsum256_pd(s_den_v);
1119 while i < end {
1120 slow_numerator += *pv.get_unchecked(i);
1121 slow_denominator += *vv.get_unchecked(i);
1122 i += 1;
1123 }
1124
1125 let mut j = end - fast_period;
1126 let mut f_num_v = _mm256_setzero_pd();
1127 let mut f_den_v = _mm256_setzero_pd();
1128 while j + 4 <= end {
1129 let pv4 = _mm256_loadu_pd(pv.as_ptr().add(j));
1130 let v4 = _mm256_loadu_pd(vv.as_ptr().add(j));
1131 f_num_v = _mm256_add_pd(f_num_v, pv4);
1132 f_den_v = _mm256_add_pd(f_den_v, v4);
1133 j += 4;
1134 }
1135 let mut fast_numerator = hsum256_pd(f_num_v);
1136 let mut fast_denominator = hsum256_pd(f_den_v);
1137 while j < end {
1138 fast_numerator += *pv.get_unchecked(j);
1139 fast_denominator += *vv.get_unchecked(j);
1140 j += 1;
1141 }
1142
1143 slow_out[warm] = if slow_denominator != 0.0 {
1144 slow_numerator / slow_denominator
1145 } else {
1146 0.0
1147 };
1148 fast_out[warm] = if fast_denominator != 0.0 {
1149 fast_numerator / fast_denominator
1150 } else {
1151 0.0
1152 };
1153
1154 for k in (warm + 1)..len {
1155 let old_s = k - slow_period;
1156 let new_pv = *pv.get_unchecked(k);
1157 let new_vv = *vv.get_unchecked(k);
1158 let old_pv = *pv.get_unchecked(old_s);
1159 let old_vv = *vv.get_unchecked(old_s);
1160 slow_numerator -= old_pv;
1161 slow_denominator -= old_vv;
1162 slow_numerator += new_pv;
1163 slow_denominator += new_vv;
1164 slow_out[k] = if slow_denominator != 0.0 {
1165 slow_numerator / slow_denominator
1166 } else {
1167 0.0
1168 };
1169
1170 let old_f = k - fast_period;
1171 let old_fp = *pv.get_unchecked(old_f);
1172 let old_fv = *vv.get_unchecked(old_f);
1173 fast_numerator -= old_fp;
1174 fast_denominator -= old_fv;
1175 fast_numerator += new_pv;
1176 fast_denominator += new_vv;
1177 fast_out[k] = if fast_denominator != 0.0 {
1178 fast_numerator / fast_denominator
1179 } else {
1180 0.0
1181 };
1182 }
1183}
1184
1185#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1186#[target_feature(enable = "avx512f,fma")]
1187pub unsafe fn buff_averages_row_avx512_from_masked(
1188 pv: &[f64],
1189 vv: &[f64],
1190 fast_period: usize,
1191 slow_period: usize,
1192 first: usize,
1193 fast_out: &mut [f64],
1194 slow_out: &mut [f64],
1195) {
1196 let len = pv.len();
1197 if len == 0 {
1198 return;
1199 }
1200
1201 let warm = first + slow_period - 1;
1202 if warm >= len {
1203 return;
1204 }
1205
1206 let end = warm + 1;
1207
1208 let mut i = end - slow_period;
1209 let mut s_num_v = _mm512_setzero_pd();
1210 let mut s_den_v = _mm512_setzero_pd();
1211 while i + 8 <= end {
1212 let pv8 = _mm512_loadu_pd(pv.as_ptr().add(i));
1213 let v8 = _mm512_loadu_pd(vv.as_ptr().add(i));
1214 s_num_v = _mm512_add_pd(s_num_v, pv8);
1215 s_den_v = _mm512_add_pd(s_den_v, v8);
1216 i += 8;
1217 }
1218 let mut slow_numerator = _mm512_reduce_add_pd(s_num_v);
1219 let mut slow_denominator = _mm512_reduce_add_pd(s_den_v);
1220 while i < end {
1221 slow_numerator += *pv.get_unchecked(i);
1222 slow_denominator += *vv.get_unchecked(i);
1223 i += 1;
1224 }
1225
1226 let mut j = end - fast_period;
1227 let mut f_num_v = _mm512_setzero_pd();
1228 let mut f_den_v = _mm512_setzero_pd();
1229 while j + 8 <= end {
1230 let pv8 = _mm512_loadu_pd(pv.as_ptr().add(j));
1231 let v8 = _mm512_loadu_pd(vv.as_ptr().add(j));
1232 f_num_v = _mm512_add_pd(f_num_v, pv8);
1233 f_den_v = _mm512_add_pd(f_den_v, v8);
1234 j += 8;
1235 }
1236 let mut fast_numerator = _mm512_reduce_add_pd(f_num_v);
1237 let mut fast_denominator = _mm512_reduce_add_pd(f_den_v);
1238 while j < end {
1239 fast_numerator += *pv.get_unchecked(j);
1240 fast_denominator += *vv.get_unchecked(j);
1241 j += 1;
1242 }
1243
1244 slow_out[warm] = if slow_denominator != 0.0 {
1245 slow_numerator / slow_denominator
1246 } else {
1247 0.0
1248 };
1249 fast_out[warm] = if fast_denominator != 0.0 {
1250 fast_numerator / fast_denominator
1251 } else {
1252 0.0
1253 };
1254
1255 for k in (warm + 1)..len {
1256 let old_s = k - slow_period;
1257 let new_pv = *pv.get_unchecked(k);
1258 let new_vv = *vv.get_unchecked(k);
1259 let old_pv = *pv.get_unchecked(old_s);
1260 let old_vv = *vv.get_unchecked(old_s);
1261 slow_numerator -= old_pv;
1262 slow_denominator -= old_vv;
1263 slow_numerator += new_pv;
1264 slow_denominator += new_vv;
1265 slow_out[k] = if slow_denominator != 0.0 {
1266 slow_numerator / slow_denominator
1267 } else {
1268 0.0
1269 };
1270
1271 let old_f = k - fast_period;
1272 let old_fp = *pv.get_unchecked(old_f);
1273 let old_fv = *vv.get_unchecked(old_f);
1274 fast_numerator -= old_fp;
1275 fast_denominator -= old_fv;
1276 fast_numerator += new_pv;
1277 fast_denominator += new_vv;
1278 fast_out[k] = if fast_denominator != 0.0 {
1279 fast_numerator / fast_denominator
1280 } else {
1281 0.0
1282 };
1283 }
1284}
1285
1286#[derive(Debug, Clone)]
1287pub struct BuffAveragesStream {
1288 ring_pv: Vec<f64>,
1289 ring_vv: Vec<f64>,
1290
1291 cap: usize,
1292
1293 fast_period: usize,
1294 slow_period: usize,
1295
1296 fast_num: f64,
1297 fast_den: f64,
1298 slow_num: f64,
1299 slow_den: f64,
1300
1301 index: usize,
1302
1303 warm_target_count: Option<usize>,
1304}
1305
1306impl BuffAveragesStream {
1307 #[inline]
1308 pub fn try_new(params: BuffAveragesParams) -> Result<Self, BuffAveragesError> {
1309 let fast_period = params.fast_period.unwrap_or(5);
1310 let slow_period = params.slow_period.unwrap_or(20);
1311
1312 if fast_period == 0 {
1313 return Err(BuffAveragesError::InvalidPeriod {
1314 period: fast_period,
1315 data_len: 0,
1316 });
1317 }
1318 if slow_period == 0 {
1319 return Err(BuffAveragesError::InvalidPeriod {
1320 period: slow_period,
1321 data_len: 0,
1322 });
1323 }
1324
1325 let cap = core::cmp::max(fast_period, slow_period);
1326
1327 Ok(Self {
1328 ring_pv: vec![0.0; cap],
1329 ring_vv: vec![0.0; cap],
1330 cap,
1331 fast_period,
1332 slow_period,
1333 fast_num: 0.0,
1334 fast_den: 0.0,
1335 slow_num: 0.0,
1336 slow_den: 0.0,
1337 index: 0,
1338 warm_target_count: None,
1339 })
1340 }
1341
1342 #[inline]
1343 pub fn update(&mut self, price: f64, volume: f64) -> Option<(f64, f64)> {
1344 let n = self.index;
1345 let write_idx = n % self.cap;
1346
1347 if self.warm_target_count.is_none() && !price.is_nan() {
1348 self.warm_target_count = Some(n + self.slow_period);
1349 }
1350
1351 let valid = !price.is_nan() && !volume.is_nan();
1352 let pv_new = if valid {
1353 price.mul_add(volume, 0.0)
1354 } else {
1355 0.0
1356 };
1357 let vv_new = if valid { volume } else { 0.0 };
1358
1359 if n >= self.slow_period {
1360 let idx_out_slow = (n + self.cap - self.slow_period) % self.cap;
1361 let old_pv = unsafe { *self.ring_pv.get_unchecked(idx_out_slow) };
1362 let old_vv = unsafe { *self.ring_vv.get_unchecked(idx_out_slow) };
1363 self.slow_num -= old_pv;
1364 self.slow_den -= old_vv;
1365 }
1366 if n >= self.fast_period {
1367 let idx_out_fast = (n + self.cap - self.fast_period) % self.cap;
1368 let old_pv = unsafe { *self.ring_pv.get_unchecked(idx_out_fast) };
1369 let old_vv = unsafe { *self.ring_vv.get_unchecked(idx_out_fast) };
1370 self.fast_num -= old_pv;
1371 self.fast_den -= old_vv;
1372 }
1373
1374 unsafe {
1375 *self.ring_pv.get_unchecked_mut(write_idx) = pv_new;
1376 *self.ring_vv.get_unchecked_mut(write_idx) = vv_new;
1377 }
1378
1379 self.slow_num += pv_new;
1380 self.slow_den += vv_new;
1381 self.fast_num += pv_new;
1382 self.fast_den += vv_new;
1383
1384 self.index = n + 1;
1385
1386 if let Some(warm) = self.warm_target_count {
1387 if self.index >= warm {
1388 let slow = if self.slow_den != 0.0 {
1389 self.slow_num / self.slow_den
1390 } else {
1391 0.0
1392 };
1393 let fast = if self.fast_den != 0.0 {
1394 self.fast_num / self.fast_den
1395 } else {
1396 0.0
1397 };
1398 return Some((fast, slow));
1399 }
1400 }
1401 None
1402 }
1403}
1404
1405#[derive(Clone, Debug)]
1406pub struct BuffAveragesBatchRange {
1407 pub fast_period: (usize, usize, usize),
1408 pub slow_period: (usize, usize, usize),
1409}
1410
1411impl Default for BuffAveragesBatchRange {
1412 fn default() -> Self {
1413 Self {
1414 fast_period: (5, 5, 0),
1415 slow_period: (20, 269, 1),
1416 }
1417 }
1418}
1419
1420#[derive(Clone, Debug)]
1421pub struct BuffAveragesBatchOutput {
1422 pub fast: Vec<f64>,
1423 pub slow: Vec<f64>,
1424 pub combos: Vec<(usize, usize)>,
1425 pub rows: usize,
1426 pub cols: usize,
1427}
1428
1429#[derive(Clone, Debug, Default)]
1430pub struct BuffAveragesBatchBuilder {
1431 range: BuffAveragesBatchRange,
1432 kernel: Kernel,
1433}
1434
1435impl BuffAveragesBatchBuilder {
1436 pub fn new() -> Self {
1437 Self::default()
1438 }
1439
1440 pub fn kernel(mut self, k: Kernel) -> Self {
1441 self.kernel = k;
1442 self
1443 }
1444
1445 #[inline]
1446 pub fn fast_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1447 self.range.fast_period = (start, end, step);
1448 self
1449 }
1450
1451 #[inline]
1452 pub fn fast_period_static(mut self, val: usize) -> Self {
1453 self.range.fast_period = (val, val, 0);
1454 self
1455 }
1456
1457 #[inline]
1458 pub fn slow_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1459 self.range.slow_period = (start, end, step);
1460 self
1461 }
1462
1463 #[inline]
1464 pub fn slow_period_static(mut self, val: usize) -> Self {
1465 self.range.slow_period = (val, val, 0);
1466 self
1467 }
1468
1469 pub fn apply_candles(
1470 self,
1471 candles: &Candles,
1472 ) -> Result<BuffAveragesBatchOutput, BuffAveragesError> {
1473 let price = source_type(candles, "close");
1474 let volume = &candles.volume;
1475 buff_averages_batch_with_kernel(price, volume, &self.range, self.kernel)
1476 }
1477
1478 pub fn apply_slices(
1479 self,
1480 price: &[f64],
1481 volume: &[f64],
1482 ) -> Result<BuffAveragesBatchOutput, BuffAveragesError> {
1483 buff_averages_batch_with_kernel(price, volume, &self.range, self.kernel)
1484 }
1485}
1486
1487fn expand_grid_ba(r: &BuffAveragesBatchRange) -> Vec<(usize, usize)> {
1488 fn axis((a, b, s): (usize, usize, usize)) -> Vec<usize> {
1489 if s == 0 || a == b {
1490 return vec![a];
1491 }
1492 let (lo, hi) = if a <= b { (a, b) } else { (b, a) };
1493 (lo..=hi).step_by(s).collect()
1494 }
1495
1496 let fasts = axis(r.fast_period);
1497 let slows = axis(r.slow_period);
1498 let mut v = Vec::with_capacity(fasts.len() * slows.len());
1499
1500 for &f in &fasts {
1501 for &s in &slows {
1502 v.push((f, s));
1503 }
1504 }
1505 v
1506}
1507
1508#[inline]
1509pub fn buff_averages_batch_inner_into(
1510 price: &[f64],
1511 volume: &[f64],
1512 sweep: &BuffAveragesBatchRange,
1513 kern: Kernel,
1514 fast_out: &mut [f64],
1515 slow_out: &mut [f64],
1516) -> Result<Vec<(usize, usize)>, BuffAveragesError> {
1517 buff_averages_batch_inner_into_parallel(price, volume, sweep, kern, fast_out, slow_out, false)
1518}
1519
1520#[inline]
1521fn buff_averages_batch_inner_into_parallel(
1522 price: &[f64],
1523 volume: &[f64],
1524 sweep: &BuffAveragesBatchRange,
1525 kern: Kernel,
1526 fast_out: &mut [f64],
1527 slow_out: &mut [f64],
1528 parallel: bool,
1529) -> Result<Vec<(usize, usize)>, BuffAveragesError> {
1530 let combos = expand_grid_ba(sweep);
1531 if combos.is_empty() {
1532 let (fs, fe, fp) = sweep.fast_period;
1533 return Err(BuffAveragesError::InvalidRange {
1534 start: fs.min(fe),
1535 end: fs.max(fe),
1536 step: fp,
1537 });
1538 }
1539
1540 if price.len() != volume.len() || price.is_empty() {
1541 return Err(BuffAveragesError::MismatchedDataLength {
1542 price_len: price.len(),
1543 volume_len: volume.len(),
1544 });
1545 }
1546
1547 let first = price
1548 .iter()
1549 .position(|x| !x.is_nan())
1550 .ok_or(BuffAveragesError::AllValuesNaN)?;
1551
1552 let max_slow = combos.iter().map(|&(_, s)| s).max().unwrap();
1553 if price.len() - first < max_slow {
1554 return Err(BuffAveragesError::NotEnoughValidData {
1555 needed: max_slow,
1556 valid: price.len() - first,
1557 });
1558 }
1559
1560 let rows = combos.len();
1561 let cols = price.len();
1562 if rows.checked_mul(cols).is_none() {
1563 return Err(BuffAveragesError::SizeOverflow { rows, cols });
1564 }
1565 let expected = rows * cols;
1566 if fast_out.len() != expected || slow_out.len() != expected {
1567 return Err(BuffAveragesError::OutputLengthMismatch {
1568 expected,
1569 got: core::cmp::min(fast_out.len(), slow_out.len()),
1570 });
1571 }
1572
1573 let fast_mu = unsafe {
1574 core::slice::from_raw_parts_mut(
1575 fast_out.as_mut_ptr() as *mut core::mem::MaybeUninit<f64>,
1576 fast_out.len(),
1577 )
1578 };
1579 let slow_mu = unsafe {
1580 core::slice::from_raw_parts_mut(
1581 slow_out.as_mut_ptr() as *mut core::mem::MaybeUninit<f64>,
1582 slow_out.len(),
1583 )
1584 };
1585
1586 let warms: Vec<usize> = combos.iter().map(|&(_, slow)| first + slow - 1).collect();
1587 init_matrix_prefixes(fast_mu, cols, &warms);
1588 init_matrix_prefixes(slow_mu, cols, &warms);
1589
1590 match kern {
1591 Kernel::Auto | Kernel::ScalarBatch | Kernel::Avx2Batch | Kernel::Avx512Batch => {}
1592 other => return Err(BuffAveragesError::InvalidKernelForBatch(other)),
1593 }
1594
1595 let simd = match match kern {
1596 Kernel::Auto => Kernel::ScalarBatch,
1597 k => k,
1598 } {
1599 Kernel::ScalarBatch => Kernel::Scalar,
1600 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1601 Kernel::Avx2Batch => Kernel::Avx2,
1602 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1603 Kernel::Avx512Batch => Kernel::Avx512,
1604 _ => Kernel::Scalar,
1605 };
1606
1607 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1608 let masked_buffers: Option<(Vec<f64>, Vec<f64>)> =
1609 if rows > 1 && matches!(simd, Kernel::Avx2 | Kernel::Avx512) {
1610 let mut pv = vec![0.0; price.len()];
1611 let mut vv = vec![0.0; price.len()];
1612 unsafe {
1613 match simd {
1614 Kernel::Avx2 => build_masked_pv_v_avx2(price, volume, &mut pv, &mut vv),
1615 Kernel::Avx512 => build_masked_pv_v_avx512(price, volume, &mut pv, &mut vv),
1616 _ => {}
1617 }
1618 }
1619 Some((pv, vv))
1620 } else {
1621 None
1622 };
1623
1624 if parallel {
1625 #[cfg(not(target_arch = "wasm32"))]
1626 {
1627 use rayon::prelude::*;
1628
1629 fast_out
1630 .par_chunks_mut(cols)
1631 .zip(slow_out.par_chunks_mut(cols))
1632 .enumerate()
1633 .for_each(|(row, (fr, sr))| {
1634 let (fp, sp) = combos[row];
1635 let handled = {
1636 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1637 {
1638 if let Some((pv, vv)) = masked_buffers.as_ref() {
1639 unsafe {
1640 match simd {
1641 Kernel::Avx2 => {
1642 buff_averages_row_avx2_from_masked(
1643 pv, vv, fp, sp, first, fr, sr,
1644 );
1645 true
1646 }
1647 Kernel::Avx512 => {
1648 buff_averages_row_avx512_from_masked(
1649 pv, vv, fp, sp, first, fr, sr,
1650 );
1651 true
1652 }
1653 _ => false,
1654 }
1655 }
1656 } else {
1657 false
1658 }
1659 }
1660 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1661 {
1662 false
1663 }
1664 };
1665
1666 if !handled {
1667 buff_averages_compute_into(price, volume, fp, sp, first, simd, fr, sr);
1668 }
1669 });
1670 }
1671
1672 #[cfg(target_arch = "wasm32")]
1673 {
1674 for (row, &(fp, sp)) in combos.iter().enumerate() {
1675 let fr = &mut fast_out[row * cols..(row + 1) * cols];
1676 let sr = &mut slow_out[row * cols..(row + 1) * cols];
1677 let handled = {
1678 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1679 {
1680 if let Some((pv, vv)) = masked_buffers.as_ref() {
1681 unsafe {
1682 match simd {
1683 Kernel::Avx2 => {
1684 buff_averages_row_avx2_from_masked(
1685 pv, vv, fp, sp, first, fr, sr,
1686 );
1687 true
1688 }
1689 Kernel::Avx512 => {
1690 buff_averages_row_avx512_from_masked(
1691 pv, vv, fp, sp, first, fr, sr,
1692 );
1693 true
1694 }
1695 _ => false,
1696 }
1697 }
1698 } else {
1699 false
1700 }
1701 }
1702 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1703 {
1704 false
1705 }
1706 };
1707
1708 if !handled {
1709 buff_averages_compute_into(price, volume, fp, sp, first, simd, fr, sr);
1710 }
1711 }
1712 }
1713 } else {
1714 for (row, &(fp, sp)) in combos.iter().enumerate() {
1715 let fr = &mut fast_out[row * cols..(row + 1) * cols];
1716 let sr = &mut slow_out[row * cols..(row + 1) * cols];
1717 let handled = {
1718 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1719 {
1720 if let Some((pv, vv)) = masked_buffers.as_ref() {
1721 unsafe {
1722 match simd {
1723 Kernel::Avx2 => {
1724 buff_averages_row_avx2_from_masked(
1725 pv, vv, fp, sp, first, fr, sr,
1726 );
1727 true
1728 }
1729 Kernel::Avx512 => {
1730 buff_averages_row_avx512_from_masked(
1731 pv, vv, fp, sp, first, fr, sr,
1732 );
1733 true
1734 }
1735 _ => false,
1736 }
1737 }
1738 } else {
1739 false
1740 }
1741 }
1742 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1743 {
1744 false
1745 }
1746 };
1747
1748 if !handled {
1749 buff_averages_compute_into(price, volume, fp, sp, first, simd, fr, sr);
1750 }
1751 }
1752 }
1753
1754 Ok(combos)
1755}
1756
1757pub fn buff_averages_batch_with_kernel(
1758 price: &[f64],
1759 volume: &[f64],
1760 sweep: &BuffAveragesBatchRange,
1761 k: Kernel,
1762) -> Result<BuffAveragesBatchOutput, BuffAveragesError> {
1763 buff_averages_batch_inner(price, volume, sweep, k, false)
1764}
1765
1766#[inline(always)]
1767pub fn buff_averages_batch_par_slice(
1768 price: &[f64],
1769 volume: &[f64],
1770 sweep: &BuffAveragesBatchRange,
1771 k: Kernel,
1772) -> Result<BuffAveragesBatchOutput, BuffAveragesError> {
1773 buff_averages_batch_inner(price, volume, sweep, k, true)
1774}
1775
1776#[inline(always)]
1777fn buff_averages_batch_inner(
1778 price: &[f64],
1779 volume: &[f64],
1780 sweep: &BuffAveragesBatchRange,
1781 k: Kernel,
1782 parallel: bool,
1783) -> Result<BuffAveragesBatchOutput, BuffAveragesError> {
1784 if price.is_empty() {
1785 return Err(BuffAveragesError::EmptyInputData);
1786 }
1787 if price.len() != volume.len() {
1788 return Err(BuffAveragesError::MismatchedDataLength {
1789 price_len: price.len(),
1790 volume_len: volume.len(),
1791 });
1792 }
1793 let first = price
1794 .iter()
1795 .position(|x| !x.is_nan())
1796 .ok_or(BuffAveragesError::AllValuesNaN)?;
1797 let combos = expand_grid_ba(sweep);
1798 if combos.is_empty() {
1799 let (fs, fe, fp) = sweep.fast_period;
1800 return Err(BuffAveragesError::InvalidRange {
1801 start: fs.min(fe),
1802 end: fs.max(fe),
1803 step: fp,
1804 });
1805 }
1806 let max_slow = combos.iter().map(|&(_, s)| s).max().unwrap();
1807 if price.len() - first < max_slow {
1808 return Err(BuffAveragesError::NotEnoughValidData {
1809 needed: max_slow,
1810 valid: price.len() - first,
1811 });
1812 }
1813
1814 let rows = combos.len();
1815 let cols = price.len();
1816 if rows.checked_mul(cols).is_none() {
1817 return Err(BuffAveragesError::SizeOverflow { rows, cols });
1818 }
1819
1820 let mut fast_mu = make_uninit_matrix(rows, cols);
1821 let mut slow_mu = make_uninit_matrix(rows, cols);
1822
1823 let fast_slice =
1824 unsafe { core::slice::from_raw_parts_mut(fast_mu.as_mut_ptr() as *mut f64, fast_mu.len()) };
1825 let slow_slice =
1826 unsafe { core::slice::from_raw_parts_mut(slow_mu.as_mut_ptr() as *mut f64, slow_mu.len()) };
1827
1828 buff_averages_batch_inner_into_parallel(
1829 price, volume, sweep, k, fast_slice, slow_slice, parallel,
1830 )?;
1831
1832 let fast = unsafe {
1833 let ptr = fast_mu.as_mut_ptr() as *mut f64;
1834 let len = fast_mu.len();
1835 let cap = fast_mu.capacity();
1836 core::mem::forget(fast_mu);
1837 Vec::from_raw_parts(ptr, len, cap)
1838 };
1839 let slow = unsafe {
1840 let ptr = slow_mu.as_mut_ptr() as *mut f64;
1841 let len = slow_mu.len();
1842 let cap = slow_mu.capacity();
1843 core::mem::forget(slow_mu);
1844 Vec::from_raw_parts(ptr, len, cap)
1845 };
1846
1847 Ok(BuffAveragesBatchOutput {
1848 fast,
1849 slow,
1850 combos,
1851 rows,
1852 cols,
1853 })
1854}
1855
1856#[cfg(feature = "python")]
1857#[pyfunction(name = "buff_averages")]
1858#[pyo3(signature = (price, volume, fast_period=5, slow_period=20, kernel=None))]
1859pub fn buff_averages_py<'py>(
1860 py: Python<'py>,
1861 price: PyReadonlyArray1<'py, f64>,
1862 volume: PyReadonlyArray1<'py, f64>,
1863 fast_period: usize,
1864 slow_period: usize,
1865 kernel: Option<&str>,
1866) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
1867 let price_slice = price.as_slice()?;
1868 let volume_slice = volume.as_slice()?;
1869 let kern = validate_kernel(kernel, false)?;
1870 let params = BuffAveragesParams {
1871 fast_period: Some(fast_period),
1872 slow_period: Some(slow_period),
1873 };
1874 let input = BuffAveragesInput::from_slices(price_slice, volume_slice, params);
1875
1876 let result = py
1877 .allow_threads(|| buff_averages_with_kernel(&input, kern))
1878 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1879
1880 Ok((
1881 result.fast_buff.into_pyarray(py),
1882 result.slow_buff.into_pyarray(py),
1883 ))
1884}
1885
1886#[cfg(feature = "python")]
1887#[pyfunction(name = "buff_averages_batch")]
1888#[pyo3(signature = (price, volume, fast_range, slow_range, kernel=None))]
1889pub fn buff_averages_batch_py<'py>(
1890 py: Python<'py>,
1891 price: PyReadonlyArray1<'py, f64>,
1892 volume: PyReadonlyArray1<'py, f64>,
1893 fast_range: (usize, usize, usize),
1894 slow_range: (usize, usize, usize),
1895 kernel: Option<&str>,
1896) -> PyResult<Bound<'py, PyDict>> {
1897 use numpy::IntoPyArray;
1898 let p = price.as_slice()?;
1899 let v = volume.as_slice()?;
1900 let sweep = BuffAveragesBatchRange {
1901 fast_period: fast_range,
1902 slow_period: slow_range,
1903 };
1904 let kern = validate_kernel(kernel, true)?;
1905
1906 let combos = expand_grid_ba(&sweep);
1907 let rows = combos.len();
1908 let cols = p.len();
1909 let fast_arr = unsafe { numpy::PyArray1::<f64>::new(py, [rows * cols], false) };
1910 let slow_arr = unsafe { numpy::PyArray1::<f64>::new(py, [rows * cols], false) };
1911
1912 let fast_slice = unsafe { fast_arr.as_slice_mut()? };
1913 let slow_slice = unsafe { slow_arr.as_slice_mut()? };
1914
1915 let combos = py
1916 .allow_threads(|| {
1917 buff_averages_batch_inner_into(p, v, &sweep, kern, fast_slice, slow_slice)
1918 })
1919 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1920
1921 let d = PyDict::new(py);
1922 d.set_item("fast", fast_arr.reshape((rows, cols))?)?;
1923 d.set_item("slow", slow_arr.reshape((rows, cols))?)?;
1924 d.set_item(
1925 "fast_periods",
1926 combos
1927 .iter()
1928 .map(|c| c.0 as i64)
1929 .collect::<Vec<_>>()
1930 .into_pyarray(py),
1931 )?;
1932 d.set_item(
1933 "slow_periods",
1934 combos
1935 .iter()
1936 .map(|c| c.1 as i64)
1937 .collect::<Vec<_>>()
1938 .into_pyarray(py),
1939 )?;
1940 Ok(d)
1941}
1942
1943#[cfg(all(feature = "python", feature = "cuda"))]
1944#[pyfunction(name = "buff_averages_cuda_batch_dev")]
1945#[pyo3(signature = (price_f32, volume_f32, fast_range, slow_range, device_id=0))]
1946pub fn buff_averages_cuda_batch_dev_py(
1947 py: Python<'_>,
1948 price_f32: PyReadonlyArray1<'_, f32>,
1949 volume_f32: PyReadonlyArray1<'_, f32>,
1950 fast_range: (usize, usize, usize),
1951 slow_range: (usize, usize, usize),
1952 device_id: usize,
1953) -> PyResult<(BuffAveragesDeviceArrayF32Py, BuffAveragesDeviceArrayF32Py)> {
1954 if !cuda_available() {
1955 return Err(PyValueError::new_err("CUDA not available"));
1956 }
1957
1958 let price = price_f32.as_slice()?;
1959 let volume = volume_f32.as_slice()?;
1960 let sweep = BuffAveragesBatchRange {
1961 fast_period: fast_range,
1962 slow_period: slow_range,
1963 };
1964
1965 let (fast, slow, rows, cols, ctx, dev) = py.allow_threads(|| {
1966 let cuda =
1967 CudaBuffAverages::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1968 let (f, s) = cuda
1969 .buff_averages_batch_dev(price, volume, &sweep)
1970 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1971 let rows = f.rows;
1972 let cols = f.cols;
1973 let ctx = cuda.context_arc();
1974 let dev = cuda.device_id();
1975 Ok::<_, pyo3::PyErr>((f.buf, s.buf, rows, cols, ctx, dev))
1976 })?;
1977
1978 Ok((
1979 BuffAveragesDeviceArrayF32Py {
1980 buf: Some(fast),
1981 rows,
1982 cols,
1983 _ctx: ctx.clone(),
1984 device_id: dev,
1985 },
1986 BuffAveragesDeviceArrayF32Py {
1987 buf: Some(slow),
1988 rows,
1989 cols,
1990 _ctx: ctx,
1991 device_id: dev,
1992 },
1993 ))
1994}
1995
1996#[cfg(all(feature = "python", feature = "cuda"))]
1997#[pyfunction(name = "buff_averages_cuda_many_series_one_param_dev")]
1998#[pyo3(signature = (prices_tm_f32, volumes_tm_f32, cols, rows, fast_period, slow_period, device_id=0))]
1999pub fn buff_averages_cuda_many_series_one_param_dev_py(
2000 py: Python<'_>,
2001 prices_tm_f32: PyReadonlyArray1<'_, f32>,
2002 volumes_tm_f32: PyReadonlyArray1<'_, f32>,
2003 cols: usize,
2004 rows: usize,
2005 fast_period: usize,
2006 slow_period: usize,
2007 device_id: usize,
2008) -> PyResult<(BuffAveragesDeviceArrayF32Py, BuffAveragesDeviceArrayF32Py)> {
2009 if !cuda_available() {
2010 return Err(PyValueError::new_err("CUDA not available"));
2011 }
2012
2013 let prices = prices_tm_f32.as_slice()?;
2014 let volumes = volumes_tm_f32.as_slice()?;
2015
2016 let (fast, slow, rows_o, cols_o, ctx, dev) = py.allow_threads(|| {
2017 let cuda =
2018 CudaBuffAverages::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2019 let (f, s) = cuda
2020 .buff_averages_many_series_one_param_time_major_dev(
2021 prices,
2022 volumes,
2023 cols,
2024 rows,
2025 fast_period,
2026 slow_period,
2027 )
2028 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2029 let ctx = cuda.context_arc();
2030 let dev = cuda.device_id();
2031 Ok::<_, pyo3::PyErr>((f.buf, s.buf, f.rows, f.cols, ctx, dev))
2032 })?;
2033
2034 Ok((
2035 BuffAveragesDeviceArrayF32Py {
2036 buf: Some(fast),
2037 rows: rows_o,
2038 cols: cols_o,
2039 _ctx: ctx.clone(),
2040 device_id: dev,
2041 },
2042 BuffAveragesDeviceArrayF32Py {
2043 buf: Some(slow),
2044 rows: rows_o,
2045 cols: cols_o,
2046 _ctx: ctx,
2047 device_id: dev,
2048 },
2049 ))
2050}
2051
2052#[cfg(feature = "python")]
2053#[pyclass(name = "BuffAveragesStream")]
2054pub struct BuffAveragesStreamPy {
2055 stream: BuffAveragesStream,
2056}
2057
2058#[cfg(feature = "python")]
2059#[pymethods]
2060impl BuffAveragesStreamPy {
2061 #[new]
2062 fn new(fast_period: usize, slow_period: usize) -> PyResult<Self> {
2063 let params = BuffAveragesParams {
2064 fast_period: Some(fast_period),
2065 slow_period: Some(slow_period),
2066 };
2067 let stream = BuffAveragesStream::try_new(params)
2068 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2069 Ok(BuffAveragesStreamPy { stream })
2070 }
2071
2072 fn update(&mut self, price: f64, volume: f64) -> Option<(f64, f64)> {
2073 self.stream.update(price, volume)
2074 }
2075}
2076
2077#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2078#[derive(Serialize, Deserialize)]
2079pub struct BuffAveragesJsResult {
2080 pub values: Vec<f64>,
2081 pub rows: usize,
2082 pub cols: usize,
2083}
2084
2085#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2086#[wasm_bindgen(js_name = buff_averages)]
2087pub fn buff_averages_unified_js(
2088 price: &[f64],
2089 volume: &[f64],
2090 fast_period: usize,
2091 slow_period: usize,
2092) -> Result<JsValue, JsValue> {
2093 let len = price.len();
2094 let params = BuffAveragesParams {
2095 fast_period: Some(fast_period),
2096 slow_period: Some(slow_period),
2097 };
2098 let input = BuffAveragesInput::from_slices(price, volume, params);
2099
2100 let mut mat = make_uninit_matrix(2, len);
2101 {
2102 let warms = {
2103 let (_, _, _, sp, first, _) = buff_averages_prepare(&input, Kernel::Auto)
2104 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2105 vec![first + sp - 1, first + sp - 1]
2106 };
2107 init_matrix_prefixes(&mut mat, len, &warms);
2108 }
2109
2110 let values = unsafe {
2111 let flat = core::slice::from_raw_parts_mut(mat.as_mut_ptr() as *mut f64, mat.len());
2112 let (fast_out, slow_out) = flat.split_at_mut(len);
2113 buff_averages_into_slices(fast_out, slow_out, &input, Kernel::Auto)
2114 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2115 let ptr = mat.as_mut_ptr() as *mut f64;
2116 let len = mat.len();
2117 let cap = mat.capacity();
2118 core::mem::forget(mat);
2119 Vec::from_raw_parts(ptr, len, cap)
2120 };
2121
2122 let js = BuffAveragesJsResult {
2123 values,
2124 rows: 2,
2125 cols: len,
2126 };
2127 serde_wasm_bindgen::to_value(&js)
2128 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2129}
2130
2131#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2132#[wasm_bindgen]
2133pub fn buff_averages_js(
2134 price: &[f64],
2135 volume: &[f64],
2136 fast_period: usize,
2137 slow_period: usize,
2138) -> Result<Vec<f64>, JsValue> {
2139 let len = price.len();
2140 let params = BuffAveragesParams {
2141 fast_period: Some(fast_period),
2142 slow_period: Some(slow_period),
2143 };
2144 let input = BuffAveragesInput::from_slices(price, volume, params);
2145
2146 let mut mat = make_uninit_matrix(2, len);
2147 {
2148 let (_, _, _, sp, first, _) = buff_averages_prepare(&input, Kernel::Auto)
2149 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2150 let warm = first + sp - 1;
2151 init_matrix_prefixes(&mut mat, len, &[warm, warm]);
2152 }
2153
2154 let values = unsafe {
2155 let flat = core::slice::from_raw_parts_mut(mat.as_mut_ptr() as *mut f64, mat.len());
2156 let (fast_out, slow_out) = flat.split_at_mut(len);
2157 buff_averages_into_slices(fast_out, slow_out, &input, Kernel::Auto)
2158 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2159 let ptr = mat.as_mut_ptr() as *mut f64;
2160 let len = mat.len();
2161 let cap = mat.capacity();
2162 core::mem::forget(mat);
2163 Vec::from_raw_parts(ptr, len, cap)
2164 };
2165
2166 Ok(values)
2167}
2168
2169#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2170#[wasm_bindgen]
2171pub fn buff_averages_into(
2172 price_ptr: *const f64,
2173 volume_ptr: *const f64,
2174 out_ptr: *mut f64,
2175 len: usize,
2176 fast_period: usize,
2177 slow_period: usize,
2178) -> Result<(), JsValue> {
2179 if price_ptr.is_null() || volume_ptr.is_null() || out_ptr.is_null() {
2180 return Err(JsValue::from_str(
2181 "null pointer passed to buff_averages_into",
2182 ));
2183 }
2184
2185 unsafe {
2186 let price = core::slice::from_raw_parts(price_ptr, len);
2187 let volume = core::slice::from_raw_parts(volume_ptr, len);
2188 let (fast_out, slow_out) =
2189 core::slice::from_raw_parts_mut(out_ptr, 2 * len).split_at_mut(len);
2190
2191 let params = BuffAveragesParams {
2192 fast_period: Some(fast_period),
2193 slow_period: Some(slow_period),
2194 };
2195 let input = BuffAveragesInput::from_slices(price, volume, params);
2196
2197 buff_averages_into_slices(fast_out, slow_out, &input, Kernel::Auto)
2198 .map_err(|e| JsValue::from_str(&e.to_string()))
2199 }
2200}
2201
2202#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2203#[wasm_bindgen]
2204pub fn buff_averages_alloc(len: usize) -> *mut f64 {
2205 let mut v = Vec::<f64>::with_capacity(2 * len);
2206 let ptr = v.as_mut_ptr();
2207 core::mem::forget(v);
2208 ptr
2209}
2210
2211#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2212#[wasm_bindgen]
2213pub fn buff_averages_free(ptr: *mut f64, len: usize) {
2214 unsafe {
2215 let _ = Vec::from_raw_parts(ptr, 2 * len, 2 * len);
2216 }
2217}
2218
2219#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2220#[derive(Serialize, Deserialize)]
2221pub struct BuffAveragesBatchJsOutput {
2222 pub values: Vec<f64>,
2223 pub rows: usize,
2224 pub cols: usize,
2225 pub fast_periods: Vec<usize>,
2226 pub slow_periods: Vec<usize>,
2227}
2228
2229#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2230#[wasm_bindgen(js_name = buff_averages_batch)]
2231pub fn buff_averages_batch_unified_js(
2232 price: &[f64],
2233 volume: &[f64],
2234 fast_range: Vec<usize>,
2235 slow_range: Vec<usize>,
2236) -> Result<JsValue, JsValue> {
2237 if fast_range.len() != 3 || slow_range.len() != 3 {
2238 return Err(JsValue::from_str(
2239 "fast_range and slow_range must each have 3 elements [start, end, step]",
2240 ));
2241 }
2242
2243 let sweep = BuffAveragesBatchRange {
2244 fast_period: (fast_range[0], fast_range[1], fast_range[2]),
2245 slow_period: (slow_range[0], slow_range[1], slow_range[2]),
2246 };
2247
2248 let out = buff_averages_batch_with_kernel(price, volume, &sweep, detect_best_batch_kernel())
2249 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2250
2251 let mut values = Vec::with_capacity(out.fast.len() + out.slow.len());
2252 values.extend_from_slice(&out.fast);
2253 values.extend_from_slice(&out.slow);
2254
2255 let js = BuffAveragesBatchJsOutput {
2256 values,
2257 rows: out.rows * 2,
2258 cols: out.cols,
2259 fast_periods: out.combos.iter().map(|c| c.0).collect(),
2260 slow_periods: out.combos.iter().map(|c| c.1).collect(),
2261 };
2262
2263 serde_wasm_bindgen::to_value(&js)
2264 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2265}
2266
2267#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2268#[wasm_bindgen]
2269pub fn buff_averages_batch_into(
2270 price_ptr: *const f64,
2271 volume_ptr: *const f64,
2272 out_fast_ptr: *mut f64,
2273 out_slow_ptr: *mut f64,
2274 len: usize,
2275 fast_start: usize,
2276 fast_end: usize,
2277 fast_step: usize,
2278 slow_start: usize,
2279 slow_end: usize,
2280 slow_step: usize,
2281) -> Result<usize, JsValue> {
2282 if price_ptr.is_null()
2283 || volume_ptr.is_null()
2284 || out_fast_ptr.is_null()
2285 || out_slow_ptr.is_null()
2286 {
2287 return Err(JsValue::from_str(
2288 "null pointer passed to buff_averages_batch_into",
2289 ));
2290 }
2291 unsafe {
2292 let price = core::slice::from_raw_parts(price_ptr, len);
2293 let volume = core::slice::from_raw_parts(volume_ptr, len);
2294 let sweep = BuffAveragesBatchRange {
2295 fast_period: (fast_start, fast_end, fast_step),
2296 slow_period: (slow_start, slow_end, slow_step),
2297 };
2298
2299 let combos = {
2300 let rows = expand_grid_ba(&sweep).len();
2301 let fast_out = core::slice::from_raw_parts_mut(out_fast_ptr, rows * len);
2302 let slow_out = core::slice::from_raw_parts_mut(out_slow_ptr, rows * len);
2303 buff_averages_batch_inner_into(
2304 price,
2305 volume,
2306 &sweep,
2307 detect_best_batch_kernel(),
2308 fast_out,
2309 slow_out,
2310 )
2311 .map_err(|e| JsValue::from_str(&e.to_string()))?
2312 };
2313 Ok(combos.len())
2314 }
2315}
2316
2317#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2318#[wasm_bindgen]
2319#[deprecated(
2320 since = "1.0.0",
2321 note = "For weight reuse patterns, use the fast/unsafe API with persistent buffers"
2322)]
2323pub struct BuffAveragesContext {
2324 fast_period: usize,
2325 slow_period: usize,
2326 kernel: Kernel,
2327}
2328
2329#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2330#[wasm_bindgen]
2331#[allow(deprecated)]
2332impl BuffAveragesContext {
2333 #[wasm_bindgen(constructor)]
2334 #[deprecated(
2335 since = "1.0.0",
2336 note = "For performance patterns, use the fast/unsafe API with persistent buffers"
2337 )]
2338 pub fn new(fast_period: usize, slow_period: usize) -> Result<BuffAveragesContext, JsValue> {
2339 if fast_period == 0 {
2340 return Err(JsValue::from_str(&format!(
2341 "Invalid fast period: {}",
2342 fast_period
2343 )));
2344 }
2345 if slow_period == 0 {
2346 return Err(JsValue::from_str(&format!(
2347 "Invalid slow period: {}",
2348 slow_period
2349 )));
2350 }
2351
2352 Ok(BuffAveragesContext {
2353 fast_period,
2354 slow_period,
2355 kernel: Kernel::Auto,
2356 })
2357 }
2358
2359 pub fn update_into(
2360 &self,
2361 price_ptr: *const f64,
2362 volume_ptr: *const f64,
2363 fast_out_ptr: *mut f64,
2364 slow_out_ptr: *mut f64,
2365 len: usize,
2366 ) -> Result<(), JsValue> {
2367 if len < self.slow_period {
2368 return Err(JsValue::from_str("Data length less than slow period"));
2369 }
2370
2371 if price_ptr.is_null()
2372 || volume_ptr.is_null()
2373 || fast_out_ptr.is_null()
2374 || slow_out_ptr.is_null()
2375 {
2376 return Err(JsValue::from_str("null pointer passed to update_into"));
2377 }
2378
2379 unsafe {
2380 let price = std::slice::from_raw_parts(price_ptr, len);
2381 let volume = std::slice::from_raw_parts(volume_ptr, len);
2382 let fast_out = std::slice::from_raw_parts_mut(fast_out_ptr, len);
2383 let slow_out = std::slice::from_raw_parts_mut(slow_out_ptr, len);
2384
2385 let params = BuffAveragesParams {
2386 fast_period: Some(self.fast_period),
2387 slow_period: Some(self.slow_period),
2388 };
2389 let input = BuffAveragesInput::from_slices(price, volume, params);
2390
2391 let needs_temp = price_ptr == fast_out_ptr
2392 || price_ptr == slow_out_ptr
2393 || volume_ptr == fast_out_ptr
2394 || volume_ptr == slow_out_ptr;
2395
2396 if needs_temp {
2397 let mut temp_fast = vec![0.0; len];
2398 let mut temp_slow = vec![0.0; len];
2399
2400 buff_averages_into_slices(&mut temp_fast, &mut temp_slow, &input, self.kernel)
2401 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2402
2403 fast_out.copy_from_slice(&temp_fast);
2404 slow_out.copy_from_slice(&temp_slow);
2405 } else {
2406 buff_averages_into_slices(fast_out, slow_out, &input, self.kernel)
2407 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2408 }
2409 }
2410
2411 Ok(())
2412 }
2413
2414 pub fn get_warmup_period(&self) -> usize {
2415 self.slow_period - 1
2416 }
2417
2418 #[wasm_bindgen]
2419 pub fn compute(&self, price: &[f64], volume: &[f64]) -> Result<Vec<f64>, JsValue> {
2420 let params = BuffAveragesParams {
2421 fast_period: Some(self.fast_period),
2422 slow_period: Some(self.slow_period),
2423 };
2424 let input = BuffAveragesInput::from_slices(price, volume, params);
2425 let result = buff_averages_with_kernel(&input, self.kernel)
2426 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2427
2428 let mut output = Vec::with_capacity(price.len() * 2);
2429 output.extend_from_slice(&result.fast_buff);
2430 output.extend_from_slice(&result.slow_buff);
2431 Ok(output)
2432 }
2433}
2434
2435#[cfg(test)]
2436mod tests {
2437 use super::*;
2438 use crate::skip_if_unsupported;
2439 use crate::utilities::data_loader::read_candles_from_csv;
2440 #[cfg(feature = "proptest")]
2441 use proptest::prelude::*;
2442 use std::error::Error;
2443
2444 fn check_buff_averages_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2445 skip_if_unsupported!(kernel, test_name);
2446 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2447 let candles = read_candles_from_csv(file_path)?;
2448
2449 let input =
2450 BuffAveragesInput::from_candles(&candles, "close", BuffAveragesParams::default());
2451 let result = buff_averages_with_kernel(&input, kernel)?;
2452
2453 let expected_fast = [
2454 58740.30855637,
2455 59132.28418702,
2456 59309.76658172,
2457 59266.10492431,
2458 59194.11908892,
2459 ];
2460
2461 let expected_slow = [
2462 59209.26229392,
2463 59201.87047432,
2464 59217.15739355,
2465 59195.74527194,
2466 59196.26139533,
2467 ];
2468
2469 let start = result.fast_buff.len().saturating_sub(6);
2470
2471 for (i, (&fast_val, &slow_val)) in result.fast_buff[start..]
2472 .iter()
2473 .take(5)
2474 .zip(result.slow_buff[start..].iter())
2475 .enumerate()
2476 {
2477 let fast_diff = (fast_val - expected_fast[i]).abs();
2478 let slow_diff = (slow_val - expected_slow[i]).abs();
2479 assert!(
2480 fast_diff < 1e-3,
2481 "[{}] Buff Averages {:?} fast mismatch at idx {}: got {}, expected {}",
2482 test_name,
2483 kernel,
2484 i,
2485 fast_val,
2486 expected_fast[i]
2487 );
2488 assert!(
2489 slow_diff < 1e-3,
2490 "[{}] Buff Averages {:?} slow mismatch at idx {}: got {}, expected {}",
2491 test_name,
2492 kernel,
2493 i,
2494 slow_val,
2495 expected_slow[i]
2496 );
2497 }
2498 Ok(())
2499 }
2500
2501 #[cfg(debug_assertions)]
2502 fn check_buff_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2503 skip_if_unsupported!(kernel, test_name);
2504 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2505 let c = read_candles_from_csv(file)?;
2506 let out = buff_averages_with_kernel(&BuffAveragesInput::with_default_candles(&c), kernel)?;
2507
2508 for (i, &v) in out.fast_buff.iter().enumerate() {
2509 if v.is_nan() {
2510 continue;
2511 }
2512 let b = v.to_bits();
2513 assert!(
2514 b != 0x11111111_11111111 && b != 0x22222222_22222222 && b != 0x33333333_33333333,
2515 "[{}] poison in fast at {}: {:#x}",
2516 test_name,
2517 i,
2518 b
2519 );
2520 }
2521
2522 for (i, &v) in out.slow_buff.iter().enumerate() {
2523 if v.is_nan() {
2524 continue;
2525 }
2526 let b = v.to_bits();
2527 assert!(
2528 b != 0x11111111_11111111 && b != 0x22222222_22222222 && b != 0x33333333_33333333,
2529 "[{}] poison in slow at {}: {:#x}",
2530 test_name,
2531 i,
2532 b
2533 );
2534 }
2535 Ok(())
2536 }
2537
2538 fn check_buff_nan_prefix(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2539 skip_if_unsupported!(kernel, test_name);
2540 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2541 let c = read_candles_from_csv(file)?;
2542 let input = BuffAveragesInput::with_default_candles(&c);
2543
2544 let (price, _, _, slow_p, first, _) = buff_averages_prepare(&input, kernel)?;
2545 let warm = first + slow_p - 1;
2546
2547 let out = buff_averages_with_kernel(&input, kernel)?;
2548
2549 assert!(
2550 out.fast_buff[..warm].iter().all(|x| x.is_nan()),
2551 "[{}] fast warmup not NaN",
2552 test_name
2553 );
2554 assert!(
2555 out.slow_buff[..warm].iter().all(|x| x.is_nan()),
2556 "[{}] slow warmup not NaN",
2557 test_name
2558 );
2559 assert!(
2560 out.fast_buff[warm..].iter().all(|x| x.is_finite()),
2561 "[{}] fast post-warm has NaN",
2562 test_name
2563 );
2564 assert!(
2565 out.slow_buff[warm..].iter().all(|x| x.is_finite()),
2566 "[{}] slow post-warm has NaN",
2567 test_name
2568 );
2569 Ok(())
2570 }
2571
2572 fn check_buff_averages_partial_params(
2573 test_name: &str,
2574 kernel: Kernel,
2575 ) -> Result<(), Box<dyn Error>> {
2576 skip_if_unsupported!(kernel, test_name);
2577 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2578 let candles = read_candles_from_csv(file_path)?;
2579
2580 let default_params = BuffAveragesParams {
2581 fast_period: None,
2582 slow_period: None,
2583 };
2584 let input = BuffAveragesInput::from_candles(&candles, "close", default_params);
2585 let output = buff_averages_with_kernel(&input, kernel)?;
2586 assert_eq!(output.fast_buff.len(), candles.close.len());
2587 assert_eq!(output.slow_buff.len(), candles.close.len());
2588
2589 Ok(())
2590 }
2591
2592 fn check_buff_averages_default_candles(
2593 test_name: &str,
2594 kernel: Kernel,
2595 ) -> Result<(), Box<dyn Error>> {
2596 skip_if_unsupported!(kernel, test_name);
2597 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2598 let candles = read_candles_from_csv(file_path)?;
2599
2600 let input = BuffAveragesInput::with_default_candles(&candles);
2601 match input.data {
2602 BuffAveragesData::Candles { source, .. } => assert_eq!(source, "close"),
2603 _ => panic!("Expected BuffAveragesData::Candles"),
2604 }
2605 let output = buff_averages_with_kernel(&input, kernel)?;
2606 assert_eq!(output.fast_buff.len(), candles.close.len());
2607 assert_eq!(output.slow_buff.len(), candles.close.len());
2608
2609 Ok(())
2610 }
2611
2612 fn check_buff_averages_zero_period(
2613 test_name: &str,
2614 kernel: Kernel,
2615 ) -> Result<(), Box<dyn Error>> {
2616 skip_if_unsupported!(kernel, test_name);
2617 let input_data = [10.0, 20.0, 30.0];
2618 let volume_data = [100.0, 200.0, 300.0];
2619 let params = BuffAveragesParams {
2620 fast_period: Some(0),
2621 slow_period: Some(10),
2622 };
2623 let input = BuffAveragesInput::from_slices(&input_data, &volume_data, params);
2624 let res = buff_averages_with_kernel(&input, kernel);
2625 assert!(
2626 res.is_err(),
2627 "[{}] Buff Averages should fail with zero period",
2628 test_name
2629 );
2630 Ok(())
2631 }
2632
2633 fn check_buff_averages_period_exceeds_length(
2634 test_name: &str,
2635 kernel: Kernel,
2636 ) -> Result<(), Box<dyn Error>> {
2637 skip_if_unsupported!(kernel, test_name);
2638 let data_small = [10.0, 20.0, 30.0];
2639 let volume_small = [100.0, 200.0, 300.0];
2640 let params = BuffAveragesParams {
2641 fast_period: Some(5),
2642 slow_period: Some(10),
2643 };
2644 let input = BuffAveragesInput::from_slices(&data_small, &volume_small, params);
2645 let res = buff_averages_with_kernel(&input, kernel);
2646 assert!(
2647 res.is_err(),
2648 "[{}] Buff Averages should fail with period exceeding length",
2649 test_name
2650 );
2651 Ok(())
2652 }
2653
2654 fn check_buff_averages_very_small_dataset(
2655 test_name: &str,
2656 kernel: Kernel,
2657 ) -> Result<(), Box<dyn Error>> {
2658 skip_if_unsupported!(kernel, test_name);
2659 let single_point = [42.0];
2660 let single_volume = [100.0];
2661 let params = BuffAveragesParams::default();
2662 let input = BuffAveragesInput::from_slices(&single_point, &single_volume, params);
2663 let res = buff_averages_with_kernel(&input, kernel);
2664 assert!(
2665 res.is_err(),
2666 "[{}] Buff Averages should fail with insufficient data",
2667 test_name
2668 );
2669 Ok(())
2670 }
2671
2672 fn check_buff_averages_empty_input(
2673 test_name: &str,
2674 kernel: Kernel,
2675 ) -> Result<(), Box<dyn Error>> {
2676 skip_if_unsupported!(kernel, test_name);
2677 let empty: [f64; 0] = [];
2678 let params = BuffAveragesParams::default();
2679 let input = BuffAveragesInput::from_slices(&empty, &empty, params);
2680 let res = buff_averages_with_kernel(&input, kernel);
2681 assert!(
2682 res.is_err(),
2683 "[{}] Buff Averages should fail with empty input",
2684 test_name
2685 );
2686 Ok(())
2687 }
2688
2689 fn check_buff_averages_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2690 skip_if_unsupported!(kernel, test_name);
2691 let nan_data = [f64::NAN, f64::NAN, f64::NAN];
2692 let nan_volume = [f64::NAN, f64::NAN, f64::NAN];
2693 let params = BuffAveragesParams::default();
2694 let input = BuffAveragesInput::from_slices(&nan_data, &nan_volume, params);
2695 let res = buff_averages_with_kernel(&input, kernel);
2696 assert!(
2697 res.is_err(),
2698 "[{}] Buff Averages should fail with all NaN values",
2699 test_name
2700 );
2701 Ok(())
2702 }
2703
2704 fn check_buff_averages_mismatched_lengths(
2705 test_name: &str,
2706 kernel: Kernel,
2707 ) -> Result<(), Box<dyn Error>> {
2708 skip_if_unsupported!(kernel, test_name);
2709 let price_data = [10.0, 20.0, 30.0];
2710 let volume_data = [100.0, 200.0];
2711 let params = BuffAveragesParams::default();
2712 let input = BuffAveragesInput::from_slices(&price_data, &volume_data, params);
2713 let res = buff_averages_with_kernel(&input, kernel);
2714 assert!(
2715 res.is_err(),
2716 "[{}] Buff Averages should fail with mismatched data lengths",
2717 test_name
2718 );
2719 Ok(())
2720 }
2721
2722 fn check_buff_averages_missing_volume(
2723 test_name: &str,
2724 kernel: Kernel,
2725 ) -> Result<(), Box<dyn Error>> {
2726 skip_if_unsupported!(kernel, test_name);
2727 let price_data = [
2728 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0, 130.0,
2729 140.0, 150.0, 160.0, 170.0, 180.0, 190.0, 200.0,
2730 ];
2731 let params = BuffAveragesParams::default();
2732
2733 let input = BuffAveragesInput {
2734 data: BuffAveragesData::Slice(&price_data),
2735 params,
2736 volume: None,
2737 };
2738
2739 let res = buff_averages_with_kernel(&input, kernel);
2740 assert!(
2741 res.is_err(),
2742 "[{}] Buff Averages should fail with missing volume data",
2743 test_name
2744 );
2745 Ok(())
2746 }
2747
2748 fn check_buff_averages_batch_single(
2749 test_name: &str,
2750 kernel: Kernel,
2751 ) -> Result<(), Box<dyn Error>> {
2752 skip_if_unsupported!(kernel, test_name);
2753 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2754 let candles = read_candles_from_csv(file_path)?;
2755
2756 let range = BuffAveragesBatchRange {
2757 fast_period: (5, 5, 0),
2758 slow_period: (20, 20, 0),
2759 };
2760
2761 let price = source_type(&candles, "close");
2762 let volume = &candles.volume;
2763
2764 let batch_result = buff_averages_batch_with_kernel(price, volume, &range, kernel)?;
2765
2766 let single_input =
2767 BuffAveragesInput::from_slices(price, volume, BuffAveragesParams::default());
2768 let single_result = buff_averages_with_kernel(&single_input, kernel)?;
2769
2770 assert_eq!(batch_result.rows, 1, "[{}] Expected 1 row", test_name);
2771 assert_eq!(
2772 batch_result.combos.len(),
2773 1,
2774 "[{}] Expected 1 combination",
2775 test_name
2776 );
2777
2778 for i in 0..price.len() {
2779 let batch_fast = batch_result.fast[i];
2780 let single_fast = single_result.fast_buff[i];
2781 if batch_fast.is_finite() && single_fast.is_finite() {
2782 assert!(
2783 (batch_fast - single_fast).abs() < 1e-10,
2784 "[{}] Fast mismatch at {}: batch={}, single={}",
2785 test_name,
2786 i,
2787 batch_fast,
2788 single_fast
2789 );
2790 }
2791
2792 let batch_slow = batch_result.slow[i];
2793 let single_slow = single_result.slow_buff[i];
2794 if batch_slow.is_finite() && single_slow.is_finite() {
2795 assert!(
2796 (batch_slow - single_slow).abs() < 1e-10,
2797 "[{}] Slow mismatch at {}: batch={}, single={}",
2798 test_name,
2799 i,
2800 batch_slow,
2801 single_slow
2802 );
2803 }
2804 }
2805 Ok(())
2806 }
2807
2808 fn check_buff_averages_batch_grid(
2809 test_name: &str,
2810 kernel: Kernel,
2811 ) -> Result<(), Box<dyn Error>> {
2812 skip_if_unsupported!(kernel, test_name);
2813 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2814 let candles = read_candles_from_csv(file_path)?;
2815
2816 let range = BuffAveragesBatchRange {
2817 fast_period: (3, 7, 2),
2818 slow_period: (18, 22, 2),
2819 };
2820
2821 let price = source_type(&candles, "close");
2822 let volume = &candles.volume;
2823
2824 let result = buff_averages_batch_with_kernel(price, volume, &range, kernel)?;
2825
2826 assert_eq!(result.rows, 9, "[{}] Expected 9 rows", test_name);
2827 assert_eq!(
2828 result.cols,
2829 candles.close.len(),
2830 "[{}] Cols mismatch",
2831 test_name
2832 );
2833 assert_eq!(
2834 result.combos.len(),
2835 9,
2836 "[{}] Expected 9 combinations",
2837 test_name
2838 );
2839 assert_eq!(
2840 result.fast.len(),
2841 9 * candles.close.len(),
2842 "[{}] Fast size mismatch",
2843 test_name
2844 );
2845 assert_eq!(
2846 result.slow.len(),
2847 9 * candles.close.len(),
2848 "[{}] Slow size mismatch",
2849 test_name
2850 );
2851
2852 let expected_combos = vec![
2853 (3, 18),
2854 (3, 20),
2855 (3, 22),
2856 (5, 18),
2857 (5, 20),
2858 (5, 22),
2859 (7, 18),
2860 (7, 20),
2861 (7, 22),
2862 ];
2863 assert_eq!(
2864 result.combos, expected_combos,
2865 "[{}] Combinations mismatch",
2866 test_name
2867 );
2868
2869 Ok(())
2870 }
2871
2872 fn check_buff_averages_batch_empty(
2873 test_name: &str,
2874 kernel: Kernel,
2875 ) -> Result<(), Box<dyn Error>> {
2876 skip_if_unsupported!(kernel, test_name);
2877 let price = [];
2878 let volume = [];
2879
2880 let range = BuffAveragesBatchRange {
2881 fast_period: (5, 10, 1),
2882 slow_period: (15, 20, 1),
2883 };
2884
2885 let res = buff_averages_batch_with_kernel(&price, &volume, &range, kernel);
2886 assert!(
2887 res.is_err(),
2888 "[{}] Batch should fail with empty input",
2889 test_name
2890 );
2891 Ok(())
2892 }
2893
2894 fn check_buff_averages_batch_parallel(
2895 test_name: &str,
2896 kernel: Kernel,
2897 ) -> Result<(), Box<dyn Error>> {
2898 skip_if_unsupported!(kernel, test_name);
2899 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2900 let candles = read_candles_from_csv(file_path)?;
2901
2902 let range = BuffAveragesBatchRange {
2903 fast_period: (3, 7, 2),
2904 slow_period: (18, 22, 2),
2905 };
2906
2907 let price = source_type(&candles, "close");
2908 let volume = &candles.volume;
2909
2910 let seq_result = buff_averages_batch_with_kernel(price, volume, &range, kernel)?;
2911
2912 let par_result = buff_averages_batch_par_slice(price, volume, &range, kernel)?;
2913
2914 assert_eq!(
2915 seq_result.rows, par_result.rows,
2916 "[{}] Row count mismatch",
2917 test_name
2918 );
2919 assert_eq!(
2920 seq_result.cols, par_result.cols,
2921 "[{}] Col count mismatch",
2922 test_name
2923 );
2924 assert_eq!(
2925 seq_result.combos, par_result.combos,
2926 "[{}] Combos mismatch",
2927 test_name
2928 );
2929
2930 for i in 0..seq_result.fast.len() {
2931 let seq_fast = seq_result.fast[i];
2932 let par_fast = par_result.fast[i];
2933 if seq_fast.is_finite() && par_fast.is_finite() {
2934 assert!(
2935 (seq_fast - par_fast).abs() < 1e-10,
2936 "[{}] Fast parallel mismatch at {}: seq={}, par={}",
2937 test_name,
2938 i,
2939 seq_fast,
2940 par_fast
2941 );
2942 } else {
2943 assert_eq!(
2944 seq_fast.is_nan(),
2945 par_fast.is_nan(),
2946 "[{}] Fast NaN mismatch at {}",
2947 test_name,
2948 i
2949 );
2950 }
2951 }
2952
2953 for i in 0..seq_result.slow.len() {
2954 let seq_slow = seq_result.slow[i];
2955 let par_slow = par_result.slow[i];
2956 if seq_slow.is_finite() && par_slow.is_finite() {
2957 assert!(
2958 (seq_slow - par_slow).abs() < 1e-10,
2959 "[{}] Slow parallel mismatch at {}: seq={}, par={}",
2960 test_name,
2961 i,
2962 seq_slow,
2963 par_slow
2964 );
2965 } else {
2966 assert_eq!(
2967 seq_slow.is_nan(),
2968 par_slow.is_nan(),
2969 "[{}] Slow NaN mismatch at {}",
2970 test_name,
2971 i
2972 );
2973 }
2974 }
2975
2976 Ok(())
2977 }
2978
2979 #[test]
2980 fn test_buff_averages_stream() -> Result<(), Box<dyn Error>> {
2981 let params = BuffAveragesParams::default();
2982 let mut stream = BuffAveragesStream::try_new(params)?;
2983
2984 let test_data = vec![
2985 (100.0, 1000.0),
2986 (110.0, 1100.0),
2987 (120.0, 1200.0),
2988 (130.0, 1300.0),
2989 (140.0, 1400.0),
2990 (150.0, 1500.0),
2991 (160.0, 1600.0),
2992 (170.0, 1700.0),
2993 (180.0, 1800.0),
2994 (190.0, 1900.0),
2995 (200.0, 2000.0),
2996 (210.0, 2100.0),
2997 (220.0, 2200.0),
2998 (230.0, 2300.0),
2999 (240.0, 2400.0),
3000 (250.0, 2500.0),
3001 (260.0, 2600.0),
3002 (270.0, 2700.0),
3003 (280.0, 2800.0),
3004 (290.0, 2900.0),
3005 (300.0, 3000.0),
3006 ];
3007
3008 let mut results = Vec::new();
3009 for (price, volume) in test_data {
3010 if let Some(result) = stream.update(price, volume) {
3011 results.push(result);
3012 }
3013 }
3014
3015 assert!(!results.is_empty(), "Stream should produce results");
3016
3017 Ok(())
3018 }
3019
3020 macro_rules! generate_buff_averages_tests {
3021 ($($test_fn:ident),*) => {
3022 paste::paste! {
3023 $(
3024 #[test]
3025 fn [<$test_fn _scalar>]() {
3026 let _ = $test_fn(stringify!([<$test_fn _scalar>]), Kernel::Scalar);
3027 }
3028 )*
3029
3030 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3031 $(
3032 #[test]
3033 fn [<$test_fn _avx2>]() {
3034 let _ = $test_fn(stringify!([<$test_fn _avx2>]), Kernel::Avx2);
3035 }
3036
3037 #[test]
3038 fn [<$test_fn _avx512>]() {
3039 let _ = $test_fn(stringify!([<$test_fn _avx512>]), Kernel::Avx512);
3040 }
3041 )*
3042 }
3043 };
3044 }
3045
3046 macro_rules! gen_batch_tests {
3047 ($fn_name:ident) => {
3048 paste::paste! {
3049 #[test]
3050 fn [<$fn_name _scalar>]() {
3051 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3052 }
3053
3054 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3055 #[test]
3056 fn [<$fn_name _avx2>]() {
3057 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3058 }
3059
3060 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3061 #[test]
3062 fn [<$fn_name _avx512>]() {
3063 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3064 }
3065
3066 #[test]
3067 fn [<$fn_name _auto>]() {
3068 let _ = $fn_name(stringify!([<$fn_name _auto>]), Kernel::Auto);
3069 }
3070 }
3071 };
3072 }
3073
3074 generate_buff_averages_tests!(
3075 check_buff_averages_accuracy,
3076 check_buff_averages_partial_params,
3077 check_buff_averages_default_candles,
3078 check_buff_averages_zero_period,
3079 check_buff_averages_period_exceeds_length,
3080 check_buff_averages_very_small_dataset,
3081 check_buff_averages_empty_input,
3082 check_buff_averages_all_nan,
3083 check_buff_averages_mismatched_lengths,
3084 check_buff_averages_missing_volume,
3085 check_buff_nan_prefix
3086 );
3087
3088 #[cfg(debug_assertions)]
3089 generate_buff_averages_tests!(check_buff_no_poison);
3090
3091 gen_batch_tests!(check_buff_averages_batch_single);
3092 gen_batch_tests!(check_buff_averages_batch_grid);
3093 gen_batch_tests!(check_buff_averages_batch_empty);
3094 gen_batch_tests!(check_buff_averages_batch_parallel);
3095
3096 #[cfg(feature = "proptest")]
3097 proptest! {
3098 #[test]
3099 fn prop_buff_averages_length_preserved(
3100 len in 50usize..100,
3101 fast_period in 2usize..10,
3102 slow_period in 11usize..30
3103 ) {
3104
3105 let data: Vec<f64> = (0..len).map(|i| (i as f64 + 1.0) * 10.0).collect();
3106 let volume: Vec<f64> = (0..len).map(|i| (i as f64 + 1.0) * 100.0).collect();
3107
3108 prop_assume!(data.len() > slow_period);
3109
3110 let params = BuffAveragesParams {
3111 fast_period: Some(fast_period),
3112 slow_period: Some(slow_period),
3113 };
3114 let input = BuffAveragesInput::from_slices(&data, &volume, params);
3115
3116 if let Ok(output) = buff_averages(&input) {
3117 prop_assert_eq!(output.fast_buff.len(), data.len());
3118 prop_assert_eq!(output.slow_buff.len(), data.len());
3119 }
3120 }
3121
3122 #[test]
3123 fn prop_buff_averages_nan_handling(
3124 len in 50usize..100
3125 ) {
3126
3127 let mut data: Vec<f64> = (0..len).map(|i| (i as f64 + 1.0) * 10.0).collect();
3128 let mut volume: Vec<f64> = (0..len).map(|i| (i as f64 + 1.0) * 100.0).collect();
3129
3130
3131 for i in (0..5).map(|x| x * 10) {
3132 if i < data.len() {
3133 data[i] = f64::NAN;
3134 volume[i] = f64::NAN;
3135 }
3136 }
3137
3138 let params = BuffAveragesParams::default();
3139 let input = BuffAveragesInput::from_slices(&data, &volume, params);
3140
3141
3142 let _ = buff_averages(&input);
3143 }
3144 }
3145
3146 #[test]
3147 fn test_buff_averages_into_matches_api() -> Result<(), Box<dyn Error>> {
3148 let len = 256usize;
3149 let price: Vec<f64> = (0..len).map(|i| (i as f64) * 1.5 + 10.0).collect();
3150 let volume: Vec<f64> = (0..len).map(|i| (i as f64) * 2.0 + 100.0).collect();
3151
3152 let params = BuffAveragesParams::default();
3153 let input = BuffAveragesInput::from_slices(&price, &volume, params);
3154
3155 let base = buff_averages(&input)?;
3156
3157 let mut out_fast = vec![0.0; len];
3158 let mut out_slow = vec![0.0; len];
3159 super::buff_averages_into(&input, &mut out_fast, &mut out_slow)?;
3160
3161 fn eq_nan_or_close(a: f64, b: f64) -> bool {
3162 (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
3163 }
3164
3165 assert_eq!(base.fast_buff.len(), out_fast.len());
3166 assert_eq!(base.slow_buff.len(), out_slow.len());
3167 for i in 0..len {
3168 assert!(
3169 eq_nan_or_close(base.fast_buff[i], out_fast[i]),
3170 "fast mismatch at {}: {} vs {}",
3171 i,
3172 base.fast_buff[i],
3173 out_fast[i]
3174 );
3175 assert!(
3176 eq_nan_or_close(base.slow_buff[i], out_slow[i]),
3177 "slow mismatch at {}: {} vs {}",
3178 i,
3179 base.slow_buff[i],
3180 out_slow[i]
3181 );
3182 }
3183 Ok(())
3184 }
3185}