1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10use crate::utilities::data_loader::{source_type, Candles};
11use crate::utilities::enums::Kernel;
12use crate::utilities::helpers::{
13 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
14 make_uninit_matrix,
15};
16#[cfg(feature = "python")]
17use crate::utilities::kernel_validation::validate_kernel;
18use aligned_vec::{AVec, CACHELINE_ALIGN};
19#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
20use core::arch::x86_64::*;
21#[cfg(not(target_arch = "wasm32"))]
22use rayon::prelude::*;
23#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
24use serde::{Deserialize, Serialize};
25use std::convert::AsRef;
26use std::error::Error;
27use thiserror::Error;
28#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
29use wasm_bindgen::prelude::*;
30
31#[derive(Debug, Clone)]
32pub enum MidpriceData<'a> {
33 Candles {
34 candles: &'a Candles,
35 high_src: &'a str,
36 low_src: &'a str,
37 },
38 Slices {
39 high: &'a [f64],
40 low: &'a [f64],
41 },
42}
43
44#[derive(Debug, Clone)]
45pub struct MidpriceOutput {
46 pub values: Vec<f64>,
47}
48
49#[derive(Debug, Clone)]
50pub struct MidpriceParams {
51 pub period: Option<usize>,
52}
53
54impl Default for MidpriceParams {
55 fn default() -> Self {
56 Self { period: Some(14) }
57 }
58}
59
60#[derive(Debug, Clone)]
61pub struct MidpriceInput<'a> {
62 pub data: MidpriceData<'a>,
63 pub params: MidpriceParams,
64}
65
66impl<'a> MidpriceInput<'a> {
67 #[inline]
68 pub fn from_candles(
69 candles: &'a Candles,
70 high_src: &'a str,
71 low_src: &'a str,
72 params: MidpriceParams,
73 ) -> Self {
74 Self {
75 data: MidpriceData::Candles {
76 candles,
77 high_src,
78 low_src,
79 },
80 params,
81 }
82 }
83 #[inline]
84 pub fn from_slices(high: &'a [f64], low: &'a [f64], params: MidpriceParams) -> Self {
85 Self {
86 data: MidpriceData::Slices { high, low },
87 params,
88 }
89 }
90 #[inline]
91 pub fn with_default_candles(candles: &'a Candles) -> Self {
92 Self {
93 data: MidpriceData::Candles {
94 candles,
95 high_src: "high",
96 low_src: "low",
97 },
98 params: MidpriceParams::default(),
99 }
100 }
101 #[inline]
102 pub fn get_period(&self) -> usize {
103 self.params.period.unwrap_or(14)
104 }
105}
106
107#[derive(Copy, Clone, Debug)]
108pub struct MidpriceBuilder {
109 period: Option<usize>,
110 kernel: Kernel,
111}
112
113impl Default for MidpriceBuilder {
114 fn default() -> Self {
115 Self {
116 period: None,
117 kernel: Kernel::Auto,
118 }
119 }
120}
121
122impl MidpriceBuilder {
123 #[inline(always)]
124 pub fn new() -> Self {
125 Self::default()
126 }
127 #[inline(always)]
128 pub fn period(mut self, n: usize) -> Self {
129 self.period = Some(n);
130 self
131 }
132 #[inline(always)]
133 pub fn kernel(mut self, k: Kernel) -> Self {
134 self.kernel = k;
135 self
136 }
137 #[inline(always)]
138 pub fn apply(self, c: &Candles) -> Result<MidpriceOutput, MidpriceError> {
139 let p = MidpriceParams {
140 period: self.period,
141 };
142 let i = MidpriceInput::from_candles(c, "high", "low", p);
143 midprice_with_kernel(&i, self.kernel)
144 }
145 #[inline(always)]
146 pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<MidpriceOutput, MidpriceError> {
147 let p = MidpriceParams {
148 period: self.period,
149 };
150 let i = MidpriceInput::from_slices(high, low, p);
151 midprice_with_kernel(&i, self.kernel)
152 }
153 #[inline(always)]
154 pub fn into_stream(self) -> Result<MidpriceStream, MidpriceError> {
155 let p = MidpriceParams {
156 period: self.period,
157 };
158 MidpriceStream::try_new(p)
159 }
160}
161
162#[derive(Debug, Error)]
163pub enum MidpriceError {
164 #[error("midprice: Empty data provided.")]
165 EmptyInputData,
166 #[error("midprice: All values are NaN.")]
167 AllValuesNaN,
168 #[error("midprice: Invalid period: period = {period}, data length = {data_len}")]
169 InvalidPeriod { period: usize, data_len: usize },
170 #[error("midprice: Not enough valid data: needed = {needed}, valid = {valid}")]
171 NotEnoughValidData { needed: usize, valid: usize },
172 #[error("midprice: Output length mismatch: expected = {expected}, got = {got}")]
173 OutputLengthMismatch { expected: usize, got: usize },
174 #[error("midprice: Invalid range: start={start}, end={end}, step={step}")]
175 InvalidRange {
176 start: usize,
177 end: usize,
178 step: usize,
179 },
180 #[error("midprice: Invalid kernel for batch: {0:?}")]
181 InvalidKernelForBatch(Kernel),
182 #[error("midprice: Mismatched data length: high_len = {high_len}, low_len = {low_len}")]
183 MismatchedDataLength { high_len: usize, low_len: usize },
184 #[error("midprice: invalid input: {0}")]
185 InvalidInput(&'static str),
186}
187
188#[inline]
189pub fn midprice(input: &MidpriceInput) -> Result<MidpriceOutput, MidpriceError> {
190 midprice_with_kernel(input, Kernel::Auto)
191}
192
193pub fn midprice_with_kernel(
194 input: &MidpriceInput,
195 kernel: Kernel,
196) -> Result<MidpriceOutput, MidpriceError> {
197 let (high, low) = match &input.data {
198 MidpriceData::Candles {
199 candles,
200 high_src,
201 low_src,
202 } => {
203 let h = source_type(candles, high_src);
204 let l = source_type(candles, low_src);
205 (h, l)
206 }
207 MidpriceData::Slices { high, low } => (*high, *low),
208 };
209
210 if high.is_empty() || low.is_empty() {
211 return Err(MidpriceError::EmptyInputData);
212 }
213 if high.len() != low.len() {
214 return Err(MidpriceError::MismatchedDataLength {
215 high_len: high.len(),
216 low_len: low.len(),
217 });
218 }
219 let period = input.get_period();
220 if period == 0 || period > high.len() {
221 return Err(MidpriceError::InvalidPeriod {
222 period,
223 data_len: high.len(),
224 });
225 }
226 let first_valid_idx = match (0..high.len()).find(|&i| !high[i].is_nan() && !low[i].is_nan()) {
227 Some(idx) => idx,
228 None => return Err(MidpriceError::AllValuesNaN),
229 };
230 if (high.len() - first_valid_idx) < period {
231 return Err(MidpriceError::NotEnoughValidData {
232 needed: period,
233 valid: high.len() - first_valid_idx,
234 });
235 }
236 let mut out = alloc_with_nan_prefix(high.len(), first_valid_idx + period - 1);
237 let chosen = match kernel {
238 Kernel::Auto => detect_best_kernel(),
239 other => other,
240 };
241 unsafe {
242 match chosen {
243 Kernel::Scalar | Kernel::ScalarBatch => {
244 midprice_scalar(high, low, period, first_valid_idx, &mut out)
245 }
246 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
247 Kernel::Avx2 | Kernel::Avx2Batch => {
248 midprice_avx2(high, low, period, first_valid_idx, &mut out)
249 }
250 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
251 Kernel::Avx512 | Kernel::Avx512Batch => {
252 midprice_avx512(high, low, period, first_valid_idx, &mut out)
253 }
254 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
255 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
256 midprice_scalar(high, low, period, first_valid_idx, &mut out)
257 }
258 Kernel::Auto => midprice_scalar(high, low, period, first_valid_idx, &mut out),
259 }
260 }
261 Ok(MidpriceOutput { values: out })
262}
263
264#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
265pub fn midprice_into(input: &MidpriceInput, out: &mut [f64]) -> Result<(), MidpriceError> {
266 let (high, low) = match &input.data {
267 MidpriceData::Candles {
268 candles,
269 high_src,
270 low_src,
271 } => {
272 let h = source_type(candles, high_src);
273 let l = source_type(candles, low_src);
274 (h, l)
275 }
276 MidpriceData::Slices { high, low } => (*high, *low),
277 };
278
279 if out.len() != high.len() || high.len() != low.len() {
280 return Err(MidpriceError::OutputLengthMismatch {
281 expected: high.len(),
282 got: out.len(),
283 });
284 }
285
286 midprice_into_slice(out, high, low, &input.params, Kernel::Auto)
287}
288
289#[inline]
290pub fn midprice_scalar(
291 high: &[f64],
292 low: &[f64],
293 period: usize,
294 first_valid_idx: usize,
295 out: &mut [f64],
296) {
297 if period <= 64 {
298 let len = high.len();
299 let warmup_end = first_valid_idx + period - 1;
300 unsafe {
301 let high_ptr = high.as_ptr();
302 let low_ptr = low.as_ptr();
303 let out_ptr = out.as_mut_ptr();
304 for i in warmup_end..len {
305 let window_start = i + 1 - period;
306 let mut highest = f64::NEG_INFINITY;
307 let mut lowest = f64::INFINITY;
308 let mut j = window_start;
309 while j <= i {
310 let h = *high_ptr.add(j);
311 if h > highest {
312 highest = h;
313 }
314 let l = *low_ptr.add(j);
315 if l < lowest {
316 lowest = l;
317 }
318 j += 1;
319 }
320 *out_ptr.add(i) = (highest + lowest) * 0.5;
321 }
322 }
323 return;
324 }
325
326 let warmup_end = first_valid_idx + period - 1;
327 let mut dq_high: std::collections::VecDeque<usize> =
328 std::collections::VecDeque::with_capacity(period + 1);
329 let mut dq_low: std::collections::VecDeque<usize> =
330 std::collections::VecDeque::with_capacity(period + 1);
331
332 for i in first_valid_idx..high.len() {
333 let hv = high[i];
334 if !hv.is_nan() {
335 while let Some(&j) = dq_high.back() {
336 if high[j] <= hv {
337 dq_high.pop_back();
338 } else {
339 break;
340 }
341 }
342 dq_high.push_back(i);
343 }
344
345 let lv = low[i];
346 if !lv.is_nan() {
347 while let Some(&j) = dq_low.back() {
348 if low[j] >= lv {
349 dq_low.pop_back();
350 } else {
351 break;
352 }
353 }
354 dq_low.push_back(i);
355 }
356
357 if i < warmup_end {
358 continue;
359 }
360
361 let start = i + 1 - period;
362 while let Some(&j) = dq_high.front() {
363 if j < start {
364 dq_high.pop_front();
365 } else {
366 break;
367 }
368 }
369 while let Some(&j) = dq_low.front() {
370 if j < start {
371 dq_low.pop_front();
372 } else {
373 break;
374 }
375 }
376
377 let highest = dq_high
378 .front()
379 .map(|&j| high[j])
380 .unwrap_or(f64::NEG_INFINITY);
381 let lowest = dq_low.front().map(|&j| low[j]).unwrap_or(f64::INFINITY);
382 out[i] = (highest + lowest) * 0.5;
383 }
384}
385
386#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
387#[inline]
388pub fn midprice_avx2(
389 high: &[f64],
390 low: &[f64],
391 period: usize,
392 first_valid_idx: usize,
393 out: &mut [f64],
394) {
395 midprice_scalar(high, low, period, first_valid_idx, out);
396}
397
398#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
399#[inline]
400pub fn midprice_avx512(
401 high: &[f64],
402 low: &[f64],
403 period: usize,
404 first_valid_idx: usize,
405 out: &mut [f64],
406) {
407 if period <= 32 {
408 unsafe { midprice_avx512_short(high, low, period, first_valid_idx, out) }
409 } else {
410 unsafe { midprice_avx512_long(high, low, period, first_valid_idx, out) }
411 }
412}
413
414#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
415#[inline]
416pub unsafe fn midprice_avx512_short(
417 high: &[f64],
418 low: &[f64],
419 period: usize,
420 first_valid_idx: usize,
421 out: &mut [f64],
422) {
423 midprice_scalar(high, low, period, first_valid_idx, out);
424}
425
426#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
427#[inline]
428pub unsafe fn midprice_avx512_long(
429 high: &[f64],
430 low: &[f64],
431 period: usize,
432 first_valid_idx: usize,
433 out: &mut [f64],
434) {
435 midprice_scalar(high, low, period, first_valid_idx, out);
436}
437
438#[derive(Debug, Clone)]
439pub struct MidpriceStream {
440 period: usize,
441
442 started: bool,
443
444 seen: usize,
445
446 dq_high: std::collections::VecDeque<(usize, f64)>,
447 dq_low: std::collections::VecDeque<(usize, f64)>,
448}
449
450impl MidpriceStream {
451 pub fn try_new(params: MidpriceParams) -> Result<Self, MidpriceError> {
452 let period = params.period.unwrap_or(14);
453 if period == 0 {
454 return Err(MidpriceError::InvalidPeriod {
455 period,
456 data_len: 0,
457 });
458 }
459 let cap = period + 1;
460 Ok(Self {
461 period,
462 started: false,
463 seen: 0,
464 dq_high: std::collections::VecDeque::with_capacity(cap),
465 dq_low: std::collections::VecDeque::with_capacity(cap),
466 })
467 }
468
469 #[inline(always)]
470 pub fn update(&mut self, high: f64, low: f64) -> Option<f64> {
471 if !self.started {
472 if !(high.is_finite() && low.is_finite()) {
473 return None;
474 }
475 self.started = true;
476 self.seen = 0;
477 }
478
479 let i = self.seen;
480
481 if high.is_finite() {
482 while let Some(&(_, v)) = self.dq_high.back() {
483 if v <= high {
484 self.dq_high.pop_back();
485 } else {
486 break;
487 }
488 }
489 self.dq_high.push_back((i, high));
490 }
491
492 if low.is_finite() {
493 while let Some(&(_, v)) = self.dq_low.back() {
494 if v >= low {
495 self.dq_low.pop_back();
496 } else {
497 break;
498 }
499 }
500 self.dq_low.push_back((i, low));
501 }
502
503 let start = i.saturating_add(1).saturating_sub(self.period);
504 while let Some(&(idx, _)) = self.dq_high.front() {
505 if idx < start {
506 self.dq_high.pop_front();
507 } else {
508 break;
509 }
510 }
511 while let Some(&(idx, _)) = self.dq_low.front() {
512 if idx < start {
513 self.dq_low.pop_front();
514 } else {
515 break;
516 }
517 }
518
519 self.seen = i + 1;
520
521 if self.seen < self.period {
522 return None;
523 }
524
525 Some(self.calc())
526 }
527
528 #[inline(always)]
529 fn calc(&self) -> f64 {
530 let max_h = self
531 .dq_high
532 .front()
533 .map(|&(_, v)| v)
534 .unwrap_or(f64::NEG_INFINITY);
535 let min_l = self
536 .dq_low
537 .front()
538 .map(|&(_, v)| v)
539 .unwrap_or(f64::INFINITY);
540 (max_h + min_l) / 2.0
541 }
542}
543
544#[derive(Clone, Debug)]
545pub struct MidpriceBatchRange {
546 pub period: (usize, usize, usize),
547}
548
549impl Default for MidpriceBatchRange {
550 fn default() -> Self {
551 Self {
552 period: (14, 263, 1),
553 }
554 }
555}
556
557#[derive(Clone, Debug, Default)]
558pub struct MidpriceBatchBuilder {
559 range: MidpriceBatchRange,
560 kernel: Kernel,
561}
562
563impl MidpriceBatchBuilder {
564 pub fn new() -> Self {
565 Self::default()
566 }
567 pub fn kernel(mut self, k: Kernel) -> Self {
568 self.kernel = k;
569 self
570 }
571 #[inline]
572 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
573 self.range.period = (start, end, step);
574 self
575 }
576 #[inline]
577 pub fn period_static(mut self, p: usize) -> Self {
578 self.range.period = (p, p, 0);
579 self
580 }
581 pub fn apply_slices(
582 self,
583 high: &[f64],
584 low: &[f64],
585 ) -> Result<MidpriceBatchOutput, MidpriceError> {
586 midprice_batch_with_kernel(high, low, &self.range, self.kernel)
587 }
588 pub fn apply_candles(
589 self,
590 c: &Candles,
591 high_src: &str,
592 low_src: &str,
593 ) -> Result<MidpriceBatchOutput, MidpriceError> {
594 let high = source_type(c, high_src);
595 let low = source_type(c, low_src);
596 self.apply_slices(high, low)
597 }
598 pub fn with_default_candles(c: &Candles) -> Result<MidpriceBatchOutput, MidpriceError> {
599 MidpriceBatchBuilder::new()
600 .kernel(Kernel::Auto)
601 .apply_candles(c, "high", "low")
602 }
603}
604
605pub fn midprice_batch_with_kernel(
606 high: &[f64],
607 low: &[f64],
608 sweep: &MidpriceBatchRange,
609 k: Kernel,
610) -> Result<MidpriceBatchOutput, MidpriceError> {
611 let kernel = match k {
612 Kernel::Auto => detect_best_batch_kernel(),
613 other if other.is_batch() => other,
614 _ => return Err(MidpriceError::InvalidKernelForBatch(k)),
615 };
616 let simd = match kernel {
617 Kernel::Avx512Batch => Kernel::Avx512,
618 Kernel::Avx2Batch => Kernel::Avx2,
619 Kernel::ScalarBatch => Kernel::Scalar,
620 _ => Kernel::Scalar,
621 };
622 midprice_batch_par_slice(high, low, sweep, simd)
623}
624
625#[derive(Clone, Debug)]
626pub struct MidpriceBatchOutput {
627 pub values: Vec<f64>,
628 pub combos: Vec<MidpriceParams>,
629 pub rows: usize,
630 pub cols: usize,
631}
632impl MidpriceBatchOutput {
633 pub fn row_for_params(&self, p: &MidpriceParams) -> Option<usize> {
634 self.combos
635 .iter()
636 .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
637 }
638 pub fn values_for(&self, p: &MidpriceParams) -> Option<&[f64]> {
639 self.row_for_params(p).and_then(|row| {
640 row.checked_mul(self.cols)
641 .map(|start| &self.values[start..start + self.cols])
642 })
643 }
644}
645
646#[inline(always)]
647fn expand_grid(r: &MidpriceBatchRange) -> Result<Vec<MidpriceParams>, MidpriceError> {
648 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, MidpriceError> {
649 if step == 0 || start == end {
650 return Ok(vec![start]);
651 }
652 let mut v = Vec::new();
653 let s = step.max(1);
654 if start < end {
655 let mut cur = start;
656 while cur <= end {
657 v.push(cur);
658 let next = cur.saturating_add(s);
659 if next == cur {
660 break;
661 }
662 cur = next;
663 }
664 } else {
665 let mut cur = start;
666 while cur >= end {
667 v.push(cur);
668 let next = cur.saturating_sub(s);
669 if next == cur {
670 break;
671 }
672 cur = next;
673 if cur == 0 && end > 0 {
674 break;
675 }
676 }
677 v.retain(|&x| x >= end);
678 }
679 if v.is_empty() {
680 return Err(MidpriceError::InvalidRange { start, end, step });
681 }
682 Ok(v)
683 }
684 let periods = axis_usize(r.period)?;
685 let mut out = Vec::with_capacity(periods.len());
686 for &p in &periods {
687 out.push(MidpriceParams { period: Some(p) });
688 }
689 Ok(out)
690}
691
692#[inline(always)]
693fn validate_batch_periods(combos: &[MidpriceParams], cols: usize) -> Result<usize, MidpriceError> {
694 let mut max_p = 0usize;
695 for c in combos {
696 let p = c.period.unwrap_or(14);
697 if p == 0 || p > cols {
698 return Err(MidpriceError::InvalidPeriod {
699 period: p,
700 data_len: cols,
701 });
702 }
703 if p > max_p {
704 max_p = p;
705 }
706 }
707 Ok(max_p)
708}
709
710#[inline(always)]
711fn batch_warm_prefixes(
712 combos: &[MidpriceParams],
713 first: usize,
714 cols: usize,
715) -> Result<(Vec<usize>, usize), MidpriceError> {
716 let mut max_p = 0usize;
717 let mut warm = Vec::with_capacity(combos.len());
718 for c in combos {
719 let p = c.period.unwrap_or(14);
720 if p == 0 || p > cols {
721 return Err(MidpriceError::InvalidPeriod {
722 period: p,
723 data_len: cols,
724 });
725 }
726 if p > max_p {
727 max_p = p;
728 }
729 warm.push(first + p - 1);
730 }
731 if cols - first < max_p {
732 return Err(MidpriceError::NotEnoughValidData {
733 needed: max_p,
734 valid: cols - first,
735 });
736 }
737 Ok((warm, max_p))
738}
739
740#[inline(always)]
741pub fn midprice_batch_slice(
742 high: &[f64],
743 low: &[f64],
744 sweep: &MidpriceBatchRange,
745 kern: Kernel,
746) -> Result<MidpriceBatchOutput, MidpriceError> {
747 midprice_batch_inner(high, low, sweep, kern, false)
748}
749
750#[inline(always)]
751pub fn midprice_batch_par_slice(
752 high: &[f64],
753 low: &[f64],
754 sweep: &MidpriceBatchRange,
755 kern: Kernel,
756) -> Result<MidpriceBatchOutput, MidpriceError> {
757 midprice_batch_inner(high, low, sweep, kern, true)
758}
759
760#[inline(always)]
761fn midprice_batch_inner_into(
762 high: &[f64],
763 low: &[f64],
764 sweep: &MidpriceBatchRange,
765 kern: Kernel,
766 parallel: bool,
767 out: &mut [f64],
768) -> Result<Vec<MidpriceParams>, MidpriceError> {
769 let combos = expand_grid(sweep)?;
770 if high.is_empty() || low.is_empty() {
771 return Err(MidpriceError::EmptyInputData);
772 }
773 if high.len() != low.len() {
774 return Err(MidpriceError::MismatchedDataLength {
775 high_len: high.len(),
776 low_len: low.len(),
777 });
778 }
779 let first = (0..high.len())
780 .find(|&i| !high[i].is_nan() && !low[i].is_nan())
781 .ok_or(MidpriceError::AllValuesNaN)?;
782 let cols = high.len();
783 let max_p = validate_batch_periods(&combos, cols)?;
784 if cols - first < max_p {
785 return Err(MidpriceError::NotEnoughValidData {
786 needed: max_p,
787 valid: cols - first,
788 });
789 }
790 let rows = combos.len();
791 let expected = rows
792 .checked_mul(cols)
793 .ok_or(MidpriceError::InvalidInput("rows*cols overflow"))?;
794 if out.len() != expected {
795 return Err(MidpriceError::OutputLengthMismatch {
796 expected,
797 got: out.len(),
798 });
799 }
800
801 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
802 let period = combos[row].period.unwrap();
803 match kern {
804 Kernel::Scalar | Kernel::ScalarBatch => {
805 midprice_row_scalar(high, low, first, period, out_row)
806 }
807 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
808 Kernel::Avx2 | Kernel::Avx2Batch => {
809 midprice_row_avx2(high, low, first, period, out_row)
810 }
811 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
812 Kernel::Avx512 | Kernel::Avx512Batch => {
813 midprice_row_avx512(high, low, first, period, out_row)
814 }
815 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
816 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
817 midprice_row_scalar(high, low, first, period, out_row)
818 }
819 Kernel::Auto => midprice_row_scalar(high, low, first, period, out_row),
820 }
821 };
822
823 if parallel {
824 #[cfg(not(target_arch = "wasm32"))]
825 {
826 out.par_chunks_mut(cols)
827 .enumerate()
828 .for_each(|(row, slice)| do_row(row, slice));
829 }
830
831 #[cfg(target_arch = "wasm32")]
832 {
833 for (row, slice) in out.chunks_mut(cols).enumerate() {
834 do_row(row, slice);
835 }
836 }
837 } else {
838 for (row, slice) in out.chunks_mut(cols).enumerate() {
839 do_row(row, slice);
840 }
841 }
842
843 Ok(combos)
844}
845
846#[inline(always)]
847fn midprice_batch_inner(
848 high: &[f64],
849 low: &[f64],
850 sweep: &MidpriceBatchRange,
851 kern: Kernel,
852 parallel: bool,
853) -> Result<MidpriceBatchOutput, MidpriceError> {
854 let combos = expand_grid(sweep)?;
855 if high.is_empty() || low.is_empty() {
856 return Err(MidpriceError::EmptyInputData);
857 }
858 if high.len() != low.len() {
859 return Err(MidpriceError::MismatchedDataLength {
860 high_len: high.len(),
861 low_len: low.len(),
862 });
863 }
864
865 let first = (0..high.len())
866 .find(|&i| !high[i].is_nan() && !low[i].is_nan())
867 .ok_or(MidpriceError::AllValuesNaN)?;
868 let cols = high.len();
869 let (warm, _max_p) = batch_warm_prefixes(&combos, first, cols)?;
870
871 let rows = combos.len();
872 let total_cells = rows
873 .checked_mul(cols)
874 .ok_or(MidpriceError::InvalidInput("rows*cols overflow"))?;
875
876 let mut buf_mu = make_uninit_matrix(rows, cols);
877
878 init_matrix_prefixes(&mut buf_mu, cols, &warm);
879
880 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
881 let out: &mut [f64] =
882 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
883 let combos = midprice_batch_inner_into(high, low, sweep, kern, parallel, out)?;
884
885 let values = unsafe {
886 Vec::from_raw_parts(
887 guard.as_mut_ptr() as *mut f64,
888 total_cells,
889 guard.capacity(),
890 )
891 };
892
893 Ok(MidpriceBatchOutput {
894 values,
895 combos,
896 rows,
897 cols,
898 })
899}
900
901#[inline(always)]
902pub unsafe fn midprice_row_scalar(
903 high: &[f64],
904 low: &[f64],
905 first: usize,
906 period: usize,
907 out: &mut [f64],
908) {
909 if period <= 64 {
910 let len = high.len();
911 let warmup_end = first + period - 1;
912 let high_ptr = high.as_ptr();
913 let low_ptr = low.as_ptr();
914 let out_ptr = out.as_mut_ptr();
915 for i in warmup_end..len {
916 let window_start = i + 1 - period;
917 let mut highest = f64::NEG_INFINITY;
918 let mut lowest = f64::INFINITY;
919 let mut j = window_start;
920 while j <= i {
921 let h = *high_ptr.add(j);
922 if h > highest {
923 highest = h;
924 }
925 let l = *low_ptr.add(j);
926 if l < lowest {
927 lowest = l;
928 }
929 j += 1;
930 }
931 *out_ptr.add(i) = (highest + lowest) * 0.5;
932 }
933 return;
934 }
935
936 let warmup_end = first + period - 1;
937 let mut dq_high: std::collections::VecDeque<usize> =
938 std::collections::VecDeque::with_capacity(period + 1);
939 let mut dq_low: std::collections::VecDeque<usize> =
940 std::collections::VecDeque::with_capacity(period + 1);
941
942 for i in first..high.len() {
943 let hv = high[i];
944 if !hv.is_nan() {
945 while let Some(&j) = dq_high.back() {
946 if high[j] <= hv {
947 dq_high.pop_back();
948 } else {
949 break;
950 }
951 }
952 dq_high.push_back(i);
953 }
954
955 let lv = low[i];
956 if !lv.is_nan() {
957 while let Some(&j) = dq_low.back() {
958 if low[j] >= lv {
959 dq_low.pop_back();
960 } else {
961 break;
962 }
963 }
964 dq_low.push_back(i);
965 }
966
967 if i < warmup_end {
968 continue;
969 }
970
971 let start = i + 1 - period;
972 while let Some(&j) = dq_high.front() {
973 if j < start {
974 dq_high.pop_front();
975 } else {
976 break;
977 }
978 }
979 while let Some(&j) = dq_low.front() {
980 if j < start {
981 dq_low.pop_front();
982 } else {
983 break;
984 }
985 }
986
987 let highest = dq_high
988 .front()
989 .map(|&j| high[j])
990 .unwrap_or(f64::NEG_INFINITY);
991 let lowest = dq_low.front().map(|&j| low[j]).unwrap_or(f64::INFINITY);
992 out[i] = (highest + lowest) * 0.5;
993 }
994}
995
996#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
997#[inline(always)]
998pub unsafe fn midprice_row_avx2(
999 high: &[f64],
1000 low: &[f64],
1001 first: usize,
1002 period: usize,
1003 out: &mut [f64],
1004) {
1005 midprice_row_scalar(high, low, first, period, out);
1006}
1007
1008#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1009#[inline(always)]
1010pub unsafe fn midprice_row_avx512(
1011 high: &[f64],
1012 low: &[f64],
1013 first: usize,
1014 period: usize,
1015 out: &mut [f64],
1016) {
1017 if period <= 32 {
1018 midprice_row_avx512_short(high, low, first, period, out);
1019 } else {
1020 midprice_row_avx512_long(high, low, first, period, out);
1021 }
1022}
1023
1024#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1025#[inline(always)]
1026pub unsafe fn midprice_row_avx512_short(
1027 high: &[f64],
1028 low: &[f64],
1029 first: usize,
1030 period: usize,
1031 out: &mut [f64],
1032) {
1033 midprice_row_scalar(high, low, first, period, out);
1034}
1035
1036#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1037#[inline(always)]
1038pub unsafe fn midprice_row_avx512_long(
1039 high: &[f64],
1040 low: &[f64],
1041 first: usize,
1042 period: usize,
1043 out: &mut [f64],
1044) {
1045 midprice_row_scalar(high, low, first, period, out);
1046}
1047
1048#[cfg(feature = "python")]
1049#[pyfunction(name = "midprice")]
1050#[pyo3(signature = (high, low, period, kernel=None))]
1051pub fn midprice_py<'py>(
1052 py: Python<'py>,
1053 high: PyReadonlyArray1<'py, f64>,
1054 low: PyReadonlyArray1<'py, f64>,
1055 period: usize,
1056 kernel: Option<&str>,
1057) -> PyResult<Bound<'py, PyArray1<f64>>> {
1058 let high_slice = high.as_slice()?;
1059 let low_slice = low.as_slice()?;
1060 let kern = validate_kernel(kernel, false)?;
1061
1062 let params = MidpriceParams {
1063 period: Some(period),
1064 };
1065 let input = MidpriceInput::from_slices(high_slice, low_slice, params);
1066
1067 let result_vec: Vec<f64> = py
1068 .allow_threads(|| midprice_with_kernel(&input, kern).map(|o| o.values))
1069 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1070
1071 Ok(result_vec.into_pyarray(py))
1072}
1073
1074#[cfg(feature = "python")]
1075#[pyfunction(name = "midprice_batch")]
1076#[pyo3(signature = (high, low, period_range, kernel=None))]
1077pub fn midprice_batch_py<'py>(
1078 py: Python<'py>,
1079 high: PyReadonlyArray1<'py, f64>,
1080 low: PyReadonlyArray1<'py, f64>,
1081 period_range: (usize, usize, usize),
1082 kernel: Option<&str>,
1083) -> PyResult<Bound<'py, PyDict>> {
1084 let high_slice = high.as_slice()?;
1085 let low_slice = low.as_slice()?;
1086 let kern = validate_kernel(kernel, true)?;
1087 let sweep = MidpriceBatchRange {
1088 period: period_range,
1089 };
1090
1091 let cols = high_slice.len();
1092 if cols == 0 {
1093 return Err(PyValueError::new_err("midprice: empty data"));
1094 }
1095 if cols != low_slice.len() {
1096 return Err(PyValueError::new_err(format!(
1097 "midprice: length mismatch: high={}, low={}",
1098 cols,
1099 low_slice.len()
1100 )));
1101 }
1102 let first = (0..cols)
1103 .find(|&i| !high_slice[i].is_nan() && !low_slice[i].is_nan())
1104 .ok_or_else(|| PyValueError::new_err("midprice: All values are NaN"))?;
1105
1106 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1107 let rows = combos.len();
1108 let total = rows
1109 .checked_mul(cols)
1110 .ok_or_else(|| PyValueError::new_err("midprice: rows*cols overflow"))?;
1111
1112 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1113 let slice_out = unsafe { out_arr.as_slice_mut()? };
1114
1115 let (warm, _max_p) = batch_warm_prefixes(&combos, first, cols)
1116 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1117 let out_mu = unsafe {
1118 std::slice::from_raw_parts_mut(
1119 slice_out.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
1120 total,
1121 )
1122 };
1123 init_matrix_prefixes(out_mu, cols, &warm);
1124
1125 let combos = py
1126 .allow_threads(|| {
1127 let kernel = match kern {
1128 Kernel::Auto => detect_best_batch_kernel(),
1129 k => k,
1130 };
1131 let simd = match kernel {
1132 Kernel::Avx512Batch => Kernel::Avx512,
1133 Kernel::Avx2Batch => Kernel::Avx2,
1134 Kernel::ScalarBatch => Kernel::Scalar,
1135 _ => Kernel::Scalar,
1136 };
1137 midprice_batch_inner_into(high_slice, low_slice, &sweep, simd, true, slice_out)
1138 })
1139 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1140
1141 let dict = PyDict::new(py);
1142 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1143 dict.set_item(
1144 "periods",
1145 combos
1146 .iter()
1147 .map(|p| p.period.unwrap() as u64)
1148 .collect::<Vec<_>>()
1149 .into_pyarray(py),
1150 )?;
1151 Ok(dict)
1152}
1153
1154#[cfg(feature = "python")]
1155#[pyclass(name = "MidpriceStream")]
1156pub struct MidpriceStreamPy {
1157 stream: MidpriceStream,
1158}
1159
1160#[cfg(feature = "python")]
1161#[pymethods]
1162impl MidpriceStreamPy {
1163 #[new]
1164 fn new(period: usize) -> PyResult<Self> {
1165 let params = MidpriceParams {
1166 period: Some(period),
1167 };
1168 let stream =
1169 MidpriceStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1170 Ok(MidpriceStreamPy { stream })
1171 }
1172
1173 fn update(&mut self, high: f64, low: f64) -> Option<f64> {
1174 self.stream.update(high, low)
1175 }
1176}
1177
1178pub fn midprice_into_slice(
1179 dst: &mut [f64],
1180 high: &[f64],
1181 low: &[f64],
1182 params: &MidpriceParams,
1183 kern: Kernel,
1184) -> Result<(), MidpriceError> {
1185 if high.is_empty() || low.is_empty() {
1186 return Err(MidpriceError::EmptyInputData);
1187 }
1188 if high.len() != low.len() {
1189 return Err(MidpriceError::MismatchedDataLength {
1190 high_len: high.len(),
1191 low_len: low.len(),
1192 });
1193 }
1194 if dst.len() != high.len() {
1195 return Err(MidpriceError::OutputLengthMismatch {
1196 expected: high.len(),
1197 got: dst.len(),
1198 });
1199 }
1200
1201 let period = params.period.unwrap_or(14);
1202 if period == 0 || period > high.len() {
1203 return Err(MidpriceError::InvalidPeriod {
1204 period,
1205 data_len: high.len(),
1206 });
1207 }
1208
1209 let first_valid_idx = match (0..high.len()).find(|&i| !high[i].is_nan() && !low[i].is_nan()) {
1210 Some(idx) => idx,
1211 None => return Err(MidpriceError::AllValuesNaN),
1212 };
1213
1214 if high.len() - first_valid_idx < period {
1215 return Err(MidpriceError::NotEnoughValidData {
1216 needed: period,
1217 valid: high.len() - first_valid_idx,
1218 });
1219 }
1220
1221 let chosen = match kern {
1222 Kernel::Auto => detect_best_kernel(),
1223 other => other,
1224 };
1225
1226 let warmup_end = first_valid_idx + period - 1;
1227 for v in &mut dst[..warmup_end] {
1228 *v = f64::NAN;
1229 }
1230
1231 match chosen {
1232 Kernel::Scalar | Kernel::ScalarBatch => {
1233 midprice_scalar(high, low, period, first_valid_idx, dst)
1234 }
1235 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1236 Kernel::Avx2 | Kernel::Avx2Batch => midprice_avx2(high, low, period, first_valid_idx, dst),
1237 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1238 Kernel::Avx512 | Kernel::Avx512Batch => {
1239 midprice_avx512(high, low, period, first_valid_idx, dst)
1240 }
1241 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1242 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
1243 midprice_scalar(high, low, period, first_valid_idx, dst)
1244 }
1245 Kernel::Auto => midprice_scalar(high, low, period, first_valid_idx, dst),
1246 }
1247
1248 Ok(())
1249}
1250
1251#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1252#[wasm_bindgen]
1253pub fn midprice_js(high: &[f64], low: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1254 let params = MidpriceParams {
1255 period: Some(period),
1256 };
1257
1258 let mut output = vec![0.0; high.len()];
1259
1260 midprice_into_slice(&mut output, high, low, ¶ms, Kernel::Auto)
1261 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1262
1263 Ok(output)
1264}
1265
1266#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1267#[wasm_bindgen]
1268pub fn midprice_alloc(len: usize) -> *mut f64 {
1269 let mut vec = Vec::<f64>::with_capacity(len);
1270 let ptr = vec.as_mut_ptr();
1271 std::mem::forget(vec);
1272 ptr
1273}
1274
1275#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1276#[wasm_bindgen]
1277pub fn midprice_free(ptr: *mut f64, len: usize) {
1278 if !ptr.is_null() {
1279 unsafe {
1280 let _ = Vec::from_raw_parts(ptr, len, len);
1281 }
1282 }
1283}
1284
1285#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1286#[wasm_bindgen]
1287pub fn midprice_into(
1288 in_high_ptr: *const f64,
1289 in_low_ptr: *const f64,
1290 out_ptr: *mut f64,
1291 len: usize,
1292 period: usize,
1293) -> Result<(), JsValue> {
1294 if in_high_ptr.is_null() || in_low_ptr.is_null() || out_ptr.is_null() {
1295 return Err(JsValue::from_str("Null pointer provided"));
1296 }
1297
1298 unsafe {
1299 let high = std::slice::from_raw_parts(in_high_ptr, len);
1300 let low = std::slice::from_raw_parts(in_low_ptr, len);
1301 let params = MidpriceParams {
1302 period: Some(period),
1303 };
1304
1305 if in_high_ptr == out_ptr || in_low_ptr == out_ptr {
1306 let mut temp = vec![0.0; len];
1307 midprice_into_slice(&mut temp, high, low, ¶ms, Kernel::Auto)
1308 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1309 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1310 out.copy_from_slice(&temp);
1311 } else {
1312 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1313 midprice_into_slice(out, high, low, ¶ms, Kernel::Auto)
1314 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1315 }
1316 Ok(())
1317 }
1318}
1319
1320#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1321#[derive(Serialize, Deserialize)]
1322pub struct MidpriceBatchConfig {
1323 pub period_range: (usize, usize, usize),
1324}
1325
1326#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1327#[derive(Serialize, Deserialize)]
1328pub struct MidpriceBatchJsOutput {
1329 pub values: Vec<f64>,
1330 pub periods: Vec<usize>,
1331 pub rows: usize,
1332 pub cols: usize,
1333}
1334
1335#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1336#[wasm_bindgen(js_name = midprice_batch)]
1337pub fn midprice_batch_js(high: &[f64], low: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1338 let config: MidpriceBatchConfig = serde_wasm_bindgen::from_value(config)
1339 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1340
1341 let range = MidpriceBatchRange {
1342 period: config.period_range,
1343 };
1344
1345 let output = midprice_batch_inner(high, low, &range, Kernel::Auto, false)
1346 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1347
1348 let periods: Vec<usize> = output
1349 .combos
1350 .iter()
1351 .map(|p| p.period.unwrap_or(14))
1352 .collect();
1353
1354 let js_output = MidpriceBatchJsOutput {
1355 values: output.values,
1356 periods,
1357 rows: output.rows,
1358 cols: output.cols,
1359 };
1360
1361 serde_wasm_bindgen::to_value(&js_output).map_err(|e| JsValue::from_str(&e.to_string()))
1362}
1363
1364#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1365#[wasm_bindgen]
1366pub fn midprice_batch_into(
1367 in_high_ptr: *const f64,
1368 in_low_ptr: *const f64,
1369 out_ptr: *mut f64,
1370 len: usize,
1371 period_start: usize,
1372 period_end: usize,
1373 period_step: usize,
1374) -> Result<usize, JsValue> {
1375 if len == 0 {
1376 return Err(JsValue::from_str("midprice: Empty data provided."));
1377 }
1378 if in_high_ptr.is_null() || in_low_ptr.is_null() || out_ptr.is_null() {
1379 return Err(JsValue::from_str("Null pointer provided"));
1380 }
1381 unsafe {
1382 let high = std::slice::from_raw_parts(in_high_ptr, len);
1383 let low = std::slice::from_raw_parts(in_low_ptr, len);
1384
1385 let range = MidpriceBatchRange {
1386 period: (period_start, period_end, period_step),
1387 };
1388 let combos = expand_grid(&range).map_err(|e| JsValue::from_str(&e.to_string()))?;
1389 let rows = combos.len();
1390 let total = rows
1391 .checked_mul(len)
1392 .ok_or_else(|| JsValue::from_str("rows*len overflow"))?;
1393
1394 let first = (0..len)
1395 .find(|&i| !high[i].is_nan() && !low[i].is_nan())
1396 .ok_or_else(|| JsValue::from_str("All values are NaN"))?;
1397 let (warm, _max_p) = batch_warm_prefixes(&combos, first, len)
1398 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1399
1400 if in_high_ptr == out_ptr || in_low_ptr == out_ptr {
1401 let mut temp: Vec<f64> = Vec::with_capacity(total);
1402 temp.set_len(total);
1403 let mu = std::slice::from_raw_parts_mut(
1404 temp.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
1405 total,
1406 );
1407 init_matrix_prefixes(mu, len, &warm);
1408 midprice_batch_inner_into(high, low, &range, Kernel::Auto, false, &mut temp)
1409 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1410 std::slice::from_raw_parts_mut(out_ptr, total).copy_from_slice(&temp);
1411 } else {
1412 let out = std::slice::from_raw_parts_mut(out_ptr, total);
1413 let mu = std::slice::from_raw_parts_mut(
1414 out.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
1415 total,
1416 );
1417 init_matrix_prefixes(mu, len, &warm);
1418 midprice_batch_inner_into(high, low, &range, Kernel::Auto, false, out)
1419 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1420 }
1421 Ok(rows)
1422 }
1423}
1424
1425#[cfg(test)]
1426mod tests {
1427 use super::*;
1428 use crate::skip_if_unsupported;
1429 use crate::utilities::data_loader::read_candles_from_csv;
1430
1431 fn check_midprice_partial_params(
1432 test_name: &str,
1433 kernel: Kernel,
1434 ) -> Result<(), Box<dyn Error>> {
1435 skip_if_unsupported!(kernel, test_name);
1436 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1437 let candles = read_candles_from_csv(file_path)?;
1438
1439 let default_params = MidpriceParams { period: None };
1440 let input = MidpriceInput::with_default_candles(&candles);
1441 let output = midprice_with_kernel(&input, kernel)?;
1442 assert_eq!(output.values.len(), candles.close.len());
1443 Ok(())
1444 }
1445
1446 fn check_midprice_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1447 skip_if_unsupported!(kernel, test_name);
1448 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1449 let candles = read_candles_from_csv(file_path)?;
1450
1451 let input = MidpriceInput::with_default_candles(&candles);
1452 let result = midprice_with_kernel(&input, kernel)?;
1453 let expected_last_five = [59583.0, 59583.0, 59583.0, 59486.0, 58989.0];
1454 let start = result.values.len().saturating_sub(5);
1455 for (i, &val) in result.values[start..].iter().enumerate() {
1456 let diff = (val - expected_last_five[i]).abs();
1457 assert!(
1458 diff < 1e-1,
1459 "[{}] MIDPRICE {:?} mismatch at idx {}: got {}, expected {}",
1460 test_name,
1461 kernel,
1462 i,
1463 val,
1464 expected_last_five[i]
1465 );
1466 }
1467 Ok(())
1468 }
1469
1470 fn check_midprice_default_candles(
1471 test_name: &str,
1472 kernel: Kernel,
1473 ) -> Result<(), Box<dyn Error>> {
1474 skip_if_unsupported!(kernel, test_name);
1475 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1476 let candles = read_candles_from_csv(file_path)?;
1477 let input = MidpriceInput::with_default_candles(&candles);
1478 let output = midprice_with_kernel(&input, kernel)?;
1479 assert_eq!(output.values.len(), candles.close.len());
1480 Ok(())
1481 }
1482
1483 fn check_midprice_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1484 skip_if_unsupported!(kernel, test_name);
1485 let highs = [10.0, 14.0, 12.0];
1486 let lows = [5.0, 6.0, 7.0];
1487 let params = MidpriceParams { period: Some(0) };
1488 let input = MidpriceInput::from_slices(&highs, &lows, params);
1489 let res = midprice_with_kernel(&input, kernel);
1490 assert!(
1491 res.is_err(),
1492 "[{}] MIDPRICE should fail with zero period",
1493 test_name
1494 );
1495 Ok(())
1496 }
1497
1498 fn check_midprice_period_exceeds_length(
1499 test_name: &str,
1500 kernel: Kernel,
1501 ) -> Result<(), Box<dyn Error>> {
1502 skip_if_unsupported!(kernel, test_name);
1503 let highs = [10.0, 14.0, 12.0];
1504 let lows = [5.0, 6.0, 7.0];
1505 let params = MidpriceParams { period: Some(10) };
1506 let input = MidpriceInput::from_slices(&highs, &lows, params);
1507 let res = midprice_with_kernel(&input, kernel);
1508 assert!(
1509 res.is_err(),
1510 "[{}] MIDPRICE should fail with period exceeding length",
1511 test_name
1512 );
1513 Ok(())
1514 }
1515
1516 fn check_midprice_very_small_dataset(
1517 test_name: &str,
1518 kernel: Kernel,
1519 ) -> Result<(), Box<dyn Error>> {
1520 skip_if_unsupported!(kernel, test_name);
1521 let highs = [42.0];
1522 let lows = [36.0];
1523 let params = MidpriceParams { period: Some(14) };
1524 let input = MidpriceInput::from_slices(&highs, &lows, params);
1525 let res = midprice_with_kernel(&input, kernel);
1526 assert!(
1527 res.is_err(),
1528 "[{}] MIDPRICE should fail with insufficient data",
1529 test_name
1530 );
1531 Ok(())
1532 }
1533
1534 fn check_midprice_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1535 skip_if_unsupported!(kernel, test_name);
1536 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1537 let candles = read_candles_from_csv(file_path)?;
1538
1539 let params = MidpriceParams { period: Some(10) };
1540 let input = MidpriceInput::with_default_candles(&candles);
1541 let first_result = midprice_with_kernel(&input, kernel)?;
1542
1543 let second_input = MidpriceInput::from_slices(
1544 &first_result.values,
1545 &first_result.values,
1546 MidpriceParams { period: Some(10) },
1547 );
1548 let second_result = midprice_with_kernel(&second_input, kernel)?;
1549 assert_eq!(second_result.values.len(), first_result.values.len());
1550 Ok(())
1551 }
1552
1553 fn check_midprice_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1554 skip_if_unsupported!(kernel, test_name);
1555 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1556 let candles = read_candles_from_csv(file_path)?;
1557 let input = MidpriceInput::with_default_candles(&candles);
1558 let res = midprice_with_kernel(&input, kernel)?;
1559 assert_eq!(res.values.len(), candles.close.len());
1560 if res.values.len() > 20 {
1561 for (i, &val) in res.values[20..].iter().enumerate() {
1562 assert!(
1563 !val.is_nan(),
1564 "[{}] Found unexpected NaN at out-index {}",
1565 test_name,
1566 20 + i
1567 );
1568 }
1569 }
1570 Ok(())
1571 }
1572
1573 fn check_midprice_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1574 skip_if_unsupported!(kernel, test_name);
1575 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1576 let candles = read_candles_from_csv(file_path)?;
1577
1578 let period = 14;
1579 let input = MidpriceInput::with_default_candles(&candles);
1580 let batch_output = midprice_with_kernel(&input, kernel)?.values;
1581
1582 let high = source_type(&candles, "high");
1583 let low = source_type(&candles, "low");
1584 let mut stream = MidpriceStream::try_new(MidpriceParams {
1585 period: Some(period),
1586 })?;
1587 let mut stream_values = Vec::with_capacity(high.len());
1588 for (&h, &l) in high.iter().zip(low.iter()) {
1589 match stream.update(h, l) {
1590 Some(val) => stream_values.push(val),
1591 None => stream_values.push(f64::NAN),
1592 }
1593 }
1594 assert_eq!(batch_output.len(), stream_values.len());
1595 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1596 if b.is_nan() && s.is_nan() {
1597 continue;
1598 }
1599 let diff = (b - s).abs();
1600 assert!(
1601 diff < 1e-9,
1602 "[{}] MIDPRICE streaming mismatch at idx {}: batch={}, stream={}, diff={}",
1603 test_name,
1604 i,
1605 b,
1606 s,
1607 diff
1608 );
1609 }
1610 Ok(())
1611 }
1612
1613 fn check_midprice_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1614 skip_if_unsupported!(kernel, test_name);
1615 let highs = [f64::NAN, f64::NAN, f64::NAN];
1616 let lows = [f64::NAN, f64::NAN, f64::NAN];
1617 let params = MidpriceParams { period: Some(2) };
1618 let input = MidpriceInput::from_slices(&highs, &lows, params);
1619 let result = midprice_with_kernel(&input, kernel);
1620 assert!(
1621 result.is_err(),
1622 "[{}] Expected error for all NaN values",
1623 test_name
1624 );
1625 if let Err(e) = result {
1626 assert!(e.to_string().contains("All values are NaN"));
1627 }
1628 Ok(())
1629 }
1630
1631 #[cfg(debug_assertions)]
1632 fn check_midprice_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1633 skip_if_unsupported!(kernel, test_name);
1634
1635 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1636 let candles = read_candles_from_csv(file_path)?;
1637
1638 let test_params = vec![
1639 MidpriceParams::default(),
1640 MidpriceParams { period: Some(2) },
1641 MidpriceParams { period: Some(5) },
1642 MidpriceParams { period: Some(7) },
1643 MidpriceParams { period: Some(20) },
1644 MidpriceParams { period: Some(30) },
1645 MidpriceParams { period: Some(50) },
1646 MidpriceParams { period: Some(100) },
1647 MidpriceParams { period: Some(200) },
1648 ];
1649
1650 for (param_idx, params) in test_params.iter().enumerate() {
1651 let input = MidpriceInput::from_candles(&candles, "high", "low", params.clone());
1652 let output = midprice_with_kernel(&input, kernel)?;
1653
1654 for (i, &val) in output.values.iter().enumerate() {
1655 if val.is_nan() {
1656 continue;
1657 }
1658
1659 let bits = val.to_bits();
1660
1661 if bits == 0x11111111_11111111 {
1662 panic!(
1663 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1664 with params: period={} (param set {})",
1665 test_name,
1666 val,
1667 bits,
1668 i,
1669 params.period.unwrap_or(14),
1670 param_idx
1671 );
1672 }
1673
1674 if bits == 0x22222222_22222222 {
1675 panic!(
1676 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1677 with params: period={} (param set {})",
1678 test_name,
1679 val,
1680 bits,
1681 i,
1682 params.period.unwrap_or(14),
1683 param_idx
1684 );
1685 }
1686
1687 if bits == 0x33333333_33333333 {
1688 panic!(
1689 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1690 with params: period={} (param set {})",
1691 test_name,
1692 val,
1693 bits,
1694 i,
1695 params.period.unwrap_or(14),
1696 param_idx
1697 );
1698 }
1699 }
1700 }
1701
1702 Ok(())
1703 }
1704
1705 #[cfg(not(debug_assertions))]
1706 fn check_midprice_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1707 Ok(())
1708 }
1709
1710 #[cfg(test)]
1711 fn check_midprice_property(
1712 test_name: &str,
1713 kernel: Kernel,
1714 ) -> Result<(), Box<dyn std::error::Error>> {
1715 use proptest::prelude::*;
1716 skip_if_unsupported!(kernel, test_name);
1717
1718 let strat = (2usize..=100)
1719 .prop_flat_map(|period| {
1720 (
1721 prop::collection::vec(
1722 prop::strategy::Union::new(vec![
1723 (0.001f64..0.1f64).boxed(),
1724 (10f64..10000f64).boxed(),
1725 (1e6f64..1e8f64).boxed(),
1726 ])
1727 .prop_filter("finite", |x| x.is_finite()),
1728 period..=500,
1729 ),
1730 Just(period),
1731 prop::strategy::Union::new(vec![
1732 (0.0001f64..0.01f64).boxed(),
1733 (0.01f64..0.1f64).boxed(),
1734 (0.1f64..0.3f64).boxed(),
1735 ]),
1736 0usize..=9,
1737 )
1738 })
1739 .prop_map(|(base_prices, period, spread_factor, scenario)| {
1740 let len = base_prices.len();
1741 let mut highs = Vec::with_capacity(len);
1742 let mut lows = Vec::with_capacity(len);
1743
1744 match scenario {
1745 0 => {
1746 for &base in &base_prices {
1747 let spread = base * spread_factor;
1748 highs.push(base + spread / 2.0);
1749 lows.push(base - spread / 2.0);
1750 }
1751 }
1752 1 => {
1753 let price = base_prices[0];
1754 let spread = price * spread_factor;
1755 for _ in 0..len {
1756 highs.push(price + spread / 2.0);
1757 lows.push(price - spread / 2.0);
1758 }
1759 }
1760 2 => {
1761 let start_price = base_prices[0];
1762 let increment = 10.0;
1763 for i in 0..len {
1764 let price = start_price + (i as f64) * increment;
1765 let spread = price * spread_factor;
1766 highs.push(price + spread / 2.0);
1767 lows.push(price - spread / 2.0);
1768 }
1769 }
1770 3 => {
1771 let start_price = base_prices[0];
1772 let decrement = 10.0;
1773 for i in 0..len {
1774 let price = (start_price - (i as f64) * decrement).max(10.0);
1775 let spread = price * spread_factor;
1776 highs.push(price + spread / 2.0);
1777 lows.push(price - spread / 2.0);
1778 }
1779 }
1780 4 => {
1781 for &base in &base_prices {
1782 let spread = base * 0.01;
1783 highs.push(base + spread);
1784 lows.push(base);
1785 }
1786 }
1787 5 => {
1788 let amplitude = 100.0;
1789 let offset = base_prices[0];
1790 for i in 0..len {
1791 let phase = (i as f64) * 0.1;
1792 let price = offset + amplitude * phase.sin();
1793 let spread = price.abs() * spread_factor;
1794 highs.push(price + spread / 2.0);
1795 lows.push(price - spread / 2.0);
1796 }
1797 }
1798 6 => {
1799 let step_size = 100.0;
1800 let steps = 5;
1801 for i in 0..len {
1802 let step = (i * steps / len) as f64;
1803 let price = base_prices[0] + step * step_size;
1804 let spread = price * spread_factor;
1805 highs.push(price + spread / 2.0);
1806 lows.push(price - spread / 2.0);
1807 }
1808 }
1809 7 => {
1810 for (i, &base) in base_prices.iter().enumerate() {
1811 let volatility = ((i as f64 * 0.1).sin() + 1.0) * 0.5;
1812 let price = base * (1.0 + spread_factor * volatility);
1813 let spread = price * spread_factor;
1814 highs.push(price + spread / 2.0);
1815 lows.push(price - spread / 2.0);
1816 }
1817 }
1818 8 => {
1819 for i in 0..len {
1820 let price = base_prices[0] + (i as f64) * 5.0;
1821 let spread = price * spread_factor.min(0.1);
1822 highs.push(price + spread / 2.0);
1823 lows.push(price - spread / 2.0);
1824 }
1825 }
1826 _ => {
1827 let mid_point = len / 2;
1828 for i in 0..len {
1829 let base = if i < mid_point {
1830 base_prices[0] + (i as f64) * 20.0
1831 } else {
1832 base_prices[0] + (mid_point as f64) * 20.0
1833 - ((i - mid_point) as f64) * 20.0
1834 };
1835 let spread = base * spread_factor.min(0.05);
1836 highs.push(base + spread / 2.0);
1837 lows.push(base - spread / 2.0);
1838 }
1839 }
1840 }
1841
1842 (highs, lows, period)
1843 });
1844
1845 proptest::test_runner::TestRunner::default()
1846 .run(&strat, |(highs, lows, period)| {
1847 let params = MidpriceParams { period: Some(period) };
1848 let input = MidpriceInput::from_slices(&highs, &lows, params);
1849
1850
1851 let MidpriceOutput { values: out } = midprice_with_kernel(&input, kernel)?;
1852 let MidpriceOutput { values: ref_out } = midprice_with_kernel(&input, Kernel::Scalar)?;
1853
1854
1855 let first_valid = (0..highs.len())
1856 .find(|&i| !highs[i].is_nan() && !lows[i].is_nan())
1857 .unwrap_or(0);
1858
1859
1860 let warmup_end = first_valid + period - 1;
1861 for i in 0..warmup_end.min(out.len()) {
1862 prop_assert!(
1863 out[i].is_nan(),
1864 "Expected NaN during warmup at index {}, got {}",
1865 i, out[i]
1866 );
1867 }
1868
1869
1870 for i in warmup_end..out.len() {
1871 let y = out[i];
1872 let r = ref_out[i];
1873 let window_start = i + 1 - period;
1874
1875
1876 let window_high = &highs[window_start..=i];
1877 let window_low = &lows[window_start..=i];
1878 let max_high = window_high.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1879 let min_low = window_low.iter().cloned().fold(f64::INFINITY, f64::min);
1880
1881
1882 let magnitude = y.abs().max(1.0);
1883 let tolerance = (magnitude * f64::EPSILON * 10.0).max(1e-9);
1884
1885 prop_assert!(
1886 y.is_nan() || (y >= min_low - tolerance && y <= max_high + tolerance),
1887 "Midprice {} at index {} outside bounds [{}, {}]",
1888 y, i, min_low, max_high
1889 );
1890
1891
1892 if y.is_finite() && r.is_finite() {
1893 let ulp_diff = y.to_bits().abs_diff(r.to_bits());
1894 prop_assert!(
1895 (y - r).abs() <= tolerance || ulp_diff <= 8,
1896 "Kernel mismatch at index {}: {} vs {} (ULP={})",
1897 i, y, r, ulp_diff
1898 );
1899 } else {
1900 prop_assert_eq!(
1901 y.to_bits(), r.to_bits(),
1902 "NaN/finite mismatch at index {}: {} vs {}",
1903 i, y, r
1904 );
1905 }
1906
1907
1908 if period == 1 {
1909 let expected = (highs[i] + lows[i]) / 2.0;
1910 prop_assert!(
1911 (y - expected).abs() <= tolerance,
1912 "Period=1 midprice {} != expected {} at index {}",
1913 y, expected, i
1914 );
1915 }
1916
1917
1918 if window_high.windows(2).all(|w| (w[0] - w[1]).abs() < tolerance) &&
1919 window_low.windows(2).all(|w| (w[0] - w[1]).abs() < tolerance) {
1920 let expected = (window_high[0] + window_low[0]) / 2.0;
1921 prop_assert!(
1922 (y - expected).abs() <= tolerance,
1923 "Constant data midprice {} != expected {} at index {}",
1924 y, expected, i
1925 );
1926 }
1927
1928
1929 let actual_max_high = window_high.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1930 let actual_min_low = window_low.iter().cloned().fold(f64::INFINITY, f64::min);
1931 let expected = (actual_max_high + actual_min_low) / 2.0;
1932 prop_assert!(
1933 (y - expected).abs() <= tolerance,
1934 "Midprice {} != expected {} at index {}",
1935 y, expected, i
1936 );
1937
1938
1939
1940
1941
1942 if period == 1 {
1943 prop_assert!(
1944 y >= lows[i] - tolerance && y <= highs[i] + tolerance,
1945 "When period=1, midprice {} should be within current bar's range [{}, {}] at index {}",
1946 y, lows[i], highs[i], i
1947 );
1948 }
1949
1950
1951
1952
1953 if i > warmup_end && window_high.len() > 1 && window_low.len() > 1 {
1954 let high_increasing = window_high.windows(2).all(|w| w[1] >= w[0] - tolerance);
1955 let high_decreasing = window_high.windows(2).all(|w| w[1] <= w[0] + tolerance);
1956 let low_increasing = window_low.windows(2).all(|w| w[1] >= w[0] - tolerance);
1957 let low_decreasing = window_low.windows(2).all(|w| w[1] <= w[0] + tolerance);
1958
1959
1960 if high_increasing && low_increasing {
1961
1962 if i > warmup_end {
1963 let prev_y = out[i - 1];
1964 if prev_y.is_finite() && y.is_finite() {
1965
1966
1967 let allowed_decrease = max_high * 0.1;
1968 prop_assert!(
1969 y >= prev_y - allowed_decrease,
1970 "Midprice should generally increase with monotonic increasing data: {} < {} - {} at index {}",
1971 y, prev_y, allowed_decrease, i
1972 );
1973 }
1974 }
1975 }
1976 }
1977 }
1978
1979 Ok(())
1980 })?;
1981
1982 Ok(())
1983 }
1984
1985 macro_rules! generate_all_midprice_tests {
1986 ($($test_fn:ident),*) => {
1987 paste::paste! {
1988 $(
1989 #[test]
1990 fn [<$test_fn _scalar_f64>]() {
1991 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1992 }
1993 )*
1994 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1995 $(
1996 #[test]
1997 fn [<$test_fn _avx2_f64>]() {
1998 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1999 }
2000 #[test]
2001 fn [<$test_fn _avx512_f64>]() {
2002 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2003 }
2004 )*
2005 }
2006 }
2007 }
2008
2009 generate_all_midprice_tests!(
2010 check_midprice_partial_params,
2011 check_midprice_accuracy,
2012 check_midprice_default_candles,
2013 check_midprice_zero_period,
2014 check_midprice_period_exceeds_length,
2015 check_midprice_very_small_dataset,
2016 check_midprice_reinput,
2017 check_midprice_nan_handling,
2018 check_midprice_streaming,
2019 check_midprice_all_nan,
2020 check_midprice_no_poison
2021 );
2022
2023 #[cfg(test)]
2024 generate_all_midprice_tests!(check_midprice_property);
2025
2026 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2027 skip_if_unsupported!(kernel, test);
2028
2029 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2030 let c = read_candles_from_csv(file)?;
2031
2032 let output = MidpriceBatchBuilder::new()
2033 .kernel(kernel)
2034 .apply_candles(&c, "high", "low")?;
2035
2036 let def = MidpriceParams::default();
2037 let row = output.values_for(&def).expect("default row missing");
2038
2039 assert_eq!(row.len(), c.close.len());
2040
2041 let expected = [59583.0, 59583.0, 59583.0, 59486.0, 58989.0];
2042 let start = row.len() - 5;
2043 for (i, &v) in row[start..].iter().enumerate() {
2044 assert!(
2045 (v - expected[i]).abs() < 1e-1,
2046 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2047 );
2048 }
2049 Ok(())
2050 }
2051
2052 #[cfg(debug_assertions)]
2053 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2054 skip_if_unsupported!(kernel, test);
2055
2056 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2057 let c = read_candles_from_csv(file)?;
2058
2059 let test_configs = vec![
2060 (2, 10, 2),
2061 (5, 25, 5),
2062 (30, 60, 15),
2063 (2, 5, 1),
2064 (10, 10, 0),
2065 (14, 14, 0),
2066 (50, 100, 25),
2067 ];
2068
2069 for (cfg_idx, &(period_start, period_end, period_step)) in test_configs.iter().enumerate() {
2070 let output = MidpriceBatchBuilder::new()
2071 .kernel(kernel)
2072 .period_range(period_start, period_end, period_step)
2073 .apply_candles(&c, "high", "low")?;
2074
2075 for (idx, &val) in output.values.iter().enumerate() {
2076 if val.is_nan() {
2077 continue;
2078 }
2079
2080 let bits = val.to_bits();
2081 let row = idx / output.cols;
2082 let col = idx % output.cols;
2083 let combo = &output.combos[row];
2084
2085 if bits == 0x11111111_11111111 {
2086 panic!(
2087 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2088 at row {} col {} (flat index {}) with params: period={}",
2089 test,
2090 cfg_idx,
2091 val,
2092 bits,
2093 row,
2094 col,
2095 idx,
2096 combo.period.unwrap_or(14)
2097 );
2098 }
2099
2100 if bits == 0x22222222_22222222 {
2101 panic!(
2102 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2103 at row {} col {} (flat index {}) with params: period={}",
2104 test,
2105 cfg_idx,
2106 val,
2107 bits,
2108 row,
2109 col,
2110 idx,
2111 combo.period.unwrap_or(14)
2112 );
2113 }
2114
2115 if bits == 0x33333333_33333333 {
2116 panic!(
2117 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2118 at row {} col {} (flat index {}) with params: period={}",
2119 test,
2120 cfg_idx,
2121 val,
2122 bits,
2123 row,
2124 col,
2125 idx,
2126 combo.period.unwrap_or(14)
2127 );
2128 }
2129 }
2130 }
2131
2132 Ok(())
2133 }
2134
2135 #[cfg(not(debug_assertions))]
2136 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2137 Ok(())
2138 }
2139
2140 macro_rules! gen_batch_tests {
2141 ($fn_name:ident) => {
2142 paste::paste! {
2143 #[test] fn [<$fn_name _scalar>]() {
2144 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2145 }
2146 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2147 #[test] fn [<$fn_name _avx2>]() {
2148 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2149 }
2150 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2151 #[test] fn [<$fn_name _avx512>]() {
2152 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2153 }
2154 #[test] fn [<$fn_name _auto_detect>]() {
2155 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2156 }
2157 }
2158 };
2159 }
2160 gen_batch_tests!(check_batch_default_row);
2161 gen_batch_tests!(check_batch_no_poison);
2162
2163 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2164 #[test]
2165 fn test_midprice_into_matches_api() -> Result<(), Box<dyn Error>> {
2166 let n = 256usize;
2167 let mut high = Vec::with_capacity(n);
2168 let mut low = Vec::with_capacity(n);
2169 for i in 0..n {
2170 let t = i as f64;
2171 let h = 100.0 + 0.1 * t + (0.03 * t).sin();
2172 let l = h - (2.0 + ((i % 7) as f64));
2173 high.push(h);
2174 low.push(l);
2175 }
2176
2177 let params = MidpriceParams::default();
2178 let input = MidpriceInput::from_slices(&high, &low, params.clone());
2179
2180 let baseline = midprice(&input)?.values;
2181
2182 let mut out = vec![0.0; n];
2183 midprice_into(&input, &mut out)?;
2184
2185 assert_eq!(baseline.len(), out.len());
2186
2187 fn eq_or_both_nan(a: f64, b: f64) -> bool {
2188 (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
2189 }
2190 for i in 0..n {
2191 assert!(
2192 eq_or_both_nan(baseline[i], out[i]),
2193 "Mismatch at index {}: baseline={}, into={}",
2194 i,
2195 baseline[i],
2196 out[i]
2197 );
2198 }
2199 Ok(())
2200 }
2201}
2202
2203#[cfg(feature = "python")]
2204pub fn register_midprice_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
2205 m.add_function(wrap_pyfunction!(midprice_py, m)?)?;
2206 m.add_function(wrap_pyfunction!(midprice_batch_py, m)?)?;
2207 Ok(())
2208}