stac_io/
api.rs

1//! A STAC API client.
2
3use crate::{Error, Result};
4use async_stream::try_stream;
5use futures::{Stream, StreamExt, pin_mut};
6use http::header::{HeaderName, USER_AGENT};
7use reqwest::{ClientBuilder, IntoUrl, Method, StatusCode, header::HeaderMap};
8use serde::{Serialize, de::DeserializeOwned};
9use serde_json::{Map, Value};
10use stac::api::{GetItems, Item, ItemCollection, Items, Search, UrlBuilder};
11use stac::{Collection, Link, Links, SelfHref};
12use std::pin::Pin;
13use tokio::{
14    runtime::{Builder, Runtime},
15    sync::mpsc::{self, error::SendError},
16    task::JoinHandle,
17};
18
19const DEFAULT_CHANNEL_BUFFER: usize = 4;
20
21/// Searches a STAC API.
22pub async fn search(
23    href: &str,
24    mut search: Search,
25    max_items: Option<usize>,
26) -> Result<ItemCollection> {
27    let client = Client::new(href)?;
28    if search.limit.is_none() {
29        if let Some(max_items) = max_items {
30            search.limit = Some(max_items.try_into()?);
31        }
32    }
33    let stream = client.search(search).await?;
34    let mut items = if let Some(max_items) = max_items {
35        if max_items == 0 {
36            return Ok(ItemCollection::default());
37        }
38        Vec::with_capacity(max_items)
39    } else {
40        Vec::new()
41    };
42    pin_mut!(stream);
43    while let Some(item) = stream.next().await {
44        let item = item?;
45        items.push(item);
46        if let Some(max_items) = max_items {
47            if items.len() >= max_items {
48                break;
49            }
50        }
51    }
52    let item_collection = ItemCollection::new(items)?;
53    Ok(item_collection)
54}
55
56/// A client for interacting with STAC APIs.
57#[derive(Clone, Debug)]
58pub struct Client {
59    client: reqwest::Client,
60    channel_buffer: usize,
61    url_builder: UrlBuilder,
62}
63
64/// A client for interacting with STAC APIs without async.
65#[derive(Debug)]
66pub struct BlockingClient(Client);
67
68/// A blocking iterator over items.
69#[allow(missing_debug_implementations)]
70pub struct BlockingIterator {
71    runtime: Runtime,
72    stream: Pin<Box<dyn Stream<Item = Result<Item>>>>,
73}
74
75impl Client {
76    /// Creates a new API client.
77    ///
78    /// # Examples
79    ///
80    /// ```
81    /// # use stac_io::api::Client;
82    /// let client = Client::new("https://planetarycomputer.microsoft.com/api/stac/v1").unwrap();
83    /// ```
84    pub fn new(url: &str) -> Result<Client> {
85        // TODO support HATEOS (aka look up the urls from the root catalog)
86        let mut headers = HeaderMap::new();
87        let _ = headers.insert(
88            USER_AGENT,
89            format!("rustac/{}", env!("CARGO_PKG_VERSION")).parse()?,
90        );
91        let client = ClientBuilder::new().default_headers(headers).build()?;
92        Client::with_client(client, url)
93    }
94
95    /// Creates a new API client with the given [Client].
96    ///
97    /// Useful if you want to customize the behavior of the underlying `Client`,
98    /// as documented in [Client::new].
99    ///
100    /// # Examples
101    ///
102    /// ```
103    /// use stac_io::api::Client;
104    ///
105    /// let client = reqwest::Client::new();
106    /// let client = Client::with_client(client, "https://earth-search.aws.element84.com/v1/").unwrap();
107    /// ```
108    pub fn with_client(client: reqwest::Client, url: &str) -> Result<Client> {
109        Ok(Client {
110            client,
111            channel_buffer: DEFAULT_CHANNEL_BUFFER,
112            url_builder: UrlBuilder::new(url)?,
113        })
114    }
115
116    /// Returns a single collection.
117    ///
118    /// # Examples
119    ///
120    /// ```no_run
121    /// # use stac_io::api::Client;
122    /// let client = Client::new("https://planetarycomputer.microsoft.com/api/stac/v1").unwrap();
123    /// # tokio_test::block_on(async {
124    /// let collection = client.collection("sentinel-2-l2a").await.unwrap().unwrap();
125    /// # })
126    /// ```
127    pub async fn collection(&self, id: &str) -> Result<Option<Collection>> {
128        let url = self.url_builder.collection(id)?;
129        not_found_to_none(self.get(url).await)
130    }
131
132    /// Returns a stream of items belonging to a collection, using the [items
133    /// endpoint](https://github.com/radiantearth/stac-api-spec/tree/main/ogcapi-features#collection-items-collectionscollectioniditems).
134    ///
135    /// The `items` argument can be used to filter, sort, and otherwise
136    /// configure the request.
137    ///
138    /// # Examples
139    ///
140    /// ```no_run
141    /// use stac_io::api::Client;
142    /// use stac::api::Items;
143    /// use futures::StreamExt;
144    ///
145    /// let client = Client::new("https://planetarycomputer.microsoft.com/api/stac/v1").unwrap();
146    /// let items = Items {
147    ///     limit: Some(1),
148    ///     ..Default::default()
149    /// };
150    /// # tokio_test::block_on(async {
151    /// let items: Vec<_> = client
152    ///     .items("sentinel-2-l2a", items)
153    ///     .await
154    ///     .unwrap()
155    ///     .map(|result| result.unwrap())
156    ///     .collect()
157    ///     .await;
158    /// assert_eq!(items.len(), 1);
159    /// # })
160    /// ```
161    pub async fn items(
162        &self,
163        id: &str,
164        items: impl Into<Option<Items>>,
165    ) -> Result<impl Stream<Item = Result<Item>>> {
166        let url = self.url_builder.items(id)?; // TODO HATEOS
167        let items = match items.into() {
168            Some(items) => Some(GetItems::try_from(items)?),
169            _ => None,
170        };
171        let page = self
172            .request(Method::GET, url.clone(), items.as_ref(), None)
173            .await?;
174        Ok(stream_items(self.clone(), page, self.channel_buffer))
175    }
176
177    /// Searches an API, returning a stream of items.
178    ///
179    /// # Examples
180    ///
181    /// ```no_run
182    /// use stac_io::api::Client;
183    /// use stac::api::Search;
184    /// use futures::StreamExt;
185    ///
186    /// let client = Client::new("https://planetarycomputer.microsoft.com/api/stac/v1").unwrap();
187    /// let mut search = Search { collections: vec!["sentinel-2-l2a".to_string()], ..Default::default() };
188    /// # tokio_test::block_on(async {
189    /// let items: Vec<_> = client
190    ///     .search(search)
191    ///     .await
192    ///     .unwrap()
193    ///     .take(1)
194    ///     .map(|result| result.unwrap())
195    ///     .collect()
196    ///     .await;
197    /// assert_eq!(items.len(), 1);
198    /// # })
199    /// ```
200    pub async fn search(&self, search: Search) -> Result<impl Stream<Item = Result<Item>> + use<>> {
201        let url = self.url_builder.search().clone();
202        tracing::debug!("searching {url}");
203        // TODO support GET
204        let page = self.post(url.clone(), &search).await?;
205        Ok(stream_items(self.clone(), page, self.channel_buffer))
206    }
207
208    async fn get<V>(&self, url: impl IntoUrl) -> Result<V>
209    where
210        V: DeserializeOwned + SelfHref,
211    {
212        let url = url.into_url()?;
213        let mut value = self
214            .request::<(), V>(Method::GET, url.clone(), None, None)
215            .await?;
216        value.set_self_href(url);
217        Ok(value)
218    }
219
220    async fn post<S, R>(&self, url: impl IntoUrl, data: &S) -> Result<R>
221    where
222        S: Serialize + 'static,
223        R: DeserializeOwned,
224    {
225        self.request(Method::POST, url, Some(data), None).await
226    }
227
228    async fn request<S, R>(
229        &self,
230        method: Method,
231        url: impl IntoUrl,
232        params: impl Into<Option<&S>>,
233        headers: impl Into<Option<HeaderMap>>,
234    ) -> Result<R>
235    where
236        S: Serialize + 'static,
237        R: DeserializeOwned,
238    {
239        let url = url.into_url()?;
240        let mut request = match method {
241            Method::GET => {
242                let mut request = self.client.get(url);
243                if let Some(query) = params.into() {
244                    request = request.query(query);
245                }
246                request
247            }
248            Method::POST => {
249                let mut request = self.client.post(url);
250                if let Some(data) = params.into() {
251                    request = request.json(&data);
252                }
253                request
254            }
255            _ => unimplemented!(),
256        };
257        if let Some(headers) = headers.into() {
258            request = request.headers(headers);
259        }
260        let response = request.send().await?.error_for_status()?;
261        response.json().await.map_err(Error::from)
262    }
263
264    async fn request_from_link<R>(&self, link: Link) -> Result<R>
265    where
266        R: DeserializeOwned,
267    {
268        let method = if let Some(method) = link.method {
269            method.parse()?
270        } else {
271            Method::GET
272        };
273        let headers = if let Some(headers) = link.headers {
274            let mut header_map = HeaderMap::new();
275            for (key, value) in headers.into_iter() {
276                let header_name: HeaderName = key.parse()?;
277                let _ = header_map.insert(header_name, value.to_string().parse()?);
278            }
279            Some(header_map)
280        } else {
281            None
282        };
283        self.request::<Map<String, Value>, R>(method, link.href.as_str(), &link.body, headers)
284            .await
285    }
286}
287
288impl BlockingClient {
289    /// Creates a new blocking client.
290    ///
291    /// # Examples
292    ///
293    /// ```
294    /// use stac_io::api::BlockingClient;
295    ///
296    /// let client = BlockingClient::new("https://planetarycomputer.microsoft.com/api/stac/vi").unwrap();
297    /// ```
298    pub fn new(url: &str) -> Result<BlockingClient> {
299        Client::new(url).map(Self)
300    }
301
302    /// Searches an API, returning an iterable of items.
303    ///
304    /// To prevent fetching _all_ the items (which might be a lot), it is recommended to pass a `max_items`.
305    ///
306    /// # Examples
307    ///
308    /// ```no_run
309    /// use stac_io::api::BlockingClient;
310    /// use stac::api::Search;
311    ///
312    /// let client = BlockingClient::new("https://planetarycomputer.microsoft.com/api/stac/v1").unwrap();
313    /// let mut search = Search { collections: vec!["sentinel-2-l2a".to_string()], ..Default::default() };
314    /// let items: Vec<_> = client
315    ///     .search(search)
316    ///     .unwrap()
317    ///     .map(|result| result.unwrap())
318    ///     .take(1)
319    ///     .collect();
320    /// assert_eq!(items.len(), 1);
321    /// ```
322    pub fn search(&self, search: Search) -> Result<BlockingIterator> {
323        let runtime = Builder::new_current_thread().enable_all().build()?;
324        let stream = runtime.block_on(async move { self.0.search(search).await })?;
325        Ok(BlockingIterator {
326            runtime,
327            stream: Box::pin(stream),
328        })
329    }
330}
331
332impl Iterator for BlockingIterator {
333    type Item = Result<Item>;
334
335    fn next(&mut self) -> Option<Self::Item> {
336        self.runtime.block_on(self.stream.next())
337    }
338}
339
340fn stream_items(
341    client: Client,
342    page: ItemCollection,
343    channel_buffer: usize,
344) -> impl Stream<Item = Result<Item>> {
345    let (tx, mut rx) = mpsc::channel(channel_buffer);
346    let handle: JoinHandle<std::result::Result<(), SendError<_>>> = tokio::spawn(async move {
347        let pages = stream_pages(client, page);
348        pin_mut!(pages);
349        while let Some(result) = pages.next().await {
350            match result {
351                Ok(page) => tx.send(Ok(page)).await?,
352                Err(err) => {
353                    tx.send(Err(err)).await?;
354                    return Ok(());
355                }
356            }
357        }
358        Ok(())
359    });
360    try_stream! {
361        while let Some(result) = rx.recv().await {
362            let page = result?;
363            for item in page.items {
364                yield item;
365            }
366        }
367        let _ = handle.await?;
368    }
369}
370
371fn stream_pages(
372    client: Client,
373    mut page: ItemCollection,
374) -> impl Stream<Item = Result<ItemCollection>> {
375    try_stream! {
376        loop {
377            if page.items.is_empty() {
378                break;
379            }
380            let next_link = page.link("next").cloned();
381            yield page;
382            if let Some(next_link) = next_link {
383                if let Some(next_page) = client.request_from_link(next_link).await? {
384                    page = next_page;
385                } else {
386                    break;
387                }
388            } else {
389                break;
390            }
391        }
392    }
393}
394
395fn not_found_to_none<T>(result: Result<T>) -> Result<Option<T>> {
396    let mut result = result.map(Some);
397    if let Err(Error::Reqwest(ref err)) = result {
398        if err
399            .status()
400            .map(|s| s == StatusCode::NOT_FOUND)
401            .unwrap_or_default()
402        {
403            result = Ok(None);
404        }
405    }
406    result
407}
408
409#[cfg(test)]
410mod tests {
411    use super::Client;
412    use futures::StreamExt;
413    use mockito::{Matcher, Server};
414    use serde_json::json;
415    use stac::Links;
416    use stac::api::{ItemCollection, Items, Search};
417    use url::Url;
418
419    #[tokio::test]
420    async fn collection_not_found() {
421        let mut server = Server::new_async().await;
422        let collection = server
423            .mock("GET", "/collections/not-a-collection")
424            .with_body(include_str!("../mocks/not-a-collection.json"))
425            .with_header("content-type", "application/json")
426            .with_status(404)
427            .create_async()
428            .await;
429
430        let client = Client::new(&server.url()).unwrap();
431        assert!(
432            client
433                .collection("not-a-collection")
434                .await
435                .unwrap()
436                .is_none()
437        );
438        collection.assert_async().await;
439    }
440
441    #[tokio::test]
442    async fn search_with_paging() {
443        let mut server = Server::new_async().await;
444        let mut page_1_body: ItemCollection =
445            serde_json::from_str(include_str!("../mocks/search-page-1.json")).unwrap();
446        let mut next_link = page_1_body.link("next").unwrap().clone();
447        next_link.href = format!("{}/search", server.url());
448        page_1_body.set_link(next_link);
449        let page_1 = server
450            .mock("POST", "/search")
451            .match_body(Matcher::Json(json!({
452                "collections": ["sentinel-2-l2a"],
453                "limit": 1
454            })))
455            .with_body(serde_json::to_string(&page_1_body).unwrap())
456            .with_header("content-type", "application/geo+json")
457            .create_async()
458            .await;
459        let page_2 = server
460            .mock("POST", "/search")
461            .match_body(Matcher::Json(json!({
462                "collections": ["sentinel-2-l2a"],
463                "limit": 1,
464                "token": "next:S2A_MSIL2A_20230216T150721_R082_T19PHS_20230217T082924"
465            })))
466            .with_body(include_str!("../mocks/search-page-2.json"))
467            .with_header("content-type", "application/geo+json")
468            .create_async()
469            .await;
470
471        let client = Client::new(&server.url()).unwrap();
472        let mut search = Search {
473            collections: vec!["sentinel-2-l2a".to_string()],
474            ..Default::default()
475        };
476        search.items.limit = Some(1);
477        let items: Vec<_> = client
478            .search(search)
479            .await
480            .unwrap()
481            .map(|result| result.unwrap())
482            .take(2)
483            .collect()
484            .await;
485        page_1.assert_async().await;
486        page_2.assert_async().await;
487        assert_eq!(items.len(), 2);
488        assert!(items[0]["id"] != items[1]["id"]);
489    }
490
491    #[tokio::test]
492    async fn items_with_paging() {
493        let mut server = Server::new_async().await;
494        let mut page_1_body: ItemCollection =
495            serde_json::from_str(include_str!("../mocks/items-page-1.json")).unwrap();
496        let mut next_link = page_1_body.link("next").unwrap().clone();
497        let url: Url = next_link.href.as_str().parse().unwrap();
498        let query = url.query().unwrap();
499        next_link.href = format!(
500            "{}/collections/sentinel-2-l2a/items?{}",
501            server.url(),
502            query
503        );
504        page_1_body.set_link(next_link);
505        let page_1 = server
506            .mock("GET", "/collections/sentinel-2-l2a/items?limit=1")
507            .with_body(serde_json::to_string(&page_1_body).unwrap())
508            .with_header("content-type", "application/geo+json")
509            .create_async()
510            .await;
511        let page_2 = server
512            .mock("GET", "/collections/sentinel-2-l2a/items?limit=1&token=next:S2A_MSIL2A_20230216T235751_R087_T52CEB_20230217T134604")
513            .with_body(include_str!("../mocks/items-page-2.json"))
514            .with_header("content-type", "application/geo+json")
515            .create_async()
516            .await;
517
518        let client = Client::new(&server.url()).unwrap();
519        let items = Items {
520            limit: Some(1),
521            ..Default::default()
522        };
523        let items: Vec<_> = client
524            .items("sentinel-2-l2a", Some(items))
525            .await
526            .unwrap()
527            .map(|result| result.unwrap())
528            .take(2)
529            .collect()
530            .await;
531        page_1.assert_async().await;
532        page_2.assert_async().await;
533        assert_eq!(items.len(), 2);
534        assert!(items[0]["id"] != items[1]["id"]);
535    }
536
537    #[tokio::test]
538    async fn stop_on_empty_page() {
539        let mut server = Server::new_async().await;
540        let mut page_body: ItemCollection =
541            serde_json::from_str(include_str!("../mocks/items-page-1.json")).unwrap();
542        let mut next_link = page_body.link("next").unwrap().clone();
543        let url: Url = next_link.href.as_str().parse().unwrap();
544        let query = url.query().unwrap();
545        next_link.href = format!(
546            "{}/collections/sentinel-2-l2a/items?{}",
547            server.url(),
548            query
549        );
550        page_body.set_link(next_link);
551        page_body.items = vec![];
552        let page = server
553            .mock("GET", "/collections/sentinel-2-l2a/items?limit=1")
554            .with_body(serde_json::to_string(&page_body).unwrap())
555            .with_header("content-type", "application/geo+json")
556            .create_async()
557            .await;
558
559        let client = Client::new(&server.url()).unwrap();
560        let items = Items {
561            limit: Some(1),
562            ..Default::default()
563        };
564        let items: Vec<_> = client
565            .items("sentinel-2-l2a", Some(items))
566            .await
567            .unwrap()
568            .map(|result| result.unwrap())
569            .collect()
570            .await;
571        page.assert_async().await;
572        assert!(items.is_empty());
573    }
574
575    #[tokio::test]
576    async fn user_agent() {
577        let mut server = Server::new_async().await;
578        let _ = server
579            .mock("POST", "/search")
580            .with_body_from_file("mocks/items-page-1.json")
581            .match_header(
582                "user-agent",
583                format!("rustac/{}", env!("CARGO_PKG_VERSION")).as_str(),
584            )
585            .create_async()
586            .await;
587        let client = Client::new(&server.url()).unwrap();
588        let _ = client.search(Default::default()).await.unwrap();
589    }
590}