1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use bson::{doc, Bson, Document as BsonDocument};
5use futures::TryStreamExt;
6use mongodb::Client;
7use serde_json::Value;
8use synaptic_core::{Document, Embeddings, SynapticError, VectorStore};
9
10#[derive(Debug, Clone)]
16pub struct MongoVectorConfig {
17 pub database: String,
19 pub collection: String,
21 pub index_name: String,
23 pub vector_field: String,
25 pub content_field: String,
27 pub num_candidates: Option<i64>,
29}
30
31impl MongoVectorConfig {
32 pub fn new(database: impl Into<String>, collection: impl Into<String>) -> Self {
34 Self {
35 database: database.into(),
36 collection: collection.into(),
37 index_name: "vector_index".to_string(),
38 vector_field: "embedding".to_string(),
39 content_field: "content".to_string(),
40 num_candidates: None,
41 }
42 }
43
44 pub fn with_index_name(mut self, index_name: impl Into<String>) -> Self {
46 self.index_name = index_name.into();
47 self
48 }
49
50 pub fn with_vector_field(mut self, vector_field: impl Into<String>) -> Self {
52 self.vector_field = vector_field.into();
53 self
54 }
55
56 pub fn with_content_field(mut self, content_field: impl Into<String>) -> Self {
58 self.content_field = content_field.into();
59 self
60 }
61
62 pub fn with_num_candidates(mut self, num_candidates: i64) -> Self {
66 self.num_candidates = Some(num_candidates);
67 self
68 }
69}
70
71pub struct MongoVectorStore {
86 config: MongoVectorConfig,
87 client: Client,
88 collection: mongodb::Collection<BsonDocument>,
89}
90
91impl MongoVectorStore {
92 pub async fn from_uri(uri: &str, config: MongoVectorConfig) -> Result<Self, SynapticError> {
94 let client = Client::with_uri_str(uri).await.map_err(|e| {
95 SynapticError::VectorStore(format!("failed to connect to MongoDB: {e}"))
96 })?;
97
98 Ok(Self::from_client(client, config))
99 }
100
101 pub fn from_client(client: Client, config: MongoVectorConfig) -> Self {
103 let db = client.database(&config.database);
104 let collection = db.collection::<BsonDocument>(&config.collection);
105 Self {
106 config,
107 client,
108 collection,
109 }
110 }
111
112 pub fn client(&self) -> &Client {
114 &self.client
115 }
116
117 pub fn config(&self) -> &MongoVectorConfig {
119 &self.config
120 }
121
122 pub fn collection(&self) -> &mongodb::Collection<BsonDocument> {
124 &self.collection
125 }
126
127 fn num_candidates(&self, k: usize) -> i64 {
129 self.config
130 .num_candidates
131 .unwrap_or_else(|| (k as i64) * 10)
132 }
133}
134
135#[async_trait]
140impl VectorStore for MongoVectorStore {
141 async fn add_documents(
142 &self,
143 docs: Vec<Document>,
144 embeddings: &dyn Embeddings,
145 ) -> Result<Vec<String>, SynapticError> {
146 if docs.is_empty() {
147 return Ok(Vec::new());
148 }
149
150 let texts: Vec<&str> = docs.iter().map(|d| d.content.as_str()).collect();
152 let vectors = embeddings.embed_documents(&texts).await?;
153
154 let mut ids = Vec::with_capacity(docs.len());
155 let mut bson_docs = Vec::with_capacity(docs.len());
156
157 for (doc, vector) in docs.into_iter().zip(vectors) {
158 let id = if doc.id.is_empty() {
159 bson::oid::ObjectId::new().to_hex()
160 } else {
161 doc.id.clone()
162 };
163
164 let bson_vector: Vec<Bson> =
166 vector.into_iter().map(|v| Bson::Double(v as f64)).collect();
167
168 let metadata_bson = json_map_to_bson(&doc.metadata);
170
171 let bson_doc = doc! {
172 "_id": &id,
173 &self.config.content_field: &doc.content,
174 &self.config.vector_field: bson_vector,
175 "metadata": metadata_bson,
176 };
177
178 ids.push(id);
179 bson_docs.push(bson_doc);
180 }
181
182 self.collection
183 .insert_many(bson_docs)
184 .await
185 .map_err(|e| SynapticError::VectorStore(format!("MongoDB insert failed: {e}")))?;
186
187 Ok(ids)
188 }
189
190 async fn similarity_search(
191 &self,
192 query: &str,
193 k: usize,
194 embeddings: &dyn Embeddings,
195 ) -> Result<Vec<Document>, SynapticError> {
196 let results = self
197 .similarity_search_with_score(query, k, embeddings)
198 .await?;
199 Ok(results.into_iter().map(|(doc, _)| doc).collect())
200 }
201
202 async fn similarity_search_with_score(
203 &self,
204 query: &str,
205 k: usize,
206 embeddings: &dyn Embeddings,
207 ) -> Result<Vec<(Document, f32)>, SynapticError> {
208 let query_vec = embeddings.embed_query(query).await?;
209 self.similarity_search_by_vector_with_score(&query_vec, k)
210 .await
211 }
212
213 async fn similarity_search_by_vector(
214 &self,
215 embedding: &[f32],
216 k: usize,
217 ) -> Result<Vec<Document>, SynapticError> {
218 let results = self
219 .similarity_search_by_vector_with_score(embedding, k)
220 .await?;
221 Ok(results.into_iter().map(|(doc, _)| doc).collect())
222 }
223
224 async fn delete(&self, ids: &[&str]) -> Result<(), SynapticError> {
225 if ids.is_empty() {
226 return Ok(());
227 }
228
229 let id_values: Vec<Bson> = ids.iter().map(|id| Bson::String(id.to_string())).collect();
230
231 self.collection
232 .delete_many(doc! { "_id": { "$in": id_values } })
233 .await
234 .map_err(|e| SynapticError::VectorStore(format!("MongoDB delete failed: {e}")))?;
235
236 Ok(())
237 }
238}
239
240impl MongoVectorStore {
241 async fn similarity_search_by_vector_with_score(
246 &self,
247 embedding: &[f32],
248 k: usize,
249 ) -> Result<Vec<(Document, f32)>, SynapticError> {
250 let num_candidates = self.num_candidates(k);
251
252 let query_vector: Vec<Bson> = embedding.iter().map(|v| Bson::Double(*v as f64)).collect();
254
255 let vector_search_stage = doc! {
257 "$vectorSearch": {
258 "index": &self.config.index_name,
259 "path": &self.config.vector_field,
260 "queryVector": query_vector,
261 "numCandidates": num_candidates,
262 "limit": k as i64,
263 }
264 };
265
266 let project_stage = doc! {
268 "$project": {
269 "_id": 1,
270 &self.config.content_field: 1,
271 "metadata": 1,
272 "score": { "$meta": "vectorSearchScore" },
273 }
274 };
275
276 let pipeline = vec![vector_search_stage, project_stage];
277
278 let mut cursor =
279 self.collection.aggregate(pipeline).await.map_err(|e| {
280 SynapticError::VectorStore(format!("MongoDB aggregation failed: {e}"))
281 })?;
282
283 let mut results = Vec::new();
284
285 while let Some(bson_doc) = cursor
286 .try_next()
287 .await
288 .map_err(|e| SynapticError::VectorStore(format!("MongoDB cursor error: {e}")))?
289 {
290 let id = bson_doc.get_str("_id").unwrap_or("").to_string();
291
292 let content = bson_doc
293 .get_str(&self.config.content_field)
294 .unwrap_or("")
295 .to_string();
296
297 let score = bson_doc.get_f64("score").unwrap_or(0.0) as f32;
298
299 let metadata = bson_doc
300 .get_document("metadata")
301 .ok()
302 .map(bson_doc_to_json_map)
303 .unwrap_or_default();
304
305 let doc = Document::with_metadata(id, content, metadata);
306 results.push((doc, score));
307 }
308
309 Ok(results)
310 }
311}
312
313fn json_map_to_bson(map: &HashMap<String, Value>) -> BsonDocument {
319 let mut doc = BsonDocument::new();
320 for (k, v) in map {
321 doc.insert(k.clone(), json_to_bson(v));
322 }
323 doc
324}
325
326fn json_to_bson(value: &Value) -> Bson {
328 match value {
329 Value::Null => Bson::Null,
330 Value::Bool(b) => Bson::Boolean(*b),
331 Value::Number(n) => {
332 if let Some(i) = n.as_i64() {
333 Bson::Int64(i)
334 } else if let Some(f) = n.as_f64() {
335 Bson::Double(f)
336 } else {
337 Bson::Null
338 }
339 }
340 Value::String(s) => Bson::String(s.clone()),
341 Value::Array(arr) => Bson::Array(arr.iter().map(json_to_bson).collect()),
342 Value::Object(map) => {
343 let mut doc = BsonDocument::new();
344 for (k, v) in map {
345 doc.insert(k.clone(), json_to_bson(v));
346 }
347 Bson::Document(doc)
348 }
349 }
350}
351
352fn bson_doc_to_json_map(doc: &BsonDocument) -> HashMap<String, Value> {
354 let mut map = HashMap::new();
355 for (k, v) in doc {
356 map.insert(k.clone(), bson_to_json(v));
357 }
358 map
359}
360
361fn bson_to_json(bson: &Bson) -> Value {
363 match bson {
364 Bson::Null => Value::Null,
365 Bson::Boolean(b) => Value::Bool(*b),
366 Bson::Int32(i) => Value::Number((*i as i64).into()),
367 Bson::Int64(i) => Value::Number((*i).into()),
368 Bson::Double(f) => serde_json::Number::from_f64(*f)
369 .map(Value::Number)
370 .unwrap_or(Value::Null),
371 Bson::String(s) => Value::String(s.clone()),
372 Bson::Array(arr) => Value::Array(arr.iter().map(bson_to_json).collect()),
373 Bson::Document(doc) => {
374 let map: serde_json::Map<String, Value> = doc
375 .iter()
376 .map(|(k, v)| (k.clone(), bson_to_json(v)))
377 .collect();
378 Value::Object(map)
379 }
380 Bson::ObjectId(oid) => Value::String(oid.to_hex()),
381 Bson::DateTime(dt) => Value::String(dt.to_string()),
382 Bson::Binary(bin) => Value::String(format!("<binary {} bytes>", bin.bytes.len())),
383 _ => Value::String(format!("{bson}")),
384 }
385}
386
387#[cfg(test)]
392mod tests {
393 use super::*;
394
395 #[test]
396 fn config_new_sets_defaults() {
397 let config = MongoVectorConfig::new("my_db", "my_collection");
398 assert_eq!(config.database, "my_db");
399 assert_eq!(config.collection, "my_collection");
400 assert_eq!(config.index_name, "vector_index");
401 assert_eq!(config.vector_field, "embedding");
402 assert_eq!(config.content_field, "content");
403 assert!(config.num_candidates.is_none());
404 }
405
406 #[test]
407 fn config_with_index_name() {
408 let config = MongoVectorConfig::new("db", "col").with_index_name("custom_index");
409 assert_eq!(config.index_name, "custom_index");
410 }
411
412 #[test]
413 fn config_with_vector_field() {
414 let config = MongoVectorConfig::new("db", "col").with_vector_field("vec");
415 assert_eq!(config.vector_field, "vec");
416 }
417
418 #[test]
419 fn config_with_content_field() {
420 let config = MongoVectorConfig::new("db", "col").with_content_field("text");
421 assert_eq!(config.content_field, "text");
422 }
423
424 #[test]
425 fn config_with_num_candidates() {
426 let config = MongoVectorConfig::new("db", "col").with_num_candidates(200);
427 assert_eq!(config.num_candidates, Some(200));
428 }
429
430 #[test]
431 fn config_builder_chain() {
432 let config = MongoVectorConfig::new("test_db", "embeddings")
433 .with_index_name("my_vs_index")
434 .with_vector_field("vec_field")
435 .with_content_field("text_field")
436 .with_num_candidates(500);
437
438 assert_eq!(config.database, "test_db");
439 assert_eq!(config.collection, "embeddings");
440 assert_eq!(config.index_name, "my_vs_index");
441 assert_eq!(config.vector_field, "vec_field");
442 assert_eq!(config.content_field, "text_field");
443 assert_eq!(config.num_candidates, Some(500));
444 }
445
446 #[test]
447 fn json_to_bson_roundtrip_string() {
448 let json = Value::String("hello".into());
449 let bson = json_to_bson(&json);
450 let back = bson_to_json(&bson);
451 assert_eq!(json, back);
452 }
453
454 #[test]
455 fn json_to_bson_roundtrip_number_int() {
456 let json = serde_json::json!(42);
457 let bson = json_to_bson(&json);
458 let back = bson_to_json(&bson);
459 assert_eq!(json, back);
460 }
461
462 #[test]
463 fn json_to_bson_roundtrip_number_float() {
464 let json = serde_json::json!(3.14);
465 let bson = json_to_bson(&json);
466 let back = bson_to_json(&bson);
467 assert_eq!(json, back);
468 }
469
470 #[test]
471 fn json_to_bson_roundtrip_bool() {
472 let json = Value::Bool(true);
473 let bson = json_to_bson(&json);
474 let back = bson_to_json(&bson);
475 assert_eq!(json, back);
476 }
477
478 #[test]
479 fn json_to_bson_roundtrip_null() {
480 let json = Value::Null;
481 let bson = json_to_bson(&json);
482 let back = bson_to_json(&bson);
483 assert_eq!(json, back);
484 }
485
486 #[test]
487 fn json_to_bson_roundtrip_array() {
488 let json = serde_json::json!([1, "two", true]);
489 let bson = json_to_bson(&json);
490 let back = bson_to_json(&bson);
491 assert_eq!(json, back);
492 }
493
494 #[test]
495 fn json_to_bson_roundtrip_object() {
496 let json = serde_json::json!({"key": "value", "num": 42});
497 let bson = json_to_bson(&json);
498 let back = bson_to_json(&bson);
499 assert_eq!(json, back);
500 }
501
502 #[test]
503 fn json_map_to_bson_and_back() {
504 let mut map = HashMap::new();
505 map.insert("source".to_string(), Value::String("test".into()));
506 map.insert("page".to_string(), serde_json::json!(42));
507
508 let bson_doc = json_map_to_bson(&map);
509 let back = bson_doc_to_json_map(&bson_doc);
510
511 assert_eq!(map, back);
512 }
513
514 #[test]
515 fn num_candidates_default() {
516 let config = MongoVectorConfig::new("db", "col");
517 let k = 10_usize;
520 let result = config.num_candidates.unwrap_or_else(|| (k as i64) * 10);
521 assert_eq!(result, 100);
522 }
523
524 #[test]
525 fn num_candidates_custom() {
526 let config = MongoVectorConfig::new("db", "col").with_num_candidates(200);
527 let k = 10_usize;
528 let result = config.num_candidates.unwrap_or_else(|| (k as i64) * 10);
529 assert_eq!(result, 200);
530 }
531}