rusty_brain/
fft.rs

1use nalgebra::Complex;
2use ndarray::s;
3use ndarray::Array1;
4use ndarray::ArrayView1;
5use std::f32::consts::PI;
6
7pub trait FastFourierTransform<T> {
8    fn fft(&self) -> Array1<Complex<T>>;
9}
10
11macro_rules! impl_fft_for {
12    // Special treatment for `i16` vectors, required for full compatibility with
13    (i16) => {
14        impl<'a> FastFourierTransform<i16> for ArrayView1<'a, i16> {
15            fn fft(&self) -> Array1<Complex<i16>> {
16                let n = self.len();
17
18                if n == 1 {
19                    return Array1::from_elem(1, Complex::new(self[0], 0));
20                }
21
22                let even = self.slice(s![..; 2]);
23                let odd = self.slice(s![1..; 2]);
24
25                let fft_even = even.fft();
26                let fft_odd = odd.fft();
27
28                let mut result = Array1::zeros(n);
29                for k in 0..n / 2 {
30                    let angle = 2.0 * PI * k as f32 / n as f32;
31                    let twiddle = Complex::new(
32                        (angle.cos() * 32767.0_f32).round() as i16,
33                        (angle.sin() * 32767.0_f32).round() as i16,
34                    );
35
36                    result[k] = fft_even[k] + fft_odd[k] * twiddle;
37                    result[k + n / 2] = fft_even[k] - fft_odd[k] * twiddle;
38                }
39
40                result
41            }
42        }
43    };
44    // Funny special case for full compatibility with BrainVision Core Data Format 1.0
45    // Implement FFT for floating-point types
46    ($float_t: ty) => {
47        impl<'a> FastFourierTransform<$float_t> for ArrayView1<'a, $float_t> {
48            fn fft(&self) -> Array1<Complex<$float_t>> {
49                let n = self.len();
50
51                if n == 1 {
52                    return Array1::from_elem(1, Complex::from(self[0]));
53                }
54
55                let even = self.slice(s![..; 2]);
56                let odd = self.slice(s![1..; 2]);
57
58                let fft_even = even.fft();
59                let fft_odd = odd.fft();
60
61                let mut result = Array1::zeros(n);
62                for k in 0..n / 2 {
63                    let angle = 2.0 * PI as $ float_t * k as $float_t / n as $float_t;
64                    let twiddle = Complex::new(angle.cos(), angle.sin());
65                    result[k] = fft_even[k] + twiddle * fft_odd[k];
66                    result[k + n / 2] = fft_even[k] - twiddle * fft_odd[k];
67                }
68
69                result
70            }
71        }
72    };
73}
74
75impl_fft_for!(i16);
76impl_fft_for!(f32);