swiftide_indexing/
pipeline.rs

1use anyhow::Result;
2use futures_util::{StreamExt, TryFutureExt, TryStreamExt};
3use swiftide_core::{
4    BatchableTransformer, ChunkerTransformer, Loader, NodeCache, Persist, SimplePrompt,
5    Transformer, WithBatchIndexingDefaults, WithIndexingDefaults,
6    indexing::{Chunk, IndexingDefaults},
7};
8use tokio::{
9    sync::{Mutex, mpsc},
10    task,
11};
12use tracing::Instrument;
13
14use std::{pin::Pin, sync::Arc, time::Duration};
15
16use swiftide_core::indexing::{EmbedMode, IndexingStream, Node};
17
18macro_rules! trace_span {
19    ($op:literal, $step:expr) => {
20        tracing::trace_span!($op, "otel.name" = format!("{}.{}", $op, $step.name()),)
21    };
22
23    ($op:literal) => {
24        tracing::trace_span!($op, "otel.name" = format!("{}", $op),)
25    };
26}
27
28macro_rules! node_trace_log {
29    ($step:expr, $node:expr, $msg:literal) => {
30        tracing::trace!(
31            node = ?$node,
32            node_id = ?$node.id(),
33            step = $step.name(),
34            $msg
35        )
36    };
37}
38
39macro_rules! batch_node_trace_log {
40    ($step:expr, $nodes:expr, $msg:literal) => {
41        tracing::trace!(batch_size = $nodes.len(), nodes = ?$nodes, step = $step.name(), $msg)
42    };
43}
44
45macro_rules! pipeline_with_new_stream {
46    ($pipeline:expr, $stream:expr) => {
47        Pipeline {
48            stream: $stream.into(),
49            storage_setup_fns: $pipeline.storage_setup_fns.clone(),
50            concurrency: $pipeline.concurrency,
51            indexing_defaults: $pipeline.indexing_defaults.clone(),
52            batch_size: $pipeline.batch_size,
53        }
54    };
55}
56
57/// The default batch size for batch processing.
58const DEFAULT_BATCH_SIZE: usize = 256;
59
60/// A pipeline for indexing files, adding metadata, chunking, transforming, embedding, and then
61/// storing them.
62///
63/// The `Pipeline` struct orchestrates the entire file indexing process. It is designed to be
64/// flexible and performant, allowing for various stages of data transformation and storage to be
65/// configured and executed asynchronously.
66///
67/// # Fields
68///
69/// * `stream` - The stream of `Node` items to be processed.
70/// * `storage` - Optional storage backend where the processed nodes will be stored.
71/// * `concurrency` - The level of concurrency for processing nodes.
72pub struct Pipeline<T: Chunk> {
73    stream: IndexingStream<T>,
74    // storage: Vec<Arc<dyn Persist<Input = T, Output = T>>>,
75    storage_setup_fns: Vec<DynStorageSetupFn>,
76    concurrency: usize,
77    indexing_defaults: IndexingDefaults,
78    batch_size: usize,
79}
80
81type DynStorageSetupFn =
82    Arc<dyn Fn() -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync>;
83
84impl<T: Chunk> Default for Pipeline<T> {
85    /// Creates a default `Pipeline` with an empty stream, no storage, and a concurrency level equal
86    /// to the number of CPUs.
87    fn default() -> Self {
88        Self {
89            stream: IndexingStream::<T>::empty(),
90            storage_setup_fns: Vec::new(),
91            concurrency: num_cpus::get(),
92            indexing_defaults: IndexingDefaults::default(),
93            batch_size: DEFAULT_BATCH_SIZE,
94        }
95    }
96}
97
98impl<T: Chunk> Pipeline<T> {
99    /// Creates a `Pipeline` from a given loader.
100    ///
101    /// # Arguments
102    ///
103    /// * `loader` - A loader that implements the `Loader` trait.
104    ///
105    /// # Returns
106    ///
107    /// An instance of `Pipeline` initialized with the provided loader.
108    pub fn from_loader(loader: impl Loader<Output = T> + 'static) -> Self {
109        let stream = loader.into_stream();
110        Self {
111            stream,
112            ..Default::default()
113        }
114    }
115
116    /// Sets the default LLM client to be used for LLM prompts for all transformers in the
117    /// pipeline.
118    #[must_use]
119    pub fn with_default_llm_client(mut self, client: impl SimplePrompt + 'static) -> Self {
120        self.indexing_defaults = IndexingDefaults::from_simple_prompt(Box::new(client));
121        self
122    }
123
124    /// Creates a `Pipeline` from a given stream.
125    ///
126    /// # Arguments
127    ///
128    /// * `stream` - An `IndexingStream` containing the nodes to be processed.
129    ///
130    /// # Returns
131    ///
132    /// An instance of `Pipeline` initialized with the provided stream.
133    pub fn from_stream(stream: impl Into<IndexingStream<T>>) -> Self {
134        Self {
135            stream: stream.into(),
136            ..Default::default()
137        }
138    }
139
140    /// Sets the concurrency level for the pipeline. By default the concurrency is set to the
141    /// number of cpus.
142    ///
143    /// # Arguments
144    ///
145    /// * `concurrency` - The desired level of concurrency.
146    ///
147    /// # Returns
148    ///
149    /// An instance of `Pipeline` with the updated concurrency level.
150    #[must_use]
151    pub fn with_concurrency(mut self, concurrency: usize) -> Self {
152        self.concurrency = concurrency;
153        self
154    }
155
156    /// Sets the embed mode for the pipeline. The embed mode controls what (combination) fields of a
157    /// [`Node`] be embedded with a vector when transforming with [`crate::transformers::Embed`]
158    ///
159    /// See also [`swiftide_core::indexing::EmbedMode`].
160    ///
161    /// # Arguments
162    ///
163    /// * `embed_mode` - The desired embed mode.
164    ///
165    /// # Returns
166    ///
167    /// An instance of `Pipeline` with the updated embed mode.
168    #[must_use]
169    pub fn with_embed_mode(mut self, embed_mode: EmbedMode) -> Self {
170        self.stream = self
171            .stream
172            .map_ok(move |mut node| {
173                node.embed_mode = embed_mode;
174                node
175            })
176            .boxed()
177            .into();
178        self
179    }
180
181    /// Filters out cached nodes using the provided cache.
182    ///
183    /// # Arguments
184    ///
185    /// * `cache` - A cache that implements the `NodeCache` trait.
186    ///
187    /// # Returns
188    ///
189    /// An instance of `Pipeline` with the updated stream that filters out cached nodes.
190    #[must_use]
191    pub fn filter_cached(mut self, cache: impl NodeCache<Input = T> + 'static) -> Self {
192        let cache = Arc::new(cache);
193        self.stream = self
194            .stream
195            .try_filter_map(move |node| {
196                let cache = Arc::clone(&cache);
197                let span = trace_span!("filter_cached", cache);
198
199                async move {
200                    if cache.get(&node).await {
201                        node_trace_log!(cache, node, "node in cache, skipping");
202                        Ok(None)
203                    } else {
204                        node_trace_log!(cache, node, "node not in cache, processing");
205                        cache.set(&node).await;
206                        Ok(Some(node))
207                    }
208                }
209                .instrument(span.or_current())
210            })
211            .boxed()
212            .into();
213        self
214    }
215
216    /// Adds a transformer to the pipeline.
217    ///
218    /// Closures can also be provided as transformers.
219    ///
220    /// # Arguments
221    ///
222    /// * `transformer` - A transformer that implements the `Transformer` trait.
223    ///
224    /// # Returns
225    ///
226    /// An instance of `Pipeline` with the updated stream that applies the transformer to each node.
227    #[must_use]
228    pub fn then<Output: Chunk>(
229        self,
230        mut transformer: impl Transformer<Input = T, Output = Output> + WithIndexingDefaults + 'static,
231    ) -> Pipeline<Output> {
232        let concurrency = transformer.concurrency().unwrap_or(self.concurrency);
233
234        transformer.with_indexing_defaults(self.indexing_defaults.clone());
235
236        let transformer = Arc::new(transformer);
237        let stream = self
238            .stream
239            .map_ok(move |node| {
240                let transformer = transformer.clone();
241                let span = trace_span!("then", transformer);
242
243                task::spawn(
244                    async move {
245                        node_trace_log!(transformer, node, "Transforming node");
246                        transformer.transform_node(node).await
247                    }
248                    .instrument(span.or_current()),
249                )
250                .err_into::<anyhow::Error>()
251            })
252            .try_buffer_unordered(concurrency)
253            .map(|x| x.and_then(|x| x));
254
255        pipeline_with_new_stream!(self, stream.boxed())
256    }
257
258    /// Adds a batch transformer to the pipeline.
259    ///
260    /// If the transformer has a batch size set, the batch size from the transformer is used,
261    /// otherwise the pipeline default batch size ([`DEFAULT_BATCH_SIZE`]).
262    ///
263    /// # Arguments
264    ///
265    /// * `transformer` - A transformer that implements the `BatchableTransformer` trait.
266    ///
267    /// # Returns
268    ///
269    /// An instance of `Pipeline` with the updated stream that applies the batch transformer to each
270    /// batch of nodes.
271    #[must_use]
272    pub fn then_in_batch<Output: Chunk>(
273        self,
274        mut transformer: impl BatchableTransformer<Input = T, Output = Output>
275        + WithBatchIndexingDefaults
276        + 'static,
277    ) -> Pipeline<Output> {
278        let concurrency = transformer.concurrency().unwrap_or(self.concurrency);
279
280        transformer.with_indexing_defaults(self.indexing_defaults.clone());
281
282        let transformer = Arc::new(transformer);
283        let stream = self
284            .stream
285            .try_chunks(transformer.batch_size().unwrap_or(self.batch_size))
286            .map_ok(move |nodes| {
287                let transformer = Arc::clone(&transformer);
288                let span = trace_span!("then_in_batch", transformer);
289
290                tokio::spawn(
291                    async move {
292                        batch_node_trace_log!(transformer, nodes, "batch transforming nodes");
293                        transformer.batch_transform(nodes).await
294                    }
295                    .instrument(span.or_current()),
296                )
297                .map_err(anyhow::Error::from)
298            })
299            .err_into::<anyhow::Error>()
300            .try_buffer_unordered(concurrency) // First get the streams from each future
301            .try_flatten_unordered(None) // Then flatten the streams into a single stream
302            .boxed();
303
304        pipeline_with_new_stream!(self, stream)
305    }
306
307    /// Adds a chunker transformer to the pipeline.
308    ///
309    /// # Arguments
310    ///
311    /// * `chunker` - A transformer that implements the `ChunkerTransformer` trait.
312    ///
313    /// # Returns
314    ///
315    /// An instance of `Pipeline` with the updated stream that applies the chunker transformer to
316    /// each node.
317    #[must_use]
318    pub fn then_chunk<Output: Chunk>(
319        self,
320        chunker: impl ChunkerTransformer<Input = T, Output = Output> + 'static,
321    ) -> Pipeline<Output> {
322        let chunker = Arc::new(chunker);
323        let concurrency = chunker.concurrency().unwrap_or(self.concurrency);
324        let stream = self
325            .stream
326            .map_ok(move |node| {
327                let chunker = Arc::clone(&chunker);
328                let span = trace_span!("then_chunk", chunker);
329
330                tokio::spawn(
331                    async move {
332                        node_trace_log!(chunker, node, "Chunking node");
333                        chunker.transform_node(node).await
334                    }
335                    .instrument(span.or_current()),
336                )
337                .map_err(anyhow::Error::from)
338            })
339            .err_into::<anyhow::Error>()
340            .try_buffer_unordered(concurrency)
341            .try_flatten_unordered(None);
342
343        pipeline_with_new_stream!(self, stream.boxed())
344    }
345
346    /// Persists indexing nodes using the provided storage backend.
347    ///
348    /// # Arguments
349    ///
350    /// * `storage` - A storage backend that implements the `Storage` trait.
351    ///
352    /// # Returns
353    ///
354    /// An instance of `Pipeline` with the configured storage backend.
355    ///
356    /// # Panics
357    ///
358    /// Panics if batch size turns out to be not set and batch storage is still invoked.
359    /// Pipeline only invokes batch storing if the batch size is set, so should be alright.
360    #[must_use]
361    pub fn then_store_with<Output: Chunk>(
362        mut self,
363        storage: impl Persist<Input = T, Output = Output> + 'static,
364    ) -> Pipeline<Output> {
365        let storage = Arc::new(storage);
366
367        let storage_closure = storage.clone();
368
369        // Ensure we run the setup function only once.
370        let completed = Arc::new(Mutex::new(false));
371        let setup_fn: DynStorageSetupFn = Arc::new(move || {
372            let completed = Arc::clone(&completed);
373            let storage_closure = Arc::clone(&storage_closure);
374            Box::pin(async move {
375                let mut lock = completed.lock().await;
376
377                tracing::trace!(?storage_closure, "Setting up storage");
378                storage_closure.setup().await?;
379                *lock = true;
380                Ok(())
381            })
382        });
383        self.storage_setup_fns.push(setup_fn);
384
385        // add storage to the stream instead of doing it at the end
386        let stream = if storage.batch_size().is_some() {
387            self.stream
388                .try_chunks(storage.batch_size().unwrap())
389                .map_ok(move |nodes| {
390                    let storage = Arc::clone(&storage);
391                    let span = trace_span!("then_store_with_batched", storage);
392
393                    tokio::spawn(
394                        async move {
395                            batch_node_trace_log!(storage, nodes, "batch storing nodes");
396                            storage.batch_store(nodes).await
397                        }
398                        .instrument(span.or_current()),
399                    )
400                    .map_err(anyhow::Error::from)
401                })
402                .err_into::<anyhow::Error>()
403                .try_buffer_unordered(self.concurrency)
404                .try_flatten_unordered(None)
405                .boxed()
406        } else {
407            self.stream
408                .map_ok(move |node| {
409                    let storage = Arc::clone(&storage);
410                    let span = trace_span!("then_store_with", storage);
411
412                    tokio::spawn(
413                        async move {
414                            node_trace_log!(storage, node, "Storing node");
415
416                            storage.store(node).await
417                        }
418                        .instrument(span.or_current()),
419                    )
420                    .err_into::<anyhow::Error>()
421                })
422                .try_buffer_unordered(self.concurrency)
423                .map(|x| x.and_then(|x| x))
424                .boxed()
425        };
426
427        pipeline_with_new_stream!(self, stream)
428    }
429
430    /// Splits the stream into two streams based on a predicate.
431    ///
432    /// Note that this is not lazy. It will start consuming the stream immediately
433    /// and send each item to the left or right stream based on the predicate.
434    ///
435    /// The other streams have a buffer, but should be started as soon as possible.
436    /// The channels of the resulting streams are bounded and the parent stream will panic
437    /// if sending fails.
438    ///
439    /// They can either be run concurrently, alternated between or merged back together.
440    ///
441    /// # Panics
442    ///
443    /// Panics if the receiving pipelines buffers are full or unavailable.
444    #[must_use]
445    pub fn split_by<P>(self, predicate: P) -> (Self, Self)
446    where
447        P: Fn(&Result<Node<T>>) -> bool + Send + Sync + 'static,
448    {
449        let predicate = Arc::new(predicate);
450
451        let (left_tx, left_rx) = mpsc::channel(1000);
452        let (right_tx, right_rx) = mpsc::channel(1000);
453
454        let stream = self.stream;
455        let span = trace_span!("split_by");
456        tokio::spawn(
457            async move {
458                stream
459                    .for_each_concurrent(self.concurrency, move |item| {
460                        let predicate = Arc::clone(&predicate);
461                        let left_tx = left_tx.clone();
462                        let right_tx = right_tx.clone();
463                        async move {
464                            if predicate(&item) {
465                                tracing::trace!(?item, "Sending to left stream");
466                                left_tx
467                                    .send(item)
468                                    .await
469                                    .expect("Failed to send to left stream");
470                            } else {
471                                tracing::trace!(?item, "Sending to right stream");
472                                right_tx
473                                    .send(item)
474                                    .await
475                                    .expect("Failed to send to right stream");
476                            }
477                        }
478                    })
479                    .await;
480            }
481            .instrument(span.or_current()),
482        );
483
484        let left_pipeline = pipeline_with_new_stream!(self, left_rx);
485
486        let right_pipeline = pipeline_with_new_stream!(self, right_rx);
487
488        (left_pipeline, right_pipeline)
489    }
490
491    /// Merges two streams into one
492    ///
493    /// This is useful for merging two streams that have been split using the `split_by` method.
494    ///
495    /// The full stream can then be processed using the `run` method.
496    #[must_use]
497    pub fn merge(self, other: Self) -> Self {
498        let stream = tokio_stream::StreamExt::merge(self.stream, other.stream);
499
500        Self {
501            stream: stream.boxed().into(),
502            ..self
503        }
504    }
505
506    /// Throttles the stream of nodes, limiting the rate to 1 per duration.
507    ///
508    /// Useful for rate limiting the indexing pipeline. Uses `tokio_stream::StreamExt::throttle`
509    /// internally which has a granualarity of 1ms.
510    #[must_use]
511    pub fn throttle(mut self, duration: impl Into<Duration>) -> Self {
512        self.stream = tokio_stream::StreamExt::throttle(self.stream, duration.into())
513            .boxed()
514            .into();
515        self
516    }
517
518    // Silently filters out errors encountered by the pipeline.
519    //
520    // This method filters out errors encountered by the pipeline, preventing them from bubbling up
521    // and terminating the stream. Note that errors are not logged.
522    #[must_use]
523    pub fn filter_errors(mut self) -> Self {
524        self.stream = self
525            .stream
526            .filter_map(|result| async {
527                match result {
528                    Ok(node) => Some(Ok(node)),
529                    Err(_e) => None,
530                }
531            })
532            .boxed()
533            .into();
534        self
535    }
536
537    /// Provide a closure to selectively filter nodes or errors
538    ///
539    /// This allows you to skip specific errors or nodes, or do ad hoc inspection.
540    ///
541    /// If the closure returns true, the result is kept, otherwise it is skipped.
542    #[must_use]
543    pub fn filter<F>(mut self, filter: F) -> Self
544    where
545        F: Fn(&Result<Node<T>>) -> bool + Send + Sync + 'static,
546    {
547        self.stream = self
548            .stream
549            .filter(move |result| {
550                let will_retain = filter(result);
551
552                async move { will_retain }
553            })
554            .boxed()
555            .into();
556        self
557    }
558
559    /// Logs all results processed by the pipeline.
560    ///
561    /// This method logs all results processed by the pipeline at the `DEBUG` level.
562    #[must_use]
563    pub fn log_all(self) -> Self {
564        self.log_errors().log_nodes()
565    }
566
567    /// Logs all errors encountered by the pipeline.
568    ///
569    /// This method logs all errors encountered by the pipeline at the `ERROR` level.
570    #[must_use]
571    pub fn log_errors(mut self) -> Self {
572        self.stream = self
573            .stream
574            .inspect_err(|e| tracing::error!(?e, "Error processing node"))
575            .boxed()
576            .into();
577        self
578    }
579
580    /// Logs all nodes processed by the pipeline.
581    ///
582    /// This method logs all nodes processed by the pipeline at the `DEBUG` level.
583    #[must_use]
584    pub fn log_nodes(mut self) -> Self {
585        self.stream = self
586            .stream
587            .inspect_ok(|node| tracing::debug!(?node, "Processed node: {:?}", node))
588            .boxed()
589            .into();
590        self
591    }
592
593    /// Runs the indexing pipeline.
594    ///
595    /// This method processes the stream of nodes, applying all configured transformations and
596    /// storing the results.
597    ///
598    /// # Returns
599    ///
600    /// A `Result` indicating the success or failure of the pipeline execution.
601    ///
602    /// # Errors
603    ///
604    /// Returns an error if no storage backend is configured or if any stage of the pipeline fails.
605    #[tracing::instrument(skip_all, fields(total_nodes), name = "indexing_pipeline.run")]
606    pub async fn run(mut self) -> Result<()> {
607        tracing::info!(
608            "Starting indexing pipeline with {} concurrency",
609            self.concurrency
610        );
611        let now = std::time::Instant::now();
612
613        // TODO: No longer bail if storage is empty. Do whatever you want
614        // if self.storage.is_empty() {
615        //     anyhow::bail!("No storage configured for indexing pipeline");
616        // }
617
618        // Ensure all storage backends are set up before processing nodes
619        let setup_futures = self
620            .storage_setup_fns
621            .into_iter()
622            .map(|func| async move { func().await })
623            .collect::<Vec<_>>();
624        futures_util::future::try_join_all(setup_futures).await?;
625
626        let mut total_nodes = 0;
627        while self.stream.try_next().await?.is_some() {
628            total_nodes += 1;
629        }
630
631        let elapsed_in_seconds = now.elapsed().as_secs();
632        tracing::info!(
633            elapsed_in_seconds,
634            "Processed {} nodes in {} seconds",
635            total_nodes,
636            elapsed_in_seconds
637        );
638        tracing::Span::current().record("total_nodes", total_nodes);
639
640        Ok(())
641    }
642}
643
644#[cfg(test)]
645mod tests {
646
647    use super::*;
648    use crate::persist::MemoryStorage;
649    use mockall::Sequence;
650    use swiftide_core::indexing::*;
651
652    /// Tests a simple run of the indexing pipeline.
653    #[test_log::test(tokio::test)]
654    async fn test_simple_run() {
655        let mut loader = MockLoader::new();
656        let mut transformer = MockTransformer::new();
657        let mut batch_transformer = MockBatchableTransformer::new();
658        let mut chunker = MockChunkerTransformer::new();
659        let mut storage = MockPersist::new();
660
661        let mut seq = Sequence::new();
662
663        loader
664            .expect_into_stream()
665            .times(1)
666            .in_sequence(&mut seq)
667            .returning(|| vec![Ok(Node::default())].into());
668
669        transformer.expect_transform_node().returning(|mut node| {
670            node.chunk = "transformed".to_string();
671            Ok(node)
672        });
673        transformer.expect_concurrency().returning(|| None);
674        transformer.expect_name().returning(|| "transformer");
675
676        batch_transformer
677            .expect_batch_transform()
678            .times(1)
679            .in_sequence(&mut seq)
680            .returning(|nodes| IndexingStream::iter(nodes.into_iter().map(Ok)));
681        batch_transformer.expect_concurrency().returning(|| None);
682        batch_transformer.expect_name().returning(|| "transformer");
683        batch_transformer.expect_batch_size().returning(|| None);
684
685        chunker
686            .expect_transform_node()
687            .times(1)
688            .in_sequence(&mut seq)
689            .returning(|node| {
690                let mut nodes = vec![];
691                for i in 0..3 {
692                    let mut node = node.clone();
693                    node.chunk = format!("transformed_chunk_{i}");
694                    nodes.push(Ok(node));
695                }
696                nodes.into()
697            });
698        chunker.expect_concurrency().returning(|| None);
699        chunker.expect_name().returning(|| "chunker");
700
701        storage.expect_setup().returning(|| Ok(()));
702        storage.expect_batch_size().returning(|| None);
703        storage
704            .expect_store()
705            .times(3)
706            .in_sequence(&mut seq)
707            .withf(|node| node.chunk.starts_with("transformed_chunk_"))
708            .returning(Ok);
709        storage.expect_name().returning(|| "storage");
710
711        let pipeline = Pipeline::from_loader(loader)
712            .then(transformer)
713            .then_in_batch(batch_transformer)
714            .then_chunk(chunker)
715            .then_store_with(storage);
716
717        pipeline.run().await.unwrap();
718    }
719
720    #[tokio::test]
721    async fn test_skipping_errors() {
722        let mut loader = MockLoader::new();
723        let mut transformer = MockTransformer::new();
724        let mut storage = MockPersist::new();
725        let mut seq = Sequence::new();
726        loader
727            .expect_into_stream()
728            .times(1)
729            .in_sequence(&mut seq)
730            .returning(|| vec![Ok(Node::default())].into());
731        transformer
732            .expect_transform_node()
733            .returning(|_node| Err(anyhow::anyhow!("Error transforming node")));
734        transformer.expect_concurrency().returning(|| None);
735        transformer.expect_name().returning(|| "mock");
736        storage.expect_setup().returning(|| Ok(()));
737        storage.expect_batch_size().returning(|| None);
738        storage.expect_store().times(0).returning(Ok);
739        let pipeline = Pipeline::from_loader(loader)
740            .then(transformer)
741            .then_store_with(storage)
742            .filter_errors();
743        pipeline.run().await.unwrap();
744    }
745
746    #[tokio::test]
747    async fn test_concurrent_calls_with_simple_transformer() {
748        let mut loader = MockLoader::new();
749        let mut transformer = MockTransformer::new();
750        let mut storage = MockPersist::new();
751        let mut seq = Sequence::new();
752        loader
753            .expect_into_stream()
754            .times(1)
755            .in_sequence(&mut seq)
756            .returning(|| {
757                vec![
758                    Ok(Node::default()),
759                    Ok(Node::default()),
760                    Ok(Node::default()),
761                ]
762                .into()
763            });
764        transformer
765            .expect_transform_node()
766            .times(3)
767            .in_sequence(&mut seq)
768            .returning(|mut node| {
769                node.chunk = "transformed".to_string();
770                Ok(node)
771            });
772        transformer.expect_concurrency().returning(|| Some(3));
773        transformer.expect_name().returning(|| "transformer");
774        storage.expect_setup().returning(|| Ok(()));
775        storage.expect_batch_size().returning(|| None);
776        storage.expect_store().times(3).returning(Ok);
777        storage.expect_name().returning(|| "storage");
778
779        let pipeline = Pipeline::from_loader(loader)
780            .then(transformer)
781            .then_store_with(storage);
782        pipeline.run().await.unwrap();
783    }
784
785    #[tokio::test]
786    async fn test_arbitrary_closures_as_transformer() {
787        let mut loader = MockLoader::new();
788        let transformer = |node: TextNode| {
789            let mut node = node;
790            node.chunk = "transformed".to_string();
791            Ok(node)
792        };
793        let storage = MemoryStorage::default();
794        let mut seq = Sequence::new();
795        loader
796            .expect_into_stream()
797            .times(1)
798            .in_sequence(&mut seq)
799            .returning(|| vec![Ok(TextNode::default())].into());
800
801        let pipeline = Pipeline::from_loader(loader)
802            .then(transformer)
803            .then_store_with(storage.clone());
804        pipeline.run().await.unwrap();
805
806        dbg!(storage.clone());
807        let processed_node = storage.get("0").await.unwrap();
808        assert_eq!(processed_node.chunk, "transformed");
809    }
810
811    #[tokio::test]
812    async fn test_arbitrary_closures_as_batch_transformer() {
813        let mut loader = MockLoader::new();
814        let batch_transformer = |nodes: Vec<TextNode>| {
815            IndexingStream::iter(nodes.into_iter().map(|mut node| {
816                node.chunk = "transformed".to_string();
817                Ok(node)
818            }))
819        };
820        let storage = MemoryStorage::default();
821        let mut seq = Sequence::new();
822        loader
823            .expect_into_stream()
824            .times(1)
825            .in_sequence(&mut seq)
826            .returning(|| vec![Ok(TextNode::default())].into());
827
828        let pipeline = Pipeline::from_loader(loader)
829            .then_in_batch(batch_transformer)
830            .then_store_with(storage.clone());
831        pipeline.run().await.unwrap();
832
833        dbg!(storage.clone());
834        let processed_node = storage.get("0").await.unwrap();
835        assert_eq!(processed_node.chunk, "transformed");
836    }
837
838    #[tokio::test]
839    async fn test_filter_closure() {
840        let mut loader = MockLoader::new();
841        let storage = MemoryStorage::default();
842        let mut seq = Sequence::new();
843        loader
844            .expect_into_stream()
845            .times(1)
846            .in_sequence(&mut seq)
847            .returning(|| {
848                vec![
849                    Ok(TextNode::default()),
850                    Ok(TextNode::new("skip")),
851                    Ok(TextNode::default()),
852                ]
853                .into()
854            });
855        let pipeline = Pipeline::from_loader(loader)
856            .filter(|result| {
857                let node = result.as_ref().unwrap();
858                node.chunk != "skip"
859            })
860            .then_store_with(storage.clone());
861        pipeline.run().await.unwrap();
862        let nodes = storage.get_all().await;
863        assert_eq!(nodes.len(), 2);
864    }
865
866    #[test_log::test(tokio::test)]
867    async fn test_split_and_merge() {
868        let mut loader = MockLoader::new();
869        let storage = MemoryStorage::default();
870        let mut seq = Sequence::new();
871        loader
872            .expect_into_stream()
873            .times(1)
874            .in_sequence(&mut seq)
875            .returning(|| {
876                vec![
877                    Ok(TextNode::default()),
878                    Ok(TextNode::new("will go left")),
879                    Ok(TextNode::default()),
880                ]
881                .into()
882            });
883
884        let pipeline = Pipeline::from_loader(loader);
885        let (mut left, mut right) = pipeline.split_by(|node| {
886            if let Ok(node) = node {
887                node.chunk.starts_with("will go left")
888            } else {
889                false
890            }
891        });
892
893        // change the chunk to 'left'
894        left = left
895            .then(move |mut node: TextNode| {
896                node.chunk = "left".to_string();
897
898                Ok(node)
899            })
900            .log_all();
901
902        right = right.then(move |mut node: TextNode| {
903            node.chunk = "right".to_string();
904            Ok(node)
905        });
906
907        left.merge(right)
908            .then_store_with(storage.clone())
909            .run()
910            .await
911            .unwrap();
912        dbg!(storage.clone());
913
914        let all_nodes = storage.get_all_values().await;
915        assert_eq!(
916            all_nodes.iter().filter(|node| node.chunk == "left").count(),
917            1
918        );
919        assert_eq!(
920            all_nodes
921                .iter()
922                .filter(|node| node.chunk == "right")
923                .count(),
924            2
925        );
926    }
927
928    #[tokio::test]
929    async fn test_all_steps_should_work_as_dyn_box() {
930        let mut loader = MockLoader::new();
931        loader
932            .expect_into_stream_boxed()
933            .returning(|| vec![Ok(TextNode::default())].into());
934
935        let mut transformer = MockTransformer::new();
936        transformer.expect_transform_node().returning(Ok);
937        transformer.expect_concurrency().returning(|| None);
938        transformer.expect_name().returning(|| "mock");
939
940        let mut batch_transformer = MockBatchableTransformer::new();
941        batch_transformer
942            .expect_batch_transform()
943            .returning(std::convert::Into::into);
944        batch_transformer.expect_concurrency().returning(|| None);
945        batch_transformer.expect_name().returning(|| "mock");
946        let mut chunker = MockChunkerTransformer::new();
947        chunker
948            .expect_transform_node()
949            .returning(|node| vec![node].into());
950        chunker.expect_concurrency().returning(|| None);
951        chunker.expect_name().returning(|| "mock");
952
953        let mut storage = MockPersist::new();
954        storage.expect_setup().returning(|| Ok(()));
955        storage.expect_store().returning(Ok);
956        storage.expect_batch_size().returning(|| None);
957        storage.expect_name().returning(|| "mock");
958
959        let pipeline = Pipeline::from_loader(Box::new(loader) as Box<dyn Loader<Output = String>>)
960            .then(Box::new(transformer) as Box<dyn Transformer<Input = String, Output = String>>)
961            .then_in_batch(Box::new(batch_transformer) as Box<dyn BatchableTransformer<Input = String, Output = String>>)
962            .then_chunk(Box::new(chunker) as Box<dyn ChunkerTransformer<Input = String, Output = String>>)
963            .then_store_with(Box::new(storage) as Box<dyn Persist<Input = String, Output = String>>);
964        pipeline.run().await.unwrap();
965    }
966}