1use crate::error::{FFTError, FFTResult};
7use ndarray::{Array2, ArrayView2};
8use num_complex::Complex64;
9use num_traits::NumCast;
10use rustfft::{num_complex::Complex as RustComplex, FftPlanner};
11use std::any::Any;
12use std::fmt::Debug;
13use std::num::NonZeroUsize;
14
15#[allow(dead_code)]
17fn downcast_to_complex<T: 'static>(value: &T) -> Option<Complex64> {
18 if let Some(complex) = (value as &dyn Any).downcast_ref::<Complex64>() {
20 return Some(*complex);
21 }
22
23 if let Some(complex) = (value as &dyn Any).downcast_ref::<num_complex::Complex<f32>>() {
25 return Some(Complex64::new(complex.re as f64, complex.im as f64));
26 }
27
28 if let Some(complex) = (value as &dyn Any).downcast_ref::<RustComplex<f64>>() {
30 return Some(Complex64::new(complex.re, complex.im));
31 }
32
33 if let Some(complex) = (value as &dyn Any).downcast_ref::<RustComplex<f32>>() {
34 return Some(Complex64::new(complex.re as f64, complex.im as f64));
35 }
36
37 None
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum FftMode {
43 Forward,
45 Inverse,
47}
48
49#[allow(dead_code)]
91pub fn fft_inplace(
92 input: &mut [Complex64],
93 output: &mut [Complex64],
94 mode: FftMode,
95 normalize: bool,
96) -> FFTResult<usize> {
97 let n = input.len();
98
99 if n == 0 {
100 return Err(FFTError::ValueError("Input array is empty".to_string()));
101 }
102
103 if output.len() < n {
104 return Err(FFTError::ValueError(format!(
105 "Output buffer is too small: got {}, need {}",
106 output.len(),
107 n
108 )));
109 }
110
111 let use_simd = n >= 32 && crate::simd_fft::simd_support_available();
113
114 if use_simd {
115 let result = match mode {
117 FftMode::Forward => crate::simd_fft::fft_adaptive(
118 input,
119 if normalize { Some("forward") } else { None },
120 )?,
121 FftMode::Inverse => crate::simd_fft::ifft_adaptive(
122 input,
123 if normalize { Some("backward") } else { None },
124 )?,
125 };
126
127 for (i, &val) in result.iter().enumerate() {
129 input[i] = val;
130 output[i] = val;
131 }
132
133 return Ok(n);
134 }
135
136 let mut planner = FftPlanner::new();
139 let fft = match mode {
140 FftMode::Forward => planner.plan_fft_forward(n),
141 FftMode::Inverse => planner.plan_fft_inverse(n),
142 };
143
144 let mut buffer: Vec<RustComplex<f64>> = input
146 .iter()
147 .map(|&c| RustComplex::new(c.re, c.im))
148 .collect();
149
150 fft.process(&mut buffer);
152
153 let scale = if normalize { 1.0 / (n as f64) } else { 1.0 };
155
156 if scale != 1.0 && use_simd {
157 for (i, &c) in buffer.iter().enumerate() {
159 input[i] = Complex64::new(c.re, c.im);
160 }
161
162 crate::simd_fft::apply_simd_normalization(input, scale);
164
165 output.copy_from_slice(input);
167 } else {
168 for (i, &c) in buffer.iter().enumerate() {
170 input[i] = Complex64::new(c.re * scale, c.im * scale);
171 output[i] = input[i];
172 }
173 }
174
175 Ok(n)
176}
177
178#[allow(dead_code)]
197pub fn process_in_chunks<T, F>(
198 input: &[T],
199 chunk_size: usize,
200 mut op: F,
201) -> FFTResult<Vec<Complex64>>
202where
203 T: NumCast + Copy + Debug + 'static,
204 F: FnMut(&[T]) -> FFTResult<Vec<Complex64>>,
205{
206 if input.len() <= chunk_size {
207 return op(input);
209 }
210
211 let chunk_size_nz = NonZeroUsize::new(chunk_size).unwrap_or(NonZeroUsize::new(1).unwrap());
212 let n_chunks = input.len().div_ceil(chunk_size_nz.get());
213 let mut result = Vec::with_capacity(input.len());
214
215 for i in 0..n_chunks {
216 let start = i * chunk_size;
217 let end = (start + chunk_size).min(input.len());
218 let chunk = &input[start..end];
219
220 let chunk_result = op(chunk)?;
221 result.extend(chunk_result);
222 }
223
224 Ok(result)
225}
226
227#[allow(dead_code)]
247pub fn fft2_efficient<T>(
248 input: &ArrayView2<T>,
249 shape: Option<(usize, usize)>,
250 mode: FftMode,
251 normalize: bool,
252) -> FFTResult<Array2<Complex64>>
253where
254 T: NumCast + Copy + Debug + 'static,
255{
256 let (n_rows, n_cols) = input.dim();
257 let (n_rows_out, n_cols_out) = shape.unwrap_or((n_rows, n_cols));
258
259 if n_rows_out == 0 || n_cols_out == 0 {
261 return Err(FFTError::ValueError(
262 "Output dimensions must be positive".to_string(),
263 ));
264 }
265
266 let mut complex_input = Array2::zeros((n_rows_out, n_cols_out));
268 for r in 0..n_rows.min(n_rows_out) {
269 for c in 0..n_cols.min(n_cols_out) {
270 let val = input[[r, c]];
271 match num_traits::cast::cast::<T, f64>(val) {
272 Some(val_f64) => {
273 complex_input[[r, c]] = Complex64::new(val_f64, 0.0);
274 }
275 None => {
276 if let Some(complex_val) = downcast_to_complex::<T>(&val) {
278 complex_input[[r, c]] = complex_val;
279 } else {
280 return Err(FFTError::ValueError(format!(
281 "Could not convert {val:?} to f64 or Complex64"
282 )));
283 }
284 }
285 }
286 }
287 }
288
289 let mut buffer = complex_input.as_slice_mut().unwrap().to_vec();
291
292 let mut planner = FftPlanner::new();
294
295 let _row_buffer = vec![Complex64::new(0.0, 0.0); n_cols_out];
297
298 for r in 0..n_rows_out {
300 let row_start = r * n_cols_out;
301 let row_end = row_start + n_cols_out;
302 let row_slice = &mut buffer[row_start..row_end];
303
304 let row_fft = match mode {
305 FftMode::Forward => planner.plan_fft_forward(n_cols_out),
306 FftMode::Inverse => planner.plan_fft_inverse(n_cols_out),
307 };
308
309 let mut row_data: Vec<RustComplex<f64>> = row_slice
311 .iter()
312 .map(|&c| RustComplex::new(c.re, c.im))
313 .collect();
314
315 row_fft.process(&mut row_data);
317
318 for (i, &c) in row_data.iter().enumerate() {
320 row_slice[i] = Complex64::new(c.re, c.im);
321 }
322 }
323
324 let mut transposed = vec![Complex64::new(0.0, 0.0); n_rows_out * n_cols_out];
326
327 for r in 0..n_rows_out {
329 for c in 0..n_cols_out {
330 let src_idx = r * n_cols_out + c;
331 let dst_idx = c * n_rows_out + r;
332 transposed[dst_idx] = buffer[src_idx];
333 }
334 }
335
336 let _col_buffer = vec![Complex64::new(0.0, 0.0); n_rows_out];
338
339 for c in 0..n_cols_out {
341 let col_start = c * n_rows_out;
342 let col_end = col_start + n_rows_out;
343 let col_slice = &mut transposed[col_start..col_end];
344
345 let col_fft = match mode {
346 FftMode::Forward => planner.plan_fft_forward(n_rows_out),
347 FftMode::Inverse => planner.plan_fft_inverse(n_rows_out),
348 };
349
350 let mut col_data: Vec<RustComplex<f64>> = col_slice
352 .iter()
353 .map(|&c| RustComplex::new(c.re, c.im))
354 .collect();
355
356 col_fft.process(&mut col_data);
358
359 for (i, &c) in col_data.iter().enumerate() {
361 col_slice[i] = Complex64::new(c.re, c.im);
362 }
363 }
364
365 let scale = if normalize {
367 1.0 / ((n_rows_out * n_cols_out) as f64)
368 } else {
369 1.0
370 };
371
372 let mut result = Array2::zeros((n_rows_out, n_cols_out));
373
374 for r in 0..n_rows_out {
376 for c in 0..n_cols_out {
377 let src_idx = c * n_rows_out + r;
378 let val = transposed[src_idx];
379 result[[r, c]] = Complex64::new(val.re * scale, val.im * scale);
380 }
381 }
382
383 Ok(result)
384}
385
386#[allow(dead_code)]
406pub fn fft_streaming<T>(
407 input: &[T],
408 n: Option<usize>,
409 mode: FftMode,
410 chunk_size: Option<usize>,
411) -> FFTResult<Vec<Complex64>>
412where
413 T: NumCast + Copy + Debug + 'static,
414{
415 let input_length = input.len();
416 let n_val = n.unwrap_or(input_length);
417 let chunk_size_val = chunk_size.unwrap_or(
418 if input_length > 1_000_000 {
420 1_048_576
422 } else if input_length > 100_000 {
423 65_536
425 } else {
426 input_length
428 },
429 );
430
431 if input_length <= chunk_size_val || n_val <= chunk_size_val {
433 let mut complex_input: Vec<Complex64> = Vec::with_capacity(input_length);
435
436 for &val in input {
437 match num_traits::cast::cast::<T, f64>(val) {
438 Some(val_f64) => {
439 complex_input.push(Complex64::new(val_f64, 0.0));
440 }
441 None => {
442 if let Some(complex_val) = downcast_to_complex::<T>(&val) {
444 complex_input.push(complex_val);
445 } else {
446 return Err(FFTError::ValueError(format!(
447 "Could not convert {val:?} to f64 or Complex64"
448 )));
449 }
450 }
451 }
452 }
453
454 match n_val.cmp(&complex_input.len()) {
456 std::cmp::Ordering::Less => {
457 complex_input.truncate(n_val);
459 }
460 std::cmp::Ordering::Greater => {
461 complex_input.resize(n_val, Complex64::new(0.0, 0.0));
463 }
464 std::cmp::Ordering::Equal => {
465 }
467 }
468
469 let mut planner = FftPlanner::new();
471 let fft = match mode {
472 FftMode::Forward => planner.plan_fft_forward(n_val),
473 FftMode::Inverse => planner.plan_fft_inverse(n_val),
474 };
475
476 let mut buffer: Vec<RustComplex<f64>> = complex_input
478 .iter()
479 .map(|&c| RustComplex::new(c.re, c.im))
480 .collect();
481
482 fft.process(&mut buffer);
484
485 let scale = if mode == FftMode::Inverse {
487 1.0 / (n_val as f64)
488 } else {
489 1.0
490 };
491
492 let result: Vec<Complex64> = buffer
493 .into_iter()
494 .map(|c| Complex64::new(c.re * scale, c.im * scale))
495 .collect();
496
497 return Ok(result);
498 }
499
500 let chunk_size_nz = NonZeroUsize::new(chunk_size_val).unwrap_or(NonZeroUsize::new(1).unwrap());
502 let n_chunks = n_val.div_ceil(chunk_size_nz.get());
503 let mut result = Vec::with_capacity(n_val);
504
505 for i in 0..n_chunks {
506 let start = i * chunk_size_val;
507 let end = (start + chunk_size_val).min(n_val);
508 let chunk_size = end - start;
509
510 let mut chunk_input = Vec::with_capacity(chunk_size);
512
513 if start < input_length {
514 let input_end = end.min(input_length);
516 for val in input[start..input_end].iter() {
517 match num_traits::cast::cast::<T, f64>(*val) {
518 Some(val_f64) => {
519 chunk_input.push(Complex64::new(val_f64, 0.0));
520 }
521 None => {
522 if let Some(complex_val) = downcast_to_complex::<T>(val) {
524 chunk_input.push(complex_val);
525 } else {
526 return Err(FFTError::ValueError(format!(
527 "Could not convert {val:?} to f64 or Complex64"
528 )));
529 }
530 }
531 }
532 }
533
534 if input_end < end {
536 chunk_input.resize(chunk_size, Complex64::new(0.0, 0.0));
537 }
538 } else {
539 chunk_input.resize(chunk_size, Complex64::new(0.0, 0.0));
541 }
542
543 let mut planner = FftPlanner::new();
545 let fft = match mode {
546 FftMode::Forward => planner.plan_fft_forward(chunk_size),
547 FftMode::Inverse => planner.plan_fft_inverse(chunk_size),
548 };
549
550 let mut buffer: Vec<RustComplex<f64>> = chunk_input
552 .iter()
553 .map(|&c| RustComplex::new(c.re, c.im))
554 .collect();
555
556 fft.process(&mut buffer);
558
559 let scale = if mode == FftMode::Inverse {
561 1.0 / (chunk_size as f64)
562 } else {
563 1.0
564 };
565
566 let chunk_result: Vec<Complex64> = buffer
567 .into_iter()
568 .map(|c| Complex64::new(c.re * scale, c.im * scale))
569 .collect();
570
571 result.extend(chunk_result);
573 }
574
575 if mode == FftMode::Inverse {
578 let full_scale = 1.0 / (n_val as f64);
579 let chunk_scale = 1.0 / (chunk_size_val as f64);
580 let scale_adjustment = full_scale / chunk_scale;
581
582 for val in &mut result {
583 val.re *= scale_adjustment;
584 val.im *= scale_adjustment;
585 }
586 }
587
588 Ok(result)
589}
590
591#[cfg(test)]
592mod tests {
593 use super::*;
594 use approx::assert_relative_eq;
595 use ndarray::array;
596
597 #[test]
598 fn test_fft_inplace() {
599 let mut input = vec![
601 Complex64::new(1.0, 0.0),
602 Complex64::new(2.0, 0.0),
603 Complex64::new(3.0, 0.0),
604 Complex64::new(4.0, 0.0),
605 ];
606 let mut output = vec![Complex64::new(0.0, 0.0); 4];
607
608 fft_inplace(&mut input, &mut output, FftMode::Forward, false).unwrap();
610
611 assert_relative_eq!(input[0].re, 10.0, epsilon = 1e-10);
613
614 fft_inplace(&mut input, &mut output, FftMode::Inverse, true).unwrap();
616
617 assert_relative_eq!(input[0].re, 1.0, epsilon = 1e-10);
619 assert_relative_eq!(input[1].re, 2.0, epsilon = 1e-10);
620 assert_relative_eq!(input[2].re, 3.0, epsilon = 1e-10);
621 assert_relative_eq!(input[3].re, 4.0, epsilon = 1e-10);
622 }
623
624 #[test]
625 fn test_fft2_efficient() {
626 let arr = array![[1.0, 2.0], [3.0, 4.0]];
628
629 let spectrum_2d = fft2_efficient(&arr.view(), None, FftMode::Forward, false).unwrap();
631
632 assert_relative_eq!(spectrum_2d[[0, 0]].re, 10.0, epsilon = 1e-10);
634
635 let recovered = fft2_efficient(&spectrum_2d.view(), None, FftMode::Inverse, true).unwrap();
637
638 assert_relative_eq!(recovered[[0, 0]].re, 1.0, epsilon = 1e-10);
640 assert_relative_eq!(recovered[[0, 1]].re, 2.0, epsilon = 1e-10);
641 assert_relative_eq!(recovered[[1, 0]].re, 3.0, epsilon = 1e-10);
642 assert_relative_eq!(recovered[[1, 1]].re, 4.0, epsilon = 1e-10);
643 }
644
645 #[test]
646 fn test_fft_streaming() {
647 let signal = vec![1.0, 2.0, 3.0, 4.0];
649
650 let result = fft_streaming(&signal, None, FftMode::Forward, None).unwrap();
652
653 assert_relative_eq!(result[0].re, 10.0, epsilon = 1e-10);
655
656 let inverse = fft_streaming(&result, None, FftMode::Inverse, None).unwrap();
658
659 assert_relative_eq!(inverse[0].re, 1.0, epsilon = 1e-10);
661 assert_relative_eq!(inverse[1].re, 2.0, epsilon = 1e-10);
662 assert_relative_eq!(inverse[2].re, 3.0, epsilon = 1e-10);
663 assert_relative_eq!(inverse[3].re, 4.0, epsilon = 1e-10);
664
665 let result_chunked =
667 fft_streaming(&signal, None, FftMode::Forward, Some(signal.len())).unwrap();
668
669 for (a, b) in result.iter().zip(result_chunked.iter()) {
671 assert_relative_eq!(a.re, b.re, epsilon = 1e-10);
672 assert_relative_eq!(a.im, b.im, epsilon = 1e-10);
673 }
674 }
675}