Skip to main content

vector_ta/indicators/
midprice.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10use crate::utilities::data_loader::{source_type, Candles};
11use crate::utilities::enums::Kernel;
12use crate::utilities::helpers::{
13    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
14    make_uninit_matrix,
15};
16#[cfg(feature = "python")]
17use crate::utilities::kernel_validation::validate_kernel;
18use aligned_vec::{AVec, CACHELINE_ALIGN};
19#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
20use core::arch::x86_64::*;
21#[cfg(not(target_arch = "wasm32"))]
22use rayon::prelude::*;
23#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
24use serde::{Deserialize, Serialize};
25use std::convert::AsRef;
26use std::error::Error;
27use thiserror::Error;
28#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
29use wasm_bindgen::prelude::*;
30
31#[derive(Debug, Clone)]
32pub enum MidpriceData<'a> {
33    Candles {
34        candles: &'a Candles,
35        high_src: &'a str,
36        low_src: &'a str,
37    },
38    Slices {
39        high: &'a [f64],
40        low: &'a [f64],
41    },
42}
43
44#[derive(Debug, Clone)]
45pub struct MidpriceOutput {
46    pub values: Vec<f64>,
47}
48
49#[derive(Debug, Clone)]
50pub struct MidpriceParams {
51    pub period: Option<usize>,
52}
53
54impl Default for MidpriceParams {
55    fn default() -> Self {
56        Self { period: Some(14) }
57    }
58}
59
60#[derive(Debug, Clone)]
61pub struct MidpriceInput<'a> {
62    pub data: MidpriceData<'a>,
63    pub params: MidpriceParams,
64}
65
66impl<'a> MidpriceInput<'a> {
67    #[inline]
68    pub fn from_candles(
69        candles: &'a Candles,
70        high_src: &'a str,
71        low_src: &'a str,
72        params: MidpriceParams,
73    ) -> Self {
74        Self {
75            data: MidpriceData::Candles {
76                candles,
77                high_src,
78                low_src,
79            },
80            params,
81        }
82    }
83    #[inline]
84    pub fn from_slices(high: &'a [f64], low: &'a [f64], params: MidpriceParams) -> Self {
85        Self {
86            data: MidpriceData::Slices { high, low },
87            params,
88        }
89    }
90    #[inline]
91    pub fn with_default_candles(candles: &'a Candles) -> Self {
92        Self {
93            data: MidpriceData::Candles {
94                candles,
95                high_src: "high",
96                low_src: "low",
97            },
98            params: MidpriceParams::default(),
99        }
100    }
101    #[inline]
102    pub fn get_period(&self) -> usize {
103        self.params.period.unwrap_or(14)
104    }
105}
106
107#[derive(Copy, Clone, Debug)]
108pub struct MidpriceBuilder {
109    period: Option<usize>,
110    kernel: Kernel,
111}
112
113impl Default for MidpriceBuilder {
114    fn default() -> Self {
115        Self {
116            period: None,
117            kernel: Kernel::Auto,
118        }
119    }
120}
121
122impl MidpriceBuilder {
123    #[inline(always)]
124    pub fn new() -> Self {
125        Self::default()
126    }
127    #[inline(always)]
128    pub fn period(mut self, n: usize) -> Self {
129        self.period = Some(n);
130        self
131    }
132    #[inline(always)]
133    pub fn kernel(mut self, k: Kernel) -> Self {
134        self.kernel = k;
135        self
136    }
137    #[inline(always)]
138    pub fn apply(self, c: &Candles) -> Result<MidpriceOutput, MidpriceError> {
139        let p = MidpriceParams {
140            period: self.period,
141        };
142        let i = MidpriceInput::from_candles(c, "high", "low", p);
143        midprice_with_kernel(&i, self.kernel)
144    }
145    #[inline(always)]
146    pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<MidpriceOutput, MidpriceError> {
147        let p = MidpriceParams {
148            period: self.period,
149        };
150        let i = MidpriceInput::from_slices(high, low, p);
151        midprice_with_kernel(&i, self.kernel)
152    }
153    #[inline(always)]
154    pub fn into_stream(self) -> Result<MidpriceStream, MidpriceError> {
155        let p = MidpriceParams {
156            period: self.period,
157        };
158        MidpriceStream::try_new(p)
159    }
160}
161
162#[derive(Debug, Error)]
163pub enum MidpriceError {
164    #[error("midprice: Empty data provided.")]
165    EmptyInputData,
166    #[error("midprice: All values are NaN.")]
167    AllValuesNaN,
168    #[error("midprice: Invalid period: period = {period}, data length = {data_len}")]
169    InvalidPeriod { period: usize, data_len: usize },
170    #[error("midprice: Not enough valid data: needed = {needed}, valid = {valid}")]
171    NotEnoughValidData { needed: usize, valid: usize },
172    #[error("midprice: Output length mismatch: expected = {expected}, got = {got}")]
173    OutputLengthMismatch { expected: usize, got: usize },
174    #[error("midprice: Invalid range: start={start}, end={end}, step={step}")]
175    InvalidRange {
176        start: usize,
177        end: usize,
178        step: usize,
179    },
180    #[error("midprice: Invalid kernel for batch: {0:?}")]
181    InvalidKernelForBatch(Kernel),
182    #[error("midprice: Mismatched data length: high_len = {high_len}, low_len = {low_len}")]
183    MismatchedDataLength { high_len: usize, low_len: usize },
184    #[error("midprice: invalid input: {0}")]
185    InvalidInput(&'static str),
186}
187
188#[inline]
189pub fn midprice(input: &MidpriceInput) -> Result<MidpriceOutput, MidpriceError> {
190    midprice_with_kernel(input, Kernel::Auto)
191}
192
193pub fn midprice_with_kernel(
194    input: &MidpriceInput,
195    kernel: Kernel,
196) -> Result<MidpriceOutput, MidpriceError> {
197    let (high, low) = match &input.data {
198        MidpriceData::Candles {
199            candles,
200            high_src,
201            low_src,
202        } => {
203            let h = source_type(candles, high_src);
204            let l = source_type(candles, low_src);
205            (h, l)
206        }
207        MidpriceData::Slices { high, low } => (*high, *low),
208    };
209
210    if high.is_empty() || low.is_empty() {
211        return Err(MidpriceError::EmptyInputData);
212    }
213    if high.len() != low.len() {
214        return Err(MidpriceError::MismatchedDataLength {
215            high_len: high.len(),
216            low_len: low.len(),
217        });
218    }
219    let period = input.get_period();
220    if period == 0 || period > high.len() {
221        return Err(MidpriceError::InvalidPeriod {
222            period,
223            data_len: high.len(),
224        });
225    }
226    let first_valid_idx = match (0..high.len()).find(|&i| !high[i].is_nan() && !low[i].is_nan()) {
227        Some(idx) => idx,
228        None => return Err(MidpriceError::AllValuesNaN),
229    };
230    if (high.len() - first_valid_idx) < period {
231        return Err(MidpriceError::NotEnoughValidData {
232            needed: period,
233            valid: high.len() - first_valid_idx,
234        });
235    }
236    let mut out = alloc_with_nan_prefix(high.len(), first_valid_idx + period - 1);
237    let chosen = match kernel {
238        Kernel::Auto => detect_best_kernel(),
239        other => other,
240    };
241    unsafe {
242        match chosen {
243            Kernel::Scalar | Kernel::ScalarBatch => {
244                midprice_scalar(high, low, period, first_valid_idx, &mut out)
245            }
246            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
247            Kernel::Avx2 | Kernel::Avx2Batch => {
248                midprice_avx2(high, low, period, first_valid_idx, &mut out)
249            }
250            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
251            Kernel::Avx512 | Kernel::Avx512Batch => {
252                midprice_avx512(high, low, period, first_valid_idx, &mut out)
253            }
254            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
255            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
256                midprice_scalar(high, low, period, first_valid_idx, &mut out)
257            }
258            Kernel::Auto => midprice_scalar(high, low, period, first_valid_idx, &mut out),
259        }
260    }
261    Ok(MidpriceOutput { values: out })
262}
263
264#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
265pub fn midprice_into(input: &MidpriceInput, out: &mut [f64]) -> Result<(), MidpriceError> {
266    let (high, low) = match &input.data {
267        MidpriceData::Candles {
268            candles,
269            high_src,
270            low_src,
271        } => {
272            let h = source_type(candles, high_src);
273            let l = source_type(candles, low_src);
274            (h, l)
275        }
276        MidpriceData::Slices { high, low } => (*high, *low),
277    };
278
279    if out.len() != high.len() || high.len() != low.len() {
280        return Err(MidpriceError::OutputLengthMismatch {
281            expected: high.len(),
282            got: out.len(),
283        });
284    }
285
286    midprice_into_slice(out, high, low, &input.params, Kernel::Auto)
287}
288
289#[inline]
290pub fn midprice_scalar(
291    high: &[f64],
292    low: &[f64],
293    period: usize,
294    first_valid_idx: usize,
295    out: &mut [f64],
296) {
297    if period <= 64 {
298        let len = high.len();
299        let warmup_end = first_valid_idx + period - 1;
300        unsafe {
301            let high_ptr = high.as_ptr();
302            let low_ptr = low.as_ptr();
303            let out_ptr = out.as_mut_ptr();
304            for i in warmup_end..len {
305                let window_start = i + 1 - period;
306                let mut highest = f64::NEG_INFINITY;
307                let mut lowest = f64::INFINITY;
308                let mut j = window_start;
309                while j <= i {
310                    let h = *high_ptr.add(j);
311                    if h > highest {
312                        highest = h;
313                    }
314                    let l = *low_ptr.add(j);
315                    if l < lowest {
316                        lowest = l;
317                    }
318                    j += 1;
319                }
320                *out_ptr.add(i) = (highest + lowest) * 0.5;
321            }
322        }
323        return;
324    }
325
326    let warmup_end = first_valid_idx + period - 1;
327    let mut dq_high: std::collections::VecDeque<usize> =
328        std::collections::VecDeque::with_capacity(period + 1);
329    let mut dq_low: std::collections::VecDeque<usize> =
330        std::collections::VecDeque::with_capacity(period + 1);
331
332    for i in first_valid_idx..high.len() {
333        let hv = high[i];
334        if !hv.is_nan() {
335            while let Some(&j) = dq_high.back() {
336                if high[j] <= hv {
337                    dq_high.pop_back();
338                } else {
339                    break;
340                }
341            }
342            dq_high.push_back(i);
343        }
344
345        let lv = low[i];
346        if !lv.is_nan() {
347            while let Some(&j) = dq_low.back() {
348                if low[j] >= lv {
349                    dq_low.pop_back();
350                } else {
351                    break;
352                }
353            }
354            dq_low.push_back(i);
355        }
356
357        if i < warmup_end {
358            continue;
359        }
360
361        let start = i + 1 - period;
362        while let Some(&j) = dq_high.front() {
363            if j < start {
364                dq_high.pop_front();
365            } else {
366                break;
367            }
368        }
369        while let Some(&j) = dq_low.front() {
370            if j < start {
371                dq_low.pop_front();
372            } else {
373                break;
374            }
375        }
376
377        let highest = dq_high
378            .front()
379            .map(|&j| high[j])
380            .unwrap_or(f64::NEG_INFINITY);
381        let lowest = dq_low.front().map(|&j| low[j]).unwrap_or(f64::INFINITY);
382        out[i] = (highest + lowest) * 0.5;
383    }
384}
385
386#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
387#[inline]
388pub fn midprice_avx2(
389    high: &[f64],
390    low: &[f64],
391    period: usize,
392    first_valid_idx: usize,
393    out: &mut [f64],
394) {
395    midprice_scalar(high, low, period, first_valid_idx, out);
396}
397
398#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
399#[inline]
400pub fn midprice_avx512(
401    high: &[f64],
402    low: &[f64],
403    period: usize,
404    first_valid_idx: usize,
405    out: &mut [f64],
406) {
407    if period <= 32 {
408        unsafe { midprice_avx512_short(high, low, period, first_valid_idx, out) }
409    } else {
410        unsafe { midprice_avx512_long(high, low, period, first_valid_idx, out) }
411    }
412}
413
414#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
415#[inline]
416pub unsafe fn midprice_avx512_short(
417    high: &[f64],
418    low: &[f64],
419    period: usize,
420    first_valid_idx: usize,
421    out: &mut [f64],
422) {
423    midprice_scalar(high, low, period, first_valid_idx, out);
424}
425
426#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
427#[inline]
428pub unsafe fn midprice_avx512_long(
429    high: &[f64],
430    low: &[f64],
431    period: usize,
432    first_valid_idx: usize,
433    out: &mut [f64],
434) {
435    midprice_scalar(high, low, period, first_valid_idx, out);
436}
437
438#[derive(Debug, Clone)]
439pub struct MidpriceStream {
440    period: usize,
441
442    started: bool,
443
444    seen: usize,
445
446    dq_high: std::collections::VecDeque<(usize, f64)>,
447    dq_low: std::collections::VecDeque<(usize, f64)>,
448}
449
450impl MidpriceStream {
451    pub fn try_new(params: MidpriceParams) -> Result<Self, MidpriceError> {
452        let period = params.period.unwrap_or(14);
453        if period == 0 {
454            return Err(MidpriceError::InvalidPeriod {
455                period,
456                data_len: 0,
457            });
458        }
459        let cap = period + 1;
460        Ok(Self {
461            period,
462            started: false,
463            seen: 0,
464            dq_high: std::collections::VecDeque::with_capacity(cap),
465            dq_low: std::collections::VecDeque::with_capacity(cap),
466        })
467    }
468
469    #[inline(always)]
470    pub fn update(&mut self, high: f64, low: f64) -> Option<f64> {
471        if !self.started {
472            if !(high.is_finite() && low.is_finite()) {
473                return None;
474            }
475            self.started = true;
476            self.seen = 0;
477        }
478
479        let i = self.seen;
480
481        if high.is_finite() {
482            while let Some(&(_, v)) = self.dq_high.back() {
483                if v <= high {
484                    self.dq_high.pop_back();
485                } else {
486                    break;
487                }
488            }
489            self.dq_high.push_back((i, high));
490        }
491
492        if low.is_finite() {
493            while let Some(&(_, v)) = self.dq_low.back() {
494                if v >= low {
495                    self.dq_low.pop_back();
496                } else {
497                    break;
498                }
499            }
500            self.dq_low.push_back((i, low));
501        }
502
503        let start = i.saturating_add(1).saturating_sub(self.period);
504        while let Some(&(idx, _)) = self.dq_high.front() {
505            if idx < start {
506                self.dq_high.pop_front();
507            } else {
508                break;
509            }
510        }
511        while let Some(&(idx, _)) = self.dq_low.front() {
512            if idx < start {
513                self.dq_low.pop_front();
514            } else {
515                break;
516            }
517        }
518
519        self.seen = i + 1;
520
521        if self.seen < self.period {
522            return None;
523        }
524
525        Some(self.calc())
526    }
527
528    #[inline(always)]
529    fn calc(&self) -> f64 {
530        let max_h = self
531            .dq_high
532            .front()
533            .map(|&(_, v)| v)
534            .unwrap_or(f64::NEG_INFINITY);
535        let min_l = self
536            .dq_low
537            .front()
538            .map(|&(_, v)| v)
539            .unwrap_or(f64::INFINITY);
540        (max_h + min_l) / 2.0
541    }
542}
543
544#[derive(Clone, Debug)]
545pub struct MidpriceBatchRange {
546    pub period: (usize, usize, usize),
547}
548
549impl Default for MidpriceBatchRange {
550    fn default() -> Self {
551        Self {
552            period: (14, 263, 1),
553        }
554    }
555}
556
557#[derive(Clone, Debug, Default)]
558pub struct MidpriceBatchBuilder {
559    range: MidpriceBatchRange,
560    kernel: Kernel,
561}
562
563impl MidpriceBatchBuilder {
564    pub fn new() -> Self {
565        Self::default()
566    }
567    pub fn kernel(mut self, k: Kernel) -> Self {
568        self.kernel = k;
569        self
570    }
571    #[inline]
572    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
573        self.range.period = (start, end, step);
574        self
575    }
576    #[inline]
577    pub fn period_static(mut self, p: usize) -> Self {
578        self.range.period = (p, p, 0);
579        self
580    }
581    pub fn apply_slices(
582        self,
583        high: &[f64],
584        low: &[f64],
585    ) -> Result<MidpriceBatchOutput, MidpriceError> {
586        midprice_batch_with_kernel(high, low, &self.range, self.kernel)
587    }
588    pub fn apply_candles(
589        self,
590        c: &Candles,
591        high_src: &str,
592        low_src: &str,
593    ) -> Result<MidpriceBatchOutput, MidpriceError> {
594        let high = source_type(c, high_src);
595        let low = source_type(c, low_src);
596        self.apply_slices(high, low)
597    }
598    pub fn with_default_candles(c: &Candles) -> Result<MidpriceBatchOutput, MidpriceError> {
599        MidpriceBatchBuilder::new()
600            .kernel(Kernel::Auto)
601            .apply_candles(c, "high", "low")
602    }
603}
604
605pub fn midprice_batch_with_kernel(
606    high: &[f64],
607    low: &[f64],
608    sweep: &MidpriceBatchRange,
609    k: Kernel,
610) -> Result<MidpriceBatchOutput, MidpriceError> {
611    let kernel = match k {
612        Kernel::Auto => detect_best_batch_kernel(),
613        other if other.is_batch() => other,
614        _ => return Err(MidpriceError::InvalidKernelForBatch(k)),
615    };
616    let simd = match kernel {
617        Kernel::Avx512Batch => Kernel::Avx512,
618        Kernel::Avx2Batch => Kernel::Avx2,
619        Kernel::ScalarBatch => Kernel::Scalar,
620        _ => Kernel::Scalar,
621    };
622    midprice_batch_par_slice(high, low, sweep, simd)
623}
624
625#[derive(Clone, Debug)]
626pub struct MidpriceBatchOutput {
627    pub values: Vec<f64>,
628    pub combos: Vec<MidpriceParams>,
629    pub rows: usize,
630    pub cols: usize,
631}
632impl MidpriceBatchOutput {
633    pub fn row_for_params(&self, p: &MidpriceParams) -> Option<usize> {
634        self.combos
635            .iter()
636            .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
637    }
638    pub fn values_for(&self, p: &MidpriceParams) -> Option<&[f64]> {
639        self.row_for_params(p).and_then(|row| {
640            row.checked_mul(self.cols)
641                .map(|start| &self.values[start..start + self.cols])
642        })
643    }
644}
645
646#[inline(always)]
647fn expand_grid(r: &MidpriceBatchRange) -> Result<Vec<MidpriceParams>, MidpriceError> {
648    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, MidpriceError> {
649        if step == 0 || start == end {
650            return Ok(vec![start]);
651        }
652        let mut v = Vec::new();
653        let s = step.max(1);
654        if start < end {
655            let mut cur = start;
656            while cur <= end {
657                v.push(cur);
658                let next = cur.saturating_add(s);
659                if next == cur {
660                    break;
661                }
662                cur = next;
663            }
664        } else {
665            let mut cur = start;
666            while cur >= end {
667                v.push(cur);
668                let next = cur.saturating_sub(s);
669                if next == cur {
670                    break;
671                }
672                cur = next;
673                if cur == 0 && end > 0 {
674                    break;
675                }
676            }
677            v.retain(|&x| x >= end);
678        }
679        if v.is_empty() {
680            return Err(MidpriceError::InvalidRange { start, end, step });
681        }
682        Ok(v)
683    }
684    let periods = axis_usize(r.period)?;
685    let mut out = Vec::with_capacity(periods.len());
686    for &p in &periods {
687        out.push(MidpriceParams { period: Some(p) });
688    }
689    Ok(out)
690}
691
692#[inline(always)]
693fn validate_batch_periods(combos: &[MidpriceParams], cols: usize) -> Result<usize, MidpriceError> {
694    let mut max_p = 0usize;
695    for c in combos {
696        let p = c.period.unwrap_or(14);
697        if p == 0 || p > cols {
698            return Err(MidpriceError::InvalidPeriod {
699                period: p,
700                data_len: cols,
701            });
702        }
703        if p > max_p {
704            max_p = p;
705        }
706    }
707    Ok(max_p)
708}
709
710#[inline(always)]
711fn batch_warm_prefixes(
712    combos: &[MidpriceParams],
713    first: usize,
714    cols: usize,
715) -> Result<(Vec<usize>, usize), MidpriceError> {
716    let mut max_p = 0usize;
717    let mut warm = Vec::with_capacity(combos.len());
718    for c in combos {
719        let p = c.period.unwrap_or(14);
720        if p == 0 || p > cols {
721            return Err(MidpriceError::InvalidPeriod {
722                period: p,
723                data_len: cols,
724            });
725        }
726        if p > max_p {
727            max_p = p;
728        }
729        warm.push(first + p - 1);
730    }
731    if cols - first < max_p {
732        return Err(MidpriceError::NotEnoughValidData {
733            needed: max_p,
734            valid: cols - first,
735        });
736    }
737    Ok((warm, max_p))
738}
739
740#[inline(always)]
741pub fn midprice_batch_slice(
742    high: &[f64],
743    low: &[f64],
744    sweep: &MidpriceBatchRange,
745    kern: Kernel,
746) -> Result<MidpriceBatchOutput, MidpriceError> {
747    midprice_batch_inner(high, low, sweep, kern, false)
748}
749
750#[inline(always)]
751pub fn midprice_batch_par_slice(
752    high: &[f64],
753    low: &[f64],
754    sweep: &MidpriceBatchRange,
755    kern: Kernel,
756) -> Result<MidpriceBatchOutput, MidpriceError> {
757    midprice_batch_inner(high, low, sweep, kern, true)
758}
759
760#[inline(always)]
761fn midprice_batch_inner_into(
762    high: &[f64],
763    low: &[f64],
764    sweep: &MidpriceBatchRange,
765    kern: Kernel,
766    parallel: bool,
767    out: &mut [f64],
768) -> Result<Vec<MidpriceParams>, MidpriceError> {
769    let combos = expand_grid(sweep)?;
770    if high.is_empty() || low.is_empty() {
771        return Err(MidpriceError::EmptyInputData);
772    }
773    if high.len() != low.len() {
774        return Err(MidpriceError::MismatchedDataLength {
775            high_len: high.len(),
776            low_len: low.len(),
777        });
778    }
779    let first = (0..high.len())
780        .find(|&i| !high[i].is_nan() && !low[i].is_nan())
781        .ok_or(MidpriceError::AllValuesNaN)?;
782    let cols = high.len();
783    let max_p = validate_batch_periods(&combos, cols)?;
784    if cols - first < max_p {
785        return Err(MidpriceError::NotEnoughValidData {
786            needed: max_p,
787            valid: cols - first,
788        });
789    }
790    let rows = combos.len();
791    let expected = rows
792        .checked_mul(cols)
793        .ok_or(MidpriceError::InvalidInput("rows*cols overflow"))?;
794    if out.len() != expected {
795        return Err(MidpriceError::OutputLengthMismatch {
796            expected,
797            got: out.len(),
798        });
799    }
800
801    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
802        let period = combos[row].period.unwrap();
803        match kern {
804            Kernel::Scalar | Kernel::ScalarBatch => {
805                midprice_row_scalar(high, low, first, period, out_row)
806            }
807            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
808            Kernel::Avx2 | Kernel::Avx2Batch => {
809                midprice_row_avx2(high, low, first, period, out_row)
810            }
811            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
812            Kernel::Avx512 | Kernel::Avx512Batch => {
813                midprice_row_avx512(high, low, first, period, out_row)
814            }
815            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
816            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
817                midprice_row_scalar(high, low, first, period, out_row)
818            }
819            Kernel::Auto => midprice_row_scalar(high, low, first, period, out_row),
820        }
821    };
822
823    if parallel {
824        #[cfg(not(target_arch = "wasm32"))]
825        {
826            out.par_chunks_mut(cols)
827                .enumerate()
828                .for_each(|(row, slice)| do_row(row, slice));
829        }
830
831        #[cfg(target_arch = "wasm32")]
832        {
833            for (row, slice) in out.chunks_mut(cols).enumerate() {
834                do_row(row, slice);
835            }
836        }
837    } else {
838        for (row, slice) in out.chunks_mut(cols).enumerate() {
839            do_row(row, slice);
840        }
841    }
842
843    Ok(combos)
844}
845
846#[inline(always)]
847fn midprice_batch_inner(
848    high: &[f64],
849    low: &[f64],
850    sweep: &MidpriceBatchRange,
851    kern: Kernel,
852    parallel: bool,
853) -> Result<MidpriceBatchOutput, MidpriceError> {
854    let combos = expand_grid(sweep)?;
855    if high.is_empty() || low.is_empty() {
856        return Err(MidpriceError::EmptyInputData);
857    }
858    if high.len() != low.len() {
859        return Err(MidpriceError::MismatchedDataLength {
860            high_len: high.len(),
861            low_len: low.len(),
862        });
863    }
864
865    let first = (0..high.len())
866        .find(|&i| !high[i].is_nan() && !low[i].is_nan())
867        .ok_or(MidpriceError::AllValuesNaN)?;
868    let cols = high.len();
869    let (warm, _max_p) = batch_warm_prefixes(&combos, first, cols)?;
870
871    let rows = combos.len();
872    let total_cells = rows
873        .checked_mul(cols)
874        .ok_or(MidpriceError::InvalidInput("rows*cols overflow"))?;
875
876    let mut buf_mu = make_uninit_matrix(rows, cols);
877
878    init_matrix_prefixes(&mut buf_mu, cols, &warm);
879
880    let mut guard = core::mem::ManuallyDrop::new(buf_mu);
881    let out: &mut [f64] =
882        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
883    let combos = midprice_batch_inner_into(high, low, sweep, kern, parallel, out)?;
884
885    let values = unsafe {
886        Vec::from_raw_parts(
887            guard.as_mut_ptr() as *mut f64,
888            total_cells,
889            guard.capacity(),
890        )
891    };
892
893    Ok(MidpriceBatchOutput {
894        values,
895        combos,
896        rows,
897        cols,
898    })
899}
900
901#[inline(always)]
902pub unsafe fn midprice_row_scalar(
903    high: &[f64],
904    low: &[f64],
905    first: usize,
906    period: usize,
907    out: &mut [f64],
908) {
909    if period <= 64 {
910        let len = high.len();
911        let warmup_end = first + period - 1;
912        let high_ptr = high.as_ptr();
913        let low_ptr = low.as_ptr();
914        let out_ptr = out.as_mut_ptr();
915        for i in warmup_end..len {
916            let window_start = i + 1 - period;
917            let mut highest = f64::NEG_INFINITY;
918            let mut lowest = f64::INFINITY;
919            let mut j = window_start;
920            while j <= i {
921                let h = *high_ptr.add(j);
922                if h > highest {
923                    highest = h;
924                }
925                let l = *low_ptr.add(j);
926                if l < lowest {
927                    lowest = l;
928                }
929                j += 1;
930            }
931            *out_ptr.add(i) = (highest + lowest) * 0.5;
932        }
933        return;
934    }
935
936    let warmup_end = first + period - 1;
937    let mut dq_high: std::collections::VecDeque<usize> =
938        std::collections::VecDeque::with_capacity(period + 1);
939    let mut dq_low: std::collections::VecDeque<usize> =
940        std::collections::VecDeque::with_capacity(period + 1);
941
942    for i in first..high.len() {
943        let hv = high[i];
944        if !hv.is_nan() {
945            while let Some(&j) = dq_high.back() {
946                if high[j] <= hv {
947                    dq_high.pop_back();
948                } else {
949                    break;
950                }
951            }
952            dq_high.push_back(i);
953        }
954
955        let lv = low[i];
956        if !lv.is_nan() {
957            while let Some(&j) = dq_low.back() {
958                if low[j] >= lv {
959                    dq_low.pop_back();
960                } else {
961                    break;
962                }
963            }
964            dq_low.push_back(i);
965        }
966
967        if i < warmup_end {
968            continue;
969        }
970
971        let start = i + 1 - period;
972        while let Some(&j) = dq_high.front() {
973            if j < start {
974                dq_high.pop_front();
975            } else {
976                break;
977            }
978        }
979        while let Some(&j) = dq_low.front() {
980            if j < start {
981                dq_low.pop_front();
982            } else {
983                break;
984            }
985        }
986
987        let highest = dq_high
988            .front()
989            .map(|&j| high[j])
990            .unwrap_or(f64::NEG_INFINITY);
991        let lowest = dq_low.front().map(|&j| low[j]).unwrap_or(f64::INFINITY);
992        out[i] = (highest + lowest) * 0.5;
993    }
994}
995
996#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
997#[inline(always)]
998pub unsafe fn midprice_row_avx2(
999    high: &[f64],
1000    low: &[f64],
1001    first: usize,
1002    period: usize,
1003    out: &mut [f64],
1004) {
1005    midprice_row_scalar(high, low, first, period, out);
1006}
1007
1008#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1009#[inline(always)]
1010pub unsafe fn midprice_row_avx512(
1011    high: &[f64],
1012    low: &[f64],
1013    first: usize,
1014    period: usize,
1015    out: &mut [f64],
1016) {
1017    if period <= 32 {
1018        midprice_row_avx512_short(high, low, first, period, out);
1019    } else {
1020        midprice_row_avx512_long(high, low, first, period, out);
1021    }
1022}
1023
1024#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1025#[inline(always)]
1026pub unsafe fn midprice_row_avx512_short(
1027    high: &[f64],
1028    low: &[f64],
1029    first: usize,
1030    period: usize,
1031    out: &mut [f64],
1032) {
1033    midprice_row_scalar(high, low, first, period, out);
1034}
1035
1036#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1037#[inline(always)]
1038pub unsafe fn midprice_row_avx512_long(
1039    high: &[f64],
1040    low: &[f64],
1041    first: usize,
1042    period: usize,
1043    out: &mut [f64],
1044) {
1045    midprice_row_scalar(high, low, first, period, out);
1046}
1047
1048#[cfg(feature = "python")]
1049#[pyfunction(name = "midprice")]
1050#[pyo3(signature = (high, low, period, kernel=None))]
1051pub fn midprice_py<'py>(
1052    py: Python<'py>,
1053    high: PyReadonlyArray1<'py, f64>,
1054    low: PyReadonlyArray1<'py, f64>,
1055    period: usize,
1056    kernel: Option<&str>,
1057) -> PyResult<Bound<'py, PyArray1<f64>>> {
1058    let high_slice = high.as_slice()?;
1059    let low_slice = low.as_slice()?;
1060    let kern = validate_kernel(kernel, false)?;
1061
1062    let params = MidpriceParams {
1063        period: Some(period),
1064    };
1065    let input = MidpriceInput::from_slices(high_slice, low_slice, params);
1066
1067    let result_vec: Vec<f64> = py
1068        .allow_threads(|| midprice_with_kernel(&input, kern).map(|o| o.values))
1069        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1070
1071    Ok(result_vec.into_pyarray(py))
1072}
1073
1074#[cfg(feature = "python")]
1075#[pyfunction(name = "midprice_batch")]
1076#[pyo3(signature = (high, low, period_range, kernel=None))]
1077pub fn midprice_batch_py<'py>(
1078    py: Python<'py>,
1079    high: PyReadonlyArray1<'py, f64>,
1080    low: PyReadonlyArray1<'py, f64>,
1081    period_range: (usize, usize, usize),
1082    kernel: Option<&str>,
1083) -> PyResult<Bound<'py, PyDict>> {
1084    let high_slice = high.as_slice()?;
1085    let low_slice = low.as_slice()?;
1086    let kern = validate_kernel(kernel, true)?;
1087    let sweep = MidpriceBatchRange {
1088        period: period_range,
1089    };
1090
1091    let cols = high_slice.len();
1092    if cols == 0 {
1093        return Err(PyValueError::new_err("midprice: empty data"));
1094    }
1095    if cols != low_slice.len() {
1096        return Err(PyValueError::new_err(format!(
1097            "midprice: length mismatch: high={}, low={}",
1098            cols,
1099            low_slice.len()
1100        )));
1101    }
1102    let first = (0..cols)
1103        .find(|&i| !high_slice[i].is_nan() && !low_slice[i].is_nan())
1104        .ok_or_else(|| PyValueError::new_err("midprice: All values are NaN"))?;
1105
1106    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1107    let rows = combos.len();
1108    let total = rows
1109        .checked_mul(cols)
1110        .ok_or_else(|| PyValueError::new_err("midprice: rows*cols overflow"))?;
1111
1112    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1113    let slice_out = unsafe { out_arr.as_slice_mut()? };
1114
1115    let (warm, _max_p) = batch_warm_prefixes(&combos, first, cols)
1116        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1117    let out_mu = unsafe {
1118        std::slice::from_raw_parts_mut(
1119            slice_out.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
1120            total,
1121        )
1122    };
1123    init_matrix_prefixes(out_mu, cols, &warm);
1124
1125    let combos = py
1126        .allow_threads(|| {
1127            let kernel = match kern {
1128                Kernel::Auto => detect_best_batch_kernel(),
1129                k => k,
1130            };
1131            let simd = match kernel {
1132                Kernel::Avx512Batch => Kernel::Avx512,
1133                Kernel::Avx2Batch => Kernel::Avx2,
1134                Kernel::ScalarBatch => Kernel::Scalar,
1135                _ => Kernel::Scalar,
1136            };
1137            midprice_batch_inner_into(high_slice, low_slice, &sweep, simd, true, slice_out)
1138        })
1139        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1140
1141    let dict = PyDict::new(py);
1142    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1143    dict.set_item(
1144        "periods",
1145        combos
1146            .iter()
1147            .map(|p| p.period.unwrap() as u64)
1148            .collect::<Vec<_>>()
1149            .into_pyarray(py),
1150    )?;
1151    Ok(dict)
1152}
1153
1154#[cfg(feature = "python")]
1155#[pyclass(name = "MidpriceStream")]
1156pub struct MidpriceStreamPy {
1157    stream: MidpriceStream,
1158}
1159
1160#[cfg(feature = "python")]
1161#[pymethods]
1162impl MidpriceStreamPy {
1163    #[new]
1164    fn new(period: usize) -> PyResult<Self> {
1165        let params = MidpriceParams {
1166            period: Some(period),
1167        };
1168        let stream =
1169            MidpriceStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1170        Ok(MidpriceStreamPy { stream })
1171    }
1172
1173    fn update(&mut self, high: f64, low: f64) -> Option<f64> {
1174        self.stream.update(high, low)
1175    }
1176}
1177
1178pub fn midprice_into_slice(
1179    dst: &mut [f64],
1180    high: &[f64],
1181    low: &[f64],
1182    params: &MidpriceParams,
1183    kern: Kernel,
1184) -> Result<(), MidpriceError> {
1185    if high.is_empty() || low.is_empty() {
1186        return Err(MidpriceError::EmptyInputData);
1187    }
1188    if high.len() != low.len() {
1189        return Err(MidpriceError::MismatchedDataLength {
1190            high_len: high.len(),
1191            low_len: low.len(),
1192        });
1193    }
1194    if dst.len() != high.len() {
1195        return Err(MidpriceError::OutputLengthMismatch {
1196            expected: high.len(),
1197            got: dst.len(),
1198        });
1199    }
1200
1201    let period = params.period.unwrap_or(14);
1202    if period == 0 || period > high.len() {
1203        return Err(MidpriceError::InvalidPeriod {
1204            period,
1205            data_len: high.len(),
1206        });
1207    }
1208
1209    let first_valid_idx = match (0..high.len()).find(|&i| !high[i].is_nan() && !low[i].is_nan()) {
1210        Some(idx) => idx,
1211        None => return Err(MidpriceError::AllValuesNaN),
1212    };
1213
1214    if high.len() - first_valid_idx < period {
1215        return Err(MidpriceError::NotEnoughValidData {
1216            needed: period,
1217            valid: high.len() - first_valid_idx,
1218        });
1219    }
1220
1221    let chosen = match kern {
1222        Kernel::Auto => detect_best_kernel(),
1223        other => other,
1224    };
1225
1226    let warmup_end = first_valid_idx + period - 1;
1227    for v in &mut dst[..warmup_end] {
1228        *v = f64::NAN;
1229    }
1230
1231    match chosen {
1232        Kernel::Scalar | Kernel::ScalarBatch => {
1233            midprice_scalar(high, low, period, first_valid_idx, dst)
1234        }
1235        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1236        Kernel::Avx2 | Kernel::Avx2Batch => midprice_avx2(high, low, period, first_valid_idx, dst),
1237        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1238        Kernel::Avx512 | Kernel::Avx512Batch => {
1239            midprice_avx512(high, low, period, first_valid_idx, dst)
1240        }
1241        #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1242        Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
1243            midprice_scalar(high, low, period, first_valid_idx, dst)
1244        }
1245        Kernel::Auto => midprice_scalar(high, low, period, first_valid_idx, dst),
1246    }
1247
1248    Ok(())
1249}
1250
1251#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1252#[wasm_bindgen]
1253pub fn midprice_js(high: &[f64], low: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1254    let params = MidpriceParams {
1255        period: Some(period),
1256    };
1257
1258    let mut output = vec![0.0; high.len()];
1259
1260    midprice_into_slice(&mut output, high, low, &params, Kernel::Auto)
1261        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1262
1263    Ok(output)
1264}
1265
1266#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1267#[wasm_bindgen]
1268pub fn midprice_alloc(len: usize) -> *mut f64 {
1269    let mut vec = Vec::<f64>::with_capacity(len);
1270    let ptr = vec.as_mut_ptr();
1271    std::mem::forget(vec);
1272    ptr
1273}
1274
1275#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1276#[wasm_bindgen]
1277pub fn midprice_free(ptr: *mut f64, len: usize) {
1278    if !ptr.is_null() {
1279        unsafe {
1280            let _ = Vec::from_raw_parts(ptr, len, len);
1281        }
1282    }
1283}
1284
1285#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1286#[wasm_bindgen]
1287pub fn midprice_into(
1288    in_high_ptr: *const f64,
1289    in_low_ptr: *const f64,
1290    out_ptr: *mut f64,
1291    len: usize,
1292    period: usize,
1293) -> Result<(), JsValue> {
1294    if in_high_ptr.is_null() || in_low_ptr.is_null() || out_ptr.is_null() {
1295        return Err(JsValue::from_str("Null pointer provided"));
1296    }
1297
1298    unsafe {
1299        let high = std::slice::from_raw_parts(in_high_ptr, len);
1300        let low = std::slice::from_raw_parts(in_low_ptr, len);
1301        let params = MidpriceParams {
1302            period: Some(period),
1303        };
1304
1305        if in_high_ptr == out_ptr || in_low_ptr == out_ptr {
1306            let mut temp = vec![0.0; len];
1307            midprice_into_slice(&mut temp, high, low, &params, Kernel::Auto)
1308                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1309            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1310            out.copy_from_slice(&temp);
1311        } else {
1312            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1313            midprice_into_slice(out, high, low, &params, Kernel::Auto)
1314                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1315        }
1316        Ok(())
1317    }
1318}
1319
1320#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1321#[derive(Serialize, Deserialize)]
1322pub struct MidpriceBatchConfig {
1323    pub period_range: (usize, usize, usize),
1324}
1325
1326#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1327#[derive(Serialize, Deserialize)]
1328pub struct MidpriceBatchJsOutput {
1329    pub values: Vec<f64>,
1330    pub periods: Vec<usize>,
1331    pub rows: usize,
1332    pub cols: usize,
1333}
1334
1335#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1336#[wasm_bindgen(js_name = midprice_batch)]
1337pub fn midprice_batch_js(high: &[f64], low: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1338    let config: MidpriceBatchConfig = serde_wasm_bindgen::from_value(config)
1339        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1340
1341    let range = MidpriceBatchRange {
1342        period: config.period_range,
1343    };
1344
1345    let output = midprice_batch_inner(high, low, &range, Kernel::Auto, false)
1346        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1347
1348    let periods: Vec<usize> = output
1349        .combos
1350        .iter()
1351        .map(|p| p.period.unwrap_or(14))
1352        .collect();
1353
1354    let js_output = MidpriceBatchJsOutput {
1355        values: output.values,
1356        periods,
1357        rows: output.rows,
1358        cols: output.cols,
1359    };
1360
1361    serde_wasm_bindgen::to_value(&js_output).map_err(|e| JsValue::from_str(&e.to_string()))
1362}
1363
1364#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1365#[wasm_bindgen]
1366pub fn midprice_batch_into(
1367    in_high_ptr: *const f64,
1368    in_low_ptr: *const f64,
1369    out_ptr: *mut f64,
1370    len: usize,
1371    period_start: usize,
1372    period_end: usize,
1373    period_step: usize,
1374) -> Result<usize, JsValue> {
1375    if len == 0 {
1376        return Err(JsValue::from_str("midprice: Empty data provided."));
1377    }
1378    if in_high_ptr.is_null() || in_low_ptr.is_null() || out_ptr.is_null() {
1379        return Err(JsValue::from_str("Null pointer provided"));
1380    }
1381    unsafe {
1382        let high = std::slice::from_raw_parts(in_high_ptr, len);
1383        let low = std::slice::from_raw_parts(in_low_ptr, len);
1384
1385        let range = MidpriceBatchRange {
1386            period: (period_start, period_end, period_step),
1387        };
1388        let combos = expand_grid(&range).map_err(|e| JsValue::from_str(&e.to_string()))?;
1389        let rows = combos.len();
1390        let total = rows
1391            .checked_mul(len)
1392            .ok_or_else(|| JsValue::from_str("rows*len overflow"))?;
1393
1394        let first = (0..len)
1395            .find(|&i| !high[i].is_nan() && !low[i].is_nan())
1396            .ok_or_else(|| JsValue::from_str("All values are NaN"))?;
1397        let (warm, _max_p) = batch_warm_prefixes(&combos, first, len)
1398            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1399
1400        if in_high_ptr == out_ptr || in_low_ptr == out_ptr {
1401            let mut temp: Vec<f64> = Vec::with_capacity(total);
1402            temp.set_len(total);
1403            let mu = std::slice::from_raw_parts_mut(
1404                temp.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
1405                total,
1406            );
1407            init_matrix_prefixes(mu, len, &warm);
1408            midprice_batch_inner_into(high, low, &range, Kernel::Auto, false, &mut temp)
1409                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1410            std::slice::from_raw_parts_mut(out_ptr, total).copy_from_slice(&temp);
1411        } else {
1412            let out = std::slice::from_raw_parts_mut(out_ptr, total);
1413            let mu = std::slice::from_raw_parts_mut(
1414                out.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
1415                total,
1416            );
1417            init_matrix_prefixes(mu, len, &warm);
1418            midprice_batch_inner_into(high, low, &range, Kernel::Auto, false, out)
1419                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1420        }
1421        Ok(rows)
1422    }
1423}
1424
1425#[cfg(test)]
1426mod tests {
1427    use super::*;
1428    use crate::skip_if_unsupported;
1429    use crate::utilities::data_loader::read_candles_from_csv;
1430
1431    fn check_midprice_partial_params(
1432        test_name: &str,
1433        kernel: Kernel,
1434    ) -> Result<(), Box<dyn Error>> {
1435        skip_if_unsupported!(kernel, test_name);
1436        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1437        let candles = read_candles_from_csv(file_path)?;
1438
1439        let default_params = MidpriceParams { period: None };
1440        let input = MidpriceInput::with_default_candles(&candles);
1441        let output = midprice_with_kernel(&input, kernel)?;
1442        assert_eq!(output.values.len(), candles.close.len());
1443        Ok(())
1444    }
1445
1446    fn check_midprice_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1447        skip_if_unsupported!(kernel, test_name);
1448        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1449        let candles = read_candles_from_csv(file_path)?;
1450
1451        let input = MidpriceInput::with_default_candles(&candles);
1452        let result = midprice_with_kernel(&input, kernel)?;
1453        let expected_last_five = [59583.0, 59583.0, 59583.0, 59486.0, 58989.0];
1454        let start = result.values.len().saturating_sub(5);
1455        for (i, &val) in result.values[start..].iter().enumerate() {
1456            let diff = (val - expected_last_five[i]).abs();
1457            assert!(
1458                diff < 1e-1,
1459                "[{}] MIDPRICE {:?} mismatch at idx {}: got {}, expected {}",
1460                test_name,
1461                kernel,
1462                i,
1463                val,
1464                expected_last_five[i]
1465            );
1466        }
1467        Ok(())
1468    }
1469
1470    fn check_midprice_default_candles(
1471        test_name: &str,
1472        kernel: Kernel,
1473    ) -> Result<(), Box<dyn Error>> {
1474        skip_if_unsupported!(kernel, test_name);
1475        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1476        let candles = read_candles_from_csv(file_path)?;
1477        let input = MidpriceInput::with_default_candles(&candles);
1478        let output = midprice_with_kernel(&input, kernel)?;
1479        assert_eq!(output.values.len(), candles.close.len());
1480        Ok(())
1481    }
1482
1483    fn check_midprice_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1484        skip_if_unsupported!(kernel, test_name);
1485        let highs = [10.0, 14.0, 12.0];
1486        let lows = [5.0, 6.0, 7.0];
1487        let params = MidpriceParams { period: Some(0) };
1488        let input = MidpriceInput::from_slices(&highs, &lows, params);
1489        let res = midprice_with_kernel(&input, kernel);
1490        assert!(
1491            res.is_err(),
1492            "[{}] MIDPRICE should fail with zero period",
1493            test_name
1494        );
1495        Ok(())
1496    }
1497
1498    fn check_midprice_period_exceeds_length(
1499        test_name: &str,
1500        kernel: Kernel,
1501    ) -> Result<(), Box<dyn Error>> {
1502        skip_if_unsupported!(kernel, test_name);
1503        let highs = [10.0, 14.0, 12.0];
1504        let lows = [5.0, 6.0, 7.0];
1505        let params = MidpriceParams { period: Some(10) };
1506        let input = MidpriceInput::from_slices(&highs, &lows, params);
1507        let res = midprice_with_kernel(&input, kernel);
1508        assert!(
1509            res.is_err(),
1510            "[{}] MIDPRICE should fail with period exceeding length",
1511            test_name
1512        );
1513        Ok(())
1514    }
1515
1516    fn check_midprice_very_small_dataset(
1517        test_name: &str,
1518        kernel: Kernel,
1519    ) -> Result<(), Box<dyn Error>> {
1520        skip_if_unsupported!(kernel, test_name);
1521        let highs = [42.0];
1522        let lows = [36.0];
1523        let params = MidpriceParams { period: Some(14) };
1524        let input = MidpriceInput::from_slices(&highs, &lows, params);
1525        let res = midprice_with_kernel(&input, kernel);
1526        assert!(
1527            res.is_err(),
1528            "[{}] MIDPRICE should fail with insufficient data",
1529            test_name
1530        );
1531        Ok(())
1532    }
1533
1534    fn check_midprice_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1535        skip_if_unsupported!(kernel, test_name);
1536        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1537        let candles = read_candles_from_csv(file_path)?;
1538
1539        let params = MidpriceParams { period: Some(10) };
1540        let input = MidpriceInput::with_default_candles(&candles);
1541        let first_result = midprice_with_kernel(&input, kernel)?;
1542
1543        let second_input = MidpriceInput::from_slices(
1544            &first_result.values,
1545            &first_result.values,
1546            MidpriceParams { period: Some(10) },
1547        );
1548        let second_result = midprice_with_kernel(&second_input, kernel)?;
1549        assert_eq!(second_result.values.len(), first_result.values.len());
1550        Ok(())
1551    }
1552
1553    fn check_midprice_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1554        skip_if_unsupported!(kernel, test_name);
1555        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1556        let candles = read_candles_from_csv(file_path)?;
1557        let input = MidpriceInput::with_default_candles(&candles);
1558        let res = midprice_with_kernel(&input, kernel)?;
1559        assert_eq!(res.values.len(), candles.close.len());
1560        if res.values.len() > 20 {
1561            for (i, &val) in res.values[20..].iter().enumerate() {
1562                assert!(
1563                    !val.is_nan(),
1564                    "[{}] Found unexpected NaN at out-index {}",
1565                    test_name,
1566                    20 + i
1567                );
1568            }
1569        }
1570        Ok(())
1571    }
1572
1573    fn check_midprice_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1574        skip_if_unsupported!(kernel, test_name);
1575        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1576        let candles = read_candles_from_csv(file_path)?;
1577
1578        let period = 14;
1579        let input = MidpriceInput::with_default_candles(&candles);
1580        let batch_output = midprice_with_kernel(&input, kernel)?.values;
1581
1582        let high = source_type(&candles, "high");
1583        let low = source_type(&candles, "low");
1584        let mut stream = MidpriceStream::try_new(MidpriceParams {
1585            period: Some(period),
1586        })?;
1587        let mut stream_values = Vec::with_capacity(high.len());
1588        for (&h, &l) in high.iter().zip(low.iter()) {
1589            match stream.update(h, l) {
1590                Some(val) => stream_values.push(val),
1591                None => stream_values.push(f64::NAN),
1592            }
1593        }
1594        assert_eq!(batch_output.len(), stream_values.len());
1595        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1596            if b.is_nan() && s.is_nan() {
1597                continue;
1598            }
1599            let diff = (b - s).abs();
1600            assert!(
1601                diff < 1e-9,
1602                "[{}] MIDPRICE streaming mismatch at idx {}: batch={}, stream={}, diff={}",
1603                test_name,
1604                i,
1605                b,
1606                s,
1607                diff
1608            );
1609        }
1610        Ok(())
1611    }
1612
1613    fn check_midprice_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1614        skip_if_unsupported!(kernel, test_name);
1615        let highs = [f64::NAN, f64::NAN, f64::NAN];
1616        let lows = [f64::NAN, f64::NAN, f64::NAN];
1617        let params = MidpriceParams { period: Some(2) };
1618        let input = MidpriceInput::from_slices(&highs, &lows, params);
1619        let result = midprice_with_kernel(&input, kernel);
1620        assert!(
1621            result.is_err(),
1622            "[{}] Expected error for all NaN values",
1623            test_name
1624        );
1625        if let Err(e) = result {
1626            assert!(e.to_string().contains("All values are NaN"));
1627        }
1628        Ok(())
1629    }
1630
1631    #[cfg(debug_assertions)]
1632    fn check_midprice_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1633        skip_if_unsupported!(kernel, test_name);
1634
1635        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1636        let candles = read_candles_from_csv(file_path)?;
1637
1638        let test_params = vec![
1639            MidpriceParams::default(),
1640            MidpriceParams { period: Some(2) },
1641            MidpriceParams { period: Some(5) },
1642            MidpriceParams { period: Some(7) },
1643            MidpriceParams { period: Some(20) },
1644            MidpriceParams { period: Some(30) },
1645            MidpriceParams { period: Some(50) },
1646            MidpriceParams { period: Some(100) },
1647            MidpriceParams { period: Some(200) },
1648        ];
1649
1650        for (param_idx, params) in test_params.iter().enumerate() {
1651            let input = MidpriceInput::from_candles(&candles, "high", "low", params.clone());
1652            let output = midprice_with_kernel(&input, kernel)?;
1653
1654            for (i, &val) in output.values.iter().enumerate() {
1655                if val.is_nan() {
1656                    continue;
1657                }
1658
1659                let bits = val.to_bits();
1660
1661                if bits == 0x11111111_11111111 {
1662                    panic!(
1663                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1664						 with params: period={} (param set {})",
1665                        test_name,
1666                        val,
1667                        bits,
1668                        i,
1669                        params.period.unwrap_or(14),
1670                        param_idx
1671                    );
1672                }
1673
1674                if bits == 0x22222222_22222222 {
1675                    panic!(
1676                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1677						 with params: period={} (param set {})",
1678                        test_name,
1679                        val,
1680                        bits,
1681                        i,
1682                        params.period.unwrap_or(14),
1683                        param_idx
1684                    );
1685                }
1686
1687                if bits == 0x33333333_33333333 {
1688                    panic!(
1689                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1690						 with params: period={} (param set {})",
1691                        test_name,
1692                        val,
1693                        bits,
1694                        i,
1695                        params.period.unwrap_or(14),
1696                        param_idx
1697                    );
1698                }
1699            }
1700        }
1701
1702        Ok(())
1703    }
1704
1705    #[cfg(not(debug_assertions))]
1706    fn check_midprice_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1707        Ok(())
1708    }
1709
1710    #[cfg(test)]
1711    fn check_midprice_property(
1712        test_name: &str,
1713        kernel: Kernel,
1714    ) -> Result<(), Box<dyn std::error::Error>> {
1715        use proptest::prelude::*;
1716        skip_if_unsupported!(kernel, test_name);
1717
1718        let strat = (2usize..=100)
1719            .prop_flat_map(|period| {
1720                (
1721                    prop::collection::vec(
1722                        prop::strategy::Union::new(vec![
1723                            (0.001f64..0.1f64).boxed(),
1724                            (10f64..10000f64).boxed(),
1725                            (1e6f64..1e8f64).boxed(),
1726                        ])
1727                        .prop_filter("finite", |x| x.is_finite()),
1728                        period..=500,
1729                    ),
1730                    Just(period),
1731                    prop::strategy::Union::new(vec![
1732                        (0.0001f64..0.01f64).boxed(),
1733                        (0.01f64..0.1f64).boxed(),
1734                        (0.1f64..0.3f64).boxed(),
1735                    ]),
1736                    0usize..=9,
1737                )
1738            })
1739            .prop_map(|(base_prices, period, spread_factor, scenario)| {
1740                let len = base_prices.len();
1741                let mut highs = Vec::with_capacity(len);
1742                let mut lows = Vec::with_capacity(len);
1743
1744                match scenario {
1745                    0 => {
1746                        for &base in &base_prices {
1747                            let spread = base * spread_factor;
1748                            highs.push(base + spread / 2.0);
1749                            lows.push(base - spread / 2.0);
1750                        }
1751                    }
1752                    1 => {
1753                        let price = base_prices[0];
1754                        let spread = price * spread_factor;
1755                        for _ in 0..len {
1756                            highs.push(price + spread / 2.0);
1757                            lows.push(price - spread / 2.0);
1758                        }
1759                    }
1760                    2 => {
1761                        let start_price = base_prices[0];
1762                        let increment = 10.0;
1763                        for i in 0..len {
1764                            let price = start_price + (i as f64) * increment;
1765                            let spread = price * spread_factor;
1766                            highs.push(price + spread / 2.0);
1767                            lows.push(price - spread / 2.0);
1768                        }
1769                    }
1770                    3 => {
1771                        let start_price = base_prices[0];
1772                        let decrement = 10.0;
1773                        for i in 0..len {
1774                            let price = (start_price - (i as f64) * decrement).max(10.0);
1775                            let spread = price * spread_factor;
1776                            highs.push(price + spread / 2.0);
1777                            lows.push(price - spread / 2.0);
1778                        }
1779                    }
1780                    4 => {
1781                        for &base in &base_prices {
1782                            let spread = base * 0.01;
1783                            highs.push(base + spread);
1784                            lows.push(base);
1785                        }
1786                    }
1787                    5 => {
1788                        let amplitude = 100.0;
1789                        let offset = base_prices[0];
1790                        for i in 0..len {
1791                            let phase = (i as f64) * 0.1;
1792                            let price = offset + amplitude * phase.sin();
1793                            let spread = price.abs() * spread_factor;
1794                            highs.push(price + spread / 2.0);
1795                            lows.push(price - spread / 2.0);
1796                        }
1797                    }
1798                    6 => {
1799                        let step_size = 100.0;
1800                        let steps = 5;
1801                        for i in 0..len {
1802                            let step = (i * steps / len) as f64;
1803                            let price = base_prices[0] + step * step_size;
1804                            let spread = price * spread_factor;
1805                            highs.push(price + spread / 2.0);
1806                            lows.push(price - spread / 2.0);
1807                        }
1808                    }
1809                    7 => {
1810                        for (i, &base) in base_prices.iter().enumerate() {
1811                            let volatility = ((i as f64 * 0.1).sin() + 1.0) * 0.5;
1812                            let price = base * (1.0 + spread_factor * volatility);
1813                            let spread = price * spread_factor;
1814                            highs.push(price + spread / 2.0);
1815                            lows.push(price - spread / 2.0);
1816                        }
1817                    }
1818                    8 => {
1819                        for i in 0..len {
1820                            let price = base_prices[0] + (i as f64) * 5.0;
1821                            let spread = price * spread_factor.min(0.1);
1822                            highs.push(price + spread / 2.0);
1823                            lows.push(price - spread / 2.0);
1824                        }
1825                    }
1826                    _ => {
1827                        let mid_point = len / 2;
1828                        for i in 0..len {
1829                            let base = if i < mid_point {
1830                                base_prices[0] + (i as f64) * 20.0
1831                            } else {
1832                                base_prices[0] + (mid_point as f64) * 20.0
1833                                    - ((i - mid_point) as f64) * 20.0
1834                            };
1835                            let spread = base * spread_factor.min(0.05);
1836                            highs.push(base + spread / 2.0);
1837                            lows.push(base - spread / 2.0);
1838                        }
1839                    }
1840                }
1841
1842                (highs, lows, period)
1843            });
1844
1845        proptest::test_runner::TestRunner::default()
1846			.run(&strat, |(highs, lows, period)| {
1847				let params = MidpriceParams { period: Some(period) };
1848				let input = MidpriceInput::from_slices(&highs, &lows, params);
1849
1850
1851				let MidpriceOutput { values: out } = midprice_with_kernel(&input, kernel)?;
1852				let MidpriceOutput { values: ref_out } = midprice_with_kernel(&input, Kernel::Scalar)?;
1853
1854
1855				let first_valid = (0..highs.len())
1856					.find(|&i| !highs[i].is_nan() && !lows[i].is_nan())
1857					.unwrap_or(0);
1858
1859
1860				let warmup_end = first_valid + period - 1;
1861				for i in 0..warmup_end.min(out.len()) {
1862					prop_assert!(
1863						out[i].is_nan(),
1864						"Expected NaN during warmup at index {}, got {}",
1865						i, out[i]
1866					);
1867				}
1868
1869
1870				for i in warmup_end..out.len() {
1871					let y = out[i];
1872					let r = ref_out[i];
1873					let window_start = i + 1 - period;
1874
1875
1876					let window_high = &highs[window_start..=i];
1877					let window_low = &lows[window_start..=i];
1878					let max_high = window_high.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1879					let min_low = window_low.iter().cloned().fold(f64::INFINITY, f64::min);
1880
1881
1882					let magnitude = y.abs().max(1.0);
1883					let tolerance = (magnitude * f64::EPSILON * 10.0).max(1e-9);
1884
1885					prop_assert!(
1886						y.is_nan() || (y >= min_low - tolerance && y <= max_high + tolerance),
1887						"Midprice {} at index {} outside bounds [{}, {}]",
1888						y, i, min_low, max_high
1889					);
1890
1891
1892					if y.is_finite() && r.is_finite() {
1893						let ulp_diff = y.to_bits().abs_diff(r.to_bits());
1894						prop_assert!(
1895							(y - r).abs() <= tolerance || ulp_diff <= 8,
1896							"Kernel mismatch at index {}: {} vs {} (ULP={})",
1897							i, y, r, ulp_diff
1898						);
1899					} else {
1900						prop_assert_eq!(
1901							y.to_bits(), r.to_bits(),
1902							"NaN/finite mismatch at index {}: {} vs {}",
1903							i, y, r
1904						);
1905					}
1906
1907
1908					if period == 1 {
1909						let expected = (highs[i] + lows[i]) / 2.0;
1910						prop_assert!(
1911							(y - expected).abs() <= tolerance,
1912							"Period=1 midprice {} != expected {} at index {}",
1913							y, expected, i
1914						);
1915					}
1916
1917
1918					if window_high.windows(2).all(|w| (w[0] - w[1]).abs() < tolerance) &&
1919					   window_low.windows(2).all(|w| (w[0] - w[1]).abs() < tolerance) {
1920						let expected = (window_high[0] + window_low[0]) / 2.0;
1921						prop_assert!(
1922							(y - expected).abs() <= tolerance,
1923							"Constant data midprice {} != expected {} at index {}",
1924							y, expected, i
1925						);
1926					}
1927
1928
1929					let actual_max_high = window_high.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1930					let actual_min_low = window_low.iter().cloned().fold(f64::INFINITY, f64::min);
1931					let expected = (actual_max_high + actual_min_low) / 2.0;
1932					prop_assert!(
1933						(y - expected).abs() <= tolerance,
1934						"Midprice {} != expected {} at index {}",
1935						y, expected, i
1936					);
1937
1938
1939
1940
1941
1942					if period == 1 {
1943						prop_assert!(
1944							y >= lows[i] - tolerance && y <= highs[i] + tolerance,
1945							"When period=1, midprice {} should be within current bar's range [{}, {}] at index {}",
1946							y, lows[i], highs[i], i
1947						);
1948					}
1949
1950
1951
1952
1953					if i > warmup_end && window_high.len() > 1 && window_low.len() > 1 {
1954						let high_increasing = window_high.windows(2).all(|w| w[1] >= w[0] - tolerance);
1955						let high_decreasing = window_high.windows(2).all(|w| w[1] <= w[0] + tolerance);
1956						let low_increasing = window_low.windows(2).all(|w| w[1] >= w[0] - tolerance);
1957						let low_decreasing = window_low.windows(2).all(|w| w[1] <= w[0] + tolerance);
1958
1959
1960						if high_increasing && low_increasing {
1961
1962							if i > warmup_end {
1963								let prev_y = out[i - 1];
1964								if prev_y.is_finite() && y.is_finite() {
1965
1966
1967									let allowed_decrease = max_high * 0.1;
1968									prop_assert!(
1969										y >= prev_y - allowed_decrease,
1970										"Midprice should generally increase with monotonic increasing data: {} < {} - {} at index {}",
1971										y, prev_y, allowed_decrease, i
1972									);
1973								}
1974							}
1975						}
1976					}
1977				}
1978
1979				Ok(())
1980			})?;
1981
1982        Ok(())
1983    }
1984
1985    macro_rules! generate_all_midprice_tests {
1986        ($($test_fn:ident),*) => {
1987            paste::paste! {
1988                $(
1989                    #[test]
1990                    fn [<$test_fn _scalar_f64>]() {
1991                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1992                    }
1993                )*
1994                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1995                $(
1996                    #[test]
1997                    fn [<$test_fn _avx2_f64>]() {
1998                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1999                    }
2000                    #[test]
2001                    fn [<$test_fn _avx512_f64>]() {
2002                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2003                    }
2004                )*
2005            }
2006        }
2007    }
2008
2009    generate_all_midprice_tests!(
2010        check_midprice_partial_params,
2011        check_midprice_accuracy,
2012        check_midprice_default_candles,
2013        check_midprice_zero_period,
2014        check_midprice_period_exceeds_length,
2015        check_midprice_very_small_dataset,
2016        check_midprice_reinput,
2017        check_midprice_nan_handling,
2018        check_midprice_streaming,
2019        check_midprice_all_nan,
2020        check_midprice_no_poison
2021    );
2022
2023    #[cfg(test)]
2024    generate_all_midprice_tests!(check_midprice_property);
2025
2026    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2027        skip_if_unsupported!(kernel, test);
2028
2029        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2030        let c = read_candles_from_csv(file)?;
2031
2032        let output = MidpriceBatchBuilder::new()
2033            .kernel(kernel)
2034            .apply_candles(&c, "high", "low")?;
2035
2036        let def = MidpriceParams::default();
2037        let row = output.values_for(&def).expect("default row missing");
2038
2039        assert_eq!(row.len(), c.close.len());
2040
2041        let expected = [59583.0, 59583.0, 59583.0, 59486.0, 58989.0];
2042        let start = row.len() - 5;
2043        for (i, &v) in row[start..].iter().enumerate() {
2044            assert!(
2045                (v - expected[i]).abs() < 1e-1,
2046                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2047            );
2048        }
2049        Ok(())
2050    }
2051
2052    #[cfg(debug_assertions)]
2053    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2054        skip_if_unsupported!(kernel, test);
2055
2056        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2057        let c = read_candles_from_csv(file)?;
2058
2059        let test_configs = vec![
2060            (2, 10, 2),
2061            (5, 25, 5),
2062            (30, 60, 15),
2063            (2, 5, 1),
2064            (10, 10, 0),
2065            (14, 14, 0),
2066            (50, 100, 25),
2067        ];
2068
2069        for (cfg_idx, &(period_start, period_end, period_step)) in test_configs.iter().enumerate() {
2070            let output = MidpriceBatchBuilder::new()
2071                .kernel(kernel)
2072                .period_range(period_start, period_end, period_step)
2073                .apply_candles(&c, "high", "low")?;
2074
2075            for (idx, &val) in output.values.iter().enumerate() {
2076                if val.is_nan() {
2077                    continue;
2078                }
2079
2080                let bits = val.to_bits();
2081                let row = idx / output.cols;
2082                let col = idx % output.cols;
2083                let combo = &output.combos[row];
2084
2085                if bits == 0x11111111_11111111 {
2086                    panic!(
2087                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2088						 at row {} col {} (flat index {}) with params: period={}",
2089                        test,
2090                        cfg_idx,
2091                        val,
2092                        bits,
2093                        row,
2094                        col,
2095                        idx,
2096                        combo.period.unwrap_or(14)
2097                    );
2098                }
2099
2100                if bits == 0x22222222_22222222 {
2101                    panic!(
2102                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2103						 at row {} col {} (flat index {}) with params: period={}",
2104                        test,
2105                        cfg_idx,
2106                        val,
2107                        bits,
2108                        row,
2109                        col,
2110                        idx,
2111                        combo.period.unwrap_or(14)
2112                    );
2113                }
2114
2115                if bits == 0x33333333_33333333 {
2116                    panic!(
2117                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2118						 at row {} col {} (flat index {}) with params: period={}",
2119                        test,
2120                        cfg_idx,
2121                        val,
2122                        bits,
2123                        row,
2124                        col,
2125                        idx,
2126                        combo.period.unwrap_or(14)
2127                    );
2128                }
2129            }
2130        }
2131
2132        Ok(())
2133    }
2134
2135    #[cfg(not(debug_assertions))]
2136    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2137        Ok(())
2138    }
2139
2140    macro_rules! gen_batch_tests {
2141        ($fn_name:ident) => {
2142            paste::paste! {
2143                #[test] fn [<$fn_name _scalar>]()      {
2144                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2145                }
2146                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2147                #[test] fn [<$fn_name _avx2>]()        {
2148                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2149                }
2150                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2151                #[test] fn [<$fn_name _avx512>]()      {
2152                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2153                }
2154                #[test] fn [<$fn_name _auto_detect>]() {
2155                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2156                }
2157            }
2158        };
2159    }
2160    gen_batch_tests!(check_batch_default_row);
2161    gen_batch_tests!(check_batch_no_poison);
2162
2163    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2164    #[test]
2165    fn test_midprice_into_matches_api() -> Result<(), Box<dyn Error>> {
2166        let n = 256usize;
2167        let mut high = Vec::with_capacity(n);
2168        let mut low = Vec::with_capacity(n);
2169        for i in 0..n {
2170            let t = i as f64;
2171            let h = 100.0 + 0.1 * t + (0.03 * t).sin();
2172            let l = h - (2.0 + ((i % 7) as f64));
2173            high.push(h);
2174            low.push(l);
2175        }
2176
2177        let params = MidpriceParams::default();
2178        let input = MidpriceInput::from_slices(&high, &low, params.clone());
2179
2180        let baseline = midprice(&input)?.values;
2181
2182        let mut out = vec![0.0; n];
2183        midprice_into(&input, &mut out)?;
2184
2185        assert_eq!(baseline.len(), out.len());
2186
2187        fn eq_or_both_nan(a: f64, b: f64) -> bool {
2188            (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
2189        }
2190        for i in 0..n {
2191            assert!(
2192                eq_or_both_nan(baseline[i], out[i]),
2193                "Mismatch at index {}: baseline={}, into={}",
2194                i,
2195                baseline[i],
2196                out[i]
2197            );
2198        }
2199        Ok(())
2200    }
2201}
2202
2203#[cfg(feature = "python")]
2204pub fn register_midprice_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
2205    m.add_function(wrap_pyfunction!(midprice_py, m)?)?;
2206    m.add_function(wrap_pyfunction!(midprice_batch_py, m)?)?;
2207    Ok(())
2208}