rustradio/
fir.rs

1//! Finite impulse response filter.
2//!
3//! If using many taps, [`FftFilter`](crate::blocks::FftFilter) probably has
4//! better performance.
5/*
6 * TODO:
7 * * Only handles case where input, output, and tap type are all the same.
8 */
9use crate::block::{Block, BlockRet};
10use crate::stream::{ReadStream, WriteStream};
11use crate::window::{Window, WindowType};
12use crate::{Complex, Float, Result, Sample};
13
14/// Finite impulse response filter.
15pub struct Fir<T> {
16    taps: Vec<T>,
17}
18
19#[cfg(all(
20    target_feature = "avx",
21    target_feature = "sse3",
22    target_feature = "sse"
23))]
24#[allow(unreachable_code)]
25fn sum_product_avx(vec1: &[f32], vec2: &[f32]) -> f32 {
26    // SAFETY: Pointer arithmetic "should be fine". And as for instruction availability, that could
27    // be checked by the macro above.
28    unsafe {
29        use core::arch::x86_64::*;
30        assert_eq!(vec1.len(), vec2.len());
31        let len = vec1.len() - vec1.len() % 8;
32
33        // AVX.
34        let mut sum = _mm256_setzero_ps(); // Initialize sum vector to zeros.
35
36        for i in (0..len).step_by(8) {
37            // AVX.
38            let a = _mm256_loadu_ps(vec1.as_ptr().add(i));
39            let b = _mm256_loadu_ps(vec2.as_ptr().add(i));
40
41            // Multiply and accumulate.
42            // AVX.
43            let prod = _mm256_mul_ps(a, b);
44            sum = _mm256_add_ps(sum, prod);
45        }
46
47        // Split.
48        // AVX.
49        let low = _mm256_extractf128_ps(sum, 0);
50        let high = _mm256_extractf128_ps(sum, 1);
51
52        // Compact step 1 => 4 floats.
53        // SSE3.
54        let m128 = _mm_hadd_ps(low, high);
55
56        // Compact step 2 => 2 floats.
57        // SSE3.
58        let m128 = _mm_hadd_ps(m128, low);
59
60        // Compact step 3 => 1 floats.
61        // SSE3.
62        let m128 = _mm_hadd_ps(m128, low);
63        // SSE.
64        let partial = _mm_cvtss_f32(m128);
65        let skip = vec1.len() - vec1.len() % 8;
66        vec1[skip..]
67            .iter()
68            .zip(vec2[skip..].iter())
69            .fold(partial, |acc, (&f, &x)| acc + x * f)
70    }
71}
72
73impl Fir<Float> {
74    /// Run filter once, creating one sample from the taps and an
75    /// equal number of input samples.
76    #[must_use]
77    pub fn filter_float(&self, input: &[Float]) -> Float {
78        // AVX is faster, when available.
79        #[cfg(all(
80            target_feature = "avx",
81            target_feature = "sse3",
82            target_feature = "sse"
83        ))]
84        return sum_product_avx(&self.taps, input);
85        // Second fastest is generic simd.
86        #[cfg(feature = "simd")]
87        #[allow(unreachable_code)]
88        {
89            use std::simd::num::SimdFloat;
90            let batch_n = 8;
91            // How will this work if Float is f64?
92            type Batch = std::simd::f32x8;
93            let partial = input
94                .chunks_exact(batch_n)
95                .zip(self.taps.chunks_exact(batch_n))
96                .map(|(a, b)| Batch::from_slice(a) * Batch::from_slice(b))
97                .fold(Batch::splat(0.0), |acc, x| acc + x)
98                .reduce_sum();
99            // Maybe even faster if doing a second round with f32x4.
100            let skip = self.taps.len() - self.taps.len() % batch_n;
101            return input[skip..]
102                .iter()
103                .zip(self.taps[skip..].iter())
104                .fold(partial, |acc, (&f, &x)| acc + x * f);
105        }
106        #[allow(unreachable_code)]
107        self.filter(input)
108    }
109}
110
111impl<T> Fir<T>
112where
113    T: Sample + std::ops::Mul<T, Output = T> + std::ops::Add<T, Output = T>,
114{
115    /// Create new Fir.
116    #[must_use]
117    pub fn new(taps: &[T]) -> Self {
118        Self {
119            taps: taps.iter().copied().rev().collect(),
120        }
121    }
122    /// Run filter once, creating one sample from the taps and an
123    /// equal number of input samples.
124    #[must_use]
125    pub fn filter(&self, input: &[T]) -> T {
126        assert!(
127            input.len() >= self.taps.len(),
128            "input {} < taps {}",
129            input.len(),
130            self.taps.len()
131        );
132        input
133            .iter()
134            .zip(self.taps.iter())
135            .fold(T::default(), |acc, (&f, &x)| acc + x * f)
136    }
137
138    /// Call `filter()` multiple times, across an input range.
139    #[must_use]
140    pub fn filter_n(&self, input: &[T], deci: usize) -> Vec<T> {
141        let n = input.len() - self.taps.len();
142        (0..=n)
143            .step_by(deci)
144            .map(|i| self.filter(&input[i..]))
145            .collect()
146    }
147
148    /// Like `filter_n`, but avoids a copy when there's a destination in mind.
149    pub fn filter_n_inplace(&self, input: &[T], deci: usize, out: &mut [T]) {
150        out.iter_mut()
151            .enumerate()
152            .for_each(|(i, o)| *o = self.filter(&input[(i * deci)..]));
153    }
154}
155
156/// Builder for a FIR filter block.
157///
158/// A builder is needed to create a decimating FIR filter block.
159pub struct FirFilterBuilder<T> {
160    taps: Vec<T>,
161    deci: usize,
162}
163
164impl<T> FirFilterBuilder<T>
165where
166    T: Sample + std::ops::Mul<T, Output = T> + std::ops::Add<T, Output = T>,
167{
168    /// Set the decimation to the given value.
169    ///
170    /// The default is 1, meaning no decimation.
171    #[must_use]
172    pub fn deci(mut self, deci: usize) -> Self {
173        self.deci = deci;
174        self
175    }
176
177    /// Build a `FirFilter` with the provided settings.
178    #[must_use]
179    pub fn build(self, src: ReadStream<T>) -> (FirFilter<T>, ReadStream<T>) {
180        let (mut block, stream) = FirFilter::new(src, &self.taps);
181        block.deci = self.deci;
182        (block, stream)
183    }
184}
185
186/// Finite impulse response filter block.
187#[derive(rustradio_macros::Block)]
188#[rustradio(crate)]
189pub struct FirFilter<T: Sample> {
190    fir: Fir<T>,
191    ntaps: usize,
192    deci: usize,
193    #[rustradio(in)]
194    src: ReadStream<T>,
195    #[rustradio(out)]
196    dst: WriteStream<T>,
197}
198
199impl<T> FirFilter<T>
200where
201    T: Sample + std::ops::Mul<T, Output = T> + std::ops::Add<T, Output = T>,
202{
203    /// Create new `FirFilterBuilder`, with the supplied taps.
204    pub fn builder(taps: &[T]) -> FirFilterBuilder<T> {
205        FirFilterBuilder {
206            taps: taps.to_vec(),
207            deci: 1,
208        }
209    }
210    /// Create Fir block given taps.
211    pub fn new(src: ReadStream<T>, taps: &[T]) -> (Self, ReadStream<T>) {
212        let (dst, dr) = crate::stream::new_stream();
213        (
214            Self {
215                src,
216                dst,
217                ntaps: taps.len(),
218                deci: 1,
219                fir: Fir::new(taps),
220            },
221            dr,
222        )
223    }
224}
225
226impl<T> Block for FirFilter<T>
227where
228    T: Sample + std::ops::Mul<T, Output = T> + std::ops::Add<T, Output = T>,
229{
230    fn work(&mut self) -> Result<BlockRet<'_>> {
231        let (input, mut tags) = self.src.read_buf()?;
232
233        // Get number of input samples we intend to consume.
234        let n = {
235            // Carefully avoid underflow.
236            let absolute_minimum = self.ntaps + self.deci - 1;
237            if input.len() < absolute_minimum {
238                return Ok(BlockRet::WaitForStream(&self.src, absolute_minimum));
239            }
240            self.deci * ((input.len() - self.ntaps + 1) / self.deci)
241        };
242        assert_ne!(n, 0);
243
244        // To consume `n`, we may need more input samples than that.
245        let need = n + self.ntaps - 1;
246        assert!(input.len() >= need, "need {need}, have {}", input.len());
247
248        // Output must have room for at least one sample.
249        let mut out = self.dst.write_buf()?;
250        let need_out = 1;
251        if out.len() < need_out {
252            return Ok(BlockRet::WaitForStream(&self.dst, need_out));
253        }
254
255        // Cap by output capacity.
256        let n = std::cmp::min(n, out.len() * self.deci);
257
258        // Final `n` (samples to consume) calculated. Sanity check it.
259        assert_eq!(n % self.deci, 0);
260        assert_ne!(n, 0, "input: {} out: {}", input.len(), out.len());
261
262        // Run the FIR.
263        let out_n = n / self.deci;
264        self.fir
265            .filter_n_inplace(&input.slice()[..need], self.deci, &mut out.slice()[..out_n]);
266
267        // Sanity check the generated output.
268        assert!(out_n <= out.len());
269
270        input.consume(n);
271        if self.deci == 1 {
272            out.produce(out_n, &tags);
273        } else {
274            tags.iter_mut().for_each(|t| t.set_pos(t.pos() / self.deci));
275            out.produce(out_n, &tags);
276        }
277        // While we could keep track of which stream is the constraining factor,
278        // the code is simpler if work() is just called again, and the right
279        // WaitForStream is returned above instead.
280        Ok(BlockRet::Again)
281    }
282}
283
284/// Create a multiband filter.
285///
286/// TODO: this is untested.
287#[must_use]
288pub fn multiband(bands: &[(Float, Float)], taps: usize, window: &Window) -> Option<Vec<Complex>> {
289    if taps != window.0.len() {
290        return None;
291    }
292    use rustfft::FftPlanner;
293
294    let mut ideal = vec![Complex::new(0.0, 0.0); taps];
295    let scale = (taps as Float) / 2.0;
296    for (low, high) in bands {
297        let a = (low * scale).floor() as usize;
298        let b = (high * scale).ceil() as usize;
299        for n in a..b {
300            ideal[n] = Complex::new(1.0, 0.0);
301            ideal[taps - n - 1] = Complex::new(1.0, 0.0);
302        }
303    }
304    let fft_size = taps;
305    let mut planner = FftPlanner::new();
306    let ifft = planner.plan_fft_inverse(fft_size);
307    ifft.process(&mut ideal);
308    ideal.rotate_right(taps / 2);
309    let scale = (fft_size as Float).sqrt();
310    Some(
311        ideal
312            .into_iter()
313            .enumerate()
314            .map(|(n, v)| v * window.0[n] / Complex::new(scale, 0.0))
315            .collect(),
316    )
317}
318
319/// Create taps for a low pass filter as complex taps.
320#[must_use]
321pub fn low_pass_complex(
322    samp_rate: Float,
323    cutoff: Float,
324    twidth: Float,
325    window_type: &WindowType,
326) -> Vec<Complex> {
327    low_pass(samp_rate, cutoff, twidth, window_type)
328        .into_iter()
329        .map(|t| Complex::new(t, 0.0))
330        .collect()
331}
332
333fn compute_ntaps(samp_rate: Float, twidth: Float, window_type: &WindowType) -> usize {
334    let a = window_type.max_attenuation();
335    let t = (a * samp_rate / (22.0 * twidth)) as usize;
336    if (t & 1) == 0 { t + 1 } else { t }
337}
338
339/// Create taps for a low pass filter.
340///
341/// TODO: this could be faster if we supported filtering a Complex by a Float.
342/// A low pass filter doesn't actually need complex taps.
343#[must_use]
344pub fn low_pass(
345    samp_rate: Float,
346    cutoff: Float,
347    twidth: Float,
348    window_type: &WindowType,
349) -> Vec<Float> {
350    let pi = std::f64::consts::PI as Float;
351    let ntaps = compute_ntaps(samp_rate, twidth, window_type);
352    let window = window_type.make_window(ntaps);
353    let m = (ntaps - 1) / 2;
354    let fwt0 = 2.0 * pi * cutoff / samp_rate;
355    let taps: Vec<_> = window
356        .0
357        .iter()
358        .enumerate()
359        .map(|(nm, win)| {
360            let n = nm as i64 - m as i64;
361            let nf = n as Float;
362            if n == 0 {
363                fwt0 / pi * win
364            } else {
365                ((nf * fwt0).sin() / (nf * pi)) * win
366            }
367        })
368        .collect();
369    let gain = {
370        let gain: Float = 1.0;
371        let mut fmax = taps[m];
372        for n in 1..=m {
373            fmax += 2.0 * taps[n + m];
374        }
375        gain / fmax
376    };
377    taps.into_iter().map(|t| t * gain).collect()
378}
379
380/// Generate hilbert transformer filter.
381#[must_use]
382pub fn hilbert(window: &Window) -> Vec<Float> {
383    let ntaps = window.0.len();
384    let mid = (ntaps - 1) / 2;
385    let mut gain = 0.0;
386    let mut taps = vec![0.0; ntaps];
387    for i in 1..=mid {
388        if i & 1 == 1 {
389            let x = 1.0 / (i as Float);
390            taps[mid + i] = x * window.0[mid + i];
391            taps[mid - i] = -x * window.0[mid - i];
392            gain = taps[mid + i] - gain;
393        } else {
394            taps[mid + i] = 0.0;
395            taps[mid - i] = 0.0;
396        }
397    }
398    let gain = 1.0 / (2.0 * gain.abs());
399    taps.iter().map(|e| gain * *e).collect()
400}
401
402#[cfg(test)]
403#[cfg_attr(coverage_nightly, coverage(off))]
404mod tests {
405    use super::*;
406    use crate::Repeat;
407    use crate::blocks::VectorSource;
408    use crate::stream::{Tag, TagValue};
409    use crate::tests::assert_almost_equal_complex;
410
411    #[test]
412    fn test_identity() -> Result<()> {
413        let input = vec![
414            Complex::new(1.0, 0.0),
415            Complex::new(2.0, 0.0),
416            Complex::new(3.0, 0.2),
417            Complex::new(4.1, 0.0),
418            Complex::new(5.0, 0.0),
419            Complex::new(6.0, 0.2),
420        ];
421        let taps = vec![Complex::new(1.0, 0.0)];
422        for deci in 1..=(3 * input.len()) {
423            let (mut src, src_out) = VectorSource::builder(input.clone())
424                .repeat(Repeat::finite(2))
425                .build()?;
426            assert!(matches![src.work()?, BlockRet::Again]);
427            assert!(matches![src.work()?, BlockRet::EOF]);
428
429            eprintln!("Testing identity with decimation {deci}");
430            let (mut b, os) = FirFilter::builder(&taps).deci(deci).build(src_out);
431            if deci <= 2 * input.len() {
432                assert!(matches![b.work()?, BlockRet::Again]);
433            }
434            assert!(matches![b.work()?, BlockRet::WaitForStream(_, _)]);
435            let (res, tags) = os.read_buf()?;
436            let max = 2 * input.len() / deci;
437            if !res.is_empty() {
438                assert_eq!(
439                    &tags,
440                    &[
441                        Tag::new(0, "VectorSource::start", TagValue::Bool(true)),
442                        Tag::new(0, "VectorSource::repeat", TagValue::U64(0)),
443                        Tag::new(0, "VectorSource::first", TagValue::Bool(true)),
444                        Tag::new(6 / deci, "VectorSource::start", TagValue::Bool(true)),
445                        Tag::new(6 / deci, "VectorSource::repeat", TagValue::U64(1)),
446                    ]
447                );
448            }
449            assert_almost_equal_complex(
450                res.slice(),
451                &input
452                    .iter()
453                    .chain(input.iter())
454                    .copied()
455                    .step_by(deci)
456                    .take(max)
457                    .collect::<Vec<_>>(),
458            );
459        }
460        Ok(())
461    }
462
463    #[test]
464    fn test_invert() -> Result<()> {
465        let input = vec![
466            Complex::new(1.0, 0.0),
467            Complex::new(2.0, 0.0),
468            Complex::new(3.0, 0.2),
469            Complex::new(4.1, 0.0),
470            Complex::new(5.0, 0.0),
471            Complex::new(6.0, 0.2),
472        ];
473        let taps = vec![Complex::new(-1.0, 0.0)];
474        for deci in 1..=(input.len() + 1) {
475            let (mut src, src_out) = VectorSource::new(input.clone());
476            src.work()?;
477
478            eprintln!("Testing identity with decimation {deci}");
479            let (mut b, os) = FirFilter::builder(&taps).deci(deci).build(src_out);
480            if deci <= input.len() {
481                assert!(matches![b.work()?, BlockRet::Again]);
482            }
483            assert!(matches![b.work()?, BlockRet::WaitForStream(_, _)]);
484            let (res, _) = os.read_buf()?;
485            let max = input.len() / deci;
486            assert_almost_equal_complex(
487                res.slice(),
488                &input
489                    .iter()
490                    .copied()
491                    .step_by(deci)
492                    .take(max)
493                    .map(|v| -v)
494                    .collect::<Vec<_>>(),
495            );
496        }
497        Ok(())
498    }
499
500    #[test]
501    fn moving_avg() -> Result<()> {
502        let input = vec![
503            Complex::new(1.0, 0.0),
504            Complex::new(2.0, 0.0),
505            Complex::new(3.0, 0.2),
506            Complex::new(4.1, 0.0),
507            Complex::new(5.0, 0.0),
508            Complex::new(6.0, 0.2),
509        ];
510        let taps = vec![Complex::new(0.5, 0.0), Complex::new(0.5, 0.0)];
511        for deci in 1..=(input.len() + 1) {
512            let (mut src, src_out) = VectorSource::new(input.clone());
513            src.work()?;
514
515            eprintln!("Testing identity with decimation {deci}");
516            let (mut b, os) = FirFilter::builder(&taps).deci(deci).build(src_out);
517            if deci < input.len() {
518                assert!(matches![b.work()?, BlockRet::Again]);
519            }
520            assert!(matches![b.work()?, BlockRet::WaitForStream(_, _)]);
521            let (res, _) = os.read_buf()?;
522            let max = (input.len() - 1) / deci;
523            assert_almost_equal_complex(
524                res.slice(),
525                &[
526                    Complex::new(1.5, 0.0),
527                    Complex::new(2.5, 0.1),
528                    Complex::new(3.55, 0.1),
529                    Complex::new(4.55, 0.0),
530                    Complex::new(5.5, 0.1),
531                ]
532                .into_iter()
533                .step_by(deci)
534                .take(max)
535                .collect::<Vec<_>>(),
536            );
537        }
538        Ok(())
539    }
540
541    #[test]
542    fn test_complex() {
543        let input = vec![
544            Complex::new(1.0, 0.0),
545            Complex::new(2.0, 0.0),
546            Complex::new(3.0, 0.2),
547            Complex::new(4.1, 0.0),
548            Complex::new(5.0, 0.0),
549            Complex::new(6.0, 0.2),
550        ];
551        let taps = vec![
552            Complex::new(0.1, 0.0),
553            Complex::new(1.0, 0.0),
554            Complex::new(0.0, 0.2),
555        ];
556        let filter = Fir::new(&taps);
557        assert_almost_equal_complex(
558            &filter.filter_n(&input, 1),
559            &[
560                Complex::new(2.3, 0.22),
561                Complex::new(3.41, 0.6),
562                Complex::new(4.56, 0.6),
563                Complex::new(5.6, 0.84),
564            ],
565        );
566        assert_almost_equal_complex(
567            &filter.filter_n(&input, 2),
568            &[Complex::new(2.3, 0.22), Complex::new(4.56, 0.6)],
569        );
570    }
571
572    #[test]
573    fn test_filter_generator() {
574        let taps = low_pass_complex(10000.0, 1000.0, 1000.0, &WindowType::Hamming);
575        assert_eq!(taps.len(), 25);
576        assert_almost_equal_complex(
577            &taps,
578            &[
579                Complex::new(0.002010403, 0.0),
580                Complex::new(0.0016210203, 0.0),
581                Complex::new(7.851862e-10, 0.0),
582                Complex::new(-0.0044467063, 0.0),
583                Complex::new(-0.011685465, 0.0),
584                Complex::new(-0.018134259, 0.0),
585                Complex::new(-0.016773716, 0.0),
586                Complex::new(-3.6538055e-9, 0.0),
587                Complex::new(0.0358771, 0.0),
588                Complex::new(0.08697697, 0.0),
589                Complex::new(0.14148787, 0.0),
590                Complex::new(0.18345332, 0.0),
591                Complex::new(0.19922684, 0.0),
592                Complex::new(0.1834533, 0.0),
593                Complex::new(0.14148785, 0.0),
594                Complex::new(0.08697697, 0.0),
595                Complex::new(0.035877097, 0.0),
596                Complex::new(-3.6538053e-9, 0.0),
597                Complex::new(-0.016773716, 0.0),
598                Complex::new(-0.018134257, 0.0),
599                Complex::new(-0.011685458, 0.0),
600                Complex::new(-0.0044467044, 0.0),
601                Complex::new(7.851859e-10, 0.0),
602                Complex::new(0.0016210207, 0.0),
603                Complex::new(0.002010403, 0.0),
604            ],
605        );
606    }
607}