1use crate::error::{FFTError, FFTResult};
7use rustfft::{num_complex::Complex as RustComplex, FftPlanner};
8use scirs2_core::ndarray::{Array2, ArrayView2};
9use scirs2_core::numeric::Complex64;
10use scirs2_core::numeric::NumCast;
11use std::any::Any;
12use std::fmt::Debug;
13use std::num::NonZeroUsize;
14
15#[allow(dead_code)]
17fn downcast_to_complex<T: 'static>(value: &T) -> Option<Complex64> {
18 if let Some(complex) = (value as &dyn Any).downcast_ref::<Complex64>() {
20 return Some(*complex);
21 }
22
23 if let Some(complex) = (value as &dyn Any).downcast_ref::<scirs2_core::numeric::Complex<f32>>()
25 {
26 return Some(Complex64::new(complex.re as f64, complex.im as f64));
27 }
28
29 if let Some(complex) = (value as &dyn Any).downcast_ref::<RustComplex<f64>>() {
31 return Some(Complex64::new(complex.re, complex.im));
32 }
33
34 if let Some(complex) = (value as &dyn Any).downcast_ref::<RustComplex<f32>>() {
35 return Some(Complex64::new(complex.re as f64, complex.im as f64));
36 }
37
38 None
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum FftMode {
44 Forward,
46 Inverse,
48}
49
50#[allow(dead_code)]
92pub fn fft_inplace(
93 input: &mut [Complex64],
94 output: &mut [Complex64],
95 mode: FftMode,
96 normalize: bool,
97) -> FFTResult<usize> {
98 let n = input.len();
99
100 if n == 0 {
101 return Err(FFTError::ValueError("Input array is empty".to_string()));
102 }
103
104 if output.len() < n {
105 return Err(FFTError::ValueError(format!(
106 "Output buffer is too small: got {}, need {}",
107 output.len(),
108 n
109 )));
110 }
111
112 let use_simd = n >= 32 && crate::simd_fft::simd_support_available();
114
115 if use_simd {
116 let result = match mode {
118 FftMode::Forward => crate::simd_fft::fft_adaptive(
119 input,
120 if normalize { Some("forward") } else { None },
121 )?,
122 FftMode::Inverse => crate::simd_fft::ifft_adaptive(
123 input,
124 if normalize { Some("backward") } else { None },
125 )?,
126 };
127
128 for (i, &val) in result.iter().enumerate() {
130 input[i] = val;
131 output[i] = val;
132 }
133
134 return Ok(n);
135 }
136
137 let mut planner = FftPlanner::new();
140 let fft = match mode {
141 FftMode::Forward => planner.plan_fft_forward(n),
142 FftMode::Inverse => planner.plan_fft_inverse(n),
143 };
144
145 let mut buffer: Vec<RustComplex<f64>> = input
147 .iter()
148 .map(|&c| RustComplex::new(c.re, c.im))
149 .collect();
150
151 fft.process(&mut buffer);
153
154 let scale = if normalize { 1.0 / (n as f64) } else { 1.0 };
156
157 if scale != 1.0 && use_simd {
158 for (i, &c) in buffer.iter().enumerate() {
160 input[i] = Complex64::new(c.re, c.im);
161 }
162
163 crate::simd_fft::apply_simd_normalization(input, scale);
165
166 output.copy_from_slice(input);
168 } else {
169 for (i, &c) in buffer.iter().enumerate() {
171 input[i] = Complex64::new(c.re * scale, c.im * scale);
172 output[i] = input[i];
173 }
174 }
175
176 Ok(n)
177}
178
179#[allow(dead_code)]
198pub fn process_in_chunks<T, F>(
199 input: &[T],
200 chunk_size: usize,
201 mut op: F,
202) -> FFTResult<Vec<Complex64>>
203where
204 T: NumCast + Copy + Debug + 'static,
205 F: FnMut(&[T]) -> FFTResult<Vec<Complex64>>,
206{
207 if input.len() <= chunk_size {
208 return op(input);
210 }
211
212 let chunk_size_nz =
213 NonZeroUsize::new(chunk_size).unwrap_or(NonZeroUsize::new(1).expect("Operation failed"));
214 let n_chunks = input.len().div_ceil(chunk_size_nz.get());
215 let mut result = Vec::with_capacity(input.len());
216
217 for i in 0..n_chunks {
218 let start = i * chunk_size;
219 let end = (start + chunk_size).min(input.len());
220 let chunk = &input[start..end];
221
222 let chunk_result = op(chunk)?;
223 result.extend(chunk_result);
224 }
225
226 Ok(result)
227}
228
229#[allow(dead_code)]
249pub fn fft2_efficient<T>(
250 input: &ArrayView2<T>,
251 shape: Option<(usize, usize)>,
252 mode: FftMode,
253 normalize: bool,
254) -> FFTResult<Array2<Complex64>>
255where
256 T: NumCast + Copy + Debug + 'static,
257{
258 let (n_rows, n_cols) = input.dim();
259 let (n_rows_out, n_cols_out) = shape.unwrap_or((n_rows, n_cols));
260
261 if n_rows_out == 0 || n_cols_out == 0 {
263 return Err(FFTError::ValueError(
264 "Output dimensions must be positive".to_string(),
265 ));
266 }
267
268 let mut complex_input = Array2::zeros((n_rows_out, n_cols_out));
270 for r in 0..n_rows.min(n_rows_out) {
271 for c in 0..n_cols.min(n_cols_out) {
272 let val = input[[r, c]];
273 match NumCast::from(val) {
274 Some(val_f64) => {
275 complex_input[[r, c]] = Complex64::new(val_f64, 0.0);
276 }
277 None => {
278 if let Some(complex_val) = downcast_to_complex::<T>(&val) {
280 complex_input[[r, c]] = complex_val;
281 } else {
282 return Err(FFTError::ValueError(format!(
283 "Could not convert {val:?} to f64 or Complex64"
284 )));
285 }
286 }
287 }
288 }
289 }
290
291 let mut buffer = complex_input
293 .as_slice_mut()
294 .expect("Operation failed")
295 .to_vec();
296
297 let mut planner = FftPlanner::new();
299
300 let _row_buffer = vec![Complex64::new(0.0, 0.0); n_cols_out];
302
303 for r in 0..n_rows_out {
305 let row_start = r * n_cols_out;
306 let row_end = row_start + n_cols_out;
307 let row_slice = &mut buffer[row_start..row_end];
308
309 let row_fft = match mode {
310 FftMode::Forward => planner.plan_fft_forward(n_cols_out),
311 FftMode::Inverse => planner.plan_fft_inverse(n_cols_out),
312 };
313
314 let mut row_data: Vec<RustComplex<f64>> = row_slice
316 .iter()
317 .map(|&c| RustComplex::new(c.re, c.im))
318 .collect();
319
320 row_fft.process(&mut row_data);
322
323 for (i, &c) in row_data.iter().enumerate() {
325 row_slice[i] = Complex64::new(c.re, c.im);
326 }
327 }
328
329 let mut transposed = vec![Complex64::new(0.0, 0.0); n_rows_out * n_cols_out];
331
332 for r in 0..n_rows_out {
334 for c in 0..n_cols_out {
335 let src_idx = r * n_cols_out + c;
336 let dst_idx = c * n_rows_out + r;
337 transposed[dst_idx] = buffer[src_idx];
338 }
339 }
340
341 let _col_buffer = vec![Complex64::new(0.0, 0.0); n_rows_out];
343
344 for c in 0..n_cols_out {
346 let col_start = c * n_rows_out;
347 let col_end = col_start + n_rows_out;
348 let col_slice = &mut transposed[col_start..col_end];
349
350 let col_fft = match mode {
351 FftMode::Forward => planner.plan_fft_forward(n_rows_out),
352 FftMode::Inverse => planner.plan_fft_inverse(n_rows_out),
353 };
354
355 let mut col_data: Vec<RustComplex<f64>> = col_slice
357 .iter()
358 .map(|&c| RustComplex::new(c.re, c.im))
359 .collect();
360
361 col_fft.process(&mut col_data);
363
364 for (i, &c) in col_data.iter().enumerate() {
366 col_slice[i] = Complex64::new(c.re, c.im);
367 }
368 }
369
370 let scale = if normalize {
372 1.0 / ((n_rows_out * n_cols_out) as f64)
373 } else {
374 1.0
375 };
376
377 let mut result = Array2::zeros((n_rows_out, n_cols_out));
378
379 for r in 0..n_rows_out {
381 for c in 0..n_cols_out {
382 let src_idx = c * n_rows_out + r;
383 let val = transposed[src_idx];
384 result[[r, c]] = Complex64::new(val.re * scale, val.im * scale);
385 }
386 }
387
388 Ok(result)
389}
390
391#[allow(dead_code)]
411pub fn fft_streaming<T>(
412 input: &[T],
413 n: Option<usize>,
414 mode: FftMode,
415 chunk_size: Option<usize>,
416) -> FFTResult<Vec<Complex64>>
417where
418 T: NumCast + Copy + Debug + 'static,
419{
420 let input_length = input.len();
421 let n_val = n.unwrap_or(input_length);
422 let chunk_size_val = chunk_size.unwrap_or(
423 if input_length > 1_000_000 {
425 1_048_576
427 } else if input_length > 100_000 {
428 65_536
430 } else {
431 input_length
433 },
434 );
435
436 if input_length <= chunk_size_val || n_val <= chunk_size_val {
438 let mut complex_input: Vec<Complex64> = Vec::with_capacity(input_length);
440
441 for &val in input {
442 match NumCast::from(val) {
443 Some(val_f64) => {
444 complex_input.push(Complex64::new(val_f64, 0.0));
445 }
446 None => {
447 if let Some(complex_val) = downcast_to_complex::<T>(&val) {
449 complex_input.push(complex_val);
450 } else {
451 return Err(FFTError::ValueError(format!(
452 "Could not convert {val:?} to f64 or Complex64"
453 )));
454 }
455 }
456 }
457 }
458
459 match n_val.cmp(&complex_input.len()) {
461 std::cmp::Ordering::Less => {
462 complex_input.truncate(n_val);
464 }
465 std::cmp::Ordering::Greater => {
466 complex_input.resize(n_val, Complex64::new(0.0, 0.0));
468 }
469 std::cmp::Ordering::Equal => {
470 }
472 }
473
474 let mut planner = FftPlanner::new();
476 let fft = match mode {
477 FftMode::Forward => planner.plan_fft_forward(n_val),
478 FftMode::Inverse => planner.plan_fft_inverse(n_val),
479 };
480
481 let mut buffer: Vec<RustComplex<f64>> = complex_input
483 .iter()
484 .map(|&c| RustComplex::new(c.re, c.im))
485 .collect();
486
487 fft.process(&mut buffer);
489
490 let scale = if mode == FftMode::Inverse {
492 1.0 / (n_val as f64)
493 } else {
494 1.0
495 };
496
497 let result: Vec<Complex64> = buffer
498 .into_iter()
499 .map(|c| Complex64::new(c.re * scale, c.im * scale))
500 .collect();
501
502 return Ok(result);
503 }
504
505 let chunk_size_nz = NonZeroUsize::new(chunk_size_val)
507 .unwrap_or(NonZeroUsize::new(1).expect("Operation failed"));
508 let n_chunks = n_val.div_ceil(chunk_size_nz.get());
509 let mut result = Vec::with_capacity(n_val);
510
511 for i in 0..n_chunks {
512 let start = i * chunk_size_val;
513 let end = (start + chunk_size_val).min(n_val);
514 let chunk_size = end - start;
515
516 let mut chunk_input = Vec::with_capacity(chunk_size);
518
519 if start < input_length {
520 let input_end = end.min(input_length);
522 for val in input[start..input_end].iter() {
523 match NumCast::from(*val) {
524 Some(val_f64) => {
525 chunk_input.push(Complex64::new(val_f64, 0.0));
526 }
527 None => {
528 if let Some(complex_val) = downcast_to_complex::<T>(val) {
530 chunk_input.push(complex_val);
531 } else {
532 return Err(FFTError::ValueError(format!(
533 "Could not convert {val:?} to f64 or Complex64"
534 )));
535 }
536 }
537 }
538 }
539
540 if input_end < end {
542 chunk_input.resize(chunk_size, Complex64::new(0.0, 0.0));
543 }
544 } else {
545 chunk_input.resize(chunk_size, Complex64::new(0.0, 0.0));
547 }
548
549 let mut planner = FftPlanner::new();
551 let fft = match mode {
552 FftMode::Forward => planner.plan_fft_forward(chunk_size),
553 FftMode::Inverse => planner.plan_fft_inverse(chunk_size),
554 };
555
556 let mut buffer: Vec<RustComplex<f64>> = chunk_input
558 .iter()
559 .map(|&c| RustComplex::new(c.re, c.im))
560 .collect();
561
562 fft.process(&mut buffer);
564
565 let scale = if mode == FftMode::Inverse {
567 1.0 / (chunk_size as f64)
568 } else {
569 1.0
570 };
571
572 let chunk_result: Vec<Complex64> = buffer
573 .into_iter()
574 .map(|c| Complex64::new(c.re * scale, c.im * scale))
575 .collect();
576
577 result.extend(chunk_result);
579 }
580
581 if mode == FftMode::Inverse {
584 let full_scale = 1.0 / (n_val as f64);
585 let chunk_scale = 1.0 / (chunk_size_val as f64);
586 let scale_adjustment = full_scale / chunk_scale;
587
588 for val in &mut result {
589 val.re *= scale_adjustment;
590 val.im *= scale_adjustment;
591 }
592 }
593
594 Ok(result)
595}
596
597#[cfg(test)]
598mod tests {
599 use super::*;
600 use approx::assert_relative_eq;
601 use scirs2_core::ndarray::array;
602
603 #[test]
604 fn test_fft_inplace() {
605 let mut input = vec![
607 Complex64::new(1.0, 0.0),
608 Complex64::new(2.0, 0.0),
609 Complex64::new(3.0, 0.0),
610 Complex64::new(4.0, 0.0),
611 ];
612 let mut output = vec![Complex64::new(0.0, 0.0); 4];
613
614 fft_inplace(&mut input, &mut output, FftMode::Forward, false).expect("Operation failed");
616
617 assert_relative_eq!(input[0].re, 10.0, epsilon = 1e-10);
619
620 fft_inplace(&mut input, &mut output, FftMode::Inverse, true).expect("Operation failed");
622
623 assert_relative_eq!(input[0].re, 1.0, epsilon = 1e-10);
625 assert_relative_eq!(input[1].re, 2.0, epsilon = 1e-10);
626 assert_relative_eq!(input[2].re, 3.0, epsilon = 1e-10);
627 assert_relative_eq!(input[3].re, 4.0, epsilon = 1e-10);
628 }
629
630 #[test]
631 fn test_fft2_efficient() {
632 let arr = array![[1.0, 2.0], [3.0, 4.0]];
634
635 let spectrum_2d =
637 fft2_efficient(&arr.view(), None, FftMode::Forward, false).expect("Operation failed");
638
639 assert_relative_eq!(spectrum_2d[[0, 0]].re, 10.0, epsilon = 1e-10);
641
642 let recovered = fft2_efficient(&spectrum_2d.view(), None, FftMode::Inverse, true)
644 .expect("Operation failed");
645
646 assert_relative_eq!(recovered[[0, 0]].re, 1.0, epsilon = 1e-10);
648 assert_relative_eq!(recovered[[0, 1]].re, 2.0, epsilon = 1e-10);
649 assert_relative_eq!(recovered[[1, 0]].re, 3.0, epsilon = 1e-10);
650 assert_relative_eq!(recovered[[1, 1]].re, 4.0, epsilon = 1e-10);
651 }
652
653 #[test]
654 fn test_fft_streaming() {
655 let signal = vec![1.0, 2.0, 3.0, 4.0];
657
658 let result =
660 fft_streaming(&signal, None, FftMode::Forward, None).expect("Operation failed");
661
662 assert_relative_eq!(result[0].re, 10.0, epsilon = 1e-10);
664
665 let inverse =
667 fft_streaming(&result, None, FftMode::Inverse, None).expect("Operation failed");
668
669 assert_relative_eq!(inverse[0].re, 1.0, epsilon = 1e-10);
671 assert_relative_eq!(inverse[1].re, 2.0, epsilon = 1e-10);
672 assert_relative_eq!(inverse[2].re, 3.0, epsilon = 1e-10);
673 assert_relative_eq!(inverse[3].re, 4.0, epsilon = 1e-10);
674
675 let result_chunked = fft_streaming(&signal, None, FftMode::Forward, Some(signal.len()))
677 .expect("Operation failed");
678
679 for (a, b) in result.iter().zip(result_chunked.iter()) {
681 assert_relative_eq!(a.re, b.re, epsilon = 1e-10);
682 assert_relative_eq!(a.im, b.im, epsilon = 1e-10);
683 }
684 }
685}