1use crate::error::{FFTError, FFTResult};
7use crate::rfft::{irfft as irfft_basic, rfft as rfft_basic};
8use scirs2_core::numeric::Complex64;
9use scirs2_core::numeric::NumCast;
10use scirs2_core::simd_ops::{AutoOptimizer, PlatformCapabilities};
11use std::fmt::Debug;
12
13#[allow(dead_code)]
45pub fn rfft_simd<T>(input: &[T], n: Option<usize>, norm: Option<&str>) -> FFTResult<Vec<Complex64>>
46where
47 T: NumCast + Copy + Debug + 'static,
48{
49 let result = rfft_basic(input, n)?;
51
52 if let Some(norm_str) = norm {
54 let mut result_mut = result;
55 let n = input.len();
56 match norm_str {
57 "backward" => {
58 let scale = 1.0 / (n as f64);
59 result_mut.iter_mut().for_each(|c| *c *= scale);
60 }
61 "ortho" => {
62 let scale = 1.0 / (n as f64).sqrt();
63 result_mut.iter_mut().for_each(|c| *c *= scale);
64 }
65 "forward" => {
66 let scale = 1.0 / (n as f64);
67 result_mut.iter_mut().for_each(|c| *c *= scale);
68 }
69 _ => {} }
71 return Ok(result_mut);
72 }
73
74 Ok(result)
75}
76
77#[allow(dead_code)]
112pub fn irfft_simd<T>(input: &[T], n: Option<usize>, norm: Option<&str>) -> FFTResult<Vec<f64>>
113where
114 T: NumCast + Copy + Debug + 'static,
115{
116 let result = irfft_basic(input, n)?;
118
119 if let Some(norm_str) = norm {
121 let mut result_mut = result;
122 let n = input.len();
123 match norm_str {
124 "backward" => {
125 let scale = 1.0 / (n as f64);
126 result_mut.iter_mut().for_each(|c| *c *= scale);
127 }
128 "ortho" => {
129 let scale = 1.0 / (n as f64).sqrt();
130 result_mut.iter_mut().for_each(|c| *c *= scale);
131 }
132 "forward" => {
133 let scale = 1.0 / (n as f64);
134 result_mut.iter_mut().for_each(|c| *c *= scale);
135 }
136 _ => {} }
138 return Ok(result_mut);
139 }
140
141 Ok(result)
142}
143
144#[allow(dead_code)]
146pub fn rfft_adaptive<T>(
147 input: &[T],
148 n: Option<usize>,
149 norm: Option<&str>,
150) -> FFTResult<Vec<Complex64>>
151where
152 T: NumCast + Copy + Debug + 'static,
153{
154 let optimizer = AutoOptimizer::new();
155 let caps = PlatformCapabilities::detect();
156 let size = n.unwrap_or(input.len());
157
158 if caps.gpu_available && optimizer.should_use_gpu(size) {
159 match rfft_gpu(input, n, norm) {
161 Ok(result) => Ok(result),
162 Err(_) => {
163 rfft_simd(input, n, norm)
165 }
166 }
167 } else {
168 rfft_simd(input, n, norm)
169 }
170}
171
172#[allow(dead_code)]
174pub fn irfft_adaptive<T>(input: &[T], n: Option<usize>, norm: Option<&str>) -> FFTResult<Vec<f64>>
175where
176 T: NumCast + Copy + Debug + 'static,
177{
178 let optimizer = AutoOptimizer::new();
179 let caps = PlatformCapabilities::detect();
180 let size = n.unwrap_or_else(|| input.len() * 2 - 2);
181
182 if caps.gpu_available && optimizer.should_use_gpu(size) {
183 match irfft_gpu(input, n, norm) {
185 Ok(result) => Ok(result),
186 Err(_) => {
187 irfft_simd(input, n, norm)
189 }
190 }
191 } else {
192 irfft_simd(input, n, norm)
193 }
194}
195
196#[cfg(feature = "cuda")]
198#[allow(dead_code)]
199fn rfft_gpu<T>(_input: &[T], _n: Option<usize>, _norm: Option<&str>) -> FFTResult<Vec<Complex64>>
200where
201 T: NumCast + Copy + Debug + 'static,
202{
203 Err(FFTError::NotImplementedError(
206 "GPU-accelerated RFFT is not yet fully implemented".to_string(),
207 ))
208}
209
210#[cfg(feature = "cuda")]
212#[allow(dead_code)]
213fn irfft_gpu<T>(_input: &[T], _n: Option<usize>, _norm: Option<&str>) -> FFTResult<Vec<f64>>
214where
215 T: NumCast + Copy + Debug + 'static,
216{
217 Err(FFTError::NotImplementedError(
220 "GPU-accelerated IRFFT is not yet fully implemented".to_string(),
221 ))
222}
223
224#[cfg(not(feature = "cuda"))]
226#[allow(dead_code)]
227fn rfft_gpu<T>(_input: &[T], _n: Option<usize>, _norm: Option<&str>) -> FFTResult<Vec<Complex64>>
228where
229 T: NumCast + Copy + Debug + 'static,
230{
231 Err(crate::error::FFTError::NotImplementedError(
232 "GPU FFT not compiled".to_string(),
233 ))
234}
235
236#[cfg(not(feature = "cuda"))]
237#[allow(dead_code)]
238fn irfft_gpu<T>(_input: &[T], _n: Option<usize>, _norm: Option<&str>) -> FFTResult<Vec<f64>>
239where
240 T: NumCast + Copy + Debug + 'static,
241{
242 Err(crate::error::FFTError::NotImplementedError(
243 "GPU FFT not compiled".to_string(),
244 ))
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250 use approx::assert_abs_diff_eq;
251
252 #[test]
253 fn test_rfft_simd_simple() {
254 let signal = vec![1.0, 2.0, 3.0, 4.0];
255
256 let spectrum = rfft_simd(&signal, None, None).unwrap();
258
259 assert_eq!(spectrum.len(), signal.len() / 2 + 1);
261
262 assert_abs_diff_eq!(spectrum[0].re, 10.0, epsilon = 1e-10);
264 assert_abs_diff_eq!(spectrum[0].im, 0.0, epsilon = 1e-10);
265 }
266
267 #[test]
268 fn test_rfft_irfft_roundtrip() {
269 let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
270
271 let spectrum = rfft_simd(&signal, None, None).unwrap();
273
274 let recovered = irfft_simd(&spectrum, Some(signal.len()), None).unwrap();
276
277 for (i, (&orig, &rec)) in signal.iter().zip(recovered.iter()).enumerate() {
279 if (orig - rec).abs() > 1e-10 {
280 panic!("Mismatch at index {i}: {orig} != {rec}");
281 }
282 }
283 }
284
285 #[test]
286 fn test_adaptive_selection() {
287 let signal = vec![1.0; 1000];
288
289 let spectrum = rfft_adaptive(&signal, None, None).unwrap();
291 assert_eq!(spectrum.len(), signal.len() / 2 + 1);
292
293 let recovered = irfft_adaptive(&spectrum, Some(signal.len()), None).unwrap();
294 assert_eq!(recovered.len(), signal.len());
295 }
296}