1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use synaptic_core::SynapticError;
6use synaptic_embeddings::Embeddings;
7use synaptic_retrieval::{Document, Retriever};
8use tokio::sync::RwLock;
9
10use crate::VectorStore;
11
12struct StoredEntry {
14 document: Document,
15 embedding: Vec<f32>,
16}
17
18pub struct InMemoryVectorStore {
20 entries: RwLock<HashMap<String, StoredEntry>>,
21}
22
23impl InMemoryVectorStore {
24 pub fn new() -> Self {
25 Self {
26 entries: RwLock::new(HashMap::new()),
27 }
28 }
29
30 pub async fn from_texts(
32 texts: Vec<(&str, &str)>,
33 embeddings: &dyn Embeddings,
34 ) -> Result<Self, SynapticError> {
35 let store = Self::new();
36 let docs = texts
37 .into_iter()
38 .map(|(id, content)| Document::new(id, content))
39 .collect();
40 store.add_documents(docs, embeddings).await?;
41 Ok(store)
42 }
43
44 pub async fn from_documents(
46 documents: Vec<Document>,
47 embeddings: &dyn Embeddings,
48 ) -> Result<Self, SynapticError> {
49 let store = Self::new();
50 store.add_documents(documents, embeddings).await?;
51 Ok(store)
52 }
53
54 pub async fn max_marginal_relevance_search(
63 &self,
64 query: &str,
65 k: usize,
66 fetch_k: usize,
67 lambda_mult: f32,
68 embeddings: &dyn Embeddings,
69 ) -> Result<Vec<Document>, SynapticError> {
70 let query_vec = embeddings.embed_query(query).await?;
71 let entries = self.entries.read().await;
72
73 let mut candidates: Vec<(String, Document, Vec<f32>, f32)> = entries
75 .values()
76 .map(|entry| {
77 let score = cosine_similarity(&query_vec, &entry.embedding);
78 (
79 entry.document.id.clone(),
80 entry.document.clone(),
81 entry.embedding.clone(),
82 score,
83 )
84 })
85 .collect();
86
87 candidates.sort_by(|a, b| b.3.partial_cmp(&a.3).unwrap_or(std::cmp::Ordering::Equal));
89 candidates.truncate(fetch_k);
90
91 if candidates.is_empty() || k == 0 {
92 return Ok(Vec::new());
93 }
94
95 let mut selected: Vec<(Document, Vec<f32>)> = Vec::with_capacity(k);
97 let mut remaining = candidates;
98
99 while selected.len() < k && !remaining.is_empty() {
100 let mut best_idx = 0;
101 let mut best_score = f32::NEG_INFINITY;
102
103 for (i, (_id, _doc, emb, query_sim)) in remaining.iter().enumerate() {
104 let max_sim_to_selected = if selected.is_empty() {
106 0.0
107 } else {
108 selected
109 .iter()
110 .map(|(_, sel_emb)| cosine_similarity(sel_emb, emb))
111 .fold(f32::NEG_INFINITY, f32::max)
112 };
113
114 let mmr_score = lambda_mult * query_sim - (1.0 - lambda_mult) * max_sim_to_selected;
115
116 if mmr_score > best_score {
117 best_score = mmr_score;
118 best_idx = i;
119 }
120 }
121
122 let (_id, doc, emb, _query_sim) = remaining.remove(best_idx);
123 selected.push((doc, emb));
124 }
125
126 Ok(selected.into_iter().map(|(doc, _)| doc).collect())
127 }
128}
129
130impl Default for InMemoryVectorStore {
131 fn default() -> Self {
132 Self::new()
133 }
134}
135
136#[async_trait]
137impl VectorStore for InMemoryVectorStore {
138 async fn add_documents(
139 &self,
140 docs: Vec<Document>,
141 embeddings: &dyn Embeddings,
142 ) -> Result<Vec<String>, SynapticError> {
143 let texts: Vec<&str> = docs.iter().map(|d| d.content.as_str()).collect();
144 let vectors = embeddings.embed_documents(&texts).await?;
145
146 let mut entries = self.entries.write().await;
147 let mut ids = Vec::with_capacity(docs.len());
148
149 for (doc, embedding) in docs.into_iter().zip(vectors) {
150 ids.push(doc.id.clone());
151 entries.insert(
152 doc.id.clone(),
153 StoredEntry {
154 document: doc,
155 embedding,
156 },
157 );
158 }
159
160 Ok(ids)
161 }
162
163 async fn similarity_search(
164 &self,
165 query: &str,
166 k: usize,
167 embeddings: &dyn Embeddings,
168 ) -> Result<Vec<Document>, SynapticError> {
169 let results = self
170 .similarity_search_with_score(query, k, embeddings)
171 .await?;
172 Ok(results.into_iter().map(|(doc, _)| doc).collect())
173 }
174
175 async fn similarity_search_with_score(
176 &self,
177 query: &str,
178 k: usize,
179 embeddings: &dyn Embeddings,
180 ) -> Result<Vec<(Document, f32)>, SynapticError> {
181 let query_vec = embeddings.embed_query(query).await?;
182 let entries = self.entries.read().await;
183
184 let mut scored: Vec<(Document, f32)> = entries
185 .values()
186 .map(|entry| {
187 let score = cosine_similarity(&query_vec, &entry.embedding);
188 (entry.document.clone(), score)
189 })
190 .collect();
191
192 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
194 scored.truncate(k);
195
196 Ok(scored)
197 }
198
199 async fn similarity_search_by_vector(
200 &self,
201 embedding: &[f32],
202 k: usize,
203 ) -> Result<Vec<Document>, SynapticError> {
204 let entries = self.entries.read().await;
205
206 let mut scored: Vec<(Document, f32)> = entries
207 .values()
208 .map(|entry| {
209 let score = cosine_similarity(embedding, &entry.embedding);
210 (entry.document.clone(), score)
211 })
212 .collect();
213
214 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
215 scored.truncate(k);
216
217 Ok(scored.into_iter().map(|(doc, _)| doc).collect())
218 }
219
220 async fn delete(&self, ids: &[&str]) -> Result<(), SynapticError> {
221 let mut entries = self.entries.write().await;
222 for id in ids {
223 entries.remove(*id);
224 }
225 Ok(())
226 }
227}
228
229pub struct VectorStoreRetriever<S: VectorStore> {
231 store: Arc<S>,
232 embeddings: Arc<dyn Embeddings>,
233 k: usize,
234 score_threshold: Option<f32>,
235}
236
237impl<S: VectorStore + 'static> VectorStoreRetriever<S> {
238 pub fn new(store: Arc<S>, embeddings: Arc<dyn Embeddings>, k: usize) -> Self {
239 Self {
240 store,
241 embeddings,
242 k,
243 score_threshold: None,
244 }
245 }
246
247 pub fn with_score_threshold(mut self, threshold: f32) -> Self {
250 self.score_threshold = Some(threshold);
251 self
252 }
253}
254
255#[async_trait]
256impl<S: VectorStore + 'static> Retriever for VectorStoreRetriever<S> {
257 async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapticError> {
258 let k = if top_k > 0 { top_k } else { self.k };
259
260 if let Some(threshold) = self.score_threshold {
261 let scored = self
262 .store
263 .similarity_search_with_score(query, k, self.embeddings.as_ref())
264 .await?;
265 Ok(scored
266 .into_iter()
267 .filter(|(_, score)| *score >= threshold)
268 .map(|(doc, _)| doc)
269 .collect())
270 } else {
271 self.store
272 .similarity_search(query, k, self.embeddings.as_ref())
273 .await
274 }
275 }
276}
277
278fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
280 if a.len() != b.len() || a.is_empty() {
281 return 0.0;
282 }
283
284 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
285 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
286 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
287
288 if mag_a == 0.0 || mag_b == 0.0 {
289 return 0.0;
290 }
291
292 dot / (mag_a * mag_b)
293}