streamweave_transformers/moving-average/
transformer.rs1use super::moving_average_transformer::{MovingAverageState, MovingAverageTransformer};
4use async_trait::async_trait;
5use futures::StreamExt;
6use std::pin::Pin;
7use std::sync::Arc;
8use streamweave::Input;
9use streamweave::Output;
10use streamweave::{Transformer, TransformerConfig};
11use streamweave_error::{ComponentInfo, ErrorAction, ErrorContext, ErrorStrategy, StreamError};
12use streamweave_stateful::{InMemoryStateStore, StateStore, StateStoreExt, StatefulTransformer};
13use tokio_stream::Stream;
14
15impl Input for MovingAverageTransformer {
16 type Input = f64;
17 type InputStream = Pin<Box<dyn Stream<Item = f64> + Send>>;
18}
19
20impl Output for MovingAverageTransformer {
21 type Output = f64;
22 type OutputStream = Pin<Box<dyn Stream<Item = f64> + Send>>;
23}
24
25#[derive(Debug)]
27pub struct SharedMovingAverageStore(pub Arc<InMemoryStateStore<MovingAverageState>>);
28
29impl Clone for SharedMovingAverageStore {
30 fn clone(&self) -> Self {
31 Self(Arc::clone(&self.0))
32 }
33}
34
35impl StateStore<MovingAverageState> for SharedMovingAverageStore {
36 fn get(&self) -> streamweave_stateful::StateResult<Option<MovingAverageState>> {
37 self.0.get()
38 }
39
40 fn set(&self, state: MovingAverageState) -> streamweave_stateful::StateResult<()> {
41 self.0.set(state)
42 }
43
44 fn update_with(
45 &self,
46 f: Box<dyn FnOnce(Option<MovingAverageState>) -> MovingAverageState + Send>,
47 ) -> streamweave_stateful::StateResult<MovingAverageState> {
48 self.0.update_with(f)
49 }
50
51 fn reset(&self) -> streamweave_stateful::StateResult<()> {
52 self.0.reset()
53 }
54
55 fn is_initialized(&self) -> bool {
56 self.0.is_initialized()
57 }
58
59 fn initial_state(&self) -> Option<MovingAverageState> {
60 self.0.initial_state()
61 }
62}
63
64#[async_trait]
65impl Transformer for MovingAverageTransformer {
66 type InputPorts = (f64,);
67 type OutputPorts = (f64,);
68
69 fn transform(&mut self, input: Self::InputStream) -> Self::OutputStream {
70 let state_store_clone = Arc::clone(&self.state_store);
71 let window_size = self.window_size;
72
73 input
74 .map(move |item| {
75 state_store_clone
76 .update(move |current_opt| {
77 let mut state = current_opt.unwrap_or_else(|| MovingAverageState::new(window_size));
78 state.add_value(item);
79 state
80 })
81 .map(|state| state.average())
82 .unwrap_or(0.0)
83 })
84 .boxed()
85 }
86
87 fn set_config_impl(&mut self, config: TransformerConfig<f64>) {
88 self.config = config;
89 }
90
91 fn get_config_impl(&self) -> &TransformerConfig<f64> {
92 &self.config
93 }
94
95 fn get_config_mut_impl(&mut self) -> &mut TransformerConfig<f64> {
96 &mut self.config
97 }
98
99 fn handle_error(&self, error: &StreamError<f64>) -> ErrorAction {
100 match &self.config.error_strategy {
101 ErrorStrategy::Stop => ErrorAction::Stop,
102 ErrorStrategy::Skip => ErrorAction::Skip,
103 ErrorStrategy::Retry(n) if error.retries < *n => ErrorAction::Retry,
104 ErrorStrategy::Custom(handler) => handler(error),
105 _ => ErrorAction::Stop,
106 }
107 }
108
109 fn create_error_context(&self, item: Option<f64>) -> ErrorContext<f64> {
110 ErrorContext {
111 timestamp: chrono::Utc::now(),
112 item,
113 component_name: self.component_info().name,
114 component_type: std::any::type_name::<Self>().to_string(),
115 }
116 }
117
118 fn component_info(&self) -> ComponentInfo {
119 ComponentInfo {
120 name: self
121 .config
122 .name
123 .clone()
124 .unwrap_or_else(|| "moving_average_transformer".to_string()),
125 type_name: std::any::type_name::<Self>().to_string(),
126 }
127 }
128}
129
130impl StatefulTransformer for MovingAverageTransformer {
131 type State = MovingAverageState;
132 type Store = SharedMovingAverageStore;
133
134 fn state_store(&self) -> &Self::Store {
135 unreachable!("Use state(), set_state(), reset_state() directly")
136 }
137
138 fn state_store_mut(&mut self) -> &mut Self::Store {
139 unreachable!("Use state(), set_state(), reset_state() directly")
140 }
141
142 fn state(&self) -> streamweave_stateful::StateResult<Option<Self::State>> {
143 self.state_store.get()
144 }
145
146 fn set_state(&self, state: Self::State) -> streamweave_stateful::StateResult<()> {
147 self.state_store.set(state)
148 }
149
150 fn reset_state(&self) -> streamweave_stateful::StateResult<()> {
151 self.state_store.reset()
152 }
153
154 fn has_state(&self) -> bool {
155 self.state_store.is_initialized()
156 }
157
158 fn update_state<F>(&self, f: F) -> streamweave_stateful::StateResult<Self::State>
159 where
160 F: FnOnce(Option<Self::State>) -> Self::State + Send + 'static,
161 {
162 self.state_store.update_with(Box::new(f))
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use futures::stream;
170
171 #[tokio::test]
172 async fn test_moving_average_basic() {
173 let mut transformer = MovingAverageTransformer::new(3);
174 let input = Box::pin(stream::iter(vec![1.0, 2.0, 3.0, 4.0, 5.0]));
175
176 let output = transformer.transform(input);
177 let results: Vec<f64> = output.collect().await;
178
179 assert_eq!(results, vec![1.0, 1.5, 2.0, 3.0, 4.0]);
186 }
187
188 #[tokio::test]
189 async fn test_moving_average_window_size_1() {
190 let mut transformer = MovingAverageTransformer::new(1);
191 let input = Box::pin(stream::iter(vec![5.0, 10.0, 15.0, 20.0]));
192
193 let output = transformer.transform(input);
194 let results: Vec<f64> = output.collect().await;
195
196 assert_eq!(results, vec![5.0, 10.0, 15.0, 20.0]);
198 }
199
200 #[tokio::test]
201 async fn test_moving_average_large_window() {
202 let mut transformer = MovingAverageTransformer::new(10);
203 let input = Box::pin(stream::iter(vec![1.0, 2.0, 3.0]));
204
205 let output = transformer.transform(input);
206 let results: Vec<f64> = output.collect().await;
207
208 assert_eq!(results, vec![1.0, 1.5, 2.0]);
213 }
214
215 #[tokio::test]
216 async fn test_moving_average_empty_input() {
217 let mut transformer = MovingAverageTransformer::new(3);
218 let input: Pin<Box<dyn Stream<Item = f64> + Send>> = Box::pin(stream::iter(vec![]));
219
220 let output = transformer.transform(input);
221 let results: Vec<f64> = output.collect().await;
222
223 assert!(results.is_empty());
224 }
225
226 #[tokio::test]
227 async fn test_moving_average_state_persistence() {
228 let mut transformer = MovingAverageTransformer::new(3);
229
230 let input1 = Box::pin(stream::iter(vec![3.0, 6.0, 9.0]));
232 let output1 = transformer.transform(input1);
233 let results1: Vec<f64> = output1.collect().await;
234 assert_eq!(results1, vec![3.0, 4.5, 6.0]);
235
236 let input2 = Box::pin(stream::iter(vec![12.0]));
238 let output2 = transformer.transform(input2);
239 let results2: Vec<f64> = output2.collect().await;
240 assert_eq!(results2, vec![9.0]);
242 }
243
244 #[tokio::test]
245 async fn test_moving_average_state_reset() {
246 let mut transformer = MovingAverageTransformer::new(3);
247
248 let input1 = Box::pin(stream::iter(vec![10.0, 20.0, 30.0]));
250 let output1 = transformer.transform(input1);
251 let _: Vec<f64> = output1.collect().await;
252
253 transformer.reset_state().unwrap();
255
256 let input2 = Box::pin(stream::iter(vec![1.0, 2.0]));
258 let output2 = transformer.transform(input2);
259 let results2: Vec<f64> = output2.collect().await;
260 assert_eq!(results2, vec![1.0, 1.5]);
261 }
262
263 #[tokio::test]
264 async fn test_moving_average_get_state() {
265 let mut transformer = MovingAverageTransformer::new(3);
266 let input = Box::pin(stream::iter(vec![2.0, 4.0, 6.0, 8.0]));
267
268 let output = transformer.transform(input);
269 let _: Vec<f64> = output.collect().await;
270
271 let final_state = transformer.state().unwrap().unwrap();
273 assert_eq!(final_state.window.len(), 3);
275 assert_eq!(final_state.average(), 6.0);
276 }
277
278 #[tokio::test]
279 async fn test_moving_average_component_info() {
280 let transformer = MovingAverageTransformer::new(5).with_name("my_moving_avg".to_string());
281
282 let info = transformer.component_info();
283 assert_eq!(info.name, "my_moving_avg");
284 assert!(info.type_name.contains("MovingAverageTransformer"));
285 }
286
287 #[tokio::test]
288 async fn test_moving_average_window_size() {
289 let transformer = MovingAverageTransformer::new(7);
290 assert_eq!(transformer.window_size(), 7);
291 }
292
293 #[tokio::test]
294 async fn test_moving_average_negative_values() {
295 let mut transformer = MovingAverageTransformer::new(2);
296 let input = Box::pin(stream::iter(vec![10.0, -4.0, 6.0]));
297
298 let output = transformer.transform(input);
299 let results: Vec<f64> = output.collect().await;
300
301 assert_eq!(results, vec![10.0, 3.0, 1.0]);
305 }
306
307 #[tokio::test]
308 async fn test_moving_average_has_state() {
309 let transformer = MovingAverageTransformer::new(3);
310 assert!(transformer.has_state());
311 }
312
313 #[tokio::test]
314 #[should_panic(expected = "Window size must be greater than 0")]
315 async fn test_moving_average_zero_window_panics() {
316 let _ = MovingAverageTransformer::new(0);
317 }
318
319 #[test]
320 fn test_moving_average_state_struct() {
321 let mut state = MovingAverageState::new(3);
322
323 assert_eq!(state.average(), 0.0); state.add_value(6.0);
326 assert_eq!(state.average(), 6.0);
327
328 state.add_value(12.0);
329 assert_eq!(state.average(), 9.0);
330
331 state.add_value(9.0);
332 assert_eq!(state.average(), 9.0);
333
334 state.add_value(3.0);
335 assert_eq!(state.average(), 8.0); }
337}