scir_fft/
lib.rs

1//! FFT utilities for SciR.
2#![deny(missing_docs)]
3
4use ndarray::Array1;
5use num_complex::Complex64;
6use realfft::RealFftPlanner;
7use rustfft::FftPlanner;
8
9/// Compute the forward FFT of a real-valued array.
10///
11/// # Examples
12/// ```
13/// use ndarray::Array1;
14/// let x = Array1::from_vec(vec![0.0, 1.0, 0.0, -1.0]);
15/// let y = scir_fft::fft(&x);
16/// assert_eq!(y.len(), x.len());
17/// ```
18pub fn fft(input: &Array1<f64>) -> Array1<Complex64> {
19    let mut planner = FftPlanner::<f64>::new();
20    let fft = planner.plan_fft_forward(input.len());
21    let mut buffer: Vec<Complex64> = input.iter().map(|&x| Complex64::new(x, 0.0)).collect();
22    fft.process(&mut buffer);
23    Array1::from_vec(buffer)
24}
25
26/// Compute the inverse FFT of a complex-valued array.
27///
28/// # Examples
29/// ```
30/// use ndarray::Array1;
31/// use num_complex::Complex64;
32/// let x = Array1::from_vec(vec![Complex64::new(1.0,0.0); 4]);
33/// let y = scir_fft::ifft(&x);
34/// assert_eq!(y.len(), x.len());
35/// ```
36pub fn ifft(input: &Array1<Complex64>) -> Array1<Complex64> {
37    let mut planner = FftPlanner::<f64>::new();
38    let fft = planner.plan_fft_inverse(input.len());
39    let mut buffer: Vec<Complex64> = input.to_vec();
40    fft.process(&mut buffer);
41    let n = input.len() as f64;
42    Array1::from_vec(buffer.into_iter().map(|v| v / n).collect())
43}
44
45/// Compute the forward real FFT of a real-valued array.
46///
47/// # Examples
48/// ```
49/// use ndarray::Array1;
50/// let x = Array1::from_vec(vec![0.0, 1.0, 0.0, -1.0]);
51/// let y = scir_fft::rfft(&x);
52/// assert!(y.len() >= 1);
53/// ```
54pub fn rfft(input: &Array1<f64>) -> Array1<Complex64> {
55    let mut planner = RealFftPlanner::<f64>::new();
56    let rfft = planner.plan_fft_forward(input.len());
57    let mut buffer = input.to_vec();
58    let mut spectrum = rfft.make_output_vec();
59    rfft.process(&mut buffer, &mut spectrum).unwrap();
60    Array1::from_vec(spectrum)
61}
62
63/// Compute the inverse real FFT producing a real-valued array.
64///
65/// # Examples
66/// ```
67/// use ndarray::Array1;
68/// use num_complex::Complex64;
69/// let spec = Array1::from_vec(vec![Complex64::new(1.0,0.0); 3]);
70/// let x = scir_fft::irfft(&spec);
71/// assert!(x.len() >= 2);
72/// ```
73pub fn irfft(input: &Array1<Complex64>) -> Array1<f64> {
74    let n = (input.len() - 1) * 2;
75    let mut planner = RealFftPlanner::<f64>::new();
76    let irfft = planner.plan_fft_inverse(n);
77    let mut buffer = input.to_vec();
78    let mut output = irfft.make_output_vec();
79    irfft.process(&mut buffer, &mut output).unwrap();
80    let n_f64 = n as f64;
81    Array1::from_vec(output.into_iter().map(|v| v / n_f64).collect())
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87    use ndarray_npy::ReadNpyExt;
88    use scir_core::assert_close;
89    use std::{fs::File, path::PathBuf};
90    fn fixtures_base() -> Option<PathBuf> {
91        let base = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../fixtures");
92        if base.exists() {
93            Some(base)
94        } else {
95            None
96        }
97    }
98
99    #[test]
100    fn fft_matches_fixtures() {
101        let Some(base) = fixtures_base() else {
102            eprintln!("[scir-fft] fixtures missing; skipping fft_matches_fixtures");
103            return;
104        };
105        for &n in &[8, 16] {
106            let in_path = base.join(format!("fft_input_{n}.npy"));
107            let out_path = base.join(format!("fft_output_{n}.npy"));
108            let input: Array1<f64> = match File::open(&in_path)
109                .ok()
110                .and_then(|f| ReadNpyExt::read_npy(f).ok())
111            {
112                Some(v) => v,
113                None => {
114                    eprintln!("[scir-fft] missing {}; skipping", in_path.display());
115                    return;
116                }
117            };
118            let expected: Array1<Complex64> = match File::open(&out_path)
119                .ok()
120                .and_then(|f| ReadNpyExt::read_npy(f).ok())
121            {
122                Some(v) => v,
123                None => {
124                    eprintln!("[scir-fft] missing {}; skipping", out_path.display());
125                    return;
126                }
127            };
128            let result = fft(&input);
129            assert_close!(&result, &expected, complex_array, atol = 1e-9, rtol = 1e-9);
130        }
131    }
132
133    #[test]
134    fn ifft_matches_fixtures() {
135        let Some(base) = fixtures_base() else {
136            eprintln!("[scir-fft] fixtures missing; skipping ifft_matches_fixtures");
137            return;
138        };
139        for &n in &[8, 16] {
140            let in_path = base.join(format!("fft_output_{n}.npy"));
141            let exp_path = base.join(format!("ifft_output_{n}.npy"));
142            let input: Array1<Complex64> = match File::open(&in_path)
143                .ok()
144                .and_then(|f| ReadNpyExt::read_npy(f).ok())
145            {
146                Some(v) => v,
147                None => {
148                    eprintln!("[scir-fft] missing {}; skipping", in_path.display());
149                    return;
150                }
151            };
152            let expected: Array1<Complex64> = match File::open(&exp_path)
153                .ok()
154                .and_then(|f| ReadNpyExt::read_npy(f).ok())
155            {
156                Some(v) => v,
157                None => {
158                    eprintln!("[scir-fft] missing {}; skipping", exp_path.display());
159                    return;
160                }
161            };
162            let result = ifft(&input);
163            assert_close!(&result, &expected, complex_array, atol = 1e-9, rtol = 1e-9);
164        }
165    }
166
167    #[test]
168    fn rfft_matches_fixtures() {
169        let Some(base) = fixtures_base() else {
170            eprintln!("[scir-fft] fixtures missing; skipping rfft_matches_fixtures");
171            return;
172        };
173        for &n in &[8, 16] {
174            let in_path = base.join(format!("fft_input_{n}.npy"));
175            let exp_path = base.join(format!("rfft_output_{n}.npy"));
176            let input: Array1<f64> = match File::open(&in_path)
177                .ok()
178                .and_then(|f| ReadNpyExt::read_npy(f).ok())
179            {
180                Some(v) => v,
181                None => {
182                    eprintln!("[scir-fft] missing {}; skipping", in_path.display());
183                    return;
184                }
185            };
186            let expected: Array1<Complex64> = match File::open(&exp_path)
187                .ok()
188                .and_then(|f| ReadNpyExt::read_npy(f).ok())
189            {
190                Some(v) => v,
191                None => {
192                    eprintln!("[scir-fft] missing {}; skipping", exp_path.display());
193                    return;
194                }
195            };
196            let result = rfft(&input);
197            assert_close!(&result, &expected, complex_array, atol = 1e-9, rtol = 1e-9);
198        }
199    }
200
201    #[test]
202    fn irfft_matches_fixtures() {
203        let Some(base) = fixtures_base() else {
204            eprintln!("[scir-fft] fixtures missing; skipping irfft_matches_fixtures");
205            return;
206        };
207        for &n in &[8, 16] {
208            let in_path = base.join(format!("rfft_output_{n}.npy"));
209            let exp_path = base.join(format!("fft_input_{n}.npy"));
210            let input: Array1<Complex64> = match File::open(&in_path)
211                .ok()
212                .and_then(|f| ReadNpyExt::read_npy(f).ok())
213            {
214                Some(v) => v,
215                None => {
216                    eprintln!("[scir-fft] missing {}; skipping", in_path.display());
217                    return;
218                }
219            };
220            let expected: Array1<f64> = match File::open(&exp_path)
221                .ok()
222                .and_then(|f| ReadNpyExt::read_npy(f).ok())
223            {
224                Some(v) => v,
225                None => {
226                    eprintln!("[scir-fft] missing {}; skipping", exp_path.display());
227                    return;
228                }
229            };
230            let result = irfft(&input);
231            assert_close!(&result, &expected, array, atol = 1e-9, rtol = 1e-9);
232        }
233    }
234}