stac_api/
client.rs

1//! A STAC API client.
2
3use crate::{Error, GetItems, Item, ItemCollection, Items, Result, Search, UrlBuilder};
4use async_stream::try_stream;
5use futures::{pin_mut, Stream, StreamExt};
6use http::header::{HeaderName, USER_AGENT};
7use reqwest::{header::HeaderMap, ClientBuilder, IntoUrl, Method, StatusCode};
8use serde::{de::DeserializeOwned, Serialize};
9use serde_json::{Map, Value};
10use stac::{Collection, Link, Links, SelfHref};
11use std::pin::Pin;
12use tokio::{
13    runtime::{Builder, Runtime},
14    sync::mpsc::{self, error::SendError},
15    task::JoinHandle,
16};
17
18const DEFAULT_CHANNEL_BUFFER: usize = 4;
19
20/// Searches a STAC API.
21pub async fn search(
22    href: &str,
23    mut search: Search,
24    max_items: Option<usize>,
25) -> Result<ItemCollection> {
26    let client = Client::new(href)?;
27    if search.limit.is_none() {
28        if let Some(max_items) = max_items {
29            search.limit = Some(max_items.try_into()?);
30        }
31    }
32    let stream = client.search(search).await.unwrap();
33    let mut items = if let Some(max_items) = max_items {
34        if max_items == 0 {
35            return Ok(ItemCollection::default());
36        }
37        Vec::with_capacity(max_items)
38    } else {
39        Vec::new()
40    };
41    pin_mut!(stream);
42    while let Some(item) = stream.next().await {
43        let item = item?;
44        items.push(item);
45        if let Some(max_items) = max_items {
46            if items.len() >= max_items {
47                break;
48            }
49        }
50    }
51    ItemCollection::new(items)
52}
53
54/// A client for interacting with STAC APIs.
55#[derive(Clone, Debug)]
56pub struct Client {
57    client: reqwest::Client,
58    channel_buffer: usize,
59    url_builder: UrlBuilder,
60}
61
62/// A client for interacting with STAC APIs without async.
63#[derive(Debug)]
64pub struct BlockingClient(Client);
65
66/// A blocking iterator over items.
67#[allow(missing_debug_implementations)]
68pub struct BlockingIterator {
69    runtime: Runtime,
70    stream: Pin<Box<dyn Stream<Item = Result<Item>>>>,
71}
72
73impl Client {
74    /// Creates a new API client.
75    ///
76    /// # Examples
77    ///
78    /// ```
79    /// # use stac_api::Client;
80    /// let client = Client::new("https://planetarycomputer.microsoft.com/api/stac/v1").unwrap();
81    /// ```
82    pub fn new(url: &str) -> Result<Client> {
83        // TODO support HATEOS (aka look up the urls from the root catalog)
84        let mut headers = HeaderMap::new();
85        let _ = headers.insert(
86            USER_AGENT,
87            format!("stac-rs/{}", env!("CARGO_PKG_VERSION")).parse()?,
88        );
89        let client = ClientBuilder::new().default_headers(headers).build()?;
90        Client::with_client(client, url)
91    }
92
93    /// Creates a new API client with the given [Client].
94    ///
95    /// Useful if you want to customize the behavior of the underlying `Client`,
96    /// as documented in [Client::new].
97    ///
98    /// # Examples
99    ///
100    /// ```
101    /// use stac_api::Client;
102    ///
103    /// let client = reqwest::Client::new();
104    /// let client = Client::with_client(client, "https://earth-search.aws.element84.com/v1/").unwrap();
105    /// ```
106    pub fn with_client(client: reqwest::Client, url: &str) -> Result<Client> {
107        Ok(Client {
108            client,
109            channel_buffer: DEFAULT_CHANNEL_BUFFER,
110            url_builder: UrlBuilder::new(url)?,
111        })
112    }
113
114    /// Returns a single collection.
115    ///
116    /// # Examples
117    ///
118    /// ```no_run
119    /// # use stac_api::Client;
120    /// let client = Client::new("https://planetarycomputer.microsoft.com/api/stac/v1").unwrap();
121    /// # tokio_test::block_on(async {
122    /// let collection = client.collection("sentinel-2-l2a").await.unwrap().unwrap();
123    /// # })
124    /// ```
125    pub async fn collection(&self, id: &str) -> Result<Option<Collection>> {
126        let url = self.url_builder.collection(id)?;
127        not_found_to_none(self.get(url).await)
128    }
129
130    /// Returns a stream of items belonging to a collection, using the [items
131    /// endpoint](https://github.com/radiantearth/stac-api-spec/tree/main/ogcapi-features#collection-items-collectionscollectioniditems).
132    ///
133    /// The `items` argument can be used to filter, sort, and otherwise
134    /// configure the request.
135    ///
136    /// # Examples
137    ///
138    /// ```no_run
139    /// use stac_api::{Items, Client};
140    /// use futures::StreamExt;
141    ///
142    /// let client = Client::new("https://planetarycomputer.microsoft.com/api/stac/v1").unwrap();
143    /// let items = Items {
144    ///     limit: Some(1),
145    ///     ..Default::default()
146    /// };
147    /// # tokio_test::block_on(async {
148    /// let items: Vec<_> = client
149    ///     .items("sentinel-2-l2a", items)
150    ///     .await
151    ///     .unwrap()
152    ///     .map(|result| result.unwrap())
153    ///     .collect()
154    ///     .await;
155    /// assert_eq!(items.len(), 1);
156    /// # })
157    /// ```
158    pub async fn items(
159        &self,
160        id: &str,
161        items: impl Into<Option<Items>>,
162    ) -> Result<impl Stream<Item = Result<Item>>> {
163        let url = self.url_builder.items(id)?; // TODO HATEOS
164        let items = if let Some(items) = items.into() {
165            Some(GetItems::try_from(items)?)
166        } else {
167            None
168        };
169        let page = self
170            .request(Method::GET, url.clone(), items.as_ref(), None)
171            .await?;
172        Ok(stream_items(self.clone(), page, self.channel_buffer))
173    }
174
175    /// Searches an API, returning a stream of items.
176    ///
177    /// # Examples
178    ///
179    /// ```no_run
180    /// use stac_api::{Search, Client};
181    /// use futures::StreamExt;
182    ///
183    /// let client = Client::new("https://planetarycomputer.microsoft.com/api/stac/v1").unwrap();
184    /// let mut search = Search { collections: vec!["sentinel-2-l2a".to_string()], ..Default::default() };
185    /// # tokio_test::block_on(async {
186    /// let items: Vec<_> = client
187    ///     .search(search)
188    ///     .await
189    ///     .unwrap()
190    ///     .take(1)
191    ///     .map(|result| result.unwrap())
192    ///     .collect()
193    ///     .await;
194    /// assert_eq!(items.len(), 1);
195    /// # })
196    /// ```
197    pub async fn search(&self, search: Search) -> Result<impl Stream<Item = Result<Item>>> {
198        let url = self.url_builder.search().clone();
199        tracing::debug!("searching {url}");
200        // TODO support GET
201        let page = self.post(url.clone(), &search).await?;
202        Ok(stream_items(self.clone(), page, self.channel_buffer))
203    }
204
205    async fn get<V>(&self, url: impl IntoUrl) -> Result<V>
206    where
207        V: DeserializeOwned + SelfHref,
208    {
209        let url = url.into_url()?;
210        let mut value = self
211            .request::<(), V>(Method::GET, url.clone(), None, None)
212            .await?;
213        *value.self_href_mut() = Some(url.into());
214        Ok(value)
215    }
216
217    async fn post<S, R>(&self, url: impl IntoUrl, data: &S) -> Result<R>
218    where
219        S: Serialize + 'static,
220        R: DeserializeOwned,
221    {
222        self.request(Method::POST, url, Some(data), None).await
223    }
224
225    async fn request<S, R>(
226        &self,
227        method: Method,
228        url: impl IntoUrl,
229        params: impl Into<Option<&S>>,
230        headers: impl Into<Option<HeaderMap>>,
231    ) -> Result<R>
232    where
233        S: Serialize + 'static,
234        R: DeserializeOwned,
235    {
236        let url = url.into_url()?;
237        let mut request = match method {
238            Method::GET => {
239                let mut request = self.client.get(url);
240                if let Some(query) = params.into() {
241                    request = request.query(query);
242                }
243                request
244            }
245            Method::POST => {
246                let mut request = self.client.post(url);
247                if let Some(data) = params.into() {
248                    request = request.json(&data);
249                }
250                request
251            }
252            _ => unimplemented!(),
253        };
254        if let Some(headers) = headers.into() {
255            request = request.headers(headers);
256        }
257        let response = request.send().await?.error_for_status()?;
258        response.json().await.map_err(Error::from)
259    }
260
261    async fn request_from_link<R>(&self, link: Link) -> Result<R>
262    where
263        R: DeserializeOwned,
264    {
265        let method = if let Some(method) = link.method {
266            method.parse()?
267        } else {
268            Method::GET
269        };
270        let headers = if let Some(headers) = link.headers {
271            let mut header_map = HeaderMap::new();
272            for (key, value) in headers.into_iter() {
273                let header_name: HeaderName = key.parse()?;
274                let _ = header_map.insert(header_name, value.to_string().parse()?);
275            }
276            Some(header_map)
277        } else {
278            None
279        };
280        self.request::<Map<String, Value>, R>(method, link.href.as_str(), &link.body, headers)
281            .await
282    }
283}
284
285impl BlockingClient {
286    /// Creates a new blocking client.
287    ///
288    /// # Examples
289    ///
290    /// ```
291    /// use stac_api::BlockingClient;
292    ///
293    /// let client = BlockingClient::new("https://planetarycomputer.microsoft.com/api/stac/vi").unwrap();
294    /// ```
295    pub fn new(url: &str) -> Result<BlockingClient> {
296        Client::new(url).map(Self)
297    }
298
299    /// Searches an API, returning an iterable of items.
300    ///
301    /// To prevent fetching _all_ the items (which might be a lot), it is recommended to pass a `max_items`.
302    ///
303    /// # Examples
304    ///
305    /// ```no_run
306    /// use stac_api::{Search, BlockingClient};
307    ///
308    /// let client = BlockingClient::new("https://planetarycomputer.microsoft.com/api/stac/v1").unwrap();
309    /// let mut search = Search { collections: vec!["sentinel-2-l2a".to_string()], ..Default::default() };
310    /// let items: Vec<_> = client
311    ///     .search(search)
312    ///     .unwrap()
313    ///     .map(|result| result.unwrap())
314    ///     .take(1)
315    ///     .collect();
316    /// assert_eq!(items.len(), 1);
317    /// ```
318    pub fn search(&self, search: Search) -> Result<BlockingIterator> {
319        let runtime = Builder::new_current_thread().enable_all().build()?;
320        let stream = runtime.block_on(async move { self.0.search(search).await })?;
321        Ok(BlockingIterator {
322            runtime,
323            stream: Box::pin(stream),
324        })
325    }
326}
327
328impl Iterator for BlockingIterator {
329    type Item = Result<Item>;
330
331    fn next(&mut self) -> Option<Self::Item> {
332        self.runtime.block_on(self.stream.next())
333    }
334}
335
336fn stream_items(
337    client: Client,
338    page: ItemCollection,
339    channel_buffer: usize,
340) -> impl Stream<Item = Result<Item>> {
341    let (tx, mut rx) = mpsc::channel(channel_buffer);
342    let handle: JoinHandle<std::result::Result<(), SendError<_>>> = tokio::spawn(async move {
343        let pages = stream_pages(client, page);
344        pin_mut!(pages);
345        while let Some(result) = pages.next().await {
346            match result {
347                Ok(page) => tx.send(Ok(page)).await?,
348                Err(err) => {
349                    tx.send(Err(err)).await?;
350                    return Ok(());
351                }
352            }
353        }
354        Ok(())
355    });
356    try_stream! {
357        while let Some(result) = rx.recv().await {
358            let page = result?;
359            for item in page.items {
360                yield item;
361            }
362        }
363        let _ = handle.await?;
364    }
365}
366
367fn stream_pages(
368    client: Client,
369    mut page: ItemCollection,
370) -> impl Stream<Item = Result<ItemCollection>> {
371    try_stream! {
372        loop {
373            if page.items.is_empty() {
374                break;
375            }
376            let next_link = page.link("next").cloned();
377            yield page;
378            if let Some(next_link) = next_link {
379                if let Some(next_page) = client.request_from_link(next_link).await? {
380                    page = next_page;
381                } else {
382                    break;
383                }
384            } else {
385                break;
386            }
387        }
388    }
389}
390
391fn not_found_to_none<T>(result: Result<T>) -> Result<Option<T>> {
392    let mut result = result.map(Some);
393    if let Err(Error::Reqwest(ref err)) = result {
394        if err
395            .status()
396            .map(|s| s == StatusCode::NOT_FOUND)
397            .unwrap_or_default()
398        {
399            result = Ok(None);
400        }
401    }
402    result
403}
404
405#[cfg(test)]
406mod tests {
407    use super::Client;
408    use crate::{ItemCollection, Items, Search};
409    use futures::StreamExt;
410    use mockito::{Matcher, Server};
411    use serde_json::json;
412    use stac::Links;
413    use url::Url;
414
415    #[tokio::test]
416    async fn collection_not_found() {
417        let mut server = Server::new_async().await;
418        let collection = server
419            .mock("GET", "/collections/not-a-collection")
420            .with_body(include_str!("../mocks/not-a-collection.json"))
421            .with_header("content-type", "application/json")
422            .with_status(404)
423            .create_async()
424            .await;
425
426        let client = Client::new(&server.url()).unwrap();
427        assert!(client
428            .collection("not-a-collection")
429            .await
430            .unwrap()
431            .is_none());
432        collection.assert_async().await;
433    }
434
435    #[tokio::test]
436    async fn search_with_paging() {
437        let mut server = Server::new_async().await;
438        let mut page_1_body: ItemCollection =
439            serde_json::from_str(include_str!("../mocks/search-page-1.json")).unwrap();
440        let mut next_link = page_1_body.link("next").unwrap().clone();
441        next_link.href = format!("{}/search", server.url()).into();
442        page_1_body.set_link(next_link);
443        let page_1 = server
444            .mock("POST", "/search")
445            .match_body(Matcher::Json(json!({
446                "collections": ["sentinel-2-l2a"],
447                "limit": 1
448            })))
449            .with_body(serde_json::to_string(&page_1_body).unwrap())
450            .with_header("content-type", "application/geo+json")
451            .create_async()
452            .await;
453        let page_2 = server
454            .mock("POST", "/search")
455            .match_body(Matcher::Json(json!({
456                "collections": ["sentinel-2-l2a"],
457                "limit": 1,
458                "token": "next:S2A_MSIL2A_20230216T150721_R082_T19PHS_20230217T082924"
459            })))
460            .with_body(include_str!("../mocks/search-page-2.json"))
461            .with_header("content-type", "application/geo+json")
462            .create_async()
463            .await;
464
465        let client = Client::new(&server.url()).unwrap();
466        let mut search = Search {
467            collections: vec!["sentinel-2-l2a".to_string()],
468            ..Default::default()
469        };
470        search.items.limit = Some(1);
471        let items: Vec<_> = client
472            .search(search)
473            .await
474            .unwrap()
475            .map(|result| result.unwrap())
476            .take(2)
477            .collect()
478            .await;
479        page_1.assert_async().await;
480        page_2.assert_async().await;
481        assert_eq!(items.len(), 2);
482        assert!(items[0]["id"] != items[1]["id"]);
483    }
484
485    #[tokio::test]
486    async fn items_with_paging() {
487        let mut server = Server::new_async().await;
488        let mut page_1_body: ItemCollection =
489            serde_json::from_str(include_str!("../mocks/items-page-1.json")).unwrap();
490        let mut next_link = page_1_body.link("next").unwrap().clone();
491        let url: Url = next_link.href.as_str().parse().unwrap();
492        let query = url.query().unwrap();
493        next_link.href = format!(
494            "{}/collections/sentinel-2-l2a/items?{}",
495            server.url(),
496            query
497        )
498        .into();
499        page_1_body.set_link(next_link);
500        let page_1 = server
501            .mock("GET", "/collections/sentinel-2-l2a/items?limit=1")
502            .with_body(serde_json::to_string(&page_1_body).unwrap())
503            .with_header("content-type", "application/geo+json")
504            .create_async()
505            .await;
506        let page_2 = server
507            .mock("GET", "/collections/sentinel-2-l2a/items?limit=1&token=next:S2A_MSIL2A_20230216T235751_R087_T52CEB_20230217T134604")
508            .with_body(include_str!("../mocks/items-page-2.json"))
509            .with_header("content-type", "application/geo+json")
510            .create_async()
511            .await;
512
513        let client = Client::new(&server.url()).unwrap();
514        let items = Items {
515            limit: Some(1),
516            ..Default::default()
517        };
518        let items: Vec<_> = client
519            .items("sentinel-2-l2a", Some(items))
520            .await
521            .unwrap()
522            .map(|result| result.unwrap())
523            .take(2)
524            .collect()
525            .await;
526        page_1.assert_async().await;
527        page_2.assert_async().await;
528        assert_eq!(items.len(), 2);
529        assert!(items[0]["id"] != items[1]["id"]);
530    }
531
532    #[tokio::test]
533    async fn stop_on_empty_page() {
534        let mut server = Server::new_async().await;
535        let mut page_body: ItemCollection =
536            serde_json::from_str(include_str!("../mocks/items-page-1.json")).unwrap();
537        let mut next_link = page_body.link("next").unwrap().clone();
538        let url: Url = next_link.href.as_str().parse().unwrap();
539        let query = url.query().unwrap();
540        next_link.href = format!(
541            "{}/collections/sentinel-2-l2a/items?{}",
542            server.url(),
543            query
544        )
545        .into();
546        page_body.set_link(next_link);
547        page_body.items = vec![];
548        let page = server
549            .mock("GET", "/collections/sentinel-2-l2a/items?limit=1")
550            .with_body(serde_json::to_string(&page_body).unwrap())
551            .with_header("content-type", "application/geo+json")
552            .create_async()
553            .await;
554
555        let client = Client::new(&server.url()).unwrap();
556        let items = Items {
557            limit: Some(1),
558            ..Default::default()
559        };
560        let items: Vec<_> = client
561            .items("sentinel-2-l2a", Some(items))
562            .await
563            .unwrap()
564            .map(|result| result.unwrap())
565            .collect()
566            .await;
567        page.assert_async().await;
568        assert!(items.is_empty());
569    }
570
571    #[tokio::test]
572    async fn user_agent() {
573        let mut server = Server::new_async().await;
574        let _ = server
575            .mock("POST", "/search")
576            .with_body_from_file("mocks/items-page-1.json")
577            .match_header(
578                "user-agent",
579                format!("stac-rs/{}", env!("CARGO_PKG_VERSION")).as_str(),
580            )
581            .create_async()
582            .await;
583        let client = Client::new(&server.url()).unwrap();
584        let _ = client.search(Default::default()).await.unwrap();
585    }
586}