Skip to main content

synaptic_weaviate/
vector_store.rs

1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use serde_json::{json, Value};
5use synaptic_core::{Document, Embeddings, SynapticError, VectorStore};
6use uuid::Uuid;
7
8// ---------------------------------------------------------------------------
9// WeaviateConfig
10// ---------------------------------------------------------------------------
11
12/// Configuration for connecting to a Weaviate instance.
13#[derive(Debug, Clone)]
14pub struct WeaviateConfig {
15    /// HTTP scheme: `http` or `https`.
16    pub scheme: String,
17    /// Host and port, e.g. `localhost:8080` or `my-cluster.weaviate.network`.
18    pub host: String,
19    /// Weaviate class (collection) name. Must start with an uppercase letter.
20    pub class_name: String,
21    /// Optional API key for authentication (Weaviate Cloud Services).
22    pub api_key: Option<String>,
23}
24
25impl WeaviateConfig {
26    pub fn new(
27        scheme: impl Into<String>,
28        host: impl Into<String>,
29        class_name: impl Into<String>,
30    ) -> Self {
31        Self {
32            scheme: scheme.into(),
33            host: host.into(),
34            class_name: class_name.into(),
35            api_key: None,
36        }
37    }
38
39    /// Add an API key for authentication.
40    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
41        self.api_key = Some(api_key.into());
42        self
43    }
44
45    /// Build the base URL from scheme and host.
46    pub fn base_url(&self) -> String {
47        format!("{}://{}", self.scheme, self.host)
48    }
49}
50
51// ---------------------------------------------------------------------------
52// WeaviateVectorStore
53// ---------------------------------------------------------------------------
54
55/// Weaviate-backed vector store.
56///
57/// Implements [`VectorStore`] using the Weaviate v1 REST API:
58/// - Batch add: `POST /v1/batch/objects`
59/// - Similarity search: `POST /v1/graphql` with `nearVector`
60/// - Delete: `DELETE /v1/objects/{class}/{id}`
61///
62/// Call [`WeaviateVectorStore::initialize`] once to create the class schema
63/// before adding documents.
64pub struct WeaviateVectorStore {
65    config: WeaviateConfig,
66    client: reqwest::Client,
67}
68
69impl WeaviateVectorStore {
70    /// Create a new store with the given configuration.
71    pub fn new(config: WeaviateConfig) -> Self {
72        Self {
73            config,
74            client: reqwest::Client::new(),
75        }
76    }
77
78    /// Create with a custom reqwest client.
79    pub fn with_client(config: WeaviateConfig, client: reqwest::Client) -> Self {
80        Self { config, client }
81    }
82
83    /// Return a reference to the configuration.
84    pub fn config(&self) -> &WeaviateConfig {
85        &self.config
86    }
87
88    /// Create the Weaviate class schema for this store (idempotent).
89    ///
90    /// Creates a class with `content` (text), `metadata` (text), and
91    /// `docId` (text) properties. Uses the `cosine` distance metric via
92    /// `"vectorizer": "none"` (caller supplies vectors).
93    pub async fn initialize(&self) -> Result<(), SynapticError> {
94        let url = format!("{}/v1/schema", self.config.base_url());
95
96        let schema = json!({
97            "class": self.config.class_name,
98            "description": format!("Synaptic vector store: {}", self.config.class_name),
99            "properties": [
100                {
101                    "name": "content",
102                    "dataType": ["text"],
103                    "description": "Document content"
104                },
105                {
106                    "name": "docId",
107                    "dataType": ["text"],
108                    "description": "Original document ID"
109                },
110                {
111                    "name": "metadata",
112                    "dataType": ["text"],
113                    "description": "JSON-serialized document metadata"
114                }
115            ],
116            "vectorizer": "none"
117        });
118
119        let mut req = self.client.post(&url).json(&schema);
120        if let Some(ref key) = self.config.api_key {
121            req = req.header("Authorization", format!("Bearer {key}"));
122        }
123
124        let resp = req
125            .send()
126            .await
127            .map_err(|e| SynapticError::VectorStore(format!("Weaviate initialize: {e}")))?;
128
129        let status = resp.status().as_u16();
130        // 200 = created; 422 = class already exists — both are acceptable
131        if status != 200 && status != 422 {
132            let body = resp.text().await.unwrap_or_default();
133            return Err(SynapticError::VectorStore(format!(
134                "Weaviate schema error (HTTP {status}): {body}"
135            )));
136        }
137
138        Ok(())
139    }
140
141    fn apply_auth(&self, req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
142        if let Some(ref key) = self.config.api_key {
143            req.header("Authorization", format!("Bearer {key}"))
144        } else {
145            req
146        }
147    }
148
149    /// Execute a nearVector GraphQL query and return raw items.
150    async fn near_vector_query(
151        &self,
152        vector: &[f32],
153        k: usize,
154        with_score: bool,
155    ) -> Result<Vec<Value>, SynapticError> {
156        let additional = if with_score {
157            "_additional { id distance }"
158        } else {
159            "_additional { id }"
160        };
161
162        let graphql_query = format!(
163            "{{ Get {{ {class}(limit: {k}, nearVector: {{ vector: {vector} }}) {{ content docId metadata {additional} }} }} }}",
164            class = self.config.class_name,
165            k = k,
166            vector = serde_json::to_string(vector).unwrap_or_default(),
167        );
168
169        let url = format!("{}/v1/graphql", self.config.base_url());
170        let req = self.apply_auth(
171            self.client
172                .post(&url)
173                .json(&json!({ "query": graphql_query })),
174        );
175
176        let resp = req
177            .send()
178            .await
179            .map_err(|e| SynapticError::VectorStore(format!("Weaviate search: {e}")))?;
180
181        let status = resp.status().as_u16();
182        let body: Value = resp
183            .json()
184            .await
185            .map_err(|e| SynapticError::VectorStore(format!("Weaviate search parse: {e}")))?;
186
187        if status != 200 {
188            return Err(SynapticError::VectorStore(format!(
189                "Weaviate search error (HTTP {status}): {body}"
190            )));
191        }
192
193        Ok(body["data"]["Get"][&self.config.class_name]
194            .as_array()
195            .cloned()
196            .unwrap_or_default())
197    }
198
199    fn item_to_document(item: &Value) -> Document {
200        let content = item["content"].as_str().unwrap_or("").to_string();
201        let id = item["docId"].as_str().unwrap_or("").to_string();
202        let metadata: HashMap<String, Value> = item["metadata"]
203            .as_str()
204            .and_then(|s| serde_json::from_str(s).ok())
205            .unwrap_or_default();
206        Document {
207            id,
208            content,
209            metadata,
210        }
211    }
212}
213
214#[async_trait]
215impl VectorStore for WeaviateVectorStore {
216    async fn add_documents(
217        &self,
218        documents: Vec<Document>,
219        embeddings: &dyn Embeddings,
220    ) -> Result<Vec<String>, SynapticError> {
221        if documents.is_empty() {
222            return Ok(vec![]);
223        }
224
225        let texts: Vec<&str> = documents.iter().map(|d| d.content.as_str()).collect();
226        let vectors = embeddings.embed_documents(&texts).await?;
227
228        let mut objects = Vec::with_capacity(documents.len());
229        let mut ids = Vec::with_capacity(documents.len());
230
231        for (doc, vector) in documents.iter().zip(vectors.iter()) {
232            let weaviate_id = Uuid::new_v4().to_string();
233            ids.push(weaviate_id.clone());
234
235            let metadata_str =
236                serde_json::to_string(&doc.metadata).unwrap_or_else(|_| "{}".to_string());
237
238            objects.push(json!({
239                "class": self.config.class_name,
240                "id": weaviate_id,
241                "properties": {
242                    "content": doc.content,
243                    "docId": doc.id,
244                    "metadata": metadata_str,
245                },
246                "vector": vector,
247            }));
248        }
249
250        let url = format!("{}/v1/batch/objects", self.config.base_url());
251        let body = json!({ "objects": objects });
252
253        let req = self.apply_auth(self.client.post(&url).json(&body));
254        let resp = req
255            .send()
256            .await
257            .map_err(|e| SynapticError::VectorStore(format!("Weaviate batch add: {e}")))?;
258
259        let status = resp.status().as_u16();
260        if status != 200 {
261            let text = resp.text().await.unwrap_or_default();
262            return Err(SynapticError::VectorStore(format!(
263                "Weaviate batch add error (HTTP {status}): {text}"
264            )));
265        }
266
267        Ok(ids)
268    }
269
270    async fn similarity_search(
271        &self,
272        query: &str,
273        k: usize,
274        embeddings: &dyn Embeddings,
275    ) -> Result<Vec<Document>, SynapticError> {
276        let query_vector = embeddings.embed_query(query).await?;
277        let items = self.near_vector_query(&query_vector, k, false).await?;
278        Ok(items.iter().map(Self::item_to_document).collect())
279    }
280
281    async fn similarity_search_with_score(
282        &self,
283        query: &str,
284        k: usize,
285        embeddings: &dyn Embeddings,
286    ) -> Result<Vec<(Document, f32)>, SynapticError> {
287        let query_vector = embeddings.embed_query(query).await?;
288        let items = self.near_vector_query(&query_vector, k, true).await?;
289        Ok(items
290            .iter()
291            .map(|item| {
292                let doc = Self::item_to_document(item);
293                // Weaviate returns cosine distance (0=identical, 2=opposite)
294                // Convert to similarity score: 1 - distance/2
295                let distance = item["_additional"]["distance"].as_f64().unwrap_or(1.0) as f32;
296                let score = 1.0 - distance / 2.0;
297                (doc, score)
298            })
299            .collect())
300    }
301
302    async fn similarity_search_by_vector(
303        &self,
304        embedding: &[f32],
305        k: usize,
306    ) -> Result<Vec<Document>, SynapticError> {
307        let items = self.near_vector_query(embedding, k, false).await?;
308        Ok(items.iter().map(Self::item_to_document).collect())
309    }
310
311    async fn delete(&self, ids: &[&str]) -> Result<(), SynapticError> {
312        for id in ids {
313            let url = format!(
314                "{}/v1/objects/{}/{}",
315                self.config.base_url(),
316                self.config.class_name,
317                id
318            );
319            let req = self.apply_auth(self.client.delete(&url));
320            let resp = req
321                .send()
322                .await
323                .map_err(|e| SynapticError::VectorStore(format!("Weaviate delete: {e}")))?;
324
325            let status = resp.status().as_u16();
326            // 204 = deleted; 404 = already gone (OK to ignore)
327            if status != 204 && status != 404 {
328                let text = resp.text().await.unwrap_or_default();
329                return Err(SynapticError::VectorStore(format!(
330                    "Weaviate delete error (HTTP {status}): {text}"
331                )));
332            }
333        }
334        Ok(())
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn config_base_url() {
344        let cfg = WeaviateConfig::new("http", "localhost:8080", "Document");
345        assert_eq!(cfg.base_url(), "http://localhost:8080");
346    }
347
348    #[test]
349    fn config_with_api_key() {
350        let cfg = WeaviateConfig::new("https", "cluster.weaviate.network", "MyClass")
351            .with_api_key("wcs-secret-key");
352        assert_eq!(cfg.api_key, Some("wcs-secret-key".to_string()));
353    }
354
355    #[test]
356    fn config_class_name() {
357        let cfg = WeaviateConfig::new("http", "localhost:8080", "SynapticDocs");
358        assert_eq!(cfg.class_name, "SynapticDocs");
359    }
360}