1use minarrow::enums::error::KernelError;
22use minarrow::{FloatArray, Vec64};
23use num_complex::Complex64;
24
25#[inline(always)]
26pub fn butterfly_radix8(buf: &mut [Complex64]) {
27 debug_assert_eq!(buf.len(), 8);
28
29 let (x0, x1, x2, x3, x4, x5, x6, x7) = (
31 buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7],
32 );
33
34 let a04 = x0 + x4;
36 let s04 = x0 - x4;
37 let a26 = x2 + x6;
38 let s26 = x2 - x6;
39 let a15 = x1 + x5;
40 let s15 = x1 - x5;
41 let a37 = x3 + x7;
42 let s37 = x3 - x7;
43
44 let a04a26 = a04 + a26;
46 let a04s26 = a04 - a26;
47 let a15a37 = a15 + a37;
48 let a15s37 = a15 - a37;
49
50 const J: Complex64 = Complex64 { re: 0.0, im: 1.0 };
52
53 buf[0] = a04a26 + a15a37;
55 buf[4] = a04a26 - a15a37;
56
57 let t0 = s04 + J * s26;
58 let t1 = s15 + J * s37;
59 buf[2] = t0 + Complex64::new(0.0, -1.0) * t1; buf[6] = t0 + Complex64::new(0.0, 1.0) * t1; let u0 = a04s26;
63 let u1 = Complex64::new(0.0, -1.0) * a15s37;
64 buf[1] = u0 + u1; buf[5] = u0 - u1; let v0 = s04 - J * s26;
68 let v1 = s15 - J * s37;
69 buf[3] = v0 - Complex64::new(0.0, 1.0) * v1; buf[7] = v0 - Complex64::new(0.0, -1.0) * v1; }
72
73#[inline(always)]
75fn fft4_in_place(x: &mut [Complex64; 4]) {
76 let x0 = x[0];
77 let x1 = x[1];
78 let x2 = x[2];
79 let x3 = x[3];
80
81 let a = x0 + x2; let b = x0 - x2; let c = x1 + x3; let d = (x1 - x3) * Complex64::new(0.0, -1.0); x[0] = a + c; x[2] = a - c; x[1] = b + d; x[3] = b - d; }
91
92#[inline(always)]
95pub fn fft8_radix(
96 buf: &mut [Complex64; 8],
97) -> Result<(FloatArray<f64>, FloatArray<f64>), KernelError> {
98 let mut even = [buf[0], buf[2], buf[4], buf[6]];
100 let mut odd = [buf[1], buf[3], buf[5], buf[7]];
101
102 fft4_in_place(&mut even);
104 fft4_in_place(&mut odd);
105
106 let s = std::f64::consts::FRAC_1_SQRT_2;
112 let w1 = Complex64::new(s, -s);
113 let w2 = Complex64::new(0.0, -1.0);
114 let w3 = Complex64::new(-s, -s);
115
116 let t0 = odd[0]; let t1 = w1 * odd[1]; let t2 = w2 * odd[2]; let t3 = w3 * odd[3]; buf[0] = even[0] + t0;
122 buf[4] = even[0] - t0;
123
124 buf[1] = even[1] + t1;
125 buf[5] = even[1] - t1;
126
127 buf[2] = even[2] + t2;
128 buf[6] = even[2] - t2;
129
130 buf[3] = even[3] + t3;
131 buf[7] = even[3] - t3;
132
133 let mut real = Vec64::with_capacity(8);
135 let mut imag = Vec64::with_capacity(8);
136 for &z in buf.iter() {
137 real.push(z.re);
138 imag.push(z.im);
139 }
140 Ok((FloatArray::new(real, None), FloatArray::new(imag, None)))
141}
142
143#[inline]
145pub fn block_fft(
146 data: &mut [Complex64],
147) -> Result<(FloatArray<f64>, FloatArray<f64>), KernelError> {
148 let n = data.len();
149 if n < 2 || (n & (n - 1)) != 0 {
150 return Err(KernelError::InvalidArguments(
151 "block_fft: N must be power-of-two and ≥2".into(),
152 ));
153 }
154
155 let bits = n.trailing_zeros();
157 for i in 0..n {
158 let rev = i.reverse_bits() >> (usize::BITS - bits);
159 if i < rev {
160 data.swap(i, rev);
161 }
162 }
163
164 let mut m = 2;
166 while m <= n {
167 let half = m / 2;
168 let theta = -2.0 * std::f64::consts::PI / (m as f64);
169 let w_m = Complex64::from_polar(1.0, theta);
170
171 for k in (0..n).step_by(m) {
172 let mut w = Complex64::new(1.0, 0.0);
173 for j in 0..half {
174 let t = w * data[k + j + half];
175 let u = data[k + j];
176 data[k + j] = u + t;
177 data[k + j + half] = u - t;
178 w *= w_m;
179 }
180 }
181 m <<= 1;
182 }
183
184 let mut real = Vec64::with_capacity(n);
185 let mut imag = Vec64::with_capacity(n);
186 for &z in data.iter() {
187 real.push(z.re);
188 imag.push(z.im);
189 }
190 Ok((FloatArray::new(real, None), FloatArray::new(imag, None)))
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use num_complex::Complex64;
197 use rand::Rng;
198
199 fn scipy_fft_ref_8_seq_0_7() -> [Complex64; 8] {
202 [
203 Complex64::new(28.0, 0.0),
204 Complex64::new(-4.0, 9.6568542494923797),
205 Complex64::new(-4.0, 4.0),
206 Complex64::new(-4.0, 1.6568542494923806),
207 Complex64::new(-4.0, 0.0),
208 Complex64::new(-4.0, -1.6568542494923806),
209 Complex64::new(-4.0, -4.0),
210 Complex64::new(-4.0, -9.6568542494923797),
211 ]
212 }
213
214 fn scipy_fft_ref_16_seq_0_15() -> [Complex64; 16] {
215 [
216 Complex64::new(120.0, 0.0),
217 Complex64::new(-7.9999999999999991, 40.218715937006785),
218 Complex64::new(-8.0, 19.313708498984759),
219 Complex64::new(-7.9999999999999991, 11.972846101323913),
220 Complex64::new(-8.0, 8.0),
221 Complex64::new(-8.0, 5.345429103354391),
222 Complex64::new(-8.0, 3.3137084989847612),
223 Complex64::new(-8.0, 1.5912989390372658),
224 Complex64::new(-8.0, 0.0),
225 Complex64::new(-7.9999999999999991, -1.5912989390372658),
226 Complex64::new(-8.0, -3.3137084989847612),
227 Complex64::new(-7.9999999999999991, -5.3454291033543946),
228 Complex64::new(-8.0, -8.0),
229 Complex64::new(-8.0, -11.97284610132391),
230 Complex64::new(-8.0, -19.313708498984759),
231 Complex64::new(-8.0, -40.218715937006785),
232 ]
233 }
234
235 #[test]
236 fn butterfly_radix8_impulse_all_ones() {
237 let mut buf = [
238 Complex64::new(1.0, 0.0),
239 Complex64::new(0.0, 0.0),
240 Complex64::new(0.0, 0.0),
241 Complex64::new(0.0, 0.0),
242 Complex64::new(0.0, 0.0),
243 Complex64::new(0.0, 0.0),
244 Complex64::new(0.0, 0.0),
245 Complex64::new(0.0, 0.0),
246 ];
247 butterfly_radix8(&mut buf);
248 let ones = [Complex64::new(1.0, 0.0); 8];
249 assert_vec_close(&buf, &ones, 1e-15);
250 }
251
252 #[test]
253 fn fft8_radix_matches_scipy_seq0_7() {
254 let mut buf = [
255 Complex64::new(0.0, 0.0),
256 Complex64::new(1.0, 0.0),
257 Complex64::new(2.0, 0.0),
258 Complex64::new(3.0, 0.0),
259 Complex64::new(4.0, 0.0),
260 Complex64::new(5.0, 0.0),
261 Complex64::new(6.0, 0.0),
262 Complex64::new(7.0, 0.0),
263 ];
264 let (_re, _im) = fft8_radix(&mut buf).unwrap();
265 let ref_out = scipy_fft_ref_8_seq_0_7();
266 assert_vec_close(&buf, &ref_out, 1e-12);
267 }
268
269 #[test]
270 fn block_fft_matches_scipy_seq0_7() {
271 let mut data = (0..8)
272 .map(|v| Complex64::new(v as f64, 0.0))
273 .collect::<Vec<_>>();
274 let (_re, _im) = block_fft(&mut data).unwrap();
275 let ref_out = scipy_fft_ref_8_seq_0_7();
276 assert_vec_close(&data, &ref_out, 1e-12);
277 }
278
279 #[test]
280 fn block_fft_matches_scipy_seq0_15() {
281 let mut data = (0..16)
282 .map(|v| Complex64::new(v as f64, 0.0))
283 .collect::<Vec<_>>();
284 let (_re, _im) = block_fft(&mut data).unwrap();
285 let ref_out = scipy_fft_ref_16_seq_0_15();
286 assert_vec_close(&data, &ref_out, 1e-11);
287 }
288
289 fn dft_naive(x: &[Complex64]) -> Vec<Complex64> {
291 let n = x.len() as f64;
292 (0..x.len())
293 .map(|k| {
294 let mut sum = Complex64::new(0.0, 0.0);
295 for (n_idx, &val) in x.iter().enumerate() {
296 let angle = -2.0 * std::f64::consts::PI * (k as f64) * (n_idx as f64) / n;
297 sum += val * Complex64::from_polar(1.0, angle);
298 }
299 sum
300 })
301 .collect()
302 }
303
304 fn assert_vec_close(a: &[Complex64], b: &[Complex64], eps: f64) {
305 assert_eq!(a.len(), b.len());
306 for (x, y) in a.iter().zip(b) {
307 assert!((x - y).norm() < eps, "mismatch: x={:?}, y={:?}", x, y);
308 }
309 }
310
311 #[test]
312 fn radix8_exact() {
313 let mut buf = [
314 Complex64::new(0.0, 0.0),
315 Complex64::new(1.0, 0.0),
316 Complex64::new(2.0, 0.0),
317 Complex64::new(3.0, 0.0),
318 Complex64::new(4.0, 0.0),
319 Complex64::new(5.0, 0.0),
320 Complex64::new(6.0, 0.0),
321 Complex64::new(7.0, 0.0),
322 ];
323 let (_, _) = fft8_radix(&mut buf).unwrap();
324 let ref_out = dft_naive(&[
325 Complex64::new(0.0, 0.0),
326 Complex64::new(1.0, 0.0),
327 Complex64::new(2.0, 0.0),
328 Complex64::new(3.0, 0.0),
329 Complex64::new(4.0, 0.0),
330 Complex64::new(5.0, 0.0),
331 Complex64::new(6.0, 0.0),
332 Complex64::new(7.0, 0.0),
333 ]);
334 assert_vec_close(&buf, &ref_out, 1e-12);
335 }
336
337 #[test]
338 fn block_fft_random_lengths() {
339 let mut rng = rand::rng();
340 for &n in &[8, 16, 32, 64, 128, 256, 512, 1024] {
341 let mut data: Vec<Complex64> = (0..n)
342 .map(|_| Complex64::new(rng.random(), rng.random()))
343 .collect();
344 let ref_data = data.clone();
345 let (_, _) = block_fft(&mut data).unwrap();
346 let ref_out = dft_naive(&ref_data);
347 assert_vec_close(&data, &ref_out, 1e-9); }
349 }
350
351 #[test]
352 fn block_fft_power_of_two_check() {
353 let mut bad = vec![Complex64::new(0.0, 0.0); 12]; assert!(block_fft(&mut bad).is_err());
355 }
356}