scirs2_fft/ndim.rs
1//! N-Dimensional FFT Utilities
2//!
3//! This module provides convenient wrappers and utilities for N-dimensional
4//! Fourier transforms operating directly on `ndarray` arrays of complex
5//! numbers, as well as 2-D shift helpers and N-D frequency bin generation.
6//!
7//! # Overview
8//!
9//! | Function | Description |
10//! |----------|-------------|
11//! | [`fftn_complex`] | N-D FFT on `ArrayD<Complex<f64>>` |
12//! | [`ifftn_complex`] | N-D inverse FFT on `ArrayD<Complex<f64>>` |
13//! | [`fftshift2`] | Move zero-frequency to the centre of a 2-D array |
14//! | [`ifftshift2`] | Inverse of [`fftshift2`] |
15//! | [`fftfreq_nd`] | Frequency bins for each axis of an N-D transform |
16//!
17//! ## Relationship to existing helpers
18//!
19//! * For generic `D`-dimensional arrays use [`crate::helper::fftshift`] /
20//! [`crate::helper::ifftshift`].
21//! * For standard `ArrayD<T>` with real input see [`crate::fft::fftn`] /
22//! [`crate::fft::ifftn`].
23//! * The functions here operate specifically on *complex-valued* `ArrayD` /
24//! `Array2` and expose a simpler axes-only interface.
25
26use crate::error::{FFTError, FFTResult};
27use crate::fft::{fft, ifft};
28use scirs2_core::ndarray::{Array2, ArrayD, Axis};
29use scirs2_core::numeric::Complex64;
30
31// ─────────────────────────────────────────────────────────────────────────────
32// fftn_complex / ifftn_complex
33// ─────────────────────────────────────────────────────────────────────────────
34
35/// N-dimensional FFT of a complex-valued array.
36///
37/// Applies a 1-D FFT along each axis listed in `axes` (or along all axes when
38/// `axes` is `None`), producing a complex output array of the same shape.
39///
40/// # Arguments
41///
42/// * `x` - Input complex array of any dimensionality.
43/// * `axes` - Axes to transform. `None` → transform all axes.
44///
45/// # Errors
46///
47/// Returns an error if any axis index is out of bounds.
48///
49/// # Examples
50///
51/// ```rust
52/// use scirs2_fft::ndim::fftn_complex;
53/// use scirs2_core::ndarray::{ArrayD, IxDyn};
54/// use scirs2_core::numeric::Complex64;
55///
56/// // 2 × 4 complex array
57/// let data: Vec<Complex64> = (0..8).map(|i| Complex64::new(i as f64, 0.0)).collect();
58/// let x = ArrayD::from_shape_vec(IxDyn(&[2, 4]), data).expect("shape ok");
59///
60/// let spectrum = fftn_complex(&x, None).expect("fftn failed");
61/// assert_eq!(spectrum.shape(), x.shape());
62/// ```
63pub fn fftn_complex(x: &ArrayD<Complex64>, axes: Option<&[usize]>) -> FFTResult<ArrayD<Complex64>> {
64 let ndim = x.ndim();
65 let axes_to_transform: Vec<usize> = match axes {
66 Some(a) => {
67 for &ax in a {
68 if ax >= ndim {
69 return Err(FFTError::ValueError(format!(
70 "axis {ax} out of bounds for array of ndim={ndim}"
71 )));
72 }
73 }
74 a.to_vec()
75 }
76 None => (0..ndim).collect(),
77 };
78
79 let mut result = x.to_owned();
80 for ax in axes_to_transform {
81 apply_fft1d_along_axis(&mut result, ax, false)?;
82 }
83 Ok(result)
84}
85
86/// N-dimensional inverse FFT of a complex-valued array.
87///
88/// Applies a 1-D inverse FFT along each axis listed in `axes` (or along all
89/// axes when `axes` is `None`).
90///
91/// # Arguments
92///
93/// * `x` - Input complex array.
94/// * `axes` - Axes to transform inversely. `None` → transform all axes.
95///
96/// # Errors
97///
98/// Returns an error if any axis index is out of bounds.
99///
100/// # Examples
101///
102/// ```rust
103/// use scirs2_fft::ndim::{fftn_complex, ifftn_complex};
104/// use scirs2_core::ndarray::{ArrayD, IxDyn};
105/// use scirs2_core::numeric::Complex64;
106///
107/// let data: Vec<Complex64> = (0..8).map(|i| Complex64::new(i as f64, 0.0)).collect();
108/// let x = ArrayD::from_shape_vec(IxDyn(&[2, 4]), data).expect("shape ok");
109///
110/// let spectrum = fftn_complex(&x, None).expect("fftn failed");
111/// let recovered = ifftn_complex(&spectrum, None).expect("ifftn failed");
112///
113/// // Round-trip should recover the original (within floating-point tolerance)
114/// for (a, b) in x.iter().zip(recovered.iter()) {
115/// assert!((a.re - b.re).abs() < 1e-10);
116/// assert!((a.im - b.im).abs() < 1e-10);
117/// }
118/// ```
119pub fn ifftn_complex(
120 x: &ArrayD<Complex64>,
121 axes: Option<&[usize]>,
122) -> FFTResult<ArrayD<Complex64>> {
123 let ndim = x.ndim();
124 let axes_to_transform: Vec<usize> = match axes {
125 Some(a) => {
126 for &ax in a {
127 if ax >= ndim {
128 return Err(FFTError::ValueError(format!(
129 "axis {ax} out of bounds for array of ndim={ndim}"
130 )));
131 }
132 }
133 a.to_vec()
134 }
135 None => (0..ndim).collect(),
136 };
137
138 let mut result = x.to_owned();
139 for ax in axes_to_transform {
140 apply_fft1d_along_axis(&mut result, ax, true)?;
141 }
142 Ok(result)
143}
144
145// ─────────────────────────────────────────────────────────────────────────────
146// 2-D shift helpers
147// ─────────────────────────────────────────────────────────────────────────────
148
149/// Shift the zero-frequency component to the centre of a 2-D complex array.
150///
151/// For a 2-D FFT output of shape `(M, N)` the DC component is at `[0, 0]`.
152/// `fftshift2` moves it to the centre position `[M/2, N/2]` (integer division),
153/// which is the natural representation for visualisation.
154///
155/// # Examples
156///
157/// ```rust
158/// use scirs2_fft::ndim::fftshift2;
159/// use scirs2_core::ndarray::Array2;
160/// use scirs2_core::numeric::Complex64;
161///
162/// // 4×4 array where position [0,0] has value 1 (DC component)
163/// let mut data = Array2::<Complex64>::zeros((4, 4));
164/// data[[0, 0]] = Complex64::new(1.0, 0.0);
165///
166/// let shifted = fftshift2(&data);
167/// // After shift the DC component is at [2, 2]
168/// assert!((shifted[[2, 2]].re - 1.0).abs() < 1e-12);
169/// ```
170pub fn fftshift2(x: &Array2<Complex64>) -> Array2<Complex64> {
171 shift2_impl(x, false)
172}
173
174/// Inverse of [`fftshift2`]: move the zero-frequency back to position `[0, 0]`.
175///
176/// # Examples
177///
178/// ```rust
179/// use scirs2_fft::ndim::{fftshift2, ifftshift2};
180/// use scirs2_core::ndarray::Array2;
181/// use scirs2_core::numeric::Complex64;
182///
183/// let mut data = Array2::<Complex64>::zeros((4, 4));
184/// data[[0, 0]] = Complex64::new(1.0, 0.0);
185///
186/// let shifted = fftshift2(&data);
187/// let recovered = ifftshift2(&shifted);
188/// assert!((recovered[[0, 0]].re - 1.0).abs() < 1e-12);
189/// ```
190pub fn ifftshift2(x: &Array2<Complex64>) -> Array2<Complex64> {
191 shift2_impl(x, true)
192}
193
194// ─────────────────────────────────────────────────────────────────────────────
195// Frequency bins for N-D FFT
196// ─────────────────────────────────────────────────────────────────────────────
197
198/// Compute frequency bins for each axis of an N-dimensional FFT.
199///
200/// Returns a vector (one entry per axis) of frequency bin arrays in cycles per
201/// unit, using the per-axis sample spacings supplied in `d`. This generalises
202/// [`crate::helper::fftfreq`] to multiple axes at once.
203///
204/// # Arguments
205///
206/// * `shape` - Shape of the N-D array (one entry per dimension).
207/// * `d` - Sample spacing for each dimension. Must have the same length as
208/// `shape`; a value of `1.0` gives frequencies in cycles/sample.
209///
210/// # Returns
211///
212/// `Vec<Vec<f64>>` where `result[i]` contains the `shape[i]` frequency values
213/// for axis `i`.
214///
215/// # Errors
216///
217/// Returns an error if `shape.len() != d.len()` or if any spacing is ≤ 0.
218///
219/// # Examples
220///
221/// ```rust
222/// use scirs2_fft::ndim::fftfreq_nd;
223///
224/// // 4×8 array, sample spacing 0.5 in first axis and 1.0 in second
225/// let freqs = fftfreq_nd(&[4, 8], &[0.5, 1.0]).expect("fftfreq_nd failed");
226///
227/// assert_eq!(freqs.len(), 2);
228/// assert_eq!(freqs[0].len(), 4);
229/// assert_eq!(freqs[1].len(), 8);
230///
231/// // DC component is always 0
232/// assert_eq!(freqs[0][0], 0.0);
233/// assert_eq!(freqs[1][0], 0.0);
234/// ```
235pub fn fftfreq_nd(shape: &[usize], d: &[f64]) -> FFTResult<Vec<Vec<f64>>> {
236 if shape.len() != d.len() {
237 return Err(FFTError::ValueError(format!(
238 "shape.len()={} must equal d.len()={}",
239 shape.len(),
240 d.len()
241 )));
242 }
243 for (i, &spacing) in d.iter().enumerate() {
244 if spacing <= 0.0 {
245 return Err(FFTError::ValueError(format!(
246 "sample spacing d[{i}]={spacing} must be > 0"
247 )));
248 }
249 }
250
251 shape
252 .iter()
253 .zip(d.iter())
254 .map(|(&n, &spacing)| fftfreq_1d(n, spacing))
255 .collect()
256}
257
258// ─────────────────────────────────────────────────────────────────────────────
259// Private helpers
260// ─────────────────────────────────────────────────────────────────────────────
261
262/// Apply a 1-D FFT or IFFT along the given axis of a dynamic-dim complex array.
263fn apply_fft1d_along_axis(
264 data: &mut ArrayD<Complex64>,
265 axis: usize,
266 inverse: bool,
267) -> FFTResult<()> {
268 let axis_len = data.shape()[axis];
269 let mut buf = vec![Complex64::new(0.0, 0.0); axis_len];
270
271 for mut lane in data.lanes_mut(Axis(axis)) {
272 buf.iter_mut().zip(lane.iter()).for_each(|(b, &x)| *b = x);
273
274 // Pass explicit size to avoid auto-padding to next power of two
275 let n = buf.len();
276 let transformed = if inverse {
277 ifft(&buf, Some(n))?
278 } else {
279 fft(&buf, Some(n))?
280 };
281
282 lane.iter_mut()
283 .zip(transformed.iter())
284 .for_each(|(d, &s)| *d = s);
285 }
286 Ok(())
287}
288
289/// Shared implementation for fftshift2 / ifftshift2.
290///
291/// `inverse = false` → forward shift (DC to centre).
292/// `inverse = true` → inverse shift (centre to DC).
293fn shift2_impl(x: &Array2<Complex64>, inverse: bool) -> Array2<Complex64> {
294 let (rows, cols) = x.dim();
295 let row_shift = if inverse {
296 // For odd n: forward shift by n/2 (floor), inverse by ceil
297 rows - rows / 2
298 } else {
299 rows / 2
300 };
301 let col_shift = if inverse { cols - cols / 2 } else { cols / 2 };
302
303 let mut out = Array2::<Complex64>::zeros((rows, cols));
304 for r in 0..rows {
305 let new_r = (r + row_shift) % rows;
306 for c in 0..cols {
307 let new_c = (c + col_shift) % cols;
308 out[[new_r, new_c]] = x[[r, c]];
309 }
310 }
311 out
312}
313
314/// 1-D fftfreq: frequency values for n samples with spacing d.
315///
316/// Matches the convention of `numpy.fft.fftfreq` / `scipy.fft.fftfreq`:
317/// - Even n: `[0, 1, ..., n/2-1, -n/2, -(n/2-1), ..., -1] / (n * d)`
318/// - Odd n: `[0, 1, ..., (n-1)/2, -((n-1)/2), ..., -1] / (n * d)`
319fn fftfreq_1d(n: usize, d: f64) -> FFTResult<Vec<f64>> {
320 if n == 0 {
321 return Ok(Vec::new());
322 }
323 let scale = 1.0 / (n as f64 * d);
324
325 let mut freqs = Vec::with_capacity(n);
326 let p = (n / 2) as i64; // positive half length (floor(n/2))
327
328 // Positive frequencies: 0, 1, ..., p (for even n, p = n/2; for odd, p = (n-1)/2)
329 // But for even n the Nyquist bin n/2 is represented as *negative* (-n/2)
330 for i in 0..n as i64 {
331 let k = if i <= p - (if n % 2 == 0 { 1 } else { 0 }) as i64 {
332 // Positive frequencies: 0 .. floor((n-1)/2)
333 i
334 } else {
335 // Negative frequencies: -floor(n/2) .. -1
336 i - n as i64
337 };
338 freqs.push(k as f64 * scale);
339 }
340 Ok(freqs)
341}
342
343// ─────────────────────────────────────────────────────────────────────────────
344// Tests
345// ─────────────────────────────────────────────────────────────────────────────
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use approx::assert_relative_eq;
351 use scirs2_core::ndarray::IxDyn;
352 use std::f64::consts::PI;
353
354 // ── fftn_complex / ifftn_complex roundtrip ───────────────────────────────
355
356 fn make_complex_array(shape: &[usize]) -> ArrayD<Complex64> {
357 let n: usize = shape.iter().product();
358 let data: Vec<Complex64> = (0..n)
359 .map(|i| Complex64::new(i as f64, -(i as f64) * 0.5))
360 .collect();
361 ArrayD::from_shape_vec(IxDyn(shape), data).expect("shape ok")
362 }
363
364 #[test]
365 fn test_fftn_ifftn_roundtrip_1d() {
366 let x = make_complex_array(&[16]);
367 let s = fftn_complex(&x, None).expect("fftn");
368 let r = ifftn_complex(&s, None).expect("ifftn");
369 for (a, b) in x.iter().zip(r.iter()) {
370 assert_relative_eq!(a.re, b.re, epsilon = 1e-10);
371 assert_relative_eq!(a.im, b.im, epsilon = 1e-10);
372 }
373 }
374
375 #[test]
376 fn test_fftn_ifftn_roundtrip_2d() {
377 let x = make_complex_array(&[4, 8]);
378 let s = fftn_complex(&x, None).expect("fftn 2d");
379 let r = ifftn_complex(&s, None).expect("ifftn 2d");
380 for (a, b) in x.iter().zip(r.iter()) {
381 assert_relative_eq!(a.re, b.re, epsilon = 1e-10);
382 assert_relative_eq!(a.im, b.im, epsilon = 1e-10);
383 }
384 }
385
386 #[test]
387 fn test_fftn_ifftn_roundtrip_3d() {
388 let x = make_complex_array(&[2, 3, 4]);
389 let s = fftn_complex(&x, None).expect("fftn 3d");
390 let r = ifftn_complex(&s, None).expect("ifftn 3d");
391 for (a, b) in x.iter().zip(r.iter()) {
392 assert_relative_eq!(a.re, b.re, epsilon = 1e-10);
393 assert_relative_eq!(a.im, b.im, epsilon = 1e-10);
394 }
395 }
396
397 #[test]
398 fn test_fftn_partial_axes() {
399 let x = make_complex_array(&[4, 8]);
400 // Only transform axis 1
401 let s1 = fftn_complex(&x, Some(&[1])).expect("fftn axis 1");
402 let r1 = ifftn_complex(&s1, Some(&[1])).expect("ifftn axis 1");
403 for (a, b) in x.iter().zip(r1.iter()) {
404 assert_relative_eq!(a.re, b.re, epsilon = 1e-10);
405 assert_relative_eq!(a.im, b.im, epsilon = 1e-10);
406 }
407 }
408
409 #[test]
410 fn test_fftn_out_of_bounds_axis() {
411 let x = make_complex_array(&[4, 8]);
412 assert!(fftn_complex(&x, Some(&[2])).is_err()); // only 2 axes (0, 1)
413 assert!(ifftn_complex(&x, Some(&[5])).is_err());
414 }
415
416 #[test]
417 fn test_fftn_shape_preserved() {
418 let x = make_complex_array(&[3, 5, 7]);
419 let s = fftn_complex(&x, None).expect("fftn");
420 assert_eq!(s.shape(), x.shape());
421 }
422
423 // ── fftshift2 / ifftshift2 ───────────────────────────────────────────────
424
425 #[test]
426 fn test_fftshift2_roundtrip_even() {
427 let rows = 4;
428 let cols = 6;
429 let data: Vec<Complex64> = (0..(rows * cols) as i32)
430 .map(|i| Complex64::new(i as f64, 0.0))
431 .collect();
432 let x = Array2::from_shape_vec((rows, cols), data).expect("shape");
433 let shifted = fftshift2(&x);
434 let recovered = ifftshift2(&shifted);
435 for r in 0..rows {
436 for c in 0..cols {
437 assert_relative_eq!(x[[r, c]].re, recovered[[r, c]].re, epsilon = 1e-12);
438 }
439 }
440 }
441
442 #[test]
443 fn test_fftshift2_roundtrip_odd() {
444 let rows = 5;
445 let cols = 7;
446 let data: Vec<Complex64> = (0..(rows * cols) as i32)
447 .map(|i| Complex64::new(i as f64, i as f64 * 0.1))
448 .collect();
449 let x = Array2::from_shape_vec((rows, cols), data).expect("shape");
450 let shifted = fftshift2(&x);
451 let recovered = ifftshift2(&shifted);
452 for r in 0..rows {
453 for c in 0..cols {
454 assert_relative_eq!(x[[r, c]].re, recovered[[r, c]].re, epsilon = 1e-12);
455 assert_relative_eq!(x[[r, c]].im, recovered[[r, c]].im, epsilon = 1e-12);
456 }
457 }
458 }
459
460 #[test]
461 fn test_fftshift2_dc_to_centre() {
462 let mut data = Array2::<Complex64>::zeros((4, 4));
463 data[[0, 0]] = Complex64::new(1.0, 0.0);
464 let shifted = fftshift2(&data);
465 // For n=4, shift = 2 → DC moves to [2, 2]
466 assert_relative_eq!(shifted[[2, 2]].re, 1.0, epsilon = 1e-12);
467 assert_relative_eq!(shifted[[0, 0]].re, 0.0, epsilon = 1e-12);
468 }
469
470 #[test]
471 fn test_ifftshift2_dc_back() {
472 let mut data = Array2::<Complex64>::zeros((4, 4));
473 data[[0, 0]] = Complex64::new(1.0, 0.0);
474 let shifted = fftshift2(&data);
475 let recovered = ifftshift2(&shifted);
476 assert_relative_eq!(recovered[[0, 0]].re, 1.0, epsilon = 1e-12);
477 }
478
479 // ── fftfreq_nd ───────────────────────────────────────────────────────────
480
481 #[test]
482 fn test_fftfreq_nd_basic() {
483 let freqs = fftfreq_nd(&[4, 8], &[1.0, 1.0]).expect("fftfreq_nd");
484 assert_eq!(freqs.len(), 2);
485 assert_eq!(freqs[0].len(), 4);
486 assert_eq!(freqs[1].len(), 8);
487 // DC is always 0
488 assert_relative_eq!(freqs[0][0], 0.0, epsilon = 1e-15);
489 assert_relative_eq!(freqs[1][0], 0.0, epsilon = 1e-15);
490 }
491
492 #[test]
493 fn test_fftfreq_nd_matches_1d_fftfreq() {
494 // Compare with the scalar fftfreq from crate::helper
495 use crate::helper::fftfreq;
496 let n = 16;
497 let d = 0.5;
498 let nd_freqs = fftfreq_nd(&[n], &[d]).expect("nd");
499 let scalar_freqs = fftfreq(n, d).expect("1d");
500 assert_eq!(nd_freqs[0].len(), scalar_freqs.len());
501 for (a, b) in nd_freqs[0].iter().zip(scalar_freqs.iter()) {
502 assert_relative_eq!(*a, *b, epsilon = 1e-14);
503 }
504 }
505
506 #[test]
507 fn test_fftfreq_nd_spacing() {
508 // With d=0.5 the max positive frequency doubles compared to d=1.0
509 let f1 = fftfreq_nd(&[8], &[1.0]).expect("d=1");
510 let f2 = fftfreq_nd(&[8], &[0.5]).expect("d=0.5");
511 // Max positive freq for n=8, d=1: 3/8; for d=0.5: 3/4
512 assert_relative_eq!(f1[0][3], 3.0 / 8.0, epsilon = 1e-14);
513 assert_relative_eq!(f2[0][3], 3.0 / 4.0, epsilon = 1e-14);
514 }
515
516 #[test]
517 fn test_fftfreq_nd_mismatch_error() {
518 assert!(fftfreq_nd(&[4, 8], &[1.0]).is_err()); // lengths differ
519 assert!(fftfreq_nd(&[4], &[0.0]).is_err()); // zero spacing
520 assert!(fftfreq_nd(&[4], &[-1.0]).is_err()); // negative spacing
521 }
522
523 #[test]
524 fn test_fftfreq_nd_empty_axis() {
525 let freqs = fftfreq_nd(&[0, 4], &[1.0, 1.0]).expect("empty axis ok");
526 assert_eq!(freqs[0].len(), 0);
527 assert_eq!(freqs[1].len(), 4);
528 }
529
530 // ── Correctness: 2D FFT shift is consistent with element-wise check ──────
531
532 #[test]
533 fn test_fftshift2_known_pattern() {
534 // Build a 4×4 array with known values at corners
535 let rows = 4;
536 let cols = 4;
537 let mut x = Array2::<Complex64>::zeros((rows, cols));
538 x[[0, 0]] = Complex64::new(1.0, 0.0); // top-left (DC)
539 x[[0, 2]] = Complex64::new(2.0, 0.0); // top-right region
540 x[[2, 0]] = Complex64::new(3.0, 0.0); // bottom-left region
541 x[[2, 2]] = Complex64::new(4.0, 0.0); // bottom-right region
542
543 let shifted = fftshift2(&x);
544 // For n=4 (even), shift = 2 → each element at [r,c] moves to [(r+2)%4, (c+2)%4]
545 assert_relative_eq!(shifted[[2, 2]].re, 1.0, epsilon = 1e-12); // was [0,0]
546 assert_relative_eq!(shifted[[2, 0]].re, 2.0, epsilon = 1e-12); // was [0,2]
547 assert_relative_eq!(shifted[[0, 2]].re, 3.0, epsilon = 1e-12); // was [2,0]
548 assert_relative_eq!(shifted[[0, 0]].re, 4.0, epsilon = 1e-12); // was [2,2]
549 }
550
551 // ── Integration: fftn + fftshift2 on a sinusoidal image ─────────────────
552
553 #[test]
554 fn test_fftn_then_shift_preserves_energy() {
555 use std::f64::consts::PI;
556 let n = 8;
557 // Simple 2D sinusoid
558 let data: Vec<Complex64> = (0..n * n)
559 .map(|k| {
560 let r = k / n;
561 let c = k % n;
562 let re =
563 (2.0 * PI * r as f64 / n as f64).cos() * (2.0 * PI * c as f64 / n as f64).cos();
564 Complex64::new(re, 0.0)
565 })
566 .collect();
567 let x = ArrayD::from_shape_vec(IxDyn(&[n, n]), data).expect("shape");
568 let spec = fftn_complex(&x, None).expect("fftn");
569 // Parseval: sum |X[k]|^2 = n^2 * sum |x[n]|^2
570 let energy_x: f64 = x.iter().map(|c| c.norm_sqr()).sum();
571 let energy_s: f64 = spec.iter().map(|c| c.norm_sqr()).sum();
572 let n2 = (n * n) as f64;
573 assert_relative_eq!(energy_s, n2 * energy_x, epsilon = 1e-8 * energy_s.max(1.0));
574 }
575}