1use scirs2_core::numeric::Complex64;
20use std::f64::consts::PI;
21
22use crate::error::{FFTError, FFTResult};
23
24pub fn generate_twiddle_table(n: usize) -> FFTResult<Vec<Complex64>> {
36 if n == 0 {
37 return Err(FFTError::ValueError(
38 "generate_twiddle_table: n must be > 0".into(),
39 ));
40 }
41 if n == 1 {
42 return Ok(vec![Complex64::new(1.0, 0.0)]);
43 }
44 let inv_n = -2.0 * PI / n as f64;
45 Ok((0..n)
46 .map(|k| {
47 let angle = inv_n * k as f64;
48 Complex64::new(angle.cos(), angle.sin())
49 })
50 .collect())
51}
52
53pub fn generate_inverse_twiddle_table(n: usize) -> FFTResult<Vec<Complex64>> {
61 if n == 0 {
62 return Err(FFTError::ValueError(
63 "generate_inverse_twiddle_table: n must be > 0".into(),
64 ));
65 }
66 if n == 1 {
67 return Ok(vec![Complex64::new(1.0, 0.0)]);
68 }
69 let inv_n = 2.0 * PI / n as f64;
70 Ok((0..n)
71 .map(|k| {
72 let angle = inv_n * k as f64;
73 Complex64::new(angle.cos(), angle.sin())
74 })
75 .collect())
76}
77
78#[inline(always)]
93pub fn butterfly2(a: &mut Complex64, b: &mut Complex64, twiddle: Complex64) {
94 let t = twiddle * *b;
95 let new_a = *a + t;
96 let new_b = *a - t;
97 *a = new_a;
98 *b = new_b;
99}
100
101#[inline]
121pub fn butterfly4(a: &mut [Complex64; 4], twiddles: &[Complex64; 3]) {
122 let x0 = a[0];
131 let x1 = a[1];
132 let x2 = a[2];
133 let x3 = a[3];
134
135 a[0] = x0 + x1 + x2 + x3;
137
138 a[1] = x0 + twiddles[0] * x1 + twiddles[1] * x2 + twiddles[2] * x3;
140
141 let w2 = twiddles[1]; let w4 = w2 * w2; let w6 = w4 * w2; a[2] = x0 + w2 * x1 + w4 * x2 + w6 * x3;
147
148 let w3 = twiddles[2]; let w9 = w3 * w3 * w3; a[3] = x0 + w3 * x1 + w6 * x2 + w9 * x3;
152}
153
154#[inline]
163pub fn butterfly8(a: &mut [Complex64; 8], twiddles: &[Complex64; 7]) {
164 let w = [
167 Complex64::new(1.0, 0.0), twiddles[0], twiddles[1], twiddles[2], twiddles[3], twiddles[4], twiddles[5], twiddles[6], ];
176
177 let input = *a;
178 for k in 0..8 {
179 let mut sum = Complex64::new(0.0, 0.0);
180 for n in 0..8 {
181 let idx = (n * k) % 8;
182 sum += input[n] * w[idx];
183 }
184 a[k] = sum;
185 }
186}
187
188pub fn split_radix_butterfly(data: &mut [Complex64]) -> FFTResult<()> {
202 let n = data.len();
203 if n < 4 {
204 return Err(FFTError::ValueError(
205 "split_radix_butterfly: length must be >= 4".into(),
206 ));
207 }
208 if !n.is_power_of_two() {
209 return Err(FFTError::ValueError(
210 "split_radix_butterfly: length must be a power of two".into(),
211 ));
212 }
213
214 let bits = n.trailing_zeros();
216 for i in 0..n {
217 let j = reverse_bits(i, bits);
218 if i < j {
219 data.swap(i, j);
220 }
221 }
222
223 let mut size = 2;
225 while size <= n {
226 let half = size / 2;
227 let angle_step = -2.0 * PI / size as f64;
228
229 let mut group_start = 0;
230 while group_start < n {
231 for k in 0..half {
232 let angle = angle_step * k as f64;
233 let twiddle = Complex64::new(angle.cos(), angle.sin());
234
235 let i = group_start + k;
236 let j = i + half;
237
238 let t = twiddle * data[j];
239 data[j] = data[i] - t;
240 data[i] = data[i] + t;
241 }
242 group_start += size;
243 }
244 size *= 2;
245 }
246
247 Ok(())
248}
249
250fn reverse_bits(x: usize, bits: u32) -> usize {
252 let mut result = 0usize;
253 let mut val = x;
254 for _ in 0..bits {
255 result = (result << 1) | (val & 1);
256 val >>= 1;
257 }
258 result
259}
260
261pub fn direct_dft(data: &[Complex64]) -> FFTResult<Vec<Complex64>> {
274 let n = data.len();
275 if n == 0 {
276 return Err(FFTError::ValueError("direct_dft: empty input".into()));
277 }
278 if n == 1 {
279 return Ok(data.to_vec());
280 }
281
282 let angle_base = -2.0 * PI / n as f64;
283 let mut result = vec![Complex64::new(0.0, 0.0); n];
284 for k in 0..n {
285 let mut sum = Complex64::new(0.0, 0.0);
286 for j in 0..n {
287 let angle = angle_base * (k * j) as f64;
288 let w = Complex64::new(angle.cos(), angle.sin());
289 sum += data[j] * w;
290 }
291 result[k] = sum;
292 }
293 Ok(result)
294}
295
296pub fn direct_idft(data: &[Complex64]) -> FFTResult<Vec<Complex64>> {
302 let n = data.len();
303 if n == 0 {
304 return Err(FFTError::ValueError("direct_idft: empty input".into()));
305 }
306 if n == 1 {
307 return Ok(data.to_vec());
308 }
309
310 let angle_base = 2.0 * PI / n as f64;
311 let inv_n = 1.0 / n as f64;
312 let mut result = vec![Complex64::new(0.0, 0.0); n];
313 for k in 0..n {
314 let mut sum = Complex64::new(0.0, 0.0);
315 for j in 0..n {
316 let angle = angle_base * (k * j) as f64;
317 let w = Complex64::new(angle.cos(), angle.sin());
318 sum += data[j] * w;
319 }
320 result[k] = sum * inv_n;
321 }
322 Ok(result)
323}
324
325#[cfg(test)]
330mod tests {
331 use super::*;
332 use approx::assert_relative_eq;
333
334 fn max_abs_err(a: &[Complex64], b: &[Complex64]) -> f64 {
336 a.iter()
337 .zip(b.iter())
338 .map(|(x, y)| (x - y).norm())
339 .fold(0.0_f64, f64::max)
340 }
341
342 #[test]
344 fn test_twiddle_table_size_1() {
345 let tw = generate_twiddle_table(1).expect("should succeed");
346 assert_eq!(tw.len(), 1);
347 assert_relative_eq!(tw[0].re, 1.0, epsilon = 1e-15);
348 assert_relative_eq!(tw[0].im, 0.0, epsilon = 1e-15);
349 }
350
351 #[test]
352 fn test_twiddle_table_values() {
353 let n = 8;
354 let tw = generate_twiddle_table(n).expect("should succeed");
355 assert_eq!(tw.len(), n);
356
357 assert_relative_eq!(tw[0].re, 1.0, epsilon = 1e-14);
359 assert_relative_eq!(tw[0].im, 0.0, epsilon = 1e-14);
360
361 assert_relative_eq!(tw[n / 4].re, 0.0, epsilon = 1e-14);
363 assert_relative_eq!(tw[n / 4].im, -1.0, epsilon = 1e-14);
364
365 assert_relative_eq!(tw[n / 2].re, -1.0, epsilon = 1e-14);
367 assert_relative_eq!(tw[n / 2].im, 0.0, epsilon = 1e-14);
368
369 for w in &tw {
371 assert_relative_eq!(w.norm(), 1.0, epsilon = 1e-14);
372 }
373 }
374
375 #[test]
376 fn test_twiddle_table_error_on_zero() {
377 assert!(generate_twiddle_table(0).is_err());
378 }
379
380 #[test]
382 fn test_butterfly2_trivial_twiddle() {
383 let mut a = Complex64::new(3.0, 0.0);
384 let mut b = Complex64::new(1.0, 0.0);
385 butterfly2(&mut a, &mut b, Complex64::new(1.0, 0.0));
386 assert_relative_eq!(a.re, 4.0, epsilon = 1e-14);
387 assert_relative_eq!(b.re, 2.0, epsilon = 1e-14);
388 }
389
390 #[test]
391 fn test_butterfly2_with_twiddle() {
392 let mut a = Complex64::new(5.0, 0.0);
394 let mut b = Complex64::new(3.0, 0.0);
395 butterfly2(&mut a, &mut b, Complex64::new(-1.0, 0.0));
396 assert_relative_eq!(a.re, 2.0, epsilon = 1e-14);
397 assert_relative_eq!(b.re, 8.0, epsilon = 1e-14);
398 }
399
400 #[test]
402 fn test_butterfly4_matches_direct_dft() {
403 let input = [
404 Complex64::new(1.0, 0.0),
405 Complex64::new(2.0, 0.0),
406 Complex64::new(3.0, 0.0),
407 Complex64::new(4.0, 0.0),
408 ];
409 let expected = direct_dft(&input).expect("direct_dft failed");
410
411 let twiddles = [
416 Complex64::new(0.0, -1.0), Complex64::new(-1.0, 0.0), Complex64::new(0.0, 1.0), ];
420 let mut data = input;
421 butterfly4(&mut data, &twiddles);
422
423 let err = max_abs_err(&data, &expected);
424 assert!(err < 1e-12, "butterfly4 error = {err}");
425 }
426
427 #[test]
429 fn test_butterfly8_matches_direct_dft() {
430 let input: [Complex64; 8] = [
431 Complex64::new(1.0, 0.0),
432 Complex64::new(2.0, -1.0),
433 Complex64::new(0.5, 0.5),
434 Complex64::new(3.0, 0.0),
435 Complex64::new(-1.0, 1.0),
436 Complex64::new(0.0, 2.0),
437 Complex64::new(1.5, -0.5),
438 Complex64::new(-0.5, 0.0),
439 ];
440 let expected = direct_dft(&input).expect("direct_dft failed");
441
442 let twiddles: [Complex64; 7] = std::array::from_fn(|k| {
444 let angle = -2.0 * PI * (k + 1) as f64 / 8.0;
445 Complex64::new(angle.cos(), angle.sin())
446 });
447
448 let mut data = input;
449 butterfly8(&mut data, &twiddles);
450
451 let err = max_abs_err(&data, &expected);
452 assert!(err < 1e-10, "butterfly8 error = {err}");
453 }
454
455 #[test]
457 fn test_direct_dft_known_result() {
458 let input = vec![Complex64::new(1.0, 0.0); 4];
460 let result = direct_dft(&input).expect("direct_dft failed");
461 assert_relative_eq!(result[0].re, 4.0, epsilon = 1e-12);
462 for k in 1..4 {
463 assert!(result[k].norm() < 1e-12, "non-zero at k={k}");
464 }
465 }
466
467 #[test]
468 fn test_direct_dft_idft_roundtrip() {
469 let input = vec![
470 Complex64::new(1.0, 2.0),
471 Complex64::new(3.0, -1.0),
472 Complex64::new(0.5, 0.5),
473 Complex64::new(-2.0, 1.5),
474 ];
475 let spectrum = direct_dft(&input).expect("dft failed");
476 let recovered = direct_idft(&spectrum).expect("idft failed");
477 let err = max_abs_err(&input, &recovered);
478 assert!(err < 1e-12, "roundtrip error = {err}");
479 }
480
481 #[test]
482 fn test_direct_dft_empty() {
483 assert!(direct_dft(&[]).is_err());
484 }
485
486 #[test]
488 fn test_split_radix_butterfly_size_4() {
489 let input = vec![
490 Complex64::new(1.0, 0.0),
491 Complex64::new(0.0, 1.0),
492 Complex64::new(-1.0, 0.0),
493 Complex64::new(0.0, -1.0),
494 ];
495 let expected = direct_dft(&input).expect("dft failed");
496 let mut data = input;
497 split_radix_butterfly(&mut data).expect("split_radix failed");
498 let err = max_abs_err(&data, &expected);
499 assert!(err < 1e-10, "split_radix error (n=4) = {err}");
500 }
501
502 #[test]
503 fn test_split_radix_butterfly_size_8() {
504 let input: Vec<Complex64> = (0..8)
505 .map(|k| Complex64::new(k as f64, -(k as f64) * 0.5))
506 .collect();
507 let expected = direct_dft(&input).expect("dft failed");
508 let mut data = input;
509 split_radix_butterfly(&mut data).expect("split_radix failed");
510 let err = max_abs_err(&data, &expected);
511 assert!(err < 1e-10, "split_radix error (n=8) = {err}");
512 }
513
514 #[test]
515 fn test_split_radix_butterfly_size_16() {
516 let input: Vec<Complex64> = (0..16)
517 .map(|k| Complex64::new((k as f64 * 0.5).sin(), (k as f64 * 0.3).cos()))
518 .collect();
519 let expected = direct_dft(&input).expect("dft failed");
520 let mut data = input;
521 split_radix_butterfly(&mut data).expect("split_radix failed");
522 let err = max_abs_err(&data, &expected);
523 assert!(err < 1e-10, "split_radix error (n=16) = {err}");
524 }
525
526 #[test]
527 fn test_split_radix_butterfly_not_power_of_two() {
528 let mut data = vec![Complex64::new(1.0, 0.0); 6];
529 assert!(split_radix_butterfly(&mut data).is_err());
530 }
531
532 #[test]
533 fn test_split_radix_butterfly_too_small() {
534 let mut data = vec![Complex64::new(1.0, 0.0); 2];
535 assert!(split_radix_butterfly(&mut data).is_err());
536 }
537
538 #[test]
540 fn test_inverse_twiddle_table() {
541 let n = 8;
542 let fw = generate_twiddle_table(n).expect("forward failed");
543 let inv = generate_inverse_twiddle_table(n).expect("inverse failed");
544 for k in 0..n {
546 let product = fw[k] * inv[k];
547 assert_relative_eq!(product.re, 1.0, epsilon = 1e-14);
548 assert_relative_eq!(product.im, 0.0, epsilon = 1e-14);
549 }
550 }
551}