1use anyhow::Result;
2use futures_util::{StreamExt, TryFutureExt, TryStreamExt};
3use swiftide_core::{
4 indexing::IndexingDefaults, BatchableTransformer, ChunkerTransformer, Loader, NodeCache,
5 Persist, SimplePrompt, Transformer, WithBatchIndexingDefaults, WithIndexingDefaults,
6};
7use tokio::{sync::mpsc, task};
8use tracing::Instrument;
9
10use std::{sync::Arc, time::Duration};
11
12use swiftide_core::indexing::{EmbedMode, IndexingStream, Node};
13
14const DEFAULT_BATCH_SIZE: usize = 256;
16
17pub struct Pipeline {
29 stream: IndexingStream,
30 storage: Vec<Arc<dyn Persist>>,
31 concurrency: usize,
32 indexing_defaults: IndexingDefaults,
33 batch_size: usize,
34}
35
36impl Default for Pipeline {
37 fn default() -> Self {
39 Self {
40 stream: IndexingStream::empty(),
41 storage: Vec::default(),
42 concurrency: num_cpus::get(),
43 indexing_defaults: IndexingDefaults::default(),
44 batch_size: DEFAULT_BATCH_SIZE,
45 }
46 }
47}
48
49impl Pipeline {
50 pub fn from_loader(loader: impl Loader + 'static) -> Self {
60 let stream = loader.into_stream();
61 Self {
62 stream,
63 ..Default::default()
64 }
65 }
66
67 #[must_use]
70 pub fn with_default_llm_client(mut self, client: impl SimplePrompt + 'static) -> Self {
71 self.indexing_defaults = IndexingDefaults::from_simple_prompt(Box::new(client));
72 self
73 }
74
75 pub fn from_stream(stream: impl Into<IndexingStream>) -> Self {
85 Self {
86 stream: stream.into(),
87 ..Default::default()
88 }
89 }
90
91 #[must_use]
102 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
103 self.concurrency = concurrency;
104 self
105 }
106
107 #[must_use]
120 pub fn with_embed_mode(mut self, embed_mode: EmbedMode) -> Self {
121 self.stream = self
122 .stream
123 .map_ok(move |mut node| {
124 node.embed_mode = embed_mode;
125 node
126 })
127 .boxed()
128 .into();
129 self
130 }
131
132 #[must_use]
142 pub fn filter_cached(mut self, cache: impl NodeCache + 'static) -> Self {
143 let cache = Arc::new(cache);
144 self.stream = self
145 .stream
146 .try_filter_map(move |node| {
147 let cache = Arc::clone(&cache);
148 let span =
149 tracing::trace_span!("filter_cached", node_cache = ?cache, node = ?node );
150 async move {
151 if cache.get(&node).await {
152 tracing::debug!(node = ?node, node_cache = cache.name(), "Node in cache, skipping");
153 Ok(None)
154 } else {
155 cache.set(&node).await;
156 tracing::debug!(node = ?node, node_cache = cache.name(), "Node not in cache, processing");
157 Ok(Some(node))
158 }
159 }
160 .instrument(span.or_current())
161 })
162 .boxed()
163 .into();
164 self
165 }
166
167 #[must_use]
179 pub fn then(
180 mut self,
181 mut transformer: impl Transformer + WithIndexingDefaults + 'static,
182 ) -> Self {
183 let concurrency = transformer.concurrency().unwrap_or(self.concurrency);
184
185 transformer.with_indexing_defaults(self.indexing_defaults.clone());
186
187 let transformer = Arc::new(transformer);
188 self.stream = self
189 .stream
190 .map_ok(move |node| {
191 let transformer = transformer.clone();
192 let span = tracing::trace_span!("then", node = ?node);
193
194 task::spawn(async move {
195 tracing::debug!(node = ?node, transformer = transformer.name(), "Transforming node");
196 transformer.transform_node(node).await
197 }.instrument(span.or_current())
198 )
199 .err_into::<anyhow::Error>()
200 })
201 .try_buffer_unordered(concurrency)
202 .map(|x| x.and_then(|x| x))
203 .boxed()
204 .into();
205
206 self
207 }
208
209 #[must_use]
221 pub fn then_in_batch(
222 mut self,
223 mut transformer: impl BatchableTransformer + WithBatchIndexingDefaults + 'static,
224 ) -> Self {
225 let concurrency = transformer.concurrency().unwrap_or(self.concurrency);
226
227 transformer.with_indexing_defaults(self.indexing_defaults.clone());
228
229 let transformer = Arc::new(transformer);
230 self.stream = self
231 .stream
232 .try_chunks(transformer.batch_size().unwrap_or(self.batch_size))
233 .map_ok(move |nodes| {
234 let transformer = Arc::clone(&transformer);
235 let span = tracing::trace_span!("then_in_batch", nodes = ?nodes );
236
237 tokio::spawn(
238 async move {
239 tracing::debug!(
240 batch_transformer = transformer.name(),
241 num_nodes = nodes.len(),
242 "Batch transforming nodes"
243 );
244 transformer.batch_transform(nodes).await
245 }
246 .instrument(span.or_current()),
247 )
248 .map_err(anyhow::Error::from)
249 })
250 .err_into::<anyhow::Error>()
251 .try_buffer_unordered(concurrency) .try_flatten_unordered(None) .boxed()
254 .into();
255 self
256 }
257
258 #[must_use]
268 pub fn then_chunk(mut self, chunker: impl ChunkerTransformer + 'static) -> Self {
269 let chunker = Arc::new(chunker);
270 let concurrency = chunker.concurrency().unwrap_or(self.concurrency);
271 self.stream = self
272 .stream
273 .map_ok(move |node| {
274 let chunker = Arc::clone(&chunker);
275 let span = tracing::trace_span!("then_chunk", chunker = ?chunker, node = ?node );
276
277 tokio::spawn(
278 async move {
279 tracing::debug!(chunker = chunker.name(), "Chunking node");
280 chunker.transform_node(node).await
281 }
282 .instrument(span.or_current()),
283 )
284 .map_err(anyhow::Error::from)
285 })
286 .err_into::<anyhow::Error>()
287 .try_buffer_unordered(concurrency)
288 .try_flatten_unordered(None)
289 .boxed()
290 .into();
291
292 self
293 }
294
295 #[must_use]
310 pub fn then_store_with(mut self, storage: impl Persist + 'static) -> Self {
311 let storage = Arc::new(storage);
312 self.storage.push(storage.clone());
313 if storage.batch_size().is_some() {
315 self.stream = self
316 .stream
317 .try_chunks(storage.batch_size().unwrap())
318 .map_ok(move |nodes| {
319 let storage = Arc::clone(&storage);
320 let span = tracing::trace_span!("then_store_with_batched", storage = ?storage, nodes = ?nodes );
321
322 tokio::spawn(async move {
323 tracing::debug!(storage = storage.name(), num_nodes = nodes.len(), "Batch Storing nodes");
324 storage.batch_store(nodes).await
325 }
326 .instrument(span.or_current())
327 )
328 .map_err(anyhow::Error::from)
329
330 })
331 .err_into::<anyhow::Error>()
332 .try_buffer_unordered(self.concurrency)
333 .try_flatten_unordered(None)
334 .boxed().into();
335 } else {
336 self.stream = self
337 .stream
338 .map_ok(move |node| {
339 let storage = Arc::clone(&storage);
340 let span =
341 tracing::trace_span!("then_store_with", storage = ?storage, node = ?node );
342
343 tokio::spawn(
344 async move {
345 tracing::debug!(storage = storage.name(), "Storing node");
346
347 storage.store(node).await
348 }
349 .instrument(span.or_current()),
350 )
351 .err_into::<anyhow::Error>()
352 })
353 .try_buffer_unordered(self.concurrency)
354 .map(|x| x.and_then(|x| x))
355 .boxed()
356 .into();
357 }
358
359 self
360 }
361
362 #[must_use]
377 pub fn split_by<P>(self, predicate: P) -> (Self, Self)
378 where
379 P: Fn(&Result<Node>) -> bool + Send + Sync + 'static,
380 {
381 let predicate = Arc::new(predicate);
382
383 let (left_tx, left_rx) = mpsc::channel(1000);
384 let (right_tx, right_rx) = mpsc::channel(1000);
385
386 let stream = self.stream;
387 let span = tracing::trace_span!("split_by");
388 tokio::spawn(
389 async move {
390 stream
391 .for_each(move |item| {
392 let predicate = Arc::clone(&predicate);
393 let left_tx = left_tx.clone();
394 let right_tx = right_tx.clone();
395 async move {
396 if predicate(&item) {
397 tracing::debug!(?item, "Sending to left stream");
398 left_tx
399 .send(item)
400 .await
401 .expect("Failed to send to left stream");
402 } else {
403 tracing::debug!(?item, "Sending to right stream");
404 right_tx
405 .send(item)
406 .await
407 .expect("Failed to send to right stream");
408 }
409 }
410 })
411 .await;
412 }
413 .instrument(span.or_current()),
414 );
415
416 let left_pipeline = Self {
417 stream: left_rx.into(),
418 storage: self.storage.clone(),
419 concurrency: self.concurrency,
420 indexing_defaults: self.indexing_defaults.clone(),
421 batch_size: self.batch_size,
422 };
423
424 let right_pipeline = Self {
425 stream: right_rx.into(),
426 storage: self.storage.clone(),
427 concurrency: self.concurrency,
428 indexing_defaults: self.indexing_defaults.clone(),
429 batch_size: self.batch_size,
430 };
431
432 (left_pipeline, right_pipeline)
433 }
434
435 #[must_use]
441 pub fn merge(self, other: Self) -> Self {
442 let stream = tokio_stream::StreamExt::merge(self.stream, other.stream);
443
444 Self {
445 stream: stream.boxed().into(),
446 ..self
447 }
448 }
449
450 #[must_use]
454 pub fn throttle(mut self, duration: impl Into<Duration>) -> Self {
455 self.stream = tokio_stream::StreamExt::throttle(self.stream, duration.into())
456 .boxed()
457 .into();
458 self
459 }
460
461 #[must_use]
466 pub fn filter_errors(mut self) -> Self {
467 self.stream = self
468 .stream
469 .filter_map(|result| async {
470 match result {
471 Ok(node) => Some(Ok(node)),
472 Err(_e) => None,
473 }
474 })
475 .boxed()
476 .into();
477 self
478 }
479
480 #[must_use]
486 pub fn filter<F>(mut self, filter: F) -> Self
487 where
488 F: Fn(&Result<Node>) -> bool + Send + Sync + 'static,
489 {
490 self.stream = self
491 .stream
492 .filter(move |result| {
493 let will_retain = filter(result);
494
495 async move { will_retain }
496 })
497 .boxed()
498 .into();
499 self
500 }
501
502 #[must_use]
506 pub fn log_all(self) -> Self {
507 self.log_errors().log_nodes()
508 }
509
510 #[must_use]
514 pub fn log_errors(mut self) -> Self {
515 self.stream = self
516 .stream
517 .inspect_err(|e| tracing::error!("Error processing node: {:?}", e))
518 .boxed()
519 .into();
520 self
521 }
522
523 #[must_use]
527 pub fn log_nodes(mut self) -> Self {
528 self.stream = self
529 .stream
530 .inspect_ok(|node| tracing::debug!("Processed node: {:?}", node))
531 .boxed()
532 .into();
533 self
534 }
535
536 #[tracing::instrument(skip_all, fields(total_nodes), name = "indexing_pipeline.run")]
548 pub async fn run(mut self) -> Result<()> {
549 tracing::info!(
550 "Starting indexing pipeline with {} concurrency",
551 self.concurrency
552 );
553 let now = std::time::Instant::now();
554 if self.storage.is_empty() {
555 anyhow::bail!("No storage configured for indexing pipeline");
556 }
557
558 let setup_futures = self
560 .storage
561 .into_iter()
562 .map(|storage| async move { storage.setup().await })
563 .collect::<Vec<_>>();
564 futures_util::future::try_join_all(setup_futures).await?;
565
566 let mut total_nodes = 0;
567 while self.stream.try_next().await?.is_some() {
568 total_nodes += 1;
569 }
570
571 let elapsed_in_seconds = now.elapsed().as_secs();
572 tracing::warn!(
573 elapsed_in_seconds,
574 "Processed {} nodes in {} seconds",
575 total_nodes,
576 elapsed_in_seconds
577 );
578 tracing::Span::current().record("total_nodes", total_nodes);
579
580 Ok(())
581 }
582}
583
584#[cfg(test)]
585mod tests {
586
587 use super::*;
588 use crate::persist::MemoryStorage;
589 use mockall::Sequence;
590 use swiftide_core::indexing::*;
591
592 #[test_log::test(tokio::test)]
594 async fn test_simple_run() {
595 let mut loader = MockLoader::new();
596 let mut transformer = MockTransformer::new();
597 let mut batch_transformer = MockBatchableTransformer::new();
598 let mut chunker = MockChunkerTransformer::new();
599 let mut storage = MockPersist::new();
600
601 let mut seq = Sequence::new();
602
603 loader
604 .expect_into_stream()
605 .times(1)
606 .in_sequence(&mut seq)
607 .returning(|| vec![Ok(Node::default())].into());
608
609 transformer.expect_transform_node().returning(|mut node| {
610 node.chunk = "transformed".to_string();
611 Ok(node)
612 });
613 transformer.expect_concurrency().returning(|| None);
614 transformer.expect_name().returning(|| "transformer");
615
616 batch_transformer
617 .expect_batch_transform()
618 .times(1)
619 .in_sequence(&mut seq)
620 .returning(|nodes| IndexingStream::iter(nodes.into_iter().map(Ok)));
621 batch_transformer.expect_concurrency().returning(|| None);
622 batch_transformer.expect_name().returning(|| "transformer");
623 batch_transformer.expect_batch_size().returning(|| None);
624
625 chunker
626 .expect_transform_node()
627 .times(1)
628 .in_sequence(&mut seq)
629 .returning(|node| {
630 let mut nodes = vec![];
631 for i in 0..3 {
632 let mut node = node.clone();
633 node.chunk = format!("transformed_chunk_{i}");
634 nodes.push(Ok(node));
635 }
636 nodes.into()
637 });
638 chunker.expect_concurrency().returning(|| None);
639 chunker.expect_name().returning(|| "chunker");
640
641 storage.expect_setup().returning(|| Ok(()));
642 storage.expect_batch_size().returning(|| None);
643 storage
644 .expect_store()
645 .times(3)
646 .in_sequence(&mut seq)
647 .withf(|node| node.chunk.starts_with("transformed_chunk_"))
648 .returning(Ok);
649 storage.expect_name().returning(|| "storage");
650
651 let pipeline = Pipeline::from_loader(loader)
652 .then(transformer)
653 .then_in_batch(batch_transformer)
654 .then_chunk(chunker)
655 .then_store_with(storage);
656
657 pipeline.run().await.unwrap();
658 }
659
660 #[tokio::test]
661 async fn test_skipping_errors() {
662 let mut loader = MockLoader::new();
663 let mut transformer = MockTransformer::new();
664 let mut storage = MockPersist::new();
665 let mut seq = Sequence::new();
666 loader
667 .expect_into_stream()
668 .times(1)
669 .in_sequence(&mut seq)
670 .returning(|| vec![Ok(Node::default())].into());
671 transformer
672 .expect_transform_node()
673 .returning(|_node| Err(anyhow::anyhow!("Error transforming node")));
674 transformer.expect_concurrency().returning(|| None);
675 storage.expect_setup().returning(|| Ok(()));
676 storage.expect_batch_size().returning(|| None);
677 storage.expect_store().times(0).returning(Ok);
678 let pipeline = Pipeline::from_loader(loader)
679 .then(transformer)
680 .then_store_with(storage)
681 .filter_errors();
682 pipeline.run().await.unwrap();
683 }
684
685 #[tokio::test]
686 async fn test_concurrent_calls_with_simple_transformer() {
687 let mut loader = MockLoader::new();
688 let mut transformer = MockTransformer::new();
689 let mut storage = MockPersist::new();
690 let mut seq = Sequence::new();
691 loader
692 .expect_into_stream()
693 .times(1)
694 .in_sequence(&mut seq)
695 .returning(|| {
696 vec![
697 Ok(Node::default()),
698 Ok(Node::default()),
699 Ok(Node::default()),
700 ]
701 .into()
702 });
703 transformer
704 .expect_transform_node()
705 .times(3)
706 .in_sequence(&mut seq)
707 .returning(|mut node| {
708 node.chunk = "transformed".to_string();
709 Ok(node)
710 });
711 transformer.expect_concurrency().returning(|| Some(3));
712 transformer.expect_name().returning(|| "transformer");
713 storage.expect_setup().returning(|| Ok(()));
714 storage.expect_batch_size().returning(|| None);
715 storage.expect_store().times(3).returning(Ok);
716 storage.expect_name().returning(|| "storage");
717
718 let pipeline = Pipeline::from_loader(loader)
719 .then(transformer)
720 .then_store_with(storage);
721 pipeline.run().await.unwrap();
722 }
723
724 #[tokio::test]
725 async fn test_arbitrary_closures_as_transformer() {
726 let mut loader = MockLoader::new();
727 let transformer = |node: Node| {
728 let mut node = node;
729 node.chunk = "transformed".to_string();
730 Ok(node)
731 };
732 let storage = MemoryStorage::default();
733 let mut seq = Sequence::new();
734 loader
735 .expect_into_stream()
736 .times(1)
737 .in_sequence(&mut seq)
738 .returning(|| vec![Ok(Node::default())].into());
739
740 let pipeline = Pipeline::from_loader(loader)
741 .then(transformer)
742 .then_store_with(storage.clone());
743 pipeline.run().await.unwrap();
744
745 dbg!(storage.clone());
746 let processed_node = storage.get("0").await.unwrap();
747 assert_eq!(processed_node.chunk, "transformed");
748 }
749
750 #[tokio::test]
751 async fn test_arbitrary_closures_as_batch_transformer() {
752 let mut loader = MockLoader::new();
753 let batch_transformer = |nodes: Vec<Node>| {
754 IndexingStream::iter(nodes.into_iter().map(|mut node| {
755 node.chunk = "transformed".to_string();
756 Ok(node)
757 }))
758 };
759 let storage = MemoryStorage::default();
760 let mut seq = Sequence::new();
761 loader
762 .expect_into_stream()
763 .times(1)
764 .in_sequence(&mut seq)
765 .returning(|| vec![Ok(Node::default())].into());
766
767 let pipeline = Pipeline::from_loader(loader)
768 .then_in_batch(batch_transformer)
769 .then_store_with(storage.clone());
770 pipeline.run().await.unwrap();
771
772 dbg!(storage.clone());
773 let processed_node = storage.get("0").await.unwrap();
774 assert_eq!(processed_node.chunk, "transformed");
775 }
776
777 #[tokio::test]
778 async fn test_filter_closure() {
779 let mut loader = MockLoader::new();
780 let storage = MemoryStorage::default();
781 let mut seq = Sequence::new();
782 loader
783 .expect_into_stream()
784 .times(1)
785 .in_sequence(&mut seq)
786 .returning(|| {
787 vec![
788 Ok(Node::default()),
789 Ok(Node::new("skip")),
790 Ok(Node::default()),
791 ]
792 .into()
793 });
794 let pipeline = Pipeline::from_loader(loader)
795 .filter(|result| {
796 let node = result.as_ref().unwrap();
797 node.chunk != "skip"
798 })
799 .then_store_with(storage.clone());
800 pipeline.run().await.unwrap();
801 let nodes = storage.get_all().await;
802 assert_eq!(nodes.len(), 2);
803 }
804
805 #[test_log::test(tokio::test)]
806 async fn test_split_and_merge() {
807 let mut loader = MockLoader::new();
808 let storage = MemoryStorage::default();
809 let mut seq = Sequence::new();
810 loader
811 .expect_into_stream()
812 .times(1)
813 .in_sequence(&mut seq)
814 .returning(|| {
815 vec![
816 Ok(Node::default()),
817 Ok(Node::new("will go left")),
818 Ok(Node::default()),
819 ]
820 .into()
821 });
822
823 let pipeline = Pipeline::from_loader(loader);
824 let (mut left, mut right) = pipeline.split_by(|node| {
825 if let Ok(node) = node {
826 node.chunk.starts_with("will go left")
827 } else {
828 false
829 }
830 });
831
832 left = left
834 .then(move |mut node: Node| {
835 node.chunk = "left".to_string();
836
837 Ok(node)
838 })
839 .log_all();
840
841 right = right.then(move |mut node: Node| {
842 node.chunk = "right".to_string();
843 Ok(node)
844 });
845
846 left.merge(right)
847 .then_store_with(storage.clone())
848 .run()
849 .await
850 .unwrap();
851 dbg!(storage.clone());
852
853 let all_nodes = storage.get_all_values().await;
854 assert_eq!(
855 all_nodes.iter().filter(|node| node.chunk == "left").count(),
856 1
857 );
858 assert_eq!(
859 all_nodes
860 .iter()
861 .filter(|node| node.chunk == "right")
862 .count(),
863 2
864 );
865 }
866
867 #[tokio::test]
868 async fn test_all_steps_should_work_as_dyn_box() {
869 let mut loader = MockLoader::new();
870 loader
871 .expect_into_stream_boxed()
872 .returning(|| vec![Ok(Node::default())].into());
873
874 let mut transformer = MockTransformer::new();
875 transformer.expect_transform_node().returning(Ok);
876 transformer.expect_concurrency().returning(|| None);
877
878 let mut batch_transformer = MockBatchableTransformer::new();
879 batch_transformer
880 .expect_batch_transform()
881 .returning(std::convert::Into::into);
882 batch_transformer.expect_concurrency().returning(|| None);
883 let mut chunker = MockChunkerTransformer::new();
884 chunker
885 .expect_transform_node()
886 .returning(|node| vec![node].into());
887 chunker.expect_concurrency().returning(|| None);
888
889 let mut storage = MockPersist::new();
890 storage.expect_setup().returning(|| Ok(()));
891 storage.expect_store().returning(Ok);
892 storage.expect_batch_size().returning(|| None);
893
894 let pipeline = Pipeline::from_loader(Box::new(loader) as Box<dyn Loader>)
895 .then(Box::new(transformer) as Box<dyn Transformer>)
896 .then_in_batch(Box::new(batch_transformer) as Box<dyn BatchableTransformer>)
897 .then_chunk(Box::new(chunker) as Box<dyn ChunkerTransformer>)
898 .then_store_with(Box::new(storage) as Box<dyn Persist>);
899 pipeline.run().await.unwrap();
900 }
901}