1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
10use serde::{Deserialize, Serialize};
11#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
12use wasm_bindgen::prelude::*;
13
14use crate::utilities::data_loader::{source_type, Candles};
15use crate::utilities::enums::Kernel;
16use crate::utilities::helpers::{
17 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
18 make_uninit_matrix,
19};
20#[cfg(feature = "python")]
21use crate::utilities::kernel_validation::validate_kernel;
22#[cfg(not(target_arch = "wasm32"))]
23use rayon::prelude::*;
24use std::collections::VecDeque;
25use std::convert::AsRef;
26use std::error::Error;
27use std::mem::MaybeUninit;
28use thiserror::Error;
29
30#[derive(Debug, Clone)]
31pub enum MidpointData<'a> {
32 Candles {
33 candles: &'a Candles,
34 source: &'a str,
35 },
36 Slice(&'a [f64]),
37}
38
39impl<'a> AsRef<[f64]> for MidpointInput<'a> {
40 #[inline(always)]
41 fn as_ref(&self) -> &[f64] {
42 match &self.data {
43 MidpointData::Slice(slice) => slice,
44 MidpointData::Candles { candles, source } => source_type(candles, source),
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
50pub struct MidpointOutput {
51 pub values: Vec<f64>,
52}
53
54#[derive(Debug, Clone)]
55#[cfg_attr(
56 all(target_arch = "wasm32", feature = "wasm"),
57 derive(Serialize, Deserialize)
58)]
59pub struct MidpointParams {
60 pub period: Option<usize>,
61}
62
63impl Default for MidpointParams {
64 fn default() -> Self {
65 Self { period: Some(14) }
66 }
67}
68
69#[derive(Debug, Clone)]
70pub struct MidpointInput<'a> {
71 pub data: MidpointData<'a>,
72 pub params: MidpointParams,
73}
74
75impl<'a> MidpointInput<'a> {
76 #[inline]
77 pub fn from_candles(c: &'a Candles, s: &'a str, p: MidpointParams) -> Self {
78 Self {
79 data: MidpointData::Candles {
80 candles: c,
81 source: s,
82 },
83 params: p,
84 }
85 }
86 #[inline]
87 pub fn from_slice(sl: &'a [f64], p: MidpointParams) -> Self {
88 Self {
89 data: MidpointData::Slice(sl),
90 params: p,
91 }
92 }
93 #[inline]
94 pub fn with_default_candles(c: &'a Candles) -> Self {
95 Self::from_candles(c, "close", MidpointParams::default())
96 }
97 #[inline]
98 pub fn get_period(&self) -> usize {
99 self.params.period.unwrap_or(14)
100 }
101}
102
103#[derive(Copy, Clone, Debug)]
104pub struct MidpointBuilder {
105 period: Option<usize>,
106 kernel: Kernel,
107}
108
109impl Default for MidpointBuilder {
110 fn default() -> Self {
111 Self {
112 period: None,
113 kernel: Kernel::Auto,
114 }
115 }
116}
117
118impl MidpointBuilder {
119 #[inline(always)]
120 pub fn new() -> Self {
121 Self::default()
122 }
123 #[inline(always)]
124 pub fn period(mut self, n: usize) -> Self {
125 self.period = Some(n);
126 self
127 }
128 #[inline(always)]
129 pub fn kernel(mut self, k: Kernel) -> Self {
130 self.kernel = k;
131 self
132 }
133 #[inline(always)]
134 pub fn apply(self, c: &Candles) -> Result<MidpointOutput, MidpointError> {
135 let p = MidpointParams {
136 period: self.period,
137 };
138 let i = MidpointInput::from_candles(c, "close", p);
139 midpoint_with_kernel(&i, self.kernel)
140 }
141 #[inline(always)]
142 pub fn apply_slice(self, d: &[f64]) -> Result<MidpointOutput, MidpointError> {
143 let p = MidpointParams {
144 period: self.period,
145 };
146 let i = MidpointInput::from_slice(d, p);
147 midpoint_with_kernel(&i, self.kernel)
148 }
149 #[inline(always)]
150 pub fn into_stream(self) -> Result<MidpointStream, MidpointError> {
151 let p = MidpointParams {
152 period: self.period,
153 };
154 MidpointStream::try_new(p)
155 }
156}
157
158#[derive(Debug, Error)]
159pub enum MidpointError {
160 #[error("midpoint: Empty input data (All values are NaN).")]
161 EmptyInputData,
162 #[error("midpoint: All values are NaN.")]
163 AllValuesNaN,
164 #[error("midpoint: Invalid period: period = {period}, data length = {data_len}")]
165 InvalidPeriod { period: usize, data_len: usize },
166 #[error("midpoint: Not enough valid data: needed = {needed}, valid = {valid}")]
167 NotEnoughValidData { needed: usize, valid: usize },
168 #[error("midpoint: output length mismatch: expected = {expected}, got = {got}")]
169 OutputLengthMismatch { expected: usize, got: usize },
170 #[error("midpoint: invalid kernel for batch: {0:?}")]
171 InvalidKernelForBatch(Kernel),
172 #[error("midpoint: invalid range: start={start}, end={end}, step={step}")]
173 InvalidRange {
174 start: usize,
175 end: usize,
176 step: usize,
177 },
178 #[error("midpoint: invalid input: {0}")]
179 InvalidInput(String),
180}
181
182#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
183impl From<MidpointError> for JsValue {
184 fn from(err: MidpointError) -> Self {
185 JsValue::from_str(&err.to_string())
186 }
187}
188
189#[inline]
190pub fn midpoint(input: &MidpointInput) -> Result<MidpointOutput, MidpointError> {
191 midpoint_with_kernel(input, Kernel::Auto)
192}
193
194#[inline(always)]
195fn midpoint_prepare<'a>(
196 input: &'a MidpointInput,
197) -> Result<(&'a [f64], usize, usize, usize), MidpointError> {
198 let data: &[f64] = input.as_ref();
199 if data.is_empty() {
200 return Err(MidpointError::EmptyInputData);
201 }
202
203 let first = data
204 .iter()
205 .position(|x| !x.is_nan())
206 .ok_or(MidpointError::AllValuesNaN)?;
207 let len = data.len();
208 let period = input.get_period();
209
210 if period == 0 || period > len {
211 return Err(MidpointError::InvalidPeriod {
212 period,
213 data_len: len,
214 });
215 }
216 if (len - first) < period {
217 return Err(MidpointError::NotEnoughValidData {
218 needed: period,
219 valid: len - first,
220 });
221 }
222
223 Ok((data, period, first, len))
224}
225
226pub fn midpoint_with_kernel(
227 input: &MidpointInput,
228 kernel: Kernel,
229) -> Result<MidpointOutput, MidpointError> {
230 let (data, period, first, len) = midpoint_prepare(input)?;
231 let mut out = alloc_with_nan_prefix(len, first + period - 1);
232
233 let chosen = match kernel {
234 Kernel::Auto => detect_best_kernel(),
235 other => other,
236 };
237
238 unsafe {
239 match chosen {
240 Kernel::Scalar | Kernel::ScalarBatch => midpoint_scalar(data, period, first, &mut out),
241 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
242 Kernel::Avx2 | Kernel::Avx2Batch => midpoint_avx2(data, period, first, &mut out),
243 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
244 Kernel::Avx512 | Kernel::Avx512Batch => midpoint_avx512(data, period, first, &mut out),
245 _ => unreachable!(),
246 }
247 }
248
249 Ok(MidpointOutput { values: out })
250}
251
252#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
253#[inline]
254pub fn midpoint_into(input: &MidpointInput, out: &mut [f64]) -> Result<(), MidpointError> {
255 midpoint_into_slice(out, input, Kernel::Auto)
256}
257
258#[inline]
259pub fn midpoint_into_slice(
260 out: &mut [f64],
261 input: &MidpointInput,
262 kernel: Kernel,
263) -> Result<(), MidpointError> {
264 let (data, period, first, len) = midpoint_prepare(input)?;
265
266 if out.len() != len {
267 return Err(MidpointError::OutputLengthMismatch {
268 expected: len,
269 got: out.len(),
270 });
271 }
272
273 for i in 0..(first + period - 1) {
274 out[i] = f64::NAN;
275 }
276
277 let chosen = match kernel {
278 Kernel::Auto => detect_best_kernel(),
279 other => other,
280 };
281
282 unsafe {
283 match chosen {
284 Kernel::Scalar | Kernel::ScalarBatch => midpoint_scalar(data, period, first, out),
285 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
286 Kernel::Avx2 | Kernel::Avx2Batch => midpoint_avx2(data, period, first, out),
287 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
288 Kernel::Avx512 | Kernel::Avx512Batch => midpoint_avx512(data, period, first, out),
289 _ => unreachable!(),
290 }
291 }
292
293 Ok(())
294}
295
296#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
297#[inline]
298pub fn midpoint_avx512(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
299 unsafe { midpoint_avx512_long(data, period, first, out) }
300}
301
302#[inline]
303pub fn midpoint_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
304 for i in (first + period - 1)..data.len() {
305 let window = &data[(i + 1 - period)..=i];
306 let mut highest = f64::MIN;
307 let mut lowest = f64::MAX;
308 for &val in window {
309 if val > highest {
310 highest = val;
311 }
312 if val < lowest {
313 lowest = val;
314 }
315 }
316 out[i] = (highest + lowest) / 2.0;
317 }
318}
319
320#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
321#[inline]
322pub fn midpoint_avx2(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
323 midpoint_scalar(data, period, first, out)
324}
325
326#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
327#[inline]
328pub fn midpoint_avx512_short(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
329 midpoint_scalar(data, period, first, out)
330}
331
332#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
333#[inline]
334pub fn midpoint_avx512_long(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
335 midpoint_scalar(data, period, first, out)
336}
337
338#[derive(Debug, Clone)]
339pub struct MidpointStream {
340 period: usize,
341
342 maxdq: VecDeque<(usize, f64)>,
343
344 mindq: VecDeque<(usize, f64)>,
345
346 idx: usize,
347
348 warmup_left: Option<usize>,
349}
350
351impl MidpointStream {
352 pub fn try_new(params: MidpointParams) -> Result<Self, MidpointError> {
353 let period = params.period.unwrap_or(14);
354 if period == 0 {
355 return Err(MidpointError::InvalidPeriod {
356 period,
357 data_len: 0,
358 });
359 }
360 Ok(Self {
361 period,
362 maxdq: VecDeque::with_capacity(period),
363 mindq: VecDeque::with_capacity(period),
364 idx: 0,
365 warmup_left: None,
366 })
367 }
368
369 #[inline(always)]
370 pub fn update(&mut self, value: f64) -> Option<f64> {
371 let i = self.idx;
372 self.idx = i.wrapping_add(1);
373
374 let start = i.saturating_add(1).saturating_sub(self.period);
375
376 while let Some(&(j, _)) = self.maxdq.front() {
377 if j < start {
378 self.maxdq.pop_front();
379 } else {
380 break;
381 }
382 }
383 while let Some(&(j, _)) = self.mindq.front() {
384 if j < start {
385 self.mindq.pop_front();
386 } else {
387 break;
388 }
389 }
390
391 if !value.is_nan() {
392 if self.warmup_left.is_none() {
393 self.warmup_left = Some(self.period.saturating_sub(1));
394 }
395
396 while let Some(&(_, v)) = self.maxdq.back() {
397 if v <= value {
398 self.maxdq.pop_back();
399 } else {
400 break;
401 }
402 }
403 self.maxdq.push_back((i, value));
404
405 while let Some(&(_, v)) = self.mindq.back() {
406 if v >= value {
407 self.mindq.pop_back();
408 } else {
409 break;
410 }
411 }
412 self.mindq.push_back((i, value));
413 }
414
415 match self.warmup_left {
416 None => return None,
417 Some(k) if k > 0 => {
418 self.warmup_left = Some(k - 1);
419 return None;
420 }
421 _ => {}
422 }
423
424 let hi = self.maxdq.front().map(|&(_, v)| v).unwrap_or(f64::MIN);
425 let lo = self.mindq.front().map(|&(_, v)| v).unwrap_or(f64::MAX);
426 Some(avg2_fast(hi, lo))
427 }
428}
429
430#[inline(always)]
431fn avg2_fast(a: f64, b: f64) -> f64 {
432 (a + b).mul_add(0.5, 0.0)
433}
434
435#[derive(Clone, Debug)]
436pub struct MidpointBatchRange {
437 pub period: (usize, usize, usize),
438}
439impl Default for MidpointBatchRange {
440 fn default() -> Self {
441 Self {
442 period: (14, 263, 1),
443 }
444 }
445}
446#[derive(Clone, Debug, Default)]
447pub struct MidpointBatchBuilder {
448 range: MidpointBatchRange,
449 kernel: Kernel,
450}
451impl MidpointBatchBuilder {
452 pub fn new() -> Self {
453 Self::default()
454 }
455 pub fn kernel(mut self, k: Kernel) -> Self {
456 self.kernel = k;
457 self
458 }
459 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
460 self.range.period = (start, end, step);
461 self
462 }
463 pub fn period_static(mut self, p: usize) -> Self {
464 self.range.period = (p, p, 0);
465 self
466 }
467 pub fn apply_slice(self, data: &[f64]) -> Result<MidpointBatchOutput, MidpointError> {
468 midpoint_batch_with_kernel(data, &self.range, self.kernel)
469 }
470 pub fn with_default_slice(
471 data: &[f64],
472 k: Kernel,
473 ) -> Result<MidpointBatchOutput, MidpointError> {
474 MidpointBatchBuilder::new().kernel(k).apply_slice(data)
475 }
476 pub fn apply_candles(
477 self,
478 c: &Candles,
479 src: &str,
480 ) -> Result<MidpointBatchOutput, MidpointError> {
481 let slice = source_type(c, src);
482 self.apply_slice(slice)
483 }
484 pub fn with_default_candles(c: &Candles) -> Result<MidpointBatchOutput, MidpointError> {
485 MidpointBatchBuilder::new()
486 .kernel(Kernel::Auto)
487 .apply_candles(c, "close")
488 }
489}
490
491pub fn midpoint_batch_with_kernel(
492 data: &[f64],
493 sweep: &MidpointBatchRange,
494 k: Kernel,
495) -> Result<MidpointBatchOutput, MidpointError> {
496 let kernel = match k {
497 Kernel::Auto => detect_best_batch_kernel(),
498 other if other.is_batch() => other,
499 other => return Err(MidpointError::InvalidKernelForBatch(other)),
500 };
501 let simd = match kernel {
502 Kernel::Avx512Batch => Kernel::Avx512,
503 Kernel::Avx2Batch => Kernel::Avx2,
504 Kernel::ScalarBatch => Kernel::Scalar,
505 _ => unreachable!(),
506 };
507 midpoint_batch_par_slice(data, sweep, simd)
508}
509
510#[derive(Clone, Debug)]
511pub struct MidpointBatchOutput {
512 pub values: Vec<f64>,
513 pub combos: Vec<MidpointParams>,
514 pub rows: usize,
515 pub cols: usize,
516}
517impl MidpointBatchOutput {
518 pub fn row_for_params(&self, p: &MidpointParams) -> Option<usize> {
519 self.combos
520 .iter()
521 .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
522 }
523 pub fn values_for(&self, p: &MidpointParams) -> Option<&[f64]> {
524 self.row_for_params(p).map(|row| {
525 let start = row * self.cols;
526 &self.values[start..start + self.cols]
527 })
528 }
529}
530
531#[inline(always)]
532fn expand_grid_checked(r: &MidpointBatchRange) -> Result<Vec<MidpointParams>, MidpointError> {
533 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, MidpointError> {
534 if step == 0 || start == end {
535 return Ok(vec![start]);
536 }
537 let mut out = Vec::new();
538 if start < end {
539 let mut v = start;
540 loop {
541 if v > end {
542 break;
543 }
544 out.push(v);
545 match v.checked_add(step) {
546 Some(next) => {
547 if next == v {
548 break;
549 }
550 v = next;
551 }
552 None => break,
553 }
554 }
555 } else {
556 let mut v = start;
557 loop {
558 if v < end {
559 break;
560 }
561 out.push(v);
562 if v == end {
563 break;
564 }
565 let next = match v.checked_sub(step) {
566 Some(n) => n,
567 None => break,
568 };
569 if next > v {
570 break;
571 }
572 v = next;
573 }
574 }
575 if out.is_empty() {
576 return Err(MidpointError::InvalidRange { start, end, step });
577 }
578 Ok(out)
579 }
580 let periods = axis_usize(r.period)?;
581 let mut out = Vec::with_capacity(periods.len());
582 for &p in &periods {
583 out.push(MidpointParams { period: Some(p) });
584 }
585 Ok(out)
586}
587
588#[inline(always)]
589pub fn midpoint_batch_slice(
590 data: &[f64],
591 sweep: &MidpointBatchRange,
592 kern: Kernel,
593) -> Result<MidpointBatchOutput, MidpointError> {
594 midpoint_batch_inner(data, sweep, kern, false)
595}
596
597#[inline(always)]
598pub fn midpoint_batch_par_slice(
599 data: &[f64],
600 sweep: &MidpointBatchRange,
601 kern: Kernel,
602) -> Result<MidpointBatchOutput, MidpointError> {
603 midpoint_batch_inner(data, sweep, kern, true)
604}
605
606#[inline(always)]
607fn midpoint_batch_inner(
608 data: &[f64],
609 sweep: &MidpointBatchRange,
610 kern: Kernel,
611 parallel: bool,
612) -> Result<MidpointBatchOutput, MidpointError> {
613 if data.is_empty() {
614 return Err(MidpointError::EmptyInputData);
615 }
616
617 let combos = expand_grid_checked(sweep)?;
618
619 let first = data
620 .iter()
621 .position(|x| !x.is_nan())
622 .ok_or(MidpointError::AllValuesNaN)?;
623 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
624 if data.len() - first < max_p {
625 return Err(MidpointError::NotEnoughValidData {
626 needed: max_p,
627 valid: data.len() - first,
628 });
629 }
630
631 let rows = combos.len();
632 let cols = data.len();
633
634 let mut buf_mu = make_uninit_matrix(rows, cols);
635 let warm_prefixes: Vec<usize> = combos
636 .iter()
637 .map(|c| first + c.period.unwrap() - 1)
638 .collect();
639 init_matrix_prefixes(&mut buf_mu, cols, &warm_prefixes);
640
641 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
642 let out_mu: &mut [MaybeUninit<f64>] =
643 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr(), guard.len()) };
644
645 let do_row = |row: usize, row_mu: &mut [MaybeUninit<f64>]| unsafe {
646 let period = combos[row].period.unwrap();
647
648 let row_out: &mut [f64] =
649 core::slice::from_raw_parts_mut(row_mu.as_mut_ptr() as *mut f64, row_mu.len());
650
651 match kern {
652 Kernel::Scalar => midpoint_row_scalar(data, first, period, row_out),
653 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
654 Kernel::Avx2 => midpoint_row_avx2(data, first, period, row_out),
655 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
656 Kernel::Avx512 => midpoint_row_avx512(data, first, period, row_out),
657 _ => unreachable!(),
658 }
659 };
660
661 if parallel {
662 #[cfg(not(target_arch = "wasm32"))]
663 {
664 out_mu
665 .par_chunks_mut(cols)
666 .enumerate()
667 .for_each(|(row, slice)| do_row(row, slice));
668 }
669 #[cfg(target_arch = "wasm32")]
670 {
671 for (row, slice) in out_mu.chunks_mut(cols).enumerate() {
672 do_row(row, slice);
673 }
674 }
675 } else {
676 for (row, slice) in out_mu.chunks_mut(cols).enumerate() {
677 do_row(row, slice);
678 }
679 }
680
681 let values: Vec<f64> = unsafe {
682 Vec::from_raw_parts(
683 guard.as_mut_ptr() as *mut f64,
684 guard.len(),
685 guard.capacity(),
686 )
687 };
688
689 Ok(MidpointBatchOutput {
690 values,
691 combos,
692 rows,
693 cols,
694 })
695}
696
697#[inline(always)]
698unsafe fn midpoint_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
699 midpoint_scalar(data, period, first, out)
700}
701#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
702#[inline(always)]
703unsafe fn midpoint_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
704 midpoint_scalar(data, period, first, out)
705}
706#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
707#[inline(always)]
708unsafe fn midpoint_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
709 midpoint_scalar(data, period, first, out)
710}
711#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
712#[inline(always)]
713unsafe fn midpoint_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
714 midpoint_scalar(data, period, first, out)
715}
716#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
717#[inline(always)]
718unsafe fn midpoint_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
719 midpoint_scalar(data, period, first, out)
720}
721
722#[cfg(test)]
723mod tests {
724 use super::*;
725 use crate::skip_if_unsupported;
726 use crate::utilities::data_loader::read_candles_from_csv;
727 #[cfg(feature = "proptest")]
728 use proptest::prelude::*;
729
730 fn check_midpoint_partial_params(
731 test_name: &str,
732 kernel: Kernel,
733 ) -> Result<(), Box<dyn Error>> {
734 skip_if_unsupported!(kernel, test_name);
735 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
736 let candles = read_candles_from_csv(file_path)?;
737 let default_params = MidpointParams { period: None };
738 let input = MidpointInput::from_candles(&candles, "close", default_params);
739 let output = midpoint_with_kernel(&input, kernel)?;
740 assert_eq!(output.values.len(), candles.close.len());
741 Ok(())
742 }
743 fn check_midpoint_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
744 skip_if_unsupported!(kernel, test_name);
745 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
746 let candles = read_candles_from_csv(file_path)?;
747 let input = MidpointInput::from_candles(&candles, "close", MidpointParams::default());
748 let result = midpoint_with_kernel(&input, kernel)?;
749 let expected_last_five = [59578.5, 59578.5, 59578.5, 58886.0, 58886.0];
750 let start = result.values.len().saturating_sub(5);
751 for (i, &val) in result.values[start..].iter().enumerate() {
752 let diff = (val - expected_last_five[i]).abs();
753 assert!(
754 diff < 1e-1,
755 "[{}] MIDPOINT {:?} mismatch at idx {}: got {}, expected {}",
756 test_name,
757 kernel,
758 i,
759 val,
760 expected_last_five[i]
761 );
762 }
763 Ok(())
764 }
765 fn check_midpoint_default_candles(
766 test_name: &str,
767 kernel: Kernel,
768 ) -> Result<(), Box<dyn Error>> {
769 skip_if_unsupported!(kernel, test_name);
770 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
771 let candles = read_candles_from_csv(file_path)?;
772 let input = MidpointInput::with_default_candles(&candles);
773 match input.data {
774 MidpointData::Candles { source, .. } => assert_eq!(source, "close"),
775 _ => panic!("Expected MidpointData::Candles"),
776 }
777 let output = midpoint_with_kernel(&input, kernel)?;
778 assert_eq!(output.values.len(), candles.close.len());
779 Ok(())
780 }
781 fn check_midpoint_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
782 skip_if_unsupported!(kernel, test_name);
783 let input_data = [10.0, 20.0, 30.0];
784 let params = MidpointParams { period: Some(0) };
785 let input = MidpointInput::from_slice(&input_data, params);
786 let res = midpoint_with_kernel(&input, kernel);
787 assert!(
788 res.is_err(),
789 "[{}] MIDPOINT should fail with zero period",
790 test_name
791 );
792 Ok(())
793 }
794 fn check_midpoint_period_exceeds_length(
795 test_name: &str,
796 kernel: Kernel,
797 ) -> Result<(), Box<dyn Error>> {
798 skip_if_unsupported!(kernel, test_name);
799 let data_small = [10.0, 20.0, 30.0];
800 let params = MidpointParams { period: Some(10) };
801 let input = MidpointInput::from_slice(&data_small, params);
802 let res = midpoint_with_kernel(&input, kernel);
803 assert!(
804 res.is_err(),
805 "[{}] MIDPOINT should fail with period exceeding length",
806 test_name
807 );
808 Ok(())
809 }
810 fn check_midpoint_very_small_dataset(
811 test_name: &str,
812 kernel: Kernel,
813 ) -> Result<(), Box<dyn Error>> {
814 skip_if_unsupported!(kernel, test_name);
815 let single_point = [42.0];
816 let params = MidpointParams { period: Some(9) };
817 let input = MidpointInput::from_slice(&single_point, params);
818 let res = midpoint_with_kernel(&input, kernel);
819 assert!(
820 res.is_err(),
821 "[{}] MIDPOINT should fail with insufficient data",
822 test_name
823 );
824 Ok(())
825 }
826 fn check_midpoint_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
827 skip_if_unsupported!(kernel, test_name);
828 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
829 let candles = read_candles_from_csv(file_path)?;
830 let first_params = MidpointParams { period: Some(14) };
831 let first_input = MidpointInput::from_candles(&candles, "close", first_params);
832 let first_result = midpoint_with_kernel(&first_input, kernel)?;
833 let second_params = MidpointParams { period: Some(14) };
834 let second_input = MidpointInput::from_slice(&first_result.values, second_params);
835 let second_result = midpoint_with_kernel(&second_input, kernel)?;
836 assert_eq!(second_result.values.len(), first_result.values.len());
837 Ok(())
838 }
839 fn check_midpoint_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
840 skip_if_unsupported!(kernel, test_name);
841 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
842 let candles = read_candles_from_csv(file_path)?;
843 let input = MidpointInput::from_candles(&candles, "close", MidpointParams::default());
844 let res = midpoint_with_kernel(&input, kernel)?;
845 assert_eq!(res.values.len(), candles.close.len());
846 if res.values.len() > 240 {
847 for (i, &val) in res.values[240..].iter().enumerate() {
848 assert!(
849 !val.is_nan(),
850 "[{}] Found unexpected NaN at out-index {}",
851 test_name,
852 240 + i
853 );
854 }
855 }
856 Ok(())
857 }
858 fn check_midpoint_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
859 skip_if_unsupported!(kernel, test_name);
860 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
861 let candles = read_candles_from_csv(file_path)?;
862 let period = 14;
863 let input = MidpointInput::from_candles(
864 &candles,
865 "close",
866 MidpointParams {
867 period: Some(period),
868 },
869 );
870 let batch_output = midpoint_with_kernel(&input, kernel)?.values;
871 let mut stream = MidpointStream::try_new(MidpointParams {
872 period: Some(period),
873 })?;
874 let mut stream_values = Vec::with_capacity(candles.close.len());
875 for &price in &candles.close {
876 match stream.update(price) {
877 Some(mid_val) => stream_values.push(mid_val),
878 None => stream_values.push(f64::NAN),
879 }
880 }
881 assert_eq!(batch_output.len(), stream_values.len());
882 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
883 if b.is_nan() && s.is_nan() {
884 continue;
885 }
886 let diff = (b - s).abs();
887 assert!(
888 diff < 1e-9,
889 "[{}] MIDPOINT streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
890 test_name,
891 i,
892 b,
893 s,
894 diff
895 );
896 }
897 Ok(())
898 }
899
900 fn check_midpoint_constant_data(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
901 skip_if_unsupported!(kernel, test_name);
902
903 let constant_val = 42.5;
904 let data = vec![constant_val; 100];
905 let params = MidpointParams { period: Some(10) };
906 let input = MidpointInput::from_slice(&data, params);
907 let result = midpoint_with_kernel(&input, kernel)?;
908
909 for i in 9..100 {
910 assert!(
911 (result.values[i] - constant_val).abs() < 1e-10,
912 "[{}] Constant data should produce constant midpoint at index {}: expected {}, got {}",
913 test_name, i, constant_val, result.values[i]
914 );
915 }
916 Ok(())
917 }
918
919 fn check_midpoint_monotonic_increasing(
920 test_name: &str,
921 kernel: Kernel,
922 ) -> Result<(), Box<dyn Error>> {
923 skip_if_unsupported!(kernel, test_name);
924
925 let data: Vec<f64> = (0..100).map(|i| i as f64).collect();
926 let period = 10;
927 let params = MidpointParams {
928 period: Some(period),
929 };
930 let input = MidpointInput::from_slice(&data, params);
931 let result = midpoint_with_kernel(&input, kernel)?;
932
933 for i in (period - 1)..data.len() {
934 let expected = (data[i + 1 - period] + data[i]) / 2.0;
935 assert!(
936 (result.values[i] - expected).abs() < 1e-10,
937 "[{}] Monotonic increasing midpoint mismatch at index {}: expected {}, got {}",
938 test_name,
939 i,
940 expected,
941 result.values[i]
942 );
943 }
944 Ok(())
945 }
946
947 fn check_midpoint_monotonic_decreasing(
948 test_name: &str,
949 kernel: Kernel,
950 ) -> Result<(), Box<dyn Error>> {
951 skip_if_unsupported!(kernel, test_name);
952
953 let data: Vec<f64> = (0..100).map(|i| 100.0 - i as f64).collect();
954 let period = 10;
955 let params = MidpointParams {
956 period: Some(period),
957 };
958 let input = MidpointInput::from_slice(&data, params);
959 let result = midpoint_with_kernel(&input, kernel)?;
960
961 for i in (period - 1)..data.len() {
962 let expected = (data[i] + data[i + 1 - period]) / 2.0;
963 assert!(
964 (result.values[i] - expected).abs() < 1e-10,
965 "[{}] Monotonic decreasing midpoint mismatch at index {}: expected {}, got {}",
966 test_name,
967 i,
968 expected,
969 result.values[i]
970 );
971 }
972 Ok(())
973 }
974
975 fn check_midpoint_alternating_extremes(
976 test_name: &str,
977 kernel: Kernel,
978 ) -> Result<(), Box<dyn Error>> {
979 skip_if_unsupported!(kernel, test_name);
980
981 let mut data = Vec::with_capacity(100);
982 for i in 0..100 {
983 data.push(if i % 2 == 0 { 100.0 } else { 0.0 });
984 }
985 let period = 5;
986 let params = MidpointParams {
987 period: Some(period),
988 };
989 let input = MidpointInput::from_slice(&data, params);
990 let result = midpoint_with_kernel(&input, kernel)?;
991
992 for i in (period - 1)..data.len() {
993 assert!(
994 (result.values[i] - 50.0).abs() < 1e-10,
995 "[{}] Alternating extremes midpoint should be 50 at index {}: got {}",
996 test_name,
997 i,
998 result.values[i]
999 );
1000 }
1001 Ok(())
1002 }
1003
1004 fn check_midpoint_single_spike(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1005 skip_if_unsupported!(kernel, test_name);
1006
1007 let mut data = vec![10.0; 100];
1008 data[50] = 100.0;
1009
1010 let period = 5;
1011 let params = MidpointParams {
1012 period: Some(period),
1013 };
1014 let input = MidpointInput::from_slice(&data, params);
1015 let result = midpoint_with_kernel(&input, kernel)?;
1016
1017 for i in (period - 1)..50 {
1018 assert!(
1019 (result.values[i] - 10.0).abs() < 1e-10,
1020 "[{}] Before spike, midpoint should be 10 at index {}: got {}",
1021 test_name,
1022 i,
1023 result.values[i]
1024 );
1025 }
1026
1027 for i in 50..55 {
1028 let expected = 55.0;
1029 assert!(
1030 (result.values[i] - expected).abs() < 1e-10,
1031 "[{}] During spike window, midpoint should be 55 at index {}: got {}",
1032 test_name,
1033 i,
1034 result.values[i]
1035 );
1036 }
1037
1038 for i in 55..data.len() {
1039 assert!(
1040 (result.values[i] - 10.0).abs() < 1e-10,
1041 "[{}] Well after spike, midpoint should be 10 at index {}: got {}",
1042 test_name,
1043 i,
1044 result.values[i]
1045 );
1046 }
1047 Ok(())
1048 }
1049
1050 fn check_midpoint_boundary_values(
1051 test_name: &str,
1052 kernel: Kernel,
1053 ) -> Result<(), Box<dyn Error>> {
1054 skip_if_unsupported!(kernel, test_name);
1055
1056 let data = vec![f64::MAX, f64::MIN, 0.0, 1e-300, 1e300];
1057 let params = MidpointParams { period: Some(2) };
1058 let input = MidpointInput::from_slice(&data, params.clone());
1059 let result = midpoint_with_kernel(&input, kernel)?;
1060
1061 assert!(
1062 result.values[1].is_finite(),
1063 "[{}] Result should be finite for extreme values",
1064 test_name
1065 );
1066
1067 let small_data = vec![1e-300, 2e-300, 3e-300, 4e-300, 5e-300];
1068 let input_small = MidpointInput::from_slice(&small_data, params);
1069 let result_small = midpoint_with_kernel(&input_small, kernel)?;
1070
1071 for i in 1..small_data.len() {
1072 assert!(
1073 result_small.values[i] > 0.0 && result_small.values[i].is_finite(),
1074 "[{}] Should handle very small values at index {}",
1075 test_name,
1076 i
1077 );
1078 }
1079 Ok(())
1080 }
1081
1082 fn check_midpoint_period_one(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1083 skip_if_unsupported!(kernel, test_name);
1084
1085 let data = vec![1.0, 5.0, 3.0, 7.0, 2.0, 9.0, 4.0];
1086 let params = MidpointParams { period: Some(1) };
1087 let input = MidpointInput::from_slice(&data, params);
1088 let result = midpoint_with_kernel(&input, kernel)?;
1089
1090 for i in 0..data.len() {
1091 assert!(
1092 (result.values[i] - data[i]).abs() < 1e-10,
1093 "[{}] Period=1 should return input value at index {}: expected {}, got {}",
1094 test_name,
1095 i,
1096 data[i],
1097 result.values[i]
1098 );
1099 }
1100 Ok(())
1101 }
1102
1103 fn check_midpoint_all_same_except_one(
1104 test_name: &str,
1105 kernel: Kernel,
1106 ) -> Result<(), Box<dyn Error>> {
1107 skip_if_unsupported!(kernel, test_name);
1108
1109 let mut data = vec![50.0; 100];
1110 data[30] = 150.0;
1111
1112 let period = 20;
1113 let params = MidpointParams {
1114 period: Some(period),
1115 };
1116 let input = MidpointInput::from_slice(&data, params);
1117 let result = midpoint_with_kernel(&input, kernel)?;
1118
1119 for i in (period - 1)..30 {
1120 assert!(
1121 (result.values[i] - 50.0).abs() < 1e-10,
1122 "[{}] Before outlier window, midpoint should be 50 at index {}: got {}",
1123 test_name,
1124 i,
1125 result.values[i]
1126 );
1127 }
1128
1129 for i in 30..50 {
1130 let expected = 100.0;
1131 assert!(
1132 (result.values[i] - expected).abs() < 1e-10,
1133 "[{}] With outlier in window, midpoint should be 100 at index {}: got {}",
1134 test_name,
1135 i,
1136 result.values[i]
1137 );
1138 }
1139
1140 for i in 50..data.len() {
1141 assert!(
1142 (result.values[i] - 50.0).abs() < 1e-10,
1143 "[{}] Well after outlier, midpoint should be 50 at index {}: got {}",
1144 test_name,
1145 i,
1146 result.values[i]
1147 );
1148 }
1149 Ok(())
1150 }
1151
1152 #[cfg(debug_assertions)]
1153 fn check_midpoint_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1154 skip_if_unsupported!(kernel, test_name);
1155
1156 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1157 let candles = read_candles_from_csv(file_path)?;
1158
1159 let test_params = vec![
1160 MidpointParams::default(),
1161 MidpointParams { period: Some(2) },
1162 MidpointParams { period: Some(5) },
1163 MidpointParams { period: Some(7) },
1164 MidpointParams { period: Some(10) },
1165 MidpointParams { period: Some(20) },
1166 MidpointParams { period: Some(30) },
1167 MidpointParams { period: Some(50) },
1168 MidpointParams { period: Some(100) },
1169 MidpointParams { period: Some(200) },
1170 ];
1171
1172 for (param_idx, params) in test_params.iter().enumerate() {
1173 let input = MidpointInput::from_candles(&candles, "close", params.clone());
1174 let output = midpoint_with_kernel(&input, kernel)?;
1175
1176 for (i, &val) in output.values.iter().enumerate() {
1177 if val.is_nan() {
1178 continue;
1179 }
1180
1181 let bits = val.to_bits();
1182
1183 if bits == 0x11111111_11111111 {
1184 panic!(
1185 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1186 with params: period={} (param set {})",
1187 test_name,
1188 val,
1189 bits,
1190 i,
1191 params.period.unwrap_or(14),
1192 param_idx
1193 );
1194 }
1195
1196 if bits == 0x22222222_22222222 {
1197 panic!(
1198 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1199 with params: period={} (param set {})",
1200 test_name,
1201 val,
1202 bits,
1203 i,
1204 params.period.unwrap_or(14),
1205 param_idx
1206 );
1207 }
1208
1209 if bits == 0x33333333_33333333 {
1210 panic!(
1211 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1212 with params: period={} (param set {})",
1213 test_name,
1214 val,
1215 bits,
1216 i,
1217 params.period.unwrap_or(14),
1218 param_idx
1219 );
1220 }
1221 }
1222 }
1223
1224 Ok(())
1225 }
1226
1227 #[cfg(not(debug_assertions))]
1228 fn check_midpoint_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1229 Ok(())
1230 }
1231
1232 #[cfg(feature = "proptest")]
1233 #[allow(clippy::float_cmp)]
1234 fn check_midpoint_property(
1235 test_name: &str,
1236 kernel: Kernel,
1237 ) -> Result<(), Box<dyn std::error::Error>> {
1238 skip_if_unsupported!(kernel, test_name);
1239
1240 let strat = (1usize..50, 50usize..400, 0usize..9, any::<u64>()).prop_map(
1241 |(period, len, scenario, seed)| {
1242 let mut lcg = seed;
1243 let mut rng = || {
1244 lcg = lcg.wrapping_mul(1103515245).wrapping_add(12345);
1245 ((lcg / 65536) % 1000000) as f64 / 10000.0 - 50.0
1246 };
1247
1248 let data = match scenario {
1249 0 => (0..len).map(|_| rng()).collect(),
1250 1 => {
1251 let val = rng();
1252 vec![val; len]
1253 }
1254 2 => {
1255 let start = rng();
1256 let step = rng().abs() / 100.0;
1257 (0..len).map(|i| start + (i as f64) * step).collect()
1258 }
1259 3 => {
1260 let start = rng();
1261 let step = rng().abs() / 100.0;
1262 (0..len).map(|i| start - (i as f64) * step).collect()
1263 }
1264 4 => (0..len)
1265 .map(|i| {
1266 if i % 2 == 0 {
1267 1000.0 + rng()
1268 } else {
1269 -1000.0 + rng()
1270 }
1271 })
1272 .collect(),
1273 5 => {
1274 let amplitude = rng().abs() + 10.0;
1275 let offset = rng();
1276 (0..len)
1277 .map(|i| offset + amplitude * (i as f64 * 0.1).sin())
1278 .collect()
1279 }
1280 6 => (0..len).map(|_| rng() * 1e6).collect(),
1281 7 => (0..len).map(|_| rng() * 1e-3).collect(),
1282 8 => (0..len)
1283 .map(|i| {
1284 if i % 3 == 0 {
1285 rng() * 1e6
1286 } else if i % 3 == 1 {
1287 rng() * 1e-3
1288 } else {
1289 rng()
1290 }
1291 })
1292 .collect(),
1293 _ => {
1294 let amplitude = rng().abs() + 10.0;
1295 let period_len = 20;
1296 (0..len)
1297 .map(|i| {
1298 let phase = (i % period_len) as f64 / period_len as f64;
1299 if phase < 0.5 {
1300 amplitude * (2.0 * phase)
1301 } else {
1302 amplitude * (2.0 - 2.0 * phase)
1303 }
1304 })
1305 .collect()
1306 }
1307 };
1308
1309 (data, period, scenario)
1310 },
1311 );
1312
1313 proptest::test_runner::TestRunner::default().run(&strat, |(data, period, scenario)| {
1314 let params = MidpointParams {
1315 period: Some(period),
1316 };
1317 let input = MidpointInput::from_slice(&data, params);
1318
1319 let result = midpoint_with_kernel(&input, kernel)?;
1320 let scalar_result = midpoint_with_kernel(&input, Kernel::Scalar)?;
1321
1322 let tolerance = |expected: f64| -> f64 { (expected.abs() * 1e-12).max(1e-10) };
1323
1324 prop_assert_eq!(result.values.len(), data.len());
1325
1326 let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1327 let warmup_end = first + period - 1;
1328
1329 for i in 0..warmup_end.min(result.values.len()) {
1330 prop_assert!(
1331 result.values[i].is_nan(),
1332 "Expected NaN at index {} during warmup (warmup_end={})",
1333 i,
1334 warmup_end
1335 );
1336 }
1337
1338 for i in warmup_end..data.len() {
1339 let window = &data[(i + 1 - period)..=i];
1340 let mut highest = f64::MIN;
1341 let mut lowest = f64::MAX;
1342
1343 for &val in window {
1344 if val > highest {
1345 highest = val;
1346 }
1347 if val < lowest {
1348 lowest = val;
1349 }
1350 }
1351
1352 let expected = (highest + lowest) / 2.0;
1353 let actual = result.values[i];
1354 let tol = tolerance(expected);
1355
1356 prop_assert!(
1357 (actual - expected).abs() < tol,
1358 "Mathematical accuracy failed at index {}: expected {}, got {}, tolerance {}",
1359 i,
1360 expected,
1361 actual,
1362 tol
1363 );
1364 }
1365
1366 for i in 0..result.values.len() {
1367 let kernel_val = result.values[i];
1368 let scalar_val = scalar_result.values[i];
1369
1370 if kernel_val.is_nan() && scalar_val.is_nan() {
1371 continue;
1372 }
1373
1374 let tol = tolerance(scalar_val);
1375 prop_assert!(
1376 (kernel_val - scalar_val).abs() < tol,
1377 "Kernel consistency failed at index {}: kernel={}, scalar={}, tolerance={}",
1378 i,
1379 kernel_val,
1380 scalar_val,
1381 tol
1382 );
1383 }
1384
1385 if period == 1 {
1386 for i in first..data.len() {
1387 let tol = tolerance(data[i]);
1388 prop_assert!(
1389 (result.values[i] - data[i]).abs() < tol,
1390 "Period=1 should equal input at index {}: {} vs {}, tolerance {}",
1391 i,
1392 result.values[i],
1393 data[i],
1394 tol
1395 );
1396 }
1397 }
1398
1399 if !data.is_empty() {
1400 let first_val = data[first];
1401 let val_tol = tolerance(first_val);
1402 if data.windows(2).all(|w| (w[0] - w[1]).abs() < val_tol) {
1403 for i in warmup_end..data.len() {
1404 prop_assert!(
1405 (result.values[i] - first_val).abs() < val_tol,
1406 "Constant data should produce constant output at index {}",
1407 i
1408 );
1409 }
1410 }
1411 }
1412
1413 for i in warmup_end..data.len() {
1414 let window = &data[(i + 1 - period)..=i];
1415 if !window.is_empty() {
1416 let window_val = window[0];
1417 let win_tol = tolerance(window_val);
1418 if window.windows(2).all(|w| (w[0] - w[1]).abs() < win_tol) {
1419 prop_assert!(
1420 (result.values[i] - window_val).abs() < win_tol,
1421 "Window with identical values should produce that value at index {}",
1422 i
1423 );
1424 }
1425 }
1426 }
1427
1428 if scenario == 2 || scenario == 3 {
1429 for i in warmup_end..data.len() {
1430 let window_start = data[i + 1 - period];
1431 let window_end = data[i];
1432 let expected_midpoint = (window_start + window_end) / 2.0;
1433 let tol = tolerance(expected_midpoint);
1434
1435 prop_assert!(
1436 (result.values[i] - expected_midpoint).abs() < tol,
1437 "Monotonic data midpoint mismatch at index {}: expected {}, got {}, tolerance {}",
1438 i, expected_midpoint, result.values[i], tol
1439 );
1440 }
1441 }
1442
1443 #[cfg(debug_assertions)]
1444 {
1445 for (i, &val) in result.values.iter().enumerate() {
1446 if val.is_nan() {
1447 continue;
1448 }
1449
1450 let bits = val.to_bits();
1451 prop_assert!(
1452 bits != 0x11111111_11111111
1453 && bits != 0x22222222_22222222
1454 && bits != 0x33333333_33333333,
1455 "Found poison value at index {}: {} (0x{:016X})",
1456 i,
1457 val,
1458 bits
1459 );
1460 }
1461 }
1462
1463 Ok(())
1464 })?;
1465
1466 Ok(())
1467 }
1468
1469 macro_rules! generate_all_midpoint_tests {
1470 ($($test_fn:ident),*) => {
1471 paste::paste! {
1472 $( #[test] fn [<$test_fn _scalar_f64>]() { let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar); } )*
1473 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1474 $( #[test] fn [<$test_fn _avx2_f64>]() { let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2); }
1475 #[test] fn [<$test_fn _avx512_f64>]() { let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512); }
1476 )*
1477 }
1478 }
1479 }
1480 generate_all_midpoint_tests!(
1481 check_midpoint_partial_params,
1482 check_midpoint_accuracy,
1483 check_midpoint_default_candles,
1484 check_midpoint_zero_period,
1485 check_midpoint_period_exceeds_length,
1486 check_midpoint_very_small_dataset,
1487 check_midpoint_reinput,
1488 check_midpoint_nan_handling,
1489 check_midpoint_streaming,
1490 check_midpoint_no_poison,
1491 check_midpoint_constant_data,
1492 check_midpoint_monotonic_increasing,
1493 check_midpoint_monotonic_decreasing,
1494 check_midpoint_alternating_extremes,
1495 check_midpoint_single_spike,
1496 check_midpoint_boundary_values,
1497 check_midpoint_period_one,
1498 check_midpoint_all_same_except_one
1499 );
1500
1501 #[cfg(feature = "proptest")]
1502 generate_all_midpoint_tests!(check_midpoint_property);
1503
1504 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1505 skip_if_unsupported!(kernel, test);
1506 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1507 let c = read_candles_from_csv(file)?;
1508 let output = MidpointBatchBuilder::new()
1509 .kernel(kernel)
1510 .apply_candles(&c, "close")?;
1511 let def = MidpointParams::default();
1512 let row = output.values_for(&def).expect("default row missing");
1513 assert_eq!(row.len(), c.close.len());
1514 let expected = [59578.5, 59578.5, 59578.5, 58886.0, 58886.0];
1515 let start = row.len() - 5;
1516 for (i, &v) in row[start..].iter().enumerate() {
1517 assert!(
1518 (v - expected[i]).abs() < 1e-1,
1519 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1520 );
1521 }
1522 Ok(())
1523 }
1524
1525 fn check_batch_multiple_periods(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1526 skip_if_unsupported!(kernel, test);
1527 let data: Vec<f64> = (0..100).map(|i| i as f64).collect();
1528
1529 let output = MidpointBatchBuilder::new()
1530 .kernel(kernel)
1531 .period_range(5, 15, 5)
1532 .apply_slice(&data)?;
1533
1534 assert_eq!(output.rows, 3);
1535 assert_eq!(output.cols, 100);
1536 assert_eq!(output.combos.len(), 3);
1537
1538 let periods = [5, 10, 15];
1539 for (row_idx, &period) in periods.iter().enumerate() {
1540 let row = output
1541 .values_for(&MidpointParams {
1542 period: Some(period),
1543 })
1544 .expect(&format!("Missing row for period {}", period));
1545
1546 for i in 0..(period - 1) {
1547 assert!(
1548 row[i].is_nan(),
1549 "[{}] Expected NaN at index {} for period {}",
1550 test,
1551 i,
1552 period
1553 );
1554 }
1555
1556 for i in (period - 1)..100 {
1557 let expected = (data[i + 1 - period] + data[i]) / 2.0;
1558 assert!(
1559 (row[i] - expected).abs() < 1e-10,
1560 "[{}] Period {} mismatch at index {}: expected {}, got {}",
1561 test,
1562 period,
1563 i,
1564 expected,
1565 row[i]
1566 );
1567 }
1568 }
1569 Ok(())
1570 }
1571
1572 fn check_batch_edge_cases(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1573 skip_if_unsupported!(kernel, test);
1574
1575 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1576 let output = MidpointBatchBuilder::new()
1577 .kernel(kernel)
1578 .period_range(3, 5, 10)
1579 .apply_slice(&data)?;
1580
1581 assert_eq!(
1582 output.rows, 1,
1583 "[{}] Step > range should give single row",
1584 test
1585 );
1586 assert_eq!(output.combos[0].period, Some(3));
1587
1588 let output2 = MidpointBatchBuilder::new()
1589 .kernel(kernel)
1590 .period_static(3)
1591 .apply_slice(&data)?;
1592
1593 assert_eq!(
1594 output2.rows, 1,
1595 "[{}] Static period should give single row",
1596 test
1597 );
1598 assert_eq!(output2.combos[0].period, Some(3));
1599
1600 let empty_data: Vec<f64> = vec![];
1601 let result = MidpointBatchBuilder::new()
1602 .kernel(kernel)
1603 .period_static(3)
1604 .apply_slice(&empty_data);
1605
1606 assert!(result.is_err(), "[{}] Empty data should fail", test);
1607
1608 Ok(())
1609 }
1610
1611 #[cfg(debug_assertions)]
1612 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1613 skip_if_unsupported!(kernel, test);
1614
1615 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1616 let c = read_candles_from_csv(file)?;
1617
1618 let test_configs = vec![
1619 (2, 10, 2),
1620 (5, 25, 5),
1621 (30, 60, 15),
1622 (2, 5, 1),
1623 (10, 20, 2),
1624 (50, 100, 10),
1625 (14, 14, 0),
1626 (3, 7, 1),
1627 ];
1628
1629 for (cfg_idx, &(period_start, period_end, period_step)) in test_configs.iter().enumerate() {
1630 let output = MidpointBatchBuilder::new()
1631 .kernel(kernel)
1632 .period_range(period_start, period_end, period_step)
1633 .apply_candles(&c, "close")?;
1634
1635 for (idx, &val) in output.values.iter().enumerate() {
1636 if val.is_nan() {
1637 continue;
1638 }
1639
1640 let bits = val.to_bits();
1641 let row = idx / output.cols;
1642 let col = idx % output.cols;
1643 let combo = &output.combos[row];
1644
1645 if bits == 0x11111111_11111111 {
1646 panic!(
1647 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1648 at row {} col {} (flat index {}) with params: period={}",
1649 test,
1650 cfg_idx,
1651 val,
1652 bits,
1653 row,
1654 col,
1655 idx,
1656 combo.period.unwrap_or(14)
1657 );
1658 }
1659
1660 if bits == 0x22222222_22222222 {
1661 panic!(
1662 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1663 at row {} col {} (flat index {}) with params: period={}",
1664 test,
1665 cfg_idx,
1666 val,
1667 bits,
1668 row,
1669 col,
1670 idx,
1671 combo.period.unwrap_or(14)
1672 );
1673 }
1674
1675 if bits == 0x33333333_33333333 {
1676 panic!(
1677 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1678 at row {} col {} (flat index {}) with params: period={}",
1679 test,
1680 cfg_idx,
1681 val,
1682 bits,
1683 row,
1684 col,
1685 idx,
1686 combo.period.unwrap_or(14)
1687 );
1688 }
1689 }
1690 }
1691
1692 Ok(())
1693 }
1694
1695 #[cfg(not(debug_assertions))]
1696 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1697 Ok(())
1698 }
1699 macro_rules! gen_batch_tests {
1700 ($fn_name:ident) => {
1701 paste::paste! {
1702 #[test] fn [<$fn_name _scalar>]() { let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch); }
1703 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1704 #[test] fn [<$fn_name _avx2>]() { let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch); }
1705 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1706 #[test] fn [<$fn_name _avx512>]() { let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch); }
1707 #[test] fn [<$fn_name _auto_detect>]() { let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto); }
1708 }
1709 };
1710 }
1711 gen_batch_tests!(check_batch_default_row);
1712 gen_batch_tests!(check_batch_no_poison);
1713 gen_batch_tests!(check_batch_multiple_periods);
1714 gen_batch_tests!(check_batch_edge_cases);
1715
1716 #[test]
1717 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1718 fn test_midpoint_into_matches_api() -> Result<(), Box<dyn Error>> {
1719 let len = 512usize;
1720 let mut data = vec![0.0f64; len];
1721
1722 data[0] = f64::NAN;
1723 data[1] = f64::NAN;
1724 data[2] = f64::NAN;
1725 for i in 3..len {
1726 let x = i as f64;
1727
1728 data[i] = 100.0 + 0.05 * x + (0.01 * x).sin() * 2.0;
1729 }
1730
1731 let input = MidpointInput::from_slice(&data, MidpointParams::default());
1732
1733 let baseline = midpoint(&input)?.values;
1734
1735 let mut out = vec![0.0f64; len];
1736 midpoint_into(&input, &mut out)?;
1737
1738 assert_eq!(baseline.len(), out.len());
1739
1740 #[inline]
1741 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1742 (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
1743 }
1744
1745 for (i, (&a, &b)) in baseline.iter().zip(out.iter()).enumerate() {
1746 assert!(
1747 eq_or_both_nan(a, b),
1748 "mismatch at {}: baseline={} into={}",
1749 i,
1750 a,
1751 b
1752 );
1753 }
1754
1755 Ok(())
1756 }
1757}
1758
1759#[inline(always)]
1760pub fn midpoint_batch_inner_into(
1761 data: &[f64],
1762 sweep: &MidpointBatchRange,
1763 kern: Kernel,
1764 parallel: bool,
1765 out: &mut [f64],
1766) -> Result<Vec<MidpointParams>, MidpointError> {
1767 if data.is_empty() {
1768 return Err(MidpointError::EmptyInputData);
1769 }
1770
1771 let combos = expand_grid_checked(sweep)?;
1772
1773 let first = data
1774 .iter()
1775 .position(|x| !x.is_nan())
1776 .ok_or(MidpointError::AllValuesNaN)?;
1777 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1778 if data.len() - first < max_p {
1779 return Err(MidpointError::NotEnoughValidData {
1780 needed: max_p,
1781 valid: data.len() - first,
1782 });
1783 }
1784
1785 let rows = combos.len();
1786 let cols = data.len();
1787 let _total = rows.checked_mul(cols).ok_or_else(|| {
1788 MidpointError::InvalidInput("rows*cols overflow in midpoint_batch_inner".into())
1789 })?;
1790 let expected = rows.checked_mul(cols).ok_or_else(|| {
1791 MidpointError::InvalidInput("rows*cols overflow in midpoint_batch_inner_into".into())
1792 })?;
1793 if out.len() != expected {
1794 return Err(MidpointError::OutputLengthMismatch {
1795 expected,
1796 got: out.len(),
1797 });
1798 }
1799
1800 let out_mu: &mut [MaybeUninit<f64>] = unsafe {
1801 core::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1802 };
1803 let warm_prefixes: Vec<usize> = combos
1804 .iter()
1805 .map(|c| first + c.period.unwrap() - 1)
1806 .collect();
1807 init_matrix_prefixes(out_mu, cols, &warm_prefixes);
1808
1809 let do_row = |row: usize, row_mu: &mut [MaybeUninit<f64>]| unsafe {
1810 let period = combos[row].period.unwrap();
1811 let row_out: &mut [f64] =
1812 core::slice::from_raw_parts_mut(row_mu.as_mut_ptr() as *mut f64, row_mu.len());
1813 match kern {
1814 Kernel::Scalar => midpoint_row_scalar(data, first, period, row_out),
1815 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1816 Kernel::Avx2 => midpoint_row_avx2(data, first, period, row_out),
1817 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1818 Kernel::Avx512 => midpoint_row_avx512(data, first, period, row_out),
1819 _ => unreachable!(),
1820 }
1821 };
1822
1823 if parallel {
1824 #[cfg(not(target_arch = "wasm32"))]
1825 {
1826 out_mu
1827 .par_chunks_mut(cols)
1828 .enumerate()
1829 .for_each(|(row, slice)| do_row(row, slice));
1830 }
1831 #[cfg(target_arch = "wasm32")]
1832 {
1833 for (row, slice) in out_mu.chunks_mut(cols).enumerate() {
1834 do_row(row, slice);
1835 }
1836 }
1837 } else {
1838 for (row, slice) in out_mu.chunks_mut(cols).enumerate() {
1839 do_row(row, slice);
1840 }
1841 }
1842
1843 Ok(combos)
1844}
1845
1846#[cfg(feature = "python")]
1847#[pyfunction(name = "midpoint")]
1848#[pyo3(signature = (data, period=None, kernel=None))]
1849pub fn midpoint_py<'py>(
1850 py: Python<'py>,
1851 data: PyReadonlyArray1<'py, f64>,
1852 period: Option<usize>,
1853 kernel: Option<&str>,
1854) -> PyResult<Bound<'py, PyArray1<f64>>> {
1855 use numpy::{IntoPyArray, PyArrayMethods};
1856
1857 let slice_in = data.as_slice()?;
1858 let kern = validate_kernel(kernel, false)?;
1859
1860 let params = MidpointParams { period };
1861 let input = MidpointInput::from_slice(slice_in, params);
1862
1863 let result_vec: Vec<f64> = py
1864 .allow_threads(|| midpoint_with_kernel(&input, kern).map(|o| o.values))
1865 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1866
1867 Ok(result_vec.into_pyarray(py))
1868}
1869
1870#[cfg(feature = "python")]
1871#[pyclass(name = "MidpointStream")]
1872pub struct MidpointStreamPy {
1873 stream: MidpointStream,
1874}
1875
1876#[cfg(feature = "python")]
1877#[pymethods]
1878impl MidpointStreamPy {
1879 #[new]
1880 fn new(period: Option<usize>) -> PyResult<Self> {
1881 let params = MidpointParams { period };
1882 let stream =
1883 MidpointStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1884 Ok(MidpointStreamPy { stream })
1885 }
1886
1887 fn update(&mut self, value: f64) -> Option<f64> {
1888 self.stream.update(value)
1889 }
1890}
1891
1892#[cfg(feature = "python")]
1893#[pyfunction(name = "midpoint_batch")]
1894#[pyo3(signature = (data, period_range, kernel=None))]
1895pub fn midpoint_batch_py<'py>(
1896 py: Python<'py>,
1897 data: PyReadonlyArray1<'py, f64>,
1898 period_range: (usize, usize, usize),
1899 kernel: Option<&str>,
1900) -> PyResult<Bound<'py, PyDict>> {
1901 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1902
1903 let slice_in = data.as_slice()?;
1904 let kern = validate_kernel(kernel, true)?;
1905
1906 let sweep = MidpointBatchRange {
1907 period: period_range,
1908 };
1909
1910 let combos = expand_grid_checked(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1911 let rows = combos.len();
1912 let cols = slice_in.len();
1913 let expected = rows
1914 .checked_mul(cols)
1915 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1916
1917 let out_arr = unsafe { PyArray1::<f64>::new(py, [expected], false) };
1918 let slice_out = unsafe { out_arr.as_slice_mut()? };
1919
1920 let combos = py
1921 .allow_threads(|| {
1922 let kernel = match kern {
1923 Kernel::Auto => detect_best_batch_kernel(),
1924 k => k,
1925 };
1926 let simd = match kernel {
1927 Kernel::Avx512Batch => Kernel::Avx512,
1928 Kernel::Avx2Batch => Kernel::Avx2,
1929 Kernel::ScalarBatch => Kernel::Scalar,
1930 _ => unreachable!(),
1931 };
1932 midpoint_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1933 })
1934 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1935
1936 let dict = PyDict::new(py);
1937 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1938 dict.set_item(
1939 "periods",
1940 combos
1941 .iter()
1942 .map(|p| p.period.unwrap() as u64)
1943 .collect::<Vec<_>>()
1944 .into_pyarray(py),
1945 )?;
1946
1947 Ok(dict)
1948}
1949
1950#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1951#[wasm_bindgen]
1952pub fn midpoint_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1953 let params = MidpointParams {
1954 period: Some(period),
1955 };
1956 let input = MidpointInput::from_slice(data, params);
1957
1958 let mut output = vec![0.0; data.len()];
1959 midpoint_into_slice(&mut output, &input, Kernel::Auto)
1960 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1961
1962 Ok(output)
1963}
1964
1965#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1966#[wasm_bindgen]
1967pub fn midpoint_alloc(len: usize) -> *mut f64 {
1968 let mut vec = Vec::<f64>::with_capacity(len);
1969 let ptr = vec.as_mut_ptr();
1970 std::mem::forget(vec);
1971 ptr
1972}
1973
1974#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1975#[wasm_bindgen]
1976pub fn midpoint_free(ptr: *mut f64, len: usize) {
1977 if !ptr.is_null() {
1978 unsafe {
1979 let _ = Vec::from_raw_parts(ptr, len, len);
1980 }
1981 }
1982}
1983
1984#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1985#[wasm_bindgen]
1986pub fn midpoint_into(
1987 in_ptr: *const f64,
1988 out_ptr: *mut f64,
1989 len: usize,
1990 period: usize,
1991) -> Result<(), JsValue> {
1992 if in_ptr.is_null() || out_ptr.is_null() {
1993 return Err(JsValue::from_str("Null pointer provided"));
1994 }
1995
1996 unsafe {
1997 let data = std::slice::from_raw_parts(in_ptr, len);
1998 let params = MidpointParams {
1999 period: Some(period),
2000 };
2001 let input = MidpointInput::from_slice(data, params);
2002
2003 if in_ptr == out_ptr {
2004 let mut temp = vec![0.0; len];
2005 midpoint_into_slice(&mut temp, &input, Kernel::Auto)
2006 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2007 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2008 out.copy_from_slice(&temp);
2009 } else {
2010 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2011 midpoint_into_slice(out, &input, Kernel::Auto)
2012 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2013 }
2014 Ok(())
2015 }
2016}
2017
2018#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2019#[wasm_bindgen]
2020pub fn midpoint_batch_into(
2021 in_ptr: *const f64,
2022 out_ptr: *mut f64,
2023 len: usize,
2024 period_start: usize,
2025 period_end: usize,
2026 period_step: usize,
2027) -> Result<usize, JsValue> {
2028 if in_ptr.is_null() || out_ptr.is_null() {
2029 return Err(JsValue::from_str(
2030 "null pointer passed to midpoint_batch_into",
2031 ));
2032 }
2033
2034 unsafe {
2035 let data = std::slice::from_raw_parts(in_ptr, len);
2036
2037 let sweep = MidpointBatchRange {
2038 period: (period_start, period_end, period_step),
2039 };
2040
2041 let combos = expand_grid_checked(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2042 let rows = combos.len();
2043 let cols = len;
2044 let expected = rows
2045 .checked_mul(cols)
2046 .ok_or_else(|| JsValue::from_str("rows*cols overflow in midpoint_batch_into"))?;
2047
2048 let out = std::slice::from_raw_parts_mut(out_ptr, expected);
2049
2050 midpoint_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
2051 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2052
2053 Ok(rows)
2054 }
2055}
2056
2057#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2058#[derive(Serialize, Deserialize)]
2059pub struct MidpointBatchConfig {
2060 pub period_range: (usize, usize, usize),
2061}
2062
2063#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2064#[derive(Serialize, Deserialize)]
2065pub struct MidpointBatchJsOutput {
2066 pub values: Vec<f64>,
2067 pub combos: Vec<MidpointParams>,
2068 pub rows: usize,
2069 pub cols: usize,
2070}
2071
2072#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2073#[wasm_bindgen(js_name = midpoint_batch)]
2074pub fn midpoint_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2075 let config: MidpointBatchConfig = serde_wasm_bindgen::from_value(config)
2076 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2077
2078 let sweep = MidpointBatchRange {
2079 period: config.period_range,
2080 };
2081
2082 let result = midpoint_batch_with_kernel(data, &sweep, Kernel::Auto)
2083 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2084
2085 let output = MidpointBatchJsOutput {
2086 values: result.values,
2087 combos: result.combos,
2088 rows: result.rows,
2089 cols: result.cols,
2090 };
2091
2092 serde_wasm_bindgen::to_value(&output).map_err(|e| JsValue::from_str(&e.to_string()))
2093}