1use anyhow::Result;
2use futures_util::{StreamExt, TryFutureExt, TryStreamExt};
3use swiftide_core::{
4 BatchableTransformer, ChunkerTransformer, Loader, NodeCache, Persist, SimplePrompt,
5 Transformer, WithBatchIndexingDefaults, WithIndexingDefaults, indexing::IndexingDefaults,
6};
7use tokio::{sync::mpsc, task};
8use tracing::Instrument;
9
10use std::{sync::Arc, time::Duration};
11
12use swiftide_core::indexing::{EmbedMode, IndexingStream, Node};
13
14macro_rules! trace_span {
15 ($op:literal, $step:expr) => {
16 tracing::trace_span!($op, "otel.name" = format!("{}.{}", $op, $step.name()),)
17 };
18
19 ($op:literal) => {
20 tracing::trace_span!($op, "otel.name" = format!("{}", $op),)
21 };
22}
23
24macro_rules! node_trace_log {
25 ($step:expr, $node:expr, $msg:literal) => {
26 tracing::trace!(
27 node = ?$node,
28 node_id = ?$node.id(),
29 step = $step.name(),
30 $msg
31 )
32 };
33}
34
35macro_rules! batch_node_trace_log {
36 ($step:expr, $nodes:expr, $msg:literal) => {
37 tracing::trace!(batch_size = $nodes.len(), nodes = ?$nodes, step = $step.name(), $msg)
38 };
39}
40
41const DEFAULT_BATCH_SIZE: usize = 256;
43
44pub struct Pipeline {
57 stream: IndexingStream,
58 storage: Vec<Arc<dyn Persist>>,
59 concurrency: usize,
60 indexing_defaults: IndexingDefaults,
61 batch_size: usize,
62}
63
64impl Default for Pipeline {
65 fn default() -> Self {
68 Self {
69 stream: IndexingStream::empty(),
70 storage: Vec::default(),
71 concurrency: num_cpus::get(),
72 indexing_defaults: IndexingDefaults::default(),
73 batch_size: DEFAULT_BATCH_SIZE,
74 }
75 }
76}
77
78impl Pipeline {
79 pub fn from_loader(loader: impl Loader + 'static) -> Self {
89 let stream = loader.into_stream();
90 Self {
91 stream,
92 ..Default::default()
93 }
94 }
95
96 #[must_use]
99 pub fn with_default_llm_client(mut self, client: impl SimplePrompt + 'static) -> Self {
100 self.indexing_defaults = IndexingDefaults::from_simple_prompt(Box::new(client));
101 self
102 }
103
104 pub fn from_stream(stream: impl Into<IndexingStream>) -> Self {
114 Self {
115 stream: stream.into(),
116 ..Default::default()
117 }
118 }
119
120 #[must_use]
131 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
132 self.concurrency = concurrency;
133 self
134 }
135
136 #[must_use]
149 pub fn with_embed_mode(mut self, embed_mode: EmbedMode) -> Self {
150 self.stream = self
151 .stream
152 .map_ok(move |mut node| {
153 node.embed_mode = embed_mode;
154 node
155 })
156 .boxed()
157 .into();
158 self
159 }
160
161 #[must_use]
171 pub fn filter_cached(mut self, cache: impl NodeCache + 'static) -> Self {
172 let cache = Arc::new(cache);
173 self.stream = self
174 .stream
175 .try_filter_map(move |node| {
176 let cache = Arc::clone(&cache);
177 let span = trace_span!("filter_cached", cache);
178
179 async move {
180 if cache.get(&node).await {
181 node_trace_log!(cache, node, "node in cache, skipping");
182 Ok(None)
183 } else {
184 node_trace_log!(cache, node, "node not in cache, processing");
185 cache.set(&node).await;
186 Ok(Some(node))
187 }
188 }
189 .instrument(span.or_current())
190 })
191 .boxed()
192 .into();
193 self
194 }
195
196 #[must_use]
208 pub fn then(
209 mut self,
210 mut transformer: impl Transformer + WithIndexingDefaults + 'static,
211 ) -> Self {
212 let concurrency = transformer.concurrency().unwrap_or(self.concurrency);
213
214 transformer.with_indexing_defaults(self.indexing_defaults.clone());
215
216 let transformer = Arc::new(transformer);
217 self.stream = self
218 .stream
219 .map_ok(move |node| {
220 let transformer = transformer.clone();
221 let span = trace_span!("then", transformer);
222
223 task::spawn(
224 async move {
225 node_trace_log!(transformer, node, "Transforming node");
226 transformer.transform_node(node).await
227 }
228 .instrument(span.or_current()),
229 )
230 .err_into::<anyhow::Error>()
231 })
232 .try_buffer_unordered(concurrency)
233 .map(|x| x.and_then(|x| x))
234 .boxed()
235 .into();
236
237 self
238 }
239
240 #[must_use]
254 pub fn then_in_batch(
255 mut self,
256 mut transformer: impl BatchableTransformer + WithBatchIndexingDefaults + 'static,
257 ) -> Self {
258 let concurrency = transformer.concurrency().unwrap_or(self.concurrency);
259
260 transformer.with_indexing_defaults(self.indexing_defaults.clone());
261
262 let transformer = Arc::new(transformer);
263 self.stream = self
264 .stream
265 .try_chunks(transformer.batch_size().unwrap_or(self.batch_size))
266 .map_ok(move |nodes| {
267 let transformer = Arc::clone(&transformer);
268 let span = trace_span!("then_in_batch", transformer);
269
270 tokio::spawn(
271 async move {
272 batch_node_trace_log!(transformer, nodes, "batch transforming nodes");
273 transformer.batch_transform(nodes).await
274 }
275 .instrument(span.or_current()),
276 )
277 .map_err(anyhow::Error::from)
278 })
279 .err_into::<anyhow::Error>()
280 .try_buffer_unordered(concurrency) .try_flatten_unordered(None) .boxed()
283 .into();
284 self
285 }
286
287 #[must_use]
298 pub fn then_chunk(mut self, chunker: impl ChunkerTransformer + 'static) -> Self {
299 let chunker = Arc::new(chunker);
300 let concurrency = chunker.concurrency().unwrap_or(self.concurrency);
301 self.stream = self
302 .stream
303 .map_ok(move |node| {
304 let chunker = Arc::clone(&chunker);
305 let span = trace_span!("then_chunk", chunker);
306
307 tokio::spawn(
308 async move {
309 node_trace_log!(chunker, node, "Chunking node");
310 chunker.transform_node(node).await
311 }
312 .instrument(span.or_current()),
313 )
314 .map_err(anyhow::Error::from)
315 })
316 .err_into::<anyhow::Error>()
317 .try_buffer_unordered(concurrency)
318 .try_flatten_unordered(None)
319 .boxed()
320 .into();
321
322 self
323 }
324
325 #[must_use]
340 pub fn then_store_with(mut self, storage: impl Persist + 'static) -> Self {
341 let storage = Arc::new(storage);
342 self.storage.push(storage.clone());
343 if storage.batch_size().is_some() {
345 self.stream = self
346 .stream
347 .try_chunks(storage.batch_size().unwrap())
348 .map_ok(move |nodes| {
349 let storage = Arc::clone(&storage);
350 let span = trace_span!("then_store_with_batched", storage);
351
352 tokio::spawn(
353 async move {
354 batch_node_trace_log!(storage, nodes, "batch storing nodes");
355 storage.batch_store(nodes).await
356 }
357 .instrument(span.or_current()),
358 )
359 .map_err(anyhow::Error::from)
360 })
361 .err_into::<anyhow::Error>()
362 .try_buffer_unordered(self.concurrency)
363 .try_flatten_unordered(None)
364 .boxed()
365 .into();
366 } else {
367 self.stream = self
368 .stream
369 .map_ok(move |node| {
370 let storage = Arc::clone(&storage);
371 let span = trace_span!("then_store_with", storage);
372
373 tokio::spawn(
374 async move {
375 node_trace_log!(storage, node, "Storing node");
376
377 storage.store(node).await
378 }
379 .instrument(span.or_current()),
380 )
381 .err_into::<anyhow::Error>()
382 })
383 .try_buffer_unordered(self.concurrency)
384 .map(|x| x.and_then(|x| x))
385 .boxed()
386 .into();
387 }
388
389 self
390 }
391
392 #[must_use]
407 pub fn split_by<P>(self, predicate: P) -> (Self, Self)
408 where
409 P: Fn(&Result<Node>) -> bool + Send + Sync + 'static,
410 {
411 let predicate = Arc::new(predicate);
412
413 let (left_tx, left_rx) = mpsc::channel(1000);
414 let (right_tx, right_rx) = mpsc::channel(1000);
415
416 let stream = self.stream;
417 let span = trace_span!("split_by");
418 tokio::spawn(
419 async move {
420 stream
421 .for_each_concurrent(self.concurrency, move |item| {
422 let predicate = Arc::clone(&predicate);
423 let left_tx = left_tx.clone();
424 let right_tx = right_tx.clone();
425 async move {
426 if predicate(&item) {
427 tracing::trace!(?item, "Sending to left stream");
428 left_tx
429 .send(item)
430 .await
431 .expect("Failed to send to left stream");
432 } else {
433 tracing::trace!(?item, "Sending to right stream");
434 right_tx
435 .send(item)
436 .await
437 .expect("Failed to send to right stream");
438 }
439 }
440 })
441 .await;
442 }
443 .instrument(span.or_current()),
444 );
445
446 let left_pipeline = Self {
447 stream: left_rx.into(),
448 storage: self.storage.clone(),
449 concurrency: self.concurrency,
450 indexing_defaults: self.indexing_defaults.clone(),
451 batch_size: self.batch_size,
452 };
453
454 let right_pipeline = Self {
455 stream: right_rx.into(),
456 storage: self.storage.clone(),
457 concurrency: self.concurrency,
458 indexing_defaults: self.indexing_defaults.clone(),
459 batch_size: self.batch_size,
460 };
461
462 (left_pipeline, right_pipeline)
463 }
464
465 #[must_use]
471 pub fn merge(self, other: Self) -> Self {
472 let stream = tokio_stream::StreamExt::merge(self.stream, other.stream);
473
474 Self {
475 stream: stream.boxed().into(),
476 ..self
477 }
478 }
479
480 #[must_use]
485 pub fn throttle(mut self, duration: impl Into<Duration>) -> Self {
486 self.stream = tokio_stream::StreamExt::throttle(self.stream, duration.into())
487 .boxed()
488 .into();
489 self
490 }
491
492 #[must_use]
497 pub fn filter_errors(mut self) -> Self {
498 self.stream = self
499 .stream
500 .filter_map(|result| async {
501 match result {
502 Ok(node) => Some(Ok(node)),
503 Err(_e) => None,
504 }
505 })
506 .boxed()
507 .into();
508 self
509 }
510
511 #[must_use]
517 pub fn filter<F>(mut self, filter: F) -> Self
518 where
519 F: Fn(&Result<Node>) -> bool + Send + Sync + 'static,
520 {
521 self.stream = self
522 .stream
523 .filter(move |result| {
524 let will_retain = filter(result);
525
526 async move { will_retain }
527 })
528 .boxed()
529 .into();
530 self
531 }
532
533 #[must_use]
537 pub fn log_all(self) -> Self {
538 self.log_errors().log_nodes()
539 }
540
541 #[must_use]
545 pub fn log_errors(mut self) -> Self {
546 self.stream = self
547 .stream
548 .inspect_err(|e| tracing::error!(?e, "Error processing node"))
549 .boxed()
550 .into();
551 self
552 }
553
554 #[must_use]
558 pub fn log_nodes(mut self) -> Self {
559 self.stream = self
560 .stream
561 .inspect_ok(|node| tracing::debug!(?node, "Processed node: {:?}", node))
562 .boxed()
563 .into();
564 self
565 }
566
567 #[tracing::instrument(skip_all, fields(total_nodes), name = "indexing_pipeline.run")]
580 pub async fn run(mut self) -> Result<()> {
581 tracing::info!(
582 "Starting indexing pipeline with {} concurrency",
583 self.concurrency
584 );
585 let now = std::time::Instant::now();
586 if self.storage.is_empty() {
587 anyhow::bail!("No storage configured for indexing pipeline");
588 }
589
590 let setup_futures = self
592 .storage
593 .into_iter()
594 .map(|storage| async move { storage.setup().await })
595 .collect::<Vec<_>>();
596 futures_util::future::try_join_all(setup_futures).await?;
597
598 let mut total_nodes = 0;
599 while self.stream.try_next().await?.is_some() {
600 total_nodes += 1;
601 }
602
603 let elapsed_in_seconds = now.elapsed().as_secs();
604 tracing::info!(
605 elapsed_in_seconds,
606 "Processed {} nodes in {} seconds",
607 total_nodes,
608 elapsed_in_seconds
609 );
610 tracing::Span::current().record("total_nodes", total_nodes);
611
612 Ok(())
613 }
614}
615
616#[cfg(test)]
617mod tests {
618
619 use super::*;
620 use crate::persist::MemoryStorage;
621 use mockall::Sequence;
622 use swiftide_core::indexing::*;
623
624 #[test_log::test(tokio::test)]
626 async fn test_simple_run() {
627 let mut loader = MockLoader::new();
628 let mut transformer = MockTransformer::new();
629 let mut batch_transformer = MockBatchableTransformer::new();
630 let mut chunker = MockChunkerTransformer::new();
631 let mut storage = MockPersist::new();
632
633 let mut seq = Sequence::new();
634
635 loader
636 .expect_into_stream()
637 .times(1)
638 .in_sequence(&mut seq)
639 .returning(|| vec![Ok(Node::default())].into());
640
641 transformer.expect_transform_node().returning(|mut node| {
642 node.chunk = "transformed".to_string();
643 Ok(node)
644 });
645 transformer.expect_concurrency().returning(|| None);
646 transformer.expect_name().returning(|| "transformer");
647
648 batch_transformer
649 .expect_batch_transform()
650 .times(1)
651 .in_sequence(&mut seq)
652 .returning(|nodes| IndexingStream::iter(nodes.into_iter().map(Ok)));
653 batch_transformer.expect_concurrency().returning(|| None);
654 batch_transformer.expect_name().returning(|| "transformer");
655 batch_transformer.expect_batch_size().returning(|| None);
656
657 chunker
658 .expect_transform_node()
659 .times(1)
660 .in_sequence(&mut seq)
661 .returning(|node| {
662 let mut nodes = vec![];
663 for i in 0..3 {
664 let mut node = node.clone();
665 node.chunk = format!("transformed_chunk_{i}");
666 nodes.push(Ok(node));
667 }
668 nodes.into()
669 });
670 chunker.expect_concurrency().returning(|| None);
671 chunker.expect_name().returning(|| "chunker");
672
673 storage.expect_setup().returning(|| Ok(()));
674 storage.expect_batch_size().returning(|| None);
675 storage
676 .expect_store()
677 .times(3)
678 .in_sequence(&mut seq)
679 .withf(|node| node.chunk.starts_with("transformed_chunk_"))
680 .returning(Ok);
681 storage.expect_name().returning(|| "storage");
682
683 let pipeline = Pipeline::from_loader(loader)
684 .then(transformer)
685 .then_in_batch(batch_transformer)
686 .then_chunk(chunker)
687 .then_store_with(storage);
688
689 pipeline.run().await.unwrap();
690 }
691
692 #[tokio::test]
693 async fn test_skipping_errors() {
694 let mut loader = MockLoader::new();
695 let mut transformer = MockTransformer::new();
696 let mut storage = MockPersist::new();
697 let mut seq = Sequence::new();
698 loader
699 .expect_into_stream()
700 .times(1)
701 .in_sequence(&mut seq)
702 .returning(|| vec![Ok(Node::default())].into());
703 transformer
704 .expect_transform_node()
705 .returning(|_node| Err(anyhow::anyhow!("Error transforming node")));
706 transformer.expect_concurrency().returning(|| None);
707 transformer.expect_name().returning(|| "mock");
708 storage.expect_setup().returning(|| Ok(()));
709 storage.expect_batch_size().returning(|| None);
710 storage.expect_store().times(0).returning(Ok);
711 let pipeline = Pipeline::from_loader(loader)
712 .then(transformer)
713 .then_store_with(storage)
714 .filter_errors();
715 pipeline.run().await.unwrap();
716 }
717
718 #[tokio::test]
719 async fn test_concurrent_calls_with_simple_transformer() {
720 let mut loader = MockLoader::new();
721 let mut transformer = MockTransformer::new();
722 let mut storage = MockPersist::new();
723 let mut seq = Sequence::new();
724 loader
725 .expect_into_stream()
726 .times(1)
727 .in_sequence(&mut seq)
728 .returning(|| {
729 vec![
730 Ok(Node::default()),
731 Ok(Node::default()),
732 Ok(Node::default()),
733 ]
734 .into()
735 });
736 transformer
737 .expect_transform_node()
738 .times(3)
739 .in_sequence(&mut seq)
740 .returning(|mut node| {
741 node.chunk = "transformed".to_string();
742 Ok(node)
743 });
744 transformer.expect_concurrency().returning(|| Some(3));
745 transformer.expect_name().returning(|| "transformer");
746 storage.expect_setup().returning(|| Ok(()));
747 storage.expect_batch_size().returning(|| None);
748 storage.expect_store().times(3).returning(Ok);
749 storage.expect_name().returning(|| "storage");
750
751 let pipeline = Pipeline::from_loader(loader)
752 .then(transformer)
753 .then_store_with(storage);
754 pipeline.run().await.unwrap();
755 }
756
757 #[tokio::test]
758 async fn test_arbitrary_closures_as_transformer() {
759 let mut loader = MockLoader::new();
760 let transformer = |node: Node| {
761 let mut node = node;
762 node.chunk = "transformed".to_string();
763 Ok(node)
764 };
765 let storage = MemoryStorage::default();
766 let mut seq = Sequence::new();
767 loader
768 .expect_into_stream()
769 .times(1)
770 .in_sequence(&mut seq)
771 .returning(|| vec![Ok(Node::default())].into());
772
773 let pipeline = Pipeline::from_loader(loader)
774 .then(transformer)
775 .then_store_with(storage.clone());
776 pipeline.run().await.unwrap();
777
778 dbg!(storage.clone());
779 let processed_node = storage.get("0").await.unwrap();
780 assert_eq!(processed_node.chunk, "transformed");
781 }
782
783 #[tokio::test]
784 async fn test_arbitrary_closures_as_batch_transformer() {
785 let mut loader = MockLoader::new();
786 let batch_transformer = |nodes: Vec<Node>| {
787 IndexingStream::iter(nodes.into_iter().map(|mut node| {
788 node.chunk = "transformed".to_string();
789 Ok(node)
790 }))
791 };
792 let storage = MemoryStorage::default();
793 let mut seq = Sequence::new();
794 loader
795 .expect_into_stream()
796 .times(1)
797 .in_sequence(&mut seq)
798 .returning(|| vec![Ok(Node::default())].into());
799
800 let pipeline = Pipeline::from_loader(loader)
801 .then_in_batch(batch_transformer)
802 .then_store_with(storage.clone());
803 pipeline.run().await.unwrap();
804
805 dbg!(storage.clone());
806 let processed_node = storage.get("0").await.unwrap();
807 assert_eq!(processed_node.chunk, "transformed");
808 }
809
810 #[tokio::test]
811 async fn test_filter_closure() {
812 let mut loader = MockLoader::new();
813 let storage = MemoryStorage::default();
814 let mut seq = Sequence::new();
815 loader
816 .expect_into_stream()
817 .times(1)
818 .in_sequence(&mut seq)
819 .returning(|| {
820 vec![
821 Ok(Node::default()),
822 Ok(Node::new("skip")),
823 Ok(Node::default()),
824 ]
825 .into()
826 });
827 let pipeline = Pipeline::from_loader(loader)
828 .filter(|result| {
829 let node = result.as_ref().unwrap();
830 node.chunk != "skip"
831 })
832 .then_store_with(storage.clone());
833 pipeline.run().await.unwrap();
834 let nodes = storage.get_all().await;
835 assert_eq!(nodes.len(), 2);
836 }
837
838 #[test_log::test(tokio::test)]
839 async fn test_split_and_merge() {
840 let mut loader = MockLoader::new();
841 let storage = MemoryStorage::default();
842 let mut seq = Sequence::new();
843 loader
844 .expect_into_stream()
845 .times(1)
846 .in_sequence(&mut seq)
847 .returning(|| {
848 vec![
849 Ok(Node::default()),
850 Ok(Node::new("will go left")),
851 Ok(Node::default()),
852 ]
853 .into()
854 });
855
856 let pipeline = Pipeline::from_loader(loader);
857 let (mut left, mut right) = pipeline.split_by(|node| {
858 if let Ok(node) = node {
859 node.chunk.starts_with("will go left")
860 } else {
861 false
862 }
863 });
864
865 left = left
867 .then(move |mut node: Node| {
868 node.chunk = "left".to_string();
869
870 Ok(node)
871 })
872 .log_all();
873
874 right = right.then(move |mut node: Node| {
875 node.chunk = "right".to_string();
876 Ok(node)
877 });
878
879 left.merge(right)
880 .then_store_with(storage.clone())
881 .run()
882 .await
883 .unwrap();
884 dbg!(storage.clone());
885
886 let all_nodes = storage.get_all_values().await;
887 assert_eq!(
888 all_nodes.iter().filter(|node| node.chunk == "left").count(),
889 1
890 );
891 assert_eq!(
892 all_nodes
893 .iter()
894 .filter(|node| node.chunk == "right")
895 .count(),
896 2
897 );
898 }
899
900 #[tokio::test]
901 async fn test_all_steps_should_work_as_dyn_box() {
902 let mut loader = MockLoader::new();
903 loader
904 .expect_into_stream_boxed()
905 .returning(|| vec![Ok(Node::default())].into());
906
907 let mut transformer = MockTransformer::new();
908 transformer.expect_transform_node().returning(Ok);
909 transformer.expect_concurrency().returning(|| None);
910 transformer.expect_name().returning(|| "mock");
911
912 let mut batch_transformer = MockBatchableTransformer::new();
913 batch_transformer
914 .expect_batch_transform()
915 .returning(std::convert::Into::into);
916 batch_transformer.expect_concurrency().returning(|| None);
917 batch_transformer.expect_name().returning(|| "mock");
918 let mut chunker = MockChunkerTransformer::new();
919 chunker
920 .expect_transform_node()
921 .returning(|node| vec![node].into());
922 chunker.expect_concurrency().returning(|| None);
923 chunker.expect_name().returning(|| "mock");
924
925 let mut storage = MockPersist::new();
926 storage.expect_setup().returning(|| Ok(()));
927 storage.expect_store().returning(Ok);
928 storage.expect_batch_size().returning(|| None);
929 storage.expect_name().returning(|| "mock");
930
931 let pipeline = Pipeline::from_loader(Box::new(loader) as Box<dyn Loader>)
932 .then(Box::new(transformer) as Box<dyn Transformer>)
933 .then_in_batch(Box::new(batch_transformer) as Box<dyn BatchableTransformer>)
934 .then_chunk(Box::new(chunker) as Box<dyn ChunkerTransformer>)
935 .then_store_with(Box::new(storage) as Box<dyn Persist>);
936 pipeline.run().await.unwrap();
937 }
938}