stac_api/
client.rs

1//! A STAC API client.
2
3use crate::{Error, GetItems, Item, ItemCollection, Items, Result, Search, UrlBuilder};
4use async_stream::try_stream;
5use futures::{Stream, StreamExt, pin_mut};
6use http::header::{HeaderName, USER_AGENT};
7use reqwest::{ClientBuilder, IntoUrl, Method, StatusCode, header::HeaderMap};
8use serde::{Serialize, de::DeserializeOwned};
9use serde_json::{Map, Value};
10use stac::{Collection, Link, Links, SelfHref};
11use std::pin::Pin;
12use tokio::{
13    runtime::{Builder, Runtime},
14    sync::mpsc::{self, error::SendError},
15    task::JoinHandle,
16};
17
18const DEFAULT_CHANNEL_BUFFER: usize = 4;
19
20/// Searches a STAC API.
21pub async fn search(
22    href: &str,
23    mut search: Search,
24    max_items: Option<usize>,
25) -> Result<ItemCollection> {
26    let client = Client::new(href)?;
27    if search.limit.is_none() {
28        if let Some(max_items) = max_items {
29            search.limit = Some(max_items.try_into()?);
30        }
31    }
32    let stream = client.search(search).await?;
33    let mut items = if let Some(max_items) = max_items {
34        if max_items == 0 {
35            return Ok(ItemCollection::default());
36        }
37        Vec::with_capacity(max_items)
38    } else {
39        Vec::new()
40    };
41    pin_mut!(stream);
42    while let Some(item) = stream.next().await {
43        let item = item?;
44        items.push(item);
45        if let Some(max_items) = max_items {
46            if items.len() >= max_items {
47                break;
48            }
49        }
50    }
51    ItemCollection::new(items)
52}
53
54/// A client for interacting with STAC APIs.
55#[derive(Clone, Debug)]
56pub struct Client {
57    client: reqwest::Client,
58    channel_buffer: usize,
59    url_builder: UrlBuilder,
60}
61
62/// A client for interacting with STAC APIs without async.
63#[derive(Debug)]
64pub struct BlockingClient(Client);
65
66/// A blocking iterator over items.
67#[allow(missing_debug_implementations)]
68pub struct BlockingIterator {
69    runtime: Runtime,
70    stream: Pin<Box<dyn Stream<Item = Result<Item>>>>,
71}
72
73impl Client {
74    /// Creates a new API client.
75    ///
76    /// # Examples
77    ///
78    /// ```
79    /// # use stac_api::Client;
80    /// let client = Client::new("https://planetarycomputer.microsoft.com/api/stac/v1").unwrap();
81    /// ```
82    pub fn new(url: &str) -> Result<Client> {
83        // TODO support HATEOS (aka look up the urls from the root catalog)
84        let mut headers = HeaderMap::new();
85        let _ = headers.insert(
86            USER_AGENT,
87            format!("rustac/{}", env!("CARGO_PKG_VERSION")).parse()?,
88        );
89        let client = ClientBuilder::new().default_headers(headers).build()?;
90        Client::with_client(client, url)
91    }
92
93    /// Creates a new API client with the given [Client].
94    ///
95    /// Useful if you want to customize the behavior of the underlying `Client`,
96    /// as documented in [Client::new].
97    ///
98    /// # Examples
99    ///
100    /// ```
101    /// use stac_api::Client;
102    ///
103    /// let client = reqwest::Client::new();
104    /// let client = Client::with_client(client, "https://earth-search.aws.element84.com/v1/").unwrap();
105    /// ```
106    pub fn with_client(client: reqwest::Client, url: &str) -> Result<Client> {
107        Ok(Client {
108            client,
109            channel_buffer: DEFAULT_CHANNEL_BUFFER,
110            url_builder: UrlBuilder::new(url)?,
111        })
112    }
113
114    /// Returns a single collection.
115    ///
116    /// # Examples
117    ///
118    /// ```no_run
119    /// # use stac_api::Client;
120    /// let client = Client::new("https://planetarycomputer.microsoft.com/api/stac/v1").unwrap();
121    /// # tokio_test::block_on(async {
122    /// let collection = client.collection("sentinel-2-l2a").await.unwrap().unwrap();
123    /// # })
124    /// ```
125    pub async fn collection(&self, id: &str) -> Result<Option<Collection>> {
126        let url = self.url_builder.collection(id)?;
127        not_found_to_none(self.get(url).await)
128    }
129
130    /// Returns a stream of items belonging to a collection, using the [items
131    /// endpoint](https://github.com/radiantearth/stac-api-spec/tree/main/ogcapi-features#collection-items-collectionscollectioniditems).
132    ///
133    /// The `items` argument can be used to filter, sort, and otherwise
134    /// configure the request.
135    ///
136    /// # Examples
137    ///
138    /// ```no_run
139    /// use stac_api::{Items, Client};
140    /// use futures::StreamExt;
141    ///
142    /// let client = Client::new("https://planetarycomputer.microsoft.com/api/stac/v1").unwrap();
143    /// let items = Items {
144    ///     limit: Some(1),
145    ///     ..Default::default()
146    /// };
147    /// # tokio_test::block_on(async {
148    /// let items: Vec<_> = client
149    ///     .items("sentinel-2-l2a", items)
150    ///     .await
151    ///     .unwrap()
152    ///     .map(|result| result.unwrap())
153    ///     .collect()
154    ///     .await;
155    /// assert_eq!(items.len(), 1);
156    /// # })
157    /// ```
158    pub async fn items(
159        &self,
160        id: &str,
161        items: impl Into<Option<Items>>,
162    ) -> Result<impl Stream<Item = Result<Item>>> {
163        let url = self.url_builder.items(id)?; // TODO HATEOS
164        let items = match items.into() {
165            Some(items) => Some(GetItems::try_from(items)?),
166            _ => None,
167        };
168        let page = self
169            .request(Method::GET, url.clone(), items.as_ref(), None)
170            .await?;
171        Ok(stream_items(self.clone(), page, self.channel_buffer))
172    }
173
174    /// Searches an API, returning a stream of items.
175    ///
176    /// # Examples
177    ///
178    /// ```no_run
179    /// use stac_api::{Search, Client};
180    /// use futures::StreamExt;
181    ///
182    /// let client = Client::new("https://planetarycomputer.microsoft.com/api/stac/v1").unwrap();
183    /// let mut search = Search { collections: vec!["sentinel-2-l2a".to_string()], ..Default::default() };
184    /// # tokio_test::block_on(async {
185    /// let items: Vec<_> = client
186    ///     .search(search)
187    ///     .await
188    ///     .unwrap()
189    ///     .take(1)
190    ///     .map(|result| result.unwrap())
191    ///     .collect()
192    ///     .await;
193    /// assert_eq!(items.len(), 1);
194    /// # })
195    /// ```
196    pub async fn search(&self, search: Search) -> Result<impl Stream<Item = Result<Item>> + use<>> {
197        let url = self.url_builder.search().clone();
198        tracing::debug!("searching {url}");
199        // TODO support GET
200        let page = self.post(url.clone(), &search).await?;
201        Ok(stream_items(self.clone(), page, self.channel_buffer))
202    }
203
204    async fn get<V>(&self, url: impl IntoUrl) -> Result<V>
205    where
206        V: DeserializeOwned + SelfHref,
207    {
208        let url = url.into_url()?;
209        let mut value = self
210            .request::<(), V>(Method::GET, url.clone(), None, None)
211            .await?;
212        value.set_self_href(url);
213        Ok(value)
214    }
215
216    async fn post<S, R>(&self, url: impl IntoUrl, data: &S) -> Result<R>
217    where
218        S: Serialize + 'static,
219        R: DeserializeOwned,
220    {
221        self.request(Method::POST, url, Some(data), None).await
222    }
223
224    async fn request<S, R>(
225        &self,
226        method: Method,
227        url: impl IntoUrl,
228        params: impl Into<Option<&S>>,
229        headers: impl Into<Option<HeaderMap>>,
230    ) -> Result<R>
231    where
232        S: Serialize + 'static,
233        R: DeserializeOwned,
234    {
235        let url = url.into_url()?;
236        let mut request = match method {
237            Method::GET => {
238                let mut request = self.client.get(url);
239                if let Some(query) = params.into() {
240                    request = request.query(query);
241                }
242                request
243            }
244            Method::POST => {
245                let mut request = self.client.post(url);
246                if let Some(data) = params.into() {
247                    request = request.json(&data);
248                }
249                request
250            }
251            _ => unimplemented!(),
252        };
253        if let Some(headers) = headers.into() {
254            request = request.headers(headers);
255        }
256        let response = request.send().await?.error_for_status()?;
257        response.json().await.map_err(Error::from)
258    }
259
260    async fn request_from_link<R>(&self, link: Link) -> Result<R>
261    where
262        R: DeserializeOwned,
263    {
264        let method = if let Some(method) = link.method {
265            method.parse()?
266        } else {
267            Method::GET
268        };
269        let headers = if let Some(headers) = link.headers {
270            let mut header_map = HeaderMap::new();
271            for (key, value) in headers.into_iter() {
272                let header_name: HeaderName = key.parse()?;
273                let _ = header_map.insert(header_name, value.to_string().parse()?);
274            }
275            Some(header_map)
276        } else {
277            None
278        };
279        self.request::<Map<String, Value>, R>(method, link.href.as_str(), &link.body, headers)
280            .await
281    }
282}
283
284impl BlockingClient {
285    /// Creates a new blocking client.
286    ///
287    /// # Examples
288    ///
289    /// ```
290    /// use stac_api::BlockingClient;
291    ///
292    /// let client = BlockingClient::new("https://planetarycomputer.microsoft.com/api/stac/vi").unwrap();
293    /// ```
294    pub fn new(url: &str) -> Result<BlockingClient> {
295        Client::new(url).map(Self)
296    }
297
298    /// Searches an API, returning an iterable of items.
299    ///
300    /// To prevent fetching _all_ the items (which might be a lot), it is recommended to pass a `max_items`.
301    ///
302    /// # Examples
303    ///
304    /// ```no_run
305    /// use stac_api::{Search, BlockingClient};
306    ///
307    /// let client = BlockingClient::new("https://planetarycomputer.microsoft.com/api/stac/v1").unwrap();
308    /// let mut search = Search { collections: vec!["sentinel-2-l2a".to_string()], ..Default::default() };
309    /// let items: Vec<_> = client
310    ///     .search(search)
311    ///     .unwrap()
312    ///     .map(|result| result.unwrap())
313    ///     .take(1)
314    ///     .collect();
315    /// assert_eq!(items.len(), 1);
316    /// ```
317    pub fn search(&self, search: Search) -> Result<BlockingIterator> {
318        let runtime = Builder::new_current_thread().enable_all().build()?;
319        let stream = runtime.block_on(async move { self.0.search(search).await })?;
320        Ok(BlockingIterator {
321            runtime,
322            stream: Box::pin(stream),
323        })
324    }
325}
326
327impl Iterator for BlockingIterator {
328    type Item = Result<Item>;
329
330    fn next(&mut self) -> Option<Self::Item> {
331        self.runtime.block_on(self.stream.next())
332    }
333}
334
335fn stream_items(
336    client: Client,
337    page: ItemCollection,
338    channel_buffer: usize,
339) -> impl Stream<Item = Result<Item>> {
340    let (tx, mut rx) = mpsc::channel(channel_buffer);
341    let handle: JoinHandle<std::result::Result<(), SendError<_>>> = tokio::spawn(async move {
342        let pages = stream_pages(client, page);
343        pin_mut!(pages);
344        while let Some(result) = pages.next().await {
345            match result {
346                Ok(page) => tx.send(Ok(page)).await?,
347                Err(err) => {
348                    tx.send(Err(err)).await?;
349                    return Ok(());
350                }
351            }
352        }
353        Ok(())
354    });
355    try_stream! {
356        while let Some(result) = rx.recv().await {
357            let page = result?;
358            for item in page.items {
359                yield item;
360            }
361        }
362        let _ = handle.await?;
363    }
364}
365
366fn stream_pages(
367    client: Client,
368    mut page: ItemCollection,
369) -> impl Stream<Item = Result<ItemCollection>> {
370    try_stream! {
371        loop {
372            if page.items.is_empty() {
373                break;
374            }
375            let next_link = page.link("next").cloned();
376            yield page;
377            if let Some(next_link) = next_link {
378                if let Some(next_page) = client.request_from_link(next_link).await? {
379                    page = next_page;
380                } else {
381                    break;
382                }
383            } else {
384                break;
385            }
386        }
387    }
388}
389
390fn not_found_to_none<T>(result: Result<T>) -> Result<Option<T>> {
391    let mut result = result.map(Some);
392    if let Err(Error::Reqwest(ref err)) = result {
393        if err
394            .status()
395            .map(|s| s == StatusCode::NOT_FOUND)
396            .unwrap_or_default()
397        {
398            result = Ok(None);
399        }
400    }
401    result
402}
403
404#[cfg(test)]
405mod tests {
406    use super::Client;
407    use crate::{ItemCollection, Items, Search};
408    use futures::StreamExt;
409    use mockito::{Matcher, Server};
410    use serde_json::json;
411    use stac::Links;
412    use url::Url;
413
414    #[tokio::test]
415    async fn collection_not_found() {
416        let mut server = Server::new_async().await;
417        let collection = server
418            .mock("GET", "/collections/not-a-collection")
419            .with_body(include_str!("../mocks/not-a-collection.json"))
420            .with_header("content-type", "application/json")
421            .with_status(404)
422            .create_async()
423            .await;
424
425        let client = Client::new(&server.url()).unwrap();
426        assert!(
427            client
428                .collection("not-a-collection")
429                .await
430                .unwrap()
431                .is_none()
432        );
433        collection.assert_async().await;
434    }
435
436    #[tokio::test]
437    async fn search_with_paging() {
438        let mut server = Server::new_async().await;
439        let mut page_1_body: ItemCollection =
440            serde_json::from_str(include_str!("../mocks/search-page-1.json")).unwrap();
441        let mut next_link = page_1_body.link("next").unwrap().clone();
442        next_link.href = format!("{}/search", server.url());
443        page_1_body.set_link(next_link);
444        let page_1 = server
445            .mock("POST", "/search")
446            .match_body(Matcher::Json(json!({
447                "collections": ["sentinel-2-l2a"],
448                "limit": 1
449            })))
450            .with_body(serde_json::to_string(&page_1_body).unwrap())
451            .with_header("content-type", "application/geo+json")
452            .create_async()
453            .await;
454        let page_2 = server
455            .mock("POST", "/search")
456            .match_body(Matcher::Json(json!({
457                "collections": ["sentinel-2-l2a"],
458                "limit": 1,
459                "token": "next:S2A_MSIL2A_20230216T150721_R082_T19PHS_20230217T082924"
460            })))
461            .with_body(include_str!("../mocks/search-page-2.json"))
462            .with_header("content-type", "application/geo+json")
463            .create_async()
464            .await;
465
466        let client = Client::new(&server.url()).unwrap();
467        let mut search = Search {
468            collections: vec!["sentinel-2-l2a".to_string()],
469            ..Default::default()
470        };
471        search.items.limit = Some(1);
472        let items: Vec<_> = client
473            .search(search)
474            .await
475            .unwrap()
476            .map(|result| result.unwrap())
477            .take(2)
478            .collect()
479            .await;
480        page_1.assert_async().await;
481        page_2.assert_async().await;
482        assert_eq!(items.len(), 2);
483        assert!(items[0]["id"] != items[1]["id"]);
484    }
485
486    #[tokio::test]
487    async fn items_with_paging() {
488        let mut server = Server::new_async().await;
489        let mut page_1_body: ItemCollection =
490            serde_json::from_str(include_str!("../mocks/items-page-1.json")).unwrap();
491        let mut next_link = page_1_body.link("next").unwrap().clone();
492        let url: Url = next_link.href.as_str().parse().unwrap();
493        let query = url.query().unwrap();
494        next_link.href = format!(
495            "{}/collections/sentinel-2-l2a/items?{}",
496            server.url(),
497            query
498        );
499        page_1_body.set_link(next_link);
500        let page_1 = server
501            .mock("GET", "/collections/sentinel-2-l2a/items?limit=1")
502            .with_body(serde_json::to_string(&page_1_body).unwrap())
503            .with_header("content-type", "application/geo+json")
504            .create_async()
505            .await;
506        let page_2 = server
507            .mock("GET", "/collections/sentinel-2-l2a/items?limit=1&token=next:S2A_MSIL2A_20230216T235751_R087_T52CEB_20230217T134604")
508            .with_body(include_str!("../mocks/items-page-2.json"))
509            .with_header("content-type", "application/geo+json")
510            .create_async()
511            .await;
512
513        let client = Client::new(&server.url()).unwrap();
514        let items = Items {
515            limit: Some(1),
516            ..Default::default()
517        };
518        let items: Vec<_> = client
519            .items("sentinel-2-l2a", Some(items))
520            .await
521            .unwrap()
522            .map(|result| result.unwrap())
523            .take(2)
524            .collect()
525            .await;
526        page_1.assert_async().await;
527        page_2.assert_async().await;
528        assert_eq!(items.len(), 2);
529        assert!(items[0]["id"] != items[1]["id"]);
530    }
531
532    #[tokio::test]
533    async fn stop_on_empty_page() {
534        let mut server = Server::new_async().await;
535        let mut page_body: ItemCollection =
536            serde_json::from_str(include_str!("../mocks/items-page-1.json")).unwrap();
537        let mut next_link = page_body.link("next").unwrap().clone();
538        let url: Url = next_link.href.as_str().parse().unwrap();
539        let query = url.query().unwrap();
540        next_link.href = format!(
541            "{}/collections/sentinel-2-l2a/items?{}",
542            server.url(),
543            query
544        );
545        page_body.set_link(next_link);
546        page_body.items = vec![];
547        let page = server
548            .mock("GET", "/collections/sentinel-2-l2a/items?limit=1")
549            .with_body(serde_json::to_string(&page_body).unwrap())
550            .with_header("content-type", "application/geo+json")
551            .create_async()
552            .await;
553
554        let client = Client::new(&server.url()).unwrap();
555        let items = Items {
556            limit: Some(1),
557            ..Default::default()
558        };
559        let items: Vec<_> = client
560            .items("sentinel-2-l2a", Some(items))
561            .await
562            .unwrap()
563            .map(|result| result.unwrap())
564            .collect()
565            .await;
566        page.assert_async().await;
567        assert!(items.is_empty());
568    }
569
570    #[tokio::test]
571    async fn user_agent() {
572        let mut server = Server::new_async().await;
573        let _ = server
574            .mock("POST", "/search")
575            .with_body_from_file("mocks/items-page-1.json")
576            .match_header(
577                "user-agent",
578                format!("rustac/{}", env!("CARGO_PKG_VERSION")).as_str(),
579            )
580            .create_async()
581            .await;
582        let client = Client::new(&server.url()).unwrap();
583        let _ = client.search(Default::default()).await.unwrap();
584    }
585}