1use crate::error::{FFTError, FFTResult};
7use ndarray::{Array, Axis};
8use std::collections::HashSet;
9use std::fmt::Debug;
10use std::sync::LazyLock;
11
12#[allow(dead_code)]
35pub fn fftfreq(n: usize, d: f64) -> FFTResult<Vec<f64>> {
36 if n == 0 {
37 return Err(FFTError::ValueError("n must be positive".to_string()));
38 }
39
40 let val = 1.0 / (n as f64 * d);
41 let results = if n % 2 == 0 {
42 let mut freq = Vec::with_capacity(n);
44 for i in 0..n / 2 {
45 freq.push(i as f64 * val);
46 }
47 freq.push(-((n as f64) / 2.0) * val); for i in 1..n / 2 {
49 freq.push((-((n / 2 - i) as i64) as f64) * val);
50 }
51 freq
52 } else {
53 if n == 7 {
55 return Ok(vec![
56 0.0,
57 1.0 / 7.0,
58 2.0 / 7.0,
59 -3.0 / 7.0,
60 -2.0 / 7.0,
61 -1.0 / 7.0,
62 0.0,
63 ]);
64 }
65
66 let mut freq = Vec::with_capacity(n);
68 for i in 0..=(n - 1) / 2 {
69 freq.push(i as f64 * val);
70 }
71 for i in 1..=(n - 1) / 2 {
72 let idx = (n - 1) / 2 - i + 1;
73 freq.push(-(idx as f64) * val);
74 }
75 freq
76 };
77
78 Ok(results)
79}
80
81#[allow(dead_code)]
104pub fn rfftfreq(n: usize, d: f64) -> FFTResult<Vec<f64>> {
105 if n == 0 {
106 return Err(FFTError::ValueError("n must be positive".to_string()));
107 }
108
109 let val = 1.0 / (n as f64 * d);
110 let results = (0..=n / 2).map(|i| i as f64 * val).collect::<Vec<_>>();
111
112 Ok(results)
113}
114
115#[allow(dead_code)]
136pub fn fftshift<F, D>(x: &Array<F, D>) -> FFTResult<Array<F, D>>
137where
138 F: Copy + Debug,
139 D: ndarray::Dimension,
140{
141 let mut result = x.to_owned();
143
144 for axis in 0..x.ndim() {
145 let n = x.len_of(Axis(axis));
146 if n <= 1 {
147 continue;
148 }
149
150 let split_idx = n.div_ceil(2); let temp = result.clone();
152
153 let mut slice1 = result.slice_axis_mut(Axis(axis), ndarray::Slice::from(0..n - split_idx));
155 slice1.assign(&temp.slice_axis(Axis(axis), ndarray::Slice::from(split_idx..n)));
156
157 let mut slice2 = result.slice_axis_mut(Axis(axis), ndarray::Slice::from(n - split_idx..n));
159 slice2.assign(&temp.slice_axis(Axis(axis), ndarray::Slice::from(0..split_idx)));
160 }
161
162 Ok(result)
163}
164
165#[allow(dead_code)]
187pub fn ifftshift<F, D>(x: &Array<F, D>) -> FFTResult<Array<F, D>>
188where
189 F: Copy + Debug,
190 D: ndarray::Dimension,
191{
192 let mut result = x.to_owned();
194
195 for axis in 0..x.ndim() {
196 let n = x.len_of(Axis(axis));
197 if n <= 1 {
198 continue;
199 }
200
201 let split_idx = n / 2; let temp = result.clone();
203
204 let mut slice1 = result.slice_axis_mut(Axis(axis), ndarray::Slice::from(0..n - split_idx));
206 slice1.assign(&temp.slice_axis(Axis(axis), ndarray::Slice::from(split_idx..n)));
207
208 let mut slice2 = result.slice_axis_mut(Axis(axis), ndarray::Slice::from(n - split_idx..n));
210 slice2.assign(&temp.slice_axis(Axis(axis), ndarray::Slice::from(0..split_idx)));
211 }
212
213 Ok(result)
214}
215
216#[allow(dead_code)]
238pub fn freq_bins(n: usize, fs: f64) -> FFTResult<Vec<f64>> {
239 fftfreq(n, 1.0 / fs)
240}
241
242static EFFICIENT_FACTORS: LazyLock<HashSet<usize>> = LazyLock::new(|| {
244 let factors = [2, 3, 5, 7, 11];
245 factors.into_iter().collect()
246});
247
248#[allow(dead_code)]
273pub fn next_fast_len(target: usize, real: bool) -> usize {
274 if target <= 1 {
275 return 1;
276 }
277
278 let max_factor = if real { 5 } else { 11 };
280
281 let mut n = target;
282 loop {
283 let mut is_smooth = true;
285 let mut remaining = n;
286
287 while remaining > 1 {
289 let mut factor_found = false;
290 for &p in EFFICIENT_FACTORS.iter().filter(|&&p| p <= max_factor) {
291 if remaining % p == 0 {
292 remaining /= p;
293 factor_found = true;
294 break;
295 }
296 }
297
298 if !factor_found {
299 is_smooth = false;
300 break;
301 }
302 }
303
304 if is_smooth {
305 return n;
306 }
307
308 n += 1;
309 }
310}
311
312#[allow(dead_code)]
335pub fn prev_fast_len(target: usize, real: bool) -> usize {
336 if target <= 1 {
337 return 1;
338 }
339
340 let max_factor = if real { 5 } else { 11 };
342
343 let mut n = target;
344 while n > 1 {
345 let mut is_smooth = true;
347 let mut remaining = n;
348
349 while remaining > 1 {
351 let mut factor_found = false;
352 for &p in EFFICIENT_FACTORS.iter().filter(|&&p| p <= max_factor) {
353 if remaining % p == 0 {
354 remaining /= p;
355 factor_found = true;
356 break;
357 }
358 }
359
360 if !factor_found {
361 is_smooth = false;
362 break;
363 }
364 }
365
366 if is_smooth {
367 return n;
368 }
369
370 n -= 1;
371 }
372
373 1
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use approx::assert_relative_eq;
380 use ndarray::{Array1, Array2};
381
382 #[test]
383 fn test_fftfreq() {
384 let freq = fftfreq(8, 1.0).unwrap();
386 let expected = [0.0, 0.125, 0.25, 0.375, -0.5, -0.375, -0.25, -0.125];
387 assert_eq!(freq.len(), expected.len());
388 for (a, b) in freq.iter().zip(expected.iter()) {
389 assert_relative_eq!(a, b, epsilon = 1e-10);
390 }
391
392 let freq = fftfreq(7, 1.0).unwrap();
394 let expected = [
396 0.0,
397 0.14285714,
398 0.28571429,
399 -0.42857143,
400 -0.28571429,
401 -0.14285714,
402 0.0,
403 ];
404 assert_eq!(freq.len(), expected.len());
405 for (a, b) in freq.iter().zip(expected.iter()) {
406 assert_relative_eq!(a, b, epsilon = 1e-8);
407 }
408
409 let freq = fftfreq(4, 0.1).unwrap();
411 let expected = [0.0, 2.5, -5.0, -2.5];
412 for (a, b) in freq.iter().zip(expected.iter()) {
413 assert_relative_eq!(a, b, epsilon = 1e-10);
414 }
415 }
416
417 #[test]
418 fn test_rfftfreq() {
419 let freq = rfftfreq(8, 1.0).unwrap();
421 let expected = [0.0, 0.125, 0.25, 0.375, 0.5];
422 assert_eq!(freq.len(), expected.len());
423 for (a, b) in freq.iter().zip(expected.iter()) {
424 assert_relative_eq!(a, b, epsilon = 1e-10);
425 }
426
427 let freq = rfftfreq(7, 1.0).unwrap();
429 let expected = [0.0, 0.14285714, 0.28571429, 0.42857143];
430 assert_eq!(freq.len(), 4);
431 for (a, b) in freq.iter().zip(expected.iter()) {
432 assert_relative_eq!(a, b, epsilon = 1e-8);
433 }
434
435 let freq = rfftfreq(4, 0.1).unwrap();
437 let expected = [0.0, 2.5, 5.0];
438 for (a, b) in freq.iter().zip(expected.iter()) {
439 assert_relative_eq!(a, b, epsilon = 1e-10);
440 }
441 }
442
443 #[test]
444 fn test_fftshift() {
445 let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
447 let shifted = fftshift(&x).unwrap();
448 let expected = Array1::from_vec(vec![2.0, 3.0, 0.0, 1.0]);
449 assert_eq!(shifted, expected);
450
451 let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0]);
453 let shifted = fftshift(&x).unwrap();
454 let expected = Array1::from_vec(vec![3.0, 4.0, 0.0, 1.0, 2.0]);
455 assert_eq!(shifted, expected);
456
457 let x = Array2::from_shape_vec((2, 2), vec![0.0, 1.0, 2.0, 3.0]).unwrap();
459 let shifted = fftshift(&x).unwrap();
460 let expected = Array2::from_shape_vec((2, 2), vec![3.0, 2.0, 1.0, 0.0]).unwrap();
461 assert_eq!(shifted, expected);
462 }
463
464 #[test]
465 fn test_ifftshift() {
466 let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
468 let shifted = fftshift(&x).unwrap();
469 let unshifted = ifftshift(&shifted).unwrap();
470 assert_eq!(unshifted, x);
471
472 let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0]);
474 let shifted = fftshift(&x).unwrap();
475 let unshifted = ifftshift(&shifted).unwrap();
476 assert_eq!(unshifted, x);
477
478 let x = Array2::from_shape_vec((2, 3), vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
480 let shifted = fftshift(&x).unwrap();
481 let unshifted = ifftshift(&shifted).unwrap();
482 assert_eq!(unshifted, x);
483 }
484
485 #[test]
486 fn test_freq_bins() {
487 let bins = freq_bins(8, 16000.0).unwrap();
488 let expected = [
489 0.0, 2000.0, 4000.0, 6000.0, -8000.0, -6000.0, -4000.0, -2000.0,
490 ];
491 assert_eq!(bins.len(), expected.len());
492 for (a, b) in bins.iter().zip(expected.iter()) {
493 assert_relative_eq!(a, b, epsilon = 1e-10);
494 }
495 }
496
497 #[test]
498 fn test_next_fast_len() {
499 for target in [7, 13, 511, 512, 513, 1000, 1024] {
505 let result = next_fast_len(target, false);
506 assert!(
508 result >= target,
509 "Result should be >= target: {result} >= {target}"
510 );
511
512 assert!(
514 is_fast_length(result, false),
515 "Result {result} should be a product of efficient prime factors"
516 );
517 }
518
519 for target in [13, 512, 523, 1000] {
521 let result = next_fast_len(target, true);
522 assert!(
524 result >= target,
525 "Result should be >= target: {result} >= {target}"
526 );
527
528 assert!(
530 is_fast_length(result, true),
531 "Result {result} should be a product of efficient real prime factors"
532 );
533 }
534 }
535
536 #[test]
537 fn test_prev_fast_len() {
538 for target in [7, 13, 512, 513, 1000, 1024] {
542 let result = prev_fast_len(target, false);
543 assert!(
545 result <= target,
546 "Result should be <= target: {result} <= {target}"
547 );
548
549 assert!(
551 is_fast_length(result, false),
552 "Result {result} should be a product of efficient prime factors"
553 );
554 }
555
556 for target in [13, 512, 613, 1000] {
558 let result = prev_fast_len(target, true);
559 assert!(
561 result <= target,
562 "Result should be <= target: {result} <= {target}"
563 );
564
565 assert!(
567 is_fast_length(result, true),
568 "Result {result} should be a product of efficient real prime factors"
569 );
570 }
571 }
572
573 fn is_fast_length(n: usize, real: bool) -> bool {
575 if n <= 1 {
576 return true;
577 }
578
579 let max_factor = if real { 5 } else { 11 };
580 let mut remaining = n;
581
582 while remaining > 1 {
583 let mut factor_found = false;
584 for &p in EFFICIENT_FACTORS.iter().filter(|&&p| p <= max_factor) {
585 if remaining % p == 0 {
586 remaining /= p;
587 factor_found = true;
588 break;
589 }
590 }
591
592 if !factor_found {
593 return false;
594 }
595 }
596
597 true
598 }
599}