sci_rs/signal/
convolve.rs1use nalgebra::Complex;
2use num_traits::{Float, FromPrimitive, Signed, Zero};
3use rustfft::{FftNum, FftPlanner};
4
5pub enum ConvolveMode {
7 Full,
9 Valid,
11 Same,
13}
14
15pub fn fftconvolve<F: Float + FftNum>(in1: &[F], in2: &[F], mode: ConvolveMode) -> Vec<F> {
30 let n1 = in1.len();
32 let n2 = in2.len();
33 let n = n1 + n2 - 1;
34 let fft_size = n.next_power_of_two();
35
36 let mut padded_in1 = vec![Complex::zero(); fft_size];
38 let mut padded_in2 = vec![Complex::zero(); fft_size];
39
40 padded_in1.iter_mut().zip(in1.iter()).for_each(|(p, &v)| {
42 *p = Complex::new(v, F::zero());
43 });
44 padded_in2.iter_mut().zip(in2.iter()).for_each(|(p, &v)| {
45 *p = Complex::new(v, F::zero());
46 });
47
48 let mut planner = FftPlanner::new();
50 let fft = planner.plan_fft_forward(fft_size);
51 fft.process(&mut padded_in1);
52 fft.process(&mut padded_in2);
53
54 let mut result_freq: Vec<Complex<F>> = padded_in1
56 .iter()
57 .zip(&padded_in2)
58 .map(|(a, b)| a * b)
59 .collect();
60
61 let ifft = planner.plan_fft_inverse(fft_size);
63 ifft.process(&mut result_freq);
64
65 let fft_size = F::from(fft_size).unwrap();
67 let full_convolution = result_freq
68 .iter()
69 .take(n)
70 .map(|x| x.re / fft_size)
71 .collect();
72
73 match mode {
75 ConvolveMode::Full => full_convolution,
76 ConvolveMode::Valid => {
77 if n1 >= n2 {
78 full_convolution[(n2 - 1)..(n1)].to_vec()
79 } else {
80 Vec::new()
81 }
82 }
83 ConvolveMode::Same => {
84 let start = (n2 - 1) / 2;
85 let end = start + n1;
86 full_convolution[start..end].to_vec()
87 }
88 }
89}
90
91pub fn convolve<F: Float + FftNum>(in1: &[F], in2: &[F], mode: ConvolveMode) -> Vec<F> {
101 fftconvolve(in1, in2, mode)
102}
103
104pub fn correlate<F: Float + FftNum>(in1: &[F], in2: &[F], mode: ConvolveMode) -> Vec<F> {
117 let mut in2_rev = in2.to_vec();
119 in2_rev.reverse();
120 fftconvolve(in1, &in2_rev, mode)
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126 use approx::assert_relative_eq;
127
128 #[test]
129 fn test_convolve() {
130 let in1 = vec![1.0, 2.0, 3.0];
131 let in2 = vec![4.0, 5.0, 6.0];
132 let result = convolve(&in1, &in2, ConvolveMode::Full);
133 let expected = vec![4.0, 13.0, 28.0, 27.0, 18.0];
134
135 for (a, b) in result.iter().zip(expected.iter()) {
136 assert_relative_eq!(a, b, epsilon = 1e-10);
137 }
138 }
139
140 #[test]
141 fn test_correlate() {
142 let in1 = vec![1.0, 2.0, 3.0];
143 let in2 = vec![4.0, 5.0, 6.0];
144 let result = correlate(&in1, &in2, ConvolveMode::Full);
145 let expected = vec![6.0, 17.0, 32.0, 23.0, 12.0];
146 for (a, b) in result.iter().zip(expected.iter()) {
147 assert_relative_eq!(a, b, epsilon = 1e-10);
148 }
149 }
150
151 #[test]
152 fn test_convolve_valid() {
153 let in1 = vec![1.0, 2.0, 3.0, 4.0];
154 let in2 = vec![1.0, 2.0];
155 let result = convolve(&in1, &in2, ConvolveMode::Valid);
156 let expected = vec![4.0, 7.0, 10.0];
157 for (a, b) in result.iter().zip(expected.iter()) {
158 assert_relative_eq!(a, b, epsilon = 1e-10);
159 }
160 }
161
162 #[test]
163 fn test_convolve_same() {
164 let in1 = vec![1.0, 2.0, 3.0, 4.0];
165 let in2 = vec![1.0, 2.0, 1.0];
166 let result = convolve(&in1, &in2, ConvolveMode::Same);
167 let expected = vec![4.0, 8.0, 12.0, 11.0];
168 for (a, b) in result.iter().zip(expected.iter()) {
169 assert_relative_eq!(a, b, epsilon = 1e-10);
170 }
171 }
172
173 #[test]
174 fn test_scipy_example() {
175 use rand::distributions::{Distribution, Standard};
176 use rand::thread_rng;
177
178 let mut rng = thread_rng();
180 let sig: Vec<f64> = Standard.sample_iter(&mut rng).take(1000).collect();
181
182 let autocorr = correlate(&sig, &sig, ConvolveMode::Full);
184
185 assert_eq!(autocorr.len(), 1999); assert!(autocorr.iter().all(|&x| !x.is_nan())); let max_idx = autocorr
191 .iter()
192 .enumerate()
193 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
194 .unwrap()
195 .0;
196 assert!((max_idx as i32 - 999).abs() <= 1); let sig: Vec<f32> = sig.iter().map(|x| *x as f32).collect();
199 let autocorr: Vec<f32> = autocorr.iter().map(|x| *x as f32).collect();
200 crate::plot::python_plot(vec![&sig, &autocorr]);
201 }
202}