1#![allow(dead_code)]
4#![allow(missing_docs)]
5
6use super::*;
7use crate::error::Result;
8use std::marker::PhantomData;
9
10pub struct PipelineBuilder<I, O> {
12 stages: Vec<Box<dyn PipelineStage>>,
13 config: PipelineConfig,
14 _input: PhantomData<I>,
15 _output: PhantomData<O>,
16}
17
18impl<I, O> Default for PipelineBuilder<I, O>
19where
20 I: 'static + Send + Sync,
21 O: 'static + Send + Sync,
22{
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl<I, O> PipelineBuilder<I, O>
29where
30 I: 'static + Send + Sync,
31 O: 'static + Send + Sync,
32{
33 pub fn new() -> Self {
35 Self {
36 stages: Vec::new(),
37 config: PipelineConfig::default(),
38 _input: PhantomData,
39 _output: PhantomData,
40 }
41 }
42
43 pub fn parallel(mut self, enabled: bool) -> Self {
45 self.config.parallel = enabled;
46 self
47 }
48
49 pub fn num_threads(mut self, threads: usize) -> Self {
51 self.config.num_threads = Some(threads);
52 self
53 }
54
55 pub fn with_cache(mut self, cache_dir: impl AsRef<Path>) -> Self {
57 self.config.enable_cache = true;
58 self.config.cache_dir = Some(cache_dir.as_ref().to_path_buf());
59 self
60 }
61
62 pub fn with_checkpoints(mut self, interval: Duration) -> Self {
64 self.config.checkpoint = true;
65 self.config.checkpoint_interval = interval;
66 self
67 }
68
69 pub fn memory_limit(mut self, bytes: usize) -> Self {
71 self.config.max_memory = Some(bytes);
72 self
73 }
74
75 pub fn transform<T, U, F>(mut self, name: &str, f: F) -> PipelineBuilder<I, U>
77 where
78 T: 'static + Send + Sync,
79 U: 'static + Send + Sync,
80 F: Fn(T) -> Result<U> + Send + Sync + 'static,
81 {
82 self.stages.push(function_stage(name, f));
83 PipelineBuilder {
84 stages: self.stages,
85 config: self.config,
86 _input: self._input,
87 _output: PhantomData,
88 }
89 }
90
91 pub fn filter<T, F>(mut self, name: &str, predicate: F) -> Self
93 where
94 T: 'static + Send + Sync + Clone,
95 F: Fn(&T) -> bool + Send + Sync + 'static,
96 {
97 let stage = function_stage(name, move |input: T| {
98 if predicate(&input) {
99 Ok(input)
100 } else {
101 Err(IoError::Other("Filtered out".to_string()))
102 }
103 });
104 self.stages.push(stage);
105 self
106 }
107
108 pub fn map_each<T, U, F>(mut self, name: &str, f: F) -> PipelineBuilder<I, Vec<U>>
110 where
111 T: 'static + Send + Sync,
112 U: 'static + Send + Sync,
113 F: Fn(T) -> Result<U> + Send + Sync + 'static + Clone,
114 O: IntoIterator<Item = T>,
115 {
116 let stage = function_stage(name, move |input: O| {
117 let results: Result<Vec<U>> = input.into_iter().map(|item| f.clone()(item)).collect();
118 results
119 });
120 self.stages.push(stage);
121 PipelineBuilder {
122 stages: self.stages,
123 config: self.config,
124 _input: self._input,
125 _output: PhantomData,
126 }
127 }
128
129 pub fn stage(mut self, stage: Box<dyn PipelineStage>) -> Self {
131 self.stages.push(stage);
132 self
133 }
134
135 pub fn tap<T, F>(mut self, name: &str, f: F) -> Self
137 where
138 T: 'static + Send + Sync + Clone,
139 F: Fn(&T) -> Result<()> + Send + Sync + 'static,
140 {
141 let stage = function_stage(name, move |input: T| {
142 f(&input)?;
143 Ok(input)
144 });
145 self.stages.push(stage);
146 self
147 }
148
149 pub fn inspect<T>(mut self, name: &str) -> Self
151 where
152 T: 'static + Send + Sync + Clone + std::fmt::Debug,
153 {
154 let name_owned = name.to_string();
155 let stage = function_stage(name, move |input: T| {
156 println!("[{name_owned}] {input:?}");
157 Ok(input)
158 });
159 self.stages.push(stage);
160 self
161 }
162
163 pub fn build(self) -> Pipeline<I, O> {
165 Pipeline {
166 stages: self.stages,
167 config: self.config,
168 _input: PhantomData,
169 _output: PhantomData,
170 }
171 }
172}
173
174pub struct BranchingPipelineBuilder<I> {
176 branches: Vec<(String, Box<dyn PipelineStage>)>,
177 selector: Box<dyn Fn(&I) -> String + Send + Sync>,
178 config: PipelineConfig,
179}
180
181impl<I> BranchingPipelineBuilder<I>
182where
183 I: 'static + Send + Sync,
184{
185 pub fn new<F>(selector: F) -> Self
187 where
188 F: Fn(&I) -> String + Send + Sync + 'static,
189 {
190 Self {
191 branches: Vec::new(),
192 selector: Box::new(selector),
193 config: PipelineConfig::default(),
194 }
195 }
196
197 pub fn branch<O, P>(mut self, name: &str, pipeline: Pipeline<I, O>) -> Self
199 where
200 O: 'static + Send + Sync,
201 {
202 self.branches.push((
203 name.to_string(),
204 Box::new(BranchStage {
205 name: name.to_string(),
206 pipeline: Box::new(pipeline),
207 }),
208 ));
209 self
210 }
211
212 pub fn build<O>(self) -> Pipeline<I, O>
214 where
215 O: 'static + Send + Sync,
216 {
217 Pipeline::new().add_stage(Box::new(BranchingStage {
218 branches: self.branches.into_iter().collect(),
219 selector: self.selector,
220 }))
221 }
222}
223
224struct BranchStage {
225 name: String,
226 pipeline: Box<dyn Any + Send + Sync>,
227}
228
229impl PipelineStage for BranchStage {
230 fn execute(
231 &self,
232 input: PipelineData<Box<dyn Any + Send + Sync>>,
233 ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
234 let mut output = input;
236 output.metadata.set("branch_executed", self.name.clone());
237 output
238 .metadata
239 .set("branch_timestamp", chrono::Utc::now().to_rfc3339());
240 Ok(output)
241 }
242
243 fn name(&self) -> String {
244 self.name.clone()
245 }
246
247 fn stage_type(&self) -> String {
248 "branch".to_string()
249 }
250}
251
252struct BranchingStage<I> {
253 branches: HashMap<String, Box<dyn PipelineStage>>,
254 selector: Box<dyn Fn(&I) -> String + Send + Sync>,
255}
256
257impl<I> PipelineStage for BranchingStage<I>
258where
259 I: 'static + Send + Sync,
260{
261 fn execute(
262 &self,
263 input: PipelineData<Box<dyn Any + Send + Sync>>,
264 ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
265 let typed_input = input
266 .data
267 .downcast_ref::<I>()
268 .ok_or_else(|| IoError::Other("Type mismatch in branching stage".to_string()))?;
269
270 let branch_name = (self.selector)(typed_input);
271
272 if let Some(branch) = self.branches.get(&branch_name) {
273 branch.execute(input)
274 } else {
275 Err(IoError::Other(format!("Unknown branch: {}", branch_name)))
276 }
277 }
278
279 fn name(&self) -> String {
280 "branching".to_string()
281 }
282}
283
284pub struct ParallelPipelineBuilder<I, O> {
286 pipelines: Vec<Pipeline<I, O>>,
287 combiner: Box<dyn Fn(Vec<O>) -> Result<O> + Send + Sync>,
288 config: PipelineConfig,
289}
290
291impl<I, O> ParallelPipelineBuilder<I, O>
292where
293 I: 'static + Send + Sync + Clone,
294 O: 'static + Send + Sync,
295{
296 pub fn new<F>(combiner: F) -> Self
298 where
299 F: Fn(Vec<O>) -> Result<O> + Send + Sync + 'static,
300 {
301 Self {
302 pipelines: Vec::new(),
303 combiner: Box::new(combiner),
304 config: PipelineConfig::default(),
305 }
306 }
307
308 pub fn pipeline(mut self, pipeline: Pipeline<I, O>) -> Self {
310 self.pipelines.push(pipeline);
311 self
312 }
313
314 pub fn build(self) -> Pipeline<I, O> {
316 Pipeline::new().add_stage(Box::new(ParallelStage {
317 pipelines: self.pipelines,
318 combiner: self.combiner,
319 }))
320 }
321}
322
323struct ParallelStage<I, O> {
324 pipelines: Vec<Pipeline<I, O>>,
325 combiner: Box<dyn Fn(Vec<O>) -> Result<O> + Send + Sync>,
326}
327
328impl<I, O> PipelineStage for ParallelStage<I, O>
329where
330 I: 'static + Send + Sync + Clone,
331 O: 'static + Send + Sync,
332{
333 fn execute(
334 &self,
335 input: PipelineData<Box<dyn Any + Send + Sync>>,
336 ) -> Result<PipelineData<Box<dyn Any + Send + Sync>>> {
337 let typed_input = input
338 .data
339 .downcast::<I>()
340 .map_err(|_| IoError::Other("Type mismatch in parallel stage".to_string()))?;
341
342 let results: Result<Vec<O>> = self
344 .pipelines
345 .par_iter()
346 .map(|pipeline| pipeline.execute((*typed_input).clone()))
347 .collect();
348
349 let combined = (self.combiner)(results?)?;
350
351 Ok(PipelineData {
352 data: Box::new(combined) as Box<dyn Any + Send + Sync>,
353 metadata: input.metadata,
354 context: input.context,
355 })
356 }
357
358 fn name(&self) -> String {
359 "parallel".to_string()
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 #[test]
368 fn test_pipeline_builder() {
369 let pipeline: Pipeline<i32, String> = PipelineBuilder::<i32, String>::new()
370 .transform("double", |x: i32| Ok(x * 2))
371 .transform("to_string", |x: i32| Ok(x.to_string()))
372 .build();
373
374 let result = pipeline.execute(21).unwrap();
375 assert_eq!(result, "42");
376 }
377
378 #[test]
379 fn test_pipeline_with_filter() {
380 let pipeline: Pipeline<Vec<i32>, Vec<i32>> = PipelineBuilder::<Vec<i32>, Vec<i32>>::new()
381 .transform("filter_even", |nums: Vec<i32>| {
382 Ok(nums.into_iter().filter(|&x| x % 2 == 0).collect())
383 })
384 .build();
385
386 let result = pipeline.execute(vec![1, 2, 3, 4, 5, 6]).unwrap();
387 assert_eq!(result, vec![2, 4, 6]);
388 }
389}