1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
19 make_uninit_matrix,
20};
21#[cfg(feature = "python")]
22use crate::utilities::kernel_validation::validate_kernel;
23
24#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
25use core::arch::x86_64::*;
26#[cfg(not(target_arch = "wasm32"))]
27use rayon::prelude::*;
28use std::collections::VecDeque;
29use std::convert::AsRef;
30use std::error::Error;
31use std::mem::MaybeUninit;
32use thiserror::Error;
33
34#[cfg(all(feature = "python", feature = "cuda"))]
35use crate::cuda::cuda_available;
36#[cfg(all(feature = "python", feature = "cuda"))]
37use crate::cuda::moving_averages::DeviceArrayF32;
38#[cfg(all(feature = "python", feature = "cuda"))]
39use crate::cuda::oscillators::CudaFisher;
40#[cfg(all(feature = "python", feature = "cuda"))]
41use crate::utilities::dlpack_cuda::DeviceArrayF32Py;
42#[cfg(all(feature = "python", feature = "cuda"))]
43use cust::context::Context;
44#[cfg(all(feature = "python", feature = "cuda"))]
45use std::sync::Arc;
46
47impl<'a> FisherInput<'a> {
48 #[inline(always)]
49 pub fn as_ref(&self) -> (&'a [f64], &'a [f64]) {
50 match &self.data {
51 FisherData::Candles { candles } => {
52 (source_type(candles, "high"), source_type(candles, "low"))
53 }
54 FisherData::Slices { high, low } => (*high, *low),
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
60pub enum FisherData<'a> {
61 Candles { candles: &'a Candles },
62 Slices { high: &'a [f64], low: &'a [f64] },
63}
64
65#[derive(Debug, Clone)]
66pub struct FisherOutput {
67 pub fisher: Vec<f64>,
68 pub signal: Vec<f64>,
69}
70
71#[derive(Debug, Clone)]
72#[cfg_attr(
73 all(target_arch = "wasm32", feature = "wasm"),
74 derive(Serialize, Deserialize)
75)]
76pub struct FisherParams {
77 pub period: Option<usize>,
78}
79
80impl Default for FisherParams {
81 fn default() -> Self {
82 Self { period: Some(9) }
83 }
84}
85
86#[derive(Debug, Clone)]
87pub struct FisherInput<'a> {
88 pub data: FisherData<'a>,
89 pub params: FisherParams,
90}
91
92impl<'a> FisherInput<'a> {
93 #[inline]
94 pub fn from_candles(candles: &'a Candles, params: FisherParams) -> Self {
95 Self {
96 data: FisherData::Candles { candles },
97 params,
98 }
99 }
100
101 #[inline(always)]
102 pub fn get_high_low(&self) -> (&'a [f64], &'a [f64]) {
103 match &self.data {
104 FisherData::Candles { candles } => {
105 let high = source_type(candles, "high");
106 let low = source_type(candles, "low");
107 (high, low)
108 }
109 FisherData::Slices { high, low } => (*high, *low),
110 }
111 }
112 #[inline]
113 pub fn from_slices(high: &'a [f64], low: &'a [f64], params: FisherParams) -> Self {
114 Self {
115 data: FisherData::Slices { high, low },
116 params,
117 }
118 }
119 #[inline]
120 pub fn with_default_candles(candles: &'a Candles) -> Self {
121 Self::from_candles(candles, FisherParams::default())
122 }
123 #[inline]
124 pub fn get_period(&self) -> usize {
125 self.params.period.unwrap_or(9)
126 }
127}
128
129#[derive(Copy, Clone, Debug)]
130pub struct FisherBuilder {
131 period: Option<usize>,
132 kernel: Kernel,
133}
134
135impl Default for FisherBuilder {
136 fn default() -> Self {
137 Self {
138 period: None,
139 kernel: Kernel::Auto,
140 }
141 }
142}
143
144impl FisherBuilder {
145 #[inline(always)]
146 pub fn new() -> Self {
147 Self::default()
148 }
149 #[inline(always)]
150 pub fn period(mut self, n: usize) -> Self {
151 self.period = Some(n);
152 self
153 }
154 #[inline(always)]
155 pub fn kernel(mut self, k: Kernel) -> Self {
156 self.kernel = k;
157 self
158 }
159 #[inline(always)]
160 pub fn apply(self, c: &Candles) -> Result<FisherOutput, FisherError> {
161 let p = FisherParams {
162 period: self.period,
163 };
164 let i = FisherInput::from_candles(c, p);
165 fisher_with_kernel(&i, self.kernel)
166 }
167 #[inline(always)]
168 pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<FisherOutput, FisherError> {
169 let p = FisherParams {
170 period: self.period,
171 };
172 let i = FisherInput::from_slices(high, low, p);
173 fisher_with_kernel(&i, self.kernel)
174 }
175 #[inline(always)]
176 pub fn into_stream(self) -> Result<FisherStream, FisherError> {
177 let p = FisherParams {
178 period: self.period,
179 };
180 FisherStream::try_new(p)
181 }
182}
183
184#[derive(Debug, Error)]
185pub enum FisherError {
186 #[error("fisher: Empty data provided.")]
187 EmptyData,
188
189 #[error("fisher: Empty input data.")]
190 EmptyInputData,
191 #[error("fisher: Invalid period: period = {period}, data length = {data_len}")]
192 InvalidPeriod { period: usize, data_len: usize },
193 #[error("fisher: Not enough valid data: needed = {needed}, valid = {valid}")]
194 NotEnoughValidData { needed: usize, valid: usize },
195 #[error("fisher: All values are NaN.")]
196 AllValuesNaN,
197
198 #[error("fisher: Invalid output length: expected = {expected}, actual = {actual}")]
199 InvalidLength { expected: usize, actual: usize },
200
201 #[error("fisher: Output length mismatch: expected = {expected}, got = {got}")]
202 OutputLengthMismatch { expected: usize, got: usize },
203 #[error("fisher: Mismatched data length: high={high}, low={low}")]
204 MismatchedDataLength { high: usize, low: usize },
205
206 #[error("fisher: Invalid range expansion: start={start}, end={end}, step={step}")]
207 InvalidRange {
208 start: usize,
209 end: usize,
210 step: usize,
211 },
212 #[error("fisher: Invalid kernel for batch path: {0:?}")]
213 InvalidKernelForBatch(crate::utilities::enums::Kernel),
214}
215
216#[inline]
217pub fn fisher(input: &FisherInput) -> Result<FisherOutput, FisherError> {
218 fisher_with_kernel(input, Kernel::Auto)
219}
220
221pub fn fisher_with_kernel(
222 input: &FisherInput,
223 kernel: Kernel,
224) -> Result<FisherOutput, FisherError> {
225 let (high, low) = input.get_high_low();
226
227 if high.is_empty() || low.is_empty() {
228 return Err(FisherError::EmptyInputData);
229 }
230 if high.len() != low.len() {
231 return Err(FisherError::MismatchedDataLength {
232 high: high.len(),
233 low: low.len(),
234 });
235 }
236
237 let period = input.get_period();
238 let data_len = high.len();
239 if period == 0 || period > data_len {
240 return Err(FisherError::InvalidPeriod { period, data_len });
241 }
242
243 let mut first = None;
244 for i in 0..data_len {
245 if !high[i].is_nan() && !low[i].is_nan() {
246 first = Some(i);
247 break;
248 }
249 }
250 let first = first.ok_or(FisherError::AllValuesNaN)?;
251
252 if (data_len - first) < period {
253 return Err(FisherError::NotEnoughValidData {
254 needed: period,
255 valid: data_len - first,
256 });
257 }
258
259 let chosen = match kernel {
260 Kernel::Auto => detect_best_kernel(),
261 other => other,
262 };
263
264 let warmup = first + period - 1;
265 let mut fisher_vals = alloc_with_nan_prefix(data_len, warmup);
266 let mut signal_vals = alloc_with_nan_prefix(data_len, warmup);
267
268 unsafe {
269 match chosen {
270 Kernel::Scalar | Kernel::ScalarBatch => {
271 fisher_scalar_into(high, low, period, first, &mut fisher_vals, &mut signal_vals)
272 }
273 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
274 Kernel::Avx2 | Kernel::Avx2Batch => {
275 fisher_avx2_into(high, low, period, first, &mut fisher_vals, &mut signal_vals)
276 }
277 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
278 Kernel::Avx512 | Kernel::Avx512Batch => {
279 fisher_avx512_into(high, low, period, first, &mut fisher_vals, &mut signal_vals)
280 }
281 _ => unreachable!(),
282 }
283 }
284
285 Ok(FisherOutput {
286 fisher: fisher_vals,
287 signal: signal_vals,
288 })
289}
290
291#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
292#[inline]
293pub fn fisher_into(
294 input: &FisherInput,
295 fisher_out: &mut [f64],
296 signal_out: &mut [f64],
297) -> Result<(), FisherError> {
298 fisher_into_slice(fisher_out, signal_out, input, Kernel::Auto)
299}
300
301#[inline]
302pub fn fisher_scalar_into(
303 high: &[f64],
304 low: &[f64],
305 period: usize,
306 first: usize,
307 fisher_out: &mut [f64],
308 signal_out: &mut [f64],
309) {
310 let len = high.len().min(low.len());
311 if period == 0 || first >= len {
312 return;
313 }
314
315 let mut prev_fish = 0.0f64;
316 let mut val1 = 0.0f64;
317 let warm = first + period - 1;
318
319 for i in warm..len {
320 let start = i + 1 - period;
321
322 let (mut min_val, mut max_val) = (f64::MAX, f64::MIN);
323 for j in start..=i {
324 let midpoint = 0.5 * (high[j] + low[j]);
325 if midpoint > max_val {
326 max_val = midpoint;
327 }
328 if midpoint < min_val {
329 min_val = midpoint;
330 }
331 }
332
333 let range = (max_val - min_val).max(0.001);
334 let hl = 0.5 * (high[i] + low[i]);
335 val1 = 0.67f64.mul_add(val1, 0.66 * ((hl - min_val) / range - 0.5));
336 if val1 > 0.99 {
337 val1 = 0.999;
338 } else if val1 < -0.99 {
339 val1 = -0.999;
340 }
341 signal_out[i] = prev_fish;
342 let new_fish = 0.5f64.mul_add(((1.0 + val1) / (1.0 - val1)).ln(), 0.5 * prev_fish);
343 fisher_out[i] = new_fish;
344 prev_fish = new_fish;
345 }
346}
347
348#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
349#[inline]
350pub fn fisher_avx512_into(
351 high: &[f64],
352 low: &[f64],
353 period: usize,
354 first: usize,
355 fisher_out: &mut [f64],
356 signal_out: &mut [f64],
357) {
358 fisher_scalar_into(high, low, period, first, fisher_out, signal_out)
359}
360
361#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
362#[inline]
363pub fn fisher_avx2_into(
364 high: &[f64],
365 low: &[f64],
366 period: usize,
367 first: usize,
368 fisher_out: &mut [f64],
369 signal_out: &mut [f64],
370) {
371 fisher_scalar_into(high, low, period, first, fisher_out, signal_out)
372}
373
374#[inline]
375pub fn fisher_into_slice(
376 fisher_dst: &mut [f64],
377 signal_dst: &mut [f64],
378 input: &FisherInput,
379 kern: Kernel,
380) -> Result<(), FisherError> {
381 let (high, low) = input.as_ref();
382 if high.is_empty() || low.is_empty() {
383 return Err(FisherError::EmptyData);
384 }
385 if high.len() != low.len() {
386 return Err(FisherError::MismatchedDataLength {
387 high: high.len(),
388 low: low.len(),
389 });
390 }
391
392 let data_len = high.len();
393 let period = input.params.period.unwrap_or(9);
394
395 let mut first = None;
396 for i in 0..data_len {
397 if !high[i].is_nan() && !low[i].is_nan() {
398 first = Some(i);
399 break;
400 }
401 }
402 let first = first.ok_or(FisherError::AllValuesNaN)?;
403
404 if period == 0 || period > data_len {
405 return Err(FisherError::InvalidPeriod { period, data_len });
406 }
407 if fisher_dst.len() != data_len || signal_dst.len() != data_len {
408 return Err(FisherError::OutputLengthMismatch {
409 expected: data_len,
410 got: fisher_dst.len().min(signal_dst.len()),
411 });
412 }
413
414 let chosen = if kern == Kernel::Auto {
415 detect_best_kernel()
416 } else {
417 kern
418 };
419
420 match chosen {
421 Kernel::Scalar | Kernel::ScalarBatch => {
422 fisher_scalar_into(high, low, period, first, fisher_dst, signal_dst)
423 }
424 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
425 Kernel::Avx2 | Kernel::Avx2Batch => {
426 fisher_avx2_into(high, low, period, first, fisher_dst, signal_dst)
427 }
428 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
429 Kernel::Avx512 | Kernel::Avx512Batch => {
430 fisher_avx512_into(high, low, period, first, fisher_dst, signal_dst)
431 }
432 _ => unreachable!(),
433 }
434
435 let warmup_end = first + period - 1;
436 for i in 0..warmup_end {
437 fisher_dst[i] = f64::NAN;
438 signal_dst[i] = f64::NAN;
439 }
440
441 Ok(())
442}
443
444#[inline]
445pub fn fisher_batch_with_kernel(
446 high: &[f64],
447 low: &[f64],
448 sweep: &FisherBatchRange,
449 k: Kernel,
450) -> Result<FisherBatchOutput, FisherError> {
451 let kernel = match k {
452 Kernel::Auto => detect_best_batch_kernel(),
453 other if other.is_batch() => other,
454 other => return Err(FisherError::InvalidKernelForBatch(other)),
455 };
456 let simd = match kernel {
457 Kernel::Avx512Batch => Kernel::Avx512,
458 Kernel::Avx2Batch => Kernel::Avx2,
459 Kernel::ScalarBatch => Kernel::Scalar,
460 _ => unreachable!(),
461 };
462 fisher_batch_par_slice(high, low, sweep, simd)
463}
464
465#[derive(Clone, Debug)]
466pub struct FisherBatchRange {
467 pub period: (usize, usize, usize),
468}
469
470impl Default for FisherBatchRange {
471 fn default() -> Self {
472 Self {
473 period: (9, 258, 1),
474 }
475 }
476}
477
478#[derive(Clone, Debug, Default)]
479pub struct FisherBatchBuilder {
480 range: FisherBatchRange,
481 kernel: Kernel,
482}
483
484impl FisherBatchBuilder {
485 pub fn new() -> Self {
486 Self::default()
487 }
488 pub fn kernel(mut self, k: Kernel) -> Self {
489 self.kernel = k;
490 self
491 }
492 #[inline]
493 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
494 self.range.period = (start, end, step);
495 self
496 }
497 #[inline]
498 pub fn period_static(mut self, p: usize) -> Self {
499 self.range.period = (p, p, 0);
500 self
501 }
502 pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<FisherBatchOutput, FisherError> {
503 fisher_batch_with_kernel(high, low, &self.range, self.kernel)
504 }
505 pub fn with_default_slices(
506 high: &[f64],
507 low: &[f64],
508 k: Kernel,
509 ) -> Result<FisherBatchOutput, FisherError> {
510 FisherBatchBuilder::new().kernel(k).apply_slices(high, low)
511 }
512 pub fn apply_candles(self, c: &Candles) -> Result<FisherBatchOutput, FisherError> {
513 let high = source_type(c, "high");
514 let low = source_type(c, "low");
515 self.apply_slices(high, low)
516 }
517 pub fn with_default_candles(c: &Candles) -> Result<FisherBatchOutput, FisherError> {
518 FisherBatchBuilder::new()
519 .kernel(Kernel::Auto)
520 .apply_candles(c)
521 }
522}
523
524#[derive(Clone, Debug)]
525pub struct FisherBatchOutput {
526 pub fisher: Vec<f64>,
527 pub signal: Vec<f64>,
528 pub combos: Vec<FisherParams>,
529 pub rows: usize,
530 pub cols: usize,
531}
532impl FisherBatchOutput {
533 pub fn row_for_params(&self, p: &FisherParams) -> Option<usize> {
534 self.combos
535 .iter()
536 .position(|c| c.period.unwrap_or(9) == p.period.unwrap_or(9))
537 }
538 pub fn fisher_for(&self, p: &FisherParams) -> Option<&[f64]> {
539 self.row_for_params(p).map(|row| {
540 let start = row * self.cols;
541 &self.fisher[start..start + self.cols]
542 })
543 }
544 pub fn signal_for(&self, p: &FisherParams) -> Option<&[f64]> {
545 self.row_for_params(p).map(|row| {
546 let start = row * self.cols;
547 &self.signal[start..start + self.cols]
548 })
549 }
550}
551
552#[inline(always)]
553fn expand_grid(r: &FisherBatchRange) -> Vec<FisherParams> {
554 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
555 if step == 0 || start == end {
556 return vec![start];
557 }
558 let mut out = Vec::new();
559 if start < end {
560 if step == 0 {
561 return vec![start];
562 }
563 let mut v = start;
564 while v <= end {
565 out.push(v);
566 match v.checked_add(step) {
567 Some(n) => v = n,
568 None => break,
569 }
570 }
571 } else {
572 if step == 0 {
573 return vec![start];
574 }
575 let mut v = start;
576 while v >= end {
577 out.push(v);
578 if v < end + step {
579 break;
580 }
581 v -= step;
582 if v == 0 && end > 0 && step > 0 {
583 if v < end {
584 break;
585 }
586 }
587 }
588 }
589 out
590 }
591 let periods = axis_usize(r.period);
592 let mut out = Vec::with_capacity(periods.len());
593 for &p in &periods {
594 out.push(FisherParams { period: Some(p) });
595 }
596 out
597}
598
599#[inline(always)]
600pub fn fisher_batch_slice(
601 high: &[f64],
602 low: &[f64],
603 sweep: &FisherBatchRange,
604 kern: Kernel,
605) -> Result<FisherBatchOutput, FisherError> {
606 fisher_batch_inner(high, low, sweep, kern, false)
607}
608#[inline(always)]
609pub fn fisher_batch_par_slice(
610 high: &[f64],
611 low: &[f64],
612 sweep: &FisherBatchRange,
613 kern: Kernel,
614) -> Result<FisherBatchOutput, FisherError> {
615 fisher_batch_inner(high, low, sweep, kern, true)
616}
617
618#[inline(always)]
619fn fisher_batch_inner(
620 high: &[f64],
621 low: &[f64],
622 sweep: &FisherBatchRange,
623 kern: Kernel,
624 parallel: bool,
625) -> Result<FisherBatchOutput, FisherError> {
626 let combos = expand_grid(sweep);
627 if combos.is_empty() {
628 let (s, e, st) = sweep.period;
629 return Err(FisherError::InvalidRange {
630 start: s,
631 end: e,
632 step: st,
633 });
634 }
635
636 let data_len = high.len().min(low.len());
637
638 let mut first = None;
639 for i in 0..data_len {
640 if !high[i].is_nan() && !low[i].is_nan() {
641 first = Some(i);
642 break;
643 }
644 }
645 let first = first.ok_or(FisherError::AllValuesNaN)?;
646
647 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
648 if data_len - first < max_p {
649 return Err(FisherError::NotEnoughValidData {
650 needed: max_p,
651 valid: data_len - first,
652 });
653 }
654 let rows = combos.len();
655 let cols = data_len;
656
657 let _ = rows.checked_mul(cols).ok_or(FisherError::InvalidRange {
658 start: sweep.period.0,
659 end: sweep.period.1,
660 step: sweep.period.2,
661 })?;
662
663 let mut fisher_mu = make_uninit_matrix(rows, cols);
664 let mut signal_mu = make_uninit_matrix(rows, cols);
665
666 let mut warmup_periods: Vec<usize> = Vec::with_capacity(combos.len());
667 for c in &combos {
668 let p = c.period.unwrap_or(0);
669 let warm = first
670 .checked_add(p.saturating_sub(1))
671 .ok_or(FisherError::InvalidRange {
672 start: sweep.period.0,
673 end: sweep.period.1,
674 step: sweep.period.2,
675 })?;
676 warmup_periods.push(warm);
677 }
678
679 init_matrix_prefixes(&mut fisher_mu, cols, &warmup_periods);
680 init_matrix_prefixes(&mut signal_mu, cols, &warmup_periods);
681
682 let mut fisher_guard = core::mem::ManuallyDrop::new(fisher_mu);
683 let mut signal_guard = core::mem::ManuallyDrop::new(signal_mu);
684
685 let fisher_slice: &mut [f64] = unsafe {
686 core::slice::from_raw_parts_mut(fisher_guard.as_mut_ptr() as *mut f64, fisher_guard.len())
687 };
688 let signal_slice: &mut [f64] = unsafe {
689 core::slice::from_raw_parts_mut(signal_guard.as_mut_ptr() as *mut f64, signal_guard.len())
690 };
691
692 let hl: Vec<f64> = (0..data_len).map(|i| 0.5 * (high[i] + low[i])).collect();
693
694 let do_row = |row: usize, out_fish: &mut [f64], out_signal: &mut [f64]| {
695 let period = combos[row].period.unwrap();
696 match kern {
697 Kernel::Scalar | Kernel::ScalarBatch => {
698 fisher_row_scalar_from_hl(&hl, first, period, out_fish, out_signal)
699 }
700 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
701 Kernel::Avx2 | Kernel::Avx2Batch => {
702 fisher_row_avx2_direct(high, low, first, period, out_fish, out_signal)
703 }
704 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
705 Kernel::Avx512 | Kernel::Avx512Batch => {
706 fisher_row_avx512_direct(high, low, first, period, out_fish, out_signal)
707 }
708 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
709 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
710 fisher_row_scalar_direct(high, low, first, period, out_fish, out_signal)
711 }
712 Kernel::Auto => fisher_row_scalar_from_hl(&hl, first, period, out_fish, out_signal),
713 }
714 };
715
716 if parallel {
717 #[cfg(not(target_arch = "wasm32"))]
718 {
719 fisher_slice
720 .par_chunks_mut(cols)
721 .zip(signal_slice.par_chunks_mut(cols))
722 .enumerate()
723 .for_each(|(row, (fish, sig))| do_row(row, fish, sig));
724 }
725
726 #[cfg(target_arch = "wasm32")]
727 {
728 for (row, (fish, sig)) in fisher_slice
729 .chunks_mut(cols)
730 .zip(signal_slice.chunks_mut(cols))
731 .enumerate()
732 {
733 do_row(row, fish, sig);
734 }
735 }
736 } else {
737 for (row, (fish, sig)) in fisher_slice
738 .chunks_mut(cols)
739 .zip(signal_slice.chunks_mut(cols))
740 .enumerate()
741 {
742 do_row(row, fish, sig);
743 }
744 }
745
746 let fisher = unsafe {
747 Vec::from_raw_parts(
748 fisher_guard.as_mut_ptr() as *mut f64,
749 fisher_guard.len(),
750 fisher_guard.capacity(),
751 )
752 };
753 let signal = unsafe {
754 Vec::from_raw_parts(
755 signal_guard.as_mut_ptr() as *mut f64,
756 signal_guard.len(),
757 signal_guard.capacity(),
758 )
759 };
760
761 core::mem::forget(fisher_guard);
762 core::mem::forget(signal_guard);
763
764 Ok(FisherBatchOutput {
765 fisher,
766 signal,
767 combos,
768 rows,
769 cols,
770 })
771}
772
773#[inline(always)]
774fn fisher_batch_inner_into(
775 high: &[f64],
776 low: &[f64],
777 sweep: &FisherBatchRange,
778 kern: Kernel,
779 parallel: bool,
780 fisher_out: &mut [f64],
781 signal_out: &mut [f64],
782) -> Result<Vec<FisherParams>, FisherError> {
783 let combos = expand_grid(sweep);
784 if combos.is_empty() {
785 let (s, e, st) = sweep.period;
786 return Err(FisherError::InvalidRange {
787 start: s,
788 end: e,
789 step: st,
790 });
791 }
792
793 let data_len = high.len().min(low.len());
794
795 let mut first = None;
796 for i in 0..data_len {
797 if !high[i].is_nan() && !low[i].is_nan() {
798 first = Some(i);
799 break;
800 }
801 }
802 let first = first.ok_or(FisherError::AllValuesNaN)?;
803
804 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
805 let _ = combos
806 .len()
807 .checked_mul(max_p)
808 .ok_or(FisherError::InvalidRange {
809 start: sweep.period.0,
810 end: sweep.period.1,
811 step: sweep.period.2,
812 })?;
813 if data_len - first < max_p {
814 return Err(FisherError::NotEnoughValidData {
815 needed: max_p,
816 valid: data_len - first,
817 });
818 }
819
820 let rows = combos.len();
821 let cols = data_len;
822 let _ = rows.checked_mul(cols).ok_or(FisherError::InvalidRange {
823 start: sweep.period.0,
824 end: sweep.period.1,
825 step: sweep.period.2,
826 })?;
827
828 for (row, combo) in combos.iter().enumerate() {
829 let p = combo.period.unwrap_or(0);
830 let warmup = first
831 .checked_add(p.saturating_sub(1))
832 .ok_or(FisherError::InvalidRange {
833 start: sweep.period.0,
834 end: sweep.period.1,
835 step: sweep.period.2,
836 })?;
837 let row_start = row * cols;
838 for i in 0..warmup.min(cols) {
839 fisher_out[row_start + i] = f64::NAN;
840 signal_out[row_start + i] = f64::NAN;
841 }
842 }
843
844 let hl: Vec<f64> = (0..data_len).map(|i| 0.5 * (high[i] + low[i])).collect();
845
846 let do_row = |row: usize, out_fish: &mut [f64], out_signal: &mut [f64]| {
847 let period = combos[row].period.unwrap();
848 match kern {
849 Kernel::Scalar | Kernel::ScalarBatch => {
850 fisher_row_scalar_from_hl(&hl, first, period, out_fish, out_signal)
851 }
852 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
853 Kernel::Avx2 | Kernel::Avx2Batch => {
854 fisher_row_avx2_direct(high, low, first, period, out_fish, out_signal)
855 }
856 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
857 Kernel::Avx512 | Kernel::Avx512Batch => {
858 fisher_row_avx512_direct(high, low, first, period, out_fish, out_signal)
859 }
860 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
861 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
862 fisher_row_scalar_direct(high, low, first, period, out_fish, out_signal)
863 }
864 Kernel::Auto => fisher_row_scalar_from_hl(&hl, first, period, out_fish, out_signal),
865 }
866 };
867
868 if parallel {
869 #[cfg(not(target_arch = "wasm32"))]
870 {
871 fisher_out
872 .par_chunks_mut(cols)
873 .zip(signal_out.par_chunks_mut(cols))
874 .enumerate()
875 .for_each(|(row, (fish, sig))| do_row(row, fish, sig));
876 }
877
878 #[cfg(target_arch = "wasm32")]
879 {
880 for (row, (fish, sig)) in fisher_out
881 .chunks_mut(cols)
882 .zip(signal_out.chunks_mut(cols))
883 .enumerate()
884 {
885 do_row(row, fish, sig);
886 }
887 }
888 } else {
889 for (row, (fish, sig)) in fisher_out
890 .chunks_mut(cols)
891 .zip(signal_out.chunks_mut(cols))
892 .enumerate()
893 {
894 do_row(row, fish, sig);
895 }
896 }
897
898 Ok(combos)
899}
900
901#[inline(always)]
902fn fisher_row_scalar_direct(
903 high: &[f64],
904 low: &[f64],
905 first: usize,
906 period: usize,
907 out_fish: &mut [f64],
908 out_signal: &mut [f64],
909) {
910 let len = high.len().min(low.len());
911 if period == 0 || first >= len {
912 return;
913 }
914
915 let mut prev_fish = 0.0f64;
916 let mut val1 = 0.0f64;
917
918 let warm = first + period - 1;
919
920 for i in warm..len {
921 let start = i + 1 - period;
922
923 let (mut min_val, mut max_val) = (f64::MAX, f64::MIN);
924 for j in start..=i {
925 let midpoint = 0.5 * (high[j] + low[j]);
926 if midpoint > max_val {
927 max_val = midpoint;
928 }
929 if midpoint < min_val {
930 min_val = midpoint;
931 }
932 }
933
934 let range = (max_val - min_val).max(0.001);
935 let hl = 0.5 * (high[i] + low[i]);
936 val1 = 0.67 * val1 + 0.66 * ((hl - min_val) / range - 0.5);
937 if val1 > 0.99 {
938 val1 = 0.999;
939 } else if val1 < -0.99 {
940 val1 = -0.999;
941 }
942 out_signal[i] = prev_fish;
943 let new_fish = 0.5 * ((1.0 + val1) / (1.0 - val1)).ln() + 0.5 * prev_fish;
944 out_fish[i] = new_fish;
945 prev_fish = new_fish;
946 }
947}
948
949#[inline(always)]
950fn fisher_row_scalar_from_hl(
951 hl: &[f64],
952 first: usize,
953 period: usize,
954 out_fish: &mut [f64],
955 out_signal: &mut [f64],
956) {
957 let len = hl.len();
958 if period == 0 || first >= len {
959 return;
960 }
961
962 let mut prev_fish = 0.0f64;
963 let mut val1 = 0.0f64;
964 let warm = first + period - 1;
965
966 for i in warm..len {
967 let start = i + 1 - period;
968 let (mut min_val, mut max_val) = (f64::MAX, f64::MIN);
969 for &v in &hl[start..=i] {
970 if v > max_val {
971 max_val = v;
972 }
973 if v < min_val {
974 min_val = v;
975 }
976 }
977
978 let range = (max_val - min_val).max(0.001);
979 let v = hl[i];
980 val1 = 0.67f64.mul_add(val1, 0.66 * ((v - min_val) / range - 0.5));
981 if val1 > 0.99 {
982 val1 = 0.999;
983 } else if val1 < -0.99 {
984 val1 = -0.999;
985 }
986 out_signal[i] = prev_fish;
987 let new_fish = 0.5f64.mul_add(((1.0 + val1) / (1.0 - val1)).ln(), 0.5 * prev_fish);
988 out_fish[i] = new_fish;
989 prev_fish = new_fish;
990 }
991}
992
993#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
994#[inline]
995pub fn fisher_row_avx2_direct(
996 high: &[f64],
997 low: &[f64],
998 first: usize,
999 period: usize,
1000 out_fish: &mut [f64],
1001 out_signal: &mut [f64],
1002) {
1003 fisher_row_scalar_direct(high, low, first, period, out_fish, out_signal)
1004}
1005
1006#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1007#[inline]
1008pub fn fisher_row_avx512_direct(
1009 high: &[f64],
1010 low: &[f64],
1011 first: usize,
1012 period: usize,
1013 out_fish: &mut [f64],
1014 out_signal: &mut [f64],
1015) {
1016 fisher_row_scalar_direct(high, low, first, period, out_fish, out_signal)
1017}
1018
1019#[derive(Debug, Clone)]
1020pub struct FisherStream {
1021 period: usize,
1022
1023 idx: usize,
1024 filled: bool,
1025
1026 minq: VecDeque<(f64, usize)>,
1027 maxq: VecDeque<(f64, usize)>,
1028 prev_fish: f64,
1029 val1: f64,
1030}
1031
1032impl FisherStream {
1033 pub fn try_new(params: FisherParams) -> Result<Self, FisherError> {
1034 let period = params.period.unwrap_or(9);
1035 if period == 0 {
1036 return Err(FisherError::InvalidPeriod {
1037 period,
1038 data_len: 0,
1039 });
1040 }
1041 Ok(Self {
1042 period,
1043 idx: 0,
1044 filled: false,
1045 minq: VecDeque::with_capacity(period + 1),
1046 maxq: VecDeque::with_capacity(period + 1),
1047 prev_fish: 0.0,
1048 val1: 0.0,
1049 })
1050 }
1051
1052 #[inline(always)]
1053 pub fn update(&mut self, high: f64, low: f64) -> Option<(f64, f64)> {
1054 let v = 0.5 * (high + low);
1055 let k = self.idx;
1056
1057 while let Some(&(last_v, _)) = self.minq.back() {
1058 if last_v >= v {
1059 self.minq.pop_back();
1060 } else {
1061 break;
1062 }
1063 }
1064 self.minq.push_back((v, k));
1065
1066 while let Some(&(last_v, _)) = self.maxq.back() {
1067 if last_v <= v {
1068 self.maxq.pop_back();
1069 } else {
1070 break;
1071 }
1072 }
1073 self.maxq.push_back((v, k));
1074
1075 let start = k.saturating_sub(self.period - 1);
1076 while let Some(&(_, i)) = self.minq.front() {
1077 if i < start {
1078 self.minq.pop_front();
1079 } else {
1080 break;
1081 }
1082 }
1083 while let Some(&(_, i)) = self.maxq.front() {
1084 if i < start {
1085 self.maxq.pop_front();
1086 } else {
1087 break;
1088 }
1089 }
1090
1091 self.idx = k + 1;
1092
1093 if !self.filled {
1094 self.filled = self.idx >= self.period;
1095 if !self.filled {
1096 return None;
1097 }
1098 }
1099
1100 let min_val = self.minq.front().map(|&(x, _)| x).unwrap_or(v);
1101 let max_val = self.maxq.front().map(|&(x, _)| x).unwrap_or(v);
1102
1103 let range = (max_val - min_val).max(0.001);
1104
1105 self.val1 = 0.67f64.mul_add(self.val1, 0.66 * ((v - min_val) / range - 0.5));
1106 if self.val1 > 0.99 {
1107 self.val1 = 0.999;
1108 } else if self.val1 < -0.99 {
1109 self.val1 = -0.999;
1110 }
1111
1112 let signal = self.prev_fish;
1113 let fisher = 0.5f64.mul_add(
1114 ((1.0 + self.val1) / (1.0 - self.val1)).ln(),
1115 0.5 * self.prev_fish,
1116 );
1117
1118 self.prev_fish = fisher;
1119 Some((fisher, signal))
1120 }
1121}
1122
1123#[cfg(test)]
1124mod tests {
1125 use super::*;
1126 use crate::skip_if_unsupported;
1127 use crate::utilities::data_loader::read_candles_from_csv;
1128
1129 fn check_fisher_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1130 skip_if_unsupported!(kernel, test_name);
1131 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1132 let candles = read_candles_from_csv(file_path)?;
1133 let default_params = FisherParams { period: None };
1134 let input = FisherInput::from_candles(&candles, default_params);
1135 let output = fisher_with_kernel(&input, kernel)?;
1136 assert_eq!(output.fisher.len(), candles.close.len());
1137 Ok(())
1138 }
1139
1140 fn check_fisher_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1141 skip_if_unsupported!(kernel, test_name);
1142 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1143 let candles = read_candles_from_csv(file_path)?;
1144 let input = FisherInput::from_candles(&candles, FisherParams::default());
1145 let result = fisher_with_kernel(&input, kernel)?;
1146 let expected_last_five_fisher = [
1147 -0.4720164683904261,
1148 -0.23467530106650444,
1149 -0.14879388501136784,
1150 -0.026651419122953053,
1151 -0.2569225042442664,
1152 ];
1153 let start = result.fisher.len().saturating_sub(5);
1154 for (i, &val) in result.fisher[start..].iter().enumerate() {
1155 let diff = (val - expected_last_five_fisher[i]).abs();
1156 assert!(
1157 diff < 1e-1,
1158 "[{}] Fisher {:?} mismatch at idx {}: got {}, expected {}",
1159 test_name,
1160 kernel,
1161 i,
1162 val,
1163 expected_last_five_fisher[i]
1164 );
1165 }
1166 Ok(())
1167 }
1168
1169 fn check_fisher_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1170 skip_if_unsupported!(kernel, test_name);
1171 let high = [10.0, 20.0, 30.0];
1172 let low = [5.0, 15.0, 25.0];
1173 let params = FisherParams { period: Some(0) };
1174 let input = FisherInput::from_slices(&high, &low, params);
1175 let res = fisher_with_kernel(&input, kernel);
1176 assert!(
1177 res.is_err(),
1178 "[{}] Fisher should fail with zero period",
1179 test_name
1180 );
1181 Ok(())
1182 }
1183
1184 fn check_fisher_period_exceeds_length(
1185 test_name: &str,
1186 kernel: Kernel,
1187 ) -> Result<(), Box<dyn Error>> {
1188 skip_if_unsupported!(kernel, test_name);
1189 let high = [10.0, 20.0, 30.0];
1190 let low = [5.0, 15.0, 25.0];
1191 let params = FisherParams { period: Some(10) };
1192 let input = FisherInput::from_slices(&high, &low, params);
1193 let res = fisher_with_kernel(&input, kernel);
1194 assert!(
1195 res.is_err(),
1196 "[{}] Fisher should fail with period exceeding length",
1197 test_name
1198 );
1199 Ok(())
1200 }
1201
1202 fn check_fisher_very_small_dataset(
1203 test_name: &str,
1204 kernel: Kernel,
1205 ) -> Result<(), Box<dyn Error>> {
1206 skip_if_unsupported!(kernel, test_name);
1207 let high = [10.0];
1208 let low = [5.0];
1209 let params = FisherParams { period: Some(9) };
1210 let input = FisherInput::from_slices(&high, &low, params);
1211 let res = fisher_with_kernel(&input, kernel);
1212 assert!(
1213 res.is_err(),
1214 "[{}] Fisher should fail with insufficient data",
1215 test_name
1216 );
1217 Ok(())
1218 }
1219
1220 fn check_fisher_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1221 skip_if_unsupported!(kernel, test_name);
1222 let high = [10.0, 12.0, 14.0, 16.0, 18.0, 20.0];
1223 let low = [5.0, 7.0, 9.0, 10.0, 13.0, 15.0];
1224 let first_params = FisherParams { period: Some(3) };
1225 let first_input = FisherInput::from_slices(&high, &low, first_params);
1226 let first_result = fisher_with_kernel(&first_input, kernel)?;
1227 let second_params = FisherParams { period: Some(3) };
1228 let second_input =
1229 FisherInput::from_slices(&first_result.fisher, &first_result.signal, second_params);
1230 let second_result = fisher_with_kernel(&second_input, kernel)?;
1231 assert_eq!(first_result.fisher.len(), second_result.fisher.len());
1232 assert_eq!(first_result.signal.len(), second_result.signal.len());
1233 Ok(())
1234 }
1235
1236 fn check_fisher_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1237 skip_if_unsupported!(kernel, test_name);
1238 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1239 let candles = read_candles_from_csv(file_path)?;
1240 let input = FisherInput::from_candles(&candles, FisherParams::default());
1241 let res = fisher_with_kernel(&input, kernel)?;
1242 assert_eq!(res.fisher.len(), candles.close.len());
1243 if res.fisher.len() > 240 {
1244 for (i, &val) in res.fisher[240..].iter().enumerate() {
1245 assert!(
1246 !val.is_nan(),
1247 "[{}] Found unexpected NaN at out-index {}",
1248 test_name,
1249 240 + i
1250 );
1251 }
1252 }
1253 Ok(())
1254 }
1255
1256 fn check_fisher_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1257 skip_if_unsupported!(kernel, test_name);
1258 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1259 let candles = read_candles_from_csv(file_path)?;
1260 let period = 9;
1261 let input = FisherInput::from_candles(
1262 &candles,
1263 FisherParams {
1264 period: Some(period),
1265 },
1266 );
1267 let batch_output = fisher_with_kernel(&input, kernel)?.fisher;
1268
1269 let highs = source_type(&candles, "high");
1270 let lows = source_type(&candles, "low");
1271
1272 let mut stream = FisherStream::try_new(FisherParams {
1273 period: Some(period),
1274 })?;
1275 let mut stream_fisher = Vec::with_capacity(highs.len());
1276 for (&h, &l) in highs.iter().zip(lows.iter()) {
1277 match stream.update(h, l) {
1278 Some((fish, _sig)) => stream_fisher.push(fish),
1279 None => stream_fisher.push(f64::NAN),
1280 }
1281 }
1282
1283 assert_eq!(batch_output.len(), stream_fisher.len());
1284 for (i, (&b, &s)) in batch_output.iter().zip(stream_fisher.iter()).enumerate() {
1285 if b.is_nan() && s.is_nan() {
1286 continue;
1287 }
1288 let diff = (b - s).abs();
1289 assert!(
1290 diff < 1e-9,
1291 "[{}] Fisher streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1292 test_name,
1293 i,
1294 b,
1295 s,
1296 diff
1297 );
1298 }
1299 Ok(())
1300 }
1301
1302 #[cfg(debug_assertions)]
1303 fn check_fisher_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1304 skip_if_unsupported!(kernel, test_name);
1305
1306 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1307 let candles = read_candles_from_csv(file_path)?;
1308
1309 let test_params = vec![
1310 FisherParams::default(),
1311 FisherParams { period: Some(1) },
1312 FisherParams { period: Some(2) },
1313 FisherParams { period: Some(3) },
1314 FisherParams { period: Some(5) },
1315 FisherParams { period: Some(10) },
1316 FisherParams { period: Some(20) },
1317 FisherParams { period: Some(30) },
1318 FisherParams { period: Some(50) },
1319 FisherParams { period: Some(100) },
1320 FisherParams { period: Some(200) },
1321 FisherParams { period: Some(240) },
1322 ];
1323
1324 for (param_idx, params) in test_params.iter().enumerate() {
1325 let input = FisherInput::from_candles(&candles, params.clone());
1326 let output = fisher_with_kernel(&input, kernel)?;
1327
1328 for (i, &val) in output.fisher.iter().enumerate() {
1329 if val.is_nan() {
1330 continue;
1331 }
1332
1333 let bits = val.to_bits();
1334
1335 if bits == 0x11111111_11111111 {
1336 panic!(
1337 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1338 in fisher output with params: period={} (param set {})",
1339 test_name,
1340 val,
1341 bits,
1342 i,
1343 params.period.unwrap_or(9),
1344 param_idx
1345 );
1346 }
1347
1348 if bits == 0x22222222_22222222 {
1349 panic!(
1350 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1351 in fisher output with params: period={} (param set {})",
1352 test_name,
1353 val,
1354 bits,
1355 i,
1356 params.period.unwrap_or(9),
1357 param_idx
1358 );
1359 }
1360
1361 if bits == 0x33333333_33333333 {
1362 panic!(
1363 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1364 in fisher output with params: period={} (param set {})",
1365 test_name,
1366 val,
1367 bits,
1368 i,
1369 params.period.unwrap_or(9),
1370 param_idx
1371 );
1372 }
1373 }
1374
1375 for (i, &val) in output.signal.iter().enumerate() {
1376 if val.is_nan() {
1377 continue;
1378 }
1379
1380 let bits = val.to_bits();
1381
1382 if bits == 0x11111111_11111111 {
1383 panic!(
1384 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1385 in signal output with params: period={} (param set {})",
1386 test_name,
1387 val,
1388 bits,
1389 i,
1390 params.period.unwrap_or(9),
1391 param_idx
1392 );
1393 }
1394
1395 if bits == 0x22222222_22222222 {
1396 panic!(
1397 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1398 in signal output with params: period={} (param set {})",
1399 test_name,
1400 val,
1401 bits,
1402 i,
1403 params.period.unwrap_or(9),
1404 param_idx
1405 );
1406 }
1407
1408 if bits == 0x33333333_33333333 {
1409 panic!(
1410 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1411 in signal output with params: period={} (param set {})",
1412 test_name,
1413 val,
1414 bits,
1415 i,
1416 params.period.unwrap_or(9),
1417 param_idx
1418 );
1419 }
1420 }
1421 }
1422
1423 Ok(())
1424 }
1425
1426 #[cfg(not(debug_assertions))]
1427 fn check_fisher_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1428 Ok(())
1429 }
1430
1431 #[cfg(feature = "proptest")]
1432 #[allow(clippy::float_cmp)]
1433 fn check_fisher_property(
1434 test_name: &str,
1435 kernel: Kernel,
1436 ) -> Result<(), Box<dyn std::error::Error>> {
1437 use proptest::prelude::*;
1438 skip_if_unsupported!(kernel, test_name);
1439
1440 let strat = (2usize..=50).prop_flat_map(|period| {
1441 (
1442 (100f64..10000f64, 0.01f64..0.05f64, period + 10..400)
1443 .prop_flat_map(move |(base_price, volatility, data_len)| {
1444 (
1445 Just(base_price),
1446 Just(volatility),
1447 Just(data_len),
1448 prop::collection::vec((-1f64..1f64), data_len),
1449 prop::collection::vec(prop::bool::ANY, data_len),
1450 )
1451 })
1452 .prop_map(
1453 move |(
1454 base_price,
1455 volatility,
1456 data_len,
1457 price_changes,
1458 zero_spread_flags,
1459 )| {
1460 let mut high = Vec::with_capacity(data_len);
1461 let mut low = Vec::with_capacity(data_len);
1462 let mut current_price = base_price;
1463
1464 for i in 0..data_len {
1465 let change = price_changes[i] * volatility * current_price;
1466 current_price = (current_price + change).max(10.0);
1467
1468 if zero_spread_flags[i] && i % 5 == 0 {
1469 high.push(current_price);
1470 low.push(current_price);
1471 } else {
1472 let spread =
1473 current_price * 0.01 * (0.1 + price_changes[i].abs());
1474 high.push(current_price + spread);
1475 low.push((current_price - spread).max(10.0));
1476 }
1477 }
1478
1479 (high, low)
1480 },
1481 ),
1482 Just(period),
1483 )
1484 });
1485
1486 proptest::test_runner::TestRunner::default().run(&strat, |((high, low), period)| {
1487 let params = FisherParams {
1488 period: Some(period),
1489 };
1490 let input = FisherInput::from_slices(&high, &low, params);
1491
1492 let FisherOutput {
1493 fisher: out,
1494 signal: sig,
1495 } = fisher_with_kernel(&input, kernel)?;
1496 let FisherOutput {
1497 fisher: ref_out,
1498 signal: ref_sig,
1499 } = fisher_with_kernel(&input, Kernel::Scalar)?;
1500
1501 prop_assert_eq!(
1502 out.len(),
1503 high.len(),
1504 "[{}] Fisher output length mismatch",
1505 test_name
1506 );
1507 prop_assert_eq!(
1508 sig.len(),
1509 high.len(),
1510 "[{}] Signal output length mismatch",
1511 test_name
1512 );
1513
1514 let mut first_valid = None;
1515 for i in 0..high.len() {
1516 if !high[i].is_nan() && !low[i].is_nan() {
1517 first_valid = Some(i);
1518 break;
1519 }
1520 }
1521
1522 if let Some(first) = first_valid {
1523 let warmup_end = first + period - 1;
1524 for i in 0..warmup_end.min(out.len()) {
1525 prop_assert!(
1526 out[i].is_nan(),
1527 "[{}] Expected NaN at index {} during warmup",
1528 test_name,
1529 i
1530 );
1531 prop_assert!(
1532 sig[i].is_nan(),
1533 "[{}] Expected NaN at signal index {} during warmup",
1534 test_name,
1535 i
1536 );
1537 }
1538
1539 if warmup_end < out.len() {
1540 prop_assert!(
1541 !out[warmup_end].is_nan(),
1542 "[{}] Expected valid value at index {} after warmup",
1543 test_name,
1544 warmup_end
1545 );
1546 }
1547 }
1548
1549 if let Some(first) = first_valid {
1550 let warmup_end = first + period - 1;
1551
1552 for window_start in warmup_end..out.len().saturating_sub(period * 2) {
1553 let window_end = (window_start + period).min(out.len());
1554
1555 let mut is_constant = true;
1556 let first_hl = (high[window_start] + low[window_start]) / 2.0;
1557
1558 for i in window_start..window_end {
1559 let current_hl = (high[i] + low[i]) / 2.0;
1560 if (current_hl - first_hl).abs() > 0.001 * first_hl {
1561 is_constant = false;
1562 break;
1563 }
1564 }
1565
1566 if is_constant && window_end > window_start + 3 {
1567 let fisher_start = out[window_start].abs();
1568 let fisher_end = out[window_end - 1].abs();
1569
1570 if fisher_start > 0.1 {
1571 prop_assert!(
1572 fisher_end <= fisher_start * 1.1,
1573 "[{}] Fisher not trending to zero in constant period [{}, {}]: start={}, end={}",
1574 test_name, window_start, window_end, fisher_start, fisher_end
1575 );
1576 }
1577 }
1578 }
1579 }
1580
1581 for i in 1..out.len() {
1582 if !out[i - 1].is_nan() && !sig[i].is_nan() {
1583 prop_assert!(
1584 (sig[i] - out[i - 1]).abs() < 1e-9,
1585 "[{}] Signal at {} ({}) doesn't match previous Fisher ({})",
1586 test_name,
1587 i,
1588 sig[i],
1589 out[i - 1]
1590 );
1591 }
1592 }
1593
1594 if let Some(first) = first_valid {
1595 let warmup_end = first + period - 1;
1596 if warmup_end < out.len() && !out[warmup_end].is_nan() {
1597 prop_assert!(
1598 out[warmup_end].abs() < 5.0,
1599 "[{}] First Fisher value {} seems incorrect (should start from zero state)",
1600 test_name,
1601 out[warmup_end]
1602 );
1603 }
1604 }
1605
1606 for i in 0..out.len() {
1607 let y = out[i];
1608 let r = ref_out[i];
1609 let s = sig[i];
1610 let rs = ref_sig[i];
1611
1612 if y.is_nan() || r.is_nan() {
1613 prop_assert_eq!(
1614 y.is_nan(),
1615 r.is_nan(),
1616 "[{}] NaN mismatch at index {}",
1617 test_name,
1618 i
1619 );
1620 continue;
1621 }
1622
1623 if s.is_nan() || rs.is_nan() {
1624 prop_assert_eq!(
1625 s.is_nan(),
1626 rs.is_nan(),
1627 "[{}] Signal NaN mismatch at index {}",
1628 test_name,
1629 i
1630 );
1631 continue;
1632 }
1633
1634 let y_bits = y.to_bits();
1635 let r_bits = r.to_bits();
1636 let s_bits = s.to_bits();
1637 let rs_bits = rs.to_bits();
1638
1639 let ulp_diff_fisher: u64 = y_bits.abs_diff(r_bits);
1640 let ulp_diff_signal: u64 = s_bits.abs_diff(rs_bits);
1641
1642 prop_assert!(
1643 (y - r).abs() <= 1e-9 || ulp_diff_fisher <= 4,
1644 "[{}] Fisher mismatch idx {}: {} vs {} (ULP={})",
1645 test_name,
1646 i,
1647 y,
1648 r,
1649 ulp_diff_fisher
1650 );
1651
1652 prop_assert!(
1653 (s - rs).abs() <= 1e-9 || ulp_diff_signal <= 4,
1654 "[{}] Signal mismatch idx {}: {} vs {} (ULP={})",
1655 test_name,
1656 i,
1657 s,
1658 rs,
1659 ulp_diff_signal
1660 );
1661 }
1662
1663 Ok(())
1664 })?;
1665
1666 Ok(())
1667 }
1668
1669 macro_rules! generate_all_fisher_tests {
1670 ($($test_fn:ident),*) => {
1671 paste::paste! {
1672 $(
1673 #[test]
1674 fn [<$test_fn _scalar_f64>]() {
1675 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1676 }
1677 )*
1678 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1679 $(
1680 #[test]
1681 fn [<$test_fn _avx2_f64>]() {
1682 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1683 }
1684 #[test]
1685 fn [<$test_fn _avx512_f64>]() {
1686 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1687 }
1688 )*
1689 }
1690 }
1691 }
1692
1693 generate_all_fisher_tests!(
1694 check_fisher_partial_params,
1695 check_fisher_accuracy,
1696 check_fisher_zero_period,
1697 check_fisher_period_exceeds_length,
1698 check_fisher_very_small_dataset,
1699 check_fisher_reinput,
1700 check_fisher_nan_handling,
1701 check_fisher_streaming,
1702 check_fisher_no_poison
1703 );
1704
1705 #[cfg(feature = "proptest")]
1706 generate_all_fisher_tests!(check_fisher_property);
1707
1708 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1709 skip_if_unsupported!(kernel, test);
1710
1711 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1712 let c = read_candles_from_csv(file)?;
1713 let output = FisherBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
1714
1715 let def = FisherParams::default();
1716 let row = output.fisher_for(&def).expect("default row missing");
1717
1718 assert_eq!(row.len(), c.close.len());
1719
1720 let expected_last_five = [
1721 -0.4720164683904261,
1722 -0.23467530106650444,
1723 -0.14879388501136784,
1724 -0.026651419122953053,
1725 -0.2569225042442664,
1726 ];
1727 let start = row.len().saturating_sub(5);
1728 for (i, &val) in row[start..].iter().enumerate() {
1729 let diff = (val - expected_last_five[i]).abs();
1730 assert!(
1731 diff < 1e-1,
1732 "[{test}] default-row mismatch at idx {i}: {val} vs {expected_last_five:?}"
1733 );
1734 }
1735 Ok(())
1736 }
1737
1738 #[cfg(debug_assertions)]
1739 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1740 skip_if_unsupported!(kernel, test);
1741
1742 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1743 let c = read_candles_from_csv(file)?;
1744
1745 let test_configs = vec![
1746 (1, 10, 1),
1747 (2, 20, 2),
1748 (5, 50, 5),
1749 (10, 100, 10),
1750 (20, 240, 20),
1751 (9, 9, 0),
1752 (50, 200, 50),
1753 (1, 5, 1),
1754 (100, 240, 40),
1755 (3, 30, 3),
1756 ];
1757
1758 for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1759 let output = FisherBatchBuilder::new()
1760 .kernel(kernel)
1761 .period_range(p_start, p_end, p_step)
1762 .apply_candles(&c)?;
1763
1764 for (idx, &val) in output.fisher.iter().enumerate() {
1765 if val.is_nan() {
1766 continue;
1767 }
1768
1769 let bits = val.to_bits();
1770 let row = idx / output.cols;
1771 let col = idx % output.cols;
1772 let combo = &output.combos[row];
1773
1774 if bits == 0x11111111_11111111 {
1775 panic!(
1776 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1777 at row {} col {} (flat index {}) in fisher output with params: period={}",
1778 test,
1779 cfg_idx,
1780 val,
1781 bits,
1782 row,
1783 col,
1784 idx,
1785 combo.period.unwrap_or(9)
1786 );
1787 }
1788
1789 if bits == 0x22222222_22222222 {
1790 panic!(
1791 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1792 at row {} col {} (flat index {}) in fisher output with params: period={}",
1793 test,
1794 cfg_idx,
1795 val,
1796 bits,
1797 row,
1798 col,
1799 idx,
1800 combo.period.unwrap_or(9)
1801 );
1802 }
1803
1804 if bits == 0x33333333_33333333 {
1805 panic!(
1806 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1807 at row {} col {} (flat index {}) in fisher output with params: period={}",
1808 test,
1809 cfg_idx,
1810 val,
1811 bits,
1812 row,
1813 col,
1814 idx,
1815 combo.period.unwrap_or(9)
1816 );
1817 }
1818 }
1819
1820 for (idx, &val) in output.signal.iter().enumerate() {
1821 if val.is_nan() {
1822 continue;
1823 }
1824
1825 let bits = val.to_bits();
1826 let row = idx / output.cols;
1827 let col = idx % output.cols;
1828 let combo = &output.combos[row];
1829
1830 if bits == 0x11111111_11111111 {
1831 panic!(
1832 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1833 at row {} col {} (flat index {}) in signal output with params: period={}",
1834 test,
1835 cfg_idx,
1836 val,
1837 bits,
1838 row,
1839 col,
1840 idx,
1841 combo.period.unwrap_or(9)
1842 );
1843 }
1844
1845 if bits == 0x22222222_22222222 {
1846 panic!(
1847 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1848 at row {} col {} (flat index {}) in signal output with params: period={}",
1849 test,
1850 cfg_idx,
1851 val,
1852 bits,
1853 row,
1854 col,
1855 idx,
1856 combo.period.unwrap_or(9)
1857 );
1858 }
1859
1860 if bits == 0x33333333_33333333 {
1861 panic!(
1862 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1863 at row {} col {} (flat index {}) in signal output with params: period={}",
1864 test,
1865 cfg_idx,
1866 val,
1867 bits,
1868 row,
1869 col,
1870 idx,
1871 combo.period.unwrap_or(9)
1872 );
1873 }
1874 }
1875 }
1876
1877 Ok(())
1878 }
1879
1880 #[cfg(not(debug_assertions))]
1881 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1882 Ok(())
1883 }
1884
1885 macro_rules! gen_batch_tests {
1886 ($fn_name:ident) => {
1887 paste::paste! {
1888 #[test] fn [<$fn_name _scalar>]() {
1889 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1890 }
1891 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1892 #[test] fn [<$fn_name _avx2>]() {
1893 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1894 }
1895 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1896 #[test] fn [<$fn_name _avx512>]() {
1897 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1898 }
1899 #[test] fn [<$fn_name _auto_detect>]() {
1900 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1901 }
1902 }
1903 };
1904 }
1905 gen_batch_tests!(check_batch_default_row);
1906 gen_batch_tests!(check_batch_no_poison);
1907
1908 #[test]
1909 fn check_batch_kernel_dispatch() -> Result<(), Box<dyn Error>> {
1910 let high = vec![10.0, 12.0, 14.0, 16.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0];
1911 let low = vec![5.0, 7.0, 9.0, 10.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0];
1912 let sweep = FisherBatchRange { period: (3, 5, 1) };
1913
1914 let scalar_result = fisher_batch_slice(&high, &low, &sweep, Kernel::Scalar)?;
1915
1916 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1917 if is_x86_feature_detected!("avx2") {
1918 let avx2_result = fisher_batch_slice(&high, &low, &sweep, Kernel::Avx2)?;
1919
1920 for i in 0..scalar_result.fisher.len() {
1921 let diff = (scalar_result.fisher[i] - avx2_result.fisher[i]).abs();
1922 assert!(
1923 diff < 1e-10
1924 || (scalar_result.fisher[i].is_nan() && avx2_result.fisher[i].is_nan()),
1925 "Fisher mismatch at {}: scalar={}, avx2={}",
1926 i,
1927 scalar_result.fisher[i],
1928 avx2_result.fisher[i]
1929 );
1930 }
1931 }
1932
1933 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1934 if is_x86_feature_detected!("avx512f") {
1935 let avx512_result = fisher_batch_slice(&high, &low, &sweep, Kernel::Avx512)?;
1936
1937 for i in 0..scalar_result.fisher.len() {
1938 let diff = (scalar_result.fisher[i] - avx512_result.fisher[i]).abs();
1939 assert!(
1940 diff < 1e-10
1941 || (scalar_result.fisher[i].is_nan() && avx512_result.fisher[i].is_nan()),
1942 "Fisher mismatch at {}: scalar={}, avx512={}",
1943 i,
1944 scalar_result.fisher[i],
1945 avx512_result.fisher[i]
1946 );
1947 }
1948 }
1949
1950 Ok(())
1951 }
1952
1953 #[test]
1954 fn test_fisher_into_matches_api() -> Result<(), Box<dyn Error>> {
1955 let n = 256usize;
1956 let mut ts = Vec::with_capacity(n);
1957 let mut open = Vec::with_capacity(n);
1958 let mut high = Vec::with_capacity(n);
1959 let mut low = Vec::with_capacity(n);
1960 let mut close = Vec::with_capacity(n);
1961 let mut volume = Vec::with_capacity(n);
1962
1963 for i in 0..n {
1964 ts.push(i as i64);
1965 let base = 1000.0 + (i as f64) * 0.1;
1966 let wiggle = ((i as f64) * 0.15).sin() * 2.0;
1967 let h = base + 5.0 + wiggle;
1968 let l = base - 5.0 - 0.5 * wiggle;
1969 let o = base - 1.0;
1970 let c = base + 1.0;
1971 open.push(o);
1972 high.push(h);
1973 low.push(l);
1974 close.push(c);
1975 volume.push(100.0 + (i % 10) as f64);
1976 }
1977
1978 let candles = crate::utilities::data_loader::Candles::new(
1979 ts,
1980 open,
1981 high.clone(),
1982 low.clone(),
1983 close,
1984 volume,
1985 );
1986 let input = FisherInput::from_candles(&candles, FisherParams::default());
1987
1988 let base = fisher(&input)?;
1989
1990 let mut out_fish = vec![0.0; n];
1991 let mut out_sig = vec![0.0; n];
1992
1993 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1994 {
1995 fisher_into(&input, &mut out_fish, &mut out_sig)?;
1996 }
1997
1998 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1999 (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
2000 }
2001
2002 assert_eq!(out_fish.len(), base.fisher.len());
2003 assert_eq!(out_sig.len(), base.signal.len());
2004 for i in 0..n {
2005 assert!(
2006 eq_or_both_nan(out_fish[i], base.fisher[i]),
2007 "fisher mismatch at {}: {} vs {}",
2008 i,
2009 out_fish[i],
2010 base.fisher[i]
2011 );
2012 assert!(
2013 eq_or_both_nan(out_sig[i], base.signal[i]),
2014 "signal mismatch at {}: {} vs {}",
2015 i,
2016 out_sig[i],
2017 base.signal[i]
2018 );
2019 }
2020
2021 Ok(())
2022 }
2023}
2024
2025#[cfg(feature = "python")]
2026#[pyfunction(name = "fisher")]
2027#[pyo3(signature = (high, low, period, kernel=None))]
2028pub fn fisher_py<'py>(
2029 py: Python<'py>,
2030 high: PyReadonlyArray1<'py, f64>,
2031 low: PyReadonlyArray1<'py, f64>,
2032 period: usize,
2033 kernel: Option<&str>,
2034) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
2035 use numpy::{IntoPyArray, PyArrayMethods};
2036
2037 let high_slice = high.as_slice()?;
2038 let low_slice = low.as_slice()?;
2039 let kern = validate_kernel(kernel, false)?;
2040
2041 let data_len = high_slice.len().min(low_slice.len());
2042
2043 let fisher_arr = unsafe { PyArray1::<f64>::new(py, [data_len], false) };
2044 let signal_arr = unsafe { PyArray1::<f64>::new(py, [data_len], false) };
2045 let fisher_slice = unsafe { fisher_arr.as_slice_mut()? };
2046 let signal_slice = unsafe { signal_arr.as_slice_mut()? };
2047
2048 let params = FisherParams {
2049 period: Some(period),
2050 };
2051 let input = FisherInput::from_slices(high_slice, low_slice, params);
2052
2053 py.allow_threads(|| fisher_into_slice(fisher_slice, signal_slice, &input, kern))
2054 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2055
2056 Ok((fisher_arr, signal_arr))
2057}
2058
2059#[cfg(feature = "python")]
2060#[pyclass(name = "FisherStream")]
2061pub struct FisherStreamPy {
2062 stream: FisherStream,
2063}
2064
2065#[cfg(feature = "python")]
2066#[pymethods]
2067impl FisherStreamPy {
2068 #[new]
2069 pub fn new(period: usize) -> PyResult<Self> {
2070 let params = FisherParams {
2071 period: Some(period),
2072 };
2073 let stream =
2074 FisherStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2075 Ok(FisherStreamPy { stream })
2076 }
2077
2078 pub fn update(&mut self, high: f64, low: f64) -> Option<(f64, f64)> {
2079 self.stream.update(high, low)
2080 }
2081}
2082
2083#[cfg(feature = "python")]
2084#[pyfunction(name = "fisher_batch")]
2085#[pyo3(signature = (high, low, period_range, kernel=None))]
2086pub fn fisher_batch_py<'py>(
2087 py: Python<'py>,
2088 high: PyReadonlyArray1<'py, f64>,
2089 low: PyReadonlyArray1<'py, f64>,
2090 period_range: (usize, usize, usize),
2091 kernel: Option<&str>,
2092) -> PyResult<Bound<'py, PyDict>> {
2093 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2094 use pyo3::types::PyDict;
2095
2096 let high_slice = high.as_slice()?;
2097 let low_slice = low.as_slice()?;
2098 if high_slice.len() != low_slice.len() {
2099 return Err(PyValueError::new_err(format!(
2100 "Mismatched data length: high={}, low={}",
2101 high_slice.len(),
2102 low_slice.len()
2103 )));
2104 }
2105
2106 let kern = validate_kernel(kernel, true)?;
2107 let sweep = FisherBatchRange {
2108 period: period_range,
2109 };
2110
2111 let combos = expand_grid(&sweep);
2112 let rows = combos.len();
2113 let cols = high_slice.len();
2114 let total_len = rows
2115 .checked_mul(cols)
2116 .ok_or_else(|| PyValueError::new_err("Fisher batch: rows*cols overflow"))?;
2117
2118 let fisher_arr = unsafe { PyArray1::<f64>::new(py, [total_len], false) };
2119 let signal_arr = unsafe { PyArray1::<f64>::new(py, [total_len], false) };
2120
2121 let first = (0..cols)
2122 .find(|&i| !high_slice[i].is_nan() && !low_slice[i].is_nan())
2123 .ok_or_else(|| PyValueError::new_err("All values are NaN"))?;
2124 let mut warmups: Vec<usize> = Vec::with_capacity(combos.len());
2125 for c in &combos {
2126 let p = c.period.unwrap_or(0);
2127 let warm = first
2128 .checked_add(p.saturating_sub(1))
2129 .ok_or_else(|| PyValueError::new_err("Fisher batch: warmup overflow"))?;
2130 warmups.push(warm);
2131 }
2132
2133 unsafe {
2134 let fisher_mu = std::slice::from_raw_parts_mut(
2135 fisher_arr.as_slice_mut()?.as_mut_ptr() as *mut MaybeUninit<f64>,
2136 total_len,
2137 );
2138 let signal_mu = std::slice::from_raw_parts_mut(
2139 signal_arr.as_slice_mut()?.as_mut_ptr() as *mut MaybeUninit<f64>,
2140 total_len,
2141 );
2142 init_matrix_prefixes(fisher_mu, cols, &warmups);
2143 init_matrix_prefixes(signal_mu, cols, &warmups);
2144 }
2145
2146 let fisher_ptr = unsafe { fisher_arr.as_slice_mut()?.as_mut_ptr() } as usize;
2147 let signal_ptr = unsafe { signal_arr.as_slice_mut()?.as_mut_ptr() } as usize;
2148
2149 py.allow_threads(move || {
2150 let kernel = match kern {
2151 Kernel::Auto => detect_best_batch_kernel(),
2152 k => k,
2153 };
2154 let simd = match kernel {
2155 Kernel::Avx512Batch => Kernel::Avx512,
2156 Kernel::Avx2Batch => Kernel::Avx2,
2157 Kernel::ScalarBatch => Kernel::Scalar,
2158 _ => kernel,
2159 };
2160
2161 unsafe {
2162 let fisher_slice = std::slice::from_raw_parts_mut(fisher_ptr as *mut f64, total_len);
2163 let signal_slice = std::slice::from_raw_parts_mut(signal_ptr as *mut f64, total_len);
2164 fisher_batch_inner_into(
2165 high_slice,
2166 low_slice,
2167 &sweep,
2168 simd,
2169 true,
2170 fisher_slice,
2171 signal_slice,
2172 )
2173 }
2174 })
2175 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2176
2177 let dict = PyDict::new(py);
2178 dict.set_item("fisher", fisher_arr.reshape((rows, cols))?)?;
2179 dict.set_item("signal", signal_arr.reshape((rows, cols))?)?;
2180 dict.set_item(
2181 "periods",
2182 combos
2183 .iter()
2184 .map(|p| p.period.unwrap() as u64)
2185 .collect::<Vec<_>>()
2186 .into_pyarray(py),
2187 )?;
2188
2189 Ok(dict)
2190}
2191
2192#[cfg(all(feature = "python", feature = "cuda"))]
2193#[pyclass(
2194 module = "ta_indicators.cuda",
2195 name = "FisherDeviceArrayF32",
2196 unsendable
2197)]
2198pub struct FisherDeviceArrayF32Py {
2199 pub(crate) inner: Option<DeviceArrayF32Py>,
2200}
2201
2202#[cfg(all(feature = "python", feature = "cuda"))]
2203#[pymethods]
2204impl FisherDeviceArrayF32Py {
2205 #[getter]
2206 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2207 let inner = self
2208 .inner
2209 .as_ref()
2210 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2211 inner.__cuda_array_interface__(py)
2212 }
2213
2214 fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
2215 let inner = self
2216 .inner
2217 .as_ref()
2218 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2219 inner.__dlpack_device__()
2220 }
2221
2222 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2223 fn __dlpack__<'py>(
2224 &mut self,
2225 py: Python<'py>,
2226 stream: Option<PyObject>,
2227 max_version: Option<PyObject>,
2228 dl_device: Option<PyObject>,
2229 copy: Option<PyObject>,
2230 ) -> PyResult<PyObject> {
2231 if let Some(ref s_obj) = stream {
2232 if let Ok(s) = s_obj.extract::<usize>(py) {
2233 if s == 0 {
2234 return Err(PyValueError::new_err(
2235 "__dlpack__ stream=0 is invalid for CUDA",
2236 ));
2237 }
2238 }
2239 }
2240
2241 let mut inner = self
2242 .inner
2243 .take()
2244 .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
2245 let capsule = inner.__dlpack__(py, stream, max_version, dl_device, copy)?;
2246 Ok(capsule)
2247 }
2248}
2249
2250#[cfg(all(feature = "python", feature = "cuda"))]
2251impl FisherDeviceArrayF32Py {
2252 pub fn new_from_rust(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
2253 Self {
2254 inner: Some(DeviceArrayF32Py {
2255 inner,
2256 _ctx: Some(ctx_guard),
2257 device_id: Some(device_id),
2258 }),
2259 }
2260 }
2261}
2262
2263#[cfg(all(feature = "python", feature = "cuda"))]
2264#[pyfunction(name = "fisher_cuda_batch_dev")]
2265#[pyo3(signature = (high_f32, low_f32, period_range, device_id=0))]
2266pub fn fisher_cuda_batch_dev_py<'py>(
2267 py: Python<'py>,
2268 high_f32: numpy::PyReadonlyArray1<'py, f32>,
2269 low_f32: numpy::PyReadonlyArray1<'py, f32>,
2270 period_range: (usize, usize, usize),
2271 device_id: usize,
2272) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2273 use numpy::IntoPyArray;
2274 if !cuda_available() {
2275 return Err(PyValueError::new_err("CUDA not available"));
2276 }
2277 let high = high_f32.as_slice()?;
2278 let low = low_f32.as_slice()?;
2279 let sweep = FisherBatchRange {
2280 period: period_range,
2281 };
2282 let cuda = CudaFisher::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2283 let ctx_guard = cuda.context_arc();
2284 let dev_id = cuda.device_id();
2285 let ((pair, combos)) = py.allow_threads(|| {
2286 cuda.fisher_batch_dev(high, low, &sweep)
2287 .map_err(|e| PyValueError::new_err(e.to_string()))
2288 })?;
2289 let dict = pyo3::types::PyDict::new(py);
2290 dict.set_item(
2291 "fisher",
2292 Py::new(
2293 py,
2294 FisherDeviceArrayF32Py::new_from_rust(pair.fisher, ctx_guard.clone(), dev_id),
2295 )?,
2296 )?;
2297 dict.set_item(
2298 "signal",
2299 Py::new(
2300 py,
2301 FisherDeviceArrayF32Py::new_from_rust(pair.signal, ctx_guard, dev_id),
2302 )?,
2303 )?;
2304 dict.set_item(
2305 "periods",
2306 combos
2307 .iter()
2308 .map(|p| p.period.unwrap() as u64)
2309 .collect::<Vec<_>>()
2310 .into_pyarray(py),
2311 )?;
2312 dict.set_item("rows", combos.len())?;
2313 dict.set_item("cols", high.len())?;
2314 Ok(dict)
2315}
2316
2317#[cfg(all(feature = "python", feature = "cuda"))]
2318#[pyfunction(name = "fisher_cuda_many_series_one_param_dev")]
2319#[pyo3(signature = (high_tm_f32, low_tm_f32, cols, rows, period, device_id=0))]
2320pub fn fisher_cuda_many_series_one_param_dev_py<'py>(
2321 py: Python<'py>,
2322 high_tm_f32: numpy::PyReadonlyArray1<'py, f32>,
2323 low_tm_f32: numpy::PyReadonlyArray1<'py, f32>,
2324 cols: usize,
2325 rows: usize,
2326 period: usize,
2327 device_id: usize,
2328) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2329 if !cuda_available() {
2330 return Err(PyValueError::new_err("CUDA not available"));
2331 }
2332 let high_tm = high_tm_f32.as_slice()?;
2333 let low_tm = low_tm_f32.as_slice()?;
2334 let cuda = CudaFisher::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2335 let ctx_guard = cuda.context_arc();
2336 let dev_id = cuda.device_id();
2337 let pair = py.allow_threads(|| {
2338 cuda.fisher_many_series_one_param_time_major_dev(high_tm, low_tm, cols, rows, period)
2339 .map_err(|e| PyValueError::new_err(e.to_string()))
2340 })?;
2341 let dict = pyo3::types::PyDict::new(py);
2342 dict.set_item(
2343 "fisher",
2344 Py::new(
2345 py,
2346 FisherDeviceArrayF32Py::new_from_rust(pair.fisher, ctx_guard.clone(), dev_id),
2347 )?,
2348 )?;
2349 dict.set_item(
2350 "signal",
2351 Py::new(
2352 py,
2353 FisherDeviceArrayF32Py::new_from_rust(pair.signal, ctx_guard, dev_id),
2354 )?,
2355 )?;
2356 dict.set_item("rows", rows)?;
2357 dict.set_item("cols", cols)?;
2358 Ok(dict)
2359}
2360
2361#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2362#[wasm_bindgen]
2363pub struct FisherResult {
2364 values: Vec<f64>,
2365 rows: usize,
2366 cols: usize,
2367}
2368
2369#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2370#[wasm_bindgen]
2371impl FisherResult {
2372 #[wasm_bindgen(getter)]
2373 pub fn values(&self) -> Vec<f64> {
2374 self.values.clone()
2375 }
2376
2377 #[wasm_bindgen(getter)]
2378 pub fn rows(&self) -> usize {
2379 self.rows
2380 }
2381
2382 #[wasm_bindgen(getter)]
2383 pub fn cols(&self) -> usize {
2384 self.cols
2385 }
2386}
2387
2388#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2389#[wasm_bindgen]
2390pub fn fisher_js(high: &[f64], low: &[f64], period: usize) -> Result<FisherResult, JsValue> {
2391 if high.len() != low.len() {
2392 return Err(JsValue::from_str("Mismatched data length"));
2393 }
2394 let len = high.len();
2395 let params = FisherParams {
2396 period: Some(period),
2397 };
2398 let input = FisherInput::from_slices(high, low, params);
2399
2400 let total = len
2401 .checked_mul(2)
2402 .ok_or_else(|| JsValue::from_str("fisher_js: len*2 overflow"))?;
2403 let mut out = vec![0.0_f64; total];
2404 let (fisher_out, signal_out) = out.split_at_mut(len);
2405
2406 fisher_into_slice(fisher_out, signal_out, &input, detect_best_kernel())
2407 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2408
2409 Ok(FisherResult {
2410 values: out,
2411 rows: 2,
2412 cols: len,
2413 })
2414}
2415
2416#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2417#[wasm_bindgen]
2418pub fn fisher_into(
2419 high_ptr: *const f64,
2420 low_ptr: *const f64,
2421 out_ptr: *mut f64,
2422 len: usize,
2423 period: usize,
2424) -> Result<(), JsValue> {
2425 if high_ptr.is_null() || low_ptr.is_null() || out_ptr.is_null() {
2426 return Err(JsValue::from_str("null pointer passed to fisher_into"));
2427 }
2428 unsafe {
2429 let high = std::slice::from_raw_parts(high_ptr, len);
2430 let low = std::slice::from_raw_parts(low_ptr, len);
2431 if high.len() != low.len() {
2432 return Err(JsValue::from_str("Mismatched data length"));
2433 }
2434 let total = len
2435 .checked_mul(2)
2436 .ok_or_else(|| JsValue::from_str("fisher_into: len*2 overflow"))?;
2437 let out = std::slice::from_raw_parts_mut(out_ptr, total);
2438 let (fisher_out, signal_out) = out.split_at_mut(len);
2439
2440 let params = FisherParams {
2441 period: Some(period),
2442 };
2443 let input = FisherInput::from_slices(high, low, params);
2444
2445 fisher_into_slice(fisher_out, signal_out, &input, detect_best_kernel())
2446 .map_err(|e| JsValue::from_str(&e.to_string()))
2447 }
2448}
2449
2450#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2451#[wasm_bindgen]
2452pub fn fisher_alloc(len: usize) -> *mut f64 {
2453 let mut vec = Vec::<f64>::with_capacity(len);
2454 let ptr = vec.as_mut_ptr();
2455 std::mem::forget(vec);
2456 ptr
2457}
2458
2459#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2460#[wasm_bindgen]
2461pub fn fisher_free(ptr: *mut f64, len: usize) {
2462 if !ptr.is_null() {
2463 unsafe {
2464 let _ = Vec::from_raw_parts(ptr, len, len);
2465 }
2466 }
2467}
2468
2469#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2470#[derive(Serialize, Deserialize)]
2471pub struct FisherBatchConfig {
2472 pub period_range: (usize, usize, usize),
2473}
2474
2475#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2476#[derive(Serialize, Deserialize)]
2477pub struct FisherBatchJsOutput {
2478 pub values: Vec<f64>,
2479 pub rows: usize,
2480 pub cols: usize,
2481 pub combos: Vec<FisherParams>,
2482}
2483
2484#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2485#[wasm_bindgen(js_name = fisher_batch)]
2486pub fn fisher_batch_unified_js(
2487 high: &[f64],
2488 low: &[f64],
2489 config: JsValue,
2490) -> Result<JsValue, JsValue> {
2491 if high.len() != low.len() {
2492 return Err(JsValue::from_str("Mismatched data length"));
2493 }
2494 let config: FisherBatchConfig = serde_wasm_bindgen::from_value(config)
2495 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2496
2497 let sweep = FisherBatchRange {
2498 period: config.period_range,
2499 };
2500
2501 let combos = expand_grid(&sweep);
2502 let cols = high.len();
2503 let rows = combos.len();
2504 let total_elems = rows
2505 .checked_mul(cols)
2506 .ok_or_else(|| JsValue::from_str("fisher_batch_unified_js: rows*cols overflow"))?;
2507
2508 let mut fisher = vec![0.0; total_elems];
2509 let mut signal = vec![0.0; total_elems];
2510
2511 fisher_batch_inner_into(
2512 high,
2513 low,
2514 &sweep,
2515 detect_best_kernel(),
2516 false,
2517 &mut fisher,
2518 &mut signal,
2519 )
2520 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2521
2522 let values_capacity = total_elems
2523 .checked_mul(2)
2524 .ok_or_else(|| JsValue::from_str("fisher_batch_unified_js: values capacity overflow"))?;
2525 let mut values = Vec::with_capacity(values_capacity);
2526 for r in 0..rows {
2527 values.extend_from_slice(&fisher[r * cols..(r + 1) * cols]);
2528 values.extend_from_slice(&signal[r * cols..(r + 1) * cols]);
2529 }
2530
2531 let js = FisherBatchJsOutput {
2532 values,
2533 rows: 2 * rows,
2534 cols,
2535 combos,
2536 };
2537 serde_wasm_bindgen::to_value(&js)
2538 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2539}
2540
2541#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2542#[wasm_bindgen]
2543pub fn fisher_batch_into(
2544 high_ptr: *const f64,
2545 low_ptr: *const f64,
2546 fisher_ptr: *mut f64,
2547 signal_ptr: *mut f64,
2548 len: usize,
2549 period_start: usize,
2550 period_end: usize,
2551 period_step: usize,
2552) -> Result<usize, JsValue> {
2553 if high_ptr.is_null() || low_ptr.is_null() || fisher_ptr.is_null() || signal_ptr.is_null() {
2554 return Err(JsValue::from_str("Null pointer provided"));
2555 }
2556
2557 unsafe {
2558 let high = std::slice::from_raw_parts(high_ptr, len);
2559 let low = std::slice::from_raw_parts(low_ptr, len);
2560
2561 let sweep = FisherBatchRange {
2562 period: (period_start, period_end, period_step),
2563 };
2564
2565 let combos = expand_grid(&sweep);
2566 let rows = combos.len();
2567 let total_size = rows
2568 .checked_mul(len)
2569 .ok_or_else(|| JsValue::from_str("fisher_batch_into: rows*len overflow"))?;
2570
2571 let fisher_out = std::slice::from_raw_parts_mut(fisher_ptr, total_size);
2572 let signal_out = std::slice::from_raw_parts_mut(signal_ptr, total_size);
2573
2574 let output = fisher_batch_inner(high, low, &sweep, Kernel::Auto, false)
2575 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2576
2577 fisher_out.copy_from_slice(&output.fisher);
2578 signal_out.copy_from_slice(&output.signal);
2579
2580 Ok(rows)
2581 }
2582}
2583
2584#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2585#[wasm_bindgen]
2586#[deprecated(
2587 since = "1.0.0",
2588 note = "For streaming patterns, use the fast/unsafe API with persistent buffers"
2589)]
2590pub struct FisherContext {
2591 period: usize,
2592 kernel: Kernel,
2593}
2594
2595#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2596#[wasm_bindgen]
2597#[allow(deprecated)]
2598impl FisherContext {
2599 #[wasm_bindgen(constructor)]
2600 #[deprecated(
2601 since = "1.0.0",
2602 note = "For streaming patterns, use the fast/unsafe API with persistent buffers"
2603 )]
2604 pub fn new(period: usize) -> Result<FisherContext, JsValue> {
2605 if period == 0 {
2606 return Err(JsValue::from_str("Invalid period: 0"));
2607 }
2608
2609 Ok(FisherContext {
2610 period,
2611 kernel: detect_best_kernel(),
2612 })
2613 }
2614
2615 pub fn update_into(
2616 &self,
2617 high_ptr: *const f64,
2618 low_ptr: *const f64,
2619 fisher_ptr: *mut f64,
2620 signal_ptr: *mut f64,
2621 len: usize,
2622 ) -> Result<(), JsValue> {
2623 if high_ptr.is_null() || low_ptr.is_null() || fisher_ptr.is_null() || signal_ptr.is_null() {
2624 return Err(JsValue::from_str("Null pointer provided"));
2625 }
2626
2627 unsafe {
2628 let high = std::slice::from_raw_parts(high_ptr, len);
2629 let low = std::slice::from_raw_parts(low_ptr, len);
2630 let fisher_out = std::slice::from_raw_parts_mut(fisher_ptr, len);
2631 let signal_out = std::slice::from_raw_parts_mut(signal_ptr, len);
2632
2633 let params = FisherParams {
2634 period: Some(self.period),
2635 };
2636 let input = FisherInput::from_slices(high, low, params);
2637
2638 fisher_into_slice(fisher_out, signal_out, &input, self.kernel)
2639 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2640
2641 Ok(())
2642 }
2643 }
2644
2645 pub fn get_period(&self) -> usize {
2646 self.period
2647 }
2648
2649 pub fn get_warmup_period(&self) -> usize {
2650 self.period - 1
2651 }
2652}