zarr_datafusion/optimizer/
minmax_optimization.rs

1//! Custom optimizer rule to replace MIN/MAX aggregates with constants from statistics
2//!
3//! This rule transforms:
4//!   `SELECT MIN(coord) FROM table` → `SELECT 0` (from statistics)
5//!   `SELECT MAX(coord) FROM table` → `SELECT 90` (from statistics)
6//!   `SELECT MIN(lat), MAX(lat), MIN(lon) FROM table` → `SELECT 0, 90, 0`
7//!
8//! The optimization is only applied when:
9//! 1. All aggregate expressions are MIN() or MAX() functions
10//! 2. No GROUP BY clause
11//! 3. No WHERE filters
12//! 4. Statistics provide exact min_value/max_value for the column
13
14use datafusion::common::stats::Precision;
15use datafusion::common::tree_node::Transformed;
16use datafusion::common::{Result, ScalarValue};
17use datafusion::datasource::source_as_provider;
18use datafusion::logical_expr::expr::AggregateFunction;
19use datafusion::logical_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableScan};
20use datafusion::optimizer::optimizer::ApplyOrder;
21use datafusion::optimizer::OptimizerRule;
22use datafusion::prelude::lit;
23use std::sync::Arc;
24use tracing::{debug, info, trace, warn};
25
26/// Optimizer rule that replaces MIN()/MAX() with constants from table statistics
27#[derive(Debug, Default)]
28pub struct MinMaxStatisticsRule;
29
30impl MinMaxStatisticsRule {
31    pub fn new() -> Self {
32        Self
33    }
34}
35
36/// Type of aggregate function (MIN or MAX)
37#[derive(Debug, Clone, Copy)]
38enum MinMaxType {
39    Min,
40    Max,
41}
42
43impl OptimizerRule for MinMaxStatisticsRule {
44    fn name(&self) -> &str {
45        "minmax_statistics"
46    }
47
48    fn apply_order(&self) -> Option<ApplyOrder> {
49        Some(ApplyOrder::BottomUp)
50    }
51
52    fn supports_rewrite(&self) -> bool {
53        true
54    }
55
56    fn rewrite(
57        &self,
58        plan: LogicalPlan,
59        _config: &dyn datafusion::optimizer::OptimizerConfig,
60    ) -> Result<Transformed<LogicalPlan>> {
61        // Only optimize Aggregate nodes
62        let LogicalPlan::Aggregate(aggregate) = &plan else {
63            trace!("Skipping non-Aggregate node");
64            return Ok(Transformed::no(plan));
65        };
66
67        debug!(
68            group_by_count = aggregate.group_expr.len(),
69            aggr_count = aggregate.aggr_expr.len(),
70            "Evaluating Aggregate node for MIN/MAX optimization"
71        );
72
73        // Must have no GROUP BY (simple aggregate)
74        if !aggregate.group_expr.is_empty() {
75            debug!(
76                group_by_count = aggregate.group_expr.len(),
77                "Skipping: has GROUP BY clause"
78            );
79            return Ok(Transformed::no(plan));
80        }
81
82        // Must have at least one aggregate expression
83        if aggregate.aggr_expr.is_empty() {
84            debug!("Skipping: no aggregate expressions");
85            return Ok(Transformed::no(plan));
86        }
87
88        // Input must be a simple TableScan with no filters
89        let input = aggregate.input.as_ref();
90        let table_scan = match unwrap_to_table_scan(input) {
91            Some(scan) => scan,
92            None => {
93                debug!("Skipping: could not find TableScan in input");
94                return Ok(Transformed::no(plan));
95            }
96        };
97
98        let table_name = table_scan.table_name.to_string();
99        debug!(table = %table_name, "Found TableScan");
100
101        // Must have no filters
102        if !table_scan.filters.is_empty() {
103            debug!(
104                table = %table_name,
105                filter_count = table_scan.filters.len(),
106                "Skipping: query has WHERE filters"
107            );
108            return Ok(Transformed::no(plan));
109        }
110
111        // Get TableProvider from TableSource to access statistics
112        let provider = match source_as_provider(&table_scan.source) {
113            Ok(p) => p,
114            Err(e) => {
115                warn!(
116                    table = %table_name,
117                    error = %e,
118                    "Could not get TableProvider from TableSource"
119                );
120                return Ok(Transformed::no(plan));
121            }
122        };
123
124        // Get statistics from the table
125        let statistics = match provider.statistics() {
126            Some(stats) => stats,
127            None => {
128                debug!(
129                    table = %table_name,
130                    "Skipping: table does not provide statistics"
131                );
132                return Ok(Transformed::no(plan));
133            }
134        };
135
136        let schema = provider.schema();
137
138        // Process all aggregate expressions - all must be MIN() or MAX() functions
139        let mut minmax_values: Vec<(String, ScalarValue)> = Vec::new();
140
141        for aggr_expr in &aggregate.aggr_expr {
142            match extract_minmax_info(aggr_expr, input) {
143                Some((minmax_type, column_name)) => {
144                    // Get the alias name for this expression
145                    let alias_name = get_expr_alias(aggr_expr, &aggregate.schema);
146
147                    // Find column index to get statistics
148                    let col_idx = schema
149                        .fields()
150                        .iter()
151                        .position(|f| f.name() == &column_name);
152
153                    if let Some(idx) = col_idx {
154                        let col_stats = &statistics.column_statistics;
155                        if idx < col_stats.len() {
156                            let value = match minmax_type {
157                                MinMaxType::Min => match &col_stats[idx].min_value {
158                                    Precision::Exact(v) => {
159                                        debug!(
160                                            alias = %alias_name,
161                                            column = %column_name,
162                                            value = %v,
163                                            "Found exact MIN value from statistics"
164                                        );
165                                        v.clone()
166                                    }
167                                    Precision::Inexact(_) => {
168                                        debug!(
169                                            column = %column_name,
170                                            "Skipping: min value is inexact"
171                                        );
172                                        return Ok(Transformed::no(plan));
173                                    }
174                                    Precision::Absent => {
175                                        debug!(
176                                            column = %column_name,
177                                            "Skipping: min value not available"
178                                        );
179                                        return Ok(Transformed::no(plan));
180                                    }
181                                },
182                                MinMaxType::Max => match &col_stats[idx].max_value {
183                                    Precision::Exact(v) => {
184                                        debug!(
185                                            alias = %alias_name,
186                                            column = %column_name,
187                                            value = %v,
188                                            "Found exact MAX value from statistics"
189                                        );
190                                        v.clone()
191                                    }
192                                    Precision::Inexact(_) => {
193                                        debug!(
194                                            column = %column_name,
195                                            "Skipping: max value is inexact"
196                                        );
197                                        return Ok(Transformed::no(plan));
198                                    }
199                                    Precision::Absent => {
200                                        debug!(
201                                            column = %column_name,
202                                            "Skipping: max value not available"
203                                        );
204                                        return Ok(Transformed::no(plan));
205                                    }
206                                },
207                            };
208
209                            minmax_values.push((alias_name, value));
210                        } else {
211                            debug!(
212                                column = %column_name,
213                                "Skipping: column statistics not available"
214                            );
215                            return Ok(Transformed::no(plan));
216                        }
217                    } else {
218                        debug!(column = %column_name, "Skipping: column not found in schema");
219                        return Ok(Transformed::no(plan));
220                    }
221                }
222                None => {
223                    // Not a MIN/MAX function - can't optimize this aggregate
224                    debug!("Skipping: found non-MIN/MAX aggregate function");
225                    return Ok(Transformed::no(plan));
226                }
227            }
228        }
229
230        info!(
231            table = %table_name,
232            minmax_expressions = minmax_values.len(),
233            values = ?minmax_values,
234            "MIN/MAX optimization applied - replacing with constants"
235        );
236
237        create_minmax_plan(&minmax_values)
238    }
239}
240
241/// Extract MIN/MAX function info from an expression
242/// Returns Some((MinMaxType, column_name)) if it's a MIN or MAX function, None otherwise
243fn extract_minmax_info(expr: &Expr, input_plan: &LogicalPlan) -> Option<(MinMaxType, String)> {
244    let inner = unwrap_alias(expr);
245    match inner {
246        Expr::AggregateFunction(AggregateFunction { func, params }) => {
247            let func_name = func.name().to_lowercase();
248            let minmax_type = match func_name.as_str() {
249                "min" => MinMaxType::Min,
250                "max" => MinMaxType::Max,
251                _ => {
252                    trace!(func_name = %func_name, "Not a MIN/MAX function");
253                    return None;
254                }
255            };
256
257            // Extract column name from arguments
258            // Need to handle: Column, Alias(Column), Cast(Column), Alias(Cast(Column))
259            // Also handle __common_expr_N which needs to be traced back
260            let arg = params.args.first()?;
261            let inner_arg = unwrap_alias(arg);
262
263            match inner_arg {
264                Expr::Column(col) => {
265                    // Check if this is a __common_expr reference that needs tracing
266                    if col.name.starts_with("__common_expr") {
267                        trace_column_in_plan(&col.name, input_plan, minmax_type)
268                    } else {
269                        Some((minmax_type, col.name.clone()))
270                    }
271                }
272                Expr::Cast(cast) => {
273                    // Dictionary columns get cast to their value type
274                    if let Expr::Column(col) = unwrap_alias(cast.expr.as_ref()) {
275                        return Some((minmax_type, col.name.clone()));
276                    }
277                    debug!(cast_expr_type = %cast.expr.variant_name(), "MIN/MAX cast argument is not a column");
278                    None
279                }
280                other => {
281                    debug!(expr_type = %other.variant_name(), "MIN/MAX argument is not a column or cast");
282                    None
283                }
284            }
285        }
286        _ => {
287            trace!("Not an AggregateFunction expression");
288            None
289        }
290    }
291}
292
293/// Trace a __common_expr column back to its original column through projections
294fn trace_column_in_plan(
295    expr_name: &str,
296    plan: &LogicalPlan,
297    minmax_type: MinMaxType,
298) -> Option<(MinMaxType, String)> {
299    match plan {
300        LogicalPlan::Projection(proj) => {
301            // Find the expression that defines this common expr
302            for proj_expr in &proj.expr {
303                if let Expr::Alias(alias) = proj_expr {
304                    if alias.name == expr_name {
305                        // Found the definition, extract column from the inner expression
306                        let inner = unwrap_alias(&alias.expr);
307                        if let Expr::Cast(cast) = inner {
308                            if let Expr::Column(col) = unwrap_alias(cast.expr.as_ref()) {
309                                debug!(
310                                    common_expr = %expr_name,
311                                    original_col = %col.name,
312                                    "Traced common expression to original column"
313                                );
314                                return Some((minmax_type, col.name.clone()));
315                            }
316                        } else if let Expr::Column(col) = inner {
317                            debug!(
318                                common_expr = %expr_name,
319                                original_col = %col.name,
320                                "Traced common expression to original column"
321                            );
322                            return Some((minmax_type, col.name.clone()));
323                        }
324                    }
325                }
326            }
327            // Try in input plan
328            trace_column_in_plan(expr_name, &proj.input, minmax_type)
329        }
330        LogicalPlan::TableScan(_) => None,
331        LogicalPlan::SubqueryAlias(alias) => {
332            trace_column_in_plan(expr_name, &alias.input, minmax_type)
333        }
334        other => {
335            // Try to trace through other plan types
336            for input in other.inputs() {
337                if let Some(result) = trace_column_in_plan(expr_name, input, minmax_type) {
338                    return Some(result);
339                }
340            }
341            None
342        }
343    }
344}
345
346/// Get the alias name for an expression
347fn get_expr_alias(expr: &Expr, _schema: &Arc<datafusion::common::DFSchema>) -> String {
348    // Try to get alias from the expression itself
349    if let Expr::Alias(alias) = expr {
350        return alias.name.clone();
351    }
352
353    // Use the expression's schema name (e.g., "min(lat)" or "max(lon)")
354    expr.schema_name().to_string()
355}
356
357/// Unwrap Alias expressions to get the inner expression
358fn unwrap_alias(expr: &Expr) -> &Expr {
359    match expr {
360        Expr::Alias(alias) => unwrap_alias(&alias.expr),
361        other => other,
362    }
363}
364
365/// Unwrap projections to find the underlying TableScan
366/// Returns None if there's a Filter (min/max would change after filtering)
367fn unwrap_to_table_scan(plan: &LogicalPlan) -> Option<&TableScan> {
368    match plan {
369        LogicalPlan::TableScan(scan) => Some(scan),
370        LogicalPlan::Projection(Projection { input, .. }) => unwrap_to_table_scan(input),
371        LogicalPlan::SubqueryAlias(alias) => unwrap_to_table_scan(&alias.input),
372        LogicalPlan::Filter(_) => {
373            // Filter changes the min/max - can't optimize
374            debug!("Found Filter node - min/max would change after filtering");
375            None
376        }
377        other => {
378            debug!(node_type = %other.display(), "Unsupported node type in input");
379            None
380        }
381    }
382}
383
384/// Create a new plan that returns MIN/MAX values as constants
385fn create_minmax_plan(minmax_values: &[(String, ScalarValue)]) -> Result<Transformed<LogicalPlan>> {
386    // Create literal expressions for each value with their aliases
387    let exprs: Vec<Expr> = minmax_values
388        .iter()
389        .map(|(alias, value)| lit(value.clone()).alias(alias))
390        .collect();
391
392    // Create an empty relation (no input needed)
393    let empty = LogicalPlan::EmptyRelation(EmptyRelation {
394        produce_one_row: true,
395        schema: Arc::new(arrow::datatypes::Schema::empty().try_into()?),
396    });
397
398    // Create projection with all the constants
399    let projection = LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(empty))?);
400
401    Ok(Transformed::yes(projection))
402}