swiftide_indexing/
pipeline.rs

1use anyhow::Result;
2use futures_util::{
3    FutureExt, StreamExt, TryFutureExt, TryStreamExt,
4    future::{BoxFuture, Shared},
5};
6use swiftide_core::{
7    BatchableTransformer, ChunkerTransformer, Loader, NodeCache, Persist, SimplePrompt,
8    Transformer, WithBatchIndexingDefaults, WithIndexingDefaults, indexing::IndexingDefaults,
9};
10use tokio::{
11    sync::mpsc,
12    task::{self},
13};
14use tracing::Instrument;
15
16use std::{collections::HashSet, sync::Arc, time::Duration};
17
18use swiftide_core::indexing::{EmbedMode, IndexingStream, Node};
19
20macro_rules! trace_span {
21    ($op:literal, $step:expr) => {
22        tracing::trace_span!($op, "otel.name" = format!("{}.{}", $op, $step.name()),)
23    };
24
25    ($op:literal) => {
26        tracing::trace_span!($op, "otel.name" = format!("{}", $op),)
27    };
28}
29
30macro_rules! node_trace_log {
31    ($step:expr, $node:expr, $msg:literal) => {
32        tracing::trace!(
33            node = ?$node,
34            node_id = ?$node.id(),
35            step = $step.name(),
36            $msg
37        )
38    };
39}
40
41macro_rules! batch_node_trace_log {
42    ($step:expr, $nodes:expr, $msg:literal) => {
43        tracing::trace!(batch_size = $nodes.len(), nodes = ?$nodes, step = $step.name(), $msg)
44    };
45}
46
47/// The default batch size for batch processing.
48const DEFAULT_BATCH_SIZE: usize = 256;
49
50/// A pipeline for indexing files, adding metadata, chunking, transforming, embedding, and then
51/// storing them.
52///
53/// The `Pipeline` struct orchestrates the entire file indexing process. It is designed to be
54/// flexible and performant, allowing for various stages of data transformation and storage to be
55/// configured and executed asynchronously.
56///
57/// # Fields
58///
59/// * `stream` - The stream of `Node` items to be processed.
60/// * `storage` - Optional storage backend where the processed nodes will be stored.
61/// * `concurrency` - The level of concurrency for processing nodes.
62pub struct Pipeline {
63    stream: IndexingStream,
64    storage: Vec<Arc<dyn Persist>>,
65    concurrency: usize,
66    indexing_defaults: IndexingDefaults,
67    batch_size: usize,
68    cache_sender: Option<mpsc::Sender<uuid::Uuid>>,
69    cache_handle: Option<Shared<BoxFuture<'static, ()>>>,
70}
71
72impl Default for Pipeline {
73    /// Creates a default `Pipeline` with an empty stream, no storage, and a concurrency level equal
74    /// to the number of CPUs.
75    fn default() -> Self {
76        Self {
77            stream: IndexingStream::empty(),
78            storage: Vec::default(),
79            concurrency: num_cpus::get(),
80            indexing_defaults: IndexingDefaults::default(),
81            batch_size: DEFAULT_BATCH_SIZE,
82            cache_sender: None,
83            cache_handle: None,
84        }
85    }
86}
87
88impl Pipeline {
89    /// Creates a `Pipeline` from a given loader.
90    ///
91    /// # Arguments
92    ///
93    /// * `loader` - A loader that implements the `Loader` trait.
94    ///
95    /// # Returns
96    ///
97    /// An instance of `Pipeline` initialized with the provided loader.
98    pub fn from_loader(loader: impl Loader + 'static) -> Self {
99        let stream = loader.into_stream();
100        Self {
101            stream,
102            ..Default::default()
103        }
104    }
105
106    /// Sets the default LLM client to be used for LLM prompts for all transformers in the
107    /// pipeline.
108    #[must_use]
109    pub fn with_default_llm_client(mut self, client: impl SimplePrompt + 'static) -> Self {
110        self.indexing_defaults = IndexingDefaults::from_simple_prompt(Box::new(client));
111        self
112    }
113
114    /// Creates a `Pipeline` from a given stream.
115    ///
116    /// # Arguments
117    ///
118    /// * `stream` - An `IndexingStream` containing the nodes to be processed.
119    ///
120    /// # Returns
121    ///
122    /// An instance of `Pipeline` initialized with the provided stream.
123    pub fn from_stream(stream: impl Into<IndexingStream>) -> Self {
124        Self {
125            stream: stream.into(),
126            ..Default::default()
127        }
128    }
129
130    /// Sets the concurrency level for the pipeline. By default the concurrency is set to the
131    /// number of cpus.
132    ///
133    /// # Arguments
134    ///
135    /// * `concurrency` - The desired level of concurrency.
136    ///
137    /// # Returns
138    ///
139    /// An instance of `Pipeline` with the updated concurrency level.
140    #[must_use]
141    pub fn with_concurrency(mut self, concurrency: usize) -> Self {
142        self.concurrency = concurrency;
143        self
144    }
145
146    /// Sets the embed mode for the pipeline. The embed mode controls what (combination) fields of a
147    /// [`Node`] be embedded with a vector when transforming with [`crate::transformers::Embed`]
148    ///
149    /// See also [`swiftide_core::indexing::EmbedMode`].
150    ///
151    /// # Arguments
152    ///
153    /// * `embed_mode` - The desired embed mode.
154    ///
155    /// # Returns
156    ///
157    /// An instance of `Pipeline` with the updated embed mode.
158    #[must_use]
159    pub fn with_embed_mode(mut self, embed_mode: EmbedMode) -> Self {
160        self.stream = self
161            .stream
162            .map_ok(move |mut node| {
163                node.embed_mode = embed_mode;
164                node
165            })
166            .boxed()
167            .into();
168        self
169    }
170
171    /// Filters out cached nodes using the provided cache.
172    ///
173    /// # Arguments
174    ///
175    /// * `cache` - A cache that implements the `NodeCache` trait.
176    ///
177    /// # Returns
178    ///
179    /// An instance of `Pipeline` with the updated stream that filters out cached nodes.
180    #[must_use]
181    pub fn filter_cached(mut self, cache: impl NodeCache + 'static) -> Self {
182        let cache = Arc::new(cache);
183        let (cache_sender, mut cache_receiver) = mpsc::channel(1000);
184
185        let cache_for_task = Arc::clone(&cache);
186
187        let handle = tokio::spawn(async move {
188            while let Some(cache_key) = cache_receiver.recv().await {
189                let cache_node = Node {
190                    parent_id: Some(cache_key),
191                    ..Default::default()
192                };
193                cache_for_task.set(&cache_node).await;
194            }
195        });
196
197        self.cache_handle = Some(
198            handle
199                .map(|result| {
200                    if let Err(err) = result {
201                        tracing::error!("Cache task failed: {err}");
202                    }
203                })
204                .boxed()
205                .shared(),
206        );
207
208        self.stream = self
209            .stream
210            .try_filter_map(move |node| {
211                let cache = Arc::clone(&cache);
212                let span = trace_span!("filter_cached", cache);
213
214                async move {
215                    if cache.get(&node).await {
216                        node_trace_log!(cache, node, "node in cache, skipping");
217                        Ok(None)
218                    } else {
219                        node_trace_log!(cache, node, "node not in cache, processing");
220                        Ok(Some(node))
221                    }
222                }
223                .instrument(span.or_current())
224            })
225            .boxed()
226            .into();
227        self.cache_sender = Some(cache_sender);
228        self
229    }
230
231    /// Adds a transformer to the pipeline.
232    ///
233    /// Closures can also be provided as transformers.
234    ///
235    /// # Arguments
236    ///
237    /// * `transformer` - A transformer that implements the `Transformer` trait.
238    ///
239    /// # Returns
240    ///
241    /// An instance of `Pipeline` with the updated stream that applies the transformer to each node.
242    #[must_use]
243    pub fn then(
244        mut self,
245        mut transformer: impl Transformer + WithIndexingDefaults + 'static,
246    ) -> Self {
247        let concurrency = transformer.concurrency().unwrap_or(self.concurrency);
248
249        transformer.with_indexing_defaults(self.indexing_defaults.clone());
250
251        let transformer = Arc::new(transformer);
252        self.stream = self
253            .stream
254            .map_ok(move |node| {
255                let transformer = transformer.clone();
256                let span = trace_span!("then", transformer);
257
258                task::spawn(
259                    async move {
260                        node_trace_log!(transformer, node, "Transforming node");
261                        transformer.transform_node(node).await
262                    }
263                    .instrument(span.or_current()),
264                )
265                .err_into::<anyhow::Error>()
266            })
267            .try_buffer_unordered(concurrency)
268            .map(|x| x.and_then(|x| x))
269            .boxed()
270            .into();
271
272        self
273    }
274
275    /// Adds a batch transformer to the pipeline.
276    ///
277    /// If the transformer has a batch size set, the batch size from the transformer is used,
278    /// otherwise the pipeline default batch size ([`DEFAULT_BATCH_SIZE`]).
279    ///
280    /// # Arguments
281    ///
282    /// * `transformer` - A transformer that implements the `BatchableTransformer` trait.
283    ///
284    /// # Returns
285    ///
286    /// An instance of `Pipeline` with the updated stream that applies the batch transformer to each
287    /// batch of nodes.
288    #[must_use]
289    pub fn then_in_batch(
290        mut self,
291        mut transformer: impl BatchableTransformer + WithBatchIndexingDefaults + 'static,
292    ) -> Self {
293        let concurrency = transformer.concurrency().unwrap_or(self.concurrency);
294
295        transformer.with_indexing_defaults(self.indexing_defaults.clone());
296
297        let transformer = Arc::new(transformer);
298        self.stream = self
299            .stream
300            .try_chunks(transformer.batch_size().unwrap_or(self.batch_size))
301            .map_ok(move |nodes| {
302                let transformer = Arc::clone(&transformer);
303                let span = trace_span!("then_in_batch", transformer);
304
305                tokio::spawn(
306                    async move {
307                        batch_node_trace_log!(transformer, nodes, "batch transforming nodes");
308                        transformer.batch_transform(nodes).await
309                    }
310                    .instrument(span.or_current()),
311                )
312                .map_err(anyhow::Error::from)
313            })
314            .err_into::<anyhow::Error>()
315            .try_buffer_unordered(concurrency) // First get the streams from each future
316            .try_flatten_unordered(None) // Then flatten all the streams back into one
317            .boxed()
318            .into();
319        self
320    }
321
322    /// Adds a chunker transformer to the pipeline.
323    ///
324    /// # Arguments
325    ///
326    /// * `chunker` - A transformer that implements the `ChunkerTransformer` trait.
327    ///
328    /// # Returns
329    ///
330    /// An instance of `Pipeline` with the updated stream that applies the chunker transformer to
331    /// each node.
332    #[must_use]
333    pub fn then_chunk(mut self, chunker: impl ChunkerTransformer + 'static) -> Self {
334        let chunker = Arc::new(chunker);
335        let concurrency = chunker.concurrency().unwrap_or(self.concurrency);
336        self.stream = self
337            .stream
338            .map_ok(move |node| {
339                let chunker = Arc::clone(&chunker);
340                let span = trace_span!("then_chunk", chunker);
341
342                tokio::spawn(
343                    async move {
344                        node_trace_log!(chunker, node, "Chunking node");
345                        chunker.transform_node(node).await
346                    }
347                    .instrument(span.or_current()),
348                )
349                .map_err(anyhow::Error::from)
350            })
351            .err_into::<anyhow::Error>()
352            .try_buffer_unordered(concurrency)
353            .try_flatten_unordered(None)
354            .boxed()
355            .into();
356
357        self
358    }
359
360    /// Persists indexing nodes using the provided storage backend.
361    ///
362    /// # Arguments
363    ///
364    /// * `storage` - A storage backend that implements the `Storage` trait.
365    ///
366    /// # Returns
367    ///
368    /// An instance of `Pipeline` with the configured storage backend.
369    ///
370    /// # Panics
371    ///
372    /// Panics if batch size turns out to be not set and batch storage is still invoked.
373    /// Pipeline only invokes batch storing if the batch size is set, so should be alright.
374    #[must_use]
375    pub fn then_store_with(mut self, storage: impl Persist + 'static) -> Self {
376        let storage = Arc::new(storage);
377        self.storage.push(storage.clone());
378        // add storage to the stream instead of doing it at the end
379        if storage.batch_size().is_some() {
380            self.stream = self
381                .stream
382                .try_chunks(storage.batch_size().unwrap())
383                .map_ok(move |nodes| {
384                    let storage = Arc::clone(&storage);
385                    let span = trace_span!("then_store_with_batched", storage);
386
387                    tokio::spawn(
388                        async move {
389                            batch_node_trace_log!(storage, nodes, "batch storing nodes");
390                            storage.batch_store(nodes).await
391                        }
392                        .instrument(span.or_current()),
393                    )
394                    .map_err(anyhow::Error::from)
395                })
396                .err_into::<anyhow::Error>()
397                .try_buffer_unordered(self.concurrency)
398                .try_flatten_unordered(None)
399                .boxed()
400                .into();
401        } else {
402            self.stream = self
403                .stream
404                .map_ok(move |node| {
405                    let storage = Arc::clone(&storage);
406                    let span = trace_span!("then_store_with", storage);
407
408                    tokio::spawn(
409                        async move {
410                            node_trace_log!(storage, node, "Storing node");
411
412                            storage.store(node).await
413                        }
414                        .instrument(span.or_current()),
415                    )
416                    .err_into::<anyhow::Error>()
417                })
418                .try_buffer_unordered(self.concurrency)
419                .map(|x| x.and_then(|x| x))
420                .boxed()
421                .into();
422        }
423
424        self
425    }
426
427    /// Splits the stream into two streams based on a predicate.
428    ///
429    /// Note that this is not lazy. It will start consuming the stream immediately
430    /// and send each item to the left or right stream based on the predicate.
431    ///
432    /// The other streams have a buffer, but should be started as soon as possible.
433    /// The channels of the resulting streams are bounded and the parent stream will panic
434    /// if sending fails.
435    ///
436    /// They can either be run concurrently, alternated between or merged back together.
437    ///
438    /// # Panics
439    ///
440    /// Panics if the receiving pipelines buffers are full or unavailable.
441    #[must_use]
442    pub fn split_by<P>(self, predicate: P) -> (Self, Self)
443    where
444        P: Fn(&Result<Node>) -> bool + Send + Sync + 'static,
445    {
446        let predicate = Arc::new(predicate);
447
448        let (left_tx, left_rx) = mpsc::channel(1000);
449        let (right_tx, right_rx) = mpsc::channel(1000);
450
451        let stream = self.stream;
452        let span = trace_span!("split_by");
453        tokio::spawn(
454            async move {
455                stream
456                    .for_each_concurrent(self.concurrency, move |item| {
457                        let predicate = Arc::clone(&predicate);
458                        let left_tx = left_tx.clone();
459                        let right_tx = right_tx.clone();
460                        async move {
461                            if predicate(&item) {
462                                tracing::trace!(?item, "Sending to left stream");
463                                left_tx
464                                    .send(item)
465                                    .await
466                                    .expect("Failed to send to left stream");
467                            } else {
468                                tracing::trace!(?item, "Sending to right stream");
469                                right_tx
470                                    .send(item)
471                                    .await
472                                    .expect("Failed to send to right stream");
473                            }
474                        }
475                    })
476                    .await;
477            }
478            .instrument(span.or_current()),
479        );
480
481        let left_pipeline = Self {
482            stream: left_rx.into(),
483            storage: self.storage.clone(),
484            concurrency: self.concurrency,
485            indexing_defaults: self.indexing_defaults.clone(),
486            batch_size: self.batch_size,
487            cache_sender: self.cache_sender.clone(),
488            cache_handle: self.cache_handle.clone(),
489        };
490
491        let right_pipeline = Self {
492            stream: right_rx.into(),
493            storage: self.storage.clone(),
494            concurrency: self.concurrency,
495            indexing_defaults: self.indexing_defaults.clone(),
496            batch_size: self.batch_size,
497            cache_sender: self.cache_sender.clone(),
498            cache_handle: self.cache_handle.clone(),
499        };
500
501        (left_pipeline, right_pipeline)
502    }
503
504    /// Merges two streams into one
505    ///
506    /// This is useful for merging two streams that have been split using the `split_by` method.
507    ///
508    /// The full stream can then be processed using the `run` method.
509    #[must_use]
510    pub fn merge(self, other: Self) -> Self {
511        let stream = tokio_stream::StreamExt::merge(self.stream, other.stream);
512
513        Self {
514            stream: stream.boxed().into(),
515            ..self
516        }
517    }
518
519    /// Throttles the stream of nodes, limiting the rate to 1 per duration.
520    ///
521    /// Useful for rate limiting the indexing pipeline. Uses `tokio_stream::StreamExt::throttle`
522    /// internally which has a granualarity of 1ms.
523    #[must_use]
524    pub fn throttle(mut self, duration: impl Into<Duration>) -> Self {
525        self.stream = tokio_stream::StreamExt::throttle(self.stream, duration.into())
526            .boxed()
527            .into();
528        self
529    }
530
531    // Silently filters out errors encountered by the pipeline.
532    //
533    // This method filters out errors encountered by the pipeline, preventing them from bubbling up
534    // and terminating the stream. Note that errors are not logged.
535    #[must_use]
536    pub fn filter_errors(mut self) -> Self {
537        self.stream = self
538            .stream
539            .filter_map(|result| async {
540                match result {
541                    Ok(node) => Some(Ok(node)),
542                    Err(_e) => None,
543                }
544            })
545            .boxed()
546            .into();
547        self
548    }
549
550    /// Provide a closure to selectively filter nodes or errors
551    ///
552    /// This allows you to skip specific errors or nodes, or do ad hoc inspection.
553    ///
554    /// If the closure returns true, the result is kept, otherwise it is skipped.
555    #[must_use]
556    pub fn filter<F>(mut self, filter: F) -> Self
557    where
558        F: Fn(&Result<Node>) -> bool + Send + Sync + 'static,
559    {
560        self.stream = self
561            .stream
562            .filter(move |result| {
563                let will_retain = filter(result);
564
565                async move { will_retain }
566            })
567            .boxed()
568            .into();
569        self
570    }
571
572    /// Logs all results processed by the pipeline.
573    ///
574    /// This method logs all results processed by the pipeline at the `DEBUG` level.
575    #[must_use]
576    pub fn log_all(self) -> Self {
577        self.log_errors().log_nodes()
578    }
579
580    /// Logs all errors encountered by the pipeline.
581    ///
582    /// This method logs all errors encountered by the pipeline at the `ERROR` level.
583    #[must_use]
584    pub fn log_errors(mut self) -> Self {
585        self.stream = self
586            .stream
587            .inspect_err(|e| tracing::error!(?e, "Error processing node"))
588            .boxed()
589            .into();
590        self
591    }
592
593    /// Logs all nodes processed by the pipeline.
594    ///
595    /// This method logs all nodes processed by the pipeline at the `DEBUG` level.
596    #[must_use]
597    pub fn log_nodes(mut self) -> Self {
598        self.stream = self
599            .stream
600            .inspect_ok(|node| tracing::debug!(?node, "Processed node: {:?}", node))
601            .boxed()
602            .into();
603        self
604    }
605
606    /// Runs the indexing pipeline.
607    ///
608    /// This method processes the stream of nodes, applying all configured transformations and
609    /// storing the results.
610    ///
611    /// # Returns
612    ///
613    /// A `Result` indicating the success or failure of the pipeline execution.
614    ///
615    /// # Errors
616    ///
617    /// Returns an error if no storage backend is configured or if any stage of the pipeline fails.
618    #[tracing::instrument(skip_all, fields(total_nodes), name = "indexing_pipeline.run")]
619    pub async fn run(mut self) -> Result<()> {
620        tracing::info!(
621            "Starting indexing pipeline with {} concurrency",
622            self.concurrency
623        );
624        let now = std::time::Instant::now();
625        if self.storage.is_empty() {
626            anyhow::bail!("No storage configured for indexing pipeline");
627        }
628
629        // Ensure all storage backends are set up before processing nodes
630        let setup_futures = self
631            .storage
632            .into_iter()
633            .map(|storage| async move { storage.setup().await })
634            .collect::<Vec<_>>();
635        futures_util::future::try_join_all(setup_futures).await?;
636
637        let mut cache_keys = HashSet::new();
638
639        let mut total_nodes = 0;
640
641        while let Some(node) = self.stream.try_next().await? {
642            total_nodes += 1;
643
644            let cache_key = node.parent_id().unwrap_or(node.id());
645
646            if let Some(sender) = &self.cache_sender {
647                if cache_keys.insert(cache_key) {
648                    let _ = sender.send(cache_key).await;
649                }
650            }
651        }
652
653        let elapsed_in_seconds = now.elapsed().as_secs();
654        tracing::info!(
655            elapsed_in_seconds,
656            "Processed {total_nodes} nodes in {elapsed_in_seconds} seconds",
657        );
658
659        tracing::Span::current().record("total_nodes", total_nodes);
660
661        // Drop the sender to signal that no more cache keys will be sent
662        drop(self.cache_sender);
663
664        if let Some(handle) = self.cache_handle {
665            handle.await;
666        }
667
668        Ok(())
669    }
670}
671
672#[cfg(test)]
673mod tests {
674
675    use super::*;
676    use crate::persist::MemoryStorage;
677    use mockall::Sequence;
678    use swiftide_core::indexing::*;
679
680    /// Tests a simple run of the indexing pipeline.
681    #[test_log::test(tokio::test)]
682    async fn test_simple_run() {
683        let mut loader = MockLoader::new();
684        let mut transformer = MockTransformer::new();
685        let mut batch_transformer = MockBatchableTransformer::new();
686        let mut chunker = MockChunkerTransformer::new();
687        let mut storage = MockPersist::new();
688
689        let mut seq = Sequence::new();
690
691        loader
692            .expect_into_stream()
693            .times(1)
694            .in_sequence(&mut seq)
695            .returning(|| vec![Ok(Node::default())].into());
696
697        transformer.expect_transform_node().returning(|mut node| {
698            node.chunk = "transformed".to_string();
699            Ok(node)
700        });
701        transformer.expect_concurrency().returning(|| None);
702        transformer.expect_name().returning(|| "transformer");
703
704        batch_transformer
705            .expect_batch_transform()
706            .times(1)
707            .in_sequence(&mut seq)
708            .returning(|nodes| IndexingStream::iter(nodes.into_iter().map(Ok)));
709        batch_transformer.expect_concurrency().returning(|| None);
710        batch_transformer.expect_name().returning(|| "transformer");
711        batch_transformer.expect_batch_size().returning(|| None);
712
713        chunker
714            .expect_transform_node()
715            .times(1)
716            .in_sequence(&mut seq)
717            .returning(|node| {
718                let mut nodes = vec![];
719                for i in 0..3 {
720                    let mut node = node.clone();
721                    node.chunk = format!("transformed_chunk_{i}");
722                    nodes.push(Ok(node));
723                }
724                nodes.into()
725            });
726        chunker.expect_concurrency().returning(|| None);
727        chunker.expect_name().returning(|| "chunker");
728
729        storage.expect_setup().returning(|| Ok(()));
730        storage.expect_batch_size().returning(|| None);
731        storage
732            .expect_store()
733            .times(3)
734            .in_sequence(&mut seq)
735            .withf(|node| node.chunk.starts_with("transformed_chunk_"))
736            .returning(Ok);
737        storage.expect_name().returning(|| "storage");
738
739        let pipeline = Pipeline::from_loader(loader)
740            .then(transformer)
741            .then_in_batch(batch_transformer)
742            .then_chunk(chunker)
743            .then_store_with(storage);
744
745        pipeline.run().await.unwrap();
746    }
747
748    #[tokio::test]
749    async fn test_skipping_errors() {
750        let mut loader = MockLoader::new();
751        let mut transformer = MockTransformer::new();
752        let mut storage = MockPersist::new();
753        let mut seq = Sequence::new();
754        loader
755            .expect_into_stream()
756            .times(1)
757            .in_sequence(&mut seq)
758            .returning(|| vec![Ok(Node::default())].into());
759        transformer
760            .expect_transform_node()
761            .returning(|_node| Err(anyhow::anyhow!("Error transforming node")));
762        transformer.expect_concurrency().returning(|| None);
763        transformer.expect_name().returning(|| "mock");
764        storage.expect_setup().returning(|| Ok(()));
765        storage.expect_batch_size().returning(|| None);
766        storage.expect_store().times(0).returning(Ok);
767        let pipeline = Pipeline::from_loader(loader)
768            .then(transformer)
769            .then_store_with(storage)
770            .filter_errors();
771        pipeline.run().await.unwrap();
772    }
773
774    #[tokio::test]
775    async fn test_concurrent_calls_with_simple_transformer() {
776        let mut loader = MockLoader::new();
777        let mut transformer = MockTransformer::new();
778        let mut storage = MockPersist::new();
779        let mut seq = Sequence::new();
780        loader
781            .expect_into_stream()
782            .times(1)
783            .in_sequence(&mut seq)
784            .returning(|| {
785                vec![
786                    Ok(Node::default()),
787                    Ok(Node::default()),
788                    Ok(Node::default()),
789                ]
790                .into()
791            });
792        transformer
793            .expect_transform_node()
794            .times(3)
795            .in_sequence(&mut seq)
796            .returning(|mut node| {
797                node.chunk = "transformed".to_string();
798                Ok(node)
799            });
800        transformer.expect_concurrency().returning(|| Some(3));
801        transformer.expect_name().returning(|| "transformer");
802        storage.expect_setup().returning(|| Ok(()));
803        storage.expect_batch_size().returning(|| None);
804        storage.expect_store().times(3).returning(Ok);
805        storage.expect_name().returning(|| "storage");
806
807        let pipeline = Pipeline::from_loader(loader)
808            .then(transformer)
809            .then_store_with(storage);
810        pipeline.run().await.unwrap();
811    }
812
813    #[tokio::test]
814    async fn test_arbitrary_closures_as_transformer() {
815        let mut loader = MockLoader::new();
816        let transformer = |node: Node| {
817            let mut node = node;
818            node.chunk = "transformed".to_string();
819            Ok(node)
820        };
821        let storage = MemoryStorage::default();
822        let mut seq = Sequence::new();
823        loader
824            .expect_into_stream()
825            .times(1)
826            .in_sequence(&mut seq)
827            .returning(|| vec![Ok(Node::default())].into());
828
829        let pipeline = Pipeline::from_loader(loader)
830            .then(transformer)
831            .then_store_with(storage.clone());
832        pipeline.run().await.unwrap();
833
834        dbg!(storage.clone());
835        let processed_node = storage.get("0").await.unwrap();
836        assert_eq!(processed_node.chunk, "transformed");
837    }
838
839    #[tokio::test]
840    async fn test_arbitrary_closures_as_batch_transformer() {
841        let mut loader = MockLoader::new();
842        let batch_transformer = |nodes: Vec<Node>| {
843            IndexingStream::iter(nodes.into_iter().map(|mut node| {
844                node.chunk = "transformed".to_string();
845                Ok(node)
846            }))
847        };
848        let storage = MemoryStorage::default();
849        let mut seq = Sequence::new();
850        loader
851            .expect_into_stream()
852            .times(1)
853            .in_sequence(&mut seq)
854            .returning(|| vec![Ok(Node::default())].into());
855
856        let pipeline = Pipeline::from_loader(loader)
857            .then_in_batch(batch_transformer)
858            .then_store_with(storage.clone());
859        pipeline.run().await.unwrap();
860
861        dbg!(storage.clone());
862        let processed_node = storage.get("0").await.unwrap();
863        assert_eq!(processed_node.chunk, "transformed");
864    }
865
866    #[tokio::test]
867    async fn test_filter_closure() {
868        let mut loader = MockLoader::new();
869        let storage = MemoryStorage::default();
870        let mut seq = Sequence::new();
871        loader
872            .expect_into_stream()
873            .times(1)
874            .in_sequence(&mut seq)
875            .returning(|| {
876                vec![
877                    Ok(Node::default()),
878                    Ok(Node::new("skip")),
879                    Ok(Node::default()),
880                ]
881                .into()
882            });
883        let pipeline = Pipeline::from_loader(loader)
884            .filter(|result| {
885                let node = result.as_ref().unwrap();
886                node.chunk != "skip"
887            })
888            .then_store_with(storage.clone());
889        pipeline.run().await.unwrap();
890        let nodes = storage.get_all().await;
891        assert_eq!(nodes.len(), 2);
892    }
893
894    #[test_log::test(tokio::test)]
895    async fn test_split_and_merge() {
896        let mut loader = MockLoader::new();
897        let storage = MemoryStorage::default();
898        let mut seq = Sequence::new();
899        loader
900            .expect_into_stream()
901            .times(1)
902            .in_sequence(&mut seq)
903            .returning(|| {
904                vec![
905                    Ok(Node::default()),
906                    Ok(Node::new("will go left")),
907                    Ok(Node::default()),
908                ]
909                .into()
910            });
911
912        let pipeline = Pipeline::from_loader(loader);
913        let (mut left, mut right) = pipeline.split_by(|node| {
914            if let Ok(node) = node {
915                node.chunk.starts_with("will go left")
916            } else {
917                false
918            }
919        });
920
921        // change the chunk to 'left'
922        left = left
923            .then(move |mut node: Node| {
924                node.chunk = "left".to_string();
925
926                Ok(node)
927            })
928            .log_all();
929
930        right = right.then(move |mut node: Node| {
931            node.chunk = "right".to_string();
932            Ok(node)
933        });
934
935        left.merge(right)
936            .then_store_with(storage.clone())
937            .run()
938            .await
939            .unwrap();
940        dbg!(storage.clone());
941
942        let all_nodes = storage.get_all_values().await;
943        assert_eq!(
944            all_nodes.iter().filter(|node| node.chunk == "left").count(),
945            1
946        );
947        assert_eq!(
948            all_nodes
949                .iter()
950                .filter(|node| node.chunk == "right")
951                .count(),
952            2
953        );
954    }
955
956    #[tokio::test]
957    async fn test_all_steps_should_work_as_dyn_box() {
958        let mut loader = MockLoader::new();
959        loader
960            .expect_into_stream_boxed()
961            .returning(|| vec![Ok(Node::default())].into());
962
963        let mut transformer = MockTransformer::new();
964        transformer.expect_transform_node().returning(Ok);
965        transformer.expect_concurrency().returning(|| None);
966        transformer.expect_name().returning(|| "mock");
967
968        let mut batch_transformer = MockBatchableTransformer::new();
969        batch_transformer
970            .expect_batch_transform()
971            .returning(std::convert::Into::into);
972        batch_transformer.expect_concurrency().returning(|| None);
973        batch_transformer.expect_name().returning(|| "mock");
974        let mut chunker = MockChunkerTransformer::new();
975        chunker
976            .expect_transform_node()
977            .returning(|node| vec![node].into());
978        chunker.expect_concurrency().returning(|| None);
979        chunker.expect_name().returning(|| "mock");
980
981        let mut storage = MockPersist::new();
982        storage.expect_setup().returning(|| Ok(()));
983        storage.expect_store().returning(Ok);
984        storage.expect_batch_size().returning(|| None);
985        storage.expect_name().returning(|| "mock");
986
987        let pipeline = Pipeline::from_loader(Box::new(loader) as Box<dyn Loader>)
988            .then(Box::new(transformer) as Box<dyn Transformer>)
989            .then_in_batch(Box::new(batch_transformer) as Box<dyn BatchableTransformer>)
990            .then_chunk(Box::new(chunker) as Box<dyn ChunkerTransformer>)
991            .then_store_with(Box::new(storage) as Box<dyn Persist>);
992        pipeline.run().await.unwrap();
993    }
994
995    #[tokio::test]
996    async fn test_nodes_only_cached_on_success() {
997        let mut loader = MockLoader::new();
998        let mut cache = MockNodeCache::new();
999        let mut storage = MockPersist::new();
1000        let mut transformer = MockTransformer::new();
1001
1002        loader
1003            .expect_into_stream()
1004            .returning(|| vec![Ok(Node::default())].into());
1005
1006        cache.expect_get().times(1).returning(|_| false);
1007        cache.expect_name().returning(|| "test_cache");
1008
1009        transformer
1010            .expect_transform_node()
1011            .returning(|_| Err(anyhow::anyhow!("Transformation failed")));
1012        transformer.expect_concurrency().returning(|| None);
1013        transformer
1014            .expect_name()
1015            .returning(|| "failing_transformer");
1016
1017        storage.expect_setup().returning(|| Ok(()));
1018        storage.expect_batch_size().returning(|| None);
1019
1020        cache.expect_set().times(0);
1021
1022        let pipeline = Pipeline::from_loader(loader)
1023            .filter_cached(cache)
1024            .then(transformer)
1025            .then_store_with(storage)
1026            .filter_errors();
1027
1028        pipeline.run().await.unwrap();
1029    }
1030
1031    #[tokio::test]
1032    async fn test_nodes_cached_on_successful_storage() {
1033        let mut loader = MockLoader::new();
1034        let mut cache = MockNodeCache::new();
1035        let storage = MemoryStorage::default();
1036
1037        loader
1038            .expect_into_stream()
1039            .returning(|| vec![Ok(Node::default())].into());
1040
1041        cache.expect_name().returning(|| "test_cache");
1042        cache.expect_get().times(1).returning(|_| false);
1043
1044        cache.expect_set().times(1).returning(|_| ());
1045
1046        let pipeline = Pipeline::from_loader(loader)
1047            .filter_cached(cache)
1048            .then_store_with(storage);
1049
1050        pipeline.run().await.unwrap();
1051    }
1052}