Skip to main content

synaptic_vectorstores/
in_memory.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use synaptic_core::{Document, Embeddings, Retriever, SynapticError};
6use tokio::sync::RwLock;
7
8use crate::VectorStore;
9
10/// Stored document with its embedding vector.
11struct StoredEntry {
12    document: Document,
13    embedding: Vec<f32>,
14}
15
16/// In-memory vector store using cosine similarity.
17pub struct InMemoryVectorStore {
18    entries: RwLock<HashMap<String, StoredEntry>>,
19}
20
21impl InMemoryVectorStore {
22    pub fn new() -> Self {
23        Self {
24            entries: RwLock::new(HashMap::new()),
25        }
26    }
27
28    /// Create a new store pre-populated with texts.
29    pub async fn from_texts(
30        texts: Vec<(&str, &str)>,
31        embeddings: &dyn Embeddings,
32    ) -> Result<Self, SynapticError> {
33        let store = Self::new();
34        let docs = texts
35            .into_iter()
36            .map(|(id, content)| Document::new(id, content))
37            .collect();
38        store.add_documents(docs, embeddings).await?;
39        Ok(store)
40    }
41
42    /// Create a new store pre-populated with documents.
43    pub async fn from_documents(
44        documents: Vec<Document>,
45        embeddings: &dyn Embeddings,
46    ) -> Result<Self, SynapticError> {
47        let store = Self::new();
48        store.add_documents(documents, embeddings).await?;
49        Ok(store)
50    }
51
52    /// Maximum Marginal Relevance search for diverse results.
53    ///
54    /// `lambda_mult` controls the trade-off between relevance and diversity:
55    /// - 1.0 = pure relevance (equivalent to standard similarity search)
56    /// - 0.0 = maximum diversity
57    /// - 0.5 = balanced (typical default)
58    ///
59    /// `fetch_k` is the number of initial candidates to fetch before MMR filtering.
60    pub async fn max_marginal_relevance_search(
61        &self,
62        query: &str,
63        k: usize,
64        fetch_k: usize,
65        lambda_mult: f32,
66        embeddings: &dyn Embeddings,
67    ) -> Result<Vec<Document>, SynapticError> {
68        let query_vec = embeddings.embed_query(query).await?;
69        let entries = self.entries.read().await;
70
71        // Score all candidates against the query
72        let mut candidates: Vec<(String, Document, Vec<f32>, f32)> = entries
73            .values()
74            .map(|entry| {
75                let score = cosine_similarity(&query_vec, &entry.embedding);
76                (
77                    entry.document.id.clone(),
78                    entry.document.clone(),
79                    entry.embedding.clone(),
80                    score,
81                )
82            })
83            .collect();
84
85        // Sort by query similarity descending and take top fetch_k
86        candidates.sort_by(|a, b| b.3.partial_cmp(&a.3).unwrap_or(std::cmp::Ordering::Equal));
87        candidates.truncate(fetch_k);
88
89        if candidates.is_empty() || k == 0 {
90            return Ok(Vec::new());
91        }
92
93        // Greedy MMR selection
94        let mut selected: Vec<(Document, Vec<f32>)> = Vec::with_capacity(k);
95        let mut remaining = candidates;
96
97        while selected.len() < k && !remaining.is_empty() {
98            let mut best_idx = 0;
99            let mut best_score = f32::NEG_INFINITY;
100
101            for (i, (_id, _doc, emb, query_sim)) in remaining.iter().enumerate() {
102                // Compute max similarity to already-selected documents
103                let max_sim_to_selected = if selected.is_empty() {
104                    0.0
105                } else {
106                    selected
107                        .iter()
108                        .map(|(_, sel_emb)| cosine_similarity(sel_emb, emb))
109                        .fold(f32::NEG_INFINITY, f32::max)
110                };
111
112                let mmr_score = lambda_mult * query_sim - (1.0 - lambda_mult) * max_sim_to_selected;
113
114                if mmr_score > best_score {
115                    best_score = mmr_score;
116                    best_idx = i;
117                }
118            }
119
120            let (_id, doc, emb, _query_sim) = remaining.remove(best_idx);
121            selected.push((doc, emb));
122        }
123
124        Ok(selected.into_iter().map(|(doc, _)| doc).collect())
125    }
126}
127
128impl Default for InMemoryVectorStore {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134#[async_trait]
135impl VectorStore for InMemoryVectorStore {
136    async fn add_documents(
137        &self,
138        docs: Vec<Document>,
139        embeddings: &dyn Embeddings,
140    ) -> Result<Vec<String>, SynapticError> {
141        let texts: Vec<&str> = docs.iter().map(|d| d.content.as_str()).collect();
142        let vectors = embeddings.embed_documents(&texts).await?;
143
144        let mut entries = self.entries.write().await;
145        let mut ids = Vec::with_capacity(docs.len());
146
147        for (doc, embedding) in docs.into_iter().zip(vectors) {
148            ids.push(doc.id.clone());
149            entries.insert(
150                doc.id.clone(),
151                StoredEntry {
152                    document: doc,
153                    embedding,
154                },
155            );
156        }
157
158        Ok(ids)
159    }
160
161    async fn similarity_search(
162        &self,
163        query: &str,
164        k: usize,
165        embeddings: &dyn Embeddings,
166    ) -> Result<Vec<Document>, SynapticError> {
167        let results = self
168            .similarity_search_with_score(query, k, embeddings)
169            .await?;
170        Ok(results.into_iter().map(|(doc, _)| doc).collect())
171    }
172
173    async fn similarity_search_with_score(
174        &self,
175        query: &str,
176        k: usize,
177        embeddings: &dyn Embeddings,
178    ) -> Result<Vec<(Document, f32)>, SynapticError> {
179        let query_vec = embeddings.embed_query(query).await?;
180        let entries = self.entries.read().await;
181
182        let mut scored: Vec<(Document, f32)> = entries
183            .values()
184            .map(|entry| {
185                let score = cosine_similarity(&query_vec, &entry.embedding);
186                (entry.document.clone(), score)
187            })
188            .collect();
189
190        // Sort by score descending
191        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
192        scored.truncate(k);
193
194        Ok(scored)
195    }
196
197    async fn similarity_search_by_vector(
198        &self,
199        embedding: &[f32],
200        k: usize,
201    ) -> Result<Vec<Document>, SynapticError> {
202        let entries = self.entries.read().await;
203
204        let mut scored: Vec<(Document, f32)> = entries
205            .values()
206            .map(|entry| {
207                let score = cosine_similarity(embedding, &entry.embedding);
208                (entry.document.clone(), score)
209            })
210            .collect();
211
212        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
213        scored.truncate(k);
214
215        Ok(scored.into_iter().map(|(doc, _)| doc).collect())
216    }
217
218    async fn delete(&self, ids: &[&str]) -> Result<(), SynapticError> {
219        let mut entries = self.entries.write().await;
220        for id in ids {
221            entries.remove(*id);
222        }
223        Ok(())
224    }
225}
226
227/// A retriever that wraps a VectorStore, bridging it to the `Retriever` trait.
228pub struct VectorStoreRetriever<S: VectorStore> {
229    store: Arc<S>,
230    embeddings: Arc<dyn Embeddings>,
231    k: usize,
232    score_threshold: Option<f32>,
233}
234
235impl<S: VectorStore + 'static> VectorStoreRetriever<S> {
236    pub fn new(store: Arc<S>, embeddings: Arc<dyn Embeddings>, k: usize) -> Self {
237        Self {
238            store,
239            embeddings,
240            k,
241            score_threshold: None,
242        }
243    }
244
245    /// Set a minimum similarity score threshold. Only documents with a score
246    /// greater than or equal to the threshold will be returned.
247    pub fn with_score_threshold(mut self, threshold: f32) -> Self {
248        self.score_threshold = Some(threshold);
249        self
250    }
251}
252
253#[async_trait]
254impl<S: VectorStore + 'static> Retriever for VectorStoreRetriever<S> {
255    async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapticError> {
256        let k = if top_k > 0 { top_k } else { self.k };
257
258        if let Some(threshold) = self.score_threshold {
259            let scored = self
260                .store
261                .similarity_search_with_score(query, k, self.embeddings.as_ref())
262                .await?;
263            Ok(scored
264                .into_iter()
265                .filter(|(_, score)| *score >= threshold)
266                .map(|(doc, _)| doc)
267                .collect())
268        } else {
269            self.store
270                .similarity_search(query, k, self.embeddings.as_ref())
271                .await
272        }
273    }
274}
275
276/// Compute cosine similarity between two vectors.
277fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
278    if a.len() != b.len() || a.is_empty() {
279        return 0.0;
280    }
281
282    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
283    let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
284    let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
285
286    if mag_a == 0.0 || mag_b == 0.0 {
287        return 0.0;
288    }
289
290    dot / (mag_a * mag_b)
291}