swiftide_core/
indexing_traits.rs

1//! Traits in Swiftide allow for easy extendability
2//!
3//! All steps defined in the indexing pipeline and the generic transformers can also take a
4//! trait. To bring your own transformers, models and loaders, all you need to do is implement the
5//! trait and it should work out of the box.
6use crate::node::Node;
7use crate::Embeddings;
8use crate::{
9    indexing_defaults::IndexingDefaults, indexing_stream::IndexingStream, SparseEmbeddings,
10};
11use std::fmt::Debug;
12use std::sync::Arc;
13
14use crate::prompt::Prompt;
15use anyhow::Result;
16use async_trait::async_trait;
17
18pub use dyn_clone::DynClone;
19/// All traits are easily mockable under tests
20#[cfg(feature = "test-utils")]
21#[doc(hidden)]
22use mockall::{mock, predicate::str};
23
24#[async_trait]
25/// Transforms single nodes into single nodes
26pub trait Transformer: Send + Sync + DynClone {
27    async fn transform_node(&self, node: Node) -> Result<Node>;
28
29    /// Overrides the default concurrency of the pipeline
30    fn concurrency(&self) -> Option<usize> {
31        None
32    }
33
34    fn name(&self) -> &'static str {
35        let name = std::any::type_name::<Self>();
36        name.split("::").last().unwrap_or(name)
37    }
38}
39
40dyn_clone::clone_trait_object!(Transformer);
41
42#[cfg(feature = "test-utils")]
43mock! {
44    #[derive(Debug)]
45    pub Transformer {}
46
47    #[async_trait]
48    impl Transformer for Transformer {
49        async fn transform_node(&self, node: Node) -> Result<Node>;
50        fn concurrency(&self) -> Option<usize>;
51        fn name(&self) -> &'static str;
52    }
53
54    impl Clone for Transformer {
55        fn clone(&self) -> Self;
56    }
57}
58
59#[async_trait]
60impl Transformer for Box<dyn Transformer> {
61    async fn transform_node(&self, node: Node) -> Result<Node> {
62        self.as_ref().transform_node(node).await
63    }
64    fn concurrency(&self) -> Option<usize> {
65        self.as_ref().concurrency()
66    }
67    fn name(&self) -> &'static str {
68        self.as_ref().name()
69    }
70}
71
72#[async_trait]
73impl Transformer for Arc<dyn Transformer> {
74    async fn transform_node(&self, node: Node) -> Result<Node> {
75        self.as_ref().transform_node(node).await
76    }
77    fn concurrency(&self) -> Option<usize> {
78        self.as_ref().concurrency()
79    }
80    fn name(&self) -> &'static str {
81        self.as_ref().name()
82    }
83}
84
85#[async_trait]
86impl Transformer for &dyn Transformer {
87    async fn transform_node(&self, node: Node) -> Result<Node> {
88        (*self).transform_node(node).await
89    }
90    fn concurrency(&self) -> Option<usize> {
91        (*self).concurrency()
92    }
93}
94
95#[async_trait]
96/// Use a closure as a transformer
97impl<F> Transformer for F
98where
99    F: Fn(Node) -> Result<Node> + Send + Sync + Clone,
100{
101    async fn transform_node(&self, node: Node) -> Result<Node> {
102        self(node)
103    }
104}
105
106#[async_trait]
107/// Transforms batched single nodes into streams of nodes
108pub trait BatchableTransformer: Send + Sync + DynClone {
109    /// Transforms a batch of nodes into a stream of nodes
110    async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream;
111
112    /// Overrides the default concurrency of the pipeline
113    fn concurrency(&self) -> Option<usize> {
114        None
115    }
116
117    fn name(&self) -> &'static str {
118        let name = std::any::type_name::<Self>();
119        name.split("::").last().unwrap_or(name)
120    }
121
122    /// Overrides the default batch size of the pipeline
123    fn batch_size(&self) -> Option<usize> {
124        None
125    }
126}
127
128dyn_clone::clone_trait_object!(BatchableTransformer);
129
130#[cfg(feature = "test-utils")]
131mock! {
132    #[derive(Debug)]
133    pub BatchableTransformer {}
134
135    #[async_trait]
136    impl BatchableTransformer for BatchableTransformer {
137        async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream;
138        fn name(&self) -> &'static str;
139        fn batch_size(&self) -> Option<usize>;
140        fn concurrency(&self) -> Option<usize>;
141    }
142
143    impl Clone for BatchableTransformer {
144        fn clone(&self) -> Self;
145    }
146}
147#[async_trait]
148/// Use a closure as a batchable transformer
149impl<F> BatchableTransformer for F
150where
151    F: Fn(Vec<Node>) -> IndexingStream + Send + Sync + Clone,
152{
153    async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream {
154        self(nodes)
155    }
156}
157
158#[async_trait]
159impl BatchableTransformer for Box<dyn BatchableTransformer> {
160    async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream {
161        self.as_ref().batch_transform(nodes).await
162    }
163    fn concurrency(&self) -> Option<usize> {
164        self.as_ref().concurrency()
165    }
166    fn name(&self) -> &'static str {
167        self.as_ref().name()
168    }
169}
170
171#[async_trait]
172impl BatchableTransformer for Arc<dyn BatchableTransformer> {
173    async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream {
174        self.as_ref().batch_transform(nodes).await
175    }
176    fn concurrency(&self) -> Option<usize> {
177        self.as_ref().concurrency()
178    }
179    fn name(&self) -> &'static str {
180        self.as_ref().name()
181    }
182}
183
184#[async_trait]
185impl BatchableTransformer for &dyn BatchableTransformer {
186    async fn batch_transform(&self, nodes: Vec<Node>) -> IndexingStream {
187        (*self).batch_transform(nodes).await
188    }
189    fn concurrency(&self) -> Option<usize> {
190        (*self).concurrency()
191    }
192}
193
194/// Starting point of a stream
195pub trait Loader: DynClone {
196    fn into_stream(self) -> IndexingStream;
197
198    /// Intended for use with Box<dyn Loader>
199    ///
200    /// Only needed if you use trait objects (Box<dyn Loader>)
201    ///
202    /// # Example
203    ///
204    /// ```ignore
205    /// fn into_stream_boxed(self: Box<Self>) -> IndexingStream {
206    ///    self.into_stream()
207    ///  }
208    /// ```
209    fn into_stream_boxed(self: Box<Self>) -> IndexingStream {
210        unimplemented!("Please implement into_stream_boxed for your loader, it needs to be implemented on the concrete type")
211    }
212
213    fn name(&self) -> &'static str {
214        let name = std::any::type_name::<Self>();
215        name.split("::").last().unwrap_or(name)
216    }
217}
218
219dyn_clone::clone_trait_object!(Loader);
220
221#[cfg(feature = "test-utils")]
222mock! {
223    #[derive(Debug)]
224    pub Loader {}
225
226    #[async_trait]
227    impl Loader for Loader {
228        fn into_stream(self) -> IndexingStream;
229        fn into_stream_boxed(self: Box<Self>) -> IndexingStream;
230        fn name(&self) -> &'static str;
231    }
232
233    impl Clone for Loader {
234        fn clone(&self) -> Self;
235    }
236}
237
238impl Loader for Box<dyn Loader> {
239    fn into_stream(self) -> IndexingStream {
240        Loader::into_stream_boxed(self)
241    }
242
243    fn into_stream_boxed(self: Box<Self>) -> IndexingStream {
244        Loader::into_stream(*self)
245    }
246    fn name(&self) -> &'static str {
247        self.as_ref().name()
248    }
249}
250
251impl Loader for &dyn Loader {
252    fn into_stream(self) -> IndexingStream {
253        Loader::into_stream_boxed(Box::new(self))
254    }
255
256    fn into_stream_boxed(self: Box<Self>) -> IndexingStream {
257        Loader::into_stream(*self)
258    }
259}
260
261#[async_trait]
262/// Turns one node into many nodes
263pub trait ChunkerTransformer: Send + Sync + Debug + DynClone {
264    async fn transform_node(&self, node: Node) -> IndexingStream;
265
266    /// Overrides the default concurrency of the pipeline
267    fn concurrency(&self) -> Option<usize> {
268        None
269    }
270
271    fn name(&self) -> &'static str {
272        let name = std::any::type_name::<Self>();
273        name.split("::").last().unwrap_or(name)
274    }
275}
276
277dyn_clone::clone_trait_object!(ChunkerTransformer);
278
279#[cfg(feature = "test-utils")]
280mock! {
281    #[derive(Debug)]
282    pub ChunkerTransformer {}
283
284    #[async_trait]
285    impl ChunkerTransformer for ChunkerTransformer {
286    async fn transform_node(&self, node: Node) -> IndexingStream;
287        fn name(&self) -> &'static str;
288        fn concurrency(&self) -> Option<usize>;
289    }
290
291    impl Clone for ChunkerTransformer {
292        fn clone(&self) -> Self;
293    }
294}
295#[async_trait]
296impl ChunkerTransformer for Box<dyn ChunkerTransformer> {
297    async fn transform_node(&self, node: Node) -> IndexingStream {
298        self.as_ref().transform_node(node).await
299    }
300    fn concurrency(&self) -> Option<usize> {
301        self.as_ref().concurrency()
302    }
303    fn name(&self) -> &'static str {
304        self.as_ref().name()
305    }
306}
307
308#[async_trait]
309impl ChunkerTransformer for Arc<dyn ChunkerTransformer> {
310    async fn transform_node(&self, node: Node) -> IndexingStream {
311        self.as_ref().transform_node(node).await
312    }
313    fn concurrency(&self) -> Option<usize> {
314        self.as_ref().concurrency()
315    }
316    fn name(&self) -> &'static str {
317        self.as_ref().name()
318    }
319}
320
321#[async_trait]
322impl ChunkerTransformer for &dyn ChunkerTransformer {
323    async fn transform_node(&self, node: Node) -> IndexingStream {
324        (*self).transform_node(node).await
325    }
326    fn concurrency(&self) -> Option<usize> {
327        (*self).concurrency()
328    }
329}
330
331// #[cfg_attr(feature = "test-utils", automock)]
332#[async_trait]
333/// Caches nodes, typically by their path and hash
334/// Recommended to namespace on the storage
335///
336/// For now just bool return value for easy filter
337pub trait NodeCache: Send + Sync + Debug + DynClone {
338    async fn get(&self, node: &Node) -> bool;
339    async fn set(&self, node: &Node);
340
341    /// Optionally provide a method to clear the cache
342    async fn clear(&self) -> Result<()> {
343        unimplemented!("Clear not implemented")
344    }
345
346    fn name(&self) -> &'static str {
347        let name = std::any::type_name::<Self>();
348        name.split("::").last().unwrap_or(name)
349    }
350}
351
352dyn_clone::clone_trait_object!(NodeCache);
353
354#[cfg(feature = "test-utils")]
355mock! {
356    #[derive(Debug)]
357    pub NodeCache {}
358
359    #[async_trait]
360    impl NodeCache for NodeCache {
361        async fn get(&self, node: &Node) -> bool;
362        async fn set(&self, node: &Node);
363        async fn clear(&self) -> Result<()>;
364        fn name(&self) -> &'static str;
365
366    }
367
368    impl Clone for NodeCache {
369        fn clone(&self) -> Self;
370    }
371}
372
373#[async_trait]
374impl NodeCache for Box<dyn NodeCache> {
375    async fn get(&self, node: &Node) -> bool {
376        self.as_ref().get(node).await
377    }
378    async fn set(&self, node: &Node) {
379        self.as_ref().set(node).await;
380    }
381    async fn clear(&self) -> Result<()> {
382        self.as_ref().clear().await
383    }
384    fn name(&self) -> &'static str {
385        self.as_ref().name()
386    }
387}
388
389#[async_trait]
390impl NodeCache for Arc<dyn NodeCache> {
391    async fn get(&self, node: &Node) -> bool {
392        self.as_ref().get(node).await
393    }
394    async fn set(&self, node: &Node) {
395        self.as_ref().set(node).await;
396    }
397    async fn clear(&self) -> Result<()> {
398        self.as_ref().clear().await
399    }
400    fn name(&self) -> &'static str {
401        self.as_ref().name()
402    }
403}
404
405#[async_trait]
406impl NodeCache for &dyn NodeCache {
407    async fn get(&self, node: &Node) -> bool {
408        (*self).get(node).await
409    }
410    async fn set(&self, node: &Node) {
411        (*self).set(node).await;
412    }
413    async fn clear(&self) -> Result<()> {
414        (*self).clear().await
415    }
416}
417
418#[async_trait]
419/// Embeds a list of strings and returns its embeddings.
420/// Assumes the strings will be moved.
421pub trait EmbeddingModel: Send + Sync + Debug + DynClone {
422    async fn embed(&self, input: Vec<String>) -> Result<Embeddings>;
423
424    fn name(&self) -> &'static str {
425        let name = std::any::type_name::<Self>();
426        name.split("::").last().unwrap_or(name)
427    }
428}
429
430dyn_clone::clone_trait_object!(EmbeddingModel);
431
432#[cfg(feature = "test-utils")]
433mock! {
434    #[derive(Debug)]
435    pub EmbeddingModel {}
436
437    #[async_trait]
438    impl EmbeddingModel for EmbeddingModel {
439        async fn embed(&self, input: Vec<String>) -> Result<Embeddings>;
440        fn name(&self) -> &'static str;
441    }
442
443    impl Clone for EmbeddingModel {
444        fn clone(&self) -> Self;
445    }
446}
447
448#[async_trait]
449impl EmbeddingModel for Box<dyn EmbeddingModel> {
450    async fn embed(&self, input: Vec<String>) -> Result<Embeddings> {
451        self.as_ref().embed(input).await
452    }
453
454    fn name(&self) -> &'static str {
455        self.as_ref().name()
456    }
457}
458
459#[async_trait]
460impl EmbeddingModel for Arc<dyn EmbeddingModel> {
461    async fn embed(&self, input: Vec<String>) -> Result<Embeddings> {
462        self.as_ref().embed(input).await
463    }
464
465    fn name(&self) -> &'static str {
466        self.as_ref().name()
467    }
468}
469
470#[async_trait]
471impl EmbeddingModel for &dyn EmbeddingModel {
472    async fn embed(&self, input: Vec<String>) -> Result<Embeddings> {
473        (*self).embed(input).await
474    }
475}
476
477#[async_trait]
478/// Embeds a list of strings and returns its embeddings.
479/// Assumes the strings will be moved.
480pub trait SparseEmbeddingModel: Send + Sync + Debug + DynClone {
481    async fn sparse_embed(&self, input: Vec<String>) -> Result<SparseEmbeddings>;
482
483    fn name(&self) -> &'static str {
484        let name = std::any::type_name::<Self>();
485        name.split("::").last().unwrap_or(name)
486    }
487}
488
489dyn_clone::clone_trait_object!(SparseEmbeddingModel);
490
491#[cfg(feature = "test-utils")]
492mock! {
493    #[derive(Debug)]
494    pub SparseEmbeddingModel {}
495
496    #[async_trait]
497    impl SparseEmbeddingModel for SparseEmbeddingModel {
498    async fn sparse_embed(&self, input: Vec<String>) -> Result<SparseEmbeddings>;
499        fn name(&self) -> &'static str;
500    }
501
502    impl Clone for SparseEmbeddingModel {
503        fn clone(&self) -> Self;
504    }
505}
506
507#[async_trait]
508impl SparseEmbeddingModel for Box<dyn SparseEmbeddingModel> {
509    async fn sparse_embed(&self, input: Vec<String>) -> Result<SparseEmbeddings> {
510        self.as_ref().sparse_embed(input).await
511    }
512
513    fn name(&self) -> &'static str {
514        self.as_ref().name()
515    }
516}
517
518#[async_trait]
519impl SparseEmbeddingModel for Arc<dyn SparseEmbeddingModel> {
520    async fn sparse_embed(&self, input: Vec<String>) -> Result<SparseEmbeddings> {
521        self.as_ref().sparse_embed(input).await
522    }
523
524    fn name(&self) -> &'static str {
525        self.as_ref().name()
526    }
527}
528
529#[async_trait]
530impl SparseEmbeddingModel for &dyn SparseEmbeddingModel {
531    async fn sparse_embed(&self, input: Vec<String>) -> Result<SparseEmbeddings> {
532        (*self).sparse_embed(input).await
533    }
534}
535
536#[async_trait]
537/// Given a string prompt, queries an LLM
538pub trait SimplePrompt: Debug + Send + Sync + DynClone {
539    // Takes a simple prompt, prompts the llm and returns the response
540    async fn prompt(&self, prompt: Prompt) -> Result<String>;
541
542    fn name(&self) -> &'static str {
543        let name = std::any::type_name::<Self>();
544        name.split("::").last().unwrap_or(name)
545    }
546}
547
548dyn_clone::clone_trait_object!(SimplePrompt);
549
550#[cfg(feature = "test-utils")]
551mock! {
552    #[derive(Debug)]
553    pub SimplePrompt {}
554
555    #[async_trait]
556    impl SimplePrompt for SimplePrompt {
557        async fn prompt(&self, prompt: Prompt) -> Result<String>;
558        fn name(&self) -> &'static str;
559    }
560
561    impl Clone for SimplePrompt {
562        fn clone(&self) -> Self;
563    }
564}
565
566#[async_trait]
567impl SimplePrompt for Box<dyn SimplePrompt> {
568    async fn prompt(&self, prompt: Prompt) -> Result<String> {
569        self.as_ref().prompt(prompt).await
570    }
571
572    fn name(&self) -> &'static str {
573        self.as_ref().name()
574    }
575}
576
577#[async_trait]
578impl SimplePrompt for Arc<dyn SimplePrompt> {
579    async fn prompt(&self, prompt: Prompt) -> Result<String> {
580        self.as_ref().prompt(prompt).await
581    }
582
583    fn name(&self) -> &'static str {
584        self.as_ref().name()
585    }
586}
587
588#[async_trait]
589impl SimplePrompt for &dyn SimplePrompt {
590    async fn prompt(&self, prompt: Prompt) -> Result<String> {
591        (*self).prompt(prompt).await
592    }
593}
594
595#[async_trait]
596/// Persists nodes
597pub trait Persist: Debug + Send + Sync + DynClone {
598    async fn setup(&self) -> Result<()>;
599    async fn store(&self, node: Node) -> Result<Node>;
600    async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream;
601    fn batch_size(&self) -> Option<usize> {
602        None
603    }
604
605    fn name(&self) -> &'static str {
606        let name = std::any::type_name::<Self>();
607        name.split("::").last().unwrap_or(name)
608    }
609}
610
611dyn_clone::clone_trait_object!(Persist);
612
613#[cfg(feature = "test-utils")]
614mock! {
615    #[derive(Debug)]
616    pub Persist {}
617
618    #[async_trait]
619    impl Persist for Persist {
620        async fn setup(&self) -> Result<()>;
621        async fn store(&self, node: Node) -> Result<Node>;
622        async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream;
623        fn batch_size(&self) -> Option<usize>;
624
625        fn name(&self) -> &'static str;
626    }
627
628    impl Clone for Persist {
629        fn clone(&self) -> Self;
630    }
631}
632
633#[async_trait]
634impl Persist for Box<dyn Persist> {
635    async fn setup(&self) -> Result<()> {
636        self.as_ref().setup().await
637    }
638    async fn store(&self, node: Node) -> Result<Node> {
639        self.as_ref().store(node).await
640    }
641    async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
642        self.as_ref().batch_store(nodes).await
643    }
644    fn batch_size(&self) -> Option<usize> {
645        self.as_ref().batch_size()
646    }
647    fn name(&self) -> &'static str {
648        self.as_ref().name()
649    }
650}
651
652#[async_trait]
653impl Persist for Arc<dyn Persist> {
654    async fn setup(&self) -> Result<()> {
655        self.as_ref().setup().await
656    }
657    async fn store(&self, node: Node) -> Result<Node> {
658        self.as_ref().store(node).await
659    }
660    async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
661        self.as_ref().batch_store(nodes).await
662    }
663    fn batch_size(&self) -> Option<usize> {
664        self.as_ref().batch_size()
665    }
666    fn name(&self) -> &'static str {
667        self.as_ref().name()
668    }
669}
670
671#[async_trait]
672impl Persist for &dyn Persist {
673    async fn setup(&self) -> Result<()> {
674        (*self).setup().await
675    }
676    async fn store(&self, node: Node) -> Result<Node> {
677        (*self).store(node).await
678    }
679    async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
680        (*self).batch_store(nodes).await
681    }
682    fn batch_size(&self) -> Option<usize> {
683        (*self).batch_size()
684    }
685}
686
687/// Allows for passing defaults from the pipeline to the transformer
688/// Required for batch transformers as at least a marker, implementation is not required
689pub trait WithIndexingDefaults {
690    fn with_indexing_defaults(&mut self, _indexing_defaults: IndexingDefaults) {}
691}
692
693/// Allows for passing defaults from the pipeline to the batch transformer
694/// Required for batch transformers as at least a marker, implementation is not required
695pub trait WithBatchIndexingDefaults {
696    fn with_indexing_defaults(&mut self, _indexing_defaults: IndexingDefaults) {}
697}
698
699impl WithIndexingDefaults for dyn Transformer {}
700impl WithIndexingDefaults for Box<dyn Transformer> {
701    fn with_indexing_defaults(&mut self, indexing_defaults: IndexingDefaults) {
702        self.as_mut().with_indexing_defaults(indexing_defaults);
703    }
704}
705impl WithBatchIndexingDefaults for dyn BatchableTransformer {}
706impl WithBatchIndexingDefaults for Box<dyn BatchableTransformer> {
707    fn with_indexing_defaults(&mut self, indexing_defaults: IndexingDefaults) {
708        self.as_mut().with_indexing_defaults(indexing_defaults);
709    }
710}
711
712impl<F> WithIndexingDefaults for F where F: Fn(Node) -> Result<Node> {}
713impl<F> WithBatchIndexingDefaults for F where F: Fn(Vec<Node>) -> IndexingStream {}
714
715#[cfg(feature = "test-utils")]
716impl WithIndexingDefaults for MockTransformer {}
717//
718#[cfg(feature = "test-utils")]
719impl WithBatchIndexingDefaults for MockBatchableTransformer {}