1mod config;
4mod error;
5pub mod filter;
6pub mod mapper;
7
8use std::fmt;
9
10use mapper::{
11 build_near_vector_query, class_schema_request, doc_to_object, graphql_hits_to_results,
12 GraphQlRequest, GraphQlResponse,
13};
14use serde::Deserialize;
15use serde_json::Value as JsonValue;
16use wesichain_core::{Document, MetadataFilter, SearchResult, StoreError, VectorStore};
17
18use crate::filter::to_weaviate_filter;
19
20pub use config::WeaviateStoreBuilder;
21pub use error::WeaviateStoreError;
22
23#[derive(Clone)]
24pub struct WeaviateVectorStore {
25 client: reqwest::Client,
26 base_url: String,
27 class_name: String,
28 api_key: Option<String>,
29 auto_create_class: bool,
30}
31
32impl fmt::Debug for WeaviateVectorStore {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 let api_key = if self.api_key.is_some() {
35 "<redacted>"
36 } else {
37 "<none>"
38 };
39
40 f.debug_struct("WeaviateVectorStore")
41 .field("base_url", &self.base_url)
42 .field("class_name", &self.class_name)
43 .field("api_key", &api_key)
44 .field("auto_create_class", &self.auto_create_class)
45 .finish()
46 }
47}
48
49impl WeaviateVectorStore {
50 pub fn builder() -> WeaviateStoreBuilder {
51 WeaviateStoreBuilder::new()
52 }
53
54 pub fn base_url(&self) -> &str {
55 &self.base_url
56 }
57
58 pub fn class_name(&self) -> &str {
59 &self.class_name
60 }
61
62 pub fn api_key(&self) -> Option<&str> {
63 self.api_key.as_deref()
64 }
65
66 pub async fn scored_search(
67 &self,
68 query_embedding: &[f32],
69 top_k: usize,
70 filter: Option<&MetadataFilter>,
71 ) -> Result<Vec<SearchResult>, StoreError> {
72 <Self as VectorStore>::search(self, query_embedding, top_k, filter).await
73 }
74
75 pub fn auto_create_class(&self) -> bool {
76 self.auto_create_class
77 }
78
79 fn endpoint(&self, path: &str) -> String {
80 format!(
81 "{}/{}",
82 self.base_url.trim_end_matches('/'),
83 path.trim_start_matches('/')
84 )
85 }
86
87 fn request_builder(&self, method: reqwest::Method, path: &str) -> reqwest::RequestBuilder {
88 let request = self.client.request(method, self.endpoint(path));
89 if let Some(api_key) = self.api_key() {
90 request.header("Authorization", format!("Bearer {api_key}"))
91 } else {
92 request
93 }
94 }
95
96 async fn send_json(
97 &self,
98 request: reqwest::RequestBuilder,
99 ) -> Result<JsonValue, WeaviateStoreError> {
100 let response = request.send().await.map_err(WeaviateStoreError::from)?;
101 let status = response.status();
102 let body = response.text().await.map_err(WeaviateStoreError::from)?;
103
104 if !status.is_success() {
105 return Err(self.http_error_from_response(status.as_u16(), &body));
106 }
107
108 if body.trim().is_empty() {
109 return Ok(JsonValue::Null);
110 }
111
112 serde_json::from_str(&body).map_err(|err| WeaviateStoreError::InvalidResponse {
113 message: format!("failed to decode weaviate response body: {err}"),
114 })
115 }
116
117 async fn create_class_schema(&self) -> Result<(), WeaviateStoreError> {
118 let schema = class_schema_request(&self.class_name);
119 match self
120 .send_json(
121 self.request_builder(reqwest::Method::POST, "v1/schema")
122 .json(&schema),
123 )
124 .await
125 {
126 Ok(_) => Ok(()),
127 Err(WeaviateStoreError::HttpStatus { status, message })
128 if (status == 409 || status == 422)
129 && is_class_already_exists_message(&message) =>
130 {
131 Ok(())
132 }
133 Err(err) => Err(err),
134 }
135 }
136
137 async fn add_once(&self, docs: Vec<Document>) -> Result<(), WeaviateStoreError> {
138 let mut objects = Vec::with_capacity(docs.len());
139 let mut expected_dimension: Option<usize> = None;
140 for doc in docs {
141 let object = doc_to_object(doc, &self.class_name)?;
142 match expected_dimension {
143 Some(expected) if expected != object.vector.len() => {
144 return Err(WeaviateStoreError::InvalidResponse {
145 message: format!(
146 "dimension mismatch in batch: expected {expected}, got {}",
147 object.vector.len()
148 ),
149 });
150 }
151 None => expected_dimension = Some(object.vector.len()),
152 _ => {}
153 }
154 objects.push(object);
155 }
156
157 for object in objects {
158 let request = self
159 .request_builder(reqwest::Method::POST, "v1/objects")
160 .json(&object);
161
162 let _ = self.send_json(request).await?;
163 }
164
165 Ok(())
166 }
167
168 fn http_error_from_response(&self, status: u16, body: &str) -> WeaviateStoreError {
169 let message = weaviate_error_message(body);
170 if is_class_not_found_message(&message) {
171 return WeaviateStoreError::ClassNotFound {
172 class_name: self.class_name.clone(),
173 message,
174 };
175 }
176
177 WeaviateStoreError::HttpStatus { status, message }
178 }
179}
180
181#[async_trait::async_trait]
182impl VectorStore for WeaviateVectorStore {
183 async fn add(&self, docs: Vec<Document>) -> Result<(), StoreError> {
184 if docs.is_empty() {
185 return Ok(());
186 }
187
188 match self.add_once(docs.clone()).await {
189 Ok(()) => Ok(()),
190 Err(WeaviateStoreError::ClassNotFound { .. }) if self.auto_create_class => {
191 self.create_class_schema().await.map_err(StoreError::from)?;
192 self.add_once(docs).await.map_err(StoreError::from)
193 }
194 Err(error) => Err(StoreError::from(error)),
195 }
196 }
197
198 async fn search(
199 &self,
200 query_embedding: &[f32],
201 top_k: usize,
202 filter: Option<&MetadataFilter>,
203 ) -> Result<Vec<SearchResult>, StoreError> {
204 if query_embedding.is_empty() || top_k == 0 {
205 return Ok(Vec::new());
206 }
207
208 let where_clause = filter
209 .map(to_weaviate_filter)
210 .transpose()
211 .map_err(StoreError::from)?;
212
213 let query = build_near_vector_query(
214 &self.class_name,
215 query_embedding,
216 top_k,
217 where_clause.as_deref(),
218 );
219 let response = self
220 .send_json(
221 self.request_builder(reqwest::Method::POST, "v1/graphql")
222 .json(&GraphQlRequest { query }),
223 )
224 .await
225 .map_err(StoreError::from)?;
226
227 let gql_response: GraphQlResponse = serde_json::from_value(response).map_err(|err| {
228 StoreError::from(WeaviateStoreError::InvalidResponse {
229 message: format!("failed to decode GraphQL envelope: {err}"),
230 })
231 })?;
232
233 if let Some(first_error) = gql_response.errors.first() {
234 let message = first_error.message.clone();
235 if is_class_not_found_message(&message) {
236 return Err(StoreError::from(WeaviateStoreError::ClassNotFound {
237 class_name: self.class_name.clone(),
238 message,
239 }));
240 }
241
242 return Err(StoreError::from(WeaviateStoreError::InvalidResponse {
243 message,
244 }));
245 }
246
247 let data = gql_response.data.ok_or_else(|| {
248 StoreError::from(WeaviateStoreError::InvalidResponse {
249 message: "missing GraphQL data in response".to_string(),
250 })
251 })?;
252
253 let mut results =
254 graphql_hits_to_results(data, &self.class_name).map_err(StoreError::from)?;
255 results.sort_by(|left, right| right.score.total_cmp(&left.score));
256 Ok(results)
257 }
258
259 async fn delete(&self, ids: &[String]) -> Result<(), StoreError> {
260 if ids.is_empty() {
261 return Ok(());
262 }
263
264 for id in ids {
265 if id.trim().is_empty() {
266 return Err(StoreError::InvalidId(id.clone()));
267 }
268
269 let encoded_id = urlencoding::encode(id);
270 let path = format!("v1/objects/{}/{}", self.class_name, encoded_id);
271 let _ = self
272 .send_json(self.request_builder(reqwest::Method::DELETE, &path))
273 .await
274 .map_err(StoreError::from)?;
275 }
276
277 Ok(())
278 }
279}
280
281#[derive(Debug, Deserialize)]
282struct WeaviateErrorEnvelope {
283 #[serde(default)]
284 error: Vec<WeaviateErrorMessage>,
285}
286
287#[derive(Debug, Deserialize)]
288struct WeaviateErrorMessage {
289 message: String,
290}
291
292fn weaviate_error_message(body: &str) -> String {
293 let trimmed = body.trim();
294 if trimmed.is_empty() {
295 return "unknown weaviate error".to_string();
296 }
297
298 serde_json::from_str::<WeaviateErrorEnvelope>(trimmed)
299 .ok()
300 .and_then(|envelope| envelope.error.into_iter().next().map(|entry| entry.message))
301 .unwrap_or_else(|| trimmed.to_string())
302}
303
304fn is_class_not_found_message(message: &str) -> bool {
305 let normalized = message.to_lowercase();
306 normalized.contains("class") && normalized.contains("not found")
307}
308
309fn is_class_already_exists_message(message: &str) -> bool {
310 let normalized = message.to_lowercase();
311 normalized.contains("class")
312 && (normalized.contains("already exists") || normalized.contains("already exist"))
313}