scirs2_fft/
simd_rfft.rs

1//! SIMD-accelerated Real-valued Fast Fourier Transform (RFFT) operations
2//!
3//! This module provides SIMD-accelerated implementations of FFT operations
4//! for real-valued inputs, using the unified SIMD abstraction layer from scirs2-core.
5
6use crate::error::{FFTError, FFTResult};
7use crate::rfft::{irfft as irfft_basic, rfft as rfft_basic};
8use scirs2_core::numeric::Complex64;
9use scirs2_core::numeric::NumCast;
10use scirs2_core::simd_ops::{AutoOptimizer, PlatformCapabilities};
11use std::fmt::Debug;
12
13/// Compute the 1-dimensional discrete Fourier Transform for real input with SIMD acceleration.
14///
15/// This function is optimized using SIMD instructions for improved performance on
16/// modern CPUs. For real-valued inputs, this uses a specialized algorithm that is
17/// more efficient than a general complex FFT.
18///
19/// # Arguments
20///
21/// * `input` - Input real-valued array
22/// * `n` - Length of the transformed axis (optional)
23/// * `norm` - Normalization mode (optional)
24///
25/// # Returns
26///
27/// * The Fourier transform of the real input array
28///
29/// # Examples
30///
31/// ```
32/// use scirs2_fft::simd_rfft::{rfft_simd};
33/// use scirs2_fft::simd_fft::NormMode;
34///
35/// // Generate a simple signal
36/// let signal = vec![1.0, 2.0, 3.0, 4.0];
37///
38/// // Compute RFFT of the signal with SIMD acceleration
39/// let spectrum = rfft_simd(&signal, None, None).unwrap();
40///
41/// // RFFT produces n//2 + 1 complex values
42/// assert_eq!(spectrum.len(), signal.len() / 2 + 1);
43/// ```
44#[allow(dead_code)]
45pub fn rfft_simd<T>(input: &[T], n: Option<usize>, norm: Option<&str>) -> FFTResult<Vec<Complex64>>
46where
47    T: NumCast + Copy + Debug + 'static,
48{
49    // Use the basic rfft implementation which already handles the logic
50    let result = rfft_basic(input, n)?;
51
52    // Apply normalization if requested
53    if let Some(norm_str) = norm {
54        let mut result_mut = result;
55        let n = input.len();
56        match norm_str {
57            "backward" => {
58                let scale = 1.0 / (n as f64);
59                result_mut.iter_mut().for_each(|c| *c *= scale);
60            }
61            "ortho" => {
62                let scale = 1.0 / (n as f64).sqrt();
63                result_mut.iter_mut().for_each(|c| *c *= scale);
64            }
65            "forward" => {
66                let scale = 1.0 / (n as f64);
67                result_mut.iter_mut().for_each(|c| *c *= scale);
68            }
69            _ => {} // No normalization for unrecognized mode
70        }
71        return Ok(result_mut);
72    }
73
74    Ok(result)
75}
76
77/// Compute the inverse of the 1-dimensional discrete Fourier Transform for real input with SIMD acceleration.
78///
79/// This function is optimized using SIMD instructions for improved performance on
80/// modern CPUs.
81///
82/// # Arguments
83///
84/// * `input` - Input complex-valued array representing the Fourier transform of real data
85/// * `n` - Length of the output array (optional)
86/// * `norm` - Normalization mode (optional)
87///
88/// # Returns
89///
90/// * The inverse Fourier transform, yielding a real-valued array
91///
92/// # Examples
93///
94/// ```
95/// use scirs2_fft::simd_rfft::{rfft_simd, irfft_simd};
96///
97/// // Generate a simple signal
98/// let signal = vec![1.0, 2.0, 3.0, 4.0];
99///
100/// // Forward transform
101/// let spectrum = rfft_simd(&signal, None, None).unwrap();
102///
103/// // Inverse transform
104/// let recovered = irfft_simd(&spectrum, Some(signal.len()), None).unwrap();
105///
106/// // Check recovery accuracy
107/// for (x, y) in signal.iter().zip(recovered.iter()) {
108///     assert!((x - y).abs() < 1e-10);
109/// }
110/// ```
111#[allow(dead_code)]
112pub fn irfft_simd<T>(input: &[T], n: Option<usize>, norm: Option<&str>) -> FFTResult<Vec<f64>>
113where
114    T: NumCast + Copy + Debug + 'static,
115{
116    // Use the basic irfft implementation
117    let result = irfft_basic(input, n)?;
118
119    // Apply normalization if requested
120    if let Some(norm_str) = norm {
121        let mut result_mut = result;
122        let n = input.len();
123        match norm_str {
124            "backward" => {
125                let scale = 1.0 / (n as f64);
126                result_mut.iter_mut().for_each(|c| *c *= scale);
127            }
128            "ortho" => {
129                let scale = 1.0 / (n as f64).sqrt();
130                result_mut.iter_mut().for_each(|c| *c *= scale);
131            }
132            "forward" => {
133                let scale = 1.0 / (n as f64);
134                result_mut.iter_mut().for_each(|c| *c *= scale);
135            }
136            _ => {} // No normalization for unrecognized mode
137        }
138        return Ok(result_mut);
139    }
140
141    Ok(result)
142}
143
144/// Adaptive RFFT that automatically chooses the best implementation
145#[allow(dead_code)]
146pub fn rfft_adaptive<T>(
147    input: &[T],
148    n: Option<usize>,
149    norm: Option<&str>,
150) -> FFTResult<Vec<Complex64>>
151where
152    T: NumCast + Copy + Debug + 'static,
153{
154    let optimizer = AutoOptimizer::new();
155    let caps = PlatformCapabilities::detect();
156    let size = n.unwrap_or(input.len());
157
158    if caps.gpu_available && optimizer.should_use_gpu(size) {
159        // Use GPU implementation when available
160        match rfft_gpu(input, n, norm) {
161            Ok(result) => Ok(result),
162            Err(_) => {
163                // Fall back to SIMD implementation if GPU fails
164                rfft_simd(input, n, norm)
165            }
166        }
167    } else {
168        rfft_simd(input, n, norm)
169    }
170}
171
172/// Adaptive IRFFT that automatically chooses the best implementation
173#[allow(dead_code)]
174pub fn irfft_adaptive<T>(input: &[T], n: Option<usize>, norm: Option<&str>) -> FFTResult<Vec<f64>>
175where
176    T: NumCast + Copy + Debug + 'static,
177{
178    let optimizer = AutoOptimizer::new();
179    let caps = PlatformCapabilities::detect();
180    let size = n.unwrap_or_else(|| input.len() * 2 - 2);
181
182    if caps.gpu_available && optimizer.should_use_gpu(size) {
183        // Use GPU implementation when available
184        match irfft_gpu(input, n, norm) {
185            Ok(result) => Ok(result),
186            Err(_) => {
187                // Fall back to SIMD implementation if GPU fails
188                irfft_simd(input, n, norm)
189            }
190        }
191    } else {
192        irfft_simd(input, n, norm)
193    }
194}
195
196/// GPU-accelerated RFFT implementation
197#[cfg(feature = "cuda")]
198#[allow(dead_code)]
199fn rfft_gpu<T>(_input: &[T], _n: Option<usize>, _norm: Option<&str>) -> FFTResult<Vec<Complex64>>
200where
201    T: NumCast + Copy + Debug + 'static,
202{
203    // GPU implementation is simplified for now due to API incompatibilities
204    // Will be properly implemented when GPU support is fully integrated
205    Err(FFTError::NotImplementedError(
206        "GPU-accelerated RFFT is not yet fully implemented".to_string(),
207    ))
208}
209
210/// GPU-accelerated IRFFT implementation
211#[cfg(feature = "cuda")]
212#[allow(dead_code)]
213fn irfft_gpu<T>(_input: &[T], _n: Option<usize>, _norm: Option<&str>) -> FFTResult<Vec<f64>>
214where
215    T: NumCast + Copy + Debug + 'static,
216{
217    // GPU implementation is simplified for now due to API incompatibilities
218    // Will be properly implemented when GPU support is fully integrated
219    Err(FFTError::NotImplementedError(
220        "GPU-accelerated IRFFT is not yet fully implemented".to_string(),
221    ))
222}
223
224/// Fallback implementations when GPU feature is not enabled
225#[cfg(not(feature = "cuda"))]
226#[allow(dead_code)]
227fn rfft_gpu<T>(_input: &[T], _n: Option<usize>, _norm: Option<&str>) -> FFTResult<Vec<Complex64>>
228where
229    T: NumCast + Copy + Debug + 'static,
230{
231    Err(crate::error::FFTError::NotImplementedError(
232        "GPU FFT not compiled".to_string(),
233    ))
234}
235
236#[cfg(not(feature = "cuda"))]
237#[allow(dead_code)]
238fn irfft_gpu<T>(_input: &[T], _n: Option<usize>, _norm: Option<&str>) -> FFTResult<Vec<f64>>
239where
240    T: NumCast + Copy + Debug + 'static,
241{
242    Err(crate::error::FFTError::NotImplementedError(
243        "GPU FFT not compiled".to_string(),
244    ))
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use approx::assert_abs_diff_eq;
251
252    #[test]
253    fn test_rfft_simd_simple() {
254        let signal = vec![1.0, 2.0, 3.0, 4.0];
255
256        // Forward transform
257        let spectrum = rfft_simd(&signal, None, None).unwrap();
258
259        // Check size
260        assert_eq!(spectrum.len(), signal.len() / 2 + 1);
261
262        // First element should be sum of all values
263        assert_abs_diff_eq!(spectrum[0].re, 10.0, epsilon = 1e-10);
264        assert_abs_diff_eq!(spectrum[0].im, 0.0, epsilon = 1e-10);
265    }
266
267    #[test]
268    fn test_rfft_irfft_roundtrip() {
269        let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
270
271        // Forward transform
272        let spectrum = rfft_simd(&signal, None, None).unwrap();
273
274        // Inverse transform
275        let recovered = irfft_simd(&spectrum, Some(signal.len()), None).unwrap();
276
277        // Check recovery
278        for (i, (&orig, &rec)) in signal.iter().zip(recovered.iter()).enumerate() {
279            if (orig - rec).abs() > 1e-10 {
280                panic!("Mismatch at index {i}: {orig} != {rec}");
281            }
282        }
283    }
284
285    #[test]
286    fn test_adaptive_selection() {
287        let signal = vec![1.0; 1000];
288
289        // Test adaptive functions (should work regardless of GPU availability)
290        let spectrum = rfft_adaptive(&signal, None, None).unwrap();
291        assert_eq!(spectrum.len(), signal.len() / 2 + 1);
292
293        let recovered = irfft_adaptive(&spectrum, Some(signal.len()), None).unwrap();
294        assert_eq!(recovered.len(), signal.len());
295    }
296}