zarr_datafusion/datasource/
zarr.rs

1use arrow::datatypes::SchemaRef;
2use async_trait::async_trait;
3use datafusion::catalog::Session;
4use datafusion::common::stats::{ColumnStatistics, Precision, Statistics};
5use datafusion::logical_expr::{Expr, TableProviderFilterPushDown};
6use datafusion::{datasource::TableProvider, error::Result, physical_plan::ExecutionPlan};
7use std::sync::Arc;
8use tracing::{debug, info};
9use zarrs::storage::AsyncReadableListableStorage;
10use zarrs_object_store::object_store::path::Path as ObjectPath;
11
12use crate::physical_plan::zarr_exec::ZarrExec;
13use crate::reader::filter::parse_coord_filters;
14use crate::reader::schema_inference::ZarrStoreMeta;
15
16/// Cached remote store info (store, prefix, metadata)
17pub type CachedRemoteStore = Option<(AsyncReadableListableStorage, ObjectPath, ZarrStoreMeta)>;
18
19pub struct ZarrTable {
20    schema: SchemaRef,
21    path: String,
22    /// Cached async store and metadata for remote URLs (avoids recreating on each query)
23    cached_remote: CachedRemoteStore,
24    /// Store metadata for statistics (used for count optimization)
25    store_meta: Option<ZarrStoreMeta>,
26}
27
28impl std::fmt::Debug for ZarrTable {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("ZarrTable")
31            .field("schema", &self.schema)
32            .field("path", &self.path)
33            .field(
34                "cached_remote",
35                &self.cached_remote.as_ref().map(|(_, p, _)| p),
36            )
37            .field(
38                "total_rows",
39                &self.store_meta.as_ref().map(|m| m.total_rows),
40            )
41            .finish()
42    }
43}
44
45impl ZarrTable {
46    pub fn new(schema: SchemaRef, path: impl Into<String>) -> Self {
47        Self {
48            schema,
49            path: path.into(),
50            cached_remote: None,
51            store_meta: None,
52        }
53    }
54
55    /// Create a ZarrTable with store metadata (for local paths)
56    pub fn with_metadata(
57        schema: SchemaRef,
58        path: impl Into<String>,
59        metadata: ZarrStoreMeta,
60    ) -> Self {
61        Self {
62            schema,
63            path: path.into(),
64            cached_remote: None,
65            store_meta: Some(metadata),
66        }
67    }
68
69    /// Create a ZarrTable with a cached async store and metadata (for remote URLs)
70    pub fn with_cached_remote(
71        schema: SchemaRef,
72        path: impl Into<String>,
73        store: AsyncReadableListableStorage,
74        prefix: ObjectPath,
75        metadata: ZarrStoreMeta,
76    ) -> Self {
77        Self {
78            schema,
79            path: path.into(),
80            cached_remote: Some((store, prefix, metadata.clone())),
81            store_meta: Some(metadata),
82        }
83    }
84}
85
86#[async_trait]
87impl TableProvider for ZarrTable {
88    fn as_any(&self) -> &dyn std::any::Any {
89        self
90    }
91
92    fn schema(&self) -> SchemaRef {
93        self.schema.clone()
94    }
95
96    fn table_type(&self) -> datafusion::datasource::TableType {
97        datafusion::datasource::TableType::Base
98    }
99
100    /// Indicate which filters can be pushed down to the scan
101    ///
102    /// Returns `Inexact` for all filters - we'll handle coordinate equality
103    /// filters during scan, but DataFusion should still apply filters post-scan
104    /// for correctness (in case we miss any).
105    fn supports_filters_pushdown(
106        &self,
107        filters: &[&Expr],
108    ) -> Result<Vec<TableProviderFilterPushDown>> {
109        Ok(filters
110            .iter()
111            .map(|_| TableProviderFilterPushDown::Inexact)
112            .collect())
113    }
114
115    async fn scan(
116        &self,
117        _state: &dyn Session,
118        projection: Option<&Vec<usize>>,
119        filters: &[datafusion::logical_expr::Expr],
120        limit: Option<usize>,
121    ) -> Result<Arc<dyn ExecutionPlan>> {
122        // Log projection pushdown
123        let total_columns = self.schema.fields().len();
124        if let Some(indices) = projection {
125            let projected_names: Vec<_> = indices
126                .iter()
127                .map(|&i| self.schema.field(i).name().as_str())
128                .collect();
129            info!(
130                projected = indices.len(),
131                total = total_columns,
132                columns = ?projected_names,
133                "Projection pushdown"
134            );
135        } else {
136            info!(
137                projected = total_columns,
138                total = total_columns,
139                "No projection pushdown (all columns)"
140            );
141        }
142
143        // Log limit pushdown
144        if let Some(limit) = limit {
145            info!(limit, "Limit pushdown");
146        }
147
148        // Parse coordinate filters for filter pushdown
149        debug!(
150            num_filters = filters.len(),
151            filters = ?filters,
152            "Filters passed to scan()"
153        );
154        let coord_filters = if let Some(meta) = &self.store_meta {
155            let coord_names: Vec<String> = meta.coords.iter().map(|c| c.name.clone()).collect();
156            debug!(?coord_names, "Coordinate names from metadata");
157            let parsed = parse_coord_filters(filters, &coord_names);
158            if !parsed.is_empty() {
159                info!(
160                    num_filters = parsed.len(),
161                    coords = ?parsed.filters.keys().collect::<Vec<_>>(),
162                    "Filter pushdown"
163                );
164                Some(parsed)
165            } else {
166                None
167            }
168        } else {
169            // No metadata available - can't do filter pushdown
170            None
171        };
172
173        Ok(Arc::new(ZarrExec::new(
174            self.schema.clone(),
175            self.path.clone(),
176            projection.cloned(),
177            limit,
178            self.cached_remote.clone(),
179            coord_filters,
180        )))
181    }
182
183    /// Return statistics for this table
184    ///
185    /// This enables DataFusion's optimizer to convert count(*) and count(column)
186    /// queries into constant values without scanning the data.
187    ///
188    /// For coordinate columns, we also provide:
189    /// - min_value/max_value: Enables MIN(coord)/MAX(coord) optimization
190    /// - distinct_count: Number of unique coordinate values
191    fn statistics(&self) -> Option<Statistics> {
192        let meta = self.store_meta.as_ref()?;
193
194        // Build column statistics
195        let column_statistics: Vec<ColumnStatistics> = self
196            .schema
197            .fields()
198            .iter()
199            .map(|field| {
200                let field_name = field.name();
201
202                // Check if this is a coordinate column with min/max
203                if let Some(coord) = meta.coords.iter().find(|c| &c.name == field_name) {
204                    if let Some((min, max)) = coord.coord_min_max {
205                        // Coordinates have distinct_count = shape[0] (number of unique values)
206                        let distinct_count = coord.shape[0] as usize;
207
208                        // Convert min/max to ScalarValue based on the underlying type
209                        // Dictionary types have a value type inside
210                        let (min_value, max_value) = match field.data_type() {
211                            arrow::datatypes::DataType::Dictionary(_, value_type) => {
212                                scalar_values_from_f64(min, max, value_type.as_ref())
213                            }
214                            dt => scalar_values_from_f64(min, max, dt),
215                        };
216
217                        info!(
218                            coord = %field_name,
219                            min = %min_value,
220                            max = %max_value,
221                            distinct = distinct_count,
222                            "Coordinate statistics"
223                        );
224
225                        return ColumnStatistics {
226                            null_count: Precision::Exact(0),
227                            min_value: Precision::Exact(min_value),
228                            max_value: Precision::Exact(max_value),
229                            distinct_count: Precision::Exact(distinct_count),
230                            ..Default::default()
231                        };
232                    }
233                }
234
235                // Default: only null_count for data variables
236                ColumnStatistics {
237                    null_count: Precision::Exact(0),
238                    ..Default::default()
239                }
240            })
241            .collect();
242
243        info!(
244            total_rows = meta.total_rows,
245            num_columns = column_statistics.len(),
246            "Providing statistics for query optimization"
247        );
248
249        Some(Statistics {
250            num_rows: Precision::Exact(meta.total_rows),
251            total_byte_size: Precision::Absent,
252            column_statistics,
253        })
254    }
255}
256
257/// Convert f64 min/max values to appropriate ScalarValue based on Arrow data type
258fn scalar_values_from_f64(
259    min: f64,
260    max: f64,
261    data_type: &arrow::datatypes::DataType,
262) -> (
263    datafusion::common::ScalarValue,
264    datafusion::common::ScalarValue,
265) {
266    use arrow::datatypes::DataType;
267    use datafusion::common::ScalarValue;
268
269    match data_type {
270        DataType::Float64 => (
271            ScalarValue::Float64(Some(min)),
272            ScalarValue::Float64(Some(max)),
273        ),
274        DataType::Float32 => (
275            ScalarValue::Float32(Some(min as f32)),
276            ScalarValue::Float32(Some(max as f32)),
277        ),
278        DataType::Int64 => (
279            ScalarValue::Int64(Some(min as i64)),
280            ScalarValue::Int64(Some(max as i64)),
281        ),
282        DataType::Int32 => (
283            ScalarValue::Int32(Some(min as i32)),
284            ScalarValue::Int32(Some(max as i32)),
285        ),
286        DataType::Int16 => (
287            ScalarValue::Int16(Some(min as i16)),
288            ScalarValue::Int16(Some(max as i16)),
289        ),
290        DataType::UInt64 => (
291            ScalarValue::UInt64(Some(min as u64)),
292            ScalarValue::UInt64(Some(max as u64)),
293        ),
294        DataType::UInt32 => (
295            ScalarValue::UInt32(Some(min as u32)),
296            ScalarValue::UInt32(Some(max as u32)),
297        ),
298        // Fallback to Float64
299        _ => (
300            ScalarValue::Float64(Some(min)),
301            ScalarValue::Float64(Some(max)),
302        ),
303    }
304}