synaptic_vectorstores/
multi_vector.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use synaptic_core::SynapseError;
6use synaptic_embeddings::Embeddings;
7use synaptic_retrieval::{Document, Retriever};
8use tokio::sync::RwLock;
9
10use crate::VectorStore;
11
12pub struct MultiVectorRetriever<S: VectorStore> {
18 vectorstore: Arc<S>,
19 embeddings: Arc<dyn Embeddings>,
20 docstore: Arc<RwLock<HashMap<String, Document>>>,
22 id_key: String,
24 k: usize,
25}
26
27impl<S: VectorStore + 'static> MultiVectorRetriever<S> {
28 pub fn new(vectorstore: Arc<S>, embeddings: Arc<dyn Embeddings>, k: usize) -> Self {
34 Self {
35 vectorstore,
36 embeddings,
37 docstore: Arc::new(RwLock::new(HashMap::new())),
38 id_key: "parent_id".to_string(),
39 k,
40 }
41 }
42
43 pub fn with_id_key(mut self, key: impl Into<String>) -> Self {
46 self.id_key = key.into();
47 self
48 }
49
50 pub async fn add_documents(
56 &self,
57 parent_docs: Vec<Document>,
58 child_docs: Vec<Document>,
59 ) -> Result<(), SynapseError> {
60 {
62 let mut store = self.docstore.write().await;
63 for doc in parent_docs {
64 store.insert(doc.id.clone(), doc);
65 }
66 }
67
68 self.vectorstore
70 .add_documents(child_docs, self.embeddings.as_ref())
71 .await?;
72
73 Ok(())
74 }
75}
76
77#[async_trait]
78impl<S: VectorStore + 'static> Retriever for MultiVectorRetriever<S> {
79 async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapseError> {
80 let k = if top_k > 0 { top_k } else { self.k };
81
82 let children = self
84 .vectorstore
85 .similarity_search(query, k, self.embeddings.as_ref())
86 .await?;
87
88 let docstore = self.docstore.read().await;
90 let mut seen = std::collections::HashSet::new();
91 let mut parents = Vec::new();
92
93 for child in &children {
94 if let Some(parent_id_value) = child.metadata.get(&self.id_key) {
95 if let Some(parent_id) = parent_id_value.as_str() {
96 if seen.insert(parent_id.to_string()) {
97 if let Some(parent) = docstore.get(parent_id) {
98 parents.push(parent.clone());
99 }
100 }
101 }
102 }
103 }
104
105 Ok(parents)
106 }
107}