use crate::{Client, Error, Result};
use async_stream::try_stream;
use futures_core::stream::Stream;
use futures_util::{pin_mut, StreamExt};
use reqwest::Method;
use stac::{Collection, Links};
use stac_api::{GetItems, Item, ItemCollection, Items, Search, UrlBuilder};
use tokio::sync::mpsc;
const DEFAULT_CHANNEL_BUFFER: usize = 4;
#[derive(Debug)]
pub struct ApiClient {
client: Client,
channel_buffer: usize,
url_builder: UrlBuilder,
}
impl ApiClient {
pub fn new(url: &str) -> Result<ApiClient> {
Ok(ApiClient {
client: Client::new(),
channel_buffer: DEFAULT_CHANNEL_BUFFER,
url_builder: UrlBuilder::new(url)?,
})
}
pub async fn collection(&self, id: &str) -> Result<Option<Collection>> {
let url = self.url_builder.collection(id)?;
self.client.get(url).await
}
pub async fn items(
&self,
id: &str,
items: impl Into<Option<Items>>,
) -> Result<impl Stream<Item = Result<Item>>> {
let url = self.url_builder.items(id)?; let items = if let Some(items) = items.into() {
Some(GetItems::try_from(items)?)
} else {
None
};
let page: Option<ItemCollection> = self
.client
.request(Method::GET, url.clone(), items.as_ref(), None)
.await?;
if let Some(page) = page {
Ok(stream_items(self.client.clone(), page, self.channel_buffer))
} else {
Err(Error::NotFound(url))
}
}
pub async fn search(&self, search: Search) -> Result<impl Stream<Item = Result<Item>>> {
let url = self.url_builder.search().clone();
let page: Option<ItemCollection> = self.client.post(url.clone(), &search).await?;
if let Some(page) = page {
Ok(stream_items(self.client.clone(), page, self.channel_buffer))
} else {
Err(Error::NotFound(url))
}
}
}
fn stream_items(
client: Client,
page: ItemCollection,
channel_buffer: usize,
) -> impl Stream<Item = Result<Item>> {
let (tx, mut rx) = mpsc::channel(channel_buffer);
let handle = tokio::spawn(async move {
let pages = stream_pages(client, page);
pin_mut!(pages);
while let Some(result) = pages.next().await {
match result {
Ok(page) => tx.send(Ok(page)).await.unwrap(),
Err(err) => {
tx.send(Err(err)).await.unwrap();
return;
}
}
}
});
try_stream! {
while let Some(result) = rx.recv().await {
let page = result?;
for item in page.items {
yield item;
}
}
handle.await?;
}
}
fn stream_pages(
client: Client,
mut page: ItemCollection,
) -> impl Stream<Item = Result<ItemCollection>> {
try_stream! {
loop {
if page.items.is_empty() {
break;
}
let next_link = page.link("next").cloned();
yield page;
if let Some(next_link) = next_link {
if let Some(next_page) = client.request_from_link(next_link).await? {
page = next_page;
} else {
break;
}
} else {
break;
}
}
}
}
#[cfg(test)]
mod tests {
use super::ApiClient;
use futures_util::stream::StreamExt;
use mockito::{Matcher, Server};
use serde_json::json;
use stac::Links;
use stac_api::{ItemCollection, Items, Search};
use url::Url;
#[tokio::test]
async fn collection_not_found() {
let mut server = Server::new_async().await;
let collection = server
.mock("GET", "/collections/not-a-collection")
.with_body(include_str!("../mocks/not-a-collection.json"))
.with_header("content-type", "application/json")
.with_status(404)
.create_async()
.await;
let client = ApiClient::new(&server.url()).unwrap();
assert!(client
.collection("not-a-collection")
.await
.unwrap()
.is_none());
collection.assert_async().await;
}
#[tokio::test]
async fn search_with_paging() {
let mut server = Server::new_async().await;
let mut page_1_body: ItemCollection =
serde_json::from_str(include_str!("../mocks/search-page-1.json")).unwrap();
let mut next_link = page_1_body.link("next").unwrap().clone();
next_link.href = format!("{}/search", server.url());
page_1_body.set_link(next_link);
let page_1 = server
.mock("POST", "/search")
.match_body(Matcher::Json(json!({
"collections": ["sentinel-2-l2a"],
"limit": 1
})))
.with_body(serde_json::to_string(&page_1_body).unwrap())
.with_header("content-type", "application/geo+json")
.create_async()
.await;
let page_2 = server
.mock("POST", "/search")
.match_body(Matcher::Json(json!({
"collections": ["sentinel-2-l2a"],
"limit": 1,
"token": "next:S2A_MSIL2A_20230216T150721_R082_T19PHS_20230217T082924"
})))
.with_body(include_str!("../mocks/search-page-2.json"))
.with_header("content-type", "application/geo+json")
.create_async()
.await;
let client = ApiClient::new(&server.url()).unwrap();
let search = Search {
collections: Some(vec!["sentinel-2-l2a".to_string()]),
limit: Some(1),
..Default::default()
};
let items: Vec<_> = client
.search(search)
.await
.unwrap()
.map(|result| result.unwrap())
.take(2)
.collect()
.await;
page_1.assert_async().await;
page_2.assert_async().await;
assert_eq!(items.len(), 2);
assert!(items[0]["id"] != items[1]["id"]);
}
#[tokio::test]
async fn items_with_paging() {
let mut server = Server::new_async().await;
let mut page_1_body: ItemCollection =
serde_json::from_str(include_str!("../mocks/items-page-1.json")).unwrap();
let mut next_link = page_1_body.link("next").unwrap().clone();
let url: Url = next_link.href.parse().unwrap();
let query = url.query().unwrap();
next_link.href = format!(
"{}/collections/sentinel-2-l2a/items?{}",
server.url(),
query
);
page_1_body.set_link(next_link);
let page_1 = server
.mock("GET", "/collections/sentinel-2-l2a/items?limit=1")
.with_body(serde_json::to_string(&page_1_body).unwrap())
.with_header("content-type", "application/geo+json")
.create_async()
.await;
let page_2 = server
.mock("GET", "/collections/sentinel-2-l2a/items?limit=1&token=next:S2A_MSIL2A_20230216T235751_R087_T52CEB_20230217T134604")
.with_body(include_str!("../mocks/items-page-2.json"))
.with_header("content-type", "application/geo+json")
.create_async()
.await;
let client = ApiClient::new(&server.url()).unwrap();
let items = Items {
limit: Some(1),
..Default::default()
};
let items: Vec<_> = client
.items("sentinel-2-l2a", Some(items))
.await
.unwrap()
.map(|result| result.unwrap())
.take(2)
.collect()
.await;
page_1.assert_async().await;
page_2.assert_async().await;
assert_eq!(items.len(), 2);
assert!(items[0]["id"] != items[1]["id"]);
}
#[tokio::test]
async fn stop_on_empty_page() {
let mut server = Server::new_async().await;
let mut page_body: ItemCollection =
serde_json::from_str(include_str!("../mocks/items-page-1.json")).unwrap();
let mut next_link = page_body.link("next").unwrap().clone();
let url: Url = next_link.href.parse().unwrap();
let query = url.query().unwrap();
next_link.href = format!(
"{}/collections/sentinel-2-l2a/items?{}",
server.url(),
query
);
page_body.set_link(next_link);
page_body.items = vec![];
let page = server
.mock("GET", "/collections/sentinel-2-l2a/items?limit=1")
.with_body(serde_json::to_string(&page_body).unwrap())
.with_header("content-type", "application/geo+json")
.create_async()
.await;
let client = ApiClient::new(&server.url()).unwrap();
let items = Items {
limit: Some(1),
..Default::default()
};
let items: Vec<_> = client
.items("sentinel-2-l2a", Some(items))
.await
.unwrap()
.map(|result| result.unwrap())
.collect()
.await;
page.assert_async().await;
assert!(items.is_empty());
}
}