swiftide_core/
indexing_traits.rs

1//! Traits in Swiftide allow for easy extendability
2//!
3//! All steps defined in the indexing pipeline and the generic transformers can also take a
4//! trait. To bring your own transformers, models and loaders, all you need to do is implement the
5//! trait and it should work out of the box.
6use crate::Embeddings;
7use crate::node::{Chunk, Node};
8use crate::{
9    SparseEmbeddings, indexing_defaults::IndexingDefaults, indexing_stream::IndexingStream,
10};
11use std::fmt::Debug;
12use std::sync::Arc;
13
14use crate::chat_completion::errors::LanguageModelError;
15use crate::prompt::Prompt;
16use anyhow::Result;
17use async_trait::async_trait;
18
19pub use dyn_clone::DynClone;
20/// All traits are easily mockable under tests
21#[cfg(feature = "test-utils")]
22#[doc(hidden)]
23use mockall::{mock, predicate::str};
24use schemars::{JsonSchema, schema_for};
25use serde::de::DeserializeOwned;
26
27#[async_trait]
28/// Transforms single nodes into single nodes
29pub trait Transformer: Send + Sync + DynClone {
30    type Input: Chunk;
31    type Output: Chunk;
32
33    async fn transform_node(&self, node: Node<Self::Input>) -> Result<Node<Self::Output>>;
34
35    /// Overrides the default concurrency of the pipeline
36    fn concurrency(&self) -> Option<usize> {
37        None
38    }
39
40    fn name(&self) -> &'static str {
41        let name = std::any::type_name::<Self>();
42        name.split("::").last().unwrap_or(name)
43    }
44}
45
46dyn_clone::clone_trait_object!(<I, O> Transformer<Input = I, Output = O>);
47
48#[cfg(feature = "test-utils")]
49mock! {
50    #[derive(Debug)]
51    pub Transformer {}
52
53    #[async_trait]
54    impl Transformer for Transformer {
55        type Input = String;
56        type Output = String;
57
58        async fn transform_node(&self, node: Node<String>) -> Result<Node<String>>;
59        fn concurrency(&self) -> Option<usize>;
60        fn name(&self) -> &'static str;
61    }
62
63    impl Clone for Transformer {
64        fn clone(&self) -> Self;
65    }
66}
67
68#[async_trait]
69impl<I: Chunk, O: Chunk> Transformer for Box<dyn Transformer<Input = I, Output = O>> {
70    type Input = I;
71    type Output = O;
72
73    async fn transform_node(&self, node: Node<Self::Input>) -> Result<Node<Self::Output>> {
74        self.as_ref().transform_node(node).await
75    }
76    fn concurrency(&self) -> Option<usize> {
77        self.as_ref().concurrency()
78    }
79    fn name(&self) -> &'static str {
80        self.as_ref().name()
81    }
82}
83
84#[async_trait]
85impl<I: Chunk, O: Chunk> Transformer for Arc<dyn Transformer<Input = I, Output = O>> {
86    type Input = I;
87    type Output = O;
88
89    async fn transform_node(&self, node: Node<Self::Input>) -> Result<Node<Self::Output>> {
90        self.as_ref().transform_node(node).await
91    }
92    fn concurrency(&self) -> Option<usize> {
93        self.as_ref().concurrency()
94    }
95    fn name(&self) -> &'static str {
96        self.as_ref().name()
97    }
98}
99
100#[async_trait]
101impl<I: Chunk, O: Chunk> Transformer for &dyn Transformer<Input = I, Output = O> {
102    type Input = I;
103    type Output = O;
104
105    async fn transform_node(&self, node: Node<Self::Input>) -> Result<Node<Self::Output>> {
106        (*self).transform_node(node).await
107    }
108    fn concurrency(&self) -> Option<usize> {
109        (*self).concurrency()
110    }
111}
112
113#[async_trait]
114/// Use a closure as a transformer
115// TODO: Find a way to make this work with full generics
116impl<F> Transformer for F
117where
118    F: Fn(Node<String>) -> Result<Node<String>> + Send + Sync + Clone,
119{
120    type Input = String;
121    type Output = String;
122
123    async fn transform_node(&self, node: Node<Self::Input>) -> Result<Node<Self::Output>> {
124        self(node)
125    }
126}
127
128#[async_trait]
129/// Transforms batched single nodes into streams of nodes
130pub trait BatchableTransformer: Send + Sync + DynClone {
131    type Input: Chunk;
132    type Output: Chunk;
133
134    /// Transforms a batch of nodes into a stream of nodes
135    async fn batch_transform(&self, nodes: Vec<Node<Self::Input>>) -> IndexingStream<Self::Output>;
136
137    /// Overrides the default concurrency of the pipeline
138    fn concurrency(&self) -> Option<usize> {
139        None
140    }
141
142    fn name(&self) -> &'static str {
143        let name = std::any::type_name::<Self>();
144        name.split("::").last().unwrap_or(name)
145    }
146
147    /// Overrides the default batch size of the pipeline
148    fn batch_size(&self) -> Option<usize> {
149        None
150    }
151}
152
153dyn_clone::clone_trait_object!(<I, O> BatchableTransformer<Input = I, Output = O>);
154
155#[cfg(feature = "test-utils")]
156mock! {
157    #[derive(Debug)]
158    pub BatchableTransformer {}
159
160    #[async_trait]
161    impl BatchableTransformer for BatchableTransformer {
162        type Input = String;
163        type Output = String;
164
165        async fn batch_transform(&self, nodes: Vec<Node<String>>) -> IndexingStream<String>;
166        fn name(&self) -> &'static str;
167        fn batch_size(&self) -> Option<usize>;
168        fn concurrency(&self) -> Option<usize>;
169    }
170
171    impl Clone for BatchableTransformer {
172        fn clone(&self) -> Self;
173    }
174}
175#[async_trait]
176/// Use a closure as a batchable transformer
177impl<F> BatchableTransformer for F
178where
179    F: Fn(Vec<Node<String>>) -> IndexingStream<String> + Send + Sync + Clone,
180{
181    type Input = String;
182    type Output = String;
183
184    async fn batch_transform(&self, nodes: Vec<Node<String>>) -> IndexingStream<String> {
185        self(nodes)
186    }
187}
188
189#[async_trait]
190impl<I: Chunk, O: Chunk> BatchableTransformer
191    for Box<dyn BatchableTransformer<Input = I, Output = O>>
192{
193    type Input = I;
194    type Output = O;
195
196    async fn batch_transform(&self, nodes: Vec<Node<Self::Input>>) -> IndexingStream<Self::Output> {
197        self.as_ref().batch_transform(nodes).await
198    }
199    fn concurrency(&self) -> Option<usize> {
200        self.as_ref().concurrency()
201    }
202    fn name(&self) -> &'static str {
203        self.as_ref().name()
204    }
205}
206
207#[async_trait]
208impl<I: Chunk, O: Chunk> BatchableTransformer
209    for Arc<dyn BatchableTransformer<Input = I, Output = O>>
210{
211    type Input = I;
212    type Output = O;
213
214    async fn batch_transform(&self, nodes: Vec<Node<Self::Input>>) -> IndexingStream<Self::Output> {
215        self.as_ref().batch_transform(nodes).await
216    }
217    fn concurrency(&self) -> Option<usize> {
218        self.as_ref().concurrency()
219    }
220    fn name(&self) -> &'static str {
221        self.as_ref().name()
222    }
223}
224
225#[async_trait]
226impl<I: Chunk, O: Chunk> BatchableTransformer for &dyn BatchableTransformer<Input = I, Output = O> {
227    type Input = I;
228    type Output = O;
229
230    async fn batch_transform(&self, nodes: Vec<Node<Self::Input>>) -> IndexingStream<Self::Output> {
231        (*self).batch_transform(nodes).await
232    }
233    fn concurrency(&self) -> Option<usize> {
234        (*self).concurrency()
235    }
236}
237
238/// Starting point of a stream
239pub trait Loader: DynClone + Send + Sync {
240    type Output: Chunk;
241
242    fn into_stream(self) -> IndexingStream<Self::Output>;
243
244    /// Intended for use with Box<dyn Loader>
245    ///
246    /// Only needed if you use trait objects (Box<dyn Loader>)
247    ///
248    /// # Example
249    ///
250    /// ```ignore
251    /// fn into_stream_boxed(self: Box<Self>) -> IndexingStream {
252    ///    self.into_stream()
253    ///  }
254    /// ```
255    fn into_stream_boxed(self: Box<Self>) -> IndexingStream<Self::Output> {
256        unimplemented!(
257            "Please implement into_stream_boxed for your loader, it needs to be implemented on the concrete type"
258        )
259    }
260
261    fn name(&self) -> &'static str {
262        let name = std::any::type_name::<Self>();
263        name.split("::").last().unwrap_or(name)
264    }
265}
266
267dyn_clone::clone_trait_object!(<O> Loader<Output = O>);
268
269#[cfg(feature = "test-utils")]
270mock! {
271    #[derive(Debug)]
272    pub Loader {}
273
274    #[async_trait]
275    impl Loader for Loader {
276        type Output = String;
277
278        fn into_stream(self) -> IndexingStream<String>;
279        fn into_stream_boxed(self: Box<Self>) -> IndexingStream<String>;
280        fn name(&self) -> &'static str;
281    }
282
283    impl Clone for Loader {
284        fn clone(&self) -> Self;
285    }
286}
287
288impl<O: Chunk> Loader for Box<dyn Loader<Output = O>> {
289    type Output = O;
290
291    fn into_stream(self) -> IndexingStream<Self::Output> {
292        Loader::into_stream_boxed(self)
293    }
294
295    fn into_stream_boxed(self: Box<Self>) -> IndexingStream<Self::Output> {
296        Loader::into_stream(*self)
297    }
298    fn name(&self) -> &'static str {
299        self.as_ref().name()
300    }
301}
302
303impl<O: Chunk> Loader for &dyn Loader<Output = O> {
304    type Output = O;
305
306    fn into_stream(self) -> IndexingStream<Self::Output> {
307        Loader::into_stream_boxed(Box::new(self))
308    }
309
310    fn into_stream_boxed(self: Box<Self>) -> IndexingStream<Self::Output> {
311        Loader::into_stream(*self)
312    }
313}
314
315#[async_trait]
316/// Turns one node into many nodes
317pub trait ChunkerTransformer: Send + Sync + DynClone {
318    type Input: Chunk;
319    type Output: Chunk;
320
321    async fn transform_node(&self, node: Node<Self::Input>) -> IndexingStream<Self::Output>;
322
323    /// Overrides the default concurrency of the pipeline
324    fn concurrency(&self) -> Option<usize> {
325        None
326    }
327
328    fn name(&self) -> &'static str {
329        let name = std::any::type_name::<Self>();
330        name.split("::").last().unwrap_or(name)
331    }
332}
333
334dyn_clone::clone_trait_object!(<I, O> ChunkerTransformer<Input = I, Output = O>);
335
336#[cfg(feature = "test-utils")]
337mock! {
338    #[derive(Debug)]
339    pub ChunkerTransformer {}
340
341    #[async_trait]
342    impl ChunkerTransformer for ChunkerTransformer {
343        type Input = String;
344        type Output = String;
345
346    async fn transform_node(&self, node: Node<String>) -> IndexingStream<String>;
347        fn name(&self) -> &'static str;
348        fn concurrency(&self) -> Option<usize>;
349    }
350
351    impl Clone for ChunkerTransformer {
352        fn clone(&self) -> Self;
353    }
354}
355#[async_trait]
356impl<I: Chunk, O: Chunk> ChunkerTransformer for Box<dyn ChunkerTransformer<Input = I, Output = O>> {
357    type Input = I;
358    type Output = O;
359
360    async fn transform_node(&self, node: Node<I>) -> IndexingStream<O> {
361        self.as_ref().transform_node(node).await
362    }
363    fn concurrency(&self) -> Option<usize> {
364        self.as_ref().concurrency()
365    }
366    fn name(&self) -> &'static str {
367        self.as_ref().name()
368    }
369}
370
371#[async_trait]
372impl<I: Chunk, O: Chunk> ChunkerTransformer for Arc<dyn ChunkerTransformer<Input = I, Output = O>> {
373    type Input = I;
374    type Output = O;
375
376    async fn transform_node(&self, node: Node<I>) -> IndexingStream<O> {
377        self.as_ref().transform_node(node).await
378    }
379    fn concurrency(&self) -> Option<usize> {
380        self.as_ref().concurrency()
381    }
382    fn name(&self) -> &'static str {
383        self.as_ref().name()
384    }
385}
386
387#[async_trait]
388impl<I: Chunk, O: Chunk> ChunkerTransformer for &dyn ChunkerTransformer<Input = I, Output = O> {
389    type Input = I;
390    type Output = O;
391
392    async fn transform_node(&self, node: Node<I>) -> IndexingStream<O> {
393        (*self).transform_node(node).await
394    }
395    fn concurrency(&self) -> Option<usize> {
396        (*self).concurrency()
397    }
398}
399
400#[async_trait]
401impl<F> ChunkerTransformer for F
402where
403    F: Fn(Node<String>) -> IndexingStream<String> + Send + Sync + Clone,
404{
405    async fn transform_node(&self, node: Node<String>) -> IndexingStream<String> {
406        self(node)
407    }
408
409    type Input = String;
410
411    type Output = String;
412}
413
414#[async_trait]
415/// Caches nodes, typically by their path and hash
416/// Recommended to namespace on the storage
417///
418/// For now just bool return value for easy filter
419pub trait NodeCache: Send + Sync + Debug + DynClone {
420    type Input: Chunk;
421
422    async fn get(&self, node: &Node<Self::Input>) -> bool;
423    async fn set(&self, node: &Node<Self::Input>);
424
425    /// Optionally provide a method to clear the cache
426    async fn clear(&self) -> Result<()> {
427        unimplemented!("Clear not implemented")
428    }
429
430    fn name(&self) -> &'static str {
431        let name = std::any::type_name::<Self>();
432        name.split("::").last().unwrap_or(name)
433    }
434}
435
436dyn_clone::clone_trait_object!(<T> NodeCache<Input = T>);
437
438#[cfg(feature = "test-utils")]
439mock! {
440    #[derive(Debug)]
441    pub NodeCache {}
442
443    #[async_trait]
444    impl NodeCache for NodeCache {
445        type Input = String;
446        async fn get(&self, node: &Node<String>) -> bool;
447        async fn set(&self, node: &Node<String>);
448        async fn clear(&self) -> Result<()>;
449        fn name(&self) -> &'static str;
450
451    }
452
453    impl Clone for NodeCache {
454        fn clone(&self) -> Self;
455    }
456}
457
458#[async_trait]
459impl<T: Chunk> NodeCache for Box<dyn NodeCache<Input = T>> {
460    type Input = T;
461
462    async fn get(&self, node: &Node<T>) -> bool {
463        self.as_ref().get(node).await
464    }
465    async fn set(&self, node: &Node<T>) {
466        self.as_ref().set(node).await;
467    }
468    async fn clear(&self) -> Result<()> {
469        self.as_ref().clear().await
470    }
471    fn name(&self) -> &'static str {
472        self.as_ref().name()
473    }
474}
475
476#[async_trait]
477impl<T: Chunk> NodeCache for Arc<dyn NodeCache<Input = T>> {
478    type Input = T;
479    async fn get(&self, node: &Node<T>) -> bool {
480        self.as_ref().get(node).await
481    }
482    async fn set(&self, node: &Node<T>) {
483        self.as_ref().set(node).await;
484    }
485    async fn clear(&self) -> Result<()> {
486        self.as_ref().clear().await
487    }
488    fn name(&self) -> &'static str {
489        self.as_ref().name()
490    }
491}
492
493#[async_trait]
494impl<T: Chunk> NodeCache for &dyn NodeCache<Input = T> {
495    type Input = T;
496    async fn get(&self, node: &Node<T>) -> bool {
497        (*self).get(node).await
498    }
499    async fn set(&self, node: &Node<T>) {
500        (*self).set(node).await;
501    }
502    async fn clear(&self) -> Result<()> {
503        (*self).clear().await
504    }
505}
506
507#[async_trait]
508/// Embeds a list of strings and returns its embeddings.
509/// Assumes the strings will be moved.
510pub trait EmbeddingModel: Send + Sync + Debug + DynClone {
511    async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError>;
512
513    fn name(&self) -> &'static str {
514        let name = std::any::type_name::<Self>();
515        name.split("::").last().unwrap_or(name)
516    }
517}
518
519dyn_clone::clone_trait_object!(EmbeddingModel);
520
521#[cfg(feature = "test-utils")]
522mock! {
523    #[derive(Debug)]
524    pub EmbeddingModel {}
525
526    #[async_trait]
527    impl EmbeddingModel for EmbeddingModel {
528        async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError>;
529        fn name(&self) -> &'static str;
530    }
531
532    impl Clone for EmbeddingModel {
533        fn clone(&self) -> Self;
534    }
535}
536
537#[async_trait]
538impl EmbeddingModel for Box<dyn EmbeddingModel> {
539    async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> {
540        self.as_ref().embed(input).await
541    }
542
543    fn name(&self) -> &'static str {
544        self.as_ref().name()
545    }
546}
547
548#[async_trait]
549impl EmbeddingModel for Arc<dyn EmbeddingModel> {
550    async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> {
551        self.as_ref().embed(input).await
552    }
553
554    fn name(&self) -> &'static str {
555        self.as_ref().name()
556    }
557}
558
559#[async_trait]
560impl EmbeddingModel for &dyn EmbeddingModel {
561    async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> {
562        (*self).embed(input).await
563    }
564}
565
566#[async_trait]
567/// Embeds a list of strings and returns its embeddings.
568/// Assumes the strings will be moved.
569pub trait SparseEmbeddingModel: Send + Sync + Debug + DynClone {
570    async fn sparse_embed(
571        &self,
572        input: Vec<String>,
573    ) -> Result<SparseEmbeddings, LanguageModelError>;
574
575    fn name(&self) -> &'static str {
576        let name = std::any::type_name::<Self>();
577        name.split("::").last().unwrap_or(name)
578    }
579}
580
581dyn_clone::clone_trait_object!(SparseEmbeddingModel);
582
583#[cfg(feature = "test-utils")]
584mock! {
585    #[derive(Debug)]
586    pub SparseEmbeddingModel {}
587
588    #[async_trait]
589    impl SparseEmbeddingModel for SparseEmbeddingModel {
590        async fn sparse_embed(&self, input: Vec<String>) -> Result<SparseEmbeddings, LanguageModelError>;
591        fn name(&self) -> &'static str;
592    }
593
594    impl Clone for SparseEmbeddingModel {
595        fn clone(&self) -> Self;
596    }
597}
598
599#[async_trait]
600impl SparseEmbeddingModel for Box<dyn SparseEmbeddingModel> {
601    async fn sparse_embed(
602        &self,
603        input: Vec<String>,
604    ) -> Result<SparseEmbeddings, LanguageModelError> {
605        self.as_ref().sparse_embed(input).await
606    }
607
608    fn name(&self) -> &'static str {
609        self.as_ref().name()
610    }
611}
612
613#[async_trait]
614impl SparseEmbeddingModel for Arc<dyn SparseEmbeddingModel> {
615    async fn sparse_embed(
616        &self,
617        input: Vec<String>,
618    ) -> Result<SparseEmbeddings, LanguageModelError> {
619        self.as_ref().sparse_embed(input).await
620    }
621
622    fn name(&self) -> &'static str {
623        self.as_ref().name()
624    }
625}
626
627#[async_trait]
628impl SparseEmbeddingModel for &dyn SparseEmbeddingModel {
629    async fn sparse_embed(
630        &self,
631        input: Vec<String>,
632    ) -> Result<SparseEmbeddings, LanguageModelError> {
633        (*self).sparse_embed(input).await
634    }
635}
636
637#[async_trait]
638/// Given a string prompt, queries an LLM
639pub trait SimplePrompt: Debug + Send + Sync + DynClone {
640    // Takes a simple prompt, prompts the llm and returns the response
641    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError>;
642
643    fn name(&self) -> &'static str {
644        let name = std::any::type_name::<Self>();
645        name.split("::").last().unwrap_or(name)
646    }
647}
648
649dyn_clone::clone_trait_object!(SimplePrompt);
650
651#[cfg(feature = "test-utils")]
652mock! {
653    #[derive(Debug)]
654    pub SimplePrompt {}
655
656    #[async_trait]
657    impl SimplePrompt for SimplePrompt {
658        async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError>;
659        fn name(&self) -> &'static str;
660    }
661
662    impl Clone for SimplePrompt {
663        fn clone(&self) -> Self;
664    }
665}
666
667#[async_trait]
668impl SimplePrompt for Box<dyn SimplePrompt> {
669    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
670        self.as_ref().prompt(prompt).await
671    }
672
673    fn name(&self) -> &'static str {
674        self.as_ref().name()
675    }
676}
677
678#[async_trait]
679impl SimplePrompt for Arc<dyn SimplePrompt> {
680    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
681        self.as_ref().prompt(prompt).await
682    }
683
684    fn name(&self) -> &'static str {
685        self.as_ref().name()
686    }
687}
688
689#[async_trait]
690impl SimplePrompt for &dyn SimplePrompt {
691    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
692        (*self).prompt(prompt).await
693    }
694}
695
696#[async_trait]
697/// Persists nodes
698pub trait Persist: Debug + Send + Sync + DynClone {
699    type Input: Chunk;
700    type Output: Chunk;
701
702    async fn setup(&self) -> Result<()>;
703    async fn store(&self, node: Node<Self::Input>) -> Result<Node<Self::Output>>;
704    async fn batch_store(&self, nodes: Vec<Node<Self::Input>>) -> IndexingStream<Self::Output>;
705    fn batch_size(&self) -> Option<usize> {
706        None
707    }
708
709    fn name(&self) -> &'static str {
710        let name = std::any::type_name::<Self>();
711        name.split("::").last().unwrap_or(name)
712    }
713}
714
715dyn_clone::clone_trait_object!(<I, O> Persist<Input = I, Output = O>);
716
717#[cfg(feature = "test-utils")]
718mock! {
719    #[derive(Debug)]
720    pub Persist {}
721
722    #[async_trait]
723    impl Persist for Persist {
724        type Input = String;
725        type Output = String;
726
727        async fn setup(&self) -> Result<()>;
728        async fn store(&self, node: Node<String>) -> Result<Node<String>>;
729        async fn batch_store(&self, nodes: Vec<Node<String>>) -> IndexingStream<String>;
730        fn batch_size(&self) -> Option<usize>;
731
732        fn name(&self) -> &'static str;
733    }
734
735    impl Clone for Persist {
736        fn clone(&self) -> Self;
737    }
738}
739
740#[async_trait]
741impl<I: Chunk, O: Chunk> Persist for Box<dyn Persist<Input = I, Output = O>> {
742    type Input = I;
743    type Output = O;
744
745    async fn setup(&self) -> Result<()> {
746        self.as_ref().setup().await
747    }
748    async fn store(&self, node: Node<I>) -> Result<Node<O>> {
749        self.as_ref().store(node).await
750    }
751    async fn batch_store(&self, nodes: Vec<Node<I>>) -> IndexingStream<O> {
752        self.as_ref().batch_store(nodes).await
753    }
754    fn batch_size(&self) -> Option<usize> {
755        self.as_ref().batch_size()
756    }
757    fn name(&self) -> &'static str {
758        self.as_ref().name()
759    }
760}
761
762#[async_trait]
763impl<I: Chunk, O: Chunk> Persist for Arc<dyn Persist<Input = I, Output = O>> {
764    type Input = I;
765    type Output = O;
766
767    async fn setup(&self) -> Result<()> {
768        self.as_ref().setup().await
769    }
770    async fn store(&self, node: Node<I>) -> Result<Node<O>> {
771        self.as_ref().store(node).await
772    }
773    async fn batch_store(&self, nodes: Vec<Node<I>>) -> IndexingStream<O> {
774        self.as_ref().batch_store(nodes).await
775    }
776    fn batch_size(&self) -> Option<usize> {
777        self.as_ref().batch_size()
778    }
779    fn name(&self) -> &'static str {
780        self.as_ref().name()
781    }
782}
783
784#[async_trait]
785impl<I: Chunk, O: Chunk> Persist for &dyn Persist<Input = I, Output = O> {
786    type Input = I;
787    type Output = O;
788
789    async fn setup(&self) -> Result<()> {
790        (*self).setup().await
791    }
792    async fn store(&self, node: Node<I>) -> Result<Node<O>> {
793        (*self).store(node).await
794    }
795    async fn batch_store(&self, nodes: Vec<Node<I>>) -> IndexingStream<O> {
796        (*self).batch_store(nodes).await
797    }
798    fn batch_size(&self) -> Option<usize> {
799        (*self).batch_size()
800    }
801}
802
803/// Allows for passing defaults from the pipeline to the transformer
804/// Required for batch transformers as at least a marker, implementation is not required
805pub trait WithIndexingDefaults {
806    fn with_indexing_defaults(&mut self, _indexing_defaults: IndexingDefaults) {}
807}
808
809/// Allows for passing defaults from the pipeline to the batch transformer
810/// Required for batch transformers as at least a marker, implementation is not required
811pub trait WithBatchIndexingDefaults {
812    fn with_indexing_defaults(&mut self, _indexing_defaults: IndexingDefaults) {}
813}
814
815impl<I, O> WithIndexingDefaults for dyn Transformer<Input = I, Output = O> {}
816impl<I, O> WithIndexingDefaults for Box<dyn Transformer<Input = I, Output = O>> {
817    fn with_indexing_defaults(&mut self, indexing_defaults: IndexingDefaults) {
818        self.as_mut().with_indexing_defaults(indexing_defaults);
819    }
820}
821impl<I, O> WithBatchIndexingDefaults for dyn BatchableTransformer<Input = I, Output = O> {}
822impl<I, O> WithBatchIndexingDefaults for Box<dyn BatchableTransformer<Input = I, Output = O>> {
823    fn with_indexing_defaults(&mut self, indexing_defaults: IndexingDefaults) {
824        self.as_mut().with_indexing_defaults(indexing_defaults);
825    }
826}
827
828impl<F> WithIndexingDefaults for F where F: Fn(Node<String>) -> Result<Node<String>> {}
829impl<F> WithBatchIndexingDefaults for F where F: Fn(Vec<Node<String>>) -> IndexingStream<String> {}
830
831#[cfg(feature = "test-utils")]
832impl WithIndexingDefaults for MockTransformer {}
833// //
834#[cfg(feature = "test-utils")]
835impl WithBatchIndexingDefaults for MockBatchableTransformer {}
836
837#[async_trait]
838/// Given a string prompt, queries an LLM to return structured data
839pub trait StructuredPrompt: Debug + Send + Sync + DynClone {
840    async fn structured_prompt<T: serde::Serialize + DeserializeOwned + JsonSchema>(
841        &self,
842        prompt: Prompt,
843    ) -> Result<T, LanguageModelError>;
844
845    fn name(&self) -> &'static str {
846        let name = std::any::type_name::<Self>();
847        name.split("::").last().unwrap_or(name)
848    }
849}
850
851/// Helper trait object to call structured prompt with dynamic dispatch
852///
853/// Internally Swiftide only implements this trait, as implementing `DynStructuredPrompt` gives
854/// `StructuredPrompt` for free
855#[async_trait]
856pub trait DynStructuredPrompt: Debug + Send + Sync + DynClone {
857    async fn structured_prompt_dyn(
858        &self,
859        prompt: Prompt,
860        schema: schemars::Schema,
861    ) -> Result<serde_json::Value, LanguageModelError>;
862
863    fn name(&self) -> &'static str {
864        let name = std::any::type_name::<Self>();
865        name.split("::").last().unwrap_or(name)
866    }
867}
868
869dyn_clone::clone_trait_object!(DynStructuredPrompt);
870
871#[async_trait]
872impl<C> StructuredPrompt for C
873where
874    C: DynStructuredPrompt + Debug + Send + Sync + DynClone,
875{
876    async fn structured_prompt<T: serde::Serialize + DeserializeOwned + JsonSchema>(
877        &self,
878        prompt: Prompt,
879    ) -> Result<T, LanguageModelError> {
880        // Call with T = serde_json::Value
881        let schema = schema_for!(T);
882        let val = self.structured_prompt_dyn(prompt, schema).await?;
883
884        let parsed = serde_json::from_value(val).map_err(LanguageModelError::permanent)?;
885
886        Ok(parsed)
887    }
888}