1use crate::{Error, Extension, Result};
2use arrow_array::{RecordBatch, RecordBatchIterator};
3use chrono::DateTime;
4use cql2::{Expr, ToDuckSQL};
5use duckdb::{Connection, types::Value};
6use geo::BoundingRect;
7use geojson::Geometry;
8use stac::{Collection, SpatialExtent, TemporalExtent, geoarrow::DATETIME_COLUMNS};
9use stac_api::{Direction, Search};
10use std::ops::{Deref, DerefMut};
11
12pub const DEFAULT_USE_HIVE_PARTITIONING: bool = false;
14
15pub const DEFAULT_CONVERT_WKB: bool = true;
17
18pub const DEFAULT_COLLECTION_DESCRIPTION: &str =
20 "Auto-generated collection from stac-geoparquet extents";
21
22pub const DEFAULT_UNION_BY_NAME: bool = true;
24
25#[derive(Debug)]
27pub struct Client {
28 connection: Connection,
29
30 pub use_hive_partitioning: bool,
32
33 pub convert_wkb: bool,
37
38 pub union_by_name: bool,
42}
43
44impl Client {
45 pub fn new() -> Result<Client> {
60 let connection = Connection::open_in_memory()?;
61 connection.execute("INSTALL spatial", [])?;
62 connection.execute("LOAD spatial", [])?;
63 connection.execute("INSTALL icu", [])?;
64 connection.execute("LOAD icu", [])?;
65 Ok(connection.into())
66 }
67
68 pub fn extensions(&self) -> Result<Vec<Extension>> {
79 let mut statement = self.prepare(
80 "SELECT extension_name, loaded, installed, install_path, description, extension_version, install_mode, installed_from FROM duckdb_extensions();",
81 )?;
82 let extensions = statement
83 .query_map([], |row| {
84 Ok(Extension {
85 name: row.get("extension_name")?,
86 loaded: row.get("loaded")?,
87 installed: row.get("installed")?,
88 install_path: row.get("install_path")?,
89 description: row.get("description")?,
90 version: row.get("extension_version")?,
91 install_mode: row.get("install_mode")?,
92 installed_from: row.get("installed_from")?,
93 })
94 })?
95 .collect::<std::result::Result<Vec<_>, duckdb::Error>>()?;
96 Ok(extensions)
97 }
98
99 pub fn collections(&self, href: &str) -> Result<Vec<Collection>> {
110 let start_datetime= if self.prepare(&format!(
111 "SELECT column_name FROM (DESCRIBE SELECT * from {}) where column_name = 'start_datetime'",
112 self.format_parquet_href(href)
113 ))?.query([])?.next()?.is_some() {
114 "strftime(min(coalesce(start_datetime, datetime)), '%xT%X%z')"
115 } else {
116 "strftime(min(datetime), '%xT%X%z')"
117 };
118 let end_datetime = if self
119 .prepare(&format!(
120 "SELECT column_name FROM (DESCRIBE SELECT * from {}) where column_name = 'end_datetime'",
121 self.format_parquet_href(href)
122 ))?
123 .query([])?
124 .next()?
125 .is_some()
126 {
127 "strftime(max(coalesce(end_datetime, datetime)), '%xT%X%z')"
128 } else {
129 "strftime(max(datetime), '%xT%X%z')"
130 };
131 let mut statement = self.prepare(&format!(
132 "SELECT DISTINCT collection FROM {}",
133 self.format_parquet_href(href)
134 ))?;
135 let mut collections = Vec::new();
136 for row in statement.query_map([], |row| row.get::<_, String>(0))? {
137 let collection_id = row?;
138 let mut statement = self.connection.prepare(&
139 format!("SELECT ST_AsGeoJSON(ST_Extent_Agg(geometry)), {}, {} FROM {} WHERE collection = $1", start_datetime, end_datetime,
140 self.format_parquet_href(href)
141 ))?;
142 let row = statement.query_row([&collection_id], |row| {
143 Ok((
144 row.get::<_, String>(0)?,
145 row.get::<_, String>(1)?,
146 row.get::<_, String>(2)?,
147 ))
148 })?;
149 let mut collection = Collection::new(collection_id, DEFAULT_COLLECTION_DESCRIPTION);
150 let geometry: geo::Geometry = Geometry::from_json_value(serde_json::from_str(&row.0)?)
151 .map_err(Box::new)?
152 .try_into()
153 .map_err(Box::new)?;
154 if let Some(bbox) = geometry.bounding_rect() {
155 collection.extent.spatial = SpatialExtent {
156 bbox: vec![bbox.into()],
157 };
158 }
159 collection.extent.temporal = TemporalExtent {
160 interval: vec![[
161 Some(DateTime::parse_from_str(&row.1, "%FT%T%#z")?.into()),
162 Some(DateTime::parse_from_str(&row.2, "%FT%T%#z")?.into()),
163 ]],
164 };
165 collections.push(collection);
166 }
167 Ok(collections)
168 }
169
170 pub fn search(&self, href: &str, search: Search) -> Result<stac_api::ItemCollection> {
181 let record_batches = self.search_to_arrow(href, search)?;
182 if record_batches.is_empty() {
183 Ok(Default::default())
184 } else {
185 let schema = record_batches[0].schema();
186 let item_collection = stac::geoarrow::json::from_record_batch_reader(
187 RecordBatchIterator::new(record_batches.into_iter().map(Ok), schema),
188 )?;
189 Ok(item_collection.into())
190 }
191 }
192
193 pub fn search_to_arrow(&self, href: &str, search: Search) -> Result<Vec<RecordBatch>> {
204 if search.items.query.is_some() {
209 return Err(Error::QueryNotImplemented);
210 }
211
212 let mut statement = self.prepare(&format!(
214 "SELECT column_name FROM (DESCRIBE SELECT * from {})",
215 self.format_parquet_href(href)
216 ))?;
217 let mut has_start_datetime = false;
218 let mut has_end_datetime = false;
219 let mut column_names = Vec::new();
220 let mut columns = Vec::new();
221 for row in statement.query_map([], |row| row.get::<_, String>(0))? {
222 let column = row?;
223 if column == "start_datetime" {
224 has_start_datetime = true;
225 }
226 if column == "end_datetime" {
227 has_end_datetime = true;
228 }
229
230 if let Some(fields) = search.fields.as_ref() {
231 if fields.exclude.contains(&column)
232 || !(fields.include.is_empty() || fields.include.contains(&column))
233 {
234 continue;
235 }
236 }
237
238 if column == "geometry" {
239 columns.push("ST_AsWKB(geometry) geometry".to_string());
240 } else if DATETIME_COLUMNS.contains(&column.as_str()) {
241 columns.push(format!("\"{column}\"::TIMESTAMPTZ {column}"))
242 } else {
243 columns.push(format!("\"{column}\""));
244 }
245 column_names.push(column);
246 }
247
248 let limit = search.items.limit;
250 let offset = search
251 .items
252 .additional_fields
253 .get("offset")
254 .and_then(|v| v.as_i64());
255
256 let mut order_by = Vec::with_capacity(search.sortby.len());
258 for sortby in &search.sortby {
259 order_by.push(format!(
260 "\"{}\" {}",
261 sortby.field,
262 match sortby.direction {
263 Direction::Ascending => "ASC",
264 Direction::Descending => "DESC",
265 }
266 ));
267 }
268
269 let mut wheres = Vec::new();
271 let mut params = Vec::new();
272 if !search.ids.is_empty() {
273 wheres.push(format!(
274 "id IN ({})",
275 (0..search.ids.len())
276 .map(|_| "?")
277 .collect::<Vec<_>>()
278 .join(",")
279 ));
280 params.extend(search.ids.into_iter().map(Value::Text));
281 }
282 if let Some(intersects) = search.intersects {
283 wheres.push("ST_Intersects(geometry, ST_GeomFromGeoJSON(?))".to_string());
284 params.push(Value::Text(intersects.to_string()));
285 }
286 if !search.collections.is_empty() {
287 wheres.push(format!(
288 "collection IN ({})",
289 (0..search.collections.len())
290 .map(|_| "?")
291 .collect::<Vec<_>>()
292 .join(",")
293 ));
294 params.extend(search.collections.into_iter().map(Value::Text));
295 }
296 if let Some(bbox) = search.items.bbox {
297 wheres.push("ST_Intersects(geometry, ST_GeomFromGeoJSON(?))".to_string());
298 params.push(Value::Text(bbox.to_geometry().to_string()));
299 }
300 if let Some(datetime) = search.items.datetime {
301 let interval = stac::datetime::parse(&datetime)?;
302 if let Some(start) = interval.0 {
303 wheres.push(format!(
304 "?::TIMESTAMPTZ <= {}",
305 if has_start_datetime {
306 "start_datetime"
307 } else {
308 "datetime"
309 }
310 ));
311 params.push(Value::Text(start.to_rfc3339()));
312 }
313 if let Some(end) = interval.1 {
314 wheres.push(format!(
315 "?::TIMESTAMPTZ >= {}", if has_end_datetime {
317 "end_datetime"
318 } else {
319 "datetime"
320 }
321 ));
322 params.push(Value::Text(end.to_rfc3339()));
323 }
324 }
325 if let Some(filter) = search.items.filter {
326 let expr: Expr = filter.try_into()?;
327 if expr_properties_match(&expr, &column_names) {
328 let sql = expr.to_ducksql().map_err(Box::new)?;
329 wheres.push(sql);
330 } else {
331 return Ok(Vec::new());
332 }
333 }
334
335 let mut suffix = String::new();
336 if !wheres.is_empty() {
337 suffix.push_str(&format!(" WHERE {}", wheres.join(" AND ")));
338 }
339 if !order_by.is_empty() {
340 suffix.push_str(&format!(" ORDER BY {}", order_by.join(", ")));
341 }
342 if let Some(limit) = limit {
343 suffix.push_str(&format!(" LIMIT {limit}"));
344 }
345 if let Some(offset) = offset {
346 suffix.push_str(&format!(" OFFSET {offset}"));
347 }
348
349 let sql = format!(
350 "SELECT {} FROM {}{}",
351 columns.join(","),
352 self.format_parquet_href(href),
353 suffix,
354 );
355 log::debug!("duckdb sql: {sql}");
356 let mut statement = self.prepare(&sql)?;
357 statement
358 .query_arrow(duckdb::params_from_iter(params))?
359 .map(|record_batch| {
360 let record_batch = if self.convert_wkb {
361 stac::geoarrow::with_native_geometry(record_batch, "geometry")?
362 } else {
363 stac::geoarrow::add_wkb_metadata(record_batch, "geometry")?
364 };
365 Ok(record_batch)
366 })
367 .collect::<Result<_>>()
368 }
369
370 fn format_parquet_href(&self, href: &str) -> String {
371 format!(
372 "read_parquet('{}', filename=true, hive_partitioning={}, union_by_name={})",
373 href,
374 if self.use_hive_partitioning {
375 "true"
376 } else {
377 "false"
378 },
379 if self.union_by_name { "true" } else { "false" }
380 )
381 }
382}
383
384fn expr_properties_match(expr: &Expr, properties: &[String]) -> bool {
385 use Expr::*;
386
387 match expr {
388 Property { property } => properties.contains(property),
389 Float(_) | Literal(_) | Bool(_) | Geometry(_) => true,
390 Operation { args, .. } => args
391 .iter()
392 .all(|expr| expr_properties_match(expr, properties)),
393 Interval { interval } => interval
394 .iter()
395 .all(|expr| expr_properties_match(expr, properties)),
396 Timestamp { timestamp } => expr_properties_match(timestamp, properties),
397 Date { date } => expr_properties_match(date, properties),
398 Array(exprs) => exprs
399 .iter()
400 .all(|expr| expr_properties_match(expr, properties)),
401 BBox { bbox } => bbox
402 .iter()
403 .all(|expr| expr_properties_match(expr, properties)),
404 }
405}
406
407impl Deref for Client {
408 type Target = Connection;
409
410 fn deref(&self) -> &Self::Target {
411 &self.connection
412 }
413}
414
415impl DerefMut for Client {
416 fn deref_mut(&mut self) -> &mut Self::Target {
417 &mut self.connection
418 }
419}
420
421impl From<Connection> for Client {
422 fn from(connection: Connection) -> Self {
423 Client {
424 connection,
425 use_hive_partitioning: DEFAULT_USE_HIVE_PARTITIONING,
426 convert_wkb: DEFAULT_CONVERT_WKB,
427 union_by_name: DEFAULT_UNION_BY_NAME,
428 }
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::Client;
435 use duckdb::Connection;
436 use geo::Geometry;
437 use rstest::{fixture, rstest};
438 use stac::Bbox;
439 use stac_api::{Items, Search, Sortby};
440 use stac_validate::Validate;
441
442 #[fixture]
443 #[once]
444 fn install_spatial() {
445 let connection = Connection::open_in_memory().unwrap();
446 connection.execute("INSTALL spatial", []).unwrap();
447 }
448
449 #[allow(unused_variables)]
450 #[fixture]
451 fn client(install_spatial: ()) -> Client {
452 Client::new().unwrap()
453 }
454
455 #[allow(unused_variables)]
456 #[rstest]
457 fn new(install_spatial: ()) {
458 Client::new().unwrap();
459 }
460
461 #[rstest]
462 fn extensions(client: Client) {
463 let _ = client.extensions().unwrap();
464 }
465
466 #[rstest]
467 fn search(client: Client) {
468 let item_collection = client
469 .search("data/100-sentinel-2-items.parquet", Search::default())
470 .unwrap();
471 assert_eq!(item_collection.items.len(), 100);
472 item_collection.items[0].validate().unwrap();
473 }
474
475 #[rstest]
476 fn search_to_arrow(client: Client) {
477 let record_batches = client
478 .search_to_arrow("data/100-sentinel-2-items.parquet", Search::default())
479 .unwrap();
480 assert_eq!(record_batches.len(), 1);
481 }
482
483 #[rstest]
484 fn search_ids(client: Client) {
485 let item_collection = client
486 .search(
487 "data/100-sentinel-2-items.parquet",
488 Search::default().ids(vec![
489 "S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429".to_string(),
490 ]),
491 )
492 .unwrap();
493 assert_eq!(item_collection.items.len(), 1);
494 assert_eq!(
495 item_collection.items[0]["id"],
496 "S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429"
497 );
498 }
499
500 #[rstest]
501 fn search_intersects(client: Client) {
502 let item_collection = client
503 .search(
504 "data/100-sentinel-2-items.parquet",
505 Search::default().intersects(&Geometry::Point(geo::point! { x: -106., y: 40.5 })),
506 )
507 .unwrap();
508 assert_eq!(item_collection.items.len(), 50);
509 }
510
511 #[rstest]
512 fn search_collections(client: Client) {
513 let item_collection = client
514 .search(
515 "data/100-sentinel-2-items.parquet",
516 Search::default().collections(vec!["sentinel-2-l2a".to_string()]),
517 )
518 .unwrap();
519 assert_eq!(item_collection.items.len(), 100);
520
521 let item_collection = client
522 .search(
523 "data/100-sentinel-2-items.parquet",
524 Search::default().collections(vec!["foobar".to_string()]),
525 )
526 .unwrap();
527 assert_eq!(item_collection.items.len(), 0);
528 }
529
530 #[rstest]
531 fn search_bbox(client: Client) {
532 let item_collection = client
533 .search(
534 "data/100-sentinel-2-items.parquet",
535 Search::default().bbox(Bbox::new(-106.1, 40.5, -106.0, 40.6)),
536 )
537 .unwrap();
538 assert_eq!(item_collection.items.len(), 50);
539 }
540
541 #[rstest]
542 fn search_datetime(client: Client) {
543 let item_collection = client
544 .search(
545 "data/100-sentinel-2-items.parquet",
546 Search::default().datetime("2024-12-02T00:00:00Z/.."),
547 )
548 .unwrap();
549 assert_eq!(item_collection.items.len(), 1);
550 let item_collection = client
551 .search(
552 "data/100-sentinel-2-items.parquet",
553 Search::default().datetime("../2024-12-02T00:00:00Z"),
554 )
555 .unwrap();
556 assert_eq!(item_collection.items.len(), 99);
557 }
558
559 #[rstest]
560 fn search_datetime_empty_interval(client: Client) {
561 let item_collection = client
562 .search(
563 "data/100-sentinel-2-items.parquet",
564 Search::default().datetime("2024-12-02T00:00:00Z/"),
565 )
566 .unwrap();
567 assert_eq!(item_collection.items.len(), 1);
568 }
569
570 #[rstest]
571 fn search_limit(client: Client) {
572 let item_collection = client
573 .search(
574 "data/100-sentinel-2-items.parquet",
575 Search::default().limit(42),
576 )
577 .unwrap();
578 assert_eq!(item_collection.items.len(), 42);
579 }
580
581 #[rstest]
582 fn search_offset(client: Client) {
583 let mut search = Search::default().limit(1);
584 search
585 .items
586 .additional_fields
587 .insert("offset".to_string(), 1.into());
588 let item_collection = client
589 .search("data/100-sentinel-2-items.parquet", search)
590 .unwrap();
591 assert_eq!(
592 item_collection.items[0]["id"],
593 "S2A_MSIL2A_20241201T175721_R141_T13TDE_20241201T213150"
594 );
595 }
596
597 #[rstest]
598 fn search_sortby(client: Client) {
599 let item_collection = client
600 .search(
601 "data/100-sentinel-2-items.parquet",
602 Search::default()
603 .sortby(vec![Sortby::asc("datetime")])
604 .limit(1),
605 )
606 .unwrap();
607 assert_eq!(
608 item_collection.items[0]["id"],
609 "S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429"
610 );
611
612 let item_collection = client
613 .search(
614 "data/100-sentinel-2-items.parquet",
615 Search::default()
616 .sortby(vec![Sortby::desc("datetime")])
617 .limit(1),
618 )
619 .unwrap();
620 assert_eq!(
621 item_collection.items[0]["id"],
622 "S2B_MSIL2A_20241203T174629_R098_T13TDE_20241203T211406"
623 );
624 }
625
626 #[rstest]
627 fn search_fields(client: Client) {
628 let item_collection = client
629 .search(
630 "data/100-sentinel-2-items.parquet",
631 Search::default().fields("+id".parse().unwrap()).limit(1),
632 )
633 .unwrap();
634 assert_eq!(item_collection.items[0].len(), 1);
635 }
636
637 #[rstest]
638 fn collections(client: Client) {
639 let collections = client
640 .collections("data/100-sentinel-2-items.parquet")
641 .unwrap();
642 assert_eq!(collections.len(), 1);
643 }
644
645 #[rstest]
646 fn no_convert_wkb(mut client: Client) {
647 client.convert_wkb = false;
648 let record_batches = client
649 .search_to_arrow("data/100-sentinel-2-items.parquet", Search::default())
650 .unwrap();
651 let schema = record_batches[0].schema();
652 assert_eq!(
653 schema.field_with_name("geometry").unwrap().metadata()["ARROW:extension:name"],
654 "geoarrow.wkb"
655 );
656 }
657
658 #[rstest]
659 fn filter(client: Client) {
660 let search = Search {
661 items: Items {
662 filter: Some("sat:relative_orbit = 98".parse().unwrap()),
663 ..Default::default()
664 },
665 ..Default::default()
666 };
667 let item_collection = client
668 .search("data/100-sentinel-2-items.parquet", search)
669 .unwrap();
670 assert_eq!(item_collection.items.len(), 49);
671 }
672
673 #[rstest]
674 fn filter_no_column(client: Client) {
675 let search = Search {
676 items: Items {
677 filter: Some("foo:bar = 42".parse().unwrap()),
678 ..Default::default()
679 },
680 ..Default::default()
681 };
682 let item_collection = client
683 .search("data/100-sentinel-2-items.parquet", search)
684 .unwrap();
685 assert_eq!(item_collection.items.len(), 0);
686 }
687
688 #[rstest]
689 fn sortby_property(client: Client) {
690 let search = Search {
691 items: Items {
692 sortby: vec!["eo:cloud_cover".parse().unwrap()],
693 ..Default::default()
694 },
695 ..Default::default()
696 };
697 let item_collection = client
698 .search("data/100-sentinel-2-items.parquet", search)
699 .unwrap();
700 assert_eq!(item_collection.items.len(), 100);
701 }
702
703 #[rstest]
704 fn union_by_name(client: Client) {
705 let _ = client.search("data/*.parquet", Default::default()).unwrap();
706 }
707
708 #[rstest]
709 fn no_union_by_name(mut client: Client) {
710 client.union_by_name = false;
711 let _ = client
712 .search("data/*.parquet", Default::default())
713 .unwrap_err();
714 }
715}