streamweave_transformers/batch/
transformer.rs1use super::batch_transformer::BatchTransformer;
2use async_trait::async_trait;
3use futures::StreamExt;
4use streamweave::{Transformer, TransformerConfig};
5use streamweave_error::{ComponentInfo, ErrorAction, ErrorContext, ErrorStrategy, StreamError};
6
7#[async_trait]
8impl<T> Transformer for BatchTransformer<T>
9where
10 T: std::fmt::Debug + Clone + Send + Sync + 'static,
11{
12 type InputPorts = (T,);
13 type OutputPorts = (Vec<T>,);
14
15 fn transform(&mut self, mut input: Self::InputStream) -> Self::OutputStream {
16 let size = self.size;
17 let mut current_batch: Vec<T> = Vec::with_capacity(size);
18
19 Box::pin(async_stream::stream! {
20 while let Some(item) = input.next().await {
21 current_batch.push(item);
22 if current_batch.len() == size {
23 yield current_batch;
24 current_batch = Vec::with_capacity(size);
25 }
26 }
27 if !current_batch.is_empty() {
28 yield current_batch;
29 }
30 })
31 }
32
33 fn set_config_impl(&mut self, config: TransformerConfig<T>) {
34 self.config = config;
35 }
36
37 fn get_config_impl(&self) -> &TransformerConfig<T> {
38 &self.config
39 }
40
41 fn get_config_mut_impl(&mut self) -> &mut TransformerConfig<T> {
42 &mut self.config
43 }
44
45 fn handle_error(&self, error: &StreamError<T>) -> ErrorAction {
46 match self.config.error_strategy {
47 ErrorStrategy::Stop => ErrorAction::Stop,
48 ErrorStrategy::Skip => ErrorAction::Skip,
49 ErrorStrategy::Retry(n) if error.retries < n => ErrorAction::Retry,
50 _ => ErrorAction::Stop,
51 }
52 }
53
54 fn create_error_context(&self, item: Option<T>) -> ErrorContext<T> {
55 ErrorContext {
56 timestamp: chrono::Utc::now(),
57 item,
58 component_name: self.component_info().name,
59 component_type: std::any::type_name::<Self>().to_string(),
60 }
61 }
62
63 fn component_info(&self) -> ComponentInfo {
64 ComponentInfo {
65 name: self
66 .config
67 .name
68 .clone()
69 .unwrap_or_else(|| "batch_transformer".to_string()),
70 type_name: std::any::type_name::<Self>().to_string(),
71 }
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use futures::stream;
79
80 #[tokio::test]
81 async fn test_batch_exact_size() {
82 let mut transformer = BatchTransformer::new(3).unwrap();
83 let input = stream::iter(vec![1, 2, 3, 4, 5, 6].into_iter());
84 let boxed_input = Box::pin(input);
85
86 let result: Vec<Vec<i32>> = transformer.transform(boxed_input).collect().await;
87
88 assert_eq!(result, vec![vec![1, 2, 3], vec![4, 5, 6]]);
89 }
90
91 #[tokio::test]
92 async fn test_batch_partial_last_chunk() {
93 let mut transformer = BatchTransformer::new(2).unwrap();
94 let input = stream::iter(vec![1, 2, 3, 4, 5].into_iter());
95 let boxed_input = Box::pin(input);
96
97 let result: Vec<Vec<i32>> = transformer.transform(boxed_input).collect().await;
98
99 assert_eq!(result, vec![vec![1, 2], vec![3, 4], vec![5]]);
100 }
101
102 #[tokio::test]
103 async fn test_batch_empty_input() {
104 let mut transformer = BatchTransformer::new(2).unwrap();
105 let input = stream::iter(Vec::<i32>::new());
106 let boxed_input = Box::pin(input);
107
108 let result: Vec<Vec<i32>> = transformer.transform(boxed_input).collect().await;
109
110 assert!(result.is_empty());
111 }
112
113 #[tokio::test]
114 async fn test_batch_size_one() {
115 let mut transformer = BatchTransformer::new(1).unwrap();
116 let input = stream::iter(vec![1, 2, 3].into_iter());
117 let boxed_input = Box::pin(input);
118
119 let result: Vec<Vec<i32>> = transformer.transform(boxed_input).collect().await;
120
121 assert_eq!(result, vec![vec![1], vec![2], vec![3]]);
122 }
123
124 #[tokio::test]
125 async fn test_batch_size_larger_than_input() {
126 let mut transformer = BatchTransformer::new(10).unwrap();
127 let input = stream::iter(vec![1, 2, 3].into_iter());
128 let boxed_input = Box::pin(input);
129
130 let result: Vec<Vec<i32>> = transformer.transform(boxed_input).collect().await;
131
132 assert_eq!(result, vec![vec![1, 2, 3]]);
133 }
134
135 #[test]
136 fn test_batch_invalid_size() {
137 let result = BatchTransformer::<i32>::new(0);
138 assert!(result.is_err());
139 }
140
141 #[test]
142 fn test_error_handling_strategies() {
143 let transformer: BatchTransformer<i32> = BatchTransformer::new(2).unwrap();
144 assert_eq!(transformer.config.error_strategy, ErrorStrategy::Stop);
145 }
146}