swiftide_core/
indexing_traits.rs

1//! Traits in Swiftide allow for easy extendability
2//!
3//! All steps defined in the indexing pipeline and the generic transformers can also take a
4//! trait. To bring your own transformers, models and loaders, all you need to do is implement the
5//! trait and it should work out of the box.
6use crate::Embeddings;
7use crate::node::Node;
8use crate::{
9    SparseEmbeddings, indexing_defaults::IndexingDefaults, indexing_stream::IndexingStream,
10};
11use std::fmt::Debug;
12use std::sync::Arc;
13
14use crate::chat_completion::errors::LanguageModelError;
15use crate::prompt::Prompt;
16use anyhow::Result;
17use async_trait::async_trait;
18
19pub use dyn_clone::DynClone;
20/// All traits are easily mockable under tests
21#[cfg(feature = "test-utils")]
22#[doc(hidden)]
23use mockall::{mock, predicate::str};
24
25#[async_trait]
26/// Transforms single nodes into single nodes
27pub trait Transformer: Send + Sync + DynClone {
28    async fn transform_node(&self, node: Node) -> Result<Node>;
29
30    /// Overrides the default concurrency of the pipeline
31    fn concurrency(&self) -> Option<usize> {
32        None
33    }
34
35    fn name(&self) -> &'static str {
36        let name = std::any::type_name::<Self>();
37        name.split("::").last().unwrap_or(name)
38    }
39}
40
41dyn_clone::clone_trait_object!(Transformer);
42
43#[cfg(feature = "test-utils")]
44mock! {
45    #[derive(Debug)]
46    pub Transformer {}
47
48    #[async_trait]
49    impl Transformer for Transformer {
50        async fn transform_node(&self, node: Node) -> Result<Node>;
51        fn concurrency(&self) -> Option<usize>;
52        fn name(&self) -> &'static str;
53    }
54
55    impl Clone for Transformer {
56        fn clone(&self) -> Self;
57    }
58}
59
60#[async_trait]
61impl Transformer for Box<dyn Transformer> {
62    async fn transform_node(&self, node: Node) -> Result<Node> {
63        self.as_ref().transform_node(node).await
64    }
65    fn concurrency(&self) -> Option<usize> {
66        self.as_ref().concurrency()
67    }
68    fn name(&self) -> &'static str {
69        self.as_ref().name()
70    }
71}
72
73#[async_trait]
74impl Transformer for Arc<dyn Transformer> {
75    async fn transform_node(&self, node: Node) -> Result<Node> {
76        self.as_ref().transform_node(node).await
77    }
78    fn concurrency(&self) -> Option<usize> {
79        self.as_ref().concurrency()
80    }
81    fn name(&self) -> &'static str {
82        self.as_ref().name()
83    }
84}
85
86#[async_trait]
87impl Transformer for &dyn Transformer {
88    async fn transform_node(&self, node: Node) -> Result<Node> {
89        (*self).transform_node(node).await
90    }
91    fn concurrency(&self) -> Option<usize> {
92        (*self).concurrency()
93    }
94}
95
96#[async_trait]
97/// Use a closure as a transformer
98impl<F> Transformer for F
99where
100    F: Fn(Node) -> Result<Node> + Send + Sync + Clone,
101{
102    async fn transform_node(&self, node: Node) -> Result<Node> {
103        self(node)
104    }
105}
106
107#[async_trait]
108/// Transforms batched single nodes into streams of nodes
109pub trait BatchableTransformer: Send + Sync + DynClone {
110    /// Transforms a batch of nodes into a stream of nodes
111    async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream;
112
113    /// Overrides the default concurrency of the pipeline
114    fn concurrency(&self) -> Option<usize> {
115        None
116    }
117
118    fn name(&self) -> &'static str {
119        let name = std::any::type_name::<Self>();
120        name.split("::").last().unwrap_or(name)
121    }
122
123    /// Overrides the default batch size of the pipeline
124    fn batch_size(&self) -> Option<usize> {
125        None
126    }
127}
128
129dyn_clone::clone_trait_object!(BatchableTransformer);
130
131#[cfg(feature = "test-utils")]
132mock! {
133    #[derive(Debug)]
134    pub BatchableTransformer {}
135
136    #[async_trait]
137    impl BatchableTransformer for BatchableTransformer {
138        async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream;
139        fn name(&self) -> &'static str;
140        fn batch_size(&self) -> Option<usize>;
141        fn concurrency(&self) -> Option<usize>;
142    }
143
144    impl Clone for BatchableTransformer {
145        fn clone(&self) -> Self;
146    }
147}
148#[async_trait]
149/// Use a closure as a batchable transformer
150impl<F> BatchableTransformer for F
151where
152    F: Fn(Vec<Node>) -> IndexingStream + Send + Sync + Clone,
153{
154    async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream {
155        self(nodes)
156    }
157}
158
159#[async_trait]
160impl BatchableTransformer for Box<dyn BatchableTransformer> {
161    async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream {
162        self.as_ref().batch_transform(nodes).await
163    }
164    fn concurrency(&self) -> Option<usize> {
165        self.as_ref().concurrency()
166    }
167    fn name(&self) -> &'static str {
168        self.as_ref().name()
169    }
170}
171
172#[async_trait]
173impl BatchableTransformer for Arc<dyn BatchableTransformer> {
174    async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream {
175        self.as_ref().batch_transform(nodes).await
176    }
177    fn concurrency(&self) -> Option<usize> {
178        self.as_ref().concurrency()
179    }
180    fn name(&self) -> &'static str {
181        self.as_ref().name()
182    }
183}
184
185#[async_trait]
186impl BatchableTransformer for &dyn BatchableTransformer {
187    async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream {
188        (*self).batch_transform(nodes).await
189    }
190    fn concurrency(&self) -> Option<usize> {
191        (*self).concurrency()
192    }
193}
194
195/// Starting point of a stream
196pub trait Loader: DynClone {
197    fn into_stream(self) -> IndexingStream;
198
199    /// Intended for use with Box<dyn Loader>
200    ///
201    /// Only needed if you use trait objects (Box<dyn Loader>)
202    ///
203    /// # Example
204    ///
205    /// ```ignore
206    /// fn into_stream_boxed(self: Box<Self>) -> IndexingStream {
207    ///    self.into_stream()
208    ///  }
209    /// ```
210    fn into_stream_boxed(self: Box<Self>) -> IndexingStream {
211        unimplemented!(
212            "Please implement into_stream_boxed for your loader, it needs to be implemented on the concrete type"
213        )
214    }
215
216    fn name(&self) -> &'static str {
217        let name = std::any::type_name::<Self>();
218        name.split("::").last().unwrap_or(name)
219    }
220}
221
222dyn_clone::clone_trait_object!(Loader);
223
224#[cfg(feature = "test-utils")]
225mock! {
226    #[derive(Debug)]
227    pub Loader {}
228
229    #[async_trait]
230    impl Loader for Loader {
231        fn into_stream(self) -> IndexingStream;
232        fn into_stream_boxed(self: Box<Self>) -> IndexingStream;
233        fn name(&self) -> &'static str;
234    }
235
236    impl Clone for Loader {
237        fn clone(&self) -> Self;
238    }
239}
240
241impl Loader for Box<dyn Loader> {
242    fn into_stream(self) -> IndexingStream {
243        Loader::into_stream_boxed(self)
244    }
245
246    fn into_stream_boxed(self: Box<Self>) -> IndexingStream {
247        Loader::into_stream(*self)
248    }
249    fn name(&self) -> &'static str {
250        self.as_ref().name()
251    }
252}
253
254impl Loader for &dyn Loader {
255    fn into_stream(self) -> IndexingStream {
256        Loader::into_stream_boxed(Box::new(self))
257    }
258
259    fn into_stream_boxed(self: Box<Self>) -> IndexingStream {
260        Loader::into_stream(*self)
261    }
262}
263
264#[async_trait]
265/// Turns one node into many nodes
266pub trait ChunkerTransformer: Send + Sync + Debug + DynClone {
267    async fn transform_node(&self, node: Node) -> IndexingStream;
268
269    /// Overrides the default concurrency of the pipeline
270    fn concurrency(&self) -> Option<usize> {
271        None
272    }
273
274    fn name(&self) -> &'static str {
275        let name = std::any::type_name::<Self>();
276        name.split("::").last().unwrap_or(name)
277    }
278}
279
280dyn_clone::clone_trait_object!(ChunkerTransformer);
281
282#[cfg(feature = "test-utils")]
283mock! {
284    #[derive(Debug)]
285    pub ChunkerTransformer {}
286
287    #[async_trait]
288    impl ChunkerTransformer for ChunkerTransformer {
289    async fn transform_node(&self, node: Node) -> IndexingStream;
290        fn name(&self) -> &'static str;
291        fn concurrency(&self) -> Option<usize>;
292    }
293
294    impl Clone for ChunkerTransformer {
295        fn clone(&self) -> Self;
296    }
297}
298#[async_trait]
299impl ChunkerTransformer for Box<dyn ChunkerTransformer> {
300    async fn transform_node(&self, node: Node) -> IndexingStream {
301        self.as_ref().transform_node(node).await
302    }
303    fn concurrency(&self) -> Option<usize> {
304        self.as_ref().concurrency()
305    }
306    fn name(&self) -> &'static str {
307        self.as_ref().name()
308    }
309}
310
311#[async_trait]
312impl ChunkerTransformer for Arc<dyn ChunkerTransformer> {
313    async fn transform_node(&self, node: Node) -> IndexingStream {
314        self.as_ref().transform_node(node).await
315    }
316    fn concurrency(&self) -> Option<usize> {
317        self.as_ref().concurrency()
318    }
319    fn name(&self) -> &'static str {
320        self.as_ref().name()
321    }
322}
323
324#[async_trait]
325impl ChunkerTransformer for &dyn ChunkerTransformer {
326    async fn transform_node(&self, node: Node) -> IndexingStream {
327        (*self).transform_node(node).await
328    }
329    fn concurrency(&self) -> Option<usize> {
330        (*self).concurrency()
331    }
332}
333
334// #[cfg_attr(feature = "test-utils", automock)]
335#[async_trait]
336/// Caches nodes, typically by their path and hash
337/// Recommended to namespace on the storage
338///
339/// For now just bool return value for easy filter
340pub trait NodeCache: Send + Sync + Debug + DynClone {
341    async fn get(&self, node: &Node) -> bool;
342    async fn set(&self, node: &Node);
343
344    /// Optionally provide a method to clear the cache
345    async fn clear(&self) -> Result<()> {
346        unimplemented!("Clear not implemented")
347    }
348
349    fn name(&self) -> &'static str {
350        let name = std::any::type_name::<Self>();
351        name.split("::").last().unwrap_or(name)
352    }
353}
354
355dyn_clone::clone_trait_object!(NodeCache);
356
357#[cfg(feature = "test-utils")]
358mock! {
359    #[derive(Debug)]
360    pub NodeCache {}
361
362    #[async_trait]
363    impl NodeCache for NodeCache {
364        async fn get(&self, node: &Node) -> bool;
365        async fn set(&self, node: &Node);
366        async fn clear(&self) -> Result<()>;
367        fn name(&self) -> &'static str;
368
369    }
370
371    impl Clone for NodeCache {
372        fn clone(&self) -> Self;
373    }
374}
375
376#[async_trait]
377impl NodeCache for Box<dyn NodeCache> {
378    async fn get(&self, node: &Node) -> bool {
379        self.as_ref().get(node).await
380    }
381    async fn set(&self, node: &Node) {
382        self.as_ref().set(node).await;
383    }
384    async fn clear(&self) -> Result<()> {
385        self.as_ref().clear().await
386    }
387    fn name(&self) -> &'static str {
388        self.as_ref().name()
389    }
390}
391
392#[async_trait]
393impl NodeCache for Arc<dyn NodeCache> {
394    async fn get(&self, node: &Node) -> bool {
395        self.as_ref().get(node).await
396    }
397    async fn set(&self, node: &Node) {
398        self.as_ref().set(node).await;
399    }
400    async fn clear(&self) -> Result<()> {
401        self.as_ref().clear().await
402    }
403    fn name(&self) -> &'static str {
404        self.as_ref().name()
405    }
406}
407
408#[async_trait]
409impl NodeCache for &dyn NodeCache {
410    async fn get(&self, node: &Node) -> bool {
411        (*self).get(node).await
412    }
413    async fn set(&self, node: &Node) {
414        (*self).set(node).await;
415    }
416    async fn clear(&self) -> Result<()> {
417        (*self).clear().await
418    }
419}
420
421#[async_trait]
422/// Embeds a list of strings and returns its embeddings.
423/// Assumes the strings will be moved.
424pub trait EmbeddingModel: Send + Sync + Debug + DynClone {
425    async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError>;
426
427    fn name(&self) -> &'static str {
428        let name = std::any::type_name::<Self>();
429        name.split("::").last().unwrap_or(name)
430    }
431}
432
433dyn_clone::clone_trait_object!(EmbeddingModel);
434
435#[cfg(feature = "test-utils")]
436mock! {
437    #[derive(Debug)]
438    pub EmbeddingModel {}
439
440    #[async_trait]
441    impl EmbeddingModel for EmbeddingModel {
442        async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError>;
443        fn name(&self) -> &'static str;
444    }
445
446    impl Clone for EmbeddingModel {
447        fn clone(&self) -> Self;
448    }
449}
450
451#[async_trait]
452impl EmbeddingModel for Box<dyn EmbeddingModel> {
453    async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> {
454        self.as_ref().embed(input).await
455    }
456
457    fn name(&self) -> &'static str {
458        self.as_ref().name()
459    }
460}
461
462#[async_trait]
463impl EmbeddingModel for Arc<dyn EmbeddingModel> {
464    async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> {
465        self.as_ref().embed(input).await
466    }
467
468    fn name(&self) -> &'static str {
469        self.as_ref().name()
470    }
471}
472
473#[async_trait]
474impl EmbeddingModel for &dyn EmbeddingModel {
475    async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> {
476        (*self).embed(input).await
477    }
478}
479
480#[async_trait]
481/// Embeds a list of strings and returns its embeddings.
482/// Assumes the strings will be moved.
483pub trait SparseEmbeddingModel: Send + Sync + Debug + DynClone {
484    async fn sparse_embed(
485        &self,
486        input: Vec<String>,
487    ) -> Result<SparseEmbeddings, LanguageModelError>;
488
489    fn name(&self) -> &'static str {
490        let name = std::any::type_name::<Self>();
491        name.split("::").last().unwrap_or(name)
492    }
493}
494
495dyn_clone::clone_trait_object!(SparseEmbeddingModel);
496
497#[cfg(feature = "test-utils")]
498mock! {
499    #[derive(Debug)]
500    pub SparseEmbeddingModel {}
501
502    #[async_trait]
503    impl SparseEmbeddingModel for SparseEmbeddingModel {
504        async fn sparse_embed(&self, input: Vec<String>) -> Result<SparseEmbeddings, LanguageModelError>;
505        fn name(&self) -> &'static str;
506    }
507
508    impl Clone for SparseEmbeddingModel {
509        fn clone(&self) -> Self;
510    }
511}
512
513#[async_trait]
514impl SparseEmbeddingModel for Box<dyn SparseEmbeddingModel> {
515    async fn sparse_embed(
516        &self,
517        input: Vec<String>,
518    ) -> Result<SparseEmbeddings, LanguageModelError> {
519        self.as_ref().sparse_embed(input).await
520    }
521
522    fn name(&self) -> &'static str {
523        self.as_ref().name()
524    }
525}
526
527#[async_trait]
528impl SparseEmbeddingModel for Arc<dyn SparseEmbeddingModel> {
529    async fn sparse_embed(
530        &self,
531        input: Vec<String>,
532    ) -> Result<SparseEmbeddings, LanguageModelError> {
533        self.as_ref().sparse_embed(input).await
534    }
535
536    fn name(&self) -> &'static str {
537        self.as_ref().name()
538    }
539}
540
541#[async_trait]
542impl SparseEmbeddingModel for &dyn SparseEmbeddingModel {
543    async fn sparse_embed(
544        &self,
545        input: Vec<String>,
546    ) -> Result<SparseEmbeddings, LanguageModelError> {
547        (*self).sparse_embed(input).await
548    }
549}
550
551#[async_trait]
552/// Given a string prompt, queries an LLM
553pub trait SimplePrompt: Debug + Send + Sync + DynClone {
554    // Takes a simple prompt, prompts the llm and returns the response
555    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError>;
556
557    fn name(&self) -> &'static str {
558        let name = std::any::type_name::<Self>();
559        name.split("::").last().unwrap_or(name)
560    }
561}
562
563dyn_clone::clone_trait_object!(SimplePrompt);
564
565#[cfg(feature = "test-utils")]
566mock! {
567    #[derive(Debug)]
568    pub SimplePrompt {}
569
570    #[async_trait]
571    impl SimplePrompt for SimplePrompt {
572        async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError>;
573        fn name(&self) -> &'static str;
574    }
575
576    impl Clone for SimplePrompt {
577        fn clone(&self) -> Self;
578    }
579}
580
581#[async_trait]
582impl SimplePrompt for Box<dyn SimplePrompt> {
583    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
584        self.as_ref().prompt(prompt).await
585    }
586
587    fn name(&self) -> &'static str {
588        self.as_ref().name()
589    }
590}
591
592#[async_trait]
593impl SimplePrompt for Arc<dyn SimplePrompt> {
594    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
595        self.as_ref().prompt(prompt).await
596    }
597
598    fn name(&self) -> &'static str {
599        self.as_ref().name()
600    }
601}
602
603#[async_trait]
604impl SimplePrompt for &dyn SimplePrompt {
605    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
606        (*self).prompt(prompt).await
607    }
608}
609
610#[async_trait]
611/// Persists nodes
612pub trait Persist: Debug + Send + Sync + DynClone {
613    async fn setup(&self) -> Result<()>;
614    async fn store(&self, node: Node) -> Result<Node>;
615    async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream;
616    fn batch_size(&self) -> Option<usize> {
617        None
618    }
619
620    fn name(&self) -> &'static str {
621        let name = std::any::type_name::<Self>();
622        name.split("::").last().unwrap_or(name)
623    }
624}
625
626dyn_clone::clone_trait_object!(Persist);
627
628#[cfg(feature = "test-utils")]
629mock! {
630    #[derive(Debug)]
631    pub Persist {}
632
633    #[async_trait]
634    impl Persist for Persist {
635        async fn setup(&self) -> Result<()>;
636        async fn store(&self, node: Node) -> Result<Node>;
637        async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream;
638        fn batch_size(&self) -> Option<usize>;
639
640        fn name(&self) -> &'static str;
641    }
642
643    impl Clone for Persist {
644        fn clone(&self) -> Self;
645    }
646}
647
648#[async_trait]
649impl Persist for Box<dyn Persist> {
650    async fn setup(&self) -> Result<()> {
651        self.as_ref().setup().await
652    }
653    async fn store(&self, node: Node) -> Result<Node> {
654        self.as_ref().store(node).await
655    }
656    async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
657        self.as_ref().batch_store(nodes).await
658    }
659    fn batch_size(&self) -> Option<usize> {
660        self.as_ref().batch_size()
661    }
662    fn name(&self) -> &'static str {
663        self.as_ref().name()
664    }
665}
666
667#[async_trait]
668impl Persist for Arc<dyn Persist> {
669    async fn setup(&self) -> Result<()> {
670        self.as_ref().setup().await
671    }
672    async fn store(&self, node: Node) -> Result<Node> {
673        self.as_ref().store(node).await
674    }
675    async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
676        self.as_ref().batch_store(nodes).await
677    }
678    fn batch_size(&self) -> Option<usize> {
679        self.as_ref().batch_size()
680    }
681    fn name(&self) -> &'static str {
682        self.as_ref().name()
683    }
684}
685
686#[async_trait]
687impl Persist for &dyn Persist {
688    async fn setup(&self) -> Result<()> {
689        (*self).setup().await
690    }
691    async fn store(&self, node: Node) -> Result<Node> {
692        (*self).store(node).await
693    }
694    async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
695        (*self).batch_store(nodes).await
696    }
697    fn batch_size(&self) -> Option<usize> {
698        (*self).batch_size()
699    }
700}
701
702/// Allows for passing defaults from the pipeline to the transformer
703/// Required for batch transformers as at least a marker, implementation is not required
704pub trait WithIndexingDefaults {
705    fn with_indexing_defaults(&mut self, _indexing_defaults: IndexingDefaults) {}
706}
707
708/// Allows for passing defaults from the pipeline to the batch transformer
709/// Required for batch transformers as at least a marker, implementation is not required
710pub trait WithBatchIndexingDefaults {
711    fn with_indexing_defaults(&mut self, _indexing_defaults: IndexingDefaults) {}
712}
713
714impl WithIndexingDefaults for dyn Transformer {}
715impl WithIndexingDefaults for Box<dyn Transformer> {
716    fn with_indexing_defaults(&mut self, indexing_defaults: IndexingDefaults) {
717        self.as_mut().with_indexing_defaults(indexing_defaults);
718    }
719}
720impl WithBatchIndexingDefaults for dyn BatchableTransformer {}
721impl WithBatchIndexingDefaults for Box<dyn BatchableTransformer> {
722    fn with_indexing_defaults(&mut self, indexing_defaults: IndexingDefaults) {
723        self.as_mut().with_indexing_defaults(indexing_defaults);
724    }
725}
726
727impl<F> WithIndexingDefaults for F where F: Fn(Node) -> Result<Node> {}
728impl<F> WithBatchIndexingDefaults for F where F: Fn(Vec<Node>) -> IndexingStream {}
729
730#[cfg(feature = "test-utils")]
731impl WithIndexingDefaults for MockTransformer {}
732//
733#[cfg(feature = "test-utils")]
734impl WithBatchIndexingDefaults for MockBatchableTransformer {}