zarr_datafusion/optimizer/
minmax_optimization.rs1use datafusion::common::stats::Precision;
15use datafusion::common::tree_node::Transformed;
16use datafusion::common::{Result, ScalarValue};
17use datafusion::datasource::source_as_provider;
18use datafusion::logical_expr::expr::AggregateFunction;
19use datafusion::logical_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableScan};
20use datafusion::optimizer::optimizer::ApplyOrder;
21use datafusion::optimizer::OptimizerRule;
22use datafusion::prelude::lit;
23use std::sync::Arc;
24use tracing::{debug, info, trace, warn};
25
26#[derive(Debug, Default)]
28pub struct MinMaxStatisticsRule;
29
30impl MinMaxStatisticsRule {
31 pub fn new() -> Self {
32 Self
33 }
34}
35
36#[derive(Debug, Clone, Copy)]
38enum MinMaxType {
39 Min,
40 Max,
41}
42
43impl OptimizerRule for MinMaxStatisticsRule {
44 fn name(&self) -> &str {
45 "minmax_statistics"
46 }
47
48 fn apply_order(&self) -> Option<ApplyOrder> {
49 Some(ApplyOrder::BottomUp)
50 }
51
52 fn supports_rewrite(&self) -> bool {
53 true
54 }
55
56 fn rewrite(
57 &self,
58 plan: LogicalPlan,
59 _config: &dyn datafusion::optimizer::OptimizerConfig,
60 ) -> Result<Transformed<LogicalPlan>> {
61 let LogicalPlan::Aggregate(aggregate) = &plan else {
63 trace!("Skipping non-Aggregate node");
64 return Ok(Transformed::no(plan));
65 };
66
67 debug!(
68 group_by_count = aggregate.group_expr.len(),
69 aggr_count = aggregate.aggr_expr.len(),
70 "Evaluating Aggregate node for MIN/MAX optimization"
71 );
72
73 if !aggregate.group_expr.is_empty() {
75 debug!(
76 group_by_count = aggregate.group_expr.len(),
77 "Skipping: has GROUP BY clause"
78 );
79 return Ok(Transformed::no(plan));
80 }
81
82 if aggregate.aggr_expr.is_empty() {
84 debug!("Skipping: no aggregate expressions");
85 return Ok(Transformed::no(plan));
86 }
87
88 let input = aggregate.input.as_ref();
90 let table_scan = match unwrap_to_table_scan(input) {
91 Some(scan) => scan,
92 None => {
93 debug!("Skipping: could not find TableScan in input");
94 return Ok(Transformed::no(plan));
95 }
96 };
97
98 let table_name = table_scan.table_name.to_string();
99 debug!(table = %table_name, "Found TableScan");
100
101 if !table_scan.filters.is_empty() {
103 debug!(
104 table = %table_name,
105 filter_count = table_scan.filters.len(),
106 "Skipping: query has WHERE filters"
107 );
108 return Ok(Transformed::no(plan));
109 }
110
111 let provider = match source_as_provider(&table_scan.source) {
113 Ok(p) => p,
114 Err(e) => {
115 warn!(
116 table = %table_name,
117 error = %e,
118 "Could not get TableProvider from TableSource"
119 );
120 return Ok(Transformed::no(plan));
121 }
122 };
123
124 let statistics = match provider.statistics() {
126 Some(stats) => stats,
127 None => {
128 debug!(
129 table = %table_name,
130 "Skipping: table does not provide statistics"
131 );
132 return Ok(Transformed::no(plan));
133 }
134 };
135
136 let schema = provider.schema();
137
138 let mut minmax_values: Vec<(String, ScalarValue)> = Vec::new();
140
141 for aggr_expr in &aggregate.aggr_expr {
142 match extract_minmax_info(aggr_expr, input) {
143 Some((minmax_type, column_name)) => {
144 let alias_name = get_expr_alias(aggr_expr, &aggregate.schema);
146
147 let col_idx = schema
149 .fields()
150 .iter()
151 .position(|f| f.name() == &column_name);
152
153 if let Some(idx) = col_idx {
154 let col_stats = &statistics.column_statistics;
155 if idx < col_stats.len() {
156 let value = match minmax_type {
157 MinMaxType::Min => match &col_stats[idx].min_value {
158 Precision::Exact(v) => {
159 debug!(
160 alias = %alias_name,
161 column = %column_name,
162 value = %v,
163 "Found exact MIN value from statistics"
164 );
165 v.clone()
166 }
167 Precision::Inexact(_) => {
168 debug!(
169 column = %column_name,
170 "Skipping: min value is inexact"
171 );
172 return Ok(Transformed::no(plan));
173 }
174 Precision::Absent => {
175 debug!(
176 column = %column_name,
177 "Skipping: min value not available"
178 );
179 return Ok(Transformed::no(plan));
180 }
181 },
182 MinMaxType::Max => match &col_stats[idx].max_value {
183 Precision::Exact(v) => {
184 debug!(
185 alias = %alias_name,
186 column = %column_name,
187 value = %v,
188 "Found exact MAX value from statistics"
189 );
190 v.clone()
191 }
192 Precision::Inexact(_) => {
193 debug!(
194 column = %column_name,
195 "Skipping: max value is inexact"
196 );
197 return Ok(Transformed::no(plan));
198 }
199 Precision::Absent => {
200 debug!(
201 column = %column_name,
202 "Skipping: max value not available"
203 );
204 return Ok(Transformed::no(plan));
205 }
206 },
207 };
208
209 minmax_values.push((alias_name, value));
210 } else {
211 debug!(
212 column = %column_name,
213 "Skipping: column statistics not available"
214 );
215 return Ok(Transformed::no(plan));
216 }
217 } else {
218 debug!(column = %column_name, "Skipping: column not found in schema");
219 return Ok(Transformed::no(plan));
220 }
221 }
222 None => {
223 debug!("Skipping: found non-MIN/MAX aggregate function");
225 return Ok(Transformed::no(plan));
226 }
227 }
228 }
229
230 info!(
231 table = %table_name,
232 minmax_expressions = minmax_values.len(),
233 values = ?minmax_values,
234 "MIN/MAX optimization applied - replacing with constants"
235 );
236
237 create_minmax_plan(&minmax_values)
238 }
239}
240
241fn extract_minmax_info(expr: &Expr, input_plan: &LogicalPlan) -> Option<(MinMaxType, String)> {
244 let inner = unwrap_alias(expr);
245 match inner {
246 Expr::AggregateFunction(AggregateFunction { func, params }) => {
247 let func_name = func.name().to_lowercase();
248 let minmax_type = match func_name.as_str() {
249 "min" => MinMaxType::Min,
250 "max" => MinMaxType::Max,
251 _ => {
252 trace!(func_name = %func_name, "Not a MIN/MAX function");
253 return None;
254 }
255 };
256
257 let arg = params.args.first()?;
261 let inner_arg = unwrap_alias(arg);
262
263 match inner_arg {
264 Expr::Column(col) => {
265 if col.name.starts_with("__common_expr") {
267 trace_column_in_plan(&col.name, input_plan, minmax_type)
268 } else {
269 Some((minmax_type, col.name.clone()))
270 }
271 }
272 Expr::Cast(cast) => {
273 if let Expr::Column(col) = unwrap_alias(cast.expr.as_ref()) {
275 return Some((minmax_type, col.name.clone()));
276 }
277 debug!(cast_expr_type = %cast.expr.variant_name(), "MIN/MAX cast argument is not a column");
278 None
279 }
280 other => {
281 debug!(expr_type = %other.variant_name(), "MIN/MAX argument is not a column or cast");
282 None
283 }
284 }
285 }
286 _ => {
287 trace!("Not an AggregateFunction expression");
288 None
289 }
290 }
291}
292
293fn trace_column_in_plan(
295 expr_name: &str,
296 plan: &LogicalPlan,
297 minmax_type: MinMaxType,
298) -> Option<(MinMaxType, String)> {
299 match plan {
300 LogicalPlan::Projection(proj) => {
301 for proj_expr in &proj.expr {
303 if let Expr::Alias(alias) = proj_expr {
304 if alias.name == expr_name {
305 let inner = unwrap_alias(&alias.expr);
307 if let Expr::Cast(cast) = inner {
308 if let Expr::Column(col) = unwrap_alias(cast.expr.as_ref()) {
309 debug!(
310 common_expr = %expr_name,
311 original_col = %col.name,
312 "Traced common expression to original column"
313 );
314 return Some((minmax_type, col.name.clone()));
315 }
316 } else if let Expr::Column(col) = inner {
317 debug!(
318 common_expr = %expr_name,
319 original_col = %col.name,
320 "Traced common expression to original column"
321 );
322 return Some((minmax_type, col.name.clone()));
323 }
324 }
325 }
326 }
327 trace_column_in_plan(expr_name, &proj.input, minmax_type)
329 }
330 LogicalPlan::TableScan(_) => None,
331 LogicalPlan::SubqueryAlias(alias) => {
332 trace_column_in_plan(expr_name, &alias.input, minmax_type)
333 }
334 other => {
335 for input in other.inputs() {
337 if let Some(result) = trace_column_in_plan(expr_name, input, minmax_type) {
338 return Some(result);
339 }
340 }
341 None
342 }
343 }
344}
345
346fn get_expr_alias(expr: &Expr, _schema: &Arc<datafusion::common::DFSchema>) -> String {
348 if let Expr::Alias(alias) = expr {
350 return alias.name.clone();
351 }
352
353 expr.schema_name().to_string()
355}
356
357fn unwrap_alias(expr: &Expr) -> &Expr {
359 match expr {
360 Expr::Alias(alias) => unwrap_alias(&alias.expr),
361 other => other,
362 }
363}
364
365fn unwrap_to_table_scan(plan: &LogicalPlan) -> Option<&TableScan> {
368 match plan {
369 LogicalPlan::TableScan(scan) => Some(scan),
370 LogicalPlan::Projection(Projection { input, .. }) => unwrap_to_table_scan(input),
371 LogicalPlan::SubqueryAlias(alias) => unwrap_to_table_scan(&alias.input),
372 LogicalPlan::Filter(_) => {
373 debug!("Found Filter node - min/max would change after filtering");
375 None
376 }
377 other => {
378 debug!(node_type = %other.display(), "Unsupported node type in input");
379 None
380 }
381 }
382}
383
384fn create_minmax_plan(minmax_values: &[(String, ScalarValue)]) -> Result<Transformed<LogicalPlan>> {
386 let exprs: Vec<Expr> = minmax_values
388 .iter()
389 .map(|(alias, value)| lit(value.clone()).alias(alias))
390 .collect();
391
392 let empty = LogicalPlan::EmptyRelation(EmptyRelation {
394 produce_one_row: true,
395 schema: Arc::new(arrow::datatypes::Schema::empty().try_into()?),
396 });
397
398 let projection = LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(empty))?);
400
401 Ok(Transformed::yes(projection))
402}