1use crate::{Error, GetItems, Item, ItemCollection, Items, Result, Search, UrlBuilder};
4use async_stream::try_stream;
5use futures::{pin_mut, Stream, StreamExt};
6use http::header::{HeaderName, USER_AGENT};
7use reqwest::{header::HeaderMap, ClientBuilder, IntoUrl, Method, StatusCode};
8use serde::{de::DeserializeOwned, Serialize};
9use serde_json::{Map, Value};
10use stac::{Collection, Link, Links, SelfHref};
11use std::pin::Pin;
12use tokio::{
13 runtime::{Builder, Runtime},
14 sync::mpsc::{self, error::SendError},
15 task::JoinHandle,
16};
17
18const DEFAULT_CHANNEL_BUFFER: usize = 4;
19
20pub async fn search(
22 href: &str,
23 mut search: Search,
24 max_items: Option<usize>,
25) -> Result<ItemCollection> {
26 let client = Client::new(href)?;
27 if search.limit.is_none() {
28 if let Some(max_items) = max_items {
29 search.limit = Some(max_items.try_into()?);
30 }
31 }
32 let stream = client.search(search).await.unwrap();
33 let mut items = if let Some(max_items) = max_items {
34 if max_items == 0 {
35 return Ok(ItemCollection::default());
36 }
37 Vec::with_capacity(max_items)
38 } else {
39 Vec::new()
40 };
41 pin_mut!(stream);
42 while let Some(item) = stream.next().await {
43 let item = item?;
44 items.push(item);
45 if let Some(max_items) = max_items {
46 if items.len() >= max_items {
47 break;
48 }
49 }
50 }
51 ItemCollection::new(items)
52}
53
54#[derive(Clone, Debug)]
56pub struct Client {
57 client: reqwest::Client,
58 channel_buffer: usize,
59 url_builder: UrlBuilder,
60}
61
62#[derive(Debug)]
64pub struct BlockingClient(Client);
65
66#[allow(missing_debug_implementations)]
68pub struct BlockingIterator {
69 runtime: Runtime,
70 stream: Pin<Box<dyn Stream<Item = Result<Item>>>>,
71}
72
73impl Client {
74 pub fn new(url: &str) -> Result<Client> {
83 let mut headers = HeaderMap::new();
85 let _ = headers.insert(
86 USER_AGENT,
87 format!("stac-rs/{}", env!("CARGO_PKG_VERSION")).parse()?,
88 );
89 let client = ClientBuilder::new().default_headers(headers).build()?;
90 Client::with_client(client, url)
91 }
92
93 pub fn with_client(client: reqwest::Client, url: &str) -> Result<Client> {
107 Ok(Client {
108 client,
109 channel_buffer: DEFAULT_CHANNEL_BUFFER,
110 url_builder: UrlBuilder::new(url)?,
111 })
112 }
113
114 pub async fn collection(&self, id: &str) -> Result<Option<Collection>> {
126 let url = self.url_builder.collection(id)?;
127 not_found_to_none(self.get(url).await)
128 }
129
130 pub async fn items(
159 &self,
160 id: &str,
161 items: impl Into<Option<Items>>,
162 ) -> Result<impl Stream<Item = Result<Item>>> {
163 let url = self.url_builder.items(id)?; let items = if let Some(items) = items.into() {
165 Some(GetItems::try_from(items)?)
166 } else {
167 None
168 };
169 let page = self
170 .request(Method::GET, url.clone(), items.as_ref(), None)
171 .await?;
172 Ok(stream_items(self.clone(), page, self.channel_buffer))
173 }
174
175 pub async fn search(&self, search: Search) -> Result<impl Stream<Item = Result<Item>>> {
198 let url = self.url_builder.search().clone();
199 tracing::debug!("searching {url}");
200 let page = self.post(url.clone(), &search).await?;
202 Ok(stream_items(self.clone(), page, self.channel_buffer))
203 }
204
205 async fn get<V>(&self, url: impl IntoUrl) -> Result<V>
206 where
207 V: DeserializeOwned + SelfHref,
208 {
209 let url = url.into_url()?;
210 let mut value = self
211 .request::<(), V>(Method::GET, url.clone(), None, None)
212 .await?;
213 *value.self_href_mut() = Some(url.into());
214 Ok(value)
215 }
216
217 async fn post<S, R>(&self, url: impl IntoUrl, data: &S) -> Result<R>
218 where
219 S: Serialize + 'static,
220 R: DeserializeOwned,
221 {
222 self.request(Method::POST, url, Some(data), None).await
223 }
224
225 async fn request<S, R>(
226 &self,
227 method: Method,
228 url: impl IntoUrl,
229 params: impl Into<Option<&S>>,
230 headers: impl Into<Option<HeaderMap>>,
231 ) -> Result<R>
232 where
233 S: Serialize + 'static,
234 R: DeserializeOwned,
235 {
236 let url = url.into_url()?;
237 let mut request = match method {
238 Method::GET => {
239 let mut request = self.client.get(url);
240 if let Some(query) = params.into() {
241 request = request.query(query);
242 }
243 request
244 }
245 Method::POST => {
246 let mut request = self.client.post(url);
247 if let Some(data) = params.into() {
248 request = request.json(&data);
249 }
250 request
251 }
252 _ => unimplemented!(),
253 };
254 if let Some(headers) = headers.into() {
255 request = request.headers(headers);
256 }
257 let response = request.send().await?.error_for_status()?;
258 response.json().await.map_err(Error::from)
259 }
260
261 async fn request_from_link<R>(&self, link: Link) -> Result<R>
262 where
263 R: DeserializeOwned,
264 {
265 let method = if let Some(method) = link.method {
266 method.parse()?
267 } else {
268 Method::GET
269 };
270 let headers = if let Some(headers) = link.headers {
271 let mut header_map = HeaderMap::new();
272 for (key, value) in headers.into_iter() {
273 let header_name: HeaderName = key.parse()?;
274 let _ = header_map.insert(header_name, value.to_string().parse()?);
275 }
276 Some(header_map)
277 } else {
278 None
279 };
280 self.request::<Map<String, Value>, R>(method, link.href.as_str(), &link.body, headers)
281 .await
282 }
283}
284
285impl BlockingClient {
286 pub fn new(url: &str) -> Result<BlockingClient> {
296 Client::new(url).map(Self)
297 }
298
299 pub fn search(&self, search: Search) -> Result<BlockingIterator> {
319 let runtime = Builder::new_current_thread().enable_all().build()?;
320 let stream = runtime.block_on(async move { self.0.search(search).await })?;
321 Ok(BlockingIterator {
322 runtime,
323 stream: Box::pin(stream),
324 })
325 }
326}
327
328impl Iterator for BlockingIterator {
329 type Item = Result<Item>;
330
331 fn next(&mut self) -> Option<Self::Item> {
332 self.runtime.block_on(self.stream.next())
333 }
334}
335
336fn stream_items(
337 client: Client,
338 page: ItemCollection,
339 channel_buffer: usize,
340) -> impl Stream<Item = Result<Item>> {
341 let (tx, mut rx) = mpsc::channel(channel_buffer);
342 let handle: JoinHandle<std::result::Result<(), SendError<_>>> = tokio::spawn(async move {
343 let pages = stream_pages(client, page);
344 pin_mut!(pages);
345 while let Some(result) = pages.next().await {
346 match result {
347 Ok(page) => tx.send(Ok(page)).await?,
348 Err(err) => {
349 tx.send(Err(err)).await?;
350 return Ok(());
351 }
352 }
353 }
354 Ok(())
355 });
356 try_stream! {
357 while let Some(result) = rx.recv().await {
358 let page = result?;
359 for item in page.items {
360 yield item;
361 }
362 }
363 let _ = handle.await?;
364 }
365}
366
367fn stream_pages(
368 client: Client,
369 mut page: ItemCollection,
370) -> impl Stream<Item = Result<ItemCollection>> {
371 try_stream! {
372 loop {
373 if page.items.is_empty() {
374 break;
375 }
376 let next_link = page.link("next").cloned();
377 yield page;
378 if let Some(next_link) = next_link {
379 if let Some(next_page) = client.request_from_link(next_link).await? {
380 page = next_page;
381 } else {
382 break;
383 }
384 } else {
385 break;
386 }
387 }
388 }
389}
390
391fn not_found_to_none<T>(result: Result<T>) -> Result<Option<T>> {
392 let mut result = result.map(Some);
393 if let Err(Error::Reqwest(ref err)) = result {
394 if err
395 .status()
396 .map(|s| s == StatusCode::NOT_FOUND)
397 .unwrap_or_default()
398 {
399 result = Ok(None);
400 }
401 }
402 result
403}
404
405#[cfg(test)]
406mod tests {
407 use super::Client;
408 use crate::{ItemCollection, Items, Search};
409 use futures::StreamExt;
410 use mockito::{Matcher, Server};
411 use serde_json::json;
412 use stac::Links;
413 use url::Url;
414
415 #[tokio::test]
416 async fn collection_not_found() {
417 let mut server = Server::new_async().await;
418 let collection = server
419 .mock("GET", "/collections/not-a-collection")
420 .with_body(include_str!("../mocks/not-a-collection.json"))
421 .with_header("content-type", "application/json")
422 .with_status(404)
423 .create_async()
424 .await;
425
426 let client = Client::new(&server.url()).unwrap();
427 assert!(client
428 .collection("not-a-collection")
429 .await
430 .unwrap()
431 .is_none());
432 collection.assert_async().await;
433 }
434
435 #[tokio::test]
436 async fn search_with_paging() {
437 let mut server = Server::new_async().await;
438 let mut page_1_body: ItemCollection =
439 serde_json::from_str(include_str!("../mocks/search-page-1.json")).unwrap();
440 let mut next_link = page_1_body.link("next").unwrap().clone();
441 next_link.href = format!("{}/search", server.url()).into();
442 page_1_body.set_link(next_link);
443 let page_1 = server
444 .mock("POST", "/search")
445 .match_body(Matcher::Json(json!({
446 "collections": ["sentinel-2-l2a"],
447 "limit": 1
448 })))
449 .with_body(serde_json::to_string(&page_1_body).unwrap())
450 .with_header("content-type", "application/geo+json")
451 .create_async()
452 .await;
453 let page_2 = server
454 .mock("POST", "/search")
455 .match_body(Matcher::Json(json!({
456 "collections": ["sentinel-2-l2a"],
457 "limit": 1,
458 "token": "next:S2A_MSIL2A_20230216T150721_R082_T19PHS_20230217T082924"
459 })))
460 .with_body(include_str!("../mocks/search-page-2.json"))
461 .with_header("content-type", "application/geo+json")
462 .create_async()
463 .await;
464
465 let client = Client::new(&server.url()).unwrap();
466 let mut search = Search {
467 collections: vec!["sentinel-2-l2a".to_string()],
468 ..Default::default()
469 };
470 search.items.limit = Some(1);
471 let items: Vec<_> = client
472 .search(search)
473 .await
474 .unwrap()
475 .map(|result| result.unwrap())
476 .take(2)
477 .collect()
478 .await;
479 page_1.assert_async().await;
480 page_2.assert_async().await;
481 assert_eq!(items.len(), 2);
482 assert!(items[0]["id"] != items[1]["id"]);
483 }
484
485 #[tokio::test]
486 async fn items_with_paging() {
487 let mut server = Server::new_async().await;
488 let mut page_1_body: ItemCollection =
489 serde_json::from_str(include_str!("../mocks/items-page-1.json")).unwrap();
490 let mut next_link = page_1_body.link("next").unwrap().clone();
491 let url: Url = next_link.href.as_str().parse().unwrap();
492 let query = url.query().unwrap();
493 next_link.href = format!(
494 "{}/collections/sentinel-2-l2a/items?{}",
495 server.url(),
496 query
497 )
498 .into();
499 page_1_body.set_link(next_link);
500 let page_1 = server
501 .mock("GET", "/collections/sentinel-2-l2a/items?limit=1")
502 .with_body(serde_json::to_string(&page_1_body).unwrap())
503 .with_header("content-type", "application/geo+json")
504 .create_async()
505 .await;
506 let page_2 = server
507 .mock("GET", "/collections/sentinel-2-l2a/items?limit=1&token=next:S2A_MSIL2A_20230216T235751_R087_T52CEB_20230217T134604")
508 .with_body(include_str!("../mocks/items-page-2.json"))
509 .with_header("content-type", "application/geo+json")
510 .create_async()
511 .await;
512
513 let client = Client::new(&server.url()).unwrap();
514 let items = Items {
515 limit: Some(1),
516 ..Default::default()
517 };
518 let items: Vec<_> = client
519 .items("sentinel-2-l2a", Some(items))
520 .await
521 .unwrap()
522 .map(|result| result.unwrap())
523 .take(2)
524 .collect()
525 .await;
526 page_1.assert_async().await;
527 page_2.assert_async().await;
528 assert_eq!(items.len(), 2);
529 assert!(items[0]["id"] != items[1]["id"]);
530 }
531
532 #[tokio::test]
533 async fn stop_on_empty_page() {
534 let mut server = Server::new_async().await;
535 let mut page_body: ItemCollection =
536 serde_json::from_str(include_str!("../mocks/items-page-1.json")).unwrap();
537 let mut next_link = page_body.link("next").unwrap().clone();
538 let url: Url = next_link.href.as_str().parse().unwrap();
539 let query = url.query().unwrap();
540 next_link.href = format!(
541 "{}/collections/sentinel-2-l2a/items?{}",
542 server.url(),
543 query
544 )
545 .into();
546 page_body.set_link(next_link);
547 page_body.items = vec![];
548 let page = server
549 .mock("GET", "/collections/sentinel-2-l2a/items?limit=1")
550 .with_body(serde_json::to_string(&page_body).unwrap())
551 .with_header("content-type", "application/geo+json")
552 .create_async()
553 .await;
554
555 let client = Client::new(&server.url()).unwrap();
556 let items = Items {
557 limit: Some(1),
558 ..Default::default()
559 };
560 let items: Vec<_> = client
561 .items("sentinel-2-l2a", Some(items))
562 .await
563 .unwrap()
564 .map(|result| result.unwrap())
565 .collect()
566 .await;
567 page.assert_async().await;
568 assert!(items.is_empty());
569 }
570
571 #[tokio::test]
572 async fn user_agent() {
573 let mut server = Server::new_async().await;
574 let _ = server
575 .mock("POST", "/search")
576 .with_body_from_file("mocks/items-page-1.json")
577 .match_header(
578 "user-agent",
579 format!("stac-rs/{}", env!("CARGO_PKG_VERSION")).as_str(),
580 )
581 .create_async()
582 .await;
583 let client = Client::new(&server.url()).unwrap();
584 let _ = client.search(Default::default()).await.unwrap();
585 }
586}