1use crate::block::{Block, BlockRet};
10use crate::stream::{ReadStream, WriteStream};
11use crate::window::{Window, WindowType};
12use crate::{Complex, Float, Result, Sample};
13
14pub struct Fir<T> {
16 taps: Vec<T>,
17}
18
19#[cfg(all(
20 target_feature = "avx",
21 target_feature = "sse3",
22 target_feature = "sse"
23))]
24#[allow(unreachable_code)]
25fn sum_product_avx(vec1: &[f32], vec2: &[f32]) -> f32 {
26 unsafe {
29 use core::arch::x86_64::*;
30 assert_eq!(vec1.len(), vec2.len());
31 let len = vec1.len() - vec1.len() % 8;
32
33 let mut sum = _mm256_setzero_ps(); for i in (0..len).step_by(8) {
37 let a = _mm256_loadu_ps(vec1.as_ptr().add(i));
39 let b = _mm256_loadu_ps(vec2.as_ptr().add(i));
40
41 let prod = _mm256_mul_ps(a, b);
44 sum = _mm256_add_ps(sum, prod);
45 }
46
47 let low = _mm256_extractf128_ps(sum, 0);
50 let high = _mm256_extractf128_ps(sum, 1);
51
52 let m128 = _mm_hadd_ps(low, high);
55
56 let m128 = _mm_hadd_ps(m128, low);
59
60 let m128 = _mm_hadd_ps(m128, low);
63 let partial = _mm_cvtss_f32(m128);
65 let skip = vec1.len() - vec1.len() % 8;
66 vec1[skip..]
67 .iter()
68 .zip(vec2[skip..].iter())
69 .fold(partial, |acc, (&f, &x)| acc + x * f)
70 }
71}
72
73impl Fir<Float> {
74 #[must_use]
77 pub fn filter_float(&self, input: &[Float]) -> Float {
78 #[cfg(all(
80 target_feature = "avx",
81 target_feature = "sse3",
82 target_feature = "sse"
83 ))]
84 return sum_product_avx(&self.taps, input);
85 #[cfg(feature = "simd")]
87 #[allow(unreachable_code)]
88 {
89 use std::simd::num::SimdFloat;
90 let batch_n = 8;
91 type Batch = std::simd::f32x8;
93 let partial = input
94 .chunks_exact(batch_n)
95 .zip(self.taps.chunks_exact(batch_n))
96 .map(|(a, b)| Batch::from_slice(a) * Batch::from_slice(b))
97 .fold(Batch::splat(0.0), |acc, x| acc + x)
98 .reduce_sum();
99 let skip = self.taps.len() - self.taps.len() % batch_n;
101 return input[skip..]
102 .iter()
103 .zip(self.taps[skip..].iter())
104 .fold(partial, |acc, (&f, &x)| acc + x * f);
105 }
106 #[allow(unreachable_code)]
107 self.filter(input)
108 }
109}
110
111impl<T> Fir<T>
112where
113 T: Sample + std::ops::Mul<T, Output = T> + std::ops::Add<T, Output = T>,
114{
115 #[must_use]
117 pub fn new(taps: &[T]) -> Self {
118 Self {
119 taps: taps.iter().copied().rev().collect(),
120 }
121 }
122 #[must_use]
125 pub fn filter(&self, input: &[T]) -> T {
126 assert!(
127 input.len() >= self.taps.len(),
128 "input {} < taps {}",
129 input.len(),
130 self.taps.len()
131 );
132 input
133 .iter()
134 .zip(self.taps.iter())
135 .fold(T::default(), |acc, (&f, &x)| acc + x * f)
136 }
137
138 #[must_use]
140 pub fn filter_n(&self, input: &[T], deci: usize) -> Vec<T> {
141 let n = input.len() - self.taps.len();
142 (0..=n)
143 .step_by(deci)
144 .map(|i| self.filter(&input[i..]))
145 .collect()
146 }
147
148 pub fn filter_n_inplace(&self, input: &[T], deci: usize, out: &mut [T]) {
150 out.iter_mut()
151 .enumerate()
152 .for_each(|(i, o)| *o = self.filter(&input[(i * deci)..]));
153 }
154}
155
156pub struct FirFilterBuilder<T> {
160 taps: Vec<T>,
161 deci: usize,
162}
163
164impl<T> FirFilterBuilder<T>
165where
166 T: Sample + std::ops::Mul<T, Output = T> + std::ops::Add<T, Output = T>,
167{
168 #[must_use]
172 pub fn deci(mut self, deci: usize) -> Self {
173 self.deci = deci;
174 self
175 }
176
177 #[must_use]
179 pub fn build(self, src: ReadStream<T>) -> (FirFilter<T>, ReadStream<T>) {
180 let (mut block, stream) = FirFilter::new(src, &self.taps);
181 block.deci = self.deci;
182 (block, stream)
183 }
184}
185
186#[derive(rustradio_macros::Block)]
188#[rustradio(crate)]
189pub struct FirFilter<T: Sample> {
190 fir: Fir<T>,
191 ntaps: usize,
192 deci: usize,
193 #[rustradio(in)]
194 src: ReadStream<T>,
195 #[rustradio(out)]
196 dst: WriteStream<T>,
197}
198
199impl<T> FirFilter<T>
200where
201 T: Sample + std::ops::Mul<T, Output = T> + std::ops::Add<T, Output = T>,
202{
203 pub fn builder(taps: &[T]) -> FirFilterBuilder<T> {
205 FirFilterBuilder {
206 taps: taps.to_vec(),
207 deci: 1,
208 }
209 }
210 pub fn new(src: ReadStream<T>, taps: &[T]) -> (Self, ReadStream<T>) {
212 let (dst, dr) = crate::stream::new_stream();
213 (
214 Self {
215 src,
216 dst,
217 ntaps: taps.len(),
218 deci: 1,
219 fir: Fir::new(taps),
220 },
221 dr,
222 )
223 }
224}
225
226impl<T> Block for FirFilter<T>
227where
228 T: Sample + std::ops::Mul<T, Output = T> + std::ops::Add<T, Output = T>,
229{
230 fn work(&mut self) -> Result<BlockRet<'_>> {
231 let (input, mut tags) = self.src.read_buf()?;
232
233 let n = {
235 let absolute_minimum = self.ntaps + self.deci - 1;
237 if input.len() < absolute_minimum {
238 return Ok(BlockRet::WaitForStream(&self.src, absolute_minimum));
239 }
240 self.deci * ((input.len() - self.ntaps + 1) / self.deci)
241 };
242 assert_ne!(n, 0);
243
244 let need = n + self.ntaps - 1;
246 assert!(input.len() >= need, "need {need}, have {}", input.len());
247
248 let mut out = self.dst.write_buf()?;
250 let need_out = 1;
251 if out.len() < need_out {
252 return Ok(BlockRet::WaitForStream(&self.dst, need_out));
253 }
254
255 let n = std::cmp::min(n, out.len() * self.deci);
257
258 assert_eq!(n % self.deci, 0);
260 assert_ne!(n, 0, "input: {} out: {}", input.len(), out.len());
261
262 let out_n = n / self.deci;
264 self.fir
265 .filter_n_inplace(&input.slice()[..need], self.deci, &mut out.slice()[..out_n]);
266
267 assert!(out_n <= out.len());
269
270 input.consume(n);
271 if self.deci == 1 {
272 out.produce(out_n, &tags);
273 } else {
274 tags.iter_mut().for_each(|t| t.set_pos(t.pos() / self.deci));
275 out.produce(out_n, &tags);
276 }
277 Ok(BlockRet::Again)
281 }
282}
283
284#[must_use]
288pub fn multiband(bands: &[(Float, Float)], taps: usize, window: &Window) -> Option<Vec<Complex>> {
289 if taps != window.0.len() {
290 return None;
291 }
292 use rustfft::FftPlanner;
293
294 let mut ideal = vec![Complex::new(0.0, 0.0); taps];
295 let scale = (taps as Float) / 2.0;
296 for (low, high) in bands {
297 let a = (low * scale).floor() as usize;
298 let b = (high * scale).ceil() as usize;
299 for n in a..b {
300 ideal[n] = Complex::new(1.0, 0.0);
301 ideal[taps - n - 1] = Complex::new(1.0, 0.0);
302 }
303 }
304 let fft_size = taps;
305 let mut planner = FftPlanner::new();
306 let ifft = planner.plan_fft_inverse(fft_size);
307 ifft.process(&mut ideal);
308 ideal.rotate_right(taps / 2);
309 let scale = (fft_size as Float).sqrt();
310 Some(
311 ideal
312 .into_iter()
313 .enumerate()
314 .map(|(n, v)| v * window.0[n] / Complex::new(scale, 0.0))
315 .collect(),
316 )
317}
318
319#[must_use]
321pub fn low_pass_complex(
322 samp_rate: Float,
323 cutoff: Float,
324 twidth: Float,
325 window_type: &WindowType,
326) -> Vec<Complex> {
327 low_pass(samp_rate, cutoff, twidth, window_type)
328 .into_iter()
329 .map(|t| Complex::new(t, 0.0))
330 .collect()
331}
332
333fn compute_ntaps(samp_rate: Float, twidth: Float, window_type: &WindowType) -> usize {
334 let a = window_type.max_attenuation();
335 let t = (a * samp_rate / (22.0 * twidth)) as usize;
336 if (t & 1) == 0 { t + 1 } else { t }
337}
338
339#[must_use]
344pub fn low_pass(
345 samp_rate: Float,
346 cutoff: Float,
347 twidth: Float,
348 window_type: &WindowType,
349) -> Vec<Float> {
350 let pi = std::f64::consts::PI as Float;
351 let ntaps = compute_ntaps(samp_rate, twidth, window_type);
352 let window = window_type.make_window(ntaps);
353 let m = (ntaps - 1) / 2;
354 let fwt0 = 2.0 * pi * cutoff / samp_rate;
355 let taps: Vec<_> = window
356 .0
357 .iter()
358 .enumerate()
359 .map(|(nm, win)| {
360 let n = nm as i64 - m as i64;
361 let nf = n as Float;
362 if n == 0 {
363 fwt0 / pi * win
364 } else {
365 ((nf * fwt0).sin() / (nf * pi)) * win
366 }
367 })
368 .collect();
369 let gain = {
370 let gain: Float = 1.0;
371 let mut fmax = taps[m];
372 for n in 1..=m {
373 fmax += 2.0 * taps[n + m];
374 }
375 gain / fmax
376 };
377 taps.into_iter().map(|t| t * gain).collect()
378}
379
380#[must_use]
382pub fn hilbert(window: &Window) -> Vec<Float> {
383 let ntaps = window.0.len();
384 let mid = (ntaps - 1) / 2;
385 let mut gain = 0.0;
386 let mut taps = vec![0.0; ntaps];
387 for i in 1..=mid {
388 if i & 1 == 1 {
389 let x = 1.0 / (i as Float);
390 taps[mid + i] = x * window.0[mid + i];
391 taps[mid - i] = -x * window.0[mid - i];
392 gain = taps[mid + i] - gain;
393 } else {
394 taps[mid + i] = 0.0;
395 taps[mid - i] = 0.0;
396 }
397 }
398 let gain = 1.0 / (2.0 * gain.abs());
399 taps.iter().map(|e| gain * *e).collect()
400}
401
402#[cfg(test)]
403#[cfg_attr(coverage_nightly, coverage(off))]
404mod tests {
405 use super::*;
406 use crate::Repeat;
407 use crate::blocks::VectorSource;
408 use crate::stream::{Tag, TagValue};
409 use crate::tests::assert_almost_equal_complex;
410
411 #[test]
412 fn test_identity() -> Result<()> {
413 let input = vec![
414 Complex::new(1.0, 0.0),
415 Complex::new(2.0, 0.0),
416 Complex::new(3.0, 0.2),
417 Complex::new(4.1, 0.0),
418 Complex::new(5.0, 0.0),
419 Complex::new(6.0, 0.2),
420 ];
421 let taps = vec![Complex::new(1.0, 0.0)];
422 for deci in 1..=(3 * input.len()) {
423 let (mut src, src_out) = VectorSource::builder(input.clone())
424 .repeat(Repeat::finite(2))
425 .build()?;
426 assert!(matches![src.work()?, BlockRet::Again]);
427 assert!(matches![src.work()?, BlockRet::EOF]);
428
429 eprintln!("Testing identity with decimation {deci}");
430 let (mut b, os) = FirFilter::builder(&taps).deci(deci).build(src_out);
431 if deci <= 2 * input.len() {
432 assert!(matches![b.work()?, BlockRet::Again]);
433 }
434 assert!(matches![b.work()?, BlockRet::WaitForStream(_, _)]);
435 let (res, tags) = os.read_buf()?;
436 let max = 2 * input.len() / deci;
437 if !res.is_empty() {
438 assert_eq!(
439 &tags,
440 &[
441 Tag::new(0, "VectorSource::start", TagValue::Bool(true)),
442 Tag::new(0, "VectorSource::repeat", TagValue::U64(0)),
443 Tag::new(0, "VectorSource::first", TagValue::Bool(true)),
444 Tag::new(6 / deci, "VectorSource::start", TagValue::Bool(true)),
445 Tag::new(6 / deci, "VectorSource::repeat", TagValue::U64(1)),
446 ]
447 );
448 }
449 assert_almost_equal_complex(
450 res.slice(),
451 &input
452 .iter()
453 .chain(input.iter())
454 .copied()
455 .step_by(deci)
456 .take(max)
457 .collect::<Vec<_>>(),
458 );
459 }
460 Ok(())
461 }
462
463 #[test]
464 fn test_invert() -> Result<()> {
465 let input = vec![
466 Complex::new(1.0, 0.0),
467 Complex::new(2.0, 0.0),
468 Complex::new(3.0, 0.2),
469 Complex::new(4.1, 0.0),
470 Complex::new(5.0, 0.0),
471 Complex::new(6.0, 0.2),
472 ];
473 let taps = vec![Complex::new(-1.0, 0.0)];
474 for deci in 1..=(input.len() + 1) {
475 let (mut src, src_out) = VectorSource::new(input.clone());
476 src.work()?;
477
478 eprintln!("Testing identity with decimation {deci}");
479 let (mut b, os) = FirFilter::builder(&taps).deci(deci).build(src_out);
480 if deci <= input.len() {
481 assert!(matches![b.work()?, BlockRet::Again]);
482 }
483 assert!(matches![b.work()?, BlockRet::WaitForStream(_, _)]);
484 let (res, _) = os.read_buf()?;
485 let max = input.len() / deci;
486 assert_almost_equal_complex(
487 res.slice(),
488 &input
489 .iter()
490 .copied()
491 .step_by(deci)
492 .take(max)
493 .map(|v| -v)
494 .collect::<Vec<_>>(),
495 );
496 }
497 Ok(())
498 }
499
500 #[test]
501 fn moving_avg() -> Result<()> {
502 let input = vec![
503 Complex::new(1.0, 0.0),
504 Complex::new(2.0, 0.0),
505 Complex::new(3.0, 0.2),
506 Complex::new(4.1, 0.0),
507 Complex::new(5.0, 0.0),
508 Complex::new(6.0, 0.2),
509 ];
510 let taps = vec![Complex::new(0.5, 0.0), Complex::new(0.5, 0.0)];
511 for deci in 1..=(input.len() + 1) {
512 let (mut src, src_out) = VectorSource::new(input.clone());
513 src.work()?;
514
515 eprintln!("Testing identity with decimation {deci}");
516 let (mut b, os) = FirFilter::builder(&taps).deci(deci).build(src_out);
517 if deci < input.len() {
518 assert!(matches![b.work()?, BlockRet::Again]);
519 }
520 assert!(matches![b.work()?, BlockRet::WaitForStream(_, _)]);
521 let (res, _) = os.read_buf()?;
522 let max = (input.len() - 1) / deci;
523 assert_almost_equal_complex(
524 res.slice(),
525 &[
526 Complex::new(1.5, 0.0),
527 Complex::new(2.5, 0.1),
528 Complex::new(3.55, 0.1),
529 Complex::new(4.55, 0.0),
530 Complex::new(5.5, 0.1),
531 ]
532 .into_iter()
533 .step_by(deci)
534 .take(max)
535 .collect::<Vec<_>>(),
536 );
537 }
538 Ok(())
539 }
540
541 #[test]
542 fn test_complex() {
543 let input = vec![
544 Complex::new(1.0, 0.0),
545 Complex::new(2.0, 0.0),
546 Complex::new(3.0, 0.2),
547 Complex::new(4.1, 0.0),
548 Complex::new(5.0, 0.0),
549 Complex::new(6.0, 0.2),
550 ];
551 let taps = vec![
552 Complex::new(0.1, 0.0),
553 Complex::new(1.0, 0.0),
554 Complex::new(0.0, 0.2),
555 ];
556 let filter = Fir::new(&taps);
557 assert_almost_equal_complex(
558 &filter.filter_n(&input, 1),
559 &[
560 Complex::new(2.3, 0.22),
561 Complex::new(3.41, 0.6),
562 Complex::new(4.56, 0.6),
563 Complex::new(5.6, 0.84),
564 ],
565 );
566 assert_almost_equal_complex(
567 &filter.filter_n(&input, 2),
568 &[Complex::new(2.3, 0.22), Complex::new(4.56, 0.6)],
569 );
570 }
571
572 #[test]
573 fn test_filter_generator() {
574 let taps = low_pass_complex(10000.0, 1000.0, 1000.0, &WindowType::Hamming);
575 assert_eq!(taps.len(), 25);
576 assert_almost_equal_complex(
577 &taps,
578 &[
579 Complex::new(0.002010403, 0.0),
580 Complex::new(0.0016210203, 0.0),
581 Complex::new(7.851862e-10, 0.0),
582 Complex::new(-0.0044467063, 0.0),
583 Complex::new(-0.011685465, 0.0),
584 Complex::new(-0.018134259, 0.0),
585 Complex::new(-0.016773716, 0.0),
586 Complex::new(-3.6538055e-9, 0.0),
587 Complex::new(0.0358771, 0.0),
588 Complex::new(0.08697697, 0.0),
589 Complex::new(0.14148787, 0.0),
590 Complex::new(0.18345332, 0.0),
591 Complex::new(0.19922684, 0.0),
592 Complex::new(0.1834533, 0.0),
593 Complex::new(0.14148785, 0.0),
594 Complex::new(0.08697697, 0.0),
595 Complex::new(0.035877097, 0.0),
596 Complex::new(-3.6538053e-9, 0.0),
597 Complex::new(-0.016773716, 0.0),
598 Complex::new(-0.018134257, 0.0),
599 Complex::new(-0.011685458, 0.0),
600 Complex::new(-0.0044467044, 0.0),
601 Complex::new(7.851859e-10, 0.0),
602 Complex::new(0.0016210207, 0.0),
603 Complex::new(0.002010403, 0.0),
604 ],
605 );
606 }
607}