zarr_datafusion/reader/
filter.rs

1//! Filter pushdown support for Zarr queries
2//!
3//! This module parses DataFusion filter expressions to extract coordinate
4//! equality filters (e.g., `time = 1323647`), which can be used to read
5//! only the relevant subset of Zarr arrays.
6//!
7//! For a Zarr store with coordinates [time, hybrid, lat, lon], a filter like
8//! `time = X AND hybrid = Y` allows us to read only the slice of data where
9//! those coordinates match, dramatically reducing memory usage.
10
11use datafusion::common::ScalarValue;
12use datafusion::logical_expr::Expr;
13use std::collections::HashMap;
14use tracing::{debug, info, trace, warn};
15
16/// Represents a parsed coordinate filter
17///
18/// For a filter like `time = 1323647`, this stores:
19/// - coord_name: "time"
20/// - value: ScalarValue::Int64(1323647)
21#[derive(Debug, Clone)]
22pub struct CoordFilter {
23    /// Name of the coordinate column
24    pub coord_name: String,
25    /// Value to match (must be equality filter)
26    pub value: ScalarValue,
27}
28
29/// Collection of coordinate filters extracted from a WHERE clause
30#[derive(Debug, Clone, Default)]
31pub struct CoordFilters {
32    /// Map from coordinate name to filter value
33    pub filters: HashMap<String, ScalarValue>,
34}
35
36impl CoordFilters {
37    pub fn new() -> Self {
38        Self {
39            filters: HashMap::new(),
40        }
41    }
42
43    /// Check if any filters were extracted
44    pub fn is_empty(&self) -> bool {
45        self.filters.is_empty()
46    }
47
48    /// Get the filter value for a coordinate, if any
49    pub fn get(&self, coord_name: &str) -> Option<&ScalarValue> {
50        self.filters.get(coord_name)
51    }
52
53    /// Number of coordinate filters
54    pub fn len(&self) -> usize {
55        self.filters.len()
56    }
57}
58
59/// Parse DataFusion filter expressions to extract coordinate equality filters
60///
61/// Only extracts simple equality filters of the form:
62/// - `coord = value` (Column = Literal)
63/// - `value = coord` (Literal = Column)
64///
65/// Combined with AND:
66/// - `coord1 = value1 AND coord2 = value2`
67///
68/// Other filter types (OR, >, <, LIKE, etc.) are ignored and left for
69/// DataFusion to handle post-scan.
70pub fn parse_coord_filters(filters: &[Expr], coord_names: &[String]) -> CoordFilters {
71    let mut result = CoordFilters::new();
72
73    for filter in filters {
74        extract_equality_filters(filter, coord_names, &mut result);
75    }
76
77    if !result.is_empty() {
78        info!(
79            num_filters = result.len(),
80            filters = ?result.filters.keys().collect::<Vec<_>>(),
81            "Extracted coordinate filters for pushdown"
82        );
83    } else {
84        debug!("No coordinate equality filters found for pushdown");
85    }
86
87    result
88}
89
90/// Recursively extract equality filters from an expression
91fn extract_equality_filters(expr: &Expr, coord_names: &[String], result: &mut CoordFilters) {
92    match expr {
93        // Handle AND: recurse into both sides
94        Expr::BinaryExpr(binary) if binary.op == datafusion::logical_expr::Operator::And => {
95            extract_equality_filters(&binary.left, coord_names, result);
96            extract_equality_filters(&binary.right, coord_names, result);
97        }
98
99        // Handle equality: Column = Literal or Literal = Column
100        Expr::BinaryExpr(binary) if binary.op == datafusion::logical_expr::Operator::Eq => {
101            if let Some((col_name, value)) = extract_column_literal_eq(&binary.left, &binary.right)
102            {
103                if coord_names.contains(&col_name) {
104                    debug!(
105                        coord = %col_name,
106                        value = %value,
107                        "Found coordinate equality filter"
108                    );
109                    result.filters.insert(col_name, value);
110                } else {
111                    trace!(
112                        column = %col_name,
113                        "Equality filter on non-coordinate column, skipping"
114                    );
115                }
116            }
117        }
118
119        // Handle CAST expressions that wrap the filter
120        Expr::Cast(cast) => {
121            extract_equality_filters(&cast.expr, coord_names, result);
122        }
123
124        // Other expressions: OR, >, <, etc. - skip for now
125        other => {
126            trace!(expr_type = %other.variant_name(), "Skipping non-equality filter expression");
127        }
128    }
129}
130
131/// Extract column name and literal value from an equality expression
132///
133/// Returns Some((column_name, value)) for patterns like:
134/// - Column = Literal
135/// - Literal = Column
136/// - Cast(Column) = Literal
137fn extract_column_literal_eq(left: &Expr, right: &Expr) -> Option<(String, ScalarValue)> {
138    // Try Column = Literal
139    if let (Some(col_name), Some(value)) = (extract_column_name(left), extract_literal(right)) {
140        return Some((col_name, value));
141    }
142
143    // Try Literal = Column
144    if let (Some(value), Some(col_name)) = (extract_literal(left), extract_column_name(right)) {
145        return Some((col_name, value));
146    }
147
148    None
149}
150
151/// Extract column name from expression, handling Cast wrappers
152fn extract_column_name(expr: &Expr) -> Option<String> {
153    match expr {
154        Expr::Column(col) => Some(col.name.clone()),
155        Expr::Cast(cast) => extract_column_name(&cast.expr),
156        Expr::TryCast(cast) => extract_column_name(&cast.expr),
157        _ => None,
158    }
159}
160
161/// Extract literal value from expression
162///
163/// Unwraps Dictionary scalar values to get the underlying value,
164/// since coordinate filters compare against raw values, not dictionary indices.
165fn extract_literal(expr: &Expr) -> Option<ScalarValue> {
166    match expr {
167        Expr::Literal(value, _) => Some(unwrap_dictionary_value(value.clone())),
168        Expr::Cast(cast) => {
169            // Handle cast of literal
170            if let Expr::Literal(value, _) = cast.expr.as_ref() {
171                // Try to cast the value to the target type
172                value
173                    .cast_to(&cast.data_type)
174                    .ok()
175                    .map(unwrap_dictionary_value)
176            } else {
177                None
178            }
179        }
180        _ => None,
181    }
182}
183
184/// Unwrap Dictionary scalar values to get the inner value
185///
186/// DataFusion wraps literal values in Dictionary type when comparing against
187/// Dictionary columns. We need the raw value for coordinate lookup.
188fn unwrap_dictionary_value(value: ScalarValue) -> ScalarValue {
189    match value {
190        ScalarValue::Dictionary(_, inner) => unwrap_dictionary_value(*inner),
191        other => other,
192    }
193}
194
195/// Calculate which indices to read from each coordinate based on filters
196///
197/// For each coordinate:
198/// - If filtered (e.g., `time = X`), find the index of X in the coordinate values
199/// - If not filtered, read all values
200///
201/// Returns a map from coordinate name to (start_idx, end_idx) range.
202/// If a filter value is not found in the coordinate, returns None (no matches).
203pub fn calculate_coord_ranges(
204    filters: &CoordFilters,
205    coord_names: &[String],
206    coord_values: &[CoordValuesRef<'_>],
207) -> Option<Vec<(usize, usize)>> {
208    let mut ranges = Vec::with_capacity(coord_names.len());
209
210    for (i, name) in coord_names.iter().enumerate() {
211        let values = &coord_values[i];
212        let range = if let Some(filter_value) = filters.get(name) {
213            // Find the index of the filter value in this coordinate
214            if let Some(idx) = find_value_index(values, filter_value) {
215                debug!(
216                    coord = %name,
217                    filter_value = %filter_value,
218                    index = idx,
219                    "Found filter value at index"
220                );
221                (idx, idx + 1) // Single value range
222            } else {
223                warn!(
224                    coord = %name,
225                    filter_value = %filter_value,
226                    "Filter value not found in coordinate - query will return no results"
227                );
228                return None; // No matches possible
229            }
230        } else {
231            // No filter on this coordinate - read all values
232            (0, values.len())
233        };
234        ranges.push(range);
235    }
236
237    Some(ranges)
238}
239
240/// Reference to coordinate values for searching
241pub enum CoordValuesRef<'a> {
242    Int64(&'a [i64]),
243    Float32(&'a [f32]),
244    Float64(&'a [f64]),
245}
246
247impl<'a> CoordValuesRef<'a> {
248    pub fn len(&self) -> usize {
249        match self {
250            CoordValuesRef::Int64(v) => v.len(),
251            CoordValuesRef::Float32(v) => v.len(),
252            CoordValuesRef::Float64(v) => v.len(),
253        }
254    }
255
256    pub fn is_empty(&self) -> bool {
257        self.len() == 0
258    }
259}
260
261/// Find the index of a scalar value in coordinate values
262fn find_value_index(values: &CoordValuesRef<'_>, target: &ScalarValue) -> Option<usize> {
263    match (values, target) {
264        (CoordValuesRef::Int64(vals), ScalarValue::Int64(Some(v))) => {
265            vals.iter().position(|x| x == v)
266        }
267        (CoordValuesRef::Int64(vals), ScalarValue::Int32(Some(v))) => {
268            let v64 = *v as i64;
269            vals.iter().position(|x| *x == v64)
270        }
271        (CoordValuesRef::Float32(vals), ScalarValue::Float32(Some(v))) => {
272            vals.iter().position(|x| (x - v).abs() < f32::EPSILON)
273        }
274        (CoordValuesRef::Float32(vals), ScalarValue::Float64(Some(v))) => {
275            let v32 = *v as f32;
276            vals.iter().position(|x| (x - v32).abs() < f32::EPSILON)
277        }
278        (CoordValuesRef::Float64(vals), ScalarValue::Float64(Some(v))) => {
279            vals.iter().position(|x| (x - v).abs() < f64::EPSILON)
280        }
281        (CoordValuesRef::Float64(vals), ScalarValue::Float32(Some(v))) => {
282            let v64 = *v as f64;
283            vals.iter().position(|x| (x - v64).abs() < f64::EPSILON)
284        }
285        // Handle integer to float comparisons
286        (CoordValuesRef::Float32(vals), ScalarValue::Int64(Some(v))) => {
287            let vf = *v as f32;
288            vals.iter().position(|x| (x - vf).abs() < f32::EPSILON)
289        }
290        (CoordValuesRef::Float64(vals), ScalarValue::Int64(Some(v))) => {
291            let vf = *v as f64;
292            vals.iter().position(|x| (x - vf).abs() < f64::EPSILON)
293        }
294        _ => {
295            debug!(
296                target_type = ?std::mem::discriminant(target),
297                "Unsupported filter value type for coordinate lookup"
298            );
299            None
300        }
301    }
302}
303
304/// Calculate the total number of rows after applying coordinate filters
305pub fn calculate_filtered_rows(coord_ranges: &[(usize, usize)]) -> usize {
306    coord_ranges
307        .iter()
308        .map(|(start, end)| end - start)
309        .product()
310}
311
312/// Calculate Zarr array subset ranges from coordinate filter ranges
313///
314/// Converts coordinate ranges to ArraySubset ranges for reading
315/// a specific slice of an nD Zarr array.
316pub fn coord_ranges_to_array_ranges(coord_ranges: &[(usize, usize)]) -> Vec<std::ops::Range<u64>> {
317    coord_ranges
318        .iter()
319        .map(|(start, end)| (*start as u64)..(*end as u64))
320        .collect()
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use datafusion::prelude::*;
327
328    #[test]
329    fn test_parse_simple_equality() {
330        let coord_names = vec!["time".to_string(), "lat".to_string()];
331
332        // time = 100
333        let filter = col("time").eq(lit(100i64));
334        let filters = parse_coord_filters(&[filter], &coord_names);
335
336        assert_eq!(filters.len(), 1);
337        assert!(filters.get("time").is_some());
338    }
339
340    #[test]
341    fn test_parse_and_filters() {
342        let coord_names = vec!["time".to_string(), "hybrid".to_string(), "lat".to_string()];
343
344        // time = 100 AND hybrid = 50
345        let filter = col("time")
346            .eq(lit(100i64))
347            .and(col("hybrid").eq(lit(50i64)));
348        let filters = parse_coord_filters(&[filter], &coord_names);
349
350        assert_eq!(filters.len(), 2);
351        assert!(filters.get("time").is_some());
352        assert!(filters.get("hybrid").is_some());
353    }
354
355    #[test]
356    fn test_ignore_non_coord_columns() {
357        let coord_names = vec!["time".to_string()];
358
359        // temperature = 20 (not a coordinate)
360        let filter = col("temperature").eq(lit(20i64));
361        let filters = parse_coord_filters(&[filter], &coord_names);
362
363        assert!(filters.is_empty());
364    }
365
366    #[test]
367    fn test_find_value_index() {
368        let vals = vec![10i64, 20, 30, 40, 50];
369        let values_ref = CoordValuesRef::Int64(&vals);
370
371        assert_eq!(
372            find_value_index(&values_ref, &ScalarValue::Int64(Some(30))),
373            Some(2)
374        );
375        assert_eq!(
376            find_value_index(&values_ref, &ScalarValue::Int64(Some(100))),
377            None
378        );
379    }
380
381    #[test]
382    fn test_calculate_filtered_rows() {
383        // time: 1 value, hybrid: 1 value, lat: 721, lon: 1440
384        let ranges = vec![(5, 6), (10, 11), (0, 721), (0, 1440)];
385        let rows = calculate_filtered_rows(&ranges);
386        assert_eq!(rows, 1 * 1 * 721 * 1440);
387    }
388}