Skip to main content

synaptic_vectorstores/
in_memory.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use synaptic_core::SynapseError;
6use synaptic_embeddings::Embeddings;
7use synaptic_retrieval::{Document, Retriever};
8use tokio::sync::RwLock;
9
10use crate::VectorStore;
11
12/// Stored document with its embedding vector.
13struct StoredEntry {
14    document: Document,
15    embedding: Vec<f32>,
16}
17
18/// In-memory vector store using cosine similarity.
19pub struct InMemoryVectorStore {
20    entries: RwLock<HashMap<String, StoredEntry>>,
21}
22
23impl InMemoryVectorStore {
24    pub fn new() -> Self {
25        Self {
26            entries: RwLock::new(HashMap::new()),
27        }
28    }
29
30    /// Create a new store pre-populated with texts.
31    pub async fn from_texts(
32        texts: Vec<(&str, &str)>,
33        embeddings: &dyn Embeddings,
34    ) -> Result<Self, SynapseError> {
35        let store = Self::new();
36        let docs = texts
37            .into_iter()
38            .map(|(id, content)| Document::new(id, content))
39            .collect();
40        store.add_documents(docs, embeddings).await?;
41        Ok(store)
42    }
43
44    /// Create a new store pre-populated with documents.
45    pub async fn from_documents(
46        documents: Vec<Document>,
47        embeddings: &dyn Embeddings,
48    ) -> Result<Self, SynapseError> {
49        let store = Self::new();
50        store.add_documents(documents, embeddings).await?;
51        Ok(store)
52    }
53
54    /// Maximum Marginal Relevance search for diverse results.
55    ///
56    /// `lambda_mult` controls the trade-off between relevance and diversity:
57    /// - 1.0 = pure relevance (equivalent to standard similarity search)
58    /// - 0.0 = maximum diversity
59    /// - 0.5 = balanced (typical default)
60    ///
61    /// `fetch_k` is the number of initial candidates to fetch before MMR filtering.
62    pub async fn max_marginal_relevance_search(
63        &self,
64        query: &str,
65        k: usize,
66        fetch_k: usize,
67        lambda_mult: f32,
68        embeddings: &dyn Embeddings,
69    ) -> Result<Vec<Document>, SynapseError> {
70        let query_vec = embeddings.embed_query(query).await?;
71        let entries = self.entries.read().await;
72
73        // Score all candidates against the query
74        let mut candidates: Vec<(String, Document, Vec<f32>, f32)> = entries
75            .values()
76            .map(|entry| {
77                let score = cosine_similarity(&query_vec, &entry.embedding);
78                (
79                    entry.document.id.clone(),
80                    entry.document.clone(),
81                    entry.embedding.clone(),
82                    score,
83                )
84            })
85            .collect();
86
87        // Sort by query similarity descending and take top fetch_k
88        candidates.sort_by(|a, b| b.3.partial_cmp(&a.3).unwrap_or(std::cmp::Ordering::Equal));
89        candidates.truncate(fetch_k);
90
91        if candidates.is_empty() || k == 0 {
92            return Ok(Vec::new());
93        }
94
95        // Greedy MMR selection
96        let mut selected: Vec<(Document, Vec<f32>)> = Vec::with_capacity(k);
97        let mut remaining = candidates;
98
99        while selected.len() < k && !remaining.is_empty() {
100            let mut best_idx = 0;
101            let mut best_score = f32::NEG_INFINITY;
102
103            for (i, (_id, _doc, emb, query_sim)) in remaining.iter().enumerate() {
104                // Compute max similarity to already-selected documents
105                let max_sim_to_selected = if selected.is_empty() {
106                    0.0
107                } else {
108                    selected
109                        .iter()
110                        .map(|(_, sel_emb)| cosine_similarity(sel_emb, emb))
111                        .fold(f32::NEG_INFINITY, f32::max)
112                };
113
114                let mmr_score = lambda_mult * query_sim - (1.0 - lambda_mult) * max_sim_to_selected;
115
116                if mmr_score > best_score {
117                    best_score = mmr_score;
118                    best_idx = i;
119                }
120            }
121
122            let (_id, doc, emb, _query_sim) = remaining.remove(best_idx);
123            selected.push((doc, emb));
124        }
125
126        Ok(selected.into_iter().map(|(doc, _)| doc).collect())
127    }
128}
129
130impl Default for InMemoryVectorStore {
131    fn default() -> Self {
132        Self::new()
133    }
134}
135
136#[async_trait]
137impl VectorStore for InMemoryVectorStore {
138    async fn add_documents(
139        &self,
140        docs: Vec<Document>,
141        embeddings: &dyn Embeddings,
142    ) -> Result<Vec<String>, SynapseError> {
143        let texts: Vec<&str> = docs.iter().map(|d| d.content.as_str()).collect();
144        let vectors = embeddings.embed_documents(&texts).await?;
145
146        let mut entries = self.entries.write().await;
147        let mut ids = Vec::with_capacity(docs.len());
148
149        for (doc, embedding) in docs.into_iter().zip(vectors) {
150            ids.push(doc.id.clone());
151            entries.insert(
152                doc.id.clone(),
153                StoredEntry {
154                    document: doc,
155                    embedding,
156                },
157            );
158        }
159
160        Ok(ids)
161    }
162
163    async fn similarity_search(
164        &self,
165        query: &str,
166        k: usize,
167        embeddings: &dyn Embeddings,
168    ) -> Result<Vec<Document>, SynapseError> {
169        let results = self
170            .similarity_search_with_score(query, k, embeddings)
171            .await?;
172        Ok(results.into_iter().map(|(doc, _)| doc).collect())
173    }
174
175    async fn similarity_search_with_score(
176        &self,
177        query: &str,
178        k: usize,
179        embeddings: &dyn Embeddings,
180    ) -> Result<Vec<(Document, f32)>, SynapseError> {
181        let query_vec = embeddings.embed_query(query).await?;
182        let entries = self.entries.read().await;
183
184        let mut scored: Vec<(Document, f32)> = entries
185            .values()
186            .map(|entry| {
187                let score = cosine_similarity(&query_vec, &entry.embedding);
188                (entry.document.clone(), score)
189            })
190            .collect();
191
192        // Sort by score descending
193        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
194        scored.truncate(k);
195
196        Ok(scored)
197    }
198
199    async fn similarity_search_by_vector(
200        &self,
201        embedding: &[f32],
202        k: usize,
203    ) -> Result<Vec<Document>, SynapseError> {
204        let entries = self.entries.read().await;
205
206        let mut scored: Vec<(Document, f32)> = entries
207            .values()
208            .map(|entry| {
209                let score = cosine_similarity(embedding, &entry.embedding);
210                (entry.document.clone(), score)
211            })
212            .collect();
213
214        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
215        scored.truncate(k);
216
217        Ok(scored.into_iter().map(|(doc, _)| doc).collect())
218    }
219
220    async fn delete(&self, ids: &[&str]) -> Result<(), SynapseError> {
221        let mut entries = self.entries.write().await;
222        for id in ids {
223            entries.remove(*id);
224        }
225        Ok(())
226    }
227}
228
229/// A retriever that wraps a VectorStore, bridging it to the `Retriever` trait.
230pub struct VectorStoreRetriever<S: VectorStore> {
231    store: Arc<S>,
232    embeddings: Arc<dyn Embeddings>,
233    k: usize,
234    score_threshold: Option<f32>,
235}
236
237impl<S: VectorStore + 'static> VectorStoreRetriever<S> {
238    pub fn new(store: Arc<S>, embeddings: Arc<dyn Embeddings>, k: usize) -> Self {
239        Self {
240            store,
241            embeddings,
242            k,
243            score_threshold: None,
244        }
245    }
246
247    /// Set a minimum similarity score threshold. Only documents with a score
248    /// greater than or equal to the threshold will be returned.
249    pub fn with_score_threshold(mut self, threshold: f32) -> Self {
250        self.score_threshold = Some(threshold);
251        self
252    }
253}
254
255#[async_trait]
256impl<S: VectorStore + 'static> Retriever for VectorStoreRetriever<S> {
257    async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapseError> {
258        let k = if top_k > 0 { top_k } else { self.k };
259
260        if let Some(threshold) = self.score_threshold {
261            let scored = self
262                .store
263                .similarity_search_with_score(query, k, self.embeddings.as_ref())
264                .await?;
265            Ok(scored
266                .into_iter()
267                .filter(|(_, score)| *score >= threshold)
268                .map(|(doc, _)| doc)
269                .collect())
270        } else {
271            self.store
272                .similarity_search(query, k, self.embeddings.as_ref())
273                .await
274        }
275    }
276}
277
278/// Compute cosine similarity between two vectors.
279fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
280    if a.len() != b.len() || a.is_empty() {
281        return 0.0;
282    }
283
284    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
285    let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
286    let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
287
288    if mag_a == 0.0 || mag_b == 0.0 {
289        return 0.0;
290    }
291
292    dot / (mag_a * mag_b)
293}