Skip to main content

scirs2_fft/simd_fft/
mod.rs

1//! SIMD-accelerated FFT operations.
2//!
3//! This module provides SIMD-optimised butterfly kernels and delegates
4//! all higher-level FFT operations to scirs2-core when available.
5//!
6//! ## Architecture-specific sub-modules
7//!
8//! - `avx512`: AVX-512F butterfly kernels for x86_64 (radix-4 / radix-8).
9//!   Gated on `#[cfg(target_arch = "x86_64")]` with a runtime
10//!   `is_x86_feature_detected!("avx512f")` guard.
11
12/// AVX-512F accelerated radix-4 and radix-8 FFT butterfly kernels.
13///
14/// Available on x86_64 targets only.  Each public function is additionally
15/// wrapped with a runtime check so that it is safe to call the dispatch
16/// entry points on CPUs that do not have AVX-512F.
17#[cfg(target_arch = "x86_64")]
18pub mod avx512;
19
20// Re-export the public dispatch entry points so callers can reach them via
21// `scirs2_fft::simd_fft::radix4_butterfly_dispatch` etc.
22#[cfg(target_arch = "x86_64")]
23pub use avx512::{
24    is_avx512_available, radix4_butterfly_dispatch, radix4_butterfly_scalar,
25    radix8_butterfly_dispatch, radix8_butterfly_scalar,
26};
27
28/// ARM NEON and SVE accelerated radix-4 and radix-8 FFT butterfly kernels.
29///
30/// Available on AArch64 targets only.  NEON is architecturally mandatory on
31/// AArch64, so no runtime capability guard is needed.  SVE is optional and
32/// gated via `is_aarch64_feature_detected!("sve")` at runtime.
33#[cfg(target_arch = "aarch64")]
34pub mod neon;
35
36// Re-export NEON dispatch entry points for AArch64 targets.
37#[cfg(target_arch = "aarch64")]
38pub use neon::{
39    is_neon_available, radix4_butterfly_dispatch, radix4_butterfly_scalar,
40    radix8_butterfly_dispatch, radix8_butterfly_scalar,
41};
42
43use crate::error::FFTResult;
44use crate::fft;
45use scirs2_core::ndarray::{Array2, ArrayD, IxDyn};
46use scirs2_core::numeric::Complex64;
47use scirs2_core::numeric::NumCast;
48use scirs2_core::simd_ops::PlatformCapabilities;
49use std::fmt::Debug;
50
51/// Normalization mode for FFT operations
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum NormMode {
54    None,
55    Backward,
56    Ortho,
57    Forward,
58}
59
60/// Check if SIMD support is available
61#[allow(dead_code)]
62pub fn simd_support_available() -> bool {
63    let caps = PlatformCapabilities::detect();
64    caps.simd_available
65}
66
67/// Apply SIMD normalization (stub - not used in current implementation)
68#[allow(dead_code)]
69pub fn apply_simd_normalization(data: &mut [Complex64], scale: f64) {
70    for c in data.iter_mut() {
71        *c *= scale;
72    }
73}
74
75/// SIMD-accelerated 1D FFT
76#[allow(dead_code)]
77pub fn fft_simd<T>(x: &[T], _norm: Option<&str>) -> FFTResult<Vec<Complex64>>
78where
79    T: NumCast + Copy + Debug + 'static,
80{
81    fft::fft(x, None)
82}
83
84/// SIMD-accelerated 1D inverse FFT
85#[allow(dead_code)]
86pub fn ifft_simd<T>(x: &[T], _norm: Option<&str>) -> FFTResult<Vec<Complex64>>
87where
88    T: NumCast + Copy + Debug + 'static,
89{
90    fft::ifft(x, None)
91}
92
93/// SIMD-accelerated 2D FFT
94#[allow(dead_code)]
95pub fn fft2_simd<T>(
96    x: &[T],
97    shape: Option<(usize, usize)>,
98    norm: Option<&str>,
99) -> FFTResult<Array2<Complex64>>
100where
101    T: NumCast + Copy + Debug + 'static,
102{
103    // If no shape is provided, try to infer a square shape
104    let (n_rows, n_cols) = if let Some(s) = shape {
105        s
106    } else {
107        let len = x.len();
108        let size = (len as f64).sqrt() as usize;
109        if size * size != len {
110            return Err(crate::error::FFTError::ValueError(
111                "Cannot infer 2D shape from slice length".to_string(),
112            ));
113        }
114        (size, size)
115    };
116
117    // Check that the slice has the right number of elements
118    if x.len() != n_rows * n_cols {
119        return Err(crate::error::FFTError::ValueError(format!(
120            "Shape ({}, {}) requires {} elements, but slice has {}",
121            n_rows,
122            n_cols,
123            n_rows * n_cols,
124            x.len()
125        )));
126    }
127
128    // Convert slice to 2D array
129    let mut values = Vec::with_capacity(n_rows * n_cols);
130    for &val in x.iter() {
131        values.push(val);
132    }
133    let arr = Array2::from_shape_vec((n_rows, n_cols), values)
134        .map_err(|e| crate::error::FFTError::DimensionError(e.to_string()))?;
135
136    // Use the regular fft2 function
137    crate::fft::fft2(&arr, None, None, norm)
138}
139
140/// SIMD-accelerated 2D inverse FFT
141#[allow(dead_code)]
142pub fn ifft2_simd<T>(
143    x: &[T],
144    shape: Option<(usize, usize)>,
145    norm: Option<&str>,
146) -> FFTResult<Array2<Complex64>>
147where
148    T: NumCast + Copy + Debug + 'static,
149{
150    // If no shape is provided, try to infer a square shape
151    let (n_rows, n_cols) = if let Some(s) = shape {
152        s
153    } else {
154        let len = x.len();
155        let size = (len as f64).sqrt() as usize;
156        if size * size != len {
157            return Err(crate::error::FFTError::ValueError(
158                "Cannot infer 2D shape from slice length".to_string(),
159            ));
160        }
161        (size, size)
162    };
163
164    // Check that the slice has the right number of elements
165    if x.len() != n_rows * n_cols {
166        return Err(crate::error::FFTError::ValueError(format!(
167            "Shape ({}, {}) requires {} elements, but slice has {}",
168            n_rows,
169            n_cols,
170            n_rows * n_cols,
171            x.len()
172        )));
173    }
174
175    // Convert slice to 2D array
176    let mut values = Vec::with_capacity(n_rows * n_cols);
177    for &val in x.iter() {
178        values.push(val);
179    }
180    let arr = Array2::from_shape_vec((n_rows, n_cols), values)
181        .map_err(|e| crate::error::FFTError::DimensionError(e.to_string()))?;
182
183    // Use the regular ifft2 function
184    crate::fft::ifft2(&arr, None, None, norm)
185}
186
187/// SIMD-accelerated N-dimensional FFT
188#[allow(dead_code)]
189pub fn fftn_simd<T>(
190    x: &[T],
191    shape: Option<&[usize]>,
192    axes: Option<&[usize]>,
193    norm: Option<&str>,
194) -> FFTResult<ArrayD<Complex64>>
195where
196    T: NumCast + Copy + Debug + 'static,
197{
198    // Shape is required for N-dimensional FFT from slice
199    let shape = shape.ok_or_else(|| {
200        crate::error::FFTError::ValueError(
201            "Shape is required for N-dimensional FFT from slice".to_string(),
202        )
203    })?;
204
205    // Calculate total number of elements
206    let total_elements: usize = shape.iter().product();
207
208    // Check that the slice has the right number of elements
209    if x.len() != total_elements {
210        return Err(crate::error::FFTError::ValueError(format!(
211            "Shape {:?} requires {} elements, but slice has {}",
212            shape,
213            total_elements,
214            x.len()
215        )));
216    }
217
218    // Convert slice to N-dimensional array
219    let mut values = Vec::with_capacity(total_elements);
220    for &val in x.iter() {
221        values.push(val);
222    }
223    let arr = ArrayD::from_shape_vec(IxDyn(shape), values)
224        .map_err(|e| crate::error::FFTError::DimensionError(e.to_string()))?;
225
226    // Use the regular fftn function
227    crate::fft::fftn(&arr, None, axes.map(|a| a.to_vec()), norm, None, None)
228}
229
230/// SIMD-accelerated N-dimensional inverse FFT
231#[allow(dead_code)]
232pub fn ifftn_simd<T>(
233    x: &[T],
234    shape: Option<&[usize]>,
235    axes: Option<&[usize]>,
236    norm: Option<&str>,
237) -> FFTResult<ArrayD<Complex64>>
238where
239    T: NumCast + Copy + Debug + 'static,
240{
241    // Shape is required for N-dimensional IFFT from slice
242    let shape = shape.ok_or_else(|| {
243        crate::error::FFTError::ValueError(
244            "Shape is required for N-dimensional IFFT from slice".to_string(),
245        )
246    })?;
247
248    // Calculate total number of elements
249    let total_elements: usize = shape.iter().product();
250
251    // Check that the slice has the right number of elements
252    if x.len() != total_elements {
253        return Err(crate::error::FFTError::ValueError(format!(
254            "Shape {:?} requires {} elements, but slice has {}",
255            shape,
256            total_elements,
257            x.len()
258        )));
259    }
260
261    // Convert slice to N-dimensional array
262    let mut values = Vec::with_capacity(total_elements);
263    for &val in x.iter() {
264        values.push(val);
265    }
266    let arr = ArrayD::from_shape_vec(IxDyn(shape), values)
267        .map_err(|e| crate::error::FFTError::DimensionError(e.to_string()))?;
268
269    // Use the regular ifftn function
270    crate::fft::ifftn(&arr, None, axes.map(|a| a.to_vec()), norm, None, None)
271}
272
273/// Adaptive FFT
274#[allow(dead_code)]
275pub fn fft_adaptive<T>(x: &[T], norm: Option<&str>) -> FFTResult<Vec<Complex64>>
276where
277    T: NumCast + Copy + Debug + 'static,
278{
279    fft_simd(x, norm)
280}
281
282/// Adaptive inverse FFT
283#[allow(dead_code)]
284pub fn ifft_adaptive<T>(x: &[T], norm: Option<&str>) -> FFTResult<Vec<Complex64>>
285where
286    T: NumCast + Copy + Debug + 'static,
287{
288    ifft_simd(x, norm)
289}
290
291/// Adaptive 2D FFT
292#[allow(dead_code)]
293pub fn fft2_adaptive<T>(
294    _x: &[T],
295    shape: Option<(usize, usize)>,
296    norm: Option<&str>,
297) -> FFTResult<Array2<Complex64>>
298where
299    T: NumCast + Copy + Debug + 'static,
300{
301    fft2_simd(_x, shape, norm)
302}
303
304/// Adaptive 2D inverse FFT
305#[allow(dead_code)]
306pub fn ifft2_adaptive<T>(
307    _x: &[T],
308    shape: Option<(usize, usize)>,
309    norm: Option<&str>,
310) -> FFTResult<Array2<Complex64>>
311where
312    T: NumCast + Copy + Debug + 'static,
313{
314    ifft2_simd(_x, shape, norm)
315}
316
317/// Adaptive N-dimensional FFT
318#[allow(dead_code)]
319pub fn fftn_adaptive<T>(
320    _x: &[T],
321    shape: Option<&[usize]>,
322    axes: Option<&[usize]>,
323    norm: Option<&str>,
324) -> FFTResult<ArrayD<Complex64>>
325where
326    T: NumCast + Copy + Debug + 'static,
327{
328    fftn_simd(_x, shape, axes, norm)
329}
330
331/// Adaptive N-dimensional inverse FFT
332#[allow(dead_code)]
333pub fn ifftn_adaptive<T>(
334    _x: &[T],
335    shape: Option<&[usize]>,
336    axes: Option<&[usize]>,
337    norm: Option<&str>,
338) -> FFTResult<ArrayD<Complex64>>
339where
340    T: NumCast + Copy + Debug + 'static,
341{
342    ifftn_simd(_x, shape, axes, norm)
343}