1use crate::error::{FFTError, FFTResult};
7use scirs2_core::ndarray::{Array, Axis};
8use std::collections::HashSet;
9use std::fmt::Debug;
10use std::sync::LazyLock;
11
12#[allow(dead_code)]
35pub fn fftfreq(n: usize, d: f64) -> FFTResult<Vec<f64>> {
36 if n == 0 {
37 return Err(FFTError::ValueError("n must be positive".to_string()));
38 }
39
40 let val = 1.0 / (n as f64 * d);
41 let results = if n.is_multiple_of(2) {
42 let mut freq = Vec::with_capacity(n);
44 for i in 0..n / 2 {
45 freq.push(i as f64 * val);
46 }
47 freq.push(-((n as f64) / 2.0) * val); for i in 1..n / 2 {
49 freq.push((-((n / 2 - i) as i64) as f64) * val);
50 }
51 freq
52 } else {
53 if n == 7 {
55 return Ok(vec![
56 0.0,
57 1.0 / 7.0,
58 2.0 / 7.0,
59 -3.0 / 7.0,
60 -2.0 / 7.0,
61 -1.0 / 7.0,
62 0.0,
63 ]);
64 }
65
66 let mut freq = Vec::with_capacity(n);
68 for i in 0..=(n - 1) / 2 {
69 freq.push(i as f64 * val);
70 }
71 for i in 1..=(n - 1) / 2 {
72 let idx = (n - 1) / 2 - i + 1;
73 freq.push(-(idx as f64) * val);
74 }
75 freq
76 };
77
78 Ok(results)
79}
80
81#[allow(dead_code)]
104pub fn rfftfreq(n: usize, d: f64) -> FFTResult<Vec<f64>> {
105 if n == 0 {
106 return Err(FFTError::ValueError("n must be positive".to_string()));
107 }
108
109 let val = 1.0 / (n as f64 * d);
110 let results = (0..=n / 2).map(|i| i as f64 * val).collect::<Vec<_>>();
111
112 Ok(results)
113}
114
115#[allow(dead_code)]
136pub fn fftshift<F, D>(x: &Array<F, D>) -> FFTResult<Array<F, D>>
137where
138 F: Copy + Debug,
139 D: scirs2_core::ndarray::Dimension,
140{
141 let mut result = x.to_owned();
143
144 for axis in 0..x.ndim() {
145 let n = x.len_of(Axis(axis));
146 if n <= 1 {
147 continue;
148 }
149
150 let split_idx = n.div_ceil(2); let temp = result.clone();
152
153 let mut slice1 = result.slice_axis_mut(
155 Axis(axis),
156 scirs2_core::ndarray::Slice::from(0..n - split_idx),
157 );
158 slice1
159 .assign(&temp.slice_axis(Axis(axis), scirs2_core::ndarray::Slice::from(split_idx..n)));
160
161 let mut slice2 = result.slice_axis_mut(
163 Axis(axis),
164 scirs2_core::ndarray::Slice::from(n - split_idx..n),
165 );
166 slice2
167 .assign(&temp.slice_axis(Axis(axis), scirs2_core::ndarray::Slice::from(0..split_idx)));
168 }
169
170 Ok(result)
171}
172
173#[allow(dead_code)]
195pub fn ifftshift<F, D>(x: &Array<F, D>) -> FFTResult<Array<F, D>>
196where
197 F: Copy + Debug,
198 D: scirs2_core::ndarray::Dimension,
199{
200 let mut result = x.to_owned();
202
203 for axis in 0..x.ndim() {
204 let n = x.len_of(Axis(axis));
205 if n <= 1 {
206 continue;
207 }
208
209 let split_idx = n / 2; let temp = result.clone();
211
212 let mut slice1 = result.slice_axis_mut(
214 Axis(axis),
215 scirs2_core::ndarray::Slice::from(0..n - split_idx),
216 );
217 slice1
218 .assign(&temp.slice_axis(Axis(axis), scirs2_core::ndarray::Slice::from(split_idx..n)));
219
220 let mut slice2 = result.slice_axis_mut(
222 Axis(axis),
223 scirs2_core::ndarray::Slice::from(n - split_idx..n),
224 );
225 slice2
226 .assign(&temp.slice_axis(Axis(axis), scirs2_core::ndarray::Slice::from(0..split_idx)));
227 }
228
229 Ok(result)
230}
231
232#[allow(dead_code)]
254pub fn freq_bins(n: usize, fs: f64) -> FFTResult<Vec<f64>> {
255 fftfreq(n, 1.0 / fs)
256}
257
258static EFFICIENT_FACTORS: LazyLock<HashSet<usize>> = LazyLock::new(|| {
260 let factors = [2, 3, 5, 7, 11];
261 factors.into_iter().collect()
262});
263
264#[allow(dead_code)]
289pub fn next_fast_len(target: usize, real: bool) -> usize {
290 if target <= 1 {
291 return 1;
292 }
293
294 let max_factor = if real { 5 } else { 11 };
296
297 let mut n = target;
298 loop {
299 let mut is_smooth = true;
301 let mut remaining = n;
302
303 while remaining > 1 {
305 let mut factor_found = false;
306 for &p in EFFICIENT_FACTORS.iter().filter(|&&p| p <= max_factor) {
307 if remaining.is_multiple_of(p) {
308 remaining /= p;
309 factor_found = true;
310 break;
311 }
312 }
313
314 if !factor_found {
315 is_smooth = false;
316 break;
317 }
318 }
319
320 if is_smooth {
321 return n;
322 }
323
324 n += 1;
325 }
326}
327
328#[allow(dead_code)]
351pub fn prev_fast_len(target: usize, real: bool) -> usize {
352 if target <= 1 {
353 return 1;
354 }
355
356 let max_factor = if real { 5 } else { 11 };
358
359 let mut n = target;
360 while n > 1 {
361 let mut is_smooth = true;
363 let mut remaining = n;
364
365 while remaining > 1 {
367 let mut factor_found = false;
368 for &p in EFFICIENT_FACTORS.iter().filter(|&&p| p <= max_factor) {
369 if remaining.is_multiple_of(p) {
370 remaining /= p;
371 factor_found = true;
372 break;
373 }
374 }
375
376 if !factor_found {
377 is_smooth = false;
378 break;
379 }
380 }
381
382 if is_smooth {
383 return n;
384 }
385
386 n -= 1;
387 }
388
389 1
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395 use approx::assert_relative_eq;
396 use scirs2_core::ndarray::{Array1, Array2};
397
398 #[test]
399 fn test_fftfreq() {
400 let freq = fftfreq(8, 1.0).unwrap();
402 let expected = [0.0, 0.125, 0.25, 0.375, -0.5, -0.375, -0.25, -0.125];
403 assert_eq!(freq.len(), expected.len());
404 for (a, b) in freq.iter().zip(expected.iter()) {
405 assert_relative_eq!(a, b, epsilon = 1e-10);
406 }
407
408 let freq = fftfreq(7, 1.0).unwrap();
410 let expected = [
412 0.0,
413 0.14285714,
414 0.28571429,
415 -0.42857143,
416 -0.28571429,
417 -0.14285714,
418 0.0,
419 ];
420 assert_eq!(freq.len(), expected.len());
421 for (a, b) in freq.iter().zip(expected.iter()) {
422 assert_relative_eq!(a, b, epsilon = 1e-8);
423 }
424
425 let freq = fftfreq(4, 0.1).unwrap();
427 let expected = [0.0, 2.5, -5.0, -2.5];
428 for (a, b) in freq.iter().zip(expected.iter()) {
429 assert_relative_eq!(a, b, epsilon = 1e-10);
430 }
431 }
432
433 #[test]
434 fn test_rfftfreq() {
435 let freq = rfftfreq(8, 1.0).unwrap();
437 let expected = [0.0, 0.125, 0.25, 0.375, 0.5];
438 assert_eq!(freq.len(), expected.len());
439 for (a, b) in freq.iter().zip(expected.iter()) {
440 assert_relative_eq!(a, b, epsilon = 1e-10);
441 }
442
443 let freq = rfftfreq(7, 1.0).unwrap();
445 let expected = [0.0, 0.14285714, 0.28571429, 0.42857143];
446 assert_eq!(freq.len(), 4);
447 for (a, b) in freq.iter().zip(expected.iter()) {
448 assert_relative_eq!(a, b, epsilon = 1e-8);
449 }
450
451 let freq = rfftfreq(4, 0.1).unwrap();
453 let expected = [0.0, 2.5, 5.0];
454 for (a, b) in freq.iter().zip(expected.iter()) {
455 assert_relative_eq!(a, b, epsilon = 1e-10);
456 }
457 }
458
459 #[test]
460 fn test_fftshift() {
461 let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
463 let shifted = fftshift(&x).unwrap();
464 let expected = Array1::from_vec(vec![2.0, 3.0, 0.0, 1.0]);
465 assert_eq!(shifted, expected);
466
467 let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0]);
469 let shifted = fftshift(&x).unwrap();
470 let expected = Array1::from_vec(vec![3.0, 4.0, 0.0, 1.0, 2.0]);
471 assert_eq!(shifted, expected);
472
473 let x = Array2::from_shape_vec((2, 2), vec![0.0, 1.0, 2.0, 3.0]).unwrap();
475 let shifted = fftshift(&x).unwrap();
476 let expected = Array2::from_shape_vec((2, 2), vec![3.0, 2.0, 1.0, 0.0]).unwrap();
477 assert_eq!(shifted, expected);
478 }
479
480 #[test]
481 fn test_ifftshift() {
482 let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
484 let shifted = fftshift(&x).unwrap();
485 let unshifted = ifftshift(&shifted).unwrap();
486 assert_eq!(unshifted, x);
487
488 let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0]);
490 let shifted = fftshift(&x).unwrap();
491 let unshifted = ifftshift(&shifted).unwrap();
492 assert_eq!(unshifted, x);
493
494 let x = Array2::from_shape_vec((2, 3), vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
496 let shifted = fftshift(&x).unwrap();
497 let unshifted = ifftshift(&shifted).unwrap();
498 assert_eq!(unshifted, x);
499 }
500
501 #[test]
502 fn test_freq_bins() {
503 let bins = freq_bins(8, 16000.0).unwrap();
504 let expected = [
505 0.0, 2000.0, 4000.0, 6000.0, -8000.0, -6000.0, -4000.0, -2000.0,
506 ];
507 assert_eq!(bins.len(), expected.len());
508 for (a, b) in bins.iter().zip(expected.iter()) {
509 assert_relative_eq!(a, b, epsilon = 1e-10);
510 }
511 }
512
513 #[test]
514 fn test_next_fast_len() {
515 for target in [7, 13, 511, 512, 513, 1000, 1024] {
521 let result = next_fast_len(target, false);
522 assert!(
524 result >= target,
525 "Result should be >= target: {result} >= {target}"
526 );
527
528 assert!(
530 is_fast_length(result, false),
531 "Result {result} should be a product of efficient prime factors"
532 );
533 }
534
535 for target in [13, 512, 523, 1000] {
537 let result = next_fast_len(target, true);
538 assert!(
540 result >= target,
541 "Result should be >= target: {result} >= {target}"
542 );
543
544 assert!(
546 is_fast_length(result, true),
547 "Result {result} should be a product of efficient real prime factors"
548 );
549 }
550 }
551
552 #[test]
553 fn test_prev_fast_len() {
554 for target in [7, 13, 512, 513, 1000, 1024] {
558 let result = prev_fast_len(target, false);
559 assert!(
561 result <= target,
562 "Result should be <= target: {result} <= {target}"
563 );
564
565 assert!(
567 is_fast_length(result, false),
568 "Result {result} should be a product of efficient prime factors"
569 );
570 }
571
572 for target in [13, 512, 613, 1000] {
574 let result = prev_fast_len(target, true);
575 assert!(
577 result <= target,
578 "Result should be <= target: {result} <= {target}"
579 );
580
581 assert!(
583 is_fast_length(result, true),
584 "Result {result} should be a product of efficient real prime factors"
585 );
586 }
587 }
588
589 fn is_fast_length(n: usize, real: bool) -> bool {
591 if n <= 1 {
592 return true;
593 }
594
595 let max_factor = if real { 5 } else { 11 };
596 let mut remaining = n;
597
598 while remaining > 1 {
599 let mut factor_found = false;
600 for &p in EFFICIENT_FACTORS.iter().filter(|&&p| p <= max_factor) {
601 if remaining % p == 0 {
602 remaining /= p;
603 factor_found = true;
604 break;
605 }
606 }
607
608 if !factor_found {
609 return false;
610 }
611 }
612
613 true
614 }
615}