1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde_wasm_bindgen;
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::{source_type, Candles};
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21 make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
26use core::arch::x86_64::*;
27#[cfg(not(target_arch = "wasm32"))]
28use rayon::prelude::*;
29use std::convert::AsRef;
30use std::error::Error;
31use std::mem::MaybeUninit;
32use thiserror::Error;
33
34#[cfg(all(feature = "python", feature = "cuda"))]
35use crate::cuda::cuda_available;
36#[cfg(all(feature = "python", feature = "cuda"))]
37use crate::cuda::moving_averages::DeviceArrayF32;
38#[cfg(all(feature = "python", feature = "cuda"))]
39use crate::cuda::oscillators::mfi_wrapper::CudaMfi;
40#[cfg(all(feature = "python", feature = "cuda"))]
41use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
42#[cfg(all(feature = "python", feature = "cuda"))]
43use cust::context::Context;
44#[cfg(all(feature = "python", feature = "cuda"))]
45use std::sync::Arc;
46
47#[derive(Debug, Clone)]
48pub enum MfiData<'a> {
49 Candles {
50 candles: &'a Candles,
51 source: &'a str,
52 },
53 Slices {
54 typical_price: &'a [f64],
55 volume: &'a [f64],
56 },
57}
58
59#[derive(Debug, Clone)]
60pub struct MfiOutput {
61 pub values: Vec<f64>,
62}
63
64#[derive(Debug, Clone)]
65#[cfg_attr(
66 all(target_arch = "wasm32", feature = "wasm"),
67 derive(Serialize, Deserialize)
68)]
69pub struct MfiParams {
70 pub period: Option<usize>,
71}
72
73impl Default for MfiParams {
74 fn default() -> Self {
75 Self { period: Some(14) }
76 }
77}
78
79#[derive(Debug, Clone)]
80pub struct MfiInput<'a> {
81 pub data: MfiData<'a>,
82 pub params: MfiParams,
83}
84
85impl<'a> MfiInput<'a> {
86 #[inline]
87 pub fn from_candles(candles: &'a Candles, source: &'a str, params: MfiParams) -> Self {
88 Self {
89 data: MfiData::Candles { candles, source },
90 params,
91 }
92 }
93
94 #[inline]
95 pub fn from_slices(typical_price: &'a [f64], volume: &'a [f64], params: MfiParams) -> Self {
96 Self {
97 data: MfiData::Slices {
98 typical_price,
99 volume,
100 },
101 params,
102 }
103 }
104
105 #[inline]
106 pub fn with_default_candles(candles: &'a Candles) -> Self {
107 Self::from_candles(candles, "hlc3", MfiParams::default())
108 }
109
110 #[inline]
111 pub fn get_period(&self) -> usize {
112 self.params.period.unwrap_or(14)
113 }
114}
115
116#[derive(Debug, Error)]
117pub enum MfiError {
118 #[error("mfi: Empty data provided.")]
119 EmptyInputData,
120 #[error("mfi: Invalid period: period = {period}, data length = {data_len}")]
121 InvalidPeriod { period: usize, data_len: usize },
122 #[error("mfi: Not enough valid data: needed = {needed}, valid = {valid}")]
123 NotEnoughValidData { needed: usize, valid: usize },
124 #[error("mfi: All values are NaN.")]
125 AllValuesNaN,
126 #[error("mfi: Output length mismatch: expected = {expected}, got = {got}")]
127 OutputLengthMismatch { expected: usize, got: usize },
128 #[error("mfi: Invalid range: start={start} end={end} step={step}")]
129 InvalidRange {
130 start: usize,
131 end: usize,
132 step: usize,
133 },
134 #[error("mfi: Invalid kernel for batch path: {0:?}")]
135 InvalidKernelForBatch(Kernel),
136}
137
138#[inline]
139pub fn mfi(input: &MfiInput) -> Result<MfiOutput, MfiError> {
140 mfi_with_kernel(input, Kernel::Auto)
141}
142
143#[inline(always)]
144fn mfi_prepare<'a>(
145 input: &'a MfiInput<'a>,
146 kernel: Kernel,
147) -> Result<(&'a [f64], &'a [f64], usize, usize, Kernel), MfiError> {
148 let (typical_price, volume): (&[f64], &[f64]) = match &input.data {
149 MfiData::Candles { candles, source } => {
150 (source_type(candles, source), candles.volume.as_slice())
151 }
152 MfiData::Slices {
153 typical_price,
154 volume,
155 } => (*typical_price, *volume),
156 };
157
158 let length = typical_price.len();
159 if length == 0 || volume.len() != length {
160 return Err(MfiError::EmptyInputData);
161 }
162
163 let period = input.get_period();
164 let first_valid_idx = (0..length).find(|&i| !typical_price[i].is_nan() && !volume[i].is_nan());
165 let first_valid_idx = match first_valid_idx {
166 Some(idx) => idx,
167 None => return Err(MfiError::AllValuesNaN),
168 };
169
170 if period == 0 || period > length {
171 return Err(MfiError::InvalidPeriod {
172 period,
173 data_len: length,
174 });
175 }
176 if (length - first_valid_idx) < period {
177 return Err(MfiError::NotEnoughValidData {
178 needed: period,
179 valid: length - first_valid_idx,
180 });
181 }
182
183 let chosen = match kernel {
184 Kernel::Auto => detect_best_kernel(),
185 other => other,
186 };
187
188 Ok((typical_price, volume, period, first_valid_idx, chosen))
189}
190
191#[inline(always)]
192fn mfi_compute_into(
193 typical_price: &[f64],
194 volume: &[f64],
195 period: usize,
196 first: usize,
197 kernel: Kernel,
198 out: &mut [f64],
199) {
200 unsafe {
201 match kernel {
202 Kernel::Scalar | Kernel::ScalarBatch => {
203 mfi_scalar(typical_price, volume, period, first, out)
204 }
205 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
206 Kernel::Avx2 | Kernel::Avx2Batch => mfi_avx2(typical_price, volume, period, first, out),
207 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
208 Kernel::Avx512 | Kernel::Avx512Batch => {
209 mfi_avx512(typical_price, volume, period, first, out)
210 }
211 _ => unreachable!(),
212 }
213 }
214}
215
216pub fn mfi_with_kernel(input: &MfiInput, kernel: Kernel) -> Result<MfiOutput, MfiError> {
217 let (typical_price, volume, period, first_valid_idx, chosen) = mfi_prepare(input, kernel)?;
218
219 let warmup_period = first_valid_idx + period - 1;
220 let mut out = alloc_with_nan_prefix(typical_price.len(), warmup_period);
221
222 mfi_compute_into(
223 typical_price,
224 volume,
225 period,
226 first_valid_idx,
227 chosen,
228 &mut out,
229 );
230
231 Ok(MfiOutput { values: out })
232}
233
234#[inline]
235pub unsafe fn mfi_scalar(
236 typical_price: &[f64],
237 volume: &[f64],
238 period: usize,
239 first: usize,
240 out: &mut [f64],
241) {
242 let len = typical_price.len();
243 if len == 0 {
244 return;
245 }
246
247 let mut ring_buf = vec![0.0f64; period * 2];
248
249 let tp_ptr = typical_price.as_ptr();
250 let vol_ptr = volume.as_ptr();
251 let out_ptr = out.as_mut_ptr();
252 let pos_ptr = ring_buf.as_mut_ptr();
253 let neg_ptr = ring_buf.as_mut_ptr().add(period);
254
255 let mut pos_sum = 0.0f64;
256 let mut neg_sum = 0.0f64;
257
258 let mut prev = *tp_ptr.add(first);
259 let mut ring = 0usize;
260
261 let seed_start = first + 1;
262 let seed_end = first + period;
263 let mut i = seed_start;
264 while i < seed_end {
265 let tp_i = *tp_ptr.add(i);
266 let flow = tp_i * *vol_ptr.add(i);
267 let diff = tp_i - prev;
268 prev = tp_i;
269
270 let gt = (diff > 0.0) as i32 as f64;
271 let lt = (diff < 0.0) as i32 as f64;
272 let pos_new = flow * gt;
273 let neg_new = flow * lt;
274
275 *pos_ptr.add(ring) = pos_new;
276 *neg_ptr.add(ring) = neg_new;
277 pos_sum += pos_new;
278 neg_sum += neg_new;
279
280 ring += 1;
281 if ring == period {
282 ring = 0;
283 }
284 i += 1;
285 }
286
287 let idx0 = seed_end - 1;
288 if idx0 < len {
289 let total = pos_sum + neg_sum;
290
291 let val = if total < 1e-14 {
292 0.0
293 } else {
294 100.0 * (pos_sum / total)
295 };
296 *out_ptr.add(idx0) = val;
297 }
298
299 i = seed_end;
300 while i < len {
301 let old_pos = *pos_ptr.add(ring);
302 let old_neg = *neg_ptr.add(ring);
303 pos_sum -= old_pos;
304 neg_sum -= old_neg;
305
306 let tp_i = *tp_ptr.add(i);
307 let flow = tp_i * *vol_ptr.add(i);
308 let diff = tp_i - prev;
309 prev = tp_i;
310
311 let gt = (diff > 0.0) as i32 as f64;
312 let lt = (diff < 0.0) as i32 as f64;
313 let pos_new = flow * gt;
314 let neg_new = flow * lt;
315
316 *pos_ptr.add(ring) = pos_new;
317 *neg_ptr.add(ring) = neg_new;
318 pos_sum += pos_new;
319 neg_sum += neg_new;
320
321 let total = pos_sum + neg_sum;
322 let val = if total < 1e-14 {
323 0.0
324 } else {
325 100.0 * (pos_sum / total)
326 };
327 *out_ptr.add(i) = val;
328
329 ring += 1;
330 if ring == period {
331 ring = 0;
332 }
333
334 i += 1;
335 }
336}
337
338#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
339#[inline]
340pub unsafe fn mfi_avx2(
341 typical_price: &[f64],
342 volume: &[f64],
343 period: usize,
344 first: usize,
345 out: &mut [f64],
346) {
347 mfi_scalar(typical_price, volume, period, first, out)
348}
349
350#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
351#[inline]
352pub fn mfi_avx512(
353 typical_price: &[f64],
354 volume: &[f64],
355 period: usize,
356 first: usize,
357 out: &mut [f64],
358) {
359 unsafe {
360 if period <= 32 {
361 mfi_avx512_short(typical_price, volume, period, first, out)
362 } else {
363 mfi_avx512_long(typical_price, volume, period, first, out)
364 }
365 }
366}
367
368#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
369#[inline]
370pub unsafe fn mfi_avx512_short(
371 typical_price: &[f64],
372 volume: &[f64],
373 period: usize,
374 first: usize,
375 out: &mut [f64],
376) {
377 mfi_scalar(typical_price, volume, period, first, out)
378}
379
380#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
381#[inline]
382pub unsafe fn mfi_avx512_long(
383 typical_price: &[f64],
384 volume: &[f64],
385 period: usize,
386 first: usize,
387 out: &mut [f64],
388) {
389 mfi_scalar(typical_price, volume, period, first, out)
390}
391
392#[derive(Copy, Clone, Debug)]
393pub struct MfiBuilder {
394 period: Option<usize>,
395 kernel: Kernel,
396}
397
398impl Default for MfiBuilder {
399 fn default() -> Self {
400 Self {
401 period: None,
402 kernel: Kernel::Auto,
403 }
404 }
405}
406
407impl MfiBuilder {
408 #[inline(always)]
409 pub fn new() -> Self {
410 Self::default()
411 }
412 #[inline(always)]
413 pub fn period(mut self, n: usize) -> Self {
414 self.period = Some(n);
415 self
416 }
417 #[inline(always)]
418 pub fn kernel(mut self, k: Kernel) -> Self {
419 self.kernel = k;
420 self
421 }
422
423 #[inline(always)]
424 pub fn apply(self, c: &Candles) -> Result<MfiOutput, MfiError> {
425 let p = MfiParams {
426 period: self.period,
427 };
428 let i = MfiInput::from_candles(c, "hlc3", p);
429 mfi_with_kernel(&i, self.kernel)
430 }
431
432 #[inline(always)]
433 pub fn apply_slices(
434 self,
435 typical_price: &[f64],
436 volume: &[f64],
437 ) -> Result<MfiOutput, MfiError> {
438 let p = MfiParams {
439 period: self.period,
440 };
441 let i = MfiInput::from_slices(typical_price, volume, p);
442 mfi_with_kernel(&i, self.kernel)
443 }
444
445 #[inline(always)]
446 pub fn into_stream(self) -> Result<MfiStream, MfiError> {
447 let p = MfiParams {
448 period: self.period,
449 };
450 MfiStream::try_new(p)
451 }
452}
453
454#[derive(Debug, Clone)]
455pub struct MfiStream {
456 period: usize,
457 pos_buf: Vec<f64>,
458 neg_buf: Vec<f64>,
459 head: usize,
460 filled: bool,
461 pos_sum: f64,
462 neg_sum: f64,
463 prev_typical: f64,
464 index: usize,
465}
466
467impl MfiStream {
468 pub fn try_new(params: MfiParams) -> Result<Self, MfiError> {
469 let period = params.period.unwrap_or(14);
470 if period == 0 {
471 return Err(MfiError::InvalidPeriod {
472 period,
473 data_len: 0,
474 });
475 }
476 Ok(Self {
477 period,
478 pos_buf: vec![0.0; period],
479 neg_buf: vec![0.0; period],
480 head: 0,
481 filled: false,
482 pos_sum: 0.0,
483 neg_sum: 0.0,
484 prev_typical: f64::NAN,
485 index: 0,
486 })
487 }
488
489 #[inline(always)]
490 pub fn update(&mut self, typical_price: f64, volume: f64) -> Option<f64> {
491 if self.index == 0 {
492 self.prev_typical = typical_price;
493 self.index = 1;
494 return None;
495 }
496
497 let diff = typical_price - self.prev_typical;
498 self.prev_typical = typical_price;
499
500 let flow = typical_price.mul_add(volume, 0.0);
501
502 let gt = (diff > 0.0) as i32 as f64;
503 let lt = (diff < 0.0) as i32 as f64;
504 let pos_new = flow * gt;
505 let neg_new = flow * lt;
506
507 unsafe {
508 let old_pos = *self.pos_buf.get_unchecked(self.head);
509 let old_neg = *self.neg_buf.get_unchecked(self.head);
510
511 self.pos_sum += pos_new - old_pos;
512 self.neg_sum += neg_new - old_neg;
513
514 *self.pos_buf.get_unchecked_mut(self.head) = pos_new;
515 *self.neg_buf.get_unchecked_mut(self.head) = neg_new;
516 }
517
518 self.head += 1;
519 if self.head == self.period {
520 self.head = 0;
521 self.filled = true;
522 }
523 self.index += 1;
524
525 if !self.filled {
526 return None;
527 }
528
529 let total = self.pos_sum + self.neg_sum;
530 if total <= 1e-14 {
531 Some(0.0)
532 } else {
533 Some(100.0 * self.pos_sum * total.recip())
534 }
535 }
536}
537
538#[derive(Clone, Debug)]
539pub struct MfiBatchRange {
540 pub period: (usize, usize, usize),
541}
542
543impl Default for MfiBatchRange {
544 fn default() -> Self {
545 Self {
546 period: (14, 263, 1),
547 }
548 }
549}
550
551#[derive(Clone, Debug, Default)]
552pub struct MfiBatchBuilder {
553 range: MfiBatchRange,
554 kernel: Kernel,
555}
556
557impl MfiBatchBuilder {
558 pub fn new() -> Self {
559 Self::default()
560 }
561 pub fn kernel(mut self, k: Kernel) -> Self {
562 self.kernel = k;
563 self
564 }
565
566 #[inline]
567 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
568 self.range.period = (start, end, step);
569 self
570 }
571 #[inline]
572 pub fn period_static(mut self, p: usize) -> Self {
573 self.range.period = (p, p, 0);
574 self
575 }
576
577 pub fn apply_slices(
578 self,
579 typical_price: &[f64],
580 volume: &[f64],
581 ) -> Result<MfiBatchOutput, MfiError> {
582 mfi_batch_with_kernel(typical_price, volume, &self.range, self.kernel)
583 }
584
585 pub fn apply_candles(self, c: &Candles) -> Result<MfiBatchOutput, MfiError> {
586 let typical_price = source_type(c, "hlc3");
587 self.apply_slices(typical_price, &c.volume)
588 }
589
590 pub fn with_default_candles(c: &Candles, k: Kernel) -> Result<MfiBatchOutput, MfiError> {
591 MfiBatchBuilder::new().kernel(k).apply_candles(c)
592 }
593}
594
595#[derive(Clone, Debug)]
596pub struct MfiBatchOutput {
597 pub values: Vec<f64>,
598 pub combos: Vec<MfiParams>,
599 pub rows: usize,
600 pub cols: usize,
601}
602impl MfiBatchOutput {
603 pub fn row_for_params(&self, p: &MfiParams) -> Option<usize> {
604 self.combos
605 .iter()
606 .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
607 }
608 pub fn values_for(&self, p: &MfiParams) -> Option<&[f64]> {
609 self.row_for_params(p).map(|row| {
610 let start = row * self.cols;
611 &self.values[start..start + self.cols]
612 })
613 }
614}
615
616#[inline(always)]
617fn expand_grid(r: &MfiBatchRange) -> Vec<MfiParams> {
618 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, MfiError> {
619 if step == 0 || start == end {
620 return Ok(vec![start]);
621 }
622 if start < end {
623 return Ok((start..=end).step_by(step.max(1)).collect());
624 }
625 let mut v = Vec::new();
626 let mut x = start as isize;
627 let end_i = end as isize;
628 let st = (step as isize).max(1);
629 while x >= end_i {
630 v.push(x as usize);
631 x -= st;
632 }
633 if v.is_empty() {
634 return Err(MfiError::InvalidRange { start, end, step });
635 }
636 Ok(v)
637 }
638 let periods = match axis_usize(r.period) {
639 Ok(v) => v,
640 Err(_) => return Vec::new(),
641 };
642 let mut out = Vec::with_capacity(periods.len());
643 for &p in &periods {
644 out.push(MfiParams { period: Some(p) });
645 }
646 out
647}
648
649pub fn mfi_batch_with_kernel(
650 typical_price: &[f64],
651 volume: &[f64],
652 sweep: &MfiBatchRange,
653 k: Kernel,
654) -> Result<MfiBatchOutput, MfiError> {
655 let kernel = match k {
656 Kernel::Auto => detect_best_batch_kernel(),
657 other if other.is_batch() => other,
658 _ => return Err(MfiError::InvalidKernelForBatch(k)),
659 };
660 let simd = match kernel {
661 Kernel::Avx512Batch => Kernel::Avx512,
662 Kernel::Avx2Batch => Kernel::Avx2,
663 Kernel::ScalarBatch => Kernel::Scalar,
664 _ => unreachable!(),
665 };
666 mfi_batch_par_slice(typical_price, volume, sweep, simd)
667}
668
669#[inline(always)]
670pub fn mfi_batch_slice(
671 typical_price: &[f64],
672 volume: &[f64],
673 sweep: &MfiBatchRange,
674 kern: Kernel,
675) -> Result<MfiBatchOutput, MfiError> {
676 mfi_batch_inner(typical_price, volume, sweep, kern, false)
677}
678
679#[inline(always)]
680pub fn mfi_batch_par_slice(
681 typical_price: &[f64],
682 volume: &[f64],
683 sweep: &MfiBatchRange,
684 kern: Kernel,
685) -> Result<MfiBatchOutput, MfiError> {
686 mfi_batch_inner(typical_price, volume, sweep, kern, true)
687}
688
689fn round_up8(x: usize) -> usize {
690 (x + 7) & !7
691}
692
693#[inline(always)]
694fn mfi_batch_inner(
695 typical_price: &[f64],
696 volume: &[f64],
697 sweep: &MfiBatchRange,
698 kern: Kernel,
699 parallel: bool,
700) -> Result<MfiBatchOutput, MfiError> {
701 let combos = expand_grid(sweep);
702 if combos.is_empty() {
703 return Err(MfiError::InvalidRange {
704 start: sweep.period.0,
705 end: sweep.period.1,
706 step: sweep.period.2,
707 });
708 }
709 let length = typical_price.len();
710 let first = (0..length)
711 .find(|&i| !typical_price[i].is_nan() && !volume[i].is_nan())
712 .ok_or(MfiError::AllValuesNaN)?;
713
714 let max_p = combos
715 .iter()
716 .map(|c| round_up8(c.period.unwrap()))
717 .max()
718 .unwrap();
719 if length - first < max_p {
720 return Err(MfiError::NotEnoughValidData {
721 needed: max_p,
722 valid: length - first,
723 });
724 }
725
726 let rows = combos.len();
727 let cols = length;
728
729 if volume.len() != cols {
730 return Err(MfiError::EmptyInputData);
731 }
732
733 let mut buf_mu = make_uninit_matrix(rows, cols);
734 let warmup_periods: Vec<usize> = combos
735 .iter()
736 .map(|c| first + c.period.unwrap() - 1)
737 .collect();
738 init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
739
740 let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
741 let out: &mut [f64] = unsafe {
742 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
743 };
744
745 let rows = combos.len();
746 let use_prefix = rows >= 8;
747
748 let (pos_prefix, neg_prefix) = if use_prefix {
749 let (pp, np) =
750 unsafe { precompute_flow_prefixes_select(typical_price, volume, first, kern) };
751 (Some(pp), Some(np))
752 } else {
753 (None, None)
754 };
755
756 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
757 let period = combos[row].period.unwrap();
758 if let (Some(ref pp), Some(ref np)) = (pos_prefix.as_ref(), neg_prefix.as_ref()) {
759 mfi_row_from_prefixes(pp, np, first, period, out_row)
760 } else {
761 match kern {
762 Kernel::Scalar => mfi_row_scalar(typical_price, volume, first, period, out_row),
763 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
764 Kernel::Avx2 => mfi_row_avx2(typical_price, volume, first, period, out_row),
765 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
766 Kernel::Avx512 => mfi_row_avx512(typical_price, volume, first, period, out_row),
767 _ => unreachable!(),
768 }
769 }
770 };
771
772 if parallel {
773 #[cfg(not(target_arch = "wasm32"))]
774 {
775 use rayon::prelude::*;
776 out.par_chunks_mut(cols)
777 .enumerate()
778 .for_each(|(row, slice)| do_row(row, slice));
779 }
780
781 #[cfg(target_arch = "wasm32")]
782 {
783 for (row, slice) in out.chunks_mut(cols).enumerate() {
784 do_row(row, slice);
785 }
786 }
787 } else {
788 for (row, slice) in out.chunks_mut(cols).enumerate() {
789 do_row(row, slice);
790 }
791 }
792
793 let values = unsafe {
794 Vec::from_raw_parts(
795 buf_guard.as_mut_ptr() as *mut f64,
796 buf_guard.len(),
797 buf_guard.capacity(),
798 )
799 };
800
801 Ok(MfiBatchOutput {
802 values,
803 combos,
804 rows,
805 cols,
806 })
807}
808
809#[inline(always)]
810fn mfi_batch_inner_into(
811 typical_price: &[f64],
812 volume: &[f64],
813 sweep: &MfiBatchRange,
814 kern: Kernel,
815 parallel: bool,
816 out: &mut [f64],
817) -> Result<Vec<MfiParams>, MfiError> {
818 let combos = expand_grid(sweep);
819 if combos.is_empty() {
820 return Err(MfiError::InvalidRange {
821 start: sweep.period.0,
822 end: sweep.period.1,
823 step: sweep.period.2,
824 });
825 }
826
827 let length = typical_price.len();
828 let first = (0..length)
829 .find(|&i| !typical_price[i].is_nan() && !volume[i].is_nan())
830 .ok_or(MfiError::AllValuesNaN)?;
831
832 let max_p = combos
833 .iter()
834 .map(|c| round_up8(c.period.unwrap()))
835 .max()
836 .unwrap();
837 if length - first < max_p {
838 return Err(MfiError::NotEnoughValidData {
839 needed: max_p,
840 valid: length - first,
841 });
842 }
843
844 let cols = length;
845
846 if volume.len() != cols {
847 return Err(MfiError::EmptyInputData);
848 }
849
850 let rows = combos.len();
851 let use_prefix = rows >= 8;
852 let (pos_prefix, neg_prefix) = if use_prefix {
853 let (pp, np) =
854 unsafe { precompute_flow_prefixes_select(typical_price, volume, first, kern) };
855 (Some(pp), Some(np))
856 } else {
857 (None, None)
858 };
859
860 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
861 let period = combos[row].period.unwrap();
862
863 let warmup_end = first + period - 1;
864 for v in &mut out_row[..warmup_end] {
865 *v = f64::NAN;
866 }
867 if let (Some(ref pp), Some(ref np)) = (pos_prefix.as_ref(), neg_prefix.as_ref()) {
868 mfi_row_from_prefixes(pp, np, first, period, out_row)
869 } else {
870 match kern {
871 Kernel::Scalar => mfi_row_scalar(typical_price, volume, first, period, out_row),
872 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
873 Kernel::Avx2 => mfi_row_avx2(typical_price, volume, first, period, out_row),
874 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
875 Kernel::Avx512 => mfi_row_avx512(typical_price, volume, first, period, out_row),
876 _ => unreachable!(),
877 }
878 }
879 };
880
881 if parallel {
882 #[cfg(not(target_arch = "wasm32"))]
883 {
884 use rayon::prelude::*;
885 out.par_chunks_mut(cols)
886 .enumerate()
887 .for_each(|(row, slice)| do_row(row, slice));
888 }
889
890 #[cfg(target_arch = "wasm32")]
891 {
892 for (row, slice) in out.chunks_mut(cols).enumerate() {
893 do_row(row, slice);
894 }
895 }
896 } else {
897 for (row, slice) in out.chunks_mut(cols).enumerate() {
898 do_row(row, slice);
899 }
900 }
901
902 Ok(combos)
903}
904
905#[inline(always)]
906unsafe fn precompute_flow_prefixes_scalar(
907 typical_price: &[f64],
908 volume: &[f64],
909 first: usize,
910) -> (Vec<f64>, Vec<f64>) {
911 let len = typical_price.len();
912 let tp_ptr = typical_price.as_ptr();
913 let vol_ptr = volume.as_ptr();
914
915 let mut pos_prefix = vec![0.0f64; len];
916 let mut neg_prefix = vec![0.0f64; len];
917
918 if len == 0 {
919 return (pos_prefix, neg_prefix);
920 }
921
922 let mut i = first + 1;
923 let mut prev = *tp_ptr.add(first);
924 while i < len {
925 let tp_i = *tp_ptr.add(i);
926 let flow = tp_i * *vol_ptr.add(i);
927 let diff = tp_i - prev;
928 prev = tp_i;
929
930 let gt = (diff > 0.0) as i32 as f64;
931 let lt = (diff < 0.0) as i32 as f64;
932 let pos = flow * gt;
933 let neg = flow * lt;
934
935 pos_prefix[i] = pos_prefix[i - 1] + pos;
936 neg_prefix[i] = neg_prefix[i - 1] + neg;
937 i += 1;
938 }
939
940 if first > 0 {
941 pos_prefix[first] = 0.0;
942 neg_prefix[first] = 0.0;
943 }
944
945 (pos_prefix, neg_prefix)
946}
947
948#[inline(always)]
949unsafe fn precompute_flow_prefixes_select(
950 typical_price: &[f64],
951 volume: &[f64],
952 first: usize,
953 kern: Kernel,
954) -> (Vec<f64>, Vec<f64>) {
955 match kern {
956 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
957 Kernel::Avx2 | Kernel::Avx512 => {
958 precompute_flow_prefixes_avx2(typical_price, volume, first)
959 }
960 _ => precompute_flow_prefixes_scalar(typical_price, volume, first),
961 }
962}
963
964#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
965#[inline(always)]
966unsafe fn precompute_flow_prefixes_avx2(
967 typical_price: &[f64],
968 volume: &[f64],
969 first: usize,
970) -> (Vec<f64>, Vec<f64>) {
971 use core::arch::x86_64::*;
972 let len = typical_price.len();
973 let mut pos_prefix = vec![0.0f64; len];
974 let mut neg_prefix = vec![0.0f64; len];
975 if len == 0 {
976 return (pos_prefix, neg_prefix);
977 }
978
979 let mut pos_sum = 0.0f64;
980 let mut neg_sum = 0.0f64;
981
982 if first < len {
983 pos_prefix[first] = 0.0;
984 neg_prefix[first] = 0.0;
985 }
986
987 let mut i = first + 1;
988 let tp_ptr = typical_price.as_ptr();
989 let vol_ptr = volume.as_ptr();
990 let zero = _mm256_set1_pd(0.0);
991
992 while i + 4 <= len {
993 let tp_cur = _mm256_loadu_pd(tp_ptr.add(i));
994 let tp_prev = _mm256_loadu_pd(tp_ptr.add(i - 1));
995 let vol_cur = _mm256_loadu_pd(vol_ptr.add(i));
996
997 let flow = _mm256_mul_pd(tp_cur, vol_cur);
998
999 let diff = _mm256_sub_pd(tp_cur, tp_prev);
1000
1001 let m_gt = _mm256_cmp_pd(diff, zero, _CMP_GT_OQ);
1002 let m_lt = _mm256_cmp_pd(diff, zero, _CMP_LT_OQ);
1003
1004 let pos_v = _mm256_and_pd(flow, m_gt);
1005 let neg_v = _mm256_and_pd(flow, m_lt);
1006
1007 let mut pos_tmp = [0.0f64; 4];
1008 let mut neg_tmp = [0.0f64; 4];
1009 _mm256_storeu_pd(pos_tmp.as_mut_ptr(), pos_v);
1010 _mm256_storeu_pd(neg_tmp.as_mut_ptr(), neg_v);
1011
1012 pos_sum += pos_tmp[0];
1013 neg_sum += neg_tmp[0];
1014 *pos_prefix.get_unchecked_mut(i) = pos_sum;
1015 *neg_prefix.get_unchecked_mut(i) = neg_sum;
1016
1017 pos_sum += pos_tmp[1];
1018 neg_sum += neg_tmp[1];
1019 *pos_prefix.get_unchecked_mut(i + 1) = pos_sum;
1020 *neg_prefix.get_unchecked_mut(i + 1) = neg_sum;
1021
1022 pos_sum += pos_tmp[2];
1023 neg_sum += neg_tmp[2];
1024 *pos_prefix.get_unchecked_mut(i + 2) = pos_sum;
1025 *neg_prefix.get_unchecked_mut(i + 2) = neg_sum;
1026
1027 pos_sum += pos_tmp[3];
1028 neg_sum += neg_tmp[3];
1029 *pos_prefix.get_unchecked_mut(i + 3) = pos_sum;
1030 *neg_prefix.get_unchecked_mut(i + 3) = neg_sum;
1031
1032 i += 4;
1033 }
1034
1035 while i < len {
1036 let tp_i = *tp_ptr.add(i);
1037 let flow = tp_i * *vol_ptr.add(i);
1038 let diff = tp_i - *tp_ptr.add(i - 1);
1039 let gt = (diff > 0.0) as i32 as f64;
1040 let lt = (diff < 0.0) as i32 as f64;
1041 pos_sum += flow * gt;
1042 neg_sum += flow * lt;
1043 *pos_prefix.get_unchecked_mut(i) = pos_sum;
1044 *neg_prefix.get_unchecked_mut(i) = neg_sum;
1045 i += 1;
1046 }
1047
1048 (pos_prefix, neg_prefix)
1049}
1050
1051#[inline(always)]
1052unsafe fn mfi_row_from_prefixes(
1053 pos_prefix: &[f64],
1054 neg_prefix: &[f64],
1055 first: usize,
1056 period: usize,
1057 out: &mut [f64],
1058) {
1059 let len = out.len();
1060 if len == 0 {
1061 return;
1062 }
1063 let idx0 = first + period - 1;
1064 if idx0 >= len {
1065 return;
1066 }
1067
1068 let pos0 = pos_prefix[idx0] - pos_prefix[first];
1069 let neg0 = neg_prefix[idx0] - neg_prefix[first];
1070 let tot0 = pos0 + neg0;
1071 *out.get_unchecked_mut(idx0) = if tot0 < 1e-14 {
1072 0.0
1073 } else {
1074 100.0 * (pos0 / tot0)
1075 };
1076
1077 let mut i = idx0 + 1;
1078 while i < len {
1079 let base = i - period;
1080 let pos_sum = pos_prefix[i] - pos_prefix[base];
1081 let neg_sum = neg_prefix[i] - neg_prefix[base];
1082 let total = pos_sum + neg_sum;
1083 let val = if total < 1e-14 {
1084 0.0
1085 } else {
1086 100.0 * (pos_sum / total)
1087 };
1088 *out.get_unchecked_mut(i) = val;
1089 i += 1;
1090 }
1091}
1092
1093#[inline(always)]
1094unsafe fn mfi_row_scalar(
1095 typical_price: &[f64],
1096 volume: &[f64],
1097 first: usize,
1098 period: usize,
1099 out: &mut [f64],
1100) {
1101 mfi_scalar(typical_price, volume, period, first, out)
1102}
1103
1104#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1105#[inline(always)]
1106unsafe fn mfi_row_avx2(
1107 typical_price: &[f64],
1108 volume: &[f64],
1109 first: usize,
1110 period: usize,
1111 out: &mut [f64],
1112) {
1113 mfi_scalar(typical_price, volume, period, first, out)
1114}
1115
1116#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1117#[inline(always)]
1118pub unsafe fn mfi_row_avx512(
1119 typical_price: &[f64],
1120 volume: &[f64],
1121 first: usize,
1122 period: usize,
1123 out: &mut [f64],
1124) {
1125 if period <= 32 {
1126 mfi_row_avx512_short(typical_price, volume, first, period, out)
1127 } else {
1128 mfi_row_avx512_long(typical_price, volume, first, period, out)
1129 }
1130}
1131
1132#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1133#[inline(always)]
1134pub unsafe fn mfi_row_avx512_short(
1135 typical_price: &[f64],
1136 volume: &[f64],
1137 first: usize,
1138 period: usize,
1139 out: &mut [f64],
1140) {
1141 mfi_scalar(typical_price, volume, period, first, out)
1142}
1143
1144#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1145#[inline(always)]
1146pub unsafe fn mfi_row_avx512_long(
1147 typical_price: &[f64],
1148 volume: &[f64],
1149 first: usize,
1150 period: usize,
1151 out: &mut [f64],
1152) {
1153 mfi_scalar(typical_price, volume, period, first, out)
1154}
1155
1156#[cfg(feature = "python")]
1157#[pyfunction(name = "mfi")]
1158#[pyo3(signature = (typical_price, volume, period, kernel=None))]
1159pub fn mfi_py<'py>(
1160 py: Python<'py>,
1161 typical_price: PyReadonlyArray1<'py, f64>,
1162 volume: PyReadonlyArray1<'py, f64>,
1163 period: usize,
1164 kernel: Option<&str>,
1165) -> PyResult<Bound<'py, PyArray1<f64>>> {
1166 let typical_slice = typical_price.as_slice()?;
1167 let volume_slice = volume.as_slice()?;
1168 let kern = validate_kernel(kernel, false)?;
1169
1170 let params = MfiParams {
1171 period: Some(period),
1172 };
1173 let input = MfiInput::from_slices(typical_slice, volume_slice, params);
1174
1175 let result_vec: Vec<f64> = py
1176 .allow_threads(|| mfi_with_kernel(&input, kern).map(|o| o.values))
1177 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1178
1179 Ok(result_vec.into_pyarray(py))
1180}
1181
1182#[cfg(feature = "python")]
1183#[pyclass(name = "MfiStream")]
1184pub struct MfiStreamPy {
1185 inner: MfiStream,
1186}
1187
1188#[cfg(all(feature = "python", feature = "cuda"))]
1189#[pyclass(module = "ta_indicators.cuda", unsendable)]
1190pub struct MfiDeviceArrayF32Py {
1191 pub(crate) inner: Option<DeviceArrayF32>,
1192 pub(crate) ctx: Arc<Context>,
1193 pub(crate) device_id: i32,
1194}
1195
1196#[cfg(all(feature = "python", feature = "cuda"))]
1197#[pymethods]
1198impl MfiDeviceArrayF32Py {
1199 #[getter]
1200 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1201 let inner = self
1202 .inner
1203 .as_ref()
1204 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1205 let d = PyDict::new(py);
1206 d.set_item("shape", (inner.rows, inner.cols))?;
1207 d.set_item("typestr", "<f4")?;
1208 d.set_item(
1209 "strides",
1210 (
1211 inner.cols * std::mem::size_of::<f32>(),
1212 std::mem::size_of::<f32>(),
1213 ),
1214 )?;
1215 d.set_item("data", (inner.device_ptr() as usize, false))?;
1216 d.set_item("version", 3)?;
1217 Ok(d)
1218 }
1219
1220 fn __dlpack_device__(&self) -> (i32, i32) {
1221 (2, self.device_id)
1222 }
1223
1224 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1225 fn __dlpack__<'py>(
1226 &mut self,
1227 py: Python<'py>,
1228 stream: Option<pyo3::PyObject>,
1229 max_version: Option<pyo3::PyObject>,
1230 dl_device: Option<pyo3::PyObject>,
1231 copy: Option<pyo3::PyObject>,
1232 ) -> PyResult<PyObject> {
1233 let (kdl, alloc_dev) = self.__dlpack_device__();
1234 if let Some(dev_obj) = dl_device.as_ref() {
1235 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1236 if dev_ty != kdl || dev_id != alloc_dev {
1237 let wants_copy = copy
1238 .as_ref()
1239 .and_then(|c| c.extract::<bool>(py).ok())
1240 .unwrap_or(false);
1241 if wants_copy {
1242 return Err(PyValueError::new_err(
1243 "device copy not implemented for __dlpack__",
1244 ));
1245 } else {
1246 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1247 }
1248 }
1249 }
1250 }
1251 let _ = stream;
1252
1253 let inner = self
1254 .inner
1255 .take()
1256 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1257
1258 let rows = inner.rows;
1259 let cols = inner.cols;
1260 let buf = inner.buf;
1261
1262 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1263
1264 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1265 }
1266}
1267
1268#[cfg(feature = "python")]
1269#[pymethods]
1270impl MfiStreamPy {
1271 #[new]
1272 pub fn new(period: usize) -> PyResult<Self> {
1273 let params = MfiParams {
1274 period: Some(period),
1275 };
1276 let inner = MfiStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1277 Ok(MfiStreamPy { inner })
1278 }
1279
1280 pub fn update(&mut self, typical_price: f64, volume: f64) -> Option<f64> {
1281 self.inner.update(typical_price, volume)
1282 }
1283}
1284
1285#[cfg(feature = "python")]
1286#[pyfunction(name = "mfi_batch")]
1287#[pyo3(signature = (typical_price, volume, period_range, kernel=None))]
1288pub fn mfi_batch_py<'py>(
1289 py: Python<'py>,
1290 typical_price: PyReadonlyArray1<'py, f64>,
1291 volume: PyReadonlyArray1<'py, f64>,
1292 period_range: (usize, usize, usize),
1293 kernel: Option<&str>,
1294) -> PyResult<Bound<'py, PyDict>> {
1295 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1296 use pyo3::types::PyDict;
1297
1298 let tp = typical_price.as_slice()?;
1299 let vol = volume.as_slice()?;
1300 if tp.len() != vol.len() {
1301 return Err(PyValueError::new_err(
1302 "mfi_batch: typical_price and volume length mismatch",
1303 ));
1304 }
1305
1306 let sweep = MfiBatchRange {
1307 period: period_range,
1308 };
1309 let kern = validate_kernel(kernel, true)?;
1310
1311 let combos = expand_grid(&sweep);
1312 let rows = combos.len();
1313 let cols = tp.len();
1314
1315 let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1316 let out_slice = unsafe { out_arr.as_slice_mut()? };
1317
1318 let combos = py
1319 .allow_threads(|| {
1320 let k = match kern {
1321 Kernel::Auto => detect_best_batch_kernel(),
1322 k => k,
1323 };
1324
1325 let simd = match k {
1326 Kernel::Avx512Batch => Kernel::Avx512,
1327 Kernel::Avx2Batch => Kernel::Avx2,
1328 Kernel::ScalarBatch => Kernel::Scalar,
1329 _ => k,
1330 };
1331 mfi_batch_inner_into(tp, vol, &sweep, simd, true, out_slice)
1332 })
1333 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1334
1335 let dict = PyDict::new(py);
1336
1337 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1338 dict.set_item(
1339 "periods",
1340 combos
1341 .iter()
1342 .map(|p| p.period.unwrap() as u64)
1343 .collect::<Vec<_>>()
1344 .into_pyarray(py),
1345 )?;
1346 Ok(dict)
1347}
1348
1349#[cfg(all(feature = "python", feature = "cuda"))]
1350#[pyfunction(name = "mfi_cuda_batch_dev")]
1351#[pyo3(signature = (typical_price, volume, period_range, device_id=0))]
1352pub fn mfi_cuda_batch_dev_py(
1353 py: Python<'_>,
1354 typical_price: PyReadonlyArray1<'_, f32>,
1355 volume: PyReadonlyArray1<'_, f32>,
1356 period_range: (usize, usize, usize),
1357 device_id: usize,
1358) -> PyResult<MfiDeviceArrayF32Py> {
1359 if !cuda_available() {
1360 return Err(PyValueError::new_err("CUDA not available"));
1361 }
1362 let tp = typical_price.as_slice()?;
1363 let vol = volume.as_slice()?;
1364 if tp.len() != vol.len() {
1365 return Err(PyValueError::new_err("mismatched input lengths"));
1366 }
1367 let sweep = MfiBatchRange {
1368 period: period_range,
1369 };
1370 let (inner, ctx, dev_id) = py.allow_threads(|| {
1371 let cuda = CudaMfi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1372 let ctx = cuda.context_arc();
1373 let dev_id = cuda.device_id() as i32;
1374 let (arr, _combos) = cuda
1375 .mfi_batch_dev(tp, vol, &sweep)
1376 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1377 Ok::<_, PyErr>((arr, ctx, dev_id))
1378 })?;
1379 Ok(MfiDeviceArrayF32Py {
1380 inner: Some(inner),
1381 ctx,
1382 device_id: dev_id,
1383 })
1384}
1385
1386#[cfg(all(feature = "python", feature = "cuda"))]
1387#[pyfunction(name = "mfi_cuda_many_series_one_param_dev")]
1388#[pyo3(signature = (typical_price_tm, volume_tm, cols, rows, period, device_id=0))]
1389pub fn mfi_cuda_many_series_one_param_dev_py(
1390 py: Python<'_>,
1391 typical_price_tm: PyReadonlyArray1<'_, f32>,
1392 volume_tm: PyReadonlyArray1<'_, f32>,
1393 cols: usize,
1394 rows: usize,
1395 period: usize,
1396 device_id: usize,
1397) -> PyResult<MfiDeviceArrayF32Py> {
1398 if !cuda_available() {
1399 return Err(PyValueError::new_err("CUDA not available"));
1400 }
1401 let tp = typical_price_tm.as_slice()?;
1402 let vol = volume_tm.as_slice()?;
1403 if tp.len() != vol.len() {
1404 return Err(PyValueError::new_err("mismatched input lengths"));
1405 }
1406 if tp.len() != cols * rows {
1407 return Err(PyValueError::new_err("unexpected matrix size"));
1408 }
1409 let (inner, ctx, dev_id) = py.allow_threads(|| {
1410 let cuda = CudaMfi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1411 let ctx = cuda.context_arc();
1412 let dev_id = cuda.device_id() as i32;
1413 let arr = cuda
1414 .mfi_many_series_one_param_time_major_dev(tp, vol, cols, rows, period)
1415 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1416 Ok::<_, PyErr>((arr, ctx, dev_id))
1417 })?;
1418 Ok(MfiDeviceArrayF32Py {
1419 inner: Some(inner),
1420 ctx,
1421 device_id: dev_id,
1422 })
1423}
1424
1425#[inline]
1426pub fn mfi_into_slice(dst: &mut [f64], input: &MfiInput, kern: Kernel) -> Result<(), MfiError> {
1427 let (typical_price, volume, period, first_valid_idx, chosen) = mfi_prepare(input, kern)?;
1428
1429 if dst.len() != typical_price.len() {
1430 return Err(MfiError::OutputLengthMismatch {
1431 expected: typical_price.len(),
1432 got: dst.len(),
1433 });
1434 }
1435
1436 mfi_compute_into(typical_price, volume, period, first_valid_idx, chosen, dst);
1437
1438 let warmup_period = first_valid_idx + period - 1;
1439
1440 let nan_q = f64::from_bits(0x7ff8_0000_0000_0000);
1441 for v in &mut dst[..warmup_period] {
1442 *v = nan_q;
1443 }
1444
1445 Ok(())
1446}
1447
1448#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1449pub fn mfi_into(input: &MfiInput, out: &mut [f64]) -> Result<(), MfiError> {
1450 mfi_into_slice(out, input, Kernel::Auto)
1451}
1452
1453#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1454#[wasm_bindgen]
1455pub fn mfi_js(typical_price: &[f64], volume: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1456 let params = MfiParams {
1457 period: Some(period),
1458 };
1459 let input = MfiInput::from_slices(typical_price, volume, params);
1460
1461 let result = mfi_with_kernel(&input, detect_best_kernel())
1462 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1463
1464 Ok(result.values)
1465}
1466
1467#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1468#[wasm_bindgen]
1469pub fn mfi_into(
1470 typical_price_ptr: *const f64,
1471 volume_ptr: *const f64,
1472 out_ptr: *mut f64,
1473 len: usize,
1474 period: usize,
1475) -> Result<(), JsValue> {
1476 if typical_price_ptr.is_null() || volume_ptr.is_null() || out_ptr.is_null() {
1477 return Err(JsValue::from_str("Null pointer provided"));
1478 }
1479
1480 unsafe {
1481 let typical_price = std::slice::from_raw_parts(typical_price_ptr, len);
1482 let volume = std::slice::from_raw_parts(volume_ptr, len);
1483 let params = MfiParams {
1484 period: Some(period),
1485 };
1486 let input = MfiInput::from_slices(typical_price, volume, params);
1487
1488 if typical_price_ptr == out_ptr || volume_ptr == out_ptr {
1489 let result = mfi_with_kernel(&input, detect_best_kernel())
1490 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1491 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1492 out.copy_from_slice(&result.values);
1493 } else {
1494 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1495 mfi_into_slice(out, &input, detect_best_kernel())
1496 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1497 }
1498 Ok(())
1499 }
1500}
1501
1502#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1503#[wasm_bindgen]
1504pub fn mfi_alloc(len: usize) -> *mut f64 {
1505 let mut vec = Vec::<f64>::with_capacity(len);
1506 let ptr = vec.as_mut_ptr();
1507 std::mem::forget(vec);
1508 ptr
1509}
1510
1511#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1512#[wasm_bindgen]
1513pub fn mfi_free(ptr: *mut f64, len: usize) {
1514 if !ptr.is_null() {
1515 unsafe {
1516 let _ = Vec::from_raw_parts(ptr, len, len);
1517 }
1518 }
1519}
1520
1521#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1522#[derive(Serialize, Deserialize)]
1523pub struct MfiBatchConfig {
1524 pub period_range: (usize, usize, usize),
1525}
1526
1527#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1528#[derive(Serialize, Deserialize)]
1529pub struct MfiBatchJsOutput {
1530 pub values: Vec<f64>,
1531 pub combos: Vec<MfiParams>,
1532 pub rows: usize,
1533 pub cols: usize,
1534}
1535
1536#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1537#[wasm_bindgen(js_name = mfi_batch)]
1538pub fn mfi_batch_unified_js(
1539 typical_price: &[f64],
1540 volume: &[f64],
1541 config: JsValue,
1542) -> Result<JsValue, JsValue> {
1543 let config: MfiBatchConfig = serde_wasm_bindgen::from_value(config)
1544 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1545
1546 let sweep = MfiBatchRange {
1547 period: config.period_range,
1548 };
1549
1550 let output = mfi_batch_inner(typical_price, volume, &sweep, detect_best_kernel(), false)
1551 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1552
1553 let js_output = MfiBatchJsOutput {
1554 values: output.values,
1555 combos: output.combos,
1556 rows: output.rows,
1557 cols: output.cols,
1558 };
1559
1560 serde_wasm_bindgen::to_value(&js_output)
1561 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1562}
1563
1564#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1565#[wasm_bindgen]
1566pub fn mfi_batch_into(
1567 typical_price_ptr: *const f64,
1568 volume_ptr: *const f64,
1569 out_ptr: *mut f64,
1570 len: usize,
1571 period_start: usize,
1572 period_end: usize,
1573 period_step: usize,
1574) -> Result<usize, JsValue> {
1575 if typical_price_ptr.is_null() || volume_ptr.is_null() || out_ptr.is_null() {
1576 return Err(JsValue::from_str("null pointer passed to mfi_batch_into"));
1577 }
1578 unsafe {
1579 let tp = std::slice::from_raw_parts(typical_price_ptr, len);
1580 let vol = std::slice::from_raw_parts(volume_ptr, len);
1581
1582 let sweep = MfiBatchRange {
1583 period: (period_start, period_end, period_step),
1584 };
1585 let combos = expand_grid(&sweep);
1586 let rows = combos.len();
1587 let cols = len;
1588 let total = rows
1589 .checked_mul(cols)
1590 .ok_or_else(|| JsValue::from_str("mfi_batch_into: rows*cols overflow"))?;
1591
1592 let out = std::slice::from_raw_parts_mut(out_ptr, total);
1593
1594 mfi_batch_inner_into(tp, vol, &sweep, detect_best_kernel(), false, out)
1595 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1596
1597 Ok(rows)
1598 }
1599}
1600
1601#[cfg(test)]
1602mod tests {
1603 use super::*;
1604 use crate::skip_if_unsupported;
1605 use crate::utilities::data_loader::read_candles_from_csv;
1606 use paste::paste;
1607 use std::error::Error;
1608
1609 #[test]
1610 fn test_mfi_into_matches_api() -> Result<(), Box<dyn Error>> {
1611 let n = 256usize;
1612 let mut tp = Vec::with_capacity(n);
1613 let mut vol = Vec::with_capacity(n);
1614 for i in 0..n {
1615 let i_f = i as f64;
1616
1617 let price = 100.0 + 0.123 * i_f + ((i % 7) as f64 - 3.0) * 0.05;
1618 tp.push(price);
1619
1620 vol.push(1_000.0 + ((i * 37) % 113) as f64);
1621 }
1622
1623 let input = MfiInput::from_slices(&tp, &vol, MfiParams::default());
1624
1625 let baseline = mfi(&input)?.values;
1626
1627 let mut out = vec![0.0f64; n];
1628 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1629 {
1630 mfi_into(&input, &mut out)?;
1631 }
1632 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1633 {
1634 mfi_into_slice(&mut out, &input, Kernel::Auto)?;
1635 }
1636
1637 assert_eq!(baseline.len(), out.len());
1638
1639 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1640 (a.is_nan() && b.is_nan()) || (a == b)
1641 }
1642 for (i, (a, b)) in baseline.iter().zip(out.iter()).enumerate() {
1643 assert!(
1644 eq_or_both_nan(*a, *b),
1645 "mismatch at {}: baseline={} vs into={}",
1646 i,
1647 a,
1648 b
1649 );
1650 }
1651
1652 Ok(())
1653 }
1654
1655 fn check_mfi_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1656 skip_if_unsupported!(kernel, test_name);
1657 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1658 let candles = read_candles_from_csv(file_path)?;
1659 let default_params = MfiParams { period: None };
1660 let input = MfiInput::from_candles(&candles, "hlc3", default_params);
1661 let output = mfi_with_kernel(&input, kernel)?;
1662 assert_eq!(output.values.len(), candles.close.len());
1663 Ok(())
1664 }
1665
1666 fn check_mfi_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1667 skip_if_unsupported!(kernel, test_name);
1668 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1669 let candles = read_candles_from_csv(file_path)?;
1670 let params = MfiParams { period: Some(14) };
1671 let input = MfiInput::from_candles(&candles, "hlc3", params);
1672 let mfi_result = mfi_with_kernel(&input, kernel)?;
1673 let expected_last_five_mfi = [
1674 38.13874339324763,
1675 37.44139770113819,
1676 31.02039511395131,
1677 28.092605898618896,
1678 25.905204729397813,
1679 ];
1680 let start_index = mfi_result.values.len() - 5;
1681 for (i, &value) in mfi_result.values[start_index..].iter().enumerate() {
1682 let expected_value = expected_last_five_mfi[i];
1683 let diff = (value - expected_value).abs();
1684 assert!(
1685 diff < 1e-1,
1686 "MFI mismatch at index {}: expected {}, got {}",
1687 i,
1688 expected_value,
1689 value
1690 );
1691 }
1692 Ok(())
1693 }
1694
1695 fn check_mfi_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1696 skip_if_unsupported!(kernel, test_name);
1697 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1698 let candles = read_candles_from_csv(file_path)?;
1699 let input = MfiInput::with_default_candles(&candles);
1700 let output = mfi_with_kernel(&input, kernel)?;
1701 assert_eq!(output.values.len(), candles.close.len());
1702 Ok(())
1703 }
1704
1705 fn check_mfi_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1706 skip_if_unsupported!(kernel, test_name);
1707 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1708 let candles = read_candles_from_csv(file_path)?;
1709 let params = MfiParams { period: Some(0) };
1710 let input = MfiInput::from_candles(&candles, "hlc3", params);
1711 let result = mfi_with_kernel(&input, kernel);
1712 assert!(result.is_err());
1713 Ok(())
1714 }
1715
1716 fn check_mfi_period_exceeds_length(
1717 test_name: &str,
1718 kernel: Kernel,
1719 ) -> Result<(), Box<dyn Error>> {
1720 skip_if_unsupported!(kernel, test_name);
1721 let input_high = [1.0, 2.0, 3.0];
1722 let input_low = [0.5, 1.5, 2.5];
1723 let input_close = [0.8, 1.8, 2.8];
1724 let input_volume = [100.0, 200.0, 300.0];
1725
1726 let typical_price: Vec<f64> = input_high
1727 .iter()
1728 .zip(&input_low)
1729 .zip(&input_close)
1730 .map(|((h, l), c)| (h + l + c) / 3.0)
1731 .collect();
1732 let params = MfiParams { period: Some(10) };
1733 let input = MfiInput::from_slices(&typical_price, &input_volume, params);
1734 let result = mfi_with_kernel(&input, kernel);
1735 assert!(result.is_err());
1736 Ok(())
1737 }
1738
1739 fn check_mfi_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1740 skip_if_unsupported!(kernel, test_name);
1741 let input_high = [1.0];
1742 let input_low = [0.5];
1743 let input_close = [0.8];
1744 let input_volume = [100.0];
1745
1746 let typical_price = [(input_high[0] + input_low[0] + input_close[0]) / 3.0];
1747 let params = MfiParams { period: Some(14) };
1748 let input = MfiInput::from_slices(&typical_price, &input_volume, params);
1749 let result = mfi_with_kernel(&input, kernel);
1750 assert!(result.is_err());
1751 Ok(())
1752 }
1753
1754 fn check_mfi_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1755 skip_if_unsupported!(kernel, test_name);
1756 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1757 let candles = read_candles_from_csv(file_path)?;
1758 let first_params = MfiParams { period: Some(7) };
1759 let first_input = MfiInput::from_candles(&candles, "hlc3", first_params);
1760 let first_result = mfi_with_kernel(&first_input, kernel)?;
1761 let second_params = MfiParams { period: Some(7) };
1762
1763 let typical_price_values: Vec<f64> = first_result.values.clone();
1764 let volume_values: Vec<f64> = vec![10_000.0; first_result.values.len()];
1765 let second_input =
1766 MfiInput::from_slices(&typical_price_values, &volume_values, second_params);
1767 let second_result = mfi_with_kernel(&second_input, kernel)?;
1768 assert_eq!(second_result.values.len(), first_result.values.len());
1769 Ok(())
1770 }
1771
1772 #[cfg(debug_assertions)]
1773 fn check_mfi_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1774 skip_if_unsupported!(kernel, test_name);
1775
1776 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1777 let candles = read_candles_from_csv(file_path)?;
1778
1779 let test_params = vec![
1780 MfiParams::default(),
1781 MfiParams { period: Some(2) },
1782 MfiParams { period: Some(5) },
1783 MfiParams { period: Some(7) },
1784 MfiParams { period: Some(10) },
1785 MfiParams { period: Some(14) },
1786 MfiParams { period: Some(20) },
1787 MfiParams { period: Some(30) },
1788 MfiParams { period: Some(50) },
1789 MfiParams { period: Some(100) },
1790 MfiParams { period: Some(200) },
1791 ];
1792
1793 for (param_idx, params) in test_params.iter().enumerate() {
1794 let input = MfiInput::from_candles(&candles, "hlc3", params.clone());
1795 let output = mfi_with_kernel(&input, kernel)?;
1796
1797 for (i, &val) in output.values.iter().enumerate() {
1798 if val.is_nan() {
1799 continue;
1800 }
1801
1802 let bits = val.to_bits();
1803
1804 if bits == 0x11111111_11111111 {
1805 panic!(
1806 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1807 with params: period={} (param set {})",
1808 test_name,
1809 val,
1810 bits,
1811 i,
1812 params.period.unwrap_or(14),
1813 param_idx
1814 );
1815 }
1816
1817 if bits == 0x22222222_22222222 {
1818 panic!(
1819 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1820 with params: period={} (param set {})",
1821 test_name,
1822 val,
1823 bits,
1824 i,
1825 params.period.unwrap_or(14),
1826 param_idx
1827 );
1828 }
1829
1830 if bits == 0x33333333_33333333 {
1831 panic!(
1832 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1833 with params: period={} (param set {})",
1834 test_name,
1835 val,
1836 bits,
1837 i,
1838 params.period.unwrap_or(14),
1839 param_idx
1840 );
1841 }
1842 }
1843 }
1844
1845 Ok(())
1846 }
1847
1848 #[cfg(not(debug_assertions))]
1849 fn check_mfi_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1850 Ok(())
1851 }
1852
1853 #[cfg(feature = "proptest")]
1854 #[allow(clippy::float_cmp)]
1855 fn check_mfi_property(
1856 test_name: &str,
1857 kernel: Kernel,
1858 ) -> Result<(), Box<dyn std::error::Error>> {
1859 use proptest::prelude::*;
1860 skip_if_unsupported!(kernel, test_name);
1861
1862 let strat = (2usize..=50).prop_flat_map(|period| {
1863 (period..=400).prop_flat_map(move |data_len| {
1864 prop_oneof![
1865
1866 6 => (
1867
1868 (10.0f64..10000.0f64),
1869
1870 (0.01f64..0.2f64),
1871
1872 (1000.0f64..1_000_000.0f64),
1873
1874 prop::collection::vec(-1.0f64..1.0f64, data_len),
1875
1876 prop::collection::vec(0.0f64..1.0f64, data_len),
1877 ).prop_map(move |(base_price, volatility, volume_mult, changes, vol_factors)| {
1878 let mut typical_price = Vec::with_capacity(data_len);
1879 let mut volume = Vec::with_capacity(data_len);
1880 let mut price = base_price;
1881
1882 for i in 0..data_len {
1883
1884 let change = changes[i] * volatility;
1885 price *= 1.0 + change;
1886 price = price.max(0.01);
1887 typical_price.push(price);
1888
1889
1890 let vol = volume_mult * (0.5 + vol_factors[i] + change.abs() * 2.0);
1891 volume.push(vol.max(0.0));
1892 }
1893
1894 (typical_price, volume, period)
1895 }),
1896
1897
1898 15 => prop::collection::vec(100.0f64..1000.0f64, 1..=1)
1899 .prop_map(move |prices| {
1900 let price = prices[0];
1901 let typical_price = vec![price; data_len];
1902 let volume = vec![10000.0; data_len];
1903 (typical_price, volume, period)
1904 }),
1905
1906
1907 15 => prop::bool::ANY.prop_map(move |uptrend| {
1908 let mut typical_price = Vec::with_capacity(data_len);
1909 let mut volume = Vec::with_capacity(data_len);
1910 let start_price = 100.0;
1911
1912 for i in 0..data_len {
1913 let trend_factor = if uptrend {
1914 1.0 + (i as f64 / data_len as f64) * 2.0
1915 } else {
1916 1.0 - (i as f64 / data_len as f64) * 0.7
1917 };
1918 typical_price.push(start_price * trend_factor);
1919
1920 volume.push(10000.0 * (1.0 + i as f64 / data_len as f64) * 2.0);
1921 }
1922
1923 (typical_price, volume, period)
1924 }),
1925
1926
1927 1 => Just((
1928 (0..data_len).map(|i| 100.0 + (i as f64)).collect::<Vec<_>>(),
1929 vec![0.0; data_len],
1930 period
1931 )),
1932 ]
1933 })
1934 });
1935
1936 proptest::test_runner::TestRunner::default().run(
1937 &strat,
1938 |(typical_price, volume, period)| {
1939 let params = MfiParams {
1940 period: Some(period),
1941 };
1942 let input = MfiInput::from_slices(&typical_price, &volume, params.clone());
1943
1944 let MfiOutput { values: out } = mfi_with_kernel(&input, kernel)?;
1945
1946 let MfiOutput { values: ref_out } = mfi_with_kernel(&input, Kernel::Scalar)?;
1947
1948 prop_assert_eq!(out.len(), typical_price.len(), "Output length mismatch");
1949
1950 let first_valid_idx = (0..typical_price.len())
1951 .find(|&i| !typical_price[i].is_nan() && !volume[i].is_nan())
1952 .unwrap_or(0);
1953
1954 let expected_warmup = first_valid_idx + period - 1;
1955
1956 for i in 0..out.len() {
1957 if i < expected_warmup {
1958 prop_assert!(
1959 out[i].is_nan(),
1960 "Expected NaN during warmup at index {}, got {}",
1961 i,
1962 out[i]
1963 );
1964 } else if i == expected_warmup {
1965 prop_assert!(
1966 !out[i].is_nan(),
1967 "Expected first non-NaN at index {} but got NaN",
1968 i
1969 );
1970 }
1971 }
1972
1973 for (i, &val) in out.iter().enumerate().skip(expected_warmup) {
1974 if !val.is_nan() {
1975 prop_assert!(
1976 val >= 0.0 && val <= 100.0,
1977 "MFI out of bounds at index {}: {}",
1978 i,
1979 val
1980 );
1981 }
1982 }
1983
1984 let is_constant = typical_price
1985 .windows(2)
1986 .all(|w| (w[0] - w[1]).abs() < 1e-10);
1987 if is_constant && expected_warmup < out.len() {
1988 for i in expected_warmup..out.len() {
1989 if !out[i].is_nan() {
1990 prop_assert!(
1991 out[i].abs() < 1e-3,
1992 "Constant price MFI should be ~0, got {} at index {}",
1993 out[i],
1994 i
1995 );
1996 }
1997 }
1998 }
1999
2000 let all_zero_volume = volume.iter().all(|&v| v.abs() < 1e-14);
2001 if all_zero_volume && expected_warmup < out.len() {
2002 for i in expected_warmup..out.len() {
2003 if !out[i].is_nan() {
2004 prop_assert!(
2005 out[i].abs() < 1e-3,
2006 "Zero volume MFI should be 0, got {} at index {}",
2007 out[i],
2008 i
2009 );
2010 }
2011 }
2012 }
2013
2014 if expected_warmup + period < typical_price.len() {
2015 let check_idx = expected_warmup + period;
2016
2017 let window_start = check_idx - period + 1;
2018 let window_end = check_idx;
2019
2020 let mut up_volume = 0.0;
2021 let mut down_volume = 0.0;
2022
2023 for i in window_start..window_end {
2024 if i > 0 && i < typical_price.len() {
2025 let price_change = typical_price[i] - typical_price[i - 1];
2026 if price_change > 0.0 {
2027 up_volume += volume[i] * typical_price[i];
2028 } else if price_change < 0.0 {
2029 down_volume += volume[i] * typical_price[i];
2030 }
2031 }
2032 }
2033
2034 if up_volume > down_volume * 2.0 && check_idx < out.len() {
2035 let mfi_val = out[check_idx];
2036 if !mfi_val.is_nan() && (up_volume + down_volume) > 1e-10 {
2037 prop_assert!(
2038 mfi_val > 50.0,
2039 "MFI should be > 50 when up money flow dominates (up: {}, down: {}), got {}",
2040 up_volume,
2041 down_volume,
2042 mfi_val
2043 );
2044 }
2045 }
2046
2047 if down_volume > up_volume * 2.0 && check_idx < out.len() {
2048 let mfi_val = out[check_idx];
2049 if !mfi_val.is_nan() && (up_volume + down_volume) > 1e-10 {
2050 prop_assert!(
2051 mfi_val < 50.0,
2052 "MFI should be < 50 when down money flow dominates (up: {}, down: {}), got {}",
2053 up_volume,
2054 down_volume,
2055 mfi_val
2056 );
2057 }
2058 }
2059 }
2060
2061 if expected_warmup + 5 < typical_price.len() {
2062 let verify_idx = expected_warmup + 5;
2063
2064 let mut pos_sum = 0.0;
2065 let mut neg_sum = 0.0;
2066
2067 let window_start = verify_idx - period + 1;
2068
2069 for i in window_start..=verify_idx {
2070 if i > 0 && i < typical_price.len() {
2071 let price_diff = typical_price[i] - typical_price[i - 1];
2072 let money_flow = typical_price[i] * volume[i];
2073
2074 if price_diff > 0.0 {
2075 pos_sum += money_flow;
2076 } else if price_diff < 0.0 {
2077 neg_sum += money_flow;
2078 }
2079 }
2080 }
2081
2082 let total = pos_sum + neg_sum;
2083 let expected_mfi = if total < 1e-14 {
2084 0.0
2085 } else {
2086 100.0 * (pos_sum / total)
2087 };
2088
2089 let actual_mfi = out[verify_idx];
2090 if !actual_mfi.is_nan() {
2091 prop_assert!(
2092 (actual_mfi - expected_mfi).abs() < 0.1,
2093 "MFI formula verification failed at index {}: expected {} (pos: {}, neg: {}), got {}",
2094 verify_idx,
2095 expected_mfi,
2096 pos_sum,
2097 neg_sum,
2098 actual_mfi
2099 );
2100 }
2101 }
2102
2103 if period >= 5 && period <= 20 {
2104 let test_len = period * 3;
2105 let mut prices = Vec::with_capacity(test_len);
2106 let mut increasing_vol = Vec::with_capacity(test_len);
2107 let mut decreasing_vol = Vec::with_capacity(test_len);
2108
2109 for i in 0..test_len {
2110 prices.push(100.0 + i as f64);
2111
2112 increasing_vol.push(1000.0 * (1.0 + i as f64));
2113
2114 decreasing_vol.push(1000.0 * (test_len as f64 - i as f64));
2115 }
2116
2117 let input_inc = MfiInput::from_slices(&prices, &increasing_vol, params.clone());
2118 let MfiOutput { values: out_inc } = mfi_with_kernel(&input_inc, kernel)?;
2119
2120 let input_dec = MfiInput::from_slices(&prices, &decreasing_vol, params.clone());
2121 let MfiOutput { values: out_dec } = mfi_with_kernel(&input_dec, kernel)?;
2122
2123 let check_idx = period * 2;
2124 if check_idx < out_inc.len() {
2125 let mfi_inc = out_inc[check_idx];
2126 let mfi_dec = out_dec[check_idx];
2127
2128 if !mfi_inc.is_nan() && !mfi_dec.is_nan() {
2129 prop_assert!(
2130 mfi_inc > 90.0,
2131 "MFI with increasing volume on uptrend should be > 90, got {}",
2132 mfi_inc
2133 );
2134 prop_assert!(
2135 mfi_dec > 90.0,
2136 "MFI with decreasing volume on uptrend should be > 90, got {}",
2137 mfi_dec
2138 );
2139
2140 prop_assert!(
2141 mfi_inc > mfi_dec,
2142 "MFI with increasing volume ({}) should be > MFI with decreasing volume ({}) on uptrend",
2143 mfi_inc,
2144 mfi_dec
2145 );
2146 }
2147 }
2148 }
2149
2150 for i in 0..out.len() {
2151 let y = out[i];
2152 let r = ref_out[i];
2153
2154 if y.is_nan() || r.is_nan() {
2155 prop_assert_eq!(
2156 y.is_nan(),
2157 r.is_nan(),
2158 "NaN mismatch at index {}: kernel={}, scalar={}",
2159 i,
2160 y,
2161 r
2162 );
2163 continue;
2164 }
2165
2166 let y_bits = y.to_bits();
2167 let r_bits = r.to_bits();
2168 let ulp_diff = y_bits.abs_diff(r_bits);
2169
2170 prop_assert!(
2171 (y - r).abs() <= 1e-9 || ulp_diff <= 5,
2172 "Kernel mismatch at index {}: {} vs {} (ULP={})",
2173 i,
2174 y,
2175 r,
2176 ulp_diff
2177 );
2178 }
2179
2180 Ok(())
2181 },
2182 )?;
2183
2184 Ok(())
2185 }
2186
2187 macro_rules! generate_all_mfi_tests {
2188 ($($test_fn:ident),*) => {
2189 paste! {
2190 $(
2191 #[test]
2192 fn [<$test_fn _scalar_f64>]() {
2193 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2194 }
2195 )*
2196 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2197 $(
2198 #[test]
2199 fn [<$test_fn _avx2_f64>]() {
2200 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2201 }
2202 #[test]
2203 fn [<$test_fn _avx512_f64>]() {
2204 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2205 }
2206 )*
2207 }
2208 }
2209 }
2210 generate_all_mfi_tests!(
2211 check_mfi_partial_params,
2212 check_mfi_accuracy,
2213 check_mfi_default_candles,
2214 check_mfi_zero_period,
2215 check_mfi_period_exceeds_length,
2216 check_mfi_very_small_dataset,
2217 check_mfi_reinput,
2218 check_mfi_no_poison
2219 );
2220
2221 #[cfg(feature = "proptest")]
2222 generate_all_mfi_tests!(check_mfi_property);
2223 fn check_batch_default_row(
2224 test: &str,
2225 kernel: Kernel,
2226 ) -> Result<(), Box<dyn std::error::Error>> {
2227 skip_if_unsupported!(kernel, test);
2228 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2229 let c = read_candles_from_csv(file)?;
2230
2231 let output = MfiBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
2232
2233 let def = MfiParams::default();
2234 let row = output.values_for(&def).expect("default row missing");
2235
2236 assert_eq!(row.len(), c.close.len());
2237
2238 let expected = [
2239 38.13874339324763,
2240 37.44139770113819,
2241 31.02039511395131,
2242 28.092605898618896,
2243 25.905204729397813,
2244 ];
2245 let start = row.len().saturating_sub(5);
2246 for (i, &v) in row[start..].iter().enumerate() {
2247 assert!(
2248 (v - expected[i]).abs() < 1e-1,
2249 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2250 );
2251 }
2252 Ok(())
2253 }
2254
2255 #[cfg(debug_assertions)]
2256 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2257 skip_if_unsupported!(kernel, test);
2258
2259 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2260 let c = read_candles_from_csv(file)?;
2261
2262 let test_configs = vec![
2263 (2, 10, 2),
2264 (5, 25, 5),
2265 (30, 60, 15),
2266 (2, 5, 1),
2267 (10, 50, 10),
2268 (7, 21, 7),
2269 (14, 14, 0),
2270 ];
2271
2272 for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
2273 let output = MfiBatchBuilder::new()
2274 .kernel(kernel)
2275 .period_range(p_start, p_end, p_step)
2276 .apply_candles(&c)?;
2277
2278 for (idx, &val) in output.values.iter().enumerate() {
2279 if val.is_nan() {
2280 continue;
2281 }
2282
2283 let bits = val.to_bits();
2284 let row = idx / output.cols;
2285 let col = idx % output.cols;
2286 let combo = &output.combos[row];
2287
2288 if bits == 0x11111111_11111111 {
2289 panic!(
2290 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2291 at row {} col {} (flat index {}) with params: period={}",
2292 test,
2293 cfg_idx,
2294 val,
2295 bits,
2296 row,
2297 col,
2298 idx,
2299 combo.period.unwrap_or(14)
2300 );
2301 }
2302
2303 if bits == 0x22222222_22222222 {
2304 panic!(
2305 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2306 at row {} col {} (flat index {}) with params: period={}",
2307 test,
2308 cfg_idx,
2309 val,
2310 bits,
2311 row,
2312 col,
2313 idx,
2314 combo.period.unwrap_or(14)
2315 );
2316 }
2317
2318 if bits == 0x33333333_33333333 {
2319 panic!(
2320 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2321 at row {} col {} (flat index {}) with params: period={}",
2322 test,
2323 cfg_idx,
2324 val,
2325 bits,
2326 row,
2327 col,
2328 idx,
2329 combo.period.unwrap_or(14)
2330 );
2331 }
2332 }
2333 }
2334
2335 Ok(())
2336 }
2337
2338 #[cfg(not(debug_assertions))]
2339 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2340 Ok(())
2341 }
2342
2343 macro_rules! gen_batch_tests {
2344 ($fn_name:ident) => {
2345 paste! {
2346 #[test] fn [<$fn_name _scalar>]() {
2347 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2348 }
2349 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2350 #[test] fn [<$fn_name _avx2>]() {
2351 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2352 }
2353 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2354 #[test] fn [<$fn_name _avx512>]() {
2355 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2356 }
2357 #[test] fn [<$fn_name _auto_detect>]() {
2358 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2359 }
2360 }
2361 };
2362 }
2363
2364 gen_batch_tests!(check_batch_default_row);
2365 gen_batch_tests!(check_batch_no_poison);
2366}