#![warn(unused_crate_dependencies)]
use arrow::{
array::{GenericByteArray, RecordBatch},
datatypes::{GenericBinaryType, SchemaBuilder},
};
use duckdb::{types::Value, Connection};
use geoarrow::{
array::{CoordType, WKBArray},
datatypes::NativeType,
table::Table,
};
use stac_api::{Direction, Search};
use std::fmt::Debug;
use thiserror::Error;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum Error {
#[error(transparent)]
Arrow(#[from] arrow::error::ArrowError),
#[error(transparent)]
DuckDB(#[from] duckdb::Error),
#[error(transparent)]
GeoArrow(#[from] geoarrow::error::GeoArrowError),
#[error(transparent)]
Stac(#[from] stac::Error),
#[error(transparent)]
StacApi(#[from] stac_api::Error),
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug)]
pub struct Client {
connection: Connection,
}
#[derive(Debug)]
pub struct Query {
pub sql: String,
pub params: Vec<Value>,
}
impl Client {
pub fn new() -> Result<Client> {
let connection = Connection::open_in_memory()?;
connection.execute("INSTALL spatial", [])?;
connection.execute("LOAD spatial", [])?;
Ok(Client { connection })
}
pub fn search(&self, href: &str, search: impl Into<Search>) -> Result<stac::ItemCollection> {
let record_batches = self.search_to_arrow(href, search)?;
if record_batches.is_empty() {
return Ok(Vec::new().into());
}
let schema = record_batches[0].schema();
let table = Table::try_new(record_batches, schema)?;
let items = stac::geoarrow::from_table(table)?;
Ok(items)
}
pub fn search_to_json(
&self,
href: &str,
search: impl Into<Search>,
) -> Result<stac_api::ItemCollection> {
let record_batches = self.search_to_arrow(href, search)?;
if record_batches.is_empty() {
return Ok(Vec::new().into());
}
let schema = record_batches[0].schema();
let table = Table::try_new(record_batches, schema)?;
let items = stac::geoarrow::json::from_table(table)?;
let item_collection = stac_api::ItemCollection::new(items)?;
Ok(item_collection)
}
pub fn search_to_arrow(
&self,
href: &str,
search: impl Into<Search>,
) -> Result<Vec<RecordBatch>> {
let query = self.query(search, href)?;
let mut statement = self.connection.prepare(&query.sql)?;
statement
.query_arrow(duckdb::params_from_iter(query.params))?
.map(to_geoarrow_record_batch)
.collect::<Result<_>>()
}
fn query(&self, search: impl Into<Search>, href: &str) -> Result<Query> {
let mut search: Search = search.into();
let limit = search.items.limit.take();
let sortby = std::mem::take(&mut search.items.sortby);
let fields = std::mem::take(&mut search.items.fields);
let mut statement = self.connection.prepare(&format!(
"SELECT column_name FROM (DESCRIBE SELECT * from read_parquet('{}'))",
href
))?;
let mut columns = Vec::new();
let mut has_start_datetime = false;
let mut has_end_datetime: bool = false;
for row in statement.query_map([], |row| row.get::<_, String>(0))? {
let column = row?;
if column == "start_datetime" {
has_start_datetime = true;
}
if column == "end_datetime" {
has_end_datetime = true;
}
if let Some(fields) = fields.as_ref() {
if fields.exclude.contains(&column)
|| !(fields.include.is_empty() || fields.include.contains(&column))
{
continue;
}
}
if column == "geometry" {
columns.push("ST_AsWKB(geometry) geometry".to_string());
} else {
columns.push(format!("\"{}\"", column));
}
}
let mut wheres = Vec::new();
let mut params = Vec::new();
if !search.ids.is_empty() {
wheres.push(format!(
"id IN ({})",
(0..search.ids.len())
.map(|_| "?")
.collect::<Vec<_>>()
.join(",")
));
params.extend(search.ids.into_iter().map(Value::Text));
}
if let Some(intersects) = search.intersects {
wheres.push("ST_Intersects(geometry, ST_GeomFromGeoJSON(?))".to_string());
params.push(Value::Text(intersects.to_string()));
}
if !search.collections.is_empty() {
wheres.push(format!(
"collection IN ({})",
(0..search.collections.len())
.map(|_| "?")
.collect::<Vec<_>>()
.join(",")
));
params.extend(search.collections.into_iter().map(Value::Text));
}
if let Some(bbox) = search.items.bbox {
wheres.push("ST_Intersects(geometry, ST_GeomFromGeoJSON(?))".to_string());
params.push(Value::Text(bbox.to_geometry().to_string()));
}
if let Some(datetime) = search.items.datetime {
let interval = stac::datetime::parse(&datetime)?;
if let Some(start) = interval.0 {
wheres.push(format!(
"?::TIMESTAMPTZ <= {}",
if has_start_datetime {
"start_datetime"
} else {
"datetime"
}
));
params.push(Value::Text(start.to_rfc3339()));
}
if let Some(end) = interval.1 {
wheres.push(format!(
"?::TIMESTAMPTZ >= {}", if has_end_datetime {
"end_datetime"
} else {
"datetime"
}
));
params.push(Value::Text(end.to_rfc3339()));
}
}
if search.items.filter.is_some() {
todo!("Implement the filter extension");
}
if search.items.query.is_some() {
todo!("Implement the query extension");
}
let mut suffix = String::new();
if !wheres.is_empty() {
suffix.push_str(&format!(" WHERE {}", wheres.join(" AND ")));
}
if !sortby.is_empty() {
let mut order_by = Vec::with_capacity(sortby.len());
for sortby in sortby {
order_by.push(format!(
"{} {}",
sortby.field,
match sortby.direction {
Direction::Ascending => "ASC",
Direction::Descending => "DESC",
}
));
}
suffix.push_str(&format!(" ORDER BY {}", order_by.join(", ")));
}
if let Some(limit) = limit {
suffix.push_str(&format!(" LIMIT {}", limit));
}
Ok(Query {
sql: format!(
"SELECT {} FROM read_parquet('{}'){}",
columns.join(","),
href,
suffix,
),
params,
})
}
}
pub fn version() -> &'static str {
env!("CARGO_PKG_VERSION")
}
fn to_geoarrow_record_batch(mut record_batch: RecordBatch) -> Result<RecordBatch> {
if let Some((index, _)) = record_batch.schema().column_with_name("geometry") {
let geometry_column = record_batch.remove_column(index);
let binary_array: GenericByteArray<GenericBinaryType<i32>> =
arrow::array::downcast_array(&geometry_column);
let wkb_array = WKBArray::new(binary_array, Default::default());
let geometry_array = geoarrow::io::wkb::from_wkb(
&wkb_array,
NativeType::Geometry(CoordType::Interleaved),
false,
)?;
let mut columns = record_batch.columns().to_vec();
let mut schema_builder = SchemaBuilder::from(&*record_batch.schema());
schema_builder.push(geometry_array.extension_field());
let schema = schema_builder.finish();
columns.push(geometry_array.to_array_ref());
record_batch = RecordBatch::try_new(schema.into(), columns)?;
}
Ok(record_batch)
}
#[cfg(test)]
mod tests {
use super::Client;
use geo::Geometry;
use rstest::{fixture, rstest};
use stac::{Bbox, ValidateBlocking};
use stac_api::{Search, Sortby};
use std::sync::Mutex;
static MUTEX: Mutex<()> = Mutex::new(());
#[fixture]
fn client() -> Client {
let _mutex = MUTEX.lock().unwrap();
Client::new().unwrap()
}
#[rstest]
fn search_all(client: Client) {
let item_collection = client
.search("data/100-sentinel-2-items.parquet", Search::default())
.unwrap();
assert_eq!(item_collection.items.len(), 100);
item_collection.items[0].validate_blocking().unwrap();
}
#[rstest]
fn search_ids(client: Client) {
let item_collection = client
.search(
"data/100-sentinel-2-items.parquet",
Search::default().ids(vec![
"S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429".to_string(),
]),
)
.unwrap();
assert_eq!(item_collection.items.len(), 1);
assert_eq!(
item_collection.items[0].id,
"S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429"
);
}
#[rstest]
fn search_intersects(client: Client) {
let item_collection = client
.search(
"data/100-sentinel-2-items.parquet",
Search::default().intersects(&Geometry::Point(geo::point! { x: -106., y: 40.5 })),
)
.unwrap();
assert_eq!(item_collection.items.len(), 50);
}
#[rstest]
fn search_collections(client: Client) {
let item_collection = client
.search(
"data/100-sentinel-2-items.parquet",
Search::default().collections(vec!["sentinel-2-l2a".to_string()]),
)
.unwrap();
assert_eq!(item_collection.items.len(), 100);
let item_collection = client
.search(
"data/100-sentinel-2-items.parquet",
Search::default().collections(vec!["foobar".to_string()]),
)
.unwrap();
assert_eq!(item_collection.items.len(), 0);
}
#[rstest]
fn search_bbox(client: Client) {
let item_collection = client
.search(
"data/100-sentinel-2-items.parquet",
Search::default().bbox(Bbox::new(-106.1, 40.5, -106.0, 40.6)),
)
.unwrap();
assert_eq!(item_collection.items.len(), 50);
}
#[rstest]
fn search_datetime(client: Client) {
let item_collection = client
.search(
"data/100-sentinel-2-items.parquet",
Search::default().datetime("2024-12-02T00:00:00Z/.."),
)
.unwrap();
assert_eq!(item_collection.items.len(), 1);
let item_collection = client
.search(
"data/100-sentinel-2-items.parquet",
Search::default().datetime("../2024-12-02T00:00:00Z"),
)
.unwrap();
assert_eq!(item_collection.items.len(), 99);
}
#[rstest]
fn search_limit(client: Client) {
let item_collection = client
.search(
"data/100-sentinel-2-items.parquet",
Search::default().limit(42),
)
.unwrap();
assert_eq!(item_collection.items.len(), 42);
}
#[rstest]
fn search_sortby(client: Client) {
let item_collection = client
.search(
"data/100-sentinel-2-items.parquet",
Search::default()
.sortby(vec![Sortby::asc("datetime")])
.limit(1),
)
.unwrap();
assert_eq!(
item_collection.items[0].id,
"S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429"
);
let item_collection = client
.search(
"data/100-sentinel-2-items.parquet",
Search::default()
.sortby(vec![Sortby::desc("datetime")])
.limit(1),
)
.unwrap();
assert_eq!(
item_collection.items[0].id,
"S2B_MSIL2A_20241203T174629_R098_T13TDE_20241203T211406"
);
}
#[rstest]
fn search_fields(client: Client) {
let item_collection = client
.search_to_json(
"data/100-sentinel-2-items.parquet",
Search::default().fields("+id".parse().unwrap()).limit(1),
)
.unwrap();
assert_eq!(item_collection.items[0].len(), 1);
}
}