scirs2_fft/fft/
planning.rs

1/*!
2 * Parallel FFT algorithms implementations
3 *
4 * This module provides implementations of parallel Fast Fourier Transform (FFT)
5 * algorithms for multi-threaded execution on multi-core CPUs.
6 */
7
8use crate::error::FFTResult;
9use crate::fft::algorithms::{parse_norm_mode, NormMode};
10use rustfft::{num_complex::Complex as RustComplex, FftPlanner};
11use scirs2_core::ndarray::{Array2, Axis};
12use scirs2_core::numeric::Complex64;
13use scirs2_core::numeric::NumCast;
14
15use scirs2_core::parallel_ops::*;
16
17/// Compute a 2D FFT using parallel processing for rows and columns
18///
19/// # Arguments
20///
21/// * `input` - Input 2D array
22/// * `shape` - Shape of the output (optional)
23/// * `axes` - Axes along which to compute the FFT (optional)
24/// * `norm` - Normalization mode (optional)
25/// * `workers` - Number of worker threads to use (optional)
26///
27/// # Returns
28///
29/// A 2D array of complex values representing the parallel FFT result
30///
31/// # Examples
32///
33/// ```
34/// use scirs2_fft::fft2_parallel;
35/// use scirs2_core::ndarray::{array, Array2};
36///
37/// // Create a simple 2x2 array
38/// let input = array![[1.0, 2.0], [3.0, 4.0]];
39///
40/// // Compute the 2D FFT in parallel
41/// let result = fft2_parallel(&input, None, None, None, None).unwrap();
42///
43/// // The DC component should be the sum of all elements
44/// assert!((result[[0, 0]].re - 10.0).abs() < 1e-10);
45/// ```
46#[cfg(feature = "parallel")]
47#[allow(clippy::too_many_arguments)]
48#[allow(dead_code)]
49pub fn fft2_parallel<T>(
50    input: &Array2<T>,
51    shape: Option<(usize, usize)>,
52    axes: Option<(i32, i32)>,
53    norm: Option<&str>,
54    workers: Option<usize>,
55) -> FFTResult<Array2<Complex64>>
56where
57    T: NumCast + Copy + std::fmt::Debug + 'static,
58{
59    // Get input array shape
60    let inputshape = input.shape();
61
62    // Determine output shape
63    let outputshape = shape.unwrap_or((inputshape[0], inputshape[1]));
64
65    // Determine axes to perform FFT on
66    let axes = axes.unwrap_or((0, 1));
67
68    // Validate axes
69    if axes.0 < 0 || axes.0 > 1 || axes.1 < 0 || axes.1 > 1 || axes.0 == axes.1 {
70        return Err(crate::FFTError::ValueError(
71            "Invalid axes for 2D FFT".to_string(),
72        ));
73    }
74
75    // Parse normalization mode
76    let norm_mode = parse_norm_mode(norm, false);
77
78    // Number of workers for parallel computation
79    #[cfg(feature = "parallel")]
80    let num_workers = workers.unwrap_or_else(|| num_threads().min(8));
81
82    // Convert input array to complex numbers
83    let mut complex_input = Array2::<Complex64>::zeros((inputshape[0], inputshape[1]));
84    for i in 0..inputshape[0] {
85        for j in 0..inputshape[1] {
86            let val = input[[i, j]];
87
88            // Try to convert to Complex64
89            if let Some(c) = crate::fft::utility::try_as_complex(val) {
90                complex_input[[i, j]] = c;
91            } else {
92                // Not a complex number, try to convert to f64 and make into a complex with zero imaginary part
93                let real = NumCast::from(val).ok_or_else(|| {
94                    crate::FFTError::ValueError(format!("Could not convert {val:?} to f64"))
95                })?;
96                complex_input[[i, j]] = Complex64::new(real, 0.0);
97            }
98        }
99    }
100
101    // Pad or truncate to match output shape if necessary
102    let mut padded_input = if inputshape != [outputshape.0, outputshape.1] {
103        let mut padded = Array2::<Complex64>::zeros((outputshape.0, outputshape.1));
104        let copy_rows = std::cmp::min(inputshape[0], outputshape.0);
105        let copy_cols = std::cmp::min(inputshape[1], outputshape.1);
106
107        for i in 0..copy_rows {
108            for j in 0..copy_cols {
109                padded[[i, j]] = complex_input[[i, j]];
110            }
111        }
112        padded
113    } else {
114        complex_input
115    };
116
117    // Create FFT planner
118    let mut planner = FftPlanner::new();
119
120    // Perform FFT along each row in parallel
121    let row_fft = planner.plan_fft_forward(outputshape.1);
122
123    if num_workers > 1 {
124        padded_input
125            .axis_iter_mut(Axis(0))
126            .into_par_iter()
127            .for_each(|mut row| {
128                // Convert to rustfft compatible format
129                let mut buffer: Vec<RustComplex<f64>> =
130                    row.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
131
132                // Perform FFT
133                row_fft.process(&mut buffer);
134
135                // Update row with FFT result
136                for (i, val) in buffer.iter().enumerate() {
137                    row[i] = Complex64::new(val.re, val.im);
138                }
139            });
140    } else {
141        // Fall back to sequential processing if only one worker
142        for mut row in padded_input.rows_mut() {
143            let mut buffer: Vec<RustComplex<f64>> =
144                row.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
145
146            row_fft.process(&mut buffer);
147
148            for (i, val) in buffer.iter().enumerate() {
149                row[i] = Complex64::new(val.re, val.im);
150            }
151        }
152    }
153
154    // Perform FFT along each column in parallel
155    let col_fft = planner.plan_fft_forward(outputshape.0);
156
157    if num_workers > 1 {
158        padded_input
159            .axis_iter_mut(Axis(1))
160            .into_par_iter()
161            .for_each(|mut col| {
162                // Convert to rustfft compatible format
163                let mut buffer: Vec<RustComplex<f64>> =
164                    col.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
165
166                // Perform FFT
167                col_fft.process(&mut buffer);
168
169                // Update column with FFT result
170                for (i, val) in buffer.iter().enumerate() {
171                    col[i] = Complex64::new(val.re, val.im);
172                }
173            });
174    } else {
175        // Fall back to sequential processing if only one worker
176        for mut col in padded_input.columns_mut() {
177            let mut buffer: Vec<RustComplex<f64>> =
178                col.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
179
180            col_fft.process(&mut buffer);
181
182            for (i, val) in buffer.iter().enumerate() {
183                col[i] = Complex64::new(val.re, val.im);
184            }
185        }
186    }
187
188    // Apply normalization if needed
189    if norm_mode != NormMode::None {
190        let total_elements = outputshape.0 * outputshape.1;
191        let scale = match norm_mode {
192            NormMode::Backward => 1.0 / (total_elements as f64),
193            NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
194            NormMode::Forward => 1.0 / (total_elements as f64),
195            NormMode::None => 1.0, // Should not happen due to earlier check
196        };
197
198        padded_input.mapv_inplace(|x| x * scale);
199    }
200
201    Ok(padded_input)
202}
203
204/// Non-parallel fallback implementation of fft2_parallel for when the parallel feature is disabled
205#[cfg(not(feature = "parallel"))]
206#[allow(dead_code)]
207pub fn fft2_parallel<T>(
208    input: &Array2<T>,
209    shape: Option<(usize, usize)>,
210    _axes: Option<(i32, i32)>,
211    _norm: Option<&str>,
212    _workers: Option<usize>,
213) -> FFTResult<Array2<Complex64>>
214where
215    T: NumCast + Copy + std::fmt::Debug + 'static,
216{
217    // When parallel feature is disabled, just use the standard fft2 implementation
218    crate::fft::algorithms::fft2(input, shape, None, None)
219}
220
221/// Compute the inverse 2D FFT using parallel processing
222///
223/// # Arguments
224///
225/// * `input` - Input 2D array of complex values
226/// * `shape` - Shape of the output (optional)
227/// * `axes` - Axes along which to compute the FFT (optional)
228/// * `norm` - Normalization mode (optional)
229/// * `workers` - Number of worker threads to use (optional)
230///
231/// # Returns
232///
233/// A 2D array of complex values representing the inverse FFT result
234#[cfg(feature = "parallel")]
235#[allow(clippy::too_many_arguments)]
236#[allow(dead_code)]
237pub fn ifft2_parallel<T>(
238    input: &Array2<T>,
239    shape: Option<(usize, usize)>,
240    axes: Option<(i32, i32)>,
241    norm: Option<&str>,
242    workers: Option<usize>,
243) -> FFTResult<Array2<Complex64>>
244where
245    T: NumCast + Copy + std::fmt::Debug + 'static,
246{
247    // Get input array shape
248    let inputshape = input.shape();
249
250    // Determine output shape
251    let outputshape = shape.unwrap_or((inputshape[0], inputshape[1]));
252
253    // Determine axes to perform FFT on
254    let axes = axes.unwrap_or((0, 1));
255
256    // Validate axes
257    if axes.0 < 0 || axes.0 > 1 || axes.1 < 0 || axes.1 > 1 || axes.0 == axes.1 {
258        return Err(crate::FFTError::ValueError(
259            "Invalid axes for 2D IFFT".to_string(),
260        ));
261    }
262
263    // Parse normalization mode (default is "backward" for inverse FFT)
264    let norm_mode = parse_norm_mode(norm, true);
265
266    // Number of workers for parallel computation
267    #[cfg(feature = "parallel")]
268    let num_workers = workers.unwrap_or_else(|| num_threads().min(8));
269
270    // Convert input to complex and copy to output shape
271    let mut complex_input = Array2::<Complex64>::zeros((inputshape[0], inputshape[1]));
272    for i in 0..inputshape[0] {
273        for j in 0..inputshape[1] {
274            let val = input[[i, j]];
275
276            // Try to convert to Complex64
277            if let Some(c) = crate::fft::utility::try_as_complex(val) {
278                complex_input[[i, j]] = c;
279            } else {
280                // Not a complex number, try to convert to f64 and make into a complex with zero imaginary part
281                let real = NumCast::from(val).ok_or_else(|| {
282                    crate::FFTError::ValueError(format!("Could not convert {val:?} to f64"))
283                })?;
284                complex_input[[i, j]] = Complex64::new(real, 0.0);
285            }
286        }
287    }
288
289    // Pad or truncate to match output shape if necessary
290    let mut padded_input = if inputshape != [outputshape.0, outputshape.1] {
291        let mut padded = Array2::<Complex64>::zeros((outputshape.0, outputshape.1));
292        let copy_rows = std::cmp::min(inputshape[0], outputshape.0);
293        let copy_cols = std::cmp::min(inputshape[1], outputshape.1);
294
295        for i in 0..copy_rows {
296            for j in 0..copy_cols {
297                padded[[i, j]] = complex_input[[i, j]];
298            }
299        }
300        padded
301    } else {
302        complex_input
303    };
304
305    // Create FFT planner
306    let mut planner = FftPlanner::new();
307
308    // Perform inverse FFT along each row in parallel
309    let row_ifft = planner.plan_fft_inverse(outputshape.1);
310
311    if num_workers > 1 {
312        padded_input
313            .axis_iter_mut(Axis(0))
314            .into_par_iter()
315            .for_each(|mut row| {
316                // Convert to rustfft compatible format
317                let mut buffer: Vec<RustComplex<f64>> =
318                    row.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
319
320                // Perform inverse FFT
321                row_ifft.process(&mut buffer);
322
323                // Update row with IFFT result
324                for (i, val) in buffer.iter().enumerate() {
325                    row[i] = Complex64::new(val.re, val.im);
326                }
327            });
328    } else {
329        // Fall back to sequential processing if only one worker
330        for mut row in padded_input.rows_mut() {
331            let mut buffer: Vec<RustComplex<f64>> =
332                row.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
333
334            row_ifft.process(&mut buffer);
335
336            for (i, val) in buffer.iter().enumerate() {
337                row[i] = Complex64::new(val.re, val.im);
338            }
339        }
340    }
341
342    // Perform inverse FFT along each column in parallel
343    let col_ifft = planner.plan_fft_inverse(outputshape.0);
344
345    if num_workers > 1 {
346        padded_input
347            .axis_iter_mut(Axis(1))
348            .into_par_iter()
349            .for_each(|mut col| {
350                // Convert to rustfft compatible format
351                let mut buffer: Vec<RustComplex<f64>> =
352                    col.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
353
354                // Perform inverse FFT
355                col_ifft.process(&mut buffer);
356
357                // Update column with IFFT result
358                for (i, val) in buffer.iter().enumerate() {
359                    col[i] = Complex64::new(val.re, val.im);
360                }
361            });
362    } else {
363        // Fall back to sequential processing if only one worker
364        for mut col in padded_input.columns_mut() {
365            let mut buffer: Vec<RustComplex<f64>> =
366                col.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
367
368            col_ifft.process(&mut buffer);
369
370            for (i, val) in buffer.iter().enumerate() {
371                col[i] = Complex64::new(val.re, val.im);
372            }
373        }
374    }
375
376    // Apply appropriate normalization
377    let total_elements = outputshape.0 * outputshape.1;
378    let scale = match norm_mode {
379        NormMode::Backward => 1.0 / (total_elements as f64),
380        NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
381        NormMode::Forward => 1.0, // No additional normalization for forward mode in IFFT
382        NormMode::None => 1.0,    // No normalization
383    };
384
385    if scale != 1.0 {
386        padded_input.mapv_inplace(|x| x * scale);
387    }
388
389    Ok(padded_input)
390}
391
392/// Non-parallel fallback implementation of ifft2_parallel for when the parallel feature is disabled
393#[cfg(not(feature = "parallel"))]
394#[allow(dead_code)]
395pub fn ifft2_parallel<T>(
396    input: &Array2<T>,
397    shape: Option<(usize, usize)>,
398    _axes: Option<(i32, i32)>,
399    _norm: Option<&str>,
400    _workers: Option<usize>,
401) -> FFTResult<Array2<Complex64>>
402where
403    T: NumCast + Copy + std::fmt::Debug + 'static,
404{
405    // When parallel feature is disabled, just use the standard ifft2 implementation
406    crate::fft::algorithms::ifft2(input, shape, None, None)
407}