1use std::ops::{Add, AddAssign, Div, Mul, Sub};
2use std::time::{Duration, Instant};
3
4use num_traits::{One, Zero};
5
6pub const DEFAULT_EWMA_ALPHA_SHORT: f32 = 0.9;
7pub const DEFAULT_EWMA_ALPHA_MID: f32 = 0.7;
8pub const DEFAULT_EWMA_ALPHA_LONG: f32 = 0.2;
9
10#[derive(Clone, Default)]
11pub struct ChannelMetrics {
12 pub tx: Metrics,
13 pub rx: Metrics,
14}
15
16#[derive(Clone)]
17pub struct Metrics {
18 pub short: TimeWindow,
19 pub mid: TimeWindow,
20 pub long: TimeWindow,
21}
22
23impl Default for Metrics {
24 fn default() -> Self {
25 Self::new(
26 DEFAULT_EWMA_ALPHA_SHORT,
27 DEFAULT_EWMA_ALPHA_MID,
28 DEFAULT_EWMA_ALPHA_LONG,
29 )
30 .unwrap()
31 }
32}
33
34impl Metrics {
35 pub fn new(short_alpha: f32, mid_alpha: f32, long_alpha: f32) -> Result<Self, f32> {
36 let now = Instant::now();
37 let sample_freq = Duration::from_secs(1);
38
39 Ok(Self {
40 short: TimeWindow::new(now, sample_freq, Ewma::new(short_alpha)?),
41 mid: TimeWindow::new(now, sample_freq, Ewma::new(mid_alpha)?),
42 long: TimeWindow::new(now, sample_freq, Ewma::new(long_alpha)?),
43 })
44 }
45
46 pub fn push(&mut self, value: f32) {
47 let now = Instant::now();
48 self.short.push(value, now);
49 self.mid.push(value, now);
50 self.long.push(value, now);
51 }
52}
53
54pub trait Average<T>: Clone {
56 fn push(&mut self, value: T);
57 fn value(&self) -> T;
58}
59
60#[derive(Clone)]
62pub struct Ewma<T = f32>
63where
64 T: Clone,
65{
66 value: Option<T>,
67 alpha: T,
68 one_min_alpha: T,
69}
70
71impl<T> Ewma<T>
72where
73 T: Zero + One + Sub<Output = T> + PartialOrd + Clone,
74{
75 pub fn new(alpha: T) -> Result<Self, T> {
76 let zero = T::zero();
77 let one = T::one();
78
79 if alpha < zero || alpha > one {
80 return Err(alpha);
81 }
82
83 let one_min_alpha = one.sub(alpha.clone());
84
85 Ok(Self {
86 value: None,
87 alpha,
88 one_min_alpha,
89 })
90 }
91}
92
93impl<T> Average<T> for Ewma<T>
94where
95 T: Zero + Add<Output = T> + Mul<Output = T> + Clone,
96{
97 fn push(&mut self, value: T) {
98 let new_value = match self.value.take() {
99 Some(v) => self.alpha.clone().mul(value) + self.one_min_alpha.clone().mul(v),
100 None => value,
101 };
102 self.value.replace(new_value);
103 }
104
105 #[inline]
106 fn value(&self) -> T {
107 self.value.clone().unwrap_or_else(T::zero)
108 }
109}
110
111#[derive(Clone)]
112pub struct TimeWindow<V = f32, A = Ewma<V>>
113where
114 A: Average<V>,
115{
116 size: Duration,
117 updated: Instant,
118 acc: V,
119 total: V,
120 average: A,
121}
122
123impl<V, A> TimeWindow<V, A>
124where
125 V: Zero,
126 A: Average<V>,
127{
128 pub fn new(start: Instant, size: Duration, average: A) -> Self {
129 Self {
130 size,
131 updated: start.checked_sub(size).unwrap(),
132 acc: V::zero(),
133 total: V::zero(),
134 average,
135 }
136 }
137}
138
139impl<V, A> TimeWindow<V, A>
140where
141 V: Add<Output = V> + PartialEq + Zero + Clone,
142 A: Average<V>,
143{
144 #[inline]
145 pub fn average(&mut self, time: Instant) -> V {
146 self.advance(time);
147 self.average.value()
148 }
149
150 #[inline]
151 pub fn sum(&self) -> V {
152 self.total.clone()
153 }
154
155 pub fn push(&mut self, mut value: V, time: Instant) {
156 if time - self.updated < self.size {
157 self.acc = self.acc.clone() + value;
158 } else {
159 self.advance(time);
160 value = value + std::mem::replace(&mut self.acc, V::zero());
161 self.push_value(value, time);
162 }
163 }
164
165 fn advance(&mut self, time: Instant) {
166 if time <= self.updated {
167 return;
168 }
169
170 let size_ms = self.size.as_millis() as f64;
171 let cycles = ((time - self.updated).as_millis() as f64 + size_ms / 2.) / size_ms;
172 let samples = (cycles as usize).saturating_sub(1);
173 if samples == 0 {
174 return;
175 }
176
177 if !self.acc.eq(&V::zero()) {
178 let value = std::mem::replace(&mut self.acc, V::zero());
179 self.push_value(value, self.updated);
180 }
181 for _ in 0..samples {
182 self.push_value(V::zero(), self.updated + self.size);
183 }
184 }
185
186 fn push_value(&mut self, value: V, time: Instant) {
187 self.total = self.total.clone().add(value.clone());
188 self.average.push(value);
189 self.updated = time;
190 }
191}
192
193impl<V, A> Add for TimeWindow<V, A>
194where
195 V: Zero + Add<Output = V> + AddAssign + PartialEq + Clone,
196 A: Average<V> + Add<Output = A>,
197{
198 type Output = TimeWindow<V, A>;
199
200 fn add(mut self, mut rhs: TimeWindow<V, A>) -> TimeWindow<V, A> {
201 let timestamp = *std::cmp::max(&self.updated, &rhs.updated);
205
206 self.advance(timestamp);
207 rhs.advance(timestamp);
208
209 self.updated = timestamp;
210 self.acc += rhs.acc;
211 self.total += rhs.total;
212 self.average = self.average + rhs.average;
215
216 self
217 }
218}
219
220impl<T> Add for Ewma<T>
221where
222 T: Add<Output = T> + AddAssign + Div<f32, Output = T> + Clone,
223{
224 type Output = Ewma<T>;
225
226 fn add(mut self, rhs: Self) -> Self::Output {
227 self.value = match (self.value, rhs.value) {
232 (Some(val1), Some(val2)) => Some((val1 + val2) / 2.0f32),
233 (Some(val1), None) => Some(val1),
234 (None, Some(val2)) => Some(val2),
235 (None, None) => None,
236 };
237 self
238 }
239}
240
241impl Add for Metrics {
242 type Output = Metrics;
243
244 fn add(mut self, rhs: Self) -> Self::Output {
245 self.long = self.long + rhs.long;
246 self.mid = self.mid + rhs.mid;
247 self.short = self.short + rhs.short;
248
249 self
250 }
251}
252
253impl Add for ChannelMetrics {
254 type Output = ChannelMetrics;
255
256 fn add(mut self, rhs: Self) -> Self::Output {
257 self.rx = self.rx + rhs.rx;
258 self.tx = self.tx + rhs.tx;
259
260 self
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use std::time::{Duration, Instant};
267
268 use crate::metrics::{Ewma, TimeWindow};
269
270 fn assert_approx_eq(val: f64, expected: f64) {
272 assert!(val > expected - 0.01);
273 assert!(val < expected + 0.01);
274 }
275
276 #[test]
277 fn time_window_ewma_swift() {
278 let mut now = Instant::now();
279 let sample_freq = Duration::from_secs(1);
280 let until = now + Duration::from_secs(4);
281
282 let avg = Ewma::new(0.8_f64).expect("failed to create an instance of EWMA");
283 let mut tw = TimeWindow::new(now, sample_freq, avg);
284
285 while now <= until {
286 tw.push(0.1, now);
287 now += Duration::from_millis(sample_freq.as_millis() as u64 / 10);
288 }
289 assert_approx_eq(tw.average(now), 1.);
290 }
291
292 #[test]
293 fn time_window_ewma_steady() {
294 let mut now = Instant::now();
295 let sample_freq = Duration::from_secs(1);
296 let until = now + Duration::from_secs(4);
297
298 let avg = Ewma::new(0.8_f64).expect("failed to create an instance of EWMA");
299 let mut tw = TimeWindow::new(now, sample_freq, avg);
300
301 while now <= until {
302 tw.push(123., now);
303 now += sample_freq;
304 }
305 assert_approx_eq(tw.average(now), 123.);
306 }
307
308 #[test]
309 fn time_window_ewma_tardy() {
310 let mut now = Instant::now();
311 let sample_freq = Duration::from_secs(1);
312 let until = now + Duration::from_secs(8);
313
314 let avg = Ewma::new(0.2_f64).expect("failed to create an instance of EWMA");
315 let mut tw = TimeWindow::new(now, sample_freq, avg);
316
317 while now <= until {
318 tw.push(1., now);
319 now += sample_freq * 2;
320 }
321
322 assert_approx_eq(tw.average(now), 0.5);
323 }
324}