1use reqwest::Url;
2use std::error::Error;
3use std::sync::Arc;
4
5use crate::collections::{
6 batch::{BatchAddObjects, BatchAddReferencesResponse, BatchDeleteRequest, BatchDeleteResponse},
7 error::BatchError,
8 objects::{ConsistencyLevel, MultiObjects, References},
9};
10
11#[derive(Debug)]
14pub struct Batch {
15 endpoint: Url,
16 client: Arc<reqwest::Client>,
17}
18
19impl Batch {
20 pub(super) fn new(url: &Url, client: Arc<reqwest::Client>) -> Result<Self, Box<dyn Error>> {
21 let endpoint = url.join("/v1/batch/")?;
22 Ok(Batch { endpoint, client })
23 }
24
25 pub async fn objects_batch_add(
67 &self,
68 objects: MultiObjects,
69 consistency_level: Option<ConsistencyLevel>,
70 tenant: Option<&str>,
71 ) -> Result<BatchAddObjects, Box<dyn Error>> {
72 let mut endpoint = self.endpoint.join("objects")?;
73 if let Some(x) = consistency_level {
74 endpoint
75 .query_pairs_mut()
76 .append_pair("consistency_level", x.value());
77 }
78
79 if let Some(t) = tenant {
80 endpoint.query_pairs_mut().append_pair("tenant", t);
81 }
82
83 let payload = serde_json::to_value(&objects)?;
84 let res = self.client.post(endpoint).json(&payload).send().await?;
85 match res.status() {
86 reqwest::StatusCode::OK => {
87 let res: BatchAddObjects = res.json().await?;
88 Ok(res)
89 }
90 _ => Err(Box::new(BatchError(format!(
91 "status code {} received.",
92 res.status()
93 )))),
94 }
95 }
96
97 pub async fn objects_batch_delete(
134 &self,
135 request_body: BatchDeleteRequest,
136 consistency_level: Option<ConsistencyLevel>,
137 tenant: Option<&str>,
138 ) -> Result<BatchDeleteResponse, Box<dyn Error>> {
139 let mut endpoint = self.endpoint.join("objects")?;
140 if let Some(x) = consistency_level {
141 endpoint
142 .query_pairs_mut()
143 .append_pair("consistency_level", x.value());
144 }
145
146 if let Some(t) = tenant {
147 endpoint.query_pairs_mut().append_pair("tenant", t);
148 }
149
150 let payload = serde_json::to_value(&request_body)?;
151 let res = self.client.delete(endpoint).json(&payload).send().await?;
152 match res.status() {
153 reqwest::StatusCode::OK => {
154 let res: BatchDeleteResponse = res.json().await?;
155 Ok(res)
156 }
157 _ => Err(Box::new(BatchError(format!(
158 "status code {} received.",
159 res.status()
160 )))),
161 }
162 }
163
164 pub async fn references_batch_add(
214 &self,
215 references: References,
216 consistency_level: Option<ConsistencyLevel>,
217 tenant: Option<&str>,
218 ) -> Result<BatchAddReferencesResponse, Box<dyn Error>> {
219 let mut converted: Vec<serde_json::Value> = Vec::new();
220 for reference in references.0 {
221 let new_ref = serde_json::json!({
222 "from": format!(
223 "weaviate://localhost/{}/{}/{}",
224 reference.from_class_name,
225 reference.from_uuid,
226 reference.from_property_name
227 ),
228 "to": format!(
229 "weaviate://localhost/{}/{}",
230 reference.to_class_name,
231 reference.to_uuid
232 ),
233 });
234 converted.push(new_ref);
235 }
236 let payload = serde_json::json!(converted);
237
238 let mut endpoint = self.endpoint.join("references")?;
239 if let Some(cl) = consistency_level {
240 endpoint
241 .query_pairs_mut()
242 .append_pair("consistency_level", &cl.value());
243 }
244
245 if let Some(t) = tenant {
246 endpoint.query_pairs_mut().append_pair("tenant", t);
247 }
248
249 let res = self.client.post(endpoint).json(&payload).send().await?;
250 match res.status() {
251 reqwest::StatusCode::OK => {
252 let res: BatchAddReferencesResponse = res.json().await?;
253 Ok(res)
254 }
255 _ => Err(Box::new(BatchError(format!(
256 "status code {} received.",
257 res.status()
258 )))),
259 }
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use uuid::Uuid;
266
267 use crate::{
268 collections::objects::{MultiObjects, Object},
269 collections::{
270 batch::{
271 BatchAddObject, BatchDeleteRequest, BatchDeleteResponse, BatchDeleteResult,
272 GeneralStatus, MatchConfig, ResultStatus,
273 },
274 objects::{Reference, References},
275 },
276 WeaviateClient,
277 };
278
279 fn get_test_harness() -> (mockito::ServerGuard, WeaviateClient) {
280 let mock_server = mockito::Server::new();
281 let mut host = "http://".to_string();
282 host.push_str(&mock_server.host_with_port());
283 let client = WeaviateClient::builder(&host).build().unwrap();
284 (mock_server, client)
285 }
286
287 fn test_create_objects() -> MultiObjects {
288 let properties = serde_json::json!({
289 "name": "test",
290 "number": 123,
291 });
292 MultiObjects {
293 objects: vec![Object {
294 class: "Test".into(),
295 properties,
296 id: Some(Uuid::new_v4()),
297 vector: None,
298 tenant: None,
299 creation_time_unix: None,
300 last_update_time_unix: None,
301 vector_weights: None,
302 additional: None
303 }],
304 }
305 }
306
307 fn test_batch_add_object_response() -> String {
308 let properties = serde_json::json!({
309 "name": "test",
310 "number": 123,
311 });
312 serde_json::to_string(&vec![BatchAddObject {
313 class: "Test".into(),
314 properties,
315 id: None,
316 vector: None,
317 tenant: None,
318 creation_time_unix: None,
319 last_update_time_unix: None,
320 vector_weights: None,
321 result: ResultStatus {
322 status: GeneralStatus::SUCCESS,
323 },
324 }])
325 .unwrap()
326 }
327
328 fn test_delete_objects() -> BatchDeleteRequest {
329 let map = serde_json::json!({
331 "operator": "NotEqual",
332 "path": ["name"],
333 "valueText": "aaa"
334 });
335 BatchDeleteRequest::builder(MatchConfig::new("Test", map)).build()
336 }
337
338 fn test_delete_response() -> BatchDeleteResponse {
339 let map = serde_json::json!({
340 "operator": "NotEqual",
341 "path": ["name"],
342 "valueText": "aaa"
343 });
344 BatchDeleteResponse {
345 matches: MatchConfig::new("Test", map),
346 output: None,
347 dry_run: None,
348 results: BatchDeleteResult {
349 matches: 0,
350 limit: 1,
351 successful: 1,
352 failed: 0,
353 objects: None,
354 },
355 }
356 }
357
358 fn test_references() -> References {
359 let uuid = Uuid::parse_str("36ddd591-2dee-4e7e-a3cc-eb86d30a4303").unwrap();
360 let uuid2 = Uuid::parse_str("6bb06a43-e7f0-393e-9ecf-3c0f4e129064").unwrap();
361 let uuid3 = Uuid::parse_str("b72912b9-e5d7-304e-a654-66dc63c55b32").unwrap();
362 References::new(vec![
363 Reference::new("Test", &uuid, "testProp", "Other", &uuid2),
364 Reference::new("Test", &uuid, "testProp", "Other", &uuid3),
365 ])
366 }
367
368 fn test_add_references_response() -> String {
369 serde_json::to_string(&serde_json::json!([{
370 "result": {
371 "errors": {
372 "error": [
373 {
374 "message": "test"
375 }
376 ]
377 },
378 "status": "FAILED"
379 }
380 }]))
381 .unwrap()
382 }
383
384 fn mock_post(
385 server: &mut mockito::ServerGuard,
386 endpoint: &str,
387 status_code: usize,
388 body: &str,
389 ) -> mockito::Mock {
390 server
391 .mock("POST", endpoint)
392 .with_status(status_code)
393 .with_header("content-type", "application/json")
394 .with_body(body)
395 .create()
396 }
397
398 fn mock_delete(
399 server: &mut mockito::ServerGuard,
400 endpoint: &str,
401 status_code: usize,
402 body: &str,
403 ) -> mockito::Mock {
404 server
405 .mock("DELETE", endpoint)
406 .with_status(status_code)
407 .with_header("content-type", "application/json")
408 .with_body(body)
409 .create()
410 }
411
412 #[tokio::test]
413 async fn test_objects_batch_add_ok() {
414 let objects = test_create_objects();
415 let res_str = test_batch_add_object_response();
416 let (mut mock_server, client) = get_test_harness();
417 let mock = mock_post(&mut mock_server, "/v1/batch/objects", 200, &res_str);
418 let res = client.batch.objects_batch_add(objects, None, None).await;
419 mock.assert();
420 assert!(res.is_ok());
421 }
422
423 #[tokio::test]
424 async fn test_objects_batch_add_err() {
425 let objects = test_create_objects();
426 let (mut mock_server, client) = get_test_harness();
427 let mock = mock_post(&mut mock_server, "/v1/batch/objects", 404, "");
428 let res = client.batch.objects_batch_add(objects, None, None).await;
429 mock.assert();
430 assert!(res.is_err());
431 }
432
433 #[tokio::test]
434 async fn test_objects_batch_delete_ok() {
435 let req = test_delete_objects();
436 let out = test_delete_response();
437 let res_str = serde_json::to_string(&out).unwrap();
438 let (mut mock_server, client) = get_test_harness();
439 let mock = mock_delete(&mut mock_server, "/v1/batch/objects", 200, &res_str);
440 let res = client.batch.objects_batch_delete(req, None, None).await;
441 mock.assert();
442 assert!(res.is_ok());
443 }
444
445 #[tokio::test]
446 async fn test_objects_batch_delete_err() {
447 let req = test_delete_objects();
448 let (mut mock_server, client) = get_test_harness();
449 let mock = mock_delete(&mut mock_server, "/v1/batch/objects", 401, "");
450 let res = client.batch.objects_batch_delete(req, None, None).await;
451 mock.assert();
452 assert!(res.is_err());
453 }
454
455 #[tokio::test]
456 async fn test_references_batch_add_ok() {
457 let refs = test_references();
458 let res_str = test_add_references_response();
459 let (mut mock_server, client) = get_test_harness();
460 let mock = mock_post(&mut mock_server, "/v1/batch/references", 200, &res_str);
461 let res = client.batch.references_batch_add(refs, None, None).await;
462 mock.assert();
463 assert!(res.is_ok());
464 }
465
466 #[tokio::test]
467 async fn test_references_batch_add_err() {
468 let refs = test_references();
469 let (mut mock_server, client) = get_test_harness();
470 let mock = mock_post(&mut mock_server, "/v1/batch/references", 500, "");
471 let res = client.batch.references_batch_add(refs, None, None).await;
472 mock.assert();
473 assert!(res.is_err());
474 }
475}