synwire_core/retrievers/
traits.rs1use crate::BoxFuture;
4use crate::documents::Document;
5use crate::embeddings::Embeddings;
6use crate::error::SynwireError;
7use crate::vectorstores::VectorStore;
8use crate::vectorstores::mmr::maximal_marginal_relevance;
9
10pub trait Retriever: Send + Sync {
15 fn get_relevant_documents<'a>(
17 &'a self,
18 query: &'a str,
19 ) -> BoxFuture<'a, Result<Vec<Document>, SynwireError>>;
20}
21
22#[derive(Debug, Clone)]
24#[non_exhaustive]
25pub enum SearchType {
26 Similarity,
28 Mmr {
30 lambda: f32,
33 },
34}
35
36#[derive(Debug, Clone)]
38#[non_exhaustive]
39pub enum RetrievalMode {
40 Dense,
42 Sparse,
44 Hybrid {
46 alpha: f32,
48 },
49}
50
51pub struct VectorStoreRetriever {
56 store: Box<dyn VectorStore>,
57 embeddings: Box<dyn Embeddings>,
58 k: usize,
59 search_type: SearchType,
60 retrieval_mode: RetrievalMode,
61}
62
63impl VectorStoreRetriever {
64 pub fn new(
66 store: Box<dyn VectorStore>,
67 embeddings: Box<dyn Embeddings>,
68 k: usize,
69 search_type: SearchType,
70 retrieval_mode: RetrievalMode,
71 ) -> Self {
72 Self {
73 store,
74 embeddings,
75 k,
76 search_type,
77 retrieval_mode,
78 }
79 }
80}
81
82impl Retriever for VectorStoreRetriever {
83 fn get_relevant_documents<'a>(
84 &'a self,
85 query: &'a str,
86 ) -> BoxFuture<'a, Result<Vec<Document>, SynwireError>> {
87 Box::pin(async move {
88 match &self.retrieval_mode {
90 RetrievalMode::Sparse => {
91 return Err(SynwireError::Other(
92 "sparse retrieval is not supported by VectorStoreRetriever".into(),
93 ));
94 }
95 RetrievalMode::Hybrid { .. } => {
96 return Err(SynwireError::Other(
97 "hybrid retrieval is not supported by VectorStoreRetriever".into(),
98 ));
99 }
100 RetrievalMode::Dense => {}
101 }
102
103 match &self.search_type {
104 SearchType::Similarity => {
105 self.store
106 .similarity_search(query, self.k, self.embeddings.as_ref())
107 .await
108 }
109 SearchType::Mmr { lambda } => {
110 let fetch_k = self.k * 4;
112 let candidates = self
113 .store
114 .similarity_search_with_score(query, fetch_k, self.embeddings.as_ref())
115 .await?;
116
117 if candidates.is_empty() {
118 return Ok(Vec::new());
119 }
120
121 let query_vec = self.embeddings.embed_query(query).await?;
122 let texts: Vec<String> = candidates
123 .iter()
124 .map(|(doc, _)| doc.page_content.clone())
125 .collect();
126 let candidate_embeddings = self.embeddings.embed_documents(&texts).await?;
127
128 let indices = maximal_marginal_relevance(
129 &query_vec,
130 &candidate_embeddings,
131 self.k,
132 *lambda,
133 );
134
135 Ok(indices
136 .into_iter()
137 .filter_map(|i| candidates.get(i).map(|(doc, _)| doc.clone()))
138 .collect())
139 }
140 }
141 })
142 }
143}
144
145#[cfg(test)]
146#[allow(clippy::unwrap_used)]
147mod tests {
148 use super::*;
149 use crate::embeddings::FakeEmbeddings;
150 use crate::vectorstores::InMemoryVectorStore;
151
152 #[tokio::test]
153 async fn vector_store_retriever_wraps_store() {
154 let store = InMemoryVectorStore::new();
155 let embeddings = FakeEmbeddings::new(32);
156
157 let docs = vec![
158 Document::new("rust programming"),
159 Document::new("python scripting"),
160 Document::new("rust ownership model"),
161 ];
162 let _ = store.add_documents(&docs, &embeddings).await.unwrap();
163
164 let retriever = VectorStoreRetriever::new(
165 Box::new(store),
166 Box::new(embeddings),
167 2,
168 SearchType::Similarity,
169 RetrievalMode::Dense,
170 );
171
172 let results = retriever.get_relevant_documents("rust").await.unwrap();
173 assert_eq!(results.len(), 2);
174 }
175
176 #[tokio::test]
177 async fn vector_store_retriever_mmr_search() {
178 let store = InMemoryVectorStore::new();
179 let embeddings = FakeEmbeddings::new(32);
180
181 let docs = vec![
182 Document::new("alpha beta"),
183 Document::new("alpha gamma"),
184 Document::new("delta epsilon"),
185 ];
186 let _ = store.add_documents(&docs, &embeddings).await.unwrap();
187
188 let retriever = VectorStoreRetriever::new(
189 Box::new(store),
190 Box::new(embeddings),
191 2,
192 SearchType::Mmr { lambda: 0.5 },
193 RetrievalMode::Dense,
194 );
195
196 let results = retriever.get_relevant_documents("alpha").await.unwrap();
197 assert_eq!(results.len(), 2);
198 }
199
200 #[tokio::test]
201 async fn retriever_sparse_mode_returns_error() {
202 let store = InMemoryVectorStore::new();
203 let embeddings = FakeEmbeddings::new(32);
204
205 let retriever = VectorStoreRetriever::new(
206 Box::new(store),
207 Box::new(embeddings),
208 2,
209 SearchType::Similarity,
210 RetrievalMode::Sparse,
211 );
212
213 let result = retriever.get_relevant_documents("test").await;
214 assert!(result.is_err());
215 }
216
217 #[tokio::test]
218 async fn retriever_hybrid_mode_returns_error() {
219 let store = InMemoryVectorStore::new();
220 let embeddings = FakeEmbeddings::new(32);
221
222 let retriever = VectorStoreRetriever::new(
223 Box::new(store),
224 Box::new(embeddings),
225 2,
226 SearchType::Similarity,
227 RetrievalMode::Hybrid { alpha: 0.5 },
228 );
229
230 let result = retriever.get_relevant_documents("test").await;
231 assert!(result.is_err());
232 }
233}