1use datafusion::common::ScalarValue;
12use datafusion::logical_expr::Expr;
13use std::collections::HashMap;
14use tracing::{debug, info, trace, warn};
15
16#[derive(Debug, Clone)]
22pub struct CoordFilter {
23 pub coord_name: String,
25 pub value: ScalarValue,
27}
28
29#[derive(Debug, Clone, Default)]
31pub struct CoordFilters {
32 pub filters: HashMap<String, ScalarValue>,
34}
35
36impl CoordFilters {
37 pub fn new() -> Self {
38 Self {
39 filters: HashMap::new(),
40 }
41 }
42
43 pub fn is_empty(&self) -> bool {
45 self.filters.is_empty()
46 }
47
48 pub fn get(&self, coord_name: &str) -> Option<&ScalarValue> {
50 self.filters.get(coord_name)
51 }
52
53 pub fn len(&self) -> usize {
55 self.filters.len()
56 }
57}
58
59pub fn parse_coord_filters(filters: &[Expr], coord_names: &[String]) -> CoordFilters {
71 let mut result = CoordFilters::new();
72
73 for filter in filters {
74 extract_equality_filters(filter, coord_names, &mut result);
75 }
76
77 if !result.is_empty() {
78 info!(
79 num_filters = result.len(),
80 filters = ?result.filters.keys().collect::<Vec<_>>(),
81 "Extracted coordinate filters for pushdown"
82 );
83 } else {
84 debug!("No coordinate equality filters found for pushdown");
85 }
86
87 result
88}
89
90fn extract_equality_filters(expr: &Expr, coord_names: &[String], result: &mut CoordFilters) {
92 match expr {
93 Expr::BinaryExpr(binary) if binary.op == datafusion::logical_expr::Operator::And => {
95 extract_equality_filters(&binary.left, coord_names, result);
96 extract_equality_filters(&binary.right, coord_names, result);
97 }
98
99 Expr::BinaryExpr(binary) if binary.op == datafusion::logical_expr::Operator::Eq => {
101 if let Some((col_name, value)) = extract_column_literal_eq(&binary.left, &binary.right)
102 {
103 if coord_names.contains(&col_name) {
104 debug!(
105 coord = %col_name,
106 value = %value,
107 "Found coordinate equality filter"
108 );
109 result.filters.insert(col_name, value);
110 } else {
111 trace!(
112 column = %col_name,
113 "Equality filter on non-coordinate column, skipping"
114 );
115 }
116 }
117 }
118
119 Expr::Cast(cast) => {
121 extract_equality_filters(&cast.expr, coord_names, result);
122 }
123
124 other => {
126 trace!(expr_type = %other.variant_name(), "Skipping non-equality filter expression");
127 }
128 }
129}
130
131fn extract_column_literal_eq(left: &Expr, right: &Expr) -> Option<(String, ScalarValue)> {
138 if let (Some(col_name), Some(value)) = (extract_column_name(left), extract_literal(right)) {
140 return Some((col_name, value));
141 }
142
143 if let (Some(value), Some(col_name)) = (extract_literal(left), extract_column_name(right)) {
145 return Some((col_name, value));
146 }
147
148 None
149}
150
151fn extract_column_name(expr: &Expr) -> Option<String> {
153 match expr {
154 Expr::Column(col) => Some(col.name.clone()),
155 Expr::Cast(cast) => extract_column_name(&cast.expr),
156 Expr::TryCast(cast) => extract_column_name(&cast.expr),
157 _ => None,
158 }
159}
160
161fn extract_literal(expr: &Expr) -> Option<ScalarValue> {
166 match expr {
167 Expr::Literal(value, _) => Some(unwrap_dictionary_value(value.clone())),
168 Expr::Cast(cast) => {
169 if let Expr::Literal(value, _) = cast.expr.as_ref() {
171 value
173 .cast_to(&cast.data_type)
174 .ok()
175 .map(unwrap_dictionary_value)
176 } else {
177 None
178 }
179 }
180 _ => None,
181 }
182}
183
184fn unwrap_dictionary_value(value: ScalarValue) -> ScalarValue {
189 match value {
190 ScalarValue::Dictionary(_, inner) => unwrap_dictionary_value(*inner),
191 other => other,
192 }
193}
194
195pub fn calculate_coord_ranges(
204 filters: &CoordFilters,
205 coord_names: &[String],
206 coord_values: &[CoordValuesRef<'_>],
207) -> Option<Vec<(usize, usize)>> {
208 let mut ranges = Vec::with_capacity(coord_names.len());
209
210 for (i, name) in coord_names.iter().enumerate() {
211 let values = &coord_values[i];
212 let range = if let Some(filter_value) = filters.get(name) {
213 if let Some(idx) = find_value_index(values, filter_value) {
215 debug!(
216 coord = %name,
217 filter_value = %filter_value,
218 index = idx,
219 "Found filter value at index"
220 );
221 (idx, idx + 1) } else {
223 warn!(
224 coord = %name,
225 filter_value = %filter_value,
226 "Filter value not found in coordinate - query will return no results"
227 );
228 return None; }
230 } else {
231 (0, values.len())
233 };
234 ranges.push(range);
235 }
236
237 Some(ranges)
238}
239
240pub enum CoordValuesRef<'a> {
242 Int64(&'a [i64]),
243 Float32(&'a [f32]),
244 Float64(&'a [f64]),
245}
246
247impl<'a> CoordValuesRef<'a> {
248 pub fn len(&self) -> usize {
249 match self {
250 CoordValuesRef::Int64(v) => v.len(),
251 CoordValuesRef::Float32(v) => v.len(),
252 CoordValuesRef::Float64(v) => v.len(),
253 }
254 }
255
256 pub fn is_empty(&self) -> bool {
257 self.len() == 0
258 }
259}
260
261fn find_value_index(values: &CoordValuesRef<'_>, target: &ScalarValue) -> Option<usize> {
263 match (values, target) {
264 (CoordValuesRef::Int64(vals), ScalarValue::Int64(Some(v))) => {
265 vals.iter().position(|x| x == v)
266 }
267 (CoordValuesRef::Int64(vals), ScalarValue::Int32(Some(v))) => {
268 let v64 = *v as i64;
269 vals.iter().position(|x| *x == v64)
270 }
271 (CoordValuesRef::Float32(vals), ScalarValue::Float32(Some(v))) => {
272 vals.iter().position(|x| (x - v).abs() < f32::EPSILON)
273 }
274 (CoordValuesRef::Float32(vals), ScalarValue::Float64(Some(v))) => {
275 let v32 = *v as f32;
276 vals.iter().position(|x| (x - v32).abs() < f32::EPSILON)
277 }
278 (CoordValuesRef::Float64(vals), ScalarValue::Float64(Some(v))) => {
279 vals.iter().position(|x| (x - v).abs() < f64::EPSILON)
280 }
281 (CoordValuesRef::Float64(vals), ScalarValue::Float32(Some(v))) => {
282 let v64 = *v as f64;
283 vals.iter().position(|x| (x - v64).abs() < f64::EPSILON)
284 }
285 (CoordValuesRef::Float32(vals), ScalarValue::Int64(Some(v))) => {
287 let vf = *v as f32;
288 vals.iter().position(|x| (x - vf).abs() < f32::EPSILON)
289 }
290 (CoordValuesRef::Float64(vals), ScalarValue::Int64(Some(v))) => {
291 let vf = *v as f64;
292 vals.iter().position(|x| (x - vf).abs() < f64::EPSILON)
293 }
294 _ => {
295 debug!(
296 target_type = ?std::mem::discriminant(target),
297 "Unsupported filter value type for coordinate lookup"
298 );
299 None
300 }
301 }
302}
303
304pub fn calculate_filtered_rows(coord_ranges: &[(usize, usize)]) -> usize {
306 coord_ranges
307 .iter()
308 .map(|(start, end)| end - start)
309 .product()
310}
311
312pub fn coord_ranges_to_array_ranges(coord_ranges: &[(usize, usize)]) -> Vec<std::ops::Range<u64>> {
317 coord_ranges
318 .iter()
319 .map(|(start, end)| (*start as u64)..(*end as u64))
320 .collect()
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use datafusion::prelude::*;
327
328 #[test]
329 fn test_parse_simple_equality() {
330 let coord_names = vec!["time".to_string(), "lat".to_string()];
331
332 let filter = col("time").eq(lit(100i64));
334 let filters = parse_coord_filters(&[filter], &coord_names);
335
336 assert_eq!(filters.len(), 1);
337 assert!(filters.get("time").is_some());
338 }
339
340 #[test]
341 fn test_parse_and_filters() {
342 let coord_names = vec!["time".to_string(), "hybrid".to_string(), "lat".to_string()];
343
344 let filter = col("time")
346 .eq(lit(100i64))
347 .and(col("hybrid").eq(lit(50i64)));
348 let filters = parse_coord_filters(&[filter], &coord_names);
349
350 assert_eq!(filters.len(), 2);
351 assert!(filters.get("time").is_some());
352 assert!(filters.get("hybrid").is_some());
353 }
354
355 #[test]
356 fn test_ignore_non_coord_columns() {
357 let coord_names = vec!["time".to_string()];
358
359 let filter = col("temperature").eq(lit(20i64));
361 let filters = parse_coord_filters(&[filter], &coord_names);
362
363 assert!(filters.is_empty());
364 }
365
366 #[test]
367 fn test_find_value_index() {
368 let vals = vec![10i64, 20, 30, 40, 50];
369 let values_ref = CoordValuesRef::Int64(&vals);
370
371 assert_eq!(
372 find_value_index(&values_ref, &ScalarValue::Int64(Some(30))),
373 Some(2)
374 );
375 assert_eq!(
376 find_value_index(&values_ref, &ScalarValue::Int64(Some(100))),
377 None
378 );
379 }
380
381 #[test]
382 fn test_calculate_filtered_rows() {
383 let ranges = vec![(5, 6), (10, 11), (0, 721), (0, 1440)];
385 let rows = calculate_filtered_rows(&ranges);
386 assert_eq!(rows, 1 * 1 * 721 * 1440);
387 }
388}