simd_kernels/kernels/scientific/
fft.rs

1// Copyright Peter Bower 2025. All Rights Reserved.
2// Licensed under Mozilla Public License (MPL) 2.0.
3
4//! # **Fast Fourier Transform Module** - *High-Performance Frequency Domain Analysis*
5//!
6//! This module implements optimised Fast Fourier Transform (FFT) algorithms for efficient
7//! frequency domain analysis and signal processing applications. It provides both small-scale
8//! radix-optimised transforms and large-scale blocked implementations for scientific
9//! computing and digital signal processing workflows.
10//!
11//! ## Use cases
12//!
13//! The Fast Fourier Transform is fundamental to numerous computational domains:
14//! - **Digital Signal Processing**: Spectral analysis, filtering, and convolution
15//! - **Image Processing**: Frequency domain transformations and enhancement
16//! - **Scientific Computing**: Numerical solution of PDEs via spectral methods
17//! - **Audio Processing**: Frequency analysis and synthesis
18//! - **Telecommunications**: Modulation, demodulation, and channel analysis
19//! - **Machine Learning**: Feature extraction and data preprocessing
20
21use minarrow::enums::error::KernelError;
22use minarrow::{FloatArray, Vec64};
23use num_complex::Complex64;
24
25#[inline(always)]
26pub fn butterfly_radix8(buf: &mut [Complex64]) {
27    debug_assert_eq!(buf.len(), 8);
28
29    // Split into even/odd halves and reuse temporaries to minimise loads
30    let (x0, x1, x2, x3, x4, x5, x6, x7) = (
31        buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7],
32    );
33
34    // First layer (radix-2 butterflies)
35    let a04 = x0 + x4;
36    let s04 = x0 - x4;
37    let a26 = x2 + x6;
38    let s26 = x2 - x6;
39    let a15 = x1 + x5;
40    let s15 = x1 - x5;
41    let a37 = x3 + x7;
42    let s37 = x3 - x7;
43
44    // Second layer (radix-4)
45    let a04a26 = a04 + a26;
46    let a04s26 = a04 - a26;
47    let a15a37 = a15 + a37;
48    let a15s37 = a15 - a37;
49
50    // Calculate ±i·(something) once
51    const J: Complex64 = Complex64 { re: 0.0, im: 1.0 };
52
53    // Radix-8 output
54    buf[0] = a04a26 + a15a37;
55    buf[4] = a04a26 - a15a37;
56
57    let t0 = s04 + J * s26;
58    let t1 = s15 + J * s37;
59    buf[2] = t0 + Complex64::new(0.0, -1.0) * t1; //  e^{-jπ/2}
60    buf[6] = t0 + Complex64::new(0.0, 1.0) * t1; //  e^{ jπ/2}
61
62    let u0 = a04s26;
63    let u1 = Complex64::new(0.0, -1.0) * a15s37;
64    buf[1] = u0 + u1; //  e^{-jπ/4} merged factor
65    buf[5] = u0 - u1; //  e^{ 3jπ/4}
66
67    let v0 = s04 - J * s26;
68    let v1 = s15 - J * s37;
69    buf[3] = v0 - Complex64::new(0.0, 1.0) * v1; //  e^{ jπ/2}
70    buf[7] = v0 - Complex64::new(0.0, -1.0) * v1; //  e^{-jπ/2}
71}
72
73// In-place radix-4 DIT for 4 points.
74#[inline(always)]
75fn fft4_in_place(x: &mut [Complex64; 4]) {
76    let x0 = x[0];
77    let x1 = x[1];
78    let x2 = x[2];
79    let x3 = x[3];
80
81    let a = x0 + x2; // (0)+(2)
82    let b = x0 - x2; // (0)-(2)
83    let c = x1 + x3; // (1)+(3)
84    let d = (x1 - x3) * Complex64::new(0.0, -1.0); // (1)-(3) times -j
85
86    x[0] = a + c; // k=0
87    x[2] = a - c; // k=2
88    x[1] = b + d; // k=1
89    x[3] = b - d; // k=3
90}
91
92/// 8-point FFT
93/// 8-point FFT (radix-2/4 DIT): split evens/odds -> FFT4 each -> twiddle & combine.
94#[inline(always)]
95pub fn fft8_radix(
96    buf: &mut [Complex64; 8],
97) -> Result<(FloatArray<f64>, FloatArray<f64>), KernelError> {
98    // Split into evens and odds
99    let mut even = [buf[0], buf[2], buf[4], buf[6]];
100    let mut odd = [buf[1], buf[3], buf[5], buf[7]];
101
102    // 4-point FFTs
103    fft4_in_place(&mut even);
104    fft4_in_place(&mut odd);
105
106    // Twiddles W8^k = exp(-j*2π*k/8)
107    // W8^0 = 1 + 0j
108    // W8^1 =  √2/2 - j√2/2
109    // W8^2 =  0   - j
110    // W8^3 = -√2/2 - j√2/2
111    let s = std::f64::consts::FRAC_1_SQRT_2;
112    let w1 = Complex64::new(s, -s);
113    let w2 = Complex64::new(0.0, -1.0);
114    let w3 = Complex64::new(-s, -s);
115
116    let t0 = odd[0]; // W8^0 * odd[0]
117    let t1 = w1 * odd[1]; // W8^1 * odd[1]
118    let t2 = w2 * odd[2]; // W8^2 * odd[2]
119    let t3 = w3 * odd[3]; // W8^3 * odd[3]
120
121    buf[0] = even[0] + t0;
122    buf[4] = even[0] - t0;
123
124    buf[1] = even[1] + t1;
125    buf[5] = even[1] - t1;
126
127    buf[2] = even[2] + t2;
128    buf[6] = even[2] - t2;
129
130    buf[3] = even[3] + t3;
131    buf[7] = even[3] - t3;
132
133    // Package outputs (same as your function did)
134    let mut real = Vec64::with_capacity(8);
135    let mut imag = Vec64::with_capacity(8);
136    for &z in buf.iter() {
137        real.push(z.re);
138        imag.push(z.im);
139    }
140    Ok((FloatArray::new(real, None), FloatArray::new(imag, None)))
141}
142
143/// Power-of-two, in-place FFT (≥8, radix-2 stages, radix-8 leaf).
144#[inline]
145pub fn block_fft(
146    data: &mut [Complex64],
147) -> Result<(FloatArray<f64>, FloatArray<f64>), KernelError> {
148    let n = data.len();
149    if n < 2 || (n & (n - 1)) != 0 {
150        return Err(KernelError::InvalidArguments(
151            "block_fft: N must be power-of-two and ≥2".into(),
152        ));
153    }
154
155    // bit-reversal permutation
156    let bits = n.trailing_zeros();
157    for i in 0..n {
158        let rev = i.reverse_bits() >> (usize::BITS - bits);
159        if i < rev {
160            data.swap(i, rev);
161        }
162    }
163
164    // iterative radix-2 DIT
165    let mut m = 2;
166    while m <= n {
167        let half = m / 2;
168        let theta = -2.0 * std::f64::consts::PI / (m as f64);
169        let w_m = Complex64::from_polar(1.0, theta);
170
171        for k in (0..n).step_by(m) {
172            let mut w = Complex64::new(1.0, 0.0);
173            for j in 0..half {
174                let t = w * data[k + j + half];
175                let u = data[k + j];
176                data[k + j] = u + t;
177                data[k + j + half] = u - t;
178                w *= w_m;
179            }
180        }
181        m <<= 1;
182    }
183
184    let mut real = Vec64::with_capacity(n);
185    let mut imag = Vec64::with_capacity(n);
186    for &z in data.iter() {
187        real.push(z.re);
188        imag.push(z.im);
189    }
190    Ok((FloatArray::new(real, None), FloatArray::new(imag, None)))
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use num_complex::Complex64;
197    use rand::Rng;
198
199    // ---- SciPy/NumPy FFT references ----
200
201    fn scipy_fft_ref_8_seq_0_7() -> [Complex64; 8] {
202        [
203            Complex64::new(28.0, 0.0),
204            Complex64::new(-4.0, 9.6568542494923797),
205            Complex64::new(-4.0, 4.0),
206            Complex64::new(-4.0, 1.6568542494923806),
207            Complex64::new(-4.0, 0.0),
208            Complex64::new(-4.0, -1.6568542494923806),
209            Complex64::new(-4.0, -4.0),
210            Complex64::new(-4.0, -9.6568542494923797),
211        ]
212    }
213
214    fn scipy_fft_ref_16_seq_0_15() -> [Complex64; 16] {
215        [
216            Complex64::new(120.0, 0.0),
217            Complex64::new(-7.9999999999999991, 40.218715937006785),
218            Complex64::new(-8.0, 19.313708498984759),
219            Complex64::new(-7.9999999999999991, 11.972846101323913),
220            Complex64::new(-8.0, 8.0),
221            Complex64::new(-8.0, 5.345429103354391),
222            Complex64::new(-8.0, 3.3137084989847612),
223            Complex64::new(-8.0, 1.5912989390372658),
224            Complex64::new(-8.0, 0.0),
225            Complex64::new(-7.9999999999999991, -1.5912989390372658),
226            Complex64::new(-8.0, -3.3137084989847612),
227            Complex64::new(-7.9999999999999991, -5.3454291033543946),
228            Complex64::new(-8.0, -8.0),
229            Complex64::new(-8.0, -11.97284610132391),
230            Complex64::new(-8.0, -19.313708498984759),
231            Complex64::new(-8.0, -40.218715937006785),
232        ]
233    }
234
235    #[test]
236    fn butterfly_radix8_impulse_all_ones() {
237        let mut buf = [
238            Complex64::new(1.0, 0.0),
239            Complex64::new(0.0, 0.0),
240            Complex64::new(0.0, 0.0),
241            Complex64::new(0.0, 0.0),
242            Complex64::new(0.0, 0.0),
243            Complex64::new(0.0, 0.0),
244            Complex64::new(0.0, 0.0),
245            Complex64::new(0.0, 0.0),
246        ];
247        butterfly_radix8(&mut buf);
248        let ones = [Complex64::new(1.0, 0.0); 8];
249        assert_vec_close(&buf, &ones, 1e-15);
250    }
251
252    #[test]
253    fn fft8_radix_matches_scipy_seq0_7() {
254        let mut buf = [
255            Complex64::new(0.0, 0.0),
256            Complex64::new(1.0, 0.0),
257            Complex64::new(2.0, 0.0),
258            Complex64::new(3.0, 0.0),
259            Complex64::new(4.0, 0.0),
260            Complex64::new(5.0, 0.0),
261            Complex64::new(6.0, 0.0),
262            Complex64::new(7.0, 0.0),
263        ];
264        let (_re, _im) = fft8_radix(&mut buf).unwrap();
265        let ref_out = scipy_fft_ref_8_seq_0_7();
266        assert_vec_close(&buf, &ref_out, 1e-12);
267    }
268
269    #[test]
270    fn block_fft_matches_scipy_seq0_7() {
271        let mut data = (0..8)
272            .map(|v| Complex64::new(v as f64, 0.0))
273            .collect::<Vec<_>>();
274        let (_re, _im) = block_fft(&mut data).unwrap();
275        let ref_out = scipy_fft_ref_8_seq_0_7();
276        assert_vec_close(&data, &ref_out, 1e-12);
277    }
278
279    #[test]
280    fn block_fft_matches_scipy_seq0_15() {
281        let mut data = (0..16)
282            .map(|v| Complex64::new(v as f64, 0.0))
283            .collect::<Vec<_>>();
284        let (_re, _im) = block_fft(&mut data).unwrap();
285        let ref_out = scipy_fft_ref_16_seq_0_15();
286        assert_vec_close(&data, &ref_out, 1e-11);
287    }
288
289    // Basic DFT for validation
290    fn dft_naive(x: &[Complex64]) -> Vec<Complex64> {
291        let n = x.len() as f64;
292        (0..x.len())
293            .map(|k| {
294                let mut sum = Complex64::new(0.0, 0.0);
295                for (n_idx, &val) in x.iter().enumerate() {
296                    let angle = -2.0 * std::f64::consts::PI * (k as f64) * (n_idx as f64) / n;
297                    sum += val * Complex64::from_polar(1.0, angle);
298                }
299                sum
300            })
301            .collect()
302    }
303
304    fn assert_vec_close(a: &[Complex64], b: &[Complex64], eps: f64) {
305        assert_eq!(a.len(), b.len());
306        for (x, y) in a.iter().zip(b) {
307            assert!((x - y).norm() < eps, "mismatch: x={:?}, y={:?}", x, y);
308        }
309    }
310
311    #[test]
312    fn radix8_exact() {
313        let mut buf = [
314            Complex64::new(0.0, 0.0),
315            Complex64::new(1.0, 0.0),
316            Complex64::new(2.0, 0.0),
317            Complex64::new(3.0, 0.0),
318            Complex64::new(4.0, 0.0),
319            Complex64::new(5.0, 0.0),
320            Complex64::new(6.0, 0.0),
321            Complex64::new(7.0, 0.0),
322        ];
323        let (_, _) = fft8_radix(&mut buf).unwrap();
324        let ref_out = dft_naive(&[
325            Complex64::new(0.0, 0.0),
326            Complex64::new(1.0, 0.0),
327            Complex64::new(2.0, 0.0),
328            Complex64::new(3.0, 0.0),
329            Complex64::new(4.0, 0.0),
330            Complex64::new(5.0, 0.0),
331            Complex64::new(6.0, 0.0),
332            Complex64::new(7.0, 0.0),
333        ]);
334        assert_vec_close(&buf, &ref_out, 1e-12);
335    }
336
337    #[test]
338    fn block_fft_random_lengths() {
339        let mut rng = rand::rng();
340        for &n in &[8, 16, 32, 64, 128, 256, 512, 1024] {
341            let mut data: Vec<Complex64> = (0..n)
342                .map(|_| Complex64::new(rng.random(), rng.random()))
343                .collect();
344            let ref_data = data.clone();
345            let (_, _) = block_fft(&mut data).unwrap();
346            let ref_out = dft_naive(&ref_data);
347            assert_vec_close(&data, &ref_out, 1e-9); // generous for large n
348        }
349    }
350
351    #[test]
352    fn block_fft_power_of_two_check() {
353        let mut bad = vec![Complex64::new(0.0, 0.0); 12]; // not power of two
354        assert!(block_fft(&mut bad).is_err());
355    }
356}