Skip to main content

scirs2/
fft.rs

1//! Python bindings for scirs2-fft
2//!
3//! This module provides Python bindings for Fast Fourier Transform operations.
4//!
5//! FFT functions return NumPy complex128 arrays for optimal performance and
6//! compatibility with NumPy's FFT functions.
7
8use pyo3::exceptions::PyRuntimeError;
9use pyo3::prelude::*;
10
11// NumPy types for Python array interface (scirs2-numpy with native ndarray 0.17)
12// Complex64 from scirs2_numpy maps to NumPy's complex128 (cdouble)
13use scirs2_numpy::{Complex64 as NumpyComplex64, IntoPyArray, PyArray1, PyArrayMethods};
14
15// ndarray types from scirs2-core
16use scirs2_core::{numeric::Complex64, Array1};
17
18// Direct imports from scirs2-fft (native ndarray 0.17 support)
19use scirs2_fft::{dct, fftfreq, fftshift, idct, ifftshift, next_fast_len, rfftfreq, DCTType};
20
21// Fallback imports for non-OxiFFT builds
22#[cfg(not(feature = "oxifft"))]
23use scirs2_fft::{fft, ifft, irfft};
24
25// ========================================
26// CORE FFT FUNCTIONS
27// ========================================
28
29/// 1D FFT (OxiFFT-optimized for f64!)
30/// Returns NumPy complex128 array (compatible with np.fft.fft output)
31#[pyfunction]
32fn fft_py(py: Python, data: &Bound<'_, PyArray1<f64>>) -> PyResult<Py<PyArray1<NumpyComplex64>>> {
33    let binding = data.readonly();
34    let arr = binding.as_array();
35
36    // Use OxiFFT for f64 (Pure Rust high-performance!)
37    #[cfg(feature = "oxifft")]
38    {
39        // Convert real input to complex for fft_oxifft
40        let complex_input: scirs2_core::ndarray::Array1<Complex64> =
41            arr.iter().map(|&r| Complex64::new(r, 0.0)).collect();
42
43        let result = scirs2_fft::oxifft_backend::fft_oxifft(&complex_input.view())
44            .map_err(|e| PyRuntimeError::new_err(format!("FFT (OxiFFT) failed: {}", e)))?;
45
46        // Convert to NumPy complex128 array
47        let complex_result: Vec<NumpyComplex64> = result
48            .iter()
49            .map(|c| NumpyComplex64::new(c.re, c.im))
50            .collect();
51
52        Ok(Array1::from_vec(complex_result).into_pyarray(py).unbind())
53    }
54
55    // Fallback to pure Rust
56    #[cfg(not(feature = "oxifft"))]
57    {
58        let vec_data: Vec<f64> = arr.to_vec();
59
60        let result = fft(&vec_data, None)
61            .map_err(|e| PyRuntimeError::new_err(format!("FFT failed: {}", e)))?;
62
63        // Convert to NumPy complex128 array
64        let complex_result: Vec<NumpyComplex64> = result
65            .iter()
66            .map(|c| NumpyComplex64::new(c.re, c.im))
67            .collect();
68
69        return Ok(Array1::from_vec(complex_result).into_pyarray(py).unbind());
70    }
71}
72
73/// 1D inverse FFT (OxiFFT-optimized!)
74/// Accepts and returns NumPy complex128 arrays (compatible with np.fft.ifft)
75#[pyfunction]
76fn ifft_py(
77    py: Python,
78    data: &Bound<'_, PyArray1<NumpyComplex64>>,
79) -> PyResult<Py<PyArray1<NumpyComplex64>>> {
80    let binding = data.readonly();
81    let arr = binding.as_array();
82
83    // Use OxiFFT for f64 (Pure Rust high-performance!)
84    #[cfg(feature = "oxifft")]
85    {
86        // Convert NumPy complex to scirs2 Complex64
87        let complex_input: scirs2_core::ndarray::Array1<Complex64> =
88            arr.iter().map(|c| Complex64::new(c.re, c.im)).collect();
89
90        let result = scirs2_fft::oxifft_backend::ifft_oxifft(&complex_input.view())
91            .map_err(|e| PyRuntimeError::new_err(format!("IFFT (OxiFFT) failed: {}", e)))?;
92
93        // Convert back to NumPy complex128
94        let complex_result: Vec<NumpyComplex64> = result
95            .iter()
96            .map(|c| NumpyComplex64::new(c.re, c.im))
97            .collect();
98
99        Ok(Array1::from_vec(complex_result).into_pyarray(py).unbind())
100    }
101
102    // Fallback to pure Rust
103    #[cfg(not(feature = "oxifft"))]
104    {
105        // Convert NumPy complex to Vec<Complex64>
106        let complex_input: Vec<Complex64> =
107            arr.iter().map(|c| Complex64::new(c.re, c.im)).collect();
108
109        let result = ifft(&complex_input, None)
110            .map_err(|e| PyRuntimeError::new_err(format!("IFFT failed: {}", e)))?;
111
112        // Convert back to NumPy complex128
113        let complex_result: Vec<NumpyComplex64> = result
114            .iter()
115            .map(|c| NumpyComplex64::new(c.re, c.im))
116            .collect();
117
118        return Ok(Array1::from_vec(complex_result).into_pyarray(py).unbind());
119    }
120}
121
122/// Real FFT - FFT of real-valued input (OxiFFT-optimized!)
123/// Returns NumPy complex128 array (compatible with np.fft.rfft output)
124#[pyfunction]
125fn rfft_py(py: Python, data: &Bound<'_, PyArray1<f64>>) -> PyResult<Py<PyArray1<NumpyComplex64>>> {
126    let binding = data.readonly();
127    let arr = binding.as_array();
128
129    // Use OxiFFT for f64 (OxiFFT + plan caching = high performance!)
130    #[cfg(feature = "oxifft")]
131    {
132        let result = scirs2_fft::oxifft_backend::rfft_oxifft(&arr)
133            .map_err(|e| PyRuntimeError::new_err(format!("RFFT (OxiFFT) failed: {}", e)))?;
134
135        // Convert to NumPy complex128 array
136        let complex_result: Vec<NumpyComplex64> = result
137            .iter()
138            .map(|c| NumpyComplex64::new(c.re, c.im))
139            .collect();
140
141        Ok(Array1::from_vec(complex_result).into_pyarray(py).unbind())
142    }
143
144    // Fallback to pure Rust
145    #[cfg(not(feature = "oxifft"))]
146    {
147        let vec_data: Vec<f64> = arr.to_vec();
148
149        let result = rfft(&vec_data, None)
150            .map_err(|e| PyRuntimeError::new_err(format!("RFFT failed: {}", e)))?;
151
152        // Convert to NumPy complex128 array
153        let complex_result: Vec<NumpyComplex64> = result
154            .iter()
155            .map(|c| NumpyComplex64::new(c.re, c.im))
156            .collect();
157
158        return Ok(Array1::from_vec(complex_result).into_pyarray(py).unbind());
159    }
160}
161
162/// Inverse real FFT (OxiFFT-optimized!)
163/// Accepts NumPy complex128 array, returns real f64 array (compatible with np.fft.irfft)
164#[pyfunction]
165#[pyo3(signature = (data, n=None))]
166fn irfft_py(
167    py: Python,
168    data: &Bound<'_, PyArray1<NumpyComplex64>>,
169    n: Option<usize>,
170) -> PyResult<Py<PyArray1<f64>>> {
171    let binding = data.readonly();
172    let arr = binding.as_array();
173
174    // Use OxiFFT for f64 (Pure Rust high-performance!)
175    #[cfg(feature = "oxifft")]
176    {
177        // Convert NumPy complex to scirs2 Complex64
178        let complex_input: scirs2_core::ndarray::Array1<Complex64> =
179            arr.iter().map(|c| Complex64::new(c.re, c.im)).collect();
180
181        // Infer output size: if n is None, assume n = 2*(input_len - 1)
182        let output_len = n.unwrap_or_else(|| 2 * (complex_input.len() - 1));
183
184        let result = scirs2_fft::oxifft_backend::irfft_oxifft(&complex_input.view(), output_len)
185            .map_err(|e| PyRuntimeError::new_err(format!("IRFFT (OxiFFT) failed: {}", e)))?;
186
187        Ok(result.into_pyarray(py).unbind())
188    }
189
190    // Fallback to pure Rust
191    #[cfg(not(feature = "oxifft"))]
192    {
193        // Convert NumPy complex to Vec<Complex64>
194        let complex_input: Vec<Complex64> =
195            arr.iter().map(|c| Complex64::new(c.re, c.im)).collect();
196
197        let result = irfft(&complex_input, n)
198            .map_err(|e| PyRuntimeError::new_err(format!("IRFFT failed: {}", e)))?;
199
200        return Ok(Array1::from_vec(result).into_pyarray(py).unbind());
201    }
202}
203
204// ========================================
205// DCT FUNCTIONS
206// ========================================
207
208/// Discrete Cosine Transform (OxiFFT-optimized for Type 2!)
209/// dct_type: 1, 2, 3, or 4
210#[pyfunction]
211#[pyo3(signature = (data, dct_type=2))]
212fn dct_py(
213    py: Python,
214    data: &Bound<'_, PyArray1<f64>>,
215    dct_type: usize,
216) -> PyResult<Py<PyArray1<f64>>> {
217    let binding = data.readonly();
218    let arr = binding.as_array();
219
220    // Use OxiFFT for DCT Type 2 (most common)
221    #[cfg(feature = "oxifft")]
222    if dct_type == 2 {
223        let result = scirs2_fft::oxifft_backend::dct2_oxifft(&arr)
224            .map_err(|e| PyRuntimeError::new_err(format!("DCT-II (OxiFFT) failed: {}", e)))?;
225        return Ok(result.into_pyarray(py).unbind());
226    }
227
228    // Fallback to pure Rust for other types
229    let vec_data: Vec<f64> = arr.to_vec();
230    let dct_type_enum = match dct_type {
231        1 => DCTType::Type1,
232        2 => DCTType::Type2,
233        3 => DCTType::Type3,
234        4 => DCTType::Type4,
235        _ => {
236            return Err(PyRuntimeError::new_err(format!(
237                "Invalid DCT type: {}",
238                dct_type
239            )))
240        }
241    };
242
243    let result = dct(&vec_data, Some(dct_type_enum), None)
244        .map_err(|e| PyRuntimeError::new_err(format!("DCT failed: {}", e)))?;
245
246    Ok(Array1::from_vec(result).into_pyarray(py).unbind())
247}
248
249/// Inverse Discrete Cosine Transform (OxiFFT-optimized for Type 2!)
250/// dct_type: 1, 2, 3, or 4
251#[pyfunction]
252#[pyo3(signature = (data, dct_type=2))]
253fn idct_py(
254    py: Python,
255    data: &Bound<'_, PyArray1<f64>>,
256    dct_type: usize,
257) -> PyResult<Py<PyArray1<f64>>> {
258    let binding = data.readonly();
259    let arr = binding.as_array();
260
261    // Use OxiFFT for IDCT Type 2 (most common)
262    #[cfg(feature = "oxifft")]
263    if dct_type == 2 {
264        let result = scirs2_fft::oxifft_backend::idct2_oxifft(&arr)
265            .map_err(|e| PyRuntimeError::new_err(format!("IDCT-II (OxiFFT) failed: {}", e)))?;
266        return Ok(result.into_pyarray(py).unbind());
267    }
268
269    // Fallback to pure Rust for other types
270    let vec_data: Vec<f64> = arr.to_vec();
271    let dct_type_enum = match dct_type {
272        1 => DCTType::Type1,
273        2 => DCTType::Type2,
274        3 => DCTType::Type3,
275        4 => DCTType::Type4,
276        _ => {
277            return Err(PyRuntimeError::new_err(format!(
278                "Invalid DCT type: {}",
279                dct_type
280            )))
281        }
282    };
283
284    let result = idct(&vec_data, Some(dct_type_enum), None)
285        .map_err(|e| PyRuntimeError::new_err(format!("IDCT failed: {}", e)))?;
286
287    Ok(Array1::from_vec(result).into_pyarray(py).unbind())
288}
289
290// ========================================
291// 2D FFT FUNCTIONS (OxiFFT-optimized!)
292// ========================================
293
294/// 2D FFT - Returns NumPy complex128 2D array (compatible with np.fft.fft2)
295/// Optimized: Uses rfft2 + Hermitian symmetry to reconstruct full spectrum
296/// This avoids the slow real→complex conversion for the input
297#[pyfunction]
298fn fft2_py(
299    py: Python,
300    data: &Bound<'_, scirs2_numpy::PyArray2<f64>>,
301) -> PyResult<Py<scirs2_numpy::PyArray2<NumpyComplex64>>> {
302    let binding = data.readonly();
303    let arr = binding.as_array();
304
305    #[cfg(feature = "oxifft")]
306    {
307        let (rows, cols) = arr.dim();
308
309        // Use rfft2 which is much faster for real input (no complex conversion needed)
310        let half_result = scirs2_fft::oxifft_backend::rfft2_oxifft(&arr)
311            .map_err(|e| PyRuntimeError::new_err(format!("FFT2 (OxiFFT) failed: {}", e)))?;
312
313        // Reconstruct full spectrum using Hermitian symmetry
314        // rfft2 gives us columns 0 to cols/2, we need to fill cols/2+1 to cols-1
315        let half_cols = cols / 2 + 1;
316        let mut full_result: Vec<NumpyComplex64> = Vec::with_capacity(rows * cols);
317
318        for row in 0..rows {
319            // Copy the first half (cols 0 to half_cols-1)
320            for col in 0..half_cols {
321                let c = half_result[[row, col]];
322                full_result.push(NumpyComplex64::new(c.re, c.im));
323            }
324
325            // Reconstruct the second half using Hermitian symmetry
326            // For real input: X[k1, k2] = conj(X[N1-k1, N2-k2])
327            for col in half_cols..cols {
328                let conj_row = if row == 0 { 0 } else { rows - row };
329                let conj_col = cols - col;
330                let c = half_result[[conj_row, conj_col]];
331                full_result.push(NumpyComplex64::new(c.re, -c.im)); // conjugate
332            }
333        }
334
335        let result_array = scirs2_core::ndarray::Array2::from_shape_vec((rows, cols), full_result)
336            .map_err(|e| PyRuntimeError::new_err(format!("Shape error: {}", e)))?;
337
338        Ok(result_array.into_pyarray(py).unbind())
339    }
340
341    #[cfg(not(feature = "oxifft"))]
342    {
343        let _ = arr;
344        Err(PyRuntimeError::new_err("FFT2 requires oxifft feature"))
345    }
346}
347
348/// 2D real FFT - Returns NumPy complex128 2D array (compatible with np.fft.rfft2)
349#[pyfunction]
350fn rfft2_py(
351    py: Python,
352    data: &Bound<'_, scirs2_numpy::PyArray2<f64>>,
353) -> PyResult<Py<scirs2_numpy::PyArray2<NumpyComplex64>>> {
354    let binding = data.readonly();
355    let arr = binding.as_array();
356
357    #[cfg(feature = "oxifft")]
358    {
359        let result = scirs2_fft::oxifft_backend::rfft2_oxifft(&arr)
360            .map_err(|e| PyRuntimeError::new_err(format!("RFFT2 (OxiFFT) failed: {}", e)))?;
361
362        // Convert to NumPy complex128 2D array
363        let (rows, cols) = result.dim();
364        let complex_result: Vec<NumpyComplex64> = result
365            .iter()
366            .map(|c| NumpyComplex64::new(c.re, c.im))
367            .collect();
368
369        let result_array =
370            scirs2_core::ndarray::Array2::from_shape_vec((rows, cols), complex_result)
371                .map_err(|e| PyRuntimeError::new_err(format!("Shape error: {}", e)))?;
372
373        Ok(result_array.into_pyarray(py).unbind())
374    }
375
376    #[cfg(not(feature = "oxifft"))]
377    {
378        let _ = arr;
379        Err(PyRuntimeError::new_err("RFFT2 requires oxifft feature"))
380    }
381}
382
383/// 2D inverse FFT - Accepts and returns NumPy complex128 2D arrays (compatible with np.fft.ifft2)
384/// Optimized: Direct allocation without intermediate conversions
385#[pyfunction]
386fn ifft2_py(
387    py: Python,
388    data: &Bound<'_, scirs2_numpy::PyArray2<NumpyComplex64>>,
389) -> PyResult<Py<scirs2_numpy::PyArray2<NumpyComplex64>>> {
390    let binding = data.readonly();
391    let arr = binding.as_array();
392
393    #[cfg(feature = "oxifft")]
394    {
395        let (rows, cols) = arr.dim();
396        let n = rows * cols;
397
398        // Optimized: Direct allocation with capacity (no intermediate Array1)
399        let mut complex_vec: Vec<Complex64> = Vec::with_capacity(n);
400        for c in arr.iter() {
401            complex_vec.push(Complex64::new(c.re, c.im));
402        }
403        let complex_input = scirs2_core::ndarray::Array2::from_shape_vec((rows, cols), complex_vec)
404            .map_err(|e| PyRuntimeError::new_err(format!("Shape error: {}", e)))?;
405
406        let result = scirs2_fft::oxifft_backend::ifft2_oxifft(&complex_input.view())
407            .map_err(|e| PyRuntimeError::new_err(format!("IFFT2 (OxiFFT) failed: {}", e)))?;
408
409        // Direct allocation for output
410        let mut result_vec: Vec<NumpyComplex64> = Vec::with_capacity(n);
411        for c in result.iter() {
412            result_vec.push(NumpyComplex64::new(c.re, c.im));
413        }
414
415        let result_array = scirs2_core::ndarray::Array2::from_shape_vec((rows, cols), result_vec)
416            .map_err(|e| PyRuntimeError::new_err(format!("Shape error: {}", e)))?;
417
418        Ok(result_array.into_pyarray(py).unbind())
419    }
420
421    #[cfg(not(feature = "oxifft"))]
422    {
423        let _ = arr;
424        Err(PyRuntimeError::new_err("IFFT2 requires oxifft feature"))
425    }
426}
427
428/// 2D inverse real FFT - Accepts NumPy complex128 2D array, returns real 2D array
429#[pyfunction]
430#[pyo3(signature = (data, shape))]
431fn irfft2_py(
432    py: Python,
433    data: &Bound<'_, scirs2_numpy::PyArray2<NumpyComplex64>>,
434    shape: (usize, usize),
435) -> PyResult<Py<scirs2_numpy::PyArray2<f64>>> {
436    let binding = data.readonly();
437    let arr = binding.as_array();
438
439    #[cfg(feature = "oxifft")]
440    {
441        let (in_rows, in_cols) = arr.dim();
442
443        // Convert NumPy complex to scirs2 Complex64
444        let complex_input: scirs2_core::ndarray::Array2<Complex64> = arr
445            .iter()
446            .map(|c| Complex64::new(c.re, c.im))
447            .collect::<Vec<_>>()
448            .into_iter()
449            .collect::<scirs2_core::ndarray::Array1<_>>()
450            .into_shape_with_order((in_rows, in_cols))
451            .map_err(|e| PyRuntimeError::new_err(format!("Shape error: {}", e)))?;
452
453        let result = scirs2_fft::oxifft_backend::irfft2_oxifft(&complex_input.view(), shape)
454            .map_err(|e| PyRuntimeError::new_err(format!("IRFFT2 (OxiFFT) failed: {}", e)))?;
455
456        Ok(result.into_pyarray(py).unbind())
457    }
458
459    #[cfg(not(feature = "oxifft"))]
460    {
461        let _ = (arr, shape);
462        Err(PyRuntimeError::new_err("IRFFT2 requires oxifft feature"))
463    }
464}
465
466// ========================================
467// HELPER FUNCTIONS
468// ========================================
469
470/// FFT sample frequencies
471#[pyfunction]
472#[pyo3(signature = (n, d=1.0))]
473fn fftfreq_py(py: Python, n: usize, d: f64) -> PyResult<Py<PyArray1<f64>>> {
474    let result =
475        fftfreq(n, d).map_err(|e| PyRuntimeError::new_err(format!("FFT freq failed: {}", e)))?;
476    Ok(Array1::from_vec(result).into_pyarray(py).unbind())
477}
478
479/// Real FFT sample frequencies
480#[pyfunction]
481#[pyo3(signature = (n, d=1.0))]
482fn rfftfreq_py(py: Python, n: usize, d: f64) -> PyResult<Py<PyArray1<f64>>> {
483    let result =
484        rfftfreq(n, d).map_err(|e| PyRuntimeError::new_err(format!("RFFT freq failed: {}", e)))?;
485    Ok(Array1::from_vec(result).into_pyarray(py).unbind())
486}
487
488/// FFT shift - shift zero-frequency component to center
489#[pyfunction]
490fn fftshift_py(py: Python, data: &Bound<'_, PyArray1<f64>>) -> PyResult<Py<PyArray1<f64>>> {
491    let binding = data.readonly();
492    let arr = binding.as_array().to_owned();
493
494    let result =
495        fftshift(&arr).map_err(|e| PyRuntimeError::new_err(format!("FFT shift failed: {}", e)))?;
496    Ok(result.into_pyarray(py).unbind())
497}
498
499/// Inverse FFT shift
500#[pyfunction]
501fn ifftshift_py(py: Python, data: &Bound<'_, PyArray1<f64>>) -> PyResult<Py<PyArray1<f64>>> {
502    let binding = data.readonly();
503    let arr = binding.as_array().to_owned();
504
505    let result = ifftshift(&arr)
506        .map_err(|e| PyRuntimeError::new_err(format!("Inverse FFT shift failed: {}", e)))?;
507    Ok(result.into_pyarray(py).unbind())
508}
509
510/// Find next fast length for FFT
511/// Returns the smallest size >= n that can be efficiently transformed
512#[pyfunction]
513#[pyo3(signature = (n, real=false))]
514fn next_fast_len_py(n: usize, real: bool) -> usize {
515    next_fast_len(n, real)
516}
517
518/// Python module registration
519pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
520    // Core FFT functions
521    m.add_function(wrap_pyfunction!(fft_py, m)?)?;
522    m.add_function(wrap_pyfunction!(ifft_py, m)?)?;
523    m.add_function(wrap_pyfunction!(rfft_py, m)?)?;
524    m.add_function(wrap_pyfunction!(irfft_py, m)?)?;
525
526    // DCT functions
527    m.add_function(wrap_pyfunction!(dct_py, m)?)?;
528    m.add_function(wrap_pyfunction!(idct_py, m)?)?;
529
530    // 2D FFT functions
531    m.add_function(wrap_pyfunction!(fft2_py, m)?)?;
532    m.add_function(wrap_pyfunction!(ifft2_py, m)?)?;
533    m.add_function(wrap_pyfunction!(rfft2_py, m)?)?;
534    m.add_function(wrap_pyfunction!(irfft2_py, m)?)?;
535
536    // Helper functions
537    m.add_function(wrap_pyfunction!(fftfreq_py, m)?)?;
538    m.add_function(wrap_pyfunction!(rfftfreq_py, m)?)?;
539    m.add_function(wrap_pyfunction!(fftshift_py, m)?)?;
540    m.add_function(wrap_pyfunction!(ifftshift_py, m)?)?;
541    m.add_function(wrap_pyfunction!(next_fast_len_py, m)?)?;
542
543    Ok(())
544}