stac_duckdb/
client.rs

1use crate::{Error, Extension, Result};
2use arrow_array::{RecordBatch, RecordBatchIterator};
3use chrono::DateTime;
4use cql2::{Expr, ToDuckSQL};
5use duckdb::{Connection, types::Value};
6use geo::BoundingRect;
7use geojson::Geometry;
8use stac::{Collection, SpatialExtent, TemporalExtent, geoarrow::DATETIME_COLUMNS};
9use stac_api::{Direction, Search};
10use std::ops::{Deref, DerefMut};
11
12/// Default hive partitioning value
13pub const DEFAULT_USE_HIVE_PARTITIONING: bool = false;
14
15/// Default convert wkb value.
16pub const DEFAULT_CONVERT_WKB: bool = true;
17
18/// The default collection description.
19pub const DEFAULT_COLLECTION_DESCRIPTION: &str =
20    "Auto-generated collection from stac-geoparquet extents";
21
22/// The default union by name value.
23pub const DEFAULT_UNION_BY_NAME: bool = true;
24
25/// A client for making DuckDB requests for STAC objects.
26#[derive(Debug)]
27pub struct Client {
28    connection: Connection,
29
30    /// Whether to use hive partitioning
31    pub use_hive_partitioning: bool,
32
33    /// Whether to convert WKB to native geometries.
34    ///
35    /// If False, WKB metadata will be added.
36    pub convert_wkb: bool,
37
38    /// Whether to use `union_by_name` when querying.
39    ///
40    /// Defaults to true.
41    pub union_by_name: bool,
42}
43
44impl Client {
45    /// Creates a new client with an in-memory DuckDB connection.
46    ///
47    /// This function will install the spatial extension. If you'd like to
48    /// manage your own extensions (e.g. if your extensions are stored in a
49    /// different location), set things up then use `connection.into()` to get a
50    /// new `Client`.
51    ///
52    /// # Examples
53    ///
54    /// ```
55    /// use stac_duckdb::Client;
56    ///
57    /// let client = Client::new().unwrap();
58    /// ```
59    pub fn new() -> Result<Client> {
60        let connection = Connection::open_in_memory()?;
61        connection.execute("INSTALL spatial", [])?;
62        connection.execute("LOAD spatial", [])?;
63        connection.execute("INSTALL icu", [])?;
64        connection.execute("LOAD icu", [])?;
65        Ok(connection.into())
66    }
67
68    /// Returns a vector of all extensions.
69    ///
70    /// # Examples
71    ///
72    /// ```
73    /// use stac_duckdb::Client;
74    ///
75    /// let client = Client::new().unwrap();
76    /// let extensions = client.extensions().unwrap();
77    /// ```
78    pub fn extensions(&self) -> Result<Vec<Extension>> {
79        let mut statement = self.prepare(
80            "SELECT extension_name, loaded, installed, install_path, description, extension_version, install_mode, installed_from FROM duckdb_extensions();",
81        )?;
82        let extensions = statement
83            .query_map([], |row| {
84                Ok(Extension {
85                    name: row.get("extension_name")?,
86                    loaded: row.get("loaded")?,
87                    installed: row.get("installed")?,
88                    install_path: row.get("install_path")?,
89                    description: row.get("description")?,
90                    version: row.get("extension_version")?,
91                    install_mode: row.get("install_mode")?,
92                    installed_from: row.get("installed_from")?,
93                })
94            })?
95            .collect::<std::result::Result<Vec<_>, duckdb::Error>>()?;
96        Ok(extensions)
97    }
98
99    /// Returns one or more [stac::Collection] from the items in the stac-geoparquet file.
100    ///
101    /// # Examples
102    ///
103    /// ```
104    /// use stac_duckdb::Client;
105    ///
106    /// let client = Client::new().unwrap();
107    /// let collections = client.collections("data/100-sentinel-2-items.parquet").unwrap();
108    /// ```
109    pub fn collections(&self, href: &str) -> Result<Vec<Collection>> {
110        let start_datetime= if self.prepare(&format!(
111            "SELECT column_name FROM (DESCRIBE SELECT * from {}) where column_name = 'start_datetime'",
112            self.format_parquet_href(href)
113        ))?.query([])?.next()?.is_some() {
114            "strftime(min(coalesce(start_datetime, datetime)), '%xT%X%z')"
115        } else {
116            "strftime(min(datetime), '%xT%X%z')"
117        };
118        let end_datetime = if self
119            .prepare(&format!(
120            "SELECT column_name FROM (DESCRIBE SELECT * from {}) where column_name = 'end_datetime'",
121            self.format_parquet_href(href)
122        ))?
123            .query([])?
124            .next()?
125            .is_some()
126        {
127            "strftime(max(coalesce(end_datetime, datetime)), '%xT%X%z')"
128        } else {
129            "strftime(max(datetime), '%xT%X%z')"
130        };
131        let mut statement = self.prepare(&format!(
132            "SELECT DISTINCT collection FROM {}",
133            self.format_parquet_href(href)
134        ))?;
135        let mut collections = Vec::new();
136        for row in statement.query_map([], |row| row.get::<_, String>(0))? {
137            let collection_id = row?;
138            let mut statement = self.connection.prepare(&
139                format!("SELECT ST_AsGeoJSON(ST_Extent_Agg(geometry)), {}, {} FROM {} WHERE collection = $1", start_datetime, end_datetime,
140                self.format_parquet_href(href)
141            ))?;
142            let row = statement.query_row([&collection_id], |row| {
143                Ok((
144                    row.get::<_, String>(0)?,
145                    row.get::<_, String>(1)?,
146                    row.get::<_, String>(2)?,
147                ))
148            })?;
149            let mut collection = Collection::new(collection_id, DEFAULT_COLLECTION_DESCRIPTION);
150            let geometry: geo::Geometry = Geometry::from_json_value(serde_json::from_str(&row.0)?)
151                .map_err(Box::new)?
152                .try_into()
153                .map_err(Box::new)?;
154            if let Some(bbox) = geometry.bounding_rect() {
155                collection.extent.spatial = SpatialExtent {
156                    bbox: vec![bbox.into()],
157                };
158            }
159            collection.extent.temporal = TemporalExtent {
160                interval: vec![[
161                    Some(DateTime::parse_from_str(&row.1, "%FT%T%#z")?.into()),
162                    Some(DateTime::parse_from_str(&row.2, "%FT%T%#z")?.into()),
163                ]],
164            };
165            collections.push(collection);
166        }
167        Ok(collections)
168    }
169
170    /// Searches a single stac-geoparquet file.
171    ///
172    /// # Examples
173    ///
174    /// ```
175    /// use stac_duckdb::Client;
176    ///
177    /// let client = Client::new().unwrap();
178    /// let item_collection = client.search("data/100-sentinel-2-items.parquet", Default::default()).unwrap();
179    /// ```
180    pub fn search(&self, href: &str, search: Search) -> Result<stac_api::ItemCollection> {
181        let record_batches = self.search_to_arrow(href, search)?;
182        if record_batches.is_empty() {
183            Ok(Default::default())
184        } else {
185            let schema = record_batches[0].schema();
186            let item_collection = stac::geoarrow::json::from_record_batch_reader(
187                RecordBatchIterator::new(record_batches.into_iter().map(Ok), schema),
188            )?;
189            Ok(item_collection.into())
190        }
191    }
192
193    /// Searches to an iterator of record batches.
194    ///
195    /// # Examples
196    ///
197    /// ```
198    /// use stac_duckdb::Client;
199    ///
200    /// let client = Client::new().unwrap();
201    /// let record_batches = client.search_to_arrow("data/100-sentinel-2-items.parquet", Default::default()).unwrap();
202    /// ```
203    pub fn search_to_arrow(&self, href: &str, search: Search) -> Result<Vec<RecordBatch>> {
204        // TODO can we return an iterator?
205
206        // Note that we pull out some fields early so we can avoid closing some search strings below.
207
208        if search.items.query.is_some() {
209            return Err(Error::QueryNotImplemented);
210        }
211
212        // Check which columns we'll be selecting
213        let mut statement = self.prepare(&format!(
214            "SELECT column_name FROM (DESCRIBE SELECT * from {})",
215            self.format_parquet_href(href)
216        ))?;
217        let mut has_start_datetime = false;
218        let mut has_end_datetime = false;
219        let mut column_names = Vec::new();
220        let mut columns = Vec::new();
221        for row in statement.query_map([], |row| row.get::<_, String>(0))? {
222            let column = row?;
223            if column == "start_datetime" {
224                has_start_datetime = true;
225            }
226            if column == "end_datetime" {
227                has_end_datetime = true;
228            }
229
230            if let Some(fields) = search.fields.as_ref() {
231                if fields.exclude.contains(&column)
232                    || !(fields.include.is_empty() || fields.include.contains(&column))
233                {
234                    continue;
235                }
236            }
237
238            if column == "geometry" {
239                columns.push("ST_AsWKB(geometry) geometry".to_string());
240            } else if DATETIME_COLUMNS.contains(&column.as_str()) {
241                columns.push(format!("\"{column}\"::TIMESTAMPTZ {column}"))
242            } else {
243                columns.push(format!("\"{column}\""));
244            }
245            column_names.push(column);
246        }
247
248        // Get limit and offset
249        let limit = search.items.limit;
250        let offset = search
251            .items
252            .additional_fields
253            .get("offset")
254            .and_then(|v| v.as_i64());
255
256        // Build order_by
257        let mut order_by = Vec::with_capacity(search.sortby.len());
258        for sortby in &search.sortby {
259            order_by.push(format!(
260                "\"{}\" {}",
261                sortby.field,
262                match sortby.direction {
263                    Direction::Ascending => "ASC",
264                    Direction::Descending => "DESC",
265                }
266            ));
267        }
268
269        // Build wheres and params
270        let mut wheres = Vec::new();
271        let mut params = Vec::new();
272        if !search.ids.is_empty() {
273            wheres.push(format!(
274                "id IN ({})",
275                (0..search.ids.len())
276                    .map(|_| "?")
277                    .collect::<Vec<_>>()
278                    .join(",")
279            ));
280            params.extend(search.ids.into_iter().map(Value::Text));
281        }
282        if let Some(intersects) = search.intersects {
283            wheres.push("ST_Intersects(geometry, ST_GeomFromGeoJSON(?))".to_string());
284            params.push(Value::Text(intersects.to_string()));
285        }
286        if !search.collections.is_empty() {
287            wheres.push(format!(
288                "collection IN ({})",
289                (0..search.collections.len())
290                    .map(|_| "?")
291                    .collect::<Vec<_>>()
292                    .join(",")
293            ));
294            params.extend(search.collections.into_iter().map(Value::Text));
295        }
296        if let Some(bbox) = search.items.bbox {
297            wheres.push("ST_Intersects(geometry, ST_GeomFromGeoJSON(?))".to_string());
298            params.push(Value::Text(bbox.to_geometry().to_string()));
299        }
300        if let Some(datetime) = search.items.datetime {
301            let interval = stac::datetime::parse(&datetime)?;
302            if let Some(start) = interval.0 {
303                wheres.push(format!(
304                    "?::TIMESTAMPTZ <= {}",
305                    if has_start_datetime {
306                        "start_datetime"
307                    } else {
308                        "datetime"
309                    }
310                ));
311                params.push(Value::Text(start.to_rfc3339()));
312            }
313            if let Some(end) = interval.1 {
314                wheres.push(format!(
315                    "?::TIMESTAMPTZ >= {}", // Inclusive, https://github.com/radiantearth/stac-spec/pull/1280
316                    if has_end_datetime {
317                        "end_datetime"
318                    } else {
319                        "datetime"
320                    }
321                ));
322                params.push(Value::Text(end.to_rfc3339()));
323            }
324        }
325        if let Some(filter) = search.items.filter {
326            let expr: Expr = filter.try_into()?;
327            if expr_properties_match(&expr, &column_names) {
328                let sql = expr.to_ducksql().map_err(Box::new)?;
329                wheres.push(sql);
330            } else {
331                return Ok(Vec::new());
332            }
333        }
334
335        let mut suffix = String::new();
336        if !wheres.is_empty() {
337            suffix.push_str(&format!(" WHERE {}", wheres.join(" AND ")));
338        }
339        if !order_by.is_empty() {
340            suffix.push_str(&format!(" ORDER BY {}", order_by.join(", ")));
341        }
342        if let Some(limit) = limit {
343            suffix.push_str(&format!(" LIMIT {limit}"));
344        }
345        if let Some(offset) = offset {
346            suffix.push_str(&format!(" OFFSET {offset}"));
347        }
348
349        let sql = format!(
350            "SELECT {} FROM {}{}",
351            columns.join(","),
352            self.format_parquet_href(href),
353            suffix,
354        );
355        log::debug!("duckdb sql: {sql}");
356        let mut statement = self.prepare(&sql)?;
357        statement
358            .query_arrow(duckdb::params_from_iter(params))?
359            .map(|record_batch| {
360                let record_batch = if self.convert_wkb {
361                    stac::geoarrow::with_native_geometry(record_batch, "geometry")?
362                } else {
363                    stac::geoarrow::add_wkb_metadata(record_batch, "geometry")?
364                };
365                Ok(record_batch)
366            })
367            .collect::<Result<_>>()
368    }
369
370    fn format_parquet_href(&self, href: &str) -> String {
371        format!(
372            "read_parquet('{}', filename=true, hive_partitioning={}, union_by_name={})",
373            href,
374            if self.use_hive_partitioning {
375                "true"
376            } else {
377                "false"
378            },
379            if self.union_by_name { "true" } else { "false" }
380        )
381    }
382}
383
384fn expr_properties_match(expr: &Expr, properties: &[String]) -> bool {
385    use Expr::*;
386
387    match expr {
388        Property { property } => properties.contains(property),
389        Float(_) | Literal(_) | Bool(_) | Geometry(_) => true,
390        Operation { args, .. } => args
391            .iter()
392            .all(|expr| expr_properties_match(expr, properties)),
393        Interval { interval } => interval
394            .iter()
395            .all(|expr| expr_properties_match(expr, properties)),
396        Timestamp { timestamp } => expr_properties_match(timestamp, properties),
397        Date { date } => expr_properties_match(date, properties),
398        Array(exprs) => exprs
399            .iter()
400            .all(|expr| expr_properties_match(expr, properties)),
401        BBox { bbox } => bbox
402            .iter()
403            .all(|expr| expr_properties_match(expr, properties)),
404    }
405}
406
407impl Deref for Client {
408    type Target = Connection;
409
410    fn deref(&self) -> &Self::Target {
411        &self.connection
412    }
413}
414
415impl DerefMut for Client {
416    fn deref_mut(&mut self) -> &mut Self::Target {
417        &mut self.connection
418    }
419}
420
421impl From<Connection> for Client {
422    fn from(connection: Connection) -> Self {
423        Client {
424            connection,
425            use_hive_partitioning: DEFAULT_USE_HIVE_PARTITIONING,
426            convert_wkb: DEFAULT_CONVERT_WKB,
427            union_by_name: DEFAULT_UNION_BY_NAME,
428        }
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use super::Client;
435    use duckdb::Connection;
436    use geo::Geometry;
437    use rstest::{fixture, rstest};
438    use stac::Bbox;
439    use stac_api::{Items, Search, Sortby};
440    use stac_validate::Validate;
441
442    #[fixture]
443    #[once]
444    fn install_spatial() {
445        let connection = Connection::open_in_memory().unwrap();
446        connection.execute("INSTALL spatial", []).unwrap();
447    }
448
449    #[allow(unused_variables)]
450    #[fixture]
451    fn client(install_spatial: ()) -> Client {
452        Client::new().unwrap()
453    }
454
455    #[allow(unused_variables)]
456    #[rstest]
457    fn new(install_spatial: ()) {
458        Client::new().unwrap();
459    }
460
461    #[rstest]
462    fn extensions(client: Client) {
463        let _ = client.extensions().unwrap();
464    }
465
466    #[rstest]
467    fn search(client: Client) {
468        let item_collection = client
469            .search("data/100-sentinel-2-items.parquet", Search::default())
470            .unwrap();
471        assert_eq!(item_collection.items.len(), 100);
472        item_collection.items[0].validate().unwrap();
473    }
474
475    #[rstest]
476    fn search_to_arrow(client: Client) {
477        let record_batches = client
478            .search_to_arrow("data/100-sentinel-2-items.parquet", Search::default())
479            .unwrap();
480        assert_eq!(record_batches.len(), 1);
481    }
482
483    #[rstest]
484    fn search_ids(client: Client) {
485        let item_collection = client
486            .search(
487                "data/100-sentinel-2-items.parquet",
488                Search::default().ids(vec![
489                    "S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429".to_string(),
490                ]),
491            )
492            .unwrap();
493        assert_eq!(item_collection.items.len(), 1);
494        assert_eq!(
495            item_collection.items[0]["id"],
496            "S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429"
497        );
498    }
499
500    #[rstest]
501    fn search_intersects(client: Client) {
502        let item_collection = client
503            .search(
504                "data/100-sentinel-2-items.parquet",
505                Search::default().intersects(&Geometry::Point(geo::point! { x: -106., y: 40.5 })),
506            )
507            .unwrap();
508        assert_eq!(item_collection.items.len(), 50);
509    }
510
511    #[rstest]
512    fn search_collections(client: Client) {
513        let item_collection = client
514            .search(
515                "data/100-sentinel-2-items.parquet",
516                Search::default().collections(vec!["sentinel-2-l2a".to_string()]),
517            )
518            .unwrap();
519        assert_eq!(item_collection.items.len(), 100);
520
521        let item_collection = client
522            .search(
523                "data/100-sentinel-2-items.parquet",
524                Search::default().collections(vec!["foobar".to_string()]),
525            )
526            .unwrap();
527        assert_eq!(item_collection.items.len(), 0);
528    }
529
530    #[rstest]
531    fn search_bbox(client: Client) {
532        let item_collection = client
533            .search(
534                "data/100-sentinel-2-items.parquet",
535                Search::default().bbox(Bbox::new(-106.1, 40.5, -106.0, 40.6)),
536            )
537            .unwrap();
538        assert_eq!(item_collection.items.len(), 50);
539    }
540
541    #[rstest]
542    fn search_datetime(client: Client) {
543        let item_collection = client
544            .search(
545                "data/100-sentinel-2-items.parquet",
546                Search::default().datetime("2024-12-02T00:00:00Z/.."),
547            )
548            .unwrap();
549        assert_eq!(item_collection.items.len(), 1);
550        let item_collection = client
551            .search(
552                "data/100-sentinel-2-items.parquet",
553                Search::default().datetime("../2024-12-02T00:00:00Z"),
554            )
555            .unwrap();
556        assert_eq!(item_collection.items.len(), 99);
557    }
558
559    #[rstest]
560    fn search_datetime_empty_interval(client: Client) {
561        let item_collection = client
562            .search(
563                "data/100-sentinel-2-items.parquet",
564                Search::default().datetime("2024-12-02T00:00:00Z/"),
565            )
566            .unwrap();
567        assert_eq!(item_collection.items.len(), 1);
568    }
569
570    #[rstest]
571    fn search_limit(client: Client) {
572        let item_collection = client
573            .search(
574                "data/100-sentinel-2-items.parquet",
575                Search::default().limit(42),
576            )
577            .unwrap();
578        assert_eq!(item_collection.items.len(), 42);
579    }
580
581    #[rstest]
582    fn search_offset(client: Client) {
583        let mut search = Search::default().limit(1);
584        search
585            .items
586            .additional_fields
587            .insert("offset".to_string(), 1.into());
588        let item_collection = client
589            .search("data/100-sentinel-2-items.parquet", search)
590            .unwrap();
591        assert_eq!(
592            item_collection.items[0]["id"],
593            "S2A_MSIL2A_20241201T175721_R141_T13TDE_20241201T213150"
594        );
595    }
596
597    #[rstest]
598    fn search_sortby(client: Client) {
599        let item_collection = client
600            .search(
601                "data/100-sentinel-2-items.parquet",
602                Search::default()
603                    .sortby(vec![Sortby::asc("datetime")])
604                    .limit(1),
605            )
606            .unwrap();
607        assert_eq!(
608            item_collection.items[0]["id"],
609            "S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429"
610        );
611
612        let item_collection = client
613            .search(
614                "data/100-sentinel-2-items.parquet",
615                Search::default()
616                    .sortby(vec![Sortby::desc("datetime")])
617                    .limit(1),
618            )
619            .unwrap();
620        assert_eq!(
621            item_collection.items[0]["id"],
622            "S2B_MSIL2A_20241203T174629_R098_T13TDE_20241203T211406"
623        );
624    }
625
626    #[rstest]
627    fn search_fields(client: Client) {
628        let item_collection = client
629            .search(
630                "data/100-sentinel-2-items.parquet",
631                Search::default().fields("+id".parse().unwrap()).limit(1),
632            )
633            .unwrap();
634        assert_eq!(item_collection.items[0].len(), 1);
635    }
636
637    #[rstest]
638    fn collections(client: Client) {
639        let collections = client
640            .collections("data/100-sentinel-2-items.parquet")
641            .unwrap();
642        assert_eq!(collections.len(), 1);
643    }
644
645    #[rstest]
646    fn no_convert_wkb(mut client: Client) {
647        client.convert_wkb = false;
648        let record_batches = client
649            .search_to_arrow("data/100-sentinel-2-items.parquet", Search::default())
650            .unwrap();
651        let schema = record_batches[0].schema();
652        assert_eq!(
653            schema.field_with_name("geometry").unwrap().metadata()["ARROW:extension:name"],
654            "geoarrow.wkb"
655        );
656    }
657
658    #[rstest]
659    fn filter(client: Client) {
660        let search = Search {
661            items: Items {
662                filter: Some("sat:relative_orbit = 98".parse().unwrap()),
663                ..Default::default()
664            },
665            ..Default::default()
666        };
667        let item_collection = client
668            .search("data/100-sentinel-2-items.parquet", search)
669            .unwrap();
670        assert_eq!(item_collection.items.len(), 49);
671    }
672
673    #[rstest]
674    fn filter_no_column(client: Client) {
675        let search = Search {
676            items: Items {
677                filter: Some("foo:bar = 42".parse().unwrap()),
678                ..Default::default()
679            },
680            ..Default::default()
681        };
682        let item_collection = client
683            .search("data/100-sentinel-2-items.parquet", search)
684            .unwrap();
685        assert_eq!(item_collection.items.len(), 0);
686    }
687
688    #[rstest]
689    fn sortby_property(client: Client) {
690        let search = Search {
691            items: Items {
692                sortby: vec!["eo:cloud_cover".parse().unwrap()],
693                ..Default::default()
694            },
695            ..Default::default()
696        };
697        let item_collection = client
698            .search("data/100-sentinel-2-items.parquet", search)
699            .unwrap();
700        assert_eq!(item_collection.items.len(), 100);
701    }
702
703    #[rstest]
704    fn union_by_name(client: Client) {
705        let _ = client.search("data/*.parquet", Default::default()).unwrap();
706    }
707
708    #[rstest]
709    fn no_union_by_name(mut client: Client) {
710        client.union_by_name = false;
711        let _ = client
712            .search("data/*.parquet", Default::default())
713            .unwrap_err();
714    }
715}