swiftide_indexing/
pipeline.rs

1use anyhow::Result;
2use futures_util::{StreamExt, TryFutureExt, TryStreamExt};
3use swiftide_core::{
4    indexing::IndexingDefaults, BatchableTransformer, ChunkerTransformer, Loader, NodeCache,
5    Persist, SimplePrompt, Transformer, WithBatchIndexingDefaults, WithIndexingDefaults,
6};
7use tokio::{sync::mpsc, task};
8use tracing::Instrument;
9
10use std::{sync::Arc, time::Duration};
11
12use swiftide_core::indexing::{EmbedMode, IndexingStream, Node};
13
14/// The default batch size for batch processing.
15const DEFAULT_BATCH_SIZE: usize = 256;
16
17/// A pipeline for indexing files, adding metadata, chunking, transforming, embedding, and then
18/// storing them.
19///
20/// The `Pipeline` struct orchestrates the entire file indexing process. It is designed to be
21/// flexible and performant, allowing for various stages of data transformation and storage to be
22/// configured and executed asynchronously.
23///
24/// # Fields
25///
26/// * `stream` - The stream of `Node` items to be processed.
27/// * `storage` - Optional storage backend where the processed nodes will be stored.
28/// * `concurrency` - The level of concurrency for processing nodes.
29pub struct Pipeline {
30    stream: IndexingStream,
31    storage: Vec<Arc<dyn Persist>>,
32    concurrency: usize,
33    indexing_defaults: IndexingDefaults,
34    batch_size: usize,
35}
36
37impl Default for Pipeline {
38    /// Creates a default `Pipeline` with an empty stream, no storage, and a concurrency level equal
39    /// to the number of CPUs.
40    fn default() -> Self {
41        Self {
42            stream: IndexingStream::empty(),
43            storage: Vec::default(),
44            concurrency: num_cpus::get(),
45            indexing_defaults: IndexingDefaults::default(),
46            batch_size: DEFAULT_BATCH_SIZE,
47        }
48    }
49}
50
51impl Pipeline {
52    /// Creates a `Pipeline` from a given loader.
53    ///
54    /// # Arguments
55    ///
56    /// * `loader` - A loader that implements the `Loader` trait.
57    ///
58    /// # Returns
59    ///
60    /// An instance of `Pipeline` initialized with the provided loader.
61    pub fn from_loader(loader: impl Loader + 'static) -> Self {
62        let stream = loader.into_stream();
63        Self {
64            stream,
65            ..Default::default()
66        }
67    }
68
69    /// Sets the default LLM client to be used for LLM prompts for all transformers in the
70    /// pipeline.
71    #[must_use]
72    pub fn with_default_llm_client(mut self, client: impl SimplePrompt + 'static) -> Self {
73        self.indexing_defaults = IndexingDefaults::from_simple_prompt(Box::new(client));
74        self
75    }
76
77    /// Creates a `Pipeline` from a given stream.
78    ///
79    /// # Arguments
80    ///
81    /// * `stream` - An `IndexingStream` containing the nodes to be processed.
82    ///
83    /// # Returns
84    ///
85    /// An instance of `Pipeline` initialized with the provided stream.
86    pub fn from_stream(stream: impl Into<IndexingStream>) -> Self {
87        Self {
88            stream: stream.into(),
89            ..Default::default()
90        }
91    }
92
93    /// Sets the concurrency level for the pipeline. By default the concurrency is set to the
94    /// number of cpus.
95    ///
96    /// # Arguments
97    ///
98    /// * `concurrency` - The desired level of concurrency.
99    ///
100    /// # Returns
101    ///
102    /// An instance of `Pipeline` with the updated concurrency level.
103    #[must_use]
104    pub fn with_concurrency(mut self, concurrency: usize) -> Self {
105        self.concurrency = concurrency;
106        self
107    }
108
109    /// Sets the embed mode for the pipeline. The embed mode controls what (combination) fields of a
110    /// [`Node`] be embedded with a vector when transforming with [`crate::transformers::Embed`]
111    ///
112    /// See also [`swiftide_core::indexing::EmbedMode`].
113    ///
114    /// # Arguments
115    ///
116    /// * `embed_mode` - The desired embed mode.
117    ///
118    /// # Returns
119    ///
120    /// An instance of `Pipeline` with the updated embed mode.
121    #[must_use]
122    pub fn with_embed_mode(mut self, embed_mode: EmbedMode) -> Self {
123        self.stream = self
124            .stream
125            .map_ok(move |mut node| {
126                node.embed_mode = embed_mode;
127                node
128            })
129            .boxed()
130            .into();
131        self
132    }
133
134    /// Filters out cached nodes using the provided cache.
135    ///
136    /// # Arguments
137    ///
138    /// * `cache` - A cache that implements the `NodeCache` trait.
139    ///
140    /// # Returns
141    ///
142    /// An instance of `Pipeline` with the updated stream that filters out cached nodes.
143    #[must_use]
144    pub fn filter_cached(mut self, cache: impl NodeCache + 'static) -> Self {
145        let cache = Arc::new(cache);
146        self.stream = self
147            .stream
148            .try_filter_map(move |node| {
149                let cache = Arc::clone(&cache);
150                let span =
151                    tracing::trace_span!("filter_cached", node_cache = ?cache, node = ?node );
152                async move {
153                    if cache.get(&node).await {
154                        tracing::debug!(node = ?node, node_cache = cache.name(), "Node in cache, skipping");
155                        Ok(None)
156                    } else {
157                        cache.set(&node).await;
158                        tracing::debug!(node = ?node, node_cache = cache.name(), "Node not in cache, processing");
159                        Ok(Some(node))
160                    }
161                }
162                .instrument(span.or_current())
163            })
164            .boxed()
165            .into();
166        self
167    }
168
169    /// Adds a transformer to the pipeline.
170    ///
171    /// Closures can also be provided as transformers.
172    ///
173    /// # Arguments
174    ///
175    /// * `transformer` - A transformer that implements the `Transformer` trait.
176    ///
177    /// # Returns
178    ///
179    /// An instance of `Pipeline` with the updated stream that applies the transformer to each node.
180    #[must_use]
181    pub fn then(
182        mut self,
183        mut transformer: impl Transformer + WithIndexingDefaults + 'static,
184    ) -> Self {
185        let concurrency = transformer.concurrency().unwrap_or(self.concurrency);
186
187        transformer.with_indexing_defaults(self.indexing_defaults.clone());
188
189        let transformer = Arc::new(transformer);
190        self.stream = self
191            .stream
192            .map_ok(move |node| {
193                let transformer = transformer.clone();
194                let span = tracing::trace_span!("then", node = ?node);
195
196                task::spawn(async move {
197                    tracing::debug!(node = ?node, transformer = transformer.name(), "Transforming node");
198                    transformer.transform_node(node).await
199                }.instrument(span.or_current())
200                )
201                .err_into::<anyhow::Error>()
202            })
203            .try_buffer_unordered(concurrency)
204            .map(|x| x.and_then(|x| x))
205            .boxed()
206            .into();
207
208        self
209    }
210
211    /// Adds a batch transformer to the pipeline.
212    ///
213    /// If the transformer has a batch size set, the batch size from the transformer is used,
214    /// otherwise the pipeline default batch size ([`DEFAULT_BATCH_SIZE`]).
215    ///
216    /// # Arguments
217    ///
218    /// * `transformer` - A transformer that implements the `BatchableTransformer` trait.
219    ///
220    /// # Returns
221    ///
222    /// An instance of `Pipeline` with the updated stream that applies the batch transformer to each
223    /// batch of nodes.
224    #[must_use]
225    pub fn then_in_batch(
226        mut self,
227        mut transformer: impl BatchableTransformer + WithBatchIndexingDefaults + 'static,
228    ) -> Self {
229        let concurrency = transformer.concurrency().unwrap_or(self.concurrency);
230
231        transformer.with_indexing_defaults(self.indexing_defaults.clone());
232
233        let transformer = Arc::new(transformer);
234        self.stream = self
235            .stream
236            .try_chunks(transformer.batch_size().unwrap_or(self.batch_size))
237            .map_ok(move |nodes| {
238                let transformer = Arc::clone(&transformer);
239                let span = tracing::trace_span!("then_in_batch",  nodes = ?nodes );
240
241                tokio::spawn(
242                    async move {
243                        tracing::debug!(
244                            batch_transformer = transformer.name(),
245                            num_nodes = nodes.len(),
246                            "Batch transforming nodes"
247                        );
248                        transformer.batch_transform(nodes).await
249                    }
250                    .instrument(span.or_current()),
251                )
252                .map_err(anyhow::Error::from)
253            })
254            .err_into::<anyhow::Error>()
255            .try_buffer_unordered(concurrency) // First get the streams from each future
256            .try_flatten_unordered(None) // Then flatten all the streams back into one
257            .boxed()
258            .into();
259        self
260    }
261
262    /// Adds a chunker transformer to the pipeline.
263    ///
264    /// # Arguments
265    ///
266    /// * `chunker` - A transformer that implements the `ChunkerTransformer` trait.
267    ///
268    /// # Returns
269    ///
270    /// An instance of `Pipeline` with the updated stream that applies the chunker transformer to
271    /// each node.
272    #[must_use]
273    pub fn then_chunk(mut self, chunker: impl ChunkerTransformer + 'static) -> Self {
274        let chunker = Arc::new(chunker);
275        let concurrency = chunker.concurrency().unwrap_or(self.concurrency);
276        self.stream = self
277            .stream
278            .map_ok(move |node| {
279                let chunker = Arc::clone(&chunker);
280                let span = tracing::trace_span!("then_chunk", chunker = ?chunker, node = ?node );
281
282                tokio::spawn(
283                    async move {
284                        tracing::debug!(chunker = chunker.name(), "Chunking node");
285                        chunker.transform_node(node).await
286                    }
287                    .instrument(span.or_current()),
288                )
289                .map_err(anyhow::Error::from)
290            })
291            .err_into::<anyhow::Error>()
292            .try_buffer_unordered(concurrency)
293            .try_flatten_unordered(None)
294            .boxed()
295            .into();
296
297        self
298    }
299
300    /// Persists indexing nodes using the provided storage backend.
301    ///
302    /// # Arguments
303    ///
304    /// * `storage` - A storage backend that implements the `Storage` trait.
305    ///
306    /// # Returns
307    ///
308    /// An instance of `Pipeline` with the configured storage backend.
309    ///
310    /// # Panics
311    ///
312    /// Panics if batch size turns out to be not set and batch storage is still invoked.
313    /// Pipeline only invokes batch storing if the batch size is set, so should be alright.
314    #[must_use]
315    pub fn then_store_with(mut self, storage: impl Persist + 'static) -> Self {
316        let storage = Arc::new(storage);
317        self.storage.push(storage.clone());
318        // add storage to the stream instead of doing it at the end
319        if storage.batch_size().is_some() {
320            self.stream = self
321                .stream
322                .try_chunks(storage.batch_size().unwrap())
323                .map_ok(move |nodes| {
324                    let storage = Arc::clone(&storage);
325                    let span = tracing::trace_span!("then_store_with_batched", storage = ?storage, nodes = ?nodes );
326
327                tokio::spawn(async move {
328                        tracing::debug!(storage = storage.name(), num_nodes = nodes.len(), "Batch Storing nodes");
329                        storage.batch_store(nodes).await
330                    }
331                    .instrument(span.or_current())
332                    )
333                    .map_err(anyhow::Error::from)
334
335                })
336                .err_into::<anyhow::Error>()
337                .try_buffer_unordered(self.concurrency)
338                .try_flatten_unordered(None)
339                .boxed().into();
340        } else {
341            self.stream = self
342                .stream
343                .map_ok(move |node| {
344                    let storage = Arc::clone(&storage);
345                    let span =
346                        tracing::trace_span!("then_store_with", storage = ?storage, node = ?node );
347
348                    tokio::spawn(
349                        async move {
350                            tracing::debug!(storage = storage.name(), "Storing node");
351
352                            storage.store(node).await
353                        }
354                        .instrument(span.or_current()),
355                    )
356                    .err_into::<anyhow::Error>()
357                })
358                .try_buffer_unordered(self.concurrency)
359                .map(|x| x.and_then(|x| x))
360                .boxed()
361                .into();
362        }
363
364        self
365    }
366
367    /// Splits the stream into two streams based on a predicate.
368    ///
369    /// Note that this is not lazy. It will start consuming the stream immediately
370    /// and send each item to the left or right stream based on the predicate.
371    ///
372    /// The other streams have a buffer, but should be started as soon as possible.
373    /// The channels of the resulting streams are bounded and the parent stream will panic
374    /// if sending fails.
375    ///
376    /// They can either be run concurrently, alternated between or merged back together.
377    ///
378    /// # Panics
379    ///
380    /// Panics if the receiving pipelines buffers are full or unavailable.
381    #[must_use]
382    pub fn split_by<P>(self, predicate: P) -> (Self, Self)
383    where
384        P: Fn(&Result<Node>) -> bool + Send + Sync + 'static,
385    {
386        let predicate = Arc::new(predicate);
387
388        let (left_tx, left_rx) = mpsc::channel(1000);
389        let (right_tx, right_rx) = mpsc::channel(1000);
390
391        let stream = self.stream;
392        let span = tracing::trace_span!("split_by");
393        tokio::spawn(
394            async move {
395                stream
396                    .for_each_concurrent(self.concurrency, move |item| {
397                        let predicate = Arc::clone(&predicate);
398                        let left_tx = left_tx.clone();
399                        let right_tx = right_tx.clone();
400                        async move {
401                            if predicate(&item) {
402                                tracing::debug!(?item, "Sending to left stream");
403                                left_tx
404                                    .send(item)
405                                    .await
406                                    .expect("Failed to send to left stream");
407                            } else {
408                                tracing::debug!(?item, "Sending to right stream");
409                                right_tx
410                                    .send(item)
411                                    .await
412                                    .expect("Failed to send to right stream");
413                            }
414                        }
415                    })
416                    .await;
417            }
418            .instrument(span.or_current()),
419        );
420
421        let left_pipeline = Self {
422            stream: left_rx.into(),
423            storage: self.storage.clone(),
424            concurrency: self.concurrency,
425            indexing_defaults: self.indexing_defaults.clone(),
426            batch_size: self.batch_size,
427        };
428
429        let right_pipeline = Self {
430            stream: right_rx.into(),
431            storage: self.storage.clone(),
432            concurrency: self.concurrency,
433            indexing_defaults: self.indexing_defaults.clone(),
434            batch_size: self.batch_size,
435        };
436
437        (left_pipeline, right_pipeline)
438    }
439
440    /// Merges two streams into one
441    ///
442    /// This is useful for merging two streams that have been split using the `split_by` method.
443    ///
444    /// The full stream can then be processed using the `run` method.
445    #[must_use]
446    pub fn merge(self, other: Self) -> Self {
447        let stream = tokio_stream::StreamExt::merge(self.stream, other.stream);
448
449        Self {
450            stream: stream.boxed().into(),
451            ..self
452        }
453    }
454
455    /// Throttles the stream of nodes, limiting the rate to 1 per duration.
456    ///
457    /// Useful for rate limiting the indexing pipeline. Uses `tokio_stream::StreamExt::throttle`
458    /// internally which has a granualarity of 1ms.
459    #[must_use]
460    pub fn throttle(mut self, duration: impl Into<Duration>) -> Self {
461        self.stream = tokio_stream::StreamExt::throttle(self.stream, duration.into())
462            .boxed()
463            .into();
464        self
465    }
466
467    // Silently filters out errors encountered by the pipeline.
468    //
469    // This method filters out errors encountered by the pipeline, preventing them from bubbling up
470    // and terminating the stream. Note that errors are not logged.
471    #[must_use]
472    pub fn filter_errors(mut self) -> Self {
473        self.stream = self
474            .stream
475            .filter_map(|result| async {
476                match result {
477                    Ok(node) => Some(Ok(node)),
478                    Err(_e) => None,
479                }
480            })
481            .boxed()
482            .into();
483        self
484    }
485
486    /// Provide a closure to selectively filter nodes or errors
487    ///
488    /// This allows you to skip specific errors or nodes, or do ad hoc inspection.
489    ///
490    /// If the closure returns true, the result is kept, otherwise it is skipped.
491    #[must_use]
492    pub fn filter<F>(mut self, filter: F) -> Self
493    where
494        F: Fn(&Result<Node>) -> bool + Send + Sync + 'static,
495    {
496        self.stream = self
497            .stream
498            .filter(move |result| {
499                let will_retain = filter(result);
500
501                async move { will_retain }
502            })
503            .boxed()
504            .into();
505        self
506    }
507
508    /// Logs all results processed by the pipeline.
509    ///
510    /// This method logs all results processed by the pipeline at the `DEBUG` level.
511    #[must_use]
512    pub fn log_all(self) -> Self {
513        self.log_errors().log_nodes()
514    }
515
516    /// Logs all errors encountered by the pipeline.
517    ///
518    /// This method logs all errors encountered by the pipeline at the `ERROR` level.
519    #[must_use]
520    pub fn log_errors(mut self) -> Self {
521        self.stream = self
522            .stream
523            .inspect_err(|e| tracing::error!("Error processing node: {:?}", e))
524            .boxed()
525            .into();
526        self
527    }
528
529    /// Logs all nodes processed by the pipeline.
530    ///
531    /// This method logs all nodes processed by the pipeline at the `DEBUG` level.
532    #[must_use]
533    pub fn log_nodes(mut self) -> Self {
534        self.stream = self
535            .stream
536            .inspect_ok(|node| tracing::debug!("Processed node: {:?}", node))
537            .boxed()
538            .into();
539        self
540    }
541
542    /// Runs the indexing pipeline.
543    ///
544    /// This method processes the stream of nodes, applying all configured transformations and
545    /// storing the results.
546    ///
547    /// # Returns
548    ///
549    /// A `Result` indicating the success or failure of the pipeline execution.
550    ///
551    /// # Errors
552    ///
553    /// Returns an error if no storage backend is configured or if any stage of the pipeline fails.
554    #[tracing::instrument(skip_all, fields(total_nodes), name = "indexing_pipeline.run")]
555    pub async fn run(mut self) -> Result<()> {
556        tracing::info!(
557            "Starting indexing pipeline with {} concurrency",
558            self.concurrency
559        );
560        let now = std::time::Instant::now();
561        if self.storage.is_empty() {
562            anyhow::bail!("No storage configured for indexing pipeline");
563        }
564
565        // Ensure all storage backends are set up before processing nodes
566        let setup_futures = self
567            .storage
568            .into_iter()
569            .map(|storage| async move { storage.setup().await })
570            .collect::<Vec<_>>();
571        futures_util::future::try_join_all(setup_futures).await?;
572
573        let mut total_nodes = 0;
574        while self.stream.try_next().await?.is_some() {
575            total_nodes += 1;
576        }
577
578        let elapsed_in_seconds = now.elapsed().as_secs();
579        tracing::warn!(
580            elapsed_in_seconds,
581            "Processed {} nodes in {} seconds",
582            total_nodes,
583            elapsed_in_seconds
584        );
585        tracing::Span::current().record("total_nodes", total_nodes);
586
587        Ok(())
588    }
589}
590
591#[cfg(test)]
592mod tests {
593
594    use super::*;
595    use crate::persist::MemoryStorage;
596    use mockall::Sequence;
597    use swiftide_core::indexing::*;
598
599    /// Tests a simple run of the indexing pipeline.
600    #[test_log::test(tokio::test)]
601    async fn test_simple_run() {
602        let mut loader = MockLoader::new();
603        let mut transformer = MockTransformer::new();
604        let mut batch_transformer = MockBatchableTransformer::new();
605        let mut chunker = MockChunkerTransformer::new();
606        let mut storage = MockPersist::new();
607
608        let mut seq = Sequence::new();
609
610        loader
611            .expect_into_stream()
612            .times(1)
613            .in_sequence(&mut seq)
614            .returning(|| vec![Ok(Node::default())].into());
615
616        transformer.expect_transform_node().returning(|mut node| {
617            node.chunk = "transformed".to_string();
618            Ok(node)
619        });
620        transformer.expect_concurrency().returning(|| None);
621        transformer.expect_name().returning(|| "transformer");
622
623        batch_transformer
624            .expect_batch_transform()
625            .times(1)
626            .in_sequence(&mut seq)
627            .returning(|nodes| IndexingStream::iter(nodes.into_iter().map(Ok)));
628        batch_transformer.expect_concurrency().returning(|| None);
629        batch_transformer.expect_name().returning(|| "transformer");
630        batch_transformer.expect_batch_size().returning(|| None);
631
632        chunker
633            .expect_transform_node()
634            .times(1)
635            .in_sequence(&mut seq)
636            .returning(|node| {
637                let mut nodes = vec![];
638                for i in 0..3 {
639                    let mut node = node.clone();
640                    node.chunk = format!("transformed_chunk_{i}");
641                    nodes.push(Ok(node));
642                }
643                nodes.into()
644            });
645        chunker.expect_concurrency().returning(|| None);
646        chunker.expect_name().returning(|| "chunker");
647
648        storage.expect_setup().returning(|| Ok(()));
649        storage.expect_batch_size().returning(|| None);
650        storage
651            .expect_store()
652            .times(3)
653            .in_sequence(&mut seq)
654            .withf(|node| node.chunk.starts_with("transformed_chunk_"))
655            .returning(Ok);
656        storage.expect_name().returning(|| "storage");
657
658        let pipeline = Pipeline::from_loader(loader)
659            .then(transformer)
660            .then_in_batch(batch_transformer)
661            .then_chunk(chunker)
662            .then_store_with(storage);
663
664        pipeline.run().await.unwrap();
665    }
666
667    #[tokio::test]
668    async fn test_skipping_errors() {
669        let mut loader = MockLoader::new();
670        let mut transformer = MockTransformer::new();
671        let mut storage = MockPersist::new();
672        let mut seq = Sequence::new();
673        loader
674            .expect_into_stream()
675            .times(1)
676            .in_sequence(&mut seq)
677            .returning(|| vec![Ok(Node::default())].into());
678        transformer
679            .expect_transform_node()
680            .returning(|_node| Err(anyhow::anyhow!("Error transforming node")));
681        transformer.expect_concurrency().returning(|| None);
682        storage.expect_setup().returning(|| Ok(()));
683        storage.expect_batch_size().returning(|| None);
684        storage.expect_store().times(0).returning(Ok);
685        let pipeline = Pipeline::from_loader(loader)
686            .then(transformer)
687            .then_store_with(storage)
688            .filter_errors();
689        pipeline.run().await.unwrap();
690    }
691
692    #[tokio::test]
693    async fn test_concurrent_calls_with_simple_transformer() {
694        let mut loader = MockLoader::new();
695        let mut transformer = MockTransformer::new();
696        let mut storage = MockPersist::new();
697        let mut seq = Sequence::new();
698        loader
699            .expect_into_stream()
700            .times(1)
701            .in_sequence(&mut seq)
702            .returning(|| {
703                vec![
704                    Ok(Node::default()),
705                    Ok(Node::default()),
706                    Ok(Node::default()),
707                ]
708                .into()
709            });
710        transformer
711            .expect_transform_node()
712            .times(3)
713            .in_sequence(&mut seq)
714            .returning(|mut node| {
715                node.chunk = "transformed".to_string();
716                Ok(node)
717            });
718        transformer.expect_concurrency().returning(|| Some(3));
719        transformer.expect_name().returning(|| "transformer");
720        storage.expect_setup().returning(|| Ok(()));
721        storage.expect_batch_size().returning(|| None);
722        storage.expect_store().times(3).returning(Ok);
723        storage.expect_name().returning(|| "storage");
724
725        let pipeline = Pipeline::from_loader(loader)
726            .then(transformer)
727            .then_store_with(storage);
728        pipeline.run().await.unwrap();
729    }
730
731    #[tokio::test]
732    async fn test_arbitrary_closures_as_transformer() {
733        let mut loader = MockLoader::new();
734        let transformer = |node: Node| {
735            let mut node = node;
736            node.chunk = "transformed".to_string();
737            Ok(node)
738        };
739        let storage = MemoryStorage::default();
740        let mut seq = Sequence::new();
741        loader
742            .expect_into_stream()
743            .times(1)
744            .in_sequence(&mut seq)
745            .returning(|| vec![Ok(Node::default())].into());
746
747        let pipeline = Pipeline::from_loader(loader)
748            .then(transformer)
749            .then_store_with(storage.clone());
750        pipeline.run().await.unwrap();
751
752        dbg!(storage.clone());
753        let processed_node = storage.get("0").await.unwrap();
754        assert_eq!(processed_node.chunk, "transformed");
755    }
756
757    #[tokio::test]
758    async fn test_arbitrary_closures_as_batch_transformer() {
759        let mut loader = MockLoader::new();
760        let batch_transformer = |nodes: Vec<Node>| {
761            IndexingStream::iter(nodes.into_iter().map(|mut node| {
762                node.chunk = "transformed".to_string();
763                Ok(node)
764            }))
765        };
766        let storage = MemoryStorage::default();
767        let mut seq = Sequence::new();
768        loader
769            .expect_into_stream()
770            .times(1)
771            .in_sequence(&mut seq)
772            .returning(|| vec![Ok(Node::default())].into());
773
774        let pipeline = Pipeline::from_loader(loader)
775            .then_in_batch(batch_transformer)
776            .then_store_with(storage.clone());
777        pipeline.run().await.unwrap();
778
779        dbg!(storage.clone());
780        let processed_node = storage.get("0").await.unwrap();
781        assert_eq!(processed_node.chunk, "transformed");
782    }
783
784    #[tokio::test]
785    async fn test_filter_closure() {
786        let mut loader = MockLoader::new();
787        let storage = MemoryStorage::default();
788        let mut seq = Sequence::new();
789        loader
790            .expect_into_stream()
791            .times(1)
792            .in_sequence(&mut seq)
793            .returning(|| {
794                vec![
795                    Ok(Node::default()),
796                    Ok(Node::new("skip")),
797                    Ok(Node::default()),
798                ]
799                .into()
800            });
801        let pipeline = Pipeline::from_loader(loader)
802            .filter(|result| {
803                let node = result.as_ref().unwrap();
804                node.chunk != "skip"
805            })
806            .then_store_with(storage.clone());
807        pipeline.run().await.unwrap();
808        let nodes = storage.get_all().await;
809        assert_eq!(nodes.len(), 2);
810    }
811
812    #[test_log::test(tokio::test)]
813    async fn test_split_and_merge() {
814        let mut loader = MockLoader::new();
815        let storage = MemoryStorage::default();
816        let mut seq = Sequence::new();
817        loader
818            .expect_into_stream()
819            .times(1)
820            .in_sequence(&mut seq)
821            .returning(|| {
822                vec![
823                    Ok(Node::default()),
824                    Ok(Node::new("will go left")),
825                    Ok(Node::default()),
826                ]
827                .into()
828            });
829
830        let pipeline = Pipeline::from_loader(loader);
831        let (mut left, mut right) = pipeline.split_by(|node| {
832            if let Ok(node) = node {
833                node.chunk.starts_with("will go left")
834            } else {
835                false
836            }
837        });
838
839        // change the chunk to 'left'
840        left = left
841            .then(move |mut node: Node| {
842                node.chunk = "left".to_string();
843
844                Ok(node)
845            })
846            .log_all();
847
848        right = right.then(move |mut node: Node| {
849            node.chunk = "right".to_string();
850            Ok(node)
851        });
852
853        left.merge(right)
854            .then_store_with(storage.clone())
855            .run()
856            .await
857            .unwrap();
858        dbg!(storage.clone());
859
860        let all_nodes = storage.get_all_values().await;
861        assert_eq!(
862            all_nodes.iter().filter(|node| node.chunk == "left").count(),
863            1
864        );
865        assert_eq!(
866            all_nodes
867                .iter()
868                .filter(|node| node.chunk == "right")
869                .count(),
870            2
871        );
872    }
873
874    #[tokio::test]
875    async fn test_all_steps_should_work_as_dyn_box() {
876        let mut loader = MockLoader::new();
877        loader
878            .expect_into_stream_boxed()
879            .returning(|| vec![Ok(Node::default())].into());
880
881        let mut transformer = MockTransformer::new();
882        transformer.expect_transform_node().returning(Ok);
883        transformer.expect_concurrency().returning(|| None);
884
885        let mut batch_transformer = MockBatchableTransformer::new();
886        batch_transformer
887            .expect_batch_transform()
888            .returning(std::convert::Into::into);
889        batch_transformer.expect_concurrency().returning(|| None);
890        let mut chunker = MockChunkerTransformer::new();
891        chunker
892            .expect_transform_node()
893            .returning(|node| vec![node].into());
894        chunker.expect_concurrency().returning(|| None);
895
896        let mut storage = MockPersist::new();
897        storage.expect_setup().returning(|| Ok(()));
898        storage.expect_store().returning(Ok);
899        storage.expect_batch_size().returning(|| None);
900
901        let pipeline = Pipeline::from_loader(Box::new(loader) as Box<dyn Loader>)
902            .then(Box::new(transformer) as Box<dyn Transformer>)
903            .then_in_batch(Box::new(batch_transformer) as Box<dyn BatchableTransformer>)
904            .then_chunk(Box::new(chunker) as Box<dyn ChunkerTransformer>)
905            .then_store_with(Box::new(storage) as Box<dyn Persist>);
906        pipeline.run().await.unwrap();
907    }
908}