scirs2_fft/
nufft.rs

1//! Non-Uniform Fast Fourier Transform module
2//!
3//! This module provides implementations of the Non-Uniform Fast Fourier Transform (NUFFT),
4//! which computes the Fourier transform of data sampled at non-uniform intervals.
5//!
6//! # Overview
7//!
8//! NUFFT is an extension of the FFT for data that is not sampled on a uniform grid.
9//! This implementation uses a grid-based approach with interpolation to approximate
10//! the non-uniform Fourier transform efficiently.
11//!
12//! # Types of NUFFT
13//!
14//! * Type 1 (Non-Uniform to Uniform): Data at non-uniform locations, transform to uniform frequency grid
15//! * Type 2 (Uniform to Non-Uniform): Data at uniform locations, transform to non-uniform frequency grid
16
17use crate::error::{FFTError, FFTResult};
18use scirs2_core::numeric::Complex64;
19use scirs2_core::numeric::Zero;
20use std::f64::consts::PI;
21
22/// NUFFT interpolation type
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum InterpolationType {
25    /// Linear interpolation
26    Linear,
27    /// Gaussian kernel-based interpolation
28    Gaussian,
29    /// Minimal peak width Gaussian
30    MinGaussian,
31}
32
33/// Performs Non-Uniform Fast Fourier Transform (NUFFT) Type 1.
34///
35/// This function computes the FFT of data sampled at non-uniform locations.
36/// NUFFT Type 1 transforms from non-uniform samples to a uniform frequency grid.
37///
38/// # Arguments
39///
40/// * `x` - Non-uniform sample points (must be in range [-π, π])
41/// * `samples` - Complex sample values at the non-uniform points
42/// * `m` - Size of the output uniform grid
43/// * `interp_type` - Interpolation type to use
44/// * `epsilon` - Desired precision (typically 1e-6 to 1e-12)
45///
46/// # Returns
47///
48/// * A complex-valued array containing the NUFFT on a uniform frequency grid
49///
50/// # Errors
51///
52/// Returns an error if the computation fails or inputs are invalid.
53///
54/// # Examples
55///
56/// ```
57/// use scirs2_fft::nufft::{nufft_type1, InterpolationType};
58/// use scirs2_core::numeric::Complex64;
59/// use std::f64::consts::PI;
60///
61/// // Create non-uniform sample points in [-π, π]
62/// let n = 100;
63/// let x: Vec<f64> = (0..n).map(|i| -PI + 1.8 * PI * i as f64 / (n as f64)).collect();
64///
65/// // Create sample values for a simple function (e.g., a Gaussian)
66/// let samples: Vec<Complex64> = x.iter()
67///     .map(|&xi| {
68///         let real = (-xi.powi(2) / 2.0).exp();
69///         Complex64::new(real, 0.0)
70///     })
71///     .collect();
72///
73/// // Compute NUFFT
74/// let m = 128;  // Output grid size
75/// let result = nufft_type1(&x, &samples, m, InterpolationType::Gaussian, 1e-6).unwrap();
76///
77/// // The transform of a Gaussian is another Gaussian
78/// assert!(result.len() == m);
79/// ```
80///
81/// # Notes
82///
83/// This is a basic implementation. For performance-critical applications,
84/// consider using a more optimized NUFFT library.
85#[allow(dead_code)]
86pub fn nufft_type1(
87    x: &[f64],
88    samples: &[Complex64],
89    m: usize,
90    interp_type: InterpolationType,
91    epsilon: f64,
92) -> FFTResult<Vec<Complex64>> {
93    // Check inputs
94    if x.len() != samples.len() {
95        return Err(FFTError::DimensionError(
96            "Sample points and values must have the same length".to_string(),
97        ));
98    }
99
100    if epsilon <= 0.0 {
101        return Err(FFTError::ValueError(
102            "Precision parameter epsilon must be positive".to_string(),
103        ));
104    }
105
106    // Check if x values are in the correct range [-π, π]
107    for &xi in x {
108        if !(-PI..=PI).contains(&xi) {
109            return Err(FFTError::ValueError(
110                "Sample points must be in the range [-π, π]".to_string(),
111            ));
112        }
113    }
114
115    // Estimate parameters for the algorithm
116    let tau = 2.0; // Oversampling factor, usually in range [1.5, 2.5]
117    let n_grid = tau as usize * m; // Size of the oversampled grid
118
119    // Determine the width parameter based on the chosen interpolation _type
120    let sigma = match interp_type {
121        InterpolationType::Linear => 2.0,
122        InterpolationType::Gaussian => 2.0 * (-epsilon.ln()).sqrt(),
123        InterpolationType::MinGaussian => 1.0,
124    };
125
126    // Compute the spreading width (kernel half-width)
127    let width = (sigma * sigma * (-epsilon.ln()) / PI).ceil() as usize;
128    let width = width.max(2); // At least 2 for stability
129
130    // Grid spacing
131    let h_grid = 2.0 * PI / n_grid as f64;
132
133    // Initialize the oversampled grid
134    let mut grid_data = vec![Complex64::zero(); n_grid];
135
136    // Spread the non-uniform data onto the uniform grid using the chosen kernel
137    for (&xi, &sample) in x.iter().zip(samples.iter()) {
138        // Map the x value to the grid index
139        let x_grid = (xi + PI) / h_grid;
140        let i_grid = x_grid.floor() as isize;
141
142        // Spread the sample to nearby grid points
143        for j in (-(width as isize))..=(width as isize) {
144            let idx = (i_grid + j).rem_euclid(n_grid as isize) as usize;
145            let kernel_arg = (x_grid - (i_grid + j) as f64) / sigma;
146
147            let kernel_value = match interp_type {
148                InterpolationType::Linear => {
149                    if kernel_arg.abs() <= 1.0 {
150                        1.0 - kernel_arg.abs()
151                    } else {
152                        0.0
153                    }
154                }
155                InterpolationType::Gaussian | InterpolationType::MinGaussian => {
156                    (-kernel_arg * kernel_arg).exp()
157                }
158            };
159
160            grid_data[idx] += sample * kernel_value;
161        }
162    }
163
164    // Compute the FFT of the grid data
165    let grid_fft = fft_backend(&grid_data)?;
166
167    // Extract the desired frequency components
168    let mut result = Vec::with_capacity(m);
169
170    for i in 0..m {
171        if i <= m / 2 {
172            // Positive frequencies
173            result.push(grid_fft[i]);
174        } else {
175            // Negative frequencies
176            result.push(grid_fft[n_grid - (m - i)]);
177        }
178    }
179
180    Ok(result)
181}
182
183/// Performs Non-Uniform Fast Fourier Transform (NUFFT) Type 2.
184///
185/// This function computes the FFT from a uniform grid to non-uniform frequencies.
186/// NUFFT Type 2 is essentially the adjoint of Type 1.
187///
188/// # Arguments
189///
190/// * `spectrum` - Input spectrum on a uniform grid
191/// * `x` - Non-uniform frequency points where output is desired (must be in [-π, π])
192/// * `interp_type` - Interpolation type to use
193/// * `epsilon` - Desired precision (typically 1e-6 to 1e-12)
194///
195/// # Returns
196///
197/// * A complex-valued array containing the NUFFT at the specified non-uniform points
198///
199/// # Errors
200///
201/// Returns an error if the computation fails or inputs are invalid.
202///
203/// # Examples
204///
205/// ```
206/// use scirs2_fft::nufft::{nufft_type2, InterpolationType};
207/// use scirs2_core::numeric::Complex64;
208/// use std::f64::consts::PI;
209///
210/// // Create a spectrum on a uniform grid
211/// let m = 128;
212/// let spectrum: Vec<Complex64> = (0..m)
213///     .map(|i| {
214///         // Simple Gaussian in frequency domain
215///         let f = i as f64 - m as f64 / 2.0;
216///         let val = (-f * f / (2.0 * 10.0)).exp();
217///         Complex64::new(val, 0.0)
218///     })
219///     .collect();
220///
221/// // Define non-uniform points where we want to evaluate the transform
222/// // Ensure all points are in the range [-π, π]
223/// let n = 100;
224/// let x: Vec<f64> = (0..n).map(|i| -PI + 1.99 * PI * i as f64 / (n as f64 - 1.0)).collect();
225///
226/// // Compute NUFFT Type 2
227/// let result = nufft_type2(&spectrum, &x, InterpolationType::Gaussian, 1e-6).unwrap();
228///
229/// // The output should have the same length as the non-uniform points
230/// assert_eq!(result.len(), x.len());
231/// ```
232///
233/// # Notes
234///
235/// This is a basic implementation. For performance-critical applications,
236/// consider using a more optimized NUFFT library.
237#[allow(dead_code)]
238pub fn nufft_type2(
239    spectrum: &[Complex64],
240    x: &[f64],
241    interp_type: InterpolationType,
242    epsilon: f64,
243) -> FFTResult<Vec<Complex64>> {
244    // Check inputs
245    if epsilon <= 0.0 {
246        return Err(FFTError::ValueError(
247            "Precision parameter epsilon must be positive".to_string(),
248        ));
249    }
250
251    // Check if x values are in the correct range [-π, π]
252    for &xi in x {
253        if !(-PI..=PI).contains(&xi) {
254            return Err(FFTError::ValueError(
255                "Output points must be in the range [-π, π]".to_string(),
256            ));
257        }
258    }
259
260    let m = spectrum.len();
261    let tau = 2.0; // Oversampling factor
262    let n_grid = tau as usize * m; // Size of the oversampled grid
263
264    // Determine the width parameter
265    let sigma = match interp_type {
266        InterpolationType::Linear => 2.0,
267        InterpolationType::Gaussian => 2.0 * (-epsilon.ln()).sqrt(),
268        InterpolationType::MinGaussian => 1.0,
269    };
270
271    // Compute the spreading width (kernel half-width)
272    let width = (sigma * sigma * (-epsilon.ln()) / PI).ceil() as usize;
273    let width = width.max(2); // At least 2 for stability
274
275    // Prepare the input for the inverse FFT
276    let mut padded_spectrum = vec![Complex64::zero(); n_grid];
277
278    // Copy the spectrum to the padded array
279    for i in 0..m {
280        if i <= m / 2 {
281            // Positive frequencies
282            padded_spectrum[i] = spectrum[i];
283        } else {
284            // Negative frequencies
285            padded_spectrum[n_grid - (m - i)] = spectrum[i];
286        }
287    }
288
289    // Compute the inverse FFT
290    let grid_ifft = ifft_backend(&padded_spectrum)?;
291
292    // Grid spacing
293    let h_grid = 2.0 * PI / n_grid as f64;
294
295    // Interpolate at the non-uniform points
296    let mut result = vec![Complex64::zero(); x.len()];
297
298    for (i, &xi) in x.iter().enumerate() {
299        // Map the x value to the grid index
300        let x_grid = (xi + PI) / h_grid;
301        let i_grid = x_grid.floor() as isize;
302
303        // Interpolate from nearby grid points
304        for j in (-(width as isize))..=(width as isize) {
305            let idx = (i_grid + j).rem_euclid(n_grid as isize) as usize;
306            let kernel_arg = (x_grid - (i_grid + j) as f64) / sigma;
307
308            let kernel_value = match interp_type {
309                InterpolationType::Linear => {
310                    if kernel_arg.abs() <= 1.0 {
311                        1.0 - kernel_arg.abs()
312                    } else {
313                        0.0
314                    }
315                }
316                InterpolationType::Gaussian | InterpolationType::MinGaussian => {
317                    (-kernel_arg * kernel_arg).exp()
318                }
319            };
320
321            result[i] += grid_ifft[idx] * kernel_value;
322        }
323    }
324
325    Ok(result)
326}
327
328/// Helper function for FFT computation used in NUFFT implementations
329#[allow(dead_code)]
330fn fft_backend(data: &[Complex64]) -> FFTResult<Vec<Complex64>> {
331    use rustfft::{num_complex::Complex, FftPlanner};
332
333    let n = data.len();
334    let mut planner = FftPlanner::new();
335    let fft = planner.plan_fft_forward(n);
336
337    // Convert to rustfft's Complex type
338    let mut buffer: Vec<Complex<f64>> = data.iter().map(|&c| Complex::new(c.re, c.im)).collect();
339
340    // Perform the FFT
341    fft.process(&mut buffer);
342
343    // Convert back to scirs2_core::numeric::Complex64
344    Ok(buffer
345        .into_iter()
346        .map(|c| Complex64::new(c.re, c.im))
347        .collect())
348}
349
350/// Helper function for IFFT computation used in NUFFT implementations
351#[allow(dead_code)]
352fn ifft_backend(data: &[Complex64]) -> FFTResult<Vec<Complex64>> {
353    use rustfft::{num_complex::Complex, FftPlanner};
354
355    let n = data.len();
356    let mut planner = FftPlanner::new();
357    let ifft = planner.plan_fft_inverse(n);
358
359    // Convert to rustfft's Complex type
360    let mut buffer: Vec<Complex<f64>> = data.iter().map(|&c| Complex::new(c.re, c.im)).collect();
361
362    // Perform the IFFT
363    ifft.process(&mut buffer);
364
365    // Convert back to scirs2_core::numeric::Complex64 and normalize
366    let scale = 1.0 / n as f64;
367    Ok(buffer
368        .into_iter()
369        .map(|c| Complex64::new(c.re * scale, c.im * scale))
370        .collect())
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use approx::assert_relative_eq;
377
378    #[test]
379    fn test_nufft_type1_gaussian() {
380        // Create a Gaussian function
381        let n = 100;
382        let x: Vec<f64> = (0..n)
383            .map(|i| -PI + 1.8 * PI * i as f64 / (n as f64))
384            .collect();
385        let samples: Vec<Complex64> = x
386            .iter()
387            .map(|&xi| {
388                let real = (-xi.powi(2) / 2.0).exp();
389                Complex64::new(real, 0.0)
390            })
391            .collect();
392
393        // Compute NUFFT Type 1
394        let m = 128;
395        let result = nufft_type1(&x, &samples, m, InterpolationType::Gaussian, 1e-6).unwrap();
396
397        // The transform of a Gaussian is another Gaussian
398        // Check that the result is not all zeros and has the expected length
399        assert_eq!(result.len(), m);
400        assert!(result.iter().any(|&c| c.norm() > 1e-10));
401
402        // For a Gaussian, the transform should be centered around the middle
403        // and have significant energy, but we won't test exact shapes
404        // as this is an approximation
405
406        // Simply check that the spectrum is not all zeros or uniform
407        let max_val = result.iter().map(|c| c.norm()).fold(0.0, f64::max);
408        let min_val = result
409            .iter()
410            .map(|c| c.norm())
411            .fold(f64::INFINITY, f64::min);
412
413        // Ensure we have a reasonable dynamic range in the spectrum
414        assert!(max_val > 0.0);
415        assert!(min_val >= 0.0);
416        // Some frequency components should be at least 2x stronger than others
417        assert!(max_val > min_val * 2.0);
418    }
419
420    #[test]
421    fn test_nufft_type2_consistency() {
422        // Create a simple spectrum (impulse at the center)
423        let m = 32;
424        let mut spectrum = vec![Complex64::new(0.0, 0.0); m];
425        spectrum[m / 2] = Complex64::new(1.0, 0.0);
426
427        // Define non-uniform points
428        let n = 50;
429        let x: Vec<f64> = (0..n)
430            .map(|i| -PI + 1.8 * PI * i as f64 / (n as f64))
431            .collect();
432
433        // Compute NUFFT Type 2
434        let result = nufft_type2(&spectrum, &x, InterpolationType::Gaussian, 1e-6).unwrap();
435
436        // Result should be approximately constant magnitude complex exponentials
437        assert_eq!(result.len(), n);
438
439        // Check that magnitudes are approximately constant
440        let avg_magnitude: f64 = result.iter().map(|c| c.norm()).sum::<f64>() / n as f64;
441        for c in result {
442            assert_relative_eq!(c.norm(), avg_magnitude, epsilon = 0.2);
443        }
444    }
445
446    #[test]
447    fn test_nufft_type1_linear_interp() {
448        // Create a simple cosine function
449        let n = 120;
450        let x: Vec<f64> = (0..n)
451            .map(|i| -PI + 1.8 * PI * i as f64 / (n as f64))
452            .collect();
453        let samples: Vec<Complex64> = x.iter().map(|&xi| Complex64::new(xi.cos(), 0.0)).collect();
454
455        // Compute NUFFT Type 1 with linear interpolation
456        let m = 64;
457        let result = nufft_type1(&x, &samples, m, InterpolationType::Linear, 1e-6).unwrap();
458
459        // For a cosine function, we expect peaks at k=±1
460        assert_eq!(result.len(), m);
461
462        // Find the two largest peaks
463        let mut magnitudes: Vec<(usize, f64)> = result
464            .iter()
465            .enumerate()
466            .map(|(i, &c)| (i, c.norm()))
467            .collect();
468        magnitudes.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
469
470        // Check that the peaks are at the expected frequencies
471        let peak1 = magnitudes[0].0;
472        let peak2 = magnitudes[1].0;
473
474        // Either we have peaks at 1 and m-1 or at some other symmetric pair
475        let matches_expected = (peak1 == 1 && peak2 == m - 1) || (peak1 == m - 1 && peak2 == 1);
476
477        assert!(matches_expected || (peak1 as i32 - peak2 as i32).abs() == 2);
478    }
479
480    #[test]
481    fn test_nufft_errors() {
482        // Test with mismatched lengths
483        let x = vec![0.0, 1.0];
484        let samples = vec![Complex64::new(1.0, 0.0)];
485
486        let result = nufft_type1(&x, &samples, 8, InterpolationType::Gaussian, 1e-6);
487        assert!(result.is_err());
488
489        // Test with invalid epsilon
490        let x = vec![0.0];
491        let samples = vec![Complex64::new(1.0, 0.0)];
492
493        let result = nufft_type1(&x, &samples, 8, InterpolationType::Gaussian, -1.0);
494        assert!(result.is_err());
495
496        // Test with x values outside [-π, π]
497        let x = vec![4.0];
498        let samples = vec![Complex64::new(1.0, 0.0)];
499
500        let result = nufft_type1(&x, &samples, 8, InterpolationType::Gaussian, 1e-6);
501        assert!(result.is_err());
502    }
503}