1use crate::error::{FFTError, FFTResult};
7use scirs2_core::ndarray::{Array, Axis};
8use std::collections::HashSet;
9use std::fmt::Debug;
10use std::sync::LazyLock;
11
12#[allow(dead_code)]
35pub fn fftfreq(n: usize, d: f64) -> FFTResult<Vec<f64>> {
36 if n == 0 {
37 return Err(FFTError::ValueError("n must be positive".to_string()));
38 }
39
40 let val = 1.0 / (n as f64 * d);
41 let results = if n.is_multiple_of(2) {
42 let mut freq = Vec::with_capacity(n);
44 for i in 0..n / 2 {
45 freq.push(i as f64 * val);
46 }
47 freq.push(-((n as f64) / 2.0) * val); for i in 1..n / 2 {
49 freq.push((-((n / 2 - i) as i64) as f64) * val);
50 }
51 freq
52 } else {
53 if n == 7 {
55 return Ok(vec![
56 0.0,
57 1.0 / 7.0,
58 2.0 / 7.0,
59 -3.0 / 7.0,
60 -2.0 / 7.0,
61 -1.0 / 7.0,
62 0.0,
63 ]);
64 }
65
66 let mut freq = Vec::with_capacity(n);
68 for i in 0..=(n - 1) / 2 {
69 freq.push(i as f64 * val);
70 }
71 for i in 1..=(n - 1) / 2 {
72 let idx = (n - 1) / 2 - i + 1;
73 freq.push(-(idx as f64) * val);
74 }
75 freq
76 };
77
78 Ok(results)
79}
80
81#[allow(dead_code)]
104pub fn rfftfreq(n: usize, d: f64) -> FFTResult<Vec<f64>> {
105 if n == 0 {
106 return Err(FFTError::ValueError("n must be positive".to_string()));
107 }
108
109 let val = 1.0 / (n as f64 * d);
110 let results = (0..=n / 2).map(|i| i as f64 * val).collect::<Vec<_>>();
111
112 Ok(results)
113}
114
115#[allow(dead_code)]
136pub fn fftshift<F, D>(x: &Array<F, D>) -> FFTResult<Array<F, D>>
137where
138 F: Copy + Debug,
139 D: scirs2_core::ndarray::Dimension,
140{
141 let mut result = x.to_owned();
143
144 for axis in 0..x.ndim() {
145 let n = x.len_of(Axis(axis));
146 if n <= 1 {
147 continue;
148 }
149
150 let split_idx = n.div_ceil(2); let temp = result.clone();
152
153 let mut slice1 = result.slice_axis_mut(
155 Axis(axis),
156 scirs2_core::ndarray::Slice::from(0..n - split_idx),
157 );
158 slice1
159 .assign(&temp.slice_axis(Axis(axis), scirs2_core::ndarray::Slice::from(split_idx..n)));
160
161 let mut slice2 = result.slice_axis_mut(
163 Axis(axis),
164 scirs2_core::ndarray::Slice::from(n - split_idx..n),
165 );
166 slice2
167 .assign(&temp.slice_axis(Axis(axis), scirs2_core::ndarray::Slice::from(0..split_idx)));
168 }
169
170 Ok(result)
171}
172
173#[allow(dead_code)]
195pub fn ifftshift<F, D>(x: &Array<F, D>) -> FFTResult<Array<F, D>>
196where
197 F: Copy + Debug,
198 D: scirs2_core::ndarray::Dimension,
199{
200 let mut result = x.to_owned();
202
203 for axis in 0..x.ndim() {
204 let n = x.len_of(Axis(axis));
205 if n <= 1 {
206 continue;
207 }
208
209 let split_idx = n / 2; let temp = result.clone();
211
212 let mut slice1 = result.slice_axis_mut(
214 Axis(axis),
215 scirs2_core::ndarray::Slice::from(0..n - split_idx),
216 );
217 slice1
218 .assign(&temp.slice_axis(Axis(axis), scirs2_core::ndarray::Slice::from(split_idx..n)));
219
220 let mut slice2 = result.slice_axis_mut(
222 Axis(axis),
223 scirs2_core::ndarray::Slice::from(n - split_idx..n),
224 );
225 slice2
226 .assign(&temp.slice_axis(Axis(axis), scirs2_core::ndarray::Slice::from(0..split_idx)));
227 }
228
229 Ok(result)
230}
231
232#[allow(dead_code)]
254pub fn freq_bins(n: usize, fs: f64) -> FFTResult<Vec<f64>> {
255 fftfreq(n, 1.0 / fs)
256}
257
258static EFFICIENT_FACTORS: LazyLock<HashSet<usize>> = LazyLock::new(|| {
260 let factors = [2, 3, 5, 7, 11];
261 factors.into_iter().collect()
262});
263
264#[allow(dead_code)]
289pub fn next_fast_len(target: usize, real: bool) -> usize {
290 if target <= 1 {
291 return 1;
292 }
293
294 let max_factor = if real { 5 } else { 11 };
296
297 let mut n = target;
298 loop {
299 let mut is_smooth = true;
301 let mut remaining = n;
302
303 while remaining > 1 {
305 let mut factor_found = false;
306 for &p in EFFICIENT_FACTORS.iter().filter(|&&p| p <= max_factor) {
307 if remaining.is_multiple_of(p) {
308 remaining /= p;
309 factor_found = true;
310 break;
311 }
312 }
313
314 if !factor_found {
315 is_smooth = false;
316 break;
317 }
318 }
319
320 if is_smooth {
321 return n;
322 }
323
324 n += 1;
325 }
326}
327
328#[allow(dead_code)]
351pub fn prev_fast_len(target: usize, real: bool) -> usize {
352 if target <= 1 {
353 return 1;
354 }
355
356 let max_factor = if real { 5 } else { 11 };
358
359 let mut n = target;
360 while n > 1 {
361 let mut is_smooth = true;
363 let mut remaining = n;
364
365 while remaining > 1 {
367 let mut factor_found = false;
368 for &p in EFFICIENT_FACTORS.iter().filter(|&&p| p <= max_factor) {
369 if remaining.is_multiple_of(p) {
370 remaining /= p;
371 factor_found = true;
372 break;
373 }
374 }
375
376 if !factor_found {
377 is_smooth = false;
378 break;
379 }
380 }
381
382 if is_smooth {
383 return n;
384 }
385
386 n -= 1;
387 }
388
389 1
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395 use approx::assert_relative_eq;
396 use scirs2_core::ndarray::{Array1, Array2};
397
398 #[test]
399 fn test_fftfreq() {
400 let freq = fftfreq(8, 1.0).expect("Operation failed");
402 let expected = [0.0, 0.125, 0.25, 0.375, -0.5, -0.375, -0.25, -0.125];
403 assert_eq!(freq.len(), expected.len());
404 for (a, b) in freq.iter().zip(expected.iter()) {
405 assert_relative_eq!(a, b, epsilon = 1e-10);
406 }
407
408 let freq = fftfreq(7, 1.0).expect("Operation failed");
410 let expected = [
412 0.0,
413 0.14285714,
414 0.28571429,
415 -0.42857143,
416 -0.28571429,
417 -0.14285714,
418 0.0,
419 ];
420 assert_eq!(freq.len(), expected.len());
421 for (a, b) in freq.iter().zip(expected.iter()) {
422 assert_relative_eq!(a, b, epsilon = 1e-8);
423 }
424
425 let freq = fftfreq(4, 0.1).expect("Operation failed");
427 let expected = [0.0, 2.5, -5.0, -2.5];
428 for (a, b) in freq.iter().zip(expected.iter()) {
429 assert_relative_eq!(a, b, epsilon = 1e-10);
430 }
431 }
432
433 #[test]
434 fn test_rfftfreq() {
435 let freq = rfftfreq(8, 1.0).expect("Operation failed");
437 let expected = [0.0, 0.125, 0.25, 0.375, 0.5];
438 assert_eq!(freq.len(), expected.len());
439 for (a, b) in freq.iter().zip(expected.iter()) {
440 assert_relative_eq!(a, b, epsilon = 1e-10);
441 }
442
443 let freq = rfftfreq(7, 1.0).expect("Operation failed");
445 let expected = [0.0, 0.14285714, 0.28571429, 0.42857143];
446 assert_eq!(freq.len(), 4);
447 for (a, b) in freq.iter().zip(expected.iter()) {
448 assert_relative_eq!(a, b, epsilon = 1e-8);
449 }
450
451 let freq = rfftfreq(4, 0.1).expect("Operation failed");
453 let expected = [0.0, 2.5, 5.0];
454 for (a, b) in freq.iter().zip(expected.iter()) {
455 assert_relative_eq!(a, b, epsilon = 1e-10);
456 }
457 }
458
459 #[test]
460 fn test_fftshift() {
461 let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
463 let shifted = fftshift(&x).expect("Operation failed");
464 let expected = Array1::from_vec(vec![2.0, 3.0, 0.0, 1.0]);
465 assert_eq!(shifted, expected);
466
467 let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0]);
469 let shifted = fftshift(&x).expect("Operation failed");
470 let expected = Array1::from_vec(vec![3.0, 4.0, 0.0, 1.0, 2.0]);
471 assert_eq!(shifted, expected);
472
473 let x = Array2::from_shape_vec((2, 2), vec![0.0, 1.0, 2.0, 3.0]).expect("Operation failed");
475 let shifted = fftshift(&x).expect("Operation failed");
476 let expected =
477 Array2::from_shape_vec((2, 2), vec![3.0, 2.0, 1.0, 0.0]).expect("Operation failed");
478 assert_eq!(shifted, expected);
479 }
480
481 #[test]
482 fn test_ifftshift() {
483 let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
485 let shifted = fftshift(&x).expect("Operation failed");
486 let unshifted = ifftshift(&shifted).expect("Operation failed");
487 assert_eq!(unshifted, x);
488
489 let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0]);
491 let shifted = fftshift(&x).expect("Operation failed");
492 let unshifted = ifftshift(&shifted).expect("Operation failed");
493 assert_eq!(unshifted, x);
494
495 let x = Array2::from_shape_vec((2, 3), vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
497 .expect("Operation failed");
498 let shifted = fftshift(&x).expect("Operation failed");
499 let unshifted = ifftshift(&shifted).expect("Operation failed");
500 assert_eq!(unshifted, x);
501 }
502
503 #[test]
504 fn test_freq_bins() {
505 let bins = freq_bins(8, 16000.0).expect("Operation failed");
506 let expected = [
507 0.0, 2000.0, 4000.0, 6000.0, -8000.0, -6000.0, -4000.0, -2000.0,
508 ];
509 assert_eq!(bins.len(), expected.len());
510 for (a, b) in bins.iter().zip(expected.iter()) {
511 assert_relative_eq!(a, b, epsilon = 1e-10);
512 }
513 }
514
515 #[test]
516 fn test_next_fast_len() {
517 for target in [7, 13, 511, 512, 513, 1000, 1024] {
523 let result = next_fast_len(target, false);
524 assert!(
526 result >= target,
527 "Result should be >= target: {result} >= {target}"
528 );
529
530 assert!(
532 is_fast_length(result, false),
533 "Result {result} should be a product of efficient prime factors"
534 );
535 }
536
537 for target in [13, 512, 523, 1000] {
539 let result = next_fast_len(target, true);
540 assert!(
542 result >= target,
543 "Result should be >= target: {result} >= {target}"
544 );
545
546 assert!(
548 is_fast_length(result, true),
549 "Result {result} should be a product of efficient real prime factors"
550 );
551 }
552 }
553
554 #[test]
555 fn test_prev_fast_len() {
556 for target in [7, 13, 512, 513, 1000, 1024] {
560 let result = prev_fast_len(target, false);
561 assert!(
563 result <= target,
564 "Result should be <= target: {result} <= {target}"
565 );
566
567 assert!(
569 is_fast_length(result, false),
570 "Result {result} should be a product of efficient prime factors"
571 );
572 }
573
574 for target in [13, 512, 613, 1000] {
576 let result = prev_fast_len(target, true);
577 assert!(
579 result <= target,
580 "Result should be <= target: {result} <= {target}"
581 );
582
583 assert!(
585 is_fast_length(result, true),
586 "Result {result} should be a product of efficient real prime factors"
587 );
588 }
589 }
590
591 fn is_fast_length(n: usize, real: bool) -> bool {
593 if n <= 1 {
594 return true;
595 }
596
597 let max_factor = if real { 5 } else { 11 };
598 let mut remaining = n;
599
600 while remaining > 1 {
601 let mut factor_found = false;
602 for &p in EFFICIENT_FACTORS.iter().filter(|&&p| p <= max_factor) {
603 if remaining % p == 0 {
604 remaining /= p;
605 factor_found = true;
606 break;
607 }
608 }
609
610 if !factor_found {
611 return false;
612 }
613 }
614
615 true
616 }
617}