time_series_filter/
lib.rs

1#![cfg_attr(not(test), no_std)]
2
3use core::ops::Range;
4use num_traits::{Float, PrimInt};
5
6pub trait EwmaFilter<T> {
7  /// Push the next sample in the series into the filter.
8  /// Returns exponentially weighted moving average
9  fn push_sample(&mut self, new_value: T) -> T;
10
11  /// Returns cached exponentially weighted moving average
12  fn ewma_average(&self) -> T;
13
14  /// returns the local minima and maxima
15  fn local_range(&self) -> Range<T>;
16}
17
18/// Implements exponential weighted moving average of time series samples,
19/// including exponentially fading minimum and maximum
20pub struct FloatSeriesEwmaFilter<T> {
21  /// number of samples that have been pushed through the filter
22  sample_count: usize,
23  /// recent minimum value (not global minimum)
24  local_min: T,
25  /// recent maximum value (not global maximum)
26  local_max: T,
27  /// exponentially weighted moving average
28  average: T,
29  /// weighting factor-- bigger alpha causes faster fade of old values
30  alpha: T,
31}
32
33impl<T> FloatSeriesEwmaFilter<T>
34  where
35    T: Float + core::ops::AddAssign,
36{
37  pub fn new(alpha: T) -> Self {
38    Self {
39      sample_count: 0,
40      alpha,
41      local_min: T::zero(),
42      local_max: T::zero(),
43      average: T::zero(),
44    }
45  }
46
47  pub fn default() -> Self {
48    Self::new(T::from(0.01).unwrap())
49  }
50}
51
52impl<T> EwmaFilter<T> for FloatSeriesEwmaFilter<T>
53  where
54    T: Float + core::ops::AddAssign,
55{
56  /// Returns exponentially weighted moving average
57  fn push_sample(&mut self, new_value: T) -> T {
58    if self.sample_count == 0 {
59      //seed the EMWA with the initial value
60      self.local_min = new_value;
61      self.local_max = new_value;
62      self.average = new_value;
63    } else {
64      self.average += self.alpha * (new_value - self.average);
65
66      // extrema fade toward average
67      if new_value > self.local_max {
68        self.local_max = new_value;
69      } else if new_value > self.average {
70        self.local_max += self.alpha * (new_value - self.local_max);
71      }
72
73      if new_value < self.local_min {
74        self.local_min = new_value;
75      } else if new_value < self.average {
76        self.local_min += self.alpha * (new_value - self.local_min);
77      }
78    }
79    self.sample_count += 1;
80
81    self.average
82  }
83
84  fn ewma_average(&self) -> T {
85    self.average
86  }
87
88  fn local_range(&self) -> Range<T> {
89    self.local_min..self.local_max
90  }
91}
92
93pub struct IntSeriesEwmaFilter<T> {
94  /// sample count
95  sample_count: usize,
96
97  /// recent minimum value (not global minimum)
98  local_min: T,
99  /// recent maximum value (not global maximum)
100  local_max: T,
101  /// exponentially weighted moving average
102  average: T,
103  /// weighting factor-- bigger alpha causes faster fade of old values
104  alpha_numerator: T,
105  alpha_denominator: T,
106}
107
108impl<T> IntSeriesEwmaFilter<T>
109  where
110    T: PrimInt + core::ops::AddAssign,
111{
112  pub fn new(alpha_numerator: T, alpha_denominator: T) -> Self {
113    Self {
114      sample_count: 0,
115      alpha_numerator,
116      alpha_denominator,
117      local_min: T::zero(),
118      local_max: T::zero(),
119      average: T::zero(),
120    }
121  }
122
123  pub fn default() -> Self {
124    Self::new(T::one(), T::from(100).unwrap())
125  }
126}
127
128impl<T> EwmaFilter<T> for IntSeriesEwmaFilter<T>
129  where
130    T: PrimInt + core::ops::AddAssign + core::ops::SubAssign + core::fmt::Debug
131{
132  /// Returns exponentially weighted moving average
133  fn push_sample(&mut self, new_value: T) -> T {
134    if self.sample_count == 0 {
135      //seed the EMWA with the initial value
136      self.local_min = new_value;
137      self.local_max = new_value;
138      self.average = new_value;
139    }
140    else {
141      if new_value >= self.average {
142        let avg_diff = new_value - self.average;
143        let incr = (self.alpha_numerator * avg_diff) / self.alpha_denominator;
144         self.average += incr;
145      }
146      else {
147        let avg_diff = self.average - new_value;
148        let incr = (self.alpha_numerator * avg_diff) / self.alpha_denominator;
149        self.average -= incr;
150      };
151
152      // extrema fade toward average
153      if new_value > self.local_max {
154        self.local_max = new_value;
155      } else if new_value > self.average {
156        self.local_max -=
157          (self.alpha_numerator * (self.local_max - new_value)) / self.alpha_denominator;
158      }
159
160      if new_value < self.local_min {
161        self.local_min = new_value;
162      } else if new_value < self.average {
163        self.local_min +=
164          (self.alpha_numerator * (new_value - self.local_min)) / self.alpha_denominator;
165      }
166    }
167    self.sample_count += 1;
168
169    self.average
170  }
171
172  fn ewma_average(&self) -> T {
173    self.average
174  }
175
176  fn local_range(&self) -> Range<T> {
177    self.local_min..self.local_max
178  }
179}
180
181#[cfg(test)]
182mod tests {
183  use crate::{EwmaFilter, FloatSeriesEwmaFilter, IntSeriesEwmaFilter};
184  use assert_approx_eq::assert_approx_eq;
185
186  #[test]
187  fn test_float_basic() {
188    let mut tracko: FloatSeriesEwmaFilter<f32> = FloatSeriesEwmaFilter::default();
189    for i in 0..1000 {
190      tracko.push_sample(i as f32);
191    }
192    assert_approx_eq!(tracko.ewma_average(), 900.0, 1f32);
193
194    let mut tracko: FloatSeriesEwmaFilter<f32> = FloatSeriesEwmaFilter::new(0.01);
195    for i in 0..1000 {
196      tracko.push_sample(i as f32);
197    }
198    assert_approx_eq!(tracko.ewma_average(), 900.0, 1f32);
199    let range = tracko.local_range();
200    assert_eq!(range.end, 999.0);
201    assert_eq!(range.start, 0.0);
202  }
203
204  #[test]
205  fn test_integer_basic() {
206    let mut tracko: IntSeriesEwmaFilter<u32> = IntSeriesEwmaFilter::default();
207    for i in 0..1000 {
208      tracko.push_sample(i);
209    }
210    assert_eq!(tracko.ewma_average(), 900);
211
212    let mut tracko: IntSeriesEwmaFilter<u32> = IntSeriesEwmaFilter::new(1, 100);
213    for i in 0..1000 {
214      tracko.push_sample(i);
215    }
216    assert_eq!(tracko.ewma_average(), 900);
217
218    let range = tracko.local_range();
219    assert_eq!(range.end, 999);
220    assert_eq!(range.start, 0);
221  }
222
223  #[test]
224  fn test_float_up_down() {
225    let mut tracko: FloatSeriesEwmaFilter<f32> =
226      FloatSeriesEwmaFilter::new(1.0/50.0);
227
228    for _i in 0..100 {
229      tracko.push_sample(500.0);
230      tracko.push_sample(750.0);
231      tracko.push_sample(1000.0);
232    }
233
234    assert_approx_eq!(tracko.ewma_average(), 750.0, 5.0);
235
236    let range = tracko.local_range();
237    assert_eq!(range.start, 500.0);
238    assert_eq!(range.end, 1000.0);
239  }
240
241  #[test]
242  fn test_integer_up_down() {
243    let mut tracko: IntSeriesEwmaFilter<u32> =
244      IntSeriesEwmaFilter::new(1, 50);
245
246    for _i in 0..100 {
247      tracko.push_sample(500);
248      tracko.push_sample(750);
249      tracko.push_sample(1000);
250    }
251
252    assert_eq!(tracko.ewma_average(), 750);
253
254    let range = tracko.local_range();
255    assert_eq!(range.end, 1000);
256    assert_eq!(range.start, 500);
257  }
258
259
260}