1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use serde_json::{json, Value};
5use synaptic_core::{Document, Embeddings, SynapticError, VectorStore};
6use uuid::Uuid;
7
8#[derive(Debug, Clone)]
14pub struct WeaviateConfig {
15 pub scheme: String,
17 pub host: String,
19 pub class_name: String,
21 pub api_key: Option<String>,
23}
24
25impl WeaviateConfig {
26 pub fn new(
27 scheme: impl Into<String>,
28 host: impl Into<String>,
29 class_name: impl Into<String>,
30 ) -> Self {
31 Self {
32 scheme: scheme.into(),
33 host: host.into(),
34 class_name: class_name.into(),
35 api_key: None,
36 }
37 }
38
39 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
41 self.api_key = Some(api_key.into());
42 self
43 }
44
45 pub fn base_url(&self) -> String {
47 format!("{}://{}", self.scheme, self.host)
48 }
49}
50
51pub struct WeaviateVectorStore {
65 config: WeaviateConfig,
66 client: reqwest::Client,
67}
68
69impl WeaviateVectorStore {
70 pub fn new(config: WeaviateConfig) -> Self {
72 Self {
73 config,
74 client: reqwest::Client::new(),
75 }
76 }
77
78 pub fn with_client(config: WeaviateConfig, client: reqwest::Client) -> Self {
80 Self { config, client }
81 }
82
83 pub fn config(&self) -> &WeaviateConfig {
85 &self.config
86 }
87
88 pub async fn initialize(&self) -> Result<(), SynapticError> {
94 let url = format!("{}/v1/schema", self.config.base_url());
95
96 let schema = json!({
97 "class": self.config.class_name,
98 "description": format!("Synaptic vector store: {}", self.config.class_name),
99 "properties": [
100 {
101 "name": "content",
102 "dataType": ["text"],
103 "description": "Document content"
104 },
105 {
106 "name": "docId",
107 "dataType": ["text"],
108 "description": "Original document ID"
109 },
110 {
111 "name": "metadata",
112 "dataType": ["text"],
113 "description": "JSON-serialized document metadata"
114 }
115 ],
116 "vectorizer": "none"
117 });
118
119 let mut req = self.client.post(&url).json(&schema);
120 if let Some(ref key) = self.config.api_key {
121 req = req.header("Authorization", format!("Bearer {key}"));
122 }
123
124 let resp = req
125 .send()
126 .await
127 .map_err(|e| SynapticError::VectorStore(format!("Weaviate initialize: {e}")))?;
128
129 let status = resp.status().as_u16();
130 if status != 200 && status != 422 {
132 let body = resp.text().await.unwrap_or_default();
133 return Err(SynapticError::VectorStore(format!(
134 "Weaviate schema error (HTTP {status}): {body}"
135 )));
136 }
137
138 Ok(())
139 }
140
141 fn apply_auth(&self, req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
142 if let Some(ref key) = self.config.api_key {
143 req.header("Authorization", format!("Bearer {key}"))
144 } else {
145 req
146 }
147 }
148
149 async fn near_vector_query(
151 &self,
152 vector: &[f32],
153 k: usize,
154 with_score: bool,
155 ) -> Result<Vec<Value>, SynapticError> {
156 let additional = if with_score {
157 "_additional { id distance }"
158 } else {
159 "_additional { id }"
160 };
161
162 let graphql_query = format!(
163 "{{ Get {{ {class}(limit: {k}, nearVector: {{ vector: {vector} }}) {{ content docId metadata {additional} }} }} }}",
164 class = self.config.class_name,
165 k = k,
166 vector = serde_json::to_string(vector).unwrap_or_default(),
167 );
168
169 let url = format!("{}/v1/graphql", self.config.base_url());
170 let req = self.apply_auth(
171 self.client
172 .post(&url)
173 .json(&json!({ "query": graphql_query })),
174 );
175
176 let resp = req
177 .send()
178 .await
179 .map_err(|e| SynapticError::VectorStore(format!("Weaviate search: {e}")))?;
180
181 let status = resp.status().as_u16();
182 let body: Value = resp
183 .json()
184 .await
185 .map_err(|e| SynapticError::VectorStore(format!("Weaviate search parse: {e}")))?;
186
187 if status != 200 {
188 return Err(SynapticError::VectorStore(format!(
189 "Weaviate search error (HTTP {status}): {body}"
190 )));
191 }
192
193 Ok(body["data"]["Get"][&self.config.class_name]
194 .as_array()
195 .cloned()
196 .unwrap_or_default())
197 }
198
199 fn item_to_document(item: &Value) -> Document {
200 let content = item["content"].as_str().unwrap_or("").to_string();
201 let id = item["docId"].as_str().unwrap_or("").to_string();
202 let metadata: HashMap<String, Value> = item["metadata"]
203 .as_str()
204 .and_then(|s| serde_json::from_str(s).ok())
205 .unwrap_or_default();
206 Document {
207 id,
208 content,
209 metadata,
210 }
211 }
212}
213
214#[async_trait]
215impl VectorStore for WeaviateVectorStore {
216 async fn add_documents(
217 &self,
218 documents: Vec<Document>,
219 embeddings: &dyn Embeddings,
220 ) -> Result<Vec<String>, SynapticError> {
221 if documents.is_empty() {
222 return Ok(vec![]);
223 }
224
225 let texts: Vec<&str> = documents.iter().map(|d| d.content.as_str()).collect();
226 let vectors = embeddings.embed_documents(&texts).await?;
227
228 let mut objects = Vec::with_capacity(documents.len());
229 let mut ids = Vec::with_capacity(documents.len());
230
231 for (doc, vector) in documents.iter().zip(vectors.iter()) {
232 let weaviate_id = Uuid::new_v4().to_string();
233 ids.push(weaviate_id.clone());
234
235 let metadata_str =
236 serde_json::to_string(&doc.metadata).unwrap_or_else(|_| "{}".to_string());
237
238 objects.push(json!({
239 "class": self.config.class_name,
240 "id": weaviate_id,
241 "properties": {
242 "content": doc.content,
243 "docId": doc.id,
244 "metadata": metadata_str,
245 },
246 "vector": vector,
247 }));
248 }
249
250 let url = format!("{}/v1/batch/objects", self.config.base_url());
251 let body = json!({ "objects": objects });
252
253 let req = self.apply_auth(self.client.post(&url).json(&body));
254 let resp = req
255 .send()
256 .await
257 .map_err(|e| SynapticError::VectorStore(format!("Weaviate batch add: {e}")))?;
258
259 let status = resp.status().as_u16();
260 if status != 200 {
261 let text = resp.text().await.unwrap_or_default();
262 return Err(SynapticError::VectorStore(format!(
263 "Weaviate batch add error (HTTP {status}): {text}"
264 )));
265 }
266
267 Ok(ids)
268 }
269
270 async fn similarity_search(
271 &self,
272 query: &str,
273 k: usize,
274 embeddings: &dyn Embeddings,
275 ) -> Result<Vec<Document>, SynapticError> {
276 let query_vector = embeddings.embed_query(query).await?;
277 let items = self.near_vector_query(&query_vector, k, false).await?;
278 Ok(items.iter().map(Self::item_to_document).collect())
279 }
280
281 async fn similarity_search_with_score(
282 &self,
283 query: &str,
284 k: usize,
285 embeddings: &dyn Embeddings,
286 ) -> Result<Vec<(Document, f32)>, SynapticError> {
287 let query_vector = embeddings.embed_query(query).await?;
288 let items = self.near_vector_query(&query_vector, k, true).await?;
289 Ok(items
290 .iter()
291 .map(|item| {
292 let doc = Self::item_to_document(item);
293 let distance = item["_additional"]["distance"].as_f64().unwrap_or(1.0) as f32;
296 let score = 1.0 - distance / 2.0;
297 (doc, score)
298 })
299 .collect())
300 }
301
302 async fn similarity_search_by_vector(
303 &self,
304 embedding: &[f32],
305 k: usize,
306 ) -> Result<Vec<Document>, SynapticError> {
307 let items = self.near_vector_query(embedding, k, false).await?;
308 Ok(items.iter().map(Self::item_to_document).collect())
309 }
310
311 async fn delete(&self, ids: &[&str]) -> Result<(), SynapticError> {
312 for id in ids {
313 let url = format!(
314 "{}/v1/objects/{}/{}",
315 self.config.base_url(),
316 self.config.class_name,
317 id
318 );
319 let req = self.apply_auth(self.client.delete(&url));
320 let resp = req
321 .send()
322 .await
323 .map_err(|e| SynapticError::VectorStore(format!("Weaviate delete: {e}")))?;
324
325 let status = resp.status().as_u16();
326 if status != 204 && status != 404 {
328 let text = resp.text().await.unwrap_or_default();
329 return Err(SynapticError::VectorStore(format!(
330 "Weaviate delete error (HTTP {status}): {text}"
331 )));
332 }
333 }
334 Ok(())
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn config_base_url() {
344 let cfg = WeaviateConfig::new("http", "localhost:8080", "Document");
345 assert_eq!(cfg.base_url(), "http://localhost:8080");
346 }
347
348 #[test]
349 fn config_with_api_key() {
350 let cfg = WeaviateConfig::new("https", "cluster.weaviate.network", "MyClass")
351 .with_api_key("wcs-secret-key");
352 assert_eq!(cfg.api_key, Some("wcs-secret-key".to_string()));
353 }
354
355 #[test]
356 fn config_class_name() {
357 let cfg = WeaviateConfig::new("http", "localhost:8080", "SynapticDocs");
358 assert_eq!(cfg.class_name, "SynapticDocs");
359 }
360}