scirs2_fft/fft/
algorithms.rs

1/*!
2 * FFT algorithm implementations
3 *
4 * This module provides implementations of the Fast Fourier Transform (FFT)
5 * and its inverse (IFFT) in 1D, 2D, and N-dimensional cases.
6 */
7
8use crate::error::{FFTError, FFTResult};
9use ndarray::{Array2, ArrayD, Axis, IxDyn};
10use num_complex::Complex64;
11use num_traits::NumCast;
12use rustfft::{num_complex::Complex as RustComplex, FftPlanner};
13use scirs2_core::safe_ops::{safe_divide, safe_sqrt};
14use std::fmt::Debug;
15
16// We're using the serial implementation even with parallel feature enabled,
17// since we're not using parallelism at this level
18
19/// Normalization mode for FFT operations
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum NormMode {
22    /// No normalization (default for forward transforms)
23    None,
24    /// Normalize by 1/n (default for inverse transforms)
25    Backward,
26    /// Normalize by 1/sqrt(n) (unitary transform)
27    Ortho,
28    /// Normalize by 1/n for both forward and inverse transforms
29    Forward,
30}
31
32impl From<&str> for NormMode {
33    fn from(s: &str) -> Self {
34        match s {
35            "backward" => NormMode::Backward,
36            "ortho" => NormMode::Ortho,
37            "forward" => NormMode::Forward,
38            _ => NormMode::None,
39        }
40    }
41}
42
43/// Convert a normalization mode string to NormMode enum
44#[allow(dead_code)]
45pub fn parse_norm_mode(_norm: Option<&str>, isinverse: bool) -> NormMode {
46    match _norm {
47        Some(s) => NormMode::from(s),
48        None if isinverse => NormMode::Backward, // Default for _inverse transforms
49        None => NormMode::None,                  // Default for forward transforms
50    }
51}
52
53/// Apply normalization to FFT results based on the specified mode
54#[allow(dead_code)]
55fn apply_normalization(data: &mut [Complex64], n: usize, mode: NormMode) -> FFTResult<()> {
56    match mode {
57        NormMode::None => {} // No normalization
58        NormMode::Backward => {
59            let n_f64 = n as f64;
60            let scale = safe_divide(1.0, n_f64).map_err(|_| {
61                FFTError::ValueError(
62                    "Division by zero in backward normalization: FFT size is zero".to_string(),
63                )
64            })?;
65            data.iter_mut().for_each(|c| *c *= scale);
66        }
67        NormMode::Ortho => {
68            let n_f64 = n as f64;
69            let sqrt_n = safe_sqrt(n_f64).map_err(|_| {
70                FFTError::ComputationError(
71                    "Invalid square root in orthogonal normalization".to_string(),
72                )
73            })?;
74            let scale = safe_divide(1.0, sqrt_n).map_err(|_| {
75                FFTError::ValueError("Division by zero in orthogonal normalization".to_string())
76            })?;
77            data.iter_mut().for_each(|c| *c *= scale);
78        }
79        NormMode::Forward => {
80            let n_f64 = n as f64;
81            let scale = safe_divide(1.0, n_f64).map_err(|_| {
82                FFTError::ValueError(
83                    "Division by zero in forward normalization: FFT size is zero".to_string(),
84                )
85            })?;
86            data.iter_mut().for_each(|c| *c *= scale);
87        }
88    }
89    Ok(())
90}
91
92/// Convert a single value to Complex64
93#[allow(dead_code)]
94fn convert_to_complex<T>(val: T) -> FFTResult<Complex64>
95where
96    T: NumCast + Copy + Debug + 'static,
97{
98    // First try to cast directly to f64 (for real numbers)
99    if let Some(real) = num_traits::cast::<T, f64>(val) {
100        return Ok(Complex64::new(real, 0.0));
101    }
102
103    // If direct casting fails, check if it's already a Complex64
104    use std::any::Any;
105    if let Some(complex) = (&val as &dyn Any).downcast_ref::<Complex64>() {
106        return Ok(*complex);
107    }
108
109    // Try to handle f32 complex numbers
110    if let Some(complex32) = (&val as &dyn Any).downcast_ref::<num_complex::Complex<f32>>() {
111        return Ok(Complex64::new(complex32.re as f64, complex32.im as f64));
112    }
113
114    Err(FFTError::ValueError(format!(
115        "Could not convert {val:?} to numeric type"
116    )))
117}
118
119/// Convert input data to complex values
120#[allow(dead_code)]
121fn to_complex<T>(input: &[T]) -> FFTResult<Vec<Complex64>>
122where
123    T: NumCast + Copy + Debug + 'static,
124{
125    input.iter().map(|&val| convert_to_complex(val)).collect()
126}
127
128/// Compute the 1-dimensional Fast Fourier Transform
129///
130/// # Arguments
131///
132/// * `input` - Input data array
133/// * `n` - Length of the output (optional)
134///
135/// # Returns
136///
137/// A vector of complex values representing the FFT result
138///
139/// # Examples
140///
141/// ```
142/// use scirs2_fft::fft;
143/// use num_complex::Complex64;
144///
145/// // Generate a simple signal
146/// let signal = vec![1.0, 2.0, 3.0, 4.0];
147///
148/// // Compute the FFT
149/// let spectrum = fft(&signal, None).unwrap();
150///
151/// // The DC component should be the sum of the input
152/// assert!((spectrum[0].re - 10.0).abs() < 1e-10);
153/// assert!(spectrum[0].im.abs() < 1e-10);
154/// ```
155#[allow(dead_code)]
156pub fn fft<T>(input: &[T], n: Option<usize>) -> FFTResult<Vec<Complex64>>
157where
158    T: NumCast + Copy + Debug + 'static,
159{
160    // Input validation
161    if input.is_empty() {
162        return Err(FFTError::ValueError("Input cannot be empty".to_string()));
163    }
164
165    // Determine the FFT size (n or next power of 2 if n is None)
166    let input_len = input.len();
167    let fft_size = n.unwrap_or_else(|| input_len.next_power_of_two());
168
169    // Convert _input to complex numbers
170    let mut data = to_complex(input)?;
171
172    // Pad or truncate data to match fft_size
173    if fft_size != input_len {
174        if fft_size > input_len {
175            // Pad with zeros
176            data.resize(fft_size, Complex64::new(0.0, 0.0));
177        } else {
178            // Truncate
179            data.truncate(fft_size);
180        }
181    }
182
183    // Use rustfft library for computation
184    let mut planner = FftPlanner::new();
185    let fft = planner.plan_fft_forward(fft_size);
186
187    // Convert to rustfft-compatible complex type
188    let mut buffer: Vec<RustComplex<f64>> =
189        data.iter().map(|c| RustComplex::new(c.re, c.im)).collect();
190
191    // Perform FFT in-place
192    fft.process(&mut buffer);
193
194    // Convert back to our Complex64 type
195    let result: Vec<Complex64> = buffer
196        .into_iter()
197        .map(|c| Complex64::new(c.re, c.im))
198        .collect();
199
200    Ok(result)
201}
202
203/// Compute the inverse 1-dimensional Fast Fourier Transform
204///
205/// # Arguments
206///
207/// * `input` - Input data array
208/// * `n` - Length of the output (optional)
209///
210/// # Returns
211///
212/// A vector of complex values representing the inverse FFT result
213///
214/// # Examples
215///
216/// ```
217/// use scirs2_fft::{fft, ifft};
218/// use num_complex::Complex64;
219///
220/// // Generate a simple signal
221/// let signal = vec![1.0, 2.0, 3.0, 4.0];
222///
223/// // Compute the FFT
224/// let spectrum = fft(&signal, None).unwrap();
225///
226/// // Compute the inverse FFT
227/// let reconstructed = ifft(&spectrum, None).unwrap();
228///
229/// // The reconstructed signal should match the original
230/// for (i, val) in signal.iter().enumerate() {
231///     assert!((*val - reconstructed[i].re).abs() < 1e-10);
232///     assert!(reconstructed[i].im.abs() < 1e-10);
233/// }
234/// ```
235#[allow(dead_code)]
236pub fn ifft<T>(input: &[T], n: Option<usize>) -> FFTResult<Vec<Complex64>>
237where
238    T: NumCast + Copy + Debug + 'static,
239{
240    // Input validation
241    if input.is_empty() {
242        return Err(FFTError::ValueError("Input cannot be empty".to_string()));
243    }
244
245    // Determine the FFT size
246    let input_len = input.len();
247    let fft_size = n.unwrap_or_else(|| input_len.next_power_of_two());
248
249    // Convert _input to complex numbers
250    let mut data = to_complex(input)?;
251
252    // Pad or truncate data to match fft_size
253    if fft_size != input_len {
254        if fft_size > input_len {
255            // Pad with zeros
256            data.resize(fft_size, Complex64::new(0.0, 0.0));
257        } else {
258            // Truncate
259            data.truncate(fft_size);
260        }
261    }
262
263    // Create FFT planner and plan
264    let mut planner = FftPlanner::new();
265    let ifft = planner.plan_fft_inverse(fft_size);
266
267    // Convert to rustfft-compatible complex type
268    let mut buffer: Vec<RustComplex<f64>> =
269        data.iter().map(|c| RustComplex::new(c.re, c.im)).collect();
270
271    // Perform inverse FFT in-place
272    ifft.process(&mut buffer);
273
274    // Convert back to our Complex64 type with normalization
275    let mut result: Vec<Complex64> = buffer
276        .into_iter()
277        .map(|c| Complex64::new(c.re, c.im))
278        .collect();
279
280    // Apply 1/N normalization (standard for IFFT)
281    apply_normalization(&mut result, fft_size, NormMode::Backward)?;
282
283    // Truncate if necessary to match the original _input length
284    if n.is_none() && fft_size > input_len {
285        result.truncate(input_len);
286    }
287
288    Ok(result)
289}
290
291/// Compute the 2-dimensional Fast Fourier Transform
292///
293/// # Arguments
294///
295/// * `input` - Input 2D array
296/// * `shape` - Shape of the output (optional)
297/// * `axes` - Axes along which to compute the FFT (optional)
298/// * `norm` - Normalization mode: "backward", "ortho", or "forward" (optional)
299///
300/// # Returns
301///
302/// A 2D array of complex values representing the FFT result
303///
304/// # Examples
305///
306/// ```
307/// use scirs2_fft::fft2;
308/// use ndarray::{array, Array2};
309///
310/// // Create a simple 2x2 array
311/// let input = array![[1.0, 2.0], [3.0, 4.0]];
312///
313/// // Compute the 2D FFT
314/// let result = fft2(&input, None, None, None).unwrap();
315///
316/// // The DC component should be the sum of all elements
317/// assert!((result[[0, 0]].re - 10.0).abs() < 1e-10);
318/// ```
319#[allow(dead_code)]
320pub fn fft2<T>(
321    input: &Array2<T>,
322    shape: Option<(usize, usize)>,
323    axes: Option<(i32, i32)>,
324    norm: Option<&str>,
325) -> FFTResult<Array2<Complex64>>
326where
327    T: NumCast + Copy + Debug + 'static,
328{
329    // Get input array shape
330    let inputshape = input.shape();
331
332    // Determine output shape
333    let outputshape = shape.unwrap_or((inputshape[0], inputshape[1]));
334
335    // Determine axes to perform FFT on
336    let axes = axes.unwrap_or((0, 1));
337
338    // Validate axes
339    if axes.0 < 0 || axes.0 > 1 || axes.1 < 0 || axes.1 > 1 || axes.0 == axes.1 {
340        return Err(FFTError::ValueError("Invalid axes for 2D FFT".to_string()));
341    }
342
343    // Parse normalization mode
344    let norm_mode = parse_norm_mode(norm, false);
345
346    // Create the output array
347    let mut output = Array2::<Complex64>::zeros(outputshape);
348
349    // Convert input array to complex numbers
350    let mut complex_input = Array2::<Complex64>::zeros((inputshape[0], inputshape[1]));
351    for i in 0..inputshape[0] {
352        for j in 0..inputshape[1] {
353            let val = input[[i, j]];
354
355            // Convert using the unified conversion function
356            complex_input[[i, j]] = convert_to_complex(val)?;
357        }
358    }
359
360    // Pad or truncate to match output shape if necessary
361    let mut padded_input = if inputshape != [outputshape.0, outputshape.1] {
362        let mut padded = Array2::<Complex64>::zeros((outputshape.0, outputshape.1));
363        let copy_rows = std::cmp::min(inputshape[0], outputshape.0);
364        let copy_cols = std::cmp::min(inputshape[1], outputshape.1);
365
366        for i in 0..copy_rows {
367            for j in 0..copy_cols {
368                padded[[i, j]] = complex_input[[i, j]];
369            }
370        }
371        padded
372    } else {
373        complex_input
374    };
375
376    // Create FFT planner
377    let mut planner = FftPlanner::new();
378
379    // Perform FFT along each row
380    let row_fft = planner.plan_fft_forward(outputshape.1);
381    for mut row in padded_input.rows_mut() {
382        // Convert to rustfft compatible format
383        let mut buffer: Vec<RustComplex<f64>> =
384            row.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
385
386        // Perform FFT
387        row_fft.process(&mut buffer);
388
389        // Update row with FFT result
390        for (i, val) in buffer.iter().enumerate() {
391            row[i] = Complex64::new(val.re, val.im);
392        }
393    }
394
395    // Perform FFT along each column
396    let col_fft = planner.plan_fft_forward(outputshape.0);
397    for mut col in padded_input.columns_mut() {
398        // Convert to rustfft compatible format
399        let mut buffer: Vec<RustComplex<f64>> =
400            col.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
401
402        // Perform FFT
403        col_fft.process(&mut buffer);
404
405        // Update column with FFT result
406        for (i, val) in buffer.iter().enumerate() {
407            col[i] = Complex64::new(val.re, val.im);
408        }
409    }
410
411    // Apply normalization if needed
412    if norm_mode != NormMode::None {
413        let total_elements = outputshape.0 * outputshape.1;
414        let scale = match norm_mode {
415            NormMode::Backward => 1.0 / (total_elements as f64),
416            NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
417            NormMode::Forward => 1.0 / (total_elements as f64),
418            NormMode::None => 1.0, // Should not happen due to earlier check
419        };
420
421        padded_input.mapv_inplace(|x| x * scale);
422    }
423
424    // Copy result to output
425    output.assign(&padded_input);
426
427    Ok(output)
428}
429
430/// Compute the inverse 2-dimensional Fast Fourier Transform
431///
432/// # Arguments
433///
434/// * `input` - Input 2D array
435/// * `shape` - Shape of the output (optional)
436/// * `axes` - Axes along which to compute the inverse FFT (optional)
437/// * `norm` - Normalization mode: "backward", "ortho", or "forward" (optional)
438///
439/// # Returns
440///
441/// A 2D array of complex values representing the inverse FFT result
442///
443/// # Examples
444///
445/// ```
446/// use scirs2_fft::{fft2, ifft2};
447/// use ndarray::{array, Array2};
448///
449/// // Create a simple 2x2 array
450/// let input = array![[1.0, 2.0], [3.0, 4.0]];
451///
452/// // Compute the 2D FFT
453/// let spectrum = fft2(&input, None, None, None).unwrap();
454///
455/// // Compute the inverse 2D FFT
456/// let reconstructed = ifft2(&spectrum, None, None, None).unwrap();
457///
458/// // The reconstructed signal should match the original
459/// for i in 0..2 {
460///     for j in 0..2 {
461///         assert!((input[[i, j]] - reconstructed[[i, j]].re).abs() < 1e-10);
462///         assert!(reconstructed[[i, j]].im.abs() < 1e-10);
463///     }
464/// }
465/// ```
466#[allow(dead_code)]
467pub fn ifft2<T>(
468    input: &Array2<T>,
469    shape: Option<(usize, usize)>,
470    axes: Option<(i32, i32)>,
471    norm: Option<&str>,
472) -> FFTResult<Array2<Complex64>>
473where
474    T: NumCast + Copy + Debug + 'static,
475{
476    // Get input array shape
477    let inputshape = input.shape();
478
479    // Determine output shape
480    let outputshape = shape.unwrap_or((inputshape[0], inputshape[1]));
481
482    // Determine axes to perform FFT on
483    let axes = axes.unwrap_or((0, 1));
484
485    // Validate axes
486    if axes.0 < 0 || axes.0 > 1 || axes.1 < 0 || axes.1 > 1 || axes.0 == axes.1 {
487        return Err(FFTError::ValueError("Invalid axes for 2D IFFT".to_string()));
488    }
489
490    // Parse normalization mode (default is "backward" for inverse FFT)
491    let norm_mode = parse_norm_mode(norm, true);
492
493    // Convert input to complex and copy to output shape
494    let mut complex_input = Array2::<Complex64>::zeros((inputshape[0], inputshape[1]));
495    for i in 0..inputshape[0] {
496        for j in 0..inputshape[1] {
497            let val = input[[i, j]];
498
499            // Convert using the unified conversion function
500            complex_input[[i, j]] = convert_to_complex(val)?;
501        }
502    }
503
504    // Pad or truncate to match output shape if necessary
505    let mut padded_input = if inputshape != [outputshape.0, outputshape.1] {
506        let mut padded = Array2::<Complex64>::zeros((outputshape.0, outputshape.1));
507        let copy_rows = std::cmp::min(inputshape[0], outputshape.0);
508        let copy_cols = std::cmp::min(inputshape[1], outputshape.1);
509
510        for i in 0..copy_rows {
511            for j in 0..copy_cols {
512                padded[[i, j]] = complex_input[[i, j]];
513            }
514        }
515        padded
516    } else {
517        complex_input
518    };
519
520    // Create FFT planner
521    let mut planner = FftPlanner::new();
522
523    // Perform inverse FFT along each row
524    let row_ifft = planner.plan_fft_inverse(outputshape.1);
525    for mut row in padded_input.rows_mut() {
526        // Convert to rustfft compatible format
527        let mut buffer: Vec<RustComplex<f64>> =
528            row.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
529
530        // Perform inverse FFT
531        row_ifft.process(&mut buffer);
532
533        // Update row with IFFT result
534        for (i, val) in buffer.iter().enumerate() {
535            row[i] = Complex64::new(val.re, val.im);
536        }
537    }
538
539    // Perform inverse FFT along each column
540    let col_ifft = planner.plan_fft_inverse(outputshape.0);
541    for mut col in padded_input.columns_mut() {
542        // Convert to rustfft compatible format
543        let mut buffer: Vec<RustComplex<f64>> =
544            col.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
545
546        // Perform inverse FFT
547        col_ifft.process(&mut buffer);
548
549        // Update column with IFFT result
550        for (i, val) in buffer.iter().enumerate() {
551            col[i] = Complex64::new(val.re, val.im);
552        }
553    }
554
555    // Apply appropriate normalization
556    let total_elements = outputshape.0 * outputshape.1;
557    let scale = match norm_mode {
558        NormMode::Backward => 1.0 / (total_elements as f64),
559        NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
560        NormMode::Forward => 1.0, // No additional normalization for forward mode in IFFT
561        NormMode::None => 1.0,    // No normalization
562    };
563
564    if scale != 1.0 {
565        padded_input.mapv_inplace(|x| x * scale);
566    }
567
568    Ok(padded_input)
569}
570
571/// Compute the N-dimensional Fast Fourier Transform
572///
573/// # Arguments
574///
575/// * `input` - Input N-dimensional array
576/// * `shape` - Shape of the output (optional)
577/// * `axes` - Axes along which to compute the FFT (optional)
578/// * `norm` - Normalization mode: "backward", "ortho", or "forward" (optional)
579/// * `overwrite_x` - Whether to overwrite the input array (optional)
580/// * `workers` - Number of worker threads to use (optional)
581///
582/// # Returns
583///
584/// An N-dimensional array of complex values representing the FFT result
585///
586/// # Examples
587///
588/// ```
589/// use scirs2_fft::fftn;
590/// use ndarray::{Array, IxDyn};
591///
592/// // Create a 3D array
593/// let mut data = Array::zeros(IxDyn(&[2, 2, 2]));
594/// data[[0, 0, 0]] = 1.0;
595/// data[[1, 1, 1]] = 1.0;
596///
597/// // Compute the N-dimensional FFT
598/// let result = fftn(&data, None, None, None, None, None).unwrap();
599///
600/// // Check dimensions
601/// assert_eq!(result.shape(), &[2, 2, 2]);
602/// ```
603#[allow(clippy::too_many_arguments)]
604#[allow(dead_code)]
605pub fn fftn<T>(
606    input: &ArrayD<T>,
607    shape: Option<Vec<usize>>,
608    axes: Option<Vec<usize>>,
609    norm: Option<&str>,
610    _overwrite_x: Option<bool>,
611    _workers: Option<usize>,
612) -> FFTResult<ArrayD<Complex64>>
613where
614    T: NumCast + Copy + Debug + 'static,
615{
616    let inputshape = input.shape().to_vec();
617    let input_ndim = inputshape.len();
618
619    // Determine output shape
620    let outputshape = shape.unwrap_or_else(|| inputshape.clone());
621
622    // Validate output shape
623    if outputshape.len() != input_ndim {
624        return Err(FFTError::ValueError(
625            "Output shape must have the same number of dimensions as input".to_string(),
626        ));
627    }
628
629    // Determine axes to perform FFT on
630    let axes = axes.unwrap_or_else(|| (0..input_ndim).collect());
631
632    // Validate axes
633    for &axis in &axes {
634        if axis >= input_ndim {
635            return Err(FFTError::ValueError(format!(
636                "Axis {axis} out of bounds for array of dimension {input_ndim}"
637            )));
638        }
639    }
640
641    // Parse normalization mode
642    let norm_mode = parse_norm_mode(norm, false);
643
644    // Convert input array to complex
645    let mut complex_input = ArrayD::<Complex64>::zeros(IxDyn(&inputshape));
646    for (idx, &val) in input.iter().enumerate() {
647        let mut idx_vec = Vec::with_capacity(input_ndim);
648        let mut remaining = idx;
649
650        for &dim in input.shape().iter().rev() {
651            idx_vec.push(remaining % dim);
652            remaining /= dim;
653        }
654
655        idx_vec.reverse();
656
657        complex_input[IxDyn(&idx_vec)] = convert_to_complex(val)?;
658    }
659
660    // Pad or truncate to match output shape if necessary
661    let mut result = if inputshape != outputshape {
662        let mut padded = ArrayD::<Complex64>::zeros(IxDyn(&outputshape));
663
664        // Copy all elements that fit within both arrays
665        for (idx, &val) in complex_input.iter().enumerate() {
666            let mut idx_vec = Vec::with_capacity(input_ndim);
667            let mut remaining = idx;
668
669            for &dim in input.shape().iter().rev() {
670                idx_vec.push(remaining % dim);
671                remaining /= dim;
672            }
673
674            idx_vec.reverse();
675
676            let mut in_bounds = true;
677            for (dim, &idx_val) in idx_vec.iter().enumerate() {
678                if idx_val >= outputshape[dim] {
679                    in_bounds = false;
680                    break;
681                }
682            }
683
684            if in_bounds {
685                padded[IxDyn(&idx_vec)] = val;
686            }
687        }
688
689        padded
690    } else {
691        complex_input
692    };
693
694    // Create FFT planner
695    let mut planner = FftPlanner::new();
696
697    // Perform FFT along each axis
698    for &axis in &axes {
699        let axis_len = outputshape[axis];
700        let fft = planner.plan_fft_forward(axis_len);
701
702        // For each slice along the current axis
703        let axis = Axis(axis);
704
705        for mut lane in result.lanes_mut(axis) {
706            // Convert to rustfft compatible format
707            let mut buffer: Vec<RustComplex<f64>> =
708                lane.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
709
710            // Perform FFT
711            fft.process(&mut buffer);
712
713            // Update lane with FFT result
714            for (i, val) in buffer.iter().enumerate() {
715                lane[i] = Complex64::new(val.re, val.im);
716            }
717        }
718    }
719
720    // Apply normalization if needed
721    if norm_mode != NormMode::None {
722        let total_elements: usize = outputshape.iter().product();
723        let scale = match norm_mode {
724            NormMode::Backward => 1.0 / (total_elements as f64),
725            NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
726            NormMode::Forward => 1.0 / (total_elements as f64),
727            NormMode::None => 1.0, // Should not happen due to earlier check
728        };
729
730        result.mapv_inplace(|_x| _x * scale);
731    }
732
733    Ok(result)
734}
735
736/// Compute the inverse N-dimensional Fast Fourier Transform
737///
738/// # Arguments
739///
740/// * `input` - Input N-dimensional array
741/// * `shape` - Shape of the output (optional)
742/// * `axes` - Axes along which to compute the inverse FFT (optional)
743/// * `norm` - Normalization mode: "backward", "ortho", or "forward" (optional)
744/// * `overwrite_x` - Whether to overwrite the input array (optional)
745/// * `workers` - Number of worker threads to use (optional)
746///
747/// # Returns
748///
749/// An N-dimensional array of complex values representing the inverse FFT result
750///
751/// # Examples
752///
753/// ```
754/// use scirs2_fft::{fftn, ifftn};
755/// use ndarray::{Array, IxDyn};
756/// use num_complex::Complex64;
757///
758/// // Create a 3D array
759/// let mut data = Array::zeros(IxDyn(&[2, 2, 2]));
760/// data[[0, 0, 0]] = 1.0;
761/// data[[1, 1, 1]] = 1.0;
762///
763/// // Compute the N-dimensional FFT
764/// let spectrum = fftn(&data, None, None, None, None, None).unwrap();
765///
766/// // Compute the inverse N-dimensional FFT
767/// let result = ifftn(&spectrum, None, None, None, None, None).unwrap();
768///
769/// // Check if the original data is recovered
770/// for i in 0..2 {
771///     for j in 0..2 {
772///         for k in 0..2 {
773///             let expected = if (i == 0 && j == 0 && k == 0) || (i == 1 && j == 1 && k == 1) {
774///                 1.0
775///             } else {
776///                 0.0
777///             };
778///             assert!((result[[i, j, k]].re - expected).abs() < 1e-10);
779///             assert!(result[[i, j, k]].im.abs() < 1e-10);
780///         }
781///     }
782/// }
783/// ```
784#[allow(clippy::too_many_arguments)]
785#[allow(dead_code)]
786pub fn ifftn<T>(
787    input: &ArrayD<T>,
788    shape: Option<Vec<usize>>,
789    axes: Option<Vec<usize>>,
790    norm: Option<&str>,
791    _overwrite_x: Option<bool>,
792    _workers: Option<usize>,
793) -> FFTResult<ArrayD<Complex64>>
794where
795    T: NumCast + Copy + Debug + 'static,
796{
797    let inputshape = input.shape().to_vec();
798    let input_ndim = inputshape.len();
799
800    // Determine output shape
801    let outputshape = shape.unwrap_or_else(|| inputshape.clone());
802
803    // Validate output shape
804    if outputshape.len() != input_ndim {
805        return Err(FFTError::ValueError(
806            "Output shape must have the same number of dimensions as input".to_string(),
807        ));
808    }
809
810    // Determine axes to perform FFT on
811    let axes = axes.unwrap_or_else(|| (0..input_ndim).collect());
812
813    // Validate axes
814    for &axis in &axes {
815        if axis >= input_ndim {
816            return Err(FFTError::ValueError(format!(
817                "Axis {axis} out of bounds for array of dimension {input_ndim}"
818            )));
819        }
820    }
821
822    // Parse normalization mode (default is "backward" for inverse FFT)
823    let norm_mode = parse_norm_mode(norm, true);
824
825    // Create workspace array - convert input to complex first
826    let mut complex_input = ArrayD::<Complex64>::zeros(IxDyn(&inputshape));
827    for (idx, &val) in input.iter().enumerate() {
828        let mut idx_vec = Vec::with_capacity(input_ndim);
829        let mut remaining = idx;
830
831        for &dim in input.shape().iter().rev() {
832            idx_vec.push(remaining % dim);
833            remaining /= dim;
834        }
835
836        idx_vec.reverse();
837
838        // Try to convert to Complex64
839        complex_input[IxDyn(&idx_vec)] = convert_to_complex(val)?;
840    }
841
842    // Now handle padding/resizing if needed
843    let mut result = if inputshape != outputshape {
844        let mut padded = ArrayD::<Complex64>::zeros(IxDyn(&outputshape));
845
846        // Copy all elements that fit within both arrays
847        for (idx, &val) in complex_input.iter().enumerate() {
848            let mut idx_vec = Vec::with_capacity(input_ndim);
849            let mut remaining = idx;
850
851            for &dim in input.shape().iter().rev() {
852                idx_vec.push(remaining % dim);
853                remaining /= dim;
854            }
855
856            idx_vec.reverse();
857
858            let mut in_bounds = true;
859            for (dim, &idx_val) in idx_vec.iter().enumerate() {
860                if idx_val >= outputshape[dim] {
861                    in_bounds = false;
862                    break;
863                }
864            }
865
866            if in_bounds {
867                padded[IxDyn(&idx_vec)] = val;
868            }
869        }
870
871        padded
872    } else {
873        complex_input
874    };
875
876    // Create FFT planner
877    let mut planner = FftPlanner::new();
878
879    // Perform inverse FFT along each axis
880    for &axis in &axes {
881        let axis_len = outputshape[axis];
882        let ifft = planner.plan_fft_inverse(axis_len);
883
884        // For each slice along the current axis
885        let axis = Axis(axis);
886
887        for mut lane in result.lanes_mut(axis) {
888            // Convert to rustfft compatible format
889            let mut buffer: Vec<RustComplex<f64>> =
890                lane.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
891
892            // Perform inverse FFT
893            ifft.process(&mut buffer);
894
895            // Update lane with IFFT result
896            for (i, val) in buffer.iter().enumerate() {
897                lane[i] = Complex64::new(val.re, val.im);
898            }
899        }
900    }
901
902    // Apply appropriate normalization
903    if norm_mode != NormMode::None {
904        let total_elements: usize = axes.iter().map(|&a| outputshape[a]).product();
905        let scale = match norm_mode {
906            NormMode::Backward => 1.0 / (total_elements as f64),
907            NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
908            NormMode::Forward => 1.0, // No additional normalization
909            NormMode::None => 1.0,    // No normalization
910        };
911
912        if scale != 1.0 {
913            result.mapv_inplace(|_x| _x * scale);
914        }
915    }
916
917    Ok(result)
918}