streamweave_transformers/running-sum/
transformer.rs

1//! Transformer implementations for RunningSumTransformer.
2
3use super::running_sum_transformer::RunningSumTransformer;
4use async_trait::async_trait;
5use futures::StreamExt;
6use std::fmt::Debug;
7use std::ops::Add;
8use std::pin::Pin;
9use std::sync::Arc;
10use streamweave::Input;
11use streamweave::Output;
12use streamweave::{Transformer, TransformerConfig};
13use streamweave_error::{ComponentInfo, ErrorAction, ErrorContext, ErrorStrategy, StreamError};
14use streamweave_stateful::{InMemoryStateStore, StateStore, StateStoreExt, StatefulTransformer};
15use tokio_stream::Stream;
16
17impl<T> Input for RunningSumTransformer<T>
18where
19  T: Add<Output = T> + Default + Debug + Clone + Send + Sync + 'static,
20{
21  type Input = T;
22  type InputStream = Pin<Box<dyn Stream<Item = T> + Send>>;
23}
24
25impl<T> Output for RunningSumTransformer<T>
26where
27  T: Add<Output = T> + Default + Debug + Clone + Send + Sync + 'static,
28{
29  type Output = T;
30  type OutputStream = Pin<Box<dyn Stream<Item = T> + Send>>;
31}
32
33/// Wrapper struct to implement StateStore on `Arc<InMemoryStateStore<T>>`
34#[derive(Debug)]
35pub struct SharedStateStore<T: Clone + Send + Sync>(pub Arc<InMemoryStateStore<T>>);
36
37impl<T: Clone + Send + Sync> Clone for SharedStateStore<T> {
38  fn clone(&self) -> Self {
39    Self(Arc::clone(&self.0))
40  }
41}
42
43impl<T: Clone + Send + Sync> StateStore<T> for SharedStateStore<T> {
44  fn get(&self) -> streamweave_stateful::StateResult<Option<T>> {
45    self.0.get()
46  }
47
48  fn set(&self, state: T) -> streamweave_stateful::StateResult<()> {
49    self.0.set(state)
50  }
51
52  fn update_with(
53    &self,
54    f: Box<dyn FnOnce(Option<T>) -> T + Send>,
55  ) -> streamweave_stateful::StateResult<T> {
56    self.0.update_with(f)
57  }
58
59  fn reset(&self) -> streamweave_stateful::StateResult<()> {
60    self.0.reset()
61  }
62
63  fn is_initialized(&self) -> bool {
64    self.0.is_initialized()
65  }
66
67  fn initial_state(&self) -> Option<T> {
68    self.0.initial_state()
69  }
70}
71
72#[async_trait]
73impl<T> Transformer for RunningSumTransformer<T>
74where
75  T: Add<Output = T> + Default + Debug + Clone + Send + Sync + 'static,
76{
77  type InputPorts = (T,);
78  type OutputPorts = (T,);
79  fn transform(&mut self, input: Self::InputStream) -> Self::OutputStream {
80    let state_store_clone = Arc::clone(&self.state_store);
81
82    input
83      .map(move |item| {
84        state_store_clone
85          .update(move |current_opt| {
86            let current = current_opt.unwrap_or_default();
87            current + item
88          })
89          .unwrap_or_else(|_| T::default())
90      })
91      .boxed()
92  }
93
94  fn set_config_impl(&mut self, config: TransformerConfig<T>) {
95    self.config = config;
96  }
97
98  fn get_config_impl(&self) -> &TransformerConfig<T> {
99    &self.config
100  }
101
102  fn get_config_mut_impl(&mut self) -> &mut TransformerConfig<T> {
103    &mut self.config
104  }
105
106  fn handle_error(&self, error: &StreamError<T>) -> ErrorAction {
107    match &self.config.error_strategy {
108      ErrorStrategy::Stop => ErrorAction::Stop,
109      ErrorStrategy::Skip => ErrorAction::Skip,
110      ErrorStrategy::Retry(n) if error.retries < *n => ErrorAction::Retry,
111      ErrorStrategy::Custom(handler) => handler(error),
112      _ => ErrorAction::Stop,
113    }
114  }
115
116  fn create_error_context(&self, item: Option<T>) -> ErrorContext<T> {
117    ErrorContext {
118      timestamp: chrono::Utc::now(),
119      item,
120      component_name: self.component_info().name,
121      component_type: std::any::type_name::<Self>().to_string(),
122    }
123  }
124
125  fn component_info(&self) -> ComponentInfo {
126    ComponentInfo {
127      name: self
128        .config
129        .name
130        .clone()
131        .unwrap_or_else(|| "running_sum_transformer".to_string()),
132      type_name: std::any::type_name::<Self>().to_string(),
133    }
134  }
135}
136
137impl<T> StatefulTransformer for RunningSumTransformer<T>
138where
139  T: Add<Output = T> + Default + Debug + Clone + Send + Sync + 'static,
140{
141  type State = T;
142  type Store = SharedStateStore<T>;
143
144  fn state_store(&self) -> &Self::Store {
145    // This is a bit awkward, but we need to return a reference to SharedStateStore
146    // For now, we'll work around this by implementing the state methods directly
147    unreachable!("Use state(), set_state(), reset_state() directly")
148  }
149
150  fn state_store_mut(&mut self) -> &mut Self::Store {
151    unreachable!("Use state(), set_state(), reset_state() directly")
152  }
153
154  fn state(&self) -> streamweave_stateful::StateResult<Option<Self::State>> {
155    self.state_store.get()
156  }
157
158  fn set_state(&self, state: Self::State) -> streamweave_stateful::StateResult<()> {
159    self.state_store.set(state)
160  }
161
162  fn reset_state(&self) -> streamweave_stateful::StateResult<()> {
163    self.state_store.reset()
164  }
165
166  fn has_state(&self) -> bool {
167    self.state_store.is_initialized()
168  }
169
170  fn update_state<F>(&self, f: F) -> streamweave_stateful::StateResult<Self::State>
171  where
172    F: FnOnce(Option<Self::State>) -> Self::State + Send + 'static,
173  {
174    self.state_store.update_with(Box::new(f))
175  }
176}
177
178#[cfg(test)]
179mod tests {
180  use super::*;
181  use futures::stream;
182
183  #[tokio::test]
184  async fn test_running_sum_basic() {
185    let mut transformer = RunningSumTransformer::<i32>::new();
186    let input = Box::pin(stream::iter(vec![1, 2, 3, 4, 5]));
187
188    let output = transformer.transform(input);
189    let results: Vec<i32> = output.collect().await;
190
191    assert_eq!(results, vec![1, 3, 6, 10, 15]);
192  }
193
194  #[tokio::test]
195  async fn test_running_sum_with_initial_value() {
196    let mut transformer = RunningSumTransformer::<i32>::with_initial(100);
197    let input = Box::pin(stream::iter(vec![1, 2, 3]));
198
199    let output = transformer.transform(input);
200    let results: Vec<i32> = output.collect().await;
201
202    assert_eq!(results, vec![101, 103, 106]);
203  }
204
205  #[tokio::test]
206  async fn test_running_sum_empty_input() {
207    let mut transformer = RunningSumTransformer::<i32>::new();
208    let input: Pin<Box<dyn Stream<Item = i32> + Send>> = Box::pin(stream::iter(vec![]));
209
210    let output = transformer.transform(input);
211    let results: Vec<i32> = output.collect().await;
212
213    assert!(results.is_empty());
214  }
215
216  #[tokio::test]
217  async fn test_running_sum_floats() {
218    let mut transformer = RunningSumTransformer::<f64>::new();
219    let input = Box::pin(stream::iter(vec![1.5, 2.5, 3.0]));
220
221    let output = transformer.transform(input);
222    let results: Vec<f64> = output.collect().await;
223
224    assert_eq!(results, vec![1.5, 4.0, 7.0]);
225  }
226
227  #[tokio::test]
228  async fn test_running_sum_state_persistence() {
229    let mut transformer = RunningSumTransformer::<i32>::new();
230
231    // First batch
232    let input1 = Box::pin(stream::iter(vec![1, 2, 3]));
233    let output1 = transformer.transform(input1);
234    let results1: Vec<i32> = output1.collect().await;
235    assert_eq!(results1, vec![1, 3, 6]);
236
237    // State should persist, so second batch continues from 6
238    let input2 = Box::pin(stream::iter(vec![4, 5]));
239    let output2 = transformer.transform(input2);
240    let results2: Vec<i32> = output2.collect().await;
241    assert_eq!(results2, vec![10, 15]);
242  }
243
244  #[tokio::test]
245  async fn test_running_sum_state_reset() {
246    let mut transformer = RunningSumTransformer::<i32>::new();
247
248    // Process some items
249    let input1 = Box::pin(stream::iter(vec![1, 2, 3]));
250    let output1 = transformer.transform(input1);
251    let _: Vec<i32> = output1.collect().await;
252
253    // Reset state
254    transformer.reset_state().unwrap();
255
256    // Should start from 0 again
257    let input2 = Box::pin(stream::iter(vec![10, 20]));
258    let output2 = transformer.transform(input2);
259    let results2: Vec<i32> = output2.collect().await;
260    assert_eq!(results2, vec![10, 30]);
261  }
262
263  #[tokio::test]
264  async fn test_running_sum_get_state() {
265    let mut transformer = RunningSumTransformer::<i32>::new();
266    let input = Box::pin(stream::iter(vec![1, 2, 3, 4, 5]));
267
268    let output = transformer.transform(input);
269    let _: Vec<i32> = output.collect().await;
270
271    // Get final state
272    let final_state = transformer.state().unwrap().unwrap();
273    assert_eq!(final_state, 15);
274  }
275
276  #[tokio::test]
277  async fn test_running_sum_component_info() {
278    let transformer = RunningSumTransformer::<i32>::new().with_name("my_running_sum".to_string());
279
280    let info = transformer.component_info();
281    assert_eq!(info.name, "my_running_sum");
282    assert!(info.type_name.contains("RunningSumTransformer"));
283  }
284
285  #[tokio::test]
286  async fn test_running_sum_negative_numbers() {
287    let mut transformer = RunningSumTransformer::<i32>::new();
288    let input = Box::pin(stream::iter(vec![10, -3, 5, -7]));
289
290    let output = transformer.transform(input);
291    let results: Vec<i32> = output.collect().await;
292
293    assert_eq!(results, vec![10, 7, 12, 5]);
294  }
295
296  #[tokio::test]
297  async fn test_running_sum_has_state() {
298    let transformer = RunningSumTransformer::<i32>::new();
299    assert!(transformer.has_state());
300  }
301
302  #[tokio::test]
303  async fn test_running_sum_set_state() {
304    let mut transformer = RunningSumTransformer::<i32>::new();
305
306    // Set state directly
307    transformer.set_state(100).unwrap();
308
309    // Process some items
310    let input = Box::pin(stream::iter(vec![1, 2, 3]));
311    let output = transformer.transform(input);
312    let results: Vec<i32> = output.collect().await;
313
314    assert_eq!(results, vec![101, 103, 106]);
315  }
316}