sac_base/operations/math/
fir.rs1use crate::network::node::Node;
2use alloc::boxed::Box;
3use core::any::Any;
4use alloc::vec::Vec;
5use alloc::collections::VecDeque;
6use core::ops;
7use core::default::Default;
8
9pub struct FIR<T> {
14 fir: Vec<T>,
16 fir_samples: VecDeque<T>,
18}
19
20impl<T> FIR<T>
21 where T: Copy + Clone + 'static + ops::Div<Output=T> + ops::Mul<Output=T> + ops::Add<Output=T> + ops::Sub<Output=T> + Default
22{
23 pub fn new(fir: Vec<T>) -> Node<T>
46 {
47 let fir_len = fir.len();
48 let mut internal: FIR<T> = FIR {
49 fir: fir,
50 fir_samples: VecDeque::with_capacity(fir_len),
51 };
52
53 for _i in 0..fir_len {
55 internal.fir_samples.push_back(T::default());
56 }
57
58 let storage = Box::new(internal) as Box<dyn Any>;
59 Node::new(storage, | data_input, data_container, output| {
60 let fir: &mut FIR<T> = data_container.downcast_mut::<FIR<T>>().unwrap();
61 let (inputs, max_len) = data_input;
62 output.clear();
63 inputs.into_iter().take(1).into_iter().for_each(|data| {
64 data.into_iter().take(max_len).into_iter().for_each(|v| {
65 let y_0 = FIR::calculate(*v, fir);
66 output.push(y_0);
67 })
68 })
69 })
70 }
71
72 #[inline(always)]
73 pub fn calculate(sample: T, data: &mut FIR<T>) -> T {
74 data.fir_samples.pop_back();
76 data.fir_samples.push_front(sample);
78
79 let mut y_0 = T::default();
80
81 for i in 0..data.fir.len() {
83 let x = *data.fir_samples.get(i).unwrap();
84 let b = *data.fir.get(i).unwrap();
85 y_0 = y_0 + b * x;
86 }
87
88 return y_0;
89 }
90}
91
92#[cfg(test)]
93mod tests {
94
95 use super::*;
96 use alloc::vec::Vec;
97
98 #[test]
99 fn test_fir_simple() {
100
101 let test_vector: Vec<f64> = [0,1,1,1,1,1,1,1,1].iter().map(|v| {return v.clone() as f64}).collect();
102
103 let fir: Vec<f64> = [1.0, 1.0].iter().map(|v| v.clone()).collect();
104
105 let mut fir = FIR::new(fir);
106
107 let mut result = Vec::new();
108
109 for val in test_vector.iter() {
110 let mut input = Vec::new();
111 let slice_ref = &[*val][..];
112 input.push(slice_ref);
113 fir.process((input, 1));
114 result.push(fir.data.get(0).unwrap().clone());
115 }
116
117 let control_vector: Vec<f64> = [0, 1, 2, 2, 2, 2, 2, 2, 2].iter().map(|v| v.clone() as f64).collect();
118
119 assert_eq!(result, control_vector);
120 }
121
122 #[test]
123 fn test_fir_integrate() {
124
125 let mut test_vector: Vec<&[f64]> = Vec::new();
126 test_vector.push(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0][..]);
127
128 let fir: Vec<f64> = [1.0, -1.0].iter().map(|v| v.clone()).collect();
129 let mut fir = FIR::new(fir);
130
131 fir.process((test_vector, 9));
132 let result = Vec::from(fir.poll());
133
134 assert_eq!([0.0, 1.0, 1.0 ,1.0 , 1.0, 1.0, 1.0, 1.0, 1.0], result[..]);
135 }
136}
137