1use crate::error::{FFTError, FFTResult};
7use rustfft::{num_complex::Complex as RustComplex, FftPlanner};
8use scirs2_core::ndarray::{Array2, ArrayView2};
9use scirs2_core::numeric::Complex64;
10use scirs2_core::numeric::NumCast;
11use std::any::Any;
12use std::fmt::Debug;
13use std::num::NonZeroUsize;
14
15#[allow(dead_code)]
17fn downcast_to_complex<T: 'static>(value: &T) -> Option<Complex64> {
18 if let Some(complex) = (value as &dyn Any).downcast_ref::<Complex64>() {
20 return Some(*complex);
21 }
22
23 if let Some(complex) = (value as &dyn Any).downcast_ref::<scirs2_core::numeric::Complex<f32>>()
25 {
26 return Some(Complex64::new(complex.re as f64, complex.im as f64));
27 }
28
29 if let Some(complex) = (value as &dyn Any).downcast_ref::<RustComplex<f64>>() {
31 return Some(Complex64::new(complex.re, complex.im));
32 }
33
34 if let Some(complex) = (value as &dyn Any).downcast_ref::<RustComplex<f32>>() {
35 return Some(Complex64::new(complex.re as f64, complex.im as f64));
36 }
37
38 None
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum FftMode {
44 Forward,
46 Inverse,
48}
49
50#[allow(dead_code)]
92pub fn fft_inplace(
93 input: &mut [Complex64],
94 output: &mut [Complex64],
95 mode: FftMode,
96 normalize: bool,
97) -> FFTResult<usize> {
98 let n = input.len();
99
100 if n == 0 {
101 return Err(FFTError::ValueError("Input array is empty".to_string()));
102 }
103
104 if output.len() < n {
105 return Err(FFTError::ValueError(format!(
106 "Output buffer is too small: got {}, need {}",
107 output.len(),
108 n
109 )));
110 }
111
112 let use_simd = n >= 32 && crate::simd_fft::simd_support_available();
114
115 if use_simd {
116 let result = match mode {
118 FftMode::Forward => crate::simd_fft::fft_adaptive(
119 input,
120 if normalize { Some("forward") } else { None },
121 )?,
122 FftMode::Inverse => crate::simd_fft::ifft_adaptive(
123 input,
124 if normalize { Some("backward") } else { None },
125 )?,
126 };
127
128 for (i, &val) in result.iter().enumerate() {
130 input[i] = val;
131 output[i] = val;
132 }
133
134 return Ok(n);
135 }
136
137 let mut planner = FftPlanner::new();
140 let fft = match mode {
141 FftMode::Forward => planner.plan_fft_forward(n),
142 FftMode::Inverse => planner.plan_fft_inverse(n),
143 };
144
145 let mut buffer: Vec<RustComplex<f64>> = input
147 .iter()
148 .map(|&c| RustComplex::new(c.re, c.im))
149 .collect();
150
151 fft.process(&mut buffer);
153
154 let scale = if normalize { 1.0 / (n as f64) } else { 1.0 };
156
157 if scale != 1.0 && use_simd {
158 for (i, &c) in buffer.iter().enumerate() {
160 input[i] = Complex64::new(c.re, c.im);
161 }
162
163 crate::simd_fft::apply_simd_normalization(input, scale);
165
166 output.copy_from_slice(input);
168 } else {
169 for (i, &c) in buffer.iter().enumerate() {
171 input[i] = Complex64::new(c.re * scale, c.im * scale);
172 output[i] = input[i];
173 }
174 }
175
176 Ok(n)
177}
178
179#[allow(dead_code)]
198pub fn process_in_chunks<T, F>(
199 input: &[T],
200 chunk_size: usize,
201 mut op: F,
202) -> FFTResult<Vec<Complex64>>
203where
204 T: NumCast + Copy + Debug + 'static,
205 F: FnMut(&[T]) -> FFTResult<Vec<Complex64>>,
206{
207 if input.len() <= chunk_size {
208 return op(input);
210 }
211
212 let chunk_size_nz = NonZeroUsize::new(chunk_size).unwrap_or(NonZeroUsize::new(1).unwrap());
213 let n_chunks = input.len().div_ceil(chunk_size_nz.get());
214 let mut result = Vec::with_capacity(input.len());
215
216 for i in 0..n_chunks {
217 let start = i * chunk_size;
218 let end = (start + chunk_size).min(input.len());
219 let chunk = &input[start..end];
220
221 let chunk_result = op(chunk)?;
222 result.extend(chunk_result);
223 }
224
225 Ok(result)
226}
227
228#[allow(dead_code)]
248pub fn fft2_efficient<T>(
249 input: &ArrayView2<T>,
250 shape: Option<(usize, usize)>,
251 mode: FftMode,
252 normalize: bool,
253) -> FFTResult<Array2<Complex64>>
254where
255 T: NumCast + Copy + Debug + 'static,
256{
257 let (n_rows, n_cols) = input.dim();
258 let (n_rows_out, n_cols_out) = shape.unwrap_or((n_rows, n_cols));
259
260 if n_rows_out == 0 || n_cols_out == 0 {
262 return Err(FFTError::ValueError(
263 "Output dimensions must be positive".to_string(),
264 ));
265 }
266
267 let mut complex_input = Array2::zeros((n_rows_out, n_cols_out));
269 for r in 0..n_rows.min(n_rows_out) {
270 for c in 0..n_cols.min(n_cols_out) {
271 let val = input[[r, c]];
272 match NumCast::from(val) {
273 Some(val_f64) => {
274 complex_input[[r, c]] = Complex64::new(val_f64, 0.0);
275 }
276 None => {
277 if let Some(complex_val) = downcast_to_complex::<T>(&val) {
279 complex_input[[r, c]] = complex_val;
280 } else {
281 return Err(FFTError::ValueError(format!(
282 "Could not convert {val:?} to f64 or Complex64"
283 )));
284 }
285 }
286 }
287 }
288 }
289
290 let mut buffer = complex_input.as_slice_mut().unwrap().to_vec();
292
293 let mut planner = FftPlanner::new();
295
296 let _row_buffer = vec![Complex64::new(0.0, 0.0); n_cols_out];
298
299 for r in 0..n_rows_out {
301 let row_start = r * n_cols_out;
302 let row_end = row_start + n_cols_out;
303 let row_slice = &mut buffer[row_start..row_end];
304
305 let row_fft = match mode {
306 FftMode::Forward => planner.plan_fft_forward(n_cols_out),
307 FftMode::Inverse => planner.plan_fft_inverse(n_cols_out),
308 };
309
310 let mut row_data: Vec<RustComplex<f64>> = row_slice
312 .iter()
313 .map(|&c| RustComplex::new(c.re, c.im))
314 .collect();
315
316 row_fft.process(&mut row_data);
318
319 for (i, &c) in row_data.iter().enumerate() {
321 row_slice[i] = Complex64::new(c.re, c.im);
322 }
323 }
324
325 let mut transposed = vec![Complex64::new(0.0, 0.0); n_rows_out * n_cols_out];
327
328 for r in 0..n_rows_out {
330 for c in 0..n_cols_out {
331 let src_idx = r * n_cols_out + c;
332 let dst_idx = c * n_rows_out + r;
333 transposed[dst_idx] = buffer[src_idx];
334 }
335 }
336
337 let _col_buffer = vec![Complex64::new(0.0, 0.0); n_rows_out];
339
340 for c in 0..n_cols_out {
342 let col_start = c * n_rows_out;
343 let col_end = col_start + n_rows_out;
344 let col_slice = &mut transposed[col_start..col_end];
345
346 let col_fft = match mode {
347 FftMode::Forward => planner.plan_fft_forward(n_rows_out),
348 FftMode::Inverse => planner.plan_fft_inverse(n_rows_out),
349 };
350
351 let mut col_data: Vec<RustComplex<f64>> = col_slice
353 .iter()
354 .map(|&c| RustComplex::new(c.re, c.im))
355 .collect();
356
357 col_fft.process(&mut col_data);
359
360 for (i, &c) in col_data.iter().enumerate() {
362 col_slice[i] = Complex64::new(c.re, c.im);
363 }
364 }
365
366 let scale = if normalize {
368 1.0 / ((n_rows_out * n_cols_out) as f64)
369 } else {
370 1.0
371 };
372
373 let mut result = Array2::zeros((n_rows_out, n_cols_out));
374
375 for r in 0..n_rows_out {
377 for c in 0..n_cols_out {
378 let src_idx = c * n_rows_out + r;
379 let val = transposed[src_idx];
380 result[[r, c]] = Complex64::new(val.re * scale, val.im * scale);
381 }
382 }
383
384 Ok(result)
385}
386
387#[allow(dead_code)]
407pub fn fft_streaming<T>(
408 input: &[T],
409 n: Option<usize>,
410 mode: FftMode,
411 chunk_size: Option<usize>,
412) -> FFTResult<Vec<Complex64>>
413where
414 T: NumCast + Copy + Debug + 'static,
415{
416 let input_length = input.len();
417 let n_val = n.unwrap_or(input_length);
418 let chunk_size_val = chunk_size.unwrap_or(
419 if input_length > 1_000_000 {
421 1_048_576
423 } else if input_length > 100_000 {
424 65_536
426 } else {
427 input_length
429 },
430 );
431
432 if input_length <= chunk_size_val || n_val <= chunk_size_val {
434 let mut complex_input: Vec<Complex64> = Vec::with_capacity(input_length);
436
437 for &val in input {
438 match NumCast::from(val) {
439 Some(val_f64) => {
440 complex_input.push(Complex64::new(val_f64, 0.0));
441 }
442 None => {
443 if let Some(complex_val) = downcast_to_complex::<T>(&val) {
445 complex_input.push(complex_val);
446 } else {
447 return Err(FFTError::ValueError(format!(
448 "Could not convert {val:?} to f64 or Complex64"
449 )));
450 }
451 }
452 }
453 }
454
455 match n_val.cmp(&complex_input.len()) {
457 std::cmp::Ordering::Less => {
458 complex_input.truncate(n_val);
460 }
461 std::cmp::Ordering::Greater => {
462 complex_input.resize(n_val, Complex64::new(0.0, 0.0));
464 }
465 std::cmp::Ordering::Equal => {
466 }
468 }
469
470 let mut planner = FftPlanner::new();
472 let fft = match mode {
473 FftMode::Forward => planner.plan_fft_forward(n_val),
474 FftMode::Inverse => planner.plan_fft_inverse(n_val),
475 };
476
477 let mut buffer: Vec<RustComplex<f64>> = complex_input
479 .iter()
480 .map(|&c| RustComplex::new(c.re, c.im))
481 .collect();
482
483 fft.process(&mut buffer);
485
486 let scale = if mode == FftMode::Inverse {
488 1.0 / (n_val as f64)
489 } else {
490 1.0
491 };
492
493 let result: Vec<Complex64> = buffer
494 .into_iter()
495 .map(|c| Complex64::new(c.re * scale, c.im * scale))
496 .collect();
497
498 return Ok(result);
499 }
500
501 let chunk_size_nz = NonZeroUsize::new(chunk_size_val).unwrap_or(NonZeroUsize::new(1).unwrap());
503 let n_chunks = n_val.div_ceil(chunk_size_nz.get());
504 let mut result = Vec::with_capacity(n_val);
505
506 for i in 0..n_chunks {
507 let start = i * chunk_size_val;
508 let end = (start + chunk_size_val).min(n_val);
509 let chunk_size = end - start;
510
511 let mut chunk_input = Vec::with_capacity(chunk_size);
513
514 if start < input_length {
515 let input_end = end.min(input_length);
517 for val in input[start..input_end].iter() {
518 match NumCast::from(*val) {
519 Some(val_f64) => {
520 chunk_input.push(Complex64::new(val_f64, 0.0));
521 }
522 None => {
523 if let Some(complex_val) = downcast_to_complex::<T>(val) {
525 chunk_input.push(complex_val);
526 } else {
527 return Err(FFTError::ValueError(format!(
528 "Could not convert {val:?} to f64 or Complex64"
529 )));
530 }
531 }
532 }
533 }
534
535 if input_end < end {
537 chunk_input.resize(chunk_size, Complex64::new(0.0, 0.0));
538 }
539 } else {
540 chunk_input.resize(chunk_size, Complex64::new(0.0, 0.0));
542 }
543
544 let mut planner = FftPlanner::new();
546 let fft = match mode {
547 FftMode::Forward => planner.plan_fft_forward(chunk_size),
548 FftMode::Inverse => planner.plan_fft_inverse(chunk_size),
549 };
550
551 let mut buffer: Vec<RustComplex<f64>> = chunk_input
553 .iter()
554 .map(|&c| RustComplex::new(c.re, c.im))
555 .collect();
556
557 fft.process(&mut buffer);
559
560 let scale = if mode == FftMode::Inverse {
562 1.0 / (chunk_size as f64)
563 } else {
564 1.0
565 };
566
567 let chunk_result: Vec<Complex64> = buffer
568 .into_iter()
569 .map(|c| Complex64::new(c.re * scale, c.im * scale))
570 .collect();
571
572 result.extend(chunk_result);
574 }
575
576 if mode == FftMode::Inverse {
579 let full_scale = 1.0 / (n_val as f64);
580 let chunk_scale = 1.0 / (chunk_size_val as f64);
581 let scale_adjustment = full_scale / chunk_scale;
582
583 for val in &mut result {
584 val.re *= scale_adjustment;
585 val.im *= scale_adjustment;
586 }
587 }
588
589 Ok(result)
590}
591
592#[cfg(test)]
593mod tests {
594 use super::*;
595 use approx::assert_relative_eq;
596 use scirs2_core::ndarray::array;
597
598 #[test]
599 fn test_fft_inplace() {
600 let mut input = vec![
602 Complex64::new(1.0, 0.0),
603 Complex64::new(2.0, 0.0),
604 Complex64::new(3.0, 0.0),
605 Complex64::new(4.0, 0.0),
606 ];
607 let mut output = vec![Complex64::new(0.0, 0.0); 4];
608
609 fft_inplace(&mut input, &mut output, FftMode::Forward, false).unwrap();
611
612 assert_relative_eq!(input[0].re, 10.0, epsilon = 1e-10);
614
615 fft_inplace(&mut input, &mut output, FftMode::Inverse, true).unwrap();
617
618 assert_relative_eq!(input[0].re, 1.0, epsilon = 1e-10);
620 assert_relative_eq!(input[1].re, 2.0, epsilon = 1e-10);
621 assert_relative_eq!(input[2].re, 3.0, epsilon = 1e-10);
622 assert_relative_eq!(input[3].re, 4.0, epsilon = 1e-10);
623 }
624
625 #[test]
626 fn test_fft2_efficient() {
627 let arr = array![[1.0, 2.0], [3.0, 4.0]];
629
630 let spectrum_2d = fft2_efficient(&arr.view(), None, FftMode::Forward, false).unwrap();
632
633 assert_relative_eq!(spectrum_2d[[0, 0]].re, 10.0, epsilon = 1e-10);
635
636 let recovered = fft2_efficient(&spectrum_2d.view(), None, FftMode::Inverse, true).unwrap();
638
639 assert_relative_eq!(recovered[[0, 0]].re, 1.0, epsilon = 1e-10);
641 assert_relative_eq!(recovered[[0, 1]].re, 2.0, epsilon = 1e-10);
642 assert_relative_eq!(recovered[[1, 0]].re, 3.0, epsilon = 1e-10);
643 assert_relative_eq!(recovered[[1, 1]].re, 4.0, epsilon = 1e-10);
644 }
645
646 #[test]
647 fn test_fft_streaming() {
648 let signal = vec![1.0, 2.0, 3.0, 4.0];
650
651 let result = fft_streaming(&signal, None, FftMode::Forward, None).unwrap();
653
654 assert_relative_eq!(result[0].re, 10.0, epsilon = 1e-10);
656
657 let inverse = fft_streaming(&result, None, FftMode::Inverse, None).unwrap();
659
660 assert_relative_eq!(inverse[0].re, 1.0, epsilon = 1e-10);
662 assert_relative_eq!(inverse[1].re, 2.0, epsilon = 1e-10);
663 assert_relative_eq!(inverse[2].re, 3.0, epsilon = 1e-10);
664 assert_relative_eq!(inverse[3].re, 4.0, epsilon = 1e-10);
665
666 let result_chunked =
668 fft_streaming(&signal, None, FftMode::Forward, Some(signal.len())).unwrap();
669
670 for (a, b) in result.iter().zip(result_chunked.iter()) {
672 assert_relative_eq!(a.re, b.re, epsilon = 1e-10);
673 assert_relative_eq!(a.im, b.im, epsilon = 1e-10);
674 }
675 }
676}