vector_ta/indicators/
fisher.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
19    make_uninit_matrix,
20};
21#[cfg(feature = "python")]
22use crate::utilities::kernel_validation::validate_kernel;
23
24#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
25use core::arch::x86_64::*;
26#[cfg(not(target_arch = "wasm32"))]
27use rayon::prelude::*;
28use std::collections::VecDeque;
29use std::convert::AsRef;
30use std::error::Error;
31use std::mem::MaybeUninit;
32use thiserror::Error;
33
34#[cfg(all(feature = "python", feature = "cuda"))]
35use crate::cuda::cuda_available;
36#[cfg(all(feature = "python", feature = "cuda"))]
37use crate::cuda::moving_averages::DeviceArrayF32;
38#[cfg(all(feature = "python", feature = "cuda"))]
39use crate::cuda::oscillators::CudaFisher;
40#[cfg(all(feature = "python", feature = "cuda"))]
41use crate::utilities::dlpack_cuda::DeviceArrayF32Py;
42#[cfg(all(feature = "python", feature = "cuda"))]
43use cust::context::Context;
44#[cfg(all(feature = "python", feature = "cuda"))]
45use std::sync::Arc;
46
47impl<'a> FisherInput<'a> {
48    #[inline(always)]
49    pub fn as_ref(&self) -> (&'a [f64], &'a [f64]) {
50        match &self.data {
51            FisherData::Candles { candles } => {
52                (source_type(candles, "high"), source_type(candles, "low"))
53            }
54            FisherData::Slices { high, low } => (*high, *low),
55        }
56    }
57}
58
59#[derive(Debug, Clone)]
60pub enum FisherData<'a> {
61    Candles { candles: &'a Candles },
62    Slices { high: &'a [f64], low: &'a [f64] },
63}
64
65#[derive(Debug, Clone)]
66pub struct FisherOutput {
67    pub fisher: Vec<f64>,
68    pub signal: Vec<f64>,
69}
70
71#[derive(Debug, Clone)]
72#[cfg_attr(
73    all(target_arch = "wasm32", feature = "wasm"),
74    derive(Serialize, Deserialize)
75)]
76pub struct FisherParams {
77    pub period: Option<usize>,
78}
79
80impl Default for FisherParams {
81    fn default() -> Self {
82        Self { period: Some(9) }
83    }
84}
85
86#[derive(Debug, Clone)]
87pub struct FisherInput<'a> {
88    pub data: FisherData<'a>,
89    pub params: FisherParams,
90}
91
92impl<'a> FisherInput<'a> {
93    #[inline]
94    pub fn from_candles(candles: &'a Candles, params: FisherParams) -> Self {
95        Self {
96            data: FisherData::Candles { candles },
97            params,
98        }
99    }
100
101    #[inline(always)]
102    pub fn get_high_low(&self) -> (&'a [f64], &'a [f64]) {
103        match &self.data {
104            FisherData::Candles { candles } => {
105                let high = source_type(candles, "high");
106                let low = source_type(candles, "low");
107                (high, low)
108            }
109            FisherData::Slices { high, low } => (*high, *low),
110        }
111    }
112    #[inline]
113    pub fn from_slices(high: &'a [f64], low: &'a [f64], params: FisherParams) -> Self {
114        Self {
115            data: FisherData::Slices { high, low },
116            params,
117        }
118    }
119    #[inline]
120    pub fn with_default_candles(candles: &'a Candles) -> Self {
121        Self::from_candles(candles, FisherParams::default())
122    }
123    #[inline]
124    pub fn get_period(&self) -> usize {
125        self.params.period.unwrap_or(9)
126    }
127}
128
129#[derive(Copy, Clone, Debug)]
130pub struct FisherBuilder {
131    period: Option<usize>,
132    kernel: Kernel,
133}
134
135impl Default for FisherBuilder {
136    fn default() -> Self {
137        Self {
138            period: None,
139            kernel: Kernel::Auto,
140        }
141    }
142}
143
144impl FisherBuilder {
145    #[inline(always)]
146    pub fn new() -> Self {
147        Self::default()
148    }
149    #[inline(always)]
150    pub fn period(mut self, n: usize) -> Self {
151        self.period = Some(n);
152        self
153    }
154    #[inline(always)]
155    pub fn kernel(mut self, k: Kernel) -> Self {
156        self.kernel = k;
157        self
158    }
159    #[inline(always)]
160    pub fn apply(self, c: &Candles) -> Result<FisherOutput, FisherError> {
161        let p = FisherParams {
162            period: self.period,
163        };
164        let i = FisherInput::from_candles(c, p);
165        fisher_with_kernel(&i, self.kernel)
166    }
167    #[inline(always)]
168    pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<FisherOutput, FisherError> {
169        let p = FisherParams {
170            period: self.period,
171        };
172        let i = FisherInput::from_slices(high, low, p);
173        fisher_with_kernel(&i, self.kernel)
174    }
175    #[inline(always)]
176    pub fn into_stream(self) -> Result<FisherStream, FisherError> {
177        let p = FisherParams {
178            period: self.period,
179        };
180        FisherStream::try_new(p)
181    }
182}
183
184#[derive(Debug, Error)]
185pub enum FisherError {
186    #[error("fisher: Empty data provided.")]
187    EmptyData,
188
189    #[error("fisher: Empty input data.")]
190    EmptyInputData,
191    #[error("fisher: Invalid period: period = {period}, data length = {data_len}")]
192    InvalidPeriod { period: usize, data_len: usize },
193    #[error("fisher: Not enough valid data: needed = {needed}, valid = {valid}")]
194    NotEnoughValidData { needed: usize, valid: usize },
195    #[error("fisher: All values are NaN.")]
196    AllValuesNaN,
197
198    #[error("fisher: Invalid output length: expected = {expected}, actual = {actual}")]
199    InvalidLength { expected: usize, actual: usize },
200
201    #[error("fisher: Output length mismatch: expected = {expected}, got = {got}")]
202    OutputLengthMismatch { expected: usize, got: usize },
203    #[error("fisher: Mismatched data length: high={high}, low={low}")]
204    MismatchedDataLength { high: usize, low: usize },
205
206    #[error("fisher: Invalid range expansion: start={start}, end={end}, step={step}")]
207    InvalidRange {
208        start: usize,
209        end: usize,
210        step: usize,
211    },
212    #[error("fisher: Invalid kernel for batch path: {0:?}")]
213    InvalidKernelForBatch(crate::utilities::enums::Kernel),
214}
215
216#[inline]
217pub fn fisher(input: &FisherInput) -> Result<FisherOutput, FisherError> {
218    fisher_with_kernel(input, Kernel::Auto)
219}
220
221pub fn fisher_with_kernel(
222    input: &FisherInput,
223    kernel: Kernel,
224) -> Result<FisherOutput, FisherError> {
225    let (high, low) = input.get_high_low();
226
227    if high.is_empty() || low.is_empty() {
228        return Err(FisherError::EmptyInputData);
229    }
230    if high.len() != low.len() {
231        return Err(FisherError::MismatchedDataLength {
232            high: high.len(),
233            low: low.len(),
234        });
235    }
236
237    let period = input.get_period();
238    let data_len = high.len();
239    if period == 0 || period > data_len {
240        return Err(FisherError::InvalidPeriod { period, data_len });
241    }
242
243    let mut first = None;
244    for i in 0..data_len {
245        if !high[i].is_nan() && !low[i].is_nan() {
246            first = Some(i);
247            break;
248        }
249    }
250    let first = first.ok_or(FisherError::AllValuesNaN)?;
251
252    if (data_len - first) < period {
253        return Err(FisherError::NotEnoughValidData {
254            needed: period,
255            valid: data_len - first,
256        });
257    }
258
259    let chosen = match kernel {
260        Kernel::Auto => detect_best_kernel(),
261        other => other,
262    };
263
264    let warmup = first + period - 1;
265    let mut fisher_vals = alloc_with_nan_prefix(data_len, warmup);
266    let mut signal_vals = alloc_with_nan_prefix(data_len, warmup);
267
268    unsafe {
269        match chosen {
270            Kernel::Scalar | Kernel::ScalarBatch => {
271                fisher_scalar_into(high, low, period, first, &mut fisher_vals, &mut signal_vals)
272            }
273            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
274            Kernel::Avx2 | Kernel::Avx2Batch => {
275                fisher_avx2_into(high, low, period, first, &mut fisher_vals, &mut signal_vals)
276            }
277            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
278            Kernel::Avx512 | Kernel::Avx512Batch => {
279                fisher_avx512_into(high, low, period, first, &mut fisher_vals, &mut signal_vals)
280            }
281            _ => unreachable!(),
282        }
283    }
284
285    Ok(FisherOutput {
286        fisher: fisher_vals,
287        signal: signal_vals,
288    })
289}
290
291#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
292#[inline]
293pub fn fisher_into(
294    input: &FisherInput,
295    fisher_out: &mut [f64],
296    signal_out: &mut [f64],
297) -> Result<(), FisherError> {
298    fisher_into_slice(fisher_out, signal_out, input, Kernel::Auto)
299}
300
301#[inline]
302pub fn fisher_scalar_into(
303    high: &[f64],
304    low: &[f64],
305    period: usize,
306    first: usize,
307    fisher_out: &mut [f64],
308    signal_out: &mut [f64],
309) {
310    let len = high.len().min(low.len());
311    if period == 0 || first >= len {
312        return;
313    }
314
315    let mut prev_fish = 0.0f64;
316    let mut val1 = 0.0f64;
317    let warm = first + period - 1;
318
319    for i in warm..len {
320        let start = i + 1 - period;
321
322        let (mut min_val, mut max_val) = (f64::MAX, f64::MIN);
323        for j in start..=i {
324            let midpoint = 0.5 * (high[j] + low[j]);
325            if midpoint > max_val {
326                max_val = midpoint;
327            }
328            if midpoint < min_val {
329                min_val = midpoint;
330            }
331        }
332
333        let range = (max_val - min_val).max(0.001);
334        let hl = 0.5 * (high[i] + low[i]);
335        val1 = 0.67f64.mul_add(val1, 0.66 * ((hl - min_val) / range - 0.5));
336        if val1 > 0.99 {
337            val1 = 0.999;
338        } else if val1 < -0.99 {
339            val1 = -0.999;
340        }
341        signal_out[i] = prev_fish;
342        let new_fish = 0.5f64.mul_add(((1.0 + val1) / (1.0 - val1)).ln(), 0.5 * prev_fish);
343        fisher_out[i] = new_fish;
344        prev_fish = new_fish;
345    }
346}
347
348#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
349#[inline]
350pub fn fisher_avx512_into(
351    high: &[f64],
352    low: &[f64],
353    period: usize,
354    first: usize,
355    fisher_out: &mut [f64],
356    signal_out: &mut [f64],
357) {
358    fisher_scalar_into(high, low, period, first, fisher_out, signal_out)
359}
360
361#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
362#[inline]
363pub fn fisher_avx2_into(
364    high: &[f64],
365    low: &[f64],
366    period: usize,
367    first: usize,
368    fisher_out: &mut [f64],
369    signal_out: &mut [f64],
370) {
371    fisher_scalar_into(high, low, period, first, fisher_out, signal_out)
372}
373
374#[inline]
375pub fn fisher_into_slice(
376    fisher_dst: &mut [f64],
377    signal_dst: &mut [f64],
378    input: &FisherInput,
379    kern: Kernel,
380) -> Result<(), FisherError> {
381    let (high, low) = input.as_ref();
382    if high.is_empty() || low.is_empty() {
383        return Err(FisherError::EmptyData);
384    }
385    if high.len() != low.len() {
386        return Err(FisherError::MismatchedDataLength {
387            high: high.len(),
388            low: low.len(),
389        });
390    }
391
392    let data_len = high.len();
393    let period = input.params.period.unwrap_or(9);
394
395    let mut first = None;
396    for i in 0..data_len {
397        if !high[i].is_nan() && !low[i].is_nan() {
398            first = Some(i);
399            break;
400        }
401    }
402    let first = first.ok_or(FisherError::AllValuesNaN)?;
403
404    if period == 0 || period > data_len {
405        return Err(FisherError::InvalidPeriod { period, data_len });
406    }
407    if fisher_dst.len() != data_len || signal_dst.len() != data_len {
408        return Err(FisherError::OutputLengthMismatch {
409            expected: data_len,
410            got: fisher_dst.len().min(signal_dst.len()),
411        });
412    }
413
414    let chosen = if kern == Kernel::Auto {
415        detect_best_kernel()
416    } else {
417        kern
418    };
419
420    match chosen {
421        Kernel::Scalar | Kernel::ScalarBatch => {
422            fisher_scalar_into(high, low, period, first, fisher_dst, signal_dst)
423        }
424        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
425        Kernel::Avx2 | Kernel::Avx2Batch => {
426            fisher_avx2_into(high, low, period, first, fisher_dst, signal_dst)
427        }
428        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
429        Kernel::Avx512 | Kernel::Avx512Batch => {
430            fisher_avx512_into(high, low, period, first, fisher_dst, signal_dst)
431        }
432        _ => unreachable!(),
433    }
434
435    let warmup_end = first + period - 1;
436    for i in 0..warmup_end {
437        fisher_dst[i] = f64::NAN;
438        signal_dst[i] = f64::NAN;
439    }
440
441    Ok(())
442}
443
444#[inline]
445pub fn fisher_batch_with_kernel(
446    high: &[f64],
447    low: &[f64],
448    sweep: &FisherBatchRange,
449    k: Kernel,
450) -> Result<FisherBatchOutput, FisherError> {
451    let kernel = match k {
452        Kernel::Auto => detect_best_batch_kernel(),
453        other if other.is_batch() => other,
454        other => return Err(FisherError::InvalidKernelForBatch(other)),
455    };
456    let simd = match kernel {
457        Kernel::Avx512Batch => Kernel::Avx512,
458        Kernel::Avx2Batch => Kernel::Avx2,
459        Kernel::ScalarBatch => Kernel::Scalar,
460        _ => unreachable!(),
461    };
462    fisher_batch_par_slice(high, low, sweep, simd)
463}
464
465#[derive(Clone, Debug)]
466pub struct FisherBatchRange {
467    pub period: (usize, usize, usize),
468}
469
470impl Default for FisherBatchRange {
471    fn default() -> Self {
472        Self {
473            period: (9, 258, 1),
474        }
475    }
476}
477
478#[derive(Clone, Debug, Default)]
479pub struct FisherBatchBuilder {
480    range: FisherBatchRange,
481    kernel: Kernel,
482}
483
484impl FisherBatchBuilder {
485    pub fn new() -> Self {
486        Self::default()
487    }
488    pub fn kernel(mut self, k: Kernel) -> Self {
489        self.kernel = k;
490        self
491    }
492    #[inline]
493    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
494        self.range.period = (start, end, step);
495        self
496    }
497    #[inline]
498    pub fn period_static(mut self, p: usize) -> Self {
499        self.range.period = (p, p, 0);
500        self
501    }
502    pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<FisherBatchOutput, FisherError> {
503        fisher_batch_with_kernel(high, low, &self.range, self.kernel)
504    }
505    pub fn with_default_slices(
506        high: &[f64],
507        low: &[f64],
508        k: Kernel,
509    ) -> Result<FisherBatchOutput, FisherError> {
510        FisherBatchBuilder::new().kernel(k).apply_slices(high, low)
511    }
512    pub fn apply_candles(self, c: &Candles) -> Result<FisherBatchOutput, FisherError> {
513        let high = source_type(c, "high");
514        let low = source_type(c, "low");
515        self.apply_slices(high, low)
516    }
517    pub fn with_default_candles(c: &Candles) -> Result<FisherBatchOutput, FisherError> {
518        FisherBatchBuilder::new()
519            .kernel(Kernel::Auto)
520            .apply_candles(c)
521    }
522}
523
524#[derive(Clone, Debug)]
525pub struct FisherBatchOutput {
526    pub fisher: Vec<f64>,
527    pub signal: Vec<f64>,
528    pub combos: Vec<FisherParams>,
529    pub rows: usize,
530    pub cols: usize,
531}
532impl FisherBatchOutput {
533    pub fn row_for_params(&self, p: &FisherParams) -> Option<usize> {
534        self.combos
535            .iter()
536            .position(|c| c.period.unwrap_or(9) == p.period.unwrap_or(9))
537    }
538    pub fn fisher_for(&self, p: &FisherParams) -> Option<&[f64]> {
539        self.row_for_params(p).map(|row| {
540            let start = row * self.cols;
541            &self.fisher[start..start + self.cols]
542        })
543    }
544    pub fn signal_for(&self, p: &FisherParams) -> Option<&[f64]> {
545        self.row_for_params(p).map(|row| {
546            let start = row * self.cols;
547            &self.signal[start..start + self.cols]
548        })
549    }
550}
551
552#[inline(always)]
553fn expand_grid(r: &FisherBatchRange) -> Vec<FisherParams> {
554    fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
555        if step == 0 || start == end {
556            return vec![start];
557        }
558        let mut out = Vec::new();
559        if start < end {
560            if step == 0 {
561                return vec![start];
562            }
563            let mut v = start;
564            while v <= end {
565                out.push(v);
566                match v.checked_add(step) {
567                    Some(n) => v = n,
568                    None => break,
569                }
570            }
571        } else {
572            if step == 0 {
573                return vec![start];
574            }
575            let mut v = start;
576            while v >= end {
577                out.push(v);
578                if v < end + step {
579                    break;
580                }
581                v -= step;
582                if v == 0 && end > 0 && step > 0 {
583                    if v < end {
584                        break;
585                    }
586                }
587            }
588        }
589        out
590    }
591    let periods = axis_usize(r.period);
592    let mut out = Vec::with_capacity(periods.len());
593    for &p in &periods {
594        out.push(FisherParams { period: Some(p) });
595    }
596    out
597}
598
599#[inline(always)]
600pub fn fisher_batch_slice(
601    high: &[f64],
602    low: &[f64],
603    sweep: &FisherBatchRange,
604    kern: Kernel,
605) -> Result<FisherBatchOutput, FisherError> {
606    fisher_batch_inner(high, low, sweep, kern, false)
607}
608#[inline(always)]
609pub fn fisher_batch_par_slice(
610    high: &[f64],
611    low: &[f64],
612    sweep: &FisherBatchRange,
613    kern: Kernel,
614) -> Result<FisherBatchOutput, FisherError> {
615    fisher_batch_inner(high, low, sweep, kern, true)
616}
617
618#[inline(always)]
619fn fisher_batch_inner(
620    high: &[f64],
621    low: &[f64],
622    sweep: &FisherBatchRange,
623    kern: Kernel,
624    parallel: bool,
625) -> Result<FisherBatchOutput, FisherError> {
626    let combos = expand_grid(sweep);
627    if combos.is_empty() {
628        let (s, e, st) = sweep.period;
629        return Err(FisherError::InvalidRange {
630            start: s,
631            end: e,
632            step: st,
633        });
634    }
635
636    let data_len = high.len().min(low.len());
637
638    let mut first = None;
639    for i in 0..data_len {
640        if !high[i].is_nan() && !low[i].is_nan() {
641            first = Some(i);
642            break;
643        }
644    }
645    let first = first.ok_or(FisherError::AllValuesNaN)?;
646
647    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
648    if data_len - first < max_p {
649        return Err(FisherError::NotEnoughValidData {
650            needed: max_p,
651            valid: data_len - first,
652        });
653    }
654    let rows = combos.len();
655    let cols = data_len;
656
657    let _ = rows.checked_mul(cols).ok_or(FisherError::InvalidRange {
658        start: sweep.period.0,
659        end: sweep.period.1,
660        step: sweep.period.2,
661    })?;
662
663    let mut fisher_mu = make_uninit_matrix(rows, cols);
664    let mut signal_mu = make_uninit_matrix(rows, cols);
665
666    let mut warmup_periods: Vec<usize> = Vec::with_capacity(combos.len());
667    for c in &combos {
668        let p = c.period.unwrap_or(0);
669        let warm = first
670            .checked_add(p.saturating_sub(1))
671            .ok_or(FisherError::InvalidRange {
672                start: sweep.period.0,
673                end: sweep.period.1,
674                step: sweep.period.2,
675            })?;
676        warmup_periods.push(warm);
677    }
678
679    init_matrix_prefixes(&mut fisher_mu, cols, &warmup_periods);
680    init_matrix_prefixes(&mut signal_mu, cols, &warmup_periods);
681
682    let mut fisher_guard = core::mem::ManuallyDrop::new(fisher_mu);
683    let mut signal_guard = core::mem::ManuallyDrop::new(signal_mu);
684
685    let fisher_slice: &mut [f64] = unsafe {
686        core::slice::from_raw_parts_mut(fisher_guard.as_mut_ptr() as *mut f64, fisher_guard.len())
687    };
688    let signal_slice: &mut [f64] = unsafe {
689        core::slice::from_raw_parts_mut(signal_guard.as_mut_ptr() as *mut f64, signal_guard.len())
690    };
691
692    let hl: Vec<f64> = (0..data_len).map(|i| 0.5 * (high[i] + low[i])).collect();
693
694    let do_row = |row: usize, out_fish: &mut [f64], out_signal: &mut [f64]| {
695        let period = combos[row].period.unwrap();
696        match kern {
697            Kernel::Scalar | Kernel::ScalarBatch => {
698                fisher_row_scalar_from_hl(&hl, first, period, out_fish, out_signal)
699            }
700            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
701            Kernel::Avx2 | Kernel::Avx2Batch => {
702                fisher_row_avx2_direct(high, low, first, period, out_fish, out_signal)
703            }
704            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
705            Kernel::Avx512 | Kernel::Avx512Batch => {
706                fisher_row_avx512_direct(high, low, first, period, out_fish, out_signal)
707            }
708            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
709            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
710                fisher_row_scalar_direct(high, low, first, period, out_fish, out_signal)
711            }
712            Kernel::Auto => fisher_row_scalar_from_hl(&hl, first, period, out_fish, out_signal),
713        }
714    };
715
716    if parallel {
717        #[cfg(not(target_arch = "wasm32"))]
718        {
719            fisher_slice
720                .par_chunks_mut(cols)
721                .zip(signal_slice.par_chunks_mut(cols))
722                .enumerate()
723                .for_each(|(row, (fish, sig))| do_row(row, fish, sig));
724        }
725
726        #[cfg(target_arch = "wasm32")]
727        {
728            for (row, (fish, sig)) in fisher_slice
729                .chunks_mut(cols)
730                .zip(signal_slice.chunks_mut(cols))
731                .enumerate()
732            {
733                do_row(row, fish, sig);
734            }
735        }
736    } else {
737        for (row, (fish, sig)) in fisher_slice
738            .chunks_mut(cols)
739            .zip(signal_slice.chunks_mut(cols))
740            .enumerate()
741        {
742            do_row(row, fish, sig);
743        }
744    }
745
746    let fisher = unsafe {
747        Vec::from_raw_parts(
748            fisher_guard.as_mut_ptr() as *mut f64,
749            fisher_guard.len(),
750            fisher_guard.capacity(),
751        )
752    };
753    let signal = unsafe {
754        Vec::from_raw_parts(
755            signal_guard.as_mut_ptr() as *mut f64,
756            signal_guard.len(),
757            signal_guard.capacity(),
758        )
759    };
760
761    core::mem::forget(fisher_guard);
762    core::mem::forget(signal_guard);
763
764    Ok(FisherBatchOutput {
765        fisher,
766        signal,
767        combos,
768        rows,
769        cols,
770    })
771}
772
773#[inline(always)]
774fn fisher_batch_inner_into(
775    high: &[f64],
776    low: &[f64],
777    sweep: &FisherBatchRange,
778    kern: Kernel,
779    parallel: bool,
780    fisher_out: &mut [f64],
781    signal_out: &mut [f64],
782) -> Result<Vec<FisherParams>, FisherError> {
783    let combos = expand_grid(sweep);
784    if combos.is_empty() {
785        let (s, e, st) = sweep.period;
786        return Err(FisherError::InvalidRange {
787            start: s,
788            end: e,
789            step: st,
790        });
791    }
792
793    let data_len = high.len().min(low.len());
794
795    let mut first = None;
796    for i in 0..data_len {
797        if !high[i].is_nan() && !low[i].is_nan() {
798            first = Some(i);
799            break;
800        }
801    }
802    let first = first.ok_or(FisherError::AllValuesNaN)?;
803
804    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
805    let _ = combos
806        .len()
807        .checked_mul(max_p)
808        .ok_or(FisherError::InvalidRange {
809            start: sweep.period.0,
810            end: sweep.period.1,
811            step: sweep.period.2,
812        })?;
813    if data_len - first < max_p {
814        return Err(FisherError::NotEnoughValidData {
815            needed: max_p,
816            valid: data_len - first,
817        });
818    }
819
820    let rows = combos.len();
821    let cols = data_len;
822    let _ = rows.checked_mul(cols).ok_or(FisherError::InvalidRange {
823        start: sweep.period.0,
824        end: sweep.period.1,
825        step: sweep.period.2,
826    })?;
827
828    for (row, combo) in combos.iter().enumerate() {
829        let p = combo.period.unwrap_or(0);
830        let warmup = first
831            .checked_add(p.saturating_sub(1))
832            .ok_or(FisherError::InvalidRange {
833                start: sweep.period.0,
834                end: sweep.period.1,
835                step: sweep.period.2,
836            })?;
837        let row_start = row * cols;
838        for i in 0..warmup.min(cols) {
839            fisher_out[row_start + i] = f64::NAN;
840            signal_out[row_start + i] = f64::NAN;
841        }
842    }
843
844    let hl: Vec<f64> = (0..data_len).map(|i| 0.5 * (high[i] + low[i])).collect();
845
846    let do_row = |row: usize, out_fish: &mut [f64], out_signal: &mut [f64]| {
847        let period = combos[row].period.unwrap();
848        match kern {
849            Kernel::Scalar | Kernel::ScalarBatch => {
850                fisher_row_scalar_from_hl(&hl, first, period, out_fish, out_signal)
851            }
852            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
853            Kernel::Avx2 | Kernel::Avx2Batch => {
854                fisher_row_avx2_direct(high, low, first, period, out_fish, out_signal)
855            }
856            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
857            Kernel::Avx512 | Kernel::Avx512Batch => {
858                fisher_row_avx512_direct(high, low, first, period, out_fish, out_signal)
859            }
860            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
861            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
862                fisher_row_scalar_direct(high, low, first, period, out_fish, out_signal)
863            }
864            Kernel::Auto => fisher_row_scalar_from_hl(&hl, first, period, out_fish, out_signal),
865        }
866    };
867
868    if parallel {
869        #[cfg(not(target_arch = "wasm32"))]
870        {
871            fisher_out
872                .par_chunks_mut(cols)
873                .zip(signal_out.par_chunks_mut(cols))
874                .enumerate()
875                .for_each(|(row, (fish, sig))| do_row(row, fish, sig));
876        }
877
878        #[cfg(target_arch = "wasm32")]
879        {
880            for (row, (fish, sig)) in fisher_out
881                .chunks_mut(cols)
882                .zip(signal_out.chunks_mut(cols))
883                .enumerate()
884            {
885                do_row(row, fish, sig);
886            }
887        }
888    } else {
889        for (row, (fish, sig)) in fisher_out
890            .chunks_mut(cols)
891            .zip(signal_out.chunks_mut(cols))
892            .enumerate()
893        {
894            do_row(row, fish, sig);
895        }
896    }
897
898    Ok(combos)
899}
900
901#[inline(always)]
902fn fisher_row_scalar_direct(
903    high: &[f64],
904    low: &[f64],
905    first: usize,
906    period: usize,
907    out_fish: &mut [f64],
908    out_signal: &mut [f64],
909) {
910    let len = high.len().min(low.len());
911    if period == 0 || first >= len {
912        return;
913    }
914
915    let mut prev_fish = 0.0f64;
916    let mut val1 = 0.0f64;
917
918    let warm = first + period - 1;
919
920    for i in warm..len {
921        let start = i + 1 - period;
922
923        let (mut min_val, mut max_val) = (f64::MAX, f64::MIN);
924        for j in start..=i {
925            let midpoint = 0.5 * (high[j] + low[j]);
926            if midpoint > max_val {
927                max_val = midpoint;
928            }
929            if midpoint < min_val {
930                min_val = midpoint;
931            }
932        }
933
934        let range = (max_val - min_val).max(0.001);
935        let hl = 0.5 * (high[i] + low[i]);
936        val1 = 0.67 * val1 + 0.66 * ((hl - min_val) / range - 0.5);
937        if val1 > 0.99 {
938            val1 = 0.999;
939        } else if val1 < -0.99 {
940            val1 = -0.999;
941        }
942        out_signal[i] = prev_fish;
943        let new_fish = 0.5 * ((1.0 + val1) / (1.0 - val1)).ln() + 0.5 * prev_fish;
944        out_fish[i] = new_fish;
945        prev_fish = new_fish;
946    }
947}
948
949#[inline(always)]
950fn fisher_row_scalar_from_hl(
951    hl: &[f64],
952    first: usize,
953    period: usize,
954    out_fish: &mut [f64],
955    out_signal: &mut [f64],
956) {
957    let len = hl.len();
958    if period == 0 || first >= len {
959        return;
960    }
961
962    let mut prev_fish = 0.0f64;
963    let mut val1 = 0.0f64;
964    let warm = first + period - 1;
965
966    for i in warm..len {
967        let start = i + 1 - period;
968        let (mut min_val, mut max_val) = (f64::MAX, f64::MIN);
969        for &v in &hl[start..=i] {
970            if v > max_val {
971                max_val = v;
972            }
973            if v < min_val {
974                min_val = v;
975            }
976        }
977
978        let range = (max_val - min_val).max(0.001);
979        let v = hl[i];
980        val1 = 0.67f64.mul_add(val1, 0.66 * ((v - min_val) / range - 0.5));
981        if val1 > 0.99 {
982            val1 = 0.999;
983        } else if val1 < -0.99 {
984            val1 = -0.999;
985        }
986        out_signal[i] = prev_fish;
987        let new_fish = 0.5f64.mul_add(((1.0 + val1) / (1.0 - val1)).ln(), 0.5 * prev_fish);
988        out_fish[i] = new_fish;
989        prev_fish = new_fish;
990    }
991}
992
993#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
994#[inline]
995pub fn fisher_row_avx2_direct(
996    high: &[f64],
997    low: &[f64],
998    first: usize,
999    period: usize,
1000    out_fish: &mut [f64],
1001    out_signal: &mut [f64],
1002) {
1003    fisher_row_scalar_direct(high, low, first, period, out_fish, out_signal)
1004}
1005
1006#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1007#[inline]
1008pub fn fisher_row_avx512_direct(
1009    high: &[f64],
1010    low: &[f64],
1011    first: usize,
1012    period: usize,
1013    out_fish: &mut [f64],
1014    out_signal: &mut [f64],
1015) {
1016    fisher_row_scalar_direct(high, low, first, period, out_fish, out_signal)
1017}
1018
1019#[derive(Debug, Clone)]
1020pub struct FisherStream {
1021    period: usize,
1022
1023    idx: usize,
1024    filled: bool,
1025
1026    minq: VecDeque<(f64, usize)>,
1027    maxq: VecDeque<(f64, usize)>,
1028    prev_fish: f64,
1029    val1: f64,
1030}
1031
1032impl FisherStream {
1033    pub fn try_new(params: FisherParams) -> Result<Self, FisherError> {
1034        let period = params.period.unwrap_or(9);
1035        if period == 0 {
1036            return Err(FisherError::InvalidPeriod {
1037                period,
1038                data_len: 0,
1039            });
1040        }
1041        Ok(Self {
1042            period,
1043            idx: 0,
1044            filled: false,
1045            minq: VecDeque::with_capacity(period + 1),
1046            maxq: VecDeque::with_capacity(period + 1),
1047            prev_fish: 0.0,
1048            val1: 0.0,
1049        })
1050    }
1051
1052    #[inline(always)]
1053    pub fn update(&mut self, high: f64, low: f64) -> Option<(f64, f64)> {
1054        let v = 0.5 * (high + low);
1055        let k = self.idx;
1056
1057        while let Some(&(last_v, _)) = self.minq.back() {
1058            if last_v >= v {
1059                self.minq.pop_back();
1060            } else {
1061                break;
1062            }
1063        }
1064        self.minq.push_back((v, k));
1065
1066        while let Some(&(last_v, _)) = self.maxq.back() {
1067            if last_v <= v {
1068                self.maxq.pop_back();
1069            } else {
1070                break;
1071            }
1072        }
1073        self.maxq.push_back((v, k));
1074
1075        let start = k.saturating_sub(self.period - 1);
1076        while let Some(&(_, i)) = self.minq.front() {
1077            if i < start {
1078                self.minq.pop_front();
1079            } else {
1080                break;
1081            }
1082        }
1083        while let Some(&(_, i)) = self.maxq.front() {
1084            if i < start {
1085                self.maxq.pop_front();
1086            } else {
1087                break;
1088            }
1089        }
1090
1091        self.idx = k + 1;
1092
1093        if !self.filled {
1094            self.filled = self.idx >= self.period;
1095            if !self.filled {
1096                return None;
1097            }
1098        }
1099
1100        let min_val = self.minq.front().map(|&(x, _)| x).unwrap_or(v);
1101        let max_val = self.maxq.front().map(|&(x, _)| x).unwrap_or(v);
1102
1103        let range = (max_val - min_val).max(0.001);
1104
1105        self.val1 = 0.67f64.mul_add(self.val1, 0.66 * ((v - min_val) / range - 0.5));
1106        if self.val1 > 0.99 {
1107            self.val1 = 0.999;
1108        } else if self.val1 < -0.99 {
1109            self.val1 = -0.999;
1110        }
1111
1112        let signal = self.prev_fish;
1113        let fisher = 0.5f64.mul_add(
1114            ((1.0 + self.val1) / (1.0 - self.val1)).ln(),
1115            0.5 * self.prev_fish,
1116        );
1117
1118        self.prev_fish = fisher;
1119        Some((fisher, signal))
1120    }
1121}
1122
1123#[cfg(test)]
1124mod tests {
1125    use super::*;
1126    use crate::skip_if_unsupported;
1127    use crate::utilities::data_loader::read_candles_from_csv;
1128
1129    fn check_fisher_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1130        skip_if_unsupported!(kernel, test_name);
1131        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1132        let candles = read_candles_from_csv(file_path)?;
1133        let default_params = FisherParams { period: None };
1134        let input = FisherInput::from_candles(&candles, default_params);
1135        let output = fisher_with_kernel(&input, kernel)?;
1136        assert_eq!(output.fisher.len(), candles.close.len());
1137        Ok(())
1138    }
1139
1140    fn check_fisher_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1141        skip_if_unsupported!(kernel, test_name);
1142        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1143        let candles = read_candles_from_csv(file_path)?;
1144        let input = FisherInput::from_candles(&candles, FisherParams::default());
1145        let result = fisher_with_kernel(&input, kernel)?;
1146        let expected_last_five_fisher = [
1147            -0.4720164683904261,
1148            -0.23467530106650444,
1149            -0.14879388501136784,
1150            -0.026651419122953053,
1151            -0.2569225042442664,
1152        ];
1153        let start = result.fisher.len().saturating_sub(5);
1154        for (i, &val) in result.fisher[start..].iter().enumerate() {
1155            let diff = (val - expected_last_five_fisher[i]).abs();
1156            assert!(
1157                diff < 1e-1,
1158                "[{}] Fisher {:?} mismatch at idx {}: got {}, expected {}",
1159                test_name,
1160                kernel,
1161                i,
1162                val,
1163                expected_last_five_fisher[i]
1164            );
1165        }
1166        Ok(())
1167    }
1168
1169    fn check_fisher_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1170        skip_if_unsupported!(kernel, test_name);
1171        let high = [10.0, 20.0, 30.0];
1172        let low = [5.0, 15.0, 25.0];
1173        let params = FisherParams { period: Some(0) };
1174        let input = FisherInput::from_slices(&high, &low, params);
1175        let res = fisher_with_kernel(&input, kernel);
1176        assert!(
1177            res.is_err(),
1178            "[{}] Fisher should fail with zero period",
1179            test_name
1180        );
1181        Ok(())
1182    }
1183
1184    fn check_fisher_period_exceeds_length(
1185        test_name: &str,
1186        kernel: Kernel,
1187    ) -> Result<(), Box<dyn Error>> {
1188        skip_if_unsupported!(kernel, test_name);
1189        let high = [10.0, 20.0, 30.0];
1190        let low = [5.0, 15.0, 25.0];
1191        let params = FisherParams { period: Some(10) };
1192        let input = FisherInput::from_slices(&high, &low, params);
1193        let res = fisher_with_kernel(&input, kernel);
1194        assert!(
1195            res.is_err(),
1196            "[{}] Fisher should fail with period exceeding length",
1197            test_name
1198        );
1199        Ok(())
1200    }
1201
1202    fn check_fisher_very_small_dataset(
1203        test_name: &str,
1204        kernel: Kernel,
1205    ) -> Result<(), Box<dyn Error>> {
1206        skip_if_unsupported!(kernel, test_name);
1207        let high = [10.0];
1208        let low = [5.0];
1209        let params = FisherParams { period: Some(9) };
1210        let input = FisherInput::from_slices(&high, &low, params);
1211        let res = fisher_with_kernel(&input, kernel);
1212        assert!(
1213            res.is_err(),
1214            "[{}] Fisher should fail with insufficient data",
1215            test_name
1216        );
1217        Ok(())
1218    }
1219
1220    fn check_fisher_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1221        skip_if_unsupported!(kernel, test_name);
1222        let high = [10.0, 12.0, 14.0, 16.0, 18.0, 20.0];
1223        let low = [5.0, 7.0, 9.0, 10.0, 13.0, 15.0];
1224        let first_params = FisherParams { period: Some(3) };
1225        let first_input = FisherInput::from_slices(&high, &low, first_params);
1226        let first_result = fisher_with_kernel(&first_input, kernel)?;
1227        let second_params = FisherParams { period: Some(3) };
1228        let second_input =
1229            FisherInput::from_slices(&first_result.fisher, &first_result.signal, second_params);
1230        let second_result = fisher_with_kernel(&second_input, kernel)?;
1231        assert_eq!(first_result.fisher.len(), second_result.fisher.len());
1232        assert_eq!(first_result.signal.len(), second_result.signal.len());
1233        Ok(())
1234    }
1235
1236    fn check_fisher_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1237        skip_if_unsupported!(kernel, test_name);
1238        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1239        let candles = read_candles_from_csv(file_path)?;
1240        let input = FisherInput::from_candles(&candles, FisherParams::default());
1241        let res = fisher_with_kernel(&input, kernel)?;
1242        assert_eq!(res.fisher.len(), candles.close.len());
1243        if res.fisher.len() > 240 {
1244            for (i, &val) in res.fisher[240..].iter().enumerate() {
1245                assert!(
1246                    !val.is_nan(),
1247                    "[{}] Found unexpected NaN at out-index {}",
1248                    test_name,
1249                    240 + i
1250                );
1251            }
1252        }
1253        Ok(())
1254    }
1255
1256    fn check_fisher_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1257        skip_if_unsupported!(kernel, test_name);
1258        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1259        let candles = read_candles_from_csv(file_path)?;
1260        let period = 9;
1261        let input = FisherInput::from_candles(
1262            &candles,
1263            FisherParams {
1264                period: Some(period),
1265            },
1266        );
1267        let batch_output = fisher_with_kernel(&input, kernel)?.fisher;
1268
1269        let highs = source_type(&candles, "high");
1270        let lows = source_type(&candles, "low");
1271
1272        let mut stream = FisherStream::try_new(FisherParams {
1273            period: Some(period),
1274        })?;
1275        let mut stream_fisher = Vec::with_capacity(highs.len());
1276        for (&h, &l) in highs.iter().zip(lows.iter()) {
1277            match stream.update(h, l) {
1278                Some((fish, _sig)) => stream_fisher.push(fish),
1279                None => stream_fisher.push(f64::NAN),
1280            }
1281        }
1282
1283        assert_eq!(batch_output.len(), stream_fisher.len());
1284        for (i, (&b, &s)) in batch_output.iter().zip(stream_fisher.iter()).enumerate() {
1285            if b.is_nan() && s.is_nan() {
1286                continue;
1287            }
1288            let diff = (b - s).abs();
1289            assert!(
1290                diff < 1e-9,
1291                "[{}] Fisher streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1292                test_name,
1293                i,
1294                b,
1295                s,
1296                diff
1297            );
1298        }
1299        Ok(())
1300    }
1301
1302    #[cfg(debug_assertions)]
1303    fn check_fisher_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1304        skip_if_unsupported!(kernel, test_name);
1305
1306        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1307        let candles = read_candles_from_csv(file_path)?;
1308
1309        let test_params = vec![
1310            FisherParams::default(),
1311            FisherParams { period: Some(1) },
1312            FisherParams { period: Some(2) },
1313            FisherParams { period: Some(3) },
1314            FisherParams { period: Some(5) },
1315            FisherParams { period: Some(10) },
1316            FisherParams { period: Some(20) },
1317            FisherParams { period: Some(30) },
1318            FisherParams { period: Some(50) },
1319            FisherParams { period: Some(100) },
1320            FisherParams { period: Some(200) },
1321            FisherParams { period: Some(240) },
1322        ];
1323
1324        for (param_idx, params) in test_params.iter().enumerate() {
1325            let input = FisherInput::from_candles(&candles, params.clone());
1326            let output = fisher_with_kernel(&input, kernel)?;
1327
1328            for (i, &val) in output.fisher.iter().enumerate() {
1329                if val.is_nan() {
1330                    continue;
1331                }
1332
1333                let bits = val.to_bits();
1334
1335                if bits == 0x11111111_11111111 {
1336                    panic!(
1337                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1338						 in fisher output with params: period={} (param set {})",
1339                        test_name,
1340                        val,
1341                        bits,
1342                        i,
1343                        params.period.unwrap_or(9),
1344                        param_idx
1345                    );
1346                }
1347
1348                if bits == 0x22222222_22222222 {
1349                    panic!(
1350                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1351						 in fisher output with params: period={} (param set {})",
1352                        test_name,
1353                        val,
1354                        bits,
1355                        i,
1356                        params.period.unwrap_or(9),
1357                        param_idx
1358                    );
1359                }
1360
1361                if bits == 0x33333333_33333333 {
1362                    panic!(
1363                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1364						 in fisher output with params: period={} (param set {})",
1365                        test_name,
1366                        val,
1367                        bits,
1368                        i,
1369                        params.period.unwrap_or(9),
1370                        param_idx
1371                    );
1372                }
1373            }
1374
1375            for (i, &val) in output.signal.iter().enumerate() {
1376                if val.is_nan() {
1377                    continue;
1378                }
1379
1380                let bits = val.to_bits();
1381
1382                if bits == 0x11111111_11111111 {
1383                    panic!(
1384                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1385						 in signal output with params: period={} (param set {})",
1386                        test_name,
1387                        val,
1388                        bits,
1389                        i,
1390                        params.period.unwrap_or(9),
1391                        param_idx
1392                    );
1393                }
1394
1395                if bits == 0x22222222_22222222 {
1396                    panic!(
1397                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1398						 in signal output with params: period={} (param set {})",
1399                        test_name,
1400                        val,
1401                        bits,
1402                        i,
1403                        params.period.unwrap_or(9),
1404                        param_idx
1405                    );
1406                }
1407
1408                if bits == 0x33333333_33333333 {
1409                    panic!(
1410                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1411						 in signal output with params: period={} (param set {})",
1412                        test_name,
1413                        val,
1414                        bits,
1415                        i,
1416                        params.period.unwrap_or(9),
1417                        param_idx
1418                    );
1419                }
1420            }
1421        }
1422
1423        Ok(())
1424    }
1425
1426    #[cfg(not(debug_assertions))]
1427    fn check_fisher_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1428        Ok(())
1429    }
1430
1431    #[cfg(feature = "proptest")]
1432    #[allow(clippy::float_cmp)]
1433    fn check_fisher_property(
1434        test_name: &str,
1435        kernel: Kernel,
1436    ) -> Result<(), Box<dyn std::error::Error>> {
1437        use proptest::prelude::*;
1438        skip_if_unsupported!(kernel, test_name);
1439
1440        let strat = (2usize..=50).prop_flat_map(|period| {
1441            (
1442                (100f64..10000f64, 0.01f64..0.05f64, period + 10..400)
1443                    .prop_flat_map(move |(base_price, volatility, data_len)| {
1444                        (
1445                            Just(base_price),
1446                            Just(volatility),
1447                            Just(data_len),
1448                            prop::collection::vec((-1f64..1f64), data_len),
1449                            prop::collection::vec(prop::bool::ANY, data_len),
1450                        )
1451                    })
1452                    .prop_map(
1453                        move |(
1454                            base_price,
1455                            volatility,
1456                            data_len,
1457                            price_changes,
1458                            zero_spread_flags,
1459                        )| {
1460                            let mut high = Vec::with_capacity(data_len);
1461                            let mut low = Vec::with_capacity(data_len);
1462                            let mut current_price = base_price;
1463
1464                            for i in 0..data_len {
1465                                let change = price_changes[i] * volatility * current_price;
1466                                current_price = (current_price + change).max(10.0);
1467
1468                                if zero_spread_flags[i] && i % 5 == 0 {
1469                                    high.push(current_price);
1470                                    low.push(current_price);
1471                                } else {
1472                                    let spread =
1473                                        current_price * 0.01 * (0.1 + price_changes[i].abs());
1474                                    high.push(current_price + spread);
1475                                    low.push((current_price - spread).max(10.0));
1476                                }
1477                            }
1478
1479                            (high, low)
1480                        },
1481                    ),
1482                Just(period),
1483            )
1484        });
1485
1486        proptest::test_runner::TestRunner::default().run(&strat, |((high, low), period)| {
1487            let params = FisherParams {
1488                period: Some(period),
1489            };
1490            let input = FisherInput::from_slices(&high, &low, params);
1491
1492            let FisherOutput {
1493                fisher: out,
1494                signal: sig,
1495            } = fisher_with_kernel(&input, kernel)?;
1496            let FisherOutput {
1497                fisher: ref_out,
1498                signal: ref_sig,
1499            } = fisher_with_kernel(&input, Kernel::Scalar)?;
1500
1501            prop_assert_eq!(
1502                out.len(),
1503                high.len(),
1504                "[{}] Fisher output length mismatch",
1505                test_name
1506            );
1507            prop_assert_eq!(
1508                sig.len(),
1509                high.len(),
1510                "[{}] Signal output length mismatch",
1511                test_name
1512            );
1513
1514            let mut first_valid = None;
1515            for i in 0..high.len() {
1516                if !high[i].is_nan() && !low[i].is_nan() {
1517                    first_valid = Some(i);
1518                    break;
1519                }
1520            }
1521
1522            if let Some(first) = first_valid {
1523                let warmup_end = first + period - 1;
1524                for i in 0..warmup_end.min(out.len()) {
1525                    prop_assert!(
1526                        out[i].is_nan(),
1527                        "[{}] Expected NaN at index {} during warmup",
1528                        test_name,
1529                        i
1530                    );
1531                    prop_assert!(
1532                        sig[i].is_nan(),
1533                        "[{}] Expected NaN at signal index {} during warmup",
1534                        test_name,
1535                        i
1536                    );
1537                }
1538
1539                if warmup_end < out.len() {
1540                    prop_assert!(
1541                        !out[warmup_end].is_nan(),
1542                        "[{}] Expected valid value at index {} after warmup",
1543                        test_name,
1544                        warmup_end
1545                    );
1546                }
1547            }
1548
1549            if let Some(first) = first_valid {
1550                let warmup_end = first + period - 1;
1551
1552                for window_start in warmup_end..out.len().saturating_sub(period * 2) {
1553                    let window_end = (window_start + period).min(out.len());
1554
1555                    let mut is_constant = true;
1556                    let first_hl = (high[window_start] + low[window_start]) / 2.0;
1557
1558                    for i in window_start..window_end {
1559                        let current_hl = (high[i] + low[i]) / 2.0;
1560                        if (current_hl - first_hl).abs() > 0.001 * first_hl {
1561                            is_constant = false;
1562                            break;
1563                        }
1564                    }
1565
1566                    if is_constant && window_end > window_start + 3 {
1567                        let fisher_start = out[window_start].abs();
1568                        let fisher_end = out[window_end - 1].abs();
1569
1570                        if fisher_start > 0.1 {
1571                            prop_assert!(
1572									fisher_end <= fisher_start * 1.1,
1573									"[{}] Fisher not trending to zero in constant period [{}, {}]: start={}, end={}",
1574									test_name, window_start, window_end, fisher_start, fisher_end
1575								);
1576                        }
1577                    }
1578                }
1579            }
1580
1581            for i in 1..out.len() {
1582                if !out[i - 1].is_nan() && !sig[i].is_nan() {
1583                    prop_assert!(
1584                        (sig[i] - out[i - 1]).abs() < 1e-9,
1585                        "[{}] Signal at {} ({}) doesn't match previous Fisher ({})",
1586                        test_name,
1587                        i,
1588                        sig[i],
1589                        out[i - 1]
1590                    );
1591                }
1592            }
1593
1594            if let Some(first) = first_valid {
1595                let warmup_end = first + period - 1;
1596                if warmup_end < out.len() && !out[warmup_end].is_nan() {
1597                    prop_assert!(
1598                        out[warmup_end].abs() < 5.0,
1599                        "[{}] First Fisher value {} seems incorrect (should start from zero state)",
1600                        test_name,
1601                        out[warmup_end]
1602                    );
1603                }
1604            }
1605
1606            for i in 0..out.len() {
1607                let y = out[i];
1608                let r = ref_out[i];
1609                let s = sig[i];
1610                let rs = ref_sig[i];
1611
1612                if y.is_nan() || r.is_nan() {
1613                    prop_assert_eq!(
1614                        y.is_nan(),
1615                        r.is_nan(),
1616                        "[{}] NaN mismatch at index {}",
1617                        test_name,
1618                        i
1619                    );
1620                    continue;
1621                }
1622
1623                if s.is_nan() || rs.is_nan() {
1624                    prop_assert_eq!(
1625                        s.is_nan(),
1626                        rs.is_nan(),
1627                        "[{}] Signal NaN mismatch at index {}",
1628                        test_name,
1629                        i
1630                    );
1631                    continue;
1632                }
1633
1634                let y_bits = y.to_bits();
1635                let r_bits = r.to_bits();
1636                let s_bits = s.to_bits();
1637                let rs_bits = rs.to_bits();
1638
1639                let ulp_diff_fisher: u64 = y_bits.abs_diff(r_bits);
1640                let ulp_diff_signal: u64 = s_bits.abs_diff(rs_bits);
1641
1642                prop_assert!(
1643                    (y - r).abs() <= 1e-9 || ulp_diff_fisher <= 4,
1644                    "[{}] Fisher mismatch idx {}: {} vs {} (ULP={})",
1645                    test_name,
1646                    i,
1647                    y,
1648                    r,
1649                    ulp_diff_fisher
1650                );
1651
1652                prop_assert!(
1653                    (s - rs).abs() <= 1e-9 || ulp_diff_signal <= 4,
1654                    "[{}] Signal mismatch idx {}: {} vs {} (ULP={})",
1655                    test_name,
1656                    i,
1657                    s,
1658                    rs,
1659                    ulp_diff_signal
1660                );
1661            }
1662
1663            Ok(())
1664        })?;
1665
1666        Ok(())
1667    }
1668
1669    macro_rules! generate_all_fisher_tests {
1670        ($($test_fn:ident),*) => {
1671            paste::paste! {
1672                $(
1673                    #[test]
1674                    fn [<$test_fn _scalar_f64>]() {
1675                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1676                    }
1677                )*
1678                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1679                $(
1680                    #[test]
1681                    fn [<$test_fn _avx2_f64>]() {
1682                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1683                    }
1684                    #[test]
1685                    fn [<$test_fn _avx512_f64>]() {
1686                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1687                    }
1688                )*
1689            }
1690        }
1691    }
1692
1693    generate_all_fisher_tests!(
1694        check_fisher_partial_params,
1695        check_fisher_accuracy,
1696        check_fisher_zero_period,
1697        check_fisher_period_exceeds_length,
1698        check_fisher_very_small_dataset,
1699        check_fisher_reinput,
1700        check_fisher_nan_handling,
1701        check_fisher_streaming,
1702        check_fisher_no_poison
1703    );
1704
1705    #[cfg(feature = "proptest")]
1706    generate_all_fisher_tests!(check_fisher_property);
1707
1708    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1709        skip_if_unsupported!(kernel, test);
1710
1711        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1712        let c = read_candles_from_csv(file)?;
1713        let output = FisherBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
1714
1715        let def = FisherParams::default();
1716        let row = output.fisher_for(&def).expect("default row missing");
1717
1718        assert_eq!(row.len(), c.close.len());
1719
1720        let expected_last_five = [
1721            -0.4720164683904261,
1722            -0.23467530106650444,
1723            -0.14879388501136784,
1724            -0.026651419122953053,
1725            -0.2569225042442664,
1726        ];
1727        let start = row.len().saturating_sub(5);
1728        for (i, &val) in row[start..].iter().enumerate() {
1729            let diff = (val - expected_last_five[i]).abs();
1730            assert!(
1731                diff < 1e-1,
1732                "[{test}] default-row mismatch at idx {i}: {val} vs {expected_last_five:?}"
1733            );
1734        }
1735        Ok(())
1736    }
1737
1738    #[cfg(debug_assertions)]
1739    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1740        skip_if_unsupported!(kernel, test);
1741
1742        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1743        let c = read_candles_from_csv(file)?;
1744
1745        let test_configs = vec![
1746            (1, 10, 1),
1747            (2, 20, 2),
1748            (5, 50, 5),
1749            (10, 100, 10),
1750            (20, 240, 20),
1751            (9, 9, 0),
1752            (50, 200, 50),
1753            (1, 5, 1),
1754            (100, 240, 40),
1755            (3, 30, 3),
1756        ];
1757
1758        for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1759            let output = FisherBatchBuilder::new()
1760                .kernel(kernel)
1761                .period_range(p_start, p_end, p_step)
1762                .apply_candles(&c)?;
1763
1764            for (idx, &val) in output.fisher.iter().enumerate() {
1765                if val.is_nan() {
1766                    continue;
1767                }
1768
1769                let bits = val.to_bits();
1770                let row = idx / output.cols;
1771                let col = idx % output.cols;
1772                let combo = &output.combos[row];
1773
1774                if bits == 0x11111111_11111111 {
1775                    panic!(
1776                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1777						 at row {} col {} (flat index {}) in fisher output with params: period={}",
1778                        test,
1779                        cfg_idx,
1780                        val,
1781                        bits,
1782                        row,
1783                        col,
1784                        idx,
1785                        combo.period.unwrap_or(9)
1786                    );
1787                }
1788
1789                if bits == 0x22222222_22222222 {
1790                    panic!(
1791                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1792						 at row {} col {} (flat index {}) in fisher output with params: period={}",
1793                        test,
1794                        cfg_idx,
1795                        val,
1796                        bits,
1797                        row,
1798                        col,
1799                        idx,
1800                        combo.period.unwrap_or(9)
1801                    );
1802                }
1803
1804                if bits == 0x33333333_33333333 {
1805                    panic!(
1806                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1807						 at row {} col {} (flat index {}) in fisher output with params: period={}",
1808                        test,
1809                        cfg_idx,
1810                        val,
1811                        bits,
1812                        row,
1813                        col,
1814                        idx,
1815                        combo.period.unwrap_or(9)
1816                    );
1817                }
1818            }
1819
1820            for (idx, &val) in output.signal.iter().enumerate() {
1821                if val.is_nan() {
1822                    continue;
1823                }
1824
1825                let bits = val.to_bits();
1826                let row = idx / output.cols;
1827                let col = idx % output.cols;
1828                let combo = &output.combos[row];
1829
1830                if bits == 0x11111111_11111111 {
1831                    panic!(
1832                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1833						 at row {} col {} (flat index {}) in signal output with params: period={}",
1834                        test,
1835                        cfg_idx,
1836                        val,
1837                        bits,
1838                        row,
1839                        col,
1840                        idx,
1841                        combo.period.unwrap_or(9)
1842                    );
1843                }
1844
1845                if bits == 0x22222222_22222222 {
1846                    panic!(
1847                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1848						 at row {} col {} (flat index {}) in signal output with params: period={}",
1849                        test,
1850                        cfg_idx,
1851                        val,
1852                        bits,
1853                        row,
1854                        col,
1855                        idx,
1856                        combo.period.unwrap_or(9)
1857                    );
1858                }
1859
1860                if bits == 0x33333333_33333333 {
1861                    panic!(
1862                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1863						 at row {} col {} (flat index {}) in signal output with params: period={}",
1864                        test,
1865                        cfg_idx,
1866                        val,
1867                        bits,
1868                        row,
1869                        col,
1870                        idx,
1871                        combo.period.unwrap_or(9)
1872                    );
1873                }
1874            }
1875        }
1876
1877        Ok(())
1878    }
1879
1880    #[cfg(not(debug_assertions))]
1881    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1882        Ok(())
1883    }
1884
1885    macro_rules! gen_batch_tests {
1886        ($fn_name:ident) => {
1887            paste::paste! {
1888                #[test] fn [<$fn_name _scalar>]()      {
1889                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1890                }
1891                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1892                #[test] fn [<$fn_name _avx2>]()        {
1893                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1894                }
1895                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1896                #[test] fn [<$fn_name _avx512>]()      {
1897                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1898                }
1899                #[test] fn [<$fn_name _auto_detect>]() {
1900                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1901                }
1902            }
1903        };
1904    }
1905    gen_batch_tests!(check_batch_default_row);
1906    gen_batch_tests!(check_batch_no_poison);
1907
1908    #[test]
1909    fn check_batch_kernel_dispatch() -> Result<(), Box<dyn Error>> {
1910        let high = vec![10.0, 12.0, 14.0, 16.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0];
1911        let low = vec![5.0, 7.0, 9.0, 10.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0];
1912        let sweep = FisherBatchRange { period: (3, 5, 1) };
1913
1914        let scalar_result = fisher_batch_slice(&high, &low, &sweep, Kernel::Scalar)?;
1915
1916        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1917        if is_x86_feature_detected!("avx2") {
1918            let avx2_result = fisher_batch_slice(&high, &low, &sweep, Kernel::Avx2)?;
1919
1920            for i in 0..scalar_result.fisher.len() {
1921                let diff = (scalar_result.fisher[i] - avx2_result.fisher[i]).abs();
1922                assert!(
1923                    diff < 1e-10
1924                        || (scalar_result.fisher[i].is_nan() && avx2_result.fisher[i].is_nan()),
1925                    "Fisher mismatch at {}: scalar={}, avx2={}",
1926                    i,
1927                    scalar_result.fisher[i],
1928                    avx2_result.fisher[i]
1929                );
1930            }
1931        }
1932
1933        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1934        if is_x86_feature_detected!("avx512f") {
1935            let avx512_result = fisher_batch_slice(&high, &low, &sweep, Kernel::Avx512)?;
1936
1937            for i in 0..scalar_result.fisher.len() {
1938                let diff = (scalar_result.fisher[i] - avx512_result.fisher[i]).abs();
1939                assert!(
1940                    diff < 1e-10
1941                        || (scalar_result.fisher[i].is_nan() && avx512_result.fisher[i].is_nan()),
1942                    "Fisher mismatch at {}: scalar={}, avx512={}",
1943                    i,
1944                    scalar_result.fisher[i],
1945                    avx512_result.fisher[i]
1946                );
1947            }
1948        }
1949
1950        Ok(())
1951    }
1952
1953    #[test]
1954    fn test_fisher_into_matches_api() -> Result<(), Box<dyn Error>> {
1955        let n = 256usize;
1956        let mut ts = Vec::with_capacity(n);
1957        let mut open = Vec::with_capacity(n);
1958        let mut high = Vec::with_capacity(n);
1959        let mut low = Vec::with_capacity(n);
1960        let mut close = Vec::with_capacity(n);
1961        let mut volume = Vec::with_capacity(n);
1962
1963        for i in 0..n {
1964            ts.push(i as i64);
1965            let base = 1000.0 + (i as f64) * 0.1;
1966            let wiggle = ((i as f64) * 0.15).sin() * 2.0;
1967            let h = base + 5.0 + wiggle;
1968            let l = base - 5.0 - 0.5 * wiggle;
1969            let o = base - 1.0;
1970            let c = base + 1.0;
1971            open.push(o);
1972            high.push(h);
1973            low.push(l);
1974            close.push(c);
1975            volume.push(100.0 + (i % 10) as f64);
1976        }
1977
1978        let candles = crate::utilities::data_loader::Candles::new(
1979            ts,
1980            open,
1981            high.clone(),
1982            low.clone(),
1983            close,
1984            volume,
1985        );
1986        let input = FisherInput::from_candles(&candles, FisherParams::default());
1987
1988        let base = fisher(&input)?;
1989
1990        let mut out_fish = vec![0.0; n];
1991        let mut out_sig = vec![0.0; n];
1992
1993        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1994        {
1995            fisher_into(&input, &mut out_fish, &mut out_sig)?;
1996        }
1997
1998        fn eq_or_both_nan(a: f64, b: f64) -> bool {
1999            (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
2000        }
2001
2002        assert_eq!(out_fish.len(), base.fisher.len());
2003        assert_eq!(out_sig.len(), base.signal.len());
2004        for i in 0..n {
2005            assert!(
2006                eq_or_both_nan(out_fish[i], base.fisher[i]),
2007                "fisher mismatch at {}: {} vs {}",
2008                i,
2009                out_fish[i],
2010                base.fisher[i]
2011            );
2012            assert!(
2013                eq_or_both_nan(out_sig[i], base.signal[i]),
2014                "signal mismatch at {}: {} vs {}",
2015                i,
2016                out_sig[i],
2017                base.signal[i]
2018            );
2019        }
2020
2021        Ok(())
2022    }
2023}
2024
2025#[cfg(feature = "python")]
2026#[pyfunction(name = "fisher")]
2027#[pyo3(signature = (high, low, period, kernel=None))]
2028pub fn fisher_py<'py>(
2029    py: Python<'py>,
2030    high: PyReadonlyArray1<'py, f64>,
2031    low: PyReadonlyArray1<'py, f64>,
2032    period: usize,
2033    kernel: Option<&str>,
2034) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
2035    use numpy::{IntoPyArray, PyArrayMethods};
2036
2037    let high_slice = high.as_slice()?;
2038    let low_slice = low.as_slice()?;
2039    let kern = validate_kernel(kernel, false)?;
2040
2041    let data_len = high_slice.len().min(low_slice.len());
2042
2043    let fisher_arr = unsafe { PyArray1::<f64>::new(py, [data_len], false) };
2044    let signal_arr = unsafe { PyArray1::<f64>::new(py, [data_len], false) };
2045    let fisher_slice = unsafe { fisher_arr.as_slice_mut()? };
2046    let signal_slice = unsafe { signal_arr.as_slice_mut()? };
2047
2048    let params = FisherParams {
2049        period: Some(period),
2050    };
2051    let input = FisherInput::from_slices(high_slice, low_slice, params);
2052
2053    py.allow_threads(|| fisher_into_slice(fisher_slice, signal_slice, &input, kern))
2054        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2055
2056    Ok((fisher_arr, signal_arr))
2057}
2058
2059#[cfg(feature = "python")]
2060#[pyclass(name = "FisherStream")]
2061pub struct FisherStreamPy {
2062    stream: FisherStream,
2063}
2064
2065#[cfg(feature = "python")]
2066#[pymethods]
2067impl FisherStreamPy {
2068    #[new]
2069    pub fn new(period: usize) -> PyResult<Self> {
2070        let params = FisherParams {
2071            period: Some(period),
2072        };
2073        let stream =
2074            FisherStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2075        Ok(FisherStreamPy { stream })
2076    }
2077
2078    pub fn update(&mut self, high: f64, low: f64) -> Option<(f64, f64)> {
2079        self.stream.update(high, low)
2080    }
2081}
2082
2083#[cfg(feature = "python")]
2084#[pyfunction(name = "fisher_batch")]
2085#[pyo3(signature = (high, low, period_range, kernel=None))]
2086pub fn fisher_batch_py<'py>(
2087    py: Python<'py>,
2088    high: PyReadonlyArray1<'py, f64>,
2089    low: PyReadonlyArray1<'py, f64>,
2090    period_range: (usize, usize, usize),
2091    kernel: Option<&str>,
2092) -> PyResult<Bound<'py, PyDict>> {
2093    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2094    use pyo3::types::PyDict;
2095
2096    let high_slice = high.as_slice()?;
2097    let low_slice = low.as_slice()?;
2098    if high_slice.len() != low_slice.len() {
2099        return Err(PyValueError::new_err(format!(
2100            "Mismatched data length: high={}, low={}",
2101            high_slice.len(),
2102            low_slice.len()
2103        )));
2104    }
2105
2106    let kern = validate_kernel(kernel, true)?;
2107    let sweep = FisherBatchRange {
2108        period: period_range,
2109    };
2110
2111    let combos = expand_grid(&sweep);
2112    let rows = combos.len();
2113    let cols = high_slice.len();
2114    let total_len = rows
2115        .checked_mul(cols)
2116        .ok_or_else(|| PyValueError::new_err("Fisher batch: rows*cols overflow"))?;
2117
2118    let fisher_arr = unsafe { PyArray1::<f64>::new(py, [total_len], false) };
2119    let signal_arr = unsafe { PyArray1::<f64>::new(py, [total_len], false) };
2120
2121    let first = (0..cols)
2122        .find(|&i| !high_slice[i].is_nan() && !low_slice[i].is_nan())
2123        .ok_or_else(|| PyValueError::new_err("All values are NaN"))?;
2124    let mut warmups: Vec<usize> = Vec::with_capacity(combos.len());
2125    for c in &combos {
2126        let p = c.period.unwrap_or(0);
2127        let warm = first
2128            .checked_add(p.saturating_sub(1))
2129            .ok_or_else(|| PyValueError::new_err("Fisher batch: warmup overflow"))?;
2130        warmups.push(warm);
2131    }
2132
2133    unsafe {
2134        let fisher_mu = std::slice::from_raw_parts_mut(
2135            fisher_arr.as_slice_mut()?.as_mut_ptr() as *mut MaybeUninit<f64>,
2136            total_len,
2137        );
2138        let signal_mu = std::slice::from_raw_parts_mut(
2139            signal_arr.as_slice_mut()?.as_mut_ptr() as *mut MaybeUninit<f64>,
2140            total_len,
2141        );
2142        init_matrix_prefixes(fisher_mu, cols, &warmups);
2143        init_matrix_prefixes(signal_mu, cols, &warmups);
2144    }
2145
2146    let fisher_ptr = unsafe { fisher_arr.as_slice_mut()?.as_mut_ptr() } as usize;
2147    let signal_ptr = unsafe { signal_arr.as_slice_mut()?.as_mut_ptr() } as usize;
2148
2149    py.allow_threads(move || {
2150        let kernel = match kern {
2151            Kernel::Auto => detect_best_batch_kernel(),
2152            k => k,
2153        };
2154        let simd = match kernel {
2155            Kernel::Avx512Batch => Kernel::Avx512,
2156            Kernel::Avx2Batch => Kernel::Avx2,
2157            Kernel::ScalarBatch => Kernel::Scalar,
2158            _ => kernel,
2159        };
2160
2161        unsafe {
2162            let fisher_slice = std::slice::from_raw_parts_mut(fisher_ptr as *mut f64, total_len);
2163            let signal_slice = std::slice::from_raw_parts_mut(signal_ptr as *mut f64, total_len);
2164            fisher_batch_inner_into(
2165                high_slice,
2166                low_slice,
2167                &sweep,
2168                simd,
2169                true,
2170                fisher_slice,
2171                signal_slice,
2172            )
2173        }
2174    })
2175    .map_err(|e| PyValueError::new_err(e.to_string()))?;
2176
2177    let dict = PyDict::new(py);
2178    dict.set_item("fisher", fisher_arr.reshape((rows, cols))?)?;
2179    dict.set_item("signal", signal_arr.reshape((rows, cols))?)?;
2180    dict.set_item(
2181        "periods",
2182        combos
2183            .iter()
2184            .map(|p| p.period.unwrap() as u64)
2185            .collect::<Vec<_>>()
2186            .into_pyarray(py),
2187    )?;
2188
2189    Ok(dict)
2190}
2191
2192#[cfg(all(feature = "python", feature = "cuda"))]
2193#[pyclass(
2194    module = "ta_indicators.cuda",
2195    name = "FisherDeviceArrayF32",
2196    unsendable
2197)]
2198pub struct FisherDeviceArrayF32Py {
2199    pub(crate) inner: Option<DeviceArrayF32Py>,
2200}
2201
2202#[cfg(all(feature = "python", feature = "cuda"))]
2203#[pymethods]
2204impl FisherDeviceArrayF32Py {
2205    #[getter]
2206    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2207        let inner = self
2208            .inner
2209            .as_ref()
2210            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2211        inner.__cuda_array_interface__(py)
2212    }
2213
2214    fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
2215        let inner = self
2216            .inner
2217            .as_ref()
2218            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2219        inner.__dlpack_device__()
2220    }
2221
2222    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2223    fn __dlpack__<'py>(
2224        &mut self,
2225        py: Python<'py>,
2226        stream: Option<PyObject>,
2227        max_version: Option<PyObject>,
2228        dl_device: Option<PyObject>,
2229        copy: Option<PyObject>,
2230    ) -> PyResult<PyObject> {
2231        if let Some(ref s_obj) = stream {
2232            if let Ok(s) = s_obj.extract::<usize>(py) {
2233                if s == 0 {
2234                    return Err(PyValueError::new_err(
2235                        "__dlpack__ stream=0 is invalid for CUDA",
2236                    ));
2237                }
2238            }
2239        }
2240
2241        let mut inner = self
2242            .inner
2243            .take()
2244            .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
2245        let capsule = inner.__dlpack__(py, stream, max_version, dl_device, copy)?;
2246        Ok(capsule)
2247    }
2248}
2249
2250#[cfg(all(feature = "python", feature = "cuda"))]
2251impl FisherDeviceArrayF32Py {
2252    pub fn new_from_rust(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
2253        Self {
2254            inner: Some(DeviceArrayF32Py {
2255                inner,
2256                _ctx: Some(ctx_guard),
2257                device_id: Some(device_id),
2258            }),
2259        }
2260    }
2261}
2262
2263#[cfg(all(feature = "python", feature = "cuda"))]
2264#[pyfunction(name = "fisher_cuda_batch_dev")]
2265#[pyo3(signature = (high_f32, low_f32, period_range, device_id=0))]
2266pub fn fisher_cuda_batch_dev_py<'py>(
2267    py: Python<'py>,
2268    high_f32: numpy::PyReadonlyArray1<'py, f32>,
2269    low_f32: numpy::PyReadonlyArray1<'py, f32>,
2270    period_range: (usize, usize, usize),
2271    device_id: usize,
2272) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2273    use numpy::IntoPyArray;
2274    if !cuda_available() {
2275        return Err(PyValueError::new_err("CUDA not available"));
2276    }
2277    let high = high_f32.as_slice()?;
2278    let low = low_f32.as_slice()?;
2279    let sweep = FisherBatchRange {
2280        period: period_range,
2281    };
2282    let cuda = CudaFisher::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2283    let ctx_guard = cuda.context_arc();
2284    let dev_id = cuda.device_id();
2285    let ((pair, combos)) = py.allow_threads(|| {
2286        cuda.fisher_batch_dev(high, low, &sweep)
2287            .map_err(|e| PyValueError::new_err(e.to_string()))
2288    })?;
2289    let dict = pyo3::types::PyDict::new(py);
2290    dict.set_item(
2291        "fisher",
2292        Py::new(
2293            py,
2294            FisherDeviceArrayF32Py::new_from_rust(pair.fisher, ctx_guard.clone(), dev_id),
2295        )?,
2296    )?;
2297    dict.set_item(
2298        "signal",
2299        Py::new(
2300            py,
2301            FisherDeviceArrayF32Py::new_from_rust(pair.signal, ctx_guard, dev_id),
2302        )?,
2303    )?;
2304    dict.set_item(
2305        "periods",
2306        combos
2307            .iter()
2308            .map(|p| p.period.unwrap() as u64)
2309            .collect::<Vec<_>>()
2310            .into_pyarray(py),
2311    )?;
2312    dict.set_item("rows", combos.len())?;
2313    dict.set_item("cols", high.len())?;
2314    Ok(dict)
2315}
2316
2317#[cfg(all(feature = "python", feature = "cuda"))]
2318#[pyfunction(name = "fisher_cuda_many_series_one_param_dev")]
2319#[pyo3(signature = (high_tm_f32, low_tm_f32, cols, rows, period, device_id=0))]
2320pub fn fisher_cuda_many_series_one_param_dev_py<'py>(
2321    py: Python<'py>,
2322    high_tm_f32: numpy::PyReadonlyArray1<'py, f32>,
2323    low_tm_f32: numpy::PyReadonlyArray1<'py, f32>,
2324    cols: usize,
2325    rows: usize,
2326    period: usize,
2327    device_id: usize,
2328) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2329    if !cuda_available() {
2330        return Err(PyValueError::new_err("CUDA not available"));
2331    }
2332    let high_tm = high_tm_f32.as_slice()?;
2333    let low_tm = low_tm_f32.as_slice()?;
2334    let cuda = CudaFisher::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2335    let ctx_guard = cuda.context_arc();
2336    let dev_id = cuda.device_id();
2337    let pair = py.allow_threads(|| {
2338        cuda.fisher_many_series_one_param_time_major_dev(high_tm, low_tm, cols, rows, period)
2339            .map_err(|e| PyValueError::new_err(e.to_string()))
2340    })?;
2341    let dict = pyo3::types::PyDict::new(py);
2342    dict.set_item(
2343        "fisher",
2344        Py::new(
2345            py,
2346            FisherDeviceArrayF32Py::new_from_rust(pair.fisher, ctx_guard.clone(), dev_id),
2347        )?,
2348    )?;
2349    dict.set_item(
2350        "signal",
2351        Py::new(
2352            py,
2353            FisherDeviceArrayF32Py::new_from_rust(pair.signal, ctx_guard, dev_id),
2354        )?,
2355    )?;
2356    dict.set_item("rows", rows)?;
2357    dict.set_item("cols", cols)?;
2358    Ok(dict)
2359}
2360
2361#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2362#[wasm_bindgen]
2363pub struct FisherResult {
2364    values: Vec<f64>,
2365    rows: usize,
2366    cols: usize,
2367}
2368
2369#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2370#[wasm_bindgen]
2371impl FisherResult {
2372    #[wasm_bindgen(getter)]
2373    pub fn values(&self) -> Vec<f64> {
2374        self.values.clone()
2375    }
2376
2377    #[wasm_bindgen(getter)]
2378    pub fn rows(&self) -> usize {
2379        self.rows
2380    }
2381
2382    #[wasm_bindgen(getter)]
2383    pub fn cols(&self) -> usize {
2384        self.cols
2385    }
2386}
2387
2388#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2389#[wasm_bindgen]
2390pub fn fisher_js(high: &[f64], low: &[f64], period: usize) -> Result<FisherResult, JsValue> {
2391    if high.len() != low.len() {
2392        return Err(JsValue::from_str("Mismatched data length"));
2393    }
2394    let len = high.len();
2395    let params = FisherParams {
2396        period: Some(period),
2397    };
2398    let input = FisherInput::from_slices(high, low, params);
2399
2400    let total = len
2401        .checked_mul(2)
2402        .ok_or_else(|| JsValue::from_str("fisher_js: len*2 overflow"))?;
2403    let mut out = vec![0.0_f64; total];
2404    let (fisher_out, signal_out) = out.split_at_mut(len);
2405
2406    fisher_into_slice(fisher_out, signal_out, &input, detect_best_kernel())
2407        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2408
2409    Ok(FisherResult {
2410        values: out,
2411        rows: 2,
2412        cols: len,
2413    })
2414}
2415
2416#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2417#[wasm_bindgen]
2418pub fn fisher_into(
2419    high_ptr: *const f64,
2420    low_ptr: *const f64,
2421    out_ptr: *mut f64,
2422    len: usize,
2423    period: usize,
2424) -> Result<(), JsValue> {
2425    if high_ptr.is_null() || low_ptr.is_null() || out_ptr.is_null() {
2426        return Err(JsValue::from_str("null pointer passed to fisher_into"));
2427    }
2428    unsafe {
2429        let high = std::slice::from_raw_parts(high_ptr, len);
2430        let low = std::slice::from_raw_parts(low_ptr, len);
2431        if high.len() != low.len() {
2432            return Err(JsValue::from_str("Mismatched data length"));
2433        }
2434        let total = len
2435            .checked_mul(2)
2436            .ok_or_else(|| JsValue::from_str("fisher_into: len*2 overflow"))?;
2437        let out = std::slice::from_raw_parts_mut(out_ptr, total);
2438        let (fisher_out, signal_out) = out.split_at_mut(len);
2439
2440        let params = FisherParams {
2441            period: Some(period),
2442        };
2443        let input = FisherInput::from_slices(high, low, params);
2444
2445        fisher_into_slice(fisher_out, signal_out, &input, detect_best_kernel())
2446            .map_err(|e| JsValue::from_str(&e.to_string()))
2447    }
2448}
2449
2450#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2451#[wasm_bindgen]
2452pub fn fisher_alloc(len: usize) -> *mut f64 {
2453    let mut vec = Vec::<f64>::with_capacity(len);
2454    let ptr = vec.as_mut_ptr();
2455    std::mem::forget(vec);
2456    ptr
2457}
2458
2459#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2460#[wasm_bindgen]
2461pub fn fisher_free(ptr: *mut f64, len: usize) {
2462    if !ptr.is_null() {
2463        unsafe {
2464            let _ = Vec::from_raw_parts(ptr, len, len);
2465        }
2466    }
2467}
2468
2469#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2470#[derive(Serialize, Deserialize)]
2471pub struct FisherBatchConfig {
2472    pub period_range: (usize, usize, usize),
2473}
2474
2475#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2476#[derive(Serialize, Deserialize)]
2477pub struct FisherBatchJsOutput {
2478    pub values: Vec<f64>,
2479    pub rows: usize,
2480    pub cols: usize,
2481    pub combos: Vec<FisherParams>,
2482}
2483
2484#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2485#[wasm_bindgen(js_name = fisher_batch)]
2486pub fn fisher_batch_unified_js(
2487    high: &[f64],
2488    low: &[f64],
2489    config: JsValue,
2490) -> Result<JsValue, JsValue> {
2491    if high.len() != low.len() {
2492        return Err(JsValue::from_str("Mismatched data length"));
2493    }
2494    let config: FisherBatchConfig = serde_wasm_bindgen::from_value(config)
2495        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2496
2497    let sweep = FisherBatchRange {
2498        period: config.period_range,
2499    };
2500
2501    let combos = expand_grid(&sweep);
2502    let cols = high.len();
2503    let rows = combos.len();
2504    let total_elems = rows
2505        .checked_mul(cols)
2506        .ok_or_else(|| JsValue::from_str("fisher_batch_unified_js: rows*cols overflow"))?;
2507
2508    let mut fisher = vec![0.0; total_elems];
2509    let mut signal = vec![0.0; total_elems];
2510
2511    fisher_batch_inner_into(
2512        high,
2513        low,
2514        &sweep,
2515        detect_best_kernel(),
2516        false,
2517        &mut fisher,
2518        &mut signal,
2519    )
2520    .map_err(|e| JsValue::from_str(&e.to_string()))?;
2521
2522    let values_capacity = total_elems
2523        .checked_mul(2)
2524        .ok_or_else(|| JsValue::from_str("fisher_batch_unified_js: values capacity overflow"))?;
2525    let mut values = Vec::with_capacity(values_capacity);
2526    for r in 0..rows {
2527        values.extend_from_slice(&fisher[r * cols..(r + 1) * cols]);
2528        values.extend_from_slice(&signal[r * cols..(r + 1) * cols]);
2529    }
2530
2531    let js = FisherBatchJsOutput {
2532        values,
2533        rows: 2 * rows,
2534        cols,
2535        combos,
2536    };
2537    serde_wasm_bindgen::to_value(&js)
2538        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2539}
2540
2541#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2542#[wasm_bindgen]
2543pub fn fisher_batch_into(
2544    high_ptr: *const f64,
2545    low_ptr: *const f64,
2546    fisher_ptr: *mut f64,
2547    signal_ptr: *mut f64,
2548    len: usize,
2549    period_start: usize,
2550    period_end: usize,
2551    period_step: usize,
2552) -> Result<usize, JsValue> {
2553    if high_ptr.is_null() || low_ptr.is_null() || fisher_ptr.is_null() || signal_ptr.is_null() {
2554        return Err(JsValue::from_str("Null pointer provided"));
2555    }
2556
2557    unsafe {
2558        let high = std::slice::from_raw_parts(high_ptr, len);
2559        let low = std::slice::from_raw_parts(low_ptr, len);
2560
2561        let sweep = FisherBatchRange {
2562            period: (period_start, period_end, period_step),
2563        };
2564
2565        let combos = expand_grid(&sweep);
2566        let rows = combos.len();
2567        let total_size = rows
2568            .checked_mul(len)
2569            .ok_or_else(|| JsValue::from_str("fisher_batch_into: rows*len overflow"))?;
2570
2571        let fisher_out = std::slice::from_raw_parts_mut(fisher_ptr, total_size);
2572        let signal_out = std::slice::from_raw_parts_mut(signal_ptr, total_size);
2573
2574        let output = fisher_batch_inner(high, low, &sweep, Kernel::Auto, false)
2575            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2576
2577        fisher_out.copy_from_slice(&output.fisher);
2578        signal_out.copy_from_slice(&output.signal);
2579
2580        Ok(rows)
2581    }
2582}
2583
2584#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2585#[wasm_bindgen]
2586#[deprecated(
2587    since = "1.0.0",
2588    note = "For streaming patterns, use the fast/unsafe API with persistent buffers"
2589)]
2590pub struct FisherContext {
2591    period: usize,
2592    kernel: Kernel,
2593}
2594
2595#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2596#[wasm_bindgen]
2597#[allow(deprecated)]
2598impl FisherContext {
2599    #[wasm_bindgen(constructor)]
2600    #[deprecated(
2601        since = "1.0.0",
2602        note = "For streaming patterns, use the fast/unsafe API with persistent buffers"
2603    )]
2604    pub fn new(period: usize) -> Result<FisherContext, JsValue> {
2605        if period == 0 {
2606            return Err(JsValue::from_str("Invalid period: 0"));
2607        }
2608
2609        Ok(FisherContext {
2610            period,
2611            kernel: detect_best_kernel(),
2612        })
2613    }
2614
2615    pub fn update_into(
2616        &self,
2617        high_ptr: *const f64,
2618        low_ptr: *const f64,
2619        fisher_ptr: *mut f64,
2620        signal_ptr: *mut f64,
2621        len: usize,
2622    ) -> Result<(), JsValue> {
2623        if high_ptr.is_null() || low_ptr.is_null() || fisher_ptr.is_null() || signal_ptr.is_null() {
2624            return Err(JsValue::from_str("Null pointer provided"));
2625        }
2626
2627        unsafe {
2628            let high = std::slice::from_raw_parts(high_ptr, len);
2629            let low = std::slice::from_raw_parts(low_ptr, len);
2630            let fisher_out = std::slice::from_raw_parts_mut(fisher_ptr, len);
2631            let signal_out = std::slice::from_raw_parts_mut(signal_ptr, len);
2632
2633            let params = FisherParams {
2634                period: Some(self.period),
2635            };
2636            let input = FisherInput::from_slices(high, low, params);
2637
2638            fisher_into_slice(fisher_out, signal_out, &input, self.kernel)
2639                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2640
2641            Ok(())
2642        }
2643    }
2644
2645    pub fn get_period(&self) -> usize {
2646        self.period
2647    }
2648
2649    pub fn get_warmup_period(&self) -> usize {
2650        self.period - 1
2651    }
2652}