swiftide_indexing/
pipeline.rs

1use anyhow::Result;
2use futures_util::{StreamExt, TryFutureExt, TryStreamExt};
3use swiftide_core::{
4    BatchableTransformer, ChunkerTransformer, Loader, NodeCache, Persist, SimplePrompt,
5    Transformer, WithBatchIndexingDefaults, WithIndexingDefaults, indexing::IndexingDefaults,
6};
7use tokio::{sync::mpsc, task};
8use tracing::Instrument;
9
10use std::{sync::Arc, time::Duration};
11
12use swiftide_core::indexing::{EmbedMode, IndexingStream, Node};
13
14macro_rules! trace_span {
15    ($op:literal, $step:expr) => {
16        tracing::trace_span!($op, "otel.name" = format!("{}.{}", $op, $step.name()),)
17    };
18
19    ($op:literal) => {
20        tracing::trace_span!($op, "otel.name" = format!("{}", $op),)
21    };
22}
23
24macro_rules! node_trace_log {
25    ($step:expr, $node:expr, $msg:literal) => {
26        tracing::trace!(
27            node = ?$node,
28            node_id = ?$node.id(),
29            step = $step.name(),
30            $msg
31        )
32    };
33}
34
35macro_rules! batch_node_trace_log {
36    ($step:expr, $nodes:expr, $msg:literal) => {
37        tracing::trace!(batch_size = $nodes.len(), nodes = ?$nodes, step = $step.name(), $msg)
38    };
39}
40
41/// The default batch size for batch processing.
42const DEFAULT_BATCH_SIZE: usize = 256;
43
44/// A pipeline for indexing files, adding metadata, chunking, transforming, embedding, and then
45/// storing them.
46///
47/// The `Pipeline` struct orchestrates the entire file indexing process. It is designed to be
48/// flexible and performant, allowing for various stages of data transformation and storage to be
49/// configured and executed asynchronously.
50///
51/// # Fields
52///
53/// * `stream` - The stream of `Node` items to be processed.
54/// * `storage` - Optional storage backend where the processed nodes will be stored.
55/// * `concurrency` - The level of concurrency for processing nodes.
56pub struct Pipeline {
57    stream: IndexingStream,
58    storage: Vec<Arc<dyn Persist>>,
59    concurrency: usize,
60    indexing_defaults: IndexingDefaults,
61    batch_size: usize,
62}
63
64impl Default for Pipeline {
65    /// Creates a default `Pipeline` with an empty stream, no storage, and a concurrency level equal
66    /// to the number of CPUs.
67    fn default() -> Self {
68        Self {
69            stream: IndexingStream::empty(),
70            storage: Vec::default(),
71            concurrency: num_cpus::get(),
72            indexing_defaults: IndexingDefaults::default(),
73            batch_size: DEFAULT_BATCH_SIZE,
74        }
75    }
76}
77
78impl Pipeline {
79    /// Creates a `Pipeline` from a given loader.
80    ///
81    /// # Arguments
82    ///
83    /// * `loader` - A loader that implements the `Loader` trait.
84    ///
85    /// # Returns
86    ///
87    /// An instance of `Pipeline` initialized with the provided loader.
88    pub fn from_loader(loader: impl Loader + 'static) -> Self {
89        let stream = loader.into_stream();
90        Self {
91            stream,
92            ..Default::default()
93        }
94    }
95
96    /// Sets the default LLM client to be used for LLM prompts for all transformers in the
97    /// pipeline.
98    #[must_use]
99    pub fn with_default_llm_client(mut self, client: impl SimplePrompt + 'static) -> Self {
100        self.indexing_defaults = IndexingDefaults::from_simple_prompt(Box::new(client));
101        self
102    }
103
104    /// Creates a `Pipeline` from a given stream.
105    ///
106    /// # Arguments
107    ///
108    /// * `stream` - An `IndexingStream` containing the nodes to be processed.
109    ///
110    /// # Returns
111    ///
112    /// An instance of `Pipeline` initialized with the provided stream.
113    pub fn from_stream(stream: impl Into<IndexingStream>) -> Self {
114        Self {
115            stream: stream.into(),
116            ..Default::default()
117        }
118    }
119
120    /// Sets the concurrency level for the pipeline. By default the concurrency is set to the
121    /// number of cpus.
122    ///
123    /// # Arguments
124    ///
125    /// * `concurrency` - The desired level of concurrency.
126    ///
127    /// # Returns
128    ///
129    /// An instance of `Pipeline` with the updated concurrency level.
130    #[must_use]
131    pub fn with_concurrency(mut self, concurrency: usize) -> Self {
132        self.concurrency = concurrency;
133        self
134    }
135
136    /// Sets the embed mode for the pipeline. The embed mode controls what (combination) fields of a
137    /// [`Node`] be embedded with a vector when transforming with [`crate::transformers::Embed`]
138    ///
139    /// See also [`swiftide_core::indexing::EmbedMode`].
140    ///
141    /// # Arguments
142    ///
143    /// * `embed_mode` - The desired embed mode.
144    ///
145    /// # Returns
146    ///
147    /// An instance of `Pipeline` with the updated embed mode.
148    #[must_use]
149    pub fn with_embed_mode(mut self, embed_mode: EmbedMode) -> Self {
150        self.stream = self
151            .stream
152            .map_ok(move |mut node| {
153                node.embed_mode = embed_mode;
154                node
155            })
156            .boxed()
157            .into();
158        self
159    }
160
161    /// Filters out cached nodes using the provided cache.
162    ///
163    /// # Arguments
164    ///
165    /// * `cache` - A cache that implements the `NodeCache` trait.
166    ///
167    /// # Returns
168    ///
169    /// An instance of `Pipeline` with the updated stream that filters out cached nodes.
170    #[must_use]
171    pub fn filter_cached(mut self, cache: impl NodeCache + 'static) -> Self {
172        let cache = Arc::new(cache);
173        self.stream = self
174            .stream
175            .try_filter_map(move |node| {
176                let cache = Arc::clone(&cache);
177                let span = trace_span!("filter_cached", cache);
178
179                async move {
180                    if cache.get(&node).await {
181                        node_trace_log!(cache, node, "node in cache, skipping");
182                        Ok(None)
183                    } else {
184                        node_trace_log!(cache, node, "node not in cache, processing");
185                        cache.set(&node).await;
186                        Ok(Some(node))
187                    }
188                }
189                .instrument(span.or_current())
190            })
191            .boxed()
192            .into();
193        self
194    }
195
196    /// Adds a transformer to the pipeline.
197    ///
198    /// Closures can also be provided as transformers.
199    ///
200    /// # Arguments
201    ///
202    /// * `transformer` - A transformer that implements the `Transformer` trait.
203    ///
204    /// # Returns
205    ///
206    /// An instance of `Pipeline` with the updated stream that applies the transformer to each node.
207    #[must_use]
208    pub fn then(
209        mut self,
210        mut transformer: impl Transformer + WithIndexingDefaults + 'static,
211    ) -> Self {
212        let concurrency = transformer.concurrency().unwrap_or(self.concurrency);
213
214        transformer.with_indexing_defaults(self.indexing_defaults.clone());
215
216        let transformer = Arc::new(transformer);
217        self.stream = self
218            .stream
219            .map_ok(move |node| {
220                let transformer = transformer.clone();
221                let span = trace_span!("then", transformer);
222
223                task::spawn(
224                    async move {
225                        node_trace_log!(transformer, node, "Transforming node");
226                        transformer.transform_node(node).await
227                    }
228                    .instrument(span.or_current()),
229                )
230                .err_into::<anyhow::Error>()
231            })
232            .try_buffer_unordered(concurrency)
233            .map(|x| x.and_then(|x| x))
234            .boxed()
235            .into();
236
237        self
238    }
239
240    /// Adds a batch transformer to the pipeline.
241    ///
242    /// If the transformer has a batch size set, the batch size from the transformer is used,
243    /// otherwise the pipeline default batch size ([`DEFAULT_BATCH_SIZE`]).
244    ///
245    /// # Arguments
246    ///
247    /// * `transformer` - A transformer that implements the `BatchableTransformer` trait.
248    ///
249    /// # Returns
250    ///
251    /// An instance of `Pipeline` with the updated stream that applies the batch transformer to each
252    /// batch of nodes.
253    #[must_use]
254    pub fn then_in_batch(
255        mut self,
256        mut transformer: impl BatchableTransformer + WithBatchIndexingDefaults + 'static,
257    ) -> Self {
258        let concurrency = transformer.concurrency().unwrap_or(self.concurrency);
259
260        transformer.with_indexing_defaults(self.indexing_defaults.clone());
261
262        let transformer = Arc::new(transformer);
263        self.stream = self
264            .stream
265            .try_chunks(transformer.batch_size().unwrap_or(self.batch_size))
266            .map_ok(move |nodes| {
267                let transformer = Arc::clone(&transformer);
268                let span = trace_span!("then_in_batch", transformer);
269
270                tokio::spawn(
271                    async move {
272                        batch_node_trace_log!(transformer, nodes, "batch transforming nodes");
273                        transformer.batch_transform(nodes).await
274                    }
275                    .instrument(span.or_current()),
276                )
277                .map_err(anyhow::Error::from)
278            })
279            .err_into::<anyhow::Error>()
280            .try_buffer_unordered(concurrency) // First get the streams from each future
281            .try_flatten_unordered(None) // Then flatten all the streams back into one
282            .boxed()
283            .into();
284        self
285    }
286
287    /// Adds a chunker transformer to the pipeline.
288    ///
289    /// # Arguments
290    ///
291    /// * `chunker` - A transformer that implements the `ChunkerTransformer` trait.
292    ///
293    /// # Returns
294    ///
295    /// An instance of `Pipeline` with the updated stream that applies the chunker transformer to
296    /// each node.
297    #[must_use]
298    pub fn then_chunk(mut self, chunker: impl ChunkerTransformer + 'static) -> Self {
299        let chunker = Arc::new(chunker);
300        let concurrency = chunker.concurrency().unwrap_or(self.concurrency);
301        self.stream = self
302            .stream
303            .map_ok(move |node| {
304                let chunker = Arc::clone(&chunker);
305                let span = trace_span!("then_chunk", chunker);
306
307                tokio::spawn(
308                    async move {
309                        node_trace_log!(chunker, node, "Chunking node");
310                        chunker.transform_node(node).await
311                    }
312                    .instrument(span.or_current()),
313                )
314                .map_err(anyhow::Error::from)
315            })
316            .err_into::<anyhow::Error>()
317            .try_buffer_unordered(concurrency)
318            .try_flatten_unordered(None)
319            .boxed()
320            .into();
321
322        self
323    }
324
325    /// Persists indexing nodes using the provided storage backend.
326    ///
327    /// # Arguments
328    ///
329    /// * `storage` - A storage backend that implements the `Storage` trait.
330    ///
331    /// # Returns
332    ///
333    /// An instance of `Pipeline` with the configured storage backend.
334    ///
335    /// # Panics
336    ///
337    /// Panics if batch size turns out to be not set and batch storage is still invoked.
338    /// Pipeline only invokes batch storing if the batch size is set, so should be alright.
339    #[must_use]
340    pub fn then_store_with(mut self, storage: impl Persist + 'static) -> Self {
341        let storage = Arc::new(storage);
342        self.storage.push(storage.clone());
343        // add storage to the stream instead of doing it at the end
344        if storage.batch_size().is_some() {
345            self.stream = self
346                .stream
347                .try_chunks(storage.batch_size().unwrap())
348                .map_ok(move |nodes| {
349                    let storage = Arc::clone(&storage);
350                    let span = trace_span!("then_store_with_batched", storage);
351
352                    tokio::spawn(
353                        async move {
354                            batch_node_trace_log!(storage, nodes, "batch storing nodes");
355                            storage.batch_store(nodes).await
356                        }
357                        .instrument(span.or_current()),
358                    )
359                    .map_err(anyhow::Error::from)
360                })
361                .err_into::<anyhow::Error>()
362                .try_buffer_unordered(self.concurrency)
363                .try_flatten_unordered(None)
364                .boxed()
365                .into();
366        } else {
367            self.stream = self
368                .stream
369                .map_ok(move |node| {
370                    let storage = Arc::clone(&storage);
371                    let span = trace_span!("then_store_with", storage);
372
373                    tokio::spawn(
374                        async move {
375                            node_trace_log!(storage, node, "Storing node");
376
377                            storage.store(node).await
378                        }
379                        .instrument(span.or_current()),
380                    )
381                    .err_into::<anyhow::Error>()
382                })
383                .try_buffer_unordered(self.concurrency)
384                .map(|x| x.and_then(|x| x))
385                .boxed()
386                .into();
387        }
388
389        self
390    }
391
392    /// Splits the stream into two streams based on a predicate.
393    ///
394    /// Note that this is not lazy. It will start consuming the stream immediately
395    /// and send each item to the left or right stream based on the predicate.
396    ///
397    /// The other streams have a buffer, but should be started as soon as possible.
398    /// The channels of the resulting streams are bounded and the parent stream will panic
399    /// if sending fails.
400    ///
401    /// They can either be run concurrently, alternated between or merged back together.
402    ///
403    /// # Panics
404    ///
405    /// Panics if the receiving pipelines buffers are full or unavailable.
406    #[must_use]
407    pub fn split_by<P>(self, predicate: P) -> (Self, Self)
408    where
409        P: Fn(&Result<Node>) -> bool + Send + Sync + 'static,
410    {
411        let predicate = Arc::new(predicate);
412
413        let (left_tx, left_rx) = mpsc::channel(1000);
414        let (right_tx, right_rx) = mpsc::channel(1000);
415
416        let stream = self.stream;
417        let span = trace_span!("split_by");
418        tokio::spawn(
419            async move {
420                stream
421                    .for_each_concurrent(self.concurrency, move |item| {
422                        let predicate = Arc::clone(&predicate);
423                        let left_tx = left_tx.clone();
424                        let right_tx = right_tx.clone();
425                        async move {
426                            if predicate(&item) {
427                                tracing::trace!(?item, "Sending to left stream");
428                                left_tx
429                                    .send(item)
430                                    .await
431                                    .expect("Failed to send to left stream");
432                            } else {
433                                tracing::trace!(?item, "Sending to right stream");
434                                right_tx
435                                    .send(item)
436                                    .await
437                                    .expect("Failed to send to right stream");
438                            }
439                        }
440                    })
441                    .await;
442            }
443            .instrument(span.or_current()),
444        );
445
446        let left_pipeline = Self {
447            stream: left_rx.into(),
448            storage: self.storage.clone(),
449            concurrency: self.concurrency,
450            indexing_defaults: self.indexing_defaults.clone(),
451            batch_size: self.batch_size,
452        };
453
454        let right_pipeline = Self {
455            stream: right_rx.into(),
456            storage: self.storage.clone(),
457            concurrency: self.concurrency,
458            indexing_defaults: self.indexing_defaults.clone(),
459            batch_size: self.batch_size,
460        };
461
462        (left_pipeline, right_pipeline)
463    }
464
465    /// Merges two streams into one
466    ///
467    /// This is useful for merging two streams that have been split using the `split_by` method.
468    ///
469    /// The full stream can then be processed using the `run` method.
470    #[must_use]
471    pub fn merge(self, other: Self) -> Self {
472        let stream = tokio_stream::StreamExt::merge(self.stream, other.stream);
473
474        Self {
475            stream: stream.boxed().into(),
476            ..self
477        }
478    }
479
480    /// Throttles the stream of nodes, limiting the rate to 1 per duration.
481    ///
482    /// Useful for rate limiting the indexing pipeline. Uses `tokio_stream::StreamExt::throttle`
483    /// internally which has a granualarity of 1ms.
484    #[must_use]
485    pub fn throttle(mut self, duration: impl Into<Duration>) -> Self {
486        self.stream = tokio_stream::StreamExt::throttle(self.stream, duration.into())
487            .boxed()
488            .into();
489        self
490    }
491
492    // Silently filters out errors encountered by the pipeline.
493    //
494    // This method filters out errors encountered by the pipeline, preventing them from bubbling up
495    // and terminating the stream. Note that errors are not logged.
496    #[must_use]
497    pub fn filter_errors(mut self) -> Self {
498        self.stream = self
499            .stream
500            .filter_map(|result| async {
501                match result {
502                    Ok(node) => Some(Ok(node)),
503                    Err(_e) => None,
504                }
505            })
506            .boxed()
507            .into();
508        self
509    }
510
511    /// Provide a closure to selectively filter nodes or errors
512    ///
513    /// This allows you to skip specific errors or nodes, or do ad hoc inspection.
514    ///
515    /// If the closure returns true, the result is kept, otherwise it is skipped.
516    #[must_use]
517    pub fn filter<F>(mut self, filter: F) -> Self
518    where
519        F: Fn(&Result<Node>) -> bool + Send + Sync + 'static,
520    {
521        self.stream = self
522            .stream
523            .filter(move |result| {
524                let will_retain = filter(result);
525
526                async move { will_retain }
527            })
528            .boxed()
529            .into();
530        self
531    }
532
533    /// Logs all results processed by the pipeline.
534    ///
535    /// This method logs all results processed by the pipeline at the `DEBUG` level.
536    #[must_use]
537    pub fn log_all(self) -> Self {
538        self.log_errors().log_nodes()
539    }
540
541    /// Logs all errors encountered by the pipeline.
542    ///
543    /// This method logs all errors encountered by the pipeline at the `ERROR` level.
544    #[must_use]
545    pub fn log_errors(mut self) -> Self {
546        self.stream = self
547            .stream
548            .inspect_err(|e| tracing::error!(?e, "Error processing node"))
549            .boxed()
550            .into();
551        self
552    }
553
554    /// Logs all nodes processed by the pipeline.
555    ///
556    /// This method logs all nodes processed by the pipeline at the `DEBUG` level.
557    #[must_use]
558    pub fn log_nodes(mut self) -> Self {
559        self.stream = self
560            .stream
561            .inspect_ok(|node| tracing::debug!(?node, "Processed node: {:?}", node))
562            .boxed()
563            .into();
564        self
565    }
566
567    /// Runs the indexing pipeline.
568    ///
569    /// This method processes the stream of nodes, applying all configured transformations and
570    /// storing the results.
571    ///
572    /// # Returns
573    ///
574    /// A `Result` indicating the success or failure of the pipeline execution.
575    ///
576    /// # Errors
577    ///
578    /// Returns an error if no storage backend is configured or if any stage of the pipeline fails.
579    #[tracing::instrument(skip_all, fields(total_nodes), name = "indexing_pipeline.run")]
580    pub async fn run(mut self) -> Result<()> {
581        tracing::info!(
582            "Starting indexing pipeline with {} concurrency",
583            self.concurrency
584        );
585        let now = std::time::Instant::now();
586        if self.storage.is_empty() {
587            anyhow::bail!("No storage configured for indexing pipeline");
588        }
589
590        // Ensure all storage backends are set up before processing nodes
591        let setup_futures = self
592            .storage
593            .into_iter()
594            .map(|storage| async move { storage.setup().await })
595            .collect::<Vec<_>>();
596        futures_util::future::try_join_all(setup_futures).await?;
597
598        let mut total_nodes = 0;
599        while self.stream.try_next().await?.is_some() {
600            total_nodes += 1;
601        }
602
603        let elapsed_in_seconds = now.elapsed().as_secs();
604        tracing::info!(
605            elapsed_in_seconds,
606            "Processed {} nodes in {} seconds",
607            total_nodes,
608            elapsed_in_seconds
609        );
610        tracing::Span::current().record("total_nodes", total_nodes);
611
612        Ok(())
613    }
614}
615
616#[cfg(test)]
617mod tests {
618
619    use super::*;
620    use crate::persist::MemoryStorage;
621    use mockall::Sequence;
622    use swiftide_core::indexing::*;
623
624    /// Tests a simple run of the indexing pipeline.
625    #[test_log::test(tokio::test)]
626    async fn test_simple_run() {
627        let mut loader = MockLoader::new();
628        let mut transformer = MockTransformer::new();
629        let mut batch_transformer = MockBatchableTransformer::new();
630        let mut chunker = MockChunkerTransformer::new();
631        let mut storage = MockPersist::new();
632
633        let mut seq = Sequence::new();
634
635        loader
636            .expect_into_stream()
637            .times(1)
638            .in_sequence(&mut seq)
639            .returning(|| vec![Ok(Node::default())].into());
640
641        transformer.expect_transform_node().returning(|mut node| {
642            node.chunk = "transformed".to_string();
643            Ok(node)
644        });
645        transformer.expect_concurrency().returning(|| None);
646        transformer.expect_name().returning(|| "transformer");
647
648        batch_transformer
649            .expect_batch_transform()
650            .times(1)
651            .in_sequence(&mut seq)
652            .returning(|nodes| IndexingStream::iter(nodes.into_iter().map(Ok)));
653        batch_transformer.expect_concurrency().returning(|| None);
654        batch_transformer.expect_name().returning(|| "transformer");
655        batch_transformer.expect_batch_size().returning(|| None);
656
657        chunker
658            .expect_transform_node()
659            .times(1)
660            .in_sequence(&mut seq)
661            .returning(|node| {
662                let mut nodes = vec![];
663                for i in 0..3 {
664                    let mut node = node.clone();
665                    node.chunk = format!("transformed_chunk_{i}");
666                    nodes.push(Ok(node));
667                }
668                nodes.into()
669            });
670        chunker.expect_concurrency().returning(|| None);
671        chunker.expect_name().returning(|| "chunker");
672
673        storage.expect_setup().returning(|| Ok(()));
674        storage.expect_batch_size().returning(|| None);
675        storage
676            .expect_store()
677            .times(3)
678            .in_sequence(&mut seq)
679            .withf(|node| node.chunk.starts_with("transformed_chunk_"))
680            .returning(Ok);
681        storage.expect_name().returning(|| "storage");
682
683        let pipeline = Pipeline::from_loader(loader)
684            .then(transformer)
685            .then_in_batch(batch_transformer)
686            .then_chunk(chunker)
687            .then_store_with(storage);
688
689        pipeline.run().await.unwrap();
690    }
691
692    #[tokio::test]
693    async fn test_skipping_errors() {
694        let mut loader = MockLoader::new();
695        let mut transformer = MockTransformer::new();
696        let mut storage = MockPersist::new();
697        let mut seq = Sequence::new();
698        loader
699            .expect_into_stream()
700            .times(1)
701            .in_sequence(&mut seq)
702            .returning(|| vec![Ok(Node::default())].into());
703        transformer
704            .expect_transform_node()
705            .returning(|_node| Err(anyhow::anyhow!("Error transforming node")));
706        transformer.expect_concurrency().returning(|| None);
707        transformer.expect_name().returning(|| "mock");
708        storage.expect_setup().returning(|| Ok(()));
709        storage.expect_batch_size().returning(|| None);
710        storage.expect_store().times(0).returning(Ok);
711        let pipeline = Pipeline::from_loader(loader)
712            .then(transformer)
713            .then_store_with(storage)
714            .filter_errors();
715        pipeline.run().await.unwrap();
716    }
717
718    #[tokio::test]
719    async fn test_concurrent_calls_with_simple_transformer() {
720        let mut loader = MockLoader::new();
721        let mut transformer = MockTransformer::new();
722        let mut storage = MockPersist::new();
723        let mut seq = Sequence::new();
724        loader
725            .expect_into_stream()
726            .times(1)
727            .in_sequence(&mut seq)
728            .returning(|| {
729                vec![
730                    Ok(Node::default()),
731                    Ok(Node::default()),
732                    Ok(Node::default()),
733                ]
734                .into()
735            });
736        transformer
737            .expect_transform_node()
738            .times(3)
739            .in_sequence(&mut seq)
740            .returning(|mut node| {
741                node.chunk = "transformed".to_string();
742                Ok(node)
743            });
744        transformer.expect_concurrency().returning(|| Some(3));
745        transformer.expect_name().returning(|| "transformer");
746        storage.expect_setup().returning(|| Ok(()));
747        storage.expect_batch_size().returning(|| None);
748        storage.expect_store().times(3).returning(Ok);
749        storage.expect_name().returning(|| "storage");
750
751        let pipeline = Pipeline::from_loader(loader)
752            .then(transformer)
753            .then_store_with(storage);
754        pipeline.run().await.unwrap();
755    }
756
757    #[tokio::test]
758    async fn test_arbitrary_closures_as_transformer() {
759        let mut loader = MockLoader::new();
760        let transformer = |node: Node| {
761            let mut node = node;
762            node.chunk = "transformed".to_string();
763            Ok(node)
764        };
765        let storage = MemoryStorage::default();
766        let mut seq = Sequence::new();
767        loader
768            .expect_into_stream()
769            .times(1)
770            .in_sequence(&mut seq)
771            .returning(|| vec![Ok(Node::default())].into());
772
773        let pipeline = Pipeline::from_loader(loader)
774            .then(transformer)
775            .then_store_with(storage.clone());
776        pipeline.run().await.unwrap();
777
778        dbg!(storage.clone());
779        let processed_node = storage.get("0").await.unwrap();
780        assert_eq!(processed_node.chunk, "transformed");
781    }
782
783    #[tokio::test]
784    async fn test_arbitrary_closures_as_batch_transformer() {
785        let mut loader = MockLoader::new();
786        let batch_transformer = |nodes: Vec<Node>| {
787            IndexingStream::iter(nodes.into_iter().map(|mut node| {
788                node.chunk = "transformed".to_string();
789                Ok(node)
790            }))
791        };
792        let storage = MemoryStorage::default();
793        let mut seq = Sequence::new();
794        loader
795            .expect_into_stream()
796            .times(1)
797            .in_sequence(&mut seq)
798            .returning(|| vec![Ok(Node::default())].into());
799
800        let pipeline = Pipeline::from_loader(loader)
801            .then_in_batch(batch_transformer)
802            .then_store_with(storage.clone());
803        pipeline.run().await.unwrap();
804
805        dbg!(storage.clone());
806        let processed_node = storage.get("0").await.unwrap();
807        assert_eq!(processed_node.chunk, "transformed");
808    }
809
810    #[tokio::test]
811    async fn test_filter_closure() {
812        let mut loader = MockLoader::new();
813        let storage = MemoryStorage::default();
814        let mut seq = Sequence::new();
815        loader
816            .expect_into_stream()
817            .times(1)
818            .in_sequence(&mut seq)
819            .returning(|| {
820                vec![
821                    Ok(Node::default()),
822                    Ok(Node::new("skip")),
823                    Ok(Node::default()),
824                ]
825                .into()
826            });
827        let pipeline = Pipeline::from_loader(loader)
828            .filter(|result| {
829                let node = result.as_ref().unwrap();
830                node.chunk != "skip"
831            })
832            .then_store_with(storage.clone());
833        pipeline.run().await.unwrap();
834        let nodes = storage.get_all().await;
835        assert_eq!(nodes.len(), 2);
836    }
837
838    #[test_log::test(tokio::test)]
839    async fn test_split_and_merge() {
840        let mut loader = MockLoader::new();
841        let storage = MemoryStorage::default();
842        let mut seq = Sequence::new();
843        loader
844            .expect_into_stream()
845            .times(1)
846            .in_sequence(&mut seq)
847            .returning(|| {
848                vec![
849                    Ok(Node::default()),
850                    Ok(Node::new("will go left")),
851                    Ok(Node::default()),
852                ]
853                .into()
854            });
855
856        let pipeline = Pipeline::from_loader(loader);
857        let (mut left, mut right) = pipeline.split_by(|node| {
858            if let Ok(node) = node {
859                node.chunk.starts_with("will go left")
860            } else {
861                false
862            }
863        });
864
865        // change the chunk to 'left'
866        left = left
867            .then(move |mut node: Node| {
868                node.chunk = "left".to_string();
869
870                Ok(node)
871            })
872            .log_all();
873
874        right = right.then(move |mut node: Node| {
875            node.chunk = "right".to_string();
876            Ok(node)
877        });
878
879        left.merge(right)
880            .then_store_with(storage.clone())
881            .run()
882            .await
883            .unwrap();
884        dbg!(storage.clone());
885
886        let all_nodes = storage.get_all_values().await;
887        assert_eq!(
888            all_nodes.iter().filter(|node| node.chunk == "left").count(),
889            1
890        );
891        assert_eq!(
892            all_nodes
893                .iter()
894                .filter(|node| node.chunk == "right")
895                .count(),
896            2
897        );
898    }
899
900    #[tokio::test]
901    async fn test_all_steps_should_work_as_dyn_box() {
902        let mut loader = MockLoader::new();
903        loader
904            .expect_into_stream_boxed()
905            .returning(|| vec![Ok(Node::default())].into());
906
907        let mut transformer = MockTransformer::new();
908        transformer.expect_transform_node().returning(Ok);
909        transformer.expect_concurrency().returning(|| None);
910        transformer.expect_name().returning(|| "mock");
911
912        let mut batch_transformer = MockBatchableTransformer::new();
913        batch_transformer
914            .expect_batch_transform()
915            .returning(std::convert::Into::into);
916        batch_transformer.expect_concurrency().returning(|| None);
917        batch_transformer.expect_name().returning(|| "mock");
918        let mut chunker = MockChunkerTransformer::new();
919        chunker
920            .expect_transform_node()
921            .returning(|node| vec![node].into());
922        chunker.expect_concurrency().returning(|| None);
923        chunker.expect_name().returning(|| "mock");
924
925        let mut storage = MockPersist::new();
926        storage.expect_setup().returning(|| Ok(()));
927        storage.expect_store().returning(Ok);
928        storage.expect_batch_size().returning(|| None);
929        storage.expect_name().returning(|| "mock");
930
931        let pipeline = Pipeline::from_loader(Box::new(loader) as Box<dyn Loader>)
932            .then(Box::new(transformer) as Box<dyn Transformer>)
933            .then_in_batch(Box::new(batch_transformer) as Box<dyn BatchableTransformer>)
934            .then_chunk(Box::new(chunker) as Box<dyn ChunkerTransformer>)
935            .then_store_with(Box::new(storage) as Box<dyn Persist>);
936        pipeline.run().await.unwrap();
937    }
938}