1use crate::error::{FFTError, FFTResult};
38#[cfg(feature = "oxifft")]
39use crate::oxifft_plan_cache;
40#[cfg(feature = "oxifft")]
41use oxifft::{Complex as OxiComplex, Direction};
42use scirs2_core::numeric::Complex;
43use scirs2_core::numeric::Float;
44
45pub trait RealToComplex<T: Float>: Send + Sync {
50 fn process(&self, input: &[T], output: &mut [Complex<T>]) -> FFTResult<()>;
57
58 fn len(&self) -> usize;
60
61 fn is_empty(&self) -> bool {
63 self.len() == 0
64 }
65
66 fn output_len(&self) -> usize {
68 self.len() / 2 + 1
69 }
70}
71
72pub trait ComplexToReal<T: Float>: Send + Sync {
77 fn process(&self, input: &[Complex<T>], output: &mut [T]) -> FFTResult<()>;
84
85 fn len(&self) -> usize;
87
88 fn is_empty(&self) -> bool {
90 self.len() == 0
91 }
92
93 fn input_len(&self) -> usize {
95 self.len() / 2 + 1
96 }
97}
98
99struct RealFftPlanF64 {
101 length: usize,
102}
103
104impl RealFftPlanF64 {
105 fn new(length: usize) -> Self {
106 Self { length }
107 }
108}
109
110impl RealToComplex<f64> for RealFftPlanF64 {
111 fn process(&self, input: &[f64], output: &mut [Complex<f64>]) -> FFTResult<()> {
112 if input.len() != self.length {
114 return Err(FFTError::ValueError(format!(
115 "Input length {} doesn't match plan length {}",
116 input.len(),
117 self.length
118 )));
119 }
120 if output.len() != self.output_len() {
121 return Err(FFTError::ValueError(format!(
122 "Output length {} doesn't match expected length {}",
123 output.len(),
124 self.output_len()
125 )));
126 }
127
128 #[cfg(feature = "oxifft")]
129 {
130 let input_oxi: Vec<OxiComplex<f64>> =
132 input.iter().map(|&x| OxiComplex::new(x, 0.0)).collect();
133 let mut output_oxi: Vec<OxiComplex<f64>> = vec![OxiComplex::new(0.0, 0.0); self.length];
134
135 oxifft_plan_cache::execute_c2c(&input_oxi, &mut output_oxi, Direction::Forward)?;
136
137 let out_len = self.output_len();
139 for (i, dst) in output.iter_mut().enumerate().take(out_len) {
140 *dst = Complex::new(output_oxi[i].re, output_oxi[i].im);
141 }
142 }
143
144 #[cfg(not(feature = "oxifft"))]
145 {
146 for dst in output.iter_mut() {
148 *dst = Complex::new(0.0, 0.0);
149 }
150 }
151
152 Ok(())
153 }
154
155 fn len(&self) -> usize {
156 self.length
157 }
158}
159
160struct InverseRealFftPlanF64 {
162 length: usize,
163}
164
165impl InverseRealFftPlanF64 {
166 fn new(length: usize) -> Self {
167 Self { length }
168 }
169}
170
171impl ComplexToReal<f64> for InverseRealFftPlanF64 {
172 fn process(&self, input: &[Complex<f64>], output: &mut [f64]) -> FFTResult<()> {
173 if input.len() != self.input_len() {
175 return Err(FFTError::ValueError(format!(
176 "Input length {} doesn't match expected length {}",
177 input.len(),
178 self.input_len()
179 )));
180 }
181 if output.len() != self.length {
182 return Err(FFTError::ValueError(format!(
183 "Output length {} doesn't match plan length {}",
184 output.len(),
185 self.length
186 )));
187 }
188
189 #[cfg(feature = "oxifft")]
190 {
191 let mut buffer_oxi: Vec<OxiComplex<f64>> = Vec::with_capacity(self.length);
193
194 for &c in input.iter() {
196 buffer_oxi.push(OxiComplex::new(c.re, c.im));
197 }
198
199 let start_idx = if self.length % 2 == 0 {
201 input.len() - 1
202 } else {
203 input.len()
204 };
205
206 for i in (1..start_idx).rev() {
207 buffer_oxi.push(OxiComplex::new(input[i].re, -input[i].im));
208 }
209
210 while buffer_oxi.len() < self.length {
212 buffer_oxi.push(OxiComplex::new(0.0, 0.0));
213 }
214
215 let mut out_oxi: Vec<OxiComplex<f64>> = vec![OxiComplex::new(0.0, 0.0); self.length];
216
217 oxifft_plan_cache::execute_c2c(&buffer_oxi, &mut out_oxi, Direction::Backward)?;
218
219 let scale = 1.0 / self.length as f64;
221 for (i, dst) in output.iter_mut().enumerate() {
222 *dst = out_oxi[i].re * scale;
223 }
224 }
225
226 #[cfg(not(feature = "oxifft"))]
227 {
228 for dst in output.iter_mut() {
229 *dst = 0.0;
230 }
231 }
232
233 Ok(())
234 }
235
236 fn len(&self) -> usize {
237 self.length
238 }
239}
240
241struct RealFftPlanF32 {
245 length: usize,
246}
247
248impl RealFftPlanF32 {
249 fn new(length: usize) -> Self {
250 Self { length }
251 }
252}
253
254impl RealToComplex<f32> for RealFftPlanF32 {
255 fn process(&self, input: &[f32], output: &mut [Complex<f32>]) -> FFTResult<()> {
256 if input.len() != self.length {
258 return Err(FFTError::ValueError(format!(
259 "Input length {} doesn't match plan length {}",
260 input.len(),
261 self.length
262 )));
263 }
264 if output.len() != self.output_len() {
265 return Err(FFTError::ValueError(format!(
266 "Output length {} doesn't match expected length {}",
267 output.len(),
268 self.output_len()
269 )));
270 }
271
272 #[cfg(feature = "oxifft")]
273 {
274 let input_oxi: Vec<OxiComplex<f64>> = input
276 .iter()
277 .map(|&x| OxiComplex::new(x as f64, 0.0))
278 .collect();
279 let mut output_oxi: Vec<OxiComplex<f64>> = vec![OxiComplex::new(0.0, 0.0); self.length];
280
281 oxifft_plan_cache::execute_c2c(&input_oxi, &mut output_oxi, Direction::Forward)?;
282
283 let out_len = self.output_len();
285 for (i, dst) in output.iter_mut().enumerate().take(out_len) {
286 *dst = Complex::new(output_oxi[i].re as f32, output_oxi[i].im as f32);
287 }
288 }
289
290 #[cfg(not(feature = "oxifft"))]
291 {
292 for dst in output.iter_mut() {
293 *dst = Complex::new(0.0f32, 0.0f32);
294 }
295 }
296
297 Ok(())
298 }
299
300 fn len(&self) -> usize {
301 self.length
302 }
303}
304
305struct InverseRealFftPlanF32 {
307 length: usize,
308}
309
310impl InverseRealFftPlanF32 {
311 fn new(length: usize) -> Self {
312 Self { length }
313 }
314}
315
316impl ComplexToReal<f32> for InverseRealFftPlanF32 {
317 fn process(&self, input: &[Complex<f32>], output: &mut [f32]) -> FFTResult<()> {
318 if input.len() != self.input_len() {
320 return Err(FFTError::ValueError(format!(
321 "Input length {} doesn't match expected length {}",
322 input.len(),
323 self.input_len()
324 )));
325 }
326 if output.len() != self.length {
327 return Err(FFTError::ValueError(format!(
328 "Output length {} doesn't match plan length {}",
329 output.len(),
330 self.length
331 )));
332 }
333
334 #[cfg(feature = "oxifft")]
335 {
336 let mut buffer_oxi: Vec<OxiComplex<f64>> = Vec::with_capacity(self.length);
338
339 for &c in input.iter() {
340 buffer_oxi.push(OxiComplex::new(c.re as f64, c.im as f64));
341 }
342
343 let start_idx = if self.length % 2 == 0 {
344 input.len() - 1
345 } else {
346 input.len()
347 };
348
349 for i in (1..start_idx).rev() {
350 buffer_oxi.push(OxiComplex::new(input[i].re as f64, -(input[i].im as f64)));
351 }
352
353 while buffer_oxi.len() < self.length {
354 buffer_oxi.push(OxiComplex::new(0.0, 0.0));
355 }
356
357 let mut out_oxi: Vec<OxiComplex<f64>> = vec![OxiComplex::new(0.0, 0.0); self.length];
358
359 oxifft_plan_cache::execute_c2c(&buffer_oxi, &mut out_oxi, Direction::Backward)?;
360
361 let scale = 1.0 / self.length as f64;
363 for (i, dst) in output.iter_mut().enumerate() {
364 *dst = (out_oxi[i].re * scale) as f32;
365 }
366 }
367
368 #[cfg(not(feature = "oxifft"))]
369 {
370 for dst in output.iter_mut() {
371 *dst = 0.0f32;
372 }
373 }
374
375 Ok(())
376 }
377
378 fn len(&self) -> usize {
379 self.length
380 }
381}
382
383pub struct RealFftPlanner<T: Float> {
403 _phantom: std::marker::PhantomData<T>,
404}
405
406impl RealFftPlanner<f64> {
407 pub fn new() -> Self {
409 Self {
410 _phantom: std::marker::PhantomData,
411 }
412 }
413
414 pub fn plan_fft_forward(&mut self, length: usize) -> std::sync::Arc<dyn RealToComplex<f64>> {
424 std::sync::Arc::new(RealFftPlanF64::new(length))
425 }
426
427 pub fn plan_fft_inverse(&mut self, length: usize) -> std::sync::Arc<dyn ComplexToReal<f64>> {
437 std::sync::Arc::new(InverseRealFftPlanF64::new(length))
438 }
439}
440
441impl Default for RealFftPlanner<f64> {
442 fn default() -> Self {
443 Self::new()
444 }
445}
446
447impl RealFftPlanner<f32> {
448 pub fn new() -> Self {
450 Self {
451 _phantom: std::marker::PhantomData,
452 }
453 }
454
455 pub fn plan_fft_forward(&mut self, length: usize) -> std::sync::Arc<dyn RealToComplex<f32>> {
465 std::sync::Arc::new(RealFftPlanF32::new(length))
466 }
467
468 pub fn plan_fft_inverse(&mut self, length: usize) -> std::sync::Arc<dyn ComplexToReal<f32>> {
478 std::sync::Arc::new(InverseRealFftPlanF32::new(length))
479 }
480}
481
482impl Default for RealFftPlanner<f32> {
483 fn default() -> Self {
484 Self::new()
485 }
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491 use scirs2_core::numeric::Complex64;
492 use std::f64::consts::PI;
493
494 #[test]
495 fn test_real_fft_planner_f64() {
496 let mut planner = RealFftPlanner::<f64>::new();
497 let forward = planner.plan_fft_forward(8);
498 let inverse = planner.plan_fft_inverse(8);
499
500 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
502 let mut spectrum = vec![Complex64::new(0.0, 0.0); 5]; forward
506 .process(&input, &mut spectrum)
507 .expect("Forward FFT failed");
508
509 let sum: f64 = input.iter().sum();
511 assert!((spectrum[0].re - sum).abs() < 1e-10);
512 assert!(spectrum[0].im.abs() < 1e-10);
513
514 let mut recovered = vec![0.0; 8];
516 inverse
517 .process(&spectrum, &mut recovered)
518 .expect("Inverse FFT failed");
519
520 for (i, (&orig, &recov)) in input.iter().zip(recovered.iter()).enumerate() {
522 assert!(
523 (orig - recov).abs() < 1e-10,
524 "Mismatch at index {}: {} vs {}",
525 i,
526 orig,
527 recov
528 );
529 }
530 }
531
532 #[test]
533 fn test_real_fft_planner_f32() {
534 let mut planner = RealFftPlanner::<f32>::new();
535 let forward = planner.plan_fft_forward(8);
536 let inverse = planner.plan_fft_inverse(8);
537
538 let input = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
540 let mut spectrum = vec![Complex::new(0.0f32, 0.0); 5]; forward
544 .process(&input, &mut spectrum)
545 .expect("Forward FFT failed");
546
547 let mut recovered = vec![0.0f32; 8];
549 inverse
550 .process(&spectrum, &mut recovered)
551 .expect("Inverse FFT failed");
552
553 for (i, (&orig, &recov)) in input.iter().zip(recovered.iter()).enumerate() {
555 assert!(
556 (orig - recov).abs() < 1e-5,
557 "Mismatch at index {}: {} vs {}",
558 i,
559 orig,
560 recov
561 );
562 }
563 }
564
565 #[test]
566 fn test_sine_wave_fft() {
567 let mut planner = RealFftPlanner::<f64>::new();
568 let length = 128;
569 let forward = planner.plan_fft_forward(length);
570
571 let freq_index = 5;
573 let input: Vec<f64> = (0..length)
574 .map(|i| (2.0 * PI * freq_index as f64 * i as f64 / length as f64).sin())
575 .collect();
576
577 let mut spectrum = vec![Complex64::new(0.0, 0.0); length / 2 + 1];
578 forward.process(&input, &mut spectrum).expect("FFT failed");
579
580 let magnitudes: Vec<f64> = spectrum.iter().map(|c| c.norm()).collect();
582
583 let (peak_idx, &peak_mag) = magnitudes
585 .iter()
586 .enumerate()
587 .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("Operation failed"))
588 .expect("Operation failed");
589
590 assert_eq!(
591 peak_idx, freq_index,
592 "Peak should be at frequency index {}",
593 freq_index
594 );
595 assert!(peak_mag > length as f64 / 4.0, "Peak magnitude too small");
596 }
597
598 #[test]
599 fn test_plan_properties() {
600 let mut planner = RealFftPlanner::<f64>::new();
601 let forward = planner.plan_fft_forward(1024);
602
603 assert_eq!(forward.len(), 1024);
604 assert_eq!(forward.output_len(), 513); assert!(!forward.is_empty());
606 }
607
608 #[test]
609 fn test_voirs_usage_pattern() {
610 struct AudioProcessor {
612 forward: std::sync::Arc<dyn RealToComplex<f64>>,
613 backward: std::sync::Arc<dyn ComplexToReal<f64>>,
614 }
615
616 impl AudioProcessor {
617 fn new(size: usize) -> Self {
618 let mut planner = RealFftPlanner::<f64>::new();
619 Self {
620 forward: planner.plan_fft_forward(size),
621 backward: planner.plan_fft_inverse(size),
622 }
623 }
624
625 fn process(&self, input: &[f64]) -> Vec<f64> {
626 let mut spectrum = vec![Complex64::new(0.0, 0.0); self.forward.output_len()];
627 self.forward
628 .process(input, &mut spectrum)
629 .expect("Forward FFT failed");
630
631 let mut output = vec![0.0; self.backward.len()];
632 self.backward
633 .process(&spectrum, &mut output)
634 .expect("Inverse FFT failed");
635
636 output
637 }
638 }
639
640 let processor = AudioProcessor::new(16);
641 let input = vec![
642 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
643 ];
644 let output = processor.process(&input);
645
646 for (i, (&orig, &recov)) in input.iter().zip(output.iter()).enumerate() {
648 assert!((orig - recov).abs() < 1e-10, "Mismatch at {}", i);
649 }
650 }
651}