weaviate_community/
batch.rs

1use reqwest::Url;
2use std::error::Error;
3use std::sync::Arc;
4
5use crate::collections::{
6    batch::{BatchAddObjects, BatchAddReferencesResponse, BatchDeleteRequest, BatchDeleteResponse},
7    error::BatchError,
8    objects::{ConsistencyLevel, MultiObjects, References},
9};
10
11/// All batch related endpoints and functionality described in
12/// [Weaviate meta API documentation](https://weaviate.io/developers/weaviate/api/rest/batch)
13#[derive(Debug)]
14pub struct Batch {
15    endpoint: Url,
16    client: Arc<reqwest::Client>,
17}
18
19impl Batch {
20    pub(super) fn new(url: &Url, client: Arc<reqwest::Client>) -> Result<Self, Box<dyn Error>> {
21        let endpoint = url.join("/v1/batch/")?;
22        Ok(Batch { endpoint, client })
23    }
24
25    /// Batch add objects.
26    ///
27    /// # Parameters
28    /// - objects: the objects to add
29    /// - consistency_level: the consistency level to use
30    ///
31    /// # Example
32    /// ```rust
33    /// use uuid::Uuid;
34    /// use weaviate_community::WeaviateClient;
35    /// use weaviate_community::collections::objects::{Object, MultiObjects, ConsistencyLevel};
36    ///
37    /// #[tokio::main]
38    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
39    ///     let client = WeaviateClient::builder("http://localhost:8080").build()?;
40    ///
41    ///     let author_uuid = Uuid::parse_str("36ddd591-2dee-4e7e-a3cc-eb86d30a4303").unwrap();
42    ///     let article_a_uuid = Uuid::parse_str("6bb06a43-e7f0-393e-9ecf-3c0f4e129064").unwrap();
43    ///     let article_b_uuid = Uuid::parse_str("b72912b9-e5d7-304e-a654-66dc63c55b32").unwrap();
44    ///
45    ///     let article_a = Object::builder("Article", serde_json::json!({}))
46    ///         .with_id(article_a_uuid.clone())
47    ///         .build();
48    ///
49    ///     let article_b = Object::builder("Article", serde_json::json!({}))
50    ///         .with_id(article_b_uuid.clone())
51    ///         .build();
52    ///
53    ///     let author = Object::builder("Author", serde_json::json!({}))
54    ///         .with_id(author_uuid.clone())
55    ///         .build();
56    ///
57    ///     let res = client.batch.objects_batch_add(
58    ///         MultiObjects::new(vec![article_a, article_b, author]),
59    ///         Some(ConsistencyLevel::ALL),
60    ///         None
61    ///     ).await;
62    ///
63    ///     Ok(())
64    /// }
65    /// ```
66    pub async fn objects_batch_add(
67        &self,
68        objects: MultiObjects,
69        consistency_level: Option<ConsistencyLevel>,
70        tenant: Option<&str>,
71    ) -> Result<BatchAddObjects, Box<dyn Error>> {
72        let mut endpoint = self.endpoint.join("objects")?;
73        if let Some(x) = consistency_level {
74            endpoint
75                .query_pairs_mut()
76                .append_pair("consistency_level", x.value());
77        }
78
79        if let Some(t) = tenant {
80            endpoint.query_pairs_mut().append_pair("tenant", t);
81        }
82
83        let payload = serde_json::to_value(&objects)?;
84        let res = self.client.post(endpoint).json(&payload).send().await?;
85        match res.status() {
86            reqwest::StatusCode::OK => {
87                let res: BatchAddObjects = res.json().await?;
88                Ok(res)
89            }
90            _ => Err(Box::new(BatchError(format!(
91                "status code {} received.",
92                res.status()
93            )))),
94        }
95    }
96
97    /// Batch delete objects.
98    ///
99    /// # Parameters
100    /// - request_body: the config to use for deletion
101    /// - consistency_level: the consistency level to use
102    ///
103    /// # Example
104    /// ```rust
105    /// use uuid::Uuid;
106    /// use weaviate_community::WeaviateClient;
107    /// use weaviate_community::collections::objects::{Object, MultiObjects, ConsistencyLevel};
108    /// use weaviate_community::collections::batch::{BatchDeleteRequest, MatchConfig};
109    ///
110    /// #[tokio::main]
111    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
112    ///     let client = WeaviateClient::builder("http://localhost:8080").build()?;
113    ///     let req = BatchDeleteRequest::builder(
114    ///         MatchConfig::new(
115    ///             "Article",
116    ///             serde_json::json!({
117    ///                 "operator": "Like",
118    ///                 "path": ["id"],
119    ///                 "valueText": "*4*",
120    ///             })
121    ///         )
122    ///     ).build();
123    ///
124    ///     let res = client.batch.objects_batch_delete(
125    ///         req,
126    ///         Some(ConsistencyLevel::ALL),
127    ///         None
128    ///     ).await;
129    ///
130    ///     Ok(())
131    /// }
132    /// ```
133    pub async fn objects_batch_delete(
134        &self,
135        request_body: BatchDeleteRequest,
136        consistency_level: Option<ConsistencyLevel>,
137        tenant: Option<&str>,
138    ) -> Result<BatchDeleteResponse, Box<dyn Error>> {
139        let mut endpoint = self.endpoint.join("objects")?;
140        if let Some(x) = consistency_level {
141            endpoint
142                .query_pairs_mut()
143                .append_pair("consistency_level", x.value());
144        }
145
146        if let Some(t) = tenant {
147            endpoint.query_pairs_mut().append_pair("tenant", t);
148        }
149
150        let payload = serde_json::to_value(&request_body)?;
151        let res = self.client.delete(endpoint).json(&payload).send().await?;
152        match res.status() {
153            reqwest::StatusCode::OK => {
154                let res: BatchDeleteResponse = res.json().await?;
155                Ok(res)
156            }
157            _ => Err(Box::new(BatchError(format!(
158                "status code {} received.",
159                res.status()
160            )))),
161        }
162    }
163
164    /// Batch add references.
165    ///
166    /// Note that the consistency_level and tenant_name in the `Reference` items contained within
167    /// the `References` input bare no effect on this method and will be ignored.
168    ///
169    /// # Parameters
170    /// - references: the references to add
171    /// - consistency_level: the consistency level to use
172    ///
173    /// # Example
174    /// ```rust
175    /// use uuid::Uuid;
176    /// use weaviate_community::WeaviateClient;
177    /// use weaviate_community::collections::objects::{Reference, References, ConsistencyLevel};
178    ///
179    /// #[tokio::main]
180    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
181    ///     let client = WeaviateClient::builder("http://localhost:8080").build()?;
182    ///
183    ///     let author_uuid = Uuid::parse_str("36ddd591-2dee-4e7e-a3cc-eb86d30a4303").unwrap();
184    ///     let article_a_uuid = Uuid::parse_str("6bb06a43-e7f0-393e-9ecf-3c0f4e129064").unwrap();
185    ///     let article_b_uuid = Uuid::parse_str("b72912b9-e5d7-304e-a654-66dc63c55b32").unwrap();
186    ///
187    ///     let references = References::new(vec![
188    ///         Reference::new(
189    ///             "Author",
190    ///             &author_uuid,
191    ///             "wroteArticles",
192    ///             "Article",
193    ///             &article_a_uuid,
194    ///         ),
195    ///         Reference::new(
196    ///             "Author",
197    ///             &author_uuid,
198    ///             "wroteArticles",
199    ///             "Article",
200    ///             &article_b_uuid,
201    ///         ),
202    ///     ]);
203    ///
204    ///     let res = client.batch.references_batch_add(
205    ///         references,
206    ///         Some(ConsistencyLevel::ALL),
207    ///         None
208    ///     ).await;
209    ///
210    ///     Ok(())
211    /// }
212    /// ```
213    pub async fn references_batch_add(
214        &self,
215        references: References,
216        consistency_level: Option<ConsistencyLevel>,
217        tenant: Option<&str>,
218    ) -> Result<BatchAddReferencesResponse, Box<dyn Error>> {
219        let mut converted: Vec<serde_json::Value> = Vec::new();
220        for reference in references.0 {
221            let new_ref = serde_json::json!({
222                "from": format!(
223                    "weaviate://localhost/{}/{}/{}",
224                    reference.from_class_name,
225                    reference.from_uuid,
226                    reference.from_property_name
227                ),
228                "to": format!(
229                    "weaviate://localhost/{}/{}",
230                    reference.to_class_name,
231                    reference.to_uuid
232                ),
233            });
234            converted.push(new_ref);
235        }
236        let payload = serde_json::json!(converted);
237
238        let mut endpoint = self.endpoint.join("references")?;
239        if let Some(cl) = consistency_level {
240            endpoint
241                .query_pairs_mut()
242                .append_pair("consistency_level", &cl.value());
243        }
244
245        if let Some(t) = tenant {
246            endpoint.query_pairs_mut().append_pair("tenant", t);
247        }
248
249        let res = self.client.post(endpoint).json(&payload).send().await?;
250        match res.status() {
251            reqwest::StatusCode::OK => {
252                let res: BatchAddReferencesResponse = res.json().await?;
253                Ok(res)
254            }
255            _ => Err(Box::new(BatchError(format!(
256                "status code {} received.",
257                res.status()
258            )))),
259        }
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use uuid::Uuid;
266
267    use crate::{
268        collections::objects::{MultiObjects, Object},
269        collections::{
270            batch::{
271                BatchAddObject, BatchDeleteRequest, BatchDeleteResponse, BatchDeleteResult,
272                GeneralStatus, MatchConfig, ResultStatus,
273            },
274            objects::{Reference, References},
275        },
276        WeaviateClient,
277    };
278
279    fn get_test_harness() -> (mockito::ServerGuard, WeaviateClient) {
280        let mock_server = mockito::Server::new();
281        let mut host = "http://".to_string();
282        host.push_str(&mock_server.host_with_port());
283        let client = WeaviateClient::builder(&host).build().unwrap();
284        (mock_server, client)
285    }
286
287    fn test_create_objects() -> MultiObjects {
288        let properties = serde_json::json!({
289            "name": "test",
290            "number": 123,
291        });
292        MultiObjects {
293            objects: vec![Object {
294                class: "Test".into(),
295                properties,
296                id: Some(Uuid::new_v4()),
297                vector: None,
298                tenant: None,
299                creation_time_unix: None,
300                last_update_time_unix: None,
301                vector_weights: None,
302                additional: None
303            }],
304        }
305    }
306
307    fn test_batch_add_object_response() -> String {
308        let properties = serde_json::json!({
309            "name": "test",
310            "number": 123,
311        });
312        serde_json::to_string(&vec![BatchAddObject {
313            class: "Test".into(),
314            properties,
315            id: None,
316            vector: None,
317            tenant: None,
318            creation_time_unix: None,
319            last_update_time_unix: None,
320            vector_weights: None,
321            result: ResultStatus {
322                status: GeneralStatus::SUCCESS,
323            },
324        }])
325        .unwrap()
326    }
327
328    fn test_delete_objects() -> BatchDeleteRequest {
329        // this will eventually be defined with the graphql stuff later on
330        let map = serde_json::json!({
331            "operator": "NotEqual",
332            "path": ["name"],
333            "valueText": "aaa"
334        });
335        BatchDeleteRequest::builder(MatchConfig::new("Test", map)).build()
336    }
337
338    fn test_delete_response() -> BatchDeleteResponse {
339        let map = serde_json::json!({
340            "operator": "NotEqual",
341            "path": ["name"],
342            "valueText": "aaa"
343        });
344        BatchDeleteResponse {
345            matches: MatchConfig::new("Test", map),
346            output: None,
347            dry_run: None,
348            results: BatchDeleteResult {
349                matches: 0,
350                limit: 1,
351                successful: 1,
352                failed: 0,
353                objects: None,
354            },
355        }
356    }
357
358    fn test_references() -> References {
359        let uuid = Uuid::parse_str("36ddd591-2dee-4e7e-a3cc-eb86d30a4303").unwrap();
360        let uuid2 = Uuid::parse_str("6bb06a43-e7f0-393e-9ecf-3c0f4e129064").unwrap();
361        let uuid3 = Uuid::parse_str("b72912b9-e5d7-304e-a654-66dc63c55b32").unwrap();
362        References::new(vec![
363            Reference::new("Test", &uuid, "testProp", "Other", &uuid2),
364            Reference::new("Test", &uuid, "testProp", "Other", &uuid3),
365        ])
366    }
367
368    fn test_add_references_response() -> String {
369        serde_json::to_string(&serde_json::json!([{
370            "result": {
371                "errors": {
372                    "error": [
373                        {
374                            "message": "test"
375                        }
376                    ]
377                },
378                "status": "FAILED"
379            }
380        }]))
381        .unwrap()
382    }
383
384    fn mock_post(
385        server: &mut mockito::ServerGuard,
386        endpoint: &str,
387        status_code: usize,
388        body: &str,
389    ) -> mockito::Mock {
390        server
391            .mock("POST", endpoint)
392            .with_status(status_code)
393            .with_header("content-type", "application/json")
394            .with_body(body)
395            .create()
396    }
397
398    fn mock_delete(
399        server: &mut mockito::ServerGuard,
400        endpoint: &str,
401        status_code: usize,
402        body: &str,
403    ) -> mockito::Mock {
404        server
405            .mock("DELETE", endpoint)
406            .with_status(status_code)
407            .with_header("content-type", "application/json")
408            .with_body(body)
409            .create()
410    }
411
412    #[tokio::test]
413    async fn test_objects_batch_add_ok() {
414        let objects = test_create_objects();
415        let res_str = test_batch_add_object_response();
416        let (mut mock_server, client) = get_test_harness();
417        let mock = mock_post(&mut mock_server, "/v1/batch/objects", 200, &res_str);
418        let res = client.batch.objects_batch_add(objects, None, None).await;
419        mock.assert();
420        assert!(res.is_ok());
421    }
422
423    #[tokio::test]
424    async fn test_objects_batch_add_err() {
425        let objects = test_create_objects();
426        let (mut mock_server, client) = get_test_harness();
427        let mock = mock_post(&mut mock_server, "/v1/batch/objects", 404, "");
428        let res = client.batch.objects_batch_add(objects, None, None).await;
429        mock.assert();
430        assert!(res.is_err());
431    }
432
433    #[tokio::test]
434    async fn test_objects_batch_delete_ok() {
435        let req = test_delete_objects();
436        let out = test_delete_response();
437        let res_str = serde_json::to_string(&out).unwrap();
438        let (mut mock_server, client) = get_test_harness();
439        let mock = mock_delete(&mut mock_server, "/v1/batch/objects", 200, &res_str);
440        let res = client.batch.objects_batch_delete(req, None, None).await;
441        mock.assert();
442        assert!(res.is_ok());
443    }
444
445    #[tokio::test]
446    async fn test_objects_batch_delete_err() {
447        let req = test_delete_objects();
448        let (mut mock_server, client) = get_test_harness();
449        let mock = mock_delete(&mut mock_server, "/v1/batch/objects", 401, "");
450        let res = client.batch.objects_batch_delete(req, None, None).await;
451        mock.assert();
452        assert!(res.is_err());
453    }
454
455    #[tokio::test]
456    async fn test_references_batch_add_ok() {
457        let refs = test_references();
458        let res_str = test_add_references_response();
459        let (mut mock_server, client) = get_test_harness();
460        let mock = mock_post(&mut mock_server, "/v1/batch/references", 200, &res_str);
461        let res = client.batch.references_batch_add(refs, None, None).await;
462        mock.assert();
463        assert!(res.is_ok());
464    }
465
466    #[tokio::test]
467    async fn test_references_batch_add_err() {
468        let refs = test_references();
469        let (mut mock_server, client) = get_test_harness();
470        let mock = mock_post(&mut mock_server, "/v1/batch/references", 500, "");
471        let res = client.batch.references_batch_add(refs, None, None).await;
472        mock.assert();
473        assert!(res.is_err());
474    }
475}