swiftide_core/
indexing_traits.rs

1//! Traits in Swiftide allow for easy extendability
2//!
3//! All steps defined in the indexing pipeline and the generic transformers can also take a
4//! trait. To bring your own transformers, models and loaders, all you need to do is implement the
5//! trait and it should work out of the box.
6use crate::Embeddings;
7use crate::node::Node;
8use crate::{
9    SparseEmbeddings, indexing_defaults::IndexingDefaults, indexing_stream::IndexingStream,
10};
11use std::fmt::Debug;
12use std::sync::Arc;
13
14use crate::chat_completion::errors::LanguageModelError;
15use crate::prompt::Prompt;
16use anyhow::Result;
17use async_trait::async_trait;
18
19pub use dyn_clone::DynClone;
20/// All traits are easily mockable under tests
21#[cfg(feature = "test-utils")]
22#[doc(hidden)]
23use mockall::{mock, predicate::str};
24
25#[async_trait]
26/// Transforms single nodes into single nodes
27pub trait Transformer: Send + Sync + DynClone {
28    async fn transform_node(&self, node: Node) -> Result<Node>;
29
30    /// Overrides the default concurrency of the pipeline
31    fn concurrency(&self) -> Option<usize> {
32        None
33    }
34
35    fn name(&self) -> &'static str {
36        let name = std::any::type_name::<Self>();
37        name.split("::").last().unwrap_or(name)
38    }
39}
40
41dyn_clone::clone_trait_object!(Transformer);
42
43#[cfg(feature = "test-utils")]
44mock! {
45    #[derive(Debug)]
46    pub Transformer {}
47
48    #[async_trait]
49    impl Transformer for Transformer {
50        async fn transform_node(&self, node: Node) -> Result<Node>;
51        fn concurrency(&self) -> Option<usize>;
52        fn name(&self) -> &'static str;
53    }
54
55    impl Clone for Transformer {
56        fn clone(&self) -> Self;
57    }
58}
59
60#[async_trait]
61impl Transformer for Box<dyn Transformer> {
62    async fn transform_node(&self, node: Node) -> Result<Node> {
63        self.as_ref().transform_node(node).await
64    }
65    fn concurrency(&self) -> Option<usize> {
66        self.as_ref().concurrency()
67    }
68    fn name(&self) -> &'static str {
69        self.as_ref().name()
70    }
71}
72
73#[async_trait]
74impl Transformer for Arc<dyn Transformer> {
75    async fn transform_node(&self, node: Node) -> Result<Node> {
76        self.as_ref().transform_node(node).await
77    }
78    fn concurrency(&self) -> Option<usize> {
79        self.as_ref().concurrency()
80    }
81    fn name(&self) -> &'static str {
82        self.as_ref().name()
83    }
84}
85
86#[async_trait]
87impl Transformer for &dyn Transformer {
88    async fn transform_node(&self, node: Node) -> Result<Node> {
89        (*self).transform_node(node).await
90    }
91    fn concurrency(&self) -> Option<usize> {
92        (*self).concurrency()
93    }
94}
95
96#[async_trait]
97/// Use a closure as a transformer
98impl<F> Transformer for F
99where
100    F: Fn(Node) -> Result<Node> + Send + Sync + Clone,
101{
102    async fn transform_node(&self, node: Node) -> Result<Node> {
103        self(node)
104    }
105}
106
107#[async_trait]
108/// Transforms batched single nodes into streams of nodes
109pub trait BatchableTransformer: Send + Sync + DynClone {
110    /// Transforms a batch of nodes into a stream of nodes
111    async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream;
112
113    /// Overrides the default concurrency of the pipeline
114    fn concurrency(&self) -> Option<usize> {
115        None
116    }
117
118    fn name(&self) -> &'static str {
119        let name = std::any::type_name::<Self>();
120        name.split("::").last().unwrap_or(name)
121    }
122
123    /// Overrides the default batch size of the pipeline
124    fn batch_size(&self) -> Option<usize> {
125        None
126    }
127}
128
129dyn_clone::clone_trait_object!(BatchableTransformer);
130
131#[cfg(feature = "test-utils")]
132mock! {
133    #[derive(Debug)]
134    pub BatchableTransformer {}
135
136    #[async_trait]
137    impl BatchableTransformer for BatchableTransformer {
138        async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream;
139        fn name(&self) -> &'static str;
140        fn batch_size(&self) -> Option<usize>;
141        fn concurrency(&self) -> Option<usize>;
142    }
143
144    impl Clone for BatchableTransformer {
145        fn clone(&self) -> Self;
146    }
147}
148#[async_trait]
149/// Use a closure as a batchable transformer
150impl<F> BatchableTransformer for F
151where
152    F: Fn(Vec<Node>) -> IndexingStream + Send + Sync + Clone,
153{
154    async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream {
155        self(nodes)
156    }
157}
158
159#[async_trait]
160impl BatchableTransformer for Box<dyn BatchableTransformer> {
161    async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream {
162        self.as_ref().batch_transform(nodes).await
163    }
164    fn concurrency(&self) -> Option<usize> {
165        self.as_ref().concurrency()
166    }
167    fn name(&self) -> &'static str {
168        self.as_ref().name()
169    }
170}
171
172#[async_trait]
173impl BatchableTransformer for Arc<dyn BatchableTransformer> {
174    async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream {
175        self.as_ref().batch_transform(nodes).await
176    }
177    fn concurrency(&self) -> Option<usize> {
178        self.as_ref().concurrency()
179    }
180    fn name(&self) -> &'static str {
181        self.as_ref().name()
182    }
183}
184
185#[async_trait]
186impl BatchableTransformer for &dyn BatchableTransformer {
187    async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream {
188        (*self).batch_transform(nodes).await
189    }
190    fn concurrency(&self) -> Option<usize> {
191        (*self).concurrency()
192    }
193}
194
195/// Starting point of a stream
196pub trait Loader: DynClone + Send + Sync {
197    fn into_stream(self) -> IndexingStream;
198
199    /// Intended for use with Box<dyn Loader>
200    ///
201    /// Only needed if you use trait objects (Box<dyn Loader>)
202    ///
203    /// # Example
204    ///
205    /// ```ignore
206    /// fn into_stream_boxed(self: Box<Self>) -> IndexingStream {
207    ///    self.into_stream()
208    ///  }
209    /// ```
210    fn into_stream_boxed(self: Box<Self>) -> IndexingStream {
211        unimplemented!(
212            "Please implement into_stream_boxed for your loader, it needs to be implemented on the concrete type"
213        )
214    }
215
216    fn name(&self) -> &'static str {
217        let name = std::any::type_name::<Self>();
218        name.split("::").last().unwrap_or(name)
219    }
220}
221
222dyn_clone::clone_trait_object!(Loader);
223
224#[cfg(feature = "test-utils")]
225mock! {
226    #[derive(Debug)]
227    pub Loader {}
228
229    #[async_trait]
230    impl Loader for Loader {
231        fn into_stream(self) -> IndexingStream;
232        fn into_stream_boxed(self: Box<Self>) -> IndexingStream;
233        fn name(&self) -> &'static str;
234    }
235
236    impl Clone for Loader {
237        fn clone(&self) -> Self;
238    }
239}
240
241impl Loader for Box<dyn Loader> {
242    fn into_stream(self) -> IndexingStream {
243        Loader::into_stream_boxed(self)
244    }
245
246    fn into_stream_boxed(self: Box<Self>) -> IndexingStream {
247        Loader::into_stream(*self)
248    }
249    fn name(&self) -> &'static str {
250        self.as_ref().name()
251    }
252}
253
254impl Loader for &dyn Loader {
255    fn into_stream(self) -> IndexingStream {
256        Loader::into_stream_boxed(Box::new(self))
257    }
258
259    fn into_stream_boxed(self: Box<Self>) -> IndexingStream {
260        Loader::into_stream(*self)
261    }
262}
263
264#[async_trait]
265/// Turns one node into many nodes
266pub trait ChunkerTransformer: Send + Sync + DynClone {
267    async fn transform_node(&self, node: Node) -> IndexingStream;
268
269    /// Overrides the default concurrency of the pipeline
270    fn concurrency(&self) -> Option<usize> {
271        None
272    }
273
274    fn name(&self) -> &'static str {
275        let name = std::any::type_name::<Self>();
276        name.split("::").last().unwrap_or(name)
277    }
278}
279
280dyn_clone::clone_trait_object!(ChunkerTransformer);
281
282#[cfg(feature = "test-utils")]
283mock! {
284    #[derive(Debug)]
285    pub ChunkerTransformer {}
286
287    #[async_trait]
288    impl ChunkerTransformer for ChunkerTransformer {
289    async fn transform_node(&self, node: Node) -> IndexingStream;
290        fn name(&self) -> &'static str;
291        fn concurrency(&self) -> Option<usize>;
292    }
293
294    impl Clone for ChunkerTransformer {
295        fn clone(&self) -> Self;
296    }
297}
298#[async_trait]
299impl ChunkerTransformer for Box<dyn ChunkerTransformer> {
300    async fn transform_node(&self, node: Node) -> IndexingStream {
301        self.as_ref().transform_node(node).await
302    }
303    fn concurrency(&self) -> Option<usize> {
304        self.as_ref().concurrency()
305    }
306    fn name(&self) -> &'static str {
307        self.as_ref().name()
308    }
309}
310
311#[async_trait]
312impl ChunkerTransformer for Arc<dyn ChunkerTransformer> {
313    async fn transform_node(&self, node: Node) -> IndexingStream {
314        self.as_ref().transform_node(node).await
315    }
316    fn concurrency(&self) -> Option<usize> {
317        self.as_ref().concurrency()
318    }
319    fn name(&self) -> &'static str {
320        self.as_ref().name()
321    }
322}
323
324#[async_trait]
325impl ChunkerTransformer for &dyn ChunkerTransformer {
326    async fn transform_node(&self, node: Node) -> IndexingStream {
327        (*self).transform_node(node).await
328    }
329    fn concurrency(&self) -> Option<usize> {
330        (*self).concurrency()
331    }
332}
333
334#[async_trait]
335impl<F> ChunkerTransformer for F
336where
337    F: Fn(Node) -> IndexingStream + Send + Sync + Clone,
338{
339    async fn transform_node(&self, node: Node) -> IndexingStream {
340        self(node)
341    }
342}
343
344// #[cfg_attr(feature = "test-utils", automock)]
345#[async_trait]
346/// Caches nodes, typically by their path and hash
347/// Recommended to namespace on the storage
348///
349/// For now just bool return value for easy filter
350pub trait NodeCache: Send + Sync + Debug + DynClone {
351    async fn get(&self, node: &Node) -> bool;
352    async fn set(&self, node: &Node);
353
354    /// Optionally provide a method to clear the cache
355    async fn clear(&self) -> Result<()> {
356        unimplemented!("Clear not implemented")
357    }
358
359    fn name(&self) -> &'static str {
360        let name = std::any::type_name::<Self>();
361        name.split("::").last().unwrap_or(name)
362    }
363}
364
365dyn_clone::clone_trait_object!(NodeCache);
366
367#[cfg(feature = "test-utils")]
368mock! {
369    #[derive(Debug)]
370    pub NodeCache {}
371
372    #[async_trait]
373    impl NodeCache for NodeCache {
374        async fn get(&self, node: &Node) -> bool;
375        async fn set(&self, node: &Node);
376        async fn clear(&self) -> Result<()>;
377        fn name(&self) -> &'static str;
378
379    }
380
381    impl Clone for NodeCache {
382        fn clone(&self) -> Self;
383    }
384}
385
386#[async_trait]
387impl NodeCache for Box<dyn NodeCache> {
388    async fn get(&self, node: &Node) -> bool {
389        self.as_ref().get(node).await
390    }
391    async fn set(&self, node: &Node) {
392        self.as_ref().set(node).await;
393    }
394    async fn clear(&self) -> Result<()> {
395        self.as_ref().clear().await
396    }
397    fn name(&self) -> &'static str {
398        self.as_ref().name()
399    }
400}
401
402#[async_trait]
403impl NodeCache for Arc<dyn NodeCache> {
404    async fn get(&self, node: &Node) -> bool {
405        self.as_ref().get(node).await
406    }
407    async fn set(&self, node: &Node) {
408        self.as_ref().set(node).await;
409    }
410    async fn clear(&self) -> Result<()> {
411        self.as_ref().clear().await
412    }
413    fn name(&self) -> &'static str {
414        self.as_ref().name()
415    }
416}
417
418#[async_trait]
419impl NodeCache for &dyn NodeCache {
420    async fn get(&self, node: &Node) -> bool {
421        (*self).get(node).await
422    }
423    async fn set(&self, node: &Node) {
424        (*self).set(node).await;
425    }
426    async fn clear(&self) -> Result<()> {
427        (*self).clear().await
428    }
429}
430
431#[async_trait]
432/// Embeds a list of strings and returns its embeddings.
433/// Assumes the strings will be moved.
434pub trait EmbeddingModel: Send + Sync + Debug + DynClone {
435    async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError>;
436
437    fn name(&self) -> &'static str {
438        let name = std::any::type_name::<Self>();
439        name.split("::").last().unwrap_or(name)
440    }
441}
442
443dyn_clone::clone_trait_object!(EmbeddingModel);
444
445#[cfg(feature = "test-utils")]
446mock! {
447    #[derive(Debug)]
448    pub EmbeddingModel {}
449
450    #[async_trait]
451    impl EmbeddingModel for EmbeddingModel {
452        async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError>;
453        fn name(&self) -> &'static str;
454    }
455
456    impl Clone for EmbeddingModel {
457        fn clone(&self) -> Self;
458    }
459}
460
461#[async_trait]
462impl EmbeddingModel for Box<dyn EmbeddingModel> {
463    async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> {
464        self.as_ref().embed(input).await
465    }
466
467    fn name(&self) -> &'static str {
468        self.as_ref().name()
469    }
470}
471
472#[async_trait]
473impl EmbeddingModel for Arc<dyn EmbeddingModel> {
474    async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> {
475        self.as_ref().embed(input).await
476    }
477
478    fn name(&self) -> &'static str {
479        self.as_ref().name()
480    }
481}
482
483#[async_trait]
484impl EmbeddingModel for &dyn EmbeddingModel {
485    async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> {
486        (*self).embed(input).await
487    }
488}
489
490#[async_trait]
491/// Embeds a list of strings and returns its embeddings.
492/// Assumes the strings will be moved.
493pub trait SparseEmbeddingModel: Send + Sync + Debug + DynClone {
494    async fn sparse_embed(
495        &self,
496        input: Vec<String>,
497    ) -> Result<SparseEmbeddings, LanguageModelError>;
498
499    fn name(&self) -> &'static str {
500        let name = std::any::type_name::<Self>();
501        name.split("::").last().unwrap_or(name)
502    }
503}
504
505dyn_clone::clone_trait_object!(SparseEmbeddingModel);
506
507#[cfg(feature = "test-utils")]
508mock! {
509    #[derive(Debug)]
510    pub SparseEmbeddingModel {}
511
512    #[async_trait]
513    impl SparseEmbeddingModel for SparseEmbeddingModel {
514        async fn sparse_embed(&self, input: Vec<String>) -> Result<SparseEmbeddings, LanguageModelError>;
515        fn name(&self) -> &'static str;
516    }
517
518    impl Clone for SparseEmbeddingModel {
519        fn clone(&self) -> Self;
520    }
521}
522
523#[async_trait]
524impl SparseEmbeddingModel for Box<dyn SparseEmbeddingModel> {
525    async fn sparse_embed(
526        &self,
527        input: Vec<String>,
528    ) -> Result<SparseEmbeddings, LanguageModelError> {
529        self.as_ref().sparse_embed(input).await
530    }
531
532    fn name(&self) -> &'static str {
533        self.as_ref().name()
534    }
535}
536
537#[async_trait]
538impl SparseEmbeddingModel for Arc<dyn SparseEmbeddingModel> {
539    async fn sparse_embed(
540        &self,
541        input: Vec<String>,
542    ) -> Result<SparseEmbeddings, LanguageModelError> {
543        self.as_ref().sparse_embed(input).await
544    }
545
546    fn name(&self) -> &'static str {
547        self.as_ref().name()
548    }
549}
550
551#[async_trait]
552impl SparseEmbeddingModel for &dyn SparseEmbeddingModel {
553    async fn sparse_embed(
554        &self,
555        input: Vec<String>,
556    ) -> Result<SparseEmbeddings, LanguageModelError> {
557        (*self).sparse_embed(input).await
558    }
559}
560
561#[async_trait]
562/// Given a string prompt, queries an LLM
563pub trait SimplePrompt: Debug + Send + Sync + DynClone {
564    // Takes a simple prompt, prompts the llm and returns the response
565    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError>;
566
567    fn name(&self) -> &'static str {
568        let name = std::any::type_name::<Self>();
569        name.split("::").last().unwrap_or(name)
570    }
571}
572
573dyn_clone::clone_trait_object!(SimplePrompt);
574
575#[cfg(feature = "test-utils")]
576mock! {
577    #[derive(Debug)]
578    pub SimplePrompt {}
579
580    #[async_trait]
581    impl SimplePrompt for SimplePrompt {
582        async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError>;
583        fn name(&self) -> &'static str;
584    }
585
586    impl Clone for SimplePrompt {
587        fn clone(&self) -> Self;
588    }
589}
590
591#[async_trait]
592impl SimplePrompt for Box<dyn SimplePrompt> {
593    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
594        self.as_ref().prompt(prompt).await
595    }
596
597    fn name(&self) -> &'static str {
598        self.as_ref().name()
599    }
600}
601
602#[async_trait]
603impl SimplePrompt for Arc<dyn SimplePrompt> {
604    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
605        self.as_ref().prompt(prompt).await
606    }
607
608    fn name(&self) -> &'static str {
609        self.as_ref().name()
610    }
611}
612
613#[async_trait]
614impl SimplePrompt for &dyn SimplePrompt {
615    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
616        (*self).prompt(prompt).await
617    }
618}
619
620#[async_trait]
621/// Persists nodes
622pub trait Persist: Debug + Send + Sync + DynClone {
623    async fn setup(&self) -> Result<()>;
624    async fn store(&self, node: Node) -> Result<Node>;
625    async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream;
626    fn batch_size(&self) -> Option<usize> {
627        None
628    }
629
630    fn name(&self) -> &'static str {
631        let name = std::any::type_name::<Self>();
632        name.split("::").last().unwrap_or(name)
633    }
634}
635
636dyn_clone::clone_trait_object!(Persist);
637
638#[cfg(feature = "test-utils")]
639mock! {
640    #[derive(Debug)]
641    pub Persist {}
642
643    #[async_trait]
644    impl Persist for Persist {
645        async fn setup(&self) -> Result<()>;
646        async fn store(&self, node: Node) -> Result<Node>;
647        async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream;
648        fn batch_size(&self) -> Option<usize>;
649
650        fn name(&self) -> &'static str;
651    }
652
653    impl Clone for Persist {
654        fn clone(&self) -> Self;
655    }
656}
657
658#[async_trait]
659impl Persist for Box<dyn Persist> {
660    async fn setup(&self) -> Result<()> {
661        self.as_ref().setup().await
662    }
663    async fn store(&self, node: Node) -> Result<Node> {
664        self.as_ref().store(node).await
665    }
666    async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
667        self.as_ref().batch_store(nodes).await
668    }
669    fn batch_size(&self) -> Option<usize> {
670        self.as_ref().batch_size()
671    }
672    fn name(&self) -> &'static str {
673        self.as_ref().name()
674    }
675}
676
677#[async_trait]
678impl Persist for Arc<dyn Persist> {
679    async fn setup(&self) -> Result<()> {
680        self.as_ref().setup().await
681    }
682    async fn store(&self, node: Node) -> Result<Node> {
683        self.as_ref().store(node).await
684    }
685    async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
686        self.as_ref().batch_store(nodes).await
687    }
688    fn batch_size(&self) -> Option<usize> {
689        self.as_ref().batch_size()
690    }
691    fn name(&self) -> &'static str {
692        self.as_ref().name()
693    }
694}
695
696#[async_trait]
697impl Persist for &dyn Persist {
698    async fn setup(&self) -> Result<()> {
699        (*self).setup().await
700    }
701    async fn store(&self, node: Node) -> Result<Node> {
702        (*self).store(node).await
703    }
704    async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
705        (*self).batch_store(nodes).await
706    }
707    fn batch_size(&self) -> Option<usize> {
708        (*self).batch_size()
709    }
710}
711
712/// Allows for passing defaults from the pipeline to the transformer
713/// Required for batch transformers as at least a marker, implementation is not required
714pub trait WithIndexingDefaults {
715    fn with_indexing_defaults(&mut self, _indexing_defaults: IndexingDefaults) {}
716}
717
718/// Allows for passing defaults from the pipeline to the batch transformer
719/// Required for batch transformers as at least a marker, implementation is not required
720pub trait WithBatchIndexingDefaults {
721    fn with_indexing_defaults(&mut self, _indexing_defaults: IndexingDefaults) {}
722}
723
724impl WithIndexingDefaults for dyn Transformer {}
725impl WithIndexingDefaults for Box<dyn Transformer> {
726    fn with_indexing_defaults(&mut self, indexing_defaults: IndexingDefaults) {
727        self.as_mut().with_indexing_defaults(indexing_defaults);
728    }
729}
730impl WithBatchIndexingDefaults for dyn BatchableTransformer {}
731impl WithBatchIndexingDefaults for Box<dyn BatchableTransformer> {
732    fn with_indexing_defaults(&mut self, indexing_defaults: IndexingDefaults) {
733        self.as_mut().with_indexing_defaults(indexing_defaults);
734    }
735}
736
737impl<F> WithIndexingDefaults for F where F: Fn(Node) -> Result<Node> {}
738impl<F> WithBatchIndexingDefaults for F where F: Fn(Vec<Node>) -> IndexingStream {}
739
740#[cfg(feature = "test-utils")]
741impl WithIndexingDefaults for MockTransformer {}
742//
743#[cfg(feature = "test-utils")]
744impl WithBatchIndexingDefaults for MockBatchableTransformer {}