1use serde::{Deserialize, Serialize};
39use std::f64::consts::PI;
40
41use scirs2_core::numeric::Complex64;
42
43use crate::bluestein;
44use crate::butterfly::{direct_dft, generate_twiddle_table};
45use crate::cache_oblivious::{cache_oblivious_fft, cache_oblivious_ifft};
46use crate::error::{FFTError, FFTResult};
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
54#[non_exhaustive]
55#[derive(Default)]
56pub enum FftAlgorithm {
57 CooleyTukey,
59 SplitRadix,
61 #[default]
63 CacheOblivious,
64 Bluestein,
66 Rader,
68}
69
70#[derive(Debug, Clone)]
72pub struct FftPlanConfig {
73 pub size: usize,
75 pub algorithm: FftAlgorithm,
77 pub precompute_twiddles: bool,
79}
80
81impl Default for FftPlanConfig {
82 fn default() -> Self {
83 Self {
84 size: 0,
85 algorithm: FftAlgorithm::default(),
86 precompute_twiddles: true,
87 }
88 }
89}
90
91#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
93pub struct SerComplex {
94 re: f64,
95 im: f64,
96}
97
98impl From<Complex64> for SerComplex {
99 fn from(c: Complex64) -> Self {
100 Self { re: c.re, im: c.im }
101 }
102}
103
104impl From<SerComplex> for Complex64 {
105 fn from(s: SerComplex) -> Self {
106 Complex64::new(s.re, s.im)
107 }
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
112#[non_exhaustive]
113pub enum PlanNode {
114 DirectDft {
116 n: usize,
118 },
119 CooleyTukey {
121 n: usize,
123 twiddles: Vec<SerComplex>,
125 },
126 FourStep {
128 n: usize,
130 n1: usize,
132 n2: usize,
134 twiddles: Vec<SerComplex>,
136 },
137 Bluestein {
139 n: usize,
141 },
142 SplitRadix {
144 n: usize,
146 twiddles: Vec<SerComplex>,
148 },
149 Rader {
151 n: usize,
153 },
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct FftPlan {
159 pub size: usize,
161 pub algorithm: FftAlgorithm,
163 pub root: PlanNode,
165 pub precomputed: bool,
167}
168
169pub fn create_plan(config: &FftPlanConfig) -> FFTResult<FftPlan> {
183 let n = config.size;
184 if n == 0 {
185 return Err(FFTError::PlanError("create_plan: size must be >= 1".into()));
186 }
187
188 let root = match config.algorithm {
189 FftAlgorithm::CooleyTukey => {
190 if !n.is_power_of_two() {
191 return Err(FFTError::PlanError(
192 "CooleyTukey requires power-of-two size".into(),
193 ));
194 }
195 let twiddles = if config.precompute_twiddles {
196 generate_twiddle_table(n)?
197 .into_iter()
198 .map(SerComplex::from)
199 .collect()
200 } else {
201 Vec::new()
202 };
203 PlanNode::CooleyTukey { n, twiddles }
204 }
205 FftAlgorithm::SplitRadix => {
206 if !n.is_power_of_two() {
207 return Err(FFTError::PlanError(
208 "SplitRadix requires power-of-two size".into(),
209 ));
210 }
211 let twiddles = if config.precompute_twiddles {
212 generate_twiddle_table(n)?
213 .into_iter()
214 .map(SerComplex::from)
215 .collect()
216 } else {
217 Vec::new()
218 };
219 PlanNode::SplitRadix { n, twiddles }
220 }
221 FftAlgorithm::CacheOblivious => {
222 if n <= 16 {
223 PlanNode::DirectDft { n }
224 } else {
225 let (n1, n2) = find_factor_pair(n);
227 let twiddles = if config.precompute_twiddles {
228 compute_four_step_twiddles(n, n1, n2)
229 .into_iter()
230 .map(SerComplex::from)
231 .collect()
232 } else {
233 Vec::new()
234 };
235 if n1 == 1 || n2 == 1 {
236 PlanNode::Bluestein { n }
238 } else {
239 PlanNode::FourStep {
240 n,
241 n1,
242 n2,
243 twiddles,
244 }
245 }
246 }
247 }
248 FftAlgorithm::Bluestein => PlanNode::Bluestein { n },
249 FftAlgorithm::Rader => {
250 if !is_prime(n) {
251 return Err(FFTError::PlanError(
252 "Rader algorithm requires prime size".into(),
253 ));
254 }
255 PlanNode::Rader { n }
256 }
257 };
258
259 Ok(FftPlan {
260 size: n,
261 algorithm: config.algorithm,
262 root,
263 precomputed: config.precompute_twiddles,
264 })
265}
266
267pub fn execute_plan(plan: &FftPlan, data: &mut [Complex64]) -> FFTResult<()> {
279 if data.len() != plan.size {
280 return Err(FFTError::DimensionError(format!(
281 "execute_plan: expected data of length {}, got {}",
282 plan.size,
283 data.len()
284 )));
285 }
286
287 let result = execute_node(&plan.root, data)?;
288 data.copy_from_slice(&result);
289 Ok(())
290}
291
292fn execute_node(node: &PlanNode, data: &[Complex64]) -> FFTResult<Vec<Complex64>> {
294 match node {
295 PlanNode::DirectDft { n } => {
296 debug_assert_eq!(*n, data.len());
297 if *n <= 1 {
298 Ok(data.to_vec())
299 } else {
300 direct_dft(data)
301 }
302 }
303 PlanNode::CooleyTukey { n, twiddles } => {
304 debug_assert_eq!(*n, data.len());
305 execute_cooley_tukey(data, *n, twiddles)
306 }
307 PlanNode::SplitRadix { n, .. } => {
308 debug_assert_eq!(*n, data.len());
309 let mut buf = data.to_vec();
311 crate::butterfly::split_radix_butterfly(&mut buf)?;
312 Ok(buf)
313 }
314 PlanNode::FourStep { n, n1, n2, .. } => {
315 debug_assert_eq!(*n, data.len());
316 cache_oblivious_fft(data).or_else(|_| {
317 let _ = (n1, n2); direct_dft(data)
320 })
321 }
322 PlanNode::Bluestein { n } => {
323 debug_assert_eq!(*n, data.len());
324 bluestein::bluestein_fft(data)
325 }
326 PlanNode::Rader { n } => {
327 debug_assert_eq!(*n, data.len());
328 bluestein::bluestein_fft(data)
330 }
331 }
332}
333
334fn execute_cooley_tukey(
336 data: &[Complex64],
337 n: usize,
338 precomputed_twiddles: &[SerComplex],
339) -> FFTResult<Vec<Complex64>> {
340 if n == 1 {
341 return Ok(data.to_vec());
342 }
343
344 let mut output = bit_reverse_copy(data, n);
346
347 let mut size = 2;
349 while size <= n {
350 let half = size / 2;
351 let step = n / size;
352
353 for k in 0..half {
354 let twiddle = if !precomputed_twiddles.is_empty() {
355 let idx = k * step;
356 if idx < precomputed_twiddles.len() {
357 Complex64::from(precomputed_twiddles[idx])
358 } else {
359 let angle = -2.0 * PI * (k * step) as f64 / n as f64;
360 Complex64::new(angle.cos(), angle.sin())
361 }
362 } else {
363 let angle = -2.0 * PI * (k * step) as f64 / n as f64;
364 Complex64::new(angle.cos(), angle.sin())
365 };
366
367 let mut j = k;
368 while j < n {
369 let u = output[j];
370 let t = twiddle * output[j + half];
371 output[j] = u + t;
372 output[j + half] = u - t;
373 j += size;
374 }
375 }
376 size *= 2;
377 }
378
379 Ok(output)
380}
381
382fn bit_reverse_copy(data: &[Complex64], n: usize) -> Vec<Complex64> {
384 let bits = n.trailing_zeros();
385 let mut out = vec![Complex64::new(0.0, 0.0); n];
386 for i in 0..n {
387 let rev = reverse_bits(i, bits);
388 out[rev] = data[i];
389 }
390 out
391}
392
393fn reverse_bits(x: usize, bits: u32) -> usize {
395 let mut result = 0usize;
396 let mut val = x;
397 for _ in 0..bits {
398 result = (result << 1) | (val & 1);
399 val >>= 1;
400 }
401 result
402}
403
404pub fn serialize_plan(plan: &FftPlan) -> FFTResult<Vec<u8>> {
414 serde_json::to_vec(plan).map_err(|e| FFTError::IOError(format!("serialize_plan: {e}")))
415}
416
417pub fn deserialize_plan(data: &[u8]) -> FFTResult<FftPlan> {
423 serde_json::from_slice(data).map_err(|e| FFTError::IOError(format!("deserialize_plan: {e}")))
424}
425
426fn find_factor_pair(n: usize) -> (usize, usize) {
433 if n <= 1 {
434 return (1, n);
435 }
436 let sqrt_n = (n as f64).sqrt() as usize;
437 for candidate in (2..=sqrt_n).rev() {
438 if n % candidate == 0 {
439 return (candidate, n / candidate);
440 }
441 }
442 (1, n) }
444
445fn compute_four_step_twiddles(n: usize, n1: usize, n2: usize) -> Vec<Complex64> {
447 let angle_base = -2.0 * PI / n as f64;
448 let mut twiddles = Vec::with_capacity(n1 * n2);
449 for i in 0..n1 {
450 for j in 0..n2 {
451 let angle = angle_base * (i * j) as f64;
452 twiddles.push(Complex64::new(angle.cos(), angle.sin()));
453 }
454 }
455 twiddles
456}
457
458fn is_prime(n: usize) -> bool {
460 if n < 2 {
461 return false;
462 }
463 if n < 4 {
464 return true;
465 }
466 if n % 2 == 0 || n % 3 == 0 {
467 return false;
468 }
469 let mut i = 5;
470 while i * i <= n {
471 if n % i == 0 || n % (i + 2) == 0 {
472 return false;
473 }
474 i += 6;
475 }
476 true
477}
478
479#[cfg(test)]
484mod tests {
485 use super::*;
486 use approx::assert_relative_eq;
487
488 fn max_abs_err(a: &[Complex64], b: &[Complex64]) -> f64 {
489 a.iter()
490 .zip(b.iter())
491 .map(|(x, y)| (x - y).norm())
492 .fold(0.0_f64, f64::max)
493 }
494
495 fn reference_dft(data: &[Complex64]) -> Vec<Complex64> {
496 direct_dft(data).expect("direct_dft failed")
497 }
498
499 #[test]
501 fn test_create_plan_cooley_tukey() {
502 let config = FftPlanConfig {
503 size: 16,
504 algorithm: FftAlgorithm::CooleyTukey,
505 precompute_twiddles: true,
506 };
507 let plan = create_plan(&config).expect("create_plan failed");
508 assert_eq!(plan.size, 16);
509 assert_eq!(plan.algorithm, FftAlgorithm::CooleyTukey);
510 assert!(plan.precomputed);
511 }
512
513 #[test]
514 fn test_create_plan_cooley_tukey_non_pow2_fails() {
515 let config = FftPlanConfig {
516 size: 12,
517 algorithm: FftAlgorithm::CooleyTukey,
518 precompute_twiddles: false,
519 };
520 assert!(create_plan(&config).is_err());
521 }
522
523 #[test]
524 fn test_create_plan_cache_oblivious() {
525 for &n in &[8, 16, 32, 64, 12, 15, 24] {
526 let config = FftPlanConfig {
527 size: n,
528 algorithm: FftAlgorithm::CacheOblivious,
529 precompute_twiddles: true,
530 };
531 let plan = create_plan(&config).expect("create_plan failed");
532 assert_eq!(plan.size, n);
533 }
534 }
535
536 #[test]
537 fn test_create_plan_bluestein() {
538 for &n in &[7, 11, 13, 17, 100] {
539 let config = FftPlanConfig {
540 size: n,
541 algorithm: FftAlgorithm::Bluestein,
542 precompute_twiddles: false,
543 };
544 let plan = create_plan(&config).expect("create_plan failed");
545 assert_eq!(plan.size, n);
546 }
547 }
548
549 #[test]
550 fn test_create_plan_rader_prime() {
551 let config = FftPlanConfig {
552 size: 7,
553 algorithm: FftAlgorithm::Rader,
554 precompute_twiddles: false,
555 };
556 let plan = create_plan(&config).expect("create_plan failed");
557 assert_eq!(plan.size, 7);
558 }
559
560 #[test]
561 fn test_create_plan_rader_non_prime_fails() {
562 let config = FftPlanConfig {
563 size: 12,
564 algorithm: FftAlgorithm::Rader,
565 precompute_twiddles: false,
566 };
567 assert!(create_plan(&config).is_err());
568 }
569
570 #[test]
571 fn test_create_plan_zero_size_fails() {
572 let config = FftPlanConfig {
573 size: 0,
574 algorithm: FftAlgorithm::CacheOblivious,
575 precompute_twiddles: false,
576 };
577 assert!(create_plan(&config).is_err());
578 }
579
580 #[test]
582 fn test_execute_plan_cooley_tukey() {
583 let config = FftPlanConfig {
584 size: 8,
585 algorithm: FftAlgorithm::CooleyTukey,
586 precompute_twiddles: true,
587 };
588 let plan = create_plan(&config).expect("create_plan failed");
589 let input: Vec<Complex64> = (0..8).map(|k| Complex64::new(k as f64, 0.0)).collect();
590 let expected = reference_dft(&input);
591
592 let mut data = input;
593 execute_plan(&plan, &mut data).expect("execute_plan failed");
594 let err = max_abs_err(&data, &expected);
595 assert!(err < 1e-10, "CT execution error = {err}");
596 }
597
598 #[test]
599 fn test_execute_plan_bluestein() {
600 let config = FftPlanConfig {
601 size: 7,
602 algorithm: FftAlgorithm::Bluestein,
603 precompute_twiddles: false,
604 };
605 let plan = create_plan(&config).expect("create_plan failed");
606 let input: Vec<Complex64> = (0..7).map(|k| Complex64::new(k as f64, 0.0)).collect();
607 let expected = reference_dft(&input);
608
609 let mut data = input;
610 execute_plan(&plan, &mut data).expect("execute_plan failed");
611 let err = max_abs_err(&data, &expected);
612 assert!(err < 1e-10, "Bluestein execution error = {err}");
613 }
614
615 #[test]
616 fn test_execute_plan_wrong_size() {
617 let config = FftPlanConfig {
618 size: 8,
619 algorithm: FftAlgorithm::CooleyTukey,
620 precompute_twiddles: false,
621 };
622 let plan = create_plan(&config).expect("create_plan failed");
623 let mut data = vec![Complex64::new(1.0, 0.0); 16];
624 assert!(execute_plan(&plan, &mut data).is_err());
625 }
626
627 #[test]
628 fn test_execute_plan_matches_direct_fft() {
629 let test_cases = vec![
631 (8, FftAlgorithm::CooleyTukey),
632 (16, FftAlgorithm::CooleyTukey),
633 (8, FftAlgorithm::SplitRadix),
634 (7, FftAlgorithm::Bluestein),
635 (13, FftAlgorithm::Bluestein),
636 ];
637 for (n, algo) in test_cases {
638 let config = FftPlanConfig {
639 size: n,
640 algorithm: algo,
641 precompute_twiddles: true,
642 };
643 let plan = create_plan(&config).expect("plan creation failed");
644 let input: Vec<Complex64> = (0..n)
645 .map(|k| Complex64::new((k as f64 * 0.5).sin(), (k as f64 * 0.3).cos()))
646 .collect();
647 let expected = reference_dft(&input);
648 let mut data = input;
649 execute_plan(&plan, &mut data).expect("execution failed");
650 let err = max_abs_err(&data, &expected);
651 assert!(err < 1e-8, "{algo:?} n={n}: error = {err}");
652 }
653 }
654
655 #[test]
657 fn test_serialize_deserialize_roundtrip() {
658 let config = FftPlanConfig {
659 size: 16,
660 algorithm: FftAlgorithm::CooleyTukey,
661 precompute_twiddles: true,
662 };
663 let plan = create_plan(&config).expect("create_plan failed");
664 let bytes = serialize_plan(&plan).expect("serialize failed");
665 assert!(!bytes.is_empty());
666
667 let plan2 = deserialize_plan(&bytes).expect("deserialize failed");
668 assert_eq!(plan.size, plan2.size);
669 assert_eq!(plan.algorithm, plan2.algorithm);
670 assert_eq!(plan.precomputed, plan2.precomputed);
671 }
672
673 #[test]
674 fn test_serialized_plan_produces_same_result() {
675 let config = FftPlanConfig {
676 size: 8,
677 algorithm: FftAlgorithm::CooleyTukey,
678 precompute_twiddles: true,
679 };
680 let plan = create_plan(&config).expect("create_plan failed");
681 let bytes = serialize_plan(&plan).expect("serialize failed");
682 let plan2 = deserialize_plan(&bytes).expect("deserialize failed");
683
684 let input: Vec<Complex64> = (0..8).map(|k| Complex64::new(k as f64, 0.0)).collect();
685
686 let mut data1 = input.clone();
687 let mut data2 = input;
688 execute_plan(&plan, &mut data1).expect("exec1 failed");
689 execute_plan(&plan2, &mut data2).expect("exec2 failed");
690
691 let err = max_abs_err(&data1, &data2);
692 assert!(err < 1e-14, "serialized plan diverges: {err}");
693 }
694
695 #[test]
696 fn test_deserialize_invalid_data() {
697 assert!(deserialize_plan(b"not json").is_err());
698 }
699
700 #[test]
702 fn test_plan_reuse_same_result() {
703 let config = FftPlanConfig {
704 size: 16,
705 algorithm: FftAlgorithm::CooleyTukey,
706 precompute_twiddles: true,
707 };
708 let plan = create_plan(&config).expect("create_plan failed");
709
710 let input1: Vec<Complex64> = (0..16).map(|k| Complex64::new(k as f64, 0.0)).collect();
711 let input2: Vec<Complex64> = (0..16).map(|k| Complex64::new(0.0, k as f64)).collect();
712
713 let mut data1 = input1.clone();
714 let mut data2 = input2.clone();
715 execute_plan(&plan, &mut data1).expect("exec1 failed");
716 execute_plan(&plan, &mut data2).expect("exec2 failed");
717
718 let expected1 = reference_dft(&input1);
719 let expected2 = reference_dft(&input2);
720
721 assert!(max_abs_err(&data1, &expected1) < 1e-10);
722 assert!(max_abs_err(&data2, &expected2) < 1e-10);
723 }
724
725 #[test]
727 fn test_is_prime() {
728 assert!(!is_prime(0));
729 assert!(!is_prime(1));
730 assert!(is_prime(2));
731 assert!(is_prime(3));
732 assert!(!is_prime(4));
733 assert!(is_prime(5));
734 assert!(is_prime(7));
735 assert!(!is_prime(9));
736 assert!(is_prime(11));
737 assert!(is_prime(13));
738 assert!(!is_prime(15));
739 assert!(is_prime(17));
740 }
741
742 #[test]
743 fn test_find_factor_pair() {
744 let (a, b) = find_factor_pair(12);
745 assert_eq!(a * b, 12);
746 assert!(a >= 2);
747
748 let (a, b) = find_factor_pair(16);
749 assert_eq!(a * b, 16);
750 assert!(a >= 2);
751
752 let (a, b) = find_factor_pair(13);
754 assert_eq!(a, 1);
755 assert_eq!(b, 13);
756 }
757
758 #[test]
759 fn test_bit_reverse() {
760 assert_eq!(reverse_bits(0b000, 3), 0b000);
761 assert_eq!(reverse_bits(0b001, 3), 0b100);
762 assert_eq!(reverse_bits(0b010, 3), 0b010);
763 assert_eq!(reverse_bits(0b011, 3), 0b110);
764 assert_eq!(reverse_bits(0b100, 3), 0b001);
765 }
766}