weaviate_community/
query.rs

1use crate::collections::{
2    error::GraphQLError,
3    query::{AggregateQuery, ExploreQuery, GetQuery, RawQuery},
4};
5use reqwest::Url;
6use std::error::Error;
7use std::sync::Arc;
8
9/// All GraphQL related endpoints and functionality described in
10/// [Weaviate GraphQL API documentation](https://weaviate.io/developers/weaviate/api/graphql)
11#[derive(Debug)]
12pub struct Query {
13    endpoint: Url,
14    client: Arc<reqwest::Client>,
15}
16
17impl Query {
18    /// Create a new Query object. The query object is intended to like inside the WeaviateClient
19    /// and be called through the WeaviateClient.
20    pub(super) fn new(url: &Url, client: Arc<reqwest::Client>) -> Result<Self, Box<dyn Error>> {
21        let endpoint = url.join("/v1/graphql")?;
22        Ok(Query { endpoint, client })
23    }
24
25    /// Execute the Get{} GraphQL query
26    ///
27    /// # Parameters
28    /// - query: the query to execute
29    ///
30    /// # Example
31    /// ```no_run
32    /// use weaviate_community::WeaviateClient;
33    /// use weaviate_community::collections::query::GetBuilder;
34    ///
35    /// #[tokio::main]
36    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
37    ///     let client = WeaviateClient::builder("http://localhost:8080").build()?;
38    ///     let query = GetBuilder::new(
39    ///         "JeopardyQuestion",
40    ///         vec![
41    ///             "question",
42    ///             "answer",
43    ///             "points",
44    ///             "hasCategory { ... on JeopardyCategory { title }}"
45    ///         ])
46    ///         .with_limit(1)
47    ///         .with_additional(vec!["id"])
48    ///         .build();
49    ///     let res = client.query.get(query).await;
50    ///
51    ///     Ok(())
52    /// }
53    /// ```
54    pub async fn get(&self, query: GetQuery) -> Result<serde_json::Value, Box<dyn Error>> {
55        let payload = serde_json::to_value(query).unwrap();
56        let res = self
57            .client
58            .post(self.endpoint.clone())
59            .json(&payload)
60            .send()
61            .await?;
62        match res.status() {
63            reqwest::StatusCode::OK => {
64                let res = res.json::<serde_json::Value>().await?;
65                Ok(res)
66            }
67            _ => Err(Box::new(GraphQLError(format!(
68                "status code {} received when executing GraphQL Get.",
69                res.status()
70            )))),
71        }
72    }
73
74    /// Execute the Aggregate{} GraphQL query
75    ///
76    ///
77    /// # Parameters
78    /// - query: the query to execute
79    ///
80    /// # Example
81    /// ```no_run
82    /// use weaviate_community::WeaviateClient;
83    /// use weaviate_community::collections::query::AggregateBuilder;
84    ///
85    /// #[tokio::main]
86    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
87    ///     let client = WeaviateClient::builder("http://localhost:8080").build()?;
88    ///     let query = AggregateBuilder::new("Article")
89    ///         .with_meta_count()
90    ///         .with_fields(vec!["wordCount {count maximum mean median minimum mode sum type}"])
91    ///         .build();
92    ///     let res = client.query.aggregate(query).await;
93    ///     Ok(())
94    /// }
95    /// ```
96    pub async fn aggregate(
97        &self,
98        query: AggregateQuery,
99    ) -> Result<serde_json::Value, Box<dyn Error>> {
100        let payload = serde_json::to_value(query).unwrap();
101        let res = self
102            .client
103            .post(self.endpoint.clone())
104            .json(&payload)
105            .send()
106            .await?;
107        match res.status() {
108            reqwest::StatusCode::OK => {
109                let res = res.json::<serde_json::Value>().await?;
110                Ok(res)
111            }
112            _ => Err(Box::new(GraphQLError(format!(
113                "status code {} received when executing GraphQL Aggregate.",
114                res.status()
115            )))),
116        }
117    }
118
119    /// Execute the Explore{} GraphQL query
120    ///
121    /// # Parameters
122    /// - query: the query to execute
123    ///
124    /// # Example
125    /// ```no_run
126    /// use weaviate_community::WeaviateClient;
127    /// use weaviate_community::collections::query::ExploreBuilder;
128    ///
129    /// #[tokio::main]
130    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
131    ///     let client = WeaviateClient::builder("http://localhost:8080").build()?;
132    ///     let query = ExploreBuilder::new()
133    ///         .with_limit(1)
134    ///         .with_near_vector("{vector: [-0.36840257,0.13973749,-0.28994447]}")
135    ///         .with_fields(vec!["className"])
136    ///         .build();
137    ///     let res = client.query.explore(query).await;
138    ///     Ok(())
139    /// }
140    /// ```
141    pub async fn explore(&self, query: ExploreQuery) -> Result<serde_json::Value, Box<dyn Error>> {
142        let payload = serde_json::to_value(query).unwrap();
143        let res = self
144            .client
145            .post(self.endpoint.clone())
146            .json(&payload)
147            .send()
148            .await?;
149        match res.status() {
150            reqwest::StatusCode::OK => {
151                let res = res.json::<serde_json::Value>().await?;
152                Ok(res)
153            }
154            _ => Err(Box::new(GraphQLError(format!(
155                "status code {} received when executing GraphQL Explore.",
156                res.status()
157            )))),
158        }
159    }
160
161    /// Execute a raw GraphQL query.
162    ///
163    /// This method has been implemented to allow you to run your own query that doesn't fit in
164    /// with the format that is set out in this crate.
165    ///
166    /// If there is a query that you think should be added, please open up a new feature request on
167    /// GitHub.
168    ///
169    /// # Parameters
170    /// - query: the query to execute
171    ///
172    /// # Example
173    /// ```no_run
174    /// use weaviate_community::WeaviateClient;
175    /// use weaviate_community::collections::query::RawQuery;
176    ///
177    /// #[tokio::main]
178    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
179    ///     let client = WeaviateClient::builder("http://localhost:8080").build()?;
180    ///     let query = RawQuery::new("{Get{JeopardyQuestion{question answer points}}}");
181    ///     let res = client.query.raw(query).await;
182    ///     Ok(())
183    ///
184    /// }
185    /// ```
186    pub async fn raw(&self, query: RawQuery) -> Result<serde_json::Value, Box<dyn Error>> {
187        let payload = serde_json::to_value(query).unwrap();
188        let res = self
189            .client
190            .post(self.endpoint.clone())
191            .json(&payload)
192            .send()
193            .await?;
194        match res.status() {
195            reqwest::StatusCode::OK => {
196                let res = res.json::<serde_json::Value>().await?;
197                Ok(res)
198            }
199            _ => Err(Box::new(GraphQLError(format!(
200                "status code {} received when executing GraphQL raw query.",
201                res.status()
202            )))),
203        }
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use crate::collections::query::RawQuery;
210    use crate::collections::query::{AggregateBuilder, ExploreBuilder, GetBuilder};
211    use crate::WeaviateClient;
212
213    fn get_test_harness() -> (mockito::ServerGuard, WeaviateClient) {
214        let mock_server = mockito::Server::new();
215        let mut host = "http://".to_string();
216        host.push_str(&mock_server.host_with_port());
217        let client = WeaviateClient::builder(&host).build().unwrap();
218        (mock_server, client)
219    }
220
221    fn mock_post(
222        server: &mut mockito::ServerGuard,
223        endpoint: &str,
224        status_code: usize,
225        body: &str,
226    ) -> mockito::Mock {
227        server
228            .mock("POST", endpoint)
229            .with_status(status_code)
230            .with_header("content-type", "application/json")
231            .with_body(body)
232            .create()
233    }
234
235    fn test_get_response() -> String {
236        let data = serde_json::to_string(&serde_json::json!({
237            "data": {
238                "Get": {
239                    "JeopardyQuestion": [
240                        {
241                            "answer": "Jonah",
242                            "points": 100,
243                            "question": "This prophet passed the time he spent inside a fish offering up prayers"
244                        },
245                    ]
246                }
247            }
248        })).unwrap();
249        data
250    }
251
252    fn test_aggregate_response() -> String {
253        let data = serde_json::to_string(&serde_json::json!(
254        {
255          "data": {
256            "Aggregate": {
257              "Article": [
258                {
259                  "inPublication": {
260                    "pointingTo": [
261                      "Publication"
262                    ],
263                    "type": "cref"
264                  },
265                  "meta": {
266                    "count": 4403
267                  },
268                  "wordCount": {
269                    "count": 4403,
270                    "maximum": 16852,
271                    "mean": 966.0113558937088,
272                    "median": 680,
273                    "minimum": 109,
274                    "mode": 575,
275                    "sum": 4253348,
276                    "type": "int"
277                  }
278                }
279              ]
280            }
281          }
282        }))
283        .unwrap();
284        data
285    }
286
287    fn test_explore_response() -> String {
288        let data = serde_json::to_string(&serde_json::json!(
289        {
290          "data": {
291            "Explore": [
292              {
293                "beacon": "weaviate://localhost/7e9b9ffe-e645-302d-9d94-517670623b35",
294                "certainty": 0.975523,
295                "className": "Publication"
296              }
297            ]
298          },
299          "errors": null
300        }))
301        .unwrap();
302        data
303    }
304
305    #[tokio::test]
306    async fn test_get_query_ok() {
307        let (mut mock_server, client) = get_test_harness();
308        let mock = mock_post(&mut mock_server, "/v1/graphql", 200, &test_get_response());
309        let query = GetBuilder::new(
310            "JeopardyQuestion",
311            vec![
312                "question",
313                "answer",
314                "points",
315                "hasCategory { ... on JeopardyCategory { title }}",
316            ],
317        )
318        .with_limit(1)
319        .with_additional(vec!["id"])
320        .build();
321        let res = client.query.get(query).await;
322        mock.assert();
323        assert!(res.is_ok());
324        assert_eq!(
325            res.unwrap()["data"]["Get"]["JeopardyQuestion"]
326                .as_array()
327                .unwrap()
328                .len(),
329            1
330        );
331    }
332
333    #[tokio::test]
334    async fn test_get_query_err() {
335        let (mut mock_server, client) = get_test_harness();
336        let mock = mock_post(&mut mock_server, "/v1/graphql", 422, "");
337        let query = GetBuilder::new(
338            "JeopardyQuestion",
339            vec![
340                "question",
341                "answer",
342                "points",
343                "hasCategory { ... on JeopardyCategory { title }}",
344            ],
345        )
346        .with_limit(1)
347        .with_additional(vec!["id"])
348        .build();
349        let res = client.query.get(query).await;
350        mock.assert();
351        assert!(res.is_err());
352    }
353
354    #[tokio::test]
355    async fn test_aggregate_query_ok() {
356        let (mut mock_server, client) = get_test_harness();
357        let mock = mock_post(
358            &mut mock_server,
359            "/v1/graphql",
360            200,
361            &test_aggregate_response(),
362        );
363        let query = AggregateBuilder::new("Article")
364            .with_meta_count()
365            .with_fields(vec![
366                "wordCount {count maximum mean median minimum mode sum type}",
367            ])
368            .build();
369        let res = client.query.aggregate(query).await;
370        mock.assert();
371        assert!(res.is_ok());
372        assert_eq!(
373            res.unwrap()["data"]["Aggregate"]["Article"]
374                .as_array()
375                .unwrap()
376                .len(),
377            1
378        );
379    }
380
381    #[tokio::test]
382    async fn test_aggregate_query_err() {
383        let (mut mock_server, client) = get_test_harness();
384        let mock = mock_post(&mut mock_server, "/v1/graphql", 422, "");
385        let query = AggregateBuilder::new("JeopardyQuestion").build();
386        let res = client.query.aggregate(query).await;
387        mock.assert();
388        assert!(res.is_err());
389    }
390
391    #[tokio::test]
392    async fn test_explore_query_ok() {
393        let (mut mock_server, client) = get_test_harness();
394        let mock = mock_post(
395            &mut mock_server,
396            "/v1/graphql",
397            200,
398            &test_explore_response(),
399        );
400        let query = ExploreBuilder::new()
401            .with_limit(1)
402            .with_near_vector("{vector: [-0.36840257,0.13973749,-0.28994447]}")
403            .with_fields(vec!["className"])
404            .build();
405        let res = client.query.explore(query).await;
406        mock.assert();
407        assert!(res.is_ok());
408    }
409
410    #[tokio::test]
411    async fn test_explore_query_err() {
412        let (mut mock_server, client) = get_test_harness();
413        let mock = mock_post(&mut mock_server, "/v1/graphql", 422, "");
414        let query = ExploreBuilder::new().build();
415        let res = client.query.explore(query).await;
416        mock.assert();
417        assert!(res.is_err());
418    }
419
420    #[tokio::test]
421    async fn test_raw_query_ok() {
422        let (mut mock_server, client) = get_test_harness();
423        let mock = mock_post(&mut mock_server, "/v1/graphql", 200, &test_get_response());
424        let query = RawQuery::new("{ Get { JeopardyQuestion { question answer points } } }");
425        let res = client.query.raw(query).await;
426        mock.assert();
427        assert!(res.is_ok());
428        assert_eq!(
429            res.unwrap()["data"]["Get"]["JeopardyQuestion"]
430                .as_array()
431                .unwrap()
432                .len(),
433            1
434        );
435    }
436
437    #[tokio::test]
438    async fn test_raw_query_err() {
439        let (mut mock_server, client) = get_test_harness();
440        let mock = mock_post(&mut mock_server, "/v1/graphql", 422, "");
441        let query = RawQuery::new("{ Get { JeopardyQuestion { question answer points } } }");
442        let res = client.query.raw(query).await;
443        mock.assert();
444        assert!(res.is_err());
445    }
446}