1use crate::collections::{
2 error::GraphQLError,
3 query::{AggregateQuery, ExploreQuery, GetQuery, RawQuery},
4};
5use reqwest::Url;
6use std::error::Error;
7use std::sync::Arc;
8
9#[derive(Debug)]
12pub struct Query {
13 endpoint: Url,
14 client: Arc<reqwest::Client>,
15}
16
17impl Query {
18 pub(super) fn new(url: &Url, client: Arc<reqwest::Client>) -> Result<Self, Box<dyn Error>> {
21 let endpoint = url.join("/v1/graphql")?;
22 Ok(Query { endpoint, client })
23 }
24
25 pub async fn get(&self, query: GetQuery) -> Result<serde_json::Value, Box<dyn Error>> {
55 let payload = serde_json::to_value(query).unwrap();
56 let res = self
57 .client
58 .post(self.endpoint.clone())
59 .json(&payload)
60 .send()
61 .await?;
62 match res.status() {
63 reqwest::StatusCode::OK => {
64 let res = res.json::<serde_json::Value>().await?;
65 Ok(res)
66 }
67 _ => Err(Box::new(GraphQLError(format!(
68 "status code {} received when executing GraphQL Get.",
69 res.status()
70 )))),
71 }
72 }
73
74 pub async fn aggregate(
97 &self,
98 query: AggregateQuery,
99 ) -> Result<serde_json::Value, Box<dyn Error>> {
100 let payload = serde_json::to_value(query).unwrap();
101 let res = self
102 .client
103 .post(self.endpoint.clone())
104 .json(&payload)
105 .send()
106 .await?;
107 match res.status() {
108 reqwest::StatusCode::OK => {
109 let res = res.json::<serde_json::Value>().await?;
110 Ok(res)
111 }
112 _ => Err(Box::new(GraphQLError(format!(
113 "status code {} received when executing GraphQL Aggregate.",
114 res.status()
115 )))),
116 }
117 }
118
119 pub async fn explore(&self, query: ExploreQuery) -> Result<serde_json::Value, Box<dyn Error>> {
142 let payload = serde_json::to_value(query).unwrap();
143 let res = self
144 .client
145 .post(self.endpoint.clone())
146 .json(&payload)
147 .send()
148 .await?;
149 match res.status() {
150 reqwest::StatusCode::OK => {
151 let res = res.json::<serde_json::Value>().await?;
152 Ok(res)
153 }
154 _ => Err(Box::new(GraphQLError(format!(
155 "status code {} received when executing GraphQL Explore.",
156 res.status()
157 )))),
158 }
159 }
160
161 pub async fn raw(&self, query: RawQuery) -> Result<serde_json::Value, Box<dyn Error>> {
187 let payload = serde_json::to_value(query).unwrap();
188 let res = self
189 .client
190 .post(self.endpoint.clone())
191 .json(&payload)
192 .send()
193 .await?;
194 match res.status() {
195 reqwest::StatusCode::OK => {
196 let res = res.json::<serde_json::Value>().await?;
197 Ok(res)
198 }
199 _ => Err(Box::new(GraphQLError(format!(
200 "status code {} received when executing GraphQL raw query.",
201 res.status()
202 )))),
203 }
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use crate::collections::query::RawQuery;
210 use crate::collections::query::{AggregateBuilder, ExploreBuilder, GetBuilder};
211 use crate::WeaviateClient;
212
213 fn get_test_harness() -> (mockito::ServerGuard, WeaviateClient) {
214 let mock_server = mockito::Server::new();
215 let mut host = "http://".to_string();
216 host.push_str(&mock_server.host_with_port());
217 let client = WeaviateClient::builder(&host).build().unwrap();
218 (mock_server, client)
219 }
220
221 fn mock_post(
222 server: &mut mockito::ServerGuard,
223 endpoint: &str,
224 status_code: usize,
225 body: &str,
226 ) -> mockito::Mock {
227 server
228 .mock("POST", endpoint)
229 .with_status(status_code)
230 .with_header("content-type", "application/json")
231 .with_body(body)
232 .create()
233 }
234
235 fn test_get_response() -> String {
236 let data = serde_json::to_string(&serde_json::json!({
237 "data": {
238 "Get": {
239 "JeopardyQuestion": [
240 {
241 "answer": "Jonah",
242 "points": 100,
243 "question": "This prophet passed the time he spent inside a fish offering up prayers"
244 },
245 ]
246 }
247 }
248 })).unwrap();
249 data
250 }
251
252 fn test_aggregate_response() -> String {
253 let data = serde_json::to_string(&serde_json::json!(
254 {
255 "data": {
256 "Aggregate": {
257 "Article": [
258 {
259 "inPublication": {
260 "pointingTo": [
261 "Publication"
262 ],
263 "type": "cref"
264 },
265 "meta": {
266 "count": 4403
267 },
268 "wordCount": {
269 "count": 4403,
270 "maximum": 16852,
271 "mean": 966.0113558937088,
272 "median": 680,
273 "minimum": 109,
274 "mode": 575,
275 "sum": 4253348,
276 "type": "int"
277 }
278 }
279 ]
280 }
281 }
282 }))
283 .unwrap();
284 data
285 }
286
287 fn test_explore_response() -> String {
288 let data = serde_json::to_string(&serde_json::json!(
289 {
290 "data": {
291 "Explore": [
292 {
293 "beacon": "weaviate://localhost/7e9b9ffe-e645-302d-9d94-517670623b35",
294 "certainty": 0.975523,
295 "className": "Publication"
296 }
297 ]
298 },
299 "errors": null
300 }))
301 .unwrap();
302 data
303 }
304
305 #[tokio::test]
306 async fn test_get_query_ok() {
307 let (mut mock_server, client) = get_test_harness();
308 let mock = mock_post(&mut mock_server, "/v1/graphql", 200, &test_get_response());
309 let query = GetBuilder::new(
310 "JeopardyQuestion",
311 vec![
312 "question",
313 "answer",
314 "points",
315 "hasCategory { ... on JeopardyCategory { title }}",
316 ],
317 )
318 .with_limit(1)
319 .with_additional(vec!["id"])
320 .build();
321 let res = client.query.get(query).await;
322 mock.assert();
323 assert!(res.is_ok());
324 assert_eq!(
325 res.unwrap()["data"]["Get"]["JeopardyQuestion"]
326 .as_array()
327 .unwrap()
328 .len(),
329 1
330 );
331 }
332
333 #[tokio::test]
334 async fn test_get_query_err() {
335 let (mut mock_server, client) = get_test_harness();
336 let mock = mock_post(&mut mock_server, "/v1/graphql", 422, "");
337 let query = GetBuilder::new(
338 "JeopardyQuestion",
339 vec![
340 "question",
341 "answer",
342 "points",
343 "hasCategory { ... on JeopardyCategory { title }}",
344 ],
345 )
346 .with_limit(1)
347 .with_additional(vec!["id"])
348 .build();
349 let res = client.query.get(query).await;
350 mock.assert();
351 assert!(res.is_err());
352 }
353
354 #[tokio::test]
355 async fn test_aggregate_query_ok() {
356 let (mut mock_server, client) = get_test_harness();
357 let mock = mock_post(
358 &mut mock_server,
359 "/v1/graphql",
360 200,
361 &test_aggregate_response(),
362 );
363 let query = AggregateBuilder::new("Article")
364 .with_meta_count()
365 .with_fields(vec![
366 "wordCount {count maximum mean median minimum mode sum type}",
367 ])
368 .build();
369 let res = client.query.aggregate(query).await;
370 mock.assert();
371 assert!(res.is_ok());
372 assert_eq!(
373 res.unwrap()["data"]["Aggregate"]["Article"]
374 .as_array()
375 .unwrap()
376 .len(),
377 1
378 );
379 }
380
381 #[tokio::test]
382 async fn test_aggregate_query_err() {
383 let (mut mock_server, client) = get_test_harness();
384 let mock = mock_post(&mut mock_server, "/v1/graphql", 422, "");
385 let query = AggregateBuilder::new("JeopardyQuestion").build();
386 let res = client.query.aggregate(query).await;
387 mock.assert();
388 assert!(res.is_err());
389 }
390
391 #[tokio::test]
392 async fn test_explore_query_ok() {
393 let (mut mock_server, client) = get_test_harness();
394 let mock = mock_post(
395 &mut mock_server,
396 "/v1/graphql",
397 200,
398 &test_explore_response(),
399 );
400 let query = ExploreBuilder::new()
401 .with_limit(1)
402 .with_near_vector("{vector: [-0.36840257,0.13973749,-0.28994447]}")
403 .with_fields(vec!["className"])
404 .build();
405 let res = client.query.explore(query).await;
406 mock.assert();
407 assert!(res.is_ok());
408 }
409
410 #[tokio::test]
411 async fn test_explore_query_err() {
412 let (mut mock_server, client) = get_test_harness();
413 let mock = mock_post(&mut mock_server, "/v1/graphql", 422, "");
414 let query = ExploreBuilder::new().build();
415 let res = client.query.explore(query).await;
416 mock.assert();
417 assert!(res.is_err());
418 }
419
420 #[tokio::test]
421 async fn test_raw_query_ok() {
422 let (mut mock_server, client) = get_test_harness();
423 let mock = mock_post(&mut mock_server, "/v1/graphql", 200, &test_get_response());
424 let query = RawQuery::new("{ Get { JeopardyQuestion { question answer points } } }");
425 let res = client.query.raw(query).await;
426 mock.assert();
427 assert!(res.is_ok());
428 assert_eq!(
429 res.unwrap()["data"]["Get"]["JeopardyQuestion"]
430 .as_array()
431 .unwrap()
432 .len(),
433 1
434 );
435 }
436
437 #[tokio::test]
438 async fn test_raw_query_err() {
439 let (mut mock_server, client) = get_test_harness();
440 let mock = mock_post(&mut mock_server, "/v1/graphql", 422, "");
441 let query = RawQuery::new("{ Get { JeopardyQuestion { question answer points } } }");
442 let res = client.query.raw(query).await;
443 mock.assert();
444 assert!(res.is_err());
445 }
446}