1use anyhow::Result;
2use futures_util::{StreamExt, TryFutureExt, TryStreamExt};
3use swiftide_core::{
4 indexing::IndexingDefaults, BatchableTransformer, ChunkerTransformer, Loader, NodeCache,
5 Persist, SimplePrompt, Transformer, WithBatchIndexingDefaults, WithIndexingDefaults,
6};
7use tokio::{sync::mpsc, task};
8use tracing::Instrument;
9
10use std::{sync::Arc, time::Duration};
11
12use swiftide_core::indexing::{EmbedMode, IndexingStream, Node};
13
14const DEFAULT_BATCH_SIZE: usize = 256;
16
17pub struct Pipeline {
30 stream: IndexingStream,
31 storage: Vec<Arc<dyn Persist>>,
32 concurrency: usize,
33 indexing_defaults: IndexingDefaults,
34 batch_size: usize,
35}
36
37impl Default for Pipeline {
38 fn default() -> Self {
41 Self {
42 stream: IndexingStream::empty(),
43 storage: Vec::default(),
44 concurrency: num_cpus::get(),
45 indexing_defaults: IndexingDefaults::default(),
46 batch_size: DEFAULT_BATCH_SIZE,
47 }
48 }
49}
50
51impl Pipeline {
52 pub fn from_loader(loader: impl Loader + 'static) -> Self {
62 let stream = loader.into_stream();
63 Self {
64 stream,
65 ..Default::default()
66 }
67 }
68
69 #[must_use]
72 pub fn with_default_llm_client(mut self, client: impl SimplePrompt + 'static) -> Self {
73 self.indexing_defaults = IndexingDefaults::from_simple_prompt(Box::new(client));
74 self
75 }
76
77 pub fn from_stream(stream: impl Into<IndexingStream>) -> Self {
87 Self {
88 stream: stream.into(),
89 ..Default::default()
90 }
91 }
92
93 #[must_use]
104 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
105 self.concurrency = concurrency;
106 self
107 }
108
109 #[must_use]
122 pub fn with_embed_mode(mut self, embed_mode: EmbedMode) -> Self {
123 self.stream = self
124 .stream
125 .map_ok(move |mut node| {
126 node.embed_mode = embed_mode;
127 node
128 })
129 .boxed()
130 .into();
131 self
132 }
133
134 #[must_use]
144 pub fn filter_cached(mut self, cache: impl NodeCache + 'static) -> Self {
145 let cache = Arc::new(cache);
146 self.stream = self
147 .stream
148 .try_filter_map(move |node| {
149 let cache = Arc::clone(&cache);
150 let span =
151 tracing::trace_span!("filter_cached", node_cache = ?cache, node = ?node );
152 async move {
153 if cache.get(&node).await {
154 tracing::debug!(node = ?node, node_cache = cache.name(), "Node in cache, skipping");
155 Ok(None)
156 } else {
157 cache.set(&node).await;
158 tracing::debug!(node = ?node, node_cache = cache.name(), "Node not in cache, processing");
159 Ok(Some(node))
160 }
161 }
162 .instrument(span.or_current())
163 })
164 .boxed()
165 .into();
166 self
167 }
168
169 #[must_use]
181 pub fn then(
182 mut self,
183 mut transformer: impl Transformer + WithIndexingDefaults + 'static,
184 ) -> Self {
185 let concurrency = transformer.concurrency().unwrap_or(self.concurrency);
186
187 transformer.with_indexing_defaults(self.indexing_defaults.clone());
188
189 let transformer = Arc::new(transformer);
190 self.stream = self
191 .stream
192 .map_ok(move |node| {
193 let transformer = transformer.clone();
194 let span = tracing::trace_span!("then", node = ?node);
195
196 task::spawn(async move {
197 tracing::debug!(node = ?node, transformer = transformer.name(), "Transforming node");
198 transformer.transform_node(node).await
199 }.instrument(span.or_current())
200 )
201 .err_into::<anyhow::Error>()
202 })
203 .try_buffer_unordered(concurrency)
204 .map(|x| x.and_then(|x| x))
205 .boxed()
206 .into();
207
208 self
209 }
210
211 #[must_use]
225 pub fn then_in_batch(
226 mut self,
227 mut transformer: impl BatchableTransformer + WithBatchIndexingDefaults + 'static,
228 ) -> Self {
229 let concurrency = transformer.concurrency().unwrap_or(self.concurrency);
230
231 transformer.with_indexing_defaults(self.indexing_defaults.clone());
232
233 let transformer = Arc::new(transformer);
234 self.stream = self
235 .stream
236 .try_chunks(transformer.batch_size().unwrap_or(self.batch_size))
237 .map_ok(move |nodes| {
238 let transformer = Arc::clone(&transformer);
239 let span = tracing::trace_span!("then_in_batch", nodes = ?nodes );
240
241 tokio::spawn(
242 async move {
243 tracing::debug!(
244 batch_transformer = transformer.name(),
245 num_nodes = nodes.len(),
246 "Batch transforming nodes"
247 );
248 transformer.batch_transform(nodes).await
249 }
250 .instrument(span.or_current()),
251 )
252 .map_err(anyhow::Error::from)
253 })
254 .err_into::<anyhow::Error>()
255 .try_buffer_unordered(concurrency) .try_flatten_unordered(None) .boxed()
258 .into();
259 self
260 }
261
262 #[must_use]
273 pub fn then_chunk(mut self, chunker: impl ChunkerTransformer + 'static) -> Self {
274 let chunker = Arc::new(chunker);
275 let concurrency = chunker.concurrency().unwrap_or(self.concurrency);
276 self.stream = self
277 .stream
278 .map_ok(move |node| {
279 let chunker = Arc::clone(&chunker);
280 let span = tracing::trace_span!("then_chunk", chunker = ?chunker, node = ?node );
281
282 tokio::spawn(
283 async move {
284 tracing::debug!(chunker = chunker.name(), "Chunking node");
285 chunker.transform_node(node).await
286 }
287 .instrument(span.or_current()),
288 )
289 .map_err(anyhow::Error::from)
290 })
291 .err_into::<anyhow::Error>()
292 .try_buffer_unordered(concurrency)
293 .try_flatten_unordered(None)
294 .boxed()
295 .into();
296
297 self
298 }
299
300 #[must_use]
315 pub fn then_store_with(mut self, storage: impl Persist + 'static) -> Self {
316 let storage = Arc::new(storage);
317 self.storage.push(storage.clone());
318 if storage.batch_size().is_some() {
320 self.stream = self
321 .stream
322 .try_chunks(storage.batch_size().unwrap())
323 .map_ok(move |nodes| {
324 let storage = Arc::clone(&storage);
325 let span = tracing::trace_span!("then_store_with_batched", storage = ?storage, nodes = ?nodes );
326
327 tokio::spawn(async move {
328 tracing::debug!(storage = storage.name(), num_nodes = nodes.len(), "Batch Storing nodes");
329 storage.batch_store(nodes).await
330 }
331 .instrument(span.or_current())
332 )
333 .map_err(anyhow::Error::from)
334
335 })
336 .err_into::<anyhow::Error>()
337 .try_buffer_unordered(self.concurrency)
338 .try_flatten_unordered(None)
339 .boxed().into();
340 } else {
341 self.stream = self
342 .stream
343 .map_ok(move |node| {
344 let storage = Arc::clone(&storage);
345 let span =
346 tracing::trace_span!("then_store_with", storage = ?storage, node = ?node );
347
348 tokio::spawn(
349 async move {
350 tracing::debug!(storage = storage.name(), "Storing node");
351
352 storage.store(node).await
353 }
354 .instrument(span.or_current()),
355 )
356 .err_into::<anyhow::Error>()
357 })
358 .try_buffer_unordered(self.concurrency)
359 .map(|x| x.and_then(|x| x))
360 .boxed()
361 .into();
362 }
363
364 self
365 }
366
367 #[must_use]
382 pub fn split_by<P>(self, predicate: P) -> (Self, Self)
383 where
384 P: Fn(&Result<Node>) -> bool + Send + Sync + 'static,
385 {
386 let predicate = Arc::new(predicate);
387
388 let (left_tx, left_rx) = mpsc::channel(1000);
389 let (right_tx, right_rx) = mpsc::channel(1000);
390
391 let stream = self.stream;
392 let span = tracing::trace_span!("split_by");
393 tokio::spawn(
394 async move {
395 stream
396 .for_each_concurrent(self.concurrency, move |item| {
397 let predicate = Arc::clone(&predicate);
398 let left_tx = left_tx.clone();
399 let right_tx = right_tx.clone();
400 async move {
401 if predicate(&item) {
402 tracing::debug!(?item, "Sending to left stream");
403 left_tx
404 .send(item)
405 .await
406 .expect("Failed to send to left stream");
407 } else {
408 tracing::debug!(?item, "Sending to right stream");
409 right_tx
410 .send(item)
411 .await
412 .expect("Failed to send to right stream");
413 }
414 }
415 })
416 .await;
417 }
418 .instrument(span.or_current()),
419 );
420
421 let left_pipeline = Self {
422 stream: left_rx.into(),
423 storage: self.storage.clone(),
424 concurrency: self.concurrency,
425 indexing_defaults: self.indexing_defaults.clone(),
426 batch_size: self.batch_size,
427 };
428
429 let right_pipeline = Self {
430 stream: right_rx.into(),
431 storage: self.storage.clone(),
432 concurrency: self.concurrency,
433 indexing_defaults: self.indexing_defaults.clone(),
434 batch_size: self.batch_size,
435 };
436
437 (left_pipeline, right_pipeline)
438 }
439
440 #[must_use]
446 pub fn merge(self, other: Self) -> Self {
447 let stream = tokio_stream::StreamExt::merge(self.stream, other.stream);
448
449 Self {
450 stream: stream.boxed().into(),
451 ..self
452 }
453 }
454
455 #[must_use]
460 pub fn throttle(mut self, duration: impl Into<Duration>) -> Self {
461 self.stream = tokio_stream::StreamExt::throttle(self.stream, duration.into())
462 .boxed()
463 .into();
464 self
465 }
466
467 #[must_use]
472 pub fn filter_errors(mut self) -> Self {
473 self.stream = self
474 .stream
475 .filter_map(|result| async {
476 match result {
477 Ok(node) => Some(Ok(node)),
478 Err(_e) => None,
479 }
480 })
481 .boxed()
482 .into();
483 self
484 }
485
486 #[must_use]
492 pub fn filter<F>(mut self, filter: F) -> Self
493 where
494 F: Fn(&Result<Node>) -> bool + Send + Sync + 'static,
495 {
496 self.stream = self
497 .stream
498 .filter(move |result| {
499 let will_retain = filter(result);
500
501 async move { will_retain }
502 })
503 .boxed()
504 .into();
505 self
506 }
507
508 #[must_use]
512 pub fn log_all(self) -> Self {
513 self.log_errors().log_nodes()
514 }
515
516 #[must_use]
520 pub fn log_errors(mut self) -> Self {
521 self.stream = self
522 .stream
523 .inspect_err(|e| tracing::error!("Error processing node: {:?}", e))
524 .boxed()
525 .into();
526 self
527 }
528
529 #[must_use]
533 pub fn log_nodes(mut self) -> Self {
534 self.stream = self
535 .stream
536 .inspect_ok(|node| tracing::debug!("Processed node: {:?}", node))
537 .boxed()
538 .into();
539 self
540 }
541
542 #[tracing::instrument(skip_all, fields(total_nodes), name = "indexing_pipeline.run")]
555 pub async fn run(mut self) -> Result<()> {
556 tracing::info!(
557 "Starting indexing pipeline with {} concurrency",
558 self.concurrency
559 );
560 let now = std::time::Instant::now();
561 if self.storage.is_empty() {
562 anyhow::bail!("No storage configured for indexing pipeline");
563 }
564
565 let setup_futures = self
567 .storage
568 .into_iter()
569 .map(|storage| async move { storage.setup().await })
570 .collect::<Vec<_>>();
571 futures_util::future::try_join_all(setup_futures).await?;
572
573 let mut total_nodes = 0;
574 while self.stream.try_next().await?.is_some() {
575 total_nodes += 1;
576 }
577
578 let elapsed_in_seconds = now.elapsed().as_secs();
579 tracing::warn!(
580 elapsed_in_seconds,
581 "Processed {} nodes in {} seconds",
582 total_nodes,
583 elapsed_in_seconds
584 );
585 tracing::Span::current().record("total_nodes", total_nodes);
586
587 Ok(())
588 }
589}
590
591#[cfg(test)]
592mod tests {
593
594 use super::*;
595 use crate::persist::MemoryStorage;
596 use mockall::Sequence;
597 use swiftide_core::indexing::*;
598
599 #[test_log::test(tokio::test)]
601 async fn test_simple_run() {
602 let mut loader = MockLoader::new();
603 let mut transformer = MockTransformer::new();
604 let mut batch_transformer = MockBatchableTransformer::new();
605 let mut chunker = MockChunkerTransformer::new();
606 let mut storage = MockPersist::new();
607
608 let mut seq = Sequence::new();
609
610 loader
611 .expect_into_stream()
612 .times(1)
613 .in_sequence(&mut seq)
614 .returning(|| vec![Ok(Node::default())].into());
615
616 transformer.expect_transform_node().returning(|mut node| {
617 node.chunk = "transformed".to_string();
618 Ok(node)
619 });
620 transformer.expect_concurrency().returning(|| None);
621 transformer.expect_name().returning(|| "transformer");
622
623 batch_transformer
624 .expect_batch_transform()
625 .times(1)
626 .in_sequence(&mut seq)
627 .returning(|nodes| IndexingStream::iter(nodes.into_iter().map(Ok)));
628 batch_transformer.expect_concurrency().returning(|| None);
629 batch_transformer.expect_name().returning(|| "transformer");
630 batch_transformer.expect_batch_size().returning(|| None);
631
632 chunker
633 .expect_transform_node()
634 .times(1)
635 .in_sequence(&mut seq)
636 .returning(|node| {
637 let mut nodes = vec![];
638 for i in 0..3 {
639 let mut node = node.clone();
640 node.chunk = format!("transformed_chunk_{i}");
641 nodes.push(Ok(node));
642 }
643 nodes.into()
644 });
645 chunker.expect_concurrency().returning(|| None);
646 chunker.expect_name().returning(|| "chunker");
647
648 storage.expect_setup().returning(|| Ok(()));
649 storage.expect_batch_size().returning(|| None);
650 storage
651 .expect_store()
652 .times(3)
653 .in_sequence(&mut seq)
654 .withf(|node| node.chunk.starts_with("transformed_chunk_"))
655 .returning(Ok);
656 storage.expect_name().returning(|| "storage");
657
658 let pipeline = Pipeline::from_loader(loader)
659 .then(transformer)
660 .then_in_batch(batch_transformer)
661 .then_chunk(chunker)
662 .then_store_with(storage);
663
664 pipeline.run().await.unwrap();
665 }
666
667 #[tokio::test]
668 async fn test_skipping_errors() {
669 let mut loader = MockLoader::new();
670 let mut transformer = MockTransformer::new();
671 let mut storage = MockPersist::new();
672 let mut seq = Sequence::new();
673 loader
674 .expect_into_stream()
675 .times(1)
676 .in_sequence(&mut seq)
677 .returning(|| vec![Ok(Node::default())].into());
678 transformer
679 .expect_transform_node()
680 .returning(|_node| Err(anyhow::anyhow!("Error transforming node")));
681 transformer.expect_concurrency().returning(|| None);
682 storage.expect_setup().returning(|| Ok(()));
683 storage.expect_batch_size().returning(|| None);
684 storage.expect_store().times(0).returning(Ok);
685 let pipeline = Pipeline::from_loader(loader)
686 .then(transformer)
687 .then_store_with(storage)
688 .filter_errors();
689 pipeline.run().await.unwrap();
690 }
691
692 #[tokio::test]
693 async fn test_concurrent_calls_with_simple_transformer() {
694 let mut loader = MockLoader::new();
695 let mut transformer = MockTransformer::new();
696 let mut storage = MockPersist::new();
697 let mut seq = Sequence::new();
698 loader
699 .expect_into_stream()
700 .times(1)
701 .in_sequence(&mut seq)
702 .returning(|| {
703 vec![
704 Ok(Node::default()),
705 Ok(Node::default()),
706 Ok(Node::default()),
707 ]
708 .into()
709 });
710 transformer
711 .expect_transform_node()
712 .times(3)
713 .in_sequence(&mut seq)
714 .returning(|mut node| {
715 node.chunk = "transformed".to_string();
716 Ok(node)
717 });
718 transformer.expect_concurrency().returning(|| Some(3));
719 transformer.expect_name().returning(|| "transformer");
720 storage.expect_setup().returning(|| Ok(()));
721 storage.expect_batch_size().returning(|| None);
722 storage.expect_store().times(3).returning(Ok);
723 storage.expect_name().returning(|| "storage");
724
725 let pipeline = Pipeline::from_loader(loader)
726 .then(transformer)
727 .then_store_with(storage);
728 pipeline.run().await.unwrap();
729 }
730
731 #[tokio::test]
732 async fn test_arbitrary_closures_as_transformer() {
733 let mut loader = MockLoader::new();
734 let transformer = |node: Node| {
735 let mut node = node;
736 node.chunk = "transformed".to_string();
737 Ok(node)
738 };
739 let storage = MemoryStorage::default();
740 let mut seq = Sequence::new();
741 loader
742 .expect_into_stream()
743 .times(1)
744 .in_sequence(&mut seq)
745 .returning(|| vec![Ok(Node::default())].into());
746
747 let pipeline = Pipeline::from_loader(loader)
748 .then(transformer)
749 .then_store_with(storage.clone());
750 pipeline.run().await.unwrap();
751
752 dbg!(storage.clone());
753 let processed_node = storage.get("0").await.unwrap();
754 assert_eq!(processed_node.chunk, "transformed");
755 }
756
757 #[tokio::test]
758 async fn test_arbitrary_closures_as_batch_transformer() {
759 let mut loader = MockLoader::new();
760 let batch_transformer = |nodes: Vec<Node>| {
761 IndexingStream::iter(nodes.into_iter().map(|mut node| {
762 node.chunk = "transformed".to_string();
763 Ok(node)
764 }))
765 };
766 let storage = MemoryStorage::default();
767 let mut seq = Sequence::new();
768 loader
769 .expect_into_stream()
770 .times(1)
771 .in_sequence(&mut seq)
772 .returning(|| vec![Ok(Node::default())].into());
773
774 let pipeline = Pipeline::from_loader(loader)
775 .then_in_batch(batch_transformer)
776 .then_store_with(storage.clone());
777 pipeline.run().await.unwrap();
778
779 dbg!(storage.clone());
780 let processed_node = storage.get("0").await.unwrap();
781 assert_eq!(processed_node.chunk, "transformed");
782 }
783
784 #[tokio::test]
785 async fn test_filter_closure() {
786 let mut loader = MockLoader::new();
787 let storage = MemoryStorage::default();
788 let mut seq = Sequence::new();
789 loader
790 .expect_into_stream()
791 .times(1)
792 .in_sequence(&mut seq)
793 .returning(|| {
794 vec![
795 Ok(Node::default()),
796 Ok(Node::new("skip")),
797 Ok(Node::default()),
798 ]
799 .into()
800 });
801 let pipeline = Pipeline::from_loader(loader)
802 .filter(|result| {
803 let node = result.as_ref().unwrap();
804 node.chunk != "skip"
805 })
806 .then_store_with(storage.clone());
807 pipeline.run().await.unwrap();
808 let nodes = storage.get_all().await;
809 assert_eq!(nodes.len(), 2);
810 }
811
812 #[test_log::test(tokio::test)]
813 async fn test_split_and_merge() {
814 let mut loader = MockLoader::new();
815 let storage = MemoryStorage::default();
816 let mut seq = Sequence::new();
817 loader
818 .expect_into_stream()
819 .times(1)
820 .in_sequence(&mut seq)
821 .returning(|| {
822 vec![
823 Ok(Node::default()),
824 Ok(Node::new("will go left")),
825 Ok(Node::default()),
826 ]
827 .into()
828 });
829
830 let pipeline = Pipeline::from_loader(loader);
831 let (mut left, mut right) = pipeline.split_by(|node| {
832 if let Ok(node) = node {
833 node.chunk.starts_with("will go left")
834 } else {
835 false
836 }
837 });
838
839 left = left
841 .then(move |mut node: Node| {
842 node.chunk = "left".to_string();
843
844 Ok(node)
845 })
846 .log_all();
847
848 right = right.then(move |mut node: Node| {
849 node.chunk = "right".to_string();
850 Ok(node)
851 });
852
853 left.merge(right)
854 .then_store_with(storage.clone())
855 .run()
856 .await
857 .unwrap();
858 dbg!(storage.clone());
859
860 let all_nodes = storage.get_all_values().await;
861 assert_eq!(
862 all_nodes.iter().filter(|node| node.chunk == "left").count(),
863 1
864 );
865 assert_eq!(
866 all_nodes
867 .iter()
868 .filter(|node| node.chunk == "right")
869 .count(),
870 2
871 );
872 }
873
874 #[tokio::test]
875 async fn test_all_steps_should_work_as_dyn_box() {
876 let mut loader = MockLoader::new();
877 loader
878 .expect_into_stream_boxed()
879 .returning(|| vec![Ok(Node::default())].into());
880
881 let mut transformer = MockTransformer::new();
882 transformer.expect_transform_node().returning(Ok);
883 transformer.expect_concurrency().returning(|| None);
884
885 let mut batch_transformer = MockBatchableTransformer::new();
886 batch_transformer
887 .expect_batch_transform()
888 .returning(std::convert::Into::into);
889 batch_transformer.expect_concurrency().returning(|| None);
890 let mut chunker = MockChunkerTransformer::new();
891 chunker
892 .expect_transform_node()
893 .returning(|node| vec![node].into());
894 chunker.expect_concurrency().returning(|| None);
895
896 let mut storage = MockPersist::new();
897 storage.expect_setup().returning(|| Ok(()));
898 storage.expect_store().returning(Ok);
899 storage.expect_batch_size().returning(|| None);
900
901 let pipeline = Pipeline::from_loader(Box::new(loader) as Box<dyn Loader>)
902 .then(Box::new(transformer) as Box<dyn Transformer>)
903 .then_in_batch(Box::new(batch_transformer) as Box<dyn BatchableTransformer>)
904 .then_chunk(Box::new(chunker) as Box<dyn ChunkerTransformer>)
905 .then_store_with(Box::new(storage) as Box<dyn Persist>);
906 pipeline.run().await.unwrap();
907 }
908}