1use anyhow::Result;
2use futures_util::{
3 FutureExt, StreamExt, TryFutureExt, TryStreamExt,
4 future::{BoxFuture, Shared},
5};
6use swiftide_core::{
7 BatchableTransformer, ChunkerTransformer, Loader, NodeCache, Persist, SimplePrompt,
8 Transformer, WithBatchIndexingDefaults, WithIndexingDefaults, indexing::IndexingDefaults,
9};
10use tokio::{
11 sync::mpsc,
12 task::{self},
13};
14use tracing::Instrument;
15
16use std::{collections::HashSet, sync::Arc, time::Duration};
17
18use swiftide_core::indexing::{EmbedMode, IndexingStream, Node};
19
20macro_rules! trace_span {
21 ($op:literal, $step:expr) => {
22 tracing::trace_span!($op, "otel.name" = format!("{}.{}", $op, $step.name()),)
23 };
24
25 ($op:literal) => {
26 tracing::trace_span!($op, "otel.name" = format!("{}", $op),)
27 };
28}
29
30macro_rules! node_trace_log {
31 ($step:expr, $node:expr, $msg:literal) => {
32 tracing::trace!(
33 node = ?$node,
34 node_id = ?$node.id(),
35 step = $step.name(),
36 $msg
37 )
38 };
39}
40
41macro_rules! batch_node_trace_log {
42 ($step:expr, $nodes:expr, $msg:literal) => {
43 tracing::trace!(batch_size = $nodes.len(), nodes = ?$nodes, step = $step.name(), $msg)
44 };
45}
46
47const DEFAULT_BATCH_SIZE: usize = 256;
49
50pub struct Pipeline {
63 stream: IndexingStream,
64 storage: Vec<Arc<dyn Persist>>,
65 concurrency: usize,
66 indexing_defaults: IndexingDefaults,
67 batch_size: usize,
68 cache_sender: Option<mpsc::Sender<uuid::Uuid>>,
69 cache_handle: Option<Shared<BoxFuture<'static, ()>>>,
70}
71
72impl Default for Pipeline {
73 fn default() -> Self {
76 Self {
77 stream: IndexingStream::empty(),
78 storage: Vec::default(),
79 concurrency: num_cpus::get(),
80 indexing_defaults: IndexingDefaults::default(),
81 batch_size: DEFAULT_BATCH_SIZE,
82 cache_sender: None,
83 cache_handle: None,
84 }
85 }
86}
87
88impl Pipeline {
89 pub fn from_loader(loader: impl Loader + 'static) -> Self {
99 let stream = loader.into_stream();
100 Self {
101 stream,
102 ..Default::default()
103 }
104 }
105
106 #[must_use]
109 pub fn with_default_llm_client(mut self, client: impl SimplePrompt + 'static) -> Self {
110 self.indexing_defaults = IndexingDefaults::from_simple_prompt(Box::new(client));
111 self
112 }
113
114 pub fn from_stream(stream: impl Into<IndexingStream>) -> Self {
124 Self {
125 stream: stream.into(),
126 ..Default::default()
127 }
128 }
129
130 #[must_use]
141 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
142 self.concurrency = concurrency;
143 self
144 }
145
146 #[must_use]
159 pub fn with_embed_mode(mut self, embed_mode: EmbedMode) -> Self {
160 self.stream = self
161 .stream
162 .map_ok(move |mut node| {
163 node.embed_mode = embed_mode;
164 node
165 })
166 .boxed()
167 .into();
168 self
169 }
170
171 #[must_use]
181 pub fn filter_cached(mut self, cache: impl NodeCache + 'static) -> Self {
182 let cache = Arc::new(cache);
183 let (cache_sender, mut cache_receiver) = mpsc::channel(1000);
184
185 let cache_for_task = Arc::clone(&cache);
186
187 let handle = tokio::spawn(async move {
188 while let Some(cache_key) = cache_receiver.recv().await {
189 let cache_node = Node {
190 parent_id: Some(cache_key),
191 ..Default::default()
192 };
193 cache_for_task.set(&cache_node).await;
194 }
195 });
196
197 self.cache_handle = Some(
198 handle
199 .map(|result| {
200 if let Err(err) = result {
201 tracing::error!("Cache task failed: {err}");
202 }
203 })
204 .boxed()
205 .shared(),
206 );
207
208 self.stream = self
209 .stream
210 .try_filter_map(move |node| {
211 let cache = Arc::clone(&cache);
212 let span = trace_span!("filter_cached", cache);
213
214 async move {
215 if cache.get(&node).await {
216 node_trace_log!(cache, node, "node in cache, skipping");
217 Ok(None)
218 } else {
219 node_trace_log!(cache, node, "node not in cache, processing");
220 Ok(Some(node))
221 }
222 }
223 .instrument(span.or_current())
224 })
225 .boxed()
226 .into();
227 self.cache_sender = Some(cache_sender);
228 self
229 }
230
231 #[must_use]
243 pub fn then(
244 mut self,
245 mut transformer: impl Transformer + WithIndexingDefaults + 'static,
246 ) -> Self {
247 let concurrency = transformer.concurrency().unwrap_or(self.concurrency);
248
249 transformer.with_indexing_defaults(self.indexing_defaults.clone());
250
251 let transformer = Arc::new(transformer);
252 self.stream = self
253 .stream
254 .map_ok(move |node| {
255 let transformer = transformer.clone();
256 let span = trace_span!("then", transformer);
257
258 task::spawn(
259 async move {
260 node_trace_log!(transformer, node, "Transforming node");
261 transformer.transform_node(node).await
262 }
263 .instrument(span.or_current()),
264 )
265 .err_into::<anyhow::Error>()
266 })
267 .try_buffer_unordered(concurrency)
268 .map(|x| x.and_then(|x| x))
269 .boxed()
270 .into();
271
272 self
273 }
274
275 #[must_use]
289 pub fn then_in_batch(
290 mut self,
291 mut transformer: impl BatchableTransformer + WithBatchIndexingDefaults + 'static,
292 ) -> Self {
293 let concurrency = transformer.concurrency().unwrap_or(self.concurrency);
294
295 transformer.with_indexing_defaults(self.indexing_defaults.clone());
296
297 let transformer = Arc::new(transformer);
298 self.stream = self
299 .stream
300 .try_chunks(transformer.batch_size().unwrap_or(self.batch_size))
301 .map_ok(move |nodes| {
302 let transformer = Arc::clone(&transformer);
303 let span = trace_span!("then_in_batch", transformer);
304
305 tokio::spawn(
306 async move {
307 batch_node_trace_log!(transformer, nodes, "batch transforming nodes");
308 transformer.batch_transform(nodes).await
309 }
310 .instrument(span.or_current()),
311 )
312 .map_err(anyhow::Error::from)
313 })
314 .err_into::<anyhow::Error>()
315 .try_buffer_unordered(concurrency) .try_flatten_unordered(None) .boxed()
318 .into();
319 self
320 }
321
322 #[must_use]
333 pub fn then_chunk(mut self, chunker: impl ChunkerTransformer + 'static) -> Self {
334 let chunker = Arc::new(chunker);
335 let concurrency = chunker.concurrency().unwrap_or(self.concurrency);
336 self.stream = self
337 .stream
338 .map_ok(move |node| {
339 let chunker = Arc::clone(&chunker);
340 let span = trace_span!("then_chunk", chunker);
341
342 tokio::spawn(
343 async move {
344 node_trace_log!(chunker, node, "Chunking node");
345 chunker.transform_node(node).await
346 }
347 .instrument(span.or_current()),
348 )
349 .map_err(anyhow::Error::from)
350 })
351 .err_into::<anyhow::Error>()
352 .try_buffer_unordered(concurrency)
353 .try_flatten_unordered(None)
354 .boxed()
355 .into();
356
357 self
358 }
359
360 #[must_use]
375 pub fn then_store_with(mut self, storage: impl Persist + 'static) -> Self {
376 let storage = Arc::new(storage);
377 self.storage.push(storage.clone());
378 if storage.batch_size().is_some() {
380 self.stream = self
381 .stream
382 .try_chunks(storage.batch_size().unwrap())
383 .map_ok(move |nodes| {
384 let storage = Arc::clone(&storage);
385 let span = trace_span!("then_store_with_batched", storage);
386
387 tokio::spawn(
388 async move {
389 batch_node_trace_log!(storage, nodes, "batch storing nodes");
390 storage.batch_store(nodes).await
391 }
392 .instrument(span.or_current()),
393 )
394 .map_err(anyhow::Error::from)
395 })
396 .err_into::<anyhow::Error>()
397 .try_buffer_unordered(self.concurrency)
398 .try_flatten_unordered(None)
399 .boxed()
400 .into();
401 } else {
402 self.stream = self
403 .stream
404 .map_ok(move |node| {
405 let storage = Arc::clone(&storage);
406 let span = trace_span!("then_store_with", storage);
407
408 tokio::spawn(
409 async move {
410 node_trace_log!(storage, node, "Storing node");
411
412 storage.store(node).await
413 }
414 .instrument(span.or_current()),
415 )
416 .err_into::<anyhow::Error>()
417 })
418 .try_buffer_unordered(self.concurrency)
419 .map(|x| x.and_then(|x| x))
420 .boxed()
421 .into();
422 }
423
424 self
425 }
426
427 #[must_use]
442 pub fn split_by<P>(self, predicate: P) -> (Self, Self)
443 where
444 P: Fn(&Result<Node>) -> bool + Send + Sync + 'static,
445 {
446 let predicate = Arc::new(predicate);
447
448 let (left_tx, left_rx) = mpsc::channel(1000);
449 let (right_tx, right_rx) = mpsc::channel(1000);
450
451 let stream = self.stream;
452 let span = trace_span!("split_by");
453 tokio::spawn(
454 async move {
455 stream
456 .for_each_concurrent(self.concurrency, move |item| {
457 let predicate = Arc::clone(&predicate);
458 let left_tx = left_tx.clone();
459 let right_tx = right_tx.clone();
460 async move {
461 if predicate(&item) {
462 tracing::trace!(?item, "Sending to left stream");
463 left_tx
464 .send(item)
465 .await
466 .expect("Failed to send to left stream");
467 } else {
468 tracing::trace!(?item, "Sending to right stream");
469 right_tx
470 .send(item)
471 .await
472 .expect("Failed to send to right stream");
473 }
474 }
475 })
476 .await;
477 }
478 .instrument(span.or_current()),
479 );
480
481 let left_pipeline = Self {
482 stream: left_rx.into(),
483 storage: self.storage.clone(),
484 concurrency: self.concurrency,
485 indexing_defaults: self.indexing_defaults.clone(),
486 batch_size: self.batch_size,
487 cache_sender: self.cache_sender.clone(),
488 cache_handle: self.cache_handle.clone(),
489 };
490
491 let right_pipeline = Self {
492 stream: right_rx.into(),
493 storage: self.storage.clone(),
494 concurrency: self.concurrency,
495 indexing_defaults: self.indexing_defaults.clone(),
496 batch_size: self.batch_size,
497 cache_sender: self.cache_sender.clone(),
498 cache_handle: self.cache_handle.clone(),
499 };
500
501 (left_pipeline, right_pipeline)
502 }
503
504 #[must_use]
510 pub fn merge(self, other: Self) -> Self {
511 let stream = tokio_stream::StreamExt::merge(self.stream, other.stream);
512
513 Self {
514 stream: stream.boxed().into(),
515 ..self
516 }
517 }
518
519 #[must_use]
524 pub fn throttle(mut self, duration: impl Into<Duration>) -> Self {
525 self.stream = tokio_stream::StreamExt::throttle(self.stream, duration.into())
526 .boxed()
527 .into();
528 self
529 }
530
531 #[must_use]
536 pub fn filter_errors(mut self) -> Self {
537 self.stream = self
538 .stream
539 .filter_map(|result| async {
540 match result {
541 Ok(node) => Some(Ok(node)),
542 Err(_e) => None,
543 }
544 })
545 .boxed()
546 .into();
547 self
548 }
549
550 #[must_use]
556 pub fn filter<F>(mut self, filter: F) -> Self
557 where
558 F: Fn(&Result<Node>) -> bool + Send + Sync + 'static,
559 {
560 self.stream = self
561 .stream
562 .filter(move |result| {
563 let will_retain = filter(result);
564
565 async move { will_retain }
566 })
567 .boxed()
568 .into();
569 self
570 }
571
572 #[must_use]
576 pub fn log_all(self) -> Self {
577 self.log_errors().log_nodes()
578 }
579
580 #[must_use]
584 pub fn log_errors(mut self) -> Self {
585 self.stream = self
586 .stream
587 .inspect_err(|e| tracing::error!(?e, "Error processing node"))
588 .boxed()
589 .into();
590 self
591 }
592
593 #[must_use]
597 pub fn log_nodes(mut self) -> Self {
598 self.stream = self
599 .stream
600 .inspect_ok(|node| tracing::debug!(?node, "Processed node: {:?}", node))
601 .boxed()
602 .into();
603 self
604 }
605
606 #[tracing::instrument(skip_all, fields(total_nodes), name = "indexing_pipeline.run")]
619 pub async fn run(mut self) -> Result<()> {
620 tracing::info!(
621 "Starting indexing pipeline with {} concurrency",
622 self.concurrency
623 );
624 let now = std::time::Instant::now();
625 if self.storage.is_empty() {
626 anyhow::bail!("No storage configured for indexing pipeline");
627 }
628
629 let setup_futures = self
631 .storage
632 .into_iter()
633 .map(|storage| async move { storage.setup().await })
634 .collect::<Vec<_>>();
635 futures_util::future::try_join_all(setup_futures).await?;
636
637 let mut cache_keys = HashSet::new();
638
639 let mut total_nodes = 0;
640
641 while let Some(node) = self.stream.try_next().await? {
642 total_nodes += 1;
643
644 let cache_key = node.parent_id().unwrap_or(node.id());
645
646 if let Some(sender) = &self.cache_sender {
647 if cache_keys.insert(cache_key) {
648 let _ = sender.send(cache_key).await;
649 }
650 }
651 }
652
653 let elapsed_in_seconds = now.elapsed().as_secs();
654 tracing::info!(
655 elapsed_in_seconds,
656 "Processed {total_nodes} nodes in {elapsed_in_seconds} seconds",
657 );
658
659 tracing::Span::current().record("total_nodes", total_nodes);
660
661 drop(self.cache_sender);
663
664 if let Some(handle) = self.cache_handle {
665 handle.await;
666 }
667
668 Ok(())
669 }
670}
671
672#[cfg(test)]
673mod tests {
674
675 use super::*;
676 use crate::persist::MemoryStorage;
677 use mockall::Sequence;
678 use swiftide_core::indexing::*;
679
680 #[test_log::test(tokio::test)]
682 async fn test_simple_run() {
683 let mut loader = MockLoader::new();
684 let mut transformer = MockTransformer::new();
685 let mut batch_transformer = MockBatchableTransformer::new();
686 let mut chunker = MockChunkerTransformer::new();
687 let mut storage = MockPersist::new();
688
689 let mut seq = Sequence::new();
690
691 loader
692 .expect_into_stream()
693 .times(1)
694 .in_sequence(&mut seq)
695 .returning(|| vec![Ok(Node::default())].into());
696
697 transformer.expect_transform_node().returning(|mut node| {
698 node.chunk = "transformed".to_string();
699 Ok(node)
700 });
701 transformer.expect_concurrency().returning(|| None);
702 transformer.expect_name().returning(|| "transformer");
703
704 batch_transformer
705 .expect_batch_transform()
706 .times(1)
707 .in_sequence(&mut seq)
708 .returning(|nodes| IndexingStream::iter(nodes.into_iter().map(Ok)));
709 batch_transformer.expect_concurrency().returning(|| None);
710 batch_transformer.expect_name().returning(|| "transformer");
711 batch_transformer.expect_batch_size().returning(|| None);
712
713 chunker
714 .expect_transform_node()
715 .times(1)
716 .in_sequence(&mut seq)
717 .returning(|node| {
718 let mut nodes = vec![];
719 for i in 0..3 {
720 let mut node = node.clone();
721 node.chunk = format!("transformed_chunk_{i}");
722 nodes.push(Ok(node));
723 }
724 nodes.into()
725 });
726 chunker.expect_concurrency().returning(|| None);
727 chunker.expect_name().returning(|| "chunker");
728
729 storage.expect_setup().returning(|| Ok(()));
730 storage.expect_batch_size().returning(|| None);
731 storage
732 .expect_store()
733 .times(3)
734 .in_sequence(&mut seq)
735 .withf(|node| node.chunk.starts_with("transformed_chunk_"))
736 .returning(Ok);
737 storage.expect_name().returning(|| "storage");
738
739 let pipeline = Pipeline::from_loader(loader)
740 .then(transformer)
741 .then_in_batch(batch_transformer)
742 .then_chunk(chunker)
743 .then_store_with(storage);
744
745 pipeline.run().await.unwrap();
746 }
747
748 #[tokio::test]
749 async fn test_skipping_errors() {
750 let mut loader = MockLoader::new();
751 let mut transformer = MockTransformer::new();
752 let mut storage = MockPersist::new();
753 let mut seq = Sequence::new();
754 loader
755 .expect_into_stream()
756 .times(1)
757 .in_sequence(&mut seq)
758 .returning(|| vec![Ok(Node::default())].into());
759 transformer
760 .expect_transform_node()
761 .returning(|_node| Err(anyhow::anyhow!("Error transforming node")));
762 transformer.expect_concurrency().returning(|| None);
763 transformer.expect_name().returning(|| "mock");
764 storage.expect_setup().returning(|| Ok(()));
765 storage.expect_batch_size().returning(|| None);
766 storage.expect_store().times(0).returning(Ok);
767 let pipeline = Pipeline::from_loader(loader)
768 .then(transformer)
769 .then_store_with(storage)
770 .filter_errors();
771 pipeline.run().await.unwrap();
772 }
773
774 #[tokio::test]
775 async fn test_concurrent_calls_with_simple_transformer() {
776 let mut loader = MockLoader::new();
777 let mut transformer = MockTransformer::new();
778 let mut storage = MockPersist::new();
779 let mut seq = Sequence::new();
780 loader
781 .expect_into_stream()
782 .times(1)
783 .in_sequence(&mut seq)
784 .returning(|| {
785 vec![
786 Ok(Node::default()),
787 Ok(Node::default()),
788 Ok(Node::default()),
789 ]
790 .into()
791 });
792 transformer
793 .expect_transform_node()
794 .times(3)
795 .in_sequence(&mut seq)
796 .returning(|mut node| {
797 node.chunk = "transformed".to_string();
798 Ok(node)
799 });
800 transformer.expect_concurrency().returning(|| Some(3));
801 transformer.expect_name().returning(|| "transformer");
802 storage.expect_setup().returning(|| Ok(()));
803 storage.expect_batch_size().returning(|| None);
804 storage.expect_store().times(3).returning(Ok);
805 storage.expect_name().returning(|| "storage");
806
807 let pipeline = Pipeline::from_loader(loader)
808 .then(transformer)
809 .then_store_with(storage);
810 pipeline.run().await.unwrap();
811 }
812
813 #[tokio::test]
814 async fn test_arbitrary_closures_as_transformer() {
815 let mut loader = MockLoader::new();
816 let transformer = |node: Node| {
817 let mut node = node;
818 node.chunk = "transformed".to_string();
819 Ok(node)
820 };
821 let storage = MemoryStorage::default();
822 let mut seq = Sequence::new();
823 loader
824 .expect_into_stream()
825 .times(1)
826 .in_sequence(&mut seq)
827 .returning(|| vec![Ok(Node::default())].into());
828
829 let pipeline = Pipeline::from_loader(loader)
830 .then(transformer)
831 .then_store_with(storage.clone());
832 pipeline.run().await.unwrap();
833
834 dbg!(storage.clone());
835 let processed_node = storage.get("0").await.unwrap();
836 assert_eq!(processed_node.chunk, "transformed");
837 }
838
839 #[tokio::test]
840 async fn test_arbitrary_closures_as_batch_transformer() {
841 let mut loader = MockLoader::new();
842 let batch_transformer = |nodes: Vec<Node>| {
843 IndexingStream::iter(nodes.into_iter().map(|mut node| {
844 node.chunk = "transformed".to_string();
845 Ok(node)
846 }))
847 };
848 let storage = MemoryStorage::default();
849 let mut seq = Sequence::new();
850 loader
851 .expect_into_stream()
852 .times(1)
853 .in_sequence(&mut seq)
854 .returning(|| vec![Ok(Node::default())].into());
855
856 let pipeline = Pipeline::from_loader(loader)
857 .then_in_batch(batch_transformer)
858 .then_store_with(storage.clone());
859 pipeline.run().await.unwrap();
860
861 dbg!(storage.clone());
862 let processed_node = storage.get("0").await.unwrap();
863 assert_eq!(processed_node.chunk, "transformed");
864 }
865
866 #[tokio::test]
867 async fn test_filter_closure() {
868 let mut loader = MockLoader::new();
869 let storage = MemoryStorage::default();
870 let mut seq = Sequence::new();
871 loader
872 .expect_into_stream()
873 .times(1)
874 .in_sequence(&mut seq)
875 .returning(|| {
876 vec![
877 Ok(Node::default()),
878 Ok(Node::new("skip")),
879 Ok(Node::default()),
880 ]
881 .into()
882 });
883 let pipeline = Pipeline::from_loader(loader)
884 .filter(|result| {
885 let node = result.as_ref().unwrap();
886 node.chunk != "skip"
887 })
888 .then_store_with(storage.clone());
889 pipeline.run().await.unwrap();
890 let nodes = storage.get_all().await;
891 assert_eq!(nodes.len(), 2);
892 }
893
894 #[test_log::test(tokio::test)]
895 async fn test_split_and_merge() {
896 let mut loader = MockLoader::new();
897 let storage = MemoryStorage::default();
898 let mut seq = Sequence::new();
899 loader
900 .expect_into_stream()
901 .times(1)
902 .in_sequence(&mut seq)
903 .returning(|| {
904 vec![
905 Ok(Node::default()),
906 Ok(Node::new("will go left")),
907 Ok(Node::default()),
908 ]
909 .into()
910 });
911
912 let pipeline = Pipeline::from_loader(loader);
913 let (mut left, mut right) = pipeline.split_by(|node| {
914 if let Ok(node) = node {
915 node.chunk.starts_with("will go left")
916 } else {
917 false
918 }
919 });
920
921 left = left
923 .then(move |mut node: Node| {
924 node.chunk = "left".to_string();
925
926 Ok(node)
927 })
928 .log_all();
929
930 right = right.then(move |mut node: Node| {
931 node.chunk = "right".to_string();
932 Ok(node)
933 });
934
935 left.merge(right)
936 .then_store_with(storage.clone())
937 .run()
938 .await
939 .unwrap();
940 dbg!(storage.clone());
941
942 let all_nodes = storage.get_all_values().await;
943 assert_eq!(
944 all_nodes.iter().filter(|node| node.chunk == "left").count(),
945 1
946 );
947 assert_eq!(
948 all_nodes
949 .iter()
950 .filter(|node| node.chunk == "right")
951 .count(),
952 2
953 );
954 }
955
956 #[tokio::test]
957 async fn test_all_steps_should_work_as_dyn_box() {
958 let mut loader = MockLoader::new();
959 loader
960 .expect_into_stream_boxed()
961 .returning(|| vec![Ok(Node::default())].into());
962
963 let mut transformer = MockTransformer::new();
964 transformer.expect_transform_node().returning(Ok);
965 transformer.expect_concurrency().returning(|| None);
966 transformer.expect_name().returning(|| "mock");
967
968 let mut batch_transformer = MockBatchableTransformer::new();
969 batch_transformer
970 .expect_batch_transform()
971 .returning(std::convert::Into::into);
972 batch_transformer.expect_concurrency().returning(|| None);
973 batch_transformer.expect_name().returning(|| "mock");
974 let mut chunker = MockChunkerTransformer::new();
975 chunker
976 .expect_transform_node()
977 .returning(|node| vec![node].into());
978 chunker.expect_concurrency().returning(|| None);
979 chunker.expect_name().returning(|| "mock");
980
981 let mut storage = MockPersist::new();
982 storage.expect_setup().returning(|| Ok(()));
983 storage.expect_store().returning(Ok);
984 storage.expect_batch_size().returning(|| None);
985 storage.expect_name().returning(|| "mock");
986
987 let pipeline = Pipeline::from_loader(Box::new(loader) as Box<dyn Loader>)
988 .then(Box::new(transformer) as Box<dyn Transformer>)
989 .then_in_batch(Box::new(batch_transformer) as Box<dyn BatchableTransformer>)
990 .then_chunk(Box::new(chunker) as Box<dyn ChunkerTransformer>)
991 .then_store_with(Box::new(storage) as Box<dyn Persist>);
992 pipeline.run().await.unwrap();
993 }
994
995 #[tokio::test]
996 async fn test_nodes_only_cached_on_success() {
997 let mut loader = MockLoader::new();
998 let mut cache = MockNodeCache::new();
999 let mut storage = MockPersist::new();
1000 let mut transformer = MockTransformer::new();
1001
1002 loader
1003 .expect_into_stream()
1004 .returning(|| vec![Ok(Node::default())].into());
1005
1006 cache.expect_get().times(1).returning(|_| false);
1007 cache.expect_name().returning(|| "test_cache");
1008
1009 transformer
1010 .expect_transform_node()
1011 .returning(|_| Err(anyhow::anyhow!("Transformation failed")));
1012 transformer.expect_concurrency().returning(|| None);
1013 transformer
1014 .expect_name()
1015 .returning(|| "failing_transformer");
1016
1017 storage.expect_setup().returning(|| Ok(()));
1018 storage.expect_batch_size().returning(|| None);
1019
1020 cache.expect_set().times(0);
1021
1022 let pipeline = Pipeline::from_loader(loader)
1023 .filter_cached(cache)
1024 .then(transformer)
1025 .then_store_with(storage)
1026 .filter_errors();
1027
1028 pipeline.run().await.unwrap();
1029 }
1030
1031 #[tokio::test]
1032 async fn test_nodes_cached_on_successful_storage() {
1033 let mut loader = MockLoader::new();
1034 let mut cache = MockNodeCache::new();
1035 let storage = MemoryStorage::default();
1036
1037 loader
1038 .expect_into_stream()
1039 .returning(|| vec![Ok(Node::default())].into());
1040
1041 cache.expect_name().returning(|| "test_cache");
1042 cache.expect_get().times(1).returning(|_| false);
1043
1044 cache.expect_set().times(1).returning(|_| ());
1045
1046 let pipeline = Pipeline::from_loader(loader)
1047 .filter_cached(cache)
1048 .then_store_with(storage);
1049
1050 pipeline.run().await.unwrap();
1051 }
1052}