scirs2_fft/fft/
algorithms.rs

1/*!
2 * FFT algorithm implementations
3 *
4 * This module provides implementations of the Fast Fourier Transform (FFT)
5 * and its inverse (IFFT) in 1D, 2D, and N-dimensional cases.
6 */
7
8use crate::error::{FFTError, FFTResult};
9use rustfft::{num_complex::Complex as RustComplex, FftPlanner};
10use scirs2_core::ndarray::{Array2, ArrayD, Axis, IxDyn};
11use scirs2_core::numeric::Complex64;
12use scirs2_core::numeric::NumCast;
13use scirs2_core::safe_ops::{safe_divide, safe_sqrt};
14use std::fmt::Debug;
15
16// We're using the serial implementation even with parallel feature enabled,
17// since we're not using parallelism at this level
18
19/// Normalization mode for FFT operations
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum NormMode {
22    /// No normalization (default for forward transforms)
23    None,
24    /// Normalize by 1/n (default for inverse transforms)
25    Backward,
26    /// Normalize by 1/sqrt(n) (unitary transform)
27    Ortho,
28    /// Normalize by 1/n for both forward and inverse transforms
29    Forward,
30}
31
32impl From<&str> for NormMode {
33    fn from(s: &str) -> Self {
34        match s {
35            "backward" => NormMode::Backward,
36            "ortho" => NormMode::Ortho,
37            "forward" => NormMode::Forward,
38            _ => NormMode::None,
39        }
40    }
41}
42
43/// Convert a normalization mode string to NormMode enum
44#[allow(dead_code)]
45pub fn parse_norm_mode(_norm: Option<&str>, isinverse: bool) -> NormMode {
46    match _norm {
47        Some(s) => NormMode::from(s),
48        None if isinverse => NormMode::Backward, // Default for _inverse transforms
49        None => NormMode::None,                  // Default for forward transforms
50    }
51}
52
53/// Apply normalization to FFT results based on the specified mode
54#[allow(dead_code)]
55fn apply_normalization(data: &mut [Complex64], n: usize, mode: NormMode) -> FFTResult<()> {
56    match mode {
57        NormMode::None => {} // No normalization
58        NormMode::Backward => {
59            let n_f64 = n as f64;
60            let scale = safe_divide(1.0, n_f64).map_err(|_| {
61                FFTError::ValueError(
62                    "Division by zero in backward normalization: FFT size is zero".to_string(),
63                )
64            })?;
65            data.iter_mut().for_each(|c| *c *= scale);
66        }
67        NormMode::Ortho => {
68            let n_f64 = n as f64;
69            let sqrt_n = safe_sqrt(n_f64).map_err(|_| {
70                FFTError::ComputationError(
71                    "Invalid square root in orthogonal normalization".to_string(),
72                )
73            })?;
74            let scale = safe_divide(1.0, sqrt_n).map_err(|_| {
75                FFTError::ValueError("Division by zero in orthogonal normalization".to_string())
76            })?;
77            data.iter_mut().for_each(|c| *c *= scale);
78        }
79        NormMode::Forward => {
80            let n_f64 = n as f64;
81            let scale = safe_divide(1.0, n_f64).map_err(|_| {
82                FFTError::ValueError(
83                    "Division by zero in forward normalization: FFT size is zero".to_string(),
84                )
85            })?;
86            data.iter_mut().for_each(|c| *c *= scale);
87        }
88    }
89    Ok(())
90}
91
92/// Convert a single value to Complex64
93#[allow(dead_code)]
94fn convert_to_complex<T>(val: T) -> FFTResult<Complex64>
95where
96    T: NumCast + Copy + Debug + 'static,
97{
98    // First try to cast directly to f64 (for real numbers)
99    if let Some(real) = NumCast::from(val) {
100        return Ok(Complex64::new(real, 0.0));
101    }
102
103    // If direct casting fails, check if it's already a Complex64
104    use std::any::Any;
105    if let Some(complex) = (&val as &dyn Any).downcast_ref::<Complex64>() {
106        return Ok(*complex);
107    }
108
109    // Try to handle f32 complex numbers
110    if let Some(complex32) = (&val as &dyn Any).downcast_ref::<scirs2_core::numeric::Complex<f32>>()
111    {
112        return Ok(Complex64::new(complex32.re as f64, complex32.im as f64));
113    }
114
115    Err(FFTError::ValueError(format!(
116        "Could not convert {val:?} to numeric type"
117    )))
118}
119
120/// Convert input data to complex values
121#[allow(dead_code)]
122fn to_complex<T>(input: &[T]) -> FFTResult<Vec<Complex64>>
123where
124    T: NumCast + Copy + Debug + 'static,
125{
126    input.iter().map(|&val| convert_to_complex(val)).collect()
127}
128
129/// Compute the 1-dimensional Fast Fourier Transform
130///
131/// # Arguments
132///
133/// * `input` - Input data array
134/// * `n` - Length of the output (optional)
135///
136/// # Returns
137///
138/// A vector of complex values representing the FFT result
139///
140/// # Examples
141///
142/// ```
143/// use scirs2_fft::fft;
144/// use scirs2_core::numeric::Complex64;
145///
146/// // Generate a simple signal
147/// let signal = vec![1.0, 2.0, 3.0, 4.0];
148///
149/// // Compute the FFT
150/// let spectrum = fft(&signal, None).unwrap();
151///
152/// // The DC component should be the sum of the input
153/// assert!((spectrum[0].re - 10.0).abs() < 1e-10);
154/// assert!(spectrum[0].im.abs() < 1e-10);
155/// ```
156#[allow(dead_code)]
157pub fn fft<T>(input: &[T], n: Option<usize>) -> FFTResult<Vec<Complex64>>
158where
159    T: NumCast + Copy + Debug + 'static,
160{
161    // Input validation
162    if input.is_empty() {
163        return Err(FFTError::ValueError("Input cannot be empty".to_string()));
164    }
165
166    // Determine the FFT size (n or next power of 2 if n is None)
167    let input_len = input.len();
168    let fft_size = n.unwrap_or_else(|| input_len.next_power_of_two());
169
170    // Convert _input to complex numbers
171    let mut data = to_complex(input)?;
172
173    // Pad or truncate data to match fft_size
174    if fft_size != input_len {
175        if fft_size > input_len {
176            // Pad with zeros
177            data.resize(fft_size, Complex64::new(0.0, 0.0));
178        } else {
179            // Truncate
180            data.truncate(fft_size);
181        }
182    }
183
184    // Use rustfft library for computation
185    let mut planner = FftPlanner::new();
186    let fft = planner.plan_fft_forward(fft_size);
187
188    // Convert to rustfft-compatible complex type
189    let mut buffer: Vec<RustComplex<f64>> =
190        data.iter().map(|c| RustComplex::new(c.re, c.im)).collect();
191
192    // Perform FFT in-place
193    fft.process(&mut buffer);
194
195    // Convert back to our Complex64 type
196    let result: Vec<Complex64> = buffer
197        .into_iter()
198        .map(|c| Complex64::new(c.re, c.im))
199        .collect();
200
201    Ok(result)
202}
203
204/// Compute the inverse 1-dimensional Fast Fourier Transform
205///
206/// # Arguments
207///
208/// * `input` - Input data array
209/// * `n` - Length of the output (optional)
210///
211/// # Returns
212///
213/// A vector of complex values representing the inverse FFT result
214///
215/// # Examples
216///
217/// ```
218/// use scirs2_fft::{fft, ifft};
219/// use scirs2_core::numeric::Complex64;
220///
221/// // Generate a simple signal
222/// let signal = vec![1.0, 2.0, 3.0, 4.0];
223///
224/// // Compute the FFT
225/// let spectrum = fft(&signal, None).unwrap();
226///
227/// // Compute the inverse FFT
228/// let reconstructed = ifft(&spectrum, None).unwrap();
229///
230/// // The reconstructed signal should match the original
231/// for (i, val) in signal.iter().enumerate() {
232///     assert!((*val - reconstructed[i].re).abs() < 1e-10);
233///     assert!(reconstructed[i].im.abs() < 1e-10);
234/// }
235/// ```
236#[allow(dead_code)]
237pub fn ifft<T>(input: &[T], n: Option<usize>) -> FFTResult<Vec<Complex64>>
238where
239    T: NumCast + Copy + Debug + 'static,
240{
241    // Input validation
242    if input.is_empty() {
243        return Err(FFTError::ValueError("Input cannot be empty".to_string()));
244    }
245
246    // Determine the FFT size
247    let input_len = input.len();
248    let fft_size = n.unwrap_or_else(|| input_len.next_power_of_two());
249
250    // Convert _input to complex numbers
251    let mut data = to_complex(input)?;
252
253    // Pad or truncate data to match fft_size
254    if fft_size != input_len {
255        if fft_size > input_len {
256            // Pad with zeros
257            data.resize(fft_size, Complex64::new(0.0, 0.0));
258        } else {
259            // Truncate
260            data.truncate(fft_size);
261        }
262    }
263
264    // Create FFT planner and plan
265    let mut planner = FftPlanner::new();
266    let ifft = planner.plan_fft_inverse(fft_size);
267
268    // Convert to rustfft-compatible complex type
269    let mut buffer: Vec<RustComplex<f64>> =
270        data.iter().map(|c| RustComplex::new(c.re, c.im)).collect();
271
272    // Perform inverse FFT in-place
273    ifft.process(&mut buffer);
274
275    // Convert back to our Complex64 type with normalization
276    let mut result: Vec<Complex64> = buffer
277        .into_iter()
278        .map(|c| Complex64::new(c.re, c.im))
279        .collect();
280
281    // Apply 1/N normalization (standard for IFFT)
282    apply_normalization(&mut result, fft_size, NormMode::Backward)?;
283
284    // Truncate if necessary to match the original _input length
285    if n.is_none() && fft_size > input_len {
286        result.truncate(input_len);
287    }
288
289    Ok(result)
290}
291
292/// Compute the 2-dimensional Fast Fourier Transform
293///
294/// # Arguments
295///
296/// * `input` - Input 2D array
297/// * `shape` - Shape of the output (optional)
298/// * `axes` - Axes along which to compute the FFT (optional)
299/// * `norm` - Normalization mode: "backward", "ortho", or "forward" (optional)
300///
301/// # Returns
302///
303/// A 2D array of complex values representing the FFT result
304///
305/// # Examples
306///
307/// ```
308/// use scirs2_fft::fft2;
309/// use scirs2_core::ndarray::{array, Array2};
310///
311/// // Create a simple 2x2 array
312/// let input = array![[1.0, 2.0], [3.0, 4.0]];
313///
314/// // Compute the 2D FFT
315/// let result = fft2(&input, None, None, None).unwrap();
316///
317/// // The DC component should be the sum of all elements
318/// assert!((result[[0, 0]].re - 10.0).abs() < 1e-10);
319/// ```
320#[allow(dead_code)]
321pub fn fft2<T>(
322    input: &Array2<T>,
323    shape: Option<(usize, usize)>,
324    axes: Option<(i32, i32)>,
325    norm: Option<&str>,
326) -> FFTResult<Array2<Complex64>>
327where
328    T: NumCast + Copy + Debug + 'static,
329{
330    // Get input array shape
331    let inputshape = input.shape();
332
333    // Determine output shape
334    let outputshape = shape.unwrap_or((inputshape[0], inputshape[1]));
335
336    // Determine axes to perform FFT on
337    let axes = axes.unwrap_or((0, 1));
338
339    // Validate axes
340    if axes.0 < 0 || axes.0 > 1 || axes.1 < 0 || axes.1 > 1 || axes.0 == axes.1 {
341        return Err(FFTError::ValueError("Invalid axes for 2D FFT".to_string()));
342    }
343
344    // Parse normalization mode
345    let norm_mode = parse_norm_mode(norm, false);
346
347    // Create the output array
348    let mut output = Array2::<Complex64>::zeros(outputshape);
349
350    // Convert input array to complex numbers
351    let mut complex_input = Array2::<Complex64>::zeros((inputshape[0], inputshape[1]));
352    for i in 0..inputshape[0] {
353        for j in 0..inputshape[1] {
354            let val = input[[i, j]];
355
356            // Convert using the unified conversion function
357            complex_input[[i, j]] = convert_to_complex(val)?;
358        }
359    }
360
361    // Pad or truncate to match output shape if necessary
362    let mut padded_input = if inputshape != [outputshape.0, outputshape.1] {
363        let mut padded = Array2::<Complex64>::zeros((outputshape.0, outputshape.1));
364        let copy_rows = std::cmp::min(inputshape[0], outputshape.0);
365        let copy_cols = std::cmp::min(inputshape[1], outputshape.1);
366
367        for i in 0..copy_rows {
368            for j in 0..copy_cols {
369                padded[[i, j]] = complex_input[[i, j]];
370            }
371        }
372        padded
373    } else {
374        complex_input
375    };
376
377    // Create FFT planner
378    let mut planner = FftPlanner::new();
379
380    // Perform FFT along each row
381    let row_fft = planner.plan_fft_forward(outputshape.1);
382    for mut row in padded_input.rows_mut() {
383        // Convert to rustfft compatible format
384        let mut buffer: Vec<RustComplex<f64>> =
385            row.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
386
387        // Perform FFT
388        row_fft.process(&mut buffer);
389
390        // Update row with FFT result
391        for (i, val) in buffer.iter().enumerate() {
392            row[i] = Complex64::new(val.re, val.im);
393        }
394    }
395
396    // Perform FFT along each column
397    let col_fft = planner.plan_fft_forward(outputshape.0);
398    for mut col in padded_input.columns_mut() {
399        // Convert to rustfft compatible format
400        let mut buffer: Vec<RustComplex<f64>> =
401            col.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
402
403        // Perform FFT
404        col_fft.process(&mut buffer);
405
406        // Update column with FFT result
407        for (i, val) in buffer.iter().enumerate() {
408            col[i] = Complex64::new(val.re, val.im);
409        }
410    }
411
412    // Apply normalization if needed
413    if norm_mode != NormMode::None {
414        let total_elements = outputshape.0 * outputshape.1;
415        let scale = match norm_mode {
416            NormMode::Backward => 1.0 / (total_elements as f64),
417            NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
418            NormMode::Forward => 1.0 / (total_elements as f64),
419            NormMode::None => 1.0, // Should not happen due to earlier check
420        };
421
422        padded_input.mapv_inplace(|x| x * scale);
423    }
424
425    // Copy result to output
426    output.assign(&padded_input);
427
428    Ok(output)
429}
430
431/// Compute the inverse 2-dimensional Fast Fourier Transform
432///
433/// # Arguments
434///
435/// * `input` - Input 2D array
436/// * `shape` - Shape of the output (optional)
437/// * `axes` - Axes along which to compute the inverse FFT (optional)
438/// * `norm` - Normalization mode: "backward", "ortho", or "forward" (optional)
439///
440/// # Returns
441///
442/// A 2D array of complex values representing the inverse FFT result
443///
444/// # Examples
445///
446/// ```
447/// use scirs2_fft::{fft2, ifft2};
448/// use scirs2_core::ndarray::{array, Array2};
449///
450/// // Create a simple 2x2 array
451/// let input = array![[1.0, 2.0], [3.0, 4.0]];
452///
453/// // Compute the 2D FFT
454/// let spectrum = fft2(&input, None, None, None).unwrap();
455///
456/// // Compute the inverse 2D FFT
457/// let reconstructed = ifft2(&spectrum, None, None, None).unwrap();
458///
459/// // The reconstructed signal should match the original
460/// for i in 0..2 {
461///     for j in 0..2 {
462///         assert!((input[[i, j]] - reconstructed[[i, j]].re).abs() < 1e-10);
463///         assert!(reconstructed[[i, j]].im.abs() < 1e-10);
464///     }
465/// }
466/// ```
467#[allow(dead_code)]
468pub fn ifft2<T>(
469    input: &Array2<T>,
470    shape: Option<(usize, usize)>,
471    axes: Option<(i32, i32)>,
472    norm: Option<&str>,
473) -> FFTResult<Array2<Complex64>>
474where
475    T: NumCast + Copy + Debug + 'static,
476{
477    // Get input array shape
478    let inputshape = input.shape();
479
480    // Determine output shape
481    let outputshape = shape.unwrap_or((inputshape[0], inputshape[1]));
482
483    // Determine axes to perform FFT on
484    let axes = axes.unwrap_or((0, 1));
485
486    // Validate axes
487    if axes.0 < 0 || axes.0 > 1 || axes.1 < 0 || axes.1 > 1 || axes.0 == axes.1 {
488        return Err(FFTError::ValueError("Invalid axes for 2D IFFT".to_string()));
489    }
490
491    // Parse normalization mode (default is "backward" for inverse FFT)
492    let norm_mode = parse_norm_mode(norm, true);
493
494    // Convert input to complex and copy to output shape
495    let mut complex_input = Array2::<Complex64>::zeros((inputshape[0], inputshape[1]));
496    for i in 0..inputshape[0] {
497        for j in 0..inputshape[1] {
498            let val = input[[i, j]];
499
500            // Convert using the unified conversion function
501            complex_input[[i, j]] = convert_to_complex(val)?;
502        }
503    }
504
505    // Pad or truncate to match output shape if necessary
506    let mut padded_input = if inputshape != [outputshape.0, outputshape.1] {
507        let mut padded = Array2::<Complex64>::zeros((outputshape.0, outputshape.1));
508        let copy_rows = std::cmp::min(inputshape[0], outputshape.0);
509        let copy_cols = std::cmp::min(inputshape[1], outputshape.1);
510
511        for i in 0..copy_rows {
512            for j in 0..copy_cols {
513                padded[[i, j]] = complex_input[[i, j]];
514            }
515        }
516        padded
517    } else {
518        complex_input
519    };
520
521    // Create FFT planner
522    let mut planner = FftPlanner::new();
523
524    // Perform inverse FFT along each row
525    let row_ifft = planner.plan_fft_inverse(outputshape.1);
526    for mut row in padded_input.rows_mut() {
527        // Convert to rustfft compatible format
528        let mut buffer: Vec<RustComplex<f64>> =
529            row.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
530
531        // Perform inverse FFT
532        row_ifft.process(&mut buffer);
533
534        // Update row with IFFT result
535        for (i, val) in buffer.iter().enumerate() {
536            row[i] = Complex64::new(val.re, val.im);
537        }
538    }
539
540    // Perform inverse FFT along each column
541    let col_ifft = planner.plan_fft_inverse(outputshape.0);
542    for mut col in padded_input.columns_mut() {
543        // Convert to rustfft compatible format
544        let mut buffer: Vec<RustComplex<f64>> =
545            col.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
546
547        // Perform inverse FFT
548        col_ifft.process(&mut buffer);
549
550        // Update column with IFFT result
551        for (i, val) in buffer.iter().enumerate() {
552            col[i] = Complex64::new(val.re, val.im);
553        }
554    }
555
556    // Apply appropriate normalization
557    let total_elements = outputshape.0 * outputshape.1;
558    let scale = match norm_mode {
559        NormMode::Backward => 1.0 / (total_elements as f64),
560        NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
561        NormMode::Forward => 1.0, // No additional normalization for forward mode in IFFT
562        NormMode::None => 1.0,    // No normalization
563    };
564
565    if scale != 1.0 {
566        padded_input.mapv_inplace(|x| x * scale);
567    }
568
569    Ok(padded_input)
570}
571
572/// Compute the N-dimensional Fast Fourier Transform
573///
574/// # Arguments
575///
576/// * `input` - Input N-dimensional array
577/// * `shape` - Shape of the output (optional)
578/// * `axes` - Axes along which to compute the FFT (optional)
579/// * `norm` - Normalization mode: "backward", "ortho", or "forward" (optional)
580/// * `overwrite_x` - Whether to overwrite the input array (optional)
581/// * `workers` - Number of worker threads to use (optional)
582///
583/// # Returns
584///
585/// An N-dimensional array of complex values representing the FFT result
586///
587/// # Examples
588///
589/// ```
590/// use scirs2_fft::fftn;
591/// use scirs2_core::ndarray::{Array, IxDyn};
592///
593/// // Create a 3D array
594/// let mut data = Array::zeros(IxDyn(&[2, 2, 2]));
595/// data[[0, 0, 0]] = 1.0;
596/// data[[1, 1, 1]] = 1.0;
597///
598/// // Compute the N-dimensional FFT
599/// let result = fftn(&data, None, None, None, None, None).unwrap();
600///
601/// // Check dimensions
602/// assert_eq!(result.shape(), &[2, 2, 2]);
603/// ```
604#[allow(clippy::too_many_arguments)]
605#[allow(dead_code)]
606pub fn fftn<T>(
607    input: &ArrayD<T>,
608    shape: Option<Vec<usize>>,
609    axes: Option<Vec<usize>>,
610    norm: Option<&str>,
611    _overwrite_x: Option<bool>,
612    _workers: Option<usize>,
613) -> FFTResult<ArrayD<Complex64>>
614where
615    T: NumCast + Copy + Debug + 'static,
616{
617    let inputshape = input.shape().to_vec();
618    let input_ndim = inputshape.len();
619
620    // Determine output shape
621    let outputshape = shape.unwrap_or_else(|| inputshape.clone());
622
623    // Validate output shape
624    if outputshape.len() != input_ndim {
625        return Err(FFTError::ValueError(
626            "Output shape must have the same number of dimensions as input".to_string(),
627        ));
628    }
629
630    // Determine axes to perform FFT on
631    let axes = axes.unwrap_or_else(|| (0..input_ndim).collect());
632
633    // Validate axes
634    for &axis in &axes {
635        if axis >= input_ndim {
636            return Err(FFTError::ValueError(format!(
637                "Axis {axis} out of bounds for array of dimension {input_ndim}"
638            )));
639        }
640    }
641
642    // Parse normalization mode
643    let norm_mode = parse_norm_mode(norm, false);
644
645    // Convert input array to complex
646    let mut complex_input = ArrayD::<Complex64>::zeros(IxDyn(&inputshape));
647    for (idx, &val) in input.iter().enumerate() {
648        let mut idx_vec = Vec::with_capacity(input_ndim);
649        let mut remaining = idx;
650
651        for &dim in input.shape().iter().rev() {
652            idx_vec.push(remaining % dim);
653            remaining /= dim;
654        }
655
656        idx_vec.reverse();
657
658        complex_input[IxDyn(&idx_vec)] = convert_to_complex(val)?;
659    }
660
661    // Pad or truncate to match output shape if necessary
662    let mut result = if inputshape != outputshape {
663        let mut padded = ArrayD::<Complex64>::zeros(IxDyn(&outputshape));
664
665        // Copy all elements that fit within both arrays
666        for (idx, &val) in complex_input.iter().enumerate() {
667            let mut idx_vec = Vec::with_capacity(input_ndim);
668            let mut remaining = idx;
669
670            for &dim in input.shape().iter().rev() {
671                idx_vec.push(remaining % dim);
672                remaining /= dim;
673            }
674
675            idx_vec.reverse();
676
677            let mut in_bounds = true;
678            for (dim, &idx_val) in idx_vec.iter().enumerate() {
679                if idx_val >= outputshape[dim] {
680                    in_bounds = false;
681                    break;
682                }
683            }
684
685            if in_bounds {
686                padded[IxDyn(&idx_vec)] = val;
687            }
688        }
689
690        padded
691    } else {
692        complex_input
693    };
694
695    // Create FFT planner
696    let mut planner = FftPlanner::new();
697
698    // Perform FFT along each axis
699    for &axis in &axes {
700        let axis_len = outputshape[axis];
701        let fft = planner.plan_fft_forward(axis_len);
702
703        // For each slice along the current axis
704        let axis = Axis(axis);
705
706        for mut lane in result.lanes_mut(axis) {
707            // Convert to rustfft compatible format
708            let mut buffer: Vec<RustComplex<f64>> =
709                lane.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
710
711            // Perform FFT
712            fft.process(&mut buffer);
713
714            // Update lane with FFT result
715            for (i, val) in buffer.iter().enumerate() {
716                lane[i] = Complex64::new(val.re, val.im);
717            }
718        }
719    }
720
721    // Apply normalization if needed
722    if norm_mode != NormMode::None {
723        let total_elements: usize = outputshape.iter().product();
724        let scale = match norm_mode {
725            NormMode::Backward => 1.0 / (total_elements as f64),
726            NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
727            NormMode::Forward => 1.0 / (total_elements as f64),
728            NormMode::None => 1.0, // Should not happen due to earlier check
729        };
730
731        result.mapv_inplace(|_x| _x * scale);
732    }
733
734    Ok(result)
735}
736
737/// Compute the inverse N-dimensional Fast Fourier Transform
738///
739/// # Arguments
740///
741/// * `input` - Input N-dimensional array
742/// * `shape` - Shape of the output (optional)
743/// * `axes` - Axes along which to compute the inverse FFT (optional)
744/// * `norm` - Normalization mode: "backward", "ortho", or "forward" (optional)
745/// * `overwrite_x` - Whether to overwrite the input array (optional)
746/// * `workers` - Number of worker threads to use (optional)
747///
748/// # Returns
749///
750/// An N-dimensional array of complex values representing the inverse FFT result
751///
752/// # Examples
753///
754/// ```
755/// use scirs2_fft::{fftn, ifftn};
756/// use scirs2_core::ndarray::{Array, IxDyn};
757/// use scirs2_core::numeric::Complex64;
758///
759/// // Create a 3D array
760/// let mut data = Array::zeros(IxDyn(&[2, 2, 2]));
761/// data[[0, 0, 0]] = 1.0;
762/// data[[1, 1, 1]] = 1.0;
763///
764/// // Compute the N-dimensional FFT
765/// let spectrum = fftn(&data, None, None, None, None, None).unwrap();
766///
767/// // Compute the inverse N-dimensional FFT
768/// let result = ifftn(&spectrum, None, None, None, None, None).unwrap();
769///
770/// // Check if the original data is recovered
771/// for i in 0..2 {
772///     for j in 0..2 {
773///         for k in 0..2 {
774///             let expected = if (i == 0 && j == 0 && k == 0) || (i == 1 && j == 1 && k == 1) {
775///                 1.0
776///             } else {
777///                 0.0
778///             };
779///             assert!((result[[i, j, k]].re - expected).abs() < 1e-10);
780///             assert!(result[[i, j, k]].im.abs() < 1e-10);
781///         }
782///     }
783/// }
784/// ```
785#[allow(clippy::too_many_arguments)]
786#[allow(dead_code)]
787pub fn ifftn<T>(
788    input: &ArrayD<T>,
789    shape: Option<Vec<usize>>,
790    axes: Option<Vec<usize>>,
791    norm: Option<&str>,
792    _overwrite_x: Option<bool>,
793    _workers: Option<usize>,
794) -> FFTResult<ArrayD<Complex64>>
795where
796    T: NumCast + Copy + Debug + 'static,
797{
798    let inputshape = input.shape().to_vec();
799    let input_ndim = inputshape.len();
800
801    // Determine output shape
802    let outputshape = shape.unwrap_or_else(|| inputshape.clone());
803
804    // Validate output shape
805    if outputshape.len() != input_ndim {
806        return Err(FFTError::ValueError(
807            "Output shape must have the same number of dimensions as input".to_string(),
808        ));
809    }
810
811    // Determine axes to perform FFT on
812    let axes = axes.unwrap_or_else(|| (0..input_ndim).collect());
813
814    // Validate axes
815    for &axis in &axes {
816        if axis >= input_ndim {
817            return Err(FFTError::ValueError(format!(
818                "Axis {axis} out of bounds for array of dimension {input_ndim}"
819            )));
820        }
821    }
822
823    // Parse normalization mode (default is "backward" for inverse FFT)
824    let norm_mode = parse_norm_mode(norm, true);
825
826    // Create workspace array - convert input to complex first
827    let mut complex_input = ArrayD::<Complex64>::zeros(IxDyn(&inputshape));
828    for (idx, &val) in input.iter().enumerate() {
829        let mut idx_vec = Vec::with_capacity(input_ndim);
830        let mut remaining = idx;
831
832        for &dim in input.shape().iter().rev() {
833            idx_vec.push(remaining % dim);
834            remaining /= dim;
835        }
836
837        idx_vec.reverse();
838
839        // Try to convert to Complex64
840        complex_input[IxDyn(&idx_vec)] = convert_to_complex(val)?;
841    }
842
843    // Now handle padding/resizing if needed
844    let mut result = if inputshape != outputshape {
845        let mut padded = ArrayD::<Complex64>::zeros(IxDyn(&outputshape));
846
847        // Copy all elements that fit within both arrays
848        for (idx, &val) in complex_input.iter().enumerate() {
849            let mut idx_vec = Vec::with_capacity(input_ndim);
850            let mut remaining = idx;
851
852            for &dim in input.shape().iter().rev() {
853                idx_vec.push(remaining % dim);
854                remaining /= dim;
855            }
856
857            idx_vec.reverse();
858
859            let mut in_bounds = true;
860            for (dim, &idx_val) in idx_vec.iter().enumerate() {
861                if idx_val >= outputshape[dim] {
862                    in_bounds = false;
863                    break;
864                }
865            }
866
867            if in_bounds {
868                padded[IxDyn(&idx_vec)] = val;
869            }
870        }
871
872        padded
873    } else {
874        complex_input
875    };
876
877    // Create FFT planner
878    let mut planner = FftPlanner::new();
879
880    // Perform inverse FFT along each axis
881    for &axis in &axes {
882        let axis_len = outputshape[axis];
883        let ifft = planner.plan_fft_inverse(axis_len);
884
885        // For each slice along the current axis
886        let axis = Axis(axis);
887
888        for mut lane in result.lanes_mut(axis) {
889            // Convert to rustfft compatible format
890            let mut buffer: Vec<RustComplex<f64>> =
891                lane.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
892
893            // Perform inverse FFT
894            ifft.process(&mut buffer);
895
896            // Update lane with IFFT result
897            for (i, val) in buffer.iter().enumerate() {
898                lane[i] = Complex64::new(val.re, val.im);
899            }
900        }
901    }
902
903    // Apply appropriate normalization
904    if norm_mode != NormMode::None {
905        let total_elements: usize = axes.iter().map(|&a| outputshape[a]).product();
906        let scale = match norm_mode {
907            NormMode::Backward => 1.0 / (total_elements as f64),
908            NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
909            NormMode::Forward => 1.0, // No additional normalization
910            NormMode::None => 1.0,    // No normalization
911        };
912
913        if scale != 1.0 {
914            result.mapv_inplace(|_x| _x * scale);
915        }
916    }
917
918    Ok(result)
919}