streamweave_transformers/moving-average/
transformer.rs

1//! Transformer implementations for MovingAverageTransformer.
2
3use super::moving_average_transformer::{MovingAverageState, MovingAverageTransformer};
4use async_trait::async_trait;
5use futures::StreamExt;
6use std::pin::Pin;
7use std::sync::Arc;
8use streamweave::Input;
9use streamweave::Output;
10use streamweave::{Transformer, TransformerConfig};
11use streamweave_error::{ComponentInfo, ErrorAction, ErrorContext, ErrorStrategy, StreamError};
12use streamweave_stateful::{InMemoryStateStore, StateStore, StateStoreExt, StatefulTransformer};
13use tokio_stream::Stream;
14
15impl Input for MovingAverageTransformer {
16  type Input = f64;
17  type InputStream = Pin<Box<dyn Stream<Item = f64> + Send>>;
18}
19
20impl Output for MovingAverageTransformer {
21  type Output = f64;
22  type OutputStream = Pin<Box<dyn Stream<Item = f64> + Send>>;
23}
24
25/// Wrapper struct to implement StateStore on `Arc<InMemoryStateStore<MovingAverageState>>`
26#[derive(Debug)]
27pub struct SharedMovingAverageStore(pub Arc<InMemoryStateStore<MovingAverageState>>);
28
29impl Clone for SharedMovingAverageStore {
30  fn clone(&self) -> Self {
31    Self(Arc::clone(&self.0))
32  }
33}
34
35impl StateStore<MovingAverageState> for SharedMovingAverageStore {
36  fn get(&self) -> streamweave_stateful::StateResult<Option<MovingAverageState>> {
37    self.0.get()
38  }
39
40  fn set(&self, state: MovingAverageState) -> streamweave_stateful::StateResult<()> {
41    self.0.set(state)
42  }
43
44  fn update_with(
45    &self,
46    f: Box<dyn FnOnce(Option<MovingAverageState>) -> MovingAverageState + Send>,
47  ) -> streamweave_stateful::StateResult<MovingAverageState> {
48    self.0.update_with(f)
49  }
50
51  fn reset(&self) -> streamweave_stateful::StateResult<()> {
52    self.0.reset()
53  }
54
55  fn is_initialized(&self) -> bool {
56    self.0.is_initialized()
57  }
58
59  fn initial_state(&self) -> Option<MovingAverageState> {
60    self.0.initial_state()
61  }
62}
63
64#[async_trait]
65impl Transformer for MovingAverageTransformer {
66  type InputPorts = (f64,);
67  type OutputPorts = (f64,);
68
69  fn transform(&mut self, input: Self::InputStream) -> Self::OutputStream {
70    let state_store_clone = Arc::clone(&self.state_store);
71    let window_size = self.window_size;
72
73    input
74      .map(move |item| {
75        state_store_clone
76          .update(move |current_opt| {
77            let mut state = current_opt.unwrap_or_else(|| MovingAverageState::new(window_size));
78            state.add_value(item);
79            state
80          })
81          .map(|state| state.average())
82          .unwrap_or(0.0)
83      })
84      .boxed()
85  }
86
87  fn set_config_impl(&mut self, config: TransformerConfig<f64>) {
88    self.config = config;
89  }
90
91  fn get_config_impl(&self) -> &TransformerConfig<f64> {
92    &self.config
93  }
94
95  fn get_config_mut_impl(&mut self) -> &mut TransformerConfig<f64> {
96    &mut self.config
97  }
98
99  fn handle_error(&self, error: &StreamError<f64>) -> ErrorAction {
100    match &self.config.error_strategy {
101      ErrorStrategy::Stop => ErrorAction::Stop,
102      ErrorStrategy::Skip => ErrorAction::Skip,
103      ErrorStrategy::Retry(n) if error.retries < *n => ErrorAction::Retry,
104      ErrorStrategy::Custom(handler) => handler(error),
105      _ => ErrorAction::Stop,
106    }
107  }
108
109  fn create_error_context(&self, item: Option<f64>) -> ErrorContext<f64> {
110    ErrorContext {
111      timestamp: chrono::Utc::now(),
112      item,
113      component_name: self.component_info().name,
114      component_type: std::any::type_name::<Self>().to_string(),
115    }
116  }
117
118  fn component_info(&self) -> ComponentInfo {
119    ComponentInfo {
120      name: self
121        .config
122        .name
123        .clone()
124        .unwrap_or_else(|| "moving_average_transformer".to_string()),
125      type_name: std::any::type_name::<Self>().to_string(),
126    }
127  }
128}
129
130impl StatefulTransformer for MovingAverageTransformer {
131  type State = MovingAverageState;
132  type Store = SharedMovingAverageStore;
133
134  fn state_store(&self) -> &Self::Store {
135    unreachable!("Use state(), set_state(), reset_state() directly")
136  }
137
138  fn state_store_mut(&mut self) -> &mut Self::Store {
139    unreachable!("Use state(), set_state(), reset_state() directly")
140  }
141
142  fn state(&self) -> streamweave_stateful::StateResult<Option<Self::State>> {
143    self.state_store.get()
144  }
145
146  fn set_state(&self, state: Self::State) -> streamweave_stateful::StateResult<()> {
147    self.state_store.set(state)
148  }
149
150  fn reset_state(&self) -> streamweave_stateful::StateResult<()> {
151    self.state_store.reset()
152  }
153
154  fn has_state(&self) -> bool {
155    self.state_store.is_initialized()
156  }
157
158  fn update_state<F>(&self, f: F) -> streamweave_stateful::StateResult<Self::State>
159  where
160    F: FnOnce(Option<Self::State>) -> Self::State + Send + 'static,
161  {
162    self.state_store.update_with(Box::new(f))
163  }
164}
165
166#[cfg(test)]
167mod tests {
168  use super::*;
169  use futures::stream;
170
171  #[tokio::test]
172  async fn test_moving_average_basic() {
173    let mut transformer = MovingAverageTransformer::new(3);
174    let input = Box::pin(stream::iter(vec![1.0, 2.0, 3.0, 4.0, 5.0]));
175
176    let output = transformer.transform(input);
177    let results: Vec<f64> = output.collect().await;
178
179    // Window progression:
180    // [1] -> 1.0
181    // [1,2] -> 1.5
182    // [1,2,3] -> 2.0
183    // [2,3,4] -> 3.0
184    // [3,4,5] -> 4.0
185    assert_eq!(results, vec![1.0, 1.5, 2.0, 3.0, 4.0]);
186  }
187
188  #[tokio::test]
189  async fn test_moving_average_window_size_1() {
190    let mut transformer = MovingAverageTransformer::new(1);
191    let input = Box::pin(stream::iter(vec![5.0, 10.0, 15.0, 20.0]));
192
193    let output = transformer.transform(input);
194    let results: Vec<f64> = output.collect().await;
195
196    // With window size 1, output equals input
197    assert_eq!(results, vec![5.0, 10.0, 15.0, 20.0]);
198  }
199
200  #[tokio::test]
201  async fn test_moving_average_large_window() {
202    let mut transformer = MovingAverageTransformer::new(10);
203    let input = Box::pin(stream::iter(vec![1.0, 2.0, 3.0]));
204
205    let output = transformer.transform(input);
206    let results: Vec<f64> = output.collect().await;
207
208    // Window never fills, so we get cumulative averages
209    // [1] -> 1.0
210    // [1,2] -> 1.5
211    // [1,2,3] -> 2.0
212    assert_eq!(results, vec![1.0, 1.5, 2.0]);
213  }
214
215  #[tokio::test]
216  async fn test_moving_average_empty_input() {
217    let mut transformer = MovingAverageTransformer::new(3);
218    let input: Pin<Box<dyn Stream<Item = f64> + Send>> = Box::pin(stream::iter(vec![]));
219
220    let output = transformer.transform(input);
221    let results: Vec<f64> = output.collect().await;
222
223    assert!(results.is_empty());
224  }
225
226  #[tokio::test]
227  async fn test_moving_average_state_persistence() {
228    let mut transformer = MovingAverageTransformer::new(3);
229
230    // First batch
231    let input1 = Box::pin(stream::iter(vec![3.0, 6.0, 9.0]));
232    let output1 = transformer.transform(input1);
233    let results1: Vec<f64> = output1.collect().await;
234    assert_eq!(results1, vec![3.0, 4.5, 6.0]);
235
236    // State should persist, window is [3, 6, 9]
237    let input2 = Box::pin(stream::iter(vec![12.0]));
238    let output2 = transformer.transform(input2);
239    let results2: Vec<f64> = output2.collect().await;
240    // Window becomes [6, 9, 12] -> avg 9.0
241    assert_eq!(results2, vec![9.0]);
242  }
243
244  #[tokio::test]
245  async fn test_moving_average_state_reset() {
246    let mut transformer = MovingAverageTransformer::new(3);
247
248    // Process some items
249    let input1 = Box::pin(stream::iter(vec![10.0, 20.0, 30.0]));
250    let output1 = transformer.transform(input1);
251    let _: Vec<f64> = output1.collect().await;
252
253    // Reset state
254    transformer.reset_state().unwrap();
255
256    // Should start fresh
257    let input2 = Box::pin(stream::iter(vec![1.0, 2.0]));
258    let output2 = transformer.transform(input2);
259    let results2: Vec<f64> = output2.collect().await;
260    assert_eq!(results2, vec![1.0, 1.5]);
261  }
262
263  #[tokio::test]
264  async fn test_moving_average_get_state() {
265    let mut transformer = MovingAverageTransformer::new(3);
266    let input = Box::pin(stream::iter(vec![2.0, 4.0, 6.0, 8.0]));
267
268    let output = transformer.transform(input);
269    let _: Vec<f64> = output.collect().await;
270
271    // Get final state
272    let final_state = transformer.state().unwrap().unwrap();
273    // Window should be [4, 6, 8]
274    assert_eq!(final_state.window.len(), 3);
275    assert_eq!(final_state.average(), 6.0);
276  }
277
278  #[tokio::test]
279  async fn test_moving_average_component_info() {
280    let transformer = MovingAverageTransformer::new(5).with_name("my_moving_avg".to_string());
281
282    let info = transformer.component_info();
283    assert_eq!(info.name, "my_moving_avg");
284    assert!(info.type_name.contains("MovingAverageTransformer"));
285  }
286
287  #[tokio::test]
288  async fn test_moving_average_window_size() {
289    let transformer = MovingAverageTransformer::new(7);
290    assert_eq!(transformer.window_size(), 7);
291  }
292
293  #[tokio::test]
294  async fn test_moving_average_negative_values() {
295    let mut transformer = MovingAverageTransformer::new(2);
296    let input = Box::pin(stream::iter(vec![10.0, -4.0, 6.0]));
297
298    let output = transformer.transform(input);
299    let results: Vec<f64> = output.collect().await;
300
301    // [10] -> 10.0
302    // [10, -4] -> 3.0
303    // [-4, 6] -> 1.0
304    assert_eq!(results, vec![10.0, 3.0, 1.0]);
305  }
306
307  #[tokio::test]
308  async fn test_moving_average_has_state() {
309    let transformer = MovingAverageTransformer::new(3);
310    assert!(transformer.has_state());
311  }
312
313  #[tokio::test]
314  #[should_panic(expected = "Window size must be greater than 0")]
315  async fn test_moving_average_zero_window_panics() {
316    let _ = MovingAverageTransformer::new(0);
317  }
318
319  #[test]
320  fn test_moving_average_state_struct() {
321    let mut state = MovingAverageState::new(3);
322
323    assert_eq!(state.average(), 0.0); // Empty window
324
325    state.add_value(6.0);
326    assert_eq!(state.average(), 6.0);
327
328    state.add_value(12.0);
329    assert_eq!(state.average(), 9.0);
330
331    state.add_value(9.0);
332    assert_eq!(state.average(), 9.0);
333
334    state.add_value(3.0);
335    assert_eq!(state.average(), 8.0); // [12, 9, 3]
336  }
337}