Skip to main content

synwire_core/retrievers/
traits.rs

1//! Retriever trait and types.
2
3use crate::BoxFuture;
4use crate::documents::Document;
5use crate::embeddings::Embeddings;
6use crate::error::SynwireError;
7use crate::vectorstores::VectorStore;
8use crate::vectorstores::mmr::maximal_marginal_relevance;
9
10/// Trait for document retrievers.
11///
12/// A retriever takes a natural-language query and returns relevant documents.
13/// This is the primary abstraction for retrieval-augmented generation (RAG).
14pub trait Retriever: Send + Sync {
15    /// Retrieve documents relevant to the query.
16    fn get_relevant_documents<'a>(
17        &'a self,
18        query: &'a str,
19    ) -> BoxFuture<'a, Result<Vec<Document>, SynwireError>>;
20}
21
22/// The similarity search strategy to use.
23#[derive(Debug, Clone)]
24#[non_exhaustive]
25pub enum SearchType {
26    /// Standard cosine similarity search.
27    Similarity,
28    /// Maximal Marginal Relevance search, balancing relevance and diversity.
29    Mmr {
30        /// Controls the relevance-diversity trade-off.
31        /// `1.0` = pure relevance, `0.0` = maximum diversity.
32        lambda: f32,
33    },
34}
35
36/// The retrieval mode (dense, sparse, or hybrid).
37#[derive(Debug, Clone)]
38#[non_exhaustive]
39pub enum RetrievalMode {
40    /// Dense vector retrieval (default).
41    Dense,
42    /// Sparse keyword-based retrieval (e.g., BM25).
43    Sparse,
44    /// Hybrid retrieval combining dense and sparse.
45    Hybrid {
46        /// Weight for dense retrieval. `1.0` = pure dense, `0.0` = pure sparse.
47        alpha: f32,
48    },
49}
50
51/// A retriever backed by a [`VectorStore`] and [`Embeddings`] model.
52///
53/// Wraps a vector store to provide the [`Retriever`] interface with
54/// configurable search type and retrieval mode.
55pub struct VectorStoreRetriever {
56    store: Box<dyn VectorStore>,
57    embeddings: Box<dyn Embeddings>,
58    k: usize,
59    search_type: SearchType,
60    retrieval_mode: RetrievalMode,
61}
62
63impl VectorStoreRetriever {
64    /// Creates a new `VectorStoreRetriever`.
65    pub fn new(
66        store: Box<dyn VectorStore>,
67        embeddings: Box<dyn Embeddings>,
68        k: usize,
69        search_type: SearchType,
70        retrieval_mode: RetrievalMode,
71    ) -> Self {
72        Self {
73            store,
74            embeddings,
75            k,
76            search_type,
77            retrieval_mode,
78        }
79    }
80}
81
82impl Retriever for VectorStoreRetriever {
83    fn get_relevant_documents<'a>(
84        &'a self,
85        query: &'a str,
86    ) -> BoxFuture<'a, Result<Vec<Document>, SynwireError>> {
87        Box::pin(async move {
88            // Reject unsupported modes
89            match &self.retrieval_mode {
90                RetrievalMode::Sparse => {
91                    return Err(SynwireError::Other(
92                        "sparse retrieval is not supported by VectorStoreRetriever".into(),
93                    ));
94                }
95                RetrievalMode::Hybrid { .. } => {
96                    return Err(SynwireError::Other(
97                        "hybrid retrieval is not supported by VectorStoreRetriever".into(),
98                    ));
99                }
100                RetrievalMode::Dense => {}
101            }
102
103            match &self.search_type {
104                SearchType::Similarity => {
105                    self.store
106                        .similarity_search(query, self.k, self.embeddings.as_ref())
107                        .await
108                }
109                SearchType::Mmr { lambda } => {
110                    // Fetch more candidates than k for MMR re-ranking
111                    let fetch_k = self.k * 4;
112                    let candidates = self
113                        .store
114                        .similarity_search_with_score(query, fetch_k, self.embeddings.as_ref())
115                        .await?;
116
117                    if candidates.is_empty() {
118                        return Ok(Vec::new());
119                    }
120
121                    let query_vec = self.embeddings.embed_query(query).await?;
122                    let texts: Vec<String> = candidates
123                        .iter()
124                        .map(|(doc, _)| doc.page_content.clone())
125                        .collect();
126                    let candidate_embeddings = self.embeddings.embed_documents(&texts).await?;
127
128                    let indices = maximal_marginal_relevance(
129                        &query_vec,
130                        &candidate_embeddings,
131                        self.k,
132                        *lambda,
133                    );
134
135                    Ok(indices
136                        .into_iter()
137                        .filter_map(|i| candidates.get(i).map(|(doc, _)| doc.clone()))
138                        .collect())
139                }
140            }
141        })
142    }
143}
144
145#[cfg(test)]
146#[allow(clippy::unwrap_used)]
147mod tests {
148    use super::*;
149    use crate::embeddings::FakeEmbeddings;
150    use crate::vectorstores::InMemoryVectorStore;
151
152    #[tokio::test]
153    async fn vector_store_retriever_wraps_store() {
154        let store = InMemoryVectorStore::new();
155        let embeddings = FakeEmbeddings::new(32);
156
157        let docs = vec![
158            Document::new("rust programming"),
159            Document::new("python scripting"),
160            Document::new("rust ownership model"),
161        ];
162        let _ = store.add_documents(&docs, &embeddings).await.unwrap();
163
164        let retriever = VectorStoreRetriever::new(
165            Box::new(store),
166            Box::new(embeddings),
167            2,
168            SearchType::Similarity,
169            RetrievalMode::Dense,
170        );
171
172        let results = retriever.get_relevant_documents("rust").await.unwrap();
173        assert_eq!(results.len(), 2);
174    }
175
176    #[tokio::test]
177    async fn vector_store_retriever_mmr_search() {
178        let store = InMemoryVectorStore::new();
179        let embeddings = FakeEmbeddings::new(32);
180
181        let docs = vec![
182            Document::new("alpha beta"),
183            Document::new("alpha gamma"),
184            Document::new("delta epsilon"),
185        ];
186        let _ = store.add_documents(&docs, &embeddings).await.unwrap();
187
188        let retriever = VectorStoreRetriever::new(
189            Box::new(store),
190            Box::new(embeddings),
191            2,
192            SearchType::Mmr { lambda: 0.5 },
193            RetrievalMode::Dense,
194        );
195
196        let results = retriever.get_relevant_documents("alpha").await.unwrap();
197        assert_eq!(results.len(), 2);
198    }
199
200    #[tokio::test]
201    async fn retriever_sparse_mode_returns_error() {
202        let store = InMemoryVectorStore::new();
203        let embeddings = FakeEmbeddings::new(32);
204
205        let retriever = VectorStoreRetriever::new(
206            Box::new(store),
207            Box::new(embeddings),
208            2,
209            SearchType::Similarity,
210            RetrievalMode::Sparse,
211        );
212
213        let result = retriever.get_relevant_documents("test").await;
214        assert!(result.is_err());
215    }
216
217    #[tokio::test]
218    async fn retriever_hybrid_mode_returns_error() {
219        let store = InMemoryVectorStore::new();
220        let embeddings = FakeEmbeddings::new(32);
221
222        let retriever = VectorStoreRetriever::new(
223            Box::new(store),
224            Box::new(embeddings),
225            2,
226            SearchType::Similarity,
227            RetrievalMode::Hybrid { alpha: 0.5 },
228        );
229
230        let result = retriever.get_relevant_documents("test").await;
231        assert!(result.is_err());
232    }
233}