1use crate::error::{FFTError, FFTResult};
7use crate::fft::{fft, ifft};
8use scirs2_core::numeric::Complex64;
9use scirs2_core::numeric::NumCast;
10use std::fmt::Debug;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum FilterType {
15 LowPass,
17 HighPass,
19 BandPass,
21 BandStop,
23 Custom,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum FilterWindow {
30 Rectangular,
32 Hamming,
34 Hanning,
36 Blackman,
38 Kaiser,
40}
41
42#[derive(Debug, Clone)]
44pub struct FilterSpec {
45 pub filter_type: FilterType,
47 pub order: usize,
49 pub cutoff: f64,
51 pub cutoff_high: Option<f64>,
53 pub window: FilterWindow,
55 pub kaiser_beta: Option<f64>,
57 pub custom_coeffs: Option<Vec<f64>>,
59}
60
61impl Default for FilterSpec {
62 fn default() -> Self {
63 Self {
64 filter_type: FilterType::LowPass,
65 order: 64,
66 cutoff: 0.25,
67 cutoff_high: None,
68 window: FilterWindow::Hamming,
69 kaiser_beta: None,
70 custom_coeffs: None,
71 }
72 }
73}
74
75#[allow(dead_code)]
86pub fn frequency_filter<T>(signal: &[T], filterspec: &FilterSpec) -> FFTResult<Vec<f64>>
87where
88 T: NumCast + Copy + Debug,
89{
90 let max_size = 1024;
92 let limit = signal.len().min(max_size);
93
94 let input: Vec<f64> = signal
96 .iter()
97 .take(limit)
98 .map(|&val| {
99 NumCast::from(val)
100 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {:?} to f64", val)))
101 })
102 .collect::<FFTResult<Vec<_>>>()?;
103
104 let spectrum = fft(&input, None)?;
106
107 let freq_response = design_frequency_response(filterspec, spectrum.len())?;
109
110 let filtered_spectrum: Vec<Complex64> = spectrum
112 .iter()
113 .zip(&freq_response)
114 .map(|(&s, &r)| s * r)
115 .collect();
116
117 let result = ifft(&filtered_spectrum, None)?;
119
120 let filtered: Vec<f64> = result.iter().map(|c| c.re).collect();
122
123 Ok(filtered)
124}
125
126#[allow(dead_code)]
137fn design_frequency_response(filter_spec: &FilterSpec, size: usize) -> FFTResult<Vec<f64>> {
138 if let Some(ref coeffs) = filter_spec.custom_coeffs {
139 if filter_spec.filter_type == FilterType::Custom {
140 return Ok(coeffs.clone());
142 }
143 }
144
145 let mut response = vec![0.0; size];
146
147 let cutoff_idx = (filter_spec.cutoff * size as f64) as usize;
149 let cutoff_high_idx = filter_spec
150 .cutoff_high
151 .map(|c| (c * size as f64) as usize)
152 .unwrap_or(cutoff_idx);
153
154 match filter_spec.filter_type {
155 FilterType::LowPass => {
156 for i in 0..=cutoff_idx.min(size / 2) {
157 response[i] = 1.0;
158 if i > 0 && i < size / 2 {
159 response[size - i] = 1.0;
160 }
161 }
162 }
163 FilterType::HighPass => {
164 for i in cutoff_idx..=size / 2 {
165 response[i] = 1.0;
166 if i > 0 && i < size / 2 {
167 response[size - i] = 1.0;
168 }
169 }
170 }
171 FilterType::BandPass => {
172 for i in cutoff_idx..=cutoff_high_idx.min(size / 2) {
173 response[i] = 1.0;
174 if i > 0 && i < size / 2 {
175 response[size - i] = 1.0;
176 }
177 }
178 }
179 FilterType::BandStop => {
180 for i in 0..=size / 2 {
181 if i <= cutoff_idx || i >= cutoff_high_idx {
182 response[i] = 1.0;
183 if i > 0 && i < size / 2 {
184 response[size - i] = 1.0;
185 }
186 }
187 }
188 }
189 FilterType::Custom => {
190 return Err(FFTError::ValueError(
191 "Custom filter type requires custom_coeffs to be provided".to_string(),
192 ));
193 }
194 }
195
196 apply_window_to_response(&mut response, filter_spec);
198
199 Ok(response)
200}
201
202#[allow(dead_code)]
209fn apply_window_to_response(response: &mut [f64], filterspec: &FilterSpec) {
210 let size = response.len();
212
213 match filterspec.window {
214 FilterWindow::Rectangular => {
215 }
217 FilterWindow::Hamming => {
218 for i in 0..size {
219 if response[i] > 0.0 {
220 let window_val =
221 0.54 - 0.46 * (2.0 * std::f64::consts::PI * i as f64 / size as f64).cos();
222 response[i] *= window_val;
223 }
224 }
225 }
226 FilterWindow::Hanning => {
227 for i in 0..size {
228 if response[i] > 0.0 {
229 let window_val =
230 0.5 * (1.0 - (2.0 * std::f64::consts::PI * i as f64 / size as f64).cos());
231 response[i] *= window_val;
232 }
233 }
234 }
235 FilterWindow::Blackman => {
236 for i in 0..size {
237 if response[i] > 0.0 {
238 let x = 2.0 * std::f64::consts::PI * i as f64 / size as f64;
239 let window_val = 0.42 - 0.5 * x.cos() + 0.08 * (2.0 * x).cos();
240 response[i] *= window_val;
241 }
242 }
243 }
244 FilterWindow::Kaiser => {
245 let beta = filterspec.kaiser_beta.unwrap_or(3.0);
246 for i in 0..size {
248 if response[i] > 0.0 {
249 let x = 2.0 * i as f64 / size as f64 - 1.0;
250 let window_val = bessel_i0(beta * (1.0 - x * x).sqrt()) / bessel_i0(beta);
251 response[i] *= window_val;
252 }
253 }
254 }
255 }
256}
257
258#[allow(dead_code)]
268fn bessel_i0(x: f64) -> f64 {
269 let ax = x.abs();
271
272 if ax < 3.75 {
273 let y = (x / 3.75).powi(2);
274 1.0 + y
275 * (3.5156229
276 + y * (3.0899424
277 + y * (1.2067492 + y * (0.2659732 + y * (0.0360768 + y * 0.0045813)))))
278 } else {
279 let y = 3.75 / ax;
280 (ax.exp() / ax.sqrt())
281 * (0.39894228
282 + y * (0.01328592
283 + y * (0.00225319
284 + y * (-0.00157565
285 + y * (0.00916281
286 + y * (-0.02057706
287 + y * (0.02635537 + y * (-0.01647633 + y * 0.00392377))))))))
288 }
289}
290
291#[allow(dead_code)]
302pub fn convolve<T, U>(signal: &[T], kernel: &[U]) -> FFTResult<Vec<f64>>
303where
304 T: NumCast + Copy + Debug,
305 U: NumCast + Copy + Debug,
306{
307 let max_size = 512;
309 let signal_len = signal.len().min(max_size);
310 let kernel_len = kernel.len().min(max_size);
311 let result_len = signal_len + kernel_len - 1;
312
313 let signal_f64: Vec<f64> = signal
315 .iter()
316 .take(signal_len)
317 .map(|&val| {
318 NumCast::from(val)
319 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {:?} to f64", val)))
320 })
321 .collect::<FFTResult<Vec<_>>>()?;
322
323 let kernel_f64: Vec<f64> = kernel
324 .iter()
325 .take(kernel_len)
326 .map(|&val| {
327 NumCast::from(val)
328 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {:?} to f64", val)))
329 })
330 .collect::<FFTResult<Vec<_>>>()?;
331
332 let mut signal_padded = signal_f64;
334 signal_padded.resize(result_len, 0.0);
335
336 let mut kernel_padded = kernel_f64;
337 kernel_padded.resize(result_len, 0.0);
338
339 let signal_fft = fft(&signal_padded, None)?;
341 let kernel_fft = fft(&kernel_padded, None)?;
342
343 let mut result_fft = Vec::with_capacity(result_len);
345 for i in 0..result_len {
346 result_fft.push(signal_fft[i] * kernel_fft[i]);
347 }
348
349 let result_complex = ifft(&result_fft, None)?;
351
352 let result: Vec<f64> = result_complex.iter().map(|c| c.re).collect();
354
355 Ok(result)
356}
357
358#[allow(dead_code)]
369pub fn cross_correlate<T, U>(signal1: &[T], signal2: &[U]) -> FFTResult<Vec<f64>>
370where
371 T: NumCast + Copy + Debug,
372 U: NumCast + Copy + Debug,
373{
374 let signal1_len = signal1.len();
376 let signal2_len = signal2.len();
377 let result_len = signal1_len + signal2_len - 1;
378
379 let signal1_f64: Vec<f64> = signal1
381 .iter()
382 .map(|&val| {
383 NumCast::from(val)
384 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {:?} to f64", val)))
385 })
386 .collect::<FFTResult<Vec<_>>>()?;
387
388 let signal2_f64: Vec<f64> = signal2
389 .iter()
390 .map(|&val| {
391 NumCast::from(val)
392 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {:?} to f64", val)))
393 })
394 .collect::<FFTResult<Vec<_>>>()?;
395
396 let mut signal1_padded = signal1_f64.clone();
398 signal1_padded.resize(result_len, 0.0);
399
400 let mut signal2_padded = signal2_f64.clone();
401 signal2_padded.resize(result_len, 0.0);
402
403 let signal1_fft = fft(&signal1_padded, None)?;
405 let signal2_fft = fft(&signal2_padded, None)?;
406
407 let mut result_fft = Vec::with_capacity(result_len);
409 for i in 0..result_len {
410 result_fft.push(signal1_fft[i] * signal2_fft[i].conj());
411 }
412
413 let result_complex = ifft(&result_fft, None)?;
415
416 let result: Vec<f64> = result_complex.iter().map(|c| c.re).collect();
418
419 Ok(result)
420}
421
422#[allow(dead_code)]
432pub fn design_fir_filter(filter_spec: &FilterSpec) -> FFTResult<Vec<f64>> {
433 let order = filter_spec.order;
434
435 let adjusted_order = if order % 2 == 0 { order + 1 } else { order };
437
438 let n_freqs = 2048; let freq_response = design_frequency_response(filter_spec, n_freqs)?;
441
442 let mut complex_response = vec![Complex64::new(0.0, 0.0); n_freqs];
444 for i in 0..n_freqs {
445 complex_response[i] = Complex64::new(freq_response[i], 0.0);
446 }
447
448 let impulse_response = ifft(&complex_response, None)?;
449
450 let half_order = adjusted_order / 2;
452 let mut coeffs = vec![0.0; adjusted_order];
453
454 for i in 0..adjusted_order {
455 let idx = (i + n_freqs - half_order) % n_freqs;
456 coeffs[i] = impulse_response[idx].re;
457 }
458
459 let mut window = vec![0.0; adjusted_order];
461 match filter_spec.window {
462 FilterWindow::Rectangular => {
463 window.iter_mut().for_each(|w| *w = 1.0);
464 }
465 FilterWindow::Hamming => {
466 for i in 0..adjusted_order {
467 window[i] = 0.54
468 - 0.46
469 * (2.0 * std::f64::consts::PI * i as f64 / (adjusted_order - 1) as f64)
470 .cos();
471 }
472 }
473 FilterWindow::Hanning => {
474 for i in 0..adjusted_order {
475 window[i] = 0.5
476 * (1.0
477 - (2.0 * std::f64::consts::PI * i as f64 / (adjusted_order - 1) as f64)
478 .cos());
479 }
480 }
481 FilterWindow::Blackman => {
482 for i in 0..adjusted_order {
483 let x = 2.0 * std::f64::consts::PI * i as f64 / (adjusted_order - 1) as f64;
484 window[i] = 0.42 - 0.5 * x.cos() + 0.08 * (2.0 * x).cos();
485 }
486 }
487 FilterWindow::Kaiser => {
488 let beta = filter_spec.kaiser_beta.unwrap_or(3.0);
489 for i in 0..adjusted_order {
490 let x = 2.0 * i as f64 / (adjusted_order - 1) as f64 - 1.0;
491 window[i] = bessel_i0(beta * (1.0 - x * x).sqrt()) / bessel_i0(beta);
492 }
493 }
494 }
495
496 for i in 0..adjusted_order {
498 coeffs[i] *= window[i];
499 }
500
501 let dc_gain: f64 = coeffs.iter().sum();
503 if dc_gain != 0.0 {
504 for coeff in &mut coeffs {
505 *coeff /= dc_gain;
506 }
507 }
508
509 Ok(coeffs)
510}
511
512#[allow(dead_code)]
523pub fn fir_filter<T>(signal: &[T], filtercoeffs: &[f64]) -> FFTResult<Vec<f64>>
524where
525 T: NumCast + Copy + Debug,
526{
527 convolve(signal, filtercoeffs)
528}
529
530#[cfg(test)]
531#[cfg(feature = "never")] mod tests {
533 use super::*;
534 use approx::assert_relative_eq;
535
536 #[test]
537 fn test_frequency_filter_lowpass() {
538 let n = 128;
540 let mut signal = vec![0.0; n];
541
542 for i in 0..n {
544 signal[i] += (2.0 * std::f64::consts::PI * 2.0 * i as f64 / n as f64).sin();
545 }
546
547 for i in 0..n {
549 signal[i] += 0.5 * (2.0 * std::f64::consts::PI * 10.0 * i as f64 / n as f64).sin();
550 }
551
552 let filter_spec = FilterSpec {
554 filter_type: FilterType::LowPass,
555 order: 32,
556 cutoff: 0.25, window: FilterWindow::Hamming,
558 ..Default::default()
559 };
560
561 let filtered = frequency_filter(&signal, &filter_spec).expect("Operation failed");
563
564 assert_eq!(filtered.len(), signal.len());
568 }
569
570 #[test]
571 fn test_convolution() {
572 let signal = vec![1.0, 2.0, 3.0];
575 let kernel = vec![0.5, 0.5];
576
577 let result = convolve(&signal, &kernel).expect("Operation failed");
578
579 assert_eq!(result.len(), signal.len() + kernel.len() - 1);
580 assert_relative_eq!(result[0], 0.5, epsilon = 1e-10);
581 assert_relative_eq!(result[1], 1.5, epsilon = 1e-10);
582 assert_relative_eq!(result[2], 2.5, epsilon = 1e-10);
583 assert_relative_eq!(result[3], 1.5, epsilon = 1e-10);
584 }
585}