1use crate::{Error, Extension, Result};
2use arrow_array::{RecordBatch, RecordBatchIterator};
3use chrono::DateTime;
4use cql2::{Expr, ToDuckSQL};
5use duckdb::{Connection, types::Value};
6use geo::BoundingRect;
7use geojson::Geometry;
8use stac::{Collection, SpatialExtent, TemporalExtent, geoarrow::DATETIME_COLUMNS};
9use stac_api::{Direction, Search};
10use std::ops::{Deref, DerefMut};
11
12pub const DEFAULT_USE_HIVE_PARTITIONING: bool = false;
14
15pub const DEFAULT_CONVERT_WKB: bool = true;
17
18pub const DEFAULT_COLLECTION_DESCRIPTION: &str =
20 "Auto-generated collection from stac-geoparquet extents";
21
22pub const DEFAULT_UNION_BY_NAME: bool = true;
24
25#[derive(Debug)]
27pub struct Client {
28 connection: Connection,
29
30 pub use_hive_partitioning: bool,
32
33 pub convert_wkb: bool,
37
38 pub union_by_name: bool,
42}
43
44impl Client {
45 pub fn new() -> Result<Client> {
60 let connection = Connection::open_in_memory()?;
61 connection.execute("INSTALL spatial", [])?;
62 connection.execute("LOAD spatial", [])?;
63 connection.execute("INSTALL icu", [])?;
64 connection.execute("LOAD icu", [])?;
65 Ok(connection.into())
66 }
67
68 pub fn extensions(&self) -> Result<Vec<Extension>> {
79 let mut statement = self.prepare(
80 "SELECT extension_name, loaded, installed, install_path, description, extension_version, install_mode, installed_from FROM duckdb_extensions();",
81 )?;
82 let extensions = statement
83 .query_map([], |row| {
84 Ok(Extension {
85 name: row.get("extension_name")?,
86 loaded: row.get("loaded")?,
87 installed: row.get("installed")?,
88 install_path: row.get("install_path")?,
89 description: row.get("description")?,
90 version: row.get("extension_version")?,
91 install_mode: row.get("install_mode")?,
92 installed_from: row.get("installed_from")?,
93 })
94 })?
95 .collect::<std::result::Result<Vec<_>, duckdb::Error>>()?;
96 Ok(extensions)
97 }
98
99 pub fn collections(&self, href: &str) -> Result<Vec<Collection>> {
110 let start_datetime= if self.prepare(&format!(
111 "SELECT column_name FROM (DESCRIBE SELECT * from {}) where column_name = 'start_datetime'",
112 self.format_parquet_href(href)
113 ))?.query([])?.next()?.is_some() {
114 "strftime(min(coalesce(start_datetime, datetime)), '%xT%X%z')"
115 } else {
116 "strftime(min(datetime), '%xT%X%z')"
117 };
118 let end_datetime = if self
119 .prepare(&format!(
120 "SELECT column_name FROM (DESCRIBE SELECT * from {}) where column_name = 'end_datetime'",
121 self.format_parquet_href(href)
122 ))?
123 .query([])?
124 .next()?
125 .is_some()
126 {
127 "strftime(max(coalesce(end_datetime, datetime)), '%xT%X%z')"
128 } else {
129 "strftime(max(datetime), '%xT%X%z')"
130 };
131 let mut statement = self.prepare(&format!(
132 "SELECT DISTINCT collection FROM {}",
133 self.format_parquet_href(href)
134 ))?;
135 let mut collections = Vec::new();
136 for row in statement.query_map([], |row| row.get::<_, String>(0))? {
137 let collection_id = row?;
138 let mut statement = self.connection.prepare(&
139 format!("SELECT ST_AsGeoJSON(ST_Extent_Agg(geometry)), {}, {} FROM {} WHERE collection = $1", start_datetime, end_datetime,
140 self.format_parquet_href(href)
141 ))?;
142 let row = statement.query_row([&collection_id], |row| {
143 Ok((
144 row.get::<_, String>(0)?,
145 row.get::<_, String>(1)?,
146 row.get::<_, String>(2)?,
147 ))
148 })?;
149 let mut collection = Collection::new(collection_id, DEFAULT_COLLECTION_DESCRIPTION);
150 let geometry: geo::Geometry = Geometry::from_json_value(serde_json::from_str(&row.0)?)
151 .map_err(Box::new)?
152 .try_into()
153 .map_err(Box::new)?;
154 if let Some(bbox) = geometry.bounding_rect() {
155 collection.extent.spatial = SpatialExtent {
156 bbox: vec![bbox.into()],
157 };
158 }
159 collection.extent.temporal = TemporalExtent {
160 interval: vec![[
161 Some(DateTime::parse_from_str(&row.1, "%FT%T%#z")?.into()),
162 Some(DateTime::parse_from_str(&row.2, "%FT%T%#z")?.into()),
163 ]],
164 };
165 collections.push(collection);
166 }
167 Ok(collections)
168 }
169
170 pub fn search(&self, href: &str, search: Search) -> Result<stac_api::ItemCollection> {
181 let record_batches = self.search_to_arrow(href, search)?;
182 if record_batches.is_empty() {
183 Ok(Default::default())
184 } else {
185 let schema = record_batches[0].schema();
186 let item_collection = stac::geoarrow::json::from_record_batch_reader(
187 RecordBatchIterator::new(record_batches.into_iter().map(Ok), schema),
188 )?;
189 Ok(item_collection.into())
190 }
191 }
192
193 pub fn search_to_arrow(&self, href: &str, search: Search) -> Result<Vec<RecordBatch>> {
204 if let Some((sql, params)) = self.build_query(href, search)? {
206 log::debug!("duckdb sql: {sql}");
207 let mut statement = self.prepare(&sql)?;
208 statement
209 .query_arrow(duckdb::params_from_iter(params))?
210 .map(|record_batch| {
211 let record_batch = if self.convert_wkb {
212 stac::geoarrow::with_native_geometry(record_batch, "geometry")?
213 } else {
214 stac::geoarrow::add_wkb_metadata(record_batch, "geometry")?
215 };
216 Ok(record_batch)
217 })
218 .collect::<Result<_>>()
219 } else {
220 Ok(Vec::new())
221 }
222 }
223
224 pub fn build_query(&self, href: &str, search: Search) -> Result<Option<(String, Vec<Value>)>> {
237 if search.items.query.is_some() {
240 return Err(Error::QueryNotImplemented);
241 }
242
243 let mut statement = self.prepare(&format!(
245 "SELECT column_name FROM (DESCRIBE SELECT * from {})",
246 self.format_parquet_href(href)
247 ))?;
248 let mut has_start_datetime = false;
249 let mut has_end_datetime = false;
250 let mut column_names = Vec::new();
251 let mut columns = Vec::new();
252 for row in statement.query_map([], |row| row.get::<_, String>(0))? {
253 let column = row?;
254 if column == "start_datetime" {
255 has_start_datetime = true;
256 }
257 if column == "end_datetime" {
258 has_end_datetime = true;
259 }
260
261 if let Some(fields) = search.fields.as_ref() {
262 if fields.exclude.contains(&column)
263 || !(fields.include.is_empty() || fields.include.contains(&column))
264 {
265 continue;
266 }
267 }
268
269 if column == "geometry" {
270 columns.push("ST_AsWKB(geometry) geometry".to_string());
271 } else if DATETIME_COLUMNS.contains(&column.as_str()) {
272 columns.push(format!("\"{column}\"::TIMESTAMPTZ {column}"))
273 } else {
274 columns.push(format!("\"{column}\""));
275 }
276 column_names.push(column);
277 }
278
279 let limit = search.items.limit;
281 let offset = search
282 .items
283 .additional_fields
284 .get("offset")
285 .and_then(|v| v.as_i64());
286
287 let mut order_by = Vec::with_capacity(search.sortby.len());
289 for sortby in &search.sortby {
290 order_by.push(format!(
291 "\"{}\" {}",
292 sortby.field,
293 match sortby.direction {
294 Direction::Ascending => "ASC",
295 Direction::Descending => "DESC",
296 }
297 ));
298 }
299
300 let mut wheres = Vec::new();
302 let mut params = Vec::new();
303 if !search.ids.is_empty() {
304 wheres.push(format!(
305 "id IN ({})",
306 (0..search.ids.len())
307 .map(|_| "?")
308 .collect::<Vec<_>>()
309 .join(",")
310 ));
311 params.extend(search.ids.into_iter().map(Value::Text));
312 }
313 if let Some(intersects) = search.intersects {
314 wheres.push("ST_Intersects(geometry, ST_GeomFromGeoJSON(?))".to_string());
315 params.push(Value::Text(intersects.to_string()));
316 }
317 if !search.collections.is_empty() {
318 wheres.push(format!(
319 "collection IN ({})",
320 (0..search.collections.len())
321 .map(|_| "?")
322 .collect::<Vec<_>>()
323 .join(",")
324 ));
325 params.extend(search.collections.into_iter().map(Value::Text));
326 }
327 if let Some(bbox) = search.items.bbox {
328 wheres.push("ST_Intersects(geometry, ST_GeomFromGeoJSON(?))".to_string());
329 params.push(Value::Text(bbox.to_geometry().to_string()));
330 }
331 if let Some(datetime) = search.items.datetime {
332 let interval = stac::datetime::parse(&datetime)?;
333 if let Some(start) = interval.0 {
334 wheres.push(format!(
335 "?::TIMESTAMPTZ <= {}",
336 if has_start_datetime {
337 "start_datetime"
338 } else {
339 "datetime"
340 }
341 ));
342 params.push(Value::Text(start.to_rfc3339()));
343 }
344 if let Some(end) = interval.1 {
345 wheres.push(format!(
346 "?::TIMESTAMPTZ >= {}", if has_end_datetime {
348 "end_datetime"
349 } else {
350 "datetime"
351 }
352 ));
353 params.push(Value::Text(end.to_rfc3339()));
354 }
355 }
356 if let Some(filter) = search.items.filter {
357 let expr: Expr = filter.try_into()?;
358 if expr_properties_match(&expr, &column_names) {
359 let sql = expr.to_ducksql().map_err(Box::new)?;
360 wheres.push(sql);
361 } else {
362 return Ok(None);
363 }
364 }
365
366 let mut suffix = String::new();
367 if !wheres.is_empty() {
368 suffix.push_str(&format!(" WHERE {}", wheres.join(" AND ")));
369 }
370 if !order_by.is_empty() {
371 suffix.push_str(&format!(" ORDER BY {}", order_by.join(", ")));
372 }
373 if let Some(limit) = limit {
374 suffix.push_str(&format!(" LIMIT {limit}"));
375 }
376 if let Some(offset) = offset {
377 suffix.push_str(&format!(" OFFSET {offset}"));
378 }
379
380 let sql = format!(
381 "SELECT {} FROM {}{}",
382 columns.join(","),
383 self.format_parquet_href(href),
384 suffix,
385 );
386 Ok(Some((sql, params)))
387 }
388
389 fn format_parquet_href(&self, href: &str) -> String {
390 format!(
391 "read_parquet('{}', filename=true, hive_partitioning={}, union_by_name={})",
392 href,
393 if self.use_hive_partitioning {
394 "true"
395 } else {
396 "false"
397 },
398 if self.union_by_name { "true" } else { "false" }
399 )
400 }
401}
402
403fn expr_properties_match(expr: &Expr, properties: &[String]) -> bool {
404 use Expr::*;
405
406 match expr {
407 Property { property } => properties.contains(property),
408 Float(_) | Literal(_) | Bool(_) | Geometry(_) => true,
409 Operation { args, .. } => args
410 .iter()
411 .all(|expr| expr_properties_match(expr, properties)),
412 Interval { interval } => interval
413 .iter()
414 .all(|expr| expr_properties_match(expr, properties)),
415 Timestamp { timestamp } => expr_properties_match(timestamp, properties),
416 Date { date } => expr_properties_match(date, properties),
417 Array(exprs) => exprs
418 .iter()
419 .all(|expr| expr_properties_match(expr, properties)),
420 BBox { bbox } => bbox
421 .iter()
422 .all(|expr| expr_properties_match(expr, properties)),
423 Null => expr_properties_match(expr, properties),
424 }
425}
426
427impl Deref for Client {
428 type Target = Connection;
429
430 fn deref(&self) -> &Self::Target {
431 &self.connection
432 }
433}
434
435impl DerefMut for Client {
436 fn deref_mut(&mut self) -> &mut Self::Target {
437 &mut self.connection
438 }
439}
440
441impl From<Connection> for Client {
442 fn from(connection: Connection) -> Self {
443 Client {
444 connection,
445 use_hive_partitioning: DEFAULT_USE_HIVE_PARTITIONING,
446 convert_wkb: DEFAULT_CONVERT_WKB,
447 union_by_name: DEFAULT_UNION_BY_NAME,
448 }
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::Client;
455 use duckdb::Connection;
456 use geo::Geometry;
457 use rstest::{fixture, rstest};
458 use stac::Bbox;
459 use stac_api::{Items, Search, Sortby};
460 use stac_validate::Validate;
461
462 #[fixture]
463 #[once]
464 fn install_extensions() {
465 let connection = Connection::open_in_memory().unwrap();
466 connection.execute("INSTALL icu", []).unwrap();
467 connection.execute("INSTALL spatial", []).unwrap();
468 }
469
470 #[allow(unused_variables)]
471 #[fixture]
472 fn client(install_extensions: ()) -> Client {
473 Client::new().unwrap()
474 }
475
476 #[rstest]
477 fn extensions(client: Client) {
478 let _ = client.extensions().unwrap();
479 }
480
481 #[rstest]
482 #[tokio::test]
483 async fn search(client: Client) {
484 let item_collection = client
485 .search("data/100-sentinel-2-items.parquet", Search::default())
486 .unwrap();
487 assert_eq!(item_collection.items.len(), 100);
488 item_collection.items[0].validate().await.unwrap();
489 }
490
491 #[rstest]
492 fn search_to_arrow(client: Client) {
493 let record_batches = client
494 .search_to_arrow("data/100-sentinel-2-items.parquet", Search::default())
495 .unwrap();
496 assert_eq!(record_batches.len(), 1);
497 }
498
499 #[rstest]
500 fn search_ids(client: Client) {
501 let item_collection = client
502 .search(
503 "data/100-sentinel-2-items.parquet",
504 Search::default().ids(vec![
505 "S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429".to_string(),
506 ]),
507 )
508 .unwrap();
509 assert_eq!(item_collection.items.len(), 1);
510 assert_eq!(
511 item_collection.items[0]["id"],
512 "S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429"
513 );
514 }
515
516 #[rstest]
517 fn search_intersects(client: Client) {
518 let item_collection = client
519 .search(
520 "data/100-sentinel-2-items.parquet",
521 Search::default().intersects(&Geometry::Point(geo::point! { x: -106., y: 40.5 })),
522 )
523 .unwrap();
524 assert_eq!(item_collection.items.len(), 50);
525 }
526
527 #[rstest]
528 fn search_collections(client: Client) {
529 let item_collection = client
530 .search(
531 "data/100-sentinel-2-items.parquet",
532 Search::default().collections(vec!["sentinel-2-l2a".to_string()]),
533 )
534 .unwrap();
535 assert_eq!(item_collection.items.len(), 100);
536
537 let item_collection = client
538 .search(
539 "data/100-sentinel-2-items.parquet",
540 Search::default().collections(vec!["foobar".to_string()]),
541 )
542 .unwrap();
543 assert_eq!(item_collection.items.len(), 0);
544 }
545
546 #[rstest]
547 fn search_bbox(client: Client) {
548 let item_collection = client
549 .search(
550 "data/100-sentinel-2-items.parquet",
551 Search::default().bbox(Bbox::new(-106.1, 40.5, -106.0, 40.6)),
552 )
553 .unwrap();
554 assert_eq!(item_collection.items.len(), 50);
555 }
556
557 #[rstest]
558 fn search_datetime(client: Client) {
559 let item_collection = client
560 .search(
561 "data/100-sentinel-2-items.parquet",
562 Search::default().datetime("2024-12-02T00:00:00Z/.."),
563 )
564 .unwrap();
565 assert_eq!(item_collection.items.len(), 1);
566 let item_collection = client
567 .search(
568 "data/100-sentinel-2-items.parquet",
569 Search::default().datetime("../2024-12-02T00:00:00Z"),
570 )
571 .unwrap();
572 assert_eq!(item_collection.items.len(), 99);
573 }
574
575 #[rstest]
576 fn search_datetime_empty_interval(client: Client) {
577 let item_collection = client
578 .search(
579 "data/100-sentinel-2-items.parquet",
580 Search::default().datetime("2024-12-02T00:00:00Z/"),
581 )
582 .unwrap();
583 assert_eq!(item_collection.items.len(), 1);
584 }
585
586 #[rstest]
587 fn search_limit(client: Client) {
588 let item_collection = client
589 .search(
590 "data/100-sentinel-2-items.parquet",
591 Search::default().limit(42),
592 )
593 .unwrap();
594 assert_eq!(item_collection.items.len(), 42);
595 }
596
597 #[rstest]
598 fn search_offset(client: Client) {
599 let mut search = Search::default().limit(1);
600 search
601 .items
602 .additional_fields
603 .insert("offset".to_string(), 1.into());
604 let item_collection = client
605 .search("data/100-sentinel-2-items.parquet", search)
606 .unwrap();
607 assert_eq!(
608 item_collection.items[0]["id"],
609 "S2A_MSIL2A_20241201T175721_R141_T13TDE_20241201T213150"
610 );
611 }
612
613 #[rstest]
614 fn search_sortby(client: Client) {
615 let item_collection = client
616 .search(
617 "data/100-sentinel-2-items.parquet",
618 Search::default()
619 .sortby(vec![Sortby::asc("datetime")])
620 .limit(1),
621 )
622 .unwrap();
623 assert_eq!(
624 item_collection.items[0]["id"],
625 "S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429"
626 );
627
628 let item_collection = client
629 .search(
630 "data/100-sentinel-2-items.parquet",
631 Search::default()
632 .sortby(vec![Sortby::desc("datetime")])
633 .limit(1),
634 )
635 .unwrap();
636 assert_eq!(
637 item_collection.items[0]["id"],
638 "S2B_MSIL2A_20241203T174629_R098_T13TDE_20241203T211406"
639 );
640 }
641
642 #[rstest]
643 fn search_fields(client: Client) {
644 let item_collection = client
645 .search(
646 "data/100-sentinel-2-items.parquet",
647 Search::default().fields("+id".parse().unwrap()).limit(1),
648 )
649 .unwrap();
650 assert_eq!(item_collection.items[0].len(), 1);
651 }
652
653 #[rstest]
654 fn collections(client: Client) {
655 let collections = client
656 .collections("data/100-sentinel-2-items.parquet")
657 .unwrap();
658 assert_eq!(collections.len(), 1);
659 }
660
661 #[rstest]
662 fn no_convert_wkb(mut client: Client) {
663 client.convert_wkb = false;
664 let record_batches = client
665 .search_to_arrow("data/100-sentinel-2-items.parquet", Search::default())
666 .unwrap();
667 let schema = record_batches[0].schema();
668 assert_eq!(
669 schema.field_with_name("geometry").unwrap().metadata()["ARROW:extension:name"],
670 "geoarrow.wkb"
671 );
672 }
673
674 #[rstest]
675 fn filter(client: Client) {
676 let search = Search {
677 items: Items {
678 filter: Some("sat:relative_orbit = 98".parse().unwrap()),
679 ..Default::default()
680 },
681 ..Default::default()
682 };
683 let item_collection = client
684 .search("data/100-sentinel-2-items.parquet", search)
685 .unwrap();
686 assert_eq!(item_collection.items.len(), 49);
687 }
688
689 #[rstest]
690 fn filter_no_column(client: Client) {
691 let search = Search {
692 items: Items {
693 filter: Some("foo:bar = 42".parse().unwrap()),
694 ..Default::default()
695 },
696 ..Default::default()
697 };
698 let item_collection = client
699 .search("data/100-sentinel-2-items.parquet", search)
700 .unwrap();
701 assert_eq!(item_collection.items.len(), 0);
702 }
703
704 #[rstest]
705 fn sortby_property(client: Client) {
706 let search = Search {
707 items: Items {
708 sortby: vec!["eo:cloud_cover".parse().unwrap()],
709 ..Default::default()
710 },
711 ..Default::default()
712 };
713 let item_collection = client
714 .search("data/100-sentinel-2-items.parquet", search)
715 .unwrap();
716 assert_eq!(item_collection.items.len(), 100);
717 }
718
719 #[rstest]
720 fn union_by_name(client: Client) {
721 let _ = client.search("data/*.parquet", Default::default()).unwrap();
722 }
723
724 #[rstest]
725 fn no_union_by_name(mut client: Client) {
726 client.union_by_name = false;
727 let _ = client
728 .search("data/*.parquet", Default::default())
729 .unwrap_err();
730 }
731}