sci_rs/signal/
convolve.rs

1use nalgebra::Complex;
2use num_traits::{Float, FromPrimitive, Signed, Zero};
3use rustfft::{FftNum, FftPlanner};
4
5/// Convolution mode determines behavior near edges and output size
6pub enum ConvolveMode {
7    /// Full convolution, output size is `in1.len() + in2.len() - 1`
8    Full,
9    /// Valid convolution, output size is `max(in1.len(), in2.len()) - min(in1.len(), in2.len()) + 1`
10    Valid,
11    /// Same convolution, output size is `in1.len()`
12    Same,
13}
14
15/// Performs FFT-based convolution on two slices of floating point values.
16///
17/// According to Python docs, this is generally much faster than direct convolution
18/// for large arrays (n > ~500), but can be slower when only a few output values are needed.
19/// We only implement the FFT version in Rust for now.
20///
21/// # Arguments
22/// - `in1`: First input signal
23/// - `in2`: Second input signal
24/// - `mode`: Convolution mode (currently only Full is supported)
25///
26/// # Returns
27/// A Vec containing the discrete linear convolution of `in1` with `in2`.
28/// For Full mode, the output length will be `in1.len() + in2.len() - 1`.
29pub fn fftconvolve<F: Float + FftNum>(in1: &[F], in2: &[F], mode: ConvolveMode) -> Vec<F> {
30    // Determine the size of the FFT (next power of 2 for zero-padding)
31    let n1 = in1.len();
32    let n2 = in2.len();
33    let n = n1 + n2 - 1;
34    let fft_size = n.next_power_of_two();
35
36    // Prepare input buffers as Complex<F> with zero-padding to fft_size
37    let mut padded_in1 = vec![Complex::zero(); fft_size];
38    let mut padded_in2 = vec![Complex::zero(); fft_size];
39
40    // Copy input data into zero-padded buffers
41    padded_in1.iter_mut().zip(in1.iter()).for_each(|(p, &v)| {
42        *p = Complex::new(v, F::zero());
43    });
44    padded_in2.iter_mut().zip(in2.iter()).for_each(|(p, &v)| {
45        *p = Complex::new(v, F::zero());
46    });
47
48    // Perform the FFT
49    let mut planner = FftPlanner::new();
50    let fft = planner.plan_fft_forward(fft_size);
51    fft.process(&mut padded_in1);
52    fft.process(&mut padded_in2);
53
54    // Multiply element-wise in the frequency domain
55    let mut result_freq: Vec<Complex<F>> = padded_in1
56        .iter()
57        .zip(&padded_in2)
58        .map(|(a, b)| a * b)
59        .collect();
60
61    // Perform the inverse FFT
62    let ifft = planner.plan_fft_inverse(fft_size);
63    ifft.process(&mut result_freq);
64
65    // Take only the real part, normalize, and truncate to the original output size (n)
66    let fft_size = F::from(fft_size).unwrap();
67    let full_convolution = result_freq
68        .iter()
69        .take(n)
70        .map(|x| x.re / fft_size)
71        .collect();
72
73    // Extract the appropriate slice based on the mode
74    match mode {
75        ConvolveMode::Full => full_convolution,
76        ConvolveMode::Valid => {
77            if n1 >= n2 {
78                full_convolution[(n2 - 1)..(n1)].to_vec()
79            } else {
80                Vec::new()
81            }
82        }
83        ConvolveMode::Same => {
84            let start = (n2 - 1) / 2;
85            let end = start + n1;
86            full_convolution[start..end].to_vec()
87        }
88    }
89}
90
91/// Compute the convolution of two signals using FFT.
92///
93/// # Arguments
94/// * `in1` - First input array
95/// * `in2` - Second input array
96///
97/// # Returns
98/// A Vec containing the convolution of `in1` with `in2`.
99/// With Full mode, the output length will be `in1.len() + in2.len() - 1`.
100pub fn convolve<F: Float + FftNum>(in1: &[F], in2: &[F], mode: ConvolveMode) -> Vec<F> {
101    fftconvolve(in1, in2, mode)
102}
103
104/// Compute the cross-correlation of two signals using FFT.
105///
106/// Cross-correlation is similar to convolution but with flipping one of the signals.
107/// This function uses FFT to compute the correlation efficiently.
108///
109/// # Arguments
110/// * `in1` - First input array
111/// * `in2` - Second input array
112///
113/// # Returns
114/// A Vec containing the cross-correlation of `in1` with `in2`.
115/// With Full mode, the output length will be `in1.len() + in2.len() - 1`.
116pub fn correlate<F: Float + FftNum>(in1: &[F], in2: &[F], mode: ConvolveMode) -> Vec<F> {
117    // For correlation, we need to reverse in2
118    let mut in2_rev = in2.to_vec();
119    in2_rev.reverse();
120    fftconvolve(in1, &in2_rev, mode)
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use approx::assert_relative_eq;
127
128    #[test]
129    fn test_convolve() {
130        let in1 = vec![1.0, 2.0, 3.0];
131        let in2 = vec![4.0, 5.0, 6.0];
132        let result = convolve(&in1, &in2, ConvolveMode::Full);
133        let expected = vec![4.0, 13.0, 28.0, 27.0, 18.0];
134
135        for (a, b) in result.iter().zip(expected.iter()) {
136            assert_relative_eq!(a, b, epsilon = 1e-10);
137        }
138    }
139
140    #[test]
141    fn test_correlate() {
142        let in1 = vec![1.0, 2.0, 3.0];
143        let in2 = vec![4.0, 5.0, 6.0];
144        let result = correlate(&in1, &in2, ConvolveMode::Full);
145        let expected = vec![6.0, 17.0, 32.0, 23.0, 12.0];
146        for (a, b) in result.iter().zip(expected.iter()) {
147            assert_relative_eq!(a, b, epsilon = 1e-10);
148        }
149    }
150
151    #[test]
152    fn test_convolve_valid() {
153        let in1 = vec![1.0, 2.0, 3.0, 4.0];
154        let in2 = vec![1.0, 2.0];
155        let result = convolve(&in1, &in2, ConvolveMode::Valid);
156        let expected = vec![4.0, 7.0, 10.0];
157        for (a, b) in result.iter().zip(expected.iter()) {
158            assert_relative_eq!(a, b, epsilon = 1e-10);
159        }
160    }
161
162    #[test]
163    fn test_convolve_same() {
164        let in1 = vec![1.0, 2.0, 3.0, 4.0];
165        let in2 = vec![1.0, 2.0, 1.0];
166        let result = convolve(&in1, &in2, ConvolveMode::Same);
167        let expected = vec![4.0, 8.0, 12.0, 11.0];
168        for (a, b) in result.iter().zip(expected.iter()) {
169            assert_relative_eq!(a, b, epsilon = 1e-10);
170        }
171    }
172
173    #[test]
174    fn test_scipy_example() {
175        use rand::distributions::{Distribution, Standard};
176        use rand::thread_rng;
177
178        // Generate 1000 random samples from standard normal distribution
179        let mut rng = thread_rng();
180        let sig: Vec<f64> = Standard.sample_iter(&mut rng).take(1000).collect();
181
182        // Compute autocorrelation using correlate directly
183        let autocorr = correlate(&sig, &sig, ConvolveMode::Full);
184
185        // Basic sanity checks
186        assert_eq!(autocorr.len(), 1999); // Full convolution length should be 2N-1
187        assert!(autocorr.iter().all(|&x| !x.is_nan())); // No NaN values
188
189        // Maximum correlation should be near the middle since it's autocorrelation
190        let max_idx = autocorr
191            .iter()
192            .enumerate()
193            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
194            .unwrap()
195            .0;
196        assert!((max_idx as i32 - 999).abs() <= 1); // Should be near index 999
197
198        let sig: Vec<f32> = sig.iter().map(|x| *x as f32).collect();
199        let autocorr: Vec<f32> = autocorr.iter().map(|x| *x as f32).collect();
200        crate::plot::python_plot(vec![&sig, &autocorr]);
201    }
202}