Skip to main content

scirs2_fft/
ndim.rs

1//! N-Dimensional FFT Utilities
2//!
3//! This module provides convenient wrappers and utilities for N-dimensional
4//! Fourier transforms operating directly on `ndarray` arrays of complex
5//! numbers, as well as 2-D shift helpers and N-D frequency bin generation.
6//!
7//! # Overview
8//!
9//! | Function | Description |
10//! |----------|-------------|
11//! | [`fftn_complex`]  | N-D FFT on `ArrayD<Complex<f64>>` |
12//! | [`ifftn_complex`] | N-D inverse FFT on `ArrayD<Complex<f64>>` |
13//! | [`fftshift2`]     | Move zero-frequency to the centre of a 2-D array |
14//! | [`ifftshift2`]    | Inverse of [`fftshift2`] |
15//! | [`fftfreq_nd`]    | Frequency bins for each axis of an N-D transform |
16//!
17//! ## Relationship to existing helpers
18//!
19//! * For generic `D`-dimensional arrays use [`crate::helper::fftshift`] /
20//!   [`crate::helper::ifftshift`].
21//! * For standard `ArrayD<T>` with real input see [`crate::fft::fftn`] /
22//!   [`crate::fft::ifftn`].
23//! * The functions here operate specifically on *complex-valued* `ArrayD` /
24//!   `Array2` and expose a simpler axes-only interface.
25
26use crate::error::{FFTError, FFTResult};
27use crate::fft::{fft, ifft};
28use scirs2_core::ndarray::{Array2, ArrayD, Axis};
29use scirs2_core::numeric::Complex64;
30
31// ─────────────────────────────────────────────────────────────────────────────
32//  fftn_complex / ifftn_complex
33// ─────────────────────────────────────────────────────────────────────────────
34
35/// N-dimensional FFT of a complex-valued array.
36///
37/// Applies a 1-D FFT along each axis listed in `axes` (or along all axes when
38/// `axes` is `None`), producing a complex output array of the same shape.
39///
40/// # Arguments
41///
42/// * `x`    - Input complex array of any dimensionality.
43/// * `axes` - Axes to transform.  `None` → transform all axes.
44///
45/// # Errors
46///
47/// Returns an error if any axis index is out of bounds.
48///
49/// # Examples
50///
51/// ```rust
52/// use scirs2_fft::ndim::fftn_complex;
53/// use scirs2_core::ndarray::{ArrayD, IxDyn};
54/// use scirs2_core::numeric::Complex64;
55///
56/// // 2 × 4 complex array
57/// let data: Vec<Complex64> = (0..8).map(|i| Complex64::new(i as f64, 0.0)).collect();
58/// let x = ArrayD::from_shape_vec(IxDyn(&[2, 4]), data).expect("shape ok");
59///
60/// let spectrum = fftn_complex(&x, None).expect("fftn failed");
61/// assert_eq!(spectrum.shape(), x.shape());
62/// ```
63pub fn fftn_complex(x: &ArrayD<Complex64>, axes: Option<&[usize]>) -> FFTResult<ArrayD<Complex64>> {
64    let ndim = x.ndim();
65    let axes_to_transform: Vec<usize> = match axes {
66        Some(a) => {
67            for &ax in a {
68                if ax >= ndim {
69                    return Err(FFTError::ValueError(format!(
70                        "axis {ax} out of bounds for array of ndim={ndim}"
71                    )));
72                }
73            }
74            a.to_vec()
75        }
76        None => (0..ndim).collect(),
77    };
78
79    let mut result = x.to_owned();
80    for ax in axes_to_transform {
81        apply_fft1d_along_axis(&mut result, ax, false)?;
82    }
83    Ok(result)
84}
85
86/// N-dimensional inverse FFT of a complex-valued array.
87///
88/// Applies a 1-D inverse FFT along each axis listed in `axes` (or along all
89/// axes when `axes` is `None`).
90///
91/// # Arguments
92///
93/// * `x`    - Input complex array.
94/// * `axes` - Axes to transform inversely.  `None` → transform all axes.
95///
96/// # Errors
97///
98/// Returns an error if any axis index is out of bounds.
99///
100/// # Examples
101///
102/// ```rust
103/// use scirs2_fft::ndim::{fftn_complex, ifftn_complex};
104/// use scirs2_core::ndarray::{ArrayD, IxDyn};
105/// use scirs2_core::numeric::Complex64;
106///
107/// let data: Vec<Complex64> = (0..8).map(|i| Complex64::new(i as f64, 0.0)).collect();
108/// let x = ArrayD::from_shape_vec(IxDyn(&[2, 4]), data).expect("shape ok");
109///
110/// let spectrum  = fftn_complex(&x, None).expect("fftn failed");
111/// let recovered = ifftn_complex(&spectrum, None).expect("ifftn failed");
112///
113/// // Round-trip should recover the original (within floating-point tolerance)
114/// for (a, b) in x.iter().zip(recovered.iter()) {
115///     assert!((a.re - b.re).abs() < 1e-10);
116///     assert!((a.im - b.im).abs() < 1e-10);
117/// }
118/// ```
119pub fn ifftn_complex(
120    x: &ArrayD<Complex64>,
121    axes: Option<&[usize]>,
122) -> FFTResult<ArrayD<Complex64>> {
123    let ndim = x.ndim();
124    let axes_to_transform: Vec<usize> = match axes {
125        Some(a) => {
126            for &ax in a {
127                if ax >= ndim {
128                    return Err(FFTError::ValueError(format!(
129                        "axis {ax} out of bounds for array of ndim={ndim}"
130                    )));
131                }
132            }
133            a.to_vec()
134        }
135        None => (0..ndim).collect(),
136    };
137
138    let mut result = x.to_owned();
139    for ax in axes_to_transform {
140        apply_fft1d_along_axis(&mut result, ax, true)?;
141    }
142    Ok(result)
143}
144
145// ─────────────────────────────────────────────────────────────────────────────
146//  2-D shift helpers
147// ─────────────────────────────────────────────────────────────────────────────
148
149/// Shift the zero-frequency component to the centre of a 2-D complex array.
150///
151/// For a 2-D FFT output of shape `(M, N)` the DC component is at `[0, 0]`.
152/// `fftshift2` moves it to the centre position `[M/2, N/2]` (integer division),
153/// which is the natural representation for visualisation.
154///
155/// # Examples
156///
157/// ```rust
158/// use scirs2_fft::ndim::fftshift2;
159/// use scirs2_core::ndarray::Array2;
160/// use scirs2_core::numeric::Complex64;
161///
162/// // 4×4 array where position [0,0] has value 1 (DC component)
163/// let mut data = Array2::<Complex64>::zeros((4, 4));
164/// data[[0, 0]] = Complex64::new(1.0, 0.0);
165///
166/// let shifted = fftshift2(&data);
167/// // After shift the DC component is at [2, 2]
168/// assert!((shifted[[2, 2]].re - 1.0).abs() < 1e-12);
169/// ```
170pub fn fftshift2(x: &Array2<Complex64>) -> Array2<Complex64> {
171    shift2_impl(x, false)
172}
173
174/// Inverse of [`fftshift2`]: move the zero-frequency back to position `[0, 0]`.
175///
176/// # Examples
177///
178/// ```rust
179/// use scirs2_fft::ndim::{fftshift2, ifftshift2};
180/// use scirs2_core::ndarray::Array2;
181/// use scirs2_core::numeric::Complex64;
182///
183/// let mut data = Array2::<Complex64>::zeros((4, 4));
184/// data[[0, 0]] = Complex64::new(1.0, 0.0);
185///
186/// let shifted   = fftshift2(&data);
187/// let recovered = ifftshift2(&shifted);
188/// assert!((recovered[[0, 0]].re - 1.0).abs() < 1e-12);
189/// ```
190pub fn ifftshift2(x: &Array2<Complex64>) -> Array2<Complex64> {
191    shift2_impl(x, true)
192}
193
194// ─────────────────────────────────────────────────────────────────────────────
195//  Frequency bins for N-D FFT
196// ─────────────────────────────────────────────────────────────────────────────
197
198/// Compute frequency bins for each axis of an N-dimensional FFT.
199///
200/// Returns a vector (one entry per axis) of frequency bin arrays in cycles per
201/// unit, using the per-axis sample spacings supplied in `d`.  This generalises
202/// [`crate::helper::fftfreq`] to multiple axes at once.
203///
204/// # Arguments
205///
206/// * `shape` - Shape of the N-D array (one entry per dimension).
207/// * `d`     - Sample spacing for each dimension.  Must have the same length as
208/// `shape`; a value of `1.0` gives frequencies in cycles/sample.
209///
210/// # Returns
211///
212/// `Vec<Vec<f64>>` where `result[i]` contains the `shape[i]` frequency values
213/// for axis `i`.
214///
215/// # Errors
216///
217/// Returns an error if `shape.len() != d.len()` or if any spacing is ≤ 0.
218///
219/// # Examples
220///
221/// ```rust
222/// use scirs2_fft::ndim::fftfreq_nd;
223///
224/// // 4×8 array, sample spacing 0.5 in first axis and 1.0 in second
225/// let freqs = fftfreq_nd(&[4, 8], &[0.5, 1.0]).expect("fftfreq_nd failed");
226///
227/// assert_eq!(freqs.len(), 2);
228/// assert_eq!(freqs[0].len(), 4);
229/// assert_eq!(freqs[1].len(), 8);
230///
231/// // DC component is always 0
232/// assert_eq!(freqs[0][0], 0.0);
233/// assert_eq!(freqs[1][0], 0.0);
234/// ```
235pub fn fftfreq_nd(shape: &[usize], d: &[f64]) -> FFTResult<Vec<Vec<f64>>> {
236    if shape.len() != d.len() {
237        return Err(FFTError::ValueError(format!(
238            "shape.len()={} must equal d.len()={}",
239            shape.len(),
240            d.len()
241        )));
242    }
243    for (i, &spacing) in d.iter().enumerate() {
244        if spacing <= 0.0 {
245            return Err(FFTError::ValueError(format!(
246                "sample spacing d[{i}]={spacing} must be > 0"
247            )));
248        }
249    }
250
251    shape
252        .iter()
253        .zip(d.iter())
254        .map(|(&n, &spacing)| fftfreq_1d(n, spacing))
255        .collect()
256}
257
258// ─────────────────────────────────────────────────────────────────────────────
259//  Private helpers
260// ─────────────────────────────────────────────────────────────────────────────
261
262/// Apply a 1-D FFT or IFFT along the given axis of a dynamic-dim complex array.
263fn apply_fft1d_along_axis(
264    data: &mut ArrayD<Complex64>,
265    axis: usize,
266    inverse: bool,
267) -> FFTResult<()> {
268    let axis_len = data.shape()[axis];
269    let mut buf = vec![Complex64::new(0.0, 0.0); axis_len];
270
271    for mut lane in data.lanes_mut(Axis(axis)) {
272        buf.iter_mut().zip(lane.iter()).for_each(|(b, &x)| *b = x);
273
274        // Pass explicit size to avoid auto-padding to next power of two
275        let n = buf.len();
276        let transformed = if inverse {
277            ifft(&buf, Some(n))?
278        } else {
279            fft(&buf, Some(n))?
280        };
281
282        lane.iter_mut()
283            .zip(transformed.iter())
284            .for_each(|(d, &s)| *d = s);
285    }
286    Ok(())
287}
288
289/// Shared implementation for fftshift2 / ifftshift2.
290///
291/// `inverse = false` → forward shift (DC to centre).
292/// `inverse = true`  → inverse shift (centre to DC).
293fn shift2_impl(x: &Array2<Complex64>, inverse: bool) -> Array2<Complex64> {
294    let (rows, cols) = x.dim();
295    let row_shift = if inverse {
296        // For odd n: forward shift by n/2 (floor), inverse by ceil
297        rows - rows / 2
298    } else {
299        rows / 2
300    };
301    let col_shift = if inverse { cols - cols / 2 } else { cols / 2 };
302
303    let mut out = Array2::<Complex64>::zeros((rows, cols));
304    for r in 0..rows {
305        let new_r = (r + row_shift) % rows;
306        for c in 0..cols {
307            let new_c = (c + col_shift) % cols;
308            out[[new_r, new_c]] = x[[r, c]];
309        }
310    }
311    out
312}
313
314/// 1-D fftfreq: frequency values for n samples with spacing d.
315///
316/// Matches the convention of `numpy.fft.fftfreq` / `scipy.fft.fftfreq`:
317/// - Even n: `[0, 1, ..., n/2-1, -n/2, -(n/2-1), ..., -1] / (n * d)`
318/// - Odd  n: `[0, 1, ..., (n-1)/2, -((n-1)/2), ..., -1] / (n * d)`
319fn fftfreq_1d(n: usize, d: f64) -> FFTResult<Vec<f64>> {
320    if n == 0 {
321        return Ok(Vec::new());
322    }
323    let scale = 1.0 / (n as f64 * d);
324
325    let mut freqs = Vec::with_capacity(n);
326    let p = (n / 2) as i64; // positive half length (floor(n/2))
327
328    // Positive frequencies: 0, 1, ..., p  (for even n, p = n/2; for odd, p = (n-1)/2)
329    // But for even n the Nyquist bin n/2 is represented as *negative* (-n/2)
330    for i in 0..n as i64 {
331        let k = if i <= p - (if n % 2 == 0 { 1 } else { 0 }) as i64 {
332            // Positive frequencies: 0 .. floor((n-1)/2)
333            i
334        } else {
335            // Negative frequencies: -floor(n/2) .. -1
336            i - n as i64
337        };
338        freqs.push(k as f64 * scale);
339    }
340    Ok(freqs)
341}
342
343// ─────────────────────────────────────────────────────────────────────────────
344//  Tests
345// ─────────────────────────────────────────────────────────────────────────────
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use approx::assert_relative_eq;
351    use scirs2_core::ndarray::IxDyn;
352    use std::f64::consts::PI;
353
354    // ── fftn_complex / ifftn_complex roundtrip ───────────────────────────────
355
356    fn make_complex_array(shape: &[usize]) -> ArrayD<Complex64> {
357        let n: usize = shape.iter().product();
358        let data: Vec<Complex64> = (0..n)
359            .map(|i| Complex64::new(i as f64, -(i as f64) * 0.5))
360            .collect();
361        ArrayD::from_shape_vec(IxDyn(shape), data).expect("shape ok")
362    }
363
364    #[test]
365    fn test_fftn_ifftn_roundtrip_1d() {
366        let x = make_complex_array(&[16]);
367        let s = fftn_complex(&x, None).expect("fftn");
368        let r = ifftn_complex(&s, None).expect("ifftn");
369        for (a, b) in x.iter().zip(r.iter()) {
370            assert_relative_eq!(a.re, b.re, epsilon = 1e-10);
371            assert_relative_eq!(a.im, b.im, epsilon = 1e-10);
372        }
373    }
374
375    #[test]
376    fn test_fftn_ifftn_roundtrip_2d() {
377        let x = make_complex_array(&[4, 8]);
378        let s = fftn_complex(&x, None).expect("fftn 2d");
379        let r = ifftn_complex(&s, None).expect("ifftn 2d");
380        for (a, b) in x.iter().zip(r.iter()) {
381            assert_relative_eq!(a.re, b.re, epsilon = 1e-10);
382            assert_relative_eq!(a.im, b.im, epsilon = 1e-10);
383        }
384    }
385
386    #[test]
387    fn test_fftn_ifftn_roundtrip_3d() {
388        let x = make_complex_array(&[2, 3, 4]);
389        let s = fftn_complex(&x, None).expect("fftn 3d");
390        let r = ifftn_complex(&s, None).expect("ifftn 3d");
391        for (a, b) in x.iter().zip(r.iter()) {
392            assert_relative_eq!(a.re, b.re, epsilon = 1e-10);
393            assert_relative_eq!(a.im, b.im, epsilon = 1e-10);
394        }
395    }
396
397    #[test]
398    fn test_fftn_partial_axes() {
399        let x = make_complex_array(&[4, 8]);
400        // Only transform axis 1
401        let s1 = fftn_complex(&x, Some(&[1])).expect("fftn axis 1");
402        let r1 = ifftn_complex(&s1, Some(&[1])).expect("ifftn axis 1");
403        for (a, b) in x.iter().zip(r1.iter()) {
404            assert_relative_eq!(a.re, b.re, epsilon = 1e-10);
405            assert_relative_eq!(a.im, b.im, epsilon = 1e-10);
406        }
407    }
408
409    #[test]
410    fn test_fftn_out_of_bounds_axis() {
411        let x = make_complex_array(&[4, 8]);
412        assert!(fftn_complex(&x, Some(&[2])).is_err()); // only 2 axes (0, 1)
413        assert!(ifftn_complex(&x, Some(&[5])).is_err());
414    }
415
416    #[test]
417    fn test_fftn_shape_preserved() {
418        let x = make_complex_array(&[3, 5, 7]);
419        let s = fftn_complex(&x, None).expect("fftn");
420        assert_eq!(s.shape(), x.shape());
421    }
422
423    // ── fftshift2 / ifftshift2 ───────────────────────────────────────────────
424
425    #[test]
426    fn test_fftshift2_roundtrip_even() {
427        let rows = 4;
428        let cols = 6;
429        let data: Vec<Complex64> = (0..(rows * cols) as i32)
430            .map(|i| Complex64::new(i as f64, 0.0))
431            .collect();
432        let x = Array2::from_shape_vec((rows, cols), data).expect("shape");
433        let shifted = fftshift2(&x);
434        let recovered = ifftshift2(&shifted);
435        for r in 0..rows {
436            for c in 0..cols {
437                assert_relative_eq!(x[[r, c]].re, recovered[[r, c]].re, epsilon = 1e-12);
438            }
439        }
440    }
441
442    #[test]
443    fn test_fftshift2_roundtrip_odd() {
444        let rows = 5;
445        let cols = 7;
446        let data: Vec<Complex64> = (0..(rows * cols) as i32)
447            .map(|i| Complex64::new(i as f64, i as f64 * 0.1))
448            .collect();
449        let x = Array2::from_shape_vec((rows, cols), data).expect("shape");
450        let shifted = fftshift2(&x);
451        let recovered = ifftshift2(&shifted);
452        for r in 0..rows {
453            for c in 0..cols {
454                assert_relative_eq!(x[[r, c]].re, recovered[[r, c]].re, epsilon = 1e-12);
455                assert_relative_eq!(x[[r, c]].im, recovered[[r, c]].im, epsilon = 1e-12);
456            }
457        }
458    }
459
460    #[test]
461    fn test_fftshift2_dc_to_centre() {
462        let mut data = Array2::<Complex64>::zeros((4, 4));
463        data[[0, 0]] = Complex64::new(1.0, 0.0);
464        let shifted = fftshift2(&data);
465        // For n=4, shift = 2 → DC moves to [2, 2]
466        assert_relative_eq!(shifted[[2, 2]].re, 1.0, epsilon = 1e-12);
467        assert_relative_eq!(shifted[[0, 0]].re, 0.0, epsilon = 1e-12);
468    }
469
470    #[test]
471    fn test_ifftshift2_dc_back() {
472        let mut data = Array2::<Complex64>::zeros((4, 4));
473        data[[0, 0]] = Complex64::new(1.0, 0.0);
474        let shifted = fftshift2(&data);
475        let recovered = ifftshift2(&shifted);
476        assert_relative_eq!(recovered[[0, 0]].re, 1.0, epsilon = 1e-12);
477    }
478
479    // ── fftfreq_nd ───────────────────────────────────────────────────────────
480
481    #[test]
482    fn test_fftfreq_nd_basic() {
483        let freqs = fftfreq_nd(&[4, 8], &[1.0, 1.0]).expect("fftfreq_nd");
484        assert_eq!(freqs.len(), 2);
485        assert_eq!(freqs[0].len(), 4);
486        assert_eq!(freqs[1].len(), 8);
487        // DC is always 0
488        assert_relative_eq!(freqs[0][0], 0.0, epsilon = 1e-15);
489        assert_relative_eq!(freqs[1][0], 0.0, epsilon = 1e-15);
490    }
491
492    #[test]
493    fn test_fftfreq_nd_matches_1d_fftfreq() {
494        // Compare with the scalar fftfreq from crate::helper
495        use crate::helper::fftfreq;
496        let n = 16;
497        let d = 0.5;
498        let nd_freqs = fftfreq_nd(&[n], &[d]).expect("nd");
499        let scalar_freqs = fftfreq(n, d).expect("1d");
500        assert_eq!(nd_freqs[0].len(), scalar_freqs.len());
501        for (a, b) in nd_freqs[0].iter().zip(scalar_freqs.iter()) {
502            assert_relative_eq!(*a, *b, epsilon = 1e-14);
503        }
504    }
505
506    #[test]
507    fn test_fftfreq_nd_spacing() {
508        // With d=0.5 the max positive frequency doubles compared to d=1.0
509        let f1 = fftfreq_nd(&[8], &[1.0]).expect("d=1");
510        let f2 = fftfreq_nd(&[8], &[0.5]).expect("d=0.5");
511        // Max positive freq for n=8, d=1: 3/8; for d=0.5: 3/4
512        assert_relative_eq!(f1[0][3], 3.0 / 8.0, epsilon = 1e-14);
513        assert_relative_eq!(f2[0][3], 3.0 / 4.0, epsilon = 1e-14);
514    }
515
516    #[test]
517    fn test_fftfreq_nd_mismatch_error() {
518        assert!(fftfreq_nd(&[4, 8], &[1.0]).is_err()); // lengths differ
519        assert!(fftfreq_nd(&[4], &[0.0]).is_err()); // zero spacing
520        assert!(fftfreq_nd(&[4], &[-1.0]).is_err()); // negative spacing
521    }
522
523    #[test]
524    fn test_fftfreq_nd_empty_axis() {
525        let freqs = fftfreq_nd(&[0, 4], &[1.0, 1.0]).expect("empty axis ok");
526        assert_eq!(freqs[0].len(), 0);
527        assert_eq!(freqs[1].len(), 4);
528    }
529
530    // ── Correctness: 2D FFT shift is consistent with element-wise check ──────
531
532    #[test]
533    fn test_fftshift2_known_pattern() {
534        // Build a 4×4 array with known values at corners
535        let rows = 4;
536        let cols = 4;
537        let mut x = Array2::<Complex64>::zeros((rows, cols));
538        x[[0, 0]] = Complex64::new(1.0, 0.0); // top-left (DC)
539        x[[0, 2]] = Complex64::new(2.0, 0.0); // top-right region
540        x[[2, 0]] = Complex64::new(3.0, 0.0); // bottom-left region
541        x[[2, 2]] = Complex64::new(4.0, 0.0); // bottom-right region
542
543        let shifted = fftshift2(&x);
544        // For n=4 (even), shift = 2 → each element at [r,c] moves to [(r+2)%4, (c+2)%4]
545        assert_relative_eq!(shifted[[2, 2]].re, 1.0, epsilon = 1e-12); // was [0,0]
546        assert_relative_eq!(shifted[[2, 0]].re, 2.0, epsilon = 1e-12); // was [0,2]
547        assert_relative_eq!(shifted[[0, 2]].re, 3.0, epsilon = 1e-12); // was [2,0]
548        assert_relative_eq!(shifted[[0, 0]].re, 4.0, epsilon = 1e-12); // was [2,2]
549    }
550
551    // ── Integration: fftn + fftshift2 on a sinusoidal image ─────────────────
552
553    #[test]
554    fn test_fftn_then_shift_preserves_energy() {
555        use std::f64::consts::PI;
556        let n = 8;
557        // Simple 2D sinusoid
558        let data: Vec<Complex64> = (0..n * n)
559            .map(|k| {
560                let r = k / n;
561                let c = k % n;
562                let re =
563                    (2.0 * PI * r as f64 / n as f64).cos() * (2.0 * PI * c as f64 / n as f64).cos();
564                Complex64::new(re, 0.0)
565            })
566            .collect();
567        let x = ArrayD::from_shape_vec(IxDyn(&[n, n]), data).expect("shape");
568        let spec = fftn_complex(&x, None).expect("fftn");
569        // Parseval: sum |X[k]|^2 = n^2 * sum |x[n]|^2
570        let energy_x: f64 = x.iter().map(|c| c.norm_sqr()).sum();
571        let energy_s: f64 = spec.iter().map(|c| c.norm_sqr()).sum();
572        let n2 = (n * n) as f64;
573        assert_relative_eq!(energy_s, n2 * energy_x, epsilon = 1e-8 * energy_s.max(1.0));
574    }
575}