1use crate::{Error, GetItems, Item, ItemCollection, Items, Result, Search, UrlBuilder};
4use async_stream::try_stream;
5use futures::{Stream, StreamExt, pin_mut};
6use http::header::{HeaderName, USER_AGENT};
7use reqwest::{ClientBuilder, IntoUrl, Method, StatusCode, header::HeaderMap};
8use serde::{Serialize, de::DeserializeOwned};
9use serde_json::{Map, Value};
10use stac::{Collection, Link, Links, SelfHref};
11use std::pin::Pin;
12use tokio::{
13 runtime::{Builder, Runtime},
14 sync::mpsc::{self, error::SendError},
15 task::JoinHandle,
16};
17
18const DEFAULT_CHANNEL_BUFFER: usize = 4;
19
20pub async fn search(
22 href: &str,
23 mut search: Search,
24 max_items: Option<usize>,
25) -> Result<ItemCollection> {
26 let client = Client::new(href)?;
27 if search.limit.is_none() {
28 if let Some(max_items) = max_items {
29 search.limit = Some(max_items.try_into()?);
30 }
31 }
32 let stream = client.search(search).await?;
33 let mut items = if let Some(max_items) = max_items {
34 if max_items == 0 {
35 return Ok(ItemCollection::default());
36 }
37 Vec::with_capacity(max_items)
38 } else {
39 Vec::new()
40 };
41 pin_mut!(stream);
42 while let Some(item) = stream.next().await {
43 let item = item?;
44 items.push(item);
45 if let Some(max_items) = max_items {
46 if items.len() >= max_items {
47 break;
48 }
49 }
50 }
51 ItemCollection::new(items)
52}
53
54#[derive(Clone, Debug)]
56pub struct Client {
57 client: reqwest::Client,
58 channel_buffer: usize,
59 url_builder: UrlBuilder,
60}
61
62#[derive(Debug)]
64pub struct BlockingClient(Client);
65
66#[allow(missing_debug_implementations)]
68pub struct BlockingIterator {
69 runtime: Runtime,
70 stream: Pin<Box<dyn Stream<Item = Result<Item>>>>,
71}
72
73impl Client {
74 pub fn new(url: &str) -> Result<Client> {
83 let mut headers = HeaderMap::new();
85 let _ = headers.insert(
86 USER_AGENT,
87 format!("rustac/{}", env!("CARGO_PKG_VERSION")).parse()?,
88 );
89 let client = ClientBuilder::new().default_headers(headers).build()?;
90 Client::with_client(client, url)
91 }
92
93 pub fn with_client(client: reqwest::Client, url: &str) -> Result<Client> {
107 Ok(Client {
108 client,
109 channel_buffer: DEFAULT_CHANNEL_BUFFER,
110 url_builder: UrlBuilder::new(url)?,
111 })
112 }
113
114 pub async fn collection(&self, id: &str) -> Result<Option<Collection>> {
126 let url = self.url_builder.collection(id)?;
127 not_found_to_none(self.get(url).await)
128 }
129
130 pub async fn items(
159 &self,
160 id: &str,
161 items: impl Into<Option<Items>>,
162 ) -> Result<impl Stream<Item = Result<Item>>> {
163 let url = self.url_builder.items(id)?; let items = match items.into() {
165 Some(items) => Some(GetItems::try_from(items)?),
166 _ => None,
167 };
168 let page = self
169 .request(Method::GET, url.clone(), items.as_ref(), None)
170 .await?;
171 Ok(stream_items(self.clone(), page, self.channel_buffer))
172 }
173
174 pub async fn search(&self, search: Search) -> Result<impl Stream<Item = Result<Item>> + use<>> {
197 let url = self.url_builder.search().clone();
198 tracing::debug!("searching {url}");
199 let page = self.post(url.clone(), &search).await?;
201 Ok(stream_items(self.clone(), page, self.channel_buffer))
202 }
203
204 async fn get<V>(&self, url: impl IntoUrl) -> Result<V>
205 where
206 V: DeserializeOwned + SelfHref,
207 {
208 let url = url.into_url()?;
209 let mut value = self
210 .request::<(), V>(Method::GET, url.clone(), None, None)
211 .await?;
212 value.set_self_href(url);
213 Ok(value)
214 }
215
216 async fn post<S, R>(&self, url: impl IntoUrl, data: &S) -> Result<R>
217 where
218 S: Serialize + 'static,
219 R: DeserializeOwned,
220 {
221 self.request(Method::POST, url, Some(data), None).await
222 }
223
224 async fn request<S, R>(
225 &self,
226 method: Method,
227 url: impl IntoUrl,
228 params: impl Into<Option<&S>>,
229 headers: impl Into<Option<HeaderMap>>,
230 ) -> Result<R>
231 where
232 S: Serialize + 'static,
233 R: DeserializeOwned,
234 {
235 let url = url.into_url()?;
236 let mut request = match method {
237 Method::GET => {
238 let mut request = self.client.get(url);
239 if let Some(query) = params.into() {
240 request = request.query(query);
241 }
242 request
243 }
244 Method::POST => {
245 let mut request = self.client.post(url);
246 if let Some(data) = params.into() {
247 request = request.json(&data);
248 }
249 request
250 }
251 _ => unimplemented!(),
252 };
253 if let Some(headers) = headers.into() {
254 request = request.headers(headers);
255 }
256 let response = request.send().await?.error_for_status()?;
257 response.json().await.map_err(Error::from)
258 }
259
260 async fn request_from_link<R>(&self, link: Link) -> Result<R>
261 where
262 R: DeserializeOwned,
263 {
264 let method = if let Some(method) = link.method {
265 method.parse()?
266 } else {
267 Method::GET
268 };
269 let headers = if let Some(headers) = link.headers {
270 let mut header_map = HeaderMap::new();
271 for (key, value) in headers.into_iter() {
272 let header_name: HeaderName = key.parse()?;
273 let _ = header_map.insert(header_name, value.to_string().parse()?);
274 }
275 Some(header_map)
276 } else {
277 None
278 };
279 self.request::<Map<String, Value>, R>(method, link.href.as_str(), &link.body, headers)
280 .await
281 }
282}
283
284impl BlockingClient {
285 pub fn new(url: &str) -> Result<BlockingClient> {
295 Client::new(url).map(Self)
296 }
297
298 pub fn search(&self, search: Search) -> Result<BlockingIterator> {
318 let runtime = Builder::new_current_thread().enable_all().build()?;
319 let stream = runtime.block_on(async move { self.0.search(search).await })?;
320 Ok(BlockingIterator {
321 runtime,
322 stream: Box::pin(stream),
323 })
324 }
325}
326
327impl Iterator for BlockingIterator {
328 type Item = Result<Item>;
329
330 fn next(&mut self) -> Option<Self::Item> {
331 self.runtime.block_on(self.stream.next())
332 }
333}
334
335fn stream_items(
336 client: Client,
337 page: ItemCollection,
338 channel_buffer: usize,
339) -> impl Stream<Item = Result<Item>> {
340 let (tx, mut rx) = mpsc::channel(channel_buffer);
341 let handle: JoinHandle<std::result::Result<(), SendError<_>>> = tokio::spawn(async move {
342 let pages = stream_pages(client, page);
343 pin_mut!(pages);
344 while let Some(result) = pages.next().await {
345 match result {
346 Ok(page) => tx.send(Ok(page)).await?,
347 Err(err) => {
348 tx.send(Err(err)).await?;
349 return Ok(());
350 }
351 }
352 }
353 Ok(())
354 });
355 try_stream! {
356 while let Some(result) = rx.recv().await {
357 let page = result?;
358 for item in page.items {
359 yield item;
360 }
361 }
362 let _ = handle.await?;
363 }
364}
365
366fn stream_pages(
367 client: Client,
368 mut page: ItemCollection,
369) -> impl Stream<Item = Result<ItemCollection>> {
370 try_stream! {
371 loop {
372 if page.items.is_empty() {
373 break;
374 }
375 let next_link = page.link("next").cloned();
376 yield page;
377 if let Some(next_link) = next_link {
378 if let Some(next_page) = client.request_from_link(next_link).await? {
379 page = next_page;
380 } else {
381 break;
382 }
383 } else {
384 break;
385 }
386 }
387 }
388}
389
390fn not_found_to_none<T>(result: Result<T>) -> Result<Option<T>> {
391 let mut result = result.map(Some);
392 if let Err(Error::Reqwest(ref err)) = result {
393 if err
394 .status()
395 .map(|s| s == StatusCode::NOT_FOUND)
396 .unwrap_or_default()
397 {
398 result = Ok(None);
399 }
400 }
401 result
402}
403
404#[cfg(test)]
405mod tests {
406 use super::Client;
407 use crate::{ItemCollection, Items, Search};
408 use futures::StreamExt;
409 use mockito::{Matcher, Server};
410 use serde_json::json;
411 use stac::Links;
412 use url::Url;
413
414 #[tokio::test]
415 async fn collection_not_found() {
416 let mut server = Server::new_async().await;
417 let collection = server
418 .mock("GET", "/collections/not-a-collection")
419 .with_body(include_str!("../mocks/not-a-collection.json"))
420 .with_header("content-type", "application/json")
421 .with_status(404)
422 .create_async()
423 .await;
424
425 let client = Client::new(&server.url()).unwrap();
426 assert!(
427 client
428 .collection("not-a-collection")
429 .await
430 .unwrap()
431 .is_none()
432 );
433 collection.assert_async().await;
434 }
435
436 #[tokio::test]
437 async fn search_with_paging() {
438 let mut server = Server::new_async().await;
439 let mut page_1_body: ItemCollection =
440 serde_json::from_str(include_str!("../mocks/search-page-1.json")).unwrap();
441 let mut next_link = page_1_body.link("next").unwrap().clone();
442 next_link.href = format!("{}/search", server.url());
443 page_1_body.set_link(next_link);
444 let page_1 = server
445 .mock("POST", "/search")
446 .match_body(Matcher::Json(json!({
447 "collections": ["sentinel-2-l2a"],
448 "limit": 1
449 })))
450 .with_body(serde_json::to_string(&page_1_body).unwrap())
451 .with_header("content-type", "application/geo+json")
452 .create_async()
453 .await;
454 let page_2 = server
455 .mock("POST", "/search")
456 .match_body(Matcher::Json(json!({
457 "collections": ["sentinel-2-l2a"],
458 "limit": 1,
459 "token": "next:S2A_MSIL2A_20230216T150721_R082_T19PHS_20230217T082924"
460 })))
461 .with_body(include_str!("../mocks/search-page-2.json"))
462 .with_header("content-type", "application/geo+json")
463 .create_async()
464 .await;
465
466 let client = Client::new(&server.url()).unwrap();
467 let mut search = Search {
468 collections: vec!["sentinel-2-l2a".to_string()],
469 ..Default::default()
470 };
471 search.items.limit = Some(1);
472 let items: Vec<_> = client
473 .search(search)
474 .await
475 .unwrap()
476 .map(|result| result.unwrap())
477 .take(2)
478 .collect()
479 .await;
480 page_1.assert_async().await;
481 page_2.assert_async().await;
482 assert_eq!(items.len(), 2);
483 assert!(items[0]["id"] != items[1]["id"]);
484 }
485
486 #[tokio::test]
487 async fn items_with_paging() {
488 let mut server = Server::new_async().await;
489 let mut page_1_body: ItemCollection =
490 serde_json::from_str(include_str!("../mocks/items-page-1.json")).unwrap();
491 let mut next_link = page_1_body.link("next").unwrap().clone();
492 let url: Url = next_link.href.as_str().parse().unwrap();
493 let query = url.query().unwrap();
494 next_link.href = format!(
495 "{}/collections/sentinel-2-l2a/items?{}",
496 server.url(),
497 query
498 );
499 page_1_body.set_link(next_link);
500 let page_1 = server
501 .mock("GET", "/collections/sentinel-2-l2a/items?limit=1")
502 .with_body(serde_json::to_string(&page_1_body).unwrap())
503 .with_header("content-type", "application/geo+json")
504 .create_async()
505 .await;
506 let page_2 = server
507 .mock("GET", "/collections/sentinel-2-l2a/items?limit=1&token=next:S2A_MSIL2A_20230216T235751_R087_T52CEB_20230217T134604")
508 .with_body(include_str!("../mocks/items-page-2.json"))
509 .with_header("content-type", "application/geo+json")
510 .create_async()
511 .await;
512
513 let client = Client::new(&server.url()).unwrap();
514 let items = Items {
515 limit: Some(1),
516 ..Default::default()
517 };
518 let items: Vec<_> = client
519 .items("sentinel-2-l2a", Some(items))
520 .await
521 .unwrap()
522 .map(|result| result.unwrap())
523 .take(2)
524 .collect()
525 .await;
526 page_1.assert_async().await;
527 page_2.assert_async().await;
528 assert_eq!(items.len(), 2);
529 assert!(items[0]["id"] != items[1]["id"]);
530 }
531
532 #[tokio::test]
533 async fn stop_on_empty_page() {
534 let mut server = Server::new_async().await;
535 let mut page_body: ItemCollection =
536 serde_json::from_str(include_str!("../mocks/items-page-1.json")).unwrap();
537 let mut next_link = page_body.link("next").unwrap().clone();
538 let url: Url = next_link.href.as_str().parse().unwrap();
539 let query = url.query().unwrap();
540 next_link.href = format!(
541 "{}/collections/sentinel-2-l2a/items?{}",
542 server.url(),
543 query
544 );
545 page_body.set_link(next_link);
546 page_body.items = vec![];
547 let page = server
548 .mock("GET", "/collections/sentinel-2-l2a/items?limit=1")
549 .with_body(serde_json::to_string(&page_body).unwrap())
550 .with_header("content-type", "application/geo+json")
551 .create_async()
552 .await;
553
554 let client = Client::new(&server.url()).unwrap();
555 let items = Items {
556 limit: Some(1),
557 ..Default::default()
558 };
559 let items: Vec<_> = client
560 .items("sentinel-2-l2a", Some(items))
561 .await
562 .unwrap()
563 .map(|result| result.unwrap())
564 .collect()
565 .await;
566 page.assert_async().await;
567 assert!(items.is_empty());
568 }
569
570 #[tokio::test]
571 async fn user_agent() {
572 let mut server = Server::new_async().await;
573 let _ = server
574 .mock("POST", "/search")
575 .with_body_from_file("mocks/items-page-1.json")
576 .match_header(
577 "user-agent",
578 format!("rustac/{}", env!("CARGO_PKG_VERSION")).as_str(),
579 )
580 .create_async()
581 .await;
582 let client = Client::new(&server.url()).unwrap();
583 let _ = client.search(Default::default()).await.unwrap();
584 }
585}