1use anyhow::Result;
2use futures_util::{StreamExt, TryFutureExt, TryStreamExt};
3use swiftide_core::{
4 BatchableTransformer, ChunkerTransformer, Loader, NodeCache, Persist, SimplePrompt,
5 Transformer, WithBatchIndexingDefaults, WithIndexingDefaults,
6 indexing::{Chunk, IndexingDefaults},
7};
8use tokio::{
9 sync::{Mutex, mpsc},
10 task,
11};
12use tracing::Instrument;
13
14use std::{pin::Pin, sync::Arc, time::Duration};
15
16use swiftide_core::indexing::{EmbedMode, IndexingStream, Node};
17
18macro_rules! trace_span {
19 ($op:literal, $step:expr) => {
20 tracing::trace_span!($op, "otel.name" = format!("{}.{}", $op, $step.name()),)
21 };
22
23 ($op:literal) => {
24 tracing::trace_span!($op, "otel.name" = format!("{}", $op),)
25 };
26}
27
28macro_rules! node_trace_log {
29 ($step:expr, $node:expr, $msg:literal) => {
30 tracing::trace!(
31 node = ?$node,
32 node_id = ?$node.id(),
33 step = $step.name(),
34 $msg
35 )
36 };
37}
38
39macro_rules! batch_node_trace_log {
40 ($step:expr, $nodes:expr, $msg:literal) => {
41 tracing::trace!(batch_size = $nodes.len(), nodes = ?$nodes, step = $step.name(), $msg)
42 };
43}
44
45macro_rules! pipeline_with_new_stream {
46 ($pipeline:expr, $stream:expr) => {
47 Pipeline {
48 stream: $stream.into(),
49 storage_setup_fns: $pipeline.storage_setup_fns.clone(),
50 concurrency: $pipeline.concurrency,
51 indexing_defaults: $pipeline.indexing_defaults.clone(),
52 batch_size: $pipeline.batch_size,
53 }
54 };
55}
56
57const DEFAULT_BATCH_SIZE: usize = 256;
59
60pub struct Pipeline<T: Chunk> {
73 stream: IndexingStream<T>,
74 storage_setup_fns: Vec<DynStorageSetupFn>,
76 concurrency: usize,
77 indexing_defaults: IndexingDefaults,
78 batch_size: usize,
79}
80
81type DynStorageSetupFn =
82 Arc<dyn Fn() -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync>;
83
84impl<T: Chunk> Default for Pipeline<T> {
85 fn default() -> Self {
88 Self {
89 stream: IndexingStream::<T>::empty(),
90 storage_setup_fns: Vec::new(),
91 concurrency: num_cpus::get(),
92 indexing_defaults: IndexingDefaults::default(),
93 batch_size: DEFAULT_BATCH_SIZE,
94 }
95 }
96}
97
98impl<T: Chunk> Pipeline<T> {
99 pub fn from_loader(loader: impl Loader<Output = T> + 'static) -> Self {
109 let stream = loader.into_stream();
110 Self {
111 stream,
112 ..Default::default()
113 }
114 }
115
116 #[must_use]
119 pub fn with_default_llm_client(mut self, client: impl SimplePrompt + 'static) -> Self {
120 self.indexing_defaults = IndexingDefaults::from_simple_prompt(Box::new(client));
121 self
122 }
123
124 pub fn from_stream(stream: impl Into<IndexingStream<T>>) -> Self {
134 Self {
135 stream: stream.into(),
136 ..Default::default()
137 }
138 }
139
140 #[must_use]
151 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
152 self.concurrency = concurrency;
153 self
154 }
155
156 #[must_use]
169 pub fn with_embed_mode(mut self, embed_mode: EmbedMode) -> Self {
170 self.stream = self
171 .stream
172 .map_ok(move |mut node| {
173 node.embed_mode = embed_mode;
174 node
175 })
176 .boxed()
177 .into();
178 self
179 }
180
181 #[must_use]
191 pub fn filter_cached(mut self, cache: impl NodeCache<Input = T> + 'static) -> Self {
192 let cache = Arc::new(cache);
193 self.stream = self
194 .stream
195 .try_filter_map(move |node| {
196 let cache = Arc::clone(&cache);
197 let span = trace_span!("filter_cached", cache);
198
199 async move {
200 if cache.get(&node).await {
201 node_trace_log!(cache, node, "node in cache, skipping");
202 Ok(None)
203 } else {
204 node_trace_log!(cache, node, "node not in cache, processing");
205 cache.set(&node).await;
206 Ok(Some(node))
207 }
208 }
209 .instrument(span.or_current())
210 })
211 .boxed()
212 .into();
213 self
214 }
215
216 #[must_use]
228 pub fn then<Output: Chunk>(
229 self,
230 mut transformer: impl Transformer<Input = T, Output = Output> + WithIndexingDefaults + 'static,
231 ) -> Pipeline<Output> {
232 let concurrency = transformer.concurrency().unwrap_or(self.concurrency);
233
234 transformer.with_indexing_defaults(self.indexing_defaults.clone());
235
236 let transformer = Arc::new(transformer);
237 let stream = self
238 .stream
239 .map_ok(move |node| {
240 let transformer = transformer.clone();
241 let span = trace_span!("then", transformer);
242
243 task::spawn(
244 async move {
245 node_trace_log!(transformer, node, "Transforming node");
246 transformer.transform_node(node).await
247 }
248 .instrument(span.or_current()),
249 )
250 .err_into::<anyhow::Error>()
251 })
252 .try_buffer_unordered(concurrency)
253 .map(|x| x.and_then(|x| x));
254
255 pipeline_with_new_stream!(self, stream.boxed())
256 }
257
258 #[must_use]
272 pub fn then_in_batch<Output: Chunk>(
273 self,
274 mut transformer: impl BatchableTransformer<Input = T, Output = Output>
275 + WithBatchIndexingDefaults
276 + 'static,
277 ) -> Pipeline<Output> {
278 let concurrency = transformer.concurrency().unwrap_or(self.concurrency);
279
280 transformer.with_indexing_defaults(self.indexing_defaults.clone());
281
282 let transformer = Arc::new(transformer);
283 let stream = self
284 .stream
285 .try_chunks(transformer.batch_size().unwrap_or(self.batch_size))
286 .map_ok(move |nodes| {
287 let transformer = Arc::clone(&transformer);
288 let span = trace_span!("then_in_batch", transformer);
289
290 tokio::spawn(
291 async move {
292 batch_node_trace_log!(transformer, nodes, "batch transforming nodes");
293 transformer.batch_transform(nodes).await
294 }
295 .instrument(span.or_current()),
296 )
297 .map_err(anyhow::Error::from)
298 })
299 .err_into::<anyhow::Error>()
300 .try_buffer_unordered(concurrency) .try_flatten_unordered(None) .boxed();
303
304 pipeline_with_new_stream!(self, stream)
305 }
306
307 #[must_use]
318 pub fn then_chunk<Output: Chunk>(
319 self,
320 chunker: impl ChunkerTransformer<Input = T, Output = Output> + 'static,
321 ) -> Pipeline<Output> {
322 let chunker = Arc::new(chunker);
323 let concurrency = chunker.concurrency().unwrap_or(self.concurrency);
324 let stream = self
325 .stream
326 .map_ok(move |node| {
327 let chunker = Arc::clone(&chunker);
328 let span = trace_span!("then_chunk", chunker);
329
330 tokio::spawn(
331 async move {
332 node_trace_log!(chunker, node, "Chunking node");
333 chunker.transform_node(node).await
334 }
335 .instrument(span.or_current()),
336 )
337 .map_err(anyhow::Error::from)
338 })
339 .err_into::<anyhow::Error>()
340 .try_buffer_unordered(concurrency)
341 .try_flatten_unordered(None);
342
343 pipeline_with_new_stream!(self, stream.boxed())
344 }
345
346 #[must_use]
361 pub fn then_store_with<Output: Chunk>(
362 mut self,
363 storage: impl Persist<Input = T, Output = Output> + 'static,
364 ) -> Pipeline<Output> {
365 let storage = Arc::new(storage);
366
367 let storage_closure = storage.clone();
368
369 let completed = Arc::new(Mutex::new(false));
371 let setup_fn: DynStorageSetupFn = Arc::new(move || {
372 let completed = Arc::clone(&completed);
373 let storage_closure = Arc::clone(&storage_closure);
374 Box::pin(async move {
375 let mut lock = completed.lock().await;
376
377 tracing::trace!(?storage_closure, "Setting up storage");
378 storage_closure.setup().await?;
379 *lock = true;
380 Ok(())
381 })
382 });
383 self.storage_setup_fns.push(setup_fn);
384
385 let stream = if storage.batch_size().is_some() {
387 self.stream
388 .try_chunks(storage.batch_size().unwrap())
389 .map_ok(move |nodes| {
390 let storage = Arc::clone(&storage);
391 let span = trace_span!("then_store_with_batched", storage);
392
393 tokio::spawn(
394 async move {
395 batch_node_trace_log!(storage, nodes, "batch storing nodes");
396 storage.batch_store(nodes).await
397 }
398 .instrument(span.or_current()),
399 )
400 .map_err(anyhow::Error::from)
401 })
402 .err_into::<anyhow::Error>()
403 .try_buffer_unordered(self.concurrency)
404 .try_flatten_unordered(None)
405 .boxed()
406 } else {
407 self.stream
408 .map_ok(move |node| {
409 let storage = Arc::clone(&storage);
410 let span = trace_span!("then_store_with", storage);
411
412 tokio::spawn(
413 async move {
414 node_trace_log!(storage, node, "Storing node");
415
416 storage.store(node).await
417 }
418 .instrument(span.or_current()),
419 )
420 .err_into::<anyhow::Error>()
421 })
422 .try_buffer_unordered(self.concurrency)
423 .map(|x| x.and_then(|x| x))
424 .boxed()
425 };
426
427 pipeline_with_new_stream!(self, stream)
428 }
429
430 #[must_use]
445 pub fn split_by<P>(self, predicate: P) -> (Self, Self)
446 where
447 P: Fn(&Result<Node<T>>) -> bool + Send + Sync + 'static,
448 {
449 let predicate = Arc::new(predicate);
450
451 let (left_tx, left_rx) = mpsc::channel(1000);
452 let (right_tx, right_rx) = mpsc::channel(1000);
453
454 let stream = self.stream;
455 let span = trace_span!("split_by");
456 tokio::spawn(
457 async move {
458 stream
459 .for_each_concurrent(self.concurrency, move |item| {
460 let predicate = Arc::clone(&predicate);
461 let left_tx = left_tx.clone();
462 let right_tx = right_tx.clone();
463 async move {
464 if predicate(&item) {
465 tracing::trace!(?item, "Sending to left stream");
466 left_tx
467 .send(item)
468 .await
469 .expect("Failed to send to left stream");
470 } else {
471 tracing::trace!(?item, "Sending to right stream");
472 right_tx
473 .send(item)
474 .await
475 .expect("Failed to send to right stream");
476 }
477 }
478 })
479 .await;
480 }
481 .instrument(span.or_current()),
482 );
483
484 let left_pipeline = pipeline_with_new_stream!(self, left_rx);
485
486 let right_pipeline = pipeline_with_new_stream!(self, right_rx);
487
488 (left_pipeline, right_pipeline)
489 }
490
491 #[must_use]
497 pub fn merge(self, other: Self) -> Self {
498 let stream = tokio_stream::StreamExt::merge(self.stream, other.stream);
499
500 Self {
501 stream: stream.boxed().into(),
502 ..self
503 }
504 }
505
506 #[must_use]
511 pub fn throttle(mut self, duration: impl Into<Duration>) -> Self {
512 self.stream = tokio_stream::StreamExt::throttle(self.stream, duration.into())
513 .boxed()
514 .into();
515 self
516 }
517
518 #[must_use]
523 pub fn filter_errors(mut self) -> Self {
524 self.stream = self
525 .stream
526 .filter_map(|result| async {
527 match result {
528 Ok(node) => Some(Ok(node)),
529 Err(_e) => None,
530 }
531 })
532 .boxed()
533 .into();
534 self
535 }
536
537 #[must_use]
543 pub fn filter<F>(mut self, filter: F) -> Self
544 where
545 F: Fn(&Result<Node<T>>) -> bool + Send + Sync + 'static,
546 {
547 self.stream = self
548 .stream
549 .filter(move |result| {
550 let will_retain = filter(result);
551
552 async move { will_retain }
553 })
554 .boxed()
555 .into();
556 self
557 }
558
559 #[must_use]
563 pub fn log_all(self) -> Self {
564 self.log_errors().log_nodes()
565 }
566
567 #[must_use]
571 pub fn log_errors(mut self) -> Self {
572 self.stream = self
573 .stream
574 .inspect_err(|e| tracing::error!(?e, "Error processing node"))
575 .boxed()
576 .into();
577 self
578 }
579
580 #[must_use]
584 pub fn log_nodes(mut self) -> Self {
585 self.stream = self
586 .stream
587 .inspect_ok(|node| tracing::debug!(?node, "Processed node: {:?}", node))
588 .boxed()
589 .into();
590 self
591 }
592
593 #[tracing::instrument(skip_all, fields(total_nodes), name = "indexing_pipeline.run")]
606 pub async fn run(mut self) -> Result<()> {
607 tracing::info!(
608 "Starting indexing pipeline with {} concurrency",
609 self.concurrency
610 );
611 let now = std::time::Instant::now();
612
613 let setup_futures = self
620 .storage_setup_fns
621 .into_iter()
622 .map(|func| async move { func().await })
623 .collect::<Vec<_>>();
624 futures_util::future::try_join_all(setup_futures).await?;
625
626 let mut total_nodes = 0;
627 while self.stream.try_next().await?.is_some() {
628 total_nodes += 1;
629 }
630
631 let elapsed_in_seconds = now.elapsed().as_secs();
632 tracing::info!(
633 elapsed_in_seconds,
634 "Processed {} nodes in {} seconds",
635 total_nodes,
636 elapsed_in_seconds
637 );
638 tracing::Span::current().record("total_nodes", total_nodes);
639
640 Ok(())
641 }
642}
643
644#[cfg(test)]
645mod tests {
646
647 use super::*;
648 use crate::persist::MemoryStorage;
649 use mockall::Sequence;
650 use swiftide_core::indexing::*;
651
652 #[test_log::test(tokio::test)]
654 async fn test_simple_run() {
655 let mut loader = MockLoader::new();
656 let mut transformer = MockTransformer::new();
657 let mut batch_transformer = MockBatchableTransformer::new();
658 let mut chunker = MockChunkerTransformer::new();
659 let mut storage = MockPersist::new();
660
661 let mut seq = Sequence::new();
662
663 loader
664 .expect_into_stream()
665 .times(1)
666 .in_sequence(&mut seq)
667 .returning(|| vec![Ok(Node::default())].into());
668
669 transformer.expect_transform_node().returning(|mut node| {
670 node.chunk = "transformed".to_string();
671 Ok(node)
672 });
673 transformer.expect_concurrency().returning(|| None);
674 transformer.expect_name().returning(|| "transformer");
675
676 batch_transformer
677 .expect_batch_transform()
678 .times(1)
679 .in_sequence(&mut seq)
680 .returning(|nodes| IndexingStream::iter(nodes.into_iter().map(Ok)));
681 batch_transformer.expect_concurrency().returning(|| None);
682 batch_transformer.expect_name().returning(|| "transformer");
683 batch_transformer.expect_batch_size().returning(|| None);
684
685 chunker
686 .expect_transform_node()
687 .times(1)
688 .in_sequence(&mut seq)
689 .returning(|node| {
690 let mut nodes = vec![];
691 for i in 0..3 {
692 let mut node = node.clone();
693 node.chunk = format!("transformed_chunk_{i}");
694 nodes.push(Ok(node));
695 }
696 nodes.into()
697 });
698 chunker.expect_concurrency().returning(|| None);
699 chunker.expect_name().returning(|| "chunker");
700
701 storage.expect_setup().returning(|| Ok(()));
702 storage.expect_batch_size().returning(|| None);
703 storage
704 .expect_store()
705 .times(3)
706 .in_sequence(&mut seq)
707 .withf(|node| node.chunk.starts_with("transformed_chunk_"))
708 .returning(Ok);
709 storage.expect_name().returning(|| "storage");
710
711 let pipeline = Pipeline::from_loader(loader)
712 .then(transformer)
713 .then_in_batch(batch_transformer)
714 .then_chunk(chunker)
715 .then_store_with(storage);
716
717 pipeline.run().await.unwrap();
718 }
719
720 #[tokio::test]
721 async fn test_skipping_errors() {
722 let mut loader = MockLoader::new();
723 let mut transformer = MockTransformer::new();
724 let mut storage = MockPersist::new();
725 let mut seq = Sequence::new();
726 loader
727 .expect_into_stream()
728 .times(1)
729 .in_sequence(&mut seq)
730 .returning(|| vec![Ok(Node::default())].into());
731 transformer
732 .expect_transform_node()
733 .returning(|_node| Err(anyhow::anyhow!("Error transforming node")));
734 transformer.expect_concurrency().returning(|| None);
735 transformer.expect_name().returning(|| "mock");
736 storage.expect_setup().returning(|| Ok(()));
737 storage.expect_batch_size().returning(|| None);
738 storage.expect_store().times(0).returning(Ok);
739 let pipeline = Pipeline::from_loader(loader)
740 .then(transformer)
741 .then_store_with(storage)
742 .filter_errors();
743 pipeline.run().await.unwrap();
744 }
745
746 #[tokio::test]
747 async fn test_concurrent_calls_with_simple_transformer() {
748 let mut loader = MockLoader::new();
749 let mut transformer = MockTransformer::new();
750 let mut storage = MockPersist::new();
751 let mut seq = Sequence::new();
752 loader
753 .expect_into_stream()
754 .times(1)
755 .in_sequence(&mut seq)
756 .returning(|| {
757 vec![
758 Ok(Node::default()),
759 Ok(Node::default()),
760 Ok(Node::default()),
761 ]
762 .into()
763 });
764 transformer
765 .expect_transform_node()
766 .times(3)
767 .in_sequence(&mut seq)
768 .returning(|mut node| {
769 node.chunk = "transformed".to_string();
770 Ok(node)
771 });
772 transformer.expect_concurrency().returning(|| Some(3));
773 transformer.expect_name().returning(|| "transformer");
774 storage.expect_setup().returning(|| Ok(()));
775 storage.expect_batch_size().returning(|| None);
776 storage.expect_store().times(3).returning(Ok);
777 storage.expect_name().returning(|| "storage");
778
779 let pipeline = Pipeline::from_loader(loader)
780 .then(transformer)
781 .then_store_with(storage);
782 pipeline.run().await.unwrap();
783 }
784
785 #[tokio::test]
786 async fn test_arbitrary_closures_as_transformer() {
787 let mut loader = MockLoader::new();
788 let transformer = |node: TextNode| {
789 let mut node = node;
790 node.chunk = "transformed".to_string();
791 Ok(node)
792 };
793 let storage = MemoryStorage::default();
794 let mut seq = Sequence::new();
795 loader
796 .expect_into_stream()
797 .times(1)
798 .in_sequence(&mut seq)
799 .returning(|| vec![Ok(TextNode::default())].into());
800
801 let pipeline = Pipeline::from_loader(loader)
802 .then(transformer)
803 .then_store_with(storage.clone());
804 pipeline.run().await.unwrap();
805
806 dbg!(storage.clone());
807 let processed_node = storage.get("0").await.unwrap();
808 assert_eq!(processed_node.chunk, "transformed");
809 }
810
811 #[tokio::test]
812 async fn test_arbitrary_closures_as_batch_transformer() {
813 let mut loader = MockLoader::new();
814 let batch_transformer = |nodes: Vec<TextNode>| {
815 IndexingStream::iter(nodes.into_iter().map(|mut node| {
816 node.chunk = "transformed".to_string();
817 Ok(node)
818 }))
819 };
820 let storage = MemoryStorage::default();
821 let mut seq = Sequence::new();
822 loader
823 .expect_into_stream()
824 .times(1)
825 .in_sequence(&mut seq)
826 .returning(|| vec![Ok(TextNode::default())].into());
827
828 let pipeline = Pipeline::from_loader(loader)
829 .then_in_batch(batch_transformer)
830 .then_store_with(storage.clone());
831 pipeline.run().await.unwrap();
832
833 dbg!(storage.clone());
834 let processed_node = storage.get("0").await.unwrap();
835 assert_eq!(processed_node.chunk, "transformed");
836 }
837
838 #[tokio::test]
839 async fn test_filter_closure() {
840 let mut loader = MockLoader::new();
841 let storage = MemoryStorage::default();
842 let mut seq = Sequence::new();
843 loader
844 .expect_into_stream()
845 .times(1)
846 .in_sequence(&mut seq)
847 .returning(|| {
848 vec![
849 Ok(TextNode::default()),
850 Ok(TextNode::new("skip")),
851 Ok(TextNode::default()),
852 ]
853 .into()
854 });
855 let pipeline = Pipeline::from_loader(loader)
856 .filter(|result| {
857 let node = result.as_ref().unwrap();
858 node.chunk != "skip"
859 })
860 .then_store_with(storage.clone());
861 pipeline.run().await.unwrap();
862 let nodes = storage.get_all().await;
863 assert_eq!(nodes.len(), 2);
864 }
865
866 #[test_log::test(tokio::test)]
867 async fn test_split_and_merge() {
868 let mut loader = MockLoader::new();
869 let storage = MemoryStorage::default();
870 let mut seq = Sequence::new();
871 loader
872 .expect_into_stream()
873 .times(1)
874 .in_sequence(&mut seq)
875 .returning(|| {
876 vec![
877 Ok(TextNode::default()),
878 Ok(TextNode::new("will go left")),
879 Ok(TextNode::default()),
880 ]
881 .into()
882 });
883
884 let pipeline = Pipeline::from_loader(loader);
885 let (mut left, mut right) = pipeline.split_by(|node| {
886 if let Ok(node) = node {
887 node.chunk.starts_with("will go left")
888 } else {
889 false
890 }
891 });
892
893 left = left
895 .then(move |mut node: TextNode| {
896 node.chunk = "left".to_string();
897
898 Ok(node)
899 })
900 .log_all();
901
902 right = right.then(move |mut node: TextNode| {
903 node.chunk = "right".to_string();
904 Ok(node)
905 });
906
907 left.merge(right)
908 .then_store_with(storage.clone())
909 .run()
910 .await
911 .unwrap();
912 dbg!(storage.clone());
913
914 let all_nodes = storage.get_all_values().await;
915 assert_eq!(
916 all_nodes.iter().filter(|node| node.chunk == "left").count(),
917 1
918 );
919 assert_eq!(
920 all_nodes
921 .iter()
922 .filter(|node| node.chunk == "right")
923 .count(),
924 2
925 );
926 }
927
928 #[tokio::test]
929 async fn test_all_steps_should_work_as_dyn_box() {
930 let mut loader = MockLoader::new();
931 loader
932 .expect_into_stream_boxed()
933 .returning(|| vec![Ok(TextNode::default())].into());
934
935 let mut transformer = MockTransformer::new();
936 transformer.expect_transform_node().returning(Ok);
937 transformer.expect_concurrency().returning(|| None);
938 transformer.expect_name().returning(|| "mock");
939
940 let mut batch_transformer = MockBatchableTransformer::new();
941 batch_transformer
942 .expect_batch_transform()
943 .returning(std::convert::Into::into);
944 batch_transformer.expect_concurrency().returning(|| None);
945 batch_transformer.expect_name().returning(|| "mock");
946 let mut chunker = MockChunkerTransformer::new();
947 chunker
948 .expect_transform_node()
949 .returning(|node| vec![node].into());
950 chunker.expect_concurrency().returning(|| None);
951 chunker.expect_name().returning(|| "mock");
952
953 let mut storage = MockPersist::new();
954 storage.expect_setup().returning(|| Ok(()));
955 storage.expect_store().returning(Ok);
956 storage.expect_batch_size().returning(|| None);
957 storage.expect_name().returning(|| "mock");
958
959 let pipeline = Pipeline::from_loader(Box::new(loader) as Box<dyn Loader<Output = String>>)
960 .then(Box::new(transformer) as Box<dyn Transformer<Input = String, Output = String>>)
961 .then_in_batch(Box::new(batch_transformer) as Box<dyn BatchableTransformer<Input = String, Output = String>>)
962 .then_chunk(Box::new(chunker) as Box<dyn ChunkerTransformer<Input = String, Output = String>>)
963 .then_store_with(Box::new(storage) as Box<dyn Persist<Input = String, Output = String>>);
964 pipeline.run().await.unwrap();
965 }
966}