Skip to main content

synaptic_mongodb/
vector_store.rs

1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use bson::{doc, Bson, Document as BsonDocument};
5use futures::TryStreamExt;
6use mongodb::Client;
7use serde_json::Value;
8use synaptic_core::{Document, Embeddings, SynapticError, VectorStore};
9
10// ---------------------------------------------------------------------------
11// MongoVectorConfig
12// ---------------------------------------------------------------------------
13
14/// Configuration for a [`MongoVectorStore`].
15#[derive(Debug, Clone)]
16pub struct MongoVectorConfig {
17    /// MongoDB database name.
18    pub database: String,
19    /// MongoDB collection name.
20    pub collection: String,
21    /// Name of the Atlas Vector Search index (default: `vector_index`).
22    pub index_name: String,
23    /// Field name storing the embedding vector (default: `embedding`).
24    pub vector_field: String,
25    /// Field name storing the document content (default: `content`).
26    pub content_field: String,
27    /// Number of candidates for `$vectorSearch` (default: `10 * k`).
28    pub num_candidates: Option<i64>,
29}
30
31impl MongoVectorConfig {
32    /// Create a new config with the required database and collection names.
33    pub fn new(database: impl Into<String>, collection: impl Into<String>) -> Self {
34        Self {
35            database: database.into(),
36            collection: collection.into(),
37            index_name: "vector_index".to_string(),
38            vector_field: "embedding".to_string(),
39            content_field: "content".to_string(),
40            num_candidates: None,
41        }
42    }
43
44    /// Set the vector search index name.
45    pub fn with_index_name(mut self, index_name: impl Into<String>) -> Self {
46        self.index_name = index_name.into();
47        self
48    }
49
50    /// Set the field name for storing embedding vectors.
51    pub fn with_vector_field(mut self, vector_field: impl Into<String>) -> Self {
52        self.vector_field = vector_field.into();
53        self
54    }
55
56    /// Set the field name for storing document content.
57    pub fn with_content_field(mut self, content_field: impl Into<String>) -> Self {
58        self.content_field = content_field.into();
59        self
60    }
61
62    /// Set the number of candidates for `$vectorSearch`.
63    ///
64    /// If not set, defaults to `10 * k` at query time.
65    pub fn with_num_candidates(mut self, num_candidates: i64) -> Self {
66        self.num_candidates = Some(num_candidates);
67        self
68    }
69}
70
71// ---------------------------------------------------------------------------
72// MongoVectorStore
73// ---------------------------------------------------------------------------
74
75/// A [`VectorStore`] implementation backed by MongoDB Atlas Vector Search.
76///
77/// Documents are stored in a MongoDB collection with fields:
78/// - `_id`: the document ID
79/// - `content`: the document text
80/// - `embedding`: the vector embedding (array of doubles)
81/// - `metadata`: an embedded document with arbitrary metadata
82///
83/// Similarity search uses the `$vectorSearch` aggregation stage, which requires
84/// a pre-configured Atlas Vector Search index on the collection.
85pub struct MongoVectorStore {
86    config: MongoVectorConfig,
87    client: Client,
88    collection: mongodb::Collection<BsonDocument>,
89}
90
91impl MongoVectorStore {
92    /// Create a new store by connecting to MongoDB at the given URI.
93    pub async fn from_uri(uri: &str, config: MongoVectorConfig) -> Result<Self, SynapticError> {
94        let client = Client::with_uri_str(uri).await.map_err(|e| {
95            SynapticError::VectorStore(format!("failed to connect to MongoDB: {e}"))
96        })?;
97
98        Ok(Self::from_client(client, config))
99    }
100
101    /// Create a new store from an existing MongoDB client.
102    pub fn from_client(client: Client, config: MongoVectorConfig) -> Self {
103        let db = client.database(&config.database);
104        let collection = db.collection::<BsonDocument>(&config.collection);
105        Self {
106            config,
107            client,
108            collection,
109        }
110    }
111
112    /// Return a reference to the underlying MongoDB client.
113    pub fn client(&self) -> &Client {
114        &self.client
115    }
116
117    /// Return a reference to the configuration.
118    pub fn config(&self) -> &MongoVectorConfig {
119        &self.config
120    }
121
122    /// Return a reference to the underlying MongoDB collection.
123    pub fn collection(&self) -> &mongodb::Collection<BsonDocument> {
124        &self.collection
125    }
126
127    /// Compute the number of candidates to use in `$vectorSearch`.
128    fn num_candidates(&self, k: usize) -> i64 {
129        self.config
130            .num_candidates
131            .unwrap_or_else(|| (k as i64) * 10)
132    }
133}
134
135// ---------------------------------------------------------------------------
136// VectorStore implementation
137// ---------------------------------------------------------------------------
138
139#[async_trait]
140impl VectorStore for MongoVectorStore {
141    async fn add_documents(
142        &self,
143        docs: Vec<Document>,
144        embeddings: &dyn Embeddings,
145    ) -> Result<Vec<String>, SynapticError> {
146        if docs.is_empty() {
147            return Ok(Vec::new());
148        }
149
150        // Compute embeddings for all documents.
151        let texts: Vec<&str> = docs.iter().map(|d| d.content.as_str()).collect();
152        let vectors = embeddings.embed_documents(&texts).await?;
153
154        let mut ids = Vec::with_capacity(docs.len());
155        let mut bson_docs = Vec::with_capacity(docs.len());
156
157        for (doc, vector) in docs.into_iter().zip(vectors) {
158            let id = if doc.id.is_empty() {
159                bson::oid::ObjectId::new().to_hex()
160            } else {
161                doc.id.clone()
162            };
163
164            // Convert the embedding vector to BSON array of doubles.
165            let bson_vector: Vec<Bson> =
166                vector.into_iter().map(|v| Bson::Double(v as f64)).collect();
167
168            // Convert metadata to BSON document.
169            let metadata_bson = json_map_to_bson(&doc.metadata);
170
171            let bson_doc = doc! {
172                "_id": &id,
173                &self.config.content_field: &doc.content,
174                &self.config.vector_field: bson_vector,
175                "metadata": metadata_bson,
176            };
177
178            ids.push(id);
179            bson_docs.push(bson_doc);
180        }
181
182        self.collection
183            .insert_many(bson_docs)
184            .await
185            .map_err(|e| SynapticError::VectorStore(format!("MongoDB insert failed: {e}")))?;
186
187        Ok(ids)
188    }
189
190    async fn similarity_search(
191        &self,
192        query: &str,
193        k: usize,
194        embeddings: &dyn Embeddings,
195    ) -> Result<Vec<Document>, SynapticError> {
196        let results = self
197            .similarity_search_with_score(query, k, embeddings)
198            .await?;
199        Ok(results.into_iter().map(|(doc, _)| doc).collect())
200    }
201
202    async fn similarity_search_with_score(
203        &self,
204        query: &str,
205        k: usize,
206        embeddings: &dyn Embeddings,
207    ) -> Result<Vec<(Document, f32)>, SynapticError> {
208        let query_vec = embeddings.embed_query(query).await?;
209        self.similarity_search_by_vector_with_score(&query_vec, k)
210            .await
211    }
212
213    async fn similarity_search_by_vector(
214        &self,
215        embedding: &[f32],
216        k: usize,
217    ) -> Result<Vec<Document>, SynapticError> {
218        let results = self
219            .similarity_search_by_vector_with_score(embedding, k)
220            .await?;
221        Ok(results.into_iter().map(|(doc, _)| doc).collect())
222    }
223
224    async fn delete(&self, ids: &[&str]) -> Result<(), SynapticError> {
225        if ids.is_empty() {
226            return Ok(());
227        }
228
229        let id_values: Vec<Bson> = ids.iter().map(|id| Bson::String(id.to_string())).collect();
230
231        self.collection
232            .delete_many(doc! { "_id": { "$in": id_values } })
233            .await
234            .map_err(|e| SynapticError::VectorStore(format!("MongoDB delete failed: {e}")))?;
235
236        Ok(())
237    }
238}
239
240impl MongoVectorStore {
241    /// Search by vector and return documents with their similarity scores.
242    ///
243    /// Uses the `$vectorSearch` aggregation pipeline stage available in
244    /// MongoDB Atlas.
245    async fn similarity_search_by_vector_with_score(
246        &self,
247        embedding: &[f32],
248        k: usize,
249    ) -> Result<Vec<(Document, f32)>, SynapticError> {
250        let num_candidates = self.num_candidates(k);
251
252        // Convert embedding to BSON array.
253        let query_vector: Vec<Bson> = embedding.iter().map(|v| Bson::Double(*v as f64)).collect();
254
255        // Build the $vectorSearch stage.
256        let vector_search_stage = doc! {
257            "$vectorSearch": {
258                "index": &self.config.index_name,
259                "path": &self.config.vector_field,
260                "queryVector": query_vector,
261                "numCandidates": num_candidates,
262                "limit": k as i64,
263            }
264        };
265
266        // Build the $project stage to include the score.
267        let project_stage = doc! {
268            "$project": {
269                "_id": 1,
270                &self.config.content_field: 1,
271                "metadata": 1,
272                "score": { "$meta": "vectorSearchScore" },
273            }
274        };
275
276        let pipeline = vec![vector_search_stage, project_stage];
277
278        let mut cursor =
279            self.collection.aggregate(pipeline).await.map_err(|e| {
280                SynapticError::VectorStore(format!("MongoDB aggregation failed: {e}"))
281            })?;
282
283        let mut results = Vec::new();
284
285        while let Some(bson_doc) = cursor
286            .try_next()
287            .await
288            .map_err(|e| SynapticError::VectorStore(format!("MongoDB cursor error: {e}")))?
289        {
290            let id = bson_doc.get_str("_id").unwrap_or("").to_string();
291
292            let content = bson_doc
293                .get_str(&self.config.content_field)
294                .unwrap_or("")
295                .to_string();
296
297            let score = bson_doc.get_f64("score").unwrap_or(0.0) as f32;
298
299            let metadata = bson_doc
300                .get_document("metadata")
301                .ok()
302                .map(bson_doc_to_json_map)
303                .unwrap_or_default();
304
305            let doc = Document::with_metadata(id, content, metadata);
306            results.push((doc, score));
307        }
308
309        Ok(results)
310    }
311}
312
313// ---------------------------------------------------------------------------
314// Conversion helpers
315// ---------------------------------------------------------------------------
316
317/// Convert a JSON metadata map to a BSON document.
318fn json_map_to_bson(map: &HashMap<String, Value>) -> BsonDocument {
319    let mut doc = BsonDocument::new();
320    for (k, v) in map {
321        doc.insert(k.clone(), json_to_bson(v));
322    }
323    doc
324}
325
326/// Convert a `serde_json::Value` to a `bson::Bson` value.
327fn json_to_bson(value: &Value) -> Bson {
328    match value {
329        Value::Null => Bson::Null,
330        Value::Bool(b) => Bson::Boolean(*b),
331        Value::Number(n) => {
332            if let Some(i) = n.as_i64() {
333                Bson::Int64(i)
334            } else if let Some(f) = n.as_f64() {
335                Bson::Double(f)
336            } else {
337                Bson::Null
338            }
339        }
340        Value::String(s) => Bson::String(s.clone()),
341        Value::Array(arr) => Bson::Array(arr.iter().map(json_to_bson).collect()),
342        Value::Object(map) => {
343            let mut doc = BsonDocument::new();
344            for (k, v) in map {
345                doc.insert(k.clone(), json_to_bson(v));
346            }
347            Bson::Document(doc)
348        }
349    }
350}
351
352/// Convert a BSON document to a JSON metadata map.
353fn bson_doc_to_json_map(doc: &BsonDocument) -> HashMap<String, Value> {
354    let mut map = HashMap::new();
355    for (k, v) in doc {
356        map.insert(k.clone(), bson_to_json(v));
357    }
358    map
359}
360
361/// Convert a `bson::Bson` value to a `serde_json::Value`.
362fn bson_to_json(bson: &Bson) -> Value {
363    match bson {
364        Bson::Null => Value::Null,
365        Bson::Boolean(b) => Value::Bool(*b),
366        Bson::Int32(i) => Value::Number((*i as i64).into()),
367        Bson::Int64(i) => Value::Number((*i).into()),
368        Bson::Double(f) => serde_json::Number::from_f64(*f)
369            .map(Value::Number)
370            .unwrap_or(Value::Null),
371        Bson::String(s) => Value::String(s.clone()),
372        Bson::Array(arr) => Value::Array(arr.iter().map(bson_to_json).collect()),
373        Bson::Document(doc) => {
374            let map: serde_json::Map<String, Value> = doc
375                .iter()
376                .map(|(k, v)| (k.clone(), bson_to_json(v)))
377                .collect();
378            Value::Object(map)
379        }
380        Bson::ObjectId(oid) => Value::String(oid.to_hex()),
381        Bson::DateTime(dt) => Value::String(dt.to_string()),
382        Bson::Binary(bin) => Value::String(format!("<binary {} bytes>", bin.bytes.len())),
383        _ => Value::String(format!("{bson}")),
384    }
385}
386
387// ---------------------------------------------------------------------------
388// Tests
389// ---------------------------------------------------------------------------
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    #[test]
396    fn config_new_sets_defaults() {
397        let config = MongoVectorConfig::new("my_db", "my_collection");
398        assert_eq!(config.database, "my_db");
399        assert_eq!(config.collection, "my_collection");
400        assert_eq!(config.index_name, "vector_index");
401        assert_eq!(config.vector_field, "embedding");
402        assert_eq!(config.content_field, "content");
403        assert!(config.num_candidates.is_none());
404    }
405
406    #[test]
407    fn config_with_index_name() {
408        let config = MongoVectorConfig::new("db", "col").with_index_name("custom_index");
409        assert_eq!(config.index_name, "custom_index");
410    }
411
412    #[test]
413    fn config_with_vector_field() {
414        let config = MongoVectorConfig::new("db", "col").with_vector_field("vec");
415        assert_eq!(config.vector_field, "vec");
416    }
417
418    #[test]
419    fn config_with_content_field() {
420        let config = MongoVectorConfig::new("db", "col").with_content_field("text");
421        assert_eq!(config.content_field, "text");
422    }
423
424    #[test]
425    fn config_with_num_candidates() {
426        let config = MongoVectorConfig::new("db", "col").with_num_candidates(200);
427        assert_eq!(config.num_candidates, Some(200));
428    }
429
430    #[test]
431    fn config_builder_chain() {
432        let config = MongoVectorConfig::new("test_db", "embeddings")
433            .with_index_name("my_vs_index")
434            .with_vector_field("vec_field")
435            .with_content_field("text_field")
436            .with_num_candidates(500);
437
438        assert_eq!(config.database, "test_db");
439        assert_eq!(config.collection, "embeddings");
440        assert_eq!(config.index_name, "my_vs_index");
441        assert_eq!(config.vector_field, "vec_field");
442        assert_eq!(config.content_field, "text_field");
443        assert_eq!(config.num_candidates, Some(500));
444    }
445
446    #[test]
447    fn json_to_bson_roundtrip_string() {
448        let json = Value::String("hello".into());
449        let bson = json_to_bson(&json);
450        let back = bson_to_json(&bson);
451        assert_eq!(json, back);
452    }
453
454    #[test]
455    fn json_to_bson_roundtrip_number_int() {
456        let json = serde_json::json!(42);
457        let bson = json_to_bson(&json);
458        let back = bson_to_json(&bson);
459        assert_eq!(json, back);
460    }
461
462    #[test]
463    fn json_to_bson_roundtrip_number_float() {
464        let json = serde_json::json!(3.14);
465        let bson = json_to_bson(&json);
466        let back = bson_to_json(&bson);
467        assert_eq!(json, back);
468    }
469
470    #[test]
471    fn json_to_bson_roundtrip_bool() {
472        let json = Value::Bool(true);
473        let bson = json_to_bson(&json);
474        let back = bson_to_json(&bson);
475        assert_eq!(json, back);
476    }
477
478    #[test]
479    fn json_to_bson_roundtrip_null() {
480        let json = Value::Null;
481        let bson = json_to_bson(&json);
482        let back = bson_to_json(&bson);
483        assert_eq!(json, back);
484    }
485
486    #[test]
487    fn json_to_bson_roundtrip_array() {
488        let json = serde_json::json!([1, "two", true]);
489        let bson = json_to_bson(&json);
490        let back = bson_to_json(&bson);
491        assert_eq!(json, back);
492    }
493
494    #[test]
495    fn json_to_bson_roundtrip_object() {
496        let json = serde_json::json!({"key": "value", "num": 42});
497        let bson = json_to_bson(&json);
498        let back = bson_to_json(&bson);
499        assert_eq!(json, back);
500    }
501
502    #[test]
503    fn json_map_to_bson_and_back() {
504        let mut map = HashMap::new();
505        map.insert("source".to_string(), Value::String("test".into()));
506        map.insert("page".to_string(), serde_json::json!(42));
507
508        let bson_doc = json_map_to_bson(&map);
509        let back = bson_doc_to_json_map(&bson_doc);
510
511        assert_eq!(map, back);
512    }
513
514    #[test]
515    fn num_candidates_default() {
516        let config = MongoVectorConfig::new("db", "col");
517        // We cannot call num_candidates() without a MongoVectorStore, but we can
518        // test the logic directly.
519        let k = 10_usize;
520        let result = config.num_candidates.unwrap_or_else(|| (k as i64) * 10);
521        assert_eq!(result, 100);
522    }
523
524    #[test]
525    fn num_candidates_custom() {
526        let config = MongoVectorConfig::new("db", "col").with_num_candidates(200);
527        let k = 10_usize;
528        let result = config.num_candidates.unwrap_or_else(|| (k as i64) * 10);
529        assert_eq!(result, 200);
530    }
531}