time_series_filter/
lib.rs1#![cfg_attr(not(test), no_std)]
2
3use core::ops::Range;
4use num_traits::{Float, PrimInt};
5
6pub trait EwmaFilter<T> {
7 fn push_sample(&mut self, new_value: T) -> T;
10
11 fn ewma_average(&self) -> T;
13
14 fn local_range(&self) -> Range<T>;
16}
17
18pub struct FloatSeriesEwmaFilter<T> {
21 sample_count: usize,
23 local_min: T,
25 local_max: T,
27 average: T,
29 alpha: T,
31}
32
33impl<T> FloatSeriesEwmaFilter<T>
34 where
35 T: Float + core::ops::AddAssign,
36{
37 pub fn new(alpha: T) -> Self {
38 Self {
39 sample_count: 0,
40 alpha,
41 local_min: T::zero(),
42 local_max: T::zero(),
43 average: T::zero(),
44 }
45 }
46
47 pub fn default() -> Self {
48 Self::new(T::from(0.01).unwrap())
49 }
50}
51
52impl<T> EwmaFilter<T> for FloatSeriesEwmaFilter<T>
53 where
54 T: Float + core::ops::AddAssign,
55{
56 fn push_sample(&mut self, new_value: T) -> T {
58 if self.sample_count == 0 {
59 self.local_min = new_value;
61 self.local_max = new_value;
62 self.average = new_value;
63 } else {
64 self.average += self.alpha * (new_value - self.average);
65
66 if new_value > self.local_max {
68 self.local_max = new_value;
69 } else if new_value > self.average {
70 self.local_max += self.alpha * (new_value - self.local_max);
71 }
72
73 if new_value < self.local_min {
74 self.local_min = new_value;
75 } else if new_value < self.average {
76 self.local_min += self.alpha * (new_value - self.local_min);
77 }
78 }
79 self.sample_count += 1;
80
81 self.average
82 }
83
84 fn ewma_average(&self) -> T {
85 self.average
86 }
87
88 fn local_range(&self) -> Range<T> {
89 self.local_min..self.local_max
90 }
91}
92
93pub struct IntSeriesEwmaFilter<T> {
94 sample_count: usize,
96
97 local_min: T,
99 local_max: T,
101 average: T,
103 alpha_numerator: T,
105 alpha_denominator: T,
106}
107
108impl<T> IntSeriesEwmaFilter<T>
109 where
110 T: PrimInt + core::ops::AddAssign,
111{
112 pub fn new(alpha_numerator: T, alpha_denominator: T) -> Self {
113 Self {
114 sample_count: 0,
115 alpha_numerator,
116 alpha_denominator,
117 local_min: T::zero(),
118 local_max: T::zero(),
119 average: T::zero(),
120 }
121 }
122
123 pub fn default() -> Self {
124 Self::new(T::one(), T::from(100).unwrap())
125 }
126}
127
128impl<T> EwmaFilter<T> for IntSeriesEwmaFilter<T>
129 where
130 T: PrimInt + core::ops::AddAssign + core::ops::SubAssign + core::fmt::Debug
131{
132 fn push_sample(&mut self, new_value: T) -> T {
134 if self.sample_count == 0 {
135 self.local_min = new_value;
137 self.local_max = new_value;
138 self.average = new_value;
139 }
140 else {
141 if new_value >= self.average {
142 let avg_diff = new_value - self.average;
143 let incr = (self.alpha_numerator * avg_diff) / self.alpha_denominator;
144 self.average += incr;
145 }
146 else {
147 let avg_diff = self.average - new_value;
148 let incr = (self.alpha_numerator * avg_diff) / self.alpha_denominator;
149 self.average -= incr;
150 };
151
152 if new_value > self.local_max {
154 self.local_max = new_value;
155 } else if new_value > self.average {
156 self.local_max -=
157 (self.alpha_numerator * (self.local_max - new_value)) / self.alpha_denominator;
158 }
159
160 if new_value < self.local_min {
161 self.local_min = new_value;
162 } else if new_value < self.average {
163 self.local_min +=
164 (self.alpha_numerator * (new_value - self.local_min)) / self.alpha_denominator;
165 }
166 }
167 self.sample_count += 1;
168
169 self.average
170 }
171
172 fn ewma_average(&self) -> T {
173 self.average
174 }
175
176 fn local_range(&self) -> Range<T> {
177 self.local_min..self.local_max
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use crate::{EwmaFilter, FloatSeriesEwmaFilter, IntSeriesEwmaFilter};
184 use assert_approx_eq::assert_approx_eq;
185
186 #[test]
187 fn test_float_basic() {
188 let mut tracko: FloatSeriesEwmaFilter<f32> = FloatSeriesEwmaFilter::default();
189 for i in 0..1000 {
190 tracko.push_sample(i as f32);
191 }
192 assert_approx_eq!(tracko.ewma_average(), 900.0, 1f32);
193
194 let mut tracko: FloatSeriesEwmaFilter<f32> = FloatSeriesEwmaFilter::new(0.01);
195 for i in 0..1000 {
196 tracko.push_sample(i as f32);
197 }
198 assert_approx_eq!(tracko.ewma_average(), 900.0, 1f32);
199 let range = tracko.local_range();
200 assert_eq!(range.end, 999.0);
201 assert_eq!(range.start, 0.0);
202 }
203
204 #[test]
205 fn test_integer_basic() {
206 let mut tracko: IntSeriesEwmaFilter<u32> = IntSeriesEwmaFilter::default();
207 for i in 0..1000 {
208 tracko.push_sample(i);
209 }
210 assert_eq!(tracko.ewma_average(), 900);
211
212 let mut tracko: IntSeriesEwmaFilter<u32> = IntSeriesEwmaFilter::new(1, 100);
213 for i in 0..1000 {
214 tracko.push_sample(i);
215 }
216 assert_eq!(tracko.ewma_average(), 900);
217
218 let range = tracko.local_range();
219 assert_eq!(range.end, 999);
220 assert_eq!(range.start, 0);
221 }
222
223 #[test]
224 fn test_float_up_down() {
225 let mut tracko: FloatSeriesEwmaFilter<f32> =
226 FloatSeriesEwmaFilter::new(1.0/50.0);
227
228 for _i in 0..100 {
229 tracko.push_sample(500.0);
230 tracko.push_sample(750.0);
231 tracko.push_sample(1000.0);
232 }
233
234 assert_approx_eq!(tracko.ewma_average(), 750.0, 5.0);
235
236 let range = tracko.local_range();
237 assert_eq!(range.start, 500.0);
238 assert_eq!(range.end, 1000.0);
239 }
240
241 #[test]
242 fn test_integer_up_down() {
243 let mut tracko: IntSeriesEwmaFilter<u32> =
244 IntSeriesEwmaFilter::new(1, 50);
245
246 for _i in 0..100 {
247 tracko.push_sample(500);
248 tracko.push_sample(750);
249 tracko.push_sample(1000);
250 }
251
252 assert_eq!(tracko.ewma_average(), 750);
253
254 let range = tracko.local_range();
255 assert_eq!(range.end, 1000);
256 assert_eq!(range.start, 500);
257 }
258
259
260}