1use pyo3::exceptions::PyRuntimeError;
9use pyo3::prelude::*;
10
11use scirs2_numpy::{Complex64 as NumpyComplex64, IntoPyArray, PyArray1, PyArrayMethods};
14
15use scirs2_core::{numeric::Complex64, Array1};
17
18use scirs2_fft::{dct, fftfreq, fftshift, idct, ifftshift, next_fast_len, rfftfreq, DCTType};
20
21#[cfg(not(feature = "oxifft"))]
23use scirs2_fft::{fft, ifft, irfft};
24
25#[pyfunction]
32fn fft_py(py: Python, data: &Bound<'_, PyArray1<f64>>) -> PyResult<Py<PyArray1<NumpyComplex64>>> {
33 let binding = data.readonly();
34 let arr = binding.as_array();
35
36 #[cfg(feature = "oxifft")]
38 {
39 let complex_input: scirs2_core::ndarray::Array1<Complex64> =
41 arr.iter().map(|&r| Complex64::new(r, 0.0)).collect();
42
43 let result = scirs2_fft::oxifft_backend::fft_oxifft(&complex_input.view())
44 .map_err(|e| PyRuntimeError::new_err(format!("FFT (OxiFFT) failed: {}", e)))?;
45
46 let complex_result: Vec<NumpyComplex64> = result
48 .iter()
49 .map(|c| NumpyComplex64::new(c.re, c.im))
50 .collect();
51
52 Ok(Array1::from_vec(complex_result).into_pyarray(py).unbind())
53 }
54
55 #[cfg(not(feature = "oxifft"))]
57 {
58 let vec_data: Vec<f64> = arr.to_vec();
59
60 let result = fft(&vec_data, None)
61 .map_err(|e| PyRuntimeError::new_err(format!("FFT failed: {}", e)))?;
62
63 let complex_result: Vec<NumpyComplex64> = result
65 .iter()
66 .map(|c| NumpyComplex64::new(c.re, c.im))
67 .collect();
68
69 return Ok(Array1::from_vec(complex_result).into_pyarray(py).unbind());
70 }
71}
72
73#[pyfunction]
76fn ifft_py(
77 py: Python,
78 data: &Bound<'_, PyArray1<NumpyComplex64>>,
79) -> PyResult<Py<PyArray1<NumpyComplex64>>> {
80 let binding = data.readonly();
81 let arr = binding.as_array();
82
83 #[cfg(feature = "oxifft")]
85 {
86 let complex_input: scirs2_core::ndarray::Array1<Complex64> =
88 arr.iter().map(|c| Complex64::new(c.re, c.im)).collect();
89
90 let result = scirs2_fft::oxifft_backend::ifft_oxifft(&complex_input.view())
91 .map_err(|e| PyRuntimeError::new_err(format!("IFFT (OxiFFT) failed: {}", e)))?;
92
93 let complex_result: Vec<NumpyComplex64> = result
95 .iter()
96 .map(|c| NumpyComplex64::new(c.re, c.im))
97 .collect();
98
99 Ok(Array1::from_vec(complex_result).into_pyarray(py).unbind())
100 }
101
102 #[cfg(not(feature = "oxifft"))]
104 {
105 let complex_input: Vec<Complex64> =
107 arr.iter().map(|c| Complex64::new(c.re, c.im)).collect();
108
109 let result = ifft(&complex_input, None)
110 .map_err(|e| PyRuntimeError::new_err(format!("IFFT failed: {}", e)))?;
111
112 let complex_result: Vec<NumpyComplex64> = result
114 .iter()
115 .map(|c| NumpyComplex64::new(c.re, c.im))
116 .collect();
117
118 return Ok(Array1::from_vec(complex_result).into_pyarray(py).unbind());
119 }
120}
121
122#[pyfunction]
125fn rfft_py(py: Python, data: &Bound<'_, PyArray1<f64>>) -> PyResult<Py<PyArray1<NumpyComplex64>>> {
126 let binding = data.readonly();
127 let arr = binding.as_array();
128
129 #[cfg(feature = "oxifft")]
131 {
132 let result = scirs2_fft::oxifft_backend::rfft_oxifft(&arr)
133 .map_err(|e| PyRuntimeError::new_err(format!("RFFT (OxiFFT) failed: {}", e)))?;
134
135 let complex_result: Vec<NumpyComplex64> = result
137 .iter()
138 .map(|c| NumpyComplex64::new(c.re, c.im))
139 .collect();
140
141 Ok(Array1::from_vec(complex_result).into_pyarray(py).unbind())
142 }
143
144 #[cfg(not(feature = "oxifft"))]
146 {
147 let vec_data: Vec<f64> = arr.to_vec();
148
149 let result = rfft(&vec_data, None)
150 .map_err(|e| PyRuntimeError::new_err(format!("RFFT failed: {}", e)))?;
151
152 let complex_result: Vec<NumpyComplex64> = result
154 .iter()
155 .map(|c| NumpyComplex64::new(c.re, c.im))
156 .collect();
157
158 return Ok(Array1::from_vec(complex_result).into_pyarray(py).unbind());
159 }
160}
161
162#[pyfunction]
165#[pyo3(signature = (data, n=None))]
166fn irfft_py(
167 py: Python,
168 data: &Bound<'_, PyArray1<NumpyComplex64>>,
169 n: Option<usize>,
170) -> PyResult<Py<PyArray1<f64>>> {
171 let binding = data.readonly();
172 let arr = binding.as_array();
173
174 #[cfg(feature = "oxifft")]
176 {
177 let complex_input: scirs2_core::ndarray::Array1<Complex64> =
179 arr.iter().map(|c| Complex64::new(c.re, c.im)).collect();
180
181 let output_len = n.unwrap_or_else(|| 2 * (complex_input.len() - 1));
183
184 let result = scirs2_fft::oxifft_backend::irfft_oxifft(&complex_input.view(), output_len)
185 .map_err(|e| PyRuntimeError::new_err(format!("IRFFT (OxiFFT) failed: {}", e)))?;
186
187 Ok(result.into_pyarray(py).unbind())
188 }
189
190 #[cfg(not(feature = "oxifft"))]
192 {
193 let complex_input: Vec<Complex64> =
195 arr.iter().map(|c| Complex64::new(c.re, c.im)).collect();
196
197 let result = irfft(&complex_input, n)
198 .map_err(|e| PyRuntimeError::new_err(format!("IRFFT failed: {}", e)))?;
199
200 return Ok(Array1::from_vec(result).into_pyarray(py).unbind());
201 }
202}
203
204#[pyfunction]
211#[pyo3(signature = (data, dct_type=2))]
212fn dct_py(
213 py: Python,
214 data: &Bound<'_, PyArray1<f64>>,
215 dct_type: usize,
216) -> PyResult<Py<PyArray1<f64>>> {
217 let binding = data.readonly();
218 let arr = binding.as_array();
219
220 #[cfg(feature = "oxifft")]
222 if dct_type == 2 {
223 let result = scirs2_fft::oxifft_backend::dct2_oxifft(&arr)
224 .map_err(|e| PyRuntimeError::new_err(format!("DCT-II (OxiFFT) failed: {}", e)))?;
225 return Ok(result.into_pyarray(py).unbind());
226 }
227
228 let vec_data: Vec<f64> = arr.to_vec();
230 let dct_type_enum = match dct_type {
231 1 => DCTType::Type1,
232 2 => DCTType::Type2,
233 3 => DCTType::Type3,
234 4 => DCTType::Type4,
235 _ => {
236 return Err(PyRuntimeError::new_err(format!(
237 "Invalid DCT type: {}",
238 dct_type
239 )))
240 }
241 };
242
243 let result = dct(&vec_data, Some(dct_type_enum), None)
244 .map_err(|e| PyRuntimeError::new_err(format!("DCT failed: {}", e)))?;
245
246 Ok(Array1::from_vec(result).into_pyarray(py).unbind())
247}
248
249#[pyfunction]
252#[pyo3(signature = (data, dct_type=2))]
253fn idct_py(
254 py: Python,
255 data: &Bound<'_, PyArray1<f64>>,
256 dct_type: usize,
257) -> PyResult<Py<PyArray1<f64>>> {
258 let binding = data.readonly();
259 let arr = binding.as_array();
260
261 #[cfg(feature = "oxifft")]
263 if dct_type == 2 {
264 let result = scirs2_fft::oxifft_backend::idct2_oxifft(&arr)
265 .map_err(|e| PyRuntimeError::new_err(format!("IDCT-II (OxiFFT) failed: {}", e)))?;
266 return Ok(result.into_pyarray(py).unbind());
267 }
268
269 let vec_data: Vec<f64> = arr.to_vec();
271 let dct_type_enum = match dct_type {
272 1 => DCTType::Type1,
273 2 => DCTType::Type2,
274 3 => DCTType::Type3,
275 4 => DCTType::Type4,
276 _ => {
277 return Err(PyRuntimeError::new_err(format!(
278 "Invalid DCT type: {}",
279 dct_type
280 )))
281 }
282 };
283
284 let result = idct(&vec_data, Some(dct_type_enum), None)
285 .map_err(|e| PyRuntimeError::new_err(format!("IDCT failed: {}", e)))?;
286
287 Ok(Array1::from_vec(result).into_pyarray(py).unbind())
288}
289
290#[pyfunction]
298fn fft2_py(
299 py: Python,
300 data: &Bound<'_, scirs2_numpy::PyArray2<f64>>,
301) -> PyResult<Py<scirs2_numpy::PyArray2<NumpyComplex64>>> {
302 let binding = data.readonly();
303 let arr = binding.as_array();
304
305 #[cfg(feature = "oxifft")]
306 {
307 let (rows, cols) = arr.dim();
308
309 let half_result = scirs2_fft::oxifft_backend::rfft2_oxifft(&arr)
311 .map_err(|e| PyRuntimeError::new_err(format!("FFT2 (OxiFFT) failed: {}", e)))?;
312
313 let half_cols = cols / 2 + 1;
316 let mut full_result: Vec<NumpyComplex64> = Vec::with_capacity(rows * cols);
317
318 for row in 0..rows {
319 for col in 0..half_cols {
321 let c = half_result[[row, col]];
322 full_result.push(NumpyComplex64::new(c.re, c.im));
323 }
324
325 for col in half_cols..cols {
328 let conj_row = if row == 0 { 0 } else { rows - row };
329 let conj_col = cols - col;
330 let c = half_result[[conj_row, conj_col]];
331 full_result.push(NumpyComplex64::new(c.re, -c.im)); }
333 }
334
335 let result_array = scirs2_core::ndarray::Array2::from_shape_vec((rows, cols), full_result)
336 .map_err(|e| PyRuntimeError::new_err(format!("Shape error: {}", e)))?;
337
338 Ok(result_array.into_pyarray(py).unbind())
339 }
340
341 #[cfg(not(feature = "oxifft"))]
342 {
343 let _ = arr;
344 Err(PyRuntimeError::new_err("FFT2 requires oxifft feature"))
345 }
346}
347
348#[pyfunction]
350fn rfft2_py(
351 py: Python,
352 data: &Bound<'_, scirs2_numpy::PyArray2<f64>>,
353) -> PyResult<Py<scirs2_numpy::PyArray2<NumpyComplex64>>> {
354 let binding = data.readonly();
355 let arr = binding.as_array();
356
357 #[cfg(feature = "oxifft")]
358 {
359 let result = scirs2_fft::oxifft_backend::rfft2_oxifft(&arr)
360 .map_err(|e| PyRuntimeError::new_err(format!("RFFT2 (OxiFFT) failed: {}", e)))?;
361
362 let (rows, cols) = result.dim();
364 let complex_result: Vec<NumpyComplex64> = result
365 .iter()
366 .map(|c| NumpyComplex64::new(c.re, c.im))
367 .collect();
368
369 let result_array =
370 scirs2_core::ndarray::Array2::from_shape_vec((rows, cols), complex_result)
371 .map_err(|e| PyRuntimeError::new_err(format!("Shape error: {}", e)))?;
372
373 Ok(result_array.into_pyarray(py).unbind())
374 }
375
376 #[cfg(not(feature = "oxifft"))]
377 {
378 let _ = arr;
379 Err(PyRuntimeError::new_err("RFFT2 requires oxifft feature"))
380 }
381}
382
383#[pyfunction]
386fn ifft2_py(
387 py: Python,
388 data: &Bound<'_, scirs2_numpy::PyArray2<NumpyComplex64>>,
389) -> PyResult<Py<scirs2_numpy::PyArray2<NumpyComplex64>>> {
390 let binding = data.readonly();
391 let arr = binding.as_array();
392
393 #[cfg(feature = "oxifft")]
394 {
395 let (rows, cols) = arr.dim();
396 let n = rows * cols;
397
398 let mut complex_vec: Vec<Complex64> = Vec::with_capacity(n);
400 for c in arr.iter() {
401 complex_vec.push(Complex64::new(c.re, c.im));
402 }
403 let complex_input = scirs2_core::ndarray::Array2::from_shape_vec((rows, cols), complex_vec)
404 .map_err(|e| PyRuntimeError::new_err(format!("Shape error: {}", e)))?;
405
406 let result = scirs2_fft::oxifft_backend::ifft2_oxifft(&complex_input.view())
407 .map_err(|e| PyRuntimeError::new_err(format!("IFFT2 (OxiFFT) failed: {}", e)))?;
408
409 let mut result_vec: Vec<NumpyComplex64> = Vec::with_capacity(n);
411 for c in result.iter() {
412 result_vec.push(NumpyComplex64::new(c.re, c.im));
413 }
414
415 let result_array = scirs2_core::ndarray::Array2::from_shape_vec((rows, cols), result_vec)
416 .map_err(|e| PyRuntimeError::new_err(format!("Shape error: {}", e)))?;
417
418 Ok(result_array.into_pyarray(py).unbind())
419 }
420
421 #[cfg(not(feature = "oxifft"))]
422 {
423 let _ = arr;
424 Err(PyRuntimeError::new_err("IFFT2 requires oxifft feature"))
425 }
426}
427
428#[pyfunction]
430#[pyo3(signature = (data, shape))]
431fn irfft2_py(
432 py: Python,
433 data: &Bound<'_, scirs2_numpy::PyArray2<NumpyComplex64>>,
434 shape: (usize, usize),
435) -> PyResult<Py<scirs2_numpy::PyArray2<f64>>> {
436 let binding = data.readonly();
437 let arr = binding.as_array();
438
439 #[cfg(feature = "oxifft")]
440 {
441 let (in_rows, in_cols) = arr.dim();
442
443 let complex_input: scirs2_core::ndarray::Array2<Complex64> = arr
445 .iter()
446 .map(|c| Complex64::new(c.re, c.im))
447 .collect::<Vec<_>>()
448 .into_iter()
449 .collect::<scirs2_core::ndarray::Array1<_>>()
450 .into_shape_with_order((in_rows, in_cols))
451 .map_err(|e| PyRuntimeError::new_err(format!("Shape error: {}", e)))?;
452
453 let result = scirs2_fft::oxifft_backend::irfft2_oxifft(&complex_input.view(), shape)
454 .map_err(|e| PyRuntimeError::new_err(format!("IRFFT2 (OxiFFT) failed: {}", e)))?;
455
456 Ok(result.into_pyarray(py).unbind())
457 }
458
459 #[cfg(not(feature = "oxifft"))]
460 {
461 let _ = (arr, shape);
462 Err(PyRuntimeError::new_err("IRFFT2 requires oxifft feature"))
463 }
464}
465
466#[pyfunction]
472#[pyo3(signature = (n, d=1.0))]
473fn fftfreq_py(py: Python, n: usize, d: f64) -> PyResult<Py<PyArray1<f64>>> {
474 let result =
475 fftfreq(n, d).map_err(|e| PyRuntimeError::new_err(format!("FFT freq failed: {}", e)))?;
476 Ok(Array1::from_vec(result).into_pyarray(py).unbind())
477}
478
479#[pyfunction]
481#[pyo3(signature = (n, d=1.0))]
482fn rfftfreq_py(py: Python, n: usize, d: f64) -> PyResult<Py<PyArray1<f64>>> {
483 let result =
484 rfftfreq(n, d).map_err(|e| PyRuntimeError::new_err(format!("RFFT freq failed: {}", e)))?;
485 Ok(Array1::from_vec(result).into_pyarray(py).unbind())
486}
487
488#[pyfunction]
490fn fftshift_py(py: Python, data: &Bound<'_, PyArray1<f64>>) -> PyResult<Py<PyArray1<f64>>> {
491 let binding = data.readonly();
492 let arr = binding.as_array().to_owned();
493
494 let result =
495 fftshift(&arr).map_err(|e| PyRuntimeError::new_err(format!("FFT shift failed: {}", e)))?;
496 Ok(result.into_pyarray(py).unbind())
497}
498
499#[pyfunction]
501fn ifftshift_py(py: Python, data: &Bound<'_, PyArray1<f64>>) -> PyResult<Py<PyArray1<f64>>> {
502 let binding = data.readonly();
503 let arr = binding.as_array().to_owned();
504
505 let result = ifftshift(&arr)
506 .map_err(|e| PyRuntimeError::new_err(format!("Inverse FFT shift failed: {}", e)))?;
507 Ok(result.into_pyarray(py).unbind())
508}
509
510#[pyfunction]
513#[pyo3(signature = (n, real=false))]
514fn next_fast_len_py(n: usize, real: bool) -> usize {
515 next_fast_len(n, real)
516}
517
518pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
520 m.add_function(wrap_pyfunction!(fft_py, m)?)?;
522 m.add_function(wrap_pyfunction!(ifft_py, m)?)?;
523 m.add_function(wrap_pyfunction!(rfft_py, m)?)?;
524 m.add_function(wrap_pyfunction!(irfft_py, m)?)?;
525
526 m.add_function(wrap_pyfunction!(dct_py, m)?)?;
528 m.add_function(wrap_pyfunction!(idct_py, m)?)?;
529
530 m.add_function(wrap_pyfunction!(fft2_py, m)?)?;
532 m.add_function(wrap_pyfunction!(ifft2_py, m)?)?;
533 m.add_function(wrap_pyfunction!(rfft2_py, m)?)?;
534 m.add_function(wrap_pyfunction!(irfft2_py, m)?)?;
535
536 m.add_function(wrap_pyfunction!(fftfreq_py, m)?)?;
538 m.add_function(wrap_pyfunction!(rfftfreq_py, m)?)?;
539 m.add_function(wrap_pyfunction!(fftshift_py, m)?)?;
540 m.add_function(wrap_pyfunction!(ifftshift_py, m)?)?;
541 m.add_function(wrap_pyfunction!(next_fast_len_py, m)?)?;
542
543 Ok(())
544}