streamweave_transformers/moving-average/
moving_average_transformer.rs1use std::collections::VecDeque;
4use std::sync::Arc;
5use streamweave::TransformerConfig;
6use streamweave_error::ErrorStrategy;
7use streamweave_stateful::InMemoryStateStore;
8
9#[derive(Debug, Clone)]
13pub struct MovingAverageState {
14 pub window: VecDeque<f64>,
16 pub window_size: usize,
18}
19
20impl MovingAverageState {
21 pub fn new(window_size: usize) -> Self {
23 Self {
24 window: VecDeque::with_capacity(window_size),
25 window_size,
26 }
27 }
28
29 pub fn add_value(&mut self, value: f64) {
31 if self.window.len() >= self.window_size {
32 self.window.pop_front();
33 }
34 self.window.push_back(value);
35 }
36
37 pub fn average(&self) -> f64 {
39 if self.window.is_empty() {
40 return 0.0;
41 }
42 let sum: f64 = self.window.iter().sum();
43 sum / self.window.len() as f64
44 }
45}
46
47#[derive(Debug)]
70pub struct MovingAverageTransformer {
71 pub(crate) config: TransformerConfig<f64>,
73 pub(crate) state_store: Arc<InMemoryStateStore<MovingAverageState>>,
75 pub(crate) window_size: usize,
77}
78
79impl Clone for MovingAverageTransformer {
80 fn clone(&self) -> Self {
81 Self {
82 config: self.config.clone(),
83 state_store: Arc::clone(&self.state_store),
84 window_size: self.window_size,
85 }
86 }
87}
88
89impl MovingAverageTransformer {
90 pub fn new(window_size: usize) -> Self {
100 assert!(window_size > 0, "Window size must be greater than 0");
101 Self {
102 config: TransformerConfig::default(),
103 state_store: Arc::new(InMemoryStateStore::new(MovingAverageState::new(
104 window_size,
105 ))),
106 window_size,
107 }
108 }
109
110 pub fn with_name(mut self, name: String) -> Self {
112 self.config.name = Some(name);
113 self
114 }
115
116 pub fn with_error_strategy(mut self, strategy: ErrorStrategy<f64>) -> Self {
118 self.config.error_strategy = strategy;
119 self
120 }
121
122 pub fn window_size(&self) -> usize {
124 self.window_size
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 #[test]
133 fn test_moving_average_state_new() {
134 let state = MovingAverageState::new(5);
135 assert_eq!(state.window_size, 5);
136 assert!(state.window.is_empty());
137 }
138
139 #[test]
140 fn test_moving_average_state_add_value() {
141 let mut state = MovingAverageState::new(3);
142 state.add_value(10.0);
143 assert_eq!(state.window.len(), 1);
144 assert_eq!(state.window[0], 10.0);
145 }
146
147 #[test]
148 fn test_moving_average_state_add_value_overflow() {
149 let mut state = MovingAverageState::new(2);
150 state.add_value(1.0);
151 state.add_value(2.0);
152 state.add_value(3.0); assert_eq!(state.window.len(), 2);
154 assert_eq!(state.window[0], 2.0);
155 assert_eq!(state.window[1], 3.0);
156 }
157
158 #[test]
159 fn test_moving_average_state_average_empty() {
160 let state = MovingAverageState::new(3);
161 assert_eq!(state.average(), 0.0);
162 }
163
164 #[test]
165 fn test_moving_average_state_average_single() {
166 let mut state = MovingAverageState::new(3);
167 state.add_value(10.0);
168 assert_eq!(state.average(), 10.0);
169 }
170
171 #[test]
172 fn test_moving_average_state_average_multiple() {
173 let mut state = MovingAverageState::new(3);
174 state.add_value(10.0);
175 state.add_value(20.0);
176 state.add_value(30.0);
177 assert_eq!(state.average(), 20.0);
178 }
179
180 #[test]
181 fn test_moving_average_state_clone() {
182 let mut state1 = MovingAverageState::new(3);
183 state1.add_value(5.0);
184 state1.add_value(10.0);
185
186 let state2 = state1.clone();
187 assert_eq!(state1.window, state2.window);
188 assert_eq!(state1.window_size, state2.window_size);
189 assert_eq!(state1.average(), state2.average());
190 }
191
192 #[test]
193 fn test_moving_average_transformer_new() {
194 let transformer = MovingAverageTransformer::new(5);
195 assert_eq!(transformer.window_size(), 5);
196 }
197
198 #[test]
199 #[should_panic(expected = "Window size must be greater than 0")]
200 fn test_moving_average_transformer_new_zero_panics() {
201 let _ = MovingAverageTransformer::new(0);
202 }
203
204 #[test]
205 fn test_moving_average_transformer_with_name() {
206 let transformer = MovingAverageTransformer::new(3).with_name("test_moving_avg".to_string());
207 assert_eq!(transformer.config.name, Some("test_moving_avg".to_string()));
208 }
209
210 #[test]
211 fn test_moving_average_transformer_with_error_strategy() {
212 let transformer =
213 MovingAverageTransformer::new(3).with_error_strategy(ErrorStrategy::<f64>::Skip);
214 assert!(matches!(
215 transformer.config.error_strategy,
216 ErrorStrategy::Skip
217 ));
218 }
219
220 #[test]
221 fn test_moving_average_transformer_clone() {
222 let transformer1 = MovingAverageTransformer::new(5);
223 let transformer2 = transformer1.clone();
224
225 assert_eq!(transformer1.window_size(), transformer2.window_size());
226 assert_eq!(transformer1.window_size, transformer2.window_size);
227 }
228
229 #[test]
230 fn test_moving_average_transformer_chaining() {
231 let transformer = MovingAverageTransformer::new(7)
232 .with_error_strategy(ErrorStrategy::<f64>::Retry(3))
233 .with_name("chained_moving_avg".to_string());
234
235 assert_eq!(transformer.window_size(), 7);
236 assert!(matches!(
237 transformer.config.error_strategy,
238 ErrorStrategy::Retry(3)
239 ));
240 assert_eq!(
241 transformer.config.name,
242 Some("chained_moving_avg".to_string())
243 );
244 }
245}