1use crate::error::FFTResult;
9use crate::fft::algorithms::{parse_norm_mode, NormMode};
10use ndarray::{Array2, Axis};
11use num_complex::Complex64;
12use num_traits::NumCast;
13use rustfft::{num_complex::Complex as RustComplex, FftPlanner};
14
15use scirs2_core::parallel_ops::*;
16
17#[cfg(feature = "parallel")]
47#[allow(clippy::too_many_arguments)]
48#[allow(dead_code)]
49pub fn fft2_parallel<T>(
50 input: &Array2<T>,
51 shape: Option<(usize, usize)>,
52 axes: Option<(i32, i32)>,
53 norm: Option<&str>,
54 workers: Option<usize>,
55) -> FFTResult<Array2<Complex64>>
56where
57 T: NumCast + Copy + std::fmt::Debug + 'static,
58{
59 let inputshape = input.shape();
61
62 let outputshape = shape.unwrap_or((inputshape[0], inputshape[1]));
64
65 let axes = axes.unwrap_or((0, 1));
67
68 if axes.0 < 0 || axes.0 > 1 || axes.1 < 0 || axes.1 > 1 || axes.0 == axes.1 {
70 return Err(crate::FFTError::ValueError(
71 "Invalid axes for 2D FFT".to_string(),
72 ));
73 }
74
75 let norm_mode = parse_norm_mode(norm, false);
77
78 #[cfg(feature = "parallel")]
80 let num_workers = workers.unwrap_or_else(|| num_threads().min(8));
81
82 let mut complex_input = Array2::<Complex64>::zeros((inputshape[0], inputshape[1]));
84 for i in 0..inputshape[0] {
85 for j in 0..inputshape[1] {
86 let val = input[[i, j]];
87
88 if let Some(c) = crate::fft::utility::try_as_complex(val) {
90 complex_input[[i, j]] = c;
91 } else {
92 let real = num_traits::cast::<T, f64>(val).ok_or_else(|| {
94 crate::FFTError::ValueError(format!("Could not convert {val:?} to f64"))
95 })?;
96 complex_input[[i, j]] = Complex64::new(real, 0.0);
97 }
98 }
99 }
100
101 let mut padded_input = if inputshape != [outputshape.0, outputshape.1] {
103 let mut padded = Array2::<Complex64>::zeros((outputshape.0, outputshape.1));
104 let copy_rows = std::cmp::min(inputshape[0], outputshape.0);
105 let copy_cols = std::cmp::min(inputshape[1], outputshape.1);
106
107 for i in 0..copy_rows {
108 for j in 0..copy_cols {
109 padded[[i, j]] = complex_input[[i, j]];
110 }
111 }
112 padded
113 } else {
114 complex_input
115 };
116
117 let mut planner = FftPlanner::new();
119
120 let row_fft = planner.plan_fft_forward(outputshape.1);
122
123 if num_workers > 1 {
124 padded_input
125 .axis_iter_mut(Axis(0))
126 .into_par_iter()
127 .for_each(|mut row| {
128 let mut buffer: Vec<RustComplex<f64>> =
130 row.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
131
132 row_fft.process(&mut buffer);
134
135 for (i, val) in buffer.iter().enumerate() {
137 row[i] = Complex64::new(val.re, val.im);
138 }
139 });
140 } else {
141 for mut row in padded_input.rows_mut() {
143 let mut buffer: Vec<RustComplex<f64>> =
144 row.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
145
146 row_fft.process(&mut buffer);
147
148 for (i, val) in buffer.iter().enumerate() {
149 row[i] = Complex64::new(val.re, val.im);
150 }
151 }
152 }
153
154 let col_fft = planner.plan_fft_forward(outputshape.0);
156
157 if num_workers > 1 {
158 padded_input
159 .axis_iter_mut(Axis(1))
160 .into_par_iter()
161 .for_each(|mut col| {
162 let mut buffer: Vec<RustComplex<f64>> =
164 col.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
165
166 col_fft.process(&mut buffer);
168
169 for (i, val) in buffer.iter().enumerate() {
171 col[i] = Complex64::new(val.re, val.im);
172 }
173 });
174 } else {
175 for mut col in padded_input.columns_mut() {
177 let mut buffer: Vec<RustComplex<f64>> =
178 col.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
179
180 col_fft.process(&mut buffer);
181
182 for (i, val) in buffer.iter().enumerate() {
183 col[i] = Complex64::new(val.re, val.im);
184 }
185 }
186 }
187
188 if norm_mode != NormMode::None {
190 let total_elements = outputshape.0 * outputshape.1;
191 let scale = match norm_mode {
192 NormMode::Backward => 1.0 / (total_elements as f64),
193 NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
194 NormMode::Forward => 1.0 / (total_elements as f64),
195 NormMode::None => 1.0, };
197
198 padded_input.mapv_inplace(|x| x * scale);
199 }
200
201 Ok(padded_input)
202}
203
204#[cfg(not(feature = "parallel"))]
206#[allow(dead_code)]
207pub fn fft2_parallel<T>(
208 input: &Array2<T>,
209 shape: Option<(usize, usize)>,
210 _axes: Option<(i32, i32)>,
211 _norm: Option<&str>,
212 _workers: Option<usize>,
213) -> FFTResult<Array2<Complex64>>
214where
215 T: NumCast + Copy + std::fmt::Debug + 'static,
216{
217 crate::fft::algorithms::fft2(input, shape, None, None)
219}
220
221#[cfg(feature = "parallel")]
235#[allow(clippy::too_many_arguments)]
236#[allow(dead_code)]
237pub fn ifft2_parallel<T>(
238 input: &Array2<T>,
239 shape: Option<(usize, usize)>,
240 axes: Option<(i32, i32)>,
241 norm: Option<&str>,
242 workers: Option<usize>,
243) -> FFTResult<Array2<Complex64>>
244where
245 T: NumCast + Copy + std::fmt::Debug + 'static,
246{
247 let inputshape = input.shape();
249
250 let outputshape = shape.unwrap_or((inputshape[0], inputshape[1]));
252
253 let axes = axes.unwrap_or((0, 1));
255
256 if axes.0 < 0 || axes.0 > 1 || axes.1 < 0 || axes.1 > 1 || axes.0 == axes.1 {
258 return Err(crate::FFTError::ValueError(
259 "Invalid axes for 2D IFFT".to_string(),
260 ));
261 }
262
263 let norm_mode = parse_norm_mode(norm, true);
265
266 #[cfg(feature = "parallel")]
268 let num_workers = workers.unwrap_or_else(|| num_threads().min(8));
269
270 let mut complex_input = Array2::<Complex64>::zeros((inputshape[0], inputshape[1]));
272 for i in 0..inputshape[0] {
273 for j in 0..inputshape[1] {
274 let val = input[[i, j]];
275
276 if let Some(c) = crate::fft::utility::try_as_complex(val) {
278 complex_input[[i, j]] = c;
279 } else {
280 let real = num_traits::cast::<T, f64>(val).ok_or_else(|| {
282 crate::FFTError::ValueError(format!("Could not convert {val:?} to f64"))
283 })?;
284 complex_input[[i, j]] = Complex64::new(real, 0.0);
285 }
286 }
287 }
288
289 let mut padded_input = if inputshape != [outputshape.0, outputshape.1] {
291 let mut padded = Array2::<Complex64>::zeros((outputshape.0, outputshape.1));
292 let copy_rows = std::cmp::min(inputshape[0], outputshape.0);
293 let copy_cols = std::cmp::min(inputshape[1], outputshape.1);
294
295 for i in 0..copy_rows {
296 for j in 0..copy_cols {
297 padded[[i, j]] = complex_input[[i, j]];
298 }
299 }
300 padded
301 } else {
302 complex_input
303 };
304
305 let mut planner = FftPlanner::new();
307
308 let row_ifft = planner.plan_fft_inverse(outputshape.1);
310
311 if num_workers > 1 {
312 padded_input
313 .axis_iter_mut(Axis(0))
314 .into_par_iter()
315 .for_each(|mut row| {
316 let mut buffer: Vec<RustComplex<f64>> =
318 row.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
319
320 row_ifft.process(&mut buffer);
322
323 for (i, val) in buffer.iter().enumerate() {
325 row[i] = Complex64::new(val.re, val.im);
326 }
327 });
328 } else {
329 for mut row in padded_input.rows_mut() {
331 let mut buffer: Vec<RustComplex<f64>> =
332 row.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
333
334 row_ifft.process(&mut buffer);
335
336 for (i, val) in buffer.iter().enumerate() {
337 row[i] = Complex64::new(val.re, val.im);
338 }
339 }
340 }
341
342 let col_ifft = planner.plan_fft_inverse(outputshape.0);
344
345 if num_workers > 1 {
346 padded_input
347 .axis_iter_mut(Axis(1))
348 .into_par_iter()
349 .for_each(|mut col| {
350 let mut buffer: Vec<RustComplex<f64>> =
352 col.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
353
354 col_ifft.process(&mut buffer);
356
357 for (i, val) in buffer.iter().enumerate() {
359 col[i] = Complex64::new(val.re, val.im);
360 }
361 });
362 } else {
363 for mut col in padded_input.columns_mut() {
365 let mut buffer: Vec<RustComplex<f64>> =
366 col.iter().map(|&c| RustComplex::new(c.re, c.im)).collect();
367
368 col_ifft.process(&mut buffer);
369
370 for (i, val) in buffer.iter().enumerate() {
371 col[i] = Complex64::new(val.re, val.im);
372 }
373 }
374 }
375
376 let total_elements = outputshape.0 * outputshape.1;
378 let scale = match norm_mode {
379 NormMode::Backward => 1.0 / (total_elements as f64),
380 NormMode::Ortho => 1.0 / (total_elements as f64).sqrt(),
381 NormMode::Forward => 1.0, NormMode::None => 1.0, };
384
385 if scale != 1.0 {
386 padded_input.mapv_inplace(|x| x * scale);
387 }
388
389 Ok(padded_input)
390}
391
392#[cfg(not(feature = "parallel"))]
394#[allow(dead_code)]
395pub fn ifft2_parallel<T>(
396 input: &Array2<T>,
397 shape: Option<(usize, usize)>,
398 _axes: Option<(i32, i32)>,
399 _norm: Option<&str>,
400 _workers: Option<usize>,
401) -> FFTResult<Array2<Complex64>>
402where
403 T: NumCast + Copy + std::fmt::Debug + 'static,
404{
405 crate::fft::algorithms::ifft2(input, shape, None, None)
407}