Skip to main content

wesichain_weaviate/
lib.rs

1//! Weaviate vector store integration for Wesichain.
2
3mod config;
4mod error;
5pub mod filter;
6pub mod mapper;
7
8use std::fmt;
9
10use mapper::{
11    build_near_vector_query, class_schema_request, doc_to_object, graphql_hits_to_results,
12    GraphQlRequest, GraphQlResponse,
13};
14use serde::Deserialize;
15use serde_json::Value as JsonValue;
16use wesichain_core::{Document, MetadataFilter, SearchResult, StoreError, VectorStore};
17
18use crate::filter::to_weaviate_filter;
19
20pub use config::WeaviateStoreBuilder;
21pub use error::WeaviateStoreError;
22
23#[derive(Clone)]
24pub struct WeaviateVectorStore {
25    client: reqwest::Client,
26    base_url: String,
27    class_name: String,
28    api_key: Option<String>,
29    auto_create_class: bool,
30}
31
32impl fmt::Debug for WeaviateVectorStore {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        let api_key = if self.api_key.is_some() {
35            "<redacted>"
36        } else {
37            "<none>"
38        };
39
40        f.debug_struct("WeaviateVectorStore")
41            .field("base_url", &self.base_url)
42            .field("class_name", &self.class_name)
43            .field("api_key", &api_key)
44            .field("auto_create_class", &self.auto_create_class)
45            .finish()
46    }
47}
48
49impl WeaviateVectorStore {
50    pub fn builder() -> WeaviateStoreBuilder {
51        WeaviateStoreBuilder::new()
52    }
53
54    pub fn base_url(&self) -> &str {
55        &self.base_url
56    }
57
58    pub fn class_name(&self) -> &str {
59        &self.class_name
60    }
61
62    pub fn api_key(&self) -> Option<&str> {
63        self.api_key.as_deref()
64    }
65
66    pub async fn scored_search(
67        &self,
68        query_embedding: &[f32],
69        top_k: usize,
70        filter: Option<&MetadataFilter>,
71    ) -> Result<Vec<SearchResult>, StoreError> {
72        <Self as VectorStore>::search(self, query_embedding, top_k, filter).await
73    }
74
75    pub fn auto_create_class(&self) -> bool {
76        self.auto_create_class
77    }
78
79    fn endpoint(&self, path: &str) -> String {
80        format!(
81            "{}/{}",
82            self.base_url.trim_end_matches('/'),
83            path.trim_start_matches('/')
84        )
85    }
86
87    fn request_builder(&self, method: reqwest::Method, path: &str) -> reqwest::RequestBuilder {
88        let request = self.client.request(method, self.endpoint(path));
89        if let Some(api_key) = self.api_key() {
90            request.header("Authorization", format!("Bearer {api_key}"))
91        } else {
92            request
93        }
94    }
95
96    async fn send_json(
97        &self,
98        request: reqwest::RequestBuilder,
99    ) -> Result<JsonValue, WeaviateStoreError> {
100        let response = request.send().await.map_err(WeaviateStoreError::from)?;
101        let status = response.status();
102        let body = response.text().await.map_err(WeaviateStoreError::from)?;
103
104        if !status.is_success() {
105            return Err(self.http_error_from_response(status.as_u16(), &body));
106        }
107
108        if body.trim().is_empty() {
109            return Ok(JsonValue::Null);
110        }
111
112        serde_json::from_str(&body).map_err(|err| WeaviateStoreError::InvalidResponse {
113            message: format!("failed to decode weaviate response body: {err}"),
114        })
115    }
116
117    async fn create_class_schema(&self) -> Result<(), WeaviateStoreError> {
118        let schema = class_schema_request(&self.class_name);
119        match self
120            .send_json(
121                self.request_builder(reqwest::Method::POST, "v1/schema")
122                    .json(&schema),
123            )
124            .await
125        {
126            Ok(_) => Ok(()),
127            Err(WeaviateStoreError::HttpStatus { status, message })
128                if (status == 409 || status == 422)
129                    && is_class_already_exists_message(&message) =>
130            {
131                Ok(())
132            }
133            Err(err) => Err(err),
134        }
135    }
136
137    async fn add_once(&self, docs: Vec<Document>) -> Result<(), WeaviateStoreError> {
138        let mut objects = Vec::with_capacity(docs.len());
139        let mut expected_dimension: Option<usize> = None;
140        for doc in docs {
141            let object = doc_to_object(doc, &self.class_name)?;
142            match expected_dimension {
143                Some(expected) if expected != object.vector.len() => {
144                    return Err(WeaviateStoreError::InvalidResponse {
145                        message: format!(
146                            "dimension mismatch in batch: expected {expected}, got {}",
147                            object.vector.len()
148                        ),
149                    });
150                }
151                None => expected_dimension = Some(object.vector.len()),
152                _ => {}
153            }
154            objects.push(object);
155        }
156
157        for object in objects {
158            let request = self
159                .request_builder(reqwest::Method::POST, "v1/objects")
160                .json(&object);
161
162            let _ = self.send_json(request).await?;
163        }
164
165        Ok(())
166    }
167
168    fn http_error_from_response(&self, status: u16, body: &str) -> WeaviateStoreError {
169        let message = weaviate_error_message(body);
170        if is_class_not_found_message(&message) {
171            return WeaviateStoreError::ClassNotFound {
172                class_name: self.class_name.clone(),
173                message,
174            };
175        }
176
177        WeaviateStoreError::HttpStatus { status, message }
178    }
179}
180
181#[async_trait::async_trait]
182impl VectorStore for WeaviateVectorStore {
183    async fn add(&self, docs: Vec<Document>) -> Result<(), StoreError> {
184        if docs.is_empty() {
185            return Ok(());
186        }
187
188        match self.add_once(docs.clone()).await {
189            Ok(()) => Ok(()),
190            Err(WeaviateStoreError::ClassNotFound { .. }) if self.auto_create_class => {
191                self.create_class_schema().await.map_err(StoreError::from)?;
192                self.add_once(docs).await.map_err(StoreError::from)
193            }
194            Err(error) => Err(StoreError::from(error)),
195        }
196    }
197
198    async fn search(
199        &self,
200        query_embedding: &[f32],
201        top_k: usize,
202        filter: Option<&MetadataFilter>,
203    ) -> Result<Vec<SearchResult>, StoreError> {
204        if query_embedding.is_empty() || top_k == 0 {
205            return Ok(Vec::new());
206        }
207
208        let where_clause = filter
209            .map(to_weaviate_filter)
210            .transpose()
211            .map_err(StoreError::from)?;
212
213        let query = build_near_vector_query(
214            &self.class_name,
215            query_embedding,
216            top_k,
217            where_clause.as_deref(),
218        );
219        let response = self
220            .send_json(
221                self.request_builder(reqwest::Method::POST, "v1/graphql")
222                    .json(&GraphQlRequest { query }),
223            )
224            .await
225            .map_err(StoreError::from)?;
226
227        let gql_response: GraphQlResponse = serde_json::from_value(response).map_err(|err| {
228            StoreError::from(WeaviateStoreError::InvalidResponse {
229                message: format!("failed to decode GraphQL envelope: {err}"),
230            })
231        })?;
232
233        if let Some(first_error) = gql_response.errors.first() {
234            let message = first_error.message.clone();
235            if is_class_not_found_message(&message) {
236                return Err(StoreError::from(WeaviateStoreError::ClassNotFound {
237                    class_name: self.class_name.clone(),
238                    message,
239                }));
240            }
241
242            return Err(StoreError::from(WeaviateStoreError::InvalidResponse {
243                message,
244            }));
245        }
246
247        let data = gql_response.data.ok_or_else(|| {
248            StoreError::from(WeaviateStoreError::InvalidResponse {
249                message: "missing GraphQL data in response".to_string(),
250            })
251        })?;
252
253        let mut results =
254            graphql_hits_to_results(data, &self.class_name).map_err(StoreError::from)?;
255        results.sort_by(|left, right| right.score.total_cmp(&left.score));
256        Ok(results)
257    }
258
259    async fn delete(&self, ids: &[String]) -> Result<(), StoreError> {
260        if ids.is_empty() {
261            return Ok(());
262        }
263
264        for id in ids {
265            if id.trim().is_empty() {
266                return Err(StoreError::InvalidId(id.clone()));
267            }
268
269            let encoded_id = urlencoding::encode(id);
270            let path = format!("v1/objects/{}/{}", self.class_name, encoded_id);
271            let _ = self
272                .send_json(self.request_builder(reqwest::Method::DELETE, &path))
273                .await
274                .map_err(StoreError::from)?;
275        }
276
277        Ok(())
278    }
279}
280
281#[derive(Debug, Deserialize)]
282struct WeaviateErrorEnvelope {
283    #[serde(default)]
284    error: Vec<WeaviateErrorMessage>,
285}
286
287#[derive(Debug, Deserialize)]
288struct WeaviateErrorMessage {
289    message: String,
290}
291
292fn weaviate_error_message(body: &str) -> String {
293    let trimmed = body.trim();
294    if trimmed.is_empty() {
295        return "unknown weaviate error".to_string();
296    }
297
298    serde_json::from_str::<WeaviateErrorEnvelope>(trimmed)
299        .ok()
300        .and_then(|envelope| envelope.error.into_iter().next().map(|entry| entry.message))
301        .unwrap_or_else(|| trimmed.to_string())
302}
303
304fn is_class_not_found_message(message: &str) -> bool {
305    let normalized = message.to_lowercase();
306    normalized.contains("class") && normalized.contains("not found")
307}
308
309fn is_class_already_exists_message(message: &str) -> bool {
310    let normalized = message.to_lowercase();
311    normalized.contains("class")
312        && (normalized.contains("already exists") || normalized.contains("already exist"))
313}