1use crate::{Error, Result};
4use async_stream::try_stream;
5use futures::{Stream, StreamExt, pin_mut};
6use http::header::{HeaderName, USER_AGENT};
7use reqwest::{ClientBuilder, IntoUrl, Method, StatusCode, header::HeaderMap};
8use serde::{Serialize, de::DeserializeOwned};
9use serde_json::{Map, Value};
10use stac::api::{GetItems, Item, ItemCollection, Items, Search, UrlBuilder};
11use stac::{Collection, Link, Links, SelfHref};
12use std::pin::Pin;
13use tokio::{
14 runtime::{Builder, Runtime},
15 sync::mpsc::{self, error::SendError},
16 task::JoinHandle,
17};
18
19const DEFAULT_CHANNEL_BUFFER: usize = 4;
20
21pub async fn search(
23 href: &str,
24 mut search: Search,
25 max_items: Option<usize>,
26) -> Result<ItemCollection> {
27 let client = Client::new(href)?;
28 if search.limit.is_none() {
29 if let Some(max_items) = max_items {
30 search.limit = Some(max_items.try_into()?);
31 }
32 }
33 let stream = client.search(search).await?;
34 let mut items = if let Some(max_items) = max_items {
35 if max_items == 0 {
36 return Ok(ItemCollection::default());
37 }
38 Vec::with_capacity(max_items)
39 } else {
40 Vec::new()
41 };
42 pin_mut!(stream);
43 while let Some(item) = stream.next().await {
44 let item = item?;
45 items.push(item);
46 if let Some(max_items) = max_items {
47 if items.len() >= max_items {
48 break;
49 }
50 }
51 }
52 let item_collection = ItemCollection::new(items)?;
53 Ok(item_collection)
54}
55
56#[derive(Clone, Debug)]
58pub struct Client {
59 client: reqwest::Client,
60 channel_buffer: usize,
61 url_builder: UrlBuilder,
62}
63
64#[derive(Debug)]
66pub struct BlockingClient(Client);
67
68#[allow(missing_debug_implementations)]
70pub struct BlockingIterator {
71 runtime: Runtime,
72 stream: Pin<Box<dyn Stream<Item = Result<Item>>>>,
73}
74
75impl Client {
76 pub fn new(url: &str) -> Result<Client> {
85 let mut headers = HeaderMap::new();
87 let _ = headers.insert(
88 USER_AGENT,
89 format!("rustac/{}", env!("CARGO_PKG_VERSION")).parse()?,
90 );
91 let client = ClientBuilder::new().default_headers(headers).build()?;
92 Client::with_client(client, url)
93 }
94
95 pub fn with_client(client: reqwest::Client, url: &str) -> Result<Client> {
109 Ok(Client {
110 client,
111 channel_buffer: DEFAULT_CHANNEL_BUFFER,
112 url_builder: UrlBuilder::new(url)?,
113 })
114 }
115
116 pub async fn collection(&self, id: &str) -> Result<Option<Collection>> {
128 let url = self.url_builder.collection(id)?;
129 not_found_to_none(self.get(url).await)
130 }
131
132 pub async fn items(
162 &self,
163 id: &str,
164 items: impl Into<Option<Items>>,
165 ) -> Result<impl Stream<Item = Result<Item>>> {
166 let url = self.url_builder.items(id)?; let items = match items.into() {
168 Some(items) => Some(GetItems::try_from(items)?),
169 _ => None,
170 };
171 let page = self
172 .request(Method::GET, url.clone(), items.as_ref(), None)
173 .await?;
174 Ok(stream_items(self.clone(), page, self.channel_buffer))
175 }
176
177 pub async fn search(&self, search: Search) -> Result<impl Stream<Item = Result<Item>> + use<>> {
201 let url = self.url_builder.search().clone();
202 tracing::debug!("searching {url}");
203 let page = self.post(url.clone(), &search).await?;
205 Ok(stream_items(self.clone(), page, self.channel_buffer))
206 }
207
208 async fn get<V>(&self, url: impl IntoUrl) -> Result<V>
209 where
210 V: DeserializeOwned + SelfHref,
211 {
212 let url = url.into_url()?;
213 let mut value = self
214 .request::<(), V>(Method::GET, url.clone(), None, None)
215 .await?;
216 value.set_self_href(url);
217 Ok(value)
218 }
219
220 async fn post<S, R>(&self, url: impl IntoUrl, data: &S) -> Result<R>
221 where
222 S: Serialize + 'static,
223 R: DeserializeOwned,
224 {
225 self.request(Method::POST, url, Some(data), None).await
226 }
227
228 async fn request<S, R>(
229 &self,
230 method: Method,
231 url: impl IntoUrl,
232 params: impl Into<Option<&S>>,
233 headers: impl Into<Option<HeaderMap>>,
234 ) -> Result<R>
235 where
236 S: Serialize + 'static,
237 R: DeserializeOwned,
238 {
239 let url = url.into_url()?;
240 let mut request = match method {
241 Method::GET => {
242 let mut request = self.client.get(url);
243 if let Some(query) = params.into() {
244 request = request.query(query);
245 }
246 request
247 }
248 Method::POST => {
249 let mut request = self.client.post(url);
250 if let Some(data) = params.into() {
251 request = request.json(&data);
252 }
253 request
254 }
255 _ => unimplemented!(),
256 };
257 if let Some(headers) = headers.into() {
258 request = request.headers(headers);
259 }
260 let response = request.send().await?.error_for_status()?;
261 response.json().await.map_err(Error::from)
262 }
263
264 async fn request_from_link<R>(&self, link: Link) -> Result<R>
265 where
266 R: DeserializeOwned,
267 {
268 let method = if let Some(method) = link.method {
269 method.parse()?
270 } else {
271 Method::GET
272 };
273 let headers = if let Some(headers) = link.headers {
274 let mut header_map = HeaderMap::new();
275 for (key, value) in headers.into_iter() {
276 let header_name: HeaderName = key.parse()?;
277 let _ = header_map.insert(header_name, value.to_string().parse()?);
278 }
279 Some(header_map)
280 } else {
281 None
282 };
283 self.request::<Map<String, Value>, R>(method, link.href.as_str(), &link.body, headers)
284 .await
285 }
286}
287
288impl BlockingClient {
289 pub fn new(url: &str) -> Result<BlockingClient> {
299 Client::new(url).map(Self)
300 }
301
302 pub fn search(&self, search: Search) -> Result<BlockingIterator> {
323 let runtime = Builder::new_current_thread().enable_all().build()?;
324 let stream = runtime.block_on(async move { self.0.search(search).await })?;
325 Ok(BlockingIterator {
326 runtime,
327 stream: Box::pin(stream),
328 })
329 }
330}
331
332impl Iterator for BlockingIterator {
333 type Item = Result<Item>;
334
335 fn next(&mut self) -> Option<Self::Item> {
336 self.runtime.block_on(self.stream.next())
337 }
338}
339
340fn stream_items(
341 client: Client,
342 page: ItemCollection,
343 channel_buffer: usize,
344) -> impl Stream<Item = Result<Item>> {
345 let (tx, mut rx) = mpsc::channel(channel_buffer);
346 let handle: JoinHandle<std::result::Result<(), SendError<_>>> = tokio::spawn(async move {
347 let pages = stream_pages(client, page);
348 pin_mut!(pages);
349 while let Some(result) = pages.next().await {
350 match result {
351 Ok(page) => tx.send(Ok(page)).await?,
352 Err(err) => {
353 tx.send(Err(err)).await?;
354 return Ok(());
355 }
356 }
357 }
358 Ok(())
359 });
360 try_stream! {
361 while let Some(result) = rx.recv().await {
362 let page = result?;
363 for item in page.items {
364 yield item;
365 }
366 }
367 let _ = handle.await?;
368 }
369}
370
371fn stream_pages(
372 client: Client,
373 mut page: ItemCollection,
374) -> impl Stream<Item = Result<ItemCollection>> {
375 try_stream! {
376 loop {
377 if page.items.is_empty() {
378 break;
379 }
380 let next_link = page.link("next").cloned();
381 yield page;
382 if let Some(next_link) = next_link {
383 if let Some(next_page) = client.request_from_link(next_link).await? {
384 page = next_page;
385 } else {
386 break;
387 }
388 } else {
389 break;
390 }
391 }
392 }
393}
394
395fn not_found_to_none<T>(result: Result<T>) -> Result<Option<T>> {
396 let mut result = result.map(Some);
397 if let Err(Error::Reqwest(ref err)) = result {
398 if err
399 .status()
400 .map(|s| s == StatusCode::NOT_FOUND)
401 .unwrap_or_default()
402 {
403 result = Ok(None);
404 }
405 }
406 result
407}
408
409#[cfg(test)]
410mod tests {
411 use super::Client;
412 use futures::StreamExt;
413 use mockito::{Matcher, Server};
414 use serde_json::json;
415 use stac::Links;
416 use stac::api::{ItemCollection, Items, Search};
417 use url::Url;
418
419 #[tokio::test]
420 async fn collection_not_found() {
421 let mut server = Server::new_async().await;
422 let collection = server
423 .mock("GET", "/collections/not-a-collection")
424 .with_body(include_str!("../mocks/not-a-collection.json"))
425 .with_header("content-type", "application/json")
426 .with_status(404)
427 .create_async()
428 .await;
429
430 let client = Client::new(&server.url()).unwrap();
431 assert!(
432 client
433 .collection("not-a-collection")
434 .await
435 .unwrap()
436 .is_none()
437 );
438 collection.assert_async().await;
439 }
440
441 #[tokio::test]
442 async fn search_with_paging() {
443 let mut server = Server::new_async().await;
444 let mut page_1_body: ItemCollection =
445 serde_json::from_str(include_str!("../mocks/search-page-1.json")).unwrap();
446 let mut next_link = page_1_body.link("next").unwrap().clone();
447 next_link.href = format!("{}/search", server.url());
448 page_1_body.set_link(next_link);
449 let page_1 = server
450 .mock("POST", "/search")
451 .match_body(Matcher::Json(json!({
452 "collections": ["sentinel-2-l2a"],
453 "limit": 1
454 })))
455 .with_body(serde_json::to_string(&page_1_body).unwrap())
456 .with_header("content-type", "application/geo+json")
457 .create_async()
458 .await;
459 let page_2 = server
460 .mock("POST", "/search")
461 .match_body(Matcher::Json(json!({
462 "collections": ["sentinel-2-l2a"],
463 "limit": 1,
464 "token": "next:S2A_MSIL2A_20230216T150721_R082_T19PHS_20230217T082924"
465 })))
466 .with_body(include_str!("../mocks/search-page-2.json"))
467 .with_header("content-type", "application/geo+json")
468 .create_async()
469 .await;
470
471 let client = Client::new(&server.url()).unwrap();
472 let mut search = Search {
473 collections: vec!["sentinel-2-l2a".to_string()],
474 ..Default::default()
475 };
476 search.items.limit = Some(1);
477 let items: Vec<_> = client
478 .search(search)
479 .await
480 .unwrap()
481 .map(|result| result.unwrap())
482 .take(2)
483 .collect()
484 .await;
485 page_1.assert_async().await;
486 page_2.assert_async().await;
487 assert_eq!(items.len(), 2);
488 assert!(items[0]["id"] != items[1]["id"]);
489 }
490
491 #[tokio::test]
492 async fn items_with_paging() {
493 let mut server = Server::new_async().await;
494 let mut page_1_body: ItemCollection =
495 serde_json::from_str(include_str!("../mocks/items-page-1.json")).unwrap();
496 let mut next_link = page_1_body.link("next").unwrap().clone();
497 let url: Url = next_link.href.as_str().parse().unwrap();
498 let query = url.query().unwrap();
499 next_link.href = format!(
500 "{}/collections/sentinel-2-l2a/items?{}",
501 server.url(),
502 query
503 );
504 page_1_body.set_link(next_link);
505 let page_1 = server
506 .mock("GET", "/collections/sentinel-2-l2a/items?limit=1")
507 .with_body(serde_json::to_string(&page_1_body).unwrap())
508 .with_header("content-type", "application/geo+json")
509 .create_async()
510 .await;
511 let page_2 = server
512 .mock("GET", "/collections/sentinel-2-l2a/items?limit=1&token=next:S2A_MSIL2A_20230216T235751_R087_T52CEB_20230217T134604")
513 .with_body(include_str!("../mocks/items-page-2.json"))
514 .with_header("content-type", "application/geo+json")
515 .create_async()
516 .await;
517
518 let client = Client::new(&server.url()).unwrap();
519 let items = Items {
520 limit: Some(1),
521 ..Default::default()
522 };
523 let items: Vec<_> = client
524 .items("sentinel-2-l2a", Some(items))
525 .await
526 .unwrap()
527 .map(|result| result.unwrap())
528 .take(2)
529 .collect()
530 .await;
531 page_1.assert_async().await;
532 page_2.assert_async().await;
533 assert_eq!(items.len(), 2);
534 assert!(items[0]["id"] != items[1]["id"]);
535 }
536
537 #[tokio::test]
538 async fn stop_on_empty_page() {
539 let mut server = Server::new_async().await;
540 let mut page_body: ItemCollection =
541 serde_json::from_str(include_str!("../mocks/items-page-1.json")).unwrap();
542 let mut next_link = page_body.link("next").unwrap().clone();
543 let url: Url = next_link.href.as_str().parse().unwrap();
544 let query = url.query().unwrap();
545 next_link.href = format!(
546 "{}/collections/sentinel-2-l2a/items?{}",
547 server.url(),
548 query
549 );
550 page_body.set_link(next_link);
551 page_body.items = vec![];
552 let page = server
553 .mock("GET", "/collections/sentinel-2-l2a/items?limit=1")
554 .with_body(serde_json::to_string(&page_body).unwrap())
555 .with_header("content-type", "application/geo+json")
556 .create_async()
557 .await;
558
559 let client = Client::new(&server.url()).unwrap();
560 let items = Items {
561 limit: Some(1),
562 ..Default::default()
563 };
564 let items: Vec<_> = client
565 .items("sentinel-2-l2a", Some(items))
566 .await
567 .unwrap()
568 .map(|result| result.unwrap())
569 .collect()
570 .await;
571 page.assert_async().await;
572 assert!(items.is_empty());
573 }
574
575 #[tokio::test]
576 async fn user_agent() {
577 let mut server = Server::new_async().await;
578 let _ = server
579 .mock("POST", "/search")
580 .with_body_from_file("mocks/items-page-1.json")
581 .match_header(
582 "user-agent",
583 format!("rustac/{}", env!("CARGO_PKG_VERSION")).as_str(),
584 )
585 .create_async()
586 .await;
587 let client = Client::new(&server.url()).unwrap();
588 let _ = client.search(Default::default()).await.unwrap();
589 }
590}