1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use synaptic_core::{Document, Embeddings, Retriever, SynapticError};
6use tokio::sync::RwLock;
7
8use crate::VectorStore;
9
10struct StoredEntry {
12 document: Document,
13 embedding: Vec<f32>,
14}
15
16pub struct InMemoryVectorStore {
18 entries: RwLock<HashMap<String, StoredEntry>>,
19}
20
21impl InMemoryVectorStore {
22 pub fn new() -> Self {
23 Self {
24 entries: RwLock::new(HashMap::new()),
25 }
26 }
27
28 pub async fn from_texts(
30 texts: Vec<(&str, &str)>,
31 embeddings: &dyn Embeddings,
32 ) -> Result<Self, SynapticError> {
33 let store = Self::new();
34 let docs = texts
35 .into_iter()
36 .map(|(id, content)| Document::new(id, content))
37 .collect();
38 store.add_documents(docs, embeddings).await?;
39 Ok(store)
40 }
41
42 pub async fn from_documents(
44 documents: Vec<Document>,
45 embeddings: &dyn Embeddings,
46 ) -> Result<Self, SynapticError> {
47 let store = Self::new();
48 store.add_documents(documents, embeddings).await?;
49 Ok(store)
50 }
51
52 pub async fn max_marginal_relevance_search(
61 &self,
62 query: &str,
63 k: usize,
64 fetch_k: usize,
65 lambda_mult: f32,
66 embeddings: &dyn Embeddings,
67 ) -> Result<Vec<Document>, SynapticError> {
68 let query_vec = embeddings.embed_query(query).await?;
69 let entries = self.entries.read().await;
70
71 let mut candidates: Vec<(String, Document, Vec<f32>, f32)> = entries
73 .values()
74 .map(|entry| {
75 let score = cosine_similarity(&query_vec, &entry.embedding);
76 (
77 entry.document.id.clone(),
78 entry.document.clone(),
79 entry.embedding.clone(),
80 score,
81 )
82 })
83 .collect();
84
85 candidates.sort_by(|a, b| b.3.partial_cmp(&a.3).unwrap_or(std::cmp::Ordering::Equal));
87 candidates.truncate(fetch_k);
88
89 if candidates.is_empty() || k == 0 {
90 return Ok(Vec::new());
91 }
92
93 let mut selected: Vec<(Document, Vec<f32>)> = Vec::with_capacity(k);
95 let mut remaining = candidates;
96
97 while selected.len() < k && !remaining.is_empty() {
98 let mut best_idx = 0;
99 let mut best_score = f32::NEG_INFINITY;
100
101 for (i, (_id, _doc, emb, query_sim)) in remaining.iter().enumerate() {
102 let max_sim_to_selected = if selected.is_empty() {
104 0.0
105 } else {
106 selected
107 .iter()
108 .map(|(_, sel_emb)| cosine_similarity(sel_emb, emb))
109 .fold(f32::NEG_INFINITY, f32::max)
110 };
111
112 let mmr_score = lambda_mult * query_sim - (1.0 - lambda_mult) * max_sim_to_selected;
113
114 if mmr_score > best_score {
115 best_score = mmr_score;
116 best_idx = i;
117 }
118 }
119
120 let (_id, doc, emb, _query_sim) = remaining.remove(best_idx);
121 selected.push((doc, emb));
122 }
123
124 Ok(selected.into_iter().map(|(doc, _)| doc).collect())
125 }
126}
127
128impl Default for InMemoryVectorStore {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134#[async_trait]
135impl VectorStore for InMemoryVectorStore {
136 async fn add_documents(
137 &self,
138 docs: Vec<Document>,
139 embeddings: &dyn Embeddings,
140 ) -> Result<Vec<String>, SynapticError> {
141 let texts: Vec<&str> = docs.iter().map(|d| d.content.as_str()).collect();
142 let vectors = embeddings.embed_documents(&texts).await?;
143
144 let mut entries = self.entries.write().await;
145 let mut ids = Vec::with_capacity(docs.len());
146
147 for (doc, embedding) in docs.into_iter().zip(vectors) {
148 ids.push(doc.id.clone());
149 entries.insert(
150 doc.id.clone(),
151 StoredEntry {
152 document: doc,
153 embedding,
154 },
155 );
156 }
157
158 Ok(ids)
159 }
160
161 async fn similarity_search(
162 &self,
163 query: &str,
164 k: usize,
165 embeddings: &dyn Embeddings,
166 ) -> Result<Vec<Document>, SynapticError> {
167 let results = self
168 .similarity_search_with_score(query, k, embeddings)
169 .await?;
170 Ok(results.into_iter().map(|(doc, _)| doc).collect())
171 }
172
173 async fn similarity_search_with_score(
174 &self,
175 query: &str,
176 k: usize,
177 embeddings: &dyn Embeddings,
178 ) -> Result<Vec<(Document, f32)>, SynapticError> {
179 let query_vec = embeddings.embed_query(query).await?;
180 let entries = self.entries.read().await;
181
182 let mut scored: Vec<(Document, f32)> = entries
183 .values()
184 .map(|entry| {
185 let score = cosine_similarity(&query_vec, &entry.embedding);
186 (entry.document.clone(), score)
187 })
188 .collect();
189
190 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
192 scored.truncate(k);
193
194 Ok(scored)
195 }
196
197 async fn similarity_search_by_vector(
198 &self,
199 embedding: &[f32],
200 k: usize,
201 ) -> Result<Vec<Document>, SynapticError> {
202 let entries = self.entries.read().await;
203
204 let mut scored: Vec<(Document, f32)> = entries
205 .values()
206 .map(|entry| {
207 let score = cosine_similarity(embedding, &entry.embedding);
208 (entry.document.clone(), score)
209 })
210 .collect();
211
212 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
213 scored.truncate(k);
214
215 Ok(scored.into_iter().map(|(doc, _)| doc).collect())
216 }
217
218 async fn delete(&self, ids: &[&str]) -> Result<(), SynapticError> {
219 let mut entries = self.entries.write().await;
220 for id in ids {
221 entries.remove(*id);
222 }
223 Ok(())
224 }
225}
226
227pub struct VectorStoreRetriever<S: VectorStore> {
229 store: Arc<S>,
230 embeddings: Arc<dyn Embeddings>,
231 k: usize,
232 score_threshold: Option<f32>,
233}
234
235impl<S: VectorStore + 'static> VectorStoreRetriever<S> {
236 pub fn new(store: Arc<S>, embeddings: Arc<dyn Embeddings>, k: usize) -> Self {
237 Self {
238 store,
239 embeddings,
240 k,
241 score_threshold: None,
242 }
243 }
244
245 pub fn with_score_threshold(mut self, threshold: f32) -> Self {
248 self.score_threshold = Some(threshold);
249 self
250 }
251}
252
253#[async_trait]
254impl<S: VectorStore + 'static> Retriever for VectorStoreRetriever<S> {
255 async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapticError> {
256 let k = if top_k > 0 { top_k } else { self.k };
257
258 if let Some(threshold) = self.score_threshold {
259 let scored = self
260 .store
261 .similarity_search_with_score(query, k, self.embeddings.as_ref())
262 .await?;
263 Ok(scored
264 .into_iter()
265 .filter(|(_, score)| *score >= threshold)
266 .map(|(doc, _)| doc)
267 .collect())
268 } else {
269 self.store
270 .similarity_search(query, k, self.embeddings.as_ref())
271 .await
272 }
273 }
274}
275
276fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
278 if a.len() != b.len() || a.is_empty() {
279 return 0.0;
280 }
281
282 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
283 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
284 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
285
286 if mag_a == 0.0 || mag_b == 0.0 {
287 return 0.0;
288 }
289
290 dot / (mag_a * mag_b)
291}