streamweave_transformers/batch/
transformer.rs

1use super::batch_transformer::BatchTransformer;
2use async_trait::async_trait;
3use futures::StreamExt;
4use streamweave::{Transformer, TransformerConfig};
5use streamweave_error::{ComponentInfo, ErrorAction, ErrorContext, ErrorStrategy, StreamError};
6
7#[async_trait]
8impl<T> Transformer for BatchTransformer<T>
9where
10  T: std::fmt::Debug + Clone + Send + Sync + 'static,
11{
12  type InputPorts = (T,);
13  type OutputPorts = (Vec<T>,);
14
15  fn transform(&mut self, mut input: Self::InputStream) -> Self::OutputStream {
16    let size = self.size;
17    let mut current_batch: Vec<T> = Vec::with_capacity(size);
18
19    Box::pin(async_stream::stream! {
20      while let Some(item) = input.next().await {
21        current_batch.push(item);
22        if current_batch.len() == size {
23          yield current_batch;
24          current_batch = Vec::with_capacity(size);
25        }
26      }
27      if !current_batch.is_empty() {
28        yield current_batch;
29      }
30    })
31  }
32
33  fn set_config_impl(&mut self, config: TransformerConfig<T>) {
34    self.config = config;
35  }
36
37  fn get_config_impl(&self) -> &TransformerConfig<T> {
38    &self.config
39  }
40
41  fn get_config_mut_impl(&mut self) -> &mut TransformerConfig<T> {
42    &mut self.config
43  }
44
45  fn handle_error(&self, error: &StreamError<T>) -> ErrorAction {
46    match self.config.error_strategy {
47      ErrorStrategy::Stop => ErrorAction::Stop,
48      ErrorStrategy::Skip => ErrorAction::Skip,
49      ErrorStrategy::Retry(n) if error.retries < n => ErrorAction::Retry,
50      _ => ErrorAction::Stop,
51    }
52  }
53
54  fn create_error_context(&self, item: Option<T>) -> ErrorContext<T> {
55    ErrorContext {
56      timestamp: chrono::Utc::now(),
57      item,
58      component_name: self.component_info().name,
59      component_type: std::any::type_name::<Self>().to_string(),
60    }
61  }
62
63  fn component_info(&self) -> ComponentInfo {
64    ComponentInfo {
65      name: self
66        .config
67        .name
68        .clone()
69        .unwrap_or_else(|| "batch_transformer".to_string()),
70      type_name: std::any::type_name::<Self>().to_string(),
71    }
72  }
73}
74
75#[cfg(test)]
76mod tests {
77  use super::*;
78  use futures::stream;
79
80  #[tokio::test]
81  async fn test_batch_exact_size() {
82    let mut transformer = BatchTransformer::new(3).unwrap();
83    let input = stream::iter(vec![1, 2, 3, 4, 5, 6].into_iter());
84    let boxed_input = Box::pin(input);
85
86    let result: Vec<Vec<i32>> = transformer.transform(boxed_input).collect().await;
87
88    assert_eq!(result, vec![vec![1, 2, 3], vec![4, 5, 6]]);
89  }
90
91  #[tokio::test]
92  async fn test_batch_partial_last_chunk() {
93    let mut transformer = BatchTransformer::new(2).unwrap();
94    let input = stream::iter(vec![1, 2, 3, 4, 5].into_iter());
95    let boxed_input = Box::pin(input);
96
97    let result: Vec<Vec<i32>> = transformer.transform(boxed_input).collect().await;
98
99    assert_eq!(result, vec![vec![1, 2], vec![3, 4], vec![5]]);
100  }
101
102  #[tokio::test]
103  async fn test_batch_empty_input() {
104    let mut transformer = BatchTransformer::new(2).unwrap();
105    let input = stream::iter(Vec::<i32>::new());
106    let boxed_input = Box::pin(input);
107
108    let result: Vec<Vec<i32>> = transformer.transform(boxed_input).collect().await;
109
110    assert!(result.is_empty());
111  }
112
113  #[tokio::test]
114  async fn test_batch_size_one() {
115    let mut transformer = BatchTransformer::new(1).unwrap();
116    let input = stream::iter(vec![1, 2, 3].into_iter());
117    let boxed_input = Box::pin(input);
118
119    let result: Vec<Vec<i32>> = transformer.transform(boxed_input).collect().await;
120
121    assert_eq!(result, vec![vec![1], vec![2], vec![3]]);
122  }
123
124  #[tokio::test]
125  async fn test_batch_size_larger_than_input() {
126    let mut transformer = BatchTransformer::new(10).unwrap();
127    let input = stream::iter(vec![1, 2, 3].into_iter());
128    let boxed_input = Box::pin(input);
129
130    let result: Vec<Vec<i32>> = transformer.transform(boxed_input).collect().await;
131
132    assert_eq!(result, vec![vec![1, 2, 3]]);
133  }
134
135  #[test]
136  fn test_batch_invalid_size() {
137    let result = BatchTransformer::<i32>::new(0);
138    assert!(result.is_err());
139  }
140
141  #[test]
142  fn test_error_handling_strategies() {
143    let transformer: BatchTransformer<i32> = BatchTransformer::new(2).unwrap();
144    assert_eq!(transformer.config.error_strategy, ErrorStrategy::Stop);
145  }
146}