Skip to main content

vector_ta/indicators/
mfi.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde_wasm_bindgen;
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::{source_type, Candles};
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21    make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
26use core::arch::x86_64::*;
27#[cfg(not(target_arch = "wasm32"))]
28use rayon::prelude::*;
29use std::convert::AsRef;
30use std::error::Error;
31use std::mem::MaybeUninit;
32use thiserror::Error;
33
34#[cfg(all(feature = "python", feature = "cuda"))]
35use crate::cuda::cuda_available;
36#[cfg(all(feature = "python", feature = "cuda"))]
37use crate::cuda::moving_averages::DeviceArrayF32;
38#[cfg(all(feature = "python", feature = "cuda"))]
39use crate::cuda::oscillators::mfi_wrapper::CudaMfi;
40#[cfg(all(feature = "python", feature = "cuda"))]
41use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
42#[cfg(all(feature = "python", feature = "cuda"))]
43use cust::context::Context;
44#[cfg(all(feature = "python", feature = "cuda"))]
45use std::sync::Arc;
46
47#[derive(Debug, Clone)]
48pub enum MfiData<'a> {
49    Candles {
50        candles: &'a Candles,
51        source: &'a str,
52    },
53    Slices {
54        typical_price: &'a [f64],
55        volume: &'a [f64],
56    },
57}
58
59#[derive(Debug, Clone)]
60pub struct MfiOutput {
61    pub values: Vec<f64>,
62}
63
64#[derive(Debug, Clone)]
65#[cfg_attr(
66    all(target_arch = "wasm32", feature = "wasm"),
67    derive(Serialize, Deserialize)
68)]
69pub struct MfiParams {
70    pub period: Option<usize>,
71}
72
73impl Default for MfiParams {
74    fn default() -> Self {
75        Self { period: Some(14) }
76    }
77}
78
79#[derive(Debug, Clone)]
80pub struct MfiInput<'a> {
81    pub data: MfiData<'a>,
82    pub params: MfiParams,
83}
84
85impl<'a> MfiInput<'a> {
86    #[inline]
87    pub fn from_candles(candles: &'a Candles, source: &'a str, params: MfiParams) -> Self {
88        Self {
89            data: MfiData::Candles { candles, source },
90            params,
91        }
92    }
93
94    #[inline]
95    pub fn from_slices(typical_price: &'a [f64], volume: &'a [f64], params: MfiParams) -> Self {
96        Self {
97            data: MfiData::Slices {
98                typical_price,
99                volume,
100            },
101            params,
102        }
103    }
104
105    #[inline]
106    pub fn with_default_candles(candles: &'a Candles) -> Self {
107        Self::from_candles(candles, "hlc3", MfiParams::default())
108    }
109
110    #[inline]
111    pub fn get_period(&self) -> usize {
112        self.params.period.unwrap_or(14)
113    }
114}
115
116#[derive(Debug, Error)]
117pub enum MfiError {
118    #[error("mfi: Empty data provided.")]
119    EmptyInputData,
120    #[error("mfi: Invalid period: period = {period}, data length = {data_len}")]
121    InvalidPeriod { period: usize, data_len: usize },
122    #[error("mfi: Not enough valid data: needed = {needed}, valid = {valid}")]
123    NotEnoughValidData { needed: usize, valid: usize },
124    #[error("mfi: All values are NaN.")]
125    AllValuesNaN,
126    #[error("mfi: Output length mismatch: expected = {expected}, got = {got}")]
127    OutputLengthMismatch { expected: usize, got: usize },
128    #[error("mfi: Invalid range: start={start} end={end} step={step}")]
129    InvalidRange {
130        start: usize,
131        end: usize,
132        step: usize,
133    },
134    #[error("mfi: Invalid kernel for batch path: {0:?}")]
135    InvalidKernelForBatch(Kernel),
136}
137
138#[inline]
139pub fn mfi(input: &MfiInput) -> Result<MfiOutput, MfiError> {
140    mfi_with_kernel(input, Kernel::Auto)
141}
142
143#[inline(always)]
144fn mfi_prepare<'a>(
145    input: &'a MfiInput<'a>,
146    kernel: Kernel,
147) -> Result<(&'a [f64], &'a [f64], usize, usize, Kernel), MfiError> {
148    let (typical_price, volume): (&[f64], &[f64]) = match &input.data {
149        MfiData::Candles { candles, source } => {
150            (source_type(candles, source), candles.volume.as_slice())
151        }
152        MfiData::Slices {
153            typical_price,
154            volume,
155        } => (*typical_price, *volume),
156    };
157
158    let length = typical_price.len();
159    if length == 0 || volume.len() != length {
160        return Err(MfiError::EmptyInputData);
161    }
162
163    let period = input.get_period();
164    let first_valid_idx = (0..length).find(|&i| !typical_price[i].is_nan() && !volume[i].is_nan());
165    let first_valid_idx = match first_valid_idx {
166        Some(idx) => idx,
167        None => return Err(MfiError::AllValuesNaN),
168    };
169
170    if period == 0 || period > length {
171        return Err(MfiError::InvalidPeriod {
172            period,
173            data_len: length,
174        });
175    }
176    if (length - first_valid_idx) < period {
177        return Err(MfiError::NotEnoughValidData {
178            needed: period,
179            valid: length - first_valid_idx,
180        });
181    }
182
183    let chosen = match kernel {
184        Kernel::Auto => detect_best_kernel(),
185        other => other,
186    };
187
188    Ok((typical_price, volume, period, first_valid_idx, chosen))
189}
190
191#[inline(always)]
192fn mfi_compute_into(
193    typical_price: &[f64],
194    volume: &[f64],
195    period: usize,
196    first: usize,
197    kernel: Kernel,
198    out: &mut [f64],
199) {
200    unsafe {
201        match kernel {
202            Kernel::Scalar | Kernel::ScalarBatch => {
203                mfi_scalar(typical_price, volume, period, first, out)
204            }
205            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
206            Kernel::Avx2 | Kernel::Avx2Batch => mfi_avx2(typical_price, volume, period, first, out),
207            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
208            Kernel::Avx512 | Kernel::Avx512Batch => {
209                mfi_avx512(typical_price, volume, period, first, out)
210            }
211            _ => unreachable!(),
212        }
213    }
214}
215
216pub fn mfi_with_kernel(input: &MfiInput, kernel: Kernel) -> Result<MfiOutput, MfiError> {
217    let (typical_price, volume, period, first_valid_idx, chosen) = mfi_prepare(input, kernel)?;
218
219    let warmup_period = first_valid_idx + period - 1;
220    let mut out = alloc_with_nan_prefix(typical_price.len(), warmup_period);
221
222    mfi_compute_into(
223        typical_price,
224        volume,
225        period,
226        first_valid_idx,
227        chosen,
228        &mut out,
229    );
230
231    Ok(MfiOutput { values: out })
232}
233
234#[inline]
235pub unsafe fn mfi_scalar(
236    typical_price: &[f64],
237    volume: &[f64],
238    period: usize,
239    first: usize,
240    out: &mut [f64],
241) {
242    let len = typical_price.len();
243    if len == 0 {
244        return;
245    }
246
247    let mut ring_buf = vec![0.0f64; period * 2];
248
249    let tp_ptr = typical_price.as_ptr();
250    let vol_ptr = volume.as_ptr();
251    let out_ptr = out.as_mut_ptr();
252    let pos_ptr = ring_buf.as_mut_ptr();
253    let neg_ptr = ring_buf.as_mut_ptr().add(period);
254
255    let mut pos_sum = 0.0f64;
256    let mut neg_sum = 0.0f64;
257
258    let mut prev = *tp_ptr.add(first);
259    let mut ring = 0usize;
260
261    let seed_start = first + 1;
262    let seed_end = first + period;
263    let mut i = seed_start;
264    while i < seed_end {
265        let tp_i = *tp_ptr.add(i);
266        let flow = tp_i * *vol_ptr.add(i);
267        let diff = tp_i - prev;
268        prev = tp_i;
269
270        let gt = (diff > 0.0) as i32 as f64;
271        let lt = (diff < 0.0) as i32 as f64;
272        let pos_new = flow * gt;
273        let neg_new = flow * lt;
274
275        *pos_ptr.add(ring) = pos_new;
276        *neg_ptr.add(ring) = neg_new;
277        pos_sum += pos_new;
278        neg_sum += neg_new;
279
280        ring += 1;
281        if ring == period {
282            ring = 0;
283        }
284        i += 1;
285    }
286
287    let idx0 = seed_end - 1;
288    if idx0 < len {
289        let total = pos_sum + neg_sum;
290
291        let val = if total < 1e-14 {
292            0.0
293        } else {
294            100.0 * (pos_sum / total)
295        };
296        *out_ptr.add(idx0) = val;
297    }
298
299    i = seed_end;
300    while i < len {
301        let old_pos = *pos_ptr.add(ring);
302        let old_neg = *neg_ptr.add(ring);
303        pos_sum -= old_pos;
304        neg_sum -= old_neg;
305
306        let tp_i = *tp_ptr.add(i);
307        let flow = tp_i * *vol_ptr.add(i);
308        let diff = tp_i - prev;
309        prev = tp_i;
310
311        let gt = (diff > 0.0) as i32 as f64;
312        let lt = (diff < 0.0) as i32 as f64;
313        let pos_new = flow * gt;
314        let neg_new = flow * lt;
315
316        *pos_ptr.add(ring) = pos_new;
317        *neg_ptr.add(ring) = neg_new;
318        pos_sum += pos_new;
319        neg_sum += neg_new;
320
321        let total = pos_sum + neg_sum;
322        let val = if total < 1e-14 {
323            0.0
324        } else {
325            100.0 * (pos_sum / total)
326        };
327        *out_ptr.add(i) = val;
328
329        ring += 1;
330        if ring == period {
331            ring = 0;
332        }
333
334        i += 1;
335    }
336}
337
338#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
339#[inline]
340pub unsafe fn mfi_avx2(
341    typical_price: &[f64],
342    volume: &[f64],
343    period: usize,
344    first: usize,
345    out: &mut [f64],
346) {
347    mfi_scalar(typical_price, volume, period, first, out)
348}
349
350#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
351#[inline]
352pub fn mfi_avx512(
353    typical_price: &[f64],
354    volume: &[f64],
355    period: usize,
356    first: usize,
357    out: &mut [f64],
358) {
359    unsafe {
360        if period <= 32 {
361            mfi_avx512_short(typical_price, volume, period, first, out)
362        } else {
363            mfi_avx512_long(typical_price, volume, period, first, out)
364        }
365    }
366}
367
368#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
369#[inline]
370pub unsafe fn mfi_avx512_short(
371    typical_price: &[f64],
372    volume: &[f64],
373    period: usize,
374    first: usize,
375    out: &mut [f64],
376) {
377    mfi_scalar(typical_price, volume, period, first, out)
378}
379
380#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
381#[inline]
382pub unsafe fn mfi_avx512_long(
383    typical_price: &[f64],
384    volume: &[f64],
385    period: usize,
386    first: usize,
387    out: &mut [f64],
388) {
389    mfi_scalar(typical_price, volume, period, first, out)
390}
391
392#[derive(Copy, Clone, Debug)]
393pub struct MfiBuilder {
394    period: Option<usize>,
395    kernel: Kernel,
396}
397
398impl Default for MfiBuilder {
399    fn default() -> Self {
400        Self {
401            period: None,
402            kernel: Kernel::Auto,
403        }
404    }
405}
406
407impl MfiBuilder {
408    #[inline(always)]
409    pub fn new() -> Self {
410        Self::default()
411    }
412    #[inline(always)]
413    pub fn period(mut self, n: usize) -> Self {
414        self.period = Some(n);
415        self
416    }
417    #[inline(always)]
418    pub fn kernel(mut self, k: Kernel) -> Self {
419        self.kernel = k;
420        self
421    }
422
423    #[inline(always)]
424    pub fn apply(self, c: &Candles) -> Result<MfiOutput, MfiError> {
425        let p = MfiParams {
426            period: self.period,
427        };
428        let i = MfiInput::from_candles(c, "hlc3", p);
429        mfi_with_kernel(&i, self.kernel)
430    }
431
432    #[inline(always)]
433    pub fn apply_slices(
434        self,
435        typical_price: &[f64],
436        volume: &[f64],
437    ) -> Result<MfiOutput, MfiError> {
438        let p = MfiParams {
439            period: self.period,
440        };
441        let i = MfiInput::from_slices(typical_price, volume, p);
442        mfi_with_kernel(&i, self.kernel)
443    }
444
445    #[inline(always)]
446    pub fn into_stream(self) -> Result<MfiStream, MfiError> {
447        let p = MfiParams {
448            period: self.period,
449        };
450        MfiStream::try_new(p)
451    }
452}
453
454#[derive(Debug, Clone)]
455pub struct MfiStream {
456    period: usize,
457    pos_buf: Vec<f64>,
458    neg_buf: Vec<f64>,
459    head: usize,
460    filled: bool,
461    pos_sum: f64,
462    neg_sum: f64,
463    prev_typical: f64,
464    index: usize,
465}
466
467impl MfiStream {
468    pub fn try_new(params: MfiParams) -> Result<Self, MfiError> {
469        let period = params.period.unwrap_or(14);
470        if period == 0 {
471            return Err(MfiError::InvalidPeriod {
472                period,
473                data_len: 0,
474            });
475        }
476        Ok(Self {
477            period,
478            pos_buf: vec![0.0; period],
479            neg_buf: vec![0.0; period],
480            head: 0,
481            filled: false,
482            pos_sum: 0.0,
483            neg_sum: 0.0,
484            prev_typical: f64::NAN,
485            index: 0,
486        })
487    }
488
489    #[inline(always)]
490    pub fn update(&mut self, typical_price: f64, volume: f64) -> Option<f64> {
491        if self.index == 0 {
492            self.prev_typical = typical_price;
493            self.index = 1;
494            return None;
495        }
496
497        let diff = typical_price - self.prev_typical;
498        self.prev_typical = typical_price;
499
500        let flow = typical_price.mul_add(volume, 0.0);
501
502        let gt = (diff > 0.0) as i32 as f64;
503        let lt = (diff < 0.0) as i32 as f64;
504        let pos_new = flow * gt;
505        let neg_new = flow * lt;
506
507        unsafe {
508            let old_pos = *self.pos_buf.get_unchecked(self.head);
509            let old_neg = *self.neg_buf.get_unchecked(self.head);
510
511            self.pos_sum += pos_new - old_pos;
512            self.neg_sum += neg_new - old_neg;
513
514            *self.pos_buf.get_unchecked_mut(self.head) = pos_new;
515            *self.neg_buf.get_unchecked_mut(self.head) = neg_new;
516        }
517
518        self.head += 1;
519        if self.head == self.period {
520            self.head = 0;
521            self.filled = true;
522        }
523        self.index += 1;
524
525        if !self.filled {
526            return None;
527        }
528
529        let total = self.pos_sum + self.neg_sum;
530        if total <= 1e-14 {
531            Some(0.0)
532        } else {
533            Some(100.0 * self.pos_sum * total.recip())
534        }
535    }
536}
537
538#[derive(Clone, Debug)]
539pub struct MfiBatchRange {
540    pub period: (usize, usize, usize),
541}
542
543impl Default for MfiBatchRange {
544    fn default() -> Self {
545        Self {
546            period: (14, 263, 1),
547        }
548    }
549}
550
551#[derive(Clone, Debug, Default)]
552pub struct MfiBatchBuilder {
553    range: MfiBatchRange,
554    kernel: Kernel,
555}
556
557impl MfiBatchBuilder {
558    pub fn new() -> Self {
559        Self::default()
560    }
561    pub fn kernel(mut self, k: Kernel) -> Self {
562        self.kernel = k;
563        self
564    }
565
566    #[inline]
567    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
568        self.range.period = (start, end, step);
569        self
570    }
571    #[inline]
572    pub fn period_static(mut self, p: usize) -> Self {
573        self.range.period = (p, p, 0);
574        self
575    }
576
577    pub fn apply_slices(
578        self,
579        typical_price: &[f64],
580        volume: &[f64],
581    ) -> Result<MfiBatchOutput, MfiError> {
582        mfi_batch_with_kernel(typical_price, volume, &self.range, self.kernel)
583    }
584
585    pub fn apply_candles(self, c: &Candles) -> Result<MfiBatchOutput, MfiError> {
586        let typical_price = source_type(c, "hlc3");
587        self.apply_slices(typical_price, &c.volume)
588    }
589
590    pub fn with_default_candles(c: &Candles, k: Kernel) -> Result<MfiBatchOutput, MfiError> {
591        MfiBatchBuilder::new().kernel(k).apply_candles(c)
592    }
593}
594
595#[derive(Clone, Debug)]
596pub struct MfiBatchOutput {
597    pub values: Vec<f64>,
598    pub combos: Vec<MfiParams>,
599    pub rows: usize,
600    pub cols: usize,
601}
602impl MfiBatchOutput {
603    pub fn row_for_params(&self, p: &MfiParams) -> Option<usize> {
604        self.combos
605            .iter()
606            .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
607    }
608    pub fn values_for(&self, p: &MfiParams) -> Option<&[f64]> {
609        self.row_for_params(p).map(|row| {
610            let start = row * self.cols;
611            &self.values[start..start + self.cols]
612        })
613    }
614}
615
616#[inline(always)]
617fn expand_grid(r: &MfiBatchRange) -> Vec<MfiParams> {
618    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, MfiError> {
619        if step == 0 || start == end {
620            return Ok(vec![start]);
621        }
622        if start < end {
623            return Ok((start..=end).step_by(step.max(1)).collect());
624        }
625        let mut v = Vec::new();
626        let mut x = start as isize;
627        let end_i = end as isize;
628        let st = (step as isize).max(1);
629        while x >= end_i {
630            v.push(x as usize);
631            x -= st;
632        }
633        if v.is_empty() {
634            return Err(MfiError::InvalidRange { start, end, step });
635        }
636        Ok(v)
637    }
638    let periods = match axis_usize(r.period) {
639        Ok(v) => v,
640        Err(_) => return Vec::new(),
641    };
642    let mut out = Vec::with_capacity(periods.len());
643    for &p in &periods {
644        out.push(MfiParams { period: Some(p) });
645    }
646    out
647}
648
649pub fn mfi_batch_with_kernel(
650    typical_price: &[f64],
651    volume: &[f64],
652    sweep: &MfiBatchRange,
653    k: Kernel,
654) -> Result<MfiBatchOutput, MfiError> {
655    let kernel = match k {
656        Kernel::Auto => detect_best_batch_kernel(),
657        other if other.is_batch() => other,
658        _ => return Err(MfiError::InvalidKernelForBatch(k)),
659    };
660    let simd = match kernel {
661        Kernel::Avx512Batch => Kernel::Avx512,
662        Kernel::Avx2Batch => Kernel::Avx2,
663        Kernel::ScalarBatch => Kernel::Scalar,
664        _ => unreachable!(),
665    };
666    mfi_batch_par_slice(typical_price, volume, sweep, simd)
667}
668
669#[inline(always)]
670pub fn mfi_batch_slice(
671    typical_price: &[f64],
672    volume: &[f64],
673    sweep: &MfiBatchRange,
674    kern: Kernel,
675) -> Result<MfiBatchOutput, MfiError> {
676    mfi_batch_inner(typical_price, volume, sweep, kern, false)
677}
678
679#[inline(always)]
680pub fn mfi_batch_par_slice(
681    typical_price: &[f64],
682    volume: &[f64],
683    sweep: &MfiBatchRange,
684    kern: Kernel,
685) -> Result<MfiBatchOutput, MfiError> {
686    mfi_batch_inner(typical_price, volume, sweep, kern, true)
687}
688
689fn round_up8(x: usize) -> usize {
690    (x + 7) & !7
691}
692
693#[inline(always)]
694fn mfi_batch_inner(
695    typical_price: &[f64],
696    volume: &[f64],
697    sweep: &MfiBatchRange,
698    kern: Kernel,
699    parallel: bool,
700) -> Result<MfiBatchOutput, MfiError> {
701    let combos = expand_grid(sweep);
702    if combos.is_empty() {
703        return Err(MfiError::InvalidRange {
704            start: sweep.period.0,
705            end: sweep.period.1,
706            step: sweep.period.2,
707        });
708    }
709    let length = typical_price.len();
710    let first = (0..length)
711        .find(|&i| !typical_price[i].is_nan() && !volume[i].is_nan())
712        .ok_or(MfiError::AllValuesNaN)?;
713
714    let max_p = combos
715        .iter()
716        .map(|c| round_up8(c.period.unwrap()))
717        .max()
718        .unwrap();
719    if length - first < max_p {
720        return Err(MfiError::NotEnoughValidData {
721            needed: max_p,
722            valid: length - first,
723        });
724    }
725
726    let rows = combos.len();
727    let cols = length;
728
729    if volume.len() != cols {
730        return Err(MfiError::EmptyInputData);
731    }
732
733    let mut buf_mu = make_uninit_matrix(rows, cols);
734    let warmup_periods: Vec<usize> = combos
735        .iter()
736        .map(|c| first + c.period.unwrap() - 1)
737        .collect();
738    init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
739
740    let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
741    let out: &mut [f64] = unsafe {
742        core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
743    };
744
745    let rows = combos.len();
746    let use_prefix = rows >= 8;
747
748    let (pos_prefix, neg_prefix) = if use_prefix {
749        let (pp, np) =
750            unsafe { precompute_flow_prefixes_select(typical_price, volume, first, kern) };
751        (Some(pp), Some(np))
752    } else {
753        (None, None)
754    };
755
756    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
757        let period = combos[row].period.unwrap();
758        if let (Some(ref pp), Some(ref np)) = (pos_prefix.as_ref(), neg_prefix.as_ref()) {
759            mfi_row_from_prefixes(pp, np, first, period, out_row)
760        } else {
761            match kern {
762                Kernel::Scalar => mfi_row_scalar(typical_price, volume, first, period, out_row),
763                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
764                Kernel::Avx2 => mfi_row_avx2(typical_price, volume, first, period, out_row),
765                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
766                Kernel::Avx512 => mfi_row_avx512(typical_price, volume, first, period, out_row),
767                _ => unreachable!(),
768            }
769        }
770    };
771
772    if parallel {
773        #[cfg(not(target_arch = "wasm32"))]
774        {
775            use rayon::prelude::*;
776            out.par_chunks_mut(cols)
777                .enumerate()
778                .for_each(|(row, slice)| do_row(row, slice));
779        }
780
781        #[cfg(target_arch = "wasm32")]
782        {
783            for (row, slice) in out.chunks_mut(cols).enumerate() {
784                do_row(row, slice);
785            }
786        }
787    } else {
788        for (row, slice) in out.chunks_mut(cols).enumerate() {
789            do_row(row, slice);
790        }
791    }
792
793    let values = unsafe {
794        Vec::from_raw_parts(
795            buf_guard.as_mut_ptr() as *mut f64,
796            buf_guard.len(),
797            buf_guard.capacity(),
798        )
799    };
800
801    Ok(MfiBatchOutput {
802        values,
803        combos,
804        rows,
805        cols,
806    })
807}
808
809#[inline(always)]
810fn mfi_batch_inner_into(
811    typical_price: &[f64],
812    volume: &[f64],
813    sweep: &MfiBatchRange,
814    kern: Kernel,
815    parallel: bool,
816    out: &mut [f64],
817) -> Result<Vec<MfiParams>, MfiError> {
818    let combos = expand_grid(sweep);
819    if combos.is_empty() {
820        return Err(MfiError::InvalidRange {
821            start: sweep.period.0,
822            end: sweep.period.1,
823            step: sweep.period.2,
824        });
825    }
826
827    let length = typical_price.len();
828    let first = (0..length)
829        .find(|&i| !typical_price[i].is_nan() && !volume[i].is_nan())
830        .ok_or(MfiError::AllValuesNaN)?;
831
832    let max_p = combos
833        .iter()
834        .map(|c| round_up8(c.period.unwrap()))
835        .max()
836        .unwrap();
837    if length - first < max_p {
838        return Err(MfiError::NotEnoughValidData {
839            needed: max_p,
840            valid: length - first,
841        });
842    }
843
844    let cols = length;
845
846    if volume.len() != cols {
847        return Err(MfiError::EmptyInputData);
848    }
849
850    let rows = combos.len();
851    let use_prefix = rows >= 8;
852    let (pos_prefix, neg_prefix) = if use_prefix {
853        let (pp, np) =
854            unsafe { precompute_flow_prefixes_select(typical_price, volume, first, kern) };
855        (Some(pp), Some(np))
856    } else {
857        (None, None)
858    };
859
860    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
861        let period = combos[row].period.unwrap();
862
863        let warmup_end = first + period - 1;
864        for v in &mut out_row[..warmup_end] {
865            *v = f64::NAN;
866        }
867        if let (Some(ref pp), Some(ref np)) = (pos_prefix.as_ref(), neg_prefix.as_ref()) {
868            mfi_row_from_prefixes(pp, np, first, period, out_row)
869        } else {
870            match kern {
871                Kernel::Scalar => mfi_row_scalar(typical_price, volume, first, period, out_row),
872                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
873                Kernel::Avx2 => mfi_row_avx2(typical_price, volume, first, period, out_row),
874                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
875                Kernel::Avx512 => mfi_row_avx512(typical_price, volume, first, period, out_row),
876                _ => unreachable!(),
877            }
878        }
879    };
880
881    if parallel {
882        #[cfg(not(target_arch = "wasm32"))]
883        {
884            use rayon::prelude::*;
885            out.par_chunks_mut(cols)
886                .enumerate()
887                .for_each(|(row, slice)| do_row(row, slice));
888        }
889
890        #[cfg(target_arch = "wasm32")]
891        {
892            for (row, slice) in out.chunks_mut(cols).enumerate() {
893                do_row(row, slice);
894            }
895        }
896    } else {
897        for (row, slice) in out.chunks_mut(cols).enumerate() {
898            do_row(row, slice);
899        }
900    }
901
902    Ok(combos)
903}
904
905#[inline(always)]
906unsafe fn precompute_flow_prefixes_scalar(
907    typical_price: &[f64],
908    volume: &[f64],
909    first: usize,
910) -> (Vec<f64>, Vec<f64>) {
911    let len = typical_price.len();
912    let tp_ptr = typical_price.as_ptr();
913    let vol_ptr = volume.as_ptr();
914
915    let mut pos_prefix = vec![0.0f64; len];
916    let mut neg_prefix = vec![0.0f64; len];
917
918    if len == 0 {
919        return (pos_prefix, neg_prefix);
920    }
921
922    let mut i = first + 1;
923    let mut prev = *tp_ptr.add(first);
924    while i < len {
925        let tp_i = *tp_ptr.add(i);
926        let flow = tp_i * *vol_ptr.add(i);
927        let diff = tp_i - prev;
928        prev = tp_i;
929
930        let gt = (diff > 0.0) as i32 as f64;
931        let lt = (diff < 0.0) as i32 as f64;
932        let pos = flow * gt;
933        let neg = flow * lt;
934
935        pos_prefix[i] = pos_prefix[i - 1] + pos;
936        neg_prefix[i] = neg_prefix[i - 1] + neg;
937        i += 1;
938    }
939
940    if first > 0 {
941        pos_prefix[first] = 0.0;
942        neg_prefix[first] = 0.0;
943    }
944
945    (pos_prefix, neg_prefix)
946}
947
948#[inline(always)]
949unsafe fn precompute_flow_prefixes_select(
950    typical_price: &[f64],
951    volume: &[f64],
952    first: usize,
953    kern: Kernel,
954) -> (Vec<f64>, Vec<f64>) {
955    match kern {
956        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
957        Kernel::Avx2 | Kernel::Avx512 => {
958            precompute_flow_prefixes_avx2(typical_price, volume, first)
959        }
960        _ => precompute_flow_prefixes_scalar(typical_price, volume, first),
961    }
962}
963
964#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
965#[inline(always)]
966unsafe fn precompute_flow_prefixes_avx2(
967    typical_price: &[f64],
968    volume: &[f64],
969    first: usize,
970) -> (Vec<f64>, Vec<f64>) {
971    use core::arch::x86_64::*;
972    let len = typical_price.len();
973    let mut pos_prefix = vec![0.0f64; len];
974    let mut neg_prefix = vec![0.0f64; len];
975    if len == 0 {
976        return (pos_prefix, neg_prefix);
977    }
978
979    let mut pos_sum = 0.0f64;
980    let mut neg_sum = 0.0f64;
981
982    if first < len {
983        pos_prefix[first] = 0.0;
984        neg_prefix[first] = 0.0;
985    }
986
987    let mut i = first + 1;
988    let tp_ptr = typical_price.as_ptr();
989    let vol_ptr = volume.as_ptr();
990    let zero = _mm256_set1_pd(0.0);
991
992    while i + 4 <= len {
993        let tp_cur = _mm256_loadu_pd(tp_ptr.add(i));
994        let tp_prev = _mm256_loadu_pd(tp_ptr.add(i - 1));
995        let vol_cur = _mm256_loadu_pd(vol_ptr.add(i));
996
997        let flow = _mm256_mul_pd(tp_cur, vol_cur);
998
999        let diff = _mm256_sub_pd(tp_cur, tp_prev);
1000
1001        let m_gt = _mm256_cmp_pd(diff, zero, _CMP_GT_OQ);
1002        let m_lt = _mm256_cmp_pd(diff, zero, _CMP_LT_OQ);
1003
1004        let pos_v = _mm256_and_pd(flow, m_gt);
1005        let neg_v = _mm256_and_pd(flow, m_lt);
1006
1007        let mut pos_tmp = [0.0f64; 4];
1008        let mut neg_tmp = [0.0f64; 4];
1009        _mm256_storeu_pd(pos_tmp.as_mut_ptr(), pos_v);
1010        _mm256_storeu_pd(neg_tmp.as_mut_ptr(), neg_v);
1011
1012        pos_sum += pos_tmp[0];
1013        neg_sum += neg_tmp[0];
1014        *pos_prefix.get_unchecked_mut(i) = pos_sum;
1015        *neg_prefix.get_unchecked_mut(i) = neg_sum;
1016
1017        pos_sum += pos_tmp[1];
1018        neg_sum += neg_tmp[1];
1019        *pos_prefix.get_unchecked_mut(i + 1) = pos_sum;
1020        *neg_prefix.get_unchecked_mut(i + 1) = neg_sum;
1021
1022        pos_sum += pos_tmp[2];
1023        neg_sum += neg_tmp[2];
1024        *pos_prefix.get_unchecked_mut(i + 2) = pos_sum;
1025        *neg_prefix.get_unchecked_mut(i + 2) = neg_sum;
1026
1027        pos_sum += pos_tmp[3];
1028        neg_sum += neg_tmp[3];
1029        *pos_prefix.get_unchecked_mut(i + 3) = pos_sum;
1030        *neg_prefix.get_unchecked_mut(i + 3) = neg_sum;
1031
1032        i += 4;
1033    }
1034
1035    while i < len {
1036        let tp_i = *tp_ptr.add(i);
1037        let flow = tp_i * *vol_ptr.add(i);
1038        let diff = tp_i - *tp_ptr.add(i - 1);
1039        let gt = (diff > 0.0) as i32 as f64;
1040        let lt = (diff < 0.0) as i32 as f64;
1041        pos_sum += flow * gt;
1042        neg_sum += flow * lt;
1043        *pos_prefix.get_unchecked_mut(i) = pos_sum;
1044        *neg_prefix.get_unchecked_mut(i) = neg_sum;
1045        i += 1;
1046    }
1047
1048    (pos_prefix, neg_prefix)
1049}
1050
1051#[inline(always)]
1052unsafe fn mfi_row_from_prefixes(
1053    pos_prefix: &[f64],
1054    neg_prefix: &[f64],
1055    first: usize,
1056    period: usize,
1057    out: &mut [f64],
1058) {
1059    let len = out.len();
1060    if len == 0 {
1061        return;
1062    }
1063    let idx0 = first + period - 1;
1064    if idx0 >= len {
1065        return;
1066    }
1067
1068    let pos0 = pos_prefix[idx0] - pos_prefix[first];
1069    let neg0 = neg_prefix[idx0] - neg_prefix[first];
1070    let tot0 = pos0 + neg0;
1071    *out.get_unchecked_mut(idx0) = if tot0 < 1e-14 {
1072        0.0
1073    } else {
1074        100.0 * (pos0 / tot0)
1075    };
1076
1077    let mut i = idx0 + 1;
1078    while i < len {
1079        let base = i - period;
1080        let pos_sum = pos_prefix[i] - pos_prefix[base];
1081        let neg_sum = neg_prefix[i] - neg_prefix[base];
1082        let total = pos_sum + neg_sum;
1083        let val = if total < 1e-14 {
1084            0.0
1085        } else {
1086            100.0 * (pos_sum / total)
1087        };
1088        *out.get_unchecked_mut(i) = val;
1089        i += 1;
1090    }
1091}
1092
1093#[inline(always)]
1094unsafe fn mfi_row_scalar(
1095    typical_price: &[f64],
1096    volume: &[f64],
1097    first: usize,
1098    period: usize,
1099    out: &mut [f64],
1100) {
1101    mfi_scalar(typical_price, volume, period, first, out)
1102}
1103
1104#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1105#[inline(always)]
1106unsafe fn mfi_row_avx2(
1107    typical_price: &[f64],
1108    volume: &[f64],
1109    first: usize,
1110    period: usize,
1111    out: &mut [f64],
1112) {
1113    mfi_scalar(typical_price, volume, period, first, out)
1114}
1115
1116#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1117#[inline(always)]
1118pub unsafe fn mfi_row_avx512(
1119    typical_price: &[f64],
1120    volume: &[f64],
1121    first: usize,
1122    period: usize,
1123    out: &mut [f64],
1124) {
1125    if period <= 32 {
1126        mfi_row_avx512_short(typical_price, volume, first, period, out)
1127    } else {
1128        mfi_row_avx512_long(typical_price, volume, first, period, out)
1129    }
1130}
1131
1132#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1133#[inline(always)]
1134pub unsafe fn mfi_row_avx512_short(
1135    typical_price: &[f64],
1136    volume: &[f64],
1137    first: usize,
1138    period: usize,
1139    out: &mut [f64],
1140) {
1141    mfi_scalar(typical_price, volume, period, first, out)
1142}
1143
1144#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1145#[inline(always)]
1146pub unsafe fn mfi_row_avx512_long(
1147    typical_price: &[f64],
1148    volume: &[f64],
1149    first: usize,
1150    period: usize,
1151    out: &mut [f64],
1152) {
1153    mfi_scalar(typical_price, volume, period, first, out)
1154}
1155
1156#[cfg(feature = "python")]
1157#[pyfunction(name = "mfi")]
1158#[pyo3(signature = (typical_price, volume, period, kernel=None))]
1159pub fn mfi_py<'py>(
1160    py: Python<'py>,
1161    typical_price: PyReadonlyArray1<'py, f64>,
1162    volume: PyReadonlyArray1<'py, f64>,
1163    period: usize,
1164    kernel: Option<&str>,
1165) -> PyResult<Bound<'py, PyArray1<f64>>> {
1166    let typical_slice = typical_price.as_slice()?;
1167    let volume_slice = volume.as_slice()?;
1168    let kern = validate_kernel(kernel, false)?;
1169
1170    let params = MfiParams {
1171        period: Some(period),
1172    };
1173    let input = MfiInput::from_slices(typical_slice, volume_slice, params);
1174
1175    let result_vec: Vec<f64> = py
1176        .allow_threads(|| mfi_with_kernel(&input, kern).map(|o| o.values))
1177        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1178
1179    Ok(result_vec.into_pyarray(py))
1180}
1181
1182#[cfg(feature = "python")]
1183#[pyclass(name = "MfiStream")]
1184pub struct MfiStreamPy {
1185    inner: MfiStream,
1186}
1187
1188#[cfg(all(feature = "python", feature = "cuda"))]
1189#[pyclass(module = "ta_indicators.cuda", unsendable)]
1190pub struct MfiDeviceArrayF32Py {
1191    pub(crate) inner: Option<DeviceArrayF32>,
1192    pub(crate) ctx: Arc<Context>,
1193    pub(crate) device_id: i32,
1194}
1195
1196#[cfg(all(feature = "python", feature = "cuda"))]
1197#[pymethods]
1198impl MfiDeviceArrayF32Py {
1199    #[getter]
1200    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1201        let inner = self
1202            .inner
1203            .as_ref()
1204            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1205        let d = PyDict::new(py);
1206        d.set_item("shape", (inner.rows, inner.cols))?;
1207        d.set_item("typestr", "<f4")?;
1208        d.set_item(
1209            "strides",
1210            (
1211                inner.cols * std::mem::size_of::<f32>(),
1212                std::mem::size_of::<f32>(),
1213            ),
1214        )?;
1215        d.set_item("data", (inner.device_ptr() as usize, false))?;
1216        d.set_item("version", 3)?;
1217        Ok(d)
1218    }
1219
1220    fn __dlpack_device__(&self) -> (i32, i32) {
1221        (2, self.device_id)
1222    }
1223
1224    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1225    fn __dlpack__<'py>(
1226        &mut self,
1227        py: Python<'py>,
1228        stream: Option<pyo3::PyObject>,
1229        max_version: Option<pyo3::PyObject>,
1230        dl_device: Option<pyo3::PyObject>,
1231        copy: Option<pyo3::PyObject>,
1232    ) -> PyResult<PyObject> {
1233        let (kdl, alloc_dev) = self.__dlpack_device__();
1234        if let Some(dev_obj) = dl_device.as_ref() {
1235            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1236                if dev_ty != kdl || dev_id != alloc_dev {
1237                    let wants_copy = copy
1238                        .as_ref()
1239                        .and_then(|c| c.extract::<bool>(py).ok())
1240                        .unwrap_or(false);
1241                    if wants_copy {
1242                        return Err(PyValueError::new_err(
1243                            "device copy not implemented for __dlpack__",
1244                        ));
1245                    } else {
1246                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1247                    }
1248                }
1249            }
1250        }
1251        let _ = stream;
1252
1253        let inner = self
1254            .inner
1255            .take()
1256            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1257
1258        let rows = inner.rows;
1259        let cols = inner.cols;
1260        let buf = inner.buf;
1261
1262        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1263
1264        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1265    }
1266}
1267
1268#[cfg(feature = "python")]
1269#[pymethods]
1270impl MfiStreamPy {
1271    #[new]
1272    pub fn new(period: usize) -> PyResult<Self> {
1273        let params = MfiParams {
1274            period: Some(period),
1275        };
1276        let inner = MfiStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1277        Ok(MfiStreamPy { inner })
1278    }
1279
1280    pub fn update(&mut self, typical_price: f64, volume: f64) -> Option<f64> {
1281        self.inner.update(typical_price, volume)
1282    }
1283}
1284
1285#[cfg(feature = "python")]
1286#[pyfunction(name = "mfi_batch")]
1287#[pyo3(signature = (typical_price, volume, period_range, kernel=None))]
1288pub fn mfi_batch_py<'py>(
1289    py: Python<'py>,
1290    typical_price: PyReadonlyArray1<'py, f64>,
1291    volume: PyReadonlyArray1<'py, f64>,
1292    period_range: (usize, usize, usize),
1293    kernel: Option<&str>,
1294) -> PyResult<Bound<'py, PyDict>> {
1295    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1296    use pyo3::types::PyDict;
1297
1298    let tp = typical_price.as_slice()?;
1299    let vol = volume.as_slice()?;
1300    if tp.len() != vol.len() {
1301        return Err(PyValueError::new_err(
1302            "mfi_batch: typical_price and volume length mismatch",
1303        ));
1304    }
1305
1306    let sweep = MfiBatchRange {
1307        period: period_range,
1308    };
1309    let kern = validate_kernel(kernel, true)?;
1310
1311    let combos = expand_grid(&sweep);
1312    let rows = combos.len();
1313    let cols = tp.len();
1314
1315    let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1316    let out_slice = unsafe { out_arr.as_slice_mut()? };
1317
1318    let combos = py
1319        .allow_threads(|| {
1320            let k = match kern {
1321                Kernel::Auto => detect_best_batch_kernel(),
1322                k => k,
1323            };
1324
1325            let simd = match k {
1326                Kernel::Avx512Batch => Kernel::Avx512,
1327                Kernel::Avx2Batch => Kernel::Avx2,
1328                Kernel::ScalarBatch => Kernel::Scalar,
1329                _ => k,
1330            };
1331            mfi_batch_inner_into(tp, vol, &sweep, simd, true, out_slice)
1332        })
1333        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1334
1335    let dict = PyDict::new(py);
1336
1337    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1338    dict.set_item(
1339        "periods",
1340        combos
1341            .iter()
1342            .map(|p| p.period.unwrap() as u64)
1343            .collect::<Vec<_>>()
1344            .into_pyarray(py),
1345    )?;
1346    Ok(dict)
1347}
1348
1349#[cfg(all(feature = "python", feature = "cuda"))]
1350#[pyfunction(name = "mfi_cuda_batch_dev")]
1351#[pyo3(signature = (typical_price, volume, period_range, device_id=0))]
1352pub fn mfi_cuda_batch_dev_py(
1353    py: Python<'_>,
1354    typical_price: PyReadonlyArray1<'_, f32>,
1355    volume: PyReadonlyArray1<'_, f32>,
1356    period_range: (usize, usize, usize),
1357    device_id: usize,
1358) -> PyResult<MfiDeviceArrayF32Py> {
1359    if !cuda_available() {
1360        return Err(PyValueError::new_err("CUDA not available"));
1361    }
1362    let tp = typical_price.as_slice()?;
1363    let vol = volume.as_slice()?;
1364    if tp.len() != vol.len() {
1365        return Err(PyValueError::new_err("mismatched input lengths"));
1366    }
1367    let sweep = MfiBatchRange {
1368        period: period_range,
1369    };
1370    let (inner, ctx, dev_id) = py.allow_threads(|| {
1371        let cuda = CudaMfi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1372        let ctx = cuda.context_arc();
1373        let dev_id = cuda.device_id() as i32;
1374        let (arr, _combos) = cuda
1375            .mfi_batch_dev(tp, vol, &sweep)
1376            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1377        Ok::<_, PyErr>((arr, ctx, dev_id))
1378    })?;
1379    Ok(MfiDeviceArrayF32Py {
1380        inner: Some(inner),
1381        ctx,
1382        device_id: dev_id,
1383    })
1384}
1385
1386#[cfg(all(feature = "python", feature = "cuda"))]
1387#[pyfunction(name = "mfi_cuda_many_series_one_param_dev")]
1388#[pyo3(signature = (typical_price_tm, volume_tm, cols, rows, period, device_id=0))]
1389pub fn mfi_cuda_many_series_one_param_dev_py(
1390    py: Python<'_>,
1391    typical_price_tm: PyReadonlyArray1<'_, f32>,
1392    volume_tm: PyReadonlyArray1<'_, f32>,
1393    cols: usize,
1394    rows: usize,
1395    period: usize,
1396    device_id: usize,
1397) -> PyResult<MfiDeviceArrayF32Py> {
1398    if !cuda_available() {
1399        return Err(PyValueError::new_err("CUDA not available"));
1400    }
1401    let tp = typical_price_tm.as_slice()?;
1402    let vol = volume_tm.as_slice()?;
1403    if tp.len() != vol.len() {
1404        return Err(PyValueError::new_err("mismatched input lengths"));
1405    }
1406    if tp.len() != cols * rows {
1407        return Err(PyValueError::new_err("unexpected matrix size"));
1408    }
1409    let (inner, ctx, dev_id) = py.allow_threads(|| {
1410        let cuda = CudaMfi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1411        let ctx = cuda.context_arc();
1412        let dev_id = cuda.device_id() as i32;
1413        let arr = cuda
1414            .mfi_many_series_one_param_time_major_dev(tp, vol, cols, rows, period)
1415            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1416        Ok::<_, PyErr>((arr, ctx, dev_id))
1417    })?;
1418    Ok(MfiDeviceArrayF32Py {
1419        inner: Some(inner),
1420        ctx,
1421        device_id: dev_id,
1422    })
1423}
1424
1425#[inline]
1426pub fn mfi_into_slice(dst: &mut [f64], input: &MfiInput, kern: Kernel) -> Result<(), MfiError> {
1427    let (typical_price, volume, period, first_valid_idx, chosen) = mfi_prepare(input, kern)?;
1428
1429    if dst.len() != typical_price.len() {
1430        return Err(MfiError::OutputLengthMismatch {
1431            expected: typical_price.len(),
1432            got: dst.len(),
1433        });
1434    }
1435
1436    mfi_compute_into(typical_price, volume, period, first_valid_idx, chosen, dst);
1437
1438    let warmup_period = first_valid_idx + period - 1;
1439
1440    let nan_q = f64::from_bits(0x7ff8_0000_0000_0000);
1441    for v in &mut dst[..warmup_period] {
1442        *v = nan_q;
1443    }
1444
1445    Ok(())
1446}
1447
1448#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1449pub fn mfi_into(input: &MfiInput, out: &mut [f64]) -> Result<(), MfiError> {
1450    mfi_into_slice(out, input, Kernel::Auto)
1451}
1452
1453#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1454#[wasm_bindgen]
1455pub fn mfi_js(typical_price: &[f64], volume: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1456    let params = MfiParams {
1457        period: Some(period),
1458    };
1459    let input = MfiInput::from_slices(typical_price, volume, params);
1460
1461    let result = mfi_with_kernel(&input, detect_best_kernel())
1462        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1463
1464    Ok(result.values)
1465}
1466
1467#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1468#[wasm_bindgen]
1469pub fn mfi_into(
1470    typical_price_ptr: *const f64,
1471    volume_ptr: *const f64,
1472    out_ptr: *mut f64,
1473    len: usize,
1474    period: usize,
1475) -> Result<(), JsValue> {
1476    if typical_price_ptr.is_null() || volume_ptr.is_null() || out_ptr.is_null() {
1477        return Err(JsValue::from_str("Null pointer provided"));
1478    }
1479
1480    unsafe {
1481        let typical_price = std::slice::from_raw_parts(typical_price_ptr, len);
1482        let volume = std::slice::from_raw_parts(volume_ptr, len);
1483        let params = MfiParams {
1484            period: Some(period),
1485        };
1486        let input = MfiInput::from_slices(typical_price, volume, params);
1487
1488        if typical_price_ptr == out_ptr || volume_ptr == out_ptr {
1489            let result = mfi_with_kernel(&input, detect_best_kernel())
1490                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1491            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1492            out.copy_from_slice(&result.values);
1493        } else {
1494            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1495            mfi_into_slice(out, &input, detect_best_kernel())
1496                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1497        }
1498        Ok(())
1499    }
1500}
1501
1502#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1503#[wasm_bindgen]
1504pub fn mfi_alloc(len: usize) -> *mut f64 {
1505    let mut vec = Vec::<f64>::with_capacity(len);
1506    let ptr = vec.as_mut_ptr();
1507    std::mem::forget(vec);
1508    ptr
1509}
1510
1511#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1512#[wasm_bindgen]
1513pub fn mfi_free(ptr: *mut f64, len: usize) {
1514    if !ptr.is_null() {
1515        unsafe {
1516            let _ = Vec::from_raw_parts(ptr, len, len);
1517        }
1518    }
1519}
1520
1521#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1522#[derive(Serialize, Deserialize)]
1523pub struct MfiBatchConfig {
1524    pub period_range: (usize, usize, usize),
1525}
1526
1527#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1528#[derive(Serialize, Deserialize)]
1529pub struct MfiBatchJsOutput {
1530    pub values: Vec<f64>,
1531    pub combos: Vec<MfiParams>,
1532    pub rows: usize,
1533    pub cols: usize,
1534}
1535
1536#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1537#[wasm_bindgen(js_name = mfi_batch)]
1538pub fn mfi_batch_unified_js(
1539    typical_price: &[f64],
1540    volume: &[f64],
1541    config: JsValue,
1542) -> Result<JsValue, JsValue> {
1543    let config: MfiBatchConfig = serde_wasm_bindgen::from_value(config)
1544        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1545
1546    let sweep = MfiBatchRange {
1547        period: config.period_range,
1548    };
1549
1550    let output = mfi_batch_inner(typical_price, volume, &sweep, detect_best_kernel(), false)
1551        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1552
1553    let js_output = MfiBatchJsOutput {
1554        values: output.values,
1555        combos: output.combos,
1556        rows: output.rows,
1557        cols: output.cols,
1558    };
1559
1560    serde_wasm_bindgen::to_value(&js_output)
1561        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1562}
1563
1564#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1565#[wasm_bindgen]
1566pub fn mfi_batch_into(
1567    typical_price_ptr: *const f64,
1568    volume_ptr: *const f64,
1569    out_ptr: *mut f64,
1570    len: usize,
1571    period_start: usize,
1572    period_end: usize,
1573    period_step: usize,
1574) -> Result<usize, JsValue> {
1575    if typical_price_ptr.is_null() || volume_ptr.is_null() || out_ptr.is_null() {
1576        return Err(JsValue::from_str("null pointer passed to mfi_batch_into"));
1577    }
1578    unsafe {
1579        let tp = std::slice::from_raw_parts(typical_price_ptr, len);
1580        let vol = std::slice::from_raw_parts(volume_ptr, len);
1581
1582        let sweep = MfiBatchRange {
1583            period: (period_start, period_end, period_step),
1584        };
1585        let combos = expand_grid(&sweep);
1586        let rows = combos.len();
1587        let cols = len;
1588        let total = rows
1589            .checked_mul(cols)
1590            .ok_or_else(|| JsValue::from_str("mfi_batch_into: rows*cols overflow"))?;
1591
1592        let out = std::slice::from_raw_parts_mut(out_ptr, total);
1593
1594        mfi_batch_inner_into(tp, vol, &sweep, detect_best_kernel(), false, out)
1595            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1596
1597        Ok(rows)
1598    }
1599}
1600
1601#[cfg(test)]
1602mod tests {
1603    use super::*;
1604    use crate::skip_if_unsupported;
1605    use crate::utilities::data_loader::read_candles_from_csv;
1606    use paste::paste;
1607    use std::error::Error;
1608
1609    #[test]
1610    fn test_mfi_into_matches_api() -> Result<(), Box<dyn Error>> {
1611        let n = 256usize;
1612        let mut tp = Vec::with_capacity(n);
1613        let mut vol = Vec::with_capacity(n);
1614        for i in 0..n {
1615            let i_f = i as f64;
1616
1617            let price = 100.0 + 0.123 * i_f + ((i % 7) as f64 - 3.0) * 0.05;
1618            tp.push(price);
1619
1620            vol.push(1_000.0 + ((i * 37) % 113) as f64);
1621        }
1622
1623        let input = MfiInput::from_slices(&tp, &vol, MfiParams::default());
1624
1625        let baseline = mfi(&input)?.values;
1626
1627        let mut out = vec![0.0f64; n];
1628        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1629        {
1630            mfi_into(&input, &mut out)?;
1631        }
1632        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1633        {
1634            mfi_into_slice(&mut out, &input, Kernel::Auto)?;
1635        }
1636
1637        assert_eq!(baseline.len(), out.len());
1638
1639        fn eq_or_both_nan(a: f64, b: f64) -> bool {
1640            (a.is_nan() && b.is_nan()) || (a == b)
1641        }
1642        for (i, (a, b)) in baseline.iter().zip(out.iter()).enumerate() {
1643            assert!(
1644                eq_or_both_nan(*a, *b),
1645                "mismatch at {}: baseline={} vs into={}",
1646                i,
1647                a,
1648                b
1649            );
1650        }
1651
1652        Ok(())
1653    }
1654
1655    fn check_mfi_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1656        skip_if_unsupported!(kernel, test_name);
1657        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1658        let candles = read_candles_from_csv(file_path)?;
1659        let default_params = MfiParams { period: None };
1660        let input = MfiInput::from_candles(&candles, "hlc3", default_params);
1661        let output = mfi_with_kernel(&input, kernel)?;
1662        assert_eq!(output.values.len(), candles.close.len());
1663        Ok(())
1664    }
1665
1666    fn check_mfi_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1667        skip_if_unsupported!(kernel, test_name);
1668        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1669        let candles = read_candles_from_csv(file_path)?;
1670        let params = MfiParams { period: Some(14) };
1671        let input = MfiInput::from_candles(&candles, "hlc3", params);
1672        let mfi_result = mfi_with_kernel(&input, kernel)?;
1673        let expected_last_five_mfi = [
1674            38.13874339324763,
1675            37.44139770113819,
1676            31.02039511395131,
1677            28.092605898618896,
1678            25.905204729397813,
1679        ];
1680        let start_index = mfi_result.values.len() - 5;
1681        for (i, &value) in mfi_result.values[start_index..].iter().enumerate() {
1682            let expected_value = expected_last_five_mfi[i];
1683            let diff = (value - expected_value).abs();
1684            assert!(
1685                diff < 1e-1,
1686                "MFI mismatch at index {}: expected {}, got {}",
1687                i,
1688                expected_value,
1689                value
1690            );
1691        }
1692        Ok(())
1693    }
1694
1695    fn check_mfi_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1696        skip_if_unsupported!(kernel, test_name);
1697        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1698        let candles = read_candles_from_csv(file_path)?;
1699        let input = MfiInput::with_default_candles(&candles);
1700        let output = mfi_with_kernel(&input, kernel)?;
1701        assert_eq!(output.values.len(), candles.close.len());
1702        Ok(())
1703    }
1704
1705    fn check_mfi_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1706        skip_if_unsupported!(kernel, test_name);
1707        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1708        let candles = read_candles_from_csv(file_path)?;
1709        let params = MfiParams { period: Some(0) };
1710        let input = MfiInput::from_candles(&candles, "hlc3", params);
1711        let result = mfi_with_kernel(&input, kernel);
1712        assert!(result.is_err());
1713        Ok(())
1714    }
1715
1716    fn check_mfi_period_exceeds_length(
1717        test_name: &str,
1718        kernel: Kernel,
1719    ) -> Result<(), Box<dyn Error>> {
1720        skip_if_unsupported!(kernel, test_name);
1721        let input_high = [1.0, 2.0, 3.0];
1722        let input_low = [0.5, 1.5, 2.5];
1723        let input_close = [0.8, 1.8, 2.8];
1724        let input_volume = [100.0, 200.0, 300.0];
1725
1726        let typical_price: Vec<f64> = input_high
1727            .iter()
1728            .zip(&input_low)
1729            .zip(&input_close)
1730            .map(|((h, l), c)| (h + l + c) / 3.0)
1731            .collect();
1732        let params = MfiParams { period: Some(10) };
1733        let input = MfiInput::from_slices(&typical_price, &input_volume, params);
1734        let result = mfi_with_kernel(&input, kernel);
1735        assert!(result.is_err());
1736        Ok(())
1737    }
1738
1739    fn check_mfi_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1740        skip_if_unsupported!(kernel, test_name);
1741        let input_high = [1.0];
1742        let input_low = [0.5];
1743        let input_close = [0.8];
1744        let input_volume = [100.0];
1745
1746        let typical_price = [(input_high[0] + input_low[0] + input_close[0]) / 3.0];
1747        let params = MfiParams { period: Some(14) };
1748        let input = MfiInput::from_slices(&typical_price, &input_volume, params);
1749        let result = mfi_with_kernel(&input, kernel);
1750        assert!(result.is_err());
1751        Ok(())
1752    }
1753
1754    fn check_mfi_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1755        skip_if_unsupported!(kernel, test_name);
1756        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1757        let candles = read_candles_from_csv(file_path)?;
1758        let first_params = MfiParams { period: Some(7) };
1759        let first_input = MfiInput::from_candles(&candles, "hlc3", first_params);
1760        let first_result = mfi_with_kernel(&first_input, kernel)?;
1761        let second_params = MfiParams { period: Some(7) };
1762
1763        let typical_price_values: Vec<f64> = first_result.values.clone();
1764        let volume_values: Vec<f64> = vec![10_000.0; first_result.values.len()];
1765        let second_input =
1766            MfiInput::from_slices(&typical_price_values, &volume_values, second_params);
1767        let second_result = mfi_with_kernel(&second_input, kernel)?;
1768        assert_eq!(second_result.values.len(), first_result.values.len());
1769        Ok(())
1770    }
1771
1772    #[cfg(debug_assertions)]
1773    fn check_mfi_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1774        skip_if_unsupported!(kernel, test_name);
1775
1776        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1777        let candles = read_candles_from_csv(file_path)?;
1778
1779        let test_params = vec![
1780            MfiParams::default(),
1781            MfiParams { period: Some(2) },
1782            MfiParams { period: Some(5) },
1783            MfiParams { period: Some(7) },
1784            MfiParams { period: Some(10) },
1785            MfiParams { period: Some(14) },
1786            MfiParams { period: Some(20) },
1787            MfiParams { period: Some(30) },
1788            MfiParams { period: Some(50) },
1789            MfiParams { period: Some(100) },
1790            MfiParams { period: Some(200) },
1791        ];
1792
1793        for (param_idx, params) in test_params.iter().enumerate() {
1794            let input = MfiInput::from_candles(&candles, "hlc3", params.clone());
1795            let output = mfi_with_kernel(&input, kernel)?;
1796
1797            for (i, &val) in output.values.iter().enumerate() {
1798                if val.is_nan() {
1799                    continue;
1800                }
1801
1802                let bits = val.to_bits();
1803
1804                if bits == 0x11111111_11111111 {
1805                    panic!(
1806                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1807						with params: period={} (param set {})",
1808                        test_name,
1809                        val,
1810                        bits,
1811                        i,
1812                        params.period.unwrap_or(14),
1813                        param_idx
1814                    );
1815                }
1816
1817                if bits == 0x22222222_22222222 {
1818                    panic!(
1819                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1820						with params: period={} (param set {})",
1821                        test_name,
1822                        val,
1823                        bits,
1824                        i,
1825                        params.period.unwrap_or(14),
1826                        param_idx
1827                    );
1828                }
1829
1830                if bits == 0x33333333_33333333 {
1831                    panic!(
1832                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1833						with params: period={} (param set {})",
1834                        test_name,
1835                        val,
1836                        bits,
1837                        i,
1838                        params.period.unwrap_or(14),
1839                        param_idx
1840                    );
1841                }
1842            }
1843        }
1844
1845        Ok(())
1846    }
1847
1848    #[cfg(not(debug_assertions))]
1849    fn check_mfi_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1850        Ok(())
1851    }
1852
1853    #[cfg(feature = "proptest")]
1854    #[allow(clippy::float_cmp)]
1855    fn check_mfi_property(
1856        test_name: &str,
1857        kernel: Kernel,
1858    ) -> Result<(), Box<dyn std::error::Error>> {
1859        use proptest::prelude::*;
1860        skip_if_unsupported!(kernel, test_name);
1861
1862        let strat = (2usize..=50).prop_flat_map(|period| {
1863            (period..=400).prop_flat_map(move |data_len| {
1864                prop_oneof![
1865
1866                    6 => (
1867
1868                        (10.0f64..10000.0f64),
1869
1870                        (0.01f64..0.2f64),
1871
1872                        (1000.0f64..1_000_000.0f64),
1873
1874                        prop::collection::vec(-1.0f64..1.0f64, data_len),
1875
1876                        prop::collection::vec(0.0f64..1.0f64, data_len),
1877                    ).prop_map(move |(base_price, volatility, volume_mult, changes, vol_factors)| {
1878                        let mut typical_price = Vec::with_capacity(data_len);
1879                        let mut volume = Vec::with_capacity(data_len);
1880                        let mut price = base_price;
1881
1882                        for i in 0..data_len {
1883
1884                            let change = changes[i] * volatility;
1885                            price *= 1.0 + change;
1886                            price = price.max(0.01);
1887                            typical_price.push(price);
1888
1889
1890                            let vol = volume_mult * (0.5 + vol_factors[i] + change.abs() * 2.0);
1891                            volume.push(vol.max(0.0));
1892                        }
1893
1894                        (typical_price, volume, period)
1895                    }),
1896
1897
1898                    15 => prop::collection::vec(100.0f64..1000.0f64, 1..=1)
1899                        .prop_map(move |prices| {
1900                            let price = prices[0];
1901                            let typical_price = vec![price; data_len];
1902                            let volume = vec![10000.0; data_len];
1903                            (typical_price, volume, period)
1904                        }),
1905
1906
1907                    15 => prop::bool::ANY.prop_map(move |uptrend| {
1908                        let mut typical_price = Vec::with_capacity(data_len);
1909                        let mut volume = Vec::with_capacity(data_len);
1910                        let start_price = 100.0;
1911
1912                        for i in 0..data_len {
1913                            let trend_factor = if uptrend {
1914                                1.0 + (i as f64 / data_len as f64) * 2.0
1915                            } else {
1916                                1.0 - (i as f64 / data_len as f64) * 0.7
1917                            };
1918                            typical_price.push(start_price * trend_factor);
1919
1920                            volume.push(10000.0 * (1.0 + i as f64 / data_len as f64) * 2.0);
1921                        }
1922
1923                        (typical_price, volume, period)
1924                    }),
1925
1926
1927                    1 => Just((
1928                        (0..data_len).map(|i| 100.0 + (i as f64)).collect::<Vec<_>>(),
1929                        vec![0.0; data_len],
1930                        period
1931                    )),
1932                ]
1933            })
1934        });
1935
1936        proptest::test_runner::TestRunner::default().run(
1937            &strat,
1938            |(typical_price, volume, period)| {
1939                let params = MfiParams {
1940                    period: Some(period),
1941                };
1942                let input = MfiInput::from_slices(&typical_price, &volume, params.clone());
1943
1944                let MfiOutput { values: out } = mfi_with_kernel(&input, kernel)?;
1945
1946                let MfiOutput { values: ref_out } = mfi_with_kernel(&input, Kernel::Scalar)?;
1947
1948                prop_assert_eq!(out.len(), typical_price.len(), "Output length mismatch");
1949
1950                let first_valid_idx = (0..typical_price.len())
1951                    .find(|&i| !typical_price[i].is_nan() && !volume[i].is_nan())
1952                    .unwrap_or(0);
1953
1954                let expected_warmup = first_valid_idx + period - 1;
1955
1956                for i in 0..out.len() {
1957                    if i < expected_warmup {
1958                        prop_assert!(
1959                            out[i].is_nan(),
1960                            "Expected NaN during warmup at index {}, got {}",
1961                            i,
1962                            out[i]
1963                        );
1964                    } else if i == expected_warmup {
1965                        prop_assert!(
1966                            !out[i].is_nan(),
1967                            "Expected first non-NaN at index {} but got NaN",
1968                            i
1969                        );
1970                    }
1971                }
1972
1973                for (i, &val) in out.iter().enumerate().skip(expected_warmup) {
1974                    if !val.is_nan() {
1975                        prop_assert!(
1976                            val >= 0.0 && val <= 100.0,
1977                            "MFI out of bounds at index {}: {}",
1978                            i,
1979                            val
1980                        );
1981                    }
1982                }
1983
1984                let is_constant = typical_price
1985                    .windows(2)
1986                    .all(|w| (w[0] - w[1]).abs() < 1e-10);
1987                if is_constant && expected_warmup < out.len() {
1988                    for i in expected_warmup..out.len() {
1989                        if !out[i].is_nan() {
1990                            prop_assert!(
1991                                out[i].abs() < 1e-3,
1992                                "Constant price MFI should be ~0, got {} at index {}",
1993                                out[i],
1994                                i
1995                            );
1996                        }
1997                    }
1998                }
1999
2000                let all_zero_volume = volume.iter().all(|&v| v.abs() < 1e-14);
2001                if all_zero_volume && expected_warmup < out.len() {
2002                    for i in expected_warmup..out.len() {
2003                        if !out[i].is_nan() {
2004                            prop_assert!(
2005                                out[i].abs() < 1e-3,
2006                                "Zero volume MFI should be 0, got {} at index {}",
2007                                out[i],
2008                                i
2009                            );
2010                        }
2011                    }
2012                }
2013
2014                if expected_warmup + period < typical_price.len() {
2015                    let check_idx = expected_warmup + period;
2016
2017                    let window_start = check_idx - period + 1;
2018                    let window_end = check_idx;
2019
2020                    let mut up_volume = 0.0;
2021                    let mut down_volume = 0.0;
2022
2023                    for i in window_start..window_end {
2024                        if i > 0 && i < typical_price.len() {
2025                            let price_change = typical_price[i] - typical_price[i - 1];
2026                            if price_change > 0.0 {
2027                                up_volume += volume[i] * typical_price[i];
2028                            } else if price_change < 0.0 {
2029                                down_volume += volume[i] * typical_price[i];
2030                            }
2031                        }
2032                    }
2033
2034                    if up_volume > down_volume * 2.0 && check_idx < out.len() {
2035                        let mfi_val = out[check_idx];
2036                        if !mfi_val.is_nan() && (up_volume + down_volume) > 1e-10 {
2037                            prop_assert!(
2038								mfi_val > 50.0,
2039								"MFI should be > 50 when up money flow dominates (up: {}, down: {}), got {}",
2040								up_volume,
2041								down_volume,
2042								mfi_val
2043							);
2044                        }
2045                    }
2046
2047                    if down_volume > up_volume * 2.0 && check_idx < out.len() {
2048                        let mfi_val = out[check_idx];
2049                        if !mfi_val.is_nan() && (up_volume + down_volume) > 1e-10 {
2050                            prop_assert!(
2051								mfi_val < 50.0,
2052								"MFI should be < 50 when down money flow dominates (up: {}, down: {}), got {}",
2053								up_volume,
2054								down_volume,
2055								mfi_val
2056							);
2057                        }
2058                    }
2059                }
2060
2061                if expected_warmup + 5 < typical_price.len() {
2062                    let verify_idx = expected_warmup + 5;
2063
2064                    let mut pos_sum = 0.0;
2065                    let mut neg_sum = 0.0;
2066
2067                    let window_start = verify_idx - period + 1;
2068
2069                    for i in window_start..=verify_idx {
2070                        if i > 0 && i < typical_price.len() {
2071                            let price_diff = typical_price[i] - typical_price[i - 1];
2072                            let money_flow = typical_price[i] * volume[i];
2073
2074                            if price_diff > 0.0 {
2075                                pos_sum += money_flow;
2076                            } else if price_diff < 0.0 {
2077                                neg_sum += money_flow;
2078                            }
2079                        }
2080                    }
2081
2082                    let total = pos_sum + neg_sum;
2083                    let expected_mfi = if total < 1e-14 {
2084                        0.0
2085                    } else {
2086                        100.0 * (pos_sum / total)
2087                    };
2088
2089                    let actual_mfi = out[verify_idx];
2090                    if !actual_mfi.is_nan() {
2091                        prop_assert!(
2092							(actual_mfi - expected_mfi).abs() < 0.1,
2093							"MFI formula verification failed at index {}: expected {} (pos: {}, neg: {}), got {}",
2094							verify_idx,
2095							expected_mfi,
2096							pos_sum,
2097							neg_sum,
2098							actual_mfi
2099						);
2100                    }
2101                }
2102
2103                if period >= 5 && period <= 20 {
2104                    let test_len = period * 3;
2105                    let mut prices = Vec::with_capacity(test_len);
2106                    let mut increasing_vol = Vec::with_capacity(test_len);
2107                    let mut decreasing_vol = Vec::with_capacity(test_len);
2108
2109                    for i in 0..test_len {
2110                        prices.push(100.0 + i as f64);
2111
2112                        increasing_vol.push(1000.0 * (1.0 + i as f64));
2113
2114                        decreasing_vol.push(1000.0 * (test_len as f64 - i as f64));
2115                    }
2116
2117                    let input_inc = MfiInput::from_slices(&prices, &increasing_vol, params.clone());
2118                    let MfiOutput { values: out_inc } = mfi_with_kernel(&input_inc, kernel)?;
2119
2120                    let input_dec = MfiInput::from_slices(&prices, &decreasing_vol, params.clone());
2121                    let MfiOutput { values: out_dec } = mfi_with_kernel(&input_dec, kernel)?;
2122
2123                    let check_idx = period * 2;
2124                    if check_idx < out_inc.len() {
2125                        let mfi_inc = out_inc[check_idx];
2126                        let mfi_dec = out_dec[check_idx];
2127
2128                        if !mfi_inc.is_nan() && !mfi_dec.is_nan() {
2129                            prop_assert!(
2130                                mfi_inc > 90.0,
2131                                "MFI with increasing volume on uptrend should be > 90, got {}",
2132                                mfi_inc
2133                            );
2134                            prop_assert!(
2135                                mfi_dec > 90.0,
2136                                "MFI with decreasing volume on uptrend should be > 90, got {}",
2137                                mfi_dec
2138                            );
2139
2140                            prop_assert!(
2141								mfi_inc > mfi_dec,
2142								"MFI with increasing volume ({}) should be > MFI with decreasing volume ({}) on uptrend",
2143								mfi_inc,
2144								mfi_dec
2145							);
2146                        }
2147                    }
2148                }
2149
2150                for i in 0..out.len() {
2151                    let y = out[i];
2152                    let r = ref_out[i];
2153
2154                    if y.is_nan() || r.is_nan() {
2155                        prop_assert_eq!(
2156                            y.is_nan(),
2157                            r.is_nan(),
2158                            "NaN mismatch at index {}: kernel={}, scalar={}",
2159                            i,
2160                            y,
2161                            r
2162                        );
2163                        continue;
2164                    }
2165
2166                    let y_bits = y.to_bits();
2167                    let r_bits = r.to_bits();
2168                    let ulp_diff = y_bits.abs_diff(r_bits);
2169
2170                    prop_assert!(
2171                        (y - r).abs() <= 1e-9 || ulp_diff <= 5,
2172                        "Kernel mismatch at index {}: {} vs {} (ULP={})",
2173                        i,
2174                        y,
2175                        r,
2176                        ulp_diff
2177                    );
2178                }
2179
2180                Ok(())
2181            },
2182        )?;
2183
2184        Ok(())
2185    }
2186
2187    macro_rules! generate_all_mfi_tests {
2188        ($($test_fn:ident),*) => {
2189            paste! {
2190                $(
2191                    #[test]
2192                    fn [<$test_fn _scalar_f64>]() {
2193                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2194                    }
2195                )*
2196                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2197                $(
2198                    #[test]
2199                    fn [<$test_fn _avx2_f64>]() {
2200                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2201                    }
2202                    #[test]
2203                    fn [<$test_fn _avx512_f64>]() {
2204                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2205                    }
2206                )*
2207            }
2208        }
2209    }
2210    generate_all_mfi_tests!(
2211        check_mfi_partial_params,
2212        check_mfi_accuracy,
2213        check_mfi_default_candles,
2214        check_mfi_zero_period,
2215        check_mfi_period_exceeds_length,
2216        check_mfi_very_small_dataset,
2217        check_mfi_reinput,
2218        check_mfi_no_poison
2219    );
2220
2221    #[cfg(feature = "proptest")]
2222    generate_all_mfi_tests!(check_mfi_property);
2223    fn check_batch_default_row(
2224        test: &str,
2225        kernel: Kernel,
2226    ) -> Result<(), Box<dyn std::error::Error>> {
2227        skip_if_unsupported!(kernel, test);
2228        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2229        let c = read_candles_from_csv(file)?;
2230
2231        let output = MfiBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
2232
2233        let def = MfiParams::default();
2234        let row = output.values_for(&def).expect("default row missing");
2235
2236        assert_eq!(row.len(), c.close.len());
2237
2238        let expected = [
2239            38.13874339324763,
2240            37.44139770113819,
2241            31.02039511395131,
2242            28.092605898618896,
2243            25.905204729397813,
2244        ];
2245        let start = row.len().saturating_sub(5);
2246        for (i, &v) in row[start..].iter().enumerate() {
2247            assert!(
2248                (v - expected[i]).abs() < 1e-1,
2249                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2250            );
2251        }
2252        Ok(())
2253    }
2254
2255    #[cfg(debug_assertions)]
2256    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2257        skip_if_unsupported!(kernel, test);
2258
2259        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2260        let c = read_candles_from_csv(file)?;
2261
2262        let test_configs = vec![
2263            (2, 10, 2),
2264            (5, 25, 5),
2265            (30, 60, 15),
2266            (2, 5, 1),
2267            (10, 50, 10),
2268            (7, 21, 7),
2269            (14, 14, 0),
2270        ];
2271
2272        for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
2273            let output = MfiBatchBuilder::new()
2274                .kernel(kernel)
2275                .period_range(p_start, p_end, p_step)
2276                .apply_candles(&c)?;
2277
2278            for (idx, &val) in output.values.iter().enumerate() {
2279                if val.is_nan() {
2280                    continue;
2281                }
2282
2283                let bits = val.to_bits();
2284                let row = idx / output.cols;
2285                let col = idx % output.cols;
2286                let combo = &output.combos[row];
2287
2288                if bits == 0x11111111_11111111 {
2289                    panic!(
2290                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2291						at row {} col {} (flat index {}) with params: period={}",
2292                        test,
2293                        cfg_idx,
2294                        val,
2295                        bits,
2296                        row,
2297                        col,
2298                        idx,
2299                        combo.period.unwrap_or(14)
2300                    );
2301                }
2302
2303                if bits == 0x22222222_22222222 {
2304                    panic!(
2305                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2306						at row {} col {} (flat index {}) with params: period={}",
2307                        test,
2308                        cfg_idx,
2309                        val,
2310                        bits,
2311                        row,
2312                        col,
2313                        idx,
2314                        combo.period.unwrap_or(14)
2315                    );
2316                }
2317
2318                if bits == 0x33333333_33333333 {
2319                    panic!(
2320                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2321						at row {} col {} (flat index {}) with params: period={}",
2322                        test,
2323                        cfg_idx,
2324                        val,
2325                        bits,
2326                        row,
2327                        col,
2328                        idx,
2329                        combo.period.unwrap_or(14)
2330                    );
2331                }
2332            }
2333        }
2334
2335        Ok(())
2336    }
2337
2338    #[cfg(not(debug_assertions))]
2339    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2340        Ok(())
2341    }
2342
2343    macro_rules! gen_batch_tests {
2344        ($fn_name:ident) => {
2345            paste! {
2346                #[test] fn [<$fn_name _scalar>]()      {
2347                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2348                }
2349                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2350                #[test] fn [<$fn_name _avx2>]()        {
2351                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2352                }
2353                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2354                #[test] fn [<$fn_name _avx512>]()      {
2355                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2356                }
2357                #[test] fn [<$fn_name _auto_detect>]() {
2358                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2359                }
2360            }
2361        };
2362    }
2363
2364    gen_batch_tests!(check_batch_default_row);
2365    gen_batch_tests!(check_batch_no_poison);
2366}