streamweave_transformers/running-sum/
transformer.rs1use super::running_sum_transformer::RunningSumTransformer;
4use async_trait::async_trait;
5use futures::StreamExt;
6use std::fmt::Debug;
7use std::ops::Add;
8use std::pin::Pin;
9use std::sync::Arc;
10use streamweave::Input;
11use streamweave::Output;
12use streamweave::{Transformer, TransformerConfig};
13use streamweave_error::{ComponentInfo, ErrorAction, ErrorContext, ErrorStrategy, StreamError};
14use streamweave_stateful::{InMemoryStateStore, StateStore, StateStoreExt, StatefulTransformer};
15use tokio_stream::Stream;
16
17impl<T> Input for RunningSumTransformer<T>
18where
19 T: Add<Output = T> + Default + Debug + Clone + Send + Sync + 'static,
20{
21 type Input = T;
22 type InputStream = Pin<Box<dyn Stream<Item = T> + Send>>;
23}
24
25impl<T> Output for RunningSumTransformer<T>
26where
27 T: Add<Output = T> + Default + Debug + Clone + Send + Sync + 'static,
28{
29 type Output = T;
30 type OutputStream = Pin<Box<dyn Stream<Item = T> + Send>>;
31}
32
33#[derive(Debug)]
35pub struct SharedStateStore<T: Clone + Send + Sync>(pub Arc<InMemoryStateStore<T>>);
36
37impl<T: Clone + Send + Sync> Clone for SharedStateStore<T> {
38 fn clone(&self) -> Self {
39 Self(Arc::clone(&self.0))
40 }
41}
42
43impl<T: Clone + Send + Sync> StateStore<T> for SharedStateStore<T> {
44 fn get(&self) -> streamweave_stateful::StateResult<Option<T>> {
45 self.0.get()
46 }
47
48 fn set(&self, state: T) -> streamweave_stateful::StateResult<()> {
49 self.0.set(state)
50 }
51
52 fn update_with(
53 &self,
54 f: Box<dyn FnOnce(Option<T>) -> T + Send>,
55 ) -> streamweave_stateful::StateResult<T> {
56 self.0.update_with(f)
57 }
58
59 fn reset(&self) -> streamweave_stateful::StateResult<()> {
60 self.0.reset()
61 }
62
63 fn is_initialized(&self) -> bool {
64 self.0.is_initialized()
65 }
66
67 fn initial_state(&self) -> Option<T> {
68 self.0.initial_state()
69 }
70}
71
72#[async_trait]
73impl<T> Transformer for RunningSumTransformer<T>
74where
75 T: Add<Output = T> + Default + Debug + Clone + Send + Sync + 'static,
76{
77 type InputPorts = (T,);
78 type OutputPorts = (T,);
79 fn transform(&mut self, input: Self::InputStream) -> Self::OutputStream {
80 let state_store_clone = Arc::clone(&self.state_store);
81
82 input
83 .map(move |item| {
84 state_store_clone
85 .update(move |current_opt| {
86 let current = current_opt.unwrap_or_default();
87 current + item
88 })
89 .unwrap_or_else(|_| T::default())
90 })
91 .boxed()
92 }
93
94 fn set_config_impl(&mut self, config: TransformerConfig<T>) {
95 self.config = config;
96 }
97
98 fn get_config_impl(&self) -> &TransformerConfig<T> {
99 &self.config
100 }
101
102 fn get_config_mut_impl(&mut self) -> &mut TransformerConfig<T> {
103 &mut self.config
104 }
105
106 fn handle_error(&self, error: &StreamError<T>) -> ErrorAction {
107 match &self.config.error_strategy {
108 ErrorStrategy::Stop => ErrorAction::Stop,
109 ErrorStrategy::Skip => ErrorAction::Skip,
110 ErrorStrategy::Retry(n) if error.retries < *n => ErrorAction::Retry,
111 ErrorStrategy::Custom(handler) => handler(error),
112 _ => ErrorAction::Stop,
113 }
114 }
115
116 fn create_error_context(&self, item: Option<T>) -> ErrorContext<T> {
117 ErrorContext {
118 timestamp: chrono::Utc::now(),
119 item,
120 component_name: self.component_info().name,
121 component_type: std::any::type_name::<Self>().to_string(),
122 }
123 }
124
125 fn component_info(&self) -> ComponentInfo {
126 ComponentInfo {
127 name: self
128 .config
129 .name
130 .clone()
131 .unwrap_or_else(|| "running_sum_transformer".to_string()),
132 type_name: std::any::type_name::<Self>().to_string(),
133 }
134 }
135}
136
137impl<T> StatefulTransformer for RunningSumTransformer<T>
138where
139 T: Add<Output = T> + Default + Debug + Clone + Send + Sync + 'static,
140{
141 type State = T;
142 type Store = SharedStateStore<T>;
143
144 fn state_store(&self) -> &Self::Store {
145 unreachable!("Use state(), set_state(), reset_state() directly")
148 }
149
150 fn state_store_mut(&mut self) -> &mut Self::Store {
151 unreachable!("Use state(), set_state(), reset_state() directly")
152 }
153
154 fn state(&self) -> streamweave_stateful::StateResult<Option<Self::State>> {
155 self.state_store.get()
156 }
157
158 fn set_state(&self, state: Self::State) -> streamweave_stateful::StateResult<()> {
159 self.state_store.set(state)
160 }
161
162 fn reset_state(&self) -> streamweave_stateful::StateResult<()> {
163 self.state_store.reset()
164 }
165
166 fn has_state(&self) -> bool {
167 self.state_store.is_initialized()
168 }
169
170 fn update_state<F>(&self, f: F) -> streamweave_stateful::StateResult<Self::State>
171 where
172 F: FnOnce(Option<Self::State>) -> Self::State + Send + 'static,
173 {
174 self.state_store.update_with(Box::new(f))
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use futures::stream;
182
183 #[tokio::test]
184 async fn test_running_sum_basic() {
185 let mut transformer = RunningSumTransformer::<i32>::new();
186 let input = Box::pin(stream::iter(vec![1, 2, 3, 4, 5]));
187
188 let output = transformer.transform(input);
189 let results: Vec<i32> = output.collect().await;
190
191 assert_eq!(results, vec![1, 3, 6, 10, 15]);
192 }
193
194 #[tokio::test]
195 async fn test_running_sum_with_initial_value() {
196 let mut transformer = RunningSumTransformer::<i32>::with_initial(100);
197 let input = Box::pin(stream::iter(vec![1, 2, 3]));
198
199 let output = transformer.transform(input);
200 let results: Vec<i32> = output.collect().await;
201
202 assert_eq!(results, vec![101, 103, 106]);
203 }
204
205 #[tokio::test]
206 async fn test_running_sum_empty_input() {
207 let mut transformer = RunningSumTransformer::<i32>::new();
208 let input: Pin<Box<dyn Stream<Item = i32> + Send>> = Box::pin(stream::iter(vec![]));
209
210 let output = transformer.transform(input);
211 let results: Vec<i32> = output.collect().await;
212
213 assert!(results.is_empty());
214 }
215
216 #[tokio::test]
217 async fn test_running_sum_floats() {
218 let mut transformer = RunningSumTransformer::<f64>::new();
219 let input = Box::pin(stream::iter(vec![1.5, 2.5, 3.0]));
220
221 let output = transformer.transform(input);
222 let results: Vec<f64> = output.collect().await;
223
224 assert_eq!(results, vec![1.5, 4.0, 7.0]);
225 }
226
227 #[tokio::test]
228 async fn test_running_sum_state_persistence() {
229 let mut transformer = RunningSumTransformer::<i32>::new();
230
231 let input1 = Box::pin(stream::iter(vec![1, 2, 3]));
233 let output1 = transformer.transform(input1);
234 let results1: Vec<i32> = output1.collect().await;
235 assert_eq!(results1, vec![1, 3, 6]);
236
237 let input2 = Box::pin(stream::iter(vec![4, 5]));
239 let output2 = transformer.transform(input2);
240 let results2: Vec<i32> = output2.collect().await;
241 assert_eq!(results2, vec![10, 15]);
242 }
243
244 #[tokio::test]
245 async fn test_running_sum_state_reset() {
246 let mut transformer = RunningSumTransformer::<i32>::new();
247
248 let input1 = Box::pin(stream::iter(vec![1, 2, 3]));
250 let output1 = transformer.transform(input1);
251 let _: Vec<i32> = output1.collect().await;
252
253 transformer.reset_state().unwrap();
255
256 let input2 = Box::pin(stream::iter(vec![10, 20]));
258 let output2 = transformer.transform(input2);
259 let results2: Vec<i32> = output2.collect().await;
260 assert_eq!(results2, vec![10, 30]);
261 }
262
263 #[tokio::test]
264 async fn test_running_sum_get_state() {
265 let mut transformer = RunningSumTransformer::<i32>::new();
266 let input = Box::pin(stream::iter(vec![1, 2, 3, 4, 5]));
267
268 let output = transformer.transform(input);
269 let _: Vec<i32> = output.collect().await;
270
271 let final_state = transformer.state().unwrap().unwrap();
273 assert_eq!(final_state, 15);
274 }
275
276 #[tokio::test]
277 async fn test_running_sum_component_info() {
278 let transformer = RunningSumTransformer::<i32>::new().with_name("my_running_sum".to_string());
279
280 let info = transformer.component_info();
281 assert_eq!(info.name, "my_running_sum");
282 assert!(info.type_name.contains("RunningSumTransformer"));
283 }
284
285 #[tokio::test]
286 async fn test_running_sum_negative_numbers() {
287 let mut transformer = RunningSumTransformer::<i32>::new();
288 let input = Box::pin(stream::iter(vec![10, -3, 5, -7]));
289
290 let output = transformer.transform(input);
291 let results: Vec<i32> = output.collect().await;
292
293 assert_eq!(results, vec![10, 7, 12, 5]);
294 }
295
296 #[tokio::test]
297 async fn test_running_sum_has_state() {
298 let transformer = RunningSumTransformer::<i32>::new();
299 assert!(transformer.has_state());
300 }
301
302 #[tokio::test]
303 async fn test_running_sum_set_state() {
304 let mut transformer = RunningSumTransformer::<i32>::new();
305
306 transformer.set_state(100).unwrap();
308
309 let input = Box::pin(stream::iter(vec![1, 2, 3]));
311 let output = transformer.transform(input);
312 let results: Vec<i32> = output.collect().await;
313
314 assert_eq!(results, vec![101, 103, 106]);
315 }
316}