swiftide_core/
indexing_traits.rs

1//! Traits in Swiftide allow for easy extendability
2//!
3//! All steps defined in the indexing pipeline and the generic transformers can also take a
4//! trait. To bring your own transformers, models and loaders, all you need to do is implement the
5//! trait and it should work out of the box.
6use crate::node::Node;
7use crate::Embeddings;
8use crate::{
9    indexing_defaults::IndexingDefaults, indexing_stream::IndexingStream, SparseEmbeddings,
10};
11use std::fmt::Debug;
12use std::sync::Arc;
13
14use crate::chat_completion::errors::LanguageModelError;
15use crate::prompt::Prompt;
16use anyhow::Result;
17use async_trait::async_trait;
18
19pub use dyn_clone::DynClone;
20/// All traits are easily mockable under tests
21#[cfg(feature = "test-utils")]
22#[doc(hidden)]
23use mockall::{mock, predicate::str};
24
25#[async_trait]
26/// Transforms single nodes into single nodes
27pub trait Transformer: Send + Sync + DynClone {
28    async fn transform_node(&self, node: Node) -> Result<Node>;
29
30    /// Overrides the default concurrency of the pipeline
31    fn concurrency(&self) -> Option<usize> {
32        None
33    }
34
35    fn name(&self) -> &'static str {
36        let name = std::any::type_name::<Self>();
37        name.split("::").last().unwrap_or(name)
38    }
39}
40
41dyn_clone::clone_trait_object!(Transformer);
42
43#[cfg(feature = "test-utils")]
44mock! {
45    #[derive(Debug)]
46    pub Transformer {}
47
48    #[async_trait]
49    impl Transformer for Transformer {
50        async fn transform_node(&self, node: Node) -> Result<Node>;
51        fn concurrency(&self) -> Option<usize>;
52        fn name(&self) -> &'static str;
53    }
54
55    impl Clone for Transformer {
56        fn clone(&self) -> Self;
57    }
58}
59
60#[async_trait]
61impl Transformer for Box<dyn Transformer> {
62    async fn transform_node(&self, node: Node) -> Result<Node> {
63        self.as_ref().transform_node(node).await
64    }
65    fn concurrency(&self) -> Option<usize> {
66        self.as_ref().concurrency()
67    }
68    fn name(&self) -> &'static str {
69        self.as_ref().name()
70    }
71}
72
73#[async_trait]
74impl Transformer for Arc<dyn Transformer> {
75    async fn transform_node(&self, node: Node) -> Result<Node> {
76        self.as_ref().transform_node(node).await
77    }
78    fn concurrency(&self) -> Option<usize> {
79        self.as_ref().concurrency()
80    }
81    fn name(&self) -> &'static str {
82        self.as_ref().name()
83    }
84}
85
86#[async_trait]
87impl Transformer for &dyn Transformer {
88    async fn transform_node(&self, node: Node) -> Result<Node> {
89        (*self).transform_node(node).await
90    }
91    fn concurrency(&self) -> Option<usize> {
92        (*self).concurrency()
93    }
94}
95
96#[async_trait]
97/// Use a closure as a transformer
98impl<F> Transformer for F
99where
100    F: Fn(Node) -> Result<Node> + Send + Sync + Clone,
101{
102    async fn transform_node(&self, node: Node) -> Result<Node> {
103        self(node)
104    }
105}
106
107#[async_trait]
108/// Transforms batched single nodes into streams of nodes
109pub trait BatchableTransformer: Send + Sync + DynClone {
110    /// Transforms a batch of nodes into a stream of nodes
111    async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream;
112
113    /// Overrides the default concurrency of the pipeline
114    fn concurrency(&self) -> Option<usize> {
115        None
116    }
117
118    fn name(&self) -> &'static str {
119        let name = std::any::type_name::<Self>();
120        name.split("::").last().unwrap_or(name)
121    }
122
123    /// Overrides the default batch size of the pipeline
124    fn batch_size(&self) -> Option<usize> {
125        None
126    }
127}
128
129dyn_clone::clone_trait_object!(BatchableTransformer);
130
131#[cfg(feature = "test-utils")]
132mock! {
133    #[derive(Debug)]
134    pub BatchableTransformer {}
135
136    #[async_trait]
137    impl BatchableTransformer for BatchableTransformer {
138        async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream;
139        fn name(&self) -> &'static str;
140        fn batch_size(&self) -> Option<usize>;
141        fn concurrency(&self) -> Option<usize>;
142    }
143
144    impl Clone for BatchableTransformer {
145        fn clone(&self) -> Self;
146    }
147}
148#[async_trait]
149/// Use a closure as a batchable transformer
150impl<F> BatchableTransformer for F
151where
152    F: Fn(Vec<Node>) -> IndexingStream + Send + Sync + Clone,
153{
154    async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream {
155        self(nodes)
156    }
157}
158
159#[async_trait]
160impl BatchableTransformer for Box<dyn BatchableTransformer> {
161    async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream {
162        self.as_ref().batch_transform(nodes).await
163    }
164    fn concurrency(&self) -> Option<usize> {
165        self.as_ref().concurrency()
166    }
167    fn name(&self) -> &'static str {
168        self.as_ref().name()
169    }
170}
171
172#[async_trait]
173impl BatchableTransformer for Arc<dyn BatchableTransformer> {
174    async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream {
175        self.as_ref().batch_transform(nodes).await
176    }
177    fn concurrency(&self) -> Option<usize> {
178        self.as_ref().concurrency()
179    }
180    fn name(&self) -> &'static str {
181        self.as_ref().name()
182    }
183}
184
185#[async_trait]
186impl BatchableTransformer for &dyn BatchableTransformer {
187    async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream {
188        (*self).batch_transform(nodes).await
189    }
190    fn concurrency(&self) -> Option<usize> {
191        (*self).concurrency()
192    }
193}
194
195/// Starting point of a stream
196pub trait Loader: DynClone {
197    fn into_stream(self) -> IndexingStream;
198
199    /// Intended for use with Box<dyn Loader>
200    ///
201    /// Only needed if you use trait objects (Box<dyn Loader>)
202    ///
203    /// # Example
204    ///
205    /// ```ignore
206    /// fn into_stream_boxed(self: Box<Self>) -> IndexingStream {
207    ///    self.into_stream()
208    ///  }
209    /// ```
210    fn into_stream_boxed(self: Box<Self>) -> IndexingStream {
211        unimplemented!("Please implement into_stream_boxed for your loader, it needs to be implemented on the concrete type")
212    }
213
214    fn name(&self) -> &'static str {
215        let name = std::any::type_name::<Self>();
216        name.split("::").last().unwrap_or(name)
217    }
218}
219
220dyn_clone::clone_trait_object!(Loader);
221
222#[cfg(feature = "test-utils")]
223mock! {
224    #[derive(Debug)]
225    pub Loader {}
226
227    #[async_trait]
228    impl Loader for Loader {
229        fn into_stream(self) -> IndexingStream;
230        fn into_stream_boxed(self: Box<Self>) -> IndexingStream;
231        fn name(&self) -> &'static str;
232    }
233
234    impl Clone for Loader {
235        fn clone(&self) -> Self;
236    }
237}
238
239impl Loader for Box<dyn Loader> {
240    fn into_stream(self) -> IndexingStream {
241        Loader::into_stream_boxed(self)
242    }
243
244    fn into_stream_boxed(self: Box<Self>) -> IndexingStream {
245        Loader::into_stream(*self)
246    }
247    fn name(&self) -> &'static str {
248        self.as_ref().name()
249    }
250}
251
252impl Loader for &dyn Loader {
253    fn into_stream(self) -> IndexingStream {
254        Loader::into_stream_boxed(Box::new(self))
255    }
256
257    fn into_stream_boxed(self: Box<Self>) -> IndexingStream {
258        Loader::into_stream(*self)
259    }
260}
261
262#[async_trait]
263/// Turns one node into many nodes
264pub trait ChunkerTransformer: Send + Sync + Debug + DynClone {
265    async fn transform_node(&self, node: Node) -> IndexingStream;
266
267    /// Overrides the default concurrency of the pipeline
268    fn concurrency(&self) -> Option<usize> {
269        None
270    }
271
272    fn name(&self) -> &'static str {
273        let name = std::any::type_name::<Self>();
274        name.split("::").last().unwrap_or(name)
275    }
276}
277
278dyn_clone::clone_trait_object!(ChunkerTransformer);
279
280#[cfg(feature = "test-utils")]
281mock! {
282    #[derive(Debug)]
283    pub ChunkerTransformer {}
284
285    #[async_trait]
286    impl ChunkerTransformer for ChunkerTransformer {
287    async fn transform_node(&self, node: Node) -> IndexingStream;
288        fn name(&self) -> &'static str;
289        fn concurrency(&self) -> Option<usize>;
290    }
291
292    impl Clone for ChunkerTransformer {
293        fn clone(&self) -> Self;
294    }
295}
296#[async_trait]
297impl ChunkerTransformer for Box<dyn ChunkerTransformer> {
298    async fn transform_node(&self, node: Node) -> IndexingStream {
299        self.as_ref().transform_node(node).await
300    }
301    fn concurrency(&self) -> Option<usize> {
302        self.as_ref().concurrency()
303    }
304    fn name(&self) -> &'static str {
305        self.as_ref().name()
306    }
307}
308
309#[async_trait]
310impl ChunkerTransformer for Arc<dyn ChunkerTransformer> {
311    async fn transform_node(&self, node: Node) -> IndexingStream {
312        self.as_ref().transform_node(node).await
313    }
314    fn concurrency(&self) -> Option<usize> {
315        self.as_ref().concurrency()
316    }
317    fn name(&self) -> &'static str {
318        self.as_ref().name()
319    }
320}
321
322#[async_trait]
323impl ChunkerTransformer for &dyn ChunkerTransformer {
324    async fn transform_node(&self, node: Node) -> IndexingStream {
325        (*self).transform_node(node).await
326    }
327    fn concurrency(&self) -> Option<usize> {
328        (*self).concurrency()
329    }
330}
331
332// #[cfg_attr(feature = "test-utils", automock)]
333#[async_trait]
334/// Caches nodes, typically by their path and hash
335/// Recommended to namespace on the storage
336///
337/// For now just bool return value for easy filter
338pub trait NodeCache: Send + Sync + Debug + DynClone {
339    async fn get(&self, node: &Node) -> bool;
340    async fn set(&self, node: &Node);
341
342    /// Optionally provide a method to clear the cache
343    async fn clear(&self) -> Result<()> {
344        unimplemented!("Clear not implemented")
345    }
346
347    fn name(&self) -> &'static str {
348        let name = std::any::type_name::<Self>();
349        name.split("::").last().unwrap_or(name)
350    }
351}
352
353dyn_clone::clone_trait_object!(NodeCache);
354
355#[cfg(feature = "test-utils")]
356mock! {
357    #[derive(Debug)]
358    pub NodeCache {}
359
360    #[async_trait]
361    impl NodeCache for NodeCache {
362        async fn get(&self, node: &Node) -> bool;
363        async fn set(&self, node: &Node);
364        async fn clear(&self) -> Result<()>;
365        fn name(&self) -> &'static str;
366
367    }
368
369    impl Clone for NodeCache {
370        fn clone(&self) -> Self;
371    }
372}
373
374#[async_trait]
375impl NodeCache for Box<dyn NodeCache> {
376    async fn get(&self, node: &Node) -> bool {
377        self.as_ref().get(node).await
378    }
379    async fn set(&self, node: &Node) {
380        self.as_ref().set(node).await;
381    }
382    async fn clear(&self) -> Result<()> {
383        self.as_ref().clear().await
384    }
385    fn name(&self) -> &'static str {
386        self.as_ref().name()
387    }
388}
389
390#[async_trait]
391impl NodeCache for Arc<dyn NodeCache> {
392    async fn get(&self, node: &Node) -> bool {
393        self.as_ref().get(node).await
394    }
395    async fn set(&self, node: &Node) {
396        self.as_ref().set(node).await;
397    }
398    async fn clear(&self) -> Result<()> {
399        self.as_ref().clear().await
400    }
401    fn name(&self) -> &'static str {
402        self.as_ref().name()
403    }
404}
405
406#[async_trait]
407impl NodeCache for &dyn NodeCache {
408    async fn get(&self, node: &Node) -> bool {
409        (*self).get(node).await
410    }
411    async fn set(&self, node: &Node) {
412        (*self).set(node).await;
413    }
414    async fn clear(&self) -> Result<()> {
415        (*self).clear().await
416    }
417}
418
419#[async_trait]
420/// Embeds a list of strings and returns its embeddings.
421/// Assumes the strings will be moved.
422pub trait EmbeddingModel: Send + Sync + Debug + DynClone {
423    async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError>;
424
425    fn name(&self) -> &'static str {
426        let name = std::any::type_name::<Self>();
427        name.split("::").last().unwrap_or(name)
428    }
429}
430
431dyn_clone::clone_trait_object!(EmbeddingModel);
432
433#[cfg(feature = "test-utils")]
434mock! {
435    #[derive(Debug)]
436    pub EmbeddingModel {}
437
438    #[async_trait]
439    impl EmbeddingModel for EmbeddingModel {
440        async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError>;
441        fn name(&self) -> &'static str;
442    }
443
444    impl Clone for EmbeddingModel {
445        fn clone(&self) -> Self;
446    }
447}
448
449#[async_trait]
450impl EmbeddingModel for Box<dyn EmbeddingModel> {
451    async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> {
452        self.as_ref().embed(input).await
453    }
454
455    fn name(&self) -> &'static str {
456        self.as_ref().name()
457    }
458}
459
460#[async_trait]
461impl EmbeddingModel for Arc<dyn EmbeddingModel> {
462    async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> {
463        self.as_ref().embed(input).await
464    }
465
466    fn name(&self) -> &'static str {
467        self.as_ref().name()
468    }
469}
470
471#[async_trait]
472impl EmbeddingModel for &dyn EmbeddingModel {
473    async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> {
474        (*self).embed(input).await
475    }
476}
477
478#[async_trait]
479/// Embeds a list of strings and returns its embeddings.
480/// Assumes the strings will be moved.
481pub trait SparseEmbeddingModel: Send + Sync + Debug + DynClone {
482    async fn sparse_embed(
483        &self,
484        input: Vec<String>,
485    ) -> Result<SparseEmbeddings, LanguageModelError>;
486
487    fn name(&self) -> &'static str {
488        let name = std::any::type_name::<Self>();
489        name.split("::").last().unwrap_or(name)
490    }
491}
492
493dyn_clone::clone_trait_object!(SparseEmbeddingModel);
494
495#[cfg(feature = "test-utils")]
496mock! {
497    #[derive(Debug)]
498    pub SparseEmbeddingModel {}
499
500    #[async_trait]
501    impl SparseEmbeddingModel for SparseEmbeddingModel {
502        async fn sparse_embed(&self, input: Vec<String>) -> Result<SparseEmbeddings, LanguageModelError>;
503        fn name(&self) -> &'static str;
504    }
505
506    impl Clone for SparseEmbeddingModel {
507        fn clone(&self) -> Self;
508    }
509}
510
511#[async_trait]
512impl SparseEmbeddingModel for Box<dyn SparseEmbeddingModel> {
513    async fn sparse_embed(
514        &self,
515        input: Vec<String>,
516    ) -> Result<SparseEmbeddings, LanguageModelError> {
517        self.as_ref().sparse_embed(input).await
518    }
519
520    fn name(&self) -> &'static str {
521        self.as_ref().name()
522    }
523}
524
525#[async_trait]
526impl SparseEmbeddingModel for Arc<dyn SparseEmbeddingModel> {
527    async fn sparse_embed(
528        &self,
529        input: Vec<String>,
530    ) -> Result<SparseEmbeddings, LanguageModelError> {
531        self.as_ref().sparse_embed(input).await
532    }
533
534    fn name(&self) -> &'static str {
535        self.as_ref().name()
536    }
537}
538
539#[async_trait]
540impl SparseEmbeddingModel for &dyn SparseEmbeddingModel {
541    async fn sparse_embed(
542        &self,
543        input: Vec<String>,
544    ) -> Result<SparseEmbeddings, LanguageModelError> {
545        (*self).sparse_embed(input).await
546    }
547}
548
549#[async_trait]
550/// Given a string prompt, queries an LLM
551pub trait SimplePrompt: Debug + Send + Sync + DynClone {
552    // Takes a simple prompt, prompts the llm and returns the response
553    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError>;
554
555    fn name(&self) -> &'static str {
556        let name = std::any::type_name::<Self>();
557        name.split("::").last().unwrap_or(name)
558    }
559}
560
561dyn_clone::clone_trait_object!(SimplePrompt);
562
563#[cfg(feature = "test-utils")]
564mock! {
565    #[derive(Debug)]
566    pub SimplePrompt {}
567
568    #[async_trait]
569    impl SimplePrompt for SimplePrompt {
570        async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError>;
571        fn name(&self) -> &'static str;
572    }
573
574    impl Clone for SimplePrompt {
575        fn clone(&self) -> Self;
576    }
577}
578
579#[async_trait]
580impl SimplePrompt for Box<dyn SimplePrompt> {
581    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
582        self.as_ref().prompt(prompt).await
583    }
584
585    fn name(&self) -> &'static str {
586        self.as_ref().name()
587    }
588}
589
590#[async_trait]
591impl SimplePrompt for Arc<dyn SimplePrompt> {
592    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
593        self.as_ref().prompt(prompt).await
594    }
595
596    fn name(&self) -> &'static str {
597        self.as_ref().name()
598    }
599}
600
601#[async_trait]
602impl SimplePrompt for &dyn SimplePrompt {
603    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
604        (*self).prompt(prompt).await
605    }
606}
607
608#[async_trait]
609/// Persists nodes
610pub trait Persist: Debug + Send + Sync + DynClone {
611    async fn setup(&self) -> Result<()>;
612    async fn store(&self, node: Node) -> Result<Node>;
613    async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream;
614    fn batch_size(&self) -> Option<usize> {
615        None
616    }
617
618    fn name(&self) -> &'static str {
619        let name = std::any::type_name::<Self>();
620        name.split("::").last().unwrap_or(name)
621    }
622}
623
624dyn_clone::clone_trait_object!(Persist);
625
626#[cfg(feature = "test-utils")]
627mock! {
628    #[derive(Debug)]
629    pub Persist {}
630
631    #[async_trait]
632    impl Persist for Persist {
633        async fn setup(&self) -> Result<()>;
634        async fn store(&self, node: Node) -> Result<Node>;
635        async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream;
636        fn batch_size(&self) -> Option<usize>;
637
638        fn name(&self) -> &'static str;
639    }
640
641    impl Clone for Persist {
642        fn clone(&self) -> Self;
643    }
644}
645
646#[async_trait]
647impl Persist for Box<dyn Persist> {
648    async fn setup(&self) -> Result<()> {
649        self.as_ref().setup().await
650    }
651    async fn store(&self, node: Node) -> Result<Node> {
652        self.as_ref().store(node).await
653    }
654    async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
655        self.as_ref().batch_store(nodes).await
656    }
657    fn batch_size(&self) -> Option<usize> {
658        self.as_ref().batch_size()
659    }
660    fn name(&self) -> &'static str {
661        self.as_ref().name()
662    }
663}
664
665#[async_trait]
666impl Persist for Arc<dyn Persist> {
667    async fn setup(&self) -> Result<()> {
668        self.as_ref().setup().await
669    }
670    async fn store(&self, node: Node) -> Result<Node> {
671        self.as_ref().store(node).await
672    }
673    async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
674        self.as_ref().batch_store(nodes).await
675    }
676    fn batch_size(&self) -> Option<usize> {
677        self.as_ref().batch_size()
678    }
679    fn name(&self) -> &'static str {
680        self.as_ref().name()
681    }
682}
683
684#[async_trait]
685impl Persist for &dyn Persist {
686    async fn setup(&self) -> Result<()> {
687        (*self).setup().await
688    }
689    async fn store(&self, node: Node) -> Result<Node> {
690        (*self).store(node).await
691    }
692    async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
693        (*self).batch_store(nodes).await
694    }
695    fn batch_size(&self) -> Option<usize> {
696        (*self).batch_size()
697    }
698}
699
700/// Allows for passing defaults from the pipeline to the transformer
701/// Required for batch transformers as at least a marker, implementation is not required
702pub trait WithIndexingDefaults {
703    fn with_indexing_defaults(&mut self, _indexing_defaults: IndexingDefaults) {}
704}
705
706/// Allows for passing defaults from the pipeline to the batch transformer
707/// Required for batch transformers as at least a marker, implementation is not required
708pub trait WithBatchIndexingDefaults {
709    fn with_indexing_defaults(&mut self, _indexing_defaults: IndexingDefaults) {}
710}
711
712impl WithIndexingDefaults for dyn Transformer {}
713impl WithIndexingDefaults for Box<dyn Transformer> {
714    fn with_indexing_defaults(&mut self, indexing_defaults: IndexingDefaults) {
715        self.as_mut().with_indexing_defaults(indexing_defaults);
716    }
717}
718impl WithBatchIndexingDefaults for dyn BatchableTransformer {}
719impl WithBatchIndexingDefaults for Box<dyn BatchableTransformer> {
720    fn with_indexing_defaults(&mut self, indexing_defaults: IndexingDefaults) {
721        self.as_mut().with_indexing_defaults(indexing_defaults);
722    }
723}
724
725impl<F> WithIndexingDefaults for F where F: Fn(Node) -> Result<Node> {}
726impl<F> WithBatchIndexingDefaults for F where F: Fn(Vec<Node>) -> IndexingStream {}
727
728#[cfg(feature = "test-utils")]
729impl WithIndexingDefaults for MockTransformer {}
730//
731#[cfg(feature = "test-utils")]
732impl WithBatchIndexingDefaults for MockBatchableTransformer {}